{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -Wno-name-shadowing #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} module Metro.Node ( NodeEnv , NodeMode (..) , SessionMode (..) , NodeT , initEnv , withEnv , setNodeMode , setSessionMode , setDefaultSessionTimeout , setDefaultSessionTimeout1 , runNodeT , startNodeT , startNodeT_ , withSessionT , nodeState , stopNodeT , env , request , requestAndRetry , newSessionEnv , nextSessionId , runSessionT_ , busy -- combine node env and conn env , NodeEnv1 (..) , initEnv1 , runNodeT1 , getEnv1 , getTimer , getNodeId , getSessionSize , getSessionSize1 ) where import Control.Monad (forM, forever, mzero, void, when) import Control.Monad.Reader.Class (MonadReader (ask), asks) import Control.Monad.Trans.Class (MonadTrans (..)) import Control.Monad.Trans.Maybe (runMaybeT) import Control.Monad.Trans.Reader (ReaderT (..), runReaderT) import Data.Hashable import Data.Int (Int64) import Data.Maybe (fromMaybe, isJust) import Metro.Class (GetPacketId, RecvPacket, SendPacket, SetPacketId, Transport, getPacketId) import Metro.Conn (ConnEnv, ConnT, FromConn (..), close, receive, runConnT) import Metro.IOHashMap (IOHashMap, newIOHashMap) import qualified Metro.IOHashMap as HM (delete, elems, insert, lookup, size) import Metro.Session (SessionEnv (sessionId), SessionT, feed, isTimeout, runSessionT) import qualified Metro.Session as S (newSessionEnv, receive, send) import Metro.Utils (getEpochTime) import System.Log.Logger (errorM) import UnliftIO import UnliftIO.Concurrent (threadDelay) data NodeMode = Single | Multi deriving (Show, Eq) data SessionMode = SingleAction | MultiAction deriving (Show, Eq) data NodeEnv u nid k rpkt = NodeEnv { uEnv :: u , nodeStatus :: TVar Bool , nodeMode :: NodeMode , sessionMode :: SessionMode , nodeSession :: TVar (Maybe (SessionEnv u nid k rpkt)) , sessionList :: IOHashMap k (SessionEnv u nid k rpkt) , sessionGen :: IO k , nodeTimer :: TVar Int64 , nodeId :: nid , sessTimeout :: TVar Int64 , onNodeLeave :: TVar (Maybe (u -> IO ())) } data NodeEnv1 u nid k rpkt tp = NodeEnv1 { nodeEnv :: NodeEnv u nid k rpkt , connEnv :: ConnEnv tp } newtype NodeT u nid k rpkt tp m a = NodeT { unNodeT :: ReaderT (NodeEnv u nid k rpkt) (ConnT tp m) a } deriving ( Functor , Applicative , Monad , MonadIO , MonadReader (NodeEnv u nid k rpkt) ) instance MonadUnliftIO m => MonadUnliftIO (NodeT u nid k rpkt tp m) where withRunInIO inner = NodeT $ ReaderT $ \r -> withRunInIO $ \run -> inner (run . runNodeT r) instance MonadTrans (NodeT u nid k rpkt tp) where lift = NodeT . lift . lift instance FromConn (NodeT u nid k rpkt) where fromConn = NodeT . lift runNodeT :: NodeEnv u nid k rpkt -> NodeT u nid k rpkt tp m a -> ConnT tp m a runNodeT nEnv = flip runReaderT nEnv . unNodeT runNodeT1 :: NodeEnv1 u nid k rpkt tp -> NodeT u nid k rpkt tp m a -> m a runNodeT1 NodeEnv1 {..} = runConnT connEnv . runNodeT nodeEnv initEnv :: MonadIO m => u -> nid -> IO k -> m (NodeEnv u nid k rpkt) initEnv uEnv nodeId sessionGen = do nodeStatus <- newTVarIO True nodeSession <- newTVarIO Nothing sessionList <- newIOHashMap nodeTimer <- newTVarIO =<< getEpochTime onNodeLeave <- newTVarIO Nothing sessTimeout <- newTVarIO 300 pure NodeEnv { nodeMode = Multi , sessionMode = SingleAction , .. } withEnv :: (Monad m) => u -> NodeT u nid k rpkt tp m a -> NodeT u nid k rpkt tp m a withEnv u m = do env0 <- ask fromConn $ runNodeT (env0 {uEnv=u}) m setNodeMode :: NodeMode -> NodeEnv u nid k rpkt -> NodeEnv u nid k rpkt setNodeMode mode nodeEnv = nodeEnv {nodeMode = mode} setSessionMode :: SessionMode -> NodeEnv u nid k rpkt -> NodeEnv u nid k rpkt setSessionMode mode nodeEnv = nodeEnv {sessionMode = mode} setDefaultSessionTimeout :: TVar Int64 -> NodeEnv u nid k rpkt -> NodeEnv u nid k rpkt setDefaultSessionTimeout t nodeEnv = nodeEnv { sessTimeout = t } setDefaultSessionTimeout1 :: MonadIO m => NodeEnv1 u nid k rpkt tp -> Int64 -> m () setDefaultSessionTimeout1 NodeEnv1 {..} = atomically . writeTVar (sessTimeout nodeEnv) initEnv1 :: MonadIO m => (NodeEnv u nid k rpkt -> NodeEnv u nid k rpkt) -> ConnEnv tp -> u -> nid -> IO k -> m (NodeEnv1 u nid k rpkt tp) initEnv1 mapEnv connEnv uEnv nid gen = do nodeEnv <- mapEnv <$> initEnv uEnv nid gen return NodeEnv1 {..} getEnv1 :: (Monad m, Transport tp) => NodeT u nid k rpkt tp m (NodeEnv1 u nid k rpkt tp) getEnv1 = do connEnv <- fromConn ask nodeEnv <- ask return NodeEnv1 {..} runSessionT_ :: Monad m => SessionEnv u nid k rpkt -> SessionT u nid k rpkt tp m a -> NodeT u nid k rpkt tp m a runSessionT_ aEnv = fromConn . runSessionT aEnv withSessionT :: (MonadUnliftIO m, Eq k, Hashable k) => Maybe Int64 -> SessionT u nid k rpkt tp m a -> NodeT u nid k rpkt tp m a withSessionT sTout sessionT = bracket nextSessionId removeSession $ \sid -> do aEnv <- newSessionEnv sTout sid runSessionT_ aEnv sessionT newSessionEnv :: (MonadIO m, Eq k, Hashable k) => Maybe Int64 -> k -> NodeT u nid k rpkt tp m (SessionEnv u nid k rpkt) newSessionEnv sTout sid = do NodeEnv{..} <- ask dTout <- readTVarIO sessTimeout sEnv <- S.newSessionEnv uEnv nodeId sid (fromMaybe dTout sTout) [] case nodeMode of Single -> atomically $ do sess <- readTVar nodeSession case sess of Nothing -> writeTVar nodeSession $ Just sEnv Just _ -> do state <- readTVar nodeStatus when state retrySTM Multi -> HM.insert sessionList sid sEnv return sEnv nextSessionId :: MonadIO m => NodeT u nid k rpkt tp m k nextSessionId = liftIO =<< asks sessionGen removeSession :: (MonadIO m, Eq k, Hashable k) => k -> NodeT u nid k rpkt tp m () removeSession mid = do NodeEnv{..} <- ask case nodeMode of Single -> atomically $ writeTVar nodeSession Nothing Multi -> HM.delete sessionList mid busy :: MonadIO m => NodeT u nid k rpkt tp m Bool busy = do NodeEnv{..} <- ask case nodeMode of Single -> isJust <$> readTVarIO nodeSession Multi -> return False tryMainLoop :: (MonadUnliftIO m, Transport tp, RecvPacket rpkt, GetPacketId k rpkt, Eq k, Hashable k) => (rpkt -> m Bool) -> SessionT u nid k rpkt tp m () -> NodeT u nid k rpkt tp m () tryMainLoop preprocess sessionHandler = do r <- tryAny $ mainLoop preprocess sessionHandler case r of Left _ -> stopNodeT Right _ -> pure () mainLoop :: (MonadUnliftIO m, Transport tp, RecvPacket rpkt, GetPacketId k rpkt, Eq k, Hashable k) => (rpkt -> m Bool) -> SessionT u nid k rpkt tp m () -> NodeT u nid k rpkt tp m () mainLoop preprocess sessionHandler = do NodeEnv{..} <- ask rpkt <- fromConn receive setTimer =<< getEpochTime r <- lift $ preprocess rpkt when r $ void . async $ tryDoFeed rpkt sessionHandler tryDoFeed :: (MonadUnliftIO m, Transport tp, GetPacketId k rpkt, Eq k, Hashable k) => rpkt -> SessionT u nid k rpkt tp m () -> NodeT u nid k rpkt tp m () tryDoFeed rpkt sessionHandler = do r <- tryAny $ doFeed rpkt sessionHandler case r of Left e -> liftIO $ errorM "Metro.Node" $ "DoFeed Error: " ++ show e Right _ -> pure () doFeed :: (MonadUnliftIO m, GetPacketId k rpkt, Eq k, Hashable k) => rpkt -> SessionT u nid k rpkt tp m () -> NodeT u nid k rpkt tp m () doFeed rpkt sessionHandler = do NodeEnv{..} <- ask v <- case nodeMode of Single -> readTVarIO nodeSession Multi -> HM.lookup sessionList $ getPacketId rpkt case v of Just aEnv -> runSessionT_ aEnv $ feed $ Just rpkt Nothing -> do let sid = getPacketId rpkt dTout <- readTVarIO sessTimeout sEnv <- S.newSessionEnv uEnv nodeId sid dTout [Just rpkt] when (sessionMode == MultiAction) $ case nodeMode of Single -> atomically $ writeTVar nodeSession $ Just sEnv Multi -> HM.insert sessionList sid sEnv bracket (return sid) removeSession $ \_ -> runSessionT_ sEnv sessionHandler startNodeT :: (MonadUnliftIO m, Transport tp, RecvPacket rpkt, GetPacketId k rpkt, Eq k, Hashable k) => SessionT u nid k rpkt tp m () -> NodeT u nid k rpkt tp m () startNodeT = startNodeT_ (const $ return True) startNodeT_ :: (MonadUnliftIO m, Transport tp, RecvPacket rpkt, GetPacketId k rpkt, Eq k, Hashable k) => (rpkt -> m Bool) -> SessionT u nid k rpkt tp m () -> NodeT u nid k rpkt tp m () startNodeT_ preprocess sessionHandler = do sess <- runCheckSessionState void . runMaybeT . forever $ do alive <- lift nodeState if alive then lift $ tryMainLoop preprocess sessionHandler else mzero cancel sess doFeedError nodeState :: MonadIO m => NodeT u nid k rpkt tp m Bool nodeState = readTVarIO =<< asks nodeStatus doFeedError :: MonadIO m => NodeT u nid k rpkt tp m () doFeedError = asks sessionList >>= HM.elems >>= mapM_ go where go :: MonadIO m => SessionEnv u nid k rpkt -> NodeT u nid k rpkt tp m () go aEnv = runSessionT_ aEnv $ feed Nothing stopNodeT :: (MonadIO m, Transport tp) => NodeT u nid k rpkt tp m () stopNodeT = do st <- asks nodeStatus atomically $ writeTVar st False fromConn close env :: Monad m => NodeT u nid k rpkt tp m u env = asks uEnv request :: (MonadUnliftIO m, Transport tp, SendPacket spkt, SetPacketId k spkt, Eq k, Hashable k) => Maybe Int64 -> spkt -> NodeT u nid k rpkt tp m (Maybe rpkt) request sTout = requestAndRetry sTout Nothing requestAndRetry :: (MonadUnliftIO m, Transport tp, SendPacket spkt, SetPacketId k spkt, Eq k, Hashable k) => Maybe Int64 -> Maybe Int -> spkt -> NodeT u nid k rpkt tp m (Maybe rpkt) requestAndRetry sTout retryTout spkt = do alive <- nodeState if alive then withSessionT sTout $ do S.send spkt t <- forM retryTout $ \tout -> async $ forever $ do threadDelay $ tout * 1000 * 1000 S.send spkt ret <- S.receive mapM_ cancel t return ret else return Nothing getTimer :: MonadIO m => NodeT u nid k rpkt tp m Int64 getTimer = readTVarIO =<< asks nodeTimer setTimer :: MonadIO m => Int64 -> NodeT u nid k rpkt tp m () setTimer t = do v <- asks nodeTimer atomically $ writeTVar v t getNodeId :: Monad m => NodeT n nid k rpkt tp m nid getNodeId = asks nodeId runCheckSessionState :: (MonadUnliftIO m, Eq k, Hashable k) => NodeT u nid k rpkt tp m (Async ()) runCheckSessionState = do sessList <- asks sessionList async . forever $ do threadDelay $ 1000 * 1000 * 10 -- 10 seconds mapM_ (checkAlive sessList) =<< HM.elems sessList where checkAlive :: (MonadUnliftIO m, Eq k, Hashable k) => IOHashMap k (SessionEnv u nid k rpkt) -> SessionEnv u nid k rpkt -> NodeT u nid k rpkt tp m () checkAlive sessList sessEnv = runSessionT_ sessEnv $ do to <- isTimeout when to $ do feed Nothing HM.delete sessList (sessionId sessEnv) getSessionSize :: MonadIO m => NodeEnv u nid k rpkt -> m Int getSessionSize NodeEnv {..} = HM.size sessionList getSessionSize1 :: MonadIO m => NodeEnv1 u nid k rpkt tp -> m Int getSessionSize1 NodeEnv1 {..} = getSessionSize nodeEnv