module Database.MySQL.OpenSSL
( connect
, connectDetail
, module Data.OpenSSLSetting
) where
import Control.Exception (bracketOnError, throwIO)
import Control.Monad
import Data.IORef (newIORef)
import Data.Connection as Conn
import qualified Data.Binary as Binary
import qualified Data.Binary.Put as Binary
import Database.MySQL.Connection hiding (connect, connectDetail)
import Database.MySQL.Protocol.Auth
import Database.MySQL.Protocol.Packet
import qualified OpenSSL as SSL
import qualified OpenSSL.X509 as X509
import qualified OpenSSL.Session as Session
import qualified System.IO.Streams.OpenSSL as SSL
import qualified System.IO.Streams.TCP as TCP
import Data.OpenSSLSetting
connect :: ConnectInfo -> (Session.SSLContext, String) -> IO MySQLConn
connect c cp = fmap snd (connectDetail c cp)
connectDetail :: ConnectInfo -> (Session.SSLContext, String) -> IO (Greeting, MySQLConn)
connectDetail (ConnectInfo host port db user pass charset) (ctx, subname) =
bracketOnError (connectWithBufferSize host port bUFSIZE) Conn.close $ \ conn -> do
let is = Conn.source conn
is' <- decodeInputStream is
p <- readPacket is'
greet <- decodeFromPacket p
if supportTLS (greetingCaps greet)
then SSL.withOpenSSL $ do
write conn (encodeToPacket 1 $ sslRequest charset)
let (sock, sockAddr) = Conn.connExtraInfo conn
bracketOnError (Session.connection ctx sock)
(\ ssl -> do
Session.shutdown ssl Session.Unidirectional
Conn.close conn
) $ \ ssl -> do
Session.connect ssl
trusted <- Session.getVerifyResult ssl
cert <- Session.getPeerCertificate ssl
subnames <- maybe (return []) (`X509.getSubjectName` False) cert
let cnname = lookup "CN" subnames
verified = maybe False (== subname) cnname
unless (trusted && verified) (throwIO $ Session.ProtocolError "fail to verify certificate")
sconn <- SSL.sslToConnection (ssl, sockAddr)
let sis = Conn.source sconn
auth = mkAuth db user pass charset greet
write sconn (encodeToPacket 2 auth)
sis' <- decodeInputStream sis
q <- readPacket sis'
if isOK q
then do
consumed <- newIORef True
let mconn = MySQLConn sis' (write sconn) (Conn.close sconn) consumed
return (greet, mconn)
else Conn.close sconn >> decodeFromPacket q >>= throwIO . ERRException
else error "Database.MySQL.OpenSSL: server doesn't support TLS connection"
where
connectWithBufferSize h p bs = TCP.connectSocket h p >>= TCP.socketToConnection bs
write c a = Conn.send c $ Binary.runPut . Binary.put $ a