-- SPDX-FileCopyrightText: 2022 Serokell <https://serokell.io>
-- SPDX-License-Identifier: MPL-2.0

-- for inequality on keygen
{-# OPTIONS_GHC -Wno-redundant-constraints #-}

module Crypto.BLST
  ( -- * Main functions
    keygen
  , skToPk
  , sign
  , verify
  , serializeSk
  , deserializeSk
  , serializePk
  , deserializePk
  , compressPk
  , decompressPk
  , serializeSignature
  , deserializeSignature
  , compressSignature
  , decompressSignature

    -- * Aggregate signatures
  , aggregateSignatures
  , aggregateVerify

    -- * Representation datatypes
  , SecretKey
  , PublicKey
  , Signature
  , B.BlstError(..)

    -- * Utility typeclasses
  , IsCurve
  , IsPoint
  , ToCurve
  , Demote

    -- * Data kinds
  , Curve(..)
  , B.EncodeMethod(..)

    -- * Typelevel byte sizes
  , ByteSize
  , SerializeOrCompress(..)

    -- * Misc helpers
  , noDST
  , byteSize
  ) where

import Control.Exception (catch, throwIO)
import Control.Monad (forM_)
import Data.ByteArray (ByteArrayAccess, Bytes, ScrubbedBytes)
import Data.ByteArray.Sized (SizedByteArray, unSizedByteArray)
import Data.Foldable (foldlM)
import Data.List.NonEmpty (NonEmpty(..))
import GHC.TypeNats (KnownNat, type (<=))
import System.IO.Unsafe (unsafePerformIO)

import Crypto.BLST.Internal.Bindings qualified as B
import Crypto.BLST.Internal.Classy
import Crypto.BLST.Internal.Demote
import Crypto.BLST.Internal.Types

-- | Generate a secret key from bytes.
keygen :: (ByteArrayAccess ba, 32 <= n, KnownNat n) => SizedByteArray n ba -> SecretKey
keygen :: forall ba (n :: Natural).
(ByteArrayAccess ba, 32 <= n, KnownNat n) =>
SizedByteArray n ba -> SecretKey
keygen = Scalar -> SecretKey
SecretKey forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IO a -> a
unsafePerformIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall ba. ByteArrayAccess ba => ba -> IO Scalar
B.keygen forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Natural) ba. SizedByteArray n ba -> ba
unSizedByteArray

-- | Convert a secret key to the corresponding public key on a given curve.
skToPk :: forall c. IsCurve c => SecretKey -> PublicKey c
skToPk :: forall (c :: Curve). IsCurve c => SecretKey -> PublicKey c
skToPk (SecretKey Scalar
sk) = forall (c :: Curve). Affine (CurveToPkPoint c) -> PublicKey c
PublicKey forall a b. (a -> b) -> a -> b
$ forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall (c :: Curve).
IsCurve c =>
Scalar -> IO (Point (CurveToPkPoint c))
skToPkPoint Scalar
sk forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (p :: PointKind). IsPoint p => Point p -> IO (Affine p)
toAffine

-- | Serialize public key.
serializePk
  :: forall c. IsCurve c
  => PublicKey c
  -> SizedByteArray (SerializedSize (CurveToPkPoint c)) Bytes
serializePk :: forall (c :: Curve).
IsCurve c =>
PublicKey c
-> SizedByteArray (SerializedSize (CurveToPkPoint c)) Bytes
serializePk (PublicKey Affine (CurveToPkPoint c)
pk) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall (p :: PointKind).
IsPoint p =>
Affine p -> IO (SizedByteArray (SerializedSize p) Bytes)
affSerialize Affine (CurveToPkPoint c)
pk

-- | Deserialize public key.
deserializePk
  :: forall c ba. (IsCurve c, ByteArrayAccess ba)
  => SizedByteArray (SerializedSize (CurveToPkPoint c)) ba
  -> Either B.BlstError (PublicKey c)
deserializePk :: forall (c :: Curve) ba.
(IsCurve c, ByteArrayAccess ba) =>
SizedByteArray (SerializedSize (CurveToPkPoint c)) ba
-> Either BlstError (PublicKey c)
deserializePk SizedByteArray (SerializedSize (CurveToPkPoint c)) ba
bs = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (c :: Curve). Affine (CurveToPkPoint c) -> PublicKey c
PublicKey forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (p :: PointKind) ba.
(IsPoint p, ByteArrayAccess ba) =>
SizedByteArray (SerializedSize p) ba
-> IO (Either BlstError (Affine p))
deserialize SizedByteArray (SerializedSize (CurveToPkPoint c)) ba
bs

