-- | A channel carries RPCs to a remote server module Network.RPCA.Channel( ChannelStatus(..) , ErrorCode(..) , Channel , networkChannel , rpc , rpcAsync ) where import Control.Concurrent.STM import GHC.Conc import Control.Timeout import Control.Monad (when) import Text.Printf (printf) import Data.Word import Data.Maybe (isJust, fromJust) import Network.Socket hiding (send, sendTo, recv, recvFrom) import qualified Data.ByteString as BS import qualified Data.Map as Map import qualified Data.Sequence as Seq import Codec.Libevent.Class import Network.RPCA.Structs import Network.RPCA.Util import qualified Network.RPCA.Connection as C -- | These are the various states that a channel can be in data ChannelStatus = Down -- ^ channel is down. Requests will be enqueued | Lame -- ^ remote end has signaled that it's shutting down | Connecting -- ^ a connection is being attempted. -- Requests will be enqueued | Up -- ^ channel is ready. Requests will be sent to the transport -- | This just factors out some common code from @rpc@ and @rpcAsync@ genericRPCCallback :: (TaggedStructure a) => ((Either ErrorCode a) -> IO ()) -> Rpcreply -> BS.ByteString -> IO () genericRPCCallback cont reply payload = let replyCode = toEnum $ fromIntegral $ rpcreply_reply_code reply in if replyCode == ErrNone then case deserialise payload of Left _ -> cont $ Left ErrReplyPayloadParseFailed Right x -> cont $ Right x else cont $ Left replyCode -- | Perform an asyncronous RPC call rpcAsync :: (Channel c, TaggedStructure a, TaggedStructure b) => c -- ^ the channel -> String -- ^ the method name -> a -- ^ the request arguments -> ((Either ErrorCode b) -> IO ()) -- ^ callback -> Float -- ^ timeout -> IO () rpcAsync channel method request cb timeout = do nqueue channel (rpcrequestEmpty { rpcrequest_method = method }) (serialise request) (Just timeout) $ genericRPCCallback cb -- | Perform a syncronous RPC rpc :: (Channel c, TaggedStructure a, TaggedStructure b) => c -- ^ the channel -> String -- ^ the method name -> a -- ^ request arguments -> Float -- ^ timeout -> IO (Either ErrorCode b) rpc channel method request timeout = do result <- atomically newEmptyTMVar nqueue channel (rpcrequestEmpty {rpcrequest_method = method }) (serialise request) Nothing $ genericRPCCallback (atomically . putTMVar result) atomically (readTMVar result) >>= return -- | This is a channel over which RPCs can be submitted class Channel c where -- | Add an RPC to the outbound queue nqueue :: c -- ^ the channel -> Rpcrequest -- ^ the request - the id and service is filled in by the channel -> BS.ByteString -- ^ payload -> Maybe Float -- ^ timeout in seconds -> (Rpcreply -> BS.ByteString -> IO ()) -- ^ callback -> IO () -- | Get the current status of the channel channelStatus :: c -> TVar ChannelStatus -- | A simple network channel which tries to maintain a connection to a -- hostname:port pair. data NetworkChannel = NetworkChannel { ncid :: TVar Word32 -- ^ the next RPC id , ncservice :: String -- ^ the target service name , nchost :: String -- ^ the target host , ncport :: Int -- ^ the target port number , ncoutq :: TVar (Seq.Seq (BS.ByteString, BS.ByteString)) -- ^ requests waiting for a connection , ncdispatch :: TVar (Map.Map Word32 (Rpcreply -> BS.ByteString -> IO ())) -- ^ the dispatch table for incomming replies. Maps the RPC id to the -- handler function , nctimeouts :: TVar (Map.Map Word32 TimeoutTag) -- ^ maps RPC id to the timeout tag that we can use to cancel the timeout -- when the reply comes in , ncstatus :: TVar ChannelStatus -- ^ the status of this channel , ncdead :: TVar Bool -- ^ if set, the channel is to be shutdown , ncconn :: TVar C.Connection } -- | This is the "reading" thread for the network connection. If it throws an -- exception the connection is shutdown and we end up in closeConnection readAction :: Socket -> NetworkChannel -> IO () readAction socket nc = do atomically $ writeTVar (ncstatus nc) Connecting hostaddr <- inet_addr $ nchost nc print "Connecting" connect socket (SockAddrInet (PortNum $ htons $ fromIntegral $ ncport nc) hostaddr) -- Set the status and drain the outbound queue (q, conn) <- atomically (do writeTVar (ncstatus nc) Up -- TODO: maybe Lame? q <- readTVar $ ncoutq nc writeTVar (ncoutq nc) Seq.empty conn <- readTVar $ ncconn nc return (q, conn)) atomically $ mapM (C.writePacket conn) $ seqToList q readReply nc socket -- | This also runs in the read thread of the connection and loops forever, -- reading replies from the network and processing them. readReply :: NetworkChannel -> Socket -> IO () readReply nc socket = do (a, payload) <- C.readPacket socket let mibreply = inboundreplyDeserialiseBS a case mibreply of Left _ -> error "Protocol error" Right ibreply -> do -- Lookup the id in the dispatch table, removing it if found. Also, -- possibly cancel a timeout linked to the RPC mcb <- atomically (do dispatch <- readTVar $ ncdispatch nc timeouts <- readTVar $ nctimeouts nc let id = inboundreply_id ibreply let mtimeoutTag = Map.lookup id timeouts -- If we have a timeout, cancel it and remove it from our map when (isJust mtimeoutTag) (do cancelTimeout $ fromJust mtimeoutTag writeTVar (nctimeouts nc) $ Map.delete id timeouts) case Map.lookup id dispatch of Nothing -> return Nothing Just cb -> do writeTVar (ncdispatch nc) $ Map.delete id dispatch return $ Just cb) case mcb of Nothing -> printf "RPC reply id:%d\n" $ ((fromIntegral $ inboundreply_id ibreply) :: Int) Just _ -> return () case mcb of Nothing -> printf "Unknown RPC reply" >> readReply nc socket Just cb -> do case inboundreply_rpc ibreply of Nothing -> error "No RPC reply in reply" Just rpcreply -> do cb rpcreply payload readReply nc socket -- | This is called by the Connection when the connection fails. If this -- channel has been shutdown then we have nothing else to do as the -- Connection will close the socket and kill the threads. Otherwise, we -- sleep, make a new socket and retry the connection closeAction :: NetworkChannel -> IO () closeAction nc = do (dead, cbs) <- atomically (do writeTVar (ncstatus nc) Down cbs <- readTVar (ncdispatch nc) >>= return . Map.elems timeoutTags <- readTVar (nctimeouts nc) >>= return . Map.elems mapM_ cancelTimeout timeoutTags writeTVar (ncdispatch nc) Map.empty writeTVar (nctimeouts nc) Map.empty dead <- readTVar $ ncdead nc return (dead, cbs)) -- fail all enqueued RPCs mapM_ (abortRPC ErrTransportFailed) cbs if dead then return () else do threadDelay 1000000 sock <- socket AF_INET Stream 0 setSocketOption sock NoDelay 1 conn <- atomically (do conn <- C.new sock $ closeAction nc writeTVar (ncconn nc) conn return conn) C.forkThreads conn $ readAction sock nc -- | Create a new networkChannel networkChannel :: String -- ^ service name -> String -- ^ hostname -> Int -- ^ port number -> IO NetworkChannel networkChannel service host port = do sock <- socket AF_INET Stream 0 setSocketOption sock NoDelay 1 c <- atomically (do ncid <- newTVar 0 ncdispatch <- newTVar Map.empty nctimeouts <- newTVar Map.empty ncstatus <- newTVar Down ncoutq <- newTVar Seq.empty ncdead <- newTVar False ncconn <- newTVar undefined let c = NetworkChannel ncid service host port ncoutq ncdispatch nctimeouts ncstatus ncdead ncconn conn <- C.new sock $ closeAction c writeTVar ncconn conn return c) conn <- atomically $ readTVar $ ncconn c C.forkThreads conn $ readAction sock c return c -- | Run the given callback with a constructued Rpcreply object which has the -- correct error code for the given error. abortRPC :: ErrorCode -- ^ the error code to give -> (Rpcreply -> BS.ByteString -> IO ()) -- ^ the callback -> IO () abortRPC err cb = do -- make up the reply with the correct error code and run the callback let rpcreply = rpcreplyEmpty { rpcreply_reply_code = fromIntegral $ fromEnum err } cb rpcreply BS.empty handleTimeout :: NetworkChannel -- ^ the channel which contains the request -> Word32 -- ^ the id which has timed out -> IO () handleTimeout nc id = do -- remove our timeout tag from the map of timeouts and return the callback mcb <- atomically (do updateTVar (Map.delete id) $ nctimeouts nc updateTVar (Map.delete id) (ncdispatch nc) >>= return . Map.lookup id) case mcb of Nothing -> return () Just cb -> abortRPC ErrTimeout cb instance Channel NetworkChannel where channelStatus = ncstatus nqueue nc rpcreq payload mtimeout cb = do -- Get an ID number for this RPC id <- atomically $ updateTVar ((+) 1) $ ncid nc -- If we have a timeout, do the IO preparation now timeoutTag <- case mtimeout of Nothing -> return Nothing Just timeout -> addTimeoutAtomic timeout (handleTimeout nc id) >>= return . Just -- Update all the bookkeeping atomically (do let rpcreq' = rpcreq { rpcrequest_service = ncservice nc } updateTVar (Map.insert id cb) $ ncdispatch nc -- if we have a timeout, run the STM action to enable it and record the -- tag in case we want to cancel it. when (isJust timeoutTag) (do tag <- fromJust timeoutTag updateTVar (Map.insert id tag) $ nctimeouts nc return ()) -- build the outgoing request let obreq = outboundrequestEmpty { outboundrequest_id = id , outboundrequest_rpc = Just rpcreq' } obreqbytes = outboundrequestSerialiseBS obreq -- depending on the channel status we either enqueue the request here or -- in the Connection. (If the Connection is still connecting etc, then -- enqueuing in the Connection will result in it getting dropped if the -- connection fails etc) status <- readTVar $ channelStatus nc case status of Up -> do conn <- readTVar $ ncconn nc C.writePacket conn (obreqbytes, payload) _ -> do updateTVar ((Seq.<|) (obreqbytes, payload)) $ ncoutq nc return ())