module Test.DejaFu.SCT.Internal.Weighted where
import Control.DeepSeq (NFData(..))
import Data.List.NonEmpty (toList)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Maybe (fromMaybe)
import System.Random (RandomGen, randomR)
import Test.DejaFu.Schedule (Scheduler(..))
import Test.DejaFu.Types
data RandSchedState g = RandSchedState
{ schedWeights :: Map ThreadId Int
, schedGen :: g
} deriving (Eq, Show)
instance NFData g => NFData (RandSchedState g) where
rnf s = rnf ( schedWeights s
, schedGen s
)
initialRandSchedState :: Maybe (Map ThreadId Int) -> g -> RandSchedState g
initialRandSchedState = RandSchedState . fromMaybe M.empty
randSched :: RandomGen g => (g -> (Int, g)) -> Scheduler (RandSchedState g)
randSched weightf = Scheduler $ \_ threads s ->
let
pick idx ((x, f):xs)
| idx < f = Just x
| otherwise = pick (idx f) xs
pick _ [] = Nothing
(choice, g'') = randomR (0, sum (map snd enabled) 1) g'
enabled = M.toList $ M.filterWithKey (\tid _ -> tid `elem` tids) weights'
(weights', g') = foldr assignWeight (M.empty, schedGen s) tids
assignWeight tid ~(ws, g0) =
let (w, g) = maybe (weightf g0) (\w0 -> (w0, g0)) (M.lookup tid (schedWeights s))
in (M.insert tid w ws, g)
tids = map fst (toList threads)
in (pick choice enabled, RandSchedState weights' g'')