{-|
  Copyright  :  (C) 2012-2016, University of Twente,
                    2016     , Myrtle Software Ltd,
                    2017     , Google Inc.,
                    2021-2022, QBayLogic B.V.
  License    :  BSD2 (see the file LICENSE)
  Maintainer :  QBayLogic B.V. <devops@qbaylogic.com>

  Utilities for rewriting: e.g. inlining, specialisation, etc.
-}

{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE NondecreasingIndentation #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TemplateHaskell #-}

module Clash.Rewrite.Util
  ( module Clash.Rewrite.Util
  , module Clash.Rewrite.WorkFree
  ) where

import           Control.Concurrent.Supply   (splitSupply)
import           Control.DeepSeq
import           Control.Exception           (throw)
import           Control.Lens ((%=), (+=), (^.))
import qualified Control.Lens                as Lens
import qualified Control.Monad               as Monad
import qualified Control.Monad.State.Strict  as State
#if MIN_VERSION_transformers(0,5,6)
import qualified Control.Monad.Trans.RWS.CPS as RWS
#else
import qualified Control.Monad.Trans.RWS.Strict as RWS
#endif
import qualified Control.Monad.Writer        as Writer
import           Data.Bifunctor              (second)
import           Data.Coerce                 (coerce)
import           Data.Functor.Const          (Const (..))
import qualified Data.HashMap.Strict         as HashMap
import           Data.List                   (group, partition, sort, sortOn)
import qualified Data.List                   as List
import qualified Data.List.Extra             as List
import           Data.List.Extra             (partitionM)
import           Data.Maybe
import qualified Data.Monoid                 as Monoid
import qualified Data.Set                    as Set
import qualified Data.Set.Lens               as Lens
import           Data.Text                   (Text)
import qualified Data.Text                   as Text
import           System.IO.Unsafe            (unsafePerformIO)
import           Data.Binary                 (encode)
import qualified Data.ByteString             as BS
import qualified Data.ByteString.Lazy        as BL

#if MIN_VERSION_ghc(9,0,0)
import           GHC.Types.Basic             (InlineSpec (..))
#else
import           BasicTypes                  (InlineSpec (..))
#endif

import           Clash.Core.Evaluator.Types  (PureHeap, whnf')
import           Clash.Core.FreeVars
  (freeLocalVars, termFreeVars', freeLocalIds, globalIdOccursIn)
import           Clash.Core.HasFreeVars      (elemFreeVars, notElemFreeVars)
import           Clash.Core.HasType
import           Clash.Core.Name
import           Clash.Core.Pretty           (showPpr)
import           Clash.Core.Subst
  (substTmEnv, aeqTerm, aeqType, extendIdSubst, mkSubst, substTm, eqTerm)
import           Clash.Core.Term
import           Clash.Core.TyCon            (TyConMap)
import           Clash.Core.Type             (Type (..), normalizeType)
import           Clash.Core.Var
  (Id, IdScope (..), TyVar, Var (..), mkGlobalId, mkLocalId, mkTyVar)
import           Clash.Core.VarEnv
  (InScopeSet, extendInScopeSet, extendInScopeSetList, mkInScopeSet,
   uniqAway, uniqAway', mapVarEnv, eltsVarEnv, unitVarSet, emptyVarEnv,
   mkVarEnv, eltsVarSet, elemVarEnv, lookupVarEnv, extendVarEnv, elemVarSet)
import           Clash.Debug
import           Clash.Driver.Types
  (TransformationInfo(..), DebugOpts(..), BindingMap, Binding(..), IsPrim(..),
  ClashEnv(..), ClashOpts(..), hasDebugInfo, isDebugging)
import           Clash.Netlist.Util          (representableType)
import           Clash.Pretty                (clashPretty, showDoc)
import           Clash.Rewrite.Types
import           Clash.Rewrite.WorkFree
import           Clash.Unique
import           Clash.Util
import           Clash.Util.Eq               (fastEqBy)
import qualified Clash.Util.Interpolate as I

-- | Lift an action working in the '_extra' state to the 'RewriteMonad'
zoomExtra :: State.State extra a -> RewriteMonad extra a
zoomExtra :: State extra a -> RewriteMonad extra a
zoomExtra State extra a
m = RWST RewriteEnv Any (RewriteState extra) IO a
-> RewriteMonad extra a
forall extra a.
RWST RewriteEnv Any (RewriteState extra) IO a
-> RewriteMonad extra a
R (RWST RewriteEnv Any (RewriteState extra) IO a
 -> RewriteMonad extra a)
-> ((RewriteEnv
     -> RewriteState extra -> IO (a, RewriteState extra, Any))
    -> RWST RewriteEnv Any (RewriteState extra) IO a)
-> (RewriteEnv
    -> RewriteState extra -> IO (a, RewriteState extra, Any))
-> RewriteMonad extra a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (RewriteEnv
 -> RewriteState extra -> IO (a, RewriteState extra, Any))
-> RWST RewriteEnv Any (RewriteState extra) IO a
forall (m :: Type -> Type) w r s a.
(Functor m, Monoid w) =>
(r -> s -> m (a, s, w)) -> RWST r w s m a
RWS.rwsT ((RewriteEnv
  -> RewriteState extra -> IO (a, RewriteState extra, Any))
 -> RewriteMonad extra a)
-> (RewriteEnv
    -> RewriteState extra -> IO (a, RewriteState extra, Any))
-> RewriteMonad extra a
forall a b. (a -> b) -> a -> b
$ \RewriteEnv
_ RewriteState extra
s ->
  let (a
a, extra
st') = State extra a -> extra -> (a, extra)
forall s a. State s a -> s -> (a, s)
State.runState State extra a
m (RewriteState extra -> extra
forall extra. RewriteState extra -> extra
_extra RewriteState extra
s)
   in (a, RewriteState extra, Any) -> IO (a, RewriteState extra, Any)
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (a
a, RewriteState extra
s { _extra :: extra
_extra = extra
st' }, Any
forall a. Monoid a => a
mempty)

-- | Some transformations might erroneously introduce shadowing. For example,
-- a transformation might result in:
--
--   let a = ...
--       b = ...
--       a = ...
--
-- where the last 'a', shadows the first, while Clash assumes that this can't
-- happen. This function finds those constructs and a list of found duplicates.
--
findAccidentialShadows :: Term -> [[Id]]
findAccidentialShadows :: Term -> [[Id]]
findAccidentialShadows =
  \case
    Var {}      -> []
    Data {}     -> []
    Literal {}  -> []
    Prim {}     -> []
    Lam Id
_ Term
t     -> Term -> [[Id]]
findAccidentialShadows Term
t
    TyLam TyVar
_ Term
t   -> Term -> [[Id]]
findAccidentialShadows Term
t
    App Term
t1 Term
t2   -> (Term -> [[Id]]) -> [Term] -> [[Id]]
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> [b]) -> t a -> [b]
concatMap Term -> [[Id]]
findAccidentialShadows [Term
t1, Term
t2]
    TyApp Term
t Type
_   -> Term -> [[Id]]
findAccidentialShadows Term
t
    Cast Term
t Type
_ Type
_  -> Term -> [[Id]]
findAccidentialShadows Term
t
    Tick TickInfo
_ Term
t    -> Term -> [[Id]]
findAccidentialShadows Term
t
    Case Term
t Type
_ [Alt]
as ->
      (Alt -> [[Id]]) -> [Alt] -> [[Id]]
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> [b]) -> t a -> [b]
concatMap (Pat -> [[Id]]
findInPat (Pat -> [[Id]]) -> (Alt -> Pat) -> Alt -> [[Id]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Alt -> Pat
forall a b. (a, b) -> a
fst) [Alt]
as [[Id]] -> [[Id]] -> [[Id]]
forall a. [a] -> [a] -> [a]
++
        (Term -> [[Id]]) -> [Term] -> [[Id]]
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> [b]) -> t a -> [b]
concatMap Term -> [[Id]]
findAccidentialShadows (Term
t Term -> [Term] -> [Term]
forall a. a -> [a] -> [a]
: (Alt -> Term) -> [Alt] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map Alt -> Term
forall a b. (a, b) -> b
snd [Alt]
as)
    Let NonRec{} Term
t -> Term -> [[Id]]
findAccidentialShadows Term
t
    Let (Rec [(Id, Term)]
bs) Term
t -> [Id] -> [[Id]]
findDups (((Id, Term) -> Id) -> [(Id, Term)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Id, Term) -> Id
forall a b. (a, b) -> a
fst [(Id, Term)]
bs) [[Id]] -> [[Id]] -> [[Id]]
forall a. [a] -> [a] -> [a]
++ Term -> [[Id]]
findAccidentialShadows Term
t

 where
  findInPat :: Pat -> [[Id]]
  findInPat :: Pat -> [[Id]]
findInPat (LitPat Literal
_)        = []
  findInPat (Pat
DefaultPat)      = []
  findInPat (DataPat DataCon
_ [TyVar]
_ [Id]
ids) = [Id] -> [[Id]]
findDups [Id]
ids

  findDups :: [Id] -> [[Id]]
  findDups :: [Id] -> [[Id]]
findDups [Id]
ids = ([Id] -> Bool) -> [[Id]] -> [[Id]]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Int
1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<) (Int -> Bool) -> ([Id] -> Int) -> [Id] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Id] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length) ([Id] -> [[Id]]
forall a. Eq a => [a] -> [[a]]
group ([Id] -> [Id]
forall a. Ord a => [a] -> [a]
sort [Id]
ids))


-- | Record if a transformation is successfully applied
apply
  :: String
  -- ^ Name of the transformation
  -> Rewrite extra
  -- ^ Transformation to be applied
  -> Rewrite extra
apply :: String -> Rewrite extra -> Rewrite extra
apply = \String
s Rewrite extra
rewrite TransformContext
ctx Term
expr0 -> do
  DebugOpts
opts <- Getting DebugOpts RewriteEnv DebugOpts
-> RewriteMonad extra DebugOpts
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting DebugOpts RewriteEnv DebugOpts
Getter RewriteEnv DebugOpts
debugOpts
  Bool -> String -> RewriteMonad extra () -> RewriteMonad extra ()
forall a. Bool -> String -> a -> a
traceIf (TransformationInfo -> String -> DebugOpts -> Bool
hasDebugInfo TransformationInfo
TryName String
s DebugOpts
opts) (String
"Trying: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
s) (() -> RewriteMonad extra ()
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ())

  (!Term
expr1,Any
anyChanged) <- RewriteMonad extra Term -> RewriteMonad extra (Term, Any)
forall w (m :: Type -> Type) a. MonadWriter w m => m a -> m (a, w)
Writer.listen (Rewrite extra
rewrite TransformContext
ctx Term
expr0)
  let hasChanged :: Bool
hasChanged = Any -> Bool
Monoid.getAny Any
anyChanged
  Bool -> RewriteMonad extra () -> RewriteMonad extra ()
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
Monad.when Bool
hasChanged ((Word -> Identity Word)
-> RewriteState extra -> Identity (RewriteState extra)
forall extra. Lens' (RewriteState extra) Word
transformCounter ((Word -> Identity Word)
 -> RewriteState extra -> Identity (RewriteState extra))
-> Word -> RewriteMonad extra ()
forall s (m :: Type -> Type) a.
(MonadState s m, Num a) =>
ASetter' s a -> a -> m ()
+= Word
1)

  -- NB: When -fclash-debug-history is on, emit binary data holding the recorded rewrite steps
  let rewriteHistFile :: Maybe String
