-- | Helpers for setting up a tls connection with @tls@ package,
-- for further customization, please refer to @tls@ package.
--
-- Note, functions in this module will throw error if can't load certificates or CA store.
--
module Data.TLSSetting
    ( -- * Choose a CAStore
      TrustedCAStore(..)
      -- * Make TLS settings
    , makeClientParams
    , makeClientParams'
    , makeServerParams
    , makeServerParams'
      -- * Internal
    , mozillaCAStorePath
    ) where

import qualified Data.ByteString            as B
import           Data.Default.Class         (def)
import qualified Data.PEM                   as X509
import qualified Data.X509                  as X509
import qualified Data.X509.CertificateStore as X509
import qualified Network.TLS                as TLS
import qualified Network.TLS.Extra          as TLS
import           Paths_mysql_haskell          (getDataFileName)
import qualified System.X509                as X509

-- | The whole point of TLS is that: a peer should have already trusted
-- some certificates, which can be used for validating other peer's certificates.
-- if the certificates sent by other side form a chain. and one of them is issued
-- by one of 'TrustedCAStore', Then the peer will be trusted.
--
data TrustedCAStore
    = SystemCAStore                   -- ^ provided by your operating system.
    | MozillaCAStore                  -- ^ provided by <https://curl.haxx.se/docs/caextract.html Mozilla>.
    | CustomCAStore FilePath          -- ^ provided by your self, the CA file can contain multiple certificates.
  deriving (Int -> TrustedCAStore -> ShowS
[TrustedCAStore] -> ShowS
TrustedCAStore -> String
(Int -> TrustedCAStore -> ShowS)
-> (TrustedCAStore -> String)
-> ([TrustedCAStore] -> ShowS)
-> Show TrustedCAStore
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> TrustedCAStore -> ShowS
showsPrec :: Int -> TrustedCAStore -> ShowS
$cshow :: TrustedCAStore -> String
show :: TrustedCAStore -> String
$cshowList :: [TrustedCAStore] -> ShowS
showList :: [TrustedCAStore] -> ShowS
Show, TrustedCAStore -> TrustedCAStore -> Bool
(TrustedCAStore -> TrustedCAStore -> Bool)
-> (TrustedCAStore -> TrustedCAStore -> Bool) -> Eq TrustedCAStore
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: TrustedCAStore -> TrustedCAStore -> Bool
== :: TrustedCAStore -> TrustedCAStore -> Bool
$c/= :: TrustedCAStore -> TrustedCAStore -> Bool
/= :: TrustedCAStore -> TrustedCAStore -> Bool
Eq)

-- | Get the built-in mozilla CA's path.
mozillaCAStorePath :: IO FilePath
mozillaCAStorePath :: IO String
mozillaCAStorePath = String -> IO String
getDataFileName String
"mozillaCAStore.pem"

makeCAStore :: TrustedCAStore -> IO X509.CertificateStore
makeCAStore :: TrustedCAStore -> IO CertificateStore
makeCAStore TrustedCAStore
SystemCAStore       = IO CertificateStore
X509.getSystemCertificateStore
makeCAStore TrustedCAStore
MozillaCAStore      = TrustedCAStore -> IO CertificateStore
makeCAStore (TrustedCAStore -> IO CertificateStore)
-> (String -> TrustedCAStore) -> String -> IO CertificateStore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> TrustedCAStore
CustomCAStore (String -> IO CertificateStore) -> IO String -> IO CertificateStore
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO String
mozillaCAStorePath
makeCAStore (CustomCAStore String
fp)  = do
    ByteString
bs <- String -> IO ByteString
B.readFile String
fp
    let pems :: [PEM]
pems = case ByteString -> Either String [PEM]
X509.pemParseBS ByteString
bs of
          Right [PEM]
pms -> [PEM]
pms
          Left String
err -> String -> [PEM]
forall a. HasCallStack => String -> a
error String
err
    case (PEM -> Either String SignedCertificate)
-> [PEM] -> Either String [SignedCertificate]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (ByteString -> Either String SignedCertificate
X509.decodeSignedCertificate (ByteString -> Either String SignedCertificate)
-> (PEM -> ByteString) -> PEM -> Either String SignedCertificate
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PEM -> ByteString
X509.pemContent) [PEM]
pems of
        Right [SignedCertificate]
cas -> CertificateStore -> IO CertificateStore
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([SignedCertificate] -> CertificateStore
X509.makeCertificateStore [SignedCertificate]
cas)
        Left String
err  -> String -> IO CertificateStore
forall a. HasCallStack => String -> a
error String
err

