-- | TCP implementation of the transport layer. 
-- 
-- The TCP implementation guarantees that only a single TCP connection (socket)
-- will be used between endpoints, provided that the addresses specified are
-- canonical. If /A/ connects to /B/ and reports its address as
-- @192.168.0.1:8080@ and /B/ subsequently connects tries to connect to /A/ as
-- @client1.local:http-alt@ then the transport layer will not realize that the
-- TCP connection can be reused. 
--
-- Applications that use the TCP transport should use
-- 'Network.Socket.withSocketsDo' in their main function for Windows
-- compatibility (see "Network.Socket").
module Network.Transport.TCP 
  ( -- * Main API
    createTransport
  , TCPParameters(..)
  , defaultTCPParameters
    -- * Internals (exposed for unit tests) 
  , createTransportExposeInternals 
  , TransportInternals(..)
  , EndPointId
  , encodeEndPointAddress
  , decodeEndPointAddress
  , ControlHeader(..)
  , ConnectionRequestResponse(..)
  , firstNonReservedLightweightConnectionId
  , firstNonReservedHeavyweightConnectionId
  , socketToEndPoint 
  , LightweightConnectionId
    -- * Design notes
    -- $design
  ) where

import Prelude hiding 
  ( mapM_
#if ! MIN_VERSION_base(4,6,0)
  , catch
#endif
  )
 
import Network.Transport
import Network.Transport.TCP.Internal 
  ( forkServer
  , recvWithLength
  , recvInt32
  , tryCloseSocket
  )
import Network.Transport.Internal 
  ( encodeInt32
  , prependLength
  , mapIOException
  , tryIO
  , tryToEnum
  , void
  , timeoutMaybe
  , asyncWhenCancelled
  )

#ifdef USE_MOCK_NETWORK
import qualified Network.Transport.TCP.Mock.Socket as N
#else
import qualified Network.Socket as N 
#endif
  ( HostName
  , ServiceName
  , Socket
  , getAddrInfo
  , socket
  , addrFamily
  , addrAddress
  , SocketType(Stream)
  , defaultProtocol
  , setSocketOption
  , SocketOption(ReuseAddr) 
  , connect
  , sOMAXCONN
  , AddrInfo
  )

#ifdef USE_MOCK_NETWORK
import Network.Transport.TCP.Mock.Socket.ByteString (sendMany)
#else
import Network.Socket.ByteString (sendMany)
#endif

import Control.Concurrent (forkIO, ThreadId, killThread, myThreadId)
import Control.Concurrent.Chan (Chan, newChan, readChan, writeChan)
import Control.Concurrent.MVar 
  ( MVar
  , newMVar
  , modifyMVar
  , modifyMVar_
  , readMVar
  , putMVar
  , newEmptyMVar
  , withMVar
  )
import Control.Category ((>>>))
import Control.Applicative ((<$>))
import Control.Monad (when, unless, join)
import Control.Exception 
  ( IOException
  , SomeException
  , AsyncException
  , handle
  , throw
  , throwIO
  , try
  , bracketOnError
  , fromException
  , catch
  )
import Data.IORef (IORef, newIORef, writeIORef, readIORef)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS (concat)
import qualified Data.ByteString.Char8 as BSC (pack, unpack)
import Data.Bits (shiftL, (.|.))
import Data.Word (Word32)
import Data.Set (Set)
import qualified Data.Set as Set 
  ( empty
  , insert
  , elems
  , singleton
  , null
  , delete
  , member
  )
import Data.Map (Map)
import qualified Data.Map as Map (empty)
import Data.Accessor (Accessor, accessor, (^.), (^=), (^:)) 
import qualified Data.Accessor.Container as DAC (mapMaybe)
import Data.Foldable (forM_, mapM_)

