{-# LANGUAGE RecursiveDo #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE LambdaCase #-}
module Network.Transport.TCP
(
createTransport
, TCPAddr(..)
, defaultTCPAddr
, TCPAddrInfo(..)
, TCPParameters(..)
, defaultTCPParameters
, createTransportExposeInternals
, TransportInternals(..)
, EndPointId
, ControlHeader(..)
, ConnectionRequestResponse(..)
, firstNonReservedLightweightConnectionId
, firstNonReservedHeavyweightConnectionId
, socketToEndPoint
, LightweightConnectionId
, QDisc(..)
, simpleUnboundedQDisc
, simpleOnePlaceQDisc
) where
import Prelude hiding
( mapM_
#if ! MIN_VERSION_base(4,6,0)
, catch
#endif
)
import Network.Transport
import Network.Transport.TCP.Internal
( ControlHeader(..)
, encodeControlHeader
, decodeControlHeader
, ConnectionRequestResponse(..)
, encodeConnectionRequestResponse
, decodeConnectionRequestResponse
, forkServer
, recvWithLength
, recvExact
, recvWord32
, encodeWord32
, tryCloseSocket
, tryShutdownSocketBoth
, resolveSockAddr
, EndPointId
, encodeEndPointAddress
, decodeEndPointAddress
, currentProtocolVersion
, randomEndPointAddress
)
import Network.Transport.Internal
( prependLength
, mapIOException
, tryIO
, tryToEnum
, void
, timeoutMaybe
, asyncWhenCancelled
)
#ifdef USE_MOCK_NETWORK
import qualified Network.Transport.TCP.Mock.Socket as N
#else
import qualified Network.Socket as N
#endif
( HostName
, ServiceName
, Socket
, getAddrInfo
, maxListenQueue
, socket
, addrFamily
, addrAddress
, SocketType(Stream)
, defaultProtocol
, setSocketOption
, SocketOption(ReuseAddr, NoDelay, UserTimeout, KeepAlive)
, isSupportedSocketOption
, connect
, AddrInfo
, SockAddr(..)
)
#ifdef USE_MOCK_NETWORK
import Network.Transport.TCP.Mock.Socket.ByteString (sendMany)
#else
import Network.Socket.ByteString (sendMany)
#endif
import Control.Concurrent
( forkIO
, ThreadId
, killThread
, myThreadId
, threadDelay
, throwTo
)
import Control.Concurrent.Chan (Chan, newChan, readChan, writeChan)
import Control.Concurrent.MVar
( MVar
, newMVar
, modifyMVar
, modifyMVar_
, readMVar
, tryReadMVar
, takeMVar
, putMVar
, tryPutMVar
, newEmptyMVar
, withMVar
)
import Control.Concurrent.Async (async, wait)
import Control.Category ((>>>))
import Control.Applicative ((<$>))
import Control.Monad (when, unless, join, mplus, (<=<))
import Control.Exception
( IOException
, SomeException
, AsyncException
, handle
, throw
, throwIO
, try
, bracketOnError
, bracket
, fromException
, finally
, catch
, bracket
, mask
, mask_
)
import Data.IORef (IORef, newIORef, writeIORef, readIORef, writeIORef)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS (concat, length, null)
import qualified Data.ByteString.Char8 as BSC (pack, unpack)
import Data.Bits (shiftL, (.|.))
import Data.Maybe (isJust, isNothing, fromJust)
import Data.Word (Word32)
import Data.Set (Set)
import qualified Data.Set as Set
( empty
, insert
, elems
, singleton
, null
, delete
, member
)
import Data.Map (Map)
import qualified Data.Map as Map (empty)
import Data.Traversable (traverse)
import Data.Accessor (Accessor, accessor, (^.), (^=), (^:))
import qualified Data.Accessor.Container as DAC (mapMaybe)
import Data.Foldable (forM_, mapM_)
import qualified System.Timeout (timeout)
data TransportAddrInfo = TransportAddrInfo
{ transportHost :: !N.HostName
, transportPort :: !N.ServiceName
, transportBindHost :: !N.HostName
, transportBindPort :: !N.ServiceName
}
data TCPTransport = TCPTransport
{ transportAddrInfo :: !(Maybe TransportAddrInfo)
, transportState :: !(MVar TransportState)
, transportParams :: !TCPParameters
}
data TransportState =
TransportValid !ValidTransportState
| TransportClosed
data ValidTransportState = ValidTransportState
{ _localEndPoints :: !(Map EndPointId LocalEndPoint)
, _nextEndPointId :: !EndPointId
}
data LocalEndPoint = LocalEndPoint
{ localAddress :: !EndPointAddress
, localEndPointId :: !EndPointId
, localState :: !(MVar LocalEndPointState)
, localQueue :: !(QDisc Event)
}
data LocalEndPointState =
LocalEndPointValid !ValidLocalEndPointState
| LocalEndPointClosed
data ValidLocalEndPointState = ValidLocalEndPointState
{
_localNextConnOutId :: !LightweightConnectionId
, _nextConnInId :: !HeavyweightConnectionId
, _localConnections :: !(Map EndPointAddress RemoteEndPoint)
}
data RemoteEndPoint = RemoteEndPoint
{ remoteAddress :: !EndPointAddress
, remoteState :: !(MVar RemoteState)
, remoteId :: !HeavyweightConnectionId
, remoteScheduled :: !(Chan (IO ()))
}
data RequestedBy = RequestedByUs | RequestedByThem
deriving (Eq, Show)
data RemoteState =
RemoteEndPointInvalid !(TransportError ConnectErrorCode)
| RemoteEndPointInit !(MVar ()) !(MVar ()) !RequestedBy
| RemoteEndPointValid !ValidRemoteEndPointState
| RemoteEndPointClosing !(MVar ()) !ValidRemoteEndPointState
| RemoteEndPointClosed
| RemoteEndPointFailed !IOException
data ValidRemoteEndPointState = ValidRemoteEndPointState
{ _remoteOutgoing :: !Int
, _remoteIncoming :: !(Set LightweightConnectionId)
, _remoteLastIncoming :: !LightweightConnectionId
, _remoteNextConnOutId :: !LightweightConnectionId
, remoteSocket :: !N.Socket
, remoteProbing :: Maybe (IO ())
, remoteSendLock :: !(MVar (Maybe SomeException))
, remoteSocketClosed :: !(IO ())
}
type EndPointPair = (LocalEndPoint, RemoteEndPoint)
type LightweightConnectionId = Word32
type HeavyweightConnectionId = Word32
data TCPAddrInfo = TCPAddrInfo {
tcpBindHost :: N.HostName
, tcpBindPort :: N.ServiceName
, tcpExternalAddress :: N.ServiceName -> (N.HostName, N.ServiceName)
}
data TCPAddr = Addressable TCPAddrInfo | Unaddressable
defaultTCPAddr :: N.HostName -> N.ServiceName -> TCPAddr
defaultTCPAddr host port = Addressable $ TCPAddrInfo {
tcpBindHost = host
, tcpBindPort = port
, tcpExternalAddress = (,) host
}
data TCPParameters = TCPParameters {
tcpBacklog :: Int
, tcpReuseServerAddr :: Bool
, tcpReuseClientAddr :: Bool
, tcpNoDelay :: Bool
, tcpKeepAlive :: Bool
, tcpUserTimeout :: Maybe Int
, transportConnectTimeout :: Maybe Int
, tcpNewQDisc :: forall t . IO (QDisc t)
, tcpMaxAddressLength :: Word32
, tcpMaxReceiveLength :: Word32
, tcpCheckPeerHost :: Bool
, tcpServerExceptionHandler :: SomeException -> IO ()
}
data TransportInternals = TransportInternals
{
transportThread :: Maybe ThreadId
, newEndPointInternal :: (forall t . Maybe (QDisc t))
-> IO (Either (TransportError NewEndPointErrorCode) EndPoint)
, socketBetween :: EndPointAddress
-> EndPointAddress
-> IO N.Socket
}
createTransport
:: TCPAddr
-> TCPParameters
-> IO (Either IOException Transport)
createTransport addr params =
either Left (Right . fst) <$> createTransportExposeInternals addr params
createTransportExposeInternals
:: TCPAddr
-> TCPParameters
-> IO (Either IOException (Transport, TransportInternals))
createTransportExposeInternals addr params = do
state <- newMVar . TransportValid $ ValidTransportState
{ _localEndPoints = Map.empty
, _nextEndPointId = 0
}
case addr of
Unaddressable ->
let transport = TCPTransport { transportState = state
, transportAddrInfo = Nothing
, transportParams = params
}
in fmap Right (mkTransport transport Nothing)
Addressable (TCPAddrInfo bindHost bindPort mkExternal) -> tryIO $ mdo
when ( isJust (tcpUserTimeout params) &&
not (N.isSupportedSocketOption N.UserTimeout)
) $
throwIO $ userError $ "Network.Transport.TCP.createTransport: " ++
"the parameter tcpUserTimeout is unsupported " ++
"in this system."
(port', result) <- do
let (externalHost, externalPort) = mkExternal port'
let addrInfo = TransportAddrInfo { transportHost = externalHost
, transportPort = externalPort
, transportBindHost = bindHost
, transportBindPort = port'
}
let transport = TCPTransport { transportState = state
, transportAddrInfo = Just addrInfo
, transportParams = params
}
bracketOnError (forkServer
bindHost
bindPort
(tcpBacklog params)
(tcpReuseServerAddr params)
(errorHandler transport)
(terminationHandler transport)
(handleConnectionRequest transport))
(\(_port', tid) -> killThread tid)
(\(port'', tid) -> (port'',) <$> mkTransport transport (Just tid))
return result
where
mkTransport :: TCPTransport
-> Maybe ThreadId
-> IO (Transport, TransportInternals)
mkTransport transport mtid = do
return
( Transport
{ newEndPoint = do
qdisc <- tcpNewQDisc params
apiNewEndPoint transport qdisc
, closeTransport = let evs = [ EndPointClosed ]
in apiCloseTransport transport mtid evs
}
, TransportInternals
{ transportThread = mtid
, socketBetween = internalSocketBetween transport
, newEndPointInternal = \mqdisc -> case mqdisc of
Just qdisc -> apiNewEndPoint transport qdisc
Nothing -> do
qdisc <- tcpNewQDisc params
apiNewEndPoint transport qdisc
}
)
errorHandler :: TCPTransport -> SomeException -> IO ()
errorHandler _ = tcpServerExceptionHandler params
terminationHandler :: TCPTransport -> SomeException -> IO ()
terminationHandler transport ex = do
let evs = [ ErrorEvent (TransportError EventTransportFailed (show ex))
, throw $ userError "Transport closed"
]
apiCloseTransport transport Nothing evs
defaultTCPParameters :: TCPParameters
defaultTCPParameters = TCPParameters {
tcpBacklog = N.maxListenQueue
, tcpReuseServerAddr = True
, tcpReuseClientAddr = True
, tcpNoDelay = False
, tcpKeepAlive = False
, tcpUserTimeout = Nothing
, tcpNewQDisc = simpleUnboundedQDisc
, transportConnectTimeout = Nothing
, tcpMaxAddressLength = maxBound
, tcpMaxReceiveLength = maxBound
, tcpCheckPeerHost = False
, tcpServerExceptionHandler = throwIO
}
apiCloseTransport :: TCPTransport -> Maybe ThreadId -> [Event] -> IO ()
apiCloseTransport transport mTransportThread evs =
asyncWhenCancelled return $ do
mTSt <- modifyMVar (transportState transport) $ \st -> case st of
TransportValid vst -> return (TransportClosed, Just vst)
TransportClosed -> return (TransportClosed, Nothing)
forM_ mTSt $ mapM_ (apiCloseEndPoint transport evs) . (^. localEndPoints)
forM_ mTransportThread killThread
apiNewEndPoint :: TCPTransport
-> QDisc Event
-> IO (Either (TransportError NewEndPointErrorCode) EndPoint)
apiNewEndPoint transport qdisc =
try . asyncWhenCancelled closeEndPoint $ do
ourEndPoint <- createLocalEndPoint transport qdisc
return EndPoint
{ receive = qdiscDequeue (localQueue ourEndPoint)
, address = localAddress ourEndPoint
, connect = apiConnect transport ourEndPoint
, closeEndPoint = let evs = [ EndPointClosed ]
in apiCloseEndPoint transport evs ourEndPoint
, newMulticastGroup = return . Left $ newMulticastGroupError
, resolveMulticastGroup = return . Left . const resolveMulticastGroupError
}
where
newMulticastGroupError =
TransportError NewMulticastGroupUnsupported "Multicast not supported"
resolveMulticastGroupError =
TransportError ResolveMulticastGroupUnsupported "Multicast not supported"
data QDisc t = QDisc {
qdiscDequeue :: IO t
, qdiscEnqueue :: EndPointAddress -> Event -> t -> IO ()
}
qdiscEnqueue' :: QDisc Event -> EndPointAddress -> Event -> IO ()
qdiscEnqueue' qdisc addr event = qdiscEnqueue qdisc addr event event
simpleUnboundedQDisc :: forall t . IO (QDisc t)
simpleUnboundedQDisc = do
eventChan <- newChan
return $ QDisc {
qdiscDequeue = readChan eventChan
, qdiscEnqueue = const (const (writeChan eventChan))
}
simpleOnePlaceQDisc :: forall t . IO (QDisc t)
simpleOnePlaceQDisc = do
mvar <- newEmptyMVar
return $ QDisc {
qdiscDequeue = takeMVar mvar
, qdiscEnqueue = const (const (putMVar mvar))
}
apiConnect :: TCPTransport
-> LocalEndPoint
-> EndPointAddress
-> Reliability
-> ConnectHints
-> IO (Either (TransportError ConnectErrorCode) Connection)
apiConnect transport ourEndPoint theirAddress _reliability hints =
try . asyncWhenCancelled close $
if localAddress ourEndPoint == theirAddress
then connectToSelf ourEndPoint
else do
resetIfBroken ourEndPoint theirAddress
(theirEndPoint, connId) <-
createConnectionTo transport ourEndPoint theirAddress hints
connAlive <- newIORef True
return Connection
{ send = apiSend (ourEndPoint, theirEndPoint) connId connAlive
, close = apiClose (ourEndPoint, theirEndPoint) connId connAlive
}
where
params = transportParams transport
apiClose :: EndPointPair -> LightweightConnectionId -> IORef Bool -> IO ()
apiClose (ourEndPoint, theirEndPoint) connId connAlive =
void . tryIO . asyncWhenCancelled return $ finally
(withScheduledAction ourEndPoint $ \sched -> do
modifyMVar_ (remoteState theirEndPoint) $ \st -> case st of
RemoteEndPointValid vst -> do
alive <- readIORef connAlive
if alive
then do
writeIORef connAlive False
sched theirEndPoint $
sendOn vst [
encodeWord32 (encodeControlHeader CloseConnection)
, encodeWord32 connId
]
return ( RemoteEndPointValid
. (remoteOutgoing ^: (\x -> x - 1))
$ vst
)
else
return (RemoteEndPointValid vst)
_ ->
return st)
(closeIfUnused (ourEndPoint, theirEndPoint))
apiSend :: EndPointPair
-> LightweightConnectionId
-> IORef Bool
-> [ByteString]
-> IO (Either (TransportError SendErrorCode) ())
apiSend (ourEndPoint, theirEndPoint) connId connAlive payload =
try . mapIOException sendFailed $ withScheduledAction ourEndPoint $ \sched -> do
withMVar (remoteState theirEndPoint) $ \st -> case st of
RemoteEndPointInvalid _ ->
relyViolation (ourEndPoint, theirEndPoint) "apiSend"
RemoteEndPointInit _ _ _ ->
relyViolation (ourEndPoint, theirEndPoint) "apiSend"
RemoteEndPointValid vst -> do
alive <- readIORef connAlive
if alive
then sched theirEndPoint $
sendOn vst (encodeWord32 connId : prependLength payload)
else throwIO $ TransportError SendClosed "Connection closed"
RemoteEndPointClosing _ _ -> do
alive <- readIORef connAlive
if alive
then relyViolation (ourEndPoint, theirEndPoint) "apiSend RemoteEndPointClosing"
else throwIO $ TransportError SendClosed "Connection closed"
RemoteEndPointClosed -> do
alive <- readIORef connAlive
if alive
then throwIO $ TransportError SendFailed "Remote endpoint closed"
else throwIO $ TransportError SendClosed "Connection closed"
RemoteEndPointFailed err -> do
alive <- readIORef connAlive
if alive
then throwIO $ TransportError SendFailed (show err)
else throwIO $ TransportError SendClosed "Connection closed"
where
sendFailed = TransportError SendFailed . show
apiCloseEndPoint :: TCPTransport
-> [Event]
-> LocalEndPoint
-> IO ()
apiCloseEndPoint transport evs ourEndPoint =
asyncWhenCancelled return $ do
removeLocalEndPoint transport ourEndPoint
mOurState <- modifyMVar (localState ourEndPoint) $ \st ->
case st of
LocalEndPointValid vst ->
return (LocalEndPointClosed, Just vst)
LocalEndPointClosed ->
return (LocalEndPointClosed, Nothing)
forM_ mOurState $ \vst -> do
forM_ (vst ^. localConnections) tryCloseRemoteSocket
let qdisc = localQueue ourEndPoint
forM_ evs (qdiscEnqueue' qdisc (localAddress ourEndPoint))
where
tryCloseRemoteSocket :: RemoteEndPoint -> IO ()
tryCloseRemoteSocket theirEndPoint = withScheduledAction ourEndPoint $ \sched -> do
let closed = RemoteEndPointFailed . userError $ "apiCloseEndPoint"
modifyMVar_ (remoteState theirEndPoint) $ \st ->
case st of
RemoteEndPointInvalid _ ->
return st
RemoteEndPointInit resolved _ _ -> do
putMVar resolved ()
return closed
RemoteEndPointValid vst -> do
sched theirEndPoint $ do
void $ tryIO $ sendOn vst
[ encodeWord32 (encodeControlHeader CloseEndPoint) ]
forM_ (remoteProbing vst) id
tryShutdownSocketBoth (remoteSocket vst)
remoteSocketClosed vst
return closed
RemoteEndPointClosing resolved vst -> do
forM_ (remoteProbing vst) id
putMVar resolved ()
sched theirEndPoint $ do
tryShutdownSocketBoth (remoteSocket vst)
remoteSocketClosed vst
return closed
RemoteEndPointClosed ->
return st
RemoteEndPointFailed err ->
return (RemoteEndPointFailed err)
handleConnectionRequest :: TCPTransport -> IO () -> (N.Socket, N.SockAddr) -> IO ()
handleConnectionRequest transport socketClosed (sock, sockAddr) = handle handleException $ do
when (tcpNoDelay $ transportParams transport) $
N.setSocketOption sock N.NoDelay 1
when (tcpKeepAlive $ transportParams transport) $
N.setSocketOption sock N.KeepAlive 1
forM_ (tcpUserTimeout $ transportParams transport) $
N.setSocketOption sock N.UserTimeout
let handleVersioned = do
protocolVersion <- recvWord32 sock
handshakeLength <- recvWord32 sock
case protocolVersion of
0x00000000 -> handleConnectionRequestV0 (sock, sockAddr)
_ -> do
sendMany sock [
encodeWord32 (encodeConnectionRequestResponse ConnectionRequestUnsupportedVersion)
, encodeWord32 0x00000000
]
_ <- recvExact sock handshakeLength
handleVersioned
let connTimeout = transportConnectTimeout (transportParams transport)
outcome <- maybe (fmap Just) System.Timeout.timeout connTimeout handleVersioned
case outcome of
Nothing -> throwIO (userError "handleConnectionRequest: timed out")
Just act -> forM_ act id
where
handleException :: SomeException -> IO ()
handleException ex = do
rethrowIfAsync (fromException ex)
rethrowIfAsync :: Maybe AsyncException -> IO ()
rethrowIfAsync = mapM_ throwIO
handleConnectionRequestV0 :: (N.Socket, N.SockAddr) -> IO (Maybe (IO ()))
handleConnectionRequestV0 (sock, sockAddr) = do
(numericHost, resolvedHost, actualPort) <-
resolveSockAddr sockAddr >>=
maybe (throwIO (userError "handleConnectionRequest: invalid socket address")) return
(ourEndPointId, theirAddress, mTheirHost) <- do
ourEndPointId <- recvWord32 sock
let maxAddressLength = tcpMaxAddressLength $ transportParams transport
mTheirAddress <- BS.concat <$> recvWithLength maxAddressLength sock
if BS.null mTheirAddress
then do
theirAddress <- randomEndPointAddress
return (ourEndPointId, theirAddress, Nothing)
else do
let theirAddress = EndPointAddress mTheirAddress
(theirHost, _, _)
<- maybe (throwIO (userError "handleConnectionRequest: peer gave malformed address"))
return
(decodeEndPointAddress theirAddress)
return (ourEndPointId, theirAddress, Just theirHost)
let checkPeerHost = tcpCheckPeerHost (transportParams transport)
continue <- case (mTheirHost, checkPeerHost) of
(Just theirHost, True) -> do
if theirHost == numericHost || theirHost == resolvedHost
then return True
else do
sendMany sock $
encodeWord32 (encodeConnectionRequestResponse ConnectionRequestHostMismatch)
: (prependLength [BSC.pack theirHost] ++ prependLength [BSC.pack numericHost] ++ prependLength [BSC.pack resolvedHost])
return False
_ -> return True
if continue
then do
ourEndPoint <- withMVar (transportState transport) $ \st -> case st of
TransportValid vst ->
case vst ^. localEndPointAt ourEndPointId of
Nothing -> do
sendMany sock [encodeWord32 (encodeConnectionRequestResponse ConnectionRequestInvalid)]
throwIO $ userError "handleConnectionRequest: Invalid endpoint"
Just ourEndPoint ->
return ourEndPoint
TransportClosed ->
throwIO $ userError "Transport closed"
return (Just (go ourEndPoint theirAddress))
else return Nothing
where
go :: LocalEndPoint -> EndPointAddress -> IO ()
go ourEndPoint theirAddress = handle handleException $ do
resetIfBroken ourEndPoint theirAddress
(theirEndPoint, isNew) <-
findRemoteEndPoint ourEndPoint theirAddress RequestedByThem Nothing
if not isNew
then do
void $ tryIO $ sendMany sock
[encodeWord32 (encodeConnectionRequestResponse ConnectionRequestCrossed)]
probeIfValid theirEndPoint
else do
sendLock <- newMVar Nothing
let vst = ValidRemoteEndPointState
{ remoteSocket = sock
, remoteSocketClosed = socketClosed
, remoteProbing = Nothing
, remoteSendLock = sendLock
, _remoteOutgoing = 0
, _remoteIncoming = Set.empty
, _remoteLastIncoming = 0
, _remoteNextConnOutId = firstNonReservedLightweightConnectionId
}
sendMany sock [encodeWord32 (encodeConnectionRequestResponse ConnectionRequestAccepted)]
resolveInit (ourEndPoint, theirEndPoint) (RemoteEndPointValid vst)
`finally`
handleIncomingMessages (transportParams transport) (ourEndPoint, theirEndPoint)
probeIfValid :: RemoteEndPoint -> IO ()
probeIfValid theirEndPoint = modifyMVar_ (remoteState theirEndPoint) $
\st -> case st of
RemoteEndPointValid
vst@(ValidRemoteEndPointState { remoteProbing = Nothing }) -> do
tid <- forkIO $ do
let params = transportParams transport
void $ tryIO $ System.Timeout.timeout
(maybe (-1) id $ transportConnectTimeout params) $ do
sendMany (remoteSocket vst)
[encodeWord32 (encodeControlHeader ProbeSocket)]
threadDelay maxBound
tryCloseSocket (remoteSocket vst)
return $ RemoteEndPointValid
vst { remoteProbing = Just (killThread tid) }
_ -> return st
handleIncomingMessages :: TCPParameters -> EndPointPair -> IO ()
handleIncomingMessages params (ourEndPoint, theirEndPoint) =
bracket acquire release act
where
acquire :: IO (Either IOError N.Socket)
acquire = withMVar theirState $ \st -> case st of
RemoteEndPointInvalid _ ->
relyViolation (ourEndPoint, theirEndPoint)
"handleIncomingMessages (invalid)"
RemoteEndPointInit _ _ _ ->
relyViolation (ourEndPoint, theirEndPoint)
"handleIncomingMessages (init)"
RemoteEndPointValid ep ->
return . Right $ remoteSocket ep
RemoteEndPointClosing _ ep ->
return . Right $ remoteSocket ep
RemoteEndPointClosed ->
return . Left $ userError "handleIncomingMessages (already closed)"
RemoteEndPointFailed _ ->
return . Left $ userError "handleIncomingMessages (failed)"
release :: Either IOError N.Socket -> IO ()
release (Left err) = prematureExit err
release (Right _) = return ()
act :: Either IOError N.Socket -> IO ()
act (Left _) = return ()
act (Right sock) = go sock `catch` prematureExit
go :: N.Socket -> IO ()
go sock = do
lcid <- recvWord32 sock :: IO LightweightConnectionId
if lcid >= firstNonReservedLightweightConnectionId
then do
readMessage sock lcid
go sock
else
case decodeControlHeader lcid of
Just CreatedNewConnection -> do
recvWord32 sock >>= createdNewConnection
go sock
Just CloseConnection -> do
recvWord32 sock >>= closeConnection
go sock
Just CloseSocket -> do
didClose <- recvWord32 sock >>= closeSocket sock
unless didClose $ go sock
Just CloseEndPoint -> do
let closeRemoteEndPoint vst = do
forM_ (remoteProbing vst) id
forM_ (Set.elems $ vst ^. remoteIncoming) $
qdiscEnqueue' ourQueue theirAddr . ConnectionClosed . connId
when (vst ^. remoteOutgoing > 0) $ do
let code = EventConnectionLost (remoteAddress theirEndPoint)
qdiscEnqueue' ourQueue theirAddr . ErrorEvent $
TransportError code "The remote endpoint was closed."
removeRemoteEndPoint (ourEndPoint, theirEndPoint)
modifyMVar_ theirState $ \s -> case s of
RemoteEndPointValid vst -> do
closeRemoteEndPoint vst
return RemoteEndPointClosed
RemoteEndPointClosing resolved vst -> do
closeRemoteEndPoint vst
putMVar resolved ()
return RemoteEndPointClosed
_ -> return s
Just ProbeSocket -> do
forkIO $ sendMany sock [encodeWord32 (encodeControlHeader ProbeSocketAck)]
go sock
Just ProbeSocketAck -> do
stopProbing
go sock
Nothing ->
throwIO $ userError "Invalid control request"
createdNewConnection :: LightweightConnectionId -> IO ()
createdNewConnection lcid = do
modifyMVar_ theirState $ \st -> do
vst <- case st of
RemoteEndPointInvalid _ ->
relyViolation (ourEndPoint, theirEndPoint)
"handleIncomingMessages:createNewConnection (invalid)"
RemoteEndPointInit _ _ _ ->
relyViolation (ourEndPoint, theirEndPoint)
"handleIncomingMessages:createNewConnection (init)"
RemoteEndPointValid vst ->
return ( (remoteIncoming ^: Set.insert lcid)
$ (remoteLastIncoming ^= lcid)
vst
)
RemoteEndPointClosing resolved vst -> do
putMVar resolved ()
return ( (remoteIncoming ^= Set.singleton lcid)
. (remoteLastIncoming ^= lcid)
$ vst
)
RemoteEndPointFailed err ->
throwIO err
RemoteEndPointClosed ->
relyViolation (ourEndPoint, theirEndPoint)
"createNewConnection (closed)"
return (RemoteEndPointValid vst)
qdiscEnqueue' ourQueue theirAddr (ConnectionOpened (connId lcid) ReliableOrdered theirAddr)
closeConnection :: LightweightConnectionId -> IO ()
closeConnection lcid = do
modifyMVar_ theirState $ \st -> case st of
RemoteEndPointInvalid _ ->
relyViolation (ourEndPoint, theirEndPoint) "closeConnection (invalid)"
RemoteEndPointInit _ _ _ ->
relyViolation (ourEndPoint, theirEndPoint) "closeConnection (init)"
RemoteEndPointValid vst -> do
unless (Set.member lcid (vst ^. remoteIncoming)) $
throwIO $ userError "Invalid CloseConnection"
return ( RemoteEndPointValid
. (remoteIncoming ^: Set.delete lcid)
$ vst
)
RemoteEndPointClosing _ _ ->
throwIO $ userError "Invalid CloseConnection request"
RemoteEndPointFailed err ->
throwIO err
RemoteEndPointClosed ->
relyViolation (ourEndPoint, theirEndPoint) "closeConnection (closed)"
qdiscEnqueue' ourQueue theirAddr (ConnectionClosed (connId lcid))
closeSocket :: N.Socket -> LightweightConnectionId -> IO Bool
closeSocket sock lastReceivedId = do
mAct <- modifyMVar theirState $ \st -> do
case st of
RemoteEndPointInvalid _ ->
relyViolation (ourEndPoint, theirEndPoint)
"handleIncomingMessages:closeSocket (invalid)"
RemoteEndPointInit _ _ _ ->
relyViolation (ourEndPoint, theirEndPoint)
"handleIncomingMessages:closeSocket (init)"
RemoteEndPointValid vst -> do
forM_ (Set.elems $ vst ^. remoteIncoming) $
qdiscEnqueue' ourQueue theirAddr . ConnectionClosed . connId
let vst' = remoteIncoming ^= Set.empty $ vst
if vst ^. remoteOutgoing > 0 || lastReceivedId /= lastSentId vst
then
return (RemoteEndPointValid vst', Nothing)
else do
forM_ (remoteProbing vst) id
removeRemoteEndPoint (ourEndPoint, theirEndPoint)
act <- schedule theirEndPoint $ do
void $ tryIO $ sendOn vst'
[ encodeWord32 (encodeControlHeader CloseSocket)
, encodeWord32 (vst ^. remoteLastIncoming)
]
return (RemoteEndPointClosed, Just act)
RemoteEndPointClosing resolved vst -> do
if lastReceivedId /= lastSentId vst
then do
return (RemoteEndPointClosing resolved vst, Nothing)
else do
when (vst ^. remoteOutgoing > 0) $ do
let code = EventConnectionLost (remoteAddress theirEndPoint)
let msg = "socket closed prematurely by peer"
qdiscEnqueue' ourQueue theirAddr . ErrorEvent $ TransportError code msg
forM_ (remoteProbing vst) id
removeRemoteEndPoint (ourEndPoint, theirEndPoint)
act <- schedule theirEndPoint $ return ()
putMVar resolved ()
return (RemoteEndPointClosed, Just act)
RemoteEndPointFailed err ->
throwIO err
RemoteEndPointClosed ->
relyViolation (ourEndPoint, theirEndPoint)
"handleIncomingMessages:closeSocket (closed)"
case mAct of
Nothing -> return False
Just act -> do
runScheduledAction (ourEndPoint, theirEndPoint) act
return True
readMessage :: N.Socket -> LightweightConnectionId -> IO ()
readMessage sock lcid =
recvWithLength recvLimit sock >>=
qdiscEnqueue' ourQueue theirAddr . Received (connId lcid)
stopProbing :: IO ()
stopProbing = modifyMVar_ theirState $ \st -> case st of
RemoteEndPointValid
vst@(ValidRemoteEndPointState { remoteProbing = Just stop }) -> do
stop
return $ RemoteEndPointValid vst { remoteProbing = Nothing }
_ -> return st
ourQueue = localQueue ourEndPoint
ourState = localState ourEndPoint
theirState = remoteState theirEndPoint
theirAddr = remoteAddress theirEndPoint
recvLimit = tcpMaxReceiveLength params
prematureExit :: IOException -> IO ()
prematureExit err = do
modifyMVar_ theirState $ \st ->
case st of
RemoteEndPointInvalid _ ->
relyViolation (ourEndPoint, theirEndPoint)
"handleIncomingMessages:prematureExit"
RemoteEndPointInit _ _ _ ->
relyViolation (ourEndPoint, theirEndPoint)
"handleIncomingMessages:prematureExit"
RemoteEndPointValid vst -> do
forM_ (remoteProbing vst) id
let code = EventConnectionLost (remoteAddress theirEndPoint)
qdiscEnqueue' ourQueue theirAddr . ErrorEvent $ TransportError code (show err)
return (RemoteEndPointFailed err)
RemoteEndPointClosing resolved vst -> do
forM_ (remoteProbing vst) id
putMVar resolved ()
return (RemoteEndPointFailed err)
RemoteEndPointClosed ->
relyViolation (ourEndPoint, theirEndPoint)
"handleIncomingMessages:prematureExit"
RemoteEndPointFailed err' -> do
modifyMVar_ ourState $ \st' -> case st' of
LocalEndPointClosed -> return st'
LocalEndPointValid _ -> do
let code = EventConnectionLost (remoteAddress theirEndPoint)
err = TransportError code (show err')
qdiscEnqueue' ourQueue theirAddr (ErrorEvent err)
return st'
return (RemoteEndPointFailed err')
connId :: LightweightConnectionId -> ConnectionId
connId = createConnectionId (remoteId theirEndPoint)
lastSentId :: ValidRemoteEndPointState -> LightweightConnectionId
lastSentId vst =
if vst ^. remoteNextConnOutId == firstNonReservedLightweightConnectionId
then 0
else (vst ^. remoteNextConnOutId) - 1
createConnectionTo
:: TCPTransport
-> LocalEndPoint
-> EndPointAddress
-> ConnectHints
-> IO (RemoteEndPoint, LightweightConnectionId)
createConnectionTo transport ourEndPoint theirAddress hints = do
timer <- case connTimeout of
Just t -> do
mv <- newEmptyMVar
_ <- forkIO $ threadDelay t >> putMVar mv ()
return $ Just $ readMVar mv
_ -> return Nothing
go timer Nothing
where
params = transportParams transport
connTimeout = connectTimeout hints `mplus` transportConnectTimeout params
go timer mr = do
(theirEndPoint, isNew) <- mapIOException connectFailed
(findRemoteEndPoint ourEndPoint theirAddress RequestedByUs timer)
`finally` case mr of
Just (theirEndPoint, ConnectionRequestCrossed) ->
modifyMVar_ (remoteState theirEndPoint) $
\rst -> case rst of
RemoteEndPointInit resolved _ _ -> do
putMVar resolved ()
removeRemoteEndPoint (ourEndPoint, theirEndPoint)
return RemoteEndPointClosed
_ -> return rst
_ -> return ()
if isNew
then do
mr' <- handle (absorbAllExceptions Nothing) $
setupRemoteEndPoint transport (ourEndPoint, theirEndPoint) connTimeout
go timer (fmap ((,) theirEndPoint) mr')
else do
mapIOException connectFailed $ do
act <- modifyMVar (remoteState theirEndPoint) $ \st -> case st of
RemoteEndPointValid vst -> do
let connId = vst ^. remoteNextConnOutId
act <- schedule theirEndPoint $ do
sendOn vst [
encodeWord32 (encodeControlHeader CreatedNewConnection)
, encodeWord32 connId
]
return connId
return ( RemoteEndPointValid
$ remoteNextConnOutId ^= connId + 1
$ vst
, act
)
RemoteEndPointInvalid err ->
throwIO err
RemoteEndPointFailed err ->
throwIO err
_ ->
relyViolation (ourEndPoint, theirEndPoint) "createConnectionTo"
connId <- runScheduledAction (ourEndPoint, theirEndPoint) act
return (theirEndPoint, connId)
connectFailed :: IOException -> TransportError ConnectErrorCode
connectFailed = TransportError ConnectFailed . show
absorbAllExceptions :: a -> SomeException -> IO a
absorbAllExceptions a _ex =
return a
setupRemoteEndPoint
:: TCPTransport
-> EndPointPair
-> Maybe Int
-> IO (Maybe ConnectionRequestResponse)
setupRemoteEndPoint transport (ourEndPoint, theirEndPoint) connTimeout = do
let mOurAddress = const ourAddress <$> transportAddrInfo transport
result <- socketToEndPoint mOurAddress
theirAddress
(tcpReuseClientAddr params)
(tcpNoDelay params)
(tcpKeepAlive params)
(tcpUserTimeout params)
connTimeout
didAccept <- case result of
Right (socketClosedVar, sock, ConnectionRequestAccepted) -> do
sendLock <- newMVar Nothing
let vst = ValidRemoteEndPointState
{ remoteSocket = sock
, remoteSocketClosed = readMVar socketClosedVar
, remoteProbing = Nothing
, remoteSendLock = sendLock
, _remoteOutgoing = 0
, _remoteIncoming = Set.empty
, _remoteLastIncoming = 0
, _remoteNextConnOutId = firstNonReservedLightweightConnectionId
}
resolveInit (ourEndPoint, theirEndPoint) (RemoteEndPointValid vst)
return (Just (socketClosedVar, sock))
Right (socketClosedVar, sock, ConnectionRequestUnsupportedVersion) -> do
let err = connectFailed "setupRemoteEndPoint: unsupported version"
resolveInit (ourEndPoint, theirEndPoint) (RemoteEndPointInvalid err)
tryCloseSocket sock `finally` putMVar socketClosedVar ()
return Nothing
Right (socketClosedVar, sock, ConnectionRequestInvalid) -> do
let err = invalidAddress "setupRemoteEndPoint: Invalid endpoint"
resolveInit (ourEndPoint, theirEndPoint) (RemoteEndPointInvalid err)
tryCloseSocket sock `finally` putMVar socketClosedVar ()
return Nothing
Right (socketClosedVar, sock, ConnectionRequestCrossed) -> do
withMVar (remoteState theirEndPoint) $ \st -> case st of
RemoteEndPointInit _ crossed _ ->
putMVar crossed ()
RemoteEndPointFailed ex ->
throwIO ex
_ ->
relyViolation (ourEndPoint, theirEndPoint) "setupRemoteEndPoint: Crossed"
tryCloseSocket sock `finally` putMVar socketClosedVar ()
return Nothing
Right (socketClosedVar, sock, ConnectionRequestHostMismatch) -> do
let handler :: SomeException -> IO (TransportError ConnectErrorCode)
handler err = return (TransportError ConnectFailed (show err))
err <- handle handler $ do
claimedHost <- recvWithLength (tcpMaxReceiveLength params) sock
actualNumericHost <- recvWithLength (tcpMaxReceiveLength params) sock
actualResolvedHost <- recvWithLength (tcpMaxReceiveLength params) sock
let reason = concat [
"setupRemoteEndPoint: Host mismatch"
, ". Claimed: "
, BSC.unpack (BS.concat claimedHost)
, "; Numeric: "
, BSC.unpack (BS.concat actualNumericHost)
, "; Resolved: "
, BSC.unpack (BS.concat actualResolvedHost)
]
return (TransportError ConnectFailed reason)
resolveInit (ourEndPoint, theirEndPoint) (RemoteEndPointInvalid err)
tryCloseSocket sock `finally` putMVar socketClosedVar ()
return Nothing
Left err -> do
resolveInit (ourEndPoint, theirEndPoint) (RemoteEndPointInvalid err)
return Nothing
forM_ didAccept $ \(socketClosed, sock) -> void $ forkIO $
handleIncomingMessages params (ourEndPoint, theirEndPoint)
`finally`
(tryCloseSocket sock `finally` putMVar socketClosed ())
return $ either (const Nothing) (Just . (\(_,_,x) -> x)) result
where
params = transportParams transport
ourAddress = localAddress ourEndPoint
theirAddress = remoteAddress theirEndPoint
invalidAddress = TransportError ConnectNotFound
connectFailed = TransportError ConnectFailed
closeIfUnused :: EndPointPair -> IO ()
closeIfUnused (ourEndPoint, theirEndPoint) = do
mAct <- modifyMVar (remoteState theirEndPoint) $ \st -> case st of
RemoteEndPointValid vst ->
if vst ^. remoteOutgoing == 0 && Set.null (vst ^. remoteIncoming)
then do
resolved <- newEmptyMVar
act <- schedule theirEndPoint $
sendOn vst [ encodeWord32 (encodeControlHeader CloseSocket)
, encodeWord32 (vst ^. remoteLastIncoming)
]
return (RemoteEndPointClosing resolved vst, Just act)
else
return (RemoteEndPointValid vst, Nothing)
_ ->
return (st, Nothing)
forM_ mAct $ runScheduledAction (ourEndPoint, theirEndPoint)
resetIfBroken :: LocalEndPoint -> EndPointAddress -> IO ()
resetIfBroken ourEndPoint theirAddress = do
mTheirEndPoint <- withMVar (localState ourEndPoint) $ \st -> case st of
LocalEndPointValid vst ->
return (vst ^. localConnectionTo theirAddress)
LocalEndPointClosed ->
throwIO $ TransportError ConnectFailed "Endpoint closed"
forM_ mTheirEndPoint $ \theirEndPoint ->
withMVar (remoteState theirEndPoint) $ \st -> case st of
RemoteEndPointInvalid _ ->
removeRemoteEndPoint (ourEndPoint, theirEndPoint)
RemoteEndPointFailed _ ->
removeRemoteEndPoint (ourEndPoint, theirEndPoint)
_ ->
return ()
connectToSelf :: LocalEndPoint
-> IO Connection
connectToSelf ourEndPoint = do
connAlive <- newIORef True
lconnId <- mapIOException connectFailed $ getLocalNextConnOutId ourEndPoint
let connId = createConnectionId heavyweightSelfConnectionId lconnId
qdiscEnqueue' ourQueue ourAddress $
ConnectionOpened connId ReliableOrdered (localAddress ourEndPoint)
return Connection
{ send = selfSend connAlive connId
, close = selfClose connAlive connId
}
where
selfSend :: IORef Bool
-> ConnectionId
-> [ByteString]
-> IO (Either (TransportError SendErrorCode) ())
selfSend connAlive connId msg =
try . withMVar ourState $ \st -> case st of
LocalEndPointValid _ -> do
alive <- readIORef connAlive
if alive
then seq (foldr seq () msg)
qdiscEnqueue' ourQueue ourAddress (Received connId msg)
else throwIO $ TransportError SendClosed "Connection closed"
LocalEndPointClosed ->
throwIO $ TransportError SendFailed "Endpoint closed"
selfClose :: IORef Bool -> ConnectionId -> IO ()
selfClose connAlive connId =
withMVar ourState $ \st -> case st of
LocalEndPointValid _ -> do
alive <- readIORef connAlive
when alive $ do
qdiscEnqueue' ourQueue ourAddress (ConnectionClosed connId)
writeIORef connAlive False
LocalEndPointClosed ->
return ()
ourQueue = localQueue ourEndPoint
ourState = localState ourEndPoint
connectFailed = TransportError ConnectFailed . show
ourAddress = localAddress ourEndPoint
resolveInit :: EndPointPair -> RemoteState -> IO ()
resolveInit (ourEndPoint, theirEndPoint) newState =
modifyMVar_ (remoteState theirEndPoint) $ \st -> case st of
RemoteEndPointInit resolved crossed _ -> do
putMVar resolved ()
tryPutMVar crossed ()
case newState of
RemoteEndPointClosed ->
removeRemoteEndPoint (ourEndPoint, theirEndPoint)
_ ->
return ()
return newState
RemoteEndPointFailed ex ->
throwIO ex
_ ->
relyViolation (ourEndPoint, theirEndPoint) "resolveInit"
getLocalNextConnOutId :: LocalEndPoint -> IO LightweightConnectionId
getLocalNextConnOutId ourEndpoint =
modifyMVar (localState ourEndpoint) $ \st -> case st of
LocalEndPointValid vst -> do
let connId = vst ^. localNextConnOutId
return ( LocalEndPointValid
. (localNextConnOutId ^= connId + 1)
$ vst
, connId)
LocalEndPointClosed ->
throwIO $ userError "Local endpoint closed"
createLocalEndPoint :: TCPTransport
-> QDisc Event
-> IO LocalEndPoint
createLocalEndPoint transport qdisc = do
state <- newMVar . LocalEndPointValid $ ValidLocalEndPointState
{ _localNextConnOutId = firstNonReservedLightweightConnectionId
, _localConnections = Map.empty
, _nextConnInId = firstNonReservedHeavyweightConnectionId
}
modifyMVar (transportState transport) $ \st -> case st of
TransportValid vst -> do
let ix = vst ^. nextEndPointId
addr <- case transportAddrInfo transport of
Nothing -> randomEndPointAddress
Just addrInfo -> return $
encodeEndPointAddress (transportHost addrInfo)
(transportPort addrInfo)
ix
let localEndPoint = LocalEndPoint { localAddress = addr
, localEndPointId = ix
, localQueue = qdisc
, localState = state
}
return ( TransportValid
. (localEndPointAt ix ^= Just localEndPoint)
. (nextEndPointId ^= ix + 1)
$ vst
, localEndPoint
)
TransportClosed ->
throwIO (TransportError NewEndPointFailed "Transport closed")
removeRemoteEndPoint :: EndPointPair -> IO ()
removeRemoteEndPoint (ourEndPoint, theirEndPoint) =
modifyMVar_ ourState $ \st -> case st of
LocalEndPointValid vst ->
case vst ^. localConnectionTo theirAddress of
Nothing ->
return st
Just remoteEndPoint' ->
if remoteId remoteEndPoint' == remoteId theirEndPoint
then return
( LocalEndPointValid
. (localConnectionTo (remoteAddress theirEndPoint) ^= Nothing)
$ vst
)
else return st
LocalEndPointClosed ->
return LocalEndPointClosed
where
ourState = localState ourEndPoint
theirAddress = remoteAddress theirEndPoint
removeLocalEndPoint :: TCPTransport -> LocalEndPoint -> IO ()
removeLocalEndPoint transport ourEndPoint =
modifyMVar_ (transportState transport) $ \st -> case st of
TransportValid vst ->
return ( TransportValid
. (localEndPointAt (localEndPointId ourEndPoint) ^= Nothing)
$ vst
)
TransportClosed ->
return TransportClosed
findRemoteEndPoint
:: LocalEndPoint
-> EndPointAddress
-> RequestedBy
-> Maybe (IO ())
-> IO (RemoteEndPoint, Bool)
findRemoteEndPoint ourEndPoint theirAddress findOrigin mtimer = go
where
go = do
(theirEndPoint, isNew) <- modifyMVar ourState $ \st -> case st of
LocalEndPointValid vst -> case vst ^. localConnectionTo theirAddress of
Just theirEndPoint ->
return (st, (theirEndPoint, False))
Nothing -> do
resolved <- newEmptyMVar
crossed <- newEmptyMVar
theirState <- newMVar (RemoteEndPointInit resolved crossed findOrigin)
scheduled <- newChan
let theirEndPoint = RemoteEndPoint
{ remoteAddress = theirAddress
, remoteState = theirState
, remoteId = vst ^. nextConnInId
, remoteScheduled = scheduled
}
return ( LocalEndPointValid
. (localConnectionTo theirAddress ^= Just theirEndPoint)
. (nextConnInId ^: (+ 1))
$ vst
, (theirEndPoint, True)
)
LocalEndPointClosed ->
throwIO $ userError "Local endpoint closed"
if isNew
then
return (theirEndPoint, True)
else do
let theirState = remoteState theirEndPoint
snapshot <- modifyMVar theirState $ \st -> case st of
RemoteEndPointValid vst ->
case findOrigin of
RequestedByUs -> do
let st' = RemoteEndPointValid
. (remoteOutgoing ^: (+ 1))
$ vst
return (st', st')
RequestedByThem ->
return (st, st)
_ ->
return (st, st)
case snapshot of
RemoteEndPointInvalid err ->
throwIO err
RemoteEndPointInit resolved crossed initOrigin ->
case (findOrigin, initOrigin) of
(RequestedByUs, RequestedByUs) ->
readMVarTimeout mtimer resolved >> go
(RequestedByUs, RequestedByThem) ->
readMVarTimeout mtimer resolved >> go
(RequestedByThem, RequestedByUs) ->
if ourAddress > theirAddress
then do
tryReadMVar crossed >>= \case
Nothing -> readMVarTimeout mtimer crossed >> go
_ -> return (theirEndPoint, True)
else
return (theirEndPoint, False)
(RequestedByThem, RequestedByThem) ->
throwIO $ userError "Already connected"
RemoteEndPointValid _ ->
return (theirEndPoint, False)
RemoteEndPointClosing resolved _ ->
readMVarTimeout mtimer resolved >> go
RemoteEndPointClosed ->
go
RemoteEndPointFailed err ->
throwIO err
ourState = localState ourEndPoint
ourAddress = localAddress ourEndPoint
readMVarTimeout Nothing mv = readMVar mv
readMVarTimeout (Just timer) mv = do
let connectTimedout = TransportError ConnectTimeout "Timed out"
tid <- myThreadId
bracket (forkIO $ timer >> throwTo tid connectTimedout) killThread $
const $ readMVar mv
sendOn :: ValidRemoteEndPointState -> [ByteString] -> IO ()
sendOn vst bs = (wait =<<) $ async $
mask $ \restore -> do
let lock = remoteSendLock vst
maybeException <- takeMVar lock
when (isNothing maybeException) $
restore (sendMany (remoteSocket vst) bs) `catch` \ex -> do
putMVar lock (Just ex)
throwIO ex
putMVar lock maybeException
forM_ maybeException $ \e ->
throwIO $ userError $ "sendOn failed earlier with: " ++ show e
type Action a = MVar (Either SomeException a)
schedule :: RemoteEndPoint -> IO a -> IO (Action a)
schedule theirEndPoint act = do
mvar <- newEmptyMVar
writeChan (remoteScheduled theirEndPoint) $
catch (act >>= putMVar mvar . Right) (putMVar mvar . Left)
return mvar
runScheduledAction :: EndPointPair -> Action a -> IO a
runScheduledAction (ourEndPoint, theirEndPoint) mvar = do
join $ readChan (remoteScheduled theirEndPoint)
ma <- readMVar mvar
case ma of
Right a -> return a
Left e -> do
forM_ (fromException e) $ \ioe ->
modifyMVar_ (remoteState theirEndPoint) $ \st ->
case st of
RemoteEndPointValid vst -> handleIOException ioe vst
_ -> return (RemoteEndPointFailed ioe)
throwIO e
where
handleIOException :: IOException
-> ValidRemoteEndPointState
-> IO RemoteState
handleIOException ex vst = do
forM_ (remoteProbing vst) id
tryShutdownSocketBoth (remoteSocket vst)
return (RemoteEndPointFailed ex)
withScheduledAction :: LocalEndPoint -> ((RemoteEndPoint -> IO a -> IO ()) -> IO ()) -> IO ()
withScheduledAction ourEndPoint f =
bracket (newIORef Nothing)
(traverse (\(tp, a) -> runScheduledAction (ourEndPoint, tp) a) <=< readIORef)
(\ref -> f (\rp g -> mask_ $ schedule rp g >>= \x -> writeIORef ref (Just (rp,x)) ))
socketToEndPoint :: Maybe EndPointAddress
-> EndPointAddress
-> Bool
-> Bool
-> Bool
-> Maybe Int
-> Maybe Int
-> IO (Either (TransportError ConnectErrorCode)
(MVar (), N.Socket, ConnectionRequestResponse))
socketToEndPoint mOurAddress theirAddress reuseAddr noDelay keepAlive
mUserTimeout timeout =
try $ do
(host, port, theirEndPointId) <- case decodeEndPointAddress theirAddress of
Nothing -> throwIO (failed . userError $ "Could not parse")
Just dec -> return dec
addr:_ <- mapIOException invalidAddress $
N.getAddrInfo Nothing (Just host) (Just port)
bracketOnError (createSocket addr) tryCloseSocket $ \sock -> do
when reuseAddr $
mapIOException failed $ N.setSocketOption sock N.ReuseAddr 1
when noDelay $
mapIOException failed $ N.setSocketOption sock N.NoDelay 1
when keepAlive $
mapIOException failed $ N.setSocketOption sock N.KeepAlive 1
forM_ mUserTimeout $
mapIOException failed . N.setSocketOption sock N.UserTimeout
response <- timeoutMaybe timeout timeoutError $ do
mapIOException invalidAddress $
N.connect sock (N.addrAddress addr)
mapIOException failed $ do
case mOurAddress of
Just (EndPointAddress ourAddress) ->
sendMany sock $
encodeWord32 currentProtocolVersion
: prependLength (encodeWord32 theirEndPointId : prependLength [ourAddress])
Nothing ->
sendMany sock $
encodeWord32 currentProtocolVersion
: prependLength ([encodeWord32 theirEndPointId, encodeWord32 0])
recvWord32 sock
case decodeConnectionRequestResponse response of
Nothing -> throwIO (failed . userError $ "Unexpected response")
Just r -> do
socketClosedVar <- newEmptyMVar
return (socketClosedVar, sock, r)
where
createSocket :: N.AddrInfo -> IO N.Socket
createSocket addr = mapIOException insufficientResources $
N.socket (N.addrFamily addr) N.Stream N.defaultProtocol
invalidAddress = TransportError ConnectNotFound . show
insufficientResources = TransportError ConnectInsufficientResources . show
failed = TransportError ConnectFailed . show
timeoutError = TransportError ConnectTimeout "Timed out"
createConnectionId :: HeavyweightConnectionId
-> LightweightConnectionId
-> ConnectionId
createConnectionId hcid lcid =
(fromIntegral hcid `shiftL` 32) .|. fromIntegral lcid
internalSocketBetween :: TCPTransport
-> EndPointAddress
-> EndPointAddress
-> IO N.Socket
internalSocketBetween transport ourAddress theirAddress = do
ourEndPointId <- case decodeEndPointAddress ourAddress of
Just (_, _, eid) -> return eid
_ -> throwIO $ userError "Malformed local EndPointAddress"
ourEndPoint <- withMVar (transportState transport) $ \st -> case st of
TransportClosed ->
throwIO $ userError "Transport closed"
TransportValid vst ->
case vst ^. localEndPointAt ourEndPointId of
Nothing -> throwIO $ userError "Local endpoint not found"
Just ep -> return ep
theirEndPoint <- withMVar (localState ourEndPoint) $ \st -> case st of
LocalEndPointClosed ->
throwIO $ userError "Local endpoint closed"
LocalEndPointValid vst ->
case vst ^. localConnectionTo theirAddress of
Nothing -> throwIO $ userError "Remote endpoint not found"
Just ep -> return ep
withMVar (remoteState theirEndPoint) $ \st -> case st of
RemoteEndPointInit _ _ _ ->
throwIO $ userError "Remote endpoint not yet initialized"
RemoteEndPointValid vst ->
return $ remoteSocket vst
RemoteEndPointClosing _ vst ->
return $ remoteSocket vst
RemoteEndPointClosed ->
throwIO $ userError "Remote endpoint closed"
RemoteEndPointInvalid err ->
throwIO err
RemoteEndPointFailed err ->
throwIO err
where
firstNonReservedLightweightConnectionId :: LightweightConnectionId
firstNonReservedLightweightConnectionId = 1024
heavyweightSelfConnectionId :: HeavyweightConnectionId
heavyweightSelfConnectionId = 0
firstNonReservedHeavyweightConnectionId :: HeavyweightConnectionId
firstNonReservedHeavyweightConnectionId = 1
localEndPoints :: Accessor ValidTransportState (Map EndPointId LocalEndPoint)
localEndPoints = accessor _localEndPoints (\es st -> st { _localEndPoints = es })
nextEndPointId :: Accessor ValidTransportState EndPointId
nextEndPointId = accessor _nextEndPointId (\eid st -> st { _nextEndPointId = eid })
localNextConnOutId :: Accessor ValidLocalEndPointState LightweightConnectionId
localNextConnOutId = accessor _localNextConnOutId (\cix st -> st { _localNextConnOutId = cix })
localConnections :: Accessor ValidLocalEndPointState (Map EndPointAddress RemoteEndPoint)
localConnections = accessor _localConnections (\es st -> st { _localConnections = es })
nextConnInId :: Accessor ValidLocalEndPointState HeavyweightConnectionId
nextConnInId = accessor _nextConnInId (\rid st -> st { _nextConnInId = rid })
remoteOutgoing :: Accessor ValidRemoteEndPointState Int
remoteOutgoing = accessor _remoteOutgoing (\cs conn -> conn { _remoteOutgoing = cs })
remoteIncoming :: Accessor ValidRemoteEndPointState (Set LightweightConnectionId)
remoteIncoming = accessor _remoteIncoming (\cs conn -> conn { _remoteIncoming = cs })
remoteLastIncoming :: Accessor ValidRemoteEndPointState LightweightConnectionId
remoteLastIncoming = accessor _remoteLastIncoming (\lcid st -> st { _remoteLastIncoming = lcid })
remoteNextConnOutId :: Accessor ValidRemoteEndPointState LightweightConnectionId
remoteNextConnOutId = accessor _remoteNextConnOutId (\cix st -> st { _remoteNextConnOutId = cix })
localEndPointAt :: EndPointId -> Accessor ValidTransportState (Maybe LocalEndPoint)
localEndPointAt addr = localEndPoints >>> DAC.mapMaybe addr
localConnectionTo :: EndPointAddress -> Accessor ValidLocalEndPointState (Maybe RemoteEndPoint)
localConnectionTo addr = localConnections >>> DAC.mapMaybe addr
relyViolation :: EndPointPair -> String -> IO a
relyViolation (ourEndPoint, theirEndPoint) str = do
elog (ourEndPoint, theirEndPoint) (str ++ " RELY violation")
fail (str ++ " RELY violation")
elog :: EndPointPair -> String -> IO ()
elog (ourEndPoint, theirEndPoint) msg = do
tid <- myThreadId
putStrLn $ show (localAddress ourEndPoint)
++ "/" ++ show (remoteAddress theirEndPoint)
++ "(" ++ show (remoteId theirEndPoint) ++ ")"
++ "/" ++ show tid
++ ": " ++ msg