{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}
module Clash.Normalize.Transformations.EtaExpand
  ( etaExpandSyn
  , etaExpansionTL
  ) where
import qualified Control.Lens as Lens
import qualified Data.Maybe as Maybe
import GHC.Stack (HasCallStack)
import Clash.Core.HasType
import Clash.Core.Term (Bind(..), CoreContext(..), Term(..), collectArgs, mkLams)
import Clash.Core.TermInfo (isFun)
import Clash.Core.Type (splitFunTy)
import Clash.Core.Util (mkInternalVar)
import Clash.Core.Var (Id)
import Clash.Core.VarEnv (elemVarSet, extendInScopeSet, extendInScopeSetList)
import Clash.Normalize.Types (NormRewrite)
import Clash.Rewrite.Types (TransformContext(..), tcCache, topEntities)
import Clash.Rewrite.Util (changed)
import Clash.Util (curLoc)
etaExpandSyn :: HasCallStack => NormRewrite
etaExpandSyn :: NormRewrite
etaExpandSyn (TransformContext InScopeSet
is0 Context
ctx) e :: Term
e@(Term -> (Term, [Either Term Type])
collectArgs -> (Var Id
f, [Either Term Type]
_)) = do
  UniqSet (Var Any)
topEnts <- Getting (UniqSet (Var Any)) RewriteEnv (UniqSet (Var Any))
-> RewriteMonad NormalizeState (UniqSet (Var Any))
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting (UniqSet (Var Any)) RewriteEnv (UniqSet (Var Any))
Lens' RewriteEnv (UniqSet (Var Any))
topEntities
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
  let isTopEnt :: Bool
isTopEnt = Id
f Id -> UniqSet (Var Any) -> Bool
forall a. Var a -> UniqSet (Var Any) -> Bool
`elemVarSet` UniqSet (Var Any)
topEnts
      isAppFunCtx :: Context -> Bool
isAppFunCtx =
        \case
          CoreContext
AppFun:Context
_ -> Bool
True
          TickC TickInfo
_:Context
c -> Context -> Bool
isAppFunCtx Context
c
          Context
_ -> Bool
False
      argTyM :: Maybe Type
argTyM = ((Type, Type) -> Type) -> Maybe (Type, Type) -> Maybe Type
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (Type, Type) -> Type
forall a b. (a, b) -> a
fst (TyConMap -> Type -> Maybe (Type, Type)
splitFunTy TyConMap
tcm (TyConMap -> Term -> Type
forall a. InferType a => TyConMap -> a -> Type
inferCoreTypeOf TyConMap
tcm Term
e))
  case Maybe Type
argTyM of
    Just Type
argTy | Bool
isTopEnt Bool -> Bool -> Bool
&& Bool -> Bool
not (Context -> Bool
isAppFunCtx Context
ctx) -> do
      Id
newId <- InScopeSet -> OccName -> Type -> RewriteMonad NormalizeState Id
forall (m :: Type -> Type).
MonadUnique m =>
InScopeSet -> OccName -> Type -> m Id
mkInternalVar InScopeSet
is0 OccName
"arg" Type
argTy
      Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Id -> Term -> Term
Lam Id
newId (Term -> Term -> Term
App Term
e (Id -> Term
Var Id
newId)))
    Maybe Type
_ -> Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
etaExpandSyn TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC etaExpandSyn #-}
stripLambda :: Term -> ([Id], Term)
stripLambda :: Term -> ([Id], Term)
stripLambda (Lam Id
bndr Term
e) =
  let ([Id]
bndrs, Term
e') = Term -> ([Id], Term)
stripLambda Term
e
   in (Id
bndr Id -> [Id] -> [Id]
forall a. a -> [a] -> [a]
: [Id]
bndrs, Term
e')
stripLambda Term
e = ([], Term
e)
etaExpansionTL :: HasCallStack => NormRewrite
etaExpansionTL :: NormRewrite
etaExpansionTL (TransformContext InScopeSet
is0 Context
ctx) (Lam Id
bndr Term
e) = do
  let ctx' :: TransformContext
ctx' = InScopeSet -> Context -> TransformContext
TransformContext (InScopeSet -> Id -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
is0 Id
bndr) (Id -> CoreContext
LamBody Id
bndr CoreContext -> Context -> Context
forall a. a -> [a] -> [a]
: Context
ctx)
  Term
e' <- HasCallStack => NormRewrite
NormRewrite
etaExpansionTL TransformContext
ctx' Term
e
  Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Term -> RewriteMonad NormalizeState Term)
-> Term -> RewriteMonad NormalizeState Term
forall a b. (a -> b) -> a -> b
$ Id -> Term -> Term
Lam Id
bndr Term
e'
etaExpansionTL (TransformContext InScopeSet
is0 Context
ctx) (Let (NonRec Id
i Term
x) Term
e) = do
  let ctx' :: TransformContext
ctx' = InScopeSet -> Context -> TransformContext
TransformContext (InScopeSet -> Id -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
is0 Id
i) ([Id] -> CoreContext
LetBody [Id
i] CoreContext -> Context -> Context
forall a. a -> [a] -> [a]
: Context
ctx)
  Term
e' <- HasCallStack => NormRewrite
NormRewrite
etaExpansionTL TransformContext
ctx' Term
e
  case Term -> ([Id], Term)
stripLambda Term
e' of
    (bs :: [Id]
bs@(Id
_:[Id]
_),Term
e2) -> do
      let e3 :: Term
e3 = Bind Term -> Term -> Term
Let (Id -> Term -> Bind Term
forall a. Id -> a -> Bind a
NonRec Id
i Term
x) Term
e2
      Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> [Id] -> Term
mkLams Term
e3 [Id]
bs)
    ([Id], Term)
