{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE DeriveDataTypeable #-}

module Network.Mattermost.WebSocket
( MMWebSocket
, MMWebSocketTimeoutException
, mmWithWebSocket
, mmCloseWebSocket
, mmSendWSAction
, mmGetConnectionHealth
, module Network.Mattermost.WebSocket.Types
) where

import           Control.Concurrent (ThreadId, forkIO, myThreadId, threadDelay)
import qualified Control.Concurrent.STM.TQueue as Queue
import           Control.Exception (Exception, SomeException, catch, throwIO, throwTo, try, evaluate)
import           Control.Monad (forever)
import           Control.Monad.STM (atomically)
import           Data.Aeson (toJSON)
import qualified Data.ByteString.Char8 as B
import           Data.ByteString.Lazy (toStrict)
import           Data.IORef
import           Data.Monoid ((<>))
import qualified Data.Text as T
import           Data.Time.Clock (NominalDiffTime, UTCTime, diffUTCTime, getCurrentTime)
import           Data.Typeable ( Typeable )
import           Network.Connection ( Connection
                                    , connectionClose
                                    , connectionGet
                                    , connectionPut
                                    )
import qualified Network.WebSockets as WS
import           Network.WebSockets.Stream (Stream, makeStream)

import           Network.Mattermost.Util
import           Network.Mattermost.Types.Base
import           Network.Mattermost.Types.Internal
import           Network.Mattermost.Types
import           Network.Mattermost.WebSocket.Types


connectionToStream :: Connection -> IO Stream
connectionToStream :: Connection -> IO Stream
connectionToStream Connection
con = IO (Maybe ByteString) -> (Maybe ByteString -> IO ()) -> IO Stream
makeStream IO (Maybe ByteString)
rd Maybe ByteString -> IO ()
wr
  where wr :: Maybe ByteString -> IO ()
wr Maybe ByteString
Nothing   = Connection -> IO ()
connectionClose Connection
con
        wr (Just ByteString
bs) = Connection -> ByteString -> IO ()
connectionPut Connection
con (ByteString -> ByteString
toStrict ByteString
bs)
        rd :: IO (Maybe ByteString)
rd = do
          ByteString
bs <- Connection -> Int -> IO ByteString
connectionGet Connection
con Int
1024
          forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ if ByteString -> Bool
B.null ByteString
bs
            then forall a. Maybe a
Nothing
            else forall a. a -> Maybe a
Just ByteString
bs

data MMWebSocket = MMWS WS.Connection (IORef NominalDiffTime)

data MMWebSocketTimeoutException = MMWebSocketTimeoutException
  deriving (Int -> MMWebSocketTimeoutException -> ShowS
[MMWebSocketTimeoutException] -> ShowS
MMWebSocketTimeoutException -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MMWebSocketTimeoutException] -> ShowS
$cshowList :: [MMWebSocketTimeoutException] -> ShowS
show :: MMWebSocketTimeoutException -> String
$cshow :: MMWebSocketTimeoutException -> String
showsPrec :: Int -> MMWebSocketTimeoutException -> ShowS
$cshowsPrec :: Int -> MMWebSocketTimeoutException -> ShowS
Show, Typeable)

instance Exception MMWebSocketTimeoutException where

data PEvent = P UTCTime

createPingPongTimeouts :: ThreadId
                       -> IORef NominalDiffTime
                       -> Int
                       -> (LogEventType -> IO ())
                       -> IO (IO (), IO (), ThreadId)
createPingPongTimeouts :: ThreadId
-> IORef NominalDiffTime
-> Int
-> (LogEventType -> IO ())
-> IO (IO (), IO (), ThreadId)
createPingPongTimeouts ThreadId
pId IORef NominalDiffTime
health Int
n LogEventType -> IO ()
doLog = do
  TQueue PEvent
pingChan <- forall a. IO (TQueue a)
Queue.newTQueueIO
  TQueue PEvent
pongChan <- forall a. IO (TQueue a)
Queue.newTQueueIO
  let pingAction :: IO ()
