{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Hercules.Agent.Socket
  ( withReliableSocket,
    checkVersion',
    Socket (..),
    syncIO,
    SocketConfig (..),
  )
where

import Control.Concurrent.STM.TBQueue (TBQueue, flushTBQueue, newTBQueue, writeTBQueue)
import Control.Concurrent.STM.TChan (TChan, writeTChan)
import Control.Concurrent.STM.TVar (TVar, modifyTVar, readTVar, writeTVar)
import Control.Monad.IO.Unlift
import qualified Data.Aeson as A
import Data.DList (DList, fromList)
import Data.List (dropWhileEnd, splitAt)
import Data.Semigroup
import Data.Time (NominalDiffTime, addUTCTime, diffUTCTime, getCurrentTime)
import Data.Time.Extras
import Hercules.API.Agent.LifeCycle.ServiceInfo (ServiceInfo)
import qualified Hercules.API.Agent.LifeCycle.ServiceInfo as ServiceInfo
import Hercules.API.Agent.Socket.Frame (Frame)
import qualified Hercules.API.Agent.Socket.Frame as Frame
import Hercules.Agent.STM (atomically, newTChanIO, newTVarIO)
import Katip (KatipContext, Severity (..), katipAddContext, katipAddNamespace, logLocM, sl)
import Network.URI (URI, uriAuthority, uriPath, uriPort, uriQuery, uriRegName, uriScheme)
import Network.WebSockets (Connection, runClientWith)
import qualified Network.WebSockets as WS
import Protolude hiding (atomically, handle, race, race_)
import qualified UnliftIO
import UnliftIO.Async (race, race_)
import UnliftIO.Exception (handle)
import UnliftIO.STM (readTVarIO)
import UnliftIO.Timeout (timeout)
import Wuss (runSecureClientWith)

data Socket r w = Socket
  { forall r w. Socket r w -> w -> STM ()
write :: w -> STM (),
    forall r w. Socket r w -> TChan r
serviceChan :: TChan r,
    forall r w. Socket r w -> STM (STM ())
sync :: STM (STM ())
  }

syncIO :: Socket r w -> IO ()
syncIO :: forall r w. Socket r w -> IO ()
syncIO = forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r w. Socket r w -> STM (STM ())
sync

-- | Parameters to start 'withReliableSocket'.
data SocketConfig ap sp m = SocketConfig
  { forall ap sp (m :: * -> *). SocketConfig ap sp m -> m ap
makeHello :: m ap,
    forall ap sp (m :: * -> *).
SocketConfig ap sp m -> sp -> m (Either Text ())
checkVersion :: sp -> m (Either Text ()),
    forall ap sp (m :: * -> *). SocketConfig ap sp m -> URI
baseURL :: URI,
    forall ap sp (m :: * -> *). SocketConfig ap sp m -> Text
path :: Text,
    forall ap sp (m :: * -> *). SocketConfig ap sp m -> ByteString
token :: ByteString
  }

requiredServiceVersion :: (Int, Int)
requiredServiceVersion :: (Int, Int)
requiredServiceVersion = (Int
2, Int
0)

ackTimeout :: NominalDiffTime
ackTimeout :: NominalDiffTime
ackTimeout = NominalDiffTime
60 -- seconds

withReliableSocket :: (A.FromJSON sp, A.ToJSON ap, MonadIO m, MonadUnliftIO m, KatipContext m) => SocketConfig ap sp m -> (Socket sp ap -> m a) -> m a
withReliableSocket :: forall sp ap (m :: * -> *) a.
(FromJSON sp, ToJSON ap, MonadIO m, MonadUnliftIO m,
 KatipContext m) =>
SocketConfig ap sp m -> (Socket sp ap -> m a) -> m a
withReliableSocket SocketConfig ap sp m
socketConfig Socket sp ap -> m a
f = do
  TBQueue (Frame ap ap)
writeQueue <- forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically forall a b. (a -> b) -> a -> b
$ forall a. Natural -> STM (TBQueue a)
newTBQueue Natural
100
  TVar Integer
agentMessageNextN <- forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO Integer
0
  TChan sp
serviceMessageChan <- forall (m :: * -> *) a. MonadIO m => m (TChan a)
newTChanIO
  TVar Integer
highestAcked <- forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO (-Integer
1)
  let tagPayload :: a -> STM (Frame o a)
tagPayload a
p = do
        Integer
c <- forall a. TVar a -> STM a
readTVar TVar Integer
agentMessageNextN
        forall a. TVar a -> a -> STM ()
writeTVar TVar Integer
agentMessageNextN (Integer
c forall a. Num a => a -> a -> a
+ Integer
1)
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Frame.Msg {n :: Integer
n = Integer
c, p :: a
p = a
p}
      socketThread :: m a
socketThread = forall ap sp (m :: * -> *).
(ToJSON ap, FromJSON sp, MonadUnliftIO m, KatipContext m) =>
SocketConfig ap sp m
-> TBQueue (Frame ap ap)
-> TChan sp
-> TVar Integer
-> forall a. m a
runReliableSocket SocketConfig ap sp m
socketConfig TBQueue (Frame ap ap)
writeQueue TChan sp
serviceMessageChan TVar Integer
highestAcked
      socket :: Socket sp ap
socket =
        Socket
          { write :: ap -> STM ()
write = forall {a} {o}. a -> STM (Frame o a)
tagPayload forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> forall a. TBQueue a -> a -> STM ()
writeTBQueue TBQueue (Frame ap ap)
writeQueue,
            serviceChan :: TChan sp
serviceChan = TChan sp
serviceMessageChan,
            sync :: STM (STM ())
sync = do
              Integer
counterAtSyncStart <- (\Integer
n -> Integer
n forall a. Num a => a -> a -> a
- Integer
1) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. TVar a -> STM a
readTVar TVar Integer
agentMessageNextN
              forall (f :: * -> *) a. Applicative f => a -> f a
pure do
                Integer
acked <- forall a. TVar a -> STM a
readTVar TVar Integer
highestAcked
                forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ Integer
acked forall a. Ord a => a -> a -> Bool
>= Integer
counterAtSyncStart
          }
  forall (m :: * -> *) a b.
MonadUnliftIO m =>
m a -> m b -> m (Either a b)
race forall {a}. m a
socketThread (Socket sp ap -> m a
f Socket sp ap
socket) forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall a. a -> a
identity forall a. a -> a
identity

checkVersion' :: Applicative m => ServiceInfo -> m (Either Text ())
checkVersion' :: forall (m :: * -> *).
Applicative m =>
ServiceInfo -> m (Either Text ())
checkVersion' ServiceInfo
si =
  if ServiceInfo -> (Int, Int)
ServiceInfo.version ServiceInfo
si forall a. Ord a => a -> a -> Bool
< (Int, Int)
requiredServiceVersion
    then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ Text
"Expected service version " forall a. Semigroup a => a -> a -> a
<> forall a b. (Show a, StringConv [Char] b) => a -> b
show (Int, Int)
requiredServiceVersion
    else forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right ()

runReliableSocket :: forall ap sp m. (A.ToJSON ap, A.FromJSON sp, MonadUnliftIO m, KatipContext m) => SocketConfig ap sp m -> TBQueue (Frame ap ap) -> TChan sp -> TVar Integer -> forall a. m a
runReliableSocket :: forall ap sp (m :: * -> *).
(ToJSON ap, FromJSON sp, MonadUnliftIO m, KatipContext m) =>
SocketConfig ap sp m
-> TBQueue (Frame ap ap)
-> TChan sp
-> TVar Integer
-> forall a. m a
runReliableSocket SocketConfig ap sp m
socketConfig TBQueue (Frame ap ap)
writeQueue TChan sp
serviceMessageChan TVar Integer
highestAcked = forall (m :: * -> *) a. KatipContext m => Namespace -> m a -> m a
katipAddNamespace Namespace
"Socket" do
  TVar (Maybe (Integer, UTCTime))
expectedAck <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO forall a. Maybe a
Nothing
  (TVar (DList (Frame Void ap))
unacked :: TVar (DList (Frame Void ap))) <- forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO forall a. Monoid a => a
mempty
  (TVar Integer
lastServiceN :: TVar Integer) <- forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO (-Integer
1)
  let katipExceptionContext :: a -> m a -> m a
katipExceptionContext a
e =
        forall i (m :: * -> *) a.
(LogItem i, KatipContext m) =>
i -> m a -> m a
katipAddContext (forall a. ToJSON a => Text -> a -> SimpleLogPayload
sl Text
"message" (forall e. Exception e => e -> [Char]
displayException a
e))
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall i (m :: * -> *) a.
(LogItem i, KatipContext m) =>
i -> m a -> m a
katipAddContext (forall a. ToJSON a => Text -> a -> SimpleLogPayload
sl Text
"exception" (forall a b. (Show a, StringConv [Char] b) => a -> b
show a
e :: [Char]))
      logWarningPause :: SomeException -> m ()
      logWarningPause :: SomeException -> m ()
logWarningPause SomeException
e | Just ConnectionException
WS.ConnectionClosed <- forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e = do
        forall {m :: * -> *} {a} {a}.
(KatipContext m, Exception a) =>
a -> m a -> m a
katipExceptionContext SomeException
e forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
(Applicative m, KatipContext m, HasCallStack) =>
Severity -> LogStr -> m ()
logLocM Severity
InfoS LogStr
"Socket closed. Reconnecting."
        forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ Int -> IO ()
threadDelay Int
10_000_000
      logWarningPause SomeException
e | Just (WS.ParseException [Char]
"not enough bytes") <- forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e = do
        forall {m :: * -> *} {a} {a}.
(KatipContext m, Exception a) =>
a -> m a -> m a
katipExceptionContext SomeException
e forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
(Applicative m, KatipContext m, HasCallStack) =>
Severity -> LogStr -> m ()
logLocM Severity
InfoS LogStr
"Socket closed prematurely. Reconnecting."
        forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ Int -> IO ()
threadDelay Int
10_000_000
      logWarningPause SomeException
e = do
        forall {m :: * -> *} {a} {a}.
(KatipContext m, Exception a) =>
a -> m a -> m a
katipExceptionContext SomeException
e forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
(Applicative m, KatipContext m, HasCallStack) =>
Severity -> LogStr -> m ()
logLocM Severity
WarningS LogStr
"Recovering from exception in socket handler. Reconnecting."
        forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ Int -> IO ()
threadDelay Int
10_000_000
      setExpectedAckForMsgs :: [Frame ap ap] -> m ()
      setExpectedAckForMsgs :: [Frame ap ap] -> m ()
setExpectedAckForMsgs [Frame ap ap]
msgs =
        [Frame ap ap]
msgs
          forall a b. a -> (a -> b) -> b
& forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (\case Frame.Msg {n :: forall o a. Frame o a -> Integer
n = Integer
n} -> forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a. a -> Max a
Max Integer
n; Frame ap ap
_ -> forall a. Monoid a => a
mempty)
          forall a b. a -> (a -> b) -> b
& forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (\(Max Integer
n) -> Integer -> m ()
setExpectedAck Integer
n)
      send :: Connection -> [Frame ap ap] -> m ()
      send :: Connection -> [Frame ap ap] -> m ()
send Connection
conn = [Frame ap ap] -> m ()
sendSorted forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a. Ord a => a -> a -> Ordering
compare forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall o a. Frame o a -> Maybe Integer
msgN)
        where
          sendRaw :: [Frame ap ap] -> m ()
          sendRaw :: [Frame ap ap] -> m ()
