{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}

module Network.Legion.Runtime (
  -- * User-facing interface.
  forkRuntime,
  makeRequest,
  search,
  Runtime,

  -- * Internal interface.
  eject,
  getDivergent,

  -- * Debugging interface.
  debugLocalPartitions,
  debugRuntimeState,
  debugPartition,
  debugIndex,

) where


import Control.Concurrent (writeChan, newChan, Chan)
import Control.Concurrent.MVar (newEmptyMVar, takeMVar, putMVar)
import Control.Monad (void, forever)
import Control.Monad.Catch (catchAll, try, SomeException, throwM,
  MonadCatch)
import Control.Monad.IO.Class (liftIO, MonadIO)
import Control.Monad.Logger (MonadLoggerIO, logError, logWarn, logDebug,
  askLoggerIO, runLoggingT)
import Control.Monad.Trans.Class (lift)
import Data.Aeson (Value)
import Data.Binary (encode)
import Data.Conduit ((.|), runConduit, awaitForever, Source, yield)
import Data.Conduit.Network (sourceSocket)
import Data.Conduit.Serialization.Binary (conduitDecode)
import Data.Map (Map)
import Data.Set (Set)
import Data.Time (UTCTime)
import Network.Legion.Application (LegionConstraints, Persistence)
import Network.Legion.Conduit (chanToSource)
import Network.Legion.Distribution (Peer)
import Network.Legion.Fork (forkC, forkL, ForkM)
import Network.Legion.Index (IndexRecord(IndexRecord),
  SearchTag(SearchTag), irTag, irKey)
import Network.Legion.LIO (LIO)
import Network.Legion.PartitionKey (PartitionKey)
import Network.Legion.PartitionState (PartitionPowerState)
import Network.Legion.Runtime.ConnectionManager (send)
import Network.Legion.Runtime.PeerMessage (PeerMessage(PeerMessage),
  PeerMessagePayload(PartitionMerge, ForwardRequest, ForwardResponse,
  ClusterMerge, Search, SearchResponse, JoinNext, JoinNextResponse),
  payload, source, messageId, JoinNextResponse(JoinFinished, Joined))
import Network.Legion.Runtime.State (makeRuntimeState, StartupMode,
  RuntimeT, JoinRequest(JoinRequest), JoinResponse(JoinOk), runRuntimeT,
  updateRecvClock, userRequest, forwardResponse, clusterMerge,
  getCM, searchResponse, joinNext, partitionMerge, joinNextResponse,
  forwardedRequest, searchDispatch)
import Network.Legion.Settings (RuntimeSettings(RuntimeSettings),
  peerBindAddr, joinBindAddr)
import Network.Legion.SocketUtil (fam)
import Network.Socket (SocketOption(ReuseAddr), SocketType(Stream),
  accept, bind, defaultProtocol, listen, setSocketOption, socket,
  SockAddr, getPeerName, Socket)
import Network.Socket.ByteString.Lazy (sendAll)
import qualified Data.Text as T
import qualified Network.Legion.Runtime.State as S


{- |
  This type represents a handle to the runtime environment of your
  Legion application. This allows you to make requests and access the
  partition index.

  'Runtime' is an opaque structure. Use 'makeRequest' and 'search' to
  access it.
-}
newtype Runtime e o s = Runtime {
    unRuntime :: Chan (RuntimeMessage e o s)
  }


{- | Fork the runtime in a background thread. -}
forkRuntime :: (LegionConstraints e o s, MonadLoggerIO m)
  => Persistence e o s
    {- ^ The persistence layer used to back the legion framework. -}
  -> RuntimeSettings
    {- ^ Settings and configuration of the legion framework. -}
  -> StartupMode
  -> m (Runtime e o s)
forkRuntime persistence settings startupMode = do
  runtime <- Runtime <$> liftIO newChan
  logging <- askLoggerIO
  liftIO . (`runLoggingT` logging) . forkC "main legion thread" $
    executeRuntime persistence settings startupMode runtime
  return runtime


{- | Send a user request to the legion runtime. -}
makeRequest :: (MonadIO m) => Runtime e o s -> PartitionKey -> e -> m o
makeRequest runtime key e = call runtime (RMUserRequest key e)


