{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -Wno-dodgy-imports #-}
{-# OPTIONS_GHC -O2 #-}
-- {-# OPTIONS_GHC -ddump-simpl -dsuppress-unfoldings -dsuppress-idinfo -dsuppress-module-prefixes -ddump-to-file #-}

-- |
-- Module      : Data.Serialize.LEB128
-- Description : LEB128 encoding
-- License     : MIT
-- Maintainer  : Joachim Breitner
--
-- | This module implements encoding and decoding of 'Natural' and 'Integer'
-- values according to LEB128 and SLEB128. See
-- https://en.wikipedia.org/wiki/LEB128 for a specification.
--
-- The module provides conversion to and from strict bytestrings.
--
-- Additionally, to integrate these into your own parsers and serializers, you
-- can use the interfaces based on 'B.Builder' as well as @cereal@'s 'G.Get'
-- and 'P.Put' monad.
--
-- The decoders will fail if the input is not in canonical representation,
-- i.e. longer than necessary.
-- Use "Data.Serialize.LEB128.Lenient" if you need the strict semantics.
--
-- This code is inspired by Andreas Klebinger's LEB128 implementation in GHC.
module Data.Serialize.LEB128
    (
    -- * The class of encodable and decodable types
      LEB128
    , SLEB128
    -- * Bytestring-based interface
    , toLEB128
    , fromLEB128
    , toSLEB128
    , fromSLEB128
    -- * Builder interface
    , buildLEB128
    , buildSLEB128
    -- * Cereal interface
    , getLEB128
    , getSLEB128
    , putLEB128
    , putSLEB128
    ) where

import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BSL
import qualified Data.ByteString.Builder as B
import qualified Data.ByteString.Builder.Extra as B
import qualified Data.Serialize.Get as G
import qualified Data.Serialize.Put as P
import Numeric.Natural
import Control.Applicative
import Control.Monad
import Data.Bits
import Data.Word
import Data.Int
import Data.Maybe
import Data.Monoid ((<>))
import Prelude hiding ((<>))

-- | Unsigned number types can be LEB128-encoded
class (Bits a, Num a, Integral a) => LEB128 a where
instance LEB128 Natural
instance LEB128 Word
instance LEB128 Word8
instance LEB128 Word16
instance LEB128 Word32
instance LEB128 Word64

-- | Signed number types can be SLEB128-encoded
class (Bits a, Num a, Integral a) => SLEB128 a
instance SLEB128 Integer
instance SLEB128 Int
instance SLEB128 Int8
instance SLEB128 Int16
instance SLEB128 Int32
instance SLEB128 Int64

-- | LEB128-encodes a natural number to a strict bytestring
toLEB128 :: LEB128 a => a -> BS.ByteString
toLEB128 :: a -> ByteString
toLEB128 = ByteString -> ByteString
BSL.toStrict (ByteString -> ByteString) -> (a -> ByteString) -> a -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AllocationStrategy -> ByteString -> Builder -> ByteString
B.toLazyByteStringWith (Int -> Int -> AllocationStrategy
B.safeStrategy Int
32 Int
32) ByteString
BSL.empty (Builder -> ByteString) -> (a -> Builder) -> a -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Builder
forall a. LEB128 a => a -> Builder
buildLEB128

{-# SPECIALIZE toLEB128 :: Natural -> BS.ByteString #-}
{-# SPECIALIZE toLEB128 :: Word -> BS.ByteString #-}
{-# SPECIALIZE toLEB128 :: Word8 -> BS.ByteString #-}
{-# SPECIALIZE toLEB128 :: Word16 -> BS.ByteString #-}
{-# SPECIALIZE toLEB128 :: Word32 -> BS.ByteString #-}
{-# SPECIALIZE toLEB128 :: Word64 -> BS.ByteString #-}

-- | SLEB128-encodes an integer to a strict bytestring
toSLEB128 :: SLEB128 a => a -> BS.ByteString
toSLEB128 :: a -> ByteString
toSLEB128 = ByteString -> ByteString
BSL.toStrict (ByteString -> ByteString) -> (a -> ByteString) -> a -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AllocationStrategy -> ByteString -> Builder -> ByteString
B.toLazyByteStringWith (Int -> Int -> AllocationStrategy
B.safeStrategy Int
32 Int
32) ByteString
BSL.empty (Builder -> ByteString) -> (a -> Builder) -> a -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Builder
forall a. SLEB128 a => a -> Builder
buildSLEB128

{-# SPECIALIZE toSLEB128 :: Integer -> BS.ByteString #-}
{-# SPECIALIZE toSLEB128 :: Int -> BS.ByteString #-}
{-# SPECIALIZE toSLEB128 :: Int8 -> BS.ByteString #-}
{-# SPECIALIZE toSLEB128 :: Int16 -> BS.ByteString #-}
{-# SPECIALIZE toSLEB128 :: Int32 -> BS.ByteString #-}
{-# SPECIALIZE toSLEB128 :: Int64 -> BS.ByteString #-}

-- | LEB128-encodes a natural number via a builder
buildLEB128 :: LEB128 a => a -> B.Builder
buildLEB128 :: a -> Builder
buildLEB128 = a -> Builder
forall t. (Integral t, Bits t) => t -> Builder
go
  where
    go :: t -> Builder
go t
i
      | t
i t -> t -> Bool
forall a. Ord a => a -> a -> Bool
<= t
127
      = Word8 -> Builder
B.word8 (t -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral t
i :: Word8)
      | Bool
otherwise =
        -- bit 7 (8th bit) indicates more to come.
        Word8 -> Builder
B.word8 (Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
setBit (t -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral t
i) Int
7) Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> t -> Builder
go (t
i t -> Int -> t
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
7)

{-# SPECIALIZE buildLEB128 :: Natural -> B.Builder #-}
{-# SPECIALIZE buildLEB128 :: Word -> B.Builder #-}
{-# SPECIALIZE buildLEB128 :: Word8 -> B.Builder #-}
{-# SPECIALIZE buildLEB128 :: Word16 -> B.Builder #-}
{-# SPECIALIZE buildLEB128 :: Word32 -> B.Builder #-}
{-# SPECIALIZE buildLEB128 :: Word64 -> B.Builder #-}

-- This gets inlined for the specialied variants
isFinite :: forall a. Bits a => Bool
isFinite :: Bool
isFinite = Maybe Int -> Bool
forall a. Maybe a -> Bool
isJust (a -> Maybe Int
forall a. Bits a => a -> Maybe Int
bitSizeMaybe (a
forall a. HasCallStack => a
undefined :: a))

-- | SLEB128-encodes an integer via a builder
buildSLEB128 :: SLEB128 a => a -> B.Builder
buildSLEB128 :: a -> Builder
buildSLEB128 a
val
  | a
val a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= -a
64 Bool -> Bool -> Bool
&& a
val a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
64 = Builder
stopByte
  | Bool
otherwise = Builder
goByte Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> a -> Builder
forall a. SLEB128 a => a -> Builder
buildSLEB128 (a -> Int -> a
forall a. Bits a => a -> Int -> a
shiftR a
val Int
7)
  where
  stopByte :: Builder
stopByte = Word8 -> Builder
B.word8 (a -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Word8) -> a -> Word8
forall a b. (a -> b) -> a -> b
$ a -> Int -> a
forall a. Bits a => a -> Int -> a
clearBit a
val Int
7)
  goByte :: Builder
goByte = Word8 -> Builder
B.word8 (a -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Word8) -> a -> Word8
forall a b. (a -> b) -> a -> b
$ a -> Int -> a
forall a. Bits a => a -> Int -> a
setBit a
val Int
7)

{-# SPECIALIZE buildSLEB128 :: Integer -> B.Builder #-}
{-# SPECIALIZE buildSLEB128 :: Int -> B.Builder #-}
{-# SPECIALIZE buildSLEB128 :: Int8 -> B.Builder #-}
{-# SPECIALIZE buildSLEB128 :: Int16 -> B.Builder #-}
{-# SPECIALIZE buildSLEB128 :: Int32 -> B.Builder #-}
{-# SPECIALIZE buildSLEB128 :: Int64 -> B.Builder #-}

-- | LEB128-encodes a natural number in @cereal@'s 'P.Put' monad
putLEB128 :: LEB128 a => P.Putter a
putLEB128 :: Putter a
putLEB128 = Putter Builder
P.putBuilder Putter Builder -> (a -> Builder) -> Putter a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Builder
forall a. LEB128 a => a -> Builder
buildLEB128
{-# INLINE putLEB128 #-}

-- | SLEB128-encodes an integer in @cereal@'s 'P.Put' monad
putSLEB128 :: SLEB128 a => P.Putter a
putSLEB128 :: Putter a
putSLEB128 = Putter Builder
P.putBuilder Putter Builder -> (a -> Builder) -> Putter a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Builder
forall a. SLEB128 a => a -> Builder
buildSLEB128
{-# INLINE putSLEB128 #-}

-- | LEB128-decodes a natural number from a strict bytestring
fromLEB128 :: LEB128 a => BS.ByteString -> Either String a
fromLEB128 :: ByteString -> Either String a
fromLEB128 = Get a -> ByteString -> Either String a
forall a. Get a -> ByteString -> Either String a
runComplete Get a
forall a. LEB128 a => Get a
getLEB128
{-# INLINE fromLEB128 #-}

-- | SLEB128-decodes an integer from a strict bytestring
fromSLEB128 :: SLEB128 a => BS.ByteString -> Either String a
fromSLEB128 :: ByteString -> Either String a
fromSLEB128 = Get a -> ByteString -> Either String a
forall a. Get a -> ByteString -> Either String a
runComplete Get a
forall a. SLEB128 a => Get a
getSLEB128
{-# INLINE fromSLEB128 #-}

runComplete :: G.Get a -> BS.ByteString -> Either String a
runComplete :: Get a -> ByteString -> Either String a
runComplete Get a
p ByteString
bs = do
    (a
x,ByteString
r) <- Get a -> ByteString -> Int -> Either String (a, ByteString)
forall a.
Get a -> ByteString -> Int -> Either String (a, ByteString)
G.runGetState Get a
p ByteString
bs Int
0
    Bool -> Either String () -> Either String ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> Bool
BS.null ByteString
r) (Either String () -> Either String ())
-> Either String () -> Either String ()
forall a b. (a -> b) -> a -> b
$ String -> Either String ()
forall a b. a -> Either a b
Left String
"extra bytes in input"
    a -> Either String a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x

-- | LEB128-decodes a natural number via @cereal@
getLEB128 :: forall a. LEB128 a => G.Get a
getLEB128 :: Get a
getLEB128 = String -> Get a -> Get a
forall a. String -> Get a -> Get a
G.label String
"LEB128" (Get a -> Get a) -> Get a -> Get a
forall a b. (a -> b) -> a -> b
$ Int -> a -> Get a
go Int
0 a
0
  where
    go :: Int -> a -> G.Get a
    go :: Int -> a -> Get a
go !Int
shift !a
w = do
      Word8
byte <- Get Word8
G.getWord8 Get Word8 -> Get Word8 -> Get Word8
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> String -> Get Word8
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"short encoding"
      let !byteVal :: a
byteVal = Word8 -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
clearBit Word8
byte Int
7)
      case a -> Maybe Int
forall a. Bits a => a -> Maybe Int
bitSizeMaybe a
w of
          Just Int
bs | Int
shift Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
bs -> String -> Get ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"overflow"
          Maybe Int
_ -> () -> Get ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      Bool -> Get () -> Get ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bits a => Bool
forall a. Bits a => Bool
isFinite @a) (Get () -> Get ()) -> Get () -> Get ()
forall a b. (a -> b) -> a -> b
$
        Bool -> Get () -> Get ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (a
byteVal a -> Int -> a
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
shift a -> Int -> a
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
shift a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
byteVal) (Get () -> Get ()) -> Get () -> Get ()
forall a b. (a -> b) -> a -> b
$
          String -> Get ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"overflow"
      let !val :: a
val = a
w a -> a -> a
forall a. Bits a => a -> a -> a
.|. (a
byteVal a -> Int -> a
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
shift)
      let !shift' :: Int
shift' = Int
shiftInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
7
      if Word8 -> Bool
forall a. Bits a => a -> Bool
hasMore Word8
byte
        then Int -> a -> Get a
go Int
shift' a
val
        else do
          Bool -> Get () -> Get ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Word8
byte Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x00 Bool -> Bool -> Bool
&& Int
shift Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0)
            (Get () -> Get ()) -> Get () -> Get ()
forall a b. (a -> b) -> a -> b
$ String -> Get ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"overlong encoding"
          a -> Get a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Get a) -> a -> Get a
forall a b. (a -> b) -> a -> b
$! a
val

    hasMore :: a -> Bool
hasMore a
b = a -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit a
b Int
7

{-# SPECIALIZE getLEB128 :: G.Get Natural #-}
{-# SPECIALIZE getLEB128 :: G.Get Word #-}
{-# SPECIALIZE getLEB128 :: G.Get Word8 #-}
{-# SPECIALIZE getLEB128 :: G.Get Word16 #-}
{-# SPECIALIZE getLEB128 :: G.Get Word32 #-}
{-# SPECIALIZE getLEB128 :: G.Get Word64 #-}

-- | SLEB128-decodes an integer via @cereal@
getSLEB128 :: forall a. SLEB128 a => G.Get a
getSLEB128 :: Get a
getSLEB128 = String -> Get a -> Get a
forall a. String -> Get a -> Get a
G.label String
"SLEB128" (Get a -> Get a) -> Get a -> Get a
forall a b. (a -> b) -> a -> b
$ Word8 -> Int -> a -> Get a
go Word8
0 Int
0 a
0
  where
    go :: Word8 -> Int -> a -> G.Get a
    go :: Word8 -> Int -> a -> Get a
go !Word8
prev !Int
shift !a
w = do
        Word8
byte <- Get Word8
G.getWord8 Get Word8 -> Get Word8 -> Get Word8
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> String -> Get Word8
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"short encoding"
        let !byteVal :: a
byteVal = Word8 -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
clearBit Word8
byte Int
7)
        case a -> Maybe Int
forall a. Bits a => a -> Maybe Int
bitSizeMaybe a
w of
            Just Int
bs | Int
shift Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
bs -> String -> Get ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"overflow"
            Maybe Int
_ -> () -> Get ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        Bool -> Get () -> Get ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bits a => Bool
forall a. Bits a => Bool
isFinite @a) (Get () -> Get ()) -> Get () -> Get ()
forall a b. (a -> b) -> a -> b
$
          Bool -> Get () -> Get ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((a
byteVal a -> Int -> a
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
shift a -> Int -> a
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
shift) a -> a -> a
forall a. Bits a => a -> a -> a
.&. a
0x7f a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
byteVal) (Get () -> Get ()) -> Get () -> Get ()
forall a b. (a -> b) -> a -> b
$
            String -> Get ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"overflow"
        let !val :: a
val = a
w a -> a -> a
forall a. Bits a => a -> a -> a
.|. (a
byteVal a -> Int -> a
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
shift)
        let !shift' :: Int
shift' = Int
shiftInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
7
        if Word8 -> Bool
forall a. Bits a => a -> Bool
hasMore Word8
byte
            then Word8 -> Int -> a -> Get a
go Word8
byte Int
shift' a
val
            else if Word8 -> Bool
forall a. Bits a => a -> Bool
signed Word8
byte
              then do
                Bool -> Get () -> Get ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Word8
byte Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x7f Bool -> Bool -> Bool
&& Word8 -> Bool
forall a. Bits a => a -> Bool
signed Word8
prev Bool -> Bool -> Bool
&& Int
shift Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0)
                  (Get () -> Get ()) -> Get () -> Get ()
forall a b. (a -> b) -> a -> b
$ String -> Get ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"overlong encoding"
                a -> Get a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Get a) -> a -> Get a
forall a b. (a -> b) -> a -> b
$! a
val a -> a -> a
forall a. Num a => a -> a -> a
- Int -> a
forall a. Bits a => Int -> a
bit Int
shift'
              else do
                Bool -> Get () -> Get ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Word8
byte Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x00 Bool -> Bool -> Bool
&& Bool -> Bool
not (Word8 -> Bool
forall a. Bits a => a -> Bool
signed Word8
prev) Bool -> Bool -> Bool
&& Int
shift Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0)
                  (Get () -> Get ()) -> Get () -> Get ()
forall a b. (a -> b) -> a -> b
$ String -> Get ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"overlong encoding"
                a -> Get a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Get a) -> a -> Get a
forall a b. (a -> b) -> a -> b
$! a
val

    hasMore :: a -> Bool
hasMore a
b = a -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit a
b Int
7
    signed :: a -> Bool
signed a
b = a -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit a
b Int
6

{-# SPECIALIZE getSLEB128 :: G.Get Integer #-}
{-# SPECIALIZE getSLEB128 :: G.Get Int #-}
{-# SPECIALIZE getSLEB128 :: G.Get Int8 #-}
{-# SPECIALIZE getSLEB128 :: G.Get Int16 #-}
{-# SPECIALIZE getSLEB128 :: G.Get Int32 #-}
{-# SPECIALIZE getSLEB128 :: G.Get Int64 #-}