{-# LANGUAGE CPP #-}

module Network.TLS.Credentials (
    Credential,
    Credentials (..),
    credentialLoadX509,
    credentialLoadX509FromMemory,
    credentialLoadX509Chain,
    credentialLoadX509ChainFromMemory,
    credentialsFindForSigning,
    credentialsFindForDecrypting,
    credentialsListSigningAlgorithms,
    credentialPublicPrivateKeys,
    credentialMatchesHashSignatures,
) where

import Data.X509
import Data.X509.File
import Data.X509.Memory
import Network.TLS.Crypto
import Network.TLS.Imports
import Network.TLS.X509

import qualified Data.X509 as X509
import qualified Network.TLS.Struct as TLS

type Credential = (CertificateChain, PrivKey)

newtype Credentials = Credentials [Credential] deriving (Int -> Credentials -> ShowS
[Credentials] -> ShowS
Credentials -> String
(Int -> Credentials -> ShowS)
-> (Credentials -> String)
-> ([Credentials] -> ShowS)
-> Show Credentials
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Credentials -> ShowS
showsPrec :: Int -> Credentials -> ShowS
$cshow :: Credentials -> String
show :: Credentials -> String
$cshowList :: [Credentials] -> ShowS
showList :: [Credentials] -> ShowS
Show)

instance Semigroup Credentials where
    Credentials [Credential]
l1 <> :: Credentials -> Credentials -> Credentials
<> Credentials [Credential]
l2 = [Credential] -> Credentials
Credentials ([Credential]
l1 [Credential] -> [Credential] -> [Credential]
forall a. [a] -> [a] -> [a]
++ [Credential]
l2)

instance Monoid Credentials where
    mempty :: Credentials
mempty = [Credential] -> Credentials
Credentials []
#if !(MIN_VERSION_base(4,11,0))
    mappend (Credentials l1) (Credentials l2) = Credentials (l1 ++ l2)
#endif

-- | try to create a new credential object from a public certificate
-- and the associated private key that are stored on the filesystem
-- in PEM format.
credentialLoadX509
    :: FilePath
    -- ^ public certificate (X.509 format)
    -> FilePath
    -- ^ private key associated
    -> IO (Either String Credential)
credentialLoadX509 :: String -> String -> IO (Either String Credential)
credentialLoadX509 String
certFile = String -> [String] -> String -> IO (Either String Credential)
credentialLoadX509Chain String
certFile []

-- | similar to 'credentialLoadX509' but take the certificate
-- and private key from memory instead of from the filesystem.
credentialLoadX509FromMemory
    :: ByteString
    -> ByteString
    -> Either String Credential
credentialLoadX509FromMemory :: ByteString -> ByteString -> Either String Credential
credentialLoadX509FromMemory ByteString
certData =
    ByteString
-> [ByteString] -> ByteString -> Either String Credential
credentialLoadX509ChainFromMemory ByteString
certData []

-- | similar to 'credentialLoadX509' but also allow specifying chain
-- certificates.
credentialLoadX509Chain
    :: FilePath
    -- ^ public certificate (X.509 format)
    -> [FilePath]
    -- ^ chain certificates (X.509 format)
    -> FilePath
    -- ^ private key associated
    -> IO (Either String Credential)
credentialLoadX509Chain :: String -> [String] -> String -> IO (Either String Credential)
credentialLoadX509Chain String
certFile [String]
chainFiles String
privateFile = do
    [SignedExact Certificate]
x509 <- String -> IO [SignedExact Certificate]
forall a.
(ASN1Object a, Eq a, Show a) =>
String -> IO [SignedExact a]
readSignedObject String
certFile
    [[SignedExact Certificate]]
chains <- (String -> IO [SignedExact Certificate])
-> [String] -> IO [[SignedExact Certificate]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM String -> IO [SignedExact Certificate]
forall a.
(ASN1Object a, Eq a, Show a) =>
String -> IO [SignedExact a]
readSignedObject [String]
chainFiles
    [PrivKey]
keys <- String -> IO [PrivKey]
readKeyFile String
privateFile
    case [PrivKey]
keys of
        [] -> Either String Credential -> IO (Either String Credential)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String Credential -> IO (Either String Credential))
-> Either String Credential -> IO (Either String Credential)
forall a b. (a -> b) -> a -> b
$ String -> Either String Credential
forall a b. a -> Either a b
Left String
"no keys found"
        (PrivKey
k : [PrivKey]
_) -> Either String Credential -> IO (Either String Credential)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String Credential -> IO (Either String Credential))
-> Either String Credential -> IO (Either String Credential)
forall a b. (a -> b) -> a -> b
$ Credential -> Either String Credential
forall a b. b -> Either a b
Right ([SignedExact Certificate] -> CertificateChain
CertificateChain ([SignedExact Certificate] -> CertificateChain)
-> ([[SignedExact Certificate]] -> [SignedExact Certificate])
-> [[SignedExact Certificate]]
-> CertificateChain
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[SignedExact Certificate]] -> [SignedExact Certificate]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[SignedExact Certificate]] -> CertificateChain)
-> [[SignedExact Certificate]] -> CertificateChain
forall a b. (a -> b) -> a -> b
$ [SignedExact Certificate]
x509 [SignedExact Certificate]
-> [[SignedExact Certificate]] -> [[SignedExact Certificate]]
forall a. a -> [a] -> [a]
: [[SignedExact Certificate]]
chains, PrivKey
k)

