module E.LambdaLift(lambdaLift,staticArgumentTransform)  where

import Control.Monad.Reader
import Control.Monad.Writer
import Data.IORef
import Data.Maybe
import Text.Printf

import Doc.PPrint
import E.Annotate
import E.E
import E.Inline
import E.Program
import E.Subst
import E.Traverse
import E.TypeCheck
import E.Values
import Fixer.Fixer
import Fixer.Supply
import GenUtil
import Name.Id
import Name.Name
import Options (verbose)
import Stats(mtick,runStatM,runStatT)
import StringTable.Atom
import Support.CanType
import Support.FreeVars
import Util.Graph as G
import Util.HasSize
import Util.SetLike hiding(Value)
import Util.UniqueMonad

annotateId mn x = case fromId x of
    Just y -> toId (toName Val (mn,'f':show y))
    Nothing -> toId (toName Val (mn,'f':show x))

-- | transform simple recursive functions into non-recursive variants
-- this is exactly the opposite of lambda lifting, but is a big win if the function ends up inlined
-- and is conducive to other optimizations
--
-- in particular, the type arguments can almost always be transformed away from the recursive inner function
--
-- this has potentially exponential behavior. beware

staticArgumentTransform :: Program -> Program
staticArgumentTransform prog = ans where
    ans = progCombinators_s (concat ds') prog { progStats = progStats prog `mappend` nstat }
    (ds',nstat) = runStatM $ mapM h (programDecomposedCombs prog)
    h (True,[comb]) = do [(_,nb)] <- f True (Right [(combHead comb, combBody comb)]); return [combBody_s nb comb]
    h (_,cs) = do
        forM cs $ \ c -> do
            e' <- g (combBody c)
            return (combBody_s e' c)
    f _ (Left (t,e)) = gds [(t,e)]
    f always (Right [(t,v@ELam {})]) | not (null collectApps), always || dropArgs > 0 = ans where
        nname = annotateId "R@" (tvrIdent t)
        dropArgs = minimum [ countCommon args aps | aps <- collectApps ] where
            args = map EVar $ snd $ fromLam v
            countCommon (x:xs) (y:ys) | x == y = 1 + countCommon xs ys
            countCommon _ _ = 0
        collectApps = execWriter (ca v) where
            ca e | (EVar v,as) <- fromAp e, tvrIdent v == tvrIdent t = tell [as] >> mapM_ ca as >> return e
            ca e = emapE ca e
        (body,args) = fromLam v
        (droppedAs,keptAs) = splitAt dropArgs args
        rbody = foldr ELam (subst t newV body)  keptAs
        newV = foldr ELam (EVar tvr') [ t { tvrIdent = emptyId } | t <- droppedAs ]
        tvr' = tvr { tvrIdent = nname, tvrType = getType rbody }
        ne' = foldr ELam (ELetRec [(tvr',rbody)]  (foldl EAp (EVar tvr') (map EVar keptAs))) args
        ans = do
            mtick $ "SimpleRecursive.{" ++ pprint t
            ne' <- g ne'
            return [(t,ne')]
    f _ (Right ts) =  gds ts
    gds ts = mapM g' ts >>= return where
        g' (t,e) = g e >>= return . (,) t
    g elet@ELetRec { eDefs = ds } =  do
        ds'' <- mapM (f False) (decomposeDs ds)
        e' <- g $ eBody elet
        return elet { eDefs = concat ds'', eBody = e' }
    g e = emapE g e

data S = S {
    funcName :: Name,
    topVars :: IdSet,
    isStrict :: Bool,
    declEnv :: [(TVr,E)]
    }

isStrict_u f r@S{isStrict  = x} = r{isStrict = f x}
topVars_u f r@S{topVars  = x} = r{topVars = f x}
isStrict_s v =  isStrict_u  (const v)

{-
etaReduce :: E -> (E,Int)
etaReduce e = case f e 0 of
        (ELam {},_) -> (e,0)
        x -> x
    where
        f (ELam t (EAp x (EVar t'))) n | n `seq` True, t == t' && not (tvrIdent t `member` (freeVars x :: IdSet)) = f x (n + 1)
        f e n = (e,n)
-}

-- | we do not lift functions that only appear in saturated strict contexts,
-- as these functions will never have an escaping thunk or partial app
-- built and can be turned into local functions in grin.
--
-- Although grin is only able to take advantage of groups of possibily
-- mutually recursive local functions that only tail-call each other, we leave
-- all candidate functions local, as further grin transformations can expose
-- tail-calls that arn't evident in core.
--
-- A final lambda-lifting needs to be done in grin to get rid of these local
-- functions that cannot be turned into loops

calculateLiftees :: Program -> IO IdSet
calculateLiftees prog = do
    fixer <- newFixer
    sup <- newSupply fixer

    let f v env ELetRec { eDefs = ds, eBody = e } = do
            let nenv = fromList [ (tvrIdent t,length (snd (fromLam e))) | (t,e) <- ds ]  `mappend` env
                nenv :: IdMap Int
                g (t,e@ELam {}) = do
                    v <- supplyValue sup (tvrIdent t)
                    let (a,_as) = fromLam e
                    f v nenv a
                g (t,e) = do
                    f (value True) nenv e
            mapM_ g ds
            f v nenv e
        f v env e@ESort {} = return ()
        f v env e@Unknown {} = return ()
        f v env e@EError {} = return ()
        f v env (EVar TVr { tvrIdent = vv }) = do
            nv <- supplyValue sup vv
            assert nv
        f v env e | (EVar TVr { tvrIdent = vv }, as@(_:_)) <- fromAp e, Just n <- mlookup vv env = do
            nv <- supplyValue sup vv
            if length as >= n then v `implies` nv else assert nv
            mapM_ (f (value True) env) as
        f v env e | (a, as@(_:_)) <- fromAp e = do
            mapM_ (f (value True) env) as
            f v env a
        f v env (ELit LitCons { litArgs = as }) = mapM_ (f (value True) env) as
        f v env ELit {} = return ()
        f v env (EPi TVr { tvrType = a } b) = f (value True) env a >> f (value True) env b
        f v env (EPrim _ as _) = mapM_ (f (value True) env) as
        f v env ec@ECase {} = do
            f v env (eCaseScrutinee ec)
            mapM_ (f v env) (caseBodies ec)
        f v env (ELam _ e) = f (value True) env e
        f _ _ EAp {} = error "this should not happen"
    mapM_ (f (value False) mempty) [ fst (fromLam e) | (_,e) <- programDs prog]

    findFixpoint Nothing {-"Liftees"-} fixer
    vs <- supplyReadValues sup
    let nlset =  (fromList [ x | (x,False) <- vs])
    when verbose $ printf "%d lambdas not lifted\n" (size nlset)
    return nlset

implies :: Value Bool -> Value Bool -> IO ()
implies x y = addRule $ y `isSuperSetOf` x

assert x = value True `implies` x

lambdaLift ::  Program -> IO Program
lambdaLift prog@Program { progDataTable = dataTable, progCombinators = cs } = do
    noLift <- calculateLiftees prog
    let wp =  fromList [ combIdent x | x <- cs ] :: IdSet
    fc <- newIORef []
    fm <- newIORef mempty
    statRef <- newIORef mempty
    let z comb  = do
            (n,as,v) <- return $ combTriple comb
            let ((v',(cs',rm)),stat) = runReader (runStatT $ execUniqT 1 $ runWriterT (f v)) S { funcName = mkFuncName (tvrIdent n), topVars = wp,isStrict = True, declEnv = [] }
            modifyIORef statRef (mappend stat)
            modifyIORef fc (\xs -> combTriple_s (n,as,v') comb:cs' ++ xs)
            modifyIORef fm (rm `mappend`)
        shouldLift t _ | tvrIdent t `member` noLift = False
        shouldLift _ ECase {} = True
        shouldLift _ ELam {} = True
        shouldLift _ _ = False
        f e@(ELetRec ds _)  = do
            let (ds',e') = decomposeLet e
            h ds' e' []
        f e = do
            st <- asks isStrict
            if ((tvrIdent tvr `notMember` noLift && isELam e) || (shouldLift tvr e && not st)) then do
                (e,fvs'') <- pLift e
                doBigLift e fvs'' return
             else g e
        -- This ensures there are no 'orphaned type terms' when something is
        -- lifted out.  The problem occurs when a type is subsituted in some
        -- places and not others, the type as free variable will not be the
        -- same as its substituted instances if the variable is bound by a
        -- lambda, Although the program is still typesafe, it is no longer
        -- easily proven so, so we avoid the whole mess by subtituting known
        -- type variables within lifted expressions. This can not duplicate work
        -- since types are unpointed, but might change space usage slightly.
--        g ec@ECase { eCaseScrutinee = (EVar v), eCaseAlts = as, eCaseDefault = d} | sortKindLike (tvrType v) = do
--            True <- asks isStrict
--            d' <- fmapM f d
--            let z (Alt l e) = do
--                    e' <- local (declEnv_u ((v,followAliases dataTable $ patToLitEE l):)) $ f e
--                    return $ Alt l e'
--            as' <- mapM z as
--            return $ caseUpdate ec { eCaseAlts = as', eCaseDefault = d'}
        g (ELam t e) = do
            e' <- local (isStrict_s True) (g e)
            return (ELam t e')
        g e = emapE' f e
        pLift e = do
            gs <- asks topVars
            ds <- asks declEnv
            let fvs = freeVars e
                fvs' = filter (not . (`member` gs) . tvrIdent) fvs
                --ss = filter (sortKindLike . tvrType) fvs'
                ss = []
                f [] e False = return (e,fvs'')
                f [] e True = pLift e
                f (s:ss) e x
                    | Just v <- lookup s ds = f ss (removeType s v e) True   -- TODO subst
                    | otherwise = f ss e x
                fvs'' = reverse $ topSort $ newGraph fvs' tvrIdent freeVars
            f ss e False
        h (Left (t,e):ds) rest ds' | shouldLift t e = do
            (e,fvs'') <- pLift e
            case fvs'' of
                [] -> doLift t e (h ds rest ds')
                fs -> doBigLift e fs (\e'' -> h ds rest ((t,e''):ds'))
        h (Left (t,e@ELam {}):ds) rest ds' = do
            let (a,as) = fromLam e
            a' <- local (isStrict_s True) (f a)
            h ds rest ((t,foldr ELam a' as):ds')

        h (Left (t,e):ds) rest ds'  = do
            let fvs =  freeVars e :: [Id]
            gs <- asks topVars
            let fvs' = filter (not . (`member` gs) ) fvs
            case fvs' of
                [] -> doLift t e (h ds rest ds')  -- We always lift CAFS to the top level for now. (GC?)
                _ ->  local (isStrict_s False) (f e) >>= \e'' -> h ds rest ((t,e''):ds')
        --h (Left (t,e):ds) e' ds' = local (isStrict_s False) (f e) >>= \e'' -> h ds e' ((t,e''):ds')
        h (Right rs:ds) rest ds' | any (uncurry shouldLift) rs  = do
            gs <- asks topVars
            let fvs =  freeVars (snds rs)--   (Set.fromList (map tvrIdent $ fsts rs) `Set.union` gs)
            let fvs' = filter (not . (`member` (fromList (map tvrIdent $ fsts rs) `mappend` gs) ) . tvrIdent) fvs
                fvs'' = reverse $ topSort $ newGraph fvs' tvrIdent freeVars
            case fvs'' of
                [] -> doLiftR rs (h ds rest ds')  -- We always lift CAFS to the top level for now. (GC?)
                fs -> doBigLiftR rs fs (\rs' -> h ds rest (rs' ++ ds'))
        h (Right rs:ds) e' ds'   = do
            rs' <- local (isStrict_s False) $ do
                flip mapM rs $ \te -> case te of
                    (t,e@ELam {}) -> do
                        let (a,as) = fromLam e
                        a' <- local (isStrict_s True) (f a)
                        return (t,foldr ELam a' as)
                    (t,e) -> do
                        e'' <- f e
                        return (t,e'')
            h ds e' (rs' ++ ds')
        h [] e ds = f e >>= return . eLetRec ds
        tellCombinator c = tell ([combTriple_s c emptyComb],mempty)
        tellCombinators c = tell (map (`combTriple_s` emptyComb) c,mempty)
        doLift t e r = local (topVars_u (insert (tvrIdent t)) ) $ do
            --(e,tn) <- return $ etaReduce e
            let (e',ls) = fromLam e
            mtick (toAtom $ "E.LambdaLift.doLift." ++ typeLift e ++ "." ++ show (length ls))
            --mticks tn (toAtom $ "E.LambdaLift.doLift.etaReduce")
            e'' <- local (isStrict_s True) $ f e'
            t <- globalName t
            tellCombinator (t,ls,e'')
            r
        doLiftR rs r = local (topVars_u (mappend (fromList (map (tvrIdent . fst) rs)) )) $ do
            flip mapM_ rs $ \ (t,e) -> do
                --(e,tn) <- return $ etaReduce e
                let (e',ls) = fromLam e
                mtick (toAtom $ "E.LambdaLift.doLiftR." ++ typeLift e ++ "." ++ show (length ls))
                --mticks tn (toAtom $ "E.LambdaLift.doLift.etaReduce")
                e'' <- local (isStrict_s True) $ f e'
                t <- globalName t
                tellCombinator (t,ls,e'')
            r
        globalName tvr | isNothing $ fromId (tvrIdent tvr) = do
            TVr { tvrIdent = t } <- newName Unknown
            let ntvr = tvr { tvrIdent = t }
            tell ([],msingleton (tvrIdent tvr) (Just $ EVar ntvr))
            return ntvr
        globalName tvr = return tvr
        newName tt = do
            un <-  newUniq
            n <- asks funcName
            return $ tVr (toId $ mapName (id,(++ ('$':show un))) n) tt
        doBigLift e fs  dr = do
            mtick (toAtom $ "E.LambdaLift.doBigLift." ++ typeLift e ++ "." ++ show (length fs))
            ds <- asks declEnv
            let tt = typeInfer' dataTable ds (foldr ELam e fs)
            tvr <- newName tt
            let (e',ls) = fromLam e
            e'' <- local (isStrict_s True) $ f e'
            tellCombinator (tvr,fs ++ ls,e'')
            let e'' = foldl EAp (EVar tvr) (map EVar fs)
            dr e''
        doBigLiftR rs fs dr = do
            ds <- asks declEnv
            rst <- flip mapM rs $ \ (t,e) -> do
                case shouldLift t e of
                    True -> do
                        mtick (toAtom $ "E.LambdaLift.doBigLiftR." ++ typeLift e ++ "." ++ show (length fs))
                        let tt = typeInfer' dataTable ds (foldr ELam e fs)
                        tvr <- newName tt
                        let (e',ls) = fromLam e
                        e'' <- local (isStrict_s True) $ f e'
                        --tell [(tvr,fs ++ ls,e'')]
                        let e''' = foldl EAp (EVar tvr) (map EVar fs)
                        return ((t,e'''),[(tvr,fs ++ ls,e'')])
                    False -> do
                        mtick (toAtom $ "E.LambdaLift.skipBigLiftR." ++ show (length fs))
                        return ((t,e),[])
            let (rs',ts) = unzip rst
            tellCombinators [ (t,ls,substLet rs' e) | (t,ls,e) <- concat ts]
            dr rs'

        mkFuncName x = case fromId x of
            Just y -> y
            Nothing -> toName Val ("LL@",'f':show x)
    mapM_ z cs
    ncs <- readIORef fc
    nstat <- readIORef statRef
    nz <- readIORef fm
    annotateProgram nz (\_ nfo -> return nfo) (\_ nfo -> return nfo) (\_ nfo -> return nfo) prog { progCombinators =  ncs, progStats = progStats prog `mappend` nstat }

typeLift ECase {} = "Case"
typeLift ELam {} = "Lambda"
typeLift _ = "Other"

removeType t v e  = subst' t v e
{-
removeType t v e = ans where
    (b,ls) = fromLam e
    ans = foldr f (substLet [(t,v)] e) ls
    f tv@(TVr { tvrType = ty} ) e = ELam nt (subst tv (EVar nt) e) where nt = tv { tvrType = (subst t v ty) }
-}