sendRaw [Frame ap ap]
msgs = do
            forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ Connection -> [DataMessage] -> IO ()
WS.sendDataMessages Connection
conn (ByteString -> DataMessage
WS.Binary forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ToJSON a => a -> ByteString
A.encode forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Frame ap ap]
msgs)
            [Frame ap ap] -> m ()
setExpectedAckForMsgs [Frame ap ap]
msgs
          sendSorted :: [Frame ap ap] -> m ()
          sendSorted :: [Frame ap ap] -> m ()
sendSorted [] = forall (f :: * -> *). Applicative f => f ()
pass
          sendSorted [Frame ap ap]
msgs = do
            let ([Frame ap ap]
msgsNow, [Frame ap ap]
msgsLater) = forall a. Int -> [a] -> ([a], [a])
Data.List.splitAt Int
100 [Frame ap ap]
msgs
            [Frame ap ap] -> m ()
sendRaw [Frame ap ap]
msgsNow
            [Frame ap ap] -> m ()
sendSorted [Frame ap ap]
msgsLater
      recv :: Connection -> m (Frame sp sp)
      recv :: Connection -> m (Frame sp sp)
recv Connection
conn = do
        forall e (m :: * -> *) a.
(Exception e, MonadIO m, MonadUnliftIO m) =>
NominalDiffTime -> e -> m a -> m a
withTimeout NominalDiffTime
ackTimeout (Text -> FatalError
FatalError Text
"Hercules.Agent.Socket.recv timed out") forall a b. (a -> b) -> a -> b
$
          forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (forall a. FromJSON a => ByteString -> Either [Char] a
A.eitherDecode forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. WebSocketsData a => Connection -> IO a
WS.receiveData Connection
conn) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
            Left [Char]
