{-# LANGUAGE DeriveGeneric #-}
module PostgresWebsockets
( postgresWsMiddleware
, newHasqlBroadcaster
, newHasqlBroadcasterOrError
) where
import qualified Hasql.Pool as H
import qualified Hasql.Notifications as H
import qualified Network.Wai as Wai
import qualified Network.Wai.Handler.WebSockets as WS
import qualified Network.WebSockets as WS
import Protolude
import qualified Data.Aeson as A
import qualified Data.ByteString.Char8 as BS
import qualified Data.ByteString.Lazy as BL
import qualified Data.HashMap.Strict as M
import qualified Data.Text.Encoding.Error as T
import Data.Time.Clock.POSIX (getPOSIXTime)
import PostgresWebsockets.Broadcast (Multiplexer, onMessage)
import qualified PostgresWebsockets.Broadcast as B
import PostgresWebsockets.Claims
import PostgresWebsockets.HasqlBroadcast (newHasqlBroadcaster,
newHasqlBroadcasterOrError)
data Message = Message
{ claims :: A.Object
, channel :: Text
, payload :: Text
} deriving (Show, Eq, Generic)
instance A.ToJSON Message
postgresWsMiddleware :: Text -> ByteString -> H.Pool -> Multiplexer -> Wai.Application -> Wai.Application
postgresWsMiddleware =
WS.websocketsOr WS.defaultConnectionOptions `compose` wsApp
where
compose = (.) . (.) . (.) . (.)
wsApp :: Text -> ByteString -> H.Pool -> Multiplexer -> WS.ServerApp
wsApp dbChannel secret pool multi pendingConn =
validateClaims requestChannel secret (toS jwtToken) >>= either rejectRequest forkSessions
where
hasRead m = m == ("r" :: ByteString) || m == ("rw" :: ByteString)
hasWrite m = m == ("w" :: ByteString) || m == ("rw" :: ByteString)
rejectRequest = WS.rejectRequest pendingConn . encodeUtf8
pathElements = BS.split '/' $ BS.drop 1 $ WS.requestPath $ WS.pendingRequest pendingConn
jwtToken
| length pathElements > 1 = headDef "" $ tailSafe pathElements
| length pathElements <= 1 = headDef "" pathElements
requestChannel
| length pathElements > 1 = Just $ headDef "" pathElements
| length pathElements <= 1 = Nothing
forkSessions (ch, mode, validClaims) = do
conn <- WS.acceptRequest pendingConn
WS.forkPingThread conn 30
when (hasRead mode) $
onMessage multi ch $ WS.sendTextData conn . B.payload
when (hasWrite mode) $
let sendNotifications = void . (H.notifyPool pool dbChannel) . toS
in notifySession validClaims (toS ch) conn sendNotifications
waitForever <- newEmptyMVar
void $ takeMVar waitForever
notifySession :: A.Object
-> Text
-> WS.Connection
-> (ByteString -> IO ())
-> IO ()
notifySession claimsToSend ch wsCon send =
withAsync (forever relayData) wait
where
relayData = jsonMsgWithTime >>= send
jsonMsgWithTime = liftA2 jsonMsg claimsWithTime (WS.receiveData wsCon)
jsonMsg :: M.HashMap Text A.Value -> ByteString -> ByteString
jsonMsg cl = BL.toStrict . A.encode . Message cl ch . decodeUtf8With T.lenientDecode
claimsWithChannel = M.insert "channel" (A.String ch) claimsToSend
claimsWithTime :: IO (M.HashMap Text A.Value)
claimsWithTime = do
time <- getPOSIXTime
return $ M.insert "message_delivered_at" (A.Number $ fromRational $ toRational time) claimsWithChannel