{-# LANGUAGE CPP                 #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE FlexibleInstances   #-}
{-# LANGUAGE Rank2Types          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections       #-}

module HaskellWorks.Data.Vector.AsVector8s
  ( AsVector8s(..)
  ) where

import Control.Monad.ST
import Data.Word
import Foreign.ForeignPtr

import qualified Data.ByteString              as BS
import qualified Data.ByteString.Internal     as BS
import qualified Data.ByteString.Lazy         as LBS
import qualified Data.Vector.Storable         as DVS
import qualified Data.Vector.Storable.Mutable as DVSM

#if !MIN_VERSION_base(4,13,0)
import Control.Applicative ((<$>)) -- Fix warning in ghc >= 9.2
#endif

class AsVector8s a where
  -- | Represent the value as a list of Vector of 'n' Word8 chunks.  The last chunk will
  -- also be of the specified chunk size filled with trailing zeros.
  asVector8s :: Int -> a -> [DVS.Vector Word8]

instance AsVector8s LBS.ByteString where
  asVector8s :: Int -> ByteString -> [Vector Word8]
asVector8s Int
n = forall a. AsVector8s a => Int -> a -> [Vector Word8]
asVector8s Int
n forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString]
LBS.toChunks
  {-# INLINE asVector8s #-}

instance AsVector8s [BS.ByteString] where
  asVector8s :: Int -> [ByteString] -> [Vector Word8]
asVector8s = Int -> [ByteString] -> [Vector Word8]
bytestringsToVectors
  {-# INLINE asVector8s #-}

bytestringsToVectors :: Int -> [BS.ByteString] -> [DVS.Vector Word8]
bytestringsToVectors :: Int -> [ByteString] -> [Vector Word8]
bytestringsToVectors Int
n = [ByteString] -> [Vector Word8]
go
  where go :: [BS.ByteString] -> [DVS.Vector Word8]
        go :: [ByteString] -> [Vector Word8]
go [ByteString]
bs = case forall (f :: * -> *) a.
(Traversable f, Storable a) =>
(forall s. ST s (f (MVector s a))) -> f (Vector a)
DVS.createT (forall s.
Int -> [ByteString] -> ST s ([ByteString], MVector s Word8)
buildOneVector Int
n [ByteString]
bs) of
          ([ByteString]
cs, Vector Word8
ws) -> if forall a. Storable a => Vector a -> Int
DVS.length Vector Word8
ws forall a. Ord a => a -> a -> Bool
> Int
0
            then Vector Word8
wsforall a. a -> [a] -> [a]
:[ByteString] -> [Vector Word8]
go [ByteString]
cs
            else []
{-# INLINE bytestringsToVectors #-}

buildOneVector :: forall s. Int -> [BS.ByteString] -> ST s ([BS.ByteString], DVS.MVector s Word8)
buildOneVector :: forall s.
Int -> [ByteString] -> ST s ([ByteString], MVector s Word8)
buildOneVector Int
n [ByteString]
ss = case forall a. (a -> Bool) -> [a] -> [a]
dropWhile ((forall a. Eq a => a -> a -> Bool
== Int
0) forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Int
BS.length) [ByteString]
ss of
  [] -> ([],) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
DVSM.new Int
0
  [ByteString]
cs -> do
    MVector s Word8
v8 <- forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
DVSM.unsafeNew Int
n
    [ByteString]
rs  <- [ByteString] -> MVector s Word8 -> ST s [ByteString]
go [ByteString]
cs MVector s Word8
v8
    forall (m :: * -> *) a. Monad m => a -> m a
return ([ByteString]
rs, MVector s Word8
v8)
  where go :: [BS.ByteString] -> DVSM.MVector s Word8 -> ST s [BS.ByteString]
        go :: [ByteString] -> MVector s Word8 -> ST s [ByteString]
go [ByteString]
ts MVector s Word8
v = if forall a s. Storable a => MVector s a -> Int
DVSM.length MVector s Word8
v forall a. Ord a => a -> a -> Bool
> Int
0
          then case [ByteString]
ts of
            (ByteString
u:[ByteString]
us) -> if ByteString -> Int
BS.length ByteString
u forall a. Ord a => a -> a -> Bool
<= forall a s. Storable a => MVector s a -> Int
DVSM.length MVector s Word8
v
              then case forall a s.
Storable a =>
Int -> MVector s a -> (MVector s a, MVector s a)
DVSM.splitAt (ByteString -> Int
BS.length ByteString
u) MVector s Word8
v of
                (MVector s Word8
va, MVector s Word8
vb) -> do
                  forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> MVector (PrimState m) a -> m ()
DVSM.copy MVector s Word8
va (forall s. ByteString -> MVector s Word8
byteStringToVector8 ByteString
u)
                  [ByteString] -> MVector s Word8 -> ST s [ByteString]
go [ByteString]
us MVector s Word8
vb
              else case Int -> ByteString -> (ByteString, ByteString)
BS.splitAt (forall a s. Storable a => MVector s a -> Int
DVSM.length MVector s Word8
v) ByteString
u of
                (ByteString
ua, ByteString
ub) -> do
                  forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> MVector (PrimState m) a -> m ()
DVSM.copy MVector s Word8
v (forall s. ByteString -> MVector s Word8
byteStringToVector8 ByteString
ua)
                  forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
ubforall a. a -> [a] -> [a]
:[ByteString]
us)
            [] -> do
              forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> a -> m ()
DVSM.set MVector s Word8
v Word8
0
              forall (m :: * -> *) a. Monad m => a -> m a
return []
          else forall (m :: * -> *) a. Monad m => a -> m a
return [ByteString]
ts
        {-# INLINE go #-}
{-# INLINE buildOneVector #-}

byteStringToVector8 :: BS.ByteString -> DVSM.MVector s Word8
byteStringToVector8 :: forall s. ByteString -> MVector s Word8
byteStringToVector8 ByteString
bs = case ByteString -> (ForeignPtr Word8, Int, Int)
BS.toForeignPtr ByteString
bs of
  (ForeignPtr Word8
fptr, Int
off, Int
len) -> forall a s. Storable a => ForeignPtr a -> Int -> Int -> MVector s a
DVSM.unsafeFromForeignPtr (forall a b. ForeignPtr a -> ForeignPtr b
castForeignPtr ForeignPtr Word8
fptr) Int
off Int
len
{-# INLINE byteStringToVector8 #-}