-- $design 
--
-- [Goals]
--
-- The TCP transport maps multiple logical connections between /A/ and /B/ (in
-- either direction) to a single TCP connection:
--
-- > +-------+                          +-------+
-- > | A     |==========================| B     |
-- > |       |>~~~~~~~~~~~~~~~~~~~~~~~~~|~~~\   |
-- > |   Q   |>~~~~~~~~~~~~~~~~~~~~~~~~~|~~~Q   |
-- > |   \~~~|~~~~~~~~~~~~~~~~~~~~~~~~~<|       |
-- > |       |==========================|       |
-- > +-------+                          +-------+
-- 
-- Ignoring the complications detailed below, the TCP connection is set up is
-- when the first lightweight connection is created (in either direction), and
-- torn down when the last lightweight connection (in either direction) is
-- closed.
--
-- [Connecting]
--
-- Let /A/, /B/ be two endpoints without any connections. When /A/ wants to
-- connect to /B/, it locally records that it is trying to connect to /B/ and
-- sends a request to /B/. As part of the request /A/ sends its own endpoint
-- address to /B/ (so that /B/ can reuse the connection in the other direction).
--
-- When /B/ receives the connection request it first checks if it did not
-- already initiate a connection request to /A/. If not it will acknowledge the
-- connection request by sending 'ConnectionRequestAccepted' to /A/ and record
-- that it has a TCP connection to /A/.
--
-- The tricky case arises when /A/ sends a connection request to /B/ and /B/
-- finds that it had already sent a connection request to /A/. In this case /B/
-- will accept the connection request from /A/ if /A/s endpoint address is
-- smaller (lexicographically) than /B/s, and reject it otherwise. If it rejects
-- it, it sends a 'ConnectionRequestCrossed' message to /A/. (The
-- lexicographical ordering is an arbitrary but convenient way to break the
-- tie.)
--
-- When it receives a 'ConnectionRequestCrossed' message the /A/ thread that
-- initiated the request just needs to wait until the /A/ thread that is dealing
-- with /B/'s connection request completes.
--
-- [Disconnecting]
-- 
-- The TCP connection is created as soon as the first logical connection from
-- /A/ to /B/ (or /B/ to /A/) is established. At this point a thread (@#@) is
-- spawned that listens for incoming connections from /B/:
--
-- > +-------+                          +-------+
-- > | A     |==========================| B     |
-- > |       |>~~~~~~~~~~~~~~~~~~~~~~~~~|~~~\   |
-- > |       |                          |   Q   |
-- > |      #|                          |       |
-- > |       |==========================|       |
-- > +-------+                          +-------+
--
-- The question is when the TCP connection can be closed again.  Conceptually,
-- we want to do reference counting: when there are no logical connections left
-- between /A/ and /B/ we want to close the socket (possibly after some
-- timeout). 
--
-- However, /A/ and /B/ need to agree that the refcount has reached zero.  It
-- might happen that /B/ sends a connection request over the existing socket at
-- the same time that /A/ closes its logical connection to /B/ and closes the
-- socket. This will cause a failure in /B/ (which will have to retry) which is
-- not caused by a network failure, which is unfortunate. (Note that the
-- connection request from /B/ might succeed even if /A/ closes the socket.) 
--
-- Instead, when /A/ is ready to close the socket it sends a 'CloseSocket'
-- request to /B/ and records that its connection to /B/ is closing. If /A/
-- receives a new connection request from /B/ after having sent the
-- 'CloseSocket' request it simply forgets that it sent a 'CloseSocket' request
-- and increments the reference count of the connection again.
-- 
-- When /B/ receives a 'CloseSocket' message and it too is ready to close the
-- connection, it will respond with a reciprocal 'CloseSocket' request to /A/
-- and then actually close the socket. /A/ meanwhile will not send any more
-- requests to /B/ after having sent a 'CloseSocket' request, and will actually
-- close its end of the socket only when receiving the 'CloseSocket' message
-- from /B/. (Since /A/ recorded that its connection to /B/ is in closing state
-- after sending a 'CloseSocket' request to /B/, it knows not to reciprocate /B/
-- reciprocal 'CloseSocket' message.)
-- 
-- If there is a concurrent thread in /A/ waiting to connect to /B/ after /A/
-- has sent a 'CloseSocket' request then this thread will block until /A/ knows
-- whether to reuse the old socket (if /B/ sends a new connection request
-- instead of acknowledging the 'CloseSocket') or to set up a new socket. 

--------------------------------------------------------------------------------
-- Internal datatypes                                                         --
--------------------------------------------------------------------------------

-- We use underscores for fields that we might update (using accessors)
--
-- All data types follow the same structure:
-- 
-- * A top-level data type describing static properties (TCPTransport,
--   LocalEndPoint, RemoteEndPoint)
-- * The 'static' properties include an MVar containing a data structure for
--   the dynamic properties (TransportState, LocalEndPointState,
--   RemoteEndPointState). The state could be invalid/valid/closed,/etc.  
-- * For the case of "valid" we use third data structure to give more details
--   about the state (ValidTransportState, ValidLocalEndPointState,
--   ValidRemoteEndPointState).

data TCPTransport = TCPTransport 
  { transportHost   :: !N.HostName
  , transportPort   :: !N.ServiceName
  , transportState  :: !(MVar TransportState)
  , transportParams :: !TCPParameters
  }

data TransportState = 
    TransportValid !ValidTransportState
  | TransportClosed

data ValidTransportState = ValidTransportState 
  { _localEndPoints :: !(Map EndPointAddress LocalEndPoint)
  , _nextEndPointId :: !EndPointId 
  }

data LocalEndPoint = LocalEndPoint
  { localAddress :: !EndPointAddress
  , localChannel :: !(Chan Event)
  , localState   :: !(MVar LocalEndPointState)
  }

data LocalEndPointState = 
    LocalEndPointValid !ValidLocalEndPointState
  | LocalEndPointClosed

data ValidLocalEndPointState = ValidLocalEndPointState 
  { -- Next available ID for an outgoing lightweight self-connection 
    -- (see also remoteNextConnOutId)
    _localNextConnOutId :: !LightweightConnectionId
    -- Next available ID for an incoming heavyweight connection
  , _nextConnInId :: !HeavyweightConnectionId
    -- Currently active outgoing heavyweight connections
  , _localConnections :: !(Map EndPointAddress RemoteEndPoint)
  }

-- REMOTE ENDPOINTS 
--
-- Remote endpoints (basically, TCP connections) have the following lifecycle:
--
--   Init  ---+---> Invalid
--            |
--            +-------------------------------\
--            |                               |
--            |       /----------\            |
--            |       |          |            |
--            |       v          |            v
--            +---> Valid ---> Closing ---> Closed
--            |       |          |            |
--            |       |          |            v
--            \-------+----------+--------> Failed
--
-- Init: There are two places where we create new remote endpoints: in
--   createConnectionTo (in response to an API 'connect' call) and in
--   handleConnectionRequest (when a remote node tries to connect to us).
--   'Init' carries an MVar () 'resolved' which concurrent threads can use to
--   wait for the remote endpoint to finish initialization. We record who
--   requested the connection (the local endpoint or the remote endpoint).
--
-- Invalid: We put the remote endpoint in invalid state only during
--   createConnectionTo when we fail to connect.
--
-- Valid: This is the "normal" state for a working remote endpoint.
-- 
-- Closing: When we detect that a remote endpoint is no longer used, we send a
--   CloseSocket request across the connection and put the remote endpoint in
--   closing state. As with Init, 'Closing' carries an MVar () 'resolved' which
--   concurrent threads can use to wait for the remote endpoint to either be
--   closed fully (if the communication parnet responds with another
--   CloseSocket) or be put back in 'Valid' state if the remote endpoint denies
--   the request.
--
--   We also put the endpoint in Closed state, directly from Init, if we our
--   outbound connection request crossed an inbound connection request and we
--   decide to keep the inbound (i.e., the remote endpoint sent us a
--   ConnectionRequestCrossed message).
--
-- Closed: The endpoint is put in Closed state after a successful garbage
--   collection.
--
-- Failed: If the connection to the remote endpoint is lost, or the local
-- endpoint (or the whole transport) is closed manually, the remote endpoint is
-- put in Failed state, and we record the reason.
--
-- Invariants for dealing with remote endpoints:
--
-- INV-SEND: Whenever we send data the remote endpoint must be locked (to avoid
--   interleaving bits of payload).
--
-- INV-CLOSE: Local endpoints should never point to remote endpoint in closed
--   state.  Whenever we put an endpoint in Closed state we remove that
--   endpoint from localConnections first, so that if a concurrent thread reads
--   the MVar, finds RemoteEndPointClosed, and then looks up the endpoint in
--   localConnections it is guaranteed to either find a different remote
--   endpoint, or else none at all (if we don't insist in this order some
--   threads might start spinning).
-- 
-- INV-RESOLVE: We should only signal on 'resolved' while the remote endpoint is
--   locked, and the remote endpoint must be in Valid or Closed state once
--   unlocked. This guarantees that there will not be two threads attempting to
--   both signal on 'resolved'. 
--
-- INV-LOST: If a send or recv fails, or a socket is closed unexpectedly, we
--   first put the remote endpoint in Closed state, and then send a
--   EventConnectionLost event. This guarantees that we only send this event
--   once.  
--
-- INV-CLOSING: An endpoint in closing state is for all intents and purposes
--   closed; that is, we shouldn't do any 'send's on it (although 'recv' is
--   acceptable, of course -- as we are waiting for the remote endpoint to
--   confirm or deny the request).
--
-- INV-LOCK-ORDER: Remote endpoint must be locked before their local endpoints.
--   In other words: it is okay to call modifyMVar on a local endpoint inside a
--   modifyMVar on a remote endpoint, but not the other way around. In
--   particular, it is okay to call removeRemoteEndPoint inside
--   modifyRemoteState.

data RemoteEndPoint = RemoteEndPoint 
  { remoteAddress   :: !EndPointAddress
  , remoteState     :: !(MVar RemoteState)
  , remoteId        :: !HeavyweightConnectionId
  , remoteScheduled :: !(Chan (IO ()))
  }

data RequestedBy = RequestedByUs | RequestedByThem 
  deriving (Eq, Show)

data RemoteState =
    -- | Invalid remote endpoint (for example, invalid address)
    RemoteEndPointInvalid !(TransportError ConnectErrorCode)
    -- | The remote endpoint is being initialized
  | RemoteEndPointInit !(MVar ()) !(MVar ()) !RequestedBy 
    -- | "Normal" working endpoint
  | RemoteEndPointValid !ValidRemoteEndPointState 
    -- | The remote endpoint is being closed (garbage collected)
  | RemoteEndPointClosing !(MVar ()) !ValidRemoteEndPointState
    -- | The remote endpoint has been closed (garbage collected)
  | RemoteEndPointClosed
    -- | The remote endpoint has failed, or has been forcefully shutdown
    -- using a closeTransport or closeEndPoint API call
  | RemoteEndPointFailed !IOException

-- TODO: we might want to replace Set (here and elsewhere) by faster 
-- containers
--
-- TODO: we could get rid of 'remoteIncoming' (and maintain less state) if
-- we introduce a new event 'AllConnectionsClosed' 
data ValidRemoteEndPointState = ValidRemoteEndPointState 
  { _remoteOutgoing      :: !Int
  , _remoteIncoming      :: !(Set LightweightConnectionId)
  , _remoteMaxIncoming   :: !LightweightConnectionId
  , _remoteNextConnOutId :: !LightweightConnectionId
  ,  remoteSocket        :: !N.Socket
  ,  remoteSendLock      :: !(MVar ())
  }

-- | Local identifier for an endpoint within this transport
type EndPointId = Word32

-- | Pair of local and a remote endpoint (for conciseness in signatures)
type EndPointPair = (LocalEndPoint, RemoteEndPoint)

-- | Lightweight connection ID (sender allocated)
--
-- A ConnectionId is the concentation of a 'HeavyweightConnectionId' and a
-- 'LightweightConnectionId'.
type LightweightConnectionId = Word32

-- | Heavyweight connection ID (recipient allocated)
--
-- A ConnectionId is the concentation of a 'HeavyweightConnectionId' and a
-- 'LightweightConnectionId'.
type HeavyweightConnectionId = Word32

-- | Control headers 
data ControlHeader = 
    -- | Tell the remote endpoint that we created a new connection
    CreatedNewConnection 
    -- | Tell the remote endpoint we will no longer be using a connection
  | CloseConnection     
    -- | Request to close the connection (see module description)
  | CloseSocket         
  deriving (Enum, Bounded, Show)

-- | Response sent by /B/ to /A/ when /A/ tries to connect
data ConnectionRequestResponse =
    -- | /B/ accepts the connection
    ConnectionRequestAccepted        
    -- | /A/ requested an invalid endpoint
  | ConnectionRequestInvalid 
    -- | /A/s request crossed with a request from /B/ (see protocols)
  | ConnectionRequestCrossed         
  deriving (Enum, Bounded, Show)

-- | Parameters for setting up the TCP transport
data TCPParameters = TCPParameters {
    -- | Backlog for 'listen'.
    -- Defaults to SOMAXCONN.
    tcpBacklog :: Int
    -- | Should we set SO_REUSEADDR on the server socket? 
    -- Defaults to True.
  , tcpReuseServerAddr :: Bool
    -- | Should we set SO_REUSEADDR on client sockets?
    -- Defaults to True.
  , tcpReuseClientAddr :: Bool
  }

