{-|
Module      : Database.MySQL.Connection
Description : Connection managment
Copyright   : (c) Winterland, 2016
License     : BSD
Maintainer  : drkoster@qq.com
Stability   : experimental
Portability : PORTABLE

This is an internal module, the 'MySQLConn' type should not directly acessed to user.

-}

module Database.MySQL.Connection where

import           Control.Exception               (Exception, bracketOnError,
                                                  throwIO, catch, SomeException)
import           Control.Monad
import qualified Crypto.Hash                     as Crypto
import qualified Data.Binary                     as Binary
import qualified Data.Binary.Put                 as Binary
import           Data.Bits
import qualified Data.ByteArray                  as BA
import           Data.ByteString                 (ByteString)
import qualified Data.ByteString                 as B
import qualified Data.ByteString.Lazy            as L
import qualified Data.ByteString.Unsafe          as B
import           Data.IORef                      (IORef, newIORef, readIORef,
                                                  writeIORef)
import           Data.Typeable
import           Data.Word
import           Database.MySQL.Protocol.Auth
import           Database.MySQL.Protocol.Command
import           Database.MySQL.Protocol.Packet
import           Network.Socket                  (HostName, PortNumber)
import           System.IO.Streams               (InputStream)
import qualified System.IO.Streams               as Stream
import qualified System.IO.Streams.TCP           as TCP
import qualified Data.Connection                 as TCP

--------------------------------------------------------------------------------

