{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE BangPatterns #-}
{-# OPTIONS_GHC -Wno-unused-imports #-}
module Clash.Normalize where
import Data.Either
import Control.Concurrent.Supply (Supply)
import Control.Lens ((.=),(^.),_1,_4)
import qualified Control.Lens as Lens
import Control.Monad.State.Strict (State)
import Data.Binary (encode)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BL
import Data.Either (partitionEithers)
import qualified Data.IntMap as IntMap
import Data.IntMap.Strict (IntMap)
import Data.List
(groupBy, intersect, mapAccumL, sortBy)
import qualified Data.Map as Map
import qualified Data.Maybe as Maybe
import qualified Data.Set as Set
import qualified Data.Set.Lens as Lens
import Data.Semigroup ((<>))
import Data.Text.Prettyprint.Doc (vcat)
import System.IO.Unsafe (unsafePerformIO)
import BasicTypes (InlineSpec (..))
import SrcLoc (SrcSpan,noSrcSpan)
import Clash.Annotations.BitRepresentation.Internal
(CustomReprs)
import Clash.Core.Evaluator (PrimEvaluator)
import Clash.Core.FreeVars
(freeLocalIds, globalIds, globalIdOccursIn, localIdDoesNotOccurIn)
import Clash.Core.Pretty (showPpr, ppr)
import Clash.Core.Subst
(deShadowTerm, extendGblSubstList, mkSubst, substTm)
import Clash.Core.Term (Term (..), collectArgsTicks)
import Clash.Core.Type (Type, splitCoreFunForallTy)
import Clash.Core.TyCon
(TyConMap, TyConName)
import Clash.Core.Util (mkApps, mkTicks, termType)
import Clash.Core.Var (Id, varName, varType)
import Clash.Core.VarEnv
(VarEnv, elemVarSet, eltsVarEnv, emptyInScopeSet, emptyVarEnv,
extendVarEnv, lookupVarEnv, mapVarEnv, mapMaybeVarEnv, mkInScopeSet,
mkVarEnv, mkVarSet, notElemVarEnv, notElemVarSet, nullVarEnv, unionVarEnv)
import Clash.Driver.Types
(BindingMap, ClashOpts (..), DebugLevel (..))
import Clash.Netlist.Types
(HWType (..), HWMap, FilteredHWType(..))
import Clash.Netlist.Util
(splitNormalized, coreTypeToHWType')
import Clash.Normalize.Strategy
import Clash.Normalize.Transformations
(appProp, bindConstantVar, caseCon, flattenLet, reduceConst, topLet,
reduceNonRepPrim, removeUnusedExpr)
import Clash.Normalize.Types
import Clash.Normalize.Util
import Clash.Primitives.Types (CompiledPrimMap)
import Clash.Rewrite.Combinators ((>->),(!->))
import Clash.Rewrite.Types
(RewriteEnv (..), RewriteState (..), bindings, curFun, dbgLevel, extra,
tcCache, topEntities, typeTranslator, customReprs, RewriteStep (..))
import Clash.Rewrite.Util
(apply, isUntranslatableType, runRewrite, runRewriteSession)
import Clash.Signal.Internal (ResetKind (..))
import Clash.Util
runNormalization
:: ClashOpts
-> Supply
-> BindingMap
-> (CustomReprs -> TyConMap -> Type ->
State HWMap (Maybe (Either String FilteredHWType)))
-> CustomReprs
-> TyConMap
-> IntMap TyConName
-> PrimEvaluator
-> CompiledPrimMap
-> VarEnv Bool
-> [Id]
-> NormalizeSession a
-> a
runNormalization opts supply globals typeTrans reprs tcm tupTcm eval primMap rcsMap topEnts
= runRewriteSession rwEnv rwState
where
rwEnv = RewriteEnv
(opt_dbgLevel opts)
typeTrans
tcm
tupTcm
eval
(mkVarSet topEnts)
reprs
rwState = RewriteState
0
globals
supply
(error $ $(curLoc) ++ "Report as bug: no curFun",noSrcSpan)
0
(IntMap.empty, 0)
normState
normState = NormalizeState
emptyVarEnv
Map.empty
emptyVarEnv
(opt_specLimit opts)
emptyVarEnv
(opt_inlineLimit opts)
(opt_inlineFunctionLimit opts)
(opt_inlineConstantLimit opts)
primMap
Map.empty
rcsMap
(opt_newInlineStrat opts)
(opt_ultra opts)
normalize
:: [Id]
-> NormalizeSession BindingMap
normalize [] = return emptyVarEnv
normalize top = do
(new,topNormalized) <- unzip <$> mapM normalize' top
newNormalized <- normalize (concat new)
return (unionVarEnv (mkVarEnv topNormalized) newNormalized)
normalize'
:: Id
-> NormalizeSession ([Id],(Id,(Id,SrcSpan,InlineSpec,Term)))
normalize' nm = do
exprM <- lookupVarEnv nm <$> Lens.use bindings
let nmS = showPpr (varName nm)
case exprM of
Just (nm',sp,inl,tm) -> do
tcm <- Lens.view tcCache
let (_,resTy) = splitCoreFunForallTy tcm (varType nm')
resTyRep <- not <$> isUntranslatableType False resTy
if resTyRep
then do
tmNorm <- normalizeTopLvlBndr nm (nm',sp,inl,tm)
let usedBndrs = Lens.toListOf globalIds (tmNorm ^. _4)
traceIf (nm `elem` usedBndrs)
(concat [ $(curLoc),"Expr belonging to bndr: ",nmS ," (:: "
, showPpr (varType (tmNorm ^. _1))
, ") remains recursive after normalization:\n"
, showPpr (tmNorm ^. _4) ])
(return ())
prevNorm <- mapVarEnv (Lens.view _1) <$> Lens.use (extra.normalized)
topEnts <- Lens.view topEntities
let toNormalize = filter (`notElemVarSet` topEnts)
$ filter (`notElemVarEnv` (extendVarEnv nm nm prevNorm)) usedBndrs
return (toNormalize,(nm,tmNorm))
else do
let usedBndrs = Lens.toListOf globalIds tm
prevNorm <- mapVarEnv (Lens.view _1) <$> Lens.use (extra.normalized)
topEnts <- Lens.view topEntities
let toNormalize = filter (`notElemVarSet` topEnts)
$ filter (`notElemVarEnv` (extendVarEnv nm nm prevNorm)) usedBndrs
lvl <- Lens.view dbgLevel
traceIf (lvl >= DebugFinal)
(concat [$(curLoc), "Expr belonging to bndr: ", nmS, " (:: "
, showPpr (varType nm')
, ") has a non-representable return type."
, " Not normalising:\n", showPpr tm] )
(return (toNormalize,(nm,(nm',sp,inl,tm))))
Nothing -> error $ $(curLoc) ++ "Expr belonging to bndr: " ++ nmS ++ " not found"
checkNonRecursive
:: BindingMap
-> BindingMap
checkNonRecursive norm = case mapMaybeVarEnv go norm of
rcs | nullVarEnv rcs -> norm
rcs -> error $ $(curLoc) ++ "Callgraph after normalisation contains following recursive components: "
++ show (vcat [ ppr a <> ppr b
| (a,b) <- eltsVarEnv rcs
])
where
go (nm,_,_,tm) =
if nm `globalIdOccursIn` tm
then Just (nm,tm)
else Nothing
cleanupGraph
:: Id
-> BindingMap
-> NormalizeSession BindingMap
cleanupGraph topEntity norm
| Just ct <- mkCallTree [] norm topEntity
= do ctFlat <- flattenCallTree ct
return (mkVarEnv $ snd $ callTreeToList [] ctFlat)
cleanupGraph _ norm = return norm
data CallTree = CLeaf (Id,(Id,SrcSpan,InlineSpec,Term))
| CBranch (Id,(Id,SrcSpan,InlineSpec,Term)) [CallTree]
mkCallTree
:: [Id]
-> BindingMap
-> Id
-> Maybe CallTree
mkCallTree visited bindingMap root
| Just rootTm <- lookupVarEnv root bindingMap
= let used = Set.toList $ Lens.setOf globalIds $ (rootTm ^. _4)
other = Maybe.mapMaybe (mkCallTree (root:visited) bindingMap) (filter (`notElem` visited) used)
in case used of
[] -> Just (CLeaf (root,rootTm))
_ -> Just (CBranch (root,rootTm) other)
mkCallTree _ _ _ = Nothing
stripArgs
:: [Id]
-> [Id]
-> [Either Term Type]
-> Maybe [Either Term Type]
stripArgs _ (_:_) [] = Nothing
stripArgs allIds [] args = if any mentionsId args
then Nothing
else Just args
where
mentionsId t = not $ null (either (Lens.toListOf freeLocalIds) (const []) t
`intersect`
allIds)
stripArgs allIds (id_:ids) (Left (Var nm):args)
| id_ == nm = stripArgs allIds ids args
| otherwise = Nothing
stripArgs _ _ _ = Nothing
flattenNode
:: CallTree
-> NormalizeSession (Either CallTree ((Id,Term),[CallTree]))
flattenNode c@(CLeaf (_,(_,_,NoInline,_))) = return (Left c)
flattenNode c@(CLeaf (nm,(_,_,_,e))) = do
isTopEntity <- elemVarSet nm <$> Lens.view topEntities
if isTopEntity then return (Left c) else do
tcm <- Lens.view tcCache
let norm = splitNormalized tcm e
case norm of
Right (ids,[(bId,bExpr)],_) -> do
let (fun,args,ticks) = collectArgsTicks bExpr
case stripArgs ids (reverse ids) (reverse args) of
Just remainder | bId `localIdDoesNotOccurIn` bExpr ->
return (Right ((nm,mkApps (mkTicks fun ticks) (reverse remainder)),[]))
_ -> return (Right ((nm,e),[]))
_ -> return (Right ((nm,e),[]))
flattenNode b@(CBranch (_,(_,_,NoInline,_)) _) =
return (Left b)
flattenNode b@(CBranch (nm,(_,_,_,e)) us) = do
isTopEntity <- elemVarSet nm <$> Lens.view topEntities
if isTopEntity then return (Left b) else do
tcm <- Lens.view tcCache
let norm = splitNormalized tcm e
case norm of
Right (ids,[(bId,bExpr)],_) -> do
let (fun,args,ticks) = collectArgsTicks bExpr
case stripArgs ids (reverse ids) (reverse args) of
Just remainder | bId `localIdDoesNotOccurIn` bExpr ->
return (Right ((nm,mkApps (mkTicks fun ticks) (reverse remainder)),us))
_ -> return (Right ((nm,e),us))
_ -> do
newInlineStrat <- Lens.use (extra.newInlineStrategy)
if newInlineStrat || isCheapFunction e
then return (Right ((nm,e),us))
else return (Left b)
flattenCallTree
:: CallTree
-> NormalizeSession CallTree
flattenCallTree c@(CLeaf _) = return c
flattenCallTree (CBranch (nm,(nm',sp,inl,tm)) used) = do
flattenedUsed <- mapM flattenCallTree used
(newUsed,il_ct) <- partitionEithers <$> mapM flattenNode flattenedUsed
let (toInline,il_used) = unzip il_ct
subst = extendGblSubstList (mkSubst emptyInScopeSet) toInline
newExpr <- case toInline of
[] -> return tm
_ -> do
let tm1 = deShadowTerm emptyInScopeSet (substTm "flattenCallTree.flattenExpr" subst tm)
#ifdef HISTORY
let !_ = unsafePerformIO
$ BS.appendFile "history.dat"
$ BL.toStrict
$ encode RewriteStep
{ t_ctx = []
, t_name = "INLINE"
, t_bndrS = showPpr (varName nm')
, t_before = tm
, t_after = tm1
}
#endif
rewriteExpr ("flattenExpr",flatten) (showPpr nm, tm1) (nm', sp)
let allUsed = newUsed ++ concat il_used
if inl /= NoInline && isCheapFunction newExpr
then do
let (toInline',allUsed') = unzip (map goCheap allUsed)
subst' = extendGblSubstList (mkSubst emptyInScopeSet)
(Maybe.catMaybes toInline')
let tm1 = deShadowTerm emptyInScopeSet (substTm "flattenCallTree.flattenCheap" subst' newExpr)
newExpr' <- rewriteExpr ("flattenCheap",flatten) (showPpr nm, tm1) (nm', sp)
return (CBranch (nm,(nm',sp,inl,newExpr')) (concat allUsed'))
else return (CBranch (nm,(nm',sp,inl,newExpr)) allUsed)
where
flatten =
innerMost (apply "appProp" appProp >->
apply "bindConstantVar" bindConstantVar >->
apply "caseCon" caseCon >->
apply "reduceConst" reduceConst >->
apply "reduceNonRepPrim" reduceNonRepPrim >->
apply "removeUnusedExpr" removeUnusedExpr >->
apply "flattenLet" flattenLet) !->
topdownSucR (apply "topLet" topLet)
goCheap c@(CLeaf (nm2,(_,_,inl2,e)))
| inl2 == NoInline = (Nothing ,[c])
| otherwise = (Just (nm2,e),[])
goCheap c@(CBranch (nm2,(_,_,inl2,e)) us)
| inl2 == NoInline = (Nothing, [c])
| otherwise = (Just (nm2,e),us)
callTreeToList
:: [Id]
-> CallTree
-> ([Id],[(Id,(Id,SrcSpan,InlineSpec,Term))])
callTreeToList visited (CLeaf (nm,bndr))
| nm `elem` visited = (visited,[])
| otherwise = (nm:visited,[(nm,bndr)])
callTreeToList visited (CBranch (nm,bndr) used)
| nm `elem` visited = (visited,[])
| otherwise = (visited',(nm,bndr):(concat others))
where
(visited',others) = mapAccumL callTreeToList (nm:visited) used