{-# LANGUAGE OverloadedStrings #-}

module Network.TLS.Handshake.Server.Common (
    applicationProtocol,
    checkValidClientCertChain,
    clientCertificate,
    credentialDigitalSignatureKey,
    filterCredentials,
    filterCredentialsWithHashSignatures,
    isCredentialAllowed,
    storePrivInfoServer,
) where

import Control.Monad.State.Strict
import Data.X509 (ExtKeyUsageFlag (..))

import Network.TLS.Context.Internal
import Network.TLS.Credentials
import Network.TLS.Crypto
import Network.TLS.Extension
import Network.TLS.Handshake.Certificate
import Network.TLS.Handshake.Common
import Network.TLS.Handshake.Key
import Network.TLS.Handshake.State
import Network.TLS.Imports
import Network.TLS.Parameters
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Util (catchException)
import Network.TLS.X509

checkValidClientCertChain
    :: MonadIO m => Context -> String -> m CertificateChain
checkValidClientCertChain :: forall (m :: * -> *).
MonadIO m =>
Context -> String -> m CertificateChain
checkValidClientCertChain Context
ctx String
errmsg = do
    Maybe CertificateChain
chain <- Context
-> HandshakeM (Maybe CertificateChain)
-> m (Maybe CertificateChain)
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM (Maybe CertificateChain)
getClientCertChain
    let throwerror :: TLSError
throwerror = String -> AlertDescription -> TLSError
Error_Protocol String
errmsg AlertDescription
UnexpectedMessage
    case Maybe CertificateChain
chain of
        Maybe CertificateChain
Nothing -> TLSError -> m CertificateChain
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore TLSError
throwerror
        Just CertificateChain
cc
            | CertificateChain -> Bool
isNullCertificateChain CertificateChain
cc -> TLSError -> m CertificateChain
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore TLSError
throwerror
            | Bool
otherwise -> CertificateChain -> m CertificateChain
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return CertificateChain
cc

credentialDigitalSignatureKey :: Credential -> Maybe PubKey
credentialDigitalSignatureKey :: Credential -> Maybe PubKey
credentialDigitalSignatureKey Credential
cred
    | (PubKey, PrivKey) -> Bool
isDigitalSignaturePair (PubKey, PrivKey)
keys = PubKey -> Maybe PubKey
forall a. a -> Maybe a
Just PubKey
pubkey
    | Bool
otherwise = Maybe PubKey
forall a. Maybe a
Nothing
  where
    keys :: (PubKey, PrivKey)
keys@(PubKey
pubkey, PrivKey
_) = Credential -> (PubKey, PrivKey)
credentialPublicPrivateKeys Credential
cred

filterCredentials :: (Credential -> Bool) -> Credentials -> Credentials
filterCredentials :: (Credential -> Bool) -> Credentials -> Credentials
filterCredentials Credential -> Bool
p (Credentials [Credential]
l) = [Credential] -> Credentials
Credentials ((Credential -> Bool) -> [Credential] -> [Credential]
forall a. (a -> Bool) -> [a] -> [a]
filter Credential -> Bool
p [Credential]
l)

isCredentialAllowed :: Version -> [ExtensionRaw] -> Credential -> Bool
isCredentialAllowed :: Version -> [ExtensionRaw] -> Credential -> Bool
isCredentialAllowed Version
ver [ExtensionRaw]
exts Credential
cred =
    PubKey
pubkey PubKey -> Version -> Bool
`versionCompatible` Version
ver Bool -> Bool -> Bool
&& (Group -> Bool) -> PubKey -> Bool
satisfiesEcPredicate Group -> Bool
p PubKey
pubkey
  where
    (PubKey
pubkey, PrivKey
_) = Credential -> (PubKey, PrivKey)
credentialPublicPrivateKeys Credential
cred
    -- ECDSA keys are tested against supported elliptic curves until TLS12 but
    -- not after.  With TLS13, the curve is linked to the signature algorithm
    -- and client support is tested with signatureCompatible13.
    p :: Group -> Bool
p
        | Version
ver Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
< Version
TLS13 = case ExtensionID -> [ExtensionRaw] -> Maybe ByteString
extensionLookup ExtensionID
EID_SupportedGroups [ExtensionRaw]
exts
            Maybe ByteString
-> (ByteString -> Maybe SupportedGroups) -> Maybe SupportedGroups
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MessageType -> ByteString -> Maybe SupportedGroups
forall a. Extension a => MessageType -> ByteString -> Maybe a
extensionDecode MessageType
MsgTClientHello of
            Maybe SupportedGroups
Nothing -> Bool -> Group -> Bool
forall a b. a -> b -> a
const Bool
True
            Just (SupportedGroups [Group]
sg) -> (Group -> [Group] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Group]
sg)
        | Bool
