{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE CPP #-}
module Network.GRPC.Server.Handlers where
import Control.Concurrent.Async (concurrently)
import Control.Monad (void, (>=>))
import Data.Binary.Get (pushChunk, Decoder(..))
import qualified Data.ByteString.Char8 as ByteString
import Data.ByteString.Char8 (ByteString)
import Data.ByteString.Lazy (toStrict)
import Network.GRPC.HTTP2.Encoding
import Network.GRPC.HTTP2.Types (path, GRPCStatus(..), GRPCStatusCode(..))
#if MIN_VERSION_wai(3,2,2)
import Network.Wai (Request, getRequestBodyChunk, strictRequestBody)
#else
import Network.Wai (Request, requestBody, strictRequestBody)
#endif
#if MIN_VERSION_base(4,11,0)
#else
import Data.Monoid ((<>))
#endif
import Network.GRPC.Server.Wai (WaiHandler, ServiceHandler(..), closeEarly)
#if !MIN_VERSION_wai(3,2,2)
getRequestBodyChunk :: Request -> IO ByteString
getRequestBodyChunk = requestBody
#endif
type UnaryHandler i o = Request -> i -> IO o
type ServerStreamHandler i o a = Request -> i -> IO (a, ServerStream o a)
newtype ServerStream o a = ServerStream {
serverStreamNext :: a -> IO (Maybe (a, o))
}
type ClientStreamHandler i o a = Request -> IO (a, ClientStream i o a)
data ClientStream i o a = ClientStream {
clientStreamHandler :: a -> i -> IO a
, clientStreamFinalizer :: a -> IO o
}
type BiDiStreamHandler i o a = Request -> IO (a, BiDiStream i o a)
data BiDiStep i o a
= Abort
| WaitInput !(a -> i -> IO a) !(a -> IO a)
| WriteOutput !a o
newtype BiDiStream i o a = BiDiStream {
bidirNextStep :: a -> IO (BiDiStep i o a)
}
unary
:: (GRPCInput r i, GRPCOutput r o)
=> r
-> UnaryHandler i o
-> ServiceHandler
unary rpc handler =
ServiceHandler (path rpc) (handleUnary rpc handler)
serverStream
:: (GRPCInput r i, GRPCOutput r o)
=> r
-> ServerStreamHandler i o a
-> ServiceHandler
serverStream rpc handler =
ServiceHandler (path rpc) (handleServerStream rpc handler)
clientStream
:: (GRPCInput r i, GRPCOutput r o)
=> r
-> ClientStreamHandler i o a
-> ServiceHandler
clientStream rpc handler =
ServiceHandler (path rpc) (handleClientStream rpc handler)
bidiStream
:: (GRPCInput r i, GRPCOutput r o)
=> r
-> BiDiStreamHandler i o a
-> ServiceHandler
bidiStream rpc handler =
ServiceHandler (path rpc) (handleBiDiStream rpc handler)
generalStream
:: (GRPCInput r i, GRPCOutput r o)
=> r
-> GeneralStreamHandler i o a b
-> ServiceHandler
generalStream rpc handler =
ServiceHandler (path rpc) (handleGeneralStream rpc handler)
handleUnary
:: (GRPCInput r i, GRPCOutput r o)
=> r
-> UnaryHandler i o
-> WaiHandler
handleUnary rpc handler decoding encoding req write flush =
handleRequestChunksLoop (decodeInput rpc $ _getDecodingCompression decoding)
handleMsg handleEof nextChunk
where
nextChunk = toStrict <$> strictRequestBody req
handleMsg = errorOnLeftOver (handler req >=> reply)
handleEof = closeEarly (GRPCStatus INVALID_ARGUMENT "early end of request body")
reply msg = write (encodeOutput rpc (_getEncodingCompression encoding) msg) >> flush
handleServerStream
:: (GRPCInput r i, GRPCOutput r o)
=> r
-> ServerStreamHandler i o a
-> WaiHandler
handleServerStream rpc handler decoding encoding req write flush =
handleRequestChunksLoop (decodeInput rpc $ _getDecodingCompression decoding)
handleMsg handleEof nextChunk
where
nextChunk = toStrict <$> strictRequestBody req
handleMsg = errorOnLeftOver (handler req >=> replyN)
handleEof = closeEarly (GRPCStatus INVALID_ARGUMENT "early end of request body")
replyN (v, sStream) = do
let go v1 = serverStreamNext sStream v1 >>= \case
Just (v2, msg) -> do
write (encodeOutput rpc (_getEncodingCompression encoding) msg) >> flush
go v2
Nothing -> pure ()
go v
handleClientStream
:: (GRPCInput r i, GRPCOutput r o)
=> r
-> ClientStreamHandler i o a
-> WaiHandler
handleClientStream rpc handler0 decoding encoding req write flush =
handler0 req >>= go
where
go (v, cStream) = handleRequestChunksLoop (decodeInput rpc $ _getDecodingCompression decoding)
(handleMsg v) (handleEof v) nextChunk
where
nextChunk = getRequestBodyChunk req
handleMsg v0 dat msg = clientStreamHandler cStream v0 msg >>= \v1 -> loop dat v1
handleEof v0 = clientStreamFinalizer cStream v0 >>= reply
reply msg = write (encodeOutput rpc (_getEncodingCompression encoding) msg) >> flush
loop chunk v1 = handleRequestChunksLoop
(flip pushChunk chunk $ decodeInput rpc (_getDecodingCompression decoding))
(handleMsg v1) (handleEof v1) nextChunk
handleBiDiStream
:: (GRPCInput r i, GRPCOutput r o)
=> r
-> BiDiStreamHandler i o a
-> WaiHandler
handleBiDiStream rpc handler0 decoding encoding req write flush =
handler0 req >>= go ""
where
nextChunk = getRequestBodyChunk req
reply msg = write (encodeOutput rpc (_getEncodingCompression encoding) msg) >> flush
go chunk (v0, bStream) = do
let cont dat v1 = go dat (v1, bStream)
step <- bidirNextStep bStream v0
case step of
WaitInput handleMsg handleEof ->
handleRequestChunksLoop (flip pushChunk chunk
$ decodeInput rpc
$ _getDecodingCompression decoding)
(\dat msg -> handleMsg v0 msg >>= cont dat)
(handleEof v0 >>= cont "")
nextChunk
WriteOutput v1 msg -> do
reply msg
cont "" v1
Abort -> return ()
type GeneralStreamHandler i o a b =
Request -> IO (a, IncomingStream i a, b, OutgoingStream o b)
data IncomingStream i a = IncomingStream {
incomingStreamHandler :: a -> i -> IO a
, incomingStreamFinalizer :: a -> IO ()
}
newtype OutgoingStream o a = OutgoingStream {
outgoingStreamNext :: a -> IO (Maybe (a, o))
}
handleGeneralStream
:: (GRPCInput r i, GRPCOutput r o)
=> r
-> GeneralStreamHandler i o a b
-> WaiHandler
handleGeneralStream rpc handler0 decoding encoding req write flush = void $
handler0 req >>= go
where
newDecoder = decodeInput rpc $ _getDecodingCompression decoding
nextChunk = getRequestBodyChunk req
reply msg = write (encodeOutput rpc (_getEncodingCompression encoding) msg) >> flush
go (in0, instream, out0, outstream) = concurrently
(incomingLoop newDecoder in0 instream)
(replyLoop out0 outstream)
replyLoop v0 sstream@(OutgoingStream next) =
next v0 >>= \case
Nothing -> return v0
(Just (v1, msg)) -> reply msg >> replyLoop v1 sstream
incomingLoop decode v0 cstream = do
let handleMsg dat msg = do
v1 <- incomingStreamHandler cstream v0 msg
incomingLoop (pushChunk newDecoder dat) v1 cstream
let handleEof = incomingStreamFinalizer cstream v0 >> pure v0
handleRequestChunksLoop decode handleMsg handleEof nextChunk
handleRequestChunksLoop
:: Decoder (Either String a)
-> (ByteString -> a -> IO b)
-> IO b
-> IO ByteString
-> IO b
{-# INLINEABLE handleRequestChunksLoop #-}
handleRequestChunksLoop decoder handleMsg handleEof nextChunk =
case decoder of
(Done unusedDat _ (Right val)) ->
handleMsg unusedDat val
(Done _ _ (Left err)) ->
closeEarly (GRPCStatus INVALID_ARGUMENT (ByteString.pack $ "done-error: " ++ err))
(Fail _ _ err) ->
closeEarly (GRPCStatus INVALID_ARGUMENT (ByteString.pack $ "fail-error: " ++ err))
partial@(Partial _) -> do
chunk <- nextChunk
if ByteString.null chunk
then
handleEof
else
handleRequestChunksLoop (pushChunk partial chunk) handleMsg handleEof nextChunk
errorOnLeftOver :: (a -> IO b) -> ByteString -> a -> IO b
errorOnLeftOver f rest
| ByteString.null rest = f
| otherwise = const $ putStrLn "left-over" >> closeEarly (GRPCStatus INVALID_ARGUMENT ("left-overs: " <> rest))