-- | EBML core data decoder, see: https://matroska-org.github.io/libebml/specs.html
module Codec.EBML.Element where

import Data.Binary.Get (Get, getWord8)
import Data.Bits (Bits (shift, testBit, (.|.)), (.&.))
import Data.ByteString (ByteString)
import Data.Int (Int64)
import Data.Text (Text)
import Data.Word (Word32, Word64)

-- | EBML document structure, including the Header and Body Root.
newtype EBMLDocument = EBMLDocument [EBMLElement]

-- | EBML element id.
newtype EBMLID = EBMLID Word32
    deriving (Int -> EBMLID -> ShowS
[EBMLID] -> ShowS
EBMLID -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [EBMLID] -> ShowS
$cshowList :: [EBMLID] -> ShowS
show :: EBMLID -> String
$cshow :: EBMLID -> String
showsPrec :: Int -> EBMLID -> ShowS
$cshowsPrec :: Int -> EBMLID -> ShowS
Show)
    deriving newtype (Integer -> EBMLID
EBMLID -> EBMLID
EBMLID -> EBMLID -> EBMLID
forall a.
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (Integer -> a)
-> Num a
fromInteger :: Integer -> EBMLID
$cfromInteger :: Integer -> EBMLID
signum :: EBMLID -> EBMLID
$csignum :: EBMLID -> EBMLID
abs :: EBMLID -> EBMLID
$cabs :: EBMLID -> EBMLID
negate :: EBMLID -> EBMLID
$cnegate :: EBMLID -> EBMLID
* :: EBMLID -> EBMLID -> EBMLID
$c* :: EBMLID -> EBMLID -> EBMLID
- :: EBMLID -> EBMLID -> EBMLID
$c- :: EBMLID -> EBMLID -> EBMLID
+ :: EBMLID -> EBMLID -> EBMLID
$c+ :: EBMLID -> EBMLID -> EBMLID
Num, EBMLID -> EBMLID -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: EBMLID -> EBMLID -> Bool
$c/= :: EBMLID -> EBMLID -> Bool
== :: EBMLID -> EBMLID -> Bool
$c== :: EBMLID -> EBMLID -> Bool
Eq, Eq EBMLID
EBMLID -> EBMLID -> Bool
EBMLID -> EBMLID -> Ordering
EBMLID -> EBMLID -> EBMLID
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: EBMLID -> EBMLID -> EBMLID
$cmin :: EBMLID -> EBMLID -> EBMLID
max :: EBMLID -> EBMLID -> EBMLID
$cmax :: EBMLID -> EBMLID -> EBMLID
>= :: EBMLID -> EBMLID -> Bool
$c>= :: EBMLID -> EBMLID -> Bool
> :: EBMLID -> EBMLID -> Bool
$c> :: EBMLID -> EBMLID -> Bool
<= :: EBMLID -> EBMLID -> Bool
$c<= :: EBMLID -> EBMLID -> Bool
< :: EBMLID -> EBMLID -> Bool
$c< :: EBMLID -> EBMLID -> Bool
compare :: EBMLID -> EBMLID -> Ordering
$ccompare :: EBMLID -> EBMLID -> Ordering
Ord)

-- | EBML element header.
data EBMLElementHeader = EBMLElementHeader
    { EBMLElementHeader -> EBMLID
eid :: EBMLID
    , EBMLElementHeader -> Maybe Word64
size :: Maybe Word64
    -- ^ size is Nothing for unknown-sized element.
    }
    deriving (EBMLElementHeader -> EBMLElementHeader -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: EBMLElementHeader -> EBMLElementHeader -> Bool
$c/= :: EBMLElementHeader -> EBMLElementHeader -> Bool
== :: EBMLElementHeader -> EBMLElementHeader -> Bool
$c== :: EBMLElementHeader -> EBMLElementHeader -> Bool
Eq, Int -> EBMLElementHeader -> ShowS
[EBMLElementHeader] -> ShowS
EBMLElementHeader -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [EBMLElementHeader] -> ShowS
$cshowList :: [EBMLElementHeader] -> ShowS
show :: EBMLElementHeader -> String
$cshow :: EBMLElementHeader -> String
showsPrec :: Int -> EBMLElementHeader -> ShowS
$cshowsPrec :: Int -> EBMLElementHeader -> ShowS
Show)

-- | EBML element.
data EBMLElement = EBMLElement
    { EBMLElement -> EBMLElementHeader
header :: EBMLElementHeader
    , EBMLElement -> EBMLValue
value :: EBMLValue
    }
    deriving (EBMLElement -> EBMLElement -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: EBMLElement -> EBMLElement -> Bool
$c/= :: EBMLElement -> EBMLElement -> Bool
== :: EBMLElement -> EBMLElement -> Bool
$c== :: EBMLElement -> EBMLElement -> Bool
Eq, Int -> EBMLElement -> ShowS
[EBMLElement] -> ShowS
EBMLElement -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [EBMLElement] -> ShowS
$cshowList :: [EBMLElement] -> ShowS
show :: EBMLElement -> String
$cshow :: EBMLElement -> String
showsPrec :: Int -> EBMLElement -> ShowS
$cshowsPrec :: Int -> EBMLElement -> ShowS
Show)

