{- 
    Copyright 2008 Mario Blazevic

    This file is part of the Streaming Component Combinators (SCC) project.

    The SCC project is free software: you can redistribute it and/or modify it under the terms of the GNU General Public
    License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later
    version.

    SCC is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty
    of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more details.

    You should have received a copy of the GNU General Public License along with SCC.  If not, see
    <http://www.gnu.org/licenses/>.
-}

-- | Module "Foundation" defines the pipe computations and their basic building blocks.

{-# LANGUAGE ScopedTypeVariables, Rank2Types, PatternGuards, ExistentialQuantification #-}

module Control.Concurrent.SCC.Foundation
   (-- * Types
    Pipe, Source, Sink, Consumer, Producer,
    -- * Flow-control functions
    pipe, pipeD, get, getSuccess, canPut, put,
    liftPipe, runPipes,
    -- * Utility functions
    cond, whenNull, pour, tee, getList, putList, consumeAndSuppress)
where

import Control.Exception (assert)
import Control.Monad (liftM, when)
import Data.Maybe (maybe)
import Data.Typeable (Typeable, cast)

import Debug.Trace (trace)

-- | 'Pipe' represents the type of monadic computations that can be split into co-routining computations using function
-- 'pipe'. The /context/ type parameter delimits the scope of the computation.
newtype Monad m => Pipe context m r = Pipe {proceed :: context -> Int -> Integer -> m (PipeState context m r)}
data PipeState context m r = Suspend context [Suspension context m r]
                           | Done Integer r
data Suspension context m r = Suspension {pid :: Int,
                                          clock :: Integer,
                                          description :: String,
                                          continuation :: SuspendedContinuation context m r}
data SuspendedContinuation context m r = forall x. Typeable x => Get (Maybe x -> Pipe context m r)
                                       | forall x. Typeable x => Put x (Bool -> Pipe context m r)
                                       | CanPut (Bool -> Pipe context m r)

-- | A 'Source' is the read-only end of a 'Pipe' communication channel.
data Source context x = Source context Int String
-- | A 'Sink' is the write-only end of a 'Pipe' communication channel.
data Sink   context x = Sink   context Int String

-- | A computation that consumes values from a 'Source' is called 'Consumer'.
type Consumer m x r = forall c. Source c x -> Pipe c m r
-- | A computation that produces values and puts them into a 'Sink' is called 'Producer'.
type Producer m x r = forall c. Sink c x -> Pipe c m r

-- | Function 'liftPipe' lifts a value of the underlying monad type into a 'Pipe' computation.
liftPipe :: forall context m r. Monad m => m r -> Pipe context m r
liftPipe mr = Pipe (\context pid clock-> liftM (Done clock) mr)

-- | Function 'runPipes' runs the given computation involving pipes and returns the final result.
-- The /context/ argument ensures that no suspended computation can escape its scope.
runPipes :: forall m r. Monad m => (forall context. Pipe context m r) -> m r
runPipes c = proceed c undefined 1 0 >>= \s-> case s of Done _ r -> return r

instance Monad m => Monad (Pipe context m) where
   return r = Pipe (\context pid clock-> return (Done clock r))
   Pipe p >>= f = Pipe (\context pid clock-> p context pid clock >>= apply f context pid)
      where apply :: forall r1 r2. (r1 -> Pipe context m r2) -> context -> Int -> PipeState context m r1 -> m (PipeState context m r2)
            apply f context pid (Done clock r) = proceed (f r) context pid (succ clock)
            apply f _ pid (Suspend context suspensions) = return $ Suspend context (map suspendApplied suspensions)
               where suspendApplied s@Suspension{description= desc, clock= clock', continuation= Get cont}
                        = s{description= "applied " ++ desc, continuation= Get ((f =<<) . cont)}
                     suspendApplied s@Suspension{description= desc, clock= clock', continuation= Put x cont}
                        = s{description= "applied " ++ desc, continuation= Put x ((f =<<) . cont)}
                     suspendApplied s@Suspension{description= desc, clock= clock', continuation= CanPut cont}
                        = s{description= "applied " ++ desc, continuation= CanPut ((f =<<) . cont)}

instance Show (Suspension context m r) where
   show Suspension{pid= pid, description = desc, continuation= c} = (case c of Put{} -> "(Put)"
                                                                               CanPut{} -> "(CanPut)"
                                                                               Get{} -> "(Get)")
                                                                    ++ desc ++ " -> " ++ show pid

-- | The 'pipe' function splits the computation into two concurrent parts, /producer/ and /consumer/. The /producer/ is
-- given a 'Sink' to put values into, and /consumer/ a 'Source' to get those values from.  Once producer and consumer
-- both complete, 'pipe' returns their paired results.
pipe :: forall context x m r1 r2. Monad m => Producer m x r1 -> Consumer m x r2 -> Pipe context m (r1, r2)
pipe = pipeD ""

-- | The 'pipeD' function is same as 'pipe', with an additional description argument.
pipeD :: forall context x m r1 r2. Monad m => String -> Producer m x r1 -> Consumer m x r2 -> Pipe context m (r1, r2)
pipeD description producer consumer
   = Pipe (\context pid clock-> let producerPid = 2*pid
                                    consumerPid = 2*pid+1
                                    context' = undefined
                                    description' = description ++ ':' : show pid
                                in assert (track (indent pid ++ "pipe " ++ description')) $
                                   do ps <- proceed (producer (Sink context' producerPid description')) context' producerPid clock
                                      cs <- proceed (consumer (Source context' consumerPid description')) context' consumerPid clock
                                      reduce context' producerPid ps consumerPid cs)

reduce :: forall c m r1 r2. Monad m => c -> Int -> PipeState c m r1 -> Int -> PipeState c m r2 -> m (PipeState c m (r1, r2))
reduce context pid1 (Done t1 r1) pid2 (Done t2 r2)
   = assert (track (indent pid1 ++ "Done " ++ show pid1 ++ " -> " ++ show pid2)) $
     return (Done (max t1 t2) (r1, r2))
reduce context pid1 (Suspend c1 ps@(Suspension{pid= pid1', clock= t, continuation= pCont} : _)) pid2 consumer@Done{}
   | pid1' == pid1, Put _ cont <- pCont
   = assert (track (indent pid1 ++ "Failed producer put " ++ show ps ++ " from " ++ show pid1)) $
     proceed (cont False) context pid1 t >>= \p'-> reduce context pid1 p' pid2 consumer
   | pid1' == pid1, CanPut cont <- pCont
   = assert (track (indent pid1 ++ "Finish producer " ++ show ps ++ " from " ++ show pid1)) $
     proceed (cont False) context pid1 t >>= \p'-> reduce context pid1 p' pid2 consumer
   | pid1' < pid1 = assert (track (indent pid1 ++ "Suspend producer " ++ show ps ++ " from " ++ show pid1)) $
                    return $ Suspend context $ map (delay (\ps'-> reduce context pid1 ps' pid2 consumer)) ps
   | otherwise = error (show pid1' ++ ">" ++ show pid1 ++ " | producer : " ++ show ps)
reduce context pid1 producer@Done{} pid2 (Suspend c2 cs@(Suspension{pid= pid2', clock= t, continuation= cCont} : _))
   | pid2' == pid2, Get cont <- cCont
   = assert (track (indent pid1 ++ "Finish consumer " ++ show cs ++ " from " ++ show pid2)) $
     proceed (cont Nothing) context pid2 t >>= reduce context pid1 producer pid2
   | pid2' < pid2 = assert (track (indent pid1 ++ "Suspend consumer " ++ show cs ++ " from " ++ show pid2)) $
                    return $ Suspend context $ map (delay (reduce context pid1 producer pid2)) cs
   | otherwise = error (show pid2' ++ ">" ++ show pid2 ++ " | consumer : " ++ show cs)
reduce context pid1 producer@(Suspend _ ps@(Suspension{pid= pid1', clock=t1, continuation= pc} : _))
               pid2 consumer@(Suspend _ cs@(Suspension{pid= pid2', clock=t2, continuation= Get cCont} : _))
   | pid1' == pid1 && pid2' == pid2, CanPut pCont <- pc
   = assert (track (indent pid1 ++ "CanPut Match at " ++ show pid1 ++ "/" ++ show pid2 ++ " : " ++ show ps ++ " -> " ++ show cs)) $
     proceed (pCont True) context pid1 t1 >>= \p'-> reduce context pid1 p' pid2 consumer
   | pid1' == pid1 && pid2' == pid2, Put x pCont <- pc
   = assert (track (indent pid1 ++ "Match at " ++ show pid1 ++ "/" ++ show pid2 ++ " : " ++ show ps ++ " -> " ++ show cs)) $
     let t' = max t1 t2
     in do p' <- assert (track "producer (") $ proceed (pCont True) context pid1 t'
           c' <- assert (track ") consumer (") $ proceed (cCont (cast x)) context pid2 t'
           assert (track ") combined ->") reduce context pid1 p' pid2 c'
reduce context pid1 producer@(Suspend c1 ps) pid2 consumer@(Suspend c2 cs) = assert (track (indent pid1 ++ "Suspend producer & consumer, "
                                                                                            ++ show ps ++ " from " ++ show pid1 ++ " & "
                                                                                            ++ show cs ++ " from " ++ show pid2)) $
                                                                             keepSuspending ps cs
     where keepSuspending (Suspension{pid=pid1'} : pTail) cs | pid1' == pid1 = keepSuspending pTail cs
           keepSuspending ps (Suspension{pid= pid2'} : cTail) | pid2' == pid2 = keepSuspending ps cTail
           keepSuspending ps cs = assert (track (indent pid1 ++ "Suspend' producer & consumer, "
                                                 ++ show ps ++ " from " ++ show pid1 ++ " & "
                                                 ++ show cs ++ " from " ++ show pid2)) $
                                  return $ Suspend context $
                                         merge (map (\p-> delay (\p'-> reduce context pid1 p' pid2 consumer) p) ps)
                                               (map (delay (reduce context pid1 producer pid2)) cs)

merge :: Monad m => [Suspension context m r] -> [Suspension context m r] -> [Suspension context m r]
merge [] l = l
merge l [] = l
merge l1@(h1@Suspension{pid= pid1, clock= c1} : tail1) l2@(h2@Suspension{pid= pid2, clock= c2} : tail2)
   | pid1 > pid2 = h1 : merge tail1 l2
   | pid1 < pid2 = h2 : merge l1 tail2
   | c1 < c2 = h1 : merge tail1 l2
   | otherwise = h2 : merge l1 tail2

delay :: Monad m => (PipeState context m r1 -> m (PipeState context m r2)) -> Suspension context m r1 -> Suspension context m r2
delay f = delay' (\p-> Pipe $ \context pid clock-> proceed p context pid clock >>= f)

delay' :: Monad m => (Pipe context m r1 -> Pipe context m r2) -> Suspension context m r1 -> Suspension context m r2
delay' f s@Suspension{description= desc, continuation= Get cont}
   = s{description= "delayed " ++ desc, continuation= Get (f . cont)}
delay' f s@Suspension{description= desc, continuation= Put x cont}
   = s{description= "delayed " ++ desc, continuation= Put x (f . cont)}
delay' f s@Suspension{description= desc, continuation= CanPut cont}
   = s{description= "delayed " ++ desc, continuation= CanPut (f . cont)}

indent 0 = ""
indent n = ' ' : indent (n `div` 2)

-- | Function 'get' tries to get a value from the given 'Source' argument. The intervening 'Pipe' computations suspend
-- all the way to the 'pipe' function invocation that created the source. The result of 'get' is 'Nothing' iff the
-- argument source is empty.
get :: forall context context' x m r. (Monad m, Typeable x)
       => Source context' x -> Pipe context m (Maybe x)
get (Source _ pid desc) = assert (track (indent pid ++ "Get from " ++ desc ++ "@" ++ show pid)) $
                          Pipe (\context pid' clock->
                                assert (track (indent pid ++ "Get<- " ++ desc ++ "@" ++ show pid ++ ":" ++ show clock)) $
                                return $ Suspend context $
                                [Suspension pid clock ("get from " ++ desc ++ "@" ++ show pid ++ ":" ++ show clock) $ Get return])

getSuccess :: forall context context' x m. (Monad m, Typeable x)
              => Source context' x
                 -> (x -> Pipe context m ()) -- ^ Success continuation
                 -> Pipe context m ()
getSuccess source succeed = get source >>= maybe (return ()) succeed

-- | Function 'put' tries to put a value into the given sink. The intervening 'Pipe' computations suspend up to the
-- 'pipe' invocation that has created the argument sink. The result of 'put' indicates whether the operation succeded.
put :: forall context context' x m r. (Monad m, Typeable x) => Sink context' x -> x -> Pipe context m Bool
put (Sink _ pid desc) x = assert (track (indent pid ++ "Put into " ++ desc ++ "@" ++ show pid)) $
                          Pipe (\context pid' clock->
                                assert (track (indent pid ++ "Put-> " ++ desc ++ "@" ++ show pid ++ ":" ++ show clock)) $
                                return $ Suspend context $
                                [Suspension pid clock ("put into " ++ desc ++ "@" ++ show pid ++ ":" ++ show clock)
                                 (Put x return)])

-- | Function 'canPut' checks if the argument sink accepts values, i.e., whether a 'put' operation would succeed on the
-- sink.
canPut :: forall context context' x m r. (Monad m, Typeable x) => Sink context' x -> Pipe context m Bool
canPut (Sink _ pid desc) = assert (track (indent pid ++ "CanPut into " ++ desc ++ "@" ++ show pid)) $
                           Pipe (\context pid' clock->
                                 assert (track (indent pid ++ "CanPut-> " ++ desc ++ "@" ++ show pid ++ ":" ++ show clock)) $
                                 return $ Suspend context $
                                 [Suspension pid clock ("canPut into " ++ desc ++ "@" ++ show pid ++ ":" ++ show clock)
                                  (CanPut return)])

-- | 'pour' copies all data from the /source/ argument into the /sink/ argument, as long as there is anything to copy
-- and the sink accepts it.
pour :: forall c c1 c2 x m. (Monad m, Typeable x) => Source c1 x -> Sink c2 x -> Pipe c m ()
pour source sink = fill'
   where fill' = canPut sink >>= flip when (getSuccess source (\x-> put sink x >> fill'))

-- | 'tee' is similar to 'pour' except it distributes every input value from the /source/ arguments into both /sink1/
-- and /sink2/.
tee :: (Monad m, Typeable x) => Source c1 x -> Sink c2 x -> Sink c3 x -> Pipe c m ()
tee source sink1 sink2 = distribute
   where distribute = do c1 <- canPut sink1
                         c2 <- canPut sink2
                         when (c1 && c2) (getSuccess source $ \x-> put sink1 x >> put sink2 x >> distribute)

-- | 'putList' puts entire list into its /sink/ argument, as long as the sink accepts it. The remainder that wasn't
-- accepted by the sink is the result value.
putList :: forall x c c1 m. (Monad m, Typeable x) => [x] -> Sink c1 x -> Pipe c m [x]
putList [] sink = return []
putList l@(x:rest) sink = put sink x >>= cond (putList rest sink) (return l)

-- | 'getList' returns the list of all values generated by the source.
getList :: forall x c c1 m. (Monad m, Typeable x) => Source c1 x -> Pipe c m [x]
getList source = get source >>= maybe (return []) (\x-> liftM (x:) (getList source))

-- | 'consumeAndSuppress' consumes the entire source ignoring the values it generates.
consumeAndSuppress :: forall x c c1 m. (Monad m, Typeable x) => Source c1 x -> Pipe c m ()
consumeAndSuppress source = getSuccess source (\x-> consumeAndSuppress source)

-- | A utility function wrapping if-then-else, useful for handling monadic truth values
cond :: a -> a -> Bool -> a
cond x y test = if test then x else y

-- | A utility function, useful for handling monadic list values where empty list means success
whenNull :: forall a m. Monad m => m [a] -> [a] -> m [a]
whenNull action list = if null list then action else return list

track :: String -> Bool
track message = True