{-| PostgRESTWS Middleware, composing this allows postgrest to create websockets connections that will communicate with the database through LISTEN/NOTIFY channels. -} {-# LANGUAGE DeriveGeneric #-} module PostgRESTWS ( postgrestWsMiddleware -- * Re-exports , newHasqlBroadcaster , newHasqlBroadcasterOrError ) where import Protolude import qualified Network.Wai as Wai import qualified Network.Wai.Handler.WebSockets as WS import qualified Network.WebSockets as WS import qualified Hasql.Pool as H import qualified Data.Text.Encoding.Error as T import Data.Time.Clock.POSIX (POSIXTime) import qualified Data.Aeson as A import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as BL import PostgRESTWS.Claims import PostgRESTWS.Database import PostgRESTWS.Broadcast (Multiplexer, onMessage, readTChan) import PostgRESTWS.HasqlBroadcast (newHasqlBroadcaster, newHasqlBroadcasterOrError) import qualified PostgRESTWS.Broadcast as B data Message = Message { claims :: A.Object , payload :: Text } deriving (Show, Eq, Generic) instance A.ToJSON Message -- | Given a secret, a function to fetch the system time, a Hasql Pool and a Multiplexer this will give you a WAI middleware. postgrestWsMiddleware :: ByteString -> IO POSIXTime -> H.Pool -> Multiplexer -> Wai.Application -> Wai.Application postgrestWsMiddleware = WS.websocketsOr WS.defaultConnectionOptions `compose` wsApp where compose = (.) . (.) . (.) . (.) -- private functions -- when the websocket is closed a ConnectionClosed Exception is triggered -- this kills all children and frees resources for us wsApp :: ByteString -> IO POSIXTime -> H.Pool -> Multiplexer -> WS.ServerApp wsApp secret getTime pqCon multi pendingConn = getTime >>= forkSessionsWhenTokenIsValid . validateClaims secret jwtToken where forkSessionsWhenTokenIsValid = either rejectRequest forkSessions hasRead m = m == ("r" :: ByteString) || m == ("rw" :: ByteString) hasWrite m = m == ("w" :: ByteString) || m == ("rw" :: ByteString) rejectRequest = WS.rejectRequest pendingConn . encodeUtf8 -- the first char in path is '/' the rest is the token jwtToken = decodeUtf8 $ BS.drop 1 $ WS.requestPath $ WS.pendingRequest pendingConn forkSessions (channel, mode, validClaims) = do -- role claim defaults to anon if not specified in jwt -- We should accept only after verifying JWT conn <- WS.acceptRequest pendingConn -- Fork a pinging thread to ensure browser connections stay alive WS.forkPingThread conn 30 when (hasRead mode) $ onMessage multi channel (\ch -> forever $ atomically (readTChan ch) >>= WS.sendTextData conn . B.payload) notifySessionFinished <- if hasWrite mode then forkAndWait $ forever $ notifySession channel validClaims pqCon conn else newMVar () takeMVar notifySessionFinished -- Having both channel and claims as parameters seem redundant -- But it allows the function to ignore the claims structure and the source -- of the channel, so all claims decoding can be coded in the caller notifySession :: BS.ByteString -> A.Object -> H.Pool -> WS.Connection -> IO () notifySession channel claimsToSend pool wsCon = WS.receiveData wsCon >>= (void . send . jsonMsg) where send = notifyPool pool channel -- we need to decode the bytestring to re-encode valid JSON for the notification jsonMsg = BL.toStrict . A.encode . Message claimsToSend . decodeUtf8With T.lenientDecode forkAndWait :: IO () -> IO (MVar ()) forkAndWait io = do mvar <- newEmptyMVar void $ forkFinally io (\_ -> putMVar mvar ()) return mvar