{-# LANGUAGE CPP #-} {-| Module : Database.MySQL.Connection Description : Connection managment Copyright : (c) Winterland, 2016 License : BSD Maintainer : drkoster@qq.com Stability : experimental Portability : PORTABLE This is an internal module, the 'MySQLConn' type should not directly acessed to user. -} module Database.MySQL.Connection where import Control.Exception (Exception, bracketOnError, throwIO) import Control.Monad import qualified Crypto.Hash as Crypto import qualified Data.Binary as Binary import qualified Data.Binary.Put as Binary import Data.Bits import qualified Data.ByteArray as BA import Data.ByteString (ByteString) import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as L import qualified Data.ByteString.Unsafe as B import Data.IORef (IORef, newIORef, readIORef, writeIORef) import qualified Data.TLSSetting as TLS import Data.Typeable import Data.Word import Database.MySQL.Protocol.Auth import Database.MySQL.Protocol.Command import Database.MySQL.Protocol.Packet import Network.Socket (HostName, PortNumber) import qualified Network.Socket as N import qualified Network.TLS as TLS import System.IO.Streams (InputStream, OutputStream) import qualified System.IO.Streams as Stream import qualified System.IO.Streams.Binary as Binary import qualified System.IO.Streams.TCP as TCP import qualified System.IO.Streams.TLS as TLS -------------------------------------------------------------------------------- data MySQLConn = MySQLConn { mysqlRead :: {-# UNPACK #-} !(InputStream Packet) , mysqlWrite :: {-# UNPACK #-} !(OutputStream Packet) , mysqlCloseSocket :: IO () , isConsumed :: {-# UNPACK #-} !(IORef Bool) } -- | Everything you need to establish a MySQL connection. -- -- You may want some helpers in "System.IO.Streams.TLS" to setup TLS connection. -- data ConnectInfo = ConnectInfo { ciHost :: HostName , ciPort :: PortNumber , ciDatabase :: ByteString , ciUser :: ByteString , ciPassword :: ByteString , ciTLSInfo :: Maybe (TLS.ClientParams, String) -- ^ If 'TLS.ClientParams' and subject name are provided, -- TLS connection will be used. } deriving Show -- | A simple 'ConnectInfo' targeting localhost with @user=root@ and empty password. -- defaultConnectInfo :: ConnectInfo defaultConnectInfo = ConnectInfo "127.0.0.1" 3306 "" "root" "" Nothing -------------------------------------------------------------------------------- -- | Socket buffer size. -- -- maybe exposed to 'ConnectInfo' laster? -- bUFSIZE :: Int bUFSIZE = 16384 -- | Establish a MySQL connection. -- connect :: ConnectInfo -> IO MySQLConn connect = fmap snd . connectDetail -- | Establish a MySQL connection with 'Greeting' back, so you can find server's version .etc. -- connectDetail :: ConnectInfo -> IO (Greeting, MySQLConn) connectDetail ci@(ConnectInfo host port _ _ _ tls) = case tls of Nothing -> bracketOnError (TCP.connectWithBufferSize host port bUFSIZE) (\(_, _, sock) -> N.close sock) $ \ (is, os, sock) -> do is' <- decodeInputStream is os' <- Binary.encodeOutputStream os p <- readPacket is' greet <- decodeFromPacket p let auth = mkAuth ci greet Stream.write (Just (encodeToPacket 1 auth)) os' q <- readPacket is' if isOK q then do consumed <- newIORef True let conn = (MySQLConn is' os' (N.close sock) consumed) return (greet, conn) else Stream.write Nothing os' >> decodeFromPacket q >>= throwIO . ERRException Just (cp, sname) -> bracketOnError (TLS.connect cp (Just sname) host port) (\(_, _, ctx) -> TLS.close ctx) $ \ (is, os, ctx) -> do is' <- decodeInputStream is os' <- Binary.encodeOutputStream os p <- readPacket is' greet <- decodeFromPacket p let auth = mkAuth ci greet Stream.write (Just (encodeToPacket 1 auth)) os' q <- readPacket is' if isOK q then do consumed <- newIORef True let conn = (MySQLConn is' os' (TLS.close ctx) consumed) return (greet, conn) else Stream.write Nothing os' >> decodeFromPacket q >>= throwIO . ERRException where mkAuth :: ConnectInfo -> Greeting -> Auth mkAuth (ConnectInfo _ _ db user pass _) greet = let salt = greetingSalt1 greet `B.append` greetingSalt2 greet scambleBuf = scramble salt pass in Auth clientCap clientMaxPacketSize clientCharset user scambleBuf db scramble :: ByteString -> ByteString -> ByteString scramble salt pass | B.null pass = B.empty | otherwise = B.pack (B.zipWith xor sha1pass withSalt) where sha1pass = sha1 pass withSalt = sha1 (salt `B.append` sha1 sha1pass) sha1 :: ByteString -> ByteString sha1 = BA.convert . (Crypto.hash :: ByteString -> Crypto.Digest Crypto.SHA1) -- | A specialized 'decodeInputStream' here for speed decodeInputStream :: InputStream ByteString -> IO (InputStream Packet) decodeInputStream is = Stream.makeInputStream $ do bs <- Stream.readExactly 4 is let len = fromIntegral (bs `B.unsafeIndex` 0) .|. fromIntegral (bs `B.unsafeIndex` 1) `shiftL` 8 .|. fromIntegral (bs `B.unsafeIndex` 2) `shiftL` 16 seqN = bs `B.unsafeIndex` 3 body <- loopRead [] len is return . Just $ Packet len seqN body loopRead acc 0 _ = return $! L.fromChunks (reverse acc) loopRead acc k is = do bs <- Stream.read is case bs of Nothing -> throwIO NetworkException Just bs' -> do let l = fromIntegral (B.length bs') if l >= k then do let (a, rest) = B.splitAt (fromIntegral k) bs' unless (B.null rest) (Stream.unRead rest is) return $! L.fromChunks (reverse (a:acc)) else do let k' = k - l k' `seq` loopRead (bs':acc) k' is -- | Close a MySQL connection. -- close :: MySQLConn -> IO () close (MySQLConn _ os closeSocket _) = do Stream.write Nothing os closeSocket -- | Send a 'COM_PING'. -- ping :: MySQLConn -> IO OK ping = flip command COM_PING -------------------------------------------------------------------------------- -- helpers -- | Send a 'Command' which don't return a resultSet. -- command :: MySQLConn -> Command -> IO OK command conn@(MySQLConn is os _ _) cmd = do guardUnconsumed conn writeCommand cmd os waitCommandReply is {-# INLINE command #-} waitCommandReply :: InputStream Packet -> IO OK waitCommandReply is = do p <- readPacket is if | isERR p -> decodeFromPacket p >>= throwIO . ERRException | isOK p -> decodeFromPacket p | otherwise -> throwIO (UnexpectedPacket p) {-# INLINE waitCommandReply #-} readPacket :: InputStream Packet -> IO Packet readPacket is = Stream.read is >>= maybe (throwIO NetworkException) (\ p@(Packet len _ bs) -> if len < 16777215 then return p else go len [bs]) where go len acc = Stream.read is >>= maybe (throwIO NetworkException) (\ (Packet len' seqN bs) -> do let len'' = len + len' acc' = bs:acc if len' < 16777215 then return (Packet len'' seqN (L.concat . reverse $ acc')) else len'' `seq` go len'' acc' ) {-# INLINE readPacket #-} writeCommand :: Command -> OutputStream Packet -> IO () writeCommand a os = let bs = Binary.runPut (Binary.put a) in go (fromIntegral (L.length bs)) 0 bs os where go len seqN bs os' = do if len < 16777215 then Stream.write (Just (Packet len seqN bs)) os' else do let (bs', rest) = L.splitAt 16777215 bs seqN' = seqN + 1 len' = len - 16777215 Stream.write (Just (Packet 16777215 seqN bs')) os' seqN' `seq` len' `seq` go len' seqN' rest os' {-# INLINE writeCommand #-} guardUnconsumed :: MySQLConn -> IO () guardUnconsumed (MySQLConn _ _ _ consumed) = do c <- readIORef consumed unless c (throwIO UnconsumedResultSet) {-# INLINE guardUnconsumed #-} writeIORef' :: IORef a -> a -> IO () writeIORef' ref x = x `seq` writeIORef ref x {-# INLINE writeIORef' #-} -------------------------------------------------------------------------------- -- default Capability Flags #define CLIENT_LONG_PASSWORD 0x00000001 #define CLIENT_FOUND_ROWS 0x00000002 #define CLIENT_LONG_FLAG 0x00000004 #define CLIENT_CONNECT_WITH_DB 0x00000008 #define CLIENT_NO_SCHEMA 0x00000010 #define CLIENT_COMPRESS 0x00000020 #define CLIENT_ODBC 0x00000040 #define CLIENT_LOCAL_FILES 0x00000080 #define CLIENT_IGNORE_SPACE 0x00000100 #define CLIENT_PROTOCOL_41 0x00000200 #define CLIENT_INTERACTIVE 0x00000400 #define CLIENT_SSL 0x00000800 #define CLIENT_IGNORE_SIGPIPE 0x00001000 #define CLIENT_TRANSACTIONS 0x00002000 #define CLIENT_RESERVED 0x00004000 #define CLIENT_SECURE_CONNECTION 0x00008000 #define CLIENT_MULTI_STATEMENTS 0x00010000 #define CLIENT_MULTI_RESULTS 0x00020000 #define CLIENT_PS_MULTI_RESULTS 0x00040000 #define CLIENT_PLUGIN_AUTH 0x00080000 #define CLIENT_CONNECT_ATTRS 0x00100000 #define CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA 0x00200000 clientCap :: Word32 clientCap = CLIENT_LONG_PASSWORD .|. CLIENT_LONG_FLAG .|. CLIENT_CONNECT_WITH_DB .|. CLIENT_IGNORE_SPACE .|. CLIENT_PROTOCOL_41 .|. CLIENT_TRANSACTIONS .|. CLIENT_MULTI_STATEMENTS .|. CLIENT_SECURE_CONNECTION clientMaxPacketSize :: Word32 clientMaxPacketSize = 0x00ffffff :: Word32 -- | Always use @utf8_general_ci@ when connecting mysql server, -- since this will simplify string decoding. clientCharset :: Word8 clientCharset = 0x21 :: Word8 -------------------------------------------------------------------------------- -- Exceptions data NetworkException = NetworkException deriving (Typeable, Show) instance Exception NetworkException data UnconsumedResultSet = UnconsumedResultSet deriving (Typeable, Show) instance Exception UnconsumedResultSet data ERRException = ERRException ERR deriving (Typeable, Show) instance Exception ERRException data UnexpectedPacket = UnexpectedPacket Packet deriving (Typeable, Show) instance Exception UnexpectedPacket