-- | Utility functions for TCP sockets
module Network.Transport.TCP.Internal
  ( ControlHeader(..)
  , encodeControlHeader
  , decodeControlHeader
  , ConnectionRequestResponse(..)
  , encodeConnectionRequestResponse
  , decodeConnectionRequestResponse
  , forkServer
  , recvWithLength
  , recvExact
  , recvWord32
  , encodeWord32
  , tryCloseSocket
  , tryShutdownSocketBoth
  , resolveSockAddr
  , EndPointId
  , encodeEndPointAddress
  , decodeEndPointAddress
  , randomEndPointAddress
  , ProtocolVersion
  , currentProtocolVersion
  ) where

#if ! MIN_VERSION_base(4,6,0)
import Prelude hiding (catch)
#endif

import Network.Transport.Internal
  ( decodeWord32
  , encodeWord32
  , void
  , tryIO
  , forkIOWithUnmask
  )

import Network.Transport ( EndPointAddress(..) )

#ifdef USE_MOCK_NETWORK
import qualified Network.Transport.TCP.Mock.Socket as N
#else
import qualified Network.Socket as N
#endif
  ( HostName
  , NameInfoFlag(NI_NUMERICHOST)
  , ServiceName
  , Socket
  , SocketType(Stream)
  , SocketOption(ReuseAddr)
  , getAddrInfo
  , defaultHints
  , socket
  , bind
  , listen
  , addrFamily
  , addrAddress
  , defaultProtocol
  , setSocketOption
  , accept
  , close
  , socketPort
  , shutdown
  , ShutdownCmd(ShutdownBoth)
  , SockAddr(..)
  , getNameInfo
  )

#ifdef USE_MOCK_NETWORK
import qualified Network.Transport.TCP.Mock.Socket.ByteString as NBS (recv)
#else
import qualified Network.Socket.ByteString as NBS (recv)
#endif

import Data.Word (Word32, Word64)

import Control.Monad (forever, when)
import Control.Exception (SomeException, catch, bracketOnError, throwIO, mask_)
import Control.Concurrent (ThreadId, forkIO)
import Control.Concurrent.MVar
  ( MVar
  , newEmptyMVar
  , putMVar
  , readMVar
  )
import Control.Monad (forever, when)
import Control.Exception
  ( SomeException
  , catch
  , bracketOnError
  , throwIO
  , mask_
  , mask
  , finally
  , onException
  )

import Control.Applicative ((<$>), (<*>))
import Data.Word (Word32)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS (length, concat, null)
import Data.ByteString.Lazy.Internal (smallChunkSize)
import Data.ByteString.Lazy (toStrict)
import qualified Data.ByteString.Char8 as BSC (unpack, pack)
import Data.ByteString.Lazy.Builder (word64BE, toLazyByteString)
import Data.Monoid ((<>))
import qualified Data.UUID as UUID
import qualified Data.UUID.V4 as UUID

-- | Local identifier for an endpoint within this transport
type EndPointId = Word32

-- | Identifies the version of the network-transport-tcp protocol.
-- It's the first piece of data sent when a new heavyweight connection is
-- established.
type ProtocolVersion = Word32

currentProtocolVersion :: ProtocolVersion
currentProtocolVersion = 0x00000000

-- | Control headers
data ControlHeader =
    -- | Tell the remote endpoint that we created a new connection
    CreatedNewConnection
    -- | Tell the remote endpoint we will no longer be using a connection
  | CloseConnection
    -- | Request to close the connection (see module description)
  | CloseSocket
    -- | Sent by an endpoint when it is closed.
  | CloseEndPoint
    -- | Message sent to probe a socket
  | ProbeSocket
    -- | Acknowledgement of the ProbeSocket message
  | ProbeSocketAck
  deriving (Show)

decodeControlHeader :: Word32 -> Maybe ControlHeader
decodeControlHeader w32 = case w32 of
  0 -> Just CreatedNewConnection
  1 -> Just CloseConnection
  2 -> Just CloseSocket
  3 -> Just CloseEndPoint
  4 -> Just ProbeSocket
  5 -> Just ProbeSocketAck
  _ -> Nothing

encodeControlHeader :: ControlHeader -> Word32
encodeControlHeader ch = case ch of
  CreatedNewConnection -> 0
  CloseConnection      -> 1
  CloseSocket          -> 2
  CloseEndPoint        -> 3
  ProbeSocket          -> 4
  ProbeSocketAck       -> 5

-- | Response sent by /B/ to /A/ when /A/ tries to connect
data ConnectionRequestResponse =
    -- | /B/ does not support the protocol version requested by /A/.
    ConnectionRequestUnsupportedVersion
    -- | /B/ accepts the connection
  | ConnectionRequestAccepted
    -- | /A/ requested an invalid endpoint
  | ConnectionRequestInvalid
    -- | /A/s request crossed with a request from /B/ (see protocols)
  | ConnectionRequestCrossed
    -- | /A/ gave an incorrect host (did not match the host that /B/ observed).
  | ConnectionRequestHostMismatch
  deriving (Show)

