module Matterhorn.Connection where

import           Prelude ()
import           Matterhorn.Prelude

import           Control.Concurrent ( forkIO, threadDelay, killThread )
import qualified Control.Concurrent.STM as STM
import           Control.Exception ( SomeException, catch, AsyncException(..), throwIO )
import qualified Data.HashMap.Strict as HM
import           Data.Int (Int64)
import           Data.Semigroup ( Max(..) )
import qualified Data.Text as T
import           Data.Time ( UTCTime(..), secondsToDiffTime, getCurrentTime
                           , diffUTCTime )
import           Data.Time.Calendar ( Day(..) )
import           Lens.Micro.Platform ( (.=) )

import           Network.Mattermost.Types ( ChannelId )
import qualified Network.Mattermost.WebSocket as WS

import           Matterhorn.Constants
import           Matterhorn.Types


connectWebsockets :: MH ()
connectWebsockets :: MH ()
connectWebsockets = do
  LogCategory -> Text -> IO ()
logger <- MH (LogCategory -> Text -> IO ())
mhGetIOLogger

  -- If we have an old websocket thread, kill it.
  Maybe ThreadId
mOldTid <- Getting (Maybe ThreadId) ChatState (Maybe ThreadId)
-> MH (Maybe ThreadId)
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use ((ChatResources -> Const (Maybe ThreadId) ChatResources)
-> ChatState -> Const (Maybe ThreadId) ChatState
Lens' ChatState ChatResources
csResources((ChatResources -> Const (Maybe ThreadId) ChatResources)
 -> ChatState -> Const (Maybe ThreadId) ChatState)
-> ((Maybe ThreadId -> Const (Maybe ThreadId) (Maybe ThreadId))
    -> ChatResources -> Const (Maybe ThreadId) ChatResources)
-> Getting (Maybe ThreadId) ChatState (Maybe ThreadId)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Maybe ThreadId -> Const (Maybe ThreadId) (Maybe ThreadId))
-> ChatResources -> Const (Maybe ThreadId) ChatResources
Lens' ChatResources (Maybe ThreadId)
crWebsocketThreadId)
  case Maybe ThreadId
mOldTid of
      Maybe ThreadId
Nothing -> () -> MH ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      Just ThreadId
oldTid -> IO () -> MH ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> MH ()) -> IO () -> MH ()
forall a b. (a -> b) -> a -> b
$ do
          LogCategory -> Text -> IO ()
logger LogCategory
LogWebsocket Text
"Terminating previous websocket thread"
          ThreadId -> IO ()
killThread ThreadId
oldTid

  ChatState
st <- Getting ChatState ChatState ChatState -> MH ChatState
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use Getting ChatState ChatState ChatState
forall a. a -> a
id
  Session
session <- MH Session
getSession

  ThreadId
tid <- IO ThreadId -> MH ThreadId
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ThreadId -> MH ThreadId) -> IO ThreadId -> MH ThreadId
forall a b. (a -> b) -> a -> b
$ do
    let shunt :: Either String (Either WebsocketActionResponse WebsocketEvent)