-- | make a simple tls 'TLS.ClientParams' that will validate server and use tls connection
-- without providing client's own certificate. suitable for connecting server which don't
-- validate clients.
--
-- we defer setting of 'TLS.clientServerIdentification' to connecting phase.
--
-- Note, tls's default validating method require server has v3 certificate.
-- you can use openssl's V3 extension to issue such a certificate. or change 'TLS.ClientParams'
-- before connecting.
--
makeClientParams :: TrustedCAStore          -- ^ trusted certificates.
                 -> IO TLS.ClientParams
makeClientParams :: TrustedCAStore -> IO ClientParams
makeClientParams TrustedCAStore
tca = do
    CertificateStore
caStore <- TrustedCAStore -> IO CertificateStore
makeCAStore TrustedCAStore
tca
    ClientParams -> IO ClientParams
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> ByteString -> ClientParams
TLS.defaultParamsClient String
"" ByteString
B.empty)
        {   TLS.clientSupported = def { TLS.supportedCiphers = TLS.ciphersuite_default }
        ,   TLS.clientShared    = def
            {   TLS.sharedCAStore         = caStore
            ,   TLS.sharedValidationCache = def
            }
        }

-- | make a simple tls 'TLS.ClientParams' that will validate server and use tls connection
-- while providing client's own certificate as well. suitable for connecting server which
-- validate clients.
--
-- Also only accept v3 certificate.
--
makeClientParams' :: FilePath       -- ^ public certificate (X.509 format).
                  -> [FilePath]     -- ^ chain certificates (X.509 format).
                                    --   the root of your certificate chain should be
                                    --   already trusted by server, or tls will fail.
                  -> FilePath       -- ^ private key associated.
                  -> TrustedCAStore -- ^ trusted certificates.
                  -> IO TLS.ClientParams
makeClientParams' :: String -> [String] -> String -> TrustedCAStore -> IO ClientParams
makeClientParams' String
pub [String]
certs String
priv TrustedCAStore
tca = do
    ClientParams
p <- TrustedCAStore -> IO ClientParams
makeClientParams TrustedCAStore
tca
    Either String Credential
c <- String -> [String] -> String -> IO (Either String Credential)
TLS.credentialLoadX509Chain String
pub [String]
certs String
priv
    case Either String Credential
c of
        Right Credential
c' ->
            ClientParams -> IO ClientParams
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ClientParams
p
                {   TLS.clientShared = (TLS.clientShared p)
                    {
                        TLS.sharedCredentials = TLS.Credentials [c']
                    }
                }
        Left String
err -> String -> IO ClientParams
forall a. HasCallStack => String -> a
error String
err

-- | make a simple tls 'TLS.ServerParams' without validating client's certificate.
--
makeServerParams :: FilePath        -- ^ public certificate (X.509 format).
                 -> [FilePath]      -- ^ chain certificates (X.509 format).
                                    --   the root of your certificate chain should be
                                    --   already trusted by client, or tls will fail.
                 -> FilePath        -- ^ private key associated.
                 -> IO TLS.ServerParams
makeServerParams :: String -> [String] -> String -> IO ServerParams
makeServerParams String
pub [String]
certs String
priv = do
    Either String Credential
c <- String -> [String] -> String -> IO (Either String Credential)
TLS.credentialLoadX509Chain String
pub [String]
certs String
priv
    case Either String Credential
c of
        Right c' :: Credential
c'@(X509.CertificateChain [SignedCertificate]
c'', PrivKey
_) ->
            ServerParams -> IO ServerParams
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ServerParams
forall a. Default a => a
def
                {   TLS.serverCACertificates =  c''
                ,   TLS.serverShared = def
                    {
                        TLS.sharedCredentials = TLS.Credentials [c']
                    }
                ,   TLS.serverSupported = def { TLS.supportedCiphers = TLS.ciphersuite_strong }
                }
        Left String
err -> String -> IO ServerParams
forall a. HasCallStack => String -> a
error String
err

-- | make a tls 'TLS.ServerParams' that also validating client's certificate.
--
makeServerParams' :: FilePath       -- ^ public certificate (X.509 format).
                  -> [FilePath]     -- ^ chain certificates (X.509 format).
                  -> FilePath       -- ^ private key associated.
                  -> TrustedCAStore -- ^ server will use these certificates to validate clients.
                  -> IO TLS.ServerParams
makeServerParams' :: String -> [String] -> String -> TrustedCAStore -> IO ServerParams
makeServerParams' String
pub [String]
certs String
priv TrustedCAStore
tca = do
    CertificateStore
caStore <- TrustedCAStore -> IO CertificateStore
makeCAStore TrustedCAStore
tca
    ServerParams
p <- String -> [String] -> String -> IO ServerParams
makeServerParams String
pub [String]
certs String
priv
    ServerParams -> IO ServerParams
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ServerParams
p
        {   TLS.serverWantClientCert = True
        ,   TLS.serverShared = (TLS.serverShared p)
            {   TLS.sharedCAStore = caStore
            }
        }