-- Copyright (c) 2016-present, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the BSD-style license found in
-- the LICENSE file in the root directory of this source tree. An
-- additional grant of patent rights can be found in the PATENTS file
-- in the same directory.

-- |
-- Module      : Codec.Compression.Zstd.Internal
-- Copyright   : (c) 2016-present, Facebook, Inc. All rights reserved.
--
-- License     : BSD3
-- Maintainer  : bryano@fb.com
-- Stability   : experimental
-- Portability : GHC
--
-- A fast lossless compression algorithm, targeting real-time
-- compression scenarios at zlib-level and better compression ratios.

module Codec.Compression.Zstd.Internal
    (
      CCtx(..)
    , DCtx(..)
    , compressWith
    , decompressWith
    , decompressedSize
    , withCCtx
    , withDCtx
    , withDict
    , trainFromSamples
    , getDictID
    ) where

import Codec.Compression.Zstd.Types (Decompress(..), Dict(..))
import Control.Exception.Base (bracket)
import Data.ByteString.Internal (ByteString(..))
import Data.Word (Word, Word8)
import Foreign.C.Types (CInt, CSize)
import Foreign.Marshal.Array (withArray)
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (Ptr, castPtr, plusPtr)
import System.IO.Unsafe (unsafePerformIO)
import qualified Codec.Compression.Zstd.FFI as C
import qualified Data.ByteString as B
import qualified Data.ByteString.Internal as B

compressWith
    :: String
    -> (Ptr Word8 -> CSize -> Ptr Word8 -> CSize -> CInt -> IO CSize)
    -> Int
    -> ByteString
    -> IO ByteString
compressWith :: String
-> (Ptr Word8 -> CSize -> Ptr Word8 -> CSize -> CInt -> IO CSize)
-> Int
-> ByteString
-> IO ByteString
compressWith String
name Ptr Word8 -> CSize -> Ptr Word8 -> CSize -> CInt -> IO CSize
compressor Int
level (PS ForeignPtr Word8
sfp Int
off Int
len)
  | Int
level Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
1 Bool -> Bool -> Bool
|| Int
level Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
C.maxCLevel
              = String -> String -> IO ByteString
forall a. String -> String -> a
bail String
name String
"unsupported compression level"
  | Bool
otherwise =
  ForeignPtr Word8 -> (Ptr Word8 -> IO ByteString) -> IO ByteString
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
sfp ((Ptr Word8 -> IO ByteString) -> IO ByteString)
-> (Ptr Word8 -> IO ByteString) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
sp -> do
    CSize
maxSize <- CSize -> IO CSize
C.compressBound (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
    ForeignPtr Word8
dfp <- Int -> IO (ForeignPtr Word8)
forall a. Int -> IO (ForeignPtr a)
B.mallocByteString (CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
maxSize)
    ForeignPtr Word8 -> (Ptr Word8 -> IO ByteString) -> IO ByteString
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
dfp ((Ptr Word8 -> IO ByteString) -> IO ByteString)
-> (Ptr Word8 -> IO ByteString) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dst -> do
      let src :: Ptr b
src = Ptr Word8
sp Ptr Word8 -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
off
      CSize
csz <- Ptr Word8 -> CSize -> Ptr Word8 -> CSize -> CInt -> IO CSize
compressor Ptr Word8
dst CSize
maxSize Ptr Word8
forall b. Ptr b
src (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
level)
      CSize -> String -> IO ByteString -> IO ByteString
forall a. CSize -> String -> IO a -> IO a
handleError CSize
csz String
name (IO ByteString -> IO ByteString) -> IO ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ do
        let size :: Int
size = CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
csz
        if CSize
csz CSize -> CSize -> Bool
forall a. Ord a => a -> a -> Bool
< CSize
128 Bool -> Bool -> Bool
|| CSize
csz CSize -> CSize -> Bool
forall a. Ord a => a -> a -> Bool
>= CSize
maxSize CSize -> CSize -> CSize
forall a. Integral a => a -> a -> a
`div` CSize
2
        then ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ForeignPtr Word8 -> Int -> Int -> ByteString
PS ForeignPtr Word8
dfp Int
0 Int
size)
        else Int -> (Ptr Word8 -> IO ()) -> IO ByteString
B.create Int
size ((Ptr Word8 -> IO ()) -> IO ByteString)
-> (Ptr Word8 -> IO ()) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
p -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
B.memcpy Ptr Word8
p Ptr Word8
dst Int
size

