{-|
Module      : PostgresWebsockets.Middleware
Description : PostgresWebsockets WAI middleware, add functionality to any WAI application.

Allow websockets connections that will communicate with the database through LISTEN/NOTIFY channels.
-}
{-# LANGUAGE DeriveGeneric #-}

module PostgresWebsockets.Middleware
  ( postgresWsMiddleware
  ) where

import Protolude hiding (toS)
import Protolude.Conv
import Data.Time.Clock (UTCTime)
import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds, posixSecondsToUTCTime)
import Control.Concurrent.AlarmClock (newAlarmClock, setAlarm)
import qualified Hasql.Notifications as H
import qualified Hasql.Pool as H
import qualified Network.Wai as Wai
import qualified Network.Wai.Handler.WebSockets as WS
import qualified Network.WebSockets as WS

import qualified Data.Aeson as A
import qualified Data.Text as T
import qualified Data.ByteString.Lazy as BL
import qualified Data.HashMap.Strict as M

import PostgresWebsockets.Broadcast (onMessage)
import PostgresWebsockets.Claims ( ConnectionInfo, validateClaims )
import PostgresWebsockets.Context ( Context(..) )
import PostgresWebsockets.Config (AppConfig(..))
import qualified PostgresWebsockets.Broadcast as B


data Event =
    WebsocketMessage
  | ConnectionOpen
  deriving (Int -> Event -> ShowS
[Event] -> ShowS
Event -> String
(Int -> Event -> ShowS)
-> (Event -> String) -> ([Event] -> ShowS) -> Show Event
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Event] -> ShowS
$cshowList :: [Event] -> ShowS
show :: Event -> String
$cshow :: Event -> String
showsPrec :: Int -> Event -> ShowS
$cshowsPrec :: Int -> Event -> ShowS
Show, Event -> Event -> Bool
(Event -> Event -> Bool) -> (Event -> Event -> Bool) -> Eq Event
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Event -> Event -> Bool
$c/= :: Event -> Event -> Bool
== :: Event -> Event -> Bool
$c== :: Event -> Event -> Bool
Eq, (forall x. Event -> Rep Event x)
-> (forall x. Rep Event x -> Event) -> Generic Event
forall x. Rep Event x -> Event
forall x. Event -> Rep Event x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep Event x -> Event
$cfrom :: forall x. Event -> Rep Event x
Generic)

data Message = Message
  { Message -> Object
claims  :: A.Object
  , Message -> Event
event   :: Event
  , Message -> Text
payload :: Text
  , Message -> Text
channel :: Text
  } deriving (Int -> Message -> ShowS
[Message] -> ShowS
Message -> String
(Int -> Message -> ShowS)
-> (Message -> String) -> ([Message] -> ShowS) -> Show Message
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Message] -> ShowS
$cshowList :: [Message] -> ShowS
show :: Message -> String
$cshow :: Message -> String
showsPrec :: Int -> Message -> ShowS
$cshowsPrec :: Int -> Message -> ShowS
Show, Message -> Message -> Bool
(Message -> Message -> Bool)
-> (Message -> Message -> Bool) -> Eq Message
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Message -> Message -> Bool
$c/= :: Message -> Message -> Bool
== :: Message -> Message -> Bool
$c== :: Message -> Message -> Bool
Eq, (forall x. Message -> Rep Message x)
-> (forall x. Rep Message x -> Message) -> Generic Message
forall x. Rep Message x -> Message
forall x. Message -> Rep Message x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep Message x -> Message
$cfrom :: forall x. Message -> Rep Message x
Generic)

instance A.ToJSON Event
instance A.ToJSON Message

-- | Given a secret, a function to fetch the system time, a Hasql Pool and a Multiplexer this will give you a WAI middleware.
postgresWsMiddleware :: Context -> Wai.Middleware
postgresWsMiddleware :: Context -> Middleware
postgresWsMiddleware =
  ConnectionOptions -> ServerApp -> Middleware
WS.websocketsOr ConnectionOptions
WS.defaultConnectionOptions (ServerApp -> Middleware)
-> (Context -> ServerApp) -> Context -> Middleware
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Context -> ServerApp
wsApp

-- private functions
jwtExpirationStatusCode :: Word16
jwtExpirationStatusCode :: Word16
jwtExpirationStatusCode = Word16
3001

