--------------------------------------------------------------------------------
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards   #-}
{-# LANGUAGE TupleSections     #-}
module Network.WebSockets.Extensions.PermessageDeflate
    ( defaultPermessageDeflate
    , PermessageDeflate(..)
    , negotiateDeflate

      -- * Considered internal
    , makeMessageInflater
    , makeMessageDeflater
    ) where


--------------------------------------------------------------------------------
import           Control.Applicative                       ((<$>))
import           Control.Exception                         (throwIO)
import           Control.Monad                             (foldM, unless)
import qualified Data.ByteString                           as B
import qualified Data.ByteString.Char8                     as B8
import qualified Data.ByteString.Lazy                      as BL
import qualified Data.ByteString.Lazy.Char8                as BL8
import qualified Data.ByteString.Lazy.Internal             as BL
import           Data.Int                                  (Int64)
import           Data.Monoid
import qualified Data.Streaming.Zlib                       as Zlib
import           Network.WebSockets.Connection.Options
import           Network.WebSockets.Extensions
import           Network.WebSockets.Extensions.Description
import           Network.WebSockets.Http
import           Network.WebSockets.Types
import           Prelude
import           Text.Read                                 (readMaybe)


--------------------------------------------------------------------------------
-- | Convert the parameters to an 'ExtensionDescription' that we can put in a
-- 'Sec-WebSocket-Extensions' header.
toExtensionDescription :: PermessageDeflate -> ExtensionDescription
toExtensionDescription :: PermessageDeflate -> ExtensionDescription
toExtensionDescription PermessageDeflate {Bool
Int
pdCompressionLevel :: PermessageDeflate -> Int
clientMaxWindowBits :: PermessageDeflate -> Int
serverMaxWindowBits :: PermessageDeflate -> Int
clientNoContextTakeover :: PermessageDeflate -> Bool
serverNoContextTakeover :: PermessageDeflate -> Bool
pdCompressionLevel :: Int
clientMaxWindowBits :: Int
serverMaxWindowBits :: Int
clientNoContextTakeover :: Bool
serverNoContextTakeover :: Bool
..} = ExtensionDescription :: ByteString -> [ExtensionParam] -> ExtensionDescription
ExtensionDescription
    { extName :: ByteString
extName   = ByteString
"permessage-deflate"
    , extParams :: [ExtensionParam]
extParams =
         [(ByteString
"server_no_context_takeover", Maybe ByteString
forall a. Maybe a
Nothing) | Bool
serverNoContextTakeover] [ExtensionParam] -> [ExtensionParam] -> [ExtensionParam]
forall a. [a] -> [a] -> [a]
++
         [(ByteString
"client_no_context_takeover", Maybe ByteString
forall a. Maybe a
Nothing) | Bool
clientNoContextTakeover] [ExtensionParam] -> [ExtensionParam] -> [ExtensionParam]
forall a. [a] -> [a] -> [a]
++
         [(ByteString
"server_max_window_bits", Int -> Maybe ByteString
param Int
serverMaxWindowBits) | Int
serverMaxWindowBits Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
15] [ExtensionParam] -> [ExtensionParam] -> [ExtensionParam]
forall a. [a] -> [a] -> [a]
++
         [(ByteString
"client_max_window_bits", Int -> Maybe ByteString
param Int
clientMaxWindowBits) | Int
clientMaxWindowBits Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
15]
    }
  where
    param :: Int -> Maybe ByteString
param = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString)
-> (Int -> ByteString) -> Int -> Maybe ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ByteString
B8.pack (String -> ByteString) -> (Int -> String) -> Int -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> String
forall a. Show a => a -> String
show


--------------------------------------------------------------------------------
toHeaders :: PermessageDeflate -> Headers
toHeaders :: PermessageDeflate -> Headers
toHeaders PermessageDeflate
pmd =
    [ ( CI ByteString
"Sec-WebSocket-Extensions"
      , ExtensionDescriptions -> ByteString
encodeExtensionDescriptions [PermessageDeflate -> ExtensionDescription
toExtensionDescription PermessageDeflate
pmd]
      )
    ]


--------------------------------------------------------------------------------
negotiateDeflate
    :: SizeLimit -> Maybe PermessageDeflate -> NegotiateExtension
