{-# LANGUAGE ScopedTypeVariables #-}
module Immutable.Shuffle where
import Control.Monad.Primitive
import Control.Monad.Random (MonadRandom (..))
import Control.Monad.ST (runST)
import Data.Vector
import qualified Mutable.Shuffle as MS
import Prelude hiding (length, take)
import System.Random (RandomGen (..))
shuffle :: forall a g. RandomGen g => Vector a -> g -> (Vector a, g)
shuffle :: Vector a -> g -> (Vector a, g)
shuffle v :: Vector a
v g :: g
g
| Vector a -> Int
forall a. Vector a -> Int
length Vector a
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 1 = (Vector a
v, g
g)
| Bool
otherwise =
(forall s. ST s (Vector a, g)) -> (Vector a, g)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Vector a, g)) -> (Vector a, g))
-> (forall s. ST s (Vector a, g)) -> (Vector a, g)
forall a b. (a -> b) -> a -> b
$
do
MVector s a
mutV <- Vector a -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
thaw Vector a
v
g
newGen <- MVector (PrimState (ST s)) a -> g -> ST s g
forall (m :: * -> *) a g.
(PrimMonad m, RandomGen g) =>
MVector (PrimState m) a -> g -> m g
MS.shuffle MVector s a
MVector (PrimState (ST s)) a
mutV g
g
Vector a
immutV <- MVector (PrimState (ST s)) a -> ST s (Vector a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
unsafeFreeze MVector s a
MVector (PrimState (ST s)) a
mutV
(Vector a, g) -> ST s (Vector a, g)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Vector a
immutV, g
newGen)
shuffleM :: forall m a . (MonadRandom m, PrimMonad m) => Vector a -> m (Vector a)
shuffleM :: Vector a -> m (Vector a)
shuffleM v :: Vector a
v
| Vector a -> Int
forall a. Vector a -> Int
length Vector a
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 1 = Vector a -> m (Vector a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Vector a
v
| Bool
otherwise =
do
MVector (PrimState m) a
mutV <- Vector a -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
thaw Vector a
v
MVector (PrimState m) a -> m ()
forall (m :: * -> *) a.
(PrimMonad m, MonadRandom m) =>
MVector (PrimState m) a -> m ()
MS.shuffleM MVector (PrimState m) a
mutV
MVector (PrimState m) a -> m (Vector a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
unsafeFreeze MVector (PrimState m) a
mutV
shuffleK :: forall m a . (MonadRandom m, PrimMonad m) => Int -> Vector a -> m (Vector a)
shuffleK :: Int -> Vector a -> m (Vector a)
shuffleK k :: Int
k v :: Vector a
v
| Vector a -> Int
forall a. Vector a -> Int
length Vector a
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 1 = Vector a -> m (Vector a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Vector a
v
| Bool
otherwise =
do
MVector (PrimState m) a
mutV <- Vector a -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
thaw Vector a
v
Int -> MVector (PrimState m) a -> m ()
forall (m :: * -> *) a.
(PrimMonad m, MonadRandom m) =>
Int -> MVector (PrimState m) a -> m ()
MS.shuffleK Int
k MVector (PrimState m) a
mutV
MVector (PrimState m) a -> m (Vector a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
unsafeFreeze MVector (PrimState m) a
mutV
sampleWithoutReplacement :: forall m a . (MonadRandom m, PrimMonad m) => Int -> Vector a -> m (Vector a)
{-# INLINEABLE sampleWithoutReplacement #-}
sampleWithoutReplacement :: Int -> Vector a -> m (Vector a)
sampleWithoutReplacement k :: Int
k v :: Vector a
v = Int -> Vector a -> Vector a
forall a. Int -> Vector a -> Vector a
take Int
k (Vector a -> Vector a) -> m (Vector a) -> m (Vector a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Vector a -> m (Vector a)
forall (m :: * -> *) a.
(MonadRandom m, PrimMonad m) =>
Int -> Vector a -> m (Vector a)
shuffleK Int
k Vector a
v
maximalCycle :: forall a g. RandomGen g => Vector a -> g -> (Vector a, g)
maximalCycle :: Vector a -> g -> (Vector a, g)
maximalCycle v :: Vector a
v g :: g
g
| Vector a -> Int
forall a. Vector a -> Int
length Vector a
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 1 = (Vector a
v, g
g)
| Bool
otherwise =
(forall s. ST s (Vector a, g)) -> (Vector a, g)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Vector a, g)) -> (Vector a, g))
-> (forall s. ST s (Vector a, g)) -> (Vector a, g)
forall a b. (a -> b) -> a -> b
$
do
MVector s a
mutV <- Vector a -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
thaw Vector a
v
g
newGen <- MVector (PrimState (ST s)) a -> g -> ST s g
forall (m :: * -> *) a g.
(PrimMonad m, RandomGen g) =>
MVector (PrimState m) a -> g -> m g
MS.maximalCycle MVector s a
MVector (PrimState (ST s)) a
mutV g
g
Vector a
immutV <- MVector (PrimState (ST s)) a -> ST s (Vector a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
unsafeFreeze MVector s a
MVector (PrimState (ST s)) a
mutV
(Vector a, g) -> ST s (Vector a, g)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Vector a
immutV, g
newGen)
maximalCycleM :: forall m a . (MonadRandom m, PrimMonad m) => Vector a -> m (Vector a)
maximalCycleM :: Vector a -> m (Vector a)
maximalCycleM v :: Vector a
v
| Vector a -> Int
forall a. Vector a -> Int
length Vector a
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 1 = Vector a -> m (Vector a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Vector a
v
| Bool
otherwise =
do
MVector (PrimState m) a
mutV <- Vector a -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
thaw Vector a
v
MVector (PrimState m) a -> m ()
forall (m :: * -> *) a.
(PrimMonad m, MonadRandom m) =>
MVector (PrimState m) a -> m ()
MS.maximalCycleM MVector (PrimState m) a
mutV
MVector (PrimState m) a -> m (Vector a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
unsafeFreeze MVector (PrimState m) a
mutV
derangement :: forall a g. (Eq a, RandomGen g) => Vector a -> g -> (Vector a, g)
derangement :: Vector a -> g -> (Vector a, g)
derangement v :: Vector a
v g :: g
g
| Vector a -> Int
forall a. Vector a -> Int
length Vector a
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 1 = (Vector a
v, g
g)
| Bool
otherwise =
(forall s. ST s (Vector a, g)) -> (Vector a, g)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Vector a, g)) -> (Vector a, g))
-> (forall s. ST s (Vector a, g)) -> (Vector a, g)
forall a b. (a -> b) -> a -> b
$
do
MVector s a
mutV <- Vector a -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
thaw Vector a
v
g
newGen <- MVector (PrimState (ST s)) a -> g -> ST s g
forall (m :: * -> *) a g.
(PrimMonad m, RandomGen g, Eq a) =>
MVector (PrimState m) a -> g -> m g
MS.derangement MVector s a
MVector (PrimState (ST s)) a
mutV g
g
Vector a
immutV <- MVector (PrimState (ST s)) a -> ST s (Vector a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
unsafeFreeze MVector s a
MVector (PrimState (ST s)) a
mutV
(Vector a, g) -> ST s (Vector a, g)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Vector a
immutV, g
newGen)
derangementM :: forall m a . (Eq a, MonadRandom m, PrimMonad m) => Vector a -> m (Vector a)
derangementM :: Vector a -> m (Vector a)
derangementM v :: Vector a
v
| Vector a -> Int
forall a. Vector a -> Int
length Vector a
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 1 = Vector a -> m (Vector a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Vector a
v
| Bool
otherwise =
do
MVector (PrimState m) a
mutV <- Vector a -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
thaw Vector a
v
MVector (PrimState m) a -> m ()
forall (m :: * -> *) a.
(PrimMonad m, MonadRandom m, Eq a) =>
MVector (PrimState m) a -> m ()
MS.derangementM MVector (PrimState m) a
mutV
MVector (PrimState m) a -> m (Vector a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
unsafeFreeze MVector (PrimState m) a
mutV