-- SPDX-FileCopyrightText: 2020 Tocqueville Group
--
-- SPDX-License-Identifier: LicenseRef-MIT-TQ

-- | Module contains helper functions when dealing with encoding
-- and decoding 'Binary'
module Util.Binary
  ( UnpackError (..)
  , ensureEnd
  , launchGet
  , TaggedDecoder(..)
  , (#:)
  , decodeBytesLike
  , decodeWithTag
  , getByteStringCopy
  , getRemainingByteStringCopy
  , unknownTag
  ) where

import Prelude hiding (EQ, Ordering(..), get)

import Data.Binary (Get)
import qualified Data.Binary.Get as Get
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LBS
import qualified Data.List as List
import Fmt (Buildable, build, fmt, hexF, pretty, (+|), (+||), (|+), (||+))
import Text.Hex (encodeHex)

----------------------------------------------------------------------------
-- Helpers
----------------------------------------------------------------------------

-- | Any decoding error.
newtype UnpackError = UnpackError { UnpackError -> Text
unUnpackError :: Text }
  deriving stock (Int -> UnpackError -> ShowS
[UnpackError] -> ShowS
UnpackError -> String
(Int -> UnpackError -> ShowS)
-> (UnpackError -> String)
-> ([UnpackError] -> ShowS)
-> Show UnpackError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [UnpackError] -> ShowS
$cshowList :: [UnpackError] -> ShowS
show :: UnpackError -> String
$cshow :: UnpackError -> String
showsPrec :: Int -> UnpackError -> ShowS
$cshowsPrec :: Int -> UnpackError -> ShowS
Show, UnpackError -> UnpackError -> Bool
(UnpackError -> UnpackError -> Bool)
-> (UnpackError -> UnpackError -> Bool) -> Eq UnpackError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: UnpackError -> UnpackError -> Bool
$c/= :: UnpackError -> UnpackError -> Bool
== :: UnpackError -> UnpackError -> Bool
$c== :: UnpackError -> UnpackError -> Bool
Eq)

instance Buildable UnpackError where
  build :: UnpackError -> Builder
build (UnpackError msg :: Text
msg) = Text -> Builder
forall p. Buildable p => p -> Builder
build Text
msg

instance Exception UnpackError where
  displayException :: UnpackError -> String
displayException = UnpackError -> String
forall a b. (Buildable a, FromBuilder b) => a -> b
pretty

ensureEnd :: Get ()
ensureEnd :: Get ()
ensureEnd =
  Get Bool -> Get () -> Get ()
forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
unlessM Get Bool
Get.isEmpty (Get () -> Get ()) -> Get () -> Get ()
forall a b. (a -> b) -> a -> b
$ do
    ByteString
remainder <- Get ByteString
Get.getRemainingLazyByteString
    String -> Get ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Get ()) -> String -> Get ()
forall a b. (a -> b) -> a -> b
$ "Expected end of entry, unconsumed bytes \
           \(" Builder -> Builder -> String
forall b. FromBuilder b => Builder -> Builder -> b
+| ByteString -> Int
forall t. Container t => t -> Int
length ByteString
remainder Int -> Builder -> Builder
forall a b. (Buildable a, FromBuilder b) => a -> Builder -> b
|+ "): "
           Builder -> Builder -> Builder
forall b. FromBuilder b => Builder -> Builder -> b
+|| ByteString -> Text
encodeHex (ByteString -> ByteString
LBS.toStrict ByteString
remainder) Text -> Builder -> Builder
forall a b. (Show a, FromBuilder b) => a -> Builder -> b
||+ ""

launchGet :: Get a -> LByteString -> Either UnpackError a
launchGet :: Get a -> ByteString -> Either UnpackError a
launchGet decoder :: Get a
decoder bs :: ByteString
bs =
  case Get a
-> ByteString
-> Either
     (ByteString, ByteOffset, String) (ByteString, ByteOffset, a)
forall a.
Get a
-> ByteString
-> Either
     (ByteString, ByteOffset, String) (ByteString, ByteOffset, a)
Get.runGetOrFail Get a
decoder ByteString
bs of
    Left (_remainder :: ByteString
_remainder, _offset :: ByteOffset
_offset, err :: String
err) -> UnpackError -> Either UnpackError a
forall a b. a -> Either a b
Left (UnpackError -> Either UnpackError a)
-> (Text -> UnpackError) -> Text -> Either UnpackError a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> UnpackError
UnpackError (Text -> Either UnpackError a) -> Text -> Either UnpackError a
forall a b. (a -> b) -> a -> b
$ String -> Text
forall a. ToText a => a -> Text
toText String
err
    Right (_remainder :: ByteString
_remainder, _offset :: ByteOffset
_offset, res :: a
res) -> a -> Either UnpackError a
forall a b. b -> Either a b
Right a
res

-- | Describes how 'decodeWithTag' should decode tag-dependent data.
-- We expect bytes of such structure: 'tdTag' followed by a bytestring
-- which will be parsed with 'tdDecoder'.
data TaggedDecoder a = TaggedDecoder
  { TaggedDecoder a -> Word8
tdTag :: Word8
  , TaggedDecoder a -> Get a
tdDecoder :: Get a
  }

