{-# LANGUAGE BangPatterns #-}

-- | Parse bits easily. Parsing can be done either in a monadic style, or more
-- efficiently, using the 'Applicative' style.
--
-- For the monadic style, write your parser as a 'BitGet' monad using the
--
--   * 'getBool'
--
--   * 'getWord8'
--
--   * 'getWord16be'
--
--   * 'getWord32be'
--
--   * 'getWord64be'
--
--   * 'getByteString'
--
-- functions and run it with 'runBitGet'.
--
-- For the applicative style, compose the fuctions
--
--   * 'bool'
--
--   * 'word8'
--
--   * 'word16be'
--
--   * 'word32be'
--
--   * 'word64be'
--
--   * 'byteString'
--
-- to make a 'Block'.
-- Use 'block' to turn it into the 'BitGet' monad to be able to run it with
-- 'runBitGet'.

module Data.Binary.Bits.Get
  ( BitGet
  , runBitGet

            -- ** Get bytes
  , getBool
  , getWord8
  , getWord16be
  , getWord32be
  , getWord64be

            -- * Blocks

            -- $blocks
  , Block
  , block

            -- ** Read in Blocks
  , bool
  , word8
  , word16be
  , word32be
  , word64be
  , byteString
  , Data.Binary.Bits.Get.getByteString
  , Data.Binary.Bits.Get.getLazyByteString
  , Data.Binary.Bits.Get.isEmpty
  ) where

import qualified Control.Monad.Fail as Fail

import Data.Binary.Get as B (Get, getLazyByteString, isEmpty)
import Data.Binary.Get.Internal as B (ensureN, get, put)

import Data.ByteString as B
import qualified Data.ByteString.Lazy as L
import Data.ByteString.Unsafe

import Control.Applicative as Appl
import Data.Bits
import Data.Word

import Prelude as P


-- $bitget
-- Parse bits using a monad.
--
-- @
--myBitParser :: 'Get' ('Word8', 'Word8')
--myBitParser = 'runGetBit' parse4by4
--
--parse4by4 :: 'BitGet' ('Word8', 'Word8')
--parse4by4 = do
--   bits <- 'getWord8' 4
--   more <- 'getWord8' 4
--   return (bits,more)
-- @

-- $blocks
-- Parse more efficiently in blocks. Each block is read with only one boundry
-- check (checking that there is enough input) as the size of the block can be
-- calculated statically. This is somewhat limiting as you cannot make the
-- parsing depend on the input being parsed.
--
-- @
--data IPV6Header = IPV6Header {
--     ipv6Version :: 'Word8'
--   , ipv6TrafficClass :: 'Word8'
--   , ipv6FlowLabel :: 'Word32
--   , ipv6PayloadLength :: 'Word16'
--   , ipv6NextHeader :: 'Word8'
--   , ipv6HopLimit :: 'Word8'
--   , ipv6SourceAddress :: 'ByteString'
--   , ipv6DestinationAddress :: 'ByteString'
-- }
--
-- ipv6headerblock =
--         IPV6Header '<$>' 'word8' 4
--                    '<*>' 'word8' 8
--                    '<*>' 'word32be' 24
--                    '<*>' 'word16be' 16
--                    '<*>' 'word8' 8
--                    '<*>' 'word8' 8
--                    '<*>' 'byteString' 16
--                    '<*>' 'byteString' 16
--
--ipv6Header :: 'Get' IPV6Header
--ipv6Header = 'runBitGet' ('block' ipv6headerblock)
-- @

data S = S {-# UNPACK #-} !ByteString {-# UNPACK #-} !Int -- Bit offset (0-7)
  deriving Int -> S -> ShowS
[S] -> ShowS
S -> String
(Int -> S -> ShowS) -> (S -> String) -> ([S] -> ShowS) -> Show S
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [S] -> ShowS
$cshowList :: [S] -> ShowS
show :: S -> String
$cshow :: S -> String
showsPrec :: Int -> S -> ShowS
$cshowsPrec :: Int -> S -> ShowS
Show

-- | A block that will be read with only one boundry check. Needs to know the
-- number of bits in advance.
data Block a = Block Int (S -> a)

instance Functor Block where
  fmap :: (a -> b) -> Block a -> Block b
fmap a -> b
f (Block Int
i S -> a
p) = Int -> (S -> b) -> Block b
forall a. Int -> (S -> a) -> Block a
Block Int
i (a -> b
f (a -> b) -> (S -> a) -> S -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. S -> a
p)

instance Applicative Block where
  pure :: a -> Block a
pure a
a = Int -> (S -> a) -> Block a
forall a. Int -> (S -> a) -> Block a
Block Int
0 (a -> S -> a
forall a b. a -> b -> a
const a
a)
  (Block Int
i S -> a -> b
p) <*> :: Block (a -> b) -> Block a -> Block b
<*> (Block Int
j S -> a
q) = Int -> (S -> b) -> Block b
forall a. Int -> (S -> a) -> Block a
Block (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) (\S
s -> S -> a -> b
p S
s (a -> b) -> a -> b
forall a b. (a -> b) -> a -> b
$ S -> a
q (Int -> S -> S
incS Int
i S
s))
  (Block Int
i S -> a
_) *> :: Block a -> Block b -> Block b
*> (Block Int
j S -> b
q) = Int -> (S -> b) -> Block b
forall a. Int -> (S -> a) -> Block a
Block (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) (S -> b
q (S -> b) -> (S -> S) -> S -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> S -> S
incS Int
i)
  (Block Int
i S -> a
p) <* :: Block a -> Block b -> Block a
<* (Block Int
j S -> b
_) = Int -> (S -> a) -> Block a
forall a. Int -> (S -> a) -> Block a
Block (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) S -> a
p

-- | Get a block. Will be read with one single boundry check, and
-- therefore requires a statically known number of bits.
-- Build blocks using 'bool', 'word8', 'word16be', 'word32be', 'word64be',
-- 'byteString' and 'Applicative'.
block :: Block a -> BitGet a
block :: Block a -> BitGet a
block (Block Int
i S -> a
p) = do
  Int -> BitGet ()
ensureBits Int
i
  S
s <- BitGet S
getState
  S -> BitGet ()
putState (S -> BitGet ()) -> S -> BitGet ()
forall a b. (a -> b) -> a -> b
$! Int -> S -> S
incS Int
i S
s
  a -> BitGet a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> BitGet a) -> a -> BitGet a
forall a b. (a -> b) -> a -> b
$! S -> a
p S
s

incS :: Int -> S -> S
incS :: Int -> S -> S
incS Int
o (S ByteString
bs Int
n) =
  let
    !o' :: Int
o' = (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
o)
    !d :: Int
