--------------------------------------------------------------------------------
-- |
-- Module      :  System.Random.Shuffle
-- Copyright   :  (C) Frank Staals
-- License     :  see the LICENSE file
-- Maintainer  :  Frank Staals
--
-- Implements Fishyer-Yates shuffle.
--
--------------------------------------------------------------------------------
module System.Random.Shuffle(shuffle) where

import           Control.Monad
import           Control.Monad.Random.Class
import qualified Data.Foldable as F
import           Data.Util
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as MV

--------------------------------------------------------------------------------

-- | Fisher–Yates shuffle, which shuffles a list/foldable uniformly at random.
--
-- running time: \(O(n)\).
shuffle :: (Foldable f, MonadRandom m) => f a -> m (V.Vector a)
shuffle :: f a -> m (Vector a)
shuffle = Vector a -> m (Vector a)
forall (f :: * -> *) a. MonadRandom f => Vector a -> f (Vector a)
withLength (Vector a -> m (Vector a))
-> (f a -> Vector a) -> f a -> m (Vector a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [a] -> Vector a
forall a. [a] -> Vector a
V.fromList ([a] -> Vector a) -> (f a -> [a]) -> f a -> Vector a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
F.toList
  where
    withLength :: Vector a -> f (Vector a)
withLength Vector a
v = let n :: Int
n = Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
v in ([SP Int Int] -> Vector a -> Vector a)
-> Vector a -> [SP Int Int] -> Vector a
forall a b c. (a -> b -> c) -> b -> a -> c
flip [SP Int Int] -> Vector a -> Vector a
forall (t :: * -> *) a.
Foldable t =>
t (SP Int Int) -> Vector a -> Vector a
withRands Vector a
v ([SP Int Int] -> Vector a) -> f [SP Int Int] -> f (Vector a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> f [SP Int Int]
forall (m :: * -> *). MonadRandom m => Int -> m [SP Int Int]
rands (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
    withRands :: t (SP Int Int) -> Vector a -> Vector a
withRands t (SP Int Int)
rs = (forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
forall a.
(forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
V.modify ((forall s. MVector s a -> ST s ()) -> Vector a -> Vector a)
-> (forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
forall a b. (a -> b) -> a -> b
$ \MVector s a
v ->
                     t (SP Int Int) -> (SP Int Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ t (SP Int Int)
rs ((SP Int Int -> ST s ()) -> ST s ())
-> (SP Int Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \(SP Int
i Int
j) -> MVector (PrimState (ST s)) a -> Int -> Int -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> Int -> m ()
MV.swap MVector s a
MVector (PrimState (ST s)) a
v Int
i Int
j


-- | Generate a list of indices in decreasing order, coupled with a random
-- value in the range [0,i].
rands   :: MonadRandom m => Int -> m [SP Int Int]
rands :: Int -> m [SP Int Int]
rands Int
n = (Int -> m (SP Int Int)) -> [Int] -> m [SP Int Int]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Int
i -> Int -> Int -> SP Int Int
forall a b. a -> b -> SP a b
SP Int
i (Int -> SP Int Int) -> m Int -> m (SP Int Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Int, Int) -> m Int
forall (m :: * -> *) a. (MonadRandom m, Random a) => (a, a) -> m a
getRandomR (Int
0,Int
i)) [Int
n,(Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)..Int
1]