-- | Alias for 'TaggedDecoder' constructor.
(#:) :: Word8 -> Get a -> TaggedDecoder a
#: :: Word8 -> Get a -> TaggedDecoder a
(#:) = Word8 -> Get a -> TaggedDecoder a
forall a. Word8 -> Get a -> TaggedDecoder a
TaggedDecoder
infixr 0 #:

-- | Get a bytestring of the given length leaving no references to the
-- original data in serialized form.
getByteStringCopy :: Int -> Get ByteString
getByteStringCopy :: Int -> Get ByteString
getByteStringCopy = (ByteString -> ByteString) -> Get ByteString -> Get ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> ByteString
BS.copy (Get ByteString -> Get ByteString)
-> (Int -> Get ByteString) -> Int -> Get ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Get ByteString
Get.getByteString

-- | Get remaining available bytes.
--
-- Note that reading all remaining decoded input may be expensive and is thus
-- discouraged, use can use this function only when you know that amount
-- of data to be consumed is limited, e.g. within 'decodeAsBytes' call.
getRemainingByteStringCopy :: Get ByteString
getRemainingByteStringCopy :: Get ByteString
getRemainingByteStringCopy = do
  ByteString
lbs <- Get ByteString
Get.getRemainingLazyByteString
  -- Avoiding memory overflows in case bad length to 'Get.isolate' was provided.
  -- Normally this function is used only to decode primitives, 'Signature' in
  -- the worst case, so we could set little length, but since this is a hack
  -- anyway let's make sure it never obstructs our work.
  Bool -> Get () -> Get ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString -> Int
forall t. Container t => t -> Int
length ByteString
lbs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> 640000) (Get () -> Get ()) -> Get () -> Get ()
forall a b. (a -> b) -> a -> b
$ String -> Get ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail "Too big length for an entity"
  return (ByteString -> ByteString
LBS.toStrict ByteString
lbs)

-- | Fail with "unknown tag" error.
unknownTag :: String -> Word8 -> Get a
unknownTag :: String -> Word8 -> Get a
unknownTag desc :: String
desc tag :: Word8
tag =
  String -> Get a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Get a) -> (Builder -> String) -> Builder -> Get a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> String
forall b. FromBuilder b => Builder -> b
fmt (Builder -> Get a) -> Builder -> Get a
forall a b. (a -> b) -> a -> b
$ "Unknown " Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> String -> Builder
forall p. Buildable p => p -> Builder
build String
desc Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> " tag: 0x" Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Word8 -> Builder
forall a. FormatAsHex a => a -> Builder
hexF Word8
tag

-- Common decoder for the case when packed data starts with a tag (1
-- byte) that specifies how to decode remaining data.
decodeWithTag :: String -> [TaggedDecoder a] -> Get a
decodeWithTag :: String -> [TaggedDecoder a] -> Get a
decodeWithTag what :: String
what decoders :: [TaggedDecoder a]
decoders = do
  Word8
tag <- String -> Get Word8 -> Get Word8
forall a. String -> Get a -> Get a
Get.label (String
what String -> ShowS
forall a. Semigroup a => a -> a -> a
<> " tag") Get Word8
Get.getWord8
  -- Number of decoders is usually small, so linear runtime lookup should be ok.
  case (TaggedDecoder a -> Bool)
-> [TaggedDecoder a] -> Maybe (TaggedDecoder a)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
List.find ((Word8
tag Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
==) (Word8 -> Bool)
-> (TaggedDecoder a -> Word8) -> TaggedDecoder a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TaggedDecoder a -> Word8
forall a. TaggedDecoder a -> Word8
tdTag) [TaggedDecoder a]
decoders of
    Nothing -> String -> Word8 -> Get a
forall a. String -> Word8 -> Get a
unknownTag String
what Word8
tag
    Just TaggedDecoder{..} -> Get a
tdDecoder

decodeBytesLike
  :: (Buildable err)
  => String -> (ByteString -> Either err a) -> Get a
decodeBytesLike :: String -> (ByteString -> Either err a) -> Get a
decodeBytesLike what :: String
what constructor :: ByteString -> Either err a
constructor = do
  ByteString
bs <- Get ByteString
getRemainingByteStringCopy
  case ByteString -> Either err a
constructor ByteString
bs of
    Left err :: err
err -> String -> Get a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Get a) -> String -> Get a
forall a b. (a -> b) -> a -> b
$ "Wrong " Builder -> Builder -> String
forall b. FromBuilder b => Builder -> Builder -> b
+| String
what String -> Builder -> Builder
forall a b. (Buildable a, FromBuilder b) => a -> Builder -> b
|+ ": " Builder -> Builder -> Builder
forall b. FromBuilder b => Builder -> Builder -> b
+| err
err err -> Builder -> Builder
forall a b. (Buildable a, FromBuilder b) => a -> Builder -> b
|+ ""
    Right res :: a
res -> a -> Get a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
res