{-|
Copyright   : (C) 2020, QBayLogic B.V.
License     : BSD2 (see the file LICENSE)
Maintainer  : QBayLogic B.V. <devops@qaylogic.com>

Check whether a term is work free or not. This is used by transformations /
evaluation to check whether it is possible to perform changes without
duplicating work in the result, e.g. inlining.
-}

{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}

module Clash.Rewrite.WorkFree
  ( isWorkFree
  , isWorkFreeClockOrResetOrEnable
  , isWorkFreeIsh
  , isConstant
  , isConstantNotClockReset
  ) where

import Control.Lens (Lens')
import Control.Monad.Extra (allM, andM, eitherM)
import Control.Monad.State.Class (MonadState)
import GHC.Stack (HasCallStack)

import Clash.Core.FreeVars
import Clash.Core.Pretty (showPpr)
import Clash.Core.Term
import Clash.Core.TermInfo
import Clash.Core.TyCon (TyConMap)
import Clash.Core.Type (isPolyFunTy)
import Clash.Core.Util
import Clash.Core.Var (Id, Var(..), isLocalId)
import Clash.Core.VarEnv (VarEnv, lookupVarEnv)
import Clash.Driver.Types (BindingMap, Binding(..))
import Clash.Util (makeCachedU)

-- | Determines whether a global binder is work free. Errors if binder does
-- not exist.
isWorkFreeBinder
  :: (HasCallStack, MonadState s m)
  => Lens' s (VarEnv Bool)
  -> BindingMap
  -> Id
  -> m Bool
isWorkFreeBinder :: Lens' s (VarEnv Bool) -> BindingMap -> Id -> m Bool
isWorkFreeBinder Lens' s (VarEnv Bool)
cache BindingMap
bndrs Id
bndr =
  Id -> Lens' s (VarEnv Bool) -> m Bool -> m Bool
forall s (m :: Type -> Type) k v.
(MonadState s m, Uniquable k) =>
k -> Lens' s (UniqMap v) -> m v -> m v
makeCachedU Id
bndr Lens' s (VarEnv Bool)
cache (m Bool -> m Bool) -> m Bool -> m Bool
forall a b. (a -> b) -> a -> b
$
    case Id -> BindingMap -> Maybe (Binding Term)
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
bndr BindingMap
bndrs of
      Maybe (Binding Term)
Nothing -> [Char] -> m Bool
forall a. HasCallStack => [Char] -> a
error ([Char]
"isWorkFreeBinder: couldn't find binder: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Id -> [Char]
forall p. PrettyPrec p => p -> [Char]
showPpr Id
bndr)
      Just (Binding Term -> Term
forall a. Binding a -> a
bindingTerm -> Term
t) ->
        if Id
bndr Id -> Term -> Bool
`globalIdOccursIn` Term
t
        then Bool -> m Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Bool
False
        else Lens' s (VarEnv Bool) -> BindingMap -> Term -> m Bool
forall s (m :: Type -> Type).
(HasCallStack, MonadState s m) =>
Lens' s (VarEnv Bool) -> BindingMap -> Term -> m Bool
isWorkFree Lens' s (VarEnv Bool)
cache BindingMap
bndrs Term
t

{-# INLINABLE isWorkFree #-}
-- | Determine whether a term does any work, i.e. adds to the size of the
-- circuit. This function requires a cache (specified as a lens) to store the
-- result for querying work info of global binders.
--
isWorkFree
  :: forall s m
   . (HasCallStack, MonadState s m)
  => Lens' s (VarEnv Bool)
  -> BindingMap
  -> Term
  -> m Bool
isWorkFree :: Lens' s (VarEnv Bool) -> BindingMap -> Term -> m Bool
isWorkFree Lens' s (VarEnv Bool)
cache BindingMap
bndrs = HasCallStack => Bool -> Term -> m Bool
Bool -> Term -> m Bool
go Bool
True
 where
  -- If we are in the outermost level of a term (i.e. not checking a subterm)
  -- then a term is work free if it simply refers to a local variable. This
  -- does not apply to subterms, as we do not want to count expressions like
  --
  --   f[LocalId] x[LocalId]
  --
  -- as being work free, as the term bound to f may introduce work.
  --
  go :: HasCallStack => Bool -> Term -> m Bool
  go :: Bool -> Term -> m Bool
go Bool
isOutermost (Term -> (Term, [Either Term Type])
collectArgs -> (Term
fun, [Either Term Type]
args)) =
    case Term
fun of
      Var Id
i
        -- We only allow polymorphic / function typed variables to be inlined
        -- if they are locally scoped, and the term is only a variable.
        --
        -- TODO This could be improved later by passing an InScopeSet to
        -- isWorkFree with all the local FVs of the term being checked. PE
        -- would need to be changed to know the FVs of global binders first.
        --
        | Type -> Bool
isPolyFunTy (Id -> Type
forall a. Var a -> Type
varType Id
i) ->
            Bool -> m Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Id -> Bool
forall a. Var a -> Bool
isLocalId Id
i Bool -> Bool -> Bool
&& Bool
isOutermost Bool -> Bool -> Bool
&& [Either Term Type] -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null [Either Term Type]
args)
        | Id -> Bool
forall a. Var a -> Bool
isLocalId Id
i ->
            Bool -> m Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Bool
True
        | Bool
otherwise ->
            [m Bool] -> m Bool
forall (m :: Type -> Type). Monad m => [m Bool] -> m Bool
andM [Lens' s (VarEnv Bool) -> BindingMap -> Id -> m Bool
forall s (m :: Type -> Type).
(HasCallStack, MonadState s m) =>
Lens' s (VarEnv Bool) -> BindingMap -> Id -> m Bool
isWorkFreeBinder Lens' s (VarEnv Bool)
cache BindingMap
bndrs Id
i, (Either Term Type -> m Bool) -> [Either Term Type] -> m Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM Either Term Type -> m Bool
forall b. Either Term b -> m Bool
goArg [Either Term Type]
args]

      Data DataCon
_ -> (Either Term Type -> m Bool) -> [Either Term Type] -> m Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM Either Term Type -> m Bool
forall b. Either Term b -> m Bool
goArg [Either Term Type]
args
      Literal Literal
_ -> Bool -> m Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Bool
True
      Prim PrimInfo
pr ->
        case PrimInfo -> WorkInfo
primWorkInfo PrimInfo
pr of
          -- We can ignore arguments because the primitive outputs a constant
          -- regardless of their values.
          WorkInfo
WorkConstant -> Bool -> m Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Bool
True
          WorkInfo
WorkNever -> (Either Term Type -> m Bool) -> [Either Term Type] -> m Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM Either Term Type -> m Bool
forall b. Either Term b -> m Bool
goArg [Either Term Type]
args
          WorkInfo
WorkVariable -> Bool -> m Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ((Either Term Type -> Bool) -> [Either Term Type] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all Either Term Type -> Bool
forall b. Either Term b -> Bool
isConstantArg [Either Term Type]
args)
          WorkInfo
WorkAlways -> Bool -> m Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Bool
False

      Lam Id
_ Term
e -> [m Bool] -> m Bool
forall (m :: Type -> Type). Monad m => [m Bool] -> m Bool
andM [HasCallStack => Bool -> Term -> m Bool
Bool -> Term -> m Bool
go Bool
False Term
e, (Either Term Type -> m Bool) -> [Either Term Type] -> m Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM Either Term Type -> m Bool
forall b. Either Term b -> m Bool
goArg [Either Term Type]
args]
      TyLam TyVar
_ Term
e -> [m Bool] -> m Bool
forall (m :: Type -> Type). Monad m => [m Bool] -> m Bool
andM [HasCallStack => Bool -> Term -> m Bool
Bool -> Term -> m Bool
go Bool
False Term
e, (Either Term Type -> m Bool) -> [Either Term Type] -> m Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM Either Term Type -> m Bool
forall b. Either Term b -> m Bool
goArg [Either Term Type]
args]
      Letrec [LetBinding]
bs Term
e -> [m Bool] -> m Bool
forall (m :: Type -> Type). Monad m => [m Bool] -> m Bool
andM [HasCallStack => Bool -> Term -> m Bool
Bool -> Term -> m Bool
go Bool
False Term
e, (LetBinding -> m Bool) -> [LetBinding] -> m Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM (HasCallStack => Bool -> Term -> m Bool
Bool -> Term -> m Bool
go Bool
False (Term -> m Bool) -> (LetBinding -> Term) -> LetBinding -> m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LetBinding -> Term
forall a b. (a, b) -> b
snd) [LetBinding]
bs, (Either Term Type -> m Bool) -> [Either Term Type] -> m Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM Either Term Type -> m Bool
forall b. Either Term b -> m Bool
goArg [Either Term Type]
args]
      Case Term
s Type
_ [(Pat
_, Term
a)] -> [m Bool] -> m Bool
forall (m :: Type -> Type). Monad m => [m Bool] -> m Bool
andM [HasCallStack => Bool -> Term -> m Bool
Bool -> Term -> m Bool
go Bool
False Term
s, HasCallStack => Bool -> Term -> m Bool
Bool -> Term -> m Bool
go Bool
False Term
a, (Either Term Type -> m Bool) -> [Either Term Type] -> m Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM Either Term Type -> m Bool
forall b. Either Term b -> m Bool
goArg [Either Term Type]
args]
      Case Term
e Type
_ [Alt]
_ -> [m Bool] -> m Bool
forall (m :: Type -> Type). Monad m => [m Bool] -> m Bool
andM [HasCallStack => Bool -> Term -> m Bool
Bool -> Term -> m Bool
go Bool
False Term
e, (Either Term Type -> m Bool) -> [Either Term Type] -> m Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM Either Term Type -> m Bool
forall b. Either Term b -> m Bool
goArg [Either Term Type]
args]
      Cast Term
e Type
_ Type
_ -> [m Bool] -> m Bool
forall (m :: Type -> Type). Monad m => [m Bool] -> m Bool
andM [HasCallStack => Bool -> Term -> m Bool
Bool -> Term -> m Bool
go Bool
False Term
e, (Either Term Type -> m Bool) -> [Either Term Type] -> m Bool
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m Bool
allM Either Term Type -> m Bool
forall b. Either Term b -> m Bool
goArg [Either Term Type]
args]

      -- (Ty)App's and  Ticks are removed by collectArgs
      Tick TickInfo
_ Term
_ -> [Char] -> m Bool
forall a. HasCallStack => [Char] -> a
error [Char]
"isWorkFree: unexpected Tick"
      App {}   -> [Char] -> m Bool
forall a. HasCallStack => [Char] -> a
error [Char]
"isWorkFree: unexpected App"
      TyApp {} -> [Char] -> m Bool
forall a. HasCallStack => [Char] -> a
error [Char]
"isWorkFree: unexpected TyApp"

  goArg :: Either Term b -> m Bool
goArg Either Term b
e = (Term -> m Bool) -> (b -> m Bool) -> m (Either Term b) -> m Bool
forall (m :: Type -> Type) a c b.
Monad m =>
(a -> m c) -> (b -> m c) -> m (Either a b) -> m c
eitherM (HasCallStack => Bool -> Term -> m Bool
Bool -> Term -> m Bool
go Bool
False) (Bool -> m Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Bool -> m Bool) -> (b -> Bool) -> b -> m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> b -> Bool
forall a b. a -> b -> a
const Bool
True) (Either Term b -> m (Either Term b)
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Either Term b
e)
  isConstantArg :: Either Term b -> Bool
