module Grin.NodeAnalyze(nodeAnalyze) where
import Control.Monad.Identity hiding(join)
import Control.Monad.RWS hiding(join)
import Data.Maybe
import Text.Printf
import qualified Data.Map as Map
import qualified Data.Set as Set
import Grin.Grin hiding(V)
import Grin.Noodle
import Options
import StringTable.Atom
import Support.CanType
import Support.FreeVars
import Support.Tickle
import Util.Gen
import Util.SetLike
import Util.UnionSolve
import Util.UniqueMonad
import qualified Stats
data NodeType
= WHNF
| Lazy
deriving(Eq,Ord,Show)
data N = N !NodeType (Topped (Set.Set Atom))
deriving(Eq)
instance Show N where
show (N nt ts) = show nt ++ "-" ++ f ts where
f Top = "[?]"
f (Only x) = show (Set.toList x)
instance Fixable NodeType where
isBottom x = x == WHNF
isTop x = x == Lazy
join x y = max x y
meet x y = min x y
eq = (==)
lte x y = x <= y
instance Fixable N where
isBottom (N a b) = isBottom a && isBottom b
isTop (N a b) = isTop a && isTop b
join (N x y) (N x' y') = N (join x x') (join y y')
meet (N x y) (N x' y') = N (meet x x') (meet y y')
lte (N x y) (N x' y') = lte x x' && lte y y'
eq (N x y) (N x' y') = eq x x' && eq y y'
showFixable n = show n
data V = V Va Ty | VIgnore
deriving(Eq,Ord)
data Va =
Vr !Var
| Fa !Atom !Int
| Fr !Atom !Int
deriving(Eq,Ord)
vr v t = V (Vr v) t
fa n i t = V (Fa n i) t
fr n i t = V (Fr n i) t
class NodeLike a where
isGood :: a -> Bool
instance NodeLike Ty where
isGood TyNode = True
isGood TyINode = True
isGood _ = False
instance NodeLike Val where
isGood v = isGood (getType v)
instance NodeLike V where
isGood (V _ t) = isGood t
isGood _ = False
instance NodeLike (Either V b) where
isGood (Left n) = isGood n
isGood _ = True
instance Show V where
showsPrec _ (V (Vr v) ty) = shows (Var v ty)
showsPrec _ (V (Fa a i) _) = shows (a,i)
showsPrec _ (V (Fr a i) _) = shows (i,a)
showsPrec _ VIgnore = showString "IGN"
newtype M a = M (RWS TyEnv (C N V) Int a)
deriving(Monad,Functor,MonadWriter (C N V))
runM :: Grin -> M a -> C N V
runM grin (M w) = case runRWS w (grinTypeEnv grin) 1 of
(_,_,w) -> w
nodeAnalyze :: Grin -> IO Grin
nodeAnalyze grin' = do
let cs = runM grin $ do
mapM_ doFunc (grinFuncs grin)
mapM_ docaf (grinCafs grin)
doFunc (toAtom "@initcafs",[] :-> initCafs grin)
grin = renameUniqueGrin grin'
docaf (v,tt) | True = tell $ Right top `equals` Left (V (Vr v) TyINode)
| otherwise = return ()
(rm,res) <- solve (const (return ())) cs
let cmap = Map.map (fromJust . flip Map.lookup res) rm
(grin',stats) <- Stats.runStatT $ tickleM (fixupfs cmap (grinTypeEnv grin)) grin
return $ transformFuncs (fixupFuncs (grinSuspFunctions grin) (grinPartFunctions grin) cmap) grin' { grinStats = stats `mappend` grinStats grin' }
data Todo = Todo !Bool [V] | TodoNothing
initCafs grin = f (grinCafs grin) (Return []) where
f ((v,node):rs) rest = BaseOp Overwrite [(Var v TyINode),node] :>>= [] :-> f rs rest
f [] rest = rest
doFunc :: (Atom,Lam) -> M ()
doFunc (name,arg :-> body) = ans where
ans :: M ()
ans = do
let rts = getType body
forMn_ rts $ \ (t,i) -> dVar (fr name i t) t
forMn_ arg $ \ (~(Var v vt),i) -> do
dVar (vr v vt) vt
tell $ cAnnotate "FunArg" $ Left (fa name i vt) `equals` Left (vr v vt)
fn (Todo True [ fr name i t | i <- naturals | t <- rts ]) body
dVar v TyNode = do
tell $ Left v `islte` Right (N WHNF Top)
dVar _ _ = return ()
zVar s v TyNode = tell $ cAnnotate ("zVar - tynode " ++ s) $ Left (vr v TyNode) `equals` Right (N WHNF Top)
zVar s v t = tell $ cAnnotate ("zVar - inode " ++ s) $ Left (vr v t) `equals` Right top
fn :: Todo -> Exp -> M ()
fn ret body = f body where
f (x :>>= [Var v vt] :-> rest) = do
dVar (vr v vt) vt
gn (Todo True [vr v vt]) x
f rest
f (x :>>= vs@(_:_:_) :-> rest) = do
vs' <- forM vs $ \ (Var v vt) -> do
dVar (vr v vt) vt
return $ vr v vt
gn (if all (== VIgnore) vs' then TodoNothing else Todo True vs') x
f rest
f (x :>>= v :-> rest) = do
forM_ (Set.toList $ freeVars v) $ \ (v,vt) -> zVar "Bind" v vt
gn TodoNothing x
f rest
f body = gn ret body
isfn _ x y | not (isGood x) = mempty
isfn (Todo True _) x y = cAnnotate "isfn True" $ Left x `equals` y
isfn (Todo False _) x y = cAnnotate "isfn False" $ Left x `isgte` y
isfn TodoNothing x y = mempty
equals x y | isGood x && isGood y = Util.UnionSolve.equals x y
| otherwise = mempty
isgte x y | isGood x && isGood y = Util.UnionSolve.isgte x y
| otherwise = mempty
islte x y | isGood x && isGood y = Util.UnionSolve.islte x y
| otherwise = mempty
gn ret head = f head where
fl ret (v :-> body) = do
forM_ (Set.toList $ freeVars v) $ \ (v,vt) -> zVar "Alt" v vt
fn ret body
dunno ty = do
dres [Right (if TyNode == t then N WHNF Top else top) | t <- ty ]
dres res = do
case ret of
Todo b vs | length res /= length vs -> error "lengths don't match!"
Todo b vs -> forM_ (zip vs res) $ \ (v,r) -> tell (isfn ret v r)
_ -> return ()
f (_ :>>= _) = error $ "Grin.NodeAnalyze: :>>="
f (Case v as)
| Todo _ n <- ret = mapM_ (fl (Todo False n)) as
| TodoNothing <- ret = mapM_ (fl TodoNothing) as
f (BaseOp Eval [x]) = do
dres [Right (N WHNF Top)]
f (BaseOp (Apply ty) xs) = do
mapM_ convertVal xs
dunno ty
f (App { expFunction = fn, expArgs = vs, expType = ty }) = do
vs' <- mapM convertVal vs
forMn_ (zip vs vs') $ \ ((tv,v),i) -> do
tell $ v `islte` Left (fa fn i (getType tv))
dres [Left $ fr fn i t | i <- [ 0 .. ] | t <- ty ]
f (Call { expValue = Item fn _, expArgs = vs, expType = ty }) = do
vs' <- mapM convertVal vs
forMn_ (zip vs vs') $ \ ((tv,v),i) -> do
tell $ v `islte` Left (fa fn i (getType tv))
dres [Left $ fr fn i t | i <- [ 0 .. ] | t <- ty ]
f (Return x) = do mapM convertVal x >>= dres
f (BaseOp (StoreNode _) w) = do mapM convertVal w >>= dres
f (BaseOp Promote [w]) = do
ww <- convertVal w
tell $ ww `islte` Right (N WHNF Top)
dres [ww]
f (BaseOp Demote [w]) = do
ww <- convertVal w
tell $ ww `islte` Right (N WHNF Top)
dres [ww]
f Error {} = return ()
f Prim { expArgs = as, expType = ty } = mapM_ convertVal as >> dunno ty
f Alloc { expValue = v } | getType v == TyNode = do
v' <- convertVal v
dres [v']
f Alloc { expValue = v } | getType v == tyINode = do
convertVal v
dunno [TyPtr tyINode]
f NewRegion { expLam = _ :-> body } = fn ret body
f (BaseOp Overwrite [Var vname ty,v]) | ty == TyINode = do
v' <- convertVal v
tell $ Left (vr vname ty) `isgte` v'
dres []
f e@(BaseOp Overwrite vs) = do mapM_ convertVal vs >> dunno (getType e)
f e@(BaseOp PokeVal vs) = do mapM_ convertVal vs >> dunno (getType e)
f e@(BaseOp PeekVal vs) = do mapM_ convertVal vs >> dunno (getType e)
f Let { expDefs = ds, expBody = e } = do
mapM_ doFunc (map (\x -> (funcDefName x, funcDefBody x)) ds)
fn ret e
f exp = error $ "NodeAnalyze.f: " ++ show exp
convertVal (Const n@(NodeC _ _)) = convertVal n
convertVal (Const _) = return $ Right (N WHNF Top)
convertVal (NodeC t vs) = case tagUnfunction t of
Nothing -> do
mapM_ convertVal vs
return $ Right (N WHNF (Only $ Set.singleton t))
Just (n,fn) -> do
vs' <- mapM convertVal vs
forMn_ (zip vs vs') $ \ ((vt,v),i) -> do
tell $ v `islte` Left (fa fn i (getType vt))
forM_ [0 .. n 1 ] $ \i -> do
tell $ Right top `islte` Left (fa fn (length vs + i) TyINode)
return $ Right (N (if n == 0 then Lazy else WHNF) (Only $ Set.singleton t))
convertVal (Var v t) = return $ Left (vr v t)
convertVal v | isGood v = return $ Right (N Lazy Top)
convertVal Lit {} = return $ Left VIgnore
convertVal ValPrim {} = return $ Left VIgnore
convertVal Index {} = return $ Left VIgnore
convertVal Item {} = return $ Left VIgnore
convertVal ValUnknown {} = return $ Left VIgnore
convertVal v = error $ "convertVal " ++ show v
bottom = N WHNF (Only (Set.empty))
top = N Lazy Top
data WhatToDo
= WhatDelete
| WhatUnchanged
| WhatConstant Val
| WhatSubs Ty (Val -> Exp) (Val -> Exp)
transformFuncs :: (Atom -> [Ty] -> Maybe [Ty] -> (Maybe [WhatToDo],Maybe [WhatToDo])) -> Grin -> Grin
transformFuncs fn grin = grin'' where
grin'' = grin' { grinTypeEnv = extendTyEnv (grinFunctions grin') (grinTypeEnv grin') }
grin' = setGrinFunctions (nfs $ grinFuncs grin) grin
nfs ds = map fs ds
fs (n,l@(ps :-> e)) = (n,f (fn n (map getType ps) (Just $ getType e)) l)
f (Nothing,Nothing) (p :-> e) = p :-> j e
f (Just ats,rts') (p :-> e) = p' :-> e' where
rts = maybe (map (const WhatUnchanged) (getType e)) id rts'
p' = concatMap f (zip p ats) where
f (v,WhatUnchanged) = [v]
f (_,WhatDelete) = []
f (_,WhatConstant _) = []
f (Var v _,WhatSubs nty _ _) = [Var v nty]
f _ = error "NodeAnalyze.transformFuncs: f bad."
e' = g (zip p ats) (j e)
g ((_,WhatUnchanged):xs) e = g xs e
g ((_,WhatDelete):xs) e = g xs e
g ((vr,WhatConstant c):xs) e = Return [c] :>>= [vr] :-> g xs e
g ((Var v vt,WhatSubs nt _ ft):xs) e = ft (Var v nt) :>>= [Var v vt] :-> g xs e
g [] e = e :>>= rvs :-> h (zip rvs rts) (drop (length (getType e)) [v1 .. ]) [] where
rvs = zipWith Var [v1 .. ] (getType e)
g _ _ = error "NodeAnalyze.transformFuncs: g bad."
h ((r,WhatUnchanged):xs) vs rs = h xs vs (r:rs)
h ((r,WhatDelete):xs) vs rs = h xs vs rs
h ((r,WhatConstant _):xs) vs rs = h xs vs rs
h ((r,WhatSubs nty tt _):xs) (v:vs) rs = tt r :>>= [Var v nty] :-> h xs vs (Var v nty:rs)
h [] _ rs = Return (reverse rs)
h _ _ _ = error "NodeAnalyze.transformFuncs: h bad."
f _ _ = error "NodeAnalyze.transformFuncs: f bad."
j app@(BaseOp (StoreNode False) [NodeC a xs]) = res where
res = if isNothing ats' then app else e'
ats = maybe (repeat WhatUnchanged) id ats'
(ats',_) = fn (tagFlipFunction a) (map getType xs) Nothing
lvars = zipWith Var [ v1 .. ] (map getType xs)
e' = Return xs :>>= lvars :-> f (zip lvars ats) []
f ((v,WhatUnchanged):xs) rs = f xs (v:rs)
f ((_,WhatDelete):xs) rs = f xs rs
f ((_,WhatConstant _):xs) rs = f xs rs
f ((Var v oty,WhatSubs nty tt _):xs) rs = tt (Var v oty) :>>= [Var v nty] :-> f xs (Var v nty:rs)
f [] rs = BaseOp (StoreNode False) [NodeC a (reverse rs)]
f _ _ = error "NodeAnalyze.transformFuncs: f bad."
j app@(App a xs ts) = res where
res = if isNothing ats' && isNothing rts' then app else e'
ats = maybe (repeat WhatUnchanged) id ats'
rts = maybe (repeat WhatUnchanged) id rts'
(ats',rts') = fn a (map getType xs) (Just ts)
lvars = zipWith Var [ v1 .. ] (map getType xs)
e' = Return xs :>>= lvars :-> f (zip lvars ats) []
f ((v,WhatUnchanged):xs) rs = f xs (v:rs)
f ((_,WhatDelete):xs) rs = f xs rs
f ((_,WhatConstant _):xs) rs = f xs rs
f ((Var v oty,WhatSubs nty tt _):xs) rs = tt (Var v oty) :>>= [Var v nty] :-> f xs (Var v nty:rs)
f [] rs = App a (reverse rs) ts' :>>= rvars :-> g (zip rvars' rts) rvars []
f _ _ = error "NodeAnalyze.transformFuncs: f bad."
g [] [] rs = Return (reverse rs)
g ((_,WhatUnchanged):xs) (n:ns) rs = g xs ns (n:rs)
g ((v,WhatDelete):xs) vs rs = Return [ValUnknown (getType v)] :>>= [v] :-> g xs vs (v:rs)
g ((v,WhatConstant c):xs) vs rs = Return [c] :>>= [v] :-> g xs vs (v:rs)
g ((v,WhatSubs _ _ ft):xs) (n:ns) rs = ft n :>>= [v] :-> g xs ns (v:rs)
g _ _ _ = error "NodeAnalyze.transformFuncs: g bad."
rvars = zipWith Var [ v1 .. ] ts'
rvars' = zipWith Var (drop (length rvars) [ v1 .. ]) ts
ts' = concatMap g (zip ts rts) where
g (t,WhatUnchanged) = [t]
g (t,WhatConstant _) = []
g (t,WhatDelete) = []
g (t,WhatSubs nty _ _) = [nty]
j Let { expDefs = ds, expBody = e } = grinLet [ updateFuncDefProps d { funcDefBody = snd $ fs (funcDefName d, funcDefBody d) } | d <- ds ] (j e)
j e = runIdentity $ mapExpExp (return . j) e
fixupFuncs sfuncs pfuncs cmap = ans where
ans a as jrs | a `Set.member` pfuncs = (Nothing,Nothing)
| a `Set.member` sfuncs = (Just aargs,Nothing)
| otherwise = (Just aargs,fmap rargs jrs) where
aargs = map (bool pnode WhatUnchanged) largs
largs = map (lupArg fa a) (zip as [0 .. ])
rargs rs = map (bool pnode WhatUnchanged) (map (lupArg fr a) (zip rs [0 .. ]))
lupArg fa a (x,i) = case (x,Map.lookup (fa a i x) cmap) of
(TyINode,Just (ResultJust _ (N WHNF _))) -> True
(TyINode,Just ResultBounded { resultLB = Just (N WHNF _) }) -> True
(TyINode,Just ResultBounded { resultLB = Nothing }) -> True
_ -> False
pnode = WhatSubs TyNode (\v -> BaseOp Promote [v]) (\v -> BaseOp Demote [v])
fixupfs cmap tyEnv l = tickleM f (l::Lam) where
lupVar (Var v t) = case Map.lookup (vr v t) cmap of
_ | v < v0 -> fail "nocafyet"
Just (ResultJust _ lb) -> return lb
Just ResultBounded { resultLB = Just lb } -> return lb
Just ResultBounded { resultLB = Nothing } -> return bottom
_ -> fail "lupVar"
lupVar _ = fail "lupVar2"
pstuff x arg n@(N w t) = liftIO $ when verbose (printf "-- %s %s %s\n" x (show arg) (show n))
f a@(BaseOp Eval [arg]) | Just n <- lupVar arg = case n of
N WHNF _ -> do
pstuff "eval" arg n
Stats.mtick (toAtom "Optimize.NodeAnalyze.eval-promote")
return (BaseOp Promote [arg])
_ -> return a
f a@(BaseOp (Apply ty) (papp:args)) | Just nn <- lupVar papp = case nn of
N WHNF tset | Only set <- tset, [sv] <- Set.toList set, TagPApp n fn <- tagInfo sv, Just (ts,_) <- findArgsType tyEnv sv -> do
pstuff "apply" papp nn
case (n,args) of
(1,[arg]) -> do
Stats.mtick (toAtom "Optimize.NodeAnalyze.apply-inline")
let va = Var v1 (getType arg)
vars = zipWith Var [ v2 .. ] ts
return $ Return [arg,papp] :>>= [va,NodeC sv vars] :-> App fn (vars ++ [va]) ty
(1,[]) -> do
Stats.mtick (toAtom "Optimize.NodeAnalyze.apply-inline")
let vars = zipWith Var [ v2 .. ] ts
return $ Return [papp] :>>= [NodeC sv vars] :-> App fn vars ty
(pn,[arg]) -> do
Stats.mtick (toAtom "Optimize.NodeAnalyze.apply-inline")
let va = Var v1 (getType arg)
vars = zipWith Var [ v2 .. ] ts
return $ Return [arg,papp] :>>= [va,NodeC sv vars] :-> dstore (NodeC (partialTag fn (pn 1)) (vars ++ [va]))
(pn,[]) -> do
Stats.mtick (toAtom "Optimize.NodeAnalyze.apply-inline")
let vars = zipWith Var [ v2 .. ] ts
return $ Return [papp] :>>= [NodeC sv vars] :-> dstore (NodeC (partialTag fn (pn 1)) vars)
_ -> return a
_ -> return a
f e = mapExpExp f e
dstore x = BaseOp (StoreNode True) [x]
renameUniqueGrin :: Grin -> Grin
renameUniqueGrin grin = res where
(res,()) = evalRWS (execUniqT 1 ans) ( mempty :: Map.Map Atom Atom) (fromList [ x | (x,_) <- grinFuncs grin ] :: Set.Set Atom)
ans = do tickleM f grin
f (l :-> b) = g b >>= return . (l :->)
g a@App { expFunction = fn } = do
m <- lift ask
case mlookup fn m of
Just fn' -> return a { expFunction = fn' }
_ -> return a
g a@Call { expValue = Item fn t } = do
m <- lift ask
case mlookup fn m of
Just fn' -> return a { expValue = Item fn' t }
_ -> return a
g (e@Let { expDefs = defs }) = do
(defs',rs) <- liftM unzip $ flip mapM defs $ \d -> do
(nn,rs) <- newName (funcDefName d)
return (d { funcDefName = nn },rs)
local (fromList rs `mappend`) $ mapExpExp g e { expDefs = defs' }
g b = mapExpExp g b
newName a = do
m <- lift get
case member a m of
False -> do lift $ modify (insert a); return (a,(a,a))
True -> do
let cfname = do
uniq <- newUniq
let fname = toAtom $ show a ++ "-" ++ show uniq
if fname `member` (m :: Set.Set Atom) then cfname else return fname
nn <- cfname
lift $ modify (insert nn)
return (nn,(a,nn))
bool x y b = if b then x else y