{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE FlexibleInstances   #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving  #-}
{-# LANGUAGE TypeFamilies        #-}
--------------------------------------------------------------------------------
-- |
-- Module      :  Network.MQTT.Broker.Server
-- Copyright   :  (c) Lars Petersen 2016
-- License     :  MIT
--
-- Maintainer  :  info@lars-petersen.net
-- Stability   :  experimental
--------------------------------------------------------------------------------
module Network.MQTT.Broker.Server
  ( serveConnection
  , MQTT ()
  , MqttServerTransportStack (..)
  , SS.Server ( .. )
  , SS.ServerConfig ( .. )
  , SS.ServerConnection ( .. )
  , SS.ServerException ( .. )
  ) where

import           Control.Concurrent
import           Control.Concurrent.Async
import qualified Control.Exception                  as E
import           Control.Monad
import qualified Data.Binary.Get                    as SG
import qualified Data.ByteString                    as BS
import           Data.Int
import           Data.IORef
import           Data.Typeable
import qualified Network.Stack.Server               as SS
import qualified Network.WebSockets                 as WS
import qualified System.Log.Logger                  as Log
import qualified System.Socket                      as S

import           Network.MQTT.Message
import           Network.MQTT.Broker.Authentication
import qualified Network.MQTT.Broker.Internal       as Session
import qualified Network.MQTT.Broker                as Broker
import qualified Network.MQTT.Broker.Session        as Session

instance (Typeable transport) => E.Exception (SS.ServerException (MQTT transport))

data MQTT transport

class SS.ServerStack a => MqttServerTransportStack a where
  getConnectionRequest :: SS.ServerConnectionInfo a -> IO ConnectionRequest

instance (Typeable f, Typeable t, Typeable p, S.Family f, S.Protocol p, S.Type t, S.HasNameInfo f) => MqttServerTransportStack (S.Socket f t p) where
  getConnectionRequest (SS.SocketServerConnectionInfo addr) = do
    remoteAddr <- S.hostName <$> S.getNameInfo addr (S.niNumericHost `mappend` S.niNumericService)
    pure ConnectionRequest {
        requestClientIdentifier = ClientIdentifier mempty
      , requestSecure = False
      , requestCleanSession = True
      , requestCredentials = Nothing
      , requestHttp = Nothing
      , requestCertificateChain = Nothing
      , requestRemoteAddress = Just remoteAddr
      }

instance (SS.StreamServerStack a, MqttServerTransportStack a) => MqttServerTransportStack (SS.WebSocket a) where
  getConnectionRequest (SS.WebSocketServerConnectionInfo tci rh) = do
    req <- getConnectionRequest tci
    pure req {
        requestHttp = Just (WS.requestPath rh, WS.requestHeaders rh)
      }

instance (SS.StreamServerStack a, MqttServerTransportStack a) => MqttServerTransportStack (SS.TLS a) where
  getConnectionRequest (SS.TlsServerConnectionInfo tci mcc) = do
    req <- getConnectionRequest tci
    pure req {
        requestSecure = True
      , requestCertificateChain = mcc
      }

instance (SS.StreamServerStack transport) => SS.ServerStack (MQTT transport) where
  data Server (MQTT transport) = MqttServer
    { mqttTransportServer     :: SS.Server transport
    }
  data ServerConfig (MQTT transport) = MqttServerConfig
    { mqttTransportConfig     :: SS.ServerConfig transport
    }
  data ServerConnection (MQTT transport) = MqttServerConnection
    { mqttTransportConnection :: SS.ServerConnection transport
    , mqttTransportLeftover   :: MVar BS.ByteString
    }
  data ServerConnectionInfo (MQTT transport) = MqttServerConnectionInfo
    { mqttTransportServerConnectionInfo :: SS.ServerConnectionInfo transport
    }
  data ServerException (MQTT transport)
    = ProtocolViolation String
    | MessageTooLong
    | ConnectionRejected RejectReason
    | KeepAliveTimeoutException
    deriving (Eq, Ord, Show, Typeable)
  withServer config handle =
    SS.withServer (mqttTransportConfig config) $ \server->
      handle (MqttServer server)
  withConnection server handler =
    SS.withConnection (mqttTransportServer server) $ \connection info->
      flip handler (MqttServerConnectionInfo info) =<< MqttServerConnection
        <$> pure connection
        <*> newMVar mempty

-- TODO: eventually too strict with message size tracking
instance (SS.StreamServerStack transport) => SS.MessageServerStack (MQTT transport) where
  type ClientMessage (MQTT transport) = ClientPacket
  type ServerMessage (MQTT transport) = ServerPacket
  sendMessage connection =
    SS.sendStreamBuilder (mqttTransportConnection connection) 8192 . serverPacketBuilder
  sendMessages connection msgs =
    SS.sendStreamBuilder (mqttTransportConnection connection) 8192 $ foldl (\b m-> b `mappend` serverPacketBuilder m) mempty msgs
  receiveMessage connection maxMsgSize =
    modifyMVar (mqttTransportLeftover connection) (execute 0 . SG.pushChunk decode)
    where
      fetch  = SS.receiveStream (mqttTransportConnection connection) 4096
      decode = SG.runGetIncremental clientPacketParser
      execute received result
        | received > maxMsgSize = E.throwIO (MessageTooLong :: SS.ServerException (MQTT transport))
        | otherwise = case result of
            SG.Partial continuation -> do
              bs <- fetch
              if BS.null bs
                then execute received (continuation Nothing)
                else execute (received + fromIntegral (BS.length bs)) (continuation $ Just bs)
            SG.Fail _ _ failure ->
              E.throwIO (ProtocolViolation failure :: SS.ServerException (MQTT transport))
            SG.Done leftover' _ msg ->
              pure (leftover', msg)
  consumeMessages connection maxMsgSize consume =
    modifyMVar_ (mqttTransportLeftover connection) (execute 0 . SG.pushChunk decode)
    where
      fetch  = SS.receiveStream (mqttTransportConnection connection) 4096
      decode = SG.runGetIncremental clientPacketParser
      execute received result
        | received > maxMsgSize = E.throwIO (MessageTooLong :: SS.ServerException (MQTT transport))
        | otherwise = case result of
            SG.Partial continuation -> do
              bs <- fetch
              if BS.null bs
                then execute received (continuation Nothing)
                else execute (received + fromIntegral (BS.length bs)) (continuation $ Just bs)
            SG.Fail _ _ failure ->
              E.throwIO (ProtocolViolation failure :: SS.ServerException (MQTT transport))
            SG.Done leftover' _ msg -> do
              done <- consume msg
              if done
                then pure leftover'
                else execute 0 (SG.pushChunk decode leftover')

deriving instance Show (SS.ServerConnectionInfo transport) => Show (SS.ServerConnectionInfo (MQTT transport))

serveConnection :: forall transport auth. (SS.StreamServerStack transport, MqttServerTransportStack transport, Authenticator auth) => Broker.Broker auth -> SS.ServerConnection (MQTT transport) -> SS.ServerConnectionInfo (MQTT transport) -> IO ()
serveConnection broker conn connInfo = do
  recentActivity <- newIORef True
  req <- getConnectionRequest (mqttTransportServerConnectionInfo connInfo)
  msg <- SS.receiveMessage conn maxInitialPacketSize
  case msg of
    ClientConnectUnsupported -> do
      Log.warningM "Server.connection" $ "Connection from "
        ++ show (requestRemoteAddress req) ++ " rejected: UnacceptableProtocolVersion"
      void $ SS.sendMessage conn (ServerConnectionRejected UnacceptableProtocolVersion)
      -- Communication ends here gracefully. The caller shall close the connection.
    ClientConnect {} -> do
      let -- | This one is called when the authenticator decided to reject the request.
          sessionRejectHandler reason = do
            Log.warningM "Server.connection" $ "Connection rejected: " ++ show reason
            void $ SS.sendMessage conn (ServerConnectionRejected reason)
            -- Communication ends here gracefully. The caller shall close the connection.

          -- | This part is where the threads for a connection are created
          --   (one for input, one for output and one watchdog thread).
          sessionAcceptHandler session sessionPresent@(SessionPresent sp) = do
            principal <- Session.getPrincipal session
            Log.infoM "Server.connection" $ "Connection accepted: Associated "
              ++ show principal ++ (if sp then " with existing session "
              ++ show (Session.sessionIdentifier session) ++  "." else " with new session.")
            void $ SS.sendMessage conn (ServerConnectionAccepted sessionPresent)
            foldl1 race_
              [ handleInput recentActivity session
              , handleOutput session
              , keepAlive recentActivity (connectKeepAlive msg) session
              ] `E.catch` (\e-> do
                Log.warningM "Server.connection" $"Session " ++ show (Session.sessionIdentifier session)
                  ++ ": Connection terminated with exception: " ++ show (e :: E.SomeException)
                E.throwIO e
              )
            Log.infoM "Server.connection" $
              "Session " ++ show (Session.sessionIdentifier session) ++ ": Graceful disconnect."
          -- Extend the request object with information gathered from the connect packet.
          request = req {
              requestClientIdentifier = connectClientIdentifier msg
            , requestCleanSession     = cleanSession
            , requestCredentials      = connectCredentials msg
            }
            where
              CleanSession cleanSession = connectCleanSession msg
      Log.infoM "Server.connection" $ "Connection request: " ++ show request
      Broker.withSession broker request sessionRejectHandler sessionAcceptHandler
    _ -> pure () -- TODO: Don't parse not-CONN packets in the first place!
  where
    -- The size of the initial CONN packet shall somewhat be limited to a moderate size.
    -- This value is assumed to make no problems while still protecting
    -- the servers resources against exaustion attacks.
    maxInitialPacketSize :: Int64
    maxInitialPacketSize  = 65535
    -- The keep alive thread wakes up every `keepAlive/2` seconds.
    -- When it detects no recent activity, it sleeps one more full `keepAlive`
    -- interval and checks again. When it still finds no recent activity, it
    -- throws an exception.
    -- That way a timeout will be detected between 1.5 and 2 `keep alive`
    -- intervals after the last actual client activity.
    keepAlive :: IORef Bool -> KeepAliveInterval -> Session.Session auth -> IO ()
    keepAlive recentActivity (KeepAliveInterval interval) session = forever $ do
      writeIORef recentActivity False
      threadDelay regularInterval
      activity <- readIORef recentActivity
      unless activity $ do
        threadDelay regularInterval
        activity' <- readIORef recentActivity
        unless activity' $ do
          -- Alert state: The client must get active within the next interval.
          Log.warningM "Server.connection.keepAlive" $ "Session " ++ show (Session.sessionIdentifier session) ++ ": Client is overdue."
          threadDelay regularInterval
          activity'' <- readIORef recentActivity
          unless activity'' $ E.throwIO (KeepAliveTimeoutException :: SS.ServerException (MQTT transport))
      where
        regularInterval = fromIntegral interval *  500000
    -- | This thread is responsible for continuously processing input.
    --   It blocks on reading the input stream until input becomes available.
    --   Input is consumed no faster than it can be processed.
    --
    --   * Read packet from the input stream.
    --   * Note that there was activity (TODO: a timeout may occur when a packet
    --     takes too long to transmit due to its size).
    --   * Process and dispatch the message internally.
    --   * Repeat and eventually wait again.
    --   * Eventually throws `ServerException`s.
    handleInput :: IORef Bool -> Session.Session auth -> IO ()
    handleInput recentActivity session = do
      maxPacketSize <- fromIntegral . quotaMaxPacketSize . principalQuota <$> Session.getPrincipal session
      SS.consumeMessages conn maxPacketSize $ \packet-> do
        writeIORef recentActivity True
        --Log.debugM "Server.connection.handleInput" $ take 50 $ show packet
        case packet of
          ClientConnect {} ->
            E.throwIO (ProtocolViolation "Unexpected CONN packet." :: SS.ServerException (MQTT transport))
          ClientConnectUnsupported ->
            E.throwIO (ProtocolViolation "Unexpected CONN packet (of unsupported protocol version)." :: SS.ServerException (MQTT transport))
          ClientPublish pid dup msg -> do
            Session.processPublish session pid dup msg
            pure False
          ClientPublishAcknowledged pid -> do
            Session.processPublishAcknowledged session pid
            pure False
          ClientPublishReceived pid -> do
            Session.processPublishReceived session pid
            pure False
          ClientPublishRelease pid -> do
            Session.processPublishRelease session pid
            pure False
          ClientPublishComplete pid -> do
            Session.processPublishComplete session pid
            pure False
          ClientSubscribe pid filters -> do
            Session.subscribe session pid filters
            pure False
          ClientUnsubscribe pid filters -> do
            Session.unsubscribe session pid filters
            pure False
          ClientPingRequest -> do
            Session.enqueuePingResponse session
            pure False
          ClientDisconnect ->
            pure True
    -- | This thread is responsible for continuously transmitting to the client
    --   and reading from the output queue.
    --
    --   * It blocks on Session.waitPending until output gets available.
    --   * When output is available, it fetches a whole sequence of messages
    --     from the output queue.
    --   * It then uses the optimized SS.sendMessages operation which fills
    --     a whole chunk with as many messages as possible and sends the chunks
    --     each with a single system call. This is _very_ important for high
    --     throughput.
    --   * Afterwards, it repeats and eventually waits again.
    handleOutput :: Session.Session auth -> IO ()
    handleOutput session = forever $ do
      -- The `waitPending` operation is blocking until messages get available.
      Session.waitPending session
      msgs <- Session.dequeue session
      SS.sendMessages conn msgs