{-# LANGUAGE ScopedTypeVariables #-}

-- | This module provides functions to perform shuffles on immutable vectors.
-- The shuffling is uniform amongst all permuations and performs the minimal
-- number of transpositions.

module Immutable.Shuffle where

import           Control.Monad.Primitive
import           Control.Monad.Random    (MonadRandom (..))
import           Control.Monad.ST        (runST)
import           Data.Vector
import qualified Mutable.Shuffle         as MS
import           Prelude                 hiding (length, take)
import           System.Random           (RandomGen (..))


-- |
-- Perform a shuffle on an immutable vector with a given random generator returning a shuffled vector and a new generator.
--
-- This uses the Fisher--Yates--Knuth algorithm.
shuffle :: forall a g. RandomGen g => Vector a -> g -> (Vector a, g)
shuffle :: Vector a -> g -> (Vector a, g)
shuffle v :: Vector a
v g :: g
g
  | Vector a -> Int
forall a. Vector a -> Int
length Vector a
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 1 = (Vector a
v, g
g)
  | Bool
otherwise     =
      (forall s. ST s (Vector a, g)) -> (Vector a, g)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Vector a, g)) -> (Vector a, g))
-> (forall s. ST s (Vector a, g)) -> (Vector a, g)
forall a b. (a -> b) -> a -> b
$
        do
          MVector s a
mutV   <- Vector a -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
thaw Vector a
v
          g
newGen <- MVector (PrimState (ST s)) a -> g -> ST s g
forall (m :: * -> *) a g.
(PrimMonad m, RandomGen g) =>
MVector (PrimState m) a -> g -> m g
MS.shuffle MVector s a
MVector (PrimState (ST s)) a
mutV g
g
          Vector a
immutV <- MVector (PrimState (ST s)) a -> ST s (Vector a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
unsafeFreeze MVector s a
MVector (PrimState (ST s)) a
mutV
          (Vector a, g) -> ST s (Vector a, g)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Vector a
immutV, g
newGen)


-- |
-- Perform a shuffle on an input immutable vector in a monad which has a source of randomness.
--
-- This uses the Fisher--Yates--Knuth algorithm.
shuffleM :: forall m a . (MonadRandom m, PrimMonad m) => Vector a -> m (Vector a)
shuffleM :: Vector a -> m (Vector a)
shuffleM v :: Vector a
v
  | Vector a -> Int
