module Database.MySQL.Connection where
import Control.Exception (Exception, bracketOnError,
throwIO)
import Control.Monad
import qualified Crypto.Hash as Crypto
import qualified Data.Binary as Binary
import qualified Data.Binary.Put as Binary
import Data.Bits
import qualified Data.ByteArray as BA
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
import qualified Data.ByteString.Unsafe as B
import Data.IORef (IORef, newIORef, readIORef,
writeIORef)
import qualified Data.TLSSetting as TLS
import Data.Typeable
import Data.Word
import Database.MySQL.Protocol.Auth
import Database.MySQL.Protocol.Command
import Database.MySQL.Protocol.Packet
import Network.Socket (HostName, PortNumber)
import qualified Network.Socket as N
import qualified Network.TLS as TLS
import System.IO.Streams (InputStream, OutputStream)
import qualified System.IO.Streams as Stream
import qualified System.IO.Streams.Binary as Binary
import qualified System.IO.Streams.TCP as TCP
import qualified System.IO.Streams.TLS as TLS
data MySQLConn = MySQLConn {
mysqlRead :: !(InputStream Packet)
, mysqlWrite :: !(OutputStream Packet)
, mysqlCloseSocket :: IO ()
, isConsumed :: !(IORef Bool)
}
data ConnectInfo = ConnectInfo
{ ciHost :: HostName
, ciPort :: PortNumber
, ciDatabase :: ByteString
, ciUser :: ByteString
, ciPassword :: ByteString
, ciTLSInfo :: Maybe (TLS.ClientParams, String)
} deriving Show
defaultConnectInfo :: ConnectInfo
defaultConnectInfo = ConnectInfo "127.0.0.1" 3306 "" "root" "" Nothing
bUFSIZE :: Int
bUFSIZE = 16384
connect :: ConnectInfo -> IO MySQLConn
connect = fmap snd . connectDetail
connectDetail :: ConnectInfo -> IO (Greeting, MySQLConn)
connectDetail ci@(ConnectInfo host port _ _ _ tls) =
case tls of
Nothing ->
bracketOnError (TCP.connectWithBufferSize host port bUFSIZE)
(\(_, _, sock) -> N.close sock) $ \ (is, os, sock) -> do
is' <- decodeInputStream is
os' <- Binary.encodeOutputStream os
p <- readPacket is'
greet <- decodeFromPacket p
let auth = mkAuth ci greet
Stream.write (Just (encodeToPacket 1 auth)) os'
q <- readPacket is'
if isOK q
then do
consumed <- newIORef True
let conn = (MySQLConn is' os' (N.close sock) consumed)
return (greet, conn)
else Stream.write Nothing os' >> decodeFromPacket q >>= throwIO . ERRException
Just (cp, sname) ->
bracketOnError (TLS.connect cp (Just sname) host port)
(\(_, _, ctx) -> TLS.close ctx) $ \ (is, os, ctx) -> do
is' <- decodeInputStream is
os' <- Binary.encodeOutputStream os
p <- readPacket is'
greet <- decodeFromPacket p
let auth = mkAuth ci greet
Stream.write (Just (encodeToPacket 1 auth)) os'
q <- readPacket is'
if isOK q
then do
consumed <- newIORef True
let conn = (MySQLConn is' os' (TLS.close ctx) consumed)
return (greet, conn)
else Stream.write Nothing os' >> decodeFromPacket q >>= throwIO . ERRException
where
mkAuth :: ConnectInfo -> Greeting -> Auth
mkAuth (ConnectInfo _ _ db user pass _) greet =
let salt = greetingSalt1 greet `B.append` greetingSalt2 greet
scambleBuf = scramble salt pass
in Auth clientCap clientMaxPacketSize clientCharset user scambleBuf db
scramble :: ByteString -> ByteString -> ByteString
scramble salt pass
| B.null pass = B.empty
| otherwise = B.pack (B.zipWith xor sha1pass withSalt)
where sha1pass = sha1 pass
withSalt = sha1 (salt `B.append` sha1 sha1pass)
sha1 :: ByteString -> ByteString
sha1 = BA.convert . (Crypto.hash :: ByteString -> Crypto.Digest Crypto.SHA1)
decodeInputStream :: InputStream ByteString -> IO (InputStream Packet)
decodeInputStream is = Stream.makeInputStream $ do
bs <- Stream.readExactly 4 is
let len = fromIntegral (bs `B.unsafeIndex` 0)
.|. fromIntegral (bs `B.unsafeIndex` 1) `shiftL` 8
.|. fromIntegral (bs `B.unsafeIndex` 2) `shiftL` 16
seqN = bs `B.unsafeIndex` 3
body <- loopRead [] len is
return . Just $ Packet len seqN body
loopRead acc 0 _ = return $! L.fromChunks (reverse acc)
loopRead acc k is = do
bs <- Stream.read is
case bs of Nothing -> throwIO NetworkException
Just bs' -> do let l = fromIntegral (B.length bs')
if l >= k
then do
let (a, rest) = B.splitAt (fromIntegral k) bs'
unless (B.null rest) (Stream.unRead rest is)
return $! L.fromChunks (reverse (a:acc))
else do
let k' = k l
k' `seq` loopRead (bs':acc) k' is
close :: MySQLConn -> IO ()
close (MySQLConn _ os closeSocket _) = do
Stream.write Nothing os
closeSocket
ping :: MySQLConn -> IO OK
ping = flip command COM_PING
command :: MySQLConn -> Command -> IO OK
command conn@(MySQLConn is os _ _) cmd = do
guardUnconsumed conn
writeCommand cmd os
waitCommandReply is
waitCommandReply :: InputStream Packet -> IO OK
waitCommandReply is = do
p <- readPacket is
if | isERR p -> decodeFromPacket p >>= throwIO . ERRException
| isOK p -> decodeFromPacket p
| otherwise -> throwIO (UnexpectedPacket p)
readPacket :: InputStream Packet -> IO Packet
readPacket is = Stream.read is >>= maybe
(throwIO NetworkException)
(\ p@(Packet len _ bs) -> if len < 16777215 then return p else go len [bs])
where
go len acc = Stream.read is >>= maybe
(throwIO NetworkException)
(\ (Packet len' seqN bs) -> do
let len'' = len + len'
acc' = bs:acc
if len' < 16777215
then return (Packet len'' seqN (L.concat . reverse $ acc'))
else len'' `seq` go len'' acc'
)
writeCommand :: Command -> OutputStream Packet -> IO ()
writeCommand a os = let bs = Binary.runPut (Binary.put a) in
go (fromIntegral (L.length bs)) 0 bs os
where
go len seqN bs os' = do
if len < 16777215
then Stream.write (Just (Packet len seqN bs)) os'
else do
let (bs', rest) = L.splitAt 16777215 bs
seqN' = seqN + 1
len' = len 16777215
Stream.write (Just (Packet 16777215 seqN bs')) os'
seqN' `seq` len' `seq` go len' seqN' rest os'
guardUnconsumed :: MySQLConn -> IO ()
guardUnconsumed (MySQLConn _ _ _ consumed) = do
c <- readIORef consumed
unless c (throwIO UnconsumedResultSet)
writeIORef' :: IORef a -> a -> IO ()
writeIORef' ref x = x `seq` writeIORef ref x
#define CLIENT_LONG_PASSWORD 0x00000001
#define CLIENT_FOUND_ROWS 0x00000002
#define CLIENT_LONG_FLAG 0x00000004
#define CLIENT_CONNECT_WITH_DB 0x00000008
#define CLIENT_NO_SCHEMA 0x00000010
#define CLIENT_COMPRESS 0x00000020
#define CLIENT_ODBC 0x00000040
#define CLIENT_LOCAL_FILES 0x00000080
#define CLIENT_IGNORE_SPACE 0x00000100
#define CLIENT_PROTOCOL_41 0x00000200
#define CLIENT_INTERACTIVE 0x00000400
#define CLIENT_SSL 0x00000800
#define CLIENT_IGNORE_SIGPIPE 0x00001000
#define CLIENT_TRANSACTIONS 0x00002000
#define CLIENT_RESERVED 0x00004000
#define CLIENT_SECURE_CONNECTION 0x00008000
#define CLIENT_MULTI_STATEMENTS 0x00010000
#define CLIENT_MULTI_RESULTS 0x00020000
#define CLIENT_PS_MULTI_RESULTS 0x00040000
#define CLIENT_PLUGIN_AUTH 0x00080000
#define CLIENT_CONNECT_ATTRS 0x00100000
#define CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA 0x00200000
clientCap :: Word32
clientCap = CLIENT_LONG_PASSWORD
.|. CLIENT_LONG_FLAG
.|. CLIENT_CONNECT_WITH_DB
.|. CLIENT_IGNORE_SPACE
.|. CLIENT_PROTOCOL_41
.|. CLIENT_TRANSACTIONS
.|. CLIENT_MULTI_STATEMENTS
.|. CLIENT_SECURE_CONNECTION
clientMaxPacketSize :: Word32
clientMaxPacketSize = 0x00ffffff :: Word32
clientCharset :: Word8
clientCharset = 0x21 :: Word8
data NetworkException = NetworkException deriving (Typeable, Show)
instance Exception NetworkException
data UnconsumedResultSet = UnconsumedResultSet deriving (Typeable, Show)
instance Exception UnconsumedResultSet
data ERRException = ERRException ERR deriving (Typeable, Show)
instance Exception ERRException
data UnexpectedPacket = UnexpectedPacket Packet deriving (Typeable, Show)
instance Exception UnexpectedPacket