{-|
Module      : PostgresWebsockets.Middleware
Description : PostgresWebsockets WAI middleware, add functionality to any WAI application.

Allow websockets connections that will communicate with the database through LISTEN/NOTIFY channels.
-}
{-# LANGUAGE DeriveGeneric #-}

module PostgresWebsockets.Middleware
  ( postgresWsMiddleware
  ) where

import Protolude
import Data.Time.Clock (UTCTime)
import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds, posixSecondsToUTCTime)
import Control.Concurrent.AlarmClock (newAlarmClock, setAlarm)
import qualified Hasql.Notifications as H
import qualified Hasql.Pool as H
import qualified Network.Wai as Wai
import qualified Network.Wai.Handler.WebSockets as WS
import qualified Network.WebSockets as WS

import qualified Data.Aeson as A
import qualified Data.ByteString.Char8 as BS
import qualified Data.ByteString.Lazy as BL
import qualified Data.HashMap.Strict as M
import qualified Data.Text.Encoding.Error as T

import PostgresWebsockets.Broadcast (onMessage)
import PostgresWebsockets.Claims ( ConnectionInfo, validateClaims )
import PostgresWebsockets.Context ( Context(..) )
import PostgresWebsockets.Config (AppConfig(..))
import qualified PostgresWebsockets.Broadcast as B


data Event =
    WebsocketMessage
  | ConnectionOpen
  deriving (Show, Eq, Generic)

data Message = Message
  { claims  :: A.Object
  , event   :: Event
  , payload :: Text
  , channel :: Text
  } deriving (Show, Eq, Generic)

instance A.ToJSON Event
instance A.ToJSON Message

-- | Given a secret, a function to fetch the system time, a Hasql Pool and a Multiplexer this will give you a WAI middleware.
postgresWsMiddleware :: Context -> Wai.Middleware
postgresWsMiddleware =
  WS.websocketsOr WS.defaultConnectionOptions . wsApp

-- private functions
jwtExpirationStatusCode :: Word16
jwtExpirationStatusCode = 3001

-- when the websocket is closed a ConnectionClosed Exception is triggered
-- this kills all children and frees resources for us
wsApp :: Context -> WS.ServerApp
wsApp Context{..} pendingConn =
  ctxGetTime >>= validateClaims requestChannel (configJwtSecret ctxConfig) (toS jwtToken) >>= either rejectRequest forkSessions
  where
    hasRead m = m == ("r" :: ByteString) || m == ("rw" :: ByteString)
    hasWrite m = m == ("w" :: ByteString) || m == ("rw" :: ByteString)

    rejectRequest :: Text -> IO ()
    rejectRequest msg = do
      putErrLn $ "Rejecting Request: " <> msg
      WS.rejectRequest pendingConn (toS msg)

    -- the URI has one of the two formats - /:jwt or /:channel/:jwt
    pathElements = BS.split '/' $ BS.drop 1 $ WS.requestPath $ WS.pendingRequest pendingConn
    jwtToken =
      case length pathElements `compare` 1 of
        GT -> headDef "" $ tailSafe pathElements
        _ -> headDef "" pathElements
    requestChannel =
      case length pathElements `compare` 1 of
        GT -> Just $ headDef "" pathElements
        _ -> Nothing
    forkSessions :: ConnectionInfo -> IO ()
    forkSessions (chs, mode, validClaims) = do
          -- We should accept only after verifying JWT
          conn <- WS.acceptRequest pendingConn
          -- Fork a pinging thread to ensure browser connections stay alive
          WS.withPingThread conn 30 (pure ()) $ do
            case M.lookup "exp" validClaims of
              Just (A.Number expClaim) -> do
                connectionExpirer <- newAlarmClock $ const (WS.sendCloseCode conn jwtExpirationStatusCode ("JWT expired" :: ByteString))
                setAlarm connectionExpirer (posixSecondsToUTCTime $ realToFrac expClaim)
              Just _ -> pure ()
              Nothing -> pure ()

            let sendNotification msg channel = sendMessageWithTimestamp $ websocketMessageForChannel msg channel
                sendMessageToDatabase = sendToDatabase ctxPool (configListenChannel ctxConfig)
                sendMessageWithTimestamp = timestampMessage ctxGetTime >=> sendMessageToDatabase
                websocketMessageForChannel = Message validClaims WebsocketMessage
                connectionOpenMessage = Message validClaims ConnectionOpen

            case configMetaChannel ctxConfig of
              Nothing -> pure ()
              Just ch -> sendMessageWithTimestamp $ connectionOpenMessage (toS $ BS.intercalate "," chs) ch

            when (hasRead mode) $
              forM_ chs $ flip (onMessage ctxMulti) $ WS.sendTextData conn . B.payload

            when (hasWrite mode) $
              notifySession conn sendNotification chs

            waitForever <- newEmptyMVar
            void $ takeMVar waitForever

-- Having both channel and claims as parameters seem redundant
-- But it allows the function to ignore the claims structure and the source
-- of the channel, so all claims decoding can be coded in the caller
notifySession :: WS.Connection -> (Text -> Text -> IO ()) -> [ByteString] -> IO ()
notifySession wsCon sendToChannel chs =
  withAsync (forever relayData) wait
  where
    relayData = do
      msg <- WS.receiveData wsCon
      forM_ chs (sendToChannel msg . toS)

sendToDatabase :: H.Pool -> Text -> Message -> IO ()
sendToDatabase pool dbChannel =
  notify . jsonMsg
  where
    notify = void . H.notifyPool pool dbChannel . toS
    jsonMsg = BL.toStrict . A.encode

timestampMessage :: IO UTCTime -> Message -> IO Message
timestampMessage getTime msg@Message{..} = do
  time <- utcTimeToPOSIXSeconds <$> getTime
  return $ msg{ claims = M.insert "message_delivered_at" (A.Number $ realToFrac time) claims}