e -> forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO (Text -> FatalError
FatalError forall a b. (a -> b) -> a -> b
$ Text
"Error decoding service message: " forall a. Semigroup a => a -> a -> a
<> forall a b. ConvertText a b => a -> b
toS [Char]
e)
            Right Frame sp sp
r -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Frame sp sp
r
      handshake :: Connection -> IO () -> m ()
handshake Connection
conn (IO ()
removeHandshakeTimeout :: IO ()) = forall (m :: * -> *) a. KatipContext m => Namespace -> m a -> m a
katipAddNamespace Namespace
"Handshake" do
        Frame sp sp
siMsg <- Connection -> m (Frame sp sp)
recv Connection
conn
        case Frame sp sp
siMsg of
          Frame.Oob {o :: forall o a. Frame o a -> o
o = sp
o'} ->
            forall ap sp (m :: * -> *).
SocketConfig ap sp m -> sp -> m (Either Text ())
checkVersion SocketConfig ap sp m
socketConfig sp
o' forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
              Left Text
e -> do
                Connection -> [Frame ap ap] -> m ()
send Connection
conn [forall o a. Text -> Frame o a
Frame.Exception Text
e]
                forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO forall a b. (a -> b) -> a -> b
$ Text -> FatalError
FatalError Text
"It looks like you're running a development version of hercules-ci-agent that is not yet supported on hercules-ci.com. Please use the stable branch or a tag."
              Right ()
_ -> forall (f :: * -> *). Applicative f => f ()
pass
          Frame sp sp
_ -> forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO forall a b. (a -> b) -> a -> b
$ Text -> FatalError
FatalError Text
"Unexpected message. This is either a bug or you might need to update your agent."
        ap
hello <- forall ap sp (m :: * -> *). SocketConfig ap sp m -> m ap
makeHello SocketConfig ap sp m
socketConfig
        Connection -> [Frame ap ap] -> m ()
send Connection
conn [forall o a. o -> Frame o a
Frame.Oob ap
hello]
        Frame sp sp
ackMsg <- Connection -> m (Frame sp sp)
recv Connection
conn
        case Frame sp sp
ackMsg of
          Frame.Ack {n :: forall o a. Frame o a -> Integer
n = Integer
n} -> forall {m :: * -> *}. MonadIO m => Integer -> m ()
cleanAcknowledged Integer
n
          Frame sp sp
_ -> forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO forall a b. (a -> b) -> a -> b
$ Text -> FatalError
FatalError Text
"Expected acknowledgement. This is either a bug or you might need to update your agent."
        forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO ()
removeHandshakeTimeout
        Connection -> m ()
sendUnacked Connection
conn
      sendUnacked :: Connection -> m ()
      sendUnacked :: Connection -> m ()
sendUnacked Connection
conn = do
        DList (Frame Void ap)
unackedNow <- forall (m :: * -> *) a. MonadIO m => TVar a -> m a
readTVarIO TVar (DList (Frame Void ap))
unacked
        Connection -> [Frame ap ap] -> m ()
send Connection
conn forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b) -> Frame a c -> Frame b c
Frame.mapOob forall a. Void -> a
absurd forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) a. Foldable t => t a -> [a]
toList DList (Frame Void ap)
unackedNow
      cleanAcknowledged :: Integer -> m ()