{- |
  Send a search request to the legion runtime. Returns results that are
  __strictly greater than__ the provided 'SearchTag'.
-}
search :: (MonadIO m) => Runtime e o s -> SearchTag -> Source m IndexRecord
search runtime tag =
  call runtime (RMUserSearch tag) >>= \case
    Nothing -> return ()
    Just record@IndexRecord {irTag, irKey} -> do
      yield record
      search runtime (SearchTag irTag (Just irKey))


{- | Get the runtime state for debugging. -}
debugRuntimeState :: (MonadIO m)
  => Runtime e o s
  -> m Value
debugRuntimeState runtime = call runtime RMDebugRuntimeState


{- | Get a partition for debugging. -}
debugPartition :: (MonadIO m)
  => Runtime e o s
  -> PartitionKey
  -> m (Maybe (PartitionPowerState e o s))
debugPartition runtime = call runtime . RMDebugPartition


{- | Eject a peer. -}
eject :: (MonadIO m)
  => Runtime e o s
  -> Peer
  -> m ()
eject runtime = call runtime . RMEject


{- | Get the index for debugging. -}
debugIndex :: (MonadIO m)
  => Runtime e o s
  -> m (Set IndexRecord)
debugIndex runtime = call runtime RMDebugIndex


{- | Get the divergent peers. -}
getDivergent :: (MonadIO m)
  => Runtime e o s
  -> m (Map Peer (Maybe UTCTime))
getDivergent runtime = call runtime RMGetDivergent


{- | Dump all of the locally managed partitions, for debugging. -}
debugLocalPartitions :: (MonadIO m)
  => Runtime e o s
  -> m (Map PartitionKey (PartitionPowerState e o s))
debugLocalPartitions runtime = call runtime RMDebugLocalPartitions


{- |
  Execute the Legion runtime, with the given user definitions, and
  framework settings. This function never returns (except maybe with an
  exception if something goes horribly wrong).
-}
executeRuntime :: (
      ForkM m,
      LegionConstraints e o s,
      MonadCatch m,
      MonadLoggerIO m
    )
  => Persistence e o s
    {- ^ The persistence layer used to back the legion framework. -}
  -> RuntimeSettings
    {- ^ Settings and configuration of the legionframework.  -}
  -> StartupMode
  -> Runtime e o s
    {- ^ A source of requests, together with a way to respond to the requets. -}
  -> m ()
executeRuntime
    persistence
    settings
    startupMode
    runtime
  = do
    {- Start the various messages sources. -}
    startPeerListener settings runtime
    startJoinListener settings runtime

    rts <- makeRuntimeState persistence settings startupMode
    void . runRuntimeT persistence rts . runConduit $
      chanToSource (unRuntime runtime)
      .| awaitForever (\msg -> do
          $(logDebug) . T.pack $ "Receieved: " ++ show msg
          lift (handleRuntimeMessage runtime msg)
        )


{- | Handle runtime message. -}
handleRuntimeMessage :: (
      ForkM m,
      LegionConstraints e o s,
      MonadCatch m,
      MonadLoggerIO m
    )
  => Runtime e o s
     {- ^
       A handle on our own runtime, used to send messages back to
       ourselves.
     -}
  -> RuntimeMessage e o s
  -> RuntimeT e o s m ()

handleRuntimeMessage
    runtime
    (RMPeerMessage msg@(PeerMessage source _ _))
  = do
    updateRecvClock source
    handlePeerMessage runtime msg

handleRuntimeMessage _ (RMJoinRequest (JoinRequest addr) responder) = do
  (peer, cluster) <- S.joinCluster addr
  respond responder (JoinOk peer cluster)

handleRuntimeMessage _ (RMDebugRuntimeState responder) =
  respond responder =<< S.debugRuntimeState

handleRuntimeMessage _ (RMDebugPartition key responder) =
  respond responder =<< S.debugPartition key

handleRuntimeMessage _ (RMEject peer responder) =
  respond responder =<< S.eject peer

handleRuntimeMessage _ (RMDebugIndex responder) =
  respond responder =<< S.debugIndex

