-- Version of the Yhc.Core.Strictness module, but based on -- Core Annotations. module Yhc.Core.StrictAnno ( coreStrictAnno) where import Yhc.Core.Type import Yhc.Core.Prim import Yhc.Core.Annotation import Yhc.Core.AnnotatePrims import qualified Data.Map as Map import Data.List(intersect, nub, partition) {- ALGORITHM: SCC PARTIAL SORT: First sort the functions so that they occur in the childmost order: x1 < x2, if x1 doesn't transitive-call x2, and x2 does transitive-call x1 Being wrong is fine, but being better gives better results PRIM STRICTNESS: The strictness of the various primitive operations BASE STRICTNESS: If all paths case on a particular value, then these are strict in that one If call onwards, then strict based on the caller -} -- | Given a function, return a list of arguments. -- True is strict in that argument, False is not. -- [] is unknown strictness -- coreStrictAnno :: CoreAnnotations -> Core -> (CoreFuncName -> [Bool]) coreStrictAnno anno core = \funcname -> Map.findWithDefault [] funcname mp where mp = mapStrictAnno anno $ sccSort $ coreFuncs core mapStrictAnno anno funcs = foldl f Map.empty funcs where f mp (prim@CorePrim{coreFuncName = name}) = case getAnnotation prim "Strictness" anno of Nothing -> mp Just (CoreStrictness bs) -> Map.insert name bs mp f mp func@(CoreFunc name args body) = case getAnnotation func "Strictness" anno of Just (CoreStrictness bs) -> Map.insert name bs mp Nothing -> Map.insert name (map (`elem` strict) args) mp where strict = strictVars body -- which variables are strict strictVars :: CoreExpr -> [String] strictVars (CorePos _ x) = strictVars x strictVars (CoreVar x) = [x] strictVars (CoreCase (CoreVar x) alts) = nub $ x : intersectList (map (strictVars . snd) alts) strictVars (CoreCase x alts) = strictVars x strictVars (CoreApp (CoreFun x) xs) | length xs == length res = nub $ concatMap strictVars $ map snd $ filter fst $ zip res xs where res = Map.findWithDefault [] x mp strictVars (CoreApp x xs) = strictVars x strictVars _ = [] intersectList :: Eq a => [[a]] -> [a] intersectList [] = [] intersectList xs = foldr1 intersect xs -- do a sort in approximate SCC order sccSort :: [CoreFunc] -> [CoreFunc] sccSort xs = prims ++ funcs where (prims,funcs) = partition isCorePrim xs