{-# LANGUAGE ForeignFunctionInterface, DeriveDataTypeable #-} module Database.Sqlite ( Database, Statement, Error, StepResult(Row, Done), SQLData(SQLInteger, SQLFloat, SQLText, SQLBlob, SQLNull), ColumnType(IntegerColumn, FloatColumn, TextColumn, BlobColumn, NullColumn), open, close, prepare, step, reset, finalize, bindParameterCount, bindParameterName, bindBlob, bindDouble, bindInt, bindInt64, bindNull, bindText, bind, column, columns, unsafeColumns ) where import qualified Data.ByteString as BS import qualified Data.ByteString.Internal as BSI import qualified Data.Text as T import qualified Data.Text.Encoding as T import Data.Typeable import Foreign import Foreign.C newtype Database = Database (Ptr ()) data Statement = Statement (Ptr ()) Database type Error = Int errorOK :: Error errorOK = 0 errorRow :: Error errorRow = 100 errorDone :: Error errorDone = 101 data StepResult = Row | Done deriving (Eq, Show) data ColumnType = IntegerColumn | FloatColumn | TextColumn | BlobColumn | NullColumn deriving (Eq, Show) data SQLData = SQLInteger Int64 | SQLFloat Double | SQLText T.Text | SQLBlob BS.ByteString | SQLNull deriving (Eq, Show, Typeable) showError :: Error -> String showError 0 = "ErrorOK" showError 1 = "ErrorError" showError 2 = "ErrorInternal" showError 3 = "ErrorPermission" showError 4 = "ErrorAbort" showError 5 = "ErrorBusy" showError 6 = "ErrorLocked" showError 7 = "ErrorNoMemory" showError 8 = "ErrorReadOnly" showError 9 = "ErrorInterrupt" showError 10 = "ErrorIO" showError 11 = "ErrorNotFound" showError 12 = "ErrorCorrupt" showError 13 = "ErrorFull" showError 14 = "ErrorCan'tOpen" showError 15 = "ErrorProtocol" showError 16 = "ErrorEmpty" showError 17 = "ErrorSchema" showError 18 = "ErrorTooBig" showError 19 = "ErrorConstraint" showError 20 = "ErrorMismatch" showError 21 = "ErrorMisuse" showError 22 = "ErrorNoLargeFileSupport" showError 23 = "ErrorAuthorization" showError 24 = "ErrorFormat" showError 25 = "ErrorRange" showError 26 = "ErrorNotADatabase" showError 100 = "ErrorRow" showError 101 = "ErrorDone" showError x = "Unknown error: " ++ show x decodeColumnType :: Int -> ColumnType decodeColumnType 1 = IntegerColumn decodeColumnType 2 = FloatColumn decodeColumnType 3 = TextColumn decodeColumnType 4 = BlobColumn decodeColumnType 5 = NullColumn decodeColumnType x = error $ "Unknown column type: " ++ show x foreign import ccall "sqlite3_errmsg" errmsgC :: Ptr () -> IO CString errmsg :: Database -> IO String errmsg (Database database) = do message <- errmsgC database byteString <- BS.packCString message return $ T.unpack $ T.decodeUtf8 byteString sqlError :: Maybe Database -> String -> Error -> IO a sqlError maybeDatabase functionName err = do details <- case maybeDatabase of Just database -> do details <- errmsg database return $ ": " ++ details Nothing -> return "." fail $ "SQLite3 returned " ++ (showError err) ++ " while attempting to perform " ++ functionName ++ details foreign import ccall "sqlite3_open" openC :: CString -> Ptr (Ptr ()) -> IO Int openError :: String -> IO (Either Database Error) openError path = do BS.useAsCString (T.encodeUtf8 $ T.pack path) (\pathC -> do alloca (\databaseC -> do err <- openC pathC databaseC if err == errorOK then do database <- peek databaseC return $ Left $ Database database else return $ Right err)) open :: String -> IO Database open path = do databaseOrError <- openError path case databaseOrError of Left database -> return database Right err -> sqlError Nothing ("open " ++ show path) err foreign import ccall "sqlite3_close" closeC :: Ptr () -> IO Int closeError :: Database -> IO Error closeError (Database database) = do err <- closeC database return err close :: Database -> IO () close database = do err <- closeError database if err == errorOK then return () else sqlError (Just database) "close" err foreign import ccall "sqlite3_prepare_v2" prepareC :: Ptr () -> CString -> Int -> Ptr (Ptr ()) -> Ptr (Ptr ()) -> IO Int prepareError :: Database -> BS.ByteString -> IO (Either Statement Error) prepareError db@(Database database) text = do BS.useAsCString text (\textC -> do alloca (\statementC -> do err <- prepareC database textC (-1) statementC nullPtr if err == errorOK then do statement <- peek statementC return $ Left $ Statement statement db else return $ Right err)) prepare :: Database -> BS.ByteString -> IO Statement prepare database text = do statementOrError <- prepareError database text case statementOrError of Left statement -> return statement Right err -> sqlError (Just database) ("prepare " ++ (show text)) err foreign import ccall "sqlite3_step" stepC :: Ptr () -> IO Int step :: Statement -> IO StepResult step (Statement statement database) = do err <- stepC statement if err == errorRow then return Row else if err == errorDone then return Done else sqlError (Just database) "step" err foreign import ccall "sqlite3_reset" resetC :: Ptr () -> IO Int reset :: Statement -> IO () reset (Statement statement database) = do err <- resetC statement if err == errorOK then return () else sqlError (Just database) "reset" err foreign import ccall "sqlite3_finalize" finalizeC :: Ptr () -> IO Int finalize :: Statement -> IO () finalize (Statement statement database) = do err <- finalizeC statement if err == errorOK then return () else sqlError (Just database) "finalize" err foreign import ccall "sqlite3_bind_parameter_count" bindParameterCountC :: Ptr () -> IO Int -- | Find the number SQL parameters in a prepared statement. bindParameterCount :: Statement -> IO Int bindParameterCount (Statement stmt _) = do bindParameterCountC stmt maybeNullCString :: CString -> IO (Maybe BS.ByteString) maybeNullCString s = if s == nullPtr then return Nothing else fmap Just (BS.packCString s) foreign import ccall "sqlite3_bind_parameter_name" bindParameterNameC :: Ptr () -> Int -> IO CString -- | Return the N-th SQL parameter name. -- -- Named parameters are returned as-is. E.g. \":v\" is returned as -- @Just \":v\"@. Unnamed parameters, however, are converted to -- @Nothing@. -- -- Note that the column index starts at 1, not 0. bindParameterName :: Statement -> Int -> IO (Maybe String) bindParameterName (Statement stmt _) colNdx = do mn <- bindParameterNameC stmt colNdx >>= maybeNullCString return (mn >>= return . T.unpack . T.decodeUtf8) foreign import ccall "sqlite3_bind_blob" bindBlobC :: Ptr () -> Int -> Ptr () -> Int -> Ptr () -> IO Int bindBlob :: Statement -> Int -> BS.ByteString -> IO () bindBlob (Statement statement database) parameterIndex byteString = do size <- return $ BS.length byteString err <- BS.useAsCString byteString (\dataC -> do err <- bindBlobC statement parameterIndex (castPtr dataC) size (intPtrToPtr (-1)) return err) if err == errorOK then return () else sqlError (Just database) "bind blob" err foreign import ccall "sqlite3_bind_double" bindDoubleC :: Ptr () -> Int -> Double -> IO Int bindDouble :: Statement -> Int -> Double -> IO () bindDouble (Statement statement database) parameterIndex datum = do err <- bindDoubleC statement parameterIndex datum if err == errorOK then return () else sqlError (Just database) "bind double" err foreign import ccall "sqlite3_bind_int" bindIntC :: Ptr () -> Int -> Int -> IO Int bindInt :: Statement -> Int -> Int -> IO () bindInt (Statement statement database) parameterIndex datum = do err <- bindIntC statement parameterIndex datum if err == errorOK then return () else sqlError (Just database) "bind int" err foreign import ccall "sqlite3_bind_int64" bindInt64C :: Ptr () -> Int -> Int64 -> IO Int bindInt64 :: Statement -> Int -> Int64 -> IO () bindInt64 (Statement statement database) parameterIndex datum = do err <- bindInt64C statement parameterIndex datum if err == errorOK then return () else sqlError (Just database) "bind int64" err foreign import ccall "sqlite3_bind_null" bindNullC :: Ptr () -> Int -> IO Int bindNull :: Statement -> Int -> IO () bindNull (Statement statement database) parameterIndex = do err <- bindNullC statement parameterIndex if err == errorOK then return () else sqlError (Just database) "bind null" err foreign import ccall "sqlite3_bind_text" bindTextC :: Ptr () -> Int -> CString -> Int -> Ptr () -> IO Int bindText :: Statement -> Int -> T.Text -> IO () bindText (Statement statement database) parameterIndex text = do byteString <- return $ T.encodeUtf8 text size <- return $ BS.length byteString err <- BS.useAsCString byteString (\dataC -> do err <- bindTextC statement parameterIndex dataC size (intPtrToPtr (-1)) return err) if err == errorOK then return () else sqlError (Just database) "bind text" err bind :: Statement -> [SQLData] -> IO () bind statement sqlData = do mapM_ (\(parameterIndex, datum) -> do case datum of SQLInteger int64 -> bindInt64 statement parameterIndex int64 SQLFloat double -> bindDouble statement parameterIndex double SQLText text -> bindText statement parameterIndex text SQLBlob blob -> bindBlob statement parameterIndex blob SQLNull -> bindNull statement parameterIndex) $ zip [1..] sqlData return () foreign import ccall "sqlite3_column_type" columnTypeC :: Ptr () -> Int -> IO Int columnType :: Statement -> Int -> IO ColumnType columnType (Statement statement _) columnIndex = do result <- columnTypeC statement columnIndex return $ decodeColumnType result foreign import ccall "sqlite3_column_bytes" columnBytesC :: Ptr () -> Int -> IO Int foreign import ccall "sqlite3_column_blob" columnBlobC :: Ptr () -> Int -> IO (Ptr ()) columnBlob :: Statement -> Int -> IO BS.ByteString columnBlob (Statement statement _) columnIndex = do size <- columnBytesC statement columnIndex BSI.create size (\resultPtr -> do dataPtr <- columnBlobC statement columnIndex if dataPtr /= nullPtr then BSI.memcpy resultPtr (castPtr dataPtr) (fromIntegral size) else return ()) foreign import ccall "sqlite3_column_int64" columnInt64C :: Ptr () -> Int -> IO Int64 columnInt64 :: Statement -> Int -> IO Int64 columnInt64 (Statement statement _) columnIndex = do columnInt64C statement columnIndex foreign import ccall "sqlite3_column_double" columnDoubleC :: Ptr () -> Int -> IO Double columnDouble :: Statement -> Int -> IO Double columnDouble (Statement statement _) columnIndex = do columnDoubleC statement columnIndex foreign import ccall "sqlite3_column_text" columnTextC :: Ptr () -> Int -> IO CString columnText :: Statement -> Int -> IO T.Text columnText (Statement statement _) columnIndex = do text <- columnTextC statement columnIndex byteString <- BS.packCString text return $ T.decodeUtf8 byteString foreign import ccall "sqlite3_column_count" columnCountC :: Ptr () -> IO Int columnCount :: Statement -> IO Int columnCount (Statement statement _) = do columnCountC statement column :: Statement -> Int -> IO SQLData column statement columnIndex = do theType <- columnType statement columnIndex typedColumn theType statement columnIndex typedColumn :: ColumnType -> Statement -> Int -> IO SQLData typedColumn theType statement columnIndex = do case theType of IntegerColumn -> fmap SQLInteger $ columnInt64 statement columnIndex FloatColumn -> fmap SQLFloat $ columnDouble statement columnIndex TextColumn -> fmap SQLText $ columnText statement columnIndex BlobColumn -> fmap SQLBlob $ columnBlob statement columnIndex NullColumn -> return SQLNull columns :: Statement -> IO [SQLData] columns statement = do count <- columnCount statement mapM (\i -> column statement i) [0..count-1] unsafeColumns :: Statement -> [Maybe ColumnType] -> IO [SQLData] unsafeColumns statement types = go 0 types where go :: Int -> [Maybe ColumnType] -> IO [SQLData] go n [] = n `seq` return [] go n (t:ts) = do c <- case t of Nothing -> column statement n Just t' -> typedColumn t' statement n cs <- go (n + 1) ts return (c:cs)