handleRuntimeMessage _ (RMGetDivergent responder) =
  respond responder =<< S.getDivergent

handleRuntimeMessage _ (RMDebugLocalPartitions responder) =
  respond responder =<< S.debugLocalPartitions

handleRuntimeMessage _ (RMUserRequest key request responder) =
  userRequest key request (respond responder)

handleRuntimeMessage _ (RMUserSearch tag responder) =
  searchDispatch tag (respond responder)


{- | Handle a peer message. -}
handlePeerMessage :: (
      ForkM m,
      LegionConstraints e o s,
      MonadCatch m,
      MonadLoggerIO m
    )
  => Runtime e o s
     {- ^
       A handle on our own runtime, used to send messages back to
       ourselves.
     -}
  -> PeerMessage e o s
  -> RuntimeT e o s m ()

handlePeerMessage {- PartitionMerge -}
    _runtime
    PeerMessage {
        payload = (PartitionMerge key partition)
      }
  =
    partitionMerge key partition

handlePeerMessage {- ForwardRequest -}
    _runtime
    PeerMessage {
        source,
        messageId,
        payload = ForwardRequest key event
      }
  =
    forwardedRequest source messageId key event

handlePeerMessage {- ForwardResponse -}
    _runtime
    PeerMessage {
        payload = ForwardResponse forMessageId output
      }
  =
    forwardResponse forMessageId output

handlePeerMessage {- ClusterMerge -}
    _runtime
    PeerMessage {
        payload = ClusterMerge cluster
      }
  =
    clusterMerge cluster

handlePeerMessage {- Search -}
    {- This is where we handle local search execution. -}
    _runtime
    PeerMessage {
        source,
        payload = Search searchTag
      }
  = do
    searchResult <- S.search searchTag
    cm <- getCM
    void $ send cm source (SearchResponse searchTag searchResult)

handlePeerMessage {- SearchResponse -}
    {-
      This is where we gather all the responses from the various peers
      to which we dispatched search requests.
    -}
    _runtime
    PeerMessage {
        source,
        payload = SearchResponse searchTag record
      }
  =
    searchResponse source searchTag record

handlePeerMessage {- JoinNext -}
    _runtime
    PeerMessage {
        source,
        messageId,
        payload = JoinNext askKeys
      }
  = do
    cm <- getCM
    joinNext source askKeys (\case
        Nothing -> void $
          send cm source (JoinNextResponse messageId JoinFinished)
        Just (gotKey, partition) -> void $
          send cm source (JoinNextResponse messageId (Joined gotKey partition))
      )

handlePeerMessage {- JoinNextResponse -}
    _runtime
    PeerMessage {
        source,
        payload = JoinNextResponse _toMessageId response
      }
  =
    joinNextResponse source (toMaybe response)
  where
    toMaybe
      :: JoinNextResponse e o s
      -> Maybe (PartitionKey, PartitionPowerState e o s)
    toMaybe (Joined key partition) = Just (key, partition)
    toMaybe JoinFinished = Nothing


{- | A way for the runtime to respond to a message. -}
newtype Responder a = Responder {
    unResponder :: a -> IO ()
  }
instance Show (Responder a) where
  show _ = "Responder"


{- | Respond to a messag, using the given responder, in 'MonadIO'. -}
respond :: (MonadIO m) => Responder a -> a -> m ()
respond responder = liftIO . unResponder responder


{- | Send a message to the runtime that blocks on a response. -}
call :: (MonadIO m)
  => Runtime e o s
  -> (Responder a -> RuntimeMessage e o s)
  -> m a
call runtime withResonder = liftIO $ do
  mvar <- newEmptyMVar
  cast runtime (withResonder (Responder (putMVar mvar)))
  takeMVar mvar


{- | Send a message to the runtime. Do not wait for a result. -}
cast :: Runtime e o s -> RuntimeMessage e o s -> IO ()
cast runtime = writeChan (unRuntime runtime)


