-------------------------------------------------------------------------------- {-# LANGUAGE BangPatterns #-} {-# LANGUAGE OverloadedStrings #-} module Network.WebSockets.Hybi13 ( headerVersions , finishRequest , finishResponse , encodeMessages , decodeMessages , createRequest -- Internal (used for testing) , encodeFrame ) where -------------------------------------------------------------------------------- import qualified Blaze.ByteString.Builder as B import Control.Applicative (pure, (<$>)) import Control.Exception (throw) import Control.Monad (liftM) import Data.Attoparsec (anyWord8) import qualified Data.Attoparsec as A import Data.Binary.Get (getWord16be, getWord64be, runGet) import Data.Bits ((.&.), (.|.)) import Data.ByteString (ByteString) import qualified Data.ByteString.Base64 as B64 import Data.ByteString.Char8 () import qualified Data.ByteString.Lazy as BL import Data.Digest.Pure.SHA (bytestringDigest, sha1) import Data.Int (Int64) import Data.IORef import Data.Monoid (mappend, mconcat, mempty) import Data.Tuple (swap) import System.Entropy as R import qualified System.IO.Streams as Streams import qualified System.IO.Streams.Attoparsec as Streams import System.Random (RandomGen, newStdGen) -------------------------------------------------------------------------------- import Network.WebSockets.Http import Network.WebSockets.Hybi13.Demultiplex import Network.WebSockets.Hybi13.Mask import Network.WebSockets.Types -------------------------------------------------------------------------------- headerVersions :: [ByteString] headerVersions = ["13"] -------------------------------------------------------------------------------- finishRequest :: RequestHead -> Response finishRequest reqHttp = let !key = getRequestHeader reqHttp "Sec-WebSocket-Key" !hash = hashKey key !encoded = B64.encode hash in response101 [("Sec-WebSocket-Accept", encoded)] "" -------------------------------------------------------------------------------- finishResponse :: RequestHead -> ResponseHead -> Response finishResponse request response -- Response message should be one of -- -- - WebSocket Protocol Handshake -- - Switching Protocols -- -- But we don't check it for now | responseCode response /= 101 = throw $ MalformedResponse response "Wrong response status or message." | responseHash /= challengeHash = throw $ MalformedResponse response "Challenge and response hashes do not match." | otherwise = Response response "" where key = getRequestHeader request "Sec-WebSocket-Key" responseHash = getResponseHeader response "Sec-WebSocket-Accept" challengeHash = B64.encode $ hashKey key -------------------------------------------------------------------------------- encodeMessage :: RandomGen g => ConnectionType -> g -> Message -> (g, B.Builder) encodeMessage conType gen msg = (gen', builder `mappend` B.flush) where mkFrame = Frame True False False False (mask, gen') = case conType of ServerConnection -> (Nothing, gen) ClientConnection -> randomMask gen builder = encodeFrame mask $ case msg of (ControlMessage (Close pl)) -> mkFrame CloseFrame pl (ControlMessage (Ping pl)) -> mkFrame PingFrame pl (ControlMessage (Pong pl)) -> mkFrame PongFrame pl (DataMessage (Text pl)) -> mkFrame TextFrame pl (DataMessage (Binary pl)) -> mkFrame BinaryFrame pl -------------------------------------------------------------------------------- encodeMessages :: ConnectionType -> Streams.OutputStream B.Builder -> IO (Streams.OutputStream Message) encodeMessages conType bStream = do genRef <- newIORef =<< newStdGen Streams.lockingOutputStream =<< Streams.makeOutputStream (next genRef) where next :: RandomGen g => IORef g -> Maybe Message -> IO () next _ Nothing = return () next genRef (Just msg) = do build <- atomicModifyIORef genRef $ \s -> encodeMessage conType s msg Streams.write (Just build) bStream -------------------------------------------------------------------------------- encodeFrame :: Mask -> Frame -> B.Builder encodeFrame mask f = B.fromWord8 byte0 `mappend` B.fromWord8 byte1 `mappend` len `mappend` maskbytes `mappend` B.fromLazyByteString (maskPayload mask (framePayload f)) where byte0 = fin .|. rsv1 .|. rsv2 .|. rsv3 .|. opcode fin = if frameFin f then 0x80 else 0x00 rsv1 = if frameRsv1 f then 0x40 else 0x00 rsv2 = if frameRsv2 f then 0x20 else 0x00 rsv3 = if frameRsv3 f then 0x10 else 0x00 opcode = case frameType f of ContinuationFrame -> 0x00 TextFrame -> 0x01 BinaryFrame -> 0x02 CloseFrame -> 0x08 PingFrame -> 0x09 PongFrame -> 0x0a (maskflag, maskbytes) = case mask of Nothing -> (0x00, mempty) Just m -> (0x80, B.fromByteString m) byte1 = maskflag .|. lenflag len' = BL.length (framePayload f) (lenflag, len) | len' < 126 = (fromIntegral len', mempty) | len' < 0x10000 = (126, B.fromWord16be (fromIntegral len')) | otherwise = (127, B.fromWord64be (fromIntegral len')) -------------------------------------------------------------------------------- decodeMessages :: Streams.InputStream ByteString -> IO (Streams.InputStream Message) decodeMessages bsStream = do dmRef <- newIORef emptyDemultiplexState Streams.makeInputStream $ next dmRef where next dmRef = do frame <- Streams.parseFromStream parseFrame bsStream m <- atomicModifyIORef dmRef $ \s -> swap $ demultiplex s frame maybe (next dmRef) (return . Just) m -------------------------------------------------------------------------------- -- | Parse a frame parseFrame :: A.Parser Frame parseFrame = do byte0 <- anyWord8 let fin = byte0 .&. 0x80 == 0x80 rsv1 = byte0 .&. 0x40 == 0x40 rsv2 = byte0 .&. 0x20 == 0x20 rsv3 = byte0 .&. 0x10 == 0x10 opcode = byte0 .&. 0x0f let ft = case opcode of 0x00 -> ContinuationFrame 0x01 -> TextFrame 0x02 -> BinaryFrame 0x08 -> CloseFrame 0x09 -> PingFrame 0x0a -> PongFrame _ -> error "Unknown opcode" byte1 <- anyWord8 let mask = byte1 .&. 0x80 == 0x80 lenflag = fromIntegral (byte1 .&. 0x7f) len <- case lenflag of 126 -> fromIntegral . runGet' getWord16be <$> A.take 2 127 -> fromIntegral . runGet' getWord64be <$> A.take 8 _ -> return lenflag masker <- maskPayload <$> if mask then Just <$> A.take 4 else pure Nothing chunks <- take64 len return $ Frame fin rsv1 rsv2 rsv3 ft (masker $ BL.fromChunks chunks) where runGet' g = runGet g . BL.fromChunks . return take64 :: Int64 -> A.Parser [ByteString] take64 n | n <= 0 = return [] | otherwise = do let n' = min intMax n chunk <- A.take (fromIntegral n') (chunk :) <$> take64 (n - n') where intMax :: Int64 intMax = fromIntegral (maxBound :: Int) -------------------------------------------------------------------------------- hashKey :: ByteString -> ByteString hashKey key = unlazy $ bytestringDigest $ sha1 $ lazy $ key `mappend` guid where guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" lazy = BL.fromChunks . return unlazy = mconcat . BL.toChunks -------------------------------------------------------------------------------- createRequest :: ByteString -> ByteString -> Bool -> Headers -> IO RequestHead createRequest hostname path secure customHeaders = do key <- B64.encode `liftM` getEntropy 16 return $ RequestHead path (headers key ++ customHeaders) secure where headers key = [ ("Host" , hostname ) , ("Connection" , "Upgrade" ) , ("Upgrade" , "websocket" ) , ("Sec-WebSocket-Key" , key ) , ("Sec-WebSocket-Version" , versionNumber) ] versionNumber = head headerVersions