isConstantArg = (Term -> Bool) -> (b -> Bool) -> Either Term b -> Bool
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Term -> Bool
isConstant (Bool -> b -> Bool
forall a b. a -> b -> a
const Bool
True)

-- | Determine if a term represents a constant
isConstant :: Term -> Bool
isConstant :: Term -> Bool
isConstant Term
e = case Term -> (Term, [Either Term Type])
collectArgs Term
e of
  (Data DataCon
_, [Either Term Type]
args)   -> (Either Term Type -> Bool) -> [Either Term Type] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all ((Term -> Bool) -> (Type -> Bool) -> Either Term Type -> Bool
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Term -> Bool
isConstant (Bool -> Type -> Bool
forall a b. a -> b -> a
const Bool
True)) [Either Term Type]
args
  (Prim PrimInfo
_, [Either Term Type]
args) -> (Either Term Type -> Bool) -> [Either Term Type] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all ((Term -> Bool) -> (Type -> Bool) -> Either Term Type -> Bool
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Term -> Bool
isConstant (Bool -> Type -> Bool
forall a b. a -> b -> a
const Bool
True)) [Either Term Type]
args
  (Lam Id
_ Term
_, [Either Term Type]
_)     -> Bool -> Bool
not (Term -> Bool
hasLocalFreeVars Term
e)
  (Literal Literal
_,[Either Term Type]
_)    -> Bool
True
  (Term, [Either Term Type])
