{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE OverloadedStrings #-}
module Squeather.Internal where
import qualified Control.Exception as Exception
import Control.Exception (throwIO)
import Control.Monad (when)
import Data.ByteString (ByteString)
import qualified Data.ByteString as ByteString
import Data.Int (Int64)
import Data.Text (Text)
import qualified Data.Text as Text
import qualified Data.Text.Encoding as Encoding
import qualified Foreign
import Foreign.C.Types (CInt(CInt), CChar, CUChar)
import Foreign (Ptr, FunPtr)
import Squeather.Internal.Bindings (SQLData(SQLNull, SQLText, SQLFloat, SQLInteger, SQLBlob))
import qualified Squeather.Internal.Bindings as Bindings
import Squeather.Internal.Types (ErrorFlag, StepResult, OpenFlags)
import qualified Squeather.Internal.Types as Types
data C'sqlite3
data C'sqlite3_stmt
data C'void
data Database = Database
{ dbPtr :: (Ptr C'sqlite3)
, dbFilename :: Text
} deriving (Eq, Ord, Show)
data Statement = Statement
{ stmtPtr :: Ptr C'sqlite3_stmt
, stmtSql :: Text
, stmtDb :: Database
} deriving (Eq, Ord, Show)
data SqueatherErrorFlag
= ParameterNotFound
| ExecFailed
| IntConversion
| UnknownColumnType CInt
| UnknownSqliteError CInt
| IncompleteBackup
| Bug
| ColumnNameNull Int
deriving (Eq, Ord, Show)
data Error = Error
{ errorContext :: Text
, errorFlag :: Either ErrorFlag SqueatherErrorFlag
, errorText :: Text
, errorFilename :: Text
} deriving (Eq, Ord, Show)
instance Exception.Exception Error
foreign import ccall unsafe "sqlite3_extended_result_codes" sqlite3_extended_result_codes
:: Ptr C'sqlite3
-> Int
-> IO CInt
foreign import ccall unsafe "sqlite3_open_v2" sqlite3_open_v2
:: Ptr CChar
-> Ptr (Ptr C'sqlite3)
-> CInt
-> Ptr CChar
-> IO CInt
foreign import ccall unsafe "sqlite3_errmsg" sqlite3_errmsg
:: Ptr C'sqlite3
-> IO (Ptr CChar)
readUtf8 :: Ptr CChar -> IO Text
readUtf8 cstr = do
bs <- ByteString.packCString cstr
return . Encoding.decodeUtf8 $ bs
writeUtf8 :: Text -> (Ptr CChar -> IO a) -> IO a
writeUtf8 txt cback = do
let bs = Encoding.encodeUtf8 txt
ByteString.useAsCString bs cback
writeUtf8Len :: Text -> ((Ptr CChar, Int) -> IO a) -> IO a
writeUtf8Len txt cback = do
let bs = Encoding.encodeUtf8 txt
ByteString.useAsCStringLen bs cback
checkError
:: Database
-> Text
-> CInt
-> IO ()
checkError (Database db dbFn) ctx err = case Bindings.parseError err of
Bindings.ParseErrorOk -> return ()
Bindings.ParseErrorStep _ -> Exception.throwIO $ Error
{ errorContext = ctx
, errorFlag = Right Bug
, errorText = "Squeather.checkError: returned StepResult - should never happen"
, errorFilename = dbFn
}
Bindings.ParseErrorError flg -> do
ptrMsg <- sqlite3_errmsg db
errMsg <- readUtf8 ptrMsg
Exception.throwIO $ Error ctx (Left flg) errMsg dbFn
Bindings.ParseErrorNotFound -> Exception.throwIO $ Error
{ errorContext = ctx
, errorFlag = Right $ UnknownSqliteError err
, errorText = Text.pack $ "Squeather.checkError: returned unknown error code " ++ show err
, errorFilename = dbFn
}
checkStepError
:: Database
-> Text
-> CInt
-> IO StepResult
checkStepError (Database dbPtr' dbName) ctx err = case Bindings.parseError err of
Bindings.ParseErrorOk -> Exception.throwIO $ Error
{ errorContext = ctx
, errorFlag = Right Bug
, errorText = "Squeather.checkStepError: returned SQLITE_OK - should never happen"
, errorFilename = dbName
}
Bindings.ParseErrorStep r -> return r
Bindings.ParseErrorError flag -> do
ptrMsg <- sqlite3_errmsg dbPtr'
errMsg <- readUtf8 ptrMsg
Exception.throwIO $ Error ctx (Left flag) errMsg dbName
Bindings.ParseErrorNotFound -> Exception.throwIO $ Error
{ errorContext = ctx
, errorFlag = Right $ UnknownSqliteError err
, errorText = Text.pack $ "Squeather.checkStepError: returned unknown error code " ++ show err
, errorFilename = dbName
}
open
:: Text
-> IO Database
open = openWithFlags openFlags
openWithFlags
:: OpenFlags
-> Text
-> IO Database
openWithFlags flags fn
= writeUtf8 fn $ \fnUtf8 ->
Foreign.alloca $ \ptrIn ->
let acq = do
code <- sqlite3_open_v2 fnUtf8 ptrIn
(Bindings.flagsToInt flags) Foreign.nullPtr
resultPtr <- Foreign.peek ptrIn
return (code, Database resultPtr fn)
rel (_, Database resultPtr _) = sqlite3_close resultPtr
use (code, db@(Database resultPtr _)) = do
checkError db "opening database" code
sqlite3_extended_result_codes resultPtr 1
>>= checkError db "setting extended result codes"
return db
in Exception.bracketOnError acq rel use
foreign import ccall unsafe "sqlite3_prepare_v2" sqlite3_prepare_v2
:: Ptr C'sqlite3
-> Ptr CChar
-> CInt
-> Ptr (Ptr C'sqlite3_stmt)
-> Ptr (Ptr CChar)
-> IO CInt
prepare
:: Database
-> Text
-> IO Statement
prepare db@(Database dbPtr' dbFn) sql
= writeUtf8Len sql $ \(sqlUtf8, sqlLen) ->
Foreign.alloca $ \ptrIn ->
let acq = do
sqlLenCInt <- intToCInt sql dbFn sqlLen
code <- sqlite3_prepare_v2 dbPtr' sqlUtf8 sqlLenCInt ptrIn Foreign.nullPtr
rsltPtr <- Foreign.peek ptrIn
return (code, rsltPtr)
rel (_, rsltPtr) = sqlite3_finalize rsltPtr
use (code, rsltPtr) = do
checkError db sql code
return $ Statement rsltPtr sql db
in Exception.bracketOnError acq rel use
foreign import ccall unsafe "sqlite3_bind_parameter_index" sqlite3_bind_parameter_index
:: Ptr C'sqlite3_stmt
-> Ptr CChar
-> IO CInt
getParameterIndex
:: Statement
-> Text
-> IO CInt
getParameterIndex (Statement stPtr stSql (Database _ dbFn)) param
= writeUtf8 param $ \paramUtf8 -> do
idx <- sqlite3_bind_parameter_index stPtr paramUtf8
if idx == 0
then throwIO $ Error stSql (Right ParameterNotFound)
("parameter not found: " <> param) dbFn
else return idx
foreign import ccall safe "sqlite3_bind_blob" sqlite3_bind_blob
:: Ptr C'sqlite3_stmt
-> CInt
-> Ptr a
-> CInt
-> FunPtr (Ptr a -> IO ())
-> IO CInt
bindBlob
:: Statement
-> Text
-> ByteString
-> IO ()
bindBlob st@(Statement sPtr sSql db) paramName blob
= ByteString.useAsCStringLen blob $ \(ptrBlob, blobLen) -> do
idx <- getParameterIndex st paramName
let transient = Foreign.castPtrToFunPtr . Foreign.intPtrToPtr
$ Bindings.c'SQLITE_TRANSIENT
blobLenCInt <- intToCInt sSql (dbFilename db) blobLen
rslt <- sqlite3_bind_blob sPtr idx ptrBlob blobLenCInt transient
checkError db sSql rslt
foreign import ccall unsafe "sqlite3_bind_double" sqlite3_bind_double
:: Ptr C'sqlite3_stmt
-> CInt
-> Double
-> IO CInt
bindDouble
:: Statement
-> Text
-> Double
-> IO ()
bindDouble st@(Statement sPtr sSql db) paramName dbl = do
idx <- getParameterIndex st paramName
rslt <- sqlite3_bind_double sPtr idx dbl
checkError db sSql rslt
foreign import ccall unsafe "sqlite3_bind_int64" sqlite3_bind_int64
:: Ptr C'sqlite3_stmt
-> CInt
-> Int64
-> IO CInt
bindInt64
:: Statement
-> Text
-> Int64
-> IO ()
bindInt64 st@(Statement sPtr sSql db) paramName int64 = do
idx <- getParameterIndex st paramName
rslt <- sqlite3_bind_int64 sPtr idx int64
checkError db sSql rslt
foreign import ccall unsafe "sqlite3_bind_null" sqlite3_bind_null
:: Ptr C'sqlite3_stmt
-> CInt
-> IO CInt
bindNull
:: Statement
-> Text
-> IO ()
bindNull st@(Statement sPtr sSql db) paramName = do
idx <- getParameterIndex st paramName
rslt <- sqlite3_bind_null sPtr idx
checkError db sSql rslt
foreign import ccall unsafe "sqlite3_bind_text" sqlite3_bind_text
:: Ptr C'sqlite3_stmt
-> CInt
-> Ptr CChar
-> CInt
-> FunPtr (Ptr a -> IO ())
-> IO CInt
bindText
:: Statement
-> Text
-> Text
-> IO ()
bindText st@(Statement sPtr sSql db) paramName txt
= writeUtf8Len txt $ \(ptrTxt, txtLen) -> do
idx <- getParameterIndex st paramName
let transient = Foreign.castPtrToFunPtr . Foreign.intPtrToPtr
$ Bindings.c'SQLITE_TRANSIENT
txtLenCInt <- intToCInt sSql (dbFilename db) txtLen
rslt <- sqlite3_bind_text sPtr idx ptrTxt txtLenCInt transient
checkError db sSql rslt
bindSqlData
:: Statement
-> Text
-> SQLData
-> IO ()
bindSqlData st name sqld = case sqld of
SQLNull -> bindNull st name
SQLText txt -> bindText st name txt
SQLFloat dbl -> bindDouble st name dbl
SQLInteger i64 -> bindInt64 st name i64
SQLBlob blob -> bindBlob st name blob
foreign import ccall unsafe "sqlite3_step" sqlite3_step
:: Ptr C'sqlite3_stmt
-> IO CInt
step :: Statement -> IO StepResult
step (Statement sPtr sSql db) = do
rslt <- sqlite3_step sPtr
checkStepError db sSql rslt
foreign import ccall unsafe "sqlite3_column_count" sqlite3_column_count
:: Ptr C'sqlite3_stmt
-> IO CInt
foreign import ccall unsafe "sqlite3_column_bytes" sqlite3_column_bytes
:: Ptr C'sqlite3_stmt
-> CInt
-> IO CInt
foreign import ccall unsafe "sqlite3_column_type" sqlite3_column_type
:: Ptr C'sqlite3_stmt
-> CInt
-> IO CInt
foreign import ccall unsafe "sqlite3_column_blob" sqlite3_column_blob
:: Ptr C'sqlite3_stmt
-> CInt
-> IO (Ptr a)
foreign import ccall unsafe "sqlite3_column_double" sqlite3_column_double
:: Ptr C'sqlite3_stmt
-> CInt
-> IO Double
foreign import ccall unsafe "sqlite3_column_int64" sqlite3_column_int64
:: Ptr C'sqlite3_stmt
-> CInt
-> IO Int64
foreign import ccall unsafe "sqlite3_column_text" sqlite3_column_text
:: Ptr C'sqlite3_stmt
-> CInt
-> IO (Ptr CUChar)
column
:: Statement
-> Int
-> IO SQLData
column (Statement stPtr sSql db) intIdx = do
idx <- intToCInt sSql (dbFilename db) intIdx
colTypeNum <- sqlite3_column_type stPtr idx
colType <- case Bindings.convertCColumnType colTypeNum of
Just n -> return n
Nothing -> Exception.throwIO $ Error
{ errorContext = sSql
, errorFlag = Right $ UnknownColumnType colTypeNum
, errorText = "Unknown column type found"
, errorFilename = dbFilename db
}
case colType of
SQLNull -> return SQLNull
SQLFloat _ -> fmap SQLFloat $ sqlite3_column_double stPtr idx
SQLBlob _ -> do
resPtr <- sqlite3_column_blob stPtr idx
resLen <- sqlite3_column_bytes stPtr idx
resLenInt <- intFromCInt sSql (dbFilename db) resLen
bs <- ByteString.packCStringLen (resPtr, resLenInt)
return $ SQLBlob bs
SQLInteger _ -> fmap SQLInteger $ sqlite3_column_int64 stPtr idx
SQLText _ -> do
resPtr <- sqlite3_column_text stPtr idx
resLen <- sqlite3_column_bytes stPtr idx
resLenInt <- intFromCInt sSql (dbFilename db) resLen
bs <- ByteString.packCStringLen (Foreign.castPtr resPtr, resLenInt)
return . SQLText . Encoding.decodeUtf8 $ bs
columnCount :: Statement -> IO Int
columnCount (Statement stPtr sSql db)
= sqlite3_column_count stPtr >>= intFromCInt sSql (dbFilename db)
columns :: Statement -> IO [SQLData]
columns st = do
nCols <- columnCount st
mapM (column st) [0 .. nCols - 1]
allRows :: Statement -> IO [[SQLData]]
allRows st = do
r <- step st
case r of
Types.Done -> return []
Types.Row -> do
cols <- columns st
rest <- allRows st
return $ cols : rest
bindParams
:: Statement
-> [(Text, SQLData)]
-> IO ()
bindParams st = mapM_ (uncurry (bindSqlData st))
execute
:: Database
-> Text
-> IO [[SQLData]]
execute db sql = Exception.bracket (prepare db sql) finalize allRows
executeNamed
:: Database
-> Text
-> [(Text, SQLData)]
-> IO [[SQLData]]
executeNamed db sql params = Exception.bracket acq rel use
where
acq = prepare db sql
rel = finalize
use stmt = do
bindParams stmt params
allRows stmt
executeNamedWithColumns
:: Database
-> Text
-> [(Text, SQLData)]
-> IO ([Text], [[SQLData]])
executeNamedWithColumns db sql params = Exception.bracket acq rel use
where
acq = prepare db sql
rel = finalize
use stmt = do
bindParams stmt params
rows <- allRows stmt
names <- columnNames stmt
return (names, rows)
foreign import ccall unsafe "sqlite3_reset" sqlite3_reset
:: Ptr C'sqlite3_stmt
-> IO CInt
reset :: Statement -> IO ()
reset (Statement stPtr _ _) = sqlite3_reset stPtr >> return ()
foreign import ccall unsafe "sqlite3_clear_bindings" sqlite3_clear_bindings
:: Ptr C'sqlite3_stmt
-> IO CInt
clearBindings :: Statement -> IO ()
clearBindings (Statement stPtr _ db)
= sqlite3_clear_bindings stPtr >>= checkError db "clearing bindings"
foreign import ccall unsafe "sqlite3_finalize" sqlite3_finalize
:: Ptr C'sqlite3_stmt
-> IO CInt
finalize :: Statement -> IO ()
finalize (Statement stPtr _ _) = sqlite3_finalize stPtr >> return ()
foreign import ccall unsafe "sqlite3_close" sqlite3_close
:: Ptr C'sqlite3
-> IO CInt
foreign import ccall unsafe "sqlite3_close_v2" sqlite3_close_v2
:: Ptr C'sqlite3
-> IO CInt
close :: Database -> IO ()
close db@(Database dbPtr' _) = sqlite3_close_v2 dbPtr' >>= checkError db "close database"
type ExecCallback a
= Ptr a
-> CInt
-> Ptr (Ptr CChar)
-> Ptr (Ptr CChar)
-> IO CInt
foreign import ccall "sqlite3_exec" sqlite3_exec
:: Ptr C'sqlite3
-> Ptr CChar
-> FunPtr (ExecCallback a)
-> Ptr a
-> Ptr (Ptr CChar)
-> IO CInt
foreign import ccall unsafe "sqlite3_free" sqlite3_free
:: Ptr a
-> IO ()
exec
:: Database
-> Text
-> IO ()
exec db@(Database dbPtr' dbFn) sqlTxt =
writeUtf8 sqlTxt $ \ptrSql ->
Foreign.alloca $ \strErr -> do
Foreign.poke strErr Foreign.nullPtr
let cleanup = Foreign.peek strErr >>= sqlite3_free
runExec = do
code <- sqlite3_exec dbPtr' ptrSql Foreign.nullFunPtr Foreign.nullPtr strErr
errVal <- Foreign.peek strErr
when (errVal /= Foreign.nullPtr) $ do
errTxt <- readUtf8 errVal
Exception.throwIO $ Error sqlTxt (Right ExecFailed) errTxt dbFn
checkError db sqlTxt code
Exception.finally runExec cleanup
foreign import ccall unsafe "sqlite3_last_insert_rowid" sqlite3_last_insert_rowid
:: Ptr C'sqlite3
-> IO Int64
lastInsertRowId :: Database -> IO Int64
lastInsertRowId (Database ptrDb _) = sqlite3_last_insert_rowid ptrDb
intToCInt
:: Text
-> Text
-> Int
-> IO CInt
intToCInt ctx fn i
| iConv > fromIntegral (maxBound :: CInt)
= throw . Text.pack $ "number too big to convert to CInt: " ++ show i
| iConv < fromIntegral (minBound :: CInt)
= throw . Text.pack $ "number too small to convert to CInt: " ++ show i
| otherwise = return $ fromIntegral i
where
iConv = fromIntegral i :: Integer
throw str = Exception.throwIO exc
where
exc = Error { errorContext = ctx
, errorFlag = Right IntConversion
, errorText = str
, errorFilename = fn
}
intFromCInt
:: Text
-> Text
-> CInt
-> IO Int
intFromCInt ctx fn i
| iConv > fromIntegral (maxBound :: Int)
= throw . Text.pack $ "number too big to convert to Int: " ++ show i
| iConv < fromIntegral (minBound :: Int)
= throw . Text.pack $ "number too small to convert to Int: " ++ show i
| otherwise = return $ fromIntegral i
where
iConv = fromIntegral i :: Integer
throw str = Exception.throwIO exc
where
exc = Error { errorContext = ctx
, errorFlag = Right IntConversion
, errorText = str
, errorFilename = fn
}
sqliteVersion :: String
sqliteVersion = Bindings.c'SQLITE_VERSION
openFlags :: OpenFlags
openFlags = Types.OpenFlags
{ Types.writeMode = Types.ReadWrite Types.Create
, Types.uri = False
, Types.memory = False
, Types.noMutex = False
, Types.fullMutex = True
, Types.sharedCache = False
, Types.privateCache = False
, Types.noFollow = False
}
data C'sqlite3_backup
foreign import ccall unsafe "sqlite3_backup_init" sqlite3_backup_init
:: Ptr C'sqlite3
-> Ptr CChar
-> Ptr C'sqlite3
-> Ptr CChar
-> IO (Ptr C'sqlite3_backup)
foreign import ccall unsafe "sqlite3_backup_step" sqlite3_backup_step
:: Ptr C'sqlite3_backup
-> CInt
-> IO CInt
foreign import ccall unsafe "sqlite3_backup_finish" sqlite3_backup_finish
:: Ptr C'sqlite3_backup
-> IO CInt
foreign import ccall unsafe "sqlite3_backup_remaining" sqlite3_backup_remaining
:: Ptr C'sqlite3_backup
-> IO CInt
foreign import ccall unsafe "sqlite3_backup_pagecount" sqlite3_backup_pagecount
:: Ptr C'sqlite3_backup
-> IO CInt
data Source = Source
{ sourceConnection :: Database
, sourceName :: Text
} deriving (Eq, Ord, Show)
data Destination = Destination
{ destConnection :: Database
, destName :: Text
} deriving (Eq, Ord, Show)
backup :: Source -> Destination -> IO ()
backup src dest = Exception.bracket acq rel use
where
acq = writeUtf8 (sourceName src) $ \ptrSrcName ->
writeUtf8 (destName dest) $ \ptrDestName ->
sqlite3_backup_init (dbPtr . destConnection $ dest) ptrDestName
(dbPtr . sourceConnection $ src) ptrSrcName
rel = sqlite3_backup_finish
use bkpPtr = do
code <- sqlite3_backup_step bkpPtr (-1)
case Bindings.parseError code of
Bindings.ParseErrorStep Types.Done -> return ()
Bindings.ParseErrorOk -> Exception.throwIO $ Error
{ errorContext = ctx
, errorFlag = Right IncompleteBackup
, errorText = "Squeather.backup: backup did not complete"
, errorFilename = ctx
}
Bindings.ParseErrorStep Types.Row -> Exception.throwIO $ Error
{ errorContext = ctx
, errorFlag = Right Bug
, errorText = "Squeather.backup: returned Row StepResult - should never happen"
, errorFilename = ctx
}
Bindings.ParseErrorError flg -> Exception.throwIO $ Error
{ errorContext = ctx
, errorFlag = Left flg
, errorText = "Squeather.backup: error during backup"
, errorFilename = ctx
}
Bindings.ParseErrorNotFound -> Exception.throwIO $ Error
{ errorContext = ctx
, errorFlag = Right $ UnknownSqliteError code
, errorText = "Squeather.backup: error during backup - code not found"
, errorFilename = ctx
}
ctx = "during backup from " <> dbFilename (sourceConnection src) <> " to "
<> dbFilename (destConnection dest)
foreign import ccall unsafe "sqlite3_changes" sqlite3_changes
:: Ptr C'sqlite3
-> IO CInt
changes :: Database -> IO Int
changes (Database dbPtr' dbName) =
sqlite3_changes dbPtr' >>= intFromCInt "changes" dbName
foreign import ccall unsafe "sqlite3_column_name" sqlite3_column_name
:: Ptr C'sqlite3_stmt
-> CInt
-> IO (Ptr CChar)
columnName
:: Statement
-> Int
-> IO Text
columnName (Statement stPtr stSql db) idx = do
cIntIdx <- intToCInt ("getting column name in " <> stSql) (dbFilename db) idx
ptrStr <- sqlite3_column_name stPtr cIntIdx
if ptrStr == Foreign.nullPtr
then throwIO $ Error
{ errorContext = stSql
, errorFlag = Right (ColumnNameNull idx)
, errorText = Text.pack $ "null pointer returned when getting column name for index " <> show idx
, errorFilename = dbFilename db
}
else readUtf8 ptrStr
columnNames :: Statement -> IO [Text]
columnNames stmt = do
i <- columnCount stmt
mapM (columnName stmt) [0 .. (i - 1)]