-- | similar to 'credentialLoadX509FromMemory' but also allow
-- specifying chain certificates.
credentialLoadX509ChainFromMemory
    :: ByteString
    -> [ByteString]
    -> ByteString
    -> Either String Credential
credentialLoadX509ChainFromMemory :: ByteString
-> [ByteString] -> ByteString -> Either String Credential
credentialLoadX509ChainFromMemory ByteString
certData [ByteString]
chainData ByteString
privateData =
    let x509 :: [SignedExact Certificate]
x509 = ByteString -> [SignedExact Certificate]
forall a.
(ASN1Object a, Eq a, Show a) =>
ByteString -> [SignedExact a]
readSignedObjectFromMemory ByteString
certData
        chains :: [[SignedExact Certificate]]
chains = (ByteString -> [SignedExact Certificate])
-> [ByteString] -> [[SignedExact Certificate]]
forall a b. (a -> b) -> [a] -> [b]
map ByteString -> [SignedExact Certificate]
forall a.
(ASN1Object a, Eq a, Show a) =>
ByteString -> [SignedExact a]
readSignedObjectFromMemory [ByteString]
chainData
        keys :: [PrivKey]
keys = ByteString -> [PrivKey]
readKeyFileFromMemory ByteString
privateData
     in case [PrivKey]
keys of
            [] -> String -> Either String Credential
forall a b. a -> Either a b
Left String
"no keys found"
            (PrivKey
k : [PrivKey]
_) -> Credential -> Either String Credential
forall a b. b -> Either a b
Right ([SignedExact Certificate] -> CertificateChain
CertificateChain ([SignedExact Certificate] -> CertificateChain)
-> ([[SignedExact Certificate]] -> [SignedExact Certificate])
-> [[SignedExact Certificate]]
-> CertificateChain
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[SignedExact Certificate]] -> [SignedExact Certificate]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[SignedExact Certificate]] -> CertificateChain)
-> [[SignedExact Certificate]] -> CertificateChain
forall a b. (a -> b) -> a -> b
$ [SignedExact Certificate]
x509 [SignedExact Certificate]
-> [[SignedExact Certificate]] -> [[SignedExact Certificate]]
forall a. a -> [a] -> [a]
: [[SignedExact Certificate]]
chains, PrivKey
k)

credentialsListSigningAlgorithms :: Credentials -> [KeyExchangeSignatureAlg]
credentialsListSigningAlgorithms :: Credentials -> [KeyExchangeSignatureAlg]
credentialsListSigningAlgorithms (Credentials [Credential]
l) = (Credential -> Maybe KeyExchangeSignatureAlg)
-> [Credential] -> [KeyExchangeSignatureAlg]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Credential -> Maybe KeyExchangeSignatureAlg
credentialCanSign [Credential]
l

credentialsFindForSigning
    :: KeyExchangeSignatureAlg -> Credentials -> Maybe Credential
credentialsFindForSigning :: KeyExchangeSignatureAlg -> Credentials -> Maybe Credential
credentialsFindForSigning KeyExchangeSignatureAlg
kxsAlg (Credentials [Credential]
l) = (Credential -> Bool) -> [Credential] -> Maybe Credential
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find Credential -> Bool
forSigning [Credential]
l
  where
    forSigning :: Credential -> Bool