-- | 'MySQLConn' wrap both 'InputStream' and 'OutputStream' for MySQL 'Packet'.
--
-- You shouldn't use one 'MySQLConn' in different thread, if you do that,
-- consider protecting it with a @MVar@.
--
data MySQLConn = MySQLConn {
        MySQLConn -> InputStream Packet
mysqlRead        :: {-# UNPACK #-} !(InputStream  Packet)
    ,   MySQLConn -> Packet -> IO ()
mysqlWrite       :: (Packet -> IO ())
    ,   MySQLConn -> IO ()
mysqlCloseSocket :: IO ()
    ,   MySQLConn -> IORef Bool
isConsumed       :: {-# UNPACK #-} !(IORef Bool)
    }

-- | Everything you need to establish a MySQL connection.
--
-- To setup a TLS connection, use module "Database.MySQL.TLS" or "Database.MySQL.OpenSSL".
--
data ConnectInfo = ConnectInfo
    { ConnectInfo -> HostName
ciHost     :: HostName
    , ConnectInfo -> PortNumber
ciPort     :: PortNumber
    , ConnectInfo -> ByteString
ciDatabase :: ByteString
    , ConnectInfo -> ByteString
ciUser     :: ByteString
    , ConnectInfo -> ByteString
ciPassword :: ByteString
    , ConnectInfo -> Word8
ciCharset  :: Word8
    } deriving Int -> ConnectInfo -> ShowS
[ConnectInfo] -> ShowS
ConnectInfo -> HostName
(Int -> ConnectInfo -> ShowS)
-> (ConnectInfo -> HostName)
-> ([ConnectInfo] -> ShowS)
-> Show ConnectInfo
forall a.
(Int -> a -> ShowS) -> (a -> HostName) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ConnectInfo -> ShowS
showsPrec :: Int -> ConnectInfo -> ShowS
$cshow :: ConnectInfo -> HostName
show :: ConnectInfo -> HostName
$cshowList :: [ConnectInfo] -> ShowS
showList :: [ConnectInfo] -> ShowS
Show

-- | A simple 'ConnectInfo' targeting localhost with @user=root@ and empty password.
--
--  Default charset is set to @utf8_general_ci@ to support older(< 5.5.3) MySQL versions,
--  but be aware this is a partial utf8 encoding, you may want to use 'defaultConnectInfoMB4'
--  instead to support full utf8 charset(emoji, etc.). You can query your server's support
--  with @SELECT id, collation_name FROM information_schema.collations ORDER BY id;@
--
defaultConnectInfo :: ConnectInfo
defaultConnectInfo :: ConnectInfo
defaultConnectInfo = HostName
-> PortNumber
-> ByteString
-> ByteString
-> ByteString
-> Word8
-> ConnectInfo
ConnectInfo HostName
"127.0.0.1" PortNumber
3306 ByteString
"" ByteString
"root" ByteString
"" Word8
utf8_general_ci

-- | 'defaultConnectInfo' with charset set to @utf8mb4_unicode_ci@
--
-- This is recommanded on any MySQL server version >= 5.5.3.
--
defaultConnectInfoMB4 :: ConnectInfo
defaultConnectInfoMB4 :: ConnectInfo
defaultConnectInfoMB4 = HostName
-> PortNumber
-> ByteString
-> ByteString
-> ByteString
-> Word8
-> ConnectInfo
ConnectInfo HostName
"127.0.0.1" PortNumber
3306 ByteString
"" ByteString
"root" ByteString
"" Word8
utf8mb4_unicode_ci

utf8_general_ci :: Word8
utf8_general_ci :: Word8
utf8_general_ci = Word8
33

utf8mb4_unicode_ci :: Word8
utf8mb4_unicode_ci :: Word8
utf8mb4_unicode_ci = Word8
224

--------------------------------------------------------------------------------

-- | Socket buffer size.
--
-- maybe exposed to 'ConnectInfo' laster?
--
bUFSIZE :: Int
bUFSIZE :: Int
bUFSIZE = Int
16384

-- | Establish a MySQL connection.
--
connect :: ConnectInfo -> IO MySQLConn
connect :: ConnectInfo -> IO MySQLConn
connect = ((Greeting, MySQLConn) -> MySQLConn)
-> IO (Greeting, MySQLConn) -> IO MySQLConn
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Greeting, MySQLConn) -> MySQLConn
forall a b. (a, b) -> b
snd (IO (Greeting, MySQLConn) -> IO MySQLConn)
-> (ConnectInfo -> IO (Greeting, MySQLConn))
-> ConnectInfo
-> IO MySQLConn
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConnectInfo -> IO (Greeting, MySQLConn)
connectDetail

-- | Establish a MySQL connection with 'Greeting' back, so you can find server's version .etc.
--
connectDetail :: ConnectInfo -> IO (Greeting, MySQLConn)
connectDetail :: ConnectInfo -> IO (Greeting, MySQLConn)
connectDetail (ConnectInfo HostName
host PortNumber
port ByteString
db ByteString
user ByteString
pass Word8
charset)
    = IO TCPConnection
-> (TCPConnection -> IO ())
-> (TCPConnection -> IO (Greeting, MySQLConn))
-> IO (Greeting, MySQLConn)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError IO TCPConnection
open TCPConnection -> IO ()
forall a. Connection a -> IO ()
TCP.close TCPConnection -> IO (Greeting, MySQLConn)
forall {a}. Connection a -> IO (Greeting, MySQLConn)
go
  where
    open :: IO TCPConnection
open  = HostName -> PortNumber -> Int -> IO TCPConnection
connectWithBufferSize HostName
host PortNumber
port Int
bUFSIZE
    go :: Connection a -> IO (Greeting, MySQLConn)
go Connection a
c  = do
        let is :: InputStream ByteString
is = Connection a -> InputStream ByteString
forall a. Connection a -> InputStream ByteString
TCP.source Connection a
c
        InputStream Packet
is' <- InputStream ByteString -> IO (InputStream Packet)
decodeInputStream InputStream ByteString
is
        Packet
p <- InputStream Packet -> IO Packet
readPacket InputStream Packet
is'
        Greeting
greet <- Packet -> IO Greeting
forall a. Binary a => Packet -> IO a
decodeFromPacket Packet
p
        let auth :: Auth
auth = ByteString -> ByteString -> ByteString -> Word8 -> Greeting -> Auth
mkAuth ByteString
db ByteString
user ByteString
pass Word8
charset Greeting
greet
        Connection a -> Packet -> IO ()
forall {p} {a}. Binary p => Connection a -> p -> IO ()
write Connection a
c (Packet -> IO ()) -> Packet -> IO ()
forall a b. (a -> b) -> a -> b
$ Word8 -> Auth -> Packet
forall a. Binary a => Word8 -> a -> Packet
encodeToPacket Word8
1 Auth
auth
        Packet
q <- InputStream Packet -> IO Packet
readPacket InputStream Packet
is'
        if Packet -> Bool
isOK Packet
q
        then do
            IORef Bool
consumed <- Bool -> IO (IORef Bool)
forall a. a -> IO (IORef a)
newIORef Bool
True
            let waitNotMandatoryOK :: IO ()
waitNotMandatoryOK = IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch
                    (IO OK -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (InputStream Packet -> IO OK
waitCommandReply InputStream Packet
is'))           -- server will either reply an OK packet
                    ((\ SomeException
_ -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()) :: SomeException -> IO ())   -- or directy close the connection
                conn :: MySQLConn
conn = InputStream Packet
-> (Packet -> IO ()) -> IO () -> IORef Bool -> MySQLConn
MySQLConn InputStream Packet
is'
                    (Connection a -> Packet -> IO ()
forall {p} {a}. Binary p => Connection a -> p -> IO ()
write Connection a
c)
                    (Command -> (Packet -> IO ()) -> IO ()
writeCommand Command
COM_QUIT (Connection a -> Packet -> IO ()
forall {p} {a}. Binary p => Connection a -> p -> IO ()
write Connection a
c) IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO ()
waitNotMandatoryOK IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Connection a -> IO ()
forall a. Connection a -> IO ()
TCP.close Connection a
c)
                    IORef Bool
consumed
            (Greeting, MySQLConn) -> IO (Greeting, MySQLConn)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Greeting
greet, MySQLConn
conn)
        else Connection a -> IO ()
forall a. Connection a -> IO ()
TCP.close Connection a
c IO () -> IO ERR -> IO ERR
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Packet -> IO ERR
forall a. Binary a => Packet -> IO a
decodeFromPacket Packet
q IO ERR
-> (ERR -> IO (Greeting, MySQLConn)) -> IO (Greeting, MySQLConn)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ERRException -> IO (Greeting, MySQLConn)
forall e a. Exception e => e -> IO a
throwIO (ERRException -> IO (Greeting, MySQLConn))
-> (ERR -> ERRException) -> ERR -> IO (Greeting, MySQLConn)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ERR -> ERRException
ERRException

    connectWithBufferSize :: HostName -> PortNumber -> Int -> IO TCPConnection
connectWithBufferSize HostName
h PortNumber
p Int
bs = HostName -> PortNumber -> IO (Socket, SockAddr)
TCP.connectSocket HostName
h PortNumber
p IO (Socket, SockAddr)
-> ((Socket, SockAddr) -> IO TCPConnection) -> IO TCPConnection
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> (Socket, SockAddr) -> IO TCPConnection
TCP.socketToConnection Int
bs
    write :: Connection a -> p -> IO ()
write Connection a
c p
a = Connection a -> ByteString -> IO ()
forall a. Connection a -> ByteString -> IO ()
TCP.send Connection a
c (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ Put -> ByteString
Binary.runPut (Put -> ByteString) -> (p -> Put) -> p -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. p -> Put
forall t. Binary t => t -> Put
Binary.put (p -> ByteString) -> p -> ByteString
forall a b. (a -> b) -> a -> b
$ p
a

mkAuth :: ByteString -> ByteString -> ByteString -> Word8 -> Greeting -> Auth
mkAuth :: ByteString -> ByteString -> ByteString -> Word8 -> Greeting -> Auth
mkAuth ByteString
db ByteString
user ByteString
pass Word8
charset Greeting
greet =
    let salt :: ByteString
salt = Greeting -> ByteString
greetingSalt1 Greeting
greet ByteString -> ByteString -> ByteString
`B.append` Greeting -> ByteString
greetingSalt2 Greeting
greet
        scambleBuf :: ByteString
scambleBuf = ByteString -> ByteString -> ByteString
scramble ByteString
salt ByteString
pass
    in Word32
-> Word32
-> Word8
-> ByteString
-> ByteString
-> ByteString
-> Auth
Auth Word32
clientCap Word32
clientMaxPacketSize Word8
charset ByteString
user ByteString
scambleBuf ByteString
db
  where
    scramble :: ByteString -> ByteString -> ByteString
    scramble :: ByteString -> ByteString -> ByteString
scramble ByteString
salt ByteString
pass'
        | ByteString -> Bool
B.null ByteString
pass' = ByteString
B.empty
        | Bool
otherwise   = [Word8] -> ByteString
B.pack ((Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> [Word8]
forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
B.zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor ByteString
sha1pass ByteString
withSalt)
        where sha1pass :: ByteString
sha1pass = ByteString -> ByteString
sha1 ByteString
pass'
              withSalt :: ByteString
withSalt = ByteString -> ByteString
sha1 (ByteString
salt ByteString -> ByteString -> ByteString
`B.append` ByteString -> ByteString
sha1 ByteString
sha1pass)

    sha1 :: ByteString -> ByteString
    sha1 :: ByteString -> ByteString
sha1 = Digest SHA1 -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert (Digest SHA1 -> ByteString)
-> (ByteString -> Digest SHA1) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> Digest SHA1
forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
ba -> Digest a
Crypto.hash :: ByteString -> Crypto.Digest Crypto.SHA1)

-- | A specialized 'decodeInputStream' here for speed
decodeInputStream :: InputStream ByteString -> IO (InputStream Packet)
decodeInputStream :: InputStream ByteString -> IO (InputStream Packet)
decodeInputStream InputStream ByteString
is = IO (Maybe Packet) -> IO (InputStream Packet)
forall a. IO (Maybe a) -> IO (InputStream a)
Stream.makeInputStream (IO (Maybe Packet) -> IO (InputStream Packet))
-> IO (Maybe Packet) -> IO (InputStream Packet)
forall a b. (a -> b) -> a -> b
$ do
    ByteString
bs <- Int -> InputStream ByteString -> IO ByteString
Stream.readExactly Int
4 InputStream ByteString
is
    let len :: Int64
len =  Word8 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString
bs ByteString -> Int -> Word8
`B.unsafeIndex` Int
0)
           Int64 -> Int64 -> Int64
forall a. Bits a => a -> a -> a
.|. Word8 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString
bs ByteString -> Int -> Word8
`B.unsafeIndex` Int
1) Int64 -> Int -> Int64
forall a. Bits a => a -> Int -> a
`shiftL` Int
8
           Int64 -> Int64 -> Int64
forall a. Bits a => a -> a -> a
.|. Word8 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString
bs ByteString -> Int -> Word8
`B.unsafeIndex` Int
2) Int64 -> Int -> Int64
forall a. Bits a => a -> Int -> a
`shiftL` Int
16
        seqN :: Word8