-- | EBML element value.
data EBMLValue
    = EBMLRoot [EBMLElement]
    | EBMLSignedInteger Int64
    | EBMLUnsignedInteger Word64
    | EBMLFloat Double
    | EBMLText Text
    | EBMLDate Text
    | EBMLBinary ByteString
    deriving (EBMLValue -> EBMLValue -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: EBMLValue -> EBMLValue -> Bool
$c/= :: EBMLValue -> EBMLValue -> Bool
== :: EBMLValue -> EBMLValue -> Bool
$c== :: EBMLValue -> EBMLValue -> Bool
Eq, Int -> EBMLValue -> ShowS
[EBMLValue] -> ShowS
EBMLValue -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [EBMLValue] -> ShowS
$cshowList :: [EBMLValue] -> ShowS
show :: EBMLValue -> String
$cshow :: EBMLValue -> String
showsPrec :: Int -> EBMLValue -> ShowS
$cshowsPrec :: Int -> EBMLValue -> ShowS
Show)

getElementHeader :: Get EBMLElementHeader
getElementHeader :: Get EBMLElementHeader
getElementHeader = EBMLID -> Maybe Word64 -> EBMLElementHeader
EBMLElementHeader forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get EBMLID
getElementID forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Get (Maybe Word64)
getMaybeDataSize

getElementID :: Get EBMLID
getElementID :: Get EBMLID
getElementID =
    Word32 -> EBMLID
EBMLID forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> do
        Word8
b1 <- Get Word8
getWord8
        let w1 :: Word32
w1 = forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
b1
        if
                | Word8
b1 forall a. Bits a => a -> Int -> Bool
`testBit` Int
7 -> forall a. (Num a, Bits a) => Int -> a -> Get a
getVar Int
0 Word32
w1
                | Word8
b1 forall a. Bits a => a -> Int -> Bool
`testBit` Int
6 -> forall a. (Num a, Bits a) => Int -> a -> Get a
getVar Int
1 Word32
w1
                | Word8
b1 forall a. Bits a => a -> Int -> Bool
`testBit` Int
5 -> forall a. (Num a, Bits a) => Int -> a -> Get a
getVar Int
2 Word32
w1
                | Word8
b1 forall a. Bits a => a -> Int -> Bool
`testBit` Int
4 -> forall a. (Num a, Bits a) => Int -> a -> Get a
getVar Int
3 Word32
w1
                | Bool
otherwise -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String
"Invalid width: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Word8
b1)

getMaybeDataSize :: Get (Maybe Word64)
getMaybeDataSize :: Get (Maybe Word64)
getMaybeDataSize = do
    Word64
sz <- Get Word64
getDataSize
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
        -- TODO: better check for unknown-sized for different VINT_DATA size.
        -- though, it seems like this is the common value.
        if Word64
sz forall a. Eq a => a -> a -> Bool
== Word64
0xFFFFFFFFFFFFFF
            then forall a. Maybe a
Nothing
            else forall a. a -> Maybe a
Just Word64
sz

getDataSize :: Get Word64
getDataSize :: Get Word64
getDataSize = do
    Word8
b1 <- Get Word8
getWord8
    if
            | Word8
b1 forall a. Bits a => a -> Int -> Bool
`testBit` Int
7 -> forall a. (Num a, Bits a) => Int -> a -> Get a
getVar Int
0 (forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8
b1 forall a. Bits a => a -> a -> a
.&. Word8
127))
            | Word8
b1 forall a. Bits a => a -> Int -> Bool
`testBit` Int
6 -> forall a. (Num a, Bits a) => Int -> a -> Get a
getVar Int
1 (forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8
b1 forall a. Bits a => a -> a -> a
.&. Word8
63))
            | Word8
b1 forall a. Bits a => a -> Int -> Bool
`testBit` Int
5 -> forall a. (Num a, Bits a) => Int -> a -> Get a
getVar Int
2 (forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8
b1 forall a. Bits a => a -> a -> a
.&. Word8
31))
            | Word8
b1 forall a. Bits a => a -> Int -> Bool
`testBit` Int
4 -> forall a. (Num a, Bits a) => Int -> a -> Get a
getVar Int
3 (forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8
b1 forall a. Bits a => a -> a -> a
.&. Word8
15))
            | Word8
b1 forall a. Bits a => a -> Int -> Bool
`testBit` Int
3 -> forall a. (Num a, Bits a) => Int -> a -> Get a
getVar Int
4 (forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8
b1 forall a. Bits a => a -> a -> a
.&. Word8
7))
            | Word8
b1 forall a. Bits a => a -> Int -> Bool
`testBit` Int
2 -> forall a. (Num a, Bits a) => Int -> a -> Get a
getVar Int
5 (forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8
b1 forall a. Bits a => a -> a -> a
.&. Word8
3))
            | Word8
b1 forall a. Bits a => a -> Int -> Bool
`testBit` Int
1 -> forall a. (Num a, Bits a) => Int -> a -> Get a
getVar Int
6 (forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8
b1 forall a. Bits a => a -> a -> a
.&. Word8
1))
            | Word8
b1 forall a. Bits a => a -> Int -> Bool
`testBit` Int
0 -> forall a. (Num a, Bits a) => Int -> a -> Get a
getVar Int
7 Word64
0
            | Bool
otherwise -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Word64
0

getVar :: (Num a, Bits a) => Int -> a -> Get a
getVar :: forall a. (Num a, Bits a) => Int -> a -> Get a
getVar Int
0 a
acc = forall (f :: * -> *) a. Applicative f => a -> f a
pure a
acc
getVar Int
n a
acc = do
    Word8
b <- Get Word8
getWord8
    forall a. (Num a, Bits a) => Int -> a -> Get a
getVar (Int
n forall a. Num a => a -> a -> a
- Int
1) ((a
acc forall a. Bits a => a -> Int -> a
`shift` Int
8) forall a. Bits a => a -> a -> a
.|. forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
b)