module Grin.Noodle where

-- various routines for manipulating and exploring grin code.

import Control.Monad.Writer
import qualified Data.Set as Set

import C.Prims
import Debug.Trace
import Grin.Grin
import Options(flint)
import StringTable.Atom(Atom())
import Support.CanType
import Support.FreeVars
import Support.Tickle
import Util.GMap
import Util.Gen
import Util.HasSize
import Util.SetLike

modifyTail :: Lam -> Exp -> Exp
modifyTail lam@(_ :-> lb) te = f (sempty :: GSet Atom) te where
    lamFV = freeVars lam :: GSet Var
    f lf e | False && trace ("modifyTail: " ++ show (lf,e)) False = undefined
    f _ (Error s ty) = Error s (getType lb)
    f lf (Case x ls) = Case x (map (g lf) ls)
    f _ lt@Let {expIsNormal = False } = lt :>>= lam
    f lf lt@Let {expDefs = defs, expBody = body, expIsNormal = True } = updateLetProps lt { expBody = f nlf body, expDefs = defs' } where
        nlf = lf `union` fromList (map funcDefName defs)
        defs' = [ updateFuncDefProps d { funcDefBody = g nlf (funcDefBody d) } | d <- defs ]
    f lf lt@MkCont {expLam = lam, expCont = cont } = lt { expLam = g lf lam, expCont = g lf cont }
    f lf (e1 :>>= p :-> e2) = e1 :>>= p :-> f lf e2
    f lf e@(App a as t) | a `member` lf = App a as (getType lb)
    f lf e = e :>>= lam
    g lf (p :-> e) | flint && not (isEmpty $ intersection (freeVars p) lamFV) = error "modifyTail: lam floated inside bad scope"
    g lf (p :-> e) = p :-> f lf e

instance Tickleable Exp Lam where
    tickleM = mapBodyM
instance Tickleable Exp Exp where
    tickleM = mapExpExp
instance Tickleable Val Exp where
    tickleM = mapExpVal
instance Tickleable Val Val where
    tickleM = mapValVal
    tickleM_ = mapValVal_
instance Tickleable Lam Grin where
    tickleM f grin = liftM (`setGrinFunctions` grin) $ mapM  (\x -> do nb <- f (funcDefBody x); return (funcDefName x, nb)) (grinFunctions grin)
instance Tickleable Lam FuncDef where
    tickleM f fd = funcDefBody_uM f fd
instance Tickleable (Atom,Lam) FuncDef where
    tickleM f fd@FuncDef { funcDefName = n, funcDefBody = b } = do
    (n',b') <- f (n,b)
    return $  updateFuncDefProps fd { funcDefBody = b', funcDefName = n' }

mapBodyM :: Monad m => (Exp -> m Exp) -> Lam -> m Lam
mapBodyM f (x :-> y) = f y >>= return . (x :->)

mapExpVal :: Monad m => (Val -> m Val) -> Exp -> m Exp
mapExpVal g x = f x where
    f (App a vs t) = return (App a) `ap` mapM g vs `ap` return t
    f (BaseOp a vs) = return (BaseOp a) `ap` mapM g vs
    f (Return vs) = return Return `ap` mapM g vs
    f (Prim x vs t) = return (Prim x) `ap` mapM g vs `ap` return t
    f e@Alloc { expValue = v, expCount = c } = do
        v <- g v
        c <- g c
        return e { expValue = v, expCount = c }
    f (Case v as) = do
        v <- g v
        return (Case v as)
    f e = return e

mapValVal fn x = f x where
    f (NodeC t vs) = return (NodeC t) `ap` mapM fn vs
    f (Index a b) = return Index `ap` fn a `ap` fn b
    f (Const v) = return Const `ap` fn v
    f (ValPrim p vs ty) = return (ValPrim p) `ap` mapM fn vs `ap` return ty
    f x = return x

mapValVal_ fn x = f x where
    f (NodeC t vs) = mapM_ fn vs
    f (Index a b) = fn a >> fn b >> return ()
    f (Const v) = fn v >> return ()
    f (ValPrim p vs ty) =  mapM_ fn vs >> return ()
    f _ = return ()

mapExpLam fn e = f e where
    f (a :>>= b) = return (a :>>=) `ap` fn b
    f (Case e as) = return (Case e) `ap` mapM fn as
    f lt@Let { expDefs = defs } = do
        defs' <- forM defs $ \d -> do
            b <- fn $ funcDefBody d
            return $ updateFuncDefProps d { funcDefBody = b }
        return $ updateLetProps lt { expDefs = defs' }
    f nr@NewRegion { expLam = lam } = do
        lam <- fn lam
        return $ nr { expLam = lam }
    f e@MkCont { expCont = c, expLam = l } = do
        c <- fn c
        l <- fn l
        return $ e { expCont = c, expLam = l }
    f e = return e

mapExpExp fn e = f e where
    f (a :>>= b) = return (:>>=) `ap` fn a `ap` g b
    f l@Let { expBody = b, expDefs = defs } = do
        b <- fn b
        return updateLetProps `ap` (mapExpLam g l { expBody = b })
    f (GcRoots vs e) = return (GcRoots vs) `ap` fn e
    f e = mapExpLam g e
    g (l :-> e) = return (l :->) `ap` fn e

mapFBodies f xs = mapM f' xs where
    f' fd@FuncDef { funcDefBody = l :-> r } = do
        r' <- f r
        return $  updateFuncDefProps fd { funcDefBody = l :-> r' }

funcDefBody_uM f fd@FuncDef { funcDefBody = b } = do
    b' <- f b
    return $  updateFuncDefProps fd { funcDefBody = b' }

grinFunctions_s nf grin = grin { grinFunctions = nf }

--------------------------
-- examining and reporting
--------------------------

isManifestNode :: Monad m => Exp -> m [Atom]
isManifestNode e = f (sempty :: GSet Atom) e where
    f lf _ | False && trace ("isManifestNode: " ++ show lf) False = undefined
    f lf (Return [(NodeC t _)]) = return [t]
    f lf Error {} = return []
    f lf (App a _ _) | a `member` lf = return []
    f lf Let { expBody = body, expIsNormal = False } = f lf body
    f lf Let { expBody = body, expDefs = defs, expIsNormal = True } = ans where
        nlf = lf `union` fromList (map funcDefName defs)
        ans = do
            xs <- mapM (f nlf . lamExp . funcDefBody) defs
            b <- f nlf body
            return (concat (b:xs))
    f lf (Case _ ls) = do
        cs <- Prelude.mapM (f lf) [ e | _ :-> e <- ls ]
        return $ concat cs
    f lf (_ :>>= _ :-> e) = isManifestNode e
    f lf _ = fail "not manifest node"

-- | Is a Val constant?
valIsConstant :: Val -> Bool
valIsConstant (NodeC _ xs) = all valIsConstant xs
valIsConstant Lit {} = True
valIsConstant Const {} = True
valIsConstant (Var v _) | v < v0 = True
valIsConstant (Index v t) = valIsConstant v && valIsConstant t
valIsConstant ValPrim {} = True
valIsConstant _ = False

-- NOPs will not produce any code at run-time so we can tail-call through them.
isNop (BaseOp Promote _) = True
isNop (BaseOp Demote _) = True
isNop _ = False

isOmittable (BaseOp Promote _) = True
isOmittable (BaseOp Demote _) = True
isOmittable (BaseOp PeekVal _) = True
isOmittable (BaseOp ReadRegister _) = True
isOmittable (BaseOp NewRegister _) = True
isOmittable (BaseOp GcPush _) = True  -- omittable because if we don't use the returned gc context, then we don't need to push to begin with
isOmittable (BaseOp (StoreNode _) _) = True
isOmittable Alloc {} = True
isOmittable (Return {}) = True
isOmittable Prim { expPrimitive = aprim } = primIsCheap aprim
isOmittable (Case x ds) = all isOmittable [ e | _ :-> e <- ds ]
isOmittable Let { expBody = x } = isOmittable x
isOmittable (e1 :>>= _ :-> e2) = isOmittable e1 && isOmittable e2
isOmittable _ = False

isErrOmittable (BaseOp Overwrite _) = True
isErrOmittable (BaseOp PokeVal _) = True
isErrOmittable (BaseOp WriteRegister _) = True
isErrOmittable (e1 :>>= _ :-> e2) = isErrOmittable e1 && isErrOmittable e2
isErrOmittable (Case x ds) = all isErrOmittable [ e | _ :-> e <- ds ]
isErrOmittable x = isOmittable x

-- collect tail and normally called functions
-- expression (tail called, non tail called)
collectFuncs :: Exp -> (Set.Set Atom,Set.Set Atom)
collectFuncs exp = runWriter (cfunc exp) where
        clfunc (l :-> r) = cfunc r
        cfunc e | False && trace ("isManifestNode: " ++ show e) False = undefined
        cfunc (e :>>= v :-> op@(BaseOp _ v')) | isNop op && v == v' = do cfunc e
        cfunc (e :>>= y) = do
            xs <- cfunc e
            tell xs
            clfunc y
        cfunc (App a _ _) = return (singleton a)
        cfunc (Case _ as) = do
            rs <- mapM clfunc as
            return (mconcat rs)
        cfunc Let { expFuncCalls = (tail,nonTail) } = do
            tell nonTail
            return tail
        cfunc Error {} = return mempty
        cfunc Prim {} = return mempty
        cfunc Return {} = return mempty
        cfunc BaseOp {} = return mempty
        cfunc Alloc {} = return mempty
        cfunc GcRoots { expBody = b} = cfunc b
        cfunc NewRegion { expLam = l } = clfunc l
        cfunc MkCont { expCont = l1, expLam = l2 } = do
            a <- clfunc l1
            b <- clfunc l2
            return (a `mappend` b)
        cfunc x = error "Grin.Noodle.collectFuncs: unknown"

grinLet defs body = updateLetProps Let {
    expDefs = defs,
    expBody = body,
    expInfo = mempty,
    expNonNormal = undefined,
    expIsNormal = undefined,
    expFuncCalls = undefined }

updateLetProps Let { expDefs = [], expBody = body } = body
updateLetProps lt@Let { expBody = body, expDefs = defs } =
        lt {
            expFuncCalls = (tail \\ myDefs, nonTail \\ myDefs),
            expNonNormal = notNormal,
            expIsNormal = isEmpty notNormal
            } where
    (tail,nonTail) = mconcatMap collectFuncs (body : map (lamExp . funcDefBody) defs)
    notNormal =  nonTail `intersection` (fromList $ map funcDefName defs)
    myDefs = fromList $ map funcDefName defs
updateLetProps e = e

data ReturnInfo = ReturnNode (Maybe Atom,[Ty]) | ReturnConst Val | ReturnCalls Atom | ReturnOther | ReturnError
    deriving(Eq,Ord)

getReturnInfo :: Exp -> [ReturnInfo]
getReturnInfo  e = ans where
    ans = execWriter (f (sempty :: GSet Atom) e)
    tells x = tell [x]
    f lf (Return [(NodeC t as)]) = tells (ReturnNode (Just t,map getType as))
    f lf (Return [z]) | valIsConstant z = tell [ReturnConst z]
    f lf Error {} = tells ReturnError
    f lf (Case _ ls) = do Prelude.mapM_ (f lf) [ e | _ :-> e <- ls ]
    f lf (_ :>>= _ :-> e) = f lf e
    f lf Let { expBody = body, expIsNormal = False } = f lf body
    f lf (App a _ _) | a `member` lf = return ()
    f lf Let { expBody = body, expDefs = defs, expIsNormal = True } = ans where
        nlf = lf `union` fromList (map funcDefName defs)
        ans = do
            mapM_ (f nlf . lamExp . funcDefBody) defs
            f nlf body
    f _ (App a _ _) = tells $ ReturnCalls a
    f _ e = tells ReturnOther

mapGrinFuncsM :: Monad m => (Atom -> Lam -> m Lam) -> Grin -> m Grin
mapGrinFuncsM f grin = liftM (`setGrinFunctions` grin) $ mapM  (\x -> do nb <- f (funcDefName x) (funcDefBody x); return (funcDefName x, nb)) (grinFunctions grin)