{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | This module provides functions to perform shuffles on mutable vectors.
-- The shuffling is uniform amongst all permuations and uses the minimal
-- number of transpositions.

module Mutable.Shuffle where

import           Control.Monad.Primitive
import           Control.Monad.Random    (MonadRandom (..))
import           Data.Vector.Mutable
import           Prelude                 hiding (length, read, tail)
import           System.Random           (RandomGen)
import qualified System.Random           as SR


-- |
-- Perform a shuffle on a mutable vector with a given random generator, returning a new random generator.
--
-- This uses the Fisher--Yates--Knuth algorithm
shuffle
  :: forall m a g
  . ( PrimMonad m
    , RandomGen g
    )
  => MVector (PrimState m) a -> g -> m g
{-# INLINABLE shuffle #-}
shuffle mutV gen = go mutV gen (length mutV - 1)
  where
    go :: MVector (PrimState m) a -> g -> Int -> m g
    {-# INLINE go #-}
    go _ g 0   =  pure g
    go v g maxInd =
      do
        let (ind, newGen) :: (Int, g) = SR.randomR (0, maxInd) g
        swap v 0 ind
        go (tail v) newGen (maxInd - 1)



-- |
-- Perform a shuffle on a mutable vector in a monad which has a source of randomness.
--
-- This uses the Fisher--Yates--Knuth algorithm
shuffleM
  :: forall m a
  . ( PrimMonad m
    , MonadRandom m
    )
  => MVector (PrimState m) a  -> m ()
{-# INLINABLE shuffleM #-}
shuffleM mutV = go mutV (length mutV - 1)
  where
    go :: MVector (PrimState m) a -> Int -> m ()
    {-# INLINE go #-}
    go _ 0   =  pure ()
    go v maxInd =
      do
        ind <-  getRandomR (0, maxInd)
        swap v 0 ind
        go (tail v) (maxInd - 1)

{-# SPECIALISE shuffleM :: MVector RealWorld a -> IO () #-}

-- |
-- Shuffle the first k elements of a vector.
--
shuffleK
  :: forall m a
  . ( PrimMonad m
    , MonadRandom m
    )
  => Int -> MVector (PrimState m) a  -> m ()
{-# INLINABLE shuffleK #-}
shuffleK numberOfShuffles mutV = go mutV (numberOfShuffles - 1)
  where
    go :: MVector (PrimState m) a -> Int -> m ()
    {-# INLINE go #-}
    go _ k | k < 0
      = error "Cannot pass negative value to ShuffleK"
    go _ k | k >= length mutV
      = error "Cannot pass value greater than the length of the vector  to ShuffleK"
    go _ 0   =  pure ()
    go v maxInd =
      do
        ind <-  getRandomR (0, maxInd)
        swap v 0 ind
        go (tail v) (maxInd - 1)



-- |
-- Perform a shuffle on a mutable vector wherein the shuffled indices form a maximal cycle.
--
-- This uses the Sattolo algorithm.
maximalCycle
  :: forall m a g
  . ( PrimMonad m
    , RandomGen g
    )
  => MVector (PrimState m) a -> g -> m g
{-# INLINABLE maximalCycle #-}
maximalCycle mutV gen = go mutV gen (length mutV - 1)
  where
    go :: MVector (PrimState m) a -> g -> Int -> m g
    {-# INLINE go #-}
    go _ g 0      =  pure g
    go v g maxInd =
      do
        let (ind, newGen) :: (Int, g) = SR.randomR (1, maxInd) g
        swap v 0 ind
        go (tail v) newGen (maxInd - 1)


-- |
-- Perform a shuffle on a mutable vector wherein the shuffled indices form a maximal cycle
-- in a monad with a source of randomness.
--
-- This uses the Sattolo algorithm.
maximalCycleM
  :: forall m a
  . ( PrimMonad m
    , MonadRandom m
    )
  => MVector (PrimState m) a  -> m ()
{-# INLINABLE maximalCycleM #-}
maximalCycleM mutV = go mutV (length mutV - 1)
  where
    go :: MVector (PrimState m) a -> Int -> m ()
    {-# INLINE go #-}
    go _ 0   =  pure ()
    go v maxInd =
      do
        ind <-  getRandomR (1, maxInd)
        swap v 0 ind
        go (tail v) (maxInd - 1)

{-# SPECIALISE maximalCycleM :: MVector RealWorld a -> IO () #-}



-- |
-- Perform a [derangement](https://en.wikipedia.org/wiki/Derangement) on a mutable vector with a given random generator, returning a new random generator.
--
-- __Note:__ It is assumed the input vector consists of distinct values.
--
-- This uses the "early refusal" algorithm.
derangement
  :: forall m a g
  . ( PrimMonad m
    , RandomGen g
    , Eq a
    )
  => MVector (PrimState m) a -> g -> m g
{-# INLINABLE derangement #-}
derangement mutV gen = do
  mutV_copy <- clone mutV
  go mutV_copy mutV gen 0 (length mutV - 1)
  where
    go :: MVector (PrimState m) a -> MVector (PrimState m) a -> g -> Int -> Int -> m g
    {-# INLINE go #-}
    go c v g lastInd 0 =
      do
        v_last_old <- read c lastInd
        v_last_new <- read v 0
        if v_last_old == v_last_new then
          do
            unsafeCopy mutV c
            go c mutV g 0 (length mutV - 1)
        else
          pure g
    go c v oldGen currInd maxInd =
      do
        let (swapInd, newGen) :: (Int, g) = SR.randomR (0, maxInd) oldGen
        v_old  <- read c currInd
        v_ind  <- read v swapInd
        if v_old == v_ind then
          do
            unsafeCopy mutV c
            go c mutV newGen 0 (length mutV - 1)
        else
          do
            swap v 0 swapInd
            go c (tail v) newGen (currInd + 1) (maxInd - 1)


-- |
-- Perform a [derangement](https://en.wikipedia.org/wiki/Derangement) on a mutable vector in a monad which has a source of randomness.
--
-- __Note:__ It is assumed the input vector consists of distinct values.
--
-- This uses the "early refusal" algorithm
derangementM
  :: forall m a
  . ( PrimMonad m
    , MonadRandom m
    , Eq a
    )
  => MVector (PrimState m) a -> m ()
{-# INLINABLE derangementM #-}
derangementM mutV = do
  mutV_copy <- clone mutV
  go mutV_copy mutV 0 (length mutV - 1)
  where
    go :: MVector (PrimState m) a -> MVector (PrimState m) a -> Int -> Int -> m ()
    {-# INLINE go #-}
    go c v lastInd 0 =
      do
        v_last_old <- read c lastInd
        v_last_new <- read v 0
        if v_last_old == v_last_new then
          do
            unsafeCopy mutV c
            go c mutV 0 (length mutV - 1)
        else
          pure ()
    go c v currInd maxInd =
      do
        swapInd :: Int <- getRandomR (0, maxInd)
        v_old  <- read c currInd
        v_ind  <- read v swapInd
        if v_old == v_ind then
          do
            unsafeCopy mutV c
            go c mutV 0 (length mutV - 1)
        else
          do
            swap v 0 swapInd
            go c (tail v) (currInd + 1) (maxInd - 1)