module Grin.Devolve(twiddleGrin,devolveTransform) where
import Control.Monad.Identity
import Control.Monad.RWS
import Data.Functor
import Data.IORef
import Data.Maybe
import qualified Data.Map as Map
import qualified Data.Set as Set
import Grin.Grin
import Grin.Noodle
import Options (verbose,fopts)
import Support.FreeVars
import Support.Transform
import Util.Gen
import Util.SetLike
import qualified FlagOpts as FO
devolveTransform = transformParms {
    transformDumpProgress = verbose,
    transformCategory = "Devolve",
    transformPass = "Grin",
    transformOperation = devolveGrin
    }
devolveGrin :: Grin -> IO Grin
devolveGrin grin = do
    col <- newIORef []
    let g (n,l :-> r) = f r >>= \r -> return (n,l :-> r)
        f lt@Let { expDefs = defs, expBody = body, .. } = do
            let iterZ :: Bool -> Map.Map Tag (Set.Set Val) -> [FuncDef] -> Map.Map Tag (Set.Set Val)
                iterZ b pmap (fd@FuncDef { funcDefName = name, funcDefBody = as :-> r }:fs) = iterZ (b || xs' /= xs) (Map.insert name xs pmap) fs where
                    xs = Set.unions $ xs':catMaybes [ Map.lookup t pmap | t <- Set.toList $ freeVars fd]
                    xs' = maybe Set.empty id (Map.lookup name pmap)
                iterZ True pmap [] = iterZ False pmap defs
                iterZ False pmap [] = pmap
                nndefs = [ fd | fd <- defs, funcDefName fd `Set.member` expNonNormal ]
                pmap = iterZ False (fromList [ (funcDefName fd, fromList [ Var x y | (x,y) <- Set.toList $ freeVars (funcDefBody fd), x > v0]) | fd <- nndefs ]) nndefs
                (nmaps,rmaps) = splitEither (map z defs)
                z fd@FuncDef { funcDefName = name, funcDefBody = as :-> r }
                    | name `Set.member` expNonNormal = Left ((name,(as ++ xs) :-> pr),xs)
                    | otherwise = Right fd { funcDefBody = as :-> pr }
                  where xs = maybe [] Set.toList $ Map.lookup name pmap
                        pr = runIdentity $ proc r
                proc (App a as t) | Just xs <- Map.lookup a pmap = return (App a (as ++ Set.toList xs) t)
                proc e = mapExpExp proc e
            
            nmaps <- mapM (g . fst) nmaps
            modifyIORef col (++ nmaps)
            updateLetProps <$> mapExpExp f lt { expDefs = rmaps, expBody = runIdentity $ proc body }
        f e = mapExpExp f e
    nf <- mapM g (grinFuncs grin)
    lf <- readIORef col
    let ntenv = extendTyEnv [ createFuncDef False x y | (x,y) <- lf ] (grinTypeEnv grin)
    return $  setGrinFunctions (lf ++ nf) grin { grinPhase = PostDevolve, grinTypeEnv = ntenv }
    
    
data Env = Env {
    envMap   :: Map.Map Var Var,
    envRoots :: Set.Set Val,
    envVar   :: Var
    }
newtype R a = R (RWS Env (Set.Set Var) () a)
    deriving(Monad,Functor,MonadReader Env,MonadWriter (Set.Set Var))
runR (R x) = fst $ evalRWS x Env { envRoots = mempty, envMap = mempty, envVar = v1 } ()
class Twiddle a where
    twiddle :: a -> R a
    twiddle a = return a
instance Twiddle Exp where
    twiddle = twiddleExp
instance Twiddle Val where
    twiddle = twiddleVal
instance Twiddle a => Twiddle [a] where
    twiddle xs = mapM twiddle xs
twiddleExp e = f e where
    f (x :>>= lam) | fopts FO.Jgc && isAllocing x = do
        roots <- asks envRoots
        let nroots = Set.fromList [ Var v t | (v,t) <- Set.toList (freeVars (if isUsing x then ([] :-> x :>>= lam) else lam)), isNode t, v > v0] Set.\\ roots
        local (\e -> e { envRoots = envRoots e `Set.union` nroots}) $ do
            ne <- return (:>>=) `ap` twiddle x `ap` twiddle lam
            return $ gcRoots (Set.toList nroots) ne
    f (x :>>= lam) = return (:>>=) `ap` twiddle x `ap` twiddle lam
    f l@Let {} = do
        ds <- twiddle (expDefs l)
        b <- twiddle (expBody l)
        return . updateLetProps $ l { expDefs = ds, expBody = b }
    f (Case v as) = return Case `ap` twiddle v `ap` twiddle as
    f x | fopts FO.Jgc && isUsing x && isAllocing x = do
        roots <- asks envRoots
        let nroots = Set.fromList [ Var v t | (v,t) <- Set.toList (freeVars x), isNode t, v > v0] Set.\\ roots
        local (\e -> e { envRoots = envRoots e `Set.union` nroots}) $ do
            ne <- mapExpVal twiddleVal x
            return $ gcRoots (Set.toList nroots) ne
    f n = do e <- mapExpVal twiddleVal n ; mapExpExp twiddle e
    isUsing (BaseOp StoreNode {} _) = True
    isUsing Alloc {} = True
    isUsing _ = False
    isAllocing (BaseOp StoreNode {} _) = True
    isAllocing (BaseOp Eval {} _) = True
    isAllocing (Return [Var {}]) = False
    isAllocing (Return [NodeC {}]) = True
    isAllocing App {} = True
    isAllocing Call {} = True
    isAllocing Let {} = True
    isAllocing (Case _ as) = any isAllocing [ b | _ :-> b <- as]
    isAllocing Alloc {} = True
    isAllocing (e :>>= _ :-> y) = isAllocing e || isAllocing y
    isAllocing _ = False
    gcRoots [] x = x
    gcRoots xs e = GcRoots xs e
    isNode TyNode = True
    isNode TyINode = True
    isNode (TyPtr TyNode) = True
    isNode (TyPtr TyINode) = True
    isNode _ = False
instance Twiddle Lam where
    twiddle (vs :-> y) = do
        let fvs = freeVars vs
        (y,uv) <- censor (Set.filter (`notElem` fvs)) $ listen (twiddle y)
        let fvp' = Map.fromList $ concatMap (\v -> if v `Set.member` uv then [] else [(v,v0)]) fvs
        vs <- censor (const mempty) . local (\e -> e { envMap = fvp' }) $ twiddle vs
        return (vs :-> y)
twiddleGrin grin = grinFunctions_s fs' grin where
    fs' = runR . twiddle  $ grinFunctions grin
instance Twiddle FuncDef where
    twiddle = funcDefBody_uM twiddle
twiddleVal x = f x where
    f var@(Var v ty) = do
        em <- asks envMap
        case Map.lookup v em of
            Just n -> tell (Set.singleton n) >> return (Var n ty)
            Nothing -> tell (Set.singleton v) >> return var
    f x = mapValVal f x