pingAction = do
        UTCTime
now <- IO UTCTime
getCurrentTime
        LogEventType -> IO ()
doLog LogEventType
WebSocketPing
        forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TQueue a -> a -> STM ()
Queue.writeTQueue TQueue PEvent
pingChan (UTCTime -> PEvent
P UTCTime
now)
  let pongAction :: IO ()
pongAction = do
        UTCTime
now <- IO UTCTime
getCurrentTime
        LogEventType -> IO ()
doLog LogEventType
WebSocketPong
        forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TQueue a -> a -> STM ()
Queue.writeTQueue TQueue PEvent
pongChan (UTCTime -> PEvent
P UTCTime
now)
  ThreadId
watchdogPId <- IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$ do
      let go :: IO ()
go = do
            P UTCTime
old <- forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TQueue a -> STM a
Queue.readTQueue TQueue PEvent
pingChan
            Int -> IO ()
threadDelay (Int
n forall a. Num a => a -> a -> a
* Int
1000 forall a. Num a => a -> a -> a
* Int
1000)
            Bool
b <- forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TQueue a -> STM Bool
Queue.isEmptyTQueue TQueue PEvent
pongChan
            if Bool
b
              then forall e. Exception e => ThreadId -> e -> IO ()
throwTo ThreadId
pId MMWebSocketTimeoutException
MMWebSocketTimeoutException
              else do
                P UTCTime
new <- forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TQueue a -> STM a
Queue.readTQueue TQueue PEvent
pongChan
                forall a. IORef a -> a -> IO ()
atomicWriteIORef IORef NominalDiffTime
health (UTCTime
new UTCTime -> UTCTime -> NominalDiffTime
`diffUTCTime` UTCTime
old)
                IO ()
go
      IO ()
go

  forall (m :: * -> *) a. Monad m => a -> m a
return (IO ()
pingAction, IO ()
pongAction, ThreadId
watchdogPId)

mmCloseWebSocket :: MMWebSocket -> IO ()
mmCloseWebSocket :: MMWebSocket -> IO ()
mmCloseWebSocket (MMWS Connection
c IORef NominalDiffTime
_) = forall a. WebSocketsData a => Connection -> a -> IO ()
WS.sendClose Connection
c ByteString
B.empty

mmGetConnectionHealth :: MMWebSocket -> IO NominalDiffTime
mmGetConnectionHealth :: MMWebSocket -> IO NominalDiffTime
mmGetConnectionHealth (MMWS Connection
_ IORef NominalDiffTime
h) = forall a. IORef a -> IO a
readIORef IORef NominalDiffTime
h

pingThread :: IO () -> WS.Connection -> IO ()
pingThread :: IO () -> Connection -> IO ()
pingThread IO ()
onPingAction Connection
conn = Int -> IO ()
loop Int
0
  where loop :: Int -> IO ()
        loop :: Int -> IO ()
loop Int
n = do
          Int -> IO ()
threadDelay (Int
10 forall a. Num a => a -> a -> a
* Int
1000 forall a. Num a => a -> a -> a
* Int
1000)
          IO ()
onPingAction
          forall a. WebSocketsData a => Connection -> a -> IO ()
WS.sendPing Connection
conn (String -> ByteString
B.pack (forall a. Show a => a -> String
show Int
n))
          Int -> IO ()
loop (Int
nforall a. Num a => a -> a -> a
+Int
1)

mmWithWebSocket :: Session
                -> (Either String (Either WebsocketActionResponse WebsocketEvent) -> IO ())
                -> (MMWebSocket -> IO ())
                -> IO ()
mmWithWebSocket :: Session
-> (Either String (Either WebsocketActionResponse WebsocketEvent)
    -> IO ())
-> (MMWebSocket -> IO ())
-> IO ()
mmWithWebSocket (Session ConnectionData
cd (Token String
tk)) Either String (Either WebsocketActionResponse WebsocketEvent)
-> IO ()
recv MMWebSocket -> IO ()
body = do
  Connection
