{-# LANGUAGE CPP #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE RecordWildCards #-}
module Network.ONCRPC.Client
( ClientServer(..)
, makeClientServerPort
, Client
, openClient
, closeClient
, clientCall
, setClientAuth
, rpcCall
) where
import Control.Concurrent (ThreadId, forkIO, killThread, threadDelay)
import Control.Concurrent.MVar (MVar, newEmptyMVar, putMVar, takeMVar, modifyMVar, modifyMVar_, modifyMVarMasked)
import Control.Exception (throwIO)
import qualified Data.ByteString.Lazy as BSL
import qualified Data.IntMap.Strict as IntMap
import Data.Time.Clock (getCurrentTime, diffUTCTime)
import qualified Network.Socket as Net
import System.IO (hPutStrLn, stderr)
import System.IO.Error (catchIOError)
import System.Random (randomIO)
#ifdef BINDRESVPORT
import Control.Monad (when)
import Foreign.C.Types (CInt(CInt))
import Foreign.Ptr (Ptr, nullPtr)
import Network.Socket.Internal (throwSocketErrorIfMinus1Retry_)
#endif
import qualified Network.ONCRPC.XDR as XDR
import qualified Network.ONCRPC.Prot as RPC
import Network.ONCRPC.Types
import Network.ONCRPC.Auth
import Network.ONCRPC.Message
import Network.ONCRPC.Transport
data ClientServer
= ClientServerPort
{ ClientServer -> HostName
clientServerHost :: Net.HostName
, ClientServer -> HostName
clientServerPort :: Net.ServiceName
#ifdef BINDRESVPORT
, clientBindResvPort :: Bool
#endif
}
makeClientServerPort :: Net.HostName -> Net.ServiceName -> ClientServer
makeClientServerPort :: HostName -> HostName -> ClientServer
makeClientServerPort HostName
h HostName
p = ClientServerPort
{ clientServerHost :: HostName
clientServerHost = HostName
h
, clientServerPort :: HostName
clientServerPort = HostName
p
#ifdef BINDRESVPORT
, clientBindResvPort = False
#endif
}
data Request = forall a . XDR.XDR a => Request
{ Request -> ByteString
requestBody :: BSL.ByteString
, ()
requestAction :: MVar (Reply a)
}
data State = State
{ State -> Maybe Socket
stateSocket :: Maybe Net.Socket
, State -> XID
stateXID :: XID
, State -> IntMap Request
stateRequests :: IntMap.IntMap Request
}
data Client = Client
{ Client -> ClientServer
clientServer :: ClientServer
, Client -> ThreadId
clientThread :: ThreadId
, Client -> MVar State
clientState :: MVar State
, Client -> Auth
clientCred, Client -> Auth
clientVerf :: Auth
}
#ifdef BINDRESVPORT
foreign import ccall unsafe "bindresvport" c_bindresvport :: CInt -> Ptr Net.SockAddr -> IO CInt
#endif
warnMsg :: Show e => String -> e -> IO ()
warnMsg :: forall e. Show e => HostName -> e -> IO ()
warnMsg HostName
m = Handle -> HostName -> IO ()
hPutStrLn Handle
stderr forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> [a] -> [a]
(++) (HostName
"Network.ONCRPC.Client: " forall a. [a] -> [a] -> [a]
++ HostName
m forall a. [a] -> [a] -> [a]
++ HostName
": ") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> HostName
show
clientRecv :: Client -> Net.Socket -> IO ()
clientRecv :: Client -> Socket -> IO ()
clientRecv Client
c Socket
sock = TransportState -> IO ()
next TransportState
transportStart where
next :: TransportState -> IO ()
next TransportState
ms =
forall {t} {t}. (t -> t -> IO ()) -> Maybe (t, t) -> IO ()
check Either HostName Rpc_msg -> TransportState -> IO ()
msg forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a.
Socket
-> Get a
-> TransportState
-> IO (Maybe (Either HostName a, TransportState))
recvGetFirst Socket
sock forall a. XDR a => Get a
XDR.xdrGet TransportState
ms
msg :: Either HostName Rpc_msg -> TransportState -> IO ()
msg (Right (RPC.Rpc_msg XID
x (RPC.Rpc_msg_body'REPLY Reply_body
b))) TransportState
ms = do
Maybe Request
q <- forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVarMasked (Client -> MVar State
clientState Client
c) forall a b. (a -> b) -> a -> b
$ \s :: State
s@State{ stateRequests :: State -> IntMap Request
stateRequests = IntMap Request
m } -> do
let (Maybe Request
q, IntMap Request
m') = forall a.
(Int -> a -> Maybe a) -> Int -> IntMap a -> (Maybe a, IntMap a)
IntMap.updateLookupWithKey (forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall a b. a -> b -> a
const forall a. Maybe a
Nothing) (forall a b. (Integral a, Num b) => a -> b
fromIntegral XID
x) IntMap Request
m
forall (m :: * -> *) a. Monad m => a -> m a
return (State
s{ stateRequests :: IntMap Request
stateRequests = IntMap Request
m' }, Maybe Request
q)
case Maybe Request
q of
Maybe Request
Nothing -> do
forall e. Show e => HostName -> e -> IO ()
warnMsg HostName
"ignoring response to unknown xid" XID
x
TransportState -> IO ()
next TransportState
ms
Just (Request ByteString
_ MVar (Reply a)
a) ->
forall {t} {t}. (t -> t -> IO ()) -> Maybe (t, t) -> IO ()
check (\Either HostName (Reply a)
r TransportState
ms' -> do
forall a. MVar a -> a -> IO ()
putMVar MVar (Reply a)
a forall a b. (a -> b) -> a -> b
$ forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall a. HostName -> Reply a
ReplyFail forall a. a -> a
id Either HostName (Reply a)
r
TransportState -> IO ()
next TransportState
ms')
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a.
Socket
-> Get a
-> TransportState
-> IO (Maybe (Either HostName a, TransportState))
recvGetNext Socket
sock (forall a. XDR a => Reply_body -> Get (Reply a)
getReply Reply_body
b) TransportState
ms
msg Either HostName Rpc_msg
e TransportState
_ = forall e. Show e => HostName -> e -> IO ()
warnMsg HostName
"couldn't decode reply msg" Either HostName Rpc_msg
e
check :: (t -> t -> IO ()) -> Maybe (t, t) -> IO ()
check t -> t -> IO ()
_ Maybe (t, t)
Nothing = forall e. Show e => HostName -> e -> IO ()
warnMsg HostName
"socket closed" ()
check t -> t -> IO ()
f (Just (t
r, t
ms)) = t -> t -> IO ()
f t
r t
ms
clientConnect :: Client -> IO Net.Socket
clientConnect :: Client -> IO Socket
clientConnect Client
c = forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar (Client -> MVar State
clientState Client
c) forall a b. (a -> b) -> a -> b
$ ClientServer -> State -> IO (State, Socket)
conn (Client -> ClientServer
clientServer Client
c) where
conn :: ClientServer -> State -> IO (State, Socket)
conn ClientServer
_ s :: State
s@State{ stateSocket :: State -> Maybe Socket
stateSocket = Just Socket
sock } = forall (m :: * -> *) a. Monad m => a -> m a
return (State
s, Socket
sock)
conn ClientServerPort{HostName
clientServerPort :: HostName
clientServerHost :: HostName
clientServerPort :: ClientServer -> HostName
clientServerHost :: ClientServer -> HostName
..} State
s = do
AddrInfo
addr:[AddrInfo]
_ <- Maybe AddrInfo -> Maybe HostName -> Maybe HostName -> IO [AddrInfo]
Net.getAddrInfo (forall a. a -> Maybe a
Just AddrInfo
Net.defaultHints{ addrSocketType :: SocketType
Net.addrSocketType = SocketType
Net.Stream }) (forall a. a -> Maybe a
Just HostName
clientServerHost) (forall a. a -> Maybe a
Just HostName
clientServerPort)
Socket
sock <- Family -> SocketType -> ProtocolNumber -> IO Socket
Net.socket (AddrInfo -> Family
Net.addrFamily AddrInfo
addr) (AddrInfo -> SocketType
Net.addrSocketType AddrInfo
addr) (AddrInfo -> ProtocolNumber
Net.addrProtocol AddrInfo
addr)
#ifdef BINDRESVPORT
when clientBindResvPort $
throwSocketErrorIfMinus1Retry_ "bindresvport" $
Net.withFdSocket sock $ \fd ->
c_bindresvport fd nullPtr
#endif
Socket -> SockAddr -> IO ()
Net.connect Socket
sock (AddrInfo -> SockAddr
Net.addrAddress AddrInfo
addr)
forall {t :: * -> *}. Foldable t => Socket -> t Request -> IO ()
resend Socket
sock (State -> IntMap Request
stateRequests State
s)
forall (m :: * -> *) a. Monad m => a -> m a
return (State
s{ stateSocket :: Maybe Socket
stateSocket = forall a. a -> Maybe a
Just Socket
sock }, Socket
sock)
resend :: Socket -> t Request -> IO ()
resend Socket
sock = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall a b. (a -> b) -> a -> b
$ Socket -> ByteString -> IO ()
sendTransport Socket
sock forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> ByteString
requestBody
clientDisconnect :: Client -> IO ()
clientDisconnect :: Client -> IO ()
clientDisconnect Client
c = forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (Client -> MVar State
clientState Client
c) forall a b. (a -> b) -> a -> b
$ \State
s -> do
forall a. IO a -> (IOError -> IO a) -> IO a
catchIOError
(forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Socket -> IO ()
Net.close forall a b. (a -> b) -> a -> b
$ State -> Maybe Socket
stateSocket State
s)
(forall e. Show e => HostName -> e -> IO ()
warnMsg HostName
"close")
forall (m :: * -> *) a. Monad m => a -> m a
return State
s{ stateSocket :: Maybe Socket
stateSocket = forall a. Maybe a
Nothing }
clientMain :: Client -> IO ()
clientMain :: Client -> IO ()
clientMain Client
c = do
UTCTime
t <- IO UTCTime
getCurrentTime
forall a. IO a -> (IOError -> IO a) -> IO a
catchIOError
(Client -> IO Socket
clientConnect Client
c forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Client -> Socket -> IO ()
clientRecv Client
c)
(forall e. Show e => HostName -> e -> IO ()
warnMsg HostName
"client")
Client -> IO ()
clientDisconnect Client
c
NominalDiffTime
dt <- (UTCTime -> UTCTime -> NominalDiffTime
`diffUTCTime` UTCTime
t) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO UTCTime
getCurrentTime
Int -> IO ()
threadDelay forall a b. (a -> b) -> a -> b
$ forall a b. (RealFrac a, Integral b) => a -> b
ceiling forall a b. (a -> b) -> a -> b
$ NominalDiffTime
300000000 forall a. Fractional a => a -> a -> a
/ (NominalDiffTime
dt forall a. Num a => a -> a -> a
+ NominalDiffTime
20)
Client -> IO ()
clientMain Client
c
openClient :: ClientServer -> IO Client
openClient :: ClientServer -> IO Client
openClient ClientServer
srv = do
MVar State
s <- forall a. IO (MVar a)
newEmptyMVar
let c :: Client
c = Client
{ clientServer :: ClientServer
clientServer = ClientServer
srv
, clientThread :: ThreadId
clientThread = forall a. HasCallStack => HostName -> a
error HostName
"clientThread"
, clientState :: MVar State
clientState = MVar State
s
, clientCred :: Auth
clientCred = Auth
AuthNone
, clientVerf :: Auth
clientVerf = Auth
AuthNone
}
XID
xid <- forall a (m :: * -> *). (Random a, MonadIO m) => m a
randomIO
ThreadId
tid <- IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$ Client -> IO ()
clientMain Client
c
forall a. MVar a -> a -> IO ()
putMVar MVar State
s State
{ stateSocket :: Maybe Socket
stateSocket = forall a. Maybe a
Nothing
, stateXID :: XID
stateXID = XID
xid
, stateRequests :: IntMap Request
stateRequests = forall a. IntMap a
IntMap.empty
}
forall (m :: * -> *) a. Monad m => a -> m a
return Client
c{ clientThread :: ThreadId
clientThread = ThreadId
tid }
setClientAuth :: Auth -> Auth -> Client -> Client
setClientAuth :: Auth -> Auth -> Client -> Client
setClientAuth Auth
cred Auth
verf Client
client = Client
client
{ clientCred :: Auth
clientCred = Auth
cred
, clientVerf :: Auth
clientVerf = Auth
verf
}
closeClient :: Client -> IO ()
closeClient :: Client -> IO ()
closeClient Client
c = do
ThreadId -> IO ()
killThread forall a b. (a -> b) -> a -> b
$ Client -> ThreadId
clientThread Client
c
Client -> IO ()
clientDisconnect Client
c
State
s <- forall a. MVar a -> IO a
takeMVar forall a b. (a -> b) -> a -> b
$ Client -> MVar State
clientState Client
c
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\(Request ByteString
_ MVar (Reply a)
a) -> forall a. MVar a -> a -> IO ()
putMVar MVar (Reply a)
a forall a b. (a -> b) -> a -> b
$ forall a. HostName -> Reply a
ReplyFail HostName
"closed") forall a b. (a -> b) -> a -> b
$ State -> IntMap Request
stateRequests State
s
clientCall :: (XDR.XDR a, XDR.XDR r) => Client -> Call a r -> IO (Reply r)
clientCall :: forall a r. (XDR a, XDR r) => Client -> Call a r -> IO (Reply r)
clientCall Client
c Call a r
a = do
MVar (Reply r)
rv <- forall a. IO (MVar a)
newEmptyMVar
Maybe Request
p <- forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar (Client -> MVar State
clientState Client
c) forall a b. (a -> b) -> a -> b
$ \State
s -> do
let x :: XID
x = State -> XID
stateXID State
s
q :: Request
q = Request
{ requestBody :: ByteString
requestBody = forall a. XDR a => a -> ByteString
XDR.xdrSerializeLazy forall a b. (a -> b) -> a -> b
$ forall a r. XID -> Call a r -> Msg a r
MsgCall XID
x Call a r
a
, requestAction :: MVar (Reply r)
requestAction = MVar (Reply r)
rv
}
(Maybe Request
p, IntMap Request
r) = forall a.
(Int -> a -> a -> a) -> Int -> a -> IntMap a -> (Maybe a, IntMap a)
IntMap.insertLookupWithKey (forall a b. a -> b -> a
const forall a b. a -> b -> a
const) (forall a b. (Integral a, Num b) => a -> b
fromIntegral XID
x) Request
q (State -> IntMap Request
stateRequests State
s)
forall a. IO a -> (IOError -> IO a) -> IO a
catchIOError
(forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Socket -> ByteString -> IO ()
`sendTransport` Request -> ByteString
requestBody Request
q) forall a b. (a -> b) -> a -> b
$ State -> Maybe Socket
stateSocket State
s)
(forall e. Show e => HostName -> e -> IO ()
warnMsg HostName
"sendTransport")
forall (m :: * -> *) a. Monad m => a -> m a
return (State
s{ stateRequests :: IntMap Request
stateRequests = IntMap Request
r, stateXID :: XID
stateXID = XID
xforall a. Num a => a -> a -> a
+XID
1 }, Maybe Request
p)
case Maybe Request
p of
Maybe Request
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
Just (Request ByteString
_ MVar (Reply a)
v) -> forall a. MVar a -> a -> IO ()
putMVar MVar (Reply a)
v (forall a. HostName -> Reply a
ReplyFail HostName
"no response")
forall a. MVar a -> IO a
takeMVar MVar (Reply r)
rv
rpcCall :: (XDR.XDR a, XDR.XDR r) => Client -> Procedure a r -> a -> IO r
rpcCall :: forall a r. (XDR a, XDR r) => Client -> Procedure a r -> a -> IO r
rpcCall Client
c Procedure a r
p a
a = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall e a. Exception e => e -> IO a
throwIO forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Reply a -> Either ReplyException a
replyResult
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a r. (XDR a, XDR r) => Client -> Call a r -> IO (Reply r)
clientCall Client
c (forall a r. Procedure a r -> Auth -> Auth -> a -> Call a r
Call Procedure a r
p (Client -> Auth
clientCred Client
c) (Client -> Auth
clientVerf Client
c) a
a)