cleanAcknowledged Integer
newAck = forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically do
        DList (Frame Void ap)
unacked0 <- forall a. TVar a -> STM a
readTVar TVar (DList (Frame Void ap))
unacked
        forall a. TVar a -> a -> STM ()
writeTVar TVar (DList (Frame Void ap))
unacked forall a b. (a -> b) -> a -> b
$
          DList (Frame Void ap)
unacked0
            forall a b. a -> (a -> b) -> b
& forall (t :: * -> *) a. Foldable t => t a -> [a]
toList
            forall a b. a -> (a -> b) -> b
& forall a. (a -> Bool) -> [a] -> [a]
filter
              ( \case
                  Frame.Msg {n :: forall o a. Frame o a -> Integer
n = Integer
n} -> Integer
n forall a. Ord a => a -> a -> Bool
> Integer
newAck
                  Frame.Oob Void
x -> forall a. Void -> a
absurd Void
x
                  Frame.Ack {} -> Bool
False
                  Frame.Exception {} -> Bool
False
              )
            forall a b. a -> (a -> b) -> b
& forall a. [a] -> DList a
fromList
        forall a. TVar a -> (a -> a) -> STM ()
modifyTVar TVar Integer
highestAcked (forall a. Ord a => a -> a -> a
max Integer
newAck)
      -- TODO (performance) IntMap?

      readThread :: Connection -> m a