decodeConnectionRequestResponse :: Word32 -> Maybe ConnectionRequestResponse
decodeConnectionRequestResponse w32 = case w32 of
  0xFFFFFFFF -> Just ConnectionRequestUnsupportedVersion
  0x00000000 -> Just ConnectionRequestAccepted
  0x00000001 -> Just ConnectionRequestInvalid
  0x00000002 -> Just ConnectionRequestCrossed
  0x00000003 -> Just ConnectionRequestHostMismatch
  _          -> Nothing

encodeConnectionRequestResponse :: ConnectionRequestResponse -> Word32
encodeConnectionRequestResponse crr = case crr of
  ConnectionRequestUnsupportedVersion -> 0xFFFFFFFF
  ConnectionRequestAccepted           -> 0x00000000
  ConnectionRequestInvalid            -> 0x00000001
  ConnectionRequestCrossed            -> 0x00000002
  ConnectionRequestHostMismatch       -> 0x00000003

-- | Generate an EndPointAddress which does not encode a host/port/endpointid.
-- Such addresses are used for unreachable endpoints, and for ephemeral
-- addresses when such endpoints establish new heavyweight connections.
randomEndPointAddress :: IO EndPointAddress
randomEndPointAddress = do
  uuid <- UUID.nextRandom
  return $ EndPointAddress (UUID.toASCIIBytes uuid)

-- | Start a server at the specified address.
--
-- This sets up a server socket for the specified host and port. Exceptions
-- thrown during setup are not caught.
--
-- Once the socket is created we spawn a new thread which repeatedly accepts
-- incoming connections and executes the given request handler in another
-- thread. If any exception occurs the accepting thread terminates and calls
-- the terminationHandler. Threads spawned for previous accepted connections
-- are not killed.
-- This exception may occur because of a call to 'N.accept', or because the
-- thread was explicitly killed.
--
-- The request handler is not responsible for closing the socket. It will be
-- closed once that handler returns. Take care to ensure that the socket is not
-- used after the handler returns, or you will get undefined behavior
-- (the file descriptor may be re-used).
--
-- The return value includes the port was bound to. This is not always the same
-- port as that given in the argument. For example, binding to port 0 actually
-- binds to a random port, selected by the OS.
forkServer :: N.HostName                     -- ^ Host
           -> N.ServiceName                  -- ^ Port
           -> Int                            -- ^ Backlog (maximum number of queued connections)
           -> Bool                           -- ^ Set ReuseAddr option?
           -> (SomeException -> IO ())       -- ^ Error handler. Called with an
                                             --   exception raised when
                                             --   accepting a connection.
           -> (SomeException -> IO ())       -- ^ Termination handler. Called
                                             --   when the error handler throws
                                             --   an exception.
           -> (IO () -> (N.Socket, N.SockAddr) -> IO ())
                                             -- ^ Request handler. Gets an
                                             --   action which completes when
                                             --   the socket is closed.
           -> IO (N.ServiceName, ThreadId)
forkServer host port backlog reuseAddr errorHandler terminationHandler requestHandler = do
    -- Resolve the specified address. By specification, getAddrInfo will never
    -- return an empty list (but will throw an exception instead) and will return
    -- the "best" address first, whatever that means
    addr:_ <- N.getAddrInfo (Just N.defaultHints) (Just host) (Just port)
    bracketOnError (N.socket (N.addrFamily addr) N.Stream N.defaultProtocol)
                   tryCloseSocket $ \sock -> do
      when reuseAddr $ N.setSocketOption sock N.ReuseAddr 1
      N.bind sock (N.addrAddress addr)
      N.listen sock backlog

      -- Close up and fill the synchonizing MVar.
      let release :: ((N.Socket, N.SockAddr), MVar ()) -> IO ()
          release ((sock, _), socketClosed) =
            N.close sock `finally` putMVar socketClosed ()

      -- Run the request handler.
      let act restore (sock, sockAddr) = do
            socketClosed <- newEmptyMVar
            void $ forkIO $ restore $ do
              requestHandler (readMVar socketClosed) (sock, sockAddr)
              `finally`
              release ((sock, sockAddr), socketClosed)

      let acceptRequest :: IO ()
          acceptRequest = mask $ \restore -> do
            -- Async exceptions are masked so that, if accept does give a
            -- socket, we'll always deliver it to the handler before the
            -- exception is raised.
            -- If it's a Right handler then it will eventually be closed.
            -- If it's a Left handler then we assume the handler itself will
            -- close it.
            (sock, sockAddr) <- N.accept sock
            -- Looks like 'act' will never throw an exception, but to be
            -- safe we'll close the socket if it does.
            let handler :: SomeException -> IO ()
                handler _ = N.close sock
            catch (act restore (sock, sockAddr)) handler

      -- We start listening for incoming requests in a separate thread. When
      -- that thread is killed, we close the server socket and the termination
      -- handler is run. We have to make sure that the exception handler is
      -- installed /before/ any asynchronous exception occurs. So we mask_, then
      -- fork (the child thread inherits the masked state from the parent), then
      -- unmask only inside the catch.
      (,) <$> fmap show (N.socketPort sock) <*>
        (mask_ $ forkIOWithUnmask $ \unmask ->
          catch (unmask (forever (catch acceptRequest errorHandler))) $ \ex -> do
            tryCloseSocket sock
            terminationHandler ex)