otherwise = Bool -> Group -> Bool
forall a b. a -> b -> a
const Bool
True

-- Filters a list of candidate credentials with credentialMatchesHashSignatures.
--
-- Algorithms to filter with are taken from "signature_algorithms_cert"
-- extension when it exists, else from "signature_algorithms" when clients do
-- not implement the new extension (see RFC 8446 section 4.2.3).
--
-- Resulting credential list can be used as input to the hybrid cipher-and-
-- certificate selection for TLS12, or to the direct certificate selection
-- simplified with TLS13.  As filtering credential signatures with client-
-- advertised algorithms is not supposed to cause negotiation failure, in case
-- of dead end with the subsequent selection process, this process should always
-- be restarted with the unfiltered credential list as input (see fallback
-- certificate chains, described in same RFC section).
--
-- Calling code should not forget to apply constraints of extension
-- "signature_algorithms" to any signature-based key exchange derived from the
-- output credentials.  Respecting client constraints on KX signatures is
-- mandatory but not implemented by this function.
filterCredentialsWithHashSignatures
    :: [ExtensionRaw] -> Credentials -> Credentials
filterCredentialsWithHashSignatures :: [ExtensionRaw] -> Credentials -> Credentials
filterCredentialsWithHashSignatures [ExtensionRaw]
exts =
    case ExtensionID -> Maybe SignatureAlgorithmsCert
forall {b}. Extension b => ExtensionID -> Maybe b
withExt ExtensionID
EID_SignatureAlgorithmsCert of
        Just (SignatureAlgorithmsCert [HashAndSignatureAlgorithm]
sas) -> [HashAndSignatureAlgorithm] -> Credentials -> Credentials
withAlgs [HashAndSignatureAlgorithm]
sas
        Maybe SignatureAlgorithmsCert
Nothing ->
            case ExtensionID -> Maybe SignatureAlgorithms
forall {b}. Extension b => ExtensionID -> Maybe b
withExt ExtensionID
EID_SignatureAlgorithms of
                Maybe SignatureAlgorithms
Nothing -> Credentials -> Credentials
forall a. a -> a
id
                Just (SignatureAlgorithms [HashAndSignatureAlgorithm]
sas) -> [HashAndSignatureAlgorithm] -> Credentials -> Credentials
withAlgs [HashAndSignatureAlgorithm]
sas
  where
    withExt :: ExtensionID -> Maybe b
withExt ExtensionID
extId = ExtensionID -> [ExtensionRaw] -> Maybe ByteString
extensionLookup ExtensionID
extId [ExtensionRaw]
exts Maybe ByteString -> (ByteString -> Maybe b) -> Maybe b
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MessageType -> ByteString -> Maybe b
forall a. Extension a => MessageType -> ByteString -> Maybe a
extensionDecode MessageType
MsgTClientHello
    withAlgs :: [HashAndSignatureAlgorithm] -> Credentials -> Credentials
withAlgs [HashAndSignatureAlgorithm]
sas = (Credential -> Bool) -> Credentials -> Credentials
filterCredentials ([HashAndSignatureAlgorithm] -> Credential -> Bool
credentialMatchesHashSignatures [HashAndSignatureAlgorithm]
sas)

storePrivInfoServer :: MonadIO m => Context -> Credential -> m ()
storePrivInfoServer :: forall (m :: * -> *). MonadIO m => Context -> Credential -> m ()
storePrivInfoServer Context
ctx (CertificateChain
cc, PrivKey
privkey) = m PubKey -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Context -> CertificateChain -> PrivKey -> m PubKey
forall (m :: * -> *).
MonadIO m =>
Context -> CertificateChain -> PrivKey -> m PubKey
storePrivInfo Context
ctx CertificateChain
cc PrivKey
privkey)

applicationProtocol
    :: Context -> [ExtensionRaw] -> ServerParams -> IO [ExtensionRaw]
applicationProtocol :: Context -> [ExtensionRaw] -> ServerParams -> IO [ExtensionRaw]
applicationProtocol Context
ctx [ExtensionRaw]
exts ServerParams
sparams = do
    -- ALPN (Application Layer Protocol Negotiation)
    case ExtensionID -> [ExtensionRaw] -> Maybe ByteString
extensionLookup ExtensionID
EID_ApplicationLayerProtocolNegotiation [ExtensionRaw]
exts
        Maybe ByteString
