module Database.Persist.Sqlite
( withSqlitePool
, withSqliteConn
, createSqlitePool
, module Database.Persist.Sql
, SqliteConf (..)
, runSqlite
, wrapConnection
) where
import Database.Persist.Sql
import qualified Database.Sqlite as Sqlite
import Control.Monad.IO.Class (MonadIO (..))
import Control.Monad.Logger (NoLoggingT, runNoLoggingT)
import Data.IORef
import qualified Data.Map as Map
import Control.Monad.Trans.Control (control)
import qualified Control.Exception as E
import Data.Text (Text)
import Control.Monad (mzero)
import Data.Aeson
import qualified Data.Text as T
import Data.Conduit
import qualified Data.Conduit.List as CL
import Control.Applicative
import Data.Int (Int64)
import Control.Monad ((>=>))
import Data.Monoid ((<>))
import Control.Monad.Trans.Control (MonadBaseControl)
import Control.Monad.Trans.Resource (ResourceT, MonadResource, runResourceT)
createSqlitePool :: MonadIO m => Text -> Int -> m ConnectionPool
createSqlitePool s = createSqlPool $ open' s
withSqlitePool :: (MonadBaseControl IO m, MonadIO m)
=> Text
-> Int
-> (ConnectionPool -> m a) -> m a
withSqlitePool s = withSqlPool $ open' s
withSqliteConn :: (MonadBaseControl IO m, MonadIO m)
=> Text -> (Connection -> m a) -> m a
withSqliteConn = withSqlConn . open'
open' :: Text -> IO Connection
open' = Sqlite.open >=> wrapConnection
wrapConnection :: Sqlite.Connection -> IO Connection
wrapConnection conn = do
smap <- newIORef $ Map.empty
return Connection
{ connPrepare = prepare' conn
, connStmtMap = smap
, connInsertSql = insertSql'
, connClose = Sqlite.close conn
, connMigrateSql = migrate'
, connBegin = helper "BEGIN"
, connCommit = helper "COMMIT"
, connRollback = ignoreExceptions . helper "ROLLBACK"
, connEscapeName = escape
, connNoLimit = "LIMIT -1"
, connRDBMS = "sqlite"
, connLimitOffset = decorateSQLWithLimitOffset "LIMIT -1"
}
where
helper t getter = do
stmt <- getter t
_ <- stmtExecute stmt []
stmtReset stmt
ignoreExceptions = E.handle (\(_ :: E.SomeException) -> return ())
runSqlite :: (MonadBaseControl IO m, MonadIO m)
=> Text
-> SqlPersistT (NoLoggingT (ResourceT m)) a
-> m a
runSqlite connstr = runResourceT
. runNoLoggingT
. withSqliteConn connstr
. runSqlConn
prepare' :: Sqlite.Connection -> Text -> IO Statement
prepare' conn sql = do
stmt <- Sqlite.prepare conn sql
return Statement
{ stmtFinalize = Sqlite.finalize stmt
, stmtReset = Sqlite.reset conn stmt
, stmtExecute = execute' conn stmt
, stmtQuery = withStmt' conn stmt
}
insertSql' :: EntityDef SqlType -> [PersistValue] -> InsertSqlResult
insertSql' ent vals =
case entityPrimary ent of
Just _ ->
ISRManyKeys sql vals
where sql = T.concat
[ "INSERT INTO "
, escape $ entityDB ent
, "("
, T.intercalate "," $ map (escape . fieldDB) $ entityFields ent
, ") VALUES("
, T.intercalate "," (map (const "?") $ entityFields ent)
, ")"
]
Nothing ->
ISRInsertGet ins sel
where
sel = "SELECT last_insert_rowid()"
ins = T.concat
[ "INSERT INTO "
, escape $ entityDB ent
, "("
, T.intercalate "," $ map (escape . fieldDB) $ entityFields ent
, ") VALUES("
, T.intercalate "," (map (const "?") $ entityFields ent)
, ")"
]
execute' :: Sqlite.Connection -> Sqlite.Statement -> [PersistValue] -> IO Int64
execute' conn stmt vals = flip finally (liftIO $ Sqlite.reset conn stmt) $ do
Sqlite.bind stmt vals
_ <- Sqlite.step stmt
Sqlite.changes conn
withStmt'
:: MonadResource m
=> Sqlite.Connection
-> Sqlite.Statement
-> [PersistValue]
-> Source m [PersistValue]
withStmt' conn stmt vals = bracketP
(Sqlite.bind stmt vals >> return stmt)
(Sqlite.reset conn)
(const pull)
where
pull = do
x <- liftIO $ Sqlite.step stmt
case x of
Sqlite.Done -> return ()
Sqlite.Row -> do
cols <- liftIO $ Sqlite.columns stmt
yield cols
pull
showSqlType :: SqlType -> Text
showSqlType SqlString = "VARCHAR"
showSqlType SqlInt32 = "INTEGER"
showSqlType SqlInt64 = "INTEGER"
showSqlType SqlReal = "REAL"
showSqlType (SqlNumeric precision scale) = T.concat [ "NUMERIC(", T.pack (show precision), ",", T.pack (show scale), ")" ]
showSqlType SqlDay = "DATE"
showSqlType SqlTime = "TIME"
showSqlType SqlDayTimeZoned = "TIMESTAMP"
showSqlType SqlDayTime = "TIMESTAMP"
showSqlType SqlBlob = "BLOB"
showSqlType SqlBool = "BOOLEAN"
showSqlType (SqlOther t) = t
migrate' :: [EntityDef a]
-> (Text -> IO Statement)
-> EntityDef SqlType
-> IO (Either [Text] [(Bool, Text)])
migrate' allDefs getter val = do
let (cols, uniqs, _) = mkColumns allDefs val
let newSql = mkCreateTable False def (filter (not . safeToRemove val . cName) cols, uniqs)
stmt <- getter "SELECT sql FROM sqlite_master WHERE type='table' AND name=?"
oldSql' <- runResourceT
$ stmtQuery stmt [PersistText $ unDBName table] $$ go
case oldSql' of
Nothing -> return $ Right [(False, newSql)]
Just oldSql -> do
if oldSql == newSql
then return $ Right []
else do
sql <- getCopyTable allDefs getter val
return $ Right sql
where
def = val
table = entityDB def
go = do
x <- CL.head
case x of
Nothing -> return Nothing
Just [PersistText y] -> return $ Just y
Just y -> error $ "Unexpected result from sqlite_master: " ++ show y
safeToRemove :: EntityDef a -> DBName -> Bool
safeToRemove def (DBName colName)
= any (elem "SafeToRemove" . fieldAttrs)
$ filter ((== (DBName colName)) . fieldDB)
$ entityFields def
getCopyTable :: [EntityDef a]
-> (Text -> IO Statement)
-> EntityDef SqlType
-> IO [(Bool, Text)]
getCopyTable allDefs getter val = do
stmt <- getter $ T.concat [ "PRAGMA table_info(", escape table, ")" ]
oldCols' <- runResourceT $ stmtQuery stmt [] $$ getCols
let oldCols = map DBName $ filter (/= "id") oldCols'
let newCols = filter (not . safeToRemove def) $ map cName cols
let common = filter (`elem` oldCols) newCols
let id_ = entityID val
return [ (False, tmpSql)
, (False, copyToTemp $ id_ : common)
, (common /= filter (not . safeToRemove def) oldCols, dropOld)
, (False, newSql)
, (False, copyToFinal $ id_ : newCols)
, (False, dropTmp)
]
where
def = val
getCols = do
x <- CL.head
case x of
Nothing -> return []
Just (_:PersistText name:_) -> do
names <- getCols
return $ name : names
Just y -> error $ "Invalid result from PRAGMA table_info: " ++ show y
table = entityDB def
tableTmp = DBName $ unDBName table <> "_backup"
(cols, uniqs, _) = mkColumns allDefs val
cols' = filter (not . safeToRemove def . cName) cols
newSql = mkCreateTable False def (cols', uniqs)
tmpSql = mkCreateTable True def { entityDB = tableTmp } (cols', uniqs)
dropTmp = "DROP TABLE " <> escape tableTmp
dropOld = "DROP TABLE " <> escape table
copyToTemp common = T.concat
[ "INSERT INTO "
, escape tableTmp
, "("
, T.intercalate "," $ map escape common
, ") SELECT "
, T.intercalate "," $ map escape common
, " FROM "
, escape table
]
copyToFinal newCols = T.concat
[ "INSERT INTO "
, escape table
, " SELECT "
, T.intercalate "," $ map escape newCols
, " FROM "
, escape tableTmp
]
mkCreateTable :: Bool -> EntityDef a -> ([Column], [UniqueDef]) -> Text
mkCreateTable isTemp entity (cols, uniqs) =
case entityPrimary entity of
Just _ ->
T.concat
[ "CREATE"
, if isTemp then " TEMP" else ""
, " TABLE "
, escape $ entityDB entity
, "("
, T.drop 1 $ T.concat $ map sqlColumn cols
, ", PRIMARY KEY "
, "("
, T.intercalate "," $ map (escape . fieldDB) $ entityFields entity
, ")"
, ")"
]
Nothing -> T.concat
[ "CREATE"
, if isTemp then " TEMP" else ""
, " TABLE "
, escape $ entityDB entity
, "("
, escape $ entityID entity
, " INTEGER PRIMARY KEY"
, T.concat $ map sqlColumn cols
, T.concat $ map sqlUnique uniqs
, ")"
]
sqlColumn :: Column -> Text
sqlColumn (Column name isNull typ def _cn _maxLen ref) = T.concat
[ ","
, escape name
, " "
, showSqlType typ
, if isNull then " NULL" else " NOT NULL"
, case def of
Nothing -> ""
Just d -> " DEFAULT " <> d
, case ref of
Nothing -> ""
Just (table, _) -> " REFERENCES " <> escape table
]
sqlUnique :: UniqueDef -> Text
sqlUnique (UniqueDef _ cname cols _) = T.concat
[ ",CONSTRAINT "
, escape cname
, " UNIQUE ("
, T.intercalate "," $ map (escape . snd) cols
, ")"
]
escape :: DBName -> Text
escape (DBName s) =
T.concat [q, T.concatMap go s, q]
where
q = T.singleton '"'
go '"' = "\"\""
go c = T.singleton c
data SqliteConf = SqliteConf
{ sqlDatabase :: Text
, sqlPoolSize :: Int
}
instance PersistConfig SqliteConf where
type PersistConfigBackend SqliteConf = SqlPersistT
type PersistConfigPool SqliteConf = ConnectionPool
createPoolConfig (SqliteConf cs size) = createSqlitePool cs size
runPool _ = runSqlPool
loadConfig (Object o) =
SqliteConf <$> o .: "database"
<*> o .: "poolsize"
loadConfig _ = mzero
finally :: MonadBaseControl IO m
=> m a
-> m b
-> m a
finally a sequel = control $ \runInIO ->
E.finally (runInIO a)
(runInIO sequel)