{-# LANGUAGE OverloadedStrings #-}

module WebsocketServer (
  ServerState,
  acceptConnection,
  processUpdates
) where

import Control.Concurrent (modifyMVar_, readMVar)
import Control.Concurrent.STM (atomically)
import Control.Concurrent.STM.TBQueue (readTBQueue)
import Control.Exception (SomeAsyncException, SomeException, finally, fromException, catch, throwIO)
import Control.Monad (forever)
import Data.Aeson (Value)
import Data.Text (Text)
import Data.UUID
import System.Random (randomIO)

import qualified Data.Aeson as Aeson
import qualified Data.ByteString.Lazy as LBS
import qualified Data.Time.Clock.POSIX as Clock
import qualified Network.WebSockets as WS
import qualified Network.HTTP.Types.Header as HttpHeader
import qualified Network.HTTP.Types.URI as Uri

import Config (Config (..))
import Core (Core (..), ServerState, Updated (..), getCurrentValue, withCoreMetrics)
import Store (Path)
import AccessControl (AccessMode(..))
import JwtMiddleware (AuthResult (..), isRequestAuthorized, errorResponseBody)

import qualified Metrics
import qualified Subscription

newUUID :: IO UUID
newUUID :: IO UUID
newUUID = IO UUID
forall a. Random a => IO a
randomIO

-- send the updated data to all subscribers to the path
broadcast :: [Text] -> Value -> ServerState -> IO ()
broadcast :: [Text] -> Value -> ServerState -> IO ()
broadcast =
  let
    send :: WS.Connection -> Value -> IO ()
    send :: Connection -> Value -> IO ()
send Connection
conn Value
value =
      Connection -> ByteString -> IO ()
forall a. WebSocketsData a => Connection -> a -> IO ()
WS.sendTextData Connection
conn (Value -> ByteString
forall a. ToJSON a => a -> ByteString
Aeson.encode Value
value)
      IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch`
      SomeException -> IO ()
sendFailed

    sendFailed :: SomeException -> IO ()
    sendFailed :: SomeException -> IO ()
sendFailed SomeException
exc
      -- Rethrow async exceptions, they are meant for inter-thread communication
      -- (e.g. ThreadKilled) and we don't expect them at this point.
      | Just SomeAsyncException
asyncExc <- SomeException -> Maybe SomeAsyncException
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
exc = SomeAsyncException -> IO ()
forall e a. Exception e => e -> IO a
throwIO (SomeAsyncException
asyncExc :: SomeAsyncException)
      -- We want to catch all other errors in order to prevent them from
      -- bubbling up and disrupting the broadcasts to other clients.
      | Bool
otherwise = () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  in
    (Connection -> Value -> IO ())
-> [Text] -> Value -> ServerState -> IO ()
forall conn id.
(conn -> Value -> IO ())
-> [Text] -> Value -> SubscriptionTree id conn -> IO ()
Subscription.broadcast Connection -> Value -> IO ()
send

-- Called for each new client that connects.
acceptConnection :: Core -> WS.PendingConnection -> IO ()
acceptConnection :: Core -> PendingConnection -> IO ()
acceptConnection Core
core PendingConnection
pending = do
  -- printRequest pending
  -- TODO: Validate the path and headers of the pending request
  AuthResult
authResult <- Core -> PendingConnection -> IO AuthResult
authorizePendingConnection Core
core PendingConnection
pending
  case AuthResult
authResult of
    AuthRejected AuthError
err ->
      PendingConnection -> RejectRequest -> IO ()
WS.rejectRequestWith PendingConnection
pending (RejectRequest -> IO ()) -> RejectRequest -> IO ()
forall a b. (a -> b) -> a -> b
$ RejectRequest :: Int -> ByteString -> Headers -> ByteString -> RejectRequest
WS.RejectRequest
        { rejectCode :: Int
WS.rejectCode = Int
401
        , rejectMessage :: ByteString
WS.rejectMessage = ByteString
"Unauthorized"
        , rejectHeaders :: Headers
WS.rejectHeaders = [(HeaderName
HttpHeader.hContentType, ByteString
"application/json")]
        , rejectBody :: ByteString
WS.rejectBody = ByteString -> ByteString
LBS.toStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ AuthError -> ByteString
errorResponseBody AuthError
err
        }
    AuthResult
AuthAccepted -> do
      let path :: [Text]
path = ([Text], Query) -> [Text]
forall a b. (a, b) -> a
fst (([Text], Query) -> [Text]) -> ([Text], Query) -> [Text]
forall a b. (a -> b) -> a -> b
$ ByteString -> ([Text], Query)
Uri.decodePath (ByteString -> ([Text], Query)) -> ByteString -> ([Text], Query)
forall a b. (a -> b) -> a -> b
$ RequestHead -> ByteString
WS.requestPath (RequestHead -> ByteString) -> RequestHead -> ByteString
forall a b. (a -> b) -> a -> b
$ PendingConnection -> RequestHead
WS.pendingRequest PendingConnection
pending
      Connection
connection <- PendingConnection -> IO Connection
WS.acceptRequest PendingConnection
pending
      -- Fork a pinging thread, for each client, to keep idle connections open and to detect
      -- closed connections. Sends a ping message every 30 seconds.
      -- Note: The thread dies silently if the connection crashes or is closed.
      Connection -> Int -> IO () -> IO () -> IO ()
forall a. Connection -> Int -> IO () -> IO a -> IO a
WS.withPingThread Connection
connection Int
30 (() -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> [Text] -> Core -> IO ()
handleClient Connection
connection [Text]
path Core
core

-- * Authorization

authorizePendingConnection :: Core -> WS.PendingConnection -> IO AuthResult
authorizePendingConnection :: Core -> PendingConnection -> IO AuthResult
authorizePendingConnection Core
core PendingConnection
conn
  | Config -> Bool
configEnableJwtAuth (Core -> Config
coreConfig Core
core) = do
      POSIXTime
now <- IO POSIXTime
Clock.getPOSIXTime
      let req :: RequestHead
req = PendingConnection -> RequestHead
WS.pendingRequest PendingConnection
conn
          ([Text]
path, Query
query) = ByteString -> ([Text], Query)
Uri.decodePath (ByteString -> ([Text], Query)) -> ByteString -> ([Text], Query)
forall a b. (a -> b) -> a -> b
$ RequestHead -> ByteString
WS.requestPath RequestHead
req
          headers :: Headers
headers = RequestHead -> Headers
WS.requestHeaders RequestHead
req
      AuthResult -> IO AuthResult
forall (m :: * -> *) a. Monad m => a -> m a
return (AuthResult -> IO AuthResult) -> AuthResult -> IO AuthResult
forall a b. (a -> b) -> a -> b
$ Headers
-> Query
-> POSIXTime
-> Maybe Signer
-> [Text]
-> AccessMode
-> AuthResult
isRequestAuthorized Headers
headers Query
query POSIXTime
now (Config -> Maybe Signer
configJwtSecret (Core -> Config
coreConfig Core
core)) [Text]
path AccessMode
ModeRead
  | Bool
otherwise = AuthResult -> IO AuthResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure AuthResult
AuthAccepted

-- * Client handling

handleClient :: WS.Connection -> Path -> Core -> IO ()
handleClient :: Connection -> [Text] -> Core -> IO ()
handleClient Connection
conn [Text]
path Core
core = do
  UUID
uuid <- IO UUID
newUUID
  let
    state :: MVar ServerState
state = Core -> MVar ServerState
coreClients Core
core
    onConnect :: IO ()
onConnect = do
      MVar ServerState -> (ServerState -> IO ServerState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar ServerState
state (ServerState -> IO ServerState
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ServerState -> IO ServerState)
-> (ServerState -> ServerState) -> ServerState -> IO ServerState
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Text] -> UUID -> Connection -> ServerState -> ServerState
forall id conn.
(Eq id, Hashable id) =>
[Text]
-> id
-> conn
-> SubscriptionTree id conn
-> SubscriptionTree id conn
Subscription.subscribe [Text]
path UUID
uuid Connection
conn)
      Core -> (IcepeakMetrics -> IO ()) -> IO ()
forall (m :: * -> *).
MonadIO m =>
Core -> (IcepeakMetrics -> IO ()) -> m ()
withCoreMetrics Core
core IcepeakMetrics -> IO ()
forall (m :: * -> *). MonadMonitor m => IcepeakMetrics -> m ()
Metrics.incrementSubscribers
    onDisconnect :: IO ()
onDisconnect = do
      MVar ServerState -> (ServerState -> IO ServerState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar ServerState
state (ServerState -> IO ServerState
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ServerState -> IO ServerState)
-> (ServerState -> ServerState) -> ServerState -> IO ServerState
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Text] -> UUID -> ServerState -> ServerState
forall id conn.
(Eq id, Hashable id) =>
[Text]
-> id -> SubscriptionTree id conn -> SubscriptionTree id conn
Subscription.unsubscribe [Text]
path UUID
uuid)
      Core -> (IcepeakMetrics -> IO ()) -> IO ()
forall (m :: * -> *).
MonadIO m =>
Core -> (IcepeakMetrics -> IO ()) -> m ()
withCoreMetrics Core
core IcepeakMetrics -> IO ()
forall (m :: * -> *). MonadMonitor m => IcepeakMetrics -> m ()
Metrics.decrementSubscribers
    sendInitialValue :: IO ()
sendInitialValue = do
      Maybe Value
currentValue <- Core -> [Text] -> IO (Maybe Value)
getCurrentValue Core
core [Text]
path
      Connection -> ByteString -> IO ()
forall a. WebSocketsData a => Connection -> a -> IO ()
WS.sendTextData Connection
conn (Maybe Value -> ByteString
forall a. ToJSON a => a -> ByteString
Aeson.encode Maybe Value
currentValue)

    -- simply ignore connection errors, otherwise, warp handles the exception
    -- and sends a 500 response in the middle of a websocket connection, and
    -- that violates the websocket protocol.
    -- Note that subscribers are still properly removed by the finally below
    handleConnectionError :: WS.ConnectionException -> IO ()
    handleConnectionError :: ConnectionException -> IO ()
handleConnectionError ConnectionException
_ = () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  -- Put the client in the subscription tree and keep the connection open.
  -- Remove it when the connection is closed.
  IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
finally (IO ()
onConnect IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO ()
sendInitialValue IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Connection -> IO ()
keepTalking Connection
conn) IO ()
onDisconnect
    IO () -> (ConnectionException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` ConnectionException -> IO ()
handleConnectionError

-- We don't send any messages here; sending is done by the update
-- loop; it finds the client in the set of subscriptions. But we do
-- need to keep the thread running, otherwise the connection will be
-- closed. So we go into an infinite loop here.
keepTalking :: WS.Connection -> IO ()
keepTalking :: Connection -> IO ()
keepTalking Connection
conn = IO DataMessage -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO DataMessage -> IO ()) -> IO DataMessage -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    -- Note: WS.receiveDataMessage will handle control messages automatically and e.g.
    -- do the closing handshake of the websocket protocol correctly
    Connection -> IO DataMessage
