-- | TCP implementation of the transport layer. 
-- 
-- The TCP implementation guarantees that only a single TCP connection (socket)
-- will be used between endpoints, provided that the addresses specified are
-- canonical. If /A/ connects to /B/ and reports its address as
-- @192.168.0.1:8080@ and /B/ subsequently connects tries to connect to /A/ as
-- @client1.local:http-alt@ then the transport layer will not realize that the
-- TCP connection can be reused. 
--
-- Applications that use the TCP transport should use
-- 'Network.Socket.withSocketsDo' in their main function for Windows
-- compatibility (see "Network.Socket").
module Network.Transport.TCP ( -- * Main API
                               createTransport
                             , TCPParameters(..)
                             , defaultTCPParameters
                               -- * Internals (exposed for unit tests) 
                             , createTransportExposeInternals 
                             , TransportInternals(..)
                             , EndPointId
                             , encodeEndPointAddress
                             , decodeEndPointAddress
                             , ControlHeader(..)
                             , ConnectionRequestResponse(..)
                             , firstNonReservedConnectionId
                             , socketToEndPoint 
                               -- * Design notes
                               -- $design
                             ) where

import Prelude hiding 
  ( mapM_
#if ! MIN_VERSION_base(4,6,0)
  , catch
#endif
  )
 
import Network.Transport
import Network.Transport.TCP.Internal ( forkServer
                                      , recvWithLength
                                      , recvInt32
                                      , tryCloseSocket
                                      )
import Network.Transport.Internal ( encodeInt32
                                  , decodeInt32
                                  , prependLength
                                  , mapIOException
                                  , tryIO
                                  , tryToEnum
                                  , void
                                  , timeoutMaybe
                                  , asyncWhenCancelled
                                  )
import qualified Network.Socket as N ( HostName
                                     , ServiceName
                                     , Socket
                                     , getAddrInfo
                                     , socket
                                     , addrFamily
                                     , addrAddress
                                     , SocketType(Stream)
                                     , defaultProtocol
                                     , setSocketOption
                                     , SocketOption(ReuseAddr) 
                                     , connect
                                     , sOMAXCONN
                                     , AddrInfo
                                     )
import Network.Socket.ByteString (sendMany)
import Control.Concurrent (forkIO, ThreadId, killThread, myThreadId)
import Control.Concurrent.Chan (Chan, newChan, readChan, writeChan)
import Control.Concurrent.MVar ( MVar
                               , newMVar
                               , modifyMVar
                               , modifyMVar_
                               , readMVar
                               , takeMVar
                               , putMVar
                               , newEmptyMVar
                               , withMVar
                               )
import Control.Category ((>>>))
import Control.Applicative ((<$>))
import Control.Monad (when, unless)
import Control.Exception ( IOException
                         , SomeException
                         , AsyncException
                         , handle
                         , throw
                         , throwIO
                         , try
                         , bracketOnError
                         , mask
                         , onException
                         , fromException
                         )
import Data.IORef (IORef, newIORef, writeIORef, readIORef)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS (concat)
import qualified Data.ByteString.Char8 as BSC (pack, unpack)
import Data.Int (Int32)
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap (empty)
import Data.IntSet (IntSet)
import qualified Data.IntSet as IntSet ( empty
                                       , insert
                                       , elems
                                       , singleton
                                       , null
                                       , delete
                                       , member
                                       )
import Data.Map (Map)
import qualified Data.Map as Map (empty)
import Data.Accessor (Accessor, accessor, (^.), (^=), (^:)) 
import qualified Data.Accessor.Container as DAC (mapMaybe, intMapMaybe)
import Data.Foldable (forM_, mapM_)

-- $design 
--
-- [Goals]
--
-- The TCP transport maps multiple logical connections between /A/ and /B/ (in
-- either direction) to a single TCP connection:
--
-- > +-------+                          +-------+
-- > | A     |==========================| B     |
-- > |       |>~~~~~~~~~~~~~~~~~~~~~~~~~|~~~\   |
-- > |   Q   |>~~~~~~~~~~~~~~~~~~~~~~~~~|~~~Q   |
-- > |   \~~~|~~~~~~~~~~~~~~~~~~~~~~~~~<|       |
-- > |       |==========================|       |
-- > +-------+                          +-------+
-- 
-- Ignoring the complications detailed below, the TCP connection is set up is
-- when the first lightweight connection is created (in either direction), and
-- torn down when the last lightweight connection (in either direction) is
-- closed.
--
-- [Connecting]
--
-- Let /A/, /B/ be two endpoints without any connections. When /A/ wants to
-- connect to /B/, it locally records that it is trying to connect to /B/ and
-- sends a request to /B/. As part of the request /A/ sends its own endpoint
-- address to /B/ (so that /B/ can reuse the connection in the other direction).
--
-- When /B/ receives the connection request it first checks if it did not
-- already initiate a connection request to /A/. If not it will acknowledge the
-- connection request by sending 'ConnectionRequestAccepted' to /A/ and record
-- that it has a TCP connection to /A/.
--
-- The tricky case arises when /A/ sends a connection request to /B/ and /B/
-- finds that it had already sent a connection request to /A/. In this case /B/
-- will accept the connection request from /A/ if /A/s endpoint address is
-- smaller (lexicographically) than /B/s, and reject it otherwise. If it rejects
-- it, it sends a 'ConnectionRequestCrossed' message to /A/. (The
-- lexicographical ordering is an arbitrary but convenient way to break the
-- tie.)
--
-- When it receives a 'ConnectionRequestCrossed' message the /A/ thread that
-- initiated the request just needs to wait until the /A/ thread that is dealing
-- with /B/'s connection request completes.
--
-- [Disconnecting]
-- 
-- The TCP connection is created as soon as the first logical connection from
-- /A/ to /B/ (or /B/ to /A/) is established. At this point a thread (@#@) is
-- spawned that listens for incoming connections from /B/:
--
-- > +-------+                          +-------+
-- > | A     |==========================| B     |
-- > |       |>~~~~~~~~~~~~~~~~~~~~~~~~~|~~~\   |
-- > |       |                          |   Q   |
-- > |      #|                          |       |
-- > |       |==========================|       |
-- > +-------+                          +-------+
--
-- The question is when the TCP connection can be closed again.  Conceptually,
-- we want to do reference counting: when there are no logical connections left
-- between /A/ and /B/ we want to close the socket (possibly after some
-- timeout). 
--
-- However, /A/ and /B/ need to agree that the refcount has reached zero.  It
-- might happen that /B/ sends a connection request over the existing socket at
-- the same time that /A/ closes its logical connection to /B/ and closes the
-- socket. This will cause a failure in /B/ (which will have to retry) which is
-- not caused by a network failure, which is unfortunate. (Note that the
-- connection request from /B/ might succeed even if /A/ closes the socket.) 
--
-- Instead, when /A/ is ready to close the socket it sends a 'CloseSocket'
-- request to /B/ and records that its connection to /B/ is closing. If /A/
-- receives a new connection request from /B/ after having sent the
-- 'CloseSocket' request it simply forgets that it sent a 'CloseSocket' request
-- and increments the reference count of the connection again.
-- 
-- When /B/ receives a 'CloseSocket' message and it too is ready to close the
-- connection, it will respond with a reciprocal 'CloseSocket' request to /A/
-- and then actually close the socket. /A/ meanwhile will not send any more
-- requests to /B/ after having sent a 'CloseSocket' request, and will actually
-- close its end of the socket only when receiving the 'CloseSocket' message
-- from /B/. (Since /A/ recorded that its connection to /B/ is in closing state
-- after sending a 'CloseSocket' request to /B/, it knows not to reciprocate /B/
-- reciprocal 'CloseSocket' message.)
-- 
-- If there is a concurrent thread in /A/ waiting to connect to /B/ after /A/
-- has sent a 'CloseSocket' request then this thread will block until /A/ knows
-- whether to reuse the old socket (if /B/ sends a new connection request
-- instead of acknowledging the 'CloseSocket') or to set up a new socket. 

--------------------------------------------------------------------------------
-- Internal datatypes                                                         --
--------------------------------------------------------------------------------

-- We use underscores for fields that we might update (using accessors)
--
-- All data types follow the same structure:
-- 
-- * A top-level data type describing static properties (TCPTransport,
--   LocalEndPoint, RemoteEndPoint)
-- * The 'static' properties include an MVar containing a data structure for
--   the dynamic properties (TransportState, LocalEndPointState,
--   RemoteEndPointState). The state could be invalid/valid/closed,/etc.  
-- * For the case of "valid" we use third data structure to give more details
--   about the state (ValidTransportState, ValidLocalEndPointState,
--   ValidRemoteEndPointState).

data TCPTransport = TCPTransport 
  { transportHost   :: N.HostName
  , transportPort   :: N.ServiceName
  , transportState  :: MVar TransportState 
  , transportParams :: TCPParameters
  }

data TransportState = 
    TransportValid ValidTransportState
  | TransportClosed

data ValidTransportState = ValidTransportState 
  { _localEndPoints :: Map EndPointAddress LocalEndPoint 
  , _nextEndPointId :: EndPointId 
  }

data LocalEndPoint = LocalEndPoint
  { localAddress :: EndPointAddress
  , localChannel :: Chan Event
  , localState   :: MVar LocalEndPointState 
  }

data LocalEndPointState = 
    LocalEndPointValid ValidLocalEndPointState
  | LocalEndPointClosed

data ValidLocalEndPointState = ValidLocalEndPointState 
  { _nextConnectionId :: !ConnectionId 
  , _localConnections :: Map EndPointAddress RemoteEndPoint 
  , _nextRemoteId     :: !Int
  }

-- REMOTE ENDPOINTS 
--
-- Remote endpoints (basically, TCP connections) have the following lifecycle:
--
--   Init  ---+---> Invalid
--            |
--            +-------------------------------\
--            |                               |
--            |       /----------\            |
--            |       |          |            |
--            |       v          |            v
--            +---> Valid ---> Closing ---> Closed
--            |       |          |            |
--            |       |          |            v
--            \-------+----------+--------> Failed
--
-- Init: There are two places where we create new remote endpoints: in
--   requestConnectionTo (in response to an API 'connect' call) and in
--   handleConnectionRequest (when a remote node tries to connect to us).
--   'Init' carries an MVar () 'resolved' which concurrent threads can use to
--   wait for the remote endpoint to finish initialization. We record who
--   requested the connection (the local endpoint or the remote endpoint).
--
-- Invalid: We put the remote endpoint in invalid state only during
--   requestConnectionTo when we fail to connect.
--
-- Valid: This is the "normal" state for a working remote endpoint.
-- 
-- Closing: When we detect that a remote endpoint is no longer used, we send a
--   CloseSocket request across the connection and put the remote endpoint in
--   closing state. As with Init, 'Closing' carries an MVar () 'resolved' which
--   concurrent threads can use to wait for the remote endpoint to either be
--   closed fully (if the communication parnet responds with another
--   CloseSocket) or be put back in 'Valid' state if the remote endpoint denies
--   the request.
--
--   We also put the endpoint in Closed state, directly from Init, if we our
--   outbound connection request crossed an inbound connection request and we
--   decide to keep the inbound (i.e., the remote endpoint sent us a
--   ConnectionRequestCrossed message).
--
-- Closed: The endpoint is put in Closed state after a successful garbage
--   collection.
--
-- Failed: If the connection to the remote endpoint is lost, or the local
-- endpoint (or the whole transport) is closed manually, the remote endpoint is
-- put in Failed state, and we record the reason.
--
-- Invariants for dealing with remote endpoints:
--
-- INV-SEND: Whenever we send data the remote endpoint must be locked (to avoid
--   interleaving bits of payload).
--
-- INV-CLOSE: Local endpoints should never point to remote endpoint in closed
--   state.  Whenever we put an endpoint in Closed state we remove that
--   endpoint from localConnections first, so that if a concurrent thread reads
--   the MVar, finds RemoteEndPointClosed, and then looks up the endpoint in
--   localConnections it is guaranteed to either find a different remote
--   endpoint, or else none at all (if we don't insist in this order some
--   threads might start spinning).
-- 
-- INV-RESOLVE: We should only signal on 'resolved' while the remote endpoint is
--   locked, and the remote endpoint must be in Valid or Closed state once
--   unlocked. This guarantees that there will not be two threads attempting to
--   both signal on 'resolved'. 
--
-- INV-LOST: If a send or recv fails, or a socket is closed unexpectedly, we
--   first put the remote endpoint in Closed state, and then send a
--   EventConnectionLost event. This guarantees that we only send this event
--   once.  
--
-- INV-CLOSING: An endpoint in closing state is for all intents and purposes
--   closed; that is, we shouldn't do any 'send's on it (although 'recv' is
--   acceptable, of course -- as we are waiting for the remote endpoint to
--   confirm or deny the request).
--
-- INV-LOCK-ORDER: Remote endpoint must be locked before their local endpoints.
--   In other words: it is okay to call modifyMVar on a local endpoint inside a
--   modifyMVar on a remote endpoint, but not the other way around. In
--   particular, it is okay to call removeRemoteEndPoint inside
--   modifyRemoteState.

