{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE AllowAmbiguousTypes #-} -- Scheduler
{-# LANGUAGE TypeFamilies #-}
{-|

Definition of 'Scheduler' as a way to control application of rewrite rules.

The 'BackoffScheduler' is a scheduler which implements exponential rule backoff
and is used by default in 'Data.Equality.Saturation.equalitySaturation'

-}
module Data.Equality.Saturation.Scheduler
    ( Scheduler(..), BackoffScheduler(..), defaultBackoffScheduler
    ) where

import qualified Data.IntMap.Strict as IM
import Data.Equality.Matching

-- | A 'Scheduler' determines whether a certain rewrite rule is banned from
-- being used based on statistics it defines and collects on applied rewrite
-- rules.
class Scheduler s where
    data Stat s

    -- | Scheduler: update stats
    updateStats :: s                  -- ^ The scheduler itself
                -> Int                -- ^ Iteration we're in
                -> Int                -- ^ Index of rewrite rule we're updating
                -> Maybe (Stat s)     -- ^ Current stat for this rewrite rule (we already got it so no point in doing a lookup again)
                -> IM.IntMap (Stat s) -- ^ The current stats map
                -> [Match]            -- ^ The list of matches resulting from matching this rewrite rule
                -> IM.IntMap (Stat s) -- ^ The updated map with new stats

    -- Decide whether to apply a matched rule based on its stats and current iteration
    isBanned :: Int -- ^ Iteration we're in
             -> Stat s -- ^ Stats for the rewrite rule
             -> Bool -- ^ Whether the rule should be applied or not

-- | A 'Scheduler' that implements exponentional rule backoff.
--
-- For each rewrite, there exists a configurable initial match limit. If a rewrite
-- search yield more than this limit, then we ban this rule for number of
-- iterations, double its limit, and double the time it will be banned next time.
--
-- This seems effective at preventing explosive rules like associativity from
-- taking an unfair amount of resources.
--
-- Originaly in [egg](https://docs.rs/egg/0.6.0/egg/struct.BackoffScheduler.html)
data BackoffScheduler = BackoffScheduler
  { BackoffScheduler -> Int
matchLimit :: {-# UNPACK #-} !Int
  , BackoffScheduler -> Int
banLength  :: {-# UNPACK #-} !Int }

-- | The default 'BackoffScheduler'.
-- 
-- The match limit is set to @1000@ and the ban length is set to @10@.
defaultBackoffScheduler :: BackoffScheduler
defaultBackoffScheduler :: BackoffScheduler
defaultBackoffScheduler = Int -> Int -> BackoffScheduler
BackoffScheduler Int
1000 Int
10

instance Scheduler BackoffScheduler where
    data Stat BackoffScheduler =
      BSS { Stat BackoffScheduler -> Int
bannedUntil :: {-# UNPACK #-} !Int
          , Stat BackoffScheduler -> Int
timesBanned :: {-# UNPACK #-} !Int
          } deriving Int -> Stat BackoffScheduler -> ShowS
[Stat BackoffScheduler] -> ShowS
Stat BackoffScheduler -> String
(Int -> Stat BackoffScheduler -> ShowS)
-> (Stat BackoffScheduler -> String)
-> ([Stat BackoffScheduler] -> ShowS)
-> Show (Stat BackoffScheduler)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Stat BackoffScheduler -> ShowS
showsPrec :: Int -> Stat BackoffScheduler -> ShowS
$cshow :: Stat BackoffScheduler -> String
show :: Stat BackoffScheduler -> String
$cshowList :: [Stat BackoffScheduler] -> ShowS
showList :: [Stat BackoffScheduler] -> ShowS
Show

    updateStats :: BackoffScheduler
-> Int
-> Int
-> Maybe (Stat BackoffScheduler)
-> IntMap (Stat BackoffScheduler)
-> [Match]
-> IntMap (Stat BackoffScheduler)
updateStats BackoffScheduler
bos Int
i Int
rw Maybe (Stat BackoffScheduler)
currentStat IntMap (Stat BackoffScheduler)
stats [Match]
matches =

        if Int
total_len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
threshold

          then
            (Maybe (Stat BackoffScheduler) -> Maybe (Stat BackoffScheduler))
-> Int
-> IntMap (Stat BackoffScheduler)
-> IntMap (Stat BackoffScheduler)
forall a. (Maybe a -> Maybe a) -> Int -> IntMap a -> IntMap a
IM.alter Maybe (Stat BackoffScheduler) -> Maybe (Stat BackoffScheduler)
updateBans Int
rw IntMap (Stat BackoffScheduler)
stats

          else
            IntMap (Stat BackoffScheduler)
stats

        where

          -- TODO: Overall difficult, and buggy at the moment.
          total_len :: Int
total_len = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Match -> Int) -> [Match] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (IntMap Int -> Int
forall a. IntMap a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (IntMap Int -> Int) -> (Match -> IntMap Int) -> Match -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Match -> IntMap Int
matchSubst) [Match]
matches)

          bannedN :: Int
bannedN = case Maybe (Stat BackoffScheduler)
currentStat of
                      Maybe (Stat BackoffScheduler)
Nothing -> Int
0;
                      Just (Stat BackoffScheduler -> Int
timesBanned -> Int
n) -> Int
n

          threshold :: Int
threshold = BackoffScheduler -> Int
matchLimit BackoffScheduler
bos Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
2Int -> Int -> Int
forall a b. (Num a, Integral b) => a -> b -> a
^Int
bannedN)

          ban_length :: Int
ban_length = BackoffScheduler -> Int
banLength BackoffScheduler
bos Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
2Int -> Int -> Int
forall a b. (Num a, Integral b) => a -> b -> a
^Int
bannedN)

          updateBans :: Maybe (Stat BackoffScheduler) -> Maybe (Stat BackoffScheduler)
updateBans = \case
            Maybe (Stat BackoffScheduler)
Nothing -> Stat BackoffScheduler -> Maybe (Stat BackoffScheduler)
forall a. a -> Maybe a
Just (Int -> Int -> Stat BackoffScheduler
BSS (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
ban_length) Int
1)
            Just (BSS Int
_ Int
n)  -> Stat BackoffScheduler -> Maybe (Stat BackoffScheduler)
forall a. a -> Maybe a
Just (Int -> Int -> Stat BackoffScheduler
BSS (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
ban_length) (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1))

    isBanned :: Int -> Stat BackoffScheduler -> Bool
isBanned Int
i Stat BackoffScheduler
s = Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Stat BackoffScheduler -> Int
bannedUntil Stat BackoffScheduler
s