module Yhc.Core.Simplify(
coreSimplify, coreSimplifyExpr,
coreSimplifyCaseCon, coreSimplifyCaseCase, coreSimplifyCaseLet,
coreSimplifyExprUnique, coreSimplifyExprUniqueExt
) where
import Data.List
import Data.Maybe
import Control.Monad
import Yhc.Core.Internal.General
import Yhc.Core.Type
import Yhc.Core.Uniplate
import Yhc.Core.FreeVar3(duplicateExpr)
import Yhc.Core.FreeVar
import Yhc.Core.UniqueId
coreSimplify :: UniplateExpr a => a -> a
coreSimplify x = context $ map coreSimplifyExpr children
where (children,context) = uniplateExpr x
coreSimplifyExpr :: CoreExpr -> CoreExpr
coreSimplifyExpr = transformExpr f
where
f (CoreCase (CoreFun x) alts) = f (CoreCase (CoreApp (CoreFun x) []) alts)
f o@(CoreCase on alts) | isCoreCon $ fst $ fromCoreApp on = transformExpr f $ coreSimplifyCaseCon o
f o@(CoreCase (CoreCase _ _) _) = transformExpr f $ coreSimplifyCaseCase o
f o@(CoreCase (CoreLet _ _) _) = transformExpr f $ coreSimplifyCaseLet o
f orig@(CoreApp (CoreCase _ _) _) = f $ CoreCase on (map g alts)
where
CoreApp (CoreCase on alts) args = uniqueExpr orig
g (lhs,rhs) = (lhs, f $ CoreApp rhs args)
f (CoreCase (CoreLet bind on) alts) = f $ CoreLet bind (f $ CoreCase on alts)
f (CoreLet bind x) = coreLet many (transformExpr f $ replaceFreeVars once x)
where
bindVars = [i | CoreVar i <- concatMap (universeExpr . snd) bind]
(once,many) = partition (uncurry isValid) bind
isValid lhs rhs = lhs `notElem` bindVars && (isSimple rhs || countFreeVar lhs x <= 1)
isSimple (CoreApp x []) = isSimple x
isSimple (CoreFun x) = True
isSimple (CorePos x y) = isSimple y
isSimple (CoreVar x) = True
isSimple (CoreApp (CorePos _ (CoreFun name)) args) = isSimple (CoreApp (CoreFun name) args)
isSimple _ = False
f (CoreLet binds (CoreCase on alts1))
| disjoint (universeExprVar on) (map fst binds) = f $ CoreCase on (map g alts1)
where g (lhs,rhs) = (lhs,f $ coreLet (filter ((`notElem` patVariables lhs) . fst) binds) $ f rhs)
f (CoreApp (CoreApp x xs) ys) = f $ CoreApp x (xs++ys)
f o@(CoreApp (CoreLam bind x) args) = transformExpr f $
coreApp (coreLam bindnew (replaceFreeVars rep x)) args2
where
args2 = drop (length bind) args
bind2 = drop (length args) bind
bindnew = take (length bind2) (freeVars 'v' \\ collectAllVars o)
rep = zip bind (args ++ map CoreVar bindnew)
f x@(CoreApp (CoreLet bind xs) ys) =
CoreLet (zip fresh (map rep rhs)) (CoreApp (rep xs) ys)
where
(lhs,rhs) = unzip bind
rep = replaceFreeVars (zip fresh (map CoreVar lhs))
fresh = freeVars 'x' \\ collectAllVars x
f x = x
coreSimplifyCaseCon :: CoreExpr -> CoreExpr
coreSimplifyCaseCon (CoreCase (CoreCon con) alts) = coreSimplifyCaseCon $ CoreCase (CoreApp (CoreCon con) []) alts
coreSimplifyCaseCon (CoreCase on@(CoreApp (CoreCon con) fields) alts)
| not $ null matches = head matches
where
matches = mapMaybe g alts
g (PatCon x xs, rhs) | x == con = Just $ replaceFreeVars (zip xs fields) rhs
g (PatDefault, rhs) = Just rhs
g _ = Nothing
coreSimplifyCaseCon x = x
coreSimplifyCaseCase :: CoreExpr -> CoreExpr
coreSimplifyCaseCase o@(CoreCase (CoreCase on alts1) alts2) = CoreCase on (map g alts1)
where
vars = freeVars 'v' \\ collectAllVars o
g (PatCon c vs,rhs) = (PatCon c vs2, CoreCase rhs2 alts2)
where
vs2 = take (length vs) vars
rhs2 = replaceFreeVars (zip vs (map CoreVar vs2)) rhs
g (lhs,rhs) = (lhs, CoreCase rhs alts2)
coreSimplifyCaseCase x = x
coreSimplifyCaseLet :: CoreExpr -> CoreExpr
coreSimplifyCaseLet o@(CoreCase (CoreLet bind x) alts) =
CoreLet (zipWith f newvars bind) (CoreCase (rep x) alts)
where
newvars = freeVars 'v' \\ collectAllVars o
rep = replaceFreeVars $ zip (map fst bind) (map CoreVar newvars)
f new (lhs,rhs) = (new, rep rhs)
uniqueExpr :: CoreExpr -> CoreExpr
uniqueExpr x = uniqueBoundVarsWithout (collectAllVars x) x
freeVars :: Char -> [String]
freeVars c = [c:show i | i <- [1..]]
coreSimplifyExprUnique :: UniqueIdM m => CoreExpr -> m CoreExpr
coreSimplifyExprUnique = coreSimplifyExprUniqueExt (const return)
coreSimplifyExprUniqueExt :: UniqueIdM m => (
(CoreExpr -> m CoreExpr) ->
(CoreExpr -> m CoreExpr)
) -> CoreExpr -> m CoreExpr
coreSimplifyExprUniqueExt ext = fs
where
fs = transformM f
coreCase__ x y = f $ CoreCase x y ; coreCase_' x y = f . CoreCase x =<< y
coreLet__ x y = f $ CoreLet x y ; coreLet_' x y = f . CoreLet x =<< y
coreLam__ x y = f $ CoreLam x y ; coreLam_' x y = f . CoreLam x =<< y
coreApp__ x y = f $ CoreApp x y ; coreApp'_ x y = f . flip CoreApp y =<< x
f (CorePos _ x ) = return x
f (CoreApp x []) = return x
f (CoreLet [] x) = return x
f (CoreLam [] x) = return x
f (CoreCase on alts) | isCoreCon con && not (null matches) = head matches
where
(con,fields) = fromCoreApp on
matches = mapMaybe g alts
g (PatDefault,rhs) = Just $ return rhs
g (PatCon x xs, rhs) | x == fromCoreCon con = Just $ coreLet__ (zip xs fields) rhs
g _ = Nothing
f (CoreCase (CoreCase on alts1) alts2) =
coreCase_' on (mapM g alts1)
where
g (lhs,rhs) = do
CoreCase _ alts22 <- duplicateExpr $ CoreCase (CoreLit $ CoreInt 0) alts2
rhs2 <- coreCase__ rhs alts22
return (lhs,rhs2)
f (CoreCase (CoreLet bind x) alts) =
coreLet_' bind (coreCase__ x alts)
f (CoreApp (CoreApp x xs) ys) = coreApp__ x (xs++ys)
f (CoreApp (CoreLet bind xs) ys) =
coreLet_' bind (coreApp__ xs ys)
f (CoreApp (CoreCase on alts) args) = coreCase_' on (mapM g alts)
where
g (lhs,rhs) = do
args2 <- mapM duplicateExpr args
rhs2 <- coreApp__ rhs args2
return (lhs,rhs2)
f (CoreApp (CoreLam bind x) args) =
coreApp'_ (coreLam_' bind2 (coreLet__ (zip bind1 args1) x)) args2
where
m = min (length bind) (length args)
(bind1,bind2) = splitAt m bind
(args1,args2) = splitAt m args
f (CoreLet bind (CoreCase on alts))
| disjoint (collectFreeVars on) (map fst bind)
= coreCase_' on (mapM g alts)
where
g (lhs,rhs) = do
rhs2 <- coreLet__ bind rhs
rhs3 <- duplicateExpr rhs2
return (lhs,rhs3)
f (CoreLet bind x) | any (isCoreLet . snd) bind =
coreLet_' (concat bs) $ coreLet__ vs_xs x
where
(vs_xs,bs) = unzip [((v,x),b) | (v,rhs) <- bind, let (b,x) = fromCoreLet rhs]
f (CoreLet bind x) | not $ null once = coreLet_' many (fs $ replaceFreeVars once x)
where
bindVars = [i | CoreVar i <- concatMap (universe . snd) bind]
(once,many) = partition (uncurry isValid) bind
isValid lhs rhs = lhs `notElem` bindVars && (isSimple rhs || countFreeVar lhs x <= 1)
isSimple x = isCoreFun x || isCoreVar x || (isCoreLit x && isCoreLitSmall (fromCoreLit x))
f x = ext f x