{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE UndecidableInstances #-}

{- | Description: Manage connections to other peers. -}
module OM.Legion.Connection (
  JoinResponse(..),
  RuntimeState(..),
  EventConstraints,
  disconnect,
  peerMessagePort,
  sendPeer,
) where

import Control.Concurrent.Async (async)
import Control.Exception.Safe (finally, tryAny)
import Control.Monad (void)
import Control.Monad.IO.Class (MonadIO(liftIO))
import Control.Monad.Logger.CallStack (LoggingT(runLoggingT),
  MonadLoggerIO(askLoggerIO), MonadLogger, logDebug, logInfo)
import Control.Monad.State (MonadState(get), modify)
import Data.Aeson (ToJSON)
import Data.Binary (Binary)
import Data.ByteString.Lazy (ByteString)
import Data.CRDT.EventFold (Event(Output, State), EventFold, EventId)
import Data.Default.Class (Default)
import Data.Function (($), (&), (.))
import Data.Map (Map)
import GHC.Generics (Generic)
import Network.Socket (PortNumber)
import OM.Fork (Responder)
import OM.Legion.MsgChan (Peer(unPeer), ClusterName, MessageId,
  PeerMessage, close, enqueueMsg, newMsgChan, stream)
import OM.Show (showt)
import OM.Socket (AddressDescription(AddressDescription), openEgress)
import Prelude (Applicative(pure), Bool(False, True), Either(Left, Right),
  Maybe(Just, Nothing), Monad((>>=)), Semigroup((<>)), Eq, IO, Show)
import System.Clock (TimeSpec)
import qualified Data.Map as Map
import qualified Streaming.Prelude as Stream


{- | A handle on the connection to a peer. -}
newtype Connection e = Connection
  { forall e.
Connection e
-> forall (m :: * -> *).
   (MonadIO m, MonadLogger m, MonadState (RuntimeState e) m) =>
   PeerMessage e -> m Bool
_unConnection
      :: forall m.
         ( MonadIO m
         , MonadLogger m
         , MonadState (RuntimeState e) m
         )
      => PeerMessage e
      -> m Bool
  }


{- | Create a connection to a peer. -}
createConnection
  :: forall m e.
     ( EventConstraints e
     , MonadLoggerIO m
     , MonadState (RuntimeState e) m
     )
  => Peer
  -> m (Connection e)
createConnection :: forall (m :: * -> *) e.
(EventConstraints e, MonadLoggerIO m,
 MonadState (RuntimeState e) m) =>
Peer -> m (Connection e)
createConnection Peer
peer = do
    Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logInfo (Text -> m ()) -> Text -> m ()
forall a b. (a -> b) -> a -> b
$ Text
"Creating connection to: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Peer -> Text
forall a b. (Show a, IsString b) => a -> b
showt Peer
peer
    RuntimeState {rsSelf} <- m (RuntimeState e)
forall s (m :: * -> *). MonadState s m => m s
get
    msgChan <- newMsgChan
    logging <- askLoggerIO
    liftIO . void . async . (`runLoggingT` logging) $
      let
        addy :: AddressDescription
        addy =
          Text -> AddressDescription
AddressDescription
            (
              Peer -> Text
unPeer Peer
peer
              Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
":" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> PortNumber -> Text
forall a b. (Show a, IsString b) => a -> b
showt PortNumber
peerMessagePort
            )
      in
        finally
          (
            tryAny (
              stream rsSelf msgChan
              & Stream.mapM logMessageSend
              & openEgress addy
            ) >>= \case
              Left SomeException
err ->
                Text -> LoggingT IO ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logInfo (Text -> LoggingT IO ()) -> Text -> LoggingT IO ()
forall a b. (a -> b) -> a -> b
$ Text
"Disconnecting because of error: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> SomeException -> Text
forall a b. (Show a, IsString b) => a -> b
showt SomeException
err
              Right () -> Text -> LoggingT IO ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logInfo Text
"Disconnecting because source dried up."
          )
          (close msgChan)

    let
      conn :: Connection e
      conn = (forall (m :: * -> *).
 (MonadIO m, MonadLogger m, MonadState (RuntimeState e) m) =>
 PeerMessage e -> m Bool)
-> Connection e
forall e.
(forall (m :: * -> *).
 (MonadIO m, MonadLogger m, MonadState (RuntimeState e) m) =>
 PeerMessage e -> m Bool)
-> Connection e
Connection (MsgChan e -> PeerMessage e -> m Bool
forall (m :: * -> *) e.
MonadIO m =>
MsgChan e -> PeerMessage e -> m Bool
enqueueMsg MsgChan e
msgChan)
    modify
      (\RuntimeState e
state -> RuntimeState e
state {
        rsConnections = Map.insert peer conn (rsConnections state)
      })
    pure conn
  where
    logMessageSend
      :: forall w.
         (MonadLogger w)
      => (Peer, PeerMessage e)
      -> w (Peer, PeerMessage e)
    logMessageSend :: forall (w :: * -> *).
MonadLogger w =>
(Peer, PeerMessage e) -> w (Peer, PeerMessage e)
logMessageSend (Peer, PeerMessage e)
msg = do
      Text -> w ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logDebug
        (Text -> w ()) -> Text -> w ()
forall a b. (a -> b) -> a -> b
$ Text
"Sending Message to Peer (peer, msg): "
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> (Peer, (Peer, PeerMessage e)) -> Text
forall a b. (Show a, IsString b) => a -> b
showt (Peer
peer, (Peer, PeerMessage e)
msg)
      (Peer, PeerMessage e) -> w (Peer, PeerMessage e)
forall a. a -> w a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Peer, PeerMessage e)
msg


