{-# LANGUAGE TupleSections #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ForeignFunctionInterface #-}
module ODBC
(
connect
, close
, exec
, query
, Value(..)
, Connection
) where
import Control.Concurrent.Async
import Control.Concurrent.MVar
import Control.DeepSeq
import Control.Exception
import qualified Data.ByteString as S
import Data.Coerce
import Data.Data
import Data.Int
import Data.Text (Text)
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import qualified Data.Text.Foreign as T
import Foreign
import Foreign.C
import GHC.Generics
sql_success :: RETCODE
sql_success = RETCODE 0
sql_success_with_info :: RETCODE
sql_success_with_info = RETCODE 1
sql_no_data :: RETCODE
sql_no_data = RETCODE 100
data ODBCException
= UnsuccessfulReturnCode !String !RETCODE
| AllocationReturnedNull !String
| UnknownType !Int16
| DatabaseIsClosed !String
deriving (Typeable, Show, Eq)
instance Exception ODBCException
data Column = Column
{ columnType :: !SQLSMALLINT
, columnSize :: !SQLULEN
, columnDigits :: !SQLSMALLINT
, columnNull :: !SQLSMALLINT
} deriving (Show)
data Value
= TextValue !Text
| BoolValue !Bool
| DoubleValue !Double
| IntValue !Int
deriving (Eq, Show, Typeable, Ord, Generic, Data)
instance NFData Value
newtype Connection = Connection
{connectionMVar :: MVar (Maybe ConnectionState)}
data ConnectionState = ConnectionState
{ connectionEnv :: !(ForeignPtr SQLHENV)
, connectionDbc :: !(ForeignPtr SQLHDBC)
}
connect ::
Text
-> IO Connection
connect string =
withBound
(do env <-
uninterruptibleMask_
(do ptr <- assertNotNull "odbc_SQLAllocEnv" odbc_SQLAllocEnv
newForeignPtr odbc_SQLFreeEnv (coerce ptr))
assertSuccess "odbc_SetEnvAttr" (withForeignPtr env odbc_SetEnvAttr)
dbc <-
uninterruptibleMask_
(do ptr <-
assertNotNull
"odbc_SQLAllocDbc"
(withForeignPtr env odbc_SQLAllocDbc)
newForeignPtr odbc_SQLFreeDbc (coerce ptr))
T.useAsPtr
string
(\wstring len ->
uninterruptibleMask_
(do assertSuccess
"odbc_SQLDriverConnect"
(withForeignPtr
dbc
(\dbcPtr ->
odbc_SQLDriverConnectW
dbcPtr
(coerce wstring)
(fromIntegral len)))
addForeignPtrFinalizer odbc_SQLDisconnect dbc))
mvar <-
newMVar
(Just (ConnectionState {connectionEnv = env, connectionDbc = dbc}))
pure (Connection mvar))
close :: Connection -> IO ()
close conn =
withBound
(do mstate <- modifyMVar (connectionMVar conn) (pure . (Nothing, ))
case mstate of
Just (ConnectionState env dbc) -> do
finalizeForeignPtr dbc
finalizeForeignPtr env
Nothing -> pure ())
exec ::
Connection
-> Text
-> IO ()
exec conn string =
withBound
(withHDBC conn "exec" (\dbc -> withExecDirect dbc string (const (pure ()))))
query ::
Connection
-> Text
-> IO [[Value]]
query conn string =
withBound
(withHDBC
conn
"query"
(\dbc -> withExecDirect dbc string fetchStatementRows))
withHDBC :: Connection -> String -> (Ptr SQLHDBC -> IO a) -> IO a
withHDBC conn label f =
withMVar
(connectionMVar conn)
(\mfptr ->
case mfptr of
Nothing -> throwIO (DatabaseIsClosed label)
Just (ConnectionState {connectionDbc = db,connectionEnv=env}) -> do
v <- withForeignPtr db f
touchForeignPtr db
touchForeignPtr env
pure v)
withExecDirect :: Ptr SQLHDBC -> Text -> (forall s. SQLHSTMT s -> IO a) -> IO a
withExecDirect dbc string cont =
withStmt
dbc
(\stmt -> do
assertSuccess
"odbc_SQLExecDirectW"
(T.useAsPtr
string
(\wstring len ->
odbc_SQLExecDirectW stmt (coerce wstring) (fromIntegral len)))
cont stmt)
withStmt :: Ptr SQLHDBC -> (forall s. SQLHSTMT s -> IO a) -> IO a
withStmt hdbc =
bracket
(assertNotNull "odbc_SQLAllocStmt" (odbc_SQLAllocStmt hdbc))
odbc_SQLFreeStmt
withBound :: IO a -> IO a
withBound = flip withAsyncBound wait
fetchStatementRows :: SQLHSTMT s -> IO [[Value]]
fetchStatementRows stmt = do
SQLSMALLINT cols <-
withMalloc
(\sizep -> do
assertSuccess
"odbc_SQLNumResultCols"
(odbc_SQLNumResultCols stmt sizep)
peek sizep)
types <- mapM (describeColumn stmt) [1 .. cols]
let loop rows = do
do retcode0 <- odbc_SQLFetch stmt
if | retcode0 == sql_no_data ->
do retcode <- odbc_SQLMoreResults stmt
if retcode == sql_success ||
retcode == sql_success_with_info
then loop rows
else pure (rows [])
| retcode0 == sql_success ||
retcode0 == sql_success_with_info ->
do fields <- sequence (zipWith (getData stmt) [1 ..] types)
loop (rows . (fields :))
| otherwise ->
throwIO (UnsuccessfulReturnCode "odbc_SQLFetch" retcode0)
loop id
describeColumn :: SQLHSTMT s -> Int16 -> IO Column
describeColumn stmt i =
T.useAsPtr
(T.replicate 1000 "0")
(\namep namelen ->
(withMalloc
(\namelenp ->
(withMalloc
(\typep ->
withMalloc
(\sizep ->
withMalloc
(\digitsp ->
withMalloc
(\nullp -> do
assertSuccess
"odbc_SQLDescribeColW"
(odbc_SQLDescribeColW
stmt
(SQLUSMALLINT (fromIntegral i))
(coerce namep)
(SQLSMALLINT (fromIntegral namelen))
namelenp
typep
sizep
digitsp
nullp)
typ <- peek typep
size <- peek sizep
digits <- peek digitsp
isnull <- peek nullp
evaluate
Column
{ columnType = typ
, columnSize = size
, columnDigits = digits
, columnNull = isnull
}))))))))
getData :: SQLHSTMT s -> Int -> Column -> IO Value
getData stmt i col =
case columnType col of
SQLSMALLINT (-9)
->
withCallocBytes
(fromIntegral allocBytes)
(\bufferp -> do
withMalloc
(\copiedPtr -> do
apply
(SQLSMALLINT (-8))
(coerce bufferp)
(SQLLEN (fromIntegral allocBytes))
copiedPtr
SQLLEN copiedBytes <- peek copiedPtr
bs <- S.packCStringLen (bufferp, fromIntegral copiedBytes)
evaluate (TextValue (T.decodeUtf16LE bs))))
where maxChars = coerce (columnSize col) :: Word64
allocBytes = maxChars * 2 + 2
SQLSMALLINT (-7)
->
withMalloc
(\ignored ->
withMalloc
(\bitPtr -> do
apply (columnType col) (coerce bitPtr) (SQLLEN 1) ignored
fmap (BoolValue . (/= (0 :: Word8))) (peek bitPtr)))
SQLSMALLINT 6
->
withMalloc
(\doublePtr ->
withMalloc
(\ignored -> do
apply (SQLSMALLINT 8) (coerce doublePtr) (SQLLEN 8) ignored
!d <- fmap DoubleValue (peek doublePtr)
pure d))
SQLSMALLINT 4
->
withMalloc
(\intPtr ->
withMalloc
(\ignored -> do
apply (columnType col) (coerce intPtr) (SQLLEN 4) ignored
fmap (IntValue . fromIntegral) (peek (intPtr :: Ptr Int32))))
SQLSMALLINT 5
->
withMalloc
(\intPtr ->
withMalloc
(\ignored -> do
apply (columnType col) (coerce intPtr) (SQLLEN 2) ignored
fmap (IntValue . fromIntegral) (peek (intPtr :: Ptr Int16))))
_ ->
throwIO
(UnknownType
(let SQLSMALLINT n = columnType col
in n))
where
apply ty bufferp bufferlen strlenp =
assertSuccess
"odbc_SQLGetData"
(odbc_SQLGetData
stmt
(SQLUSMALLINT (fromIntegral i))
ty
bufferp
bufferlen
strlenp)
assertNotNull :: (Coercible a (Ptr ())) => String -> IO a -> IO a
assertNotNull label m = do
val <- m
if coerce val == nullPtr
then throwIO (AllocationReturnedNull label)
else pure val
assertSuccess :: String -> IO RETCODE -> IO ()
assertSuccess label m = do
retcode <- m
if retcode == sql_success || retcode == sql_success_with_info
then pure ()
else throwIO (UnsuccessfulReturnCode label retcode)
data SQLHENV
data SQLHDBC
newtype SQLHSTMT s = SQLHSTMT (Ptr (SQLHSTMT s))
newtype SQLPOINTER = SQLPOINTER (Ptr SQLPOINTER)
newtype RETCODE = RETCODE Int16
deriving (Show, Eq)
newtype SQLUSMALLINT = SQLUSMALLINT Word16 deriving (Show, Eq, Storable)
newtype SQLUCHAR = SQLUCHAR Word8 deriving (Show, Eq, Storable)
newtype SQLCHAR = SQLCHAR CChar deriving (Show, Eq, Storable)
newtype SQLSMALLINT = SQLSMALLINT Int16 deriving (Show, Eq, Storable, Num)
newtype SQLLEN = SQLLEN Int64 deriving (Show, Eq, Storable, Num)
newtype SQLULEN = SQLULEN Word64 deriving (Show, Eq, Storable)
newtype SQLINTEGER = SQLINTEGER Int64 deriving (Show, Eq, Storable, Num)
newtype SQLWCHAR = SQLWCHAR CWString deriving (Show, Eq, Storable)
foreign import ccall "odbc odbc_SQLAllocEnv"
odbc_SQLAllocEnv :: IO (Ptr SQLHENV)
foreign import ccall "odbc &odbc_SQLFreeEnv"
odbc_SQLFreeEnv :: FunPtr (Ptr SQLHENV -> IO ())
foreign import ccall "odbc odbc_SetEnvAttr"
odbc_SetEnvAttr :: Ptr SQLHENV -> IO RETCODE
foreign import ccall "odbc odbc_SQLAllocDbc"
odbc_SQLAllocDbc :: Ptr SQLHENV -> IO (Ptr SQLHDBC)
foreign import ccall "odbc &odbc_SQLFreeDbc"
odbc_SQLFreeDbc :: FunPtr (Ptr SQLHDBC -> IO ())
foreign import ccall "odbc odbc_SQLDriverConnect"
odbc_SQLDriverConnectW :: Ptr SQLHDBC -> SQLWCHAR -> SQLSMALLINT -> IO RETCODE
foreign import ccall "odbc &odbc_SQLDisconnect"
odbc_SQLDisconnect :: FunPtr (Ptr SQLHDBC -> IO ())
foreign import ccall "odbc odbc_SQLAllocStmt"
odbc_SQLAllocStmt :: Ptr SQLHDBC -> IO (SQLHSTMT s)
foreign import ccall "odbc odbc_SQLFreeStmt"
odbc_SQLFreeStmt :: SQLHSTMT s -> IO ()
foreign import ccall "odbc odbc_SQLExecDirectW"
odbc_SQLExecDirectW :: SQLHSTMT s -> SQLWCHAR -> SQLINTEGER -> IO RETCODE
foreign import ccall "odbc odbc_SQLFetch"
odbc_SQLFetch :: SQLHSTMT s -> IO RETCODE
foreign import ccall "odbc odbc_SQLMoreResults"
odbc_SQLMoreResults :: SQLHSTMT s -> IO RETCODE
foreign import ccall "odbc odbc_SQLNumResultCols"
odbc_SQLNumResultCols :: SQLHSTMT s -> Ptr SQLSMALLINT -> IO RETCODE
foreign import ccall "odbc odbc_SQLGetData"
odbc_SQLGetData
:: SQLHSTMT s
-> SQLUSMALLINT
-> SQLSMALLINT
-> SQLPOINTER
-> SQLLEN
-> Ptr SQLLEN
-> IO RETCODE
foreign import ccall "odbc odbc_SQLDescribeColW"
odbc_SQLDescribeColW
:: SQLHSTMT s
-> SQLUSMALLINT
-> Ptr SQLWCHAR
-> SQLSMALLINT
-> Ptr SQLSMALLINT
-> Ptr SQLSMALLINT
-> Ptr SQLULEN
-> Ptr SQLSMALLINT
-> Ptr SQLSMALLINT
-> IO RETCODE
withMalloc :: Storable a => (Ptr a -> IO b) -> IO b
withMalloc m = bracket malloc free m
withCallocBytes :: Storable a => Int -> (Ptr a -> IO b) -> IO b
withCallocBytes n m = bracket (callocBytes n) free m