module Network.QUIC.Simple.Stream
  ( MessageQueues
  , streamSerialise
  , streamCodec
  ) where

import Codec.Serialise (Serialise, serialise, deserialiseIncremental)
import Codec.Serialise qualified as IDecode (IDecode(..))
import Control.Concurrent.Async (Async, async, race_)
import Control.Concurrent.STM
import Control.Exception (finally, throwIO)
import Control.Monad.ST (stToIO)
import Data.ByteString qualified as BS
import Data.ByteString.Lazy qualified as BSL
import Data.IORef
import Network.QUIC qualified as QUIC

{- | A pair of bounded queues wrapping a stream.
-}
type MessageQueues sendMsg recvMsg = (TBQueue sendMsg, TBQueue recvMsg)

{- | Wrap the stream with the CBOR codec for both incoming and outgoing messages.

The decoder will perform incremental parsing and emit complete messages.

No extra framing is required since CBOR is self-delimiting.
-}
streamSerialise
  :: forall sendMsg recvMsg
  . (Serialise sendMsg, Serialise recvMsg)
  => QUIC.Stream
  -> IO (Async (), MessageQueues sendMsg recvMsg)
streamSerialise stream = do
  initial <- stToIO $ deserialiseIncremental @recvMsg
  state <- newIORef initial
  let
    decode starting chunk = do
      decoder <- readIORef state
      case decoder of
        IDecode.Fail _leftovers _offset err ->
          throwIO err -- crash writer (thus the stream, and the reader/writer etc)
        IDecode.Done leftovers _consumed msg -> do
          stToIO deserialiseIncremental >>= writeIORef state -- restart decoder
          pure (leftovers, Just msg)
        IDecode.Partial consume -> do
          -- want more data (initial state?)
          stToIO (consume $ Just chunk) >>= writeIORef state -- step decoder
          if starting then
            -- re-check if done
            decode False ""
          else
            -- suspend and wait for next chunk
            pure ("", Nothing)
  streamCodec serialise (decode True) stream

{- | Wrap the stream with a codec to provide a TBQueue interface to it.

The decoder loop is stateless.
But it runs in IO so you can use external state and terminate the stream by erroring out.
-}
streamCodec
  :: (sendMsg -> BSL.ByteString) -- ^ Encoder for outgoing messages
  -> (BS.ByteString -> IO (BS.ByteString, Maybe recvMsg)) -- ^ Decoder for incomming chunks
  -> QUIC.Stream
  -> IO (Async (), MessageQueues sendMsg recvMsg)
streamCodec encode decode stream = do
  readQ <- newTBQueueIO 1024
  writeQ <- newTBQueueIO 1024
  worker <- async $
    race_ (reader "" readQ) (writer writeQ) `finally` QUIC.closeStream stream
  pure (worker, (writeQ, readQ))
  where
    reader leftovers readQ = do
      chunk <-
        if BS.null leftovers then
          QUIC.recvStream stream 4096
        else
          pure leftovers
      (leftovers', message_) <- decode chunk
      mapM_ (atomically . writeTBQueue readQ) message_
      reader leftovers' readQ

    writer writeQ = do
      message <- atomically $ readTBQueue writeQ
      let chunks = BSL.toChunks $ encode message
      QUIC.sendStreamMany stream chunks
      writer writeQ
