{-| Copyright : (C) 2012-2016, University of Twente, 2016-2017, Myrtle Software Ltd, 2017-2018, Google Inc. License : BSD2 (see the file LICENSE) Maintainer : Christiaan Baaij Transformations of the Normalization process -} {-# LANGUAGE CPP #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TemplateHaskell #-} module Clash.Normalize.Transformations ( caseLet , caseCon , caseCase , caseElemNonReachable , elemExistentials , inlineNonRep , inlineOrLiftNonRep , typeSpec , nonRepSpec , etaExpansionTL , nonRepANF , bindConstantVar , constantSpec , makeANF , deadCode , topLet , recToLetRec , inlineWorkFree , inlineHO , inlineSmall , simpleCSE , reduceConst , reduceNonRepPrim , caseFlat , disjointExpressionConsolidation , removeUnusedExpr , inlineCleanup , flattenLet , splitCastWork , inlineCast , caseCast , letCast , eliminateCastCast , argCastSpec , etaExpandSyn , appPropFast , separateArguments , separateLambda , xOptimize ) where import Control.Exception (throw) import Control.Lens (_2) import qualified Control.Lens as Lens import qualified Control.Monad as Monad import Control.Monad.State (StateT (..), modify) import Control.Monad.State.Strict (evalState) import Control.Monad.Writer (lift, listen) import Control.Monad.Trans.Except (runExcept) import Data.Coerce (coerce) import qualified Data.Either as Either import qualified Data.HashMap.Lazy as HashMap import qualified Data.HashMap.Strict as HashMapS import Data.List ((\\)) import qualified Data.List as List import qualified Data.List.Extra as List import qualified Data.Maybe as Maybe import qualified Data.Monoid as Monoid import qualified Data.Primitive.ByteArray as BA import qualified Data.Text as Text import qualified Data.Vector.Primitive as PV import GHC.Integer.GMP.Internals (Integer (..), BigNat (..)) import BasicTypes (InlineSpec (..)) import Clash.Annotations.Primitive (extractPrim) import Clash.Core.DataCon (DataCon (..)) import Clash.Core.EqSolver import Clash.Core.Name (Name (..), NameSort (..), mkUnsafeSystemName, nameOcc) import Clash.Core.FreeVars (localIdOccursIn, localIdsDoNotOccurIn, freeLocalIds, termFreeTyVars, typeFreeVars, localVarsDoNotOccurIn, localIdDoesNotOccurIn, countFreeOccurances) import Clash.Core.Literal (Literal (..)) import Clash.Core.Pretty (showPpr) import Clash.Core.Subst import Clash.Core.Term import Clash.Core.TermInfo import Clash.Core.Type (Type (..), TypeView (..), applyFunTy, isPolyFunCoreTy, isClassTy, normalizeType, splitFunForallTy, splitFunTy, tyView, mkPolyFunTy, coreView, LitTy (..), coreView1) import Clash.Core.TyCon (TyConMap, tyConDataCons) import Clash.Core.Util ( isSignalType, mkVec, tyNatSize, undefinedTm, shouldSplit, inverseTopSortLetBindings) import Clash.Core.Var (Id, TyVar, Var (..), isGlobalId, isLocalId, mkLocalId) import Clash.Core.VarEnv (InScopeSet, VarEnv, VarSet, elemVarSet, emptyVarEnv, extendInScopeSet, extendInScopeSetList, lookupVarEnv, notElemVarSet, unionVarEnvWith, unionInScope, unitVarEnv, unitVarSet, mkVarSet, mkInScopeSet, uniqAway, elemInScopeSet, elemVarEnv, foldlWithUniqueVarEnv', lookupVarEnvDirectly, extendVarEnv, unionVarEnv, eltsVarEnv, mkVarEnv, eltsVarSet) import Clash.Debug import Clash.Driver.Types (Binding(..), DebugLevel (..)) import Clash.Netlist.BlackBox.Types (Element(Err)) import Clash.Netlist.BlackBox.Util (getUsedArguments) import Clash.Netlist.Types (BlackBox(..), HWType (..), FilteredHWType(..)) import Clash.Netlist.Util (coreTypeToHWType, representableType, splitNormalized, bindsExistentials) import Clash.Normalize.DEC import Clash.Normalize.PrimitiveReductions import Clash.Normalize.Types import Clash.Normalize.Util import Clash.Primitives.Types (Primitive(..), TemplateKind(TExpr), CompiledPrimMap, UsedArguments(..)) import Clash.Rewrite.Combinators import Clash.Rewrite.Types import Clash.Rewrite.Util import Clash.Unique (Unique, lookupUniqMap) import Clash.Util inlineOrLiftNonRep :: HasCallStack => NormRewrite inlineOrLiftNonRep ctx eLet@(Letrec _ body) = inlineOrLiftBinders nonRepTest inlineTest ctx eLet where bodyFreeOccs = countFreeOccurances body nonRepTest :: (Id, Term) -> RewriteMonad extra Bool nonRepTest (Id {varType = ty}, _) = not <$> (representableType <$> Lens.view typeTranslator <*> Lens.view customReprs <*> pure False <*> Lens.view tcCache <*> pure ty) nonRepTest _ = return False inlineTest :: Term -> (Id, Term) -> Bool inlineTest e (id_, e') = -- We do __NOT__ inline: not $ or [ -- 1. recursive let-binders -- id_ `localIdOccursIn` e' -- <= already checked in inlineOrLiftBinders -- 2. join points (which are not void-wrappers) isJoinPointIn id_ e && not (isVoidWrapper e') -- 3. binders that are used more than once in the body, because -- it makes CSE a whole lot more difficult. -- -- XXX: Check whether we can extend this to the binders as well , maybe False (>1) (lookupVarEnv id_ bodyFreeOccs) ] inlineOrLiftNonRep _ e = return e {-# SCC inlineOrLiftNonRep #-} {- [Note] join points and void wrappers Join points are functions that only occur in tail-call positions within an expression, and only when they occur in a tail-call position more than once. Normally bindNonRep binds/inlines all non-recursive local functions. However, doing so for join points would significantly increase compilation time, so we avoid it. The only exception to this rule are so-called void wrappers. Void wrappers are functions of the form: > \(w :: Void) -> f a b c i.e. a wrapper around the function 'f' where the argument 'w' is not used. We do bind/line these join-points because these void-wrappers interfere with the 'disjoint expression consolidation' (DEC) and 'common sub-expression elimination' (CSE) transformation, sometimes resulting in circuits that are twice as big as they'd need to be. -} -- | Specialize functions on their type typeSpec :: HasCallStack => NormRewrite typeSpec ctx e@(TyApp e1 ty) | (Var {}, args) <- collectArgs e1 , null $ Lens.toListOf typeFreeVars ty , (_, []) <- Either.partitionEithers args = specializeNorm ctx e typeSpec _ e = return e {-# SCC typeSpec #-} -- | Specialize functions on their non-representable argument nonRepSpec :: HasCallStack => NormRewrite nonRepSpec ctx e@(App e1 e2) | (Var {}, args) <- collectArgs e1 , (_, []) <- Either.partitionEithers args , null $ Lens.toListOf termFreeTyVars e2 = do tcm <- Lens.view tcCache let e2Ty = termType tcm e2 let localVar = isLocalVar e2 nonRepE2 <- not <$> (representableType <$> Lens.view typeTranslator <*> Lens.view customReprs <*> pure False <*> Lens.view tcCache <*> pure e2Ty) if nonRepE2 && not localVar then do e2' <- inlineInternalSpecialisationArgument e2 specializeNorm ctx (App e1 e2') else return e where -- | If the argument on which we're specialising ia an internal function, -- one created by the compiler, then inline that function before we -- specialise. -- -- We need to do this because otherwise the specialisation history won't -- recognize the new specialisation argument as something the function has -- already been specialized on inlineInternalSpecialisationArgument :: Term -> NormalizeSession Term inlineInternalSpecialisationArgument app | (Var f,fArgs,ticks) <- collectArgsTicks app = do fTmM <- lookupVarEnv f <$> Lens.use bindings case fTmM of Just b | nameSort (varName (bindingId b)) == Internal -> censor (const mempty) (topdownR appPropFast ctx (mkApps (mkTicks (bindingTerm b) ticks) fArgs)) _ -> return app | otherwise = return app nonRepSpec _ e = return e {-# SCC nonRepSpec #-} -- | Lift the let-bindings out of the subject of a Case-decomposition caseLet :: HasCallStack => NormRewrite caseLet (TransformContext is0 _) (Case (collectTicks -> (Letrec xes e,ticks)) ty alts) = do -- Note [CaseLet deshadow] -- Imagine -- -- @ -- case (let x = u in e) of {p -> a} -- @ -- -- where `a` has a free variable named `x`. -- -- Simply transforming the above to: -- -- @ -- let x = u in case e of {p -> a} -- @ -- -- would be very bad, because now the let-binding captures the free x variable -- in a. -- -- We must therefor rename `x` so that it doesn't capture the free variables -- in the alternative: -- -- @ -- let x1 = u[x:=x1] in case e[x:=x1] of {p -> a} -- @ -- -- It is safe to over-approximate the free variables in `a` by simply taking -- the current InScopeSet. let (xes1,e1) = deshadowLetExpr is0 xes e changed (Letrec (map (second (`mkTicks` ticks)) xes1) (Case (mkTicks e1 ticks) ty alts)) caseLet _ e = return e {-# SCC caseLet #-} -- | Remove non-reachable alternatives. For example, consider: -- -- data STy ty where -- SInt :: Int -> STy Int -- SBool :: Bool -> STy Bool -- -- f :: STy ty -> ty -- f (SInt b) = b + 1 -- f (SBool True) = False -- f (SBool False) = True -- {-# NOINLINE f #-} -- -- g :: STy Int -> Int -- g = f -- -- @f@ is always specialized on @STy Int@. The SBool alternatives are therefore -- unreachable. Additional information can be found at: -- https://github.com/clash-lang/clash-compiler/pull/465 caseElemNonReachable :: HasCallStack => NormRewrite caseElemNonReachable _ case0@(Case scrut altsTy alts0) = do tcm <- Lens.view tcCache let (altsAbsurd, altsOther) = List.partition (isAbsurdAlt tcm) alts0 case altsAbsurd of [] -> return case0 _ -> changed =<< caseOneAlt (Case scrut altsTy altsOther) caseElemNonReachable _ e = return e {-# SCC caseElemNonReachable #-} -- | Tries to eliminate existentials by using heuristics to determine what the -- existential should be. For example, consider Vec: -- -- data Vec :: Nat -> Type -> Type where -- Nil :: Vec 0 a -- Cons x xs :: a -> Vec n a -> Vec (n + 1) a -- -- Thus, 'null' (annotated with existentials) could look like: -- -- null :: forall n . Vec n Bool -> Bool -- null v = -- case v of -- Nil {n ~ 0} -> True -- Cons {n1:Nat} {n~n1+1} (x :: a) (xs :: Vec n1 a) -> False -- -- When it's applied to a vector of length 5, this becomes: -- -- null :: Vec 5 Bool -> Bool -- null v = -- case v of -- Nil {5 ~ 0} -> True -- Cons {n1:Nat} {5~n1+1} (x :: a) (xs :: Vec n1 a) -> False -- -- This function solves 'n1' and replaces every occurrence with its solution. A -- very limited number of solutions are currently recognized: only adds (such -- as in the example) will be solved. elemExistentials :: HasCallStack => NormRewrite elemExistentials (TransformContext is0 _) (Case scrut altsTy alts0) = do tcm <- Lens.view tcCache alts1 <- mapM (go is0 tcm) alts0 caseOneAlt (Case scrut altsTy alts1) where -- Eliminate free type variables if possible go :: InScopeSet -> TyConMap -> (Pat, Term) -> NormalizeSession (Pat, Term) go is2 tcm alt@(DataPat dc exts0 xs0, term0) = case solveNonAbsurds tcm (altEqs tcm alt) of -- No equations solved: [] -> return alt -- One or more equations solved: sols -> changed =<< go is2 tcm (DataPat dc exts1 xs1, term1) where -- Substitute solution in existentials and applied types is3 = extendInScopeSetList is2 exts0 xs1 = map (substTyInVar (extendTvSubstList (mkSubst is3) sols)) xs0 exts1 = substInExistentialsList is2 exts0 sols -- Substitute solution in term. is4 = extendInScopeSetList is3 xs1 subst = extendTvSubstList (mkSubst is4) sols term1 = substTm "Replacing tyVar due to solved eq" subst term0 go _ _ alt = return alt elemExistentials _ e = return e {-# SCC elemExistentials #-} -- | Move a Case-decomposition from the subject of a Case-decomposition to the alternatives caseCase :: HasCallStack => NormRewrite caseCase (TransformContext is0 _) e@(Case (stripTicks -> Case scrut alts1Ty alts1) alts2Ty alts2) = do ty1Rep <- representableType <$> Lens.view typeTranslator <*> Lens.view customReprs <*> pure False <*> Lens.view tcCache <*> pure alts1Ty if not ty1Rep -- Deshadow to prevent accidental capture of free variables of inner -- case. Imagine: -- -- case (case a of {x -> x}) of {_ -> x} -- -- 'x' is introduced the inner 'case' and used (as a free variable) in -- the outer one. The goal of 'caseCase' is to rewrite cases such that -- their subjects aren't cases. This is achieved by 'pushing' the outer -- case to all the alternatives of the inner one. Naively doing so in -- this example would cause an accidental capture: -- -- case a of {x -> case x of {_ -> x}} -- -- Suddenly, the 'x' in the alternative of the inner case statement -- refers to the one introduced by the outer one, instead of being a -- free variable. To prevent this, we deshadow the alternatives of the -- original inner case. We now end up with: -- -- case a of {x1 -> case x1 of {_ -> x}} -- then let newAlts = map (second (\altE -> Case altE alts2Ty alts2)) (map (deShadowAlt is0) alts1) in changed $ Case scrut alts2Ty newAlts else return e caseCase _ e = return e {-# SCC caseCase #-} -- | Inline function with a non-representable result if it's the subject -- of a Case-decomposition inlineNonRep :: HasCallStack => NormRewrite inlineNonRep _ e@(Case scrut altsTy alts) | (Var f, args,ticks) <- collectArgsTicks scrut , isGlobalId f = do (cf,_) <- Lens.use curFun isInlined <- zoomExtra (alreadyInlined f cf) limit <- Lens.use (extra.inlineLimit) tcm <- Lens.view tcCache let scrutTy = termType tcm scrut noException = not (exception tcm scrutTy) if noException && (Maybe.fromMaybe 0 isInlined) > limit then trace (concat [ $(curLoc) ++ "InlineNonRep: " ++ showPpr (varName f) ," already inlined " ++ show limit ++ " times in:" , showPpr (varName cf) , "\nType of the subject is: " ++ showPpr scrutTy , "\nFunction " ++ showPpr (varName cf) , " will not reach a normal form, and compilation" , " might fail." , "\nRun with '-fclash-inline-limit=N' to increase" , " the inlining limit to N." ]) (return e) else do bodyMaybe <- lookupVarEnv f <$> Lens.use bindings nonRepScrut <- not <$> (representableType <$> Lens.view typeTranslator <*> Lens.view customReprs <*> pure False <*> Lens.view tcCache <*> pure scrutTy) case (nonRepScrut, bodyMaybe) of (True,Just b) -> do Monad.when noException (zoomExtra (addNewInline f cf)) let scrutBody0 = mkTicks (bindingTerm b) (mkInlineTick f : ticks) let scrutBody1 = mkApps scrutBody0 args changed $ Case scrutBody1 altsTy alts _ -> return e where exception = isClassTy inlineNonRep _ e = return e {-# SCC inlineNonRep #-} -- | Specialize a Case-decomposition (replace by the RHS of an alternative) if -- the subject is (an application of) a DataCon; or if there is only a single -- alternative that doesn't reference variables bound by the pattern. -- -- Note [CaseCon deshadow] -- -- Imagine: -- -- @ -- case D (f a b) (g x y) of -- D a b -> h a -- @ -- -- rewriting this to: -- -- @ -- let a = f a b -- in h a -- @ -- -- is very bad because the newly introduced let-binding now captures the free -- variable 'a' in 'f a b'. -- -- instead me must rewrite to: -- -- @ -- let a1 = f a b -- in h a1 -- @ caseCon :: HasCallStack => NormRewrite caseCon ctx@(TransformContext is0 _) e@(Case subj ty alts) = do tcm <- Lens.view tcCache case collectArgsTicks subj of -- The subject is an applied data constructor (Data dc, args, ticks) -> case List.find (equalCon . fst) alts of Just (DataPat _ tvs xs, altE) -> do let is1 = extendInScopeSetList (extendInScopeSetList is0 tvs) xs let fvs = Lens.foldMapOf freeLocalIds unitVarSet altE (binds,_) = List.partition ((`elemVarSet` fvs) . fst) $ zip xs (Either.lefts args) binds1 = map (second (`mkTicks` ticks)) binds altE1 = case binds1 of [] -> altE _ -> -- See Note [CaseCon deshadow] let ((is3,substIds),binds2) = List.mapAccumL newBinder (is1,[]) binds1 subst = extendIdSubstList (mkSubst is3) substIds body = substTm "caseCon0" subst altE in case Maybe.catMaybes binds2 of [] -> body binds3 -> Letrec binds3 body -- Use the original inScopeSet 'is0' here, not the extended inScopeSet -- 'is1', otherwise we'd make the "caseCon1" substitution substitute -- free variables that were shadowed by the pattern! let subst = extendTvSubstList (mkSubst is0) $ zip tvs (drop (length (dcUnivTyVars dc)) (Either.rights args)) changed (substTm "caseCon1" subst altE1) _ -> case alts of -- In Core, default patterns always come first, so we match against -- that if there is one, and we couldn't match with any of the data -- patterns. ((DefaultPat,altE):_) -> changed altE _ -> changed (undefinedTm ty) where -- Check whether the pattern matches the data constructor equalCon (DataPat dcPat _ _) = dcTag dc == dcTag dcPat equalCon _ = False -- Decide whether the applied arguments of the data constructor should -- be let-bound, or substituted into the alternative. We decide this -- based on the fact on whether the argument has the potential to make -- the circuit larger than needed if we were to duplicate that argument. newBinder (isN0,substN) (x,arg) | isWorkFree arg = ((isN0,(x,arg):substN),Nothing) | otherwise = let x' = uniqAway isN0 x isN1 = extendInScopeSet isN0 x' in ((isN1,(x,Var x'):substN),Just (x',arg)) -- The subject is a literal (Literal l,_,_) -> case List.find (equalLit . fst) alts of Just (LitPat _,altE) -> changed altE _ -> matchLiteralContructor e l alts where equalLit (LitPat l') = l == l' equalLit _ = False -- The subject is an applied primitive (Prim _,_,_) -> -- We try to reduce the applied primitive to WHNF whnfRW True ctx subj $ \ctx1 subj1 -> case collectArgsTicks subj1 of -- WHNF of subject is a literal, try `caseCon` with that (Literal l,_,_) -> caseCon ctx1 (Case (Literal l) ty alts) -- WHNF of subject is a data-constructor, try `caseCon` with that (Data _,_,_) -> caseCon ctx1 (Case subj1 ty alts) #if MIN_VERSION_ghc(8,2,2) -- WHNF of subject is _|_, in the form of `absentError`: that means that -- the entire case-expression is evaluates to _|_ (Prim pInfo,_:msgOrCallStack:_,ticks) | primName pInfo == "Control.Exception.Base.absentError" -> let e1 = mkApps (mkTicks (Prim pInfo) ticks) [Right ty,msgOrCallStack] in changed e1 #endif -- WHNF of subject is _|_, in the form of `absentError`, `patError`, -- or `undefined`: that means the entire case-expression is _|_ (Prim pInfo,repTy:_:msgOrCallStack:_,ticks) | primName pInfo `elem` ["Control.Exception.Base.patError" #if !MIN_VERSION_ghc(8,2,2) ,"Control.Exception.Base.absentError" #endif ,"GHC.Err.undefined"] -> let e1 = mkApps (mkTicks (Prim pInfo) ticks) [repTy,Right ty,msgOrCallStack] in changed e1 -- WHNF of subject is _|_, in the form of our internal _|_-values: that -- means the entire case-expression is _|_ (Prim pInfo,[_],ticks) | primName pInfo `elem` [ "Clash.Transformations.undefined" , "Clash.GHC.Evaluator.undefined" , "EmptyCase"] -> let e1 = mkApps (mkTicks (Prim pInfo) ticks) [Right ty] in changed e1 -- WHNF of subject is non of the above, so either a variable reference, -- or a primitive for which the evaluator doesn't have any evaluation -- rules. _ -> do let subjTy = termType tcm subj tran <- Lens.view typeTranslator reprs <- Lens.view customReprs case (`evalState` HashMapS.empty) (coreTypeToHWType tran reprs tcm subjTy) of Right (FilteredHWType (Void (Just hty)) _areVoids) | hty `elem` [BitVector 0, Unsigned 0, Signed 0, Index 1] -- If we know that the type of the subject is zero-bits wide and -- one of the Clash number types. Then the only valid alternative is -- the one that can match on the literal "0", so try 'caseCon' with -- that. -> caseCon ctx1 (Case (Literal (IntegerLiteral 0)) ty alts) _ -> do let ret = caseOneAlt e -- Otherwise check whether the entire case-expression has a single -- alternative, and pick that one. lvl <- Lens.view dbgLevel if lvl > DebugNone then do let subjIsConst = isConstant subj -- In debug mode we always report missing evaluation rules for the -- primitive evaluator traceIf (lvl > DebugNone && subjIsConst) ("Irreducible constant as case subject: " ++ showPpr subj ++ "\nCan be reduced to: " ++ showPpr subj1) ret else ret -- The subject is a variable (Var v, [], _) | isNum0 (varType v) -> -- If we know that the type of the subject is zero-bits wide and -- one of the Clash number types. Then the only valid alternative is -- the one that can match on the literal "0", so try 'caseCon' with -- that. caseCon ctx (Case (Literal (IntegerLiteral 0)) ty alts) where isNum0 (tyView -> TyConApp (nameOcc -> tcNm) [arg]) | tcNm `elem` ["Clash.Sized.Internal.BitVector.BitVector" ,"Clash.Sized.Internal.Unsigned.Unsigned" ,"Clash.Sized.Internal.Signed.Signed" ] = isLitX 0 arg | tcNm == "Clash.Sized.Internal.Index.Index" = isLitX 1 arg isNum0 (coreView1 tcm -> Just t) = isNum0 t isNum0 _ = False isLitX n (LitTy (NumTy m)) = n == m isLitX n (coreView1 tcm -> Just t) = isLitX n t isLitX _ _ = False -- Otherwise check whether the entire case-expression has a single -- alternative, and pick that one. _ -> caseOneAlt e caseCon _ e = return e {-# SCC caseCon #-} {- [Note: Name re-creation] The names of heap bound variables are safely generate with mkUniqSystemId in Clash.Core.Evaluator.newLetBinding. But only their uniqs end up in the heap, not the complete names. So we use mkUnsafeSystemName to recreate the same Name. -} matchLiteralContructor :: Term -> Literal -> [(Pat,Term)] -> NormalizeSession Term matchLiteralContructor c (IntegerLiteral l) alts = go (reverse alts) where go [(DefaultPat,e)] = changed e go ((DataPat dc [] xs,e):alts') | dcTag dc == 1 , l >= ((-2)^(63::Int)) && l < 2^(63::Int) = let fvs = Lens.foldMapOf freeLocalIds unitVarSet e (binds,_) = List.partition ((`elemVarSet` fvs) . fst) $ zip xs [Literal (IntLiteral l)] e' = case binds of [] -> e _ -> Letrec binds e in changed e' | dcTag dc == 2 , l >= 2^(63::Int) = let !(Jp# !(BN# ba)) = l ba' = BA.ByteArray ba bv = PV.Vector 0 (BA.sizeofByteArray ba') ba' fvs = Lens.foldMapOf freeLocalIds unitVarSet e (binds,_) = List.partition ((`elemVarSet` fvs) . fst) $ zip xs [Literal (ByteArrayLiteral bv)] e' = case binds of [] -> e _ -> Letrec binds e in changed e' | dcTag dc == 3 , l < ((-2)^(63::Int)) = let !(Jn# !(BN# ba)) = l ba' = BA.ByteArray ba bv = PV.Vector 0 (BA.sizeofByteArray ba') ba' fvs = Lens.foldMapOf freeLocalIds unitVarSet e (binds,_) = List.partition ((`elemVarSet` fvs) . fst) $ zip xs [Literal (ByteArrayLiteral bv)] e' = case binds of [] -> e _ -> Letrec binds e in changed e' | otherwise = go alts' go ((LitPat l', e):alts') | IntegerLiteral l == l' = changed e | otherwise = go alts' go _ = error $ $(curLoc) ++ "Report as bug: caseCon error: " ++ showPpr c matchLiteralContructor c (NaturalLiteral l) alts = go (reverse alts) where go [(DefaultPat,e)] = changed e go ((DataPat dc [] xs,e):alts') | dcTag dc == 1 , l >= 0 && l < 2^(64::Int) = let fvs = Lens.foldMapOf freeLocalIds unitVarSet e (binds,_) = List.partition ((`elemVarSet` fvs) . fst) $ zip xs [Literal (WordLiteral l)] e' = case binds of [] -> e _ -> Letrec binds e in changed e' | dcTag dc == 2 , l >= 2^(64::Int) = let !(Jp# !(BN# ba)) = l ba' = BA.ByteArray ba bv = PV.Vector 0 (BA.sizeofByteArray ba') ba' fvs = Lens.foldMapOf freeLocalIds unitVarSet e (binds,_) = List.partition ((`elemVarSet` fvs) . fst) $ zip xs [Literal (ByteArrayLiteral bv)] e' = case binds of [] -> e _ -> Letrec binds e in changed e' | otherwise = go alts' go ((LitPat l', e):alts') | NaturalLiteral l == l' = changed e | otherwise = go alts' go _ = error $ $(curLoc) ++ "Report as bug: caseCon error: " ++ showPpr c matchLiteralContructor _ _ ((DefaultPat,e):_) = changed e matchLiteralContructor c _ _ = error $ $(curLoc) ++ "Report as bug: caseCon error: " ++ showPpr c {-# SCC matchLiteralContructor #-} caseOneAlt :: Term -> RewriteMonad extra Term caseOneAlt e@(Case _ _ [(pat,altE)]) = case pat of DefaultPat -> changed altE LitPat _ -> changed altE DataPat _ tvs xs | (coerce tvs ++ coerce xs) `localVarsDoNotOccurIn` altE -> changed altE | otherwise -> return e caseOneAlt (Case _ _ alts@((_,alt):_:_)) | all ((== alt) . snd) (tail alts) = changed alt caseOneAlt e = return e {-# SCC caseOneAlt #-} -- | Bring an application of a DataCon or Primitive in ANF, when the argument is -- is considered non-representable nonRepANF :: HasCallStack => NormRewrite nonRepANF ctx@(TransformContext is0 _) e@(App appConPrim arg) | (conPrim, _) <- collectArgs e , isCon conPrim || isPrim conPrim = do untranslatable <- isUntranslatable False arg case (untranslatable,stripTicks arg) of (True,Letrec binds body) -> -- This is a situation similar to Note [CaseLet deshadow] let (binds1,body1) = deshadowLetExpr is0 binds body in changed (Letrec binds1 (App appConPrim body1)) (True,Case {}) -> specializeNorm ctx e (True,Lam {}) -> specializeNorm ctx e (True,TyLam {}) -> specializeNorm ctx e _ -> return e nonRepANF _ e = return e {-# SCC nonRepANF #-} -- | Ensure that top-level lambda's eventually bind a let-expression of which -- the body is a variable-reference. topLet :: HasCallStack => NormRewrite topLet (TransformContext is0 ctx) e | all (\c -> isLambdaBodyCtx c || isTickCtx c) ctx && not (isLet e) && not (isTick e) = do untranslatable <- isUntranslatable False e if untranslatable then return e else do tcm <- Lens.view tcCache argId <- mkTmBinderFor is0 tcm (mkUnsafeSystemName "result" 0) e changed (Letrec [(argId, e)] (Var argId)) where isTick Tick{} = True isTick _ = False topLet (TransformContext is0 ctx) e@(Letrec binds body) | all (\c -> isLambdaBodyCtx c || isTickCtx c) ctx = do let localVar = isLocalVar body untranslatable <- isUntranslatable False body if localVar || untranslatable then return e else do tcm <- Lens.view tcCache let is2 = extendInScopeSetList is0 (map fst binds) argId <- mkTmBinderFor is2 tcm (mkUnsafeSystemName "result" 0) body changed (Letrec (binds ++ [(argId,body)]) (Var argId)) topLet _ e = return e {-# SCC topLet #-} -- Misc rewrites -- | Remove unused let-bindings deadCode :: HasCallStack => NormRewrite deadCode _ e@(Letrec binds body) = do let bodyFVs = Lens.foldMapOf freeLocalIds unitVarSet body used = List.foldl' collectUsed emptyVarEnv (eltsVarSet bodyFVs) case eltsVarEnv used of [] -> changed body qqL | not (List.equalLength qqL binds) -> changed (Letrec qqL body) | otherwise -> return e where bindsEnv = mkVarEnv (map (\(x,e0) -> (x,(x,e0))) binds) collectUsed env v = if v `elemVarEnv` env then env else case lookupVarEnv v bindsEnv of Just (x,e0) -> let eFVs = Lens.foldMapOf freeLocalIds unitVarSet e0 in List.foldl' collectUsed (extendVarEnv x (x,e0) env) (eltsVarSet eFVs) Nothing -> env deadCode _ e = return e {-# SCC deadCode #-} removeUnusedExpr :: HasCallStack => NormRewrite removeUnusedExpr _ e@(collectArgsTicks -> (p@(Prim pInfo),args,ticks)) = do bbM <- HashMap.lookup (primName pInfo) <$> Lens.use (extra.primitives) let usedArgs0 = case Monad.join (extractPrim <$> bbM) of Just (BlackBoxHaskell{usedArguments}) -> case usedArguments of UsedArguments used -> Just used IgnoredArguments ignored -> Just ([0..length args - 1] \\ ignored) Just (BlackBox pNm _ _ _ _ _ _ _ _ inc r ri templ) -> Just $ if | isFromInt pNm -> [0,1,2] | primName pInfo `elem` [ "Clash.Annotations.BitRepresentation.Deriving.dontApplyInHDL" , "Clash.Sized.Vector.splitAt" ] -> [0,1] | otherwise -> concat [ maybe [] getUsedArguments r , maybe [] getUsedArguments ri , getUsedArguments templ , concatMap (getUsedArguments . snd) inc ] _ -> Nothing case usedArgs0 of Nothing -> return e Just usedArgs1 -> do tcm <- Lens.view tcCache (args1, Monoid.getAny -> hasChanged) <- listen (go tcm 0 usedArgs1 args) if hasChanged then return (mkApps (mkTicks p ticks) args1) else return e where arity = length . Either.rights . fst $ splitFunForallTy (primType pInfo) go _ _ _ [] = return [] go tcm !n used (Right ty:args') = do args'' <- go tcm n used args' return (Right ty : args'') go tcm !n used (Left tm : args') = do args'' <- go tcm (n+1) used args' case tm of TyApp (Prim p0) _ | primName p0 == "Clash.Transformations.removedArg" -> return (Left tm : args'') _ -> do let ty = termType tcm tm p' = removedTm ty if n < arity && n `notElem` used then changed (Left p' : args'') else return (Left tm : args'') removeUnusedExpr _ e@(Case _ _ [(DataPat _ [] xs,altExpr)]) = if xs `localIdsDoNotOccurIn` altExpr then changed altExpr else return e -- Replace any expression that creates a Vector of size 0 within the application -- of the Cons constructor, by the Nil constructor. removeUnusedExpr _ e@(collectArgsTicks -> (Data dc, [_,Right aTy,Right nTy,_,Left a,Left nil],ticks)) | nameOcc (dcName dc) == "Clash.Sized.Vector.Cons" = do tcm <- Lens.view tcCache case runExcept (tyNatSize tcm nTy) of Right 0 | (con, _) <- collectArgs nil , not (isCon con) -> let eTy = termType tcm e (TyConApp vecTcNm _) = tyView eTy (Just vecTc) = lookupUniqMap vecTcNm tcm [nilCon,consCon] = tyConDataCons vecTc v = mkTicks (mkVec nilCon consCon aTy 1 [a]) ticks in changed v _ -> return e removeUnusedExpr _ e = return e {-# SCC removeUnusedExpr #-} -- | Inline let-bindings when the RHS is either a local variable reference or -- is constant (except clock or reset generators) bindConstantVar :: HasCallStack => NormRewrite bindConstantVar = inlineBinders test where test _ (i,stripTicks -> e) = case isLocalVar e of -- Don't inline `let x = x in x`, it throws us in an infinite loop True -> return (i `localIdDoesNotOccurIn` e) _ -> isWorkFreeIsh e >>= \case True -> Lens.use (extra.inlineConstantLimit) >>= \case 0 -> return True n -> return (termSize e <= n) _ -> return False {-# SCC bindConstantVar #-} -- | Push a cast over a case into it's alternatives. caseCast :: HasCallStack => NormRewrite caseCast _ (Cast (stripTicks -> Case subj ty alts) ty1 ty2) = do let alts' = map (\(p,e) -> (p, Cast e ty1 ty2)) alts changed (Case subj ty alts') caseCast _ e = return e {-# SCC caseCast #-} -- | Push a cast over a Letrec into it's body letCast :: HasCallStack => NormRewrite letCast _ (Cast (stripTicks -> Letrec binds body) ty1 ty2) = changed $ Letrec binds (Cast body ty1 ty2) letCast _ e = return e {-# SCC letCast #-} -- | Push cast over an argument to a function into that function -- -- This is done by specializing on the casted argument. -- Example: -- @ -- y = f (cast a) -- where f x = g x -- @ -- transforms to: -- @ -- y = f' a -- where f' x' = (\x -> g x) (cast x') -- @ -- -- The reason d'etre for this transformation is that we hope to end up with -- and expression where two casts are "back-to-back" after which we can -- eliminate them in 'eliminateCastCast'. argCastSpec :: HasCallStack => NormRewrite argCastSpec ctx e@(App _ (stripTicks -> Cast e' _ _)) = if isWorkFree e' then go else warn go where go = specializeNorm ctx e warn = trace (unwords [ "WARNING:", $(curLoc), "specializing a function on a non work-free" , "cast. Generated HDL implementation might contain duplicate work." , "Please report this as a bug.", "\n\nExpression where this occured:" , "\n\n" ++ showPpr e ]) argCastSpec _ e = return e {-# SCC argCastSpec #-} -- | Only inline casts that just contain a 'Var', because these are guaranteed work-free. -- These are the result of the 'splitCastWork' transformation. inlineCast :: HasCallStack => NormRewrite inlineCast = inlineBinders test where test _ (_, (Cast (stripTicks -> Var {}) _ _)) = return True test _ _ = return False {-# SCC inlineCast #-} -- | Eliminate two back to back casts where the type going in and coming out are the same -- -- @ -- (cast :: b -> a) $ (cast :: a -> b) x ==> x -- @ eliminateCastCast :: HasCallStack => NormRewrite eliminateCastCast _ c@(Cast (stripTicks -> Cast e tyA tyB) tyB' tyC) = do tcm <- Lens.view tcCache let ntyA = normalizeType tcm tyA ntyB = normalizeType tcm tyB ntyB' = normalizeType tcm tyB' ntyC = normalizeType tcm tyC if ntyB == ntyB' && ntyA == ntyC then changed e else throwError where throwError = do (nm,sp) <- Lens.use curFun throw (ClashException sp ($(curLoc) ++ showPpr nm ++ ": Found 2 nested casts whose types don't line up:\n" ++ showPpr c) Nothing) eliminateCastCast _ e = return e {-# SCC eliminateCastCast #-} -- | Make a cast work-free by splitting the work of to a separate binding -- -- @ -- let x = cast (f a b) -- ==> -- let x = cast x' -- x' = f a b -- @ splitCastWork :: HasCallStack => NormRewrite splitCastWork ctx@(TransformContext is0 _) unchanged@(Letrec vs e') = do (vss', Monoid.getAny -> hasChanged) <- listen (mapM (splitCastLetBinding is0) vs) let vs' = concat vss' if hasChanged then changed (Letrec vs' e') else return unchanged where splitCastLetBinding :: InScopeSet -> LetBinding -> RewriteMonad extra [LetBinding] splitCastLetBinding isN x@(nm, e) = case stripTicks e of Cast (Var {}) _ _ -> return [x] -- already work-free Cast (Cast {}) _ _ -> return [x] -- casts will be eliminated Cast e0 ty1 ty2 -> do tcm <- Lens.view tcCache nm' <- mkTmBinderFor isN tcm (mkDerivedName ctx (nameOcc $ varName nm)) e0 changed [(nm',e0) ,(nm, Cast (Var nm') ty1 ty2) ] _ -> return [x] splitCastWork _ e = return e {-# SCC splitCastWork #-} -- | Inline work-free functions, i.e. fully applied functions that evaluate to -- a constant inlineWorkFree :: HasCallStack => NormRewrite inlineWorkFree _ e@(collectArgsTicks -> (Var f,args@(_:_),ticks)) = do tcm <- Lens.view tcCache let eTy = termType tcm e argsHaveWork <- or <$> mapM (either expressionHasWork (const (pure False))) args untranslatable <- isUntranslatableType True eTy let isSignal = isSignalType tcm eTy let lv = isLocalId f if untranslatable || isSignal || argsHaveWork || lv then return e else do bndrs <- Lens.use bindings case lookupVarEnv f bndrs of -- Don't inline recursive expressions Just b -> do isRecBndr <- isRecursiveBndr f if isRecBndr then return e else do let tm = mkTicks (bindingTerm b) (mkInlineTick f : ticks) changed $ mkApps tm args _ -> return e where -- an expression is has work when it contains free local variables, -- or has a Signal type, i.e. it does not evaluate to a work-free -- constant. expressionHasWork e' = do let fvIds = Lens.toListOf freeLocalIds e' tcm <- Lens.view tcCache let e'Ty = termType tcm e' isSignal = isSignalType tcm e'Ty return (not (null fvIds) || isSignal) inlineWorkFree _ e@(Var f) = do tcm <- Lens.view tcCache let fTy = varType f closed = not (isPolyFunCoreTy tcm fTy) isSignal = isSignalType tcm fTy untranslatable <- isUntranslatableType True fTy topEnts <- Lens.view topEntities let gv = isGlobalId f if closed && f `notElemVarSet` topEnts && not untranslatable && not isSignal && gv then do bndrs <- Lens.use bindings case lookupVarEnv f bndrs of -- Don't inline recursive expressions Just top -> do isRecBndr <- isRecursiveBndr f if isRecBndr then return e else do let topB = bindingTerm top sizeLimit <- Lens.use (extra.inlineWFCacheLimit) -- caching only worth it from a certain size onwards, otherwise -- the caching mechanism itself brings more of an overhead. if termSize topB < sizeLimit then changed topB else do b <- normalizeTopLvlBndr False f top changed (bindingTerm b) _ -> return e else return e inlineWorkFree _ e = return e {-# SCC inlineWorkFree #-} -- | Inline small functions inlineSmall :: HasCallStack => NormRewrite inlineSmall _ e@(collectArgsTicks -> (Var f,args,ticks)) = do untranslatable <- isUntranslatable True e topEnts <- Lens.view topEntities let lv = isLocalId f if untranslatable || f `elemVarSet` topEnts || lv then return e else do bndrs <- Lens.use bindings sizeLimit <- Lens.use (extra.inlineFunctionLimit) case lookupVarEnv f bndrs of -- Don't inline recursive expressions Just b -> do isRecBndr <- isRecursiveBndr f if not isRecBndr && bindingSpec b /= NoInline && termSize (bindingTerm b) < sizeLimit then do let tm = mkTicks (bindingTerm b) (mkInlineTick f : ticks) changed $ mkApps tm args else return e _ -> return e inlineSmall _ e = return e {-# SCC inlineSmall #-} -- | Specialise functions on arguments which are constant, except when they -- are clock, reset generators. constantSpec :: HasCallStack => NormRewrite constantSpec ctx@(TransformContext is0 tfCtx) e@(App e1 e2) | (Var {}, args) <- collectArgs e1 , (_, []) <- Either.partitionEithers args , null $ Lens.toListOf termFreeTyVars e2 = do specInfo<- constantSpecInfo ctx e2 if csrFoundConstant specInfo then let newBindings = csrNewBindings specInfo in if null newBindings then -- Whole of e2 is constant specializeNorm ctx (App e1 e2) else do -- Parts of e2 are constant let is1 = extendInScopeSetList is0 (fst <$> csrNewBindings specInfo) Letrec newBindings <$> specializeNorm (TransformContext is1 tfCtx) (App e1 (csrNewTerm specInfo)) else -- e2 has no constant parts return e constantSpec _ e = return e {-# SCC constantSpec #-} -- Experimental -- | Propagate arguments of application inwards; except for 'Lam' where the -- argument becomes let-bound. 'appPropFast' tries to propagate as many arguments -- as possible, down as many levels as possible; and should be called in a -- top-down traversal. -- -- The idea is that this reduces the number of traversals, which hopefully leads -- to shorter compile times. -- -- Note [AppProp no shadowing] -- -- Case 1. -- -- Imagine: -- -- @ -- (case x of -- D a b -> h a) (f x y) -- @ -- -- rewriting this to: -- -- @ -- let b = f x y -- in case x of -- D a b -> h a b -- @ -- -- is very bad because 'b' in 'h a b' is now bound by the pattern instead of the -- newly introduced let-binding -- -- instead me must deshadow w.r.t. the new variable and rewrite to: -- -- @ -- let b = f x y -- in case x of -- D a b1 -> h a b -- @ -- -- Case 2. -- -- Imagine -- -- @ -- (\x -> e) u -- @ -- -- where @u@ has a free variable named @x@, rewriting this to: -- -- @ -- let x = u -- in e -- @ -- -- would be very bad, because the let-binding suddenly captures the free -- variable in @u@. To prevent this from happening we over-approximate and check -- whether @x@ is in the current InScopeSet, and deshadow if that's the case, -- i.e. we then rewrite to: -- -- let x1 = u -- in e [x:=x1] -- -- Case 3. -- -- The same for: -- -- @ -- (let x = w in e) u -- @ -- -- where @u@ again has a free variable @x@, rewriting this to: -- -- @ -- let x = w in (e u) -- @ -- -- would be bad because the let-binding now captures the free variable in @u@. -- -- To prevent this from happening, we unconditionally deshadow the function part -- of the application w.r.t. the free variables in the argument part of the -- application. It is okay to over-approximate in this case and deshadow w.r.t -- the current InScopeSet. appPropFast :: HasCallStack => NormRewrite appPropFast ctx@(TransformContext is _) = \case e@App {} | let (fun,args,ticks) = collectArgsTicks e -> go is (deShadowTerm is fun) args ticks e@TyApp {} | let (fun,args,ticks) = collectArgsTicks e -> go is (deShadowTerm is fun) args ticks e -> return e where go :: InScopeSet -> Term -> [Either Term Type] -> [TickInfo] -> NormalizeSession Term go is0 (collectArgsTicks -> (fun,args0@(_:_),ticks0)) args1 ticks1 = go is0 fun (args0 ++ args1) (ticks0 ++ ticks1) go is0 (Lam v e) (Left arg:args) ticks = do setChanged if isWorkFree arg || isVar arg then do let subst = extendIdSubst (mkSubst is0) v arg (`mkTicks` ticks) <$> go is0 (substTm "appPropFast.AppLam" subst e) args [] else do let is1 = extendInScopeSet is0 v Letrec [(v, arg)] <$> go is1 (deShadowTerm is1 e) args ticks go is0 (Letrec vs e) args@(_:_) ticks = do setChanged let vbs = map fst vs is1 = extendInScopeSetList is0 vbs -- XXX: 'vs' should already be deshadowed w.r.t. 'is0' Letrec vs <$> go is1 e args ticks go is0 (TyLam tv e) (Right t:args) ticks = do setChanged let subst = extendTvSubst (mkSubst is0) tv t (`mkTicks` ticks) <$> go is0 (substTm "appPropFast.TyAppTyLam" subst e) args [] go is0 (Case scrut ty0 alts) args0@(_:_) ticks = do setChanged let isA1 = unionInScope is0 ((mkInScopeSet . mkVarSet . concatMap (patVars . fst)) alts) (ty1,vs,args1) <- goCaseArg isA1 ty0 [] args0 case vs of [] -> (`mkTicks` ticks) . Case scrut ty1 <$> mapM (goAlt is0 args1) alts _ -> do let vbs = map fst vs is1 = extendInScopeSetList is0 vbs alts1 = map (deShadowAlt is1) alts Letrec vs . (`mkTicks` ticks) . Case scrut ty1 <$> mapM (goAlt is1 args1) alts1 go is0 (Tick sp e) args ticks = do setChanged go is0 e args (sp:ticks) go _ fun args ticks = return (mkApps (mkTicks fun ticks) args) goAlt is0 args0 (p,e) = do let (tvs,ids) = patIds p is1 = extendInScopeSetList (extendInScopeSetList is0 tvs) ids (p,) <$> go is1 e args0 [] goCaseArg isA ty0 ls0 (Right t:args0) = do tcm <- Lens.view tcCache let ty1 = piResultTy tcm ty0 t (ty2,ls1,args1) <- goCaseArg isA ty1 ls0 args0 return (ty2,ls1,Right t:args1) goCaseArg isA0 ty0 ls0 (Left arg:args0) = do tcm <- Lens.view tcCache let argTy = termType tcm arg ty1 = applyFunTy tcm ty0 argTy case isWorkFree arg || isVar arg of True -> do (ty2,ls1,args1) <- goCaseArg isA0 ty1 ls0 args0 return (ty2,ls1,Left arg:args1) False -> do boundArg <- mkTmBinderFor isA0 tcm (mkDerivedName ctx "app_arg") arg let isA1 = extendInScopeSet isA0 boundArg (ty2,ls1,args1) <- goCaseArg isA1 ty1 ls0 args0 return (ty2,(boundArg,arg):ls1,Left (Var boundArg):args1) goCaseArg _ ty ls [] = return (ty,ls,[]) {-# SCC appPropFast #-} -- | Flatten ridiculous case-statements generated by GHC -- -- For case-statements in haskell of the form: -- -- @ -- f :: Unsigned 4 -> Unsigned 4 -- f x = case x of -- 0 -> 3 -- 1 -> 2 -- 2 -> 1 -- 3 -> 0 -- @ -- -- GHC generates Core that looks like: -- -- @ -- f = \(x :: Unsigned 4) -> case x == fromInteger 3 of -- False -> case x == fromInteger 2 of -- False -> case x == fromInteger 1 of -- False -> case x == fromInteger 0 of -- False -> error "incomplete case" -- True -> fromInteger 3 -- True -> fromInteger 2 -- True -> fromInteger 1 -- True -> fromInteger 0 -- @ -- -- Which would result in a priority decoder circuit where a normal decoder -- circuit was desired. -- -- This transformation transforms the above Core to the saner: -- -- @ -- f = \(x :: Unsigned 4) -> case x of -- _ -> error "incomplete case" -- 0 -> fromInteger 3 -- 1 -> fromInteger 2 -- 2 -> fromInteger 1 -- 3 -> fromInteger 0 -- @ caseFlat :: HasCallStack => NormRewrite caseFlat _ e@(Case (collectEqArgs -> Just (scrut',_)) ty _) = do case collectFlat scrut' e of Just alts' -> changed (Case scrut' ty (last alts' : init alts')) Nothing -> return e caseFlat _ e = return e {-# SCC caseFlat #-} collectFlat :: Term -> Term -> Maybe [(Pat,Term)] collectFlat scrut (Case (collectEqArgs -> Just (scrut', val)) _ty [lAlt,rAlt]) | scrut' == scrut = case collectArgs val of (Prim p,args') | isFromInt (primName p) -> go (last args') (Data dc,args') | nameOcc (dcName dc) == "GHC.Types.I#" -> go (last args') _ -> Nothing where go (Left (Literal i)) = case (lAlt,rAlt) of ((pl,el),(pr,er)) | isFalseDcPat pl || isTrueDcPat pr -> case collectFlat scrut el of Just alts' -> Just ((LitPat i, er) : alts') Nothing -> Just [(LitPat i, er) ,(DefaultPat, el) ] | otherwise -> case collectFlat scrut er of Just alts' -> Just ((LitPat i, el) : alts') Nothing -> Just [(LitPat i, el) ,(DefaultPat, er) ] go _ = Nothing isFalseDcPat (DataPat p _ _) = ((== "GHC.Types.False") . nameOcc . dcName) p isFalseDcPat _ = False isTrueDcPat (DataPat p _ _) = ((== "GHC.Types.True") . nameOcc . dcName) p isTrueDcPat _ = False collectFlat _ _ = Nothing {-# SCC collectFlat #-} collectEqArgs :: Term -> Maybe (Term,Term) collectEqArgs (collectArgsTicks -> (Prim p, args, ticks)) | nm == "Clash.Sized.Internal.BitVector.eq#" = let [_,_,Left scrut,Left val] = args in Just (mkTicks scrut ticks,val) | nm == "Clash.Sized.Internal.Index.eq#" || nm == "Clash.Sized.Internal.Signed.eq#" || nm == "Clash.Sized.Internal.Unsigned.eq#" = let [_,Left scrut,Left val] = args in Just (mkTicks scrut ticks,val) | nm == "Clash.Transformations.eqInt" = let [Left scrut,Left val] = args in Just (mkTicks scrut ticks,val) where nm = primName p collectEqArgs _ = Nothing type NormRewriteW = Transform (StateT ([LetBinding],InScopeSet) (RewriteMonad NormalizeState)) -- | See Note [ANF InScopeSet] tellBinders :: Monad m => [LetBinding] -> StateT ([LetBinding],InScopeSet) m () tellBinders bs = modify ((bs ++) *** (`extendInScopeSetList` (map fst bs))) -- | See Note [ANF InScopeSet]; only extends the inscopeset notifyBinders :: Monad m => [LetBinding] -> StateT ([LetBinding],InScopeSet) m () notifyBinders bs = modify (second (`extendInScopeSetList` (map fst bs))) -- | Is the given type IO-like isSimIOTy :: TyConMap -> Type -- ^ Type to check for IO-likeness -> Bool isSimIOTy tcm ty = case tyView (coreView tcm ty) of TyConApp tcNm args | nameOcc tcNm == "Clash.Explicit.SimIO.SimIO" -> True | nameOcc tcNm == "GHC.Prim.(#,#)" , [_,_,st,_] <- args -> isStateTokenTy tcm st FunTy _ res -> isSimIOTy tcm res _ -> False -- | Is the given type the state token isStateTokenTy :: TyConMap -> Type -- ^ Type to check for state tokenness -> Bool isStateTokenTy tcm ty = case tyView (coreView tcm ty) of TyConApp tcNm _ -> nameOcc tcNm == "GHC.Prim.State#" _ -> False -- | Turn an expression into a modified ANF-form. As opposed to standard ANF, -- constants do not become let-bound. makeANF :: HasCallStack => NormRewrite makeANF (TransformContext is0 ctx) (Lam bndr e) = do e' <- makeANF (TransformContext (extendInScopeSet is0 bndr) (LamBody bndr:ctx)) e return (Lam bndr e') makeANF _ e@(TyLam {}) = return e makeANF ctx@(TransformContext is0 _) e0 = do -- We need to freshen all binders in `e` because we're shuffling them around -- into a single let-binder, because even when binders don't shadow, they -- don't have to be unique within an expression. And so lifting them all -- to a single let-binder will cause issues when they're not unique. -- -- We cannot make freshening part of collectANF, because when we generate -- new binders, we need to make sure those names do not conflict with _any_ -- of the existing binders in the expression. -- -- See also Note [ANF InScopeSet] let (is2,e1) = freshenTm is0 e0 ((e2,(bndrs,_)),Monoid.getAny -> hasChanged) <- listen (runStateT (bottomupR collectANF ctx e1) ([],is2)) case bndrs of [] -> if hasChanged then return e2 else return e0 _ -> do let (e3,ticks) = collectTicks e2 (srcTicks,nmTicks) = partitionTicks ticks -- Ensure that `AppendName` ticks still scope over the entire expression changed (mkTicks (Letrec bndrs (mkTicks e3 srcTicks)) nmTicks) {-# SCC makeANF #-} -- | Note [ANF InScopeSet] -- -- The InScopeSet contains: -- -- 1. All the free variables of the expression we are traversing -- -- 2. All the bound variables of the expression we are traversing -- -- 3. The newly created let-bindings as we recurse back up the traversal -- -- All of these are needed to created let-bindings that -- -- * Do not shadow -- * Are not shadowed -- * Nor conflict with each other (i.e. have the same unique) -- -- Initially we start with the local InScopeSet and add the global variables: -- -- @ -- is1 <- unionInScope is0 <$> Lens.use globalInScope -- @ -- -- Which will gives us the (superset of) free variables of the expression. Then -- we call 'freshenTm' -- -- @ -- let (is2,e1) = freshenTm is1 e0 -- @ -- -- Which extends the InScopeSet with all the bound variables in 'e1', the -- version of 'e0' where all binders are unique (not just deshadowed). -- -- So we start out with an InScopeSet that satisfies points 1 and 2, now every -- time we create a new binder we must add it to the InScopeSet to satisfy -- point 3. -- -- Note [ANF no let-bind] -- -- | Do not let-bind: -- -- 1. Arguments with an untranslatable type: untranslatable expressions -- should be propagated down as far as possible -- -- 2. Local variables or constants: they don't add any work, so no reason -- to let-bind to enable sharing -- -- 3. IO actions, the translation of IO actions to sequential HDL constructs -- depends on IO actions to be propagated down as far as possible. collectANF :: HasCallStack => NormRewriteW collectANF ctx e@(App appf arg) | (conVarPrim, _) <- collectArgs e , isCon conVarPrim || isPrim conVarPrim || isVar conVarPrim = do untranslatable <- lift (isUntranslatable False arg) let localVar = isLocalVar arg constantNoCR <- lift (isConstantNotClockReset arg) -- See Note [ANF no let-bind] case (untranslatable,localVar || constantNoCR, isSimBind conVarPrim,arg) of (False,False,False,_) -> do tcm <- Lens.view tcCache -- See Note [ANF InScopeSet] is1 <- Lens.use _2 argId <- lift (mkTmBinderFor is1 tcm (mkDerivedName ctx "app_arg") arg) -- See Note [ANF InScopeSet] tellBinders [(argId,arg)] return (App appf (Var argId)) (True,False,_,Letrec binds body) -> do tellBinders binds return (App appf body) _ -> return e where isSimBind (Prim p) = primName p == "Clash.Explicit.SimIO.bindSimIO#" isSimBind _ = False collectANF _ (Letrec binds body) = do tcm <- Lens.view tcCache let isSimIO = isSimIOTy tcm (termType tcm body) untranslatable <- lift (isUntranslatable False body) let localVar = isLocalVar body -- See Note [ANF no let-bind] if localVar || untranslatable || isSimIO then do tellBinders binds return body else do -- See Note [ANF InScopeSet] is1 <- Lens.use _2 argId <- lift (mkTmBinderFor is1 tcm (mkUnsafeSystemName "result" 0) body) -- See Note [ANF InScopeSet] tellBinders [(argId,body)] tellBinders binds return (Var argId) -- TODO: The code below special-cases ANF for the ':-' constructor for the -- 'Signal' type. The 'Signal' type is essentially treated as a "transparent" -- type by the Clash compiler, so observing its constructor leads to all kinds -- of problems. In this case that "Clash.Rewrite.Util.mkSelectorCase" will -- try to project the LHS and RHS of the ':-' constructor, however, -- 'mkSelectorCase' uses 'coreView1' to find the "real" data-constructor. -- 'coreView1' however looks through the 'Signal' type, and hence 'mkSelector' -- finds the data constructors for the element type of Signal. This resulted in -- error #24 (https://github.com/christiaanb/clash2/issues/24), where we -- try to get the first field out of the 'Vec's 'Nil' constructor. -- -- Ultimately we should stop treating Signal as a "transparent" type and deal -- handling of the Signal type, and the involved co-recursive functions, -- properly. At the moment, Clash cannot deal with this recursive type and the -- recursive functions involved, hence the need for special-casing code. After -- everything is done properly, we should remove the two lines below. collectANF _ e@(Case _ _ [(DataPat dc _ _,_)]) | nameOcc (dcName dc) == "Clash.Signal.Internal.:-" = return e collectANF ctx (Case subj ty alts) = do let localVar = isLocalVar subj let isConstantSubj = isConstant subj (subj',subjBinders) <- if localVar || isConstantSubj then return (subj,[]) else do tcm <- Lens.view tcCache -- See Note [ANF InScopeSet] is1 <- Lens.use _2 argId <- lift (mkTmBinderFor is1 tcm (mkDerivedName ctx "case_scrut") subj) -- See Note [ANF InScopeSet] notifyBinders [(argId,subj)] return (Var argId,[(argId,subj)]) tcm <- Lens.view tcCache let isSimIOAlt = isSimIOTy tcm ty alts' <- mapM (doAlt isSimIOAlt subj') alts tellBinders subjBinders case alts' of [(DataPat _ [] xs,altExpr)] | xs `localIdsDoNotOccurIn` altExpr || isSimIOAlt -> return altExpr _ -> return (Case subj' ty alts') where doAlt :: Bool -> Term -> (Pat,Term) -> StateT ([LetBinding],InScopeSet) (RewriteMonad NormalizeState) (Pat,Term) doAlt isSimIOAlt subj' alt@(DataPat dc exts xs,altExpr) | not (bindsExistentials exts xs) = do let lv = isLocalVar altExpr patSels <- Monad.zipWithM (doPatBndr subj' dc) xs [0..] let altExprIsConstant = isConstant altExpr let usesXs (Var n) = any (== n) xs usesXs _ = False -- See [ANF no let-bind] if or [isSimIOAlt, lv && (not (usesXs altExpr) || length alts == 1), altExprIsConstant] then do -- See Note [ANF InScopeSet] tellBinders patSels return alt else do tcm <- Lens.view tcCache -- See Note [ANF InScopeSet] is1 <- Lens.use _2 altId <- lift (mkTmBinderFor is1 tcm (mkDerivedName ctx "case_alt") altExpr) -- See Note [ANF InScopeSet] tellBinders (patSels ++ [(altId,altExpr)]) return (DataPat dc exts xs,Var altId) doAlt _ _ alt@(DataPat {}, _) = return alt doAlt isSimIOAlt _ alt@(pat,altExpr) = do let lv = isLocalVar altExpr let altExprIsConstant = isConstant altExpr -- See [ANF no let-bind] if isSimIOAlt || lv || altExprIsConstant then return alt else do tcm <- Lens.view tcCache -- See Note [ANF InScopeSet] is1 <- Lens.use _2 altId <- lift (mkTmBinderFor is1 tcm (mkDerivedName ctx "case_alt") altExpr) tellBinders [(altId,altExpr)] return (pat,Var altId) doPatBndr :: Term -> DataCon -> Id -> Int -> StateT ([LetBinding],InScopeSet) (RewriteMonad NormalizeState) LetBinding doPatBndr subj' dc pId i = do tcm <- Lens.view tcCache -- See Note [ANF InScopeSet] is1 <- Lens.use _2 patExpr <- lift (mkSelectorCase ($(curLoc) ++ "doPatBndr") is1 tcm subj' (dcTag dc) i) -- No need to 'tellBinders' here because 'pId' is already in the ANF -- InScopeSet. -- -- See also Note [ANF InScopeSet] return (pId,patExpr) collectANF _ e = return e {-# SCC collectANF #-} -- | Eta-expand top-level lambda's (DON'T use in a traversal!) etaExpansionTL :: HasCallStack => NormRewrite etaExpansionTL (TransformContext is0 ctx) (Lam bndr e) = do e' <- etaExpansionTL (TransformContext (extendInScopeSet is0 bndr) (LamBody bndr:ctx)) e return $ Lam bndr e' etaExpansionTL (TransformContext is0 ctx) (Letrec xes e) = do let bndrs = map fst xes e' <- etaExpansionTL (TransformContext (extendInScopeSetList is0 bndrs) (LetBody bndrs:ctx)) e case stripLambda e' of (bs@(_:_),e2) -> do let e3 = Letrec xes e2 changed (mkLams e3 bs) _ -> return (Letrec xes e') where stripLambda :: Term -> ([Id],Term) stripLambda (Lam bndr e0) = let (bndrs,e1) = stripLambda e0 in (bndr:bndrs,e1) stripLambda e' = ([],e') etaExpansionTL (TransformContext is0 ctx) e = do tcm <- Lens.view tcCache if isFun tcm e then do let argTy = ( fst . Maybe.fromMaybe (error $ $(curLoc) ++ "etaExpansion splitFunTy") . splitFunTy tcm . termType tcm ) e newId <- mkInternalVar is0 "arg" argTy e' <- etaExpansionTL (TransformContext (extendInScopeSet is0 newId) (LamBody newId:ctx)) (App e (Var newId)) changed (Lam newId e') else return e {-# SCC etaExpansionTL #-} -- | Eta-expand functions with a Synthesize annotation, needed to allow such -- functions to appear as arguments to higher-order primitives. etaExpandSyn :: HasCallStack => NormRewrite etaExpandSyn (TransformContext is0 ctx) e@(collectArgs -> (Var f, _)) = do topEnts <- Lens.view topEntities tcm <- Lens.view tcCache let isTopEnt = f `elemVarSet` topEnts isAppFunCtx = \case AppFun:_ -> True TickC _:c -> isAppFunCtx c _ -> False argTyM = fmap fst (splitFunTy tcm (termType tcm e)) case argTyM of Just argTy | isTopEnt && not (isAppFunCtx ctx) -> do newId <- mkInternalVar is0 "arg" argTy changed (Lam newId (App e (Var newId))) _ -> return e etaExpandSyn _ e = return e {-# SCC etaExpandSyn #-} isClassConstraint :: Type -> Bool isClassConstraint (tyView -> TyConApp nm0 _) = if -- Constraint tuple: | "GHC.Classes.(%" `Text.isInfixOf` nm1 -> True -- Constraint class: | "C:" `Text.isInfixOf` nm2 -> True | otherwise -> False where nm1 = nameOcc nm0 nm2 = snd (Text.breakOnEnd "." nm1) isClassConstraint _ = False -- | Turn a normalized recursive function, where the recursive calls only pass -- along the unchanged original arguments, into let-recursive function. This -- means that all recursive calls are replaced by the same variable reference as -- found in the body of the top-level let-expression. recToLetRec :: HasCallStack => NormRewrite recToLetRec (TransformContext is0 []) e = do (fn,_) <- Lens.use curFun tcm <- Lens.view tcCache case splitNormalized tcm e of Right (args,bndrs,res) -> do let args' = map Var args (toInline,others) = List.partition (eqApp tcm fn args' . snd) bndrs resV = Var res case (toInline,others) of (_:_,_:_) -> do let is1 = extendInScopeSetList is0 (args ++ map fst bndrs) let substsInline = extendIdSubstList (mkSubst is1) $ map (second (const resV)) toInline others' = map (second (substTm "recToLetRec" substsInline)) others changed $ mkLams (Letrec others' resV) args _ -> return e _ -> return e where -- This checks whether things are semantically equal. For example, say we -- have: -- -- x :: (a, (b, c)) -- -- and -- -- y :: (a, (b, c)) -- -- If we can determine that 'y' is constructed solely using the -- corresponding fields in 'x', then we can say they are semantically -- equal. The algorithm below keeps track of what (sub)field it is -- constructing, and checks if the field-expression projects the -- corresponding (sub)field from the target variable. -- -- TODO: See [Note: Breaks on constants and predetermined equality] eqApp tcm v args (collectArgs . stripTicks -> (Var v',args')) | isGlobalId v' , v == v' , let args2 = Either.lefts args' , length args == length args2 = and (zipWith (eqArg tcm) args args2) eqApp _ _ _ _ = False eqArg _ v1 v2@(stripTicks -> Var {}) = v1 == v2 eqArg tcm v1 v2@(collectArgs . stripTicks -> (Data _, args')) | let t1 = termType tcm v1 , let t2 = termType tcm v2 , t1 == t2 = if isClassConstraint t1 then -- Class constraints are equal if their types are equal, so we can -- take a shortcut here. True else -- Check whether all arguments to the data constructor are projections -- and (zipWith (eqDat v1) (map pure [0..]) (Either.lefts args')) eqArg _ _ _ = False -- Recursively check whether a term /e/ is semantically equal to some variable /v/. -- Currently it can only assert equality when /e/ is syntactically equal -- to /v/, or is constructed out of projections of /v/, importantly: -- -- [Note: Breaks on constants and predetermined equality] -- This function currently breaks if: -- -- * One or more subfields are constants. Constants might have been -- inlined for the construction, instead of being a projection of the -- target variable. -- -- * One or more subfields are determined to be equal and one is simply -- swapped / replaced by the other. For example, say we have -- `x :: (a, a)`. If GHC determines that both elements of the tuple will -- always be the same, it might replace the (semantically equal to 'x') -- construction of `y` with `(fst x, fst x)`. -- eqDat :: Term -> [Int] -> Term -> Bool eqDat v fTrace (collectArgs . stripTicks -> (Data _, args)) = and (zipWith (eqDat v) (map (:fTrace) [0..]) (Either.lefts args)) eqDat v1 fTrace v2 = case stripProjection (reverse fTrace) v1 v2 of Just [] -> True _ -> False stripProjection :: [Int] -> Term -> Term -> Maybe [Int] stripProjection fTrace0 vTarget0 (Case v _ [(DataPat _ _ xs, r)]) = do -- Get projection made in subject of case: fTrace1 <- stripProjection fTrace0 vTarget0 v -- Extract projection of this case statement. Subsequent calls to -- 'stripProjection' will check if new target is actually used. (n, fTrace2) <- List.uncons fTrace1 vTarget1 <- List.indexMaybe xs n stripProjection fTrace2 (Var vTarget1) r stripProjection fTrace (Var sTarget) (Var s) = if sTarget == s then Just fTrace else Nothing stripProjection _fTrace _vTarget _v = Nothing recToLetRec _ e = return e {-# SCC recToLetRec #-} -- | Inline a function with functional arguments inlineHO :: HasCallStack => NormRewrite inlineHO _ e@(App _ _) | (Var f, args, ticks) <- collectArgsTicks e = do tcm <- Lens.view tcCache let hasPolyFunArgs = or (map (either (isPolyFun tcm) (const False)) args) if hasPolyFunArgs then do (cf,_) <- Lens.use curFun isInlined <- zoomExtra (alreadyInlined f cf) limit <- Lens.use (extra.inlineLimit) if (Maybe.fromMaybe 0 isInlined) > limit then do lvl <- Lens.view dbgLevel traceIf (lvl > DebugNone) ($(curLoc) ++ "InlineHO: " ++ show f ++ " already inlined " ++ show limit ++ " times in:" ++ show cf) (return e) else do bodyMaybe <- lookupVarEnv f <$> Lens.use bindings case bodyMaybe of Just b -> do zoomExtra (addNewInline f cf) changed (mkApps (mkTicks (bindingTerm b) ticks) args) _ -> return e else return e inlineHO _ e = return e {-# SCC inlineHO #-} -- | Simplified CSE, only works on let-bindings, does an inverse topological -- sort of the let-bindings and then works from top to bottom -- -- XXX: Check whether inverse top-sort followed by single traversal removes as -- many binders as the previous "apply-until-fixpoint" approach in the presence -- of recursive groups in the let-bindings. If not but just for checking whether -- changes to transformation affect the eventual size of the circuit, it would -- be really helpful if we tracked circuit size in the regression/test suite. -- On the two examples that were tested, Reducer and PipelinesViaFolds, this new -- version of CSE removed the same amount of let-binders. simpleCSE :: HasCallStack => NormRewrite simpleCSE (TransformContext is0 _) (inverseTopSortLetBindings -> Letrec bndrs body) = do let is1 = extendInScopeSetList is0 (map fst bndrs) (subst,bndrs1) <- reduceBinders (mkSubst is1) [] bndrs -- TODO: check whether a substitution over the body is enough, the reason I'm -- doing a substitution over the the binders as well is that I don't know in -- what order a recursive group shows up in a inverse topological sort. -- Depending on the order and forgetting to apply the substitution over the -- let-bindings might lead to the introduction of free variables. -- -- NB: don't apply the substitution to the entire let-expression, and that -- would rename the let-bindings because they've been added to the InScopeSet -- of the substitution. let bndrs2 = map (second (substTm "simpleCSE.bndrs" subst)) bndrs1 body1 = substTm "simpleCSE.body" subst body return (Letrec bndrs2 body1) simpleCSE _ e = return e {-# SCC simpleCSE #-} -- | XXX: is given inverse topologically sorted binders, but returns -- topologically sorted binders -- -- TODO: check further speed improvements: -- -- 1. Store the processed binders in a `Map Expr LetBinding`: -- * Trades O(1) `cons` and O(n)*aeqTerm `find` for: -- * O(log n)*aeqTerm `insert` and O(log n)*aeqTerm `lookup` -- 2. Store the processed binders in a `AEQTrie Expr LetBinding` -- * Trades O(1) `cons` and O(n)*aeqTerm `find` for: -- * O(e) `insert` and O(e) `lookup` reduceBinders :: Subst -> [LetBinding] -> [LetBinding] -> RewriteMonad NormalizeState (Subst, [LetBinding]) reduceBinders !subst processed [] = return (subst,processed) reduceBinders !subst processed ((i,substTm "reduceBinders" subst -> e):rest) | (_,_,ticks) <- collectArgsTicks e , NoDeDup `notElem` ticks , Just (i1,_) <- List.find ((== e) . snd) processed = do let subst1 = extendIdSubst subst i (Var i1) setChanged reduceBinders subst1 processed rest | otherwise = reduceBinders subst ((i,e):processed) rest {-# SCC reduceBinders #-} reduceConst :: HasCallStack => NormRewrite reduceConst ctx e@(App _ _) | (Prim p0, _) <- collectArgs e = whnfRW False ctx e $ \_ctx1 e1 -> case e1 of (collectArgs -> (Prim p1, _)) | primName p0 == primName p1 -> return e _ -> changed e1 reduceConst _ e = return e {-# SCC reduceConst #-} -- | Replace primitives by their "definition" if they would lead to let-bindings -- with a non-representable type when a function is in ANF. This happens for -- example when Clash.Size.Vector.map consumes or produces a vector of -- non-representable elements. -- -- Basically what this transformation does is replace a primitive the completely -- unrolled recursive definition that it represents. e.g. -- -- > zipWith ($) (xs :: Vec 2 (Int -> Int)) (ys :: Vec 2 Int) -- -- is replaced by: -- -- > let (x0 :: (Int -> Int)) = case xs of (:>) _ x xr -> x -- > (xr0 :: Vec 1 (Int -> Int)) = case xs of (:>) _ x xr -> xr -- > (x1 :: (Int -> Int)( = case xr0 of (:>) _ x xr -> x -- > (y0 :: Int) = case ys of (:>) _ y yr -> y -- > (yr0 :: Vec 1 Int) = case ys of (:>) _ y yr -> xr -- > (y1 :: Int = case yr0 of (:>) _ y yr -> y -- > in (($) x0 y0 :> ($) x1 y1 :> Nil) -- -- Currently, it only handles the following functions: -- -- * Clash.Sized.Vector.zipWith -- * Clash.Sized.Vector.map -- * Clash.Sized.Vector.traverse# -- * Clash.Sized.Vector.fold -- * Clash.Sized.Vector.foldr -- * Clash.Sized.Vector.dfold -- * Clash.Sized.Vector.(++) -- * Clash.Sized.Vector.head -- * Clash.Sized.Vector.tail -- * Clash.Sized.Vector.last -- * Clash.Sized.Vector.init -- * Clash.Sized.Vector.unconcat -- * Clash.Sized.Vector.transpose -- * Clash.Sized.Vector.replicate -- * Clash.Sized.Vector.replace_int -- * Clash.Sized.Vector.imap -- * Clash.Sized.Vector.dtfold -- * Clash.Sized.RTree.tdfold -- * Clash.Sized.RTree.treplicate -- * Clash.Sized.Internal.BitVector.split# -- * Clash.Sized.Internal.BitVector.eq# reduceNonRepPrim :: HasCallStack => NormRewrite reduceNonRepPrim c@(TransformContext is0 ctx) e@(App _ _) | (Prim p, args, ticks) <- collectArgsTicks e = do tcm <- Lens.view tcCache ultra <- Lens.use (extra.normalizeUltra) let eTy = termType tcm e case tyView eTy of (TyConApp vecTcNm@(nameOcc -> "Clash.Sized.Vector.Vec") [runExcept . tyNatSize tcm -> Right 0, aTy]) -> do let (Just vecTc) = lookupUniqMap vecTcNm tcm [nilCon,consCon] = tyConDataCons vecTc nilE = mkVec nilCon consCon aTy 0 [] changed (mkTicks nilE ticks) tv -> let argLen = length args in case primName p of "Clash.Sized.Vector.zipWith" | argLen == 7 -> do let [lhsElTy,rhsElty,resElTy,nTy] = Either.rights args case runExcept (tyNatSize tcm nTy) of Right n -> do shouldReduce1 <- List.orM [ pure (ultra || n < 2) , shouldReduce ctx , List.anyM isUntranslatableType_not_poly [lhsElTy,rhsElty,resElTy] ] if shouldReduce1 then let [fun,lhsArg,rhsArg] = Either.lefts args in (`mkTicks` ticks) <$> reduceZipWith c n lhsElTy rhsElty resElTy fun lhsArg rhsArg else return e _ -> return e "Clash.Sized.Vector.map" | argLen == 5 -> do let [argElTy,resElTy,nTy] = Either.rights args case runExcept (tyNatSize tcm nTy) of Right n -> do shouldReduce1 <- List.orM [ pure (ultra || n < 2 ) , shouldReduce ctx , List.anyM isUntranslatableType_not_poly [argElTy,resElTy] ] if shouldReduce1 then let [fun,arg] = Either.lefts args in (`mkTicks` ticks) <$> reduceMap c n argElTy resElTy fun arg else return e _ -> return e "Clash.Sized.Vector.traverse#" | argLen == 7 -> let [aTy,fTy,bTy,nTy] = Either.rights args in case runExcept (tyNatSize tcm nTy) of Right n -> let [dict,fun,arg] = Either.lefts args in (`mkTicks` ticks) <$> reduceTraverse c n aTy fTy bTy dict fun arg _ -> return e "Clash.Sized.Vector.fold" | argLen == 4 -> do let [aTy,nTy] = Either.rights args case runExcept (tyNatSize tcm nTy) of Right n -> do shouldReduce1 <- List.orM [ pure (ultra || n == 0) , shouldReduce ctx , isUntranslatableType_not_poly aTy ] if shouldReduce1 then let [fun,arg] = Either.lefts args in (`mkTicks` ticks) <$> reduceFold c (n + 1) aTy fun arg else return e _ -> return e "Clash.Sized.Vector.foldr" | argLen == 6 -> let [aTy,bTy,nTy] = Either.rights args in case runExcept (tyNatSize tcm nTy) of Right n -> do shouldReduce1 <- List.orM [ pure ultra , shouldReduce ctx , List.anyM isUntranslatableType_not_poly [aTy,bTy] ] if shouldReduce1 then let [fun,start,arg] = Either.lefts args in (`mkTicks` ticks) <$> reduceFoldr c n aTy fun start arg else return e _ -> return e "Clash.Sized.Vector.dfold" | argLen == 8 -> let ([_kn,_motive,fun,start,arg],[_mTy,nTy,aTy]) = Either.partitionEithers args in case runExcept (tyNatSize tcm nTy) of Right n -> (`mkTicks` ticks) <$> reduceDFold is0 n aTy fun start arg _ -> return e "Clash.Sized.Vector.++" | argLen == 5 -> let [nTy,aTy,mTy] = Either.rights args [lArg,rArg] = Either.lefts args in case (runExcept (tyNatSize tcm nTy), runExcept (tyNatSize tcm mTy)) of (Right n, Right m) | n == 0 -> changed rArg | m == 0 -> changed lArg | otherwise -> do shouldReduce1 <- List.orM [ shouldReduce ctx , isUntranslatableType_not_poly aTy ] if shouldReduce1 then (`mkTicks` ticks) <$> reduceAppend is0 n m aTy lArg rArg else return e _ -> return e "Clash.Sized.Vector.head" | argLen == 3 -> do let [nTy,aTy] = Either.rights args [vArg] = Either.lefts args case runExcept (tyNatSize tcm nTy) of Right n -> do shouldReduce1 <- List.orM [ shouldReduce ctx , isUntranslatableType_not_poly aTy ] if shouldReduce1 then (`mkTicks` ticks) <$> reduceHead is0 (n+1) aTy vArg else return e _ -> return e "Clash.Sized.Vector.tail" | argLen == 3 -> do let [nTy,aTy] = Either.rights args [vArg] = Either.lefts args case runExcept (tyNatSize tcm nTy) of Right n -> do shouldReduce1 <- List.orM [ shouldReduce ctx , isUntranslatableType_not_poly aTy ] if shouldReduce1 then (`mkTicks` ticks) <$> reduceTail is0 (n+1) aTy vArg else return e _ -> return e "Clash.Sized.Vector.last" | argLen == 3 -> do let [nTy,aTy] = Either.rights args [vArg] = Either.lefts args case runExcept (tyNatSize tcm nTy) of Right n -> do shouldReduce1 <- List.orM [ shouldReduce ctx , isUntranslatableType_not_poly aTy ] if shouldReduce1 then (`mkTicks` ticks) <$> reduceLast is0 (n+1) aTy vArg else return e _ -> return e "Clash.Sized.Vector.init" | argLen == 3 -> do let [nTy,aTy] = Either.rights args [vArg] = Either.lefts args case runExcept (tyNatSize tcm nTy) of Right n -> do shouldReduce1 <- List.orM [ shouldReduce ctx , isUntranslatableType_not_poly aTy ] if shouldReduce1 then (`mkTicks` ticks) <$> reduceInit is0 (n+1) aTy vArg else return e _ -> return e "Clash.Sized.Vector.unconcat" | argLen == 6 -> do let ([_knN,_sm,arg],[mTy,nTy,aTy]) = Either.partitionEithers args case (runExcept (tyNatSize tcm nTy), runExcept (tyNatSize tcm mTy)) of (Right n, Right 0) -> (`mkTicks` ticks) <$> reduceUnconcat n 0 aTy arg _ -> return e "Clash.Sized.Vector.transpose" | argLen == 5 -> do let ([_knN,arg],[mTy,nTy,aTy]) = Either.partitionEithers args case (runExcept (tyNatSize tcm nTy), runExcept (tyNatSize tcm mTy)) of (Right n, Right 0) -> (`mkTicks` ticks) <$> reduceTranspose n 0 aTy arg _ -> return e "Clash.Sized.Vector.replicate" | argLen == 4 -> do let ([_sArg,vArg],[nTy,aTy]) = Either.partitionEithers args case runExcept (tyNatSize tcm nTy) of Right n -> do shouldReduce1 <- List.orM [ shouldReduce ctx , isUntranslatableType_not_poly aTy ] if shouldReduce1 then (`mkTicks` ticks) <$> reduceReplicate n aTy eTy vArg else return e _ -> return e -- replace_int :: KnownNat n => Vec n a -> Int -> a -> Vec n a "Clash.Sized.Vector.replace_int" | argLen == 6 -> do let ([_knArg,vArg,iArg,aArg],[nTy,aTy]) = Either.partitionEithers args case runExcept (tyNatSize tcm nTy) of Right n -> do shouldReduce1 <- List.orM [ pure ultra , shouldReduce ctx , isUntranslatableType_not_poly aTy ] if shouldReduce1 then (`mkTicks` ticks) <$> reduceReplace_int is0 n aTy eTy vArg iArg aArg else return e _ -> return e "Clash.Sized.Vector.index_int" | argLen == 5 -> do let ([_knArg,vArg,iArg],[nTy,aTy]) = Either.partitionEithers args case runExcept (tyNatSize tcm nTy) of Right n -> do shouldReduce1 <- List.orM [ pure ultra , shouldReduce ctx , isUntranslatableType_not_poly aTy ] if shouldReduce1 then (`mkTicks` ticks) <$> reduceIndex_int is0 n aTy vArg iArg else return e _ -> return e "Clash.Sized.Vector.imap" | argLen == 6 -> do let [nTy,argElTy,resElTy] = Either.rights args case runExcept (tyNatSize tcm nTy) of Right n -> do shouldReduce1 <- List.orM [ pure (ultra || n < 2) , shouldReduce ctx , List.anyM isUntranslatableType_not_poly [argElTy,resElTy] ] if shouldReduce1 then let [_,fun,arg] = Either.lefts args in (`mkTicks` ticks) <$> reduceImap c n argElTy resElTy fun arg else return e _ -> return e "Clash.Sized.Vector.dtfold" | argLen == 8 -> let ([_kn,_motive,lrFun,brFun,arg],[_mTy,nTy,aTy]) = Either.partitionEithers args in case runExcept (tyNatSize tcm nTy) of Right n -> (`mkTicks` ticks) <$> reduceDTFold is0 n aTy lrFun brFun arg _ -> return e "Clash.Sized.Vector.reverse" | ultra , ([vArg],[nTy,aTy]) <- Either.partitionEithers args , Right n <- runExcept (tyNatSize tcm nTy) -> (`mkTicks` ticks) <$> reduceReverse is0 n aTy vArg "Clash.Sized.RTree.tdfold" | argLen == 8 -> let ([_kn,_motive,lrFun,brFun,arg],[_mTy,nTy,aTy]) = Either.partitionEithers args in case runExcept (tyNatSize tcm nTy) of Right n -> (`mkTicks` ticks) <$> reduceTFold is0 n aTy lrFun brFun arg _ -> return e "Clash.Sized.RTree.treplicate" | argLen == 4 -> do let ([_sArg,vArg],[nTy,aTy]) = Either.partitionEithers args case runExcept (tyNatSize tcm nTy) of Right n -> do shouldReduce1 <- List.orM [ shouldReduce ctx , isUntranslatableType False aTy ] if shouldReduce1 then (`mkTicks` ticks) <$> reduceTReplicate n aTy eTy vArg else return e _ -> return e "Clash.Sized.Internal.BitVector.split#" | argLen == 4 -> do let ([_knArg,bvArg],[nTy,mTy]) = Either.partitionEithers args case (runExcept (tyNatSize tcm nTy), runExcept (tyNatSize tcm mTy), tv) of (Right n, Right m, TyConApp tupTcNm [lTy,rTy]) | n == 0 -> do let (Just tupTc) = lookupUniqMap tupTcNm tcm [tupDc] = tyConDataCons tupTc tup = mkApps (Data tupDc) [Right lTy ,Right rTy ,Left bvArg ,Left (removedTm rTy) ] changed (mkTicks tup ticks) | m == 0 -> do let (Just tupTc) = lookupUniqMap tupTcNm tcm [tupDc] = tyConDataCons tupTc tup = mkApps (Data tupDc) [Right lTy ,Right rTy ,Left (removedTm lTy) ,Left bvArg ] changed (mkTicks tup ticks) _ -> return e "Clash.Sized.Internal.BitVector.eq#" | ([_,_],[nTy]) <- Either.partitionEithers args , Right 0 <- runExcept (tyNatSize tcm nTy) , TyConApp boolTcNm [] <- tv -> let (Just boolTc) = lookupUniqMap boolTcNm tcm [_falseDc,trueDc] = tyConDataCons boolTc in changed (mkTicks (Data trueDc) ticks) _ -> return e where isUntranslatableType_not_poly t = do u <- isUntranslatableType False t if u then return (null $ Lens.toListOf typeFreeVars t) else return False reduceNonRepPrim _ e = return e {-# SCC reduceNonRepPrim #-} -- | This transformation lifts applications of global binders out of -- alternatives of case-statements. -- -- e.g. It converts: -- -- @ -- case x of -- A -> f 3 y -- B -> f x x -- C -> h x -- @ -- -- into: -- -- @ -- let f_arg0 = case x of {A -> 3; B -> x} -- f_arg1 = case x of {A -> y; B -> x} -- f_out = f f_arg0 f_arg1 -- in case x of -- A -> f_out -- B -> f_out -- C -> h x -- @ disjointExpressionConsolidation :: HasCallStack => NormRewrite disjointExpressionConsolidation ctx@(TransformContext is0 _) e@(Case _scrut _ty _alts@(_:_:_)) = do (_,collected) <- collectGlobals is0 [] [] e let disJoint = filter (isDisjoint . snd . snd) collected if null disJoint then return e else do exprs <- mapM (mkDisjointGroup is0) disJoint tcm <- Lens.view tcCache lids <- Monad.zipWithM (mkFunOut is0 tcm) disJoint exprs let substitution = zip (map fst disJoint) (map Var lids) subsMatrix = l2m substitution (exprs',_) <- unzip <$> Monad.zipWithM (\s (e',seen) -> collectGlobals is0 s seen e') subsMatrix exprs (e',_) <- collectGlobals is0 substitution [] e let lb = Letrec (zip lids exprs') e' lb' <- bottomupR deadCode ctx lb changed lb' where mkFunOut isN tcm (fun,_) (e',_) = do let ty = termType tcm e' nm = case collectArgs fun of (Var v,_) -> nameOcc (varName v) (Prim p,_) -> primName p _ -> "complex_expression_" nm'' = last (Text.splitOn "." nm) `Text.append` "Out" mkInternalVar isN nm'' ty l2m = go [] where go _ [] = [] go xs (y:ys) = (xs ++ ys) : go (xs ++ [y]) ys disjointExpressionConsolidation _ e = return e {-# SCC disjointExpressionConsolidation #-} -- | Given a function in the desired normal form, inline all the following -- let-bindings: -- -- Let-bindings with an internal name that is only used once, where it binds: -- * a primitive that will be translated to an HDL expression (as opposed to -- a HDL declaration) -- * a projection case-expression (1 alternative) -- * a data constructor -- * I/O actions inlineCleanup :: HasCallStack => NormRewrite inlineCleanup (TransformContext is0 _) (Letrec binds body) = do prims <- Lens.use (extra.primitives) -- For all let-bindings, count the number of times they are referenced. -- We only inline let-bindings which are referenced only once, otherwise -- we would lose sharing. let is1 = extendInScopeSetList is0 (map fst binds) bindsFvs = map (\(v,e) -> (v,((v,e),countFreeOccurances e))) binds allOccs = List.foldl' (unionVarEnvWith (+)) emptyVarEnv $ map (snd.snd) bindsFvs bodyFVs = Lens.foldMapOf freeLocalIds unitVarSet body (il,keep) = List.partition (isInteresting allOccs prims bodyFVs) bindsFvs keep' = inlineBndrsCleanup is1 (mkVarEnv il) emptyVarEnv $ map snd keep if | null il -> return (Letrec binds body) | null keep' -> changed body | otherwise -> changed (Letrec keep' body) where -- Determine whether a let-binding is interesting to inline isInteresting :: VarEnv Int -> CompiledPrimMap -> VarSet -> (Id,((Id, Term), VarEnv Int)) -> Bool isInteresting allOccs prims bodyFVs (id_,((_,(fst.collectArgs) -> tm),_)) | nameSort (varName id_) /= User , id_ `notElemVarSet` bodyFVs = case tm of Prim pInfo | let nm = primName pInfo , Just (extractPrim -> Just p@(BlackBox {})) <- HashMap.lookup nm prims , TExpr <- kind p , Just occ <- lookupVarEnv id_ allOccs , occ < 2 -> True | otherwise -> primName pInfo `elem` ["Clash.Explicit.SimIO.bindSimIO#"] Case _ _ [_] -> True Data _ -> True Case _ aTy (_:_:_) | TyConApp (nameOcc -> "Clash.Explicit.SimIO.SimIO") _ <- tyView aTy -> True _ -> False | id_ `notElemVarSet` bodyFVs = case tm of Prim pInfo | primName pInfo `elem` [ "Clash.Explicit.SimIO.openFile" , "Clash.Explicit.SimIO.fgetc" , "Clash.Explicit.SimIO.feof" ] , Just occ <- lookupVarEnv id_ allOccs , occ < 2 -> True | otherwise -> primName pInfo `elem` ["Clash.Explicit.SimIO.bindSimIO#"] Case _ _ [(DataPat dcE _ _,_)] -> let nm = (nameOcc (dcName dcE)) in -- Inlines WW projection that exposes internals of the BitVector types nm == "Clash.Sized.Internal.BitVector.BV" || nm == "Clash.Sized.Internal.BitVector.Bit" || -- Inlines projections out of constraint-tuples (e.g. HiddenClockReset) "GHC.Classes" `Text.isPrefixOf` nm Case _ aTy (_:_:_) | TyConApp (nameOcc -> "Clash.Explicit.SimIO.SimIO") _ <- tyView aTy -> True _ -> False isInteresting _ _ _ _ = False inlineCleanup _ e = return e {-# SCC inlineCleanup #-} -- | Mark to track progress of 'reduceBindersCleanup' data Mark = Temp | Done | Rec -- | Used by 'inlineCleanup' to inline binders that we want to inline into the -- binders that we want to keep. inlineBndrsCleanup :: InScopeSet -- ^ Current InScopeSet -> VarEnv ((Id,Term),VarEnv Int) -- ^ Original let-binders with their free variables (+ #occurrences), that we -- want to inline -> VarEnv ((Id,Term),VarEnv Int,Mark) -- ^ Processed let-binders with their free variables and a tag to mark the -- progress: -- * Temp: Will eventually form a recursive cycle -- * Done: Processed, non-recursive -- * Rec: Processed, recursive -> [((Id,Term),VarEnv Int)] -- ^ The let-binders with their free variables (+ #occurrences), that we want -- to keep -> [(Id,Term)] inlineBndrsCleanup isN origInl = go where go doneInl [] = -- If some of the let-binders that we wanted to inline turn out to be -- recursive, then we have to keep those around as well, as we weren't able -- to inline them. [ (v,e) | ((v,e),_,Rec) <- eltsVarEnv doneInl ] go !doneInl (((v,e),eFVs):il) = let (sM,_,doneInl1) = foldlWithUniqueVarEnv' (reduceBindersCleanup isN origInl) (Nothing, emptyVarEnv, doneInl) eFVs e1 = case sM of Nothing -> e Just s -> substTm "inlineBndrsCleanup" s e in (v,e1):go doneInl1 il {-# SCC inlineBndrsCleanup #-} -- | Used (transitively) by 'inlineCleanup' inline to-inline let-binders into -- the other to-inline let-binders. reduceBindersCleanup :: InScopeSet -- ^ Current InScopeSet -> VarEnv ((Id,Term),VarEnv Int) -- ^ Original let-binders with their free variables (+ #occurrences) -> (Maybe Subst,VarEnv Int,VarEnv ((Id,Term),VarEnv Int,Mark)) -- ^ Accumulated: -- -- 1. (Maybe) the build up substitution so far -- 2. The free variables of the range of the substitution -- 3. Processed let-binders with their free variables and a tag to mark -- the progress: -- * Temp: Will eventually form a recursive cycle -- * Done: Processed, non-recursive -- * Rec: Processed, recursive -> Unique -- ^ The unique of the let-binding that we want to simplify -> Int -- ^ Ignore, artifact of 'foldlWithUniqueVarEnv' -> (Maybe Subst,VarEnv Int,VarEnv ((Id,Term),VarEnv Int,Mark)) -- ^ Same as the third argument reduceBindersCleanup isN origInl (!substM,!substFVs,!doneInl) u _ = case lookupVarEnvDirectly u doneInl of Nothing -> case lookupVarEnvDirectly u origInl of Nothing -> -- let-binding not found, cannot extend the substitution (substM,substFVs,doneInl) Just ((v,e),eFVs) -> -- Simplify the transitive dependencies let (sM,substFVsE,doneInl1) = foldlWithUniqueVarEnv' (reduceBindersCleanup isN origInl) ( Nothing -- It's okay/needed to over-approximate the free variables of -- the range of the new substitution by including the free -- variables of the original let-binder, because this set of -- free variables is only used to check whether let-binding will -- become self-recursive after applying the substitution. -- -- That is, it was already self-recursive, or becomes -- self-recursive after applying the substitution because it was -- part of a recursive group. And we do not want to inline -- recursive binders. , eFVs -- Temporarily extend the processing environment with the -- let-binding so we don't end up in a loop in case there is a -- recursive group. , extendVarEnv v ((v,e),eFVs,Temp) doneInl) eFVs e1 = case sM of Nothing -> e Just s -> substTm "reduceBindersCleanup" s e in if v `elemVarEnv` substFVsE then -- We cannot inline recursive let-bindings, so we do not extend -- the substitution environment. ( substM , substFVs -- And we explicitly mark the let-binding as recursive in the -- processing environment. So that it will be kept around at the -- end of 'inlineCleanup' , extendVarEnv v ((v,e1),substFVsE,Rec) doneInl1 ) else -- Extend the substitution ( Just (extendIdSubst (Maybe.fromMaybe (mkSubst isN) substM) v e1) , unionVarEnv substFVsE substFVs -- Mark the let-binding a fully "reduced", so we don't repeat -- this process when we encounter it again. , extendVarEnv v ((v,e1),substFVsE,Done) doneInl1 ) -- It's already been process, just extend the substitution environment Just ((v,e),eFVs,Done) -> ( Just (extendIdSubst (Maybe.fromMaybe (mkSubst isN) substM) v e) , unionVarEnv eFVs substFVs , doneInl ) -- It's either recursive (Rec), or part of a recursive group (Temp) where we -- originally entered a different part of the cycle. Regardless, we do not -- extend the substitution environment. Just _ -> ( substM , substFVs , doneInl ) {-# SCC reduceBindersCleanup #-} -- | Flatten's letrecs after `inlineCleanup` -- -- `inlineCleanup` sometimes exposes additional possibilities for `caseCon`, -- which then introduces let-bindings in what should be ANF. This transformation -- flattens those nested let-bindings again. -- -- NB: must only be called in the cleaning up phase. flattenLet :: HasCallStack => NormRewrite flattenLet (TransformContext is0 _) (Letrec binds body) = do let is1 = extendInScopeSetList is0 (map fst binds) bodyOccs = Lens.foldMapByOf freeLocalIds (unionVarEnvWith (+)) emptyVarEnv (`unitVarEnv` (1 :: Int)) body (is2,binds1) <- second concat <$> List.mapAccumLM go is1 binds case binds1 of -- inline binders into the body when there's only a single binder, and only -- if that binder doesn't perform any work or is only used once in the body [(id1,e1)] | Just occ <- lookupVarEnv id1 bodyOccs, isWorkFree e1 || occ < 2 -> if id1 `localIdOccursIn` e1 -- Except when the binder is recursive! then return (Letrec binds1 body) else let subst = extendIdSubst (mkSubst is2) id1 e1 in changed (substTm "flattenLet" subst body) _ -> return (Letrec binds1 body) where go :: InScopeSet -> LetBinding -> NormalizeSession (InScopeSet,[LetBinding]) go isN (id1,collectTicks -> (Letrec binds1 body1,ticks)) = do let bs1 = map fst binds1 let (binds2,body2,isN1) = -- We need to deshadow because we're merging nested let-expressions -- into a single let-expression: and within a let-expression, the -- bindings are not allowed to shadow each-other. Of course, we -- only need to deshadow if any shadowing is happening in the -- first place. -- -- This is much better than blindly calling freshenTm, and saves -- almost 30% run-time of the normalization phase on some examples. if any (`elemInScopeSet` isN) bs1 then let Letrec bindsN bodyN = deShadowTerm isN (Letrec binds1 body1) in (bindsN,bodyN,extendInScopeSetList isN (map fst bindsN)) else (binds1,body1,extendInScopeSetList isN bs1) let bodyOccs = Lens.foldMapByOf freeLocalIds (unionVarEnvWith (+)) emptyVarEnv (`unitVarEnv` (1 :: Int)) body2 (srcTicks,nmTicks) = partitionTicks ticks -- Distribute the name ticks of the let-expression over all the bindings (isN1,) . map (second (`mkTicks` nmTicks)) <$> case binds2 of -- inline binders into the body when there's only a single binder, and -- only if that binder doesn't perform any work or is only used once in -- the body [(id2,e2)] | Just occ <- lookupVarEnv id2 bodyOccs, isWorkFree e2 || occ < 2 -> if id2 `localIdOccursIn` e2 -- Except when the binder is recursive! then changed ([(id2,e2),(id1, body2)]) else let subst = extendIdSubst (mkSubst isN1) id2 e2 in changed [(id1 -- Only apply srcTicks to the body ,mkTicks (substTm "flattenLetGo" subst body2) srcTicks)] bs -> changed (bs ++ [(id1 -- Only apply srcTicks to the body ,mkTicks body2 srcTicks)]) go isN b = return (isN,[b]) flattenLet _ e = return e {-# SCC flattenLet #-} -- | Worker function of 'separateArguments'. separateLambda :: TyConMap -> TransformContext -> Id -- ^ Lambda binder -> Term -- ^ Lambda body -> Maybe Term -- ^ If lambda is split up, this function returns a Just containing the new term separateLambda tcm ctx@(TransformContext is0 _) b eb0 = case shouldSplit tcm (varType b) of Just (dc,argTys@(_:_:_)) -> let nm = mkDerivedName ctx (nameOcc (varName b)) bs0 = map (`mkLocalId` nm) argTys (is1, bs1) = List.mapAccumL newBinder is0 bs0 subst = extendIdSubst (mkSubst is1) b (mkApps dc (map (Left . Var) bs1)) eb1 = substTm "separateArguments" subst eb0 in Just (mkLams eb1 bs1) _ -> Nothing where newBinder isN0 x = let x' = uniqAway isN0 x isN1 = extendInScopeSet isN0 x' in (isN1, x') {-# SCC separateLambda #-} -- | Split apart (global) function arguments that contain types that we -- want to separate off, e.g. Clocks. Works on both the definition side (i.e. the -- lambda), and the call site (i.e. the application of the global variable). e.g. -- turns -- -- > f :: (Clock System, Reset System) -> Signal System Int -- -- into -- -- > f :: Clock System -> Reset System -> Signal System Int separateArguments :: HasCallStack => NormRewrite separateArguments ctx e0@(Lam b eb) = do tcm <- Lens.view tcCache case separateLambda tcm ctx b eb of Just e1 -> changed e1 Nothing -> return e0 separateArguments (TransformContext is0 _) e@(collectArgsTicks -> (Var g, args, ticks)) | isGlobalId g = do -- We ensure that both the type of the global variable reference is updated -- to take into account the changed arguments, and that we apply the global -- function with the split apart arguments. let (argTys0,resTy) = splitFunForallTy (varType g) (concat -> args1, Monoid.getAny -> hasChanged) <- listen (mapM (uncurry splitArg) (zip argTys0 args)) if hasChanged then let (argTys1,args2) = unzip args1 gTy = mkPolyFunTy resTy argTys1 in return (mkApps (mkTicks (Var g {varType = gTy}) ticks) args2) else return e where -- Split a single argument splitArg :: Either TyVar Type -- The quantifier/function argument type of the global variable -> Either Term Type -- The applied type argument or term argument -> NormalizeSession [(Either TyVar Type,Either Term Type)] splitArg tv arg@(Right _) = return [(tv,arg)] splitArg ty arg@(Left tmArg) = do tcm <- Lens.view tcCache let argTy = termType tcm tmArg case shouldSplit tcm argTy of Just (_,argTys@(_:_:_)) -> do tmArgs <- mapM (mkSelectorCase ($(curLoc) ++ "splitArg") is0 tcm tmArg 1) [0..length argTys - 1] changed (map ((ty,) . Left) tmArgs) _ -> return [(ty,arg)] separateArguments _ e = return e {-# SCC separateArguments #-} -- | Remove all undefined alternatives from case expressions, replacing them -- with the value of another defined alternative. If there is one defined -- alternative, the entire expression is replaced with that alternative. If -- there are no defined alternatives, the entire expression is replaced with -- a call to 'errorX'. -- -- e.g. It converts -- -- case x of -- D1 a -> f a -- D2 -> undefined -- D3 -> undefined -- -- to -- -- let subj = x -- a = case subj of -- D1 a -> field0 -- in f a -- -- where fieldN is an internal variable referring to the nth argument of a -- data constructor. -- xOptimize :: HasCallStack => NormRewrite xOptimize (TransformContext is0 _) e@(Case subj ty alts) = do runXOpt <- Lens.view aggressiveXOpt if runXOpt then do defPart <- List.partitionM (isPrimError . snd) alts case defPart of ([], _) -> return e (_, []) -> changed (Prim (PrimInfo "Clash.XException.errorX" ty WorkConstant)) (_, [alt]) -> xOptimizeSingle is0 subj alt (_, defs) -> xOptimizeMany is0 subj ty defs else return e xOptimize _ e = return e {-# SCC xOptimize #-} -- Return an expression equivalent to the alternative given. When only one -- alternative is defined the result of this function is used to replace the -- case expression. -- xOptimizeSingle :: InScopeSet -> Term -> Alt -> NormalizeSession Term xOptimizeSingle is subj (DataPat dc tvs vars, expr) = do tcm <- Lens.view tcCache subjId <- mkInternalVar is "subj" (termType tcm subj) let fieldTys = fmap varType vars lets <- Monad.zipWithM (mkFieldSelector is subjId dc tvs fieldTys) vars [0..] changed (Letrec ((subjId, subj) : lets) expr) xOptimizeSingle _ _ (_, expr) = changed expr -- Given a list of alternatives which are defined, create a new case -- expression which only ever returns a defined value. -- xOptimizeMany :: HasCallStack => InScopeSet -> Term -> Type -> [Alt] -> NormalizeSession Term xOptimizeMany is subj ty defs@(d:ds) | isAnyDefault defs = changed (Case subj ty defs) | otherwise = do newAlt <- xOptimizeSingle is subj d changed (Case subj ty $ ds <> [(DefaultPat, newAlt)]) where isAnyDefault :: [Alt] -> Bool isAnyDefault = any ((== DefaultPat) . fst) xOptimizeMany _ _ _ [] = error $ $(curLoc) ++ "Report as bug: xOptimizeMany error: No defined alternatives" mkFieldSelector :: MonadUnique m => InScopeSet -> Id -- ^ subject id -> DataCon -> [TyVar] -> [Type] -- ^ concrete types of fields -> Id -> Int -> m LetBinding mkFieldSelector is0 subj dc tvs fieldTys nm index = do fields <- mapM (\ty -> mkInternalVar is0 "field" ty) fieldTys let alt = (DataPat dc tvs fields, Var $ fields !! index) return (nm, Case (Var subj) (fieldTys !! index) [alt]) -- Check whether a term is really a black box primitive representing an error. -- Such values are undefined and are removed in X Optimization. -- isPrimError :: Term -> NormalizeSession Bool isPrimError (collectArgs -> (Prim pInfo, _)) = do prim <- Lens.use (extra . primitives . Lens.at (primName pInfo)) case prim >>= extractPrim of Just p -> return (isErr p) Nothing -> return False where isErr BlackBox{template=(BBTemplate [Err _])} = True isErr _ = False isPrimError _ = return False