{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase #-}

{- | If you are interested in sub-arrays of 'MutableByteArray's (e.g. writing
quicksort), it would be grossly inefficient to make a copy of the sub-array.
On the other hand, it'd be really annoying to track limit indices by hand.

This module defines the 'MutableBytes' type which exposes a standard array
interface for a sub-arrays without copying and without manual index
manipulation. For immutable arrays, see 'Data.Bytes'.
-}
module Data.Bytes.Mutable
  ( -- * Types
    MutableBytes

    -- * Filtering
  , takeWhile
  , dropWhile

    -- * Unsafe Slicing
  , unsafeTake
  , unsafeDrop

    -- * Conversion
  , fromMutableByteArray
  ) where

import Prelude hiding (dropWhile, takeWhile)

import Control.Monad.Primitive (PrimMonad, PrimState)
import Data.Bytes.Types (MutableBytes (MutableBytes))
import Data.Primitive (MutableByteArray)
import Data.Word (Word8)

import qualified Data.Primitive as PM

{- | Take bytes while the predicate is true, aliasing the
argument array.
-}
takeWhile ::
  (PrimMonad m) =>
  (Word8 -> m Bool) ->
  MutableBytes (PrimState m) ->
  m (MutableBytes (PrimState m))
{-# INLINE takeWhile #-}
takeWhile :: forall (m :: * -> *).
PrimMonad m =>
(Word8 -> m Bool)
-> MutableBytes (PrimState m) -> m (MutableBytes (PrimState m))
takeWhile Word8 -> m Bool
k MutableBytes (PrimState m)
b = do
  Int
n <- (Word8 -> m Bool) -> MutableBytes (PrimState m) -> m Int
forall (m :: * -> *).
PrimMonad m =>
(Word8 -> m Bool) -> MutableBytes (PrimState m) -> m Int
countWhile Word8 -> m Bool
k MutableBytes (PrimState m)
b
  MutableBytes (PrimState m) -> m (MutableBytes (PrimState m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> MutableBytes (PrimState m) -> MutableBytes (PrimState m)
forall s. Int -> MutableBytes s -> MutableBytes s
unsafeTake Int
n MutableBytes (PrimState m)
b)

{- | Drop bytes while the predicate is true, aliasing the
argument array.
-}
dropWhile ::
  (PrimMonad m) =>
  (Word8 -> m Bool) ->
  MutableBytes (PrimState m) ->
  m (MutableBytes (PrimState m))
{-# INLINE dropWhile #-}
dropWhile :: forall (m :: * -> *).
PrimMonad m =>
(Word8 -> m Bool)
-> MutableBytes (PrimState m) -> m (MutableBytes (PrimState m))
dropWhile Word8 -> m Bool
k MutableBytes (PrimState m)
b = do
  Int
n <- (Word8 -> m Bool) -> MutableBytes (PrimState m) -> m Int
forall (m :: * -> *).
PrimMonad m =>
(Word8 -> m Bool) -> MutableBytes (PrimState m) -> m Int
countWhile Word8 -> m Bool
k MutableBytes (PrimState m)
b
  MutableBytes (PrimState m) -> m (MutableBytes (PrimState m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> MutableBytes (PrimState m) -> MutableBytes (PrimState m)
forall s. Int -> MutableBytes s -> MutableBytes s
unsafeDrop Int
n MutableBytes (PrimState m)
b)

-- | Take the first @n@ bytes from the argument, aliasing it.
unsafeTake :: Int -> MutableBytes s -> MutableBytes s
{-# INLINE unsafeTake #-}
unsafeTake :: forall s. Int -> MutableBytes s -> MutableBytes s
unsafeTake Int
n (MutableBytes MutableByteArray s
arr Int
off Int
_) =
  MutableByteArray s -> Int -> Int -> MutableBytes s
forall s. MutableByteArray s -> Int -> Int -> MutableBytes s
MutableBytes MutableByteArray s
arr Int
off Int
n

{- | Drop the first @n@ bytes from the argument, aliasing it.
The new length will be @len - n@.
-}
unsafeDrop :: Int -> MutableBytes s -> MutableBytes s
{-# INLINE unsafeDrop #-}
unsafeDrop :: forall s. Int -> MutableBytes s -> MutableBytes s
unsafeDrop Int
n (MutableBytes MutableByteArray s
arr Int
off Int
len) =
  MutableByteArray s -> Int -> Int -> MutableBytes s
forall s. MutableByteArray s -> Int -> Int -> MutableBytes s
MutableBytes MutableByteArray s
arr (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n) (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n)

{- | Create a slice of 'MutableBytes' that spans the entire
argument array. This aliases the argument.
-}
fromMutableByteArray ::
  (PrimMonad m) =>
  MutableByteArray (PrimState m) ->
  m (MutableBytes (PrimState m))
{-# INLINE fromMutableByteArray #-}
fromMutableByteArray :: forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m (MutableBytes (PrimState m))
fromMutableByteArray MutableByteArray (PrimState m)
mba = do
  Int
sz <- MutableByteArray (PrimState m) -> m Int
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m Int
PM.getSizeofMutableByteArray MutableByteArray (PrimState m)
mba
  MutableBytes (PrimState m) -> m (MutableBytes (PrimState m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MutableByteArray (PrimState m)
-> Int -> Int -> MutableBytes (PrimState m)
forall s. MutableByteArray s -> Int -> Int -> MutableBytes s
MutableBytes MutableByteArray (PrimState m)
mba Int
0 Int
sz)

-- Internal. The returns the number of bytes that match the
-- predicate until the first non-match occurs. If all bytes
-- match the predicate, this will return the length originally
-- provided.
countWhile ::
  (PrimMonad m) =>
  (Word8 -> m Bool) ->
  MutableBytes (PrimState m) ->
  m Int
{-# INLINE countWhile #-}
countWhile :: forall (m :: * -> *).
PrimMonad m =>
(Word8 -> m Bool) -> MutableBytes (PrimState m) -> m Int
countWhile Word8 -> m Bool
k (MutableBytes MutableByteArray (PrimState m)
arr Int
off0 Int
len0) = Int -> Int -> Int -> m Int
forall {t} {b}. (Ord t, Num t, Num b) => Int -> t -> b -> m b
go Int
off0 Int
len0 Int
0
 where
  go :: Int -> t -> b -> m b
go !Int
off !t
len !b
n =
    if t
len t -> t -> Bool
forall a. Ord a => a -> a -> Bool
> t
0
      then
        (Word8 -> m Bool
k (Word8 -> m Bool) -> m Word8 -> m Bool
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< MutableByteArray (PrimState m) -> Int -> m Word8
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> m a
PM.readByteArray MutableByteArray (PrimState m)
arr Int
off) m Bool -> (Bool -> m b) -> m b
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          Bool
True -> Int -> t -> b -> m b
go (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (t
len t -> t -> t
forall a. Num a => a -> a -> a
- t
1) (b
n b -> b -> b
forall a. Num a => a -> a -> a
+ b
1)
          Bool
False -> b -> m b
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure b
n
      else b -> m b
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure b
n