-- | gRPC-style base64-encoding
--
-- The gRPC specification mandates standard Base64-encoding for binary headers
-- <https://datatracker.ietf.org/doc/html/rfc4648#section-4>, /but/ without
-- padding.
module Network.GRPC.Spec.Serialization.Base64 (
    encodeBase64
  , decodeBase64
  ) where

import Control.Monad
import Data.ByteString qualified as BS.Strict
import Data.ByteString qualified as Strict (ByteString)
import Data.ByteString.Base64 qualified as BS.Strict.B64
import Data.ByteString.Char8 qualified as BS.Strict.Char8

{-------------------------------------------------------------------------------
  Top-level
-------------------------------------------------------------------------------}

encodeBase64 :: Strict.ByteString -> Strict.ByteString
encodeBase64 :: ByteString -> ByteString
encodeBase64 = ByteString -> ByteString
removePadding (ByteString -> ByteString)
-> (ByteString -> ByteString) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
BS.Strict.B64.encode

decodeBase64 :: Strict.ByteString -> Either String Strict.ByteString
decodeBase64 :: ByteString -> Either String ByteString
decodeBase64 = ByteString -> Either String ByteString
BS.Strict.B64.decode (ByteString -> Either String ByteString)
-> (ByteString -> Either String ByteString)
-> ByteString
-> Either String ByteString
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< ByteString -> Either String ByteString
addPadding

{-------------------------------------------------------------------------------
  Internal: Adding and removing padding

  In Base64 encoding, every group of three bytes in the input is represented as
  4 bytes in the output. We therefore have three possibilities:

  * The input has size @3n@. No padding bytes are added.
  * The input has size @3n + 1@. Two padding byte are added.
  * The input has size @3n + 2@. One padding byte is added.

  Standard base64 encoding (including padding) therefore /always/ has size @4m@.
-------------------------------------------------------------------------------}

removePadding :: Strict.ByteString -> Strict.ByteString
removePadding :: ByteString -> ByteString
removePadding ByteString
bs
  -- Empty bytestring
  --
  -- If the bytestring is not null, it must have at least 4 bytes, justifying
  -- the calls to @index@ below.
  | ByteString -> Bool
BS.Strict.null ByteString
bs
  = ByteString
bs

  -- Two padding bytes
  | ByteString -> Int -> Char
BS.Strict.Char8.index ByteString
bs (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2) Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'='
  = Int -> ByteString -> ByteString
BS.Strict.take (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2) ByteString
bs

  -- One padding byte
  | ByteString -> Int -> Char
BS.Strict.Char8.index ByteString
bs (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'='
  = Int -> ByteString -> ByteString
BS.Strict.take (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) ByteString
bs

  | Bool
otherwise
  = ByteString
bs
  where
    len :: Int
    len :: Int
len = ByteString -> Int
BS.Strict.length ByteString
bs

addPadding :: Strict.ByteString -> Either String Strict.ByteString
addPadding :: ByteString -> Either String ByteString
addPadding ByteString
bs
  -- Three padding bytes (i.e., invalid strict)
  | Int
len Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1
  = String -> Either String ByteString
forall a b. a -> Either a b
Left (String -> Either String ByteString)
-> String -> Either String ByteString
forall a b. (a -> b) -> a -> b
$ String
"Invalid length of unpadded base64-encoded string " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
len

  -- Two padding bytes
  | Int
len Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2
  = ByteString -> Either String ByteString
forall a b. b -> Either a b
Right (ByteString -> Either String ByteString)
-> ByteString -> Either String ByteString
forall a b. (a -> b) -> a -> b
$ ByteString
bs ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> String -> ByteString
BS.Strict.Char8.pack String
"=="

  -- One padding bytes
  | Int
len Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
3
  = ByteString -> Either String ByteString
forall a b. b -> Either a b
Right (ByteString -> Either String ByteString)
-> ByteString -> Either String ByteString
forall a b. (a -> b) -> a -> b
$ ByteString
bs ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> String -> ByteString
BS.Strict.Char8.pack String
"="

  -- No padding (this includes the empty string)
  | Bool
otherwise
  = ByteString -> Either String ByteString
forall a b. b -> Either a b
Right ByteString
bs
  where
    len :: Int
    len :: Int
len = ByteString -> Int
BS.Strict.length ByteString
bs