module Network.TLS.Crypto.DH (
    -- * DH types
    DHParams,
    DHPublic,
    DHPrivate,
    DHKey,

    -- * DH methods
    dhPublic,
    dhPrivate,
    dhParams,
    dhParamsGetP,
    dhParamsGetG,
    dhParamsGetBits,
    dhGenerateKeyPair,
    dhGetShared,
    dhValid,
    dhUnwrap,
    dhUnwrapPublic,
) where

import Crypto.Number.Basic (numBits)
import qualified Crypto.PubKey.DH as DH
import qualified Data.ByteArray as B
import Network.TLS.RNG

type DHPublic = DH.PublicNumber
type DHPrivate = DH.PrivateNumber
type DHParams = DH.Params
type DHKey = DH.SharedKey

dhPublic :: Integer -> DHPublic
dhPublic :: Integer -> DHPublic
dhPublic = Integer -> DHPublic
DH.PublicNumber

dhPrivate :: Integer -> DHPrivate
dhPrivate :: Integer -> DHPrivate
dhPrivate = Integer -> DHPrivate
DH.PrivateNumber

dhParams :: Integer -> Integer -> DHParams
dhParams :: Integer -> Integer -> DHParams
dhParams Integer
p Integer
g = Integer -> Integer -> Int -> DHParams
DH.Params Integer
p Integer
g (Integer -> Int
numBits Integer
p)

dhGenerateKeyPair :: MonadRandom r => DHParams -> r (DHPrivate, DHPublic)
dhGenerateKeyPair :: forall (r :: * -> *).
MonadRandom r =>
DHParams -> r (DHPrivate, DHPublic)
dhGenerateKeyPair DHParams
params = do
    DHPrivate
priv <- DHParams -> r DHPrivate
forall (m :: * -> *). MonadRandom m => DHParams -> m DHPrivate
DH.generatePrivate DHParams
params
    let pub :: DHPublic
pub = DHParams -> DHPrivate -> DHPublic
DH.calculatePublic DHParams
params DHPrivate
priv
    (DHPrivate, DHPublic) -> r (DHPrivate, DHPublic)
forall a. a -> r a
forall (m :: * -> *) a. Monad m => a -> m a
return (DHPrivate
priv, DHPublic
pub)

dhGetShared :: DHParams -> DHPrivate -> DHPublic -> DHKey
dhGetShared :: DHParams -> DHPrivate -> DHPublic -> DHKey
dhGetShared DHParams
params DHPrivate
priv DHPublic
pub =
    DHKey -> DHKey
stripLeadingZeros (DHParams -> DHPrivate -> DHPublic -> DHKey
DH.getShared DHParams
params DHPrivate
priv DHPublic
pub)
  where
    -- strips leading zeros from the result of DH.getShared, as required
    -- for DH(E) pre-main secret in SSL/TLS before version 1.3.
    stripLeadingZeros :: DHKey -> DHKey
stripLeadingZeros (DH.SharedKey ScrubbedBytes
sb) = ScrubbedBytes -> DHKey
DH.SharedKey ((ScrubbedBytes, ScrubbedBytes) -> ScrubbedBytes
forall a b. (a, b) -> b
snd ((ScrubbedBytes, ScrubbedBytes) -> ScrubbedBytes)
-> (ScrubbedBytes, ScrubbedBytes) -> ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ (Word8 -> Bool) -> ScrubbedBytes -> (ScrubbedBytes, ScrubbedBytes)
forall bs. ByteArray bs => (Word8 -> Bool) -> bs -> (bs, bs)
B.span (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0) ScrubbedBytes
sb)

-- Check that group element in not in the 2-element subgroup { 1, p - 1 }.
-- See RFC 7919 section 3 and NIST SP 56A rev 2 section 5.6.2.3.1.
-- This verification is enough when using a safe prime.
dhValid :: DHParams -> Integer -> Bool
dhValid :: DHParams -> Integer -> Bool
dhValid (DH.Params Integer
p Integer
_ Int
_) Integer
y = Integer
1 Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
y Bool -> Bool -> Bool
&& Integer
y Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
p Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1

dhUnwrap :: DHParams -> DHPublic -> [Integer]
dhUnwrap :: DHParams -> DHPublic -> [Integer]
dhUnwrap (DH.Params Integer
p Integer
g Int
_) (DH.PublicNumber Integer
y) = [Integer
p, Integer
g, Integer
y]

dhParamsGetP :: DHParams -> Integer
dhParamsGetP :: DHParams -> Integer
dhParamsGetP (DH.Params Integer
p Integer
_ Int
_) = Integer
p

dhParamsGetG :: DHParams -> Integer
dhParamsGetG :: DHParams -> Integer
dhParamsGetG (DH.Params Integer
_ Integer
g Int
_) = Integer
g

dhParamsGetBits :: DHParams -> Int
dhParamsGetBits :: DHParams -> Int
dhParamsGetBits (DH.Params Integer
_ Integer
_ Int
b) = Int
b

dhUnwrapPublic :: DHPublic -> Integer
dhUnwrapPublic :: DHPublic -> Integer
dhUnwrapPublic (DH.PublicNumber Integer
y) = Integer
y