{-# LANGUAGE RankNTypes #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE RecordWildCards #-} module Data.PowerQueue.Worker.Distributed ( AuthToken(..), AppVersion(..) , WorkMasterConfig(..), ServerErrorEvent(..), launchWorkMaster , WorkNodeConfig(..), ClientErrorEvent(..), launchWorkNode, launchReconnectingWorkNode ) where import Control.Exception import Control.Monad.Trans import Data.Conduit import Data.Conduit.Cereal import Data.Conduit.Network import Data.IORef import Data.PowerQueue import Data.String import Data.Time.TimeSpan import Data.Word import GHC.Generics import qualified Data.ByteString as BS import qualified Data.Serialize as S import qualified Data.Text as T import qualified Data.Text.Encoding as T newtype AuthToken = AuthToken { unAuthToken :: T.Text } deriving (Show, Eq) instance S.Serialize AuthToken where put (AuthToken t) = S.put $ T.encodeUtf8 t get = AuthToken . T.decodeUtf8 <$> S.get newtype AppVersion = AppVersion { unAppVersion :: Word64 } deriving (Show, Eq, Ord) instance S.Serialize AppVersion where put (AppVersion av) = S.putWord64le av get = AppVersion <$> S.getWord64le data Message = Message { m_version :: !AppVersion , m_payload :: !BS.ByteString } instance S.Serialize Message where put msg = do S.put (m_version msg) S.putWord64le (fromIntegral $ BS.length $ m_payload msg) S.putByteString (m_payload msg) get = do vers <- S.get bsLen <- S.getWord64le bs <- S.getByteString (fromIntegral bsLen) pure $ Message vers bs data ClientPayload = CpAuth !AuthToken | CpDrain | CpCompleted | CpRollback deriving (Generic) instance S.Serialize ClientPayload data ServerPayload j = SpJob !j | SpHello | SpBadAuth | SpBadState deriving (Generic) instance S.Serialize j => S.Serialize (ServerPayload j) -- | Work master configuration data WorkMasterConfig = WorkMasterConfig { wmc_host :: !T.Text , wmc_port :: !Int , wmc_authToken :: !AuthToken -- ^ reject all client that do not send this token. See 'wnc_authToken' , wmc_appVersion :: !AppVersion -- ^ required app version for all clients. See 'wnc_appVersion' , wmc_errorHook :: !(ServerErrorEvent -> IO ()) -- ^ a (non-)critical error occured. Useful for logging } -- | Work master errors data ServerErrorEvent = SeeClientDisconnect | SeeClientBadVersion !AppVersion | SeeInvalidPayload !String deriving (Show, Eq) data ServerCliState = ServerCliState { scs_isAuthed :: !Bool , scs_rollbackJob :: !(Maybe (IO ())) , scs_confirmJob :: !(Maybe (IO ())) } -- | Launch a work master on current thread that will distribute all incoming work on a queue -- to connecting worker nodes launched via 'launchWorkNode' launchWorkMaster :: forall j. S.Serialize j => WorkMasterConfig -> QueueBackend j -> IO () launchWorkMaster wmc QueueBackend{..} = runTCPServer tcpSettings $ \app -> appSource app .| conduitGet2 S.get .| handleMessage initCliSt .| conduitPut S.put $$ appSink app where initCliSt = ServerCliState { scs_isAuthed = False , scs_rollbackJob = Nothing , scs_confirmJob = Nothing } evt cliSt e = liftIO $ do wmc_errorHook wmc e case scs_rollbackJob cliSt of Just rb -> rb Nothing -> pure () srvSend :: Monad m => ServerPayload j -> Conduit a m Message srvSend payload = yield Message { m_version = wmc_appVersion wmc , m_payload = S.encode payload } handleMessage cliSt = await >>= \mMsg -> case mMsg of Nothing -> evt cliSt SeeClientDisconnect Just message | m_version message /= wmc_appVersion wmc -> evt cliSt $ SeeClientBadVersion (m_version message) | otherwise -> case S.decode (m_payload message) of Left errMsg -> evt cliSt $ SeeInvalidPayload errMsg Right cliPayload -> handleCliPayload cliSt cliPayload handleCliPayload cliSt cliPayload = case cliPayload of CpAuth tok -> do authState <- if tok == wmc_authToken wmc then do srvSend SpHello pure True else do srvSend SpBadAuth pure False handleMessage $ cliSt { scs_isAuthed = authState } CpDrain | scs_isAuthed cliSt -> do (txId, job) <- liftIO $ qb_lift qb_dequeue srvSend $ SpJob job handleMessage $ cliSt { scs_rollbackJob = Just $ qb_lift (qb_rollback txId) , scs_confirmJob = Just $ qb_lift (qb_confirm txId) } CpCompleted | scs_isAuthed cliSt -> case scs_confirmJob cliSt of Nothing -> do srvSend SpBadState handleMessage cliSt Just ok -> do liftIO ok handleMessage $ cliSt { scs_rollbackJob = Nothing, scs_confirmJob = Nothing } CpRollback | scs_isAuthed cliSt -> case scs_rollbackJob cliSt of Nothing -> do srvSend SpBadState handleMessage cliSt Just rollback -> do liftIO rollback handleMessage $ cliSt { scs_rollbackJob = Nothing, scs_confirmJob = Nothing } _ -> do srvSend SpBadAuth handleMessage cliSt tcpSettings = serverSettings (wmc_port wmc) (fromString $ T.unpack $ wmc_host wmc) -- | Work node configuration data WorkNodeConfig = WorkNodeConfig { wnc_hostMaster :: !T.Text -- ^ host where the work master is running. See 'wmc_host' , wnc_portMaster :: !Int -- ^ port of work master. See 'wmc_port' , wnc_authToken :: !AuthToken -- ^ the authentification token. MUST match the masters 'wmc_authToken'! , wnc_appVersion :: !AppVersion -- ^ the current app version. MUST match the masters 'wmc_appVersion'! , wnc_errorHook :: !(ClientErrorEvent -> IO ()) -- ^ a (non-)critical error occured. Useful for logging , wnc_readyHook :: !(IO ()) -- ^ called once when ready for draining } -- | Work node async errors data ClientErrorEvent = CeeConnClosed | CeeBadAuthResponse | CeeInvalidAuthResponse | CeeInvalidDrainResponse | CeeServerBadVersion !AppVersion | CeeInvalidPayload !String | CeeWorkerException !String deriving (Show, Eq) data ClientState = ClientState { cs_authed :: !Bool } -- | Launch a worker node on the current thread connecting to a work master launched with -- 'launchWorkMaster' launchWorkNode :: forall j. S.Serialize j => WorkNodeConfig -> QueueWorker j -> IO () launchWorkNode wnc QueueWorker{..} = runTCPClient tcpSettings $ \app -> appSource app .| conduitGet2 S.get .| handleMessage initCliSt .| conduitPut S.put $$ appSink app where tcpSettings = clientSettings (wnc_portMaster wnc) (T.encodeUtf8 $ wnc_hostMaster wnc) initCliSt = ClientState { cs_authed = False } evt _ e = liftIO $ wnc_errorHook wnc e cliSend :: Monad m => ClientPayload -> Conduit a m Message cliSend payload = yield Message { m_version = wnc_appVersion wnc , m_payload = S.encode payload } awaitSrv :: MonadIO m => ClientState -> (ServerPayload j -> Conduit Message m Message) -> Conduit Message m Message awaitSrv cliSt go = do msg <- await case msg of Nothing -> evt cliSt CeeConnClosed Just (Message msgVer msgBsl) | msgVer == wnc_appVersion wnc -> case S.decode msgBsl of Left err -> evt cliSt $ CeeInvalidPayload err Right ok -> go ok | otherwise -> evt cliSt $ CeeServerBadVersion msgVer handleMessage cliSt | not (cs_authed cliSt) = do cliSend (CpAuth $ wnc_authToken wnc) awaitSrv cliSt $ \msg -> case msg of SpHello -> do liftIO $ wnc_readyHook wnc handleMessage $ cliSt { cs_authed = True } _ -> evt cliSt CeeInvalidAuthResponse | otherwise = workLoop cliSt workLoop cliSt = do cliSend CpDrain awaitSrv cliSt $ \msg -> case msg of SpJob job -> do execRes <- liftIO $ try $ qw_execute job case execRes of Left (e :: SomeException) -> do evt cliSt $ CeeWorkerException (show e) cliSend CpRollback Right res -> case res of JOk -> cliSend CpCompleted JRetry -> cliSend CpRollback workLoop cliSt SpBadAuth -> evt cliSt CeeBadAuthResponse _ -> evt cliSt CeeInvalidDrainResponse launchReconnectingWorkNode :: forall j. S.Serialize j => WorkNodeConfig -> (TimeSpan -> IO ()) -- ^ callback: will retry in 'TimeSpan'. Useful for logging -> QueueWorker j -> IO () launchReconnectingWorkNode wnc retryCallback qw = reconnStep (milliseconds 200) where reconnStep ts = do conGood <- newIORef False let wnc' = wnc { wnc_readyHook = do writeIORef conGood True wnc_readyHook wnc } (e :: Either SomeException ()) <- try $ launchWorkNode wnc' qw print e retryCallback ts sleepTS ts wasGood <- readIORef conGood reconnStep $ if wasGood then milliseconds 200 else multiplyTS ts 2