{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE Rank2Types #-} -- | A postgresql backend for persistent. module Database.Persist.Postgresql ( withPostgresqlPool , withPostgresqlConn , createPostgresqlPool , createPostgresqlPoolModified , module Database.Persist.Sql , ConnectionString , PostgresConf (..) , openSimpleConn , tableName , fieldName ) where import Database.Persist.Sql import Database.Persist.Sql.Util (dbIdColumnsEsc) import Data.Fixed (Pico) import qualified Database.PostgreSQL.Simple as PG import qualified Database.PostgreSQL.Simple.TypeInfo as PG import qualified Database.PostgreSQL.Simple.TypeInfo.Static as PS import qualified Database.PostgreSQL.Simple.Internal as PG import qualified Database.PostgreSQL.Simple.ToField as PGTF import qualified Database.PostgreSQL.Simple.FromField as PGFF import qualified Database.PostgreSQL.Simple.Types as PG import Database.PostgreSQL.Simple.Ok (Ok (..)) import qualified Database.PostgreSQL.LibPQ as LibPQ import Control.Monad.Trans.Resource import Control.Exception (throw) import Control.Monad.IO.Class (MonadIO (..)) import Data.Typeable import Data.IORef import qualified Data.Map as Map import Data.Maybe import Data.Either (partitionEithers) import Control.Arrow import Data.List (find, sort, groupBy) import Data.Function (on) import Data.Conduit import qualified Data.Conduit.List as CL import Control.Monad.Logger (MonadLogger, runNoLoggingT) import qualified Data.IntMap as I import Data.ByteString (ByteString) import qualified Data.ByteString.Char8 as B8 import qualified Data.Text as T import qualified Data.Text.Encoding as T import qualified Blaze.ByteString.Builder.Char8 as BBB import Data.Text (Text) import Data.Aeson import Data.Aeson.Types (modifyFailure) import Control.Monad (forM) import Data.Acquire (Acquire, mkAcquire, with) import System.Environment (getEnvironment) import Data.Int (Int64) import Data.Monoid ((<>)) import Data.Time (utc, localTimeToUTC) -- | A @libpq@ connection string. A simple example of connection -- string would be @\"host=localhost port=5432 user=test -- dbname=test password=test\"@. Please read libpq's -- documentation at -- -- for more details on how to create such strings. type ConnectionString = ByteString -- | Create a PostgreSQL connection pool and run the given -- action. The pool is properly released after the action -- finishes using it. Note that you should not use the given -- 'ConnectionPool' outside the action since it may be already -- been released. withPostgresqlPool :: (MonadBaseControl IO m, MonadLogger m, MonadIO m) => ConnectionString -- ^ Connection string to the database. -> Int -- ^ Number of connections to be kept open in -- the pool. -> (ConnectionPool -> m a) -- ^ Action to be executed that uses the -- connection pool. -> m a withPostgresqlPool ci = withSqlPool $ open' (const $ return ()) ci -- | Create a PostgreSQL connection pool. Note that it's your -- responsibility to properly close the connection pool when -- unneeded. Use 'withPostgresqlPool' for an automatic resource -- control. createPostgresqlPool :: (MonadIO m, MonadBaseControl IO m, MonadLogger m) => ConnectionString -- ^ Connection string to the database. -> Int -- ^ Number of connections to be kept open -- in the pool. -> m ConnectionPool createPostgresqlPool = createPostgresqlPoolModified (const $ return ()) -- | Same as 'createPostgresqlPool', but additionally takes a callback function -- for some connection-specific tweaking to be performed after connection -- creation. This could be used, for example, to change the schema. For more -- information, see: -- -- -- -- Since 2.1.3 createPostgresqlPoolModified :: (MonadIO m, MonadBaseControl IO m, MonadLogger m) => (PG.Connection -> IO ()) -- ^ action to perform after connection is created -> ConnectionString -- ^ Connection string to the database. -> Int -- ^ Number of connections to be kept open in the pool. -> m ConnectionPool createPostgresqlPoolModified modConn ci = createSqlPool $ open' modConn ci -- | Same as 'withPostgresqlPool', but instead of opening a pool -- of connections, only one connection is opened. withPostgresqlConn :: (MonadIO m, MonadBaseControl IO m, MonadLogger m) => ConnectionString -> (SqlBackend -> m a) -> m a withPostgresqlConn = withSqlConn . open' (const $ return ()) open' :: (PG.Connection -> IO ()) -> ConnectionString -> LogFunc -> IO SqlBackend open' modConn cstr logFunc = do conn <- PG.connectPostgreSQL cstr modConn conn openSimpleConn logFunc conn -- | Generate a 'Connection' from a 'PG.Connection' openSimpleConn :: LogFunc -> PG.Connection -> IO SqlBackend openSimpleConn logFunc conn = do smap <- newIORef $ Map.empty return SqlBackend { connPrepare = prepare' conn , connStmtMap = smap , connInsertSql = insertSql' , connClose = PG.close conn , connMigrateSql = migrate' , connBegin = const $ PG.begin conn , connCommit = const $ PG.commit conn , connRollback = const $ PG.rollback conn , connEscapeName = escape , connNoLimit = "LIMIT ALL" , connRDBMS = "postgresql" , connLimitOffset = decorateSQLWithLimitOffset "LIMIT ALL" , connLogFunc = logFunc } prepare' :: PG.Connection -> Text -> IO Statement prepare' conn sql = do let query = PG.Query (T.encodeUtf8 sql) return Statement { stmtFinalize = return () , stmtReset = return () , stmtExecute = execute' conn query , stmtQuery = withStmt' conn query } insertSql' :: EntityDef -> [PersistValue] -> InsertSqlResult insertSql' ent vals = let sql = T.concat [ "INSERT INTO " , escape $ entityDB ent , if null (entityFields ent) then " DEFAULT VALUES" else T.concat [ "(" , T.intercalate "," $ map (escape . fieldDB) $ entityFields ent , ") VALUES(" , T.intercalate "," (map (const "?") $ entityFields ent) , ")" ] ] in case entityPrimary ent of Just _pdef -> ISRManyKeys sql vals Nothing -> ISRSingle (sql <> " RETURNING " <> escape (fieldDB (entityId ent))) execute' :: PG.Connection -> PG.Query -> [PersistValue] -> IO Int64 execute' conn query vals = PG.execute conn query (map P vals) withStmt' :: MonadIO m => PG.Connection -> PG.Query -> [PersistValue] -> Acquire (Source m [PersistValue]) withStmt' conn query vals = pull `fmap` mkAcquire openS closeS where openS = do -- Construct raw query rawquery <- PG.formatQuery conn query (map P vals) -- Take raw connection (rt, rr, rc, ids) <- PG.withConnection conn $ \rawconn -> do -- Execute query mret <- LibPQ.exec rawconn rawquery case mret of Nothing -> do merr <- LibPQ.errorMessage rawconn fail $ case merr of Nothing -> "Postgresql.withStmt': unknown error" Just e -> "Postgresql.withStmt': " ++ B8.unpack e Just ret -> do -- Check result status status <- LibPQ.resultStatus ret case status of LibPQ.TuplesOk -> return () _ -> PG.throwResultError "Postgresql.withStmt': bad result status " ret status -- Get number and type of columns cols <- LibPQ.nfields ret oids <- forM [0..cols-1] $ \col -> fmap ((,) col) (LibPQ.ftype ret col) -- Ready to go! rowRef <- newIORef (LibPQ.Row 0) rowCount <- LibPQ.ntuples ret return (ret, rowRef, rowCount, oids) let getters = map (\(col, oid) -> getGetter conn oid $ PG.Field rt col oid) ids return (rt, rr, rc, getters) closeS (ret, _, _, _) = LibPQ.unsafeFreeResult ret pull x = do y <- liftIO $ pullS x case y of Nothing -> return () Just z -> yield z >> pull x pullS (ret, rowRef, rowCount, getters) = do row <- atomicModifyIORef rowRef (\r -> (r+1, r)) if row == rowCount then return Nothing else fmap Just $ forM (zip getters [0..]) $ \(getter, col) -> do mbs <- LibPQ.getvalue' ret row col case mbs of Nothing -> return PersistNull Just bs -> do ok <- PGFF.runConversion (getter mbs) conn bs `seq` case ok of Errors (exc:_) -> throw exc Errors [] -> error "Got an Errors, but no exceptions" Ok v -> return v -- | Avoid orphan instances. newtype P = P PersistValue instance PGTF.ToField P where toField (P (PersistText t)) = PGTF.toField t toField (P (PersistByteString bs)) = PGTF.toField (PG.Binary bs) toField (P (PersistInt64 i)) = PGTF.toField i toField (P (PersistDouble d)) = PGTF.toField d toField (P (PersistRational r)) = PGTF.Plain $ BBB.fromString $ show (fromRational r :: Pico) -- FIXME: Too Ambigous, can not select precision without information about field toField (P (PersistBool b)) = PGTF.toField b toField (P (PersistDay d)) = PGTF.toField d toField (P (PersistTimeOfDay t)) = PGTF.toField t toField (P (PersistUTCTime t)) = PGTF.toField t toField (P PersistNull) = PGTF.toField PG.Null toField (P (PersistList l)) = PGTF.toField $ listToJSON l toField (P (PersistMap m)) = PGTF.toField $ mapToJSON m toField (P (PersistDbSpecific s)) = PGTF.toField (Unknown s) toField (P (PersistObjectId _)) = error "Refusing to serialize a PersistObjectId to a PostgreSQL value" newtype Unknown = Unknown { unUnknown :: ByteString } deriving (Eq, Show, Read, Ord, Typeable) instance PGFF.FromField Unknown where fromField f mdata = case mdata of Nothing -> PGFF.returnError PGFF.UnexpectedNull f "" Just dat -> return (Unknown dat) instance PGTF.ToField Unknown where toField (Unknown a) = PGTF.Escape a type Getter a = PGFF.FieldParser a convertPV :: PGFF.FromField a => (a -> b) -> Getter b convertPV f = (fmap f .) . PGFF.fromField builtinGetters :: I.IntMap (Getter PersistValue) builtinGetters = I.fromList [ (k PS.bool, convertPV PersistBool) , (k PS.bytea, convertPV (PersistByteString . unBinary)) , (k PS.char, convertPV PersistText) , (k PS.name, convertPV PersistText) , (k PS.int8, convertPV PersistInt64) , (k PS.int2, convertPV PersistInt64) , (k PS.int4, convertPV PersistInt64) , (k PS.text, convertPV PersistText) , (k PS.xml, convertPV PersistText) , (k PS.float4, convertPV PersistDouble) , (k PS.float8, convertPV PersistDouble) , (k PS.abstime, convertPV PersistUTCTime) , (k PS.reltime, convertPV PersistUTCTime) , (k PS.money, convertPV PersistRational) , (k PS.bpchar, convertPV PersistText) , (k PS.varchar, convertPV PersistText) , (k PS.date, convertPV PersistDay) , (k PS.time, convertPV PersistTimeOfDay) , (k PS.timestamp, convertPV (PersistUTCTime. localTimeToUTC utc)) , (k PS.timestamptz, convertPV PersistUTCTime) , (k PS.bit, convertPV PersistInt64) , (k PS.varbit, convertPV PersistInt64) , (k PS.numeric, convertPV PersistRational) , (k PS.void, \_ _ -> return PersistNull) , (k PS.json, convertPV (PersistByteString . unUnknown)) , (k PS.jsonb, convertPV (PersistByteString . unUnknown)) , (k PS.unknown, convertPV (PersistByteString . unUnknown)) -- array types: same order as above , (1000, listOf PersistBool) , (1001, listOf (PersistByteString . unBinary)) , (1002, listOf PersistText) , (1003, listOf PersistText) , (1016, listOf PersistInt64) , (1005, listOf PersistInt64) , (1007, listOf PersistInt64) , (1009, listOf PersistText) , (143, listOf PersistText) , (1021, listOf PersistDouble) , (1022, listOf PersistDouble) , (1023, listOf PersistUTCTime) , (1024, listOf PersistUTCTime) , (791, listOf PersistRational) , (1014, listOf PersistText) , (1015, listOf PersistText) , (1182, listOf PersistDay) , (1183, listOf PersistTimeOfDay) , (1115, listOf PersistUTCTime) , (1185, listOf PersistUTCTime) , (1561, listOf PersistInt64) , (1563, listOf PersistInt64) , (1231, listOf PersistRational) -- no array(void) type , (2951, listOf (PersistDbSpecific . unUnknown)) , (199, listOf (PersistByteString . unUnknown)) , (3807, listOf (PersistByteString . unUnknown)) -- no array(unknown) either ] where k (PGFF.typoid -> i) = PG.oid2int i listOf f = convertPV (PersistList . map f . PG.fromPGArray) getGetter :: PG.Connection -> PG.Oid -> Getter PersistValue getGetter _conn oid = fromMaybe defaultGetter $ I.lookup (PG.oid2int oid) builtinGetters where defaultGetter = convertPV (PersistDbSpecific . unUnknown) unBinary :: PG.Binary a -> a unBinary (PG.Binary x) = x doesTableExist :: (Text -> IO Statement) -> DBName -- ^ table name -> IO Bool doesTableExist getter (DBName name) = do stmt <- getter sql with (stmtQuery stmt vals) ($$ start) where sql = "SELECT COUNT(*) FROM information_schema.tables WHERE table_name=?" vals = [PersistText name] start = await >>= maybe (error "No results when checking doesTableExist") start' start' [PersistInt64 0] = finish False start' [PersistInt64 1] = finish True start' res = error $ "doesTableExist returned unexpected result: " ++ show res finish x = await >>= maybe (return x) (error "Too many rows returned in doesTableExist") migrate' :: [EntityDef] -> (Text -> IO Statement) -> EntityDef -> IO (Either [Text] [(Bool, Text)]) migrate' allDefs getter entity = fmap (fmap $ map showAlterDb) $ do old <- getColumns getter entity case partitionEithers old of ([], old'') -> do exists <- if null old then doesTableExist getter name else return True return $ Right $ migrationText exists old'' (errs, _) -> return $ Left errs where name = entityDB entity migrationText exists old'' = if not exists then createText newcols fdefs udspair else let (acs, ats) = getAlters allDefs entity (newcols, udspair) old' acs' = map (AlterColumn name) acs ats' = map (AlterTable name) ats in acs' ++ ats' where old' = partitionEithers old'' (newcols', udefs, fdefs) = mkColumns allDefs entity newcols = filter (not . safeToRemove entity . cName) newcols' udspair = map udToPair udefs -- Check for table existence if there are no columns, workaround -- for https://github.com/yesodweb/persistent/issues/152 createText newcols fdefs udspair = addTable : uniques ++ references ++ foreignsAlt where addTable = AddTable $ T.concat -- Lower case e: see Database.Persist.Sql.Migration [ "CREATe TABLE " -- DO NOT FIX THE CAPITALIZATION! , escape name , "(" , idtxt , if null newcols then "" else "," , T.intercalate "," $ map showColumn newcols , ")" ] uniques = flip concatMap udspair $ \(uname, ucols) -> [AlterTable name $ AddUniqueConstraint uname ucols] references = mapMaybe (\c@Column { cName=cname, cReference=Just (refTblName, _) } -> getAddReference allDefs name refTblName cname (cReference c)) $ filter (isJust . cReference) newcols foreignsAlt = flip map fdefs (\fdef -> let (childfields, parentfields) = unzip (map (\((_,b),(_,d)) -> (b,d)) (foreignFields fdef)) in AlterColumn name (foreignRefTableDBName fdef, AddReference (foreignConstraintNameDBName fdef) childfields (map escape parentfields))) idtxt = case entityPrimary entity of Just pdef -> T.concat [" PRIMARY KEY (", T.intercalate "," $ map (escape . fieldDB) $ compositeFields pdef, ")"] Nothing -> let defText = defaultAttribute $ fieldAttrs $ entityId entity sType = fieldSqlType $ entityId entity in T.concat [ escape $ fieldDB (entityId entity) , maySerial sType defText , " PRIMARY KEY UNIQUE" , mayDefault defText ] maySerial :: SqlType -> Maybe Text -> Text maySerial SqlInt64 Nothing = " SERIAL " maySerial sType _ = " " <> showSqlType sType mayDefault :: Maybe Text -> Text mayDefault def = case def of Nothing -> "" Just d -> " DEFAULT " <> d type SafeToRemove = Bool data AlterColumn = Type SqlType Text | IsNull | NotNull | Add' Column | Drop SafeToRemove | Default Text | NoDefault | Update' Text | AddReference DBName [DBName] [Text] | DropReference DBName type AlterColumn' = (DBName, AlterColumn) data AlterTable = AddUniqueConstraint DBName [DBName] | DropConstraint DBName data AlterDB = AddTable Text | AlterColumn DBName AlterColumn' | AlterTable DBName AlterTable -- | Returns all of the columns in the given table currently in the database. getColumns :: (Text -> IO Statement) -> EntityDef -> IO [Either Text (Either Column (DBName, [DBName]))] getColumns getter def = do let sqlv=T.concat ["SELECT " ,"column_name " ,",is_nullable " ,",udt_name " ,",column_default " ,",numeric_precision " ,",numeric_scale " ,"FROM information_schema.columns " ,"WHERE table_catalog=current_database() " ,"AND table_schema=current_schema() " ,"AND table_name=? " ,"AND column_name <> ?"] stmt <- getter sqlv let vals = [ PersistText $ unDBName $ entityDB def , PersistText $ unDBName $ fieldDB (entityId def) ] cs <- with (stmtQuery stmt vals) ($$ helper) let sqlc = T.concat ["SELECT " ,"c.constraint_name, " ,"c.column_name " ,"FROM information_schema.key_column_usage c, " ,"information_schema.table_constraints k " ,"WHERE c.table_catalog=current_database() " ,"AND c.table_catalog=k.table_catalog " ,"AND c.table_schema=current_schema() " ,"AND c.table_schema=k.table_schema " ,"AND c.table_name=? " ,"AND c.table_name=k.table_name " ,"AND c.column_name <> ? " ,"AND c.constraint_name=k.constraint_name " ,"AND NOT k.constraint_type IN ('PRIMARY KEY', 'FOREIGN KEY') " ,"ORDER BY c.constraint_name, c.column_name"] stmt' <- getter sqlc us <- with (stmtQuery stmt' vals) ($$ helperU) return $ cs ++ us where getAll front = do x <- CL.head case x of Nothing -> return $ front [] Just [PersistText con, PersistText col] -> getAll (front . (:) (con, col)) Just [PersistByteString con, PersistByteString col] -> getAll (front . (:) (T.decodeUtf8 con, T.decodeUtf8 col)) Just o -> error $ "unexpected datatype returned for postgres o="++show o helperU = do rows <- getAll id return $ map (Right . Right . (DBName . fst . head &&& map (DBName . snd))) $ groupBy ((==) `on` fst) rows helper = do x <- CL.head case x of Nothing -> return [] Just x' -> do col <- liftIO $ getColumn getter (entityDB def) x' let col' = case col of Left e -> Left e Right c -> Right $ Left c cols <- helper return $ col' : cols -- | Check if a column name is listed as the "safe to remove" in the entity -- list. safeToRemove :: EntityDef -> DBName -> Bool safeToRemove def (DBName colName) = any (elem "SafeToRemove" . fieldAttrs) $ filter ((== DBName colName) . fieldDB) $ entityFields def getAlters :: [EntityDef] -> EntityDef -> ([Column], [(DBName, [DBName])]) -> ([Column], [(DBName, [DBName])]) -> ([AlterColumn'], [AlterTable]) getAlters defs def (c1, u1) (c2, u2) = (getAltersC c1 c2, getAltersU u1 u2) where getAltersC [] old = map (\x -> (cName x, Drop $ safeToRemove def $ cName x)) old getAltersC (new:news) old = let (alters, old') = findAlters defs (entityDB def) new old in alters ++ getAltersC news old' getAltersU :: [(DBName, [DBName])] -> [(DBName, [DBName])] -> [AlterTable] getAltersU [] old = map DropConstraint $ filter (not . isManual) $ map fst old getAltersU ((name, cols):news) old = case lookup name old of Nothing -> AddUniqueConstraint name cols : getAltersU news old Just ocols -> let old' = filter (\(x, _) -> x /= name) old in if sort cols == sort ocols then getAltersU news old' else DropConstraint name : AddUniqueConstraint name cols : getAltersU news old' -- Don't drop constraints which were manually added. isManual (DBName x) = "__manual_" `T.isPrefixOf` x getColumn :: (Text -> IO Statement) -> DBName -> [PersistValue] -> IO (Either Text Column) getColumn getter tname [PersistText x, PersistText y, PersistText z, d, npre, nscl] = case d' of Left s -> return $ Left s Right d'' -> case getType z of Left s -> return $ Left s Right t -> do let cname = DBName x ref <- getRef cname return $ Right Column { cName = cname , cNull = y == "YES" , cSqlType = t , cDefault = fmap stripSuffixes d'' , cDefaultConstraintName = Nothing , cMaxLen = Nothing , cReference = ref } where stripSuffixes t = loop' [ "::character varying" , "::text" ] where loop' [] = t loop' (p:ps) = case T.stripSuffix p t of Nothing -> loop' ps Just t' -> t' getRef cname = do let sql = T.concat [ "SELECT COUNT(*) FROM " , "information_schema.table_constraints " , "WHERE table_catalog=current_database() " , "AND table_schema=current_schema() " , "AND table_name=? " , "AND constraint_type='FOREIGN KEY' " , "AND constraint_name=?" ] let ref = refName tname cname stmt <- getter sql with (stmtQuery stmt [ PersistText $ unDBName tname , PersistText $ unDBName ref ]) ($$ do Just [PersistInt64 i] <- CL.head return $ if i == 0 then Nothing else Just (DBName "", ref)) d' = case d of PersistNull -> Right Nothing PersistText t -> Right $ Just t _ -> Left $ T.pack $ "Invalid default column: " ++ show d getType "int4" = Right SqlInt32 getType "int8" = Right SqlInt64 getType "varchar" = Right SqlString getType "date" = Right SqlDay getType "bool" = Right SqlBool getType "timestamptz" = Right SqlDayTime getType "float4" = Right SqlReal getType "float8" = Right SqlReal getType "bytea" = Right SqlBlob getType "time" = Right SqlTime getType "numeric" = getNumeric npre nscl getType a = Right $ SqlOther a getNumeric (PersistInt64 a) (PersistInt64 b) = Right $ SqlNumeric (fromIntegral a) (fromIntegral b) getNumeric a b = Left $ T.pack $ "Can not get numeric field precision, got: " ++ show a ++ " and " ++ show b ++ " as precision and scale" getColumn _ _ x = return $ Left $ T.pack $ "Invalid result from information_schema: " ++ show x -- | Intelligent comparison of SQL types, to account for SqlInt32 vs SqlOther integer sqlTypeEq :: SqlType -> SqlType -> Bool sqlTypeEq x y = T.toCaseFold (showSqlType x) == T.toCaseFold (showSqlType y) findAlters :: [EntityDef] -> DBName -> Column -> [Column] -> ([AlterColumn'], [Column]) findAlters defs _tablename col@(Column name isNull sqltype def _defConstraintName _maxLen ref) cols = case filter (\c -> cName c == name) cols of [] -> ([(name, Add' col)], cols) Column _ isNull' sqltype' def' _defConstraintName' _maxLen' ref':_ -> let refDrop Nothing = [] refDrop (Just (_, cname)) = [(name, DropReference cname)] refAdd Nothing = [] refAdd (Just (tname, a)) = case find ((==tname) . entityDB) defs of Just refdef -> [(tname, AddReference a [name] (dbIdColumnsEsc escape refdef))] Nothing -> error $ "could not find the entityDef for reftable[" ++ show tname ++ "]" modRef = if fmap snd ref == fmap snd ref' then [] else refDrop ref' ++ refAdd ref modNull = case (isNull, isNull') of (True, False) -> [(name, IsNull)] (False, True) -> let up = case def of Nothing -> id Just s -> (:) (name, Update' s) in up [(name, NotNull)] _ -> [] modType | sqlTypeEq sqltype sqltype' = [] -- When converting from Persistent pre-2.0 databases, we -- need to make sure that TIMESTAMP WITHOUT TIME ZONE is -- treated as UTC. | sqltype == SqlDayTime && sqltype' == SqlOther "timestamp" = [(name, Type sqltype $ T.concat [ " USING " , escape name , " AT TIME ZONE 'UTC'" ])] | otherwise = [(name, Type sqltype "")] modDef = if def == def' then [] else case def of Nothing -> [(name, NoDefault)] Just s -> [(name, Default s)] in (modRef ++ modDef ++ modNull ++ modType, filter (\c -> cName c /= name) cols) -- | Get the references to be added to a table for the given column. getAddReference :: [EntityDef] -> DBName -> DBName -> DBName -> Maybe (DBName, DBName) -> Maybe AlterDB getAddReference allDefs table reftable cname ref = case ref of Nothing -> Nothing Just (s, _) -> Just $ AlterColumn table (s, AddReference (refName table cname) [cname] id_) where id_ = fromMaybe (error $ "Could not find ID of entity " ++ show reftable) $ do entDef <- find ((== reftable) . entityDB) allDefs return $ dbIdColumnsEsc escape entDef showColumn :: Column -> Text showColumn (Column n nu sqlType' def _defConstraintName _maxLen _ref) = T.concat [ escape n , " " , showSqlType sqlType' , " " , if nu then "NULL" else "NOT NULL" , case def of Nothing -> "" Just s -> " DEFAULT " <> s ] showSqlType :: SqlType -> Text showSqlType SqlString = "VARCHAR" showSqlType SqlInt32 = "INT4" showSqlType SqlInt64 = "INT8" showSqlType SqlReal = "DOUBLE PRECISION" showSqlType (SqlNumeric s prec) = T.concat [ "NUMERIC(", T.pack (show s), ",", T.pack (show prec), ")" ] showSqlType SqlDay = "DATE" showSqlType SqlTime = "TIME" showSqlType SqlDayTime = "TIMESTAMP WITH TIME ZONE" showSqlType SqlBlob = "BYTEA" showSqlType SqlBool = "BOOLEAN" -- Added for aliasing issues re: https://github.com/yesodweb/yesod/issues/682 showSqlType (SqlOther (T.toLower -> "integer")) = "INT4" showSqlType (SqlOther t) = t showAlterDb :: AlterDB -> (Bool, Text) showAlterDb (AddTable s) = (False, s) showAlterDb (AlterColumn t (c, ac)) = (isUnsafe ac, showAlter t (c, ac)) where isUnsafe (Drop safeRemove) = not safeRemove isUnsafe _ = False showAlterDb (AlterTable t at) = (False, showAlterTable t at) showAlterTable :: DBName -> AlterTable -> Text showAlterTable table (AddUniqueConstraint cname cols) = T.concat [ "ALTER TABLE " , escape table , " ADD CONSTRAINT " , escape cname , " UNIQUE(" , T.intercalate "," $ map escape cols , ")" ] showAlterTable table (DropConstraint cname) = T.concat [ "ALTER TABLE " , escape table , " DROP CONSTRAINT " , escape cname ] showAlter :: DBName -> AlterColumn' -> Text showAlter table (n, Type t extra) = T.concat [ "ALTER TABLE " , escape table , " ALTER COLUMN " , escape n , " TYPE " , showSqlType t , extra ] showAlter table (n, IsNull) = T.concat [ "ALTER TABLE " , escape table , " ALTER COLUMN " , escape n , " DROP NOT NULL" ] showAlter table (n, NotNull) = T.concat [ "ALTER TABLE " , escape table , " ALTER COLUMN " , escape n , " SET NOT NULL" ] showAlter table (_, Add' col) = T.concat [ "ALTER TABLE " , escape table , " ADD COLUMN " , showColumn col ] showAlter table (n, Drop _) = T.concat [ "ALTER TABLE " , escape table , " DROP COLUMN " , escape n ] showAlter table (n, Default s) = T.concat [ "ALTER TABLE " , escape table , " ALTER COLUMN " , escape n , " SET DEFAULT " , s ] showAlter table (n, NoDefault) = T.concat [ "ALTER TABLE " , escape table , " ALTER COLUMN " , escape n , " DROP DEFAULT" ] showAlter table (n, Update' s) = T.concat [ "UPDATE " , escape table , " SET " , escape n , "=" , s , " WHERE " , escape n , " IS NULL" ] showAlter table (reftable, AddReference fkeyname t2 id2) = T.concat [ "ALTER TABLE " , escape table , " ADD CONSTRAINT " , escape fkeyname , " FOREIGN KEY(" , T.intercalate "," $ map escape t2 , ") REFERENCES " , escape reftable , "(" , T.intercalate "," id2 , ")" ] showAlter table (_, DropReference cname) = T.concat [ "ALTER TABLE " , escape table , " DROP CONSTRAINT " , escape cname ] -- | get the SQL string for the table that a PeristEntity represents -- Useful for raw SQL queries tableName :: forall record. ( PersistEntity record , PersistEntityBackend record ~ SqlBackend ) => record -> Text tableName = escape . tableDBName -- | get the SQL string for the field that an EntityField represents -- Useful for raw SQL queries fieldName :: forall record typ. ( PersistEntity record , PersistEntityBackend record ~ SqlBackend ) => EntityField record typ -> Text fieldName = escape . fieldDBName escape :: DBName -> Text escape (DBName s) = T.pack $ '"' : go (T.unpack s) ++ "\"" where go "" = "" go ('"':xs) = "\"\"" ++ go xs go (x:xs) = x : go xs -- | Information required to connect to a PostgreSQL database -- using @persistent@'s generic facilities. These values are the -- same that are given to 'withPostgresqlPool'. data PostgresConf = PostgresConf { pgConnStr :: ConnectionString -- ^ The connection string. , pgPoolSize :: Int -- ^ How many connections should be held on the connection pool. } deriving Show instance FromJSON PostgresConf where parseJSON v = modifyFailure ("Persistent: error loadomg PostgreSQL conf: " ++) $ flip (withObject "PostgresConf") v $ \o -> do database <- o .: "database" host <- o .: "host" port <- o .:? "port" .!= 5432 user <- o .: "user" password <- o .: "password" pool <- o .: "poolsize" let ci = PG.ConnectInfo { PG.connectHost = host , PG.connectPort = port , PG.connectUser = user , PG.connectPassword = password , PG.connectDatabase = database } cstr = PG.postgreSQLConnectionString ci return $ PostgresConf cstr pool instance PersistConfig PostgresConf where type PersistConfigBackend PostgresConf = SqlPersistT type PersistConfigPool PostgresConf = ConnectionPool createPoolConfig (PostgresConf cs size) = runNoLoggingT $ createPostgresqlPool cs size -- FIXME runPool _ = runSqlPool loadConfig = parseJSON applyEnv c0 = do env <- getEnvironment return $ addUser env $ addPass env $ addDatabase env $ addPort env $ addHost env c0 where addParam param val c = c { pgConnStr = B8.concat [pgConnStr c, " ", param, "='", pgescape val, "'"] } pgescape = B8.pack . go where go ('\'':rest) = '\\' : '\'' : go rest go ('\\':rest) = '\\' : '\\' : go rest go ( x :rest) = x : go rest go [] = [] maybeAddParam param envvar env = maybe id (addParam param) $ lookup envvar env addHost = maybeAddParam "host" "PGHOST" addPort = maybeAddParam "port" "PGPORT" addUser = maybeAddParam "user" "PGUSER" addPass = maybeAddParam "password" "PGPASS" addDatabase = maybeAddParam "dbname" "PGDATABASE" refName :: DBName -> DBName -> DBName refName (DBName table) (DBName column) = DBName $ T.concat [table, "_", column, "_fkey"] udToPair :: UniqueDef -> (DBName, [DBName]) udToPair ud = (uniqueDBName ud, map snd $ uniqueFields ud)