{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE FlexibleInstances   #-}
{-# LANGUAGE MultiWayIf          #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | An internal module that contains common decoding functionality
-- that is shared between Lazy and Strict decoders, as well as
-- generic 'Get' monad helpers.
module Data.Avro.Internal.Get
where

import qualified Codec.Compression.Zlib     as Z
import           Control.Monad              (replicateM, when)
import           Data.Binary.Get            (Get)
import qualified Data.Binary.Get            as G
import           Data.Binary.IEEE754        as IEEE
import           Data.Bits
import           Data.ByteString            (ByteString)
import qualified Data.ByteString.Lazy       as BL
import qualified Data.ByteString.Lazy.Char8 as BC
import           Data.Int
import qualified Data.Map                   as Map
import           Data.Maybe
import           Data.Monoid                ((<>))
import qualified Data.Set                   as Set
import           Data.Text                  (Text)
import qualified Data.Text                  as Text
import qualified Data.Text.Encoding         as Text
import qualified Data.Vector                as V
import           Prelude                    as P

import Data.Avro.Internal.DecodeRaw

getBoolean :: Get Bool
getBoolean :: Get Bool
getBoolean =
 do Word8
w <- Get Word8
G.getWord8
    Bool -> Get Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> Get Bool) -> Bool -> Get Bool
forall a b. (a -> b) -> a -> b
$! (Word8
w Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x01)

-- |Get a 32-bit int (zigzag encoded, max of 5 bytes)
getInt :: Get Int32
getInt :: Get Int32
getInt = Get Int32
forall i. (Bits i, Integral i, DecodeRaw i) => Get i
getZigZag

-- |Get a 64 bit int (zigzag encoded, max of 10 bytes)
getLong :: Get Int64
getLong :: Get Int64
getLong = Get Int64
forall i. (Bits i, Integral i, DecodeRaw i) => Get i
getZigZag

-- |Get an zigzag encoded integral value consuming bytes till the msb is 0.
getZigZag :: (Bits i, Integral i, DecodeRaw i) => Get i
getZigZag :: Get i
getZigZag = Get i
forall a. DecodeRaw a => Get a
decodeRaw

getBytes :: Get ByteString
getBytes :: Get ByteString
getBytes = Get Int64
getLong Get Int64 -> (Int64 -> Get ByteString) -> Get ByteString
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Int -> Get ByteString
G.getByteString (Int -> Get ByteString)
-> (Int64 -> Int) -> Int64 -> Get ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral)

getBytesLazy :: Get BL.ByteString
getBytesLazy :: Get ByteString
getBytesLazy = Get Int64
getLong Get Int64 -> (Int64 -> Get ByteString) -> Get ByteString
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Int64 -> Get ByteString
G.getLazyByteString (Int64 -> Get ByteString)
-> (Int64 -> Int64) -> Int64 -> Get ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral)

getString :: Get Text
getString :: Get Text
getString = do
  ByteString
bytes <- Get ByteString
getBytes
  case ByteString -> Either UnicodeException Text
Text.decodeUtf8' ByteString
bytes of
    Left UnicodeException
unicodeExc -> String -> Get Text
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (UnicodeException -> String
forall a. Show a => a -> String
show UnicodeException
unicodeExc)
    Right Text
text      -> Text -> Get Text
forall (m :: * -> *) a. Monad m => a -> m a
return Text
text

-- a la Java:
--  Bit 31 (the bit that is selected by the mask 0x80000000) represents the
--  sign of the floating-point number. Bits 30-23 (the bits that are
--  selected by the mask 0x7f800000) represent the exponent. Bits 22-0 (the
--  bits that are selected by the mask 0x007fffff) represent the
--  significand (sometimes called the mantissa) of the floating-point
--  number.
--
--  If the argument is positive infinity, the result is 0x7f800000.
--
--  If the argument is negative infinity, the result is 0xff800000.
--
--  If the argument is NaN, the result is 0x7fc00000.
getFloat :: Get Float
getFloat :: Get Float
getFloat = Word32 -> Float
IEEE.wordToFloat (Word32 -> Float) -> Get Word32 -> Get Float
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word32
G.getWord32le

-- As in Java:
--  Bit 63 (the bit that is selected by the mask 0x8000000000000000L)
--  represents the sign of the floating-point number. Bits 62-52 (the bits
--  that are selected by the mask 0x7ff0000000000000L) represent the
--  exponent. Bits 51-0 (the bits that are selected by the mask
--  0x000fffffffffffffL) represent the significand (sometimes called the
--  mantissa) of the floating-point number.
--
--  If the argument is positive infinity, the result is
--  0x7ff0000000000000L.
--
--  If the argument is negative infinity, the result is
--  0xfff0000000000000L.
--
--  If the argument is NaN, the result is 0x7ff8000000000000L
getDouble :: Get Double
getDouble :: Get Double
getDouble = Word64 -> Double
IEEE.wordToDouble (Word64 -> Double) -> Get Word64 -> Get Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word64
G.getWord64le

-- | Avro encodes arrays and maps as a series of blocks. Each block
-- starts with a count of the elements in the block. A series of
-- blocks is always terminated with an empty block (encoded as a 0).
decodeBlocks :: Get a -> Get [a]
decodeBlocks :: Get a -> Get [a]
decodeBlocks Get a
element = do
  Int64
count <- Get Int64
getLong
  if | Int64
count Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
== Int64
0 -> [a] -> Get [a]
forall (m :: * -> *) a. Monad m => a -> m a
return []

     -- negative counts are followed by the number of *bytes* in the
     -- array block
     | Int64
count Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
0  -> do
         Int64
_bytes <- Get Int64
getLong
         [a]
items  <- Int -> Get a -> Get [a]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Int) -> Int64 -> Int
forall a b. (a -> b) -> a -> b
$ Int64 -> Int64
forall a. Num a => a -> a
abs Int64
count) Get a
element'
         [a]
rest   <- Get a -> Get [a]
forall a. Get a -> Get [a]
decodeBlocks Get a
element
         [a] -> Get [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([a] -> Get [a]) -> [a] -> Get [a]
forall a b. (a -> b) -> a -> b
$ [a]
items [a] -> [a] -> [a]
forall a. Semigroup a => a -> a -> a
<> [a]
rest

     | Bool
otherwise  -> do
         [a]
items <- Int -> Get a -> Get [a]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
count) Get a
element'
         [a]
rest  <- Get a -> Get [a]
forall a. Get a -> Get [a]
decodeBlocks Get a
element
         [a] -> Get [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([a] -> Get [a]) -> [a] -> Get [a]
forall a b. (a -> b) -> a -> b
$ [a]
items [a] -> [a] -> [a]
forall a. Semigroup a => a -> a -> a
<> [a]
rest
  where element' :: Get a
element' = do
          !a
x <- Get a
element
          a -> Get a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x

-- Safe-ish from integral
sFromIntegral :: forall a b m. (Monad m, Bounded a, Bounded b, Integral a, Integral b) => a -> m b
sFromIntegral :: a -> m b
sFromIntegral a
a
  | Integer
aI Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> b -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (b
forall a. Bounded a => a
maxBound :: b) Bool -> Bool -> Bool
||
    Integer
aI Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< b -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (b
forall a. Bounded a => a
minBound :: b)   = String -> m b
forall a. HasCallStack => String -> a
error String
"Integral overflow."
  | Bool
otherwise                           = b -> m b
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
a)
 where aI :: Integer
aI = a -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
a :: Integer