{- |
  Shorthand for all the constraints needed for the event type. Mainly
  used so that documentation renders better.
-}
type EventConstraints e =
  ( Binary (Output e)
  , Binary (State e)
  , Binary e
  , Default (State e)
  , Eq (Output e)
  , Eq e
  , Event Peer e
  , Show (Output e)
  , Show (State e)
  , Show e
  , ToJSON (Output e)
  , ToJSON (State e)
  , ToJSON e
  )


{- | The Legionary runtime state. -}
data RuntimeState e = RuntimeState
  {         forall e. RuntimeState e -> Peer
rsSelf :: Peer
  , forall e. RuntimeState e -> EventFold ClusterName Peer e
rsClusterState :: EventFold ClusterName Peer e
  ,  forall e. RuntimeState e -> Map Peer (Connection e)
rsConnections :: Map Peer (Connection e)
  ,      forall e.
RuntimeState e -> Map (EventId Peer) (Responder (Output e))
rsWaiting :: Map (EventId Peer) (Responder (Output e))
  ,        forall e. RuntimeState e -> Map MessageId (Responder ByteString)
rsCalls :: Map MessageId (Responder ByteString)
  ,   forall e.
RuntimeState e
-> Map
     MessageId
     (Map Peer (Maybe ByteString),
      Responder (Map Peer (Maybe ByteString)), TimeSpec)
rsBroadcalls :: Map
                        MessageId
                        (
                          Map Peer (Maybe ByteString),
                          Responder (Map Peer (Maybe ByteString)),
                          TimeSpec
                        )
  ,       forall e. RuntimeState e -> MessageId
rsNextId :: MessageId
  ,       forall e. RuntimeState e -> EventFold ClusterName Peer e -> IO ()
rsNotify :: EventFold ClusterName Peer e -> IO ()
  ,        forall e.
RuntimeState e -> Map (EventId Peer) (Responder (JoinResponse e))
rsJoins :: Map
                        (EventId Peer)
                        (Responder (JoinResponse e))
                      {- ^
                        The infimum of the eventfold we send to a
                        new participant must have moved past the
                        participation event itself. In other words,
                        the join must be totally consistent across the
                        cluster. The reason is that we can't make the
                        new participant responsible for applying events
                        that occur before it joined the cluster, because
                        it has no way to ensure that it can collect all
                        such events.  Therefore, this field tracks the
                        outstanding joins until they become consistent.
                      -}
  }


