{-# LANGUAGE DeriveDataTypeable        #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE OverloadedStrings         #-}
{-# LANGUAGE PackageImports            #-}
{-# LANGUAGE ScopedTypeVariables       #-}

module Snap.Util.GZip
( withCompression
, withCompression' ) where

import   qualified Codec.Compression.GZip as GZip
import   qualified Codec.Compression.Zlib as Zlib
import             Control.Concurrent
import             Control.Applicative hiding (many)
import             Control.Exception
import             Control.Monad
import             Control.Monad.Trans
import             Data.Attoparsec.Char8 hiding (Done)
import   qualified Data.ByteString.Lazy.Char8 as L
import             Data.ByteString.Char8 (ByteString)
import             Data.Maybe
import   qualified Data.Set as Set
import             Data.Set (Set)
import             Data.Typeable
import             Prelude hiding (catch, takeWhile)

------------------------------------------------------------------------------
import             Snap.Internal.Debug
import             Snap.Internal.Parsing
import             Snap.Iteratee
import             Snap.Types


------------------------------------------------------------------------------
-- | Runs a 'Snap' web handler with compression if available.
--
-- If the client has indicated support for @gzip@ or @compress@ in its
-- @Accept-Encoding@ header, and the @Content-Type@ in the response is one of
-- the following types:
--
--   * @application/x-javascript@
--
--   * @text/css@
--
--   * @text/html@
--
--   * @text/javascript@
--
--   * @text/plain@
--
--   * @text/xml@
--
--   * @application/x-font-truetype@
--
-- Then the given handler's output stream will be compressed,
-- @Content-Encoding@ will be set in the output headers, and the
-- @Content-Length@ will be cleared if it was set. (We can't process the
-- stream in O(1) space if the length is known beforehand.)
--
-- The wrapped handler will be run to completion, and then the 'Response'
-- that's contained within the 'Snap' monad state will be passed to
-- 'finishWith' to prevent further processing.
--
withCompression :: MonadSnap m
                => m a   -- ^ the web handler to run
                -> m ()
withCompression = withCompression' compressibleMimeTypes


------------------------------------------------------------------------------
-- | The same as 'withCompression', with control over which MIME types to
-- compress.
withCompression' :: MonadSnap m
                 => Set ByteString
                    -- ^ set of compressible MIME types
                 -> m a
                    -- ^ the web handler to run
                 -> m ()
withCompression' mimeTable action = do
    _    <- action
    resp <- getResponse

    -- If a content-encoding is already set, do nothing. This prevents
    -- "withCompression $ withCompression m" from ruining your day.
    when (not $ isJust $ getHeader "Content-Encoding" resp) $ do
       let mbCt = getHeader "Content-Type" resp

       debug $ "withCompression', content-type is " ++ show mbCt

       case mbCt of
         (Just ct) -> when (Set.member ct mimeTable) chkAcceptEncoding
         _         -> return $! ()


    getResponse >>= finishWith

  where
    chkAcceptEncoding = do
        req <- getRequest
        debug $ "checking accept-encoding"
        let mbAcc = getHeader "Accept-Encoding" req
        debug $ "accept-encoding is " ++ show mbAcc
        let s = fromMaybe "" mbAcc

        types <- liftIO $ parseAcceptEncoding s

        chooseType types


    chooseType []               = return $! ()
    chooseType ("gzip":_)       = gzipCompression "gzip"
    chooseType ("compress":_)   = compressCompression "compress"
    chooseType ("x-gzip":_)     = gzipCompression "x-gzip"
    chooseType ("x-compress":_) = compressCompression "x-compress"
    chooseType (_:xs)           = chooseType xs


------------------------------------------------------------------------------
-- private following
------------------------------------------------------------------------------


------------------------------------------------------------------------------
compressibleMimeTypes :: Set ByteString
compressibleMimeTypes = Set.fromList [ "application/x-font-truetype"
                                     , "application/x-javascript"
                                     , "text/css"
                                     , "text/html"
                                     , "text/javascript"
                                     , "text/plain"
                                     , "text/xml" ]




------------------------------------------------------------------------------
gzipCompression :: MonadSnap m => ByteString -> m ()
gzipCompression ce = modifyResponse f
  where
    f = setHeader "Content-Encoding" ce .
        setHeader "Vary" "Accept-Encoding" .
        clearContentLength .
        modifyResponseBody gcompress


------------------------------------------------------------------------------
compressCompression :: MonadSnap m => ByteString -> m ()
compressCompression ce = modifyResponse f
  where
    f = setHeader "Content-Encoding" ce .
        setHeader "Vary" "Accept-Encoding" .
        clearContentLength .
        modifyResponseBody ccompress


------------------------------------------------------------------------------
-- FIXME: use zlib-bindings
gcompress :: forall a . Enumerator ByteString IO a
          -> Enumerator ByteString IO  a
gcompress = compressEnumerator GZip.compress


------------------------------------------------------------------------------
ccompress :: forall a . Enumerator ByteString IO a
          -> Enumerator ByteString IO a
ccompress = compressEnumerator Zlib.compress


------------------------------------------------------------------------------
compressEnumerator :: forall a .
                      (L.ByteString -> L.ByteString)
                   -> Enumerator ByteString IO a
                   -> Enumerator ByteString IO a
compressEnumerator compFunc enum origStep = do
    writeEnd <- liftIO $ newChan
    readEnd  <- liftIO $ newChan
    tid      <- liftIO $ forkIO $ threadProc readEnd writeEnd

    enum (f readEnd writeEnd tid origStep)

  where
    --------------------------------------------------------------------------
    streamFinished :: Stream ByteString -> Bool
    streamFinished EOF        = True
    streamFinished (Chunks _) = False


    --------------------------------------------------------------------------
    consumeSomeOutput :: Chan (Either SomeException (Stream ByteString))
                      -> Step ByteString IO a
                      -> Iteratee ByteString IO (Step ByteString IO a)
    consumeSomeOutput writeEnd step = do
        e <- lift $ isEmptyChan writeEnd
        if e
          then return step
          else do
            ech <- lift $ readChan writeEnd
            either throwError
                   (\ch -> do
                        step' <- checkDone (\k -> lift $ runIteratee $ k ch)
                                           step
                        consumeSomeOutput writeEnd step')
                   ech

    --------------------------------------------------------------------------
    consumeRest :: Chan (Either SomeException (Stream ByteString))
                -> Step ByteString IO a
                -> Iteratee ByteString IO a
    consumeRest writeEnd step = do
        ech <- lift $ readChan writeEnd
        either throwError
               (\ch -> do
                   step' <- checkDone (\k -> lift $ runIteratee $ k ch) step
                   if (streamFinished ch)
                      then returnI step'
                      else consumeRest writeEnd step')
               ech

    --------------------------------------------------------------------------
    f _ _ _ (Error e) = Error e
    f _ _ _ (Yield x _) = Yield x EOF
    f readEnd writeEnd tid st@(Continue k) = Continue $ \ch ->
        case ch of
          EOF -> do
            lift $ writeChan readEnd Nothing
            x <- consumeRest writeEnd st
            lift $ killThread tid
            return x

          (Chunks xs) -> do
            mapM_ (lift . writeChan readEnd . Just) xs
            step' <- consumeSomeOutput writeEnd (Continue k)
            returnI $ f readEnd writeEnd tid step'


    --------------------------------------------------------------------------
    threadProc :: Chan (Maybe ByteString)
               -> Chan (Either SomeException (Stream ByteString))
               -> IO ()
    threadProc readEnd writeEnd = do
        stream <- getChanContents readEnd

        let bs = L.fromChunks $ streamToChunks stream
        let output = L.toChunks $ compFunc bs

        runIt output `catch` \(e::SomeException) ->
            writeChan writeEnd $ Left e

      where
        runIt (x:xs) = do
            writeChan writeEnd (toChunk x) >> runIt xs

        runIt []     = do
            writeChan writeEnd $ Right EOF

    --------------------------------------------------------------------------
    streamToChunks []            = []
    streamToChunks (Nothing:_)   = []
    streamToChunks ((Just x):xs) = x:(streamToChunks xs)


    --------------------------------------------------------------------------
    toChunk = Right . Chunks . (:[])


------------------------------------------------------------------------------
-- We're not gonna bother with quality values; we'll do gzip or compress in
-- that order.
acceptParser :: Parser [ByteString]
acceptParser = do
    xs <- option [] $ (:[]) <$> encoding
    ys <- many (char ',' *> encoding)
    endOfInput
    return $ xs ++ ys
  where
    encoding = skipSpace *> c <* skipSpace

    c = do
        x <- coding
        option () qvalue
        return x

    qvalue = do
        skipSpace
        char ';'
        skipSpace
        char 'q'
        skipSpace
        char '='
        float
        return ()

    coding = string "*" <|> takeWhile isCodingChar

    isCodingChar ch = isDigit ch || isAlpha_ascii ch || ch == '-' || ch == '_'

    float = takeWhile isDigit >>
            option () (char '.' >> takeWhile isDigit >> pure ())


------------------------------------------------------------------------------
data BadAcceptEncodingException = BadAcceptEncodingException
   deriving (Typeable)


------------------------------------------------------------------------------
instance Show BadAcceptEncodingException where
    show BadAcceptEncodingException = "bad 'accept-encoding' header"


------------------------------------------------------------------------------
instance Exception BadAcceptEncodingException


------------------------------------------------------------------------------
parseAcceptEncoding :: ByteString -> IO [ByteString]
parseAcceptEncoding s =
    case r of
      Left _ -> throwIO BadAcceptEncodingException
      Right x -> return x
  where
    r = fullyParse s acceptParser