module DB.HSQL.ODBC.Core(handleSqlResult,withStatement)
where
import Control.Exception(throw)
import Database.HSQL
import Database.HSQL.Types
import Control.Concurrent.MVar(newMVar)
import Foreign(plusPtr,peekByteOff,toBool,Ptr,nullPtr,castPtr,copyBytes
,allocaBytes,alloca,mallocBytes,free,peek)
import Foreign.C(CString,peekCString)
import DB.HSQL.ODBC.Type
import DB.HSQL.ODBC.Functions
import DB.HSQL.ODBC.Status
withStatement :: Connection -> HDBC -> (HSTMT -> IO SQLRETURN) -> IO Statement
withStatement connection hDBC f =
allocaBytes sizeOfField $ \pFIELD -> do
res <- sqlAllocStmt hDBC (((\hsc_ptr -> hsc_ptr `plusPtr` 0)) pFIELD)
handleSqlResult sqlHandleDbc hDBC res
hSTMT <- ((\hsc_ptr -> peekByteOff hsc_ptr 0)) pFIELD
let handleResult res = handleSqlResult sqlHandleStmt hSTMT res
f hSTMT >>= handleResult
fields <- moveToFirstResult hSTMT pFIELD
buffer <- mallocBytes (fromIntegral stmtBufferSize)
refFalse <- newMVar False
let statement = Statement
{ stmtConn = connection
, stmtClose = odbcCloseStatement hSTMT buffer
, stmtFetch = odbcFetch hSTMT
, stmtGetCol = getColValue hSTMT buffer
, stmtFields = fields
, stmtClosed = refFalse }
return statement
getFieldDefs:: HSTMT -> Ptr a -> SQLUSMALLINT -> SQLUSMALLINT -> IO [FieldDef]
getFieldDefs hSTMT pFIELD n count
| n > count = return []
| otherwise = do
res <- sqlDescribeCol hSTMT n
(((\hsc_ptr -> hsc_ptr `plusPtr` 6)) pFIELD) fieldNameLength
(((\hsc_ptr -> hsc_ptr `plusPtr` 262)) pFIELD)
(((\hsc_ptr -> hsc_ptr `plusPtr` 264)) pFIELD)
(((\hsc_ptr -> hsc_ptr `plusPtr` 268)) pFIELD)
(((\hsc_ptr -> hsc_ptr `plusPtr` 272)) pFIELD)
(((\hsc_ptr -> hsc_ptr `plusPtr` 274)) pFIELD)
handleSqlResult sqlHandleStmt hSTMT res
name <- peekCString (((\hsc_ptr -> hsc_ptr `plusPtr` 6)) pFIELD)
dataType <- ((\hsc_ptr -> peekByteOff hsc_ptr 264)) pFIELD
columnSize <- ((\hsc_ptr -> peekByteOff hsc_ptr 268)) pFIELD
decimalDigits <- ((\hsc_ptr -> peekByteOff hsc_ptr 272)) pFIELD
(nullable :: SQLSMALLINT) <- ((\hsc_ptr -> peekByteOff hsc_ptr 274)) pFIELD
let sqlType = mkSqlType dataType columnSize decimalDigits
fields <- getFieldDefs hSTMT pFIELD (n+1) count
return ((name,sqlType,toBool nullable):fields)
moveToFirstResult :: HSTMT -> Ptr a -> IO [FieldDef]
moveToFirstResult hSTMT pFIELD = do
res <- sqlNumResultCols hSTMT (((\hsc_ptr -> hsc_ptr `plusPtr` 4)) pFIELD)
handleSqlResult sqlHandleStmt hSTMT res
count <- ((\hsc_ptr -> peekByteOff hsc_ptr 4)) pFIELD
if count == 0
then do
res <- sqlMoreResults hSTMT
handleSqlResult sqlHandleStmt hSTMT res
if res == sqlNoData
then return []
else moveToFirstResult hSTMT pFIELD
else getFieldDefs hSTMT pFIELD 1 count
odbcFetch :: HSTMT -> IO Bool
odbcFetch hSTMT = do
res <- sqlFetch hSTMT
handleSqlResult sqlHandleStmt hSTMT res
return (res /= sqlNoData)
odbcCloseStatement :: HSTMT -> CString -> IO ()
odbcCloseStatement hSTMT buffer = do
free buffer
sqlFreeStmt hSTMT sqlDrop >>= handleSqlResult sqlHandleStmt hSTMT
handleSqlResult :: SQLSMALLINT -> SQLHANDLE -> SQLRETURN -> IO ()
handleSqlResult handleType handle res
| res == sqlSuccess || res == sqlNoData = return ()
| res == sqlSuccessWithInfo = do
return ()
| res == sqlInvalidHandle = throw SqlInvalidHandle
| res == sqlStillExecuting = throw SqlStillExecuting
| res == sqlNeedData = throw SqlNeedData
| res == sqlError = do
getSqlError >>= throw
| otherwise = error (show res)
where
getSqlError =
allocaBytes 256 $ \pState ->
alloca $ \pNative ->
allocaBytes 256 $ \pMsg ->
alloca $ \pTextLen -> do
res <- sqlGetDiagRec handleType handle 1
pState pNative pMsg 256 pTextLen
if res == sqlNoData
then return SqlNoData
else do
state <- peekCString pState
native <- peek pNative
msg <- peekCString pMsg
return (SqlError { seState=state
, seNativeError=fromIntegral native
, seErrorMsg=msg })
stmtBufferSize = 256
getColValue :: HSTMT -> CString -> Int -> FieldDef -> (FieldDef -> CString -> Int -> IO a) -> IO a
getColValue hSTMT buffer colNumber fieldDef f = do
(res,len_or_ind) <- getData buffer (fromIntegral stmtBufferSize)
if len_or_ind == sqlNullData
then f fieldDef nullPtr 0
else if res == sqlSuccessWithInfo
then getLongData len_or_ind
else f fieldDef buffer (fromIntegral len_or_ind)
where getData :: CString -> SQLINTEGER -> IO (SQLRETURN, SQLINTEGER)
getData buffer size = alloca $ \lenP -> do
res <- sqlGetData hSTMT (fromIntegral colNumber+1)
sqlCChar (castPtr buffer) size lenP
handleSqlResult sqlHandleStmt hSTMT res
len_or_ind <- peek lenP
return (res, len_or_ind)
getLongData len = allocaBytes (fromIntegral newBufSize) $ \newBuf-> do
copyBytes newBuf buffer stmtBufferSize
let newDataStart = newBuf `plusPtr` (stmtBufferSize 1)
newDataLen = newBufSize (fromIntegral stmtBufferSize 1)
(res,_) <- getData newDataStart newDataLen
f fieldDef newBuf (fromIntegral newBufSize1)
where newBufSize = len+1