module Yhc.Core.Strictness(coreStrictness) where

import Yhc.Core.Type
import Yhc.Core.Prim

import qualified Data.Map as Map
import Data.List(intersect, nub, partition)

{-
ALGORITHM:

SCC PARTIAL SORT:
First sort the functions so that they occur in the childmost order:
x1 < x2, if x1 doesn't transitive-call x2, and x2 does transitive-call x1
Being wrong is fine, but being better gives better results

PRIM STRICTNESS:
The strictness of the various primitive operations

BASE STRICTNESS:
If all paths case on a particular value, then these are strict in that one
If call onwards, then strict based on the caller
-}


-- | Given a function, return a list of arguments.
--   True is strict in that argument, False is not.
--   [] is unknown strictness
coreStrictness :: Core -> (CoreFuncName -> [Bool])
coreStrictness core = \funcname -> Map.findWithDefault [] funcname mp
    where mp = mapStrictness $ sccSort $ coreFuncs core




mapStrictness :: [CoreFunc] -> Map.Map CoreFuncName [Bool]
mapStrictness funcs = foldl f Map.empty funcs
    where
        f mp (CorePrim{coreFuncName=name}) = case corePrimMaybe name of
                                    Nothing -> mp
                                    Just p -> Map.insert name (primStrict p) mp

        f mp (CoreFunc name args body) = Map.insert name (map (`elem` strict) args) mp
            where
                strict = strictVars body

                -- which variables are strict
                strictVars :: CoreExpr -> [String]
                strictVars (CorePos _ x) = strictVars x
                strictVars (CoreVar x) = [x]

                strictVars (CoreCase (CoreVar x) alts) = nub $ x : intersectList (map (strictVars . snd) alts)
                strictVars (CoreCase x alts) = strictVars x

                strictVars (CoreApp (CoreFun x) xs)
                    | length xs == length res
                    = nub $ concatMap strictVars $ map snd $ filter fst $ zip res xs
                    where res = Map.findWithDefault [] x mp

                strictVars (CoreApp x xs) = strictVars x

                strictVars _ = []


intersectList :: Eq a => [[a]] -> [a]
intersectList [] = []
intersectList xs = foldr1 intersect xs



-- do a sort in approximate SCC order
sccSort :: [CoreFunc] -> [CoreFunc]
sccSort xs = prims ++ funcs
    where (prims,funcs) = partition isCorePrim xs