{-# LANGUAGE DerivingVia, DeriveGeneric, RankNTypes, ScopedTypeVariables, MultiParamTypeClasses, OverloadedStrings, GeneralizedNewtypeDeriving, CPP, ExistentialQuantification, StandaloneDeriving, GADTs #-} {-# OPTIONS_GHC -fno-warn-orphans #-} {- HLINT ignore "Use lambda-case" -} module Network.RPC.Curryer.Server where import qualified Streamly.Prelude as S import Streamly.Network.Socket as SSock import Network.Socket as Socket import Network.Socket.ByteString as Socket import Streamly.Internal.Data.Parser as P hiding (concatMap) import Codec.Winery import Codec.Winery.Internal (varInt, decodeVarInt, getBytes) import GHC.Generics import GHC.Fingerprint import Data.Typeable import Control.Concurrent.MVar (MVar, newMVar, withMVar) import Control.Exception import Data.Function ((&)) import Data.Word import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as BSL import qualified Data.ByteString.FastBuilder as BB import Streamly.Data.Fold as FL hiding (foldr) import qualified Streamly.Internal.Data.Stream.IsStream as P import qualified Data.Binary as B import qualified Data.UUID as UUIDBase import qualified Data.UUID.V4 as UUIDBase import Control.Monad import Data.Functor import Control.Applicative import qualified Network.RPC.Curryer.StreamlyAdditions as SA import Data.Hashable import System.Timeout import qualified Network.ByteOrder as BO -- for toArrayS conversion import qualified Data.ByteString.Internal as BSI import qualified Streamly.Internal.Data.Array.Foreign.Type as Arr import qualified Streamly.Internal.Data.Array.Foreign.Mut.Type as ArrT import GHC.ForeignPtr (ForeignPtr(ForeignPtr)) import GHC.Ptr (minusPtr, Ptr(..)) --define CURRYER_SHOW_BYTES 1 #if CURRYER_SHOW_BYTES == 1 import Debug.Trace #endif traceBytes :: Applicative f => String -> BS.ByteString -> f () #if CURRYER_SHOW_BYTES == 1 traceBytes msg bs = traceShowM (msg, BS.length bs, bs) #else traceBytes _ _ = pure () #endif -- a level of indirection to be able to switch between serialising with and without the winery schema msgSerialise :: Serialise a => a -> BS.ByteString msgSerialise = serialiseOnly msgDeserialise :: forall s. Serialise s => BS.ByteString -> Either WineryException s msgDeserialise = deserialiseOnly data Locking a = Locking (MVar ()) a newLock :: a -> IO (Locking a) newLock x = do lock <- newMVar () pure (Locking lock x) withLock :: Locking a -> (a -> IO b) -> IO b withLock (Locking mvar v) m = withMVar mvar $ \_ -> m v lockless :: Locking a -> a lockless (Locking _ a) = a type Timeout = Word32 type BinaryMessage = BS.ByteString --includes the fingerprint of the incoming data type (held in the BinaryMessage) to determine how to dispatch the message. --add another envelope type for unencoded binary messages for any easy optimization for in-process communication data Envelope = Envelope { envFingerprint :: !Fingerprint, envMessageType :: !MessageType, envMsgId :: !UUID, envPayload :: !BinaryMessage } deriving (Generic, Show) type TimeoutMicroseconds = Int deriving instance Generic Fingerprint deriving via WineryVariant Fingerprint instance Serialise Fingerprint -- | Internal type used to mark envelope types. data MessageType = RequestMessage TimeoutMicroseconds | ResponseMessage | TimeoutResponseMessage | ExceptionResponseMessage deriving (Generic, Show) deriving Serialise via WineryVariant MessageType -- | A list of `RequestHandler`s. type RequestHandlers serverState = [RequestHandler serverState] -- | Data types for server-side request handlers, in synchronous (client waits for return value) and asynchronous (client does not wait for return value) forms. data RequestHandler serverState where -- | create a request handler with a response RequestHandler :: forall a b serverState. (Serialise a, Serialise b) => (ConnectionState serverState -> a -> IO b) -> RequestHandler serverState -- | create an asynchronous request handler where the client does not expect nor await a response AsyncRequestHandler :: forall a serverState. Serialise a => (ConnectionState serverState -> a -> IO ()) -> RequestHandler serverState -- | Server state sent in via `serve` and passed to `RequestHandler`s. data ConnectionState a = ConnectionState { connectionServerState :: a, connectionSocket :: Locking Socket } -- | Used by server-side request handlers to send additional messages to the client. This is useful for sending asynchronous responses to the client outside of the normal request-response flow. The locking socket can be found in the ConnectionState when a request handler is called. sendMessage :: Serialise a => Locking Socket -> a -> IO () sendMessage lockSock msg = do requestID <- UUID <$> UUIDBase.nextRandom let env = Envelope (fingerprint msg) (RequestMessage timeout') requestID (msgSerialise msg) timeout' = 0 sendEnvelope env lockSock --avoid orphan instance newtype UUID = UUID { _unUUID :: UUIDBase.UUID } deriving (Show, Eq, B.Binary, Hashable) instance Serialise UUID where schemaGen _ = pure (STag (TagStr "Data.UUID") SBytes) toBuilder uuid = let bytes = BSL.toStrict (B.encode uuid) in varInt (BS.length bytes) <> BB.byteString bytes {-# INLINE toBuilder #-} extractor = mkExtractor $ \schema' -> case schema' of STag (TagStr "Data.UUID") SBytes -> pure $ \term -> case term of TBytes bs -> B.decode (BSL.fromStrict bs) term' -> throw (InvalidTerm term') x -> error $ "invalid schema element " <> show x decodeCurrent = B.decode . BSL.fromStrict <$> (decodeVarInt >>= getBytes) -- | Errors from remote calls. data ConnectionError = CodecError String -- show of WineryException from exception initiator which cannot otherwise be transmitted over a line due to dependencies on TypeReps | TimeoutError | ExceptionError String deriving (Generic, Show, Eq) deriving Serialise via WineryVariant ConnectionError data TimeoutException = TimeoutException deriving Show instance Exception TimeoutException type HostAddr = (Word8, Word8, Word8, Word8) allHostAddrs,localHostAddr :: HostAddr allHostAddrs = (0,0,0,0) localHostAddr = (127,0,0,1) msgTypeP :: Parser IO Word8 MessageType msgTypeP = (P.satisfy (== 0) *> (RequestMessage . fromIntegral <$> word32P)) <|> (P.satisfy (== 1) $> ResponseMessage) <|> (P.satisfy (== 2) $> TimeoutResponseMessage) <|> (P.satisfy (== 3) $> ExceptionResponseMessage) -- Each message is length-prefixed by a 32-bit unsigned length. envelopeP :: Parser IO Word8 Envelope envelopeP = do let lenPrefixedByteStringP = do c <- fromIntegral <$> word32P --streamly can't handle takeEQ 0, so add special handling -- traceShowM ("envelopeP payload byteCount"::String, c) if c == 0 then pure BS.empty else fromArray <$> P.takeEQ c (Arr.writeN c) Envelope <$> fingerprintP <*> msgTypeP <*> uuidP <*> lenPrefixedByteStringP --overhead is fingerprint (16 bytes), msgType (1+4 optional bytes for request message), msgId (4 bytes), uuid (16 bytes) = 41 bytes per request message, 37 bytes for all others encodeEnvelope :: Envelope -> BS.ByteString encodeEnvelope (Envelope (Fingerprint fp1 fp2) msgType msgId bs) = {- traceShow ("encodeEnvelope"::String, ("fingerprint len"::String, BS.length fingerprintBs), ("msgtype length"::String,BS.length msgTypeBs), ("id len"::String, BS.length msgIdBs), ("payload len"::String, payloadLen), ("complete len"::String, BS.length completeMessage)) $-} completeMessage where completeMessage = fingerprintBs <> msgTypeBs <> msgIdBs <> lenPrefixedBs fingerprintBs = BO.bytestring64 fp1 <> BO.bytestring64 fp2 msgTypeBs = case msgType of RequestMessage timeoutms -> BS.singleton 0 <> BO.bytestring32 (fromIntegral timeoutms) ResponseMessage -> BS.singleton 1 TimeoutResponseMessage -> BS.singleton 2 ExceptionResponseMessage -> BS.singleton 3 msgIdBs = case UUIDBase.toWords (_unUUID msgId) of (u1, u2, u3, u4) -> foldr ((<>) . BO.bytestring32) BS.empty [u1, u2, u3, u4] lenPrefixedBs = BO.bytestring32 payloadLen <> bs payloadLen = fromIntegral (BS.length bs) fingerprintP :: Parser IO Word8 Fingerprint fingerprintP = Fingerprint <$> word64P <*> word64P word64P :: Parser IO Word8 Word64 word64P = do let s = FL.toList b <- P.takeEQ 8 s pure (BO.word64 (BS.pack b)) --parse a 32-bit integer from network byte order word32P :: Parser IO Word8 Word32 word32P = do let s = FL.toList w4x8 <- P.takeEQ 4 s -- traceShowM ("w4x8"::String, BO.word32 (BS.pack w4x8)) pure (BO.word32 (BS.pack w4x8)) -- uuid is encode as 4 32-bit words because of its convenient 32-bit tuple encoding uuidP :: Parser IO Word8 UUID uuidP = do u1 <- word32P u2 <- word32P u3 <- word32P --u4 <- word32P --pure (UUID (UUIDBase.fromWords u1 u2 u3 u4))-} --(UUID . UUIDBase.fromWords) <$> word32P <*> word32P <*> word32P <*> word32P UUID . UUIDBase.fromWords u1 u2 u3 <$> word32P type NewConnectionHandler msg = IO (Maybe msg) type NewMessageHandler req resp = req -> IO resp -- | Listen for new connections and handle requests which are passed the server state 's'. The MVar SockAddr can be be optionally used to know when the server is ready for processing requests. serve :: RequestHandlers s-> s -> HostAddr -> PortNumber -> Maybe (MVar SockAddr) -> IO Bool serve userMsgHandlers serverState hostaddr port mSockLock = do let handleSock sock = do lockingSocket <- newLock sock drainSocketMessages sock (serverEnvelopeHandler lockingSocket userMsgHandlers serverState) S.fromSerial (S.unfold (SA.acceptOnAddrWith [(ReuseAddr,1)] mSockLock) (hostaddr, port)) & S.fromParallel . S.mapM (forSocketM handleSock) & S.drain pure True openEnvelope :: forall s. (Serialise s, Typeable s) => Envelope -> Maybe s openEnvelope (Envelope eprint _ _ bytes) = if eprint == fingerprint (undefined :: s) then case msgDeserialise bytes of Left _e -> {-traceShow ("openEnv error"::String, e) $-} Nothing Right decoded -> Just decoded else Nothing --use winery to decode only the data structure and skip the schema deserialiseOnly :: forall s. Serialise s => BS.ByteString -> Either WineryException s deserialiseOnly bytes = do dec <- getDecoder (schema (Proxy :: Proxy s)) pure (evalDecoder dec bytes) matchEnvelope :: forall a b s. (Serialise a, Serialise b, Typeable b) => Envelope -> (ConnectionState s -> a -> IO b) -> Maybe (ConnectionState s -> a -> IO b, a) matchEnvelope envelope dispatchf = case openEnvelope envelope :: Maybe a of Nothing -> Nothing Just decoded -> Just (dispatchf, decoded) -- | Called by `serve` to process incoming envelope requests. Never returns, so use `async` to spin it off on another thread. serverEnvelopeHandler :: Locking Socket -> RequestHandlers s -> s -> Envelope -> IO () serverEnvelopeHandler _ _ _ (Envelope _ TimeoutResponseMessage _ _) = pure () serverEnvelopeHandler _ _ _ (Envelope _ ExceptionResponseMessage _ _) = pure () serverEnvelopeHandler _ _ _ (Envelope _ ResponseMessage _ _) = pure () serverEnvelopeHandler sockLock msgHandlers serverState envelope@(Envelope _ (RequestMessage timeoutms) msgId _) = do --find first matching handler let runTimeout :: IO b -> IO (Maybe b) runTimeout m = if timeoutms == 0 then (Just <$> m) `catch` timeoutExcHandler else (timeout (fromIntegral timeoutms) m) `catch` timeoutExcHandler --allow server-side function to throw TimeoutError which is caught here and becomes TimeoutError value timeoutExcHandler :: TimeoutException -> IO (Maybe b) timeoutExcHandler _ = pure Nothing sState = ConnectionState { connectionServerState = serverState, connectionSocket = sockLock } firstMatcher (RequestHandler msghandler) Nothing = case matchEnvelope envelope msghandler of Nothing -> pure Nothing Just (dispatchf, decoded) -> do --TODO add exception handling mResponse <- runTimeout (dispatchf sState decoded) let envelopeResponse = case mResponse of Just response -> Envelope (fingerprint response) ResponseMessage msgId (msgSerialise response) Nothing -> Envelope (fingerprint TimeoutError) TimeoutResponseMessage msgId BS.empty sendEnvelope envelopeResponse sockLock pure (Just ()) firstMatcher (AsyncRequestHandler msghandler) Nothing = case matchEnvelope envelope msghandler of Nothing -> pure Nothing Just (dispatchf, decoded) -> do _ <- dispatchf sState decoded pure (Just ()) firstMatcher _ acc = pure acc eExc <- try $ foldM_ (flip firstMatcher) Nothing msgHandlers :: IO (Either SomeException ()) case eExc of Left exc -> let env = Envelope (fingerprint (show exc)) ExceptionResponseMessage msgId (msgSerialise (show exc)) in sendEnvelope env sockLock Right () -> pure () type EnvelopeHandler = Envelope -> IO () drainSocketMessages :: Socket -> EnvelopeHandler -> IO () drainSocketMessages sock envelopeHandler = do S.unfold SSock.read sock & P.parseMany envelopeP & S.mapM envelopeHandler & S.fromAsync & S.drain --send length-tagged bytestring, perhaps should be in network byte order? sendEnvelope :: Envelope -> Locking Socket -> IO () sendEnvelope envelope sockLock = do let envelopebytes = encodeEnvelope envelope --Socket.sendAll syscalls send() on a loop until all the bytes are sent, so we need socket locking here to account for serialized messages of size > PIPE_BUF withLock sockLock $ \socket' -> do {-traceShowM ("sendEnvelope"::String, ("type"::String, envMessageType envelope), socket', ("env len"::String, BS.length envelopebytes), "payloadbytes"::String, envPayload envelope)-} Socket.sendAll socket' envelopebytes -- traceBytes "sendEnvelope" envelopebytes fingerprint :: Typeable a => a -> Fingerprint fingerprint = typeRepFingerprint . typeOf fromArray :: Arr.Array Word8 -> BSI.ByteString fromArray arr | aLen == 0 = mempty | otherwise = {-traceShow ("bsi len"::String, aLen, Arr.byteLength arr) $-} BSI.PS aStartFPtr 0 aLen where aStart = Arr.arrStart arr aEnd = Arr.aEnd arr aStartFPtr = case Arr.arrStart arr of Ptr addr -> ForeignPtr addr (ArrT.arrayToFptrContents (Arr.arrContents arr)) aLen = aEnd `minusPtr` aStart