module Control.Monad.Coroutine
(
Coroutine(Coroutine, resume), CoroutineStepResult, suspend,
mapMonad, mapSuspension, mapFirstSuspension,
Naught, runCoroutine, bounce, pogoStick, foldRun, seesaw, SeesawResolver(..), seesawSteps,
PairBinder, sequentialBinder, parallelBinder, liftBinder, SomeFunctor(..), composePair,
couple, merge
)
where
import Control.Applicative (Applicative(..), (<$>), liftA2)
import Control.Monad (Monad(..), ap, liftM)
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Trans.Class (MonadTrans(..))
import Data.Either (partitionEithers)
import Data.Functor.Compose (Compose(..))
import Control.Monad.Parallel (MonadParallel(..))
newtype Coroutine s m r = Coroutine {
resume :: m (Either (s (Coroutine s m r)) r)
}
type CoroutineStepResult s m r = Either (s (Coroutine s m r)) r
instance (Functor s, Functor m) => Functor (Coroutine s m) where
fmap f t = Coroutine (fmap (apply f) (resume t))
where apply fc (Right x) = Right (fc x)
apply fc (Left s) = Left (fmap (fmap fc) s)
instance (Functor s, Functor m, Monad m) => Applicative (Coroutine s m) where
pure = return
(<*>) = ap
instance (Functor s, Monad m) => Monad (Coroutine s m) where
return x = Coroutine (return (Right x))
t >>= f = Coroutine (resume t >>= apply f)
where apply fc (Right x) = resume (fc x)
apply fc (Left s) = return (Left (fmap (>>= fc) s))
t >> f = Coroutine (resume t >>= apply f)
where apply fc (Right x) = resume fc
apply fc (Left s) = return (Left (fmap (>> fc) s))
instance (Functor s, MonadParallel m) => MonadParallel (Coroutine s m) where
bindM2 = liftBinder bindM2
instance Functor s => MonadTrans (Coroutine s) where
lift = Coroutine . liftM Right
instance (Functor s, MonadIO m) => MonadIO (Coroutine s m) where
liftIO = lift . liftIO
data Naught x
instance Functor Naught where
fmap _ _ = undefined
data SomeFunctor l r x = LeftSome (l x) | RightSome (r x) | Both (Compose l r x)
instance (Functor l, Functor r) => Functor (SomeFunctor l r) where
fmap f (LeftSome l) = LeftSome (fmap f l)
fmap f (RightSome r) = RightSome (fmap f r)
fmap f (Both lr) = Both (fmap f lr)
composePair :: (Functor a, Functor b) => a x -> b y -> Compose a b (x, y)
composePair a b = Compose $ fmap (\x-> fmap ((,) x) b) a
suspend :: (Monad m, Functor s) => s (Coroutine s m x) -> Coroutine s m x
suspend s = Coroutine (return (Left s))
mapMonad :: forall s m m' x. (Functor s, Monad m, Monad m') =>
(forall y. m y -> m' y) -> Coroutine s m x -> Coroutine s m' x
mapMonad f cort = Coroutine {resume= liftM map' (f $ resume cort)}
where map' (Right r) = Right r
map' (Left s) = Left (fmap (mapMonad f) s)
mapSuspension :: (Functor s, Monad m) => (forall y. s y -> s' y) -> Coroutine s m x -> Coroutine s' m x
mapSuspension f cort = Coroutine {resume= liftM map' (resume cort)}
where map' (Right r) = Right r
map' (Left s) = Left (f $ fmap (mapSuspension f) s)
mapFirstSuspension :: forall s s' m x. (Functor s, Monad m) =>
(forall y. s y -> s y) -> Coroutine s m x -> Coroutine s m x
mapFirstSuspension f cort = Coroutine {resume= liftM map' (resume cort)}
where map' (Right r) = Right r
map' (Left s) = Left (f s)
runCoroutine :: Monad m => Coroutine Naught m x -> m x
runCoroutine = pogoStick (error "runCoroutine can run only a non-suspending coroutine!")
bounce :: (Monad m, Functor s) => (s (Coroutine s m x) -> Coroutine s m x) -> Coroutine s m x -> Coroutine s m x
bounce spring c = lift (resume c) >>= either spring return
pogoStick :: Monad m => (s (Coroutine s m x) -> Coroutine s m x) -> Coroutine s m x -> m x
pogoStick spring c = resume c >>= either (pogoStick spring . spring) return
foldRun :: Monad m => (a -> s (Coroutine s m x) -> (a, Coroutine s m x)) -> a -> Coroutine s m x -> m (a, x)
foldRun f a c = resume c
>>= \s-> case s
of Right result -> return (a, result)
Left c' -> uncurry (foldRun f) (f a c')
type PairBinder m = forall x y r. (x -> y -> m r) -> m x -> m y -> m r
sequentialBinder :: Monad m => PairBinder m
sequentialBinder f mx my = do {x <- mx; y <- my; f x y}
parallelBinder :: MonadParallel m => PairBinder m
parallelBinder = bindM2
liftBinder :: forall s m. (Functor s, Monad m) => PairBinder m -> PairBinder (Coroutine s m)
liftBinder binder f t1 t2 = Coroutine (binder combine (resume t1) (resume t2)) where
combine (Right x) (Right y) = resume (f x y)
combine (Left s) (Right y) = return $ Left (fmap (flip f y =<<) s)
combine (Right x) (Left s) = return $ Left (fmap (f x =<<) s)
combine (Left s1) (Left s2) = return $ Left (fmap (liftBinder binder f $ suspend s1) s2)
couple :: forall s1 s2 m x y. (Monad m, Functor s1, Functor s2) =>
PairBinder m -> Coroutine s1 m x -> Coroutine s2 m y -> Coroutine (SomeFunctor s1 s2) m (x, y)
couple runPair t1 t2 = Coroutine{resume= runPair proceed (resume t1) (resume t2)} where
proceed :: CoroutineStepResult s1 m x -> CoroutineStepResult s2 m y
-> m (CoroutineStepResult (SomeFunctor s1 s2) m (x, y))
proceed (Right x) (Right y) = return $ Right (x, y)
proceed (Left s1) (Left s2) = return $ Left
$ fmap (uncurry (couple runPair)) (Both $ composePair s1 s2)
proceed (Right x) (Left s2) = return $ Left $ fmap (couple runPair (return x)) (RightSome s2)
proceed (Left s1) (Right y) = return $ Left
$ fmap (flip (couple runPair) (return y)) (LeftSome s1)
merge :: forall s m x. (Monad m, Functor s) =>
(forall y. [m y] -> m [y]) -> (forall y. [s y] -> s [y])
-> [Coroutine s m x] -> Coroutine s m [x]
merge sequence1 sequence2 corts = Coroutine{resume= liftM step $ sequence1 (map resume corts)} where
step :: [CoroutineStepResult s m x] -> CoroutineStepResult s m [x]
step list = case partitionEithers list
of ([], ends) -> Right ends
(suspensions, ends) -> Left $ fmap (merge sequence1 sequence2 . (map return ends ++)) $
sequence2 suspensions
data SeesawResolver s1 s2 s1' s2' = SeesawResolver {
resumeLeft :: forall m t. (Monad m) => s1 (Coroutine s1' m t) -> Coroutine s1' m t,
resumeRight :: forall m t. (Monad m) => s2 (Coroutine s2' m t) -> Coroutine s2' m t,
resumeBoth :: forall m t1 t2 r. (Monad m) =>
(Coroutine s1' m t1 -> Coroutine s2' m t2 -> r)
-> s1 (Coroutine s1' m t1)
-> s2 (Coroutine s2' m t2)
-> r
}
seesaw :: (Monad m, Functor s1, Functor s2) =>
PairBinder m -> SeesawResolver s1 s2 s1 s2 -> Coroutine s1 m x -> Coroutine s2 m y -> m (x, y)
seesaw runPair resolver t1 t2 = seesawSteps runPair proceed t1 t2 where
proceed cont (Left s1) (Left s2) = resumeBoth resolver cont s1 s2
proceed _ (Right x) (Left s2) = liftM ((,) x) $ pogoStick (resumeRight resolver) (resumeRight resolver s2)
proceed _ (Left s1) (Right y) = liftM (flip (,) y) $ pogoStick (resumeLeft resolver) (resumeLeft resolver s1)
proceed _ (Right x) (Right y) = return (x, y)
seesawSteps :: (Monad m, Functor s1, Functor s2) =>
PairBinder m
-> ((Coroutine s1 m x -> Coroutine s2 m y -> m (x, y))
-> CoroutineStepResult s1 m x -> CoroutineStepResult s2 m y -> m (x, y))
-> Coroutine s1 m x -> Coroutine s2 m y -> m (x, y)
seesawSteps runPair proceed = seesaw' where
seesaw' t1 t2 = runPair (proceed seesaw') (resume t1) (resume t2)