-- | Utility functions for TCP sockets
module Network.Transport.TCP.Internal
  ( ControlHeader(..)
  , encodeControlHeader
  , decodeControlHeader
  , ConnectionRequestResponse(..)
  , encodeConnectionRequestResponse
  , decodeConnectionRequestResponse
  , forkServer
  , recvWithLength
  , recvExact
  , recvWord32
  , encodeWord32
  , tryCloseSocket
  , tryShutdownSocketBoth
  , decodeSockAddr
  , EndPointId
  , encodeEndPointAddress
  , decodeEndPointAddress
  , ProtocolVersion
  , currentProtocolVersion
  ) where

#if ! MIN_VERSION_base(4,6,0)
import Prelude hiding (catch)
#endif

import Network.Transport.Internal
  ( decodeWord32
  , encodeWord32
  , void
  , tryIO
  , forkIOWithUnmask
  )

import Network.Transport ( EndPointAddress(..) )

#ifdef USE_MOCK_NETWORK
import qualified Network.Transport.TCP.Mock.Socket as N
#else
import qualified Network.Socket as N
#endif
  ( HostName
  , ServiceName
  , Socket
  , SocketType(Stream)
  , SocketOption(ReuseAddr)
  , getAddrInfo
  , defaultHints
  , socket
  , bindSocket
  , listen
  , addrFamily
  , addrAddress
  , defaultProtocol
  , setSocketOption
  , accept
  , sClose
  , socketPort
  , shutdown
  , ShutdownCmd(ShutdownBoth)
  , SockAddr(..)
  , inet_ntoa
  )

#ifdef USE_MOCK_NETWORK
import qualified Network.Transport.TCP.Mock.Socket.ByteString as NBS (recv)
#else
import qualified Network.Socket.ByteString as NBS (recv)
#endif

import Data.Word (Word32)

import Control.Monad (forever, when)
import Control.Exception (SomeException, catch, bracketOnError, throwIO, mask_)
import Control.Concurrent (ThreadId, forkIO)
import Control.Concurrent.MVar
  ( MVar
  , newEmptyMVar
  , putMVar
  , readMVar
  )
import Control.Monad (forever, when)
import Control.Exception
  ( SomeException
  , catch
  , bracketOnError
  , throwIO
  , mask_
  , mask
  , finally
  , onException
  )

import Control.Applicative ((<$>), (<*>))
import Data.Word (Word32)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS (length, concat, null)
import Data.ByteString.Lazy.Internal (smallChunkSize)
import qualified Data.ByteString.Char8 as BSC (unpack, pack)

-- | Local identifier for an endpoint within this transport
type EndPointId = Word32

-- | Identifies the version of the network-transport-tcp protocol.
-- It's the first piece of data sent when a new heavyweight connection is
-- established.
type ProtocolVersion = Word32

currentProtocolVersion :: ProtocolVersion
currentProtocolVersion = 0x00000000

-- | Control headers
data ControlHeader =
    -- | Tell the remote endpoint that we created a new connection
    CreatedNewConnection
    -- | Tell the remote endpoint we will no longer be using a connection
  | CloseConnection
    -- | Request to close the connection (see module description)
  | CloseSocket
    -- | Sent by an endpoint when it is closed.
  | CloseEndPoint
    -- | Message sent to probe a socket
  | ProbeSocket
    -- | Acknowledgement of the ProbeSocket message
  | ProbeSocketAck
  deriving (Show)

decodeControlHeader :: Word32 -> Maybe ControlHeader
decodeControlHeader w32 = case w32 of
  0 -> Just CreatedNewConnection
  1 -> Just CloseConnection
  2 -> Just CloseSocket
  3 -> Just CloseEndPoint
  4 -> Just ProbeSocket
  5 -> Just ProbeSocketAck
  _ -> Nothing

encodeControlHeader :: ControlHeader -> Word32
encodeControlHeader ch = case ch of
  CreatedNewConnection -> 0
  CloseConnection      -> 1
  CloseSocket          -> 2
  CloseEndPoint        -> 3
  ProbeSocket          -> 4
  ProbeSocketAck       -> 5

-- | Response sent by /B/ to /A/ when /A/ tries to connect
data ConnectionRequestResponse =
    -- | /B/ does not support the protocol version requested by /A/.
    ConnectionRequestUnsupportedVersion
    -- | /B/ accepts the connection
  | ConnectionRequestAccepted
    -- | /A/ requested an invalid endpoint
  | ConnectionRequestInvalid
    -- | /A/s request crossed with a request from /B/ (see protocols)
  | ConnectionRequestCrossed
    -- | /A/ gave an incorrect host (did not match the host that /B/ observed).
  | ConnectionRequestHostMismatch
  deriving (Show)