-- | Internal functionality we expose for unit testing
data TransportInternals = TransportInternals 
  { -- | The ID of the thread that listens for new incoming connections
    transportThread :: ThreadId
    -- | Find the socket between a local and a remote endpoint
  , socketBetween :: EndPointAddress 
                  -> EndPointAddress 
                  -> IO N.Socket 
  }

--------------------------------------------------------------------------------
-- Top-level functionality                                                    --
--------------------------------------------------------------------------------

-- | Create a TCP transport
createTransport :: N.HostName 
                -> N.ServiceName 
                -> TCPParameters
                -> IO (Either IOException Transport)
createTransport host port params = 
  either Left (Right . fst) <$> createTransportExposeInternals host port params

-- | You should probably not use this function (used for unit testing only)
createTransportExposeInternals 
  :: N.HostName 
  -> N.ServiceName 
  -> TCPParameters
  -> IO (Either IOException (Transport, TransportInternals)) 
createTransportExposeInternals host port params = do 
    state <- newMVar . TransportValid $ ValidTransportState 
      { _localEndPoints = Map.empty 
      , _nextEndPointId = 0 
      }
    let transport = TCPTransport { transportState  = state
                                 , transportHost   = host
                                 , transportPort   = port
                                 , transportParams = params
                                 }
    tryIO $ bracketOnError (forkServer 
                             host 
                             port 
                             (tcpBacklog params) 
                             (tcpReuseServerAddr params)
                             (terminationHandler transport) 
                             (handleConnectionRequest transport))
                           killThread
                           (mkTransport transport)
  where
    mkTransport :: TCPTransport 
                -> ThreadId 
                -> IO (Transport, TransportInternals)
    mkTransport transport tid = return
      ( Transport 
          { newEndPoint    = apiNewEndPoint transport 
          , closeTransport = let evs = [ EndPointClosed
                                       , throw $ userError "Transport closed"
                                       ] in
                             apiCloseTransport transport (Just tid) evs 
          } 
      , TransportInternals 
          { transportThread = tid
          , socketBetween   = internalSocketBetween transport
          }
      )

    terminationHandler :: TCPTransport -> SomeException -> IO ()
    terminationHandler transport ex = do 
      let evs = [ ErrorEvent (TransportError EventTransportFailed (show ex))
                , throw $ userError "Transport closed" 
                ]
      apiCloseTransport transport Nothing evs 

-- | Default TCP parameters
defaultTCPParameters :: TCPParameters
defaultTCPParameters = TCPParameters {
    tcpBacklog         = N.sOMAXCONN
  , tcpReuseServerAddr = True
  , tcpReuseClientAddr = True
  }

--------------------------------------------------------------------------------
-- API functions                                                              --
--------------------------------------------------------------------------------

-- | Close the transport
apiCloseTransport :: TCPTransport -> Maybe ThreadId -> [Event] -> IO ()
apiCloseTransport transport mTransportThread evs = 
  asyncWhenCancelled return $ do
    mTSt <- modifyMVar (transportState transport) $ \st -> case st of
      TransportValid vst -> return (TransportClosed, Just vst)
      TransportClosed    -> return (TransportClosed, Nothing)
    forM_ mTSt $ mapM_ (apiCloseEndPoint transport evs) . (^. localEndPoints) 
    -- This will invoke the termination handler, which in turn will call
    -- apiCloseTransport again, but then the transport will already be closed
    -- and we won't be passed a transport thread, so we terminate immmediate
    forM_ mTransportThread killThread 
     
-- | Create a new endpoint 
apiNewEndPoint :: TCPTransport 
               -> IO (Either (TransportError NewEndPointErrorCode) EndPoint)
apiNewEndPoint transport = 
  try . asyncWhenCancelled closeEndPoint $ do
    ourEndPoint <- createLocalEndPoint transport
    return EndPoint 
      { receive       = readChan (localChannel ourEndPoint)
      , address       = localAddress ourEndPoint
      , connect       = apiConnect (transportParams transport) ourEndPoint 
      , closeEndPoint = let evs = [ EndPointClosed
                                  , throw $ userError "Endpoint closed" 
                                  ] in
                        apiCloseEndPoint transport evs ourEndPoint
      , newMulticastGroup     = return . Left $ newMulticastGroupError 
      , resolveMulticastGroup = return . Left . const resolveMulticastGroupError
      }
  where
    newMulticastGroupError = 
      TransportError NewMulticastGroupUnsupported "Multicast not supported" 
    resolveMulticastGroupError = 
      TransportError ResolveMulticastGroupUnsupported "Multicast not supported" 

-- | Connnect to an endpoint
apiConnect :: TCPParameters    -- ^ Parameters
           -> LocalEndPoint    -- ^ Local end point
           -> EndPointAddress  -- ^ Remote address
           -> Reliability      -- ^ Reliability (ignored)
           -> ConnectHints     -- ^ Hints
           -> IO (Either (TransportError ConnectErrorCode) Connection)
apiConnect params ourEndPoint theirAddress _reliability hints =
  try . asyncWhenCancelled close $ 
    if localAddress ourEndPoint == theirAddress 
      then connectToSelf ourEndPoint  
      else do
        resetIfBroken ourEndPoint theirAddress
        (theirEndPoint, connId) <- 
          createConnectionTo params ourEndPoint theirAddress hints
        -- connAlive can be an IORef rather than an MVar because it is protected
        -- by the remoteState MVar. We don't need the overhead of locking twice.
        connAlive <- newIORef True
        return Connection 
          { send  = apiSend  (ourEndPoint, theirEndPoint) connId connAlive 
          , close = apiClose (ourEndPoint, theirEndPoint) connId connAlive 
          }

-- | Close a connection
apiClose :: EndPointPair -> LightweightConnectionId -> IORef Bool -> IO ()
apiClose (ourEndPoint, theirEndPoint) connId connAlive = 
  void . tryIO . asyncWhenCancelled return $ do 
    mAct <- modifyMVar (remoteState theirEndPoint) $ \st -> case st of
      RemoteEndPointValid vst -> do
        alive <- readIORef connAlive
        if alive 
          then do
            writeIORef connAlive False
            act <- schedule theirEndPoint $ 
              sendOn vst [encodeInt32 CloseConnection, encodeInt32 connId]  
            return ( RemoteEndPointValid 
                   . (remoteOutgoing ^: (\x -> x - 1)) 
                   $ vst
                   , Just act
                   )
          else
            return (RemoteEndPointValid vst, Nothing)
      _ ->
        return (st, Nothing)
    forM_ mAct $ runScheduledAction (ourEndPoint, theirEndPoint)
    closeIfUnused (ourEndPoint, theirEndPoint)


-- | Send data across a connection
apiSend :: EndPointPair             -- ^ Local and remote endpoint 
        -> LightweightConnectionId  -- ^ Connection ID
        -> IORef Bool               -- ^ Is the connection still alive?
        -> [ByteString]             -- ^ Payload
        -> IO (Either (TransportError SendErrorCode) ())
apiSend (ourEndPoint, theirEndPoint) connId connAlive payload =  
    -- We don't need the overhead of asyncWhenCancelled here
    try . mapIOException sendFailed $ do
      act <- withMVar (remoteState theirEndPoint) $ \st -> case st of
        RemoteEndPointInvalid _ ->
          relyViolation (ourEndPoint, theirEndPoint) "apiSend"
        RemoteEndPointInit _ _ _ ->
          relyViolation (ourEndPoint, theirEndPoint) "apiSend"
        RemoteEndPointValid vst -> do
          alive <- readIORef connAlive
          if alive 
            then schedule theirEndPoint $ 
              sendOn vst (encodeInt32 connId : prependLength payload)
            else throwIO $ TransportError SendClosed "Connection closed"
        RemoteEndPointClosing _ _ -> do 
          alive <- readIORef connAlive
          if alive 
            then relyViolation (ourEndPoint, theirEndPoint) "apiSend" 
            else throwIO $ TransportError SendClosed "Connection closed"
        RemoteEndPointClosed -> do
          alive <- readIORef connAlive
          if alive 
            then relyViolation (ourEndPoint, theirEndPoint) "apiSend" 
            else throwIO $ TransportError SendClosed "Connection closed"
        RemoteEndPointFailed err -> do
          alive <- readIORef connAlive
          if alive 
            then throwIO $ TransportError SendFailed (show err) 
            else throwIO $ TransportError SendClosed "Connection closed"
      runScheduledAction (ourEndPoint, theirEndPoint) act
  where
    sendFailed = TransportError SendFailed . show

