{-# LANGUAGE RankNTypes, ScopedTypeVariables, GADTs, TypeApplications #-}
module Network.RPC.Curryer.Client where
import Network.RPC.Curryer.Server
import Network.Socket as Socket
import qualified Streamly.Network.Inet.TCP as TCP
import Codec.Winery
import Control.Concurrent.Async
import qualified Data.UUID.V4 as UUIDBase
import qualified StmContainers.Map as STMMap
import Control.Concurrent.MVar
import GHC.Conc
import Data.Time.Clock
import System.Timeout
import Control.Monad

type SyncMap = STMMap.Map UUID (MVar (Either ConnectionError BinaryMessage), UTCTime)

-- | Represents a remote connection to server.
data Connection = Connection { Connection -> Locking Socket
_conn_sockLock :: Locking Socket,
                               Connection -> Async ()
_conn_asyncThread :: Async (),
                               Connection -> SyncMap
_conn_syncmap :: SyncMap
                             }

-- | Function handlers run on the client, triggered by the server- useful for asynchronous callbacks.
data ClientAsyncRequestHandler where
  ClientAsyncRequestHandler :: forall a. Serialise a => (a -> IO ()) -> ClientAsyncRequestHandler

type ClientAsyncRequestHandlers = [ClientAsyncRequestHandler]

-- | Connects to a remote server with specific async callbacks registered.
connect :: 
  ClientAsyncRequestHandlers ->
  HostAddr ->
  PortNumber ->
  IO Connection
connect :: ClientAsyncRequestHandlers
-> HostAddr -> PortNumber -> IO Connection
connect ClientAsyncRequestHandlers
asyncHandlers HostAddr
hostAddr PortNumber
portNum = do
  Socket
sock <- HostAddr -> PortNumber -> IO Socket
TCP.connect HostAddr
hostAddr PortNumber
portNum
  SyncMap
syncmap <- IO SyncMap
forall key value. IO (Map key value)
STMMap.newIO
  Async ()
asyncThread <- IO () -> IO (Async ())
forall a. IO a -> IO (Async a)
async (Socket -> SyncMap -> ClientAsyncRequestHandlers -> IO ()
clientAsync Socket
sock SyncMap
syncmap ClientAsyncRequestHandlers
asyncHandlers)
  Locking Socket
sockLock <- Socket -> IO (Locking Socket)
forall a. a -> IO (Locking a)
newLock Socket
sock
  Connection -> IO Connection
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Connection :: Locking Socket -> Async () -> SyncMap -> Connection
Connection {
           _conn_sockLock :: Locking Socket
_conn_sockLock = Locking Socket
sockLock,
           _conn_asyncThread :: Async ()
_conn_asyncThread = Async ()
asyncThread,
           _conn_syncmap :: SyncMap
_conn_syncmap = SyncMap
syncmap
           })

-- | Close the connection and release all connection resources.
close :: Connection -> IO ()
close :: Connection -> IO ()
close Connection
conn = do
  Locking Socket -> (Socket -> IO ()) -> IO ()
