{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}

{- | This module contains low level state types. -}
module Network.Legion.Runtime.State (
  -- * Runtime State
  RuntimeState,
  makeRuntimeState,

  -- * Runtime Monads
  RuntimeT,
  runRuntimeT,
  runConcurrentT,

  -- * Runtime Monad Operations
  updateRecvClock,
  joinCluster,
  eject,
  getDivergent,
  userRequest,
  forwardedRequest,
  forwardResponse,
  clusterMerge,
  getCM,
  searchDispatch,
  search,
  searchResponse,
  joinNext,
  joinNextResponse,
  partitionMerge,
  

  -- * Debug Monad Operations
  debugIndex,
  debugRuntimeState,
  debugLocalPartitions,
  debugPartition,

  -- * Other Types
  StartupMode(..),
  JoinRequest(..),
  JoinResponse(..),
  UserResponse(..),
) where


import Control.Concurrent.STM (TVar, atomically, readTVar, newTVar,
  STM, writeTVar, modifyTVar)
import Control.Monad (unless, void)
import Control.Monad.Catch (throwM, MonadThrow, MonadCatch)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Logger (MonadLoggerIO, MonadLogger, logInfo,
  logError, logWarn, logDebug)
import Control.Monad.Trans.Class (lift, MonadTrans)
import Control.Monad.Trans.Reader (ReaderT, runReaderT, ask)
import Control.Monad.Trans.State (StateT, runStateT, modify, get, put)
import Data.Aeson (Value, toJSON, ToJSON)
import Data.Binary (encode, Binary)
import Data.Bool (bool)
import Data.Conduit ((.|), await, transPipe, runConduit)
import Data.Conduit.Network (sourceSocket)
import Data.Conduit.Serialization.Binary (conduitDecode)
import Data.Default.Class (Default)
import Data.Map (Map)
import Data.Maybe (catMaybes, fromMaybe)
import Data.Set (Set, (\\))
import Data.String (IsString, fromString)
import Data.Text (Text)
import Data.Time (UTCTime, getCurrentTime)
import GHC.Generics (Generic)
import Network.Legion.Application (Persistence, list, saveCluster,
  getState, saveState)
import Network.Legion.BSockAddr (BSockAddr(BSockAddr))
import Network.Legion.ClusterState (ClusterPowerState, RebalanceOrd,
  ClusterPowerStateT)
import Network.Legion.Distribution (Peer, newPeer)
import Network.Legion.Fork (ForkM, forkM)
import Network.Legion.Index (IndexRecord(IndexRecord),
  SearchTag(SearchTag), Indexable, indexEntries, stTag, stKey, irTag,
  irKey)
import Network.Legion.KeySet (KeySet)
import Network.Legion.PartitionKey (PartitionKey)
import Network.Legion.PartitionState (PartitionPowerState,
  PartitionPowerStateT)
import Network.Legion.PowerState (Event)
import Network.Legion.PowerState.Monad (PropAction(Send, DoNothing))
import Network.Legion.Runtime.ConnectionManager (ConnectionManager,
  newConnectionManager)
import Network.Legion.Runtime.PeerMessage (MessageId, newSequence,
  PeerMessagePayload(ClusterMerge, PartitionMerge, JoinNext,
  ForwardRequest, ForwardResponse, Search), PeerMessage(PeerMessage),
  source, messageId, payload)
import Network.Legion.Settings (RuntimeSettings(RuntimeSettings,
  peerBindAddr))
import Network.Legion.SocketUtil (fam)
import Network.Legion.UUID (getUUID)
import Network.Socket (SocketType(Stream), defaultProtocol, socket,
  SockAddr, connect)
import Network.Socket.ByteString.Lazy (sendAll)
import System.IO (stderr, hPutStrLn)
import qualified Data.Aeson as A
import qualified Data.ByteString.Lazy.Char8 as BSL8
import qualified Data.Conduit.List as CL
import qualified Data.Map as Map
import qualified Data.Set as Set
import qualified Data.Text as T
import qualified Network.Legion.ClusterState as C
import qualified Network.Legion.Distribution as D
import qualified Network.Legion.KeySet as KS
import qualified Network.Legion.PowerState as PS
import qualified Network.Legion.PowerState.Monad as PM
import qualified Network.Legion.Runtime.ConnectionManager as CM


