-- | Experimental support for coroutines.
module Effectful.Coroutine
  ( Coroutine
  , Status(..)
  , runCoroutine
  , yield
  ) where

import Control.Concurrent.MVar
import Control.Exception
import Data.Function
import qualified Control.Concurrent as C

import Effectful.Internal.Env
import Effectful.Internal.Has
import Effectful.Internal.Monad

data Coroutine o i = forall es r. Coroutine
  { ()
crInput         :: MVar (i, Env es)
  , ()
crState         :: MVar (State o r)
  , Coroutine o i -> Int
crCallerEnvSize :: Int
  }

data Status es o i r
  = Done r
  | Yielded o (i -> Eff es (Status es o i r))

runCoroutine :: Eff (Coroutine o i : es) r -> Eff es (Status es o i r)
runCoroutine :: Eff (Coroutine o i : es) r -> Eff es (Status es o i r)
runCoroutine (Eff Env (Coroutine o i : es) -> IO r
m) = (Env es -> IO (Status es o i r)) -> Eff es (Status es o i r)
forall (es :: [*]) a. (Env es -> IO a) -> Eff es a
impureEff ((Env es -> IO (Status es o i r)) -> Eff es (Status es o i r))
-> (Env es -> IO (Status es o i r)) -> Eff es (Status es o i r)
forall a b. (a -> b) -> a -> b
$ \Env es
es -> do
  Int
size    <- Env es -> IO Int
forall (es :: [*]). Env es -> IO Int
sizeEnv Env es
es
  MVar (i, Env es)
mvInput <- IO (MVar (i, Env es))
forall a. IO (MVar a)
newEmptyMVar
  MVar (State o r)
mvState <- IO (MVar (State o r))
forall a. IO (MVar a)
newEmptyMVar
  ((forall a. IO a -> IO a) -> IO (Status es o i r))
-> IO (Status es o i r)
forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
mask (((forall a. IO a -> IO a) -> IO (Status es o i r))
 -> IO (Status es o i r))
-> ((forall a. IO a -> IO a) -> IO (Status es o i r))
-> IO (Status es o i r)
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
restore -> do
    -- Create a worker thread and continue execution there.
    ThreadId
tid <- IO () -> IO ThreadId
C.forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ do
      let cr :: Coroutine o i
cr = MVar (i, Env es) -> MVar (State o r) -> Int -> Coroutine o i
forall o i (es :: [*]) r.
MVar (i, Env es) -> MVar (State o r) -> Int -> Coroutine o i
Coroutine MVar (i, Env es)
mvInput MVar (State o r)
mvState Int
size
      Either SomeException r
er <- IO r -> IO (Either SomeException r)
forall e a. Exception e => IO a -> IO (Either e a)
try (IO r -> IO (Either SomeException r))
-> IO r -> IO (Either SomeException r)
forall a b. (a -> b) -> a -> b
$ IO r -> IO r
forall a. IO a -> IO a
restore (IO r -> IO r)
-> (Env (Coroutine o i : es) -> IO r)
-> Env (Coroutine o i : es)
-> IO r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env (Coroutine o i : es) -> IO r
m (Env (Coroutine o i : es) -> IO r)
-> IO (Env (Coroutine o i : es)) -> IO r
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Coroutine o i -> Env es -> IO (Env (Coroutine o i : es))
forall e (es :: [*]).
HasCallStack =>
e -> Env es -> IO (Env (e : es))
unsafeConsEnv Coroutine o i
cr Env es
es
      MVar (State o r) -> State o r -> IO Bool
