-- |
-- Module      : System.IO.Streams.Lzma
-- Copyright   : © 2015 Herbert Valerio Riedel
-- License     : BSD3
--
-- Maintainer  : hvr@gnu.org
-- Stability   : experimental
-- Portability : portable
--
-- Simple IO-Streams interface for lzma/xz compression
--
module System.IO.Streams.Lzma
    ( -- * 'ByteString' decompression
      decompress
    , decompressWith
    , DecodeLzmaFlags(..)
    , defaultDecodeLzmaFlags

      -- * 'ByteString' compression
    , compress
    , compressWith
    , EncodeLzmaFlags(..)
    , LzmaCheck(..)
    , LzmaCompLevel(..)
    , defaultEncodeLzmaFlags

    ) where

import           Control.Exception
import           Control.Monad
import           Data.ByteString                  (ByteString)
import qualified Data.ByteString                  as BS
import           Data.IORef
import           LibLzma
import           System.IO.Streams       (InputStream, makeInputStream)
import qualified System.IO.Streams as Streams

decompress :: InputStream ByteString -> IO (InputStream ByteString)
decompress = decompressWith defaultDecodeLzmaFlags

decompressWith :: DecodeLzmaFlags -> InputStream ByteString -> IO (InputStream ByteString)
decompressWith flags ibs
    = newDecodeLzmaStream flags >>= either throwIO (wrapLzmaStream ibs)

compress :: InputStream ByteString -> IO (InputStream ByteString)
compress = compressWith defaultEncodeLzmaFlags

compressWith :: EncodeLzmaFlags -> InputStream ByteString -> IO (InputStream ByteString)
compressWith flags ibs
    = newEncodeLzmaStream flags >>= either throwIO (wrapLzmaStream ibs)

-- TODO: figure out sensible buffer-size
wrapLzmaStream :: InputStream ByteString -> LzmaStream -> IO (InputStream ByteString)
wrapLzmaStream ibs ls0 = do
    st <- newIORef (Right ls0)
    makeInputStream (go st)
  where
    go st = readIORef st >>= either goLeft goRight
      where
        goRight ls = do
            ibuf <- getChunk

            (rc, _, obuf) <- case ibuf of
                Nothing -> runLzmaStream ls BS.empty True bUFSIZ
                Just bs -> do
                    retval@(_, consumed, _) <- runLzmaStream ls bs False bUFSIZ
                    when (consumed < BS.length bs) $ do
                        Streams.unRead (BS.drop consumed bs) ibs
                    return retval

            unless (rc == LZMA_OK) $ do
                writeIORef st (Left rc)
                unless (rc == LZMA_STREAM_END) $
                    throwIO rc

            case rc of
                LZMA_OK -> if (BS.null obuf)
                                  then goRight ls -- feed de/encoder some more
                                  else return (Just obuf)

                LZMA_STREAM_END -> do
                    writeIORef st (Left rc)
                    if BS.null obuf
                        then return Nothing
                        else return (Just obuf)

                _ -> writeIORef st (Left rc) >> throwIO rc

    goLeft err = case err of
        LZMA_STREAM_END -> return Nothing
        _               -> throwIO err

    bUFSIZ = 32752

    -- wrapper around 'read ibs' to retry until a non-empty ByteString or Nothing is returned
    -- TODO: consider implementing flush semantics
    getChunk = do
        mbs <- Streams.read ibs
        case mbs of
            Just bs | BS.null bs -> getChunk
            _                    -> return mbs