{-# language BangPatterns #-}
{-# language BinaryLiterals #-}
{-# language BlockArguments #-}
{-# language DeriveAnyClass #-}
{-# language DerivingStrategies #-}
{-# language MagicHash #-}
{-# language MultiWayIf #-}
{-# language NumericUnderscores #-}
{-# language UnboxedTuples #-}
{-# language UnliftedFFITypes #-}

-- | Compress a contiguous sequence of bytes into an LZ4 frame
-- containing a single block.
module Lz4.Frame
  ( -- * Compression
    compressHighlyU
  ) where

import Lz4.Internal (requiredBufferSize,c_hs_compress_HC)

import Control.Monad.ST (runST)
import Control.Monad.ST.Run (runByteArrayST)
import Data.Bytes.Types (Bytes(Bytes))
import Data.Int (Int32)
import Data.Primitive (MutableByteArray(..),ByteArray(..))
import Data.Word (Word8)
import GHC.Exts (ByteArray#,MutableByteArray#)
import GHC.IO (unsafeIOToST)
import GHC.ST (ST(ST))

import qualified Control.Exception
import qualified Data.Primitive as PM
import qualified GHC.Exts as Exts
import qualified Data.Primitive.ByteArray.LittleEndian as LE

-- | Use HC compression to produce a frame with a single block.
-- All optional fields (checksums, content sizes, and dictionary IDs)
-- are omitted.
--
-- Note: Currently, this produces incorrect output when the size of
-- the input to be compressed is greater than 4MiB. The only way
-- to correct this function is to make it not compress large input.
-- This can be done by setting the high bit of the size. This needs
-- to be tested though since it is an uncommon code path.
compressHighlyU ::
     Int -- ^ Compression level (Use 9 if uncertain)
  -> Bytes -- ^ Bytes to compress
  -> ByteArray
compressHighlyU :: Int -> Bytes -> ByteArray
compressHighlyU !Int
lvl (Bytes (ByteArray ByteArray#
arr) Int
off Int
len) = forall a. (forall s. ST s a) -> a
runST do
  let maxSz :: Int
maxSz = Int -> Int
requiredBufferSize Int
len forall a. Num a => a -> a -> a
+ Int
15
  dst :: MutableByteArray s
dst@(MutableByteArray MutableByteArray# s
dst# ) <- forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray Int
maxSz
  -- -- First 4 bytes: magic identifier
  forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
dst Int
0 (Word8
0x04 :: Word8)
  forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
dst Int
1 (Word8
0x22 :: Word8)
  forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
dst Int
2 (Word8
0x4D :: Word8)
  forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
dst Int
3 (Word8
0x18 :: Word8)
  -- Next 3 bytes: frame descriptor
  forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
dst Int
4 (Word8
0b0110_0000 :: Word8)
  if | Int
len forall a. Ord a => a -> a -> Bool
<= Int
65536 -> do
         forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
dst Int
5 (Word8
0b0100_0000 :: Word8)
         forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
dst Int
6 (Word8
0x82 :: Word8)
     | Int
len forall a. Ord a => a -> a -> Bool
<= Int
262144 -> do
         forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
dst Int
5 (Word8
0b0101_0000 :: Word8)
         forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
dst Int
6 (Word8
0xFB :: Word8)
     | Int
len forall a. Ord a => a -> a -> Bool
<= Int
1048576 -> do
         forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
dst Int
5 (Word8
0b0110_0000 :: Word8)
         forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
dst Int
6 (Word8
0x51 :: Word8)
     | Bool
otherwise -> do
         forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
dst Int
5 (Word8
0b0111_0000 :: Word8)
         forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
dst Int
6 (Word8
0x73 :: Word8)
  Int
actualSz <- forall a s. IO a -> ST s a
unsafeIOToST (forall s.
ByteArray#
-> Int -> MutableByteArray# s -> Int -> Int -> Int -> Int -> IO Int
c_hs_compress_HC ByteArray#
arr Int
off MutableByteArray# s
dst# Int
11 Int
len Int
maxSz Int
lvl)
  forall (m :: * -> *) a.
(PrimMonad m, PrimUnaligned a, Bytes a) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
LE.writeUnalignedByteArray MutableByteArray s
dst Int
7 (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
actualSz :: Int32)
  forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
dst (Int
actualSz forall a. Num a => a -> a -> a
+ Int
11) (Word8
0x00 :: Word8)
  forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
dst (Int
actualSz forall a. Num a => a -> a -> a
+ Int
12) (Word8
0x00 :: Word8)
  forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
dst (Int
actualSz forall a. Num a => a -> a -> a
+ Int
13) (Word8
0x00 :: Word8)
  forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
dst (Int
actualSz forall a. Num a => a -> a -> a
+ Int
14) (Word8
0x00 :: Word8)
  forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> Int -> m ()
PM.shrinkMutableByteArray MutableByteArray s
dst (Int
actualSz forall a. Num a => a -> a -> a
+ Int
15)
  forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray s
dst