module Yhc.Core.RecursiveLet(
isCoreLetRec,
removeRecursiveLet,
reduceRecursiveLet
) where
import Yhc.Core.Type
import Yhc.Core.Uniplate
import Yhc.Core.FreeVar
import Yhc.Core.UniqueName
import Control.Monad
import Control.Monad.State
import Data.List
removeRecursiveLet :: Core -> Core
removeRecursiveLet = uniqueFuncsSplit (remRecLet True)
reduceRecursiveLet :: Core -> Core
reduceRecursiveLet = uniqueFuncsSplit (remRecLet False)
remRecLet :: Monad m => Bool -> m CoreFuncName -> (CoreFunc -> m ()) -> CoreExpr -> m CoreExpr
remRecLet always newFunc addFunc = f
where
f (CoreLet [] x) = f x
f (CoreLet binds x) | not (null free) = do
free2 <- mapM (\(a,b) -> liftM ((,) a) $ f b) free
locked2 <- f (CoreLet locked x)
return $ CoreLet free2 locked2
where
defined = map fst binds
(locked,free) = partition (isLocked . snd) binds
isLocked = any (`elem` defined) . collectFreeVars
f (CoreLet binds x) | always = do
names <- replicateM (length binds) newFunc
let binds2 = zip lhs (map (\x -> CoreApp (CoreFun x) (map CoreVar vars)) names)
newfuncs <- zipWithM (g (zip lhs names) binds2) names rhs
mapM_ addFunc newfuncs
x2 <- f x
return $ CoreLet binds2 x2
where
(lhs,rhs) = unzip binds
vars = nub (concatMap collectFreeVars rhs) \\ lhs
g mapping binds2 name rhs = do
let free = collectFreeVars rhs
binds3 = filter ((`elem` free) . fst) binds2
body <- f $ CoreLet binds3 rhs
return $ CoreFunc name vars body
f x = descendM f x
isCoreLetRec :: CoreExpr -> Bool
isCoreLetRec (CoreLet bind xs) = not $ null $ map fst bind `intersect` concatMap (collectFreeVars . snd) bind
isCoreLetRec x = False