{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ViewPatterns #-}
module Clash.Normalize.Util
( isConstantArg
, shouldReduce
, alreadyInlined
, addNewInline
, specializeNorm
, isRecursiveBndr
, isClosed
, callGraph
, classifyFunction
, isCheapFunction
, isNonRecursiveGlobalVar
, canConstantSpec
, normalizeTopLvlBndr
, rewriteExpr
, removedTm
)
where
import Control.Lens ((&),(+~),(%=),(^.),_4,(.=))
import qualified Control.Lens as Lens
import qualified Data.List as List
import qualified Data.Map as Map
import qualified Data.HashMap.Strict as HashMapS
import Data.Text (Text)
import BasicTypes (InlineSpec)
import Clash.Annotations.Primitive (extractPrim)
import Clash.Core.FreeVars
(globalIds, hasLocalFreeVars, globalIdOccursIn)
import Clash.Core.Pretty (showPpr)
import Clash.Core.Subst (deShadowTerm)
import Clash.Core.Term
(Context, CoreContext(AppArg), PrimInfo (..), Term (..), WorkInfo (..),
collectArgs)
import Clash.Core.TyCon (TyConMap)
import Clash.Core.Type (Type, undefinedTy)
import Clash.Core.Util (isClockOrReset, isPolyFun, termType)
import Clash.Core.Var (Id, Var (..), isGlobalId)
import Clash.Core.VarEnv
(VarEnv, emptyInScopeSet, emptyVarEnv, extendVarEnv, extendVarEnvWith,
lookupVarEnv, unionVarEnvWith, unitVarEnv)
import Clash.Driver.Types (BindingMap, DebugLevel (..))
import {-# SOURCE #-} Clash.Normalize.Strategy (normalization)
import Clash.Normalize.Types
import Clash.Primitives.Util (constantArgs)
import Clash.Rewrite.Types
(RewriteMonad, bindings, curFun, dbgLevel, extra, tcCache)
import Clash.Rewrite.Util (runRewrite, specialise)
import Clash.Unique
import Clash.Util (SrcSpan, anyM, makeCachedU, traceIf)
isConstantArg
:: Text
-> Int
-> RewriteMonad NormalizeState Bool
isConstantArg nm i = do
argMap <- Lens.use (extra.primitiveArgs)
case Map.lookup nm argMap of
Nothing -> do
prims <- Lens.use (extra.primitives)
case extractPrim =<< HashMapS.lookup nm prims of
Nothing ->
pure False
Just p -> do
let m = constantArgs nm p
(extra.primitiveArgs) Lens.%= Map.insert nm m
pure (i `elem` m)
Just m ->
pure (i `elem` m)
shouldReduce
:: Context
-> RewriteMonad NormalizeState Bool
shouldReduce = anyM isConstantArg'
where
isConstantArg' (AppArg (Just (nm, _, i))) = isConstantArg nm i
isConstantArg' _ = pure False
alreadyInlined
:: Id
-> Id
-> NormalizeMonad (Maybe Int)
alreadyInlined f cf = do
inlinedHM <- Lens.use inlineHistory
case lookupVarEnv cf inlinedHM of
Nothing -> return Nothing
Just inlined' -> return (lookupVarEnv f inlined')
addNewInline
:: Id
-> Id
-> NormalizeMonad ()
addNewInline f cf =
inlineHistory %= extendVarEnvWith
cf
(unitVarEnv f 1)
(\_ hm -> extendVarEnvWith f 1 (+) hm)
specializeNorm :: NormRewrite
specializeNorm = specialise specialisationCache specialisationHistory specialisationLimit
isClosed :: TyConMap
-> Term
-> Bool
isClosed tcm = not . isPolyFun tcm
isNonRecursiveGlobalVar
:: Term
-> NormalizeSession Bool
isNonRecursiveGlobalVar (collectArgs -> (Var i, _args)) = do
let eIsGlobal = isGlobalId i
eIsRec <- isRecursiveBndr i
return (eIsGlobal && not eIsRec)
isNonRecursiveGlobalVar _ = return False
isRecursiveBndr
:: Id
-> NormalizeSession Bool
isRecursiveBndr f = do
cg <- Lens.use (extra.recursiveComponents)
case lookupVarEnv f cg of
Just isR -> return isR
Nothing -> do
fBodyM <- lookupVarEnv f <$> Lens.use bindings
case fBodyM of
Nothing -> return False
Just (_,_,_,fBody) -> do
let isR = f `globalIdOccursIn` fBody
(extra.recursiveComponents) %= extendVarEnv f isR
return isR
canConstantSpec
:: Term
-> RewriteMonad NormalizeState Bool
canConstantSpec e = do
tcm <- Lens.view tcCache
if isClockOrReset tcm (termType tcm e) then
case collectArgs e of
(Prim nm _, _) -> return (nm == "Clash.Transformations.removedArg")
_ -> return False
else
case collectArgs e of
(Data _, args) -> and <$> mapM (either canConstantSpec (const (pure True))) args
(Prim _ _, args) -> and <$> mapM (either canConstantSpec (const (pure True))) args
(Lam _ _, _) -> pure (not (hasLocalFreeVars e))
(Var f, args) -> do
(curF, _) <- Lens.use curFun
argsConst <- and <$> mapM (either canConstantSpec (const (pure True))) args
isNonRecGlobVar <- isNonRecursiveGlobalVar e
return (argsConst && isNonRecGlobVar && f /= curF)
(Literal _,_) -> pure True
_ -> pure False
type CallGraph = VarEnv (VarEnv Word)
callGraph
:: BindingMap
-> Id
-> CallGraph
callGraph bndrs rt = go emptyVarEnv (varUniq rt)
where
go cg root
| Nothing <- lookupUniqMap root cg
, Just rootTm <- lookupUniqMap root bndrs =
let used = Lens.foldMapByOf globalIds (unionVarEnvWith (+))
emptyVarEnv (`unitUniqMap` 1) (rootTm ^. _4)
cg' = extendUniqMap root used cg
in List.foldl' go cg' (keysUniqMap used)
go cg _ = cg
classifyFunction
:: Term
-> TermClassification
classifyFunction = go (TermClassification 0 0 0)
where
go !c (Lam _ e) = go c e
go !c (TyLam _ e) = go c e
go !c (Letrec bs _) = List.foldl' go c (map snd bs)
go !c e@(App {}) = case fst (collectArgs e) of
Prim {} -> c & primitive +~ 1
Var {} -> c & function +~ 1
_ -> c
go !c (Case _ _ alts) = case alts of
(_:_:_) -> c & selection +~ 1
_ -> c
go !c (Tick _ e) = go c e
go c _ = c
isCheapFunction
:: Term
-> Bool
isCheapFunction tm = case classifyFunction tm of
TermClassification {..}
| _function <= 1 -> _primitive <= 0 && _selection <= 0
| _primitive <= 1 -> _function <= 0 && _selection <= 0
| _selection <= 1 -> _function <= 0 && _primitive <= 0
| otherwise -> False
normalizeTopLvlBndr
:: Id
-> (Id, SrcSpan, InlineSpec, Term)
-> NormalizeSession (Id, SrcSpan, InlineSpec, Term)
normalizeTopLvlBndr nm (nm',sp,inl,tm) = makeCachedU nm (extra.normalized) $ do
tcm <- Lens.view tcCache
let nmS = showPpr (varName nm)
let tm1 = deShadowTerm emptyInScopeSet tm
old <- Lens.use curFun
tm2 <- rewriteExpr ("normalization",normalization) (nmS,tm1) (nm',sp)
curFun .= old
let ty' = termType tcm tm2
return (nm' {varType = ty'},sp,inl,tm2)
rewriteExpr :: (String,NormRewrite)
-> (String,Term)
-> (Id, SrcSpan)
-> NormalizeSession Term
rewriteExpr (nrwS,nrw) (bndrS,expr) (nm, sp) = do
curFun .= (nm, sp)
lvl <- Lens.view dbgLevel
let before = showPpr expr
let expr' = traceIf (lvl >= DebugFinal)
(bndrS ++ " before " ++ nrwS ++ ":\n\n" ++ before ++ "\n")
expr
rewritten <- runRewrite nrwS emptyInScopeSet nrw expr'
let after = showPpr rewritten
traceIf (lvl >= DebugFinal)
(bndrS ++ " after " ++ nrwS ++ ":\n\n" ++ after ++ "\n") $
return rewritten
removedTm
:: Type
-> Term
removedTm =
TyApp (Prim "Clash.Transformations.removedArg" (PrimInfo undefinedTy WorkNever))