data RemoteEndPoint = RemoteEndPoint 
  { remoteAddress :: EndPointAddress
  , remoteState   :: MVar RemoteState
  , remoteId      :: Int
  }

data RequestedBy = RequestedByUs | RequestedByThem 
  deriving (Eq, Show)

data RemoteState =
    -- | Invalid remote endpoint (for example, invalid address)
    RemoteEndPointInvalid (TransportError ConnectErrorCode)
    -- | The remote endpoint is being initialized
  | RemoteEndPointInit (MVar ()) RequestedBy 
    -- | "Normal" working endpoint
  | RemoteEndPointValid ValidRemoteEndPointState 
    -- | The remote endpoint is being closed (garbage collected)
  | RemoteEndPointClosing (MVar ()) ValidRemoteEndPointState
    -- | The remote endpoint has been closed (garbage collected)
  | RemoteEndPointClosed
    -- | The remote endpoint has failed, or has been forcefully shutdown
    -- using a closeTransport or closeEndPoint API call
  | RemoteEndPointFailed IOException

data ValidRemoteEndPointState = ValidRemoteEndPointState 
  { _remoteOutgoing      :: !Int
  , _remoteIncoming      :: IntSet
  ,  remoteSocket        :: N.Socket
  ,  sendOn              :: [ByteString] -> IO ()
  , _pendingCtrlRequests :: IntMap (MVar (Either IOException [ByteString]))
  , _nextCtrlRequestId   :: !ControlRequestId 
  }

-- | Local identifier for an endpoint within this transport
type EndPointId       = Int32

-- | Control request ID
-- 
-- Control requests are asynchronous; the request ID makes it possible to match
-- requests and replies
type ControlRequestId = Int32

-- | Pair of local and a remote endpoint (for conciseness in signatures)
type EndPointPair     = (LocalEndPoint, RemoteEndPoint)

-- | Control headers 
data ControlHeader = 
    -- | Request a new connection ID from the remote endpoint
    RequestConnectionId 
    -- | Tell the remote endpoint we will no longer be using a connection
  | CloseConnection     
    -- | Respond to a control request /from/ the remote endpoint
  | ControlResponse     
    -- | Request to close the connection (see module description)
  | CloseSocket         
  deriving (Enum, Bounded, Show)

-- | Response sent by /B/ to /A/ when /A/ tries to connect
data ConnectionRequestResponse =
    -- | /B/ accepts the connection
    ConnectionRequestAccepted        
    -- | /A/ requested an invalid endpoint
  | ConnectionRequestInvalid 
    -- | /A/s request crossed with a request from /B/ (see protocols)
  | ConnectionRequestCrossed         
  deriving (Enum, Bounded, Show)

-- | Parameters for setting up the TCP transport
data TCPParameters = TCPParameters {
    -- | Backlog for 'listen'.
    -- Defaults to SOMAXCONN.
    tcpBacklog :: Int
    -- | Should we set SO_REUSEADDR on the server socket? 
    -- Defaults to True.
  , tcpReuseServerAddr :: Bool
    -- | Should we set SO_REUSEADDR on client sockets?
    -- Defaults to True.
  , tcpReuseClientAddr :: Bool
  }

-- | Internal functionality we expose for unit testing
data TransportInternals = TransportInternals 
  { -- | The ID of the thread that listens for new incoming connections
    transportThread :: ThreadId
    -- | Find the socket between a local and a remote endpoint
  , socketBetween :: EndPointAddress 
                  -> EndPointAddress 
                  -> IO N.Socket 
  }

--------------------------------------------------------------------------------
-- Top-level functionality                                                    --
--------------------------------------------------------------------------------

-- | Create a TCP transport
createTransport :: N.HostName 
                -> N.ServiceName 
                -> TCPParameters
                -> IO (Either IOException Transport)
createTransport host port params = 
  either Left (Right . fst) <$> createTransportExposeInternals host port params

-- | You should probably not use this function (used for unit testing only)
createTransportExposeInternals 
  :: N.HostName 
  -> N.ServiceName 
  -> TCPParameters
  -> IO (Either IOException (Transport, TransportInternals)) 
