module Matterhorn.Connection where

import           Prelude ()
import           Matterhorn.Prelude

import           Control.Concurrent ( forkIO, threadDelay, killThread )
import qualified Control.Concurrent.STM as STM
import           Control.Exception ( SomeException, catch, AsyncException(..), throwIO )
import qualified Data.HashMap.Strict as HM
import           Data.Int (Int64)
import           Data.Semigroup ( Max(..) )
import qualified Data.Text as T
import           Data.Time ( UTCTime(..), secondsToDiffTime, getCurrentTime
                           , diffUTCTime )
import           Data.Time.Calendar ( Day(..) )
import           Lens.Micro.Platform ( (.=) )

import           Network.Mattermost.Types ( ChannelId )
import qualified Network.Mattermost.WebSocket as WS

import           Matterhorn.Constants
import           Matterhorn.Types


connectWebsockets :: MH ()
connectWebsockets :: MH ()
connectWebsockets = do
  LogCategory -> Text -> IO ()
logger <- MH (LogCategory -> Text -> IO ())
mhGetIOLogger

  -- If we have an old websocket thread, kill it.
  Maybe ThreadId
mOldTid <- forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use (Lens' ChatState ChatResources
csResourcesforall b c a. (b -> c) -> (a -> b) -> a -> c
.Lens' ChatResources (Maybe ThreadId)
crWebsocketThreadId)
  case Maybe ThreadId
mOldTid of
      Maybe ThreadId
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
      Just ThreadId
oldTid -> forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ do
          LogCategory -> Text -> IO ()
logger LogCategory
LogWebsocket Text
"Terminating previous websocket thread"
          ThreadId -> IO ()
killThread ThreadId
oldTid

  ChatState
st <- forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use forall a. a -> a
id
  Session
session <- MH Session
getSession

  ThreadId
tid <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ do
    let shunt :: Either String (Either WebsocketActionResponse WebsocketEvent)
-> m ()
shunt (Left String
msg) = forall (m :: * -> *). MonadIO m => BChan MHEvent -> MHEvent -> m ()
writeBChan (ChatState
stforall s a. s -> Getting a s a -> a
^.Lens' ChatState ChatResources
csResourcesforall b c a. (b -> c) -> (a -> b) -> a -> c
.Lens' ChatResources (BChan MHEvent)
crEventQueue) (String -> MHEvent
WebsocketParseError String
msg)
        shunt (Right (Right WebsocketEvent
e)) = forall (m :: * -> *). MonadIO m => BChan MHEvent -> MHEvent -> m ()
writeBChan (ChatState
stforall s a. s -> Getting a s a -> a
^.Lens' ChatState ChatResources
csResourcesforall b c a. (b -> c) -> (a -> b) -> a -> c
.Lens' ChatResources (BChan MHEvent)
crEventQueue) (WebsocketEvent -> MHEvent
WSEvent WebsocketEvent
e)
        shunt (Right (Left WebsocketActionResponse
e)) = forall (m :: * -> *). MonadIO m => BChan MHEvent -> MHEvent -> m ()
writeBChan (ChatState
stforall s a. s -> Getting a s a -> a
^.Lens' ChatState ChatResources
csResourcesforall b c a. (b -> c) -> (a -> b) -> a -> c
.Lens' ChatResources (BChan MHEvent)
crEventQueue) (WebsocketActionResponse -> MHEvent
WSActionResponse WebsocketActionResponse
e)
        runWS :: IO ()
runWS = Session
-> (Either String (Either WebsocketActionResponse WebsocketEvent)
    -> IO ())
-> (MMWebSocket -> IO ())
-> IO ()
WS.mmWithWebSocket Session
session forall {m :: * -> *}.
MonadIO m =>
Either String (Either WebsocketActionResponse WebsocketEvent)
-> m ()
shunt forall a b. (a -> b) -> a -> b
$ \MMWebSocket
ws -> do
                  forall (m :: * -> *). MonadIO m => BChan MHEvent -> MHEvent -> m ()
writeBChan (ChatState
stforall s a. s -> Getting a s a -> a
^.Lens' ChatState ChatResources
csResourcesforall b c a. (b -> c) -> (a -> b) -> a -> c
.Lens' ChatResources (BChan MHEvent)
crEventQueue) MHEvent
WebsocketConnect
                  ChatState
