module Language.SequentCore.SpecConstr (
plugin
) where
import Language.SequentCore.Plugin
import Language.SequentCore.Pretty ()
import Language.SequentCore.Syntax
import Language.SequentCore.Translate
import CoreMonad ( CoreM
, Plugin(installCoreToDos), defaultPlugin
, errorMsg, putMsg
, reinitializeGlobals
, CoreToDo(CoreDoSpecConstr, CoreDoPasses, CoreDoPluginPass) )
import CoreUnfold ( couldBeSmallEnoughToInline )
import CoreSyn ( CoreRule )
import DynFlags ( DynFlags(specConstrThreshold), getDynFlags )
import FastString ( fsLit, mkFastString )
import Id ( Id, mkSysLocalM, idName, idInlineActivation )
import Name ( nameOccName, occNameString )
import Outputable hiding ((<>))
import Rules ( mkRule, addIdSpecialisations )
import Var ( Var, varType )
import VarEnv
import VarSet
import Control.Applicative ( (<$>), (<|>) )
import Control.Monad
import Data.List ( nubBy )
import Data.Monoid
tracing :: Bool
tracing = False
plugin :: Plugin
plugin = defaultPlugin {
installCoreToDos = \_ todos -> do
reinitializeGlobals
case replace todos of
Nothing ->
do
errorMsg (text "Could not find SpecConstr pass to replace")
return todos
Just todos' ->
return todos'
} where
replace (CoreDoSpecConstr : todos)
= Just (specConstrPass : todos)
replace (cdp@(CoreDoPasses todos1) : todos2)
= do
todos1' <- replace todos1
return $ CoreDoPasses todos1' : todos2
<|>
do
todos2' <- replace todos2
return $ cdp : todos2'
replace (todo : todos)
= (todo :) <$> replace todos
replace []
= Nothing
specConstrPass = CoreDoPluginPass "SeqSpecConstr" (sequentPass specModule)
data ScEnv
= SCE { sc_size :: Maybe Int
, sc_how_bound :: VarEnv HowBound
, sc_dflags :: DynFlags
}
instance Outputable ScEnv where
ppr (SCE { sc_size = sz, sc_how_bound = hb })
= sep [hang (text "SCE {") 2 $ sep [
text "sc_size" <+> equals <+> maybe (text "(any)") int sz <+> comma,
text "sc_how_bound" <+> equals <+> ppr hb
], char '}']
data ScUsage = ScUsage Calls ArgUsage
instance Outputable ScUsage where
ppr (ScUsage calls usage)
= hang (text "ScUsage") 2 $ sep [ppr calls, ppr usage]
type Calls = VarEnv [Call]
type Call = [SeqCoreTerm]
data HowBound = SpecFun | SpecArg
instance Outputable HowBound where
ppr SpecFun = text "SpecFun"
ppr SpecArg = text "SpecArg"
type ArgUsage = VarSet
specModule :: [SeqCoreBind] -> CoreM [SeqCoreBind]
specModule binds = do
env <- initScEnv <$> getDynFlags
map snd <$> mapM (specInBind env) binds
initScEnv :: DynFlags -> ScEnv
initScEnv dflags = SCE { sc_size = specConstrThreshold dflags
, sc_how_bound = emptyVarEnv
, sc_dflags = dflags }
emptyScUsage :: ScUsage
emptyScUsage = ScUsage emptyVarEnv emptyVarSet
instance Monoid ScUsage where
mempty
= emptyScUsage
ScUsage calls1 used1 `mappend` ScUsage calls2 used2
= ScUsage (plusVarEnv_C (++) calls1 calls2) (used1 `unionVarSet` used2)
specInTerm :: ScEnv -> SeqCoreTerm -> CoreM (ScUsage, SeqCoreTerm)
specInTerm env (Lam xs kb c)
= do
(usage, c') <- specInCommand env' c
return (usage, Lam xs kb c')
where
env' = env { sc_how_bound = extendVarEnvList hb (zip xs (repeat SpecArg)) }
hb = sc_how_bound env
specInTerm env (Compute kb c)
= do
(usage, c') <- specInCommand env c
return (usage, Compute kb c')
specInTerm _ v
= return (emptyScUsage, v)
specInCont :: ScEnv -> SeqCoreCont -> CoreM (ScUsage, SeqCoreCont)
specInCont env (App v k)
= do
(usage1, v') <- specInTerm env v
(usage2, k') <- specInCont env k
return (usage1 <> usage2, App v' k')
specInCont env (Case x as)
= do
(usages, as') <- mapAndUnzipM (specInAlt env) as
return (mconcat usages, Case x as')
specInCont env (Cast co k)
= do
(usage, k') <- specInCont env k
return (usage, Cast co k')
specInCont env (Tick ti k)
= do
(usage, k') <- specInCont env k
return (usage, Tick ti k')
specInCont _ k
= return (emptyScUsage, k)
specInAlt :: ScEnv -> SeqCoreAlt -> CoreM (ScUsage, SeqCoreAlt)
specInAlt env (Alt ac xs c)
= do
(usage, c') <- specInCommand env c
return (usage, Alt ac xs c')
specInBind :: ScEnv -> SeqCoreBind -> CoreM (ScUsage, SeqCoreBind)
specInBind env b
= do
(u, _, b') <- specBind env b
return (u, b')
specInCommand :: ScEnv -> SeqCoreCommand -> CoreM (ScUsage, SeqCoreCommand)
specInCommand env (Command { cmdLet = bs, cmdTerm = v, cmdCont = fs })
= specBinds env bs [] []
where
specBinds :: ScEnv -> [SeqCoreBind] -> [SeqCoreBind] -> [ScUsage]
-> CoreM (ScUsage, SeqCoreCommand)
specBinds env [] bs' usages
= do
(usage', v', fs') <- specInCut env v fs
return (mconcat (usage' : usages), Command
{ cmdLet = reverse bs', cmdTerm = v', cmdCont = fs' })
specBinds env (b : bs) bs' usages
= do
(usage', env', b') <- specBind env b
specBinds env' bs (b' : bs') (usage' : usages)
specInCut :: ScEnv -> SeqCoreTerm -> SeqCoreCont
-> CoreM (ScUsage, SeqCoreTerm, SeqCoreCont)
specInCut env v k
= do
let u = usageFromCut env v k
(u_v, v') <- specInTerm env v
(u_k, k') <- specInCont env k
return (u <> u_v <> u_k, v', k')
usageFromCut :: ScEnv -> SeqCoreTerm -> SeqCoreCont -> ScUsage
usageFromCut env (Var x) (Case {})
| Just SpecArg <- sc_how_bound env `lookupVarEnv` x
= ScUsage emptyVarEnv (unitVarSet x)
usageFromCut env v@(Var f) k
| Just SpecFun <- sc_how_bound env `lookupVarEnv` f
, Just (args, _) <- asSaturatedCall v k
= ScUsage (unitVarEnv f [args]) emptyVarSet
usageFromCut _ _ _
= emptyScUsage
specBind :: ScEnv -> SeqCoreBind -> CoreM (ScUsage, ScEnv, SeqCoreBind)
specBind env (NonRec x v)
= do
(u, v') <- specInTerm env v
return (u, env, NonRec x v')
specBind env (Rec bs)
= do
(usages, vs') <- unzip `liftM` mapM (specInTerm env' . snd) bs
let
totalUsages = mconcat usages
bs' = zip (map fst bs) vs'
bindss <- mapM (specialize env' totalUsages) bs'
return (totalUsages, env', Rec (concat bindss))
where
env' = env { sc_how_bound = hb' }
hb' = mkVarEnv [(x, SpecFun) | (x, _) <- bs] `plusVarEnv`
sc_how_bound env
data CallPat = [Var] :-> [SeqCoreTerm]
instance Outputable CallPat where
ppr (xs :-> args) = ppr xs <+> text ":->" <+> ppr args
data Spec = Spec {
spec_pat :: CallPat,
spec_id :: Id,
spec_defn :: SeqCoreTerm
}
instance Outputable Spec where
ppr spec
= sep
[ text "specialization for" <+> parens (ppr $ spec_pat spec)
, text "id" <+> (ppr $ spec_id spec)
, text "defn" <+> (ppr $ spec_defn spec)
]
specToBinding :: Spec -> (Var, SeqCoreTerm)
specToBinding (Spec { spec_id = x, spec_defn = v }) = (x, v)
specialize :: ScEnv -> ScUsage -> (Var, SeqCoreTerm)
-> CoreM [(Var, SeqCoreTerm)]
specialize env (ScUsage calls used) (x, v)
| tracing
, pprTrace "specialize" (ppr x <+> ppr v) False
= undefined
| skip
= do
when tracing $ putMsg $ text "specialize: skipping" <+> ppr x
return [(x, v)]
| otherwise
= do
when tracing $ putMsg $ text "specialize: PROCESSING" <+> ppr x
specs <- mkSpecs
let x' = addRulesForSpecs specs
return $ (x', v) : map specToBinding specs
where
skip :: Bool
skip | null binders
= True
| Just sz <- sc_size env
, let coreExpr = commandToCoreExpr retId body
, not $ couldBeSmallEnoughToInline (sc_dflags env) sz coreExpr
= True
| otherwise
= False
binders :: [Var]
retId :: ContId
body :: SeqCoreCommand
(binders, retId, body)
| Lam xs k body <- v = (xs, k, body)
| otherwise = ([], undefined, undefined)
mkSpecs :: CoreM [Spec]
mkSpecs
| Just cs <- calls `lookupVarEnv` x
= do
pats <- mapM callToPat (filter shouldSpec cs)
mapM specCall (nubBy samePat pats)
| otherwise
= return []
shouldSpec :: Call -> Bool
shouldSpec args
= or $ zipWith qualifyingArg binders args
where
qualifyingArg x' (Cons {})
= x' `elemVarSet` used
qualifyingArg _ _
= False
specCall :: CallPat -> CoreM Spec
specCall pat@(vars :-> vals)
= do
let v' = Lam vars retId $
addLets (zipWith NonRec binders vals) body
x' <- mkSysLocalM (fsLit "scsc") (termType v')
return $ Spec { spec_pat = pat, spec_id = x', spec_defn = v' }
callToPat :: Call -> CoreM CallPat
callToPat args
= do
(varss, rhss) <- unzip `liftM` zipWithM argToSubpat binders args
return $ concat varss :-> rhss
argToSubpat :: Var -> SeqCoreTerm -> CoreM ([Var], SeqCoreTerm)
argToSubpat _ (Cons ctor args)
= do
let (tyArgs, tmArgs) = span isErasedTerm args
tmVars <- mapM (mkSysLocalM (fsLit "scsca") . termType) tmArgs
let val = Cons ctor $ tyArgs ++ map Var tmVars
return (tmVars, val)
argToSubpat var _
= do
p <- mkSysLocalM (fsLit "scsca") (varType var)
return ([p], Var p)
addRulesForSpecs :: [Spec] -> Var
addRulesForSpecs specs
= addIdSpecialisations x (zipWith ruleForSpec specs [1..])
ruleForSpec :: Spec -> Int -> CoreRule
ruleForSpec (Spec { spec_pat = patVars :-> patArgs, spec_id = x' }) n
= mkRule auto local name act fn bndrs args rhs
where
auto = True
local = True
name = mkFastString $ "SC:" ++ occNameString (nameOccName (idName x))
++ show n
act = idInlineActivation x
fn = idName x
bndrs = patVars
args = map termToCoreExpr patArgs
rhs = commandToCoreExpr retId $
Command [] (Var x') (
foldr (\x k -> App (Var x) k) (Return retId) patVars)
infix 4 `samePat`
samePat :: CallPat -> CallPat -> Bool
xs1 :-> cs1 `samePat` xs2 :-> cs2 =
aeqIn env cs1 cs2
where
env = rnBndrs2 (mkRnEnv2 emptyInScopeSet) xs1 xs2