-- when the websocket is closed a ConnectionClosed Exception is triggered
-- this kills all children and frees resources for us
wsApp :: Context -> WS.ServerApp
wsApp :: Context -> ServerApp
wsApp Context{IO UTCTime
Pool
Multiplexer
AppConfig
ctxGetTime :: Context -> IO UTCTime
ctxMulti :: Context -> Multiplexer
ctxPool :: Context -> Pool
ctxConfig :: Context -> AppConfig
ctxGetTime :: IO UTCTime
ctxMulti :: Multiplexer
ctxPool :: Pool
ctxConfig :: AppConfig
..} PendingConnection
pendingConn =
  IO UTCTime
ctxGetTime IO UTCTime
-> (UTCTime -> IO (Either Text ConnectionInfo))
-> IO (Either Text ConnectionInfo)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Maybe Text
-> ByteString
-> LByteString
-> UTCTime
-> IO (Either Text ConnectionInfo)
validateClaims Maybe Text
requestChannel (AppConfig -> ByteString
configJwtSecret AppConfig
ctxConfig) (Text -> LByteString
forall a b. StringConv a b => a -> b
toS Text
jwtToken) IO (Either Text ConnectionInfo)
-> (Either Text ConnectionInfo -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Text -> IO ())
-> (ConnectionInfo -> IO ()) -> Either Text ConnectionInfo -> IO ()
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Text -> IO ()
rejectRequest ConnectionInfo -> IO ()
forkSessions
  where
    hasRead :: Text -> Bool
hasRead Text
m = Text
m Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== (Text
"r" :: Text) Bool -> Bool -> Bool
|| Text
m Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== (Text
"rw" :: Text)
    hasWrite :: Text -> Bool
hasWrite Text
m = Text
m Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== (Text
"w" :: Text) Bool -> Bool -> Bool
|| Text
m Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== (Text
"rw" :: Text)

    rejectRequest :: Text -> IO ()
    rejectRequest :: Text -> IO ()
rejectRequest Text
msg = do
      Text -> IO ()
forall a (m :: * -> *). (Print a, MonadIO m) => a -> m ()
putErrLn (Text -> IO ()) -> Text -> IO ()
forall a b. (a -> b) -> a -> b
$ Text
"Rejecting Request: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
msg
      PendingConnection -> ByteString -> IO ()
WS.rejectRequest PendingConnection
pendingConn (Text -> ByteString
forall a b. StringConv a b => a -> b
toS Text
msg)

    -- the URI has one of the two formats - /:jwt or /:channel/:jwt
    pathElements :: [Text]
pathElements = (Char -> Bool) -> Text -> [Text]
T.split (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'/') (Text -> [Text]) -> Text -> [Text]
forall a b. (a -> b) -> a -> b
$ Int -> Text -> Text
T.drop Int
1 (Text -> Text) -> Text -> Text
forall a b. (a -> b) -> a -> b
$ (ByteString -> Text
forall a b. StringConv a b => a -> b
toSL (ByteString -> Text)
-> (RequestHead -> ByteString) -> RequestHead -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RequestHead -> ByteString
WS.requestPath) (RequestHead -> Text) -> RequestHead -> Text
forall a b. (a -> b) -> a -> b
$ PendingConnection -> RequestHead
WS.pendingRequest PendingConnection
pendingConn
    jwtToken :: Text
jwtToken =
      case [Text] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Text]
pathElements Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` Int
1 of
        Ordering
GT -> Text -> [Text] -> Text
forall a. a -> [a] -> a
headDef Text
"" ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ [Text] -> [Text]
forall a. [a] -> [a]
tailSafe [Text]
pathElements
        Ordering
_ -> Text -> [Text] -> Text
forall a. a -> [a] -> a
headDef Text
"" [Text]
pathElements
    requestChannel :: Maybe Text
requestChannel =
      case [Text] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Text]
pathElements Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` Int
1 of
        Ordering
GT -> Text -> Maybe Text
forall a. a -> Maybe a
Just (Text -> Maybe Text) -> Text -> Maybe Text
forall a b. (a -> b) -> a -> b
$ Text -> [Text] -> Text
forall a. a -> [a] -> a
headDef Text
"" [Text]
pathElements
        Ordering
_ -> Maybe Text
forall a. Maybe a
Nothing
    forkSessions :: ConnectionInfo -> IO ()
    forkSessions :: ConnectionInfo -> IO ()
