module Language.HERMIT.Primitive.Local.Case
(
caseExternals
, caseFloatApp
, caseFloatArg
, caseFloatCase
, caseFloatLet
, caseFloat
, caseReduce
, caseSplit
, caseSplitInline
)
where
import GhcPlugins
import Data.List
import Control.Arrow
import Language.HERMIT.Core
import Language.HERMIT.Monad
import Language.HERMIT.Kure
import Language.HERMIT.GHC
import Language.HERMIT.External
import Language.HERMIT.Primitive.Common
import Language.HERMIT.Primitive.GHC
import Language.HERMIT.Primitive.Inline
import Language.HERMIT.Primitive.AlphaConversion
import qualified Language.Haskell.TH as TH
caseExternals :: [External]
caseExternals =
[
external "case-float-app" (promoteExprR caseFloatApp :: RewriteH Core)
[ "(case ec of alt -> e) v ==> case ec of alt -> e v" ] .+ Commute .+ Shallow .+ Bash
, external "case-float-arg" (promoteExprR caseFloatArg :: RewriteH Core)
[ "f (case s of alt -> e) ==> case s of alt -> f e" ] .+ Commute .+ Shallow .+ PreCondition
, external "case-float-case" (promoteExprR caseFloatCase :: RewriteH Core)
[ "case (case ec of alt1 -> e1) of alta -> ea ==> case ec of alt1 -> case e1 of alta -> ea" ] .+ Commute .+ Eval .+ Bash
, external "case-float-let" (promoteExprR caseFloatLet :: RewriteH Core)
[ "let v = case ec of alt1 -> e1 in e ==> case ec of alt1 -> let v = e1 in e" ] .+ Commute .+ Shallow .+ Bash
, external "case-float" (promoteExprR caseFloat :: RewriteH Core)
[ "Float a Case whatever the context." ] .+ Commute .+ Shallow .+ PreCondition
, external "case-reduce" (promoteExprR caseReduce :: RewriteH Core)
[ "case-of-known-constructor"
, "case C v1..vn of C w1..wn -> e ==> e[v1/w1..vn/wn]" ] .+ Shallow .+ Eval .+ Bash
, external "case-split" (promoteExprR . caseSplit :: TH.Name -> RewriteH Core)
[ "case-split 'x"
, "e ==> case x of C1 vs -> e; C2 vs -> e, where x is free in e" ]
, external "case-split-inline" (caseSplitInline :: TH.Name -> RewriteH Core)
[ "Like case-split, but additionally inlines the matched constructor "
, "applications for all occurances of the named variable." ]
]
caseFloatApp :: RewriteH CoreExpr
caseFloatApp = prefixFailMsg "Case floating from App function failed: " $
do
captures <- appT caseAltVarsT freeVarsT (flip (map . intersect))
wildCapture <- appT caseWildVarT freeVarsT elem
appT ((if not wildCapture then idR else alphaCaseBinder Nothing)
>>> caseAllR idR (\i -> if null (captures !! i) then idR else alphaAlt)
)
idR
(\(Case s b _ty alts) v -> let newTy = exprType (App (case head alts of (_,_,f) -> f) v)
in Case s b newTy [ (c, ids, App f v)
| (c,ids,f) <- alts ])
caseFloatArg :: RewriteH CoreExpr
caseFloatArg = prefixFailMsg "Case floating from App argument failed: " $
do
captures <- appT freeVarsT caseAltVarsT (map . intersect)
wildCapture <- appT freeVarsT caseWildVarT (flip elem)
appT idR
((if not wildCapture then idR else alphaCaseBinder Nothing)
>>> caseAllR idR (\i -> if null (captures !! i) then idR else alphaAlt)
)
(\f (Case s b _ty alts) -> let newTy = exprType (App f (case head alts of (_,_,e) -> e))
in Case s b newTy [ (c, ids, App f e)
| (c,ids,e) <- alts ])
caseFloatCase :: RewriteH CoreExpr
caseFloatCase = prefixFailMsg "Case floating from Case failed: " $
do
captures <- caseT caseAltVarsT (const altFreeVarsExclWildT) (\ vss bndr _ fs -> map (intersect (concatMap ($ bndr) fs)) vss)
wildCapture <- caseT caseWildVarT (const altFreeVarsExclWildT) (\ innerBndr bndr _ fvs -> innerBndr `elem` concatMap ($ bndr) fvs)
caseT ((if not wildCapture then idR else alphaCaseBinder Nothing)
>>> caseAllR idR (\i -> if null (captures !! i) then idR else alphaAlt)
)
(const idR)
(\ (Case s1 b1 _ alts1) b2 ty alts2 -> Case s1 b1 ty [ (c1, ids1, Case e1 b2 ty alts2) | (c1, ids1, e1) <- alts1 ])
caseFloatLet :: RewriteH CoreExpr
caseFloatLet = prefixFailMsg "Case floating from Let failed: " $
do vs <- letNonRecT caseAltVarsT idR (\ letVar caseVars _ -> elem letVar $ concat caseVars)
let bdsAction = if not vs then idR else nonRecR alphaCase
letT bdsAction idR $ \ (NonRec v (Case s b ty alts)) e -> Case s b ty [ (con, ids, Let (NonRec v ec) e) | (con, ids, ec) <- alts]
caseFloat :: RewriteH CoreExpr
caseFloat = setFailMsg "Unsuitable expression for Case floating." $
caseFloatApp <+ caseFloatArg <+ caseFloatCase <+ caseFloatLet
caseReduce :: RewriteH CoreExpr
caseReduce = letTransform >>> tryR (repeatR letSubstR)
where letTransform = prefixFailMsg "Case reduction failed: " $
withPatFailMsg (wrongExprForm "Case e v t alts") $
do Case s binder _ alts <- idR
case isDataCon s of
Nothing -> fail "head of scrutinee is not a data constructor."
Just (dc, args) ->
case [ (bs, rhs) | (DataAlt dc', bs, rhs) <- alts, dc == dc' ] of
[(bs,e')] -> let (tyArgs, valArgs) = span isTypeArg args
tyBndrs = takeWhile isTyVar bs
existentials = reverse $ take (length tyBndrs) $ reverse tyArgs
in return $ nestedLets e' $ (binder, s) : zip bs (existentials ++ valArgs)
[] -> fail "no matching alternative."
_ -> fail "more than one matching alternative."
isDataCon :: CoreExpr -> Maybe (DataCon, [CoreExpr])
isDataCon expr = case fn of
Var i -> do dc <- isDataConId_maybe i
return (dc, args)
_ -> fail "not a var"
where (fn, args) = collectArgs expr
nestedLets :: CoreExpr -> [(Id, CoreExpr)] -> CoreExpr
nestedLets = foldr (\(b,rhs) -> Let $ NonRec b rhs)
caseSplit :: TH.Name -> RewriteH CoreExpr
caseSplit nm = do
frees <- freeIdsT
contextfreeT $ \ e ->
case [ i | i <- frees, cmpTHName2Var nm i ] of
[] -> fail "caseSplit: provided name is not free"
(i:_) -> do
let (tycon, tys) = splitTyConApp (idType i)
dcs = tyConDataCons tycon
aNms = map (:[]) $ cycle ['a'..'z']
dcsAndVars <- mapM (\dc -> do
as <- sequence [ newIdH a ty | (a,ty) <- zip aNms $ dataConInstArgTys dc tys ]
return (dc,as)) dcs
return $ Case (Var i) i (exprType e) [ (DataAlt dc, as, e) | (dc,as) <- dcsAndVars ]
caseSplitInline :: TH.Name -> RewriteH Core
caseSplitInline nm = promoteR (caseSplit nm) >>> anybuR (promoteExprR $ inlineName nm)