{-# LANGUAGE BangPatterns, CPP #-} -- | Parse bits easily. Parsing can be done either in a monadic style, or more -- efficiently, using the 'Applicative' style. -- -- For the monadic style, write your parser as a 'BitGet' monad using the -- -- * 'getBool' -- -- * 'getWord8' -- -- * 'getWord16be' -- -- * 'getWord32be' -- -- * 'getWord64be' -- -- * 'getByteString' -- -- functions and run it with 'runBitGet'. -- -- For the applicative style, compose the fuctions -- -- * 'bool' -- -- * 'word8' -- -- * 'word16be' -- -- * 'word32be' -- -- * 'word64be' -- -- * 'byteString' -- -- to make a 'Block'. -- Use 'block' to turn it into the 'BitGet' monad to be able to run it with -- 'runBitGet'. module Data.Binary.Bits.Get ( -- * BitGet monad -- $bitget BitGet , runBitGet -- ** Get bytes , getBool , getWord8 , getWord16be , getWord32be , getWord64be -- * Blocks -- $blocks , Block , block -- ** Read in Blocks , bool , word8 , word16be , word32be , word64be , byteString , Data.Binary.Bits.Get.getByteString , Data.Binary.Bits.Get.getLazyByteString , Data.Binary.Bits.Get.isEmpty ) where import qualified Control.Monad.Fail as Fail import Data.Binary.Get as B ( Get, getLazyByteString, isEmpty ) import Data.Binary.Get.Internal as B ( get, put, ensureN ) import Data.ByteString as B import qualified Data.ByteString.Lazy as L import Data.ByteString.Unsafe import Data.Bits import Data.Word import Control.Applicative import Prelude as P -- $bitget -- Parse bits using a monad. -- -- @ --myBitParser :: 'Get' ('Word8', 'Word8') --myBitParser = 'runGetBit' parse4by4 -- --parse4by4 :: 'BitGet' ('Word8', 'Word8') --parse4by4 = do -- bits <- 'getWord8' 4 -- more <- 'getWord8' 4 -- return (bits,more) -- @ -- $blocks -- Parse more efficiently in blocks. Each block is read with only one boundry -- check (checking that there is enough input) as the size of the block can be -- calculated statically. This is somewhat limiting as you cannot make the -- parsing depend on the input being parsed. -- -- @ --data IPV6Header = IPV6Header { -- ipv6Version :: 'Word8' -- , ipv6TrafficClass :: 'Word8' -- , ipv6FlowLabel :: 'Word32 -- , ipv6PayloadLength :: 'Word16' -- , ipv6NextHeader :: 'Word8' -- , ipv6HopLimit :: 'Word8' -- , ipv6SourceAddress :: 'ByteString' -- , ipv6DestinationAddress :: 'ByteString' -- } -- -- ipv6headerblock = -- IPV6Header '<$>' 'word8' 4 -- '<*>' 'word8' 8 -- '<*>' 'word32be' 24 -- '<*>' 'word16be' 16 -- '<*>' 'word8' 8 -- '<*>' 'word8' 8 -- '<*>' 'byteString' 16 -- '<*>' 'byteString' 16 -- --ipv6Header :: 'Get' IPV6Header --ipv6Header = 'runBitGet' ('block' ipv6headerblock) -- @ data S = S {-# UNPACK #-} !ByteString -- Input {-# UNPACK #-} !Int -- Bit offset (0-7) deriving (Show) -- | A block that will be read with only one boundry check. Needs to know the -- number of bits in advance. data Block a = Block Int (S -> a) instance Functor Block where fmap f (Block i p) = Block i (\s -> f (p s)) instance Applicative Block where pure a = Block 0 (\_ -> a) (Block i p) <*> (Block j q) = Block (i+j) (\s -> p s $ q (incS i s)) (Block i _) *> (Block j q) = Block (i+j) (q . incS i) (Block i p) <* (Block j _) = Block (i+j) p -- | Get a block. Will be read with one single boundry check, and -- therefore requires a statically known number of bits. -- Build blocks using 'bool', 'word8', 'word16be', 'word32be', 'word64be', -- 'byteString' and 'Applicative'. block :: Block a -> BitGet a block (Block i p) = do ensureBits i s <- getState putState $! (incS i s) return $! p s incS :: Int -> S -> S incS o (S bs n) = let !o' = (n+o) !d = o' `shiftR` 3 !n' = o' .&. make_mask 3 in S (unsafeDrop d bs) n' -- | make_mask 3 = 00000111 make_mask :: (Bits a, Num a) => Int -> a make_mask n = (1 `shiftL` fromIntegral n) - 1 {-# SPECIALIZE make_mask :: Int -> Int #-} {-# SPECIALIZE make_mask :: Int -> Word #-} {-# SPECIALIZE make_mask :: Int -> Word8 #-} {-# SPECIALIZE make_mask :: Int -> Word16 #-} {-# SPECIALIZE make_mask :: Int -> Word32 #-} {-# SPECIALIZE make_mask :: Int -> Word64 #-} bit_offset :: Int -> Int bit_offset n = make_mask 3 .&. n byte_offset :: Int -> Int byte_offset n = n `shiftR` 3 readBool :: S -> Bool readBool (S bs n) = testBit (unsafeHead bs) (7-n) {-# INLINE readWord8 #-} readWord8 :: Int -> S -> Word8 readWord8 n (S bs o) -- no bits at all, return 0 | n == 0 = 0 -- all bits are in the same byte -- we just need to shift and mask them right | n <= 8 - o = let w = unsafeHead bs m = make_mask n w' = (w `shiftr_w8` (8 - o - n)) .&. m in w' -- the bits are in two different bytes -- make a word16 using both bytes, and then shift and mask | n <= 8 = let w = (fromIntegral (unsafeHead bs) `shiftl_w16` 8) .|. (fromIntegral (unsafeIndex bs 1)) m = make_mask n w' = (w `shiftr_w16` (16 - o - n)) .&. m in fromIntegral w' | otherwise = error "readWord8: tried to read more than 8 bits" {-# INLINE readWord16be #-} readWord16be :: Int -> S -> Word16 readWord16be n s@(S bs o) -- 8 or fewer bits, use readWord8 | n <= 8 = fromIntegral (readWord8 n s) -- handle 9 or more bits, stored in two bytes -- no offset, plain and simple 16 bytes | o == 0 && n == 16 = let msb = fromIntegral (unsafeHead bs) lsb = fromIntegral (unsafeIndex bs 1) w = (msb `shiftl_w16` 8) .|. lsb in w -- no offset, but not full 16 bytes | o == 0 = let msb = fromIntegral (unsafeHead bs) lsb = fromIntegral (unsafeIndex bs 1) w = (msb `shiftl_w16` (n-8)) .|. (lsb `shiftr_w16` (16-n)) in w -- with offset, and n=9-16 | n <= 16 = readWithOffset s shiftl_w16 shiftr_w16 n | otherwise = error "readWord16be: tried to read more than 16 bits" {-# INLINE readWord32be #-} readWord32be :: Int -> S -> Word32 readWord32be n s@(S _ o) -- 8 or fewer bits, use readWord8 | n <= 8 = fromIntegral (readWord8 n s) -- 16 or fewer bits, use readWord16be | n <= 16 = fromIntegral (readWord16be n s) | o == 0 = readWithoutOffset s shiftl_w32 shiftr_w32 n | n <= 32 = readWithOffset s shiftl_w32 shiftr_w32 n | otherwise = error "readWord32be: tried to read more than 32 bits" {-# INLINE readWord64be #-} readWord64be :: Int -> S -> Word64 readWord64be n s@(S _ o) -- 8 or fewer bits, use readWord8 | n <= 8 = fromIntegral (readWord8 n s) -- 16 or fewer bits, use readWord16be | n <= 16 = fromIntegral (readWord16be n s) | o == 0 = readWithoutOffset s shiftl_w64 shiftr_w64 n | n <= 64 = readWithOffset s shiftl_w64 shiftr_w64 n | otherwise = error "readWord64be: tried to read more than 64 bits" readByteString :: Int -> S -> ByteString readByteString n s@(S bs o) -- no offset, easy. | o == 0 = unsafeTake n bs -- offset. ugg. this is really naive and slow. but also pretty easy :) | otherwise = B.pack (P.map (readWord8 8) (P.take n (iterate (incS 8) s))) readWithoutOffset :: (Bits a, Num a) => S -> (a -> Int -> a) -> (a -> Int -> a) -> Int -> a readWithoutOffset (S bs o) shifterL shifterR n | o /= 0 = error "readWithoutOffset: there is an offset" | bit_offset n == 0 && byte_offset n <= 4 = let segs = byte_offset n bn 0 = fromIntegral (unsafeHead bs) bn x = (bn (x-1) `shifterL` 8) .|. fromIntegral (unsafeIndex bs x) in bn (segs-1) | n <= 64 = let segs = byte_offset n o' = bit_offset (n - 8 + o) bn 0 = fromIntegral (unsafeHead bs) bn x = (bn (x-1) `shifterL` 8) .|. fromIntegral (unsafeIndex bs x) msegs = bn (segs-1) `shifterL` o' lst = (fromIntegral (unsafeIndex bs segs)) `shifterR` (8 - o') w = msegs .|. lst in w | otherwise = error "readWithoutOffset: tried to read more than 64 bits" readWithOffset :: (Bits a, Num a) => S -> (a -> Int -> a) -> (a -> Int -> a) -> Int -> a readWithOffset (S bs o) shifterL shifterR n | n <= 64 = let bits_in_msb = 8 - o (n',top) = (n - bits_in_msb , (fromIntegral (unsafeHead bs) .&. make_mask bits_in_msb) `shifterL` n') segs = byte_offset n' bn 0 = 0 bn x = (bn (x-1) `shifterL` 8) .|. fromIntegral (unsafeIndex bs x) o' = bit_offset n' mseg = bn segs `shifterL` o' lst | o' > 0 = (fromIntegral (unsafeIndex bs (segs + 1))) `shifterR` (8 - o') | otherwise = 0 w = top .|. mseg .|. lst in w | otherwise = error "readWithOffset: tried to read more than 64 bits" -- | 'BitGet' is a monad, applicative and a functor. See 'runBitGet' -- for how to run it. newtype BitGet a = B { runState :: S -> Get (S,a) } instance Monad BitGet where return = pure (B f) >>= g = B $ \s -> do (s',a) <- f s runState (g a) s' #if !MIN_VERSION_GLASGOW_HASKELL(8, 8, 1, 0) fail = Fail.fail #endif instance Fail.MonadFail BitGet where fail str = B $ \(S inp n) -> putBackState inp n >> fail str instance Functor BitGet where fmap f m = m >>= \a -> return (f a) instance Applicative BitGet where pure x = B $ \s -> return (s,x) fm <*> m = fm >>= \f -> m >>= \v -> return (f v) -- | Run a 'BitGet' within the Binary packages 'Get' monad. If a byte has -- been partially consumed it will be discarded once 'runBitGet' is finished. runBitGet :: BitGet a -> Get a runBitGet bg = do s <- mkInitState ((S str' n),a) <- runState bg s putBackState str' n return a mkInitState :: Get S mkInitState = do str <- get put B.empty return (S str 0) putBackState :: B.ByteString -> Int -> Get () putBackState bs n = do remaining <- get put (B.drop (if n==0 then 0 else 1) bs `B.append` remaining) getState :: BitGet S getState = B $ \s -> return (s,s) putState :: S -> BitGet () putState s = B $ \_ -> return (s,()) -- | Make sure there are at least @n@ bits. ensureBits :: Int -> BitGet () ensureBits n = do (S bs o) <- getState if n <= (B.length bs * 8 - o) then return () else do let currentBits = B.length bs * 8 - o let byteCount = (n - currentBits + 7) `div` 8 B $ \_ -> do B.ensureN byteCount bs' <- B.get put B.empty return (S (bs`append`bs') o, ()) -- | Get 1 bit as a 'Bool'. getBool :: BitGet Bool getBool = block bool -- | Get @n@ bits as a 'Word8'. @n@ must be within @[0..8]@. getWord8 :: Int -> BitGet Word8 getWord8 n = block (word8 n) -- | Get @n@ bits as a 'Word16'. @n@ must be within @[0..16]@. getWord16be :: Int -> BitGet Word16 getWord16be n = block (word16be n) -- | Get @n@ bits as a 'Word32'. @n@ must be within @[0..32]@. getWord32be :: Int -> BitGet Word32 getWord32be n = block (word32be n) -- | Get @n@ bits as a 'Word64'. @n@ must be within @[0..64]@. getWord64be :: Int -> BitGet Word64 getWord64be n = block (word64be n) -- | Get @n@ bytes as a 'ByteString'. getByteString :: Int -> BitGet ByteString getByteString n = block (byteString n) -- | Get @n@ bytes as a lazy ByteString. getLazyByteString :: Int -> BitGet L.ByteString getLazyByteString n = do (S _ o) <- getState case o of 0 -> B $ \ (S bs o') -> do putBackState bs o' lbs <- B.getLazyByteString (fromIntegral n) return (S B.empty 0, lbs) _ -> L.fromChunks . (:[]) <$> Data.Binary.Bits.Get.getByteString n -- | Test whether all input has been consumed, i.e. there are no remaining -- undecoded bytes. isEmpty :: BitGet Bool isEmpty = B $ \ (S bs o) -> if B.null bs then B.isEmpty >>= \e -> return (S bs o, e) else return (S bs o, False) -- | Read a 1 bit 'Bool'. bool :: Block Bool bool = Block 1 readBool -- | Read @n@ bits as a 'Word8'. @n@ must be within @[0..8]@. word8 :: Int -> Block Word8 word8 n = Block n (readWord8 n) -- | Read @n@ bits as a 'Word16'. @n@ must be within @[0..16]@. word16be :: Int -> Block Word16 word16be n = Block n (readWord16be n) -- | Read @n@ bits as a 'Word32'. @n@ must be within @[0..32]@. word32be :: Int -> Block Word32 word32be n = Block n (readWord32be n) -- | Read @n@ bits as a 'Word64'. @n@ must be within @[0..64]@. word64be :: Int -> Block Word64 word64be n = Block n (readWord64be n) -- | Read @n@ bytes as a 'ByteString'. byteString :: Int -> Block ByteString byteString n | n > 0 = Block (n*8) (readByteString n) | otherwise = Block 0 (\_ -> B.empty) -- Unchecked shifts, from the package binary shiftl_w16 :: Word16 -> Int -> Word16 shiftl_w32 :: Word32 -> Int -> Word32 shiftl_w64 :: Word64 -> Int -> Word64 shiftr_w8 :: Word8 -> Int -> Word8 shiftr_w16 :: Word16 -> Int -> Word16 shiftr_w32 :: Word32 -> Int -> Word32 shiftr_w64 :: Word64 -> Int -> Word64 shiftl_w16 = unsafeShiftL shiftl_w32 = unsafeShiftL shiftl_w64 = unsafeShiftL shiftr_w8 = unsafeShiftR shiftr_w16 = unsafeShiftR shiftr_w32 = unsafeShiftR shiftr_w64 = unsafeShiftR