{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE UndecidableInstances #-}

{- | Description: Manage connections to other peers. -}
module OM.Legion.Connection (
  JoinResponse(..),
  RuntimeState(..),
  EventConstraints,
  disconnect,
  peerMessagePort,
  sendPeer,
) where

import Control.Concurrent.Async (async)
import Control.Exception.Safe (MonadCatch, finally, tryAny)
import Control.Monad (void)
import Control.Monad.IO.Class (MonadIO(liftIO))
import Control.Monad.Logger.CallStack (LoggingT(runLoggingT),
  MonadLoggerIO(askLoggerIO), MonadLogger, logDebug, logInfo)
import Control.Monad.State (MonadState(get), modify)
import Data.Aeson (ToJSON)
import Data.Binary (Binary)
import Data.ByteString.Lazy (ByteString)
import Data.CRDT.EventFold (Event(Output, State), EventFold, EventId)
import Data.Conduit ((.|), ConduitT, awaitForever, runConduit, yield)
import Data.Default.Class (Default)
import Data.Map (Map)
import GHC.Generics (Generic)
import Network.Socket (PortNumber)
import OM.Fork (Responder)
import OM.Legion.MsgChan (Peer(unPeer), ClusterName, MessageId,
  PeerMessage, close, enqueueMsg, newMsgChan, stream)
import OM.Show (showt)
import OM.Socket (AddressDescription(AddressDescription), openEgress)
import System.Clock (TimeSpec)
import qualified Data.Map as Map


{- | A handle on the connection to a peer. -}
newtype Connection e = Connection
  { forall e.
Connection e
-> forall (m :: * -> *).
   (MonadIO m, MonadLogger m, MonadState (RuntimeState e) m) =>
   PeerMessage e -> m Bool
_unConnection
      :: forall m.
         ( MonadIO m
         , MonadLogger m
         , MonadState (RuntimeState e) m
         )
      => PeerMessage e
      -> m Bool
  }


{- | Create a connection to a peer. -}
createConnection
  :: forall m e.
     ( EventConstraints e
     , MonadCatch m
     , MonadLoggerIO m
     , MonadState (RuntimeState e) m
     )
  => Peer
  -> m (Connection e)
createConnection :: forall (m :: * -> *) e.
(EventConstraints e, MonadCatch m, MonadLoggerIO m,
 MonadState (RuntimeState e) m) =>
Peer -> m (Connection e)
createConnection Peer
peer = do
    Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logInfo (Text -> m ()) -> Text -> m ()
forall a b. (a -> b) -> a -> b
$ Text
"Creating connection to: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Peer -> Text
forall a b. (Show a, IsString b) => a -> b
showt Peer
peer
    RuntimeState {Peer
rsSelf :: forall e. RuntimeState e -> Peer
rsSelf :: Peer
rsSelf} <- m (RuntimeState e)
forall s (m :: * -> *). MonadState s m => m s
get
    MsgChan e
msgChan <- m (MsgChan e)
forall (m :: * -> *) e. MonadIO m => m (MsgChan e)
newMsgChan
    Loc -> Text -> LogLevel -> LogStr -> IO ()
logging <- m (Loc -> Text -> LogLevel -> LogStr -> IO ())
forall (m :: * -> *).
MonadLoggerIO m =>
m (Loc -> Text -> LogLevel -> LogStr -> IO ())
askLoggerIO
    IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ())