readThread Connection
conn = forall (m :: * -> *) a. KatipContext m => Namespace -> m a -> m a
katipAddNamespace Namespace
"Reader" do
        forall (f :: * -> *) a b. Applicative f => f a -> f b
forever forall a b. (a -> b) -> a -> b
$ do
          Frame sp sp
msg <- Connection -> m (Frame sp sp)
recv Connection
conn
          case Frame sp sp
msg of
            Frame.Msg {p :: forall o a. Frame o a -> a
p = sp
pl, n :: forall o a. Frame o a -> Integer
n = Integer
n} -> forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically do
              Integer
lastN <- forall a. TVar a -> STM a
readTVar TVar Integer
lastServiceN
              -- when recent
              forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Integer
n forall a. Ord a => a -> a -> Bool
> Integer
lastN) do
                forall a. TChan a -> a -> STM ()
writeTChan TChan sp
serviceMessageChan sp
pl
                forall a. TVar a -> a -> STM ()
writeTVar TVar Integer
lastServiceN Integer
n
              forall a. TBQueue a -> a -> STM ()
writeTBQueue TBQueue (Frame ap ap)
writeQueue (Frame.Ack {n :: Integer
n = Integer
n})
            Frame.Ack {n :: forall o a. Frame o a -> Integer
n = Integer
n} ->
              forall {m :: * -> *}. MonadIO m => Integer -> m ()
cleanAcknowledged Integer
n
            Frame.Oob sp
o -> forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically do
              forall a. TChan a -> a -> STM ()
writeTChan TChan sp
serviceMessageChan sp
o
            Frame.Exception Text
e -> forall i (m :: * -> *) a.
(LogItem i, KatipContext m) =>
i -> m a -> m a
katipAddContext (forall a. ToJSON a => Text -> a -> SimpleLogPayload
sl Text
"message" Text
e) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
(Applicative m, KatipContext m, HasCallStack) =>
Severity -> LogStr -> m ()
logLocM Severity
WarningS LogStr
"Service exception"
      writeThread :: Connection -> m a
writeThread Connection
conn = forall (m :: * -> *) a. KatipContext m => Namespace -> m a -> m a
katipAddNamespace Namespace
"Writer" do
        forall (f :: * -> *) a b. Applicative f => f a -> f b
forever do
          [Frame ap ap]
msgs <- forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically do
            -- TODO: make unacked bounded
            [Frame ap ap]
allMessages <- forall a. TBQueue a -> STM [a]
flushTBQueue TBQueue (Frame ap ap)
writeQueue
            forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Frame ap ap]
allMessages) forall a. STM a
retry
            forall a. TVar a -> (a -> a) -> STM ()
modifyTVar TVar (DList (Frame Void ap))
unacked (forall a. Semigroup a => a -> a -> a
<> ([Frame ap ap]
allMessages forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (f :: * -> *) o a.
Alternative f =>
Frame o a -> f (Frame Void a)
Frame.removeOob forall a b. a -> (a -> b) -> b
& forall a. [a] -> DList a
fromList))
            forall (f :: * -> *) a. Applicative f => a -> f a
pure [Frame ap ap]
allMessages
          Connection -> [Frame ap ap] -> m ()
send Connection
conn [Frame ap ap]
msgs
      setExpectedAck :: Integer -> m ()
      setExpectedAck :: Integer -> m ()
