{-# language DataKinds #-}
{-# language RankNTypes #-}
{-# language MagicHash #-}
{-# language TypeApplications #-}
{-# language KindSignatures #-}
{-# language ScopedTypeVariables #-}
{-# language PatternSynonyms #-}
{-# language GADTs #-}

-- | Encode data with protobuf. Under the hood, this uses catenable byte
-- sequences, meaning that a large number of small @ByteArray@s are allocated
-- during encoding. That makes this a simple but low-performance solution.
-- To prevent incorrect use, @Builder@ is parameterized by @Value@. This
-- distinguishes encodings of primitives from encodings of sequences of
-- key-value pairs. The functions 'message' and 'pair' are used to move
-- between these two domains.
module Protobuf.Builder
  ( Builder(..)
  , WireType(..)
  , Value(..)
  , run
    -- * Variable Length
  , variableWord8
  , variableWord16
  , variableWord32
  , variableWord64
  , sint64
  , int64
  , sint32
  , int32
    -- * Fixed Length 32
  , fixed32
    -- * Fixed Length 64
  , fixed64
  , double
    -- * Messages
  , message
  , pair
    -- * Length-Delimited 
  , shortText
  , shortByteString
  ) where

import Control.Monad.ST.Run (runByteArrayST)
import Data.Word (Word8,Word16,Word32,Word64)
import Data.Int (Int32,Int64)
import GHC.Exts (Proxy#,proxy#)
import Data.Bits ((.|.),unsafeShiftL)
import Data.Builder.Catenable.Bytes (pattern (:<))
import Data.Bytes (Bytes)
import Data.Word.Zigzag (toZigzag32,toZigzag64)
import Data.Text.Short (ShortText)
import Data.ByteString.Short (ShortByteString)

import qualified Data.Bytes as Bytes
import qualified Data.Bytes.Chunks as Chunks
import qualified Data.Bytes.Builder.Bounded as Bounded
import qualified Data.Builder.Catenable.Bytes as Builder
import qualified Data.Primitive as PM
import qualified Data.Kind as GHC
import qualified Arithmetic.Nat as Nat
import qualified Data.ByteString.Short as SBS
import qualified Data.Text.Short as TS

-- | A protobuf object builder. The data constructor is exposed, but it is
-- unsafe to use it.
--
-- Note that @Builder 'Pairs@ has @Semigroup@ and @Monoid@ instances.
-- This is an important part of the interface.
data Builder :: Value -> GHC.Type where
  Builder ::
       !Int -- length of builder
    -> !Builder.Builder -- builder
    -> Builder v

run :: Builder v -> Bytes
run :: forall (v :: Value). Builder v -> Bytes
run (Builder Int
_ Builder
b) = Chunks -> Bytes
Chunks.concat (Builder -> Chunks
Builder.run Builder
b)

-- | Protobuf\'s four wire types.
data WireType
  = BitsFixed32
  | BitsFixed64
  | BitsVariable
  | Bytes

-- | Either a primitive type (a wire type) or a collector of encoded
-- key-value pairs.
data Value
  = Primitive WireType
  | Pairs

sint32 ::Int32 -> Builder ('Primitive 'BitsVariable)
sint32 :: Int32 -> Builder ('Primitive 'BitsVariable)
sint32 Int32
w = Word32 -> Builder ('Primitive 'BitsVariable)
variableWord32 (Int32 -> Word32
toZigzag32 Int32
w)

int32 ::Int32 -> Builder ('Primitive 'BitsVariable)
int32 :: Int32 -> Builder ('Primitive 'BitsVariable)
int32 Int32
w = Word32 -> Builder ('Primitive 'BitsVariable)
variableWord32 (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
w)

sint64 ::Int64 -> Builder ('Primitive 'BitsVariable)
sint64 :: Int64 -> Builder ('Primitive 'BitsVariable)
sint64 Int64
w = Word64 -> Builder ('Primitive 'BitsVariable)
variableWord64 (Int64 -> Word64
toZigzag64 Int64
w)

int64 ::Int64 -> Builder ('Primitive 'BitsVariable)
int64 :: Int64 -> Builder ('Primitive 'BitsVariable)
int64 Int64
w = Word64 -> Builder ('Primitive 'BitsVariable)
variableWord64 (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
w)

variableWord8 :: Word8 -> Builder ('Primitive 'BitsVariable)
variableWord8 :: Word8 -> Builder ('Primitive 'BitsVariable)
variableWord8 Word8
w = Word64 -> Builder ('Primitive 'BitsVariable)
variableWord64 (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
w)

variableWord16 :: Word16 -> Builder ('Primitive 'BitsVariable)
variableWord16 :: Word16 -> Builder ('Primitive 'BitsVariable)
variableWord16 Word16
w = Word64 -> Builder ('Primitive 'BitsVariable)
variableWord64 (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
w)

variableWord32 :: Word32 -> Builder ('Primitive 'BitsVariable)
variableWord32 :: Word32 -> Builder ('Primitive 'BitsVariable)
variableWord32 Word32
w = Word64 -> Builder ('Primitive 'BitsVariable)
variableWord64 (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
w)

variableWord64 :: Word64 -> Builder ('Primitive 'BitsVariable)
variableWord64 :: Word64 -> Builder ('Primitive 'BitsVariable)
variableWord64 Word64
w =
  let b :: ByteArray
b = forall (n :: Nat). Nat n -> Builder n -> ByteArray
Bounded.run forall (n :: Nat). KnownNat n => Nat n
Nat.constant (Word64 -> Builder 10
Bounded.word64LEB128 Word64
w)
   in forall (v :: Value). Int -> Builder -> Builder v
Builder (ByteArray -> Int
PM.sizeofByteArray ByteArray
b) (ByteArray -> Bytes
Bytes.fromByteArray ByteArray
b Bytes -> Builder -> Builder
:< Builder
Builder.Empty)

fixed32 :: Word32 -> Builder ('Primitive 'BitsFixed32)
fixed32 :: Word32 -> Builder ('Primitive 'BitsFixed32)
fixed32 Word32
w =
  let b :: ByteArray
b = forall (n :: Nat). Nat n -> Builder n -> ByteArray
Bounded.run forall (n :: Nat). KnownNat n => Nat n
Nat.constant (Word32 -> Builder 4
Bounded.word32LE Word32
w)
   in forall (v :: Value). Int -> Builder -> Builder v
Builder (ByteArray -> Int
PM.sizeofByteArray ByteArray
b) (ByteArray -> Bytes
Bytes.fromByteArray ByteArray
b Bytes -> Builder -> Builder
:< Builder
Builder.Empty)

fixed64 :: Word64 -> Builder ('Primitive 'BitsFixed64)
fixed64 :: Word64 -> Builder ('Primitive 'BitsFixed64)
fixed64 Word64
w =
  let b :: ByteArray
b = forall (n :: Nat). Nat n -> Builder n -> ByteArray
Bounded.run forall (n :: Nat). KnownNat n => Nat n
Nat.constant (Word64 -> Builder 8
Bounded.word64LE Word64
w)
   in forall (v :: Value). Int -> Builder -> Builder v
Builder (ByteArray -> Int
PM.sizeofByteArray ByteArray
b) (ByteArray -> Bytes
Bytes.fromByteArray ByteArray
b Bytes -> Builder -> Builder
:< Builder
Builder.Empty)

double :: Double -> Builder ('Primitive 'BitsFixed64)
double :: Double -> Builder ('Primitive 'BitsFixed64)
double Double
w =
  let b :: ByteArray
b = (forall s. ST s ByteArray) -> ByteArray
runByteArrayST forall a b. (a -> b) -> a -> b
$ do
            MutableByteArray s
dst <- forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray Int
8
            forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray s
dst Int
0 Double
w
            forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray s
dst
   in forall (v :: Value). Int -> Builder -> Builder v
Builder Int
8 (ByteArray -> Bytes
Bytes.fromByteArray ByteArray
b Bytes -> Builder -> Builder
:< Builder
Builder.Empty)

pair :: forall (ty :: WireType). HasWireTypeNumber ty
  => Word32 -- ^ key
  -> Builder ('Primitive ty) -- ^ value
  -> Builder 'Pairs
{-# inline pair #-}
pair :: forall (ty :: WireType).
HasWireTypeNumber ty =>
Word32 -> Builder ('Primitive ty) -> Builder 'Pairs
pair Word32
k (Builder Int
valLen Builder
valBuilder) =
  let fullKey :: Word64
fullKey = forall a. Bits a => a -> Int -> a
unsafeShiftL (forall a b. (Integral a, Num b) => a -> b
fromIntegral @Word32 @Word64 Word32
k) Int
3 forall a. Bits a => a -> a -> a
.|. forall a b. (Integral a, Num b) => a -> b
fromIntegral @Word8 @Word64 (forall (t :: WireType). HasWireTypeNumber t => Proxy# t -> Word8
wireTypeNumber (forall {k} (a :: k). Proxy# a
proxy# @ty))
      keyBytes :: ByteArray
keyBytes = forall (n :: Nat). Nat n -> Builder n -> ByteArray
Bounded.run forall (n :: Nat). KnownNat n => Nat n
Nat.constant (Word64 -> Builder 10
Bounded.word64LEB128 Word64
fullKey) 
   in forall (v :: Value). Int -> Builder -> Builder v
Builder (Int
valLen forall a. Num a => a -> a -> a
+ ByteArray -> Int
PM.sizeofByteArray ByteArray
keyBytes) (Bytes -> Builder -> Builder
Builder.Cons (ByteArray -> Bytes
Bytes.fromByteArray ByteArray
keyBytes) Builder
valBuilder)

message :: Builder 'Pairs -> Builder ('Primitive 'Bytes)
message :: Builder 'Pairs -> Builder ('Primitive 'Bytes)
message (Builder Int
len Builder
b) =
  let lenBytes :: ByteArray
lenBytes = forall (n :: Nat). Nat n -> Builder n -> ByteArray
Bounded.run forall (n :: Nat). KnownNat n => Nat n
Nat.constant (Word64 -> Builder 10
Bounded.word64LEB128 (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len))
   in forall (v :: Value). Int -> Builder -> Builder v
Builder (Int
len forall a. Num a => a -> a -> a
+ ByteArray -> Int
PM.sizeofByteArray ByteArray
lenBytes) (Bytes -> Builder -> Builder
Builder.Cons (ByteArray -> Bytes
Bytes.fromByteArray ByteArray
lenBytes) Builder
b)

shortText :: ShortText -> Builder ('Primitive 'Bytes)
shortText :: ShortText -> Builder ('Primitive 'Bytes)
shortText ShortText
t = ShortByteString -> Builder ('Primitive 'Bytes)
shortByteString (ShortText -> ShortByteString
TS.toShortByteString ShortText
t)

shortByteString :: ShortByteString -> Builder ('Primitive 'Bytes)
shortByteString :: ShortByteString -> Builder ('Primitive 'Bytes)
shortByteString ShortByteString
sbs =
  let len :: Int
len = ShortByteString -> Int
SBS.length ShortByteString
sbs
      lenBytes :: ByteArray
lenBytes = forall (n :: Nat). Nat n -> Builder n -> ByteArray
Bounded.run forall (n :: Nat). KnownNat n => Nat n
Nat.constant (Word64 -> Builder 10
Bounded.word64LEB128 (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len))
   in forall (v :: Value). Int -> Builder -> Builder v
Builder (Int
len forall a. Num a => a -> a -> a
+ ByteArray -> Int
PM.sizeofByteArray ByteArray
lenBytes) (ByteArray -> Bytes
Bytes.fromByteArray ByteArray
lenBytes Bytes -> Builder -> Builder
:< ShortByteString -> Bytes
Bytes.fromShortByteString ShortByteString
sbs Bytes -> Builder -> Builder
:< Builder
Builder.Empty)

instance (v ~ 'Pairs) => Monoid (Builder v) where
  mempty :: Builder v
mempty = forall (v :: Value). Int -> Builder -> Builder v
Builder Int
0 Builder
Builder.Empty

instance (v ~ 'Pairs) => Semigroup (Builder v) where
  Builder Int
xlen Builder
x <> :: Builder v -> Builder v -> Builder v
<> Builder Int
ylen Builder
y = forall (v :: Value). Int -> Builder -> Builder v
Builder (Int
xlen forall a. Num a => a -> a -> a
+ Int
ylen) (Builder
x forall a. Semigroup a => a -> a -> a
<> Builder
y)

class HasWireTypeNumber (t :: WireType) where
  wireTypeNumber :: Proxy# t -> Word8

instance HasWireTypeNumber 'BitsFixed32 where
  wireTypeNumber :: Proxy# 'BitsFixed32 -> Word8
wireTypeNumber Proxy# 'BitsFixed32
_ = Word8
5

instance HasWireTypeNumber 'BitsFixed64 where
  wireTypeNumber :: Proxy# 'BitsFixed64 -> Word8
wireTypeNumber Proxy# 'BitsFixed64
_ = Word8
1

instance HasWireTypeNumber 'BitsVariable where
  wireTypeNumber :: Proxy# 'BitsVariable -> Word8
wireTypeNumber Proxy# 'BitsVariable
_ = Word8
0

instance HasWireTypeNumber 'Bytes where
  wireTypeNumber :: Proxy# 'Bytes -> Word8
wireTypeNumber Proxy# 'Bytes
_ = Word8
2