module PostgRESTWS
( postgrestWsMiddleware
, newHasqlBroadcaster
, newHasqlBroadcasterOrError
) where
import Protolude
import qualified Network.Wai as Wai
import qualified Network.Wai.Handler.WebSockets as WS
import qualified Network.WebSockets as WS
import qualified Hasql.Pool as H
import qualified Data.Text.Encoding.Error as T
import Data.Time.Clock.POSIX (POSIXTime)
import qualified Data.Aeson as A
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BL
import PostgRESTWS.Claims
import PostgRESTWS.Database
import PostgRESTWS.Broadcast (Multiplexer, onMessage, readTChan)
import PostgRESTWS.HasqlBroadcast (newHasqlBroadcaster, newHasqlBroadcasterOrError)
import qualified PostgRESTWS.Broadcast as B
data Message = Message
{ claims :: A.Object
, payload :: Text
} deriving (Show, Eq, Generic)
instance A.ToJSON Message
postgrestWsMiddleware :: ByteString -> IO POSIXTime -> H.Pool -> Multiplexer -> Wai.Application -> Wai.Application
postgrestWsMiddleware =
WS.websocketsOr WS.defaultConnectionOptions `compose` wsApp
where
compose = (.) . (.) . (.) . (.)
wsApp :: ByteString -> IO POSIXTime -> H.Pool -> Multiplexer -> WS.ServerApp
wsApp secret getTime pqCon multi pendingConn =
getTime >>= forkSessionsWhenTokenIsValid . validateClaims secret jwtToken
where
forkSessionsWhenTokenIsValid = either rejectRequest forkSessions
hasRead m = m == ("r" :: ByteString) || m == ("rw" :: ByteString)
hasWrite m = m == ("w" :: ByteString) || m == ("rw" :: ByteString)
rejectRequest = WS.rejectRequest pendingConn . encodeUtf8
jwtToken = decodeUtf8 $ BS.drop 1 $ WS.requestPath $ WS.pendingRequest pendingConn
forkSessions (channel, mode, validClaims) = do
conn <- WS.acceptRequest pendingConn
WS.forkPingThread conn 30
when (hasRead mode) $
onMessage multi channel (\ch ->
forever $ atomically (readTChan ch) >>= WS.sendTextData conn . B.payload)
notifySessionFinished <- if hasWrite mode
then forkAndWait $ forever $ notifySession channel validClaims pqCon conn
else newMVar ()
takeMVar notifySessionFinished
notifySession :: BS.ByteString
-> A.Object
-> H.Pool
-> WS.Connection
-> IO ()
notifySession channel claimsToSend pool wsCon =
WS.receiveData wsCon >>= (void . send . jsonMsg)
where
send = notifyPool pool channel
jsonMsg = BL.toStrict . A.encode . Message claimsToSend . decodeUtf8With T.lenientDecode
forkAndWait :: IO () -> IO (MVar ())
forkAndWait io = do
mvar <- newEmptyMVar
void $ forkFinally io (\_ -> putMVar mvar ())
return mvar