{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DoAndIfThenElse #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
module DMCC.Session
( Session (..)
, ConnectionType (..)
, startSession
, stopSession
, defaultSessionOptions
, DMCCError (..)
, DMCCHandle (..)
, sendRequestSync
, sendRequestAsync
)
where
import DMCC.Prelude
import Control.Arrow ()
import Control.Concurrent.STM.TMVar (tryPutTMVar)
import Data.ByteString (ByteString)
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import qualified Data.IntMap.Strict as IntMap
import Data.Text as T (Text, empty)
import Data.Typeable ()
import System.IO.Streams ( InputStream
, OutputStream
, ReadTooShortException
, write
)
import System.IO.Streams.Handle
import qualified System.IO.Streams.SSL as SSLStreams
import Network
import qualified Network.HTTP.Client as HTTP
import Network.Socket hiding (connect)
import OpenSSL
import qualified OpenSSL.Session as SSL
import DMCC.Types
import DMCC.XML.Request (Request)
import qualified DMCC.XML.Request as Rq
import DMCC.XML.Response (Response)
import qualified DMCC.XML.Response as Rs
import qualified DMCC.XML.Raw as Raw
import {-# SOURCE #-} DMCC.Agent
data ConnectionType = Plain
| TLS { caDir :: Maybe FilePath }
type ConnectionData = (InputStream ByteString, OutputStream ByteString, IO ())
data DMCCHandle = DMCCHandle
{ connection :: TMVar ConnectionData
, dmccSession :: TMVar (Text, Int)
, reconnect :: forall m
. (MonadUnliftIO m, MonadLoggerIO m, MonadBaseControl IO m, MonadCatch m)
=> m ()
, pingThread :: ThreadId
, readThread :: ThreadId
, procThread :: ThreadId
, invokeId :: TVar Int
, syncResponses :: TVar (IntMap.IntMap (TMVar (Maybe Response)))
, agentRequests :: TVar (IntMap.IntMap AgentId)
, sessionOptions :: SessionOptions
}
data Session = Session
{ protocolVersion :: Text
, dmccHandle :: DMCCHandle
, webHook :: Maybe (HTTP.Request, HTTP.Manager)
, agents :: TVar (Map.Map AgentId Agent)
, agentLocks :: TVar (Set.Set AgentId)
}
instance Show Session where
show as = "Session{protocolVersion=" <> unpack (protocolVersion as) <> "}"
data LoopEvent
= DMCCRsp Response
| Timeout
| ReadError
deriving Show
data DMCCError = ApplicationSessionFailed
deriving (Show, Typeable)
instance Exception DMCCError
defaultSessionOptions :: SessionOptions
defaultSessionOptions = SessionOptions 1 120 24 5
startSession :: (MonadUnliftIO m, MonadLoggerIO m, MonadBaseControl IO m, MonadCatch m)
=> (String, PortNumber)
-> ConnectionType
-> Text
-> Text
-> Maybe String
-> SessionOptions
-> m Session
startSession (host, port) ct user pass whUrl sopts = do
syncResponses <- newTVarIO IntMap.empty
agentRequests <- newTVarIO IntMap.empty
invoke <- newTVarIO 0
conn <- newEmptyTMVarIO
sess <- newEmptyTMVarIO
let
connect :: (MonadUnliftIO m, MonadBase IO m, MonadLoggerIO m, MonadCatch m) => m ConnectionData
connect = connect1 (connectionRetryAttempts sopts)
where
connectExHandler
:: (Exception e, Show e, MonadUnliftIO m, MonadLoggerIO m, MonadBase IO m, MonadCatch m)
=> Int -> e -> m ConnectionData
connectExHandler attempts e = do
logErrorN $ "Connection failed: " <> tshow e
if attempts > 0
then threadDelay (connectionRetryDelay sopts * 1000000) >> connect1 (attempts - 1)
else throwIO e
connect1 attempts =
handleNetwork (connectExHandler attempts) $
liftIO $ case ct of
Plain -> do
h <- connectTo host (PortNumber $ fromIntegral port)
hSetBuffering h NoBuffering
is <- handleToInputStream h
os <- handleToOutputStream h
let cl = hClose h
pure (is, os, cl)
TLS caDir -> withOpenSSL $ do
sslCtx <- SSL.context
SSL.contextSetDefaultCiphers sslCtx
SSL.contextSetVerificationMode sslCtx $
SSL.VerifyPeer True True Nothing
maybe (pure ()) (SSL.contextSetCADirectory sslCtx) caDir
(is, os, ssl) <- SSLStreams.connect sslCtx host port
let cl = do
SSL.shutdown ssl SSL.Unidirectional
maybe (pure ()) close $ SSL.sslSocket ssl
pure (is, os, cl)
startDMCCSession :: (MonadUnliftIO m, MonadLoggerIO m, MonadBaseControl IO m, MonadCatch m)
=> Maybe Text
-> m ((Text, Int), Text)
startDMCCSession old = do
let
sendReq =
sendRequestSyncRaw
conn
reconnect
invoke
syncResponses
Nothing
startReq sid =
sendReq
Rq.StartApplicationSession
{ applicationId = ""
, requestedProtocolVersion = Rq.DMCC_6_2
, userName = user
, password = pass
, sessionCleanupDelay = sessionDuration sopts
, sessionID = fromMaybe T.empty sid
, requestedSessionDuration = sessionDuration sopts
}
sessionMonitorReq proto =
sendReq
Rq.MonitorStart
{ acceptedProtocol = proto
, monitorRq = Rq.Session
}
startRsp <- startReq old
case (startRsp, old) of
(Just Rs.StartApplicationSessionPosResponse{..}, _) -> do
_ <- sessionMonitorReq actualProtocolVersion
pure ((sessionID, actualSessionDuration), actualProtocolVersion)
(Just Rs.StartApplicationSessionNegResponse, Just oldID) -> do
startRsp' <- startReq Nothing
case startRsp' of
Just Rs.StartApplicationSessionPosResponse{..} -> do
_ <- sessionMonitorReq actualProtocolVersion
sendRequestAsyncRaw
conn
reconnect
invoke
Nothing
Rq.TransferMonitorObjects
{ fromSessionID = oldID
, toSessionID = sessionID
, acceptedProtocol = actualProtocolVersion
}
pure ((sessionID, actualSessionDuration), actualProtocolVersion)
_ -> throwIO ApplicationSessionFailed
_ -> throwIO ApplicationSessionFailed
reconnect :: (MonadUnliftIO m, MonadBaseControl IO m, MonadLoggerIO m, MonadCatch m) => m ()
reconnect = do
logWarnN "Attempting reconnection"
(oldId, cl) <- atomically $ do
(oldId, _) <- takeTMVar sess
(_, _, cl) <- takeTMVar conn
pure (oldId, cl)
atomically $ do
srs <- readTVar syncResponses
mapM_ (`putTMVar` Nothing) $ IntMap.elems srs
writeTVar syncResponses IntMap.empty
handle
(\(e :: IOException) -> logErrorN $ "Failed to close old connection: " <> tshow e) $
liftIO cl
connect >>= atomically . putTMVar conn
logWarnN "Connection re-established"
let
shdl (Right ()) = pure ()
shdl (Left e) = throwIO e
void $ flip forkFinally shdl $ do
(newSession, _) <- startDMCCSession (Just oldId)
atomically $ putTMVar sess newSession
msgChan <- newTChanIO
let readExHandler e = do
logErrorN $ "Read error: " <> tshow e
reconnect
readThread <-
forkIO $ forever $ do
(istream, _, _) <- atomically $ readTMVar conn
handleNetwork readExHandler $
Raw.readResponse istream >>=
atomically . writeTChan msgChan . first DMCCRsp
agents <- newTVarIO Map.empty
procThread <- forkIO $ forever $ do
(msg, invokeId) <- atomically $ readTChan msgChan
ags <- readTVarIO agents
case msg of
DMCCRsp rsp -> do
sync' <- atomically $ do
srs <- readTVar syncResponses
modifyTVar' syncResponses (IntMap.delete invokeId)
pure $ IntMap.lookup invokeId srs
case sync' of
Just sync -> void $ atomically $ tryPutTMVar sync $ Just rsp
Nothing -> pure ()
ag' <- case rsp of
Rs.EventResponse monId _ ->
pure $ find (\a -> monId == monitorId a) $ Map.elems ags
Rs.CSTAErrorCodeResponse _ -> do
aid <- atomically $ do
ars <- readTVar agentRequests
modifyTVar' agentRequests (IntMap.delete invokeId)
pure $ IntMap.lookup invokeId ars
pure $ (`Map.lookup` ags) =<< aid
_ -> pure Nothing
case ag' of
Just ag -> atomically $ writeTChan (rspChan ag) rsp
Nothing -> pure ()
_ -> pure ()
pingThread <- forkIO $ forever $ do
(sid, duration) <- atomically $ readTMVar sess
sendRequestAsyncRaw conn reconnect invoke Nothing
Rq.ResetApplicationSessionTimer
{ sessionId = sid
, requestedSessionDuration = duration
}
threadDelay $ duration * 500 * 1000
let h = DMCCHandle
conn
sess
reconnect
pingThread
readThread
procThread
invoke
syncResponses
agentRequests
sopts
connect >>= atomically . putTMVar conn
(newSession, actualProtocolVersion) <- startDMCCSession Nothing
atomically $ putTMVar sess newSession
wh <- case whUrl of
Just url -> do
mgr <- liftIO $ HTTP.newManager HTTP.defaultManagerSettings
req <- HTTP.parseUrlThrow url
pure $ Just (req, mgr)
Nothing -> pure Nothing
agLocks <- newTVarIO Set.empty
pure $
Session
actualProtocolVersion
h
wh
agents
agLocks
stopSession :: (MonadUnliftIO m, MonadLoggerIO m, MonadBaseControl IO m, MonadCatch m)
=> Session -> m ()
stopSession as@Session{..} = do
ags <- readTVarIO agents
(s, _) <- atomically $ readTMVar $ dmccSession dmccHandle
forM_ (keys ags) $ releaseAgent . \aid -> AgentHandle (aid, as)
sendRequestAsync dmccHandle Nothing Rq.StopApplicationSession{sessionID = s}
killThread $ pingThread dmccHandle
killThread $ procThread dmccHandle
killThread $ readThread dmccHandle
(_, ostream, cleanup) <- atomically $ readTMVar $ connection dmccHandle
liftIO $ do
write Nothing ostream
cleanup
sendRequestSync :: (MonadUnliftIO m, MonadLoggerIO m, MonadBaseControl IO m, MonadCatch m)
=> DMCCHandle
-> Maybe AgentId
-> Request
-> m (Maybe Response)
sendRequestSync DMCCHandle{..} aid rq = do
void $ atomically $ readTMVar dmccSession
sendRequestSyncRaw
connection
reconnect
invokeId
syncResponses
((agentRequests, ) <$> aid)
rq
sendRequestSyncRaw :: (MonadUnliftIO m, MonadLoggerIO m, MonadBaseControl IO m, MonadCatch m)
=> TMVar ConnectionData
-> m ()
-> TVar Int
-> TVar (IntMap.IntMap (TMVar (Maybe Response)))
-> Maybe (TVar (IntMap.IntMap AgentId), AgentId)
-> Request
-> m (Maybe Response)
sendRequestSyncRaw connection re invoke srs ar !rq = do
(ix, var, c@(_, ostream, _)) <- atomically $ do
modifyTVar' invoke $ (`mod` 9999) . (+1)
ix <- readTVar invoke
var <- newEmptyTMVar
modifyTVar' srs (IntMap.insert ix var)
case ar of
Just (ars, a) -> modifyTVar' ars (IntMap.insert ix a)
Nothing -> pure ()
c <- takeTMVar connection
pure (ix, var, c)
let
srHandler e = do
logErrorN $ "Write error: " <> tshow e
atomically $ do
putTMVar connection c
putTMVar var Nothing
modifyTVar' srs $ IntMap.delete ix
case ar of
Just (ars, _) -> modifyTVar' ars (IntMap.delete ix)
Nothing -> pure ()
re
handleNetwork srHandler $ Raw.sendRequest ostream ix rq
atomically $ putTMVar connection c
atomically $ takeTMVar var
sendRequestAsync :: (MonadUnliftIO m, MonadLoggerIO m, MonadBaseControl IO m, MonadCatch m)
=> DMCCHandle
-> Maybe AgentId
-> Request
-> m ()
sendRequestAsync DMCCHandle{..} aid rq = do
_ <- atomically $ readTMVar dmccSession
sendRequestAsyncRaw
connection
reconnect
invokeId
((agentRequests, ) <$> aid)
rq
sendRequestAsyncRaw :: (MonadUnliftIO m, MonadLoggerIO m, MonadCatch m)
=> TMVar ConnectionData
-> m ()
-> TVar Int
-> Maybe (TVar (IntMap.IntMap AgentId), AgentId)
-> Request
-> m ()
sendRequestAsyncRaw connection re invoke ar !rq = do
(ix, c@(_, ostream, _)) <- atomically $ do
modifyTVar' invoke $ (`mod` 9999) . (+1)
ix <- readTVar invoke
case ar of
Just (ars, a) -> modifyTVar' ars (IntMap.insert ix a)
Nothing -> pure ()
c <- takeTMVar connection
pure (ix, c)
let
srHandler e = do
logErrorN $ "Write error: " <> tshow e
atomically $ do
putTMVar connection c
case ar of
Just (ars, _) -> modifyTVar' ars (IntMap.delete ix)
Nothing -> pure ()
re
handleNetwork srHandler $ Raw.sendRequest ostream ix rq
atomically $ putTMVar connection c
handleNetwork :: forall a m. (MonadUnliftIO m, MonadLoggerIO m, MonadCatch m)
=> (forall e. (Exception e, Show e) => (e -> m a))
-> m a
-> m a
handleNetwork handler action = action `catches`
[ Handler (\(e :: ReadTooShortException) -> handler e)
, Handler (\(e :: IOException) -> handler e)
, Handler (\(e :: SSL.ConnectionAbruptlyTerminated) -> handler e)
, Handler (\(e :: SSL.ProtocolError) -> handler e)
]