module Database.MySQL.OpenSSL
( connect
, connectDetail
, module Data.OpenSSLSetting
) where
import Control.Exception (bracketOnError, throwIO)
import Control.Monad
import Data.IORef (newIORef)
import Data.Connection as Conn
import qualified Data.Binary as Binary
import qualified Data.Binary.Put as Binary
import Database.MySQL.Connection hiding (connect, connectDetail)
import Database.MySQL.Protocol.Auth
import Database.MySQL.Protocol.Packet
import qualified OpenSSL as SSL
import qualified OpenSSL.X509 as X509
import qualified OpenSSL.Session as Session
import qualified System.IO.Streams.OpenSSL as SSL
import qualified System.IO.Streams.TCP as TCP
import Data.OpenSSLSetting
connect :: ConnectInfo -> (Session.SSLContext, String) -> IO MySQLConn
connect :: ConnectInfo -> (SSLContext, String) -> IO MySQLConn
connect ConnectInfo
c (SSLContext, String)
cp = ((Greeting, MySQLConn) -> MySQLConn)
-> IO (Greeting, MySQLConn) -> IO MySQLConn
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Greeting, MySQLConn) -> MySQLConn
forall a b. (a, b) -> b
snd (ConnectInfo -> (SSLContext, String) -> IO (Greeting, MySQLConn)
connectDetail ConnectInfo
c (SSLContext, String)
cp)
connectDetail :: ConnectInfo -> (Session.SSLContext, String) -> IO (Greeting, MySQLConn)
connectDetail :: ConnectInfo -> (SSLContext, String) -> IO (Greeting, MySQLConn)
connectDetail (ConnectInfo String
host PortNumber
port ByteString
db ByteString
user ByteString
pass Word8
charset) (SSLContext
ctx, String
subname) =
IO TCPConnection
-> (TCPConnection -> IO ())
-> (TCPConnection -> IO (Greeting, MySQLConn))
-> IO (Greeting, MySQLConn)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (String -> PortNumber -> Int -> IO TCPConnection
connectWithBufferSize String
host PortNumber
port Int
bUFSIZE) TCPConnection -> IO ()
forall a. Connection a -> IO ()
Conn.close ((TCPConnection -> IO (Greeting, MySQLConn))
-> IO (Greeting, MySQLConn))
-> (TCPConnection -> IO (Greeting, MySQLConn))
-> IO (Greeting, MySQLConn)
forall a b. (a -> b) -> a -> b
$ \ TCPConnection
conn -> do
let is :: InputStream ByteString
is = TCPConnection -> InputStream ByteString
forall a. Connection a -> InputStream ByteString
Conn.source TCPConnection
conn
InputStream Packet
is' <- InputStream ByteString -> IO (InputStream Packet)
decodeInputStream InputStream ByteString
is
Packet
p <- InputStream Packet -> IO Packet
readPacket InputStream Packet
is'
Greeting
greet <- Packet -> IO Greeting
forall a. Binary a => Packet -> IO a
decodeFromPacket Packet
p
if Word32 -> Bool
supportTLS (Greeting -> Word32
greetingCaps Greeting
greet)
then IO (Greeting, MySQLConn) -> IO (Greeting, MySQLConn)
forall a. IO a -> IO a
SSL.withOpenSSL (IO (Greeting, MySQLConn) -> IO (Greeting, MySQLConn))
-> IO (Greeting, MySQLConn) -> IO (Greeting, MySQLConn)
forall a b. (a -> b) -> a -> b
$ do
TCPConnection -> Packet -> IO ()
forall t a. Binary t => Connection a -> t -> IO ()
write TCPConnection
conn (Word8 -> SSLRequest -> Packet
forall a. Binary a => Word8 -> a -> Packet
encodeToPacket Word8
1 (SSLRequest -> Packet) -> SSLRequest -> Packet
forall a b. (a -> b) -> a -> b
$ Word8 -> SSLRequest
sslRequest Word8
charset)
let (Socket
sock, SockAddr
sockAddr) = TCPConnection -> (Socket, SockAddr)
forall a. Connection a -> a
Conn.connExtraInfo TCPConnection
conn
IO SSL
-> (SSL -> IO ())
-> (SSL -> IO (Greeting, MySQLConn))
-> IO (Greeting, MySQLConn)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (SSLContext -> Socket -> IO SSL
Session.connection SSLContext
ctx Socket
sock)
(\ SSL
ssl -> do
SSL -> ShutdownType -> IO ()
Session.shutdown SSL
ssl ShutdownType
Session.Unidirectional
TCPConnection -> IO ()
forall a. Connection a -> IO ()
Conn.close TCPConnection
conn
) ((SSL -> IO (Greeting, MySQLConn)) -> IO (Greeting, MySQLConn))
-> (SSL -> IO (Greeting, MySQLConn)) -> IO (Greeting, MySQLConn)
forall a b. (a -> b) -> a -> b
$ \ SSL
ssl -> do
SSL -> IO ()
Session.connect SSL
ssl
Bool
trusted <- SSL -> IO Bool
Session.getVerifyResult SSL
ssl
Maybe X509
cert <- SSL -> IO (Maybe X509)
Session.getPeerCertificate SSL
ssl
[(String, String)]
subnames <- IO [(String, String)]
-> (X509 -> IO [(String, String)])
-> Maybe X509
-> IO [(String, String)]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ([(String, String)] -> IO [(String, String)]
forall (m :: * -> *) a. Monad m => a -> m a
return []) (X509 -> Bool -> IO [(String, String)]
`X509.getSubjectName` Bool
False) Maybe X509
cert
let cnname :: Maybe String
cnname = String -> [(String, String)] -> Maybe String
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup String
"CN" [(String, String)]
subnames
verified :: Bool
verified = Bool -> (String -> Bool) -> Maybe String -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
subname) Maybe String
cnname
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Bool
trusted Bool -> Bool -> Bool
&& Bool
verified) (ProtocolError -> IO ()
forall e a. Exception e => e -> IO a
throwIO (ProtocolError -> IO ()) -> ProtocolError -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> ProtocolError
Session.ProtocolError String
"fail to verify certificate")
TLSConnection
sconn <- (SSL, SockAddr) -> IO TLSConnection
SSL.sslToConnection (SSL
ssl, SockAddr
sockAddr)
let sis :: InputStream ByteString
sis = TLSConnection -> InputStream ByteString
forall a. Connection a -> InputStream ByteString
Conn.source TLSConnection
sconn
auth :: Auth
auth = ByteString -> ByteString -> ByteString -> Word8 -> Greeting -> Auth
mkAuth ByteString
db ByteString
user ByteString
pass Word8
charset Greeting
greet
TLSConnection -> Packet -> IO ()
forall t a. Binary t => Connection a -> t -> IO ()
write TLSConnection
sconn (Word8 -> Auth -> Packet
forall a. Binary a => Word8 -> a -> Packet
encodeToPacket Word8
2 Auth
auth)
InputStream Packet
sis' <- InputStream ByteString -> IO (InputStream Packet)
decodeInputStream InputStream ByteString
sis
Packet
q <- InputStream Packet -> IO Packet
readPacket InputStream Packet
sis'
if Packet -> Bool
isOK Packet
q
then do
IORef Bool
consumed <- Bool -> IO (IORef Bool)
forall a. a -> IO (IORef a)
newIORef Bool
True
let mconn :: MySQLConn
mconn = InputStream Packet
-> (Packet -> IO ()) -> IO () -> IORef Bool -> MySQLConn
MySQLConn InputStream Packet
sis' (TLSConnection -> Packet -> IO ()
forall t a. Binary t => Connection a -> t -> IO ()
write TLSConnection
sconn) (TLSConnection -> IO ()
forall a. Connection a -> IO ()
Conn.close TLSConnection
sconn) IORef Bool
consumed
(Greeting, MySQLConn) -> IO (Greeting, MySQLConn)
forall (m :: * -> *) a. Monad m => a -> m a
return (Greeting
greet, MySQLConn
mconn)
else TLSConnection -> IO ()
forall a. Connection a -> IO ()
Conn.close TLSConnection
sconn IO () -> IO ERR -> IO ERR
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Packet -> IO ERR
forall a. Binary a => Packet -> IO a
decodeFromPacket Packet
q IO ERR
-> (ERR -> IO (Greeting, MySQLConn)) -> IO (Greeting, MySQLConn)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ERRException -> IO (Greeting, MySQLConn)
forall e a. Exception e => e -> IO a
throwIO (ERRException -> IO (Greeting, MySQLConn))
-> (ERR -> ERRException) -> ERR -> IO (Greeting, MySQLConn)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ERR -> ERRException
ERRException
else String -> IO (Greeting, MySQLConn)
forall a. HasCallStack => String -> a
error String
"Database.MySQL.OpenSSL: server doesn't support TLS connection"
where
connectWithBufferSize :: String -> PortNumber -> Int -> IO TCPConnection
connectWithBufferSize String
h PortNumber
p Int
bs = String -> PortNumber -> IO (Socket, SockAddr)
TCP.connectSocket String
h PortNumber
p IO (Socket, SockAddr)
-> ((Socket, SockAddr) -> IO TCPConnection) -> IO TCPConnection
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> (Socket, SockAddr) -> IO TCPConnection
TCP.socketToConnection Int
bs
write :: Connection a -> t -> IO ()
write Connection a
c t
a = Connection a -> ByteString -> IO ()
forall a. Connection a -> ByteString -> IO ()
Conn.send Connection a
c (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ Put -> ByteString
Binary.runPut (Put -> ByteString) -> (t -> Put) -> t -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t -> Put
forall t. Binary t => t -> Put
Binary.put (t -> ByteString) -> t -> ByteString
forall a b. (a -> b) -> a -> b
$ t
a