{-# language BangPatterns #-}
{-# language BlockArguments #-}
{-# language DeriveAnyClass #-}
{-# language DerivingStrategies #-}
{-# language MagicHash #-}
{-# language UnboxedTuples #-}
{-# language UnliftedFFITypes #-}

-- | Compress a contiguous sequence of bytes into a single LZ4 block.
-- These functions do not perform any framing.
module Lz4.Block
  ( -- * Compression
    compress
  , compressU
  , compressHighly
  , compressHighlyU
    -- * Decompression
  , decompress
  , decompressU
    -- * Unsafe Compression
  , compressInto
    -- * Computing buffer size
  , requiredBufferSize
  ) where

import Lz4.Internal (requiredBufferSize,c_hs_compress_HC)

import Control.Monad.ST (runST)
import Control.Monad.ST.Run (runByteArrayST)
import Data.Bytes.Types (Bytes(Bytes))
import Data.Primitive (MutableByteArray(..),ByteArray(..))
import GHC.Exts (ByteArray#,MutableByteArray#)
import GHC.IO (unsafeIOToST)
import GHC.ST (ST(ST))

import qualified Control.Exception
import qualified Data.Primitive as PM
import qualified GHC.Exts as Exts

-- | Compress bytes using LZ4's HC algorithm. This is slower
-- than 'compress' but provides better compression. A higher
-- compression level increases compression but decreases speed.
-- This function has undefined behavior on byte sequences larger
-- than 2,113,929,216 bytes. This calls @LZ4_compress_HC@.
compressHighly ::
     Int -- ^ Compression level (Use 9 if uncertain)
  -> Bytes -- ^ Bytes to compress
  -> Bytes
compressHighly :: Int -> Bytes -> Bytes
compressHighly !Int
lvl (Bytes (ByteArray ByteArray#
arr) Int
off Int
len) = forall a. (forall s. ST s a) -> a
runST do
  let maxSz :: Int
maxSz = Int -> Int
requiredBufferSize Int
len
  dst :: MutableByteArray s
dst@(MutableByteArray MutableByteArray# s
dst# ) <- forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray Int
maxSz
  Int
actualSz <- forall a s. IO a -> ST s a
unsafeIOToST (forall s.
ByteArray#
-> Int -> MutableByteArray# s -> Int -> Int -> Int -> Int -> IO Int
c_hs_compress_HC ByteArray#
arr Int
off MutableByteArray# s
dst# Int
0 Int
len Int
maxSz Int
lvl)
  forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> Int -> m ()
PM.shrinkMutableByteArray MutableByteArray s
dst Int
actualSz
  ByteArray
result <- forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray s
dst
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteArray -> Int -> Int -> Bytes
Bytes ByteArray
result Int
0 Int
actualSz)

-- | Variant of 'compressHighly' with an unsliced result.
compressHighlyU ::
     Int -- ^ Compression level (Use 9 if uncertain)
  -> Bytes -- ^ Bytes to compress
  -> ByteArray
compressHighlyU :: Int -> Bytes -> ByteArray
compressHighlyU !Int
lvl (Bytes (ByteArray ByteArray#
arr) Int
off Int
len) = forall a. (forall s. ST s a) -> a
runST do
  let maxSz :: Int
maxSz = Int -> Int
requiredBufferSize Int
len
  dst :: MutableByteArray s
dst@(MutableByteArray MutableByteArray# s
dst# ) <- forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray Int
maxSz
  Int
actualSz <- forall a s. IO a -> ST s a
unsafeIOToST (forall s.
ByteArray#
-> Int -> MutableByteArray# s -> Int -> Int -> Int -> Int -> IO Int
c_hs_compress_HC ByteArray#
arr Int
off MutableByteArray# s
dst# Int
0 Int
len Int
maxSz Int
lvl)
  forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> Int -> m ()
PM.shrinkMutableByteArray MutableByteArray s
dst Int
actualSz
  forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray s
dst

-- | Compress bytes using LZ4.
-- A higher acceleration factor increases speed but decreases
-- compression. This function has undefined
-- behavior on byte sequences larger than 2,113,929,216 bytes.
-- This calls @LZ4_compress_default@.
compress ::
     Int -- ^ Acceleration Factor (Use 1 if uncertain)
  -> Bytes -- ^ Bytes to compress
  -> Bytes
compress :: Int -> Bytes -> Bytes
compress !Int
lvl (Bytes (ByteArray ByteArray#
arr) Int
off Int
len) = forall a. (forall s. ST s a) -> a
runST do
  let maxSz :: Int
maxSz = Int -> Int
requiredBufferSize Int
len
  dst :: MutableByteArray s
dst@(MutableByteArray MutableByteArray# s
dst# ) <- forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray Int
maxSz
  Int
actualSz <- forall a s. IO a -> ST s a
unsafeIOToST (forall s.
ByteArray#
-> Int -> MutableByteArray# s -> Int -> Int -> Int -> Int -> IO Int
c_hs_compress_fast ByteArray#
arr Int
off MutableByteArray# s
dst# Int
0 Int
len Int
maxSz Int
lvl)
  forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> Int -> m ()
PM.shrinkMutableByteArray MutableByteArray s
dst Int
actualSz
  ByteArray
result <- forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray s
dst
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteArray -> Int -> Int -> Bytes
Bytes ByteArray
result Int
0 Int
actualSz)

-- | Compress bytes using LZ4, pasting the compressed bytes into the
-- mutable byte array at the specified offset.
--
-- Precondition: There must be at least
-- @'requiredBufferSize' (Bytes.length src)@ bytes available starting
-- from the offset in the destination buffer. This is checked, and
-- this function will throw an exception if this invariant is violated.
compressInto ::
     Int -- ^ Acceleration Factor (Use 1 if uncertain)
  -> Bytes -- ^ Bytes to compress
  -> MutableByteArray s -- ^ Destination buffer
  -> Int -- ^ Offset into destination buffer
  -> Int -- ^ Bytes remaining in destination buffer
  -> ST s Int -- ^ Next available offset in destination buffer 
compressInto :: forall s.
Int -> Bytes -> MutableByteArray s -> Int -> Int -> ST s Int
compressInto !Int
lvl (Bytes (ByteArray ByteArray#
arr) Int
off Int
len) dst :: MutableByteArray s
dst@(MutableByteArray MutableByteArray# s
dst# ) !Int
doff !Int
dlen = do
  let maxSz :: Int
maxSz = Int -> Int
requiredBufferSize Int
len
  if Int
dlen forall a. Ord a => a -> a -> Bool
< Int
maxSz
    then forall a s. IO a -> ST s a
unsafeIOToST (forall e a. Exception e => e -> IO a
Control.Exception.throwIO Lz4BufferTooSmall
Lz4BufferTooSmall)
    else do
      Int
actualSz <- forall a s. IO a -> ST s a
unsafeIOToST (forall s.
ByteArray#
-> Int -> MutableByteArray# s -> Int -> Int -> Int -> Int -> IO Int
c_hs_compress_fast ByteArray#
arr Int
off MutableByteArray# s
dst# Int
doff Int
len Int
maxSz Int
lvl)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
doff forall a. Num a => a -> a -> a
+ Int
actualSz)

-- | Variant of 'compress' with an unsliced result.
compressU :: 
     Int -- ^ Acceleration Factor (Use 1 if uncertain)
  -> Bytes -- ^ Bytes to compress
  -> ByteArray
compressU :: Int -> Bytes -> ByteArray
compressU !Int
lvl (Bytes (ByteArray ByteArray#
arr) Int
off Int
len) = (forall s. ST s ByteArray) -> ByteArray
runByteArrayST do
  let maxSz :: Int
maxSz = Int -> Int
requiredBufferSize Int
len
  dst :: MutableByteArray s
dst@(MutableByteArray MutableByteArray# s
dst# ) <- forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray Int
maxSz
  Int
actualSz <- forall a s. IO a -> ST s a
unsafeIOToST (forall s.
ByteArray#
-> Int -> MutableByteArray# s -> Int -> Int -> Int -> Int -> IO Int
c_hs_compress_fast ByteArray#
arr Int
off MutableByteArray# s
dst# Int
0 Int
len Int
maxSz Int
lvl)
  forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> Int -> m ()
PM.shrinkMutableByteArray MutableByteArray s
dst Int
actualSz
  forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray s
dst

-- | Decompress a byte sequence. Fails if the actual decompressed
-- result does not match the given expected length.
decompress ::
     Int -- ^ Expected length of decompressed bytes
  -> Bytes -- ^ Compressed bytes
  -> Maybe Bytes
decompress :: Int -> Bytes -> Maybe Bytes
decompress !Int
dstSz !Bytes
b = case Int -> Bytes -> Maybe ByteArray
decompressU Int
dstSz Bytes
b of
  Maybe ByteArray
Nothing -> forall a. Maybe a
Nothing
  Just ByteArray
r -> forall a. a -> Maybe a
Just (ByteArray -> Int -> Int -> Bytes
Bytes ByteArray
r Int
0 Int
dstSz)

-- | Variant of 'decompress' with an unsliced result.
decompressU ::
     Int -- ^ Expected length of decompressed bytes
  -> Bytes -- ^ Compressed bytes
  -> Maybe ByteArray
decompressU :: Int -> Bytes -> Maybe ByteArray
decompressU Int
dstSz (Bytes (ByteArray ByteArray#
arr) Int
off Int
len) = forall a. (forall s. ST s a) -> a
runST do
  dst :: MutableByteArray s
dst@(MutableByteArray MutableByteArray# s
dst# ) <- forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray Int
dstSz
  Int
actualSz <- forall a s. IO a -> ST s a
unsafeIOToST (forall s.
ByteArray#
-> Int -> MutableByteArray# s -> Int -> Int -> Int -> IO Int
c_hs_decompress_safe ByteArray#
arr Int
off MutableByteArray# s
dst# Int
0 Int
len Int
dstSz)
  if Int
actualSz forall a. Eq a => a -> a -> Bool
== Int
dstSz
    then do
      ByteArray
result <- forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray s
dst
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. a -> Maybe a
Just ByteArray
result)
    else forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing

foreign import ccall unsafe "hs_compress_fast"
  c_hs_compress_fast ::
       ByteArray# -- Source
    -> Int       -- Source offset
    -> MutableByteArray# s -- Destination
    -> Int       -- Destination offset
    -> Int       -- Input size
    -> Int       -- Destination capacity
    -> Int       -- Acceleration factor
    -> IO Int    -- Result length

foreign import ccall unsafe "hs_decompress_safe"
  c_hs_decompress_safe ::
       ByteArray# -- Source
    -> Int       -- Source offset
    -> MutableByteArray# s -- Destination
    -> Int       -- Destination offset
    -> Int       -- Input size
    -> Int       -- Destination capacity
    -> IO Int    -- Result length

data Lz4BufferTooSmall = Lz4BufferTooSmall
  deriving stock (Lz4BufferTooSmall -> Lz4BufferTooSmall -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Lz4BufferTooSmall -> Lz4BufferTooSmall -> Bool
$c/= :: Lz4BufferTooSmall -> Lz4BufferTooSmall -> Bool
== :: Lz4BufferTooSmall -> Lz4BufferTooSmall -> Bool
$c== :: Lz4BufferTooSmall -> Lz4BufferTooSmall -> Bool
Eq,Int -> Lz4BufferTooSmall -> ShowS
[Lz4BufferTooSmall] -> ShowS
Lz4BufferTooSmall -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Lz4BufferTooSmall] -> ShowS
$cshowList :: [Lz4BufferTooSmall] -> ShowS
show :: Lz4BufferTooSmall -> String
$cshow :: Lz4BufferTooSmall -> String
showsPrec :: Int -> Lz4BufferTooSmall -> ShowS
$cshowsPrec :: Int -> Lz4BufferTooSmall -> ShowS
Show)
  deriving anyclass (Show Lz4BufferTooSmall
Typeable Lz4BufferTooSmall
SomeException -> Maybe Lz4BufferTooSmall
Lz4BufferTooSmall -> String
Lz4BufferTooSmall -> SomeException
forall e.
Typeable e
-> Show e
-> (e -> SomeException)
-> (SomeException -> Maybe e)
-> (e -> String)
-> Exception e
displayException :: Lz4BufferTooSmall -> String
$cdisplayException :: Lz4BufferTooSmall -> String
fromException :: SomeException -> Maybe Lz4BufferTooSmall
$cfromException :: SomeException -> Maybe Lz4BufferTooSmall
toException :: Lz4BufferTooSmall -> SomeException
$ctoException :: Lz4BufferTooSmall -> SomeException
Control.Exception.Exception)

-- foreign import capi "lz4.h value sizeof(LZ4_stream_t)" lz4StreamSz :: Int
-- 
-- allocateLz4StreamT :: ST s (MutableByteArray s)
-- allocateLz4StreamT = PM.newPinnedByteArray lz4StreamSz