-- | Helpers for setting up a tls connection with @tls@ package,
-- for further customization, please refer to @tls@ package.
--
-- Note, functions in this module will throw error if can't load certificates or CA store.
--
module Data.TLSSetting
    ( -- * Choose a CAStore
      TrustedCAStore(..)
      -- * Make TLS settings
    , makeClientParams
    , makeClientParams'
    , makeServerParams
    , makeServerParams'
      -- * Internal
    , mozillaCAStorePath
    ) where

import qualified Data.ByteString            as B
import           Data.Default.Class         (def)
import qualified Data.PEM                   as X509
import qualified Data.X509                  as X509
import qualified Data.X509.CertificateStore as X509
import qualified Network.TLS                as TLS
import qualified Network.TLS.Extra          as TLS
import           Paths_mysql_haskell          (getDataFileName)
import qualified System.X509                as X509

-- | The whole point of TLS is that: a peer should have already trusted
-- some certificates, which can be used for validating other peer's certificates.
-- if the certificates sent by other side form a chain. and one of them is issued
-- by one of 'TrustedCAStore', Then the peer will be trusted.
--
data TrustedCAStore
    = SystemCAStore                   -- ^ provided by your operating system.
    | MozillaCAStore                  -- ^ provided by <https://curl.haxx.se/docs/caextract.html Mozilla>.
    | CustomCAStore FilePath          -- ^ provided by your self, the CA file can contain multiple certificates.
  deriving (Int -> TrustedCAStore -> ShowS
[TrustedCAStore] -> ShowS
TrustedCAStore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TrustedCAStore] -> ShowS
$cshowList :: [TrustedCAStore] -> ShowS
show :: TrustedCAStore -> String
$cshow :: TrustedCAStore -> String
showsPrec :: Int -> TrustedCAStore -> ShowS
$cshowsPrec :: Int -> TrustedCAStore -> ShowS
Show, TrustedCAStore -> TrustedCAStore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TrustedCAStore -> TrustedCAStore -> Bool
$c/= :: TrustedCAStore -> TrustedCAStore -> Bool
== :: TrustedCAStore -> TrustedCAStore -> Bool
$c== :: TrustedCAStore -> TrustedCAStore -> Bool
Eq)

-- | Get the built-in mozilla CA's path.
mozillaCAStorePath :: IO FilePath
mozillaCAStorePath :: IO String
mozillaCAStorePath = String -> IO String
getDataFileName String
"mozillaCAStore.pem"

makeCAStore :: TrustedCAStore -> IO X509.CertificateStore
makeCAStore :: TrustedCAStore -> IO CertificateStore
makeCAStore TrustedCAStore
SystemCAStore       = IO CertificateStore
X509.getSystemCertificateStore
makeCAStore TrustedCAStore
MozillaCAStore      = TrustedCAStore -> IO CertificateStore
makeCAStore forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> TrustedCAStore
CustomCAStore forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO String
mozillaCAStorePath
makeCAStore (CustomCAStore String
fp)  = do
    ByteString
bs <- String -> IO ByteString
B.readFile String
fp
    let Right [PEM]
pems = ByteString -> Either String [PEM]
X509.pemParseBS ByteString
bs
    case forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ByteString -> Either String SignedCertificate
X509.decodeSignedCertificate forall b c a. (b -> c) -> (a -> b) -> a -> c
. PEM -> ByteString
X509.pemContent) [PEM]
pems of
        Right [SignedCertificate]
cas -> forall (m :: * -> *) a. Monad m => a -> m a
return ([SignedCertificate] -> CertificateStore
X509.makeCertificateStore [SignedCertificate]
cas)
        Left String
err  -> forall a. HasCallStack => String -> a
error String
err

-- | make a simple tls 'TLS.ClientParams' that will validate server and use tls connection
-- without providing client's own certificate. suitable for connecting server which don't
-- validate clients.
--
-- we defer setting of 'TLS.clientServerIdentification' to connecting phase.
--
-- Note, tls's default validating method require server has v3 certificate.
-- you can use openssl's V3 extension to issue such a certificate. or change 'TLS.ClientParams'
-- before connecting.
--
makeClientParams :: TrustedCAStore          -- ^ trusted certificates.
                 -> IO TLS.ClientParams
makeClientParams :: TrustedCAStore -> IO ClientParams
makeClientParams TrustedCAStore
tca = do
    CertificateStore
caStore <- TrustedCAStore -> IO CertificateStore
makeCAStore TrustedCAStore
tca
    forall (m :: * -> *) a. Monad m => a -> m a