{- | The response to a JoinRequest message -}
newtype JoinResponse e
  = JoinOk (EventFold ClusterName Peer e)
  deriving stock ((forall x. JoinResponse e -> Rep (JoinResponse e) x)
-> (forall x. Rep (JoinResponse e) x -> JoinResponse e)
-> Generic (JoinResponse e)
forall x. Rep (JoinResponse e) x -> JoinResponse e
forall x. JoinResponse e -> Rep (JoinResponse e) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall e x. Rep (JoinResponse e) x -> JoinResponse e
forall e x. JoinResponse e -> Rep (JoinResponse e) x
$cfrom :: forall e x. JoinResponse e -> Rep (JoinResponse e) x
from :: forall x. JoinResponse e -> Rep (JoinResponse e) x
$cto :: forall e x. Rep (JoinResponse e) x -> JoinResponse e
to :: forall x. Rep (JoinResponse e) x -> JoinResponse e
Generic)
deriving stock instance (EventConstraints e) => Show (JoinResponse e)
instance (EventConstraints e) => Binary (JoinResponse e)


{- | The peer message port. -}
peerMessagePort :: PortNumber
peerMessagePort :: PortNumber
peerMessagePort = PortNumber
5288


{- | Disconnect the connection to a peer. -}
disconnect
  :: ( MonadLogger m
     , MonadState (RuntimeState e) m
     )
  => Peer
  -> m ()
disconnect :: forall (m :: * -> *) e.
(MonadLogger m, MonadState (RuntimeState e) m) =>
Peer -> m ()
disconnect Peer
peer = do
  Text -> m ()
forall (m :: * -> *). (HasCallStack, MonadLogger m) => Text -> m ()
logInfo (Text -> m ()) -> Text -> m ()
forall a b. (a -> b) -> a -> b
$ Text
"Disconnecting: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Peer -> Text
forall a b. (Show a, IsString b) => a -> b
showt Peer
peer
  (RuntimeState e -> RuntimeState e) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\state :: RuntimeState e
state@RuntimeState {Map Peer (Connection e)
rsConnections :: forall e. RuntimeState e -> Map Peer (Connection e)
rsConnections :: Map Peer (Connection e)
rsConnections} -> RuntimeState e
state {
    rsConnections = Map.delete peer rsConnections
  })


{- | Send a peer message, creating a new connection if need be. -}
sendPeer
  :: forall m e.
     ( EventConstraints e
     , MonadLoggerIO m
     , MonadState (RuntimeState e) m
     )
  => PeerMessage e
  -> Peer
  -> m ()
sendPeer :: forall (m :: * -> *) e.
(EventConstraints e, MonadLoggerIO m,
 MonadState (RuntimeState e) m) =>
PeerMessage e -> Peer -> m ()
sendPeer PeerMessage e
msg Peer
peer = do
    RuntimeState {rsConnections} <- m (RuntimeState e)
forall s (m :: * -> *). MonadState s m => m s
get
    case Map.lookup peer rsConnections of
      Maybe (Connection e)
Nothing -> do
        conn <- Peer -> m (Connection e)
forall (m :: * -> *) e.
(EventConstraints e, MonadLoggerIO m,
 MonadState (RuntimeState e) m) =>
Peer -> m (Connection e)
createConnection Peer
peer
        sendTheMessage conn
      Just Connection e
conn ->
        Connection e -> m ()
sendTheMessage Connection e
conn
  where
    sendTheMessage :: Connection e -> m ()
    sendTheMessage :: Connection e -> m ()
sendTheMessage (Connection forall (m :: * -> *).
(MonadIO m, MonadLogger m, MonadState (RuntimeState e) m) =>
PeerMessage e -> m Bool
conn) =
      PeerMessage e -> m Bool
forall (m :: * -> *).
(MonadIO m, MonadLogger m, MonadState (RuntimeState e) m) =>
PeerMessage e -> m Bool
conn PeerMessage e
msg m Bool -> (Bool -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Bool
True -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        Bool
False -> Peer -> m ()
forall (m :: * -> *) e.
(MonadLogger m, MonadState (RuntimeState e) m) =>
Peer -> m ()
disconnect Peer
peer