module PostgRESTWS
( postgrestWsMiddleware
, newHasqlBroadcaster
, newHasqlBroadcasterOrError
) where
import Protolude
import qualified Network.Wai as Wai
import qualified Network.Wai.Handler.WebSockets as WS
import qualified Network.WebSockets as WS
import qualified Hasql.Pool as H
import qualified Data.Text.Encoding.Error as T
import Data.Time.Clock.POSIX (POSIXTime)
import qualified Data.Aeson as A
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BL
import PostgRESTWS.Claims
import PostgRESTWS.Database
import PostgRESTWS.Broadcast (Multiplexer, onMessage)
import PostgRESTWS.HasqlBroadcast (newHasqlBroadcaster, newHasqlBroadcasterOrError)
import qualified PostgRESTWS.Broadcast as B
data Message = Message
{ claims :: A.Object
, payload :: Text
} deriving (Show, Eq, Generic)
instance A.ToJSON Message
postgrestWsMiddleware :: Maybe PgIdentifier -> ByteString -> IO POSIXTime -> H.Pool -> Multiplexer -> Wai.Application -> Wai.Application
postgrestWsMiddleware =
WS.websocketsOr WS.defaultConnectionOptions `compose` wsApp
where
compose = (.) . (.) . (.) . (.) . (.)
wsApp :: Maybe PgIdentifier -> ByteString -> IO POSIXTime -> H.Pool -> Multiplexer -> WS.ServerApp
wsApp mAuditChannel secret getTime pool multi pendingConn =
getTime >>= forkSessionsWhenTokenIsValid . validateClaims secret jwtToken
where
forkSessionsWhenTokenIsValid = either rejectRequest forkSessions
hasRead m = m == ("r" :: ByteString) || m == ("rw" :: ByteString)
hasWrite m = m == ("w" :: ByteString) || m == ("rw" :: ByteString)
rejectRequest = WS.rejectRequest pendingConn . encodeUtf8
jwtToken = decodeUtf8 $ BS.drop 1 $ WS.requestPath $ WS.pendingRequest pendingConn
forkSessions (channel, mode, validClaims) = do
conn <- WS.acceptRequest pendingConn
WS.forkPingThread conn 30
when (hasRead mode) $
onMessage multi channel $ WS.sendTextData conn . B.payload
when (hasWrite mode) $
let channelName = toPgIdentifier channel
sendNotifications = void . case mAuditChannel of
Nothing -> notifyPool pool channelName
Just auditChannel -> \mesg ->
notifyPool pool channelName mesg >>
notifyPool pool auditChannel mesg
in notifySession validClaims conn sendNotifications
waitForever <- newEmptyMVar
void $ takeMVar waitForever
notifySession :: A.Object
-> WS.Connection
-> (ByteString -> IO ())
-> IO ()
notifySession claimsToSend wsCon send =
withAsync (forever relayData) wait
where
relayData = WS.receiveData wsCon >>= (void . send . jsonMsg)
jsonMsg = BL.toStrict . A.encode . Message claimsToSend . decodeUtf8With T.lenientDecode