{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE StandaloneDeriving #-}

-- | PASETO cryptographic keys.
module Crypto.Paseto.Keys
  ( -- * Symmetric keys
    SymmetricKey (..)
  , symmetricKeyToBytes
  , bytesToSymmetricKeyV3
  , bytesToSymmetricKeyV4
  , generateSymmetricKeyV3
  , generateSymmetricKeyV4

    -- * Asymmetric keys
    -- ** Signing keys
  , SigningKey (..)
  , signingKeyToBytes
  , bytesToSigningKeyV3
  , bytesToSigningKeyV4
  , generateSigningKeyV3
  , generateSigningKeyV4
    -- ** Verification keys
  , VerificationKey (..)
  , verificationKeyToBytes
  , bytesToVerificationKeyV3
  , bytesToVerificationKeyV4
  , fromSigningKey
  ) where

import qualified Crypto.Error as Crypto
import qualified Crypto.Paseto.Keys.V3 as V3
import Crypto.Paseto.Mode ( Version (..) )
import Crypto.Paseto.ScrubbedBytes
  ( ScrubbedBytes32 (..), generateScrubbedBytes32, mkScrubbedBytes32 )
import qualified Crypto.PubKey.Ed25519 as Crypto.Ed25519
import Data.ByteArray ( ScrubbedBytes, constEq )
import qualified Data.ByteArray as BA
import Data.ByteString ( ByteString )
import Prelude

------------------------------------------------------------------------------
-- Symmetric keys
------------------------------------------------------------------------------

-- | Symmetric key.
--
-- Note that this type's 'Eq' instance performs a constant-time equality
-- check.
data SymmetricKey v where
  -- | Version 3 symmetric key.
  SymmetricKeyV3 :: !ScrubbedBytes32 -> SymmetricKey V3

  -- | Version 4 symmetric key.
  SymmetricKeyV4 :: !ScrubbedBytes32 -> SymmetricKey V4

instance Eq (SymmetricKey v) where
  SymmetricKeyV3 ScrubbedBytes32
x == :: SymmetricKey v -> SymmetricKey v -> Bool
== SymmetricKeyV3 ScrubbedBytes32
y = ScrubbedBytes32
x ScrubbedBytes32 -> ScrubbedBytes32 -> Bool
forall bs1 bs2.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
bs1 -> bs2 -> Bool
`constEq` ScrubbedBytes32
y
  SymmetricKeyV4 ScrubbedBytes32
x == SymmetricKeyV4 ScrubbedBytes32
y = ScrubbedBytes32
x ScrubbedBytes32 -> ScrubbedBytes32 -> Bool
forall bs1 bs2.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
bs1 -> bs2 -> Bool
`constEq` ScrubbedBytes32
y

-- | Get the raw bytes associated with a symmetric key.
symmetricKeyToBytes :: SymmetricKey v -> ScrubbedBytes
symmetricKeyToBytes :: forall (v :: Version). SymmetricKey v -> ScrubbedBytes
symmetricKeyToBytes SymmetricKey v
k =
  case SymmetricKey v
k of
    SymmetricKeyV3 (ScrubbedBytes32 ScrubbedBytes
bs) -> ScrubbedBytes
bs
    SymmetricKeyV4 (ScrubbedBytes32 ScrubbedBytes
bs) -> ScrubbedBytes
bs

-- | Construct a version 3 symmetric key from bytes.
--
-- If the provided byte string does not have a length of @32@ (@256@ bits),
-- 'Nothing' is returned.
bytesToSymmetricKeyV3 :: ScrubbedBytes -> Maybe (SymmetricKey V3)
bytesToSymmetricKeyV3 :: ScrubbedBytes -> Maybe (SymmetricKey 'V3)
bytesToSymmetricKeyV3 = (ScrubbedBytes32 -> SymmetricKey 'V3
SymmetricKeyV3 (ScrubbedBytes32 -> SymmetricKey 'V3)
-> Maybe ScrubbedBytes32 -> Maybe (SymmetricKey 'V3)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>) (Maybe ScrubbedBytes32 -> Maybe (SymmetricKey 'V3))
-> (ScrubbedBytes -> Maybe ScrubbedBytes32)
-> ScrubbedBytes
-> Maybe (SymmetricKey 'V3)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScrubbedBytes -> Maybe ScrubbedBytes32
forall b. ByteArrayAccess b => b -> Maybe ScrubbedBytes32
mkScrubbedBytes32

-- | Construct a version 4 symmetric key from bytes.
--
-- If the provided byte string does not have a length of @32@ (@256@ bits),
-- 'Nothing' is returned.
bytesToSymmetricKeyV4 :: ScrubbedBytes -> Maybe (SymmetricKey V4)
bytesToSymmetricKeyV4 :: ScrubbedBytes -> Maybe (SymmetricKey 'V4)
bytesToSymmetricKeyV4 = (ScrubbedBytes32 -> SymmetricKey 'V4
SymmetricKeyV4 (ScrubbedBytes32 -> SymmetricKey 'V4)
-> Maybe ScrubbedBytes32 -> Maybe (SymmetricKey 'V4)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>) (Maybe ScrubbedBytes32 -> Maybe (SymmetricKey 'V4))
-> (ScrubbedBytes -> Maybe ScrubbedBytes32)
-> ScrubbedBytes
-> Maybe (SymmetricKey 'V4)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScrubbedBytes -> Maybe ScrubbedBytes32
forall b. ByteArrayAccess b => b -> Maybe ScrubbedBytes32
mkScrubbedBytes32

-- | Randomly generate a version 3 symmetric key.
generateSymmetricKeyV3 :: IO (SymmetricKey V3)
generateSymmetricKeyV3 :: IO (SymmetricKey 'V3)
generateSymmetricKeyV3 = ScrubbedBytes32 -> SymmetricKey 'V3
SymmetricKeyV3 (ScrubbedBytes32 -> SymmetricKey 'V3)
-> IO ScrubbedBytes32 -> IO (SymmetricKey 'V3)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO ScrubbedBytes32
generateScrubbedBytes32

-- | Randomly generate a version 4 symmetric key.
generateSymmetricKeyV4 :: IO (SymmetricKey V4)
generateSymmetricKeyV4 :: IO (SymmetricKey 'V4)
generateSymmetricKeyV4 = ScrubbedBytes32 -> SymmetricKey 'V4
SymmetricKeyV4 (ScrubbedBytes32 -> SymmetricKey 'V4)
-> IO ScrubbedBytes32 -> IO (SymmetricKey 'V4)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO ScrubbedBytes32
generateScrubbedBytes32

------------------------------------------------------------------------------
-- Asymmetric keys
------------------------------------------------------------------------------

-- | Signing key (also known as a private\/secret key).
--
-- Note that this type's 'Eq' instance performs a constant-time equality
-- check.
data SigningKey v where
  -- | Version 3 signing key.
  SigningKeyV3 :: !V3.PrivateKeyP384 -> SigningKey V3

  -- | Version 3 signing key.
  SigningKeyV4 :: !Crypto.Ed25519.SecretKey -> SigningKey V4

instance Eq (SigningKey v) where
  SigningKey v
x == :: SigningKey v -> SigningKey v -> Bool
== SigningKey v
y = SigningKey v -> ScrubbedBytes
forall (v :: Version). SigningKey v -> ScrubbedBytes
signingKeyToBytes SigningKey v
x ScrubbedBytes -> ScrubbedBytes -> Bool
forall bs1 bs2.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
bs1 -> bs2 -> Bool
`constEq` SigningKey v -> ScrubbedBytes
forall (v :: Version). SigningKey v -> ScrubbedBytes
signingKeyToBytes SigningKey v
y

-- | Get the raw bytes associated with a signing key.
signingKeyToBytes :: SigningKey v -> ScrubbedBytes
signingKeyToBytes :: forall (v :: Version). SigningKey v -> ScrubbedBytes
signingKeyToBytes SigningKey v
sk =
  case SigningKey v
sk of
    SigningKeyV3 PrivateKeyP384
k -> PrivateKeyP384 -> ScrubbedBytes
V3.encodePrivateKeyP384 PrivateKeyP384
k
    SigningKeyV4 SecretKey
k -> SecretKey -> ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert SecretKey
k

-- | Construct a version 3 signing key from bytes.
bytesToSigningKeyV3 :: ScrubbedBytes -> Either V3.ScalarDecodingError (SigningKey V3)
bytesToSigningKeyV3 :: ScrubbedBytes -> Either ScalarDecodingError (SigningKey 'V3)
bytesToSigningKeyV3 ScrubbedBytes
bs = PrivateKeyP384 -> SigningKey 'V3
SigningKeyV3 (PrivateKeyP384 -> SigningKey 'V3)
-> Either ScalarDecodingError PrivateKeyP384
-> Either ScalarDecodingError (SigningKey 'V3)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ScrubbedBytes -> Either ScalarDecodingError PrivateKeyP384
V3.decodePrivateKeyP384 ScrubbedBytes
bs

-- | Construct a version 4 signing key from bytes.
bytesToSigningKeyV4 :: ScrubbedBytes -> Maybe (SigningKey V4)
bytesToSigningKeyV4 :: ScrubbedBytes -> Maybe (SigningKey 'V4)
bytesToSigningKeyV4 ScrubbedBytes
bs =
  SecretKey -> SigningKey 'V4
SigningKeyV4
    (SecretKey -> SigningKey 'V4)
-> Maybe SecretKey -> Maybe (SigningKey 'V4)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CryptoFailable SecretKey -> Maybe SecretKey
forall a. CryptoFailable a -> Maybe a
Crypto.maybeCryptoError (ScrubbedBytes -> CryptoFailable SecretKey
forall ba. ByteArrayAccess ba => ba -> CryptoFailable SecretKey
Crypto.Ed25519.secretKey ScrubbedBytes
bs)

-- | Randomly generate a version 3 signing key.
generateSigningKeyV3 :: IO (SigningKey V3)
generateSigningKeyV3 :: IO (SigningKey 'V3)
generateSigningKeyV3 = PrivateKeyP384 -> SigningKey 'V3
SigningKeyV3 (PrivateKeyP384 -> SigningKey 'V3)
-> IO PrivateKeyP384 -> IO (SigningKey 'V3)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO PrivateKeyP384
V3.generatePrivateKeyP384

-- | Randomly generate a version 4 signing key.
generateSigningKeyV4 :: IO (SigningKey V4)
generateSigningKeyV4 :: IO (SigningKey 'V4)
generateSigningKeyV4 = SecretKey -> SigningKey 'V4
SigningKeyV4 (SecretKey -> SigningKey 'V4)
-> IO SecretKey -> IO (SigningKey 'V4)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO SecretKey
forall (m :: * -> *). MonadRandom m => m SecretKey
Crypto.Ed25519.generateSecretKey

-- | Verification key (also known as a public key).
data VerificationKey v where
  -- | Version 3 verification key.
  VerificationKeyV3 :: !V3.PublicKeyP384 -> VerificationKey V3

  -- | Version 4 verification key.
  VerificationKeyV4 :: !Crypto.Ed25519.PublicKey -> VerificationKey V4

deriving instance Eq (VerificationKey v)

-- | Get the raw bytes associated with a verification key.
verificationKeyToBytes :: VerificationKey v -> ByteString
verificationKeyToBytes :: forall (v :: Version). VerificationKey v -> ByteString
verificationKeyToBytes VerificationKey v
vk =
  case VerificationKey v
vk of
    VerificationKeyV3 PublicKeyP384
k -> PublicKeyP384 -> ByteString
V3.encodePublicKeyP384 PublicKeyP384
k
    VerificationKeyV4 PublicKey
k -> PublicKey -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert PublicKey
k

-- | Construct a version 3 verification key from bytes.
--
-- The input 'ByteString' is expected to be formatted as either a compressed
-- or uncompressed elliptic curve public key as defined by
-- [SEC 1](https://www.secg.org/sec1-v2.pdf) and
-- [RFC 5480 section 2.2](https://datatracker.ietf.org/doc/html/rfc5480#section-2.2).
bytesToVerificationKeyV3 :: ByteString -> Either V3.PublicKeyP384DecodingError (VerificationKey V3)
bytesToVerificationKeyV3 :: ByteString
-> Either PublicKeyP384DecodingError (VerificationKey 'V3)
bytesToVerificationKeyV3 ByteString
bs = PublicKeyP384 -> VerificationKey 'V3
VerificationKeyV3 (PublicKeyP384 -> VerificationKey 'V3)
-> Either PublicKeyP384DecodingError PublicKeyP384
-> Either PublicKeyP384DecodingError (VerificationKey 'V3)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> Either PublicKeyP384DecodingError PublicKeyP384
V3.decodePublicKeyP384 ByteString
bs

-- | Construct a version 4 verification key from bytes.
bytesToVerificationKeyV4 :: ByteString -> Maybe (VerificationKey V4)
bytesToVerificationKeyV4 :: ByteString -> Maybe (VerificationKey 'V4)
bytesToVerificationKeyV4 ByteString
bs =
  PublicKey -> VerificationKey 'V4
VerificationKeyV4
    (PublicKey -> VerificationKey 'V4)
-> Maybe PublicKey -> Maybe (VerificationKey 'V4)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CryptoFailable PublicKey -> Maybe PublicKey
forall a. CryptoFailable a -> Maybe a
Crypto.maybeCryptoError (ByteString -> CryptoFailable PublicKey
forall ba. ByteArrayAccess ba => ba -> CryptoFailable PublicKey
Crypto.Ed25519.publicKey ByteString
bs)

-- | Get the 'VerificationKey' which corresponds to a given 'SigningKey'.
fromSigningKey :: SigningKey v -> VerificationKey v
fromSigningKey :: forall (v :: Version). SigningKey v -> VerificationKey v
fromSigningKey SigningKey v
sk =
  case SigningKey v
sk of
    SigningKeyV3 PrivateKeyP384
k -> PublicKeyP384 -> VerificationKey 'V3
VerificationKeyV3 (PrivateKeyP384 -> PublicKeyP384
V3.fromPrivateKeyP384 PrivateKeyP384
k)
    SigningKeyV4 SecretKey
k -> PublicKey -> VerificationKey 'V4
VerificationKeyV4 (SecretKey -> PublicKey
Crypto.Ed25519.toPublic SecretKey
k)