module HERMIT.Dictionary.Local.Case
(
externals
, caseFloatAppR
, caseFloatArgR
, caseFloatArgLemmaR
, caseFloatCaseR
, caseFloatCastR
, caseFloatLetR
, caseFloatR
, caseFloatInR
, caseFloatInAppR
, caseFloatInArgsR
, caseReduceR
, caseReduceDataconR
, caseReduceLiteralR
, caseReduceUnfoldR
, casesForM
, caseExprsForM
, caseSplitR
, caseSplitInlineR
, caseInlineScrutineeR
, caseInlineAlternativeR
, caseMergeAltsR
, caseMergeAltsWithBinderR
, caseElimR
, caseElimInlineScrutineeR
, caseElimMergeAltsR
, caseIntroSeqR
, caseElimSeqR
) where
import Control.Arrow
import Control.Monad
import Control.Monad.IO.Class
import Data.List
import Data.Monoid
import HERMIT.Core
import HERMIT.Context
import HERMIT.External
import HERMIT.GHC
import HERMIT.Kure
import HERMIT.Lemma
import HERMIT.Monad
import HERMIT.Name
import HERMIT.ParserCore
import HERMIT.Utilities
import HERMIT.Dictionary.AlphaConversion hiding (externals)
import HERMIT.Dictionary.Common
import HERMIT.Dictionary.Fold hiding (externals)
import HERMIT.Dictionary.Inline hiding (externals)
import HERMIT.Dictionary.Undefined (verifyStrictT, buildStrictnessLemmaT)
import HERMIT.Dictionary.Unfold (unfoldR)
externals :: [External]
externals =
[ external "case-float-app" (promoteExprR caseFloatAppR :: RewriteH LCore)
[ "(case ec of alt -> e) v ==> case ec of alt -> e v" ] .+ Commute .+ Shallow
, external "case-float-arg" ((\ strict -> promoteExprR (caseFloatArg Nothing (Just strict))) :: RewriteH LCore -> RewriteH LCore)
[ "Given a proof that f is strict, then"
, "f (case s of alt -> e) ==> case s of alt -> f e" ] .+ Commute .+ Shallow
, external "case-float-arg" ((\ f strict -> promoteExprR (caseFloatArg (Just f) (Just strict))) :: CoreString -> RewriteH LCore -> RewriteH LCore)
[ "For a specified f, given a proof that f is strict, then"
, "f (case s of alt -> e) ==> case s of alt -> f e" ] .+ Commute .+ Shallow
, external "case-float-arg-unsafe" ((\ f -> promoteExprR (caseFloatArg (Just f) Nothing)) :: CoreString -> RewriteH LCore)
[ "For a specified f,"
, "f (case s of alt -> e) ==> case s of alt -> f e" ] .+ Commute .+ Shallow .+ PreCondition .+ Strictness
, external "case-float-arg-unsafe" (promoteExprR . caseFloatArgLemmaR UnsafeUsed :: LemmaName -> RewriteH LCore)
[ "f (case s of alt -> e) ==> case s of alt -> f e" ] .+ Commute .+ Shallow .+ PreCondition .+ Strictness .+ Unsafe
, external "case-float-arg-lemma" (promoteExprR . caseFloatArgLemmaR Obligation :: LemmaName -> RewriteH LCore)
[ "f (case s of alt -> e) ==> case s of alt -> f e"
, "Generates a lemma with given name for strictness side condition on f." ] .+ Commute .+ Shallow .+ PreCondition .+ Strictness
, external "case-float-case" (promoteExprR caseFloatCaseR :: RewriteH LCore)
[ "case (case ec of alt1 -> e1) of alta -> ea ==> case ec of alt1 -> case e1 of alta -> ea" ] .+ Commute .+ Eval
, external "case-float-cast" (promoteExprR caseFloatCastR :: RewriteH LCore)
[ "cast (case s of p -> e) co ==> case s of p -> cast e co" ] .+ Shallow .+ Commute
, external "case-float-let" (promoteExprR caseFloatLetR :: RewriteH LCore)
[ "let v = case ec of alt1 -> e1 in e ==> case ec of alt1 -> let v = e1 in e" ] .+ Commute .+ Shallow .+ Strictness
, external "case-float" (promoteExprR caseFloatR :: RewriteH LCore)
[ "case-float = case-float-app <+ case-float-case <+ case-float-let <+ case-float-cast" ] .+ Commute .+ Shallow .+ Strictness
, external "case-float-in" (promoteExprR caseFloatInR :: RewriteH LCore)
[ "Float in a Case whatever the context." ] .+ Commute .+ Shallow .+ PreCondition
, external "case-float-in-args" (promoteExprR caseFloatInArgsR :: RewriteH LCore)
[ "Float in a Case whose alternatives are parallel applications of the same function." ] .+ Commute .+ Shallow .+ PreCondition .+ Strictness
, external "case-reduce" (promoteExprR (caseReduceR True) :: RewriteH LCore)
[ "Case of Known Constructor"
, "case-reduce-datacon <+ case-reduce-literal" ] .+ Shallow .+ Eval
, external "case-reduce-datacon" (promoteExprR (caseReduceDataconR True) :: RewriteH LCore)
[ "Case of Known Constructor"
, "case C v1..vn of C w1..wn -> e ==> let { w1 = v1 ; .. ; wn = vn } in e" ] .+ Shallow .+ Eval
, external "case-reduce-literal" (promoteExprR (caseReduceLiteralR True) :: RewriteH LCore)
[ "Case of Known Constructor"
, "case L of L -> e ==> e" ] .+ Shallow .+ Eval
, external "case-reduce-unfold" (promoteExprR (caseReduceUnfoldR True) :: RewriteH LCore)
[ "Unfold the case scrutinee and then case-reduce." ] .+ Shallow .+ Eval .+ Context
, external "case-split" ((\nm -> findVarT (unOccurrenceName nm) >>= promoteExprR . caseSplitR . varToCoreExpr) :: OccurrenceName -> RewriteH LCore)
[ "case-split 'x"
, "e ==> case x of C1 vs -> e; C2 vs -> e, where x is free in e" ] .+ Deep .+ Strictness
, external "case-split" (parseCoreExprT >=> promoteR . caseSplitR :: CoreString -> RewriteH LCore)
[ "case-split [| expr |]"
, "e ==> case expr of C1 vs -> e; C2 vs -> e"] .+ Deep .+ Strictness
, external "case-split-inline" ((\nm -> findVarT (unOccurrenceName nm) >>= promoteExprR . caseSplitInlineR . varToCoreExpr) :: OccurrenceName -> RewriteH LCore)
[ "Like case-split, but additionally inlines the matched constructor "
, "applications for all occurances of the named variable." ] .+ Deep .+ Strictness
, external "case-split-inline" (parseCoreExprT >=> promoteExprR . caseSplitInlineR :: CoreString -> RewriteH LCore)
[ "Like case-split, but additionally inlines the matched constructor "
, "applications for all occurances of the case binder." ] .+ Deep .+ Strictness
, external "case-intro-seq" (promoteExprR . caseIntroSeqR . cmpString2Var :: String -> RewriteH LCore)
[ "Force evaluation of a variable by introducing a case."
, "case-intro-seq 'v is is equivalent to adding @(seq v)@ in the source code." ] .+ Shallow .+ Introduce .+ Strictness
, external "case-elim-seq" (promoteExprR caseElimSeqR :: RewriteH LCore)
[ "Eliminate a case that corresponds to a pointless seq." ] .+ Deep .+ Eval .+ Strictness
, external "case-inline-alternative" (promoteExprR caseInlineAlternativeR :: RewriteH LCore)
[ "Inline the case binder as the case-alternative pattern everywhere in the case alternatives." ] .+ Deep
, external "case-inline-scrutinee" (promoteExprR caseInlineScrutineeR :: RewriteH LCore)
[ "Inline the case binder as the case scrutinee everywhere in the case alternatives." ] .+ Deep
, external "case-merge-alts" (promoteExprR caseMergeAltsR :: RewriteH LCore)
[ "Merge all case alternatives into a single default case."
, "The RHS of each alternative must be the same."
, "case s of {pat1 -> e ; pat2 -> e ; ... ; patn -> e} ==> case s of {_ -> e}" ]
, external "case-merge-alts-with-binder" (promoteExprR caseMergeAltsWithBinderR :: RewriteH LCore)
[ "A cleverer version of 'mergeCaseAlts' that first attempts to"
, "abstract out any occurrences of the alternative pattern using the case binder." ] .+ Deep
, external "case-elim" (promoteExprR caseElimR :: RewriteH LCore)
[ "case s of w; C vs -> e ==> e if w and vs are not free in e" ] .+ Shallow .+ Strictness
, external "case-elim-inline-scrutinee" (promoteExprR caseElimInlineScrutineeR :: RewriteH LCore)
[ "Eliminate a case, inlining any occurrences of the case binder as the scrutinee." ] .+ Deep
, external "case-elim-merge-alts" (promoteExprR caseElimMergeAltsR :: RewriteH LCore)
[ "Eliminate a case, merging the case alternatives into a single default alternative",
"and inlining the case binder as the scrutinee (if possible)." ] .+ Deep
, external "case-fold-binder" (promoteExprR caseFoldBinderR :: RewriteH LCore)
[ "In the case alternatives, fold any occurrences of the case alt patterns to the case binder." ]
]
caseElimR :: MonadCatch m => Rewrite c m CoreExpr
caseElimR = prefixFailMsg "Case elimination failed: " $
withPatFailMsg (wrongExprForm "Case s bnd ty alts") $
do Case _ bnd _ alts <- idR
case alts of
[(_, vs, e)] -> do let fvs = freeVarsExpr e
guardMsg (isEmptyVarSet $ intersectVarSet (mkVarSet (bnd:vs)) fvs) "case binder or pattern binders free in RHS."
return e
_ -> fail "more than one case alternative."
caseFloatAppR :: (ExtendPath c Crumb, ReadPath c Crumb, AddBindings c, ReadBindings c, MonadCatch m, MonadUnique m)
=> Rewrite c m CoreExpr
caseFloatAppR = prefixFailMsg "Case floating from App function failed: " $
do
captures <- appT (liftM (map mkVarSet) caseAltVarsT) (arr freeVarsExpr) (flip (map . intersectVarSet))
bndrCapture <- appT caseBinderIdT (arr freeVarsExpr) elemVarSet
appT ((if not bndrCapture then idR else alphaCaseBinderR Nothing)
>>> caseAllR idR idR idR (\i -> if isEmptyVarSet (captures !! i) then idR else alphaAltR)
)
idR
(\(Case s b _ alts) v -> let newAlts = mapAlts (`App` v) alts
in Case s b (coreAltsType newAlts) newAlts)
caseFloatArg :: Maybe CoreString -> Maybe (RewriteH LCore) -> RewriteH CoreExpr
caseFloatArg mfstr mstrictCore = let mstrict = extractR <$> mstrictCore
in case mfstr of
Nothing -> caseFloatArgR Nothing mstrict
Just f_str -> do f <- parseCoreExprT f_str
caseFloatArgR (Just f) mstrict
caseFloatArgR :: ( ExtendPath c Crumb, ReadPath c Crumb, AddBindings c, BoundVars c
, HasDynFlags m, HasHermitMEnv m, HasHscEnv m, MonadCatch m, MonadIO m, MonadThings m, MonadUnique m )
=> Maybe CoreExpr -> Maybe (Rewrite c m CoreExpr)
-> Rewrite c m CoreExpr
caseFloatArgR mf mstrict = prefixFailMsg "Case floating from App argument failed: " $
withPatFailMsg "App f (Case s w ty alts)" $
do App f (Case s w _ alts) <- idR
whenJust (\ f' -> guardMsg (exprAlphaEq f f') "given function does not match current application.") mf
whenJust (verifyStrictT f) mstrict
let fvs = freeVarsExpr f
altCaptures = map (intersectVarSet fvs . mkVarSet . altVars) alts
bndrCapture = elemVarSet w fvs
if | bndrCapture -> appAllR idR (alphaCaseBinderR Nothing) >>> caseFloatArgR Nothing Nothing
| all isEmptyVarSet altCaptures -> let new_alts = mapAlts (App f) alts
in return $ Case s w (coreAltsType new_alts) new_alts
| otherwise -> appAllR idR (caseAllR idR idR idR (\ n -> let vs = varSetElems (altCaptures !! n)
in if null vs then idR else alphaAltVarsR vs
)
) >>> caseFloatArgR Nothing Nothing
caseFloatArgLemmaR :: ( AddBindings c, BoundVars c, ExtendPath c Crumb, ReadPath c Crumb, HasHermitMEnv m
, HasHscEnv m, HasDynFlags m, HasLemmas m, MonadCatch m, MonadIO m, MonadThings m, MonadUnique m )
=> Used -> LemmaName -> Rewrite c m CoreExpr
caseFloatArgLemmaR u nm = prefixFailMsg "Case floating from application argument failed: " $
withPatFailMsg "App f (Case s w ty alts)" $ do
App f (Case s w _ alts) <- idR
let fvs = freeVarsExpr f
altCaptures = map (intersectVarSet fvs . mkVarSet . altVars) alts
bndrCapture = elemVarSet w fvs
if | bndrCapture ->
appAllR idR (alphaCaseBinderR Nothing) >>> caseFloatArgR Nothing Nothing
| all isEmptyVarSet altCaptures -> do
let new_alts = mapAlts (App f) alts
buildStrictnessLemmaT u nm f
return $ Case s w (coreAltsType new_alts) new_alts
| otherwise ->
appAllR idR (caseAllR idR idR idR (\ n -> let vs = varSetElems (altCaptures !! n)
in if null vs then idR else alphaAltVarsR vs
)
) >>> caseFloatArgLemmaR u nm
caseFloatCaseR :: (ExtendPath c Crumb, ReadPath c Crumb, AddBindings c, ReadBindings c, MonadCatch m, MonadUnique m)
=> Rewrite c m CoreExpr
caseFloatCaseR = prefixFailMsg "Case floating from Case failed: " $
do
captures <- caseT (liftM (map mkVarSet) caseAltVarsT) idR mempty (const $ arr freeVarsAlt) (\ vss w () fvs -> map (intersectVarSet (delVarSet (unionVarSets fvs) w)) vss)
bndrCapture <- caseT caseBinderIdT idR mempty (const $ arr freeVarsAlt) (\ innerBndr w () fvs -> innerBndr `elemVarSet` (delVarSet (unionVarSets fvs) w))
caseT ((if not bndrCapture then idR else alphaCaseBinderR Nothing)
>>> caseAllR idR idR idR (\i -> if isEmptyVarSet (captures !! i) then idR else alphaAltR)
)
idR
idR
(const idR)
(\ (Case s1 b1 _ alts1) b2 ty alts2 -> Case s1 b1 ty $ mapAlts (\s -> Case s b2 ty alts2) alts1)
caseFloatLetR :: (ExtendPath c Crumb, ReadPath c Crumb, AddBindings c, ReadBindings c, MonadCatch m, MonadUnique m)
=> Rewrite c m CoreExpr
caseFloatLetR = prefixFailMsg "Case floating from Let failed: " $
do vs <- letNonRecT idR caseAltVarsT mempty (\ letVar caseVars () -> letVar `elem` concat caseVars)
let bdsAction = if not vs then idR else nonRecAllR idR alphaCaseR
letT bdsAction idR $ \ (NonRec v (Case s b _ alts)) e -> let newAlts = mapAlts (\ rhs -> Let (NonRec v rhs) e) alts
in Case s b (coreAltsType newAlts) newAlts
caseFloatCastR :: MonadCatch m => Rewrite c m CoreExpr
caseFloatCastR = prefixFailMsg "Case float from cast failed: " $
withPatFailMsg (wrongExprForm "Cast (Case s bnd ty alts) co") $
do Cast (Case s bnd _ alts) co <- idR
let alts' = mapAlts (flip Cast co) alts
return $ Case s bnd (coreAltsType alts') alts'
caseFloatR :: (ExtendPath c Crumb, ReadPath c Crumb, AddBindings c, ReadBindings c, MonadCatch m, MonadUnique m)
=> Rewrite c m CoreExpr
caseFloatR = setFailMsg "Unsuitable expression for Case floating." $
caseFloatAppR <+ caseFloatCaseR <+ caseFloatLetR <+ caseFloatCastR
caseFloatInR :: (ExtendPath c Crumb, ReadPath c Crumb, AddBindings c, MonadCatch m) => Rewrite c m CoreExpr
caseFloatInR = setFailMsg "Case floating in failed." $
caseFloatInAppR <+ caseFloatInArgsR
caseFloatInAppR :: Monad m => Rewrite c m CoreExpr
caseFloatInAppR = fail "caseFloatInApp: TODO"
caseFloatInArgsR :: (ExtendPath c Crumb, ReadPath c Crumb, AddBindings c, MonadCatch m) => Rewrite c m CoreExpr
caseFloatInArgsR = prefixFailMsg "Case floating into arguments failed: " $
withPatFailMsg (wrongExprForm "Case s v t alts") $
do Case s bndr _ty alts <- idR
(vss, fs, argss) <- caseT mempty mempty mempty (\ _ -> altT mempty (\ _ -> idR) callT $ \ () vs (fn, args) -> (vs, fn, args))
(\ () () () alts' -> unzip3 [ (bndr:vs, fn, args) | (vs,fn,args) <- alts' ])
guardMsg (equivalentBy exprAlphaEq fs) "alternatives are not parallel in function call."
let fvs = [ varSetElems $ unionVarSets $ map freeVarsExpr $ f:tyArgs
| (f,args) <- zip fs argss
, let tyArgs = takeWhile isTyCoArg args ]
guardMsg (all null $ zipWith intersect fvs vss) "function bound by case binders."
let argss' = transpose argss
guardMsg (all (equivalentBy exprAlphaEq) $ filter (isTyCoArg . head) argss') "function applied at different types."
return $ mkCoreApps (head fs) [ if isTyCoArg (head args)
then head args
else let alts' = [ (ac, vs, arg) | ((ac,vs,_),arg) <- zip alts args ]
in Case s bndr (coreAltsType alts') alts'
| args <- argss' ]
caseReduceUnfoldR :: ( ExtendPath c Crumb, ReadPath c Crumb, AddBindings c
, ReadBindings c, HasEmptyContext c, MonadCatch m, MonadUnique m )
=> Bool -> Rewrite c m CoreExpr
caseReduceUnfoldR subst = caseAllR unfoldR idR idR (const idR) >>> caseReduceR subst
caseReduceR :: (ExtendPath c Crumb, ReadPath c Crumb, AddBindings c, ReadBindings c, MonadCatch m, MonadUnique m)
=> Bool -> Rewrite c m CoreExpr
caseReduceR subst =
setFailMsg "Unsuitable expression for Case reduction." $
caseReduceDataconR subst <+ caseReduceLiteralR subst
caseReduceLiteralR :: MonadCatch m => Bool -> Rewrite c m CoreExpr
caseReduceLiteralR subst =
prefixFailMsg "Case reduction failed: " $
withPatFailMsg (wrongExprForm "Case (Lit l) v t alts") $ do
Case s bndr _ alts <- idR
let in_scope = mkInScopeSet (localFreeVarsExpr s)
case exprIsLiteral_maybe (in_scope, idUnfolding) s of
Nothing -> fail "scrutinee is not a literal."
Just l -> do guardMsg (not (litIsLifted l)) "cannot case-reduce lifted literals"
case findAlt (LitAlt l) alts of
Nothing -> fail "no matching alternative."
Just (_, _, rhs) -> return $ if subst
then substCoreExpr bndr (Lit l) rhs
else mkCoreLet (NonRec bndr (Lit l)) rhs
caseReduceDataconR :: forall c m. ( ExtendPath c Crumb, ReadPath c Crumb, AddBindings c
, ReadBindings c, MonadCatch m, MonadUnique m )
=> Bool -> Rewrite c m CoreExpr
caseReduceDataconR subst = prefixFailMsg "Case reduction failed: " $
withPatFailMsg (wrongExprForm "Case e v t alts") go
where
go :: Rewrite c m CoreExpr
go = do
Case e bndr _ alts <- idR
let in_scope = mkInScopeSet (localFreeVarsExpr e)
case exprIsConApp_maybe (in_scope, idUnfolding) e of
Nothing -> fail "head of scrutinee is not a data constructor."
Just (dc, univTys, es) -> case findAlt (DataAlt dc) alts of
Nothing -> fail "no matching alternative."
Just (dc', vs, rhs) ->
let fvss = map freeVarsExpr $ map Type univTys ++ es
shadows = [ v | (v,n) <- zip vs [1..], any (elemVarSet v) (drop n fvss) ]
in if | any (elemVarSet bndr) fvss -> alphaCaseBinderR Nothing >>> go
| null shadows -> do let binds = zip (bndr : vs) (e : es)
return $ if subst
then foldr (uncurry substCoreExpr) rhs binds
else flip mkCoreLets rhs $ map (uncurry NonRec) binds
| otherwise -> caseOneR (fail "scrutinee") (fail "binder") (fail "type") (\ _ -> acceptR (\ (dc'',_,_) -> dc'' == dc') >>> alphaAltVarsR shadows) >>> go
caseSplitR :: forall c m. ( AddBindings c, BoundVars c, ExtendPath c Crumb, HasEmptyContext c, ReadPath c Crumb
, MonadCatch m, MonadUnique m )
=> CoreExpr -> Rewrite c m CoreExpr
caseSplitR s = prefixFailMsg "case-split failed: " $ do
c <- contextT
guardMsg (all (inScope c) $ varSetElems $ freeVarsExpr s) "variables in desired scrutinee are unbound."
w <- constT $ newVarH "w" (exprType s)
let f = compileFold [Equality [] s (varToCoreExpr w)]
e' <- tryR $ withVarsInScope [w] $ extractR (anytdR (promoteR $ runFoldR f) :: Rewrite c m Core)
constT $ do
dcsAndBss <- casesForM s
let alts = [ (DataAlt dc, bs, e') | (dc,bs) <- dcsAndBss ]
guardMsg (not (null alts)) "no constructors for scrutinee of that type."
return $ Case s w (exprType e') alts
casesForM :: MonadUnique m => CoreExpr -> m [(DataCon, [Id])]
casesForM e = do
(tyCon, tys) <- splitTyConAppM (exprType e)
let aNms = map (:[]) $ cycle ['a'..'z']
forM (tyConDataCons tyCon) $ \ dc -> do
bs <- sequence [ newIdH a ty | (a,ty) <- zip aNms $ dataConInstArgTys dc tys ]
return (dc,bs)
caseExprsForM :: MonadUnique m => CoreExpr -> m [CoreExpr]
caseExprsForM e = do
(_, tys) <- splitTyConAppM (exprType e)
cases <- casesForM e
return [ mkDataConApp tys dc vs | (dc,vs) <- cases ]
caseIntroSeqR :: (MonadCatch m, MonadUnique m) => (Id -> Bool) -> Rewrite c m CoreExpr
caseIntroSeqR idPred = prefixFailMsg "case-intro-seq failed: " $
do i <- matchingFreeIdT idPred
contextfreeT $ \ e -> do guardMsg (not $ isTyCoArg e) "cannot case on a type or coercion."
w <- cloneVarH (++ "'") i
let e' = substCoreExpr i (Var w) e
alts = [(DEFAULT, [], e')]
return $ Case (Var i) w (coreAltsType alts) alts
matchingFreeIdT :: Monad m => (Id -> Bool) -> Transform c m CoreExpr Id
matchingFreeIdT idPred = do
fvs <- arr freeVarsExpr
case varSetElems (filterVarSet (\ v -> idPred v && isId v) fvs) of
[] -> fail "provided name is not a free identifier."
[i] -> return i
is -> fail ("provided name matches " ++ show (length is) ++ " free identifiers.")
caseSplitInlineR :: ( ExtendPath c Crumb, ReadPath c Crumb, AddBindings c
, ReadBindings c, HasEmptyContext c, MonadCatch m, MonadUnique m )
=> CoreExpr -> Rewrite c m CoreExpr
caseSplitInlineR s = caseSplitR s >>> caseInlineAlternativeR
caseInlineBinderR :: forall c m. ( ExtendPath c Crumb, ReadPath c Crumb, AddBindings c
, ReadBindings c, HasEmptyContext c, MonadCatch m, MonadUnique m )
=> CaseBinderInlineOption -> Rewrite c m CoreExpr
caseInlineBinderR opt =
do w <- caseBinderIdT
caseAllR idR idR idR $ \ _ -> setFailMsg "no inlinable occurrences." $
do depth <- varBindingDepthT w
extractR $ anybuR (promoteExprR (configurableInlineR (CaseBinderOnly opt) (varIsOccurrenceOfT w depth)) :: Rewrite c m Core)
caseInlineScrutineeR :: ( ExtendPath c Crumb, ReadPath c Crumb, AddBindings c
, ReadBindings c, HasEmptyContext c, MonadCatch m, MonadUnique m )
=> Rewrite c m CoreExpr
caseInlineScrutineeR = prefixFailMsg "case-inline-scrutinee failed: " $
caseInlineBinderR Scrutinee
caseInlineAlternativeR :: ( ExtendPath c Crumb, ReadPath c Crumb, AddBindings c
, ReadBindings c, HasEmptyContext c, MonadCatch m, MonadUnique m )
=> Rewrite c m CoreExpr
caseInlineAlternativeR = prefixFailMsg "case-inline-alternative failed: " $
caseInlineBinderR Alternative
caseMergeAltsR :: MonadCatch m => Rewrite c m CoreExpr
caseMergeAltsR = prefixFailMsg "merge-case-alts failed: " $
withPatFailMsg (wrongExprForm "Case e w ty alts") $
do Case e w ty alts <- idR
guardMsg (notNull alts) "zero alternatives cannot be merged."
let rhss = [ rhs | (_,_,rhs) <- alts ]
guardMsg (equivalentBy exprAlphaEq rhss) "right-hand sides are not all equal."
guardMsg (all altVarsUnused alts) "variables bound in case alt pattern appear free in alt right-hand side."
return $ Case e w ty [(DEFAULT,[],head rhss)]
caseFoldBinderR :: forall c m. ( ExtendPath c Crumb, ReadPath c Crumb, AddBindings c
, ReadBindings c, HasEmptyContext c, MonadCatch m, MonadUnique m )
=> Rewrite c m CoreExpr
caseFoldBinderR = prefixFailMsg "case-fold-binder failed: " $
caseAllR idR (arr (flip setIdOccInfo NoOccInfo)) idR (const idR) >>> (do
w <- caseBinderIdT
caseAllR idR idR idR $ \ _ -> do depth <- varBindingDepthT w
extractR $ anybuR (promoteExprR (foldVarR (Just depth) w) :: Rewrite c m Core))
caseMergeAltsWithBinderR :: ( ExtendPath c Crumb, ReadPath c Crumb, AddBindings c
, ReadBindings c, HasEmptyContext c, MonadCatch m, MonadUnique m )
=> Rewrite c m CoreExpr
caseMergeAltsWithBinderR =
prefixFailMsg "merge-case-alts-with-binder failed: " $
withPatFailMsg (wrongExprForm "Case e w ty alts") $
tryR caseFoldBinderR >>> caseMergeAltsR
caseElimInlineScrutineeR :: ( ExtendPath c Crumb, ReadPath c Crumb, AddBindings c
, ReadBindings c, HasEmptyContext c, MonadCatch m, MonadUnique m )
=> Rewrite c m CoreExpr
caseElimInlineScrutineeR = alphaCaseBinderR Nothing >>> tryR caseInlineScrutineeR >>> caseElimR
caseElimMergeAltsR :: ( ExtendPath c Crumb, ReadPath c Crumb, AddBindings c
, ReadBindings c, HasEmptyContext c, MonadCatch m, MonadUnique m )
=> Rewrite c m CoreExpr
caseElimMergeAltsR = tryR caseFoldBinderR >>> tryR caseMergeAltsR >>> caseElimInlineScrutineeR
caseElimSeqR :: ( ExtendPath c Crumb, ReadPath c Crumb, AddBindings c
, ReadBindings c, HasEmptyContext c, MonadCatch m, MonadUnique m )
=> Rewrite c m CoreExpr
caseElimSeqR = prefixFailMsg "case-elim-seq failed: " $
withPatFailMsg "not a seq case." $
do Case s w _ [(DEFAULT,[],rhs)] <- idR
let is = case s of
Var i -> [i,w]
_ -> [w]
if is `isForcedIn` rhs
then caseElimInlineScrutineeR
else fail "cannot be sure that this seq case is pointless. Use case-elim-inline-scrutinee if you want to proceed anyway."
isForcedIn :: [Id] -> CoreExpr -> Bool
isForcedIn is = \case
Var i -> i `elem` is
App f _ -> is `isForcedIn` f
Let _ body -> is `isForcedIn` body
Case s _ _ _ -> is `isForcedIn` s
Cast e _ -> is `isForcedIn` e
Tick _ e -> is `isForcedIn` e
_ -> False
altVarsUnused :: CoreAlt -> Bool
altVarsUnused (_,vs,rhs) = all (`notElemVarSet` freeVarsExpr rhs) vs