-- SPDX-FileCopyrightText: 2021 Oxhead Alpha
-- SPDX-License-Identifier: LicenseRef-MIT-OA

-- | This module contains helper functions to deal with encoding
-- and decoding of binary data.
module Morley.Util.Binary
  ( UnpackError (..)
  , ensureEnd
  , launchGet
  , TaggedDecoder
  , TaggedDecoderM(..)
  , (#:)
  , (##:)
  , decodeBytesLike
  , decodeWithTag
  , decodeWithTagM
  , getByteStringCopy
  , getRemainingByteStringCopy
  , unknownTag
  ) where

import Prelude hiding (EQ, Ordering(..), get)

import Data.Binary (Get)
import Data.Binary.Get qualified as Get
import Data.ByteString qualified as BS
import Data.ByteString.Lazy qualified as LBS
import Data.List qualified as List
import Fmt (Buildable, build, fmt, hexF, pretty, (+|), (|+))
import Text.Hex (encodeHex)

----------------------------------------------------------------------------
-- Helpers
----------------------------------------------------------------------------

-- | Any decoding error.
newtype UnpackError = UnpackError { UnpackError -> Text
unUnpackError :: Text }
  deriving stock (Int -> UnpackError -> ShowS
[UnpackError] -> ShowS
UnpackError -> String
(Int -> UnpackError -> ShowS)
-> (UnpackError -> String)
-> ([UnpackError] -> ShowS)
-> Show UnpackError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [UnpackError] -> ShowS
$cshowList :: [UnpackError] -> ShowS
show :: UnpackError -> String
$cshow :: UnpackError -> String
showsPrec :: Int -> UnpackError -> ShowS
$cshowsPrec :: Int -> UnpackError -> ShowS
Show, UnpackError -> UnpackError -> Bool
(UnpackError -> UnpackError -> Bool)
-> (UnpackError -> UnpackError -> Bool) -> Eq UnpackError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: UnpackError -> UnpackError -> Bool
$c/= :: UnpackError -> UnpackError -> Bool
== :: UnpackError -> UnpackError -> Bool
$c== :: UnpackError -> UnpackError -> Bool
Eq)

instance Buildable UnpackError where
  build :: UnpackError -> Builder
build (UnpackError Text
msg) = Text -> Builder
forall p. Buildable p => p -> Builder
build Text
msg

instance Exception UnpackError where
  displayException :: UnpackError -> String
displayException = UnpackError -> String
forall a b. (Buildable a, FromBuilder b) => a -> b
pretty

ensureEnd :: Get ()
ensureEnd :: Get ()
ensureEnd =
  Get Bool -> Get () -> Get ()
forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
unlessM Get Bool
Get.isEmpty (Get () -> Get ()) -> Get () -> Get ()
forall a b. (a -> b) -> a -> b
$ do
    ByteString
remainder <- Get ByteString
Get.getRemainingLazyByteString
    String -> Get ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Get ()) -> String -> Get ()
forall a b. (a -> b) -> a -> b
$ Builder
"Expected end of entry, unconsumed bytes \
           \(" Builder -> Builder -> String
forall b. FromBuilder b => Builder -> Builder -> b
+| ByteString -> Int
forall t. Container t => t -> Int
length ByteString
remainder Int -> Builder -> Builder
forall a b. (Buildable a, FromBuilder b) => a -> Builder -> b
|+ Builder
"): \""
           Builder -> Builder -> Builder
forall b. FromBuilder b => Builder -> Builder -> b
+| ByteString -> Text
encodeHex (ByteString -> ByteString
LBS.toStrict ByteString
remainder) Text -> Builder -> Builder
forall a b. (Buildable a, FromBuilder b) => a -> Builder -> b
|+ Builder
"\""

launchGet :: Get a -> LByteString -> Either UnpackError a
launchGet :: forall a. Get a -> ByteString -> Either UnpackError a
launchGet Get a
decoder ByteString
bs =
  case Get a
-> ByteString
-> Either
     (ByteString, ByteOffset, String) (ByteString, ByteOffset, a)
forall a.
Get a
-> ByteString
-> Either
     (ByteString, ByteOffset, String) (ByteString, ByteOffset, a)
Get.runGetOrFail Get a
decoder ByteString
bs of
    Left (ByteString
_remainder, ByteOffset
_offset, String
err) -> UnpackError -> Either UnpackError a
forall a b. a -> Either a b
Left (UnpackError -> Either UnpackError a)
-> (Text -> UnpackError) -> Text -> Either UnpackError a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> UnpackError
UnpackError (Text -> Either UnpackError a) -> Text -> Either UnpackError a
forall a b. (a -> b) -> a -> b
$ String -> Text
forall a. ToText a => a -> Text
toText String
err
    Right (ByteString
_remainder, ByteOffset
_offset, a
res) -> a -> Either UnpackError a
forall a b. b -> Either a b
Right a
res

-- | Specialization of 'TaggedDecoderM' to 'IdentityT' transformer.
type TaggedDecoder a = TaggedDecoderM IdentityT a

-- | Describes how 'decodeWithTagM' should decode tag-dependent data.
-- We expect bytes of such structure: 'tdTag' followed by a bytestring
-- which will be parsed with 'tdDecoder'.
data TaggedDecoderM t a = TaggedDecoder
  { forall {k} (t :: (* -> *) -> k -> *) (a :: k).
TaggedDecoderM t a -> Word8
tdTag :: Word8
  , forall {k} (t :: (* -> *) -> k -> *) (a :: k).
TaggedDecoderM t a -> t Get a
tdDecoder :: t Get a
  }

-- | Alias for v'TaggedDecoder' constructor specialized to 'Get'
(#:) :: Word8 -> Get a -> TaggedDecoder a
#: :: forall a. Word8 -> Get a -> TaggedDecoder a
(#:) Word8
t = Word8 -> IdentityT Get a -> TaggedDecoderM IdentityT a
forall {k} (t :: (* -> *) -> k -> *) (a :: k).
Word8 -> t Get a -> TaggedDecoderM t a
TaggedDecoder Word8
t (IdentityT Get a -> TaggedDecoderM IdentityT a)
-> (Get a -> IdentityT Get a)
-> Get a
-> TaggedDecoderM IdentityT a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Get a -> IdentityT Get a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
infixr 0 #:

-- | Alias for v'TaggedDecoder' constructor.
(##:) :: Word8 -> t Get a -> TaggedDecoderM t a
##: :: forall {k} (t :: (* -> *) -> k -> *) (a :: k).
Word8 -> t Get a -> TaggedDecoderM t a
(##:) = Word8 -> t Get a -> TaggedDecoderM t a
forall {k} (t :: (* -> *) -> k -> *) (a :: k).
Word8 -> t Get a -> TaggedDecoderM t a
TaggedDecoder
infixr 0 ##:

-- | Get a bytestring of the given length leaving no references to the
-- original data in serialized form.
getByteStringCopy :: Int -> Get ByteString
getByteStringCopy :: Int -> Get ByteString
getByteStringCopy = (ByteString -> ByteString) -> Get ByteString -> Get ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> ByteString
BS.copy (Get ByteString -> Get ByteString)
-> (Int -> Get ByteString) -> Int -> Get ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Get ByteString
Get.getByteString

-- | Get remaining available bytes.
--
-- Note that reading all remaining decoded input may be expensive and is thus
-- discouraged, use can use this function only when you know that amount
-- of data to be consumed is limited, e.g. within 'decodeBytesLike' call.
getRemainingByteStringCopy :: Get ByteString
getRemainingByteStringCopy :: Get ByteString
getRemainingByteStringCopy = do
  ByteString
lbs <- Get ByteString
Get.getRemainingLazyByteString
  -- Avoiding memory overflows in case bad length to 'Get.isolate' was provided.
  -- Normally this function is used only to decode primitives, 'Signature' in
  -- the worst case, so we could set little length, but since this is a hack
  -- anyway let's make sure it never obstructs our work.
  Bool -> Get () -> Get ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString -> Int
forall t. Container t => t -> Int
length ByteString
lbs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
640000) (Get () -> Get ()) -> Get () -> Get ()
forall a b. (a -> b) -> a -> b
$ String -> Get ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Too big length for an entity"
  return (ByteString -> ByteString
LBS.toStrict ByteString
lbs)

-- | Fail with "unknown tag" error.
unknownTag :: String -> Word8 -> Get a
unknownTag :: forall a. String -> Word8 -> Get a
unknownTag String
desc Word8
tag =
  String -> Get a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Get a) -> (Builder -> String) -> Builder -> Get a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> String
forall b. FromBuilder b => Builder -> b
fmt (Builder -> Get a) -> Builder -> Get a
forall a b. (a -> b) -> a -> b
$ Builder
"Unknown " Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> String -> Builder
forall p. Buildable p => p -> Builder
build String
desc Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Builder
" tag: 0x" Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Word8 -> Builder
forall a. FormatAsHex a => a -> Builder
hexF Word8
tag

-- | Common decoder for the case when packed data starts with a tag (1 byte)
-- that specifies how to decode remaining data.
--
-- This is a version of 'decodeWithTagM' specialized to naked 'Get' monad.
decodeWithTag :: String -> [TaggedDecoder a] -> Get a
decodeWithTag :: forall a. String -> [TaggedDecoder a] -> Get a
decodeWithTag String
what [TaggedDecoder a]
decoders =
  IdentityT Get a -> Get a
forall {k} (f :: k -> *) (a :: k). IdentityT f a -> f a
runIdentityT (IdentityT Get a -> Get a) -> IdentityT Get a -> Get a
forall a b. (a -> b) -> a -> b
$ String
-> (Word8 -> IdentityT Get a)
-> [TaggedDecoder a]
-> IdentityT Get a
forall (t :: (* -> *) -> * -> *) a.
(MonadTrans t, Monad (t Get)) =>
String -> (Word8 -> t Get a) -> [TaggedDecoderM t a] -> t Get a
decodeWithTagM String
what (Get a -> IdentityT Get a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Get a -> IdentityT Get a)
-> (Word8 -> Get a) -> Word8 -> IdentityT Get a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Word8 -> Get a
forall a. String -> Word8 -> Get a
unknownTag String
what) [TaggedDecoder a]
decoders

-- | Common decoder for the case when packed data starts with a tag (1 byte)
-- that specifies how to decode remaining data.
--
-- This is a general version of 'decodeWithTag' that allows 'Get' to be wrapped
-- in a monad transformer.
decodeWithTagM
  :: (MonadTrans t, Monad (t Get))
  => String
  -> (Word8 -> t Get a)
  -> [TaggedDecoderM t a]
  -> t Get a
decodeWithTagM :: forall (t :: (* -> *) -> * -> *) a.
(MonadTrans t, Monad (t Get)) =>
String -> (Word8 -> t Get a) -> [TaggedDecoderM t a] -> t Get a
decodeWithTagM String
what Word8 -> t Get a
unknownTagFail [TaggedDecoderM t a]
decoders = do
  Word8
tag <- Get Word8 -> t Get Word8
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Get Word8 -> t Get Word8) -> Get Word8 -> t Get Word8
forall a b. (a -> b) -> a -> b
$ String -> Get Word8 -> Get Word8
forall a. String -> Get a -> Get a
Get.label (String
what String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" tag") Get Word8
Get.getWord8
  -- Number of decoders is usually small, so linear runtime lookup should be ok.
  case (TaggedDecoderM t a -> Bool)
-> [TaggedDecoderM t a] -> Maybe (TaggedDecoderM t a)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
List.find ((Word8
tag Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
==) (Word8 -> Bool)
-> (TaggedDecoderM t a -> Word8) -> TaggedDecoderM t a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TaggedDecoderM t a -> Word8
forall {k} (t :: (* -> *) -> k -> *) (a :: k).
TaggedDecoderM t a -> Word8
tdTag) [TaggedDecoderM t a]
decoders of
    Maybe (TaggedDecoderM t a)
Nothing -> Word8 -> t Get a
unknownTagFail Word8
tag
    Just TaggedDecoder{t Get a
Word8
tdDecoder :: t Get a
tdTag :: Word8
tdDecoder :: forall {k} (t :: (* -> *) -> k -> *) (a :: k).
TaggedDecoderM t a -> t Get a
tdTag :: forall {k} (t :: (* -> *) -> k -> *) (a :: k).
TaggedDecoderM t a -> Word8
..} -> t Get a
tdDecoder

decodeBytesLike
  :: (Buildable err)
  => String -> (ByteString -> Either err a) -> Get a
decodeBytesLike :: forall err a.
Buildable err =>
String -> (ByteString -> Either err a) -> Get a
decodeBytesLike String
what ByteString -> Either err a
constructor = do
  ByteString
bs <- Get ByteString
getRemainingByteStringCopy
  case ByteString -> Either err a
constructor ByteString
bs of
    Left err
err -> String -> Get a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Get a) -> String -> Get a
forall a b. (a -> b) -> a -> b
$ Builder
"Wrong " Builder -> Builder -> String
forall b. FromBuilder b => Builder -> Builder -> b
+| String
what String -> Builder -> Builder
forall a b. (Buildable a, FromBuilder b) => a -> Builder -> b
|+ Builder
": " Builder -> Builder -> Builder
forall b. FromBuilder b => Builder -> Builder -> b
+| err
err err -> Builder -> Builder
forall a b. (Buildable a, FromBuilder b) => a -> Builder -> b
|+ Builder
""
    Right a
res -> a -> Get a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
res