{-# LANGUAGE ScopedTypeVariables #-} module DB.HSQL.ODBC.Core(handleSqlResult,withStatement) where import Control.Exception(throw) import Database.HSQL import Database.HSQL.Types import Control.Concurrent.MVar(newMVar) import Foreign(plusPtr,peekByteOff,toBool,Ptr,nullPtr,castPtr,copyBytes ,allocaBytes,alloca,mallocBytes,free,peek) import Foreign.C(CString,peekCString) #ifdef DEBUG import Debug.Trace(putTraceMsg) #endif import DB.HSQL.ODBC.Type import DB.HSQL.ODBC.Functions import DB.HSQL.ODBC.Status #include -- | withStatement :: Connection -> HDBC -> (HSTMT -> IO SQLRETURN) -> IO Statement withStatement connection hDBC f = allocaBytes sizeOfField $ \pFIELD -> do res <- sqlAllocStmt hDBC ((#ptr FIELD, hSTMT) pFIELD) handleSqlResult sqlHandleDbc hDBC res hSTMT <- (#peek FIELD, hSTMT) pFIELD let handleResult res = handleSqlResult sqlHandleStmt hSTMT res #if defined(MSSQL_ODBC) sqlSetStmtAttr hSTMT sqlAttrRowArraySize 2 sqlIsInteger sqlSetStmtAttr hSTMT sqlAttrCursorType sqlCursorStatic sqlIsInteger #endif f hSTMT >>= handleResult fields <- moveToFirstResult hSTMT pFIELD buffer <- mallocBytes (fromIntegral stmtBufferSize) refFalse <- newMVar False let statement = Statement { stmtConn = connection , stmtClose = odbcCloseStatement hSTMT buffer , stmtFetch = odbcFetch hSTMT , stmtGetCol = getColValue hSTMT buffer , stmtFields = fields , stmtClosed = refFalse } return statement -- | getFieldDefs:: HSTMT -> Ptr a -> SQLUSMALLINT -> SQLUSMALLINT -> IO [ColDef] getFieldDefs hSTMT pFIELD n count | n > count = return [] | otherwise = do res <- sqlDescribeCol hSTMT n ((#ptr FIELD, fieldName) pFIELD) fieldNameLength ((#ptr FIELD, NameLength) pFIELD) ((#ptr FIELD, DataType) pFIELD) ((#ptr FIELD, ColumnSize) pFIELD) ((#ptr FIELD, DecimalDigits) pFIELD) ((#ptr FIELD, Nullable) pFIELD) handleSqlResult sqlHandleStmt hSTMT res name <- peekCString ((#ptr FIELD, fieldName) pFIELD) dataType <- (#peek FIELD, DataType) pFIELD columnSize <- (#peek FIELD, ColumnSize) pFIELD decimalDigits <- (#peek FIELD, DecimalDigits) pFIELD (nullable :: SQLSMALLINT) <- (#peek FIELD, Nullable) pFIELD let sqlType = mkSqlType dataType columnSize decimalDigits fields <- getFieldDefs hSTMT pFIELD (n+1) count return ((name,sqlType,toBool nullable):fields) -- | moveToFirstResult :: HSTMT -> Ptr a -> IO [ColDef] moveToFirstResult hSTMT pFIELD = do res <- sqlNumResultCols hSTMT ((#ptr FIELD, fieldsCount) pFIELD) handleSqlResult sqlHandleStmt hSTMT res count <- (#peek FIELD, fieldsCount) pFIELD if count == 0 then do #if defined(MSSQL_ODBC) sqlSetStmtAttr hSTMT sqlAttrRowArraySize 2 sqlIsInteger sqlSetStmtAttr hSTMT sqlAttrCursorType sqlCursorStatic sqlIsInteger #endif res <- sqlMoreResults hSTMT handleSqlResult sqlHandleStmt hSTMT res if res == sqlNoData then return [] else moveToFirstResult hSTMT pFIELD else getFieldDefs hSTMT pFIELD 1 count -- | odbcFetch :: HSTMT -> IO Bool odbcFetch hSTMT = do res <- sqlFetch hSTMT handleSqlResult sqlHandleStmt hSTMT res return (res /= sqlNoData) -- | odbcCloseStatement :: HSTMT -> CString -> IO () odbcCloseStatement hSTMT buffer = do free buffer sqlFreeStmt hSTMT sqlDrop >>= handleSqlResult sqlHandleStmt hSTMT ------------------------------------------------------------------------------ -- routines for handling exceptions ------------------------------------------------------------------------------ -- | handleSqlResult :: SQLSMALLINT -> SQLHANDLE -> SQLRETURN -> IO () handleSqlResult handleType handle res | res == sqlSuccess || res == sqlNoData = return () | res == sqlSuccessWithInfo = do #ifdef DEBUG getSqlError >>= (putTraceMsg . show) #else return () #endif | res == sqlInvalidHandle = throw SqlInvalidHandle | res == sqlStillExecuting = throw SqlStillExecuting | res == sqlNeedData = throw SqlNeedMoreData | res == sqlError = do getSqlError >>= throw | otherwise = error (show res) where getSqlError = allocaBytes 256 $ \pState -> alloca $ \pNative -> allocaBytes 256 $ \pMsg -> alloca $ \pTextLen -> do res <- sqlGetDiagRec handleType handle 1 pState pNative pMsg 256 pTextLen if res == sqlNoData then return SqlNoMoreData else do state <- peekCString pState native <- peek pNative msg <- peekCString pMsg return (SqlError { seState=state , seNativeError=fromIntegral native , seErrorMsg=msg }) -- | stmtBufferSize = 256 -- | getColValue :: HSTMT -> CString -> Int -> ColDef -> (ColDef -> CString -> Int -> IO a) -> IO a getColValue hSTMT buffer colNumber fieldDef f = do (res,len_or_ind) <- getData buffer (fromIntegral stmtBufferSize) if len_or_ind == sqlNullData then f fieldDef nullPtr 0 else if res == sqlSuccessWithInfo then getLongData len_or_ind else f fieldDef buffer (fromIntegral len_or_ind) where getData :: CString -> SQLINTEGER -> IO (SQLRETURN, SQLINTEGER) getData buffer size = alloca $ \lenP -> do res <- sqlGetData hSTMT (fromIntegral colNumber+1) sqlCChar (castPtr buffer) size lenP handleSqlResult sqlHandleStmt hSTMT res len_or_ind <- peek lenP return (res, len_or_ind) -- gets called only when there is more data than would -- fit in the normal buffer. This call to -- SQLGetData() will fetch the rest of the data. -- We create a new buffer big enough to hold the -- old and the new data, copy the old data into -- it and put the new data in buffer after the old. getLongData len = allocaBytes (fromIntegral newBufSize) $ \newBuf-> do copyBytes newBuf buffer stmtBufferSize -- The last byte of the old data with always be null, -- so it is overwritten with the first byte of the new data. let newDataStart = newBuf `plusPtr` (stmtBufferSize - 1) newDataLen = newBufSize - (fromIntegral stmtBufferSize - 1) (res,_) <- getData newDataStart newDataLen f fieldDef newBuf (fromIntegral newBufSize-1) where newBufSize = len+1 -- to allow for terminating null character