{-# LANGUAGE RankNTypes, ScopedTypeVariables, GADTs #-}
module Network.RPC.Curryer.Client where
import Network.RPC.Curryer.Server
import Network.Socket as Socket
import qualified Streamly.Network.Inet.TCP as TCP
import Codec.Winery
import Control.Concurrent.Async
import qualified Data.UUID.V4 as UUIDBase
import qualified StmContainers.Map as STMMap
import Control.Concurrent.MVar
import GHC.Conc
import Data.Time.Clock
import System.Timeout
import Control.Monad

type SyncMap = STMMap.Map UUID (MVar (Either ConnectionError BinaryMessage), UTCTime)

-- | Represents a remote connection to server.
data Connection = Connection { Connection -> Locking Socket
_conn_sockLock :: Locking Socket,
                               Connection -> Async ()
_conn_asyncThread :: Async (),
                               Connection -> SyncMap
_conn_syncmap :: SyncMap
                             }

-- | Function handlers run on the client, triggered by the server- useful for asynchronous callbacks.
data ClientAsyncRequestHandler where
  ClientAsyncRequestHandler :: forall a. Serialise a => (a -> IO ()) -> ClientAsyncRequestHandler

type ClientAsyncRequestHandlers = [ClientAsyncRequestHandler]

-- | Connects to a remote server with specific async callbacks registered.
connect :: 
  ClientAsyncRequestHandlers ->
  HostAddr ->
  PortNumber ->
  IO Connection
connect :: ClientAsyncRequestHandlers
-> HostAddr -> PortNumber -> IO Connection
connect ClientAsyncRequestHandlers
asyncHandlers HostAddr
hostAddr PortNumber
portNum = do
  Socket
sock <- HostAddr -> PortNumber -> IO Socket
TCP.connect HostAddr
hostAddr PortNumber
portNum
  Socket -> SocketOption -> Int -> IO ()
Socket.setSocketOption Socket
sock SocketOption
NoDelay Int
1
  SyncMap
syncmap <- forall key value. IO (Map key value)
STMMap.newIO
  Async ()
asyncThread <- forall a. IO a -> IO (Async a)
async (Socket -> SyncMap -> ClientAsyncRequestHandlers -> IO ()
clientAsync Socket
sock SyncMap
syncmap ClientAsyncRequestHandlers
asyncHandlers)
  Locking Socket
sockLock <- forall a. a -> IO (Locking a)
newLock Socket
sock
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Connection {
           _conn_sockLock :: Locking Socket
_conn_sockLock = Locking Socket
sockLock,
           _conn_asyncThread :: Async ()
_conn_asyncThread = Async ()
asyncThread,
           _conn_syncmap :: SyncMap
_conn_syncmap = SyncMap
syncmap
           })

-- | Close the connection and release all connection resources.
close :: Connection -> IO ()
close :: Connection -> IO ()
close Connection
conn = do
  forall a b. Locking a -> (a -> IO b) -> IO b
withLock (Connection -> Locking Socket
_conn_sockLock Connection
conn) forall a b. (a -> b) -> a -> b
$ \Socket
sock ->
    Socket -> IO ()
Socket.close Socket
sock
  forall a. Async a -> IO ()
cancel (Connection -> Async ()
_conn_asyncThread Connection
conn)

-- | async thread for handling client-side incoming messages- dispatch to proper waiting thread or asynchronous notifications handler
clientAsync :: 
  Socket ->
  SyncMap ->
  ClientAsyncRequestHandlers ->
  IO ()
clientAsync :: Socket -> SyncMap -> ClientAsyncRequestHandlers -> IO ()
clientAsync Socket
sock SyncMap
syncmap ClientAsyncRequestHandlers
asyncHandlers = do
  Locking Socket
lsock <- forall a. a -> IO (Locking a)
newLock Socket
sock
  Socket -> EnvelopeHandler -> IO ()
drainSocketMessages Socket
sock (ClientAsyncRequestHandlers
-> Locking Socket -> SyncMap -> EnvelopeHandler
clientEnvelopeHandler ClientAsyncRequestHandlers
asyncHandlers Locking Socket
lsock SyncMap
syncmap)

consumeResponse :: UUID -> STMMap.Map UUID (MVar a, b) -> a -> IO ()
consumeResponse :: forall a b. UUID -> Map UUID (MVar a, b) -> a -> IO ()
consumeResponse UUID
msgId Map UUID (MVar a, b)
syncMap a
val = do
  Maybe (MVar a, b)