-> m ()
shunt (Left String
msg) = BChan MHEvent -> MHEvent -> m ()
forall (m :: * -> *). MonadIO m => BChan MHEvent -> MHEvent -> m ()
writeBChan (ChatState
stChatState
-> Getting (BChan MHEvent) ChatState (BChan MHEvent)
-> BChan MHEvent
forall s a. s -> Getting a s a -> a
^.(ChatResources -> Const (BChan MHEvent) ChatResources)
-> ChatState -> Const (BChan MHEvent) ChatState
Lens' ChatState ChatResources
csResources((ChatResources -> Const (BChan MHEvent) ChatResources)
 -> ChatState -> Const (BChan MHEvent) ChatState)
-> ((BChan MHEvent -> Const (BChan MHEvent) (BChan MHEvent))
    -> ChatResources -> Const (BChan MHEvent) ChatResources)
-> Getting (BChan MHEvent) ChatState (BChan MHEvent)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(BChan MHEvent -> Const (BChan MHEvent) (BChan MHEvent))
-> ChatResources -> Const (BChan MHEvent) ChatResources
Lens' ChatResources (BChan MHEvent)
crEventQueue) (String -> MHEvent
WebsocketParseError String
msg)
        shunt (Right (Right WebsocketEvent
e)) = BChan MHEvent -> MHEvent -> m ()
forall (m :: * -> *). MonadIO m => BChan MHEvent -> MHEvent -> m ()
writeBChan (ChatState
stChatState
-> Getting (BChan MHEvent) ChatState (BChan MHEvent)
-> BChan MHEvent
forall s a. s -> Getting a s a -> a
^.(ChatResources -> Const (BChan MHEvent) ChatResources)
-> ChatState -> Const (BChan MHEvent) ChatState
Lens' ChatState ChatResources
csResources((ChatResources -> Const (BChan MHEvent) ChatResources)
 -> ChatState -> Const (BChan MHEvent) ChatState)
-> ((BChan MHEvent -> Const (BChan MHEvent) (BChan MHEvent))
    -> ChatResources -> Const (BChan MHEvent) ChatResources)
-> Getting (BChan MHEvent) ChatState (BChan MHEvent)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(BChan MHEvent -> Const (BChan MHEvent) (BChan MHEvent))
-> ChatResources -> Const (BChan MHEvent) ChatResources
Lens' ChatResources (BChan MHEvent)
crEventQueue) (WebsocketEvent -> MHEvent
WSEvent WebsocketEvent
e)
        shunt (Right (Left WebsocketActionResponse
e)) = BChan MHEvent -> MHEvent -> m ()
forall (m :: * -> *). MonadIO m => BChan MHEvent -> MHEvent -> m ()
writeBChan (ChatState
stChatState
-> Getting (BChan MHEvent) ChatState (BChan MHEvent)
-> BChan MHEvent
forall s a. s -> Getting a s a -> a
^.(ChatResources -> Const (BChan MHEvent) ChatResources)
-> ChatState -> Const (BChan MHEvent) ChatState
Lens' ChatState ChatResources
csResources((ChatResources -> Const (BChan MHEvent) ChatResources)
 -> ChatState -> Const (BChan MHEvent) ChatState)
-> ((BChan MHEvent -> Const (BChan MHEvent) (BChan MHEvent))
    -> ChatResources -> Const (BChan MHEvent) ChatResources)
-> Getting (BChan MHEvent) ChatState (BChan MHEvent)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(BChan MHEvent -> Const (BChan MHEvent) (BChan MHEvent))
-> ChatResources -> Const (BChan MHEvent) ChatResources
Lens' ChatResources (BChan MHEvent)
crEventQueue) (WebsocketActionResponse -> MHEvent
WSActionResponse WebsocketActionResponse
e)
        runWS :: IO ()
runWS = Session
-> (Either String (Either WebsocketActionResponse WebsocketEvent)
    -> IO ())