forkSessions ([Text]
chs, Text
mode, Object
validClaims) = do
          -- We should accept only after verifying JWT
          Connection
conn <- PendingConnection -> IO Connection
WS.acceptRequest PendingConnection
pendingConn
          -- Fork a pinging thread to ensure browser connections stay alive
          Connection -> Int -> IO () -> IO () -> IO ()
forall a. Connection -> Int -> IO () -> IO a -> IO a
WS.withPingThread Connection
conn Int
30 (() -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            case Text -> Object -> Maybe Value
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
M.lookup Text
"exp" Object
validClaims of
              Just (A.Number Scientific
expClaim) -> do
                AlarmClock UTCTime
connectionExpirer <- (AlarmClock UTCTime -> IO ()) -> IO (AlarmClock UTCTime)
forall t.
TimeScale t =>
(AlarmClock t -> IO ()) -> IO (AlarmClock t)
newAlarmClock ((AlarmClock UTCTime -> IO ()) -> IO (AlarmClock UTCTime))
-> (AlarmClock UTCTime -> IO ()) -> IO (AlarmClock UTCTime)
forall a b. (a -> b) -> a -> b
$ IO () -> AlarmClock UTCTime -> IO ()
forall a b. a -> b -> a
const (Connection -> Word16 -> ByteString -> IO ()
forall a. WebSocketsData a => Connection -> Word16 -> a -> IO ()
WS.sendCloseCode Connection
conn Word16
jwtExpirationStatusCode (ByteString
"JWT expired" :: ByteString))
                AlarmClock UTCTime -> UTCTime -> IO ()
forall t. TimeScale t => AlarmClock t -> t -> IO ()
setAlarm AlarmClock UTCTime
connectionExpirer (POSIXTime -> UTCTime
posixSecondsToUTCTime (POSIXTime -> UTCTime) -> POSIXTime -> UTCTime
forall a b. (a -> b) -> a -> b
$ Scientific -> POSIXTime
forall a b. (Real a, Fractional b) => a -> b
realToFrac Scientific
expClaim)
              Just Value
_ -> () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
              Maybe Value
Nothing -> () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

            let sendNotification :: Text -> Text -> IO ()
sendNotification Text
msg Text
channel = Message -> IO ()
sendMessageWithTimestamp (Message -> IO ()) -> Message -> IO ()
forall a b. (a -> b) -> a -> b
$ Text -> Text -> Message
websocketMessageForChannel Text
msg Text
channel
                sendMessageToDatabase :: Message -> IO ()
sendMessageToDatabase = Pool -> Text -> Message -> IO ()
sendToDatabase Pool
ctxPool (AppConfig -> Text
configListenChannel AppConfig
ctxConfig)
                sendMessageWithTimestamp :: Message -> IO ()
sendMessageWithTimestamp = IO UTCTime -> Message -> IO Message
timestampMessage IO UTCTime
ctxGetTime (Message -> IO Message) -> (Message -> IO ()) -> Message -> IO ()
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Message -> IO ()
sendMessageToDatabase
                websocketMessageForChannel :: Text -> Text -> Message
websocketMessageForChannel = Object -> Event -> Text -> Text -> Message
Message Object
validClaims Event
WebsocketMessage
                connectionOpenMessage :: Text -> Text -> Message
connectionOpenMessage = Object -> Event -> Text -> Text -> Message
Message Object
validClaims Event
ConnectionOpen

            case AppConfig -> Maybe Text
configMetaChannel AppConfig
ctxConfig of
              Maybe Text
Nothing -> () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
              Just Text
ch -> Message -> IO ()
sendMessageWithTimestamp (Message -> IO ()) -> Message -> IO ()
forall a b. (a -> b) -> a -> b
$ Text -> Text -> Message
connectionOpenMessage (Text -> Text
forall a b. StringConv a b => a -> b
toS (Text -> Text) -> Text -> Text
forall a b. (a -> b) -> a -> b
$ Text -> [Text] -> Text
T.intercalate Text
"," [Text]
chs) Text
ch

            Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Text -> Bool
hasRead Text
mode) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
              [Text] -> (Text -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Text]
