{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE TypeFamilies #-}

-- This needs to be in its own module to prevent a cyclic dependency
-- between UnliftedBytes and Data.Bytes.Types
module Data.Bytes.Internal
  ( Bytes (..)
  ) where

import Control.Monad.ST (runST)
import Control.Monad.ST.Run (runByteArrayST)
import Data.Bytes.Internal.Show (showsSlice)
import Data.Primitive (ByteArray (..))
import Data.Word (Word8)
import GHC.Exts (Int (I#), IsList (..), compareByteArrays#, isTrue#, sameMutableByteArray#, unsafeCoerce#)

import qualified Data.Foldable as F
import qualified Data.List as L
import qualified Data.Primitive as PM

-- | A slice of a 'ByteArray'.
data Bytes = Bytes
  { Bytes -> ByteArray
array :: {-# UNPACK #-} !ByteArray
  , Bytes -> Int
offset :: {-# UNPACK #-} !Int
  , Bytes -> Int
length :: {-# UNPACK #-} !Int
  }

instance IsList Bytes where
  type Item Bytes = Word8
  fromListN :: Int -> [Item Bytes] -> Bytes
fromListN Int
n [Item Bytes]
xs = ByteArray -> Int -> Int -> Bytes
Bytes (Int -> [Item ByteArray] -> ByteArray
forall l. IsList l => Int -> [Item l] -> l
fromListN Int
n [Item ByteArray]
[Item Bytes]
xs) Int
0 Int
n
  fromList :: [Item Bytes] -> Bytes
fromList [Item Bytes]
xs = Int -> [Item Bytes] -> Bytes
forall l. IsList l => Int -> [Item l] -> l
fromListN ([Word8] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
L.length [Word8]
[Item Bytes]
xs) [Item Bytes]
xs
  toList :: Bytes -> [Item Bytes]
toList (Bytes ByteArray
arr Int
off Int
len) = Int -> Int -> ByteArray -> [Word8]
toListLoop Int
off Int
len ByteArray
arr

toListLoop :: Int -> Int -> ByteArray -> [Word8]
toListLoop :: Int -> Int -> ByteArray -> [Word8]
toListLoop !Int
off !Int
len !ByteArray
arr =
  if Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
    then ByteArray -> Int -> Word8
forall a. Prim a => ByteArray -> Int -> a
PM.indexByteArray ByteArray
arr Int
off Word8 -> [Word8] -> [Word8]
forall a. a -> [a] -> [a]
: Int -> Int -> ByteArray -> [Word8]
toListLoop (Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) ByteArray
arr
    else []

instance Show Bytes where
  showsPrec :: Int -> Bytes -> ShowS
showsPrec Int
_ (Bytes ByteArray
arr Int
off Int
len) String
s = ByteArray -> Int -> Int -> ShowS
showsSlice ByteArray
arr Int
off Int
len String
s

instance Eq Bytes where
  Bytes ByteArray
arr1 Int
off1 Int
len1 == :: Bytes -> Bytes -> Bool
== Bytes ByteArray
arr2 Int
off2 Int
len2
    | Int
len1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
len2 = Bool
False
    | ByteArray -> ByteArray -> Bool
sameByteArray ByteArray
arr1 ByteArray
arr2 Bool -> Bool -> Bool
&& Int
off1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
off2 = Bool
True
    | Bool
otherwise = ByteArray -> Int -> ByteArray -> Int -> Int -> Ordering
compareByteArrays ByteArray
arr1 Int
off1 ByteArray
arr2 Int
off2 Int
len1 Ordering -> Ordering -> Bool
forall a. Eq a => a -> a -> Bool
== Ordering
EQ

instance Ord Bytes where
  compare :: Bytes -> Bytes -> Ordering
compare (Bytes ByteArray
arr1 Int
off1 Int
len1) (Bytes ByteArray
arr2 Int
off2 Int
len2)
    | ByteArray -> ByteArray -> Bool
sameByteArray ByteArray
arr1 ByteArray
arr2 Bool -> Bool -> Bool
&& Int
off1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
off2 Bool -> Bool -> Bool
&& Int
len1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
len2 = Ordering
EQ
    | Bool
otherwise = ByteArray -> Int -> ByteArray -> Int -> Int -> Ordering
compareByteArrays ByteArray
arr1 Int
off1 ByteArray
arr2 Int
off2 (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
len1 Int
len2) Ordering -> Ordering -> Ordering
forall a. Semigroup a => a -> a -> a
<> Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
len1 Int
len2

instance Semigroup Bytes where
  -- TODO: Do the trick to move the data constructor to the outside
  -- of runST.
  Bytes ByteArray
arrA Int
offA Int
lenA <> :: Bytes -> Bytes -> Bytes
<> Bytes ByteArray
arrB Int
offB Int
lenB = (forall s. ST s Bytes) -> Bytes
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s Bytes) -> Bytes)
-> (forall s. ST s Bytes) -> Bytes
forall a b. (a -> b) -> a -> b
$ do
    MutableByteArray s
marr <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray (Int
lenA Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
lenB)
    MutableByteArray (PrimState (ST s))
-> Int -> ByteArray -> Int -> Int -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> ByteArray -> Int -> Int -> m ()
PM.copyByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
marr Int
0 ByteArray
arrA Int
offA Int
lenA
    MutableByteArray (PrimState (ST s))
-> Int -> ByteArray -> Int -> Int -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> ByteArray -> Int -> Int -> m ()
PM.copyByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
marr Int
lenA ByteArray
arrB Int
offB Int
lenB
    ByteArray
r <- MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
marr
    Bytes -> ST s Bytes
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteArray -> Int -> Int -> Bytes
Bytes ByteArray
r Int
0 (Int
lenA Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
lenB))

instance Monoid Bytes where
  mempty :: Bytes
mempty = ByteArray -> Int -> Int -> Bytes
Bytes ByteArray
forall a. Monoid a => a
mempty Int
0 Int
0
  mconcat :: [Bytes] -> Bytes
mconcat [] = Bytes
forall a. Monoid a => a
mempty
  mconcat [Bytes
x] = Bytes
x
  mconcat [Bytes]
bs = ByteArray -> Int -> Int -> Bytes
Bytes ByteArray
r Int
0 Int
fullLen
   where
    !fullLen :: Int
fullLen = (Int -> Bytes -> Int) -> Int -> [Bytes] -> Int
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' (\Int
acc (Bytes ByteArray
_ Int
_ Int
len) -> Int
acc Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
len) Int
0 [Bytes]
bs
    r :: ByteArray
r = (forall s. ST s ByteArray) -> ByteArray
runByteArrayST ((forall s. ST s ByteArray) -> ByteArray)
-> (forall s. ST s ByteArray) -> ByteArray
forall a b. (a -> b) -> a -> b
$ do
      MutableByteArray s
marr <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray Int
fullLen
      !Int
_ <-
        (Int -> Bytes -> ST s Int) -> Int -> [Bytes] -> ST s Int
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
F.foldlM
          ( \ !Int
currLen (Bytes ByteArray
arr Int
off Int
len) -> do
              MutableByteArray (PrimState (ST s))
-> Int -> ByteArray -> Int -> Int -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> ByteArray -> Int -> Int -> m ()
PM.copyByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
marr Int
currLen ByteArray
arr Int
off Int
len
              Int -> ST s Int
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
currLen Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
len)
          )
          Int
0
          [Bytes]
bs
      MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
marr

compareByteArrays :: ByteArray -> Int -> ByteArray -> Int -> Int -> Ordering
{-# INLINE compareByteArrays #-}
compareByteArrays :: ByteArray -> Int -> ByteArray -> Int -> Int -> Ordering
compareByteArrays (ByteArray ByteArray#
ba1#) (I# Int#
off1#) (ByteArray ByteArray#
ba2#) (I# Int#
off2#) (I# Int#
n#) =
  Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int# -> Int
I# (ByteArray# -> Int# -> ByteArray# -> Int# -> Int# -> Int#
compareByteArrays# ByteArray#
ba1# Int#
off1# ByteArray#
ba2# Int#
off2# Int#
n#)) Int
0

sameByteArray :: ByteArray -> ByteArray -> Bool
{-# INLINE sameByteArray #-}
sameByteArray :: ByteArray -> ByteArray -> Bool
sameByteArray (ByteArray ByteArray#
ba1#) (ByteArray ByteArray#
ba2#) =
  Int# -> Bool
isTrue# (MutableByteArray# Any -> MutableByteArray# Any -> Int#
forall s. MutableByteArray# s -> MutableByteArray# s -> Int#
sameMutableByteArray# (ByteArray# -> MutableByteArray# Any
forall a b. a -> b
unsafeCoerce# ByteArray#
ba1#) (ByteArray# -> MutableByteArray# Any
forall a b. a -> b
unsafeCoerce# ByteArray#
ba2#))