{-|
  Copyright   :  (C) 2012-2016, University of Twente,
                     2016     , Myrtle Software Ltd,
                     2017     , Google Inc.
  License     :  BSD2 (see the file LICENSE)
  Maintainer  :  Christiaan Baaij <christiaan.baaij@gmail.com>

  Turn CoreHW terms into normalized CoreHW Terms
-}

{-# LANGUAGE CPP               #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TupleSections     #-}
{-# LANGUAGE TemplateHaskell   #-}
{-# LANGUAGE ViewPatterns      #-}
{-# LANGUAGE BangPatterns      #-}

{-# OPTIONS_GHC -Wno-unused-imports #-}

module Clash.Normalize where

import Data.Either

import           Control.Concurrent.Supply        (Supply)
import           Control.Lens                     ((.=),(^.),_1,_4)
import qualified Control.Lens                     as Lens
import           Control.Monad.State.Strict       (State)
import           Data.Binary                      (encode)
import qualified Data.ByteString                  as BS
import qualified Data.ByteString.Lazy             as BL
import           Data.Either                      (partitionEithers)
import qualified Data.IntMap                      as IntMap
import           Data.IntMap.Strict               (IntMap)
import           Data.List
  (groupBy, intersect, mapAccumL, sortBy)
import qualified Data.Map                         as Map
import qualified Data.Maybe                       as Maybe
import qualified Data.Set                         as Set
import qualified Data.Set.Lens                    as Lens
import           Data.Semigroup                   ((<>))
import           Data.Text.Prettyprint.Doc        (vcat)
import           System.IO.Unsafe                 (unsafePerformIO)

import           BasicTypes                       (InlineSpec (..))
import           SrcLoc                           (SrcSpan,noSrcSpan)

import           Clash.Annotations.BitRepresentation.Internal
  (CustomReprs)
import           Clash.Core.Evaluator             (PrimEvaluator)
import           Clash.Core.FreeVars
  (freeLocalIds, globalIds, globalIdOccursIn, localIdDoesNotOccurIn)
import           Clash.Core.Pretty                (showPpr, ppr)
import           Clash.Core.Subst
  (deShadowTerm, extendGblSubstList, mkSubst, substTm)
import           Clash.Core.Term                  (Term (..), collectArgsTicks)
import           Clash.Core.Type                  (Type, splitCoreFunForallTy)
import           Clash.Core.TyCon
  (TyConMap, TyConName)
import           Clash.Core.Util                  (mkApps, mkTicks, termType)
import           Clash.Core.Var                   (Id, varName, varType)
import           Clash.Core.VarEnv
  (VarEnv, elemVarSet, eltsVarEnv, emptyInScopeSet, emptyVarEnv,
   extendVarEnv, lookupVarEnv, mapVarEnv, mapMaybeVarEnv, mkInScopeSet,
   mkVarEnv, mkVarSet, notElemVarEnv, notElemVarSet, nullVarEnv, unionVarEnv)
import           Clash.Driver.Types
  (BindingMap, ClashOpts (..), DebugLevel (..))
import           Clash.Netlist.Types
  (HWType (..), HWMap, FilteredHWType(..))
import           Clash.Netlist.Util
  (splitNormalized, coreTypeToHWType')
import           Clash.Normalize.Strategy
import           Clash.Normalize.Transformations
  (appProp, bindConstantVar, caseCon, flattenLet, reduceConst, topLet,
   reduceNonRepPrim, removeUnusedExpr)
import           Clash.Normalize.Types
import           Clash.Normalize.Util
import           Clash.Primitives.Types           (CompiledPrimMap)
import           Clash.Rewrite.Combinators        ((>->),(!->))
import           Clash.Rewrite.Types
  (RewriteEnv (..), RewriteState (..), bindings, curFun, dbgLevel, extra,
   tcCache, topEntities, typeTranslator, customReprs, RewriteStep (..))
import           Clash.Rewrite.Util
  (apply, isUntranslatableType, runRewrite, runRewriteSession)
import           Clash.Signal.Internal            (ResetKind (..))
import           Clash.Util

-- | Run a NormalizeSession in a given environment
runNormalization
  :: ClashOpts
  -- ^ Level of debug messages to print
  -> Supply
  -- ^ UniqueSupply
  -> BindingMap
  -- ^ Global Binders
  -> (CustomReprs -> TyConMap -> Type ->
      State HWMap (Maybe (Either String FilteredHWType)))
  -- ^ Hardcoded Type -> HWType translator
  -> CustomReprs
  -> TyConMap
  -- ^ TyCon cache
  -> IntMap TyConName
  -- ^ Tuple TyCon cache
  -> PrimEvaluator
  -- ^ Hardcoded evaluator (delta-reduction)
  -> CompiledPrimMap
  -- ^ Primitive Definitions
  -> VarEnv Bool
  -- ^ Map telling whether a components is part of a recursive group
  -> [Id]
  -- ^ topEntities
  -> NormalizeSession a
  -- ^ NormalizeSession to run
  -> a
runNormalization :: ClashOpts
-> Supply
-> BindingMap
-> (CustomReprs
    -> TyConMap
    -> Type
    -> State HWMap (Maybe (Either String FilteredHWType)))
-> CustomReprs
-> TyConMap
-> IntMap TyConName
-> PrimEvaluator
-> CompiledPrimMap
-> VarEnv Bool
-> [Id]
-> NormalizeSession a
-> a
runNormalization opts :: ClashOpts
opts supply :: Supply
supply globals :: BindingMap
globals typeTrans :: CustomReprs
-> TyConMap
-> Type
-> State HWMap (Maybe (Either String FilteredHWType))
typeTrans reprs :: CustomReprs
reprs tcm :: TyConMap
tcm tupTcm :: IntMap TyConName
tupTcm eval :: PrimEvaluator
eval primMap :: CompiledPrimMap
primMap rcsMap :: VarEnv Bool
rcsMap topEnts :: [Id]
topEnts
  = RewriteEnv
-> RewriteState NormalizeState -> NormalizeSession a -> a
forall extra a.
RewriteEnv -> RewriteState extra -> RewriteMonad extra a -> a
runRewriteSession RewriteEnv
rwEnv RewriteState NormalizeState
rwState
  where
    rwEnv :: RewriteEnv
rwEnv     = DebugLevel
-> (CustomReprs
    -> TyConMap
    -> Type
    -> State HWMap (Maybe (Either String FilteredHWType)))
-> TyConMap
-> IntMap TyConName
-> PrimEvaluator
-> VarSet
-> CustomReprs
-> RewriteEnv
RewriteEnv
                  (ClashOpts -> DebugLevel
opt_dbgLevel ClashOpts
opts)
                  CustomReprs
-> TyConMap
-> Type
-> State HWMap (Maybe (Either String FilteredHWType))
typeTrans
                  TyConMap
tcm
                  IntMap TyConName
tupTcm
                  PrimEvaluator
eval
                  ([Id] -> VarSet
forall a. [Var a] -> VarSet
mkVarSet [Id]
topEnts)
                  CustomReprs
reprs

    rwState :: RewriteState NormalizeState
rwState   = Int
-> BindingMap
-> Supply
-> (Id, SrcSpan)
-> Int
-> GlobalHeap
-> NormalizeState
-> RewriteState NormalizeState
forall extra.
Int
-> BindingMap
-> Supply
-> (Id, SrcSpan)
-> Int
-> GlobalHeap
-> extra
-> RewriteState extra
RewriteState
                  0
                  BindingMap
globals
                  Supply
supply
                  (String -> Id
forall a. HasCallStack => String -> a
error (String -> Id) -> String -> Id
forall a b. (a -> b) -> a -> b
$ $(curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ "Report as bug: no curFun",SrcSpan
noSrcSpan)
                  0
                  (IntMap Term
forall a. IntMap a
IntMap.empty, 0)
                  NormalizeState
normState

    normState :: NormalizeState
normState = BindingMap
-> Map (Id, Int, Either Term Type) Id
-> VarEnv Int
-> Int
-> VarEnv (VarEnv Int)
-> Int
-> Word
-> Word
-> CompiledPrimMap
-> Map Text (Set Int)
-> VarEnv Bool
-> Bool
-> Bool
-> NormalizeState
NormalizeState
                  BindingMap
forall a. VarEnv a
emptyVarEnv
                  Map (Id, Int, Either Term Type) Id
forall k a. Map k a
Map.empty
                  VarEnv Int
forall a. VarEnv a
emptyVarEnv
                  (ClashOpts -> Int
opt_specLimit ClashOpts
opts)
                  VarEnv (VarEnv Int)
forall a. VarEnv a
emptyVarEnv
                  (ClashOpts -> Int
opt_inlineLimit ClashOpts
opts)
                  (ClashOpts -> Word
opt_inlineFunctionLimit ClashOpts
opts)
                  (ClashOpts -> Word
opt_inlineConstantLimit ClashOpts
opts)
                  CompiledPrimMap
primMap
                  Map Text (Set Int)
forall k a. Map k a
Map.empty
                  VarEnv Bool
rcsMap
                  (ClashOpts -> Bool
opt_newInlineStrat ClashOpts
opts)
                  (ClashOpts -> Bool
opt_ultra ClashOpts
opts)


normalize
  :: [Id]
  -> NormalizeSession BindingMap
normalize :: [Id] -> NormalizeSession BindingMap
normalize []  = BindingMap -> NormalizeSession BindingMap
forall (m :: * -> *) a. Monad m => a -> m a
return BindingMap
forall a. VarEnv a
emptyVarEnv
normalize top :: [Id]
top = do
  (new :: [[Id]]
new,topNormalized :: [(Id, (Id, SrcSpan, InlineSpec, Term))]
topNormalized) <- [([Id], (Id, (Id, SrcSpan, InlineSpec, Term)))]
-> ([[Id]], [(Id, (Id, SrcSpan, InlineSpec, Term))])
forall a b. [(a, b)] -> ([a], [b])
unzip ([([Id], (Id, (Id, SrcSpan, InlineSpec, Term)))]
 -> ([[Id]], [(Id, (Id, SrcSpan, InlineSpec, Term))]))
-> RewriteMonad
     NormalizeState [([Id], (Id, (Id, SrcSpan, InlineSpec, Term)))]
-> RewriteMonad
     NormalizeState ([[Id]], [(Id, (Id, SrcSpan, InlineSpec, Term))])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Id
 -> RewriteMonad
      NormalizeState ([Id], (Id, (Id, SrcSpan, InlineSpec, Term))))
-> [Id]
-> RewriteMonad
     NormalizeState [([Id], (Id, (Id, SrcSpan, InlineSpec, Term)))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Id
-> RewriteMonad
     NormalizeState ([Id], (Id, (Id, SrcSpan, InlineSpec, Term)))
normalize' [Id]
top
  BindingMap
newNormalized <- [Id] -> NormalizeSession BindingMap
normalize ([[Id]] -> [Id]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Id]]
new)
  BindingMap -> NormalizeSession BindingMap
forall (m :: * -> *) a. Monad m => a -> m a
return (BindingMap -> BindingMap -> BindingMap
forall a. VarEnv a -> VarEnv a -> VarEnv a
unionVarEnv ([(Id, (Id, SrcSpan, InlineSpec, Term))] -> BindingMap
forall a b. [(Var a, b)] -> VarEnv b
mkVarEnv [(Id, (Id, SrcSpan, InlineSpec, Term))]
topNormalized) BindingMap
newNormalized)

normalize'
  :: Id
  -> NormalizeSession ([Id],(Id,(Id,SrcSpan,InlineSpec,Term)))
normalize' :: Id
-> RewriteMonad
     NormalizeState ([Id], (Id, (Id, SrcSpan, InlineSpec, Term)))
normalize' nm :: Id
nm = do
  Maybe (Id, SrcSpan, InlineSpec, Term)
exprM <- Id -> BindingMap -> Maybe (Id, SrcSpan, InlineSpec, Term)
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
nm (BindingMap -> Maybe (Id, SrcSpan, InlineSpec, Term))
-> NormalizeSession BindingMap
-> RewriteMonad
     NormalizeState (Maybe (Id, SrcSpan, InlineSpec, Term))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting BindingMap (RewriteState NormalizeState) BindingMap
-> NormalizeSession BindingMap
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use Getting BindingMap (RewriteState NormalizeState) BindingMap
forall extra. Lens' (RewriteState extra) BindingMap
bindings
  let nmS :: String
nmS = Name Term -> String
forall p. PrettyPrec p => p -> String
showPpr (Id -> Name Term
forall a. Var a -> Name a
varName Id
nm)
  case Maybe (Id, SrcSpan, InlineSpec, Term)
exprM of
    Just (nm' :: Id
nm',sp :: SrcSpan
sp,inl :: InlineSpec
inl,tm :: Term
tm) -> do
      TyConMap
tcm <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
      let (_,resTy :: Type
resTy) = TyConMap -> Type -> ([Either TyVar Type], Type)
splitCoreFunForallTy TyConMap
tcm (Id -> Type
forall a. Var a -> Type
varType Id
nm')
      Bool
resTyRep <- Bool -> Bool
not (Bool -> Bool)
-> RewriteMonad NormalizeState Bool
-> RewriteMonad NormalizeState Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> Type -> RewriteMonad NormalizeState Bool
forall extra. Bool -> Type -> RewriteMonad extra Bool
isUntranslatableType Bool
False Type
resTy
      if Bool
resTyRep
         then do
            (Id, SrcSpan, InlineSpec, Term)
tmNorm <- Id
-> (Id, SrcSpan, InlineSpec, Term)
-> NormalizeSession (Id, SrcSpan, InlineSpec, Term)
normalizeTopLvlBndr Id
nm (Id
nm',SrcSpan
sp,InlineSpec
inl,Term
tm)
            let usedBndrs :: [Id]
usedBndrs = Getting (Endo [Id]) Term Id -> Term -> [Id]
forall a s. Getting (Endo [a]) s a -> s -> [a]
Lens.toListOf Getting (Endo [Id]) Term Id
Fold Term Id
globalIds ((Id, SrcSpan, InlineSpec, Term)
tmNorm (Id, SrcSpan, InlineSpec, Term)
-> Getting Term (Id, SrcSpan, InlineSpec, Term) Term -> Term
forall s a. s -> Getting a s a -> a
^. Getting Term (Id, SrcSpan, InlineSpec, Term) Term
forall s t a b. Field4 s t a b => Lens s t a b
_4)
            Bool
-> String
-> RewriteMonad NormalizeState ()
-> RewriteMonad NormalizeState ()
forall a. Bool -> String -> a -> a
traceIf (Id
nm Id -> [Id] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Id]
usedBndrs)
                    ([String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [ $(curLoc),"Expr belonging to bndr: ",String
nmS ," (:: "
                            , Type -> String
forall p. PrettyPrec p => p -> String
showPpr (Id -> Type
forall a. Var a -> Type
varType ((Id, SrcSpan, InlineSpec, Term)
tmNorm (Id, SrcSpan, InlineSpec, Term)
-> Getting Id (Id, SrcSpan, InlineSpec, Term) Id -> Id
forall s a. s -> Getting a s a -> a
^. Getting Id (Id, SrcSpan, InlineSpec, Term) Id
forall s t a b. Field1 s t a b => Lens s t a b
_1))
                            , ") remains recursive after normalization:\n"
                            , Term -> String
forall p. PrettyPrec p => p -> String
showPpr ((Id, SrcSpan, InlineSpec, Term)
tmNorm (Id, SrcSpan, InlineSpec, Term)
-> Getting Term (Id, SrcSpan, InlineSpec, Term) Term -> Term
forall s a. s -> Getting a s a -> a
^. Getting Term (Id, SrcSpan, InlineSpec, Term) Term
forall s t a b. Field4 s t a b => Lens s t a b
_4) ])
                    (() -> RewriteMonad NormalizeState ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
            VarEnv Id
prevNorm <- ((Id, SrcSpan, InlineSpec, Term) -> Id) -> BindingMap -> VarEnv Id
forall a b. (a -> b) -> VarEnv a -> VarEnv b
mapVarEnv (Getting Id (Id, SrcSpan, InlineSpec, Term) Id
-> (Id, SrcSpan, InlineSpec, Term) -> Id
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting Id (Id, SrcSpan, InlineSpec, Term) Id
forall s t a b. Field1 s t a b => Lens s t a b
_1) (BindingMap -> VarEnv Id)
-> NormalizeSession BindingMap
-> RewriteMonad NormalizeState (VarEnv Id)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting BindingMap (RewriteState NormalizeState) BindingMap
-> NormalizeSession BindingMap
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use ((NormalizeState -> Const BindingMap NormalizeState)
-> RewriteState NormalizeState
-> Const BindingMap (RewriteState NormalizeState)
forall extra extra2.
Lens (RewriteState extra) (RewriteState extra2) extra extra2
extra((NormalizeState -> Const BindingMap NormalizeState)
 -> RewriteState NormalizeState
 -> Const BindingMap (RewriteState NormalizeState))
-> ((BindingMap -> Const BindingMap BindingMap)
    -> NormalizeState -> Const BindingMap NormalizeState)
-> Getting BindingMap (RewriteState NormalizeState) BindingMap
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(BindingMap -> Const BindingMap BindingMap)
-> NormalizeState -> Const BindingMap NormalizeState
Lens' NormalizeState BindingMap
normalized)
            VarSet
topEnts  <- Getting VarSet RewriteEnv VarSet
-> RewriteMonad NormalizeState VarSet
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting VarSet RewriteEnv VarSet
Lens' RewriteEnv VarSet
topEntities
            let toNormalize :: [Id]
toNormalize = (Id -> Bool) -> [Id] -> [Id]
forall a. (a -> Bool) -> [a] -> [a]
filter (Id -> VarSet -> Bool
forall a. Var a -> VarSet -> Bool
`notElemVarSet` VarSet
topEnts)
                            ([Id] -> [Id]) -> [Id] -> [Id]
forall a b. (a -> b) -> a -> b
$ (Id -> Bool) -> [Id] -> [Id]
forall a. (a -> Bool) -> [a] -> [a]
filter (Id -> VarEnv Id -> Bool
forall a b. Var a -> VarEnv b -> Bool
`notElemVarEnv` (Id -> Id -> VarEnv Id -> VarEnv Id
forall b a. Var b -> a -> VarEnv a -> VarEnv a
extendVarEnv Id
nm Id
nm VarEnv Id
prevNorm)) [Id]
usedBndrs
            ([Id], (Id, (Id, SrcSpan, InlineSpec, Term)))
-> RewriteMonad
     NormalizeState ([Id], (Id, (Id, SrcSpan, InlineSpec, Term)))
forall (m :: * -> *) a. Monad m => a -> m a
return ([Id]
toNormalize,(Id
nm,(Id, SrcSpan, InlineSpec, Term)
tmNorm))
         else do
            let usedBndrs :: [Id]
usedBndrs = Getting (Endo [Id]) Term Id -> Term -> [Id]
forall a s. Getting (Endo [a]) s a -> s -> [a]
Lens.toListOf Getting (Endo [Id]) Term Id
Fold Term Id
globalIds Term
tm
            VarEnv Id
prevNorm <- ((Id, SrcSpan, InlineSpec, Term) -> Id) -> BindingMap -> VarEnv Id
forall a b. (a -> b) -> VarEnv a -> VarEnv b
mapVarEnv (Getting Id (Id, SrcSpan, InlineSpec, Term) Id
-> (Id, SrcSpan, InlineSpec, Term) -> Id
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting Id (Id, SrcSpan, InlineSpec, Term) Id
forall s t a b. Field1 s t a b => Lens s t a b
_1) (BindingMap -> VarEnv Id)
-> NormalizeSession BindingMap
-> RewriteMonad NormalizeState (VarEnv Id)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting BindingMap (RewriteState NormalizeState) BindingMap
-> NormalizeSession BindingMap
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use ((NormalizeState -> Const BindingMap NormalizeState)
-> RewriteState NormalizeState
-> Const BindingMap (RewriteState NormalizeState)
forall extra extra2.
Lens (RewriteState extra) (RewriteState extra2) extra extra2
extra((NormalizeState -> Const BindingMap NormalizeState)
 -> RewriteState NormalizeState
 -> Const BindingMap (RewriteState NormalizeState))
-> ((BindingMap -> Const BindingMap BindingMap)
    -> NormalizeState -> Const BindingMap NormalizeState)
-> Getting BindingMap (RewriteState NormalizeState) BindingMap
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(BindingMap -> Const BindingMap BindingMap)
-> NormalizeState -> Const BindingMap NormalizeState
Lens' NormalizeState BindingMap
normalized)
            VarSet
topEnts  <- Getting VarSet RewriteEnv VarSet
-> RewriteMonad NormalizeState VarSet
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting VarSet RewriteEnv VarSet
Lens' RewriteEnv VarSet
topEntities
            let toNormalize :: [Id]
toNormalize = (Id -> Bool) -> [Id] -> [Id]
forall a. (a -> Bool) -> [a] -> [a]
filter (Id -> VarSet -> Bool
forall a. Var a -> VarSet -> Bool
`notElemVarSet` VarSet
topEnts)
                            ([Id] -> [Id]) -> [Id] -> [Id]
forall a b. (a -> b) -> a -> b
$ (Id -> Bool) -> [Id] -> [Id]
forall a. (a -> Bool) -> [a] -> [a]
filter (Id -> VarEnv Id -> Bool
forall a b. Var a -> VarEnv b -> Bool
`notElemVarEnv` (Id -> Id -> VarEnv Id -> VarEnv Id
forall b a. Var b -> a -> VarEnv a -> VarEnv a
extendVarEnv Id
nm Id
nm VarEnv Id
prevNorm)) [Id]
usedBndrs
            DebugLevel
lvl <- Getting DebugLevel RewriteEnv DebugLevel
-> RewriteMonad NormalizeState DebugLevel
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting DebugLevel RewriteEnv DebugLevel
Lens' RewriteEnv DebugLevel
dbgLevel
            Bool
-> String
-> RewriteMonad
     NormalizeState ([Id], (Id, (Id, SrcSpan, InlineSpec, Term)))
-> RewriteMonad
     NormalizeState ([Id], (Id, (Id, SrcSpan, InlineSpec, Term)))
forall a. Bool -> String -> a -> a
traceIf (DebugLevel
lvl DebugLevel -> DebugLevel -> Bool
forall a. Ord a => a -> a -> Bool
>= DebugLevel
DebugFinal)
                    ([String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [$(curLoc), "Expr belonging to bndr: ", String
nmS, " (:: "
                            , Type -> String
forall p. PrettyPrec p => p -> String
showPpr (Id -> Type
forall a. Var a -> Type
varType Id
nm')
                            , ") has a non-representable return type."
                            , " Not normalising:\n", Term -> String
forall p. PrettyPrec p => p -> String
showPpr Term
tm] )
                    (([Id], (Id, (Id, SrcSpan, InlineSpec, Term)))
-> RewriteMonad
     NormalizeState ([Id], (Id, (Id, SrcSpan, InlineSpec, Term)))
forall (m :: * -> *) a. Monad m => a -> m a
return ([Id]
toNormalize,(Id
nm,(Id
nm',SrcSpan
sp,InlineSpec
inl,Term
tm))))
    Nothing -> String
-> RewriteMonad
     NormalizeState ([Id], (Id, (Id, SrcSpan, InlineSpec, Term)))
forall a. HasCallStack => String -> a
error (String
 -> RewriteMonad
      NormalizeState ([Id], (Id, (Id, SrcSpan, InlineSpec, Term))))
-> String
-> RewriteMonad
     NormalizeState ([Id], (Id, (Id, SrcSpan, InlineSpec, Term)))
forall a b. (a -> b) -> a -> b
$ $(curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ "Expr belonging to bndr: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
nmS String -> String -> String
forall a. [a] -> [a] -> [a]
++ " not found"

-- | Check whether the normalized bindings are non-recursive. Errors when one
-- of the components is recursive.
checkNonRecursive
  :: BindingMap
  -- ^ List of normalized binders
  -> BindingMap
checkNonRecursive :: BindingMap -> BindingMap
checkNonRecursive norm :: BindingMap
norm = case ((Id, SrcSpan, InlineSpec, Term) -> Maybe (Id, Term))
-> BindingMap -> VarEnv (Id, Term)
forall a b. (a -> Maybe b) -> VarEnv a -> VarEnv b
mapMaybeVarEnv (Id, SrcSpan, InlineSpec, Term) -> Maybe (Id, Term)
forall b c. (Id, b, c, Term) -> Maybe (Id, Term)
go BindingMap
norm of
  rcs :: VarEnv (Id, Term)
rcs | VarEnv (Id, Term) -> Bool
forall a. VarEnv a -> Bool
nullVarEnv VarEnv (Id, Term)
rcs  -> BindingMap
norm
  rcs :: VarEnv (Id, Term)
rcs -> String -> BindingMap
forall a. HasCallStack => String -> a
error (String -> BindingMap) -> String -> BindingMap
forall a b. (a -> b) -> a -> b
$ $(curLoc) String -> String -> String
forall a. [a] -> [a] -> [a]
++ "Callgraph after normalisation contains following recursive components: "
                   String -> String -> String
forall a. [a] -> [a] -> [a]
++ Doc ClashAnnotation -> String
forall a. Show a => a -> String
show ([Doc ClashAnnotation] -> Doc ClashAnnotation
forall ann. [Doc ann] -> Doc ann
vcat [ Id -> Doc ClashAnnotation
forall p. PrettyPrec p => p -> Doc ClashAnnotation
ppr Id
a Doc ClashAnnotation -> Doc ClashAnnotation -> Doc ClashAnnotation
forall a. Semigroup a => a -> a -> a
<> Term -> Doc ClashAnnotation
forall p. PrettyPrec p => p -> Doc ClashAnnotation
ppr Term
b
                                 | (a :: Id
a,b :: Term
b) <- VarEnv (Id, Term) -> [(Id, Term)]
forall a. VarEnv a -> [a]
eltsVarEnv VarEnv (Id, Term)
rcs
                                 ])
 where
  go :: (Id, b, c, Term) -> Maybe (Id, Term)
go (nm :: Id
nm,_,_,tm :: Term
tm) =
    if Id
nm Id -> Term -> Bool
`globalIdOccursIn` Term
tm
       then (Id, Term) -> Maybe (Id, Term)
forall a. a -> Maybe a
Just (Id
nm,Term
tm)
       else Maybe (Id, Term)
forall a. Maybe a
Nothing


-- | Perform general \"clean up\" of the normalized (non-recursive) function
-- hierarchy. This includes:
--
--   * Inlining functions that simply \"wrap\" another function
cleanupGraph
  :: Id
  -> BindingMap
  -> NormalizeSession BindingMap
cleanupGraph :: Id -> BindingMap -> NormalizeSession BindingMap
cleanupGraph topEntity :: Id
topEntity norm :: BindingMap
norm
  | Just ct :: CallTree
ct <- [Id] -> BindingMap -> Id -> Maybe CallTree
mkCallTree [] BindingMap
norm Id
topEntity
  = do CallTree
ctFlat <- CallTree -> NormalizeSession CallTree
flattenCallTree CallTree
ct
       BindingMap -> NormalizeSession BindingMap
forall (m :: * -> *) a. Monad m => a -> m a
return ([(Id, (Id, SrcSpan, InlineSpec, Term))] -> BindingMap
forall a b. [(Var a, b)] -> VarEnv b
mkVarEnv ([(Id, (Id, SrcSpan, InlineSpec, Term))] -> BindingMap)
-> [(Id, (Id, SrcSpan, InlineSpec, Term))] -> BindingMap
forall a b. (a -> b) -> a -> b
$ ([Id], [(Id, (Id, SrcSpan, InlineSpec, Term))])
-> [(Id, (Id, SrcSpan, InlineSpec, Term))]
forall a b. (a, b) -> b
snd (([Id], [(Id, (Id, SrcSpan, InlineSpec, Term))])
 -> [(Id, (Id, SrcSpan, InlineSpec, Term))])
-> ([Id], [(Id, (Id, SrcSpan, InlineSpec, Term))])
-> [(Id, (Id, SrcSpan, InlineSpec, Term))]
forall a b. (a -> b) -> a -> b
$ [Id] -> CallTree -> ([Id], [(Id, (Id, SrcSpan, InlineSpec, Term))])
callTreeToList [] CallTree
ctFlat)
cleanupGraph _ norm :: BindingMap
norm = BindingMap -> NormalizeSession BindingMap
forall (m :: * -> *) a. Monad m => a -> m a
return BindingMap
norm

data CallTree = CLeaf   (Id,(Id,SrcSpan,InlineSpec,Term))
              | CBranch (Id,(Id,SrcSpan,InlineSpec,Term)) [CallTree]

mkCallTree
  :: [Id]
  -- ^ Visited
  -> BindingMap
  -- ^ Global binders
  -> Id
  -- ^ Root of the call graph
  -> Maybe CallTree
mkCallTree :: [Id] -> BindingMap -> Id -> Maybe CallTree
mkCallTree visited :: [Id]
visited bindingMap :: BindingMap
bindingMap root :: Id
root
  | Just rootTm :: (Id, SrcSpan, InlineSpec, Term)
rootTm <- Id -> BindingMap -> Maybe (Id, SrcSpan, InlineSpec, Term)
forall b a. Var b -> VarEnv a -> Maybe a
lookupVarEnv Id
root BindingMap
bindingMap
  = let used :: [Id]
used   = Set Id -> [Id]
forall a. Set a -> [a]
Set.toList (Set Id -> [Id]) -> Set Id -> [Id]
forall a b. (a -> b) -> a -> b
$ Getting (Set Id) Term Id -> Term -> Set Id
forall a s. Getting (Set a) s a -> s -> Set a
Lens.setOf Getting (Set Id) Term Id
Fold Term Id
globalIds (Term -> Set Id) -> Term -> Set Id
forall a b. (a -> b) -> a -> b
$ ((Id, SrcSpan, InlineSpec, Term)
rootTm (Id, SrcSpan, InlineSpec, Term)
-> Getting Term (Id, SrcSpan, InlineSpec, Term) Term -> Term
forall s a. s -> Getting a s a -> a
^. Getting Term (Id, SrcSpan, InlineSpec, Term) Term
forall s t a b. Field4 s t a b => Lens s t a b
_4)
        other :: [CallTree]
other  = (Id -> Maybe CallTree) -> [Id] -> [CallTree]
forall a b. (a -> Maybe b) -> [a] -> [b]
Maybe.mapMaybe ([Id] -> BindingMap -> Id -> Maybe CallTree
mkCallTree (Id
rootId -> [Id] -> [Id]
forall a. a -> [a] -> [a]
:[Id]
visited) BindingMap
bindingMap) ((Id -> Bool) -> [Id] -> [Id]
forall a. (a -> Bool) -> [a] -> [a]
filter (Id -> [Id] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Id]
visited) [Id]
used)
    in  case [Id]
used of
          [] -> CallTree -> Maybe CallTree
forall a. a -> Maybe a
Just ((Id, (Id, SrcSpan, InlineSpec, Term)) -> CallTree
CLeaf   (Id
root,(Id, SrcSpan, InlineSpec, Term)
rootTm))
          _  -> CallTree -> Maybe CallTree
forall a. a -> Maybe a
Just ((Id, (Id, SrcSpan, InlineSpec, Term)) -> [CallTree] -> CallTree
CBranch (Id
root,(Id, SrcSpan, InlineSpec, Term)
rootTm) [CallTree]
other)
mkCallTree _ _ _ = Maybe CallTree
forall a. Maybe a
Nothing

stripArgs
  :: [Id]
  -> [Id]
  -> [Either Term Type]
  -> Maybe [Either Term Type]
stripArgs :: [Id] -> [Id] -> [Either Term Type] -> Maybe [Either Term Type]
stripArgs _      (_:_) []   = Maybe [Either Term Type]
forall a. Maybe a
Nothing
stripArgs allIds :: [Id]
allIds []    args :: [Either Term Type]
args = if (Either Term Type -> Bool) -> [Either Term Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Either Term Type -> Bool
forall b. Either Term b -> Bool
mentionsId [Either Term Type]
args
                                then Maybe [Either Term Type]
forall a. Maybe a
Nothing
                                else [Either Term Type] -> Maybe [Either Term Type]
forall a. a -> Maybe a
Just [Either Term Type]
args
  where
    mentionsId :: Either Term b -> Bool
mentionsId t :: Either Term b
t = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Id] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ((Term -> [Id]) -> (b -> [Id]) -> Either Term b -> [Id]
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Getting (Endo [Id]) Term Id -> Term -> [Id]
forall a s. Getting (Endo [a]) s a -> s -> [a]
Lens.toListOf Getting (Endo [Id]) Term Id
Fold Term Id
freeLocalIds) ([Id] -> b -> [Id]
forall a b. a -> b -> a
const []) Either Term b
t
                              [Id] -> [Id] -> [Id]
forall a. Eq a => [a] -> [a] -> [a]
`intersect`
                              [Id]
allIds)

stripArgs allIds :: [Id]
allIds (id_ :: Id
id_:ids :: [Id]
ids) (Left (Var nm :: Id
nm):args :: [Either Term Type]
args)
      | Id
id_ Id -> Id -> Bool
forall a. Eq a => a -> a -> Bool
== Id
nm = [Id] -> [Id] -> [Either Term Type] -> Maybe [Either Term Type]
stripArgs [Id]
allIds [Id]
ids [Either Term Type]
args
      | Bool
otherwise = Maybe [Either Term Type]
forall a. Maybe a
Nothing
stripArgs _ _ _ = Maybe [Either Term Type]
forall a. Maybe a
Nothing

flattenNode
  :: CallTree
  -> NormalizeSession (Either CallTree ((Id,Term),[CallTree]))
flattenNode :: CallTree
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
flattenNode c :: CallTree
c@(CLeaf (_,(_,_,NoInline,_))) = Either CallTree ((Id, Term), [CallTree])
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
forall (m :: * -> *) a. Monad m => a -> m a
return (CallTree -> Either CallTree ((Id, Term), [CallTree])
forall a b. a -> Either a b
Left CallTree
c)
flattenNode c :: CallTree
c@(CLeaf (nm :: Id
nm,(_,_,_,e :: Term
e))) = do
  Bool
isTopEntity <- Id -> VarSet -> Bool
forall a. Var a -> VarSet -> Bool
elemVarSet Id
nm (VarSet -> Bool)
-> RewriteMonad NormalizeState VarSet
-> RewriteMonad NormalizeState Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting VarSet RewriteEnv VarSet
-> RewriteMonad NormalizeState VarSet
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting VarSet RewriteEnv VarSet
Lens' RewriteEnv VarSet
topEntities
  if Bool
isTopEntity then Either CallTree ((Id, Term), [CallTree])
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
forall (m :: * -> *) a. Monad m => a -> m a
return (CallTree -> Either CallTree ((Id, Term), [CallTree])
forall a b. a -> Either a b
Left CallTree
c) else do
    TyConMap
tcm  <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
    let norm :: Either String ([Id], [(Id, Term)], Id)
norm = TyConMap -> Term -> Either String ([Id], [(Id, Term)], Id)
splitNormalized TyConMap
tcm Term
e
    case Either String ([Id], [(Id, Term)], Id)
norm of
      Right (ids :: [Id]
ids,[(bId :: Id
bId,bExpr :: Term
bExpr)],_) -> do
        let (fun :: Term
fun,args :: [Either Term Type]
args,ticks :: [TickInfo]
ticks) = Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks Term
bExpr
        case [Id] -> [Id] -> [Either Term Type] -> Maybe [Either Term Type]
stripArgs [Id]
ids ([Id] -> [Id]
forall a. [a] -> [a]
reverse [Id]
ids) ([Either Term Type] -> [Either Term Type]
forall a. [a] -> [a]
reverse [Either Term Type]
args) of
          Just remainder :: [Either Term Type]
remainder | Id
bId Id -> Term -> Bool
`localIdDoesNotOccurIn` Term
bExpr ->
               Either CallTree ((Id, Term), [CallTree])
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
forall (m :: * -> *) a. Monad m => a -> m a
return (((Id, Term), [CallTree])
-> Either CallTree ((Id, Term), [CallTree])
forall a b. b -> Either a b
Right ((Id
nm,Term -> [Either Term Type] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks Term
fun [TickInfo]
ticks) ([Either Term Type] -> [Either Term Type]
forall a. [a] -> [a]
reverse [Either Term Type]
remainder)),[]))
          _ -> Either CallTree ((Id, Term), [CallTree])
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
forall (m :: * -> *) a. Monad m => a -> m a
return (((Id, Term), [CallTree])
-> Either CallTree ((Id, Term), [CallTree])
forall a b. b -> Either a b
Right ((Id
nm,Term
e),[]))
      _ -> Either CallTree ((Id, Term), [CallTree])
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
forall (m :: * -> *) a. Monad m => a -> m a
return (((Id, Term), [CallTree])
-> Either CallTree ((Id, Term), [CallTree])
forall a b. b -> Either a b
Right ((Id
nm,Term
e),[]))
flattenNode b :: CallTree
b@(CBranch (_,(_,_,NoInline,_)) _) =
  Either CallTree ((Id, Term), [CallTree])
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
forall (m :: * -> *) a. Monad m => a -> m a
return (CallTree -> Either CallTree ((Id, Term), [CallTree])
forall a b. a -> Either a b
Left CallTree
b)
flattenNode b :: CallTree
b@(CBranch (nm :: Id
nm,(_,_,_,e :: Term
e)) us :: [CallTree]
us) = do
  Bool
isTopEntity <- Id -> VarSet -> Bool
forall a. Var a -> VarSet -> Bool
elemVarSet Id
nm (VarSet -> Bool)
-> RewriteMonad NormalizeState VarSet
-> RewriteMonad NormalizeState Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Getting VarSet RewriteEnv VarSet
-> RewriteMonad NormalizeState VarSet
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting VarSet RewriteEnv VarSet
Lens' RewriteEnv VarSet
topEntities
  if Bool
isTopEntity then Either CallTree ((Id, Term), [CallTree])
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
forall (m :: * -> *) a. Monad m => a -> m a
return (CallTree -> Either CallTree ((Id, Term), [CallTree])
forall a b. a -> Either a b
Left CallTree
b) else do
    TyConMap
tcm  <- Getting TyConMap RewriteEnv TyConMap
-> RewriteMonad NormalizeState TyConMap
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
Lens.view Getting TyConMap RewriteEnv TyConMap
Lens' RewriteEnv TyConMap
tcCache
    let norm :: Either String ([Id], [(Id, Term)], Id)
norm = TyConMap -> Term -> Either String ([Id], [(Id, Term)], Id)
splitNormalized TyConMap
tcm Term
e
    case Either String ([Id], [(Id, Term)], Id)
norm of
      Right (ids :: [Id]
ids,[(bId :: Id
bId,bExpr :: Term
bExpr)],_) -> do
        let (fun :: Term
fun,args :: [Either Term Type]
args,ticks :: [TickInfo]
ticks) = Term -> (Term, [Either Term Type], [TickInfo])
collectArgsTicks Term
bExpr
        case [Id] -> [Id] -> [Either Term Type] -> Maybe [Either Term Type]
stripArgs [Id]
ids ([Id] -> [Id]
forall a. [a] -> [a]
reverse [Id]
ids) ([Either Term Type] -> [Either Term Type]
forall a. [a] -> [a]
reverse [Either Term Type]
args) of
          Just remainder :: [Either Term Type]
remainder | Id
bId Id -> Term -> Bool
`localIdDoesNotOccurIn` Term
bExpr ->
               Either CallTree ((Id, Term), [CallTree])
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
forall (m :: * -> *) a. Monad m => a -> m a
return (((Id, Term), [CallTree])
-> Either CallTree ((Id, Term), [CallTree])
forall a b. b -> Either a b
Right ((Id
nm,Term -> [Either Term Type] -> Term
mkApps (Term -> [TickInfo] -> Term
mkTicks Term
fun [TickInfo]
ticks) ([Either Term Type] -> [Either Term Type]
forall a. [a] -> [a]
reverse [Either Term Type]
remainder)),[CallTree]
us))
          _ -> Either CallTree ((Id, Term), [CallTree])
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
forall (m :: * -> *) a. Monad m => a -> m a
return (((Id, Term), [CallTree])
-> Either CallTree ((Id, Term), [CallTree])
forall a b. b -> Either a b
Right ((Id
nm,Term
e),[CallTree]
us))
      _ -> do
        Bool
newInlineStrat <- Getting Bool (RewriteState NormalizeState) Bool
-> RewriteMonad NormalizeState Bool
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
Lens.use ((NormalizeState -> Const Bool NormalizeState)
-> RewriteState NormalizeState
-> Const Bool (RewriteState NormalizeState)
forall extra extra2.
Lens (RewriteState extra) (RewriteState extra2) extra extra2
extra((NormalizeState -> Const Bool NormalizeState)
 -> RewriteState NormalizeState
 -> Const Bool (RewriteState NormalizeState))
-> ((Bool -> Const Bool Bool)
    -> NormalizeState -> Const Bool NormalizeState)
-> Getting Bool (RewriteState NormalizeState) Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Bool -> Const Bool Bool)
-> NormalizeState -> Const Bool NormalizeState
Lens' NormalizeState Bool
newInlineStrategy)
        if Bool
newInlineStrat Bool -> Bool -> Bool
|| Term -> Bool
isCheapFunction Term
e
           then Either CallTree ((Id, Term), [CallTree])
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
forall (m :: * -> *) a. Monad m => a -> m a
return (((Id, Term), [CallTree])
-> Either CallTree ((Id, Term), [CallTree])
forall a b. b -> Either a b
Right ((Id
nm,Term
e),[CallTree]
us))
           else Either CallTree ((Id, Term), [CallTree])
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
forall (m :: * -> *) a. Monad m => a -> m a
return (CallTree -> Either CallTree ((Id, Term), [CallTree])
forall a b. a -> Either a b
Left CallTree
b)

flattenCallTree
  :: CallTree
  -> NormalizeSession CallTree
flattenCallTree :: CallTree -> NormalizeSession CallTree
flattenCallTree c :: CallTree
c@(CLeaf _) = CallTree -> NormalizeSession CallTree
forall (m :: * -> *) a. Monad m => a -> m a
return CallTree
c
flattenCallTree (CBranch (nm :: Id
nm,(nm' :: Id
nm',sp :: SrcSpan
sp,inl :: InlineSpec
inl,tm :: Term
tm)) used :: [CallTree]
used) = do
  [CallTree]
flattenedUsed   <- (CallTree -> NormalizeSession CallTree)
-> [CallTree] -> RewriteMonad NormalizeState [CallTree]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM CallTree -> NormalizeSession CallTree
flattenCallTree [CallTree]
used
  (newUsed :: [CallTree]
newUsed,il_ct :: [((Id, Term), [CallTree])]
il_ct) <- [Either CallTree ((Id, Term), [CallTree])]
-> ([CallTree], [((Id, Term), [CallTree])])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either CallTree ((Id, Term), [CallTree])]
 -> ([CallTree], [((Id, Term), [CallTree])]))
-> RewriteMonad
     NormalizeState [Either CallTree ((Id, Term), [CallTree])]
-> RewriteMonad
     NormalizeState ([CallTree], [((Id, Term), [CallTree])])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (CallTree
 -> NormalizeSession (Either CallTree ((Id, Term), [CallTree])))
-> [CallTree]
-> RewriteMonad
     NormalizeState [Either CallTree ((Id, Term), [CallTree])]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM CallTree
-> NormalizeSession (Either CallTree ((Id, Term), [CallTree]))
flattenNode [CallTree]
flattenedUsed
  let (toInline :: [(Id, Term)]
toInline,il_used :: [[CallTree]]
il_used) = [((Id, Term), [CallTree])] -> ([(Id, Term)], [[CallTree]])
forall a b. [(a, b)] -> ([a], [b])
unzip [((Id, Term), [CallTree])]
il_ct
      subst :: Subst
subst = Subst -> [(Id, Term)] -> Subst
extendGblSubstList (InScopeSet -> Subst
mkSubst InScopeSet
emptyInScopeSet) [(Id, Term)]
toInline
  Term
newExpr <- case [(Id, Term)]
toInline of
    [] -> Term -> RewriteMonad NormalizeState Term
forall (m :: * -> *) a. Monad m => a -> m a
return Term
tm
    _  -> do
      -- To have a cheap `appProp` transformation we need to
      -- deshadow, see also Note [AppProp no-shadow invariant]
      let tm1 :: Term
tm1 = HasCallStack => InScopeSet -> Term -> Term
InScopeSet -> Term -> Term
deShadowTerm InScopeSet
emptyInScopeSet (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm "flattenCallTree.flattenExpr" Subst
subst Term
tm)
#ifdef HISTORY
      -- NB: When HISTORY is on, emit binary data holding the recorded rewrite steps
      let !_ = unsafePerformIO
             $ BS.appendFile "history.dat"
             $ BL.toStrict
             $ encode RewriteStep
                 { t_ctx    = []
                 , t_name   = "INLINE"
                 , t_bndrS  = showPpr (varName nm')
                 , t_before = tm
                 , t_after  = tm1
                 }
#endif
      (String, NormRewrite)
-> (String, Term)
-> (Id, SrcSpan)
-> RewriteMonad NormalizeState Term
rewriteExpr ("flattenExpr",NormRewrite
flatten) (Id -> String
forall p. PrettyPrec p => p -> String
showPpr Id
nm, Term
tm1) (Id
nm', SrcSpan
sp)
  let allUsed :: [CallTree]
allUsed = [CallTree]
newUsed [CallTree] -> [CallTree] -> [CallTree]
forall a. [a] -> [a] -> [a]
++ [[CallTree]] -> [CallTree]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[CallTree]]
il_used
  -- inline all components when the resulting expression after flattening
  -- is still considered "cheap". This happens often at the topEntity which
  -- wraps another functions and has some selectors and data-constructors.
  if InlineSpec
inl InlineSpec -> InlineSpec -> Bool
forall a. Eq a => a -> a -> Bool
/= InlineSpec
NoInline Bool -> Bool -> Bool
&& Term -> Bool
isCheapFunction Term
newExpr
     then do
        let (toInline' :: [Maybe (Id, Term)]
toInline',allUsed' :: [[CallTree]]
allUsed') = [(Maybe (Id, Term), [CallTree])]
-> ([Maybe (Id, Term)], [[CallTree]])
forall a b. [(a, b)] -> ([a], [b])
unzip ((CallTree -> (Maybe (Id, Term), [CallTree]))
-> [CallTree] -> [(Maybe (Id, Term), [CallTree])]
forall a b. (a -> b) -> [a] -> [b]
map CallTree -> (Maybe (Id, Term), [CallTree])
goCheap [CallTree]
allUsed)
            subst' :: Subst
subst' = Subst -> [(Id, Term)] -> Subst
extendGblSubstList (InScopeSet -> Subst
mkSubst InScopeSet
emptyInScopeSet)
                                        ([Maybe (Id, Term)] -> [(Id, Term)]
forall a. [Maybe a] -> [a]
Maybe.catMaybes [Maybe (Id, Term)]
toInline')
        -- To have a cheap `appProp` transformation we need to
        -- deshadow, see also Note [AppProp no-shadow invariant]
        let tm1 :: Term
tm1 = HasCallStack => InScopeSet -> Term -> Term
InScopeSet -> Term -> Term
deShadowTerm InScopeSet
emptyInScopeSet (HasCallStack => Doc () -> Subst -> Term -> Term
Doc () -> Subst -> Term -> Term
substTm "flattenCallTree.flattenCheap" Subst
subst' Term
newExpr)
        Term
newExpr' <- (String, NormRewrite)
-> (String, Term)
-> (Id, SrcSpan)
-> RewriteMonad NormalizeState Term
rewriteExpr ("flattenCheap",NormRewrite
flatten) (Id -> String
forall p. PrettyPrec p => p -> String
showPpr Id
nm, Term
tm1) (Id
nm', SrcSpan
sp)
        CallTree -> NormalizeSession CallTree
forall (m :: * -> *) a. Monad m => a -> m a
return ((Id, (Id, SrcSpan, InlineSpec, Term)) -> [CallTree] -> CallTree
CBranch (Id
nm,(Id
nm',SrcSpan
sp,InlineSpec
inl,Term
newExpr')) ([[CallTree]] -> [CallTree]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[CallTree]]
allUsed'))
     else CallTree -> NormalizeSession CallTree
forall (m :: * -> *) a. Monad m => a -> m a
return ((Id, (Id, SrcSpan, InlineSpec, Term)) -> [CallTree] -> CallTree
CBranch (Id
nm,(Id
nm',SrcSpan
sp,InlineSpec
inl,Term
newExpr)) [CallTree]
allUsed)
  where
    flatten :: NormRewrite
flatten =
      NormRewrite -> NormRewrite
forall extra. Rewrite extra -> Rewrite extra
innerMost (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply "appProp" HasCallStack => NormRewrite
NormRewrite
appProp NormRewrite -> NormRewrite -> NormRewrite
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>->
                 String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply "bindConstantVar" HasCallStack => NormRewrite
NormRewrite
bindConstantVar NormRewrite -> NormRewrite -> NormRewrite
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>->
                 String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply "caseCon" HasCallStack => NormRewrite
NormRewrite
caseCon NormRewrite -> NormRewrite -> NormRewrite
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>->
                 String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply "reduceConst" HasCallStack => NormRewrite
NormRewrite
reduceConst NormRewrite -> NormRewrite -> NormRewrite
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>->
                 String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply "reduceNonRepPrim" HasCallStack => NormRewrite
NormRewrite
reduceNonRepPrim NormRewrite -> NormRewrite -> NormRewrite
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>->
                 String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply "removeUnusedExpr" HasCallStack => NormRewrite
NormRewrite
removeUnusedExpr NormRewrite -> NormRewrite -> NormRewrite
forall (m :: * -> *).
Monad m =>
Transform m -> Transform m -> Transform m
>->
                 String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply "flattenLet" HasCallStack => NormRewrite
NormRewrite
flattenLet) NormRewrite -> NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m -> Rewrite m
!->
      NormRewrite -> NormRewrite
forall extra. Rewrite extra -> Rewrite extra
topdownSucR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply "topLet" HasCallStack => NormRewrite
NormRewrite
topLet)

    goCheap :: CallTree -> (Maybe (Id, Term), [CallTree])
goCheap c :: CallTree
c@(CLeaf   (nm2 :: Id
nm2,(_,_,inl2 :: InlineSpec
inl2,e :: Term
e)))
      | InlineSpec
inl2 InlineSpec -> InlineSpec -> Bool
forall a. Eq a => a -> a -> Bool
== InlineSpec
NoInline = (Maybe (Id, Term)
forall a. Maybe a
Nothing     ,[CallTree
c])
      | Bool
otherwise        = ((Id, Term) -> Maybe (Id, Term)
forall a. a -> Maybe a
Just (Id
nm2,Term
e),[])
    goCheap c :: CallTree
c@(CBranch (nm2 :: Id
nm2,(_,_,inl2 :: InlineSpec
inl2,e :: Term
e)) us :: [CallTree]
us)
      | InlineSpec
inl2 InlineSpec -> InlineSpec -> Bool
forall a. Eq a => a -> a -> Bool
== InlineSpec
NoInline = (Maybe (Id, Term)
forall a. Maybe a
Nothing, [CallTree
c])
      | Bool
otherwise        = ((Id, Term) -> Maybe (Id, Term)
forall a. a -> Maybe a
Just (Id
nm2,Term
e),[CallTree]
us)

callTreeToList
  :: [Id]
  -> CallTree
  -> ([Id],[(Id,(Id,SrcSpan,InlineSpec,Term))])
callTreeToList :: [Id] -> CallTree -> ([Id], [(Id, (Id, SrcSpan, InlineSpec, Term))])
callTreeToList visited :: [Id]
visited (CLeaf (nm :: Id
nm,bndr :: (Id, SrcSpan, InlineSpec, Term)
bndr))
  | Id
nm Id -> [Id] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Id]
visited = ([Id]
visited,[])
  | Bool
otherwise         = (Id
nmId -> [Id] -> [Id]
forall a. a -> [a] -> [a]
:[Id]
visited,[(Id
nm,(Id, SrcSpan, InlineSpec, Term)
bndr)])
callTreeToList visited :: [Id]
visited (CBranch (nm :: Id
nm,bndr :: (Id, SrcSpan, InlineSpec, Term)
bndr) used :: [CallTree]
used)
  | Id
nm Id -> [Id] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Id]
visited = ([Id]
visited,[])
  | Bool
otherwise         = ([Id]
visited',(Id
nm,(Id, SrcSpan, InlineSpec, Term)
bndr)(Id, (Id, SrcSpan, InlineSpec, Term))
-> [(Id, (Id, SrcSpan, InlineSpec, Term))]
-> [(Id, (Id, SrcSpan, InlineSpec, Term))]
forall a. a -> [a] -> [a]
:([[(Id, (Id, SrcSpan, InlineSpec, Term))]]
-> [(Id, (Id, SrcSpan, InlineSpec, Term))]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(Id, (Id, SrcSpan, InlineSpec, Term))]]
others))
  where
    (visited' :: [Id]
visited',others :: [[(Id, (Id, SrcSpan, InlineSpec, Term))]]
others) = ([Id]
 -> CallTree -> ([Id], [(Id, (Id, SrcSpan, InlineSpec, Term))]))
-> [Id]
-> [CallTree]
-> ([Id], [[(Id, (Id, SrcSpan, InlineSpec, Term))]])
forall (t :: * -> *) a b c.
Traversable t =>
(a -> b -> (a, c)) -> a -> t b -> (a, t c)
mapAccumL [Id] -> CallTree -> ([Id], [(Id, (Id, SrcSpan, InlineSpec, Term))])
callTreeToList (Id
nmId -> [Id] -> [Id]
forall a. a -> [a] -> [a]
:[Id]
visited) [CallTree]
used