forSigning Credential
cred = case Credential -> Maybe KeyExchangeSignatureAlg
credentialCanSign Credential
cred of
        Maybe KeyExchangeSignatureAlg
Nothing -> Bool
False
        Just KeyExchangeSignatureAlg
kxs -> KeyExchangeSignatureAlg
kxs KeyExchangeSignatureAlg -> KeyExchangeSignatureAlg -> Bool
forall a. Eq a => a -> a -> Bool
== KeyExchangeSignatureAlg
kxsAlg

credentialsFindForDecrypting :: Credentials -> Maybe Credential
credentialsFindForDecrypting :: Credentials -> Maybe Credential
credentialsFindForDecrypting (Credentials [Credential]
l) = (Credential -> Bool) -> [Credential] -> Maybe Credential
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find Credential -> Bool
forEncrypting [Credential]
l
  where
    forEncrypting :: Credential -> Bool
forEncrypting Credential
cred = () -> Maybe ()
forall a. a -> Maybe a
Just () Maybe () -> Maybe () -> Bool
forall a. Eq a => a -> a -> Bool
== Credential -> Maybe ()
credentialCanDecrypt Credential
cred

-- here we assume that only RSA is supported for key encipherment (encryption/decryption)
-- we keep the same construction as 'credentialCanSign', returning a Maybe of () in case
-- this change in future.
credentialCanDecrypt :: Credential -> Maybe ()
credentialCanDecrypt :: Credential -> Maybe ()
credentialCanDecrypt (CertificateChain
chain, PrivKey
priv) =
    case (PubKey
pub, PrivKey
priv) of
        (PubKeyRSA PublicKey
_, PrivKeyRSA PrivateKey
_) ->
            case Extensions -> Maybe ExtKeyUsage
forall a. Extension a => Extensions -> Maybe a
extensionGet (Certificate -> Extensions
certExtensions Certificate
cert) of
                Maybe ExtKeyUsage
Nothing -> () -> Maybe ()
forall a. a -> Maybe a
Just ()
                Just (ExtKeyUsage [ExtKeyUsageFlag]
flags)
                    | ExtKeyUsageFlag
KeyUsage_keyEncipherment ExtKeyUsageFlag -> [ExtKeyUsageFlag] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ExtKeyUsageFlag]
flags -> () -> Maybe ()
forall a. a -> Maybe a
Just ()
                    | Bool
otherwise -> Maybe ()
forall a. Maybe a
Nothing
        (PubKey, PrivKey)
_ -> Maybe ()
forall a. Maybe a
Nothing
  where
    cert :: Certificate
cert = SignedExact Certificate -> Certificate
getCertificate SignedExact Certificate
signed
    pub :: PubKey
pub = Certificate -> PubKey
certPubKey Certificate
cert
    signed :: SignedExact Certificate
signed = CertificateChain -> SignedExact Certificate
getCertificateChainLeaf CertificateChain
chain

credentialCanSign :: Credential -> Maybe KeyExchangeSignatureAlg
credentialCanSign :: Credential -> Maybe KeyExchangeSignatureAlg
credentialCanSign (CertificateChain
chain, PrivKey
priv) =
    case Extensions -> Maybe ExtKeyUsage
forall a. Extension a => Extensions -> Maybe a
extensionGet (Certificate -> Extensions
certExtensions Certificate
cert) of
        Maybe ExtKeyUsage
Nothing -> (PubKey, PrivKey) -> Maybe KeyExchangeSignatureAlg
findKeyExchangeSignatureAlg (PubKey
pub, PrivKey
priv)
        Just (ExtKeyUsage [ExtKeyUsageFlag]
flags)
            | ExtKeyUsageFlag
KeyUsage_digitalSignature ExtKeyUsageFlag -> [ExtKeyUsageFlag] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ExtKeyUsageFlag]
flags ->
                (PubKey, PrivKey) -> Maybe KeyExchangeSignatureAlg
findKeyExchangeSignatureAlg (PubKey
pub, PrivKey
priv)
            | Bool
otherwise -> Maybe KeyExchangeSignatureAlg
forall a. Maybe a
Nothing
  where
    cert :: Certificate