_                -> Bool
False

isConstantNotClockReset :: TyConMap -> Term -> Bool
isConstantNotClockReset :: TyConMap -> Term -> Bool
isConstantNotClockReset TyConMap
tcm Term
e
  | TyConMap -> Type -> Bool
isClockOrReset TyConMap
tcm Type
eTy =
      case (Term, [Either Term Type]) -> Term
forall a b. (a, b) -> a
fst (Term -> (Term, [Either Term Type])
collectArgs Term
e) of
        Prim PrimInfo
pr -> PrimInfo -> Text
primName PrimInfo
pr Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"Clash.Transformations.removedArg"
        Term
_ -> Bool
False

  | Bool
otherwise = Term -> Bool
isConstant Term
e
 where
  eTy :: Type
eTy = TyConMap -> Term -> Type
termType TyConMap
tcm Term
e

-- TODO: Remove function after using WorkInfo in 'isWorkFreeIsh'
isWorkFreeClockOrResetOrEnable
  :: TyConMap
  -> Term
  -> Maybe Bool
isWorkFreeClockOrResetOrEnable :: TyConMap -> Term -> Maybe Bool
isWorkFreeClockOrResetOrEnable TyConMap
tcm Term
e =
  let eTy :: Type
eTy = TyConMap -> Term -> Type
termType TyConMap
tcm Term
e in
  if TyConMap -> Type -> Bool