chs ((Text -> IO ()) -> IO ()) -> (Text -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ (Text -> (Message -> IO ()) -> IO ())
-> (Message -> IO ()) -> Text -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Multiplexer -> Text -> (Message -> IO ()) -> IO ()
onMessage Multiplexer
ctxMulti) ((Message -> IO ()) -> Text -> IO ())
-> (Message -> IO ()) -> Text -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> Text -> IO ()
forall a. WebSocketsData a => Connection -> a -> IO ()
WS.sendTextData Connection
conn (Text -> IO ()) -> (Message -> Text) -> Message -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Message -> Text
B.payload

            Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Text -> Bool
hasWrite Text
mode) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
              Connection -> (Text -> Text -> IO ()) -> [Text] -> IO ()
notifySession Connection
conn Text -> Text -> IO ()
sendNotification [Text]
chs

            MVar Any
waitForever <- IO (MVar Any)
forall a. IO (MVar a)
newEmptyMVar
            IO Any -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Any -> IO ()) -> IO Any -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar Any -> IO Any
forall a. MVar a -> IO a
takeMVar MVar Any
waitForever

-- Having both channel and claims as parameters seem redundant
-- But it allows the function to ignore the claims structure and the source
-- of the channel, so all claims decoding can be coded in the caller
notifySession :: WS.Connection -> (Text -> Text -> IO ()) -> [Text] -> IO ()
notifySession :: Connection -> (Text -> Text -> IO ()) -> [Text] -> IO ()
notifySession Connection
wsCon Text -> Text -> IO ()
sendToChannel [Text]
chs =
  IO () -> (Async () -> IO ()) -> IO ()
forall a b. IO a -> (Async a -> IO b) -> IO b
withAsync (IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever IO ()
relayData) Async () -> IO ()
forall a. Async a -> IO a
wait
  where
    relayData :: IO ()
relayData = do
      Text
msg <- Connection -> IO Text
forall a. WebSocketsData a => Connection -> IO a
WS.receiveData Connection
wsCon
      [Text] -> (Text -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Text]
chs (Text -> Text -> IO ()
sendToChannel Text
msg (Text -> IO ()) -> (Text -> Text) -> Text -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text
forall a b. StringConv a b => a -> b
toS)

sendToDatabase :: H.Pool -> Text -> Message -> IO ()
sendToDatabase :: Pool -> Text -> Message -> IO ()
sendToDatabase Pool
pool Text
dbChannel =
  ByteString -> IO ()
notify (ByteString -> IO ())
-> (Message -> ByteString) -> Message -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Message -> ByteString
jsonMsg
  where
    notify :: ByteString -> IO ()
notify = IO (Either UsageError ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Either UsageError ()) -> IO ())
-> (ByteString -> IO (Either UsageError ())) -> ByteString -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pool -> Text -> Text -> IO (Either UsageError ())
H.notifyPool Pool
pool Text
dbChannel (Text -> IO (Either UsageError ()))
-> (ByteString -> Text) -> ByteString -> IO (Either UsageError ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Text
forall a b. StringConv a b => a -> b
toS
    jsonMsg :: Message -> ByteString
jsonMsg = LByteString -> ByteString
BL.toStrict (LByteString -> ByteString)
-> (Message -> LByteString) -> Message -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Message -> LByteString
forall a. ToJSON a => a -> LByteString
A.encode

timestampMessage :: IO UTCTime -> Message -> IO Message
timestampMessage :: IO UTCTime -> Message -> IO Message
timestampMessage IO UTCTime
getTime msg :: Message
msg@Message{Text
Object
Event
channel :: Text
payload :: Text
event :: Event
claims :: Object
channel :: Message -> Text
payload :: Message -> Text
event :: Message -> Event
claims :: Message -> Object
..} = do
  POSIXTime
time <- UTCTime -> POSIXTime
utcTimeToPOSIXSeconds (UTCTime -> POSIXTime) -> IO UTCTime -> IO POSIXTime
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO UTCTime
getTime
  Message -> IO Message
forall (m :: * -> *) a. Monad m => a -> m a
return (Message -> IO Message) -> Message -> IO Message
forall a b. (a -> b) -> a -> b
$ Message
msg{ claims :: Object
claims = Text -> Value -> Object -> Object
forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
M.insert Text
"message_delivered_at" (Scientific -> Value
A.Number (Scientific -> Value) -> Scientific -> Value
forall a b. (a -> b) -> a -> b
$ POSIXTime -> Scientific
forall a b. (Real a, Fractional b) => a -> b
realToFrac POSIXTime
time) Object
claims}