module Yhc.Core.Strictness(coreStrictness) where
import Yhc.Core.Type
import Yhc.Core.Prim
import qualified Data.Map as Map
import Data.List(intersect, nub, partition)
coreStrictness :: Core -> (CoreFuncName -> [Bool])
coreStrictness core = \funcname -> Map.findWithDefault [] funcname mp
where mp = mapStrictness $ sccSort $ coreFuncs core
mapStrictness :: [CoreFunc] -> Map.Map CoreFuncName [Bool]
mapStrictness funcs = foldl f Map.empty funcs
where
f mp (CorePrim{coreFuncName=name}) = case corePrimMaybe name of
Nothing -> mp
Just p -> Map.insert name (primStrict p) mp
f mp (CoreFunc name args body) = Map.insert name (map (`elem` strict) args) mp
where
strict = strictVars body
strictVars :: CoreExpr -> [String]
strictVars (CorePos _ x) = strictVars x
strictVars (CoreVar x) = [x]
strictVars (CoreCase (CoreVar x) alts) = nub $ x : intersectList (map (strictVars . snd) alts)
strictVars (CoreCase x alts) = strictVars x
strictVars (CoreApp (CoreFun x) xs)
| length xs == length res
= nub $ concatMap strictVars $ map snd $ filter fst $ zip res xs
where res = Map.findWithDefault [] x mp
strictVars (CoreApp x xs) = strictVars x
strictVars _ = []
intersectList :: Eq a => [[a]] -> [a]
intersectList [] = []
intersectList xs = foldr1 intersect xs
sccSort :: [CoreFunc] -> [CoreFunc]
sccSort xs = prims ++ funcs
where (prims,funcs) = partition isCorePrim xs