forall a. Vector a -> Int
length Vector a
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 1 = Vector a -> m (Vector a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Vector a
v
  | Bool
otherwise =
      do
        MVector (PrimState m) a
mutV   <- Vector a -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
thaw Vector a
v
        MVector (PrimState m) a -> m ()
forall (m :: * -> *) a.
(PrimMonad m, MonadRandom m) =>
MVector (PrimState m) a -> m ()
MS.shuffleM MVector (PrimState m) a
mutV
        MVector (PrimState m) a -> m (Vector a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
unsafeFreeze MVector (PrimState m) a
mutV


-- |
-- Perform a shuffle on the first k elements of a vector in a monad which has a
-- source of randomness.
--
shuffleK :: forall m a . (MonadRandom m, PrimMonad m) => Int -> Vector a -> m (Vector a)
shuffleK :: Int -> Vector a -> m (Vector a)
shuffleK k :: Int
k v :: Vector a
v
  | Vector a -> Int
forall a. Vector a -> Int
length Vector a
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 1 = Vector a -> m (Vector a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Vector a
v
  | Bool
otherwise =
      do
        MVector (PrimState m) a
mutV   <- Vector a -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
thaw Vector a
v
        Int -> MVector (PrimState m) a -> m ()
forall (m :: * -> *) a.
(PrimMonad m, MonadRandom m) =>
Int -> MVector (PrimState m) a -> m ()
MS.shuffleK Int
k MVector (PrimState m) a
mutV
        MVector (PrimState m) a -> m (Vector a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
unsafeFreeze MVector (PrimState m) a
mutV


-- |
-- Get a random sample of k elements without replacement from a vector.
sampleWithoutReplacement :: forall m a . (MonadRandom m, PrimMonad m) => Int -> Vector a -> m (Vector a)
{-# INLINEABLE sampleWithoutReplacement #-}
sampleWithoutReplacement :: Int -> Vector a -> m (Vector a)
sampleWithoutReplacement k :: Int
k v :: Vector a
v = Int -> Vector a -> Vector a
forall a. Int -> Vector a -> Vector a
take Int
k (Vector a -> Vector a) -> m (Vector a) -> m (Vector a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Vector a -> m (Vector a)
forall (m :: * -> *) a.
(MonadRandom m, PrimMonad m) =>
Int -> Vector a -> m (Vector a)
shuffleK Int
k Vector a
v


-- |
-- Perform an in-place shuffle on an immutable vector wherein the shuffled
-- indices form a maximal cycle.
--
-- This uses the Sattolo algorithm.
maximalCycle :: forall a g. RandomGen g => Vector a -> g -> (Vector a, g)
maximalCycle :: Vector a -> g -> (Vector a, g)
maximalCycle v :: Vector a
v g :: g
g
  | Vector a -> Int
forall a. Vector a -> Int
length Vector a
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 1 = (Vector a
v, g
g)
  | Bool
otherwise     =
      (forall s. ST s (Vector a, g)) -> (Vector a, g)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Vector a, g)) -> (Vector a, g))
-> (forall s. ST s (Vector a, g)) -> (Vector a, g)
forall a b. (a -> b) -> a -> b
$
        do
          MVector s a
mutV   <- Vector a -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
thaw Vector a
v
          g
newGen <- MVector (PrimState (ST s)) a -> g -> ST s g
forall (m :: * -> *) a g.
(PrimMonad m, RandomGen g) =>
MVector (PrimState m) a -> g -> m g
MS.maximalCycle MVector s a
MVector (PrimState (ST s)) a
mutV g
g
          Vector a
immutV <- MVector (PrimState (ST s)) a -> ST s (Vector a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
unsafeFreeze MVector s a
MVector (PrimState (ST s)) a
mutV
          (Vector a, g) -> ST s (Vector a, g)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Vector a
immutV, g
newGen)

-- |
-- Perform an in-place shuffle on an immutable vector wherein the shuffled
-- indices form a maximal cycle in a monad with a source of randomness.
--
-- This uses the Sattolo algorithm.
maximalCycleM :: forall m a . (MonadRandom m, PrimMonad m) => Vector a -> m (Vector a)
maximalCycleM :: Vector a -> m (Vector a)
maximalCycleM v :: Vector a
v
  | Vector a -> Int
forall a. Vector a -> Int
length Vector a
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 1 = Vector a -> m (Vector a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Vector a
v
  | Bool
otherwise =
      do
        MVector (PrimState m) a
mutV   <- Vector a -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
thaw Vector a
v
        MVector (PrimState m) a -> m ()
forall (m :: * -> *) a.
(PrimMonad m, MonadRandom m) =>
MVector (PrimState m) a -> m ()
MS.maximalCycleM MVector (PrimState m) a
mutV
        MVector (PrimState m) a -> m (Vector a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
unsafeFreeze MVector (PrimState m) a
mutV


-- |
-- Perform an in-place  [derangement](https://en.wikipedia.org/wiki/Derangement)
-- on an immutable vector with a given random generator, returning a new random
-- generator.
--
-- __Note:__ It is assumed the input vector consists of distinct values.
--
-- This uses the "early refusal" algorithm.
derangement :: forall a g. (Eq a, RandomGen g) => Vector a -> g -> (Vector a, g)
derangement :: Vector a -> g -> (Vector a, g)
derangement v :: Vector a
v g :: g
g
  | Vector a -> Int
forall a. Vector a -> Int
length Vector a
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 1 = (Vector a
v, g
g)
  | Bool
otherwise     =
      (forall s. ST s (Vector a, g)) -> (Vector a, g)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Vector a, g)) -> (Vector a, g))
-> (forall s. ST s (Vector a, g)) -> (Vector a, g)
forall a b. (a -> b) -> a -> b
$
        do
          MVector s a
mutV   <- Vector a -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
thaw Vector a
v
          g
newGen <- MVector (PrimState (ST s)) a -> g -> ST s g
forall (m :: * -> *) a g.
(PrimMonad m, RandomGen g, Eq a) =>
MVector (PrimState m) a -> g -> m g
MS.derangement MVector s a
MVector (PrimState (ST s)) a
mutV g
g
          Vector a
immutV <- MVector (PrimState (ST s)) a -> ST s (Vector a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
unsafeFreeze MVector s a
MVector (PrimState (ST s)) a
mutV
          (Vector a, g) -> ST s (Vector a, g)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Vector a
immutV, g
newGen)


-- |
-- Perform an in-place [derangement](https://en.wikipedia.org/wiki/Derangement) on
-- an immutable vector in a monad which has a source of randomness.
--
-- __Note:__ It is assumed the input vector consists of distinct values.
--
-- This uses the "early refusal" algorithm.
derangementM :: forall m a . (Eq a, MonadRandom m, PrimMonad m) => Vector a -> m (Vector a)
derangementM :: Vector a -> m (Vector a)
derangementM v :: Vector a
v
  | Vector a -> Int
forall a. Vector a -> Int
length Vector a
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 1 = Vector a -> m (Vector a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Vector a
v
  | Bool
otherwise =
      do
        MVector (PrimState m) a
mutV   <- Vector a -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
thaw Vector a
v
        MVector (PrimState m) a -> m ()
forall (m :: * -> *) a.
(PrimMonad m, MonadRandom m, Eq a) =>
MVector (PrimState m) a -> m ()
MS.derangementM MVector (PrimState m) a
mutV
        MVector (PrimState m) a -> m (Vector a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
unsafeFreeze MVector (PrimState m) a
mutV