{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
module Test.DejaFu.SCT.Internal.Weighted where
import Control.DeepSeq (NFData)
import Data.List.NonEmpty (toList)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import GHC.Generics (Generic)
import System.Random (RandomGen, randomR)
import Test.DejaFu.Schedule (Scheduler(..))
import Test.DejaFu.Types
data RandSchedState g = RandSchedState
{ schedWeights :: Map ThreadId Int
, schedGen :: g
} deriving (Eq, Show, Generic, NFData)
initialRandSchedState :: g -> RandSchedState g
initialRandSchedState = RandSchedState M.empty
randSched :: RandomGen g => (g -> (Int, g)) -> Scheduler (RandSchedState g)
randSched weightf = Scheduler $ \_ threads s ->
let
pick idx ((x, f):xs)
| idx < f = Just x
| otherwise = pick (idx - f) xs
pick _ [] = Nothing
(choice, g'') = randomR (0, sum (map snd enabled) - 1) g'
enabled = M.toList $ M.filterWithKey (\tid _ -> tid `elem` tids) weights'
(weights', g') = foldr assignWeight (M.empty, schedGen s) tids
assignWeight tid ~(ws, g0) =
let (w, g) = maybe (weightf g0) (\w0 -> (w0, g0)) (M.lookup tid (schedWeights s))
in (M.insert tid w ws, g)
tids = map fst (toList threads)
in (pick choice enabled, RandSchedState weights' g'')