module Database.CQL.Protocol.Header ( Header (..) , HeaderType (..) , header , encodeHeader , decodeHeader -- ** Length , Length (..) , encodeLength , decodeLength -- ** StreamId , StreamId , mkStreamId , fromStreamId , encodeStreamId , decodeStreamId -- ** Flags , Flags , compress , customPayload , tracing , warning , isSet , encodeFlags , decodeFlags ) where import Control.Applicative import Data.Bits import Data.ByteString.Lazy (ByteString) import Data.Int import Data.Monoid hiding ((<>)) import Data.Semigroup import Data.Serialize import Data.Word import Database.CQL.Protocol.Codec import Database.CQL.Protocol.Types import Prelude -- | Protocol frame header. data Header = Header { headerType :: !HeaderType , version :: !Version , flags :: !Flags , streamId :: !StreamId , opCode :: !OpCode , bodyLength :: !Length } deriving Show data HeaderType = RqHeader -- ^ A request frame header. | RsHeader -- ^ A response frame header. deriving Show encodeHeader :: Version -> HeaderType -> Flags -> StreamId -> OpCode -> Length -> PutM () encodeHeader v t f i o l = do encodeByte $ case t of RqHeader -> mapVersion v RsHeader -> mapVersion v `setBit` 7 encodeFlags f encodeStreamId v i encodeOpCode o encodeLength l decodeHeader :: Version -> Get Header decodeHeader v = do b <- getWord8 Header (mapHeaderType b) <$> toVersion (b .&. 0x7F) <*> decodeFlags <*> decodeStreamId v <*> decodeOpCode <*> decodeLength mapHeaderType :: Word8 -> HeaderType mapHeaderType b = if b `testBit` 7 then RsHeader else RqHeader -- | Deserialise a frame header using the version specific decoding format. header :: Version -> ByteString -> Either String Header header v = runGetLazy (decodeHeader v) ------------------------------------------------------------------------------ -- Version mapVersion :: Version -> Word8 mapVersion V4 = 4 mapVersion V3 = 3 toVersion :: Word8 -> Get Version toVersion 3 = return V3 toVersion 4 = return V4 toVersion w = fail $ "decode-version: unknown: " ++ show w ------------------------------------------------------------------------------ -- Length -- | The type denoting a protocol frame length. newtype Length = Length { lengthRepr :: Int32 } deriving (Eq, Show) encodeLength :: Putter Length encodeLength (Length x) = encodeInt x decodeLength :: Get Length decodeLength = Length <$> decodeInt ------------------------------------------------------------------------------ -- StreamId -- | Streams allow multiplexing of requests over a single communication -- channel. The 'StreamId' correlates 'Request's with 'Response's. newtype StreamId = StreamId Int16 deriving (Eq, Show) -- | Create a StreamId from the given integral value. In version 2, -- a StreamId is an 'Int8' and in version 3 an 'Int16'. mkStreamId :: Integral i => i -> StreamId mkStreamId = StreamId . fromIntegral -- | Convert the stream ID to an integer. fromStreamId :: StreamId -> Int fromStreamId (StreamId i) = fromIntegral i encodeStreamId :: Version -> Putter StreamId encodeStreamId V4 (StreamId x) = encodeSignedShort (fromIntegral x) encodeStreamId V3 (StreamId x) = encodeSignedShort (fromIntegral x) decodeStreamId :: Version -> Get StreamId decodeStreamId V4 = StreamId <$> decodeSignedShort decodeStreamId V3 = StreamId <$> decodeSignedShort ------------------------------------------------------------------------------ -- Flags -- | Type representing header flags. Flags form a monoid and can be used -- as in @compress <> tracing <> mempty@. newtype Flags = Flags Word8 deriving (Eq, Show) instance Semigroup Flags where (Flags a) <> (Flags b) = Flags (a .|. b) instance Monoid Flags where mempty = Flags 0 mappend = (<>) encodeFlags :: Putter Flags encodeFlags (Flags x) = encodeByte x decodeFlags :: Get Flags decodeFlags = Flags <$> decodeByte -- | Compression flag. If set, the frame body is compressed. compress :: Flags compress = Flags 1 -- | Tracing flag. If a request support tracing and the tracing flag was set, -- the response to this request will have the tracing flag set and contain -- tracing information. tracing :: Flags tracing = Flags 2 customPayload :: Flags customPayload = Flags 4 warning :: Flags warning = Flags 8 -- | Check if a particular flag is present. isSet :: Flags -> Flags -> Bool isSet (Flags a) (Flags b) = a .&. b == a