{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}
{- |
  This module is responsible for the runtime operation of the legion
  framework. This mostly means opening sockets and piping data around to the
  various connected pieces.
-}
module Network.Legion.Runtime (
  forkLegionary,
  StartupMode(..),
  Runtime,
  makeRequest,
  search,
) where

import Control.Concurrent (forkIO)
import Control.Concurrent.Chan (writeChan, newChan, Chan)
import Control.Concurrent.MVar (newEmptyMVar, takeMVar, putMVar)
import Control.Monad (void, forever, join)
import Control.Monad.Catch (catchAll, try, SomeException, throwM)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Logger (logWarn, logError, logInfo, LoggingT,
  MonadLoggerIO, runLoggingT, askLoggerIO, logDebug)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.State (StateT, runStateT, get, put, modify)
import Data.Binary (encode, Binary)
import Data.Conduit (Source, ($$), (=$=), yield, await, awaitForever,
  transPipe, ConduitM, runConduit, Sink)
import Data.Conduit.Network (sourceSocket)
import Data.Conduit.Serialization.Binary (conduitDecode)
import Data.Map (Map)
import Data.Set (Set)
import Data.String (IsString, fromString)
import Data.Text (pack)
import Data.Time (UTCTime, getCurrentTime)
import GHC.Generics (Generic)
import Network.Legion.Admin (runAdmin, AdminMessage(GetState, GetPart,
  Eject, GetIndex, GetDivergent, GetStates))
import Network.Legion.Application (LegionConstraints, Persistence,
  list, saveCluster)
import Network.Legion.BSockAddr (BSockAddr(BSockAddr))
import Network.Legion.ClusterState (ClusterPowerState)
import Network.Legion.Conduit (merge, chanToSink, chanToSource)
import Network.Legion.Distribution (Peer, newPeer)
import Network.Legion.Fork (forkC)
import Network.Legion.Index (IndexRecord(IndexRecord), irTag, irKey,
  SearchTag(SearchTag), indexEntries, Indexable)
import Network.Legion.LIO (LIO)
import Network.Legion.Lift (lift2,  lift3)
import Network.Legion.PartitionKey (PartitionKey)
import Network.Legion.PartitionState (PartitionPowerState)
import Network.Legion.PowerState (Event)
import Network.Legion.Runtime.ConnectionManager (newConnectionManager,
  ConnectionManager, newPeers)
import Network.Legion.Runtime.PeerMessage (PeerMessage(PeerMessage),
  PeerMessagePayload(ForwardRequest, ForwardResponse, ClusterMerge,
  PartitionMerge, Search, SearchResponse, JoinNext, JoinNextResponse),
  MessageId, newSequence, nextMessageId, JoinNextResponse(Joined,
  JoinFinished))
import Network.Legion.Settings (RuntimeSettings(RuntimeSettings,
  adminHost, adminPort, peerBindAddr, joinBindAddr))
import Network.Legion.StateMachine (partitionMerge, clusterMerge,
  newNodeState, UserResponse(Forward, Respond), userRequest, eject,
  minimumCompleteServiceSet, joinNext, joinNextResponse)
import Network.Legion.StateMachine.Monad (NodeState, runSM, ClusterAction,
  SM, popActions, nsIndex)
import Network.Legion.UUID (getUUID)
import Network.Socket (Family(AF_INET, AF_INET6, AF_UNIX, AF_CAN),
  SocketOption(ReuseAddr), SocketType(Stream), accept, bind,
  defaultProtocol, listen, setSocketOption, socket, SockAddr(SockAddrInet,
  SockAddrInet6, SockAddrUnix, SockAddrCan), connect, getPeerName, Socket)
import Network.Socket.ByteString.Lazy (sendAll)
import System.IO (stderr, hPutStrLn)
import qualified Data.Conduit.List as CL
import qualified Data.Map as Map
import qualified Data.Set as Set
import qualified Network.Legion.ClusterState as C
import qualified Network.Legion.PowerState as PS
import qualified Network.Legion.Runtime.ConnectionManager as CM
import qualified Network.Legion.StateMachine as SM
import qualified Network.Legion.StateMachine.Monad as SMM


