{-# OPTIONS -fglasgow-exts #-} ------------------------------------------------------------------------------ {-| Module : Database.HSQL.ODBC Copyright : (c) Krasimir Angelov 2003 License : BSD-style Maintainer : kr.angelov@gmail.com Stability : provisional Portability : portable The module provides interface to ODBC -} ------------------------------------------------------------------------------ module Database.HSQL.ODBC(connect, driverConnect, module Database.HSQL) where import Database.HSQL import Database.HSQL.Types import Data.Word(Word32, Word16) import Data.Int(Int32, Int16) import Data.Maybe import Foreign import Foreign.C import Control.Monad(unless) import Control.OldException(throwDyn) import Control.Concurrent.MVar import System.IO.Unsafe import System.Time #ifdef DEBUG import Debug.Trace #endif #include #include type SQLHANDLE = Ptr () type HENV = SQLHANDLE type HDBC = SQLHANDLE type HSTMT = SQLHANDLE type HENVRef = ForeignPtr () type SQLSMALLINT = #type SQLSMALLINT type SQLUSMALLINT = #type SQLUSMALLINT type SQLINTEGER = #type SQLINTEGER type SQLUINTEGER = #type SQLUINTEGER type SQLRETURN = SQLSMALLINT type SQLLEN = SQLINTEGER type SQLULEN = SQLINTEGER #ifdef mingw32_HOST_OS #let CALLCONV = "stdcall" #else #let CALLCONV = "ccall" #endif foreign import #{CALLCONV} "HsODBC.h SQLAllocEnv" sqlAllocEnv:: Ptr HENV -> IO SQLRETURN #ifdef mingw32_HOST_OS foreign import ccall "HsODBC.h &my_sqlFreeEnv" sqlFreeEnv_p:: FunPtr (HENV -> IO ()) #else foreign import ccall "HsODBC.h &SQLFreeEnv" sqlFreeEnv_p:: FunPtr (HENV -> IO ()) #endif foreign import #{CALLCONV} "HsODBC.h SQLAllocConnect" sqlAllocConnect:: HENV -> Ptr HDBC -> IO SQLRETURN foreign import #{CALLCONV} "HsODBC.h SQLFreeConnect" sqlFreeConnect:: HDBC -> IO SQLRETURN foreign import #{CALLCONV} "HsODBC.h SQLConnect" sqlConnect:: HDBC -> CString -> Int -> CString -> Int -> CString -> Int -> IO SQLRETURN foreign import #{CALLCONV} "HsODBC.h SQLDriverConnect" sqlDriverConnect:: HDBC -> Ptr () -> CString -> SQLSMALLINT -> CString -> SQLSMALLINT -> Ptr SQLSMALLINT -> SQLUSMALLINT -> IO SQLRETURN foreign import #{CALLCONV} "HsODBC.h SQLDisconnect" sqlDisconnect:: HDBC -> IO SQLRETURN foreign import #{CALLCONV} "HsODBC.h SQLAllocStmt" sqlAllocStmt:: HDBC -> Ptr HSTMT -> IO SQLRETURN foreign import #{CALLCONV} "HsODBC.h SQLFreeStmt" sqlFreeStmt:: HSTMT -> SQLUSMALLINT -> IO SQLRETURN foreign import #{CALLCONV} "HsODBC.h SQLNumResultCols" sqlNumResultCols:: HSTMT -> Ptr SQLUSMALLINT -> IO SQLRETURN foreign import #{CALLCONV} "HsODBC.h SQLDescribeCol" sqlDescribeCol:: HSTMT -> SQLUSMALLINT -> CString -> SQLSMALLINT -> Ptr SQLSMALLINT -> Ptr SQLSMALLINT -> Ptr SQLULEN -> Ptr SQLSMALLINT -> Ptr SQLSMALLINT -> IO SQLRETURN foreign import #{CALLCONV} "HsODBC.h SQLBindCol" sqlBindCol:: HSTMT -> SQLUSMALLINT -> SQLSMALLINT -> Ptr a -> SQLLEN -> Ptr SQLINTEGER -> IO SQLRETURN foreign import #{CALLCONV} "HsODBC.h SQLFetch" sqlFetch:: HSTMT -> IO SQLRETURN foreign import #{CALLCONV} "HsODBC.h SQLGetDiagRec" sqlGetDiagRec:: SQLSMALLINT -> SQLHANDLE -> SQLSMALLINT -> CString -> Ptr SQLINTEGER -> CString -> SQLSMALLINT -> Ptr SQLSMALLINT -> IO SQLRETURN foreign import #{CALLCONV} "HsODBC.h SQLExecDirect" sqlExecDirect:: HSTMT -> CString -> Int -> IO SQLRETURN foreign import #{CALLCONV} "HsODBC.h SQLSetConnectOption" sqlSetConnectOption:: HDBC -> SQLUSMALLINT -> SQLULEN -> IO SQLRETURN foreign import #{CALLCONV} "HsODBC.h SQLTransact" sqlTransact:: HENV -> HDBC -> SQLUSMALLINT -> IO SQLRETURN foreign import #{CALLCONV} "HsODBC.h SQLGetData" sqlGetData:: HSTMT -> SQLUSMALLINT -> SQLSMALLINT -> Ptr () -> SQLINTEGER -> Ptr SQLINTEGER -> IO SQLRETURN foreign import #{CALLCONV} "HsODBC.h SQLTables" sqlTables:: HSTMT -> CString -> SQLSMALLINT -> CString -> SQLSMALLINT -> CString -> SQLSMALLINT -> CString -> SQLSMALLINT -> IO SQLRETURN foreign import #{CALLCONV} "HsODBC.h SQLColumns" sqlColumns:: HSTMT -> CString -> SQLSMALLINT -> CString -> SQLSMALLINT -> CString -> SQLSMALLINT -> CString -> SQLSMALLINT -> IO SQLRETURN foreign import #{CALLCONV} "HsODBC.h SQLMoreResults" sqlMoreResults:: HSTMT -> IO SQLRETURN #if defined(MSSQL_ODBC) foreign import #{CALLCONV} "HsODBC.h SQLSetStmtAttr" sqlSetStmtAttr:: HSTMT -> SQLINTEGER -> SQLINTEGER -> SQLINTEGER -> IO SQLRETURN #endif ------------------------------------------------------------------------------ -- routines for handling exceptions ------------------------------------------------------------------------------ handleSqlResult :: SQLSMALLINT -> SQLHANDLE -> SQLRETURN -> IO () handleSqlResult handleType handle res | res == (#const SQL_SUCCESS) || res == (#const SQL_NO_DATA) = return () | res == (#const SQL_SUCCESS_WITH_INFO) = do #ifdef DEBUG e <- getSqlError putTraceMsg (show e) #else return () #endif | res == (#const SQL_INVALID_HANDLE) = throwDyn SqlInvalidHandle | res == (#const SQL_STILL_EXECUTING) = throwDyn SqlStillExecuting | res == (#const SQL_NEED_DATA) = throwDyn SqlNeedData | res == (#const SQL_ERROR) = do e <- getSqlError throwDyn e | otherwise = error (show res) where getSqlError = allocaBytes 256 $ \pState -> alloca $ \pNative -> allocaBytes 256 $ \pMsg -> alloca $ \pTextLen -> do res <- sqlGetDiagRec handleType handle 1 pState pNative pMsg 256 pTextLen if res == (#const SQL_NO_DATA) then return SqlNoData else do state <- peekCString pState native <- peek pNative msg <- peekCString pMsg return (SqlError { seState=state , seNativeError=fromIntegral native , seErrorMsg=msg }) ------------------------------------------------------------------------------ -- keeper of HENV ------------------------------------------------------------------------------ {-# NOINLINE myEnvironment #-} myEnvironment :: HENVRef myEnvironment = unsafePerformIO $ alloca $ \ (phEnv :: Ptr HENV) -> do res <- sqlAllocEnv phEnv hEnv <- peek phEnv handleSqlResult 0 nullPtr res newForeignPtr sqlFreeEnv_p hEnv ------------------------------------------------------------------------------ -- Connect/Disconnect ------------------------------------------------------------------------------ -- | Makes a new connection to the ODBC data source connect :: String -- ^ Data source name -> String -- ^ User identifier -> String -- ^ Authentication string (password) -> IO Connection -- ^ the returned value represents the new connection connect server user authentication = connectHelper $ \hDBC -> withCString server $ \pServer -> withCString user $ \pUser -> withCString authentication $ \pAuthentication -> sqlConnect hDBC pServer (#const SQL_NTS) pUser (#const SQL_NTS) pAuthentication (#const SQL_NTS) -- | 'driverConnect' is an alternative to 'connect'. It supports data sources that -- require more connection information than the three arguments in 'connect' -- and data sources that are not defined in the system information. driverConnect :: String -- ^ Connection string -> IO Connection -- ^ the returned value represents the new connection driverConnect connString = connectHelper $ \hDBC -> withCString connString $ \pConnString -> allocaBytes 1024 $ \pOutConnString -> alloca $ \pLen -> sqlDriverConnect hDBC nullPtr pConnString (#const SQL_NTS) pOutConnString 1024 pLen (#const SQL_DRIVER_NOPROMPT) connectHelper :: (HDBC -> IO SQLRETURN) -> IO Connection connectHelper connectFunction = withForeignPtr myEnvironment $ \hEnv -> do hDBC <- alloca $ \ (phDBC :: Ptr HDBC) -> do res <- sqlAllocConnect hEnv phDBC handleSqlResult (#const SQL_HANDLE_ENV) hEnv res peek phDBC res <- connectFunction hDBC handleSqlResult (#const SQL_HANDLE_DBC) hDBC res refFalse <- newMVar False let connection = (Connection { connDisconnect = disconnect hDBC , connExecute = execute hDBC , connQuery = query connection hDBC , connTables = tables connection hDBC , connDescribe = describe connection hDBC , connBeginTransaction = beginTransaction myEnvironment hDBC , connCommitTransaction = commitTransaction myEnvironment hDBC , connRollbackTransaction = rollbackTransaction myEnvironment hDBC , connClosed = refFalse }) return connection where disconnect :: HDBC -> IO () disconnect hDBC = do sqlDisconnect hDBC >>= handleSqlResult (#const SQL_HANDLE_DBC) hDBC sqlFreeConnect hDBC >>= handleSqlResult (#const SQL_HANDLE_DBC) hDBC execute :: HDBC -> String -> IO () execute hDBC query = allocaBytes (#const sizeof(HSTMT)) $ \pStmt -> do res <- sqlAllocStmt hDBC pStmt handleSqlResult (#const SQL_HANDLE_DBC) hDBC res hSTMT <- peek pStmt withCStringLen query $ \(pQuery,len) -> do res <- sqlExecDirect hSTMT pQuery len handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res res <- sqlFreeStmt hSTMT (#const SQL_DROP) handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res stmtBufferSize = 256 withStatement :: Connection -> HDBC -> (HSTMT -> IO SQLRETURN) -> IO Statement withStatement connection hDBC f = allocaBytes (#const sizeof(FIELD)) $ \pFIELD -> do res <- sqlAllocStmt hDBC ((#ptr FIELD, hSTMT) pFIELD) handleSqlResult (#const SQL_HANDLE_DBC) hDBC res hSTMT <- (#peek FIELD, hSTMT) pFIELD let handleResult res = handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res #if defined(MSSQL_ODBC) sqlSetStmtAttr hSTMT (#const SQL_ATTR_ROW_ARRAY_SIZE) 2 (#const SQL_IS_INTEGER) sqlSetStmtAttr hSTMT (#const SQL_ATTR_CURSOR_TYPE) (#const SQL_CURSOR_STATIC) (#const SQL_IS_INTEGER) #endif f hSTMT >>= handleResult fields <- moveToFirstResult hSTMT pFIELD buffer <- mallocBytes (fromIntegral stmtBufferSize) refFalse <- newMVar False let statement = Statement { stmtConn = connection , stmtClose = closeStatement hSTMT buffer , stmtFetch = fetch hSTMT , stmtGetCol = getColValue hSTMT buffer , stmtFields = fields , stmtClosed = refFalse } return statement where moveToFirstResult :: HSTMT -> Ptr a -> IO [FieldDef] moveToFirstResult hSTMT pFIELD = do res <- sqlNumResultCols hSTMT ((#ptr FIELD, fieldsCount) pFIELD) handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res count <- (#peek FIELD, fieldsCount) pFIELD if count == 0 then do #if defined(MSSQL_ODBC) sqlSetStmtAttr hSTMT (#const SQL_ATTR_ROW_ARRAY_SIZE) 2 (#const SQL_IS_INTEGER) sqlSetStmtAttr hSTMT (#const SQL_ATTR_CURSOR_TYPE) (#const SQL_CURSOR_STATIC) (#const SQL_IS_INTEGER) #endif res <- sqlMoreResults hSTMT handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res if res == (#const SQL_NO_DATA) then return [] else moveToFirstResult hSTMT pFIELD else getFieldDefs hSTMT pFIELD 1 count getFieldDefs :: HSTMT -> Ptr a -> SQLUSMALLINT -> SQLUSMALLINT -> IO [FieldDef] getFieldDefs hSTMT pFIELD n count | n > count = return [] | otherwise = do res <- sqlDescribeCol hSTMT n ((#ptr FIELD, fieldName) pFIELD) (#const FIELD_NAME_LENGTH) ((#ptr FIELD, NameLength) pFIELD) ((#ptr FIELD, DataType) pFIELD) ((#ptr FIELD, ColumnSize) pFIELD) ((#ptr FIELD, DecimalDigits) pFIELD) ((#ptr FIELD, Nullable) pFIELD) handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res name <- peekCString ((#ptr FIELD, fieldName) pFIELD) dataType <- (#peek FIELD, DataType) pFIELD columnSize <- (#peek FIELD, ColumnSize) pFIELD decimalDigits <- (#peek FIELD, DecimalDigits) pFIELD (nullable :: SQLSMALLINT) <- (#peek FIELD, Nullable) pFIELD let sqlType = mkSqlType dataType columnSize decimalDigits fields <- getFieldDefs hSTMT pFIELD (n+1) count return ((name,sqlType,toBool nullable):fields) mkSqlType :: SQLSMALLINT -> SQLULEN -> SQLSMALLINT -> SqlType mkSqlType (#const SQL_CHAR) size _ = SqlChar (fromIntegral size) mkSqlType (#const SQL_VARCHAR) size _ = SqlVarChar (fromIntegral size) mkSqlType (#const SQL_LONGVARCHAR) size _ = SqlLongVarChar (fromIntegral size) mkSqlType (#const SQL_DECIMAL) size prec = SqlDecimal (fromIntegral size) (fromIntegral prec) mkSqlType (#const SQL_NUMERIC) size prec = SqlNumeric (fromIntegral size) (fromIntegral prec) mkSqlType (#const SQL_SMALLINT) _ _ = SqlSmallInt mkSqlType (#const SQL_INTEGER) _ _ = SqlInteger mkSqlType (#const SQL_REAL) _ _ = SqlReal -- From: http://msdn.microsoft.com/library/en-us/odbc/htm/odappdpr_2.asp -- "Depending on the implementation, the precision of SQL_FLOAT can be either 24 or 53: -- if it is 24, the SQL_FLOAT data type is the same as SQL_REAL; -- if it is 53, the SQL_FLOAT data type is the same as SQL_DOUBLE." mkSqlType (#const SQL_FLOAT) _ _ = SqlFloat mkSqlType (#const SQL_DOUBLE) _ _ = SqlDouble mkSqlType (#const SQL_BIT) _ _ = SqlBit mkSqlType (#const SQL_TINYINT) _ _ = SqlTinyInt mkSqlType (#const SQL_BIGINT) _ _ = SqlBigInt mkSqlType (#const SQL_BINARY) size _ = SqlBinary (fromIntegral size) mkSqlType (#const SQL_VARBINARY) size _ = SqlVarBinary (fromIntegral size) mkSqlType (#const SQL_LONGVARBINARY)size _ = SqlLongVarBinary (fromIntegral size) mkSqlType (#const SQL_DATE) _ _ = SqlDate mkSqlType (#const SQL_TIME) _ _ = SqlTime mkSqlType (#const SQL_TIMESTAMP) _ _ = SqlDateTime mkSqlType (#const SQL_WCHAR) size _ = SqlWChar (fromIntegral size) mkSqlType (#const SQL_WVARCHAR) size _ = SqlWVarChar (fromIntegral size) mkSqlType (#const SQL_WLONGVARCHAR) size _ = SqlWLongVarChar (fromIntegral size) mkSqlType tp _ _ = SqlUnknown (fromIntegral tp) query :: Connection -> HDBC -> String -> IO Statement query connection hDBC q = withStatement connection hDBC doQuery where doQuery hSTMT = withCStringLen q (uncurry (sqlExecDirect hSTMT)) beginTransaction myEnvironment hDBC = do sqlSetConnectOption hDBC (#const SQL_AUTOCOMMIT) (#const SQL_AUTOCOMMIT_OFF) return () commitTransaction myEnvironment hDBC = withForeignPtr myEnvironment $ \hEnv -> do sqlTransact hEnv hDBC (#const SQL_COMMIT) sqlSetConnectOption hDBC (#const SQL_AUTOCOMMIT) (#const SQL_AUTOCOMMIT_ON) return () rollbackTransaction myEnvironment hDBC = withForeignPtr myEnvironment $ \hEnv -> do sqlTransact hEnv hDBC (#const SQL_ROLLBACK) sqlSetConnectOption hDBC (#const SQL_AUTOCOMMIT) (#const SQL_AUTOCOMMIT_ON) return () tables :: Connection -> HDBC -> IO [String] tables connection hDBC = do stmt <- withStatement connection hDBC sqlTables' -- SQLTables returns (column names may vary): -- Column name # Type -- TABLE_NAME 3 VARCHAR collectRows (\s -> getFieldValue s "TABLE_NAME") stmt where sqlTables' hSTMT = sqlTables hSTMT nullPtr 0 nullPtr 0 nullPtr 0 nullPtr 0 describe :: Connection -> HDBC -> String -> IO [FieldDef] describe connection hDBC table = do stmt <- withStatement connection hDBC (sqlColumns' table) collectRows getColumnInfo stmt where sqlColumns' table hSTMT = withCStringLen table (\(pTable,len) -> sqlColumns hSTMT nullPtr 0 nullPtr 0 pTable (fromIntegral len) nullPtr 0) -- SQLColumns returns (column names may vary): -- Column name # Type -- COLUMN_NAME 4 Varchar not NULL -- DATA_TYPE 5 Smallint not NULL -- COLUMN_SIZE 7 Integer -- DECIMAL_DIGITS 9 Smallint -- NULLABLE 11 Smallint not NULL getColumnInfo stmt = do column_name <- getFieldValue stmt "COLUMN_NAME" (data_type::Int) <- getFieldValue stmt "DATA_TYPE" (column_size::Int) <- getFieldValue' stmt "COLUMN_SIZE" 0 (decimal_digits::Int) <- getFieldValue' stmt "DECIMAL_DIGITS" 0 (nullable::Int) <- getFieldValue stmt "NULLABLE" let sqlType = mkSqlType (fromIntegral data_type) (fromIntegral column_size) (fromIntegral decimal_digits) return (column_name, sqlType, toBool nullable) fetch :: HSTMT -> IO Bool fetch hSTMT = do res <- sqlFetch hSTMT handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res return (res /= (#const SQL_NO_DATA)) getColValue :: HSTMT -> CString -> Int -> FieldDef -> (FieldDef -> CString -> Int -> IO a) -> IO a getColValue hSTMT buffer colNumber fieldDef f = do (res,len_or_ind) <- getData buffer (fromIntegral stmtBufferSize) if len_or_ind == (#const SQL_NULL_DATA) then f fieldDef nullPtr 0 else if res == (#const SQL_SUCCESS_WITH_INFO) then getLongData len_or_ind else f fieldDef buffer (fromIntegral len_or_ind) where getData :: CString -> SQLINTEGER -> IO (SQLRETURN, SQLINTEGER) getData buffer size = alloca $ \lenP -> do res <- sqlGetData hSTMT (fromIntegral colNumber+1) (#const SQL_C_CHAR) (castPtr buffer) size lenP handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res len_or_ind <- peek lenP return (res, len_or_ind) -- gets called only when there is more data than would -- fit in the normal buffer. This call to -- SQLGetData() will fetch the rest of the data. -- We create a new buffer big enough to hold the -- old and the new data, copy the old data into -- it and put the new data in buffer after the old. getLongData len = allocaBytes (fromIntegral newBufSize) $ \newBuf -> do copyBytes newBuf buffer stmtBufferSize -- The last byte of the old data with always be null, -- so it is overwritten with the first byte of the new data. let newDataStart = newBuf `plusPtr` (stmtBufferSize - 1) newDataLen = newBufSize - (fromIntegral stmtBufferSize - 1) (res,_) <- getData newDataStart newDataLen f fieldDef newBuf (fromIntegral newBufSize-1) where newBufSize = len+1 -- to allow for terminating null character closeStatement :: HSTMT -> CString -> IO () closeStatement hSTMT buffer = do free buffer sqlFreeStmt hSTMT (#const SQL_DROP) >>= handleSqlResult (#const SQL_HANDLE_STMT) hSTMT