{-| Module : Database.MySQL.Connection Description : Connection managment Copyright : (c) Winterland, 2016 License : BSD Maintainer : drkoster@qq.com Stability : experimental Portability : PORTABLE This is an internal module, the 'MySQLConn' type should not directly acessed to user. -} module Database.MySQL.Connection where import Control.Applicative import Control.Exception (Exception, bracketOnError, throwIO, catch, SomeException) import Control.Monad import qualified Crypto.Hash as Crypto import qualified Data.Binary as Binary import qualified Data.Binary.Put as Binary import Data.Bits import qualified Data.ByteArray as BA import Data.ByteString (ByteString) import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as L import qualified Data.ByteString.Unsafe as B import Data.IORef (IORef, newIORef, readIORef, writeIORef) import Data.Typeable import Data.Word import Database.MySQL.Protocol.Auth import Database.MySQL.Protocol.Command import Database.MySQL.Protocol.Packet import Network.Socket (HostName, PortNumber) import System.IO.Streams (InputStream) import qualified System.IO.Streams as Stream import qualified System.IO.Streams.TCP as TCP import qualified Data.Connection as TCP -------------------------------------------------------------------------------- -- | 'MySQLConn' wrap both 'InputStream' and 'OutputStream' for MySQL 'Packet'. -- -- You shouldn't use one 'MySQLConn' in different thread, if you do that, -- consider protecting it with a @MVar@. -- data MySQLConn = MySQLConn { mysqlRead :: {-# UNPACK #-} !(InputStream Packet) , mysqlWrite :: (Packet -> IO ()) , mysqlCloseSocket :: IO () , isConsumed :: {-# UNPACK #-} !(IORef Bool) } -- | Everything you need to establish a MySQL connection. -- -- To setup a TLS connection, use module "Database.MySQL.TLS" or "Database.MySQL.OpenSSL". -- data ConnectInfo = ConnectInfo { ciHost :: HostName , ciPort :: PortNumber , ciDatabase :: ByteString , ciUser :: ByteString , ciPassword :: ByteString , ciCharset :: Word8 } deriving Show -- | A simple 'ConnectInfo' targeting localhost with @user=root@ and empty password. -- -- Default charset is set to @utf8_general_ci@ to support older(< 5.5.3) MySQL versions, -- but be aware this is a partial utf8 encoding, you may want to use 'defaultConnectInfoMB4' -- instead to support full utf8 charset(emoji, etc.). You can query your server's support -- with @SELECT id, collation_name FROM information_schema.collations ORDER BY id;@ -- defaultConnectInfo :: ConnectInfo defaultConnectInfo = ConnectInfo "127.0.0.1" 3306 "" "root" "" utf8_general_ci -- | 'defaultConnectInfo' with charset set to @utf8mb4_unicode_ci@ -- -- This is recommanded on any MySQL server version >= 5.5.3. -- defaultConnectInfoMB4 :: ConnectInfo defaultConnectInfoMB4 = ConnectInfo "127.0.0.1" 3306 "" "root" "" utf8mb4_unicode_ci utf8_general_ci :: Word8 utf8_general_ci = 33 utf8mb4_unicode_ci :: Word8 utf8mb4_unicode_ci = 224 -------------------------------------------------------------------------------- -- | Socket buffer size. -- -- maybe exposed to 'ConnectInfo' laster? -- bUFSIZE :: Int bUFSIZE = 16384 -- | Establish a MySQL connection. -- connect :: ConnectInfo -> IO MySQLConn connect = fmap snd . connectDetail -- | Establish a MySQL connection with 'Greeting' back, so you can find server's version .etc. -- connectDetail :: ConnectInfo -> IO (Greeting, MySQLConn) connectDetail (ConnectInfo host port db user pass charset) = bracketOnError open TCP.close go where open = connectWithBufferSize host port bUFSIZE go c = do let is = TCP.source c is' <- decodeInputStream is p <- readPacket is' greet <- decodeFromPacket p let auth = mkAuth db user pass charset greet write c $ encodeToPacket 1 auth q <- readPacket is' if isOK q then do consumed <- newIORef True let waitNotMandatoryOK = catch (void (waitCommandReply is')) -- server will either reply an OK packet ((\ _ -> return ()) :: SomeException -> IO ()) -- or directy close the connection conn = MySQLConn is' (write c) (writeCommand COM_QUIT (write c) >> waitNotMandatoryOK >> TCP.close c) consumed return (greet, conn) else TCP.close c >> decodeFromPacket q >>= throwIO . ERRException connectWithBufferSize h p bs = TCP.connectSocket h p >>= TCP.socketToConnection bs write c a = TCP.send c $ Binary.runPut . Binary.put $ a mkAuth :: ByteString -> ByteString -> ByteString -> Word8 -> Greeting -> Auth mkAuth db user pass charset greet = let salt = greetingSalt1 greet `B.append` greetingSalt2 greet scambleBuf = scramble salt pass in Auth clientCap clientMaxPacketSize charset user scambleBuf db where scramble :: ByteString -> ByteString -> ByteString scramble salt pass' | B.null pass' = B.empty | otherwise = B.pack (B.zipWith xor sha1pass withSalt) where sha1pass = sha1 pass' withSalt = sha1 (salt `B.append` sha1 sha1pass) sha1 :: ByteString -> ByteString sha1 = BA.convert . (Crypto.hash :: ByteString -> Crypto.Digest Crypto.SHA1) -- | A specialized 'decodeInputStream' here for speed decodeInputStream :: InputStream ByteString -> IO (InputStream Packet) decodeInputStream is = Stream.makeInputStream $ do bs <- Stream.readExactly 4 is let len = fromIntegral (bs `B.unsafeIndex` 0) .|. fromIntegral (bs `B.unsafeIndex` 1) `shiftL` 8 .|. fromIntegral (bs `B.unsafeIndex` 2) `shiftL` 16 seqN = bs `B.unsafeIndex` 3 body <- loopRead [] len is return . Just $ Packet len seqN body where loopRead acc 0 _ = return $! L.fromChunks (reverse acc) loopRead acc k is' = do bs <- Stream.read is' case bs of Nothing -> throwIO NetworkException Just bs' -> do let l = fromIntegral (B.length bs') if l >= k then do let (a, rest) = B.splitAt (fromIntegral k) bs' unless (B.null rest) (Stream.unRead rest is') return $! L.fromChunks (reverse (a:acc)) else do let k' = k - l k' `seq` loopRead (bs':acc) k' is' -- | Close a MySQL connection. -- close :: MySQLConn -> IO () close (MySQLConn _ _ closeSocket _) = closeSocket -- | Send a 'COM_PING'. -- ping :: MySQLConn -> IO OK ping = flip command COM_PING -------------------------------------------------------------------------------- -- helpers -- | Send a 'Command' which don't return a resultSet. -- command :: MySQLConn -> Command -> IO OK command conn@(MySQLConn is os _ _) cmd = do guardUnconsumed conn writeCommand cmd os waitCommandReply is {-# INLINE command #-} waitCommandReply :: InputStream Packet -> IO OK waitCommandReply is = do p <- readPacket is if | isERR p -> decodeFromPacket p >>= throwIO . ERRException | isOK p -> decodeFromPacket p | otherwise -> throwIO (UnexpectedPacket p) {-# INLINE waitCommandReply #-} waitCommandReplys :: InputStream Packet -> IO [OK] waitCommandReplys is = do p <- readPacket is if | isERR p -> decodeFromPacket p >>= throwIO . ERRException | isOK p -> do ok <- decodeFromPacket p if isThereMore ok then (ok :) <$> waitCommandReplys is else return [ok] | otherwise -> throwIO (UnexpectedPacket p) {-# INLINE waitCommandReplys #-} readPacket :: InputStream Packet -> IO Packet readPacket is = Stream.read is >>= maybe (throwIO NetworkException) (\ p@(Packet len _ bs) -> if len < 16777215 then return p else go len [bs]) where go len acc = Stream.read is >>= maybe (throwIO NetworkException) (\ (Packet len' seqN bs) -> do let len'' = len + len' acc' = bs:acc if len' < 16777215 then return (Packet len'' seqN (L.concat . reverse $ acc')) else len'' `seq` go len'' acc' ) {-# INLINE readPacket #-} writeCommand :: Command -> (Packet -> IO ()) -> IO () writeCommand a writePacket = let bs = Binary.runPut (putCommand a) in go (fromIntegral (L.length bs)) 0 bs writePacket where go len seqN bs writePacket' = do if len < 16777215 then writePacket (Packet len seqN bs) else do let (bs', rest) = L.splitAt 16777215 bs seqN' = seqN + 1 len' = len - 16777215 writePacket (Packet 16777215 seqN bs') seqN' `seq` len' `seq` go len' seqN' rest writePacket' {-# INLINE writeCommand #-} guardUnconsumed :: MySQLConn -> IO () guardUnconsumed (MySQLConn _ _ _ consumed) = do c <- readIORef consumed unless c (throwIO UnconsumedResultSet) {-# INLINE guardUnconsumed #-} writeIORef' :: IORef a -> a -> IO () writeIORef' ref x = x `seq` writeIORef ref x {-# INLINE writeIORef' #-} -------------------------------------------------------------------------------- -- Exceptions data NetworkException = NetworkException deriving (Typeable, Show) instance Exception NetworkException data UnconsumedResultSet = UnconsumedResultSet deriving (Typeable, Show) instance Exception UnconsumedResultSet data ERRException = ERRException ERR deriving (Typeable, Show) instance Exception ERRException data UnexpectedPacket = UnexpectedPacket Packet deriving (Typeable, Show) instance Exception UnexpectedPacket