cert = SignedExact Certificate -> Certificate
getCertificate SignedExact Certificate
signed
    pub :: PubKey
pub = Certificate -> PubKey
certPubKey Certificate
cert
    signed :: SignedExact Certificate
signed = CertificateChain -> SignedExact Certificate
getCertificateChainLeaf CertificateChain
chain

credentialPublicPrivateKeys :: Credential -> (PubKey, PrivKey)
credentialPublicPrivateKeys :: Credential -> (PubKey, PrivKey)
credentialPublicPrivateKeys (CertificateChain
chain, PrivKey
priv) = PubKey
pub PubKey -> (PubKey, PrivKey) -> (PubKey, PrivKey)
forall a b. a -> b -> b
`seq` (PubKey
pub, PrivKey
priv)
  where
    cert :: Certificate
cert = SignedExact Certificate -> Certificate
getCertificate SignedExact Certificate
signed
    pub :: PubKey
pub = Certificate -> PubKey
certPubKey Certificate
cert
    signed :: SignedExact Certificate
signed = CertificateChain -> SignedExact Certificate
getCertificateChainLeaf CertificateChain
chain

getHashSignature :: SignedCertificate -> Maybe TLS.HashAndSignatureAlgorithm
getHashSignature :: SignedExact Certificate -> Maybe HashAndSignatureAlgorithm
getHashSignature SignedExact Certificate
signed =
    case Signed Certificate -> SignatureALG
forall a. (Show a, Eq a, ASN1Object a) => Signed a -> SignatureALG
signedAlg (Signed Certificate -> SignatureALG)
-> Signed Certificate -> SignatureALG
forall a b. (a -> b) -> a -> b
$ SignedExact Certificate -> Signed Certificate
forall a. (Show a, Eq a, ASN1Object a) => SignedExact a -> Signed a
getSigned SignedExact Certificate
signed of
        SignatureALG HashALG
hashAlg PubKeyALG
PubKeyALG_RSA -> SignatureAlgorithm -> HashALG -> Maybe HashAndSignatureAlgorithm
forall {b}. b -> HashALG -> Maybe (HashAlgorithm, b)
convertHash SignatureAlgorithm
TLS.SignatureRSA HashALG
hashAlg
        SignatureALG HashALG
hashAlg PubKeyALG
PubKeyALG_DSA -> SignatureAlgorithm -> HashALG -> Maybe HashAndSignatureAlgorithm
forall {b}. b -> HashALG -> Maybe (HashAlgorithm, b)
convertHash SignatureAlgorithm
TLS.SignatureDSA HashALG
hashAlg
        SignatureALG HashALG
hashAlg PubKeyALG
PubKeyALG_EC -> SignatureAlgorithm -> HashALG -> Maybe HashAndSignatureAlgorithm
forall {b}. b -> HashALG -> Maybe (HashAlgorithm, b)
convertHash SignatureAlgorithm
TLS.SignatureECDSA HashALG
hashAlg
        SignatureALG HashALG
X509.HashSHA256 PubKeyALG
PubKeyALG_RSAPSS -> HashAndSignatureAlgorithm -> Maybe HashAndSignatureAlgorithm
forall a. a -> Maybe a
Just (HashAlgorithm
TLS.HashIntrinsic, SignatureAlgorithm
TLS.SignatureRSApssRSAeSHA256)
        SignatureALG HashALG
X509.HashSHA384 PubKeyALG
PubKeyALG_RSAPSS -> HashAndSignatureAlgorithm -> Maybe HashAndSignatureAlgorithm
forall a. a -> Maybe a
Just (HashAlgorithm
TLS.HashIntrinsic, SignatureAlgorithm
TLS.SignatureRSApssRSAeSHA384)
        SignatureALG HashALG
X509.HashSHA512 PubKeyALG
PubKeyALG_RSAPSS -> HashAndSignatureAlgorithm -> Maybe HashAndSignatureAlgorithm
forall a. a -> Maybe a
Just (HashAlgorithm
TLS.HashIntrinsic, SignatureAlgorithm
TLS.SignatureRSApssRSAeSHA512)
        SignatureALG_IntrinsicHash PubKeyALG
PubKeyALG_Ed25519 -> HashAndSignatureAlgorithm -> Maybe HashAndSignatureAlgorithm
forall a. a -> Maybe a
Just (HashAlgorithm
TLS.HashIntrinsic, SignatureAlgorithm
TLS.SignatureEd25519)
        SignatureALG_IntrinsicHash PubKeyALG