-- | Compress public key.
compressPk
  :: forall c. IsCurve c
  => PublicKey c
  -> SizedByteArray (CompressedSize (CurveToPkPoint c)) Bytes
compressPk :: forall (c :: Curve).
IsCurve c =>
PublicKey c
-> SizedByteArray (CompressedSize (CurveToPkPoint c)) Bytes
compressPk (PublicKey Affine (CurveToPkPoint c)
pk) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall (p :: PointKind).
IsPoint p =>
Affine p -> IO (SizedByteArray (CompressedSize p) Bytes)
affCompress Affine (CurveToPkPoint c)
pk

-- | Decompress public key.
decompressPk
  :: forall c ba. (IsCurve c, ByteArrayAccess ba)
  => SizedByteArray (CompressedSize (CurveToPkPoint c)) ba
  -> Either B.BlstError (PublicKey c)
decompressPk :: forall (c :: Curve) ba.
(IsCurve c, ByteArrayAccess ba) =>
SizedByteArray (CompressedSize (CurveToPkPoint c)) ba
-> Either BlstError (PublicKey c)
decompressPk SizedByteArray (CompressedSize (CurveToPkPoint c)) ba
bs = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (c :: Curve). Affine (CurveToPkPoint c) -> PublicKey c
PublicKey forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (p :: PointKind) ba.
(IsPoint p, ByteArrayAccess ba) =>
SizedByteArray (CompressedSize p) ba
-> IO (Either BlstError (Affine p))
uncompress SizedByteArray (CompressedSize (CurveToPkPoint c)) ba
bs

-- | Sign a single message.
sign
  :: forall c m ba ba2. (ToCurve m c, ByteArrayAccess ba, ByteArrayAccess ba2)
  => SecretKey -- ^ Secret key
  -> ba -- ^ Message to sign
  -> Maybe ba2 -- ^ Optional domain separation tag
  -> Signature c m
sign :: forall (c :: Curve) (m :: EncodeMethod) ba ba2.
(ToCurve m c, ByteArrayAccess ba, ByteArrayAccess ba2) =>
SecretKey -> ba -> Maybe ba2 -> Signature c m
sign (SecretKey Scalar
sk) ba
bytes Maybe ba2
dst = forall (c :: Curve) (m :: EncodeMethod).
Affine (CurveToMsgPoint c) -> Signature c m
Signature forall a b. (a -> b) -> a -> b
$ forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
  Point (CurveToMsgPoint c)
encMsg <- forall (meth :: EncodeMethod) (c :: Curve) ba ba2.
(ToCurve meth c, ByteArrayAccess ba, ByteArrayAccess ba2) =>
ba -> Maybe ba2 -> IO (Point (CurveToMsgPoint c))
toCurve @m ba
bytes Maybe ba2
dst
  forall (c :: Curve).
IsCurve c =>
Point (CurveToMsgPoint c)
-> Scalar -> IO (Point (CurveToMsgPoint c))
signPk Point (CurveToMsgPoint c)
encMsg Scalar
sk forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (p :: PointKind). IsPoint p => Point p -> IO (Affine p)
toAffine

-- | Serialize message signature.
serializeSignature
  :: forall c m. IsCurve c
  => Signature c m
  -> SizedByteArray (SerializedSize (CurveToMsgPoint c)) Bytes
serializeSignature :: forall (c :: Curve) (m :: EncodeMethod).
IsCurve c =>
Signature c m
-> SizedByteArray (SerializedSize (CurveToMsgPoint c)) Bytes
serializeSignature (Signature Affine (CurveToMsgPoint c)
sig) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall (p :: PointKind).
IsPoint p =>
Affine p -> IO (SizedByteArray (SerializedSize p) Bytes)
affSerialize Affine (CurveToMsgPoint c)
sig

-- | Deserialize message signature.
deserializeSignature
  :: forall c m ba. (IsCurve c, ByteArrayAccess ba)
  => SizedByteArray (SerializedSize (CurveToMsgPoint c)) ba
  -> Either B.BlstError (Signature c m)
deserializeSignature :: forall (c :: Curve) (m :: EncodeMethod) ba.
(IsCurve c, ByteArrayAccess ba) =>
SizedByteArray (SerializedSize (CurveToMsgPoint c)) ba
-> Either BlstError (Signature c m)
deserializeSignature SizedByteArray (SerializedSize (CurveToMsgPoint c)) ba
bs = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (c :: Curve) (m :: EncodeMethod).
Affine (CurveToMsgPoint c) -> Signature c m
Signature forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (p :: PointKind) ba.
(IsPoint p, ByteArrayAccess ba) =>
SizedByteArray (SerializedSize p) ba
-> IO (Either BlstError (Affine p))
deserialize SizedByteArray (SerializedSize (CurveToMsgPoint c)) ba
bs

