{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE TemplateHaskell #-}
{- |
  This module manages connections to other nodes in the cluster.
-}
module Network.Legion.Runtime.ConnectionManager (
  ConnectionManager,
  newConnectionManager,
  send,
  forward,
  newPeers,
) where

import Prelude hiding (lookup)

import Control.Concurrent (Chan, writeChan, newChan, readChan,
  newEmptyMVar, putMVar, takeMVar)
import Control.Exception (SomeException, bracketOnError)
import Control.Monad (void)
import Control.Monad.Catch (MonadCatch, try)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Logger (logInfo, logWarn, MonadLoggerIO)
import Control.Monad.Trans.Class (lift)
import Data.Binary (Binary, encode)
import Data.ByteString.Lazy (ByteString)
import Data.Conduit (Sink, runConduit, (.|), await)
import Data.Map (toList, insert, empty, Map, lookup)
import Data.Text (pack)
import Network.Legion.BSockAddr (BSockAddr(BSockAddr))
import Network.Legion.Conduit (chanToSource)
import Network.Legion.Distribution (Peer)
import Network.Legion.Fork (forkC, ForkM)
import Network.Legion.Runtime.PeerMessage (PeerMessage(PeerMessage),
  MessageId, PeerMessagePayload, source, messageId, payload,
  nextMessageId, newSequence)
import Network.Legion.SocketUtil (fam)
import Network.Socket (SockAddr, Socket, socket, SocketType(Stream),
  defaultProtocol, connect, close, SockAddr)
import Network.Socket.ByteString.Lazy (sendAll)

{- |
  A handle on the connection manager
-}
newtype ConnectionManager e o s = C (Chan (Message e o s))
instance Show (ConnectionManager e o s) where
  show _ = "ConnectionManager"


{- |
  Create a new connection manager.
-}
newConnectionManager :: (
      Binary e,
      Binary o,
      Binary s,
      ForkM m,
      MonadCatch m,
      MonadLoggerIO m
    )
  => Peer
  -> Map Peer BSockAddr
  -> m (ConnectionManager e o s)
newConnectionManager self initPeers = do
    chan <- liftIO newChan
    nextId <- newSequence
    forkC "connection manager thread" $
      manager chan S {
          nextId,
          connections = empty
        }
    let cm = C chan
    newPeers cm initPeers
    return cm
  where
    manager :: (
          Binary e,
          Binary o,
          Binary s,
          ForkM m,
          MonadCatch m,
          MonadLoggerIO m
        )
      => Chan (Message e o s)
      -> State e o s
      -> m ()
    manager chan state =
      runConduit (chanToSource chan .| handle state)

    handle :: (
          Binary e,
          Binary o,
          Binary s,
          ForkM m,
          MonadCatch m,
          MonadLoggerIO m
        )
      => State e o s
      -> Sink (Message e o s) m ()
    handle s@S {connections, nextId} =
      await >>= \case
        Nothing -> return ()
        Just (NewPeer peer addr) ->
          handle =<< case lookup peer connections of
            Nothing -> do
              conn <- lift (connection addr)
              return s {
                  connections = insert peer conn connections
                }
            Just _ ->
              return s
        Just (Send peer payload respond) -> do
          case lookup peer connections of
            Nothing -> $(logWarn) . pack $ "unknown peer: " ++ show peer
            Just conn -> liftIO $
              writeChan conn PeerMessage {
                source = self,
                messageId = nextId,
                payload
              }
          liftIO (respond nextId)
          handle s {nextId = nextMessageId nextId}
        Just (Forward peer msg) ->
          case lookup peer connections of
            Nothing -> $(logWarn) . pack $ "unknown peer: " ++ show peer
            Just conn -> liftIO $ writeChan conn msg


{- | Build a new connection. -}
connection :: (
      Binary e,
      Binary o,
      Binary s,
      ForkM m,
      MonadCatch m,
      MonadIO m,
      MonadLoggerIO m
    )
  => SockAddr
  -> m (Chan (PeerMessage e o s))

connection addr = do
    chan <- liftIO newChan
    forkC ("connection to: " ++ show addr) $
      handle chan Nothing
    return chan
  where
    handle :: (
          Binary e,
          Binary o,
          Binary s,
          MonadCatch m,
          MonadLoggerIO m
        )
      => Chan (PeerMessage e o s)
      -> Maybe Socket
      -> m ()
    handle chan so =
      liftIO (readChan chan) >>= sendWithRetry so . encode >>= handle chan

    {- | Open a socket. -}
    openSocket :: IO Socket
    openSocket =
      {-
        Make sure to close the socket if an error happens during
        connection, because if not, we could easily run out of file
        descriptors in the case where we rapidly try to send thousands
        of message to the same peer, which could happen when one object
        is a hotspot.
      -}
      bracketOnError
        (socket (fam addr) Stream defaultProtocol)
        close
        (\so -> connect so addr >> return so)

    {- |
      Try to send the payload over the socket, and if that fails, then try to
      create a new socket and retry sending the payload. Return whatever the
      "working" socket is.
    -}
    sendWithRetry :: (MonadCatch m, MonadLoggerIO m)
      => Maybe Socket
      -> ByteString
      -> m (Maybe Socket)
    sendWithRetry Nothing payload =
      try (liftIO openSocket) >>= \case
        Left err -> do
          $(logWarn) . pack
            $ "Can't connect to: " ++ show addr ++ ". Dropping message on "
            ++ "the floor. The error was: "
            ++ show (err :: SomeException)
          return Nothing
        Right so -> do
          result2 <- try (liftIO (sendAll so payload))
          case result2 of
            Left err -> $(logWarn) . pack
              $ "An error happend when trying to send a payload over a socket "
              ++ "to the address: " ++ show addr ++ ". The error was: "
              ++ show (err :: SomeException) ++ ". This is the last straw, we "
              ++ "are not retrying. The message is being dropped on the floor."
            Right _ -> return ()
          return (Just so)
    sendWithRetry (Just so) payload =
      try (liftIO (sendAll so payload)) >>= \case
        Left err -> do
          $(logInfo) . pack
            $ "Socket to " ++ show addr ++ " died. Retrying on a new "
            ++ "socket. The error was: " ++ show (err :: SomeException)
          (liftIO . void) (try (close so) :: IO (Either SomeException ()))
          sendWithRetry Nothing payload
        Right _ ->
          return (Just so)


{- | Send a message to a peer. -}
send :: (MonadIO m)
  => ConnectionManager e o s
  -> Peer
  -> PeerMessagePayload e o s
  -> m MessageId
send (C chan) peer payload = do
  mvar <- liftIO newEmptyMVar
  liftIO . writeChan chan $ Send peer payload (putMVar mvar)
  liftIO (takeMVar mvar)


{- | Forward a message. -}
forward :: (MonadIO m)
  => ConnectionManager e o s
  -> Peer
  -> PeerMessage e o s
  -> m ()
forward (C chan) peer =
  liftIO . writeChan chan . Forward peer


{- |
  Tell the connection manager about a new peer.
-}
newPeer :: (MonadIO io)
  => ConnectionManager e o s
  -> Peer
  -> SockAddr
  -> io ()
newPeer (C chan) peer addr = liftIO $ writeChan chan (NewPeer peer addr)


{- |
  Tell the connection manager about all the peers known to the cluster state.
-}
newPeers :: (MonadIO io)
  => ConnectionManager e o s
  -> Map Peer BSockAddr
  -> io ()
newPeers cm peers =
    mapM_ oneNewPeer (toList peers)
  where
    oneNewPeer (peer, BSockAddr addy) = newPeer cm peer addy


{- |
  The internal state of the connection manager.
-}
data State e o s = S {
         nextId :: MessageId,
    connections :: Map Peer (Chan (PeerMessage e o s))
  }


{- |
  The types of messages that the ConnectionManager understands.
-}
data Message e o s
  = NewPeer Peer SockAddr
  | Forward Peer (PeerMessage e o s)
  | Send Peer (PeerMessagePayload e o s) (MessageId -> IO ())