{- |
  Run the legion node framework program, with the given user definitions,
  framework settings, and request source. This function never returns
  (except maybe with an exception if something goes horribly wrong).
-}
runLegionary :: (LegionConstraints e o s)
  => Persistence e o s
    {- ^ The persistence layer used to back the legion framework. -}
  -> RuntimeSettings
    {- ^ Settings and configuration of the legionframework.  -}
  -> StartupMode
  -> Source IO (RequestMsg e o)
    {- ^ A source of requests, together with a way to respond to the requets. -}
  -> LoggingT IO ()
    {-
      Don't expose 'LIO' here because 'LIO' is a strictly internal
      symbol. 'LoggingT IO' is what we expose to the world.
    -}

runLegionary
    persistence
    settings@RuntimeSettings {adminHost, adminPort}
    startupMode
    requestSource
  = do
    {- Start the various messages sources. -}
    peerS <- loggingC =<< startPeerListener settings
    adminS <- loggingC =<< runAdmin adminPort adminHost
    joinS <- loggingC (joinMsgSource settings)

    (self, nodeState, peers) <- makeNodeState persistence settings startupMode
    rts <- newRuntimeState self peers
    let
      messageSource = transPipe lift (
          (joinS =$= CL.map J) `merge`
          (peerS =$= CL.map P) `merge`
          (requestSource =$= CL.map R) `merge`
          (adminS =$= CL.map A)
        )
    void . runRTS persistence nodeState rts . runConduit $
      messageSource
      =$= messageSink
  where
    newRuntimeState :: (Binary e, Binary o, Binary s)
      => Peer
      -> Map Peer BSockAddr
      -> LoggingT IO (RuntimeState e o s)
    newRuntimeState self peers = do
      cm <- newConnectionManager peers
      firstMessageId <- newSequence
      return RuntimeState {
          forwarded = Map.empty,
          nextId = firstMessageId,
          cm,
          self,
          commClock = Map.empty,
          searches = Map.empty
        }

    {- |
      Turn an LIO-based conduit into an IO-based conduit, so that it
      will work with `merge`.
    -}
    loggingC :: ConduitM e o LIO r -> LIO (ConduitM e o IO r)
    loggingC c = do
      logging <- askLoggerIO
      return (transPipe (`runLoggingT` logging) c)


{- |
  This is how requests are packaged when they are sent to the legion framework
  for handling. It includes the request information itself, a partition key to
  which the request is directed, and a way for the framework to deliver the
  response to some interested party.
-}
data RequestMsg e o
  = Request PartitionKey e (o -> IO ())
  | SearchDispatch SearchTag (Maybe IndexRecord -> IO ())
instance (Show e) => Show (RequestMsg e o) where
  show (Request k e _) = "(Request " ++ show k ++ " " ++ show e ++ " _)"
  show (SearchDispatch s _) = "(SearchDispatch " ++ show s ++ " _)"


messageSink :: (LegionConstraints e o s)
  => Sink (RuntimeMessage e o s) (RTS e o s) ()
messageSink = awaitForever (\msg -> do
    $(logDebug) . pack $ "Receieved: " ++ show msg
    lift $ do
      case msg of
        P (PeerMessage source _ _) ->
          updateRecvClock source
        _ -> return ()
      handleMessage msg
      updatePeers
      clusterActions
  )


{- | Make progress on outstanding cluster actions. -}
clusterActions :: RTS e o s ()
clusterActions =
    mapM_ clusterAction =<< popActions
  where
    {- |
      Actually perform a cluster action as directed by the state
      machine.
    -}
    clusterAction
      :: ClusterAction e o s
      -> RTS e o s ()

    clusterAction (SMM.ClusterMerge peer ps) =
      void $ send peer (ClusterMerge ps)

    clusterAction (SMM.PartitionMerge peer key ps) =
      void $ send peer (PartitionMerge key ps)

    clusterAction (SMM.PartitionJoin peer keys) =
      void $ send peer (JoinNext keys)


{- |
  Make sure the connection manager knows about any new peers that have
  joined the cluster.
-}
updatePeers :: RTS e o s ()
updatePeers = do
  peers <- SM.getPeers
  RuntimeState {cm} <- lift get
  lift2 $ newPeers cm peers


{- |
  Handle an individual runtime message, accepting an initial runtime
  state and an initial node state, and producing an updated runtime
  state and node state.
-}
handleMessage :: (LegionConstraints e o s)
  => RuntimeMessage e o s
  -> RTS e o s ()

