module Codec.EBML.Stream (StreamReader, newStreamReader, StreamFrame (..), feedReader) where

import Control.Monad (void, when)
import Data.Binary.Get qualified as Get
import Data.ByteString qualified as BS
import Data.Text (Text)
import Data.Text qualified as Text

import Codec.EBML.Element
import Codec.EBML.Get
import Codec.EBML.Header
import Codec.EBML.Schema
import Codec.EBML.WebM qualified as WebM

-- | A valid frame that can be served.
data StreamFrame = StreamFrame
    { StreamFrame -> ByteString
initialization :: BS.ByteString
    -- ^ The initialization segments, to be provided before the first media segment.
    , StreamFrame -> ByteString
media :: BS.ByteString
    -- ^ The begining of the last media segment found in the input buffer.
    }

-- | Create a stream reader with 'newStreamReader', and decode media segments with 'feedReader'.
data StreamReader = StreamReader
    { StreamReader -> [ByteString]
acc :: [BS.ByteString]
    -- ^ Accumulate data in case the header is not completed in the first buffer.
    , StreamReader -> Int
consumed :: Int
    -- ^ Keep track of the decoder position accross multiple buffers.
    , StreamReader -> Maybe ByteString
header :: Maybe BS.ByteString
    -- ^ The stream initialization segments.
    , StreamReader -> Decoder ()
decoder :: Get.Decoder ()
    -- ^ The current decoder.
    }

streamSchema :: EBMLSchemas
streamSchema :: EBMLSchemas
streamSchema = [EBMLSchema] -> EBMLSchemas
compileSchemas [EBMLSchema]
schemaHeader

-- | Read the initialization frame.
getInitialization :: Get.Get ()
getInitialization :: Get ()
getInitialization = do
    -- Read the EBML header element
    EBMLElement
elt <- EBMLSchemas -> Get EBMLElement
getElement EBMLSchemas
streamSchema
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (EBMLElement
elt.header.eid forall a. Eq a => a -> a -> Bool
/= EBMLID
0x1A45DFA3) do
        forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Invalid magic: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show EBMLElement
elt.header

    -- Read the begining of the first segment, until the first cluster
    EBMLElementHeader
segmentHead <- Get EBMLElementHeader
getElementHeader
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (EBMLElementHeader
segmentHead.eid forall a. Eq a => a -> a -> Bool
/= EBMLID
0x18538067) do
        forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Invalid segment: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show EBMLElementHeader
segmentHead
    [EBMLElement]
elts <- EBMLSchemas -> EBMLID -> Get [EBMLElement]
getUntil EBMLSchemas
streamSchema EBMLID
0x1F43B675
    case [EBMLElement] -> Either Text WebMDocument
WebM.decodeSegment [EBMLElement]
elts of
        Right WebMDocument
_webmDocument -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        Left Text
err -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail (Text -> String
Text.unpack Text
err)

-- | Read a cluster frame.
getCluster :: Get.Get ()
getCluster :: Get ()
getCluster = do
    EBMLElementHeader
clusterHead <- Get EBMLElementHeader
getElementHeader
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (EBMLElementHeader
clusterHead.eid forall a. Eq a => a -> a -> Bool
/= EBMLID
0x1F43B675) do
        forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Invalid cluster: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show EBMLElementHeader
clusterHead
    Get ()
getClusterBody

getClusterBody :: Get.Get ()
getClusterBody :: Get ()
getClusterBody = do
    [EBMLElement]
elts <- EBMLSchemas -> EBMLID -> Get [EBMLElement]
getUntil EBMLSchemas
streamSchema EBMLID
0x1F43B675
    case [EBMLElement]
elts of
        (EBMLElement
elt : [EBMLElement]
_) | EBMLElement
elt.header.eid forall a. Eq a => a -> a -> Bool
== EBMLID
0xE7 -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        [EBMLElement]
_ -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cluster first element is not a timestamp"

getClusterRemaining :: Get.Get ()
getClusterRemaining :: Get ()
getClusterRemaining = do
    EBMLElementHeader
elth <- Get EBMLElementHeader
getElementHeader
    if EBMLElementHeader
elth.eid forall a. Eq a => a -> a -> Bool
== EBMLID
0x1F43B675
        then -- This is in fact a new cluster, get its body
            Get ()
getClusterBody
        else -- This is a cluster left-over, let's keep on reading until a new start
            forall (f :: * -> *) a. Functor f => f a -> f ()
void (EBMLSchemas -> EBMLID -> Get [EBMLElement]
getUntil EBMLSchemas
streamSchema EBMLID
0x1F43B675)

-- | Initialize a stream reader.
newStreamReader :: StreamReader
newStreamReader :: StreamReader
newStreamReader = [ByteString]
-> Int -> Maybe ByteString -> Decoder () -> StreamReader
StreamReader [] Int
0 forall a. Maybe a
Nothing (forall a. Get a -> Decoder a
Get.runGetIncremental Get ()
getInitialization)

