module Yhc.Core.RecursiveLet( isCoreLetRec, removeRecursiveLet, reduceRecursiveLet ) where import Yhc.Core.Type import Yhc.Core.Uniplate import Yhc.Core.FreeVar import Yhc.Core.UniqueName import Control.Monad import Control.Monad.State import Data.List -- | Remove recursive lets -- -- Let's are rearranged so a variable is not used in the defining block removeRecursiveLet :: Core -> Core removeRecursiveLet = uniqueFuncsSplit (remRecLet True) -- | Reduce the number of recursive lets, but splitting lets -- which have recursive bindings, but can be linearised reduceRecursiveLet :: Core -> Core reduceRecursiveLet = uniqueFuncsSplit (remRecLet False) remRecLet :: Monad m => Bool -> m CoreFuncName -> (CoreFunc -> m ()) -> CoreExpr -> m CoreExpr remRecLet always newFunc addFunc = f where f (CoreLet [] x) = f x -- handle the variables which are mixed up, but not actually recursive -- let a = b; b = 1 in a f (CoreLet binds x) | not (null free) = do free2 <- mapM (\(a,b) -> liftM ((,) a) $ f b) free locked2 <- f (CoreLet locked x) return $ CoreLet free2 locked2 where defined = map fst binds (locked,free) = partition (isLocked . snd) binds isLocked = any (`elem` defined) . collectFreeVars -- handle the truely recursive ones -- let xs = x:xs in xs f (CoreLet binds x) | always = do names <- replicateM (length binds) newFunc let binds2 = zip lhs (map (\x -> CoreApp (CoreFun x) (map CoreVar vars)) names) newfuncs <- zipWithM (g (zip lhs names) binds2) names rhs mapM_ addFunc newfuncs x2 <- f x return $ CoreLet binds2 x2 where (lhs,rhs) = unzip binds vars = nub (concatMap collectFreeVars rhs) \\ lhs g mapping binds2 name rhs = do let free = collectFreeVars rhs binds3 = filter ((`elem` free) . fst) binds2 body <- f $ CoreLet binds3 rhs return $ CoreFunc name vars body f x = descendM f x -- | Is a CoreLet recursive, i.e. do any of the introduced variables (LHS of bind) -- also show up in the RHS of bind. -- -- Returns False if the expression is not a CoreLet. isCoreLetRec :: CoreExpr -> Bool isCoreLetRec (CoreLet bind xs) = not $ null $ map fst bind `intersect` concatMap (collectFreeVars . snd) bind isCoreLetRec x = False