module PostgRESTWS
( postgrestWsMiddleware
, newHasqlBroadcaster
, newHasqlBroadcasterOrError
) where
import qualified Hasql.Pool as H
import qualified Network.Wai as Wai
import qualified Network.Wai.Handler.WebSockets as WS
import qualified Network.WebSockets as WS
import Protolude
import qualified Data.Aeson as A
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BL
import qualified Data.HashMap.Strict as M
import qualified Data.Text.Encoding.Error as T
import Data.Time.Clock.POSIX (POSIXTime)
import PostgRESTWS.Broadcast (Multiplexer, onMessage)
import qualified PostgRESTWS.Broadcast as B
import PostgRESTWS.Claims
import PostgRESTWS.Database
import PostgRESTWS.HasqlBroadcast (newHasqlBroadcaster,
newHasqlBroadcasterOrError)
data Message = Message
{ claims :: A.Object
, payload :: Text
} deriving (Show, Eq, Generic)
instance A.ToJSON Message
postgrestWsMiddleware :: Maybe ByteString -> ByteString -> IO POSIXTime -> H.Pool -> Multiplexer -> Wai.Application -> Wai.Application
postgrestWsMiddleware =
WS.websocketsOr WS.defaultConnectionOptions `compose` wsApp
where
compose = (.) . (.) . (.) . (.) . (.)
wsApp :: Maybe ByteString -> ByteString -> IO POSIXTime -> H.Pool -> Multiplexer -> WS.ServerApp
wsApp mAuditChannel secret getTime pool multi pendingConn = do
validateClaims secret (toS jwtToken) >>= either rejectRequest forkSessions
where
hasRead m = m == ("r" :: ByteString) || m == ("rw" :: ByteString)
hasWrite m = m == ("w" :: ByteString) || m == ("rw" :: ByteString)
rejectRequest = WS.rejectRequest pendingConn . encodeUtf8
jwtToken = BS.drop 1 $ WS.requestPath $ WS.pendingRequest pendingConn
notifySessionWithTime = notifySession getTime
forkSessions (channel, mode, validClaims) = do
conn <- WS.acceptRequest pendingConn
WS.forkPingThread conn 30
when (hasRead mode) $
onMessage multi channel $ WS.sendTextData conn . B.payload
when (hasWrite mode) $
let sendNotifications = void . case mAuditChannel of
Nothing -> notifyPool pool channel
Just auditChannel -> \mesg ->
notifyPool pool channel mesg >>
notifyPool pool auditChannel mesg
in notifySessionWithTime validClaims conn sendNotifications
waitForever <- newEmptyMVar
void $ takeMVar waitForever
notifySession :: IO POSIXTime
-> A.Object
-> WS.Connection
-> (ByteString -> IO ())
-> IO ()
notifySession getTime claimsToSend wsCon send =
withAsync (forever relayData) wait
where
relayData = jsonMsgWithTime >>= send
jsonMsgWithTime = liftA2 jsonMsg claimsWithTime (WS.receiveData wsCon)
jsonMsg :: M.HashMap Text A.Value -> ByteString -> ByteString
jsonMsg cl = BL.toStrict . A.encode . Message cl . decodeUtf8With T.lenientDecode
claimsWithTime :: IO (M.HashMap Text A.Value)
claimsWithTime = do
time <- getTime
return $ M.insert "message_delivered_at" (A.Number $ fromRational $ toRational time) claimsToSend