{- 
    Copyright 2009-2010 Mario Blazevic

    This file is part of the Streaming Component Combinators (SCC) project.

    The SCC project is free software: you can redistribute it and/or modify it under the terms of the GNU General Public
    License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later
    version.

    SCC is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty
    of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more details.

    You should have received a copy of the GNU General Public License along with SCC.  If not, see
    <http://www.gnu.org/licenses/>.
-}

-- | This module defines the 'Coroutine' monad transformer.
-- 
-- A 'Coroutine' monadic computation can 'suspend' its execution at any time, returning to its invoker. The returned
-- coroutine suspension contains the continuation of the coroutine embedded in a functor. Here is an example of a
-- coroutine that suspends computation in the 'IO' monad using the functor 'Yield':
-- 
-- @
-- producer = do yield 1
--               lift (putStrLn \"Produced one, next is four.\")
--               yield 4
--               return \"Finished\"
-- @
-- 
-- A suspended 'Coroutine' computation can be resumed. The easiest way to run a coroutine is by using the 'pogoStick'
-- function, which keeps resuming the coroutine in trampolined style until it completes. Here is an example of
-- 'pogoStick' applied to the /producer/ above:
-- 
-- @
-- printProduce :: Show x => Coroutine (Yield x) IO r -> IO r
-- printProduce producer = pogoStick (\\(Yield x cont) -> lift (print x) >> cont) producer
-- @
-- 
-- Multiple concurrent coroutines can be run as well, and this module provides two different ways. The function 'seesaw'
-- can be used to run two interleaved computations. Another possible way is to weave together steps of different
-- coroutines into a single coroutine using the function 'couple', which can then be executed by 'pogoStick'.
-- 
-- Coroutines can be run from within another coroutine. In this case, the nested coroutines would normally suspend to
-- their invoker. Another option is to allow a nested coroutine to suspend both itself and its invoker at once. In this
-- case, the two suspension functors should be grouped into an 'EitherFunctor'. To run nested coroutines of this kind,
-- use functions 'pogoStickNested', 'seesawNested', and 'coupleNested'.
-- 
-- For other uses of trampoline-style coroutines, see
-- 
-- > Trampolined Style - Ganz, S. E. Friedman, D. P. Wand, M, ACM SIGPLAN NOTICES, 1999, VOL 34; NUMBER 9, pages 18-27
-- 
-- and
-- 
-- > The Essence of Multitasking - William L. Harrison, Proceedings of the 11th International Conference on Algebraic
-- > Methodology and Software Technology, volume 4019 of Lecture Notes in Computer Science, 2006

{-# LANGUAGE ScopedTypeVariables, Rank2Types, MultiParamTypeClasses, TypeFamilies, EmptyDataDecls,
             FlexibleInstances, OverlappingInstances, UndecidableInstances
 #-}

module Control.Concurrent.Coroutine
   (
    -- * Coroutine definition
    Coroutine,
    suspend,
    -- * Useful classes
    ParallelizableMonad(..), AncestorFunctor,
    -- * Running Coroutine computations
    runCoroutine, pogoStick, pogoStickNested, seesaw, seesawNested, SeesawResolver(..),
    -- * Suspension functors
    Yield(Yield), Await(Await), Naught,
    yield, await,
    -- * Nested and coupled Coroutine computations
    nest, couple, coupleNested,
    local, out, liftOut,
    EitherFunctor(LeftF, RightF), NestedFunctor (NestedFunctor), SomeFunctor(..)
   )
where

import Control.Concurrent (forkIO)
import Control.Concurrent.MVar (newEmptyMVar, putMVar, takeMVar)
import Control.Monad (liftM, liftM2, when)
import Control.Monad.Identity
import Control.Monad.Trans (MonadTrans(..), MonadIO(..))
import Control.Parallel (par, pseq)

-- | Class of monads that can perform two computations in parallel.
class Monad m => ParallelizableMonad m where
   -- | Perform two monadic computations in parallel and pass the results.
   bindM2 :: (a -> b -> m c) -> m a -> m b -> m c
   bindM2 f ma mb = do {a <- ma; b <- mb; f a b}

-- | Any monad that allows the result value to be extracted, such as `Identity` or `Maybe` monad, can implement
-- `bindM2` by using `par`.
instance ParallelizableMonad Identity where
   bindM2 f ma mb = let a = runIdentity ma
                        b = runIdentity mb
                    in  a `par` (b `pseq` a `pseq` f a b)

instance ParallelizableMonad Maybe where
   bindM2 f ma mb = case ma `par` (mb `pseq` (ma, mb))
                    of (Just a, Just b) -> f a b
                       _ -> Nothing

-- | IO is parallelizable by `forkIO`.
instance ParallelizableMonad IO where
   bindM2 f ma mb = do va <- newEmptyMVar
                       vb <- newEmptyMVar
                       forkIO (ma >>= putMVar va)
                       forkIO (mb >>= putMVar vb)
                       a <- takeMVar va
                       b <- takeMVar vb
                       f a b

-- | Suspending, resumable monadic computations.
newtype Coroutine s m r = Coroutine {
   -- | Run the next step of a `Coroutine` computation.
   resume :: m (CoroutineState s m r)
   }

data CoroutineState s m r =
   -- | Coroutine computation is finished with final value /r/.
   Done r
   -- | Computation is suspended, its remainder is embedded in the functor /s/.
 | Suspend! (s (Coroutine s m r))

instance (Functor s, Monad m) => Monad (Coroutine s m) where
   return x = Coroutine (return (Done x))
   t >>= f = Coroutine (resume t >>= apply f)
      where apply f (Done x) = resume (f x)
            apply f (Suspend s) = return (Suspend (fmap (>>= f) s))

instance (Functor s, ParallelizableMonad m) => ParallelizableMonad (Coroutine s m) where
   bindM2 f t1 t2 = Coroutine (bindM2 combine (resume t1) (resume t2)) where
      combine (Done x) (Done y) = resume (f x y)
      combine (Suspend s) (Done y) = return $ Suspend (fmap (flip f y =<<) s)
      combine (Done x) (Suspend s) = return $ Suspend (fmap (f x =<<) s)
      combine (Suspend s1) (Suspend s2) = return $ Suspend (fmap (bindM2 f $ suspend s1) s2)

instance Functor s => MonadTrans (Coroutine s) where
   lift = Coroutine . liftM Done

instance (Functor s, MonadIO m) => MonadIO (Coroutine s m) where
   liftIO = lift . liftIO

-- | The 'Yield' functor instance is equivalent to (,) but more descriptive.
data Yield x y = Yield x y
instance Functor (Yield x) where
   fmap f (Yield x y) = Yield x (f y)

-- | The 'Await' functor instance is equivalent to (->) but more descriptive.
data Await x y = Await! (x -> y)
instance Functor (Await x) where
   fmap f (Await g) = Await (f . g)

-- | The 'Naught' functor instance doesn't contain anything and cannot be constructed. Used for building non-suspendable
-- coroutines.
data Naught x
instance Functor Naught where
   fmap f _ = undefined

-- | Combines two alternative functors into one, applying one or the other. Used for nested coroutines.
data EitherFunctor l r x = LeftF (l x) | RightF (r x)
instance (Functor l, Functor r) => Functor (EitherFunctor l r) where
   fmap f (LeftF l) = LeftF (fmap f l)
   fmap f (RightF r) = RightF (fmap f r)

-- | Combines two functors into one, applying both.
newtype NestedFunctor l r x = NestedFunctor (l (r x))
instance (Functor l, Functor r) => Functor (NestedFunctor l r) where
   fmap f (NestedFunctor lr) = NestedFunctor ((fmap . fmap) f lr)

-- | Combines two functors into one, applying either or both of them. Used for coupled coroutines.
data SomeFunctor l r x = LeftSome (l x) | RightSome (r x) | Both (NestedFunctor l r x)
instance (Functor l, Functor r) => Functor (SomeFunctor l r) where
   fmap f (LeftSome l) = LeftSome (fmap f l)
   fmap f (RightSome r) = RightSome (fmap f r)
   fmap f (Both lr) = Both (fmap f lr)

-- | Suspend the current 'Coroutine'.
suspend :: (Monad m, Functor s) => s (Coroutine s m x) -> Coroutine s m x
suspend s = Coroutine (return (Suspend s))

-- | Suspend yielding a value.
yield :: forall m x. Monad m => x -> Coroutine (Yield x) m ()
yield x = suspend (Yield x (return ()))

-- | Suspend until a value is provided.
await :: forall m x. Monad m => Coroutine (Await x) m x
await = suspend (Await return)

-- | Convert a non-suspending 'Coroutine' to the base monad.
runCoroutine :: Monad m => Coroutine Naught m x -> m x
runCoroutine = pogoStick (error "runCoroutine can run only a non-suspending coroutine!")

-- | Run a 'Coroutine', using a function that converts suspension to the resumption it wraps.
pogoStick :: (Functor s, Monad m) => (s (Coroutine s m x) -> Coroutine s m x) -> Coroutine s m x -> m x
pogoStick reveal t = resume t
                     >>= \s-> case s 
                              of Done result -> return result
                                 Suspend c -> pogoStick reveal (reveal c)

-- | Run a nested 'Coroutine' that can suspend both itself and the current 'Coroutine'.
pogoStickNested :: (Functor s1, Functor s2, Monad m) => 
                   (s2 (Coroutine (EitherFunctor s1 s2) m x) -> Coroutine (EitherFunctor s1 s2) m x)
                   -> Coroutine (EitherFunctor s1 s2) m x -> Coroutine s1 m x
pogoStickNested reveal t = 
   Coroutine{resume= resume t
                      >>= \s-> case s
                               of Done result -> return (Done result)
                                  Suspend (LeftF s) -> return (Suspend (fmap (pogoStickNested reveal) s))
                                  Suspend (RightF c) -> resume (pogoStickNested reveal (reveal c))
             }

-- | Combines two values under two functors into a pair of values under a single 'NestedFunctor'.
nest :: (Functor a, Functor b) => a x -> b y -> NestedFunctor a b (x, y)
nest a b = NestedFunctor $ fmap (\x-> fmap ((,) x) b) a

-- | Weaves two coroutines into one.
couple :: (Monad m, Functor s1, Functor s2) => 
          (forall x y r. (x -> y -> m r) -> m x -> m y -> m r)
       -> Coroutine s1 m x -> Coroutine s2 m y -> Coroutine (SomeFunctor s1 s2) m (x, y)
couple runPair t1 t2 = Coroutine{resume= runPair proceed (resume t1) (resume t2)} where
   proceed (Done x) (Done y) = return $ Done (x, y)
   proceed (Suspend s1) (Suspend s2) = return $ Suspend $ fmap (uncurry (couple runPair)) (Both $ nest s1 s2)
   proceed (Done x) (Suspend s2) = return $ Suspend $ fmap (couple runPair (return x)) (RightSome s2)
   proceed (Suspend s1) (Done y) = return $ Suspend $ fmap (flip (couple runPair) (return y)) (LeftSome s1)

-- | Weaves two nested coroutines into one.
coupleNested :: (Monad m, Functor s0, Functor s1, Functor s2) => 
                (forall x y r. (x -> y -> m r) -> m x -> m y -> m r)
             -> Coroutine (EitherFunctor s0 s1) m x -> Coroutine (EitherFunctor s0 s2) m y
             -> Coroutine (EitherFunctor s0 (SomeFunctor s1 s2)) m (x, y)
coupleNested runPair = coupleNested' where
   coupleNested' t1 t2 = Coroutine{resume= runPair (\ st1 st2 -> return (proceed st1 st2)) (resume t1) (resume t2)}
   proceed (Done x) (Done y) = Done (x, y)
   proceed (Suspend (RightF s)) (Done y) = Suspend $ RightF $ fmap (flip coupleNested' (return y)) (LeftSome s)
   proceed (Done x) (Suspend (RightF s)) = Suspend $ RightF $ fmap (coupleNested' (return x)) (RightSome s)
   proceed (Suspend (RightF s1)) (Suspend (RightF s2)) =
      Suspend $ RightF $ fmap (uncurry coupleNested') (Both $ nest s1 s2)
   proceed (Suspend (LeftF s)) (Done y) = Suspend $ LeftF $ fmap (flip coupleNested' (return y)) s
   proceed (Done x) (Suspend (LeftF s)) = Suspend $ LeftF $ fmap (coupleNested' (return x)) s
   proceed (Suspend (LeftF s1)) (Suspend (LeftF s2)) = Suspend $ LeftF $ fmap (coupleNested' $ suspend $ LeftF s1) s2

-- | A simple record containing the resolver functions for all possible coroutine pair suspensions.
data SeesawResolver s1 s2 = SeesawResolver {
   resumeLeft  :: forall t. s1 t -> t,    -- ^ resolves the left suspension functor into the resumption it contains
   resumeRight :: forall t. s2 t -> t,    -- ^ resolves the right suspension into its resumption
   -- | invoked when both coroutines are suspended, resolves both suspensions or either one
   resumeAny   :: forall t1 t2 r.
                  (t1 -> r)       --  ^ continuation to resume only the left suspended coroutine
               -> (t2 -> r)       --  ^ continuation to resume the right coroutine only
               -> (t1 -> t2 -> r) --  ^ continuation to resume both coroutines
               -> s1 t1           --  ^ left suspension
               -> s2 t2           --  ^ right suspension
               -> r
}

-- | Runs two coroutines concurrently. The first argument is used to run the next step of each coroutine, the next to
-- convert the left, right, or both suspensions into the corresponding resumptions.
seesaw :: (Monad m, Functor s1, Functor s2) => 
          (forall x y r. (x -> y -> m r) -> m x -> m y -> m r)
       -> SeesawResolver s1 s2
       -> Coroutine s1 m x -> Coroutine s2 m y -> m (x, y)
seesaw runPair resolver t1 t2 = seesaw' t1 t2 where
   seesaw' t1 t2 = runPair proceed (resume t1) (resume t2)
   proceed (Done x) (Done y) = return (x, y)
   proceed (Done x) (Suspend s2) = seesaw' (return x) (resumeRight resolver s2)
   proceed (Suspend s1) (Done y) = seesaw' (resumeLeft resolver s1) (return y)
   proceed (Suspend s1) (Suspend s2) =
      resumeAny resolver (flip seesaw' (suspend s2)) (seesaw' (suspend s1)) seesaw' s1 s2

-- | Like 'seesaw', but for nested coroutines that are allowed to suspend the current coroutine as well as themselves.
seesawNested :: (Monad m, Functor s0, Functor s1, Functor s2) =>
                (forall x y r. (x -> y -> m r) -> m x -> m y -> m r)
             -> SeesawResolver s1 s2
             -> Coroutine (EitherFunctor s0 s1) m x -> Coroutine (EitherFunctor s0 s2) m y -> Coroutine s0 m (x, y)
seesawNested runPair resolver t1 t2 = seesaw' t1 t2 where
   seesaw' t1 t2 = Coroutine{resume= bouncePair t1 t2}
   bouncePair t1 t2 = runPair proceed (resume t1) (resume t2)
   proceed (Suspend (LeftF s1)) state2 = return $ Suspend $ fmap ((flip seesaw' (Coroutine $ return state2))) s1
   proceed state1 (Suspend (LeftF s2)) = return $ Suspend $ fmap (seesaw' (Coroutine $ return state1)) s2
   proceed (Done x) (Done y) = return $ Done (x, y)
   proceed state1@(Done x) (Suspend (RightF s2)) = proceed state1 =<< resume (resumeRight resolver s2)
   proceed (Suspend (RightF s1)) state2@(Done y) = flip proceed state2 =<< resume (resumeLeft resolver s1)
   proceed state1@(Suspend (RightF s1)) state2@(Suspend (RightF s2)) =
      resumeAny resolver ((flip proceed state2 =<<) . resume) ((proceed state1 =<<) . resume) bouncePair s1 s2

-- | Converts a coroutine into a nested one.
local :: forall m l r x. (Functor r, Monad m) => Coroutine r m x -> Coroutine (EitherFunctor l r) m x
local (Coroutine mr) = Coroutine (liftM inject mr)
   where inject :: CoroutineState r m x -> CoroutineState (EitherFunctor l r) m x
         inject (Done x) = Done x
         inject (Suspend r) = Suspend (RightF $ fmap local r)

-- | Converts a coroutine into one that can contain nested coroutines.
out :: forall m l r x. (Functor l, Monad m) => Coroutine l m x -> Coroutine (EitherFunctor l r) m x
out (Coroutine ml) = Coroutine (liftM inject ml)
   where inject :: CoroutineState l m x -> CoroutineState (EitherFunctor l r) m x
         inject (Done x) = Done x
         inject (Suspend l) = Suspend (LeftF $ fmap out l)

-- | Class of functors that can be lifted.
class (Functor a, Functor d) => AncestorFunctor a d where
   -- | Convert the ancestor functor into its descendant. The descendant functor typically contains the ancestor.
   liftFunctor :: a x -> d x

instance Functor a => AncestorFunctor a a where
   liftFunctor = id
instance (Functor a, Functor d', Functor d, d ~ EitherFunctor d' s, AncestorFunctor a d') => AncestorFunctor a d where
   liftFunctor = LeftF . (liftFunctor :: a x -> d' x)

-- | Like 'out', working over multiple functors.
liftOut :: forall m a d x. (Monad m, Functor a, AncestorFunctor a d) => Coroutine a m x -> Coroutine d m x
liftOut (Coroutine ma) = Coroutine (liftM inject ma)
   where inject :: CoroutineState a m x -> CoroutineState d m x
         inject (Done x) = Done x
         inject (Suspend a) = Suspend (liftFunctor $ fmap liftOut a)