negotiateDeflate :: SizeLimit -> Maybe PermessageDeflate -> NegotiateExtension
negotiateDeflate SizeLimit
messageLimit Maybe PermessageDeflate
pmd0 ExtensionDescriptions
exts0 = do
    (Headers
headers, Maybe PermessageDeflate
pmd1) <- ExtensionDescriptions
-> Maybe PermessageDeflate
-> Either String (Headers, Maybe PermessageDeflate)
negotiateDeflateOpts ExtensionDescriptions
exts0 Maybe PermessageDeflate
pmd0
    Extension -> Either String Extension
forall (m :: * -> *) a. Monad m => a -> m a
return Extension :: Headers
-> (IO (Maybe Message) -> IO (IO (Maybe Message)))
-> (([Message] -> IO ()) -> IO ([Message] -> IO ()))
-> Extension
Extension
        { extHeaders :: Headers
extHeaders = Headers
headers
        , extParse :: IO (Maybe Message) -> IO (IO (Maybe Message))
extParse   = \IO (Maybe Message)
parseRaw -> do
            Message -> IO Message
inflate <- SizeLimit -> Maybe PermessageDeflate -> IO (Message -> IO Message)
makeMessageInflater SizeLimit
messageLimit Maybe PermessageDeflate
pmd1
            IO (Maybe Message) -> IO (IO (Maybe Message))
forall (m :: * -> *) a. Monad m => a -> m a
return (IO (Maybe Message) -> IO (IO (Maybe Message)))
-> IO (Maybe Message) -> IO (IO (Maybe Message))
forall a b. (a -> b) -> a -> b
$ do
                Maybe Message
msg <- IO (Maybe Message)
parseRaw
                case Maybe Message
msg of
                    Maybe Message
Nothing -> Maybe Message -> IO (Maybe Message)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Message
forall a. Maybe a
Nothing
                    Just Message
m  -> (Message -> Maybe Message) -> IO Message -> IO (Maybe Message)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Message -> Maybe Message
forall a. a -> Maybe a
Just (Message -> IO Message
inflate Message
m)

        , extWrite :: ([Message] -> IO ()) -> IO ([Message] -> IO ())
extWrite   = \[Message] -> IO ()
writeRaw -> do
            Message -> IO Message
deflate <- Maybe PermessageDeflate -> IO (Message -> IO Message)
makeMessageDeflater Maybe PermessageDeflate
pmd1
            ([Message] -> IO ()) -> IO ([Message] -> IO ())
forall (m :: * -> *) a. Monad m => a -> m a
return (([Message] -> IO ()) -> IO ([Message] -> IO ()))
-> ([Message] -> IO ()) -> IO ([Message] -> IO ())
forall a b. (a -> b) -> a -> b
$ \[Message]
msgs ->
                (Message -> IO Message) -> [Message] -> IO [Message]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Message -> IO Message
deflate [Message]
msgs IO [Message] -> ([Message] -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [Message] -> IO ()
writeRaw
        }
  where
    negotiateDeflateOpts
        :: ExtensionDescriptions
        -> Maybe PermessageDeflate
        -> Either String (Headers, Maybe PermessageDeflate)

    negotiateDeflateOpts :: ExtensionDescriptions
-> Maybe PermessageDeflate
-> Either String (Headers, Maybe PermessageDeflate)
negotiateDeflateOpts (ExtensionDescription
ext : ExtensionDescriptions
_) (Just PermessageDeflate
x)
        | ExtensionDescription -> ByteString
extName ExtensionDescription
ext ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"x-webkit-deflate-frame" = (Headers, Maybe PermessageDeflate)
-> Either String (Headers, Maybe PermessageDeflate)
forall a b. b -> Either a b
Right
            ([(CI ByteString
"Sec-WebSocket-Extensions", ByteString
"x-webkit-deflate-frame")], PermessageDeflate -> Maybe PermessageDeflate
forall a. a -> Maybe a
Just PermessageDeflate
x)

    negotiateDeflateOpts (ExtensionDescription
ext : ExtensionDescriptions
_) (Just PermessageDeflate
x)
        | ExtensionDescription -> ByteString
extName ExtensionDescription
ext ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"permessage-deflate" = do
            PermessageDeflate
x' <- (PermessageDeflate
 -> ExtensionParam -> Either String PermessageDeflate)