createTransportExposeInternals host port params = do 
    state <- newMVar . TransportValid $ ValidTransportState 
      { _localEndPoints = Map.empty 
      , _nextEndPointId = 0 
      }
    let transport = TCPTransport { transportState  = state
                                 , transportHost   = host
                                 , transportPort   = port
                                 , transportParams = params
                                 }
    tryIO $ bracketOnError (forkServer 
                             host 
                             port 
                             (tcpBacklog params) 
                             (tcpReuseServerAddr params)
                             (terminationHandler transport) 
                             (handleConnectionRequest transport))
                           killThread
                           (mkTransport transport)
  where
    mkTransport :: TCPTransport 
                -> ThreadId 
                -> IO (Transport, TransportInternals)
    mkTransport transport tid = return
      ( Transport 
          { newEndPoint    = apiNewEndPoint transport 
          , closeTransport = let evs = [ EndPointClosed
                                       , throw $ userError "Transport closed"
                                       ] in
                             apiCloseTransport transport (Just tid) evs 
          } 
      , TransportInternals 
          { transportThread = tid
          , socketBetween   = internalSocketBetween transport
          }
      )

    terminationHandler :: TCPTransport -> SomeException -> IO ()
    terminationHandler transport ex = do 
      let evs = [ ErrorEvent (TransportError EventTransportFailed (show ex))
                , throw $ userError "Transport closed" 
                ]
      apiCloseTransport transport Nothing evs 

-- | Default TCP parameters
defaultTCPParameters :: TCPParameters
defaultTCPParameters = TCPParameters {
    tcpBacklog         = N.sOMAXCONN
  , tcpReuseServerAddr = True
  , tcpReuseClientAddr = True
  }

--------------------------------------------------------------------------------
-- API functions                                                              --
--------------------------------------------------------------------------------

-- | Close the transport
apiCloseTransport :: TCPTransport -> Maybe ThreadId -> [Event] -> IO ()
apiCloseTransport transport mTransportThread evs = 
  asyncWhenCancelled return $ do
    mTSt <- modifyMVar (transportState transport) $ \st -> case st of
      TransportValid vst -> return (TransportClosed, Just vst)
      TransportClosed    -> return (TransportClosed, Nothing)
    forM_ mTSt $ mapM_ (apiCloseEndPoint transport evs) . (^. localEndPoints) 
    -- This will invoke the termination handler, which in turn will call
    -- apiCloseTransport again, but then the transport will already be closed
    -- and we won't be passed a transport thread, so we terminate immmediate
    forM_ mTransportThread killThread 
     
-- | Create a new endpoint 
apiNewEndPoint :: TCPTransport 
               -> IO (Either (TransportError NewEndPointErrorCode) EndPoint)
apiNewEndPoint transport = 
  try . asyncWhenCancelled closeEndPoint $ do
    ourEndPoint <- createLocalEndPoint transport
    return EndPoint 
      { receive       = readChan (localChannel ourEndPoint)
      , address       = localAddress ourEndPoint
      , connect       = apiConnect (transportParams transport) ourEndPoint 
      , closeEndPoint = let evs = [ EndPointClosed
                                  , throw $ userError "Endpoint closed" 
                                  ] in
                        apiCloseEndPoint transport evs ourEndPoint
      , newMulticastGroup     = return . Left $ newMulticastGroupError 
      , resolveMulticastGroup = return . Left . const resolveMulticastGroupError
      }
  where
    newMulticastGroupError = 
      TransportError NewMulticastGroupUnsupported "Multicast not supported" 
    resolveMulticastGroupError = 
      TransportError ResolveMulticastGroupUnsupported "Multicast not supported" 

-- | Connnect to an endpoint
apiConnect :: TCPParameters    -- ^ Parameters
           -> LocalEndPoint    -- ^ Local end point
           -> EndPointAddress  -- ^ Remote address
           -> Reliability      -- ^ Reliability (ignored)
           -> ConnectHints     -- ^ Hints
           -> IO (Either (TransportError ConnectErrorCode) Connection)
apiConnect params ourEndPoint theirAddress _reliability hints =
  try . asyncWhenCancelled close $ 
    if localAddress ourEndPoint == theirAddress 
      then connectToSelf ourEndPoint  
      else do
        resetIfBroken ourEndPoint theirAddress
        (theirEndPoint, connId) <- 
          requestConnectionTo params ourEndPoint theirAddress hints
        -- connAlive can be an IORef rather than an MVar because it is protected
        -- by the remoteState MVar. We don't need the overhead of locking twice.
        connAlive <- newIORef True
        return Connection 
          { send  = apiSend  (ourEndPoint, theirEndPoint) connId connAlive 
          , close = apiClose (ourEndPoint, theirEndPoint) connId connAlive 
          }

-- | Close a connection
apiClose :: EndPointPair -> ConnectionId -> IORef Bool -> IO ()
apiClose (ourEndPoint, theirEndPoint) connId connAlive = 
  void . tryIO . asyncWhenCancelled return $ do 
    modifyRemoteState_ (ourEndPoint, theirEndPoint) remoteStateIdentity  
      { caseValid = \vst -> do
          alive <- readIORef connAlive
          if alive 
            then do
              writeIORef connAlive False
              sendOn vst [encodeInt32 CloseConnection, encodeInt32 connId] 
              return ( RemoteEndPointValid 
                     . (remoteOutgoing ^: (\x -> x - 1)) 
                     $ vst
                     )
            else
              return (RemoteEndPointValid vst)
      }
    closeIfUnused (ourEndPoint, theirEndPoint)

-- | Send data across a connection
apiSend :: EndPointPair  -- ^ Local and remote endpoint 
        -> ConnectionId  -- ^ Connection ID (supplied by remote endpoint)
        -> IORef Bool    -- ^ Is the connection still alive?
        -> [ByteString]  -- ^ Payload
        -> IO (Either (TransportError SendErrorCode) ())
apiSend (ourEndPoint, theirEndPoint) connId connAlive payload =  
    -- We don't need the overhead of asyncWhenCancelled here
    try . mapIOException sendFailed $ 
      withRemoteState (ourEndPoint, theirEndPoint) RemoteStatePatternMatch
        { caseInvalid = \_ -> 
            relyViolation (ourEndPoint, theirEndPoint) "apiSend"
        , caseInit = \_ _ ->
            relyViolation (ourEndPoint, theirEndPoint) "apiSend"
        , caseValid = \vst -> do
            alive <- readIORef connAlive
            if alive 
              then sendOn vst (encodeInt32 connId : prependLength payload)
              else throwIO $ TransportError SendClosed "Connection closed"
        , caseClosing = \_ _ -> do
            alive <- readIORef connAlive
            if alive 
              then relyViolation (ourEndPoint, theirEndPoint) "apiSend" 
              else throwIO $ TransportError SendClosed "Connection closed"
        , caseClosed = do
            alive <- readIORef connAlive
            if alive 
              then relyViolation (ourEndPoint, theirEndPoint) "apiSend" 
              else throwIO $ TransportError SendClosed "Connection closed"
        , caseFailed = \err -> do
            alive <- readIORef connAlive
            if alive 
              then throwIO $ TransportError SendFailed (show err) 
              else throwIO $ TransportError SendClosed "Connection closed"
        }
  where
    sendFailed = TransportError SendFailed . show

-- | Force-close the endpoint
apiCloseEndPoint :: TCPTransport    -- ^ Transport 
                 -> [Event]         -- ^ Events used to report closure 
                 -> LocalEndPoint   -- ^ Local endpoint
                 -> IO ()
apiCloseEndPoint transport evs ourEndPoint =
  asyncWhenCancelled return $ do
    -- Remove the reference from the transport state
    removeLocalEndPoint transport ourEndPoint
    -- Close the local endpoint 
    mOurState <- modifyMVar (localState ourEndPoint) $ \st ->
      case st of
        LocalEndPointValid vst -> 
          return (LocalEndPointClosed, Just vst)
        LocalEndPointClosed ->
          return (LocalEndPointClosed, Nothing)
    forM_ mOurState $ \vst -> do
      forM_ (vst ^. localConnections) tryCloseRemoteSocket
      forM_ evs $ writeChan (localChannel ourEndPoint) 
  where
    -- Close the remote socket and return the set of all incoming connections
    tryCloseRemoteSocket :: RemoteEndPoint -> IO () 
    tryCloseRemoteSocket theirEndPoint = do
      -- We make an attempt to close the connection nicely 
      -- (by sending a CloseSocket first)
      let closed = RemoteEndPointFailed . userError $ "apiCloseEndPoint"
      modifyMVar_ (remoteState theirEndPoint) $ \st ->
        case st of
          RemoteEndPointInvalid _ -> 
            return st
          RemoteEndPointInit resolved _ -> do
            putMVar resolved ()
            return closed 
          RemoteEndPointValid conn -> do 
            tryIO $ sendOn conn [encodeInt32 CloseSocket]
            tryCloseSocket (remoteSocket conn)
            return closed 
          RemoteEndPointClosing resolved conn -> do 
            putMVar resolved ()
            tryCloseSocket (remoteSocket conn)
            return closed 
          RemoteEndPointClosed ->
            return st
          RemoteEndPointFailed err ->
            return $ RemoteEndPointFailed err 