-> (ByteString -> Maybe ApplicationLayerProtocolNegotiation)
-> Maybe ApplicationLayerProtocolNegotiation
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MessageType
-> ByteString -> Maybe ApplicationLayerProtocolNegotiation
forall a. Extension a => MessageType -> ByteString -> Maybe a
extensionDecode MessageType
MsgTClientHello of
        Maybe ApplicationLayerProtocolNegotiation
Nothing -> [ExtensionRaw] -> IO [ExtensionRaw]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return []
        Just (ApplicationLayerProtocolNegotiation [ByteString]
protos) -> do
            case ServerHooks -> Maybe ([ByteString] -> IO ByteString)
onALPNClientSuggest (ServerHooks -> Maybe ([ByteString] -> IO ByteString))
-> ServerHooks -> Maybe ([ByteString] -> IO ByteString)
forall a b. (a -> b) -> a -> b
$ ServerParams -> ServerHooks
serverHooks ServerParams
sparams of
                Just [ByteString] -> IO ByteString
io -> do
                    ByteString
proto <- [ByteString] -> IO ByteString
io [ByteString]
protos
                    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString
proto ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"") (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
                        TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
                            String -> AlertDescription -> TLSError
Error_Protocol String
"no supported application protocols" AlertDescription
NoApplicationProtocol
                    Context -> TLSSt () -> IO ()
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx (TLSSt () -> IO ()) -> TLSSt () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                        Bool -> TLSSt ()
setExtensionALPN Bool
True
                        ByteString -> TLSSt ()
setNegotiatedProtocol ByteString
proto
                    [ExtensionRaw] -> IO [ExtensionRaw]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return
                        [ ExtensionID -> ByteString -> ExtensionRaw
ExtensionRaw
                            ExtensionID
EID_ApplicationLayerProtocolNegotiation
                            (ApplicationLayerProtocolNegotiation -> ByteString
forall a. Extension a => a -> ByteString
extensionEncode (ApplicationLayerProtocolNegotiation -> ByteString)
-> ApplicationLayerProtocolNegotiation -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ApplicationLayerProtocolNegotiation
ApplicationLayerProtocolNegotiation [ByteString
proto])
                        ]
                Maybe ([ByteString] -> IO ByteString)
_ -> [ExtensionRaw] -> IO [ExtensionRaw]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return []

clientCertificate :: ServerParams -> Context -> CertificateChain -> IO ()
clientCertificate :: ServerParams -> Context -> CertificateChain -> IO ()
clientCertificate ServerParams
sparams Context
ctx CertificateChain
certs = do
    -- run certificate recv hook
    Context -> (Hooks -> IO ()) -> IO ()
forall a. Context -> (Hooks -> IO a) -> IO a
ctxWithHooks Context
ctx (Hooks -> CertificateChain -> IO ()
`hookRecvCertificates` CertificateChain
certs)
    -- Call application callback to see whether the
    -- certificate chain is acceptable.
    --
    CertificateUsage
usage <-
        IO CertificateUsage -> IO CertificateUsage
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO CertificateUsage -> IO CertificateUsage)
-> IO CertificateUsage -> IO CertificateUsage
forall a b. (a -> b) -> a -> b
$
            IO CertificateUsage
-> (SomeException -> IO CertificateUsage) -> IO CertificateUsage
forall a. IO a -> (SomeException -> IO a) -> IO a
catchException
                (ServerHooks -> CertificateChain -> IO CertificateUsage
onClientCertificate (ServerParams -> ServerHooks
serverHooks ServerParams
sparams) CertificateChain
certs)
                SomeException -> IO CertificateUsage
rejectOnException
    case CertificateUsage
usage of
        CertificateUsage
CertificateUsageAccept -> [ExtKeyUsageFlag] -> CertificateChain -> IO ()
forall (m :: * -> *).
MonadIO m =>
[ExtKeyUsageFlag] -> CertificateChain -> m ()
verifyLeafKeyUsage [ExtKeyUsageFlag
KeyUsage_digitalSignature] CertificateChain
certs
        CertificateUsageReject CertificateRejectReason
reason -> CertificateRejectReason -> IO ()
forall (m :: * -> *) a. MonadIO m => CertificateRejectReason -> m a
certificateRejected CertificateRejectReason
reason

    -- Remember cert chain for later use.
    --
    Context -> HandshakeM () -> IO ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM () -> IO ()) -> HandshakeM () -> IO ()
forall a b. (a -> b) -> a -> b
$ CertificateChain -> HandshakeM ()
setClientCertChain CertificateChain
certs