-- | Force-close the endpoint
apiCloseEndPoint :: TCPTransport    -- ^ Transport 
                 -> [Event]         -- ^ Events used to report closure 
                 -> LocalEndPoint   -- ^ Local endpoint
                 -> IO ()
apiCloseEndPoint transport evs ourEndPoint =
  asyncWhenCancelled return $ do
    -- Remove the reference from the transport state
    removeLocalEndPoint transport ourEndPoint
    -- Close the local endpoint 
    mOurState <- modifyMVar (localState ourEndPoint) $ \st ->
      case st of
        LocalEndPointValid vst -> 
          return (LocalEndPointClosed, Just vst)
        LocalEndPointClosed ->
          return (LocalEndPointClosed, Nothing)
    forM_ mOurState $ \vst -> do
      forM_ (vst ^. localConnections) tryCloseRemoteSocket
      forM_ evs $ writeChan (localChannel ourEndPoint) 
  where
    -- Close the remote socket and return the set of all incoming connections
    tryCloseRemoteSocket :: RemoteEndPoint -> IO () 
    tryCloseRemoteSocket theirEndPoint = do
      -- We make an attempt to close the connection nicely 
      -- (by sending a CloseSocket first)
      let closed = RemoteEndPointFailed . userError $ "apiCloseEndPoint"
      mAct <- modifyMVar (remoteState theirEndPoint) $ \st ->
        case st of
          RemoteEndPointInvalid _ -> 
            return (st, Nothing)
          RemoteEndPointInit resolved _ _ -> do
            putMVar resolved ()
            return (closed, Nothing)
          RemoteEndPointValid vst -> do 
            act <- schedule theirEndPoint $ do
              tryIO $ sendOn vst [ encodeInt32 CloseSocket
                                 , encodeInt32 (vst ^. remoteMaxIncoming)
                                 ]
              tryCloseSocket (remoteSocket vst)
            return (closed, Just act)
          RemoteEndPointClosing resolved vst -> do 
            putMVar resolved ()
            act <- schedule theirEndPoint $ tryCloseSocket (remoteSocket vst)
            return (closed, Just act)
          RemoteEndPointClosed ->
            return (st, Nothing)
          RemoteEndPointFailed err ->
            return (RemoteEndPointFailed err, Nothing)
      forM_ mAct $ runScheduledAction (ourEndPoint, theirEndPoint)
            

--------------------------------------------------------------------------------
-- Incoming requests                                                          --
--------------------------------------------------------------------------------

-- | Handle a connection request (that is, a remote endpoint that is trying to
-- establish a TCP connection with us)
--
-- 'handleConnectionRequest' runs in the context of the transport thread, which
-- can be killed asynchronously by 'closeTransport'. We fork a separate thread
-- as soon as we have located the lcoal endpoint that the remote endpoint is
-- interested in. We cannot fork any sooner because then we have no way of
-- storing the thread ID and hence no way of killing the thread when we take
-- the transport down. We must be careful to close the socket when a (possibly
-- asynchronous, ThreadKilled) exception occurs. (If an exception escapes from
-- handleConnectionRequest the transport will be shut down.)
handleConnectionRequest :: TCPTransport -> N.Socket -> IO () 
handleConnectionRequest transport sock = handle handleException $ do 
    ourEndPointId <- recvInt32 sock
    theirAddress  <- EndPointAddress . BS.concat <$> recvWithLength sock 
    let ourAddress = encodeEndPointAddress (transportHost transport) 
                                           (transportPort transport)
                                           ourEndPointId
    ourEndPoint <- withMVar (transportState transport) $ \st -> case st of
      TransportValid vst ->
        case vst ^. localEndPointAt ourAddress of
          Nothing -> do
            sendMany sock [encodeInt32 ConnectionRequestInvalid]
            throwIO $ userError "handleConnectionRequest: Invalid endpoint"
          Just ourEndPoint ->
            return ourEndPoint
      TransportClosed -> 
        throwIO $ userError "Transport closed"
    void . forkIO $ go ourEndPoint theirAddress
  where
    go :: LocalEndPoint -> EndPointAddress -> IO ()
    go ourEndPoint theirAddress = do 
      -- This runs in a thread that will never be killed
      mEndPoint <- handle ((>> return Nothing) . handleException) $ do 
        resetIfBroken ourEndPoint theirAddress
        (theirEndPoint, isNew) <- 
          findRemoteEndPoint ourEndPoint theirAddress RequestedByThem
        
        if not isNew 
          then do
            tryIO $ sendMany sock [encodeInt32 ConnectionRequestCrossed]
            tryCloseSocket sock
            return Nothing 
          else do
            sendLock <- newMVar ()
            let vst = ValidRemoteEndPointState 
                        {  remoteSocket        = sock
                        ,  remoteSendLock      = sendLock
                        , _remoteOutgoing      = 0
                        , _remoteIncoming      = Set.empty
                        , _remoteMaxIncoming   = 0
                        , _remoteNextConnOutId = firstNonReservedLightweightConnectionId 
                        }
            sendMany sock [encodeInt32 ConnectionRequestAccepted]
            resolveInit (ourEndPoint, theirEndPoint) (RemoteEndPointValid vst)
            return (Just theirEndPoint)
      -- If we left the scope of the exception handler with a return value of
      -- Nothing then the socket is already closed; otherwise, the socket has
      -- been recorded as part of the remote endpoint. Either way, we no longer
      -- have to worry about closing the socket on receiving an asynchronous
      -- exception from this point forward. 
      forM_ mEndPoint $ handleIncomingMessages . (,) ourEndPoint

    handleException :: SomeException -> IO ()
    handleException ex = do 
      tryCloseSocket sock 
      rethrowIfAsync (fromException ex)

    rethrowIfAsync :: Maybe AsyncException -> IO () 
    rethrowIfAsync = mapM_ throwIO 

