{-# language BangPatterns #-}
{-# language LambdaCase #-}

module Data.Bytes.Mutable
  ( -- * Types
    MutableBytes
    -- * Filtering
  , takeWhile
  , dropWhile
    -- * Unsafe Slicing
  , unsafeTake
  , unsafeDrop
    -- * Conversion
  , fromMutableByteArray
  ) where

import Prelude hiding (takeWhile,dropWhile)

import Data.Bytes.Types (MutableBytes(MutableBytes))
import Data.Primitive (MutableByteArray)
import Data.Word (Word8)
import Control.Monad.Primitive (PrimMonad,PrimState)

import qualified Data.Primitive as PM

-- | Take bytes while the predicate is true, aliasing the
-- argument array.
takeWhile :: PrimMonad m
  => (Word8 -> m Bool)
  -> MutableBytes (PrimState m)
  -> m (MutableBytes (PrimState m))
{-# inline takeWhile #-}
takeWhile k b = do
  n <- countWhile k b
  pure (unsafeTake n b)

-- | Drop bytes while the predicate is true, aliasing the
-- argument array.
dropWhile :: PrimMonad m
  => (Word8 -> m Bool)
  -> MutableBytes (PrimState m)
  -> m (MutableBytes (PrimState m))
{-# inline dropWhile #-}
dropWhile k b = do
  n <- countWhile k b
  pure (unsafeDrop n b)

-- | Take the first @n@ bytes from the argument, aliasing it.
unsafeTake :: Int -> MutableBytes s -> MutableBytes s
{-# inline unsafeTake #-}
unsafeTake n (MutableBytes arr off _) =
  MutableBytes arr off n

-- | Drop the first @n@ bytes from the argument, aliasing it.
-- The new length will be @len - n@.
unsafeDrop :: Int -> MutableBytes s -> MutableBytes s
{-# inline unsafeDrop #-}
unsafeDrop n (MutableBytes arr off len) =
  MutableBytes arr (off + n) (len - n)

-- | Create a slice of 'MutableBytes' that spans the entire
-- argument array. This aliases the argument.
fromMutableByteArray :: PrimMonad m
  => MutableByteArray (PrimState m)
  -> m (MutableBytes (PrimState m))
{-# inline fromMutableByteArray #-}
fromMutableByteArray mba = do
  sz <- PM.getSizeofMutableByteArray mba
  pure (MutableBytes mba 0 sz)

-- Internal. The returns the number of bytes that match the
-- predicate until the first non-match occurs. If all bytes
-- match the predicate, this will return the length originally
-- provided.
countWhile :: PrimMonad m
  => (Word8 -> m Bool)
  -> MutableBytes (PrimState m)
  -> m Int
{-# inline countWhile #-}
countWhile k (MutableBytes arr off0 len0) = go off0 len0 0 where
  go !off !len !n = if len > 0
    then (k =<< PM.readByteArray arr off) >>= \case
      True -> go (off + 1) (len - 1) (n + 1)
      False -> pure n
    else pure n