module Database.Persist.Sqlite
( withSqlitePool
, withSqliteConn
, module Database.Persist
, module Database.Persist.GenericSql
) where
import Database.Persist
import Database.Persist.Base
import Database.Persist.GenericSql
import Database.Persist.GenericSql.Internal
import qualified Database.Sqlite as Sqlite
import Control.Monad.IO.Class (MonadIO (..))
import Data.List (intercalate)
import Data.IORef
import qualified Data.Map as Map
import Control.Monad.IO.Control (MonadControlIO)
import Control.Exception.Control (finally)
import Data.Text (Text, pack, unpack)
withSqlitePool :: MonadControlIO m
=> Text
-> Int
-> (ConnectionPool -> m a) -> m a
withSqlitePool s = withSqlPool $ open' s
withSqliteConn :: MonadControlIO m => Text -> (Connection -> m a) -> m a
withSqliteConn = withSqlConn . open'
open' :: Text -> IO Connection
open' s = do
conn <- Sqlite.open s
smap <- newIORef $ Map.empty
return Connection
{ prepare = prepare' conn
, stmtMap = smap
, insertSql = insertSql'
, close = Sqlite.close conn
, migrateSql = migrate'
, begin = helper "BEGIN"
, commit = helper "COMMIT"
, rollback = helper "ROLLBACK"
, escapeName = escape
, noLimit = "LIMIT -1"
}
where
helper t getter = do
stmt <- getter t
execute stmt []
reset stmt
prepare' :: Sqlite.Connection -> Text -> IO Statement
prepare' conn sql = do
stmt <- Sqlite.prepare conn sql
return Statement
{ finalize = Sqlite.finalize stmt
, reset = Sqlite.reset stmt
, execute = execute' stmt
, withStmt = withStmt' stmt
}
insertSql' :: RawName -> [RawName] -> Either Text (Text, Text)
insertSql' t cols =
Right (pack ins, sel)
where
sel = "SELECT last_insert_rowid()"
ins = concat
[ "INSERT INTO "
, escape t
, "("
, intercalate "," $ map escape cols
, ") VALUES("
, intercalate "," (map (const "?") cols)
, ")"
]
execute' :: Sqlite.Statement -> [PersistValue] -> IO ()
execute' stmt vals = flip finally (liftIO $ Sqlite.reset stmt) $ do
Sqlite.bind stmt vals
Sqlite.Done <- Sqlite.step stmt
return ()
withStmt' :: MonadControlIO m
=> Sqlite.Statement
-> [PersistValue]
-> (RowPopper m -> m a)
-> m a
withStmt' stmt vals f = flip finally (liftIO $ Sqlite.reset stmt) $ do
liftIO $ Sqlite.bind stmt vals
x <- f go
return x
where
go = liftIO $ do
x <- Sqlite.step stmt
case x of
Sqlite.Done -> return Nothing
Sqlite.Row -> do
cols <- liftIO $ Sqlite.columns stmt
return $ Just cols
showSqlType :: SqlType -> String
showSqlType SqlString = "VARCHAR"
showSqlType SqlInt32 = "INTEGER"
showSqlType SqlInteger = "INTEGER"
showSqlType SqlReal = "REAL"
showSqlType SqlDay = "DATE"
showSqlType SqlTime = "TIME"
showSqlType SqlDayTime = "TIMESTAMP"
showSqlType SqlBlob = "BLOB"
showSqlType SqlBool = "BOOLEAN"
migrate' :: PersistEntity val
=> (Text -> IO Statement)
-> val
-> IO (Either [Text] [(Bool, Text)])
migrate' getter val = do
let (cols, uniqs) = mkColumns val
let newSql = mkCreateTable False table (cols, uniqs)
stmt <- getter "SELECT sql FROM sqlite_master WHERE type='table' AND name=?"
oldSql' <- withStmt stmt [PersistText $ pack $ unRawName table] go
case oldSql' of
Nothing -> return $ Right [(False, newSql)]
Just oldSql ->
if oldSql == newSql
then return $ Right []
else do
sql <- getCopyTable getter val
return $ Right sql
where
def = entityDef val
table = rawTableName def
go pop = do
x <- pop
case x of
Nothing -> return Nothing
Just [PersistText y] -> return $ Just y
Just y -> error $ "Unexpected result from sqlite_master: " ++ show y
getCopyTable :: PersistEntity val => (Text -> IO Statement) -> val
-> IO [(Bool, Sql)]
getCopyTable getter val = do
stmt <- getter $ pack $ "PRAGMA table_info(" ++ escape table ++ ")"
oldCols' <- withStmt stmt [] getCols
let oldCols = map (RawName . unpack) $ filter (/= "id") oldCols'
let newCols = map cName cols
let common = filter (`elem` oldCols) newCols
return [ (False, tmpSql)
, (False, copyToTemp $ RawName "id" : common)
, (common /= oldCols, pack dropOld)
, (False, newSql)
, (False, copyToFinal $ RawName "id" : newCols)
, (False, pack dropTmp)
]
where
def = entityDef val
getCols pop = do
x <- pop
case x of
Nothing -> return []
Just (_:PersistText name:_) -> do
names <- getCols pop
return $ name : names
Just y -> error $ "Invalid result from PRAGMA table_info: " ++ show y
table = rawTableName def
tableTmp = RawName $ unRawName table ++ "_backup"
(cols, uniqs) = mkColumns val
newSql = mkCreateTable False table (cols, uniqs)
tmpSql = mkCreateTable True tableTmp (cols, uniqs)
dropTmp = "DROP TABLE " ++ escape tableTmp
dropOld = "DROP TABLE " ++ escape table
copyToTemp common = pack $ concat
[ "INSERT INTO "
, escape tableTmp
, "("
, intercalate "," $ map escape common
, ") SELECT "
, intercalate "," $ map escape common
, " FROM "
, escape table
]
copyToFinal newCols = pack $ concat
[ "INSERT INTO "
, escape table
, " SELECT "
, intercalate "," $ map escape newCols
, " FROM "
, escape tableTmp
]
mkCreateTable :: Bool -> RawName -> ([Column], [UniqueDef]) -> Sql
mkCreateTable isTemp table (cols, uniqs) = pack $ concat
[ "CREATE"
, if isTemp then " TEMP" else ""
, " TABLE "
, escape table
, "(id INTEGER PRIMARY KEY"
, concatMap sqlColumn cols
, concatMap sqlUnique uniqs
, ")"
]
sqlColumn :: Column -> String
sqlColumn (Column name isNull typ def ref) = concat
[ ","
, escape name
, " "
, showSqlType typ
, if isNull then " NULL" else " NOT NULL"
, case def of
Nothing -> ""
Just d -> " DEFAULT " ++ d
, case ref of
Nothing -> ""
Just (table, _) -> " REFERENCES " ++ escape table
]
sqlUnique :: UniqueDef -> String
sqlUnique (cname, cols) = concat
[ ",CONSTRAINT "
, escape cname
, " UNIQUE ("
, intercalate "," $ map escape cols
, ")"
]
type Sql = Text
escape :: RawName -> String
escape (RawName s) =
'"' : go s ++ "\""
where
go "" = ""
go ('"':xs) = "\"\"" ++ go xs
go (x:xs) = x : go xs