--------------------------------------------------------------------------------
-- As soon as a remote connection fails, we want to put notify our endpoint   --
-- and put it into a closed state. Since this may happen in many places, we   --
-- provide some abstractions.                                                 --
--------------------------------------------------------------------------------

data RemoteStatePatternMatch a = RemoteStatePatternMatch 
  { caseInvalid :: TransportError ConnectErrorCode -> IO a
  , caseInit    :: MVar () -> RequestedBy -> IO a
  , caseValid   :: ValidRemoteEndPointState -> IO a
  , caseClosing :: MVar () -> ValidRemoteEndPointState -> IO a
  , caseClosed  :: IO a
  , caseFailed  :: IOException -> IO a
  }

remoteStateIdentity :: RemoteStatePatternMatch RemoteState
remoteStateIdentity =
  RemoteStatePatternMatch 
    { caseInvalid = return . RemoteEndPointInvalid
    , caseInit    = (return .) . RemoteEndPointInit
    , caseValid   = return . RemoteEndPointValid
    , caseClosing = (return .) . RemoteEndPointClosing 
    , caseClosed  = return RemoteEndPointClosed
    , caseFailed  = return . RemoteEndPointFailed
    }

-- | Like modifyMVar, but if an I/O exception occurs don't restore the remote
-- endpoint to its original value but close it instead 
modifyRemoteState :: EndPointPair 
                  -> RemoteStatePatternMatch (RemoteState, a) 
                  -> IO a