_ -> Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Bind Term -> Term -> Term
Let (Id -> Term -> Bind Term
forall a. Id -> a -> Bind a
NonRec Id
i Term
x) Term
e')
etaExpansionTL (TransformContext InScopeSet
is0 Context
ctx) (Let (Rec [(Id, Term)]
xes) Term
e) = do
  let bndrs :: [Id]
bndrs = ((Id, Term) -> Id) -> [(Id, Term)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Id, Term) -> Id
forall a b. (a, b) -> a
fst [(Id, Term)]
xes
      ctx' :: TransformContext
ctx' = InScopeSet -> Context -> TransformContext
TransformContext (InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is0 [Id]
bndrs) ([Id] -> CoreContext
LetBody [Id]
bndrs CoreContext -> Context -> Context
forall a. a -> [a] -> [a]
: Context
ctx)
  Term
e' <- HasCallStack => NormRewrite
NormRewrite
etaExpansionTL TransformContext
ctx' Term
e
  case Term -> ([Id], Term)
stripLambda Term
e' of
    (bs :: [Id]
bs@(Id
_:[Id]
_),Term
e2) -> do
      let e3 :: Term
e3 = Bind Term -> Term -> Term
Let ([(Id, Term)] -> Bind Term
forall a. [(Id, a)] -> Bind a
Rec [(Id, Term)]
xes) Term
e2
      Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Term -> [Id] -> Term
mkLams Term
e3 [Id]
bs)
    ([Id], Term)
_ -> Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Bind Term -> Term -> Term
Let ([(Id, Term)] -> Bind Term
forall a. [(Id, a)] -> Bind a
Rec [(Id, Term)]
xes) Term
e')
etaExpansionTL (TransformContext InScopeSet
is0 Context
ctx) Term
e
  = do
    TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
    if TyConMap -> Term -> Bool
isFun TyConMap
tcm Term
e
      then do
        let argTy :: Type
argTy = ( (Type, Type) -> Type
forall a b. (a, b) -> a
fst
                    ((Type, Type) -> Type) -> (Term -> (Type, Type)) -> Term -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type, Type) -> Maybe (Type, Type) -> (Type, Type)
forall a. a -> Maybe a -> a
Maybe.fromMaybe ([Char] -> (Type, Type)
forall a. HasCallStack => [Char] -> a
error ([Char] -> (Type, Type)) -> [Char] -> (Type, Type)
forall a b. (a -> b) -> a -> b
$ $([Char]
curLoc) [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"etaExpansion splitFunTy")
                    (Maybe (Type, Type) -> (Type, Type))
-> (Term -> Maybe (Type, Type)) -> Term -> (Type, Type)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyConMap -> Type -> Maybe (Type, Type)
splitFunTy TyConMap
tcm
                    (Type -> Maybe (Type, Type))
-> (Term -> Type) -> Term -> Maybe (Type, Type)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyConMap -> Term -> Type
forall a. InferType a => TyConMap -> a -> Type
inferCoreTypeOf TyConMap
tcm
                    ) Term
e
        Id
newId <- InScopeSet -> OccName -> Type -> RewriteMonad NormalizeState Id
forall (m :: Type -> Type).
MonadUnique m =>
InScopeSet -> OccName -> Type -> m Id
mkInternalVar InScopeSet
is0 OccName
"arg" Type
argTy
        let ctx' :: TransformContext
ctx' = InScopeSet -> Context -> TransformContext
TransformContext (InScopeSet -> Id -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
is0 Id
newId) (Id -> CoreContext
LamBody Id
newId CoreContext -> Context -> Context
forall a. a -> [a] -> [a]
: Context
ctx)
        Term
e' <- HasCallStack => NormRewrite
NormRewrite
etaExpansionTL TransformContext
ctx' (Term -> Term -> Term
App Term
e (Id -> Term
Var Id
newId))
        Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed (Id -> Term -> Term
Lam Id
newId Term
e')
      else Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC etaExpansionTL #-}