{-# LANGUAGE CPP                 #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE FlexibleInstances   #-}
{-# LANGUAGE Rank2Types          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections       #-}

module HaskellWorks.Data.Vector.AsVector8ns
  ( AsVector8ns(..)
  ) where

import Control.Monad.ST
import Data.Word
import Foreign.ForeignPtr
import HaskellWorks.Data.Vector.AsVector8

import qualified Data.ByteString              as BS
import qualified Data.ByteString.Internal     as BS
import qualified Data.ByteString.Lazy         as LBS
import qualified Data.Vector.Storable         as DVS
import qualified Data.Vector.Storable.Mutable as DVSM

#if !MIN_VERSION_base(4,13,0)
import Control.Applicative ((<$>)) -- Fix warning in ghc >= 9.2
#endif

class AsVector8ns a where
  -- | Represent the value as a list of Vector of 'n' Word8 chunks.  The last chunk will
  -- also be of the specified chunk size filled with trailing zeros.
  asVector8ns :: Int -> a -> [DVS.Vector Word8]

instance AsVector8ns LBS.ByteString where
  asVector8ns :: Int -> ByteString -> [Vector Word8]
asVector8ns Int
n = forall a. AsVector8ns a => Int -> a -> [Vector Word8]
asVector8ns Int
n forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString]
LBS.toChunks
  {-# INLINE asVector8ns #-}

instance AsVector8ns [BS.ByteString] where
  asVector8ns :: Int -> [ByteString] -> [Vector Word8]
asVector8ns = Int -> [ByteString] -> [Vector Word8]
bytestringsToVectors
  {-# INLINE asVector8ns #-}

bytestringsToVectors :: Int -> [BS.ByteString] -> [DVS.Vector Word8]
bytestringsToVectors :: Int -> [ByteString] -> [Vector Word8]
bytestringsToVectors Int
n = [ByteString] -> [Vector Word8]
go
  where go :: [BS.ByteString] -> [DVS.Vector Word8]
        go :: [ByteString] -> [Vector Word8]
go [ByteString]
bss = case [ByteString]
bss of
          (ByteString
cs:[ByteString]
css) -> let csz :: Int
csz = ByteString -> Int
BS.length ByteString
cs in
            if Int
csz forall a. Ord a => a -> a -> Bool
>= Int
n
              then if Int
csz forall a. Integral a => a -> a -> a
`mod` Int
n forall a. Eq a => a -> a -> Bool
== Int
0
                then forall a. AsVector8 a => a -> Vector Word8
asVector8 ByteString
csforall a. a -> [a] -> [a]
:Int -> [ByteString] -> [Vector Word8]
bytestringsToVectors Int
n [ByteString]
css
                else let p :: Int
p = (Int
csz forall a. Integral a => a -> a -> a
`div` Int
n) forall a. Num a => a -> a -> a
* Int
n in
                  forall a. AsVector8 a => a -> Vector Word8
asVector8 (Int -> ByteString -> ByteString
BS.take Int
p ByteString
cs)forall a. a -> [a] -> [a]
:Int -> [ByteString] -> [Vector Word8]
bytestringsToVectors Int
n (Int -> ByteString -> ByteString
BS.drop Int
p ByteString
csforall a. a -> [a] -> [a]
:[ByteString]
css)
              else if Int
csz forall a. Ord a => a -> a -> Bool
> Int
0
                then case forall (f :: * -> *) a.
(Traversable f, Storable a) =>
(forall s. ST s (f (MVector s a))) -> f (Vector a)
DVS.createT (forall s.
Int -> [ByteString] -> ST s ([ByteString], MVector s Word8)
buildOneVector Int
n [ByteString]
bss) of
                  ([ByteString]
dss, Vector Word8
ws) -> if forall a. Storable a => Vector a -> Int
DVS.length Vector Word8
ws forall a. Ord a => a -> a -> Bool
> Int
0
                    then Vector Word8
wsforall a. a -> [a] -> [a]
:[ByteString] -> [Vector Word8]
go [ByteString]
dss
                    else []
                else Int -> [ByteString] -> [Vector Word8]
bytestringsToVectors Int
n [ByteString]
css
          [] -> []
{-# INLINE bytestringsToVectors #-}

buildOneVector :: forall s. Int -> [BS.ByteString] -> ST s ([BS.ByteString], DVS.MVector s Word8)
buildOneVector :: forall s.
Int -> [ByteString] -> ST s ([ByteString], MVector s Word8)
buildOneVector Int
n [ByteString]
ss = case forall a. (a -> Bool) -> [a] -> [a]
dropWhile ((forall a. Eq a => a -> a -> Bool
== Int
0) forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Int
BS.length) [ByteString]
ss of
  [] -> ([],) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
DVSM.new Int
0
  [ByteString]
cs -> do
    MVector s Word8
v64 <- forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
DVSM.unsafeNew Int
n
    let v8 :: MVector s Word8
v8 = forall a b s.
(Storable a, Storable b) =>
MVector s a -> MVector s b
DVSM.unsafeCast MVector s Word8
v64
    [ByteString]
rs  <- [ByteString] -> MVector s Word8 -> ST s [ByteString]
go [ByteString]
cs MVector s Word8
v8
    forall (m :: * -> *) a. Monad m => a -> m a
return ([ByteString]
rs, MVector s Word8
v64)
  where go :: [BS.ByteString] -> DVSM.MVector s Word8 -> ST s [BS.ByteString]
        go :: [ByteString] -> MVector s Word8 -> ST s [ByteString]
go [ByteString]
ts MVector s Word8
v = if forall a s. Storable a => MVector s a -> Int
DVSM.length MVector s Word8
v forall a. Ord a => a -> a -> Bool
> Int
0
          then case [ByteString]
ts of
            (ByteString
u:[ByteString]
us) -> if ByteString -> Int
BS.length ByteString
u forall a. Ord a => a -> a -> Bool
<= forall a s. Storable a => MVector s a -> Int
DVSM.length MVector s Word8
v
              then case forall a s.
Storable a =>
Int -> MVector s a -> (MVector s a, MVector s a)
DVSM.splitAt (ByteString -> Int
BS.length ByteString
u) MVector s Word8
v of
                (MVector s Word8
va, MVector s Word8
vb) -> do
                  forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> MVector (PrimState m) a -> m ()
DVSM.copy MVector s Word8
va (forall s. ByteString -> MVector s Word8
byteStringToVector8 ByteString
u)
                  [ByteString] -> MVector s Word8 -> ST s [ByteString]
go [ByteString]
us MVector s Word8
vb
              else case Int -> ByteString -> (ByteString, ByteString)
BS.splitAt (forall a s. Storable a => MVector s a -> Int
DVSM.length MVector s Word8
v) ByteString
u of
                (ByteString
ua, ByteString
ub) -> do
                  forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> MVector (PrimState m) a -> m ()
DVSM.copy MVector s Word8
v (forall s. ByteString -> MVector s Word8
byteStringToVector8 ByteString
ua)
                  forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
ubforall a. a -> [a] -> [a]
:[ByteString]
us)
            [] -> do
              forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> a -> m ()
DVSM.set MVector s Word8
v Word8
0
              forall (m :: * -> *) a. Monad m => a -> m a
return []
          else forall (m :: * -> *) a. Monad m => a -> m a
return [ByteString]
ts
        {-# INLINE go #-}
{-# INLINE buildOneVector #-}

byteStringToVector8 :: BS.ByteString -> DVSM.MVector s Word8
byteStringToVector8 :: forall s. ByteString -> MVector s Word8
byteStringToVector8 ByteString
bs = case ByteString -> (ForeignPtr Word8, Int, Int)
BS.toForeignPtr ByteString
bs of
  (ForeignPtr Word8
fptr, Int
off, Int
len) -> forall a s. Storable a => ForeignPtr a -> Int -> Int -> MVector s a
DVSM.unsafeFromForeignPtr (forall a b. ForeignPtr a -> ForeignPtr b
castForeignPtr ForeignPtr Word8
fptr) Int
off Int
len
{-# INLINE byteStringToVector8 #-}