{-# LANGUAGE OverloadedStrings #-} module Network.WebSockets.Protocol.Hybi10.Internal ( Hybi10_ (..) , encodeFrameHybi10 ) where import Control.Applicative (pure, (<$>)) import Control.Monad (liftM) import Data.Bits ((.&.), (.|.)) import Data.Maybe (maybeToList) import Data.Monoid (mempty, mappend, mconcat) import Data.Attoparsec (anyWord8) import Data.Binary.Get (runGet, getWord16be, getWord64be) import Data.ByteString (ByteString, intercalate) import Data.ByteString.Char8 () import Data.Digest.Pure.SHA (bytestringDigest, sha1) import Data.Enumerator ((=$)) import Data.Int (Int64) import qualified Blaze.ByteString.Builder as B import qualified Data.Attoparsec as A import qualified Data.Attoparsec.Enumerator as A import qualified Data.ByteString.Base64 as B64 import qualified Data.ByteString.Lazy as BL import qualified Data.Enumerator as E import qualified Data.Enumerator.List as EL import Network.WebSockets.Handshake.Http import Network.WebSockets.Protocol import Network.WebSockets.Protocol.Hybi10.Demultiplex import Network.WebSockets.Protocol.Hybi10.Mask import Network.WebSockets.Types import System.Entropy as R data Hybi10_ = Hybi10_ instance Protocol Hybi10_ where version Hybi10_ = "hybi10" headerVersions Hybi10_ = ["13", "8", "7"] encodeMessages Hybi10_ = EL.map encodeMessageHybi10 decodeMessages Hybi10_ = decodeMessagesHybi10 createRequest Hybi10_ = createRequestHybi10 finishRequest Hybi10_ = handshakeHybi10 finishResponse Hybi10_ = finishResponseHybi10 implementations = [Hybi10_] instance TextProtocol Hybi10_ instance BinaryProtocol Hybi10_ encodeMessageHybi10 :: Message p -> B.Builder encodeMessageHybi10 msg = builder where mkFrame = Frame True False False False builder = encodeFrameHybi10 $ case msg of (ControlMessage (Close pl)) -> mkFrame CloseFrame pl (ControlMessage (Ping pl)) -> mkFrame PingFrame pl (ControlMessage (Pong pl)) -> mkFrame PongFrame pl (DataMessage (Text pl)) -> mkFrame TextFrame pl (DataMessage (Binary pl)) -> mkFrame BinaryFrame pl -- | Encode a frame encodeFrameHybi10 :: Frame -> B.Builder encodeFrameHybi10 f = B.fromWord8 byte0 `mappend` B.fromWord8 byte1 `mappend` len `mappend` B.fromLazyByteString (framePayload f) where byte0 = fin .|. rsv1 .|. rsv2 .|. rsv3 .|. opcode fin = if frameFin f then 0x80 else 0x00 rsv1 = if frameRsv1 f then 0x40 else 0x00 rsv2 = if frameRsv2 f then 0x20 else 0x00 rsv3 = if frameRsv3 f then 0x10 else 0x00 opcode = case frameType f of ContinuationFrame -> 0x00 TextFrame -> 0x01 BinaryFrame -> 0x02 CloseFrame -> 0x08 PingFrame -> 0x09 PongFrame -> 0x0a byte1 = lenflag len' = BL.length (framePayload f) (lenflag, len) | len' < 126 = (fromIntegral len', mempty) | len' < 0x10000 = (126, B.fromWord16be (fromIntegral len')) | otherwise = (127, B.fromWord64be (fromIntegral len')) decodeMessagesHybi10 :: Monad m => E.Enumeratee ByteString (Message p) m a decodeMessagesHybi10 = (E.sequence (A.iterParser parseFrame) =$) . demultiplexEnum demultiplexEnum :: Monad m => E.Enumeratee Frame (Message p) m a demultiplexEnum = EL.concatMapAccum step emptyDemultiplexState where step s f = let (m, s') = demultiplex s f in (s', maybeToList m) -- | Parse a frame parseFrame :: A.Parser Frame parseFrame = do byte0 <- anyWord8 let fin = byte0 .&. 0x80 == 0x80 rsv1 = byte0 .&. 0x40 == 0x40 rsv2 = byte0 .&. 0x20 == 0x20 rsv3 = byte0 .&. 0x10 == 0x10 opcode = byte0 .&. 0x0f let ft = case opcode of 0x00 -> ContinuationFrame 0x01 -> TextFrame 0x02 -> BinaryFrame 0x08 -> CloseFrame 0x09 -> PingFrame 0x0a -> PongFrame _ -> error "Unknown opcode" byte1 <- anyWord8 let mask = byte1 .&. 0x80 == 0x80 lenflag = fromIntegral (byte1 .&. 0x7f) len <- case lenflag of 126 -> fromIntegral . runGet' getWord16be <$> A.take 2 127 -> fromIntegral . runGet' getWord64be <$> A.take 8 _ -> return lenflag masker <- maskPayload <$> if mask then Just <$> A.take 4 else pure Nothing chunks <- take64 len return $ Frame fin rsv1 rsv2 rsv3 ft (masker $ BL.fromChunks chunks) where runGet' g = runGet g . BL.fromChunks . return take64 :: Int64 -> A.Parser [ByteString] take64 n | n <= 0 = return [] | otherwise = do let n' = min intMax n chunk <- A.take (fromIntegral n') (chunk :) <$> take64 (n - n') where intMax :: Int64 intMax = fromIntegral (maxBound :: Int) handshakeHybi10 :: Monad m => RequestHttpPart -> E.Iteratee ByteString m Request handshakeHybi10 reqHttp@(RequestHttpPart path h _) = do key <- getRequestHeader reqHttp "Sec-WebSocket-Key" let hash = hashKeyHybi10 key let encoded = B64.encode hash return $ Request path h $ response101 [("Sec-WebSocket-Accept", encoded)] "" createRequestHybi10 :: ByteString -> ByteString -> Maybe ByteString -> Maybe [ByteString] -> Bool -> IO RequestHttpPart createRequestHybi10 hostname path origin protocols secure = do key <- B64.encode `liftM` getEntropy 16 return $ RequestHttpPart path (headers key) secure where headers key = [("Host" , hostname ) ,("Connection" , "Upgrade" ) ,("Upgrade" , "websocket" ) ,("Sec-WebSocket-Key" , key ) ,("Sec-WebSocket-Version" , versionNumber) ] ++ protocolHeader protocols ++ originHeader origin originHeader (Just o) = [("Origin" , o )] originHeader Nothing = [] protocolHeader (Just ps) = [("Sec-WebSocket-Protocol", intercalate ", " ps)] protocolHeader Nothing = [] versionNumber = head . headerVersions $ Hybi10_ finishResponseHybi10 :: Monad m => RequestHttpPart -> ResponseHttpPart -> E.Iteratee ByteString m ResponseBody finishResponseHybi10 request response = do -- Response message should be one of -- -- - WebSocket Protocol Handshake -- - Switching Protocols -- -- But we don't check it for now if responseHttpCode response /= 101 then throw "Wrong response status or message." else do key <- getRequestHeader request "Sec-WebSocket-Key" responseHash <- getResponseHeader response "Sec-WebSocket-Accept" let challengeHash = B64.encode $ hashKeyHybi10 key if responseHash /= challengeHash then throw "Challenge and response hashes do not match." else return $ ResponseBody response "" where throw msg = E.throwError $ MalformedResponse response msg hashKeyHybi10 :: ByteString -> ByteString hashKeyHybi10 key = unlazy $ bytestringDigest $ sha1 $ lazy $ key `mappend` guid where guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" lazy = BL.fromChunks . return unlazy = mconcat . BL.toChunks