{-# LINE 1 "DB/HSQL/ODBC/Core.hsc" #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LINE 2 "DB/HSQL/ODBC/Core.hsc" #-}
module DB.HSQL.ODBC.Core(handleSqlResult,withStatement) 
where


import Control.Exception(throw)
import Database.HSQL
import Database.HSQL.Types
import Control.Concurrent.MVar(newMVar)
import Foreign(plusPtr,peekByteOff,toBool,Ptr,nullPtr,castPtr,copyBytes
              ,allocaBytes,alloca,mallocBytes,free,peek)
import Foreign.C(CString,peekCString)

{-# LINE 16 "DB/HSQL/ODBC/Core.hsc" #-}

import DB.HSQL.ODBC.Type
import DB.HSQL.ODBC.Functions
import DB.HSQL.ODBC.Status


{-# LINE 22 "DB/HSQL/ODBC/Core.hsc" #-}

-- | 
withStatement :: Connection -> HDBC -> (HSTMT -> IO SQLRETURN) -> IO Statement
withStatement connection hDBC f = 
    allocaBytes sizeOfField $ \pFIELD -> do
    res <- sqlAllocStmt hDBC (((\hsc_ptr -> hsc_ptr `plusPtr` 0)) pFIELD)
{-# LINE 28 "DB/HSQL/ODBC/Core.hsc" #-}
    handleSqlResult sqlHandleDbc hDBC res
    hSTMT <- ((\hsc_ptr -> peekByteOff hsc_ptr 0)) pFIELD
{-# LINE 30 "DB/HSQL/ODBC/Core.hsc" #-}
    let handleResult res = handleSqlResult sqlHandleStmt hSTMT res


{-# LINE 36 "DB/HSQL/ODBC/Core.hsc" #-}

    f hSTMT >>= handleResult
    fields <- moveToFirstResult hSTMT pFIELD
    buffer <- mallocBytes (fromIntegral stmtBufferSize)
    refFalse <- newMVar False
    let statement = Statement
		    { stmtConn   = connection
		    , stmtClose  = odbcCloseStatement hSTMT buffer
		    , stmtFetch  = odbcFetch hSTMT
		    , stmtGetCol = getColValue hSTMT buffer
		    , stmtFields = fields
		    , stmtClosed = refFalse }
    return statement


-- | 
getFieldDefs:: HSTMT -> Ptr a -> SQLUSMALLINT -> SQLUSMALLINT -> IO [FieldDef]
getFieldDefs hSTMT pFIELD n count
    | n > count  = return []
    | otherwise = do
        res <- sqlDescribeCol hSTMT n 
                              (((\hsc_ptr -> hsc_ptr `plusPtr` 6)) pFIELD) fieldNameLength 
{-# LINE 58 "DB/HSQL/ODBC/Core.hsc" #-}
                              (((\hsc_ptr -> hsc_ptr `plusPtr` 262)) pFIELD) 
{-# LINE 59 "DB/HSQL/ODBC/Core.hsc" #-}
                              (((\hsc_ptr -> hsc_ptr `plusPtr` 264)) pFIELD) 
{-# LINE 60 "DB/HSQL/ODBC/Core.hsc" #-}
                              (((\hsc_ptr -> hsc_ptr `plusPtr` 268)) pFIELD) 
{-# LINE 61 "DB/HSQL/ODBC/Core.hsc" #-}
                              (((\hsc_ptr -> hsc_ptr `plusPtr` 272)) pFIELD) 
{-# LINE 62 "DB/HSQL/ODBC/Core.hsc" #-}
                              (((\hsc_ptr -> hsc_ptr `plusPtr` 274)) pFIELD)
{-# LINE 63 "DB/HSQL/ODBC/Core.hsc" #-}
	handleSqlResult sqlHandleStmt hSTMT res
	name <- peekCString (((\hsc_ptr -> hsc_ptr `plusPtr` 6)) pFIELD)
{-# LINE 65 "DB/HSQL/ODBC/Core.hsc" #-}
	dataType <- ((\hsc_ptr -> peekByteOff hsc_ptr 264)) pFIELD
{-# LINE 66 "DB/HSQL/ODBC/Core.hsc" #-}
	columnSize <- ((\hsc_ptr -> peekByteOff hsc_ptr 268)) pFIELD
{-# LINE 67 "DB/HSQL/ODBC/Core.hsc" #-}
	decimalDigits <- ((\hsc_ptr -> peekByteOff hsc_ptr 272)) pFIELD
{-# LINE 68 "DB/HSQL/ODBC/Core.hsc" #-}
	(nullable :: SQLSMALLINT) <- ((\hsc_ptr -> peekByteOff hsc_ptr 274)) pFIELD
{-# LINE 69 "DB/HSQL/ODBC/Core.hsc" #-}
	let sqlType = mkSqlType dataType columnSize decimalDigits
	fields <- getFieldDefs hSTMT pFIELD (n+1) count
        return ((name,sqlType,toBool nullable):fields)


-- | 
moveToFirstResult :: HSTMT -> Ptr a -> IO [FieldDef]
moveToFirstResult hSTMT pFIELD = do
  res <- sqlNumResultCols hSTMT (((\hsc_ptr -> hsc_ptr `plusPtr` 4)) pFIELD)
{-# LINE 78 "DB/HSQL/ODBC/Core.hsc" #-}
  handleSqlResult sqlHandleStmt hSTMT res
  count <- ((\hsc_ptr -> peekByteOff hsc_ptr 4)) pFIELD
{-# LINE 80 "DB/HSQL/ODBC/Core.hsc" #-}
  if count == 0
    then do


{-# LINE 87 "DB/HSQL/ODBC/Core.hsc" #-}

      res <- sqlMoreResults hSTMT
      handleSqlResult sqlHandleStmt hSTMT res
      if res == sqlNoData
        then return []
        else moveToFirstResult hSTMT pFIELD
    else getFieldDefs hSTMT pFIELD 1 count



-- |
odbcFetch :: HSTMT -> IO Bool
odbcFetch hSTMT = do
  res <- sqlFetch hSTMT
  handleSqlResult sqlHandleStmt hSTMT res
  return (res /= sqlNoData)


-- |
odbcCloseStatement :: HSTMT -> CString -> IO ()
odbcCloseStatement hSTMT buffer = do
  free buffer
  sqlFreeStmt hSTMT sqlDrop >>= handleSqlResult sqlHandleStmt hSTMT


------------------------------------------------------------------------------
-- routines for handling exceptions
------------------------------------------------------------------------------
-- |
handleSqlResult :: SQLSMALLINT -> SQLHANDLE -> SQLRETURN -> IO ()
handleSqlResult handleType handle res
    | res == sqlSuccess || res == sqlNoData = return ()
    | res == sqlSuccessWithInfo = do


{-# LINE 124 "DB/HSQL/ODBC/Core.hsc" #-}
	return ()

{-# LINE 126 "DB/HSQL/ODBC/Core.hsc" #-}

    | res == sqlInvalidHandle = throw SqlInvalidHandle
    | res == sqlStillExecuting = throw SqlStillExecuting
    | res == sqlNeedData = throw SqlNeedData
    | res == sqlError = do
	getSqlError >>= throw
    | otherwise = error (show res)
    where
	  getSqlError =
	    allocaBytes 256 $ \pState ->
	    alloca $ \pNative ->
	    allocaBytes 256 $ \pMsg ->
	    alloca $ \pTextLen -> do
		  res <- sqlGetDiagRec handleType handle 1 
                                       pState pNative pMsg 256 pTextLen
		  if res == sqlNoData
		    then return SqlNoData
		    else do
		      state  <- peekCString pState
		      native <- peek pNative
		      msg    <- peekCString pMsg
		      return (SqlError { seState=state
                                       , seNativeError=fromIntegral native
                                       , seErrorMsg=msg })


-- |
stmtBufferSize = 256


-- |
getColValue :: HSTMT -> CString -> Int -> FieldDef -> (FieldDef -> CString -> Int -> IO a) -> IO a
getColValue hSTMT buffer colNumber fieldDef f = do
  (res,len_or_ind) <- getData buffer (fromIntegral stmtBufferSize)
  if len_or_ind == sqlNullData
    then f fieldDef nullPtr 0
    else if res == sqlSuccessWithInfo
	 then getLongData len_or_ind
	 else f fieldDef buffer (fromIntegral len_or_ind)
  where getData :: CString -> SQLINTEGER -> IO (SQLRETURN, SQLINTEGER)
	getData buffer size = alloca $ \lenP -> do
	    res <- sqlGetData hSTMT (fromIntegral colNumber+1) 
                              sqlCChar (castPtr buffer) size lenP
	    handleSqlResult sqlHandleStmt hSTMT res
	    len_or_ind <- peek lenP
	    return (res, len_or_ind)

	-- gets called only when there is more data than would
	-- fit in the normal buffer. This call to
	-- SQLGetData() will fetch the rest of the data.
	-- We create a new buffer big enough to hold the
	-- old and the new data, copy the old data into
	-- it and put the new data in buffer after the old.
	getLongData len = allocaBytes (fromIntegral newBufSize) $ \newBuf-> do
	    copyBytes newBuf buffer stmtBufferSize
	    -- The last byte of the old data with always be null,
	    -- so it is overwritten with the first byte of  the new data.
	    let newDataStart = newBuf `plusPtr` (stmtBufferSize - 1)
		newDataLen = newBufSize - (fromIntegral stmtBufferSize - 1)
	    (res,_) <- getData newDataStart newDataLen
	    f fieldDef newBuf (fromIntegral newBufSize-1)
	    where newBufSize = len+1 -- to allow for terminating null character