{-# language BangPatterns #-}
{-# language DataKinds #-}
{-# language DuplicateRecordFields #-}
{-# language LambdaCase #-}
{-# language MultiWayIf #-}
{-# language NamedFieldPuns #-}
{-# language NumericUnderscores #-}
{-# language TypeApplications #-}

module Asn.Ber.Encode
  ( encode
  ) where

import Prelude hiding (length)

import Asn.Ber (Value(..),Contents(..),Class(..))
import Asn.Oid (Oid(..))
import Control.Monad.ST (runST)
import Data.Bits ((.&.),(.|.),unsafeShiftL,unsafeShiftR,bit,testBit)
import Data.Bytes.Types (Bytes(Bytes))
import Data.ByteString.Short.Internal (ShortByteString(SBS))
import Data.Foldable (foldMap',foldlM)
import Data.Int (Int64)
import Data.Primitive (SmallArray,PrimArray)
import Data.Primitive.ByteArray (byteArrayFromList,ByteArray(ByteArray))
import Data.Word (Word8,Word32)

import qualified Data.Primitive as Prim
import qualified Data.Primitive.Contiguous as C
import qualified Data.Bytes as Bytes
import qualified Data.Bytes.Builder.Bounded as BB
import qualified Data.Bytes.Types
import qualified Data.Text.Short as TS
import qualified Chronos
import qualified Arithmetic.Nat as Nat

data Encoder
  = Leaf {-# UNPACK #-} !Bytes
  | Node
    { Encoder -> Int
_length :: !Int
    , Encoder -> SmallArray Encoder
_children :: !(SmallArray Encoder)
    }
  deriving(Int -> Encoder -> ShowS
[Encoder] -> ShowS
Encoder -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Encoder] -> ShowS
$cshowList :: [Encoder] -> ShowS
show :: Encoder -> String
$cshow :: Encoder -> String
showsPrec :: Int -> Encoder -> ShowS
$cshowsPrec :: Int -> Encoder -> ShowS
Show)

length :: Encoder -> Int
length :: Encoder -> Int
length (Leaf Bytes
bs) = Bytes -> Int
Bytes.length Bytes
bs
length a :: Encoder
a@(Node Int
_ SmallArray Encoder
_) = Encoder -> Int
_length Encoder
a

instance Semigroup Encoder where
  Encoder
a <> :: Encoder -> Encoder -> Encoder
<> Encoder
b
    | Encoder -> Int
length Encoder
a forall a. Eq a => a -> a -> Bool
== Int
0 = Encoder
b
    | Encoder -> Int
length Encoder
b forall a. Eq a => a -> a -> Bool
== Int
0 = Encoder
a
  Encoder
a <> Encoder
b = Node
    { $sel:_length:Leaf :: Int
_length = Encoder -> Int
length Encoder
a forall a. Num a => a -> a -> a
+ Encoder -> Int
length Encoder
b
    , $sel:_children:Leaf :: SmallArray Encoder
_children = forall (arr :: * -> *) a.
(Contiguous arr, Element arr a) =>
a -> a -> arr a
C.doubleton Encoder
a Encoder
b
    }

instance Monoid Encoder where
  mempty :: Encoder
mempty = Int -> SmallArray Encoder -> Encoder
Node Int
0 forall a. Monoid a => a
mempty

append3 :: Encoder -> Encoder -> Encoder -> Encoder
append3 :: Encoder -> Encoder -> Encoder -> Encoder
append3 Encoder
a Encoder
b Encoder
c = Node
  { $sel:_length:Leaf :: Int
_length = Encoder -> Int
length Encoder
a forall a. Num a => a -> a -> a
+ Encoder -> Int
length Encoder
b forall a. Num a => a -> a -> a
+ Encoder -> Int
length Encoder
c
  , $sel:_children:Leaf :: SmallArray Encoder
_children = forall (arr :: * -> *) a.
(Contiguous arr, Element arr a) =>
a -> a -> a -> arr a
C.tripleton Encoder
a Encoder
b Encoder
c
  }

word8 :: Word8 -> Encoder
word8 :: Word8 -> Encoder
word8 = Bytes -> Encoder
Leaf forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Bytes
Bytes.singleton

singleton :: Bytes -> Encoder
singleton :: Bytes -> Encoder
singleton = Bytes -> Encoder
Leaf

run :: Encoder -> Bytes
run :: Encoder -> Bytes
run (Leaf Bytes
bs) = Bytes
bs
run Node{$sel:_length:Leaf :: Encoder -> Int
_length=Int
len0,$sel:_children:Leaf :: Encoder -> SmallArray Encoder
_children=SmallArray Encoder
children0} = forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
  MutableByteArray (PrimState (ST s))
dst <- forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
Prim.newByteArray Int
len0
  let go :: Int -> Encoder -> ST s Int
go !Int
ixA Encoder
eA = case Encoder
eA of
        Leaf Bytes
bs -> do
          forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> Int -> Bytes -> m ()
Bytes.unsafeCopy MutableByteArray (PrimState (ST s))
dst Int
ixA Bytes
bs
          forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bytes -> Int
Bytes.length Bytes
bs forall a. Num a => a -> a -> a
+ Int
ixA)
        Node{Int
_length :: Int
$sel:_length:Leaf :: Encoder -> Int
_length,SmallArray Encoder
_children :: SmallArray Encoder
$sel:_children:Leaf :: Encoder -> SmallArray Encoder
_children} -> do
          forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM (\Int
ixB Encoder
eB -> Int -> Encoder -> ST s Int
go Int
ixB Encoder
eB) Int
ixA SmallArray Encoder
_children
  Int
ixC <- forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM (\Int
ixA Encoder
e -> Int -> Encoder -> ST s Int
go Int
ixA Encoder
e) Int
0 SmallArray Encoder
children0
  if Int
ixC forall a. Eq a => a -> a -> Bool
/= Int
len0
    then forall a. String -> a
errorWithoutStackTrace String
"Asn.Ber.Encode.run: implementation mistake"
    else do
      ByteArray
dst' <- forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
Prim.unsafeFreezeByteArray MutableByteArray (PrimState (ST s))
dst
      forall (f :: * -> *) a. Applicative f => a -> f a
pure Bytes{$sel:array:Bytes :: ByteArray
array=ByteArray
dst',$sel:offset:Bytes :: Int
offset=Int
0,$sel:length:Bytes :: Int
length=Int
len0}

encode :: Value -> Bytes
encode :: Value -> Bytes
encode = Encoder -> Bytes
run forall b c a. (b -> c) -> (a -> b) -> a -> c
. Value -> Encoder
encodeValue

encodeValue :: Value -> Encoder
encodeValue :: Value -> Encoder
encodeValue v :: Value
v@Value{Contents
contents :: Value -> Contents
contents :: Contents
contents} =
  let theContent :: Encoder
theContent = Contents -> Encoder
encodeContents Contents
contents
   in Encoder -> Encoder -> Encoder -> Encoder
append3 (Value -> Encoder
valueHeader Value
v) (Int -> Encoder
encodeLength (Encoder -> Int
length Encoder
theContent)) Encoder
theContent

valueHeader :: Value -> Encoder
valueHeader :: Value -> Encoder
valueHeader Value{Class
tagClass :: Value -> Class
tagClass :: Class
tagClass,Word32
tagNumber :: Value -> Word32
tagNumber :: Word32
tagNumber,Contents
contents :: Contents
contents :: Value -> Contents
contents} = Encoder
byte1 forall a. Semigroup a => a -> a -> a
<> Encoder
extTag
  where
  byte1 :: Encoder
byte1 = Word8 -> Encoder
word8 (Word8
clsBits forall a. Bits a => a -> a -> a
.|. Word8
pcBits forall a. Bits a => a -> a -> a
.|. Word8
tagBits)
  clsBits :: Word8
clsBits = (forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
6) forall a b. (a -> b) -> a -> b
$ case Class
tagClass of
    Class
Universal -> Word8
0
    Class
Application -> Word8
1
    Class
ContextSpecific -> Word8
2
    Class
Private -> Word8
3
  pcBits :: Word8
pcBits = case Contents
contents of
    Constructed SmallArray Value
_ -> forall a. Bits a => Int -> a
bit Int
5
    Contents
_ -> Word8
0x00
  tagBits :: Word8
tagBits = forall a b. (Integral a, Num b) => a -> b
fromIntegral @Word32 @Word8 forall a b. (a -> b) -> a -> b
$ forall a. Ord a => a -> a -> a
min Word32
tagNumber Word32
31
  extTag :: Encoder
extTag
    | Word32
tagNumber forall a. Ord a => a -> a -> Bool
< Word32
31 = forall a. Monoid a => a
mempty
    | Bool
otherwise = Int64 -> Encoder
base128 (forall a b. (Integral a, Num b) => a -> b
fromIntegral @Word32 @Int64 Word32
tagNumber) -- FIXME use an unsigned base128 encoder

encodeLength :: Int -> Encoder
encodeLength :: Int -> Encoder
encodeLength Int
n
  | Int
n forall a. Ord a => a -> a -> Bool
< Int
128 = Word8 -> Encoder
word8 forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int @Word8 Int
n
  | Bool
otherwise =
    let len :: Encoder
len = Int64 -> Encoder
base256 (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)
        lenHeader :: Encoder
lenHeader = Word8 -> Encoder
word8 forall a b. (a -> b) -> a -> b
$ forall a. Bits a => Int -> a
bit Int
7 forall a. Bits a => a -> a -> a
.|. (forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int @Word8 (Encoder -> Int
length Encoder
len))
     in Encoder
lenHeader forall a. Semigroup a => a -> a -> a
<> Encoder
len

encodeContents :: Contents -> Encoder
encodeContents :: Contents -> Encoder
encodeContents = \case
  Integer Int64
n -> Int64 -> Encoder
base256 Int64
n
  Boolean Bool
b -> case Bool
b of
    Bool
True -> Word8 -> Encoder
word8 Word8
0xFF
    Bool
False -> Word8 -> Encoder
word8 Word8
0x00
  UtcTime Int64
epochSeconds ->
    let t :: Time
t = Int64 -> Time
Chronos.Time (Int64
epochSeconds forall a. Num a => a -> a -> a
* Int64
1_000_000_000)
     in case Time -> Datetime
Chronos.timeToDatetime Time
t of
          Chronos.Datetime
            { datetimeDate :: Datetime -> Date
datetimeDate = Chronos.Date
              { dateYear :: Date -> Year
dateYear = Chronos.Year Int
year
              , dateMonth :: Date -> Month
dateMonth = Chronos.Month Int
month
              , dateDay :: Date -> DayOfMonth
dateDay = Chronos.DayOfMonth Int
day
              }
            , datetimeTime :: Datetime -> TimeOfDay
datetimeTime = Chronos.TimeOfDay
              { timeOfDayHour :: TimeOfDay -> Int
timeOfDayHour = Int
hour
              , timeOfDayMinute :: TimeOfDay -> Int
timeOfDayMinute = Int
minute
              , timeOfDayNanoseconds :: TimeOfDay -> Int64
timeOfDayNanoseconds = Int64
nanoseconds
              }
            } -> Bytes -> Encoder
Leaf forall a b. (a -> b) -> a -> b
$ ByteArray -> Bytes
Bytes.fromByteArray forall a b. (a -> b) -> a -> b
$ forall (n :: Nat). Nat n -> Builder n -> ByteArray
BB.run forall (n :: Nat). KnownNat n => Nat n
Nat.constant forall a b. (a -> b) -> a -> b
$
              Int -> Builder 2
encodeTwoDigit (forall a. Integral a => a -> a -> a
rem Int
year Int
100)
              forall (m :: Nat) (n :: Nat).
Builder m -> Builder n -> Builder (m + n)
`BB.append`
              Int -> Builder 2
encodeTwoDigit (Int
month forall a. Num a => a -> a -> a
+ Int
1)
              forall (m :: Nat) (n :: Nat).
Builder m -> Builder n -> Builder (m + n)
`BB.append`
              Int -> Builder 2
encodeTwoDigit Int
day
              forall (m :: Nat) (n :: Nat).
Builder m -> Builder n -> Builder (m + n)
`BB.append`
              Int -> Builder 2
encodeTwoDigit Int
hour
              forall (m :: Nat) (n :: Nat).
Builder m -> Builder n -> Builder (m + n)
`BB.append`
              Int -> Builder 2
encodeTwoDigit Int
minute
              forall (m :: Nat) (n :: Nat).
Builder m -> Builder n -> Builder (m + n)
`BB.append`
              Int -> Builder 2
encodeTwoDigit (forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int64 @Int (forall a. Integral a => a -> a -> a
quot Int64
nanoseconds Int64
1_000_000_000))
              forall (m :: Nat) (n :: Nat).
Builder m -> Builder n -> Builder (m + n)
`BB.append`
              Char -> Builder 1
BB.ascii Char
'Z'
  OctetString Bytes
bs -> Bytes -> Encoder
bytes Bytes
bs
  BitString Word8
padBits Bytes
bs -> Word8 -> Encoder
word8 Word8
padBits forall a. Semigroup a => a -> a -> a
<> Bytes -> Encoder
bytes Bytes
bs
  Contents
Null -> forall a. Monoid a => a
mempty
  ObjectIdentifier (Oid PrimArray Word32
arr)
    | forall a. Prim a => PrimArray a -> Int
Prim.sizeofPrimArray PrimArray Word32
arr forall a. Ord a => a -> a -> Bool
< Int
2 -> forall a. HasCallStack => String -> a
error String
"Object Identifier must have at least two components"
    | Bool
otherwise -> PrimArray Word32 -> Encoder
objectIdentifier PrimArray Word32
arr
  Utf8String ShortText
str -> ShortText -> Encoder
utf8String ShortText
str
  PrintableString ShortText
str -> ShortText -> Encoder
printableString ShortText
str
  Constructed SmallArray Value
arr -> SmallArray Value -> Encoder
constructed SmallArray Value
arr
  Unresolved Bytes
raw -> Bytes -> Encoder
bytes Bytes
raw

encodeTwoDigit :: Int -> BB.Builder 2
encodeTwoDigit :: Int -> Builder 2
encodeTwoDigit !Int
n =
  Word8 -> Builder 1
BB.word8 (forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int @Word8 (Int
0x30 forall a. Num a => a -> a -> a
+ forall a. Integral a => a -> a -> a
quot Int
n Int
10))
  forall (m :: Nat) (n :: Nat).
Builder m -> Builder n -> Builder (m + n)
`BB.append`
  Word8 -> Builder 1
BB.word8 (forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int @Word8 (Int
0x30 forall a. Num a => a -> a -> a
+ forall a. Integral a => a -> a -> a
rem Int
n Int
10))

------------------ Content Encoders ------------------

base128 :: Int64 -> Encoder
base128 :: Int64 -> Encoder
base128 = forall {t}.
(Eq t, Num t) =>
Bool -> t -> [Word8] -> Int64 -> Encoder
go Bool
False (Int
0 :: Int) []
  where
  go :: Bool -> t -> [Word8] -> Int64 -> Encoder
go !Bool
lastNeg !t
size [Word8]
acc Int64
n =
    let content :: Word8
content = forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int64 @Word8 (Int64
n forall a. Bits a => a -> a -> a
.&. Int64
0x7F)
        rest :: Int64
rest = Int64
n forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
7
        thisNeg :: Bool
thisNeg = forall a. Bits a => a -> Int -> Bool
testBit Word8
content Int
6
        atEnd :: Bool
atEnd = (Word8
content forall a. Eq a => a -> a -> Bool
== Word8
0 Bool -> Bool -> Bool
&& Int64
rest forall a. Eq a => a -> a -> Bool
== Int64
0 Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
lastNeg)
              Bool -> Bool -> Bool
|| (Word8
content forall a. Eq a => a -> a -> Bool
== Word8
0x7F Bool -> Bool -> Bool
&& Int64
rest forall a. Eq a => a -> a -> Bool
== (-Int64
1) Bool -> Bool -> Bool
&& Bool
lastNeg)
     in if t
size forall a. Eq a => a -> a -> Bool
/= t
0 Bool -> Bool -> Bool
&& Bool
atEnd
        then forall {a}. Prim a => [a] -> Encoder
stop [Word8]
acc
        else
          let content' :: Word8
content' = (if t
size forall a. Eq a => a -> a -> Bool
== t
0 then Word8
0 else Word8
0x80) forall a. Bits a => a -> a -> a
.|. Word8
content
           in Bool -> t -> [Word8] -> Int64 -> Encoder
go Bool
thisNeg (t
size forall a. Num a => a -> a -> a
+ t
1) (Word8
content' forall a. a -> [a] -> [a]
: [Word8]
acc) Int64
rest
  stop :: [a] -> Encoder
stop [a]
acc = Bytes -> Encoder
singleton forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteArray -> Bytes
Bytes.fromByteArray forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Prim a => [a] -> ByteArray
byteArrayFromList forall a b. (a -> b) -> a -> b
$ [a]
acc

base256 :: Int64 -> Encoder
base256 :: Int64 -> Encoder
base256 Int64
n = Bytes -> Encoder
singleton forall a b. (a -> b) -> a -> b
$ ByteArray -> Bytes
Bytes.fromByteArray (forall a. Prim a => [a] -> ByteArray
byteArrayFromList [Word8]
minimized)
  where
  byteList :: [Word8]
byteList =
    [ forall a b. (Integral a, Num b) => a -> b
fromIntegral @Int64 @Word8 forall a b. (a -> b) -> a -> b
$ Int64
0xFF forall a. Bits a => a -> a -> a
.&. (Int64
n forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
bits)
    | Int
bits <- [Int
56,Int
48..Int
0]
    ]
  minimized :: [Word8]
minimized
    | Int64
n forall a. Ord a => a -> a -> Bool
< Int64
0 =
      case forall a. (a -> Bool) -> [a] -> [a]
dropWhile (forall a. Eq a => a -> a -> Bool
==Word8
0xFF) [Word8]
byteList of
        bs' :: [Word8]
bs'@(Word8
hd:[Word8]
_) | Word8
hd forall a. Bits a => a -> Int -> Bool
`testBit` Int
7 -> [Word8]
bs'
        [Word8]
bs' -> Word8
0xFFforall a. a -> [a] -> [a]
:[Word8]
bs'
    | Bool
otherwise =
      case forall a. (a -> Bool) -> [a] -> [a]
dropWhile (forall a. Eq a => a -> a -> Bool
==Word8
0x00) [Word8]
byteList of
        bs' :: [Word8]
bs'@(Word8
hd:[Word8]
_) | Bool -> Bool
not (Word8
hd forall a. Bits a => a -> Int -> Bool
`testBit` Int
7) -> [Word8]
bs'
        [Word8]
bs' -> Word8
0x00forall a. a -> [a] -> [a]
:[Word8]
bs'

bytes :: Bytes -> Encoder
bytes :: Bytes -> Encoder
bytes = Bytes -> Encoder
singleton

objectIdentifier :: PrimArray Word32 -> Encoder
objectIdentifier :: PrimArray Word32 -> Encoder
objectIdentifier PrimArray Word32
arr = Encoder
firstComps forall a. Semigroup a => a -> a -> a
<> forall a. Monoid a => [a] -> a
mconcat [Encoder]
restComps
  where
  firstComps :: Encoder
firstComps = Word8 -> Encoder
word8 forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral @Word32 @Word8 forall a b. (a -> b) -> a -> b
$
    (Word32
40 forall a. Num a => a -> a -> a
* forall a. Prim a => PrimArray a -> Int -> a
Prim.indexPrimArray PrimArray Word32
arr Int
0) forall a. Num a => a -> a -> a
+ (forall a. Prim a => PrimArray a -> Int -> a
Prim.indexPrimArray PrimArray Word32
arr Int
1)
  restComps :: [Encoder]
restComps = [Int64 -> Encoder
base128 forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral @Word32 @Int64 forall a b. (a -> b) -> a -> b
$ forall a. Prim a => PrimArray a -> Int -> a
Prim.indexPrimArray PrimArray Word32
arr Int
i
              | Int
i <- [Int
2..forall a. Prim a => PrimArray a -> Int
Prim.sizeofPrimArray PrimArray Word32
arr forall a. Num a => a -> a -> a
- Int
1]]

utf8String :: TS.ShortText -> Encoder
utf8String :: ShortText -> Encoder
utf8String ShortText
str = Bytes -> Encoder
singleton forall a b. (a -> b) -> a -> b
$ ShortText -> Bytes
shortTextToBytes forall a b. (a -> b) -> a -> b
$ ShortText
str

printableString :: TS.ShortText -> Encoder
printableString :: ShortText -> Encoder
printableString ShortText
str = Bytes -> Encoder
singleton forall a b. (a -> b) -> a -> b
$ ShortText -> Bytes
shortTextToBytes forall a b. (a -> b) -> a -> b
$ ShortText
str
  -- utf8 is backwards-compatible with ascii, so just hope that the input text is actually printable ascii

constructed :: SmallArray Value -> Encoder
constructed :: SmallArray Value -> Encoder
constructed = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap' Value -> Encoder
encodeValue

shortTextToBytes :: TS.ShortText -> Bytes
shortTextToBytes :: ShortText -> Bytes
shortTextToBytes ShortText
str = case ShortText -> ShortByteString
TS.toShortByteString ShortText
str of
  -- ShortText is already utf8-encoded, so just re-wrap it
  SBS ByteArray#
arr -> ByteArray -> Bytes
Bytes.fromByteArray (ByteArray# -> ByteArray
ByteArray ByteArray#
arr)