{-|
  Copyright  :  (C) 2012-2016, University of Twente,
                    2016-2017, Myrtle Software Ltd,
                    2017-2018, Google Inc.,
                    2021     , QBayLogic B.V.
  License    :  BSD2 (see the file LICENSE)
  Maintainer :  QBayLogic B.V. <devops@qbaylogic.com>
  The separating arguments transformation
-}

{-# LANGUAGE OverloadedStrings #-}

module Clash.Normalize.Transformations.SeparateArgs
  ( separateArguments
  ) where

import qualified Control.Lens as Lens
import Control.Monad.Writer (listen)
import qualified Data.List as List
import qualified Data.Monoid as Monoid
import GHC.Stack (HasCallStack)

import Clash.Core.HasType
import Clash.Core.Name (Name(..))
import Clash.Core.Subst (extendIdSubst, mkSubst, substTm)
import Clash.Core.Term (Term(..), collectArgsTicks, mkApps, mkLams, mkTicks)
import Clash.Core.Type (Type, mkPolyFunTy, splitFunForallTy)
import Clash.Core.TyCon (TyConMap)
import Clash.Core.Util (Projections (..), shouldSplit)
import Clash.Core.Var (Id, TyVar, Var (..), isGlobalId, mkLocalId)
import Clash.Core.VarEnv (extendInScopeSet, uniqAway)
import Clash.Normalize.Types (NormRewrite, NormalizeSession)
import Clash.Rewrite.Types (TransformContext(..), tcCache)
import Clash.Rewrite.Util (changed, mkDerivedName)

-- | Split apart (global) function arguments that contain types that we
-- want to separate off, e.g. Clocks. Works on both the definition side (i.e. the
-- lambda), and the call site (i.e. the application of the global variable). e.g.
-- turns
--
-- > f :: (Clock System, Reset System) -> Signal System Int
--
-- into
--
-- > f :: Clock System -> Reset System -> Signal System Int
separateArguments :: HasCallStack => NormRewrite
separateArguments :: NormRewrite
separateArguments TransformContext
ctx e0 :: Term
e0@(Lam Id
b Term
eb) = do
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
  case TyConMap -> TransformContext -> Id -> Term -> Maybe Term
separateLambda TyConMap
tcm TransformContext
ctx Id
b Term
eb of
    Just Term
e1 -> Term -> RewriteMonad NormalizeState Term
forall a extra. a -> RewriteMonad extra a
changed Term
e1
    Maybe Term
Nothing -> Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e0

separateArguments (TransformContext InScopeSet
is0 Context
_) e :: Term
e@(Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks -> (Var Id
g, [Either Term Type]
args, [TickInfo]
ticks))
  | Id -> Bool
forall a. Var a -> Bool
isGlobalId Id
g = do
  -- We ensure that both the type of the global variable reference is updated
  -- to take into account the changed arguments, and that we apply the global
  -- function with the split apart arguments.
  let ([Either TyVar Type]
argTys0,Type
resTy) = Type -> ([Either TyVar Type], Type)
splitFunForallTy (Id -> Type
forall a. HasType a => a -> Type
coreTypeOf Id
g)
  ([[(Either TyVar Type, Either Term Type)]]
-> [(Either TyVar Type, Either Term Type)]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat -> [(Either TyVar Type, Either Term Type)]
args1, Any -> Bool
Monoid.getAny -> Bool
hasChanged)
    <- RewriteMonad
  NormalizeState [[(Either TyVar Type, Either Term Type)]]
-> RewriteMonad
     NormalizeState ([[(Either TyVar Type, Either Term Type)]], Any)
forall w (m :: Type -> Type) a. MonadWriter w m => m a -> m (a, w)
listen (((Either TyVar Type, Either Term Type)
 -> RewriteMonad
      NormalizeState [(Either TyVar Type, Either Term Type)])
-> [(Either TyVar Type, Either Term Type)]
-> RewriteMonad
     NormalizeState [[(Either TyVar Type, Either Term Type)]]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((Either TyVar Type
 -> Either Term Type
 -> RewriteMonad
      NormalizeState [(Either TyVar Type, Either Term Type)])
-> (Either TyVar Type, Either Term Type)
-> RewriteMonad
     NormalizeState [(Either TyVar Type, Either Term Type)]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Either TyVar Type
-> Either Term Type
-> RewriteMonad
     NormalizeState [(Either TyVar Type, Either Term Type)]
splitArg) ([Either TyVar Type]
-> [Either Term Type] -> [(Either TyVar Type, Either Term Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Either TyVar Type]
argTys0 [Either Term Type]
args))
  if Bool
hasChanged then
    let ([Either TyVar Type]
argTys1,[Either Term Type]
args2) = [(Either TyVar Type, Either Term Type)]
-> ([Either TyVar Type], [Either Term Type])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Either TyVar Type, Either Term Type)]
args1
        gTy :: Type
gTy = Type -> [Either TyVar Type] -> Type
mkPolyFunTy Type
resTy [Either TyVar Type]
argTys1
    in  Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Term -> [Either Term Type] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks (Id -> Term
Var Id
g {varType :: Type
varType = Type
gTy}) [TickInfo]
ticks) [Either Term Type]
args2)
  else
    Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e

 where
  -- Split a single argument
  splitArg
    :: Either TyVar Type
    -- The quantifier/function argument type of the global variable
    -> Either Term Type
    -- The applied type argument or term argument
    -> NormalizeSession [(Either TyVar Type,Either Term Type)]
  splitArg :: Either TyVar Type
-> Either Term Type
-> RewriteMonad
     NormalizeState [(Either TyVar Type, Either Term Type)]