-- | Read a length and then a payload of that length, subject to a limit
--   on the length.
--   If the length (first 'Word32' received) is greater than the limit then
--   an exception is thrown.
recvWithLength :: Word32 -> N.Socket -> IO [ByteString]
recvWithLength limit sock = do
  len <- recvWord32 sock
  when (len > limit) $
    throwIO (userError "recvWithLength: limit exceeded")
  recvExact sock len

-- | Receive a 32-bit unsigned integer
recvWord32 :: N.Socket -> IO Word32
recvWord32 = fmap (decodeWord32 . BS.concat) . flip recvExact 4

-- | Close a socket, ignoring I/O exceptions.
tryCloseSocket :: N.Socket -> IO ()
tryCloseSocket sock = void . tryIO $
  N.close sock

-- | Shutdown socket sends and receives, ignoring I/O exceptions.
tryShutdownSocketBoth :: N.Socket -> IO ()
tryShutdownSocketBoth sock = void . tryIO $
  N.shutdown sock N.ShutdownBoth

-- | Read an exact number of bytes from a socket
--
-- Throws an I/O exception if the socket closes before the specified
-- number of bytes could be read
recvExact :: N.Socket        -- ^ Socket to read from
          -> Word32          -- ^ Number of bytes to read
          -> IO [ByteString] -- ^ Data read
recvExact sock len = go [] len
  where
    go :: [ByteString] -> Word32 -> IO [ByteString]
    go acc 0 = return (reverse acc)
    go acc l = do
      bs <- NBS.recv sock (fromIntegral l `min` smallChunkSize)
      if BS.null bs
        then throwIO (userError "recvExact: Socket closed")
        else go (bs : acc) (l - fromIntegral (BS.length bs))

-- | Get the numeric host, resolved host (via getNameInfo), and port from a
-- SockAddr. The numeric host is first, then resolved host (which may be the
-- same as the numeric host).
-- Will only give 'Just' for IPv4 addresses.
resolveSockAddr :: N.SockAddr -> IO (Maybe (N.HostName, N.HostName, N.ServiceName))
resolveSockAddr sockAddr = case sockAddr of
  N.SockAddrInet port host -> do
    (mResolvedHost, mResolvedPort) <- N.getNameInfo [] True False sockAddr
    case (mResolvedHost, mResolvedPort) of
      (Just resolvedHost, Nothing) -> do
        (Just numericHost, _) <- N.getNameInfo [N.NI_NUMERICHOST] True False sockAddr
        return $ Just (numericHost, resolvedHost, show port)
      _ -> error $ concat [
          "decodeSockAddr: unexpected resolution "
        , show sockAddr
        , " -> "
        , show mResolvedHost
        , ", "
        , show mResolvedPort
        ]
  _ -> return Nothing

-- | Encode end point address
encodeEndPointAddress :: N.HostName
                      -> N.ServiceName
                      -> EndPointId
                      -> EndPointAddress
encodeEndPointAddress host port ix = EndPointAddress . BSC.pack $
  host ++ ":" ++ port ++ ":" ++ show ix

-- | Decode end point address
decodeEndPointAddress :: EndPointAddress
                      -> Maybe (N.HostName, N.ServiceName, EndPointId)
decodeEndPointAddress (EndPointAddress bs) =
  case splitMaxFromEnd (== ':') 2 $ BSC.unpack bs of
    [host, port, endPointIdStr] ->
      case reads endPointIdStr of
        [(endPointId, "")] -> Just (host, port, endPointId)
        _                  -> Nothing
    _ ->
      Nothing

-- | @spltiMaxFromEnd p n xs@ splits list @xs@ at elements matching @p@,
-- returning at most @p@ segments -- counting from the /end/
--
-- > splitMaxFromEnd (== ':') 2 "ab:cd:ef:gh" == ["ab:cd", "ef", "gh"]
splitMaxFromEnd :: (a -> Bool) -> Int -> [a] -> [[a]]
splitMaxFromEnd p = \n -> go [[]] n . reverse
  where
    -- go :: [[a]] -> Int -> [a] -> [[a]]
    go accs         _ []     = accs
    go ([]  : accs) 0 xs     = reverse xs : accs
    go (acc : accs) n (x:xs) =
      if p x then go ([] : acc : accs) (n - 1) xs
             else go ((x : acc) : accs) n xs
    go _ _ _ = error "Bug in splitMaxFromEnd"