handleMessage {- Join Next Response -}
    (P (PeerMessage source _ (JoinNextResponse _messageId response)))
  =
    joinNextResponse source (toMaybe response)
  where
    toMaybe
      :: JoinNextResponse e o s
      -> Maybe (PartitionKey, PartitionPowerState e o s)
    toMaybe (Joined key partition) = Just (key, partition)
    toMaybe JoinFinished = Nothing

handleMessage {- Join Next -}
    (P (PeerMessage source messageId (JoinNext askKeys)))
  =
    joinNext source askKeys >>= \case
      Nothing -> void $
        send source (JoinNextResponse messageId JoinFinished)
      Just (gotKey, partition) -> void $
        send source (JoinNextResponse messageId (Joined gotKey partition))

handleMessage {- Partition Merge -}
    (P (PeerMessage _ _ (PartitionMerge key ps)))
  =
    partitionMerge key ps

handleMessage {- Cluster Merge -}
    (P (PeerMessage _ _ (ClusterMerge cs)))
  =
    clusterMerge cs

handleMessage {- Forward Request -}
    (P (msg@(PeerMessage source mid (ForwardRequest key request))))
  = do
    output <- userRequest key request
    case output of
      Respond response -> void $ send source (ForwardResponse mid response)
      Forward peer -> forward peer msg

handleMessage {- Forward Response -}
    (msg@(P (PeerMessage _ _ (ForwardResponse mid response))))
  = do
    rts <- lift get
    case lookupDelete mid (forwarded rts) of
      (Nothing, fwd) -> do
        $(logWarn) . pack $ "Unsolicited ForwardResponse: " ++ show msg
        (lift . put) rts {forwarded = fwd}
      (Just respond, fwd) -> do
        lift2 $ respond response
        (lift . put) rts {forwarded = fwd}

handleMessage {- User Request -}
    (R (Request key request respond))
  = do
    output <- userRequest key request
    case output of
      Respond response -> lift3 (respond response)
      Forward peer -> do
        messageId <- send peer (ForwardRequest key request)
        (lift . modify) $ \rts@RuntimeState {forwarded} -> rts {
            forwarded = Map.insert messageId (lift . respond) forwarded
          }

handleMessage {- Search Dispatch -}
    {-
      This is where we send out search request to all the appropriate
      nodes in the cluster.
    -}
    (R (SearchDispatch searchTag respond))
  =
    Map.lookup searchTag . searches <$> lift get >>= \case
      Nothing -> do
        {-
          No identical search is currently being executed, kick off a
          new one.
        -}
        mcss <- minimumCompleteServiceSet
        mapM_ sendOne (Set.toList mcss)
        rts@RuntimeState {searches} <- lift get
        (lift . put) rts {
            searches = Map.insert
              searchTag
              (mcss, Nothing, [lift . respond])
              searches
          }
      Just (peers, best, responders) -> do
        {-
          A search for this tag is already in progress, just add the
          responder to the responder list.
        -}
        rts@RuntimeState {searches} <- lift get
        (lift . put) rts {
            searches = Map.insert
              searchTag
              (peers, best, (lift . respond):responders)
              searches
          }
  where
    sendOne :: Peer -> RTS e o s ()
    sendOne peer =
      void $ send peer (Search searchTag)

handleMessage {- Search Execution -}
    {- This is where we handle local search execution. -}
    (P (PeerMessage source _ (Search searchTag)))
  = do
    output <- SM.search searchTag 
    void $ send source (SearchResponse searchTag output)