data RuntimeMessage e o s
  = RMPeerMessage (PeerMessage e o s)
  | RMJoinRequest JoinRequest (Responder JoinResponse)
  | RMDebugRuntimeState (Responder Value)
  | RMDebugPartition
      PartitionKey
      (Responder (Maybe (PartitionPowerState e o s)))
  | RMEject Peer (Responder ())
  | RMDebugIndex (Responder (Set IndexRecord))
  | RMGetDivergent (Responder (Map Peer (Maybe UTCTime)))
  | RMDebugLocalPartitions
      (Responder (Map PartitionKey (PartitionPowerState e o s)))
  | RMUserRequest PartitionKey e (Responder o)
  | RMUserSearch SearchTag (Responder (Maybe IndexRecord))
  deriving (Show)


{- |
  Start the peer listener, which accepts peer messages from the network
  and sends them to the runtime.
-}
startPeerListener :: (
      ForkM m,
      LegionConstraints e o s,
      MonadCatch m,
      MonadLoggerIO m
    )
  => RuntimeSettings
  -> Runtime e o s
  -> m ()

startPeerListener RuntimeSettings {peerBindAddr} runtime =
    catchAll (do
        so <- liftIO $ do
          so <- socket (fam peerBindAddr) Stream defaultProtocol
          setSocketOption so ReuseAddr 1
          bind so peerBindAddr
          listen so 5
          return so
        forkC "peer socket acceptor" $ acceptLoop so runtime
      ) (\err -> do
        $(logError) . T.pack
          $ "Couldn't start incomming peer message service, because of: "
          ++ show (err :: SomeException)
        throwM err
      )
  where
    acceptLoop :: (MonadLoggerIO m, LegionConstraints e o s, MonadCatch m)
      => Socket
      -> Runtime e o s
      -> m ()
    acceptLoop so runtime_ =
        catchAll (
          forever $ do
            (conn, _) <- liftIO $ accept so
            remoteAddr <- liftIO $ getPeerName conn
            void
              . forkL
              . logErrors remoteAddr
              $ runConduit (
                  sourceSocket conn
                  .| conduitDecode
                  .| awaitForever (liftIO . cast runtime_ . RMPeerMessage)
                )
        ) (\err -> do
          $(logError) . T.pack
            $ "error in peer message accept loop: "
            ++ show (err :: SomeException)
          throwM err
        )
      where
        logErrors :: SockAddr -> LIO () -> LIO ()
        logErrors remoteAddr io = do
          result <- try io
          case result of
            Left err ->
              $(logWarn) . T.pack
                $ "Incomming peer connection (" ++ show remoteAddr
                ++ ") crashed because of: " ++ show (err :: SomeException)
            Right v -> return v


{- |
  Starts the join listener, which accepts cluster join requests from
  the network and sends them to the runtime.
-}
startJoinListener :: (MonadCatch m, MonadLoggerIO m, ForkM m)
  => RuntimeSettings
  -> Runtime e o s
  -> m ()

startJoinListener RuntimeSettings {joinBindAddr} runtime =
    catchAll (do
        so <- liftIO $ do
          so <- socket (fam joinBindAddr) Stream defaultProtocol
          setSocketOption so ReuseAddr 1
          bind so joinBindAddr
          listen so 5
          return so
        forkC "join socket acceptor" $ acceptLoop so
      ) (\err -> do
        $(logError) . T.pack
          $ "Couldn't start join request service, because of: "
          ++ show (err :: SomeException)
        throwM err
      )
  where
    acceptLoop :: (MonadCatch m, MonadLoggerIO m) => Socket -> m ()
    acceptLoop so =
        catchAll (
          forever $ do
            (conn, _) <- liftIO (accept so)
            void
              . forkL
              . logErrors
              . liftIO
              $ runConduit (
                  sourceSocket conn
                  .| conduitDecode
                  .| awaitForever (\req -> liftIO $
                      sendAll conn . encode
                        =<< call runtime (RMJoinRequest req)
                    )
                )
        ) (\err -> do
          $(logError) . T.pack
            $ "error in join request accept loop: "
            ++ show (err :: SomeException)
          throwM err
        )
      where
        logErrors :: (MonadCatch m, MonadLoggerIO m) => m () -> m ()
        logErrors m = do
          result <- try m
          case result of
            Left err ->
              $(logWarn) . T.pack
                $ "Incomming join connection crashed because of: "
                ++ show (err :: SomeException)
            Right v -> return v