{-# LANGUAGE PackageImports #-} -- | A postgresql backend for persistent. module Database.Persist.Postgresql ( withPostgresqlPool , withPostgresqlConn , module Database.Persist , module Database.Persist.GenericSql ) where import Database.Persist import Database.Persist.Base import Database.Persist.GenericSql import Database.Persist.GenericSql.Internal import qualified Database.HDBC as H import qualified Database.HDBC.PostgreSQL as H import Control.Monad.IO.Class (MonadIO (..)) import Data.List (intercalate) import Data.IORef import qualified Data.Map as Map import Data.Either (partitionEithers) import Control.Arrow import Data.List (sort, groupBy) import Data.Function (on) import qualified Data.ByteString.UTF8 as BSU import Control.Monad.Invert (MonadInvertIO) withPostgresqlPool :: MonadInvertIO m => String -> Int -- ^ number of connections to open -> (ConnectionPool -> m a) -> m a withPostgresqlPool s = withSqlPool $ open' s withPostgresqlConn :: MonadInvertIO m => String -> (Connection -> m a) -> m a withPostgresqlConn = withSqlConn . open' open' :: String -> IO Connection open' s = do conn <- H.connectPostgreSQL s smap <- newIORef $ Map.empty return Connection { prepare = prepare' conn , stmtMap = smap , insertSql = insertSql' , close = H.disconnect conn , migrateSql = migrate' , begin = const $ return () , commit = const $ H.commit conn , rollback = const $ H.rollback conn , escapeName = escape , noLimit = "LIMIT ALL" } prepare' :: H.Connection -> String -> IO Statement prepare' conn sql = do stmt <- H.prepare conn sql return Statement { finalize = return () , reset = return () , execute = execute' stmt , withStmt = withStmt' stmt } insertSql' :: RawName -> [RawName] -> Either String (String, String) insertSql' t cols = Left $ concat [ "INSERT INTO " , escape t , "(" , intercalate "," $ map escape cols , ") VALUES(" , intercalate "," (map (const "?") cols) , ") RETURNING id" ] execute' :: H.Statement -> [PersistValue] -> IO () execute' stmt vals = do _ <- H.execute stmt $ map pToSql vals return () withStmt' :: MonadInvertIO m => H.Statement -> [PersistValue] -> (RowPopper m -> m a) -> m a withStmt' stmt vals f = do _ <- liftIO $ H.execute stmt $ map pToSql vals f $ liftIO $ (fmap . fmap) (map pFromSql) $ H.fetchRow stmt pToSql :: PersistValue -> H.SqlValue pToSql (PersistString s) = H.SqlString s pToSql (PersistByteString bs) = H.SqlByteString bs pToSql (PersistInt64 i) = H.SqlInt64 i pToSql (PersistDouble d) = H.SqlDouble d pToSql (PersistBool b) = H.SqlBool b pToSql (PersistDay d) = H.SqlLocalDate d pToSql (PersistTimeOfDay t) = H.SqlLocalTimeOfDay t pToSql (PersistUTCTime t) = H.SqlUTCTime t pToSql PersistNull = H.SqlNull pFromSql :: H.SqlValue -> PersistValue pFromSql (H.SqlString s) = PersistString s pFromSql (H.SqlByteString bs) = PersistByteString bs pFromSql (H.SqlWord32 i) = PersistInt64 $ fromIntegral i pFromSql (H.SqlWord64 i) = PersistInt64 $ fromIntegral i pFromSql (H.SqlInt32 i) = PersistInt64 $ fromIntegral i pFromSql (H.SqlInt64 i) = PersistInt64 $ fromIntegral i pFromSql (H.SqlInteger i) = PersistInt64 $ fromIntegral i pFromSql (H.SqlChar c) = PersistInt64 $ fromIntegral $ fromEnum c pFromSql (H.SqlBool b) = PersistBool b pFromSql (H.SqlDouble b) = PersistDouble b pFromSql (H.SqlRational b) = PersistDouble $ fromRational b pFromSql (H.SqlLocalDate d) = PersistDay d pFromSql (H.SqlLocalTimeOfDay d) = PersistTimeOfDay d pFromSql (H.SqlUTCTime d) = PersistUTCTime d pFromSql H.SqlNull = PersistNull pFromSql x = PersistString $ H.fromSql x -- FIXME migrate' :: PersistEntity val => (String -> IO Statement) -> val -> IO (Either [String] [(Bool, String)]) migrate' getter val = do let name = rawTableName $ entityDef val old <- getColumns getter name case partitionEithers old of ([], old'') -> do let old' = partitionEithers old'' let new = mkColumns val if null old then do let addTable = AddTable $ concat [ "CREATE TABLE " , escape name , "(id SERIAL PRIMARY KEY UNIQUE" , concatMap (\x -> ',' : showColumn x) $ fst new , ")" ] let rest = flip concatMap (snd new) $ \(uname, ucols) -> [AlterTable name $ AddUniqueConstraint uname ucols] return $ Right $ map showAlterDb $ addTable : rest else do let (acs, ats) = getAlters new old' let acs' = map (AlterColumn name) acs let ats' = map (AlterTable name) ats return $ Right $ map showAlterDb $ acs' ++ ats' (errs, _) -> return $ Left errs data AlterColumn = Type SqlType | IsNull | NotNull | Add Column | Drop | Default String | NoDefault | Update String | AddReference RawName | DropReference RawName type AlterColumn' = (RawName, AlterColumn) data AlterTable = AddUniqueConstraint RawName [RawName] | DropConstraint RawName data AlterDB = AddTable String | AlterColumn RawName AlterColumn' | AlterTable RawName AlterTable -- | Returns all of the columns in the given table currently in the database. getColumns :: (String -> IO Statement) -> RawName -> IO [Either String (Either Column UniqueDef)] getColumns getter name = do stmt <- getter $ "SELECT column_name,is_nullable,udt_name,column_default " ++ "FROM information_schema.columns " ++ "WHERE table_name=? AND column_name <> 'id'" cs <- withStmt stmt [PersistString $ unRawName name] helper stmt' <- getter $ concat [ "SELECT constraint_name, column_name " , "FROM information_schema.constraint_column_usage " , "WHERE table_name=? AND column_name <> 'id' " , "ORDER BY constraint_name, column_name" ] us <- withStmt stmt' [PersistString $ unRawName name] helperU return $ cs ++ us where getAll pop front = do x <- pop case x of Nothing -> return $ front [] Just [PersistByteString con, PersistByteString col] -> getAll pop (front . (:) (BSU.toString con, BSU.toString col)) Just _ -> getAll pop front -- FIXME error message? helperU pop = do rows <- getAll pop id return $ map (Right . Right . (RawName . fst . head &&& map (RawName . snd))) $ groupBy ((==) `on` fst) rows helper pop = do x <- pop case x of Nothing -> return [] Just x' -> do col <- getColumn getter name x' let col' = case col of Left e -> Left e Right c -> Right $ Left c cols <- helper pop return $ col' : cols getAlters :: ([Column], [UniqueDef]) -> ([Column], [UniqueDef]) -> ([AlterColumn'], [AlterTable]) getAlters (c1, u1) (c2, u2) = (getAltersC c1 c2, getAltersU u1 u2) where getAltersC [] old = map (\x -> (cName x, Drop)) old getAltersC (new:news) old = let (alters, old') = findAlters new old in alters ++ getAltersC news old' getAltersU [] old = map (DropConstraint . fst) old getAltersU ((name, cols):news) old = case lookup name old of Nothing -> AddUniqueConstraint name cols : getAltersU news old Just ocols -> let old' = filter (\(x, _) -> x /= name) old in if sort cols == ocols then getAltersU news old' else DropConstraint name : AddUniqueConstraint name cols : getAltersU news old' getColumn :: (String -> IO Statement) -> RawName -> [PersistValue] -> IO (Either String Column) getColumn getter tname [PersistByteString x, PersistByteString y, PersistByteString z, d] = case d' of Left s -> return $ Left s Right d'' -> case getType $ BSU.toString z of Left s -> return $ Left s Right t -> do let cname = RawName $ BSU.toString x ref <- getRef cname return $ Right $ Column cname (BSU.toString y == "YES") t d'' ref where getRef cname = do let sql = concat [ "SELECT COUNT(*) FROM " , "information_schema.table_constraints " , "WHERE table_name=? " , "AND constraint_type='FOREIGN KEY' " , "AND constraint_name=?" ] let ref = refName tname cname stmt <- getter sql withStmt stmt [ PersistString $ unRawName tname , PersistString $ unRawName ref ] $ \pop -> do Just [PersistInt64 i] <- pop return $ if i == 0 then Nothing else Just (RawName "", ref) d' = case d of PersistNull -> Right Nothing PersistByteString a -> Right $ Just $ BSU.toString a _ -> Left $ "Invalid default column: " ++ show d getType "int4" = Right $ SqlInt32 getType "int8" = Right $ SqlInteger getType "varchar" = Right $ SqlString getType "date" = Right $ SqlDay getType "bool" = Right $ SqlBool getType "timestamp" = Right $ SqlDayTime getType "float4" = Right $ SqlReal getType "float8" = Right $ SqlReal getType "bytea" = Right $ SqlBlob getType a = Left $ "Unknown type: " ++ a getColumn _ _ x = return $ Left $ "Invalid result from information_schema: " ++ show x findAlters :: Column -> [Column] -> ([AlterColumn'], [Column]) findAlters col@(Column name isNull type_ def ref) cols = case filter (\c -> cName c == name) cols of [] -> ([(name, Add col)], cols) Column _ isNull' type_' def' ref':_ -> let refDrop Nothing = [] refDrop (Just (_, cname)) = [(name, DropReference cname)] refAdd Nothing = [] refAdd (Just (tname, _)) = [(name, AddReference tname)] modRef = if fmap snd ref == fmap snd ref' then [] else refDrop ref' ++ refAdd ref modNull = case (isNull, isNull') of (True, False) -> [(name, IsNull)] (False, True) -> let up = case def of Nothing -> id Just s -> (:) (name, Update s) in up [(name, NotNull)] _ -> [] modType = if type_ == type_' then [] else [(name, Type type_)] modDef = if def == def' then [] else case def of Nothing -> [(name, NoDefault)] Just s -> [(name, Default s)] in (modRef ++ modDef ++ modNull ++ modType, filter (\c -> cName c /= name) cols) showColumn :: Column -> String showColumn (Column n nu t def ref) = concat [ escape n , " " , showSqlType t , " " , if nu then "NULL" else "NOT NULL" , case def of Nothing -> "" Just s -> " DEFAULT " ++ s , case ref of Nothing -> "" Just (s, _) -> " REFERENCES " ++ escape s ] showSqlType :: SqlType -> String showSqlType SqlString = "VARCHAR" showSqlType SqlInt32 = "INT4" showSqlType SqlInteger = "INT8" showSqlType SqlReal = "DOUBLE PRECISION" showSqlType SqlDay = "DATE" showSqlType SqlTime = "TIME" showSqlType SqlDayTime = "TIMESTAMP" showSqlType SqlBlob = "BYTEA" showSqlType SqlBool = "BOOLEAN" showAlterDb :: AlterDB -> (Bool, String) showAlterDb (AddTable s) = (False, s) showAlterDb (AlterColumn t (c, ac)) = (isUnsafe ac, showAlter t (c, ac)) where isUnsafe Drop = True isUnsafe _ = False showAlterDb (AlterTable t at) = (False, showAlterTable t at) showAlterTable :: RawName -> AlterTable -> String showAlterTable table (AddUniqueConstraint cname cols) = concat [ "ALTER TABLE " , escape table , " ADD CONSTRAINT " , escape cname , " UNIQUE(" , intercalate "," $ map escape cols , ")" ] showAlterTable table (DropConstraint cname) = concat [ "ALTER TABLE " , escape table , " DROP CONSTRAINT " , escape cname ] showAlter :: RawName -> AlterColumn' -> String showAlter table (n, Type t) = concat [ "ALTER TABLE " , escape table , " ALTER COLUMN " , escape n , " TYPE " , showSqlType t ] showAlter table (n, IsNull) = concat [ "ALTER TABLE " , escape table , " ALTER COLUMN " , escape n , " DROP NOT NULL" ] showAlter table (n, NotNull) = concat [ "ALTER TABLE " , escape table , " ALTER COLUMN " , escape n , " SET NOT NULL" ] showAlter table (_, Add col) = concat [ "ALTER TABLE " , escape table , " ADD COLUMN " , showColumn col ] showAlter table (n, Drop) = concat [ "ALTER TABLE " , escape table , " DROP COLUMN " , escape n ] showAlter table (n, Default s) = concat [ "ALTER TABLE " , escape table , " ALTER COLUMN " , escape n , " SET DEFAULT " , s ] showAlter table (n, NoDefault) = concat [ "ALTER TABLE " , escape table , " ALTER COLUMN " , escape n , " DROP DEFAULT" ] showAlter table (n, Update s) = concat [ "UPDATE " , escape table , " SET " , escape n , "=" , s , " WHERE " , escape n , " IS NULL" ] showAlter table (n, AddReference t2) = concat [ "ALTER TABLE " , escape table , " ADD CONSTRAINT " , escape $ refName table n , " FOREIGN KEY(" , escape n , ") REFERENCES " , escape t2 ] showAlter table (_, DropReference cname) = "ALTER TABLE " ++ escape table ++ " DROP CONSTRAINT " ++ escape cname escape :: RawName -> String escape (RawName s) = '"' : go s ++ "\"" where go "" = "" go ('"':xs) = "\"\"" ++ go xs go (x:xs) = x : go xs