{-# LANGUAGE BinaryLiterals #-}
{-# LANGUAGE OverloadedStrings #-}

module Network.QUIC.Crypto (
  -- * Payload encryption
    defaultCipher
  , initialSecrets
  , clientInitialSecret
  , serverInitialSecret
  , aeadKey
  , initialVector
  , nextSecret
  , headerProtectionKey
  , makeNonce
  , encryptPayload
  , encryptPayload'
  , decryptPayload
  , decryptPayload'
  -- * Header Protection
  , protectionMask
  , tagLength
  , sampleLength
  , bsXOR
--  , unprotectHeader
  -- * Types
  , PlainText
  , CipherText
  , Key(..)
  , IV(..)
  , CID
  , Secret(..)
  , AddDat(..)
  , Sample(..)
  , Mask(..)
  , Nonce(..)
  , Cipher
  , InitialSecret
  , TrafficSecrets
  , ClientTrafficSecret(..)
  , ServerTrafficSecret(..)
  -- * Misc
  , calculateIntegrityTag
  ) where

import Crypto.Cipher.AES
import qualified Crypto.Cipher.ChaCha as ChaCha
import qualified Crypto.Cipher.ChaChaPoly1305 as ChaChaPoly
import Crypto.Cipher.Types hiding (Cipher, IV)
import Crypto.Error (throwCryptoError, maybeCryptoError)
import qualified Crypto.MAC.Poly1305 as Poly1305
import qualified Data.ByteArray as Byte (convert, xor)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as C8
import qualified Data.ByteString.Internal as BS
import qualified Data.ByteString.Short as Short
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (Ptr, plusPtr)
import Foreign.Storable (peek, poke)
import Network.TLS hiding (Version)
import Network.TLS.Extra.Cipher
import Network.TLS.QUIC
import qualified UnliftIO.Exception as E

import Network.QUIC.Imports
import Network.QUIC.Types

----------------------------------------------------------------

defaultCipher :: Cipher
defaultCipher :: Cipher
defaultCipher = Cipher
cipher_TLS13_AES128GCM_SHA256

----------------------------------------------------------------

type PlainText  = ByteString
type CipherText = ByteString
type Salt       = ByteString

newtype Key    = Key    ByteString deriving (Key -> Key -> Bool
(Key -> Key -> Bool) -> (Key -> Key -> Bool) -> Eq Key
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Key -> Key -> Bool
$c/= :: Key -> Key -> Bool
== :: Key -> Key -> Bool
$c== :: Key -> Key -> Bool
Eq)
newtype IV     = IV     ByteString deriving (IV -> IV -> Bool
(IV -> IV -> Bool) -> (IV -> IV -> Bool) -> Eq IV
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: IV -> IV -> Bool
$c/= :: IV -> IV -> Bool
== :: IV -> IV -> Bool
$c== :: IV -> IV -> Bool
Eq)
newtype Secret = Secret ByteString deriving (Secret -> Secret -> Bool
(Secret -> Secret -> Bool)
-> (Secret -> Secret -> Bool) -> Eq Secret
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Secret -> Secret -> Bool
$c/= :: Secret -> Secret -> Bool
== :: Secret -> Secret -> Bool
$c== :: Secret -> Secret -> Bool
Eq)
newtype AddDat = AddDat ByteString deriving (AddDat -> AddDat -> Bool
(AddDat -> AddDat -> Bool)
-> (AddDat -> AddDat -> Bool) -> Eq AddDat
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: AddDat -> AddDat -> Bool
$c/= :: AddDat -> AddDat -> Bool
== :: AddDat -> AddDat -> Bool
$c== :: AddDat -> AddDat -> Bool
Eq)
newtype Sample = Sample ByteString deriving (Sample -> Sample -> Bool
(Sample -> Sample -> Bool)
-> (Sample -> Sample -> Bool) -> Eq Sample
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Sample -> Sample -> Bool
$c/= :: Sample -> Sample -> Bool
== :: Sample -> Sample -> Bool
$c== :: Sample -> Sample -> Bool
Eq)
newtype Mask   = Mask   ByteString deriving (Mask -> Mask -> Bool
(Mask -> Mask -> Bool) -> (Mask -> Mask -> Bool) -> Eq Mask
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Mask -> Mask -> Bool
$c/= :: Mask -> Mask -> Bool
== :: Mask -> Mask -> Bool
$c== :: Mask -> Mask -> Bool
Eq)
newtype Label  = Label  ByteString deriving (Label -> Label -> Bool
(Label -> Label -> Bool) -> (Label -> Label -> Bool) -> Eq Label
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Label -> Label -> Bool
$c/= :: Label -> Label -> Bool
== :: Label -> Label -> Bool
$c== :: Label -> Label -> Bool
Eq)
newtype Nonce  = Nonce  ByteString deriving (Nonce -> Nonce -> Bool
(Nonce -> Nonce -> Bool) -> (Nonce -> Nonce -> Bool) -> Eq Nonce
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Nonce -> Nonce -> Bool
$c/= :: Nonce -> Nonce -> Bool
== :: Nonce -> Nonce -> Bool
$c== :: Nonce -> Nonce -> Bool
Eq)

instance Show Key where
    show :: Key -> String
show (Key ByteString
x) = String
"Key=" String -> ShowS
forall a. [a] -> [a] -> [a]
++ ByteString -> String
C8.unpack (ByteString -> ByteString
enc16 ByteString
x)
instance Show IV where
    show :: IV -> String
show (IV ByteString
x) = String
"IV=" String -> ShowS
forall a. [a] -> [a] -> [a]
++ ByteString -> String
C8.unpack (ByteString -> ByteString
enc16 ByteString
x)
instance Show Secret where
    show :: Secret -> String
show (Secret ByteString
x) = String
"Secret=" String -> ShowS
forall a. [a] -> [a] -> [a]
++ ByteString -> String
C8.unpack (ByteString -> ByteString
enc16 ByteString
x)
instance Show AddDat where
    show :: AddDat -> String
show (AddDat ByteString
x) = String
"AddDat=" String -> ShowS
forall a. [a] -> [a] -> [a]
++ ByteString -> String
C8.unpack (ByteString -> ByteString
enc16 ByteString
x)
instance Show Sample where
    show :: Sample -> String
show (Sample ByteString
x) = String
"Sample=" String -> ShowS
forall a. [a] -> [a] -> [a]
++ ByteString -> String
C8.unpack (ByteString -> ByteString
enc16 ByteString
x)
instance Show Mask where
    show :: Mask -> String
show (Mask ByteString
x) = String
"Mask=" String -> ShowS
forall a. [a] -> [a] -> [a]
++ ByteString -> String
C8.unpack (ByteString -> ByteString
enc16 ByteString
x)
instance Show Label where
    show :: Label -> String
show (Label ByteString
x) = String
"Label=" String -> ShowS
forall a. [a] -> [a] -> [a]
++ ByteString -> String
C8.unpack (ByteString -> ByteString
enc16 ByteString
x)
instance Show Nonce where
    show :: Nonce -> String
show (Nonce ByteString
x) = String
"Nonce=" String -> ShowS
forall a. [a] -> [a] -> [a]
++ ByteString -> String
C8.unpack (ByteString -> ByteString
enc16 ByteString
x)

----------------------------------------------------------------

initialSalt :: Version -> Salt
initialSalt :: Version -> ByteString
initialSalt Version
Draft29     = ByteString
"\xaf\xbf\xec\x28\x99\x93\xd2\x4c\x9e\x97\x86\xf1\x9c\x61\x11\xe0\x43\x90\xa8\x99"
initialSalt Version
Version1    = ByteString
"\x38\x76\x2c\xf7\xf5\x59\x34\xb3\x4d\x17\x9a\xe6\xa4\xc8\x0c\xad\xcc\xbb\x7f\x0a"
initialSalt (Version Word32
v) = QUICException -> ByteString
forall e a. Exception e => e -> a
E.impureThrow (QUICException -> ByteString) -> QUICException -> ByteString
forall a b. (a -> b) -> a -> b
$ Word32 -> QUICException
VersionIsUnknown Word32
v

data InitialSecret

initialSecrets :: Version -> CID -> TrafficSecrets InitialSecret
initialSecrets :: Version -> CID -> TrafficSecrets InitialSecret
initialSecrets Version
v CID
c = (Version -> CID -> ClientTrafficSecret InitialSecret
clientInitialSecret Version
v CID
c, Version -> CID -> ServerTrafficSecret InitialSecret
serverInitialSecret Version
v CID
c)

clientInitialSecret :: Version -> CID -> ClientTrafficSecret InitialSecret
clientInitialSecret :: Version -> CID -> ClientTrafficSecret InitialSecret
clientInitialSecret Version
v CID
c = ByteString -> ClientTrafficSecret InitialSecret
forall a. ByteString -> ClientTrafficSecret a
ClientTrafficSecret (ByteString -> ClientTrafficSecret InitialSecret)
-> ByteString -> ClientTrafficSecret InitialSecret
forall a b. (a -> b) -> a -> b
$ Label -> Version -> CID -> ByteString
initialSecret (ByteString -> Label
Label ByteString
"client in") Version
v CID
c

serverInitialSecret :: Version -> CID -> ServerTrafficSecret InitialSecret
serverInitialSecret :: Version -> CID -> ServerTrafficSecret InitialSecret
serverInitialSecret Version
v CID
c = ByteString -> ServerTrafficSecret InitialSecret
forall a. ByteString -> ServerTrafficSecret a
ServerTrafficSecret (ByteString -> ServerTrafficSecret InitialSecret)
-> ByteString -> ServerTrafficSecret InitialSecret
forall a b. (a -> b) -> a -> b
$ Label -> Version -> CID -> ByteString
initialSecret (ByteString -> Label
Label ByteString
"server in") Version
v CID
c

initialSecret :: Label -> Version -> CID -> ByteString
initialSecret :: Label -> Version -> CID -> ByteString
initialSecret Label
_ Version
GreasingVersion  CID
_  = ByteString
"greasing !!!!"
initialSecret Label
_ Version
GreasingVersion2 CID
_  = ByteString
"greasing !!!!"
initialSecret (Label ByteString
label) Version
ver CID
cid = ByteString
secret
  where
    cipher :: Cipher
cipher    = Cipher
defaultCipher
    hash :: Hash
hash      = Cipher -> Hash
cipherHash Cipher
cipher
    iniSecret :: ByteString
iniSecret = Hash -> ByteString -> ByteString -> ByteString
hkdfExtract Hash
hash (Version -> ByteString
initialSalt Version
ver) (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ CID -> ByteString
fromCID CID
cid
    hashSize :: Int
hashSize  = Hash -> Int
hashDigestSize Hash
hash
    secret :: ByteString
secret    = Hash -> ByteString -> ByteString -> ByteString -> Int -> ByteString
hkdfExpandLabel Hash
hash ByteString
iniSecret ByteString
label ByteString
"" Int
hashSize

aeadKey :: Cipher -> Secret -> Key
aeadKey :: Cipher -> Secret -> Key
aeadKey = Label -> Cipher -> Secret -> Key
genKey (ByteString -> Label
Label ByteString
"quic key")

headerProtectionKey :: Cipher -> Secret -> Key
headerProtectionKey :: Cipher -> Secret -> Key
headerProtectionKey = Label -> Cipher -> Secret -> Key
genKey (ByteString -> Label
Label ByteString
"quic hp")

genKey :: Label -> Cipher -> Secret -> Key
genKey :: Label -> Cipher -> Secret -> Key
genKey (Label ByteString
label) Cipher
cipher (Secret ByteString
secret) = ByteString -> Key
Key ByteString
key
  where
    hash :: Hash
hash    = Cipher -> Hash
cipherHash Cipher
cipher
    bulk :: Bulk
bulk    = Cipher -> Bulk
cipherBulk Cipher
cipher
    keySize :: Int
keySize = Bulk -> Int
bulkKeySize Bulk
bulk
    key :: ByteString
key     = Hash -> ByteString -> ByteString -> ByteString -> Int -> ByteString
hkdfExpandLabel Hash
hash ByteString
secret ByteString
label ByteString
"" Int
keySize

initialVector :: Cipher -> Secret -> IV
initialVector :: Cipher -> Secret -> IV
initialVector Cipher
cipher (Secret ByteString
secret) = ByteString -> IV
IV ByteString
iv
  where
    hash :: Hash
hash   = Cipher -> Hash
cipherHash Cipher
cipher
    bulk :: Bulk
bulk   = Cipher -> Bulk
cipherBulk Cipher
cipher
    ivSize :: Int
ivSize = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
8 (Bulk -> Int
bulkIVSize Bulk
bulk Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Bulk -> Int
bulkExplicitIV Bulk
bulk)
    iv :: ByteString
iv     = Hash -> ByteString -> ByteString -> ByteString -> Int -> ByteString
hkdfExpandLabel Hash
hash ByteString
secret ByteString
"quic iv" ByteString
"" Int
ivSize

nextSecret :: Cipher -> Secret -> Secret
nextSecret :: Cipher -> Secret -> Secret
nextSecret Cipher
cipher (Secret ByteString
secN) = ByteString -> Secret
Secret ByteString
secN1
  where
    label :: ByteString
label    = ByteString
"quic ku"
    hash :: Hash
hash     = Cipher -> Hash
cipherHash Cipher
cipher
    hashSize :: Int
hashSize = Hash -> Int
hashDigestSize Hash
hash
    secN1 :: ByteString
secN1    = Hash -> ByteString -> ByteString -> ByteString -> Int -> ByteString
hkdfExpandLabel Hash
hash ByteString
secN ByteString
label ByteString
"" Int
hashSize

----------------------------------------------------------------

-- It would be nice to take [PlainText] and update AEAD context with
-- [PlainText]. But since each PlainText is not aligned to cipher block,
-- it's impossible.
cipherEncrypt :: Cipher -> Key -> Nonce -> PlainText -> AddDat -> [CipherText]
cipherEncrypt :: Cipher -> Key -> Nonce -> ByteString -> AddDat -> [ByteString]
cipherEncrypt Cipher
cipher
  | Cipher
cipher Cipher -> Cipher -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher
cipher_TLS13_AES128GCM_SHA256        = Key -> Nonce -> ByteString -> AddDat -> [ByteString]
aes128gcmEncrypt
  | Cipher
cipher Cipher -> Cipher -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher
cipher_TLS13_AES128CCM_SHA256        = String -> Key -> Nonce -> ByteString -> AddDat -> [ByteString]
forall a. HasCallStack => String -> a
error String
"cipher_TLS13_AES128CCM_SHA256"
  | Cipher
cipher Cipher -> Cipher -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher
cipher_TLS13_AES256GCM_SHA384        = Key -> Nonce -> ByteString -> AddDat -> [ByteString]
aes256gcmEncrypt
  | Cipher
cipher Cipher -> Cipher -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher
cipher_TLS13_CHACHA20POLY1305_SHA256 = Key -> Nonce -> ByteString -> AddDat -> [ByteString]
chacha20poly1305Encrypt
  | Bool
otherwise                                      = String -> Key -> Nonce -> ByteString -> AddDat -> [ByteString]
forall a. HasCallStack => String -> a
error String
"cipherEncrypt"

cipherDecrypt :: Cipher -> Key -> Nonce -> CipherText -> AddDat -> Maybe PlainText
cipherDecrypt :: Cipher -> Key -> Nonce -> ByteString -> AddDat -> Maybe ByteString
cipherDecrypt Cipher
cipher
  | Cipher
cipher Cipher -> Cipher -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher
cipher_TLS13_AES128GCM_SHA256        = Key -> Nonce -> ByteString -> AddDat -> Maybe ByteString
aes128gcmDecrypt
  | Cipher
cipher Cipher -> Cipher -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher
cipher_TLS13_AES128CCM_SHA256        = String -> Key -> Nonce -> ByteString -> AddDat -> Maybe ByteString
forall a. HasCallStack => String -> a
error String
"cipher_TLS13_AES128CCM_SHA256"
  | Cipher
cipher Cipher -> Cipher -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher
cipher_TLS13_AES256GCM_SHA384        = Key -> Nonce -> ByteString -> AddDat -> Maybe ByteString
aes256gcmDecrypt
  | Cipher
cipher Cipher -> Cipher -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher
cipher_TLS13_CHACHA20POLY1305_SHA256 = Key -> Nonce -> ByteString -> AddDat -> Maybe ByteString
chacha20poly1305Decrypt
  | Bool
otherwise                                      = String -> Key -> Nonce -> ByteString -> AddDat -> Maybe ByteString
forall a. HasCallStack => String -> a
error String
"cipherDecrypt"

-- IMPORTANT: Using 'let' so that parameters can be memorized.
aes128gcmEncrypt :: Key -> (Nonce -> PlainText -> AddDat -> [CipherText])
aes128gcmEncrypt :: Key -> Nonce -> ByteString -> AddDat -> [ByteString]
aes128gcmEncrypt (Key ByteString
key) =
    let aes :: AES128
aes = CryptoFailable AES128 -> AES128
forall a. CryptoFailable a -> a
throwCryptoError (ByteString -> CryptoFailable AES128
forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
cipherInit ByteString
key) :: AES128
    in \(Nonce ByteString
nonce) ByteString
plaintext (AddDat ByteString
ad) ->
      let aead :: AEAD AES128
aead = CryptoFailable (AEAD AES128) -> AEAD AES128
forall a. CryptoFailable a -> a
throwCryptoError (CryptoFailable (AEAD AES128) -> AEAD AES128)
-> CryptoFailable (AEAD AES128) -> AEAD AES128
forall a b. (a -> b) -> a -> b
$ AEADMode -> AES128 -> ByteString -> CryptoFailable (AEAD AES128)
forall cipher iv.
(BlockCipher cipher, ByteArrayAccess iv) =>
AEADMode -> cipher -> iv -> CryptoFailable (AEAD cipher)
aeadInit AEADMode
AEAD_GCM AES128
aes ByteString
nonce
          (AuthTag Bytes
tag0, ByteString
ciphertext) = AEAD AES128
-> ByteString -> ByteString -> Int -> (AuthTag, ByteString)
forall aad ba a.
(ByteArrayAccess aad, ByteArray ba) =>
AEAD a -> aad -> ba -> Int -> (AuthTag, ba)
aeadSimpleEncrypt AEAD AES128
aead ByteString
ad ByteString
plaintext Int
16
          tag :: ByteString
tag = Bytes -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
Byte.convert Bytes
tag0
      in [ByteString
ciphertext,ByteString
tag]

aes128gcmDecrypt :: Key -> (Nonce -> CipherText -> AddDat -> Maybe PlainText)
aes128gcmDecrypt :: Key -> Nonce -> ByteString -> AddDat -> Maybe ByteString
aes128gcmDecrypt (Key ByteString
key) =
    let aes :: AES128
aes = CryptoFailable AES128 -> AES128
forall a. CryptoFailable a -> a
throwCryptoError (ByteString -> CryptoFailable AES128
forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
cipherInit ByteString
key) :: AES128
    in \(Nonce ByteString
nonce) ByteString
ciphertag (AddDat ByteString
ad) ->
      let aead :: AEAD AES128
aead = CryptoFailable (AEAD AES128) -> AEAD AES128
forall a. CryptoFailable a -> a
throwCryptoError (CryptoFailable (AEAD AES128) -> AEAD AES128)
-> CryptoFailable (AEAD AES128) -> AEAD AES128
forall a b. (a -> b) -> a -> b
$ AEADMode -> AES128 -> ByteString -> CryptoFailable (AEAD AES128)
forall cipher iv.
(BlockCipher cipher, ByteArrayAccess iv) =>
AEADMode -> cipher -> iv -> CryptoFailable (AEAD cipher)
aeadInit AEADMode
AEAD_GCM AES128
aes ByteString
nonce
          (ByteString
ciphertext, ByteString
tag) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt (ByteString -> Int
BS.length ByteString
ciphertag Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
16) ByteString
ciphertag
          authtag :: AuthTag
authtag = Bytes -> AuthTag
AuthTag (Bytes -> AuthTag) -> Bytes -> AuthTag
forall a b. (a -> b) -> a -> b
$ ByteString -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
Byte.convert ByteString
tag
       in AEAD AES128
-> ByteString -> ByteString -> AuthTag -> Maybe ByteString
forall aad ba a.
(ByteArrayAccess aad, ByteArray ba) =>
AEAD a -> aad -> ba -> AuthTag -> Maybe ba
aeadSimpleDecrypt AEAD AES128
aead ByteString
ad ByteString
ciphertext AuthTag
authtag

aes256gcmEncrypt :: Key -> (Nonce -> PlainText -> AddDat -> [CipherText])
aes256gcmEncrypt :: Key -> Nonce -> ByteString -> AddDat -> [ByteString]
aes256gcmEncrypt (Key ByteString
key) =
    let aes :: AES256
aes = CryptoFailable AES256 -> AES256
forall a. CryptoFailable a -> a
throwCryptoError (ByteString -> CryptoFailable AES256
forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
cipherInit ByteString
key) :: AES256
    in \(Nonce ByteString
nonce) ByteString
plaintext (AddDat ByteString
ad) ->
      let aead :: AEAD AES256
aead = CryptoFailable (AEAD AES256) -> AEAD AES256
forall a. CryptoFailable a -> a
throwCryptoError (CryptoFailable (AEAD AES256) -> AEAD AES256)
-> CryptoFailable (AEAD AES256) -> AEAD AES256
forall a b. (a -> b) -> a -> b
$ AEADMode -> AES256 -> ByteString -> CryptoFailable (AEAD AES256)
forall cipher iv.
(BlockCipher cipher, ByteArrayAccess iv) =>
AEADMode -> cipher -> iv -> CryptoFailable (AEAD cipher)
aeadInit AEADMode
AEAD_GCM AES256
aes ByteString
nonce
          (AuthTag Bytes
tag0, ByteString
ciphertext) = AEAD AES256
-> ByteString -> ByteString -> Int -> (AuthTag, ByteString)
forall aad ba a.
(ByteArrayAccess aad, ByteArray ba) =>
AEAD a -> aad -> ba -> Int -> (AuthTag, ba)
aeadSimpleEncrypt AEAD AES256
aead ByteString
ad ByteString
plaintext Int
16
          tag :: ByteString
tag = Bytes -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
Byte.convert Bytes
tag0
      in [ByteString
ciphertext, ByteString
tag]

aes256gcmDecrypt :: Key -> (Nonce -> CipherText -> AddDat -> Maybe PlainText)
aes256gcmDecrypt :: Key -> Nonce -> ByteString -> AddDat -> Maybe ByteString
aes256gcmDecrypt (Key ByteString
key) =
    let aes :: AES256
aes = CryptoFailable AES256 -> AES256
forall a. CryptoFailable a -> a
throwCryptoError (ByteString -> CryptoFailable AES256
forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
cipherInit ByteString
key) :: AES256
    in \(Nonce ByteString
nonce) ByteString
ciphertag (AddDat ByteString
ad) ->
      let aead :: AEAD AES256
aead = CryptoFailable (AEAD AES256) -> AEAD AES256
forall a. CryptoFailable a -> a
throwCryptoError (CryptoFailable (AEAD AES256) -> AEAD AES256)
-> CryptoFailable (AEAD AES256) -> AEAD AES256
forall a b. (a -> b) -> a -> b
$ AEADMode -> AES256 -> ByteString -> CryptoFailable (AEAD AES256)
forall cipher iv.
(BlockCipher cipher, ByteArrayAccess iv) =>
AEADMode -> cipher -> iv -> CryptoFailable (AEAD cipher)
aeadInit AEADMode
AEAD_GCM AES256
aes ByteString
nonce
          (ByteString
ciphertext, ByteString
tag) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt (ByteString -> Int
BS.length ByteString
ciphertag Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
16) ByteString
ciphertag
          authtag :: AuthTag
authtag = Bytes -> AuthTag
AuthTag (Bytes -> AuthTag) -> Bytes -> AuthTag
forall a b. (a -> b) -> a -> b
$ ByteString -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
Byte.convert ByteString
tag
      in AEAD AES256
-> ByteString -> ByteString -> AuthTag -> Maybe ByteString
forall aad ba a.
(ByteArrayAccess aad, ByteArray ba) =>
AEAD a -> aad -> ba -> AuthTag -> Maybe ba
aeadSimpleDecrypt AEAD AES256
aead ByteString
ad ByteString
ciphertext AuthTag
authtag

chacha20poly1305Encrypt :: Key -> Nonce -> PlainText -> AddDat -> [CipherText]
chacha20poly1305Encrypt :: Key -> Nonce -> ByteString -> AddDat -> [ByteString]
chacha20poly1305Encrypt (Key ByteString
key) (Nonce ByteString
nonce) ByteString
plaintext (AddDat ByteString
ad) =
    [ByteString
ciphertext,Bytes -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
Byte.convert Bytes
tag]
  where
    st1 :: State
st1 = CryptoFailable State -> State
forall a. CryptoFailable a -> a
throwCryptoError (ByteString -> CryptoFailable Nonce
forall iv. ByteArrayAccess iv => iv -> CryptoFailable Nonce
ChaChaPoly.nonce12 ByteString
nonce CryptoFailable Nonce
-> (Nonce -> CryptoFailable State) -> CryptoFailable State
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Nonce -> CryptoFailable State
forall key.
ByteArrayAccess key =>
key -> Nonce -> CryptoFailable State
ChaChaPoly.initialize ByteString
key)
    st2 :: State
st2 = State -> State
ChaChaPoly.finalizeAAD (ByteString -> State -> State
forall ba. ByteArrayAccess ba => ba -> State -> State
ChaChaPoly.appendAAD ByteString
ad State
st1)
    (ByteString
ciphertext, State
st3) = ByteString -> State -> (ByteString, State)
forall ba. ByteArray ba => ba -> State -> (ba, State)
ChaChaPoly.encrypt ByteString
plaintext State
st2
    Poly1305.Auth Bytes
tag = State -> Auth
ChaChaPoly.finalize State
st3

chacha20poly1305Decrypt :: Key -> Nonce -> CipherText -> AddDat -> Maybe PlainText
chacha20poly1305Decrypt :: Key -> Nonce -> ByteString -> AddDat -> Maybe ByteString
chacha20poly1305Decrypt (Key ByteString
key) (Nonce ByteString
nonce) ByteString
ciphertag (AddDat ByteString
ad) = do
    State
st <- CryptoFailable State -> Maybe State
forall a. CryptoFailable a -> Maybe a
maybeCryptoError (ByteString -> CryptoFailable Nonce
forall iv. ByteArrayAccess iv => iv -> CryptoFailable Nonce
ChaChaPoly.nonce12 ByteString
nonce CryptoFailable Nonce
-> (Nonce -> CryptoFailable State) -> CryptoFailable State
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Nonce -> CryptoFailable State
forall key.
ByteArrayAccess key =>
key -> Nonce -> CryptoFailable State
ChaChaPoly.initialize ByteString
key)
    let st2 :: State
st2 = State -> State
ChaChaPoly.finalizeAAD (ByteString -> State -> State
forall ba. ByteArrayAccess ba => ba -> State -> State
ChaChaPoly.appendAAD ByteString
ad State
st)
        (ByteString
ciphertext, ByteString
tag) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt (ByteString -> Int
BS.length ByteString
ciphertag Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
16) ByteString
ciphertag
        (ByteString
plaintext, State
st3) = ByteString -> State -> (ByteString, State)
forall ba. ByteArray ba => ba -> State -> (ba, State)
ChaChaPoly.decrypt ByteString
ciphertext State
st2
        Poly1305.Auth Bytes
tag' = State -> Auth
ChaChaPoly.finalize State
st3
    if ByteString
tag ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== Bytes -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
Byte.convert Bytes
tag' then ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
plaintext else Maybe ByteString
forall a. Maybe a
Nothing

----------------------------------------------------------------

makeNonce :: IV -> ByteString -> Nonce
makeNonce :: IV -> ByteString -> Nonce
makeNonce (IV ByteString
iv) ByteString
pn = ByteString -> Nonce
Nonce ByteString
nonce
  where
    nonce :: ByteString
nonce = ByteString -> ByteString -> ByteString
bsXORpad ByteString
iv ByteString
pn

----------------------------------------------------------------

encryptPayload :: Cipher -> Key -> IV
               -> (PlainText -> ByteString -> PacketNumber -> [CipherText])
encryptPayload :: Cipher
-> Key -> IV -> ByteString -> ByteString -> Int -> [ByteString]
encryptPayload Cipher
cipher Key
key IV
iv =
    let enc :: Nonce -> ByteString -> AddDat -> [ByteString]
enc = Cipher -> Key -> Nonce -> ByteString -> AddDat -> [ByteString]
cipherEncrypt Cipher
cipher Key
key
        mk :: ByteString -> Nonce
mk  = IV -> ByteString -> Nonce
makeNonce IV
iv
    in \ByteString
plaintext ByteString
header Int
pn -> let bytePN :: ByteString
bytePN = Word64 -> ByteString
bytestring64 (Word64 -> ByteString) -> Word64 -> ByteString
forall a b. (a -> b) -> a -> b
$ Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
pn
                                   nonce :: Nonce
nonce  = ByteString -> Nonce
mk ByteString
bytePN
                               in Nonce -> ByteString -> AddDat -> [ByteString]
enc Nonce
nonce ByteString
plaintext (ByteString -> AddDat
AddDat ByteString
header)

encryptPayload' :: Cipher -> Key -> Nonce -> PlainText -> AddDat -> [CipherText]
encryptPayload' :: Cipher -> Key -> Nonce -> ByteString -> AddDat -> [ByteString]
encryptPayload' Cipher
cipher Key
key Nonce
nonce ByteString
plaintext AddDat
header =
    Cipher -> Key -> Nonce -> ByteString -> AddDat -> [ByteString]
cipherEncrypt Cipher
cipher Key
key Nonce
nonce ByteString
plaintext AddDat
header

----------------------------------------------------------------

decryptPayload :: Cipher -> Key -> IV
               -> (CipherText -> ByteString -> PacketNumber -> Maybe PlainText)
decryptPayload :: Cipher
-> Key -> IV -> ByteString -> ByteString -> Int -> Maybe ByteString
decryptPayload Cipher
cipher Key
key IV
iv =
    let dec :: Nonce -> ByteString -> AddDat -> Maybe ByteString
dec = Cipher -> Key -> Nonce -> ByteString -> AddDat -> Maybe ByteString
cipherDecrypt Cipher
cipher Key
key
        mk :: ByteString -> Nonce
mk  = IV -> ByteString -> Nonce
makeNonce IV
iv
    in \ByteString
ciphertext ByteString
header Int
pn -> let bytePN :: ByteString
bytePN = Word64 -> ByteString
bytestring64 (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
pn)
                                    nonce :: Nonce
nonce = ByteString -> Nonce
mk ByteString
bytePN
                                in Nonce -> ByteString -> AddDat -> Maybe ByteString
dec Nonce
nonce ByteString
ciphertext (ByteString -> AddDat
AddDat ByteString
header)

decryptPayload' :: Cipher -> Key -> Nonce -> CipherText -> AddDat -> Maybe PlainText
decryptPayload' :: Cipher -> Key -> Nonce -> ByteString -> AddDat -> Maybe ByteString
decryptPayload' Cipher
cipher Key
key Nonce
nonce ByteString
ciphertext AddDat
header =
    Cipher -> Key -> Nonce -> ByteString -> AddDat -> Maybe ByteString
cipherDecrypt Cipher
cipher Key
key Nonce
nonce ByteString
ciphertext AddDat
header

----------------------------------------------------------------

protectionMask :: Cipher -> Key -> (Sample -> Mask)
protectionMask :: Cipher -> Key -> Sample -> Mask
protectionMask Cipher
cipher Key
key =
    let f :: Sample -> Mask
f = Cipher -> Key -> Sample -> Mask
cipherHeaderProtection Cipher
cipher Key
key
    in \Sample
sample -> Sample -> Mask
f Sample
sample

cipherHeaderProtection :: Cipher -> Key -> (Sample -> Mask)
cipherHeaderProtection :: Cipher -> Key -> Sample -> Mask
cipherHeaderProtection Cipher
cipher Key
key
  | Cipher
cipher Cipher -> Cipher -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher
cipher_TLS13_AES128GCM_SHA256        = Key -> Sample -> Mask
aes128ecbEncrypt Key
key
  | Cipher
cipher Cipher -> Cipher -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher
cipher_TLS13_AES128CCM_SHA256        = String -> Sample -> Mask
forall a. HasCallStack => String -> a
error String
"cipher_TLS13_AES128CCM_SHA256"
  | Cipher
cipher Cipher -> Cipher -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher
cipher_TLS13_AES256GCM_SHA384        = Key -> Sample -> Mask
aes256ecbEncrypt Key
key
  | Cipher
cipher Cipher -> Cipher -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher
cipher_TLS13_CHACHA20POLY1305_SHA256 = Key -> Sample -> Mask
chachaEncrypt Key
key
  | Bool
otherwise                                      = String -> Sample -> Mask
forall a. HasCallStack => String -> a
error String
"cipherHeaderProtection"

aes128ecbEncrypt :: Key -> (Sample -> Mask)
aes128ecbEncrypt :: Key -> Sample -> Mask
aes128ecbEncrypt (Key ByteString
key) =
    let encrypt :: ByteString -> ByteString
encrypt = AES128 -> ByteString -> ByteString
forall cipher ba.
(BlockCipher cipher, ByteArray ba) =>
cipher -> ba -> ba
ecbEncrypt (CryptoFailable AES128 -> AES128
forall a. CryptoFailable a -> a
throwCryptoError (ByteString -> CryptoFailable AES128
forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
cipherInit ByteString
key) :: AES128)
    in \(Sample ByteString
sample) -> let mask :: ByteString
mask = ByteString -> ByteString
encrypt ByteString
sample
                           in ByteString -> Mask
Mask ByteString
mask

aes256ecbEncrypt :: Key -> (Sample -> Mask)
aes256ecbEncrypt :: Key -> Sample -> Mask
aes256ecbEncrypt (Key ByteString
key) =
    let encrypt :: ByteString -> ByteString
encrypt = AES256 -> ByteString -> ByteString
forall cipher ba.
(BlockCipher cipher, ByteArray ba) =>
cipher -> ba -> ba
ecbEncrypt (CryptoFailable AES256 -> AES256
forall a. CryptoFailable a -> a
throwCryptoError (ByteString -> CryptoFailable AES256
forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
cipherInit ByteString
key) :: AES256)
    in \(Sample ByteString
sample) -> let mask :: ByteString
mask = ByteString -> ByteString
encrypt ByteString
sample
                           in ByteString -> Mask
Mask ByteString
mask

chachaEncrypt :: Key -> Sample -> Mask
chachaEncrypt :: Key -> Sample -> Mask
chachaEncrypt (Key ByteString
key) (Sample ByteString
sample0) = ByteString -> Mask
Mask ByteString
mask
  where
    -- fixme: cryptonite hard-codes the counter, sigh
    (ByteString
_counter,ByteString
nonce) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
4 ByteString
sample0
    st :: State
st = Int -> ByteString -> ByteString -> State
forall key nonce.
(ByteArrayAccess key, ByteArrayAccess nonce) =>
Int -> key -> nonce -> State
ChaCha.initialize Int
20 ByteString
key ByteString
nonce
    (ByteString
mask,State
_) = State -> ByteString -> (ByteString, State)
forall ba. ByteArray ba => State -> ba -> (ba, State)
ChaCha.combine State
st ByteString
"\x0\x0\x0\x0\x0"

tagLength :: Cipher -> Int
tagLength :: Cipher -> Int
tagLength Cipher
cipher
  | Cipher
cipher Cipher -> Cipher -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher
cipher_TLS13_AES128GCM_SHA256        = Int
16
  | Cipher
cipher Cipher -> Cipher -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher
cipher_TLS13_AES128CCM_SHA256        = Int
16
  | Cipher
cipher Cipher -> Cipher -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher
cipher_TLS13_AES256GCM_SHA384        = Int
16
  | Cipher
cipher Cipher -> Cipher -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher
cipher_TLS13_CHACHA20POLY1305_SHA256 = Int
16 -- fixme
  | Bool
otherwise                                      = String -> Int
forall a. HasCallStack => String -> a
error String
"tagLength"

sampleLength :: Cipher -> Int
sampleLength :: Cipher -> Int
sampleLength Cipher
cipher
  | Cipher
cipher Cipher -> Cipher -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher
cipher_TLS13_AES128GCM_SHA256        = Int
16
  | Cipher
cipher Cipher -> Cipher -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher
cipher_TLS13_AES128CCM_SHA256        = Int
16
  | Cipher
cipher Cipher -> Cipher -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher
cipher_TLS13_AES256GCM_SHA384        = Int
16
  | Cipher
cipher Cipher -> Cipher -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher
cipher_TLS13_CHACHA20POLY1305_SHA256 = Int
16 -- fixme
  | Bool
otherwise                                      = String -> Int
forall a. HasCallStack => String -> a
error String
"sampleLength"

bsXOR :: ByteString -> ByteString -> ByteString
bsXOR :: ByteString -> ByteString -> ByteString
bsXOR = ByteString -> ByteString -> ByteString
forall a b c.
(ByteArrayAccess a, ByteArrayAccess b, ByteArray c) =>
a -> b -> c
Byte.xor

-- XORing IV and a packet numbr with left padded.
--             src0
-- IV          +IIIIIIIIIIIIIIIIII--------+
--                 diff          src1
-- PN          +000000000000000000+-------+
--             dst
-- Nonce       +IIIIIIIIIIIIIIIIII--------+
bsXORpad :: ByteString -> ByteString -> ByteString
bsXORpad :: ByteString -> ByteString -> ByteString
bsXORpad (PS ForeignPtr Word8
fp0 Int
off0 Int
len0) (PS ForeignPtr Word8
fp1 Int
off1 Int
len1)
  | Int
len0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
len1 = String -> ByteString
forall a. HasCallStack => String -> a
error String
"bsXORpad"
  | Bool
otherwise = Int -> (Ptr Word8 -> IO ()) -> ByteString
BS.unsafeCreate Int
len0 ((Ptr Word8 -> IO ()) -> ByteString)
-> (Ptr Word8 -> IO ()) -> ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dst ->
  ForeignPtr Word8 -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fp0 ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
p0 ->
    ForeignPtr Word8 -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fp1 ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
p1 -> do
        let src0 :: Ptr b
src0 = Ptr Word8
p0 Ptr Word8 -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
off0
        let src1 :: Ptr b
src1 = Ptr Word8
p1 Ptr Word8 -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
off1
        let diff :: Int
diff = Int
len0 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len1
        Ptr Word8 -> Ptr Word8 -> Int -> IO ()
BS.memcpy Ptr Word8
dst Ptr Word8
forall b. Ptr b
src0 Int
diff
        Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
loop (Ptr Word8
dst Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
diff) (Ptr Any
forall b. Ptr b
src0 Ptr Any -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
diff) Ptr Word8
forall b. Ptr b
src1 Int
len1
  where
    loop :: Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
    loop :: Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
loop Ptr Word8
_ Ptr Word8
_ Ptr Word8
_ Int
0 = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    loop Ptr Word8
dst Ptr Word8
src0 Ptr Word8
src1 Int
len = do
        Word8
w1 <- Ptr Word8 -> IO Word8
forall a. Storable a => Ptr a -> IO a
peek Ptr Word8
src0
        Word8
w2 <- Ptr Word8 -> IO Word8
forall a. Storable a => Ptr a -> IO a
peek Ptr Word8
src1
        Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
dst (Word8
w1 Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
`xor` Word8
w2)
        Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
loop (Ptr Word8
dst Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) (Ptr Word8
src0 Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) (Ptr Word8
src1 Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

{-
bsXORpad' :: ByteString -> ByteString -> ByteString
bsXORpad' iv pn = BS.pack $ zipWith xor ivl pnl
  where
    ivl = BS.unpack iv
    diff = BS.length iv - BS.length pn
    pnl = replicate diff 0 ++ BS.unpack pn
-}

----------------------------------------------------------------

calculateIntegrityTag :: Version -> CID -> ByteString -> ByteString
calculateIntegrityTag :: Version -> CID -> ByteString -> ByteString
calculateIntegrityTag Version
ver CID
oCID ByteString
pseudo0 =
    [ByteString] -> ByteString
BS.concat ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ Key -> Nonce -> ByteString -> AddDat -> [ByteString]
aes128gcmEncrypt Key
key Nonce
nonce ByteString
"" (ByteString -> AddDat
AddDat ByteString
pseudo)
  where
    (ShortByteString
ocid, Word8
ocidlen) = CID -> (ShortByteString, Word8)
unpackCID CID
oCID
    pseudo :: ByteString
pseudo = [ByteString] -> ByteString
BS.concat [Word8 -> ByteString
BS.singleton Word8
ocidlen
                       ,ShortByteString -> ByteString
Short.fromShort ShortByteString
ocid
                       ,ByteString
pseudo0]
    key :: Key
key | Version
ver Version -> Version -> Bool
forall a. Eq a => a -> a -> Bool
== Version
Version1 = ByteString -> Key
Key ByteString
"\xbe\x0c\x69\x0b\x9f\x66\x57\x5a\x1d\x76\x6b\x54\xe3\x68\xc8\x4e"
        | Bool
otherwise       = ByteString -> Key
Key ByteString
"\xcc\xce\x18\x7e\xd0\x9a\x09\xd0\x57\x28\x15\x5a\x6c\xb9\x6b\xe1"
    nonce :: Nonce
nonce | Version
ver Version -> Version -> Bool
forall a. Eq a => a -> a -> Bool
== Version
Version1 = ByteString -> Nonce
Nonce ByteString
"\x46\x15\x99\xd3\x5d\x63\x2b\xf2\x23\x98\x25\xbb"
          | Bool
otherwise       = ByteString -> Nonce
Nonce ByteString
"\xe5\x49\x30\xf9\x7f\x21\x36\xf0\x53\x0a\x8c\x1c"