forall a. MVar a -> a -> IO Bool
tryPutMVar MVar (State o r)
mvState ((SomeException -> State o r)
-> (r -> State o r) -> Either SomeException r -> State o r
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either SomeException -> State o r
forall o r. SomeException -> State o r
Failure r -> State o r
forall r o. r -> State o r
Success Either SomeException r
er) IO Bool -> (Bool -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Bool
False -> [Char] -> IO ()
forall a. HasCallStack => [Char] -> a
error [Char]
"unexpected"
        Bool
True  -> () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    (forall a. IO a -> IO a)
-> Env es
-> Int
-> ThreadId
-> MVar (i, Env es)
-> MVar (State o r)
-> IO (Status es o i r)
forall (es :: [*]) i o r.
(forall a. IO a -> IO a)
-> Env es
-> Int
-> ThreadId
-> MVar (i, Env es)
-> MVar (State o r)
-> IO (Status es o i r)
waitForStatus forall a. IO a -> IO a
restore Env es
es Int
size ThreadId
tid MVar (i, Env es)
mvInput MVar (State o r)
mvState

yield :: Coroutine o i :> es => o -> Eff es i
yield :: o -> Eff es i
yield o
o = (Env es -> IO i) -> Eff es i
forall (es :: [*]) a. (Env es -> IO a) -> Eff es a
impureEff ((Env es -> IO i) -> Eff es i) -> (Env es -> IO i) -> Eff es i
forall a b. (a -> b) -> a -> b
$ \Env es
es -> ((forall a. IO a -> IO a) -> IO i) -> IO i
forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
mask (((forall a. IO a -> IO a) -> IO i) -> IO i)
-> ((forall a. IO a -> IO a) -> IO i) -> IO i
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
restore -> do
  Coroutine{Int
MVar (i, Env es)
MVar (State o r)
crCallerEnvSize :: Int
crState :: MVar (State o r)
crInput :: MVar (i, Env es)
crCallerEnvSize :: forall o i. Coroutine o i -> Int
crState :: ()
crInput :: ()
..} <- Env es -> IO (Coroutine o i)
forall e (es :: [*]). (HasCallStack, e :> es) => Env es -> IO e
getEnv Env es
es
  Int
size <- Env es -> IO Int
forall (es :: [*]). Env es -> IO Int
sizeEnv Env es
es
  -- Save local part of the environment as the caller will discard it.
  Env Any
localEs <- Int -> Env es -> IO (Env Any)
forall (es0 :: [*]) (es :: [*]).
HasCallStack =>
Int -> Env es0 -> IO (Env es)
takeLastEnv (Int
size Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
crCallerEnvSize) Env es
es
  -- Pass control to the caller.
  MVar (State o r) -> State o r -> IO Bool
forall a. MVar a -> a -> IO Bool
tryPutMVar MVar (State o r)
crState (o -> State o r
forall o r. o -> State o r
Yield o
o) IO Bool -> (Bool -> IO i) -> IO i
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Bool
False -> [Char] -> IO i
forall a. HasCallStack => [Char] -> a
error [Char]
"unexpected"
    Bool
True  -> do
      (i
i, Env es
callerEs) <- IO (i, Env es) -> IO (i, Env es)
forall a. IO a -> IO a
restore (IO (i, Env es) -> IO (i, Env es))
-> IO (i, Env es) -> IO (i, Env es)
forall a b. (a -> b) -> a -> b
$ MVar (i, Env es) -> IO (i, Env es)
forall a. MVar a -> IO a
takeMVar MVar (i, Env es)
crInput
      -- The caller resumed, reconstruct the local environment. The environment
      -- needs to be replaced since the one we just got might be completely
      -- different to what we had before suspending the computation, e.g. if the
      -- computation was resumed in a different thread.
      Env es -> Env es -> IO ()
forall (es :: [*]). HasCallStack => Env es -> Env es -> IO ()
unsafeReplaceEnv Env es
es (Env es -> IO ()) -> IO (Env es) -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Env es -> Env Any -> IO (Env es)
forall (es0 :: [*]) (es1 :: [*]) (es :: [*]).
HasCallStack =>
Env es0 -> Env es1 -> IO (Env es)
unsafeAppendEnv Env es
callerEs Env Any
localEs
      i -> IO i
forall (f :: * -> *) a. Applicative f => a -> f a
pure i
i

----------------------------------------
-- Internal

data State o r where
  Failure :: SomeException -> State o r
  Success :: r             -> State o r
  Yield   :: o             -> State o r

waitForStatus
  :: (forall a. IO a -> IO a)
  -> Env es
  -> Int
  -> C.ThreadId
  -> MVar (i, Env es)
  -> MVar (State o r)
  -> IO (Status es o i r)
waitForStatus :: (forall a. IO a -> IO a)
-> Env es
-> Int
-> ThreadId
-> MVar (i, Env es)
-> MVar (State o r)
-> IO (Status es o i r)
waitForStatus forall a. IO a -> IO a
restore0 Env es
es0 Int
size0 ThreadId
tid MVar (i, Env es)
mvInput MVar (State o r)
mvState = (IO (Status es o i r) -> IO (Status es o i r))
-> IO (Status es o i r)
forall a. (a -> a) -> a
fix ((IO (Status es o i r) -> IO (Status es o i r))
 -> IO (Status es o i r))
-> (IO (Status es o i r) -> IO (Status es o i r))
-> IO (Status es o i r)
forall a b. (a -> b) -> a -> b
$ \IO (Status es o i r)
loop -> do
  IO (State o r) -> IO (Either SomeException (State o r))
forall e a. Exception e => IO a -> IO (Either e a)
try @SomeException (IO (State o r) -> IO (State o r)
forall a. IO a -> IO a
restore0 (IO (State o r) -> IO (State o r))
-> IO (State o r) -> IO (State o r)
forall a b. (a -> b) -> a -> b
$ MVar (State o r) -> IO (State o r)
forall a. MVar a -> IO a
takeMVar MVar (State o r)
mvState) IO (Either SomeException (State o r))
-> (Either SomeException (State o r) -> IO (Status es o i r))
-> IO (Status es o i r)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left SomeException
e            -> ThreadId -> SomeException -> IO ()
forall e. Exception e => ThreadId -> e -> IO ()
throwTo ThreadId
tid SomeException
e IO () -> IO (Status es o i r) -> IO (Status es o i r)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO (Status es o i r)
loop
    Right (Failure SomeException
e) -> SomeException -> IO (Status es o i r)
forall e a. Exception e => e -> IO a
throwIO SomeException
e
    Right (Success r
r) -> r -> Status es o i r
forall (es :: [*]) o i r. r -> Status es o i r
Done r
r      Status es o i r -> IO (Env Any) -> IO (Status es o i r)
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Int -> Env es -> IO (Env Any)
forall (es0 :: [*]) (es :: [*]).
HasCallStack =>
Int -> Env es0 -> IO (Env es)
unsafeTrimEnv Int
size0 Env es
es0
    Right (Yield o
o)   -> o -> (i -> Eff es (Status es o i r)) -> Status es o i r
forall (es :: [*]) o i r.
o -> (i -> Eff es (Status es o i r)) -> Status es o i r
Yielded o
o i -> Eff es (Status es o i r)
k Status es o i r -> IO (Env Any) -> IO (Status es o i r)
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Int -> Env es -> IO (Env Any)
forall (es0 :: [*]) (es :: [*]).
HasCallStack =>
Int -> Env es0 -> IO (Env es)
unsafeTrimEnv Int
size0 Env es
es0
  where
    k :: i -> Eff es (Status es o i r)
k i
i = (Env es -> IO (Status es o i r)) -> Eff es (Status es o i r)
forall (es :: [*]) a. (Env es -> IO a) -> Eff es a
impureEff ((Env es -> IO (Status es o i r)) -> Eff es (Status es o i r))
-> (Env es -> IO (Status es o i r)) -> Eff es (Status es o i r)
forall a b. (a -> b) -> a -> b
$ \Env es
es -> ((forall a. IO a -> IO a) -> IO (Status es o i r))
-> IO (Status es o i r)
forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
mask (((forall a. IO a -> IO a) -> IO (Status es o i r))
 -> IO (Status es o i r))
-> ((forall a. IO a -> IO a) -> IO (Status es o i r))
-> IO (Status es o i r)
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
restore -> do
      Int
size <- Env es -> IO Int
forall (es :: [*]). Env es -> IO Int
sizeEnv Env es
es
      -- Resume suspended computation with the current environment.
      MVar (i, Env es) -> (i, Env es) -> IO Bool
forall a. MVar a -> a -> IO Bool
tryPutMVar MVar (i, Env es)
mvInput (i
i, Env es
es) IO Bool -> (Bool -> IO (Status es o i r)) -> IO (Status es o i r)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Bool
False -> [Char] -> IO (Status es o i r)
forall a. HasCallStack => [Char] -> a
error [Char]
"unexpected"
        Bool
True  -> (forall a. IO a -> IO a)
-> Env es
-> Int
-> ThreadId
-> MVar (i, Env es)
-> MVar (State o r)
-> IO (Status es o i r)
forall (es :: [*]) i o r.
(forall a. IO a -> IO a)
-> Env es
-> Int
-> ThreadId
-> MVar (i, Env es)
-> MVar (State o r)
-> IO (Status es o i r)
waitForStatus forall a. IO a -> IO a
restore Env es
es Int
size ThreadId
tid MVar (i, Env es)
mvInput MVar (State o r)
mvState