{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE CPP #-}
module Network.GRPC.Server.Handlers.Trans where

import           Control.Concurrent.Async (concurrently)
import           Control.Monad (void, (>=>))
import           Control.Monad.IO.Class
import           Data.Binary.Get (pushChunk, Decoder(..))
import qualified Data.ByteString.Char8 as ByteString
import           Data.ByteString.Char8 (ByteString)
import           Data.ByteString.Lazy (toStrict)
import           Network.GRPC.HTTP2.Encoding
import           Network.GRPC.HTTP2.Types (path, GRPCStatus(..), GRPCStatusCode(..))
#if MIN_VERSION_wai(3,2,2)
import           Network.Wai (Request, getRequestBodyChunk, strictRequestBody)
#else
import           Network.Wai (Request, requestBody, strictRequestBody)
#endif

#if MIN_VERSION_base(4,11,0)
#else
import Data.Monoid ((<>))
#endif

import Network.GRPC.Server.Wai (WaiHandler, ServiceHandler(..), closeEarly)

#if !MIN_VERSION_wai(3,2,2)
getRequestBodyChunk :: Request -> IO ByteString
getRequestBodyChunk = requestBody
#endif

-- | Handy type to refer to Handler for 'unary' RPCs handler.
type UnaryHandler m i o = Request -> i -> m o

-- | Handy type for 'server-streaming' RPCs.
--
-- We expect an implementation to:
-- - read the input request
-- - return an initial state and an state-passing action that the server code will call to fetch the output to send to the client (or close an a Nothing)
-- See 'ServerStream' for the type which embodies these requirements.
type ServerStreamHandler m i o a = Request -> i -> m (a, ServerStream m o a)

newtype ServerStream m o a = ServerStream {
    serverStreamNext :: a -> m (Maybe (a, o))
  }

-- | Handy type for 'client-streaming' RPCs.
--
-- We expect an implementation to:
-- - acknowledge a the new client stream by returning an initial state and two functions:
-- - a state-passing handler for new client message
-- - a state-aware handler for answering the client when it is ending its stream
-- See 'ClientStream' for the type which embodies these requirements.
type ClientStreamHandler m i o a = Request -> m (a, ClientStream m i o a)

data ClientStream m i o a = ClientStream {
    clientStreamHandler   :: a -> i -> m a
  , clientStreamFinalizer :: a -> m o
  }

-- | Handy type for 'bidirectional-streaming' RPCs.
--
-- We expect an implementation to:
-- - acknowlege a new bidirection stream by returning an initial state and one functions:
-- - a state-passing function that returns a single action step
-- The action may be to
-- - stop immediately
-- - wait and handle some input with a callback and a finalizer (if the client closes the stream on its side) that may change the state
-- - return a value and a new state
--
-- There is no way to stop locally (that would mean sending HTTP2 trailers) and
-- keep receiving messages from the client.
type BiDiStreamHandler m i o a = Request -> m (a, BiDiStream m i o a)

data BiDiStep m i o a
  = Abort
  | WaitInput !(a -> i -> m a) !(a -> m a)
  | WriteOutput !a o

newtype BiDiStream m i o a = BiDiStream {
    bidirNextStep :: a -> m (BiDiStep m i o a)
  }

-- | Construct a handler for handling a unary RPC.
unary
  :: (MonadIO m, GRPCInput r i, GRPCOutput r o)
  => (forall x. m x -> IO x)
  -> r
  -> UnaryHandler m i o
  -> ServiceHandler
unary f rpc handler =
    ServiceHandler (path rpc) (handleUnary f rpc handler)

-- | Construct a handler for handling a server-streaming RPC.
serverStream
  :: (MonadIO m, GRPCInput r i, GRPCOutput r o)
  => (forall x. m x -> IO x)
  -> r
  -> ServerStreamHandler m i o a
  -> ServiceHandler
serverStream f rpc handler =
    ServiceHandler (path rpc) (handleServerStream f rpc handler)

-- | Construct a handler for handling a client-streaming RPC.
clientStream
  :: (MonadIO m, GRPCInput r i, GRPCOutput r o)
  => (forall x. m x -> IO x)
  -> r
  -> ClientStreamHandler m i o a
  -> ServiceHandler
clientStream f rpc handler =
    ServiceHandler (path rpc) (handleClientStream f rpc handler)

-- | Construct a handler for handling a bidirectional-streaming RPC.
bidiStream
  :: (MonadIO m, GRPCInput r i, GRPCOutput r o)
  => (forall x. m x -> IO x)
  -> r
  -> BiDiStreamHandler m i o a
  -> ServiceHandler
bidiStream f rpc handler =
    ServiceHandler (path rpc) (handleBiDiStream f rpc handler)

-- | Construct a handler for handling a bidirectional-streaming RPC.
generalStream
  :: (MonadIO m, GRPCInput r i, GRPCOutput r o)
  => (forall x. m x -> IO x)
  -> r
  -> GeneralStreamHandler m i o a b
  -> ServiceHandler
generalStream f rpc handler =
    ServiceHandler (path rpc) (handleGeneralStream f rpc handler)

-- | Handle unary RPCs.
handleUnary
  :: (MonadIO m, GRPCInput r i, GRPCOutput r o)
  => (forall x. m x -> IO x)
  -> r
  -> UnaryHandler m i o
  -> WaiHandler
handleUnary f rpc handler decoding encoding req write flush = f $
    handleRequestChunksLoop (decodeInput rpc $ _getDecodingCompression decoding)
                            handleMsg handleEof nextChunk
  where
    nextChunk = toStrict <$> strictRequestBody req
    handleMsg = errorOnLeftOver (handler req >=> liftIO . reply)
    handleEof = closeEarly (GRPCStatus INVALID_ARGUMENT "early end of request body")
    reply msg = write (encodeOutput rpc (_getEncodingCompression encoding) msg) >> flush

-- | Handle Server-Streaming RPCs.
handleServerStream
  :: (MonadIO m, GRPCInput r i, GRPCOutput r o)
  => (forall x. m x -> IO x)
  -> r
  -> ServerStreamHandler m i o a
  -> WaiHandler
handleServerStream f rpc handler decoding encoding req write flush = f $
    handleRequestChunksLoop (decodeInput rpc $ _getDecodingCompression decoding)
                            handleMsg handleEof nextChunk
  where
    nextChunk = toStrict <$> strictRequestBody req
    handleMsg = errorOnLeftOver (handler req >=> replyN)
    handleEof = closeEarly (GRPCStatus INVALID_ARGUMENT "early end of request body")
    replyN (v, sStream) = do
        let go v1 = serverStreamNext sStream v1 >>= \case
                Just (v2, msg) -> do
                    liftIO $ write (encodeOutput rpc (_getEncodingCompression encoding) msg)
                    liftIO flush
                    go v2
                Nothing -> pure ()
        go v

-- | Handle Client-Streaming RPCs.
handleClientStream
  :: forall m r i o a.
     (MonadIO m, GRPCInput r i, GRPCOutput r o)
  => (forall x. m x -> IO x)
  -> r
  -> ClientStreamHandler m i o a
  -> WaiHandler
handleClientStream f rpc handler0 decoding encoding req write flush =
    f $ handler0 req >>= go
  where
    go :: (a, ClientStream m i o a) -> m ()
    go (v, cStream) = handleRequestChunksLoop (decodeInput rpc $ _getDecodingCompression decoding)
                                              (handleMsg v) (handleEof v) nextChunk
      where
        nextChunk = getRequestBodyChunk req
        handleMsg v0 dat msg = clientStreamHandler cStream v0 msg >>= \v1 -> loop dat v1
        handleEof v0 = clientStreamFinalizer cStream v0 >>= reply
        reply msg = do
          liftIO $ write (encodeOutput rpc (_getEncodingCompression encoding) msg)
          liftIO flush
        loop chunk v1 = handleRequestChunksLoop
                          (flip pushChunk chunk $ decodeInput rpc (_getDecodingCompression decoding))
                          (handleMsg v1) (handleEof v1) nextChunk

-- | Handle Bidirectional-Streaming RPCs.
handleBiDiStream
  :: forall m r i o a.
     (MonadIO m, GRPCInput r i, GRPCOutput r o)
  => (forall x. m x -> IO x)
  -> r
  -> BiDiStreamHandler m i o a
  -> WaiHandler
handleBiDiStream f rpc handler0 decoding encoding req write flush =
    f $ handler0 req >>= go ""
  where
    nextChunk = getRequestBodyChunk req
    reply msg = write (encodeOutput rpc (_getEncodingCompression encoding) msg) >> flush
    go :: ByteString -> (a, BiDiStream m i o a) -> m ()
    go chunk (v0, bStream) = do
        let cont dat v1 = go dat (v1, bStream)
        step <- bidirNextStep bStream v0
        case step of
            WaitInput handleMsg handleEof ->
                handleRequestChunksLoop (flip pushChunk chunk
                                         $ decodeInput rpc
                                         $ _getDecodingCompression decoding)
                                        (\dat msg -> handleMsg v0 msg >>= cont dat)
                                        (handleEof v0 >>= cont "")
                                        nextChunk
            WriteOutput v1 msg -> do
                liftIO $ reply msg
                cont "" v1
            Abort -> return ()

-- | A GeneralStreamHandler combining server and client asynchronous streams.
type GeneralStreamHandler m i o a b =
    Request -> m (a, IncomingStream m i a, b, OutgoingStream m o b)

-- | Pair of handlers for reacting to incoming messages.
data IncomingStream m i a = IncomingStream {
    incomingStreamHandler   :: a -> i -> m a
  , incomingStreamFinalizer :: a -> m ()
  }

-- | Handler to decide on the next message (if any) to return.
newtype OutgoingStream m o a = OutgoingStream {
    outgoingStreamNext  :: a -> m (Maybe (a, o))
  }

-- | Handler for the somewhat general case where two threads behave concurrently:
-- - one reads messages from the client
-- - one returns messages to the client
handleGeneralStream
  :: forall m r i o a b.
     (MonadIO m, GRPCInput r i, GRPCOutput r o)
  => (forall x. m x -> IO x)
  -> r
  -> GeneralStreamHandler m i o a b
  -> WaiHandler
handleGeneralStream f rpc handler0 decoding encoding req write flush = void $
    f $ handler0 req >>= go
  where
    newDecoder = decodeInput rpc $ _getDecodingCompression decoding
    nextChunk = getRequestBodyChunk req
    reply msg = write (encodeOutput rpc (_getEncodingCompression encoding) msg) >> flush

    go :: (a, IncomingStream m i a, b, OutgoingStream m o b) -> m (a, b)
    go (in0, instream, out0, outstream) = liftIO $ concurrently
        (f $ incomingLoop newDecoder in0 instream)
        (f $ replyLoop out0 outstream)

    replyLoop :: b -> OutgoingStream m o b -> m b
    replyLoop v0 sstream@(OutgoingStream next) =
        next v0 >>= \case
            Nothing          -> return v0
            (Just (v1, msg)) -> liftIO (reply msg) >> replyLoop v1 sstream

    incomingLoop :: Decoder (Either String i) -> a -> IncomingStream m i a -> m a
    incomingLoop decode v0 cstream = do
        let handleMsg dat msg = do
                v1 <- incomingStreamHandler cstream v0 msg
                incomingLoop (pushChunk newDecoder dat) v1 cstream
        let handleEof = incomingStreamFinalizer cstream v0 >> pure v0
        handleRequestChunksLoop decode handleMsg handleEof nextChunk


-- | Helpers to consume input in chunks.
handleRequestChunksLoop
  :: (MonadIO m)
  => Decoder (Either String a)
  -- ^ Message decoder.
  -> (ByteString -> a -> m b)
  -- ^ Handler for a single message.
  -- The ByteString corresponds to leftover data.
  -> m b
  -- ^ Handler for handling end-of-streams.
  -> IO ByteString
  -- ^ Action to retrieve the next chunk.
  -> m b
{-# INLINEABLE handleRequestChunksLoop #-}
handleRequestChunksLoop decoder handleMsg handleEof nextChunk =
    case decoder of
        (Done unusedDat _ (Right val)) ->
            handleMsg unusedDat val
        (Done _ _ (Left err)) ->
            closeEarly (GRPCStatus INVALID_ARGUMENT (ByteString.pack $ "done-error: " ++ err))
        (Fail _ _ err)         ->
            closeEarly (GRPCStatus INVALID_ARGUMENT (ByteString.pack $ "fail-error: " ++ err))
        partial@(Partial _)    -> do
            chunk <- liftIO nextChunk
            if ByteString.null chunk
            then
                handleEof
            else
                handleRequestChunksLoop (pushChunk partial chunk) handleMsg handleEof nextChunk

-- | Combinator around message handler to error on left overs.
--
-- This combinator ensures that, unless for client stream, an unparsed piece of
-- data with a correctly-read message is treated as an error.
errorOnLeftOver :: MonadIO m => (a -> m b) -> ByteString -> a -> m b
errorOnLeftOver f rest
  | ByteString.null rest = f
  | otherwise            = const $ do
     liftIO (putStrLn "left-over")
     closeEarly (GRPCStatus INVALID_ARGUMENT ("left-overs: " <> rest))