-------------------------------------------------------------------------------- {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE TupleSections #-} module Network.WebSockets.Extensions.PermessageDeflate ( defaultPermessageDeflate , PermessageDeflate(..) , negotiateDeflate ) where -------------------------------------------------------------------------------- import Control.Applicative ((<$>)) import Control.Exception (throwIO) import Control.Monad (foldM) import qualified Data.ByteString as B import qualified Data.ByteString.Char8 as B8 import qualified Data.ByteString.Lazy as BL import qualified Data.ByteString.Lazy.Char8 as BL8 import qualified Data.ByteString.Lazy.Internal as BL import Data.Monoid import qualified Data.Streaming.Zlib as Zlib import Network.WebSockets.Extensions import Network.WebSockets.Extensions.Description import Network.WebSockets.Http import Network.WebSockets.Types import Text.Read (readMaybe) import Prelude -------------------------------------------------------------------------------- -- | Four extension parameters are defined for "permessage-deflate" to -- help endpoints manage per-connection resource usage. -- -- - "server_no_context_takeover" -- - "client_no_context_takeover" -- - "server_max_window_bits" -- - "client_max_window_bits" data PermessageDeflate = PermessageDeflate { serverNoContextTakeover :: Bool , clientNoContextTakeover :: Bool , serverMaxWindowBits :: Int , clientMaxWindowBits :: Int , pdCompressionLevel :: Int } deriving (Eq, Show) -------------------------------------------------------------------------------- defaultPermessageDeflate :: PermessageDeflate defaultPermessageDeflate = PermessageDeflate False False 15 15 8 -------------------------------------------------------------------------------- -- | Convert the parameters to an 'ExtensionDescription' that we can put in a -- 'Sec-WebSocket-Extensions' header. toExtensionDescription :: PermessageDeflate -> ExtensionDescription toExtensionDescription PermessageDeflate {..} = ExtensionDescription { extName = "permessage-deflate" , extParams = [("server_no_context_takeover", Nothing) | serverNoContextTakeover] ++ [("client_no_context_takeover", Nothing) | clientNoContextTakeover] ++ [("server_max_window_bits", param serverMaxWindowBits) | serverMaxWindowBits /= 15] ++ [("client_max_window_bits", param clientMaxWindowBits) | clientMaxWindowBits /= 15] } where param = Just . B8.pack . show -------------------------------------------------------------------------------- toHeaders :: PermessageDeflate -> Headers toHeaders pmd = [ ( "Sec-WebSocket-Extensions" , encodeExtensionDescriptions [toExtensionDescription pmd] ) ] -------------------------------------------------------------------------------- negotiateDeflate :: Maybe PermessageDeflate -> NegotiateExtension negotiateDeflate pmd0 exts0 = do (headers, pmd1) <- negotiateDeflateOpts exts0 pmd0 return Extension { extHeaders = headers , extParse = \parseRaw -> do inflate <- makeMessageInflater pmd1 return $ do msg <- parseRaw case msg of Nothing -> return Nothing Just m -> fmap Just (inflate m) , extWrite = \writeRaw -> do deflate <- makeMessageDeflater pmd1 return $ \msgs -> mapM deflate msgs >>= writeRaw } where negotiateDeflateOpts :: ExtensionDescriptions -> Maybe PermessageDeflate -> Either String (Headers, Maybe PermessageDeflate) negotiateDeflateOpts (ext : _) (Just x) | extName ext == "x-webkit-deflate-frame" = Right ([("Sec-WebSocket-Extensions", "x-webkit-deflate-frame")], Just x) negotiateDeflateOpts (ext : _) (Just x) | extName ext == "permessage-deflate" = do x' <- foldM setParam x (extParams ext) Right (toHeaders x', Just x') negotiateDeflateOpts (_ : exts) (Just x) = negotiateDeflateOpts exts (Just x) negotiateDeflateOpts _ _ = Right ([], Nothing) -------------------------------------------------------------------------------- setParam :: PermessageDeflate -> ExtensionParam -> Either String PermessageDeflate setParam pmd ("server_no_context_takeover", _) = Right pmd {serverNoContextTakeover = True} setParam pmd ("client_no_context_takeover", _) = Right pmd {clientNoContextTakeover = True} setParam pmd ("server_max_window_bits", Nothing) = Right pmd {serverMaxWindowBits = 15} setParam pmd ("server_max_window_bits", Just param) = do w <- parseWindow param Right pmd {serverMaxWindowBits = w} setParam pmd ("client_max_window_bits", Nothing) = do Right pmd {clientMaxWindowBits = 15} setParam pmd ("client_max_window_bits", Just param) = do w <- parseWindow param Right pmd {clientMaxWindowBits = w} setParam pmd (_, _) = Right pmd -------------------------------------------------------------------------------- parseWindow :: B.ByteString -> Either String Int parseWindow bs8 = case readMaybe (B8.unpack bs8) of Just w | w >= 8 && w <= 15 -> Right w | otherwise -> Left $ "Window out of bounds: " ++ show w Nothing -> Left $ "Can't parse window: " ++ show bs8 -------------------------------------------------------------------------------- -- | If the window_bits parameter is set to 8, we must set it to 9 instead. -- -- Related issues: -- - https://github.com/haskell/zlib/issues/11 -- - https://github.com/madler/zlib/issues/94 -- -- Quote from zlib manual: -- -- For the current implementation of deflate(), a windowBits value of 8 (a -- window size of 256 bytes) is not supported. As a result, a request for 8 will -- result in 9 (a 512-byte window). In that case, providing 8 to inflateInit2() -- will result in an error when the zlib header with 9 is checked against the -- initialization of inflate(). The remedy is to not use 8 with deflateInit2() -- with this initialization, or at least in that case use 9 with inflateInit2(). fixWindowBits :: Int -> Int fixWindowBits n | n < 9 = 9 | n > 15 = 15 | otherwise = n -------------------------------------------------------------------------------- appTailL :: BL.ByteString appTailL = BL.pack [0x00,0x00,0xff,0xff] -------------------------------------------------------------------------------- maybeStrip :: BL.ByteString -> BL.ByteString maybeStrip x | appTailL `BL.isSuffixOf` x = BL.take (BL.length x - 4) x maybeStrip x = x -------------------------------------------------------------------------------- rejectExtensions :: Message -> IO Message rejectExtensions (DataMessage rsv1 rsv2 rsv3 _) | rsv1 || rsv2 || rsv3 = throwIO $ CloseRequest 1002 "Protocol Error" rejectExtensions x = return x -------------------------------------------------------------------------------- makeMessageDeflater :: Maybe PermessageDeflate -> IO (Message -> IO Message) makeMessageDeflater Nothing = return rejectExtensions makeMessageDeflater (Just pmd) | serverNoContextTakeover pmd = do return $ \msg -> do ptr <- initDeflate pmd deflateMessageWith (deflateBody ptr) msg | otherwise = do ptr <- initDeflate pmd return $ \msg -> deflateMessageWith (deflateBody ptr) msg where ---------------------------------------------------------------------------- initDeflate :: PermessageDeflate -> IO Zlib.Deflate initDeflate PermessageDeflate {..} = Zlib.initDeflate pdCompressionLevel (Zlib.WindowBits (- (fixWindowBits serverMaxWindowBits))) ---------------------------------------------------------------------------- deflateMessageWith :: (BL.ByteString -> IO BL.ByteString) -> Message -> IO Message deflateMessageWith deflater (DataMessage False False False (Text x _)) = do x' <- deflater x return (DataMessage True False False (Text x' Nothing)) deflateMessageWith deflater (DataMessage False False False (Binary x)) = do x' <- deflater x return (DataMessage True False False (Binary x')) deflateMessageWith _ x = return x ---------------------------------------------------------------------------- deflateBody :: Zlib.Deflate -> BL.ByteString -> IO BL.ByteString deflateBody ptr = fmap maybeStrip . go . BL.toChunks where go [] = dePopper (Zlib.flushDeflate ptr) go (c : cs) = do bl <- Zlib.feedDeflate ptr c >>= dePopper (bl <>) <$> go cs -------------------------------------------------------------------------------- dePopper :: Zlib.Popper -> IO BL.ByteString dePopper p = p >>= \case Zlib.PRDone -> return BL.empty Zlib.PRNext c -> BL.chunk c <$> dePopper p Zlib.PRError x -> throwIO $ CloseRequest 1002 (BL8.pack (show x)) -------------------------------------------------------------------------------- makeMessageInflater :: Maybe PermessageDeflate -> IO (Message -> IO Message) makeMessageInflater Nothing = return rejectExtensions makeMessageInflater (Just pmd) | clientNoContextTakeover pmd = return $ \msg -> do ptr <- initInflate pmd inflateMessageWith (inflateBody ptr) msg | otherwise = do ptr <- initInflate pmd return $ \msg -> inflateMessageWith (inflateBody ptr) msg where -------------------------------------------------------------------------------- initInflate :: PermessageDeflate -> IO Zlib.Inflate initInflate PermessageDeflate {..} = Zlib.initInflate (Zlib.WindowBits (- (fixWindowBits clientMaxWindowBits))) ---------------------------------------------------------------------------- inflateMessageWith :: (BL.ByteString -> IO BL.ByteString) -> Message -> IO Message inflateMessageWith inflater (DataMessage True a b (Text x _)) = do x' <- inflater x return (DataMessage False a b (Text x' Nothing)) inflateMessageWith inflater (DataMessage True a b (Binary x)) = do x' <- inflater x return (DataMessage False a b (Binary x')) inflateMessageWith _ x = return x ---------------------------------------------------------------------------- inflateBody :: Zlib.Inflate -> BL.ByteString -> IO BL.ByteString inflateBody ptr = go . BL.toChunks . (<> appTailL) where go [] = BL.fromStrict <$> Zlib.flushInflate ptr go (c : cs) = do bl <- Zlib.feedInflate ptr c >>= dePopper (bl <>) <$> go cs