module Language.SequentCore.SpecConstr (
plugin
) where
import Language.SequentCore.Plugin
import Language.SequentCore.Pretty ()
import Language.SequentCore.Syntax
import Language.SequentCore.Translate
import CoreMonad ( CoreM
, Plugin(installCoreToDos), defaultPlugin
, errorMsg
, reinitializeGlobals
, CoreToDo(CoreDoSpecConstr, CoreDoPasses, CoreDoPluginPass) )
import CoreUnfold ( couldBeSmallEnoughToInline )
import CoreSyn ( CoreRule )
import DynFlags ( DynFlags(specConstrThreshold), getDynFlags )
import FastString ( fsLit, mkFastString )
import Id ( Id, mkSysLocalM, idName, idInlineActivation )
import Name ( nameOccName, occNameString )
import Outputable hiding ((<>))
import Rules ( mkRule, addIdSpecialisations )
import Var ( Var, varType )
import VarEnv
import VarSet
import Control.Applicative ( (<$>), (<|>) )
import Control.Monad
import Data.List ( nubBy )
import Data.Monoid
plugin :: Plugin
plugin = defaultPlugin {
installCoreToDos = \_ todos -> do
reinitializeGlobals
case replace todos of
Nothing ->
do
errorMsg (text "Could not find SpecConstr pass to replace")
return todos
Just todos' ->
return todos'
} where
replace (CoreDoSpecConstr : todos)
= Just (specConstrPass : todos)
replace (cdp@(CoreDoPasses todos1) : todos2)
= do
todos1' <- replace todos1
return $ CoreDoPasses todos1' : todos2
<|>
do
todos2' <- replace todos2
return $ cdp : todos2'
replace (todo : todos)
= (todo :) <$> replace todos
replace []
= Nothing
specConstrPass = CoreDoPluginPass "SeqSpecConstr" (sequentPass specModule)
data ScEnv
= SCE { sc_size :: Maybe Int
, sc_how_bound :: VarEnv HowBound
, sc_dflags :: DynFlags
}
instance Outputable ScEnv where
ppr (SCE { sc_size = sz, sc_how_bound = hb })
= sep [hang (text "SCE {") 2 $ sep [
text "sc_size" <+> equals <+> maybe (text "(any)") int sz <+> comma,
text "sc_how_bound" <+> equals <+> ppr hb
], char '}']
data ScUsage = ScUsage Calls ArgUsage
instance Outputable ScUsage where
ppr (ScUsage calls usage)
= hang (text "ScUsage") 2 $ sep [ppr calls, ppr usage]
type Calls = VarEnv [Call]
type Call = [SeqCoreValue]
data HowBound = SpecFun | SpecArg
instance Outputable HowBound where
ppr SpecFun = text "SpecFun"
ppr SpecArg = text "SpecArg"
type ArgUsage = VarSet
specModule :: [SeqCoreBind] -> CoreM [SeqCoreBind]
specModule binds = do
env <- initScEnv <$> getDynFlags
map snd <$> mapM (specInBind env) binds
initScEnv :: DynFlags -> ScEnv
initScEnv dflags = SCE { sc_size = specConstrThreshold dflags
, sc_how_bound = emptyVarEnv
, sc_dflags = dflags }
emptyScUsage :: ScUsage
emptyScUsage = ScUsage emptyVarEnv emptyVarSet
instance Monoid ScUsage where
mempty
= emptyScUsage
ScUsage calls1 used1 `mappend` ScUsage calls2 used2
= ScUsage (plusVarEnv_C (++) calls1 calls2) (used1 `unionVarSet` used2)
specInValue :: ScEnv -> SeqCoreValue -> CoreM (ScUsage, SeqCoreValue)
specInValue env (Lam x c)
= do
(usage, c') <- specInCommand env' c
return (usage, Lam x c')
where
env' = env { sc_how_bound = extendVarEnv hb x SpecArg }
hb = sc_how_bound env
specInValue env (Compute c)
= do
(usage, c') <- specInCommand env c
return (usage, mkCompute c')
specInValue _ v
= return (emptyScUsage, v)
specInCont :: ScEnv -> SeqCoreCont -> CoreM (ScUsage, SeqCoreCont)
specInCont env (App v k)
= do
(usage1, v') <- specInValue env v
(usage2, k') <- specInCont env k
return (usage1 <> usage2, App v' k')
specInCont env (Case x t as k)
= do
(usages1, as') <- unzip <$> mapM (specInAlt env) as
(usage2, k') <- specInCont env k
return (mconcat usages1 <> usage2, Case x t as' k')
specInCont env (Cast co k)
= do
(usage, k') <- specInCont env k
return (usage, Cast co k')
specInCont env (Tick ti k)
= do
(usage, k') <- specInCont env k
return (usage, Tick ti k')
specInCont _ k
= return (emptyScUsage, k)
specInAlt :: ScEnv -> SeqCoreAlt -> CoreM (ScUsage, SeqCoreAlt)
specInAlt env (Alt ac xs c)
= do
(usage, c') <- specInCommand env c
return (usage, Alt ac xs c')
specInBind :: ScEnv -> SeqCoreBind -> CoreM (ScUsage, SeqCoreBind)
specInBind env b
= do
(u, _, b') <- specBind env b
return (u, b')
specInCommand :: ScEnv -> SeqCoreCommand -> CoreM (ScUsage, SeqCoreCommand)
specInCommand env (Command { cmdLet = bs, cmdValue = v, cmdCont = fs })
= specBinds env bs [] []
where
specBinds :: ScEnv -> [SeqCoreBind] -> [SeqCoreBind] -> [ScUsage]
-> CoreM (ScUsage, SeqCoreCommand)
specBinds env [] bs' usages
= do
(usage', v', fs') <- specInCut env v fs
return (mconcat (usage' : usages), Command
{ cmdLet = reverse bs', cmdValue = v', cmdCont = fs' })
specBinds env (b : bs) bs' usages
= do
(usage', env', b') <- specBind env b
specBinds env' bs (b' : bs') (usage' : usages)
specInCut :: ScEnv -> SeqCoreValue -> SeqCoreCont
-> CoreM (ScUsage, SeqCoreValue, SeqCoreCont)
specInCut env v k
= do
let u = usageFromCut env v k
(u_v, v') <- specInValue env v
(u_k, k') <- specInCont env k
return (u <> u_v <> u_k, v', k')
usageFromCut :: ScEnv -> SeqCoreValue -> SeqCoreCont -> ScUsage
usageFromCut env (Var x) (Case {})
| Just SpecArg <- sc_how_bound env `lookupVarEnv` x
= ScUsage emptyVarEnv (unitVarSet x)
usageFromCut env v@(Var f) k
| Just SpecFun <- sc_how_bound env `lookupVarEnv` f
, Just (args, _) <- asSaturatedCall v k
= ScUsage (unitVarEnv f [args]) emptyVarSet
usageFromCut _ _ _
= emptyScUsage
specBind :: ScEnv -> SeqCoreBind -> CoreM (ScUsage, ScEnv, SeqCoreBind)
specBind env (NonRec x v)
= do
(u, v') <- specInValue env v
return (u, env, NonRec x v')
specBind env (Rec bs)
= do
(usages, vs') <- unzip `liftM` mapM (specInValue env' . snd) bs
let
totalUsages = mconcat usages
bs' = zip (map fst bs) vs'
bindss <- mapM (specialize env' totalUsages) bs'
return (totalUsages, env', Rec (concat bindss))
where
env' = env { sc_how_bound = hb' }
hb' = mkVarEnv [(x, SpecFun) | (x, _) <- bs] `plusVarEnv`
sc_how_bound env
data CallPat = [Var] :-> [SeqCoreValue]
instance Outputable CallPat where
ppr (xs :-> args) = ppr xs <+> text ":->" <+> ppr args
data Spec = Spec {
spec_pat :: CallPat,
spec_id :: Id,
spec_defn :: SeqCoreValue
}
instance Outputable Spec where
ppr spec
= sep
[ text "specialization for" <+> parens (ppr $ spec_pat spec)
, text "id" <+> (ppr $ spec_id spec)
, text "defn" <+> (ppr $ spec_defn spec)
]
specToBinding :: Spec -> (Var, SeqCoreValue)
specToBinding (Spec { spec_id = x, spec_defn = v }) = (x, v)
specialize :: ScEnv -> ScUsage -> (Var, SeqCoreValue)
-> CoreM [(Var, SeqCoreValue)]
specialize env (ScUsage calls used) (x, v)
| skip
= return [(x, v)]
| otherwise
= do
specs <- mkSpecs
let x' = addRulesForSpecs specs
return $ (x', v) : map specToBinding specs
where
skip :: Bool
skip | null binders
= True
| Just sz <- sc_size env
, not $ couldBeSmallEnoughToInline (sc_dflags env) sz
(commandToCoreExpr body)
= True
| otherwise
= False
binders :: [Var]
body :: SeqCoreCommand
(binders, body) = collectLambdas v
mkSpecs :: CoreM [Spec]
mkSpecs
| Just cs <- calls `lookupVarEnv` x
= do
pats <- mapM callToPat (filter shouldSpec cs)
mapM specCall (nubBy samePat pats)
| otherwise
= return []
shouldSpec :: Call -> Bool
shouldSpec args
= or $ zipWith qualifyingArg binders args
where
qualifyingArg x' (Cons {})
= x' `elemVarSet` used
qualifyingArg _ _
= False
specCall :: CallPat -> CoreM Spec
specCall pat@(vars :-> vals)
= do
let v' = lambdas vars $
addLets (zipWith NonRec binders vals) body
x' <- mkSysLocalM (fsLit "scsc") (valueType v')
return $ Spec { spec_pat = pat, spec_id = x', spec_defn = v' }
callToPat :: Call -> CoreM CallPat
callToPat args
= do
(varss, rhss) <- unzip `liftM` zipWithM argToSubpat binders args
return $ concat varss :-> rhss
argToSubpat :: Var -> SeqCoreValue -> CoreM ([Var], SeqCoreValue)
argToSubpat _ (Cons ctor args)
= do
let (tyArgs, tmArgs) = span isErasedValue args
tmVars <- mapM (mkSysLocalM (fsLit "scsca") . valueType) tmArgs
let val = Cons ctor $ tyArgs ++ map Var tmVars
return (tmVars, val)
argToSubpat var _
= do
p <- mkSysLocalM (fsLit "scsca") (varType var)
return ([p], Var p)
addRulesForSpecs :: [Spec] -> Var
addRulesForSpecs specs
= addIdSpecialisations x (zipWith ruleForSpec specs [1..])
ruleForSpec :: Spec -> Int -> CoreRule
ruleForSpec (Spec { spec_pat = patVars :-> patArgs, spec_id = x' }) n
= mkRule auto local name act fn bndrs args rhs
where
auto = True
local = True
name = mkFastString $ "SC:" ++ occNameString (nameOccName (idName x))
++ show n
act = idInlineActivation x
fn = idName x
bndrs = patVars
args = map valueToCoreExpr patArgs
rhs = commandToCoreExpr $
Command [] (Var x') (
foldr (\x k -> App (Var x) k) Return patVars)
infix 4 `samePat`
samePat :: CallPat -> CallPat -> Bool
xs1 :-> cs1 `samePat` xs2 :-> cs2 =
aeqIn env cs1 cs2
where
env = rnBndrs2 (mkRnEnv2 emptyInScopeSet) xs1 xs2