-- | Return the decompressed size of a compressed payload, as stored
-- in the payload's header.
--
-- The returned value will be `Nothing` if it is either not known
-- (probably because the payload was compressed using a streaming
-- API), empty, or too large to fit in an 'Int'.
--
-- /Note:/ this value should not be trusted, as it can be controlled
-- by an attacker.
decompressedSize :: ByteString -> Maybe Int
decompressedSize :: ByteString -> Maybe Int
decompressedSize (PS ForeignPtr Word8
fp Int
off Int
len) =
  IO (Maybe Int) -> Maybe Int
forall a. IO a -> a
unsafePerformIO (IO (Maybe Int) -> Maybe Int)
-> ((Ptr Word8 -> IO (Maybe Int)) -> IO (Maybe Int))
-> (Ptr Word8 -> IO (Maybe Int))
-> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForeignPtr Word8 -> (Ptr Word8 -> IO (Maybe Int)) -> IO (Maybe Int)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fp ((Ptr Word8 -> IO (Maybe Int)) -> Maybe Int)
-> (Ptr Word8 -> IO (Maybe Int)) -> Maybe Int
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> do
    CULLong
sz <- Ptr Any -> CSize -> IO CULLong
forall src. Ptr src -> CSize -> IO CULLong
C.getDecompressedSize (Ptr Word8
ptr Ptr Word8 -> Int -> Ptr Any
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
off) (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
    Maybe Int -> IO (Maybe Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Int -> IO (Maybe Int)) -> Maybe Int -> IO (Maybe Int)
forall a b. (a -> b) -> a -> b
$ if CULLong
sz CULLong -> CULLong -> Bool
forall a. Eq a => a -> a -> Bool
== CULLong
0 Bool -> Bool -> Bool
|| CULLong
sz CULLong -> CULLong -> Bool
forall a. Ord a => a -> a -> Bool
> Int -> CULLong
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
forall a. Bounded a => a
maxBound :: Int)
             then Maybe Int
forall a. Maybe a
Nothing
             else Int -> Maybe Int
forall a. a -> Maybe a
Just (CULLong -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CULLong
sz)

decompressWith :: (Ptr Word8 -> CSize -> Ptr Word8 -> CSize -> IO CSize)
               -> ByteString
               -> IO Decompress
decompressWith :: (Ptr Word8 -> CSize -> Ptr Word8 -> CSize -> IO CSize)
-> ByteString -> IO Decompress
decompressWith Ptr Word8 -> CSize -> Ptr Word8 -> CSize -> IO CSize
decompressor (PS ForeignPtr Word8
sfp Int
off Int
len) = do
  ForeignPtr Word8 -> (Ptr Word8 -> IO Decompress) -> IO Decompress
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
sfp ((Ptr Word8 -> IO Decompress) -> IO Decompress)
-> (Ptr Word8 -> IO Decompress) -> IO Decompress
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
sp -> do
    let src :: Ptr b
src = Ptr Word8
sp Ptr Word8 -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
off
    CULLong
dstSize <- Ptr Any -> CSize -> IO CULLong
forall src. Ptr src -> CSize -> IO CULLong
C.getDecompressedSize Ptr Any
forall b. Ptr b
src (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
    if CULLong
dstSize CULLong -> CULLong -> Bool
forall a. Eq a => a -> a -> Bool
== CULLong
0
      then Decompress -> IO Decompress
forall (m :: * -> *) a. Monad m => a -> m a
return Decompress
Skip
      else if CULLong
dstSize CULLong -> CULLong -> Bool
forall a. Ord a => a -> a -> Bool
> Int -> CULLong
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
forall a. Bounded a => a
maxBound :: Int)
           then Decompress -> IO Decompress
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> Decompress
Error String
"invalid compressed payload size")
           else do
      ForeignPtr Word8
dfp <- Int -> IO (ForeignPtr Word8)
forall a. Int -> IO (ForeignPtr a)
B.mallocByteString (CULLong -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CULLong
dstSize)
      CSize
size <- ForeignPtr Word8 -> (Ptr Word8 -> IO CSize) -> IO CSize
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
dfp ((Ptr Word8 -> IO CSize) -> IO CSize)
-> (Ptr Word8 -> IO CSize) -> IO CSize
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dst ->
        Ptr Word8 -> CSize -> Ptr Word8 -> CSize -> IO CSize