setExpectedAck Integer
n = do
        UTCTime
now <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime
        forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically do
          forall a. TVar a -> a -> STM ()
writeTVar TVar (Maybe (Integer, UTCTime))
expectedAck forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (Integer
n, UTCTime
now)
      noAckCleanupThread :: m ()
noAckCleanupThread = forall {m :: * -> *}. KatipContext m => Integer -> m ()
noAckCleanupThread' (-Integer
1)
      noAckCleanupThread' :: Integer -> m ()
noAckCleanupThread' Integer
confirmedLastTime = do
        (Integer
expectedN, UTCTime
sendTime) <- forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically do
          forall a. TVar a -> STM a
readTVar TVar (Maybe (Integer, UTCTime))
expectedAck
            forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
              Maybe (Integer, UTCTime)
Nothing -> forall a. STM a
retry
              Just (Integer
expectedN, UTCTime
_) | Integer
expectedN forall a. Ord a => a -> a -> Bool
<= Integer
confirmedLastTime -> forall a. STM a
retry
              Just (Integer, UTCTime)
a -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Integer, UTCTime)
a
        let expectedArrival :: UTCTime
expectedArrival = NominalDiffTime
ackTimeout NominalDiffTime -> UTCTime -> UTCTime
`addUTCTime` UTCTime
sendTime
        forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO do
          UTCTime
now <- IO UTCTime
getCurrentTime
          let waitTime :: NominalDiffTime
waitTime = UTCTime
expectedArrival UTCTime -> UTCTime -> NominalDiffTime
`diffUTCTime` UTCTime
now
          NominalDiffTime -> IO ()
delayNominalDiffTime NominalDiffTime
waitTime
        Integer
currentHighestAck <- forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically do
          forall a. TVar a -> STM a
readTVar TVar Integer
highestAcked
        if Integer
expectedN forall a. Ord a => a -> a -> Bool
> Integer
currentHighestAck
          then do
            forall i (m :: * -> *) a.
(LogItem i, KatipContext m) =>
i -> m a -> m a
katipAddContext (forall a. ToJSON a => Text -> a -> SimpleLogPayload
sl Text
"expectedAck" Integer
expectedN forall a. Semigroup a => a -> a -> a
<> forall a. ToJSON a => Text -> a -> SimpleLogPayload
sl Text
"highestAck" Integer
currentHighestAck) do
              forall (m :: * -> *).
(Applicative m, KatipContext m, HasCallStack) =>
Severity -> LogStr -> m ()
logLocM Severity
Katip.DebugS LogStr
"Did not receive ack in time. Will reconnect."
            -- terminate other threads via race_
            forall (f :: * -> *). Applicative f => f ()
pass
          else Integer -> m ()
noAckCleanupThread' Integer
expectedN
  forall (f :: * -> *) a b. Applicative f => f a -> f b
forever do
    forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
(e -> m a) -> m a -> m a
handle SomeException -> m ()
logWarningPause forall a b. (a -> b) -> a -> b
$
      forall e (m :: * -> *) a.
(Exception e, MonadUnliftIO m) =>
Int -> e -> (IO () -> m a) -> m a
withCancelableTimeout Int
handshakeTimeoutMicroseconds HandshakeTimeout
HandshakeTimeout \IO ()
removeTimeout -> do
        forall (m :: * -> *) any0 any1 a.
MonadUnliftIO m =>
SocketConfig any0 any1 m -> (Connection -> m a) -> m a
withConnection' SocketConfig ap sp m
socketConfig forall a b. (a -> b) -> a -> b
$
          \Connection
conn -> do
            forall (m :: * -> *) a. KatipContext m => Namespace -> m a -> m a
katipAddNamespace Namespace
"Handshake" do
              Connection -> IO () -> m ()
handshake Connection
conn IO ()
removeTimeout
            forall {a}. Connection -> m a
readThread Connection
conn forall (m :: * -> *) a b. MonadUnliftIO m => m a -> m b -> m ()
`race_` forall {a}. Connection -> m a
writeThread Connection
conn forall (m :: * -> *) a b. MonadUnliftIO m => m a -> m b -> m ()
`race_` m ()
noAckCleanupThread