modifyRemoteState (ourEndPoint, theirEndPoint) match = 
    mask $ \restore -> do
      st <- takeMVar theirState
      case st of
        RemoteEndPointValid vst -> do
          mResult <- try $ restore (caseValid match vst) 
          case mResult of
            Right (st', a) -> do
              putMVar theirState st'
              return a
            Left ex -> do
              case fromException ex of
                Just ioEx -> handleIOException ioEx vst 
                Nothing   -> putMVar theirState st 
              throwIO ex
        -- The other cases are less interesting, because unless the endpoint is
        -- in Valid state we're not supposed to do any IO on it
        RemoteEndPointInit resolved origin -> do
          (st', a) <- onException (restore $ caseInit match resolved origin)
                                  (putMVar theirState st)
          putMVar theirState st'
          return a
        RemoteEndPointClosing resolved vst -> do 
          (st', a) <- onException (restore $ caseClosing match resolved vst)
                                  (putMVar theirState st)
          putMVar theirState st'
          return a
        RemoteEndPointInvalid err -> do
          (st', a) <- onException (restore $ caseInvalid match err)
                                  (putMVar theirState st)
          putMVar theirState st'
          return a
        RemoteEndPointClosed -> do
          (st', a) <- onException (restore $ caseClosed match)
                                  (putMVar theirState st)
          putMVar theirState st'
          return a
        RemoteEndPointFailed err -> do
          (st', a) <- onException (restore $ caseFailed match err)
                                  (putMVar theirState st)
          putMVar theirState st'
          return a
  where
    theirState :: MVar RemoteState
    theirState = remoteState theirEndPoint

    handleIOException :: IOException -> ValidRemoteEndPointState -> IO () 
    handleIOException ex vst = do
      tryCloseSocket (remoteSocket vst)
      putMVar theirState (RemoteEndPointFailed ex)
      let incoming = IntSet.elems $ vst ^. remoteIncoming
          code     = EventConnectionLost (Just $ remoteAddress theirEndPoint) incoming
          err      = TransportError code (show ex)
      writeChan (localChannel ourEndPoint) $ ErrorEvent err 

-- | Like 'modifyRemoteState' but without a return value
modifyRemoteState_ :: EndPointPair
                   -> RemoteStatePatternMatch RemoteState 
                   -> IO ()
modifyRemoteState_ (ourEndPoint, theirEndPoint) match =
    modifyRemoteState (ourEndPoint, theirEndPoint) 
      RemoteStatePatternMatch 
        { caseInvalid = u . caseInvalid match
        , caseInit    = \resolved origin -> u $ caseInit match resolved origin
        , caseValid   = u . caseValid match
        , caseClosing = \resolved vst -> u $ caseClosing match resolved vst
        , caseClosed  = u $ caseClosed match
        , caseFailed  = u . caseFailed match
        }
  where
    u :: IO a -> IO (a, ())
    u p = p >>= \a -> return (a, ())

-- | Like 'modifyRemoteState' but without the ability to change the state
withRemoteState :: EndPointPair
                -> RemoteStatePatternMatch a
                -> IO a 
withRemoteState (ourEndPoint, theirEndPoint) match =
  modifyRemoteState (ourEndPoint, theirEndPoint)
    RemoteStatePatternMatch 
      { caseInvalid = \err -> do
          a <- caseInvalid match err
          return (RemoteEndPointInvalid err, a)
      , caseInit = \resolved origin -> do
          a <- caseInit match resolved origin
          return (RemoteEndPointInit resolved origin, a)
      , caseValid = \vst -> do
          a <- caseValid match vst
          return (RemoteEndPointValid vst, a)
      , caseClosing = \resolved vst -> do 
          a <- caseClosing match resolved vst
          return (RemoteEndPointClosing resolved vst, a)
      , caseClosed = do 
          a <- caseClosed match
          return (RemoteEndPointClosed, a)
      , caseFailed = \err -> do 
          a <- caseFailed match err
          return (RemoteEndPointFailed err, a)
      }

--------------------------------------------------------------------------------
-- Incoming requests                                                          --
--------------------------------------------------------------------------------

-- | Handle a connection request (that is, a remote endpoint that is trying to
-- establish a TCP connection with us)
--
-- 'handleConnectionRequest' runs in the context of the transport thread, which
-- can be killed asynchronously by 'closeTransport'. We fork a separate thread
-- as soon as we have located the lcoal endpoint that the remote endpoint is
-- interested in. We cannot fork any sooner because then we have no way of
-- storing the thread ID and hence no way of killing the thread when we take
-- the transport down. We must be careful to close the socket when a (possibly
-- asynchronous, ThreadKilled) exception occurs. (If an exception escapes from
-- handleConnectionRequest the transport will be shut down.)
handleConnectionRequest :: TCPTransport -> N.Socket -> IO () 
handleConnectionRequest transport sock = handle handleException $ do 
    ourEndPointId <- recvInt32 sock
    theirAddress  <- EndPointAddress . BS.concat <$> recvWithLength sock 
    let ourAddress = encodeEndPointAddress (transportHost transport) 
                                           (transportPort transport)
                                           ourEndPointId
    ourEndPoint <- withMVar (transportState transport) $ \st -> case st of
      TransportValid vst ->
        case vst ^. localEndPointAt ourAddress of
          Nothing -> do
            sendMany sock [encodeInt32 ConnectionRequestInvalid]
            throwIO $ userError "handleConnectionRequest: Invalid endpoint"
          Just ourEndPoint ->
            return ourEndPoint
      TransportClosed -> 
        throwIO $ userError "Transport closed"
    void . forkIO $ go ourEndPoint theirAddress
  where
    go :: LocalEndPoint -> EndPointAddress -> IO ()
    go ourEndPoint theirAddress = do 
      -- This runs in a thread that will never be killed
      mEndPoint <- handle ((>> return Nothing) . handleException) $ do 
        resetIfBroken ourEndPoint theirAddress
        (theirEndPoint, isNew) <- 
          findRemoteEndPoint ourEndPoint theirAddress RequestedByThem
        
        if not isNew 
          then do
            tryIO $ sendMany sock [encodeInt32 ConnectionRequestCrossed]
            tryCloseSocket sock
            return Nothing 
          else do
            let vst = ValidRemoteEndPointState 
                        {  remoteSocket        = sock
                        , _remoteOutgoing      = 0
                        , _remoteIncoming      = IntSet.empty
                        , sendOn               = sendMany sock
                        , _pendingCtrlRequests = IntMap.empty
                        , _nextCtrlRequestId   = 0
                        }
            sendMany sock [encodeInt32 ConnectionRequestAccepted]
            resolveInit (ourEndPoint, theirEndPoint) (RemoteEndPointValid vst)
            return (Just theirEndPoint)
      -- If we left the scope of the exception handler with a return value of
      -- Nothing then the socket is already closed; otherwise, the socket has
      -- been recorded as part of the remote endpoint. Either way, we no longer
      -- have to worry about closing the socket on receiving an asynchronous
      -- exception from this point forward. 
      forM_ mEndPoint $ handleIncomingMessages . (,) ourEndPoint

    handleException :: SomeException -> IO ()
    handleException ex = do 
      tryCloseSocket sock 
      rethrowIfAsync (fromException ex)

    rethrowIfAsync :: Maybe AsyncException -> IO () 
    rethrowIfAsync = mapM_ throwIO 

-- | Handle requests from a remote endpoint.
-- 
-- Returns only if the remote party closes the socket or if an error occurs.
-- This runs in a thread that will never be killed.
handleIncomingMessages :: EndPointPair -> IO () 
handleIncomingMessages (ourEndPoint, theirEndPoint) = do
    mSock <- withMVar theirState $ \st ->
      case st of
        RemoteEndPointInvalid _ ->
          relyViolation (ourEndPoint, theirEndPoint) 
            "handleIncomingMessages (invalid)" 
        RemoteEndPointInit _ _ ->
          relyViolation (ourEndPoint, theirEndPoint) 
            "handleIncomingMessages (init)"
        RemoteEndPointValid ep ->
          return . Just $ remoteSocket ep
        RemoteEndPointClosing _ ep ->
          return . Just $ remoteSocket ep
        RemoteEndPointClosed ->
          return Nothing
        RemoteEndPointFailed _ -> 
          return Nothing
    
    forM_ mSock $ \sock -> 
      tryIO (go sock) >>= either (prematureExit sock) return
  where
    -- Dispatch 
    --
    -- If a recv throws an exception this will be caught top-level and
    -- 'prematureExit' will be invoked. The same will happen if the remote
    -- endpoint is put into a Closed (or Closing) state by a concurrent thread
    -- (because a 'send' failed) -- the individual handlers below will throw a
    -- user exception which is then caught and handled the same way as an
    -- exception thrown by 'recv'.
    go :: N.Socket -> IO ()
    go sock = do
      connId <- recvInt32 sock 
      if connId >= firstNonReservedConnectionId 
        then do
          readMessage sock connId
          go sock
        else 
          case tryToEnum (fromIntegral connId) of
            Just RequestConnectionId -> do
              recvInt32 sock >>= createNewConnection 
              go sock 
            Just ControlResponse -> do 
              recvInt32 sock >>= readControlResponse sock 
              go sock
            Just CloseConnection -> do
              recvInt32 sock >>= closeConnection 
              go sock
            Just CloseSocket -> do 
              didClose <- closeSocket sock 
              unless didClose $ go sock
            Nothing ->
              throwIO $ userError "Invalid control request"
        
    -- Create a new connection
    createNewConnection :: ControlRequestId -> IO () 
    createNewConnection reqId = do 
      newId <- getNextConnectionId ourEndPoint
      modifyMVar_ theirState $ \st -> do
        vst <- case st of
          RemoteEndPointInvalid _ ->
            relyViolation (ourEndPoint, theirEndPoint) 
              "handleIncomingMessages:createNewConnection (invalid)"
          RemoteEndPointInit _ _ ->
            relyViolation (ourEndPoint, theirEndPoint) 
              "handleIncomingMessages:createNewConnection (init)"
          RemoteEndPointValid vst ->
            return (remoteIncoming ^: IntSet.insert newId $ vst)
          RemoteEndPointClosing resolved vst -> do
            -- If the endpoint is in closing state that means we send a
            -- CloseSocket request to the remote endpoint. If the remote
            -- endpoint replies with the request to create a new connection, it
            -- either ignored our request or it sent the request before it got
            -- ours.  Either way, at this point we simply restore the endpoint
            -- to RemoteEndPointValid
            putMVar resolved ()
            return (remoteIncoming ^= IntSet.singleton newId $ vst)
          RemoteEndPointFailed err -> 
            throwIO err
          RemoteEndPointClosed ->
            relyViolation (ourEndPoint, theirEndPoint) 
              "createNewConnection (closed)"
        sendOn vst ( encodeInt32 ControlResponse 
                   : encodeInt32 reqId 
                   : prependLength [encodeInt32 newId] 
                   )
        return (RemoteEndPointValid vst)
      writeChan ourChannel (ConnectionOpened newId ReliableOrdered theirAddr) 

    -- Read a control response 
    readControlResponse :: N.Socket -> ControlRequestId -> IO () 
    readControlResponse sock reqId = do
      response <- recvWithLength sock
      mmvar    <- modifyMVar theirState $ \st -> case st of
        RemoteEndPointInvalid _ ->
          relyViolation (ourEndPoint, theirEndPoint)
            "readControlResponse (invalid)"
        RemoteEndPointInit _ _ ->
          relyViolation (ourEndPoint, theirEndPoint)
            "readControlResponse (init)"
        RemoteEndPointValid vst ->
          return ( RemoteEndPointValid 
                 . (pendingCtrlRequestsAt reqId ^= Nothing) 
                 $ vst
                 , vst ^. pendingCtrlRequestsAt reqId
                 )
        RemoteEndPointClosing _ _ ->
          throwIO $ userError "Invalid control response"
        RemoteEndPointFailed err ->
          throwIO err
        RemoteEndPointClosed ->
          relyViolation (ourEndPoint, theirEndPoint) 
            "readControlResponse (closed)"
      case mmvar of
        Nothing -> 
          throwIO $ userError "Invalid request ID"
        Just mvar -> 
          putMVar mvar (Right response)

    -- Close a connection 
    -- It is important that we verify that the connection is in fact open,
    -- because otherwise we should not decrement the reference count
    closeConnection :: ConnectionId -> IO () 
    closeConnection cid = do
      modifyMVar_ theirState $ \st -> case st of
        RemoteEndPointInvalid _ ->
          relyViolation (ourEndPoint, theirEndPoint) "closeConnection (invalid)"
        RemoteEndPointInit _ _ ->
          relyViolation (ourEndPoint, theirEndPoint) "closeConnection (init)"
        RemoteEndPointValid vst -> do 
          unless (IntSet.member cid (vst ^. remoteIncoming)) $ 
            throwIO $ userError "Invalid CloseConnection"
          return ( RemoteEndPointValid 
                 . (remoteIncoming ^: IntSet.delete cid) 
                 $ vst
                 )
        RemoteEndPointClosing _ _ ->
          -- If the remote endpoint is in Closing state, that means that are as
          -- far as we are concerned there are no incoming connections. This
          -- means that a CloseConnection request at this point is invalid.
          throwIO $ userError "Invalid CloseConnection request" 
        RemoteEndPointFailed err ->
          throwIO err
        RemoteEndPointClosed ->
          relyViolation (ourEndPoint, theirEndPoint) "closeConnection (closed)"
      writeChan ourChannel (ConnectionClosed cid)
      closeIfUnused (ourEndPoint, theirEndPoint)

    -- Close the socket (if we don't have any outgoing connections)
    closeSocket :: N.Socket -> IO Bool 
    closeSocket sock =
      modifyMVar theirState $ \st ->
        case st of
          RemoteEndPointInvalid _ ->
            relyViolation (ourEndPoint, theirEndPoint)
              "handleIncomingMessages:closeSocket (invalid)"
          RemoteEndPointInit _ _ ->
            relyViolation (ourEndPoint, theirEndPoint) 
              "handleIncomingMessages:closeSocket (init)"
          RemoteEndPointValid vst -> do
            -- We regard a CloseSocket message as an (optimized) way for the
            -- remote endpoint to indicate that all its connections to us are
            -- now properly closed
            forM_ (IntSet.elems $ vst ^. remoteIncoming) $ 
              writeChan ourChannel . ConnectionClosed 
            let vst' = remoteIncoming ^= IntSet.empty $ vst 
            -- Check if we agree that the connection should be closed
            if vst' ^. remoteOutgoing == 0 
              then do 
                removeRemoteEndPoint (ourEndPoint, theirEndPoint)
                -- Attempt to reply (but don't insist)
                tryIO $ sendOn vst' [encodeInt32 CloseSocket]
                tryCloseSocket sock 
                return (RemoteEndPointClosed, True)
              else 
                return (RemoteEndPointValid vst', False)
          RemoteEndPointClosing resolved  _ -> do
            removeRemoteEndPoint (ourEndPoint, theirEndPoint)
            tryCloseSocket sock 
            putMVar resolved ()
            return (RemoteEndPointClosed, True)
          RemoteEndPointFailed err ->
            throwIO err
          RemoteEndPointClosed ->
            relyViolation (ourEndPoint, theirEndPoint) 
              "handleIncomingMessages:closeSocket (closed)"
            
    -- Read a message and output it on the endPoint's channel. By rights we
    -- should verify that the connection ID is valid, but this is unnecessary
    -- overhead
    readMessage :: N.Socket -> ConnectionId -> IO () 
    readMessage sock connId = 
      recvWithLength sock >>= writeChan ourChannel . Received connId

    -- Arguments
    ourChannel  = localChannel ourEndPoint 
    theirState  = remoteState theirEndPoint
    theirAddr   = remoteAddress theirEndPoint

    -- Deal with a premature exit
    prematureExit :: N.Socket -> IOException -> IO ()
    prematureExit sock err = do
      tryCloseSocket sock
      modifyMVar_ theirState $ \st ->
        case st of
          RemoteEndPointInvalid _ ->
            relyViolation (ourEndPoint, theirEndPoint) 
              "handleIncomingMessages:prematureExit"
          RemoteEndPointInit _ _ ->
            relyViolation (ourEndPoint, theirEndPoint) 
              "handleIncomingMessages:prematureExit"
          RemoteEndPointValid vst -> do
            let code = EventConnectionLost 
                         (Just $ remoteAddress theirEndPoint) 
                         (IntSet.elems $ vst ^. remoteIncoming)
            writeChan ourChannel . ErrorEvent $ TransportError code (show err)
            forM_ (vst ^. pendingCtrlRequests) $ flip putMVar (Left err) 
            return (RemoteEndPointFailed err)
          RemoteEndPointClosing resolved _ -> do
            putMVar resolved ()
            return (RemoteEndPointFailed err)
          RemoteEndPointClosed ->
            relyViolation (ourEndPoint, theirEndPoint) 
              "handleIncomingMessages:prematureExit"
          RemoteEndPointFailed err' ->
            return (RemoteEndPointFailed err')

--------------------------------------------------------------------------------
-- Uninterruptable auxiliary functions                                        --
--                                                                            --
-- All these functions assume they are running in a thread which will never   --
-- be killed.
--------------------------------------------------------------------------------

-- | Request a connection to a remote endpoint
--
-- This will block until we get a connection ID from the remote endpoint; if
-- the remote endpoint was in 'RemoteEndPointClosing' state then we will
-- additionally block until that is resolved. 
--
-- May throw a TransportError ConnectErrorCode exception.
requestConnectionTo :: TCPParameters 
                    -> LocalEndPoint 
                    -> EndPointAddress 
                    -> ConnectHints
                    -> IO (RemoteEndPoint, ConnectionId)
requestConnectionTo params ourEndPoint theirAddress hints = go
  where
    go = do
      (theirEndPoint, isNew) <- mapIOException connectFailed $
        findRemoteEndPoint ourEndPoint theirAddress RequestedByUs

      if isNew 
        then do
          forkIO . handle absorbAllExceptions $ 
            setupRemoteEndPoint params (ourEndPoint, theirEndPoint) hints 
          go
        else do
          reply <- mapIOException connectFailed $ 
            doRemoteRequest (ourEndPoint, theirEndPoint) RequestConnectionId 
          return (theirEndPoint, decodeInt32 . BS.concat $ reply)

    connectFailed :: IOException -> TransportError ConnectErrorCode
    connectFailed = TransportError ConnectFailed . show

    absorbAllExceptions :: SomeException -> IO ()
    absorbAllExceptions _ex = 
      return ()

-- | Set up a remote endpoint
setupRemoteEndPoint :: TCPParameters -> EndPointPair -> ConnectHints -> IO () 
setupRemoteEndPoint params (ourEndPoint, theirEndPoint) hints = do
    result <- socketToEndPoint ourAddress 
                               theirAddress 
                               (tcpReuseClientAddr params)
                               (connectTimeout hints)
    didAccept <- case result of
      Right (sock, ConnectionRequestAccepted) -> do 
        let vst = ValidRemoteEndPointState 
                    {  remoteSocket        = sock
                    , _remoteOutgoing      = 0 
                    , _remoteIncoming      = IntSet.empty
                    ,  sendOn              = sendMany sock 
                    , _pendingCtrlRequests = IntMap.empty
                    , _nextCtrlRequestId   = 0
                    }
        resolveInit (ourEndPoint, theirEndPoint) (RemoteEndPointValid vst)
        return True
      Right (sock, ConnectionRequestInvalid) -> do
        let err = invalidAddress "setupRemoteEndPoint: Invalid endpoint"
        resolveInit (ourEndPoint, theirEndPoint) (RemoteEndPointInvalid err)
        tryCloseSocket sock
        return False
      Right (sock, ConnectionRequestCrossed) -> do
        resolveInit (ourEndPoint, theirEndPoint) RemoteEndPointClosed
        tryCloseSocket sock
        return False
      Left err -> do 
        resolveInit (ourEndPoint, theirEndPoint) (RemoteEndPointInvalid err)
        return False

    when didAccept $ handleIncomingMessages (ourEndPoint, theirEndPoint) 
  where
    ourAddress      = localAddress ourEndPoint
    theirAddress    = remoteAddress theirEndPoint
    invalidAddress  = TransportError ConnectNotFound

-- | Do a (blocking) remote request 
-- 
-- May throw IO (user) exception if the local or the remote endpoint is closed,
-- if the send fails, or if the remote endpoint fails before it replies.
doRemoteRequest :: EndPointPair -> ControlHeader -> IO [ByteString]
doRemoteRequest (ourEndPoint, theirEndPoint) header = do 
  replyMVar <- newEmptyMVar
  modifyRemoteState_ (ourEndPoint, theirEndPoint) RemoteStatePatternMatch 
    { caseValid = \vst -> do 
        let reqId = vst ^. nextCtrlRequestId
        sendOn vst [encodeInt32 header, encodeInt32 reqId]
        return ( RemoteEndPointValid
               . (nextCtrlRequestId ^: (+ 1))
               . (pendingCtrlRequestsAt reqId ^= Just replyMVar)
               $ vst
               ) 
    -- Error cases
    , caseInvalid = 
        throwIO 
    , caseInit = \_ _ -> 
        relyViolation (ourEndPoint, theirEndPoint) "doRemoteRequest (init)"
    , caseClosing = \_ _ -> 
        relyViolation (ourEndPoint, theirEndPoint) "doRemoteRequest (closing)" 
    , caseClosed =
        relyViolation (ourEndPoint, theirEndPoint) "doRemoteRequest (closed)" 
    , caseFailed =
        throwIO
    }
  mReply <- takeMVar replyMVar
  case mReply of
    Left err    -> throwIO err
    Right reply -> return reply

-- | Send a CloseSocket request if the remote endpoint is unused
closeIfUnused :: EndPointPair -> IO ()
closeIfUnused (ourEndPoint, theirEndPoint) =
  modifyRemoteState_ (ourEndPoint, theirEndPoint) remoteStateIdentity 
    { caseValid = \vst -> 
        if vst ^. remoteOutgoing == 0 && IntSet.null (vst ^. remoteIncoming) 
          then do 
            sendOn vst [encodeInt32 CloseSocket]
            resolved <- newEmptyMVar
            return $ RemoteEndPointClosing resolved vst 
          else 
            return $ RemoteEndPointValid vst
    }

-- | Reset a remote endpoint if it is in Invalid mode
--
-- If the remote endpoint is currently in broken state, and
--
--   - a user calls the API function 'connect', or and the remote endpoint is
--   - an inbound connection request comes in from this remote address
--
-- we remove the remote endpoint first.
--
-- Throws a TransportError ConnectFailed exception if the local endpoint is
-- closed.
resetIfBroken :: LocalEndPoint -> EndPointAddress -> IO ()
resetIfBroken ourEndPoint theirAddress = do
  mTheirEndPoint <- withMVar (localState ourEndPoint) $ \st -> case st of
    LocalEndPointValid vst ->
      return (vst ^. localConnectionTo theirAddress)
    LocalEndPointClosed ->
      throwIO $ TransportError ConnectFailed "Endpoint closed"
  forM_ mTheirEndPoint $ \theirEndPoint -> 
    withMVar (remoteState theirEndPoint) $ \st -> case st of
      RemoteEndPointInvalid _ ->
        removeRemoteEndPoint (ourEndPoint, theirEndPoint)
      RemoteEndPointFailed _ ->
        removeRemoteEndPoint (ourEndPoint, theirEndPoint)
      _ ->
        return ()

-- | Special case of 'apiConnect': connect an endpoint to itself
-- 
-- May throw a TransportError ConnectErrorCode (if the local endpoint is closed)
connectToSelf :: LocalEndPoint 
               -> IO Connection
connectToSelf ourEndPoint = do  
    connAlive <- newIORef True  -- Protected by the local endpoint lock
    connId    <- mapIOException connectFailed $ getNextConnectionId ourEndPoint 
    writeChan ourChan $
      ConnectionOpened connId ReliableOrdered (localAddress ourEndPoint)
    return Connection 
      { send  = selfSend connAlive connId 
      , close = selfClose connAlive connId
      }
  where
    selfSend :: IORef Bool 
             -> ConnectionId 
             -> [ByteString] 
             -> IO (Either (TransportError SendErrorCode) ())
    selfSend connAlive connId msg = 
      try . withMVar ourState $ \st -> case st of
        LocalEndPointValid _ -> do
          alive <- readIORef connAlive
          if alive
            then writeChan ourChan (Received connId msg)
            else throwIO $ TransportError SendClosed "Connection closed"
        LocalEndPointClosed ->
          throwIO $ TransportError SendFailed "Endpoint closed"

    selfClose :: IORef Bool -> ConnectionId -> IO ()
    selfClose connAlive connId = 
      withMVar ourState $ \st -> case st of
        LocalEndPointValid _ -> do
          alive <- readIORef connAlive
          when alive $ do
            writeChan ourChan (ConnectionClosed connId) 
            writeIORef connAlive False
        LocalEndPointClosed ->
          return () 

    ourChan  = localChannel ourEndPoint
    ourState = localState ourEndPoint
    connectFailed = TransportError ConnectFailed . show 

-- | Resolve an endpoint currently in 'Init' state
resolveInit :: EndPointPair -> RemoteState -> IO ()
resolveInit (ourEndPoint, theirEndPoint) newState =
  modifyMVar_ (remoteState theirEndPoint) $ \st -> case st of
    RemoteEndPointInit resolved _ -> do
      putMVar resolved ()
      case newState of 
        RemoteEndPointClosed -> 
          removeRemoteEndPoint (ourEndPoint, theirEndPoint)
        _ ->
          return ()
      return newState
    RemoteEndPointFailed ex -> 
      throwIO ex
    _ ->
      relyViolation (ourEndPoint, theirEndPoint) "resolveInit"

-- | Get the next connection ID
-- 
-- Throws an IO exception when the endpoint is closed.
getNextConnectionId :: LocalEndPoint -> IO ConnectionId
getNextConnectionId ourEndpoint = 
  modifyMVar (localState ourEndpoint) $ \st -> case st of
    LocalEndPointValid vst -> do
      let connId = vst ^. nextConnectionId 
      return ( LocalEndPointValid 
             . (nextConnectionId ^= connId + 1) 
             $ vst
             , connId)
    LocalEndPointClosed ->
      throwIO $ userError "Local endpoint closed"

-- | Create a new local endpoint
-- 
-- May throw a TransportError NewEndPointErrorCode exception if the transport
-- is closed.
createLocalEndPoint :: TCPTransport -> IO LocalEndPoint
createLocalEndPoint transport = do 
    chan  <- newChan
    state <- newMVar . LocalEndPointValid $ ValidLocalEndPointState 
      { _nextConnectionId    = firstNonReservedConnectionId 
      , _localConnections    = Map.empty
      , _nextRemoteId        = 0
      }
    modifyMVar (transportState transport) $ \st -> case st of
      TransportValid vst -> do
        let ix   = vst ^. nextEndPointId
        let addr = encodeEndPointAddress (transportHost transport) 
                                         (transportPort transport)
                                         ix 
        let localEndPoint = LocalEndPoint { localAddress  = addr
                                          , localChannel  = chan
                                          , localState    = state
                                          }
        return ( TransportValid 
               . (localEndPointAt addr ^= Just localEndPoint) 
               . (nextEndPointId ^= ix + 1) 
               $ vst
               , localEndPoint
               )
      TransportClosed ->
        throwIO (TransportError NewEndPointFailed "Transport closed")


-- | Remove reference to a remote endpoint from a local endpoint
--
-- If the local endpoint is closed, do nothing
removeRemoteEndPoint :: EndPointPair -> IO ()
removeRemoteEndPoint (ourEndPoint, theirEndPoint) =
    modifyMVar_ ourState $ \st -> case st of
      LocalEndPointValid vst ->
        case vst ^. localConnectionTo theirAddress of
          Nothing -> 
            return st
          Just remoteEndPoint' ->
            if remoteId remoteEndPoint' == remoteId theirEndPoint 
              then return 
                ( LocalEndPointValid 
                . (localConnectionTo (remoteAddress theirEndPoint) ^= Nothing) 
                $ vst
                )
              else return st
      LocalEndPointClosed ->
        return LocalEndPointClosed
  where
    ourState     = localState ourEndPoint
    theirAddress = remoteAddress theirEndPoint

-- | Remove reference to a local endpoint from the transport state
--
-- Does nothing if the transport is closed
removeLocalEndPoint :: TCPTransport -> LocalEndPoint -> IO ()
removeLocalEndPoint transport ourEndPoint = 
  modifyMVar_ (transportState transport) $ \st -> case st of 
    TransportValid vst ->
      return ( TransportValid 
             . (localEndPointAt (localAddress ourEndPoint) ^= Nothing) 
             $ vst
             )
    TransportClosed ->
      return TransportClosed

-- | Find a remote endpoint. If the remote endpoint does not yet exist we
-- create it in Init state. Returns if the endpoint was new. 
findRemoteEndPoint 
  :: LocalEndPoint
  -> EndPointAddress
  -> RequestedBy 
  -> IO (RemoteEndPoint, Bool) 
findRemoteEndPoint ourEndPoint theirAddress findOrigin = go
  where
    go = do
      (theirEndPoint, isNew) <- modifyMVar ourState $ \st -> case st of
        LocalEndPointValid vst -> case vst ^. localConnectionTo theirAddress of
          Just theirEndPoint ->
            return (st, (theirEndPoint, False))
          Nothing -> do
            resolved <- newEmptyMVar
            theirState <- newMVar (RemoteEndPointInit resolved findOrigin)
            let theirEndPoint = RemoteEndPoint
                                  { remoteAddress = theirAddress
                                  , remoteState   = theirState
                                  , remoteId      = vst ^. nextRemoteId
                                  }
            return ( LocalEndPointValid 
                   . (localConnectionTo theirAddress ^= Just theirEndPoint) 
                   . (nextRemoteId ^: (+ 1)) 
                   $ vst
                   , (theirEndPoint, True) 
                   )
        LocalEndPointClosed ->
          throwIO $ userError "Local endpoint closed"
      
      if isNew 
        then
          return (theirEndPoint, True)
        else do
          let theirState = remoteState theirEndPoint
          snapshot <- modifyMVar theirState $ \st -> case st of
            RemoteEndPointValid vst -> 
              case findOrigin of
                RequestedByUs -> do
                  let st' = RemoteEndPointValid 
                          . (remoteOutgoing ^: (+ 1)) 
                          $ vst 
                  return (st', st')
                RequestedByThem ->
                  return (st, st) 
            _ ->
              return (st, st)
          -- The snapshot may no longer be up to date at this point, but if we
          -- increased the refcount then it can only either be Valid or Failed 
          -- (after an explicit call to 'closeEndPoint' or 'closeTransport') 
          case snapshot of
            RemoteEndPointInvalid err ->
              throwIO err
            RemoteEndPointInit resolved initOrigin ->
              case (findOrigin, initOrigin) of
                (RequestedByUs, RequestedByUs) -> 
                  readMVar resolved >> go 
                (RequestedByUs, RequestedByThem) -> 
                  readMVar resolved >> go
                (RequestedByThem, RequestedByUs) -> 
                  if ourAddress > theirAddress 
                    then
                      -- Wait for the Crossed message
                      readMVar resolved >> go 
                    else
                      return (theirEndPoint, False)
                (RequestedByThem, RequestedByThem) -> 
                  throwIO $ userError "Already connected"
            RemoteEndPointValid _ ->
              -- We assume that the request crossed if we find the endpoint in
              -- Valid state. It is possible that this is really an invalid
              -- request, but only in the case of a broken client (we don't
              -- maintain enough history to be able to tell the difference).
              return (theirEndPoint, False)
            RemoteEndPointClosing resolved _ ->
              readMVar resolved >> go
            RemoteEndPointClosed ->
              go
            RemoteEndPointFailed err -> 
              throwIO err
          
    ourState   = localState ourEndPoint 
    ourAddress = localAddress ourEndPoint

--------------------------------------------------------------------------------
-- "Stateless" (MVar free) functions                                          --
--------------------------------------------------------------------------------

-- | Establish a connection to a remote endpoint
--
-- Maybe throw a TransportError
socketToEndPoint :: EndPointAddress -- ^ Our address 
                 -> EndPointAddress -- ^ Their address
                 -> Bool            -- ^ Use SO_REUSEADDR?
                 -> Maybe Int       -- ^ Timeout for connect 
                 -> IO (Either (TransportError ConnectErrorCode) 
                               (N.Socket, ConnectionRequestResponse)) 
socketToEndPoint (EndPointAddress ourAddress) theirAddress reuseAddr timeout = 
  try $ do 
    (host, port, theirEndPointId) <- case decodeEndPointAddress theirAddress of 
      Nothing  -> throwIO (failed . userError $ "Could not parse")
      Just dec -> return dec
    addr:_ <- mapIOException invalidAddress $ 
      N.getAddrInfo Nothing (Just host) (Just port)
    bracketOnError (createSocket addr) tryCloseSocket $ \sock -> do
      when reuseAddr $ 
        mapIOException failed $ N.setSocketOption sock N.ReuseAddr 1
      mapIOException invalidAddress $ 
        timeoutMaybe timeout timeoutError $ 
          N.connect sock (N.addrAddress addr) 
      response <- mapIOException failed $ do
        sendMany sock (encodeInt32 theirEndPointId : prependLength [ourAddress])
        recvInt32 sock
      case tryToEnum response of
        Nothing -> throwIO (failed . userError $ "Unexpected response")
        Just r  -> return (sock, r)
  where
    createSocket :: N.AddrInfo -> IO N.Socket
    createSocket addr = mapIOException insufficientResources $ 
      N.socket (N.addrFamily addr) N.Stream N.defaultProtocol

    invalidAddress        = TransportError ConnectNotFound . show 
    insufficientResources = TransportError ConnectInsufficientResources . show 
    failed                = TransportError ConnectFailed . show
    timeoutError          = TransportError ConnectTimeout "Timed out"

-- | Encode end point address
encodeEndPointAddress :: N.HostName 
                      -> N.ServiceName 
                      -> EndPointId 
                      -> EndPointAddress
encodeEndPointAddress host port ix = EndPointAddress . BSC.pack $
  host ++ ":" ++ port ++ ":" ++ show ix 

-- | Decode end point address
decodeEndPointAddress :: EndPointAddress 
                      -> Maybe (N.HostName, N.ServiceName, EndPointId)
decodeEndPointAddress (EndPointAddress bs) = 
  case splitMaxFromEnd (== ':') 2 $ BSC.unpack bs of
    [host, port, endPointIdStr] -> 
      case reads endPointIdStr of 
        [(endPointId, "")] -> Just (host, port, endPointId)
        _                  -> Nothing
    _ ->
      Nothing

-- | @spltiMaxFromEnd p n xs@ splits list @xs@ at elements matching @p@,
-- returning at most @p@ segments -- counting from the /end/
-- 
-- > splitMaxFromEnd (== ':') 2 "ab:cd:ef:gh" == ["ab:cd", "ef", "gh"]
splitMaxFromEnd :: (a -> Bool) -> Int -> [a] -> [[a]]
splitMaxFromEnd p = \n -> go [[]] n . reverse 
  where
    -- go :: [[a]] -> Int -> [a] -> [[a]]
    go accs         _ []     = accs 
    go ([]  : accs) 0 xs     = reverse xs : accs
    go (acc : accs) n (x:xs) =
      if p x then go ([] : acc : accs) (n - 1) xs
             else go ((x : acc) : accs) n xs
    go _ _ _ = error "Bug in splitMaxFromEnd"

--------------------------------------------------------------------------------
-- Functions from TransportInternals                                          --
--------------------------------------------------------------------------------

-- Find a socket between two endpoints
-- 
-- Throws an IO exception if the socket could not be found.
internalSocketBetween :: TCPTransport    -- ^ Transport 
                      -> EndPointAddress -- ^ Local endpoint
                      -> EndPointAddress -- ^ Remote endpoint
                      -> IO N.Socket 
internalSocketBetween transport ourAddress theirAddress = do
  ourEndPoint <- withMVar (transportState transport) $ \st -> case st of
      TransportClosed -> 
        throwIO $ userError "Transport closed" 
      TransportValid vst -> 
        case vst ^. localEndPointAt ourAddress of
          Nothing -> throwIO $ userError "Local endpoint not found"
          Just ep -> return ep
  theirEndPoint <- withMVar (localState ourEndPoint) $ \st -> case st of
      LocalEndPointClosed ->
        throwIO $ userError "Local endpoint closed"
      LocalEndPointValid vst -> 
        case vst ^. localConnectionTo theirAddress of
          Nothing -> throwIO $ userError "Remote endpoint not found"
          Just ep -> return ep
  withMVar (remoteState theirEndPoint) $ \st -> case st of
    RemoteEndPointInit _ _ ->
      throwIO $ userError "Remote endpoint not yet initialized"
    RemoteEndPointValid vst -> 
      return $ remoteSocket vst
    RemoteEndPointClosing _ vst ->
      return $ remoteSocket vst 
    RemoteEndPointClosed ->
      throwIO $ userError "Remote endpoint closed"
    RemoteEndPointInvalid err ->
      throwIO err 
    RemoteEndPointFailed err ->
      throwIO err 

--------------------------------------------------------------------------------
-- Constants                                                                  --
--------------------------------------------------------------------------------

-- | We reserve a bunch of connection IDs for control messages
firstNonReservedConnectionId :: ConnectionId
firstNonReservedConnectionId = 1024

--------------------------------------------------------------------------------
-- Accessor definitions                                                       --
--------------------------------------------------------------------------------

localEndPoints :: Accessor ValidTransportState (Map EndPointAddress LocalEndPoint)
localEndPoints = accessor _localEndPoints (\es st -> st { _localEndPoints = es })

nextEndPointId :: Accessor ValidTransportState EndPointId
nextEndPointId = accessor _nextEndPointId (\eid st -> st { _nextEndPointId = eid })

nextConnectionId :: Accessor ValidLocalEndPointState ConnectionId
nextConnectionId = accessor _nextConnectionId (\cix st -> st { _nextConnectionId = cix })

localConnections :: Accessor ValidLocalEndPointState (Map EndPointAddress RemoteEndPoint)
localConnections = accessor _localConnections (\es st -> st { _localConnections = es })

nextRemoteId :: Accessor ValidLocalEndPointState Int
nextRemoteId = accessor _nextRemoteId (\rid st -> st { _nextRemoteId = rid })

remoteOutgoing :: Accessor ValidRemoteEndPointState Int
remoteOutgoing = accessor _remoteOutgoing (\cs conn -> conn { _remoteOutgoing = cs })

remoteIncoming :: Accessor ValidRemoteEndPointState IntSet
remoteIncoming = accessor _remoteIncoming (\cs conn -> conn { _remoteIncoming = cs })

pendingCtrlRequests :: Accessor ValidRemoteEndPointState (IntMap (MVar (Either IOException [ByteString])))
pendingCtrlRequests = accessor _pendingCtrlRequests (\rep st -> st { _pendingCtrlRequests = rep })

nextCtrlRequestId :: Accessor ValidRemoteEndPointState ControlRequestId 
nextCtrlRequestId = accessor _nextCtrlRequestId (\cid st -> st { _nextCtrlRequestId = cid })

localEndPointAt :: EndPointAddress -> Accessor ValidTransportState (Maybe LocalEndPoint)
localEndPointAt addr = localEndPoints >>> DAC.mapMaybe addr 

pendingCtrlRequestsAt :: ControlRequestId -> Accessor ValidRemoteEndPointState (Maybe (MVar (Either IOException [ByteString])))
pendingCtrlRequestsAt ix = pendingCtrlRequests >>> DAC.intMapMaybe (fromIntegral ix)

localConnectionTo :: EndPointAddress 
                  -> Accessor ValidLocalEndPointState (Maybe RemoteEndPoint)
localConnectionTo addr = localConnections >>> DAC.mapMaybe addr 

-------------------------------------------------------------------------------
-- Debugging                                                                 --
-------------------------------------------------------------------------------

relyViolation :: EndPointPair -> String -> IO a
relyViolation (ourEndPoint, theirEndPoint) str = do
  elog (ourEndPoint, theirEndPoint) (str ++ " RELY violation")  
  fail (str ++ " RELY violation")

elog :: EndPointPair -> String -> IO ()
elog (ourEndPoint, theirEndPoint) msg = do
  tid <- myThreadId
  putStrLn  $  show (localAddress ourEndPoint) 
    ++ "/"  ++ show (remoteAddress theirEndPoint) 
    ++ "("  ++ show (remoteId theirEndPoint) ++ ")"
    ++ "/"  ++ show tid 
    ++ ": " ++ msg