module Database.CQL.Protocol.Header
    ( Header     (..)
    , HeaderType (..)
    , header
    , encodeHeader
    , decodeHeader
      
    , Length     (..)
    , encodeLength
    , decodeLength
      
    , StreamId
    , mkStreamId
    , fromStreamId
    , encodeStreamId
    , decodeStreamId
      
    , Flags
    , compress
    , tracing
    , isSet
    , encodeFlags
    , decodeFlags
    ) where
import Control.Applicative
import Data.Bits
import Data.ByteString.Lazy (ByteString)
import Data.Int
import Data.Monoid
import Data.Serialize
import Data.Word
import Database.CQL.Protocol.Codec
import Database.CQL.Protocol.Types
import Prelude
data Header = Header
    { headerType :: !HeaderType
    , version    :: !Version
    , flags      :: !Flags
    , streamId   :: !StreamId
    , opCode     :: !OpCode
    , bodyLength :: !Length
    } deriving Show
data HeaderType
    = RqHeader 
    | RsHeader 
    deriving Show
encodeHeader :: Version -> HeaderType -> Flags -> StreamId -> OpCode -> Length -> PutM ()
encodeHeader v t f i o l = do
    encodeByte $ case t of
        RqHeader -> mapVersion v
        RsHeader -> mapVersion v `setBit` 7
    encodeFlags f
    encodeStreamId v i
    encodeOpCode o
    encodeLength l
decodeHeader :: Version -> Get Header
decodeHeader v = do
    b <- getWord8
    Header (mapHeaderType b)
        <$> toVersion (b .&. 0x7F)
        <*> decodeFlags
        <*> decodeStreamId v
        <*> decodeOpCode
        <*> decodeLength
mapHeaderType :: Word8 -> HeaderType
mapHeaderType b = if b `testBit` 7 then RsHeader else RqHeader
header :: Version -> ByteString -> Either String Header
header v = runGetLazy (decodeHeader v)
mapVersion :: Version -> Word8
mapVersion V3 = 3
mapVersion V2 = 2
toVersion :: Word8 -> Get Version
toVersion 2 = return V2
toVersion 3 = return V3
toVersion w = fail $ "decode-version: unknown: " ++ show w
newtype Length = Length { lengthRepr :: Int32 } deriving (Eq, Show)
encodeLength :: Putter Length
encodeLength (Length x) = encodeInt x
decodeLength :: Get Length
decodeLength = Length <$> decodeInt
newtype StreamId = StreamId Int16 deriving (Eq, Show)
mkStreamId :: Integral i => i -> StreamId
mkStreamId = StreamId . fromIntegral
fromStreamId :: StreamId -> Int
fromStreamId (StreamId i) = fromIntegral i
encodeStreamId :: Version -> Putter StreamId
encodeStreamId V3 (StreamId x) = encodeSignedShort (fromIntegral x)
encodeStreamId V2 (StreamId x) = encodeSignedByte (fromIntegral x)
decodeStreamId :: Version -> Get StreamId
decodeStreamId V3 = StreamId <$> decodeSignedShort
decodeStreamId V2 = StreamId . fromIntegral <$> decodeSignedByte
newtype Flags = Flags Word8 deriving (Eq, Show)
instance Monoid Flags where
    mempty = Flags 0
    mappend (Flags a) (Flags b) = Flags (a .|. b)
encodeFlags :: Putter Flags
encodeFlags (Flags x) = encodeByte x
decodeFlags :: Get Flags
decodeFlags = Flags <$> decodeByte
compress :: Flags
compress = Flags 1
tracing :: Flags
tracing = Flags 2
isSet :: Flags -> Flags -> Bool
isSet (Flags a) (Flags b) = a .&. b == a