module Language.HERMIT.Primitive.Local.Case
(
externals
, letFloatCase
, caseFloatApp
, caseFloatArg
, caseFloatCase
, caseFloatLet
, caseFloat
, caseReduce
, caseSplit
, caseSplitInline
)
where
import GhcPlugins
import Data.List
import Control.Arrow
import Control.Applicative
import Language.HERMIT.GHC
import Language.HERMIT.Kure
import Language.HERMIT.External
import Language.HERMIT.Monad
import Language.HERMIT.Primitive.Common
import Language.HERMIT.Primitive.GHC hiding (externals)
import Language.HERMIT.Primitive.Inline hiding (externals)
import Language.HERMIT.Primitive.AlphaConversion hiding (externals)
import qualified Language.Haskell.TH as TH
externals :: [External]
externals =
[
external "let-float-case" (promoteExprR letFloatCase :: RewriteH Core)
[ "case (let v = ev in e) of ... ==> let v = ev in case e of ..." ] .+ Commute .+ Shallow .+ Eval .+ Bash
, external "case-float-app" (promoteExprR caseFloatApp :: RewriteH Core)
[ "(case ec of alt -> e) v ==> case ec of alt -> e v" ] .+ Commute .+ Shallow .+ Bash
, external "case-float-arg" (promoteExprR caseFloatArg :: RewriteH Core)
[ "f (case s of alt -> e) ==> case s of alt -> f e" ] .+ Commute .+ Shallow .+ PreCondition
, external "case-float-case" (promoteExprR caseFloatCase :: RewriteH Core)
[ "case (case ec of alt1 -> e1) of alta -> ea ==> case ec of alt1 -> case e1 of alta -> ea" ] .+ Commute .+ Eval .+ Bash
, external "case-float-let" (promoteExprR caseFloatLet :: RewriteH Core)
[ "let v = case ec of alt1 -> e1 in e ==> case ec of alt1 -> let v = e1 in e" ] .+ Commute .+ Shallow .+ Bash
, external "case-float" (promoteExprR caseFloat :: RewriteH Core)
[ "Float a Case whatever the context." ] .+ Commute .+ Shallow .+ PreCondition
, external "case-reduce" (promoteExprR caseReduce :: RewriteH Core)
[ "case-of-known-constructor"
, "case C v1..vn of C w1..wn -> e ==> e[v1/w1..vn/wn]" ] .+ Shallow .+ Eval .+ Bash
, external "case-split" (promoteExprR . caseSplit :: TH.Name -> RewriteH Core)
[ "case-split 'x"
, "e ==> case x of C1 vs -> e; C2 vs -> e, where x is free in e" ]
, external "case-split-inline" (caseSplitInline :: TH.Name -> RewriteH Core)
[ "Like case-split, but additionally inlines the matched constructor "
, "applications for all occurances of the named variable." ]
]
letFloatCase :: RewriteH CoreExpr
letFloatCase = prefixFailMsg "Let floating from Case failed: " $
do
captures <- caseT letVarsT (const (pure ())) $ \ vs _ _ _ -> vs
cFrees <- freeVarsT
caseT (if null (cFrees `intersect` captures) then idR else alphaLet)
(const idR)
(\ (Let bnds e) b ty alts -> Let bnds (Case e b ty alts))
caseFloatApp :: RewriteH CoreExpr
caseFloatApp = prefixFailMsg "Case floating from App function failed: " $
do
captures <- appT caseAltVarsT freeVarsT (flip (map . intersect))
binderCapture <- appT caseBinderVarT freeVarsT intersect
appT ((if null binderCapture then idR else alphaCaseBinder Nothing)
>>> caseAllR idR (\i -> if null (captures !! i) then idR else alphaAlt)
)
idR
(\(Case s b _ty alts) v -> let newTy = exprType (App (case head alts of (_,_,f) -> f) v)
in Case s b newTy [ (c, ids, App f v)
| (c,ids,f) <- alts ])
caseFloatArg :: RewriteH CoreExpr
caseFloatArg = prefixFailMsg "Case floating from App argument failed: " $
do
captures <- appT freeVarsT caseAltVarsT (map . intersect)
binderCapture <- appT freeVarsT caseBinderVarT intersect
appT idR
((if null binderCapture then idR else alphaCaseBinder Nothing)
>>> caseAllR idR (\i -> if null (captures !! i) then idR else alphaAlt)
)
(\f (Case s b _ty alts) -> let newTy = exprType (App f (case head alts of (_,_,e) -> e))
in Case s b newTy [ (c, ids, App f e)
| (c,ids,e) <- alts ])
caseFloatCase :: RewriteH CoreExpr
caseFloatCase = prefixFailMsg "Case floating from Case failed: " $
do
captures <- caseT caseAltVarsT (const altFreeVarsT) $ \ vss bndr _ fs -> map (intersect (concatMap ($ bndr) fs)) vss
binderCapture <- caseT caseBinderVarT (const altFreeVarsT) $ \ innerBindr bndr _ fs -> intersect (concatMap ($ bndr) fs) innerBindr
caseT ((if null binderCapture then idR else alphaCaseBinder Nothing)
>>> caseAllR idR (\i -> if null (captures !! i) then idR else alphaAlt)
)
(const idR)
(\ (Case s1 b1 ty1 alts1) b2 ty2 alts2 -> Case s1 b1 ty1 [ (c1, ids1, Case e1 b2 ty2 alts2) | (c1, ids1, e1) <- alts1 ])
caseFloatLet :: RewriteH CoreExpr
caseFloatLet = prefixFailMsg "Case floating from Let failed: " $
do vs <- letNonRecT caseAltVarsT idR (\ letVar caseVars _ -> elem letVar $ concat caseVars)
let bdsAction = if not vs then idR else nonRecR alphaCase
letT bdsAction idR $ \ (NonRec v (Case s b ty alts)) e -> Case s b ty [ (con, ids, Let (NonRec v ec) e) | (con, ids, ec) <- alts]
caseFloat :: RewriteH CoreExpr
caseFloat = setFailMsg "Unsuitable expression for Case floating." $
caseFloatApp <+ caseFloatArg <+ caseFloatCase <+ caseFloatLet
caseReduce :: RewriteH CoreExpr
caseReduce = letTransform >>> tryR (repeatR letSubstR)
where letTransform = prefixFailMsg "Case reduction failed: " $
withPatFailMsg (wrongExprForm "Case e v t alts") $
do Case s binder _ alts <- idR
case isDataCon s of
Nothing -> fail "head of scrutinee is not a data constructor."
Just (dc, args) -> case [ (bs, rhs) | (DataAlt dc', bs, rhs) <- alts, dc == dc' ] of
[(bs,e')] -> let valArgs = filter isValArg args
in return $ nestedLets e' $ (binder, s) : zip bs valArgs
[] -> fail "no matching alternative."
_ -> fail "more than one matching alternative."
isDataCon :: CoreExpr -> Maybe (DataCon, [CoreExpr])
isDataCon expr = case fn of
Var i -> do dc <- isDataConId_maybe i
return (dc, args)
_ -> fail "not a var"
where (fn, args) = collectArgs expr
nestedLets :: CoreExpr -> [(Id, CoreExpr)] -> CoreExpr
nestedLets = foldr (\(b,rhs) -> Let $ NonRec b rhs)
caseSplit :: TH.Name -> RewriteH CoreExpr
caseSplit nm = do
frees <- freeIdsT
contextfreeT $ \ e ->
case [ i | i <- frees, cmpTHName2Id nm i ] of
[] -> fail "caseSplit: provided name is not free"
(i:_) -> do
let (tycon, tys) = splitTyConApp (idType i)
dcs = tyConDataCons tycon
aNms = map (:[]) $ cycle ['a'..'z']
dcsAndVars <- mapM (\dc -> do
as <- sequence [ newVarH a ty | (a,ty) <- zip aNms $ dataConInstArgTys dc tys ]
return (dc,as)) dcs
return $ Case (Var i) i (exprType e) [ (DataAlt dc, as, e) | (dc,as) <- dcsAndVars ]
caseSplitInline :: TH.Name -> RewriteH Core
caseSplitInline nm = promoteR (caseSplit nm) >>> anybuR (promoteExprR $ inlineName nm)