module Database.Persist.Postgresql
( withPostgresqlPool
, withPostgresqlConn
, module Database.Persist
, module Database.Persist.GenericSql
) where
import Database.Persist
import Database.Persist.Base
import Database.Persist.GenericSql
import Database.Persist.GenericSql.Internal
import qualified Database.HDBC as H
import qualified Database.HDBC.PostgreSQL as H
import Control.Monad.IO.Class (MonadIO (..))
import Data.List (intercalate)
import "MonadCatchIO-transformers" Control.Monad.CatchIO
import Data.IORef
import qualified Data.Map as Map
import Data.Either (partitionEithers)
import Control.Arrow
import Data.List (sort, groupBy)
import Data.Function (on)
import qualified Data.ByteString.UTF8 as BSU
withPostgresqlPool :: MonadCatchIO m
=> String
-> Int
-> (ConnectionPool -> m a) -> m a
withPostgresqlPool s = withSqlPool $ open' s
withPostgresqlConn :: MonadCatchIO m => String -> (Connection -> m a) -> m a
withPostgresqlConn = withSqlConn . open'
open' :: String -> IO Connection
open' s = do
conn <- H.connectPostgreSQL s
smap <- newIORef $ Map.empty
return Connection
{ prepare = prepare' conn
, stmtMap = smap
, insertSql = insertSql'
, close = H.disconnect conn
, migrateSql = migrate'
, begin = const $ return ()
, commit = const $ H.commit conn
, rollback = const $ H.rollback conn
, escapeName = escape
, noLimit = "LIMIT ALL"
}
prepare' :: H.Connection -> String -> IO Statement
prepare' conn sql = do
stmt <- H.prepare conn sql
return Statement
{ finalize = return ()
, reset = return ()
, execute = execute' stmt
, withStmt = withStmt' stmt
}
insertSql' :: RawName -> [RawName] -> Either String (String, String)
insertSql' t cols = Left $ concat
[ "INSERT INTO "
, escape t
, "("
, intercalate "," $ map escape cols
, ") VALUES("
, intercalate "," (map (const "?") cols)
, ") RETURNING id"
]
execute' :: H.Statement -> [PersistValue] -> IO ()
execute' stmt vals = do
_ <- H.execute stmt $ map pToSql vals
return ()
withStmt' :: MonadCatchIO m
=> H.Statement
-> [PersistValue]
-> (RowPopper m -> m a)
-> m a
withStmt' stmt vals f = do
_ <- liftIO $ H.execute stmt $ map pToSql vals
f $ liftIO $ (fmap . fmap) (map pFromSql) $ H.fetchRow stmt
pToSql :: PersistValue -> H.SqlValue
pToSql (PersistString s) = H.SqlString s
pToSql (PersistByteString bs) = H.SqlByteString bs
pToSql (PersistInt64 i) = H.SqlInt64 i
pToSql (PersistDouble d) = H.SqlDouble d
pToSql (PersistBool b) = H.SqlBool b
pToSql (PersistDay d) = H.SqlLocalDate d
pToSql (PersistTimeOfDay t) = H.SqlLocalTimeOfDay t
pToSql (PersistUTCTime t) = H.SqlUTCTime t
pToSql PersistNull = H.SqlNull
pFromSql :: H.SqlValue -> PersistValue
pFromSql (H.SqlString s) = PersistString s
pFromSql (H.SqlByteString bs) = PersistByteString bs
pFromSql (H.SqlWord32 i) = PersistInt64 $ fromIntegral i
pFromSql (H.SqlWord64 i) = PersistInt64 $ fromIntegral i
pFromSql (H.SqlInt32 i) = PersistInt64 $ fromIntegral i
pFromSql (H.SqlInt64 i) = PersistInt64 $ fromIntegral i
pFromSql (H.SqlInteger i) = PersistInt64 $ fromIntegral i
pFromSql (H.SqlChar c) = PersistInt64 $ fromIntegral $ fromEnum c
pFromSql (H.SqlBool b) = PersistBool b
pFromSql (H.SqlDouble b) = PersistDouble b
pFromSql (H.SqlRational b) = PersistDouble $ fromRational b
pFromSql (H.SqlLocalDate d) = PersistDay d
pFromSql (H.SqlLocalTimeOfDay d) = PersistTimeOfDay d
pFromSql (H.SqlUTCTime d) = PersistUTCTime d
pFromSql H.SqlNull = PersistNull
pFromSql x = PersistString $ H.fromSql x
migrate' :: PersistEntity val
=> (String -> IO Statement)
-> val
-> IO (Either [String] [(Bool, String)])
migrate' getter val = do
let name = rawTableName $ entityDef val
old <- getColumns getter name
case partitionEithers old of
([], old'') -> do
let old' = partitionEithers old''
let new = mkColumns val
if null old
then do
let addTable = AddTable $ concat
[ "CREATE TABLE "
, escape name
, "(id SERIAL PRIMARY KEY UNIQUE"
, concatMap (\x -> ',' : showColumn x) $ fst new
, ")"
]
let rest = flip concatMap (snd new) $ \(uname, ucols) ->
[AlterTable name $ AddUniqueConstraint uname ucols]
return $ Right $ map showAlterDb $ addTable : rest
else do
let (acs, ats) = getAlters new old'
let acs' = map (AlterColumn name) acs
let ats' = map (AlterTable name) ats
return $ Right $ map showAlterDb $ acs' ++ ats'
(errs, _) -> return $ Left errs
data AlterColumn = Type SqlType | IsNull | NotNull | Add Column | Drop
| Default String | NoDefault | Update String
| AddReference RawName | DropReference RawName
type AlterColumn' = (RawName, AlterColumn)
data AlterTable = AddUniqueConstraint RawName [RawName]
| DropConstraint RawName
data AlterDB = AddTable String
| AlterColumn RawName AlterColumn'
| AlterTable RawName AlterTable
getColumns :: (String -> IO Statement)
-> RawName -> IO [Either String (Either Column UniqueDef)]
getColumns getter name = do
stmt <- getter $
"SELECT column_name,is_nullable,udt_name,column_default " ++
"FROM information_schema.columns " ++
"WHERE table_name=? AND column_name <> 'id'"
cs <- withStmt stmt [PersistString $ unRawName name] helper
stmt' <- getter $ concat
[ "SELECT constraint_name, column_name "
, "FROM information_schema.constraint_column_usage "
, "WHERE table_name=? AND column_name <> 'id' "
, "ORDER BY constraint_name, column_name"
]
us <- withStmt stmt' [PersistString $ unRawName name] helperU
return $ cs ++ us
where
getAll pop front = do
x <- pop
case x of
Nothing -> return $ front []
Just [PersistByteString con, PersistByteString col] ->
getAll pop (front . (:) (BSU.toString con, BSU.toString col))
Just _ -> getAll pop front
helperU pop = do
rows <- getAll pop id
return $ map (Right . Right . (RawName . fst . head &&& map (RawName . snd)))
$ groupBy ((==) `on` fst) rows
helper pop = do
x <- pop
case x of
Nothing -> return []
Just x' -> do
col <- getColumn getter name x'
let col' = case col of
Left e -> Left e
Right c -> Right $ Left c
cols <- helper pop
return $ col' : cols
getAlters :: ([Column], [UniqueDef])
-> ([Column], [UniqueDef])
-> ([AlterColumn'], [AlterTable])
getAlters (c1, u1) (c2, u2) =
(getAltersC c1 c2, getAltersU u1 u2)
where
getAltersC [] old = map (\x -> (cName x, Drop)) old
getAltersC (new:news) old =
let (alters, old') = findAlters new old
in alters ++ getAltersC news old'
getAltersU [] old = map (DropConstraint . fst) old
getAltersU ((name, cols):news) old =
case lookup name old of
Nothing -> AddUniqueConstraint name cols : getAltersU news old
Just ocols ->
let old' = filter (\(x, _) -> x /= name) old
in if sort cols == ocols
then getAltersU news old'
else DropConstraint name
: AddUniqueConstraint name cols
: getAltersU news old'
getColumn :: (String -> IO Statement)
-> RawName -> [PersistValue]
-> IO (Either String Column)
getColumn getter tname
[PersistByteString x, PersistByteString y,
PersistByteString z, d] =
case d' of
Left s -> return $ Left s
Right d'' ->
case getType $ BSU.toString z of
Left s -> return $ Left s
Right t -> do
let cname = RawName $ BSU.toString x
ref <- getRef cname
return $ Right $ Column cname (BSU.toString y == "YES")
t d'' ref
where
getRef cname = do
let sql = concat
[ "SELECT COUNT(*) FROM "
, "information_schema.table_constraints "
, "WHERE table_name=? "
, "AND constraint_type='FOREIGN KEY' "
, "AND constraint_name=?"
]
let ref = refName tname cname
stmt <- getter sql
withStmt stmt
[ PersistString $ unRawName tname
, PersistString $ unRawName ref
] $ \pop -> do
Just [PersistInt64 i] <- pop
return $ if i == 0 then Nothing else Just (RawName "", ref)
d' = case d of
PersistNull -> Right Nothing
PersistByteString a -> Right $ Just $ BSU.toString a
_ -> Left $ "Invalid default column: " ++ show d
getType "int4" = Right $ SqlInt32
getType "int8" = Right $ SqlInteger
getType "varchar" = Right $ SqlString
getType "date" = Right $ SqlDay
getType "bool" = Right $ SqlBool
getType "timestamp" = Right $ SqlDayTime
getType "float4" = Right $ SqlReal
getType "float8" = Right $ SqlReal
getType "bytea" = Right $ SqlBlob
getType a = Left $ "Unknown type: " ++ a
getColumn _ _ x =
return $ Left $ "Invalid result from information_schema: " ++ show x
findAlters :: Column -> [Column] -> ([AlterColumn'], [Column])
findAlters col@(Column name isNull type_ def ref) cols =
case filter (\c -> cName c == name) cols of
[] -> ([(name, Add col)], cols)
Column _ isNull' type_' def' ref':_ ->
let refDrop Nothing = []
refDrop (Just (_, cname)) = [(name, DropReference cname)]
refAdd Nothing = []
refAdd (Just (tname, _)) = [(name, AddReference tname)]
modRef =
if fmap snd ref == fmap snd ref'
then []
else refDrop ref' ++ refAdd ref
modNull = case (isNull, isNull') of
(True, False) -> [(name, IsNull)]
(False, True) ->
let up = case def of
Nothing -> id
Just s -> (:) (name, Update s)
in up [(name, NotNull)]
_ -> []
modType = if type_ == type_' then [] else [(name, Type type_)]
modDef =
if def == def'
then []
else case def of
Nothing -> [(name, NoDefault)]
Just s -> [(name, Default s)]
in (modRef ++ modDef ++ modNull ++ modType,
filter (\c -> cName c /= name) cols)
showColumn :: Column -> String
showColumn (Column n nu t def ref) = concat
[ escape n
, " "
, showSqlType t
, " "
, if nu then "NULL" else "NOT NULL"
, case def of
Nothing -> ""
Just s -> " DEFAULT " ++ s
, case ref of
Nothing -> ""
Just (s, _) -> " REFERENCES " ++ escape s
]
showSqlType :: SqlType -> String
showSqlType SqlString = "VARCHAR"
showSqlType SqlInt32 = "INT4"
showSqlType SqlInteger = "INT8"
showSqlType SqlReal = "DOUBLE PRECISION"
showSqlType SqlDay = "DATE"
showSqlType SqlTime = "TIME"
showSqlType SqlDayTime = "TIMESTAMP"
showSqlType SqlBlob = "BYTEA"
showSqlType SqlBool = "BOOLEAN"
showAlterDb :: AlterDB -> (Bool, String)
showAlterDb (AddTable s) = (False, s)
showAlterDb (AlterColumn t (c, ac)) =
(isUnsafe ac, showAlter t (c, ac))
where
isUnsafe Drop = True
isUnsafe _ = False
showAlterDb (AlterTable t at) = (False, showAlterTable t at)
showAlterTable :: RawName -> AlterTable -> String
showAlterTable table (AddUniqueConstraint cname cols) = concat
[ "ALTER TABLE "
, escape table
, " ADD CONSTRAINT "
, escape cname
, " UNIQUE("
, intercalate "," $ map escape cols
, ")"
]
showAlterTable table (DropConstraint cname) = concat
[ "ALTER TABLE "
, escape table
, " DROP CONSTRAINT "
, escape cname
]
showAlter :: RawName -> AlterColumn' -> String
showAlter table (n, Type t) =
concat
[ "ALTER TABLE "
, escape table
, " ALTER COLUMN "
, escape n
, " TYPE "
, showSqlType t
]
showAlter table (n, IsNull) =
concat
[ "ALTER TABLE "
, escape table
, " ALTER COLUMN "
, escape n
, " DROP NOT NULL"
]
showAlter table (n, NotNull) =
concat
[ "ALTER TABLE "
, escape table
, " ALTER COLUMN "
, escape n
, " SET NOT NULL"
]
showAlter table (_, Add col) =
concat
[ "ALTER TABLE "
, escape table
, " ADD COLUMN "
, showColumn col
]
showAlter table (n, Drop) =
concat
[ "ALTER TABLE "
, escape table
, " DROP COLUMN "
, escape n
]
showAlter table (n, Default s) =
concat
[ "ALTER TABLE "
, escape table
, " ALTER COLUMN "
, escape n
, " SET DEFAULT "
, s
]
showAlter table (n, NoDefault) = concat
[ "ALTER TABLE "
, escape table
, " ALTER COLUMN "
, escape n
, " DROP DEFAULT"
]
showAlter table (n, Update s) = concat
[ "UPDATE "
, escape table
, " SET "
, escape n
, "="
, s
, " WHERE "
, escape n
, " IS NULL"
]
showAlter table (n, AddReference t2) = concat
[ "ALTER TABLE "
, escape table
, " ADD CONSTRAINT "
, escape $ refName table n
, " FOREIGN KEY("
, escape n
, ") REFERENCES "
, escape t2
]
showAlter table (_, DropReference cname) =
"ALTER TABLE " ++ escape table ++ " DROP CONSTRAINT " ++ escape cname
escape :: RawName -> String
escape (RawName s) =
'"' : go s ++ "\""
where
go "" = ""
go ('"':xs) = "\"\"" ++ go xs
go (x:xs) = x : go xs