forall a b. Locking a -> (a -> IO b) -> IO b
withLock (Connection -> Locking Socket
_conn_sockLock Connection
conn) ((Socket -> IO ()) -> IO ()) -> (Socket -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Socket
sock ->
    Socket -> IO ()
Socket.close Socket
sock
  Async () -> IO ()
forall a. Async a -> IO ()
cancel (Connection -> Async ()
_conn_asyncThread Connection
conn)

-- | async thread for handling client-side incoming messages- dispatch to proper waiting thread or asynchronous notifications handler
clientAsync :: 
  Socket ->
  SyncMap ->
  ClientAsyncRequestHandlers ->
  IO ()
clientAsync :: Socket -> SyncMap -> ClientAsyncRequestHandlers -> IO ()
clientAsync Socket
sock SyncMap
syncmap ClientAsyncRequestHandlers
asyncHandlers = do
  Locking Socket
lsock <- Socket -> IO (Locking Socket)
forall a. a -> IO (Locking a)
newLock Socket
sock
  Socket -> EnvelopeHandler -> IO ()
drainSocketMessages Socket
sock (ClientAsyncRequestHandlers
-> Locking Socket -> SyncMap -> EnvelopeHandler
clientEnvelopeHandler ClientAsyncRequestHandlers
asyncHandlers Locking Socket
lsock SyncMap
syncmap)

consumeResponse :: UUID -> STMMap.Map UUID (MVar a, b) -> a -> IO ()
consumeResponse :: UUID -> Map UUID (MVar a, b) -> a -> IO ()
consumeResponse UUID
msgId Map UUID (MVar a, b)
syncMap a
val = do
  Maybe (MVar a, b)
match <- STM (Maybe (MVar a, b)) -> IO (Maybe (MVar a, b))
forall a. STM a -> IO a
atomically (STM (Maybe (MVar a, b)) -> IO (Maybe (MVar a, b)))
-> STM (Maybe (MVar a, b)) -> IO (Maybe (MVar a, b))
forall a b. (a -> b) -> a -> b
$ do
    Maybe (MVar a, b)
val' <- UUID -> Map UUID (MVar a, b) -> STM (Maybe (MVar a, b))
forall key value.
(Eq key, Hashable key) =>
key -> Map key value -> STM (Maybe value)
STMMap.lookup UUID
msgId Map UUID (MVar a, b)
syncMap
    UUID -> Map UUID (MVar a, b) -> STM ()
forall key value.
(Eq key, Hashable key) =>
key -> Map key value -> STM ()
STMMap.delete UUID
msgId Map UUID (MVar a, b)
syncMap
    Maybe (MVar a, b) -> STM (Maybe (MVar a, b))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (MVar a, b)
val'
  case Maybe (MVar a, b)
match of
    Maybe (MVar a, b)
Nothing -> () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure () -- drop message
    Just (MVar a
mVar, b
_) -> MVar a -> a -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar a
mVar a
val

-- | handles envelope responses from server- timeout from ths server is ignored, but perhaps that's proper for trusted servers- the server expects the client to process all async requests
clientEnvelopeHandler ::
  ClientAsyncRequestHandlers
  -> Locking Socket
  -> SyncMap
  -> Envelope
  -> IO ()
clientEnvelopeHandler :: ClientAsyncRequestHandlers
-> Locking Socket -> SyncMap -> EnvelopeHandler
clientEnvelopeHandler ClientAsyncRequestHandlers
handlers Locking Socket
_ SyncMap
_ envelope :: Envelope
envelope@(Envelope Fingerprint
_ (RequestMessage TimeoutMicroseconds
_) UUID
_ BinaryMessage
_) = do
  --should this run off on another green thread?
  let firstMatcher :: Maybe () -> ClientAsyncRequestHandler -> IO (Maybe ())
firstMatcher Maybe ()
Nothing (ClientAsyncRequestHandler (a -> IO ()
dispatchf :: a -> IO ())) = do
        case Envelope -> Maybe a
forall s. (Serialise s, Typeable s) => Envelope -> Maybe s
openEnvelope Envelope
envelope of
          Maybe a
Nothing -> Maybe () -> IO (Maybe ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe ()
forall a. Maybe a
Nothing
          Just a
decoded -> do
            a -> IO ()
dispatchf a
decoded
            Maybe () -> IO (Maybe ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Maybe ()
forall a. a -> Maybe a
Just ())
      firstMatcher Maybe ()
acc ClientAsyncRequestHandler
_ = Maybe () -> IO (Maybe ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe ()
acc
  (Maybe () -> ClientAsyncRequestHandler -> IO (Maybe ()))
-> Maybe () -> ClientAsyncRequestHandlers -> IO ()
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ Maybe () -> ClientAsyncRequestHandler -> IO (Maybe ())
firstMatcher Maybe ()
forall a. Maybe a
Nothing ClientAsyncRequestHandlers
handlers
clientEnvelopeHandler ClientAsyncRequestHandlers
_ Locking Socket
_ SyncMap
syncMap (Envelope Fingerprint
_ MessageType
ResponseMessage UUID
msgId BinaryMessage
binaryMessage) =
  UUID -> SyncMap -> Either ConnectionError BinaryMessage -> IO ()
forall a b. UUID -> Map UUID (MVar a, b) -> a -> IO ()
consumeResponse UUID
msgId SyncMap
syncMap (BinaryMessage -> Either ConnectionError BinaryMessage
forall a b. b -> Either a b
Right BinaryMessage
binaryMessage)
clientEnvelopeHandler ClientAsyncRequestHandlers
_ Locking Socket
_ SyncMap
syncMap (Envelope Fingerprint
_ MessageType
TimeoutResponseMessage UUID
msgId BinaryMessage
_) =
  UUID -> SyncMap -> Either ConnectionError BinaryMessage -> IO ()
forall a b. UUID -> Map UUID (MVar a, b) -> a -> IO ()
consumeResponse UUID
msgId SyncMap
syncMap (ConnectionError -> Either ConnectionError BinaryMessage
forall a b. a -> Either a b
Left ConnectionError
TimeoutError)
clientEnvelopeHandler ClientAsyncRequestHandlers
_ Locking Socket
_ SyncMap
syncMap (Envelope Fingerprint
_ MessageType
ExceptionResponseMessage UUID
msgId BinaryMessage
excPayload) = 
  case BinaryMessage -> Either WineryException String
forall s. Serialise s => BinaryMessage -> Either WineryException s
msgDeserialise BinaryMessage
excPayload of
        Left WineryException
err -> String -> IO ()
forall a. HasCallStack => String -> a
error (String
"failed to deserialise exception string" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> WineryException -> String
forall a. Show a => a -> String
show WineryException
err)
        Right String
excStr ->
          UUID -> SyncMap -> Either ConnectionError BinaryMessage -> IO ()
forall a b. UUID -> Map UUID (MVar a, b) -> a -> IO ()
consumeResponse UUID
msgId SyncMap
syncMap (ConnectionError -> Either ConnectionError BinaryMessage
forall a b. a -> Either a b
Left (String -> ConnectionError
ExceptionError String
excStr))
      
-- | Basic remote function call via data type and return value.
call :: (Serialise request, Serialise response) => Connection -> request -> IO (Either ConnectionError response)
call :: Connection -> request -> IO (Either ConnectionError response)
call = Maybe TimeoutMicroseconds
-> Connection -> request -> IO (Either ConnectionError response)
forall request response.
(Serialise request, Serialise response) =>
Maybe TimeoutMicroseconds
-> Connection -> request -> IO (Either ConnectionError response)
callTimeout Maybe TimeoutMicroseconds
forall a. Maybe a
Nothing

-- | Send a request to the remote server and returns a response but with the possibility of a timeout after n microseconds.
callTimeout :: (Serialise request, Serialise response) => Maybe Int -> Connection -> request -> IO (Either ConnectionError response)
callTimeout :: Maybe TimeoutMicroseconds
-> Connection -> request -> IO (Either ConnectionError response)
callTimeout Maybe TimeoutMicroseconds
mTimeout Connection
conn request
msg = do
  UUID
requestID <- UUID -> UUID
UUID (UUID -> UUID) -> IO UUID -> IO UUID
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO UUID
UUIDBase.nextRandom  
  let mVarMap :: SyncMap
mVarMap = Connection -> SyncMap
_conn_syncmap Connection
conn
      timeoutms :: TimeoutMicroseconds
timeoutms = case Maybe TimeoutMicroseconds
mTimeout of
        Maybe TimeoutMicroseconds
Nothing -> TimeoutMicroseconds
0
        Just TimeoutMicroseconds
tm | TimeoutMicroseconds
tm TimeoutMicroseconds -> TimeoutMicroseconds -> Bool
forall a. Ord a => a -> a -> Bool
< TimeoutMicroseconds
0 -> TimeoutMicroseconds
0
        Just TimeoutMicroseconds
tm -> TimeoutMicroseconds -> TimeoutMicroseconds
forall a b. (Integral a, Num b) => a -> b
fromIntegral TimeoutMicroseconds
tm
        
      envelope :: Envelope
envelope = Fingerprint -> MessageType -> UUID -> BinaryMessage -> Envelope
Envelope Fingerprint
fprint (TimeoutMicroseconds -> MessageType
RequestMessage TimeoutMicroseconds
timeoutms) UUID
requestID (request -> BinaryMessage
forall a. Serialise a => a -> BinaryMessage
msgSerialise request
msg)
      fprint :: Fingerprint
fprint = request -> Fingerprint
forall a. Typeable a => a -> Fingerprint
fingerprint request
msg
  -- setup mvar to wait for response
  MVar (Either ConnectionError BinaryMessage)
responseMVar <- IO (MVar (Either ConnectionError BinaryMessage))
forall a. IO (MVar a)
newEmptyMVar
  UTCTime
now <- IO UTCTime
getCurrentTime
  STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ (MVar (Either ConnectionError BinaryMessage), UTCTime)
-> UUID -> SyncMap -> STM ()
forall key value.
(Eq key, Hashable key) =>
value -> key -> Map key value -> STM ()
STMMap.insert (MVar (Either ConnectionError BinaryMessage)
responseMVar, UTCTime
now) UUID
requestID SyncMap
mVarMap
  Envelope -> Locking Socket -> IO ()
sendEnvelope Envelope
envelope (Connection -> Locking Socket
_conn_sockLock Connection
conn)
  let timeoutMicroseconds :: TimeoutMicroseconds
timeoutMicroseconds =
        case Maybe TimeoutMicroseconds
mTimeout of
          Just TimeoutMicroseconds
timeout' -> TimeoutMicroseconds
timeout' TimeoutMicroseconds -> TimeoutMicroseconds -> TimeoutMicroseconds
forall a. Num a => a -> a -> a
+ TimeoutMicroseconds
100 --add 100 ms to account for unknown network latency
          Maybe TimeoutMicroseconds
Nothing -> -TimeoutMicroseconds
1
  Maybe (Either ConnectionError BinaryMessage)
mResponse <- TimeoutMicroseconds
-> IO (Either ConnectionError BinaryMessage)
-> IO (Maybe (Either ConnectionError BinaryMessage))
forall a. TimeoutMicroseconds -> IO a -> IO (Maybe a)
timeout TimeoutMicroseconds
timeoutMicroseconds (MVar (Either ConnectionError BinaryMessage)
-> IO (Either ConnectionError BinaryMessage)
forall a. MVar a -> IO a
takeMVar MVar (Either ConnectionError BinaryMessage)
responseMVar)
  STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ UUID -> SyncMap -> STM ()
forall key value.
(Eq key, Hashable key) =>
key -> Map key value -> STM ()
STMMap.delete UUID
requestID SyncMap
mVarMap
  case Maybe (Either ConnectionError BinaryMessage)
mResponse of
    --timeout
    Maybe (Either ConnectionError BinaryMessage)
Nothing ->
      Either ConnectionError response
-> IO (Either ConnectionError response)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ConnectionError -> Either ConnectionError response
forall a b. a -> Either a b
Left ConnectionError
TimeoutError)
    Just (Left ConnectionError
exc) ->
      Either ConnectionError response
-> IO (Either ConnectionError response)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ConnectionError -> Either ConnectionError response
forall a b. a -> Either a b
Left ConnectionError
exc)
    Just (Right BinaryMessage
binmsg) ->
      case BinaryMessage -> Either WineryException response
forall s. Serialise s => BinaryMessage -> Either WineryException s
msgDeserialise BinaryMessage
binmsg of
        Left WineryException
err -> String -> IO (Either ConnectionError response)
forall a. HasCallStack => String -> a
error (String
"deserialise client error " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> WineryException -> String
forall a. Show a => a -> String
show WineryException
err)
        Right response
v -> Either ConnectionError response
-> IO (Either ConnectionError response)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (response -> Either ConnectionError response
forall a b. b -> Either a b
Right response
v)

-- | Call a remote function but do not expect a response from the server.
asyncCall :: Serialise request => Connection -> request -> IO (Either ConnectionError ())
asyncCall :: Connection -> request -> IO (Either ConnectionError ())
asyncCall Connection
conn request
msg = do
  UUID
requestID <- UUID -> UUID
UUID (UUID -> UUID) -> IO UUID -> IO UUID
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO UUID
UUIDBase.nextRandom
  let envelope :: Envelope
envelope = Fingerprint -> MessageType -> UUID -> BinaryMessage -> Envelope
Envelope Fingerprint
fprint (TimeoutMicroseconds -> MessageType
RequestMessage TimeoutMicroseconds
0) UUID
requestID (request -> BinaryMessage
forall a. Serialise a => a -> BinaryMessage
msgSerialise request
msg)
      fprint :: Fingerprint
fprint = request -> Fingerprint
forall a. Typeable a => a -> Fingerprint
fingerprint request
msg
  Envelope -> Locking Socket -> IO ()
sendEnvelope Envelope
envelope (Connection -> Locking Socket
_conn_sockLock Connection
conn)
  Either ConnectionError () -> IO (Either ConnectionError ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either ConnectionError ()
forall a b. b -> Either a b
Right ())