{-# LANGUAGE LambdaCase #-} {-# LANGUAGE ViewPatterns #-} {-# OPTIONS_GHC -fcontext-stack=21 #-} -- | Utility functions used by the normalisation transformations module CLaSH.Normalize.Util where import Control.Lens ((%=)) import qualified Control.Lens as Lens import qualified Data.Graph as Graph import Data.Graph.Inductive (Gr,LNode,lsuc,mkGraph,iDom) import Data.HashMap.Lazy (HashMap) import qualified Data.HashMap.Lazy as HashMap import qualified Data.Maybe as Maybe import qualified Data.Set as Set import Unbound.LocallyNameless (Fresh, bind, embed, rec) import CLaSH.Core.FreeVars (termFreeIds) import CLaSH.Core.Var (Var (Id)) import CLaSH.Core.Term (Term (..), TmName) import CLaSH.Core.Type (Type) import CLaSH.Core.TyCon (TyCon, TyConName) import CLaSH.Core.Util (collectArgs, isPolyFun) import CLaSH.Normalize.Types import CLaSH.Rewrite.Util (specialise) -- | Determine if a function is already inlined in the context of the 'NetlistMonad' alreadyInlined :: TmName -> NormalizeMonad (Maybe Int) alreadyInlined f = do cf <- Lens.use curFun inlinedHM <- Lens.use inlineHistory case HashMap.lookup cf inlinedHM of Nothing -> return Nothing Just inlined' -> return (HashMap.lookup f inlined') addNewInline :: TmName -> NormalizeMonad () addNewInline f = do cf <- Lens.use curFun inlineHistory %= HashMap.insertWith (\_ hm -> HashMap.insertWith (+) f 1 hm) cf (HashMap.singleton f 1) -- | Specialize under the Normalization Monad specializeNorm :: NormRewrite specializeNorm = specialise specialisationCache specialisationHistory specialisationLimit -- | Determine if a term is closed isClosed :: (Functor m, Fresh m) => HashMap TyConName TyCon -> Term -> m Bool isClosed tcm = fmap not . isPolyFun tcm -- | Determine if a term represents a constant isConstant :: Term -> Bool isConstant e = case collectArgs e of (Data _, args) -> all (either isConstant (const True)) args (Prim _ _, args) -> all (either isConstant (const True)) args (Literal _,_) -> True _ -> False -- | Create a call graph for a set of global binders, given a root callGraph :: [TmName] -- ^ List of functions that should not be inspected -> HashMap TmName (Type,Term) -- ^ Global binders -> TmName -- ^ Root of the call graph -> [(TmName,[TmName])] callGraph visited bindingMap root = node:other where rootTm = Maybe.fromMaybe (error $ show root ++ " is not a global binder") $ HashMap.lookup root bindingMap used = Set.toList $ termFreeIds (snd rootTm) node = (root,used) other = concatMap (callGraph (root:visited) bindingMap) (filter (`notElem` visited) used) -- | Determine the sets of recursive components given the edges of a callgraph recursiveComponents :: [(TmName,[TmName])] -- ^ [(calling function,[called function])] -> [[TmName]] recursiveComponents = Maybe.catMaybes . map (\case {Graph.CyclicSCC vs -> Just vs; _ -> Nothing}) . Graph.stronglyConnComp . map (\(n,es) -> (n,n,es)) lambdaDropPrep :: HashMap TmName (Type,Term) -> TmName -> HashMap TmName (Type,Term) lambdaDropPrep bndrs topEntity = bndrs' where depGraph = callGraph [] bndrs topEntity used = HashMap.fromList depGraph rcs = recursiveComponents depGraph dropped = map (lambdaDrop bndrs used) rcs bndrs' = foldr (\(k,v) b -> HashMap.insert k v b) bndrs dropped lambdaDrop :: HashMap TmName (Type,Term) -- ^ Original Binders -> HashMap TmName [TmName] -- ^ Dependency Graph -> [TmName] -- ^ Recursive block -> (TmName,(Type,Term)) -- ^ Lambda-dropped Binders lambdaDrop bndrs depGraph cyc@(root:_) = block where doms = dominator depGraph cyc block = blockSink bndrs doms (0,root) lambdaDrop _ _ [] = error "Can't lambdadrop empty cycle" dominator :: HashMap TmName [TmName] -- ^ Dependency Graph -> [TmName] -- ^ Recursive block -> Gr TmName TmName -- ^ Recursive block dominator dominator cfg cyc = mkGraph nodes (map (\(e,b) -> (b,e,nodesM HashMap.! e)) doms) where nodes = zip [0..] cyc nodesM = HashMap.fromList nodes nodesI = HashMap.fromList $ zip cyc [0..] cycEdges = HashMap.map ( map (nodesI HashMap.!) . filter (`elem` cyc) ) $ HashMap.filterWithKey (\k _ -> k `elem` cyc) cfg edges = concatMap (\(i,n) -> zip3 (repeat i) (cycEdges HashMap.! n) (repeat ()) ) nodes graph = mkGraph nodes edges :: Gr TmName () doms = iDom graph 0 blockSink :: HashMap TmName (Type,Term) -- ^ Original Binders -> Gr TmName TmName -- ^ Recursive block dominator -> LNode TmName -- ^ Recursive block dominator root -> (TmName,(Type,Term)) -- ^ Block sank binder blockSink bndrs doms (nId,tmName) = (tmName,(ty,newTm)) where (ty,tm) = bndrs HashMap.! tmName sucTm = lsuc doms nId tmS = map (blockSink bndrs doms) sucTm bnds = map (\(tN,(ty',tm')) -> (Id tN (embed ty'),embed tm')) tmS newTm = case sucTm of [] -> tm _ -> Letrec (bind (rec bnds) tm)