{-# LANGUAGE OverloadedStrings #-}

module Network.HPACK.HeaderBlock.Integer (
    encodeI,
    encodeInteger,
    decodeI,
    decodeInteger,
) where

import Data.Array (Array, listArray)
import Data.Array.Base (unsafeAt)
import Network.ByteOrder

import Imports

-- $setup
-- >>> import qualified Data.ByteString as BS

powerArray :: Array Int Int
powerArray :: Array Int Int
powerArray = (Int, Int) -> [Int] -> Array Int Int
forall i e. Ix i => (i, i) -> [e] -> Array i e
listArray (Int
1, Int
8) [Int
1, Int
3, Int
7, Int
15, Int
31, Int
63, Int
127, Int
255]

----------------------------------------------------------------

{-
if I < 2^N - 1, encode I on N bits
   else
       encode (2^N - 1) on N bits
       I = I - (2^N - 1)
       while I >= 128
            encode (I % 128 + 128) on 8 bits
            I = I / 128
       encode I on 8 bits
-}

-- | Encoding integer with a temporary buffer whose size is 4096.
--   No prefix is set.
--
-- >>> BS.unpack <$> encodeInteger 5 10
-- [10]
-- >>> BS.unpack <$> encodeInteger 5 1337
-- [31,154,10]
-- >>> BS.unpack <$> encodeInteger 8 42
-- [42]
encodeInteger
    :: Int
    -- ^ N+
    -> Int
    -- ^ Target
    -> IO ByteString
encodeInteger :: Int -> Int -> IO ByteString
encodeInteger Int
n Int
i = Int -> (WriteBuffer -> IO ()) -> IO ByteString
withWriteBuffer Int
4096 ((WriteBuffer -> IO ()) -> IO ByteString)
-> (WriteBuffer -> IO ()) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \WriteBuffer
wbuf -> WriteBuffer -> (Word8 -> Word8) -> Int -> Int -> IO ()
encodeI WriteBuffer
wbuf Word8 -> Word8
forall a. a -> a
id Int
n Int
i

-- Using write8 is faster than using internals directly.
--

-- | Integer encoding with a write buffer.
{-# INLINEABLE encodeI #-}
encodeI
    :: WriteBuffer
    -> (Word8 -> Word8)
    -- ^ Setting prefix
    -> Int
    -- ^ N+
    -> Int
    -- ^ Target
    -> IO ()
encodeI :: WriteBuffer -> (Word8 -> Word8) -> Int -> Int -> IO ()
encodeI WriteBuffer
wbuf Word8 -> Word8
set Int
n Int
i
    | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
p = WriteBuffer -> Word8 -> IO ()
write8 WriteBuffer
wbuf (Word8 -> IO ()) -> Word8 -> IO ()
forall a b. (a -> b) -> a -> b
$ Word8 -> Word8
set (Word8 -> Word8) -> Word8 -> Word8
forall a b. (a -> b) -> a -> b
$ Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i
    | Bool
otherwise = do
        WriteBuffer -> Word8 -> IO ()
write8 WriteBuffer
wbuf (Word8 -> IO ()) -> Word8 -> IO ()
forall a b. (a -> b) -> a -> b
$ Word8 -> Word8
set (Word8 -> Word8) -> Word8 -> Word8
forall a b. (a -> b) -> a -> b
$ Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
p
        Int -> IO ()
encode' (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
p)
  where
    p :: Int
p = Array Int Int
powerArray Array Int Int -> Int -> Int
forall i. Ix i => Array i Int -> Int -> Int
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> Int -> e
`unsafeAt` (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
    encode' :: Int -> IO ()
    encode' :: Int -> IO ()
encode' Int
j
        | Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
128 = WriteBuffer -> Word8 -> IO ()
write8 WriteBuffer
wbuf (Word8 -> IO ()) -> Word8 -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
j
        | Bool
otherwise = do
            let q :: Int
q = Int
j Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
7
                r :: Int
r = Int
j Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
0x7f
            WriteBuffer -> Word8 -> IO ()
write8 WriteBuffer
wbuf (Word8 -> IO ()) -> Word8 -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
128)
            Int -> IO ()
encode' Int
q

----------------------------------------------------------------

{-
decode I from the next N bits
   if I < 2^N - 1, return I
   else
       M = 0
       repeat
           B = next octet
           I = I + (B & 127) * 2^M
           M = M + 7
       while B & 128 == 128
       return I
-}

-- | Integer decoding. The first argument is N of prefix.
--
-- >>> decodeInteger 5 10 $ BS.empty
-- 10
-- >>> decodeInteger 5 31 $ BS.pack [154,10]
-- 1337
-- >>> decodeInteger 8 42 $ BS.empty
-- 42
decodeInteger
    :: Int
    -- ^ N+
    -> Word8
    -- ^ The head of encoded integer whose prefix is already dropped
    -> ByteString
    -- ^ The tail of encoded integer
    -> IO Int
decodeInteger :: Int -> Word8 -> ByteString -> IO Int
decodeInteger Int
n Word8
w ByteString
bs = ByteString -> (ReadBuffer -> IO Int) -> IO Int
forall a. ByteString -> (ReadBuffer -> IO a) -> IO a
withReadBuffer ByteString
bs ((ReadBuffer -> IO Int) -> IO Int)
-> (ReadBuffer -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \ReadBuffer
rbuf -> Int -> Word8 -> ReadBuffer -> IO Int
decodeI Int
n Word8
w ReadBuffer
rbuf

{-# INLINEABLE decodeI #-}

-- | Integer decoding with a read buffer. The first argument is N of prefix.
decodeI
    :: Int
    -- ^ N+
    -> Word8
    -- ^ The head of encoded integer whose prefix is already dropped
    -> ReadBuffer
    -> IO Int
decodeI :: Int -> Word8 -> ReadBuffer -> IO Int
decodeI Int
n Word8
w ReadBuffer
rbuf
    | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
p = Int -> IO Int
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
i
    | Bool
otherwise = Int -> Int -> IO Int
decode Int
0 Int
i
  where
    p :: Int
p = Array Int Int
powerArray Array Int Int -> Int -> Int
forall i. Ix i => Array i Int -> Int -> Int
forall (a :: * -> * -> *) e i.
(IArray a e, Ix i) =>
a i e -> Int -> e
`unsafeAt` (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
    i :: Int
i = Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
w
    decode :: Int -> Int -> IO Int
    decode :: Int -> Int -> IO Int
decode Int
m Int
j = do
        Int
b <- Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Int) -> IO Word8 -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReadBuffer -> IO Word8
forall a. Readable a => a -> IO Word8
read8 ReadBuffer
rbuf
        let j' :: Int
j' = Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
b Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
0x7f) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2 Int -> Int -> Int
forall a b. (Num a, Integral b) => a -> b -> a
^ Int
m
            m' :: Int
m' = Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
7
            cont :: Bool
cont = Int
b Int -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
`testBit` Int
7
        if Bool
cont then Int -> Int -> IO Int
decode Int
m' Int
j' else Int -> IO Int
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
j'