handshakeTimeoutMicroseconds :: Int
handshakeTimeoutMicroseconds :: Int
handshakeTimeoutMicroseconds = Int
30_000_000

data HandshakeTimeout = HandshakeTimeout
  deriving (Int -> HandshakeTimeout -> ShowS
[HandshakeTimeout] -> ShowS
HandshakeTimeout -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [HandshakeTimeout] -> ShowS
$cshowList :: [HandshakeTimeout] -> ShowS
show :: HandshakeTimeout -> [Char]
$cshow :: HandshakeTimeout -> [Char]
showsPrec :: Int -> HandshakeTimeout -> ShowS
$cshowsPrec :: Int -> HandshakeTimeout -> ShowS
Show, Show HandshakeTimeout
Typeable HandshakeTimeout
SomeException -> Maybe HandshakeTimeout
HandshakeTimeout -> [Char]
HandshakeTimeout -> SomeException
forall e.
Typeable e
-> Show e
-> (e -> SomeException)
-> (SomeException -> Maybe e)
-> (e -> [Char])
-> Exception e
displayException :: HandshakeTimeout -> [Char]
$cdisplayException :: HandshakeTimeout -> [Char]
fromException :: SomeException -> Maybe HandshakeTimeout
$cfromException :: SomeException -> Maybe HandshakeTimeout
toException :: HandshakeTimeout -> SomeException
$ctoException :: HandshakeTimeout -> SomeException
Exception)

withCancelableTimeout :: (Exception e, MonadUnliftIO m) => Int -> e -> (IO () -> m a) -> m a
withCancelableTimeout :: forall e (m :: * -> *) a.
(Exception e, MonadUnliftIO m) =>
Int -> e -> (IO () -> m a) -> m a
withCancelableTimeout Int
delay e
exc IO () -> m a
cont = do
  ThreadId
requestingThread <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO ThreadId
myThreadId
  forall (m :: * -> *) a b.
MonadUnliftIO m =>
m a -> (Async a -> m b) -> m b
UnliftIO.withAsync
    ( forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO do
        Int -> IO ()
threadDelay Int
delay
        forall (m :: * -> *) e.
(MonadIO m, Exception e) =>
ThreadId -> e -> m ()
throwTo ThreadId
requestingThread e
exc
    )
    (IO () -> m a
cont forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Async a -> IO ()
cancel)

msgN :: Frame o a -> Maybe Integer
msgN :: forall o a. Frame o a -> Maybe Integer
msgN Frame.Msg {n :: forall o a. Frame o a -> Integer
n = Integer
n} = forall a. a -> Maybe a
Just Integer
n
msgN Frame o a
_ = forall a. Maybe a
Nothing

withConnection' :: (MonadUnliftIO m) => SocketConfig any0 any1 m -> (Connection -> m a) -> m a
withConnection' :: forall (m :: * -> *) any0 any1 a.
MonadUnliftIO m =>
SocketConfig any0 any1 m -> (Connection -> m a) -> m a
withConnection' SocketConfig any0 any1 m
socketConfig Connection -> m a
m = do
  UnliftIO forall a. m a -> IO a
unlift <- forall (m :: * -> *). MonadUnliftIO m => m (UnliftIO m)
askUnliftIO
  let opts :: ConnectionOptions
opts = ConnectionOptions
WS.defaultConnectionOptions
      headers :: [(CI ByteString, ByteString)]
headers = [(CI ByteString
"Authorization", ByteString
"Bearer " forall a. Semigroup a => a -> a -> a
<> forall ap sp (m :: * -> *). SocketConfig ap sp m -> ByteString
token SocketConfig any0 any1 m
socketConfig)]
      base :: URI
base = forall ap sp (m :: * -> *). SocketConfig ap sp m -> URI
baseURL SocketConfig any0 any1 m
socketConfig
      url :: URI
url = URI
base {uriPath :: [Char]
uriPath = URI -> [Char]
uriPath URI
base [Char] -> ShowS
`slash` forall a b. ConvertText a b => a -> b
toS (forall ap sp (m :: * -> *). SocketConfig ap sp m -> Text
path SocketConfig any0 any1 m
socketConfig)}
      defaultPort :: Int
defaultPort
        | URI -> [Char]
uriScheme URI
url forall a. Eq a => a -> a -> Bool
== [Char]
"http:" = Int
80
        | URI -> [Char]
