{-# LANGUAGE TemplateHaskell #-}
module Grin.SSimplify(simplify,explicitRecurse) where

import Control.Monad.Identity
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.Maybe
import qualified Data.IntMap as IM
import qualified Data.IntSet as IS
import qualified Data.Map as Map
import qualified Data.Set as Set
import Data.DeriveTH

import Grin.Grin
import Grin.Noodle
import Stats(mtick)
import StringTable.Atom
import Support.CanType
import Support.FreeVars
import Support.Tickle
import Util.GMap
import Util.Gen
import Util.HasSize
import Util.RWS
import Util.SetLike
import qualified Stats

-- This goes through and puts grin into a normal form, in addition, it carries out some straightforward
-- simplifications.
--
-- normalized form has the following properties
--
-- :>>= only appears in trailing position
-- Return [v0 .. vn] for n > 1 only appears in trailing position
--
-- all variables and function names are unique in their scope.

data SEnv = SEnv {
    envSubst :: IM.IntMap Val,   -- renaming substitution
    envCSE   :: Map.Map Exp (Atom,Exp),
    envPapp  :: IM.IntMap (Atom,[Val])
    --envPush  :: IM.IntMap Exp
    }

newtype SState = SState { usedVars :: IS.IntSet }

data SCol = SCol {
    colStats :: Stats.Stat,
    colFreeVars :: GSet Var
    }

$(derive makeMonoid ''SCol)

{-
data ExpInfo = ExpInfo {
    expFreeVars :: GSet Var,
    expUnboxing :: UnboxingResult,
    expType     :: [Ty]
    }
-}

newtype S a = S (RWS SEnv SCol SState a)
    deriving(Monad,Functor,MonadWriter SCol, MonadReader SEnv,MonadState SState)

instance Stats.MonadStats S where
    mtickStat s = S (tell mempty { colStats = s })
    mticks' n a = S (tell mempty { colStats = Stats.singleStat n a })

tellFV v = tell mempty { colFreeVars = freeVars v }

simplify :: Grin -> IO Grin
simplify grin = do
    let (fs,_,SCol { colStats = stats}) = runRWS fun mempty SState { usedVars = mempty }
        S fun = simpFuncs (grinFunctions grin)
    return grin { grinFunctions = fs, grinStats = grinStats grin `mappend` stats }

simpFuncs :: [FuncDef] -> S [FuncDef]
simpFuncs fd = do
    let f fd@FuncDef { funcDefBody = body } = do
            body' <- simpLam body
            return $ updateFuncDefProps fd { funcDefBody = body' }
    mapM f fd

simpLam :: Lam -> S Lam
simpLam (ps :-> e) = do
    (ps,env') <- renamePattern ps
    let f col = col { colFreeVars = colFreeVars col \\ freeVars ps }
    (e,col) <- censor f $ listen $ local (env' `mappend`) $ simpExp e
    ps <- mapM (zeroVars (`member` colFreeVars col)) ps
    return (ps :-> e)

dstore x = BaseOp (StoreNode True) [x]

simpDone :: Exp -> S Exp
simpDone e = do
    pmap <- asks envPapp
    case e of
        (BaseOp (Apply ty) (Var (V vn) _:fs)) | Just (tl,gs) <- IM.lookup vn pmap -> do
            (cl,fn) <- tagUnfunction tl
            mtick $ if cl == 1 then "Simplify.Apply.Papp.{" ++ show tl  else ("Simplify.Apply.App.{" ++ show fn)
            return $ if cl == 1
                then App fn (gs ++ fs) ty
                else dstore (NodeC (partialTag fn (cl - 1)) (gs ++ fs))
        (Case v ls) | isJust utypes -> ans where
            utypes@(~(Just ts)) = unboxTypes ur
            ur = foldr1 combineUnboxing [ getUnboxing e | _ :-> e <- ls ]
            ans = do
                mtick "Grin.Simplify.Unbox.case-return"
                let vs = zipWith Var [v1 ..] ts
                return (unboxModify ur (Case v ls) :>>= vs :->  (unboxRet ur vs))
        (Case v1 ls) | [v1'] :-> Case v2 ls' <- last ls, v1' == v2 || v1 == v2 -> do
            let f (p :-> e) = p :-> Return [v1] :>>= [v1'] :-> e
            mtick "Grin.Simplify.case-merge"
            return $ Case v1 (init ls ++ map f ls')
        --(e :>>= p :-> Return p') | p == p' -> do
        --    mtick "Grin.Simplify.tail-return-omit"
        --    return e
        _ -> do
            cmap <- asks envCSE
            case Map.lookup e cmap of
                Just (n,e') -> do mtick n; tellFV e'; return e'
                Nothing -> return e

simpBind :: [Val] -> Exp -> S Exp -> S Exp
simpBind p e cont = f p e where
    cse name xs = do
        (z,col) <- listen $ local (\s -> s { envCSE = Map.fromList [ (x,(toAtom name,y)) | (x,y) <- xs] `Map.union` envCSE s }) cont
        e <- simpDone e
        if isOmittable e && isEmpty (freeVars p `intersection` colFreeVars col) then do
            mtick "Simplify.Omit.Bind"
            return z
         else return $ e :>>= (p :-> z)
    cse' name xs = cse name ((e,Return p):xs)
    f p app@(BaseOp Eval [v]) =  cse' "Simplify.CSE.eval" [(BaseOp Promote [v],Return p)]
    f p (BaseOp Promote [v@Var {}]) =  cse' "Simplify.CSE.promote" [(gEval v,Return p)]
    f [p] (BaseOp Demote [v@Var {}]) =  cse' "Simplify.CSE.demote" [(BaseOp Promote [p],Return [v]),(gEval p,Return [v])]
    f [p@(Var (V vn) _)] (BaseOp (StoreNode isD) [v@(NodeC t vs)]) | not (isHoly v) = case (isD,tagUnfunction t,tagIsWHNF t) of
        (True,Nothing,_) -> cse' "Simplify.CSE.return-node" []
        (True,Just (n,fn),_) -> local (\s -> s { envPapp = IM.insert vn (t,vs) (envPapp s) }) $ cse' "Simplify.CSE.return-node-func" []
        --(False,_,True)  -> local (\s -> s { envPush = IM.insert vn (Store v) (envPush s) }) $ cse "Simplify.CSE.store-whnf" []
        --(False,_,False) -> cse' "Simplify.CSE.store" []
        _ -> cse' "Simplify.CSE.store" []
--    f [p@(Var (V vn) _)] (Return [v@(NodeC t vs)]) | not (isHoly v) = case tagUnfunction t of
--        Nothing -> cse "Simplify.CSE.return-node" [(Return [p],Return [v]),(Store p,Store v)]
--        Just (n,fn) -> local (\s -> s { envPapp = IM.insert vn (t,vs) (envPapp s) }) $ cse' "Simplify.CSE.return-node" [(Return [p],Return [v]),(Store p,Store v)]
--    f [p@(Var (V vn) _)] (Store v@(NodeC t vs)) | not (isHoly v) = case tagIsWHNF t of
--        True -> local (\s -> s { envPush = IM.insert vn (Store v) (envPush s) }) $ cse "Simplify.CSE.store-whnf" [(BaseOp Promote [p],Return [v]),(gEval p,Return [v])]
--        False -> cse' "Simplify.CSE.store" []
    f _ _ = cse "Simplify.CSE.NOT" []

extEnv :: Var -> Val -> SEnv -> SEnv
extEnv (V vn) v s = s { envSubst = IM.insert vn v (envSubst s) }

simpExp :: Exp -> S Exp
simpExp e = f e [] where
    f (e :>>= p :-> Return p') rs | p == p' = do
        mtick "Grin.Simplify.tail-return-omit"
        f e rs
    f  (a :>>= (v :-> b)) xs = do
        env <- ask
        f a ((env,v,b):xs)

    -- simple transforms
    f (BaseOp Promote [Const x]) rs = do
        mtick "Grin.Simplify.fetch-const"
        f (Return [x]) rs
--    f (Store x) rs | valIsNF x = do
--        mtick "Grin.Simplify.store-normalform"
--        f (Return [Const x]) rs
    f (BaseOp Eval [Const n]) rs = do
        mtick "Grin.Simplify.eval-const"
        f (Return [n]) rs
    f (Error s t) rs@(_:_) = do
        mtick "Grin.Simplify.error-discard"
        let (_,_,b) = last rs
        f (Error s (getType b)) []
    f (Return [v]) ((senv,[Var vn _],b):rs) | valIsConstant v = do
        mtick "Grin.Simplify.Subst.const"
        fbind vn v senv b rs
    f (Return [v@ValUnknown {}]) ((senv,[Var vn _],b):rs) = do
        mtick "Grin.Simplify.Subst.unknown"
        fbind vn v senv b rs
    f (Return [v@Var {}]) ((senv,[Var vn _],b):rs) = do
        mtick "Grin.Simplify.Subst.var"
        fbind vn v senv b rs
--    f a@(Return [NodeC t xs]) ((senv,[NodeC t' ys],b):rs) | t == t' = do
--        mtick "Grin.Simplify.Assign.node-node"
--        dtup xs ys senv b rs
    f (Return []) ((senv,[],b):rs) = do
        mtick "Grin.Simplify.Assign.unit-unit"
        dtup [] [] senv b rs
    f a@(Return (xs@(_:_:_))) ((senv,ys,b):rs) = do
        mtick "Grin.Simplify.Assign.tuple-tuple"
        dtup xs ys senv b rs
    f (Case v@Var {} [l]) rs = do
        f (Return [v] :>>= l) rs
--    f e@(Case v ls) rs | isJust utypes  = ans where
--        utypes@(~(Just ts)) = unboxTypes ur
--        ur = foldr1 combineUnboxing [ getUnboxing e | _ :-> e <- ls ]
--        ans = do
--            mtick "Grin.Simplify.Unbox.case-return"
--            let vs = zipWith Var [v1 ..] ts
--            f (unboxModify ur (Case v ls) :>>= vs :-> Return (unboxRet ur vs)) rs
    f a ((senv,p,b):xs) = do
        a <- g a
        (p,env') <- renamePattern p
        let env'' = env' `mappend` senv
        local (const env'') $ simpBind p a (f b xs)
    f x [] = do
        e <- g x
        simpDone e
    fbind vn v senv b rs = do
        v' <- applySubst v
        local (\_ -> extEnv vn v' senv) $ f b rs

    dtup xs ys senv b rs | sameLength xs ys = do
        xs <- mapM applySubst xs
        (ys,env') <- renamePattern ys
        let env'' = env' `mappend` senv
        z <- local (const env'') $ f b rs
        ts <- mapM (return . Just) [([y],Return [x]) | x <- xs | y <- ys ]
        let h [] = z
            h ((p,v):rs) = v :>>= p :-> h rs
        return $ h [ (p,v) |  Just (p,v) <- ts]
    dtup _ _ _ _ _ = error "dtup: attempt to bind unequal lists"
    g (Case v as) = do
        v <- applySubst v
        as <- mapM simpLam as
        return $ Case v as
    g  lt@Let { expDefs = defs, expBody = body } = do
        body <- f body []
        defs <- simpFuncs defs
        let dnames = fromList $ map funcDefName defs :: GSet Atom
            isInvalid e = isEmpty (freeVars e `intersection` dnames)
        case body of
            e :>>= l :-> r | isInvalid e -> do
                mtick "Simplify.simplify.let-shrink-head"
                return $ e :>>= l :-> updateLetProps lt { expBody = r, expDefs = defs }
            e :>>= l :-> r | isInvalid r -> do
                mtick "Simplify.simplify.let-shrink-tail"
                return (updateLetProps lt { expBody = e, expDefs = defs } :>>= l :-> r)
            App f as ts | f `elem` map funcDefName defs, f `Set.notMember` freeVars (map funcDefBody defs) -> do
                mtick "Simplify.simplify.let-inline-body"
                let [fbody] = [ funcDefBody fd | fd <- defs, funcDefName fd == f]
                return $ updateLetProps lt { expDefs = defs, expBody = Return as :>>= fbody }
            _ -> return $ updateLetProps lt { expBody = body, expDefs = defs }
    g x = applySubstE x

applySubstE :: Exp -> S Exp
applySubstE x = mapExpVal applySubst x

applySubst x = f x where
    f var@(Var (V v) _) = do
        env <- asks envSubst
        case IM.lookup v env of
            Just n -> tellFV n >> return n
            Nothing -> tellFV var >> return var
    f x = mapValVal f x

zeroVars fn x = f x where
    f (Var v ty) | fn v || v == v0 = return (Var v ty)
                 | otherwise = do mtick $ "Simplify.ZeroVar.{" ++ show (Var v ty); return (Var v0 ty)
    f x = mapValVal f x

renamePattern :: [Val] ->  S ([Val],SEnv)
renamePattern x = runWriterT (mapM f x) where
    f :: Val -> WriterT SEnv S Val
    f (Var v@(V vn) t) = do
        v' <- lift $ newVarName v
        let nv = Var v' t
        tell (mempty { envSubst = IM.singleton vn nv })
        return nv
    f x = mapValVal f x

newVarName :: Var -> S Var
newVarName (V 0) = return (V 0)
newVarName (V sv) = do
    s <- gets usedVars
    let nv = v sv
        v n | n `IS.member` s = v (1 + n + IS.size s)
            | otherwise = n
    modify (\e -> e { usedVars = IS.insert nv s })
    return (V nv)

isHoly (NodeC _ as) | any isValUnknown as = True
isHoly n = False

data UnboxingResult
    = UnErr [Ty]
    | UnStore !Bool !Atom [Unbox]
    | UnDemote Unbox
    | UnReturn [Unbox]
    | UnTail (Set.Set Atom) [Ty] [Ty]

data Unbox =  UnConst Val | UnUnknown Ty | UnBaseOp BaseOp [Unbox]
    deriving(Eq,Ord)

isUnUnknown UnUnknown  {} = True
isUnUnknown _ = False

instance CanType UnboxingResult where
    type TypeOf UnboxingResult = [Ty]
    getType (UnErr tys) = tys
    getType (UnReturn us) = map getType us
    getType (UnStore b _ _) = [bool b tyDNode tyINode]
    getType (UnDemote _) = [tyINode]
    getType (UnTail _ tys _) = tys

instance CanType Unbox where
    type TypeOf Unbox = Ty
    getType (UnConst v) = getType v
    getType (UnUnknown t) = t
    getType _ = error "getType: bad."

unboxRet :: UnboxingResult -> [Val] -> Exp
unboxRet ur vs = f ur vs where
    f (UnReturn xs) vs = Return $ let (r,[]) = g xs vs in r
    f (UnStore b c xs) vs = let (xs',[]) = g xs vs in BaseOp (StoreNode b) [NodeC c xs']
    f (UnDemote u) vs = let ([u'],[]) = g [u] vs in BaseOp Demote [u']
    f (UnTail a _ tys) vs | [f] <- Set.toList a = App f vs tys
    f UnErr {} _ = Return []
    f _ vs = Return vs
    g [] vs = ([],vs)
    g (UnUnknown _:xs) (v:vs) = let (r,y) = g xs vs in (v:r,y)
    g (UnConst v:xs) vs = let (r,y) = g xs vs in (v:r,y)
    g _ _ = error "SSimplify.unboxRet: bad."

unboxTypes :: UnboxingResult -> Maybe [Ty]
unboxTypes ur = f ur where
    f (UnTail ts tys _) | Set.size ts == 1 = Just tys
    f (UnTail {}) = Nothing
    f (UnErr []) = Nothing
    f (UnErr (_:_)) = Just []
    f (UnReturn us) | all isUnUnknown us = Nothing
    f (UnReturn xs) = Just $ concatMap h xs
    f (UnStore _ _ ts) = Just $ concatMap h ts
    f (UnDemote _) = Just [tyDNode]
    h (UnUnknown t) = [t]
    h (UnConst {}) = []
    h _ = error "SSimplify.unboxTypes: bad."

unboxModify :: UnboxingResult -> Exp -> Exp
unboxModify ur = f ur where
    Just nty = unboxTypes ur
    f UnErr {} = id
    f (UnTail a tys _) | [f] <- Set.toList a = runIdentity . editTail tys (mApp f)
    f (UnReturn us) | all isUnUnknown us = id
    f (UnReturn xs) = runIdentity . editTail nty (g xs)
    f (UnStore _ _ us) =runIdentity . editTail nty (z us)
    f (UnDemote _) =runIdentity . editTail nty y
    f _ = error "SSimplify.unboxModify: bad1."
    g xs (Return ys) = return $ Return (concat $ zipWith h xs ys)
    g _ _ = error "SSimplify.unboxModify: bad2."
    h (UnUnknown _) y = [y]
    h (UnConst {}) _ = []
    h _ _ = error "SSimplify.unboxModify: bad3."
    z xs (BaseOp (StoreNode _) [NodeC _ ts]) = return . Return . concat $ zipWith h xs ts
    z _ _ = error "SSimplify.unboxModify: bad4."
    y (BaseOp Demote [x]) = return $ Return [x]
    y (Return [Const v]) = return $ Return [v]
    y _ = error "SSimplify.unboxModify: bad5."
    mApp f (App f' as tys) | f == f' = return $ Return as
    mApp f e  = error $ "mApp: " ++ show (f,e)

combineUnboxing :: UnboxingResult -> UnboxingResult -> UnboxingResult
combineUnboxing ub1 ub2 = f ub1 ub2 where
    f UnErr {} x = x
    f x UnErr {} = x
    f (UnTail t1 a1 u1) (UnTail t2 a2 u2) | u1 == u2, a1 == a2 = UnTail (t1 `union` t2) a1 u1
    f (UnReturn xs) (UnReturn ys) = UnReturn (zipWith g xs ys)
    f (UnStore b1 a1 xs1) (UnStore b2 a2 xs2) | a1 == a2 = UnStore b1 a1 (zipWith g xs1 xs2)
                                              | otherwise = UnReturn [UnUnknown (bool b1 tyDNode tyINode)]
    f (UnDemote u1) (UnDemote u2) = UnDemote (g u1 u2)
    f (UnDemote u1) (UnReturn [UnConst (Const v)]) = UnDemote (UnUnknown tyDNode)
    f (UnReturn [UnConst (Const v)]) (UnDemote u1) = UnDemote (UnUnknown tyDNode)
    f x _ = UnReturn (map UnUnknown (getType x))
    g (UnConst v1) (UnConst v2) | v1 == v2 = UnConst v1
                                | otherwise = UnUnknown (getType v1)
    g x _ = UnUnknown (getType x)

getUnboxing :: Exp -> UnboxingResult
getUnboxing e = f e where
    f (Return rs) = UnReturn (map g rs)
    f (BaseOp (StoreNode b) [NodeC c xs]) = UnStore b c (map g xs)
    f (BaseOp Demote [v]) = UnDemote (g v)
    f (Error _ tys) = UnErr tys
    f (App f vs ts) = UnTail (singleton f) (getType vs) ts
    f (Case _ ls) = foldr1 combineUnboxing  [ f e | _ :-> e <- ls ]
    f Let { expDefs = defs, expBody = body, expIsNormal = False } = case f body of
        UnTail fs _ ntys | not $ Set.null (fs `Set.intersection` (Set.fromList $ map funcDefName defs)) -> UnReturn (map UnUnknown ntys)
        e -> e
    f (_ :>>= _ :-> e) = f e
    f e = UnReturn (map UnUnknown $ getType e)
    g v | valIsConstant v = UnConst v
    g v = UnUnknown (getType v)

editTail :: Monad m => [Ty] -> (Exp -> m Exp) -> Exp -> m Exp
editTail nty mt te = f (sempty :: GSet Atom) te where
    f _ (Error s ty) = return $ Error s nty
    f lf (Case x ls) = return (Case x) `ap` mapM (g lf) ls
    f lf lt@Let {expIsNormal = False, expBody = body } = do
        body <- f lf body
        return $ updateLetProps lt { expBody = body }
    f lf lt@Let {expDefs = defs, expIsNormal = True } = do
        let nlf = lf `union` fromList (map funcDefName defs)
        mapExpExp (f nlf) lt
    f lf lt@MkCont {expLam = lam, expCont = cont } = do
        a <- g lf lam
        b <- g lf cont
        return $ lt { expLam = a, expCont = b }
    f lf (e1 :>>= p :-> e2) = do
        e2 <- f lf e2
        return $ e1 :>>= p :-> e2
    f lf e@(App a as t) | a `member` lf = return $ App a as nty
    f lf e = mt e
    g lf (p :-> e) = do e <- f lf e; return $ p :-> e

bool b x y = if b then x else y

-- this finds top level functions that call themselves recursively and turns the recursive call into a
-- local definition, allowing it to be compiled to a direct loop.

explicitRecurse
    :: Grin
    -> IO Grin
explicitRecurse grin =  mapGrinFuncsM f grin where
    f name lam | name `notMember` (freeVars lam :: GSet Atom) = return lam
    f name (as :-> e) = do
        let nname = toAtom $ "bR" ++ fromAtom name
            g (App n rs t) | n == name = App nname rs t
            g e = tickle g e
        return $ as :-> grinLet [createFuncDef True nname (as :-> g e) ] (App nname as (getType e))

$(derive makeMonoid ''SEnv)