-- a fast, straightforward points to analysis
-- meant to determine nodes that are always in whnf
-- and find out evals or applys that always
-- apply to a known value

module Grin.NodeAnalyze(nodeAnalyze) where

import Control.Monad.Identity hiding(join)
import Control.Monad.RWS hiding(join)
import Data.Maybe
import Text.Printf
import qualified Data.Map as Map
import qualified Data.Set as Set

import Grin.Grin hiding(V)
import Grin.Noodle
import Options
import StringTable.Atom
import Support.CanType
import Support.FreeVars
import Support.Tickle
import Util.Gen
import Util.SetLike
import Util.UnionSolve
import Util.UniqueMonad
import qualified Stats

data NodeType
    = WHNF       -- ^ guarenteed to be a WHNF
    | Lazy       -- ^ a suspension, a WHNF, or an indirection to a WHNF
    deriving(Eq,Ord,Show)

data N = N !NodeType (Topped (Set.Set Atom))
    deriving(Eq)

instance Show N where
    show (N nt ts) = show nt ++ "-" ++ f ts  where
        f Top = "[?]"
        f (Only x) = show (Set.toList x)

instance Fixable NodeType where
    isBottom x = x == WHNF
    isTop x = x == Lazy
    join x y = max x y
    meet x y = min x y
    eq = (==)
    lte x y = x <= y

instance Fixable N where
    isBottom (N a b) = isBottom a && isBottom b
    isTop (N a b) = isTop a && isTop b
    join  (N x y) (N x' y') = N (join x x') (join y y')
    meet  (N x y) (N x' y') = N (meet x x') (meet y y')
    lte   (N x y) (N x' y') = lte x x' && lte y y'
    eq    (N x y) (N x' y') = eq x x' && eq y y'
    showFixable n = show n

data V = V Va Ty | VIgnore
    deriving(Eq,Ord)

data Va =
    Vr !Var
    | Fa !Atom !Int
    | Fr !Atom !Int
    deriving(Eq,Ord)

vr v t = V (Vr v) t
fa n i t = V (Fa n i) t
fr n i t = V (Fr n i) t

class NodeLike a where
    isGood :: a -> Bool

instance NodeLike Ty where
    isGood TyNode = True
    isGood TyINode = True
    isGood _ = False

instance NodeLike Val where
    isGood v = isGood (getType v)

instance NodeLike V where
    isGood (V _ t) = isGood t
    isGood _ = False

instance NodeLike (Either V b) where
    isGood (Left n) = isGood n
    isGood _ = True

instance Show V where
    showsPrec _ (V (Vr v) ty) = shows (Var v ty)
    showsPrec _ (V (Fa a i) _) = shows (a,i)
    showsPrec _ (V (Fr a i) _) = shows (i,a)
    showsPrec _ VIgnore = showString "IGN"

newtype M a = M (RWS TyEnv (C N V) Int a)
    deriving(Monad,Functor,MonadWriter (C N V))

runM :: Grin -> M a -> C N V
runM grin (M w) = case runRWS w (grinTypeEnv grin) 1 of
    (_,_,w) -> w

