-------------------------------------------------------------------------------- {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE TupleSections #-} module Network.WebSockets.Extensions.PermessageDeflate ( defaultPermessageDeflate , PermessageDeflate(..) , negotiateDeflate -- * Considered internal , makeMessageInflater , makeMessageDeflater ) where -------------------------------------------------------------------------------- import Control.Applicative ((<$>)) import Control.Exception (throwIO) import Control.Monad (foldM, unless) import qualified Data.ByteString as B import qualified Data.ByteString.Char8 as B8 import qualified Data.ByteString.Lazy as BL import qualified Data.ByteString.Lazy.Char8 as BL8 import qualified Data.ByteString.Lazy.Internal as BL import Data.Int (Int64) import Data.Monoid import qualified Data.Streaming.Zlib as Zlib import Network.WebSockets.Connection.Options import Network.WebSockets.Extensions import Network.WebSockets.Extensions.Description import Network.WebSockets.Http import Network.WebSockets.Types import Prelude import Text.Read (readMaybe) -------------------------------------------------------------------------------- -- | Convert the parameters to an 'ExtensionDescription' that we can put in a -- 'Sec-WebSocket-Extensions' header. toExtensionDescription :: PermessageDeflate -> ExtensionDescription toExtensionDescription PermessageDeflate {..} = ExtensionDescription { extName = "permessage-deflate" , extParams = [("server_no_context_takeover", Nothing) | serverNoContextTakeover] ++ [("client_no_context_takeover", Nothing) | clientNoContextTakeover] ++ [("server_max_window_bits", param serverMaxWindowBits) | serverMaxWindowBits /= 15] ++ [("client_max_window_bits", param clientMaxWindowBits) | clientMaxWindowBits /= 15] } where param = Just . B8.pack . show -------------------------------------------------------------------------------- toHeaders :: PermessageDeflate -> Headers toHeaders pmd = [ ( "Sec-WebSocket-Extensions" , encodeExtensionDescriptions [toExtensionDescription pmd] ) ] -------------------------------------------------------------------------------- negotiateDeflate :: SizeLimit -> Maybe PermessageDeflate -> NegotiateExtension negotiateDeflate messageLimit pmd0 exts0 = do (headers, pmd1) <- negotiateDeflateOpts exts0 pmd0 return Extension { extHeaders = headers , extParse = \parseRaw -> do inflate <- makeMessageInflater messageLimit pmd1 return $ do msg <- parseRaw case msg of Nothing -> return Nothing Just m -> fmap Just (inflate m) , extWrite = \writeRaw -> do deflate <- makeMessageDeflater pmd1 return $ \msgs -> mapM deflate msgs >>= writeRaw } where negotiateDeflateOpts :: ExtensionDescriptions -> Maybe PermessageDeflate -> Either String (Headers, Maybe PermessageDeflate) negotiateDeflateOpts (ext : _) (Just x) | extName ext == "x-webkit-deflate-frame" = Right ([("Sec-WebSocket-Extensions", "x-webkit-deflate-frame")], Just x) negotiateDeflateOpts (ext : _) (Just x) | extName ext == "permessage-deflate" = do x' <- foldM setParam x (extParams ext) Right (toHeaders x', Just x') negotiateDeflateOpts (_ : exts) (Just x) = negotiateDeflateOpts exts (Just x) negotiateDeflateOpts _ _ = Right ([], Nothing) -------------------------------------------------------------------------------- setParam :: PermessageDeflate -> ExtensionParam -> Either String PermessageDeflate setParam pmd ("server_no_context_takeover", _) = Right pmd {serverNoContextTakeover = True} setParam pmd ("client_no_context_takeover", _) = Right pmd {clientNoContextTakeover = True} setParam pmd ("server_max_window_bits", Nothing) = Right pmd {serverMaxWindowBits = 15} setParam pmd ("server_max_window_bits", Just param) = do w <- parseWindow param Right pmd {serverMaxWindowBits = w} setParam pmd ("client_max_window_bits", Nothing) = do Right pmd {clientMaxWindowBits = 15} setParam pmd ("client_max_window_bits", Just param) = do w <- parseWindow param Right pmd {clientMaxWindowBits = w} setParam pmd (_, _) = Right pmd -------------------------------------------------------------------------------- parseWindow :: B.ByteString -> Either String Int parseWindow bs8 = case readMaybe (B8.unpack bs8) of Just w | w >= 8 && w <= 15 -> Right w | otherwise -> Left $ "Window out of bounds: " ++ show w Nothing -> Left $ "Can't parse window: " ++ show bs8 -------------------------------------------------------------------------------- -- | If the window_bits parameter is set to 8, we must set it to 9 instead. -- -- Related issues: -- - https://github.com/haskell/zlib/issues/11 -- - https://github.com/madler/zlib/issues/94 -- -- Quote from zlib manual: -- -- For the current implementation of deflate(), a windowBits value of 8 (a -- window size of 256 bytes) is not supported. As a result, a request for 8 will -- result in 9 (a 512-byte window). In that case, providing 8 to inflateInit2() -- will result in an error when the zlib header with 9 is checked against the -- initialization of inflate(). The remedy is to not use 8 with deflateInit2() -- with this initialization, or at least in that case use 9 with inflateInit2(). fixWindowBits :: Int -> Int fixWindowBits n | n < 9 = 9 | n > 15 = 15 | otherwise = n -------------------------------------------------------------------------------- appTailL :: BL.ByteString appTailL = BL.pack [0x00,0x00,0xff,0xff] -------------------------------------------------------------------------------- maybeStrip :: BL.ByteString -> BL.ByteString maybeStrip x | appTailL `BL.isSuffixOf` x = BL.take (BL.length x - 4) x maybeStrip x = x -------------------------------------------------------------------------------- rejectExtensions :: Message -> IO Message rejectExtensions (DataMessage rsv1 rsv2 rsv3 _) | rsv1 || rsv2 || rsv3 = throwIO $ CloseRequest 1002 "Protocol Error" rejectExtensions x = return x -------------------------------------------------------------------------------- makeMessageDeflater :: Maybe PermessageDeflate -> IO (Message -> IO Message) makeMessageDeflater Nothing = return rejectExtensions makeMessageDeflater (Just pmd) | serverNoContextTakeover pmd = do return $ \msg -> do ptr <- initDeflate pmd deflateMessageWith (deflateBody ptr) msg | otherwise = do ptr <- initDeflate pmd return $ \msg -> deflateMessageWith (deflateBody ptr) msg where ---------------------------------------------------------------------------- initDeflate :: PermessageDeflate -> IO Zlib.Deflate initDeflate PermessageDeflate {..} = Zlib.initDeflate pdCompressionLevel (Zlib.WindowBits (- (fixWindowBits serverMaxWindowBits))) ---------------------------------------------------------------------------- deflateMessageWith :: (BL.ByteString -> IO BL.ByteString) -> Message -> IO Message deflateMessageWith deflater (DataMessage False False False (Text x _)) = do x' <- deflater x return (DataMessage True False False (Text x' Nothing)) deflateMessageWith deflater (DataMessage False False False (Binary x)) = do x' <- deflater x return (DataMessage True False False (Binary x')) deflateMessageWith _ x = return x ---------------------------------------------------------------------------- deflateBody :: Zlib.Deflate -> BL.ByteString -> IO BL.ByteString deflateBody ptr = fmap maybeStrip . go . BL.toChunks where go [] = dePopper (Zlib.flushDeflate ptr) go (c : cs) = do chunk <- Zlib.feedDeflate ptr c >>= dePopper (chunk <>) <$> go cs -------------------------------------------------------------------------------- dePopper :: Zlib.Popper -> IO BL.ByteString dePopper p = p >>= \res -> case res of Zlib.PRDone -> return BL.empty Zlib.PRNext c -> BL.chunk c <$> dePopper p Zlib.PRError x -> throwIO $ CloseRequest 1002 (BL8.pack (show x)) -------------------------------------------------------------------------------- makeMessageInflater :: SizeLimit -> Maybe PermessageDeflate -> IO (Message -> IO Message) makeMessageInflater _ Nothing = return rejectExtensions makeMessageInflater messageLimit (Just pmd) | clientNoContextTakeover pmd = return $ \msg -> do ptr <- initInflate pmd inflateMessageWith (inflateBody ptr) msg | otherwise = do ptr <- initInflate pmd return $ \msg -> inflateMessageWith (inflateBody ptr) msg where -------------------------------------------------------------------------------- initInflate :: PermessageDeflate -> IO Zlib.Inflate initInflate PermessageDeflate {..} = Zlib.initInflate (Zlib.WindowBits (- (fixWindowBits clientMaxWindowBits))) ---------------------------------------------------------------------------- inflateMessageWith :: (BL.ByteString -> IO BL.ByteString) -> Message -> IO Message inflateMessageWith inflater (DataMessage True a b (Text x _)) = do x' <- inflater x return (DataMessage False a b (Text x' Nothing)) inflateMessageWith inflater (DataMessage True a b (Binary x)) = do x' <- inflater x return (DataMessage False a b (Binary x')) inflateMessageWith _ x = return x ---------------------------------------------------------------------------- inflateBody :: Zlib.Inflate -> BL.ByteString -> IO BL.ByteString inflateBody ptr = go 0 . BL.toChunks . (<> appTailL) where go :: Int64 -> [B.ByteString] -> IO BL.ByteString go size0 [] = do chunk <- Zlib.flushInflate ptr checkSize (fromIntegral (B.length chunk) + size0) return (BL.fromStrict chunk) go size0 (c : cs) = do chunk <- Zlib.feedInflate ptr c >>= dePopper let size1 = size0 + BL.length chunk checkSize size1 (chunk <>) <$> go size1 cs ---------------------------------------------------------------------------- checkSize :: Int64 -> IO () checkSize size = unless (atMostSizeLimit size messageLimit) $ throwIO $ ParseException $ "Message of size " ++ show size ++ " exceeded limit"