{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}

-- |
-- Module      : Test.DejaFu.SCT.Internal.Weighted
-- Copyright   : (c) 2015--2019 Michael Walker
-- License     : MIT
-- Maintainer  : Michael Walker <mike@barrucadu.co.uk>
-- Stability   : experimental
-- Portability : DeriveAnyClass, DeriveGeneric
--
-- Internal types and functions for SCT via weighted random
-- scheduling.  This module is NOT considered to form part of the
-- public interface of this library.
module Test.DejaFu.SCT.Internal.Weighted where

import           Control.DeepSeq      (NFData)
import           Data.List.NonEmpty   (toList)
import           Data.Map.Strict      (Map)
import qualified Data.Map.Strict      as M
import           GHC.Generics         (Generic)
import           System.Random        (RandomGen, randomR)

import           Test.DejaFu.Schedule (Scheduler(..))
import           Test.DejaFu.Types

-------------------------------------------------------------------------------
-- * Weighted random scheduler

-- | The scheduler state
data RandSchedState g = RandSchedState
  { forall g. RandSchedState g -> Map ThreadId Int
schedWeights :: Map ThreadId Int
  -- ^ The thread weights: used in determining which to run.
  , forall g. RandSchedState g -> Maybe LengthBound
schedLengthBound :: Maybe LengthBound
  -- ^ The optional length bound.
  , forall g. RandSchedState g -> g
schedGen :: g
  -- ^ The random number generator.
  } deriving (RandSchedState g -> RandSchedState g -> Bool
forall g. Eq g => RandSchedState g -> RandSchedState g -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: RandSchedState g -> RandSchedState g -> Bool
$c/= :: forall g. Eq g => RandSchedState g -> RandSchedState g -> Bool
== :: RandSchedState g -> RandSchedState g -> Bool
$c== :: forall g. Eq g => RandSchedState g -> RandSchedState g -> Bool
Eq, Int -> RandSchedState g -> ShowS
forall g. Show g => Int -> RandSchedState g -> ShowS
forall g. Show g => [RandSchedState g] -> ShowS
forall g. Show g => RandSchedState g -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RandSchedState g] -> ShowS
$cshowList :: forall g. Show g => [RandSchedState g] -> ShowS
show :: RandSchedState g -> String
$cshow :: forall g. Show g => RandSchedState g -> String
showsPrec :: Int -> RandSchedState g -> ShowS
$cshowsPrec :: forall g. Show g => Int -> RandSchedState g -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall g x. Rep (RandSchedState g) x -> RandSchedState g
forall g x. RandSchedState g -> Rep (RandSchedState g) x
$cto :: forall g x. Rep (RandSchedState g) x -> RandSchedState g
$cfrom :: forall g x. RandSchedState g -> Rep (RandSchedState g) x
Generic, forall g. NFData g => RandSchedState g -> ()
forall a. (a -> ()) -> NFData a
rnf :: RandSchedState g -> ()
$crnf :: forall g. NFData g => RandSchedState g -> ()
NFData)

-- | Initial weighted random scheduler state.
initialRandSchedState :: Maybe LengthBound -> g -> RandSchedState g
initialRandSchedState :: forall g. Maybe LengthBound -> g -> RandSchedState g
initialRandSchedState = forall g.
Map ThreadId Int -> Maybe LengthBound -> g -> RandSchedState g
RandSchedState forall k a. Map k a
M.empty

-- | Weighted random scheduler: assigns to each new thread a weight,
-- and makes a weighted random choice out of the runnable threads at
-- every step.
randSched :: RandomGen g => (g -> (Int, g)) -> Scheduler (RandSchedState g)
randSched :: forall g.
RandomGen g =>
(g -> (Int, g)) -> Scheduler (RandSchedState g)
randSched g -> (Int, g)
weightf = forall state.
(Maybe (ThreadId, ThreadAction)
 -> NonEmpty (ThreadId, Lookahead)
 -> ConcurrencyState
 -> state
 -> (Maybe ThreadId, state))
-> Scheduler state
Scheduler forall a b. (a -> b) -> a -> b
$ \Maybe (ThreadId, ThreadAction)
_ NonEmpty (ThreadId, Lookahead)
threads ConcurrencyState
_ RandSchedState g
s ->
  let
    -- Select a thread
    pick :: t -> [(a, t)] -> Maybe a
pick t
idx ((a
x, t
f):[(a, t)]
xs)
      | t
idx forall a. Ord a => a -> a -> Bool
< t
f = forall a. a -> Maybe a
Just a
x
      | Bool
otherwise = t -> [(a, t)] -> Maybe a
pick (t
idx forall a. Num a => a -> a -> a
- t
f) [(a, t)]
xs
    pick t
_ [] = forall a. Maybe a
Nothing
    (Int
choice, g
g'') = forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
randomR (Int
0, forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(ThreadId, Int)]
enabled) forall a. Num a => a -> a -> a
- Int
1) g
g'
    enabled :: [(ThreadId, Int)]
enabled = forall k a. Map k a -> [(k, a)]
M.toList forall a b. (a -> b) -> a -> b
$ forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey (\ThreadId
tid Int
_ -> ThreadId
tid forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ThreadId]
tids) Map ThreadId Int
weights'

    -- The weights, with any new threads added.
    (Map ThreadId Int
weights', g
g') = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ThreadId -> (Map ThreadId Int, g) -> (Map ThreadId Int, g)
assignWeight (forall k a. Map k a
M.empty, forall g. RandSchedState g -> g
schedGen RandSchedState g
s) [ThreadId]
tids
    assignWeight :: ThreadId -> (Map ThreadId Int, g) -> (Map ThreadId Int, g)
assignWeight ThreadId
tid ~(Map ThreadId Int
ws, g
g0) =
      let (Int
w, g
g) = forall b a. b -> (a -> b) -> Maybe a -> b
maybe (g -> (Int, g)
weightf g
g0) (\Int
w0 -> (Int
w0, g
g0)) (forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup ThreadId
tid (forall g. RandSchedState g -> Map ThreadId Int
schedWeights RandSchedState g
s))
      in (forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert ThreadId
tid Int
w Map ThreadId Int
ws, g
g)

    -- The runnable threads.
    tids :: [ThreadId]
tids = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst (forall a. NonEmpty a -> [a]
toList NonEmpty (ThreadId, Lookahead)
threads)
  in case forall g. RandSchedState g -> Maybe LengthBound
schedLengthBound RandSchedState g
s of
    Just LengthBound
0 -> (forall a. Maybe a
Nothing, RandSchedState g
s)
    Just LengthBound
n -> (forall {t} {a}. (Ord t, Num t) => t -> [(a, t)] -> Maybe a
pick Int
choice [(ThreadId, Int)]
enabled, forall g.
Map ThreadId Int -> Maybe LengthBound -> g -> RandSchedState g
RandSchedState Map ThreadId Int
weights' (forall a. a -> Maybe a
Just (LengthBound
n forall a. Num a => a -> a -> a
- LengthBound
1)) g
g'')
    Maybe LengthBound
Nothing -> (forall {t} {a}. (Ord t, Num t) => t -> [(a, t)] -> Maybe a
pick Int
choice [(ThreadId, Int)]
enabled, forall g.
Map ThreadId Int -> Maybe LengthBound -> g -> RandSchedState g
RandSchedState Map ThreadId Int
weights' forall a. Maybe a
Nothing g
g'')