-- Copyright (c) 2016-present, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the BSD-style license found in
-- the LICENSE file in the root directory of this source tree. An
-- additional grant of patent rights can be found in the PATENTS file
-- in the same directory.

{-# LANGUAGE MultiWayIf #-}

-- |
-- Module      : Codec.Compression.Zstd.Lazy
-- Copyright   : (c) 2016-present, Facebook, Inc. All rights reserved.
--
-- License     : BSD3
-- Maintainer  : bryano@fb.com
-- Stability   : experimental
-- Portability : GHC
--
-- Lazy compression and decompression support for zstd.  Under the
-- hood, these are implemented using the streaming APIs.

module Codec.Compression.Zstd.Lazy
    (
      compress
    , decompress
    , S.maxCLevel
    ) where

import Data.ByteString.Lazy.Internal as L
import System.IO.Unsafe (unsafeInterleaveIO, unsafePerformIO)
import qualified Codec.Compression.Zstd.Streaming as S
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L

-- | Compress a payload.  The input will be consumed lazily, and the
-- compressed result generated lazily.
--
-- /Note:/ if any error occurs, compression will fail part-way through
-- with a call to 'error'.
compress :: Int
         -- ^ Compression level. Must be >= 1 and <= 'S.maxCLevel'.
         -> ByteString
         -- ^ Payload to compress.  This will be consumed lazily.
         -> ByteString
compress :: Int -> ByteString -> ByteString
compress Int
level ByteString
bs = IO Result -> ByteString -> ByteString
lazy (Int -> IO Result
S.compress Int
level) ByteString
bs

-- | Decompress a payload.  The input will be consumed lazily, and the
-- decompressed result generated lazily.
--
-- /Note:/ if any error occurs, decompression will fail part-way
-- through with a call to 'error'.
decompress :: ByteString -> ByteString
decompress :: ByteString -> ByteString
decompress ByteString
bs = IO Result -> ByteString -> ByteString
lazy IO Result
S.decompress ByteString
bs

lazy :: IO S.Result -> ByteString -> ByteString
lazy :: IO Result -> ByteString -> ByteString
lazy IO Result
start ByteString
b0 = IO ByteString -> ByteString
forall a. IO a -> a
unsafePerformIO (ByteString -> Result -> IO ByteString
go ByteString
b0 (Result -> IO ByteString) -> IO Result -> IO ByteString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO Result
start)
 where
  go :: ByteString -> Result -> IO ByteString
go ByteString
_            (S.Error String
who String
what) = String -> IO ByteString
forall a. HasCallStack => String -> a
error (String
who String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
what)
  go ByteString
bs           (S.Produce ByteString
o IO Result
k)    = do
    ByteString
os <- IO ByteString -> IO ByteString
forall a. IO a -> IO a
unsafeInterleaveIO (ByteString -> Result -> IO ByteString
go ByteString
bs (Result -> IO ByteString) -> IO Result -> IO ByteString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO Result
k)
    ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> ByteString -> ByteString
L.chunk ByteString
o ByteString
os)
  go (Chunk ByteString
c ByteString
cs) (S.Consume ByteString -> IO Result
f) = ByteString -> Result -> IO ByteString
go ByteString
cs (Result -> IO ByteString) -> IO Result -> IO ByteString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ByteString -> IO Result
f ByteString
c
  go ByteString
empty        (S.Consume ByteString -> IO Result
f) = ByteString -> Result -> IO ByteString
go ByteString
empty (Result -> IO ByteString) -> IO Result -> IO ByteString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ByteString -> IO Result
f ByteString
B.empty
  go ByteString
Empty        (S.Done ByteString
o)    = ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> ByteString -> ByteString
chunk ByteString
o ByteString
Empty)
  go ByteString
input        Result
state = String -> IO ByteString
forall a. HasCallStack => String -> a
error (String -> IO ByteString) -> String -> IO ByteString
forall a b. (a -> b) -> a -> b
$
                          String
"unpossible! " String -> String -> String
forall a. [a] -> [a] -> [a]
++
                          Int64 -> String
forall a. Show a => a -> String
show (ByteString -> Int64
L.length ByteString
input) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" bytes of input left, " String -> String -> String
forall a. [a] -> [a] -> [a]
++
                          Result -> String
forall a. Show a => a -> String
show Result
state String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" stream state"