module System.CircuitBreaker (
CircuitBreaker,
CircuitBreakerConf(..),
withBreaker,
CircutBreakerError(..),
HasCircuitConf(..),
CircuitState(..),
CircuitAction(..),
ErrorThreshold(..),
CBCondition(..),
initialBreakerState,
breakerTransitionGuard,
breakerTryPerformAction,
decrementErrorCount
) where
import Control.Monad (forever, void)
import Control.Monad.Trans (liftIO)
import Control.Monad.IO.Unlift (MonadUnliftIO)
import Control.Monad.Reader (MonadReader, asks, ReaderT, ask)
import Data.Maybe (fromMaybe, isNothing)
import Numeric.Natural (Natural)
import GHC.TypeLits (Symbol, KnownSymbol, symbolVal, Nat, KnownNat, natVal)
import Data.Proxy (Proxy(..))
import UnliftIO.Concurrent (forkIO, threadDelay)
import UnliftIO.Exception (bracketOnError, catchDeep)
import UnliftIO.MVar (MVar, putMVar, takeMVar, newMVar, swapMVar)
import qualified Data.Text as T
import qualified Data.HashMap.Strict as M
newtype CircuitBreakerConf = CBConf {
cbBreakers :: MVar (M.HashMap T.Text (MVar CircuitState))
}
initialBreakerState :: MonadUnliftIO m => m CircuitBreakerConf
initialBreakerState = do
mv <- newMVar M.empty
pure CBConf {cbBreakers = mv}
data CBCondition
= Active
| Testing
| Waiting
deriving (Show, Eq, Ord, Bounded, Enum)
data CircuitState = CircuitState {
errorCount :: Natural
, currentState :: !CBCondition
} deriving (Show)
class HasCircuitConf env where
getCircuitState :: env -> CircuitBreakerConf
instance HasCircuitConf CircuitBreakerConf where
getCircuitState = id
newtype ErrorThreshold = ET Natural
newtype DripFreq = DF Natural
reifyCircuitBreaker :: forall label df et. (KnownSymbol label, KnownNat df, KnownNat et) =>
CircuitBreaker label df et
-> (T.Text, DripFreq, ErrorThreshold)
reifyCircuitBreaker _ = (l, df, et)
where
l = T.pack $ symbolVal (Proxy :: Proxy label)
df = DF . (* 1000) . fromIntegral $ natVal (Proxy :: Proxy df)
et = ET . fromIntegral $ natVal (Proxy :: Proxy et)
data CircuitBreaker (label :: Symbol) (dripFreq :: Nat) (errorThreshold :: Nat)
data CircutBreakerError
= Failed
| CircuitBreakerClosed T.Text
deriving (Show, Eq, Ord)
data CircuitAction
= SkipClosed
| Run
deriving (Eq, Show)
withBreaker :: (KnownSymbol label, KnownNat df, KnownNat et, Monad m,
MonadUnliftIO m, MonadReader env m, HasCircuitConf env) =>
CircuitBreaker label df et
-> m a
-> m (Either CircutBreakerError a)
withBreaker breakerDefinition action = do
breakerCell <- cbBreakers <$> asks getCircuitState
breakers <- takeMVar breakerCell
let mbs = label `M.lookup` breakers
newBreaker = CircuitState {
errorCount = 0,
currentState = Active
}
bs <- maybe (newMVar newBreaker) pure mbs
if isNothing mbs
then do
putMVar breakerCell $ M.insert label bs breakers
monitor bs
else putMVar breakerCell breakers
bracketOnError (breakerTransitionGuard bs (ET et)) (onError bs) (breakerTryPerformAction label action bs)
where
(label, DF dripFreq, ET et) = reifyCircuitBreaker breakerDefinition
onError bs Run = do
bs' <- takeMVar bs
let ec' = 1 + errorCount bs'
state = if ec' >= et then Waiting else Active
putMVar bs $ CircuitState {errorCount = ec', currentState = state}
monitor bs =
void . forkIO . forever $ do
threadDelay $ fromIntegral dripFreq
decrementErrorCount bs
breakerTransitionGuard :: (MonadUnliftIO m) =>
MVar CircuitState
-> ErrorThreshold
-> m CircuitAction
breakerTransitionGuard bs (ET et) = do
cb <- takeMVar bs
let elapsed = errorCount cb < fromIntegral et
case currentState cb of
Waiting | elapsed -> do
putMVar bs $ cb {currentState = Testing}
pure Run
Active -> do
putMVar bs cb
pure Run
Waiting -> do
putMVar bs cb
pure SkipClosed
Testing -> do
putMVar bs cb
pure SkipClosed
breakerTryPerformAction :: MonadUnliftIO m =>
T.Text
-> m a
-> MVar CircuitState
-> CircuitAction
-> m (Either CircutBreakerError a)
breakerTryPerformAction label _ _ SkipClosed =
pure . Left $ CircuitBreakerClosed label
breakerTryPerformAction label action bs Run = do
res <- Right <$> action
bs' <- takeMVar bs
putMVar bs (bs' {currentState = Active})
pure res
decrementErrorCount :: MonadUnliftIO m =>
MVar CircuitState
-> m ()
decrementErrorCount breaker = do
rawState <- takeMVar breaker
let ec = errorCount rawState
if ec == 0
then putMVar breaker rawState
else putMVar breaker $ rawState {errorCount = ec -1}