{-# LANGUAGE DeriveGeneric #-}
module PostgresWebsockets.Middleware
( postgresWsMiddleware
) where
import qualified Hasql.Pool as H
import qualified Hasql.Notifications as H
import qualified Network.Wai as Wai
import qualified Network.Wai.Handler.WebSockets as WS
import qualified Network.WebSockets as WS
import Protolude
import qualified Data.Aeson as A
import qualified Data.ByteString.Char8 as BS
import qualified Data.ByteString.Lazy as BL
import qualified Data.HashMap.Strict as M
import qualified Data.Text.Encoding.Error as T
import Data.Time.Clock (UTCTime)
import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds, posixSecondsToUTCTime)
import Control.Concurrent.AlarmClock (newAlarmClock, setAlarm)
import PostgresWebsockets.Broadcast (Multiplexer, onMessage)
import qualified PostgresWebsockets.Broadcast as B
import PostgresWebsockets.Claims
data Message = Message
{ claims :: A.Object
, channel :: Text
, payload :: Text
} deriving (Show, Eq, Generic)
instance A.ToJSON Message
postgresWsMiddleware :: IO UTCTime -> Text -> ByteString -> H.Pool -> Multiplexer -> Wai.Application -> Wai.Application
postgresWsMiddleware =
WS.websocketsOr WS.defaultConnectionOptions `compose` wsApp
where
compose = (.) . (.) . (.) . (.) . (.)
jwtExpirationStatusCode :: Word16
jwtExpirationStatusCode = 3001
wsApp :: IO UTCTime -> Text -> ByteString -> H.Pool -> Multiplexer -> WS.ServerApp
wsApp getTime dbChannel secret pool multi pendingConn =
getTime >>= validateClaims requestChannel secret (toS jwtToken) >>= either rejectRequest forkSessions
where
hasRead m = m == ("r" :: ByteString) || m == ("rw" :: ByteString)
hasWrite m = m == ("w" :: ByteString) || m == ("rw" :: ByteString)
rejectRequest :: Text -> IO ()
rejectRequest msg = do
putErrLn $ "Rejecting Request: " <> msg
WS.rejectRequest pendingConn (toS msg)
pathElements = BS.split '/' $ BS.drop 1 $ WS.requestPath $ WS.pendingRequest pendingConn
jwtToken =
case length pathElements `compare` 1 of
GT -> headDef "" $ tailSafe pathElements
_ -> headDef "" pathElements
requestChannel =
case length pathElements `compare` 1 of
GT -> Just $ headDef "" pathElements
_ -> Nothing
forkSessions :: ConnectionInfo -> IO ()
forkSessions (chs, mode, validClaims) = do
conn <- WS.acceptRequest pendingConn
WS.withPingThread conn 30 (pure ()) $ do
case M.lookup "exp" validClaims of
Just (A.Number expClaim) -> do
connectionExpirer <- newAlarmClock $ const (WS.sendCloseCode conn jwtExpirationStatusCode ("JWT expired" :: ByteString))
setAlarm connectionExpirer (posixSecondsToUTCTime $ realToFrac expClaim)
Just _ -> pure ()
Nothing -> pure ()
when (hasRead mode) $
forM_ chs $ flip (onMessage multi) $ WS.sendTextData conn . B.payload
when (hasWrite mode) $
let sendNotifications = void . H.notifyPool pool dbChannel . toS
in notifySession validClaims conn getTime sendNotifications chs
waitForever <- newEmptyMVar
void $ takeMVar waitForever
notifySession :: A.Object
-> WS.Connection
-> IO UTCTime
-> (ByteString -> IO ())
-> [ByteString]
-> IO ()
notifySession claimsToSend wsCon getTime send chs =
withAsync (forever relayData) wait
where
relayData = do
msg <- WS.receiveData wsCon
forM_ chs (relayChannelData msg . toS)
relayChannelData msg ch = do
claims' <- claimsWithTime ch
send $ jsonMsg ch claims' msg
jsonMsg :: Text -> M.HashMap Text A.Value -> ByteString -> ByteString
jsonMsg ch cl = BL.toStrict . A.encode . Message cl ch . decodeUtf8With T.lenientDecode
claimsWithTime :: Text -> IO (M.HashMap Text A.Value)
claimsWithTime ch = do
time <- utcTimeToPOSIXSeconds <$> getTime
return $ M.insert "message_delivered_at" (A.Number $ realToFrac time) (claimsWithChannel ch)
claimsWithChannel ch = M.insert "channel" (A.String ch) claimsToSend