-> PermessageDeflate
-> [ExtensionParam]
-> Either String PermessageDeflate
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM PermessageDeflate
-> ExtensionParam -> Either String PermessageDeflate
setParam PermessageDeflate
x (ExtensionDescription -> [ExtensionParam]
extParams ExtensionDescription
ext)
            (Headers, Maybe PermessageDeflate)
-> Either String (Headers, Maybe PermessageDeflate)
forall a b. b -> Either a b
Right (PermessageDeflate -> Headers
toHeaders PermessageDeflate
x', PermessageDeflate -> Maybe PermessageDeflate
forall a. a -> Maybe a
Just PermessageDeflate
x')

    negotiateDeflateOpts (ExtensionDescription
_ : ExtensionDescriptions
exts) (Just PermessageDeflate
x) =
        ExtensionDescriptions
-> Maybe PermessageDeflate
-> Either String (Headers, Maybe PermessageDeflate)
negotiateDeflateOpts ExtensionDescriptions
exts (PermessageDeflate -> Maybe PermessageDeflate
forall a. a -> Maybe a
Just PermessageDeflate
x)

    negotiateDeflateOpts ExtensionDescriptions
_ Maybe PermessageDeflate
_ = (Headers, Maybe PermessageDeflate)
-> Either String (Headers, Maybe PermessageDeflate)
forall a b. b -> Either a b
Right ([], Maybe PermessageDeflate
forall a. Maybe a
Nothing)


--------------------------------------------------------------------------------
setParam
    :: PermessageDeflate -> ExtensionParam -> Either String PermessageDeflate
setParam :: PermessageDeflate
-> ExtensionParam -> Either String PermessageDeflate
setParam PermessageDeflate
pmd (ByteString
"server_no_context_takeover", Maybe ByteString
_) =
    PermessageDeflate -> Either String PermessageDeflate
forall a b. b -> Either a b
Right PermessageDeflate
pmd {serverNoContextTakeover :: Bool
serverNoContextTakeover = Bool
True}

setParam PermessageDeflate
pmd (ByteString
"client_no_context_takeover", Maybe ByteString
_) =
    PermessageDeflate -> Either String PermessageDeflate
forall a b. b -> Either a b
Right PermessageDeflate
pmd {clientNoContextTakeover :: Bool
clientNoContextTakeover = Bool
True}

setParam PermessageDeflate
pmd (ByteString
"server_max_window_bits", Maybe ByteString
Nothing) =
    PermessageDeflate -> Either String PermessageDeflate
forall a b. b -> Either a b
Right PermessageDeflate
pmd {serverMaxWindowBits :: Int
serverMaxWindowBits = Int
15}

setParam PermessageDeflate
pmd (ByteString
"server_max_window_bits", Just ByteString
param) = do
    Int
w <- ByteString -> Either String Int
parseWindow ByteString
param
    PermessageDeflate -> Either String PermessageDeflate
forall a b. b -> Either a b
Right PermessageDeflate
pmd {serverMaxWindowBits :: Int
serverMaxWindowBits = Int
w}

setParam PermessageDeflate
pmd (ByteString
"client_max_window_bits", Maybe ByteString
Nothing) = do
    PermessageDeflate -> Either String PermessageDeflate
forall a b. b -> Either a b
Right PermessageDeflate
pmd {clientMaxWindowBits :: Int
clientMaxWindowBits = Int
15}

setParam PermessageDeflate
pmd (ByteString
"client_max_window_bits", Just ByteString
param) = do
    Int
w <- ByteString -> Either String Int
parseWindow ByteString
param
    PermessageDeflate -> Either String PermessageDeflate
forall a b. b -> Either a b
Right PermessageDeflate
pmd {clientMaxWindowBits :: Int
clientMaxWindowBits = Int
w}

setParam PermessageDeflate
pmd (ByteString
_, Maybe ByteString
_) = PermessageDeflate -> Either String PermessageDeflate
forall a b. b -> Either a b
Right PermessageDeflate
pmd


--------------------------------------------------------------------------------
parseWindow :: B.ByteString -> Either String Int
parseWindow :: ByteString -> Either String Int
parseWindow ByteString
bs8 = case String -> Maybe Int
forall a. Read a => String -> Maybe a
readMaybe (ByteString -> String
B8.unpack ByteString
bs8) of
    Just Int
w
        | Int
w Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
8 Bool -> Bool -> Bool
&& Int
w Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
15 -> Int -> Either String Int
forall a b. b -> Either a b
Right Int
w
        | Bool
otherwise         -> String -> Either String Int
forall a b. a -> Either a b
Left (String -> Either String Int) -> String -> Either String Int
forall a b. (a -> b) -> a -> b
$ String
"Window out of bounds: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
w
    Maybe Int
Nothing -> String -> Either String Int
forall a b. a -> Either a b
Left (String -> Either String Int) -> String -> Either String Int
forall a b. (a -> b) -> a -> b
$ String
"Can't parse window: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ ByteString -> String
forall a. Show a => a -> String
show ByteString
bs8


--------------------------------------------------------------------------------
-- | If the window_bits parameter is set to 8, we must set it to 9 instead.
--
-- Related issues:
-- - https://github.com/haskell/zlib/issues/11
-- - https://github.com/madler/zlib/issues/94
--
-- Quote from zlib manual:
--
-- For the current implementation of deflate(), a windowBits value of 8 (a
-- window size of 256 bytes) is not supported. As a result, a request for 8 will
-- result in 9 (a 512-byte window). In that case, providing 8 to inflateInit2()
-- will result in an error when the zlib header with 9 is checked against the
-- initialization of inflate(). The remedy is to not use 8 with deflateInit2()
-- with this initialization, or at least in that case use 9 with inflateInit2().
fixWindowBits :: Int -> Int
fixWindowBits :: Int -> Int
fixWindowBits Int
n
    | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
9     = Int
9
    | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
15    = Int
15
    | Bool
otherwise = Int
n


--------------------------------------------------------------------------------
appTailL :: BL.ByteString
appTailL :: ByteString
appTailL = [Word8] -> ByteString
BL.pack [Word8
0x00,Word8
0x00,Word8
0xff,Word8
0xff]


--------------------------------------------------------------------------------
maybeStrip :: BL.ByteString -> BL.ByteString
maybeStrip :: ByteString -> ByteString
maybeStrip ByteString
x | ByteString
appTailL ByteString -> ByteString -> Bool
`BL.isSuffixOf` ByteString
x = Int64 -> ByteString -> ByteString
BL.take (ByteString -> Int64
BL.length ByteString
x Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
- Int64
4) ByteString
x
maybeStrip ByteString
x = ByteString
x


--------------------------------------------------------------------------------
rejectExtensions :: Message -> IO Message
rejectExtensions :: Message -> IO Message
rejectExtensions (DataMessage Bool
rsv1 Bool
rsv2 Bool
rsv3 DataMessage
_) | Bool
rsv1 Bool -> Bool -> Bool
|| Bool
rsv2 Bool -> Bool -> Bool
|| Bool
rsv3 =
    ConnectionException -> IO Message
forall e a. Exception e => e -> IO a
throwIO (ConnectionException -> IO Message)
-> ConnectionException -> IO Message
forall a b. (a -> b) -> a -> b
$ Word16 -> ByteString -> ConnectionException
CloseRequest Word16
1002 ByteString
"Protocol Error"
rejectExtensions Message
x = Message -> IO Message
forall (m :: * -> *) a. Monad m => a -> m a
return Message
x


--------------------------------------------------------------------------------
makeMessageDeflater
    :: Maybe PermessageDeflate -> IO (Message -> IO Message)
makeMessageDeflater :: Maybe PermessageDeflate -> IO (Message -> IO Message)
makeMessageDeflater Maybe PermessageDeflate
Nothing = (Message -> IO Message) -> IO (Message -> IO Message)
forall (m :: * -> *) a. Monad m => a -> m a
return Message -> IO Message
rejectExtensions
makeMessageDeflater (Just PermessageDeflate
pmd)
    | PermessageDeflate -> Bool
serverNoContextTakeover PermessageDeflate
pmd = do
        (Message -> IO Message) -> IO (Message -> IO Message)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Message -> IO Message) -> IO (Message -> IO Message))
