{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NondecreasingIndentation #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -Wno-unused-imports #-}
module Clash.Rewrite.Util where
import Control.DeepSeq
import Control.Exception (throw)
import Control.Lens
(Lens', (%=), (+=), (^.), _3, _4, _Left)
import qualified Control.Lens as Lens
import qualified Control.Monad as Monad
#if !MIN_VERSION_base(4,13,0)
import Control.Monad.Fail (MonadFail)
#endif
import qualified Control.Monad.State.Strict as State
import qualified Control.Monad.Writer as Writer
import Data.Bifunctor (bimap)
import Data.Coerce (coerce)
import Data.Functor.Const (Const (..))
import Data.List (group, sort)
import qualified Data.Map as Map
import Data.Maybe (catMaybes,isJust,mapMaybe)
import qualified Data.Monoid as Monoid
import qualified Data.Set as Set
import qualified Data.Set.Lens as Lens
import Data.Text (Text)
import qualified Data.Text as Text
#ifdef HISTORY
import Data.Binary (encode)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BL
import System.IO.Unsafe (unsafePerformIO)
#endif
import BasicTypes (InlineSpec (..))
import Clash.Core.DataCon (dcExtTyVars)
import Clash.Core.FreeVars
(freeLocalVars, hasLocalFreeVars, localIdDoesNotOccurIn, localIdOccursIn,
typeFreeVars, termFreeVars', freeIds)
import Clash.Core.Name
import Clash.Core.Pretty (showPpr)
import Clash.Core.Subst
(aeqTerm, aeqType, extendIdSubst, mkSubst, substTm)
import Clash.Core.Term
(LetBinding, Pat (..), Term (..), CoreContext (..), Context, PrimInfo (..),
TmName, WorkInfo (..), TickInfo, collectArgs, collectArgsTicks)
import Clash.Core.TyCon
(TyConMap, tyConDataCons)
import Clash.Core.Type (KindOrType, Type (..),
TypeView (..), coreView1,
normalizeType,
typeKind, tyView, isPolyFunTy)
import Clash.Core.Util
(isPolyFun, mkAbstraction, mkApps, mkLams, mkTicks,
mkTmApps, mkTyApps, mkTyLams, termType, dataConInstArgTysE, isClockOrReset)
import Clash.Core.Var
(Id, IdScope (..), TyVar, Var (..), isLocalId, mkGlobalId, mkLocalId, mkTyVar)
import Clash.Core.VarEnv
(InScopeSet, VarEnv, elemVarSet, extendInScopeSetList, mkInScopeSet,
notElemVarEnv, uniqAway)
import Clash.Driver.Types
(DebugLevel (..))
import Clash.Netlist.Util (representableType)
import Clash.Rewrite.Types
import Clash.Unique
import Clash.Util
zoomExtra :: State.State extra a
-> RewriteMonad extra a
zoomExtra m = R (\_ s w -> case State.runState m (s ^. extra) of
(a,s') -> (a,s {_extra = s'},w))
findAccidentialShadows :: Term -> [[Id]]
findAccidentialShadows =
\case
Var {} -> []
Data {} -> []
Literal {} -> []
Prim {} -> []
Lam _ t -> findAccidentialShadows t
TyLam _ t -> findAccidentialShadows t
App t1 t2 -> concatMap findAccidentialShadows [t1, t2]
TyApp t _ -> findAccidentialShadows t
Cast t _ _ -> findAccidentialShadows t
Tick _ t -> findAccidentialShadows t
Case t _ as ->
concatMap (findInPat . fst) as ++
concatMap findAccidentialShadows (t : map snd as)
Letrec bs t ->
findDups (map fst bs) ++ findAccidentialShadows t
where
findInPat :: Pat -> [[Id]]
findInPat (LitPat _) = []
findInPat (DefaultPat) = []
findInPat (DataPat _ _ ids) = findDups ids
findDups :: [Id] -> [[Id]]
findDups ids = filter ((1 <) . length) (group (sort ids))
apply
:: String
-> Rewrite extra
-> Rewrite extra
apply = \s rewrite ctx expr0 -> do
lvl <- Lens.view dbgLevel
(expr1,anyChanged) <- Writer.listen (rewrite ctx expr0)
let hasChanged = Monoid.getAny anyChanged
!expr2 = if hasChanged then expr1 else expr0
Monad.when hasChanged (transformCounter += 1)
#ifdef HISTORY
Monad.when hasChanged $ do
(curBndr, _) <- Lens.use curFun
let !_ = unsafePerformIO
$ BS.appendFile "history.dat"
$ BL.toStrict
$ encode RewriteStep
{ t_ctx = tfContext ctx
, t_name = s
, t_bndrS = showPpr (varName curBndr)
, t_before = expr0
, t_after = expr1
}
return ()
#endif
if lvl == DebugNone
then return expr2
else applyDebug lvl s expr0 hasChanged expr2
{-# INLINE apply #-}
applyDebug
:: DebugLevel
-> String
-> Term
-> Bool
-> Term
-> RewriteMonad extra Term
applyDebug lvl name exprOld hasChanged exprNew =
traceIf (lvl >= DebugAll) ("Trying: " ++ name ++ " on:\n" ++ before) $ do
Monad.when (lvl > DebugNone && hasChanged) $ do
tcm <- Lens.view tcCache
let beforeTy = termType tcm exprOld
beforeFV = Lens.setOf freeLocalVars exprOld
afterTy = termType tcm exprNew
afterFV = Lens.setOf freeLocalVars exprNew
newFV = not (afterFV `Set.isSubsetOf` beforeFV)
accidentalShadows = findAccidentialShadows exprNew
Monad.when newFV $
error ( concat [ $(curLoc)
, "Error when applying rewrite ", name
, " to:\n" , before
, "\nResult:\n" ++ after ++ "\n"
, "It introduces free variables."
, "\nBefore: " ++ showPpr (Set.toList beforeFV)
, "\nAfter: " ++ showPpr (Set.toList afterFV)
]
)
Monad.when (not (null accidentalShadows)) $
error ( concat [ $(curLoc)
, "Error when applying rewrite ", name
, " to:\n" , before
, "\nResult:\n" ++ after ++ "\n"
, "It accidentally creates shadowing let/case-bindings:\n"
, " ", showPpr accidentalShadows, "\n"
, "This usually means that a transformation did not extend "
, "or incorrectly extended its InScopeSet before applying a "
, "substitution."
])
traceIf (lvl >= DebugAll && (beforeTy `aeqType` afterTy))
( concat [ $(curLoc)
, "Error when applying rewrite ", name
, " to:\n" , before
, "\nResult:\n" ++ after ++ "\n"
, "Changes type from:\n", showPpr beforeTy
, "\nto:\n", showPpr afterTy
]
) (return ())
Monad.when (lvl >= DebugApplied && not hasChanged && not (exprOld `aeqTerm` exprNew)) $
error $ $(curLoc) ++ "Expression changed without notice(" ++ name ++ "): before"
++ before ++ "\nafter:\n" ++ after
traceIf (lvl >= DebugName && hasChanged) name $
traceIf (lvl >= DebugApplied && hasChanged) ("Changes when applying rewrite to:\n"
++ before ++ "\nResult:\n" ++ after ++ "\n") $
traceIf (lvl >= DebugAll && not hasChanged) ("No changes when applying rewrite "
++ name ++ " to:\n" ++ after ++ "\n") $
return exprNew
where
before = showPpr exprOld
after = showPpr exprNew
runRewrite
:: String
-> InScopeSet
-> Rewrite extra
-> Term
-> RewriteMonad extra Term
runRewrite name is rewrite expr = apply name rewrite (TransformContext is []) expr
runRewriteSession :: RewriteEnv
-> RewriteState extra
-> RewriteMonad extra a
-> a
runRewriteSession r s m = traceIf True ("Clash: Applied " ++
show (s' ^. transformCounter) ++
" transformations")
a
where
(a,s',_) = runR m r s
setChanged :: RewriteMonad extra ()
setChanged = Writer.tell (Monoid.Any True)
changed :: a -> RewriteMonad extra a
changed val = do
Writer.tell (Monoid.Any True)
return val
closestLetBinder :: Context -> Maybe Id
closestLetBinder [] = Nothing
closestLetBinder (LetBinding id_ _:_) = Just id_
closestLetBinder (_:ctx) = closestLetBinder ctx
mkDerivedName :: TransformContext -> OccName -> TmName
mkDerivedName (TransformContext _ ctx) sf = case closestLetBinder ctx of
Just id_ -> appendToName (varName id_) ('_' `Text.cons` sf)
_ -> mkUnsafeInternalName sf 0
mkTmBinderFor
:: (Monad m, MonadUnique m, MonadFail m)
=> InScopeSet
-> TyConMap
-> Name a
-> Term
-> m Id
mkTmBinderFor is tcm name e = do
Left r <- mkBinderFor is tcm name (Left e)
return r
mkBinderFor
:: (Monad m, MonadUnique m, MonadFail m)
=> InScopeSet
-> TyConMap
-> Name a
-> Either Term Type
-> m (Either Id TyVar)
mkBinderFor is tcm name (Left term) = do
name' <- cloneName name
let ty = termType tcm term
return (Left (uniqAway is (mkLocalId ty (coerce name'))))
mkBinderFor is tcm name (Right ty) = do
name' <- cloneName name
let ki = typeKind tcm ty
return (Right (uniqAway is (mkTyVar ki (coerce name'))))
mkInternalVar
:: (Monad m, MonadUnique m)
=> InScopeSet
-> OccName
-> KindOrType
-> m Id
mkInternalVar inScope name ty = do
i <- getUniqueM
let nm = mkUnsafeInternalName name i
return (uniqAway inScope (mkLocalId ty nm))
inlineBinders
:: (Term -> LetBinding -> RewriteMonad extra Bool)
-> Rewrite extra
inlineBinders condition (TransformContext inScope0 _) expr@(Letrec xes res) = do
(replace,others) <- partitionM (condition expr) xes
case replace of
[] -> return expr
_ -> do
let inScope1 = extendInScopeSetList inScope0 (map fst xes)
(others',res') = substituteBinders inScope1 replace others res
newExpr = case others' of
[] -> res'
_ -> Letrec others' res'
changed newExpr
inlineBinders _ _ e = return e
isJoinPointIn :: Id
-> Term
-> Bool
isJoinPointIn id_ e = case tailCalls id_ e of
Just n | n > 1 -> True
_ -> False
tailCalls :: Id
-> Term
-> Maybe Int
tailCalls id_ = \case
Var nm | id_ == nm -> Just 1
| otherwise -> Just 0
Lam _ e -> tailCalls id_ e
TyLam _ e -> tailCalls id_ e
App l r -> case tailCalls id_ r of
Just 0 -> tailCalls id_ l
_ -> Nothing
TyApp l _ -> tailCalls id_ l
Letrec bs e ->
let (bsIds,bsExprs) = unzip bs
bsTls = map (tailCalls id_) bsExprs
bsIdsUsed = mapMaybe (\(l,r) -> pure l <* r) (zip bsIds bsTls)
bsIdsTls = map (`tailCalls` e) bsIdsUsed
bsCount = pure . sum $ catMaybes bsTls
in case (all isJust bsTls) of
False -> Nothing
True -> case (all (==0) $ catMaybes bsTls) of
False -> case all isJust bsIdsTls of
False -> Nothing
True -> (+) <$> bsCount <*> tailCalls id_ e
True -> tailCalls id_ e
Case scrut _ alts ->
let scrutTl = tailCalls id_ scrut
altsTl = map (tailCalls id_ . snd) alts
in case scrutTl of
Just 0 | all (/= Nothing) altsTl -> Just (sum (catMaybes altsTl))
_ -> Nothing
_ -> Just 0
isVoidWrapper :: Term -> Bool
isVoidWrapper (Lam bndr e@(collectArgs -> (Var _,_))) =
bndr `localIdDoesNotOccurIn` e
isVoidWrapper _ = False
substituteBinders
:: InScopeSet
-> [LetBinding]
-> [LetBinding]
-> Term
-> ([LetBinding],Term)
substituteBinders _ [] others res = (others,res)
substituteBinders inScope ((bndr,val):rest) others res =
substituteBinders inScope rest' others' res'
where
subst = extendIdSubst (mkSubst inScope) bndr val
selfRef = bndr `localIdOccursIn` val
(res',rest',others') = if selfRef
then (res,rest,(bndr,val):others)
else ( substTm "substituteBindersRes" subst res
, map (second (substTm "substituteBindersRest" subst)) rest
, map (second (substTm "substituteBindersOthers" subst)) others
)
isWorkFree
:: Term
-> Bool
isWorkFree (collectArgs -> (fun,args)) = case fun of
Var i -> isLocalId i && not (isPolyFunTy (varType i))
Data {} -> all isWorkFreeArg args
Literal {} -> True
Prim _ pInfo -> case primWorkInfo pInfo of
WorkConstant -> True
WorkNever -> all isWorkFreeArg args
WorkVariable -> all isConstantArg args
WorkAlways -> False
Lam _ e -> isWorkFree e && all isWorkFreeArg args
TyLam _ e -> isWorkFree e && all isWorkFreeArg args
Letrec bs e ->
isWorkFree e && all (isWorkFree . snd) bs && all isWorkFreeArg args
Case s _ [(_,a)] -> isWorkFree s && isWorkFree a && all isWorkFreeArg args
Cast e _ _ -> isWorkFree e && all isWorkFreeArg args
_ -> False
where
isWorkFreeArg = either isWorkFree (const True)
isConstantArg = either isConstant (const True)
isFromInt :: Text -> Bool
isFromInt nm = nm == "Clash.Sized.Internal.BitVector.fromInteger##" ||
nm == "Clash.Sized.Internal.BitVector.fromInteger#" ||
nm == "Clash.Sized.Internal.Index.fromInteger#" ||
nm == "Clash.Sized.Internal.Signed.fromInteger#" ||
nm == "Clash.Sized.Internal.Unsigned.fromInteger#"
isConstant :: Term -> Bool
isConstant e = case collectArgs e of
(Data _, args) -> all (either isConstant (const True)) args
(Prim _ _, args) -> all (either isConstant (const True)) args
(Lam _ _, _) -> not (hasLocalFreeVars e)
(Literal _,_) -> True
_ -> False
isConstantNotClockReset
:: Term
-> RewriteMonad extra Bool
isConstantNotClockReset e = do
tcm <- Lens.view tcCache
let eTy = termType tcm e
if isClockOrReset tcm eTy
then case collectArgs e of
(Prim nm _,_) -> return (nm == "Clash.Transformations.removedArg")
_ -> return False
else pure (isConstant e)
inlineOrLiftBinders
:: (LetBinding -> RewriteMonad extra Bool)
-> (Term -> LetBinding -> RewriteMonad extra Bool)
-> Rewrite extra
inlineOrLiftBinders condition inlineOrLift (TransformContext inScope0 _) expr@(Letrec xes res) = do
(replace,others) <- partitionM condition xes
case replace of
[] -> return expr
_ -> do
let inScope1 = extendInScopeSetList inScope0 (map fst xes)
(doInline,doLift) <- partitionM (inlineOrLift expr) replace
let (others',res') = substituteBinders inScope1 doInline (doLift ++ others) res
(doLift',others'') = splitAt (length doLift) others'
doLift'' <- mapM liftBinding doLift'
let (others3,res'') = substituteBinders inScope1 doLift'' others'' res'
newExpr = case others3 of
[] -> res''
_ -> Letrec others3 res''
changed newExpr
inlineOrLiftBinders _ _ _ e = return e
liftBinding :: LetBinding
-> RewriteMonad extra LetBinding
liftBinding (var@Id {varName = idName} ,e) = do
let unitFV :: Var a -> Const (UniqSet TyVar,UniqSet Id) (Var a)
unitFV v@(Id {}) = Const (emptyUniqSet,unitUniqSet (coerce v))
unitFV v@(TyVar {}) = Const (unitUniqSet (coerce v),emptyUniqSet)
interesting :: Var a -> Bool
interesting Id {idScope = GlobalId} = False
interesting v@(Id {idScope = LocalId}) = varUniq v /= varUniq var
interesting _ = True
(boundFTVsSet,boundFVsSet) =
getConst (Lens.foldMapOf (termFreeVars' interesting) unitFV e)
boundFTVs = eltsUniqSet boundFTVsSet
boundFVs = eltsUniqSet boundFVsSet
tcm <- Lens.view tcCache
let newBodyTy = termType tcm $ mkTyLams (mkLams e boundFVs) boundFTVs
(cf,sp) <- Lens.use curFun
newBodyNm <- cloneName (appendToName (varName cf) ("_" `Text.append` nameOcc idName))
let newBodyId = mkGlobalId newBodyTy newBodyNm {nameSort = Internal}
let newExpr = mkTmApps
(mkTyApps (Var newBodyId)
(map VarTy boundFTVs))
(map Var boundFVs)
inScope0 = mkInScopeSet (coerce boundFVsSet)
inScope1 = extendInScopeSetList inScope0 [var,newBodyId]
let subst = extendIdSubst (mkSubst inScope1) var newExpr
e' = substTm "liftBinding" subst e
newBody = mkTyLams (mkLams e' boundFVs) boundFTVs
aeqExisting <- (eltsUniqMap . filterUniqMap ((`aeqTerm` newBody) . (^. _4))) <$> Lens.use bindings
case aeqExisting of
[] -> do
bindings %= extendUniqMap newBodyNm
(newBodyId
,sp
#if MIN_VERSION_ghc(8,4,1)
,NoUserInline
#else
,EmptyInlineSpec
#endif
,newBody)
return (var, newExpr)
((k,_,_,_):_) ->
let newExpr' = mkTmApps
(mkTyApps (Var k)
(map VarTy boundFTVs))
(map Var boundFVs)
in return (var, newExpr')
liftBinding _ = error $ $(curLoc) ++ "liftBinding: invalid core, expr bound to tyvar"
mkFunction
:: TmName
-> SrcSpan
-> InlineSpec
-> Term
-> RewriteMonad extra Id
mkFunction bndrNm sp inl body = do
tcm <- Lens.view tcCache
let bodyTy = termType tcm body
bodyNm <- cloneName bndrNm
addGlobalBind bodyNm bodyTy sp inl body
return (mkGlobalId bodyTy bodyNm)
addGlobalBind
:: TmName
-> Type
-> SrcSpan
-> InlineSpec
-> Term
-> RewriteMonad extra ()
addGlobalBind vNm ty sp inl body = do
let vId = mkGlobalId ty vNm
(ty,body) `deepseq` bindings %= extendUniqMap vNm (vId,sp,inl,body)
cloneName
:: (Monad m, MonadUnique m)
=> Name a
-> m (Name a)
cloneName nm = do
i <- getUniqueM
return nm {nameUniq = i}
{-# INLINE isUntranslatable #-}
isUntranslatable
:: Bool
-> Term
-> RewriteMonad extra Bool
isUntranslatable stringRepresentable tm = do
tcm <- Lens.view tcCache
not <$> (representableType <$> Lens.view typeTranslator
<*> Lens.view customReprs
<*> pure stringRepresentable
<*> pure tcm
<*> pure (termType tcm tm))
{-# INLINE isUntranslatableType #-}
isUntranslatableType
:: Bool
-> Type
-> RewriteMonad extra Bool
isUntranslatableType stringRepresentable ty =
not <$> (representableType <$> Lens.view typeTranslator
<*> Lens.view customReprs
<*> pure stringRepresentable
<*> Lens.view tcCache
<*> pure ty)
mkWildValBinder
:: (Monad m, MonadUnique m)
=> InScopeSet
-> Type
-> m Id
mkWildValBinder is = mkInternalVar is "wild"
mkSelectorCase
:: HasCallStack
=> (Functor m, Monad m, MonadUnique m)
=> String
-> InScopeSet
-> TyConMap
-> Term
-> Int
-> Int
-> m Term
mkSelectorCase caller inScope tcm scrut dcI fieldI = go (termType tcm scrut)
where
go (coreView1 tcm -> Just ty') = go ty'
go scrutTy@(tyView -> TyConApp tc args) =
case tyConDataCons (lookupUniqMap' tcm tc) of
[] -> cantCreate $(curLoc) ("TyCon has no DataCons: " ++ show tc ++ " " ++ showPpr tc) scrutTy
dcs | dcI > length dcs -> cantCreate $(curLoc) "DC index exceeds max" scrutTy
| otherwise -> do
let dc = indexNote ($(curLoc) ++ "No DC with tag: " ++ show (dcI-1)) dcs (dcI-1)
let (Just fieldTys) = dataConInstArgTysE inScope tcm dc args
if fieldI >= length fieldTys
then cantCreate $(curLoc) "Field index exceed max" scrutTy
else do
wildBndrs <- mapM (mkWildValBinder inScope) fieldTys
let ty = indexNote ($(curLoc) ++ "No DC field#: " ++ show fieldI) fieldTys fieldI
selBndr <- mkInternalVar inScope "sel" ty
let bndrs = take fieldI wildBndrs ++ [selBndr] ++ drop (fieldI+1) wildBndrs
pat = DataPat dc (dcExtTyVars dc) bndrs
retVal = Case scrut ty [ (pat, Var selBndr) ]
return retVal
go scrutTy = cantCreate $(curLoc) ("Type of subject is not a datatype: " ++ showPpr scrutTy) scrutTy
cantCreate loc info scrutTy = error $ loc ++ "Can't create selector " ++ show (caller,dcI,fieldI) ++ " for: (" ++ showPpr scrut ++ " :: " ++ showPpr scrutTy ++ ")\nAdditional info: " ++ info
specialise :: Lens' extra (Map.Map (Id, Int, Either Term Type) Id)
-> Lens' extra (VarEnv Int)
-> Lens' extra Int
-> Rewrite extra
specialise specMapLbl specHistLbl specLimitLbl ctx e = case e of
(TyApp e1 ty) -> specialise' specMapLbl specHistLbl specLimitLbl ctx e (collectArgsTicks e1) (Right ty)
(App e1 e2) -> specialise' specMapLbl specHistLbl specLimitLbl ctx e (collectArgsTicks e1) (Left e2)
_ -> return e
specialise' :: Lens' extra (Map.Map (Id, Int, Either Term Type) Id)
-> Lens' extra (VarEnv Int)
-> Lens' extra Int
-> TransformContext
-> Term
-> (Term, [Either Term Type], [TickInfo])
-> Either Term Type
-> RewriteMonad extra Term
specialise' specMapLbl specHistLbl specLimitLbl (TransformContext is0 _) e (Var f, args, ticks) specArgIn = do
lvl <- Lens.view dbgLevel
topEnts <- Lens.view topEntities
if f `elemVarSet` topEnts
then traceIf (lvl >= DebugNone) ("Not specialising TopEntity: " ++ showPpr (varName f)) (return e)
else do
tcm <- Lens.view tcCache
let specArg = bimap (normalizeTermTypes tcm) (normalizeType tcm) specArgIn
(specBndrsIn,specVars) = specArgBndrsAndVars specArg
argLen = length args
specBndrs :: [Either Id TyVar]
specBndrs = map (Lens.over _Left (normalizeId tcm)) specBndrsIn
specAbs :: Either Term Type
specAbs = either (Left . (`mkAbstraction` specBndrs)) (Right . id) specArg
specM <- Map.lookup (f,argLen,specAbs) <$> Lens.use (extra.specMapLbl)
case specM of
Just f' ->
traceIf (lvl >= DebugApplied)
("Using previous specialization of " ++ showPpr (varName f) ++ " on " ++
(either showPpr showPpr) specAbs ++ ": " ++ showPpr (varName f')) $
changed $ mkApps (mkTicks (Var f') ticks) (args ++ specVars)
Nothing -> do
bodyMaybe <- fmap (lookupUniqMap (varName f)) $ Lens.use bindings
case bodyMaybe of
Just (_,sp,inl,bodyTm) -> do
specHistM <- lookupUniqMap f <$> Lens.use (extra.specHistLbl)
specLim <- Lens.use (extra . specLimitLbl)
if maybe False (> specLim) specHistM
then throw (ClashException
sp
(unlines [ "Hit specialisation limit " ++ show specLim ++ " on function `" ++ showPpr (varName f) ++ "'.\n"
, "The function `" ++ showPpr f ++ "' is most likely recursive, and looks like it is being indefinitely specialized on a growing argument.\n"
, "Body of `" ++ showPpr f ++ "':\n" ++ showPpr bodyTm ++ "\n"
, "Argument (in position: " ++ show argLen ++ ") that triggered termination:\n" ++ (either showPpr showPpr) specArg
, "Run with '-fclash-spec-limit=N' to increase the specialisation limit to N."
])
Nothing)
else do
let existingNames = collectBndrsMinusApps bodyTm
newNames = [ mkUnsafeInternalName ("pTS" `Text.append` Text.pack (show n)) n
| n <- [(0::Int)..]
]
(boundArgs,argVars) <- fmap (unzip . map (either (Left &&& Left . Var) (Right &&& Right . VarTy))) $
Monad.zipWithM
(mkBinderFor is0 tcm)
(existingNames ++ newNames)
args
(fId,inl',specArg') <- case specArg of
Left a@(collectArgsTicks -> (Var g,gArgs,_gTicks)) -> if isPolyFun tcm a
then do
gTmM <- fmap (lookupUniqMap (varName g)) $ Lens.use bindings
return (g,maybe inl (^. _3) gTmM, maybe specArg (Left . (`mkApps` gArgs) . (^. _4)) gTmM)
else return (f,inl,specArg)
_ -> return (f,inl,specArg)
let newBody = mkAbstraction (mkApps bodyTm (argVars ++ [specArg'])) (boundArgs ++ specBndrs)
newf <- mkFunction (varName fId) sp inl' newBody
(extra.specHistLbl) %= extendUniqMapWith f 1 (+)
(extra.specMapLbl) %= Map.insert (f,argLen,specAbs) newf
let newExpr = mkApps (mkTicks (Var newf) ticks) (args ++ specVars)
newf `deepseq` changed newExpr
Nothing -> return e
where
collectBndrsMinusApps :: Term -> [Name a]
collectBndrsMinusApps = reverse . go []
where
go bs (Lam v e') = go (coerce (varName v):bs) e'
go bs (TyLam tv e') = go (coerce (varName tv):bs) e'
go bs (App e' _) = case go [] e' of
[] -> bs
bs' -> init bs' ++ bs
go bs (TyApp e' _) = case go [] e' of
[] -> bs
bs' -> init bs' ++ bs
go bs _ = bs
specialise' _ _ _ _ctx _ (appE,args,ticks) (Left specArg) = do
let (specBndrs,specVars) = specArgBndrsAndVars (Left specArg)
newBody = mkAbstraction specArg specBndrs
existing <- filterUniqMap ((`aeqTerm` newBody) . (^. _4)) <$> Lens.use bindings
newf <- case eltsUniqMap existing of
[] -> do (cf,sp) <- Lens.use curFun
mkFunction (appendToName (varName cf) "_specF")
sp
#if MIN_VERSION_ghc(8,4,1)
NoUserInline
#else
EmptyInlineSpec
#endif
newBody
((k,_,_,_):_) -> return k
let newArg = Left $ mkApps (Var newf) specVars
let newExpr = mkApps (mkTicks appE ticks) (args ++ [newArg])
changed newExpr
specialise' _ _ _ _ e _ _ = return e
normalizeTermTypes :: TyConMap -> Term -> Term
normalizeTermTypes tcm e = case e of
Cast e' ty1 ty2 -> Cast (normalizeTermTypes tcm e') (normalizeType tcm ty1) (normalizeType tcm ty2)
Var v -> Var (normalizeId tcm v)
_ -> e
normalizeId :: TyConMap -> Id -> Id
normalizeId tcm v@(Id {}) = v {varType = normalizeType tcm (varType v)}
normalizeId _ tyvar = tyvar
specArgBndrsAndVars
:: Either Term Type
-> ([Either Id TyVar], [Either Term Type])
specArgBndrsAndVars specArg =
let unitFV :: Var a -> Const (UniqSet TyVar,UniqSet Id) (Var a)
unitFV v@(Id {}) = Const (emptyUniqSet,unitUniqSet (coerce v))
unitFV v@(TyVar {}) = Const (unitUniqSet (coerce v),emptyUniqSet)
(specFTVs,specFVs) = case specArg of
Left tm -> (eltsUniqSet *** eltsUniqSet) . getConst $
Lens.foldMapOf freeLocalVars unitFV tm
Right ty -> (eltsUniqSet (Lens.foldMapOf typeFreeVars unitUniqSet ty),[] :: [Id])
specTyBndrs = map Right specFTVs
specTmBndrs = map Left specFVs
specTyVars = map (Right . VarTy) specFTVs
specTmVars = map (Left . Var) specFVs
in (specTyBndrs ++ specTmBndrs,specTyVars ++ specTmVars)