{-# LANGUAGE RankNTypes #-}
-- #hide
{-| Basic type class & type definitions for DB interfacing.
-}
module Database.HSQL.Types
    (FieldDef,SqlType(..),SqlError(..),sqlErrorTc
    ,Connection(..),Statement(..),SqlBind(..)) where

import Control.Concurrent.MVar(MVar)
import Control.Exception(throw)
import Foreign(nullPtr)
import Foreign.C(CString,peekCStringLen)

import DB.HSQL.Type
import DB.HSQL.Error

-- |
type FieldDef = (String, SqlType, Bool)

-- | A 'Connection' type represents a connection to a database,
-- through which you can operate on the it.
-- In order to create the connection you need to use the @connect@ function
-- from the module for your prefered backend.
data Connection
  = Connection {
      -- | disconnect action
      connDisconnect :: IO (),
      -- | query execution action (without return value)
      connExecute :: String -> IO (),
      -- | query action with return value
      connQuery :: String -> IO Statement,
      -- | retrieval of the names of the tables in reach
      connTables :: IO [String],
      -- | retrieval of the field defs of a table
      connDescribe :: String -> IO [FieldDef],
      -- | begin of a transaction
      connBeginTransaction :: IO (),
      -- | commit of a pending transaction
      connCommitTransaction :: IO (),
      -- | rollback of a pending transaction
      connRollbackTransaction :: IO (),
      -- | closing state of the connection
      connClosed :: MVar Bool }


-- | The 'Statement' type represents a result from the execution of given
-- SQL query.
data Statement
    = Statement { 
        -- | field descriptors
        stmtConn :: Connection,
        -- | closing action
        stmtClose :: IO (),
        -- | incrementation of the row pointer and indication
        -- whether this is still in range of available rows
        stmtFetch :: IO Bool,
        -- | extraction of a field from the current result row, with
        -- 
        --  * a column index
        -- 
        --  * a column field definition
        -- 
        --  * a generic field extraction function, specifiable by
        --    a field definition, a C string and its length
        stmtGetCol :: forall a . Int -- column index
                   -> FieldDef   -- column field definition
                   -> (FieldDef->CString->Int->IO a) 
                                 -- generic field extraction function,
                                 -- specifiable by field definition, 
                                 -- receiving the C string and its length
                   -> IO a,
        -- | field descriptors
        stmtFields :: [FieldDef],
        -- | closing state of the statement
        stmtClosed :: MVar Bool }


-- |
class SqlBind a where
	toSqlValue   :: a -> String
	fromSqlValue :: SqlType -> String -> Maybe a
	-- | This allows for faster conversion for eq. integral numeric types,
        -- etc. Default version uses fromSqlValue.
	fromSqlCStringLen :: FieldDef -> CString -> Int -> IO a
	fromSqlCStringLen (name,sqlType,_) cstr cstrLen
	  | cstr == nullPtr = throw (SqlFetchNull name)
	  | otherwise = do 
	      str <- peekCStringLen (cstr, cstrLen)
	      case fromSqlValue sqlType str of
	        Nothing -> throw (SqlBadTypeCast name sqlType)
	        Just v  -> return v