d = Int
o' Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
3
    !n' :: Int
n' = Int
o' Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int -> Int
forall a. (Bits a, Num a) => Int -> a
makeMask Int
3
  in ByteString -> Int -> S
S (Int -> ByteString -> ByteString
unsafeDrop Int
d ByteString
bs) Int
n'

-- | makeMask 3 = 00000111
makeMask :: (Bits a, Num a) => Int -> a
makeMask :: Int -> a
makeMask Int
n = (a
1 a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftL` Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n) a -> a -> a
forall a. Num a => a -> a -> a
- a
1
{-# SPECIALIZE makeMask :: Int -> Int #-}
{-# SPECIALIZE makeMask :: Int -> Word #-}
{-# SPECIALIZE makeMask :: Int -> Word8 #-}
{-# SPECIALIZE makeMask :: Int -> Word16 #-}
{-# SPECIALIZE makeMask :: Int -> Word32 #-}
{-# SPECIALIZE makeMask :: Int -> Word64 #-}

bitOffset :: Int -> Int
bitOffset :: Int -> Int
bitOffset Int
n = Int -> Int
forall a. (Bits a, Num a) => Int -> a
makeMask Int
3 Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
n

byteOffset :: Int -> Int
byteOffset :: Int -> Int
byteOffset Int
n = Int
n Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
3

readBool :: S -> Bool
readBool :: S -> Bool
readBool (S ByteString
bs Int
n) = Word8 -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit (ByteString -> Word8
unsafeHead ByteString
bs) (Int
7 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n)

{-# INLINE readWord8 #-}
readWord8 :: Int -> S -> Word8
readWord8 :: Int -> S -> Word8
readWord8 Int
n (S ByteString
bs Int
o)
  |
  -- no bits at all, return 0
    Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
  = Word8
0
  |

  -- all bits are in the same byte
  -- we just need to shift and mask them right
    Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
8 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
o
  = let
      w :: Word8
w = ByteString -> Word8
unsafeHead ByteString
bs
      m :: Word8
m = Int -> Word8
forall a. (Bits a, Num a) => Int -> a
makeMask Int
n
      w' :: Word8
w' = (Word8
w Word8 -> Int -> Word8
`shiftr_w8` (Int
8 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
o Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n)) Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
m
    in Word8
w'
  |

  -- the bits are in two different bytes
  -- make a word16 using both bytes, and then shift and mask
    Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
8
  = let
      w :: Word16
w = (Word8 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Word8
unsafeHead ByteString
bs) Word16 -> Int -> Word16
`shiftl_w16` Int
8)
        Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.|. Word8 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int -> Word8
unsafeIndex ByteString
bs Int
1)
      m :: Word16
m = Int -> Word16
forall a. (Bits a, Num a) => Int -> a
makeMask Int
n
      w' :: Word16
w' = (Word16
w Word16 -> Int -> Word16
`shiftr_w16` (Int
16 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
o Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n)) Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.&. Word16
m
    in Word16 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
w'
  | Bool
otherwise
  = String -> Word8
forall a. HasCallStack => String -> a
error String
"readWord8: tried to read more than 8 bits"

{-# INLINE readWord16be #-}
readWord16be :: Int -> S -> Word16
readWord16be :: Int -> S -> Word16
readWord16be Int
n s :: S
s@(S ByteString
bs Int
o)
  |

  -- 8 or fewer bits, use readWord8
    Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
8
  = Word8 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> S -> Word8
readWord8 Int
n S
s)
  |

  -- handle 9 or more bits, stored in two bytes

  -- no offset, plain and simple 16 bytes
    Int
o Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
&& Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
16
  = let
      msb :: Word16
msb = Word8 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Word8
unsafeHead ByteString
bs)
      lsb :: Word16
lsb = Word8 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int -> Word8
unsafeIndex ByteString
bs Int
1)
      w :: Word16
w = (Word16
msb Word16 -> Int -> Word16
`shiftl_w16` Int
8) Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.|. Word16
lsb
    in Word16
w
  |

  -- no offset, but not full 16 bytes
    Int
o Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
  = let
      msb :: Word16
msb = Word8 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Word8
unsafeHead ByteString
bs)
      lsb :: Word16
lsb = Word8 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int -> Word8
unsafeIndex ByteString
bs Int
1)
      w :: Word16
w = (Word16
msb Word16 -> Int -> Word16
`shiftl_w16` (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
8)) Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.|. (Word16
lsb Word16 -> Int -> Word16
`shiftr_w16` (Int
16 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n))
    in Word16
w
  |

  -- with offset, and n=9-16
    Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
16
  = S
-> (Word16 -> Int -> Word16)
-> (Word16 -> Int -> Word16)
-> Int
-> Word16
forall a.
(Bits a, Num a) =>
S -> (a -> Int -> a) -> (a -> Int -> a) -> Int -> a
readWithOffset S
s Word16 -> Int -> Word16
shiftl_w16 Word16 -> Int -> Word16
shiftr_w16 Int
n
  | Bool
otherwise
  = String -> Word16
forall a. HasCallStack => String -> a
error String
"readWord16be: tried to read more than 16 bits"

{-# INLINE readWord32be #-}
readWord32be :: Int -> S -> Word32
readWord32be :: Int -> S -> Word32
readWord32be Int
n s :: S
s@(S ByteString
_ Int
o)
  |
  -- 8 or fewer bits, use readWord8
    Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
8 = Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> S -> Word8
readWord8 Int
n S
s)
  |

  -- 16 or fewer bits, use readWord16be
    Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
16 = Word16 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> S -> Word16
readWord16be Int
n S
s)
  | Int
o Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = S
-> (Word32 -> Int -> Word32)
-> (Word32 -> Int -> Word32)
-> Int
-> Word32
forall a.
(Bits a, Num a) =>
S -> (a -> Int -> a) -> (a -> Int -> a) -> Int -> a
readWithoutOffset S
s Word32 -> Int -> Word32
shiftl_w32 Word32 -> Int -> Word32
shiftr_w32 Int
n
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
32 = S
-> (Word32 -> Int -> Word32)
-> (Word32 -> Int -> Word32)
-> Int
-> Word32
forall a.
(Bits a, Num a) =>
S -> (a -> Int -> a) -> (a -> Int -> a) -> Int -> a
readWithOffset S
s Word32 -> Int -> Word32
shiftl_w32 Word32 -> Int -> Word32
shiftr_w32 Int
n
  | Bool
otherwise = String -> Word32
forall a. HasCallStack => String -> a
error String
"readWord32be: tried to read more than 32 bits"


{-# INLINE readWord64be #-}
readWord64be :: Int -> S -> Word64
readWord64be :: Int -> S -> Word64
readWord64be Int
n s :: S
s@(S ByteString
_ Int
o)
  |
  -- 8 or fewer bits, use readWord8
    Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
8 = Word8 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> S -> Word8
readWord8 Int
n S
s)
  |

  -- 16 or fewer bits, use readWord16be
    Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
16 = Word16 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> S -> Word16
readWord16be Int
n S
s)
  | Int
o Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = S
-> (Word64 -> Int -> Word64)
-> (Word64 -> Int -> Word64)
-> Int
-> Word64
forall a.
(Bits a, Num a) =>
S -> (a -> Int -> a) -> (a -> Int -> a) -> Int -> a
readWithoutOffset S
s Word64 -> Int -> Word64
shiftl_w64 Word64 -> Int -> Word64
shiftr_w64 Int
n
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
64 = S
-> (Word64 -> Int -> Word64)
-> (Word64 -> Int -> Word64)
-> Int
-> Word64
forall a.
(Bits a, Num a) =>
S -> (a -> Int -> a) -> (a -> Int -> a) -> Int -> a
readWithOffset S
s Word64 -> Int -> Word64
shiftl_w64 Word64 -> Int -> Word64
shiftr_w64 Int
n
  | Bool
otherwise = String -> Word64
forall a. HasCallStack => String -> a
error String
"readWord64be: tried to read more than 64 bits"


readByteString :: Int -> S -> ByteString
readByteString :: Int -> S -> ByteString
readByteString Int
n s :: S
s@(S ByteString
bs Int
o)
  |
  -- no offset, easy.
    Int
o Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Int -> ByteString -> ByteString
unsafeTake Int
n ByteString
bs
  |
  -- offset. ugg. this is really naive and slow. but also pretty easy :)
    Bool
otherwise = [Word8] -> ByteString
B.pack ((S -> Word8) -> [S] -> [Word8]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> S -> Word8
readWord8 Int
8) (Int -> [S] -> [S]
forall a. Int -> [a] -> [a]
P.take Int
n ((S -> S) -> S -> [S]
forall a. (a -> a) -> a -> [a]
iterate (Int -> S -> S
incS Int
8) S
s)))

readWithoutOffset
  :: (Bits a, Num a) => S -> (a -> Int -> a) -> (a -> Int -> a) -> Int -> a
readWithoutOffset :: S -> (a -> Int -> a) -> (a -> Int -> a) -> Int -> a
readWithoutOffset (S ByteString
bs Int
o) a -> Int -> a
shifterL a -> Int -> a
shifterR Int
n
  | Int
o Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0
  = String -> a
forall a. HasCallStack => String -> a
error String
"readWithoutOffset: there is an offset"
  | Int -> Int
bitOffset Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
&& Int -> Int
byteOffset Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
4
  = let
      segs :: Int
segs = Int -> Int
byteOffset Int
n
      bn :: Int -> a
bn Int
0 = Word8 -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Word8
unsafeHead ByteString
bs)
      bn Int
x = (Int -> a
bn (Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) a -> Int -> a
`shifterL` Int
8) a -> a -> a
forall a. Bits a => a -> a -> a
.|. Word8 -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int -> Word8
unsafeIndex ByteString
bs Int
x)
    in Int -> a
bn (Int
segs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
64
  = let
      segs :: Int
segs = Int -> Int
byteOffset Int
n
      o' :: Int
o' = Int -> Int
bitOffset (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
8 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
o)

      bn :: Int -> a
bn Int
0 = Word8 -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Word8
unsafeHead ByteString
bs)
      bn Int
x = (Int -> a
bn (Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) a -> Int -> a
`shifterL` Int
8) a -> a -> a
forall a. Bits a => a -> a -> a
.|. Word8 -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int -> Word8
unsafeIndex ByteString
bs Int
x)

      msegs :: a
msegs = Int -> a
bn (Int
segs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) a -> Int -> a
`shifterL` Int
o'

      lst :: a
lst = Word8 -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int -> Word8
unsafeIndex ByteString
bs Int
segs) a -> Int -> a
`shifterR` (Int
8 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
o')

      w :: a
w = a
msegs a -> a -> a
forall a. Bits a => a -> a -> a
.|. a
lst
    in a
w
  | Bool
otherwise
  = String -> a
forall a. HasCallStack => String -> a
error String
"readWithoutOffset: tried to read more than 64 bits"

readWithOffset
  :: (Bits a, Num a) => S -> (a -> Int -> a) -> (a -> Int -> a) -> Int -> a
readWithOffset :: S -> (a -> Int -> a) -> (a -> Int -> a) -> Int -> a
readWithOffset (S ByteString
bs Int
o) a -> Int -> a
shifterL a -> Int -> a
shifterR Int
n
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
64
  = let
      bits_in_msb :: Int
bits_in_msb = Int
8 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
o
      (Int
n', a
top) =
        ( Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
bits_in_msb
        , (Word8 -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Word8
unsafeHead ByteString
bs) a -> a -> a
forall a. Bits a => a -> a -> a
.&. Int -> a
forall a. (Bits a, Num a) => Int -> a
makeMask Int
bits_in_msb) a -> Int -> a
`shifterL` Int
n'
        )

      segs :: Int
segs = Int -> Int
byteOffset Int
n'

      bn :: Int -> a
bn Int
0 = a
0
      bn Int
x = (Int -> a
bn (Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) a -> Int -> a
`shifterL` Int
8) a -> a -> a
forall a. Bits a => a -> a -> a
.|. Word8 -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int -> Word8
unsafeIndex ByteString
bs Int
x)

      o' :: Int
o' = Int -> Int
bitOffset Int
n'

      mseg :: a
mseg = Int -> a
bn Int
segs a -> Int -> a
`shifterL` Int
o'

      lst :: a
lst
        | Int
o' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
        = Word8 -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int -> Word8
unsafeIndex ByteString
bs (Int
segs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)) a -> Int -> a
`shifterR` (Int
8 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
o')
        | Bool
otherwise
        = a
0

      w :: a
w = a
top a -> a -> a
forall a. Bits a => a -> a -> a
.|. a
mseg a -> a -> a
forall a. Bits a => a -> a -> a
.|. a
lst
    in a
w
  | Bool
otherwise
  = String -> a
forall a. HasCallStack => String -> a
error String
"readWithOffset: tried to read more than 64 bits"

-- | 'BitGet' is a monad, applicative and a functor. See 'runBitGet'
-- for how to run it.
--
-- $bitget
newtype BitGet a = B { BitGet a -> S -> Get (S, a)
runState :: S -> Get (S,a) }

instance Monad BitGet where
  return :: a -> BitGet a
return = a -> BitGet a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
  (B S -> Get (S, a)
f) >>= :: BitGet a -> (a -> BitGet b) -> BitGet b
>>= a -> BitGet b
g = (S -> Get (S, b)) -> BitGet b
forall a. (S -> Get (S, a)) -> BitGet a
B ((S -> Get (S, b)) -> BitGet b) -> (S -> Get (S, b)) -> BitGet b
forall a b. (a -> b) -> a -> b
$ \S
s -> do
    (S
s', a
a) <- S -> Get (S, a)
f S
s
    BitGet b -> S -> Get (S, b)
forall a. BitGet a -> S -> Get (S, a)
runState (a -> BitGet b
g a
a) S
s'

instance Fail.MonadFail BitGet where
  fail :: String -> BitGet a
fail String
str = (S -> Get (S, a)) -> BitGet a
forall a. (S -> Get (S, a)) -> BitGet a
B ((S -> Get (S, a)) -> BitGet a) -> (S -> Get (S, a)) -> BitGet a
forall a b. (a -> b) -> a -> b
$ \(S ByteString
inp Int
n) -> ByteString -> Int -> Get ()
putBackState ByteString
inp Int
n Get () -> Get (S, a) -> Get (S, a)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> String -> Get (S, a)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
str

instance Functor BitGet where
  fmap :: (a -> b) -> BitGet a -> BitGet b
fmap a -> b
f BitGet a
m = BitGet a
m BitGet a -> (a -> BitGet b) -> BitGet b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \a
a -> b -> BitGet b
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> b
f a
a)

instance Applicative BitGet where
  pure :: a -> BitGet a
pure a
x = (S -> Get (S, a)) -> BitGet a
forall a. (S -> Get (S, a)) -> BitGet a
B ((S -> Get (S, a)) -> BitGet a) -> (S -> Get (S, a)) -> BitGet a
forall a b. (a -> b) -> a -> b
$ \S
s -> (S, a) -> Get (S, a)
forall (m :: * -> *) a. Monad m => a -> m a
return (S
s, a
x)
  BitGet (a -> b)
fm <*> :: BitGet (a -> b) -> BitGet a -> BitGet b
<*> BitGet a
m = BitGet (a -> b)
fm BitGet (a -> b) -> ((a -> b) -> BitGet b) -> BitGet b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \a -> b
f -> BitGet a
m BitGet a -> (a -> BitGet b) -> BitGet b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \a
v -> b -> BitGet b
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> b
f a
v)

instance Alternative BitGet where
  empty :: BitGet a
empty = (S -> Get (S, a)) -> BitGet a
forall a. (S -> Get (S, a)) -> BitGet a
B (Get (S, a) -> S -> Get (S, a)
forall a b. a -> b -> a
const Get (S, a)
forall (f :: * -> *) a. Alternative f => f a
Appl.empty)
  (B S -> Get (S, a)
f1) <|> :: BitGet a -> BitGet a -> BitGet a
<|> (B S -> Get (S, a)
f2) = (S -> Get (S, a)) -> BitGet a
forall a. (S -> Get (S, a)) -> BitGet a
B (\S
s -> S -> Get (S, a)
f1 S
s Get (S, a) -> Get (S, a) -> Get (S, a)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> S -> Get (S, a)
f2 S
s)

-- | Run a 'BitGet' within the Binary packages 'Get' monad. If a byte has
-- been partially consumed it will be discarded once 'runBitGet' is finished.
runBitGet :: BitGet a -> Get a
runBitGet :: BitGet a -> Get a
runBitGet BitGet a
bg = do
  S
s <- Get S
mkInitState
  (S ByteString
str' Int
n, a
a) <- BitGet a -> S -> Get (S, a)
forall a. BitGet a -> S -> Get (S, a)
runState BitGet a
bg S
s
  ByteString -> Int -> Get ()
putBackState ByteString
str' Int
n
  a -> Get a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a

mkInitState :: Get S
mkInitState :: Get S
mkInitState = do
  ByteString
str <- Get ByteString
get
  ByteString -> Get ()
put ByteString
B.empty
  S -> Get S
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> Int -> S
S ByteString
str Int
0)

putBackState :: B.ByteString -> Int -> Get ()
putBackState :: ByteString -> Int -> Get ()
putBackState ByteString
bs Int
n = do
  ByteString
remaining <- Get ByteString
get
  ByteString -> Get ()
put (Int -> ByteString -> ByteString
B.drop (if Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then Int
0 else Int
1) ByteString
bs ByteString -> ByteString -> ByteString
`B.append` ByteString
remaining)

getState :: BitGet S
getState :: BitGet S
getState = (S -> Get (S, S)) -> BitGet S
forall a. (S -> Get (S, a)) -> BitGet a
B ((S -> Get (S, S)) -> BitGet S) -> (S -> Get (S, S)) -> BitGet S
forall a b. (a -> b) -> a -> b
$ \S
s -> (S, S) -> Get (S, S)
forall (m :: * -> *) a. Monad m => a -> m a
return (S
s, S
s)

putState :: S -> BitGet ()
putState :: S -> BitGet ()
putState S
s = (S -> Get (S, ())) -> BitGet ()
forall a. (S -> Get (S, a)) -> BitGet a
B ((S -> Get (S, ())) -> BitGet ())
-> (S -> Get (S, ())) -> BitGet ()
forall a b. (a -> b) -> a -> b
$ \S
_ -> (S, ()) -> Get (S, ())
forall (m :: * -> *) a. Monad m => a -> m a
return (S
s, ())

-- | Make sure there are at least @n@ bits.
ensureBits :: Int -> BitGet ()
ensureBits :: Int -> BitGet ()
ensureBits Int
n = do
  (S ByteString
bs Int
o) <- BitGet S
getState
  if Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= (ByteString -> Int
B.length ByteString
bs Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
8 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
o)
    then () -> BitGet ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    else do
      let currentBits :: Int
currentBits = ByteString -> Int
B.length ByteString
bs Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
8 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
o
      let byteCount :: Int
byteCount = (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
currentBits Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
7) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
8
      (S -> Get (S, ())) -> BitGet ()
forall a. (S -> Get (S, a)) -> BitGet a
B ((S -> Get (S, ())) -> BitGet ())
-> (S -> Get (S, ())) -> BitGet ()
forall a b. (a -> b) -> a -> b
$ \S
_ -> do
        Int -> Get ()
B.ensureN Int
byteCount
        ByteString
bs' <- Get ByteString
B.get
        ByteString -> Get ()
put ByteString
B.empty
        (S, ()) -> Get (S, ())
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> Int -> S
S (ByteString
bs ByteString -> ByteString -> ByteString
`append` ByteString
bs') Int
o, ())

-- | Get 1 bit as a 'Bool'.
getBool :: BitGet Bool
getBool :: BitGet Bool
getBool = Block Bool -> BitGet Bool
forall a. Block a -> BitGet a
block Block Bool
bool

-- | Get @n@ bits as a 'Word8'. @n@ must be within @[0..8]@.
getWord8 :: Int -> BitGet Word8
getWord8 :: Int -> BitGet Word8
getWord8 Int
n = Block Word8 -> BitGet Word8
forall a. Block a -> BitGet a
block (Int -> Block Word8
word8 Int
n)

-- | Get @n@ bits as a 'Word16'. @n@ must be within @[0..16]@.
getWord16be :: Int -> BitGet Word16
getWord16be :: Int -> BitGet Word16
getWord16be Int
n = Block Word16 -> BitGet Word16
forall a. Block a -> BitGet a
block (Int -> Block Word16
word16be Int
n)

-- | Get @n@ bits as a 'Word32'. @n@ must be within @[0..32]@.
getWord32be :: Int -> BitGet Word32
getWord32be :: Int -> BitGet Word32
getWord32be Int
n = Block Word32 -> BitGet Word32
forall a. Block a -> BitGet a
block (Int -> Block Word32
word32be Int
n)

-- | Get @n@ bits as a 'Word64'. @n@ must be within @[0..64]@.
getWord64be :: Int -> BitGet Word64
getWord64be :: Int -> BitGet Word64
getWord64be Int
n = Block Word64 -> BitGet Word64
forall a. Block a -> BitGet a
block (Int -> Block Word64
word64be Int
n)

-- | Get @n@ bytes as a 'ByteString'.
getByteString :: Int -> BitGet ByteString
getByteString :: Int -> BitGet ByteString
getByteString Int
n = Block ByteString -> BitGet ByteString
forall a. Block a -> BitGet a
block (Int -> Block ByteString
byteString Int
n)

-- | Get @n@ bytes as a lazy ByteString.
getLazyByteString :: Int -> BitGet L.ByteString
getLazyByteString :: Int -> BitGet ByteString
getLazyByteString Int
n = do
  (S ByteString
_ Int
o) <- BitGet S
getState
  case Int
o of
    Int
0 -> (S -> Get (S, ByteString)) -> BitGet ByteString
forall a. (S -> Get (S, a)) -> BitGet a
B ((S -> Get (S, ByteString)) -> BitGet ByteString)
-> (S -> Get (S, ByteString)) -> BitGet ByteString
forall a b. (a -> b) -> a -> b
$ \(S ByteString
bs Int
o') -> do
      ByteString -> Int -> Get ()
putBackState ByteString
bs Int
o'
      ByteString
lbs <- Int64 -> Get ByteString
B.getLazyByteString (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)
      (S, ByteString) -> Get (S, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> Int -> S
S ByteString
B.empty Int
0, ByteString
lbs)
    Int
_ -> [ByteString] -> ByteString
L.fromChunks ([ByteString] -> ByteString)
-> (ByteString -> [ByteString]) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: []) (ByteString -> ByteString)
-> BitGet ByteString -> BitGet ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> BitGet ByteString
Data.Binary.Bits.Get.getByteString Int
n

-- | Test whether all input has been consumed, i.e. there are no remaining
-- undecoded bytes.
isEmpty :: BitGet Bool
isEmpty :: BitGet Bool
isEmpty = (S -> Get (S, Bool)) -> BitGet Bool
forall a. (S -> Get (S, a)) -> BitGet a
B ((S -> Get (S, Bool)) -> BitGet Bool)
-> (S -> Get (S, Bool)) -> BitGet Bool
forall a b. (a -> b) -> a -> b
$ \(S ByteString
bs Int
o) -> if ByteString -> Bool
B.null ByteString
bs
  then Get Bool
B.isEmpty Get Bool -> (Bool -> Get (S, Bool)) -> Get (S, Bool)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Bool
e -> (S, Bool) -> Get (S, Bool)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> Int -> S
S ByteString
bs Int
o, Bool
e)
  else (S, Bool) -> Get (S, Bool)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> Int -> S
S ByteString
bs Int
o, Bool
False)

-- | Read a 1 bit 'Bool'.
bool :: Block Bool
bool :: Block Bool
bool = Int -> (S -> Bool) -> Block Bool
forall a. Int -> (S -> a) -> Block a
Block Int
1 S -> Bool
readBool

-- | Read @n@ bits as a 'Word8'. @n@ must be within @[0..8]@.
word8 :: Int -> Block Word8
word8 :: Int -> Block Word8
word8 Int
n = Int -> (S -> Word8) -> Block Word8
forall a. Int -> (S -> a) -> Block a
Block Int
n (Int -> S -> Word8
readWord8 Int
n)

-- | Read @n@ bits as a 'Word16'. @n@ must be within @[0..16]@.
word16be :: Int -> Block Word16
word16be :: Int -> Block Word16
word16be Int
n = Int -> (S -> Word16) -> Block Word16
forall a. Int -> (S -> a) -> Block a
Block Int
n (Int -> S -> Word16
readWord16be Int
n)

-- | Read @n@ bits as a 'Word32'. @n@ must be within @[0..32]@.
word32be :: Int -> Block Word32
word32be :: Int -> Block Word32
word32be Int
n = Int -> (S -> Word32) -> Block Word32
forall a. Int -> (S -> a) -> Block a
Block Int
n (Int -> S -> Word32
readWord32be Int
n)

-- | Read @n@ bits as a 'Word64'. @n@ must be within @[0..64]@.
word64be :: Int -> Block Word64
word64be :: Int -> Block Word64
word64be Int
n = Int -> (S -> Word64) -> Block Word64
forall a. Int -> (S -> a) -> Block a
Block Int
n (Int -> S -> Word64
readWord64be Int
n)

-- | Read @n@ bytes as a 'ByteString'.
byteString :: Int -> Block ByteString
byteString :: Int -> Block ByteString
byteString Int
n
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = Int -> (S -> ByteString) -> Block ByteString
forall a. Int -> (S -> a) -> Block a
Block (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
8) (Int -> S -> ByteString
readByteString Int
n)
  | Bool
otherwise = Int -> (S -> ByteString) -> Block ByteString
forall a. Int -> (S -> a) -> Block a
Block Int
0 (ByteString -> S -> ByteString
forall a b. a -> b -> a
const ByteString
B.empty)

-- Unchecked shifts, from the package binary

shiftl_w16 :: Word16 -> Int -> Word16
shiftl_w32 :: Word32 -> Int -> Word32
shiftl_w64 :: Word64 -> Int -> Word64
shiftr_w8 :: Word8 -> Int -> Word8
shiftr_w16 :: Word16 -> Int -> Word16
shiftr_w32 :: Word32 -> Int -> Word32
shiftr_w64 :: Word64 -> Int -> Word64

shiftl_w16 :: Word16 -> Int -> Word16
shiftl_w16 = Word16 -> Int -> Word16
forall a. Bits a => a -> Int -> a
unsafeShiftL
shiftl_w32 :: Word32 -> Int -> Word32
shiftl_w32 = Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
unsafeShiftL
shiftl_w64 :: Word64 -> Int -> Word64
shiftl_w64 = Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
unsafeShiftL

shiftr_w8 :: Word8 -> Int -> Word8
shiftr_w8 = Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
unsafeShiftR
shiftr_w16 :: Word16 -> Int -> Word16
shiftr_w16 = Word16 -> Int -> Word16
forall a. Bits a => a -> Int -> a
unsafeShiftR
shiftr_w32 :: Word32 -> Int -> Word32
shiftr_w32 = Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
unsafeShiftR
shiftr_w64 :: Word64 -> Int -> Word64
shiftr_w64 = Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
unsafeShiftR