module Control.Concurrent.Coroutine
(
Coroutine,
suspend,
ParallelizableMonad(..), AncestorFunctor,
runCoroutine, pogoStick, pogoStickNested, seesaw, seesawNested, SeesawResolver(..),
Yield(Yield), Await(Await), Naught,
yield, await,
nest, couple, coupleNested,
local, out, liftOut,
EitherFunctor(LeftF, RightF), NestedFunctor (NestedFunctor), SomeFunctor(..)
)
where
import Control.Concurrent (forkIO)
import Control.Concurrent.MVar (newEmptyMVar, putMVar, takeMVar)
import Control.Monad (liftM, liftM2, when)
import Control.Monad.Identity
import Control.Monad.Trans (MonadTrans(..), MonadIO(..))
import Control.Parallel (par, pseq)
class Monad m => ParallelizableMonad m where
bindM2 :: (a -> b -> m c) -> m a -> m b -> m c
bindM2 f ma mb = do {a <- ma; b <- mb; f a b}
instance ParallelizableMonad Identity where
bindM2 f ma mb = let a = runIdentity ma
b = runIdentity mb
in a `par` (b `pseq` a `pseq` f a b)
instance ParallelizableMonad Maybe where
bindM2 f ma mb = case ma `par` (mb `pseq` (ma, mb))
of (Just a, Just b) -> f a b
_ -> Nothing
instance ParallelizableMonad IO where
bindM2 f ma mb = do va <- newEmptyMVar
vb <- newEmptyMVar
forkIO (ma >>= putMVar va)
forkIO (mb >>= putMVar vb)
a <- takeMVar va
b <- takeMVar vb
f a b
newtype Coroutine s m r = Coroutine {
resume :: m (CoroutineState s m r)
}
data CoroutineState s m r =
Done r
| Suspend! (s (Coroutine s m r))
instance (Functor s, Monad m) => Monad (Coroutine s m) where
return x = Coroutine (return (Done x))
t >>= f = Coroutine (resume t >>= apply f)
where apply f (Done x) = resume (f x)
apply f (Suspend s) = return (Suspend (fmap (>>= f) s))
instance (Functor s, ParallelizableMonad m) => ParallelizableMonad (Coroutine s m) where
bindM2 f t1 t2 = Coroutine (bindM2 combine (resume t1) (resume t2)) where
combine (Done x) (Done y) = resume (f x y)
combine (Suspend s) (Done y) = return $ Suspend (fmap (flip f y =<<) s)
combine (Done x) (Suspend s) = return $ Suspend (fmap (f x =<<) s)
combine (Suspend s1) (Suspend s2) = return $ Suspend (fmap (bindM2 f $ suspend s1) s2)
instance Functor s => MonadTrans (Coroutine s) where
lift = Coroutine . liftM Done
instance (Functor s, MonadIO m) => MonadIO (Coroutine s m) where
liftIO = lift . liftIO
data Yield x y = Yield x y
instance Functor (Yield x) where
fmap f (Yield x y) = Yield x (f y)
data Await x y = Await! (x -> y)
instance Functor (Await x) where
fmap f (Await g) = Await (f . g)
data Naught x
instance Functor Naught where
fmap f _ = undefined
data EitherFunctor l r x = LeftF (l x) | RightF (r x)
instance (Functor l, Functor r) => Functor (EitherFunctor l r) where
fmap f (LeftF l) = LeftF (fmap f l)
fmap f (RightF r) = RightF (fmap f r)
newtype NestedFunctor l r x = NestedFunctor (l (r x))
instance (Functor l, Functor r) => Functor (NestedFunctor l r) where
fmap f (NestedFunctor lr) = NestedFunctor ((fmap . fmap) f lr)
data SomeFunctor l r x = LeftSome (l x) | RightSome (r x) | Both (NestedFunctor l r x)
instance (Functor l, Functor r) => Functor (SomeFunctor l r) where
fmap f (LeftSome l) = LeftSome (fmap f l)
fmap f (RightSome r) = RightSome (fmap f r)
fmap f (Both lr) = Both (fmap f lr)
suspend :: (Monad m, Functor s) => s (Coroutine s m x) -> Coroutine s m x
suspend s = Coroutine (return (Suspend s))
yield :: forall m x. Monad m => x -> Coroutine (Yield x) m ()
yield x = suspend (Yield x (return ()))
await :: forall m x. Monad m => Coroutine (Await x) m x
await = suspend (Await return)
runCoroutine :: Monad m => Coroutine Naught m x -> m x
runCoroutine = pogoStick (error "runCoroutine can run only a non-suspending coroutine!")
pogoStick :: (Functor s, Monad m) => (s (Coroutine s m x) -> Coroutine s m x) -> Coroutine s m x -> m x
pogoStick reveal t = resume t
>>= \s-> case s
of Done result -> return result
Suspend c -> pogoStick reveal (reveal c)
pogoStickNested :: (Functor s1, Functor s2, Monad m) =>
(s2 (Coroutine (EitherFunctor s1 s2) m x) -> Coroutine (EitherFunctor s1 s2) m x)
-> Coroutine (EitherFunctor s1 s2) m x -> Coroutine s1 m x
pogoStickNested reveal t =
Coroutine{resume= resume t
>>= \s-> case s
of Done result -> return (Done result)
Suspend (LeftF s) -> return (Suspend (fmap (pogoStickNested reveal) s))
Suspend (RightF c) -> resume (pogoStickNested reveal (reveal c))
}
nest :: (Functor a, Functor b) => a x -> b y -> NestedFunctor a b (x, y)
nest a b = NestedFunctor $ fmap (\x-> fmap ((,) x) b) a
couple :: (Monad m, Functor s1, Functor s2) =>
(forall x y r. (x -> y -> m r) -> m x -> m y -> m r)
-> Coroutine s1 m x -> Coroutine s2 m y -> Coroutine (SomeFunctor s1 s2) m (x, y)
couple runPair t1 t2 = Coroutine{resume= runPair proceed (resume t1) (resume t2)} where
proceed (Done x) (Done y) = return $ Done (x, y)
proceed (Suspend s1) (Suspend s2) = return $ Suspend $ fmap (uncurry (couple runPair)) (Both $ nest s1 s2)
proceed (Done x) (Suspend s2) = return $ Suspend $ fmap (couple runPair (return x)) (RightSome s2)
proceed (Suspend s1) (Done y) = return $ Suspend $ fmap (flip (couple runPair) (return y)) (LeftSome s1)
coupleNested :: (Monad m, Functor s0, Functor s1, Functor s2) =>
(forall x y r. (x -> y -> m r) -> m x -> m y -> m r)
-> Coroutine (EitherFunctor s0 s1) m x -> Coroutine (EitherFunctor s0 s2) m y
-> Coroutine (EitherFunctor s0 (SomeFunctor s1 s2)) m (x, y)
coupleNested runPair = coupleNested' where
coupleNested' t1 t2 = Coroutine{resume= runPair (\ st1 st2 -> return (proceed st1 st2)) (resume t1) (resume t2)}
proceed (Done x) (Done y) = Done (x, y)
proceed (Suspend (RightF s)) (Done y) = Suspend $ RightF $ fmap (flip coupleNested' (return y)) (LeftSome s)
proceed (Done x) (Suspend (RightF s)) = Suspend $ RightF $ fmap (coupleNested' (return x)) (RightSome s)
proceed (Suspend (RightF s1)) (Suspend (RightF s2)) =
Suspend $ RightF $ fmap (uncurry coupleNested') (Both $ nest s1 s2)
proceed (Suspend (LeftF s)) (Done y) = Suspend $ LeftF $ fmap (flip coupleNested' (return y)) s
proceed (Done x) (Suspend (LeftF s)) = Suspend $ LeftF $ fmap (coupleNested' (return x)) s
proceed (Suspend (LeftF s1)) (Suspend (LeftF s2)) = Suspend $ LeftF $ fmap (coupleNested' $ suspend $ LeftF s1) s2
data SeesawResolver s1 s2 = SeesawResolver {
resumeLeft :: forall t. s1 t -> t,
resumeRight :: forall t. s2 t -> t,
resumeAny :: forall t1 t2 r.
(t1 -> r)
-> (t2 -> r)
-> (t1 -> t2 -> r)
-> s1 t1
-> s2 t2
-> r
}
seesaw :: (Monad m, Functor s1, Functor s2) =>
(forall x y r. (x -> y -> m r) -> m x -> m y -> m r)
-> SeesawResolver s1 s2
-> Coroutine s1 m x -> Coroutine s2 m y -> m (x, y)
seesaw runPair resolver t1 t2 = seesaw' t1 t2 where
seesaw' t1 t2 = runPair proceed (resume t1) (resume t2)
proceed (Done x) (Done y) = return (x, y)
proceed (Done x) (Suspend s2) = seesaw' (return x) (resumeRight resolver s2)
proceed (Suspend s1) (Done y) = seesaw' (resumeLeft resolver s1) (return y)
proceed (Suspend s1) (Suspend s2) =
resumeAny resolver (flip seesaw' (suspend s2)) (seesaw' (suspend s1)) seesaw' s1 s2
seesawNested :: (Monad m, Functor s0, Functor s1, Functor s2) =>
(forall x y r. (x -> y -> m r) -> m x -> m y -> m r)
-> SeesawResolver s1 s2
-> Coroutine (EitherFunctor s0 s1) m x -> Coroutine (EitherFunctor s0 s2) m y -> Coroutine s0 m (x, y)
seesawNested runPair resolver t1 t2 = seesaw' t1 t2 where
seesaw' t1 t2 = Coroutine{resume= bouncePair t1 t2}
bouncePair t1 t2 = runPair proceed (resume t1) (resume t2)
proceed (Suspend (LeftF s1)) state2 = return $ Suspend $ fmap ((flip seesaw' (Coroutine $ return state2))) s1
proceed state1 (Suspend (LeftF s2)) = return $ Suspend $ fmap (seesaw' (Coroutine $ return state1)) s2
proceed (Done x) (Done y) = return $ Done (x, y)
proceed state1@(Done x) (Suspend (RightF s2)) = proceed state1 =<< resume (resumeRight resolver s2)
proceed (Suspend (RightF s1)) state2@(Done y) = flip proceed state2 =<< resume (resumeLeft resolver s1)
proceed state1@(Suspend (RightF s1)) state2@(Suspend (RightF s2)) =
resumeAny resolver ((flip proceed state2 =<<) . resume) ((proceed state1 =<<) . resume) bouncePair s1 s2
local :: forall m l r x. (Functor r, Monad m) => Coroutine r m x -> Coroutine (EitherFunctor l r) m x
local (Coroutine mr) = Coroutine (liftM inject mr)
where inject :: CoroutineState r m x -> CoroutineState (EitherFunctor l r) m x
inject (Done x) = Done x
inject (Suspend r) = Suspend (RightF $ fmap local r)
out :: forall m l r x. (Functor l, Monad m) => Coroutine l m x -> Coroutine (EitherFunctor l r) m x
out (Coroutine ml) = Coroutine (liftM inject ml)
where inject :: CoroutineState l m x -> CoroutineState (EitherFunctor l r) m x
inject (Done x) = Done x
inject (Suspend l) = Suspend (LeftF $ fmap out l)
class (Functor a, Functor d) => AncestorFunctor a d where
liftFunctor :: a x -> d x
instance Functor a => AncestorFunctor a a where
liftFunctor = id
instance (Functor a, Functor d', Functor d, d ~ EitherFunctor d' s, AncestorFunctor a d') => AncestorFunctor a d where
liftFunctor = LeftF . (liftFunctor :: a x -> d' x)
liftOut :: forall m a d x. (Monad m, Functor a, AncestorFunctor a d) => Coroutine a m x -> Coroutine d m x
liftOut (Coroutine ma) = Coroutine (liftM inject ma)
where inject :: CoroutineState a m x -> CoroutineState d m x
inject (Done x) = Done x
inject (Suspend a) = Suspend (liftFunctor $ fmap liftOut a)