{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Examples.Raft.Socket.Node where

import Protolude hiding
  ( MVar, putMVar, takeMVar, newMVar, newEmptyMVar, readMVar
  , atomically, STM(..), Chan, newTVar, readTVar, writeTVar
  , newChan, writeChan, readChan
  , threadDelay, killThread, TVar(..)
  , catch, handle, takeWhile, takeWhile1, (<|>)
  )

import Control.Concurrent.Classy hiding (catch, ThreadId)
import Control.Monad.Fail
import Control.Monad.Catch
import Control.Monad.Trans.Class

import qualified Data.Map as Map
import qualified Data.Serialize as S
import qualified Network.Simple.TCP as NS
import Network.Simple.TCP

import Examples.Raft.Socket.Common

import Raft

--------------------------------------------------------------------------------
-- Network
--------------------------------------------------------------------------------

data NodeSocketEnv v = NodeSocketEnv
  { nsSocket :: Socket
  , nsPeers :: TVar (STM IO) (Map NodeId Socket)
  , nsMsgQueue :: TChan (STM IO) (RPCMessage v)
  , nsClientReqQueue :: TChan (STM IO) (ClientRequest v)
  }

newtype RaftSocketT v m a = RaftSocketT { unRaftSocketT :: ReaderT (NodeSocketEnv v) m a }
  deriving (Functor, Applicative, Monad, MonadIO, MonadFail, MonadReader (NodeSocketEnv v), Alternative, MonadPlus, MonadTrans)

deriving instance MonadConc m => MonadThrow (RaftSocketT v m)
deriving instance MonadConc m => MonadCatch (RaftSocketT v m)
deriving instance MonadConc m => MonadMask (RaftSocketT v m)
deriving instance MonadConc m => MonadConc (RaftSocketT v m)

--------------------
-- Raft Instances --
--------------------

instance (MonadIO m, MonadConc m, S.Serialize sm) => RaftSendClient (RaftSocketT v m) sm where
  sendClient clientId@(ClientId nid) msg = do
    let (cHost, cPort) = nidToHostPort (toS nid)
    connect cHost cPort $ \(cSock, _cSockAddr) ->
      send cSock (S.encode msg)

instance (MonadIO m, MonadConc m, S.Serialize v) => RaftRecvClient (RaftSocketT v m) v where
  type RaftRecvClientError (RaftSocketT v m) v = Text
  receiveClient = do
    cReq <- asks nsClientReqQueue
    fmap Right . liftIO . atomically $ readTChan cReq

instance (MonadIO m, MonadConc m, S.Serialize v, Show v) => RaftSendRPC (RaftSocketT v m) v where
  sendRPC nid msg = do
    tNodeSocketEnvPeers <- asks nsPeers
    nodeSocketPeers <- liftIO $ atomically $ readTVar tNodeSocketEnvPeers
    sockM <- liftIO $
        case Map.lookup nid nodeSocketPeers of
          Nothing -> handle (handleFailure tNodeSocketEnvPeers [nid] Nothing) $ do
            (sock, _) <- connectSock host port
            NS.send sock (S.encode $ RPCMessageEvent msg)
            pure $ Just sock
          Just sock -> handle (retryConnection tNodeSocketEnvPeers nid (Just sock) msg) $ do
            NS.send sock (S.encode $ RPCMessageEvent msg)
            pure $ Just sock
    liftIO $ atomically $ case sockM of
      Nothing -> pure ()
      Just sock -> writeTVar tNodeSocketEnvPeers (Map.insert nid sock nodeSocketPeers)
    where
      (host, port) = nidToHostPort nid

instance (MonadIO m, MonadConc m, Show v) => RaftRecvRPC (RaftSocketT v m) v where
  type RaftRecvRPCError (RaftSocketT v m) v = Text
  receiveRPC = do
    msgQueue <- asks nsMsgQueue
    fmap Right . liftIO . atomically $ readTChan msgQueue


-----------
-- Utils --
-----------

-- | Handles connections failures by first trying to reconnect
retryConnection
  :: (S.Serialize v, MonadIO m, MonadConc m)
  => TVar (STM m) (Map NodeId Socket)
  -> NodeId
  -> Maybe Socket
  -> RPCMessage v
  -> SomeException
  -> m (Maybe Socket)
retryConnection tNodeSocketEnvPeers nid sockM msg e =  case sockM of
  Nothing -> pure Nothing
  Just sock ->
    handle (handleFailure tNodeSocketEnvPeers [nid] Nothing) $ do
      (sock, _) <- connectSock host port
      NS.send sock (S.encode $ RPCMessageEvent msg)
      pure $ Just sock
  where
    (host, port) = nidToHostPort nid

handleFailure
  :: (MonadIO m, MonadConc m)
  => TVar (STM m) (Map NodeId Socket)
  -> [NodeId]
  -> Maybe Socket
  -> SomeException
  -> m (Maybe Socket)
handleFailure tNodeSocketEnvPeers nids sockM e = case sockM of
  Nothing -> pure Nothing
  Just sock -> do
    nodeSocketPeers <- atomically $ readTVar tNodeSocketEnvPeers
    closeSock sock
    atomically $ mapM_ (\nid -> writeTVar tNodeSocketEnvPeers (Map.delete nid nodeSocketPeers)) nids
    pure Nothing


runRaftSocketT :: (MonadIO m, MonadConc m) => NodeSocketEnv v -> RaftSocketT v m a -> m a
runRaftSocketT nodeSocketEnv = flip runReaderT nodeSocketEnv . unRaftSocketT

-- | Recursively accept a connection.
-- It keeps trying to accept connections even when a node dies
acceptForkNode
  :: forall v m. (S.Serialize v, MonadIO m, MonadConc m)
  => RaftSocketT v m ()
acceptForkNode = do
  socketEnv@NodeSocketEnv{..} <- ask
  void $ fork $ void $ forever $ acceptFork nsSocket $ \(sock', sockAddr') ->
    forever $ do
      recvSockM <- recv sock' 4096
      case recvSockM of
        Nothing -> panic "Socket was closed on the other end"
        Just recvSock -> case ((S.decode :: ByteString -> Either [Char] (MessageEvent v)) recvSock) of
          Left err -> panic $ toS err
          Right (ClientRequestEvent req@(ClientRequest cid _)) ->
            atomically $ writeTChan nsClientReqQueue req
          Right (RPCMessageEvent msg) ->
            atomically $ writeTChan nsMsgQueue msg

newSock :: HostName -> ServiceName -> IO Socket
newSock host port = do
  (sock, _) <- bindSock (Host host) port
  listenSock sock 2048
  pure sock