module Database.HSQL.MySQL(connect, module Database.HSQL) where
import Database.HSQL
import Database.HSQL.Types
import Data.Dynamic
import Data.Bits
import Data.Char
import Foreign
import Foreign.C
import Control.Monad(when,unless)
import Control.Exception (throwDyn, finally)
import Control.Concurrent.MVar
import System.Time
import System.IO.Unsafe
import Text.ParserCombinators.ReadP
import Text.Read
type MYSQL = Ptr ()
type MYSQL_RES = Ptr ()
type MYSQL_FIELD = Ptr ()
type MYSQL_ROW = Ptr CString
type MYSQL_LENGTHS = Ptr CULong
foreign import ccall "HsMySQL.h mysql_init" mysql_init :: MYSQL -> IO MYSQL
foreign import ccall "HsMySQL.h mysql_real_connect" mysql_real_connect :: MYSQL -> CString -> CString -> CString -> CString -> CInt -> CString -> CInt -> IO MYSQL
foreign import ccall "HsMySQL.h mysql_close" mysql_close :: MYSQL -> IO ()
foreign import ccall "HsMySQL.h mysql_errno" mysql_errno :: MYSQL -> IO CInt
foreign import ccall "HsMySQL.h mysql_error" mysql_error :: MYSQL -> IO CString
foreign import ccall "HsMySQL.h mysql_query" mysql_query :: MYSQL -> CString -> IO CInt
foreign import ccall "HsMySQL.h mysql_use_result" mysql_use_result :: MYSQL -> IO MYSQL_RES
foreign import ccall "HsMySQL.h mysql_fetch_field" mysql_fetch_field :: MYSQL_RES -> IO MYSQL_FIELD
foreign import ccall "HsMySQL.h mysql_free_result" mysql_free_result :: MYSQL_RES -> IO ()
foreign import ccall "HsMySQL.h mysql_fetch_row" mysql_fetch_row :: MYSQL_RES -> IO MYSQL_ROW
foreign import ccall "HsMySQL.h mysql_fetch_lengths" mysql_fetch_lengths :: MYSQL_RES -> IO MYSQL_LENGTHS
foreign import ccall "HsMySQL.h mysql_list_tables" mysql_list_tables :: MYSQL -> CString -> IO MYSQL_RES
foreign import ccall "HsMySQL.h mysql_list_fields" mysql_list_fields :: MYSQL -> CString -> CString -> IO MYSQL_RES
foreign import ccall "HsMySQL.h mysql_next_result" mysql_next_result :: MYSQL -> IO CInt
handleSqlError :: MYSQL -> IO a
handleSqlError pMYSQL = do
errno <- mysql_errno pMYSQL
errMsg <- mysql_error pMYSQL >>= peekCString
throwDyn (SqlError "" (fromIntegral errno) errMsg)
connect :: String
-> String
-> String
-> String
-> IO Connection
connect server database user authentication = do
pMYSQL <- mysql_init nullPtr
pServer <- newCString server
pDatabase <- newCString database
pUser <- newCString user
pAuthentication <- newCString authentication
res <- mysql_real_connect pMYSQL pServer pUser pAuthentication pDatabase 0 nullPtr (65536)
free pServer
free pDatabase
free pUser
free pAuthentication
when (res == nullPtr) (handleSqlError pMYSQL)
refFalse <- newMVar False
let connection = Connection
{ connDisconnect = mysql_close pMYSQL
, connExecute = execute pMYSQL
, connQuery = query connection pMYSQL
, connTables = tables connection pMYSQL
, connDescribe = describe connection pMYSQL
, connBeginTransaction = execute pMYSQL "begin"
, connCommitTransaction = execute pMYSQL "commit"
, connRollbackTransaction = execute pMYSQL "rollback"
, connClosed = refFalse
}
return connection
where
execute :: MYSQL -> String -> IO ()
execute pMYSQL query = do
res <- withCString query (mysql_query pMYSQL)
when (res /= 0) (handleSqlError pMYSQL)
withStatement :: Connection -> MYSQL -> MYSQL_RES -> IO Statement
withStatement conn pMYSQL pRes = do
currRow <- newMVar (nullPtr, nullPtr)
refFalse <- newMVar False
if (pRes == nullPtr)
then do
errno <- mysql_errno pMYSQL
when (errno /= 0) (handleSqlError pMYSQL)
return (Statement
{ stmtConn = conn
, stmtClose = return ()
, stmtFetch = fetch pRes currRow
, stmtGetCol = getColValue currRow
, stmtFields = []
, stmtClosed = refFalse
})
else do
fieldDefs <- getFieldDefs pRes
return (Statement
{ stmtConn = conn
, stmtClose = mysql_free_result pRes
, stmtFetch = fetch pRes currRow
, stmtGetCol = getColValue currRow
, stmtFields = fieldDefs
, stmtClosed = refFalse
})
where
getFieldDefs pRes = do
pField <- mysql_fetch_field pRes
if pField == nullPtr
then return []
else do
name <- ((\hsc_ptr -> peekByteOff hsc_ptr 0)) pField >>= peekCString
dataType <- ((\hsc_ptr -> peekByteOff hsc_ptr 76)) pField
columnSize <- ((\hsc_ptr -> peekByteOff hsc_ptr 28)) pField
flags <- ((\hsc_ptr -> peekByteOff hsc_ptr 64)) pField
decimalDigits <- ((\hsc_ptr -> peekByteOff hsc_ptr 68)) pField
let sqlType = mkSqlType dataType columnSize decimalDigits
defs <- getFieldDefs pRes
return ((name,sqlType,((flags :: Int) .&. (1)) == 0):defs)
mkSqlType :: Int -> Int -> Int -> SqlType
mkSqlType (254) size _ = SqlChar size
mkSqlType (253) size _ = SqlVarChar size
mkSqlType (0) size prec = SqlNumeric size prec
mkSqlType (2) _ _ = SqlSmallInt
mkSqlType (9) _ _ = SqlMedInt
mkSqlType (3) _ _ = SqlInteger
mkSqlType (4) _ _ = SqlReal
mkSqlType (5) _ _ = SqlDouble
mkSqlType (1) _ _ = SqlTinyInt
mkSqlType (8) _ _ = SqlBigInt
mkSqlType (10) _ _ = SqlDate
mkSqlType (11) _ _ = SqlTime
mkSqlType (7) _ _ = SqlTimeStamp
mkSqlType (12) _ _ = SqlDateTime
mkSqlType (13) _ _ = SqlYear
mkSqlType (252) _ _ = SqlBLOB
mkSqlType (248) _ _ = SqlSET
mkSqlType (247) _ _ = SqlENUM
mkSqlType tp _ _ = SqlUnknown tp
query :: Connection -> MYSQL -> String -> IO Statement
query conn pMYSQL query = do
res <- withCString query (mysql_query pMYSQL)
when (res /= 0) (handleSqlError pMYSQL)
pRes <- getFirstResult pMYSQL
withStatement conn pMYSQL pRes
where
getFirstResult :: MYSQL -> IO MYSQL_RES
getFirstResult pMYSQL = do
pRes <- mysql_use_result pMYSQL
if pRes == nullPtr
then do
res <- mysql_next_result pMYSQL
if res == 0
then getFirstResult pMYSQL
else return nullPtr
else return pRes
fetch :: MYSQL_RES -> MVar (MYSQL_ROW, MYSQL_LENGTHS) -> IO Bool
fetch pRes currRow
| pRes == nullPtr = return False
| otherwise = modifyMVar currRow $ \(pRow, pLengths) -> do
pRow <- mysql_fetch_row pRes
pLengths <- mysql_fetch_lengths pRes
return ((pRow, pLengths), pRow /= nullPtr)
getColValue :: MVar (MYSQL_ROW, MYSQL_LENGTHS) -> Int -> FieldDef -> (FieldDef -> CString -> Int -> IO a) -> IO a
getColValue currRow colNumber fieldDef f = do
(row, lengths) <- readMVar currRow
pValue <- peekElemOff row colNumber
len <- fmap fromIntegral (peekElemOff lengths colNumber)
f fieldDef pValue len
tables :: Connection -> MYSQL -> IO [String]
tables conn pMYSQL = do
pRes <- mysql_list_tables pMYSQL nullPtr
stmt <- withStatement conn pMYSQL pRes
collectRows (\stmt -> do
mb_v <- stmtGetCol stmt 0 ("Tables", SqlVarChar 0, False) fromSqlCStringLen
return (case mb_v of { Nothing -> ""; Just a -> a })) stmt
describe :: Connection -> MYSQL -> String -> IO [FieldDef]
describe conn pMYSQL table = do
pRes <- withCString table (\table -> mysql_list_fields pMYSQL table nullPtr)
stmt <- withStatement conn pMYSQL pRes
return (getFieldsTypes stmt)