PubKeyALG_Ed448 -> HashAndSignatureAlgorithm -> Maybe HashAndSignatureAlgorithm
forall a. a -> Maybe a
Just (HashAlgorithm
TLS.HashIntrinsic, SignatureAlgorithm
TLS.SignatureEd448)
        SignatureALG
_ -> Maybe HashAndSignatureAlgorithm
forall a. Maybe a
Nothing
  where
    convertHash :: b -> HashALG -> Maybe (HashAlgorithm, b)
convertHash b
sig HashALG
X509.HashMD5 = (HashAlgorithm, b) -> Maybe (HashAlgorithm, b)
forall a. a -> Maybe a
Just (HashAlgorithm
TLS.HashMD5, b
sig)
    convertHash b
sig HashALG
X509.HashSHA1 = (HashAlgorithm, b) -> Maybe (HashAlgorithm, b)
forall a. a -> Maybe a
Just (HashAlgorithm
TLS.HashSHA1, b
sig)
    convertHash b
sig HashALG
X509.HashSHA224 = (HashAlgorithm, b) -> Maybe (HashAlgorithm, b)
forall a. a -> Maybe a
Just (HashAlgorithm
TLS.HashSHA224, b
sig)
    convertHash b
sig HashALG
X509.HashSHA256 = (HashAlgorithm, b) -> Maybe (HashAlgorithm, b)
forall a. a -> Maybe a
Just (HashAlgorithm
TLS.HashSHA256, b
sig)
    convertHash b
sig HashALG
X509.HashSHA384 = (HashAlgorithm, b) -> Maybe (HashAlgorithm, b)
forall a. a -> Maybe a
Just (HashAlgorithm
TLS.HashSHA384, b
sig)
    convertHash b
sig HashALG
X509.HashSHA512 = (HashAlgorithm, b) -> Maybe (HashAlgorithm, b)
forall a. a -> Maybe a
Just (HashAlgorithm
TLS.HashSHA512, b
sig)
    convertHash b
_ HashALG
_ = Maybe (HashAlgorithm, b)
forall a. Maybe a
Nothing

-- | Checks whether certificate signatures in the chain comply with a list of
-- hash/signature algorithm pairs.  Currently the verification applies only to
-- the signature of the leaf certificate, and when not self-signed.  This may
-- be extended to additional chain elements in the future.
credentialMatchesHashSignatures
    :: [TLS.HashAndSignatureAlgorithm] -> Credential -> Bool
credentialMatchesHashSignatures :: [HashAndSignatureAlgorithm] -> Credential -> Bool
credentialMatchesHashSignatures [HashAndSignatureAlgorithm]
hashSigs (CertificateChain
chain, PrivKey
_) =
    case CertificateChain
chain of
        CertificateChain [] -> Bool
True
        CertificateChain (SignedExact Certificate
leaf : [SignedExact Certificate]
_) -> SignedExact Certificate -> Bool
isSelfSigned SignedExact Certificate
leaf Bool -> Bool -> Bool
|| SignedExact Certificate -> Bool
matchHashSig SignedExact Certificate
leaf
  where
    matchHashSig :: SignedExact Certificate -> Bool
matchHashSig SignedExact Certificate
signed = case SignedExact Certificate -> Maybe HashAndSignatureAlgorithm
getHashSignature SignedExact Certificate
signed of
        Maybe HashAndSignatureAlgorithm
Nothing -> Bool
False
        Just HashAndSignatureAlgorithm
hs -> HashAndSignatureAlgorithm
hs HashAndSignatureAlgorithm -> [HashAndSignatureAlgorithm] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [HashAndSignatureAlgorithm]
hashSigs

    isSelfSigned :: SignedExact Certificate -> Bool
isSelfSigned SignedExact Certificate
signed =
        let cert :: Certificate
cert = SignedExact Certificate -> Certificate
getCertificate SignedExact Certificate
signed
         in Certificate -> DistinguishedName
certSubjectDN Certificate
cert DistinguishedName -> DistinguishedName -> Bool
forall a. Eq a => a -> a -> Bool
== Certificate -> DistinguishedName
certIssuerDN Certificate
cert