{-|
Module      : Database.MySQL.Connection
Description : Alternative TLS support for mysql-haskell via @HsOpenSSL@ package.
Copyright   : (c) Winterland, 2016
License     : BSD
Maintainer  : drkoster@qq.com
Stability   : experimental
Portability : PORTABLE

This module provides secure MySQL connection using 'HsOpenSSL' package.

-}

module Database.MySQL.OpenSSL
    ( connect
    , connectDetail
    , module Data.OpenSSLSetting
    ) where

import           Control.Exception              (bracketOnError, throwIO)
import           Control.Monad
import           Data.IORef                     (newIORef)
import           Data.Connection                as Conn
import qualified Data.Binary                    as Binary
import qualified Data.Binary.Put                as Binary
import           Database.MySQL.Connection      hiding (connect, connectDetail)
import           Database.MySQL.Protocol.Auth
import           Database.MySQL.Protocol.Packet
import qualified OpenSSL                        as SSL
import qualified OpenSSL.X509                   as X509
import qualified OpenSSL.Session                as Session
import qualified System.IO.Streams.OpenSSL      as SSL
import qualified System.IO.Streams.TCP          as TCP
import           Data.OpenSSLSetting

--------------------------------------------------------------------------------

-- | Provide a 'Session.SSLContext' and a subject name to establish a TLS connection.
--
connect :: ConnectInfo -> (Session.SSLContext, String) -> IO MySQLConn
connect :: ConnectInfo -> (SSLContext, String) -> IO MySQLConn
connect ConnectInfo
c (SSLContext, String)
cp = ((Greeting, MySQLConn) -> MySQLConn)
-> IO (Greeting, MySQLConn) -> IO MySQLConn
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Greeting, MySQLConn) -> MySQLConn
forall a b. (a, b) -> b
snd (ConnectInfo -> (SSLContext, String) -> IO (Greeting, MySQLConn)
connectDetail ConnectInfo
c (SSLContext, String)
cp)

connectDetail :: ConnectInfo -> (Session.SSLContext, String) -> IO (Greeting, MySQLConn)
connectDetail :: ConnectInfo -> (SSLContext, String) -> IO (Greeting, MySQLConn)
connectDetail (ConnectInfo String
host PortNumber
port ByteString
db ByteString
user ByteString
pass Word8
charset) (SSLContext
ctx, String
subname) =
    IO TCPConnection
-> (TCPConnection -> IO ())
-> (TCPConnection -> IO (Greeting, MySQLConn))
-> IO (Greeting, MySQLConn)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (String -> PortNumber -> Int -> IO TCPConnection
connectWithBufferSize String
host PortNumber
port Int
bUFSIZE) TCPConnection -> IO ()
forall a. Connection a -> IO ()
Conn.close ((TCPConnection -> IO (Greeting, MySQLConn))
 -> IO (Greeting, MySQLConn))
-> (TCPConnection -> IO (Greeting, MySQLConn))
-> IO (Greeting, MySQLConn)
forall a b. (a -> b) -> a -> b
$ \ TCPConnection
conn -> do
            let is :: InputStream ByteString
is = TCPConnection -> InputStream ByteString
forall a. Connection a -> InputStream ByteString
Conn.source TCPConnection
conn
            InputStream Packet
is' <- InputStream ByteString -> IO (InputStream Packet)
decodeInputStream InputStream ByteString
is
            Packet
p <- InputStream Packet -> IO Packet
readPacket InputStream Packet
is'
            Greeting
greet <- Packet -> IO Greeting
forall a. Binary a => Packet -> IO a
decodeFromPacket Packet
p
            if Word32 -> Bool
supportTLS (Greeting -> Word32
greetingCaps Greeting
greet)
            then IO (Greeting, MySQLConn) -> IO (Greeting, MySQLConn)
forall a. IO a -> IO a
SSL.withOpenSSL (IO (Greeting, MySQLConn) -> IO (Greeting, MySQLConn))
-> IO (Greeting, MySQLConn) -> IO (Greeting, MySQLConn)
forall a b. (a -> b) -> a -> b
$ do
                TCPConnection -> Packet -> IO ()
forall t a. Binary t => Connection a -> t -> IO ()
write TCPConnection
conn (Word8 -> SSLRequest -> Packet
forall a. Binary a => Word8 -> a -> Packet
encodeToPacket Word8
1 (SSLRequest -> Packet) -> SSLRequest -> Packet
forall a b. (a -> b) -> a -> b
$ Word8 -> SSLRequest
sslRequest Word8
charset)
                let (Socket
sock, SockAddr
sockAddr) = TCPConnection -> (Socket, SockAddr)
forall a. Connection a -> a
Conn.connExtraInfo TCPConnection
conn
                IO SSL
-> (SSL -> IO ())
-> (SSL -> IO (Greeting, MySQLConn))
-> IO (Greeting, MySQLConn)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (SSLContext -> Socket -> IO SSL
Session.connection SSLContext
ctx Socket
sock)
                               (\ SSL
ssl -> do
                                    SSL -> ShutdownType -> IO ()
Session.shutdown SSL
ssl ShutdownType
Session.Unidirectional
                                    TCPConnection -> IO ()
forall a. Connection a -> IO ()
Conn.close TCPConnection
conn
                               ) ((SSL -> IO (Greeting, MySQLConn)) -> IO (Greeting, MySQLConn))