-- | Serialize and compress message signature.
compressSignature
  :: forall c m. IsCurve c
  => Signature c m
  -> SizedByteArray (CompressedSize (CurveToMsgPoint c)) Bytes
compressSignature :: forall (c :: Curve) (m :: EncodeMethod).
IsCurve c =>
Signature c m
-> SizedByteArray (CompressedSize (CurveToMsgPoint c)) Bytes
compressSignature (Signature Affine (CurveToMsgPoint c)
sig) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall (p :: PointKind).
IsPoint p =>
Affine p -> IO (SizedByteArray (CompressedSize p) Bytes)
affCompress Affine (CurveToMsgPoint c)
sig

-- | Decompress and deserialize message signature.
decompressSignature
  :: forall c m ba. (IsCurve c, ByteArrayAccess ba)
  => SizedByteArray (CompressedSize (CurveToMsgPoint c)) ba
  -> Either B.BlstError (Signature c m)
decompressSignature :: forall (c :: Curve) (m :: EncodeMethod) ba.
(IsCurve c, ByteArrayAccess ba) =>
SizedByteArray (CompressedSize (CurveToMsgPoint c)) ba
-> Either BlstError (Signature c m)
decompressSignature SizedByteArray (CompressedSize (CurveToMsgPoint c)) ba
bs = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (c :: Curve) (m :: EncodeMethod).
Affine (CurveToMsgPoint c) -> Signature c m
Signature forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (p :: PointKind) ba.
(IsPoint p, ByteArrayAccess ba) =>
SizedByteArray (CompressedSize p) ba
-> IO (Either BlstError (Affine p))
uncompress SizedByteArray (CompressedSize (CurveToMsgPoint c)) ba
bs

-- | Verify message signature.
verify
  :: forall c m ba ba2. (IsCurve c, Demote m, ByteArrayAccess ba, ByteArrayAccess ba2)
  => Signature c m -- ^ Signature
  -> PublicKey c -- ^ Public key of the signer
  -> ba -- ^ Message
  -> Maybe ba2 -- ^ Optional domain separation tag (must be the same as used for signing!)
  -> B.BlstError
verify :: forall (c :: Curve) (m :: EncodeMethod) ba ba2.
(IsCurve c, Demote m, ByteArrayAccess ba, ByteArrayAccess ba2) =>
Signature c m -> PublicKey c -> ba -> Maybe ba2 -> BlstError
verify (Signature Affine (CurveToMsgPoint c)
sig) (PublicKey Affine (CurveToPkPoint c)
pk) ba
bytes Maybe ba2
dst =
  forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall (c :: Curve) ba ba2.
(IsCurve c, ByteArrayAccess ba, ByteArrayAccess ba2) =>
Affine (CurveToPkPoint c)
-> Affine (CurveToMsgPoint c)
-> EncodeMethod
-> ba
-> Maybe ba2
-> IO BlstError
coreVerifyPk Affine (CurveToPkPoint c)
pk Affine (CurveToMsgPoint c)
sig EncodeMethod
meth ba
bytes Maybe ba2
dst
  where
    meth :: EncodeMethod
meth = forall {k} (x :: k). Demote x => k
demote @m

-- | Convenience synonym for 'Nothing'. Do not use domain separation tag.
noDST :: Maybe Bytes
noDST :: Maybe Bytes
noDST = forall a. Maybe a
Nothing

-- | Serialize secret key.
serializeSk :: SecretKey -> SizedByteArray B.SkSerializeSize ScrubbedBytes
serializeSk :: SecretKey -> SizedByteArray 32 ScrubbedBytes
serializeSk (SecretKey Scalar
sk) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ Scalar -> IO (SizedByteArray 32 ScrubbedBytes)
B.lendianFromScalar Scalar
sk

-- | Deserialize secret key.
deserializeSk :: ByteArrayAccess ba => SizedByteArray B.SkSerializeSize ba -> SecretKey
deserializeSk :: forall ba. ByteArrayAccess ba => SizedByteArray 32 ba -> SecretKey
deserializeSk SizedByteArray 32 ba
bs = Scalar -> SecretKey
SecretKey forall a b. (a -> b) -> a -> b
$ forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall ba. ByteArrayAccess ba => SizedByteArray 32 ba -> IO Scalar
B.scalarFromLendian SizedByteArray 32 ba
bs

