{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TemplateHaskell #-}

module Clash.Normalize.Transformations.Cast
  ( argCastSpec
  , caseCast
  , elimCastCast
  , letCast
  , splitCastWork
  ) where

import Control.Exception (throw)
import qualified Control.Lens as Lens
import Control.Monad.Writer (listen)
import qualified Data.Monoid as Monoid (Any(..))
import GHC.Stack (HasCallStack)

import Clash.Core.Name (nameOcc)
import Clash.Core.Pretty (showPpr)
import Clash.Core.Term (LetBinding, Term(..), collectArgs, stripTicks)
import Clash.Core.TermInfo (isCast)
import Clash.Core.Type (normalizeType)
import Clash.Core.Var (isGlobalId, varName)
import Clash.Core.VarEnv (InScopeSet)
import Clash.Debug (trace)
import Clash.Normalize.Transformations.Specialize (specialize)
import Clash.Normalize.Types (NormRewrite, NormalizeSession)
import Clash.Rewrite.Types
  (TransformContext(..), bindings, curFun, tcCache, workFreeBinders)
import Clash.Rewrite.Util (changed, mkDerivedName, mkTmBinderFor)
import Clash.Rewrite.WorkFree (isWorkFree)
import Clash.Util (ClashException(..), curLoc)

-- | Push cast over an argument to a function into that function
--
-- This is done by specializing on the casted argument.
-- Example:
-- @
--   y = f (cast a)
--     where f x = g x
-- @
-- transforms to:
-- @
--   y = f' a
--     where f' x' = (\x -> g x) (cast x')
-- @
--
-- The reason d'etre for this transformation is that we hope to end up with
-- and expression where two casts are "back-to-back" after which we can
-- eliminate them in 'eliminateCastCast'.
argCastSpec :: HasCallStack => NormRewrite
argCastSpec :: NormRewrite
argCastSpec TransformContext
ctx e :: Term
e@(App Term
f (Term -> Term
stripTicks -> Cast Term
e' Type
_ Type
_))
 -- Don't specialise when the arguments are casts-of-casts, these casts-of-casts
 -- will be eliminated by 'eliminateCastCast' during the normalization of the
 -- "current" function. We thus prevent the unnecessary introduction of a
 -- specialized version of 'f'.
 | Bool -> Bool
not (Term -> Bool
isCast Term
e')
 -- We can only push casts into global binders
 , (Var Id
g, [Either Term Type]
_) <- Term -> (Term, [Either Term Type])
collectArgs Term
f
 , Id -> Bool
forall a. Var a -> Bool
isGlobalId Id
g = do
  BindingMap
bndrs <- Getting BindingMap (RewriteState NormalizeState) BindingMap
-> RewriteMonad NormalizeState BindingMap
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting BindingMap (RewriteState NormalizeState) BindingMap
forall extra. Lens' (RewriteState extra) BindingMap
bindings
  Lens' (RewriteState NormalizeState) (VarEnv Bool)
-> BindingMap -> Term -> RewriteMonad NormalizeState Bool
forall s (m :: Type -> Type).
(HasCallStack, MonadState s m) =>
Lens' s (VarEnv Bool) -> BindingMap -> Term -> m Bool
isWorkFree forall extra. Lens' (RewriteState extra) (VarEnv Bool)
Lens' (RewriteState NormalizeState) (VarEnv Bool)
workFreeBinders BindingMap
bndrs Term
e' RewriteMonad NormalizeState Bool
-> (Bool -> RewriteMonad NormalizeState Term)
-> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Bool
True -> RewriteMonad NormalizeState Term
go
    Bool
False -> RewriteMonad NormalizeState Term
-> RewriteMonad NormalizeState Term
forall a. a -> a
warn RewriteMonad NormalizeState Term
go
 where
  go :: RewriteMonad NormalizeState Term
go = NormRewrite
specialize TransformContext
ctx Term
e
  warn :: a -> a
warn = String -> a -> a
forall a. String -> a -> a
trace ([String] -> String
unwords
    [ String
"WARNING:", $(String
curLoc), String
"specializing a function on a non work-free"
    , String
"cast. Generated HDL implementation might contain duplicate work."
    , String
"Please report this as a bug.", String
"\n\nExpression where this occured:"
    , String
"\n\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
e
    ])
argCastSpec TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC argCastSpec #-}

-- | Push a cast over a case into it's alternatives.
caseCast :: HasCallStack => NormRewrite
caseCast :: NormRewrite
caseCast TransformContext
_ (Cast (Term -> Term
stripTicks -> Case Term
subj Type
ty [Alt]
alts) Type
ty1 Type
ty2) = do
  let alts' :: [Alt]
alts' = (Alt -> Alt) -> [Alt] -> [Alt]
forall a b. (a -> b) -> [a] -> [b]
map (\(Pat
p,Term
e) -> (Pat
p, Term -> Type -> Type -> Term
Cast Term
e Type
ty1 Type
ty2)) [Alt]
alts
  Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> Type -> [Alt] -> Term
Case Term
subj Type
ty [Alt]
alts')
caseCast TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC caseCast #-}

-- | Eliminate two back to back casts where the type going in and coming out are the same
--
-- @
--   (cast :: b -> a) $ (cast :: a -> b) x   ==> x
-- @
elimCastCast :: HasCallStack => NormRewrite
elimCastCast :: NormRewrite
elimCastCast TransformContext
_ c :: Term
c@(Cast (Term -> Term
stripTicks -> Cast Term
e Type
tyA Type
tyB) Type
tyB' Type
tyC) = do
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
  let ntyA :: Type
ntyA  = TyConMap -> Type -> Type
normalizeType TyConMap
tcm Type
tyA
      ntyB :: Type
ntyB  = TyConMap -> Type -> Type
normalizeType TyConMap
tcm Type
tyB
      ntyB' :: Type
ntyB' = TyConMap -> Type -> Type
normalizeType TyConMap
tcm Type
tyB'
      ntyC :: Type
ntyC  = TyConMap -> Type -> Type
normalizeType TyConMap
tcm Type
tyC
  if Type
ntyB Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
ntyB' Bool -> Bool -> Bool
&& Type
ntyA Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
ntyC then Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
e
                                   else RewriteMonad NormalizeState Term
forall b. RewriteMonad NormalizeState b
throwError
  where throwError :: RewriteMonad NormalizeState b
throwError = do
          (Id
nm,SrcSpan
sp) <- Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
-> RewriteMonad NormalizeState (Id, SrcSpan)
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting (Id, SrcSpan) (RewriteState NormalizeState) (Id, SrcSpan)
forall extra. Lens' (RewriteState extra) (Id, SrcSpan)
curFun
          ClashException -> RewriteMonad NormalizeState b
forall a e. Exception e => e -> a
throw (SrcSpan -> String -> Maybe String -> ClashException
ClashException SrcSpan
sp ($(String
curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ Id -> String
forall p. PrettyPrec p => p -> String
showPpr Id
nm
                  String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": Found 2 nested casts whose types don't line up:\n"
                  String -> String -> String
forall a. [a] -> [a] -> [a]
++ Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
c)
                Maybe String
forall a. Maybe a
Nothing)

elimCastCast TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC elimCastCast #-}

-- | Push a cast over a Let into it's body
letCast :: HasCallStack => NormRewrite
letCast :: NormRewrite
letCast TransformContext
_ (Cast (Term -> Term
stripTicks -> Let Bind Term
binds Term
body) Type
ty1 Type
ty2) =
  Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> RewriteMonad NormalizeState Term)
-> Term -> RewriteMonad NormalizeState Term
forall a b. (a -> b) -> a -> b
$ Bind Term -> Term -> Term
Let Bind Term
binds (Term -> Type -> Type -> Term
Cast Term
body Type
ty1 Type
ty2)
letCast TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC letCast #-}

-- | Make a cast work-free by splitting the work of to a separate binding
--
-- @
-- let x = cast (f a b)
-- ==>
-- let x  = cast x'
--     x' = f a b
-- @
splitCastWork :: HasCallStack => NormRewrite
splitCastWork :: NormRewrite
splitCastWork ctx :: TransformContext
ctx@(TransformContext InScopeSet
is0 Context
_) unchanged :: Term
unchanged@(Letrec [LetBinding]
vs Term
e') = do
  ([[LetBinding]]
vss', Any -> Bool
Monoid.getAny -> Bool
hasChanged) <- RewriteMonad NormalizeState [[LetBinding]]
-> RewriteMonad NormalizeState ([[LetBinding]], Any)
forall w (m :: Type -> Type) a. MonadWriter w m => m a -> m (a, w)
listen ((LetBinding -> RewriteMonad NormalizeState [LetBinding])
-> [LetBinding] -> RewriteMonad NormalizeState [[LetBinding]]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (InScopeSet
-> LetBinding -> RewriteMonad NormalizeState [LetBinding]
splitCastLetBinding InScopeSet
is0) [LetBinding]
vs)
  let vs' :: [LetBinding]
vs' = [[LetBinding]] -> [LetBinding]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat [[LetBinding]]
vss'
  if Bool
hasChanged then Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed ([LetBinding] -> Term -> Term
Letrec [LetBinding]
vs' Term
e')
                else Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
unchanged
  where
    splitCastLetBinding
      :: InScopeSet
      -> LetBinding
      -> NormalizeSession [LetBinding]
    splitCastLetBinding :: InScopeSet
-> LetBinding -> RewriteMonad NormalizeState [LetBinding]
splitCastLetBinding InScopeSet
isN x :: LetBinding
x@(Id
nm, Term
e) = case Term -> Term
stripTicks Term
e of
      Cast (Var {}) Type
_ Type
_  -> [LetBinding] -> RewriteMonad NormalizeState [LetBinding]
forall (m :: Type -> Type) a. Monad m => a -> m a
return [LetBinding
x]  -- already work-free
      Cast (Cast {}) Type
_ Type
_ -> [LetBinding] -> RewriteMonad NormalizeState [LetBinding]
forall (m :: Type -> Type) a. Monad m => a -> m a
return [LetBinding
x]  -- casts will be eliminated
      Cast Term
e0 Type
ty1 Type
ty2 -> do
        TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
        Id
nm' <- InScopeSet
-> TyConMap -> Name Term -> Term -> RewriteMonad NormalizeState Id
forall (m :: Type -> Type) a.
MonadUnique m =>
InScopeSet -> TyConMap -> Name a -> Term -> m Id
mkTmBinderFor InScopeSet
isN TyConMap
tcm (TransformContext -> OccName -> Name Term
mkDerivedName TransformContext
ctx (Name Term -> OccName
forall a. Name a -> OccName
nameOcc (Name Term -> OccName) -> Name Term -> OccName
forall a b. (a -> b) -> a -> b
$ Id -> Name Term
forall a. Var a -> Name a
varName Id
nm)) Term
e0
        [LetBinding] -> RewriteMonad NormalizeState [LetBinding]
forall a extra. a -> RewriteMonad extra a
changed [(Id
nm',Term
e0)
                ,(Id
nm, Term -> Type -> Type -> Term
Cast (Id -> Term
Var Id
nm') Type
ty1 Type
ty2)
                ]
      Term
_ -> [LetBinding] -> RewriteMonad NormalizeState [LetBinding]
forall (m :: Type -> Type) a. Monad m => a -> m a
return [LetBinding
x]

splitCastWork TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC splitCastWork #-}