-> (MMWebSocket -> IO ())
-> IO ()
WS.mmWithWebSocket Session
session Either String (Either WebsocketActionResponse WebsocketEvent)
-> IO ()
forall (m :: * -> *).
MonadIO m =>
Either String (Either WebsocketActionResponse WebsocketEvent)
-> m ()
shunt ((MMWebSocket -> IO ()) -> IO ())
-> (MMWebSocket -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \MMWebSocket
ws -> do
                  BChan MHEvent -> MHEvent -> IO ()
forall (m :: * -> *). MonadIO m => BChan MHEvent -> MHEvent -> m ()
writeBChan (ChatState
stChatState
-> Getting (BChan MHEvent) ChatState (BChan MHEvent)
-> BChan MHEvent
forall s a. s -> Getting a s a -> a
^.(ChatResources -> Const (BChan MHEvent) ChatResources)
-> ChatState -> Const (BChan MHEvent) ChatState
Lens' ChatState ChatResources
csResources((ChatResources -> Const (BChan MHEvent) ChatResources)
 -> ChatState -> Const (BChan MHEvent) ChatState)
-> ((BChan MHEvent -> Const (BChan MHEvent) (BChan MHEvent))
    -> ChatResources -> Const (BChan MHEvent) ChatResources)
-> Getting (BChan MHEvent) ChatState (BChan MHEvent)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(BChan MHEvent -> Const (BChan MHEvent) (BChan MHEvent))
-> ChatResources -> Const (BChan MHEvent) ChatResources
Lens' ChatResources (BChan MHEvent)
crEventQueue) MHEvent
WebsocketConnect
                  ChatState
-> MMWebSocket -> Int64 -> HashMap ChannelId (Max UTCTime) -> IO ()
processWebsocketActions ChatState
st MMWebSocket
ws Int64
1 HashMap ChannelId (Max UTCTime)
forall k v. HashMap k v
HM.empty
    LogCategory -> Text -> IO ()
logger LogCategory
LogWebsocket Text
"Starting new websocket thread"
    IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ IO ()
runWS IO () -> (AsyncException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` AsyncException -> IO ()
ignoreThreadKilled
                   IO () -> (MMWebSocketTimeoutException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` (LogCategory -> Text -> IO ())
-> Int -> ChatState -> MMWebSocketTimeoutException -> IO ()
handleTimeout LogCategory -> Text -> IO ()
logger Int
1 ChatState
st
                   IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` (LogCategory -> Text -> IO ())
-> Int -> ChatState -> SomeException -> IO ()
handleError LogCategory -> Text -> IO ()
logger Int
5 ChatState
st

  (ChatResources -> Identity ChatResources)
-> ChatState -> Identity ChatState
Lens' ChatState ChatResources
csResources((ChatResources -> Identity ChatResources)
 -> ChatState -> Identity ChatState)
-> ((Maybe ThreadId -> Identity (Maybe ThreadId))
    -> ChatResources -> Identity ChatResources)
-> (Maybe ThreadId -> Identity (Maybe ThreadId))
-> ChatState
-> Identity ChatState
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(Maybe ThreadId -> Identity (Maybe ThreadId))
-> ChatResources -> Identity ChatResources
Lens' ChatResources (Maybe ThreadId)
crWebsocketThreadId ((Maybe ThreadId -> Identity (Maybe ThreadId))
 -> ChatState -> Identity ChatState)
-> Maybe ThreadId -> MH ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
.= ThreadId -> Maybe ThreadId
forall a. a -> Maybe a
Just ThreadId
tid

ignoreThreadKilled :: AsyncException -> IO ()
ignoreThreadKilled :: AsyncException -> IO ()
ignoreThreadKilled AsyncException
ThreadKilled = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
ignoreThreadKilled AsyncException
e = AsyncException -> IO ()
forall e a. Exception e => e -> IO a
throwIO AsyncException
e

-- | Take websocket actions from the websocket action channel in the
-- ChatState and send them to the server over the websocket.
--
-- Takes and propagates the action sequence number which is incremented
-- for each successful send.
--
-- Keeps and propagates a map of channel id to last user_typing
-- notification send time so that the new user_typing actions are
-- throttled to be send only once in two seconds.
processWebsocketActions :: ChatState -> WS.MMWebSocket -> Int64 -> HashMap ChannelId (Max UTCTime) -> IO ()
processWebsocketActions :: ChatState
-> MMWebSocket -> Int64 -> HashMap ChannelId (Max UTCTime) -> IO ()
processWebsocketActions ChatState
st MMWebSocket
ws Int64
s HashMap ChannelId (Max UTCTime)
userTypingLastNotifTimeMap = do
  WebsocketAction
action <- STM WebsocketAction -> IO WebsocketAction
forall a. STM a -> IO a
STM.atomically (STM WebsocketAction -> IO WebsocketAction)
-> STM WebsocketAction -> IO WebsocketAction
forall a b. (a -> b) -> a -> b
$ TChan WebsocketAction -> STM WebsocketAction
forall a. TChan a -> STM a
STM.readTChan (ChatState
stChatState
-> Getting
     (TChan WebsocketAction) ChatState (TChan WebsocketAction)
-> TChan WebsocketAction
forall s a. s -> Getting a s a -> a
^.(ChatResources -> Const (TChan WebsocketAction) ChatResources)
-> ChatState -> Const (TChan WebsocketAction) ChatState
Lens' ChatState ChatResources
csResources((ChatResources -> Const (TChan WebsocketAction) ChatResources)
 -> ChatState -> Const (TChan WebsocketAction) ChatState)
-> ((TChan WebsocketAction
     -> Const (TChan WebsocketAction) (TChan WebsocketAction))
    -> ChatResources -> Const (TChan WebsocketAction) ChatResources)
-> Getting
     (TChan WebsocketAction) ChatState (TChan WebsocketAction)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(TChan WebsocketAction
 -> Const (TChan WebsocketAction) (TChan WebsocketAction))
-> ChatResources -> Const (TChan WebsocketAction) ChatResources
Lens' ChatResources (TChan WebsocketAction)
crWebsocketActionChan)
  if (WebsocketAction -> Bool
shouldSendAction WebsocketAction
action)
    then do
      ConnectionData -> MMWebSocket -> WebsocketAction -> IO ()
WS.mmSendWSAction (ChatState
stChatState
-> Getting ConnectionData ChatState ConnectionData
-> ConnectionData
forall s a. s -> Getting a s a -> a
^.(ChatResources -> Const ConnectionData ChatResources)
-> ChatState -> Const ConnectionData ChatState
Lens' ChatState ChatResources
csResources((ChatResources -> Const ConnectionData ChatResources)
 -> ChatState -> Const ConnectionData ChatState)
-> ((ConnectionData -> Const ConnectionData ConnectionData)
    -> ChatResources -> Const ConnectionData ChatResources)
-> Getting ConnectionData ChatState ConnectionData
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(ConnectionData -> Const ConnectionData ConnectionData)
-> ChatResources -> Const ConnectionData ChatResources
Lens' ChatResources ConnectionData
crConn) MMWebSocket
ws (WebsocketAction -> IO ()) -> WebsocketAction -> IO ()
forall a b. (a -> b) -> a -> b
$ WebsocketAction -> WebsocketAction
convert WebsocketAction
action
      UTCTime
now <- IO UTCTime
getCurrentTime
      ChatState
-> MMWebSocket -> Int64 -> HashMap ChannelId (Max UTCTime) -> IO ()
processWebsocketActions ChatState
st MMWebSocket
ws (Int64
s Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64
1) (HashMap ChannelId (Max UTCTime) -> IO ())
-> HashMap ChannelId (Max UTCTime) -> IO ()
forall a b. (a -> b) -> a -> b
$ WebsocketAction -> UTCTime -> HashMap ChannelId (Max UTCTime)
userTypingLastNotifTimeMap' WebsocketAction
action UTCTime
now
    else do
      ChatState
-> MMWebSocket -> Int64 -> HashMap ChannelId (Max UTCTime) -> IO ()
processWebsocketActions ChatState
st MMWebSocket
ws Int64
s HashMap ChannelId (Max UTCTime)
userTypingLastNotifTimeMap
  where
    convert :: WebsocketAction -> WebsocketAction
convert (UserTyping UTCTime
_ ChannelId
cId Maybe PostId
pId) = Int64 -> ChannelId -> Maybe PostId -> WebsocketAction
WS.UserTyping Int64
s ChannelId
cId Maybe PostId
pId

    shouldSendAction :: WebsocketAction -> Bool
shouldSendAction (UserTyping UTCTime
ts ChannelId
cId Maybe PostId
_) =
      UTCTime -> UTCTime -> NominalDiffTime
diffUTCTime UTCTime
ts (ChannelId -> UTCTime
userTypingLastNotifTime ChannelId
cId) NominalDiffTime -> NominalDiffTime -> Bool
forall a. Ord a => a -> a -> Bool
>= (NominalDiffTime
userTypingExpiryInterval NominalDiffTime -> NominalDiffTime -> NominalDiffTime
forall a. Fractional a => a -> a -> a
/ NominalDiffTime
2 NominalDiffTime -> NominalDiffTime -> NominalDiffTime
forall a. Num a => a -> a -> a
- NominalDiffTime
0.5)

    userTypingLastNotifTime :: ChannelId -> UTCTime
userTypingLastNotifTime ChannelId
cId = Max UTCTime -> UTCTime
forall a. Max a -> a
getMax (Max UTCTime -> UTCTime) -> Max UTCTime -> UTCTime
forall a b. (a -> b) -> a -> b
$ Max UTCTime
-> ChannelId -> HashMap ChannelId (Max UTCTime) -> Max UTCTime
forall k v. (Eq k, Hashable k) => v -> k -> HashMap k v -> v
HM.lookupDefault (UTCTime -> Max UTCTime
forall a. a -> Max a
Max UTCTime
zeroTime) ChannelId
cId HashMap ChannelId (Max UTCTime)
userTypingLastNotifTimeMap

    zeroTime :: UTCTime
zeroTime = Day -> DiffTime -> UTCTime
UTCTime (Integer -> Day
ModifiedJulianDay Integer
0) (Integer -> DiffTime
secondsToDiffTime Integer
0)

    userTypingLastNotifTimeMap' :: WebsocketAction -> UTCTime -> HashMap ChannelId (Max UTCTime)
userTypingLastNotifTimeMap' (UserTyping UTCTime
_ ChannelId
cId Maybe PostId
_) UTCTime
now =
      (Max UTCTime -> Max UTCTime -> Max UTCTime)
-> ChannelId
-> Max UTCTime
-> HashMap ChannelId (Max UTCTime)
-> HashMap ChannelId (Max UTCTime)
forall k v.
(Eq k, Hashable k) =>
(v -> v -> v) -> k -> v -> HashMap k v -> HashMap k v
HM.insertWith Max UTCTime -> Max UTCTime -> Max UTCTime
forall a. Semigroup a => a -> a -> a
(<>) ChannelId
cId (UTCTime -> Max UTCTime
forall a. a -> Max a
Max UTCTime
now) HashMap ChannelId (Max UTCTime)
userTypingLastNotifTimeMap

handleTimeout :: (LogCategory -> Text -> IO ()) -> Int -> ChatState -> WS.MMWebSocketTimeoutException -> IO ()
handleTimeout :: (LogCategory -> Text -> IO ())
-> Int -> ChatState -> MMWebSocketTimeoutException -> IO ()
handleTimeout LogCategory -> Text -> IO ()
logger Int
seconds ChatState
st MMWebSocketTimeoutException
e = do
    LogCategory -> Text -> IO ()
logger LogCategory
LogWebsocket (Text -> IO ()) -> Text -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ String
"Websocket timeout exception: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> MMWebSocketTimeoutException -> String
forall a. Show a => a -> String
show MMWebSocketTimeoutException
e
    Int -> ChatState -> IO ()
reconnectAfter Int
seconds ChatState
st

handleError :: (LogCategory -> Text -> IO ()) -> Int -> ChatState -> SomeException -> IO ()
handleError :: (LogCategory -> Text -> IO ())
-> Int -> ChatState -> SomeException -> IO ()
handleError LogCategory -> Text -> IO ()
logger Int
seconds ChatState
st SomeException
e = do
    LogCategory -> Text -> IO ()
logger LogCategory
LogWebsocket (Text -> IO ()) -> Text -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ String
"Websocket error: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> SomeException -> String
forall a. Show a => a -> String
show SomeException
e
    Int -> ChatState -> IO ()
reconnectAfter Int
seconds ChatState
st

reconnectAfter :: Int -> ChatState -> IO ()
reconnectAfter :: Int -> ChatState -> IO ()
reconnectAfter Int
seconds ChatState
st = do
  BChan MHEvent -> MHEvent -> IO ()
forall (m :: * -> *). MonadIO m => BChan MHEvent -> MHEvent -> m ()
writeBChan (ChatState
stChatState
-> Getting (BChan MHEvent) ChatState (BChan MHEvent)
-> BChan MHEvent
forall s a. s -> Getting a s a -> a
^.(ChatResources -> Const (BChan MHEvent) ChatResources)
-> ChatState -> Const (BChan MHEvent) ChatState
Lens' ChatState ChatResources
csResources((ChatResources -> Const (BChan MHEvent) ChatResources)
 -> ChatState -> Const (BChan MHEvent) ChatState)
-> ((BChan MHEvent -> Const (BChan MHEvent) (BChan MHEvent))
    -> ChatResources -> Const (BChan MHEvent) ChatResources)
-> Getting (BChan MHEvent) ChatState (BChan MHEvent)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(BChan MHEvent -> Const (BChan MHEvent) (BChan MHEvent))
-> ChatResources -> Const (BChan MHEvent) ChatResources
Lens' ChatResources (BChan MHEvent)
crEventQueue) MHEvent
WebsocketDisconnect
  Int -> IO ()
threadDelay (Int
seconds Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1000 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1000)
  BChan MHEvent -> MHEvent -> IO ()
forall (m :: * -> *). MonadIO m => BChan MHEvent -> MHEvent -> m ()
writeBChan (ChatState
stChatState
-> Getting (BChan MHEvent) ChatState (BChan MHEvent)
-> BChan MHEvent
forall s a. s -> Getting a s a -> a
^.(ChatResources -> Const (BChan MHEvent) ChatResources)
-> ChatState -> Const (BChan MHEvent) ChatState
Lens' ChatState ChatResources
csResources((ChatResources -> Const (BChan MHEvent) ChatResources)
 -> ChatState -> Const (BChan MHEvent) ChatState)
-> ((BChan MHEvent -> Const (BChan MHEvent) (BChan MHEvent))
    -> ChatResources -> Const (BChan MHEvent) ChatResources)
-> Getting (BChan MHEvent) ChatState (BChan MHEvent)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(BChan MHEvent -> Const (BChan MHEvent) (BChan MHEvent))
-> ChatResources -> Const (BChan MHEvent) ChatResources
Lens' ChatResources (BChan MHEvent)
crEventQueue) MHEvent
RefreshWebsocketEvent