rewriteHistFile = DebugOpts -> Maybe String
dbg_historyFile DebugOpts
opts
  Bool -> RewriteMonad extra () -> RewriteMonad extra ()
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
Monad.when (Maybe String -> Bool
forall a. Maybe a -> Bool
isJust Maybe String
rewriteHistFile Bool -> Bool -> Bool
&& Bool
hasChanged) (RewriteMonad extra () -> RewriteMonad extra ())
-> RewriteMonad extra () -> RewriteMonad extra ()
forall a b. (a -> b) -> a -> b
$ do
    (Id
curBndr, SrcSpan
_) <- Getting (Id, SrcSpan) (RewriteState extra) (Id, SrcSpan)
-> RewriteMonad extra (Id, SrcSpan)
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting (Id, SrcSpan) (RewriteState extra) (Id, SrcSpan)
forall extra. Lens' (RewriteState extra) (Id, SrcSpan)
curFun
    let !()
_ = IO () -> ()
forall a. IO a -> a
unsafePerformIO
             (IO () -> ()) -> IO () -> ()
forall a b. (a -> b) -> a -> b
$ String -> ByteString -> IO ()
BS.appendFile (Maybe String -> String
forall a. HasCallStack => Maybe a -> a
fromJust Maybe String
rewriteHistFile)
             (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
BL.toStrict
             (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ RewriteStep -> ByteString
forall a. Binary a => a -> ByteString
encode RewriteStep :: Context -> String -> String -> Term -> Term -> RewriteStep
RewriteStep
                 { t_ctx :: Context
t_ctx    = TransformContext -> Context
tfContext TransformContext
ctx
                 , t_name :: String
t_name   = String
s
                 , t_bndrS :: String
t_bndrS  = Name Term -> String
forall p. PrettyPrec p => p -> String
showPpr (Id -> Name Term
forall a. Var a -> Name a
varName Id
curBndr)
                 , t_before :: Term
t_before = Term
expr0
                 , t_after :: Term
t_after  = Term
expr1
                 }
    () -> RewriteMonad extra ()
forall (m :: Type -> Type) a. Monad m => a -> m a
return ()

  if DebugOpts -> Bool
isDebugging DebugOpts
opts
    then String -> Term -> Bool -> Term -> RewriteMonad extra Term
forall extra.
String -> Term -> Bool -> Term -> RewriteMonad extra Term
applyDebug String
s Term
expr0 Bool
hasChanged Term
expr1
    else Term -> RewriteMonad extra Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
expr1
{-# INLINE apply #-}

applyDebug
  :: String
  -- ^ Name of the transformation
  -> Term
  -- ^ Original expression
  -> Bool
  -- ^ Whether the rewrite indicated change
  -> Term
  -- ^ New expression
  -> RewriteMonad extra Term
applyDebug :: String -> Term -> Bool -> Term -> RewriteMonad extra Term
applyDebug String
name Term
exprOld Bool
hasChanged Term
exprNew = do
  Word
nTrans <- Getting Word (RewriteState extra) Word -> RewriteMonad extra Word
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting Word (RewriteState extra) Word
forall extra. Lens' (RewriteState extra) Word
transformCounter
  DebugOpts
opts <- Getting DebugOpts RewriteEnv DebugOpts
-> RewriteMonad extra DebugOpts
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting DebugOpts RewriteEnv DebugOpts
Getter RewriteEnv DebugOpts
debugOpts

  let from :: Word
from = Word -> Maybe Word -> Word
forall a. a -> Maybe a -> a
fromMaybe Word
0 (DebugOpts -> Maybe Word
dbg_transformationsFrom DebugOpts
opts)
  let limit :: Word
limit = Word -> Maybe Word -> Word
forall a. a -> Maybe a -> a
fromMaybe Word
forall a. Bounded a => a
maxBound (DebugOpts -> Maybe Word
dbg_transformationsLimit DebugOpts
opts)

  if | Word
nTrans Word -> Word -> Word
forall a. Num a => a -> a -> a
- Word
from Word -> Word -> Bool
forall a. Ord a => a -> a -> Bool
> Word
limit ->
         String -> RewriteMonad extra Term
forall a. HasCallStack => String -> a
error String
"-fclash-debug-transformations-limit exceeded"
     | Word
nTrans Word -> Word -> Bool
forall a. Ord a => a -> a -> Bool
<= Word
from ->
         Term -> RewriteMonad extra Term
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Term
exprNew
     | Bool
otherwise ->
         DebugOpts -> RewriteMonad extra Term
forall extra (m :: Type -> Type).
(MonadState (RewriteState extra) m, MonadReader RewriteEnv m) =>
DebugOpts -> m Term
go DebugOpts
opts
 where
  go :: DebugOpts -> m Term
go DebugOpts
opts = Bool -> String -> m Term -> m Term
forall a. Bool -> String -> a -> a
traceIf (TransformationInfo -> String -> DebugOpts -> Bool
hasDebugInfo TransformationInfo
TryTerm String
name DebugOpts
opts) (String
"Tried: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
name String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" on:\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
before) (m Term -> m Term) -> m Term -> m Term
forall a b. (a -> b) -> a -> b
$ do
    Word
nTrans <- Word -> Word
forall a. Enum a => a -> a
pred (Word -> Word) -> m Word -> m Word
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting Word (RewriteState extra) Word -> m Word
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting Word (RewriteState extra) Word
forall extra. Lens' (RewriteState extra) Word
transformCounter

    Bool -> m () -> m ()
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
Monad.when (DebugOpts -> Bool
dbg_countTransformations DebugOpts
opts Bool -> Bool -> Bool
&& Bool
hasChanged) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
      (HashMap Text Word -> Identity (HashMap Text Word))
-> RewriteState extra -> Identity (RewriteState extra)
forall extra. Lens' (RewriteState extra) (HashMap Text Word)
transformCounters ((HashMap Text Word -> Identity (HashMap Text Word))
 -> RewriteState extra -> Identity (RewriteState extra))
-> (HashMap Text Word -> HashMap Text Word) -> m ()
forall s (m :: Type -> Type) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= (Word -> Word -> Word)
-> Text -> Word -> HashMap Text Word -> HashMap Text Word
forall k v.
(Eq k, Hashable k) =>
(v -> v -> v) -> k -> v -> HashMap k v -> HashMap k v
HashMap.insertWith ((Word -> Word) -> Word -> Word -> Word
forall a b. a -> b -> a
const Word -> Word
forall a. Enum a => a -> a
succ) (String -> Text
Text.pack String
name) Word
1

    Bool -> m () -> m ()
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
Monad.when (DebugOpts -> Bool
dbg_invariants DebugOpts
opts Bool -> Bool -> Bool
&& Bool
hasChanged) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
      TyConMap
tcm                  <- Getting TyConMap RewriteEnv TyConMap -> m TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
      let beforeTy :: Type
beforeTy          = TyConMap -> Term -> Type
forall a. InferType a => TyConMap -> a -> Type
inferCoreTypeOf TyConMap
tcm Term
exprOld
          beforeFV :: Set (Var a)
beforeFV          = Getting (Set (Var a)) Term (Var a) -> Term -> Set (Var a)
forall a s. Getting (Set a) s a -> s -> Set a
Lens.setOf Getting (Set (Var a)) Term (Var a)
forall a. Fold Term (Var a)
freeLocalVars Term
exprOld
          afterTy :: Type
afterTy           = TyConMap -> Term -> Type
forall a. InferType a => TyConMap -> a -> Type
inferCoreTypeOf TyConMap
tcm Term
exprNew
          afterFV :: Set (Var a)
afterFV           = Getting (Set (Var a)) Term (Var a) -> Term -> Set (Var a)
forall a s. Getting (Set a) s a -> s -> Set a
Lens.setOf Getting (Set (Var a)) Term (Var a)
forall a. Fold Term (Var a)
freeLocalVars Term
exprNew
          newFV :: Bool
newFV             = Bool -> Bool
not (Set (Var Any)
forall a. Set (Var a)
afterFV Set (Var Any) -> Set (Var Any) -> Bool
forall a. Ord a => Set a -> Set a -> Bool
`Set.isSubsetOf` Set (Var Any)
forall a. Set (Var a)
beforeFV)
          accidentalShadows :: [[Id]]
accidentalShadows = Term -> [[Id]]
findAccidentialShadows Term
exprNew

      Bool -> m () -> m ()
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
Monad.when Bool
newFV (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
              String -> m ()
forall a. HasCallStack => String -> a
error ( [String] -> String
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat [ $(String
curLoc)
                             , String
"Error when applying rewrite ", String
name
                             , String
" to:\n" , String
before
                             , String
"\nResult:\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
after String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\n"
                             , String
"It introduces free variables."
                             , String
"\nBefore: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Var Any] -> String
forall p. PrettyPrec p => p -> String
showPpr (Set (Var Any) -> [Var Any]
forall a. Set a -> [a]
Set.toList Set (Var Any)
forall a. Set (Var a)
beforeFV)
                             , String
"\nAfter: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Var Any] -> String
forall p. PrettyPrec p => p -> String
showPpr (Set (Var Any) -> [Var Any]
forall a. Set a -> [a]
Set.toList Set (Var Any)
forall a. Set (Var a)
afterFV)
                             ]
                    )
      Bool -> m () -> m ()
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
Monad.when (Bool -> Bool
not ([[Id]] -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null [[Id]]
accidentalShadows)) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
        String -> m ()
forall a. HasCallStack => String -> a
error ( [String] -> String
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat [ $(String
curLoc)
                       , String
"Error when applying rewrite ", String
name
                       , String
" to:\n" , String
before
                       , String
"\nResult:\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
after String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\n"
                       , String
"It accidentally creates shadowing let/case-bindings:\n"
                       , String
" ", [[Id]] -> String
forall p. PrettyPrec p => p -> String
showPpr [[Id]]
accidentalShadows, String
"\n"
                       , String
"This usually means that a transformation did not extend "
                       , String
"or incorrectly extended its InScopeSet before applying a "
                       , String
"substitution."
                       ])

      -- TODO This check should be an error instead of a trace, however this is
      -- currently very fragile as Clash doesn't keep casts in core. This should
      -- be changed when #1064 is merged.
      Bool -> m () -> m ()
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
Monad.when (TransformationInfo -> String -> DebugOpts -> Bool
hasDebugInfo TransformationInfo
AppliedTerm String
name DebugOpts
opts Bool -> Bool -> Bool
&& Bool -> Bool
not (TyConMap -> Type -> Type
normalizeType TyConMap
tcm Type
beforeTy Type -> Type -> Bool
`aeqType` TyConMap -> Type -> Type
normalizeType TyConMap
tcm Type
afterTy)) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
        String -> m ()
forall (f :: Type -> Type). Applicative f => String -> f ()
traceM ( [String] -> String
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat [ $(String
curLoc)
                       , String
"Error when applying rewrite ", String
name
                       , String
" to:\n" , String
before
                       , String
"\nResult:\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
after String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\n"
                       , String
"Changes type from:\n", Type -> String
forall p. PrettyPrec p => p -> String
showPpr Type
beforeTy
                       , String
"\nto:\n", Type -> String
forall p. PrettyPrec p => p -> String
showPpr Type
afterTy
                       ]
              )

    let exprNotEqual :: Bool
exprNotEqual = Bool -> Bool
not ((Term -> Term -> Bool) -> Term -> Term -> Bool
forall a. (a -> a -> Bool) -> a -> a -> Bool
fastEqBy Term -> Term -> Bool
eqTerm Term
exprOld Term
exprNew)
    Bool -> m () -> m ()
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
Monad.when (DebugOpts -> Bool
dbg_invariants DebugOpts
opts Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
hasChanged Bool -> Bool -> Bool
&& Bool
exprNotEqual) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
      String -> m ()
forall a. HasCallStack => String -> a
error (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ $(String
curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"Expression changed without notice(" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
name String -> String -> String
forall a. [a] -> [a] -> [a]
++  String
"): before"
                        String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
before String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\nafter:\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
after

    Bool -> String -> m Term -> m Term
forall a. Bool -> String -> a -> a
traceIf (TransformationInfo -> String -> DebugOpts -> Bool
hasDebugInfo TransformationInfo
AppliedName String
name DebugOpts
opts Bool -> Bool -> Bool
&& Bool
hasChanged) (String
name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" {" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Word -> String
forall a. Show a => a -> String
show Word
nTrans String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"}") (m Term -> m Term) -> m Term -> m Term
forall a b. (a -> b) -> a -> b
$
      Bool -> String -> m Term -> m Term
forall a. Bool -> String -> a -> a
traceIf (TransformationInfo -> String -> DebugOpts -> Bool
hasDebugInfo TransformationInfo
AppliedTerm String
name DebugOpts
opts Bool -> Bool -> Bool
&& Bool
hasChanged) (String
"Changes when applying rewrite to:\n"
                        String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
before String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\nResult:\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
after String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\n") (m Term -> m Term) -> m Term -> m Term
forall a b. (a -> b) -> a -> b
$
        Bool -> String -> m Term -> m Term
forall a. Bool -> String -> a -> a
traceIf (TransformationInfo -> String -> DebugOpts -> Bool
hasDebugInfo TransformationInfo
TryTerm String
name DebugOpts
opts Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
hasChanged) (String
"No changes when applying rewrite "
                          String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
name String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" to:\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
after String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\n") (m Term -> m Term) -> m Term -> m Term
forall a b. (a -> b) -> a -> b
$
          Term -> m Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
exprNew
   where
    before :: String
before = Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
exprOld
    after :: String
after  = Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
exprNew

-- | Perform a transformation on a Term
runRewrite
  :: String
  -- ^ Name of the transformation
  -> InScopeSet
  -> Rewrite extra
  -- ^ Transformation to perform
  -> Term
  -- ^ Term to transform
  -> RewriteMonad extra Term
runRewrite :: String
-> InScopeSet -> Rewrite extra -> Term -> RewriteMonad extra Term
runRewrite String
name InScopeSet
is Rewrite extra
rewrite Term
expr = String -> Rewrite extra -> Rewrite extra
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
name Rewrite extra
rewrite (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is []) Term
expr

-- | Evaluate a RewriteSession to its inner monad.
runRewriteSession :: RewriteEnv
                  -> RewriteState extra
                  -> RewriteMonad extra a
                  -> IO a
runRewriteSession :: RewriteEnv -> RewriteState extra -> RewriteMonad extra a -> IO a
runRewriteSession RewriteEnv
r RewriteState extra
s RewriteMonad extra a
m = do
  (a
a, RewriteState extra
s', Any
_) <- RewriteMonad extra a
-> RewriteEnv
-> RewriteState extra
-> IO (a, RewriteState extra, Any)
forall extra a.
RewriteMonad extra a
-> RewriteEnv
-> RewriteState extra
-> IO (a, RewriteState extra, Any)
runR RewriteMonad extra a
m RewriteEnv
r RewriteState extra
s
  Bool -> String -> IO a -> IO a
forall a. Bool -> String -> a -> a
traceIf (DebugOpts -> Bool
dbg_countTransformations (ClashOpts -> DebugOpts
opt_debug (ClashEnv -> ClashOpts
envOpts (RewriteEnv -> ClashEnv
_clashEnv RewriteEnv
r))))
    (String
"Clash: Transformations:\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Text -> String
Text.unpack (HashMap Text Word -> Text
showCounters (RewriteState extra
s' RewriteState extra
-> Getting
     (HashMap Text Word) (RewriteState extra) (HashMap Text Word)
-> HashMap Text Word
forall s a. s -> Getting a s a -> a
^. Getting
  (HashMap Text Word) (RewriteState extra) (HashMap Text Word)
forall extra. Lens' (RewriteState extra) (HashMap Text Word)
transformCounters))) (IO a -> IO a) -> IO a -> IO a
forall a b. (a -> b) -> a -> b
$
    Bool -> String -> (a -> IO a) -> a -> IO a
forall a. Bool -> String -> a -> a
traceIf (TransformationInfo
None TransformationInfo -> TransformationInfo -> Bool
forall a. Ord a => a -> a -> Bool
< DebugOpts -> TransformationInfo
dbg_transformationInfo (ClashOpts -> DebugOpts
opt_debug (ClashEnv -> ClashOpts
envOpts (RewriteEnv -> ClashEnv
_clashEnv RewriteEnv
r))))
      (String
"Clash: Applied " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Word -> String
forall a. Show a => a -> String
show (RewriteState extra
s' RewriteState extra
-> Getting Word (RewriteState extra) Word -> Word
forall s a. s -> Getting a s a -> a
^. Getting Word (RewriteState extra) Word
forall extra. Lens' (RewriteState extra) Word
transformCounter) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" transformations")
      a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure a
a
  where
    showCounters :: HashMap Text Word -> Text
showCounters =
      [Text] -> Text
Text.unlines
        ([Text] -> Text)
-> (HashMap Text Word -> [Text]) -> HashMap Text Word -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Text, Word) -> Text) -> [(Text, Word)] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map (\(Text
nm,Word
cnt) -> Text
nm Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
": " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
Text.pack (Word -> String
forall a. Show a => a -> String
show Word
cnt))
        ([(Text, Word)] -> [Text])
-> (HashMap Text Word -> [(Text, Word)])
-> HashMap Text Word
-> [Text]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Text, Word) -> Word) -> [(Text, Word)] -> [(Text, Word)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Text, Word) -> Word
forall a b. (a, b) -> b
snd
        ([(Text, Word)] -> [(Text, Word)])
-> (HashMap Text Word -> [(Text, Word)])
-> HashMap Text Word
-> [(Text, Word)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashMap Text Word -> [(Text, Word)]
forall k v. HashMap k v -> [(k, v)]
HashMap.toList

-- | Notify that a transformation has changed the expression
setChanged :: RewriteMonad extra ()
setChanged :: RewriteMonad extra ()
setChanged = Any -> RewriteMonad extra ()
forall w (m :: Type -> Type). MonadWriter w m => w -> m ()
Writer.tell (Bool -> Any
Monoid.Any Bool
True)

-- | Identity function that additionally notifies that a transformation has
-- changed the expression
changed :: a -> RewriteMonad extra a
changed :: a -> RewriteMonad extra a
changed a
val = do
  Any -> RewriteMonad extra ()
forall w (m :: Type -> Type). MonadWriter w m => w -> m ()
Writer.tell (Bool -> Any
Monoid.Any Bool
True)
  a -> RewriteMonad extra a
forall (m :: Type -> Type) a. Monad m => a -> m a
return a
val

closestLetBinder :: Context -> Maybe Id
closestLetBinder :: Context -> Maybe Id
closestLetBinder [] = Maybe Id
forall a. Maybe a
Nothing
closestLetBinder (LetBinding Id
id_ [Id]
_:Context
_) = Id -> Maybe Id
forall a. a -> Maybe a
Just Id
id_
closestLetBinder (CoreContext
_:Context
ctx)              = Context -> Maybe Id
closestLetBinder Context
ctx

mkDerivedName :: TransformContext -> OccName -> TmName
mkDerivedName :: TransformContext -> Text -> Name Term
mkDerivedName (TransformContext InScopeSet
_ Context
ctx) Text
sf = case Context -> Maybe Id
closestLetBinder Context
ctx of
  Just Id
id_ -> Name Term -> Text -> Name Term
forall a. Name a -> Text -> Name a
appendToName (Id -> Name Term
forall a. Var a -> Name a
varName Id
id_) (Char
'_' Char -> Text -> Text
`Text.cons` Text
sf)
  Maybe Id
_ -> Text -> Int -> Name Term
forall a. Text -> Int -> Name a
mkUnsafeInternalName Text
sf Int
0

-- | Make a new binder and variable reference for a term
mkTmBinderFor
  :: (MonadUnique m)
  => InScopeSet
  -> TyConMap -- ^ TyCon cache
  -> Name a -- ^ Name of the new binder
  -> Term -- ^ Term to bind
  -> m Id
mkTmBinderFor :: InScopeSet -> TyConMap -> Name a -> Term -> m Id
mkTmBinderFor InScopeSet
is TyConMap
tcm Name a
name Term
e =
  (Id -> Id) -> (TyVar -> Id) -> Either Id TyVar -> Id
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Id -> Id
forall a. a -> a
id (String -> TyVar -> Id
forall a. HasCallStack => String -> a
error String
"mkTmBinderFor: Result is a TyVar")
    (Either Id TyVar -> Id) -> m (Either Id TyVar) -> m Id
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> InScopeSet
-> TyConMap -> Name a -> Either Term Type -> m (Either Id TyVar)
forall (m :: Type -> Type) a.
MonadUnique m =>
InScopeSet
-> TyConMap -> Name a -> Either Term Type -> m (Either Id TyVar)
mkBinderFor InScopeSet
is TyConMap
tcm Name a
name (Term -> Either Term Type
forall a b. a -> Either a b
Left Term
e)

-- | Make a new binder and variable reference for either a term or a type
mkBinderFor
  :: (MonadUnique m)
  => InScopeSet
  -> TyConMap -- ^ TyCon cache
  -> Name a -- ^ Name of the new binder
  -> Either Term Type -- ^ Type or Term to bind
  -> m (Either Id TyVar)
mkBinderFor :: InScopeSet
-> TyConMap -> Name a -> Either Term Type -> m (Either Id TyVar)
mkBinderFor InScopeSet
is TyConMap
tcm Name a
name (Left Term
term) = do
  Name a
name' <- InScopeSet -> Name a -> m (Name a)
forall (m :: Type -> Type) a.
MonadUnique m =>
InScopeSet -> Name a -> m (Name a)
cloneNameWithInScopeSet InScopeSet
is Name a
name
  let ty :: Type
ty = TyConMap -> Term -> Type
forall a. InferType a => TyConMap -> a -> Type
inferCoreTypeOf TyConMap
tcm Term
term
  Either Id TyVar -> m (Either Id TyVar)
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Id -> Either Id TyVar
forall a b. a -> Either a b
Left (Type -> Name Term -> Id
mkLocalId Type
ty (Name a -> Name Term
coerce Name a
name')))

mkBinderFor InScopeSet
is TyConMap
tcm Name a
name (Right Type
ty) = do
  Name a
name' <- InScopeSet -> Name a -> m (Name a)
forall (m :: Type -> Type) a.
MonadUnique m =>
InScopeSet -> Name a -> m (Name a)
cloneNameWithInScopeSet InScopeSet
is Name a
name
  let ki :: Type
ki = TyConMap -> Type -> Type
forall a. InferType a => TyConMap -> a -> Type
inferCoreKindOf TyConMap
tcm Type
ty
  Either Id TyVar -> m (Either Id TyVar)
forall (m :: Type -> Type) a. Monad m => a -> m a
return (TyVar -> Either Id TyVar
forall a b. b -> Either a b
Right (Type -> TyName -> TyVar
mkTyVar Type
ki (Name a -> TyName
coerce Name a
name')))

-- | Inline the binders in a let-binding that have a certain property
inlineBinders
  :: (Term -> LetBinding -> RewriteMonad extra Bool)
  -- ^ Property test
  -> Rewrite extra
inlineBinders :: (Term -> (Id, Term) -> RewriteMonad extra Bool) -> Rewrite extra
inlineBinders Term -> (Id, Term) -> RewriteMonad extra Bool
condition (TransformContext InScopeSet
inScope0 Context
_) expr :: Term
expr@(Let (NonRec Id
i Term
x) Term
res) = do
  Bool
inline <- Term -> (Id, Term) -> RewriteMonad extra Bool
condition Term
expr (Id
i, Term
x)

  if Bool
inline Bool -> Bool -> Bool
&& Id -> Term -> Bool
forall a. HasFreeVars a => Var a -> a -> Bool
elemFreeVars Id
i Term
res then
    let inScope1 :: InScopeSet
inScope1 = InScopeSet -> Id -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
inScope0 Id
i
        subst :: Subst
subst = Subst -> Id -> Term -> Subst
extendIdSubst (InScopeSet -> Subst
mkSubst InScopeSet
inScope1) Id
i Term
x
     in Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"inlineBinders" Subst
subst Term
res)
  else
    Term -> RewriteMonad extra Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
expr

inlineBinders Term -> (Id, Term) -> RewriteMonad extra Bool
condition (TransformContext InScopeSet
inScope0 Context
_) expr :: Term
expr@(Let (Rec [(Id, Term)]
xes) Term
res) = do
  ([(Id, Term)]
toInline,[(Id, Term)]
toKeep) <- ((Id, Term) -> RewriteMonad extra Bool)
-> [(Id, Term)] -> RewriteMonad extra ([(Id, Term)], [(Id, Term)])
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m ([a], [a])
partitionM (Term -> (Id, Term) -> RewriteMonad extra Bool
condition Term
expr) [(Id, Term)]
xes
  case [(Id, Term)]
toInline of
    [] -> Term -> RewriteMonad extra Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
expr
    [(Id, Term)]
_  -> do
      let inScope1 :: InScopeSet
inScope1 = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
inScope0 (((Id, Term) -> Id) -> [(Id, Term)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Id, Term) -> Id
forall a b. (a, b) -> a
fst [(Id, Term)]
xes)
          ([(Id, Term)]
toInlRec,([(Id, Term)]
toKeep1,Term
res1)) =
            InScopeSet
-> [(Id, Term)]
-> [(Id, Term)]
-> Term
-> ([(Id, Term)], ([(Id, Term)], Term))
substituteBinders InScopeSet
inScope1 [(Id, Term)]
toInline [(Id, Term)]
toKeep Term
res
      case [(Id, Term)]
toInlRec [(Id, Term)] -> [(Id, Term)] -> [(Id, Term)]
forall a. [a] -> [a] -> [a]
++ [(Id, Term)]
toKeep1 of
        []   -> Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
res1
        [(Id, Term)]
xes1 -> Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed ([(Id, Term)] -> Term -> Term
Letrec [(Id, Term)]
xes1 Term
res1)

inlineBinders Term -> (Id, Term) -> RewriteMonad extra Bool
_ TransformContext
_ Term
e = Term -> RewriteMonad extra Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e

-- | Determine whether a binder is a join-point created for a complex case
-- expression.
--
-- A join-point is when a local function only occurs in tail-call positions,
-- and when it does, more than once.
isJoinPointIn :: Id   -- ^ 'Id' of the local binder
              -> Term -- ^ Expression in which the binder is bound
              -> Bool
isJoinPointIn :: Id -> Term -> Bool
isJoinPointIn Id
id_ Term
e = case Id -> Term -> Maybe Int
tailCalls Id
id_ Term
e of
                      Just Int
n | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1 -> Bool
True
                      Maybe Int
_              -> Bool
False

-- | Count the number of (only) tail calls of a function in an expression.
-- 'Nothing' indicates that the function was used in a non-tail call position.
tailCalls :: Id   -- ^ Function to check
          -> Term -- ^ Expression to check it in
          -> Maybe Int
tailCalls :: Id -> Term -> Maybe Int
tailCalls Id
id_ = \case
  Var Id
nm | Id
id_ Id -> Id -> Bool
forall a. Eq a => a -> a -> Bool
== Id
nm -> Int -> Maybe Int
forall a. a -> Maybe a
Just Int
1
         | Bool
otherwise -> Int -> Maybe Int
forall a. a -> Maybe a
Just Int
0
  Lam Id
_ Term
e -> Id -> Term -> Maybe Int
tailCalls Id
id_ Term
e
  TyLam TyVar
_ Term
e -> Id -> Term -> Maybe Int
tailCalls Id
id_ Term
e
  App Term
l Term
r  -> case Id -> Term -> Maybe Int
tailCalls Id
id_ Term
r of
                Just Int
0 -> Id -> Term -> Maybe Int
tailCalls Id
id_ Term
l
                Maybe Int
_      -> Maybe Int
forall a. Maybe a
Nothing
  TyApp Term
l Type
_ -> Id -> Term -> Maybe Int
tailCalls Id
id_ Term
l
  Letrec [(Id, Term)]
bs Term
e ->
    let ([Id]
bsIds,[Term]
bsExprs) = [(Id, Term)] -> ([Id], [Term])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Id, Term)]
bs
        bsTls :: [Maybe Int]
bsTls           = (Term -> Maybe Int) -> [Term] -> [Maybe Int]
forall a b. (a -> b) -> [a] -> [b]
map (Id -> Term -> Maybe Int
tailCalls Id
id_) [Term]
bsExprs
        bsIdsUsed :: [Id]
bsIdsUsed       = ((Id, Maybe Int) -> Maybe Id) -> [(Id, Maybe Int)] -> [Id]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (\(Id
l,Maybe Int
r) -> Id -> Maybe Id
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Id
l Maybe Id -> Maybe Int -> Maybe Id
forall (f :: Type -> Type) a b. Applicative f => f a -> f b -> f a
<* Maybe Int
r) ([Id] -> [Maybe Int] -> [(Id, Maybe Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Id]
bsIds [Maybe Int]
bsTls)
        bsIdsTls :: [Maybe Int]
bsIdsTls        = (Id -> Maybe Int) -> [Id] -> [Maybe Int]
forall a b. (a -> b) -> [a] -> [b]
map (Id -> Term -> Maybe Int
`tailCalls` Term
e) [Id]
bsIdsUsed
        bsCount :: Maybe Int
bsCount         = Int -> Maybe Int
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Int -> Maybe Int) -> ([Int] -> Int) -> [Int] -> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Int
forall (t :: Type -> Type) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Maybe Int) -> [Int] -> Maybe Int
forall a b. (a -> b) -> a -> b
$ [Maybe Int] -> [Int]
forall a. [Maybe a] -> [a]
catMaybes [Maybe Int]
bsTls
    in  case ((Maybe Int -> Bool) -> [Maybe Int] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all Maybe Int -> Bool
forall a. Maybe a -> Bool
isJust [Maybe Int]
bsTls) of
          Bool
False -> Maybe Int
forall a. Maybe a
Nothing
          Bool
True  -> case ((Int -> Bool) -> [Int] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
0) ([Int] -> Bool) -> [Int] -> Bool
forall a b. (a -> b) -> a -> b
$ [Maybe Int] -> [Int]
forall a. [Maybe a] -> [a]
catMaybes [Maybe Int]
bsTls) of
            Bool
False  -> case (Maybe Int -> Bool) -> [Maybe Int] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all Maybe Int -> Bool
forall a. Maybe a -> Bool
isJust [Maybe Int]
bsIdsTls of
              Bool
False -> Maybe Int
forall a. Maybe a
Nothing
              Bool
True  -> Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) (Int -> Int -> Int) -> Maybe Int -> Maybe (Int -> Int)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Int
bsCount Maybe (Int -> Int) -> Maybe Int -> Maybe Int
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Id -> Term -> Maybe Int
tailCalls Id
id_ Term
e
            Bool
True -> Id -> Term -> Maybe Int
tailCalls Id
id_ Term
e
  Case Term
scrut Type
_ [Alt]
alts ->
    let scrutTl :: Maybe Int
scrutTl = Id -> Term -> Maybe Int
tailCalls Id
id_ Term
scrut
        altsTl :: [Maybe Int]
altsTl  = (Alt -> Maybe Int) -> [Alt] -> [Maybe Int]
forall a b. (a -> b) -> [a] -> [b]
map (Id -> Term -> Maybe Int
tailCalls Id
id_ (Term -> Maybe Int) -> (Alt -> Term) -> Alt -> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Alt -> Term
forall a b. (a, b) -> b
snd) [Alt]
alts
    in  case Maybe Int
scrutTl of
          Just Int
0 | (Maybe Int -> Bool) -> [Maybe Int] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all (Maybe Int -> Maybe Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Maybe Int
forall a. Maybe a
Nothing) [Maybe Int]
altsTl -> Int -> Maybe Int
forall a. a -> Maybe a
Just ([Int] -> Int
forall (t :: Type -> Type) a. (Foldable t, Num a) => t a -> a
sum ([Maybe Int] -> [Int]
forall a. [Maybe a] -> [a]
catMaybes [Maybe Int]
altsTl))
          Maybe Int
_ -> Maybe Int
forall a. Maybe a
Nothing
  Term
_ -> Int -> Maybe Int
forall a. a -> Maybe a
Just Int
0

-- | Determines whether a function has the following shape:
--
-- > \(w :: Void) -> f a b c
--
-- i.e. is a wrapper around a (partially) applied function 'f', where the
-- introduced argument 'w' is not used by 'f'
isVoidWrapper :: Term -> Bool
isVoidWrapper :: Term -> Bool
isVoidWrapper (Lam Id
bndr e :: Term
e@(Term -> (Term, [Either Term Type])
collectArgs -> (Var Id
_,[Either Term Type]
_))) =
  Id
bndr Id -> Term -> Bool
forall a. HasFreeVars a => Var a -> a -> Bool
`notElemFreeVars` Term
e
isVoidWrapper Term
_ = Bool
False

-- | Inline the first set of binder into the second set of binders and into the
-- body of the original let expression.
substituteBinders
  :: InScopeSet
  -> [LetBinding]
  -- ^ Let-binders to substitute
  -> [LetBinding]
  -- ^ Let-binders where substitution takes place
  -> Term
  -- ^ Body where substitution takes place
  -> ([LetBinding],([LetBinding],Term))
  -- ^
  -- 1. Let-bindings that we wanted to substitute, but turned out to be recursive
  -- 2.1 Let-binders where substitution took place
  -- 2.2 Body where substitution took place
substituteBinders :: InScopeSet
-> [(Id, Term)]
-> [(Id, Term)]
-> Term
-> ([(Id, Term)], ([(Id, Term)], Term))
substituteBinders InScopeSet
inScope [(Id, Term)]
toInline [(Id, Term)]
toKeep Term
body =
  let (Subst
subst,[(Id, Term)]
toInlRec) = Subst -> [(Id, Term)] -> [(Id, Term)] -> (Subst, [(Id, Term)])
go (InScopeSet -> Subst
mkSubst InScopeSet
inScope) [] [(Id, Term)]
toInline
  in  ( ((Id, Term) -> (Id, Term)) -> [(Id, Term)] -> [(Id, Term)]
forall a b. (a -> b) -> [a] -> [b]
map ((Term -> Term) -> (Id, Term) -> (Id, Term)
forall (p :: Type -> Type -> Type) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"substToInlRec" Subst
subst)) [(Id, Term)]
toInlRec
      , ( ((Id, Term) -> (Id, Term)) -> [(Id, Term)] -> [(Id, Term)]
forall a b. (a -> b) -> [a] -> [b]
map ((Term -> Term) -> (Id, Term) -> (Id, Term)
forall (p :: Type -> Type -> Type) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"substToKeep" Subst
subst)) [(Id, Term)]
toKeep
        , HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"substBody" Subst
subst Term
body) )
 where
  go :: Subst -> [(Id, Term)] -> [(Id, Term)] -> (Subst, [(Id, Term)])
go Subst
subst [(Id, Term)]
inlRec [] = (Subst
subst,[(Id, Term)]
inlRec)
  go !Subst
subst ![(Id, Term)]
inlRec ((Id
x,Term
e):[(Id, Term)]
toInl) =
    let e1 :: Term
e1      = HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"substInl" Subst
subst Term
e
        substE :: Subst
substE  = Subst -> Id -> Term -> Subst
extendIdSubst (InScopeSet -> Subst
mkSubst InScopeSet
inScope) Id
x Term
e1
        subst1 :: Subst
subst1  = Subst
subst { substTmEnv :: IdSubstEnv
substTmEnv = (Term -> Term) -> IdSubstEnv -> IdSubstEnv
forall a b. (a -> b) -> VarEnv a -> VarEnv b
mapVarEnv (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"substSubst" Subst
substE)
                                                 (Subst -> IdSubstEnv
substTmEnv Subst
subst)}
        subst2 :: Subst
subst2  = Subst -> Id -> Term -> Subst
extendIdSubst Subst
subst1 Id
x Term
e1
    in  if Id
x Id -> Term -> Bool
forall a. HasFreeVars a => Var a -> a -> Bool
`elemFreeVars` Term
e1 then
          Subst -> [(Id, Term)] -> [(Id, Term)] -> (Subst, [(Id, Term)])
go Subst
subst ((Id
x,Term
e1)(Id, Term) -> [(Id, Term)] -> [(Id, Term)]
forall a. a -> [a] -> [a]
:[(Id, Term)]
inlRec) [(Id, Term)]
toInl
        else
          Subst -> [(Id, Term)] -> [(Id, Term)] -> (Subst, [(Id, Term)])
go Subst
subst2 [(Id, Term)]
inlRec [(Id, Term)]
toInl

-- | Lift the first set of binders to the level of global bindings, and substitute
-- these lifted bindings into the second set of binders and the body of the
-- original let expression.
liftAndSubsituteBinders
  :: InScopeSet
  -> [LetBinding]
  -- ^ Let-binders to lift, and substitute the lifted result
  -> [LetBinding]
  -- ^ Lef-binders where substitution takes place
  -> Term
  -- ^ Body where substitution takes place
  -> RewriteMonad extra ([LetBinding],Term)
liftAndSubsituteBinders :: InScopeSet
-> [(Id, Term)]
-> [(Id, Term)]
-> Term
-> RewriteMonad extra ([(Id, Term)], Term)
liftAndSubsituteBinders InScopeSet
inScope [(Id, Term)]
toLift [(Id, Term)]
toKeep Term
body = do
  Subst
subst <- Subst -> [(Id, Term)] -> RewriteMonad extra Subst
forall extra. Subst -> [(Id, Term)] -> RewriteMonad extra Subst
go (InScopeSet -> Subst
mkSubst InScopeSet
inScope) [(Id, Term)]
toLift
  ([(Id, Term)], Term) -> RewriteMonad extra ([(Id, Term)], Term)
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ( ((Id, Term) -> (Id, Term)) -> [(Id, Term)] -> [(Id, Term)]
forall a b. (a -> b) -> [a] -> [b]
map ((Term -> Term) -> (Id, Term) -> (Id, Term)
forall (p :: Type -> Type -> Type) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"liftToKeep" Subst
subst)) [(Id, Term)]
toKeep
       , HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"keepBody" Subst
subst Term
body
       )
 where
  go :: Subst -> [(Id, Term)] -> RewriteMonad extra Subst
go Subst
subst [] = Subst -> RewriteMonad extra Subst
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Subst
subst
  go !Subst
subst ((Id
x,Term
e):[(Id, Term)]
inl) = do
    let e1 :: Term
e1 = HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"liftInl" Subst
subst Term
e
    (Id
_,Term
e2) <- (Id, Term) -> RewriteMonad extra (Id, Term)
forall extra. (Id, Term) -> RewriteMonad extra (Id, Term)
liftBinding (Id
x,Term
e1)
    let substE :: Subst
substE = Subst -> Id -> Term -> Subst
extendIdSubst (InScopeSet -> Subst
mkSubst InScopeSet
inScope) Id
x Term
e2
        subst1 :: Subst
subst1 = Subst
subst { substTmEnv :: IdSubstEnv
substTmEnv = (Term -> Term) -> IdSubstEnv -> IdSubstEnv
forall a b. (a -> b) -> VarEnv a -> VarEnv b
mapVarEnv (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"liftSubst" Subst
substE)
                                                (Subst -> IdSubstEnv
substTmEnv Subst
subst) }
        subst2 :: Subst
subst2 = Subst -> Id -> Term -> Subst
extendIdSubst Subst
subst1 Id
x Term
e2
    if Id
x Id -> Term -> Bool
forall a. HasFreeVars a => Var a -> a -> Bool
`elemFreeVars` Term
e2 then do
      (Id
_,SrcSpan
sp) <- Getting (Id, SrcSpan) (RewriteState extra) (Id, SrcSpan)
-> RewriteMonad extra (Id, SrcSpan)
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting (Id, SrcSpan) (RewriteState extra) (Id, SrcSpan)
forall extra. Lens' (RewriteState extra) (Id, SrcSpan)
curFun
      ClashException -> RewriteMonad extra Subst
forall a e. Exception e => e -> a
throw (SrcSpan -> String -> Maybe String -> ClashException
ClashException SrcSpan
sp [I.i|
        Internal error: inlineOrLiftBInders failed on:

        #{showPpr (x,e)}

        creating a self-recursive let-binding:

        #{showPpr (x,e2)}

        given the already built subtitution:

        #{showDoc (clashPretty (substTmEnv subst))}
      |] Maybe String
forall a. Maybe a
Nothing)
    else
      Subst -> [(Id, Term)] -> RewriteMonad extra Subst
go Subst
subst2 [(Id, Term)]
inl

isFromInt :: Text -> Bool
isFromInt :: Text -> Bool
isFromInt Text
nm = Text
nm Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"Clash.Sized.Internal.BitVector.fromInteger##" Bool -> Bool -> Bool
||
               Text
nm Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"Clash.Sized.Internal.BitVector.fromInteger#" Bool -> Bool -> Bool
||
               Text
nm Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"Clash.Sized.Internal.Index.fromInteger#" Bool -> Bool -> Bool
||
               Text
nm Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"Clash.Sized.Internal.Signed.fromInteger#" Bool -> Bool -> Bool
||
               Text
nm Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
"Clash.Sized.Internal.Unsigned.fromInteger#"

inlineOrLiftBinders
  :: (LetBinding -> RewriteMonad extra Bool)
  -- ^ Property test
  -> (Term -> LetBinding -> Bool)
  -- ^ Test whether to lift or inline
  --
  -- * True: inline
  -- * False: lift
  -> Rewrite extra
inlineOrLiftBinders :: ((Id, Term) -> RewriteMonad extra Bool)
-> (Term -> (Id, Term) -> Bool) -> Rewrite extra
inlineOrLiftBinders (Id, Term) -> RewriteMonad extra Bool
condition Term -> (Id, Term) -> Bool
inlineOrLift (TransformContext InScopeSet
inScope0 Context
_) e :: Term
e@(Letrec [(Id, Term)]
bndrs Term
body) = do
  ([(Id, Term)]
toReplace,[(Id, Term)]
toKeep) <- ((Id, Term) -> RewriteMonad extra Bool)
-> [(Id, Term)] -> RewriteMonad extra ([(Id, Term)], [(Id, Term)])
forall (m :: Type -> Type) a.
Monad m =>
(a -> m Bool) -> [a] -> m ([a], [a])
partitionM (Id, Term) -> RewriteMonad extra Bool
condition [(Id, Term)]
bndrs
  case [(Id, Term)]
toReplace of
    [] -> Term -> RewriteMonad extra Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e
    [(Id, Term)]
_  -> do
      let inScope1 :: InScopeSet
inScope1 = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
inScope0 (((Id, Term) -> Id) -> [(Id, Term)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Id, Term) -> Id
forall a b. (a, b) -> a
fst [(Id, Term)]
bndrs)
      let ([(Id, Term)]
toInline,[(Id, Term)]
toLift) = ((Id, Term) -> Bool)
-> [(Id, Term)] -> ([(Id, Term)], [(Id, Term)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (Term -> (Id, Term) -> Bool
inlineOrLift Term
e) [(Id, Term)]
toReplace
      -- We first substitute the binders that we can inline both the binders
      -- that we intend to lift, the other binders, and the body
      let ([(Id, Term)]
toLiftExtra,([(Id, Term)]
toReplace1,Term
body1)) =
            InScopeSet
-> [(Id, Term)]
-> [(Id, Term)]
-> Term
-> ([(Id, Term)], ([(Id, Term)], Term))
substituteBinders InScopeSet
inScope1 [(Id, Term)]
toInline ([(Id, Term)]
toLift [(Id, Term)] -> [(Id, Term)] -> [(Id, Term)]
forall a. [a] -> [a] -> [a]
++ [(Id, Term)]
toKeep) Term
body
          ([(Id, Term)]
toLift1,[(Id, Term)]
toKeep1) = Int -> [(Id, Term)] -> ([(Id, Term)], [(Id, Term)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([(Id, Term)] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [(Id, Term)]
toLift) [(Id, Term)]
toReplace1
      -- We then substitute the lifted binders in the other binders and the body
      ([(Id, Term)]
toKeep2,Term
body2) <- InScopeSet
-> [(Id, Term)]
-> [(Id, Term)]
-> Term
-> RewriteMonad extra ([(Id, Term)], Term)
forall extra.
InScopeSet
-> [(Id, Term)]
-> [(Id, Term)]
-> Term
-> RewriteMonad extra ([(Id, Term)], Term)
liftAndSubsituteBinders InScopeSet
inScope1
                           ([(Id, Term)]
toLiftExtra [(Id, Term)] -> [(Id, Term)] -> [(Id, Term)]
forall a. [a] -> [a] -> [a]
++ [(Id, Term)]
toLift1)
                           [(Id, Term)]
toKeep1 Term
body1
      case [(Id, Term)]
toKeep2 of
        [] -> Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed Term
body2
        [(Id, Term)]
_  -> Term -> RewriteMonad extra Term
forall a extra. a -> RewriteMonad extra a
changed ([(Id, Term)] -> Term -> Term
Letrec [(Id, Term)]
toKeep2 Term
body2)

inlineOrLiftBinders (Id, Term) -> RewriteMonad extra Bool
_ Term -> (Id, Term) -> Bool
_ TransformContext
_ Term
e = Term -> RewriteMonad extra Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e

-- | Create a global function for a Let-binding and return a Let-binding where
-- the RHS is a reference to the new global function applied to the free
-- variables of the original RHS
liftBinding :: LetBinding
            -> RewriteMonad extra LetBinding
liftBinding :: (Id, Term) -> RewriteMonad extra (Id, Term)
liftBinding (var :: Id
var@Id {varName :: forall a. Var a -> Name a
varName = Name Term
idName} ,Term
e) = do
  -- Get all local FVs, excluding the 'idName' from the let-binding
  let unitFV :: Var a -> Const (UniqSet TyVar,UniqSet Id) (Var a)
      unitFV :: Var a -> Const (UniqSet TyVar, UniqSet Id) (Var a)
unitFV v :: Var a
v@(Id {})    = (UniqSet TyVar, UniqSet Id)
-> Const (UniqSet TyVar, UniqSet Id) (Var a)
forall k a (b :: k). a -> Const a b
Const (UniqSet TyVar
forall a. UniqSet a
emptyUniqSet,Id -> UniqSet Id
forall a. Uniquable a => a -> UniqSet a
unitUniqSet (Var a -> Id
coerce Var a
v))
      unitFV v :: Var a
v@(TyVar {}) = (UniqSet TyVar, UniqSet Id)
-> Const (UniqSet TyVar, UniqSet Id) (Var a)
forall k a (b :: k). a -> Const a b
Const (TyVar -> UniqSet TyVar
forall a. Uniquable a => a -> UniqSet a
unitUniqSet (Var a -> TyVar
coerce Var a
v),UniqSet Id
forall a. UniqSet a
emptyUniqSet)

      interesting :: Var a -> Bool
      interesting :: Var a -> Bool
interesting Id {idScope :: forall a. Var a -> IdScope
idScope = IdScope
GlobalId} = Bool
False
      interesting v :: Var a
v@(Id {idScope :: forall a. Var a -> IdScope
idScope = IdScope
LocalId}) = Var a -> Int
forall a. Var a -> Int
varUniq Var a
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Id -> Int
forall a. Var a -> Int
varUniq Id
var
      interesting Var a
_ = Bool
True

      (UniqSet TyVar
boundFTVsSet,UniqSet Id
boundFVsSet) =
        Const (UniqSet TyVar, UniqSet Id) (Var Any)
-> (UniqSet TyVar, UniqSet Id)
forall a k (b :: k). Const a b -> a
getConst (Getting
  (Const (UniqSet TyVar, UniqSet Id) (Var Any)) Term (Var Any)
-> (Var Any -> Const (UniqSet TyVar, UniqSet Id) (Var Any))
-> Term
-> Const (UniqSet TyVar, UniqSet Id) (Var Any)
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf ((forall b. Var b -> Bool)
-> Getting
     (Const (UniqSet TyVar, UniqSet Id) (Var Any)) Term (Var Any)
forall (f :: Type -> Type) a.
(Contravariant f, Applicative f) =>
(forall b. Var b -> Bool) -> (Var a -> f (Var a)) -> Term -> f Term
termFreeVars' forall b. Var b -> Bool
interesting) Var Any -> Const (UniqSet TyVar, UniqSet Id) (Var Any)
forall a. Var a -> Const (UniqSet TyVar, UniqSet Id) (Var a)
unitFV Term
e)
      boundFTVs :: [TyVar]
boundFTVs = UniqSet TyVar -> [TyVar]
forall a. UniqSet a -> [a]
eltsUniqSet UniqSet TyVar
boundFTVsSet
      boundFVs :: [Id]
boundFVs  = UniqSet Id -> [Id]
forall a. UniqSet a -> [a]
eltsUniqSet UniqSet Id
boundFVsSet

  -- Make a new global ID
  TyConMap
tcm       <- Getting TyConMap RewriteEnv TyConMap -> RewriteMonad extra TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
  let newBodyTy :: Type
newBodyTy = TyConMap -> Term -> Type
forall a. InferType a => TyConMap -> a -> Type
inferCoreTypeOf TyConMap
tcm (Term -> Type) -> Term -> Type
forall a b. (a -> b) -> a -> b
$ Term -> [TyVar] -> Term
mkTyLams (Term -> [Id] -> Term
mkLams Term
e [Id]
boundFVs) [TyVar]
boundFTVs
  (Id
cf,SrcSpan
sp)   <- Getting (Id, SrcSpan) (RewriteState extra) (Id, SrcSpan)
-> RewriteMonad extra (Id, SrcSpan)
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting (Id, SrcSpan) (RewriteState extra) (Id, SrcSpan)
forall extra. Lens' (RewriteState extra) (Id, SrcSpan)
curFun
  BindingMap
binders <- Getting BindingMap (RewriteState extra) BindingMap
-> RewriteMonad extra BindingMap
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting BindingMap (RewriteState extra) BindingMap
forall extra. Lens' (RewriteState extra) BindingMap
bindings
  Name Term
newBodyNm <-
    BindingMap -> Name Term -> RewriteMonad extra (Name Term)
forall (m :: Type -> Type) a.
MonadUnique m =>
BindingMap -> Name a -> m (Name a)
cloneNameWithBindingMap
      BindingMap
binders
      (Name Term -> Text -> Name Term
forall a. Name a -> Text -> Name a
appendToName (Id -> Name Term
forall a. Var a -> Name a
varName Id
cf) (Text
"_" Text -> Text -> Text
`Text.append` Name Term -> Text
forall a. Name a -> Text
nameOcc Name Term
idName))
  let newBodyId :: Id
newBodyId = Type -> Name Term -> Id
mkGlobalId Type
newBodyTy Name Term
newBodyNm {nameSort :: NameSort
nameSort = NameSort
Internal}

  -- Make a new expression, consisting of the the lifted function applied to
  -- its free variables
  let newExpr :: Term
newExpr = Term -> [Term] -> Term
mkTmApps
                  (Term -> [Type] -> Term
mkTyApps (Id -> Term
Var Id
newBodyId)
                            ((TyVar -> Type) -> [TyVar] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map TyVar -> Type
VarTy [TyVar]
boundFTVs))
                  ((Id -> Term) -> [Id] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map Id -> Term
Var [Id]
boundFVs)
      inScope0 :: InScopeSet
inScope0 = VarSet -> InScopeSet
mkInScopeSet (UniqSet Id -> VarSet
coerce UniqSet Id
boundFVsSet)
      inScope1 :: InScopeSet
inScope1 = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
inScope0 [Id
var,Id
newBodyId]
  let subst :: Subst
subst    = Subst -> Id -> Term -> Subst
extendIdSubst (InScopeSet -> Subst
mkSubst InScopeSet
inScope1) Id
var Term
newExpr
      -- Substitute the recursive calls by the new expression
      e' :: Term
e' = HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm Doc ()
"liftBinding" Subst
subst Term
e
      -- Create a new body that abstracts over the free variables
      newBody :: Term
newBody = Term -> [TyVar] -> Term
mkTyLams (Term -> [Id] -> Term
mkLams Term
e' [Id]
boundFVs) [TyVar]
boundFTVs

  -- Check if an alpha-equivalent global binder already exists
  [Binding Term]
aeqExisting <- (BindingMap -> [Binding Term]
forall a. UniqMap a -> [a]
eltsUniqMap (BindingMap -> [Binding Term])
-> (BindingMap -> BindingMap) -> BindingMap -> [Binding Term]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Binding Term -> Bool) -> BindingMap -> BindingMap
forall b. (b -> Bool) -> UniqMap b -> UniqMap b
filterUniqMap ((Term -> Term -> Bool
`aeqTerm` Term
newBody) (Term -> Bool) -> (Binding Term -> Term) -> Binding Term -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Binding Term -> Term
forall a. Binding a -> a
bindingTerm)) (BindingMap -> [Binding Term])
-> RewriteMonad extra BindingMap
-> RewriteMonad extra [Binding Term]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting BindingMap (RewriteState extra) BindingMap
-> RewriteMonad extra BindingMap
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting BindingMap (RewriteState extra) BindingMap
forall extra. Lens' (RewriteState extra) BindingMap
bindings
  case [Binding Term]
aeqExisting of
    -- If it doesn't, create a new binder
    [] -> do -- Add the created function to the list of global bindings
             let r :: Bool
r = Id
newBodyId Id -> Term -> Bool
`globalIdOccursIn` Term
newBody
             (BindingMap -> Identity BindingMap)
-> RewriteState extra -> Identity (RewriteState extra)
forall extra. Lens' (RewriteState extra) BindingMap
bindings ((BindingMap -> Identity BindingMap)
 -> RewriteState extra -> Identity (RewriteState extra))
-> (BindingMap -> BindingMap) -> RewriteMonad extra ()
forall s (m :: Type -> Type) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= Name Term -> Binding Term -> BindingMap -> BindingMap
forall a b. Uniquable a => a -> b -> UniqMap b -> UniqMap b
extendUniqMap Name Term
newBodyNm
                                    -- We mark this function as internal so that
                                    -- it can be inlined at the very end of
                                    -- the normalisation pipeline as part of the
                                    -- flattening pass. We don't inline
                                    -- right away because we are lifting this
                                    -- function at this moment for a reason!
                                    -- (termination, CSE and DEC oppertunities,
                                    -- ,etc.)
                                    (Id
-> SrcSpan -> InlineSpec -> IsPrim -> Term -> Bool -> Binding Term
forall a.
Id -> SrcSpan -> InlineSpec -> IsPrim -> a -> Bool -> Binding a
Binding Id
newBodyId SrcSpan
sp InlineSpec
NoUserInline IsPrim
IsFun Term
newBody Bool
r)
             -- Return the new binder
             (Id, Term) -> RewriteMonad extra (Id, Term)
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Id
var, Term
newExpr)
    -- If it does, use the existing binder
    (Binding Term
b:[Binding Term]
_) ->
      let newExpr' :: Term
newExpr' = Term -> [Term] -> Term
mkTmApps
                      (Term -> [Type] -> Term
mkTyApps (Id -> Term
Var (Id -> Term) -> Id -> Term
forall a b. (a -> b) -> a -> b
$ Binding Term -> Id
forall a. Binding a -> Id
bindingId Binding Term
b)
                                ((TyVar -> Type) -> [TyVar] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map TyVar -> Type
VarTy [TyVar]
boundFTVs))
                      ((Id -> Term) -> [Id] -> [Term]
forall a b. (a -> b) -> [a] -> [b]
map Id -> Term
Var [Id]
boundFVs)
      in  (Id, Term) -> RewriteMonad extra (Id, Term)
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Id
var, Term
newExpr')

liftBinding (Id, Term)
_ = String -> RewriteMonad extra (Id, Term)
forall a. HasCallStack => String -> a
error (String -> RewriteMonad extra (Id, Term))
-> String -> RewriteMonad extra (Id, Term)
forall a b. (a -> b) -> a -> b
$ $(String
curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"liftBinding: invalid core, expr bound to tyvar"

-- | Make a global function for a name-term tuple
mkFunction
  :: TmName
  -- ^ Name of the function
  -> SrcSpan
  -> InlineSpec
  -> Term
  -- ^ Term bound to the function
  -> RewriteMonad extra Id
  -- ^ Name with a proper unique and the type of the function
mkFunction :: Name Term -> SrcSpan -> InlineSpec -> Term -> RewriteMonad extra Id
mkFunction Name Term
bndrNm SrcSpan
sp InlineSpec
inl Term
body = do
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap -> RewriteMonad extra TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
  let bodyTy :: Type
bodyTy = TyConMap -> Term -> Type
forall a. InferType a => TyConMap -> a -> Type
inferCoreTypeOf TyConMap
tcm Term
body
  BindingMap
binders <- Getting BindingMap (RewriteState extra) BindingMap
-> RewriteMonad extra BindingMap
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting BindingMap (RewriteState extra) BindingMap
forall extra. Lens' (RewriteState extra) BindingMap
bindings
  Name Term
bodyNm <- BindingMap -> Name Term -> RewriteMonad extra (Name Term)
forall (m :: Type -> Type) a.
MonadUnique m =>
BindingMap -> Name a -> m (Name a)
cloneNameWithBindingMap BindingMap
binders Name Term
bndrNm
  Name Term
-> Type -> SrcSpan -> InlineSpec -> Term -> RewriteMonad extra ()
forall extra.
Name Term
-> Type -> SrcSpan -> InlineSpec -> Term -> RewriteMonad extra ()
addGlobalBind Name Term
bodyNm Type
bodyTy SrcSpan
sp InlineSpec
inl Term
body
  Id -> RewriteMonad extra Id
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Type -> Name Term -> Id
mkGlobalId Type
bodyTy Name Term
bodyNm)

-- | Add a function to the set of global binders
addGlobalBind
  :: TmName
  -> Type
  -> SrcSpan
  -> InlineSpec
  -> Term
  -> RewriteMonad extra ()
addGlobalBind :: Name Term
-> Type -> SrcSpan -> InlineSpec -> Term -> RewriteMonad extra ()
addGlobalBind Name Term
vNm Type
ty SrcSpan
sp InlineSpec
inl Term
body = do
  let vId :: Id
vId = Type -> Name Term -> Id
mkGlobalId Type
ty Name Term
vNm
      r :: Bool
r = Id
vId Id -> Term -> Bool
`globalIdOccursIn` Term
body
  (Type
ty,Term
body) (Type, Term)
-> ((BindingMap -> Identity BindingMap)
    -> RewriteState extra -> Identity (RewriteState extra))
-> (BindingMap -> Identity BindingMap)
-> RewriteState extra
-> Identity (RewriteState extra)
forall a b. NFData a => a -> b -> b
`deepseq` (BindingMap -> Identity BindingMap)
-> RewriteState extra -> Identity (RewriteState extra)
forall extra. Lens' (RewriteState extra) BindingMap
bindings ((BindingMap -> Identity BindingMap)
 -> RewriteState extra -> Identity (RewriteState extra))
-> (BindingMap -> BindingMap) -> RewriteMonad extra ()
forall s (m :: Type -> Type) a b.
MonadState s m =>
ASetter s s a b -> (a -> b) -> m ()
%= Name Term -> Binding Term -> BindingMap -> BindingMap
forall a b. Uniquable a => a -> b -> UniqMap b -> UniqMap b
extendUniqMap Name Term
vNm (Id
-> SrcSpan -> InlineSpec -> IsPrim -> Term -> Bool -> Binding Term
forall a.
Id -> SrcSpan -> InlineSpec -> IsPrim -> a -> Bool -> Binding a
Binding Id
vId SrcSpan
sp InlineSpec
inl IsPrim
IsFun Term
body Bool
r)

-- | Create a new name out of the given name, but with another unique. Resulting
-- unique is guaranteed to not be in the given InScopeSet.
cloneNameWithInScopeSet
  :: (MonadUnique m)
  => InScopeSet
  -> Name a
  -> m (Name a)
cloneNameWithInScopeSet :: InScopeSet -> Name a -> m (Name a)
cloneNameWithInScopeSet InScopeSet
is Name a
nm = do
  Int
i <- m Int
forall (m :: Type -> Type). MonadUnique m => m Int
getUniqueM
  Name a -> m (Name a)
forall (m :: Type -> Type) a. Monad m => a -> m a
return (InScopeSet -> Name a -> Name a
forall a. (Uniquable a, ClashPretty a) => InScopeSet -> a -> a
uniqAway InScopeSet
is (Name a -> Int -> Name a
forall a. Uniquable a => a -> Int -> a
setUnique Name a
nm Int
i))

-- | Create a new name out of the given name, but with another unique. Resulting
-- unique is guaranteed to not be in the given BindingMap.
cloneNameWithBindingMap
  :: (MonadUnique m)
  => BindingMap
  -> Name a
  -> m (Name a)
cloneNameWithBindingMap :: BindingMap -> Name a -> m (Name a)
cloneNameWithBindingMap BindingMap
binders Name a
nm = do
  Int
i <- m Int
forall (m :: Type -> Type). MonadUnique m => m Int
getUniqueM
  Name a -> m (Name a)
forall (m :: Type -> Type) a. Monad m => a -> m a
return ((Int -> Bool) -> Int -> Name a -> Name a
forall a.
(Uniquable a, ClashPretty a) =>
(Int -> Bool) -> Int -> a -> a
uniqAway' (Int -> BindingMap -> Bool
forall b. Int -> UniqMap b -> Bool
`elemUniqMapDirectly` BindingMap
binders) Int
i (Name a -> Int -> Name a
forall a. Uniquable a => a -> Int -> a
setUnique Name a
nm Int
i))

{-# INLINE isUntranslatable #-}
-- | Determine if a term cannot be represented in hardware
isUntranslatable
  :: Bool
  -- ^ String representable
  -> Term
  -> RewriteMonad extra Bool
isUntranslatable :: Bool -> Term -> RewriteMonad extra Bool
isUntranslatable Bool
stringRepresentable Term
tm = do
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap -> RewriteMonad extra TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
  Bool -> Bool
not (Bool -> Bool)
-> RewriteMonad extra Bool -> RewriteMonad extra Bool
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ((CustomReprs
 -> TyConMap
 -> Type
 -> State HWMap (Maybe (Either String FilteredHWType)))
-> CustomReprs -> Bool -> TyConMap -> Type -> Bool
representableType ((CustomReprs
  -> TyConMap
  -> Type
  -> State HWMap (Maybe (Either String FilteredHWType)))
 -> CustomReprs -> Bool -> TyConMap -> Type -> Bool)
-> RewriteMonad
     extra
     (CustomReprs
      -> TyConMap
      -> Type
      -> State HWMap (Maybe (Either String FilteredHWType)))
-> RewriteMonad
     extra (CustomReprs -> Bool -> TyConMap -> Type -> Bool)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
-> RewriteMonad
     extra
     (CustomReprs
      -> TyConMap
      -> Type
      -> State HWMap (Maybe (Either String FilteredHWType)))
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
Lens'
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
typeTranslator
                             RewriteMonad
  extra (CustomReprs -> Bool -> TyConMap -> Type -> Bool)
-> RewriteMonad extra CustomReprs
-> RewriteMonad extra (Bool -> TyConMap -> Type -> Bool)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Getting CustomReprs RewriteEnv CustomReprs
-> RewriteMonad extra CustomReprs
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting CustomReprs RewriteEnv CustomReprs
Getter RewriteEnv CustomReprs
customReprs
                             RewriteMonad extra (Bool -> TyConMap -> Type -> Bool)
-> RewriteMonad extra Bool
-> RewriteMonad extra (TyConMap -> Type -> Bool)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Bool -> RewriteMonad extra Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Bool
stringRepresentable
                             RewriteMonad extra (TyConMap -> Type -> Bool)
-> RewriteMonad extra TyConMap -> RewriteMonad extra (Type -> Bool)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> TyConMap -> RewriteMonad extra TyConMap
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure TyConMap
tcm
                             RewriteMonad extra (Type -> Bool)
-> RewriteMonad extra Type -> RewriteMonad extra Bool
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Type -> RewriteMonad extra Type
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (TyConMap -> Term -> Type
forall a. InferType a => TyConMap -> a -> Type
inferCoreTypeOf TyConMap
tcm Term
tm))

{-# INLINE isUntranslatableType #-}
-- | Determine if a type cannot be represented in hardware
isUntranslatableType
  :: Bool
  -- ^ String representable
  -> Type
  -> RewriteMonad extra Bool
isUntranslatableType :: Bool -> Type -> RewriteMonad extra Bool
isUntranslatableType Bool
stringRepresentable Type
ty =
  Bool -> Bool
not (Bool -> Bool)
-> RewriteMonad extra Bool -> RewriteMonad extra Bool
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ((CustomReprs
 -> TyConMap
 -> Type
 -> State HWMap (Maybe (Either String FilteredHWType)))
-> CustomReprs -> Bool -> TyConMap -> Type -> Bool
representableType ((CustomReprs
  -> TyConMap
  -> Type
  -> State HWMap (Maybe (Either String FilteredHWType)))
 -> CustomReprs -> Bool -> TyConMap -> Type -> Bool)
-> RewriteMonad
     extra
     (CustomReprs
      -> TyConMap
      -> Type
      -> State HWMap (Maybe (Either String FilteredHWType)))
-> RewriteMonad
     extra (CustomReprs -> Bool -> TyConMap -> Type -> Bool)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
-> RewriteMonad
     extra
     (CustomReprs
      -> TyConMap
      -> Type
      -> State HWMap (Maybe (Either String FilteredHWType)))
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
Lens'
  RewriteEnv
  (CustomReprs
   -> TyConMap
   -> Type
   -> State HWMap (Maybe (Either String FilteredHWType)))
typeTranslator
                             RewriteMonad
  extra (CustomReprs -> Bool -> TyConMap -> Type -> Bool)
-> RewriteMonad extra CustomReprs
-> RewriteMonad extra (Bool -> TyConMap -> Type -> Bool)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Getting CustomReprs RewriteEnv CustomReprs
-> RewriteMonad extra CustomReprs
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting CustomReprs RewriteEnv CustomReprs
Getter RewriteEnv CustomReprs
customReprs
                             RewriteMonad extra (Bool -> TyConMap -> Type -> Bool)
-> RewriteMonad extra Bool
-> RewriteMonad extra (TyConMap -> Type -> Bool)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Bool -> RewriteMonad extra Bool
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Bool
stringRepresentable
                             RewriteMonad extra (TyConMap -> Type -> Bool)
-> RewriteMonad extra TyConMap -> RewriteMonad extra (Type -> Bool)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Getting TyConMap RewriteEnv TyConMap -> RewriteMonad extra TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
                             RewriteMonad extra (Type -> Bool)
-> RewriteMonad extra Type -> RewriteMonad extra Bool
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Type -> RewriteMonad extra Type
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Type
ty)

normalizeTermTypes :: TyConMap -> Term -> Term
normalizeTermTypes :: TyConMap -> Term -> Term
normalizeTermTypes TyConMap
tcm Term
e = case Term
e of
  Cast Term
e' Type
ty1 Type
ty2 -> Term -> Type -> Type -> Term
Cast (TyConMap -> Term -> Term
normalizeTermTypes TyConMap
tcm Term
e') (TyConMap -> Type -> Type
normalizeType TyConMap
tcm Type
ty1) (TyConMap -> Type -> Type
normalizeType TyConMap
tcm Type
ty2)
  Var Id
v -> Id -> Term
Var (TyConMap -> Id -> Id
normalizeId TyConMap
tcm Id
v)
  -- TODO other terms?
  Term
_ -> Term
e

normalizeId :: TyConMap -> Id -> Id
normalizeId :: TyConMap -> Id -> Id
normalizeId TyConMap
tcm v :: Id
v@(Id {}) = Id
v {varType :: Type
varType = TyConMap -> Type -> Type
normalizeType TyConMap
tcm (Id -> Type
forall a. Var a -> Type
varType Id
v)}
normalizeId TyConMap
_   Id
tyvar     = Id
tyvar

-- | Evaluate an expression to weak-head normal form (WHNF), and apply a
-- transformation on the expression in WHNF.
whnfRW
  :: Bool
  -- ^ Whether the expression we're reducing to WHNF is the subject of a
  -- case expression.
  -> TransformContext
  -> Term
  -> Rewrite extra
  -> RewriteMonad extra Term
whnfRW :: Bool
-> TransformContext
-> Term
-> Rewrite extra
-> RewriteMonad extra Term
whnfRW Bool
isSubj ctx :: TransformContext
ctx@(TransformContext InScopeSet
is0 Context
_) Term
e Rewrite extra
rw = do
  TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap -> RewriteMonad extra TyConMap
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Getter RewriteEnv TyConMap
tcCache
  BindingMap
bndrs <- Getting BindingMap (RewriteState extra) BindingMap
-> RewriteMonad extra BindingMap
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting BindingMap (RewriteState extra) BindingMap
forall extra. Lens' (RewriteState extra) BindingMap
bindings
  Evaluator
eval <- Getting Evaluator RewriteEnv Evaluator
-> RewriteMonad extra Evaluator
forall s (m :: Type -> Type) a.
MonadReader s m =>
Getting a s a -> m a
Lens.view Getting Evaluator RewriteEnv Evaluator
Lens' RewriteEnv Evaluator
evaluator
  Supply
ids <- Getting Supply (RewriteState extra) Supply
-> RewriteMonad extra Supply
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting Supply (RewriteState extra) Supply
forall extra. Lens' (RewriteState extra) Supply
uniqSupply
  let (Supply
ids1,Supply
ids2) = Supply -> (Supply, Supply)
splitSupply Supply
ids
  (Supply -> Identity Supply)
-> RewriteState extra -> Identity (RewriteState extra)
forall extra. Lens' (RewriteState extra) Supply
uniqSupply ((Supply -> Identity Supply)
 -> RewriteState extra -> Identity (RewriteState extra))
-> Supply -> RewriteMonad extra ()
forall s (m :: Type -> Type) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
Lens..= Supply
ids2
  PrimHeap
gh <- Getting PrimHeap (RewriteState extra) PrimHeap
-> RewriteMonad extra PrimHeap
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting PrimHeap (RewriteState extra) PrimHeap
forall extra. Lens' (RewriteState extra) PrimHeap
globalHeap

  case Evaluator
-> BindingMap
-> TyConMap
-> PrimHeap
-> Supply
-> InScopeSet
-> Bool
-> Term
-> (PrimHeap, IdSubstEnv, Term)
whnf' Evaluator
eval BindingMap
bndrs TyConMap
tcm PrimHeap
gh Supply
ids1 InScopeSet
is0 Bool
isSubj Term
e of
    (!PrimHeap
gh1,IdSubstEnv
ph,Term
v) -> do
      (PrimHeap -> Identity PrimHeap)
-> RewriteState extra -> Identity (RewriteState extra)
forall extra. Lens' (RewriteState extra) PrimHeap
globalHeap ((PrimHeap -> Identity PrimHeap)
 -> RewriteState extra -> Identity (RewriteState extra))
-> PrimHeap -> RewriteMonad extra ()
forall s (m :: Type -> Type) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
Lens..= PrimHeap
gh1
      TyConMap -> IdSubstEnv -> Rewrite extra -> Rewrite extra
forall extra.
TyConMap -> IdSubstEnv -> Rewrite extra -> Rewrite extra
bindPureHeap TyConMap
tcm IdSubstEnv
ph Rewrite extra
rw TransformContext
ctx Term
v
{-# SCC whnfRW #-}

-- | Binds variables on the PureHeap over the result of the rewrite
--
-- To prevent unnecessary rewrites only do this when rewrite changed something.
bindPureHeap
  :: TyConMap
  -> PureHeap
  -> Rewrite extra
  -> Rewrite extra
bindPureHeap :: TyConMap -> IdSubstEnv -> Rewrite extra -> Rewrite extra
bindPureHeap TyConMap
tcm IdSubstEnv
heap Rewrite extra
rw ctx0 :: TransformContext
ctx0@(TransformContext InScopeSet
is0 Context
hist) Term
e = do
  (Term
e1, Any -> Bool
Monoid.getAny -> Bool
hasChanged) <- RewriteMonad extra Term -> RewriteMonad extra (Term, Any)
forall w (m :: Type -> Type) a. MonadWriter w m => m a -> m (a, w)
Writer.listen (RewriteMonad extra Term -> RewriteMonad extra (Term, Any))
-> RewriteMonad extra Term -> RewriteMonad extra (Term, Any)
forall a b. (a -> b) -> a -> b
$ Rewrite extra
rw TransformContext
ctx Term
e
  if Bool
hasChanged Bool -> Bool -> Bool
&& Bool -> Bool
not ([(Id, Term)] -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null [(Id, Term)]
bndrs) then do
    -- The evaluator results are post-processed with two operations:
    --
    --   1. Inline work free binders. We've seen cases in the wild† where the
    --      evaluator (or rather, 'bindPureHeap') would let-bind work-free
    --      binders that were crucial for eliminating case constructs. If these
    --      case constructs were used in a self-referential (but terminating)
    --      manner, Clash would get stuck in an infinite loop. The proper
    --      solution would be to use 'isWorkFree', instead of 'isWorkFreeIsh',
    --      in 'bindConstantVar' such that these work free constructs would get
    --      inlined again. However, this incurs a great performance penalty so
    --      we opt to prevent the evaluator from introducing this situation in
    --      the first place.
    --
    --      I'd like to stress that this is not a proper solution though, as GHC
    --      might produce a similar situation. We plan on properly solving this
    --      by eliminating the current lift/bind/eval strategy, instead replacing
    --      it by a partial evaluator‡.
    --
    --   2. Remove any unused let-bindings. Similar to (1), we risk Clash getting
    --      stuck in an infinite loop if we don't remove unused (eliminated by
    --      evaluation!) binders.
    --
    -- † https://github.com/clash-lang/clash-compiler/pull/1354#issuecomment-635430374
    -- ‡ https://www.microsoft.com/en-us/research/wp-content/uploads/2016/07/supercomp-by-eval.pdf
    BindingMap
bs <- Getting BindingMap (RewriteState extra) BindingMap
-> RewriteMonad extra BindingMap
forall s (m :: Type -> Type) a.
MonadState s m =>
Getting a s a -> m a
Lens.use Getting BindingMap (RewriteState extra) BindingMap
forall extra. Lens' (RewriteState extra) BindingMap
bindings
    (Term -> (Id, Term) -> RewriteMonad extra Bool) -> Rewrite extra
forall extra.
(Term -> (Id, Term) -> RewriteMonad extra Bool) -> Rewrite extra
inlineBinders (BindingMap -> Term -> (Id, Term) -> RewriteMonad extra Bool
forall extra (m :: Type -> Type) p a.
MonadState (RewriteState extra) m =>
BindingMap -> p -> (a, Term) -> m Bool
inlineTest BindingMap
bs) TransformContext
ctx0 ([(Id, Term)] -> Term -> Term
Letrec [(Id, Term)]
bndrs Term
e1) RewriteMonad extra Term
-> (Term -> RewriteMonad extra Term) -> RewriteMonad extra Term
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      e2 :: Term
e2@(Let Bind Term
bnders1 Term
e3) ->
        Term -> RewriteMonad extra Term
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Term -> Maybe Term -> Term
forall a. a -> Maybe a -> a
fromMaybe Term
e2 (Bind Term -> Term -> Maybe Term
removeUnusedBinders Bind Term
bnders1 Term
e3))
      Term
e2 ->
        Term -> RewriteMonad extra Term
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Term
e2
  else
    Term -> RewriteMonad extra Term
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
e1
  where
    heapIds :: [Id]
heapIds = ((Id, Term) -> Id) -> [(Id, Term)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Id, Term) -> Id
forall a b. (a, b) -> a
fst [(Id, Term)]
bndrs
    is1 :: InScopeSet
is1 = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is0 [Id]
heapIds
    ctx :: TransformContext
ctx = InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is1 ([Id] -> CoreContext
LetBody [Id]
heapIds CoreContext -> Context -> Context
forall a. a -> [a] -> [a]
: Context
hist)

    bndrs :: [(Id, Term)]
bndrs = ((Int, Term) -> (Id, Term)) -> [(Int, Term)] -> [(Id, Term)]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Term) -> (Id, Term)
toLetBinding ([(Int, Term)] -> [(Id, Term)]) -> [(Int, Term)] -> [(Id, Term)]
forall a b. (a -> b) -> a -> b
$ IdSubstEnv -> [(Int, Term)]
forall a. UniqMap a -> [(Int, a)]
toListUniqMap IdSubstEnv
heap

    toLetBinding :: (Unique,Term) -> LetBinding
    toLetBinding :: (Int, Term) -> (Id, Term)
toLetBinding (Int
uniq,Term
term) = (Id
nm, Term
term)
      where
        ty :: Type
ty = TyConMap -> Term -> Type
forall a. InferType a => TyConMap -> a -> Type
inferCoreTypeOf TyConMap
tcm Term
term
        nm :: Id
nm = Type -> Name Term -> Id
mkLocalId Type
ty (Text -> Int -> Name Term
forall a. Text -> Int -> Name a
mkUnsafeSystemName Text
"x" Int
uniq) -- See [Note: Name re-creation]

    inlineTest :: BindingMap -> p -> (a, Term) -> m Bool
inlineTest BindingMap
bs p
_ (a
_, Term -> Term
stripTicks -> Term
e_) = Lens' (RewriteState extra) (VarEnv Bool)
-> BindingMap -> Term -> m Bool
forall s (m :: Type -> Type).
(HasCallStack, MonadState s m) =>
Lens' s (VarEnv Bool) -> BindingMap -> Term -> m Bool
isWorkFree forall extra. Lens' (RewriteState extra) (VarEnv Bool)
Lens' (RewriteState extra) (VarEnv Bool)
workFreeBinders BindingMap
bs Term
e_

-- | Remove unused binders in given let-binding. Returns /Nothing/ if no unused
-- binders were found.
removeUnusedBinders
  :: Bind Term
  -> Term
  -> Maybe Term
removeUnusedBinders :: Bind Term -> Term -> Maybe Term
removeUnusedBinders (NonRec Id
i Term
_) Term
body =
  let bodyFVs :: VarSet
bodyFVs = Getting VarSet Term Id -> (Id -> VarSet) -> Term -> VarSet
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf Getting VarSet Term Id
Fold Term Id
freeLocalIds Id -> VarSet
forall a. Var a -> VarSet
unitVarSet Term
body
   in if Id
i Id -> VarSet -> Bool
forall a. Var a -> VarSet -> Bool
`elemVarSet` VarSet
bodyFVs then Maybe Term
forall a. Maybe a
Nothing else Term -> Maybe Term
forall a. a -> Maybe a
Just Term
body

removeUnusedBinders (Rec [(Id, Term)]
binds) Term
body =
  case VarEnv (Id, Term) -> [(Id, Term)]
forall a. UniqMap a -> [a]
eltsVarEnv VarEnv (Id, Term)
used of
    [] -> Term -> Maybe Term
forall a. a -> Maybe a
Just Term
body
    [(Id, Term)]
qqL | Bool -> Bool
not ([(Id, Term)] -> [(Id, Term)] -> Bool
forall a b. [a] -> [b] -> Bool
List.equalLength [(Id, Term)]
qqL [(Id, Term)]
binds)
        -> Term -> Maybe Term
forall a. a -> Maybe a
Just ([(Id, Term)] -> Term -> Term
Letrec [(Id, Term)]
qqL Term
body)
        | Bool
otherwise
        -> Maybe Term
forall a. Maybe a
Nothing
 where
  bodyFVs :: VarSet
bodyFVs = Getting VarSet Term Id -> (Id -> VarSet) -> Term -> VarSet
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf Getting VarSet Term Id
Fold Term Id
freeLocalIds Id -> VarSet
forall a. Var a -> VarSet
unitVarSet Term
body
  used :: VarEnv (Id, Term)
used = (VarEnv (Id, Term) -> Var Any -> VarEnv (Id, Term))
-> VarEnv (Id, Term) -> [Var Any] -> VarEnv (Id, Term)
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
List.foldl' VarEnv (Id, Term) -> Var Any -> VarEnv (Id, Term)
collectUsed VarEnv (Id, Term)
forall a. VarEnv a
emptyVarEnv (VarSet -> [Var Any]
eltsVarSet VarSet
bodyFVs)
  bindsEnv :: VarEnv (Id, Term)
bindsEnv = [(Id, (Id, Term))] -> VarEnv (Id, Term)
forall a b. [(Var a, b)] -> VarEnv b
mkVarEnv (((Id, Term) -> (Id, (Id, Term)))
-> [(Id, Term)] -> [(Id, (Id, Term))]
forall a b. (a -> b) -> [a] -> [b]
map (\(Id
x,Term
e0) -> (Id
x,(Id
x,Term
e0))) [(Id, Term)]
binds)

  collectUsed :: VarEnv (Id, Term) -> Var Any -> VarEnv (Id, Term)
collectUsed VarEnv (Id, Term)
env Var Any
v =
    if Var Any
v Var Any -> VarEnv (Id, Term) -> Bool
forall a b. Var a -> VarEnv b -> Bool
`elemVarEnv` VarEnv (Id, Term)
env then
      VarEnv (Id, Term)
env
    else
      case Var Any -> VarEnv (Id, Term) -> Maybe (Id, Term)
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Var Any
v VarEnv (Id, Term)
bindsEnv of
        Just (Id
x,Term
e0) ->
          let eFVs :: VarSet
eFVs = Getting VarSet Term Id -> (Id -> VarSet) -> Term -> VarSet
forall r s a. Getting r s a -> (a -> r) -> s -> r
Lens.foldMapOf Getting VarSet Term Id
Fold Term Id
freeLocalIds Id -> VarSet
forall a. Var a -> VarSet
unitVarSet Term
e0
          in  (VarEnv (Id, Term) -> Var Any -> VarEnv (Id, Term))
-> VarEnv (Id, Term) -> [Var Any] -> VarEnv (Id, Term)
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
List.foldl' VarEnv (Id, Term) -> Var Any -> VarEnv (Id, Term)
collectUsed
                          (Id -> (Id, Term) -> VarEnv (Id, Term) -> VarEnv (Id, Term)
forall b a. Var b -> a -> VarEnv a -> VarEnv a
extendVarEnv Id
x (Id
x,Term
e0) VarEnv (Id, Term)
env)
                          (VarSet -> [Var Any]
eltsVarSet VarSet
eFVs)
        Maybe (Id, Term)
Nothing -> VarEnv (Id, Term)
env