{-|
Module      : Database.MySQL.Connection
Description : TLS support for mysql-haskell via @tls@ package.
Copyright   : (c) Winterland, 2016
License     : BSD
Maintainer  : drkoster@qq.com
Stability   : experimental
Portability : PORTABLE

This module provides secure MySQL connection using 'tls' package, please make sure your certificate is v3 extension enabled.

-}

module Database.MySQL.TLS (
      connect
    , connectDetail
    , module Data.TLSSetting
    ) where

import           Control.Exception              (bracketOnError, throwIO)
import qualified Data.Binary                    as Binary
import qualified Data.Binary.Put                as Binary
import qualified Data.Connection                as Conn
import           Data.IORef                     (newIORef)
import           Data.TLSSetting
import           Database.MySQL.Connection      hiding (connect, connectDetail)
import           Database.MySQL.Protocol.Auth
import           Database.MySQL.Protocol.Packet
import qualified Network.TLS                    as TLS
import qualified System.IO.Streams.TCP          as TCP
import qualified Data.Connection                as TCP
import qualified System.IO.Streams.TLS          as TLS

--------------------------------------------------------------------------------

-- | Provide a 'TLS.ClientParams' and a subject name to establish a TLS connection.
--
connect :: ConnectInfo -> (TLS.ClientParams, String) -> IO MySQLConn
connect :: ConnectInfo -> (ClientParams, String) -> IO MySQLConn
connect ConnectInfo
c (ClientParams, String)
cp = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> b
snd (ConnectInfo -> (ClientParams, String) -> IO (Greeting, MySQLConn)
connectDetail ConnectInfo
c (ClientParams, String)
cp)

connectDetail :: ConnectInfo -> (TLS.ClientParams, String) -> IO (Greeting, MySQLConn)
connectDetail :: ConnectInfo -> (ClientParams, String) -> IO (Greeting, MySQLConn)
connectDetail (ConnectInfo String
host PortNumber
port ByteString
db ByteString
user ByteString
pass Word8
charset) (ClientParams
cparams, String
subName) =
    forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (String -> PortNumber -> Int -> IO TCPConnection
connectWithBufferSize String
host PortNumber
port Int
bUFSIZE)
       (forall a. Connection a -> IO ()
TCP.close) forall a b. (a -> b) -> a -> b
$ \ TCPConnection
c -> do
            let is :: InputStream ByteString
is = forall a. Connection a -> InputStream ByteString
TCP.source TCPConnection
c
            InputStream Packet
is' <- InputStream ByteString -> IO (InputStream Packet)
decodeInputStream InputStream ByteString
is
            Packet
p <- InputStream Packet -> IO Packet
readPacket InputStream Packet
is'
            Greeting
greet <- forall a. Binary a => Packet -> IO a
decodeFromPacket Packet
p
            if Word32 -> Bool
supportTLS (Greeting -> Word32
greetingCaps Greeting
greet)
            then do
                let cparams' :: ClientParams
cparams' = ClientParams
cparams {
                            clientUseServerNameIndication :: Bool
TLS.clientUseServerNameIndication = Bool
False
                        ,   clientServerIdentification :: (String, ByteString)
TLS.clientServerIdentification = (String
subName, ByteString
"")
                        }
                let (Socket
sock, SockAddr
sockAddr) = forall a. Connection a -> a
Conn.connExtraInfo TCPConnection
c
                forall {p} {a}. Binary p => Connection a -> p -> IO ()
write TCPConnection
c (forall a. Binary a => Word8 -> a -> Packet
encodeToPacket Word8
1 forall a b. (a -> b) -> a -> b
$ Word8 -> SSLRequest
sslRequest Word8
charset)
                forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
TLS.contextNew Socket
sock ClientParams
cparams')
                               ( \ Context
ctx -> forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.bye Context
ctx forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall a. Connection a -> IO ()
TCP.close TCPConnection
c ) forall a b. (a -> b) -> a -> b
$ \ Context
ctx -> do
                    forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.handshake Context
ctx
                    TLSConnection
tc <- (Context, SockAddr) -> IO TLSConnection
TLS.tLsToConnection (Context
ctx, SockAddr
sockAddr)
                    let tlsIs :: InputStream ByteString
tlsIs = forall a. Connection a -> InputStream ByteString
TCP.source TLSConnection
tc
                    InputStream Packet
tlsIs' <- InputStream ByteString -> IO (InputStream Packet)
decodeInputStream InputStream ByteString
tlsIs
                    let auth :: Auth
auth = ByteString -> ByteString -> ByteString -> Word8 -> Greeting -> Auth
mkAuth ByteString
db ByteString
user ByteString
pass Word8
charset Greeting
greet
                    forall {p} {a}. Binary p => Connection a -> p -> IO ()
write TLSConnection
tc (forall a. Binary a => Word8 -> a -> Packet
encodeToPacket Word8
2 Auth
auth)
                    Packet
q <- InputStream Packet -> IO Packet
readPacket InputStream Packet
tlsIs'
                    if Packet -> Bool
isOK Packet
q
                    then do
                        IORef Bool
consumed <- forall a. a -> IO (IORef a)
newIORef Bool
True
                        let conn :: MySQLConn
conn = InputStream Packet
-> (Packet -> IO ()) -> IO () -> IORef Bool -> MySQLConn
MySQLConn InputStream Packet
tlsIs' (forall {p} {a}. Binary p => Connection a -> p -> IO ()
write TLSConnection
tc) (forall a. Connection a -> IO ()
TCP.close TLSConnection
tc) IORef Bool
consumed
                        forall (m :: * -> *) a. Monad m => a -> m a
return (Greeting
greet, MySQLConn
conn)
                    else forall a. Connection a -> IO ()
TCP.close TCPConnection
c forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall a. Binary a => Packet -> IO a
decodeFromPacket Packet
q forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall e a. Exception e => e -> IO a
throwIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. ERR -> ERRException
ERRException
            else forall a. HasCallStack => String -> a
error String
"Database.MySQL.TLS: server doesn't support TLS connection"
  where
    connectWithBufferSize :: String -> PortNumber -> Int -> IO TCPConnection
connectWithBufferSize String
h PortNumber
p Int
bs = String -> PortNumber -> IO (Socket, SockAddr)
TCP.connectSocket String
h PortNumber
p forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> (Socket, SockAddr) -> IO TCPConnection
TCP.socketToConnection Int
bs
    write :: Connection a -> p -> IO ()
write Connection a
c p
a = forall a. Connection a -> ByteString -> IO ()
TCP.send Connection a
c forall a b. (a -> b) -> a -> b
$ Put -> ByteString
Binary.runPut forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Binary t => t -> Put
Binary.put forall a b. (a -> b) -> a -> b
$ p
a