{- 
    Copyright 2008 Mario Blazevic

    This file is part of the Streaming Component Combinators (SCC) project.

    The SCC project is free software: you can redistribute it and/or modify it under the terms of the GNU General Public
    License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later
    version.

    SCC is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty
    of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more details.

    You should have received a copy of the GNU General Public License along with SCC.  If not, see
    <http://www.gnu.org/licenses/>.
-}

-- | Module "Foundation" defines the pipe computations and their basic building blocks.

{-# LANGUAGE ScopedTypeVariables, Rank2Types, PatternGuards, ExistentialQuantification #-}

module Control.Concurrent.SCC.Foundation
   (-- * Classes
    ParallelizableMonad (parallelize),
    -- * Types
    Pipe, Source, Sink,
    -- * Flow-control functions
    pipe, pipeD, pipeP, get, getSuccess, canPut, put,
    liftPipe, runPipes,
    -- * Utility functions
    cond, whenNull, pour, tee, getList, putList, consumeAndSuppress)
where

import Control.Concurrent (forkIO)
import Control.Concurrent.MVar (newEmptyMVar, putMVar, takeMVar)
import Control.Exception (assert)
import Control.Monad (liftM, liftM2, when)
import Control.Monad.Identity
import Control.Parallel (par, pseq)
import Data.Maybe (maybe)
import Data.Typeable (Typeable, cast)

import Debug.Trace (trace)

class Monad m => ParallelizableMonad m where
   parallelize :: m a -> m b -> m (a, b)
   parallelize = liftM2 (,)

instance ParallelizableMonad Identity where
   parallelize ma mb = let a = runIdentity ma
                           b = runIdentity mb
                       in  a `par` (b `pseq` Identity (a, b))

instance ParallelizableMonad Maybe where
   parallelize ma mb = case ma `par` (mb `pseq` (ma, mb))
                       of (Just a, Just b) -> Just (a, b)
                          _ -> Nothing


instance ParallelizableMonad IO where
   parallelize ma mb = do va <- newEmptyMVar
                          vb <- newEmptyMVar
                          forkIO (ma >>= putMVar va)
                          forkIO (mb >>= putMVar vb)
                          a <- takeMVar va
                          b <- takeMVar vb
                          return (a, b)
                          

-- | 'Pipe' represents the type of monadic computations that can be split into co-routining computations using function
-- 'pipe'. The /context/ type parameter delimits the scope of the computation.
newtype Pipe context m r = Pipe {proceed :: PipeState context -> m (PipeRendezvous context m r)}
data PipeState context = PipeState {level :: Int,
                                    clock :: Integer}
data PipeRendezvous context m r = Suspend [Suspension context m r]
                                | Done Integer r
data Suspension context m r = Suspension {targetLevel :: Int,
                                          state :: PipeState context,
                                          description :: String,
                                          continuation :: SuspendedContinuation context m r}
data SuspendedContinuation context m r = forall x. Typeable x => Get (Maybe x -> Pipe context m r)
                                       | forall x. Typeable x => Put x (Bool -> Pipe context m r)
                                       | CanPut (Bool -> Pipe context m r)

-- | A 'Source' is the read-only end of a 'Pipe' communication channel.
data Source context x = Source Int String
-- | A 'Sink' is the write-only end of a 'Pipe' communication channel.
data Sink   context x = Sink   Int String

-- | A computation that consumes values from a 'Source' is called 'Consumer'.
type Consumer c m x r = Source c x -> Pipe c m r
-- | A computation that produces values and puts them into a 'Sink' is called 'Producer'.
type Producer c m x r = Sink c x -> Pipe c m r

-- | Function 'liftPipe' lifts a value of the underlying monad type into a 'Pipe' computation.
liftPipe :: forall context m r. Monad m => m r -> Pipe context m r
liftPipe mr = Pipe (\state-> liftM (Done (clock state)) mr)

-- | Function 'runPipes' runs the given computation involving pipes and returns the final result.
-- The /context/ argument ensures that no suspended computation can escape its scope.
runPipes :: forall m r. Monad m => (forall context. Pipe context m r) -> m r
runPipes c = proceed c (PipeState 1 0) >>= \s-> case s of Done _ r -> return r

instance Monad m => Monad (Pipe context m) where
   return r = Pipe (\state-> return (Done (clock state) r))
   Pipe p >>= f = Pipe (\state-> p state >>= apply f state)
      where apply :: forall r1 r2. (r1 -> Pipe context m r2) -> PipeState context -> PipeRendezvous context m r1
                  -> m (PipeRendezvous context m r2)
            apply f state (Done t r) = proceed (f r) state{clock= succ t}
            apply f state (Suspend suspensions) = return $ Suspend (map suspendApplied suspensions)
               where suspendApplied s = postApply (>>= f) s{description= "applied " ++ description s}

postApply :: (Pipe context m r1 -> Pipe context m r2) -> Suspension context m r1 -> Suspension context m r2
postApply f s = s{continuation= case continuation s of Get cont -> Get (f . cont)
                                                       Put x cont -> Put x (f . cont)
                                                       CanPut cont -> CanPut (f . cont)}

instance ParallelizableMonad m => ParallelizableMonad (Pipe context m) where
   parallelize p1 p2 = Pipe (\state-> liftM combine $ parallelize (proceed p1 state) (proceed p2 state))
      where combine :: forall r1 r2. (PipeRendezvous context m r1, PipeRendezvous context m r2) -> PipeRendezvous context m (r1, r2)
            combine (Done c1 r1, Done c2 r2) = Done (max c1 c2) (r1, r2)
            combine (Suspend s1, Done c2 r2) = Suspend (map (adjustSuspension c2 (liftM $ flip (,) r2)) s1)
            combine (Done c1 r1, Suspend s2) = Suspend (map (adjustSuspension c1 (liftM $ (,) r1)) s2)
            combine (r1@(Suspend s1), r2@(Suspend s2)) = Suspend (merge (map (postApply (flip parallelize (rewrap r2))) s1)
                                                                        (map (postApply (parallelize (rewrap r1))) s2))
            rewrap :: PipeRendezvous context m r -> Pipe context m r
            rewrap r = Pipe $ const $ return $ r
            adjustSuspension :: Integer -> (Pipe context m r1 -> Pipe context m r2)
                             -> Suspension context m r1 -> Suspension context m r2
            adjustSuspension c f s = postApply f s{state= (state s) {clock= clock (state s) `max` c}}

instance Show (Suspension context m r) where
   show Suspension{targetLevel= lvl, description = desc, continuation= c} = (case c of Put{} -> "(Put)"
                                                                                       CanPut{} -> "(CanPut)"
                                                                                       Get{} -> "(Get)")
                                                                            ++ desc ++ " -> " ++ show lvl

-- | The 'pipe' function splits the computation into two concurrent parts, /producer/ and /consumer/. The /producer/ is
-- given a 'Sink' to put values into, and /consumer/ a 'Source' to get those values from.  Once producer and consumer
-- both complete, 'pipe' returns their paired results.
pipe :: forall context x m r1 r2. Monad m => Producer context m x r1 -> Consumer context m x r2 -> Pipe context m (r1, r2)
pipe = pipeD ""

-- | The 'pipeD' function is same as 'pipe', with an additional description argument.
pipeD :: forall c x m r1 r2. Monad m => String -> Producer c m x r1 -> Consumer c m x r2 -> Pipe c m (r1, r2)
pipeD description producer consumer = pipePrim description (liftM2 (,)) producer consumer

-- | The 'pipeP' function is equivalent to 'pipe', except the /producer/ and /consumer/ are run in parallel if resources
-- allow.
pipeP :: forall c x m r1 r2. ParallelizableMonad m => Producer c m x r1 -> Consumer c m x r2 -> Pipe c m (r1, r2)
pipeP producer consumer = pipePrim "" parallelize producer consumer

-- | The 'pipePrim' function is the actual worker function of the 'pipe' family.
pipePrim :: forall c m x r1 r2. Monad m =>
            String -> (forall a b. m a -> m b -> m (a, b)) -> Producer c m x r1 -> Consumer c m x r2 -> Pipe c m (r1, r2)
pipePrim description pairMonads producer consumer
   = Pipe (\(PipeState level clock)-> let level' = succ level
                                          description' = description ++ ':' : show level
                                      in assert (track (indent level ++ "pipe " ++ description')) $
                                         do (ps, cs) <- pairMonads (proceed (producer (Sink level description'))
                                                                            (PipeState level' clock))
                                                                   (proceed (consumer (Source level description'))
                                                                            (PipeState level' clock))
                                            reduce pairMonads level ps cs)

reduce :: forall c m r1 r2. Monad m =>
          (m (PipeRendezvous c m r1) -> m (PipeRendezvous c m r2) -> m (PipeRendezvous c m r1, PipeRendezvous c m r2))
             -> Int -> PipeRendezvous c m r1 -> PipeRendezvous c m r2 -> m (PipeRendezvous c m (r1, r2))
reduce pairMonads level (Done t1 r1) (Done t2 r2)
   = assert (track (indent level ++ "Done " ++ show level ++ " -> " ++ show level)) $
     return (Done (max t1 t2) (r1, r2))
reduce pairMonads level (Suspend ps@(Suspension{targetLevel= l1, state= s1, continuation= pCont} : _)) consumer@Done{}
   | l1 == level, Put _ cont <- pCont
   = assert (track (indent level ++ "Failed producer put " ++ show ps ++ " from " ++ show level)) $
     proceed (cont False) s1 >>= \p'-> reduce pairMonads level p' consumer
   | l1 == level, CanPut cont <- pCont
   = assert (track (indent level ++ "Finish producer " ++ show ps ++ " from " ++ show level)) $
     proceed (cont False) s1 >>= \p'-> reduce pairMonads level p' consumer
   | l1 < level = assert (track (indent level ++ "Suspend producer " ++ show ps ++ " from " ++ show level)) $
                  return $ Suspend $ map (delay (\ps'-> reduce pairMonads level ps' consumer)) ps
   | otherwise = error (show l1 ++ ">" ++ show level ++ " | producer : " ++ show ps)
reduce pairMonads level producer@Done{} (Suspend cs@(Suspension{targetLevel= l2, state= s2, continuation= cCont} : _))
   | l2 == level, Get cont <- cCont
   = assert (track (indent level ++ "Finish consumer " ++ show cs ++ " from " ++ show level)) $
     proceed (cont Nothing) s2 >>= reduce pairMonads level producer
   | l2 < level
   = assert (track (indent level ++ "Suspend consumer " ++ show cs ++ " from " ++ show level)) $
     return $ Suspend $ map (delay (reduce pairMonads level producer)) cs
   | otherwise = error (show l2 ++ ">" ++ show level ++ " | consumer : " ++ show cs)
reduce pairMonads level producer@(Suspend ps@(Suspension{targetLevel= l1, state= s1, continuation= pc} : _))
                        consumer@(Suspend cs@(Suspension{targetLevel= l2, state= s2, continuation= Get cCont} : _))
   | l1 == level && l2 == level, CanPut pCont <- pc
   = assert (track (indent level ++ "CanPut Match at " ++ show level ++ " : " ++ show ps ++ " -> " ++ show cs)) $
     proceed (pCont True) s1 >>= \p'-> reduce pairMonads level p' consumer
   | l1 == level, Put x pCont <- pc
   = assert (track (indent level ++ "Match at " ++ show level ++ " : " ++ show ps ++ " -> " ++ show cs)) $
     do (p', c') <- pairMonads (assert (track "producer (") $ proceed (pCont True) (synchronizeState s1 s2))
                               (assert (track ") consumer (") $ proceed (cCont (cast x)) (synchronizeState s2 s1))
        assert (track ") combined ->") reduce pairMonads level p' c'
reduce pairMonads level producer@(Suspend ps) consumer@(Suspend cs) = assert (track (indent level ++ "Suspend producer & consumer, "
                                                                                     ++ show ps ++ " from " ++ show level ++ " & "
                                                                                     ++ show cs ++ " from " ++ show level)) $
                                                                                        keepSuspending ps cs
     where keepSuspending (Suspension{targetLevel=level'} : pTail) cs | level' == level = keepSuspending pTail cs
           keepSuspending ps (Suspension{targetLevel= level'} : cTail) | level' == level = keepSuspending ps cTail
           keepSuspending ps cs = assert (track (indent level ++ "Suspend' producer & consumer, "
                                                 ++ show ps ++ " from " ++ show level ++ " & "
                                                 ++ show cs ++ " from " ++ show level)) $
                                  return $ Suspend $
                                         merge (map (\p-> delay (\p'-> reduce pairMonads level p' consumer) p) ps)
                                               (map (delay (reduce pairMonads level producer)) cs)

merge :: [Suspension context m r] -> [Suspension context m r] -> [Suspension context m r]
merge [] l = l
merge l [] = l
merge l1@(h1@Suspension{targetLevel= level1, state= PipeState _ c1} : tail1)
      l2@(h2@Suspension{targetLevel= level2, state= PipeState _ c2} : tail2)
   | level1 > level2 = h1 : merge tail1 l2
   | level1 < level2 = h2 : merge l1 tail2
   | c1 < c2 = h1 : merge tail1 l2
   | otherwise = h2 : merge l1 tail2

delay :: Monad m =>
         (PipeRendezvous context m r1 -> m (PipeRendezvous context m r2)) -> Suspension context m r1 -> Suspension context m r2
delay f = delay' (\p-> Pipe $ \state-> proceed p state >>= f)

delay' :: (Pipe context m r1 -> Pipe context m r2) -> Suspension context m r1 -> Suspension context m r2
delay' f s@Suspension{description= desc, continuation= Get cont}
   = s{description= "delayed " ++ desc, continuation= Get (f . cont)}
delay' f s@Suspension{description= desc, continuation= Put x cont}
   = s{description= "delayed " ++ desc, continuation= Put x (f . cont)}
delay' f s@Suspension{description= desc, continuation= CanPut cont}
   = s{description= "delayed " ++ desc, continuation= CanPut (f . cont)}

synchronizeState :: PipeState context -> PipeState context -> PipeState context
synchronizeState (PipeState pid1 clock1) (PipeState pid2 clock2) = (PipeState pid1 (max clock1 clock2))

indent 0 = ""
indent n = ' ' : indent (n `div` 2)

-- | Function 'get' tries to get a value from the given 'Source' argument. The intervening 'Pipe' computations suspend
-- all the way to the 'pipe' function invocation that created the source. The result of 'get' is 'Nothing' iff the
-- argument source is empty.
get :: forall context x m r. (Monad m, Typeable x) => Source context x -> Pipe context m (Maybe x)
get (Source pid desc) = assert (track (indent pid ++ "Get from " ++ desc ++ "@" ++ show pid)) $
                        Pipe (\state@(PipeState pid' clock)->
                              assert (track (indent pid ++ "Get<- " ++ desc ++ "@" ++ show pid ++ ":" ++ show clock)) $
                              return $ Suspend $
                              [Suspension pid state ("get from " ++ desc ++ "@" ++ show pid ++ ":" ++ show clock) $ Get return])

getSuccess :: forall context x m. (Monad m, Typeable x)
              => Source context x
                 -> (x -> Pipe context m ()) -- ^ Success continuation
                 -> Pipe context m ()
getSuccess source succeed = get source >>= maybe (return ()) succeed

-- | Function 'put' tries to put a value into the given sink. The intervening 'Pipe' computations suspend up to the
-- 'pipe' invocation that has created the argument sink. The result of 'put' indicates whether the operation succeded.
put :: forall context x m r. (Monad m, Typeable x) => Sink context x -> x -> Pipe context m Bool
put (Sink pid desc) x = assert (track (indent pid ++ "Put into " ++ desc ++ "@" ++ show pid)) $
                        Pipe (\state@(PipeState pid' clock)->
                              assert (track (indent pid ++ "Put-> " ++ desc ++ "@" ++ show pid ++ ":" ++ show clock)) $
                              return $ Suspend $
                              [Suspension pid state ("put into " ++ desc ++ "@" ++ show pid ++ ":" ++ show clock)
                               (Put x return)])

-- | Function 'canPut' checks if the argument sink accepts values, i.e., whether a 'put' operation would succeed on the
-- sink.
canPut :: forall context x m r. (Monad m, Typeable x) => Sink context x -> Pipe context m Bool
canPut (Sink pid desc) = assert (track (indent pid ++ "CanPut into " ++ desc ++ "@" ++ show pid)) $
                         Pipe (\state@(PipeState pid' clock)->
                               assert (track (indent pid ++ "CanPut-> " ++ desc ++ "@" ++ show pid ++ ":" ++ show clock)) $
                               return $ Suspend $
                               [Suspension pid state ("canPut into " ++ desc ++ "@" ++ show pid ++ ":" ++ show clock)
                                (CanPut return)])

-- | 'pour' copies all data from the /source/ argument into the /sink/ argument, as long as there is anything to copy
-- and the sink accepts it.
pour :: forall c x m. (Monad m, Typeable x) => Source c x -> Sink c x -> Pipe c m ()
pour source sink = fill'
   where fill' = canPut sink >>= flip when (getSuccess source (\x-> put sink x >> fill'))

-- | 'tee' is similar to 'pour' except it distributes every input value from the /source/ arguments into both /sink1/
-- and /sink2/.
tee :: (Monad m, Typeable x) => Source c x -> Sink c x -> Sink c x -> Pipe c m [x]
tee source sink1 sink2 = distribute
   where distribute = do c1 <- canPut sink1
                         c2 <- canPut sink2
                         if c1 && c2
                            then get source >>= maybe (return []) (\x-> put sink1 x >> put sink2 x >> distribute)
                            else getList source

-- | 'putList' puts entire list into its /sink/ argument, as long as the sink accepts it. The remainder that wasn't
-- accepted by the sink is the result value.
putList :: forall x c m. (Monad m, Typeable x) => [x] -> Sink c x -> Pipe c m [x]
putList [] sink = return []
putList l@(x:rest) sink = put sink x >>= cond (putList rest sink) (return l)

-- | 'getList' returns the list of all values generated by the source.
getList :: forall x c m. (Monad m, Typeable x) => Source c x -> Pipe c m [x]
getList source = get source >>= maybe (return []) (\x-> liftM (x:) (getList source))

-- | 'consumeAndSuppress' consumes the entire source ignoring the values it generates.
consumeAndSuppress :: forall x c m. (Monad m, Typeable x) => Source c x -> Pipe c m ()
consumeAndSuppress source = getSuccess source (\x-> consumeAndSuppress source)

-- | A utility function wrapping if-then-else, useful for handling monadic truth values
cond :: a -> a -> Bool -> a
cond x y test = if test then x else y

-- | A utility function, useful for handling monadic list values where empty list means success
whenNull :: forall a m. Monad m => m [a] -> [a] -> m [a]
whenNull action list = if null list then action else return list

track :: String -> Bool
track message = True