{-# OPTIONS -fglasgow-exts #-}
{-| Module      :  Database.HSQL.ODBC
    Copyright   :  (c) Krasimir Angelov 2003
    License     :  BSD-style

    Maintainer  :  kr.angelov@gmail.com
    Stability   :  provisional
    Portability :  portable

    The module provides interface to ODBC
-}
module Database.HSQL.ODBC(connect, driverConnect
                         ,module Database.HSQL) where

import Database.HSQL
import Database.HSQL.Types
import Foreign(toBool,Ptr,allocaBytes,alloca,nullPtr,peek
              ,newForeignPtr,withForeignPtr)
import Foreign.C(withCString,withCStringLen)
import Control.Concurrent.MVar(newMVar)
import System.IO.Unsafe(unsafePerformIO)

import DB.HSQL.ODBC.Type(HDBC,SQLRETURN,HENV,HENVRef,mkSqlType,SQL)
import DB.HSQL.ODBC.Functions
import DB.HSQL.ODBC.Core(handleSqlResult,withStatement)

------------------------------------------------------------------------------
-- Connect/Disconnect
------------------------------------------------------------------------------
-- | Makes a new connection to the ODBC data source
connect :: String               -- ^ Data source name
        -> String               -- ^ User identifier
        -> String               -- ^ Authentication string (password)
        -> IO Connection        -- ^ the returned value represents
                                -- the new connection
connect server user authentication = connectHelper $ \hDBC ->
        withCString server $ \pServer ->
	withCString user $ \pUser ->
	withCString authentication $ \pAuthentication ->
	sqlConnect hDBC 
                   pServer (fromIntegral sqlNts)
                   pUser (fromIntegral sqlNts)
                   pAuthentication (fromIntegral sqlNts)

-- | 'driverConnect' is an alternative to 'connect'. It supports data sources
-- that require more connection information than the three arguments in
-- 'connect' and data sources that are not defined in the system information.
driverConnect :: String               -- ^ Connection string
              -> IO Connection        -- ^ the returned value represents
                                      -- the new connection
driverConnect connString = connectHelper $ \hDBC -> 
    withCString connString $ \pConnString ->
	allocaBytes 1024 $ \pOutConnString ->
	    alloca $ \pLen ->
	        sqlDriverConnect hDBC 
                                 nullPtr 
                                 pConnString 
                                 sqlNts 
                                 pOutConnString 1024 
                                 pLen 
                                 sqlDriverNoPrompt

-- |
connectHelper :: (HDBC -> IO SQLRETURN) -> IO Connection
connectHelper connectFunction = withForeignPtr myEnvironment $ \hEnv -> do
  hDBC <- alloca $ \ (phDBC :: Ptr HDBC) -> do
    res <- sqlAllocConnect hEnv phDBC
    handleSqlResult sqlHandleEnv hEnv res
    peek phDBC
  res <- connectFunction hDBC
  handleSqlResult sqlHandleDbc hDBC res
  refFalse <- newMVar False
  let connection 
        = Connection { connDisconnect = odbcDisconnect hDBC
		     , connExecute    = odbcExecute hDBC
		     , connQuery      = odbcQuery connection hDBC
		     , connTables     = odbcTables connection hDBC
		     , connDescribe   = odbcDescribe connection hDBC
		     , connBeginTransaction = 
                         beginTransaction myEnvironment hDBC
		     , connCommitTransaction = 
                         commitTransaction myEnvironment hDBC
		     , connRollbackTransaction = 
                         rollbackTransaction myEnvironment hDBC
		     , connClosed     = refFalse }
  return connection

{-|

-}
odbcDisconnect :: HDBC -- ^ ODBC handle
               -> IO ()
odbcDisconnect hDBC = do
  sqlDisconnect hDBC >>= handleSqlResult sqlHandleDbc hDBC
  sqlFreeConnect hDBC >>= handleSqlResult sqlHandleDbc hDBC

{-|

-}
odbcExecute :: HDBC -- ^ ODBC handle
            -> SQL -- ^ SQL Query
            -> IO ()
odbcExecute hDBC query = allocaBytes sizeOfHStmt $ \pStmt -> do
  res <- sqlAllocStmt hDBC pStmt
  handleSqlResult sqlHandleDbc hDBC res
  hSTMT <- peek pStmt
  withCStringLen query $ \(pQuery,len) -> do
    res <- sqlExecDirect hSTMT pQuery len
    handleSqlResult sqlHandleStmt hSTMT res
  res <- sqlFreeStmt hSTMT sqlDrop
  handleSqlResult sqlHandleStmt hSTMT res

{-|

-}
odbcQuery :: Connection 
          -> HDBC -- ^ ODBC handle
          -> String -- ^ SQL Query
          -> IO Statement
odbcQuery connection hDBC q = 
    withStatement connection hDBC doQuery
    where doQuery hSTMT = 
              withCStringLen q (uncurry (sqlExecDirect hSTMT))

{-|

-}
odbcTables :: Connection 
           -> HDBC -- ^ ODBC handle
           -> IO [String]
odbcTables connection hDBC = do
  stmt <- withStatement connection hDBC sqlTables'
  -- SQLTables returns (column names may vary):
  -- Column name     #   Type
  -- TABLE_NAME      3   VARCHAR
  collectRows (\s -> getFieldValue s "TABLE_NAME") stmt
  where sqlTables' hSTMT = sqlTables hSTMT nullPtr 0 
                                           nullPtr 0 
                                           nullPtr 0 
                                           nullPtr 0

{-|
-}
odbcDescribe :: Connection 
             -> HDBC -- ^ ODBC handle
             -> String -- ^ table name
             -> IO [FieldDef]
odbcDescribe connection hDBC table = do
  stmt <- withStatement connection hDBC (odbcSqlColumns table)
  collectRows getColumnInfo stmt

-- |
odbcSqlColumns table hSTMT =
    withCStringLen table (\(pTable,len) ->
	sqlColumns hSTMT nullPtr 0 
                         nullPtr 0 
                         pTable (fromIntegral len) 
                         nullPtr 0)

-- | SQLColumns returns (column names may vary):
-- Column name     #   Type
-- COLUMN_NAME     4   Varchar not NULL
-- DATA_TYPE       5   Smallint not NULL
-- COLUMN_SIZE     7   Integer
-- DECIMAL_DIGITS  9   Smallint
-- NULLABLE       11   Smallint not NULL
getColumnInfo stmt = do
  column_name <- getFieldValue stmt "COLUMN_NAME"
  (data_type::Int) <- getFieldValue stmt "DATA_TYPE"
  (column_size::Int) <- getFieldValue' stmt "COLUMN_SIZE" 0
  (decimal_digits::Int) <- getFieldValue' stmt "DECIMAL_DIGITS" 0
  let sqlType = mkSqlType (fromIntegral data_type) 
                          (fromIntegral column_size) 
                          (fromIntegral decimal_digits)
  (nullable::Int) <- getFieldValue stmt "NULLABLE"
  return (column_name, sqlType, toBool nullable)

------------------------------------------------------------------------------
-- transaction management
------------------------------------------------------------------------------
{-|
-}
beginTransaction:: HENVRef-> HDBC-> IO ()
beginTransaction myEnvironment hDBC = do
  sqlSetConnectOption hDBC sqlAutoCommit sqlAutoCommitOff
  return ()

{-|
-}
commitTransaction:: HENVRef-> HDBC-> IO ()
commitTransaction myEnvironment hDBC = 
    withForeignPtr myEnvironment $ \hEnv -> do
      sqlTransact hEnv hDBC sqlCommit
      sqlSetConnectOption hDBC sqlAutoCommit sqlAutoCommitOn
      return ()

{-|
-}
rollbackTransaction:: HENVRef-> HDBC-> IO ()
rollbackTransaction myEnvironment hDBC = 
    withForeignPtr myEnvironment $ \hEnv -> do
      sqlTransact hEnv hDBC sqlRollback
      sqlSetConnectOption hDBC sqlAutoCommit sqlAutoCommitOn
      return ()

------------------------------------------------------------------------------
-- keeper of HENV
------------------------------------------------------------------------------
{-|
-}
{-# NOINLINE myEnvironment #-}
myEnvironment :: HENVRef
myEnvironment = unsafePerformIO $ alloca $ \ (phEnv :: Ptr HENV) -> do
  res <- sqlAllocEnv phEnv
  hEnv <- peek phEnv
  handleSqlResult 0 nullPtr res
  newForeignPtr sqlFreeEnv_p hEnv