WS.receiveDataMessage Connection
conn

-- loop that is called for every update and that broadcasts the values to all
-- subscribers of the updated path
processUpdates :: Core -> IO ()
processUpdates :: Core -> IO ()
processUpdates Core
core = IO ()
go
  where
    go :: IO ()
go = do
      Maybe Updated
maybeUpdate <- STM (Maybe Updated) -> IO (Maybe Updated)
forall a. STM a -> IO a
atomically (STM (Maybe Updated) -> IO (Maybe Updated))
-> STM (Maybe Updated) -> IO (Maybe Updated)
forall a b. (a -> b) -> a -> b
$ TBQueue (Maybe Updated) -> STM (Maybe Updated)
forall a. TBQueue a -> STM a
readTBQueue (Core -> TBQueue (Maybe Updated)
coreUpdates Core
core)
      case Maybe Updated
maybeUpdate of
        Just (Updated [Text]
path Value
value) -> do
          ServerState
clients <- MVar ServerState -> IO ServerState
forall a. MVar a -> IO a
readMVar (Core -> MVar ServerState
coreClients Core
core)
          [Text] -> Value -> ServerState -> IO ()
broadcast [Text]
path Value
value ServerState
clients
          IO ()
go
        -- Stop the loop when we receive a Nothing.
        Maybe Updated
Nothing -> () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()