{- |
Module: Coroutine.hs
Description: Resumable coroutines implementation
Copyright: (c) 2014 Galois, Inc.

This is an implementation of coroutines in Ivory.  These coroutines:

  * may be suspended and resumed,
  * are parametrized over any memory-area type ('Area'),
  * cannot return values, but can receive values when they are resumed (though
this is an incidental detail of the current implementation),
  * for now, can have only one instance executing at once, given a particular
coroutine name and Ivory module.

-}
{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE RecordWildCards     #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies        #-}
{-# LANGUAGE TypeOperators       #-}
{-# OPTIONS_GHC -fsimpl-tick-factor=200 #-}

module Ivory.Language.Coroutine (
  -- * Usage Notes
  -- $usageNotes

  -- * Implementation Notes
  -- $implNotes

  Coroutine(..), CoroutineBody(..), coroutine,
  ) where

import           Prelude                ()
import           Prelude.Compat

import           Control.Monad          (unless, when)
import           Control.Monad.Fix      (mfix)
import qualified Data.DList             as D
import qualified Data.IntMap            as IntMap
import qualified Data.Map               as Map
import           Ivory.Language.Area
import           Ivory.Language.Array
import           Ivory.Language.Effects
import           Ivory.Language.IBool
import           Ivory.Language.Module
import           Ivory.Language.Monad
import           Ivory.Language.Proc
import           Ivory.Language.Proxy
import           Ivory.Language.Ref
import qualified Ivory.Language.Syntax  as AST
import           Ivory.Language.Type
import qualified MonadLib

-- Optimizations TODO:
-- TODO: if every block yields, don't generate a loop
-- TODO: re-use continuation variables that have gone out of scope
-- TODO: full liveness analysis to maximize re-use
-- TODO: only extract a variable to the continuation if it is live across a suspend

{- $usageNotes

The coroutine itself is presented as a function which has one argument,
@yield@.  @yield@ suspends the coroutine's execution at the point of usage.

'coroutineRun' (given an argument of 'false') then resumes the suspended
coroutine at that point, passing in a value in the process.

-}

{- $implNotes

This implementation handles a coroutine in the form of a single contiguous
block of Ivory code.  It turns this block into a single large C function,
breaking it up into an initialization portion and a series of branches inside
an infinite loop.  Each branch represents a resume and suspend point within
that block.

It was mentioned that @yield@ /suspends/ a coroutine for later resumption.
The state of the computation to be resumed later - that is, its
/continuation/ - is stored in a C struct, not visible to the outer Ivory code,
but denoted internally with a suffix of @_continuation@.  This C struct
contains a state variable and every local variable that is created.

At each @yield@, after collecting the current computation state into the
continuation, the C function breaks out of its loop and returns.

The @yield@ action is implemented as a sort of pseudo-function call which is
given a purposely-invalid name (see 'yieldName'). This Ivory call does not
become an actual function call in generated code, but rather, the code is
suspended right there.  (See the 'extractLocals' and 'addYield' functions.)

The pseudo-function call returns a 'ConstRef' to the coroutine's type.  Of
course, no actual function call exists, but the action itself still returns
something - hence, Ivory code is able to refer to the return value of a
@yield@.  Dereferences to it via 'deref' are turned to references to the
continuation struct.

As this @yield@ action is an Ivory effect, it can be passed at the Haskell
level, but it cannot be passed as a C value. (If one could hypothetically run
'procPtr' on it, then one could pass it around as a value, but any attempt to
call it outside of the same coroutine would result in either an invalid
function call, or a suspend & return of a different coroutine.)

-}

-- | Concrete representation of a coroutine; use 'coroutine' to produce one.
data Coroutine a = Coroutine
  { -- | Name of the coroutine (must be a valid C identifier)
    coroutineName :: String
    -- | Ivory effect for initializing a coroutine (if first argument is
    -- 'true'), or resuming it (if first argument is 'false').  The second
    -- argument is passed to the coroutine when resuming, and ignored when
    -- initializing.
  , coroutineRun :: forall eff s s'. GetAlloc eff ~ 'Scope s' =>
                    IBool -> ConstRef s a -> Ivory eff ()
    -- | The components a 'Module' will need for this coroutine
  , coroutineDef :: ModuleDef
  }

-- | The definition of a coroutine body, in the form of a function taking one
-- argument: an Ivory effect that is the @yield@ action which suspends the
-- coroutine, and sets the point at which 'coroutineRun' resumes it, passing a
-- value.
newtype CoroutineBody a =
  CoroutineBody (forall s1 s2 .
                 (forall b .
                  Ivory ('Effects ('Returns ()) b ('Scope s2)) (Ref s1 a)) ->
                         Ivory (ProcEffects s2 ()) ())

-- | Smart constructor for a 'Coroutine'
coroutine :: forall a. IvoryArea a =>
             String -- ^ Coroutine name (must be a valid C identifier)
             -> CoroutineBody a -- ^ Coroutine definition
             -> Coroutine a
coroutine name (CoroutineBody fromYield) = Coroutine { .. }
  where
  ((), CodeBlock { blockStmts = rawCode }) =
    runIvory $ fromYield $ call (proc yieldName $ body $ return ())
  -- The above call is a fake function call which code elsewhere extracts from
  -- the AST and replaces.

  params = CoroutineParams
           { getCont = AST.ExpLabel strTy $ AST.ExpAddrOfGlobal $
                       AST.areaSym contArea
           , getBreakLabel =
             error "Ivory.Language.Coroutine: no break label set, but breakOut called"
           }

  initialState = CoroutineState
                 { rewrites = Map.empty
                 , labels = []
                 , derefs = 0
                 }

  -- Even the initial block needs a label, in case there's a yield or
  -- return before any control flow statements. Otherwise, the resulting
  -- resumeAt call will emit a 'break;' statement outside of the
  -- forever-loop that the state machine runs in, which is invalid C.
  initCode = makeLabel' =<< getBlock rawCode (resumeAt 0)
  (((initLabel, _), (localVars, resumes)), finalState) =
     MonadLib.runM initCode params initialState
  initBB = BasicBlock [] $ BranchTo False initLabel

  -- | Name of the continuation struct type:
  strName = name ++ "_continuation"

  -- | Definitions of the continuation struct type:
  strDef = AST.Struct strName $
           -- State variable:
           AST.Typed stateType stateName :
           D.toList localVars
  strTy = AST.TyStruct strName

  -- | Area for the continuation struct instance:
  contArea = AST.Area (name ++ "_cont") False strTy AST.InitZero

  coroutineName = name

  litLabel = AST.ExpLit . AST.LitInteger . fromIntegral

  genBB (BasicBlock pre term) = pre ++ case term of
    BranchTo suspend label -> (AST.Store stateType (getCont params stateName) $ litLabel label) : if suspend then [AST.Break] else []
    CondBranchTo cond tb fb -> [AST.IfTE cond (genBB tb) (genBB fb)]

  coroutineRun :: IBool -> ConstRef s a -> Ivory eff ()
  coroutineRun doInit arg = do
    ifte_ doInit (emits mempty { blockStmts = genBB initBB }) (return ())
    emit $ AST.Forever $ (AST.Deref stateType (AST.VarName stateName) $ getCont params stateName) : do
      (label, block) <- keepUsedBlocks initLabel $ zip [0..] $ map joinTerminators $ (BasicBlock [] $ BranchTo True 0) : reverse (labels finalState)
      let cond = AST.ExpOp (AST.ExpEq stateType)
                 [AST.ExpVar (AST.VarName stateName), litLabel label]
          fn :: AST.Expr -> [AST.Stmt]
          fn = Map.findWithDefault (const []) label resumes
          b' = fn (unwrapExpr arg) ++ genBB block
      return $ AST.IfTE cond b' []

  coroutineDef = do
    visibility <- MonadLib.ask
    MonadLib.put $ mempty
      { AST.modStructs = visAcc visibility strDef
      , AST.modAreas = visAcc visibility contArea
      }

-- | This is used as the name of a pseudo-function call which marks the
-- 'yield' in a coroutine.  It is purposely not a valid C identifier so that
-- it can't collide with a real procedure.
yieldName :: String
yieldName = "+yield"

-- | Name of the variable used for the coroutine's current state.
stateName :: String
stateName = "state"

-- | The type used for the continuation's state.
stateType :: AST.Type
stateType = AST.TyWord AST.Word32

data BasicBlock = BasicBlock AST.Block Terminator
type Goto = Int
data Terminator
  = BranchTo Bool Goto
  | CondBranchTo AST.Expr BasicBlock BasicBlock

joinTerminators :: BasicBlock -> BasicBlock
joinTerminators (BasicBlock b (CondBranchTo cond t f)) =
  case (joinTerminators t, joinTerminators f) of
  (BasicBlock bt (BranchTo yt lt), BasicBlock bf (BranchTo yf lf))
    | yt == yf && lt == lf -> BasicBlock (b ++ [AST.IfTE cond bt bf]) (BranchTo yt lt)
  (t', f') -> BasicBlock b (CondBranchTo cond t' f')
joinTerminators bb = bb

doInline :: Goto -> [(Goto, BasicBlock)] -> [(Goto, BasicBlock)]
doInline inlineLabel blocks = do
  let Just (BasicBlock newStmts tgt) = lookup inlineLabel blocks
  let inlineBlock (BasicBlock b (BranchTo False dst))
        | dst == inlineLabel = BasicBlock (b ++ newStmts) tgt
      inlineBlock (BasicBlock b (CondBranchTo cond tb fb))
        = joinTerminators $ BasicBlock b $ CondBranchTo cond (inlineBlock tb) (inlineBlock fb)
      inlineBlock bb = bb
  (label, bb) <- blocks
  return $ if label == inlineLabel then (label, bb) else (label, inlineBlock bb)

keepUsedBlocks :: Goto -> [(Goto, BasicBlock)] -> [(Goto, BasicBlock)]
keepUsedBlocks root blocks = sweep $ snd $ MonadLib.runM (mark root >> ref root) IntMap.empty
  where
  mark :: Goto -> MonadLib.StateT (IntMap.IntMap Int) MonadLib.Id ()
  mark label = do
    seen <- MonadLib.get
    ref label
    unless (label `IntMap.member` seen) $ do
      let Just b = lookup label blocks
      markBlock b
  ref label = MonadLib.sets_ $ IntMap.insertWith (+) label 1
  markBlock (BasicBlock _ (BranchTo suspend label)) = do
    mark label
    -- add an extra reference for yield targets
    when suspend $ ref label
  markBlock (BasicBlock _ (CondBranchTo _ tb fb)) = markBlock tb >> markBlock fb
  sweep used = filter (\ (label, _) -> IntMap.findWithDefault 0 label used > 1) $
    foldr doInline blocks [ label | (label, 1) <- IntMap.assocs used ]

data CoroutineParams = CoroutineParams
  { getCont       :: String -> AST.Expr
  , getBreakLabel :: Terminator
  }

data CoroutineState = CoroutineState
  { rewrites :: Map.Map AST.Var (CoroutineMonad AST.Expr)
  , labels   :: [BasicBlock]
  , derefs   :: !Integer
  }

type CoroutineResume = Map.Map Goto (AST.Expr -> AST.Block)

type CoroutineVars = (D.DList (AST.Typed String), CoroutineResume)

type CoroutineMonad = MonadLib.WriterT
                      (D.DList AST.Stmt)
                      (MonadLib.ReaderT CoroutineParams
                       (MonadLib.WriterT CoroutineVars
                        (MonadLib.StateT CoroutineState MonadLib.Id)))

-- | Walk an 'AST.Stmt', update its variables as necessary to refer to the
-- continuation struct, and any time a new local is introduced, add it to the
-- continuation struct.
extractLocals :: AST.Stmt
                 -> CoroutineMonad Terminator
                 -> CoroutineMonad Terminator
extractLocals (AST.IfTE cond tb fb) rest = do
  after <- makeLabel rest
  CondBranchTo <$> runUpdateExpr (updateExpr cond)
    <*> getBlock tb (return after)
    <*> getBlock fb (return after)
extractLocals (AST.Assert cond) rest =
  (AST.Assert <$> runUpdateExpr (updateExpr cond)) >>= stmt >> rest
extractLocals (AST.CompilerAssert cond) rest =
  (AST.CompilerAssert <$> runUpdateExpr (updateExpr cond)) >>= stmt >> rest
extractLocals (AST.Assume cond) rest =
  (AST.Assume <$> runUpdateExpr (updateExpr cond)) >>= stmt >> rest
extractLocals (AST.Return {}) _ =
  error "Ivory.Language.Coroutine: can't return a value from the coroutine body"
-- XXX: this discards any code after a return. is that OK?
extractLocals (AST.ReturnVoid) _ = resumeAt 0
extractLocals (AST.Deref ty var ex) rest =
  (AST.RefCopy ty <$> addLocal ty var <*> runUpdateExpr (updateExpr ex)) >>=
  stmt >> rest
  -- Note here that an 'AST.Deref' also emits a 'RefCopy' and another local.
extractLocals (AST.Store ty lhs rhs) rest =
  (runUpdateExpr $ AST.Store ty <$> updateExpr lhs <*> updateExpr rhs) >>=
  stmt >> rest
extractLocals (AST.Assign ty var ex) rest =
  (AST.Store ty <$> addLocal ty var <*> runUpdateExpr (updateExpr ex)) >>=
  stmt >> rest
extractLocals (AST.Call ty mvar name args) rest
  -- 'yieldName' is the pseudo-function call which we handle specially at this
  -- point using 'addYield':
  | name == AST.NameSym yieldName = do
      -- XXX: yield takes no arguments and always returns something
      let (Just var, []) = (mvar, args)
      addYield ty var rest
  | otherwise = do
      -- All other function calls pass through normally, but have their
      -- arguments run through 'updateTypedExpr' and have their results saved
      -- into the continuation:
      stmt =<< AST.Call ty mvar name <$>
        runUpdateExpr (mapM updateTypedExpr args)
      case mvar of
       Nothing -> return ()
       Just var -> do
         cont <- addLocal ty var
         stmt $ AST.Store ty cont $ AST.ExpVar var
      rest
extractLocals (AST.Local ty var initex) rest = do
  cont <- addLocal ty var
  let AST.VarName varStr = var
  let ref = AST.VarName $ varStr ++ "_ref"
  initex' <- runUpdateExpr $ updateInit initex
  stmts
    [ AST.Local ty var initex'
    , AST.AllocRef ty ref $ AST.NameVar var
    , AST.RefCopy ty cont $ AST.ExpVar ref
    ]
  rest
extractLocals (AST.RefCopy ty lhs rhs) rest =
  (runUpdateExpr $ AST.RefCopy ty <$> updateExpr lhs <*> updateExpr rhs) >>=
  stmt >> rest
extractLocals (AST.RefZero ty ref) rest =
  (runUpdateExpr $ AST.RefZero ty <$> updateExpr ref) >>=
  stmt >> rest
extractLocals (AST.AllocRef _ty refvar name) rest = do
  let AST.NameVar var = name -- XXX: AFAICT, AllocRef can't have a NameSym argument.
  refvar `rewriteTo` contRef var
  rest
extractLocals (AST.Loop _ var initEx incr b) rest = do
  let ty = ivoryType (Proxy :: Proxy IxRep)
  cont <- addLocal ty var
  stmt =<< AST.Store ty cont <$> runUpdateExpr (updateExpr initEx)
  after <- makeLabel rest
  mfix $ \ loop -> makeLabel $ do
    let (condOp, incOp, limitEx) = case incr of
          AST.IncrTo ex -> (AST.ExpGt, AST.ExpAdd, ex)
          AST.DecrTo ex -> (AST.ExpLt, AST.ExpSub, ex)
    cond <- runUpdateExpr $ updateExpr $
            AST.ExpOp (condOp False ty) [AST.ExpVar var, limitEx]
    CondBranchTo cond (BasicBlock [] after) <$> do
      setBreakLabel after $ getBlock b $ do
        stmt =<< AST.Store ty cont <$>
          runUpdateExpr (updateExpr $ AST.ExpOp incOp
                         [AST.ExpVar var, AST.ExpLit (AST.LitInteger 1)])
        return loop
extractLocals (AST.Forever b) rest = do
  after <- makeLabel rest
  mfix $ \ loop -> makeLabel $ setBreakLabel after $
                   foldr extractLocals (return loop) b
-- XXX: this discards any code after a break. is that OK?
extractLocals (AST.Break) _ = MonadLib.asks getBreakLabel
extractLocals s@(AST.Comment{}) rest = stmt s >> rest

getBlock :: AST.Block -> CoroutineMonad Terminator -> CoroutineMonad BasicBlock
getBlock b next = do
  (term, b') <- MonadLib.collect $ foldr extractLocals next b
  return $ BasicBlock (D.toList b') term

makeLabel :: CoroutineMonad Terminator -> CoroutineMonad Terminator
makeLabel m = do
  block <- getBlock [] m
  case block of
    BasicBlock b term | all isComment b -> do
      stmts b
      return term
    _ -> goto =<< makeLabel' block
  where
  isComment (AST.Comment{}) = True
  isComment _ = False

makeLabel' :: BasicBlock -> CoroutineMonad Goto
makeLabel' block = MonadLib.sets $ \ state ->
  let state' = state { labels = block : labels state }
  in (length (labels state'), state')

goto :: Goto -> CoroutineMonad Terminator
goto = return . BranchTo False

resumeAt :: Goto -> CoroutineMonad Terminator
resumeAt = return . BranchTo True

contRef :: AST.Var -> CoroutineMonad AST.Expr
contRef var = do
  let AST.VarName varStr = var
  MonadLib.asks getCont <*> pure varStr

addLocal :: AST.Type -> AST.Var -> CoroutineMonad AST.Expr
addLocal ty var = do
  let AST.VarName varStr = var
  MonadLib.lift $ MonadLib.put (D.singleton $ AST.Typed ty varStr, mempty)
  cont <- contRef var
  var `rewriteTo` do
    idx <- MonadLib.sets $ \state ->
                            (derefs state, state { derefs = derefs state + 1 })
    let var' = AST.VarName $ "cont" ++ show idx
    stmt $ AST.Deref ty var' cont
    return $ AST.ExpVar var'
  return cont

-- | Generate the code to turn a @yield@ pseudo-call to a suspend and resume.
addYield :: AST.Type -> AST.Var -> CoroutineMonad Terminator ->
            CoroutineMonad Terminator
addYield ty var rest = do
  let AST.TyRef derefTy = ty
      AST.VarName varStr = var
  MonadLib.lift $ MonadLib.put
    (D.singleton $ AST.Typed derefTy varStr, mempty)
  cont <- contRef var
  var `rewriteTo` return cont
  after <- makeLabel' =<< getBlock [] rest
  let resume arg = [AST.RefCopy derefTy cont arg]
  MonadLib.lift $ MonadLib.put (mempty, Map.singleton after resume)
  resumeAt after

setBreakLabel :: Terminator -> CoroutineMonad a -> CoroutineMonad a
setBreakLabel label m = do
  params <- MonadLib.ask
  MonadLib.local (params { getBreakLabel = label }) m

-- | Insert a statement into a 'CoroutineMonad'.
stmt :: AST.Stmt -> CoroutineMonad ()
stmt s = MonadLib.put $ D.singleton s

-- | Insert a block of statements into a 'CoroutineMonad'.
stmts :: AST.Block -> CoroutineMonad ()
stmts = MonadLib.put . D.fromList

-- | Add a rewrite to some 'AST.Var', such that references to it are replaced
-- with execution of the given monadic code, and the value with the supplied
-- expression. ('addLocal' uses this to ensure that local variables that are
-- promoted to the global continuation are 'deref'ed before each use. This
-- 'deref' occurs each time because any intervening @yield@ puts a previous
-- 'deref' out of scope.)
rewriteTo :: AST.Var -- ^ Variable to replace
             -> CoroutineMonad AST.Expr -- ^ New value & statements
             -> CoroutineMonad ()
rewriteTo var repl =
  MonadLib.sets_ $ \state -> state
                             { rewrites = Map.insert var repl $ rewrites state }

-- | State monad transformer containing 'CoroutineMonad', plus a memoization
-- table recording which rewrites have already been evaluated for the current
-- statement.  /(Since an Ivory statement cannot have side effects - such as
-- storing to the continuation or suspending with 'yield' - we may safely
-- dereference the continuation variable just once per statement, regardless of
-- how many subexpressions contain that variable.)/
type UpdateExpr a = MonadLib.StateT (Map.Map AST.Var AST.Expr) CoroutineMonad a

-- | Starting from an empty memoization table, run the memoized monadic code
-- from a sequence of 'updateExpr'.  /(As each statement requires an empty
-- memoization table, this should be run only on single statements or on
-- smaller divisions of one, though it may generate redundant 'deref's in the
-- latter case.)/
runUpdateExpr :: UpdateExpr a -> CoroutineMonad a
runUpdateExpr = fmap fst . MonadLib.runStateT Map.empty

-- | Update an 'AST.Expr' with variable rewrites, applying and updating the
-- memoization table of rewrites in the process.
updateExpr :: AST.Expr -> UpdateExpr AST.Expr
updateExpr ex@(AST.ExpVar var) = do
  updated <- MonadLib.get
  -- Here's the actual memoizing. If the memo table doesn't already have the
  -- given variable in it, we try to find it in the rewrites table, and save
  -- the resulting Expr in the memo table in case we encounter another use of
  -- the same variable. If it isn't in the rewrites table, it wasn't declared
  -- in this local scope, so we'll just hope it was global or something and use
  -- it as-is. And if it was already in the memo table, we don't evaluate the
  -- monadic action in the rewrites table a second time.
  case Map.lookup var updated of
    Just ex' -> return ex'
    Nothing -> do
      ex' <- MonadLib.lift $ do
        Map.findWithDefault (return ex) var =<< fmap rewrites MonadLib.get
      MonadLib.sets_ $ Map.insert var ex'
      return ex'
updateExpr (AST.ExpLabel ty ex label) =
  AST.ExpLabel ty <$> updateExpr ex <*> pure label
updateExpr (AST.ExpIndex ty1 ex1 ty2 ex2) =
  AST.ExpIndex <$> pure ty1 <*> updateExpr ex1 <*> pure ty2 <*> updateExpr ex2
updateExpr (AST.ExpToIx ex bound) =
  AST.ExpToIx <$> updateExpr ex <*> pure bound
updateExpr (AST.ExpSafeCast ty ex) =
  AST.ExpSafeCast ty <$> updateExpr ex
updateExpr (AST.ExpOp op args) =
  AST.ExpOp op <$> mapM updateExpr args
updateExpr ex = return ex

-- | Basically 'updateExpr', but on an initializer.
updateInit :: AST.Init -> UpdateExpr AST.Init
updateInit AST.InitZero = return AST.InitZero
updateInit (AST.InitExpr ty ex) =
  -- 'AST.InitExpr' contains an expression that 'updateExpr' needs to handle:
  AST.InitExpr ty <$> updateExpr ex
updateInit (AST.InitStruct fields) =
  -- Every field of a @struct@ in an 'AST.InitStruct' contains an expression
  -- that must go through 'updateExpr':
  AST.InitStruct <$> mapM (\ (name, ex) -> (,) name <$> updateInit ex) fields
updateInit (AST.InitArray elems b) = do
  -- An 'AST.InitArray' is a list of 'AST.Init' which we must recurse over:
  ls <- mapM updateInit elems
  return (AST.InitArray ls b)

-- | Basically 'updateExpr', but on a typed expression.
updateTypedExpr :: AST.Typed AST.Expr -> UpdateExpr (AST.Typed AST.Expr)
updateTypedExpr (AST.Typed ty ex) = AST.Typed ty <$> updateExpr ex