isClockOrReset TyConMap
tcm Type
eTy Bool -> Bool -> Bool
|| TyConMap -> Type -> Bool
isEnable TyConMap
tcm Type
eTy then
    case Term -> (Term, [Either Term Type])
collectArgs Term
e of
      (Prim PrimInfo
p,[Either Term Type]
_) -> Bool -> Maybe Bool
forall a. a -> Maybe a
Just (PrimInfo -> Text
primName PrimInfo
p Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"Clash.Transformations.removedArg")
      (Var Id
_, []) -> Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
True
      (Data DataCon
_, [Either Term Type
_dom, Left (Term -> Term
stripTicks -> Data DataCon
_)]) -> Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
True -- For Enable True/False
      (Literal Literal
_,[Either Term Type]
_) -> Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
True
      (Term, [Either Term Type])
_ -> Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
False
  else
    Maybe Bool
forall a. Maybe a
Nothing

-- | A conservative version of 'isWorkFree'. Is used to determine in 'bindConstantVar'
-- to determine whether an expression can be "bound" (locally inlined). While
-- binding workfree expressions won't result in extra work for the circuit, it
-- might very well cause extra work for Clash. In fact, using 'isWorkFree' in
-- 'bindConstantVar' makes Clash two orders of magnitude slower for some of our
-- test cases.
--
-- In effect, this function is a version of 'isConstant' that also considers
-- references to clocks and resets constant. This allows us to bind
-- HiddenClock(ResetEnable) constructs, allowing Clash to constant spec
-- subconstants - most notably KnownDomain. Doing that enables Clash to
-- eliminate any case-constructs on it.
isWorkFreeIsh
  :: TyConMap
  -> Term
  -> Bool
isWorkFreeIsh :: TyConMap -> Term -> Bool
isWorkFreeIsh TyConMap
tcm Term
e =
  case TyConMap -> Term -> Maybe Bool
isWorkFreeClockOrResetOrEnable TyConMap
tcm Term
e of
    Just Bool
b -> Bool
b
    Maybe Bool
Nothing ->
      case Term -> (Term, [Either Term Type])
collectArgs Term
e of
        (Data DataCon
_, [Either Term Type]
args)     -> (Either Term Type -> Bool) -> [Either Term Type] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all Either Term Type -> Bool
forall b. Either Term b -> Bool
isWorkFreeIshArg [Either Term Type]
args
        (Prim PrimInfo
pInfo, [Either Term Type]
args) -> case PrimInfo -> WorkInfo
primWorkInfo PrimInfo
pInfo of
          WorkInfo
WorkAlways   -> Bool
False -- Things like clock or reset generator always
                                       -- perform work
          WorkInfo
WorkVariable -> (Either Term Type -> Bool) -> [Either Term Type] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all Either Term Type -> Bool
forall b. Either Term b -> Bool
isConstantArg [Either Term Type]
args
          WorkInfo
_            -> (Either Term Type -> Bool) -> [Either Term Type] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all Either Term Type -> Bool
forall b. Either Term b -> Bool
isWorkFreeIshArg [Either Term Type]
args

        (Lam Id
_ Term
_, [Either Term Type]
_)       -> Bool -> Bool
not (Term -> Bool
hasLocalFreeVars Term
e)
        (Literal Literal
_,[Either Term Type]
_)      -> Bool
True
        (Term, [Either Term Type])
_                  -> Bool
False
 where
  isWorkFreeIshArg :: Either Term b -> Bool
isWorkFreeIshArg = (Term -> Bool) -> (b -> Bool) -> Either Term b -> Bool
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (TyConMap -> Term -> Bool
isWorkFreeIsh TyConMap
tcm) (Bool -> b -> Bool
forall a b. a -> b -> a
const Bool
True)
  isConstantArg :: Either Term b -> Bool
isConstantArg    = (Term -> Bool) -> (b -> Bool) -> Either Term b -> Bool
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Term -> Bool
isConstant (Bool -> b -> Bool
forall a b. a -> b -> a
const Bool
True)