{-# LANGUAGE OverloadedStrings #-} module Network.WebSockets.Protocol.Hybi10.Internal ( Hybi10_ (..) , encodeFrameHybi10 ) where import Control.Applicative (pure, (<$>)) import Control.Monad (liftM) import Data.Bits ((.&.), (.|.)) import Data.Maybe (maybeToList) import Data.Monoid (mempty, mappend, mconcat) import System.Random (RandomGen) import Data.Attoparsec (anyWord8) import Data.Binary.Get (runGet, getWord16be, getWord64be) import Data.ByteString (ByteString, intercalate) import Data.ByteString.Char8 () import Data.Digest.Pure.SHA (bytestringDigest, sha1) import Data.Enumerator ((=$)) import Data.Int (Int64) import qualified Blaze.ByteString.Builder as B import qualified Data.Attoparsec as A import qualified Data.Attoparsec.Enumerator as A import qualified Data.ByteString.Base64 as B64 import qualified Data.ByteString.Lazy as BL import qualified Data.Enumerator as E import qualified Data.Enumerator.List as EL import Network.WebSockets.Handshake.Http import Network.WebSockets.Protocol import Network.WebSockets.Protocol.Hybi10.Demultiplex import Network.WebSockets.Protocol.Hybi10.Mask import Network.WebSockets.Types import System.Entropy as R data Hybi10_ = Hybi10_ instance Protocol Hybi10_ where version Hybi10_ = "hybi10" headerVersions Hybi10_ = ["13", "8", "7"] encodeMessages Hybi10_ m = EL.mapAccum (encodeMessageHybi10 m) decodeMessages Hybi10_ = decodeMessagesHybi10 createRequest Hybi10_ = createRequestHybi10 finishRequest Hybi10_ = handshakeHybi10 finishResponse Hybi10_ = finishResponseHybi10 implementations = [Hybi10_] instance TextProtocol Hybi10_ instance BinaryProtocol Hybi10_ encodeMessageHybi10 :: RandomGen g => Bool -> g -> Message p -> (g, B.Builder) encodeMessageHybi10 needMask gen msg = (gen', builder) where mkFrame = Frame True False False False (mask, gen') = if needMask then randomMask gen else (Nothing, gen) builder = encodeFrameHybi10 mask $ case msg of (ControlMessage (Close pl)) -> mkFrame CloseFrame pl (ControlMessage (Ping pl)) -> mkFrame PingFrame pl (ControlMessage (Pong pl)) -> mkFrame PongFrame pl (DataMessage (Text pl)) -> mkFrame TextFrame pl (DataMessage (Binary pl)) -> mkFrame BinaryFrame pl -- | Encode a frame encodeFrameHybi10 :: Mask -> Frame -> B.Builder encodeFrameHybi10 mask f = B.fromWord8 byte0 `mappend` B.fromWord8 byte1 `mappend` len `mappend` maskbytes `mappend` B.fromLazyByteString (maskPayload mask (framePayload f)) where byte0 = fin .|. rsv1 .|. rsv2 .|. rsv3 .|. opcode fin = if frameFin f then 0x80 else 0x00 rsv1 = if frameRsv1 f then 0x40 else 0x00 rsv2 = if frameRsv2 f then 0x20 else 0x00 rsv3 = if frameRsv3 f then 0x10 else 0x00 opcode = case frameType f of ContinuationFrame -> 0x00 TextFrame -> 0x01 BinaryFrame -> 0x02 CloseFrame -> 0x08 PingFrame -> 0x09 PongFrame -> 0x0a (maskflag, maskbytes) = case mask of Nothing -> (0x00, mempty) Just m -> (0x80, B.fromByteString m) byte1 = maskflag .|. lenflag len' = BL.length (framePayload f) (lenflag, len) | len' < 126 = (fromIntegral len', mempty) | len' < 0x10000 = (126, B.fromWord16be (fromIntegral len')) | otherwise = (127, B.fromWord64be (fromIntegral len')) decodeMessagesHybi10 :: Monad m => E.Enumeratee ByteString (Message p) m a decodeMessagesHybi10 = (E.sequence (A.iterParser parseFrame) =$) . demultiplexEnum demultiplexEnum :: Monad m => E.Enumeratee Frame (Message p) m a demultiplexEnum = EL.concatMapAccum step emptyDemultiplexState where step s f = let (m, s') = demultiplex s f in (s', maybeToList m) -- | Parse a frame parseFrame :: A.Parser Frame parseFrame = do byte0 <- anyWord8 let fin = byte0 .&. 0x80 == 0x80 rsv1 = byte0 .&. 0x40 == 0x40 rsv2 = byte0 .&. 0x20 == 0x20 rsv3 = byte0 .&. 0x10 == 0x10 opcode = byte0 .&. 0x0f let ft = case opcode of 0x00 -> ContinuationFrame 0x01 -> TextFrame 0x02 -> BinaryFrame 0x08 -> CloseFrame 0x09 -> PingFrame 0x0a -> PongFrame _ -> error "Unknown opcode" byte1 <- anyWord8 let mask = byte1 .&. 0x80 == 0x80 lenflag = fromIntegral (byte1 .&. 0x7f) len <- case lenflag of 126 -> fromIntegral . runGet' getWord16be <$> A.take 2 127 -> fromIntegral . runGet' getWord64be <$> A.take 8 _ -> return lenflag masker <- maskPayload <$> if mask then Just <$> A.take 4 else pure Nothing chunks <- take64 len return $ Frame fin rsv1 rsv2 rsv3 ft (masker $ BL.fromChunks chunks) where runGet' g = runGet g . BL.fromChunks . return take64 :: Int64 -> A.Parser [ByteString] take64 n | n <= 0 = return [] | otherwise = do let n' = min intMax n chunk <- A.take (fromIntegral n') (chunk :) <$> take64 (n - n') where intMax :: Int64 intMax = fromIntegral (maxBound :: Int) handshakeHybi10 :: Monad m => RequestHttpPart -> E.Iteratee ByteString m Request handshakeHybi10 reqHttp@(RequestHttpPart path h _) = do key <- getRequestHeader reqHttp "Sec-WebSocket-Key" let hash = hashKeyHybi10 key let encoded = B64.encode hash return $ Request path h $ response101 [("Sec-WebSocket-Accept", encoded)] "" createRequestHybi10 :: ByteString -> ByteString -> Maybe ByteString -> Maybe [ByteString] -> Bool -> IO RequestHttpPart createRequestHybi10 hostname path origin protocols secure = do key <- B64.encode `liftM` getEntropy 16 return $ RequestHttpPart path (headers key) secure where headers key = [("Host" , hostname ) ,("Connection" , "Upgrade" ) ,("Upgrade" , "websocket" ) ,("Sec-WebSocket-Key" , key ) ,("Sec-WebSocket-Version" , versionNumber) ] ++ protocolHeader protocols ++ originHeader origin originHeader (Just o) = [("Origin" , o )] originHeader Nothing = [] protocolHeader (Just ps) = [("Sec-WebSocket-Protocol", intercalate ", " ps)] protocolHeader Nothing = [] versionNumber = head . headerVersions $ Hybi10_ finishResponseHybi10 :: Monad m => RequestHttpPart -> ResponseHttpPart -> E.Iteratee ByteString m ResponseBody finishResponseHybi10 request response = do -- Response message should be one of -- -- - WebSocket Protocol Handshake -- - Switching Protocols -- -- But we don't check it for now if responseHttpCode response /= 101 then throw "Wrong response status or message." else do key <- getRequestHeader request "Sec-WebSocket-Key" responseHash <- getResponseHeader response "Sec-WebSocket-Accept" let challengeHash = B64.encode $ hashKeyHybi10 key if responseHash /= challengeHash then throw "Challenge and response hashes do not match." else return $ ResponseBody response "" where throw msg = E.throwError $ MalformedResponse response msg hashKeyHybi10 :: ByteString -> ByteString hashKeyHybi10 key = unlazy $ bytestringDigest $ sha1 $ lazy $ key `mappend` guid where guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" lazy = BL.fromChunks . return unlazy = mconcat . BL.toChunks