decompressor Ptr Word8
dst (CULLong -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral CULLong
dstSize) Ptr Word8
forall b. Ptr b
src (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
      Decompress -> IO Decompress
forall (m :: * -> *) a. Monad m => a -> m a
return (Decompress -> IO Decompress) -> Decompress -> IO Decompress
forall a b. (a -> b) -> a -> b
$ if CSize -> Bool
C.isError CSize
size
               then String -> Decompress
Error (CSize -> String
C.getErrorName CSize
size)
               else ByteString -> Decompress
Decompress (ForeignPtr Word8 -> Int -> Int -> ByteString
PS ForeignPtr Word8
dfp Int
0 (CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
size))

-- | Compression context.
newtype CCtx = CCtx { CCtx -> Ptr CCtx
getCCtx :: Ptr C.CCtx }

-- | Allocate a compression context, run an action that may reuse the
-- context as many times as it needs, then free the context.
withCCtx :: (CCtx -> IO a) -> IO a
withCCtx :: (CCtx -> IO a) -> IO a
withCCtx CCtx -> IO a
act =
  IO CCtx -> (CCtx -> IO ()) -> (CCtx -> IO a) -> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket ((Ptr CCtx -> CCtx) -> IO (Ptr CCtx) -> IO CCtx
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ptr CCtx -> CCtx
CCtx (String -> IO (Ptr CCtx) -> IO (Ptr CCtx)
forall a. String -> IO (Ptr a) -> IO (Ptr a)
C.checkAlloc String
"withCCtx" IO (Ptr CCtx)
C.createCCtx))
          (Ptr CCtx -> IO ()
C.freeCCtx (Ptr CCtx -> IO ()) -> (CCtx -> Ptr CCtx) -> CCtx -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CCtx -> Ptr CCtx
getCCtx) CCtx -> IO a
act

-- | Decompression context.
newtype DCtx = DCtx { DCtx -> Ptr DCtx
getDCtx :: Ptr C.DCtx }

-- | Allocate a decompression context, run an action that may reuse the
-- context as many times as it needs, then free the context.
withDCtx :: (DCtx -> IO a) -> IO a
withDCtx :: (DCtx -> IO a) -> IO a
withDCtx DCtx -> IO a
act =
  IO DCtx -> (DCtx -> IO ()) -> (DCtx -> IO a) -> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket ((Ptr DCtx -> DCtx) -> IO (Ptr DCtx) -> IO DCtx
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ptr DCtx -> DCtx
DCtx (String -> IO (Ptr DCtx) -> IO (Ptr DCtx)
forall a. String -> IO (Ptr a) -> IO (Ptr a)
C.checkAlloc String
"withDCtx" IO (Ptr DCtx)
C.createDCtx))
          (Ptr DCtx -> IO ()
C.freeDCtx (Ptr DCtx -> IO ()) -> (DCtx -> Ptr DCtx) -> DCtx -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DCtx -> Ptr DCtx
getDCtx) DCtx -> IO a
act

withDict :: Dict -> (Ptr dict -> CSize -> IO a) -> IO a
withDict :: Dict -> (Ptr dict -> CSize -> IO a) -> IO a
withDict (Dict (PS ForeignPtr Word8
fp Int
off Int
len)) Ptr dict -> CSize -> IO a
act =
  ForeignPtr Word8 -> (Ptr Word8 -> IO a) -> IO a
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fp ((Ptr Word8 -> IO a) -> IO a) -> (Ptr Word8 -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> Ptr dict -> CSize -> IO a
act (Ptr Word8
ptr Ptr Word8 -> Int -> Ptr dict
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
off) (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)

-- | Create and train a compression dictionary from a collection of
-- samples.
--
-- To create a well-trained dictionary, here are some useful
-- guidelines to keep in mind:
--
-- * A reasonable dictionary size is in the region of 100 KB.  (Trying
--   to specify a dictionary size of less than a few hundred bytes will
--   probably fail.)
--
-- * To train the dictionary well, it is best to supply a few thousand
--   training samples.
--
-- * The combined size of all training samples should be 100 or more
--   times larger than the size of the dictionary.
trainFromSamples :: Int
                 -- ^ Maximum size of the compression dictionary to
                 -- create. The actual dictionary returned may be
                 -- smaller.
                 -> [ByteString]
                 -- ^ Samples to train with.
                 -> Either String Dict
trainFromSamples :: Int -> [ByteString] -> Either String Dict
trainFromSamples Int
capacity [ByteString]
samples = IO (Either String Dict) -> Either String Dict
forall a. IO a -> a
unsafePerformIO (IO (Either String Dict) -> Either String Dict)
-> IO (Either String Dict) -> Either String Dict
forall a b. (a -> b) -> a -> b
$
  [Int]
-> (Ptr Int -> IO (Either String Dict)) -> IO (Either String Dict)
forall a b. Storable a => [a] -> (Ptr a -> IO b) -> IO b
withArray ((ByteString -> Int) -> [ByteString] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ByteString -> Int
B.length [ByteString]
samples) ((Ptr Int -> IO (Either String Dict)) -> IO (Either String Dict))
-> (Ptr Int -> IO (Either String Dict)) -> IO (Either String Dict)
forall a b. (a -> b) -> a -> b
$ \Ptr Int
sizes -> do
    ForeignPtr Word8
dfp <- Int -> IO (ForeignPtr Word8)
forall a. Int -> IO (ForeignPtr a)
B.mallocByteString Int
capacity
    let PS ForeignPtr Word8
sfp Int
_ Int
_ = [ByteString] -> ByteString
B.concat [ByteString]
samples
    ForeignPtr Word8
-> (Ptr Word8 -> IO (Either String Dict))
-> IO (Either String Dict)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
dfp ((Ptr Word8 -> IO (Either String Dict)) -> IO (Either String Dict))
-> (Ptr Word8 -> IO (Either String Dict))
-> IO (Either String Dict)
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dict ->
      ForeignPtr Word8
-> (Ptr Word8 -> IO (Either String Dict))
-> IO (Either String Dict)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
sfp ((Ptr Word8 -> IO (Either String Dict)) -> IO (Either String Dict))
-> (Ptr Word8 -> IO (Either String Dict))
-> IO (Either String Dict)
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
sampPtr -> do
        CSize
dsz <- Ptr Word8 -> CSize -> Ptr Word8 -> Ptr CSize -> CUInt -> IO CSize
forall dict samples.
Ptr dict -> CSize -> Ptr samples -> Ptr CSize -> CUInt -> IO CSize
C.trainFromBuffer
               Ptr Word8
dict (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
capacity)
               Ptr Word8
sampPtr (Ptr Int -> Ptr CSize
forall a b. Ptr a -> Ptr b
castPtr Ptr Int
sizes) (Int -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([ByteString] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ByteString]
samples))
        if CSize -> Bool
C.isError CSize
dsz
          then Either String Dict -> IO (Either String Dict)
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> Either String Dict
forall a b. a -> Either a b
Left (CSize -> String
C.getErrorName CSize
dsz))
          else (ByteString -> Either String Dict)
