{-# LANGUAGE BinaryLiterals #-}
{-# LANGUAGE OverloadedStrings #-}

module Network.QUIC.Crypto.Keys (
    defaultCipher
  , initialSecrets
  , clientInitialSecret
  , serverInitialSecret
  , aeadKey
  , initialVector
  , nextSecret
  , headerProtectionKey
  ) where

import Network.TLS hiding (Version)
import Network.TLS.Extra.Cipher
import Network.TLS.QUIC
import qualified UnliftIO.Exception as E

import Network.QUIC.Crypto.Types
import Network.QUIC.Imports
import Network.QUIC.Types

----------------------------------------------------------------

defaultCipher :: Cipher
defaultCipher :: Cipher
defaultCipher = Cipher
cipher_TLS13_AES128GCM_SHA256

----------------------------------------------------------------

initialSalt :: Version -> Salt
initialSalt :: Version -> ByteString
initialSalt Version
Draft29     = ByteString
"\xaf\xbf\xec\x28\x99\x93\xd2\x4c\x9e\x97\x86\xf1\x9c\x61\x11\xe0\x43\x90\xa8\x99"
initialSalt Version
Version1    = ByteString
"\x38\x76\x2c\xf7\xf5\x59\x34\xb3\x4d\x17\x9a\xe6\xa4\xc8\x0c\xad\xcc\xbb\x7f\x0a"
initialSalt Version
Version2    = ByteString
"\x0d\xed\xe3\xde\xf7\x00\xa6\xdb\x81\x93\x81\xbe\x6e\x26\x9d\xcb\xf9\xbd\x2e\xd9"
initialSalt (Version Word32
v) = forall e a. Exception e => e -> a
E.impureThrow forall a b. (a -> b) -> a -> b
$ Word32 -> QUICException
VersionIsUnknown Word32
v

initialSecrets :: Version -> CID -> TrafficSecrets InitialSecret
initialSecrets :: Version -> CID -> TrafficSecrets InitialSecret
initialSecrets Version
v CID
c = (Version -> CID -> ClientTrafficSecret InitialSecret
clientInitialSecret Version
v CID
c, Version -> CID -> ServerTrafficSecret InitialSecret
serverInitialSecret Version
v CID
c)

clientInitialSecret :: Version -> CID -> ClientTrafficSecret InitialSecret
clientInitialSecret :: Version -> CID -> ClientTrafficSecret InitialSecret
clientInitialSecret Version
v CID
c = forall a. ByteString -> ClientTrafficSecret a
ClientTrafficSecret forall a b. (a -> b) -> a -> b
$ Version -> CID -> Label -> ByteString
initialSecret Version
v CID
c forall a b. (a -> b) -> a -> b
$ ByteString -> Label
Label ByteString
"client in"

serverInitialSecret :: Version -> CID -> ServerTrafficSecret InitialSecret
serverInitialSecret :: Version -> CID -> ServerTrafficSecret InitialSecret
serverInitialSecret Version
v CID
c = forall a. ByteString -> ServerTrafficSecret a
ServerTrafficSecret forall a b. (a -> b) -> a -> b
$ Version -> CID -> Label -> ByteString
initialSecret Version
v CID
c forall a b. (a -> b) -> a -> b
$ ByteString -> Label
Label ByteString
"server in"

initialSecret :: Version -> CID -> Label -> ByteString
initialSecret :: Version -> CID -> Label -> ByteString
initialSecret Version
Draft29  = ByteString -> CID -> Label -> ByteString
initialSecret' forall a b. (a -> b) -> a -> b
$ Version -> ByteString
initialSalt Version
Draft29
initialSecret Version
Version1 = ByteString -> CID -> Label -> ByteString
initialSecret' forall a b. (a -> b) -> a -> b
$ Version -> ByteString
initialSalt Version
Version1
initialSecret Version
Version2 = ByteString -> CID -> Label -> ByteString
initialSecret' forall a b. (a -> b) -> a -> b
$ Version -> ByteString
initialSalt Version
Version2
initialSecret Version
_        = \CID
_ Label
_ -> ByteString
"not supported"

initialSecret' :: ByteString -> CID -> Label -> ByteString
initialSecret' :: ByteString -> CID -> Label -> ByteString
initialSecret' ByteString
salt CID
cid (Label ByteString
label) = ByteString
secret
  where
    cipher :: Cipher
cipher    = Cipher
defaultCipher
    hash :: Hash
hash      = Cipher -> Hash
cipherHash Cipher
cipher
    iniSecret :: ByteString
iniSecret = Hash -> ByteString -> ByteString -> ByteString
hkdfExtract Hash
hash ByteString
salt forall a b. (a -> b) -> a -> b
$ CID -> ByteString
fromCID CID
cid
    hashSize :: Int
hashSize  = Hash -> Int
hashDigestSize Hash
hash
    secret :: ByteString
secret    = Hash -> ByteString -> ByteString -> ByteString -> Int -> ByteString
hkdfExpandLabel Hash
hash ByteString
iniSecret ByteString
label ByteString
"" Int
hashSize

aeadKey :: Version -> Cipher -> Secret -> Key
aeadKey :: Version -> Cipher -> Secret -> Key
aeadKey Version
Draft29  = Label -> Cipher -> Secret -> Key
genKey forall a b. (a -> b) -> a -> b
$ ByteString -> Label
Label ByteString
"quic key"
aeadKey Version
Version1 = Label -> Cipher -> Secret -> Key
genKey forall a b. (a -> b) -> a -> b
$ ByteString -> Label
Label ByteString
"quic key"
aeadKey Version
Version2 = Label -> Cipher -> Secret -> Key
genKey forall a b. (a -> b) -> a -> b
$ ByteString -> Label
Label ByteString
"quicv2 key"
aeadKey Version
_        = Label -> Cipher -> Secret -> Key
genKey forall a b. (a -> b) -> a -> b
$ ByteString -> Label
Label ByteString
"not supported"

headerProtectionKey :: Version -> Cipher -> Secret -> Key
headerProtectionKey :: Version -> Cipher -> Secret -> Key
headerProtectionKey Version
Draft29  = Label -> Cipher -> Secret -> Key
genKey forall a b. (a -> b) -> a -> b
$ ByteString -> Label
Label ByteString
"quic hp"
headerProtectionKey Version
Version1 = Label -> Cipher -> Secret -> Key
genKey forall a b. (a -> b) -> a -> b
$ ByteString -> Label
Label ByteString
"quic hp"
headerProtectionKey Version
Version2 = Label -> Cipher -> Secret -> Key
genKey forall a b. (a -> b) -> a -> b
$ ByteString -> Label
Label ByteString
"quicv2 hp"
headerProtectionKey Version
_        = Label -> Cipher -> Secret -> Key
genKey forall a b. (a -> b) -> a -> b
$ ByteString -> Label
Label ByteString
"not supported"

genKey :: Label -> Cipher -> Secret -> Key
genKey :: Label -> Cipher -> Secret -> Key
genKey (Label ByteString
label) Cipher
cipher (Secret ByteString
secret) = ByteString -> Key
Key ByteString
key
  where
    hash :: Hash
hash    = Cipher -> Hash
cipherHash Cipher
cipher
    bulk :: Bulk
bulk    = Cipher -> Bulk
cipherBulk Cipher
cipher
    keySize :: Int
keySize = Bulk -> Int
bulkKeySize Bulk
bulk
    key :: ByteString
key     = Hash -> ByteString -> ByteString -> ByteString -> Int -> ByteString
hkdfExpandLabel Hash
hash ByteString
secret ByteString
label ByteString
"" Int
keySize

initialVector :: Version -> Cipher -> Secret -> IV
initialVector :: Version -> Cipher -> Secret -> IV
initialVector Version
ver Cipher
cipher (Secret ByteString
secret) = ByteString -> IV
IV ByteString
iv
  where
    label :: ByteString
label  = Version -> ByteString
ivLabel Version
ver
    hash :: Hash
hash   = Cipher -> Hash
cipherHash Cipher
cipher
    bulk :: Bulk
bulk   = Cipher -> Bulk
cipherBulk Cipher
cipher
    ivSize :: Int
ivSize = forall a. Ord a => a -> a -> a
max Int
8 (Bulk -> Int
bulkIVSize Bulk
bulk forall a. Num a => a -> a -> a
+ Bulk -> Int
bulkExplicitIV Bulk
bulk)
    iv :: ByteString
iv     = Hash -> ByteString -> ByteString -> ByteString -> Int -> ByteString
hkdfExpandLabel Hash
hash ByteString
secret ByteString
label ByteString
"" Int
ivSize

ivLabel :: Version -> ByteString
ivLabel :: Version -> ByteString
ivLabel Version
Draft29  = ByteString
"quic iv"
ivLabel Version
Version1 = ByteString
"quic iv"
ivLabel Version
Version2 = ByteString
"quicv2 iv"
ivLabel Version
_        = ByteString
"not supported"

nextSecret :: Version -> Cipher -> Secret -> Secret
nextSecret :: Version -> Cipher -> Secret -> Secret
nextSecret Version
ver Cipher
cipher (Secret ByteString
secN) = ByteString -> Secret
Secret ByteString
secN1
  where
    label :: ByteString
label    = Version -> ByteString
kuLabel Version
ver
    hash :: Hash
hash     = Cipher -> Hash
cipherHash Cipher
cipher
    hashSize :: Int
hashSize = Hash -> Int
hashDigestSize Hash
hash
    secN1 :: ByteString
secN1    = Hash -> ByteString -> ByteString -> ByteString -> Int -> ByteString
hkdfExpandLabel Hash
hash ByteString
secN ByteString
label ByteString
"" Int
hashSize

kuLabel :: Version -> ByteString
kuLabel :: Version -> ByteString
kuLabel Version
Draft29  = ByteString
"quic ku"
kuLabel Version
Version1 = ByteString
"quic ku"
kuLabel Version
Version2 = ByteString
"quicv2 ku"
kuLabel Version
_        = ByteString
"not supported"