handleMessage {- Search Response -}
    {-
      This is where we gather all the responses from the various peers
      to which we dispatched search requests.
    -}
    (msg@(P (PeerMessage source _ (SearchResponse searchTag response))))
  =
    {- TODO: see if this function can't be made more elegant. -}
    Map.lookup searchTag . searches <$> lift get >>= \case
      Nothing ->
        {- There is no search happening. -}
        $(logWarn) . pack $ "Unsolicited SearchResponse: " ++ show msg
      Just (peers, best, responders) ->
        if source `Set.member` peers
          then
            let peers2 = Set.delete source peers
            in if null peers2
              then do
                {-
                  All peers have responded, go ahead and respond to
                  the client.
                -}
                lift2 $ mapM_ ($ bestOf best response) responders
                rts@RuntimeState {searches} <- lift get
                (lift . put) rts {searches = Map.delete searchTag searches}
              else do
                {- We are still waiting on some outstanding requests. -}
                rts@RuntimeState {searches} <- lift get
                (lift . put) rts {
                    searches = Map.insert
                      searchTag
                      (peers2, bestOf best response, responders)
                      searches
                  }
          else
            {-
              There is a search happening, but the peer that responded
              is not part of it.
            -}
            $(logWarn) . pack $ "Unsolicited SearchResponse: " ++ show msg
  where
    {- |
      Figure out which index record returned to us by the various peers
      is the most appropriate to return to the user. This is mostly like
      'min' but we can't use 'min' (or fancy applicative formulations)
      because we want to favor 'Just' instead of 'Nothing'.
    -}
    bestOf :: Maybe IndexRecord -> Maybe IndexRecord -> Maybe IndexRecord
    bestOf (Just a) (Just b) = Just (min a b)
    bestOf Nothing b = b
    bestOf a Nothing = a

handleMessage {- Join Request -}
    (J (JoinRequest addy, respond))
  = do
    (peer, cluster) <- SM.join addy
    lift2 $ respond (JoinOk peer cluster)

handleMessage {- Admin Get State -}
    (A (GetState respond))
  = 
    lift2 . respond =<< SMM.getNodeState

handleMessage {- Admin Get Partition -}
    (A (GetPart key respond))
  =
    lift2 . respond =<< SM.getPartition key

handleMessage {- Admin Eject Peer -}
    (A (Eject peer respond))
  = do
    eject peer
    lift2 $ respond ()

handleMessage {- Admin Get Index -}
    (A (GetIndex respond))
  =
    lift2 . respond =<< SMM.nsIndex <$> SMM.getNodeState

handleMessage {- Admin Get Divergent -}
    (A (GetDivergent respond))
  = do
    RuntimeState {commClock} <- lift get
    diverging <- divergentPeers . SMM.partitions <$> SMM.getNodeState
    lift2 . respond $ Map.fromAscList [
        (peer, r)
        | (peer, (_, r)) <- Map.toAscList commClock
        , peer `Set.member` diverging
      ]
  where
    divergentPeers :: Map PartitionKey (PartitionPowerState e o s) -> Set Peer
    divergentPeers =
      foldr Set.union Set.empty . fmap (PS.divergent . snd) . Map.toList

handleMessage {- Admin Get States -}
    (A (GetStates respond))
  = do
    persistence <- SMM.getPersistence
    lift2 . respond . Map.fromList =<< runConduit (
        transPipe liftIO (list persistence)
        =$= CL.consume
      )


{- | This defines the various ways a node can be spun up. -}
data StartupMode
  = NewCluster
    {- ^ Indicates that we should bootstrap a new cluster at startup. -}
  | JoinCluster SockAddr
    {- ^ Indicates that the node should try to join an existing cluster. -}
  | Recover Peer ClusterPowerState
    {- ^
      Recover from a crash as the given peer, using the given cluster
      state.
    -}
  deriving (Show, Eq)


{- |
  Construct a source of incoming peer messages.  We have to start the
  peer listener first before we spin up the cluster management, which
  is why this is an @LIO (Source LIO PeerMessage)@ instead of a
  @Source LIO PeerMessage@.
-}
startPeerListener :: (LegionConstraints e o s)
  => RuntimeSettings
  -> LIO (Source LIO (PeerMessage e o s))

startPeerListener RuntimeSettings {peerBindAddr} =
    catchAll (do
        (inputChan, so) <- lift $ do
          inputChan <- newChan
          so <- socket (fam peerBindAddr) Stream defaultProtocol
          setSocketOption so ReuseAddr 1
          bind so peerBindAddr
          listen so 5
          return (inputChan, so)
        forkC "peer socket acceptor" $ acceptLoop so inputChan
        return (chanToSource inputChan)
      ) (\err -> do
        $(logError) . pack
          $ "Couldn't start incomming peer message service, because of: "
          ++ show (err :: SomeException)
        throwM err
      )
  where
    acceptLoop :: (LegionConstraints e o s)
      => Socket
      -> Chan (PeerMessage e o s)
      -> LIO ()
    acceptLoop so inputChan =
        catchAll (
          forever $ do
            (conn, _) <- lift $ accept so
            remoteAddr <- lift $ getPeerName conn
            logging <- askLoggerIO
            let runSocket =
                  sourceSocket conn
                  =$= conduitDecode
                  $$ msgSink
            void
              . lift
              . forkIO
              . (`runLoggingT` logging)
              . logErrors remoteAddr
              $ runSocket
        ) (\err -> do
          $(logError) . pack
            $ "error in peer message accept loop: "
            ++ show (err :: SomeException)
          throwM err
        )
      where
        msgSink = chanToSink inputChan

        logErrors :: SockAddr -> LIO () -> LIO ()
        logErrors remoteAddr io = do
          result <- try io
          case result of
            Left err ->
              $(logWarn) . pack
                $ "Incomming peer connection (" ++ show remoteAddr
                ++ ") crashed because of: " ++ show (err :: SomeException)
            Right v -> return v


{- | Figure out how to construct the initial node state.  -}
makeNodeState :: (Event e o s, Indexable s)
  => Persistence e o s
  -> RuntimeSettings
  -> StartupMode
  -> LIO (Peer, NodeState e o s, Map Peer BSockAddr)

makeNodeState
    persistence
    settings@RuntimeSettings {peerBindAddr}
    NewCluster
  = do
    {- Build a brand new node state, for the first node in a cluster. -}
    verifyClearPersistence persistence
    self <- newPeer
    clusterId <- getUUID
    let
      cluster = C.new clusterId self peerBindAddr
    makeNodeState persistence settings (Recover self cluster)

makeNodeState
    persistence
    settings@RuntimeSettings {peerBindAddr}
    (JoinCluster addr)
  = do
    {-
      Join a cluster by either starting fresh, or recovering from a
      shutdown or crash.
    -}
    verifyClearPersistence persistence
    $(logInfo) "Trying to join an existing cluster."
    (self, cluster) <- joinCluster (JoinRequest (BSockAddr peerBindAddr))
    makeNodeState persistence settings (Recover self cluster)
  where
    joinCluster :: JoinRequest -> LIO (Peer, ClusterPowerState)
    joinCluster joinMsg = liftIO $ do
      so <- socket (fam addr) Stream defaultProtocol
      connect so addr
      sendAll so (encode joinMsg)
      {-
        using sourceSocket and conduitDecode is easier than building
        a recive/decode state loop, even though we only read a single
        response.
      -}
      sourceSocket so =$= conduitDecode $$ do
        response <- await
        case response of
          Nothing -> fail
            $ "Couldn't join a cluster because there was no response "
            ++ "to our join request!"
          Just (JoinOk self cps) ->
            return (self, cps)

makeNodeState persistence _ (Recover self cluster) = do
    {- Make sure to rebuild the index in the case of recovery. -}
    index <- runConduit . transPipe liftIO $
      list persistence
      =$= CL.fold addIndexRecords Set.empty
    let
      nodeState = (newNodeState self cluster) {nsIndex = index}
    liftIO $ saveCluster persistence self cluster
    return (self, nodeState, C.getPeers cluster)
  where
    addIndexRecords :: (Indexable s, Event e o s)
      => Set IndexRecord
      -> (PartitionKey, PartitionPowerState e o s)
      -> Set IndexRecord
    addIndexRecords index (key, partition) =
      let
        newRecords =
          Set.map
            (`IndexRecord` key)
            (indexEntries (PS.projectedValue partition))
      in Set.union index newRecords


{- |
  Helper for 'makeNodeState'. Verify that there is nothing in the
  persistence layer.
-}
verifyClearPersistence :: (MonadLoggerIO io) => Persistence e o s -> io ()
verifyClearPersistence persistence = 
  liftIO (runConduit (list persistence =$= CL.head)) >>= \case
    Just _ -> do
      let
        msg :: (IsString a) => a
        msg = fromString
          $ "We are trying to start up a new peer, but the persistence "
          ++ "layer already has data in it. This is an invalid state. "
          ++ "New nodes must be started from a totally clean, empty state."
      $(logError) msg
      liftIO $ do
        hPutStrLn stderr msg
        putStrLn msg
        error msg
    Nothing ->
      return ()


{- | A source of cluster join request messages.  -}
joinMsgSource
  :: RuntimeSettings
  -> Source LIO (JoinRequest, JoinResponse -> LIO ())

joinMsgSource RuntimeSettings {joinBindAddr} = join . lift $
    catchAll (do
        (chan, so) <- lift $ do
          chan <- newChan
          so <- socket (fam joinBindAddr) Stream defaultProtocol
          setSocketOption so ReuseAddr 1
          bind so joinBindAddr
          listen so 5
          return (chan, so)
        forkC "join socket acceptor" $ acceptLoop so chan
        return (chanToSource chan)
      ) (\err -> do
        $(logError) . pack
          $ "Couldn't start join request service, because of: "
          ++ show (err :: SomeException)
        throwM err
      )
  where
    acceptLoop :: Socket -> Chan (JoinRequest, JoinResponse -> LIO ()) -> LIO ()
    acceptLoop so chan =
        catchAll (
          forever $ do
            (conn, _) <- lift $ accept so
            logging <- askLoggerIO
            (void . lift . forkIO . (`runLoggingT` logging) . logErrors) (
                sourceSocket conn
                =$= conduitDecode
                =$= attachResponder conn
                $$  chanToSink chan
              )
        ) (\err -> do
          $(logError) . pack
            $ "error in join request accept loop: "
            ++ show (err :: SomeException)
          throwM err
        )
      where
        logErrors :: LIO () -> LIO ()
        logErrors io = do
          result <- try io
          case result of
            Left err ->
              $(logWarn) . pack
                $ "Incomming join connection crashed because of: "
                ++ show (err :: SomeException)
            Right v -> return v

        attachResponder
          :: Socket
          -> ConduitM JoinRequest (JoinRequest, JoinResponse -> LIO ()) LIO ()
        attachResponder conn = awaitForever (\msg -> do
            mvar <- liftIO newEmptyMVar
            yield (msg, lift . putMVar mvar)
            response <- liftIO $ takeMVar mvar
            liftIO $ sendAll conn (encode response)
          )


{- | Guess the family of a `SockAddr`. -}
fam :: SockAddr -> Family
fam SockAddrInet {} = AF_INET
fam SockAddrInet6 {} = AF_INET6
fam SockAddrUnix {} = AF_UNIX
fam SockAddrCan {} = AF_CAN


{- |
  Forks the legion framework in a background thread, and returns a way to
  send user requests to it and retrieve the responses to those requests.

  - @__e__@ is the type of request your application will handle. @__e__@ stands
    for __"event"__.
  - @__o__@ is the type of response produced by your application. @__o__@ stands
    for __"output"__
  - @__s__@ is the type of state maintained by your application. More
    precisely, it is the type of the individual partitions that make up
    your global application state. @__s__@ stands for __"state"__.
-}
forkLegionary :: (LegionConstraints e o s, MonadLoggerIO io)
  => Persistence e o s
    {- ^ The persistence layer used to back the legion framework. -}
  -> RuntimeSettings
    {- ^ Settings and configuration of the legion framework. -}
  -> StartupMode
  -> io (Runtime e o)

forkLegionary persistence settings startupMode = do
  logging <- askLoggerIO
  liftIO . (`runLoggingT` logging) $ do
    chan <- liftIO newChan
    forkC "main legion thread" $
      runLegionary persistence settings startupMode (chanToSource chan)
    return Runtime {
        rtMakeRequest = \key request -> liftIO $ do
          responseVar <- newEmptyMVar
          writeChan chan (Request key request (putMVar responseVar))
          takeMVar responseVar,
        rtSearch =
          let
            findNext :: SearchTag -> IO (Maybe IndexRecord)
            findNext searchTag = do
              responseVar <- newEmptyMVar
              writeChan chan (SearchDispatch searchTag (putMVar responseVar))
              takeMVar responseVar
          in findNext

      }


{- |
  This type represents a handle to the runtime environment of your
  Legion application. This allows you to make requests and access the
  partition index.

  'Runtime' is an opaque structure. Use 'makeRequest' and 'search' to
  access it.
-}
data Runtime e o = Runtime {
    {- |
      Send an application request to the legion runtime, and get back
      a response.
    -}
    rtMakeRequest :: PartitionKey -> e -> IO o,

    {- | Query the index to find a set of partition keys.  -}
    rtSearch :: SearchTag -> IO (Maybe IndexRecord)
  }


{- | Send a user request to the legion runtime. -}
makeRequest :: (MonadIO io) => Runtime e o -> PartitionKey -> e -> io o
makeRequest rt key = liftIO . rtMakeRequest rt key


{- |
  Send a search request to the legion runtime. Returns results that are
  __strictly greater than__ the provided 'SearchTag'.
-}
search :: (MonadIO io) => Runtime e o -> SearchTag -> Source io IndexRecord
search rt tag =
  liftIO (rtSearch rt tag) >>= \case
    Nothing -> return ()
    Just record@IndexRecord {irTag, irKey} -> do
      yield record
      search rt (SearchTag irTag (Just irKey))


{- | This is the type of message passed around in the runtime. -}
data RuntimeMessage e o s
  = P (PeerMessage e o s)
  | R (RequestMsg e o)
  | J (JoinRequest, JoinResponse -> LIO ())
  | A (AdminMessage e o s)
instance (Show e, Show o, Show s) => Show (RuntimeMessage e o s) where
  show (P m) = "(P " ++ show m ++ ")"
  show (R m) = "(R " ++ show m ++ ")"
  show (J (jr, _)) = "(J (" ++ show jr ++ ", _))"
  show (A a) = "(A (" ++ show a ++ "))"


{- |
  The runtime state.

  The 'searches' field is a little weird.

  It turns out that searches are deterministic over the parameters of
  'SearchTag' and cluster state. This should make sense, because everything in
  Haskell is deterministic given __all__ the parameters. Since the cluster
  state only changes over time, searches that happen "at the same time" and
  for the same 'SearchTag' can be considered identical. I don't think it is too
  much of a stretch to say that searches that have overlapping execution times
  can be considered to be happening "at the same time", therefore the
  search tag becomes determining factor in the result of the search.

  This is a long-winded way of justifying the fact that, if we are currently
  executing a search and an identical search requests arrives, then the second
  identical search is just piggy-backed on the results of the currently
  executing search. Whether this counts as a premature optimization hack or a
  beautifully elegant expression of platonic reality is left as an exercise for
  the reader. It does help simplify the code a little bit because we don't have
  to specify some kind of UUID to differentiate otherwise identical searches.
-}
data RuntimeState e o s = RuntimeState {
         self :: Peer,
    forwarded :: Map MessageId (o -> LIO ()),
       nextId :: MessageId,
           cm :: ConnectionManager e o s,
    commClock :: Map Peer (Maybe UTCTime, Maybe UTCTime),
                 {- ^ When did we last communicate with a peer. (sent, recv). -}
     searches :: Map
                   SearchTag
                   (Set Peer, Maybe IndexRecord, [Maybe IndexRecord -> LIO ()])
  }


{- | This is the type of a join request message. -}
newtype JoinRequest = JoinRequest BSockAddr
  deriving (Generic, Show)
instance Binary JoinRequest


{- | The response to a JoinRequst message -}
data JoinResponse
  = JoinOk Peer ClusterPowerState
  deriving (Generic)
instance Binary JoinResponse


{- | Lookup a key from a map, and also delete the key if it exists. -}
lookupDelete :: (Ord k) => k -> Map k v -> (Maybe v, Map k v)
lookupDelete = Map.updateLookupWithKey (const (const Nothing))


{- | The runtime monad.  -}
type RTS e o s =
  SM e o s (
  StateT (RuntimeState e o s)
  LIO)


{- | Shorthand for running the RTS monad. -}
runRTS
  :: Persistence e o s
  -> NodeState e o s
  -> RuntimeState e o s
  -> RTS e o s a
  -> LIO (a, NodeState e o s, [ClusterAction e o s], RuntimeState e o s)
runRTS persistence ns rts =
    fmap flatten
    . (`runStateT` rts)
    . runSM persistence ns
  where
    flatten ((a, b, c), d) = (a, b, c, d)


{- |
  Send a peer message in the RTS monad, automatically taking care of
  necessary state updates.
-}
send :: Peer -> PeerMessagePayload e o s -> RTS e o s MessageId
send target payload = do
  rts@RuntimeState {cm, self, nextId} <- lift get
  (lift . put) rts {nextId = nextMessageId nextId}
  lift2 $ CM.send cm target (PeerMessage self nextId payload)
  return nextId


{- | Forward an existing message to another peer. -}
forward :: Peer -> PeerMessage e o s -> RTS e o s ()
forward target message = do
  RuntimeState {cm} <- lift get
  lift2 $ CM.send cm target message


{- | Update the time when we last received a message from a peer. -}
updateRecvClock :: Peer -> RTS e o s ()
updateRecvClock peer = do
  now <- liftIO getCurrentTime
  (lift . modify) (\rts@RuntimeState {commClock} ->
      let
        newCommClock = case Map.lookup peer commClock of
          Nothing -> Map.insert peer (Nothing, Just now) commClock
          Just (s, _) -> Map.insert peer (s, Just now) commClock
      in newCommClock `seq` rts {
          commClock = newCommClock
        }
    )