{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}

-- | Binary serializer elements
module RON.Binary.Serialize (
    serialize,
    serializeAtom,
    serializeString,
) where

import RON.Prelude

import Data.Binary qualified as Binary
import Data.Binary.Put (putDoublebe, runPut)
import Data.Bits (bit, shiftL, (.|.))
import Data.ByteString.Lazy (cons, fromStrict)
import Data.ByteString.Lazy qualified as BSL
import Data.Text.Encoding (encodeUtf8)
import Data.ZigZag (zzEncode)

import RON.Binary.Types (Desc (..), Size, descIsOp)
import RON.Types (
    Atom (AFloat, AInteger, AString, AUuid),
    ClosedOp (..),
    Op (..),
    UUID (UUID),
    WireChunk (Closed, Query, Value),
    WireFrame,
    WireReducedChunk (..),
 )
import RON.Util.Word (Word4, b0000, leastSignificant4, safeCast)

-- | Serialize a frame
serialize :: WireFrame -> Either String ByteStringL
serialize chunks = ("RON2" <>) <$> serializeBody
  where
    serializeBody = foldChunks =<< traverse serializeChunk chunks

    chunkSize :: Bool -> Int64 -> Either String ByteStringL
    chunkSize continue x
        | x < bit 31 = Right $ Binary.encode s'
        | otherwise = Left $ "chunk size is too big: " ++ show x
      where
        s = fromIntegral x :: Size
        s'
            | continue = s .|. bit 31
            | otherwise = s

    foldChunks :: [ByteStringL] -> Either String ByteStringL
    foldChunks = \case
        [] -> chunkSize False 0
        [c] -> (<> c) <$> chunkSize False (BSL.length c)
        c : cs ->
            fold
                <$> sequence
                    [chunkSize True (BSL.length c), pure c, foldChunks cs]

-- | Serialize a chunk
serializeChunk :: WireChunk -> Either String ByteStringL
serializeChunk = \case
    Closed op -> serializeClosedOp DOpClosed op
    Value rchunk -> serializeReducedChunk False rchunk
    Query rchunk -> serializeReducedChunk True rchunk

-- | Serialize a closed op
serializeClosedOp :: Desc -> ClosedOp -> Either String ByteStringL
serializeClosedOp desc ClosedOp{..} = do
    keys <-
        sequenceA
            [ serializeUuidReducer reducerId
            , serializeUuidObject objectId
            , serializeUuidOpId opId
            , serializeUuidRef refId
            ]
    payloadValue <- traverse serializeAtom payload
    serializeWithDesc desc $ fold $ keys ++ payloadValue
  where
    Op{opId, refId, payload} = op
    serializeUuidReducer = serializeWithDesc DUuidReducer . serializeUuid
    serializeUuidObject = serializeWithDesc DUuidObject . serializeUuid
    serializeUuidOpId = serializeWithDesc DUuidOp . serializeUuid
    serializeUuidRef = serializeWithDesc DUuidRef . serializeUuid

-- | Serialize a reduced op
serializeReducedOp :: Desc -> UUID -> UUID -> Op -> Either String ByteStringL
serializeReducedOp d reducerId objectId op =
    serializeClosedOp d ClosedOp{reducerId, objectId, op}

-- | Serialize a 'UUID'
serializeUuid :: UUID -> ByteStringL
serializeUuid (UUID x y) = Binary.encode x <> Binary.encode y

-- | Encode descriptor
encodeDesc :: Desc -> Word4
encodeDesc = leastSignificant4 . fromEnum

-- | Prepend serialized bytes with descriptor
serializeWithDesc ::
    Desc ->
    -- | body
    ByteStringL ->
    Either String ByteStringL
serializeWithDesc d body = do
    (lengthDesc, lengthExtended) <- lengthFields
    let descByte = safeCast (encodeDesc d) `shiftL` 4 .|. safeCast lengthDesc
    pure $ descByte `cons` lengthExtended <> body
  where
    len = BSL.length body
    lengthFields = case d of
        DAtomString
            | len == 0 -> Right (b0000, mkLengthExtended)
            | len < 16 -> Right (leastSignificant4 len, BSL.empty)
            | len < bit 31 -> Right (b0000, mkLengthExtended)
            | otherwise -> Left "String is too long"
        _
            | descIsOp d -> Right (b0000, BSL.empty)
            | len < 16 -> Right (leastSignificant4 len, BSL.empty)
            | len == 16 -> Right (b0000, BSL.empty)
            | otherwise -> Left "impossible"
    mkLengthExtended
        | len < 128 = Binary.encode (fromIntegral len :: Word8)
        | otherwise = Binary.encode (fromIntegral len .|. bit 31 :: Word32)

-- | Serialize an 'Atom'
serializeAtom :: Atom -> Either String ByteStringL
serializeAtom = \case
    AFloat f -> serializeWithDesc DAtomFloat $ serializeFloat f
    AInteger i -> serializeWithDesc DAtomInteger $ Binary.encode $ zzEncode64 i
    AString s -> serializeWithDesc DAtomString $ serializeString s
    AUuid u -> serializeWithDesc DAtomUuid $ serializeUuid u
  where
    {-# INLINE zzEncode64 #-}
    zzEncode64 :: Int64 -> Word64
    zzEncode64 = zzEncode

-- | Serialize a float atom
serializeFloat :: Double -> ByteStringL
serializeFloat = runPut . putDoublebe

-- | Serialize a reduced chunk
serializeReducedChunk :: Bool -> WireReducedChunk -> Either String ByteStringL
serializeReducedChunk isQuery WireReducedChunk{..} = do
    header <-
        serializeClosedOp
            (if isQuery then DOpQueryHeader else DOpHeader)
            wrcHeader
    body <- foldMapA (serializeReducedOp DOpReduced reducerId objectId) wrcBody
    pure $ header <> body
  where
    ClosedOp{..} = wrcHeader

-- | Serialize a string atom
serializeString :: Text -> ByteStringL
serializeString = fromStrict . encodeUtf8

foldMapA :: (Applicative f, Foldable t, Monoid b) => (a -> f b) -> t a -> f b
foldMapA f = fmap fold . traverse f . toList