-> (Message -> IO Message) -> IO (Message -> IO Message)
forall a b. (a -> b) -> a -> b
$ \Message
msg -> do
            Deflate
ptr <- PermessageDeflate -> IO Deflate
initDeflate PermessageDeflate
pmd
            (ByteString -> IO ByteString) -> Message -> IO Message
deflateMessageWith (Deflate -> ByteString -> IO ByteString
deflateBody Deflate
ptr) Message
msg
    | Bool
otherwise = do
        Deflate
ptr <- PermessageDeflate -> IO Deflate
initDeflate PermessageDeflate
pmd
        (Message -> IO Message) -> IO (Message -> IO Message)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Message -> IO Message) -> IO (Message -> IO Message))
-> (Message -> IO Message) -> IO (Message -> IO Message)
forall a b. (a -> b) -> a -> b
$ \Message
msg ->
            (ByteString -> IO ByteString) -> Message -> IO Message
deflateMessageWith (Deflate -> ByteString -> IO ByteString
deflateBody Deflate
ptr) Message
msg
  where
    ----------------------------------------------------------------------------
    initDeflate :: PermessageDeflate -> IO Zlib.Deflate
    initDeflate :: PermessageDeflate -> IO Deflate
initDeflate PermessageDeflate {Bool
Int
pdCompressionLevel :: Int
clientMaxWindowBits :: Int
serverMaxWindowBits :: Int
clientNoContextTakeover :: Bool
serverNoContextTakeover :: Bool
pdCompressionLevel :: PermessageDeflate -> Int
clientMaxWindowBits :: PermessageDeflate -> Int
serverMaxWindowBits :: PermessageDeflate -> Int
clientNoContextTakeover :: PermessageDeflate -> Bool
serverNoContextTakeover :: PermessageDeflate -> Bool
..} =
        Int -> WindowBits -> IO Deflate
