{-# LANGUAGE OverloadedStrings #-}
module Network.WebSockets.Protocol.Hybi10.Internal
    ( Hybi10_ (..)
    , encodeFrameHybi10
    ) where

import Control.Applicative (pure, (<$>))
import Control.Monad (liftM)
import Data.Bits ((.&.), (.|.))
import Data.Maybe (maybeToList)
import Data.Monoid (mempty, mappend, mconcat)
import System.Random (RandomGen)

import Data.Attoparsec (anyWord8)
import Data.Binary.Get (runGet, getWord16be, getWord64be)
import Data.ByteString (ByteString, intercalate)
import Data.ByteString.Char8 ()
import Data.Digest.Pure.SHA (bytestringDigest, sha1)
import Data.Enumerator ((=$))
import Data.Int (Int64)
import qualified Blaze.ByteString.Builder as B
import qualified Data.Attoparsec as A
import qualified Data.Attoparsec.Enumerator as A
import qualified Data.ByteString.Base64 as B64
import qualified Data.ByteString.Lazy as BL
import qualified Data.Enumerator as E
import qualified Data.Enumerator.List as EL

import Network.WebSockets.Handshake.Http
import Network.WebSockets.Protocol
import Network.WebSockets.Protocol.Hybi10.Demultiplex
import Network.WebSockets.Protocol.Hybi10.Mask
import Network.WebSockets.Types

import System.Entropy as R

data Hybi10_ = Hybi10_

instance Protocol Hybi10_ where
    version        Hybi10_   = "hybi10"
    headerVersions Hybi10_   = ["13", "8", "7"]
    encodeMessages Hybi10_ m = EL.mapAccum (encodeMessageHybi10 m)
    decodeMessages Hybi10_   = decodeMessagesHybi10
    createRequest  Hybi10_   = createRequestHybi10
    finishRequest  Hybi10_   = handshakeHybi10
    finishResponse Hybi10_   = finishResponseHybi10
    implementations          = [Hybi10_]

instance TextProtocol Hybi10_
instance BinaryProtocol Hybi10_

encodeMessageHybi10 :: RandomGen g => Bool -> g -> Message p -> (g, B.Builder)
encodeMessageHybi10 needMask gen msg = (gen', builder)
  where
    mkFrame = Frame True False False False
    (mask, gen') = if needMask then randomMask gen else (Nothing, gen)
    builder      = encodeFrameHybi10 mask $ case msg of
        (ControlMessage (Close pl)) -> mkFrame CloseFrame  pl
        (ControlMessage (Ping pl))  -> mkFrame PingFrame   pl
        (ControlMessage (Pong pl))  -> mkFrame PongFrame   pl
        (DataMessage (Text pl))     -> mkFrame TextFrame   pl
        (DataMessage (Binary pl))   -> mkFrame BinaryFrame pl

-- | Encode a frame
encodeFrameHybi10 :: Mask -> Frame -> B.Builder
encodeFrameHybi10 mask f = B.fromWord8 byte0 `mappend`
    B.fromWord8 byte1 `mappend` len `mappend` maskbytes `mappend`
    B.fromLazyByteString (maskPayload mask (framePayload f))
  where
    byte0  = fin .|. rsv1 .|. rsv2 .|. rsv3 .|. opcode
    fin    = if frameFin f  then 0x80 else 0x00
    rsv1   = if frameRsv1 f then 0x40 else 0x00
    rsv2   = if frameRsv2 f then 0x20 else 0x00
    rsv3   = if frameRsv3 f then 0x10 else 0x00
    opcode = case frameType f of
        ContinuationFrame -> 0x00
        TextFrame         -> 0x01
        BinaryFrame       -> 0x02
        CloseFrame        -> 0x08
        PingFrame         -> 0x09
        PongFrame         -> 0x0a

    (maskflag, maskbytes) = case mask of
        Nothing -> (0x00, mempty)
        Just m  -> (0x80, B.fromByteString m)

    byte1 = maskflag .|. lenflag
    len'  = BL.length (framePayload f)
    (lenflag, len)
        | len' < 126     = (fromIntegral len', mempty)
        | len' < 0x10000 = (126, B.fromWord16be (fromIntegral len'))
        | otherwise      = (127, B.fromWord64be (fromIntegral len'))

decodeMessagesHybi10 :: Monad m => E.Enumeratee ByteString (Message p) m a
decodeMessagesHybi10 =
    (E.sequence (A.iterParser parseFrame) =$) . demultiplexEnum

demultiplexEnum :: Monad m => E.Enumeratee Frame (Message p) m a
demultiplexEnum = EL.concatMapAccum step emptyDemultiplexState
  where
    step s f = let (m, s') = demultiplex s f in (s', maybeToList m)

-- | Parse a frame
parseFrame :: A.Parser Frame
parseFrame = do
    byte0 <- anyWord8
    let fin    = byte0 .&. 0x80 == 0x80
        rsv1   = byte0 .&. 0x40 == 0x40
        rsv2   = byte0 .&. 0x20 == 0x20
        rsv3   = byte0 .&. 0x10 == 0x10
        opcode = byte0 .&. 0x0f

    let ft = case opcode of
            0x00 -> ContinuationFrame
            0x01 -> TextFrame
            0x02 -> BinaryFrame
            0x08 -> CloseFrame
            0x09 -> PingFrame
            0x0a -> PongFrame
            _    -> error "Unknown opcode"

    byte1 <- anyWord8
    let mask = byte1 .&. 0x80 == 0x80
        lenflag = fromIntegral (byte1 .&. 0x7f)

    len <- case lenflag of
        126 -> fromIntegral . runGet' getWord16be <$> A.take 2
        127 -> fromIntegral . runGet' getWord64be <$> A.take 8
        _   -> return lenflag

    masker <- maskPayload <$> if mask then Just <$> A.take 4 else pure Nothing

    chunks <- take64 len

    return $ Frame fin rsv1 rsv2 rsv3 ft (masker $ BL.fromChunks chunks)
  where
    runGet' g = runGet g . BL.fromChunks . return

    take64 :: Int64 -> A.Parser [ByteString]
    take64 n
        | n <= 0    = return []
        | otherwise = do
            let n' = min intMax n
            chunk <- A.take (fromIntegral n')
            (chunk :) <$> take64 (n - n')
      where
        intMax :: Int64
        intMax = fromIntegral (maxBound :: Int)

handshakeHybi10 :: Monad m
                => RequestHttpPart
                -> E.Iteratee ByteString m Request
handshakeHybi10 reqHttp@(RequestHttpPart path h _) = do
    key <- getRequestHeader reqHttp "Sec-WebSocket-Key"
    let hash = hashKeyHybi10 key
    let encoded = B64.encode hash
    return $ Request path h $ response101 [("Sec-WebSocket-Accept", encoded)] ""

createRequestHybi10 :: ByteString
                    -> ByteString
                    -> Maybe ByteString
                    -> Maybe [ByteString]
                    -> Bool
                    -> IO RequestHttpPart
createRequestHybi10 hostname path origin protocols secure = do
    key <- B64.encode `liftM`  getEntropy 16
    return $ RequestHttpPart path (headers key) secure
  where
    headers key = [("Host"                   , hostname     )
                  ,("Connection"             , "Upgrade"    )
                  ,("Upgrade"                , "websocket"  )
                  ,("Sec-WebSocket-Key"      , key          )
                  ,("Sec-WebSocket-Version"  , versionNumber)
                  ] ++ protocolHeader protocols
                    ++ originHeader origin

    originHeader (Just o)    = [("Origin"                , o                  )]
    originHeader Nothing     = []

    protocolHeader (Just ps) = [("Sec-WebSocket-Protocol", intercalate ", " ps)]
    protocolHeader Nothing   = []

    versionNumber = head . headerVersions $ Hybi10_

finishResponseHybi10 :: Monad m
                     => RequestHttpPart
                     -> ResponseHttpPart
                     -> E.Iteratee ByteString m ResponseBody
finishResponseHybi10 request response = do
    -- Response message should be one of
    --
    -- - WebSocket Protocol Handshake
    -- - Switching Protocols
    --
    -- But we don't check it for now
    if responseHttpCode response /= 101
        then throw "Wrong response status or message."
        else do
            key          <- getRequestHeader  request  "Sec-WebSocket-Key"
            responseHash <- getResponseHeader response "Sec-WebSocket-Accept"

            let challengeHash = B64.encode $ hashKeyHybi10 key
            if responseHash /= challengeHash
                then throw "Challenge and response hashes do not match."
                else return $ ResponseBody response ""
  where
    throw msg = E.throwError $ MalformedResponse response msg


hashKeyHybi10 :: ByteString -> ByteString
hashKeyHybi10 key = unlazy $ bytestringDigest $ sha1 $ lazy $ key `mappend` guid
  where
    guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
    lazy = BL.fromChunks . return
    unlazy = mconcat . BL.toChunks