{- | The state of the runtime system. -}
data RuntimeState e o s m = RuntimeState {
             self :: Peer,
          cluster :: ClusterPowerState,
       partitions :: Map PartitionKey (TVar (PartitionWorkerState e o s m)),
         rtsIndex :: Set IndexRecord,
            joins :: Map Peer KeySet,
                     {- ^ Outstanding joins. -}
    lastRebalance :: RebalanceOrd,
        forwarded :: Map MessageId (o -> m ()),
           nextId :: MessageId,
               cm :: ConnectionManager e o s,
        recvClock :: Map Peer (Maybe UTCTime),
                     {- ^ When did we last receive a message from a peer. -}
         searches :: Map
                       SearchTag
                       (
                         Set Peer,
                         Maybe IndexRecord,
                         [Maybe IndexRecord -> m ()]
                       )
                     {- ^
                       A map of currently dispatched searches. The values
                       are the peers from which we are still expecting
                       a result, best result so far, and the responders
                       to which to send the eventual best result.

                       The 'searches' field is a little weird.

                       It turns out that searches are deterministic
                       over the parameters of 'SearchTag' and cluster
                       state. This should make sense, because everything
                       in Haskell is deterministic given __all__
                       the parameters. Since the cluster state only
                       changes over time, searches that happen "at the
                       same time" and for the same 'SearchTag' can be
                       considered identical. I don't think it is too
                       much of a stretch to say that searches that have
                       overlapping execution times can be considered
                       to be happening "at the same time", therefore
                       the search tag becomes determining factor in the
                       result of the search.

                       This is a long-winded way of justifying the fact
                       that, if we are currently executing a search and
                       an identical search requests arrives, then the
                       second identical search is just piggy-backed on the
                       results of the currently executing search. Whether
                       this counts as a premature optimization hack
                       or a beautifully elegant expression of platonic
                       reality is left as an exercise for the reader. It
                       does help simplify the code a little bit because
                       we don't have to specify some kind of UUID to
                       differentiate otherwise identical searches.
                     -}
  }
instance Show (RuntimeState e o s m) where
  show = BSL8.unpack . A.encode
instance ToJSON (RuntimeState e o s m) where
  toJSON _ = toJSON ("RuntimeState" :: Text)


{- | The state of an individual asynchronous partition worker. -}
data PartitionWorkerState e o s m = PWS {
             pwsCm :: ConnectionManager e o s,
            pwsKey :: PartitionKey,
           pwsSelf :: Peer,
    pwsPersistence :: Persistence e o s,
       pwsCacheVal :: Maybe (PartitionPowerState e o s),
       pwsJobQueue :: [(ClusterPowerState, PartitionPowerStateT e o s m ())]
  }


{- | The Runtime Monad Transformer. -}
newtype RuntimeT e o s m a = RuntimeT {
    unRuntimeT :: StateT (RuntimeState e o s m) (ReaderT (Persistence e o s) m) a
  }
  deriving (Functor, Applicative, Monad, MonadIO, MonadLogger, MonadThrow)
instance MonadTrans (RuntimeT e o s) where
  lift = RuntimeT . lift . lift


{- | Execute a 'RuntimeT'.  -}
runRuntimeT
  :: Persistence e o s
  -> RuntimeState e o s m
  -> RuntimeT e o s m a
  -> m (a, RuntimeState e o s m)
runRuntimeT persistence rts =
  (`runReaderT` persistence) . (`runStateT` rts) . unRuntimeT


{- | Initialize the runtime state. -}
makeRuntimeState :: (
      Binary e,
      Binary o,
      Binary s,
      Event e o s,
      ForkM m,
      Indexable s,
      MonadCatch m,
      MonadLoggerIO m
    )
  => Persistence e o s
  -> RuntimeSettings
  -> StartupMode
  -> m (RuntimeState e o s m)

makeRuntimeState
    persistence
    settings@RuntimeSettings {peerBindAddr}
    NewCluster
  = do
    {- Build a brand new node state, for the first node in a cluster. -}
    verifyClearPersistence persistence
    self <- newPeer
    clusterId <- getUUID
    let
      cluster = C.new clusterId self peerBindAddr
    makeRuntimeState persistence settings (Recover self cluster)

