-- |ONC RPC Client interface.
-- Handles RPC client protocol layer.
-- Clients are fully thread-safe, allowing multiple outstanding requests, and automatically reconnect on error.
-- Currently error messages are just written to stdout.

{-# LANGUAGE CPP #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE RecordWildCards #-}
module Network.ONCRPC.Client
  ( ClientServer(..)
  , makeClientServerPort
  , Client
  , openClient
  , closeClient
  , clientCall
  , setClientAuth
  , rpcCall
  ) where

import           Control.Concurrent (ThreadId, forkIO, killThread, threadDelay)
import           Control.Concurrent.MVar (MVar, newEmptyMVar, putMVar, takeMVar, modifyMVar, modifyMVar_, modifyMVarMasked)
import           Control.Exception (throwIO)
import qualified Data.ByteString.Lazy as BSL
import qualified Data.IntMap.Strict as IntMap
import           Data.Time.Clock (getCurrentTime, diffUTCTime)
import qualified Network.Socket as Net
import           System.IO (hPutStrLn, stderr)
import           System.IO.Error (catchIOError)
import           System.Random (randomIO)

#ifdef BINDRESVPORT
import           Control.Monad (when)
import           Foreign.C.Types (CInt(CInt))
import           Foreign.Ptr (Ptr, nullPtr)
import           Network.Socket.Internal (throwSocketErrorIfMinus1Retry_)
#endif

import qualified Network.ONCRPC.XDR as XDR
import qualified Network.ONCRPC.Prot as RPC
import           Network.ONCRPC.Types
import           Network.ONCRPC.Auth
import           Network.ONCRPC.Message
import           Network.ONCRPC.Transport

-- |How to connect to an RPC server.
-- Currently only TCP connections to pre-defined ports are supported (no portmap).
data ClientServer
  = ClientServerPort
    { ClientServer -> HostName
clientServerHost :: Net.HostName -- ^Host name or IP address of server
    , ClientServer -> HostName
clientServerPort :: Net.ServiceName -- ^Service name (not portmap) or port number
#ifdef BINDRESVPORT
    , clientBindResvPort :: Bool
#endif
    } -- ^a known service by host/port, currently only TCP

makeClientServerPort :: Net.HostName -> Net.ServiceName -> ClientServer
makeClientServerPort :: HostName -> HostName -> ClientServer
makeClientServerPort HostName
h HostName
p = ClientServerPort
  { clientServerHost :: HostName
clientServerHost = HostName
h
  , clientServerPort :: HostName
clientServerPort = HostName
p
#ifdef BINDRESVPORT
  , clientBindResvPort = False
#endif
  }

data Request = forall a . XDR.XDR a => Request
  { Request -> ByteString
requestBody :: BSL.ByteString -- ^for retransmits
  , ()
requestAction :: MVar (Reply a)
  }

data State = State
  { State -> Maybe Socket
stateSocket :: Maybe Net.Socket
  , State -> XID
stateXID :: XID
  , State -> IntMap Request
stateRequests :: IntMap.IntMap Request
  }

-- |An RPC Client.
data Client = Client
  { Client -> ClientServer
clientServer :: ClientServer
  , Client -> ThreadId
clientThread :: ThreadId
  , Client -> MVar State
clientState :: MVar State
  , Client -> Auth
clientCred, Client -> Auth
clientVerf :: Auth
  }

#ifdef BINDRESVPORT
foreign import ccall unsafe "bindresvport" c_bindresvport :: CInt -> Ptr Net.SockAddr -> IO CInt
#endif

warnMsg :: Show e => String -> e -> IO ()
warnMsg :: forall e. Show e => HostName -> e -> IO ()
warnMsg HostName
m = Handle -> HostName -> IO ()
hPutStrLn Handle
stderr forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> [a] -> [a]
(++) (HostName
"Network.ONCRPC.Client: " forall a. [a] -> [a] -> [a]
++ HostName
m forall a. [a] -> [a] -> [a]
++ HostName
": ") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> HostName
show

clientRecv :: Client -> Net.Socket -> IO ()
clientRecv :: Client -> Socket -> IO ()
clientRecv Client
c Socket
sock = TransportState -> IO ()
next TransportState
transportStart where
  next :: TransportState -> IO ()
next TransportState
ms =
    forall {t} {t}. (t -> t -> IO ()) -> Maybe (t, t) -> IO ()
check Either HostName Rpc_msg -> TransportState -> IO ()
msg forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a.
Socket
-> Get a
-> TransportState
-> IO (Maybe (Either HostName a, TransportState))
recvGetFirst Socket
sock forall a. XDR a => Get a
XDR.xdrGet TransportState
ms
  msg :: Either HostName Rpc_msg -> TransportState -> IO ()
msg (Right (RPC.Rpc_msg XID
x (RPC.Rpc_msg_body'REPLY Reply_body
b))) TransportState
ms = do
    Maybe Request
q <- forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVarMasked (Client -> MVar State
clientState Client
c) forall a b. (a -> b) -> a -> b
$ \s :: State
s@State{ stateRequests :: State -> IntMap Request
stateRequests = IntMap Request
m } -> do
      let (Maybe Request
q, IntMap Request
m') = forall a.
(Int -> a -> Maybe a) -> Int -> IntMap a -> (Maybe a, IntMap a)
IntMap.updateLookupWithKey (forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall a b. a -> b -> a
const forall a. Maybe a
Nothing) (forall a b. (Integral a, Num b) => a -> b
fromIntegral XID
x) IntMap Request
m
      forall (m :: * -> *) a. Monad m => a -> m a
return (State
s{ stateRequests :: IntMap Request
stateRequests = IntMap Request
m' }, Maybe Request
q)
    case Maybe Request
q of
      Maybe Request
Nothing -> do
        forall e. Show e => HostName -> e -> IO ()
warnMsg HostName
"ignoring response to unknown xid" XID
x
        TransportState -> IO ()
next TransportState
ms -- ignore
      Just (Request ByteString
_ MVar (Reply a)
a) ->
        forall {t} {t}. (t -> t -> IO ()) -> Maybe (t, t) -> IO ()
check (\Either HostName (Reply a)
r TransportState
ms' -> do
          forall a. MVar a -> a -> IO ()
putMVar MVar (Reply a)
a forall a b. (a -> b) -> a -> b
$ forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall a. HostName -> Reply a
ReplyFail forall a. a -> a
id Either HostName (Reply a)
r
          TransportState -> IO ()
next TransportState
ms')
          forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a.
Socket
-> Get a
-> TransportState
-> IO (Maybe (Either HostName a, TransportState))
recvGetNext Socket
sock (forall a. XDR a => Reply_body -> Get (Reply a)
getReply Reply_body
b) TransportState
ms
  msg Either HostName Rpc_msg
e TransportState
_ = forall e. Show e => HostName -> e -> IO ()
warnMsg HostName
"couldn't decode reply msg" Either HostName Rpc_msg
e -- return
  check :: (t -> t -> IO ()) -> Maybe (t, t) -> IO ()
check t -> t -> IO ()
_ Maybe (t, t)
Nothing = forall e. Show e => HostName -> e -> IO ()
warnMsg HostName
"socket closed" () -- return
  check t -> t -> IO ()
f (Just (t
r, t
ms)) = t -> t -> IO ()
f t
r t
ms

clientConnect :: Client -> IO Net.Socket
clientConnect :: Client -> IO Socket
clientConnect Client
c = forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar (Client -> MVar State
clientState Client
c) forall a b. (a -> b) -> a -> b
$ ClientServer -> State -> IO (State, Socket)
conn (Client -> ClientServer
clientServer Client
c) where
  conn :: ClientServer -> State -> IO (State, Socket)
conn ClientServer
_ s :: State
s@State{ stateSocket :: State -> Maybe Socket
stateSocket = Just Socket
sock } = forall (m :: * -> *) a. Monad m => a -> m a
return (State
s, Socket
sock)
  conn ClientServerPort{HostName
clientServerPort :: HostName
clientServerHost :: HostName
clientServerPort :: ClientServer -> HostName
clientServerHost :: ClientServer -> HostName
..} State
s = do
    AddrInfo
addr:[AddrInfo]
_ <- Maybe AddrInfo -> Maybe HostName -> Maybe HostName -> IO [AddrInfo]
Net.getAddrInfo (forall a. a -> Maybe a
Just AddrInfo
Net.defaultHints{ addrSocketType :: SocketType
Net.addrSocketType = SocketType
Net.Stream }) (forall a. a -> Maybe a
Just HostName
clientServerHost) (forall a. a -> Maybe a
Just HostName
clientServerPort)
    Socket
sock <- Family -> SocketType -> ProtocolNumber -> IO Socket
Net.socket (AddrInfo -> Family
Net.addrFamily AddrInfo
addr) (AddrInfo -> SocketType
Net.addrSocketType AddrInfo
addr) (AddrInfo -> ProtocolNumber
Net.addrProtocol AddrInfo
addr)
#ifdef BINDRESVPORT
    when clientBindResvPort $
      throwSocketErrorIfMinus1Retry_ "bindresvport" $
        Net.withFdSocket sock $ \fd ->
          c_bindresvport fd nullPtr
#endif
    Socket -> SockAddr -> IO ()
Net.connect Socket
sock (AddrInfo -> SockAddr
Net.addrAddress AddrInfo
addr)
    forall {t :: * -> *}. Foldable t => Socket -> t Request -> IO ()
resend Socket
sock (State -> IntMap Request
stateRequests State
s)
    forall (m :: * -> *) a. Monad m => a -> m a
return (State
s{ stateSocket :: Maybe Socket
stateSocket = forall a. a -> Maybe a
Just Socket
sock }, Socket
sock)
  resend :: Socket -> t Request -> IO ()
resend Socket
sock = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall a b. (a -> b) -> a -> b
$ Socket -> ByteString -> IO ()
sendTransport Socket
sock forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> ByteString
requestBody

clientDisconnect :: Client -> IO ()
clientDisconnect :: Client -> IO ()
clientDisconnect Client
c = forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (Client -> MVar State
clientState Client
c) forall a b. (a -> b) -> a -> b
$ \State
s -> do
  forall a. IO a -> (IOError -> IO a) -> IO a
catchIOError
    (forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Socket -> IO ()
Net.close forall a b. (a -> b) -> a -> b
$ State -> Maybe Socket
stateSocket State
s)
    (forall e. Show e => HostName -> e -> IO ()
warnMsg HostName
"close")
  forall (m :: * -> *) a. Monad m => a -> m a
return State
s{ stateSocket :: Maybe Socket
stateSocket = forall a. Maybe a
Nothing }

clientMain :: Client -> IO ()
clientMain :: Client -> IO ()
clientMain Client
c = do
  UTCTime
t <- IO UTCTime
getCurrentTime
  forall a. IO a -> (IOError -> IO a) -> IO a
catchIOError
    (Client -> IO Socket
clientConnect Client
c forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Client -> Socket -> IO ()
clientRecv Client
c)
    (forall e. Show e => HostName -> e -> IO ()
warnMsg HostName
"client")
  Client -> IO ()
clientDisconnect Client
c
  NominalDiffTime
dt <- (UTCTime -> UTCTime -> NominalDiffTime
`diffUTCTime` UTCTime
t) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO UTCTime
getCurrentTime
  Int -> IO ()
threadDelay forall a b. (a -> b) -> a -> b
$ forall a b. (RealFrac a, Integral b) => a -> b
ceiling forall a b. (a -> b) -> a -> b
$ NominalDiffTime
300000000 forall a. Fractional a => a -> a -> a
/ (NominalDiffTime
dt forall a. Num a => a -> a -> a
+ NominalDiffTime
20)
  Client -> IO ()
clientMain Client
c

-- |Create a new RPC client to the given server.
-- This client must be destroyed with 'closeClient' when done.
openClient :: ClientServer -> IO Client
openClient :: ClientServer -> IO Client
openClient ClientServer
srv = do
  MVar State
s <- forall a. IO (MVar a)
newEmptyMVar
  let c :: Client
c = Client
        { clientServer :: ClientServer
clientServer = ClientServer
srv
        , clientThread :: ThreadId
clientThread = forall a. HasCallStack => HostName -> a
error HostName
"clientThread"
        , clientState :: MVar State
clientState = MVar State
s
        , clientCred :: Auth
clientCred = Auth
AuthNone
        , clientVerf :: Auth
clientVerf = Auth
AuthNone
        }
  XID
xid <- forall a (m :: * -> *). (Random a, MonadIO m) => m a
randomIO
  ThreadId
tid <- IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$ Client -> IO ()
clientMain Client
c
  forall a. MVar a -> a -> IO ()
putMVar MVar State
s State
    { stateSocket :: Maybe Socket
stateSocket = forall a. Maybe a
Nothing
    , stateXID :: XID
stateXID = XID
xid
    , stateRequests :: IntMap Request
stateRequests = forall a. IntMap a
IntMap.empty
    }
  forall (m :: * -> *) a. Monad m => a -> m a
return Client
c{ clientThread :: ThreadId
clientThread = ThreadId
tid }

-- |Set the credentials and verifier to use when calling 'rpcCall' on a client.
-- Note that you can safely use different sets of credentials with the same underlying connection this way.
-- By default, both are set to 'AuthNone'.
setClientAuth :: Auth -> Auth -> Client -> Client
setClientAuth :: Auth -> Auth -> Client -> Client
setClientAuth Auth
cred Auth
verf Client
client = Client
client
  { clientCred :: Auth
clientCred = Auth
cred
  , clientVerf :: Auth
clientVerf = Auth
verf
  }

-- |Destroy an RPC client and close its underlying network connection.
-- Any outstanding requests return 'ReplyFail', any any further attempt to use the 'Client' may hang indefinitely.
closeClient :: Client -> IO ()
closeClient :: Client -> IO ()
closeClient Client
c = do
  ThreadId -> IO ()
killThread forall a b. (a -> b) -> a -> b
$ Client -> ThreadId
clientThread Client
c
  Client -> IO ()
clientDisconnect Client
c
  -- Leave the state empty.
  State
s <- forall a. MVar a -> IO a
takeMVar forall a b. (a -> b) -> a -> b
$ Client -> MVar State
clientState Client
c
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\(Request ByteString
_ MVar (Reply a)
a) -> forall a. MVar a -> a -> IO ()
putMVar MVar (Reply a)
a forall a b. (a -> b) -> a -> b
$ forall a. HostName -> Reply a
ReplyFail HostName
"closed") forall a b. (a -> b) -> a -> b
$ State -> IntMap Request
stateRequests State
s

-- |Send a call message using an open client, and wait for a reply, returning 'ReplyFail' on protocol error.
-- The request will be automatically retried until a response is received.
clientCall :: (XDR.XDR a, XDR.XDR r) => Client -> Call a r -> IO (Reply r)
clientCall :: forall a r. (XDR a, XDR r) => Client -> Call a r -> IO (Reply r)
clientCall Client
c Call a r
a = do
  MVar (Reply r)
rv <- forall a. IO (MVar a)
newEmptyMVar
  Maybe Request
p <- forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar (Client -> MVar State
clientState Client
c) forall a b. (a -> b) -> a -> b
$ \State
s -> do
    let x :: XID
x = State -> XID
stateXID State
s
        q :: Request
q = Request
          { requestBody :: ByteString
requestBody = forall a. XDR a => a -> ByteString
XDR.xdrSerializeLazy forall a b. (a -> b) -> a -> b
$ forall a r. XID -> Call a r -> Msg a r
MsgCall XID
x Call a r
a
          , requestAction :: MVar (Reply r)
requestAction = MVar (Reply r)
rv
          }
        (Maybe Request
p, IntMap Request
r) = forall a.
(Int -> a -> a -> a) -> Int -> a -> IntMap a -> (Maybe a, IntMap a)
IntMap.insertLookupWithKey (forall a b. a -> b -> a
const forall a b. a -> b -> a
const) (forall a b. (Integral a, Num b) => a -> b
fromIntegral XID
x) Request
q (State -> IntMap Request
stateRequests State
s)
    forall a. IO a -> (IOError -> IO a) -> IO a
catchIOError
      (forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Socket -> ByteString -> IO ()
`sendTransport` Request -> ByteString
requestBody Request
q) forall a b. (a -> b) -> a -> b
$ State -> Maybe Socket
stateSocket State
s)
      (forall e. Show e => HostName -> e -> IO ()
warnMsg HostName
"sendTransport")
    forall (m :: * -> *) a. Monad m => a -> m a
return (State
s{ stateRequests :: IntMap Request
stateRequests = IntMap Request
r, stateXID :: XID
stateXID = XID
xforall a. Num a => a -> a -> a
+XID
1 }, Maybe Request
p)
  case Maybe Request
p of
    Maybe Request
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
    Just (Request ByteString
_ MVar (Reply a)
v) -> forall a. MVar a -> a -> IO ()
putMVar MVar (Reply a)
v (forall a. HostName -> Reply a
ReplyFail HostName
"no response") -- should only happen on xid wraparound
  forall a. MVar a -> IO a
takeMVar MVar (Reply r)
rv

-- |Make an RPC request.
-- It waits for a response, retrying as necessary, or throws the 'Network.ONCRPC.Exception.RPCException', 'ReplyException', on any failure.
-- This uses the credentials set by 'setClientAuth'.
-- If you need to retrieve the auth verifier, use 'clientCall'.
rpcCall :: (XDR.XDR a, XDR.XDR r) => Client -> Procedure a r -> a -> IO r
rpcCall :: forall a r. (XDR a, XDR r) => Client -> Procedure a r -> a -> IO r
rpcCall Client
c Procedure a r
p a
a = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall e a. Exception e => e -> IO a
throwIO forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Reply a -> Either ReplyException a
replyResult
  forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a r. (XDR a, XDR r) => Client -> Call a r -> IO (Reply r)
clientCall Client
c (forall a r. Procedure a r -> Auth -> Auth -> a -> Call a r
Call Procedure a r
p (Client -> Auth
clientCred Client
c) (Client -> Auth
clientVerf Client
c) a
a)