{-# LINE 1 "Database/HSQL/MySQL.hsc" #-}
-----------------------------------------------------------------------------------------
{-# LINE 2 "Database/HSQL/MySQL.hsc" #-}
{-| Module      :  Database.HSQL.MySQL
    Copyright   :  (c) Krasimir Angelov 2003
    License     :  BSD-style

    Maintainer  :  ka2_mail@yahoo.com
    Stability   :  provisional
    Portability :  portable

    The module provides interface to MySQL database
-}
-----------------------------------------------------------------------------------------

module Database.HSQL.MySQL(connect, module Database.HSQL) where

import Database.HSQL
import Database.HSQL.Types
import Data.Dynamic
import Data.Bits
import Data.Char
import Foreign
import Foreign.C
import Control.Monad(when,unless)
import Control.Exception (throwDyn, finally)
import Control.Concurrent.MVar
import System.Time
import System.IO.Unsafe
import Text.ParserCombinators.ReadP
import Text.Read


{-# LINE 32 "Database/HSQL/MySQL.hsc" #-}

type MYSQL = Ptr ()
type MYSQL_RES = Ptr ()
type MYSQL_FIELD = Ptr ()
type MYSQL_ROW = Ptr CString
type MYSQL_LENGTHS = Ptr CULong


{-# LINE 42 "Database/HSQL/MySQL.hsc" #-}

{-# LINE 43 "Database/HSQL/MySQL.hsc" #-}

{-# LINE 44 "Database/HSQL/MySQL.hsc" #-}

foreign import ccall "HsMySQL.h mysql_init" mysql_init :: MYSQL -> IO MYSQL
{-# LINE 46 "Database/HSQL/MySQL.hsc" #-}
foreign import ccall "HsMySQL.h mysql_real_connect" mysql_real_connect :: MYSQL -> CString -> CString -> CString -> CString -> CInt -> CString -> CInt -> IO MYSQL
{-# LINE 47 "Database/HSQL/MySQL.hsc" #-}
foreign import ccall "HsMySQL.h mysql_close" mysql_close :: MYSQL -> IO ()
{-# LINE 48 "Database/HSQL/MySQL.hsc" #-}
foreign import ccall "HsMySQL.h mysql_errno" mysql_errno :: MYSQL -> IO CInt
{-# LINE 49 "Database/HSQL/MySQL.hsc" #-}
foreign import ccall "HsMySQL.h mysql_error" mysql_error :: MYSQL -> IO CString
{-# LINE 50 "Database/HSQL/MySQL.hsc" #-}
foreign import ccall "HsMySQL.h mysql_query" mysql_query :: MYSQL -> CString -> IO CInt
{-# LINE 51 "Database/HSQL/MySQL.hsc" #-}
foreign import ccall "HsMySQL.h mysql_use_result" mysql_use_result :: MYSQL -> IO MYSQL_RES
{-# LINE 52 "Database/HSQL/MySQL.hsc" #-}
foreign import ccall "HsMySQL.h mysql_fetch_field" mysql_fetch_field :: MYSQL_RES -> IO MYSQL_FIELD
{-# LINE 53 "Database/HSQL/MySQL.hsc" #-}
foreign import ccall "HsMySQL.h mysql_free_result" mysql_free_result :: MYSQL_RES -> IO ()
{-# LINE 54 "Database/HSQL/MySQL.hsc" #-}
foreign import ccall "HsMySQL.h mysql_fetch_row" mysql_fetch_row :: MYSQL_RES -> IO MYSQL_ROW
{-# LINE 55 "Database/HSQL/MySQL.hsc" #-}
foreign import ccall "HsMySQL.h mysql_fetch_lengths" mysql_fetch_lengths :: MYSQL_RES -> IO MYSQL_LENGTHS
{-# LINE 56 "Database/HSQL/MySQL.hsc" #-}
foreign import ccall "HsMySQL.h mysql_list_tables" mysql_list_tables :: MYSQL -> CString -> IO MYSQL_RES
{-# LINE 57 "Database/HSQL/MySQL.hsc" #-}
foreign import ccall "HsMySQL.h mysql_list_fields" mysql_list_fields :: MYSQL -> CString -> CString -> IO MYSQL_RES
{-# LINE 58 "Database/HSQL/MySQL.hsc" #-}
foreign import ccall "HsMySQL.h mysql_next_result" mysql_next_result :: MYSQL -> IO CInt
{-# LINE 59 "Database/HSQL/MySQL.hsc" #-}



-----------------------------------------------------------------------------------------
-- routines for handling exceptions
-----------------------------------------------------------------------------------------

handleSqlError :: MYSQL -> IO a
handleSqlError pMYSQL = do
	errno <- mysql_errno pMYSQL
	errMsg <- mysql_error pMYSQL >>= peekCString
	throwDyn (SqlError "" (fromIntegral errno) errMsg)

-----------------------------------------------------------------------------------------
-- Connect/Disconnect
-----------------------------------------------------------------------------------------

-- | Makes a new connection to the database server.
connect :: String   -- ^ Server name
        -> String   -- ^ Database name
        -> String   -- ^ User identifier
        -> String   -- ^ Authentication string (password)
        -> IO Connection
connect server database user authentication = do
	pMYSQL <- mysql_init nullPtr
	pServer <- newCString server
	pDatabase <- newCString database
	pUser <- newCString user
	pAuthentication <- newCString authentication
	res <- mysql_real_connect pMYSQL pServer pUser pAuthentication pDatabase 0 nullPtr (65536)
{-# LINE 89 "Database/HSQL/MySQL.hsc" #-}
	free pServer
	free pDatabase
	free pUser
	free pAuthentication
	when (res == nullPtr) (handleSqlError pMYSQL)
	refFalse <- newMVar False
	let connection = Connection
		{ connDisconnect = mysql_close pMYSQL
		, connExecute = execute pMYSQL
		, connQuery = query connection pMYSQL
		, connTables = tables connection pMYSQL
		, connDescribe = describe connection pMYSQL
		, connBeginTransaction = execute pMYSQL "begin"
		, connCommitTransaction = execute pMYSQL "commit"
		, connRollbackTransaction = execute pMYSQL "rollback"
		, connClosed = refFalse
		}
	return connection
	where
		execute :: MYSQL -> String -> IO ()
		execute pMYSQL query = do
			res <- withCString query (mysql_query pMYSQL)
			when (res /= 0) (handleSqlError pMYSQL)

		withStatement :: Connection -> MYSQL -> MYSQL_RES -> IO Statement
		withStatement conn pMYSQL pRes = do
			currRow  <- newMVar (nullPtr, nullPtr)
			refFalse <- newMVar False
			if (pRes == nullPtr)
			  then do
				errno <- mysql_errno pMYSQL
				when (errno /= 0) (handleSqlError pMYSQL)
				return (Statement
			              { stmtConn   = conn
			              , stmtClose  = return ()
			              , stmtFetch  = fetch pRes currRow
			              , stmtGetCol = getColValue currRow
			              , stmtFields = []
			              , stmtClosed = refFalse
			              })
			  else do
				fieldDefs <- getFieldDefs pRes
				return (Statement
			              { stmtConn   = conn
			              , stmtClose  = mysql_free_result pRes
			              , stmtFetch  = fetch pRes currRow
			              , stmtGetCol = getColValue currRow
			              , stmtFields = fieldDefs
			              , stmtClosed = refFalse
			              })
			where
				getFieldDefs pRes = do
					pField <- mysql_fetch_field pRes
					if pField == nullPtr
					  then return []
					  else do
						name <- ((\hsc_ptr -> peekByteOff hsc_ptr 0)) pField >>= peekCString
{-# LINE 146 "Database/HSQL/MySQL.hsc" #-}
						dataType <-  ((\hsc_ptr -> peekByteOff hsc_ptr 76)) pField
{-# LINE 147 "Database/HSQL/MySQL.hsc" #-}
						columnSize <-  ((\hsc_ptr -> peekByteOff hsc_ptr 28)) pField
{-# LINE 148 "Database/HSQL/MySQL.hsc" #-}
						flags <-  ((\hsc_ptr -> peekByteOff hsc_ptr 64)) pField
{-# LINE 149 "Database/HSQL/MySQL.hsc" #-}
						decimalDigits <-  ((\hsc_ptr -> peekByteOff hsc_ptr 68)) pField
{-# LINE 150 "Database/HSQL/MySQL.hsc" #-}
						let sqlType = mkSqlType dataType columnSize decimalDigits
						defs <- getFieldDefs pRes
						return ((name,sqlType,((flags :: Int) .&. (1)) == 0):defs)
{-# LINE 153 "Database/HSQL/MySQL.hsc" #-}

				mkSqlType :: Int -> Int -> Int -> SqlType
				mkSqlType (254)     size _	   = SqlChar size
{-# LINE 156 "Database/HSQL/MySQL.hsc" #-}
				mkSqlType (253) size _    = SqlVarChar size
{-# LINE 157 "Database/HSQL/MySQL.hsc" #-}
				mkSqlType (0)    size prec = SqlNumeric size prec
{-# LINE 158 "Database/HSQL/MySQL.hsc" #-}
				mkSqlType (2)      _    _    = SqlSmallInt
{-# LINE 159 "Database/HSQL/MySQL.hsc" #-}
				mkSqlType (9)      _    _    = SqlMedInt
{-# LINE 160 "Database/HSQL/MySQL.hsc" #-}
				mkSqlType (3)       _    _    = SqlInteger
{-# LINE 161 "Database/HSQL/MySQL.hsc" #-}
				mkSqlType (4)      _    _	   = SqlReal
{-# LINE 162 "Database/HSQL/MySQL.hsc" #-}
				mkSqlType (5)     _    _    = SqlDouble
{-# LINE 163 "Database/HSQL/MySQL.hsc" #-}
				mkSqlType (1)       _    _    = SqlTinyInt
{-# LINE 164 "Database/HSQL/MySQL.hsc" #-}
				mkSqlType (8)   _    _    = SqlBigInt
{-# LINE 165 "Database/HSQL/MySQL.hsc" #-}
				mkSqlType (10)       _    _    = SqlDate
{-# LINE 166 "Database/HSQL/MySQL.hsc" #-}
				mkSqlType (11)       _    _    = SqlTime
{-# LINE 167 "Database/HSQL/MySQL.hsc" #-}
				mkSqlType (7)  _    _    = SqlTimeStamp
{-# LINE 168 "Database/HSQL/MySQL.hsc" #-}
				mkSqlType (12)   _    _    = SqlDateTime
{-# LINE 169 "Database/HSQL/MySQL.hsc" #-}
				mkSqlType (13)       _    _    = SqlYear
{-# LINE 170 "Database/HSQL/MySQL.hsc" #-}
				mkSqlType (252)       _    _    = SqlBLOB
{-# LINE 171 "Database/HSQL/MySQL.hsc" #-}
				mkSqlType (248)        _    _    = SqlSET
{-# LINE 172 "Database/HSQL/MySQL.hsc" #-}
				mkSqlType (247)       _    _    = SqlENUM
{-# LINE 173 "Database/HSQL/MySQL.hsc" #-}
				mkSqlType tp                             _    _    = SqlUnknown tp

		query :: Connection -> MYSQL -> String -> IO Statement
		query conn pMYSQL query = do
			res <- withCString query (mysql_query pMYSQL)
			when (res /= 0) (handleSqlError pMYSQL)
			pRes <- getFirstResult pMYSQL
			withStatement conn pMYSQL pRes
			where
			  getFirstResult :: MYSQL -> IO MYSQL_RES
			  getFirstResult pMYSQL = do
			    pRes <- mysql_use_result pMYSQL
			    if pRes == nullPtr
			      then do
			        res <- mysql_next_result pMYSQL
			        if res == 0
			          then getFirstResult pMYSQL
			          else return nullPtr
			      else return pRes

		fetch :: MYSQL_RES -> MVar (MYSQL_ROW, MYSQL_LENGTHS) -> IO Bool
		fetch pRes currRow
			| pRes == nullPtr = return False
			| otherwise = modifyMVar currRow $ \(pRow, pLengths) -> do
				pRow <- mysql_fetch_row pRes
				pLengths <- mysql_fetch_lengths pRes
				return ((pRow, pLengths), pRow /= nullPtr)

		getColValue :: MVar (MYSQL_ROW, MYSQL_LENGTHS) -> Int -> FieldDef -> (FieldDef -> CString -> Int -> IO a) -> IO a
		getColValue currRow colNumber fieldDef f = do
			(row, lengths) <- readMVar currRow
			pValue <- peekElemOff row colNumber
			len <- fmap fromIntegral (peekElemOff lengths colNumber)
			f fieldDef pValue len

		tables :: Connection -> MYSQL -> IO [String]
		tables conn pMYSQL = do
			pRes <- mysql_list_tables pMYSQL nullPtr
			stmt <- withStatement conn pMYSQL pRes
			-- SQLTables returns:
			-- Column name     #   Type
			-- Tables_in_xx      0   VARCHAR
			collectRows (\stmt -> do
				mb_v <- stmtGetCol stmt 0 ("Tables", SqlVarChar 0, False) fromSqlCStringLen
				return (case mb_v of { Nothing -> ""; Just a -> a })) stmt

		describe :: Connection -> MYSQL -> String -> IO [FieldDef]
		describe conn pMYSQL table = do
			pRes <- withCString table (\table -> mysql_list_fields pMYSQL table nullPtr)
			stmt <- withStatement conn pMYSQL pRes
			return (getFieldsTypes stmt)