{-# LANGUAGE BangPatterns #-} {-# LANGUAGE DeriveDataTypeable #-} -- | -- This API is a slightly lower-level version of "Database.SQLite3". Namely: -- -- * It returns errors instead of throwing them. -- -- * It only uses cheap conversions. None of these bindings convert from -- 'String' or 'T.Text'. module Database.SQLite3.Direct ( -- * Connection management open, close, errmsg, setTrace, getAutoCommit, -- * Simple query execution -- | exec, execWithCallback, ExecCallback, -- * Statement management prepare, getStatementDatabase, step, reset, finalize, clearBindings, statementSql, -- * Parameter and column information bindParameterCount, bindParameterName, columnCount, columnName, -- * Binding values to a prepared statement -- | bindInt64, bindDouble, bindText, bindBlob, bindNull, -- * Reading the result row -- | columnType, columnInt64, columnDouble, columnText, columnBlob, -- * Result statistics lastInsertRowId, changes, totalChanges, -- * Interrupting a long-running query interrupt, -- * Types Database(..), Statement(..), ColumnType(..), -- ** Results and errors StepResult(..), Error(..), -- ** Special types Utf8(..), ParamIndex(..), ColumnIndex(..), ColumnCount, ) where import Database.SQLite3.Bindings import qualified Data.ByteString as BS import qualified Data.ByteString.Unsafe as BSU import qualified Data.Text as T import qualified Data.Text.Encoding as T import Control.Applicative ((<$>)) import Control.Exception as E import Control.Monad (join) import Data.ByteString (ByteString) import Data.IORef import Data.Monoid import Data.String (IsString(..)) import Data.Text.Encoding.Error (lenientDecode) import Foreign import Foreign.C newtype Database = Database (Ptr CDatabase) deriving (Eq, Show) newtype Statement = Statement (Ptr CStatement) deriving (Eq, Show) data StepResult = Row | Done deriving (Eq, Show) -- | A 'ByteString' containing UTF8-encoded text with no NUL characters. newtype Utf8 = Utf8 ByteString deriving (Eq, Ord) instance Show Utf8 where show (Utf8 s) = (show . T.decodeUtf8With lenientDecode) s -- | @fromString = Utf8 . 'T.encodeUtf8' . 'T.pack'@ instance IsString Utf8 where fromString = Utf8 . T.encodeUtf8 . T.pack instance Monoid Utf8 where mempty = Utf8 BS.empty mappend (Utf8 a) (Utf8 b) = Utf8 (BS.append a b) mconcat = Utf8 . BS.concat . map (\(Utf8 s) -> s) packUtf8 :: a -> (Utf8 -> a) -> CString -> IO a packUtf8 n f cstr | cstr == nullPtr = return n | otherwise = f . Utf8 <$> BS.packCString cstr packCStringLen :: CString -> CNumBytes -> IO ByteString packCStringLen cstr len = BS.packCStringLen (cstr, fromIntegral len) packUtf8Array :: IO a -> (Utf8 -> IO a) -> Int -> Ptr CString -> IO [a] packUtf8Array onNull onUtf8 count base = peekArray count base >>= mapM (join . packUtf8 onNull onUtf8) -- | Like 'unsafeUseAsCStringLen', but if the string is empty, -- never pass the callback a null pointer. unsafeUseAsCStringLenNoNull :: ByteString -> (CString -> CNumBytes -> IO a) -> IO a unsafeUseAsCStringLenNoNull bs cb | BS.null bs = cb (intPtrToPtr 1) 0 | otherwise = BSU.unsafeUseAsCStringLen bs $ \(ptr, len) -> cb ptr (fromIntegral len) wrapNullablePtr :: (Ptr a -> b) -> Ptr a -> Maybe b wrapNullablePtr f ptr | ptr == nullPtr = Nothing | otherwise = Just (f ptr) type Result a = Either Error a -- Convert a 'CError' to a 'Result', in the common case where -- SQLITE_OK signals success and anything else signals an error. -- -- Note that SQLITE_OK == 0. toResult :: a -> CError -> Result a toResult a (CError 0) = Right a toResult _ code = Left $ decodeError code -- Only perform the action if the 'CError' is SQLITE_OK. toResultM :: Monad m => m a -> CError -> m (Result a) toResultM m (CError 0) = m >>= return . Right toResultM _ code = return $ Left $ decodeError code toStepResult :: CError -> Result StepResult toStepResult code = case decodeError code of ErrorRow -> Right Row ErrorDone -> Right Done err -> Left err ------------------------------------------------------------------------ -- | open :: Utf8 -> IO (Either (Error, Utf8) Database) open (Utf8 path) = BS.useAsCString path $ \path' -> alloca $ \database -> do rc <- c_sqlite3_open path' database db <- Database <$> peek database -- sqlite3_open returns a sqlite3 even on failure. -- That's where we get a more descriptive error message. case toResult () rc of Left err -> do msg <- errmsg db -- This returns "out of memory" if db is null. _ <- close db -- This is harmless if db is null. return $ Left (err, msg) Right () -> if db == Database nullPtr then fail "sqlite3_open unexpectedly returned NULL" else return $ Right db -- | close :: Database -> IO (Either Error ()) close (Database db) = toResult () <$> c_sqlite3_close db -- | -- -- Cause any pending operation on the 'Database' handle to stop at its earliest -- opportunity. This simply sets a flag and returns immediately. It does not -- wait for the pending operation to finish. -- -- You'll need to compile with @-threaded@ for this to do any good. -- Without @-threaded@, FFI calls block the whole RTS, meaning 'interrupt' -- would never run at the same time as 'step'. interrupt :: Database -> IO () interrupt (Database db) = c_sqlite3_interrupt db -- | errmsg :: Database -> IO Utf8 errmsg (Database db) = c_sqlite3_errmsg db >>= packUtf8 (Utf8 BS.empty) id exec :: Database -> Utf8 -> IO (Either (Error, Utf8) ()) exec (Database db) (Utf8 sql) = BS.useAsCString sql $ \sql' -> alloca $ \msgPtrOut -> do rc <- c_sqlite3_exec db sql' nullFunPtr nullPtr msgPtrOut case toResult () rc of Left err -> do msgPtr <- peek msgPtrOut msg <- packUtf8 (Utf8 BS.empty) id msgPtr c_sqlite3_free msgPtr return $ Left (err, msg) Right () -> return $ Right () -- | Like 'exec', but invoke the callback for each result row. -- -- If the callback throws an exception, it will be rethrown by -- 'execWithCallback'. execWithCallback :: Database -> Utf8 -> ExecCallback -> IO (Either (Error, Utf8) ()) execWithCallback (Database db) (Utf8 sql) cb = do abortReason <- newIORef Nothing :: IO (IORef (Maybe SomeException)) cbCache <- newIORef Nothing :: IO (IORef (Maybe ([Maybe Utf8] -> IO ()))) -- Cache the partial application of column count and name, so if the -- caller wants to convert them to something else, it only has to do -- the conversions once. let getCallback cCount cNames = do m <- readIORef cbCache case m of Nothing -> do names <- packUtf8Array (fail "execWithCallback: NULL column name") return (fromIntegral cCount) cNames let !cb' = cb (fromFFI cCount) names writeIORef cbCache $ Just cb' return cb' Just cb' -> return cb' let onExceptionAbort io = (io >> return 0) `E.catch` \ex -> do writeIORef abortReason $ Just ex return 1 let cExecCallback _ctx cCount cValues cNames = onExceptionAbort $ do cb' <- getCallback cCount cNames values <- packUtf8Array (return Nothing) (return . Just) (fromIntegral cCount) cValues cb' values BS.useAsCString sql $ \sql' -> alloca $ \msgPtrOut -> bracket (mkCExecCallback cExecCallback) freeHaskellFunPtr $ \pExecCallback -> do let returnError err = do msgPtr <- peek msgPtrOut msg <- packUtf8 (Utf8 BS.empty) id msgPtr c_sqlite3_free msgPtr return $ Left (err, msg) rc <- c_sqlite3_exec db sql' pExecCallback nullPtr msgPtrOut case toResult () rc of Left ErrorAbort -> do m <- readIORef abortReason case m of Nothing -> returnError ErrorAbort Just ex -> throwIO ex Left err -> returnError err Right () -> return $ Right () type ExecCallback = ColumnCount -- ^ Number of columns, which is the number of items in -- the following lists. This will be the same for -- every row. -> [Utf8] -- ^ List of column names. This will be the same -- for every row. -> [Maybe Utf8] -- ^ List of column values, as returned by 'columnText'. -> IO () -- | -- -- Enable/disable tracing of SQL execution. Tracing can be disabled -- by setting 'Nothing' as the logger callback. -- -- Warning: If the logger callback throws an exception, your whole -- program will crash. Enable only for debugging! setTrace :: Database -> Maybe (Utf8 -> IO ()) -> IO () setTrace (Database db) logger = case logger of Nothing -> do _ <- c_sqlite3_trace db nullFunPtr nullPtr return () Just output -> do -- NB: this FunPtr never gets freed. Shouldn't be a big deal, -- though, since 'setTrace' is mainly for debugging, and is -- typically only called once per application invocation. cb <- mkCTraceCallback $ \_ctx cStr -> do msg <- packUtf8 (Utf8 BS.empty) id cStr output msg _ <- c_sqlite3_trace db cb nullPtr return () -- | -- -- Return 'True' if the connection is in autocommit mode, or 'False' if a -- transaction started with @BEGIN@ is still active. -- -- Be warned that some errors roll back the transaction automatically, -- and that @ROLLBACK@ will throw an error if no transaction is active. -- Use 'getAutoCommit' to avoid such an error: -- -- @ -- autocommit <- 'getAutoCommit' conn -- 'Control.Monad.when' (not autocommit) $ -- 'Database.SQLite3.exec' conn \"ROLLBACK\" -- @ getAutoCommit :: Database -> IO Bool getAutoCommit (Database db) = (/= 0) <$> c_sqlite3_get_autocommit db -- | -- -- If the query contains no SQL statements, this returns -- @'Right' 'Nothing'@. prepare :: Database -> Utf8 -> IO (Either Error (Maybe Statement)) prepare (Database db) (Utf8 sql) = BS.useAsCString sql $ \sql' -> alloca $ \statement -> c_sqlite3_prepare_v2 db sql' (-1) statement nullPtr >>= toResultM (wrapNullablePtr Statement <$> peek statement) -- | getStatementDatabase :: Statement -> IO Database getStatementDatabase (Statement stmt) = do db <- c_sqlite3_db_handle stmt if db == nullPtr then fail $ "sqlite3_db_handle(" ++ show stmt ++ ") returned NULL" else return (Database db) -- | step :: Statement -> IO (Either Error StepResult) step (Statement stmt) = toStepResult <$> c_sqlite3_step stmt -- | -- -- Warning: -- -- * If the most recent 'step' call failed, -- this will return the corresponding error. -- -- * This does not reset the bindings on a prepared statement. -- Use 'clearBindings' to do that. reset :: Statement -> IO (Either Error ()) reset (Statement stmt) = toResult () <$> c_sqlite3_reset stmt -- | -- -- /Warning:/ If the most recent 'step' call failed, -- this will return the corresponding error. finalize :: Statement -> IO (Either Error ()) finalize (Statement stmt) = toResult () <$> c_sqlite3_finalize stmt -- | -- -- Return a copy of the original SQL text used to compile the statement. statementSql :: Statement -> IO (Maybe Utf8) statementSql (Statement stmt) = c_sqlite3_sql stmt >>= packUtf8 Nothing Just -- | -- -- Set all parameters in the prepared statement to null. clearBindings :: Statement -> IO () clearBindings (Statement stmt) = do _ <- c_sqlite3_clear_bindings stmt return () -- | -- -- This returns the index of the largest (rightmost) parameter. Note that this -- is not necessarily the number of parameters. If numbered parameters like -- @?5@ are used, there may be gaps in the list. -- -- See 'ParamIndex' for more information. bindParameterCount :: Statement -> IO ParamIndex bindParameterCount (Statement stmt) = fromFFI <$> c_sqlite3_bind_parameter_count stmt -- | bindParameterName :: Statement -> ParamIndex -> IO (Maybe Utf8) bindParameterName (Statement stmt) idx = c_sqlite3_bind_parameter_name stmt (toFFI idx) >>= packUtf8 Nothing Just -- | columnCount :: Statement -> IO ColumnCount columnCount (Statement stmt) = fromFFI <$> c_sqlite3_column_count stmt -- | columnName :: Statement -> ColumnIndex -> IO (Maybe Utf8) columnName (Statement stmt) idx = c_sqlite3_column_name stmt (toFFI idx) >>= packUtf8 Nothing Just bindInt64 :: Statement -> ParamIndex -> Int64 -> IO (Either Error ()) bindInt64 (Statement stmt) idx value = toResult () <$> c_sqlite3_bind_int64 stmt (toFFI idx) value bindDouble :: Statement -> ParamIndex -> Double -> IO (Either Error ()) bindDouble (Statement stmt) idx value = toResult () <$> c_sqlite3_bind_double stmt (toFFI idx) value bindText :: Statement -> ParamIndex -> Utf8 -> IO (Either Error ()) bindText (Statement stmt) idx (Utf8 value) = unsafeUseAsCStringLenNoNull value $ \ptr len -> toResult () <$> c_sqlite3_bind_text stmt (toFFI idx) ptr len c_SQLITE_TRANSIENT bindBlob :: Statement -> ParamIndex -> ByteString -> IO (Either Error ()) bindBlob (Statement stmt) idx value = unsafeUseAsCStringLenNoNull value $ \ptr len -> toResult () <$> c_sqlite3_bind_blob stmt (toFFI idx) ptr len c_SQLITE_TRANSIENT bindNull :: Statement -> ParamIndex -> IO (Either Error ()) bindNull (Statement stmt) idx = toResult () <$> c_sqlite3_bind_null stmt (toFFI idx) columnType :: Statement -> ColumnIndex -> IO ColumnType columnType (Statement stmt) idx = decodeColumnType <$> c_sqlite3_column_type stmt (toFFI idx) columnInt64 :: Statement -> ColumnIndex -> IO Int64 columnInt64 (Statement stmt) idx = c_sqlite3_column_int64 stmt (toFFI idx) columnDouble :: Statement -> ColumnIndex -> IO Double columnDouble (Statement stmt) idx = c_sqlite3_column_double stmt (toFFI idx) columnText :: Statement -> ColumnIndex -> IO Utf8 columnText (Statement stmt) idx = do ptr <- c_sqlite3_column_text stmt (toFFI idx) len <- c_sqlite3_column_bytes stmt (toFFI idx) Utf8 <$> packCStringLen ptr len columnBlob :: Statement -> ColumnIndex -> IO ByteString columnBlob (Statement stmt) idx = do ptr <- c_sqlite3_column_blob stmt (toFFI idx) len <- c_sqlite3_column_bytes stmt (toFFI idx) packCStringLen ptr len -- | lastInsertRowId :: Database -> IO Int64 lastInsertRowId (Database db) = c_sqlite3_last_insert_rowid db -- | -- -- Return the number of rows that were changed, inserted, or deleted -- by the most recent @INSERT@, @DELETE@, or @UPDATE@ statement. changes :: Database -> IO Int changes (Database db) = fromIntegral <$> c_sqlite3_changes db -- | -- -- Return the total number of row changes caused by @INSERT@, @DELETE@, -- or @UPDATE@ statements since the 'Database' was opened. totalChanges :: Database -> IO Int totalChanges (Database db) = fromIntegral <$> c_sqlite3_total_changes db