Zlib.initDeflate
            Int
pdCompressionLevel
            (Int -> WindowBits
Zlib.WindowBits (- (Int -> Int
fixWindowBits Int
serverMaxWindowBits)))


    ----------------------------------------------------------------------------
    deflateMessageWith
        :: (BL.ByteString -> IO BL.ByteString)
        -> Message -> IO Message
    deflateMessageWith :: (ByteString -> IO ByteString) -> Message -> IO Message
deflateMessageWith ByteString -> IO ByteString
deflater (DataMessage Bool
False Bool
False Bool
False (Text ByteString
x Maybe Text
_)) = do
        ByteString
x' <- ByteString -> IO ByteString
deflater ByteString
x
        Message -> IO Message
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> Bool -> Bool -> DataMessage -> Message
DataMessage Bool
True Bool
False Bool
False (ByteString -> Maybe Text -> DataMessage
Text ByteString
x' Maybe Text
forall a. Maybe a
Nothing))
    deflateMessageWith ByteString -> IO ByteString
deflater (DataMessage Bool
False Bool
False Bool
False (Binary ByteString
x)) = do
        ByteString
x' <- ByteString -> IO ByteString
deflater ByteString
x
        Message -> IO Message
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> Bool -> Bool -> DataMessage -> Message
DataMessage Bool
True Bool
False Bool
False (ByteString -> DataMessage
Binary ByteString
x'))
    deflateMessageWith ByteString -> IO ByteString
_ Message
x = Message -> IO Message
forall (m :: * -> *) a. Monad m => a -> m a
return Message
x


    ----------------------------------------------------------------------------
    deflateBody :: Zlib.Deflate -> BL.ByteString -> IO BL.ByteString
    deflateBody :: Deflate -> ByteString -> IO ByteString
deflateBody Deflate
ptr = (ByteString -> ByteString) -> IO ByteString -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> ByteString
maybeStrip (IO ByteString -> IO ByteString)
-> (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> IO ByteString
go ([ByteString] -> IO ByteString)
-> (ByteString -> [ByteString]) -> ByteString -> IO ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString]
BL.toChunks
      where
        go :: [ByteString] -> IO ByteString
go [] =
            Popper -> IO ByteString
dePopper (Deflate -> Popper
Zlib.flushDeflate Deflate
ptr)
        go (ByteString
c : [ByteString]
cs) = do
            ByteString
chunk <- Deflate -> ByteString -> IO Popper
Zlib.feedDeflate Deflate
ptr ByteString
c IO Popper -> (Popper -> IO ByteString) -> IO ByteString
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Popper -> IO ByteString
dePopper
            (ByteString
chunk ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<>) (ByteString -> ByteString) -> IO ByteString -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [ByteString] -> IO ByteString
go [ByteString]
cs


--------------------------------------------------------------------------------
dePopper :: Zlib.Popper -> IO BL.ByteString
dePopper :: Popper -> IO ByteString
dePopper Popper
p = Popper
p Popper -> (PopperRes -> IO ByteString) -> IO ByteString
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \PopperRes
res -> case PopperRes
res of
    PopperRes
Zlib.PRDone    -> ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
BL.empty
    Zlib.PRNext ByteString
c  -> ByteString -> ByteString -> ByteString
BL.chunk ByteString
c (ByteString -> ByteString) -> IO ByteString -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Popper -> IO ByteString
dePopper Popper
p
    Zlib.PRError ZlibException
x -> ConnectionException -> IO ByteString
forall e a. Exception e => e -> IO a
throwIO (ConnectionException -> IO ByteString)
-> ConnectionException -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Word16 -> ByteString -> ConnectionException
CloseRequest Word16
1002 (String -> ByteString
BL8.pack (ZlibException -> String
forall a. Show a => a -> String
show ZlibException
x))


--------------------------------------------------------------------------------
makeMessageInflater
    :: SizeLimit -> Maybe PermessageDeflate
    -> IO (Message -> IO Message)
makeMessageInflater :: SizeLimit -> Maybe PermessageDeflate -> IO (Message -> IO Message)
makeMessageInflater SizeLimit
_ Maybe PermessageDeflate
Nothing = (Message -> IO Message) -> IO (Message -> IO Message)
forall (m :: * -> *) a. Monad m => a -> m a
return Message -> IO Message
rejectExtensions
makeMessageInflater SizeLimit
messageLimit (Just PermessageDeflate
pmd)
    | PermessageDeflate -> Bool
clientNoContextTakeover PermessageDeflate
pmd =
        (Message -> IO Message) -> IO (Message -> IO Message)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Message -> IO Message) -> IO (Message -> IO Message))