makeRuntimeState
    persistence
    settings@RuntimeSettings {peerBindAddr}
    (JoinCluster addr)
  = do
    {-
      Join a cluster by either starting fresh, or recovering from a
      shutdown or crash.
    -}
    verifyClearPersistence persistence
    $(logInfo) "Trying to join an existing cluster."
    (self, cluster) <- requestJoin (JoinRequest (BSockAddr peerBindAddr))
    makeRuntimeState persistence settings (Recover self cluster)
  where
    requestJoin :: (MonadLoggerIO io)
      => JoinRequest
      -> io (Peer, ClusterPowerState)
    requestJoin joinMsg = liftIO $ do
      so <- socket (fam addr) Stream defaultProtocol
      connect so addr
      sendAll so (encode joinMsg)
      {-
        using sourceSocket and conduitDecode is easier than building
        a recive/decode state loop, even though we only read a single
        response.
      -}
      runConduit $ sourceSocket so .| conduitDecode .| do
        response <- await
        case response of
          Nothing -> fail
            $ "Couldn't join a cluster because there was no response "
            ++ "to our join request!"
          Just (JoinOk self cps) ->
            return (self, cps)

makeRuntimeState persistence _ (Recover self cluster) = do
    {- Make sure to rebuild the index in the case of recovery. -}
    rtsIndex <- runConduit . transPipe liftIO $
      list persistence
      .| CL.fold addIndexRecords Set.empty
    firstMessageId <- newSequence
    cm <- newConnectionManager self (C.getPeers cluster)
    liftIO $ saveCluster persistence self cluster
    return RuntimeState {
        self,
        cluster,
        partitions = Map.empty,
        rtsIndex,
        joins = Map.empty,
        lastRebalance = minBound,
        forwarded = Map.empty,
        nextId = firstMessageId,
        cm,
        recvClock = Map.empty,
        searches = Map.empty
      }
  where
    addIndexRecords :: (Indexable s, Event e o s)
      => Set IndexRecord
      -> (PartitionKey, PartitionPowerState e o s)
      -> Set IndexRecord
    addIndexRecords index (key, partition) =
      let
        newRecords =
          Set.map
            (`IndexRecord` key)
            (indexEntries (PS.projectedValue partition))
      in Set.union index newRecords


{- | This defines the various ways a node can be spun up. -}
data StartupMode
  = NewCluster
    {- ^ Indicates that we should bootstrap a new cluster at startup. -}
  | JoinCluster SockAddr
    {- ^ Indicates that the node should try to join an existing cluster. -}
  | Recover Peer ClusterPowerState
    {- ^
      Recover from a crash as the given peer, using the given cluster
      state.
    -}
  deriving (Show, Eq)


{- | This is the type of a join request message. -}
newtype JoinRequest = JoinRequest BSockAddr
  deriving (Generic, Show)
instance Binary JoinRequest


{- | The response to a JoinRequest message -}
data JoinResponse
  = JoinOk Peer ClusterPowerState
  deriving (Generic)
instance Binary JoinResponse


{- |
  Helper for 'makeRuntimeState'. Verify that there is nothing in the
  persistence layer.
-}
verifyClearPersistence :: (MonadLoggerIO io) => Persistence e o s -> io ()
verifyClearPersistence persistence = 
  liftIO (runConduit (list persistence .| CL.head)) >>= \case
    Just _ -> do
      let
        msg :: (IsString a) => a
        msg = fromString
          $ "We are trying to start up a new peer, but the persistence "
          ++ "layer already has data in it. This is an invalid state. "
          ++ "New nodes must be started from a totally clean, empty state."
      $(logError) msg
      liftIO $ do
        hPutStrLn stderr msg
        putStrLn msg
        error msg
    Nothing ->
      return ()


{- | Update the time when we last received a message from a peer. -}
updateRecvClock :: (MonadIO m) => Peer -> RuntimeT e o s m ()
updateRecvClock peer = RuntimeT $ do
  now <- liftIO getCurrentTime
  modify (\rts@RuntimeState {recvClock} ->
      let
        newRecvClock = Map.insert peer (Just now) recvClock
      in newRecvClock `seq` rts {
          recvClock = newRecvClock
        }
    )


{- | Return the current index for debugging. -}
debugIndex :: (Monad m) => RuntimeT e o s m (Set IndexRecord)
debugIndex = RuntimeT $ rtsIndex <$> get


{- | Return the runtime state for debugging. -}
debugRuntimeState :: (Monad m) => RuntimeT e o s m Value
debugRuntimeState = toJSON <$> RuntimeT get


{- | Return all of the local partitions for debugging. -}
debugLocalPartitions :: (MonadIO m)
  => RuntimeT e o s m (Map PartitionKey (PartitionPowerState e o s))
