module VectorExtras.Immutable.FoldM.PrimMonad.Index where

import Control.Foldl
import Data.Vector.Generic
import qualified Data.Vector.Generic.Mutable as Mutable
import VectorExtras.Prelude hiding (length)

-- |
-- Fold on indices in PrimMonad.
type IndexPrimMonadFoldM result = forall m. PrimMonad m => FoldM m Int result

-- |
-- Given the size of the vector, construct a fold, which produces a vector of
-- frequencies of each index. I.e., the counts of how often it appeared.
--
-- It is your responsibility to ensure that the indices are within the size of the vector produced.
frequency :: (Vector vector count, Enum count) => Int -> IndexPrimMonadFoldM (vector count)
frequency :: Int -> IndexPrimMonadFoldM (vector count)
frequency Int
amount = (Mutable vector (PrimState m) count
 -> Int -> m (Mutable vector (PrimState m) count))
-> m (Mutable vector (PrimState m) count)
-> (Mutable vector (PrimState m) count -> m (vector count))
-> FoldM m Int (vector count)
forall (m :: * -> *) a b x.
(x -> a -> m x) -> m x -> (x -> m b) -> FoldM m a b
FoldM Mutable vector (PrimState m) count
-> Int -> m (Mutable vector (PrimState m) count)
forall (f :: * -> *) (v :: * -> * -> *) a.
(PrimMonad f, MVector v a, Enum a) =>
v (PrimState f) a -> Int -> f (v (PrimState f) a)
step m (Mutable vector (PrimState m) count)
init Mutable vector (PrimState m) count -> m (vector count)
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
Mutable v (PrimState m) a -> m (v a)
extract
  where
    init :: m (Mutable vector (PrimState m) count)
init = Int -> m (Mutable vector (PrimState m) count)
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
Mutable.new Int
amount
    step :: v (PrimState f) a -> Int -> f (v (PrimState f) a)
step v (PrimState f) a
mv Int
index = v (PrimState f) a -> (a -> a) -> Int -> f ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
Mutable.modify v (PrimState f) a
mv a -> a
forall a. Enum a => a -> a
succ Int
index f () -> v (PrimState f) a -> f (v (PrimState f) a)
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> v (PrimState f) a
mv
    extract :: Mutable v (PrimState m) a -> m (v a)
extract = Mutable v (PrimState m) a -> m (v a)
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
Mutable v (PrimState m) a -> m (v a)
unsafeFreeze