-> (Message -> IO Message) -> IO (Message -> IO Message)
forall a b. (a -> b) -> a -> b
$ \Message
msg -> do
            Inflate
ptr <- PermessageDeflate -> IO Inflate
initInflate PermessageDeflate
pmd
            (ByteString -> IO ByteString) -> Message -> IO Message
inflateMessageWith (Inflate -> ByteString -> IO ByteString
inflateBody Inflate
ptr) Message
msg
    | Bool
otherwise = do
        Inflate
ptr <- PermessageDeflate -> IO Inflate
initInflate PermessageDeflate
pmd
        (Message -> IO Message) -> IO (Message -> IO Message)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Message -> IO Message) -> IO (Message -> IO Message))
-> (Message -> IO Message) -> IO (Message -> IO Message)
forall a b. (a -> b) -> a -> b
$ \Message
msg ->
            (ByteString -> IO ByteString) -> Message -> IO Message
inflateMessageWith (Inflate -> ByteString -> IO ByteString
inflateBody Inflate
ptr) Message
msg
  where
    --------------------------------------------------------------------------------
    initInflate :: PermessageDeflate -> IO Zlib.Inflate
    initInflate :: PermessageDeflate -> IO Inflate
initInflate PermessageDeflate {Bool
Int
pdCompressionLevel :: Int
clientMaxWindowBits :: Int
serverMaxWindowBits :: Int
clientNoContextTakeover :: Bool
serverNoContextTakeover :: Bool
pdCompressionLevel :: PermessageDeflate -> Int
clientMaxWindowBits :: PermessageDeflate -> Int
serverMaxWindowBits :: PermessageDeflate -> Int
clientNoContextTakeover :: PermessageDeflate -> Bool
serverNoContextTakeover :: PermessageDeflate -> Bool
..} =
        WindowBits -> IO Inflate
