module Flite.Strictify
( strictifyPrim
, strictifyPrimWithPVStack
, forceAndRebind
) where
import Flite.Syntax
import Flite.Traversals
import Flite.Descend
import Flite.CallGraph
import Data.List
import Flite.LambdaLift
isInt (Int i) = True
isInt _ = False
mkApp f [] = f
mkApp (App f es) fs = App f (es ++ fs)
mkApp f es = App f es
primSatErrMsg :: String
primSatErrMsg = "Applications of primitives must be saturated"
strictifyPrim :: Prog -> Prog
strictifyPrim = onExp prim
where
prim (App (Fun f) (a:rest))
| isUnaryPrim f = mkApp result (map prim rest)
where a' = prim a
result = case isInt a' of
False -> App a' [Fun f]
True -> App (Fun f) [a']
prim (App (Fun f) (a:b:rest))
| isBinaryPrim f = mkApp result (map prim rest)
where (a', b') = (prim a, prim b)
result = case (isInt a', isInt b') of
(False, False) -> App b' [App a' [Fun f]]
(False, True ) -> App a' [Fun f, b']
(True , False) -> App b' [App (Fun f) [a']]
(True , True ) -> App (Fun f) [a', b']
prim (App (Fun f) es)
| isUnaryPrim f || isBinaryPrim f = error primSatErrMsg
prim (Fun f)
| isUnaryPrim f || isBinaryPrim f = error primSatErrMsg
prim e = descend prim e
strictifyPrimWithPVStack :: Prog -> Prog
strictifyPrimWithPVStack = onExp prim
where
prim (App (Fun f) (a:rest))
| isUnaryPrim f = mkApp result (map prim rest)
where a' = prim a
result = catApp [a', Fun f]
prim (App (Fun f) (a:b:rest))
| isBinaryPrim f = mkApp result (map prim rest)
where (a', b') = (prim a, prim b)
result = catApp [b', a', Fun f]
prim (App (Fun f) es)
| isUnaryPrim f || isBinaryPrim f = error primSatErrMsg
prim (Fun f)
| isUnaryPrim f || isBinaryPrim f = error primSatErrMsg
prim e = descend prim e
catApp :: [Exp] -> Exp
catApp es = App x xs
where
x:xs = concatMap contents es
contents (App e es) = e:es
contents e = [e]
forceAndRebind :: Prog -> Prog
forceAndRebind p = map (wrap cg wrapperIds) p ++ wrappers
where
cg = callReachableGraph p
wrappers = lambdaLift $ concatMap (makeWrapper cg) p
wrapperIds = map funcName wrappers
wrap :: CallGraph -> [Id] -> Decl -> Decl
wrap cg ws (Func f args rhs) = Func f args (wrapExp f cg ws rhs)
wrapExp :: Id -> CallGraph -> [Id] -> Exp -> Exp
wrapExp f cg ws (Fun g)
| g' `elem` ws && f `notElem` reachable cg g = Fun g'
| otherwise = Fun g
where g' = g ++ "_W"
wrapExp f cg ws e = descend (wrapExp f cg ws) e
makeWrapper :: CallGraph -> Decl -> [Decl]
makeWrapper cg (Func f args rhs)
| rhs == rhs' = []
| otherwise = [Func f' args rhs']
where
rhs' = abstract f cg rhs
f' = f ++ "_W"
neededVars :: Exp -> [Id]
neededVars (App (Fun p) es)
| isPrimId p = concatMap neededVars es
neededVars (Var v) = [v]
neededVars _ = []
abstract :: Id -> CallGraph -> Exp -> Exp
abstract f cg (Case subject@(App (Fun p) es) as)
| isPrimId p
&& not (null vs)
&& recursive =
App (force (reverse vs) (Lam (vs ++ ws) (Case (App (Fun p) es) as')))
(map Var ws)
where
nvs = neededVars subject
fvs = filter (`elem` nvs) $ concatMap (freeVars . snd) as
vs = dups (nvs ++ fvs)
recursive = f `elem` concatMap (reachable cg)
(concatMap calls (subject:map snd as))
as' = [(p, abstract f cg e) | (p, e) <- as]
ws = filter (`notElem` vs) $ nub $ concatMap freeVars $ (es ++ map snd as)
abstract f cg e = descend (abstract f cg) e
force :: [Id] -> Exp -> Exp
force [] e = e
force (v:vs) e = App (Var v) [force vs e]
dups :: Eq a => [a] -> [a]
dups [] = []
dups (x:xs)
| x `elem` xs = x : dups (filter (/= x) xs)
| otherwise = dups xs