match <- forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ do
    Maybe (MVar a, b)
val' <- forall key value.
Hashable key =>
key -> Map key value -> STM (Maybe value)
STMMap.lookup UUID
msgId Map UUID (MVar a, b)
syncMap
    forall key value. Hashable key => key -> Map key value -> STM ()
STMMap.delete UUID
msgId Map UUID (MVar a, b)
syncMap
    forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (MVar a, b)
val'
  case Maybe (MVar a, b)
match of
    Maybe (MVar a, b)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure () -- drop message
    Just (MVar a
mVar, b
_) -> forall a. MVar a -> a -> IO ()
putMVar MVar a
mVar a
val

-- | handles envelope responses from server- timeout from ths server is ignored, but perhaps that's proper for trusted servers- the server expects the client to process all async requests
clientEnvelopeHandler ::
  ClientAsyncRequestHandlers
  -> Locking Socket
  -> SyncMap
  -> Envelope
  -> IO ()
clientEnvelopeHandler :: ClientAsyncRequestHandlers
-> Locking Socket -> SyncMap -> EnvelopeHandler
clientEnvelopeHandler ClientAsyncRequestHandlers
handlers Locking Socket
_ SyncMap
_ envelope :: Envelope
envelope@(Envelope Fingerprint
_ (RequestMessage Int
_) UUID
_ BinaryMessage
_) = do
  --should this run off on another green thread?
  let firstMatcher :: Maybe () -> ClientAsyncRequestHandler -> IO (Maybe ())
firstMatcher Maybe ()
Nothing (ClientAsyncRequestHandler (a -> IO ()
dispatchf :: a -> IO ())) = do
        case forall s. (Serialise s, Typeable s) => Envelope -> Maybe s
openEnvelope Envelope
envelope of
          Maybe a
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
          Just a
decoded -> do
            a -> IO ()
dispatchf a
decoded
            forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. a -> Maybe a
Just ())
      firstMatcher Maybe ()
acc ClientAsyncRequestHandler
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe ()
acc
  forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ Maybe () -> ClientAsyncRequestHandler -> IO (Maybe ())
firstMatcher forall a. Maybe a
Nothing ClientAsyncRequestHandlers
handlers
clientEnvelopeHandler ClientAsyncRequestHandlers
_ Locking Socket
_ SyncMap
syncMap (Envelope Fingerprint
_ MessageType
ResponseMessage UUID
msgId BinaryMessage
binaryMessage) =
  forall a b. UUID -> Map UUID (MVar a, b) -> a -> IO ()
consumeResponse UUID
msgId SyncMap
syncMap (forall a b. b -> Either a b
Right BinaryMessage
binaryMessage)
clientEnvelopeHandler ClientAsyncRequestHandlers
_ Locking Socket
_ SyncMap
syncMap (Envelope Fingerprint
_ MessageType
TimeoutResponseMessage UUID
msgId BinaryMessage
_) =
  forall a b. UUID -> Map UUID (MVar a, b) -> a -> IO ()
consumeResponse UUID
msgId SyncMap
syncMap (forall a b. a -> Either a b
Left ConnectionError
TimeoutError)
clientEnvelopeHandler ClientAsyncRequestHandlers
_ Locking Socket
_ SyncMap
syncMap (Envelope Fingerprint
_ MessageType
ExceptionResponseMessage UUID
msgId BinaryMessage
excPayload) = 
  case forall s. Serialise s => BinaryMessage -> Either WineryException s
msgDeserialise BinaryMessage
excPayload of
        Left WineryException
err -> forall a. HasCallStack => String -> a
error (String
"failed to deserialise exception string" forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show WineryException
err)
        Right String
excStr ->
          forall a b. UUID -> Map UUID (MVar a, b) -> a -> IO ()
consumeResponse UUID
msgId SyncMap
syncMap (forall a b. a -> Either a b
Left (String -> ConnectionError
ExceptionError String
excStr))
      
-- | Basic remote function call via data type and return value.
call :: (Serialise request, Serialise response) => Connection -> request -> IO (Either ConnectionError response)
call :: forall request response.
(Serialise request, Serialise response) =>
Connection -> request -> IO (Either ConnectionError response)
call = forall request response.
(Serialise request, Serialise response) =>
Maybe Int
-> Connection -> request -> IO (Either ConnectionError response)
callTimeout forall a. Maybe a
Nothing

