{-# LANGUAGE DeriveDataTypeable, ForeignFunctionInterface, RecordWildCards #-}
module Database.MySQL.Base
(
ConnectInfo(..)
, SSLInfo(..)
, Seconds
, Protocol(..)
, Option(..)
, defaultConnectInfo
, defaultSSLInfo
, Connection
, Result
, Type(..)
, Row
, MySQLError(errFunction, errNumber, errMessage)
, connect
, close
, autocommit
, ping
, changeUser
, selectDB
, setCharacterSet
, threadId
, serverInfo
, hostInfo
, protocolInfo
, characterSet
, sslCipher
, serverStatus
, query
, insertID
, escape
, fieldCount
, affectedRows
, isResultValid
, freeResult
, storeResult
, useResult
, fetchRow
, fetchFields
, dataSeek
, rowSeek
, rowTell
, nextResult
, commit
, rollback
, clientInfo
, clientVersion
) where
import Control.Applicative ((<$>), (<*>))
import Control.Exception (Exception, throw)
import Control.Monad (forM_, unless, when)
import Data.ByteString.Char8 ()
import Data.ByteString.Internal (ByteString, create, createAndTrim, memcpy)
import Data.ByteString.Unsafe (unsafeUseAsCStringLen)
import Data.IORef (IORef, atomicModifyIORef, newIORef, readIORef, writeIORef)
import Data.Int (Int64)
import Data.List (foldl')
import Data.Typeable (Typeable)
import Data.Word (Word, Word16, Word64)
import Database.MySQL.Base.C
import Database.MySQL.Base.Types
import Foreign.C.String (CString, peekCString, withCString)
import Foreign.C.Types (CULong)
import Foreign.Concurrent (newForeignPtr)
import Foreign.ForeignPtr hiding (newForeignPtr)
import Foreign.Marshal.Array (peekArray)
import Foreign.Ptr (Ptr, castPtr, nullPtr)
import System.IO.Unsafe (unsafePerformIO)
import System.Mem.Weak (Weak, deRefWeak, mkWeakPtr)
data ConnectInfo = ConnectInfo {
connectHost :: String
, connectPort :: Word16
, connectUser :: String
, connectPassword :: String
, connectDatabase :: String
, connectOptions :: [Option]
, connectPath :: FilePath
, connectSSL :: Maybe SSLInfo
} deriving (Eq, Read, Show, Typeable)
data SSLInfo = SSLInfo {
sslKey :: FilePath
, sslCert :: FilePath
, sslCA :: FilePath
, sslCAPath :: FilePath
, sslCiphers :: String
} deriving (Eq, Read, Show, Typeable)
data MySQLError = ConnectionError {
errFunction :: String
, errNumber :: Int
, errMessage :: String
} | ResultError {
errFunction :: String
, errNumber :: Int
, errMessage :: String
} deriving (Eq, Show, Typeable)
instance Exception MySQLError
data Connection = Connection {
connFP :: ForeignPtr MYSQL
, connClose :: IO ()
, connResult :: IORef (Maybe (Weak Result))
} deriving (Typeable)
data Result = Result {
resFP :: ForeignPtr MYSQL_RES
, resFields :: {-# UNPACK #-} !Int
, resConnection :: Connection
, resValid :: IORef Bool
, resFetchFields :: Ptr MYSQL_RES -> IO (Ptr Field)
, resFetchRow :: Ptr MYSQL_RES -> IO MYSQL_ROW
, resFetchLengths :: Ptr MYSQL_RES -> IO (Ptr CULong)
, resFreeResult :: Ptr MYSQL_RES -> IO ()
} | EmptyResult
deriving (Typeable)
newtype Row = Row MYSQL_ROW_OFFSET
deriving (Typeable)
defaultConnectInfo :: ConnectInfo
defaultConnectInfo = ConnectInfo {
connectHost = "localhost"
, connectPort = 3306
, connectUser = "root"
, connectPassword = ""
, connectDatabase = "test"
, connectOptions = [CharsetName "utf8"]
, connectPath = ""
, connectSSL = Nothing
}
defaultSSLInfo :: SSLInfo
defaultSSLInfo = SSLInfo {
sslKey = ""
, sslCert = ""
, sslCA = ""
, sslCAPath = ""
, sslCiphers = ""
}
connect :: ConnectInfo -> IO Connection
connect ConnectInfo{..} = do
closed <- newIORef False
ptr0 <- mysql_init nullPtr
case connectSSL of
Nothing -> return ()
Just SSLInfo{..} -> withString sslKey $ \ckey ->
withString sslCert $ \ccert ->
withString sslCA $ \cca ->
withString sslCAPath $ \ccapath ->
withString sslCiphers $ \ccipher ->
mysql_ssl_set ptr0 ckey ccert cca ccapath ccipher
>> return ()
forM_ connectOptions $ \opt -> do
r <- mysql_options ptr0 opt
unless (r == 0) $ connectionError_ "connect" ptr0
let flags = foldl' (+) 0 . map toConnectFlag $ connectOptions
ptr <- withString connectHost $ \chost ->
withString connectUser $ \cuser ->
withString connectPassword $ \cpass ->
withString connectDatabase $ \cdb ->
withString connectPath $ \cpath ->
mysql_real_connect ptr0 chost cuser cpass cdb
(fromIntegral connectPort) cpath flags
when (ptr == nullPtr) $
connectionError_ "connect" ptr0
res <- newIORef Nothing
let realClose = do
cleanupConnResult res
wasClosed <- atomicModifyIORef closed $ \prev -> (True, prev)
unless wasClosed $ mysql_close ptr
fp <- newForeignPtr ptr realClose
return Connection {
connFP = fp
, connClose = realClose
, connResult = res
}
cleanupConnResult :: IORef (Maybe (Weak Result)) -> IO ()
cleanupConnResult res = do
prev <- readIORef res
case prev of
Nothing -> return ()
Just w -> maybe (return ()) freeResult =<< deRefWeak w
close :: Connection -> IO ()
close = connClose
{-# INLINE close #-}
ping :: Connection -> IO ()
ping conn = withConn conn $ \ptr -> mysql_ping ptr >>= check "ping" conn
threadId :: Connection -> IO Word
threadId conn = fromIntegral <$> withConn conn mysql_thread_id
serverInfo :: Connection -> IO String
serverInfo conn = withConn conn $ \ptr ->
peekCString =<< mysql_get_server_info ptr
hostInfo :: Connection -> IO String
hostInfo conn = withConn conn $ \ptr ->
peekCString =<< mysql_get_host_info ptr
protocolInfo :: Connection -> IO Word
protocolInfo conn = withConn conn $ \ptr ->
fromIntegral <$> mysql_get_proto_info ptr
setCharacterSet :: Connection -> String -> IO ()
setCharacterSet conn cs =
withCString cs $ \ccs ->
withConn conn $ \ptr ->
mysql_set_character_set ptr ccs >>= check "setCharacterSet" conn
characterSet :: Connection -> IO String
characterSet conn = withConn conn $ \ptr ->
peekCString =<< mysql_character_set_name ptr
sslCipher :: Connection -> IO (Maybe String)
sslCipher conn = withConn conn $ \ptr ->
withPtr peekCString =<< mysql_get_ssl_cipher ptr
serverStatus :: Connection -> IO String
serverStatus conn = withConn conn $ \ptr -> do
st <- mysql_stat ptr
checkNull "serverStatus" conn st
peekCString st
clientInfo :: String
clientInfo = unsafePerformIO $ peekCString mysql_get_client_info
{-# NOINLINE clientInfo #-}
clientVersion :: Word
clientVersion = fromIntegral mysql_get_client_version
{-# NOINLINE clientVersion #-}
autocommit :: Connection -> Bool -> IO ()
autocommit conn onOff = withConn conn $ \ptr ->
mysql_autocommit ptr b >>= check "autocommit" conn
where b = if onOff then 1 else 0
changeUser :: Connection -> String -> String -> Maybe String -> IO ()
changeUser conn user pass mdb =
withCString user $ \cuser ->
withCString pass $ \cpass ->
withMaybeString mdb $ \cdb ->
withConn conn $ \ptr ->
mysql_change_user ptr cuser cpass cdb >>= check "changeUser" conn
selectDB :: Connection -> String -> IO ()
selectDB conn db =
withCString db $ \cdb ->
withConn conn $ \ptr ->
mysql_select_db ptr cdb >>= check "selectDB" conn
query :: Connection -> ByteString -> IO ()
query conn q = withConn conn $ \ptr ->
unsafeUseAsCStringLen q $ \(p,l) ->
mysql_real_query ptr p (fromIntegral l) >>= check "query" conn
insertID :: Connection -> IO Word64
insertID conn = fromIntegral <$> (withConn conn $ mysql_insert_id)
fieldCount :: Either Connection Result -> IO Int
fieldCount (Right EmptyResult) = return 0
fieldCount (Right res) = return (resFields res)
fieldCount (Left conn) =
withConn conn $ fmap fromIntegral . mysql_field_count
affectedRows :: Connection -> IO Int64
affectedRows conn = withConn conn $ fmap fromIntegral . mysql_affected_rows
storeResult :: Connection -> IO Result
storeResult = frobResult "storeResult" mysql_store_result
mysql_fetch_fields_nonblock
mysql_fetch_row_nonblock
mysql_fetch_lengths_nonblock
mysql_free_result_nonblock
useResult :: Connection -> IO Result
useResult = frobResult "useResult" mysql_use_result
mysql_fetch_fields
mysql_fetch_row
mysql_fetch_lengths
mysql_free_result
frobResult :: String
-> (Ptr MYSQL -> IO (Ptr MYSQL_RES))
-> (Ptr MYSQL_RES -> IO (Ptr Field))
-> (Ptr MYSQL_RES -> IO MYSQL_ROW)
-> (Ptr MYSQL_RES -> IO (Ptr CULong))
-> (Ptr MYSQL_RES -> IO ())
-> Connection -> IO Result
frobResult func frob fetchFieldsFunc fetchRowFunc fetchLengthsFunc
myFreeResult conn =
withConn conn $ \ptr -> do
cleanupConnResult (connResult conn)
res <- frob ptr
fields <- mysql_field_count ptr
valid <- newIORef True
if res == nullPtr
then if fields == 0
then return EmptyResult
else connectionError func conn
else do
fp <- newForeignPtr res $ freeResult_ valid myFreeResult res
let ret = Result {
resFP = fp
, resFields = fromIntegral fields
, resConnection = conn
, resValid = valid
, resFetchFields = fetchFieldsFunc
, resFetchRow = fetchRowFunc
, resFetchLengths = fetchLengthsFunc
, resFreeResult = myFreeResult
}
weak <- mkWeakPtr ret (Just (freeResult_ valid myFreeResult res))
writeIORef (connResult conn) (Just weak)
return ret
freeResult :: Result -> IO ()
freeResult Result{..} = withForeignPtr resFP $
freeResult_ resValid resFreeResult
freeResult EmptyResult = return ()
isResultValid :: Result -> IO Bool
isResultValid Result{..} = readIORef resValid
isResultValid EmptyResult = return False
freeResult_ :: IORef Bool -> (Ptr MYSQL_RES -> IO ()) -> Ptr MYSQL_RES -> IO ()
freeResult_ valid free ptr = do
wasValid <- atomicModifyIORef valid $ \prev -> (False, prev)
when wasValid $ free ptr
fetchRow :: Result -> IO [Maybe ByteString]
fetchRow res@Result{..} = withRes "fetchRow" res $ \ptr -> do
rowPtr <- resFetchRow ptr
if rowPtr == nullPtr
then return []
else do
lenPtr <- resFetchLengths ptr
checkNull "fetchRow" resConnection lenPtr
let go len = withPtr $ \colPtr ->
create (fromIntegral len) $ \d ->
memcpy d (castPtr colPtr) (fromIntegral len)
sequence =<< zipWith go <$> peekArray resFields lenPtr
<*> peekArray resFields rowPtr
fetchRow EmptyResult = return []
fetchFields :: Result -> IO [Field]
fetchFields res@Result{..} = withRes "fetchFields" res $ \ptr -> do
peekArray resFields =<< resFetchFields ptr
fetchFields EmptyResult = return []
dataSeek :: Result -> Int64 -> IO ()
dataSeek res row = withRes "dataSeek" res $ \ptr ->
mysql_data_seek ptr (fromIntegral row)
rowTell :: Result -> IO Row
rowTell res = withRes "rowTell" res $ \ptr ->
Row <$> mysql_row_tell ptr
rowSeek :: Result -> Row -> IO Row
rowSeek res (Row row) = withRes "rowSeek" res $ \ptr ->
Row <$> mysql_row_seek ptr row
nextResult :: Connection -> IO Bool
nextResult conn = withConn conn $ \ptr -> do
cleanupConnResult (connResult conn)
i <- mysql_next_result ptr
case i of
0 -> return True
-1 -> return False
_ -> connectionError "nextResult" conn
commit :: Connection -> IO ()
commit conn = withConn conn $ \ptr ->
mysql_commit ptr >>= check "commit" conn
rollback :: Connection -> IO ()
rollback conn = withConn conn $ \ptr ->
mysql_rollback ptr >>= check "rollback" conn
escape :: Connection -> ByteString -> IO ByteString
escape conn bs = withConn conn $ \ptr ->
unsafeUseAsCStringLen bs $ \(p,l) ->
createAndTrim (l*2 + 1) $ \to ->
fromIntegral <$> mysql_real_escape_string ptr (castPtr to) p
(fromIntegral l)
withConn :: Connection -> (Ptr MYSQL -> IO a) -> IO a
withConn conn = withForeignPtr (connFP conn)
withRes :: String -> Result -> (Ptr MYSQL_RES -> IO a) -> IO a
withRes func res act = do
valid <- readIORef (resValid res)
unless valid . throw $ ResultError func 0 "result is no longer usable"
withForeignPtr (resFP res) act
withString :: String -> (CString -> IO a) -> IO a
withString [] act = act nullPtr
withString xs act = withCString xs act
withMaybeString :: Maybe String -> (CString -> IO a) -> IO a
withMaybeString Nothing act = act nullPtr
withMaybeString (Just xs) act = withCString xs act
check :: (Eq a, Num a) => String -> Connection -> a -> IO ()
check func conn r = unless (r == 0) $ connectionError func conn
{-# INLINE check #-}
checkNull :: String -> Connection -> Ptr a -> IO ()
checkNull func conn p = when (p == nullPtr) $ connectionError func conn
{-# INLINE checkNull #-}
withPtr :: (Ptr a -> IO b) -> Ptr a -> IO (Maybe b)
withPtr act p | p == nullPtr = return Nothing
| otherwise = Just <$> act p
connectionError :: String -> Connection -> IO a
connectionError func conn = withConn conn $ connectionError_ func
connectionError_ :: String -> Ptr MYSQL -> IO a
connectionError_ func ptr =do
errno <- mysql_errno ptr
msg <- peekCString =<< mysql_error ptr
throw $ ConnectionError func (fromIntegral errno) msg