debugLocalPartitions = do
  persistence <- RuntimeT (lift ask)
  Map.fromList <$> runConduit (
      transPipe liftIO (list persistence)
      .| CL.consume
    )
  


{- | Return a specific partition, for debugging. -}
debugPartition :: (MonadIO m)
  => PartitionKey
  -> RuntimeT e o s m (Maybe (PartitionPowerState e o s))
debugPartition key = RuntimeT $ do
  persistence <- lift ask
  liftIO (getState persistence key)


{- | Let a new peer join the cluster. -}
joinCluster :: (MonadIO m, MonadThrow m)
  => BSockAddr
  -> RuntimeT e o s m (Peer, ClusterPowerState)
joinCluster addr = do
  peer <- newPeer
  runClusterPowerStateT (C.joinCluster peer addr)
  cluster <- getCluster
  return (peer, cluster)


{- | Eject a peer from the cluster. -}
eject :: (MonadIO m, MonadThrow m) => Peer -> RuntimeT e o s m ()
eject peer = do
  {-
    We need to think very hard about the split brain problem. A random
    thought about that is that we should consider the extreme case where
    the network just fails completely and every node believes that every
    other node should be or has been ejected. This would obviously be
    catastrophic in terms of data durability unless we have some way to
    reintegrate an ejected node. So, either we have to guarantee that
    such a situation can never happen, or else implement a reintegration
    strategy.  It might be acceptable for the reintegration strategy to
    be very costly if it is characterized as an extreme recovery scenario.

    Question: would a reintegration strategy become less costly if the
    "next state id" for a peer were global across all power states
    instead of local to each power state?
  -}
  runClusterPowerStateT (C.eject peer)
  {-
    'runClusterPowerStateT (C.eject peer)' will cause us to attempt to
    notify the peer that they have been ejected, but that notification
    is almost certainly going to go unacknowledged because the peer
    is probably down.
    
    This call to 'eject' was presumably invoked as a result of user
    action, and we must therefore trust the user to know that the peer
    is really down and not coming back. This "guarantee" allows us to
    acknowledge the ejection on the peer's behalf.

    This call will acknowledge the drop on behalf of the peer, and also
    remove that peer from the keyspace distribution map.
  -}
  runClusterPowerStateTAs peer (return ())


{- |
  Gets the peers that the local node thinks are diverging, and the time
  we last received a message from those peers.
-}
getDivergent :: (MonadIO m) => RuntimeT e o s m (Map Peer (Maybe UTCTime))
getDivergent = RuntimeT $ do
    RuntimeState {recvClock, partitions} <- get
    diverging <- lift . lift $ divergentPeers partitions
    return $ Map.fromAscList [
        (peer, r)
        | (peer, r) <- Map.toAscList recvClock
        , peer `Set.member` diverging
      ]
  where
    divergentPeers :: (MonadIO m)
      => Map PartitionKey (TVar (PartitionWorkerState e o s m))
      -> m (Set Peer)
    divergentPeers partitions = liftIO $
      foldr Set.union Set.empty . catMaybes <$> sequence [
          fmap PS.divergent . pwsCacheVal <$> atomically (readTVar tvar)
          | (_key, tvar) <- Map.toList partitions
        ]


{- | Handle a user request, and pass the response to the continuation. -}
userRequest :: (
      Default s,
      Eq e,
      Event e o s,
      ForkM m,
      Indexable s,
      MonadCatch m,
      MonadLoggerIO m,
      Show e,
      Show s
    )
  => PartitionKey
  -> e
  -> (o -> m ())
  -> RuntimeT e o s m ()
userRequest key request k = do
  RuntimeState {self, cm} <- RuntimeT get
  route key >>= \case
    p | p == self -> 
      runConcurrentT key (
          lift . k =<< PM.event request
        )
    p -> do
      messageId <- CM.send cm p (ForwardRequest key request)
      (RuntimeT . modify) (\rts@RuntimeState {forwarded} -> rts {
          forwarded = Map.insert messageId k forwarded
        })


{- | Handle a forwarded request. -}
forwardedRequest :: (
      Default s,
      Eq e,
      Event e o s,
      ForkM m,
      MonadCatch m,
      MonadLoggerIO m
    )
  => Peer
  -> MessageId
  -> PartitionKey
  -> e
  -> RuntimeT e o s m ()
