```{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | This module provides functions to perform shuffles on mutable vectors.
-- The shuffling is uniform amongst all permuations and uses the minimal
-- number of transpositions.

module Mutable.Shuffle where

import           Data.Vector.Generic.Mutable
import           Prelude                     hiding (length, read, tail)
import           System.Random               (RandomGen)
import qualified System.Random               as SR

-- |
-- Perform a shuffle on a mutable vector with a given random generator, returning a new random generator.
--
-- This uses the Fisher--Yates--Knuth algorithm
shuffle
:: forall m a g v
, RandomGen g
, MVector v a
)
=> v (PrimState m) a -> g -> m g
{-# INLINABLE shuffle #-}
shuffle :: forall (m :: * -> *) a g (v :: * -> * -> *).
(PrimMonad m, RandomGen g, MVector v a) =>
v (PrimState m) a -> g -> m g
shuffle v (PrimState m) a
mutV g
gen = v (PrimState m) a -> g -> Int -> m g
go v (PrimState m) a
mutV g
gen (v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) a
mutV Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
where
go :: v (PrimState m) a -> g -> Int -> m g
{-# INLINE go #-}
go :: v (PrimState m) a -> g -> Int -> m g
go v (PrimState m) a
_ g
g (- 1)  = g -> m g
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure g
g
go v (PrimState m) a
_ g
g Int
0      =  g -> m g
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure g
g
go v (PrimState m) a
v g
g Int
maxInd =
do
let (Int
ind, g
newGen) :: (Int, g) = (Int, Int) -> g -> (Int, g)
forall g. RandomGen g => (Int, Int) -> g -> (Int, g)
forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
SR.randomR (Int
0, Int
maxInd) g
g
v (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
swap v (PrimState m) a
v Int
0 Int
ind
v (PrimState m) a -> g -> Int -> m g
go (v (PrimState m) a -> v (PrimState m) a
forall (v :: * -> * -> *) a s. MVector v a => v s a -> v s a
tail v (PrimState m) a
v) g
newGen (Int
maxInd Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

-- |
-- Perform a shuffle on a mutable vector in a monad which has a source of randomness.
--
-- This uses the Fisher--Yates--Knuth algorithm
shuffleM
:: forall m a v
, MVector v a
)
=> v (PrimState m) a  -> m ()
{-# INLINABLE shuffleM #-}
shuffleM :: forall (m :: * -> *) a (v :: * -> * -> *).
v (PrimState m) a -> m ()
shuffleM v (PrimState m) a
mutV = v (PrimState m) a -> Int -> m ()
go v (PrimState m) a
mutV (v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) a
mutV Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
where
go :: v (PrimState m) a -> Int -> m ()
{-# INLINE go #-}
go :: v (PrimState m) a -> Int -> m ()
go v (PrimState m) a
_ (- 1)  = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
go v (PrimState m) a
_ Int
0      = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
go v (PrimState m) a
v Int
maxInd =
do
Int
ind <-  (Int, Int) -> m Int
forall a. Random a => (a, a) -> m a
forall (m :: * -> *) a. (MonadRandom m, Random a) => (a, a) -> m a
getRandomR (Int
0, Int
maxInd)
v (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
swap v (PrimState m) a
v Int
0 Int
ind
v (PrimState m) a -> Int -> m ()
go (v (PrimState m) a -> v (PrimState m) a
forall (v :: * -> * -> *) a s. MVector v a => v s a -> v s a
tail v (PrimState m) a
v) (Int
maxInd Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

-- {-# SPECIALISE shuffleM :: MVector RealWorld a -> IO () #-}

-- |
-- Shuffle the first k elements of a vector.
--
shuffleK
:: forall m a v
, MVector v a
)
=> Int -> v (PrimState m) a  -> m ()
{-# INLINABLE shuffleK #-}
shuffleK :: forall (m :: * -> *) a (v :: * -> * -> *).
Int -> v (PrimState m) a -> m ()
shuffleK Int
numberOfShuffles v (PrimState m) a
mutV = v (PrimState m) a -> Int -> m ()
go v (PrimState m) a
mutV (Int
numberOfShuffles Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
where
go :: v (PrimState m) a -> Int -> m ()
{-# INLINE go #-}
go :: v (PrimState m) a -> Int -> m ()
go v (PrimState m) a
_ Int
k | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0
= [Char] -> m ()
forall a. HasCallStack => [Char] -> a
error [Char]
"Cannot pass negative value to ShuffleK"
go v (PrimState m) a
_ Int
k | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) a
mutV
= [Char] -> m ()
forall a. HasCallStack => [Char] -> a
error [Char]
"Cannot pass value greater than the length of the vector  to ShuffleK"
go v (PrimState m) a
_ Int
0   =  () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
go v (PrimState m) a
v Int
maxInd =
do
Int
ind <-  (Int, Int) -> m Int
forall a. Random a => (a, a) -> m a
forall (m :: * -> *) a. (MonadRandom m, Random a) => (a, a) -> m a
getRandomR (Int
0, Int
maxInd)
v (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
swap v (PrimState m) a
v Int
0 Int
ind
v (PrimState m) a -> Int -> m ()
go (v (PrimState m) a -> v (PrimState m) a
forall (v :: * -> * -> *) a s. MVector v a => v s a -> v s a
tail v (PrimState m) a
v) (Int
maxInd Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

-- |
-- Perform a shuffle on a mutable vector wherein the shuffled indices form a maximal cycle.
--
-- This uses the Sattolo algorithm.
maximalCycle
:: forall m a g v
, RandomGen g
, MVector v a
)
=> v (PrimState m) a -> g -> m g
{-# INLINABLE maximalCycle #-}
maximalCycle :: forall (m :: * -> *) a g (v :: * -> * -> *).
(PrimMonad m, RandomGen g, MVector v a) =>
v (PrimState m) a -> g -> m g
maximalCycle v (PrimState m) a
mutV g
gen = v (PrimState m) a -> g -> Int -> m g
go v (PrimState m) a
mutV g
gen (v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) a
mutV Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
where
go :: v (PrimState m) a -> g -> Int -> m g
{-# INLINE go #-}
go :: v (PrimState m) a -> g -> Int -> m g
go v (PrimState m) a
_ g
g (- 1)  =  g -> m g
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure g
g
go v (PrimState m) a
_ g
g Int
0      =  g -> m g
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure g
g
go v (PrimState m) a
v g
g Int
maxInd =
do
let (Int
ind, g
newGen) :: (Int, g) = (Int, Int) -> g -> (Int, g)
forall g. RandomGen g => (Int, Int) -> g -> (Int, g)
forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
SR.randomR (Int
1, Int
maxInd) g
g
v (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
swap v (PrimState m) a
v Int
0 Int
ind
v (PrimState m) a -> g -> Int -> m g
go (v (PrimState m) a -> v (PrimState m) a
forall (v :: * -> * -> *) a s. MVector v a => v s a -> v s a
tail v (PrimState m) a
v) g
newGen (Int
maxInd Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

-- |
-- Perform a shuffle on a mutable vector wherein the shuffled indices form a maximal cycle
-- in a monad with a source of randomness.
--
-- This uses the Sattolo algorithm.
maximalCycleM
:: forall m a v
, MVector v a
)
=> v (PrimState m) a  -> m ()
{-# INLINABLE maximalCycleM #-}
maximalCycleM :: forall (m :: * -> *) a (v :: * -> * -> *).
v (PrimState m) a -> m ()
maximalCycleM v (PrimState m) a
mutV = v (PrimState m) a -> Int -> m ()
go v (PrimState m) a
mutV (v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) a
mutV Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
where
go :: v (PrimState m) a -> Int -> m ()
{-# INLINE go #-}
go :: v (PrimState m) a -> Int -> m ()
go v (PrimState m) a
_ (- 1)  =  () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
go v (PrimState m) a
_ Int
0      =  () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
go v (PrimState m) a
v Int
maxInd =
do
Int
ind <-  (Int, Int) -> m Int
forall a. Random a => (a, a) -> m a
forall (m :: * -> *) a. (MonadRandom m, Random a) => (a, a) -> m a
getRandomR (Int
1, Int
maxInd)
v (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
swap v (PrimState m) a
v Int
0 Int
ind
v (PrimState m) a -> Int -> m ()
go (v (PrimState m) a -> v (PrimState m) a
forall (v :: * -> * -> *) a s. MVector v a => v s a -> v s a
tail v (PrimState m) a
v) (Int
maxInd Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

-- {-# SPECIALISE maximalCycleM :: MVector RealWorld a -> IO () #-}

-- |
-- Perform a [derangement](https://en.wikipedia.org/wiki/Derangement) on a mutable vector with a given random generator, returning a new random generator.
--
-- __Note:__ It is assumed the input vector consists of distinct values.
--
-- This uses the "early refusal" algorithm.
derangement
:: forall m a g v
, RandomGen g
, Eq a
, MVector v a
)
=> v (PrimState m) a -> g -> m g
{-# INLINABLE derangement #-}
derangement :: forall (m :: * -> *) a g (v :: * -> * -> *).
(PrimMonad m, RandomGen g, Eq a, MVector v a) =>
v (PrimState m) a -> g -> m g
derangement v (PrimState m) a
mutV g
gen = do
v (PrimState m) a
mutV_copy <- v (PrimState m) a -> m (v (PrimState m) a)
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> m (v (PrimState m) a)
clone v (PrimState m) a
mutV
v (PrimState m) a -> v (PrimState m) a -> g -> Int -> Int -> m g
go v (PrimState m) a
mutV_copy v (PrimState m) a
mutV g
gen Int
0 (v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) a
mutV Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
where
go :: v (PrimState m) a -> v (PrimState m) a -> g -> Int -> Int -> m g
{-# INLINE go #-}
go :: v (PrimState m) a -> v (PrimState m) a -> g -> Int -> Int -> m g
go v (PrimState m) a
_ v (PrimState m) a
_ g
g Int
_ (- 1)   = g -> m g
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure g
g
go v (PrimState m) a
c v (PrimState m) a
v g
g Int
lastInd Int
0 =
do
a
v_last_old <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
c Int
lastInd
a
v_last_new <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
v Int
0
if a
v_last_old a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
v_last_new then
do
v (PrimState m) a -> v (PrimState m) a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> v (PrimState m) a -> m ()
unsafeCopy v (PrimState m) a
mutV v (PrimState m) a
c
v (PrimState m) a -> v (PrimState m) a -> g -> Int -> Int -> m g
go v (PrimState m) a
c v (PrimState m) a
mutV g
g Int
0 (v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) a
mutV Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
else
g -> m g
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure g
g
go v (PrimState m) a
c v (PrimState m) a
v g
oldGen Int
currInd Int
maxInd =
do
let (Int
swapInd, g
newGen) :: (Int, g) = (Int, Int) -> g -> (Int, g)
forall g. RandomGen g => (Int, Int) -> g -> (Int, g)
forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
SR.randomR (Int
0, Int
maxInd) g
oldGen
a
v_old  <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
c Int
currInd
a
v_ind  <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
v Int
swapInd
if a
v_old a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
v_ind then
do
v (PrimState m) a -> v (PrimState m) a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> v (PrimState m) a -> m ()
unsafeCopy v (PrimState m) a
mutV v (PrimState m) a
c
v (PrimState m) a -> v (PrimState m) a -> g -> Int -> Int -> m g
go v (PrimState m) a
c v (PrimState m) a
mutV g
newGen Int
0 (v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) a
mutV Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
else
do
v (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
swap v (PrimState m) a
v Int
0 Int
swapInd
v (PrimState m) a -> v (PrimState m) a -> g -> Int -> Int -> m g
go v (PrimState m) a
c (v (PrimState m) a -> v (PrimState m) a
forall (v :: * -> * -> *) a s. MVector v a => v s a -> v s a
tail v (PrimState m) a
v) g
newGen (Int
currInd Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
maxInd Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

-- |
-- Perform a [derangement](https://en.wikipedia.org/wiki/Derangement) on a mutable vector in a monad which has a source of randomness.
--
-- __Note:__ It is assumed the input vector consists of distinct values.
--
-- This uses the "early refusal" algorithm
derangementM
:: forall m a v
, Eq a
, MVector v a
)
=> v (PrimState m) a -> m ()
{-# INLINABLE derangementM #-}
derangementM :: forall (m :: * -> *) a (v :: * -> * -> *).
v (PrimState m) a -> m ()
derangementM v (PrimState m) a
mutV = do
v (PrimState m) a
mutV_copy <- v (PrimState m) a -> m (v (PrimState m) a)
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> m (v (PrimState m) a)
clone v (PrimState m) a
mutV
v (PrimState m) a -> v (PrimState m) a -> Int -> Int -> m ()
go v (PrimState m) a
mutV_copy v (PrimState m) a
mutV Int
0 (v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) a
mutV Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
where
go :: v (PrimState m) a -> v (PrimState m) a -> Int -> Int -> m ()
{-# INLINE go #-}
go :: v (PrimState m) a -> v (PrimState m) a -> Int -> Int -> m ()
go v (PrimState m) a
_ v (PrimState m) a
_ Int
_  (- 1)  = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
go v (PrimState m) a
c v (PrimState m) a
v Int
lastInd Int
0 =
do
a
v_last_old <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
c Int
lastInd
a
v_last_new <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
v Int
0
if a
v_last_old a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
v_last_new then
do
v (PrimState m) a -> v (PrimState m) a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> v (PrimState m) a -> m ()
unsafeCopy v (PrimState m) a
mutV v (PrimState m) a
c
v (PrimState m) a -> v (PrimState m) a -> Int -> Int -> m ()
go v (PrimState m) a
c v (PrimState m) a
mutV Int
0 (v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) a
mutV Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
else
() -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
go v (PrimState m) a
c v (PrimState m) a
v Int
currInd Int
maxInd =
do
Int
swapInd :: Int <- (Int, Int) -> m Int
forall a. Random a => (a, a) -> m a
forall (m :: * -> *) a. (MonadRandom m, Random a) => (a, a) -> m a
getRandomR (Int
0, Int
maxInd)
a
v_old  <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
c Int
currInd
a
v_ind  <- v (PrimState m) a -> Int -> m a
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
v Int
swapInd
if a
v_old a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
v_ind then
do
v (PrimState m) a -> v (PrimState m) a -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> v (PrimState m) a -> m ()
unsafeCopy v (PrimState m) a
mutV v (PrimState m) a
c
v (PrimState m) a -> v (PrimState m) a -> Int -> Int -> m ()
go v (PrimState m) a
c v (PrimState m) a
mutV Int
0 (v (PrimState m) a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) a
mutV Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
else
do
v (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> Int -> m ()
swap v (PrimState m) a
v Int
0 Int
swapInd
v (PrimState m) a -> v (PrimState m) a -> Int -> Int -> m ()
go v (PrimState m) a
c (v (PrimState m) a -> v (PrimState m) a
forall (v :: * -> * -> *) a s. MVector v a => v s a -> v s a
tail v (PrimState m) a
v) (Int
currInd Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
maxInd Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
```