-- | Handle requests from a remote endpoint.
-- 
-- Returns only if the remote party closes the socket or if an error occurs.
-- This runs in a thread that will never be killed.
handleIncomingMessages :: EndPointPair -> IO () 
handleIncomingMessages (ourEndPoint, theirEndPoint) = do
    mSock <- withMVar theirState $ \st ->
      case st of
        RemoteEndPointInvalid _ ->
          relyViolation (ourEndPoint, theirEndPoint) 
            "handleIncomingMessages (invalid)" 
        RemoteEndPointInit _ _ _ ->
          relyViolation (ourEndPoint, theirEndPoint) 
            "handleIncomingMessages (init)"
        RemoteEndPointValid ep ->
          return . Just $ remoteSocket ep
        RemoteEndPointClosing _ ep ->
          return . Just $ remoteSocket ep
        RemoteEndPointClosed ->
          return Nothing
        RemoteEndPointFailed _ -> 
          return Nothing
    
    forM_ mSock $ \sock -> 
      tryIO (go sock) >>= either (prematureExit sock) return
  where
    -- Dispatch 
    --
    -- If a recv throws an exception this will be caught top-level and
    -- 'prematureExit' will be invoked. The same will happen if the remote
    -- endpoint is put into a Closed (or Closing) state by a concurrent thread
    -- (because a 'send' failed) -- the individual handlers below will throw a
    -- user exception which is then caught and handled the same way as an
    -- exception thrown by 'recv'.
    go :: N.Socket -> IO ()
    go sock = do
      lcid <- recvInt32 sock :: IO LightweightConnectionId
      if lcid >= firstNonReservedLightweightConnectionId 
        then do
          readMessage sock lcid 
          go sock
        else 
          case tryToEnum (fromIntegral lcid) of
            Just CreatedNewConnection -> do
              recvInt32 sock >>= createdNewConnection 
              go sock 
            Just CloseConnection -> do
              recvInt32 sock >>= closeConnection 
              go sock
            Just CloseSocket -> do 
              didClose <- recvInt32 sock >>= closeSocket sock 
              unless didClose $ go sock
            Nothing ->
              throwIO $ userError "Invalid control request"
        
    -- Create a new connection
    createdNewConnection :: LightweightConnectionId -> IO () 
    createdNewConnection lcid = do
      modifyMVar_ theirState $ \st -> do
        vst <- case st of
          RemoteEndPointInvalid _ ->
            relyViolation (ourEndPoint, theirEndPoint) 
              "handleIncomingMessages:createNewConnection (invalid)"
          RemoteEndPointInit _ _ _ ->
            relyViolation (ourEndPoint, theirEndPoint) 
              "handleIncomingMessages:createNewConnection (init)"
          RemoteEndPointValid vst ->
            return ( (remoteIncoming ^: Set.insert lcid)
                   $ (remoteMaxIncoming ^= lcid) 
                   vst
                   )
          RemoteEndPointClosing resolved vst -> do
            -- If the endpoint is in closing state that means we send a
            -- CloseSocket request to the remote endpoint. If the remote
            -- endpoint replies that it created a new connection, it either
            -- ignored our request or it sent the request before it got ours.
            -- Either way, at this point we simply restore the endpoint to
            -- RemoteEndPointValid
            putMVar resolved ()
            return ( (remoteIncoming ^= Set.singleton lcid)
                   . (remoteMaxIncoming ^= lcid)
                   $ vst
                   )
          RemoteEndPointFailed err -> 
            throwIO err
          RemoteEndPointClosed ->
            relyViolation (ourEndPoint, theirEndPoint) 
              "createNewConnection (closed)"
        return (RemoteEndPointValid vst)
      writeChan ourChannel (ConnectionOpened (connId lcid) ReliableOrdered theirAddr) 

    -- Close a connection 
    -- It is important that we verify that the connection is in fact open,
    -- because otherwise we should not decrement the reference count
    closeConnection :: LightweightConnectionId -> IO () 
    closeConnection lcid = do
      modifyMVar_ theirState $ \st -> case st of
        RemoteEndPointInvalid _ ->
          relyViolation (ourEndPoint, theirEndPoint) "closeConnection (invalid)"
        RemoteEndPointInit _ _ _ ->
          relyViolation (ourEndPoint, theirEndPoint) "closeConnection (init)"
        RemoteEndPointValid vst -> do 
          unless (Set.member lcid (vst ^. remoteIncoming)) $ 
            throwIO $ userError "Invalid CloseConnection"
          return ( RemoteEndPointValid 
                 . (remoteIncoming ^: Set.delete lcid) 
                 $ vst
                 )
        RemoteEndPointClosing _ _ ->
          -- If the remote endpoint is in Closing state, that means that are as
          -- far as we are concerned there are no incoming connections. This
          -- means that a CloseConnection request at this point is invalid.
          throwIO $ userError "Invalid CloseConnection request" 
        RemoteEndPointFailed err ->
          throwIO err
        RemoteEndPointClosed ->
          relyViolation (ourEndPoint, theirEndPoint) "closeConnection (closed)"
      writeChan ourChannel (ConnectionClosed (connId lcid))

    -- Close the socket (if we don't have any outgoing connections)
    closeSocket :: N.Socket -> LightweightConnectionId -> IO Bool 
    closeSocket sock lastReceivedId = do
      mAct <- modifyMVar theirState $ \st -> do
        case st of
          RemoteEndPointInvalid _ ->
            relyViolation (ourEndPoint, theirEndPoint)
              "handleIncomingMessages:closeSocket (invalid)"
          RemoteEndPointInit _ _ _ ->
            relyViolation (ourEndPoint, theirEndPoint) 
              "handleIncomingMessages:closeSocket (init)"
          RemoteEndPointValid vst -> do
            -- We regard a CloseSocket message as an (optimized) way for the
            -- remote endpoint to indicate that all its connections to us are
            -- now properly closed
            forM_ (Set.elems $ vst ^. remoteIncoming) $ 
              writeChan ourChannel . ConnectionClosed . connId
            let vst' = remoteIncoming ^= Set.empty $ vst 
            -- If we still have outgoing connections then we ignore the
            -- CloseSocket request (we sent a ConnectionCreated message to the
            -- remote endpoint, but it did not receive it before sending the
            -- CloseSocket request). Similarly, if lastReceivedId < lastSentId
            -- then we sent a ConnectionCreated *AND* a ConnectionClosed
            -- message to the remote endpoint, *both of which* it did not yet
            -- receive before sending the CloseSocket request. 
            if vst' ^. remoteOutgoing > 0 || lastReceivedId < lastSentId vst 
              then
                return (RemoteEndPointValid vst', Nothing)
              else do 
                removeRemoteEndPoint (ourEndPoint, theirEndPoint)
                -- Attempt to reply (but don't insist)
                act <- schedule theirEndPoint $ do
                  tryIO $ sendOn vst' [ encodeInt32 CloseSocket
                                      , encodeInt32 (vst ^. remoteMaxIncoming)
                                      ]
                  tryCloseSocket sock 
                return (RemoteEndPointClosed, Just act)
          RemoteEndPointClosing resolved vst ->  do
            -- Like above, we need to check if there is a ConnectionCreated
            -- message that we sent but that the remote endpoint has not yet 
            -- received. However, since we are in 'closing' state, the only 
            -- way this may happen is when we sent a ConnectionCreated,
            -- ConnectionClosed, and CloseSocket message, none of which have
            -- yet been received. We leave the endpoint in closing state in
            -- that case.
            if lastReceivedId < lastSentId vst
              then do
                return (RemoteEndPointClosing resolved vst, Nothing)
              else do
                removeRemoteEndPoint (ourEndPoint, theirEndPoint)
                act <- schedule theirEndPoint $ tryCloseSocket sock 
                putMVar resolved ()
                return (RemoteEndPointClosed, Just act)
          RemoteEndPointFailed err ->
            throwIO err
          RemoteEndPointClosed ->
            relyViolation (ourEndPoint, theirEndPoint) 
              "handleIncomingMessages:closeSocket (closed)"
      case mAct of
        Nothing -> return False
        Just act -> do
          runScheduledAction (ourEndPoint, theirEndPoint) act
          return True
            
    -- Read a message and output it on the endPoint's channel. By rights we
    -- should verify that the connection ID is valid, but this is unnecessary
    -- overhead
    readMessage :: N.Socket -> LightweightConnectionId -> IO () 
    readMessage sock lcid = 
      recvWithLength sock >>= writeChan ourChannel . Received (connId lcid) 

    -- Arguments
    ourChannel  = localChannel ourEndPoint 
    theirState  = remoteState theirEndPoint
    theirAddr   = remoteAddress theirEndPoint

    -- Deal with a premature exit
    prematureExit :: N.Socket -> IOException -> IO ()
    prematureExit sock err = do
      tryCloseSocket sock
      modifyMVar_ theirState $ \st ->
        case st of
          RemoteEndPointInvalid _ ->
            relyViolation (ourEndPoint, theirEndPoint) 
              "handleIncomingMessages:prematureExit"
          RemoteEndPointInit _ _ _ ->
            relyViolation (ourEndPoint, theirEndPoint) 
              "handleIncomingMessages:prematureExit"
          RemoteEndPointValid _ -> do
            let code = EventConnectionLost (remoteAddress theirEndPoint) 
            writeChan ourChannel . ErrorEvent $ TransportError code (show err)
            return (RemoteEndPointFailed err)
          RemoteEndPointClosing resolved _ -> do
            putMVar resolved ()
            return (RemoteEndPointFailed err)
          RemoteEndPointClosed ->
            relyViolation (ourEndPoint, theirEndPoint) 
              "handleIncomingMessages:prematureExit"
          RemoteEndPointFailed err' ->
            return (RemoteEndPointFailed err')

    -- Construct a connection ID
    connId :: LightweightConnectionId -> ConnectionId
    connId = createConnectionId (remoteId theirEndPoint)

    -- The ID of the last connection _we_ created (or 0 for none)
    lastSentId :: ValidRemoteEndPointState -> LightweightConnectionId
    lastSentId vst = 
      if vst ^. remoteNextConnOutId == firstNonReservedLightweightConnectionId
        then 0
        else (vst ^. remoteNextConnOutId) - 1

--------------------------------------------------------------------------------
-- Uninterruptable auxiliary functions                                        --
--                                                                            --
-- All these functions assume they are running in a thread which will never   --
-- be killed.
--------------------------------------------------------------------------------

-- | Create a connection to a remote endpoint
--
-- If the remote endpoint is in 'RemoteEndPointClosing' state then we will
-- block until that is resolved. 
--
-- May throw a TransportError ConnectErrorCode exception.
createConnectionTo :: TCPParameters 
                    -> LocalEndPoint 
                    -> EndPointAddress 
                    -> ConnectHints
                    -> IO (RemoteEndPoint, LightweightConnectionId)
createConnectionTo params ourEndPoint theirAddress hints = go
  where
    go = do
      (theirEndPoint, isNew) <- mapIOException connectFailed $
        findRemoteEndPoint ourEndPoint theirAddress RequestedByUs

      if isNew 
        then do
          forkIO . handle absorbAllExceptions $ 
            setupRemoteEndPoint params (ourEndPoint, theirEndPoint) hints 
          go
        else do
          -- 'findRemoteEndPoint' will have increased 'remoteOutgoing'
          mapIOException connectFailed $ do
            act <- modifyMVar (remoteState theirEndPoint) $ \st -> case st of
              RemoteEndPointValid vst -> do 
                let connId = vst ^. remoteNextConnOutId 
                act <- schedule theirEndPoint $ do
                  sendOn vst [encodeInt32 CreatedNewConnection, encodeInt32 connId]
                  return connId
                return ( RemoteEndPointValid 
                       $ remoteNextConnOutId ^= connId + 1 
                       $ vst
                       , act 
                       )
              -- Error cases
              RemoteEndPointInvalid err -> 
                throwIO err
              RemoteEndPointFailed err ->
                throwIO err
              -- Algorithmic errors
              _ ->
                relyViolation (ourEndPoint, theirEndPoint) "createConnectionTo"
            -- TODO: deal with exception case?
            connId <- runScheduledAction (ourEndPoint, theirEndPoint) act
            return (theirEndPoint, connId)
          

    connectFailed :: IOException -> TransportError ConnectErrorCode
    connectFailed = TransportError ConnectFailed . show

    absorbAllExceptions :: SomeException -> IO ()
    absorbAllExceptions _ex = 
      return ()

-- | Set up a remote endpoint
setupRemoteEndPoint :: TCPParameters -> EndPointPair -> ConnectHints -> IO () 
setupRemoteEndPoint params (ourEndPoint, theirEndPoint) hints = do
    result <- socketToEndPoint ourAddress 
                               theirAddress 
                               (tcpReuseClientAddr params)
                               (connectTimeout hints)
    didAccept <- case result of
      Right (sock, ConnectionRequestAccepted) -> do 
        sendLock <- newMVar () 
        let vst = ValidRemoteEndPointState 
                    {  remoteSocket        = sock
                    ,  remoteSendLock      = sendLock
                    , _remoteOutgoing      = 0 
                    , _remoteIncoming      = Set.empty
                    , _remoteMaxIncoming   = 0
                    , _remoteNextConnOutId = firstNonReservedLightweightConnectionId 
                    }
        resolveInit (ourEndPoint, theirEndPoint) (RemoteEndPointValid vst)
        return True
      Right (sock, ConnectionRequestInvalid) -> do
        let err = invalidAddress "setupRemoteEndPoint: Invalid endpoint"
        resolveInit (ourEndPoint, theirEndPoint) (RemoteEndPointInvalid err)
        tryCloseSocket sock
        return False
      Right (sock, ConnectionRequestCrossed) -> do
        withMVar (remoteState theirEndPoint) $ \st -> case st of
          RemoteEndPointInit _ crossed _ -> 
            putMVar crossed ()
          RemoteEndPointFailed ex ->
            throwIO ex
          _ ->
            relyViolation (ourEndPoint, theirEndPoint) "setupRemoteEndPoint: Crossed"
        tryCloseSocket sock
        return False
      Left err -> do 
        resolveInit (ourEndPoint, theirEndPoint) (RemoteEndPointInvalid err)
        return False

    when didAccept $ handleIncomingMessages (ourEndPoint, theirEndPoint) 
  where
    ourAddress      = localAddress ourEndPoint
    theirAddress    = remoteAddress theirEndPoint
    invalidAddress  = TransportError ConnectNotFound

-- | Send a CloseSocket request if the remote endpoint is unused
closeIfUnused :: EndPointPair -> IO ()
closeIfUnused (ourEndPoint, theirEndPoint) = do 
  mAct <- modifyMVar (remoteState theirEndPoint) $ \st -> case st of
    RemoteEndPointValid vst ->
      if vst ^. remoteOutgoing == 0 && Set.null (vst ^. remoteIncoming) 
        then do 
          resolved <- newEmptyMVar
          act <- schedule theirEndPoint $ 
            sendOn vst [ encodeInt32 CloseSocket
                       , encodeInt32 (vst ^. remoteMaxIncoming)
                       ]
          return (RemoteEndPointClosing resolved vst, Just act)
        else 
          return (RemoteEndPointValid vst, Nothing)
    _ ->
      return (st, Nothing)
  forM_ mAct $ runScheduledAction (ourEndPoint, theirEndPoint) 
  
-- | Reset a remote endpoint if it is in Invalid mode
--
-- If the remote endpoint is currently in broken state, and
--
--   - a user calls the API function 'connect', or and the remote endpoint is
--   - an inbound connection request comes in from this remote address
--
-- we remove the remote endpoint first.
--
-- Throws a TransportError ConnectFailed exception if the local endpoint is
-- closed.
resetIfBroken :: LocalEndPoint -> EndPointAddress -> IO ()
resetIfBroken ourEndPoint theirAddress = do
  mTheirEndPoint <- withMVar (localState ourEndPoint) $ \st -> case st of
    LocalEndPointValid vst ->
      return (vst ^. localConnectionTo theirAddress)
    LocalEndPointClosed ->
      throwIO $ TransportError ConnectFailed "Endpoint closed"
  forM_ mTheirEndPoint $ \theirEndPoint -> 
    withMVar (remoteState theirEndPoint) $ \st -> case st of
      RemoteEndPointInvalid _ ->
        removeRemoteEndPoint (ourEndPoint, theirEndPoint)
      RemoteEndPointFailed _ ->
        removeRemoteEndPoint (ourEndPoint, theirEndPoint)
      _ ->
        return ()

-- | Special case of 'apiConnect': connect an endpoint to itself
-- 
-- May throw a TransportError ConnectErrorCode (if the local endpoint is closed)
connectToSelf :: LocalEndPoint 
               -> IO Connection
connectToSelf ourEndPoint = do  
    connAlive <- newIORef True  -- Protected by the local endpoint lock
    lconnId   <- mapIOException connectFailed $ getLocalNextConnOutId ourEndPoint 
    let connId = createConnectionId heavyweightSelfConnectionId lconnId
    writeChan ourChan $
      ConnectionOpened connId ReliableOrdered (localAddress ourEndPoint)
    return Connection 
      { send  = selfSend connAlive connId 
      , close = selfClose connAlive connId
      }
  where
    selfSend :: IORef Bool 
             -> ConnectionId 
             -> [ByteString] 
             -> IO (Either (TransportError SendErrorCode) ())
    selfSend connAlive connId msg = 
      try . withMVar ourState $ \st -> case st of
        LocalEndPointValid _ -> do
          alive <- readIORef connAlive
          if alive
            then writeChan ourChan (Received connId msg)
            else throwIO $ TransportError SendClosed "Connection closed"
        LocalEndPointClosed ->
          throwIO $ TransportError SendFailed "Endpoint closed"

    selfClose :: IORef Bool -> ConnectionId -> IO ()
    selfClose connAlive connId = 
      withMVar ourState $ \st -> case st of
        LocalEndPointValid _ -> do
          alive <- readIORef connAlive
          when alive $ do
            writeChan ourChan (ConnectionClosed connId) 
            writeIORef connAlive False
        LocalEndPointClosed ->
          return () 

    ourChan  = localChannel ourEndPoint
    ourState = localState ourEndPoint
    connectFailed = TransportError ConnectFailed . show 

-- | Resolve an endpoint currently in 'Init' state
resolveInit :: EndPointPair -> RemoteState -> IO ()
resolveInit (ourEndPoint, theirEndPoint) newState =
  modifyMVar_ (remoteState theirEndPoint) $ \st -> case st of
    RemoteEndPointInit resolved _ _ -> do
      putMVar resolved ()
      case newState of 
        RemoteEndPointClosed -> 
          removeRemoteEndPoint (ourEndPoint, theirEndPoint)
        _ ->
          return ()
      return newState
    RemoteEndPointFailed ex -> 
      throwIO ex
    _ ->
      relyViolation (ourEndPoint, theirEndPoint) "resolveInit"

-- | Get the next outgoing self-connection ID
-- 
-- Throws an IO exception when the endpoint is closed.
getLocalNextConnOutId :: LocalEndPoint -> IO LightweightConnectionId
getLocalNextConnOutId ourEndpoint = 
  modifyMVar (localState ourEndpoint) $ \st -> case st of
    LocalEndPointValid vst -> do
      let connId = vst ^. localNextConnOutId 
      return ( LocalEndPointValid 
             . (localNextConnOutId ^= connId + 1) 
             $ vst
             , connId)
    LocalEndPointClosed ->
      throwIO $ userError "Local endpoint closed"

-- | Create a new local endpoint
-- 
-- May throw a TransportError NewEndPointErrorCode exception if the transport
-- is closed.
createLocalEndPoint :: TCPTransport -> IO LocalEndPoint
createLocalEndPoint transport = do 
    chan  <- newChan
    state <- newMVar . LocalEndPointValid $ ValidLocalEndPointState 
      { _localNextConnOutId = firstNonReservedLightweightConnectionId
      , _localConnections   = Map.empty
      , _nextConnInId       = firstNonReservedHeavyweightConnectionId 
      }
    modifyMVar (transportState transport) $ \st -> case st of
      TransportValid vst -> do
        let ix   = vst ^. nextEndPointId
        let addr = encodeEndPointAddress (transportHost transport) 
                                         (transportPort transport)
                                         ix 
        let localEndPoint = LocalEndPoint { localAddress  = addr
                                          , localChannel  = chan
                                          , localState    = state
                                          }
        return ( TransportValid 
               . (localEndPointAt addr ^= Just localEndPoint) 
               . (nextEndPointId ^= ix + 1) 
               $ vst
               , localEndPoint
               )
      TransportClosed ->
        throwIO (TransportError NewEndPointFailed "Transport closed")


-- | Remove reference to a remote endpoint from a local endpoint
--
-- If the local endpoint is closed, do nothing
removeRemoteEndPoint :: EndPointPair -> IO ()
removeRemoteEndPoint (ourEndPoint, theirEndPoint) =
    modifyMVar_ ourState $ \st -> case st of
      LocalEndPointValid vst ->
        case vst ^. localConnectionTo theirAddress of
          Nothing -> 
            return st
          Just remoteEndPoint' ->
            if remoteId remoteEndPoint' == remoteId theirEndPoint 
              then return 
                ( LocalEndPointValid 
                . (localConnectionTo (remoteAddress theirEndPoint) ^= Nothing) 
                $ vst
                )
              else return st
      LocalEndPointClosed ->
        return LocalEndPointClosed
  where
    ourState     = localState ourEndPoint
    theirAddress = remoteAddress theirEndPoint

-- | Remove reference to a local endpoint from the transport state
--
-- Does nothing if the transport is closed
removeLocalEndPoint :: TCPTransport -> LocalEndPoint -> IO ()
removeLocalEndPoint transport ourEndPoint = 
  modifyMVar_ (transportState transport) $ \st -> case st of 
    TransportValid vst ->
      return ( TransportValid 
             . (localEndPointAt (localAddress ourEndPoint) ^= Nothing) 
             $ vst
             )
    TransportClosed ->
      return TransportClosed

-- | Find a remote endpoint. If the remote endpoint does not yet exist we
-- create it in Init state. Returns if the endpoint was new. 
findRemoteEndPoint 
  :: LocalEndPoint
  -> EndPointAddress
  -> RequestedBy 
  -> IO (RemoteEndPoint, Bool) 
findRemoteEndPoint ourEndPoint theirAddress findOrigin = go
  where
    go = do
      (theirEndPoint, isNew) <- modifyMVar ourState $ \st -> case st of
        LocalEndPointValid vst -> case vst ^. localConnectionTo theirAddress of
          Just theirEndPoint ->
            return (st, (theirEndPoint, False))
          Nothing -> do
            resolved   <- newEmptyMVar
            crossed    <- newEmptyMVar
            theirState <- newMVar (RemoteEndPointInit resolved crossed findOrigin)
            scheduled  <- newChan
            let theirEndPoint = RemoteEndPoint
                                  { remoteAddress   = theirAddress
                                  , remoteState     = theirState
                                  , remoteId        = vst ^. nextConnInId
                                  , remoteScheduled = scheduled 
                                  }
            return ( LocalEndPointValid 
                   . (localConnectionTo theirAddress ^= Just theirEndPoint) 
                   . (nextConnInId ^: (+ 1)) 
                   $ vst
                   , (theirEndPoint, True) 
                   )
        LocalEndPointClosed ->
          throwIO $ userError "Local endpoint closed"
      
      if isNew 
        then
          return (theirEndPoint, True)
        else do
          let theirState = remoteState theirEndPoint
          snapshot <- modifyMVar theirState $ \st -> case st of
            RemoteEndPointValid vst -> 
              case findOrigin of
                RequestedByUs -> do
                  let st' = RemoteEndPointValid 
                          . (remoteOutgoing ^: (+ 1)) 
                          $ vst 
                  return (st', st')
                RequestedByThem ->
                  return (st, st) 
            _ ->
              return (st, st)
          -- The snapshot may no longer be up to date at this point, but if we
          -- increased the refcount then it can only either be Valid or Failed 
          -- (after an explicit call to 'closeEndPoint' or 'closeTransport') 
          case snapshot of
            RemoteEndPointInvalid err ->
              throwIO err
            RemoteEndPointInit resolved crossed initOrigin ->
              case (findOrigin, initOrigin) of
                (RequestedByUs, RequestedByUs) -> 
                  readMVar resolved >> go 
                (RequestedByUs, RequestedByThem) -> 
                  readMVar resolved >> go
                (RequestedByThem, RequestedByUs) -> 
                  if ourAddress > theirAddress 
                    then do
                      -- Wait for the Crossed message
                      readMVar crossed 
                      return (theirEndPoint, True)
                    else
                      return (theirEndPoint, False)
                (RequestedByThem, RequestedByThem) -> 
                  throwIO $ userError "Already connected"
            RemoteEndPointValid _ ->
              -- We assume that the request crossed if we find the endpoint in
              -- Valid state. It is possible that this is really an invalid
              -- request, but only in the case of a broken client (we don't
              -- maintain enough history to be able to tell the difference).
              return (theirEndPoint, False)
            RemoteEndPointClosing resolved _ ->
              readMVar resolved >> go
            RemoteEndPointClosed ->
              go
            RemoteEndPointFailed err -> 
              throwIO err
          
    ourState   = localState ourEndPoint 
    ourAddress = localAddress ourEndPoint

-- | Send a payload over a heavyweight connection (thread safe)
sendOn :: ValidRemoteEndPointState -> [ByteString] -> IO () 
sendOn vst bs = withMVar (remoteSendLock vst) $ \() -> 
  sendMany (remoteSocket vst) bs

--------------------------------------------------------------------------------
-- Scheduling actions                                                         --
--------------------------------------------------------------------------------

-- | See 'schedule'/'runScheduledAction'
type Action a = MVar (Either SomeException a)

-- | Schedule an action to be executed (see also 'runScheduledAction')
schedule :: RemoteEndPoint -> IO a -> IO (Action a)
schedule theirEndPoint act = do
  mvar <- newEmptyMVar
  writeChan (remoteScheduled theirEndPoint) $ 
    catch (act >>= putMVar mvar . Right) (putMVar mvar . Left)
  return mvar

-- | Run a scheduled action. Every call to 'schedule' should be paired with a
-- call to 'runScheduledAction' so that every scheduled action is run. Note
-- however that the there is no guarantee that in
--
-- > do act <- schedule p
-- >    runScheduledAction
--
-- 'runScheduledAction' will run @p@ (it might run some other scheduled action).
-- However, it will then wait until @p@ is executed (by this call to 
-- 'runScheduledAction' or by another).
runScheduledAction :: EndPointPair -> Action a -> IO a
runScheduledAction (ourEndPoint, theirEndPoint) mvar = do
    join $ readChan (remoteScheduled theirEndPoint)
    ma <- readMVar mvar
    case ma of
      Right a -> return a
      Left e -> do
        forM_ (fromException e) $ \ioe -> 
          modifyMVar_ (remoteState theirEndPoint) $ \st ->
            case st of  
              RemoteEndPointValid vst -> handleIOException ioe vst
              _ -> return (RemoteEndPointFailed ioe)
        throwIO e
  where
    handleIOException :: IOException 
                      -> ValidRemoteEndPointState 
                      -> IO RemoteState 
    handleIOException ex vst = do
      tryCloseSocket (remoteSocket vst)
      let code     = EventConnectionLost (remoteAddress theirEndPoint)  
          err      = TransportError code (show ex)
      writeChan (localChannel ourEndPoint) $ ErrorEvent err 
      return (RemoteEndPointFailed ex)

--------------------------------------------------------------------------------
-- "Stateless" (MVar free) functions                                          --
--------------------------------------------------------------------------------

-- | Establish a connection to a remote endpoint
--
-- Maybe throw a TransportError
socketToEndPoint :: EndPointAddress -- ^ Our address 
                 -> EndPointAddress -- ^ Their address
                 -> Bool            -- ^ Use SO_REUSEADDR?
                 -> Maybe Int       -- ^ Timeout for connect 
                 -> IO (Either (TransportError ConnectErrorCode) 
                               (N.Socket, ConnectionRequestResponse)) 
socketToEndPoint (EndPointAddress ourAddress) theirAddress reuseAddr timeout = 
  try $ do 
    (host, port, theirEndPointId) <- case decodeEndPointAddress theirAddress of 
      Nothing  -> throwIO (failed . userError $ "Could not parse")
      Just dec -> return dec
    addr:_ <- mapIOException invalidAddress $ 
      N.getAddrInfo Nothing (Just host) (Just port)
    bracketOnError (createSocket addr) tryCloseSocket $ \sock -> do
      when reuseAddr $ 
        mapIOException failed $ N.setSocketOption sock N.ReuseAddr 1
      mapIOException invalidAddress $ 
        timeoutMaybe timeout timeoutError $ 
          N.connect sock (N.addrAddress addr) 
      response <- mapIOException failed $ do
        sendMany sock (encodeInt32 theirEndPointId : prependLength [ourAddress])
        recvInt32 sock
      case tryToEnum response of
        Nothing -> throwIO (failed . userError $ "Unexpected response")
        Just r  -> return (sock, r)
  where
    createSocket :: N.AddrInfo -> IO N.Socket
    createSocket addr = mapIOException insufficientResources $ 
      N.socket (N.addrFamily addr) N.Stream N.defaultProtocol

    invalidAddress        = TransportError ConnectNotFound . show 
    insufficientResources = TransportError ConnectInsufficientResources . show 
    failed                = TransportError ConnectFailed . show
    timeoutError          = TransportError ConnectTimeout "Timed out"

-- | Encode end point address
encodeEndPointAddress :: N.HostName 
                      -> N.ServiceName 
                      -> EndPointId 
                      -> EndPointAddress
encodeEndPointAddress host port ix = EndPointAddress . BSC.pack $
  host ++ ":" ++ port ++ ":" ++ show ix 

-- | Decode end point address
decodeEndPointAddress :: EndPointAddress 
                      -> Maybe (N.HostName, N.ServiceName, EndPointId)
decodeEndPointAddress (EndPointAddress bs) = 
  case splitMaxFromEnd (== ':') 2 $ BSC.unpack bs of
    [host, port, endPointIdStr] -> 
      case reads endPointIdStr of 
        [(endPointId, "")] -> Just (host, port, endPointId)
        _                  -> Nothing
    _ ->
      Nothing

-- | Construct a ConnectionId
createConnectionId :: HeavyweightConnectionId 
                   -> LightweightConnectionId 
                   -> ConnectionId
createConnectionId hcid lcid = 
  (fromIntegral hcid `shiftL` 32) .|. fromIntegral lcid

-- | @spltiMaxFromEnd p n xs@ splits list @xs@ at elements matching @p@,
-- returning at most @p@ segments -- counting from the /end/
-- 
-- > splitMaxFromEnd (== ':') 2 "ab:cd:ef:gh" == ["ab:cd", "ef", "gh"]
splitMaxFromEnd :: (a -> Bool) -> Int -> [a] -> [[a]]
splitMaxFromEnd p = \n -> go [[]] n . reverse 
  where
    -- go :: [[a]] -> Int -> [a] -> [[a]]
    go accs         _ []     = accs 
    go ([]  : accs) 0 xs     = reverse xs : accs
    go (acc : accs) n (x:xs) =
      if p x then go ([] : acc : accs) (n - 1) xs
             else go ((x : acc) : accs) n xs
    go _ _ _ = error "Bug in splitMaxFromEnd"

--------------------------------------------------------------------------------
-- Functions from TransportInternals                                          --
--------------------------------------------------------------------------------

-- Find a socket between two endpoints
-- 
-- Throws an IO exception if the socket could not be found.
internalSocketBetween :: TCPTransport    -- ^ Transport 
                      -> EndPointAddress -- ^ Local endpoint
                      -> EndPointAddress -- ^ Remote endpoint
                      -> IO N.Socket 
internalSocketBetween transport ourAddress theirAddress = do
  ourEndPoint <- withMVar (transportState transport) $ \st -> case st of
      TransportClosed -> 
        throwIO $ userError "Transport closed" 
      TransportValid vst -> 
        case vst ^. localEndPointAt ourAddress of
          Nothing -> throwIO $ userError "Local endpoint not found"
          Just ep -> return ep
  theirEndPoint <- withMVar (localState ourEndPoint) $ \st -> case st of
      LocalEndPointClosed ->
        throwIO $ userError "Local endpoint closed"
      LocalEndPointValid vst -> 
        case vst ^. localConnectionTo theirAddress of
          Nothing -> throwIO $ userError "Remote endpoint not found"
          Just ep -> return ep
  withMVar (remoteState theirEndPoint) $ \st -> case st of
    RemoteEndPointInit _ _ _ ->
      throwIO $ userError "Remote endpoint not yet initialized"
    RemoteEndPointValid vst -> 
      return $ remoteSocket vst
    RemoteEndPointClosing _ vst ->
      return $ remoteSocket vst 
    RemoteEndPointClosed ->
      throwIO $ userError "Remote endpoint closed"
    RemoteEndPointInvalid err ->
      throwIO err 
    RemoteEndPointFailed err ->
      throwIO err 

--------------------------------------------------------------------------------
-- Constants                                                                  --
--------------------------------------------------------------------------------

-- | We reserve a bunch of connection IDs for control messages
firstNonReservedLightweightConnectionId :: LightweightConnectionId
firstNonReservedLightweightConnectionId = 1024

-- | Self-connection
heavyweightSelfConnectionId :: HeavyweightConnectionId
heavyweightSelfConnectionId = 0

-- | We reserve some connection IDs for special heavyweight connections
firstNonReservedHeavyweightConnectionId :: HeavyweightConnectionId
firstNonReservedHeavyweightConnectionId = 1

--------------------------------------------------------------------------------
-- Accessor definitions                                                       --
--------------------------------------------------------------------------------

localEndPoints :: Accessor ValidTransportState (Map EndPointAddress LocalEndPoint)
localEndPoints = accessor _localEndPoints (\es st -> st { _localEndPoints = es })

nextEndPointId :: Accessor ValidTransportState EndPointId
nextEndPointId = accessor _nextEndPointId (\eid st -> st { _nextEndPointId = eid })

localNextConnOutId :: Accessor ValidLocalEndPointState LightweightConnectionId
localNextConnOutId = accessor _localNextConnOutId (\cix st -> st { _localNextConnOutId = cix })

localConnections :: Accessor ValidLocalEndPointState (Map EndPointAddress RemoteEndPoint)
localConnections = accessor _localConnections (\es st -> st { _localConnections = es })

nextConnInId :: Accessor ValidLocalEndPointState HeavyweightConnectionId 
nextConnInId = accessor _nextConnInId (\rid st -> st { _nextConnInId = rid })

remoteOutgoing :: Accessor ValidRemoteEndPointState Int
remoteOutgoing = accessor _remoteOutgoing (\cs conn -> conn { _remoteOutgoing = cs })

remoteIncoming :: Accessor ValidRemoteEndPointState (Set LightweightConnectionId)
remoteIncoming = accessor _remoteIncoming (\cs conn -> conn { _remoteIncoming = cs })

remoteMaxIncoming :: Accessor ValidRemoteEndPointState LightweightConnectionId
remoteMaxIncoming = accessor _remoteMaxIncoming (\lcid st -> st { _remoteMaxIncoming = lcid })

remoteNextConnOutId :: Accessor ValidRemoteEndPointState LightweightConnectionId
remoteNextConnOutId = accessor _remoteNextConnOutId (\cix st -> st { _remoteNextConnOutId = cix })

localEndPointAt :: EndPointAddress -> Accessor ValidTransportState (Maybe LocalEndPoint)
localEndPointAt addr = localEndPoints >>> DAC.mapMaybe addr 

localConnectionTo :: EndPointAddress -> Accessor ValidLocalEndPointState (Maybe RemoteEndPoint)
localConnectionTo addr = localConnections >>> DAC.mapMaybe addr 

-------------------------------------------------------------------------------
-- Debugging                                                                 --
-------------------------------------------------------------------------------

relyViolation :: EndPointPair -> String -> IO a
relyViolation (ourEndPoint, theirEndPoint) str = do
  elog (ourEndPoint, theirEndPoint) (str ++ " RELY violation")  
  fail (str ++ " RELY violation")

elog :: EndPointPair -> String -> IO ()
elog (ourEndPoint, theirEndPoint) msg = do
  tid <- myThreadId
  putStrLn  $  show (localAddress ourEndPoint) 
    ++ "/"  ++ show (remoteAddress theirEndPoint) 
    ++ "("  ++ show (remoteId theirEndPoint) ++ ")"
    ++ "/"  ++ show tid 
    ++ ": " ++ msg