{-# NOINLINE nodeAnalyze #-}
nodeAnalyze :: Grin -> IO Grin
nodeAnalyze grin' = do
    let cs = runM grin $ do
            mapM_ doFunc (grinFuncs grin)
            mapM_ docaf (grinCafs grin)
            doFunc (toAtom "@initcafs",[] :-> initCafs grin)
        grin = renameUniqueGrin grin'
        docaf (v,tt) | True = tell $ Right top `equals` Left (V (Vr v) TyINode)
                     | otherwise = return ()
    --putStrLn "----------------------------"
    --print cs
    --putStrLn "----------------------------"
    --putStrLn "-- NodeAnalyze"
    (rm,res) <- solve (const (return ())) cs
    --(rm,res) <- solve putStrLn cs
    --putStrLn "----------------------------"
    --mapM_ (\ (x,y) -> putStrLn $ show x ++ " -> " ++ show y) (Map.toList rm)
    --putStrLn "----------------------------"
    --mapM_ print (Map.elems res)
    --putStrLn "----------------------------"
    let cmap = Map.map (fromJust . flip Map.lookup res) rm
    (grin',stats) <- Stats.runStatT $ tickleM (fixupfs cmap (grinTypeEnv grin)) grin
    return $ transformFuncs (fixupFuncs (grinSuspFunctions grin) (grinPartFunctions grin) cmap) grin' { grinStats = stats `mappend` grinStats grin' }

data Todo = Todo !Bool [V] | TodoNothing

initCafs grin = f (grinCafs grin) (Return []) where
        f ((v,node):rs) rest = BaseOp Overwrite [(Var v TyINode),node] :>>= [] :-> f rs rest
        f [] rest = rest

doFunc :: (Atom,Lam) -> M ()
doFunc (name,arg :-> body) = ans where
    ans :: M ()
    ans = do
        let rts = getType body
        forMn_ rts $ \ (t,i) -> dVar (fr name i t) t
        forMn_ arg $ \ (~(Var v vt),i) -> do
            dVar (vr v vt) vt
            tell $ cAnnotate "FunArg" $ Left (fa name i vt) `equals` Left (vr v vt)
        fn (Todo True [ fr name i t | i <- naturals | t <- rts ]) body
    -- restrict values of TyNode type to be in WHNF
    dVar v TyNode = do
        tell $ Left v `islte` Right (N WHNF Top)
    dVar _ _ = return ()
    -- set concrete values for vars based on their type only
    -- should only be used in patterns
    zVar s v TyNode = tell $ cAnnotate ("zVar - tynode " ++ s) $ Left (vr v TyNode) `equals` Right (N WHNF Top)
    zVar s v t = tell $ cAnnotate ("zVar - inode " ++ s) $ Left (vr v t) `equals` Right top
    fn :: Todo -> Exp -> M ()
    fn ret body = f body where
        f (x :>>= [Var v vt] :-> rest) = do
            dVar (vr v vt) vt
            gn (Todo True [vr v vt]) x
            f rest
        f (x :>>= vs@(_:_:_) :-> rest) = do
            vs' <- forM vs $ \ (Var v vt) -> do
                dVar (vr v vt) vt
                return $ vr v vt
            gn (if all (== VIgnore) vs' then TodoNothing else Todo True vs') x
            f rest
        f (x :>>= v :-> rest) = do
            forM_ (Set.toList $ freeVars v) $ \ (v,vt) -> zVar "Bind" v vt
            gn TodoNothing x
            f rest
        f body = gn ret body
    isfn _ x y | not (isGood x) = mempty
    isfn (Todo True  _) x y = cAnnotate "isfn True" $ Left x `equals` y
    isfn (Todo False _) x y = cAnnotate "isfn False" $ Left x `isgte` y
    --isfn (Todo _ _) x y = Left x `isgte` y
    isfn TodoNothing x y =  mempty
    equals x y | isGood x && isGood y = Util.UnionSolve.equals x y
               | otherwise = mempty
    isgte x y | isGood x && isGood y = Util.UnionSolve.isgte x y
              | otherwise = mempty
    islte x y | isGood x && isGood y = Util.UnionSolve.islte x y
              | otherwise = mempty
    gn ret head = f head where
        fl ret (v :-> body) = do
            forM_ (Set.toList $ freeVars v) $ \ (v,vt) -> zVar "Alt" v vt
            fn ret body
        dunno ty = do
            dres [Right (if TyNode == t then N WHNF Top else top) | t <- ty ]
        dres res = do
            case ret of
                Todo b vs | length res /= length vs -> error "lengths don't match!"
                Todo b vs -> forM_ (zip vs res) $ \ (v,r) -> tell (isfn ret v r)
                _ -> return ()
        f (_ :>>= _) = error $ "Grin.NodeAnalyze: :>>="
        f (Case v as)
            | Todo _ n <- ret = mapM_ (fl (Todo False n)) as
            | TodoNothing <- ret = mapM_ (fl TodoNothing) as
        f (BaseOp Eval [x]) = do
            dres [Right (N WHNF Top)]
        f (BaseOp (Apply ty) xs) = do
            mapM_ convertVal xs
            dunno ty
        f (App { expFunction = fn, expArgs = vs, expType = ty }) = do
            vs' <- mapM convertVal vs
            forMn_ (zip vs vs') $ \ ((tv,v),i) -> do
                tell $ v `islte` Left (fa fn i (getType tv))
            dres [Left $ fr fn i t | i <- [ 0 .. ] | t <- ty ]
        f (Call { expValue = Item fn _, expArgs = vs, expType = ty }) = do
            vs' <- mapM convertVal vs
            forMn_ (zip vs vs') $ \ ((tv,v),i) ->  do
                tell $ v `islte` Left (fa fn i (getType tv))
            dres [Left $ fr fn i t | i <- [ 0 .. ] | t <- ty ]
        f (Return x) = do mapM convertVal x >>= dres
        f (BaseOp (StoreNode _) w) = do mapM convertVal w >>= dres
        f (BaseOp Promote [w]) = do
            ww <- convertVal w
            tell $ ww `islte` Right (N WHNF Top)
            dres [ww]
        f (BaseOp Demote [w]) = do
            ww <- convertVal w
            tell $ ww `islte` Right (N WHNF Top)
            dres [ww]
        f Error {} = return ()
        f Prim { expArgs = as, expType = ty } = mapM_ convertVal as >> dunno ty
        f Alloc { expValue = v } | getType v == TyNode = do
            v' <- convertVal v
            dres [v']
        f Alloc { expValue = v } | getType v == tyINode = do
            convertVal v
            dunno [TyPtr tyINode]
        f NewRegion { expLam = _ :-> body } = fn ret body
        f (BaseOp Overwrite [Var vname ty,v]) | ty == TyINode = do
            v' <- convertVal v
            tell $ Left (vr vname ty) `isgte` v'
            dres []
        f e@(BaseOp Overwrite vs) = do mapM_ convertVal vs >> dunno (getType e)
        f e@(BaseOp PokeVal vs) = do mapM_ convertVal vs >> dunno (getType e)
        f e@(BaseOp PeekVal vs) = do mapM_ convertVal vs  >> dunno (getType e)
        f Let { expDefs = ds, expBody = e } = do
            mapM_ doFunc (map (\x -> (funcDefName x, funcDefBody x)) ds)
            fn ret e
        f exp = error $ "NodeAnalyze.f: " ++ show exp

    convertVal (Const n@(NodeC _ _)) = convertVal n
    convertVal (Const _) = return $ Right (N WHNF Top)
    convertVal (NodeC t vs) = case tagUnfunction t of
        Nothing -> do
            mapM_ convertVal vs
            return $ Right (N WHNF (Only $ Set.singleton t))
        Just (n,fn) -> do
            vs' <- mapM convertVal vs
            forMn_ (zip vs vs') $ \ ((vt,v),i) -> do
                tell $ v `islte` Left (fa fn i (getType vt))
            forM_ [0 .. n - 1 ] $ \i -> do
               tell $ Right top `islte` Left (fa fn (length vs + i) TyINode)
            return $ Right (N (if n == 0 then Lazy else WHNF) (Only $ Set.singleton t))
    convertVal (Var v t) = return $ Left (vr v t)
    convertVal v | isGood v = return $ Right (N Lazy Top)
    convertVal Lit {} = return $ Left VIgnore
    convertVal ValPrim {} = return $ Left VIgnore
    convertVal Index {} = return $ Left VIgnore
    convertVal Item {} = return $ Left VIgnore
    convertVal ValUnknown {} = return $ Left VIgnore
    convertVal v = error $ "convertVal " ++ show v

bottom = N WHNF (Only (Set.empty))
top = N Lazy Top

data WhatToDo
    = WhatDelete
    | WhatUnchanged
    | WhatConstant Val
    | WhatSubs Ty (Val -> Exp) (Val -> Exp)

--isWhatUnchanged WhatUnchanged = True
--isWhatUnchanged _ = False

transformFuncs :: (Atom -> [Ty] -> Maybe [Ty] -> (Maybe [WhatToDo],Maybe [WhatToDo])) -> Grin -> Grin
transformFuncs fn grin = grin'' where
    grin'' =  grin' { grinTypeEnv = extendTyEnv (grinFunctions grin') (grinTypeEnv grin') }
    grin' = setGrinFunctions (nfs $ grinFuncs grin) grin
    nfs ds = map fs ds
    fs (n,l@(ps :-> e)) = (n,f (fn n (map getType ps) (Just $ getType e)) l)
    f (Nothing,Nothing) (p :-> e) = p :-> j e
    f (Just ats,rts') (p :-> e) = p' :-> e' where
        rts = maybe (map (const WhatUnchanged) (getType e)) id rts'
        p' = concatMap f (zip p ats) where
            f (v,WhatUnchanged) = [v]
            f (_,WhatDelete) = []
            f (_,WhatConstant _) = []
            f (Var v _,WhatSubs nty _ _) = [Var v nty]
            f _ = error "NodeAnalyze.transformFuncs: f bad."
        e' =  g (zip p ats) (j e)
        g ((_,WhatUnchanged):xs) e = g xs e
        g ((_,WhatDelete):xs) e = g xs e
        g ((vr,WhatConstant c):xs) e = Return [c] :>>= [vr] :-> g xs e
        g ((Var v vt,WhatSubs nt _ ft):xs) e = ft (Var v nt) :>>= [Var v vt] :-> g xs e
        g [] e = e :>>= rvs :-> h (zip rvs rts) (drop (length (getType e)) [v1 .. ]) [] where
            rvs = zipWith Var [v1 .. ] (getType e)
        g _ _ = error "NodeAnalyze.transformFuncs: g bad."
        h ((r,WhatUnchanged):xs) vs rs = h xs vs (r:rs)
        h ((r,WhatDelete):xs) vs rs = h xs vs rs
        h ((r,WhatConstant _):xs) vs rs = h xs vs rs
        h ((r,WhatSubs nty tt _):xs) (v:vs) rs = tt r :>>= [Var v nty] :-> h xs vs (Var v nty:rs)
        h [] _ rs = Return (reverse rs)
        h _ _ _ = error "NodeAnalyze.transformFuncs: h bad."
    f _ _ = error "NodeAnalyze.transformFuncs: f bad."

    j app@(BaseOp (StoreNode False) [NodeC a xs]) = res where
        res = if isNothing ats' then app else e'
        ats = maybe (repeat WhatUnchanged) id ats'
        (ats',_) = fn (tagFlipFunction a) (map getType xs) Nothing
        lvars = zipWith Var [ v1 .. ] (map getType xs)
        e' = Return xs :>>= lvars :-> f (zip lvars ats) []

        f ((v,WhatUnchanged):xs) rs = f xs (v:rs)
        f ((_,WhatDelete):xs) rs = f xs rs
        f ((_,WhatConstant _):xs) rs = f xs rs
        f ((Var v oty,WhatSubs nty tt _):xs) rs = tt (Var v oty) :>>= [Var v nty] :-> f xs (Var v nty:rs)
        f [] rs = BaseOp (StoreNode False) [NodeC a (reverse rs)]
        f _ _ = error "NodeAnalyze.transformFuncs: f bad."

    j app@(App a xs ts) = res where
        res = if isNothing ats' && isNothing rts'  then app else e'
        ats = maybe (repeat WhatUnchanged) id ats'
        rts = maybe (repeat WhatUnchanged) id rts'
        (ats',rts') = fn a (map getType xs) (Just ts)
        lvars = zipWith Var [ v1 .. ] (map getType xs)
        e' = Return xs :>>= lvars :-> f (zip lvars ats) []

        f ((v,WhatUnchanged):xs) rs = f xs (v:rs)
        f ((_,WhatDelete):xs) rs = f xs rs
        f ((_,WhatConstant _):xs) rs = f xs rs
        f ((Var v oty,WhatSubs nty tt _):xs) rs = tt (Var v oty) :>>= [Var v nty] :-> f xs (Var v nty:rs)
        f [] rs = App a (reverse rs) ts' :>>= rvars :-> g (zip rvars' rts) rvars []
        f _ _ = error "NodeAnalyze.transformFuncs: f bad."

        g [] [] rs = Return (reverse rs)
        g ((_,WhatUnchanged):xs) (n:ns) rs = g xs ns (n:rs)
        g ((v,WhatDelete):xs) vs rs = Return [ValUnknown (getType v)] :>>= [v] :-> g xs vs (v:rs)
        g ((v,WhatConstant c):xs) vs rs = Return [c] :>>= [v] :-> g xs vs (v:rs)
        g ((v,WhatSubs _ _ ft):xs) (n:ns) rs = ft n :>>= [v] :-> g xs ns (v:rs)
        g _ _ _ = error "NodeAnalyze.transformFuncs: g bad."

        rvars = zipWith Var [ v1 .. ] ts'
        rvars' = zipWith Var (drop (length rvars) [ v1 .. ]) ts
        ts' = concatMap g (zip ts rts) where
            g (t,WhatUnchanged) = [t]
            g (t,WhatConstant _) = []
            g (t,WhatDelete) = []
            g (t,WhatSubs nty _ _) = [nty]

    j Let { expDefs = ds, expBody = e } =  grinLet [ updateFuncDefProps d { funcDefBody = snd $ fs (funcDefName d, funcDefBody d) } | d <- ds ] (j e)
    j e = runIdentity $ mapExpExp (return . j) e

fixupFuncs sfuncs pfuncs cmap  = ans where
    ans a as jrs | a `Set.member` pfuncs = (Nothing,Nothing)
                 | a `Set.member` sfuncs = (Just aargs,Nothing)
                 | otherwise =  (Just aargs,fmap rargs jrs) where
        aargs = map (bool pnode WhatUnchanged) largs
        largs = map (lupArg fa a) (zip as [0 ..  ])
        rargs rs = map (bool pnode WhatUnchanged) (map (lupArg fr a) (zip rs [0 ..  ]))
    lupArg fa a (x,i) =  case (x,Map.lookup (fa a i x) cmap) of
        (TyINode,Just (ResultJust _ (N WHNF _))) -> True
        (TyINode,Just ResultBounded { resultLB = Just (N WHNF _) }) -> True
        (TyINode,Just ResultBounded { resultLB = Nothing }) -> True
        _ -> False
    pnode = WhatSubs TyNode (\v -> BaseOp Promote [v]) (\v -> BaseOp Demote [v])

fixupfs cmap tyEnv l = tickleM f (l::Lam) where
    lupVar (Var v t) =  case Map.lookup (vr v t) cmap of
        _ | v < v0 -> fail "nocafyet"
        Just (ResultJust _ lb) -> return lb
        Just ResultBounded { resultLB = Just lb } -> return lb
        Just ResultBounded { resultLB = Nothing } -> return bottom
        _ -> fail "lupVar"
    lupVar _ = fail "lupVar2"
    pstuff x arg n@(N w t) = liftIO $ when verbose (printf "-- %s %s %s\n" x (show arg) (show n))
    f a@(BaseOp Eval [arg]) | Just n <- lupVar arg = case n of
        N WHNF _ -> do
            pstuff "eval" arg n
            Stats.mtick (toAtom "Optimize.NodeAnalyze.eval-promote")
            return (BaseOp Promote [arg])
        _ -> return a
    f a@(BaseOp (Apply ty) (papp:args)) | Just nn <- lupVar papp = case nn of
        N WHNF tset | Only set <- tset, [sv] <- Set.toList set, TagPApp n fn <- tagInfo sv, Just (ts,_) <- findArgsType tyEnv sv -> do
            pstuff "apply" papp nn
            case (n,args) of
                (1,[arg]) -> do
                    Stats.mtick (toAtom "Optimize.NodeAnalyze.apply-inline")
                    let va = Var v1 (getType arg)
                        vars = zipWith Var [ v2 .. ] ts
                    return $ Return [arg,papp] :>>= [va,NodeC sv vars] :-> App fn (vars ++ [va]) ty
                (1,[]) -> do
                    Stats.mtick (toAtom "Optimize.NodeAnalyze.apply-inline")
                    let vars = zipWith Var [ v2 .. ] ts
                    return $ Return [papp] :>>= [NodeC sv vars] :-> App fn vars ty
                (pn,[arg]) -> do
                    Stats.mtick (toAtom "Optimize.NodeAnalyze.apply-inline")
                    let va = Var v1 (getType arg)
                        vars = zipWith Var [ v2 .. ] ts
                    return $ Return [arg,papp] :>>= [va,NodeC sv vars] :-> dstore (NodeC (partialTag fn (pn - 1)) (vars ++ [va]))
                (pn,[]) -> do
                    Stats.mtick (toAtom "Optimize.NodeAnalyze.apply-inline")
                    let vars = zipWith Var [ v2 .. ] ts
                    return $ Return [papp] :>>= [NodeC sv vars] :-> dstore (NodeC (partialTag fn (pn - 1)) vars)
                _ -> return a
        _ -> return a
    f e = mapExpExp f e

dstore x = BaseOp (StoreNode True) [x]

renameUniqueGrin :: Grin -> Grin
renameUniqueGrin grin = res where
    (res,()) = evalRWS (execUniqT 1 ans) ( mempty :: Map.Map Atom Atom) (fromList [ x | (x,_) <- grinFuncs grin ] :: Set.Set Atom)
    ans = do tickleM f grin
    f (l :-> b) = g b >>= return . (l :->)
    g a@App  { expFunction = fn } = do
        m <- lift ask
        case mlookup fn m of
            Just fn' -> return a { expFunction = fn' }
            _ -> return a
    g a@Call { expValue = Item fn t } = do
        m <- lift ask
        case mlookup fn m of
            Just fn' -> return a { expValue = Item fn' t }
            _ -> return a
    g (e@Let { expDefs = defs }) = do
        (defs',rs) <- liftM unzip $ flip mapM defs $ \d -> do
            (nn,rs) <- newName (funcDefName d)
            return (d { funcDefName = nn },rs)
        local (fromList rs `mappend`) $  mapExpExp g e { expDefs = defs' }
    g b = mapExpExp g b
    newName a = do
        m <- lift get
        case member a m of
            False -> do lift $ modify (insert a); return (a,(a,a))
            True -> do
            let cfname = do
                uniq <- newUniq
                let fname = toAtom $ show a  ++ "-" ++ show uniq
                if fname `member` (m :: Set.Set Atom) then cfname else return fname
            nn <- cfname
            lift $ modify (insert nn)
            return (nn,(a,nn))

bool x y b = if b then x else y