module Network.Transport.TCP (
createTransport
, TCPParameters(..)
, defaultTCPParameters
, createTransportExposeInternals
, TransportInternals(..)
, EndPointId
, encodeEndPointAddress
, decodeEndPointAddress
, ControlHeader(..)
, ConnectionRequestResponse(..)
, firstNonReservedConnectionId
, socketToEndPoint
) where
import Prelude hiding
( mapM_
#if ! MIN_VERSION_base(4,6,0)
, catch
#endif
)
import Network.Transport
import Network.Transport.TCP.Internal ( forkServer
, recvWithLength
, recvInt32
, tryCloseSocket
)
import Network.Transport.Internal ( encodeInt32
, decodeInt32
, prependLength
, mapIOException
, tryIO
, tryToEnum
, void
, timeoutMaybe
, asyncWhenCancelled
)
import qualified Network.Socket as N ( HostName
, ServiceName
, Socket
, getAddrInfo
, socket
, addrFamily
, addrAddress
, SocketType(Stream)
, defaultProtocol
, setSocketOption
, SocketOption(ReuseAddr)
, connect
, sOMAXCONN
, AddrInfo
)
import Network.Socket.ByteString (sendMany)
import Control.Concurrent (forkIO, ThreadId, killThread, myThreadId)
import Control.Concurrent.Chan (Chan, newChan, readChan, writeChan)
import Control.Concurrent.MVar ( MVar
, newMVar
, modifyMVar
, modifyMVar_
, readMVar
, takeMVar
, putMVar
, newEmptyMVar
, withMVar
)
import Control.Category ((>>>))
import Control.Applicative ((<$>))
import Control.Monad (when, unless)
import Control.Exception ( IOException
, SomeException
, AsyncException
, handle
, throw
, throwIO
, try
, bracketOnError
, mask
, onException
, fromException
)
import Data.IORef (IORef, newIORef, writeIORef, readIORef)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS (concat)
import qualified Data.ByteString.Char8 as BSC (pack, unpack)
import Data.Int (Int32)
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap (empty)
import Data.IntSet (IntSet)
import qualified Data.IntSet as IntSet ( empty
, insert
, elems
, singleton
, null
, delete
, member
)
import Data.Map (Map)
import qualified Data.Map as Map (empty)
import Data.Accessor (Accessor, accessor, (^.), (^=), (^:))
import qualified Data.Accessor.Container as DAC (mapMaybe, intMapMaybe)
import Data.Foldable (forM_, mapM_)
data TCPTransport = TCPTransport
{ transportHost :: N.HostName
, transportPort :: N.ServiceName
, transportState :: MVar TransportState
, transportParams :: TCPParameters
}
data TransportState =
TransportValid ValidTransportState
| TransportClosed
data ValidTransportState = ValidTransportState
{ _localEndPoints :: Map EndPointAddress LocalEndPoint
, _nextEndPointId :: EndPointId
}
data LocalEndPoint = LocalEndPoint
{ localAddress :: EndPointAddress
, localChannel :: Chan Event
, localState :: MVar LocalEndPointState
}
data LocalEndPointState =
LocalEndPointValid ValidLocalEndPointState
| LocalEndPointClosed
data ValidLocalEndPointState = ValidLocalEndPointState
{ _nextConnectionId :: !ConnectionId
, _localConnections :: Map EndPointAddress RemoteEndPoint
, _nextRemoteId :: !Int
}
data RemoteEndPoint = RemoteEndPoint
{ remoteAddress :: EndPointAddress
, remoteState :: MVar RemoteState
, remoteId :: Int
}
data RequestedBy = RequestedByUs | RequestedByThem
deriving (Eq, Show)
data RemoteState =
RemoteEndPointInvalid (TransportError ConnectErrorCode)
| RemoteEndPointInit (MVar ()) RequestedBy
| RemoteEndPointValid ValidRemoteEndPointState
| RemoteEndPointClosing (MVar ()) ValidRemoteEndPointState
| RemoteEndPointClosed
| RemoteEndPointFailed IOException
data ValidRemoteEndPointState = ValidRemoteEndPointState
{ _remoteOutgoing :: !Int
, _remoteIncoming :: IntSet
, remoteSocket :: N.Socket
, sendOn :: [ByteString] -> IO ()
, _pendingCtrlRequests :: IntMap (MVar (Either IOException [ByteString]))
, _nextCtrlRequestId :: !ControlRequestId
}
type EndPointId = Int32
type ControlRequestId = Int32
type EndPointPair = (LocalEndPoint, RemoteEndPoint)
data ControlHeader =
RequestConnectionId
| CloseConnection
| ControlResponse
| CloseSocket
deriving (Enum, Bounded, Show)
data ConnectionRequestResponse =
ConnectionRequestAccepted
| ConnectionRequestInvalid
| ConnectionRequestCrossed
deriving (Enum, Bounded, Show)
data TCPParameters = TCPParameters {
tcpBacklog :: Int
, tcpReuseServerAddr :: Bool
, tcpReuseClientAddr :: Bool
}
data TransportInternals = TransportInternals
{
transportThread :: ThreadId
, socketBetween :: EndPointAddress
-> EndPointAddress
-> IO N.Socket
}
createTransport :: N.HostName
-> N.ServiceName
-> TCPParameters
-> IO (Either IOException Transport)
createTransport host port params =
either Left (Right . fst) <$> createTransportExposeInternals host port params
createTransportExposeInternals
:: N.HostName
-> N.ServiceName
-> TCPParameters
-> IO (Either IOException (Transport, TransportInternals))
createTransportExposeInternals host port params = do
state <- newMVar . TransportValid $ ValidTransportState
{ _localEndPoints = Map.empty
, _nextEndPointId = 0
}
let transport = TCPTransport { transportState = state
, transportHost = host
, transportPort = port
, transportParams = params
}
tryIO $ bracketOnError (forkServer
host
port
(tcpBacklog params)
(tcpReuseServerAddr params)
(terminationHandler transport)
(handleConnectionRequest transport))
killThread
(mkTransport transport)
where
mkTransport :: TCPTransport
-> ThreadId
-> IO (Transport, TransportInternals)
mkTransport transport tid = return
( Transport
{ newEndPoint = apiNewEndPoint transport
, closeTransport = let evs = [ EndPointClosed
, throw $ userError "Transport closed"
] in
apiCloseTransport transport (Just tid) evs
}
, TransportInternals
{ transportThread = tid
, socketBetween = internalSocketBetween transport
}
)
terminationHandler :: TCPTransport -> SomeException -> IO ()
terminationHandler transport ex = do
let evs = [ ErrorEvent (TransportError EventTransportFailed (show ex))
, throw $ userError "Transport closed"
]
apiCloseTransport transport Nothing evs
defaultTCPParameters :: TCPParameters
defaultTCPParameters = TCPParameters {
tcpBacklog = N.sOMAXCONN
, tcpReuseServerAddr = True
, tcpReuseClientAddr = True
}
apiCloseTransport :: TCPTransport -> Maybe ThreadId -> [Event] -> IO ()
apiCloseTransport transport mTransportThread evs =
asyncWhenCancelled return $ do
mTSt <- modifyMVar (transportState transport) $ \st -> case st of
TransportValid vst -> return (TransportClosed, Just vst)
TransportClosed -> return (TransportClosed, Nothing)
forM_ mTSt $ mapM_ (apiCloseEndPoint transport evs) . (^. localEndPoints)
forM_ mTransportThread killThread
apiNewEndPoint :: TCPTransport
-> IO (Either (TransportError NewEndPointErrorCode) EndPoint)
apiNewEndPoint transport =
try . asyncWhenCancelled closeEndPoint $ do
ourEndPoint <- createLocalEndPoint transport
return EndPoint
{ receive = readChan (localChannel ourEndPoint)
, address = localAddress ourEndPoint
, connect = apiConnect (transportParams transport) ourEndPoint
, closeEndPoint = let evs = [ EndPointClosed
, throw $ userError "Endpoint closed"
] in
apiCloseEndPoint transport evs ourEndPoint
, newMulticastGroup = return . Left $ newMulticastGroupError
, resolveMulticastGroup = return . Left . const resolveMulticastGroupError
}
where
newMulticastGroupError =
TransportError NewMulticastGroupUnsupported "Multicast not supported"
resolveMulticastGroupError =
TransportError ResolveMulticastGroupUnsupported "Multicast not supported"
apiConnect :: TCPParameters
-> LocalEndPoint
-> EndPointAddress
-> Reliability
-> ConnectHints
-> IO (Either (TransportError ConnectErrorCode) Connection)
apiConnect params ourEndPoint theirAddress _reliability hints =
try . asyncWhenCancelled close $
if localAddress ourEndPoint == theirAddress
then connectToSelf ourEndPoint
else do
resetIfBroken ourEndPoint theirAddress
(theirEndPoint, connId) <-
requestConnectionTo params ourEndPoint theirAddress hints
connAlive <- newIORef True
return Connection
{ send = apiSend (ourEndPoint, theirEndPoint) connId connAlive
, close = apiClose (ourEndPoint, theirEndPoint) connId connAlive
}
apiClose :: EndPointPair -> ConnectionId -> IORef Bool -> IO ()
apiClose (ourEndPoint, theirEndPoint) connId connAlive =
void . tryIO . asyncWhenCancelled return $ do
modifyRemoteState_ (ourEndPoint, theirEndPoint) remoteStateIdentity
{ caseValid = \vst -> do
alive <- readIORef connAlive
if alive
then do
writeIORef connAlive False
sendOn vst [encodeInt32 CloseConnection, encodeInt32 connId]
return ( RemoteEndPointValid
. (remoteOutgoing ^: (\x -> x 1))
$ vst
)
else
return (RemoteEndPointValid vst)
}
closeIfUnused (ourEndPoint, theirEndPoint)
apiSend :: EndPointPair
-> ConnectionId
-> IORef Bool
-> [ByteString]
-> IO (Either (TransportError SendErrorCode) ())
apiSend (ourEndPoint, theirEndPoint) connId connAlive payload =
try . mapIOException sendFailed $
withRemoteState (ourEndPoint, theirEndPoint) RemoteStatePatternMatch
{ caseInvalid = \_ ->
relyViolation (ourEndPoint, theirEndPoint) "apiSend"
, caseInit = \_ _ ->
relyViolation (ourEndPoint, theirEndPoint) "apiSend"
, caseValid = \vst -> do
alive <- readIORef connAlive
if alive
then sendOn vst (encodeInt32 connId : prependLength payload)
else throwIO $ TransportError SendClosed "Connection closed"
, caseClosing = \_ _ -> do
alive <- readIORef connAlive
if alive
then relyViolation (ourEndPoint, theirEndPoint) "apiSend"
else throwIO $ TransportError SendClosed "Connection closed"
, caseClosed = do
alive <- readIORef connAlive
if alive
then relyViolation (ourEndPoint, theirEndPoint) "apiSend"
else throwIO $ TransportError SendClosed "Connection closed"
, caseFailed = \err -> do
alive <- readIORef connAlive
if alive
then throwIO $ TransportError SendFailed (show err)
else throwIO $ TransportError SendClosed "Connection closed"
}
where
sendFailed = TransportError SendFailed . show
apiCloseEndPoint :: TCPTransport
-> [Event]
-> LocalEndPoint
-> IO ()
apiCloseEndPoint transport evs ourEndPoint =
asyncWhenCancelled return $ do
removeLocalEndPoint transport ourEndPoint
mOurState <- modifyMVar (localState ourEndPoint) $ \st ->
case st of
LocalEndPointValid vst ->
return (LocalEndPointClosed, Just vst)
LocalEndPointClosed ->
return (LocalEndPointClosed, Nothing)
forM_ mOurState $ \vst -> do
forM_ (vst ^. localConnections) tryCloseRemoteSocket
forM_ evs $ writeChan (localChannel ourEndPoint)
where
tryCloseRemoteSocket :: RemoteEndPoint -> IO ()
tryCloseRemoteSocket theirEndPoint = do
let closed = RemoteEndPointFailed . userError $ "apiCloseEndPoint"
modifyMVar_ (remoteState theirEndPoint) $ \st ->
case st of
RemoteEndPointInvalid _ ->
return st
RemoteEndPointInit resolved _ -> do
putMVar resolved ()
return closed
RemoteEndPointValid conn -> do
tryIO $ sendOn conn [encodeInt32 CloseSocket]
tryCloseSocket (remoteSocket conn)
return closed
RemoteEndPointClosing resolved conn -> do
putMVar resolved ()
tryCloseSocket (remoteSocket conn)
return closed
RemoteEndPointClosed ->
return st
RemoteEndPointFailed err ->
return $ RemoteEndPointFailed err
data RemoteStatePatternMatch a = RemoteStatePatternMatch
{ caseInvalid :: TransportError ConnectErrorCode -> IO a
, caseInit :: MVar () -> RequestedBy -> IO a
, caseValid :: ValidRemoteEndPointState -> IO a
, caseClosing :: MVar () -> ValidRemoteEndPointState -> IO a
, caseClosed :: IO a
, caseFailed :: IOException -> IO a
}
remoteStateIdentity :: RemoteStatePatternMatch RemoteState
remoteStateIdentity =
RemoteStatePatternMatch
{ caseInvalid = return . RemoteEndPointInvalid
, caseInit = (return .) . RemoteEndPointInit
, caseValid = return . RemoteEndPointValid
, caseClosing = (return .) . RemoteEndPointClosing
, caseClosed = return RemoteEndPointClosed
, caseFailed = return . RemoteEndPointFailed
}
modifyRemoteState :: EndPointPair
-> RemoteStatePatternMatch (RemoteState, a)
-> IO a
modifyRemoteState (ourEndPoint, theirEndPoint) match =
mask $ \restore -> do
st <- takeMVar theirState
case st of
RemoteEndPointValid vst -> do
mResult <- try $ restore (caseValid match vst)
case mResult of
Right (st', a) -> do
putMVar theirState st'
return a
Left ex -> do
case fromException ex of
Just ioEx -> handleIOException ioEx vst
Nothing -> putMVar theirState st
throwIO ex
RemoteEndPointInit resolved origin -> do
(st', a) <- onException (restore $ caseInit match resolved origin)
(putMVar theirState st)
putMVar theirState st'
return a
RemoteEndPointClosing resolved vst -> do
(st', a) <- onException (restore $ caseClosing match resolved vst)
(putMVar theirState st)
putMVar theirState st'
return a
RemoteEndPointInvalid err -> do
(st', a) <- onException (restore $ caseInvalid match err)
(putMVar theirState st)
putMVar theirState st'
return a
RemoteEndPointClosed -> do
(st', a) <- onException (restore $ caseClosed match)
(putMVar theirState st)
putMVar theirState st'
return a
RemoteEndPointFailed err -> do
(st', a) <- onException (restore $ caseFailed match err)
(putMVar theirState st)
putMVar theirState st'
return a
where
theirState :: MVar RemoteState
theirState = remoteState theirEndPoint
handleIOException :: IOException -> ValidRemoteEndPointState -> IO ()
handleIOException ex vst = do
tryCloseSocket (remoteSocket vst)
putMVar theirState (RemoteEndPointFailed ex)
let incoming = IntSet.elems $ vst ^. remoteIncoming
code = EventConnectionLost (Just $ remoteAddress theirEndPoint) incoming
err = TransportError code (show ex)
writeChan (localChannel ourEndPoint) $ ErrorEvent err
modifyRemoteState_ :: EndPointPair
-> RemoteStatePatternMatch RemoteState
-> IO ()
modifyRemoteState_ (ourEndPoint, theirEndPoint) match =
modifyRemoteState (ourEndPoint, theirEndPoint)
RemoteStatePatternMatch
{ caseInvalid = u . caseInvalid match
, caseInit = \resolved origin -> u $ caseInit match resolved origin
, caseValid = u . caseValid match
, caseClosing = \resolved vst -> u $ caseClosing match resolved vst
, caseClosed = u $ caseClosed match
, caseFailed = u . caseFailed match
}
where
u :: IO a -> IO (a, ())
u p = p >>= \a -> return (a, ())
withRemoteState :: EndPointPair
-> RemoteStatePatternMatch a
-> IO a
withRemoteState (ourEndPoint, theirEndPoint) match =
modifyRemoteState (ourEndPoint, theirEndPoint)
RemoteStatePatternMatch
{ caseInvalid = \err -> do
a <- caseInvalid match err
return (RemoteEndPointInvalid err, a)
, caseInit = \resolved origin -> do
a <- caseInit match resolved origin
return (RemoteEndPointInit resolved origin, a)
, caseValid = \vst -> do
a <- caseValid match vst
return (RemoteEndPointValid vst, a)
, caseClosing = \resolved vst -> do
a <- caseClosing match resolved vst
return (RemoteEndPointClosing resolved vst, a)
, caseClosed = do
a <- caseClosed match
return (RemoteEndPointClosed, a)
, caseFailed = \err -> do
a <- caseFailed match err
return (RemoteEndPointFailed err, a)
}
handleConnectionRequest :: TCPTransport -> N.Socket -> IO ()
handleConnectionRequest transport sock = handle handleException $ do
ourEndPointId <- recvInt32 sock
theirAddress <- EndPointAddress . BS.concat <$> recvWithLength sock
let ourAddress = encodeEndPointAddress (transportHost transport)
(transportPort transport)
ourEndPointId
ourEndPoint <- withMVar (transportState transport) $ \st -> case st of
TransportValid vst ->
case vst ^. localEndPointAt ourAddress of
Nothing -> do
sendMany sock [encodeInt32 ConnectionRequestInvalid]
throwIO $ userError "handleConnectionRequest: Invalid endpoint"
Just ourEndPoint ->
return ourEndPoint
TransportClosed ->
throwIO $ userError "Transport closed"
void . forkIO $ go ourEndPoint theirAddress
where
go :: LocalEndPoint -> EndPointAddress -> IO ()
go ourEndPoint theirAddress = do
mEndPoint <- handle ((>> return Nothing) . handleException) $ do
resetIfBroken ourEndPoint theirAddress
(theirEndPoint, isNew) <-
findRemoteEndPoint ourEndPoint theirAddress RequestedByThem
if not isNew
then do
tryIO $ sendMany sock [encodeInt32 ConnectionRequestCrossed]
tryCloseSocket sock
return Nothing
else do
let vst = ValidRemoteEndPointState
{ remoteSocket = sock
, _remoteOutgoing = 0
, _remoteIncoming = IntSet.empty
, sendOn = sendMany sock
, _pendingCtrlRequests = IntMap.empty
, _nextCtrlRequestId = 0
}
sendMany sock [encodeInt32 ConnectionRequestAccepted]
resolveInit (ourEndPoint, theirEndPoint) (RemoteEndPointValid vst)
return (Just theirEndPoint)
forM_ mEndPoint $ handleIncomingMessages . (,) ourEndPoint
handleException :: SomeException -> IO ()
handleException ex = do
tryCloseSocket sock
rethrowIfAsync (fromException ex)
rethrowIfAsync :: Maybe AsyncException -> IO ()
rethrowIfAsync = mapM_ throwIO
handleIncomingMessages :: EndPointPair -> IO ()
handleIncomingMessages (ourEndPoint, theirEndPoint) = do
mSock <- withMVar theirState $ \st ->
case st of
RemoteEndPointInvalid _ ->
relyViolation (ourEndPoint, theirEndPoint)
"handleIncomingMessages (invalid)"
RemoteEndPointInit _ _ ->
relyViolation (ourEndPoint, theirEndPoint)
"handleIncomingMessages (init)"
RemoteEndPointValid ep ->
return . Just $ remoteSocket ep
RemoteEndPointClosing _ ep ->
return . Just $ remoteSocket ep
RemoteEndPointClosed ->
return Nothing
RemoteEndPointFailed _ ->
return Nothing
forM_ mSock $ \sock ->
tryIO (go sock) >>= either (prematureExit sock) return
where
go :: N.Socket -> IO ()
go sock = do
connId <- recvInt32 sock
if connId >= firstNonReservedConnectionId
then do
readMessage sock connId
go sock
else
case tryToEnum (fromIntegral connId) of
Just RequestConnectionId -> do
recvInt32 sock >>= createNewConnection
go sock
Just ControlResponse -> do
recvInt32 sock >>= readControlResponse sock
go sock
Just CloseConnection -> do
recvInt32 sock >>= closeConnection
go sock
Just CloseSocket -> do
didClose <- closeSocket sock
unless didClose $ go sock
Nothing ->
throwIO $ userError "Invalid control request"
createNewConnection :: ControlRequestId -> IO ()
createNewConnection reqId = do
newId <- getNextConnectionId ourEndPoint
modifyMVar_ theirState $ \st -> do
vst <- case st of
RemoteEndPointInvalid _ ->
relyViolation (ourEndPoint, theirEndPoint)
"handleIncomingMessages:createNewConnection (invalid)"
RemoteEndPointInit _ _ ->
relyViolation (ourEndPoint, theirEndPoint)
"handleIncomingMessages:createNewConnection (init)"
RemoteEndPointValid vst ->
return (remoteIncoming ^: IntSet.insert newId $ vst)
RemoteEndPointClosing resolved vst -> do
putMVar resolved ()
return (remoteIncoming ^= IntSet.singleton newId $ vst)
RemoteEndPointFailed err ->
throwIO err
RemoteEndPointClosed ->
relyViolation (ourEndPoint, theirEndPoint)
"createNewConnection (closed)"
sendOn vst ( encodeInt32 ControlResponse
: encodeInt32 reqId
: prependLength [encodeInt32 newId]
)
return (RemoteEndPointValid vst)
writeChan ourChannel (ConnectionOpened newId ReliableOrdered theirAddr)
readControlResponse :: N.Socket -> ControlRequestId -> IO ()
readControlResponse sock reqId = do
response <- recvWithLength sock
mmvar <- modifyMVar theirState $ \st -> case st of
RemoteEndPointInvalid _ ->
relyViolation (ourEndPoint, theirEndPoint)
"readControlResponse (invalid)"
RemoteEndPointInit _ _ ->
relyViolation (ourEndPoint, theirEndPoint)
"readControlResponse (init)"
RemoteEndPointValid vst ->
return ( RemoteEndPointValid
. (pendingCtrlRequestsAt reqId ^= Nothing)
$ vst
, vst ^. pendingCtrlRequestsAt reqId
)
RemoteEndPointClosing _ _ ->
throwIO $ userError "Invalid control response"
RemoteEndPointFailed err ->
throwIO err
RemoteEndPointClosed ->
relyViolation (ourEndPoint, theirEndPoint)
"readControlResponse (closed)"
case mmvar of
Nothing ->
throwIO $ userError "Invalid request ID"
Just mvar ->
putMVar mvar (Right response)
closeConnection :: ConnectionId -> IO ()
closeConnection cid = do
modifyMVar_ theirState $ \st -> case st of
RemoteEndPointInvalid _ ->
relyViolation (ourEndPoint, theirEndPoint) "closeConnection (invalid)"
RemoteEndPointInit _ _ ->
relyViolation (ourEndPoint, theirEndPoint) "closeConnection (init)"
RemoteEndPointValid vst -> do
unless (IntSet.member cid (vst ^. remoteIncoming)) $
throwIO $ userError "Invalid CloseConnection"
return ( RemoteEndPointValid
. (remoteIncoming ^: IntSet.delete cid)
$ vst
)
RemoteEndPointClosing _ _ ->
throwIO $ userError "Invalid CloseConnection request"
RemoteEndPointFailed err ->
throwIO err
RemoteEndPointClosed ->
relyViolation (ourEndPoint, theirEndPoint) "closeConnection (closed)"
writeChan ourChannel (ConnectionClosed cid)
closeIfUnused (ourEndPoint, theirEndPoint)
closeSocket :: N.Socket -> IO Bool
closeSocket sock =
modifyMVar theirState $ \st ->
case st of
RemoteEndPointInvalid _ ->
relyViolation (ourEndPoint, theirEndPoint)
"handleIncomingMessages:closeSocket (invalid)"
RemoteEndPointInit _ _ ->
relyViolation (ourEndPoint, theirEndPoint)
"handleIncomingMessages:closeSocket (init)"
RemoteEndPointValid vst -> do
forM_ (IntSet.elems $ vst ^. remoteIncoming) $
writeChan ourChannel . ConnectionClosed
let vst' = remoteIncoming ^= IntSet.empty $ vst
if vst' ^. remoteOutgoing == 0
then do
removeRemoteEndPoint (ourEndPoint, theirEndPoint)
tryIO $ sendOn vst' [encodeInt32 CloseSocket]
tryCloseSocket sock
return (RemoteEndPointClosed, True)
else
return (RemoteEndPointValid vst', False)
RemoteEndPointClosing resolved _ -> do
removeRemoteEndPoint (ourEndPoint, theirEndPoint)
tryCloseSocket sock
putMVar resolved ()
return (RemoteEndPointClosed, True)
RemoteEndPointFailed err ->
throwIO err
RemoteEndPointClosed ->
relyViolation (ourEndPoint, theirEndPoint)
"handleIncomingMessages:closeSocket (closed)"
readMessage :: N.Socket -> ConnectionId -> IO ()
readMessage sock connId =
recvWithLength sock >>= writeChan ourChannel . Received connId
ourChannel = localChannel ourEndPoint
theirState = remoteState theirEndPoint
theirAddr = remoteAddress theirEndPoint
prematureExit :: N.Socket -> IOException -> IO ()
prematureExit sock err = do
tryCloseSocket sock
modifyMVar_ theirState $ \st ->
case st of
RemoteEndPointInvalid _ ->
relyViolation (ourEndPoint, theirEndPoint)
"handleIncomingMessages:prematureExit"
RemoteEndPointInit _ _ ->
relyViolation (ourEndPoint, theirEndPoint)
"handleIncomingMessages:prematureExit"
RemoteEndPointValid vst -> do
let code = EventConnectionLost
(Just $ remoteAddress theirEndPoint)
(IntSet.elems $ vst ^. remoteIncoming)
writeChan ourChannel . ErrorEvent $ TransportError code (show err)
forM_ (vst ^. pendingCtrlRequests) $ flip putMVar (Left err)
return (RemoteEndPointFailed err)
RemoteEndPointClosing resolved _ -> do
putMVar resolved ()
return (RemoteEndPointFailed err)
RemoteEndPointClosed ->
relyViolation (ourEndPoint, theirEndPoint)
"handleIncomingMessages:prematureExit"
RemoteEndPointFailed err' ->
return (RemoteEndPointFailed err')
requestConnectionTo :: TCPParameters
-> LocalEndPoint
-> EndPointAddress
-> ConnectHints
-> IO (RemoteEndPoint, ConnectionId)
requestConnectionTo params ourEndPoint theirAddress hints = go
where
go = do
(theirEndPoint, isNew) <- mapIOException connectFailed $
findRemoteEndPoint ourEndPoint theirAddress RequestedByUs
if isNew
then do
forkIO . handle absorbAllExceptions $
setupRemoteEndPoint params (ourEndPoint, theirEndPoint) hints
go
else do
reply <- mapIOException connectFailed $
doRemoteRequest (ourEndPoint, theirEndPoint) RequestConnectionId
return (theirEndPoint, decodeInt32 . BS.concat $ reply)
connectFailed :: IOException -> TransportError ConnectErrorCode
connectFailed = TransportError ConnectFailed . show
absorbAllExceptions :: SomeException -> IO ()
absorbAllExceptions _ex =
return ()
setupRemoteEndPoint :: TCPParameters -> EndPointPair -> ConnectHints -> IO ()
setupRemoteEndPoint params (ourEndPoint, theirEndPoint) hints = do
result <- socketToEndPoint ourAddress
theirAddress
(tcpReuseClientAddr params)
(connectTimeout hints)
didAccept <- case result of
Right (sock, ConnectionRequestAccepted) -> do
let vst = ValidRemoteEndPointState
{ remoteSocket = sock
, _remoteOutgoing = 0
, _remoteIncoming = IntSet.empty
, sendOn = sendMany sock
, _pendingCtrlRequests = IntMap.empty
, _nextCtrlRequestId = 0
}
resolveInit (ourEndPoint, theirEndPoint) (RemoteEndPointValid vst)
return True
Right (sock, ConnectionRequestInvalid) -> do
let err = invalidAddress "setupRemoteEndPoint: Invalid endpoint"
resolveInit (ourEndPoint, theirEndPoint) (RemoteEndPointInvalid err)
tryCloseSocket sock
return False
Right (sock, ConnectionRequestCrossed) -> do
resolveInit (ourEndPoint, theirEndPoint) RemoteEndPointClosed
tryCloseSocket sock
return False
Left err -> do
resolveInit (ourEndPoint, theirEndPoint) (RemoteEndPointInvalid err)
return False
when didAccept $ handleIncomingMessages (ourEndPoint, theirEndPoint)
where
ourAddress = localAddress ourEndPoint
theirAddress = remoteAddress theirEndPoint
invalidAddress = TransportError ConnectNotFound
doRemoteRequest :: EndPointPair -> ControlHeader -> IO [ByteString]
doRemoteRequest (ourEndPoint, theirEndPoint) header = do
replyMVar <- newEmptyMVar
modifyRemoteState_ (ourEndPoint, theirEndPoint) RemoteStatePatternMatch
{ caseValid = \vst -> do
let reqId = vst ^. nextCtrlRequestId
sendOn vst [encodeInt32 header, encodeInt32 reqId]
return ( RemoteEndPointValid
. (nextCtrlRequestId ^: (+ 1))
. (pendingCtrlRequestsAt reqId ^= Just replyMVar)
$ vst
)
, caseInvalid =
throwIO
, caseInit = \_ _ ->
relyViolation (ourEndPoint, theirEndPoint) "doRemoteRequest (init)"
, caseClosing = \_ _ ->
relyViolation (ourEndPoint, theirEndPoint) "doRemoteRequest (closing)"
, caseClosed =
relyViolation (ourEndPoint, theirEndPoint) "doRemoteRequest (closed)"
, caseFailed =
throwIO
}
mReply <- takeMVar replyMVar
case mReply of
Left err -> throwIO err
Right reply -> return reply
closeIfUnused :: EndPointPair -> IO ()
closeIfUnused (ourEndPoint, theirEndPoint) =
modifyRemoteState_ (ourEndPoint, theirEndPoint) remoteStateIdentity
{ caseValid = \vst ->
if vst ^. remoteOutgoing == 0 && IntSet.null (vst ^. remoteIncoming)
then do
sendOn vst [encodeInt32 CloseSocket]
resolved <- newEmptyMVar
return $ RemoteEndPointClosing resolved vst
else
return $ RemoteEndPointValid vst
}
resetIfBroken :: LocalEndPoint -> EndPointAddress -> IO ()
resetIfBroken ourEndPoint theirAddress = do
mTheirEndPoint <- withMVar (localState ourEndPoint) $ \st -> case st of
LocalEndPointValid vst ->
return (vst ^. localConnectionTo theirAddress)
LocalEndPointClosed ->
throwIO $ TransportError ConnectFailed "Endpoint closed"
forM_ mTheirEndPoint $ \theirEndPoint ->
withMVar (remoteState theirEndPoint) $ \st -> case st of
RemoteEndPointInvalid _ ->
removeRemoteEndPoint (ourEndPoint, theirEndPoint)
RemoteEndPointFailed _ ->
removeRemoteEndPoint (ourEndPoint, theirEndPoint)
_ ->
return ()
connectToSelf :: LocalEndPoint
-> IO Connection
connectToSelf ourEndPoint = do
connAlive <- newIORef True
connId <- mapIOException connectFailed $ getNextConnectionId ourEndPoint
writeChan ourChan $
ConnectionOpened connId ReliableOrdered (localAddress ourEndPoint)
return Connection
{ send = selfSend connAlive connId
, close = selfClose connAlive connId
}
where
selfSend :: IORef Bool
-> ConnectionId
-> [ByteString]
-> IO (Either (TransportError SendErrorCode) ())
selfSend connAlive connId msg =
try . withMVar ourState $ \st -> case st of
LocalEndPointValid _ -> do
alive <- readIORef connAlive
if alive
then writeChan ourChan (Received connId msg)
else throwIO $ TransportError SendClosed "Connection closed"
LocalEndPointClosed ->
throwIO $ TransportError SendFailed "Endpoint closed"
selfClose :: IORef Bool -> ConnectionId -> IO ()
selfClose connAlive connId =
withMVar ourState $ \st -> case st of
LocalEndPointValid _ -> do
alive <- readIORef connAlive
when alive $ do
writeChan ourChan (ConnectionClosed connId)
writeIORef connAlive False
LocalEndPointClosed ->
return ()
ourChan = localChannel ourEndPoint
ourState = localState ourEndPoint
connectFailed = TransportError ConnectFailed . show
resolveInit :: EndPointPair -> RemoteState -> IO ()
resolveInit (ourEndPoint, theirEndPoint) newState =
modifyMVar_ (remoteState theirEndPoint) $ \st -> case st of
RemoteEndPointInit resolved _ -> do
putMVar resolved ()
case newState of
RemoteEndPointClosed ->
removeRemoteEndPoint (ourEndPoint, theirEndPoint)
_ ->
return ()
return newState
RemoteEndPointFailed ex ->
throwIO ex
_ ->
relyViolation (ourEndPoint, theirEndPoint) "resolveInit"
getNextConnectionId :: LocalEndPoint -> IO ConnectionId
getNextConnectionId ourEndpoint =
modifyMVar (localState ourEndpoint) $ \st -> case st of
LocalEndPointValid vst -> do
let connId = vst ^. nextConnectionId
return ( LocalEndPointValid
. (nextConnectionId ^= connId + 1)
$ vst
, connId)
LocalEndPointClosed ->
throwIO $ userError "Local endpoint closed"
createLocalEndPoint :: TCPTransport -> IO LocalEndPoint
createLocalEndPoint transport = do
chan <- newChan
state <- newMVar . LocalEndPointValid $ ValidLocalEndPointState
{ _nextConnectionId = firstNonReservedConnectionId
, _localConnections = Map.empty
, _nextRemoteId = 0
}
modifyMVar (transportState transport) $ \st -> case st of
TransportValid vst -> do
let ix = vst ^. nextEndPointId
let addr = encodeEndPointAddress (transportHost transport)
(transportPort transport)
ix
let localEndPoint = LocalEndPoint { localAddress = addr
, localChannel = chan
, localState = state
}
return ( TransportValid
. (localEndPointAt addr ^= Just localEndPoint)
. (nextEndPointId ^= ix + 1)
$ vst
, localEndPoint
)
TransportClosed ->
throwIO (TransportError NewEndPointFailed "Transport closed")
removeRemoteEndPoint :: EndPointPair -> IO ()
removeRemoteEndPoint (ourEndPoint, theirEndPoint) =
modifyMVar_ ourState $ \st -> case st of
LocalEndPointValid vst ->
case vst ^. localConnectionTo theirAddress of
Nothing ->
return st
Just remoteEndPoint' ->
if remoteId remoteEndPoint' == remoteId theirEndPoint
then return
( LocalEndPointValid
. (localConnectionTo (remoteAddress theirEndPoint) ^= Nothing)
$ vst
)
else return st
LocalEndPointClosed ->
return LocalEndPointClosed
where
ourState = localState ourEndPoint
theirAddress = remoteAddress theirEndPoint
removeLocalEndPoint :: TCPTransport -> LocalEndPoint -> IO ()
removeLocalEndPoint transport ourEndPoint =
modifyMVar_ (transportState transport) $ \st -> case st of
TransportValid vst ->
return ( TransportValid
. (localEndPointAt (localAddress ourEndPoint) ^= Nothing)
$ vst
)
TransportClosed ->
return TransportClosed
findRemoteEndPoint
:: LocalEndPoint
-> EndPointAddress
-> RequestedBy
-> IO (RemoteEndPoint, Bool)
findRemoteEndPoint ourEndPoint theirAddress findOrigin = go
where
go = do
(theirEndPoint, isNew) <- modifyMVar ourState $ \st -> case st of
LocalEndPointValid vst -> case vst ^. localConnectionTo theirAddress of
Just theirEndPoint ->
return (st, (theirEndPoint, False))
Nothing -> do
resolved <- newEmptyMVar
theirState <- newMVar (RemoteEndPointInit resolved findOrigin)
let theirEndPoint = RemoteEndPoint
{ remoteAddress = theirAddress
, remoteState = theirState
, remoteId = vst ^. nextRemoteId
}
return ( LocalEndPointValid
. (localConnectionTo theirAddress ^= Just theirEndPoint)
. (nextRemoteId ^: (+ 1))
$ vst
, (theirEndPoint, True)
)
LocalEndPointClosed ->
throwIO $ userError "Local endpoint closed"
if isNew
then
return (theirEndPoint, True)
else do
let theirState = remoteState theirEndPoint
snapshot <- modifyMVar theirState $ \st -> case st of
RemoteEndPointValid vst ->
case findOrigin of
RequestedByUs -> do
let st' = RemoteEndPointValid
. (remoteOutgoing ^: (+ 1))
$ vst
return (st', st')
RequestedByThem ->
return (st, st)
_ ->
return (st, st)
case snapshot of
RemoteEndPointInvalid err ->
throwIO err
RemoteEndPointInit resolved initOrigin ->
case (findOrigin, initOrigin) of
(RequestedByUs, RequestedByUs) ->
readMVar resolved >> go
(RequestedByUs, RequestedByThem) ->
readMVar resolved >> go
(RequestedByThem, RequestedByUs) ->
if ourAddress > theirAddress
then
readMVar resolved >> go
else
return (theirEndPoint, False)
(RequestedByThem, RequestedByThem) ->
throwIO $ userError "Already connected"
RemoteEndPointValid _ ->
return (theirEndPoint, False)
RemoteEndPointClosing resolved _ ->
readMVar resolved >> go
RemoteEndPointClosed ->
go
RemoteEndPointFailed err ->
throwIO err
ourState = localState ourEndPoint
ourAddress = localAddress ourEndPoint
socketToEndPoint :: EndPointAddress
-> EndPointAddress
-> Bool
-> Maybe Int
-> IO (Either (TransportError ConnectErrorCode)
(N.Socket, ConnectionRequestResponse))
socketToEndPoint (EndPointAddress ourAddress) theirAddress reuseAddr timeout =
try $ do
(host, port, theirEndPointId) <- case decodeEndPointAddress theirAddress of
Nothing -> throwIO (failed . userError $ "Could not parse")
Just dec -> return dec
addr:_ <- mapIOException invalidAddress $
N.getAddrInfo Nothing (Just host) (Just port)
bracketOnError (createSocket addr) tryCloseSocket $ \sock -> do
when reuseAddr $
mapIOException failed $ N.setSocketOption sock N.ReuseAddr 1
mapIOException invalidAddress $
timeoutMaybe timeout timeoutError $
N.connect sock (N.addrAddress addr)
response <- mapIOException failed $ do
sendMany sock (encodeInt32 theirEndPointId : prependLength [ourAddress])
recvInt32 sock
case tryToEnum response of
Nothing -> throwIO (failed . userError $ "Unexpected response")
Just r -> return (sock, r)
where
createSocket :: N.AddrInfo -> IO N.Socket
createSocket addr = mapIOException insufficientResources $
N.socket (N.addrFamily addr) N.Stream N.defaultProtocol
invalidAddress = TransportError ConnectNotFound . show
insufficientResources = TransportError ConnectInsufficientResources . show
failed = TransportError ConnectFailed . show
timeoutError = TransportError ConnectTimeout "Timed out"
encodeEndPointAddress :: N.HostName
-> N.ServiceName
-> EndPointId
-> EndPointAddress
encodeEndPointAddress host port ix = EndPointAddress . BSC.pack $
host ++ ":" ++ port ++ ":" ++ show ix
decodeEndPointAddress :: EndPointAddress
-> Maybe (N.HostName, N.ServiceName, EndPointId)
decodeEndPointAddress (EndPointAddress bs) =
case splitMaxFromEnd (== ':') 2 $ BSC.unpack bs of
[host, port, endPointIdStr] ->
case reads endPointIdStr of
[(endPointId, "")] -> Just (host, port, endPointId)
_ -> Nothing
_ ->
Nothing
splitMaxFromEnd :: (a -> Bool) -> Int -> [a] -> [[a]]
splitMaxFromEnd p = \n -> go [[]] n . reverse
where
go accs _ [] = accs
go ([] : accs) 0 xs = reverse xs : accs
go (acc : accs) n (x:xs) =
if p x then go ([] : acc : accs) (n 1) xs
else go ((x : acc) : accs) n xs
go _ _ _ = error "Bug in splitMaxFromEnd"
internalSocketBetween :: TCPTransport
-> EndPointAddress
-> EndPointAddress
-> IO N.Socket
internalSocketBetween transport ourAddress theirAddress = do
ourEndPoint <- withMVar (transportState transport) $ \st -> case st of
TransportClosed ->
throwIO $ userError "Transport closed"
TransportValid vst ->
case vst ^. localEndPointAt ourAddress of
Nothing -> throwIO $ userError "Local endpoint not found"
Just ep -> return ep
theirEndPoint <- withMVar (localState ourEndPoint) $ \st -> case st of
LocalEndPointClosed ->
throwIO $ userError "Local endpoint closed"
LocalEndPointValid vst ->
case vst ^. localConnectionTo theirAddress of
Nothing -> throwIO $ userError "Remote endpoint not found"
Just ep -> return ep
withMVar (remoteState theirEndPoint) $ \st -> case st of
RemoteEndPointInit _ _ ->
throwIO $ userError "Remote endpoint not yet initialized"
RemoteEndPointValid vst ->
return $ remoteSocket vst
RemoteEndPointClosing _ vst ->
return $ remoteSocket vst
RemoteEndPointClosed ->
throwIO $ userError "Remote endpoint closed"
RemoteEndPointInvalid err ->
throwIO err
RemoteEndPointFailed err ->
throwIO err
firstNonReservedConnectionId :: ConnectionId
firstNonReservedConnectionId = 1024
localEndPoints :: Accessor ValidTransportState (Map EndPointAddress LocalEndPoint)
localEndPoints = accessor _localEndPoints (\es st -> st { _localEndPoints = es })
nextEndPointId :: Accessor ValidTransportState EndPointId
nextEndPointId = accessor _nextEndPointId (\eid st -> st { _nextEndPointId = eid })
nextConnectionId :: Accessor ValidLocalEndPointState ConnectionId
nextConnectionId = accessor _nextConnectionId (\cix st -> st { _nextConnectionId = cix })
localConnections :: Accessor ValidLocalEndPointState (Map EndPointAddress RemoteEndPoint)
localConnections = accessor _localConnections (\es st -> st { _localConnections = es })
nextRemoteId :: Accessor ValidLocalEndPointState Int
nextRemoteId = accessor _nextRemoteId (\rid st -> st { _nextRemoteId = rid })
remoteOutgoing :: Accessor ValidRemoteEndPointState Int
remoteOutgoing = accessor _remoteOutgoing (\cs conn -> conn { _remoteOutgoing = cs })
remoteIncoming :: Accessor ValidRemoteEndPointState IntSet
remoteIncoming = accessor _remoteIncoming (\cs conn -> conn { _remoteIncoming = cs })
pendingCtrlRequests :: Accessor ValidRemoteEndPointState (IntMap (MVar (Either IOException [ByteString])))
pendingCtrlRequests = accessor _pendingCtrlRequests (\rep st -> st { _pendingCtrlRequests = rep })
nextCtrlRequestId :: Accessor ValidRemoteEndPointState ControlRequestId
nextCtrlRequestId = accessor _nextCtrlRequestId (\cid st -> st { _nextCtrlRequestId = cid })
localEndPointAt :: EndPointAddress -> Accessor ValidTransportState (Maybe LocalEndPoint)
localEndPointAt addr = localEndPoints >>> DAC.mapMaybe addr
pendingCtrlRequestsAt :: ControlRequestId -> Accessor ValidRemoteEndPointState (Maybe (MVar (Either IOException [ByteString])))
pendingCtrlRequestsAt ix = pendingCtrlRequests >>> DAC.intMapMaybe (fromIntegral ix)
localConnectionTo :: EndPointAddress
-> Accessor ValidLocalEndPointState (Maybe RemoteEndPoint)
localConnectionTo addr = localConnections >>> DAC.mapMaybe addr
relyViolation :: EndPointPair -> String -> IO a
relyViolation (ourEndPoint, theirEndPoint) str = do
elog (ourEndPoint, theirEndPoint) (str ++ " RELY violation")
fail (str ++ " RELY violation")
elog :: EndPointPair -> String -> IO ()
elog (ourEndPoint, theirEndPoint) msg = do
tid <- myThreadId
putStrLn $ show (localAddress ourEndPoint)
++ "/" ++ show (remoteAddress theirEndPoint)
++ "(" ++ show (remoteId theirEndPoint) ++ ")"
++ "/" ++ show tid
++ ": " ++ msg