{- 
    Copyright 2010 Mario Blazevic

    This file is part of the Streaming Component Combinators (SCC) project.

    The SCC project is free software: you can redistribute it and/or modify it under the terms of the GNU General Public
    License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later
    version.

    SCC is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty
    of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more details.

    You should have received a copy of the GNU General Public License along with SCC.  If not, see
    <http://www.gnu.org/licenses/>.
-}

-- | This module defines nestable suspension functors for use with the 'Coroutine' monad transformer, as well as
-- functions for running nested coroutines of this sort.
-- 
-- Coroutines can be run from within another coroutine. In this case, the nested coroutines always suspend to their
-- invoker. If a function from this module, such as 'pogoStickNested', is used to run a nested coroutine, the parent
-- coroutine can be automatically suspended as well. A single suspension can thus suspend an entire chain of nested
-- coroutines.
-- 
-- Nestable coroutines of this kind should group their suspension functors into an 'EitherFunctor'. You can adjust a
-- normal suspension, such as the one produced by 'yield', using functions 'mapSuspension' and 'liftOut'. To run nested
-- coroutines, use functions 'pogoStickNested', 'seesawNested', and 'coupleNested'.

{-# LANGUAGE ScopedTypeVariables, Rank2Types, MultiParamTypeClasses, TypeFamilies,
             FlexibleContexts, FlexibleInstances, OverlappingInstances, UndecidableInstances
 #-}

module Control.Monad.Coroutine.Nested
   (
    pogoStickNested, coupleNested, seesawNested, 
    AncestorFunctor,
    liftOut
   )
where

import Control.Monad (join, liftM)
import Control.Monad.Trans.Class (lift)

import Control.Monad.Coroutine
import Control.Monad.Coroutine.SuspensionFunctors

-- | Run a nested 'Coroutine' that can suspend both itself and the current 'Coroutine'.
pogoStickNested :: forall s1 s2 m x. (Functor s1, Functor s2, Monad m) => 
                   (s2 (Coroutine (EitherFunctor s1 s2) m x) -> Coroutine (EitherFunctor s1 s2) m x)
                   -> Coroutine (EitherFunctor s1 s2) m x -> Coroutine s1 m x
pogoStickNested reveal t = 
   Coroutine{resume= resume t
                      >>= \s-> case s
                               of Right result -> return (Right result)
                                  Left (LeftF s) -> return (Left (fmap (pogoStickNested reveal) s))
                                  Left (RightF c) -> resume (pogoStickNested reveal (reveal c))}

-- | Weaves two nested coroutines into one.
coupleNested :: forall s0 s1 s2 m x y r. (Monad m, Functor s0, Monad s0, Functor s1, Functor s2) => 
                (forall x y r. (x -> y -> m r) -> m x -> m y -> m r)
             -> Coroutine (EitherFunctor s0 s1) m x -> Coroutine (EitherFunctor s0 s2) m y
             -> Coroutine (EitherFunctor s0 (SomeFunctor s1 s2)) m (x, y)
coupleNested runPair = coupleNested' where
   coupleNested' t1 t2 = Coroutine{resume= runPair (\ st1 st2 -> return (proceed st1 st2)) (resume t1) (resume t2)}
   proceed (Right x) (Right y) = Right (x, y)
   proceed (Left (RightF s)) (Right y) = Left $ RightF $ fmap (flip coupleNested' (return y)) (LeftSome s)
   proceed (Right x) (Left (RightF s)) = Left $ RightF $ fmap (coupleNested' (return x)) (RightSome s)
   proceed (Left (RightF s1)) (Left (RightF s2)) =
      Left $ RightF $ fmap (uncurry coupleNested') (Both $ composePair s1 s2)
   proceed l (Left (LeftF s)) = Left $ LeftF $ fmap (coupleNested' (Coroutine $ return l)) s
   proceed (Left (LeftF s)) r = Left $ LeftF $ fmap (flip coupleNested' (Coroutine $ return r)) s

-- | Like 'seesaw', but for nested coroutines that are allowed to suspend the current coroutine as well as themselves.
-- If both coroutines try to suspend the current coroutine in the same step, the left coroutine's suspension will have
-- precedence.
seesawNested :: (Monad m, Functor s0, Functor s1, Functor s2) =>
                (forall x y r. (x -> y -> m r) -> m x -> m y -> m r)
             -> SeesawResolver s1 s2
             -> Coroutine (EitherFunctor s0 s1) m x -> Coroutine (EitherFunctor s0 s2) m y -> Coroutine s0 m (x, y)
seesawNested runPair resolver t1 t2 = seesaw' t1 t2 where
   seesaw' t1 t2 = Coroutine{resume= bouncePair t1 t2}
   bouncePair t1 t2 = runPair proceed (resume t1) (resume t2)
   proceed (Left (LeftF s1)) state2 = return $ Left $ fmap ((flip seesaw' (Coroutine $ return state2))) s1
   proceed state1 (Left (LeftF s2)) = return $ Left $ fmap (seesaw' (Coroutine $ return state1)) s2
   proceed (Right x) (Right y) = return $ Right (x, y)
   proceed state1@(Right x) (Left (RightF s2)) = proceed state1 =<< resume (resumeRight resolver s2)
   proceed (Left (RightF s1)) state2@(Right y) = flip proceed state2 =<< resume (resumeLeft resolver s1)
   proceed state1@(Left (RightF s1)) state2@(Left (RightF s2)) =
      resumeAny resolver ((flip proceed state2 =<<) . resume) ((proceed state1 =<<) . resume) bouncePair s1 s2

-- | Class of functors that can contain another functor.
class Functor c => ChildFunctor c where
   type Parent c :: * -> *
   wrap :: Parent c x -> c x
instance (Functor p, Functor s) => ChildFunctor (EitherFunctor p s) where
   type Parent (EitherFunctor p s) = p
   wrap = LeftF

-- | Class of functors that can be lifted.
class (Functor a, Functor d) => AncestorFunctor a d where
   -- | Convert the ancestor functor into its descendant. The descendant functor typically contains the ancestor.
   liftFunctor :: a x -> d x

instance Functor a => AncestorFunctor a a where
   liftFunctor = id
instance (Functor a, ChildFunctor d, d' ~ Parent d, AncestorFunctor a d') => AncestorFunctor a d where
   liftFunctor = wrap . (liftFunctor :: a x -> d' x)

-- | Converts a coroutine into a descendant nested coroutine.
liftOut :: forall m a d x. (Monad m, Functor a, AncestorFunctor a d) => Coroutine a m x -> Coroutine d m x
liftOut cort = mapSuspension liftFunctor cort