{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE CPP #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE MultiParamTypeClasses, ScopedTypeVariables #-} -- | An ODBC backend for persistent. module Database.Persist.ODBC ( withODBCPool , withODBCConn , createODBCPool , module Database.Persist.Sql , ConnectionString , OdbcConf (..) , openSimpleConn , DBType (..) , mysql,postgres,mssqlMin2012,mssql,oracleMin12c,oracle,db2,sqlite ) where import qualified Control.Exception as E -- (catch, SomeException) import Database.Persist.Sql import qualified Database.Persist.MigratePostgres as PG import qualified Database.Persist.MigrateMySQL as MYSQL import qualified Database.Persist.MigrateMSSQL as MSSQL import qualified Database.Persist.MigrateOracle as ORACLE import qualified Database.Persist.MigrateDB2 as DB2 import qualified Database.Persist.MigrateSqlite as SQLITE import Data.Time(ZonedTime(..)) import qualified Database.HDBC.ODBC as O import qualified Database.HDBC as O import qualified Database.HDBC.SqlValue as HSV import qualified Data.Convertible as DC import Control.Monad.IO.Class (MonadIO (..)) import Data.IORef(newIORef) import qualified Data.Map as Map import qualified Data.Text as T import Data.Time.LocalTime (localTimeToUTC, utc) import Data.Text (Text) import Data.Aeson -- (Object(..), (.:)) import Control.Monad (mzero) import Control.Monad.Trans.Control (MonadBaseControl) --import Control.Monad.Trans.Resource (MonadResource) import Control.Monad.Logger import Data.Int (Int64) import Data.Conduit import Database.Persist.ODBCTypes import qualified Data.List as L import Data.Acquire (Acquire, mkAcquire) -- | An @HDBC-odbc@ connection string. A simple example of connection -- string would be @DSN=hdbctest1@. type ConnectionString = String -- | Create an ODBC connection pool and run the given -- action. The pool is properly released after the action -- finishes using it. Note that you should not use the given -- 'ConnectionPool' outside the action since it may be already -- been released. withODBCPool :: (MonadBaseControl IO m, MonadLogger m, MonadIO m) => Maybe DBType -> ConnectionString -- ^ Connection string to the database. -> Int -- ^ Number of connections to be kept open in -- the pool. -> (ConnectionPool -> m a) -- ^ Action to be executed that uses the -- connection pool. -> m a withODBCPool dbt ci = withSqlPool (\lg -> open' lg dbt ci) -- | Create an ODBC connection pool. Note that it's your -- responsibility to properly close the connection pool when -- unneeded. Use 'withODBCPool' for an automatic resource -- control. createODBCPool :: (MonadLogger m, MonadIO m, MonadBaseControl IO m) => Maybe DBType -> ConnectionString -- ^ Connection string to the database. -> Int -- ^ Number of connections to be kept open -- in the pool. -> m ConnectionPool createODBCPool dbt ci = createSqlPool (\lg -> open' lg dbt ci) -- | Same as 'withODBCPool', but instead of opening a pool -- of connections, only one connection is opened. withODBCConn :: (MonadLogger m, MonadIO m, MonadBaseControl IO m) => Maybe DBType -> ConnectionString -> (SqlBackend -> m a) -> m a withODBCConn dbt cs = withSqlConn (\lg -> open' lg dbt cs) -- | helper function that returns a connection based on the database type open' :: LogFunc -> Maybe DBType -> ConnectionString -> IO SqlBackend open' logFunc mdbtype cstr = O.connectODBC cstr >>= openSimpleConn logFunc mdbtype -- | returns a supported database type based on its version -- if the user does not provide the database type explicitly I look it up based on connection metadata findDBMS::(String, String, String) -> DBType findDBMS dvs@(driver,ver,serverver) | driver=="Oracle" = Oracle $ getServerVersionNumber dvs>=12 | "DB2" `L.isPrefixOf` driver = DB2 | driver=="Microsoft SQL Server" = MSSQL $ getServerVersionNumber dvs>=11 | driver=="MySQL" = MySQL | "PostgreSQL" `L.isPrefixOf` driver = Postgres | "SQLite" `L.isPrefixOf` driver = Sqlite False | otherwise = error $ "unknown or unsupported driver[" ++ driver ++ "] ver[" ++ ver ++ "] serverver[" ++ serverver ++ "]\nExplicitly set the type of dbms using DBType and try again!" -- | extracts the server version number getServerVersionNumber::(String, String, String) -> Integer getServerVersionNumber (driver, ver, serverver) = case reads $ takeWhile (/='.') serverver of [(a,"")] -> a xs -> error $ "getServerVersionNumber of findDBMS:cannot tell the version xs=" ++show xs ++ ":" ++ "driver[" ++ driver ++ "] ver[" ++ ver ++ "] serverver[" ++ serverver ++ "]" -- | Generate a persistent 'Connection' from an odbc 'O.Connection' openSimpleConn :: LogFunc -> Maybe DBType -> O.Connection -> IO SqlBackend openSimpleConn logFunc mdbtype conn = do let mig=case mdbtype of Nothing -> getMigrationStrategy $ findDBMS (O.proxiedClientName conn, O.proxiedClientVer conn, O.dbServerVer conn) Just dbtype -> getMigrationStrategy dbtype smap <- newIORef Map.empty return SqlBackend { connLogFunc = logFunc , connPrepare = prepare' conn , connStmtMap = smap , connInsertSql = dbmsInsertSql mig , connClose = O.disconnect conn , connMigrateSql = dbmsMigrate mig , connBegin = const $ E.catch (O.commit conn) (\(_ :: E.SomeException) -> return ()) -- there is no nested transactions. -- Transaction begining means that previous commited , connCommit = const $ O.commit conn , connRollback = const $ O.rollback conn , connEscapeName = dbmsEscape mig , connNoLimit = "" -- esqueleto uses this but needs to use connLimitOffset then we can dump this field , connRDBMS = T.pack $ show (dbmsType mig) , connLimitOffset = dbmsLimitOffset mig } -- | Choose the migration strategy based on the user provided database type getMigrationStrategy::DBType -> MigrationStrategy getMigrationStrategy dbtype = case dbtype of Postgres -> PG.getMigrationStrategy dbtype MySQL -> MYSQL.getMigrationStrategy dbtype MSSQL {} -> MSSQL.getMigrationStrategy dbtype Oracle {} -> ORACLE.getMigrationStrategy dbtype DB2 {} -> DB2.getMigrationStrategy dbtype Sqlite {} -> SQLITE.getMigrationStrategy dbtype prepare' :: O.Connection -> Text -> IO Statement prepare' conn sql = do #if DEBUG putStrLn $ "Database.Persist.ODBC.prepare': sql = " ++ T.unpack sql #endif stmt <- O.prepare conn $ T.unpack sql return Statement { stmtFinalize = O.finish stmt , stmtReset = return () -- rollback conn ? , stmtExecute = execute' stmt , stmtQuery = withStmt' stmt } execute' :: O.Statement -> [PersistValue] -> IO Int64 execute' query vals = fmap fromInteger $ O.execute query $ map (HSV.toSql . P) vals withStmt' :: MonadIO m => O.Statement -> [PersistValue] -> Acquire (Source m [PersistValue]) withStmt' stmt vals = do #if DEBUG liftIO $ putStrLn $ "withStmt': vals: " ++ show vals #endif result <- mkAcquire openS closeS return $ pull result --bracketP openS closeS pull where openS = execute' stmt vals >> return () closeS _ = O.finish stmt pull x = do mr <- liftIO $ O.fetchRow stmt maybe (return ()) (\r -> do #if DEBUG liftIO $ putStrLn $ "withStmt': yield: " ++ show r liftIO $ putStrLn $ "withStmt': yield2: " ++ show (map (unP . HSV.fromSql) r) #endif yield (map (unP . HSV.fromSql) r) pull x ) mr -- | Information required to connect to a PostgreSQL database -- using @persistent@'s generic facilities. These values are the -- same that are given to 'withODBCPool'. data OdbcConf = OdbcConf { odbcConnStr :: ConnectionString -- ^ The connection string. , odbcPoolSize :: Int -- ^ How many connections should be held on the connection pool. , odbcDbtype :: String } instance PersistConfig OdbcConf where type PersistConfigBackend OdbcConf = SqlPersistT type PersistConfigPool OdbcConf = ConnectionPool createPoolConfig (OdbcConf cs size dbtype) = runNoLoggingT $ createODBCPool (read dbtype) cs size runPool _ = runSqlPool loadConfig (Object o) = do cstr <- o .: "connStr" pool <- o .: "poolsize" dbtype <- o .: "dbtype" return $ OdbcConf cstr pool dbtype loadConfig _ = mzero applyEnv c0 = return c0 -- | Avoid orphan instances. newtype P = P { unP :: PersistValue } deriving Show instance DC.Convertible P HSV.SqlValue where safeConvert (P (PersistText t)) = Right $ HSV.toSql t safeConvert (P (PersistByteString bs)) = Right $ HSV.toSql bs safeConvert (P (PersistInt64 i)) = Right $ HSV.toSql i safeConvert (P (PersistRational r)) = Right $ HSV.toSql (fromRational r::Double) safeConvert (P (PersistDouble d)) = Right $ HSV.toSql d safeConvert (P (PersistBool b)) = Right $ HSV.SqlInteger (if b then 1 else 0) safeConvert (P (PersistDay d)) = Right $ HSV.toSql d safeConvert (P (PersistTimeOfDay t)) = Right $ HSV.toSql t safeConvert (P (PersistUTCTime t)) = Right $ HSV.toSql t -- safeConvert (P (PersistZonedTime (ZT t))) = Right $ HSV.toSql t safeConvert (P PersistNull) = Right HSV.SqlNull safeConvert (P (PersistList l)) = Right $ HSV.toSql $ listToJSON l safeConvert (P (PersistMap m)) = Right $ HSV.toSql $ mapToJSON m safeConvert p@(P (PersistObjectId _)) = Left DC.ConvertError { DC.convSourceValue = show p , DC.convSourceType = "P (PersistValue)" , DC.convDestType = "SqlValue" , DC.convErrorMessage = "Refusing to serialize a PersistObjectId to an ODBC value" } safeConvert xs = error $ "unhandled safeConvert xs=" ++ show xs -- FIXME: check if those are correct and complete. instance DC.Convertible HSV.SqlValue P where safeConvert (HSV.SqlString s) = Right $ P $ PersistText $ T.pack s safeConvert (HSV.SqlByteString bs) = Right $ P $ PersistByteString bs safeConvert v@(HSV.SqlWord32 _) = Left DC.ConvertError { DC.convSourceValue = show v , DC.convSourceType = "SqlValue" , DC.convDestType = "P (PersistValue)" , DC.convErrorMessage = "There is no conversion from SqlWord32 to PersistValue" } safeConvert v@(HSV.SqlWord64 _) = Left DC.ConvertError { DC.convSourceValue = show v , DC.convSourceType = "SqlValue" , DC.convDestType = "P (PersistValue)" , DC.convErrorMessage = "There is no conversion from SqlWord64 to PersistValue" } safeConvert (HSV.SqlInt32 i) = Right $ P $ PersistInt64 $ fromIntegral i safeConvert (HSV.SqlInt64 i) = Right $ P $ PersistInt64 i safeConvert (HSV.SqlInteger i) = Right $ P $ PersistInt64 $ fromIntegral i safeConvert (HSV.SqlChar c) = Right $ P $ charChk c safeConvert (HSV.SqlBool b) = Right $ P $ PersistBool b safeConvert (HSV.SqlDouble d) = Right $ P $ PersistDouble d safeConvert (HSV.SqlRational r) = Right $ P $ PersistRational r safeConvert (HSV.SqlLocalDate d) = Right $ P $ PersistDay d safeConvert (HSV.SqlLocalTimeOfDay t)= Right $ P $ PersistTimeOfDay t safeConvert (HSV.SqlZonedLocalTimeOfDay td _) = Right $ P $ PersistTimeOfDay td safeConvert (HSV.SqlLocalTime t) = Right $ P $ PersistUTCTime $ localTimeToUTC utc t safeConvert (HSV.SqlZonedTime zt) = Right $ P $ PersistUTCTime $ localTimeToUTC utc (zonedTimeToLocalTime zt) safeConvert (HSV.SqlUTCTime t) = Right $ P $ PersistUTCTime t safeConvert (HSV.SqlDiffTime ndt) = Right $ P $ PersistDouble $ fromRational $ toRational ndt safeConvert (HSV.SqlPOSIXTime pt) = Right $ P $ PersistDouble $ fromRational $ toRational pt safeConvert (HSV.SqlEpochTime e) = Right $ P $ PersistInt64 $ fromIntegral e safeConvert (HSV.SqlTimeDiff i) = Right $ P $ PersistInt64 $ fromIntegral i safeConvert (HSV.SqlNull) = Right $ P PersistNull charChk :: Char -> PersistValue charChk '\0' = PersistBool False charChk '\1' = PersistBool True charChk c = PersistText $ T.singleton c