-> IO ByteString -> IO (Either String Dict)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Dict -> Either String Dict
forall a b. b -> Either a b
Right (Dict -> Either String Dict)
-> (ByteString -> Dict) -> ByteString -> Either String Dict
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Dict
Dict) (IO ByteString -> IO (Either String Dict))
-> IO ByteString -> IO (Either String Dict)
forall a b. (a -> b) -> a -> b
$ do
            let size :: Int
size = CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
dsz
            if Int
size Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
128 Bool -> Bool -> Bool
|| Int
size Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
capacity Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2
            then ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ForeignPtr Word8 -> Int -> Int -> ByteString
PS ForeignPtr Word8
dfp Int
0 Int
size)
            else Int -> (Ptr Word8 -> IO ()) -> IO ByteString
B.create Int
size ((Ptr Word8 -> IO ()) -> IO ByteString)
-> (Ptr Word8 -> IO ()) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
p -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
B.memcpy Ptr Word8
p Ptr Word8
dict Int
size

-- | Return the identifier for the given dictionary, or 'Nothing' if
-- not a valid dictionary.
getDictID :: Dict -> Maybe Word
getDictID :: Dict -> Maybe Word
getDictID Dict
dict = IO (Maybe Word) -> Maybe Word
forall a. IO a -> a
unsafePerformIO (IO (Maybe Word) -> Maybe Word) -> IO (Maybe Word) -> Maybe Word
forall a b. (a -> b) -> a -> b
$ do
  CUInt
n <- Dict -> (Ptr Any -> CSize -> IO CUInt) -> IO CUInt
forall dict a. Dict -> (Ptr dict -> CSize -> IO a) -> IO a
withDict Dict
dict Ptr Any -> CSize -> IO CUInt
forall dict. Ptr dict -> CSize -> IO CUInt
C.getDictID
  Maybe Word -> IO (Maybe Word)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Word -> IO (Maybe Word)) -> Maybe Word -> IO (Maybe Word)
forall a b. (a -> b) -> a -> b
$! if CUInt
n CUInt -> CUInt -> Bool
forall a. Eq a => a -> a -> Bool
== CUInt
0
            then Maybe Word
forall a. Maybe a
Nothing
            else Word -> Maybe Word
forall a. a -> Maybe a
Just (CUInt -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral CUInt
n)

handleError :: CSize -> String -> IO a -> IO a
handleError :: CSize -> String -> IO a -> IO a
handleError CSize
sizeOrError String
func IO a
act
  | CSize -> Bool
C.isError CSize
sizeOrError
              = String -> String -> IO a
forall a. String -> String -> a
bail String
func (CSize -> String
C.getErrorName CSize
sizeOrError)
  | Bool
otherwise = IO a
act

bail :: String -> String -> a
bail :: String -> String -> a
bail String
func String
str = String -> a
forall a. HasCallStack => String -> a
error (String -> a) -> String -> a
forall a b. (a -> b) -> a -> b
$ String
"Codec.Compression.Zstd." String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
func String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
str