--------------------------------------------------------------------------------
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards   #-}
{-# LANGUAGE TupleSections     #-}
module Network.WebSockets.Extensions.PermessageDeflate
    ( defaultPermessageDeflate
    , PermessageDeflate(..)
    , negotiateDeflate

      -- * Considered internal
    , makeMessageInflater
    , makeMessageDeflater
    ) where


--------------------------------------------------------------------------------
import           Control.Applicative                       ((<$>))
import           Control.Exception                         (throwIO)
import           Control.Monad                             (foldM, unless)
import qualified Data.ByteString                           as B
import qualified Data.ByteString.Char8                     as B8
import qualified Data.ByteString.Lazy                      as BL
import qualified Data.ByteString.Lazy.Char8                as BL8
import qualified Data.ByteString.Lazy.Internal             as BL
import           Data.Int                                  (Int64)
import           Data.Monoid
import qualified Data.Streaming.Zlib                       as Zlib
import           Network.WebSockets.Connection.Options
import           Network.WebSockets.Extensions
import           Network.WebSockets.Extensions.Description
import           Network.WebSockets.Http
import           Network.WebSockets.Types
import           Prelude
import           Text.Read                                 (readMaybe)


--------------------------------------------------------------------------------
-- | Convert the parameters to an 'ExtensionDescription' that we can put in a
-- 'Sec-WebSocket-Extensions' header.
toExtensionDescription :: PermessageDeflate -> ExtensionDescription
toExtensionDescription PermessageDeflate {..} = ExtensionDescription
    { extName   = "permessage-deflate"
    , extParams =
         [("server_no_context_takeover", Nothing) | serverNoContextTakeover] ++
         [("client_no_context_takeover", Nothing) | clientNoContextTakeover] ++
         [("server_max_window_bits", param serverMaxWindowBits) | serverMaxWindowBits /= 15] ++
         [("client_max_window_bits", param clientMaxWindowBits) | clientMaxWindowBits /= 15]
    }
  where
    param = Just . B8.pack . show


--------------------------------------------------------------------------------
toHeaders :: PermessageDeflate -> Headers
toHeaders pmd =
    [ ( "Sec-WebSocket-Extensions"
      , encodeExtensionDescriptions [toExtensionDescription pmd]
      )
    ]


--------------------------------------------------------------------------------
negotiateDeflate
    :: SizeLimit -> Maybe PermessageDeflate -> NegotiateExtension
negotiateDeflate messageLimit pmd0 exts0 = do
    (headers, pmd1) <- negotiateDeflateOpts exts0 pmd0
    return Extension
        { extHeaders = headers
        , extParse   = \parseRaw -> do
            inflate <- makeMessageInflater messageLimit pmd1
            return $ do
                msg <- parseRaw
                case msg of
                    Nothing -> return Nothing
                    Just m  -> fmap Just (inflate m)

        , extWrite   = \writeRaw -> do
            deflate <- makeMessageDeflater pmd1
            return $ \msgs ->
                mapM deflate msgs >>= writeRaw
        }
  where
    negotiateDeflateOpts
        :: ExtensionDescriptions
        -> Maybe PermessageDeflate
        -> Either String (Headers, Maybe PermessageDeflate)

    negotiateDeflateOpts (ext : _) (Just x)
        | extName ext == "x-webkit-deflate-frame" = Right
            ([("Sec-WebSocket-Extensions", "x-webkit-deflate-frame")], Just x)

    negotiateDeflateOpts (ext : _) (Just x)
        | extName ext == "permessage-deflate" = do
            x' <- foldM setParam x (extParams ext)
            Right (toHeaders x', Just x')

    negotiateDeflateOpts (_ : exts) (Just x) =
        negotiateDeflateOpts exts (Just x)

    negotiateDeflateOpts _ _ = Right ([], Nothing)


--------------------------------------------------------------------------------
setParam
    :: PermessageDeflate -> ExtensionParam -> Either String PermessageDeflate
setParam pmd ("server_no_context_takeover", _) =
    Right pmd {serverNoContextTakeover = True}

setParam pmd ("client_no_context_takeover", _) =
    Right pmd {clientNoContextTakeover = True}

setParam pmd ("server_max_window_bits", Nothing) =
    Right pmd {serverMaxWindowBits = 15}

setParam pmd ("server_max_window_bits", Just param) = do
    w <- parseWindow param
    Right pmd {serverMaxWindowBits = w}

setParam pmd ("client_max_window_bits", Nothing) = do
    Right pmd {clientMaxWindowBits = 15}

setParam pmd ("client_max_window_bits", Just param) = do
    w <- parseWindow param
    Right pmd {clientMaxWindowBits = w}

setParam pmd (_, _) = Right pmd


--------------------------------------------------------------------------------
parseWindow :: B.ByteString -> Either String Int
parseWindow bs8 = case readMaybe (B8.unpack bs8) of
    Just w
        | w >= 8 && w <= 15 -> Right w
        | otherwise         -> Left $ "Window out of bounds: " ++ show w
    Nothing -> Left $ "Can't parse window: " ++ show bs8


--------------------------------------------------------------------------------
-- | If the window_bits parameter is set to 8, we must set it to 9 instead.
--
-- Related issues:
-- - https://github.com/haskell/zlib/issues/11
-- - https://github.com/madler/zlib/issues/94
--
-- Quote from zlib manual:
--
-- For the current implementation of deflate(), a windowBits value of 8 (a
-- window size of 256 bytes) is not supported. As a result, a request for 8 will
-- result in 9 (a 512-byte window). In that case, providing 8 to inflateInit2()
-- will result in an error when the zlib header with 9 is checked against the
-- initialization of inflate(). The remedy is to not use 8 with deflateInit2()
-- with this initialization, or at least in that case use 9 with inflateInit2().
fixWindowBits :: Int -> Int
fixWindowBits n
    | n < 9     = 9
    | n > 15    = 15
    | otherwise = n


--------------------------------------------------------------------------------
appTailL :: BL.ByteString
appTailL = BL.pack [0x00,0x00,0xff,0xff]


--------------------------------------------------------------------------------
maybeStrip :: BL.ByteString -> BL.ByteString
maybeStrip x | appTailL `BL.isSuffixOf` x = BL.take (BL.length x - 4) x
maybeStrip x = x


--------------------------------------------------------------------------------
rejectExtensions :: Message -> IO Message
rejectExtensions (DataMessage rsv1 rsv2 rsv3 _) | rsv1 || rsv2 || rsv3 =
    throwIO $ CloseRequest 1002 "Protocol Error"
rejectExtensions x = return x


--------------------------------------------------------------------------------
makeMessageDeflater
    :: Maybe PermessageDeflate -> IO (Message -> IO Message)
makeMessageDeflater Nothing = return rejectExtensions
makeMessageDeflater (Just pmd)
    | serverNoContextTakeover pmd = do
        return $ \msg -> do
            ptr <- initDeflate pmd
            deflateMessageWith (deflateBody ptr) msg
    | otherwise = do
        ptr <- initDeflate pmd
        return $ \msg ->
            deflateMessageWith (deflateBody ptr) msg
  where
    ----------------------------------------------------------------------------
    initDeflate :: PermessageDeflate -> IO Zlib.Deflate
    initDeflate PermessageDeflate {..} =
        Zlib.initDeflate
            pdCompressionLevel
            (Zlib.WindowBits (- (fixWindowBits serverMaxWindowBits)))


    ----------------------------------------------------------------------------
    deflateMessageWith
        :: (BL.ByteString -> IO BL.ByteString)
        -> Message -> IO Message
    deflateMessageWith deflater (DataMessage False False False (Text x _)) = do
        x' <- deflater x
        return (DataMessage True False False (Text x' Nothing))
    deflateMessageWith deflater (DataMessage False False False (Binary x)) = do
        x' <- deflater x
        return (DataMessage True False False (Binary x'))
    deflateMessageWith _ x = return x


    ----------------------------------------------------------------------------
    deflateBody :: Zlib.Deflate -> BL.ByteString -> IO BL.ByteString
    deflateBody ptr = fmap maybeStrip . go . BL.toChunks
      where
        go [] =
            dePopper (Zlib.flushDeflate ptr)
        go (c : cs) = do
            chunk <- Zlib.feedDeflate ptr c >>= dePopper
            (chunk <>) <$> go cs


--------------------------------------------------------------------------------
dePopper :: Zlib.Popper -> IO BL.ByteString
dePopper p = p >>= \res -> case res of
    Zlib.PRDone    -> return BL.empty
    Zlib.PRNext c  -> BL.chunk c <$> dePopper p
    Zlib.PRError x -> throwIO $ CloseRequest 1002 (BL8.pack (show x))


--------------------------------------------------------------------------------
makeMessageInflater
    :: SizeLimit -> Maybe PermessageDeflate
    -> IO (Message -> IO Message)
makeMessageInflater _ Nothing = return rejectExtensions
makeMessageInflater messageLimit (Just pmd)
    | clientNoContextTakeover pmd =
        return $ \msg -> do
            ptr <- initInflate pmd
            inflateMessageWith (inflateBody ptr) msg
    | otherwise = do
        ptr <- initInflate pmd
        return $ \msg ->
            inflateMessageWith (inflateBody ptr) msg
  where
    --------------------------------------------------------------------------------
    initInflate :: PermessageDeflate -> IO Zlib.Inflate
    initInflate PermessageDeflate {..} =
        Zlib.initInflate
            (Zlib.WindowBits (- (fixWindowBits clientMaxWindowBits)))


    ----------------------------------------------------------------------------
    inflateMessageWith
        :: (BL.ByteString -> IO BL.ByteString)
        -> Message -> IO Message
    inflateMessageWith inflater (DataMessage True a b (Text x _)) = do
        x' <- inflater x
        return (DataMessage False a b (Text x' Nothing))
    inflateMessageWith inflater (DataMessage True a b (Binary x)) = do
        x' <- inflater x
        return (DataMessage False a b (Binary x'))
    inflateMessageWith _ x = return x


    ----------------------------------------------------------------------------
    inflateBody :: Zlib.Inflate -> BL.ByteString -> IO BL.ByteString
    inflateBody ptr =
        go 0 . BL.toChunks . (<> appTailL)
      where
        go :: Int64 -> [B.ByteString] -> IO BL.ByteString
        go size0 []       = do
            chunk <- Zlib.flushInflate ptr
            checkSize (fromIntegral (B.length chunk) + size0)
            return (BL.fromStrict chunk)
        go size0 (c : cs) = do
            chunk <- Zlib.feedInflate ptr c >>= dePopper
            let size1 = size0 + BL.length chunk
            checkSize size1
            (chunk <>) <$> go size1 cs


    ----------------------------------------------------------------------------
    checkSize :: Int64 -> IO ()
    checkSize size = unless (atMostSizeLimit size messageLimit) $ throwIO $
        ParseException $ "Message of size " ++ show size ++ " exceeded limit"