forwardedRequest source messageId key event = do
  RuntimeState {self, cm} <- RuntimeT get
  route key >>= \case
    p | p == self ->
      runConcurrentT key (do
          o <- PM.event event
          (void . lift) (CM.send cm source (ForwardResponse messageId o))
        )
    p ->
      {-
        No need to keep track of the forwarded message, we just need to
        reconstruct the original message and send it on its way.

        TODO think about implementing cycle detection. Cycles should
        not exists, so if we detect one then that classifies as
        a bug. So really this is more of an opportunity for bug
        detection.
      -}
      CM.forward cm p PeerMessage {
          source,
          messageId,
          payload = ForwardRequest key event
        }


{- | Find the route for a user request. -}
route :: (MonadLogger m)
  => PartitionKey
  -> RuntimeT e o s m Peer
route key = RuntimeT $ do
  RuntimeState {self, cluster} <- get
  let routes = C.findRoute key cluster
  if self `Set.member` routes
    then return self
    else case Set.toList routes of
      [] -> do
        let msg = "No routes for key: " ++ show key
        $(logError) . T.pack $ msg
        error msg
      peer:_ -> return peer
  



{- | Receive a response to a forwarded user request. -}
forwardResponse :: (MonadLoggerIO m, Show o)
  => MessageId
  -> o
  -> RuntimeT e o s m ()
forwardResponse forMessageId output = do
  rts@RuntimeState{forwarded} <- RuntimeT get
  let (r, fwd) = lookupAndDelete forMessageId forwarded
  RuntimeT $ put rts {forwarded = fwd}
  case r of
    Nothing ->
      $(logWarn) . T.pack
        $ "Received unexpected forward response: "
        ++ show (forMessageId, output)
    Just respond ->
      lift (respond output)


{- | Merge a forenig cluster power state. -}
clusterMerge :: (MonadIO m, MonadThrow m)
  => ClusterPowerState
  -> RuntimeT e o s m ()
clusterMerge cluster =
  runClusterPowerStateT (PM.merge cluster)


{- | Return the handle to the connection manager. -}
getCM :: (Monad m) => RuntimeT e o s m (ConnectionManager e o s)
getCM = RuntimeT $ cm <$> get


{- | Dispatch a distributed search request. -}
searchDispatch :: (MonadIO m)
  => SearchTag
  -> (Maybe IndexRecord -> m ())
  -> RuntimeT e o s m ()
searchDispatch searchTag k =
    Map.lookup searchTag . searches <$> RuntimeT get >>= \case
      Nothing -> do
        {-
          No identical search is currently being executed, kick off a
          new one.
        -}
        mcss <- minimumCompleteServiceSet
        mapM_ sendSearch (Set.toList mcss)
        (RuntimeT . modify) (\rts@RuntimeState {searches} -> rts {
            searches = Map.insert
              searchTag
              (mcss, Nothing, [k])
              searches
          })
      Just (peers, best, responders) ->
        {-
          A search for this tag is already in progress, just add the
          responder to the responder list.
        -}
        (RuntimeT . modify) (\rts@RuntimeState {searches} -> rts {
            searches = Map.insert
              searchTag
              (peers, best, k:responders)
              searches
          })
  where
    sendSearch :: (MonadIO m)
      => Peer
      -> RuntimeT e o s m ()
    sendSearch peer = do
      cm <- getCM
      void $ CM.send cm peer (Search searchTag)


{- |
  Search the index, and return the first record that is __strictly
  greater than__ the provided search tag, if such a record exists.
-}
search :: (Monad m)
  => SearchTag
  -> RuntimeT e o s m (Maybe IndexRecord)

search SearchTag {stTag, stKey = Nothing} = RuntimeT $ do
  index <- rtsIndex <$> get
  return (Set.lookupGE IndexRecord {irTag = stTag, irKey = minBound} index)

search SearchTag {stTag, stKey = Just key} = RuntimeT $ do
  index <- rtsIndex <$> get
  return (Set.lookupGT IndexRecord {irTag = stTag, irKey = key} index)


{- | Handle an incomming search response. -}
searchResponse :: (MonadLogger m)
  => Peer
  -> SearchTag
  -> Maybe IndexRecord
  -> RuntimeT e o s m ()

