{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE CPP #-}

-----------------------------------------------------------------------------
-- |
-- Module      :  Distribution.Client.GZipUtils
-- Copyright   :  (c) Dmitry Astapov 2010
-- License     :  BSD-like
--
-- Maintainer  :  cabal-devel@gmail.com
-- Stability   :  provisional
-- Portability :  portable
--
-- Provides a convenience functions for working with files that may or may not
-- be zipped.
-----------------------------------------------------------------------------
module Distribution.Client.GZipUtils (
    maybeDecompress,
  ) where

import Prelude ()
import Distribution.Client.Compat.Prelude

import Codec.Compression.Zlib.Internal
import Data.ByteString.Lazy.Internal as BS (ByteString(Empty, Chunk))

#ifndef MIN_VERSION_zlib
#define MIN_VERSION_zlib(x,y,z) 1
#endif

#if MIN_VERSION_zlib(0,6,0)
import Control.Exception (throw)
import Control.Monad.ST.Lazy (ST, runST)
import qualified Data.ByteString as Strict
#endif

-- | Attempts to decompress the `bytes' under the assumption that
-- "data format" error at the very beginning of the stream means
-- that it is already decompressed. Caller should make sanity checks
-- to verify that it is not, in fact, garbage.
--
-- This is to deal with http proxies that lie to us and transparently
-- decompress without removing the content-encoding header. See:
-- <https://github.com/haskell/cabal/issues/678>
--
maybeDecompress :: ByteString -> ByteString
#if MIN_VERSION_zlib(0,6,0)
maybeDecompress :: ByteString -> ByteString
maybeDecompress ByteString
bytes = (forall s. ST s ByteString) -> ByteString
forall a. (forall s. ST s a) -> a
runST (ByteString -> DecompressStream (ST s) -> ST s ByteString
forall (m :: * -> *).
Monad m =>
ByteString -> DecompressStream m -> m ByteString
go ByteString
bytes DecompressStream (ST s)
forall s. DecompressStream (ST s)
decompressor)
  where
    decompressor :: DecompressStream (ST s)
    decompressor :: DecompressStream (ST s)
decompressor = Format -> DecompressParams -> DecompressStream (ST s)
forall s. Format -> DecompressParams -> DecompressStream (ST s)
decompressST Format
gzipOrZlibFormat DecompressParams
defaultDecompressParams

    -- DataError at the beginning of the stream probably means that stream is
    -- not compressed, so we return it as-is.
    -- TODO: alternatively, we might consider looking for the two magic bytes
    -- at the beginning of the gzip header.  (not an option for zlib, though.)
    go :: Monad m => ByteString -> DecompressStream m -> m ByteString
    go :: ByteString -> DecompressStream m -> m ByteString
go ByteString
cs (DecompressOutputAvailable ByteString
bs m (DecompressStream m)
k) = (ByteString -> ByteString) -> m ByteString -> m ByteString
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (ByteString -> ByteString -> ByteString
Chunk ByteString
bs) (m ByteString -> m ByteString) -> m ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> DecompressStream m -> m ByteString
forall (m :: * -> *).
Monad m =>
ByteString -> DecompressStream m -> m ByteString
go' ByteString
cs (DecompressStream m -> m ByteString)
-> m (DecompressStream m) -> m ByteString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (DecompressStream m)
k
    go ByteString
_  (DecompressStreamEnd       ByteString
_bs ) = ByteString -> m ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
Empty
    go ByteString
_  (DecompressStreamError DecompressError
_err    ) = ByteString -> m ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bytes
    go ByteString
cs (DecompressInputRequired      ByteString -> m (DecompressStream m)
k) = ByteString -> DecompressStream m -> m ByteString
forall (m :: * -> *).
Monad m =>
ByteString -> DecompressStream m -> m ByteString
go ByteString
cs' (DecompressStream m -> m ByteString)
-> m (DecompressStream m) -> m ByteString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ByteString -> m (DecompressStream m)
k ByteString
c
      where
        (ByteString
c, ByteString
cs') = ByteString -> (ByteString, ByteString)
uncons ByteString
cs

    -- Once we have received any output though we regard errors as actual errors
    -- and we throw them (as pure exceptions).
    -- TODO: We could (and should) avoid these pure exceptions.
    go' :: Monad m => ByteString -> DecompressStream m -> m ByteString
    go' :: ByteString -> DecompressStream m -> m ByteString
go' ByteString
cs (DecompressOutputAvailable ByteString
bs m (DecompressStream m)
k) = (ByteString -> ByteString) -> m ByteString -> m ByteString
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (ByteString -> ByteString -> ByteString
Chunk ByteString
bs) (m ByteString -> m ByteString) -> m ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> DecompressStream m -> m ByteString
forall (m :: * -> *).
Monad m =>
ByteString -> DecompressStream m -> m ByteString
go' ByteString
cs (DecompressStream m -> m ByteString)
-> m (DecompressStream m) -> m ByteString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (DecompressStream m)
k
    go' ByteString
_  (DecompressStreamEnd       ByteString
_bs ) = ByteString -> m ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
Empty
    go' ByteString
_  (DecompressStreamError DecompressError
err     ) = DecompressError -> m ByteString
forall a e. Exception e => e -> a
throw DecompressError
err
    go' ByteString
cs (DecompressInputRequired      ByteString -> m (DecompressStream m)
k) = ByteString -> DecompressStream m -> m ByteString
forall (m :: * -> *).
Monad m =>
ByteString -> DecompressStream m -> m ByteString
go' ByteString
cs' (DecompressStream m -> m ByteString)
-> m (DecompressStream m) -> m ByteString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ByteString -> m (DecompressStream m)
k ByteString
c
      where
        (ByteString
c, ByteString
cs') = ByteString -> (ByteString, ByteString)
uncons ByteString
cs

    uncons :: ByteString -> (Strict.ByteString, ByteString)
    uncons :: ByteString -> (ByteString, ByteString)
uncons ByteString
Empty        = (ByteString
Strict.empty, ByteString
Empty)
    uncons (Chunk ByteString
c ByteString
cs) = (ByteString
c, ByteString
cs)
#else
maybeDecompress bytes = foldStream $ decompressWithErrors gzipOrZlibFormat defaultDecompressParams bytes
  where
    -- DataError at the beginning of the stream probably means that stream is not compressed.
    -- Returning it as-is.
    -- TODO: alternatively, we might consider looking for the two magic bytes
    -- at the beginning of the gzip header.
    foldStream (StreamError _ _) = bytes
    foldStream somethingElse = doFold somethingElse

    doFold StreamEnd               = BS.Empty
    doFold (StreamChunk bs stream) = BS.Chunk bs (doFold stream)
    doFold (StreamError _ msg)  = error $ "Codec.Compression.Zlib: " ++ msg
#endif