{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE OverloadedStrings #-}
module Data.Avro.Decode
( decodeAvro
, decodeContainer
, decodeContainerWith
, getAvroOf
, GetAvro(..)
) where
import Prelude as P
import Control.Monad (replicateM,when)
import qualified Codec.Compression.Zlib as Z
import qualified Data.Aeson as A
import qualified Data.Array as Array
import qualified Data.Binary.Get as G
import Data.Binary.Get (Get,runGetOrFail)
import Data.Binary.IEEE754 as IEEE
import Data.Bits
import qualified Data.ByteString.Lazy as BL
import Data.ByteString (ByteString)
import qualified Data.ByteString.Lazy.Char8 as BC
import Data.Int
import Data.List (foldl')
import qualified Data.List.NonEmpty as NE
import Data.Maybe
import qualified Data.Map as Map
import Data.Monoid ((<>))
import qualified Data.HashMap.Strict as HashMap
import qualified Data.Set as Set
import Data.Text (Text)
import qualified Data.Text as Text
import qualified Data.Text.Encoding as Text
import qualified Data.Vector as V
import Data.Avro.DecodeRaw
import Data.Avro.Zag
import Data.Avro.Schema as S
import qualified Data.Avro.Types as T
decodeAvro :: Schema -> BL.ByteString -> Either String (T.Value Type)
decodeAvro sch = either (\(_,_,s) -> Left s) (\(_,_,a) -> Right a) . runGetOrFail (getAvroOf sch)
{-# INLINABLE decodeAvro #-}
decodeContainer :: BL.ByteString -> Either String (Schema, [[T.Value Type]])
decodeContainer = decodeContainerWith getAvroOf
{-# INLINABLE decodeContainer #-}
decodeContainerWith :: (Schema -> Get a)
-> BL.ByteString
-> Either String (Schema, [[a]])
decodeContainerWith schemaToGet bs =
case runGetOrFail (getContainerWith schemaToGet) bs of
Right (_,_,a) -> Right a
Left (_,_,s) -> Left s
{-# INLINABLE decodeContainerWith #-}
data ContainerHeader = ContainerHeader
{ syncBytes :: !BL.ByteString
, decompress :: BL.ByteString -> Get BL.ByteString
, containedSchema :: !Schema
}
nrSyncBytes :: Integral sb => sb
nrSyncBytes = 16
instance GetAvro ContainerHeader where
getAvro =
do magic <- getFixed avroMagicSize
when (BL.fromStrict magic /= avroMagicBytes)
(fail "Invalid magic number at start of container.")
metadata <- getMap :: Get (Map.Map Text BL.ByteString)
sync <- BL.fromStrict <$> getFixed nrSyncBytes
codec <- getCodec (Map.lookup "avro.codec" metadata)
schema <- case Map.lookup "avro.schema" metadata of
Nothing -> fail "Invalid container object: no schema."
Just s -> case A.eitherDecode' s of
Left e -> fail ("Can not decode container schema: " <> e)
Right x -> return x
return ContainerHeader { syncBytes = sync, decompress = codec, containedSchema = schema }
where avroMagicSize :: Integral a => a
avroMagicSize = 4
avroMagicBytes :: BL.ByteString
avroMagicBytes = BC.pack "Obj" <> BL.pack [1]
getFixed :: Int -> Get ByteString
getFixed = G.getByteString
getContainerWith :: (Schema -> Get a) -> Get (Schema, [[a]])
getContainerWith schemaToGet =
do ContainerHeader {..} <- getAvro
(containedSchema,) <$> getBlocks (schemaToGet containedSchema) syncBytes decompress
where
getBlocks :: Get a -> BL.ByteString -> (BL.ByteString -> Get BL.ByteString) -> Get [[a]]
getBlocks getValue sync decompress =
do nrObj <- sFromIntegral =<< getLong
nrBytes <- getLong
bytes <- decompress =<< G.getLazyByteString nrBytes
r <- case runGetOrFail (replicateM nrObj getValue) bytes of
Right (_,_,x) -> return x
Left (_,_,s) -> fail s
marker <- G.getLazyByteString nrSyncBytes
when (marker /= sync) (fail "Invalid marker, does not match sync bytes.")
e <- G.isEmpty
if e
then return [r]
else (r :) <$> getBlocks getValue sync decompress
getCodec :: Monad m => Maybe BL.ByteString -> m (BL.ByteString -> m BL.ByteString)
getCodec code | Just "null" <- code =
return return
| Just "deflate" <- code =
return (either (fail . show) return . Z.decompress)
| Just x <- code =
fail ("Unrecognized codec: " <> BC.unpack x)
| otherwise = return return
{-# INLINABLE getAvroOf #-}
getAvroOf :: Schema -> Get (T.Value Type)
getAvroOf ty0 = go ty0
where
env = S.buildTypeEnvironment envFail ty0
envFail t = fail $ "Named type not in schema: " <> show t
go :: Type -> Get (T.Value Type)
go ty =
case ty of
Null -> return T.Null
Boolean -> T.Boolean <$> getAvro
Int -> T.Int <$> getAvro
Long -> T.Long <$> getAvro
Float -> T.Float <$> getAvro
Double -> T.Double <$> getAvro
Bytes -> T.Bytes <$> getAvro
String -> T.String <$> getAvro
Array t ->
do vals <- getBlocksOf t
return $ T.Array (V.fromList $ mconcat vals)
Map t ->
do kvs <- getKVBlocks t
return $ T.Map (HashMap.fromList $ mconcat kvs)
NamedType tn -> env tn >>= go
Record {..} ->
do let getField Field {..} = (fldName,) <$> go fldType
T.Record ty . HashMap.fromList <$> mapM getField fields
Enum {..} ->
do val <- getLong
let sym = fromMaybe "" (symbolLookup val)
pure (T.Enum ty (fromIntegral val) sym)
Union ts unionLookup ->
do i <- getLong
case unionLookup i of
Nothing -> fail $ "Decoded Avro tag is outside the expected range for a Union. Tag: " <> show i <> " union of: " <> show (P.map typeName $ NE.toList ts)
Just t -> T.Union ts t <$> go t
Fixed {..} -> T.Fixed ty <$> G.getByteString (fromIntegral size)
getKVBlocks :: Type -> Get [[(Text,T.Value Type)]]
getKVBlocks t =
do blockLength <- abs <$> getLong
if blockLength == 0
then return []
else do vs <- replicateM (fromIntegral blockLength) ((,) <$> getString <*> go t)
(vs:) <$> getKVBlocks t
{-# INLINE getKVBlocks #-}
getBlocksOf :: Type -> Get [[T.Value Type]]
getBlocksOf t =
do blockLength <- abs <$> getLong
if blockLength == 0
then return []
else do vs <- replicateM (fromIntegral blockLength) (go t)
(vs:) <$> getBlocksOf t
{-# INLINE getBlocksOf #-}
class GetAvro a where
getAvro :: Get a
instance GetAvro ty => GetAvro (Map.Map Text ty) where
getAvro = getMap
instance GetAvro Bool where
getAvro = getBoolean
instance GetAvro Int32 where
getAvro = getInt
instance GetAvro Int64 where
getAvro = getLong
instance GetAvro BL.ByteString where
getAvro = BL.fromStrict <$> getBytes
instance GetAvro ByteString where
getAvro = getBytes
instance GetAvro Text where
getAvro = getString
instance GetAvro Float where
getAvro = getFloat
instance GetAvro Double where
getAvro = getDouble
instance GetAvro String where
getAvro = Text.unpack <$> getString
instance GetAvro a => GetAvro [a] where
getAvro = getArray
instance GetAvro a => GetAvro (Maybe a) where
getAvro =
do t <- getLong
case t of
0 -> return Nothing
1 -> Just <$> getAvro
n -> fail $ "Invalid tag for expected {null,a} Avro union, received: " <> show n
instance GetAvro a => GetAvro (Array.Array Int a) where
getAvro =
do ls <- getAvro
return $ Array.listArray (0,length ls - 1) ls
instance GetAvro a => GetAvro (V.Vector a) where
getAvro = V.fromList <$> getAvro
instance (GetAvro a, Ord a) => GetAvro (Set.Set a) where
getAvro = Set.fromList <$> getAvro
getBoolean :: Get Bool
getBoolean =
do w <- G.getWord8
return (w == 0x01)
getInt :: Get Int32
getInt = getZigZag
getLong :: Get Int64
getLong = getZigZag
getZigZag :: (Bits i, Integral i, DecodeRaw i) => Get i
getZigZag = decodeRaw
getBytes :: Get ByteString
getBytes =
do w <- getLong
G.getByteString (fromIntegral w)
getString :: Get Text
getString = Text.decodeUtf8 <$> getBytes
getFloat :: Get Float
getFloat = IEEE.wordToFloat <$> G.getWord32le
getDouble :: Get Double
getDouble = IEEE.wordToDouble <$> G.getWord64le
getArray :: GetAvro ty => Get [ty]
getArray =
do nr <- getLong
if
| nr == 0 -> return []
| nr < 0 ->
do _len <- getLong
rs <- replicateM (fromIntegral (abs nr)) getAvro
(rs <>) <$> getArray
| otherwise ->
do rs <- replicateM (fromIntegral nr) getAvro
(rs <>) <$> getArray
getMap :: GetAvro ty => Get (Map.Map Text ty)
getMap = go Map.empty
where
go acc =
do nr <- getLong
if nr == 0
then return acc
else do m <- Map.fromList <$> replicateM (fromIntegral nr) getKVs
go (Map.union m acc)
getKVs = (,) <$> getString <*> getAvro
sFromIntegral :: forall a b m. (Monad m, Bounded a, Bounded b, Integral a, Integral b) => a -> m b
sFromIntegral a
| aI > fromIntegral (maxBound :: b) ||
aI < fromIntegral (minBound :: b) = fail "Integral overflow."
| otherwise = return (fromIntegral a)
where aI = fromIntegral a :: Integer