-- | TCP implementation of the transport layer. -- -- The TCP implementation guarantees that only a single TCP connection (socket) -- will be used between endpoints, provided that the addresses specified are -- canonical. If /A/ connects to /B/ and reports its address as -- @192.168.0.1:8080@ and /B/ subsequently connects tries to connect to /A/ as -- @client1.local:http-alt@ then the transport layer will not realize that the -- TCP connection can be reused. -- -- Applications that use the TCP transport should use -- 'Network.Socket.withSocketsDo' in their main function for Windows -- compatibility (see "Network.Socket"). module Network.Transport.TCP ( -- * Main API createTransport , TCPParameters(..) , defaultTCPParameters -- * Internals (exposed for unit tests) , createTransportExposeInternals , TransportInternals(..) , EndPointId , encodeEndPointAddress , decodeEndPointAddress , ControlHeader(..) , ConnectionRequestResponse(..) , firstNonReservedLightweightConnectionId , firstNonReservedHeavyweightConnectionId , socketToEndPoint , LightweightConnectionId -- * Design notes -- $design ) where import Prelude hiding ( mapM_ #if ! MIN_VERSION_base(4,6,0) , catch #endif ) import Network.Transport import Network.Transport.TCP.Internal ( forkServer , recvWithLength , recvInt32 , tryCloseSocket ) import Network.Transport.Internal ( encodeInt32 , prependLength , mapIOException , tryIO , tryToEnum , void , timeoutMaybe , asyncWhenCancelled ) #ifdef USE_MOCK_NETWORK import qualified Network.Transport.TCP.Mock.Socket as N #else import qualified Network.Socket as N #endif ( HostName , ServiceName , Socket , getAddrInfo , socket , addrFamily , addrAddress , SocketType(Stream) , defaultProtocol , setSocketOption , SocketOption(ReuseAddr) , connect , sOMAXCONN , AddrInfo ) #ifdef USE_MOCK_NETWORK import Network.Transport.TCP.Mock.Socket.ByteString (sendMany) #else import Network.Socket.ByteString (sendMany) #endif import Control.Concurrent (forkIO, ThreadId, killThread, myThreadId) import Control.Concurrent.Chan (Chan, newChan, readChan, writeChan) import Control.Concurrent.MVar ( MVar , newMVar , modifyMVar , modifyMVar_ , readMVar , putMVar , newEmptyMVar , withMVar ) import Control.Category ((>>>)) import Control.Applicative ((<$>)) import Control.Monad (when, unless, join) import Control.Exception ( IOException , SomeException , AsyncException , handle , throw , throwIO , try , bracketOnError , fromException , catch ) import Data.IORef (IORef, newIORef, writeIORef, readIORef) import Data.ByteString (ByteString) import qualified Data.ByteString as BS (concat) import qualified Data.ByteString.Char8 as BSC (pack, unpack) import Data.Bits (shiftL, (.|.)) import Data.Word (Word32) import Data.Set (Set) import qualified Data.Set as Set ( empty , insert , elems , singleton , null , delete , member ) import Data.Map (Map) import qualified Data.Map as Map (empty) import Data.Accessor (Accessor, accessor, (^.), (^=), (^:)) import qualified Data.Accessor.Container as DAC (mapMaybe) import Data.Foldable (forM_, mapM_) -- $design -- -- [Goals] -- -- The TCP transport maps multiple logical connections between /A/ and /B/ (in -- either direction) to a single TCP connection: -- -- > +-------+ +-------+ -- > | A |==========================| B | -- > | |>~~~~~~~~~~~~~~~~~~~~~~~~~|~~~\ | -- > | Q |>~~~~~~~~~~~~~~~~~~~~~~~~~|~~~Q | -- > | \~~~|~~~~~~~~~~~~~~~~~~~~~~~~~<| | -- > | |==========================| | -- > +-------+ +-------+ -- -- Ignoring the complications detailed below, the TCP connection is set up is -- when the first lightweight connection is created (in either direction), and -- torn down when the last lightweight connection (in either direction) is -- closed. -- -- [Connecting] -- -- Let /A/, /B/ be two endpoints without any connections. When /A/ wants to -- connect to /B/, it locally records that it is trying to connect to /B/ and -- sends a request to /B/. As part of the request /A/ sends its own endpoint -- address to /B/ (so that /B/ can reuse the connection in the other direction). -- -- When /B/ receives the connection request it first checks if it did not -- already initiate a connection request to /A/. If not it will acknowledge the -- connection request by sending 'ConnectionRequestAccepted' to /A/ and record -- that it has a TCP connection to /A/. -- -- The tricky case arises when /A/ sends a connection request to /B/ and /B/ -- finds that it had already sent a connection request to /A/. In this case /B/ -- will accept the connection request from /A/ if /A/s endpoint address is -- smaller (lexicographically) than /B/s, and reject it otherwise. If it rejects -- it, it sends a 'ConnectionRequestCrossed' message to /A/. (The -- lexicographical ordering is an arbitrary but convenient way to break the -- tie.) -- -- When it receives a 'ConnectionRequestCrossed' message the /A/ thread that -- initiated the request just needs to wait until the /A/ thread that is dealing -- with /B/'s connection request completes. -- -- [Disconnecting] -- -- The TCP connection is created as soon as the first logical connection from -- /A/ to /B/ (or /B/ to /A/) is established. At this point a thread (@#@) is -- spawned that listens for incoming connections from /B/: -- -- > +-------+ +-------+ -- > | A |==========================| B | -- > | |>~~~~~~~~~~~~~~~~~~~~~~~~~|~~~\ | -- > | | | Q | -- > | #| | | -- > | |==========================| | -- > +-------+ +-------+ -- -- The question is when the TCP connection can be closed again. Conceptually, -- we want to do reference counting: when there are no logical connections left -- between /A/ and /B/ we want to close the socket (possibly after some -- timeout). -- -- However, /A/ and /B/ need to agree that the refcount has reached zero. It -- might happen that /B/ sends a connection request over the existing socket at -- the same time that /A/ closes its logical connection to /B/ and closes the -- socket. This will cause a failure in /B/ (which will have to retry) which is -- not caused by a network failure, which is unfortunate. (Note that the -- connection request from /B/ might succeed even if /A/ closes the socket.) -- -- Instead, when /A/ is ready to close the socket it sends a 'CloseSocket' -- request to /B/ and records that its connection to /B/ is closing. If /A/ -- receives a new connection request from /B/ after having sent the -- 'CloseSocket' request it simply forgets that it sent a 'CloseSocket' request -- and increments the reference count of the connection again. -- -- When /B/ receives a 'CloseSocket' message and it too is ready to close the -- connection, it will respond with a reciprocal 'CloseSocket' request to /A/ -- and then actually close the socket. /A/ meanwhile will not send any more -- requests to /B/ after having sent a 'CloseSocket' request, and will actually -- close its end of the socket only when receiving the 'CloseSocket' message -- from /B/. (Since /A/ recorded that its connection to /B/ is in closing state -- after sending a 'CloseSocket' request to /B/, it knows not to reciprocate /B/ -- reciprocal 'CloseSocket' message.) -- -- If there is a concurrent thread in /A/ waiting to connect to /B/ after /A/ -- has sent a 'CloseSocket' request then this thread will block until /A/ knows -- whether to reuse the old socket (if /B/ sends a new connection request -- instead of acknowledging the 'CloseSocket') or to set up a new socket. -------------------------------------------------------------------------------- -- Internal datatypes -- -------------------------------------------------------------------------------- -- We use underscores for fields that we might update (using accessors) -- -- All data types follow the same structure: -- -- * A top-level data type describing static properties (TCPTransport, -- LocalEndPoint, RemoteEndPoint) -- * The 'static' properties include an MVar containing a data structure for -- the dynamic properties (TransportState, LocalEndPointState, -- RemoteEndPointState). The state could be invalid/valid/closed,/etc. -- * For the case of "valid" we use third data structure to give more details -- about the state (ValidTransportState, ValidLocalEndPointState, -- ValidRemoteEndPointState). data TCPTransport = TCPTransport { transportHost :: !N.HostName , transportPort :: !N.ServiceName , transportState :: !(MVar TransportState) , transportParams :: !TCPParameters } data TransportState = TransportValid !ValidTransportState | TransportClosed data ValidTransportState = ValidTransportState { _localEndPoints :: !(Map EndPointAddress LocalEndPoint) , _nextEndPointId :: !EndPointId } data LocalEndPoint = LocalEndPoint { localAddress :: !EndPointAddress , localChannel :: !(Chan Event) , localState :: !(MVar LocalEndPointState) } data LocalEndPointState = LocalEndPointValid !ValidLocalEndPointState | LocalEndPointClosed data ValidLocalEndPointState = ValidLocalEndPointState { -- Next available ID for an outgoing lightweight self-connection -- (see also remoteNextConnOutId) _localNextConnOutId :: !LightweightConnectionId -- Next available ID for an incoming heavyweight connection , _nextConnInId :: !HeavyweightConnectionId -- Currently active outgoing heavyweight connections , _localConnections :: !(Map EndPointAddress RemoteEndPoint) } -- REMOTE ENDPOINTS -- -- Remote endpoints (basically, TCP connections) have the following lifecycle: -- -- Init ---+---> Invalid -- | -- +-------------------------------\ -- | | -- | /----------\ | -- | | | | -- | v | v -- +---> Valid ---> Closing ---> Closed -- | | | | -- | | | v -- \-------+----------+--------> Failed -- -- Init: There are two places where we create new remote endpoints: in -- createConnectionTo (in response to an API 'connect' call) and in -- handleConnectionRequest (when a remote node tries to connect to us). -- 'Init' carries an MVar () 'resolved' which concurrent threads can use to -- wait for the remote endpoint to finish initialization. We record who -- requested the connection (the local endpoint or the remote endpoint). -- -- Invalid: We put the remote endpoint in invalid state only during -- createConnectionTo when we fail to connect. -- -- Valid: This is the "normal" state for a working remote endpoint. -- -- Closing: When we detect that a remote endpoint is no longer used, we send a -- CloseSocket request across the connection and put the remote endpoint in -- closing state. As with Init, 'Closing' carries an MVar () 'resolved' which -- concurrent threads can use to wait for the remote endpoint to either be -- closed fully (if the communication parnet responds with another -- CloseSocket) or be put back in 'Valid' state if the remote endpoint denies -- the request. -- -- We also put the endpoint in Closed state, directly from Init, if we our -- outbound connection request crossed an inbound connection request and we -- decide to keep the inbound (i.e., the remote endpoint sent us a -- ConnectionRequestCrossed message). -- -- Closed: The endpoint is put in Closed state after a successful garbage -- collection. -- -- Failed: If the connection to the remote endpoint is lost, or the local -- endpoint (or the whole transport) is closed manually, the remote endpoint is -- put in Failed state, and we record the reason. -- -- Invariants for dealing with remote endpoints: -- -- INV-SEND: Whenever we send data the remote endpoint must be locked (to avoid -- interleaving bits of payload). -- -- INV-CLOSE: Local endpoints should never point to remote endpoint in closed -- state. Whenever we put an endpoint in Closed state we remove that -- endpoint from localConnections first, so that if a concurrent thread reads -- the MVar, finds RemoteEndPointClosed, and then looks up the endpoint in -- localConnections it is guaranteed to either find a different remote -- endpoint, or else none at all (if we don't insist in this order some -- threads might start spinning). -- -- INV-RESOLVE: We should only signal on 'resolved' while the remote endpoint is -- locked, and the remote endpoint must be in Valid or Closed state once -- unlocked. This guarantees that there will not be two threads attempting to -- both signal on 'resolved'. -- -- INV-LOST: If a send or recv fails, or a socket is closed unexpectedly, we -- first put the remote endpoint in Closed state, and then send a -- EventConnectionLost event. This guarantees that we only send this event -- once. -- -- INV-CLOSING: An endpoint in closing state is for all intents and purposes -- closed; that is, we shouldn't do any 'send's on it (although 'recv' is -- acceptable, of course -- as we are waiting for the remote endpoint to -- confirm or deny the request). -- -- INV-LOCK-ORDER: Remote endpoint must be locked before their local endpoints. -- In other words: it is okay to call modifyMVar on a local endpoint inside a -- modifyMVar on a remote endpoint, but not the other way around. In -- particular, it is okay to call removeRemoteEndPoint inside -- modifyRemoteState. data RemoteEndPoint = RemoteEndPoint { remoteAddress :: !EndPointAddress , remoteState :: !(MVar RemoteState) , remoteId :: !HeavyweightConnectionId , remoteScheduled :: !(Chan (IO ())) } data RequestedBy = RequestedByUs | RequestedByThem deriving (Eq, Show) data RemoteState = -- | Invalid remote endpoint (for example, invalid address) RemoteEndPointInvalid !(TransportError ConnectErrorCode) -- | The remote endpoint is being initialized | RemoteEndPointInit !(MVar ()) !(MVar ()) !RequestedBy -- | "Normal" working endpoint | RemoteEndPointValid !ValidRemoteEndPointState -- | The remote endpoint is being closed (garbage collected) | RemoteEndPointClosing !(MVar ()) !ValidRemoteEndPointState -- | The remote endpoint has been closed (garbage collected) | RemoteEndPointClosed -- | The remote endpoint has failed, or has been forcefully shutdown -- using a closeTransport or closeEndPoint API call | RemoteEndPointFailed !IOException -- TODO: we might want to replace Set (here and elsewhere) by faster -- containers -- -- TODO: we could get rid of 'remoteIncoming' (and maintain less state) if -- we introduce a new event 'AllConnectionsClosed' data ValidRemoteEndPointState = ValidRemoteEndPointState { _remoteOutgoing :: !Int , _remoteIncoming :: !(Set LightweightConnectionId) , _remoteMaxIncoming :: !LightweightConnectionId , _remoteNextConnOutId :: !LightweightConnectionId , remoteSocket :: !N.Socket , remoteSendLock :: !(MVar ()) } -- | Local identifier for an endpoint within this transport type EndPointId = Word32 -- | Pair of local and a remote endpoint (for conciseness in signatures) type EndPointPair = (LocalEndPoint, RemoteEndPoint) -- | Lightweight connection ID (sender allocated) -- -- A ConnectionId is the concentation of a 'HeavyweightConnectionId' and a -- 'LightweightConnectionId'. type LightweightConnectionId = Word32 -- | Heavyweight connection ID (recipient allocated) -- -- A ConnectionId is the concentation of a 'HeavyweightConnectionId' and a -- 'LightweightConnectionId'. type HeavyweightConnectionId = Word32 -- | Control headers data ControlHeader = -- | Tell the remote endpoint that we created a new connection CreatedNewConnection -- | Tell the remote endpoint we will no longer be using a connection | CloseConnection -- | Request to close the connection (see module description) | CloseSocket deriving (Enum, Bounded, Show) -- | Response sent by /B/ to /A/ when /A/ tries to connect data ConnectionRequestResponse = -- | /B/ accepts the connection ConnectionRequestAccepted -- | /A/ requested an invalid endpoint | ConnectionRequestInvalid -- | /A/s request crossed with a request from /B/ (see protocols) | ConnectionRequestCrossed deriving (Enum, Bounded, Show) -- | Parameters for setting up the TCP transport data TCPParameters = TCPParameters { -- | Backlog for 'listen'. -- Defaults to SOMAXCONN. tcpBacklog :: Int -- | Should we set SO_REUSEADDR on the server socket? -- Defaults to True. , tcpReuseServerAddr :: Bool -- | Should we set SO_REUSEADDR on client sockets? -- Defaults to True. , tcpReuseClientAddr :: Bool } -- | Internal functionality we expose for unit testing data TransportInternals = TransportInternals { -- | The ID of the thread that listens for new incoming connections transportThread :: ThreadId -- | Find the socket between a local and a remote endpoint , socketBetween :: EndPointAddress -> EndPointAddress -> IO N.Socket } -------------------------------------------------------------------------------- -- Top-level functionality -- -------------------------------------------------------------------------------- -- | Create a TCP transport createTransport :: N.HostName -> N.ServiceName -> TCPParameters -> IO (Either IOException Transport) createTransport host port params = either Left (Right . fst) <$> createTransportExposeInternals host port params -- | You should probably not use this function (used for unit testing only) createTransportExposeInternals :: N.HostName -> N.ServiceName -> TCPParameters -> IO (Either IOException (Transport, TransportInternals)) createTransportExposeInternals host port params = do state <- newMVar . TransportValid $ ValidTransportState { _localEndPoints = Map.empty , _nextEndPointId = 0 } let transport = TCPTransport { transportState = state , transportHost = host , transportPort = port , transportParams = params } tryIO $ bracketOnError (forkServer host port (tcpBacklog params) (tcpReuseServerAddr params) (terminationHandler transport) (handleConnectionRequest transport)) killThread (mkTransport transport) where mkTransport :: TCPTransport -> ThreadId -> IO (Transport, TransportInternals) mkTransport transport tid = return ( Transport { newEndPoint = apiNewEndPoint transport , closeTransport = let evs = [ EndPointClosed , throw $ userError "Transport closed" ] in apiCloseTransport transport (Just tid) evs } , TransportInternals { transportThread = tid , socketBetween = internalSocketBetween transport } ) terminationHandler :: TCPTransport -> SomeException -> IO () terminationHandler transport ex = do let evs = [ ErrorEvent (TransportError EventTransportFailed (show ex)) , throw $ userError "Transport closed" ] apiCloseTransport transport Nothing evs -- | Default TCP parameters defaultTCPParameters :: TCPParameters defaultTCPParameters = TCPParameters { tcpBacklog = N.sOMAXCONN , tcpReuseServerAddr = True , tcpReuseClientAddr = True } -------------------------------------------------------------------------------- -- API functions -- -------------------------------------------------------------------------------- -- | Close the transport apiCloseTransport :: TCPTransport -> Maybe ThreadId -> [Event] -> IO () apiCloseTransport transport mTransportThread evs = asyncWhenCancelled return $ do mTSt <- modifyMVar (transportState transport) $ \st -> case st of TransportValid vst -> return (TransportClosed, Just vst) TransportClosed -> return (TransportClosed, Nothing) forM_ mTSt $ mapM_ (apiCloseEndPoint transport evs) . (^. localEndPoints) -- This will invoke the termination handler, which in turn will call -- apiCloseTransport again, but then the transport will already be closed -- and we won't be passed a transport thread, so we terminate immmediate forM_ mTransportThread killThread -- | Create a new endpoint apiNewEndPoint :: TCPTransport -> IO (Either (TransportError NewEndPointErrorCode) EndPoint) apiNewEndPoint transport = try . asyncWhenCancelled closeEndPoint $ do ourEndPoint <- createLocalEndPoint transport return EndPoint { receive = readChan (localChannel ourEndPoint) , address = localAddress ourEndPoint , connect = apiConnect (transportParams transport) ourEndPoint , closeEndPoint = let evs = [ EndPointClosed , throw $ userError "Endpoint closed" ] in apiCloseEndPoint transport evs ourEndPoint , newMulticastGroup = return . Left $ newMulticastGroupError , resolveMulticastGroup = return . Left . const resolveMulticastGroupError } where newMulticastGroupError = TransportError NewMulticastGroupUnsupported "Multicast not supported" resolveMulticastGroupError = TransportError ResolveMulticastGroupUnsupported "Multicast not supported" -- | Connnect to an endpoint apiConnect :: TCPParameters -- ^ Parameters -> LocalEndPoint -- ^ Local end point -> EndPointAddress -- ^ Remote address -> Reliability -- ^ Reliability (ignored) -> ConnectHints -- ^ Hints -> IO (Either (TransportError ConnectErrorCode) Connection) apiConnect params ourEndPoint theirAddress _reliability hints = try . asyncWhenCancelled close $ if localAddress ourEndPoint == theirAddress then connectToSelf ourEndPoint else do resetIfBroken ourEndPoint theirAddress (theirEndPoint, connId) <- createConnectionTo params ourEndPoint theirAddress hints -- connAlive can be an IORef rather than an MVar because it is protected -- by the remoteState MVar. We don't need the overhead of locking twice. connAlive <- newIORef True return Connection { send = apiSend (ourEndPoint, theirEndPoint) connId connAlive , close = apiClose (ourEndPoint, theirEndPoint) connId connAlive } -- | Close a connection apiClose :: EndPointPair -> LightweightConnectionId -> IORef Bool -> IO () apiClose (ourEndPoint, theirEndPoint) connId connAlive = void . tryIO . asyncWhenCancelled return $ do mAct <- modifyMVar (remoteState theirEndPoint) $ \st -> case st of RemoteEndPointValid vst -> do alive <- readIORef connAlive if alive then do writeIORef connAlive False act <- schedule theirEndPoint $ sendOn vst [encodeInt32 CloseConnection, encodeInt32 connId] return ( RemoteEndPointValid . (remoteOutgoing ^: (\x -> x - 1)) $ vst , Just act ) else return (RemoteEndPointValid vst, Nothing) _ -> return (st, Nothing) forM_ mAct $ runScheduledAction (ourEndPoint, theirEndPoint) closeIfUnused (ourEndPoint, theirEndPoint) -- | Send data across a connection apiSend :: EndPointPair -- ^ Local and remote endpoint -> LightweightConnectionId -- ^ Connection ID -> IORef Bool -- ^ Is the connection still alive? -> [ByteString] -- ^ Payload -> IO (Either (TransportError SendErrorCode) ()) apiSend (ourEndPoint, theirEndPoint) connId connAlive payload = -- We don't need the overhead of asyncWhenCancelled here try . mapIOException sendFailed $ do act <- withMVar (remoteState theirEndPoint) $ \st -> case st of RemoteEndPointInvalid _ -> relyViolation (ourEndPoint, theirEndPoint) "apiSend" RemoteEndPointInit _ _ _ -> relyViolation (ourEndPoint, theirEndPoint) "apiSend" RemoteEndPointValid vst -> do alive <- readIORef connAlive if alive then schedule theirEndPoint $ sendOn vst (encodeInt32 connId : prependLength payload) else throwIO $ TransportError SendClosed "Connection closed" RemoteEndPointClosing _ _ -> do alive <- readIORef connAlive if alive then relyViolation (ourEndPoint, theirEndPoint) "apiSend" else throwIO $ TransportError SendClosed "Connection closed" RemoteEndPointClosed -> do alive <- readIORef connAlive if alive then relyViolation (ourEndPoint, theirEndPoint) "apiSend" else throwIO $ TransportError SendClosed "Connection closed" RemoteEndPointFailed err -> do alive <- readIORef connAlive if alive then throwIO $ TransportError SendFailed (show err) else throwIO $ TransportError SendClosed "Connection closed" runScheduledAction (ourEndPoint, theirEndPoint) act where sendFailed = TransportError SendFailed . show -- | Force-close the endpoint apiCloseEndPoint :: TCPTransport -- ^ Transport -> [Event] -- ^ Events used to report closure -> LocalEndPoint -- ^ Local endpoint -> IO () apiCloseEndPoint transport evs ourEndPoint = asyncWhenCancelled return $ do -- Remove the reference from the transport state removeLocalEndPoint transport ourEndPoint -- Close the local endpoint mOurState <- modifyMVar (localState ourEndPoint) $ \st -> case st of LocalEndPointValid vst -> return (LocalEndPointClosed, Just vst) LocalEndPointClosed -> return (LocalEndPointClosed, Nothing) forM_ mOurState $ \vst -> do forM_ (vst ^. localConnections) tryCloseRemoteSocket forM_ evs $ writeChan (localChannel ourEndPoint) where -- Close the remote socket and return the set of all incoming connections tryCloseRemoteSocket :: RemoteEndPoint -> IO () tryCloseRemoteSocket theirEndPoint = do -- We make an attempt to close the connection nicely -- (by sending a CloseSocket first) let closed = RemoteEndPointFailed . userError $ "apiCloseEndPoint" mAct <- modifyMVar (remoteState theirEndPoint) $ \st -> case st of RemoteEndPointInvalid _ -> return (st, Nothing) RemoteEndPointInit resolved _ _ -> do putMVar resolved () return (closed, Nothing) RemoteEndPointValid vst -> do act <- schedule theirEndPoint $ do tryIO $ sendOn vst [ encodeInt32 CloseSocket , encodeInt32 (vst ^. remoteMaxIncoming) ] tryCloseSocket (remoteSocket vst) return (closed, Just act) RemoteEndPointClosing resolved vst -> do putMVar resolved () act <- schedule theirEndPoint $ tryCloseSocket (remoteSocket vst) return (closed, Just act) RemoteEndPointClosed -> return (st, Nothing) RemoteEndPointFailed err -> return (RemoteEndPointFailed err, Nothing) forM_ mAct $ runScheduledAction (ourEndPoint, theirEndPoint) -------------------------------------------------------------------------------- -- Incoming requests -- -------------------------------------------------------------------------------- -- | Handle a connection request (that is, a remote endpoint that is trying to -- establish a TCP connection with us) -- -- 'handleConnectionRequest' runs in the context of the transport thread, which -- can be killed asynchronously by 'closeTransport'. We fork a separate thread -- as soon as we have located the lcoal endpoint that the remote endpoint is -- interested in. We cannot fork any sooner because then we have no way of -- storing the thread ID and hence no way of killing the thread when we take -- the transport down. We must be careful to close the socket when a (possibly -- asynchronous, ThreadKilled) exception occurs. (If an exception escapes from -- handleConnectionRequest the transport will be shut down.) handleConnectionRequest :: TCPTransport -> N.Socket -> IO () handleConnectionRequest transport sock = handle handleException $ do ourEndPointId <- recvInt32 sock theirAddress <- EndPointAddress . BS.concat <$> recvWithLength sock let ourAddress = encodeEndPointAddress (transportHost transport) (transportPort transport) ourEndPointId ourEndPoint <- withMVar (transportState transport) $ \st -> case st of TransportValid vst -> case vst ^. localEndPointAt ourAddress of Nothing -> do sendMany sock [encodeInt32 ConnectionRequestInvalid] throwIO $ userError "handleConnectionRequest: Invalid endpoint" Just ourEndPoint -> return ourEndPoint TransportClosed -> throwIO $ userError "Transport closed" void . forkIO $ go ourEndPoint theirAddress where go :: LocalEndPoint -> EndPointAddress -> IO () go ourEndPoint theirAddress = do -- This runs in a thread that will never be killed mEndPoint <- handle ((>> return Nothing) . handleException) $ do resetIfBroken ourEndPoint theirAddress (theirEndPoint, isNew) <- findRemoteEndPoint ourEndPoint theirAddress RequestedByThem if not isNew then do tryIO $ sendMany sock [encodeInt32 ConnectionRequestCrossed] tryCloseSocket sock return Nothing else do sendLock <- newMVar () let vst = ValidRemoteEndPointState { remoteSocket = sock , remoteSendLock = sendLock , _remoteOutgoing = 0 , _remoteIncoming = Set.empty , _remoteMaxIncoming = 0 , _remoteNextConnOutId = firstNonReservedLightweightConnectionId } sendMany sock [encodeInt32 ConnectionRequestAccepted] resolveInit (ourEndPoint, theirEndPoint) (RemoteEndPointValid vst) return (Just theirEndPoint) -- If we left the scope of the exception handler with a return value of -- Nothing then the socket is already closed; otherwise, the socket has -- been recorded as part of the remote endpoint. Either way, we no longer -- have to worry about closing the socket on receiving an asynchronous -- exception from this point forward. forM_ mEndPoint $ handleIncomingMessages . (,) ourEndPoint handleException :: SomeException -> IO () handleException ex = do tryCloseSocket sock rethrowIfAsync (fromException ex) rethrowIfAsync :: Maybe AsyncException -> IO () rethrowIfAsync = mapM_ throwIO -- | Handle requests from a remote endpoint. -- -- Returns only if the remote party closes the socket or if an error occurs. -- This runs in a thread that will never be killed. handleIncomingMessages :: EndPointPair -> IO () handleIncomingMessages (ourEndPoint, theirEndPoint) = do mSock <- withMVar theirState $ \st -> case st of RemoteEndPointInvalid _ -> relyViolation (ourEndPoint, theirEndPoint) "handleIncomingMessages (invalid)" RemoteEndPointInit _ _ _ -> relyViolation (ourEndPoint, theirEndPoint) "handleIncomingMessages (init)" RemoteEndPointValid ep -> return . Just $ remoteSocket ep RemoteEndPointClosing _ ep -> return . Just $ remoteSocket ep RemoteEndPointClosed -> return Nothing RemoteEndPointFailed _ -> return Nothing forM_ mSock $ \sock -> tryIO (go sock) >>= either (prematureExit sock) return where -- Dispatch -- -- If a recv throws an exception this will be caught top-level and -- 'prematureExit' will be invoked. The same will happen if the remote -- endpoint is put into a Closed (or Closing) state by a concurrent thread -- (because a 'send' failed) -- the individual handlers below will throw a -- user exception which is then caught and handled the same way as an -- exception thrown by 'recv'. go :: N.Socket -> IO () go sock = do lcid <- recvInt32 sock :: IO LightweightConnectionId if lcid >= firstNonReservedLightweightConnectionId then do readMessage sock lcid go sock else case tryToEnum (fromIntegral lcid) of Just CreatedNewConnection -> do recvInt32 sock >>= createdNewConnection go sock Just CloseConnection -> do recvInt32 sock >>= closeConnection go sock Just CloseSocket -> do didClose <- recvInt32 sock >>= closeSocket sock unless didClose $ go sock Nothing -> throwIO $ userError "Invalid control request" -- Create a new connection createdNewConnection :: LightweightConnectionId -> IO () createdNewConnection lcid = do modifyMVar_ theirState $ \st -> do vst <- case st of RemoteEndPointInvalid _ -> relyViolation (ourEndPoint, theirEndPoint) "handleIncomingMessages:createNewConnection (invalid)" RemoteEndPointInit _ _ _ -> relyViolation (ourEndPoint, theirEndPoint) "handleIncomingMessages:createNewConnection (init)" RemoteEndPointValid vst -> return ( (remoteIncoming ^: Set.insert lcid) $ (remoteMaxIncoming ^= lcid) vst ) RemoteEndPointClosing resolved vst -> do -- If the endpoint is in closing state that means we send a -- CloseSocket request to the remote endpoint. If the remote -- endpoint replies that it created a new connection, it either -- ignored our request or it sent the request before it got ours. -- Either way, at this point we simply restore the endpoint to -- RemoteEndPointValid putMVar resolved () return ( (remoteIncoming ^= Set.singleton lcid) . (remoteMaxIncoming ^= lcid) $ vst ) RemoteEndPointFailed err -> throwIO err RemoteEndPointClosed -> relyViolation (ourEndPoint, theirEndPoint) "createNewConnection (closed)" return (RemoteEndPointValid vst) writeChan ourChannel (ConnectionOpened (connId lcid) ReliableOrdered theirAddr) -- Close a connection -- It is important that we verify that the connection is in fact open, -- because otherwise we should not decrement the reference count closeConnection :: LightweightConnectionId -> IO () closeConnection lcid = do modifyMVar_ theirState $ \st -> case st of RemoteEndPointInvalid _ -> relyViolation (ourEndPoint, theirEndPoint) "closeConnection (invalid)" RemoteEndPointInit _ _ _ -> relyViolation (ourEndPoint, theirEndPoint) "closeConnection (init)" RemoteEndPointValid vst -> do unless (Set.member lcid (vst ^. remoteIncoming)) $ throwIO $ userError "Invalid CloseConnection" return ( RemoteEndPointValid . (remoteIncoming ^: Set.delete lcid) $ vst ) RemoteEndPointClosing _ _ -> -- If the remote endpoint is in Closing state, that means that are as -- far as we are concerned there are no incoming connections. This -- means that a CloseConnection request at this point is invalid. throwIO $ userError "Invalid CloseConnection request" RemoteEndPointFailed err -> throwIO err RemoteEndPointClosed -> relyViolation (ourEndPoint, theirEndPoint) "closeConnection (closed)" writeChan ourChannel (ConnectionClosed (connId lcid)) -- Close the socket (if we don't have any outgoing connections) closeSocket :: N.Socket -> LightweightConnectionId -> IO Bool closeSocket sock lastReceivedId = do mAct <- modifyMVar theirState $ \st -> do case st of RemoteEndPointInvalid _ -> relyViolation (ourEndPoint, theirEndPoint) "handleIncomingMessages:closeSocket (invalid)" RemoteEndPointInit _ _ _ -> relyViolation (ourEndPoint, theirEndPoint) "handleIncomingMessages:closeSocket (init)" RemoteEndPointValid vst -> do -- We regard a CloseSocket message as an (optimized) way for the -- remote endpoint to indicate that all its connections to us are -- now properly closed forM_ (Set.elems $ vst ^. remoteIncoming) $ writeChan ourChannel . ConnectionClosed . connId let vst' = remoteIncoming ^= Set.empty $ vst -- If we still have outgoing connections then we ignore the -- CloseSocket request (we sent a ConnectionCreated message to the -- remote endpoint, but it did not receive it before sending the -- CloseSocket request). Similarly, if lastReceivedId < lastSentId -- then we sent a ConnectionCreated *AND* a ConnectionClosed -- message to the remote endpoint, *both of which* it did not yet -- receive before sending the CloseSocket request. if vst' ^. remoteOutgoing > 0 || lastReceivedId < lastSentId vst then return (RemoteEndPointValid vst', Nothing) else do removeRemoteEndPoint (ourEndPoint, theirEndPoint) -- Attempt to reply (but don't insist) act <- schedule theirEndPoint $ do tryIO $ sendOn vst' [ encodeInt32 CloseSocket , encodeInt32 (vst ^. remoteMaxIncoming) ] tryCloseSocket sock return (RemoteEndPointClosed, Just act) RemoteEndPointClosing resolved vst -> do -- Like above, we need to check if there is a ConnectionCreated -- message that we sent but that the remote endpoint has not yet -- received. However, since we are in 'closing' state, the only -- way this may happen is when we sent a ConnectionCreated, -- ConnectionClosed, and CloseSocket message, none of which have -- yet been received. We leave the endpoint in closing state in -- that case. if lastReceivedId < lastSentId vst then do return (RemoteEndPointClosing resolved vst, Nothing) else do removeRemoteEndPoint (ourEndPoint, theirEndPoint) act <- schedule theirEndPoint $ tryCloseSocket sock putMVar resolved () return (RemoteEndPointClosed, Just act) RemoteEndPointFailed err -> throwIO err RemoteEndPointClosed -> relyViolation (ourEndPoint, theirEndPoint) "handleIncomingMessages:closeSocket (closed)" case mAct of Nothing -> return False Just act -> do runScheduledAction (ourEndPoint, theirEndPoint) act return True -- Read a message and output it on the endPoint's channel. By rights we -- should verify that the connection ID is valid, but this is unnecessary -- overhead readMessage :: N.Socket -> LightweightConnectionId -> IO () readMessage sock lcid = recvWithLength sock >>= writeChan ourChannel . Received (connId lcid) -- Arguments ourChannel = localChannel ourEndPoint theirState = remoteState theirEndPoint theirAddr = remoteAddress theirEndPoint -- Deal with a premature exit prematureExit :: N.Socket -> IOException -> IO () prematureExit sock err = do tryCloseSocket sock modifyMVar_ theirState $ \st -> case st of RemoteEndPointInvalid _ -> relyViolation (ourEndPoint, theirEndPoint) "handleIncomingMessages:prematureExit" RemoteEndPointInit _ _ _ -> relyViolation (ourEndPoint, theirEndPoint) "handleIncomingMessages:prematureExit" RemoteEndPointValid _ -> do let code = EventConnectionLost (remoteAddress theirEndPoint) writeChan ourChannel . ErrorEvent $ TransportError code (show err) return (RemoteEndPointFailed err) RemoteEndPointClosing resolved _ -> do putMVar resolved () return (RemoteEndPointFailed err) RemoteEndPointClosed -> relyViolation (ourEndPoint, theirEndPoint) "handleIncomingMessages:prematureExit" RemoteEndPointFailed err' -> return (RemoteEndPointFailed err') -- Construct a connection ID connId :: LightweightConnectionId -> ConnectionId connId = createConnectionId (remoteId theirEndPoint) -- The ID of the last connection _we_ created (or 0 for none) lastSentId :: ValidRemoteEndPointState -> LightweightConnectionId lastSentId vst = if vst ^. remoteNextConnOutId == firstNonReservedLightweightConnectionId then 0 else (vst ^. remoteNextConnOutId) - 1 -------------------------------------------------------------------------------- -- Uninterruptable auxiliary functions -- -- -- -- All these functions assume they are running in a thread which will never -- -- be killed. -------------------------------------------------------------------------------- -- | Create a connection to a remote endpoint -- -- If the remote endpoint is in 'RemoteEndPointClosing' state then we will -- block until that is resolved. -- -- May throw a TransportError ConnectErrorCode exception. createConnectionTo :: TCPParameters -> LocalEndPoint -> EndPointAddress -> ConnectHints -> IO (RemoteEndPoint, LightweightConnectionId) createConnectionTo params ourEndPoint theirAddress hints = go where go = do (theirEndPoint, isNew) <- mapIOException connectFailed $ findRemoteEndPoint ourEndPoint theirAddress RequestedByUs if isNew then do forkIO . handle absorbAllExceptions $ setupRemoteEndPoint params (ourEndPoint, theirEndPoint) hints go else do -- 'findRemoteEndPoint' will have increased 'remoteOutgoing' mapIOException connectFailed $ do act <- modifyMVar (remoteState theirEndPoint) $ \st -> case st of RemoteEndPointValid vst -> do let connId = vst ^. remoteNextConnOutId act <- schedule theirEndPoint $ do sendOn vst [encodeInt32 CreatedNewConnection, encodeInt32 connId] return connId return ( RemoteEndPointValid $ remoteNextConnOutId ^= connId + 1 $ vst , act ) -- Error cases RemoteEndPointInvalid err -> throwIO err RemoteEndPointFailed err -> throwIO err -- Algorithmic errors _ -> relyViolation (ourEndPoint, theirEndPoint) "createConnectionTo" -- TODO: deal with exception case? connId <- runScheduledAction (ourEndPoint, theirEndPoint) act return (theirEndPoint, connId) connectFailed :: IOException -> TransportError ConnectErrorCode connectFailed = TransportError ConnectFailed . show absorbAllExceptions :: SomeException -> IO () absorbAllExceptions _ex = return () -- | Set up a remote endpoint setupRemoteEndPoint :: TCPParameters -> EndPointPair -> ConnectHints -> IO () setupRemoteEndPoint params (ourEndPoint, theirEndPoint) hints = do result <- socketToEndPoint ourAddress theirAddress (tcpReuseClientAddr params) (connectTimeout hints) didAccept <- case result of Right (sock, ConnectionRequestAccepted) -> do sendLock <- newMVar () let vst = ValidRemoteEndPointState { remoteSocket = sock , remoteSendLock = sendLock , _remoteOutgoing = 0 , _remoteIncoming = Set.empty , _remoteMaxIncoming = 0 , _remoteNextConnOutId = firstNonReservedLightweightConnectionId } resolveInit (ourEndPoint, theirEndPoint) (RemoteEndPointValid vst) return True Right (sock, ConnectionRequestInvalid) -> do let err = invalidAddress "setupRemoteEndPoint: Invalid endpoint" resolveInit (ourEndPoint, theirEndPoint) (RemoteEndPointInvalid err) tryCloseSocket sock return False Right (sock, ConnectionRequestCrossed) -> do withMVar (remoteState theirEndPoint) $ \st -> case st of RemoteEndPointInit _ crossed _ -> putMVar crossed () RemoteEndPointFailed ex -> throwIO ex _ -> relyViolation (ourEndPoint, theirEndPoint) "setupRemoteEndPoint: Crossed" tryCloseSocket sock return False Left err -> do resolveInit (ourEndPoint, theirEndPoint) (RemoteEndPointInvalid err) return False when didAccept $ handleIncomingMessages (ourEndPoint, theirEndPoint) where ourAddress = localAddress ourEndPoint theirAddress = remoteAddress theirEndPoint invalidAddress = TransportError ConnectNotFound -- | Send a CloseSocket request if the remote endpoint is unused closeIfUnused :: EndPointPair -> IO () closeIfUnused (ourEndPoint, theirEndPoint) = do mAct <- modifyMVar (remoteState theirEndPoint) $ \st -> case st of RemoteEndPointValid vst -> if vst ^. remoteOutgoing == 0 && Set.null (vst ^. remoteIncoming) then do resolved <- newEmptyMVar act <- schedule theirEndPoint $ sendOn vst [ encodeInt32 CloseSocket , encodeInt32 (vst ^. remoteMaxIncoming) ] return (RemoteEndPointClosing resolved vst, Just act) else return (RemoteEndPointValid vst, Nothing) _ -> return (st, Nothing) forM_ mAct $ runScheduledAction (ourEndPoint, theirEndPoint) -- | Reset a remote endpoint if it is in Invalid mode -- -- If the remote endpoint is currently in broken state, and -- -- - a user calls the API function 'connect', or and the remote endpoint is -- - an inbound connection request comes in from this remote address -- -- we remove the remote endpoint first. -- -- Throws a TransportError ConnectFailed exception if the local endpoint is -- closed. resetIfBroken :: LocalEndPoint -> EndPointAddress -> IO () resetIfBroken ourEndPoint theirAddress = do mTheirEndPoint <- withMVar (localState ourEndPoint) $ \st -> case st of LocalEndPointValid vst -> return (vst ^. localConnectionTo theirAddress) LocalEndPointClosed -> throwIO $ TransportError ConnectFailed "Endpoint closed" forM_ mTheirEndPoint $ \theirEndPoint -> withMVar (remoteState theirEndPoint) $ \st -> case st of RemoteEndPointInvalid _ -> removeRemoteEndPoint (ourEndPoint, theirEndPoint) RemoteEndPointFailed _ -> removeRemoteEndPoint (ourEndPoint, theirEndPoint) _ -> return () -- | Special case of 'apiConnect': connect an endpoint to itself -- -- May throw a TransportError ConnectErrorCode (if the local endpoint is closed) connectToSelf :: LocalEndPoint -> IO Connection connectToSelf ourEndPoint = do connAlive <- newIORef True -- Protected by the local endpoint lock lconnId <- mapIOException connectFailed $ getLocalNextConnOutId ourEndPoint let connId = createConnectionId heavyweightSelfConnectionId lconnId writeChan ourChan $ ConnectionOpened connId ReliableOrdered (localAddress ourEndPoint) return Connection { send = selfSend connAlive connId , close = selfClose connAlive connId } where selfSend :: IORef Bool -> ConnectionId -> [ByteString] -> IO (Either (TransportError SendErrorCode) ()) selfSend connAlive connId msg = try . withMVar ourState $ \st -> case st of LocalEndPointValid _ -> do alive <- readIORef connAlive if alive then writeChan ourChan (Received connId msg) else throwIO $ TransportError SendClosed "Connection closed" LocalEndPointClosed -> throwIO $ TransportError SendFailed "Endpoint closed" selfClose :: IORef Bool -> ConnectionId -> IO () selfClose connAlive connId = withMVar ourState $ \st -> case st of LocalEndPointValid _ -> do alive <- readIORef connAlive when alive $ do writeChan ourChan (ConnectionClosed connId) writeIORef connAlive False LocalEndPointClosed -> return () ourChan = localChannel ourEndPoint ourState = localState ourEndPoint connectFailed = TransportError ConnectFailed . show -- | Resolve an endpoint currently in 'Init' state resolveInit :: EndPointPair -> RemoteState -> IO () resolveInit (ourEndPoint, theirEndPoint) newState = modifyMVar_ (remoteState theirEndPoint) $ \st -> case st of RemoteEndPointInit resolved _ _ -> do putMVar resolved () case newState of RemoteEndPointClosed -> removeRemoteEndPoint (ourEndPoint, theirEndPoint) _ -> return () return newState RemoteEndPointFailed ex -> throwIO ex _ -> relyViolation (ourEndPoint, theirEndPoint) "resolveInit" -- | Get the next outgoing self-connection ID -- -- Throws an IO exception when the endpoint is closed. getLocalNextConnOutId :: LocalEndPoint -> IO LightweightConnectionId getLocalNextConnOutId ourEndpoint = modifyMVar (localState ourEndpoint) $ \st -> case st of LocalEndPointValid vst -> do let connId = vst ^. localNextConnOutId return ( LocalEndPointValid . (localNextConnOutId ^= connId + 1) $ vst , connId) LocalEndPointClosed -> throwIO $ userError "Local endpoint closed" -- | Create a new local endpoint -- -- May throw a TransportError NewEndPointErrorCode exception if the transport -- is closed. createLocalEndPoint :: TCPTransport -> IO LocalEndPoint createLocalEndPoint transport = do chan <- newChan state <- newMVar . LocalEndPointValid $ ValidLocalEndPointState { _localNextConnOutId = firstNonReservedLightweightConnectionId , _localConnections = Map.empty , _nextConnInId = firstNonReservedHeavyweightConnectionId } modifyMVar (transportState transport) $ \st -> case st of TransportValid vst -> do let ix = vst ^. nextEndPointId let addr = encodeEndPointAddress (transportHost transport) (transportPort transport) ix let localEndPoint = LocalEndPoint { localAddress = addr , localChannel = chan , localState = state } return ( TransportValid . (localEndPointAt addr ^= Just localEndPoint) . (nextEndPointId ^= ix + 1) $ vst , localEndPoint ) TransportClosed -> throwIO (TransportError NewEndPointFailed "Transport closed") -- | Remove reference to a remote endpoint from a local endpoint -- -- If the local endpoint is closed, do nothing removeRemoteEndPoint :: EndPointPair -> IO () removeRemoteEndPoint (ourEndPoint, theirEndPoint) = modifyMVar_ ourState $ \st -> case st of LocalEndPointValid vst -> case vst ^. localConnectionTo theirAddress of Nothing -> return st Just remoteEndPoint' -> if remoteId remoteEndPoint' == remoteId theirEndPoint then return ( LocalEndPointValid . (localConnectionTo (remoteAddress theirEndPoint) ^= Nothing) $ vst ) else return st LocalEndPointClosed -> return LocalEndPointClosed where ourState = localState ourEndPoint theirAddress = remoteAddress theirEndPoint -- | Remove reference to a local endpoint from the transport state -- -- Does nothing if the transport is closed removeLocalEndPoint :: TCPTransport -> LocalEndPoint -> IO () removeLocalEndPoint transport ourEndPoint = modifyMVar_ (transportState transport) $ \st -> case st of TransportValid vst -> return ( TransportValid . (localEndPointAt (localAddress ourEndPoint) ^= Nothing) $ vst ) TransportClosed -> return TransportClosed -- | Find a remote endpoint. If the remote endpoint does not yet exist we -- create it in Init state. Returns if the endpoint was new. findRemoteEndPoint :: LocalEndPoint -> EndPointAddress -> RequestedBy -> IO (RemoteEndPoint, Bool) findRemoteEndPoint ourEndPoint theirAddress findOrigin = go where go = do (theirEndPoint, isNew) <- modifyMVar ourState $ \st -> case st of LocalEndPointValid vst -> case vst ^. localConnectionTo theirAddress of Just theirEndPoint -> return (st, (theirEndPoint, False)) Nothing -> do resolved <- newEmptyMVar crossed <- newEmptyMVar theirState <- newMVar (RemoteEndPointInit resolved crossed findOrigin) scheduled <- newChan let theirEndPoint = RemoteEndPoint { remoteAddress = theirAddress , remoteState = theirState , remoteId = vst ^. nextConnInId , remoteScheduled = scheduled } return ( LocalEndPointValid . (localConnectionTo theirAddress ^= Just theirEndPoint) . (nextConnInId ^: (+ 1)) $ vst , (theirEndPoint, True) ) LocalEndPointClosed -> throwIO $ userError "Local endpoint closed" if isNew then return (theirEndPoint, True) else do let theirState = remoteState theirEndPoint snapshot <- modifyMVar theirState $ \st -> case st of RemoteEndPointValid vst -> case findOrigin of RequestedByUs -> do let st' = RemoteEndPointValid . (remoteOutgoing ^: (+ 1)) $ vst return (st', st') RequestedByThem -> return (st, st) _ -> return (st, st) -- The snapshot may no longer be up to date at this point, but if we -- increased the refcount then it can only either be Valid or Failed -- (after an explicit call to 'closeEndPoint' or 'closeTransport') case snapshot of RemoteEndPointInvalid err -> throwIO err RemoteEndPointInit resolved crossed initOrigin -> case (findOrigin, initOrigin) of (RequestedByUs, RequestedByUs) -> readMVar resolved >> go (RequestedByUs, RequestedByThem) -> readMVar resolved >> go (RequestedByThem, RequestedByUs) -> if ourAddress > theirAddress then do -- Wait for the Crossed message readMVar crossed return (theirEndPoint, True) else return (theirEndPoint, False) (RequestedByThem, RequestedByThem) -> throwIO $ userError "Already connected" RemoteEndPointValid _ -> -- We assume that the request crossed if we find the endpoint in -- Valid state. It is possible that this is really an invalid -- request, but only in the case of a broken client (we don't -- maintain enough history to be able to tell the difference). return (theirEndPoint, False) RemoteEndPointClosing resolved _ -> readMVar resolved >> go RemoteEndPointClosed -> go RemoteEndPointFailed err -> throwIO err ourState = localState ourEndPoint ourAddress = localAddress ourEndPoint -- | Send a payload over a heavyweight connection (thread safe) sendOn :: ValidRemoteEndPointState -> [ByteString] -> IO () sendOn vst bs = withMVar (remoteSendLock vst) $ \() -> sendMany (remoteSocket vst) bs -------------------------------------------------------------------------------- -- Scheduling actions -- -------------------------------------------------------------------------------- -- | See 'schedule'/'runScheduledAction' type Action a = MVar (Either SomeException a) -- | Schedule an action to be executed (see also 'runScheduledAction') schedule :: RemoteEndPoint -> IO a -> IO (Action a) schedule theirEndPoint act = do mvar <- newEmptyMVar writeChan (remoteScheduled theirEndPoint) $ catch (act >>= putMVar mvar . Right) (putMVar mvar . Left) return mvar -- | Run a scheduled action. Every call to 'schedule' should be paired with a -- call to 'runScheduledAction' so that every scheduled action is run. Note -- however that the there is no guarantee that in -- -- > do act <- schedule p -- > runScheduledAction -- -- 'runScheduledAction' will run @p@ (it might run some other scheduled action). -- However, it will then wait until @p@ is executed (by this call to -- 'runScheduledAction' or by another). runScheduledAction :: EndPointPair -> Action a -> IO a runScheduledAction (ourEndPoint, theirEndPoint) mvar = do join $ readChan (remoteScheduled theirEndPoint) ma <- readMVar mvar case ma of Right a -> return a Left e -> do forM_ (fromException e) $ \ioe -> modifyMVar_ (remoteState theirEndPoint) $ \st -> case st of RemoteEndPointValid vst -> handleIOException ioe vst _ -> return (RemoteEndPointFailed ioe) throwIO e where handleIOException :: IOException -> ValidRemoteEndPointState -> IO RemoteState handleIOException ex vst = do tryCloseSocket (remoteSocket vst) let code = EventConnectionLost (remoteAddress theirEndPoint) err = TransportError code (show ex) writeChan (localChannel ourEndPoint) $ ErrorEvent err return (RemoteEndPointFailed ex) -------------------------------------------------------------------------------- -- "Stateless" (MVar free) functions -- -------------------------------------------------------------------------------- -- | Establish a connection to a remote endpoint -- -- Maybe throw a TransportError socketToEndPoint :: EndPointAddress -- ^ Our address -> EndPointAddress -- ^ Their address -> Bool -- ^ Use SO_REUSEADDR? -> Maybe Int -- ^ Timeout for connect -> IO (Either (TransportError ConnectErrorCode) (N.Socket, ConnectionRequestResponse)) socketToEndPoint (EndPointAddress ourAddress) theirAddress reuseAddr timeout = try $ do (host, port, theirEndPointId) <- case decodeEndPointAddress theirAddress of Nothing -> throwIO (failed . userError $ "Could not parse") Just dec -> return dec addr:_ <- mapIOException invalidAddress $ N.getAddrInfo Nothing (Just host) (Just port) bracketOnError (createSocket addr) tryCloseSocket $ \sock -> do when reuseAddr $ mapIOException failed $ N.setSocketOption sock N.ReuseAddr 1 mapIOException invalidAddress $ timeoutMaybe timeout timeoutError $ N.connect sock (N.addrAddress addr) response <- mapIOException failed $ do sendMany sock (encodeInt32 theirEndPointId : prependLength [ourAddress]) recvInt32 sock case tryToEnum response of Nothing -> throwIO (failed . userError $ "Unexpected response") Just r -> return (sock, r) where createSocket :: N.AddrInfo -> IO N.Socket createSocket addr = mapIOException insufficientResources $ N.socket (N.addrFamily addr) N.Stream N.defaultProtocol invalidAddress = TransportError ConnectNotFound . show insufficientResources = TransportError ConnectInsufficientResources . show failed = TransportError ConnectFailed . show timeoutError = TransportError ConnectTimeout "Timed out" -- | Encode end point address encodeEndPointAddress :: N.HostName -> N.ServiceName -> EndPointId -> EndPointAddress encodeEndPointAddress host port ix = EndPointAddress . BSC.pack $ host ++ ":" ++ port ++ ":" ++ show ix -- | Decode end point address decodeEndPointAddress :: EndPointAddress -> Maybe (N.HostName, N.ServiceName, EndPointId) decodeEndPointAddress (EndPointAddress bs) = case splitMaxFromEnd (== ':') 2 $ BSC.unpack bs of [host, port, endPointIdStr] -> case reads endPointIdStr of [(endPointId, "")] -> Just (host, port, endPointId) _ -> Nothing _ -> Nothing -- | Construct a ConnectionId createConnectionId :: HeavyweightConnectionId -> LightweightConnectionId -> ConnectionId createConnectionId hcid lcid = (fromIntegral hcid `shiftL` 32) .|. fromIntegral lcid -- | @spltiMaxFromEnd p n xs@ splits list @xs@ at elements matching @p@, -- returning at most @p@ segments -- counting from the /end/ -- -- > splitMaxFromEnd (== ':') 2 "ab:cd:ef:gh" == ["ab:cd", "ef", "gh"] splitMaxFromEnd :: (a -> Bool) -> Int -> [a] -> [[a]] splitMaxFromEnd p = \n -> go [[]] n . reverse where -- go :: [[a]] -> Int -> [a] -> [[a]] go accs _ [] = accs go ([] : accs) 0 xs = reverse xs : accs go (acc : accs) n (x:xs) = if p x then go ([] : acc : accs) (n - 1) xs else go ((x : acc) : accs) n xs go _ _ _ = error "Bug in splitMaxFromEnd" -------------------------------------------------------------------------------- -- Functions from TransportInternals -- -------------------------------------------------------------------------------- -- Find a socket between two endpoints -- -- Throws an IO exception if the socket could not be found. internalSocketBetween :: TCPTransport -- ^ Transport -> EndPointAddress -- ^ Local endpoint -> EndPointAddress -- ^ Remote endpoint -> IO N.Socket internalSocketBetween transport ourAddress theirAddress = do ourEndPoint <- withMVar (transportState transport) $ \st -> case st of TransportClosed -> throwIO $ userError "Transport closed" TransportValid vst -> case vst ^. localEndPointAt ourAddress of Nothing -> throwIO $ userError "Local endpoint not found" Just ep -> return ep theirEndPoint <- withMVar (localState ourEndPoint) $ \st -> case st of LocalEndPointClosed -> throwIO $ userError "Local endpoint closed" LocalEndPointValid vst -> case vst ^. localConnectionTo theirAddress of Nothing -> throwIO $ userError "Remote endpoint not found" Just ep -> return ep withMVar (remoteState theirEndPoint) $ \st -> case st of RemoteEndPointInit _ _ _ -> throwIO $ userError "Remote endpoint not yet initialized" RemoteEndPointValid vst -> return $ remoteSocket vst RemoteEndPointClosing _ vst -> return $ remoteSocket vst RemoteEndPointClosed -> throwIO $ userError "Remote endpoint closed" RemoteEndPointInvalid err -> throwIO err RemoteEndPointFailed err -> throwIO err -------------------------------------------------------------------------------- -- Constants -- -------------------------------------------------------------------------------- -- | We reserve a bunch of connection IDs for control messages firstNonReservedLightweightConnectionId :: LightweightConnectionId firstNonReservedLightweightConnectionId = 1024 -- | Self-connection heavyweightSelfConnectionId :: HeavyweightConnectionId heavyweightSelfConnectionId = 0 -- | We reserve some connection IDs for special heavyweight connections firstNonReservedHeavyweightConnectionId :: HeavyweightConnectionId firstNonReservedHeavyweightConnectionId = 1 -------------------------------------------------------------------------------- -- Accessor definitions -- -------------------------------------------------------------------------------- localEndPoints :: Accessor ValidTransportState (Map EndPointAddress LocalEndPoint) localEndPoints = accessor _localEndPoints (\es st -> st { _localEndPoints = es }) nextEndPointId :: Accessor ValidTransportState EndPointId nextEndPointId = accessor _nextEndPointId (\eid st -> st { _nextEndPointId = eid }) localNextConnOutId :: Accessor ValidLocalEndPointState LightweightConnectionId localNextConnOutId = accessor _localNextConnOutId (\cix st -> st { _localNextConnOutId = cix }) localConnections :: Accessor ValidLocalEndPointState (Map EndPointAddress RemoteEndPoint) localConnections = accessor _localConnections (\es st -> st { _localConnections = es }) nextConnInId :: Accessor ValidLocalEndPointState HeavyweightConnectionId nextConnInId = accessor _nextConnInId (\rid st -> st { _nextConnInId = rid }) remoteOutgoing :: Accessor ValidRemoteEndPointState Int remoteOutgoing = accessor _remoteOutgoing (\cs conn -> conn { _remoteOutgoing = cs }) remoteIncoming :: Accessor ValidRemoteEndPointState (Set LightweightConnectionId) remoteIncoming = accessor _remoteIncoming (\cs conn -> conn { _remoteIncoming = cs }) remoteMaxIncoming :: Accessor ValidRemoteEndPointState LightweightConnectionId remoteMaxIncoming = accessor _remoteMaxIncoming (\lcid st -> st { _remoteMaxIncoming = lcid }) remoteNextConnOutId :: Accessor ValidRemoteEndPointState LightweightConnectionId remoteNextConnOutId = accessor _remoteNextConnOutId (\cix st -> st { _remoteNextConnOutId = cix }) localEndPointAt :: EndPointAddress -> Accessor ValidTransportState (Maybe LocalEndPoint) localEndPointAt addr = localEndPoints >>> DAC.mapMaybe addr localConnectionTo :: EndPointAddress -> Accessor ValidLocalEndPointState (Maybe RemoteEndPoint) localConnectionTo addr = localConnections >>> DAC.mapMaybe addr ------------------------------------------------------------------------------- -- Debugging -- ------------------------------------------------------------------------------- relyViolation :: EndPointPair -> String -> IO a relyViolation (ourEndPoint, theirEndPoint) str = do elog (ourEndPoint, theirEndPoint) (str ++ " RELY violation") fail (str ++ " RELY violation") elog :: EndPointPair -> String -> IO () elog (ourEndPoint, theirEndPoint) msg = do tid <- myThreadId putStrLn $ show (localAddress ourEndPoint) ++ "/" ++ show (remoteAddress theirEndPoint) ++ "(" ++ show (remoteId theirEndPoint) ++ ")" ++ "/" ++ show tid ++ ": " ++ msg