module HERMIT.Optimization.SYB where
import qualified Outputable (showSDocDebug)
import Control.Arrow
import Control.Monad
import Data.Generics (gshow)
import Data.List (intercalate, intersect, partition)
import GHC.Fingerprint (Fingerprint(..), fingerprintFingerprints)
import qualified Data.Map as Map
import HERMIT.Context
import HERMIT.Core
import HERMIT.Dictionary
import HERMIT.External
import HERMIT.GHC hiding (display)
import HERMIT.Kure
import HERMIT.Monad
import HERMIT.Optimize
import HERMIT.Plugin
import HERMIT.PrettyPrinter.Common
import HERMIT.Shell.Types
import qualified Language.Haskell.TH as TH
plugin :: Plugin
plugin = optimize $ \ opts -> do
let (opts', targets) = partition (`elem` ["interactive", "interactive-only"]) opts
phase 0 $ do
unless ("interactive-only" `elem` opts') $ do
forM_ targets $ \ t -> do
liftIO $ putStrLn $ "optimizing: " ++ t
at (promoteT $ rhsOfT $ cmpTHName2Var $ TH.mkName t) $ do
run $ promoteR $ repeatR optSYB >>> tryR (innermostR $ promoteExprR letrecSubstTrivialR) >>> tryR simplifyR
unless (null opts') $ after Simplify $ interactive exts []
optSimp :: RewriteH Core
optSimp = anytdR (repeatR (promoteExprR ( rule "append"
<+ rule "[]++"
<+ castElimReflR
<+ castElimSymPlusR
<+ letElimR
<+ letSubstType (TH.mkName "*")
<+ letSubstType (TH.mkName "BOX")
<+ evalFingerprintFingerprints
<+ eqWordElim
<+ letSubstTrivialR
<+ caseReduceR)
>>> traceR "SIMPLIFYING"))
optFloat :: RewriteH Core
optFloat = anybuR (promoteExprR (( memoFloatMemoLet
<+ memoFloatMemoBind
<+ memoFloatApp
<+ memoFloatArg
<+ memoFloatLam
<+ memoFloatLet
<+ memoFloatBind
<+ memoFloatRecBind
<+ memoFloatCase
<+ memoFloatCast
<+ memoFloatAlt)
>>> traceR "FLOATING"))
optMemo :: RewriteH Core
optMemo = smarttdR $ promoteExprR $ (>>) ( eliminatesType (TH.mkName "Data")
<+ eliminatesType (TH.mkName "Typeable")
<+ eliminatesType (TH.mkName "Typeable1")
<+ eliminatesType (TH.mkName "TypeRep")
<+ eliminatesType (TH.mkName "TyCon")
<+ eliminatesType (TH.mkName "ID")
<+ eliminatesType (TH.mkName "Qr")
<+ eliminatesType (TH.mkName "Fingerprint"))
( bracketR "MEMOIZING" memoize
<+ (forcePrims [TH.mkName "fingerprintFingerprints", TH.mkName "eqWord#"] >>> traceR "FORCING"))
optSYB :: RewriteH Core
optSYB = onetdR (promoteExprR (bracketR "!!!!! USED MEMOIZED BINDING !!!!!" stashFoldAnyR)) <+ optSimp <+ optFloat <+ optMemo
exts :: [External]
exts = map ((.+ Experiment) . (.+ TODO)) [
external "let-subst-trivial" (promoteExprR letSubstTrivialR :: RewriteH Core)
[ "Let substitution (but only if e1 is a variable or literal)"
, "let x = e1 in e2 ==> e2[e1/x]"
, "x must not be free in e1." ] .+ Deep
, external "letrec-subst-trivial" (promoteExprR letrecSubstTrivialR :: RewriteH Core)
[ "Let substitution (but only if e1 is a variable or literal)"
, "let x = e1 in e2 ==> e2[e1/x]"
, "x must not be free in e1." ] .+ Deep
, external "let-subst-type" (promoteExprR . letSubstType :: TH.Name -> RewriteH Core)
[ "Let substitution (but only if the type of e1 contains the given type)"
, "(\"let-subst-type '*\" is especially useful to eliminate bindings for types)"
, "(let x = e1 in e2) ==> (e2[e1/x])"
, "x must not be free in e1." ] .+ Deep
, external "eval-eqWord" (promoteExprR eqWordElim :: RewriteH Core)
["eqWord# e1 e2 ==> True if e1 and e2 are literals and equal"
,"eqWord# e1 e2 ==> False if e1 and e2 are literals and not equal"] .+ Shallow .+ Eval
, external "eval-fingerprintFingerprints" (promoteExprR evalFingerprintFingerprints :: RewriteH Core)
["replaces 'fingerprintFingerprints [f1,f2,...]' with its value."
,"Requires that f1,f2,... are literals."] .+ Shallow .+ Eval
, external "eliminates-type" (promoteExprT . eliminatesType :: TH.Name -> TranslateH Core ())
["determines whether evaluating the term "] .+ Shallow
, external "smart-td" (smarttdR)
[ "apply a rewrite twice, in a top-down and bottom-up way, using one single tree traversal",
"succeeding if any succeed"]
, external "force" (promoteExprR force :: RewriteH Core)
["force"].+ Eval
, external "force" (promoteExprR . forcePrims :: [TH.Name] -> RewriteH Core)
["force [Name]"].+ Eval
, external "memoize" (promoteExprR memoize :: RewriteH Core)
["TODO"] .+ Eval .+ Introduce
, external "opt-syb" optSYB ["TODO"] .+ Eval .+ Introduce
, external "opt-simp" optSimp ["TODO"] .+ Eval .+ Introduce
, external "opt-float" optFloat ["TODO"] .+ Eval .+ Introduce
, external "opt-memo" optMemo ["TODO"] .+ Eval .+ Introduce
, external "memo-float-app" (promoteExprR memoFloatApp :: RewriteH Core)
[ "(let v = ev in e) x ==> let v = ev in e x" ] .+ Commute .+ Shallow
, external "memo-float-arg" (promoteExprR memoFloatArg :: RewriteH Core)
[ "f (let v = ev in e) ==> let v = ev in f e" ] .+ Commute .+ Shallow
, external "memo-float-lam" (promoteExprR memoFloatLam :: RewriteH Core)
[ "(\\ v1 -> let v2 = e1 in e2) ==> let v2 = e1 in (\\ v1 -> e2), if v1 is not free in e2.",
"If v1 = v2 then v1 will be alpha-renamed."
] .+ Commute .+ Shallow
, external "memo-float-memo-let" (promoteExprR memoFloatMemoLet :: RewriteH Core)
[ "let v = (let w = ew in ev) in e ==> let w = ew in let v = ev in e" ] .+ Commute .+ Shallow
, external "memo-float-memo-bind" (promoteExprR memoFloatMemoBind :: RewriteH Core)
[ "let v = (let w = ew in ev) in e ==> let w = ew in let v = ev in e" ] .+ Commute .+ Shallow
, external "memo-float-let" (promoteExprR memoFloatLet :: RewriteH Core)
[ "let v = (let w = ew in ev) in e ==> let w = ew in let v = ev in e" ] .+ Commute .+ Shallow
, external "memo-float-bind" (promoteExprR memoFloatBind :: RewriteH Core)
[ "let v = (let w = ew in ev) in e ==> let w = ew in let v = ev in e" ] .+ Commute .+ Shallow
, external "memo-float-rec-bind" (promoteExprR memoFloatRecBind :: RewriteH Core)
[ "let v = (let w = ew in ev) in e ==> let w = ew in let v = ev in e" ] .+ Commute .+ Shallow
, external "memo-float-case" (promoteExprR memoFloatCase :: RewriteH Core)
[ "case (let v = ev in e) of ... ==> let v = ev in case e of ..." ] .+ Commute .+ Shallow .+ Eval
, external "memo-float-cast" (promoteExprR memoFloatCast :: RewriteH Core)
[ "case (let v = ev in e) of ... ==> let v = ev in case e of ..." ] .+ Commute .+ Shallow .+ Eval
, external "memo-float-alt" (promoteExprR memoFloatAlt :: RewriteH Core)
[ "case (let v = ev in e) of ... ==> let v = ev in case e of ..." ] .+ Commute .+ Shallow .+ Eval
]
letSubstTrivialR :: RewriteH CoreExpr
letSubstTrivialR = prefixFailMsg "Let substition failed: " $ do
contextfreeT $ \ expr -> case expr of
Let (NonRec b be@(Var _)) e -> return $ substCoreExpr b be e
Let (NonRec b be@(Lit _)) e -> return $ substCoreExpr b be e
_ -> fail $ "expression is not a trivial, non-recursive Let."
letrecSubstTrivialR :: RewriteH CoreExpr
letrecSubstTrivialR = prefixFailMsg "Letrec substition failed: " $ do
contextfreeT $ \ expr -> case expr of
Let (Rec bs) e
| Just (b, be, bs', e') <- findTrivial e [] bs -> do
let bs'' = map (substCoreExpr b be . snd) bs'
e'' = substCoreExpr b be e'
return (Let (Rec (zip (map fst bs') bs'')) e'')
_ -> fail $ "expression is not a trivial, recursive Let."
findTrivial _ _ [] = Nothing
findTrivial e bs' ((b,be) : bs) =
case be of
Var _ -> Just (b, be, bs' ++ bs, e)
Lit _ -> Just (b, be, bs' ++ bs, e)
_ -> findTrivial e (bs'++[(b,be)]) bs
letSubstType :: TH.Name -> RewriteH CoreExpr
letSubstType ty = prefixFailMsg ("letSubstType '" ++ TH.nameBase ty ++ " failed: ") $
do Let (NonRec lhs rhs) body <- idR
let t = exprKindOrType rhs
dynFlags <- constT getDynFlags
guardMsg (ty `inType` t) $ " not found in " ++ showPpr dynFlags t ++ "."
letSubstR
inType :: TH.Name -> Type -> Bool
inType name ty = go ty where
go (TyVarTy _) = False
go (AppTy t1 t2) = go t1 || go t2
go (TyConApp ctor args) = name `cmpTHName2Name` tyConName ctor || any (go) args
go (FunTy t1 t2) = go t1 || go t2
go (ForAllTy _ t) = go t
eqWordElim :: RewriteH CoreExpr
eqWordElim = do
(Var v, [Lit l1, Lit l2]) <- callNameT $ TH.mkName "GHC.Prim.eqWord#"
#if __GLASGOW_HASKELL__ <= 706
return $ mkIntLitInt $ if l1 == l2 then 1 else 0
#else
dflags <- constT getDynFlags
return $ mkIntLitInt dflags $ if l1 == l2 then 1 else 0
#endif
varInfo2 :: TH.Name -> TranslateH Core String
varInfo2 nm = translate $ \ c e ->
case filter (cmpTHName2Var nm) $ Map.keys (hermitBindings c) of
[] -> fail "cannot find name."
[i] -> do dynFlags <- getDynFlags
return ("Type or Kind: " ++ (showPpr dynFlags . exprKindOrType) (Var i))
is -> fail $ "multiple names match: " ++ intercalate ", " (map var2String is)
varInfo :: RewriteH CoreExpr
varInfo = do
Var v <- idR
dynFlags <- constT getDynFlags
trace (Outputable.showSDocDebug dynFlags (ppr v)) (return (Var v))
evalFingerprintFingerprints :: RewriteH CoreExpr
evalFingerprintFingerprints = do
dynFlags <- constT getDynFlags
App (Var v) args <- idR
v' <- findIdT (TH.mkName "fingerprintFingerprints")
ctor <- findIdT (TH.mkName "Fingerprint")
w64 <- findIdT (TH.mkName "GHC.Word.W64#")
guardMsg (v == v') (var2String v ++ " does not match " ++ "fingerprintFingerprints")
let getFingerprints (App (Var nil') _) = return []
getFingerprints (Var cons' `App` _ `App` f `App` fs) = liftM2 (:) f' (getFingerprints fs) where
f' = case f of (Var ctor' `App` (Lit (MachWord w1)) `App` (Lit (MachWord w2)))
-> return (Fingerprint (fromIntegral w1) (fromIntegral w2))
_ -> fail ("Non-literal fingerprint as argument:"++gshow f)
getFingerprints a = fail ("Non-list as argument:"++gshow a)
fingerprints <- getFingerprints args
let Fingerprint w1 w2 = fingerprintFingerprints fingerprints
return (Var ctor `App` (Var w64 `App` Lit (MachWord (toInteger w1))) `App` (Var w64 `App` Lit (MachWord (toInteger w2))))
isMemoizedLet :: TranslateH Core ()
isMemoizedLet = do
e <- idR
case e of
ExprCore (Let (Rec [(v, _)]) _) -> do
flags <- constT $ getDynFlags
constT $ lookupDef (showPpr flags v)
return ()
_ -> fail "not a memoized let"
inlineType :: TH.Name -> RewriteH CoreExpr
inlineType ty = prefixFailMsg ("inlineType '" ++ TH.nameBase ty ++ " failed: ") $
withPatFailMsg (wrongExprForm "Var v") $
do Var v <- idR
let t = (exprKindOrType (Var v))
dynFlags <- constT getDynFlags
guardMsg (ty `inType` t) $ " not found in " ++ showPpr dynFlags t ++ "."
inlineR
smarttdR :: RewriteH Core -> RewriteH Core
smarttdR r = modFailMsg ("smarttdR failed: " ++) $ go where
go = r <+ go'
go' = do
ExprCore e <- idR
case e of
Var _ -> fail "smarttdR Var"
Lit _ -> fail "smarttdR Let"
App _ _ -> pathR [App_Fun] go <+ pathR [App_Arg] go
Lam _ _ -> pathR [Lam_Body] go
Let (NonRec _ _) _ -> pathR [Let_Body] go <+ pathR [Let_Bind,NonRec_RHS] go
Let (Rec bs) _ -> pathR [Let_Body] go <+ foldr (<+) (fail "smarttdR Let Rec") [ pathR [Let_Bind, Rec_Def i, Def_RHS] go
| (i, _) <- zip [0..] bs]
Case _ _ _ as -> pathR [Case_Scrutinee] go <+ foldr (<+) (fail "smarttdR Case") [pathR [Case_Alt i, Alt_RHS] go | (i, _) <- zip [0..] as]
Cast _ _ -> pathR [Cast_Expr] go
Tick _ _ -> pathR [Tick_Expr] go
Type _ -> fail "smarttdR Type"
Coercion _ -> fail "smarttdR Coercion"
eliminatesType :: TH.Name -> TranslateH CoreExpr ()
eliminatesType ty = do
dynFlags <- constT getDynFlags
prefixFailMsg ("eliminatesType: " ++ show ty ++ " ") $ do
appT successT (inTypeT ty) const
<+ castT (inTypeT ty) successT (flip const)
<+ caseT (inTypeT ty) successT successT (const successT) (\() _ _ _ -> ())
inTypeT :: TH.Name -> TranslateH CoreExpr ()
inTypeT ty = do
dynFlags <- constT getDynFlags
contextfreeT $ \ e -> do
let t = exprKindOrType e
guardMsg (ty `inType` t) $ " not found in " ++ showPpr dynFlags t ++ "."
return ()
force :: RewriteH CoreExpr
force = forcePrims []
forcePrims :: [TH.Name] -> RewriteH CoreExpr
forcePrims prims = forceDeep False prims
forceDeep :: Bool -> [TH.Name] -> RewriteH CoreExpr
forceDeep deep prims =
inlineR
<+ betaReduceR <+ letFloatAppR <+ castFloatAppR
<+ (appT (isPrimCall' prims) successT const >> appAllR idR (forceDeep True prims))
<+ appAllR (forceDeep deep prims) idR
<+ whenM (return deep) (appAllR idR (forceDeep deep prims))
<+ caseAllR (forceDeep deep prims) idR idR (const idR) <+ caseReduceR <+ letFloatCaseR
<+ letAllR idR (forceDeep deep prims)
<+ castAllR (forceDeep deep prims) idR
<+ fail "forceDeep failed"
isPrimCall' :: [TH.Name] -> TranslateH CoreExpr ()
isPrimCall' prims = do
e <- idR
if isPrimCall prims e then return () else fail "not a prim call"
isPrimCall :: [TH.Name] -> CoreExpr -> Bool
isPrimCall prims (Var v) = any (flip cmpTHName2Var v) prims
isPrimCall prims (App f _) = isPrimCall prims f
isPrimCall _ _ = False
memoize :: RewriteH CoreExpr
memoize = prefixFailMsg "memoize failed: " $ do
e <- idR
(Var v, args) <- callT
e' <- force
v' <- constT $ newIdH ("memo_"++getOccString v) (exprType e)
flags <- constT $ getDynFlags
constT $ saveDef (showPpr flags v') (Def v' e)
--cleanupUnfold
return (Let (Rec [(v', e')]) (Var v'))
memoFloatApp :: RewriteH CoreExpr
memoFloatApp = prefixFailMsg "Let floating from App function failed: " $
do flags <- constT getDynFlags
appT letVarsT idR $ \x y -> mapM (lookupDef . showPpr flags) x
vs <- appT letVarsT (arr $ varSetElems . freeVarsExpr) intersect
let letAction = if null vs then idR else alphaLetR
appT letAction idR $ \ (Let bnds e) x -> Let bnds $ App e x
memoFloatArg :: RewriteH CoreExpr
memoFloatArg = prefixFailMsg "Let floating from App argument failed: " $
do flags <- constT getDynFlags
appT idR letVarsT $ \x y -> mapM (lookupDef . showPpr flags) y
vs <- appT (arr $ varSetElems . freeVarsExpr) letVarsT intersect
let letAction = if null vs then idR else alphaLetR
appT idR letAction $ \ f (Let bnds e) -> Let bnds $ App f e
memoFloatMemoLet :: RewriteH CoreExpr
memoFloatMemoLet = prefixFailMsg "Let floating from Let failed: " $ do
(Let (Rec binds1) (Let (Rec binds2) body)) <- idR
flags <- constT getDynFlags
constT $ mapM (lookupDef . showPpr flags . fst) binds1
constT $ mapM (lookupDef . showPpr flags . fst) binds2
return (Let (Rec (binds1 ++ binds2)) body)
memoFloatLet :: RewriteH CoreExpr
memoFloatLet = prefixFailMsg "Let floating from Let failed: " $ do
(Let binds1 (Let (Rec binds2) body)) <- idR
vars <- letVarsT
flags <- constT getDynFlags
constT $ mapM (lookupDef . showPpr flags . fst) binds2
let (unfloatable, floatable) = filterBinds' vars binds2
if null floatable
then fail "cannot float"
else return (Let (Rec floatable) (Let binds1
(if null unfloatable
then body
else Let (Rec unfloatable) body)))
memoFloatMemoBind :: RewriteH CoreExpr
memoFloatMemoBind = prefixFailMsg "Let floating from Let failed: " $ do
(Let (Rec binds1) body) <- idR
flags <- constT getDynFlags
constT $ mapM (lookupDef . showPpr flags . fst) binds1
let f (x, e) = (do (Let (Rec binds2) body) <- return e
constT $ mapM (lookupDef . showPpr flags . fst) binds2
return ((x, body) : binds2)) <+
(return [(x, e)])
binds1' <- mapM f binds1
if length (concat binds1') == length binds1
then fail "cannot float"
else return (Let (Rec (concat binds1')) body)
memoFloatRecBind :: RewriteH CoreExpr
memoFloatRecBind = prefixFailMsg "Let floating from Let failed: " $ do
(Let (Rec binds1) body) <- idR
flags <- constT getDynFlags
constT $ mapM (lookupDef . showPpr flags . fst) binds1
let f (x, e) = (do (Let (Rec binds2) body) <- return e
constT $ mapM (lookupDef . showPpr flags . fst) binds2
let (unfl, fl) = filterBinds' (map fst binds1) binds2
return (fl, (x, if null unfl then body else Let (Rec unfl) body))) <+
(return ([], (x, e)))
binds1' <- mapM f binds1
(outer, inner) <- return (unzip binds1')
if null (concat outer)
then fail "cannot float"
else return (Let (Rec (concat outer)) (Let (Rec inner) body))
memoFloatBind :: RewriteH CoreExpr
memoFloatBind = prefixFailMsg "Let floating from Let failed: " $ do
(Let (NonRec lhs (Let (Rec binds) rhs)) body) <- idR
flags <- constT getDynFlags
constT $ mapM (lookupDef . showPpr flags . fst) binds
let (unfloatable, floatable) = filterBinds' [lhs] binds
if null floatable
then fail "cannot float"
else return (Let (Rec floatable) (Let (NonRec lhs (
if null unfloatable then rhs else Let (Rec unfloatable) rhs)) body))
filterBinds' vars binds = filterBinds vars [] (map (\x -> (varSetElems $ freeIdsExpr (snd x), x)) binds)
filterBinds :: [Id] -> [(Id, a)]
-> [([Id], (Id, a))]
-> ([(Id, a)], [(Id, a)])
filterBinds vars unfloatables floatables =
case partition isUnfloatable floatables of
([], _) -> (unfloatables, map snd floatables)
(newUnfloatable, newFloatables) ->
filterBinds
(map (fst . snd) newUnfloatable ++ vars)
(map snd newUnfloatable ++ unfloatables)
newFloatables
where isUnfloatable (vars', _) = not (null (intersect vars vars'))
memoFloatLam :: RewriteH CoreExpr
memoFloatLam = prefixFailMsg "Let floating from Lam failed: " $
withPatFailMsg (wrongExprForm "Lam v1 (Let (NonRec v2 e1) e2)") $
do Lam v1 (Let (Rec binds) e2) <- idR
flags <- constT getDynFlags
constT $ mapM (lookupDef . showPpr flags . fst) binds
if v1 `elem` (map fst binds)
then alphaLamR Nothing >>> memoFloatLam
else let (unfloatable, floatable) = filterBinds' [v1] binds
in if null floatable
then fail "no floatable let binds"
else return (Let (Rec floatable) (Lam v1
(if null unfloatable
then e2
else Let (Rec unfloatable) e2)))
memoFloatCase :: RewriteH CoreExpr
memoFloatCase = prefixFailMsg "Let floating from Case failed: " $
do flags <- constT getDynFlags
caseT letVarsT idR idR (const idR) $ \x _ _ _ -> mapM (lookupDef . showPpr flags) x
captures <- caseT letVarsT idR idR
(\ _ -> arr $ varSetElems . freeVarsAlt)
(\ vs wild _ fs -> vs `intersect` concatMap (wild:) fs)
caseT (if null captures then idR else alphaLetVarsR captures)
idR idR
(const idR)
(\ (Let bnds e) wild ty alts -> Let bnds (Case e wild ty alts))
memoFloatCast :: RewriteH CoreExpr
memoFloatCast = prefixFailMsg "Let floating from Cast failed: " $
do flags <- constT getDynFlags
castT letVarsT idR $ \x co -> mapM (lookupDef . showPpr flags) x
castT idR idR (\ (Let bnds e) co -> Let bnds (Cast e co))
memoFloatAlt :: RewriteH CoreExpr
memoFloatAlt = prefixFailMsg "Let floating from Case failed: " $ do
e' <- idR
flags <- constT getDynFlags
e@(Case scr wild ty alts) <- idR
let getLet pre post alt = do
(con, args, Let (Rec binds) body) <- return alt
constT $ mapM (lookupDef . showPpr flags . fst) binds
let (unfloatable, floatable) = filterBinds' args binds
if null floatable
then fail "no floatable let binds"
else return (Let (Rec floatable) (Case scr wild ty (pre ++
(con, args, if null unfloatable then body else Let (Rec unfloatable) body) : post)))
getLets pre [] = fail "no memo lets found"
getLets pre (alt : post) =
(getLet pre post alt) <+ getLets (pre ++ [alt]) post
getLets [] alts