con <- ConnectionContext -> Text -> Int -> ConnectionType -> IO Connection
mkConnection (ConnectionData -> ConnectionContext
cdConnectionCtx ConnectionData
cd) (ConnectionData -> Text
cdHostname ConnectionData
cd) (ConnectionData -> Int
cdPort ConnectionData
cd) (ConnectionData -> ConnectionType
cdConnectionType ConnectionData
cd)
  Stream
stream <- Connection -> IO Stream
connectionToStream Connection
con
  IORef NominalDiffTime
health <- forall a. a -> IO (IORef a)
newIORef NominalDiffTime
0
  ThreadId
myId <- IO ThreadId
myThreadId
  let doLog :: LogEventType -> IO ()
doLog = ConnectionData -> String -> LogEventType -> IO ()
runLogger ConnectionData
cd String
"websocket"
  (IO ()
onPing, IO ()
onPong, ThreadId
_) <- ThreadId
-> IORef NominalDiffTime
-> Int
-> (LogEventType -> IO ())
-> IO (IO (), IO (), ThreadId)
createPingPongTimeouts ThreadId
myId IORef NominalDiffTime
health Int
8 LogEventType -> IO ()
doLog
  let action :: Connection -> IO ()
action Connection
c = do
        ThreadId
pId <- IO () -> IO ThreadId
forkIO (IO () -> Connection -> IO ()
pingThread IO ()
onPing Connection
c forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` SomeException -> IO ()
cleanup)
        ThreadId
mId <- IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> b -> a -> c
flip forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch SomeException -> IO ()
cleanup forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Applicative f => f a -> f b
forever forall a b. (a -> b) -> a -> b
$ do
          Either SomeException DataMessage
result :: Either SomeException WS.DataMessage
                 <- forall e a. Exception e => IO a -> IO (Either e a)
try forall a b. (a -> b) -> a -> b
$ do
              DataMessage
msg <- Connection -> IO DataMessage
WS.receiveDataMessage Connection
c
              DataMessage
msg seq :: forall a b. a -> b -> b
`seq` forall (m :: * -> *) a. Monad m => a -> m a
return DataMessage
msg

          Either String (Either WebsocketActionResponse WebsocketEvent)
val <- case Either SomeException DataMessage
result of
                Left SomeException
e -> do
                    LogEventType -> IO ()
doLog forall a b. (a -> b) -> a -> b
$ Either String String -> LogEventType
WebSocketResponse forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ 
                        String
"Got exception on receiveDataMessage: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show SomeException
e
                    forall e a. Exception e => e -> IO a
throwIO SomeException
e
                Right DataMessage
dataMsg -> do
                    -- The message could be either a websocket event or
                    -- an action response. Those have different Haskell
                    -- types, so we need to attempt to parse each.
                    Either SomeException WebsocketEvent
evResult <- forall e a. Exception e => IO a -> IO (Either e a)
try forall a b. (a -> b) -> a -> b
$ forall a. a -> IO a
evaluate forall a b. (a -> b) -> a -> b
$ forall a. WebSocketsData a => DataMessage -> a
WS.fromDataMessage DataMessage
dataMsg
                    case Either SomeException WebsocketEvent
evResult of
                        Right WebsocketEvent
wev -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right WebsocketEvent
wev
                        Left (SomeException
e1::SomeException) -> do
                            Either SomeException WebsocketActionResponse
respResult <- forall e a. Exception e => IO a -> IO (Either e a)
try forall a b. (a -> b) -> a -> b
$ forall a. a -> IO a
evaluate forall a b. (a -> b) -> a -> b
$ forall a. WebSocketsData a => DataMessage -> a
WS.fromDataMessage DataMessage
dataMsg
                            case Either SomeException WebsocketActionResponse
respResult of
                                Right WebsocketActionResponse
actionResp -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left WebsocketActionResponse
actionResp
                                Left (SomeException
e2::SomeException) -> do
                                    LogEventType -> IO ()
doLog forall a b. (a -> b) -> a -> b
$ Either String String -> LogEventType
WebSocketResponse forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$
                                        String
"Failed to parse (exceptions following): " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show DataMessage
dataMsg
                                    LogEventType -> IO ()
doLog forall a b. (a -> b) -> a -> b
$ Either String String -> LogEventType
WebSocketResponse forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$
                                        String
"Failed to parse as a websocket event: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show SomeException
e1
                                    LogEventType -> IO ()
doLog forall a b. (a -> b) -> a -> b
$ Either String String -> LogEventType
WebSocketResponse forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$
                                        String
"Failed to parse as a websocket action response: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show SomeException
e2
                                    -- Log both exceptions, but throw
                                    -- the second. This isn't great
                                    -- because we don't know which
                                    -- exception is the *right* one. The
                                    -- best we can do is throw one of
                                    -- them and log both.
                                    forall e a. Exception e => e -> IO a
throwIO SomeException
e2

          LogEventType -> IO ()
doLog (Either String String -> LogEventType
WebSocketResponse forall a b. (a -> b) -> a -> b
$ case Either String (Either WebsocketActionResponse WebsocketEvent)
val of
                Left String
s -> forall a b. a -> Either a b
Left String
s
                Right (Left WebsocketActionResponse
v) -> forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> String
show WebsocketActionResponse
v
                Right (Right WebsocketEvent
v) -> forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> String
show WebsocketEvent
v
                )
          Either String (Either WebsocketActionResponse WebsocketEvent)
-> IO ()
recv Either String (Either WebsocketActionResponse WebsocketEvent)
val
        MMWebSocket -> IO ()
body (Connection -> IORef NominalDiffTime -> MMWebSocket
MMWS Connection
c IORef NominalDiffTime
health) forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` [ThreadId] -> SomeException -> IO ()
propagate [ThreadId
mId, ThreadId
pId]
  Text
path <- ConnectionData -> Text -> IO Text
buildPath ConnectionData
cd Text
"/websocket"
  forall a.
Stream
-> String
-> String
-> ConnectionOptions
-> Headers
-> ClientApp a
-> IO a
WS.runClientWithStream Stream
stream
                      (Text -> String
T.unpack forall a b. (a -> b) -> a -> b
$ ConnectionData -> Text
cdHostname ConnectionData
cd)
                      (Text -> String
T.unpack Text
path)
                      ConnectionOptions
WS.defaultConnectionOptions { connectionOnPong :: IO ()
WS.connectionOnPong = IO ()
onPong }
                      [ (CI ByteString
"Authorization", ByteString
"Bearer " forall a. Semigroup a => a -> a -> a
<> String -> ByteString
B.pack String
tk) ]
                      Connection -> IO ()
action
  where cleanup :: SomeException -> IO ()
        cleanup :: SomeException -> IO ()
cleanup SomeException
_ = forall (m :: * -> *) a. Monad m => a -> m a
return ()
        propagate :: [ThreadId] -> SomeException -> IO ()
        propagate :: [ThreadId] -> SomeException -> IO ()
propagate [ThreadId]
ts SomeException
e = do
          forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ forall e. Exception e => ThreadId -> e -> IO ()
throwTo ThreadId
t SomeException
e | ThreadId
t <- [ThreadId]
ts ]
          forall e a. Exception e => e -> IO a
throwIO SomeException
e

mmSendWSAction :: ConnectionData -> MMWebSocket -> WebsocketAction -> IO ()
mmSendWSAction :: ConnectionData -> MMWebSocket -> WebsocketAction -> IO ()
mmSendWSAction ConnectionData
cd (MMWS Connection
ws IORef NominalDiffTime
_) WebsocketAction
a = do
  ConnectionData -> String -> LogEventType -> IO ()
runLogger ConnectionData
cd String
"websocket" forall a b. (a -> b) -> a -> b
$ Value -> LogEventType
WebSocketRequest forall a b. (a -> b) -> a -> b
$ forall a. ToJSON a => a -> Value
toJSON WebsocketAction
a
  forall a. WebSocketsData a => Connection -> a -> IO ()
WS.sendTextData Connection
ws WebsocketAction
a