return (String -> ByteString -> ClientParams
TLS.defaultParamsClient String
"" ByteString
B.empty)
        {   clientSupported :: Supported
TLS.clientSupported = forall a. Default a => a
def { supportedCiphers :: [Cipher]
TLS.supportedCiphers = [Cipher]
TLS.ciphersuite_all }
        ,   clientShared :: Shared
TLS.clientShared    = forall a. Default a => a
def
            {   sharedCAStore :: CertificateStore
TLS.sharedCAStore         = CertificateStore
caStore
            ,   sharedValidationCache :: ValidationCache
TLS.sharedValidationCache = forall a. Default a => a
def
            }
        }

-- | make a simple tls 'TLS.ClientParams' that will validate server and use tls connection
-- while providing client's own certificate as well. suitable for connecting server which
-- validate clients.
--
-- Also only accept v3 certificate.
--
makeClientParams' :: FilePath       -- ^ public certificate (X.509 format).
                  -> [FilePath]     -- ^ chain certificates (X.509 format).
                                    --   the root of your certificate chain should be
                                    --   already trusted by server, or tls will fail.
                  -> FilePath       -- ^ private key associated.
                  -> TrustedCAStore -- ^ trusted certificates.
                  -> IO TLS.ClientParams
makeClientParams' :: String -> [String] -> String -> TrustedCAStore -> IO ClientParams
makeClientParams' String
pub [String]
certs String
priv TrustedCAStore
tca = do
    ClientParams
p <- TrustedCAStore -> IO ClientParams
makeClientParams TrustedCAStore
tca
    Either String Credential
c <- String -> [String] -> String -> IO (Either String Credential)
TLS.credentialLoadX509Chain String
pub [String]
certs String
priv
    case Either String Credential
c of
        Right Credential
c' ->
            forall (m :: * -> *) a. Monad m => a -> m a
return ClientParams
p
                {   clientShared :: Shared
TLS.clientShared = (ClientParams -> Shared
TLS.clientShared ClientParams
p)
                    {
                        sharedCredentials :: Credentials
TLS.sharedCredentials = [Credential] -> Credentials
TLS.Credentials [Credential
c']
                    }
                }
        Left String
err -> forall a. HasCallStack => String -> a
error String
err

-- | make a simple tls 'TLS.ServerParams' without validating client's certificate.
--
makeServerParams :: FilePath        -- ^ public certificate (X.509 format).
                 -> [FilePath]      -- ^ chain certificates (X.509 format).
                                    --   the root of your certificate chain should be
                                    --   already trusted by client, or tls will fail.
                 -> FilePath        -- ^ private key associated.
                 -> IO TLS.ServerParams
makeServerParams :: String -> [String] -> String -> IO ServerParams
makeServerParams String
pub [String]
certs String
priv = do
    Either String Credential
c <- String -> [String] -> String -> IO (Either String Credential)
TLS.credentialLoadX509Chain String
pub [String]
certs String
priv
    case Either String Credential
c of
        Right c' :: Credential
c'@(X509.CertificateChain [SignedCertificate]
c'', PrivKey
_) ->
            forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Default a => a
def
                {   serverCACertificates :: [SignedCertificate]
TLS.serverCACertificates =  [SignedCertificate]
c''
                ,   serverShared :: Shared
TLS.serverShared = forall a. Default a => a
def
                    {
                        sharedCredentials :: Credentials
TLS.sharedCredentials = [Credential] -> Credentials
TLS.Credentials [Credential
c']
                    }
                ,   serverSupported :: Supported
TLS.serverSupported = forall a. Default a => a
def { supportedCiphers :: [Cipher]
TLS.supportedCiphers = [Cipher]
TLS.ciphersuite_strong }
                }
        Left String
err -> forall a. HasCallStack => String -> a
error String
err

-- | make a tls 'TLS.ServerParams' that also validating client's certificate.
--
makeServerParams' :: FilePath       -- ^ public certificate (X.509 format).
                  -> [FilePath]     -- ^ chain certificates (X.509 format).
                  -> FilePath       -- ^ private key associated.
                  -> TrustedCAStore -- ^ server will use these certificates to validate clients.
                  -> IO TLS.ServerParams
makeServerParams' :: String -> [String] -> String -> TrustedCAStore -> IO ServerParams
makeServerParams' String
pub [String]
certs String
priv TrustedCAStore
tca = do
    CertificateStore
caStore <- TrustedCAStore -> IO CertificateStore
makeCAStore TrustedCAStore
tca
    ServerParams
p <- String -> [String] -> String -> IO ServerParams
makeServerParams String
pub [String]
certs String
priv
    forall (m :: * -> *) a. Monad m => a -> m a
return ServerParams
p
        {   serverWantClientCert :: Bool
TLS.serverWantClientCert = Bool
True
        ,   serverShared :: Shared
TLS.serverShared = (ServerParams -> Shared
TLS.serverShared ServerParams
p)
            {   sharedCAStore :: CertificateStore
TLS.sharedCAStore = CertificateStore
caStore
            }
        }