{-| PostgRESTWS Middleware, composing this allows postgrest to create websockets connections that will communicate with the database through LISTEN/NOTIFY channels. -} {-# LANGUAGE DeriveGeneric #-} module PostgRESTWS ( postgrestWsMiddleware -- * Re-exports , newHasqlBroadcaster , newHasqlBroadcasterOrError ) where import qualified Hasql.Pool as H import qualified Network.Wai as Wai import qualified Network.Wai.Handler.WebSockets as WS import qualified Network.WebSockets as WS import Protolude import qualified Data.Aeson as A import qualified Data.ByteString.Char8 as BS import qualified Data.ByteString.Lazy as BL import qualified Data.HashMap.Strict as M import qualified Data.Text.Encoding.Error as T import Data.Time.Clock.POSIX (getPOSIXTime) import PostgRESTWS.Broadcast (Multiplexer, onMessage) import qualified PostgRESTWS.Broadcast as B import PostgRESTWS.Claims import PostgRESTWS.Database import PostgRESTWS.HasqlBroadcast (newHasqlBroadcaster, newHasqlBroadcasterOrError) data Message = Message { claims :: A.Object , payload :: Text } deriving (Show, Eq, Generic) instance A.ToJSON Message -- | Given a secret, a function to fetch the system time, a Hasql Pool and a Multiplexer this will give you a WAI middleware. postgrestWsMiddleware :: Maybe ByteString -> ByteString -> H.Pool -> Multiplexer -> Wai.Application -> Wai.Application postgrestWsMiddleware = WS.websocketsOr WS.defaultConnectionOptions `compose` wsApp where compose = (.) . (.) . (.) . (.) -- private functions -- when the websocket is closed a ConnectionClosed Exception is triggered -- this kills all children and frees resources for us wsApp :: Maybe ByteString -> ByteString -> H.Pool -> Multiplexer -> WS.ServerApp wsApp mAuditChannel secret pool multi pendingConn = validateClaims requestChannel secret (toS jwtToken) >>= either rejectRequest forkSessions where hasRead m = m == ("r" :: ByteString) || m == ("rw" :: ByteString) hasWrite m = m == ("w" :: ByteString) || m == ("rw" :: ByteString) rejectRequest = WS.rejectRequest pendingConn . encodeUtf8 -- the URI has one of the two formats - /:jwt or /:channel/:jwt pathElements = BS.split '/' $ BS.drop 1 $ WS.requestPath $ WS.pendingRequest pendingConn jwtToken | length pathElements > 1 = headDef "" $ tailSafe pathElements | length pathElements <= 1 = headDef "" pathElements requestChannel | length pathElements > 1 = Just $ headDef "" pathElements | length pathElements <= 1 = Nothing notifySessionWithTime = notifySession forkSessions (channel, mode, validClaims) = do -- role claim defaults to anon if not specified in jwt -- We should accept only after verifying JWT conn <- WS.acceptRequest pendingConn -- Fork a pinging thread to ensure browser connections stay alive WS.forkPingThread conn 30 when (hasRead mode) $ onMessage multi channel $ WS.sendTextData conn . B.payload when (hasWrite mode) $ let sendNotifications = void . case mAuditChannel of Nothing -> notifyPool pool channel Just auditChannel -> \mesg -> notifyPool pool channel mesg >> notifyPool pool auditChannel mesg in notifySessionWithTime validClaims conn sendNotifications waitForever <- newEmptyMVar void $ takeMVar waitForever -- Having both channel and claims as parameters seem redundant -- But it allows the function to ignore the claims structure and the source -- of the channel, so all claims decoding can be coded in the caller notifySession :: A.Object -> WS.Connection -> (ByteString -> IO ()) -> IO () notifySession claimsToSend wsCon send = withAsync (forever relayData) wait where relayData = jsonMsgWithTime >>= send jsonMsgWithTime = liftA2 jsonMsg claimsWithTime (WS.receiveData wsCon) -- we need to decode the bytestring to re-encode valid JSON for the notification jsonMsg :: M.HashMap Text A.Value -> ByteString -> ByteString jsonMsg cl = BL.toStrict . A.encode . Message cl . decodeUtf8With T.lenientDecode claimsWithTime :: IO (M.HashMap Text A.Value) claimsWithTime = do time <- getPOSIXTime return $ M.insert "message_delivered_at" (A.Number $ fromRational $ toRational time) claimsToSend