{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE FlexibleInstances   #-}
{-# LANGUAGE MultiWayIf          #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | An internal module that contains common decoding functionality
-- that is shared between Lazy and Strict decoders, as well as
-- generic 'Get' monad helpers.
module Data.Avro.Internal.Get
where

import           Control.Monad              (replicateM)
import           Data.Binary.Get            (Get)
import qualified Data.Binary.Get            as G
import           Data.Binary.IEEE754        as IEEE
import           Data.Bits
import           Data.ByteString            (ByteString)
import qualified Data.ByteString.Lazy       as BL
import           Data.Int
import           Data.Text                  (Text)
import qualified Data.Text.Encoding         as Text
import           Prelude                    as P

import Data.Avro.Internal.DecodeRaw

getBoolean :: Get Bool
getBoolean :: Get Bool
getBoolean =
 do Word8
w <- Get Word8
G.getWord8
    Bool -> Get Bool
forall a. a -> Get a
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> Get Bool) -> Bool -> Get Bool
forall a b. (a -> b) -> a -> b
$! (Word8
w Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x01)

-- |Get a 32-bit int (zigzag encoded, max of 5 bytes)
getInt :: Get Int32
getInt :: Get Int32
getInt = Get Int32
forall i. (Bits i, Integral i, DecodeRaw i) => Get i
getZigZag

-- |Get a 64 bit int (zigzag encoded, max of 10 bytes)
getLong :: Get Int64
getLong :: Get Int64
getLong = Get Int64
forall i. (Bits i, Integral i, DecodeRaw i) => Get i
getZigZag

-- |Get an zigzag encoded integral value consuming bytes till the msb is 0.
getZigZag :: (Bits i, Integral i, DecodeRaw i) => Get i
getZigZag :: forall i. (Bits i, Integral i, DecodeRaw i) => Get i
getZigZag = Get i
forall a. DecodeRaw a => Get a
decodeRaw

getBytes :: Get ByteString
getBytes :: Get ByteString
getBytes = Get Int64
getLong Get Int64 -> (Int64 -> Get ByteString) -> Get ByteString
forall a b. Get a -> (a -> Get b) -> Get b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Int -> Get ByteString
G.getByteString (Int -> Get ByteString)
-> (Int64 -> Int) -> Int64 -> Get ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral)

getBytesLazy :: Get BL.ByteString
getBytesLazy :: Get ByteString
getBytesLazy = Get Int64
getLong Get Int64 -> (Int64 -> Get ByteString) -> Get ByteString
forall a b. Get a -> (a -> Get b) -> Get b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Int64 -> Get ByteString
G.getLazyByteString (Int64 -> Get ByteString)
-> (Int64 -> Int64) -> Int64 -> Get ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral)

getString :: Get Text
getString :: Get Text
getString = do
  ByteString
bytes <- Get ByteString
getBytes
  case ByteString -> Either UnicodeException Text
Text.decodeUtf8' ByteString
bytes of
    Left UnicodeException
unicodeExc -> String -> Get Text
forall a. String -> Get a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (UnicodeException -> String
forall a. Show a => a -> String
show UnicodeException
unicodeExc)
    Right Text
text      -> Text -> Get Text
forall a. a -> Get a
forall (m :: * -> *) a. Monad m => a -> m a
return Text
text

-- a la Java:
--  Bit 31 (the bit that is selected by the mask 0x80000000) represents the
--  sign of the floating-point number. Bits 30-23 (the bits that are
--  selected by the mask 0x7f800000) represent the exponent. Bits 22-0 (the
--  bits that are selected by the mask 0x007fffff) represent the
--  significand (sometimes called the mantissa) of the floating-point
--  number.
--
--  If the argument is positive infinity, the result is 0x7f800000.
--
--  If the argument is negative infinity, the result is 0xff800000.
--
--  If the argument is NaN, the result is 0x7fc00000.
getFloat :: Get Float
getFloat :: Get Float
getFloat = Word32 -> Float
IEEE.wordToFloat (Word32 -> Float) -> Get Word32 -> Get Float
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word32
G.getWord32le

-- As in Java:
--  Bit 63 (the bit that is selected by the mask 0x8000000000000000L)
--  represents the sign of the floating-point number. Bits 62-52 (the bits
--  that are selected by the mask 0x7ff0000000000000L) represent the
--  exponent. Bits 51-0 (the bits that are selected by the mask
--  0x000fffffffffffffL) represent the significand (sometimes called the
--  mantissa) of the floating-point number.
--
--  If the argument is positive infinity, the result is
--  0x7ff0000000000000L.
--
--  If the argument is negative infinity, the result is
--  0xfff0000000000000L.
--
--  If the argument is NaN, the result is 0x7ff8000000000000L
getDouble :: Get Double
getDouble :: Get Double
getDouble = Word64 -> Double
IEEE.wordToDouble (Word64 -> Double) -> Get Word64 -> Get Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word64
G.getWord64le

-- | Avro encodes arrays and maps as a series of blocks. Each block
-- starts with a count of the elements in the block. A series of
-- blocks is always terminated with an empty block (encoded as a 0).
decodeBlocks :: Get a -> Get [a]
decodeBlocks :: forall a. Get a -> Get [a]
decodeBlocks Get a
element = do
  Int64
count <- Get Int64
getLong
  if | Int64
count Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
== Int64
0 -> [a] -> Get [a]
forall a. a -> Get a
forall (m :: * -> *) a. Monad m => a -> m a
return []

     -- negative counts are followed by the number of *bytes* in the
     -- array block
     | Int64
count Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
0  -> do
         Int64
_bytes <- Get Int64
getLong
         [a]
items  <- Int -> Get a -> Get [a]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Int) -> Int64 -> Int
forall a b. (a -> b) -> a -> b
$ Int64 -> Int64
forall a. Num a => a -> a
abs Int64
count) Get a
element'
         [a]
rest   <- Get a -> Get [a]
forall a. Get a -> Get [a]
decodeBlocks Get a
element
         [a] -> Get [a]
forall a. a -> Get a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([a] -> Get [a]) -> [a] -> Get [a]
forall a b. (a -> b) -> a -> b
$ [a]
items [a] -> [a] -> [a]
forall a. Semigroup a => a -> a -> a
<> [a]
rest

     | Bool
otherwise  -> do
         [a]
items <- Int -> Get a -> Get [a]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
count) Get a
element'
         [a]
rest  <- Get a -> Get [a]
forall a. Get a -> Get [a]
decodeBlocks Get a
element
         [a] -> Get [a]
forall a. a -> Get a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([a] -> Get [a]) -> [a] -> Get [a]
forall a b. (a -> b) -> a -> b
$ [a]
items [a] -> [a] -> [a]
forall a. Semigroup a => a -> a -> a
<> [a]
rest
  where element' :: Get a
element' = do
          !a
x <- Get a
element
          a -> Get a
forall a. a -> Get a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x

-- Safe-ish from integral
sFromIntegral :: forall a b m. (Monad m, Bounded a, Bounded b, Integral a, Integral b) => a -> m b
sFromIntegral :: forall a b (m :: * -> *).
(Monad m, Bounded a, Bounded b, Integral a, Integral b) =>
a -> m b
sFromIntegral a
a
  | Integer
aI Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> b -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (b
forall a. Bounded a => a
maxBound :: b) Bool -> Bool -> Bool
||
    Integer
aI Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< b -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (b
forall a. Bounded a => a
minBound :: b)   = String -> m b
forall a. HasCallStack => String -> a
error String
"Integral overflow."
  | Bool
otherwise                           = b -> m b
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
a)
 where aI :: Integer
aI = a -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
a :: Integer