splitArg Either TyVar Type
tv arg :: Either Term Type
arg@(Right Type
_)    = [(Either TyVar Type, Either Term Type)]
-> RewriteMonad
     NormalizeState [(Either TyVar Type, Either Term Type)]
forall (m :: Type -> Type) a. Monad m => a -> m a
return [(Either TyVar Type
tv,Either Term Type
arg)]
  splitArg Either TyVar Type
ty arg :: Either Term Type
arg@(Left Term
tmArg) = do
    TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
    let argTy :: Type
argTy = TyConMap -> Term -> Type
forall a. InferType a => TyConMap -> a -> Type
inferCoreTypeOf TyConMap
tcm Term
tmArg
    case TyConMap -> Type -> Maybe ([Term] -> Term, Projections, [Type])
shouldSplit TyConMap
tcm Type
argTy of
      Just ([Term] -> Term
_,Projections forall (m :: Type -> Type).
MonadUnique m =>
InScopeSet -> Term -> m [Term]
projections,[Type]
_) -> do
        [Term]
tmArgs <- InScopeSet -> Term -> RewriteMonad NormalizeState [Term]
forall (m :: Type -> Type).
MonadUnique m =>
InScopeSet -> Term -> m [Term]
projections InScopeSet
is0 Term
tmArg
        [(Either TyVar Type, Either Term Type)]
-> RewriteMonad
     NormalizeState [(Either TyVar Type, Either Term Type)]
forall a extra. a -> RewriteMonad extra a
changed ((Term -> (Either TyVar Type, Either Term Type))
-> [Term] -> [(Either TyVar Type, Either Term Type)]
forall a b. (a -> b) -> [a] -> [b]
map ((Either TyVar Type
ty,) (Either Term Type -> (Either TyVar Type, Either Term Type))
-> (Term -> Either Term Type)
-> Term
-> (Either TyVar Type, Either Term Type)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Term -> Either Term Type
forall a b. a -> Either a b
Left) [Term]
tmArgs)
      Maybe ([Term] -> Term, Projections, [Type])
_ ->
        [(Either TyVar Type, Either Term Type)]
-> RewriteMonad
     NormalizeState [(Either TyVar Type, Either Term Type)]
forall (m :: Type -> Type) a. Monad m => a -> m a
return [(Either TyVar Type
ty,Either Term Type
arg)]

separateArguments TransformContext
_ Term
e = Term -> RewriteMonad NormalizeState Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
{-# SCC separateArguments #-}

-- | Worker function of 'separateArguments'.
separateLambda
  :: TyConMap
  -> TransformContext
  -> Id
  -- ^ Lambda binder
  -> Term
  -- ^ Lambda body
  -> Maybe Term
  -- ^ If lambda is split up, this function returns a Just containing the new term
separateLambda :: TyConMap -> TransformContext -> Id -> Term -> Maybe Term
separateLambda TyConMap
tcm ctx :: TransformContext
ctx@(TransformContext InScopeSet
is0 Context
_) Id
b Term
eb0 =
  case TyConMap -> Type -> Maybe ([Term] -> Term, Projections, [Type])
shouldSplit TyConMap
tcm (Id -> Type
forall a. HasType a => a -> Type
coreTypeOf Id
b) of
    Just ([Term] -> Term
dc, Projections
_, [Type]
argTys) ->
      let
        nm :: TmName
nm = TransformContext -> OccName -> TmName
mkDerivedName TransformContext
ctx (TmName -> OccName
forall a. Name a -> OccName
nameOcc (Id -> TmName
forall a. Var a -> Name a
varName Id
b))
        bs0 :: [Id]
bs0 = (Type -> Id) -> [Type] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> TmName -> Id
`mkLocalId` TmName
nm) [Type]
argTys
        (InScopeSet
is1, [Id]
bs1) = (InScopeSet -> Id -> (InScopeSet, Id))
-> InScopeSet -> [Id] -> (InScopeSet, [Id])
forall (t :: Type -> Type) a b c.
Traversable t =>
(a -> b -> (a, c)) -> a -> t b -> (a, t c)
List.mapAccumL InScopeSet -> Id -> (InScopeSet, Id)
forall a. InScopeSet -> Var a -> (InScopeSet, Var a)
newBinder InScopeSet
is0 [Id]
bs0
        subst :: Subst
subst = Subst -> Id -> Term -> Subst
extendIdSubst (InScopeSet -> Subst
mkSubst InScopeSet
is1) Id
b ([Term] -> Term
dc ((Id -> Term) -> [Id] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map Id -> Term
Var [Id]
bs1))
        eb1 :: Term
eb1 = HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"separateArguments" Subst
subst Term
eb0
      in
        Term -> Maybe Term
forall a. a -> Maybe a
Just (Term -> [Id] -> Term
mkLams Term
eb1 [Id]
bs1)
    Maybe ([Term] -> Term, Projections, [Type])
_ ->
      Maybe Term
forall a. Maybe a
Nothing
 where
  newBinder :: InScopeSet -> Var a -> (InScopeSet, Var a)
newBinder InScopeSet
isN0 Var a
x =
    let
      x' :: Var a
x' = InScopeSet -> Var a -> Var a
forall a. (Uniquable a, ClashPretty a) => InScopeSet -> a -> a
uniqAway InScopeSet
isN0 Var a
x
      isN1 :: InScopeSet
isN1 = InScopeSet -> Var a -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
isN0 Var a
x'
    in
      (InScopeSet
isN1, Var a
x')
{-# SCC separateLambda #-}