-- | This monad is used to increase the speed of communication between two processes - if there is
-- latency. It works by using the non-deterministic part of the communication protocol to send
-- multiple requests to the output-channel, before processing the responses from the input-channel.
--
-- Considering the example
--
-- @
-- foo = splitRequests [bar, baz]
-- bar = do x <- request (GetSumOf 1 2)
--          liftM Sum request (GetSumOf x 3)
-- baz = liftM Sum request (GetSumOf 4 5)
-- @
--
-- running @foo@ in the @RequestMonad@:
--
-- @
-- runRequestMonad inputHandle outputHandle foo
-- @
--
-- will send both messages @GetSumOf 1 2@, @GetSumOf 4 5@, without having to wait for the repsonse
-- to the first request. The last request @GetSumOf 3 3@ will be send after the response for the
-- first message has arrived.
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Sync.MerkleTree.Util.RequestMonad
    ( RequestMonad
    , request
    , runRequestMonad
    , splitRequests
    ) where

import Control.Applicative(Applicative(..))
import Control.Concurrent(Chan, writeChan, readChan, newChan, forkIO)
import Control.Monad(ap,liftM,unless)
import Control.Monad.IO.Class(MonadIO(..))
import Data.ByteString(ByteString)
import Data.IORef(IORef,newIORef,modifyIORef,readIORef)
import Data.Monoid(Monoid, mempty, mappend)
import System.IO.Streams(InputStream, OutputStream)
import Sync.MerkleTree.Util.GetFromInputStream
import qualified Data.Bytes.Serial as SE
import qualified Data.Bytes.Put as P
import qualified System.IO.Streams as ST

data SplitState f b =
    forall a. (Monoid a) => SplitState [RequestMonadT f a] a (a -> RequestMonadT ByteString b)
data RequestState f b = forall a. (SE.Serial a) => RequestState f (a -> RequestMonadT ByteString b)
data LiftIOState b = forall a. LiftIOState (IO a) (a -> RequestMonadT ByteString b)

newtype RequestMonad b = RequestMonad { unReqMonad :: RequestMonadT ByteString b }
    deriving (Monad, Functor, Applicative, MonadIO)

data RequestMonadT f b
    = Split (SplitState f b)
    | Request (RequestState f b)
    | LiftIO (LiftIOState b)
    | Return b
    | Fail String

instance Functor (RequestMonadT ByteString) where
    fmap = liftM

instance Applicative (RequestMonadT ByteString) where
    pure  = return
    (<*>) = ap

instance Monad (RequestMonadT ByteString) where
    return = Return
    fail = Fail
    (>>=) = bindImpl

instance MonadIO (RequestMonadT ByteString) where
    liftIO x =  LiftIO $ LiftIOState x Return

bindImpl ::
    (RequestMonadT ByteString a)
    -> (a -> RequestMonadT ByteString b)
    -> (RequestMonadT ByteString b)
bindImpl f g =
    case f of
      Split (SplitState xs z cont) -> Split (SplitState xs z (\t -> bindImpl (cont t) g))
      Request (RequestState r cont) -> Request (RequestState r (\t -> bindImpl (cont t) g))
      LiftIO (LiftIOState op cont) -> LiftIO (LiftIOState op (\t -> bindImpl (cont t) g))
      Return x -> g x
      Fail s -> Fail s

request :: (SE.Serial a, SE.Serial b) => a -> RequestMonad b
request x = RequestMonad $ Request $ RequestState (P.runPutS $ SE.serialize x) Return

-- | Combine results in the monad non-deterministically
-- (it is required that the monoid is commutative)
splitRequests :: (Monoid a) => [RequestMonad a] -> RequestMonad a
splitRequests alts = RequestMonad $ Split $ SplitState (map unReqMonad alts) mempty Return

data SendQueue
    = SendQueue
    { sq_chan :: Chan (Maybe ByteString)
    , sq_sendIndex :: IORef Int
    }

queueRequests :: SendQueue -> (RequestMonadT ByteString b) -> IO (RequestMonadT Int b)
queueRequests sq root =
    case root of
      LiftIO (LiftIOState op cont) -> return $ LiftIO (LiftIOState op cont)
      Request (RequestState r c) ->
          do writeChan (sq_chan sq) (Just r)
             modifyIORef (sq_sendIndex sq) (+1)
             i <- readIORef (sq_sendIndex sq)
             return $ Request (RequestState i c)
      Split (SplitState xs z cont) ->
          do xs' <- mapM (queueRequests sq) xs
             return $ Split $ SplitState xs' z cont
      Return x -> return $ Return x
      Fail s -> return $ Fail s

-- | Run the provided request monad using the given communication channels
runRequestMonad ::
    InputStream ByteString
    -> OutputStream ByteString
    -> RequestMonad b
    -> IO b
runRequestMonad is os startMonad =
    do sendChan <- newChan
       recvIdx <- newIORef 0
       sendIdx <- newIORef 0
       _ <- forkIO $ writerThread os sendChan
       let sq = SendQueue { sq_chan = sendChan, sq_sendIndex = sendIdx }
           loop monad =
               do monad' <- receiverThread recvIdx sq is monad
                  case monad' of
                    Return x -> writeChan sendChan Nothing >> return x
                    Fail err -> fail err
                    _ -> loop monad'
       queueRequests sq (unReqMonad startMonad) >>= loop

writerThread :: OutputStream ByteString -> Chan (Maybe ByteString) -> IO ()
writerThread os chan = loop
    where
      loop =
          do mBs <- readChan chan
             ST.write mBs os
             ST.write (Just "") os
             maybe (return ()) (const loop) mBs


receiverThread ::
    IORef Int
    -> SendQueue
    -> InputStream ByteString
    -> RequestMonadT Int b
    -> IO (RequestMonadT Int b)
receiverThread recvIdx sq input root =
    case root of
      LiftIO (LiftIOState op cont) -> op >>= (queueRequests sq . cont)
      Request (RequestState i cont) ->
          do x <- getFromInputStream input
             modifyIORef recvIdx (+1)
             expected <- readIORef recvIdx
             unless (expected == i) $ fail ("Expected " ++ (show i) ++ " but got " ++ show expected)
             queueRequests sq $ cont x
      Split (SplitState xs z cont) -> loop cont z xs []
      Return x -> return $ Return x
      Fail err -> return $ Fail err
    where
      loop cont z [] [] = queueRequests sq $ cont z
      loop cont z [] r = return $ Split $ SplitState (reverse r) z cont
      loop cont z (x:xs) r =
          do x' <- receiverThread recvIdx sq input x
             case x' of
               Return x'' -> loop cont (z `mappend` x'') xs r
               Fail s -> return $ Fail s
               other -> loop cont z xs (other:r)