module PrimitiveExtras.PrimArray
where

import PrimitiveExtras.Prelude hiding (replicateM, traverse_)
import PrimitiveExtras.Types
import qualified Data.Serialize as Cereal
import qualified Data.Vector.Unboxed as UnboxedVector
import qualified Data.Vector.Primitive as PrimitiveVector


oneHot :: Prim a => Int {-^ Size -} -> Int {-^ Index -} -> a -> PrimArray a
oneHot size index value =
  runST $ do
    marr <- newPrimArray size
    writePrimArray marr index value
    unsafeFreezePrimArray marr

generate :: Prim a => Int -> (Int -> IO a) -> IO (PrimArray a)
generate size elementIO =
  do
    array <- newPrimArray size
    let
      loop index =
        if index < size
          then do
            element <- elementIO index
            writePrimArray array index element
            loop (succ index)
          else unsafeFreezePrimArray array
      in loop 0

replicate :: Prim a => Int -> IO a -> IO (PrimArray a)
replicate size elementIO =
  do
    array <- newPrimArray size
    let
      loop index =
        if index < size
          then do
            element <- elementIO
            writePrimArray array index element
            loop (succ index)
          else unsafeFreezePrimArray array
      in loop 0

{-| Please notice that this function is highly untested -}
replicateM :: (Monad m, Prim element) => Int -> m element -> m (PrimArray element)
replicateM size elementM =
  do
    !mutable <- return (unsafeDupablePerformIO (newPrimArray size))
    let
      iterate index =
        if index < size
          then do
            element <- elementM
            let !() = unsafeDupablePerformIO (writePrimArray mutable index element)
            iterate (succ index)
          else return (unsafePerformIO (unsafeFreezePrimArray mutable))
      in iterate 0

traverse_ = traversePrimArray_

traverseWithIndexInRange_ :: Prim a => PrimArray a -> Int -> Int -> (Int -> a -> IO ()) -> IO ()
traverseWithIndexInRange_ primArray from to action =
  let iterate index = if index < to
        then do
          action index $! indexPrimArray primArray index
          iterate (succ index)
        else return ()
      in iterate from

toElementsUnfold :: Prim prim => PrimArray prim -> Unfold prim
toElementsUnfold ba = Unfold $ \f z -> foldlPrimArray' f z ba

toElementsUnfoldM :: (Monad m, Prim prim) => PrimArray prim -> UnfoldM m prim
toElementsUnfoldM ba = UnfoldM $ \f z -> foldlPrimArrayM' f z ba

toByteArray :: PrimArray a -> ByteArray
toByteArray (PrimArray unliftedByteArray) =
  ByteArray unliftedByteArray

toPrimitiveVector :: Prim a => PrimArray a -> PrimitiveVector.Vector a
toPrimitiveVector primArray =
  PrimitiveVector.Vector 0 (sizeofPrimArray primArray) (toByteArray primArray)

toUnboxedVector :: Prim a => PrimArray a -> UnboxedVector.Vector a
toUnboxedVector primArray =
  unsafeCoerce (toPrimitiveVector primArray)

cerealGet :: Prim element => Cereal.Get element -> Cereal.Get (PrimArray element)
cerealGet element =
  do
    size <- fromIntegral <$> Cereal.getInt64le
    replicateM size element

cerealPut :: Prim element => Cereal.Putter element -> Cereal.Putter (PrimArray element)
cerealPut element primArrayValue =
  size <> elements
  where
    size = Cereal.putInt64le (fromIntegral (sizeofPrimArray primArrayValue))
    elements = traverse_ element primArrayValue