-- | Feed data into a stream reader. Returns either an error, or maybe a new 'StreamFrame' and an updated StreamReader.
feedReader :: BS.ByteString -> StreamReader -> Either Text (Maybe StreamFrame, StreamReader)
feedReader :: ByteString
-> StreamReader -> Either Text (Maybe StreamFrame, StreamReader)
feedReader = Maybe StreamFrame
-> ByteString
-> StreamReader
-> Either Text (Maybe StreamFrame, StreamReader)
go forall a. Maybe a
Nothing
  where
    -- This is the end
    go :: Maybe StreamFrame
-> ByteString
-> StreamReader
-> Either Text (Maybe StreamFrame, StreamReader)
go Maybe StreamFrame
Nothing ByteString
"" StreamReader
sr = case forall a. Decoder a -> Decoder a
Get.pushEndOfInput StreamReader
sr.decoder of
        Get.Fail ByteString
_ ByteOffset
_ String
s -> forall a b. a -> Either a b
Left (String -> Text
Text.pack String
s)
        Get.Partial Maybe ByteString -> Decoder ()
_ -> forall a b. a -> Either a b
Left Text
"Missing data"
        Get.Done ByteString
"" ByteOffset
_ ()
_ -> forall a b. b -> Either a b
Right (forall a. Maybe a
Nothing, StreamReader
sr)
        Get.Done{} -> forall a b. a -> Either a b
Left Text
"Left-over data"
    -- Feed the decoder
    go Maybe StreamFrame
mFrame ByteString
bs StreamReader
sr =
        case forall a. Decoder a -> ByteString -> Decoder a
Get.pushChunk StreamReader
sr.decoder ByteString
bs of
            Get.Fail ByteString
_ ByteOffset
_ String
s -> forall a b. a -> Either a b
Left (String -> Text
Text.pack String
s)
            newDecoder :: Decoder ()
newDecoder@(Get.Partial Maybe ByteString -> Decoder ()
_) ->
                let newAcc :: [ByteString]
newAcc = case StreamReader
sr.header of
                        Maybe ByteString
Nothing -> ByteString
bs forall a. a -> [a] -> [a]
: StreamReader
sr.acc
                        -- We don't need to accumulate data once the header is known.
                        Just ByteString
_ -> []
                    newSR :: StreamReader
newSR = [ByteString]
-> Int -> Maybe ByteString -> Decoder () -> StreamReader
StreamReader [ByteString]
newAcc (StreamReader
sr.consumed forall a. Num a => a -> a -> a
+ ByteString -> Int
BS.length ByteString
bs) StreamReader
sr.header Decoder ()
newDecoder
                 in forall a b. b -> Either a b
Right (Maybe StreamFrame
mFrame, StreamReader
newSR)
            Get.Done ByteString
leftover ByteOffset
consumed ()
_
                | ByteString -> Bool
BS.null ByteString
leftover ->
                    -- We might have ended on a in-cluster element, use the remainingDecoder next time
                    forall a b. b -> Either a b
Right (Maybe StreamFrame
mFrame, Decoder () -> StreamReader
newIR Decoder ()
remainingDecoder)
                | Bool
otherwise ->
                    -- There might be a new frame after, keep on decoding
                    Maybe StreamFrame
-> ByteString
-> StreamReader
-> Either Text (Maybe StreamFrame, StreamReader)
go Maybe StreamFrame
newFrame ByteString
leftover (Decoder () -> StreamReader
newIR Decoder ()
segmentDecoder)
              where
                -- The header is either the one already parsed, or the current complete decoded buffer.
                newHeader :: ByteString
newHeader = case StreamReader
sr.header of
                    Just ByteString
header -> ByteString
header
                    Maybe ByteString
Nothing ->
                        let currentPos :: Int
currentPos = forall a b. (Integral a, Num b) => a -> b
fromIntegral ByteOffset
consumed forall a. Num a => a -> a -> a
- StreamReader
sr.consumed
                         in forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse (Int -> ByteString -> ByteString
BS.take Int
currentPos ByteString
bs forall a. a -> [a] -> [a]
: StreamReader
sr.acc)
                -- The new frame starts after what was decoded.
                newFrame :: Maybe StreamFrame
newFrame = forall a. a -> Maybe a
Just (ByteString -> ByteString -> StreamFrame
StreamFrame ByteString
newHeader ByteString
leftover)
                newIR :: Decoder () -> StreamReader
newIR = [ByteString]
-> Int -> Maybe ByteString -> Decoder () -> StreamReader
StreamReader [] Int
0 (forall a. a -> Maybe a
Just ByteString
newHeader)

    remainingDecoder :: Decoder ()
remainingDecoder = forall a. Get a -> Decoder a
Get.runGetIncremental Get ()
getClusterRemaining
    segmentDecoder :: Decoder ()
segmentDecoder = forall a. Get a -> Decoder a
Get.runGetIncremental Get ()
getCluster