decodeConnectionRequestResponse :: Word32 -> Maybe ConnectionRequestResponse
decodeConnectionRequestResponse w32 = case w32 of
  0xFFFFFFFF -> Just ConnectionRequestUnsupportedVersion
  0x00000000 -> Just ConnectionRequestAccepted
  0x00000001 -> Just ConnectionRequestInvalid
  0x00000002 -> Just ConnectionRequestCrossed
  0x00000003 -> Just ConnectionRequestHostMismatch
  _          -> Nothing

encodeConnectionRequestResponse :: ConnectionRequestResponse -> Word32
encodeConnectionRequestResponse crr = case crr of
  ConnectionRequestUnsupportedVersion -> 0xFFFFFFFF
  ConnectionRequestAccepted           -> 0x00000000
  ConnectionRequestInvalid            -> 0x00000001
  ConnectionRequestCrossed            -> 0x00000002
  ConnectionRequestHostMismatch       -> 0x00000003

-- | Start a server at the specified address.
--
-- This sets up a server socket for the specified host and port. Exceptions
-- thrown during setup are not caught.
--
-- Once the socket is created we spawn a new thread which repeatedly accepts
-- incoming connections and executes the given request handler in another
-- thread. If any exception occurs the accepting thread terminates and calls
-- the terminationHandler. Threads spawned for previous accepted connections
-- are not killed.
-- This exception may occur because of a call to 'N.accept', or because the
-- thread was explicitly killed.
--
-- The request handler is not responsible for closing the socket. It will be
-- closed once that handler returns. Take care to ensure that the socket is not
-- used after the handler returns, or you will get undefined behavior
-- (the file descriptor may be re-used).
--
-- The return value includes the port was bound to. This is not always the same
-- port as that given in the argument. For example, binding to port 0 actually
-- binds to a random port, selected by the OS.
forkServer :: N.HostName                     -- ^ Host
           -> N.ServiceName                  -- ^ Port
           -> Int                            -- ^ Backlog (maximum number of queued connections)
           -> Bool                           -- ^ Set ReuseAddr option?
           -> (SomeException -> IO ())       -- ^ Error handler. Called with an
                                             --   exception raised when
                                             --   accepting a connection.
           -> (SomeException -> IO ())       -- ^ Termination handler. Called
                                             --   when the error handler throws
                                             --   an exception.
           -> (IO () -> (N.Socket, N.SockAddr) -> IO ())
                                             -- ^ Request handler. Gets an
                                             --   action which completes when
                                             --   the socket is closed.
           -> IO (N.ServiceName, ThreadId)
forkServer host port backlog reuseAddr errorHandler terminationHandler requestHandler = do
    -- Resolve the specified address. By specification, getAddrInfo will never
    -- return an empty list (but will throw an exception instead) and will return
    -- the "best" address first, whatever that means
    addr:_ <- N.getAddrInfo (Just N.defaultHints) (Just host) (Just port)
    bracketOnError (N.socket (N.addrFamily addr) N.Stream N.defaultProtocol)
                   tryCloseSocket $ \sock -> do
      when reuseAddr $ N.setSocketOption sock N.ReuseAddr 1
      N.bindSocket sock (N.addrAddress addr)
      N.listen sock backlog

      -- Close up and fill the synchonizing MVar.
      let release :: ((N.Socket, N.SockAddr), MVar ()) -> IO ()
          release ((sock, _), socketClosed) =
            N.sClose sock `finally` putMVar socketClosed ()

      -- Run the request handler.
      let act restore (sock, sockAddr) = do
            socketClosed <- newEmptyMVar
            void $ forkIO $ restore $ do
              requestHandler (readMVar socketClosed) (sock, sockAddr)
              `finally`
              release ((sock, sockAddr), socketClosed)

      let acceptRequest :: IO ()
          acceptRequest = mask $ \restore -> do
            -- Async exceptions are masked so that, if accept does give a
            -- socket, we'll always deliver it to the handler before the
            -- exception is raised.
            -- If it's a Right handler then it will eventually be closed.
            -- If it's a Left handler then we assume the handler itself will
            -- close it.
            (sock, sockAddr) <- N.accept sock
            -- Looks like 'act' will never throw an exception, but to be
            -- safe we'll close the socket if it does.
            let handler :: SomeException -> IO ()
                handler _ = N.sClose sock
            catch (act restore (sock, sockAddr)) handler

      -- We start listening for incoming requests in a separate thread. When
      -- that thread is killed, we close the server socket and the termination
      -- handler is run. We have to make sure that the exception handler is
      -- installed /before/ any asynchronous exception occurs. So we mask_, then
      -- fork (the child thread inherits the masked state from the parent), then
      -- unmask only inside the catch.
      (,) <$> fmap show (N.socketPort sock) <*>
        (mask_ $ forkIOWithUnmask $ \unmask ->
          catch (unmask (forever (catch acceptRequest errorHandler))) $ \ex -> do
            tryCloseSocket sock
            terminationHandler ex)

