module Yhc.Core.StrictAnno (
coreStrictAnno) where
import Yhc.Core.Type
import Yhc.Core.Prim
import Yhc.Core.Annotation
import Yhc.Core.AnnotatePrims
import qualified Data.Map as Map
import Data.List(intersect, nub, partition)
coreStrictAnno :: CoreAnnotations -> Core -> (CoreFuncName -> [Bool])
coreStrictAnno anno core = \funcname -> Map.findWithDefault [] funcname mp
where mp = mapStrictAnno anno $ sccSort $ coreFuncs core
mapStrictAnno anno funcs = foldl f Map.empty funcs
where
f mp (prim@CorePrim{coreFuncName = name}) =
case getAnnotation prim "Strictness" anno of
Nothing -> mp
Just (CoreStrictness bs) -> Map.insert name bs mp
f mp func@(CoreFunc name args body) = case getAnnotation func "Strictness" anno of
Just (CoreStrictness bs) -> Map.insert name bs mp
Nothing -> Map.insert name (map (`elem` strict) args) mp
where
strict = strictVars body
strictVars :: CoreExpr -> [String]
strictVars (CorePos _ x) = strictVars x
strictVars (CoreVar x) = [x]
strictVars (CoreCase (CoreVar x) alts) =
nub $ x : intersectList (map (strictVars . snd) alts)
strictVars (CoreCase x alts) = strictVars x
strictVars (CoreApp (CoreFun x) xs)
| length xs == length res
= nub $ concatMap strictVars $ map snd $ filter fst $ zip res xs
where res = Map.findWithDefault [] x mp
strictVars (CoreApp x xs) = strictVars x
strictVars _ = []
intersectList :: Eq a => [[a]] -> [a]
intersectList [] = []
intersectList xs = foldr1 intersect xs
sccSort :: [CoreFunc] -> [CoreFunc]
sccSort xs = prims ++ funcs
where (prims,funcs) = partition isCorePrim xs