module Control.Monad.Coroutine
(
Coroutine(Coroutine),
resume, suspend,
mapMonad, mapSuspension,
Naught, runCoroutine, pogoStick, foldRun, seesaw, SeesawResolver(..),
NestedFunctor (NestedFunctor), SomeFunctor(..), nest,
couple, merge
)
where
import Control.Monad (liftM, when)
import Control.Monad.Trans (MonadTrans(..), MonadIO(..))
import Data.Either (either, partitionEithers)
import Data.Traversable (Traversable, sequence)
import Control.Monad.Parallel
newtype Coroutine s m r = Coroutine {
resume :: m (Either (s (Coroutine s m r)) r)
}
type CoroutineStepResult s m r = Either (s (Coroutine s m r)) r
instance (Functor s, Monad m) => Monad (Coroutine s m) where
return x = Coroutine (return (Right x))
t >>= f = Coroutine (resume t >>= apply f)
where apply f (Right x) = resume (f x)
apply f (Left s) = return (Left (fmap (>>= f) s))
instance (Functor s, MonadParallel m) => MonadParallel (Coroutine s m) where
bindM2 f t1 t2 = Coroutine (bindM2 combine (resume t1) (resume t2)) where
combine (Right x) (Right y) = resume (f x y)
combine (Left s) (Right y) = return $ Left (fmap (flip f y =<<) s)
combine (Right x) (Left s) = return $ Left (fmap (f x =<<) s)
combine (Left s1) (Left s2) = return $ Left (fmap (bindM2 f $ suspend s1) s2)
instance Functor s => MonadTrans (Coroutine s) where
lift = Coroutine . liftM Right
instance (Functor s, MonadIO m) => MonadIO (Coroutine s m) where
liftIO = lift . liftIO
data Naught x
instance Functor Naught where
fmap f _ = undefined
newtype NestedFunctor l r x = NestedFunctor (l (r x))
instance (Functor l, Functor r) => Functor (NestedFunctor l r) where
fmap f (NestedFunctor lr) = NestedFunctor ((fmap . fmap) f lr)
data SomeFunctor l r x = LeftSome (l x) | RightSome (r x) | Both (NestedFunctor l r x)
instance (Functor l, Functor r) => Functor (SomeFunctor l r) where
fmap f (LeftSome l) = LeftSome (fmap f l)
fmap f (RightSome r) = RightSome (fmap f r)
fmap f (Both lr) = Both (fmap f lr)
nest :: (Functor a, Functor b) => a x -> b y -> NestedFunctor a b (x, y)
nest a b = NestedFunctor $ fmap (\x-> fmap ((,) x) b) a
suspend :: (Monad m, Functor s) => s (Coroutine s m x) -> Coroutine s m x
suspend s = Coroutine (return (Left s))
mapMonad :: forall s m m' x. (Functor s, Monad m, Monad m') =>
(forall x. m x -> m' x) -> Coroutine s m x -> Coroutine s m' x
mapMonad f cort = Coroutine {resume= liftM map' (f $ resume cort)}
where map' (Right r) = Right r
map' (Left s) = Left (fmap (mapMonad f) s)
mapSuspension :: forall s s' m x. (Functor s, Monad m) => (forall x. s x -> s' x) -> Coroutine s m x -> Coroutine s' m x
mapSuspension f cort = Coroutine {resume= liftM map' (resume cort)}
where map' (Right r) = Right r
map' (Left s) = Left (f $ fmap (mapSuspension f) s)
runCoroutine :: Monad m => Coroutine Naught m x -> m x
runCoroutine = pogoStick (error "runCoroutine can run only a non-suspending coroutine!")
pogoStick :: Monad m => (s (Coroutine s m x) -> Coroutine s m x) -> Coroutine s m x -> m x
pogoStick reveal t = resume t
>>= \s-> case s
of Right result -> return result
Left c -> pogoStick reveal (reveal c)
foldRun :: Monad m => (a -> s (Coroutine s m x) -> (a, Coroutine s m x)) -> a -> Coroutine s m x -> m (a, x)
foldRun f a t = resume t
>>= \s-> case s
of Right result -> return (a, result)
Left c -> uncurry (foldRun f) (f a c)
couple :: forall s1 s2 m x y r. (Monad m, Functor s1, Functor s2) =>
(forall x y r. (x -> y -> m r) -> m x -> m y -> m r)
-> Coroutine s1 m x -> Coroutine s2 m y -> Coroutine (SomeFunctor s1 s2) m (x, y)
couple runPair t1 t2 = Coroutine{resume= runPair proceed (resume t1) (resume t2)} where
proceed :: CoroutineStepResult s1 m x -> CoroutineStepResult s2 m y -> m (CoroutineStepResult (SomeFunctor s1 s2) m (x, y))
proceed (Right x) (Right y) = return $ Right (x, y)
proceed (Left s1) (Left s2) = return $ Left $ fmap (uncurry (couple runPair)) (Both $ nest s1 s2)
proceed (Right x) (Left s2) = return $ Left $ fmap (couple runPair (return x)) (RightSome s2)
proceed (Left s1) (Right y) = return $ Left $ fmap (flip (couple runPair) (return y)) (LeftSome s1)
merge :: forall s m x. (Monad m, Functor s) =>
(forall x. [m x] -> m [x]) -> (forall x. [s x] -> s [x])
-> [Coroutine s m x] -> Coroutine s m [x]
merge sequence1 sequence2 corts = Coroutine{resume= liftM step $ sequence1 (map resume corts)} where
step :: [CoroutineStepResult s m x] -> CoroutineStepResult s m [x]
step list = case partitionEithers list
of ([], ends) -> Right ends
(suspensions, ends) -> Left $ fmap (merge sequence1 sequence2 . (map return ends ++)) $
sequence2 suspensions
data SeesawResolver s1 s2 = SeesawResolver {
resumeLeft :: forall t. s1 t -> t,
resumeRight :: forall t. s2 t -> t,
resumeAny :: forall t1 t2 r.
(t1 -> r)
-> (t2 -> r)
-> (t1 -> t2 -> r)
-> s1 t1
-> s2 t2
-> r
}
seesaw :: (Monad m, Functor s1, Functor s2) =>
(forall x y r. (x -> y -> m r) -> m x -> m y -> m r)
-> SeesawResolver s1 s2
-> Coroutine s1 m x -> Coroutine s2 m y -> m (x, y)
seesaw runPair resolver t1 t2 = seesaw' t1 t2 where
seesaw' t1 t2 = runPair proceed (resume t1) (resume t2)
proceed (Right x) (Right y) = return (x, y)
proceed (Right x) (Left s2) = seesaw' (return x) (resumeRight resolver s2)
proceed (Left s1) (Right y) = seesaw' (resumeLeft resolver s1) (return y)
proceed (Left s1) (Left s2) =
resumeAny resolver (flip seesaw' (suspend s2)) (seesaw' (suspend s1)) seesaw' s1 s2