{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE NoImplicitPrelude #-} {-# LANGUAGE ScopedTypeVariables #-} module HERMIT.Dictionary.Local.Let ( -- * Rewrites on Let Expressions externals -- ** Let Elimination , letNonRecSubstR , letNonRecSubstSafeR , letSubstR , letSubstSafeR , letElimR , letNonRecElimR , letRecElimR , progBindElimR , progBindNonRecElimR , progBindRecElimR -- ** Let Introduction , letIntroR , letNonRecIntroR , progNonRecIntroR , nonRecIntroR , letIntroUnfoldingR -- ** Let Floating Out , letFloatAppR , letFloatArgR , letFloatLetR , letFloatLamR , letFloatCaseR , letFloatCaseAltR , letFloatCastR , letFloatExprR , letFloatTopR -- ** Let Floating In , letFloatInR , letFloatInAppR , letFloatInCaseR , letFloatInLamR -- ** Miscallaneous , reorderNonRecLetsR , letTupleR , letToCaseR ) where import Control.Arrow import Control.Monad (ap, liftM, when) import Control.Monad.IO.Class import Data.List (intersect, partition) import Data.Monoid import HERMIT.Core import HERMIT.Context import HERMIT.Monad import HERMIT.Kure import HERMIT.External import HERMIT.GHC import HERMIT.Name import HERMIT.Utilities import HERMIT.Dictionary.Common import HERMIT.Dictionary.GHC hiding (externals) import HERMIT.Dictionary.Inline hiding (externals) import HERMIT.Dictionary.AlphaConversion hiding (externals) import HERMIT.Dictionary.Local.Bind hiding (externals) import Prelude.Compat hiding ((<$)) ------------------------------------------------------------------------------ -- | Externals relating to 'Let' expressions. externals :: [External] externals = [ external "let-subst" (promoteExprR letSubstR :: RewriteH LCore) [ "Let substitution: (let x = e1 in e2) ==> (e2[e1/x])" , "x must not be free in e1." ] .+ Deep .+ Eval , external "let-subst-safe" (promoteExprR letSubstSafeR :: RewriteH LCore) [ "Safe let substitution" , "let x = e1 in e2, safe to inline without duplicating work ==> e2[e1/x]," , "x must not be free in e1." ] .+ Deep .+ Eval , external "let-nonrec-subst-safe" (promoteExprR letNonRecSubstSafeR :: RewriteH LCore) [ "As let-subst-safe, but does not try to convert a recursive let into a non-recursive let first." ] .+ Deep .+ Eval -- , external "safe-let-subst-plus" (promoteExprR safeLetSubstPlusR :: RewriteH LCore) -- [ "Safe let substitution" -- , "let { x = e1, ... } in e2, " -- , " where safe to inline without duplicating work ==> e2[e1/x,...]," -- , "only matches non-recursive lets" ] .+ Deep .+ Eval , external "let-intro" (promoteExprR . letIntroR :: String -> RewriteH LCore) [ "e => (let v = e in v), name of v is provided" ] .+ Shallow .+ Introduce , external "let-intro-unfolding" (promoteExprR . letIntroUnfoldingR :: HermitName -> RewriteH LCore) [ "e => let f' = defn[f'/f] in e[f'/f], name of f is provided" ] , external "let-elim" (promoteExprR letElimR :: RewriteH LCore) [ "Remove an unused let binding." , "(let v = e1 in e2) ==> e2, if v is not free in e1 or e2." ] .+ Eval .+ Shallow -- , external "let-constructor-reuse" (promoteR $ not_defined "constructor-reuse" :: RewriteH LCore) -- [ "let v = C v1..vn in ... C v1..vn ... ==> let v = C v1..vn in ... v ..., fails otherwise" ] .+ Eval , external "let-float-app" (promoteExprR letFloatAppR :: RewriteH LCore) [ "(let v = ev in e) x ==> let v = ev in e x" ] .+ Commute .+ Shallow , external "let-float-arg" (promoteExprR letFloatArgR :: RewriteH LCore) [ "f (let v = ev in e) ==> let v = ev in f e" ] .+ Commute .+ Shallow , external "let-float-lam" (promoteExprR letFloatLamR :: RewriteH LCore) [ "The Full Laziness Transformation" , "(\\ v1 -> let v2 = e1 in e2) ==> let v2 = e1 in (\\ v1 -> e2), if v1 is not free in e2." , "If v1 = v2 then v1 will be alpha-renamed." ] .+ Commute .+ Shallow , external "let-float-let" (promoteExprR letFloatLetR :: RewriteH LCore) [ "let v = (let w = ew in ev) in e ==> let w = ew in let v = ev in e" ] .+ Commute .+ Shallow , external "let-float-case" (promoteExprR letFloatCaseR :: RewriteH LCore) [ "case (let v = ev in e) of ... ==> let v = ev in case e of ..." ] .+ Commute .+ Shallow .+ Eval , external "let-float-case-alt" (promoteExprR (letFloatCaseAltR Nothing) :: RewriteH LCore) [ "case s of { ... ; p -> let v = ev in e ; ... } " , "==> let v = ev in case s of { ... ; p -> e ; ... } " ] .+ Commute .+ Shallow .+ Eval , external "let-float-case-alt" (promoteExprR . letFloatCaseAltR . Just :: Int -> RewriteH LCore) [ "Float a let binding from specified alternative." , "case s of { ... ; p -> let v = ev in e ; ... } " , "==> let v = ev in case s of { ... ; p -> e ; ... } " ] .+ Commute .+ Shallow .+ Eval , external "let-float-cast" (promoteExprR letFloatCastR :: RewriteH LCore) [ "cast (let bnds in e) co ==> let bnds in cast e co" ] .+ Commute .+ Shallow , external "let-float-top" (promoteProgR letFloatTopR :: RewriteH LCore) [ "v = (let bds in e) : prog ==> bds : v = e : prog" ] .+ Commute .+ Shallow , external "let-float" (promoteProgR letFloatTopR <+ promoteExprR letFloatExprR :: RewriteH LCore) [ "Float a Let whatever the context." ] .+ Commute .+ Shallow -- Don't include in bash, as each sub-rewrite is tagged "Bash" already. , external "let-to-case" (promoteExprR letToCaseR :: RewriteH LCore) [ "let v = ev in e ==> case ev of v -> e" ] .+ Commute .+ Shallow .+ PreCondition -- , external "let-to-case-unbox" (promoteR $ not_defined "let-to-case-unbox" :: RewriteH LCore) -- [ "let v = ev in e ==> case ev of C v1..vn -> let v = C v1..vn in e" ] , external "let-float-in" (promoteExprR letFloatInR >+> anybuR (promoteExprR letElimR) :: RewriteH LCore) [ "Float-in a let if possible." ] .+ Commute .+ Shallow , external "let-float-in-app" ((promoteExprR letFloatInAppR >+> anybuR (promoteExprR letElimR)) :: RewriteH LCore) [ "let v = ev in f a ==> (let v = ev in f) (let v = ev in a)" ] .+ Commute .+ Shallow , external "let-float-in-case" ((promoteExprR letFloatInCaseR >+> anybuR (promoteExprR letElimR)) :: RewriteH LCore) [ "let v = ev in case s of p -> e ==> case (let v = ev in s) of p -> let v = ev in e" , "if v does not shadow a pattern binder in p" ] .+ Commute .+ Shallow , external "let-float-in-lam" ((promoteExprR letFloatInLamR >+> anybuR (promoteExprR letElimR)) :: RewriteH LCore) [ "let v = ev in \\ x -> e ==> \\ x -> let v = ev in e" , "if v does not shadow x" ] .+ Commute .+ Shallow , external "reorder-lets" (promoteExprR . reorderNonRecLetsR :: [String] -> RewriteH LCore) [ "Re-order a sequence of nested non-recursive let bindings." , "The argument list should contain the let-bound variables, in the desired order." ] , external "let-tuple" (promoteExprR . letTupleR :: String -> RewriteH LCore) [ "Combine nested non-recursive lets into case of a tuple." , "E.g. let {v1 = e1 ; v2 = e2 ; v3 = e3} in body ==> case (e1,e2,e3) of {(v1,v2,v3) -> body}" ] .+ Commute , external "prog-bind-elim" (promoteProgR progBindElimR :: RewriteH LCore) [ "Remove unused top-level binding(s)." , "prog-bind-nonrec-elim <+ prog-bind-rec-elim" ] .+ Eval .+ Shallow , external "prog-bind-nonrec-elim" (promoteProgR progBindNonRecElimR :: RewriteH LCore) [ "Remove unused top-level binding(s)." , "v = e : prog ==> prog, if v is not free in prog and not exported." ] .+ Eval .+ Shallow , external "prog-bind-rec-elim" (promoteProgR progBindRecElimR :: RewriteH LCore) [ "Remove unused top-level binding(s)." , "v+ = e+ : prog ==> v* = e* : prog, where v* is a subset of v+ consisting" , "of vs that are free in prog or e+, or exported." ] .+ Eval .+ Shallow ] ------------------------------------------------------------------------------------------- -- | (let x = e1 in e2) ==> (e2[e1/x]), (x must not be free in e1) letSubstR :: (AddBindings c, ExtendPath c Crumb, ReadPath c Crumb, MonadCatch m) => Rewrite c m CoreExpr letSubstR = letAllR (tryR recToNonrecR) idR >>> letNonRecSubstR -- | As 'letNonRecSubstSafeR', but attempting to convert a singleton recursive binding to a non-recursive binding first. letSubstSafeR :: (AddBindings c, ExtendPath c Crumb, ReadPath c Crumb, ReadBindings c, HasEmptyContext c, MonadCatch m) => Rewrite c m CoreExpr letSubstSafeR = letAllR (tryR recToNonrecR) idR >>> letNonRecSubstSafeR -- | @Let (NonRec v e) body@ ==> @body[e/v]@ letNonRecSubstR :: MonadCatch m => Rewrite c m CoreExpr letNonRecSubstR = prefixFailMsg "Let substitution failed: " $ withPatFailMsg (wrongExprForm "Let (NonRec v rhs) body") $ do Let (NonRec v rhs) body <- idR return (substCoreExpr v rhs body) {- TODO: This was written very early in the project by Andy. It was later modified somewhat by Neil, but without reassessing the heurisitc as a whole. It may need revisiting. Safe Subst Heuristic -------------------- Substitution is safe if (A) OR (B) OR (C). (A) The let-bound variable is a type or coercion. (B) The let-bound value is either: (i) a variable; (ii) a lambda; (iii) an application that requires more value arguments before it can perform any computation. (C) In the body, the let-bound variable must NOT occur: (i) more than once; (ii) inside a lambda. -} -- | Currently we always substitute types and coercions, and use a heuristic to decide whether to substitute expressions. -- This may need revisiting. letNonRecSubstSafeR :: forall c m. (AddBindings c, ExtendPath c Crumb, ReadPath c Crumb, ReadBindings c, HasEmptyContext c, MonadCatch m) => Rewrite c m CoreExpr letNonRecSubstSafeR = do Let (NonRec v _) _ <- idR when (isId v) $ guardMsgM (safeSubstT v) "safety criteria not met." letNonRecSubstR where safeSubstT :: Id -> Transform c m CoreExpr Bool safeSubstT i = letNonRecT mempty safeBindT (safeOccursT i) (\ () -> (||)) -- what about other Expr constructors, e.g Cast? safeBindT :: Transform c m CoreExpr Bool safeBindT = do c <- contextT arr $ \ e -> case e of Var {} -> True Lam {} -> True App {} -> case collectArgs e of (Var f,args) -> arityOf c f > length (filter (not . isTyCoArg) args) -- Neil: I've changed this to "not . isTyCoArg" rather than "not . isTypeArg". -- This may not be the right thing to do though. (other,args) -> case collectBinders other of (bds,_) -> length bds > length args _ -> False safeOccursT :: Id -> Transform c m CoreExpr Bool safeOccursT i = do depth <- varBindingDepthT i let occursHereT :: Transform c m Core () occursHereT = promoteExprT (exprIsOccurrenceOfT i depth >>> guardT) -- lamOccurrenceT can only fail if the expression is not a Lam -- return either 2 (occurrence) or 0 (no occurrence) lamOccurrenceT :: Transform c m CoreExpr (Sum Int) lamOccurrenceT = lamT mempty (mtryM (Sum 2 <$ extractT (onetdT occursHereT))) mappend occurrencesT :: Transform c m Core (Sum Int) occurrencesT = prunetdT (promoteExprT lamOccurrenceT <+ (Sum 1 <$ occursHereT)) extractT occurrencesT >>^ (getSum >>> (< 2)) (<$) :: Monad m => a -> m b -> m a a <$ mb = mb >> return a ------------------------------------------------------------------------------------------- letElimR :: (ExtendPath c Crumb, AddBindings c, MonadCatch m) => Rewrite c m CoreExpr letElimR = prefixFailMsg "Let elimination failed: " $ withPatFailMsg (wrongExprForm "Let binds expr") $ do Let bg _ <- idR case bg of NonRec{} -> letNonRecElimR Rec{} -> letRecElimR -- | Remove an unused non-recursive let binding. -- @let v = E1 in E2@ ==> @E2@, if @v@ is not free in @E2@ letNonRecElimR :: MonadCatch m => Rewrite c m CoreExpr letNonRecElimR = withPatFailMsg (wrongExprForm "Let (NonRec v e1) e2") $ do Let (NonRec v _) e <- idR guardMsg (v `notElemVarSet` freeVarsExpr e) "let-bound variable appears in the expression." return e -- | Remove all unused recursive let bindings in the current group. letRecElimR :: MonadCatch m => Rewrite c m CoreExpr letRecElimR = withPatFailMsg (wrongExprForm "Let (Rec v e1) e2") $ do Let (Rec bnds) body <- idR let bodyFrees = freeIdsExpr body bsAndFrees = map (second freeIdsExpr) bnds usedIds = chaseDependencies bodyFrees bsAndFrees bs = mkVarSet (map fst bsAndFrees) liveBinders = bs `intersectVarSet` usedIds if isEmptyVarSet liveBinders then return body else if bs `subVarSet` liveBinders then fail "no dead binders to eliminate." else return $ Let (Rec $ filter ((`elemVarSet` liveBinders) . fst) bnds) body progBindElimR :: MonadCatch m => Rewrite c m CoreProg progBindElimR = progBindNonRecElimR <+ progBindRecElimR progBindNonRecElimR :: MonadCatch m => Rewrite c m CoreProg progBindNonRecElimR = withPatFailMsg (wrongExprForm "ProgCons (NonRec v e1) e2") $ do ProgCons (NonRec v _) p <- idR guardMsg (v `notElemVarSet` freeVarsProg p) "variable appears in program body." guardMsg (not (isExportedId v)) "variable is exported." return p -- | Remove all unused bindings at the top level. progBindRecElimR :: MonadCatch m => Rewrite c m CoreProg progBindRecElimR = withPatFailMsg (wrongExprForm "ProgCons (Rec v e1) e2") $ do ProgCons (Rec bnds) p <- idR let pFrees = freeVarsProg p bsAndFrees = map (second freeIdsExpr) bnds usedIds = chaseDependencies pFrees bsAndFrees bs = mkVarSet (map fst bsAndFrees) liveBinders = (bs `intersectVarSet` usedIds) `unionVarSet` (filterVarSet isExportedId bs) if isEmptyVarSet liveBinders then return p else if bs `subVarSet` liveBinders then fail "no dead binders to eliminate." else return $ ProgCons (Rec $ filter ((`elemVarSet` liveBinders) . fst) bnds) p chaseDependencies :: VarSet -> [(Var,VarSet)] -> VarSet chaseDependencies usedIds bsAndFrees = case partition ((`elemVarSet` usedIds) . fst) bsAndFrees of ([],_) -> usedIds (used,unused) -> chaseDependencies (unionVarSets (usedIds : map snd used)) unused ------------------------------------------------------------------------------------------- -- | @let v = ev in e@ ==> @case ev of v -> e@ letToCaseR :: (ExtendPath c Crumb, ReadPath c Crumb, AddBindings c, ReadBindings c, MonadCatch m, MonadUnique m) => Rewrite c m CoreExpr letToCaseR = prefixFailMsg "Converting Let to Case failed: " $ withPatFailMsg (wrongExprForm "Let (NonRec v e1) e2") $ do Let (NonRec v ev) _ <- idR guardMsg (not $ isTyCoArg ev) "cannot case on a type or coercion." caseBndr <- extractT (cloneVarAvoidingT v Nothing [v]) letT mempty (replaceVarR v caseBndr) $ \ () e' -> Case ev caseBndr (varType v) [(DEFAULT, [], e')] ------------------------------------------------------------------------------------------- -- | @(let v = ev in e) x@ ==> @let v = ev in e x@ letFloatAppR :: (ExtendPath c Crumb, ReadPath c Crumb, AddBindings c, BoundVars c, MonadCatch m, MonadUnique m) => Rewrite c m CoreExpr letFloatAppR = prefixFailMsg "Let floating from App function failed: " $ withPatFailMsg (wrongExprForm "App (Let bnds body) e") $ do App (Let bnds body) e <- idR let vs = mkVarSet (bindVars bnds) `intersectVarSet` freeVarsExpr e if isEmptyVarSet vs then return $ Let bnds (App body e) else appAllR (alphaLetVarsR $ varSetElems vs) idR >>> letFloatAppR -- | @f (let v = ev in e)@ ==> @let v = ev in f e@ letFloatArgR :: (ExtendPath c Crumb, ReadPath c Crumb, AddBindings c, BoundVars c, MonadCatch m, MonadUnique m) => Rewrite c m CoreExpr letFloatArgR = prefixFailMsg "Let floating from App argument failed: " $ withPatFailMsg (wrongExprForm "App f (Let bnds body)") $ do App f (Let bnds body) <- idR let vs = mkVarSet (bindVars bnds) `intersectVarSet` freeVarsExpr f if isEmptyVarSet vs then return $ Let bnds (App f body) else appAllR idR (alphaLetVarsR $ varSetElems vs) >>> letFloatArgR -- | @let v = (let bds in e1) in e2@ ==> @let bds in let v = e1 in e2@ letFloatLetR :: (ExtendPath c Crumb, ReadPath c Crumb, AddBindings c, BoundVars c, MonadCatch m, MonadUnique m) => Rewrite c m CoreExpr letFloatLetR = prefixFailMsg "Let floating from Let failed: " $ withPatFailMsg (wrongExprForm "Let (NonRec v (Let bds e1)) e2") $ do Let (NonRec v (Let bds e1)) e2 <- idR let vs = mkVarSet (bindVars bds) `intersectVarSet` freeVarsExpr e2 if isEmptyVarSet vs then return $ Let bds (Let (NonRec v e1) e2) else letNonRecAllR idR (alphaLetVarsR $ varSetElems vs) idR >>> letFloatLetR -- | @(\ v -> let binds in e2)@ ==> @let binds in (\ v1 -> e2)@ -- Fails if @v@ occurs in the RHS of @binds@. -- If @v@ is shadowed in binds, then @v@ will be alpha-renamed. letFloatLamR :: (ExtendPath c Crumb, ReadPath c Crumb, AddBindings c, BoundVars c, MonadCatch m, MonadUnique m) => Rewrite c m CoreExpr letFloatLamR = prefixFailMsg "Let floating from Lam failed: " $ withPatFailMsg (wrongExprForm "Lam v1 (Let bds body)") $ do Lam v (Let binds body) <- idR let bs = bindVars binds fvs = freeVarsBind binds guardMsg (v `notElemVarSet` fvs) (unqualifiedName v ++ " occurs in the RHS of the let-bindings.") if v `elem` bs then alphaLamR Nothing >>> letFloatLamR else return $ Let binds (Lam v body) -- | @case (let bnds in e) of bndr alts@ ==> @let bnds in (case e of bndr alts)@ -- Fails if any variables bound in @bnds@ occurs in @alts@. letFloatCaseR :: (ExtendPath c Crumb, ReadPath c Crumb, AddBindings c, BoundVars c, MonadCatch m, MonadUnique m) => Rewrite c m CoreExpr letFloatCaseR = prefixFailMsg "Let floating from Case failed: " $ withPatFailMsg (wrongExprForm "Case (Let bnds e) w ty alts") $ do Case (Let bnds e) w ty alts <- idR let captures = mkVarSet (bindVars bnds) `intersectVarSet` delVarSet (unionVarSets $ map freeVarsAlt alts) w if isEmptyVarSet captures then return $ Let bnds (Case e w ty alts) else caseAllR (alphaLetVarsR $ varSetElems captures) idR idR (const idR) >>> letFloatCaseR -- | case e of w { ... ; p -> let b = rhs in body ; ... } ==> -- let b = rhs in case e of { ... ; p -> body ; ... } -- -- where no variable in `p` or `w` occurs freely in `rhs`, -- and where `b` does not capture a free variable in the overall case, -- and where `w` is not rebound in `b`. letFloatCaseAltR :: MonadCatch m => Maybe Int -> Rewrite c m CoreExpr letFloatCaseAltR maybeN = prefixFailMsg "Let float from case alternative failed: " $ withPatFailMsg (wrongExprForm "Case s w ty alts") $ do -- Perform the first safe let-floating out of a case alternative let letFloatOneAltM :: MonadCatch m => Id -> VarSet -> [CoreAlt] -> m (CoreBind,[CoreAlt]) letFloatOneAltM w fvs = go where go [] = fail "no lets can be safely floated from alternatives." go (alt:rest) = (do (bind,alt') <- letFloatAltM w fvs alt return (bind,alt':rest)) <+ liftM (second (alt :)) (go rest) -- (p -> let bnds in body) ==> (bnds, p -> body) letFloatAltM :: Monad m => Id -> VarSet -> CoreAlt -> m (CoreBind,CoreAlt) letFloatAltM w fvs (con, vs, Let bnds body) = do let bSet = mkVarSet (bindVars bnds) vSet = mkVarSet (w:vs) -- 'w' is not in 'fvs', but if it is rebound by 'b', doing this rewrite -- would cause it to bind things that were previously bound by 'b'. guardMsg (not (w `elemVarSet` bSet)) "floating would allow case binder to capture variables." -- no free vars in 'rhs' are bound by 'p' or 'w' guardMsg (isEmptyVarSet $ vSet `intersectVarSet` freeVarsBind bnds) "floating would cause variables in rhs to become unbound." -- no free vars in overall case are bound by 'b' guardMsg (isEmptyVarSet $ bSet `intersectVarSet` fvs) "floating would cause let binders to capture variables in case expression." return (bnds, (con, vs, body)) letFloatAltM _ _ _ = fail "no let expression on alternative right-hand side." Case e w ty alts <- idR fvs <- arr freeVarsExpr let l = length alts - 1 case maybeN of Just n | n < 0 || n > l -> fail $ "valid alternative indices: 0 to " ++ show l | otherwise -> do let (pre, alt:suf) = splitAt n alts (bnds,alt') <- letFloatAltM w fvs alt return $ Let bnds $ Case e w ty $ pre ++ (alt':suf) Nothing -> do (bnds,alts') <- letFloatOneAltM w fvs alts return $ Let bnds $ Case e w ty alts' -- | @cast (let bnds in e) co@ ==> @let bnds in cast e co@ letFloatCastR :: MonadCatch m => Rewrite c m CoreExpr letFloatCastR = prefixFailMsg "Let floating from Cast failed: " $ withPatFailMsg (wrongExprForm "Cast (Let bnds e) co") $ do Cast (Let bnds e) co <- idR return $ Let bnds (Cast e co) -- | Float a 'Let' through an expression, whatever the context. letFloatExprR :: (ExtendPath c Crumb, ReadPath c Crumb, AddBindings c, BoundVars c, MonadCatch m, MonadUnique m) => Rewrite c m CoreExpr letFloatExprR = setFailMsg "Unsuitable expression for Let floating." $ letFloatArgR <+ letFloatAppR <+ letFloatLetR <+ letFloatLamR <+ letFloatCaseR <+ letFloatCaseAltR Nothing <+ letFloatCastR -- | @'ProgCons' ('NonRec' v ('Let' bds e)) p@ ==> @'ProgCons' bds ('ProgCons' ('NonRec' v e) p)@ letFloatTopR :: (ExtendPath c Crumb, ReadPath c Crumb, AddBindings c, BoundVars c, MonadCatch m, MonadUnique m) => Rewrite c m CoreProg letFloatTopR = prefixFailMsg "Let floating to top level failed: " $ withPatFailMsg (wrongExprForm "NonRec v (Let bds e) `ProgCons` p") $ do ProgCons (NonRec v (Let bds e)) p <- idR let bs = bindVars bds guardMsg (all isId bs) "type and coercion bindings are not allowed at the top level." let vs = intersectVarSet (mkVarSet bs) (freeVarsProg p) if isEmptyVarSet vs then return $ ProgCons bds (ProgCons (NonRec v e) p) else consNonRecAllR idR (alphaLetVarsR $ varSetElems vs) idR >>> letFloatTopR ------------------------------------------------------------------------------------------- -- | Float in a 'Let' if possible. letFloatInR :: (AddBindings c, BoundVars c, ExtendPath c Crumb, ReadPath c Crumb, MonadCatch m, MonadUnique m) => Rewrite c m CoreExpr letFloatInR = letFloatInCaseR <+ letFloatInAppR <+ letFloatInLamR -- | @let v = ev in case s of p -> e@ ==> @case (let v = ev in s) of p -> let v = ev in e@, -- if @v@ does not shadow a pattern binder in @p@ letFloatInCaseR :: (AddBindings c, BoundVars c, ExtendPath c Crumb, ReadPath c Crumb, MonadCatch m, MonadUnique m) => Rewrite c m CoreExpr letFloatInCaseR = prefixFailMsg "Let floating in to case failed: " $ withPatFailMsg (wrongExprForm "Let bnds (Case s w ty alts)") $ do Let bnds (Case s w ty alts) <- idR let bs = bindVars bnds captured = bs `intersect` (w : concatMap altVars alts) guardMsg (null captured) "let bindings would capture case pattern bindings." let unbound = mkVarSet bs `intersectVarSet` (tyVarsOfType ty `unionVarSet` freeVarsVar w) guardMsg (isEmptyVarSet unbound) "type variables in case signature would become unbound." return (Case (Let bnds s) w ty alts) >>> caseAllR idR idR idR (\_ -> altAllR idR (\_ -> idR) (arr (Let bnds) >>> alphaLetR)) -- | @let v = ev in f a@ ==> @(let v = ev in f) (let v = ev in a)@ letFloatInAppR :: (AddBindings c, BoundVars c, ExtendPath c Crumb, ReadPath c Crumb, MonadCatch m, MonadUnique m) => Rewrite c m CoreExpr letFloatInAppR = prefixFailMsg "Let floating in to app failed: " $ withPatFailMsg (wrongExprForm "Let bnds (App e1 e2)") $ do Let bnds (App e1 e2) <- idR lhs <- return (Let bnds e1) >>> alphaLetR return $ App lhs (Let bnds e2) -- | @let v = ev in \ x -> e@ ==> @\x -> let v = ev in e@ -- if @v@ does not shadow @x@ letFloatInLamR :: (ExtendPath c Crumb, ReadPath c Crumb, AddBindings c, MonadCatch m) => Rewrite c m CoreExpr letFloatInLamR = prefixFailMsg "Let floating in to lambda failed: " $ withPatFailMsg (wrongExprForm "Let bnds (Lam v e)") $ do Let bnds (Lam v e) <- idR safe <- letT (arr bindVars) lamVarT $ flip notElem guardMsg safe "let bindings would capture lambda binding." return $ Lam v $ Let bnds e ------------------------------------------------------------------------------------------- -- | Re-order a sequence of nested non-recursive let bindings. -- The argument list should contain the let-bound variables, in the desired order. reorderNonRecLetsR :: MonadCatch m => [String] -> Rewrite c m CoreExpr reorderNonRecLetsR nms = prefixFailMsg "Reorder lets failed: " $ do guardMsg (notNull nms) "no names given." guardMsg (nodups nms) "duplicate names given." e <- idR (ves,x) <- setFailMsg "insufficient non-recursive lets." $ takeNonRecLets (length nms) e guardMsg (noneFreeIn ves) "some of the bound variables appear in the right-hand-sides." e' <- mkNonRecLets `liftM` mapM (lookupName ves) nms `ap` return x guardMsg (not $ exprSyntaxEq e e') "bindings already in specified order." return e' where takeNonRecLets :: Monad m => Int -> CoreExpr -> m ([(Var,CoreExpr)],CoreExpr) takeNonRecLets 0 x = return ([],x) takeNonRecLets n (Let (NonRec v1 e1) x) = first ((v1,e1):) `liftM` takeNonRecLets (n-1) x takeNonRecLets _ _ = fail "insufficient non-recursive lets." noneFreeIn :: [(Var,CoreExpr)] -> Bool noneFreeIn ves = let (vs,es) = unzip ves in all (`notElemVarSet` unionVarSets (map freeVarsExpr es)) vs lookupName :: Monad m => [(Var,CoreExpr)] -> String -> m (Var,CoreExpr) lookupName ves nm = case filter (cmpString2Var nm . fst) ves of [] -> fail $ "name " ++ nm ++ " not matched." [ve] -> return ve _ -> fail $ "multiple matches for " ++ nm ++ "." mkNonRecLets :: [(Var,CoreExpr)] -> CoreExpr -> CoreExpr mkNonRecLets [] x = x mkNonRecLets ((v,e):ves) x = Let (NonRec v e) (mkNonRecLets ves x) ------------------------------------------------------------------------------------------- -- | Combine nested non-recursive lets into case of a tuple. -- E.g. let {v1 = e1 ; v2 = e2 ; v3 = e3} in body ==> case (e1,e2,e3) of {(v1,v2,v3) -> body} letTupleR :: (MonadCatch m, MonadUnique m) => String -> Rewrite c m CoreExpr letTupleR nm = prefixFailMsg "Let-tuple failed: " $ do (bnds, body) <- arr collectLets let numBnds = length bnds guardMsg (numBnds > 1) "at least two non-recursive let bindings of identifiers required." let (vs, rhss) = unzip bnds -- check if tupling the bindings would cause unbound variables let frees = map freeVarsExpr (drop 1 rhss) used = unionVarSets $ zipWith intersectVarSet (map (mkVarSet . (`take` vs)) [1..]) frees if isEmptyVarSet used then let rhs = mkCoreTup rhss in constT $ do bndr <- newIdH nm (exprType rhs) return $ mkSmallTupleCase vs body bndr rhs else fail $ "the following bound variables are used in subsequent bindings: " ++ showVarSet used where -- we only collect identifiers (not type or coercion vars) because we intend to case on them. collectLets :: CoreExpr -> ([(Id, CoreExpr)],CoreExpr) collectLets (Let (NonRec v e) body) | isId v = first ((v,e):) (collectLets body) collectLets expr = ([],expr) ------------------------------------------------------------------------------------------- -- TODO: come up with a better naming scheme for these -- This code could be factored better. -- | @e@ ==> @let v = e in v@ letIntroR :: (MonadCatch m, MonadUnique m) => String -> Rewrite c m CoreExpr letIntroR nm = do e <- idR Let (NonRec v e') _ <- letNonRecIntroR nm e return $ Let (NonRec v e') (varToCoreExpr v) -- | @body@ ==> @let v = e in body@ letNonRecIntroR :: (MonadCatch m, MonadUnique m) => String -> CoreExpr -> Rewrite c m CoreExpr letNonRecIntroR nm e = prefixFailMsg "Let-introduction failed: " $ contextfreeT $ \ body -> do v <- newVarH nm $ exprKindOrType e return $ Let (NonRec v e) body -- This isn't a "Let", but it's serving the same role. Maybe create a Local/Prog module? -- | @prog@ ==> @'ProgCons' (v = e) prog@ progNonRecIntroR :: (MonadCatch m, MonadUnique m) => String -> CoreExpr -> Rewrite c m CoreProg progNonRecIntroR nm e = prefixFailMsg "Top-level binding introduction failed: " $ do guardMsg (not $ isTyCoArg e) "Top-level type or coercion definitions are prohibited." contextfreeT $ \ prog -> do i <- newIdH nm (exprType e) return $ ProgCons (NonRec i e) prog -- | nonRecIntroR nm e = 'letNonRecIntroR nm e' <+ 'progNonRecIntroR nm e' nonRecIntroR :: (MonadCatch m, MonadUnique m) => String -> CoreExpr -> Rewrite c m Core nonRecIntroR nm e = readerT $ \case ExprCore{} -> promoteExprR (letNonRecIntroR nm e) ProgCore{} -> promoteProgR (progNonRecIntroR nm e) _ -> fail "can only introduce non-recursive bindings at Program or Expression nodes." -- | Introduce a local definition for a (possibly imported) identifier. -- Rewrites occurences of the identifier to point to this new local definiton. letIntroUnfoldingR :: ( BoundVars c, ReadBindings c, HasDynFlags m, HasHermitMEnv m, LiftCoreM m , MonadCatch m, MonadIO m, MonadThings m, MonadUnique m ) => HermitName -> Rewrite c m CoreExpr letIntroUnfoldingR nm = do i <- findIdT nm (rhs,_) <- getUnfoldingT AllBinders <<< return i contextfreeT $ \ body -> do i' <- cloneVarH id i let subst = substCoreExpr i (varToCoreExpr i') bnd = if i `elemUFM` freeVarsExpr rhs then Rec [(i', subst rhs)] else NonRec i' rhs body' = subst body return $ mkCoreLet bnd body' -------------------------------------------------------------------------------------------