-- | Send a request to the remote server and returns a response but with the possibility of a timeout after n microseconds.
callTimeout :: (Serialise request, Serialise response) => Maybe Int -> Connection -> request -> IO (Either ConnectionError response)
callTimeout :: forall request response.
(Serialise request, Serialise response) =>
Maybe Int
-> Connection -> request -> IO (Either ConnectionError response)
callTimeout Maybe Int
mTimeout Connection
conn request
msg = do
  UUID
requestID <- UUID -> UUID
UUID forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO UUID
UUIDBase.nextRandom  
  let mVarMap :: SyncMap
mVarMap = Connection -> SyncMap
_conn_syncmap Connection
conn
      timeoutms :: Int
timeoutms = case Maybe Int
mTimeout of
        Maybe Int
Nothing -> Int
0
        Just Int
tm | Int
tm forall a. Ord a => a -> a -> Bool
< Int
0 -> Int
0
        Just Int
tm -> forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
tm
        
      envelope :: Envelope
envelope = Fingerprint -> MessageType -> UUID -> BinaryMessage -> Envelope
Envelope Fingerprint
fprint (Int -> MessageType
RequestMessage Int
timeoutms) UUID
requestID (forall a. Serialise a => a -> BinaryMessage
msgSerialise request
msg)
      fprint :: Fingerprint
fprint = forall a. Typeable a => a -> Fingerprint
fingerprint request
msg
  -- setup mvar to wait for response
  MVar (Either ConnectionError BinaryMessage)
responseMVar <- forall a. IO (MVar a)
newEmptyMVar
  UTCTime
now <- IO UTCTime
getCurrentTime
  forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall key value.
Hashable key =>
value -> key -> Map key value -> STM ()
STMMap.insert (MVar (Either ConnectionError BinaryMessage)
responseMVar, UTCTime
now) UUID
requestID SyncMap
mVarMap
  Envelope -> Locking Socket -> IO ()
sendEnvelope Envelope
envelope (Connection -> Locking Socket
_conn_sockLock Connection
conn)
  let timeoutMicroseconds :: Int
timeoutMicroseconds =
        case Maybe Int
mTimeout of
          Just Int
timeout' -> Int
timeout' forall a. Num a => a -> a -> a
+ Int
100 --add 100 ms to account for unknown network latency
          Maybe Int
Nothing -> -Int
1
  Maybe (Either ConnectionError BinaryMessage)
mResponse <- forall a. Int -> IO a -> IO (Maybe a)
timeout Int
timeoutMicroseconds (forall a. MVar a -> IO a
takeMVar MVar (Either ConnectionError BinaryMessage)
responseMVar)
  forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall key value. Hashable key => key -> Map key value -> STM ()
STMMap.delete UUID
requestID SyncMap
mVarMap
  case Maybe (Either ConnectionError BinaryMessage)
mResponse of
    --timeout
    Maybe (Either ConnectionError BinaryMessage)
Nothing ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> Either a b
Left ConnectionError
TimeoutError)
    Just (Left ConnectionError
exc) ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> Either a b
Left ConnectionError
exc)
    Just (Right BinaryMessage
binmsg) ->
      case forall s. Serialise s => BinaryMessage -> Either WineryException s
msgDeserialise BinaryMessage
binmsg of
        Left WineryException
err -> forall a. HasCallStack => String -> a
error (String
"deserialise client error " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show WineryException
err)
        Right response
v -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. b -> Either a b
Right response
v)

-- | Call a remote function but do not expect a response from the server.
asyncCall :: Serialise request => Connection -> request -> IO (Either ConnectionError ())
asyncCall :: forall request.
Serialise request =>
Connection -> request -> IO (Either ConnectionError ())
asyncCall Connection
conn request
msg = do
  UUID
requestID <- UUID -> UUID
UUID forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO UUID
UUIDBase.nextRandom
  let envelope :: Envelope
envelope = Fingerprint -> MessageType -> UUID -> BinaryMessage -> Envelope
Envelope Fingerprint
fprint (Int -> MessageType
RequestMessage Int
0) UUID
requestID (forall a. Serialise a => a -> BinaryMessage
msgSerialise request
msg)
      fprint :: Fingerprint
fprint = forall a. Typeable a => a -> Fingerprint
fingerprint request
msg
  Envelope -> Locking Socket -> IO ()
sendEnvelope Envelope
envelope (Connection -> Locking Socket
_conn_sockLock Connection
conn)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. b -> Either a b
Right ())