{-|
Module      : PostgresWebsockets.Middleware
Description : PostgresWebsockets WAI middleware, add functionality to any WAI application.

Allow websockets connections that will communicate with the database through LISTEN/NOTIFY channels.
-}
{-# LANGUAGE DeriveGeneric #-}

module PostgresWebsockets.Middleware
  ( postgresWsMiddleware
  ) where

import qualified Hasql.Pool                     as H
import qualified Hasql.Notifications            as H
import qualified Network.Wai                    as Wai
import qualified Network.Wai.Handler.WebSockets as WS
import qualified Network.WebSockets             as WS
import           Protolude

import qualified Data.Aeson                     as A
import qualified Data.ByteString.Char8          as BS
import qualified Data.ByteString.Lazy           as BL
import qualified Data.HashMap.Strict            as M
import qualified Data.Text.Encoding.Error       as T
import           Data.Time.Clock (UTCTime)
import           Data.Time.Clock.POSIX          (utcTimeToPOSIXSeconds, posixSecondsToUTCTime)
import           Control.Concurrent.AlarmClock (newAlarmClock, setAlarm)
import           PostgresWebsockets.Broadcast          (Multiplexer, onMessage)
import qualified PostgresWebsockets.Broadcast          as B
import           PostgresWebsockets.Claims

data Message = Message
  { claims  :: A.Object
  , channel :: Text
  , payload :: Text
  } deriving (Show, Eq, Generic)

instance A.ToJSON Message

-- | Given a secret, a function to fetch the system time, a Hasql Pool and a Multiplexer this will give you a WAI middleware.
postgresWsMiddleware :: IO UTCTime -> Text -> ByteString -> H.Pool -> Multiplexer -> Wai.Application -> Wai.Application
postgresWsMiddleware =
  WS.websocketsOr WS.defaultConnectionOptions `compose` wsApp
  where
    compose = (.) . (.) . (.) . (.) . (.)

-- private functions
jwtExpirationStatusCode :: Word16
jwtExpirationStatusCode = 3001

-- when the websocket is closed a ConnectionClosed Exception is triggered
-- this kills all children and frees resources for us
wsApp :: IO UTCTime -> Text -> ByteString -> H.Pool -> Multiplexer -> WS.ServerApp
wsApp getTime dbChannel secret pool multi pendingConn =
  getTime >>= validateClaims requestChannel secret (toS jwtToken) >>= either rejectRequest forkSessions
  where
    hasRead m = m == ("r" :: ByteString) || m == ("rw" :: ByteString)
    hasWrite m = m == ("w" :: ByteString) || m == ("rw" :: ByteString)

    rejectRequest :: Text -> IO ()
    rejectRequest msg = do
      putErrLn $ "Rejecting Request: " <> msg
      WS.rejectRequest pendingConn (toS msg)

    -- the URI has one of the two formats - /:jwt or /:channel/:jwt
    pathElements = BS.split '/' $ BS.drop 1 $ WS.requestPath $ WS.pendingRequest pendingConn
    jwtToken =
      case length pathElements `compare` 1 of
        GT -> headDef "" $ tailSafe pathElements
        _ -> headDef "" pathElements
    requestChannel =
      case length pathElements `compare` 1 of
        GT -> Just $ headDef "" pathElements
        _ -> Nothing
    forkSessions :: ConnectionInfo -> IO ()
    forkSessions (chs, mode, validClaims) = do
          -- We should accept only after verifying JWT
          conn <- WS.acceptRequest pendingConn
          -- Fork a pinging thread to ensure browser connections stay alive
          WS.withPingThread conn 30 (pure ()) $ do
            case M.lookup "exp" validClaims of
              Just (A.Number expClaim) -> do
                connectionExpirer <- newAlarmClock $ const (WS.sendCloseCode conn jwtExpirationStatusCode ("JWT expired" :: ByteString))
                setAlarm connectionExpirer (posixSecondsToUTCTime $ realToFrac expClaim)
              Just _ -> pure ()
              Nothing -> pure ()

            when (hasRead mode) $
              forM_ chs $ flip (onMessage multi) $ WS.sendTextData conn . B.payload

            when (hasWrite mode) $
              let sendNotifications = void . H.notifyPool pool dbChannel . toS
              in notifySession validClaims conn getTime sendNotifications chs

            waitForever <- newEmptyMVar
            void $ takeMVar waitForever

-- Having both channel and claims as parameters seem redundant
-- But it allows the function to ignore the claims structure and the source
-- of the channel, so all claims decoding can be coded in the caller
notifySession :: A.Object
              -> WS.Connection
              -> IO UTCTime
              -> (ByteString -> IO ())
              -> [ByteString]
              -> IO ()
notifySession claimsToSend wsCon getTime send chs =
  withAsync (forever relayData) wait
  where
    relayData = do
      msg <- WS.receiveData wsCon
      forM_ chs (relayChannelData msg . toS)

    relayChannelData msg ch = do
      claims' <- claimsWithTime ch
      send $ jsonMsg ch claims' msg

    -- we need to decode the bytestring to re-encode valid JSON for the notification
    jsonMsg :: Text -> M.HashMap Text A.Value -> ByteString -> ByteString
    jsonMsg ch cl = BL.toStrict . A.encode . Message cl ch . decodeUtf8With T.lenientDecode

    claimsWithTime :: Text -> IO (M.HashMap Text A.Value)
    claimsWithTime ch = do
      time <- utcTimeToPOSIXSeconds <$> getTime
      return $ M.insert "message_delivered_at" (A.Number $ realToFrac time) (claimsWithChannel ch)

    claimsWithChannel ch = M.insert "channel" (A.String ch) claimsToSend