-> (LoggingT IO () -> IO ()) -> LoggingT IO () -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO (Async ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Async ()) -> IO ())
-> (LoggingT IO () -> IO (Async ())) -> LoggingT IO () -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO () -> IO (Async ())
forall a. IO a -> IO (Async a)
async (IO () -> IO (Async ()))
-> (LoggingT IO () -> IO ()) -> LoggingT IO () -> IO (Async ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (LoggingT IO ()
-> (Loc -> Text -> LogLevel -> LogStr -> IO ()) -> IO ()
forall (m :: * -> *) a.
LoggingT m a -> (Loc -> Text -> LogLevel -> LogStr -> IO ()) -> m a
`runLoggingT` Loc -> Text -> LogLevel -> LogStr -> IO ()
logging) (LoggingT IO () -> m ()) -> LoggingT IO () -> m ()
forall a b. (a -> b) -> a -> b
$
      let
        addy :: AddressDescription
        addy :: AddressDescription
addy =
          Text -> AddressDescription
AddressDescription
            (
              Peer -> Text
unPeer Peer
peer
              Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
":" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> PortNumber -> Text
forall a b. (Show a, IsString b) => a -> b
showt PortNumber
peerMessagePort
            )
      in
        LoggingT IO () -> LoggingT IO () -> LoggingT IO ()
forall (m :: * -> *) a b. MonadMask m => m a -> m b -> m a
finally 
          (
            (LoggingT IO () -> LoggingT IO (Either SomeException ())
forall (m :: * -> *) a.
MonadCatch m =>
m a -> m (Either SomeException a)
tryAny (LoggingT IO () -> LoggingT IO (Either SomeException ()))
-> (ConduitT () Void (LoggingT IO) () -> LoggingT IO ())
-> ConduitT () Void (LoggingT IO) ()
-> LoggingT IO (Either SomeException ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConduitT () Void (LoggingT IO) () -> LoggingT IO ()
forall (m :: * -> *) r. Monad m => ConduitT () Void m r -> m r
runConduit) (
              Peer
-> MsgChan e -> ConduitT () (Peer, PeerMessage e) (LoggingT IO) ()
forall (m :: * -> *) e void.
MonadIO m =>
Peer -> MsgChan e -> ConduitT void (Peer, PeerMessage e) m ()
stream Peer
rsSelf MsgChan e
msgChan
              ConduitT () (Peer, PeerMessage e) (LoggingT IO) ()
-> ConduitT (Peer, PeerMessage e) Void (LoggingT IO) ()
-> ConduitT () Void (LoggingT IO) ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.| ConduitT
  (Peer, PeerMessage e) (Peer, PeerMessage e) (LoggingT IO) ()
forall (w :: * -> *).
MonadLogger w =>
ConduitT (Peer, PeerMessage e) (Peer, PeerMessage e) w ()
logMessageSend
              ConduitT
  (Peer, PeerMessage e) (Peer, PeerMessage e) (LoggingT IO) ()
-> ConduitT (Peer, PeerMessage e) Void (LoggingT IO) ()
-> ConduitT (Peer, PeerMessage e) Void (LoggingT IO) ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.| AddressDescription
-> ConduitT (Peer, PeerMessage e) Void (LoggingT IO) ()
forall o (m :: * -> *).
(Binary o, MonadFail m, MonadIO m, MonadThrow m) =>
AddressDescription -> ConduitT o Void m ()
openEgress AddressDescription
addy
            ) LoggingT IO (Either SomeException ())
-> (Either SomeException () -> LoggingT IO ()) -> LoggingT IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
              Left SomeException
err ->
                Text -> LoggingT IO ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logInfo (Text -> LoggingT IO ()) -> Text -> LoggingT IO ()
forall a b. (a -> b) -> a -> b
$ Text
"Disconnecting because of error: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> SomeException -> Text
forall a b. (Show a, IsString b) => a -> b
showt SomeException
err
              Right () -> Text -> LoggingT IO ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logInfo Text
"Disconnecting because source dried up."
          )
          (MsgChan e -> LoggingT IO ()
forall (m :: * -> *) e. MonadIO m => MsgChan e -> m ()
close MsgChan e
msgChan)
      
    let
      conn :: Connection e
      conn :: Connection e
conn = (forall (m :: * -> *).
 (MonadIO m, MonadLogger m, MonadState (RuntimeState e) m) =>
 PeerMessage e -> m Bool)
-> Connection e
forall e.
(forall (m :: * -> *).
 (MonadIO m, MonadLogger m, MonadState (RuntimeState e) m) =>
 PeerMessage e -> m Bool)
-> Connection e
Connection (MsgChan e -> PeerMessage e -> m Bool
forall (m :: * -> *) e.
MonadIO m =>
MsgChan e -> PeerMessage e -> m Bool
enqueueMsg MsgChan e
msgChan)
    (RuntimeState e -> RuntimeState e) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify
      (\RuntimeState e
state -> RuntimeState e
state {
        rsConnections :: Map Peer (Connection e)
rsConnections = Peer
-> Connection e
-> Map Peer (Connection e)
-> Map Peer (Connection e)
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Peer
peer Connection e
conn (RuntimeState e -> Map Peer (Connection e)
forall e. RuntimeState e -> Map Peer (Connection e)
rsConnections RuntimeState e
state)
      })
    Connection e -> m (Connection e)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Connection e
conn
  where
    logMessageSend
      :: forall w.
         (MonadLogger w)
      => ConduitT (Peer, PeerMessage e) (Peer, PeerMessage e) w ()
    logMessageSend :: forall (w :: * -> *).
MonadLogger w =>
ConduitT (Peer, PeerMessage e) (Peer, PeerMessage e) w ()
logMessageSend =
      ((Peer, PeerMessage e)
 -> ConduitT (Peer, PeerMessage e) (Peer, PeerMessage e) w ())
-> ConduitT (Peer, PeerMessage e) (Peer, PeerMessage e) w ()
forall (m :: * -> *) i o r.
Monad m =>
(i -> ConduitT i o m r) -> ConduitT i o m ()
awaitForever
        (\(Peer, PeerMessage e)
msg -> do
          Text -> ConduitT (Peer, PeerMessage e) (Peer, PeerMessage e) w ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logDebug
            (Text -> ConduitT (Peer, PeerMessage e) (Peer, PeerMessage e) w ())
-> Text
-> ConduitT (Peer, PeerMessage e) (Peer, PeerMessage e) w ()
forall a b. (a -> b) -> a -> b
$ Text
"Sending Message to Peer (peer, msg): "
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> (Peer, (Peer, PeerMessage e)) -> Text
forall a b. (Show a, IsString b) => a -> b
showt (Peer
peer, (Peer, PeerMessage e)
msg)
          (Peer, PeerMessage e)
-> ConduitT (Peer, PeerMessage e) (Peer, PeerMessage e) w ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield (Peer, PeerMessage e)
msg
        )


{- |
  Shorthand for all the constraints needed for the event type. Mainly
  used so that documentation renders better.
-}
type EventConstraints e =
  ( Binary (Output e)
  , Binary (State e)
  , Binary e
  , Default (State e)
  , Eq (Output e)
  , Eq e
  , Event Peer e
  , Show (Output e)
  , Show (State e)
  , Show e
  , ToJSON (Output e)
  , ToJSON (State e)
  , ToJSON e
  )


{- | The Legionary runtime state. -}
data RuntimeState e = RuntimeState
  {         forall e. RuntimeState e -> Peer
rsSelf :: Peer
  , forall e. RuntimeState e -> EventFold ClusterName Peer e
rsClusterState :: EventFold ClusterName Peer e
  ,  forall e. RuntimeState e -> Map Peer (Connection e)
rsConnections :: Map Peer (Connection e)
  ,      forall e.
RuntimeState e -> Map (EventId Peer) (Responder (Output e))
rsWaiting :: Map (EventId Peer) (Responder (Output e))
  ,        forall e. RuntimeState e -> Map MessageId (Responder ByteString)
rsCalls :: Map MessageId (Responder ByteString)
  ,   forall e.
RuntimeState e
-> Map
     MessageId
     (Map Peer (Maybe ByteString),
      Responder (Map Peer (Maybe ByteString)), TimeSpec)
rsBroadcalls :: Map
                        MessageId
                        (
                          Map Peer (Maybe ByteString),
                          Responder (Map Peer (Maybe ByteString)),
                          TimeSpec
                        )
  ,       forall e. RuntimeState e -> MessageId
rsNextId :: MessageId
  ,       forall e. RuntimeState e -> EventFold ClusterName Peer e -> IO ()
rsNotify :: EventFold ClusterName Peer e -> IO ()
  ,        forall e.
RuntimeState e -> Map (EventId Peer) (Responder (JoinResponse e))
rsJoins :: Map
                        (EventId Peer)
                        (Responder (JoinResponse e))
                      {- ^
                        The infimum of the eventfold we send to a
                        new participant must have moved past the
                        participation event itself. In other words,
                        the join must be totally consistent across the
                        cluster. The reason is that we can't make the
                        new participant responsible for applying events
                        that occur before it joined the cluster, because
                        it has no way to ensure that it can collect all
                        such events.  Therefore, this field tracks the
                        outstanding joins until they become consistent.
                      -}
  }


{- | The response to a JoinRequest message -}
newtype JoinResponse e
  = JoinOk (EventFold ClusterName Peer e)
  deriving stock ((forall x. JoinResponse e -> Rep (JoinResponse e) x)
-> (forall x. Rep (JoinResponse e) x -> JoinResponse e)
-> Generic (JoinResponse e)
forall x. Rep (JoinResponse e) x -> JoinResponse e
forall x. JoinResponse e -> Rep (JoinResponse e) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall e x. Rep (JoinResponse e) x -> JoinResponse e
forall e x. JoinResponse e -> Rep (JoinResponse e) x
$cto :: forall e x. Rep (JoinResponse e) x -> JoinResponse e
$cfrom :: forall e x. JoinResponse e -> Rep (JoinResponse e) x
Generic)
deriving stock instance (EventConstraints e) => Show (JoinResponse e)
instance (EventConstraints e) => Binary (JoinResponse e)


{- | The peer message port. -}
peerMessagePort :: PortNumber
peerMessagePort :: PortNumber
peerMessagePort = PortNumber
5288


{- | Disconnect the connection to a peer. -}
disconnect
  :: ( MonadLogger m
     , MonadState (RuntimeState e) m
     )
  => Peer
  -> m ()
disconnect :: forall (m :: * -> *) e.
(MonadLogger m, MonadState (RuntimeState e) m) =>
Peer -> m ()
disconnect Peer
peer = do
  Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logInfo (Text -> m ()) -> Text -> m ()
forall a b. (a -> b) -> a -> b
$ Text
"Disconnecting: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Peer -> Text
forall a b. (Show a, IsString b) => a -> b
showt Peer
peer
  (RuntimeState e -> RuntimeState e) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\state :: RuntimeState e
state@RuntimeState {Map Peer (Connection e)
rsConnections :: Map Peer (Connection e)
rsConnections :: forall e. RuntimeState e -> Map Peer (Connection e)
rsConnections} -> RuntimeState e
state {
    rsConnections :: Map Peer (Connection e)
rsConnections = Peer -> Map Peer (Connection e) -> Map Peer (Connection e)
forall k a. Ord k => k -> Map k a -> Map k a
Map.delete Peer
peer Map Peer (Connection e)
rsConnections
  })


{- | Send a peer message, creating a new connection if need be. -}
sendPeer
  :: forall m e.
     ( EventConstraints e
     , MonadCatch m
     , MonadLoggerIO m
     , MonadState (RuntimeState e) m
     )
  => PeerMessage e
  -> Peer
  -> m ()
sendPeer :: forall (m :: * -> *) e.
(EventConstraints e, MonadCatch m, MonadLoggerIO m,
 MonadState (RuntimeState e) m) =>
PeerMessage e -> Peer -> m ()
sendPeer PeerMessage e
msg Peer
peer = do
    RuntimeState {Map Peer (Connection e)
rsConnections :: Map Peer (Connection e)
rsConnections :: forall e. RuntimeState e -> Map Peer (Connection e)
rsConnections} <- m (RuntimeState e)
forall s (m :: * -> *). MonadState s m => m s
get
    case Peer -> Map Peer (Connection e) -> Maybe (Connection e)
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Peer
peer Map Peer (Connection e)
rsConnections of
      Maybe (Connection e)
Nothing -> do
        Connection e
conn <- Peer -> m (Connection e)
forall (m :: * -> *) e.
(EventConstraints e, MonadCatch m, MonadLoggerIO m,
 MonadState (RuntimeState e) m) =>
Peer -> m (Connection e)
createConnection Peer
peer
        Connection e -> m ()
sendTheMessage Connection e
conn
      Just Connection e
conn ->
        Connection e -> m ()
sendTheMessage Connection e
conn
  where
    sendTheMessage :: Connection e -> m ()
    sendTheMessage :: Connection e -> m ()
sendTheMessage (Connection forall (m :: * -> *).
(MonadIO m, MonadLogger m, MonadState (RuntimeState e) m) =>
PeerMessage e -> m Bool
conn) =
      PeerMessage e -> m Bool
forall (m :: * -> *).
(MonadIO m, MonadLogger m, MonadState (RuntimeState e) m) =>
PeerMessage e -> m Bool
conn PeerMessage e
msg m Bool -> (Bool -> m ()) -> m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Bool
True -> () -> m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        Bool
False -> Peer -> m ()
forall (m :: * -> *) e.
(MonadLogger m, MonadState (RuntimeState e) m) =>
Peer -> m ()
disconnect Peer
peer