{-# LANGUAGE PackageImports #-} -- | A sqlite backend for persistent. module Database.Persist.Sqlite ( withSqlitePool , withSqliteConn , module Database.Persist , module Database.Persist.GenericSql ) where import Database.Persist import Database.Persist.Base import Database.Persist.GenericSql import Database.Persist.GenericSql.Internal import qualified Database.Sqlite as Sqlite import Control.Monad.IO.Class (MonadIO (..)) import Data.List (intercalate) import Data.IORef import qualified Data.Map as Map import Control.Monad.Invert (MonadInvertIO) withSqlitePool :: MonadInvertIO m => String -> Int -- ^ number of connections to open -> (ConnectionPool -> m a) -> m a withSqlitePool s = withSqlPool $ open' s withSqliteConn :: MonadInvertIO m => String -> (Connection -> m a) -> m a withSqliteConn = withSqlConn . open' open' :: String -> IO Connection open' s = do conn <- Sqlite.open s smap <- newIORef $ Map.empty return Connection { prepare = prepare' conn , stmtMap = smap , insertSql = insertSql' , close = Sqlite.close conn , migrateSql = migrate' , begin = helper "BEGIN" , commit = helper "COMMIT" , rollback = helper "ROLLBACK" , escapeName = escape , noLimit = "LIMIT -1" } where helper t getter = do stmt <- getter t execute stmt [] prepare' :: Sqlite.Connection -> String -> IO Statement prepare' conn sql = do stmt <- Sqlite.prepare conn sql return Statement { finalize = Sqlite.finalize stmt , reset = Sqlite.reset stmt , execute = execute' stmt , withStmt = withStmt' stmt } insertSql' :: RawName -> [RawName] -> Either String (String, String) insertSql' t cols = Right (ins, sel) where sel = "SELECT last_insert_rowid()" ins = concat [ "INSERT INTO " , escape t , "(" , intercalate "," $ map escape cols , ") VALUES(" , intercalate "," (map (const "?") cols) , ")" ] execute' :: Sqlite.Statement -> [PersistValue] -> IO () execute' stmt vals = do Sqlite.bind stmt vals Sqlite.Done <- Sqlite.step stmt return () withStmt' :: MonadInvertIO m => Sqlite.Statement -> [PersistValue] -> (RowPopper m -> m a) -> m a withStmt' stmt vals f = do liftIO $ Sqlite.bind stmt vals x <- f go liftIO $ Sqlite.reset stmt return x where go = liftIO $ do x <- Sqlite.step stmt case x of Sqlite.Done -> return Nothing Sqlite.Row -> do cols <- liftIO $ Sqlite.columns stmt return $ Just cols showSqlType :: SqlType -> String showSqlType SqlString = "VARCHAR" showSqlType SqlInt32 = "INTEGER" showSqlType SqlInteger = "INTEGER" showSqlType SqlReal = "REAL" showSqlType SqlDay = "DATE" showSqlType SqlTime = "TIME" showSqlType SqlDayTime = "TIMESTAMP" showSqlType SqlBlob = "BLOB" showSqlType SqlBool = "BOOLEAN" migrate' :: PersistEntity val => (String -> IO Statement) -> val -> IO (Either [String] [(Bool, String)]) migrate' getter val = do let (cols, uniqs) = mkColumns val let newSql = mkCreateTable False table (cols, uniqs) stmt <- getter $ "SELECT sql FROM sqlite_master WHERE " ++ "type='table' AND name=?" oldSql' <- withStmt stmt [PersistString $ unRawName table] go case oldSql' of Nothing -> return $ Right [(False, newSql)] Just oldSql -> if oldSql == newSql then return $ Right [] else do sql <- getCopyTable getter val return $ Right sql where def = entityDef val table = rawTableName def go pop = do x <- pop case x of Nothing -> return Nothing Just [PersistString y] -> return $ Just y Just y -> error $ "Unexpected result from sqlite_master: " ++ show y getCopyTable :: PersistEntity val => (String -> IO Statement) -> val -> IO [(Bool, Sql)] getCopyTable getter val = do stmt <- getter $ "PRAGMA table_info(" ++ escape table ++ ")" oldCols' <- withStmt stmt [] getCols let oldCols = map RawName $ filter (/= "id") oldCols' let newCols = map cName cols let common = filter (`elem` oldCols) newCols return [ (False, tmpSql) , (False, copyToTemp $ RawName "id" : common) , (common /= oldCols, dropOld) , (False, newSql) , (False, copyToFinal $ RawName "id" : newCols) , (False, dropTmp) ] where def = entityDef val getCols pop = do x <- pop case x of Nothing -> return [] Just (_:PersistString name:_) -> do names <- getCols pop return $ name : names Just y -> error $ "Invalid result from PRAGMA table_info: " ++ show y table = rawTableName def tableTmp = RawName $ unRawName table ++ "_backup" (cols, uniqs) = mkColumns val newSql = mkCreateTable False table (cols, uniqs) tmpSql = mkCreateTable True tableTmp (cols, uniqs) dropTmp = "DROP TABLE " ++ escape tableTmp dropOld = "DROP TABLE " ++ escape table copyToTemp common = concat [ "INSERT INTO " , escape tableTmp , "(" , intercalate "," $ map escape common , ") SELECT " , intercalate "," $ map escape common , " FROM " , escape table ] copyToFinal newCols = concat [ "INSERT INTO " , escape table , " SELECT " , intercalate "," $ map escape newCols , " FROM " , escape tableTmp ] mkCreateTable :: Bool -> RawName -> ([Column], [UniqueDef]) -> Sql mkCreateTable isTemp table (cols, uniqs) = concat [ "CREATE" , if isTemp then " TEMP" else "" , " TABLE " , escape table , "(id INTEGER PRIMARY KEY" , concatMap sqlColumn cols , concatMap sqlUnique uniqs , ")" ] sqlColumn :: Column -> String sqlColumn (Column name isNull typ def ref) = concat [ "," , escape name , " " , showSqlType typ , if isNull then " NULL" else " NOT NULL" , case def of Nothing -> "" Just d -> " DEFAULT " ++ d , case ref of Nothing -> "" Just (table, _) -> " REFERENCES " ++ escape table ] sqlUnique :: UniqueDef -> String sqlUnique (cname, cols) = concat [ ",CONSTRAINT " , escape cname , " UNIQUE (" , intercalate "," $ map escape cols , ")" ] type Sql = String escape :: RawName -> String escape (RawName s) = '"' : go s ++ "\"" where go "" = "" go ('"':xs) = "\"\"" ++ go xs go (x:xs) = x : go xs