{-# LANGUAGE BangPatterns    #-}
{-# LANGUAGE MagicHash       #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE UnboxedTuples   #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module      : Data.ByteString.Short.Extra
-- Copyright   : [2017..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.ByteString.Short.Extra (

  ShortByteString,
  take,
  takeWhile,
  liftSBS,

) where

import Data.ByteString.Short                                        ( ShortByteString )
import qualified Data.ByteString.Short                              as BS
import qualified Data.ByteString.Short.Internal                     as BI

import Language.Haskell.TH                                          ( Q, TExp )
import qualified Language.Haskell.TH                                as TH
import qualified Language.Haskell.TH.Syntax                         as TH

import System.IO.Unsafe
import Prelude                                                      hiding ( take, takeWhile )

import GHC.ST
import GHC.Exts
import GHC.Word


-- | /O(n)/ @'take' n@ applied to the ShortByteString @xs@, returns the prefix
-- of @xs@ of length @n@ as a new ShortByteString, or @xs@ itself if
-- @n > 'length' xs@
--
{-# INLINEABLE take #-}
take :: Int -> ShortByteString -> ShortByteString
take :: Int -> ShortByteString -> ShortByteString
take Int
n ShortByteString
xs
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= ShortByteString -> Int
BS.length ShortByteString
xs = ShortByteString
xs
  | Bool
otherwise         = (forall s. ST s ShortByteString) -> ShortByteString
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s ShortByteString) -> ShortByteString)
-> (forall s. ST s ShortByteString) -> ShortByteString
forall a b. (a -> b) -> a -> b
$ do
      MBA s
mba <- Int -> ST s (MBA s)
forall s. Int -> ST s (MBA s)
newByteArray Int
n
      BA -> Int -> MBA s -> Int -> Int -> ST s ()
forall s. BA -> Int -> MBA s -> Int -> Int -> ST s ()
copyByteArray (ShortByteString -> BA
asBA ShortByteString
xs) Int
0 MBA s
mba Int
0 Int
n
      BA
ba  <- MBA s -> ST s BA
forall s. MBA s -> ST s BA
unsafeFreezeByteArray MBA s
mba
      ShortByteString -> ST s ShortByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (BA -> ShortByteString
asSBS BA
ba)

-- | 'takeWhile', applied to a predicate @p@ and a ShortByteString @xs@, returns
-- the longest prefix (possibly empty) of @xs@ of elements that satisfy @p@.
--
{-# INLINEABLE takeWhile #-}
takeWhile :: (Word8 -> Bool) -> ShortByteString -> ShortByteString
takeWhile :: (Word8 -> Bool) -> ShortByteString -> ShortByteString
takeWhile Word8 -> Bool
f ShortByteString
ps = Int -> ShortByteString -> ShortByteString
take ((Word8 -> Bool) -> ShortByteString -> Int
findIndexOrEnd (Bool -> Bool
not (Bool -> Bool) -> (Word8 -> Bool) -> Word8 -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Bool
f) ShortByteString
ps) ShortByteString
ps


-- | Return the index of the first element satisfying the predicate, otherwise
-- return the length of the string if no such element is found.
--
{-# INLINEABLE findIndexOrEnd #-}
findIndexOrEnd :: (Word8 -> Bool) -> ShortByteString -> Int
findIndexOrEnd :: (Word8 -> Bool) -> ShortByteString -> Int
findIndexOrEnd Word8 -> Bool
p ShortByteString
xs = Int -> Int
go Int
0
  where
    !ba :: BA
ba = ShortByteString -> BA
asBA ShortByteString
xs
    !n :: Int
n  = ShortByteString -> Int
BS.length ShortByteString
xs
    --
    go :: Int -> Int
go !Int
i | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n                   = Int
i
          | Word8 -> Bool
p (BA -> Int -> Word8
indexWord8Array BA
ba Int
i) = Int
i
          | Bool
otherwise                = Int -> Int
go (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)


-- | Lift a ShortByteString into a Template Haskell splice
--
liftSBS :: ShortByteString -> Q (TExp ShortByteString)
liftSBS :: ShortByteString -> Q (TExp ShortByteString)
liftSBS ShortByteString
bs =
  let bytes :: [Word8]
bytes = ShortByteString -> [Word8]
BS.unpack ShortByteString
bs
      len :: Int
len   = ShortByteString -> Int
BS.length ShortByteString
bs
  in
  [|| unsafePerformIO $ BI.createFromPtr $$( TH.unsafeTExpCoerce [| Ptr $(TH.litE (TH.StringPrimL bytes)) |]) len ||]

------------------------------------------------------------------------
-- Internal utils

asBA :: ShortByteString -> BA
asBA :: ShortByteString -> BA
asBA (BI.SBS ByteArray#
ba#) = ByteArray# -> BA
BA# ByteArray#
ba#

asSBS :: BA -> ShortByteString
asSBS :: BA -> ShortByteString
asSBS (BA# ByteArray#
ba#) = ByteArray# -> ShortByteString
BI.SBS ByteArray#
ba#


------------------------------------------------------------------------
-- Primop wrappers

data BA    = BA# ByteArray#
data MBA s = MBA# (MutableByteArray# s)

indexWord8Array :: BA -> Int -> Word8
indexWord8Array :: BA -> Int -> Word8
indexWord8Array (BA# ByteArray#
ba#) (I# Int#
i#) = Word# -> Word8
W8# (ByteArray# -> Int# -> Word#
indexWord8Array# ByteArray#
ba# Int#
i#)

newByteArray :: Int -> ST s (MBA s)
newByteArray :: Int -> ST s (MBA s)
newByteArray (I# Int#
len#) =
    STRep s (MBA s) -> ST s (MBA s)
forall s a. STRep s a -> ST s a
ST (STRep s (MBA s) -> ST s (MBA s))
-> STRep s (MBA s) -> ST s (MBA s)
forall a b. (a -> b) -> a -> b
$ \State# s
s -> case Int# -> State# s -> (# State# s, MutableByteArray# s #)
forall d. Int# -> State# d -> (# State# d, MutableByteArray# d #)
newByteArray# Int#
len# State# s
s of
                 (# State# s
s', MutableByteArray# s
mba# #) -> (# State# s
s', MutableByteArray# s -> MBA s
forall s. MutableByteArray# s -> MBA s
MBA# MutableByteArray# s
mba# #)

unsafeFreezeByteArray :: MBA s -> ST s BA
unsafeFreezeByteArray :: MBA s -> ST s BA
unsafeFreezeByteArray (MBA# MutableByteArray# s
mba#) =
    STRep s BA -> ST s BA
forall s a. STRep s a -> ST s a
ST (STRep s BA -> ST s BA) -> STRep s BA -> ST s BA
forall a b. (a -> b) -> a -> b
$ \State# s
s -> case MutableByteArray# s -> State# s -> (# State# s, ByteArray# #)
forall d.
MutableByteArray# d -> State# d -> (# State# d, ByteArray# #)
unsafeFreezeByteArray# MutableByteArray# s
mba# State# s
s of
                 (# State# s
s', ByteArray#
ba# #) -> (# State# s
s', ByteArray# -> BA
BA# ByteArray#
ba# #)

copyByteArray :: BA -> Int -> MBA s -> Int -> Int -> ST s ()
copyByteArray :: BA -> Int -> MBA s -> Int -> Int -> ST s ()
copyByteArray (BA# ByteArray#
src#) (I# Int#
src_off#) (MBA# MutableByteArray# s
dst#) (I# Int#
dst_off#) (I# Int#
len#) =
    STRep s () -> ST s ()
forall s a. STRep s a -> ST s a
ST (STRep s () -> ST s ()) -> STRep s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ \State# s
s -> case ByteArray#
-> Int#
-> MutableByteArray# s
-> Int#
-> Int#
-> State# s
-> State# s
forall d.
ByteArray#
-> Int#
-> MutableByteArray# d
-> Int#
-> Int#
-> State# d
-> State# d
copyByteArray# ByteArray#
src# Int#
src_off# MutableByteArray# s
dst# Int#
dst_off# Int#
len# State# s
s of
                 State# s
s' -> (# State# s
s', () #)