seqN = ByteString
bs ByteString -> Int -> Word8
`B.unsafeIndex` Int
3
    ByteString
body <- [ByteString] -> Int64 -> InputStream ByteString -> IO ByteString
forall {t}.
Integral t =>
[ByteString] -> t -> InputStream ByteString -> IO ByteString
loopRead [] Int64
len InputStream ByteString
is
    Maybe Packet -> IO (Maybe Packet)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Packet -> IO (Maybe Packet))
-> (Packet -> Maybe Packet) -> Packet -> IO (Maybe Packet)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Packet -> Maybe Packet
forall a. a -> Maybe a
Just (Packet -> IO (Maybe Packet)) -> Packet -> IO (Maybe Packet)
forall a b. (a -> b) -> a -> b
$ Int64 -> Word8 -> ByteString -> Packet
Packet Int64
len Word8
seqN ByteString
body
  where
    loopRead :: [ByteString] -> t -> InputStream ByteString -> IO ByteString
loopRead [ByteString]
acc t
0 InputStream ByteString
_  = ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$! [ByteString] -> ByteString
L.fromChunks ([ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse [ByteString]
acc)
    loopRead [ByteString]
acc t
k InputStream ByteString
is' = do
        Maybe ByteString
bs <- InputStream ByteString -> IO (Maybe ByteString)
forall a. InputStream a -> IO (Maybe a)
Stream.read InputStream ByteString
is'
        case Maybe ByteString
bs of Maybe ByteString
Nothing -> NetworkException -> IO ByteString
forall e a. Exception e => e -> IO a
throwIO NetworkException
NetworkException
                   Just ByteString
bs' -> do let l :: t
l = Int -> t
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
B.length ByteString
bs')
                                  if t
l t -> t -> Bool
forall a. Ord a => a -> a -> Bool
>= t
k
                                  then do
                                      let (ByteString
a, ByteString
rest) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt (t -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral t
k) ByteString
bs'
                                      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> Bool
B.null ByteString
rest) (ByteString -> InputStream ByteString -> IO ()
forall a. a -> InputStream a -> IO ()
Stream.unRead ByteString
rest InputStream ByteString
is')
                                      ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$! [ByteString] -> ByteString
L.fromChunks ([ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse (ByteString
aByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:[ByteString]
acc))
                                  else do
                                      let k' :: t
k' = t
k t -> t -> t
forall a. Num a => a -> a -> a
- t
l
                                      t
k' t -> IO ByteString -> IO ByteString
forall a b. a -> b -> b
`seq` [ByteString] -> t -> InputStream ByteString -> IO ByteString
loopRead (ByteString
bs'ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:[ByteString]
acc) t
k' InputStream ByteString
is'

-- | Close a MySQL connection.
--
close :: MySQLConn -> IO ()
close :: MySQLConn -> IO ()
close (MySQLConn InputStream Packet
_ Packet -> IO ()
_ IO ()
closeSocket IORef Bool
_) = IO ()
closeSocket

-- | Send a 'COM_PING'.
--
ping :: MySQLConn -> IO OK
ping :: MySQLConn -> IO OK
ping = (MySQLConn -> Command -> IO OK) -> Command -> MySQLConn -> IO OK
forall a b c. (a -> b -> c) -> b -> a -> c
flip MySQLConn -> Command -> IO OK
command Command
COM_PING

--------------------------------------------------------------------------------
-- helpers

-- | Send a 'Command' which don't return a resultSet.
--
command :: MySQLConn -> Command -> IO OK
command :: MySQLConn -> Command -> IO OK
command conn :: MySQLConn
conn@(MySQLConn InputStream Packet
is Packet -> IO ()
os IO ()
_ IORef Bool
_) Command
cmd = do
    MySQLConn -> IO ()
guardUnconsumed MySQLConn
conn
    Command -> (Packet -> IO ()) -> IO ()
writeCommand Command
cmd Packet -> IO ()
os
    InputStream Packet -> IO OK
waitCommandReply InputStream Packet
is
{-# INLINE command #-}

waitCommandReply :: InputStream Packet -> IO OK
waitCommandReply :: InputStream Packet -> IO OK
waitCommandReply InputStream Packet
is = do
    Packet
p <- InputStream Packet -> IO Packet
readPacket InputStream Packet
is
    if  | Packet -> Bool
isERR Packet
p -> Packet -> IO ERR
forall a. Binary a => Packet -> IO a
decodeFromPacket Packet
p IO ERR -> (ERR -> IO OK) -> IO OK
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ERRException -> IO OK
forall e a. Exception e => e -> IO a
throwIO (ERRException -> IO OK) -> (ERR -> ERRException) -> ERR -> IO OK
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ERR -> ERRException
ERRException
        | Packet -> Bool
isOK  Packet
p -> Packet -> IO OK
forall a. Binary a => Packet -> IO a
decodeFromPacket Packet
p
        | Bool
otherwise -> UnexpectedPacket -> IO OK
forall e a. Exception e => e -> IO a
throwIO (Packet -> UnexpectedPacket
UnexpectedPacket Packet
p)
{-# INLINE waitCommandReply #-}

waitCommandReplys :: InputStream Packet -> IO [OK]
waitCommandReplys :: InputStream Packet -> IO [OK]
waitCommandReplys InputStream Packet
is = do
    Packet
p <- InputStream Packet -> IO Packet
readPacket InputStream Packet
is
    if  | Packet -> Bool
isERR Packet
p -> Packet -> IO ERR
forall a. Binary a => Packet -> IO a
decodeFromPacket Packet
p IO ERR -> (ERR -> IO [OK]) -> IO [OK]
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ERRException -> IO [OK]
forall e a. Exception e => e -> IO a
throwIO (ERRException -> IO [OK])
-> (ERR -> ERRException) -> ERR -> IO [OK]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ERR -> ERRException
ERRException
        | Packet -> Bool
isOK  Packet
p -> do OK
ok <- Packet -> IO OK
forall a. Binary a => Packet -> IO a
decodeFromPacket Packet
p
                        if OK -> Bool
isThereMore OK
ok
                        then (OK
ok OK -> [OK] -> [OK]
forall a. a -> [a] -> [a]
:) ([OK] -> [OK]) -> IO [OK] -> IO [OK]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> InputStream Packet -> IO [OK]
waitCommandReplys InputStream Packet
is
                        else [OK] -> IO [OK]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return [OK
ok]
        | Bool
otherwise -> UnexpectedPacket -> IO [OK]
forall e a. Exception e => e -> IO a
throwIO (Packet -> UnexpectedPacket
UnexpectedPacket Packet
p)
{-# INLINE waitCommandReplys #-}

readPacket :: InputStream Packet -> IO Packet
readPacket :: InputStream Packet -> IO Packet
readPacket InputStream Packet
is = InputStream Packet -> IO (Maybe Packet)
forall a. InputStream a -> IO (Maybe a)
Stream.read InputStream Packet
is IO (Maybe Packet) -> (Maybe Packet -> IO Packet) -> IO Packet
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IO Packet -> (Packet -> IO Packet) -> Maybe Packet -> IO Packet
forall b a. b -> (a -> b) -> Maybe a -> b
maybe
    (NetworkException -> IO Packet
forall e a. Exception e => e -> IO a
throwIO NetworkException
NetworkException)
    (\ p :: Packet
p@(Packet Int64
len Word8
_ ByteString
bs) -> if Int64
len Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
16777215 then Packet -> IO Packet
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Packet
p else Int64 -> [ByteString] -> IO Packet
go Int64
len [ByteString
bs])
  where
    go :: Int64 -> [ByteString] -> IO Packet
go Int64
len [ByteString]
acc = InputStream Packet -> IO (Maybe Packet)
forall a. InputStream a -> IO (Maybe a)
Stream.read InputStream Packet
is IO (Maybe Packet) -> (Maybe Packet -> IO Packet) -> IO Packet
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IO Packet -> (Packet -> IO Packet) -> Maybe Packet -> IO Packet
forall b a. b -> (a -> b) -> Maybe a -> b
maybe
        (NetworkException -> IO Packet
forall e a. Exception e => e -> IO a
throwIO NetworkException
NetworkException)
        (\ (Packet Int64
len' Word8
seqN ByteString
bs) -> do
            let len'' :: Int64
len'' = Int64
len Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64
len'
                acc' :: [ByteString]
acc' = ByteString
bsByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:[ByteString]
acc
            if Int64
len' Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
16777215
            then Packet -> IO Packet
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int64 -> Word8 -> ByteString -> Packet
Packet Int64
len'' Word8
seqN ([ByteString] -> ByteString
L.concat ([ByteString] -> ByteString)
-> ([ByteString] -> [ByteString]) -> [ByteString] -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString]
acc'))
            else Int64
len'' Int64 -> IO Packet -> IO Packet
forall a b. a -> b -> b
`seq` Int64 -> [ByteString] -> IO Packet
go Int64
len'' [ByteString]
acc'
        )
{-# INLINE readPacket #-}

writeCommand :: Command -> (Packet -> IO ()) -> IO ()
writeCommand :: Command -> (Packet -> IO ()) -> IO ()
writeCommand Command
a Packet -> IO ()
writePacket = let bs :: ByteString
bs = Put -> ByteString
Binary.runPut (Command -> Put
putCommand Command
a) in
    Int64 -> Word8 -> ByteString -> (Packet -> IO ()) -> IO ()
forall {t}. Int64 -> Word8 -> ByteString -> t -> IO ()
go (ByteString -> Int64
L.length ByteString
bs) Word8
0 ByteString
bs Packet -> IO ()
writePacket
  where
    go :: Int64 -> Word8 -> ByteString -> t -> IO ()
go Int64
len Word8
seqN ByteString
bs t
writePacket' = do
        if Int64
len Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
16777215
        then Packet -> IO ()
writePacket (Int64 -> Word8 -> ByteString -> Packet
Packet Int64
len Word8
seqN ByteString
bs)
        else do
            let (ByteString
bs', ByteString
rest) = Int64 -> ByteString -> (ByteString, ByteString)
L.splitAt Int64
16777215 ByteString
bs
                seqN' :: Word8
seqN' = Word8
seqN Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
+ Word8
1
                len' :: Int64
len'  = Int64
len Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
- Int64
16777215

            Packet -> IO ()
writePacket (Int64 -> Word8 -> ByteString -> Packet
Packet Int64
16777215 Word8
seqN ByteString
bs')
            Word8
seqN' Word8 -> IO () -> IO ()
forall a b. a -> b -> b
`seq` Int64
len' Int64 -> IO () -> IO ()
forall a b. a -> b -> b
`seq` Int64 -> Word8 -> ByteString -> t -> IO ()
go Int64
len' Word8
seqN' ByteString
rest t
writePacket'
{-# INLINE writeCommand #-}

guardUnconsumed :: MySQLConn -> IO ()
guardUnconsumed :: MySQLConn -> IO ()
guardUnconsumed (MySQLConn InputStream Packet
_ Packet -> IO ()
_ IO ()
_ IORef Bool
consumed) = do
    Bool
c <- IORef Bool -> IO Bool
forall a. IORef a -> IO a
readIORef IORef Bool
consumed
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
c (UnconsumedResultSet -> IO ()
forall e a. Exception e => e -> IO a
throwIO UnconsumedResultSet
UnconsumedResultSet)
{-# INLINE guardUnconsumed #-}

writeIORef' :: IORef a -> a -> IO ()
writeIORef' :: forall a. IORef a -> a -> IO ()
writeIORef' IORef a
ref a
x = a
x a -> IO () -> IO ()
forall a b. a -> b -> b
`seq` IORef a -> a -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef a
ref a
x
{-# INLINE writeIORef' #-}

--------------------------------------------------------------------------------
-- Exceptions

data NetworkException = NetworkException deriving (Typeable, Int -> NetworkException -> ShowS
[NetworkException] -> ShowS
NetworkException -> HostName
(Int -> NetworkException -> ShowS)
-> (NetworkException -> HostName)
-> ([NetworkException] -> ShowS)
-> Show NetworkException
forall a.
(Int -> a -> ShowS) -> (a -> HostName) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> NetworkException -> ShowS
showsPrec :: Int -> NetworkException -> ShowS
$cshow :: NetworkException -> HostName
show :: NetworkException -> HostName
$cshowList :: [NetworkException] -> ShowS
showList :: [NetworkException] -> ShowS
Show)
instance Exception NetworkException

data UnconsumedResultSet = UnconsumedResultSet deriving (Typeable, Int -> UnconsumedResultSet -> ShowS
[UnconsumedResultSet] -> ShowS
UnconsumedResultSet -> HostName
(Int -> UnconsumedResultSet -> ShowS)
-> (UnconsumedResultSet -> HostName)
-> ([UnconsumedResultSet] -> ShowS)
-> Show UnconsumedResultSet
forall a.
(Int -> a -> ShowS) -> (a -> HostName) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> UnconsumedResultSet -> ShowS
showsPrec :: Int -> UnconsumedResultSet -> ShowS
$cshow :: UnconsumedResultSet -> HostName
show :: UnconsumedResultSet -> HostName
$cshowList :: [UnconsumedResultSet] -> ShowS
showList :: [UnconsumedResultSet] -> ShowS
Show)
instance Exception UnconsumedResultSet

data ERRException = ERRException ERR deriving (Typeable, Int -> ERRException -> ShowS
[ERRException] -> ShowS
ERRException -> HostName
(Int -> ERRException -> ShowS)
-> (ERRException -> HostName)
-> ([ERRException] -> ShowS)
-> Show ERRException
forall a.
(Int -> a -> ShowS) -> (a -> HostName) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ERRException -> ShowS
showsPrec :: Int -> ERRException -> ShowS
$cshow :: ERRException -> HostName
show :: ERRException -> HostName
$cshowList :: [ERRException] -> ShowS
showList :: [ERRException] -> ShowS
Show)
instance Exception ERRException

data UnexpectedPacket = UnexpectedPacket Packet deriving (Typeable, Int -> UnexpectedPacket -> ShowS
[UnexpectedPacket] -> ShowS
UnexpectedPacket -> HostName
(Int -> UnexpectedPacket -> ShowS)
-> (UnexpectedPacket -> HostName)
-> ([UnexpectedPacket] -> ShowS)
-> Show UnexpectedPacket
forall a.
(Int -> a -> ShowS) -> (a -> HostName) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> UnexpectedPacket -> ShowS
showsPrec :: Int -> UnexpectedPacket -> ShowS
$cshow :: UnexpectedPacket -> HostName
show :: UnexpectedPacket -> HostName
$cshowList :: [UnexpectedPacket] -> ShowS
showList :: [UnexpectedPacket] -> ShowS
Show)
instance Exception UnexpectedPacket