searchResponse source searchTag response =
    {- TODO: see if this function can't be made more elegant. -}
    Map.lookup searchTag . searches <$> RuntimeT get >>= \case
      Nothing ->
        {- There is no search happening. -}
        $(logWarn) . T.pack
          $ "Unsolicited SearchResponse: "
          ++ show (source, searchTag, response)
      Just (peers, best, responders) ->
        if source `Set.member` peers
          then
            let peers2 = Set.delete source peers
            in if null peers2
              then do
                {-
                  All peers have responded, go ahead and respond to
                  the client.
                -}
                lift $ mapM_ ($ bestOf best response) responders
                rts@RuntimeState {searches} <- RuntimeT get
                (RuntimeT . put) rts {searches = Map.delete searchTag searches}
              else do
                {- We are still waiting on some outstanding requests. -}
                rts@RuntimeState {searches} <- RuntimeT get
                (RuntimeT . put) rts {
                    searches = Map.insert
                      searchTag
                      (peers2, bestOf best response, responders)
                      searches
                  }
          else
            {-
              There is a search happening, but the peer that responded
              is not part of it.
            -}
            $(logWarn) . T.pack
              $ "Unsolicited SearchResponse: "
              ++ show (source, searchTag, response)
  where
    {- |
      Figure out which index record returned to us by the various peers
      is the most appropriate to return to the user. This is mostly like
      'min' but we can't use 'min' (or fancy applicative formulations)
      because we want to favor 'Just' instead of 'Nothing'.
    -}
    bestOf :: Maybe IndexRecord -> Maybe IndexRecord -> Maybe IndexRecord
    bestOf (Just a) (Just b) = Just (min a b)
    bestOf Nothing b = b
    bestOf a Nothing = a


{- |
  Allow a peer to participate in the replication of the partition with the
  __minimum__ key that is within the indicated partition key set. Calls
  the continuation with @Nothing@ if there is no such partition, or @Just
  (key, partition)@ where @key@ is the partition key that was joined
  and @partition@ is the resulting partition power state.
-}
joinNext :: (
      Default s,
      Eq e,
      Event e o s,
      ForkM m,
      MonadCatch m,
      MonadLoggerIO m
    )
  => Peer
  -> KeySet
  -> (Maybe (PartitionKey, PartitionPowerState e o s) -> m ())
  -> RuntimeT e o s m ()
joinNext peer askKeys k = do
  persistence <- RuntimeT (lift ask)
  (lift . runConduit) (
      transPipe liftIO (list persistence)
      .| CL.filter ((`KS.member` askKeys) . fst)
      .| CL.head
    ) >>= \case
      Nothing -> lift (k Nothing)
      Just (gotKey, _) ->
        runConcurrentT gotKey (do
            PM.participate peer
            PM.acknowledge
            partition <- PM.getPowerState
            lift (k (Just (gotKey, partition)))
          )


{- | Receive the result of a JoinNext request. -}
joinNextResponse :: (
      Default s,
      Eq e,
      Event e o s,
      ForkM m,
      MonadLoggerIO m,
      MonadCatch m,
      Show e,
      Show s
    )
  => Peer
  -> Maybe (PartitionKey, PartitionPowerState e o s)
  -> RuntimeT e o s m ()
joinNextResponse peer response = do
  RuntimeState {cluster, lastRebalance} <- RuntimeT get
  if lastRebalance > fst (C.nextAction cluster)
    then
      {- We are receiving messages from an old rebalance. Log and ignore. -}
      $(logWarn) . T.pack
        $ "Received an old join response: "
        ++ show (peer, response, cluster, lastRebalance)
    else do
      case response of
        Just (key, partition) -> do
          partitionMerge key partition
          RuntimeState {joins, cm} <- RuntimeT get
          case (KS.\\ KS.fromRange minBound key) <$> Map.lookup peer joins of
            Nothing ->
              {- An unexpected peer sent us this message, Ignore. -}
              $(logWarn) . T.pack
                $ "Unexpected join next: " ++ show (peer, response)
            Just needsJoinSet -> do
              unless (KS.null needsJoinSet) (
                  void $ CM.send cm peer (JoinNext needsJoinSet)
                )
              (RuntimeT . modify) (\rts -> rts {
                  joins = Map.filter
                    (not . KS.null)
                    (Map.insert peer needsJoinSet joins)
                })
        Nothing ->
          (RuntimeT . modify) (\rts@RuntimeState {joins} -> rts {
              joins = Map.delete peer joins
            })
      Map.null . joins <$> RuntimeT get >>= bool
        (return ())
        (runClusterPowerStateT C.finishRebalance)


{- | Merge a foreign partition replica with the local partion replica. -}
partitionMerge :: (
      Default s,
      Eq e,
      Event e o s,
      ForkM m,
      MonadCatch m,
      MonadLoggerIO m
    )
  => PartitionKey
  -> PartitionPowerState e o s
  -> RuntimeT e o s m ()
partitionMerge key foreignPartition =
  runConcurrentT key (PM.merge foreignPartition)


{- | Get the current cluster state. -}
getCluster :: (Monad m) => RuntimeT e o s m ClusterPowerState
getCluster = RuntimeT $ cluster <$> get


{- |
  Run a 'ClusterPowerStateT', and perform any resulting propagation
  actions.
-}
runClusterPowerStateT :: (MonadIO m, MonadThrow m)
  => ClusterPowerStateT m a
  -> RuntimeT e o s m a
runClusterPowerStateT m = do
  as <- RuntimeT $ self <$> get
  runClusterPowerStateTAs as m


{- |
  Run a clusterstate-flavored 'PowerStateT' in the 'RuntimeT' monad,
  automatically acknowledging the resulting power state.

  Generalized to run as any peer, in order to support exceptional cases
  like 'eject'.
-}
runClusterPowerStateTAs :: (MonadIO m, MonadThrow m)
  => Peer {- ^ The peer to run as. -}
  -> ClusterPowerStateT m a
  -> RuntimeT e o s m a
runClusterPowerStateTAs as m = do
  RuntimeState {cluster, self} <- RuntimeT get
  persistence <- RuntimeT (lift ask)
  lift (PM.runPowerStateT as cluster (m <* PM.acknowledge)) >>= \case
    Left err -> throwM err
    Right (a, action, cluster2, _outputs) -> do
      RuntimeT (modify (\rts -> rts {cluster = cluster2}))
      liftIO (saveCluster persistence self cluster2)
      case action of
        Send -> sequence_ [
            getCM >>= (\cm -> CM.send cm p (ClusterMerge cluster2))
            | p <- Set.toList (PS.allParticipants cluster2)
            , p /= self
          ]
        DoNothing -> return ()
      return a


{- |
  The type of response to a user request, either forward to another node,
  or respond directly.
-}
data UserResponse o
  = Forward Peer
  | Respond o


{- | The action is executed in a background thread. -}
runConcurrentT :: (
      Default s,
      Eq e,
      Event e o s,
      ForkM m,
      MonadCatch m,
      MonadLoggerIO m
    )
  => PartitionKey
  -> PartitionPowerStateT e o s m ()
  -> RuntimeT e o s m ()
runConcurrentT key action_ = do
    rts@RuntimeState {partitions} <- RuntimeT get
    persistence <- RuntimeT (lift ask)
    let job = (cluster rts, action_)
    case Map.lookup key partitions of
      Nothing -> do
        tvar <- liftIO (atomically (newTVar PWS {
            pwsCm = cm rts,
            pwsKey = key,
            pwsSelf = self rts,
            pwsPersistence = persistence,
            pwsCacheVal = Nothing,
            pwsJobQueue = []
          }))
        RuntimeT $ put rts {partitions = Map.insert key tvar partitions}
        lift =<< liftIO (atomically (queueAction tvar job))
      Just tvar ->
        lift =<< liftIO (atomically (queueAction tvar job))
  where
    {- |
      Put the partition action on the execution queue, and maybe also
      start the execution thread if there isn't already one running.
    -}
    queueAction :: (
          Default s,
          Eq e,
          Event e o s,
          ForkM m,
          MonadLoggerIO m
        )
      => TVar (PartitionWorkerState e o s m)
      -> (ClusterPowerState, PartitionPowerStateT e o s m ())
      -> STM (m ())
    queueAction tvar job = do
      pws@PWS {pwsJobQueue} <- readTVar tvar
      let
        forkJobWorker = return (forkM (jobWorker tvar job))
      writeTVar tvar pws {pwsJobQueue = pwsJobQueue ++ [job]}
      if null pwsJobQueue
        then forkJobWorker
        else return (return ())

    jobWorker :: (MonadLoggerIO m, Default s, Event e o s, Eq e)
      => TVar (PartitionWorkerState e o s m)
      -> (ClusterPowerState, PartitionPowerStateT e o s m ())
      -> m ()
    jobWorker tvar job =
        doJob tvar job >> nextJob >>= maybe shutdown (jobWorker tvar)
      where
        nextJob = liftIO . atomically $ do
          pws@PWS {pwsCacheVal, pwsJobQueue} <- readTVar tvar
          case pwsJobQueue of
            _:next:more -> do
              {- Pop the last job off the queue, and promote the next job. -}
              writeTVar tvar pws {pwsJobQueue = next:more}
              return (Just next)
            [_] -> do
              {-
                All jobs complete. Pop the last job off the stack and
                clean up the cache if necessary.
              -}
              writeTVar tvar $
                case Set.null . PS.divergent <$> pwsCacheVal of
                  Just False -> pws {pwsJobQueue = []}
                  _ -> pws {pwsCacheVal = Nothing, pwsJobQueue = []}
              return Nothing
            [] ->
              {- This shouldn't happen. See about using non-empty lists. -}
              return Nothing
        shutdown = return ()

    doJob :: (MonadLoggerIO m, Default s, Event e o s, Eq e)
      => TVar (PartitionWorkerState e o s m)
      -> (ClusterPowerState, PartitionPowerStateT e o s m ())
      -> m ()
    doJob tvar (cluster, action) = do
        $(logDebug) . T.pack $ "Starting job on " ++ show key
        PWS {
            pwsCm,
            pwsSelf,
            pwsPersistence,
            pwsCacheVal
          } <- liftIO (atomically (readTVar tvar))
        partition <- case pwsCacheVal of
          Nothing ->
            fromMaybe (PS.new key (C.findOwners key cluster))
              <$> liftIO (getState pwsPersistence key)
          Just partition -> return partition
        PM.runPowerStateT pwsSelf partition (
            action <* (removeObsolete >> PM.acknowledge)
          ) >>= \case
            Left err ->
              $(logError) . T.pack
                $ "Partition error: " ++ show (err, key)
            Right ((), propAction, newPartition, _outputs) -> do
              liftIO . atomically . modifyTVar tvar $ (\pws ->
                  pws {pwsCacheVal = Just newPartition}
                )
              liftIO (saveState pwsPersistence key (Just newPartition))
              case propAction of
                Send -> sequence_ [
                    CM.send pwsCm p (PartitionMerge key newPartition)
                    | p <- Set.toList (PS.allParticipants newPartition)
                    , p /= pwsSelf
                  ]
                DoNothing -> return ()
        $(logDebug) . T.pack $ "Finished job on " ++ show key
      where
        {- |
          Remove obsolete peers. Obsolete peers are peers that are no longer
          participating in the replication of this partition, due to a
          rebalance. Such peers are removed lazily here at read time.
        -}
        removeObsolete :: (Monad m, Event e o s, Eq e)
          => PartitionPowerStateT e o s m ()
        removeObsolete = do
          let owners = C.findOwners key cluster
          peers <- PS.projParticipants <$> PM.getPowerState
          let obsolete = peers \\ owners
          mapM_
            (\peer -> PM.disassociate peer >> PM.acknowledgeAs peer)
            (Set.toList obsolete)


{- | Lookup a key from a map, and also delete the key if it exists. -}
lookupAndDelete :: (Ord k) => k -> Map k v -> (Maybe v, Map k v)
lookupAndDelete = Map.updateLookupWithKey (const (const Nothing))


{- |
  Figure out the set of nodes to which search requests should be
  dispatched. "Minimum complete service set" means the minimum set
  of peers that, together, service the whole partition key space;
  thereby guaranteeing that if any particular partition is indexed,
  the corresponding index record will exist on one of these peers.

  Implementation considerations:

  There will usually be more than one solution for the MCSS. For now,
  we just compute a deterministic solution, but we should implement
  a random (or pseudo-random) solution in order to maximally balance
  cluster resources.

  Also, it is not clear that the minimum complete service set is even
  what we really want. MCSS will reduce overall network utilization,
  but it may actually increase latency. If we were to dispatch redundant
  requests to multiple nodes, we could continue with whichever request
  returns first, and ignore the slow responses. This is probably the
  best solution. We will call this "fastest competitive search".

  TODO: implement fastest competitive search.
-}
minimumCompleteServiceSet :: (Monad m) => RuntimeT e o s m (Set Peer)
minimumCompleteServiceSet = do
  RuntimeState {cluster} <- RuntimeT get
  return (D.minimumCompleteServiceSet (C.getDistribution cluster))