Zlib.initInflate
            (Int -> WindowBits
Zlib.WindowBits (- (Int -> Int
fixWindowBits Int
clientMaxWindowBits)))


    ----------------------------------------------------------------------------
    inflateMessageWith
        :: (BL.ByteString -> IO BL.ByteString)
        -> Message -> IO Message
    inflateMessageWith :: (ByteString -> IO ByteString) -> Message -> IO Message
inflateMessageWith ByteString -> IO ByteString
inflater (DataMessage Bool
True Bool
a Bool
b (Text ByteString
x Maybe Text
_)) = do
        ByteString
x' <- ByteString -> IO ByteString
inflater ByteString
x
        Message -> IO Message
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> Bool -> Bool -> DataMessage -> Message
DataMessage Bool
False Bool
a Bool
b (ByteString -> Maybe Text -> DataMessage
Text ByteString
x' Maybe Text
forall a. Maybe a
Nothing))
    inflateMessageWith ByteString -> IO ByteString
inflater (DataMessage Bool
True Bool
a Bool
b (Binary ByteString
x)) = do
        ByteString
x' <- ByteString -> IO ByteString
inflater ByteString
x
        Message -> IO Message
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> Bool -> Bool -> DataMessage -> Message
DataMessage Bool
False Bool
a Bool
b (ByteString -> DataMessage
Binary ByteString
x'))
    inflateMessageWith ByteString -> IO ByteString
_ Message
x = Message -> IO Message
forall (m :: * -> *) a. Monad m => a -> m a
return Message
x


    ----------------------------------------------------------------------------
    inflateBody :: Zlib.Inflate -> BL.ByteString -> IO BL.ByteString
    inflateBody :: Inflate -> ByteString -> IO ByteString
inflateBody Inflate
ptr =
        Int64 -> [ByteString] -> IO ByteString
go Int64
0 ([ByteString] -> IO ByteString)
-> (ByteString -> [ByteString]) -> ByteString -> IO ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString]
BL.toChunks (ByteString -> [ByteString])
-> (ByteString -> ByteString) -> ByteString -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
appTailL)
      where
        go :: Int64 -> [B.ByteString] -> IO BL.ByteString
        go :: Int64 -> [ByteString] -> IO ByteString
go Int64
size0 []       = do
            ByteString
chunk <- Inflate -> IO ByteString
Zlib.flushInflate Inflate
ptr
            Int64 -> IO ()
checkSize (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
B.length ByteString
chunk) Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64
size0)
            ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> ByteString
BL.fromStrict ByteString
chunk)
        go Int64
size0 (ByteString
c : [ByteString]
cs) = do
            ByteString
chunk <- Inflate -> ByteString -> IO Popper
Zlib.feedInflate Inflate
ptr ByteString
c IO Popper -> (Popper -> IO ByteString) -> IO ByteString
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Popper -> IO ByteString
dePopper
            let size1 :: Int64
size1 = Int64
size0 Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ ByteString -> Int64
BL.length ByteString
chunk
            Int64 -> IO ()
checkSize Int64
size1
            (ByteString
chunk ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<>) (ByteString -> ByteString) -> IO ByteString -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int64 -> [ByteString] -> IO ByteString
go Int64
size1 [ByteString]
cs


    ----------------------------------------------------------------------------
    checkSize :: Int64 -> IO ()
    checkSize :: Int64 -> IO ()
checkSize Int64
size = Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int64 -> SizeLimit -> Bool
atMostSizeLimit Int64
size SizeLimit
messageLimit) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ ConnectionException -> IO ()
forall e a. Exception e => e -> IO a
throwIO (ConnectionException -> IO ()) -> ConnectionException -> IO ()
forall a b. (a -> b) -> a -> b
$
        String -> ConnectionException
ParseException (String -> ConnectionException) -> String -> ConnectionException
forall a b. (a -> b) -> a -> b
$ String
"Message of size " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int64 -> String
forall a. Show a => a -> String
show Int64
size String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" exceeded limit"