-> (SSL -> IO (Greeting, MySQLConn)) -> IO (Greeting, MySQLConn)
forall a b. (a -> b) -> a -> b
$ \ SSL
ssl -> do
                    SSL -> IO ()
Session.connect SSL
ssl
                    Bool
trusted <- SSL -> IO Bool
Session.getVerifyResult SSL
ssl
                    Maybe X509
cert <- SSL -> IO (Maybe X509)
Session.getPeerCertificate SSL
ssl
                    [(String, String)]
subnames <- IO [(String, String)]
-> (X509 -> IO [(String, String)])
-> Maybe X509
-> IO [(String, String)]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ([(String, String)] -> IO [(String, String)]
forall (m :: * -> *) a. Monad m => a -> m a
return []) (X509 -> Bool -> IO [(String, String)]
`X509.getSubjectName` Bool
False) Maybe X509
cert
                    let cnname :: Maybe String
cnname = String -> [(String, String)] -> Maybe String
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup String
"CN" [(String, String)]
subnames
                        verified :: Bool
verified = Bool -> (String -> Bool) -> Maybe String -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
subname) Maybe String
cnname
                    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Bool
trusted Bool -> Bool -> Bool
&& Bool
verified) (ProtocolError -> IO ()
forall e a. Exception e => e -> IO a
throwIO (ProtocolError -> IO ()) -> ProtocolError -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> ProtocolError
Session.ProtocolError String
"fail to verify certificate")
                    TLSConnection
sconn <- (SSL, SockAddr) -> IO TLSConnection
SSL.sslToConnection (SSL
ssl, SockAddr
sockAddr)
                    let sis :: InputStream ByteString
sis = TLSConnection -> InputStream ByteString
forall a. Connection a -> InputStream ByteString
Conn.source TLSConnection
sconn
                        auth :: Auth
auth = ByteString -> ByteString -> ByteString -> Word8 -> Greeting -> Auth
mkAuth ByteString
db ByteString
user ByteString
pass Word8
charset Greeting
greet
                    TLSConnection -> Packet -> IO ()
forall t a. Binary t => Connection a -> t -> IO ()
write TLSConnection
sconn (Word8 -> Auth -> Packet
forall a. Binary a => Word8 -> a -> Packet
encodeToPacket Word8
2 Auth
auth)
                    InputStream Packet
sis' <- InputStream ByteString -> IO (InputStream Packet)
decodeInputStream InputStream ByteString
sis
                    Packet
q <- InputStream Packet -> IO Packet
readPacket InputStream Packet
sis'
                    if Packet -> Bool
isOK Packet
q
                    then do
                        IORef Bool
consumed <- Bool -> IO (IORef Bool)
forall a. a -> IO (IORef a)
newIORef Bool
True
                        let mconn :: MySQLConn
mconn = InputStream Packet
-> (Packet -> IO ()) -> IO () -> IORef Bool -> MySQLConn
MySQLConn InputStream Packet
sis' (TLSConnection -> Packet -> IO ()
forall t a. Binary t => Connection a -> t -> IO ()
write TLSConnection
sconn) (TLSConnection -> IO ()
forall a. Connection a -> IO ()
Conn.close TLSConnection
sconn) IORef Bool
consumed
                        (Greeting, MySQLConn) -> IO (Greeting, MySQLConn)
forall (m :: * -> *) a. Monad m => a -> m a
return (Greeting
greet, MySQLConn
mconn)
                    else TLSConnection -> IO ()
forall a. Connection a -> IO ()
Conn.close TLSConnection
sconn IO () -> IO ERR -> IO ERR
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Packet -> IO ERR
forall a. Binary a => Packet -> IO a
decodeFromPacket Packet
q IO ERR
-> (ERR -> IO (Greeting, MySQLConn)) -> IO (Greeting, MySQLConn)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ERRException -> IO (Greeting, MySQLConn)
forall e a. Exception e => e -> IO a
throwIO (ERRException -> IO (Greeting, MySQLConn))
-> (ERR -> ERRException) -> ERR -> IO (Greeting, MySQLConn)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ERR -> ERRException
ERRException
            else String -> IO (Greeting, MySQLConn)
forall a. HasCallStack => String -> a
error String
"Database.MySQL.OpenSSL: server doesn't support TLS connection"
  where
    connectWithBufferSize :: String -> PortNumber -> Int -> IO TCPConnection
connectWithBufferSize String
h PortNumber
p Int
bs = String -> PortNumber -> IO (Socket, SockAddr)
TCP.connectSocket String
h PortNumber
p IO (Socket, SockAddr)
-> ((Socket, SockAddr) -> IO TCPConnection) -> IO TCPConnection
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> (Socket, SockAddr) -> IO TCPConnection
TCP.socketToConnection Int
bs
    write :: Connection a -> t -> IO ()
write Connection a
c t
a = Connection a -> ByteString -> IO ()
forall a. Connection a -> ByteString -> IO ()
Conn.send Connection a
c (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ Put -> ByteString
Binary.runPut (Put -> ByteString) -> (t -> Put) -> t -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t -> Put
forall t. Binary t => t -> Put
Binary.put (t -> ByteString) -> t -> ByteString
forall a b. (a -> b) -> a -> b
$ t
a