-> MMWebSocket -> Int64 -> HashMap ChannelId (Max UTCTime) -> IO ()
processWebsocketActions ChatState
st MMWebSocket
ws Int64
1 forall k v. HashMap k v
HM.empty
    LogCategory -> Text -> IO ()
logger LogCategory
LogWebsocket Text
"Starting new websocket thread"
    IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$ IO ()
runWS forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` AsyncException -> IO ()
ignoreThreadKilled
                   forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` (LogCategory -> Text -> IO ())
-> Int -> ChatState -> MMWebSocketTimeoutException -> IO ()
handleTimeout LogCategory -> Text -> IO ()
logger Int
1 ChatState
st
                   forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` (LogCategory -> Text -> IO ())
-> Int -> ChatState -> SomeException -> IO ()
handleError LogCategory -> Text -> IO ()
logger Int
5 ChatState
st

  Lens' ChatState ChatResources
csResourcesforall b c a. (b -> c) -> (a -> b) -> a -> c
.Lens' ChatResources (Maybe ThreadId)
crWebsocketThreadId forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a b -> b -> m ()
.= forall a. a -> Maybe a
Just ThreadId
tid

ignoreThreadKilled :: AsyncException -> IO ()
ignoreThreadKilled :: AsyncException -> IO ()
ignoreThreadKilled AsyncException
ThreadKilled = forall (m :: * -> *) a. Monad m => a -> m a
return ()
ignoreThreadKilled AsyncException
e = forall e a. Exception e => e -> IO a
throwIO AsyncException
e

-- | Take websocket actions from the websocket action channel in the
-- ChatState and send them to the server over the websocket.
--
-- Takes and propagates the action sequence number which is incremented
-- for each successful send.
--
-- Keeps and propagates a map of channel id to last user_typing
-- notification send time so that the new user_typing actions are
-- throttled to be send only once in two seconds.
processWebsocketActions :: ChatState -> WS.MMWebSocket -> Int64 -> HashMap ChannelId (Max UTCTime) -> IO ()
processWebsocketActions :: ChatState
-> MMWebSocket -> Int64 -> HashMap ChannelId (Max UTCTime) -> IO ()
processWebsocketActions ChatState
st MMWebSocket
ws Int64
s HashMap ChannelId (Max UTCTime)
userTypingLastNotifTimeMap = do
  WebsocketAction
action <- forall a. STM a -> IO a
STM.atomically forall a b. (a -> b) -> a -> b
$ forall a. TChan a -> STM a
STM.readTChan (ChatState
stforall s a. s -> Getting a s a -> a
^.Lens' ChatState ChatResources
csResourcesforall b c a. (b -> c) -> (a -> b) -> a -> c
.Lens' ChatResources (TChan WebsocketAction)
crWebsocketActionChan)
  if (WebsocketAction -> Bool
shouldSendAction WebsocketAction
action)
    then do
      ConnectionData -> MMWebSocket -> WebsocketAction -> IO ()
WS.mmSendWSAction (ChatState
stforall s a. s -> Getting a s a -> a
^.Lens' ChatState ChatResources
csResourcesforall b c a. (b -> c) -> (a -> b) -> a -> c
.Lens' ChatResources ConnectionData
crConn) MMWebSocket
ws forall a b. (a -> b) -> a -> b
$ WebsocketAction -> WebsocketAction
convert WebsocketAction
action
      UTCTime
now <- IO UTCTime
getCurrentTime
      ChatState
-> MMWebSocket -> Int64 -> HashMap ChannelId (Max UTCTime) -> IO ()
processWebsocketActions ChatState
st MMWebSocket
ws (Int64
s forall a. Num a => a -> a -> a
+ Int64
1) forall a b. (a -> b) -> a -> b
$ WebsocketAction -> UTCTime -> HashMap ChannelId (Max UTCTime)
userTypingLastNotifTimeMap' WebsocketAction
action UTCTime
now
    else do
      ChatState
-> MMWebSocket -> Int64 -> HashMap ChannelId (Max UTCTime) -> IO ()
processWebsocketActions ChatState
st MMWebSocket
ws Int64
s HashMap ChannelId (Max UTCTime)
userTypingLastNotifTimeMap
  where
    convert :: WebsocketAction -> WebsocketAction
convert (UserTyping UTCTime
_ ChannelId
cId Maybe PostId
pId) = Int64 -> ChannelId -> Maybe PostId -> WebsocketAction
WS.UserTyping Int64
s ChannelId
cId Maybe PostId
pId

    shouldSendAction :: WebsocketAction -> Bool
shouldSendAction (UserTyping UTCTime
ts ChannelId
cId Maybe PostId
_) =
      UTCTime -> UTCTime -> NominalDiffTime
diffUTCTime UTCTime
ts (ChannelId -> UTCTime
userTypingLastNotifTime ChannelId
cId) forall a. Ord a => a -> a -> Bool
>= (NominalDiffTime
userTypingExpiryInterval forall a. Fractional a => a -> a -> a
/ NominalDiffTime
2 forall a. Num a => a -> a -> a
- NominalDiffTime
0.5)

    userTypingLastNotifTime :: ChannelId -> UTCTime
userTypingLastNotifTime ChannelId
cId = forall a. Max a -> a
getMax forall a b. (a -> b) -> a -> b
$ forall k v. (Eq k, Hashable k) => v -> k -> HashMap k v -> v
HM.lookupDefault (forall a. a -> Max a
Max UTCTime
zeroTime) ChannelId
cId HashMap ChannelId (Max UTCTime)
userTypingLastNotifTimeMap

    zeroTime :: UTCTime
zeroTime = Day -> DiffTime -> UTCTime
UTCTime (Integer -> Day
ModifiedJulianDay Integer
0) (Integer -> DiffTime
secondsToDiffTime Integer
0)

    userTypingLastNotifTimeMap' :: WebsocketAction -> UTCTime -> HashMap ChannelId (Max UTCTime)
userTypingLastNotifTimeMap' (UserTyping UTCTime
_ ChannelId
cId Maybe PostId
_) UTCTime
now =
      forall k v.
(Eq k, Hashable k) =>
(v -> v -> v) -> k -> v -> HashMap k v -> HashMap k v
HM.insertWith forall a. Semigroup a => a -> a -> a
(<>) ChannelId
cId (forall a. a -> Max a
Max UTCTime
now) HashMap ChannelId (Max UTCTime)
userTypingLastNotifTimeMap

handleTimeout :: (LogCategory -> Text -> IO ()) -> Int -> ChatState -> WS.MMWebSocketTimeoutException -> IO ()
handleTimeout :: (LogCategory -> Text -> IO ())
-> Int -> ChatState -> MMWebSocketTimeoutException -> IO ()
handleTimeout LogCategory -> Text -> IO ()
logger Int
seconds ChatState
st MMWebSocketTimeoutException
e = do
    LogCategory -> Text -> IO ()
logger LogCategory
LogWebsocket forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack forall a b. (a -> b) -> a -> b
$ String
"Websocket timeout exception: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show MMWebSocketTimeoutException
e
    Int -> ChatState -> IO ()
reconnectAfter Int
seconds ChatState
st

handleError :: (LogCategory -> Text -> IO ()) -> Int -> ChatState -> SomeException -> IO ()
handleError :: (LogCategory -> Text -> IO ())
-> Int -> ChatState -> SomeException -> IO ()
handleError LogCategory -> Text -> IO ()
logger Int
seconds ChatState
st SomeException
e = do
    LogCategory -> Text -> IO ()
logger LogCategory
LogWebsocket forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack forall a b. (a -> b) -> a -> b
$ String
"Websocket error: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show SomeException
e
    Int -> ChatState -> IO ()
reconnectAfter Int
seconds ChatState
st

reconnectAfter :: Int -> ChatState -> IO ()
reconnectAfter :: Int -> ChatState -> IO ()
reconnectAfter Int
seconds ChatState
st = do
  forall (m :: * -> *). MonadIO m => BChan MHEvent -> MHEvent -> m ()
writeBChan (ChatState
stforall s a. s -> Getting a s a -> a
^.Lens' ChatState ChatResources
csResourcesforall b c a. (b -> c) -> (a -> b) -> a -> c
.Lens' ChatResources (BChan MHEvent)
crEventQueue) MHEvent
WebsocketDisconnect
  Int -> IO ()
threadDelay (Int
seconds forall a. Num a => a -> a -> a
* Int
1000 forall a. Num a => a -> a -> a
* Int
1000)
  forall (m :: * -> *). MonadIO m => BChan MHEvent -> MHEvent -> m ()
writeBChan (ChatState
stforall s a. s -> Getting a s a -> a
^.Lens' ChatState ChatResources
csResourcesforall b c a. (b -> c) -> (a -> b) -> a -> c
.Lens' ChatResources (BChan MHEvent)
crEventQueue) MHEvent
RefreshWebsocketEvent