{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE UnboxedTuples #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

module Codec.Compression.Zlib.OutputWindow (
  OutputWindow,
  emptyWindow,
  emitExcess,
  finalizeWindow,
  addByte,
  addChunk,
  addOldChunk,
) where

import Control.Monad (foldM)
import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L
import qualified Data.ByteString.Short as SBS
import Data.ByteString.Short.Internal (ShortByteString (SBS))
import qualified Data.Primitive as Prim
import qualified Data.Vector.Primitive as V
import qualified Data.Vector.Primitive.Mutable as MV
import GHC.ST (ST (..))
import GHC.Word (Word8 (..))

windowSize :: Int
windowSize :: Int
windowSize = Int
128 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1024

data OutputWindow s = OutputWindow
  { OutputWindow s -> MVector s Word8
owWindow :: {-# UNPACK #-} !(MV.MVector s Word8)
  , OutputWindow s -> Int
owNext :: {-# UNPACK #-} !Int
  }

emptyWindow :: ST s (OutputWindow s)
emptyWindow :: ST s (OutputWindow s)
emptyWindow = do
  MVector s Word8
window <- Int -> ST s (MVector (PrimState (ST s)) Word8)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MVector (PrimState m) a)
MV.new Int
windowSize
  OutputWindow s -> ST s (OutputWindow s)
forall (m :: * -> *) a. Monad m => a -> m a
return (MVector s Word8 -> Int -> OutputWindow s
forall s. MVector s Word8 -> Int -> OutputWindow s
OutputWindow MVector s Word8
window Int
0)

excessChunkSize :: Int
excessChunkSize :: Int
excessChunkSize = Int
32768

emitExcess :: OutputWindow s -> ST s (Maybe (S.ByteString, OutputWindow s))
emitExcess :: OutputWindow s -> ST s (Maybe (ByteString, OutputWindow s))
emitExcess OutputWindow{owWindow :: forall s. OutputWindow s -> MVector s Word8
owWindow = MVector s Word8
window, owNext :: forall s. OutputWindow s -> Int
owNext = Int
initialOffset}
  | Int
initialOffset Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
excessChunkSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2 = Maybe (ByteString, OutputWindow s)
-> ST s (Maybe (ByteString, OutputWindow s))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (ByteString, OutputWindow s)
forall a. Maybe a
Nothing
  | Bool
otherwise = do
    Vector Word8
toEmit <- MVector (PrimState (ST s)) Word8 -> ST s (Vector Word8)
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
V.freeze (MVector (PrimState (ST s)) Word8 -> ST s (Vector Word8))
-> MVector (PrimState (ST s)) Word8 -> ST s (Vector Word8)
forall a b. (a -> b) -> a -> b
$ Int -> Int -> MVector s Word8 -> MVector s Word8
forall a s. Prim a => Int -> Int -> MVector s a -> MVector s a
MV.slice Int
0 Int
excessChunkSize MVector s Word8
window
    let excessLength :: Int
excessLength = Int
initialOffset Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
excessChunkSize
    -- Need move as these can overlap!
    MVector (PrimState (ST s)) Word8
-> MVector (PrimState (ST s)) Word8 -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
MVector (PrimState m) a -> MVector (PrimState m) a -> m ()
MV.move (Int -> Int -> MVector s Word8 -> MVector s Word8
forall a s. Prim a => Int -> Int -> MVector s a -> MVector s a
MV.slice Int
0 Int
excessLength MVector s Word8
window) (Int -> Int -> MVector s Word8 -> MVector s Word8
forall a s. Prim a => Int -> Int -> MVector s a -> MVector s a
MV.slice Int
excessChunkSize Int
excessLength MVector s Word8
window)
    let ow' :: OutputWindow s
ow' = MVector s Word8 -> Int -> OutputWindow s
forall s. MVector s Word8 -> Int -> OutputWindow s
OutputWindow MVector s Word8
window Int
excessLength
    Maybe (ByteString, OutputWindow s)
-> ST s (Maybe (ByteString, OutputWindow s))
forall (m :: * -> *) a. Monad m => a -> m a
return ((ByteString, OutputWindow s) -> Maybe (ByteString, OutputWindow s)
forall a. a -> Maybe a
Just (ShortByteString -> ByteString
SBS.fromShort (ShortByteString -> ByteString) -> ShortByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Vector Word8 -> ShortByteString
toByteString Vector Word8
toEmit, OutputWindow s
ow'))

finalizeWindow :: OutputWindow s -> ST s S.ByteString
finalizeWindow :: OutputWindow s -> ST s ByteString
finalizeWindow OutputWindow s
ow = do
  -- safe as we're doing it at the end
  Vector Word8
res <- MVector (PrimState (ST s)) Word8 -> ST s (Vector Word8)
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze (Int -> Int -> MVector s Word8 -> MVector s Word8
forall a s. Prim a => Int -> Int -> MVector s a -> MVector s a
MV.slice Int
0 (OutputWindow s -> Int
forall s. OutputWindow s -> Int
owNext OutputWindow s
ow) (OutputWindow s -> MVector s Word8
forall s. OutputWindow s -> MVector s Word8
owWindow OutputWindow s
ow))
  ByteString -> ST s ByteString
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> ST s ByteString) -> ByteString -> ST s ByteString
forall a b. (a -> b) -> a -> b
$ ShortByteString -> ByteString
SBS.fromShort (ShortByteString -> ByteString) -> ShortByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Vector Word8 -> ShortByteString
toByteString Vector Word8
res

-- -----------------------------------------------------------------------------

addByte :: OutputWindow s -> Word8 -> ST s (OutputWindow s)
addByte :: OutputWindow s -> Word8 -> ST s (OutputWindow s)
addByte !OutputWindow s
ow !Word8
b = do
  let offset :: Int
offset = OutputWindow s -> Int
forall s. OutputWindow s -> Int
owNext OutputWindow s
ow
  MVector (PrimState (ST s)) Word8 -> Int -> Word8 -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
MVector (PrimState m) a -> Int -> a -> m ()
MV.write (OutputWindow s -> MVector s Word8
forall s. OutputWindow s -> MVector s Word8
owWindow OutputWindow s
ow) Int
offset Word8
b
  OutputWindow s -> ST s (OutputWindow s)
forall (m :: * -> *) a. Monad m => a -> m a
return OutputWindow s
ow{owNext :: Int
owNext = Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1}

addChunk :: OutputWindow s -> L.ByteString -> ST s (OutputWindow s)
addChunk :: OutputWindow s -> ByteString -> ST s (OutputWindow s)
addChunk !OutputWindow s
ow !ByteString
bs = (OutputWindow s -> ByteString -> ST s (OutputWindow s))
-> OutputWindow s -> [ByteString] -> ST s (OutputWindow s)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM OutputWindow s -> ByteString -> ST s (OutputWindow s)
forall s. OutputWindow s -> ByteString -> ST s (OutputWindow s)
copyChunk OutputWindow s
ow (ByteString -> [ByteString]
L.toChunks ByteString
bs)

copyChunk :: OutputWindow s -> S.ByteString -> ST s (OutputWindow s)
copyChunk :: OutputWindow s -> ByteString -> ST s (OutputWindow s)
copyChunk OutputWindow s
ow ByteString
sbstr = do
  -- safe as we're never going to look at this again
  MVector s Word8
ba <- Vector Word8 -> ST s (MVector (PrimState (ST s)) Word8)
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
Vector a -> m (MVector (PrimState m) a)
V.unsafeThaw (Vector Word8 -> ST s (MVector (PrimState (ST s)) Word8))
-> Vector Word8 -> ST s (MVector (PrimState (ST s)) Word8)
forall a b. (a -> b) -> a -> b
$ ShortByteString -> Vector Word8
fromByteString (ShortByteString -> Vector Word8)
-> ShortByteString -> Vector Word8
forall a b. (a -> b) -> a -> b
$ ByteString -> ShortByteString
SBS.toShort ByteString
sbstr
  let offset :: Int
offset = OutputWindow s -> Int
forall s. OutputWindow s -> Int
owNext OutputWindow s
ow
      len :: Int
len = MVector s Word8 -> Int
forall a s. Prim a => MVector s a -> Int
MV.length MVector s Word8
ba
  MVector (PrimState (ST s)) Word8
-> MVector (PrimState (ST s)) Word8 -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
MVector (PrimState m) a -> MVector (PrimState m) a -> m ()
MV.copy (Int -> Int -> MVector s Word8 -> MVector s Word8
forall a s. Prim a => Int -> Int -> MVector s a -> MVector s a
MV.slice Int
offset Int
len (OutputWindow s -> MVector s Word8
forall s. OutputWindow s -> MVector s Word8
owWindow OutputWindow s
ow)) MVector s Word8
MVector (PrimState (ST s)) Word8
ba
  OutputWindow s -> ST s (OutputWindow s)
forall (m :: * -> *) a. Monad m => a -> m a
return OutputWindow s
ow{owNext :: Int
owNext = Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
len}

addOldChunk :: OutputWindow s -> Int -> Int -> ST s (OutputWindow s, S.ByteString)
addOldChunk :: OutputWindow s -> Int -> Int -> ST s (OutputWindow s, ByteString)
addOldChunk (OutputWindow MVector s Word8
window Int
next) Int
dist Int
len = do
  -- zlib can ask us to copy an "old" chunk that extends past our current offset.
  -- The intention is that we then start copying the "new" data we just copied into
  -- place. 'copyChunked' handles this for us.
  MVector s Word8 -> MVector s Word8 -> Int -> ST s ()
forall s. MVector s Word8 -> MVector s Word8 -> Int -> ST s ()
copyChunked (Int -> Int -> MVector s Word8 -> MVector s Word8
forall a s. Prim a => Int -> Int -> MVector s a -> MVector s a
MV.slice Int
next Int
len MVector s Word8
window) (Int -> Int -> MVector s Word8 -> MVector s Word8
forall a s. Prim a => Int -> Int -> MVector s a -> MVector s a
MV.slice (Int
next Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
dist) Int
len MVector s Word8
window) Int
dist
  Vector Word8
result <- MVector (PrimState (ST s)) Word8 -> ST s (Vector Word8)
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
V.freeze (MVector (PrimState (ST s)) Word8 -> ST s (Vector Word8))
-> MVector (PrimState (ST s)) Word8 -> ST s (Vector Word8)
forall a b. (a -> b) -> a -> b
$ Int -> Int -> MVector s Word8 -> MVector s Word8
forall a s. Prim a => Int -> Int -> MVector s a -> MVector s a
MV.slice Int
next Int
len MVector s Word8
window
  (OutputWindow s, ByteString) -> ST s (OutputWindow s, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (MVector s Word8 -> Int -> OutputWindow s
forall s. MVector s Word8 -> Int -> OutputWindow s
OutputWindow MVector s Word8
window (Int
next Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
len), ShortByteString -> ByteString
SBS.fromShort (ShortByteString -> ByteString) -> ShortByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Vector Word8 -> ShortByteString
toByteString Vector Word8
result)

{- | A copy function that copies the buffers sequentially in chunks no larger than
 the stated size. This allows us to handle the insane zlib behaviour.
-}
copyChunked :: MV.MVector s Word8 -> MV.MVector s Word8 -> Int -> ST s ()
copyChunked :: MVector s Word8 -> MVector s Word8 -> Int -> ST s ()
copyChunked MVector s Word8
dest MVector s Word8
src Int
chunkSize = Int -> Int -> ST s ()
go Int
0 (MVector s Word8 -> Int
forall a s. Prim a => MVector s a -> Int
MV.length MVector s Word8
src)
 where
  go :: Int -> Int -> ST s ()
go Int
_ Int
0 = () -> ST s ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  go Int
copied Int
toCopy = do
    let thisChunkSize :: Int
thisChunkSize = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
toCopy Int
chunkSize
    MVector (PrimState (ST s)) Word8
-> MVector (PrimState (ST s)) Word8 -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
MVector (PrimState m) a -> MVector (PrimState m) a -> m ()
MV.copy (Int -> Int -> MVector s Word8 -> MVector s Word8
forall a s. Prim a => Int -> Int -> MVector s a -> MVector s a
MV.slice Int
copied Int
thisChunkSize MVector s Word8
dest) (Int -> Int -> MVector s Word8 -> MVector s Word8
forall a s. Prim a => Int -> Int -> MVector s a -> MVector s a
MV.slice Int
copied Int
thisChunkSize MVector s Word8
src)
    Int -> Int -> ST s ()
go (Int
copied Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
thisChunkSize) (Int
toCopy Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
thisChunkSize)

-- TODO: these are a bit questionable. Maybe we can just pass around Vector Word8 in the client code?
fromByteString :: SBS.ShortByteString -> V.Vector Word8
fromByteString :: ShortByteString -> Vector Word8
fromByteString (SBS ByteArray#
ba) =
  let len :: Int
len = ByteArray -> Int
Prim.sizeofByteArray (ByteArray# -> ByteArray
Prim.ByteArray ByteArray#
ba)
      sz :: Int
sz = Word8 -> Int
forall a. Prim a => a -> Int
Prim.sizeOf (Word8
forall a. HasCallStack => a
undefined :: Word8)
   in Int -> Int -> ByteArray -> Vector Word8
forall a. Int -> Int -> ByteArray -> Vector a
V.Vector Int
0 (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
sz) (ByteArray# -> ByteArray
Prim.ByteArray ByteArray#
ba)

toByteString :: V.Vector Word8 -> SBS.ShortByteString
toByteString :: Vector Word8 -> ShortByteString
toByteString (V.Vector Int
offset Int
len ByteArray
ba) =
  let sz :: Int
sz = Word8 -> Int
forall a. Prim a => a -> Int
Prim.sizeOf (Word8
forall a. HasCallStack => a
undefined :: Word8)
      !(Prim.ByteArray ByteArray#
ba') = ByteArray -> Int -> Int -> ByteArray
Prim.cloneByteArray ByteArray
ba (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
sz) (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
sz)
   in ByteArray# -> ShortByteString
SBS ByteArray#
ba'