{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | This module provides functions to perform shuffles on mutable vectors.
-- The shuffling is uniform amongst all permuations and uses the minimal
-- number of transpositions.

module Mutable.Shuffle where

import           Control.Monad.Primitive
import           Control.Monad.Random    (MonadRandom (..))
import           Data.Vector.Mutable
import           Prelude                 hiding (length, read, tail)
import           System.Random           (RandomGen)
import qualified System.Random           as SR


-- |
-- Perform a shuffle on a mutable vector with a given random generator, returning a new random generator.
--
-- This uses the Fisher--Yates--Knuth algorithm
shuffle
  :: forall m a g
  . ( PrimMonad m
    , RandomGen g
    )
  => MVector (PrimState m) a -> g -> m g
{-# INLINABLE shuffle #-}
shuffle :: MVector (PrimState m) a -> g -> m g
shuffle mutV :: MVector (PrimState m) a
mutV gen :: g
gen = MVector (PrimState m) a -> g -> Int -> m g
go MVector (PrimState m) a
mutV g
gen (MVector (PrimState m) a -> Int
forall s a. MVector s a -> Int
length MVector (PrimState m) a
mutV Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1)
  where
    go :: MVector (PrimState m) a -> g -> Int -> m g
    {-# INLINE go #-}
    go :: MVector (PrimState m) a -> g -> Int -> m g
go _ g :: g
g 0   =  g -> m g
forall (f :: * -> *) a. Applicative f => a -> f a
pure g
g
    go v :: MVector (PrimState m) a
v g :: g
g maxInd :: Int
maxInd =
      do
        let (ind :: Int
ind, newGen :: g
newGen) :: (Int, g) = (Int, Int) -> g -> (Int, g)
forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
SR.randomR (0, Int
maxInd) g
g
        MVector (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> Int -> m ()
swap MVector (PrimState m) a
v 0 Int
ind
        MVector (PrimState m) a -> g -> Int -> m g
go (MVector (PrimState m) a -> MVector (PrimState m) a
forall s a. MVector s a -> MVector s a
tail MVector (PrimState m) a
v) g
newGen (Int
maxInd Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1)



-- |
-- Perform a shuffle on a mutable vector in a monad which has a source of randomness.
--
-- This uses the Fisher--Yates--Knuth algorithm
shuffleM
  :: forall m a
  . ( PrimMonad m
    , MonadRandom m
    )
  => MVector (PrimState m) a  -> m ()
{-# INLINABLE shuffleM #-}
shuffleM :: MVector (PrimState m) a -> m ()
shuffleM mutV :: MVector (PrimState m) a
mutV = MVector (PrimState m) a -> Int -> m ()
go MVector (PrimState m) a
mutV (MVector (PrimState m) a -> Int
forall s a. MVector s a -> Int
length MVector (PrimState m) a
mutV Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1)
  where
    go :: MVector (PrimState m) a -> Int -> m ()
    {-# INLINE go #-}
    go :: MVector (PrimState m) a -> Int -> m ()
go _ 0   =  () -> m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    go v :: MVector (PrimState m) a
v maxInd :: Int
maxInd =
      do
        Int
ind <-  (Int, Int) -> m Int
forall (m :: * -> *) a. (MonadRandom m, Random a) => (a, a) -> m a
getRandomR (0, Int
maxInd)
        MVector (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> Int -> m ()
swap MVector (PrimState m) a
v 0 Int
ind
        MVector (PrimState m) a -> Int -> m ()
go (MVector (PrimState m) a -> MVector (PrimState m) a
forall s a. MVector s a -> MVector s a
tail MVector (PrimState m) a
v) (Int
maxInd Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1)

{-# SPECIALISE shuffleM :: MVector RealWorld a -> IO () #-}

-- |
-- Shuffle the first k elements of a vector.
--
shuffleK
  :: forall m a
  . ( PrimMonad m
    , MonadRandom m
    )
  => Int -> MVector (PrimState m) a  -> m ()
{-# INLINABLE shuffleK #-}
shuffleK :: Int -> MVector (PrimState m) a -> m ()
shuffleK numberOfShuffles :: Int
numberOfShuffles mutV :: MVector (PrimState m) a
mutV = MVector (PrimState m) a -> Int -> m ()
go MVector (PrimState m) a
mutV (Int
numberOfShuffles Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1)
  where
    go :: MVector (PrimState m) a -> Int -> m ()
    {-# INLINE go #-}
    go :: MVector (PrimState m) a -> Int -> m ()
go _ k :: Int
k | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< 0
      = [Char] -> m ()
forall a. HasCallStack => [Char] -> a
error "Cannot pass negative value to ShuffleK"
    go _ k :: Int
k | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= MVector (PrimState m) a -> Int
forall s a. MVector s a -> Int
length MVector (PrimState m) a
mutV
      = [Char] -> m ()
forall a. HasCallStack => [Char] -> a
error "Cannot pass value greater than the length of the vector  to ShuffleK"
    go _ 0   =  () -> m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    go v :: MVector (PrimState m) a
v maxInd :: Int
maxInd =
      do
        Int
ind <-  (Int, Int) -> m Int
forall (m :: * -> *) a. (MonadRandom m, Random a) => (a, a) -> m a
getRandomR (0, Int
maxInd)
        MVector (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> Int -> m ()
swap MVector (PrimState m) a
v 0 Int
ind
        MVector (PrimState m) a -> Int -> m ()
go (MVector (PrimState m) a -> MVector (PrimState m) a
forall s a. MVector s a -> MVector s a
tail MVector (PrimState m) a
v) (Int
maxInd Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1)



-- |
-- Perform a shuffle on a mutable vector wherein the shuffled indices form a maximal cycle.
--
-- This uses the Sattolo algorithm.
maximalCycle
  :: forall m a g
  . ( PrimMonad m
    , RandomGen g
    )
  => MVector (PrimState m) a -> g -> m g
{-# INLINABLE maximalCycle #-}
maximalCycle :: MVector (PrimState m) a -> g -> m g
maximalCycle mutV :: MVector (PrimState m) a
mutV gen :: g
gen = MVector (PrimState m) a -> g -> Int -> m g
go MVector (PrimState m) a
mutV g
gen (MVector (PrimState m) a -> Int
forall s a. MVector s a -> Int
length MVector (PrimState m) a
mutV Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1)
  where
    go :: MVector (PrimState m) a -> g -> Int -> m g
    {-# INLINE go #-}
    go :: MVector (PrimState m) a -> g -> Int -> m g
go _ g :: g
g 0      =  g -> m g
forall (f :: * -> *) a. Applicative f => a -> f a
pure g
g
    go v :: MVector (PrimState m) a
v g :: g
g maxInd :: Int
maxInd =
      do
        let (ind :: Int
ind, newGen :: g
newGen) :: (Int, g) = (Int, Int) -> g -> (Int, g)
forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
SR.randomR (1, Int
maxInd) g
g
        MVector (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> Int -> m ()
swap MVector (PrimState m) a
v 0 Int
ind
        MVector (PrimState m) a -> g -> Int -> m g
go (MVector (PrimState m) a -> MVector (PrimState m) a
forall s a. MVector s a -> MVector s a
tail MVector (PrimState m) a
v) g
newGen (Int
maxInd Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1)


-- |
-- Perform a shuffle on a mutable vector wherein the shuffled indices form a maximal cycle
-- in a monad with a source of randomness.
--
-- This uses the Sattolo algorithm.
maximalCycleM
  :: forall m a
  . ( PrimMonad m
    , MonadRandom m
    )
  => MVector (PrimState m) a  -> m ()
{-# INLINABLE maximalCycleM #-}
maximalCycleM :: MVector (PrimState m) a -> m ()
maximalCycleM mutV :: MVector (PrimState m) a
mutV = MVector (PrimState m) a -> Int -> m ()
go MVector (PrimState m) a
mutV (MVector (PrimState m) a -> Int
forall s a. MVector s a -> Int
length MVector (PrimState m) a
mutV Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1)
  where
    go :: MVector (PrimState m) a -> Int -> m ()
    {-# INLINE go #-}
    go :: MVector (PrimState m) a -> Int -> m ()
go _ 0   =  () -> m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    go v :: MVector (PrimState m) a
v maxInd :: Int
maxInd =
      do
        Int
ind <-  (Int, Int) -> m Int
forall (m :: * -> *) a. (MonadRandom m, Random a) => (a, a) -> m a
getRandomR (1, Int
maxInd)
        MVector (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> Int -> m ()
swap MVector (PrimState m) a
v 0 Int
ind
        MVector (PrimState m) a -> Int -> m ()
go (MVector (PrimState m) a -> MVector (PrimState m) a
forall s a. MVector s a -> MVector s a
tail MVector (PrimState m) a
v) (Int
maxInd Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1)

{-# SPECIALISE maximalCycleM :: MVector RealWorld a -> IO () #-}



-- |
-- Perform a [derangement](https://en.wikipedia.org/wiki/Derangement) on a mutable vector with a given random generator, returning a new random generator.
--
-- __Note:__ It is assumed the input vector consists of distinct values.
--
-- This uses the "early refusal" algorithm.
derangement
  :: forall m a g
  . ( PrimMonad m
    , RandomGen g
    , Eq a
    )
  => MVector (PrimState m) a -> g -> m g
{-# INLINABLE derangement #-}
derangement :: MVector (PrimState m) a -> g -> m g
derangement mutV :: MVector (PrimState m) a
mutV gen :: g
gen = do
  MVector (PrimState m) a
mutV_copy <- MVector (PrimState m) a -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (MVector (PrimState m) a)
clone MVector (PrimState m) a
mutV
  MVector (PrimState m) a
-> MVector (PrimState m) a -> g -> Int -> Int -> m g
go MVector (PrimState m) a
mutV_copy MVector (PrimState m) a
mutV g
gen 0 (MVector (PrimState m) a -> Int
forall s a. MVector s a -> Int
length MVector (PrimState m) a
mutV Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1)
  where
    go :: MVector (PrimState m) a -> MVector (PrimState m) a -> g -> Int -> Int -> m g
    {-# INLINE go #-}
    go :: MVector (PrimState m) a
-> MVector (PrimState m) a -> g -> Int -> Int -> m g
go c :: MVector (PrimState m) a
c v :: MVector (PrimState m) a
v g :: g
g lastInd :: Int
lastInd 0 =
      do
        a
v_last_old <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
read MVector (PrimState m) a
c Int
lastInd
        a
v_last_new <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
read MVector (PrimState m) a
v 0
        if a
v_last_old a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
v_last_new then
          do
            MVector (PrimState m) a -> MVector (PrimState m) a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> MVector (PrimState m) a -> m ()
unsafeCopy MVector (PrimState m) a
mutV MVector (PrimState m) a
c
            MVector (PrimState m) a
-> MVector (PrimState m) a -> g -> Int -> Int -> m g
go MVector (PrimState m) a
c MVector (PrimState m) a
mutV g
g 0 (MVector (PrimState m) a -> Int
forall s a. MVector s a -> Int
length MVector (PrimState m) a
mutV Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1)
        else
          g -> m g
forall (f :: * -> *) a. Applicative f => a -> f a
pure g
g
    go c :: MVector (PrimState m) a
c v :: MVector (PrimState m) a
v oldGen :: g
oldGen currInd :: Int
currInd maxInd :: Int
maxInd =
      do
        let (swapInd :: Int
swapInd, newGen :: g
newGen) :: (Int, g) = (Int, Int) -> g -> (Int, g)
forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
SR.randomR (0, Int
maxInd) g
oldGen
        a
v_old  <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
read MVector (PrimState m) a
c Int
currInd
        a
v_ind  <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
read MVector (PrimState m) a
v Int
swapInd
        if a
v_old a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
v_ind then
          do
            MVector (PrimState m) a -> MVector (PrimState m) a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> MVector (PrimState m) a -> m ()
unsafeCopy MVector (PrimState m) a
mutV MVector (PrimState m) a
c
            MVector (PrimState m) a
-> MVector (PrimState m) a -> g -> Int -> Int -> m g
go MVector (PrimState m) a
c MVector (PrimState m) a
mutV g
newGen 0 (MVector (PrimState m) a -> Int
forall s a. MVector s a -> Int
length MVector (PrimState m) a
mutV Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1)
        else
          do
            MVector (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> Int -> m ()
swap MVector (PrimState m) a
v 0 Int
swapInd
            MVector (PrimState m) a
-> MVector (PrimState m) a -> g -> Int -> Int -> m g
go MVector (PrimState m) a
c (MVector (PrimState m) a -> MVector (PrimState m) a
forall s a. MVector s a -> MVector s a
tail MVector (PrimState m) a
v) g
newGen (Int
currInd Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 1) (Int
maxInd Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1)


-- |
-- Perform a [derangement](https://en.wikipedia.org/wiki/Derangement) on a mutable vector in a monad which has a source of randomness.
--
-- __Note:__ It is assumed the input vector consists of distinct values.
--
-- This uses the "early refusal" algorithm
derangementM
  :: forall m a
  . ( PrimMonad m
    , MonadRandom m
    , Eq a
    )
  => MVector (PrimState m) a -> m ()
{-# INLINABLE derangementM #-}
derangementM :: MVector (PrimState m) a -> m ()
derangementM mutV :: MVector (PrimState m) a
mutV = do
  MVector (PrimState m) a
mutV_copy <- MVector (PrimState m) a -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (MVector (PrimState m) a)
clone MVector (PrimState m) a
mutV
  MVector (PrimState m) a
-> MVector (PrimState m) a -> Int -> Int -> m ()
go MVector (PrimState m) a
mutV_copy MVector (PrimState m) a
mutV 0 (MVector (PrimState m) a -> Int
forall s a. MVector s a -> Int
length MVector (PrimState m) a
mutV Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1)
  where
    go :: MVector (PrimState m) a -> MVector (PrimState m) a -> Int -> Int -> m ()
    {-# INLINE go #-}
    go :: MVector (PrimState m) a
-> MVector (PrimState m) a -> Int -> Int -> m ()
go c :: MVector (PrimState m) a
c v :: MVector (PrimState m) a
v lastInd :: Int
lastInd 0 =
      do
        a
v_last_old <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
read MVector (PrimState m) a
c Int
lastInd
        a
v_last_new <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
read MVector (PrimState m) a
v 0
        if a
v_last_old a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
v_last_new then
          do
            MVector (PrimState m) a -> MVector (PrimState m) a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> MVector (PrimState m) a -> m ()
unsafeCopy MVector (PrimState m) a
mutV MVector (PrimState m) a
c
            MVector (PrimState m) a
-> MVector (PrimState m) a -> Int -> Int -> m ()
go MVector (PrimState m) a
c MVector (PrimState m) a
mutV 0 (MVector (PrimState m) a -> Int
forall s a. MVector s a -> Int
length MVector (PrimState m) a
mutV Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1)
        else
          () -> m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    go c :: MVector (PrimState m) a
c v :: MVector (PrimState m) a
v currInd :: Int
currInd maxInd :: Int
maxInd =
      do
        Int
swapInd :: Int <- (Int, Int) -> m Int
forall (m :: * -> *) a. (MonadRandom m, Random a) => (a, a) -> m a
getRandomR (0, Int
maxInd)
        a
v_old  <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
read MVector (PrimState m) a
c Int
currInd
        a
v_ind  <- MVector (PrimState m) a -> Int -> m a
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
read MVector (PrimState m) a
v Int
swapInd
        if a
v_old a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
v_ind then
          do
            MVector (PrimState m) a -> MVector (PrimState m) a -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> MVector (PrimState m) a -> m ()
unsafeCopy MVector (PrimState m) a
mutV MVector (PrimState m) a
c
            MVector (PrimState m) a
-> MVector (PrimState m) a -> Int -> Int -> m ()
go MVector (PrimState m) a
c MVector (PrimState m) a
mutV 0 (MVector (PrimState m) a -> Int
forall s a. MVector s a -> Int
length MVector (PrimState m) a
mutV Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1)
        else
          do
            MVector (PrimState m) a -> Int -> Int -> m ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> Int -> m ()
swap MVector (PrimState m) a
v 0 Int
swapInd
            MVector (PrimState m) a
-> MVector (PrimState m) a -> Int -> Int -> m ()
go MVector (PrimState m) a
c (MVector (PrimState m) a -> MVector (PrimState m) a
forall s a. MVector s a -> MVector s a
tail MVector (PrimState m) a
v) (Int
currInd Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 1) (Int
maxInd Int -> Int -> Int
forall a. Num a => a -> a -> a
- 1)