uriScheme URI
url forall a. Eq a => a -> a -> Bool
== [Char]
"https:" = Int
443
        | Bool
otherwise = forall a. HasCallStack => Text -> a
panic Text
"Hercules.Agent.Socket: invalid uri scheme"
      port :: Int
port = forall a. a -> Maybe a -> a
fromMaybe Int
defaultPort forall a b. (a -> b) -> a -> b
$ do
        URIAuth
auth <- URI -> Maybe URIAuth
uriAuthority URI
url
        forall b a. (Read b, StringConv a [Char]) => a -> Maybe b
readMaybe forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
dropWhile (forall a. Eq a => a -> a -> Bool
== Char
':') forall a b. (a -> b) -> a -> b
$ URIAuth -> [Char]
uriPort URIAuth
auth
      regname :: [Char]
regname = forall a. a -> Maybe a -> a
fromMaybe (forall a. HasCallStack => Text -> a
panic Text
"Hercules.Agent.Socket: url has no regname") forall a b. (a -> b) -> a -> b
$ do
        URIAuth
auth <- URI -> Maybe URIAuth
uriAuthority URI
url
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ URIAuth -> [Char]
uriRegName URIAuth
auth
      httpPath :: [Char]
httpPath = URI -> [Char]
uriPath URI
url forall a. Semigroup a => a -> a -> a
<> URI -> [Char]
uriQuery URI
url
      runSocket :: ClientApp a -> IO a
runSocket
        | URI -> [Char]
uriScheme URI
url forall a. Eq a => a -> a -> Bool
== [Char]
"http:" = forall a.
[Char]
-> Int
-> [Char]
-> ConnectionOptions
-> [(CI ByteString, ByteString)]
-> ClientApp a
-> IO a
runClientWith [Char]
regname Int
port [Char]
httpPath ConnectionOptions
opts [(CI ByteString, ByteString)]
headers
        | URI -> [Char]
uriScheme URI
url forall a. Eq a => a -> a -> Bool
== [Char]
"https:" = forall a.
[Char]
-> PortNumber
-> [Char]
-> ConnectionOptions
-> [(CI ByteString, ByteString)]
-> ClientApp a
-> IO a
runSecureClientWith [Char]
regname (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
port) [Char]
httpPath ConnectionOptions
opts [(CI ByteString, ByteString)]
headers
        | Bool
otherwise = forall a. HasCallStack => Text -> a
panic Text
"Hercules.Agent.Socket: invalid uri scheme"
  forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall {a}. ClientApp a -> IO a
runSocket forall a b. (a -> b) -> a -> b
$ \Connection
conn -> forall a. m a -> IO a
unlift (Connection -> m a
m Connection
conn)

slash :: [Char] -> [Char] -> [Char]
[Char]
a slash :: [Char] -> ShowS
`slash` [Char]
b = forall a. (a -> Bool) -> [a] -> [a]
dropWhileEnd (forall a. Eq a => a -> a -> Bool
== Char
'/') [Char]
a forall a. Semigroup a => a -> a -> a
<> [Char]
"/" forall a. Semigroup a => a -> a -> a
<> forall a. (a -> Bool) -> [a] -> [a]
dropWhile (forall a. Eq a => a -> a -> Bool
== Char
'/') [Char]
b

withTimeout :: (Exception e, MonadIO m, MonadUnliftIO m) => NominalDiffTime -> e -> m a -> m a
withTimeout :: forall e (m :: * -> *) a.
(Exception e, MonadIO m, MonadUnliftIO m) =>
NominalDiffTime -> e -> m a -> m a
withTimeout NominalDiffTime
t e
e m a
_ | NominalDiffTime
t forall a. Ord a => a -> a -> Bool
<= NominalDiffTime
0 = forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO e
e
withTimeout NominalDiffTime
t e
e m a
m =
  forall (m :: * -> *) a.
MonadUnliftIO m =>
Int -> m a -> m (Maybe a)
timeout (forall a b. (RealFrac a, Integral b) => a -> b
ceiling forall a b. (a -> b) -> a -> b
$ NominalDiffTime
t forall a. Num a => a -> a -> a
* NominalDiffTime
1_000_000) m a
m forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Maybe a
Nothing -> forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO e
e
    Just a
a -> forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a