-- | Read a length and then a payload of that length, subject to a limit
--   on the length.
--   If the length (first 'Word32' received) is greater than the limit then
--   an exception is thrown.
recvWithLength :: Word32 -> N.Socket -> IO [ByteString]
recvWithLength limit sock = do
  len <- recvWord32 sock
  when (len > limit) $
    throwIO (userError "recvWithLength: limit exceeded")
  recvExact sock len

-- | Receive a 32-bit unsigned integer
recvWord32 :: N.Socket -> IO Word32
recvWord32 = fmap (decodeWord32 . BS.concat) . flip recvExact 4

-- | Close a socket, ignoring I/O exceptions.
tryCloseSocket :: N.Socket -> IO ()
tryCloseSocket sock = void . tryIO $
  N.sClose sock

-- | Shutdown socket sends and receives, ignoring I/O exceptions.
tryShutdownSocketBoth :: N.Socket -> IO ()
tryShutdownSocketBoth sock = void . tryIO $
  N.shutdown sock N.ShutdownBoth

-- | Read an exact number of bytes from a socket
--
-- Throws an I/O exception if the socket closes before the specified
-- number of bytes could be read
recvExact :: N.Socket        -- ^ Socket to read from
          -> Word32          -- ^ Number of bytes to read
          -> IO [ByteString] -- ^ Data read
recvExact sock len = go [] len
  where
    go :: [ByteString] -> Word32 -> IO [ByteString]
    go acc 0 = return (reverse acc)
    go acc l = do
      bs <- NBS.recv sock (fromIntegral l `min` smallChunkSize)
      if BS.null bs
        then throwIO (userError "recvExact: Socket closed")
        else go (bs : acc) (l - fromIntegral (BS.length bs))

-- | Produce a HostName and ServiceName from a SockAddr. Only gives 'Just' for
-- IPv4 addresses.
decodeSockAddr :: N.SockAddr -> IO (Maybe (N.HostName, N.ServiceName))
decodeSockAddr sockAddr = case sockAddr of
  N.SockAddrInet port host -> do
    hostString <- N.inet_ntoa host
    return $ Just (hostString, show port)
  _ -> return Nothing

-- | Encode end point address
encodeEndPointAddress :: N.HostName
                      -> N.ServiceName
                      -> EndPointId
                      -> EndPointAddress
encodeEndPointAddress host port ix = EndPointAddress . BSC.pack $
  host ++ ":" ++ port ++ ":" ++ show ix

-- | Decode end point address
decodeEndPointAddress :: EndPointAddress
                      -> Maybe (N.HostName, N.ServiceName, EndPointId)
decodeEndPointAddress (EndPointAddress bs) =
  case splitMaxFromEnd (== ':') 2 $ BSC.unpack bs of
    [host, port, endPointIdStr] ->
      case reads endPointIdStr of
        [(endPointId, "")] -> Just (host, port, endPointId)
        _                  -> Nothing
    _ ->
      Nothing

-- | @spltiMaxFromEnd p n xs@ splits list @xs@ at elements matching @p@,
-- returning at most @p@ segments -- counting from the /end/
--
-- > splitMaxFromEnd (== ':') 2 "ab:cd:ef:gh" == ["ab:cd", "ef", "gh"]
splitMaxFromEnd :: (a -> Bool) -> Int -> [a] -> [[a]]
splitMaxFromEnd p = \n -> go [[]] n . reverse
  where
    -- go :: [[a]] -> Int -> [a] -> [[a]]
    go accs         _ []     = accs
    go ([]  : accs) 0 xs     = reverse xs : accs
    go (acc : accs) n (x:xs) =
      if p x then go ([] : acc : accs) (n - 1) xs
             else go ((x : acc) : accs) n xs
    go _ _ _ = error "Bug in splitMaxFromEnd"