{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DoAndIfThenElse #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
module DMCC.Session
  ( Session (..)
  , startSession
  , stopSession
  , defaultSessionOptions
  , DMCCError (..)
  , DMCCHandle (..)
  , sendRequestSync
  , sendRequestAsync
  )
where
import           DMCC.Prelude
import           Control.Arrow ()
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import qualified Data.IntMap.Strict as IntMap
import           Data.Text as T (empty)
import           Data.Typeable ()
import           System.IO.Streams ( InputStream
                                   , OutputStream
                                   , ReadTooShortException
                                   , write
                                   )
import qualified System.IO.Streams.SSL as SSLStreams
import qualified Network.HTTP.Client as HTTP
import           Network.Socket hiding (connect)
import           OpenSSL
import qualified OpenSSL.Session as SSL
import           DMCC.Types
import           DMCC.XML.Request (Request)
import qualified DMCC.XML.Request as Rq
import           DMCC.XML.Response (Response)
import qualified DMCC.XML.Response as Rs
import qualified DMCC.XML.Raw as Raw
import {-# SOURCE #-} DMCC.Agent
type ConnectionData = (InputStream ByteString, OutputStream ByteString, IO ())
data DMCCHandle = DMCCHandle
  { connection :: TMVar ConnectionData
  
  , dmccSession :: TMVar (Text, Int)
  
  , reconnect :: forall m
               . (MonadUnliftIO m, MonadLoggerIO m, MonadBaseControl IO m, MonadCatch m)
              => m ()
  
  
  , pingThread :: ThreadId
  , readThread :: ThreadId
  
  , procThread :: ThreadId
  
  , invokeId :: TVar Int
  
  , syncResponses :: TVar (IntMap.IntMap (TMVar (Maybe Response)))
  , agentRequests :: TVar (IntMap.IntMap AgentId)
  
  
  , sessionOptions :: SessionOptions
  }
data Session = Session
  { protocolVersion :: Text
  , dmccHandle :: DMCCHandle
  , webHook :: Maybe (HTTP.Request, HTTP.Manager)
  
  , agents :: TVar (Map.Map AgentId Agent)
  , agentLocks :: TVar (Set.Set AgentId)
  }
instance Show Session where
  show as = "Session{protocolVersion=" <> unpack (protocolVersion as) <> "}"
data LoopEvent
  = DMCCRsp Response
  | Timeout
  | ReadError
  deriving Show
data DMCCError = ApplicationSessionFailed
                 deriving (Show, Typeable)
instance Exception DMCCError
defaultSessionOptions :: SessionOptions
defaultSessionOptions = SessionOptions 1 120 24 5
startSession :: (MonadUnliftIO m, MonadLoggerIO m, MonadBaseControl IO m, MonadCatch m)
             => (String, PortNumber)
             
             -> Maybe FilePath
             
             -> Text
             
             -> Text
             
             -> Maybe String
             
             -> SessionOptions
             -> m Session
startSession (host, port) caDir user pass whUrl sopts = do
  syncResponses <- newTVarIO IntMap.empty
  agentRequests <- newTVarIO IntMap.empty
  invoke <- newTVarIO 0
  
  conn <- newEmptyTMVarIO
  
  
  sess <- newEmptyTMVarIO
  let
    
    connect :: (MonadUnliftIO m, MonadBase IO m, MonadLoggerIO m, MonadCatch m) => m ConnectionData
    connect = connect1 (connectionRetryAttempts sopts)
      where
        connectExHandler
          :: (Exception e, Show e, MonadUnliftIO m, MonadLoggerIO m, MonadBase IO m, MonadCatch m)
          => Int -> e -> m ConnectionData
        connectExHandler attempts e = do
          logErrorN $ "Connection failed: " <> tshow e
          if attempts > 0
             then threadDelay (connectionRetryDelay sopts * 1000000) >> connect1 (attempts - 1)
             else throwIO e
        connect1 attempts =
          handleNetwork (connectExHandler attempts) $
          liftIO $ withOpenSSL $ do
              sslCtx <- SSL.context
              SSL.contextSetDefaultCiphers sslCtx
              SSL.contextSetVerificationMode sslCtx $
                SSL.VerifyPeer True True Nothing
              maybe (pure ()) (SSL.contextSetCADirectory sslCtx) caDir
              (is, os, ssl) <- SSLStreams.connect sslCtx host port
              let cl = do
                    SSL.shutdown ssl SSL.Unidirectional
                    maybe (pure ()) close $ SSL.sslSocket ssl
              pure (is, os, cl)
    
    startDMCCSession :: (MonadUnliftIO m, MonadLoggerIO m, MonadBaseControl IO m, MonadCatch m)
                     => Maybe Text
                     
                     
                     -> m ((Text, Int), Text)
    startDMCCSession old = do
      let
        sendReq =
          sendRequestSyncRaw
          conn
          reconnect
          invoke
          syncResponses
          Nothing
        startReq sid =
          sendReq
          Rq.StartApplicationSession
          { applicationId = ""
          , requestedProtocolVersion = Rq.DMCC_6_2
          , userName = user
          , password = pass
          , sessionCleanupDelay = sessionDuration sopts
          , sessionID = fromMaybe T.empty sid
          , requestedSessionDuration = sessionDuration sopts
          }
        
        
        sessionMonitorReq proto =
          sendReq
          Rq.MonitorStart
          { acceptedProtocol = proto
          , monitorRq = Rq.Session
          }
      startRsp <- startReq old
      case (startRsp, old) of
        (Just Rs.StartApplicationSessionPosResponse{..}, _) -> do
          _ <- sessionMonitorReq actualProtocolVersion
          pure ((sessionID, actualSessionDuration), actualProtocolVersion)
        (Just Rs.StartApplicationSessionNegResponse, Just oldID) -> do
          
          startRsp' <- startReq Nothing
          case startRsp' of
            Just Rs.StartApplicationSessionPosResponse{..} -> do
              _ <- sessionMonitorReq actualProtocolVersion
              
              sendRequestAsyncRaw
                conn
                reconnect
                invoke
                Nothing
                Rq.TransferMonitorObjects
                { fromSessionID = oldID
                , toSessionID = sessionID
                , acceptedProtocol = actualProtocolVersion
                }
              pure ((sessionID, actualSessionDuration), actualProtocolVersion)
            _ -> throwIO ApplicationSessionFailed
        _ -> throwIO ApplicationSessionFailed
    
    
    
    reconnect :: (MonadUnliftIO m, MonadBaseControl IO m, MonadLoggerIO m, MonadCatch m) => m ()
    reconnect = do
      logWarnN "Attempting reconnection"
      
      (oldId, cl) <- atomically $ do
        (oldId, _) <- takeTMVar sess
        (_, _, cl) <- takeTMVar conn
        pure (oldId, cl)
      
      atomically $ do
        srs <- readTVar syncResponses
        mapM_ (`putTMVar` Nothing) $ IntMap.elems srs
        writeTVar syncResponses IntMap.empty
      handle
        (\(e :: IOException) -> logErrorN $ "Failed to close old connection: " <> tshow e) $
        liftIO cl
      
      connect >>= atomically . putTMVar conn
      logWarnN "Connection re-established"
      let
        shdl (Right ()) = pure ()
        shdl (Left e) = throwIO e
      
      
      
      
      void $ flip forkFinally shdl $ do
        (newSession, _) <- startDMCCSession (Just oldId)
        atomically $ putTMVar sess newSession
  
  msgChan <- newTChanIO
  let readExHandler e = do
        logErrorN $ "Read error: " <> tshow e
        reconnect
  readThread <-
    forkIO $ forever $ do
      (istream, _, _) <- atomically $ readTMVar conn
      handleNetwork readExHandler $
        Raw.readResponse istream >>=
        atomically . writeTChan msgChan . first DMCCRsp
  agents <- newTVarIO Map.empty
  
  procThread <- forkIO $ forever $ do
    (msg, invokeId) <- atomically $ readTChan msgChan
    
    ags <- readTVarIO agents
    case msg of
      DMCCRsp rsp -> do
        
        sync' <- atomically $ do
          srs <- readTVar syncResponses
          modifyTVar' syncResponses (IntMap.delete invokeId)
          pure $ IntMap.lookup invokeId srs
        case sync' of
          Just sync -> void $ atomically $ tryPutTMVar sync $ Just rsp
          Nothing -> pure ()
        
        ag' <- case rsp of
                 Rs.EventResponse monId _ ->
                   pure $ find (\a -> monId == monitorId a) $ Map.elems ags
                 Rs.CSTAErrorCodeResponse _ -> do
                   aid <- atomically $ do
                     ars <- readTVar agentRequests
                     modifyTVar' agentRequests (IntMap.delete invokeId)
                     pure $ IntMap.lookup invokeId ars
                   pure $ (`Map.lookup` ags) =<< aid
                 _ -> pure Nothing
        case ag' of
          Just ag -> atomically $ writeTChan (rspChan ag) rsp
          
          Nothing -> pure ()
      _ -> pure ()
  
  pingThread <- forkIO $ forever $ do
    
    (sid, duration) <- atomically $ readTMVar sess
    sendRequestAsyncRaw conn reconnect invoke Nothing
      Rq.ResetApplicationSessionTimer
      { sessionId = sid
      , requestedSessionDuration = duration
      }
    threadDelay $ duration * 500 * 1000
  let h = DMCCHandle
          conn
          sess
          reconnect
          pingThread
          readThread
          procThread
          invoke
          syncResponses
          agentRequests
          sopts
  
  connect >>= atomically . putTMVar conn
  (newSession, actualProtocolVersion) <- startDMCCSession Nothing
  atomically $ putTMVar sess newSession
  wh <- case whUrl of
          Just url -> do
            mgr <- liftIO $ HTTP.newManager HTTP.defaultManagerSettings
            req <- HTTP.parseUrlThrow url
            pure $ Just (req, mgr)
          Nothing -> pure Nothing
  agLocks <- newTVarIO Set.empty
  pure $
    Session
    actualProtocolVersion
    h
    wh
    agents
    agLocks
stopSession :: (MonadUnliftIO m, MonadLoggerIO m, MonadBaseControl IO m, MonadCatch m)
            => Session -> m ()
stopSession as@Session{..} = do
  
  ags <- readTVarIO agents
  (s, _) <- atomically $ readTMVar $ dmccSession dmccHandle
  forM_ (keys ags) $ releaseAgent . \aid -> AgentHandle (aid, as)
  sendRequestAsync dmccHandle Nothing Rq.StopApplicationSession{sessionID = s}
  
  killThread $ pingThread dmccHandle
  killThread $ procThread dmccHandle
  killThread $ readThread dmccHandle
  (_, ostream, cleanup) <- atomically $ readTMVar $ connection dmccHandle
  liftIO $ do
    write Nothing ostream
    cleanup
sendRequestSync :: (MonadUnliftIO m, MonadLoggerIO m, MonadBaseControl IO m, MonadCatch m)
                => DMCCHandle
                -> Maybe AgentId
                
                
                
                
                
                -> Request
                -> m (Maybe Response)
sendRequestSync DMCCHandle{..} aid rq = do
  void $ atomically $ readTMVar dmccSession
  sendRequestSyncRaw
    connection
    reconnect
    invokeId
    syncResponses
    ((agentRequests, ) <$> aid)
    rq
sendRequestSyncRaw :: (MonadUnliftIO m, MonadLoggerIO m, MonadBaseControl IO m, MonadCatch m)
                   => TMVar ConnectionData
                   
                   -> m ()
                   
                   -> TVar Int
                   -> TVar (IntMap.IntMap (TMVar (Maybe Response)))
                   
                   -> Maybe (TVar (IntMap.IntMap AgentId), AgentId)
                   
                   
                   
                   -> Request
                   -> m (Maybe Response)
sendRequestSyncRaw connection re invoke srs ar !rq = do
  (ix, var, c@(_, ostream, _)) <- atomically $ do
    modifyTVar' invoke $ (`mod` 9999) . (+1)
    ix <- readTVar invoke
    var <- newEmptyTMVar
    modifyTVar' srs (IntMap.insert ix var)
    case ar of
      Just (ars, a) -> modifyTVar' ars (IntMap.insert ix a)
      Nothing -> pure ()
    c <- takeTMVar connection
    pure (ix, var, c)
  let
    srHandler e = do
      logErrorN $ "Write error: " <> tshow e
      atomically $ do
        putTMVar connection c
        putTMVar var Nothing
        modifyTVar' srs $ IntMap.delete ix
        case ar of
          Just (ars, _) -> modifyTVar' ars (IntMap.delete ix)
          Nothing -> pure ()
      re
  handleNetwork srHandler $ Raw.sendRequest ostream ix rq
  
  
  atomically $ putTMVar connection c
  atomically $ takeTMVar var
sendRequestAsync :: (MonadUnliftIO m, MonadLoggerIO m, MonadBaseControl IO m, MonadCatch m)
                 => DMCCHandle
                 -> Maybe AgentId
                 
                 
                 -> Request
                 -> m ()
sendRequestAsync DMCCHandle{..} aid rq = do
  _ <- atomically $ readTMVar dmccSession
  sendRequestAsyncRaw
    connection
    reconnect
    invokeId
    ((agentRequests, ) <$> aid)
    rq
sendRequestAsyncRaw :: (MonadUnliftIO m, MonadLoggerIO m, MonadCatch m)
                    => TMVar ConnectionData
                    
                    -> m ()
                    
                    -> TVar Int
                    -> Maybe (TVar (IntMap.IntMap AgentId), AgentId)
                    -> Request
                    -> m ()
sendRequestAsyncRaw connection re invoke ar !rq = do
  (ix, c@(_, ostream, _)) <- atomically $ do
    modifyTVar' invoke $ (`mod` 9999) . (+1)
    ix <- readTVar invoke
    case ar of
      Just (ars, a) -> modifyTVar' ars (IntMap.insert ix a)
      Nothing -> pure ()
    c <- takeTMVar connection
    pure (ix, c)
  let
    srHandler e = do
      logErrorN $ "Write error: " <> tshow e
      atomically $ do
        putTMVar connection c
        case ar of
          Just (ars, _) -> modifyTVar' ars (IntMap.delete ix)
          Nothing -> pure ()
      re
  handleNetwork srHandler $ Raw.sendRequest ostream ix rq
  atomically $ putTMVar connection c
handleNetwork :: forall a m. (MonadUnliftIO m, MonadLoggerIO m, MonadCatch m)
              => (forall e. (Exception e, Show e) => (e -> m a))
              
              -> m a
              -> m a
handleNetwork handler action = action `catches`
  [ Handler (\(e :: ReadTooShortException) -> handler e)
  , Handler (\(e :: IOException) -> handler e)
  , Handler (\(e :: SSL.ConnectionAbruptlyTerminated) -> handler e)
  , Handler (\(e :: SSL.ProtocolError) -> handler e)
  ]