-- | Aggregate multiple signatures.
aggregateSignatures :: forall c m. IsCurve c => NonEmpty (Signature c m) -> Signature c m
aggregateSignatures :: forall (c :: Curve) (m :: EncodeMethod).
IsCurve c =>
NonEmpty (Signature c m) -> Signature c m
aggregateSignatures (Signature Affine (CurveToMsgPoint c)
x :| [Signature c m]
xs) = forall (c :: Curve) (m :: EncodeMethod).
Affine (CurveToMsgPoint c) -> Signature c m
Signature forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
  Point (CurveToMsgPoint c)
start <- forall (p :: PointKind). IsPoint p => Affine p -> IO (Point p)
fromAffine Affine (CurveToMsgPoint c)
x
  forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM forall {c :: Curve} {m :: EncodeMethod}.
IsPoint (CurveToMsgPoint c) =>
Point (CurveToMsgPoint c)
-> Signature c m -> IO (Point (CurveToMsgPoint c))
add Point (CurveToMsgPoint c)
start [Signature c m]
xs forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (p :: PointKind). IsPoint p => Point p -> IO (Affine p)
toAffine
  where
    add :: Point (CurveToMsgPoint c)
-> Signature c m -> IO (Point (CurveToMsgPoint c))
add Point (CurveToMsgPoint c)
x' (Signature Affine (CurveToMsgPoint c)
y) = forall (p :: PointKind).
IsPoint p =>
Point p -> Affine p -> IO (Point p)
addOrDoubleAffine Point (CurveToMsgPoint c)
x' Affine (CurveToMsgPoint c)
y

-- | Aggregate signature verification.
aggregateVerify
  :: forall c m ba ba2. (IsCurve c, Demote m, ByteArrayAccess ba, ByteArrayAccess ba2)
  => NonEmpty (PublicKey c, ba) -- ^ Public keys with corresponding messages
  -> Signature c m -- ^ Aggregate signature
  -> Maybe ba2 -- ^ Optional domain separation tag (must be the same as used for signing!)
  -> Either B.BlstError Bool
aggregateVerify :: forall (c :: Curve) (m :: EncodeMethod) ba ba2.
(IsCurve c, Demote m, ByteArrayAccess ba, ByteArrayAccess ba2) =>
NonEmpty (PublicKey c, ba)
-> Signature c m -> Maybe ba2 -> Either BlstError Bool
aggregateVerify ((PublicKey Affine (CurveToPkPoint c)
pk1, ba
msg1) :| [(PublicKey c, ba)]
xs) (Signature Affine (CurveToMsgPoint c)
sig) Maybe ba2
dst = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
  PairingCtx
ctx <- forall ba.
ByteArrayAccess ba =>
EncodeMethod -> Maybe ba -> IO PairingCtx
B.pairingInit (forall {k} (x :: k). Demote x => k
demote @m) Maybe ba2
dst
  BlstError -> IO ()
checkThrow forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (c :: Curve) ba.
(IsCurve c, ByteArrayAccess ba) =>
PairingCtx
-> Affine (CurveToPkPoint c)
-> Bool
-> Maybe (Affine (CurveToMsgPoint c))
-> Bool
-> ba
-> IO BlstError
pairingChkNAggrPk PairingCtx
ctx Affine (CurveToPkPoint c)
pk1 Bool
True (forall a. a -> Maybe a
Just Affine (CurveToMsgPoint c)
sig) Bool
True ba
msg1
  forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(PublicKey c, ba)]
xs forall a b. (a -> b) -> a -> b
$ \(PublicKey Affine (CurveToPkPoint c)
pki, ba
msgi) ->
    BlstError -> IO ()
checkThrow forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (c :: Curve) ba.
(IsCurve c, ByteArrayAccess ba) =>
PairingCtx
-> Affine (CurveToPkPoint c)
-> Bool
-> Maybe (Affine (CurveToMsgPoint c))
-> Bool
-> ba
-> IO BlstError
pairingChkNAggrPk PairingCtx
ctx Affine (CurveToPkPoint c)
pki Bool
True forall a. Maybe a
Nothing Bool
True ba
msgi
  PairingCtx -> IO ()
B.pairingCommit PairingCtx
ctx
  forall a b. b -> Either a b
Right forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> PairingCtx -> IO Bool
B.pairingFinalVerify PairingCtx
ctx
  forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` \(BlstError
err :: B.BlstError) -> forall (f :: Type -> Type) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left BlstError
err
  where
    checkThrow :: BlstError -> IO ()
checkThrow = \case
      BlstError
B.BlstSuccess -> forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ()
      BlstError
x -> forall e a. Exception e => e -> IO a
throwIO BlstError
x