{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
{- |
  This module contains the state machine implementation of a legion node.
-}
module Network.Legion.StateMachine (
  stateMachine,
  LInput(..),
  LOutput(..),
  JoinRequest(..),
  JoinResponse(..),
  AdminMessage(..),
  NodeState,
  Forwarded(..),
  PeerMessage(..),
  PeerMessagePayload(..),
  MessageId,
  next,
  newNodeState,
) where

import Prelude hiding (lookup)

import Control.Exception (throw)
import Control.Monad (unless)
import Control.Monad.Catch (try, SomeException, MonadCatch, MonadThrow)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Logger (logDebug, logWarn, logError, logInfo,
  MonadLogger)
import Control.Monad.Trans.Class (MonadTrans, lift)
import Control.Monad.Trans.State (StateT, runStateT, get, put)
import Data.Binary (Binary)
import Data.Conduit (Source, Conduit, ($$), await, awaitForever,
  transPipe, ConduitM, yield, ($=))
import Data.Default.Class (Default)
import Data.Map (Map, insert, lookup)
import Data.Maybe (fromMaybe)
import Data.Set (member, minView, (\\))
import Data.Text (pack)
import Data.Time.Clock (getCurrentTime)
import Data.UUID (UUID)
import Data.Word (Word64)
import GHC.Generics (Generic)
import Network.Legion.Application (Legionary, LegionConstraints,
  Persistence(getState, saveState, list), Legionary(Legionary,
  persistence, handleRequest), RequestMsg)
import Network.Legion.BSockAddr (BSockAddr)
import Network.Legion.ClusterState (claimParticipation, ClusterPropState,
  getPeers, getDistribution, ClusterPowerState)
import Network.Legion.Distribution (rebalanceAction, RebalanceAction(
  Invite), Peer, newPeer)
import Network.Legion.KeySet (union, KeySet)
import Network.Legion.LIO (LIO)
import Network.Legion.PartitionKey (PartitionKey)
import Network.Legion.PartitionState (PartitionPowerState, PartitionPropState)
import Network.Legion.PowerState (ApplyDelta)
import Network.Legion.UUID (getUUID)
import qualified Data.Conduit.List as CL
import qualified Data.Map as Map
import qualified Data.Set as Set
import qualified Network.Legion.ClusterState as C
import qualified Network.Legion.KeySet as KS
import qualified Network.Legion.PartitionState as P


{- | This conduit houses the main legionary state machine.  -}
stateMachine :: (LegionConstraints i o s)
  => Legionary i o s
  -> NodeState i o s
  -> Conduit (LInput i o s) LIO (LOutput i o s)
stateMachine l n = awaitForever (\msg -> do
    newState <- runStateMT n $ do
      handleMessage l msg
      heartbeat
      migrate l
      propagate
      rebalance l
      logState
    stateMachine l newState
  )
  where
    logState = lift . logNodeState =<< getS


{- | Handle one incomming message.  -}
handleMessage :: (LegionConstraints i o s)
  => Legionary i o s
  -> LInput i o s
  -> StateM i o s ()

handleMessage l msg = do
  NodeState {cluster} <- getS
  let
    {- | Return `True` if the peer is a known peer, false otherwise.  -}
    known peer = peer `member` C.allParticipants cluster
  $(logDebug) . pack $ "Receiving: " ++ show msg
  case msg of
    P peerMsg@PeerMessage {source} ->
      if known source
        then handlePeerMessage l peerMsg
        else
          $(logWarn) . pack
            $ "Dropping message from unknown peer: " ++ show source
    R ((key, request), respond) ->
      case minView (C.findPartition key cluster) of
        Nothing ->
          $(logError) . pack
            $ "Keyspace does not contain key: " ++ show key ++ ". This "
            ++ "is a very bad thing and probably means there is a bug, "
            ++ "or else this node has not joined a cluster yet."
        Just (peer, _) ->
          forward peer key request respond
    J m -> handleJoinRequest m
    A m -> lift . handleAdminMessage l m =<< getS


{- | Handles one incomming message from a peer. -}
handlePeerMessage :: (LegionConstraints i o s)
  => Legionary i o s
  -> PeerMessage i o s
  -> StateM i o s ()

handlePeerMessage -- PartitionMerge
    Legionary {
        persistence
      }
    msg@PeerMessage {
        source,
        payload = PartitionMerge key ps
      }
  = do
    nodeState@NodeState {self, propStates, cluster} <- getS
    propState <- lift $ maybe
      (getStateL persistence self cluster key)
      return
      (lookup key propStates)
    let
      owners = C.findPartition key cluster
    case P.mergeEither source ps propState of
      Left err ->
        $(logWarn) . pack
          $ "Can't apply incomming partition action message "
          ++ show msg ++ "because of: " ++ show err
      Right newPropState -> do
        $(logDebug) "Saving because of PartitionMerge"
        lift $ saveStateL persistence key (
            if P.participating newPropState
              then Just (P.getPowerState newPropState)
              else Nothing
          )
        putS nodeState {
            propStates = if newPropState == P.new key self owners
              then Map.delete key propStates
              else insert key newPropState propStates
          }

handlePeerMessage -- ForwardRequest
    Legionary {handleRequest, persistence}
    msg@PeerMessage {
        payload = ForwardRequest key request,
        source,
        messageId
      }
  = do
    ns@NodeState {self, cluster, propStates} <- getS
    let owners = C.findPartition key cluster
    if self `member` owners
      then do
        let
          respond = send source . ForwardResponse messageId

        -- TODO 
        --   - figure out some slick concurrency here, by maintaining
        --       a map of keys that are currently being accessed or
        --       something
        -- 
        either (respond . rethrow) respond =<< try (do 
            prop <- lift $ getStateL persistence self cluster key
            let response = handleRequest key request (P.ask prop)
                newProp = P.delta request prop
            $(logDebug) "Saving because of ForwardRequest"
            lift $ saveStateL persistence key (Just (P.getPowerState newProp))
            $(logInfo) . pack
              $ "Handling user request: " ++ show request
            $(logDebug) . pack
              $ "Request details request: " ++ show prop ++ " ++ "
              ++ show request ++ " --> " ++ show (response, newProp)
            putS ns {propStates = insert key newProp propStates}
            return response
          )
      else
        {-
          we don't own the key after all, someone was wrong to forward
          us this request.
        -}
        case minView owners of
          Nothing -> $(logError) . pack
            $ "Can't find any owners for the key: " ++ show key
          Just (peer, _) ->
            emit (Send peer msg)
  where
    {- |
      rethrow is just a reification of `throw`.
    -}
    rethrow :: SomeException -> a
    rethrow = throw

handlePeerMessage -- ForwardResponse
    Legionary {}
    msg@PeerMessage {
        payload = ForwardResponse messageId response
      }
  = do
    nodeState@NodeState {forwarded} <- getS
    case lookup messageId (unF forwarded) of
      Nothing -> $(logWarn) . pack
        $  "This peer received a response for a forwarded request that it "
        ++ "didn't send. The only time you might expect to see this is if "
        ++ "this peer recently crashed and was immediately restarted. If "
        ++ "you are seeing this in other circumstances then probably "
        ++ "something is very wrong at the network level. The message was: "
        ++ show msg
      Just respond ->
        lift $ respond response
    putS nodeState {
        forwarded = F . Map.delete messageId . unF $ forwarded
      }

handlePeerMessage -- ClusterMerge
    Legionary {}
    msg@PeerMessage {
        source,
        payload = ClusterMerge ps
      }
  = do
    nodeState@NodeState {migration, cluster} <- getS
    case C.mergeEither source ps cluster of
      Left err ->
        $(logWarn) . pack
          $ "Can't apply incomming cluster action message "
          ++ show msg ++ "because of: " ++ show err
      Right (newCluster, newMigration) ->
        putS nodeState {
            migration = migration `union` newMigration,
            cluster = newCluster
          }


{- | Handle a join request message -}
handleJoinRequest
  :: (JoinRequest, JoinResponse -> LIO ())
  -> StateM i o s ()

handleJoinRequest (JoinRequest peerAddr, respond) = do
  ns@NodeState {cluster} <- getS
  peer <- lift newPeer
  let newCluster = C.joinCluster peer peerAddr cluster
  emit .  NewPeers . getPeers $ newCluster
  lift $ respond (JoinOk peer (C.getPowerState newCluster))
  putS ns {cluster = newCluster}


{- |
  Handle a message from the admin service.
-}
handleAdminMessage
  :: Legionary i o s
  -> AdminMessage i o s
  -> NodeState i o s
  -> LIO ()
handleAdminMessage _ (GetState respond) ns =
  respond ns
handleAdminMessage Legionary {persistence} (GetPart key respond) _ = do
  partitionVal <- lift (getState persistence key)
  respond partitionVal


{- | Update all of the propagation states with the current time.  -}
heartbeat :: StateM i o s ()
heartbeat = do
  now <- liftIO getCurrentTime
  ns@NodeState {cluster, propStates} <- getS
  putS ns {
      cluster = C.heartbeat now cluster,
      propStates = Map.fromAscList [
          (k, P.heartbeat now p)
          | (k, p) <- Map.toAscList propStates
        ]
    }


{- |
  Migrate partitions based on new cluster state information.

  TODO: this migration algorithm is super naive. It just goes ahead
  and migrates everything in one pass, which is going to be terrible
  for performance.

  Also, it is important to remember that "migrate" in this context does
  not mean "transfer data". Rather, "migrate" means to add a participating
  peer to a partition. This will cause the data to be transfered in the
  normal course of propagation.
-}
migrate :: (LegionConstraints i o s) => Legionary i o s -> StateM i o s ()
migrate Legionary{persistence} = do
    ns@NodeState {migration} <- getS
    unless (KS.null migration) $
      putS =<< lift (
          listL persistence
          $= CL.filter ((`KS.member` migration) . fst)
          $$ accum ns {migration = KS.empty}
        )
  where
    accum ns@NodeState {self, cluster, propStates} = await >>= \case
      Nothing -> return ns
      Just (key, ps) -> 
        let
          origProp = fromMaybe (P.initProp self ps) (lookup key propStates)
          newPeers_ = C.findPartition key cluster \\ P.projParticipants origProp
          {- This 'P.participate' is where the magic happens. -}
          newProp = foldr P.participate origProp (Set.toList newPeers_)
        in do
          $(logDebug) . pack $ "Migrating: " ++ show key
          lift (saveStateL persistence key (Just (P.getPowerState newProp)))
          accum ns {
              propStates = Map.insert key newProp propStates
            }


{- |
  Handle all cluster and partition state propagation actions, and return
  an updated node state.
-}
propagate :: (LegionConstraints i o s) => StateM i o s ()
propagate = do
    ns@NodeState {cluster, propStates, self} <- getS
    let (peers, ps, cluster2) = C.actions cluster
    $(logDebug) . pack $ "Cluster Actions: " ++ show (peers, ps)
    mapM_ (doClusterAction ps) (Set.toList peers)
    propStates2 <- mapM doPartitionActions (Map.toList propStates)
    putS ns {
        cluster = cluster2,
        propStates = Map.fromAscList [
            (k, p)
            | (k, p) <- propStates2
            , p /= P.initProp self (P.getPowerState p)
          ]
      }
  where
    doClusterAction ps peer =
      send peer (ClusterMerge ps)

    doPartitionActions (key, propState) = do
        let (peers, ps, propState2) = P.actions propState
        mapM_ (perform ps) (Set.toList peers)
        return (key, propState2)
      where
        perform ps peer =
          send peer (PartitionMerge key ps)


{- |
  Figure out if any rebalancing actions must be taken by this node, and kick
  them off if so.
-}
rebalance :: (LegionConstraints i o s) => Legionary i o s -> StateM i o s ()
rebalance _ = do
  ns@NodeState {self, cluster} <- getS
  let
    allPeers = (Set.fromList . Map.keys . getPeers) cluster
    dist = getDistribution cluster
    action = rebalanceAction self allPeers dist
  $(logDebug) . pack $ "The rebalance action is: " ++ show action
  putS ns {
      cluster = case action of
        Nothing -> cluster
        Just (Invite ks) -> claimParticipation self ks cluster
    }


{- | This is the type of input accepted by the legionary state machine. -}
data LInput i o s
  = P (PeerMessage i o s)
  | R (RequestMsg i o)
  | J (JoinRequest, JoinResponse -> LIO ())
  | A (AdminMessage i o s)

instance (Show i, Show o, Show s) => Show (LInput i o s) where
  show (P m) = "(P " ++ show m ++ ")"
  show (R ((p, i), _)) = "(R ((" ++ show p ++ ", " ++ show i ++ "), _))"
  show (J (jr, _)) = "(J (" ++ show jr ++ ", _))"
  show (A a) = "(A (" ++ show a ++ "))"


{- | This is the type of output produced by the legionary state machine. -}
data LOutput i o s
  = Send Peer (PeerMessage i o s)
  | NewPeers (Map Peer BSockAddr)


{- | A helper function to log the state of the node: -}
logNodeState :: (LegionConstraints i o s) => NodeState i o s -> LIO ()
logNodeState ns = $(logDebug) . pack
    $ "The current node state is: " ++ show ns


{- | Like `getState`, but in LIO, and provides the correct bottom value.  -}
getStateL :: (ApplyDelta i s, Default s)
  => Persistence i s
  -> Peer
  -> ClusterPropState
  -> PartitionKey
  -> LIO (PartitionPropState i s)

getStateL p self cluster key =
  {- dp == default participants -}
  let dp = C.findPartition key cluster
  in maybe
      (P.new key self dp)
      (P.initProp self)
      <$> lift (getState p key)


{- | Like `saveState`, but in LIO.  -}
saveStateL
  :: Persistence i s
  -> PartitionKey
  -> Maybe (PartitionPowerState i s)
  -> LIO ()
saveStateL p k = lift . saveState p k


{- | Like `list`, but in LIO.  -}
listL :: Persistence i s -> Source LIO (PartitionKey, PartitionPowerState i s)
listL p = transPipe lift (list p)


{- | This is the type of a join request message.  -}
data JoinRequest = JoinRequest BSockAddr
  deriving (Generic, Show)
instance Binary JoinRequest


{- | The response to a JoinRequst message -}
data JoinResponse
  = JoinOk Peer ClusterPowerState
  | JoinRejected String
  deriving (Generic)
instance Binary JoinResponse


{- |
  The type of messages sent by the admin service.
-}
data AdminMessage i o s
  = GetState (NodeState i o s -> LIO ())
  | GetPart PartitionKey (Maybe (PartitionPowerState i s) -> LIO ())

instance Show (AdminMessage i o s) where
  show (GetState _) = "(GetState _)"
  show (GetPart k _) = "(GetPart " ++ show k ++ " _)"


{- | Defines the local state of a node in the cluster.  -}
data NodeState i o s = NodeState {
             self :: Peer,
          cluster :: ClusterPropState,
        forwarded :: Forwarded o,
       propStates :: Map PartitionKey (PartitionPropState i s),
        migration :: KeySet,
           nextId :: MessageId
  }
  deriving (Show)


{- | A set of forwardmed messages.  -}
newtype Forwarded o = F {unF :: Map MessageId (o -> LIO ())}
instance Show (Forwarded o) where
  show = show . Map.keys . unF


{- |
  The type of messages sent to us from other peers.
-}
data PeerMessage i o s = PeerMessage {
       source :: Peer,
    messageId :: MessageId,
      payload :: PeerMessagePayload i o s
  }
  deriving (Generic, Show)
instance (Binary i, Binary o, Binary s) => Binary (PeerMessage i o s)


{- |
  The data contained within a peer message.

  When we get around to implementing durability and data replication,
  the sustained inability to confirm that a node has received one of
  these messages should result in the ejection of that node from the
  cluster and the blacklisting of that node so that it can never re-join.
-}
data PeerMessagePayload i o s
  = PartitionMerge PartitionKey (PartitionPowerState i s)
  | ForwardRequest PartitionKey i
  | ForwardResponse MessageId o
  | ClusterMerge ClusterPowerState
  deriving (Generic, Show)
instance (Binary i, Binary o, Binary s) => Binary (PeerMessagePayload i o s)


data MessageId = M UUID Word64 deriving (Generic, Show, Eq, Ord)
instance Binary MessageId


{- |
  Generate the next message id in the sequence. We would normally use
  `succ` for this kind of thing, but making `MessageId` an instance of
  `Enum` really isn't appropriate.
-}
next :: MessageId -> MessageId
next (M sequenceId ord) = M sequenceId (ord + 1)


{- |
  Initialize a new sequence of messageIds
-}
newSequence ::  LIO MessageId
newSequence = lift $ do
  sid <- getUUID
  return (M sid 0)


{- |
  Make a new node state.
-}
newNodeState :: Peer -> ClusterPropState -> LIO (NodeState i o s)
newNodeState self cluster = do
  nextId <- newSequence
  return NodeState {
      self,
      nextId,
      cluster,
      forwarded = F Map.empty,
      propStates = Map.empty,
      migration = KS.empty
    }


send :: Peer -> PeerMessagePayload i o s -> StateM i o s ()
send peer payload = do
  ns@NodeState {self, nextId} <- getS
  emit (Send peer PeerMessage {
      source = self,
      messageId = nextId,
      payload
    })
  putS ns {nextId = next nextId}


{- |
  Forward a user request to a peer for handling, being sure to do all
  the node state accounting.
-}
forward
  :: Peer
  -> PartitionKey
  -> i
  -> (o -> IO ())
  -> StateM i o s ()
forward peer key request respond = do
  ns@NodeState {nextId, self, forwarded} <- getS
  emit (Send peer PeerMessage {
      source = self,
      messageId = nextId,
      payload = ForwardRequest key request
    })
  putS ns {
      nextId = next nextId,
      forwarded = F . insert nextId (lift . respond) . unF $ forwarded
    }


{- |
  The monad in which the internals of the state machine run. This is really
  just a conduit, but we wrap it because we only want to allow `yield`, which
  we have re-named `emit`.
-}
newtype StateMT i o s m r = StateMT {
    unStateMT ::
      StateT
        (NodeState i o s)
        (ConduitM (LInput i o s) (LOutput i o s) m)
        r
  } deriving (
    Functor, Applicative, Monad, MonadLogger, MonadCatch,
    MonadThrow, MonadIO
  )
instance MonadTrans (StateMT i o s) where
  lift = StateMT . lift . lift


{- |
  The state machine monad, in LIO.
-}
type StateM i o s r = StateMT i o s LIO r


{- |
  Run the state machine monad, starting with the initial node state.
-}
runStateMT
  :: NodeState i o s
  -> StateMT i o s m ()
  -> ConduitM (LInput i o s) (LOutput i o s) m (NodeState i o s)
runStateMT ns = fmap snd . (`runStateT` ns) . unStateMT


{- |
  Emit some output from the state machine.
-}
emit :: LOutput i o s -> StateM i o s ()
emit = StateMT . lift . yield


{- |
  Get the node State.
-}
getS :: StateMT i o s m (NodeState i o s)
getS = StateMT get


{- |
  Put the node state.
-}
putS :: NodeState i o s -> StateMT i o s m ()
putS = StateMT . put