module Database.Persist.MigratePostgres
( getMigrationStrategy
) where
import Database.Persist.Sql
import Control.Monad.IO.Class (MonadIO (..))
import Control.Monad.Trans.Resource (runResourceT)
import qualified Data.Text as T
import Data.Text (pack,Text)
import Data.Either (partitionEithers)
import Control.Arrow
import Data.List (find, intercalate, sort, groupBy)
import Data.Function (on)
import Data.Conduit
import qualified Data.Conduit.List as CL
import Data.Maybe (mapMaybe)
import qualified Data.Text.Encoding as T
import Database.Persist.ODBCTypes
#if DEBUG
import Debug.Trace
tracex::String -> a -> a
tracex = trace
#else
tracex::String -> a -> a
tracex _ b = b
#endif
getMigrationStrategy :: DBType -> MigrationStrategy
getMigrationStrategy dbtype@Postgres {} =
MigrationStrategy
{ dbmsLimitOffset=decorateSQLWithLimitOffset "LIMIT ALL"
,dbmsMigrate=migrate'
,dbmsInsertSql=insertSql'
,dbmsEscape=escape
,dbmsType=dbtype
}
getMigrationStrategy dbtype = error $ "Postgres: calling with invalid dbtype " ++ show dbtype
migrate' :: [EntityDef a]
-> (Text -> IO Statement)
-> EntityDef SqlType
-> IO (Either [Text] [(Bool, Text)])
migrate' allDefs getter val = fmap (fmap $ map showAlterDb) $ do
let name = entityDB val
old <- getColumns getter val
case partitionEithers old of
([], old'') -> do
let old' = partitionEithers old''
let (newcols', udefs, fdefs) = mkColumns allDefs val
let newcols = filter (not . safeToRemove val . cName) newcols'
let udspair = map udToPair udefs
if null old
then do
let idtxt = case entityPrimary val of
Just pdef -> concat [" PRIMARY KEY (", intercalate "," $ map (T.unpack . escape . snd) $ primaryFields pdef, ")"]
Nothing -> concat [T.unpack $ escape $ entityID val
, " SERIAL PRIMARY KEY UNIQUE"]
let addTable = AddTable $ concat
[ "CREATe TABLE "
, T.unpack $ escape name
, "("
, idtxt
, if null newcols then [] else ","
, intercalate "," $ map showColumn newcols
, ")"
]
let uniques = flip concatMap udspair $ \(uname, ucols) ->
[AlterTable name $ AddUniqueConstraint uname ucols]
references = mapMaybe (\c@Column { cName=cname, cReference=Just (refTblName, _) } -> getAddReference allDefs name refTblName cname (cReference c)) $ filter (\c -> cReference c /= Nothing) newcols
foreignsAlt = map (\fdef -> let (childfields, parentfields) = unzip (map (\(_,b,_,d) -> (b,d)) (foreignFields fdef))
in AlterColumn name (foreignRefTableDBName fdef, AddReference (foreignConstraintNameDBName fdef) childfields parentfields)) fdefs
return $ Right $ addTable : uniques ++ references ++ foreignsAlt
else do
let (acs, ats) = getAlters allDefs val (newcols, udspair) old'
let acs' = map (AlterColumn name) acs
let ats' = map (AlterTable name) ats
return $ Right $ acs' ++ ats'
(errs, _) -> return $ Left errs
type SafeToRemove = Bool
data AlterColumn = Type SqlType | IsNull | NotNull | Add' Column | Drop SafeToRemove
| Default String | NoDefault | Update' String
| AddReference DBName [DBName] [DBName] | DropReference DBName
type AlterColumn' = (DBName, AlterColumn)
data AlterTable = AddUniqueConstraint DBName [DBName]
| DropConstraint DBName
data AlterDB = AddTable String
| AlterColumn DBName AlterColumn'
| AlterTable DBName AlterTable
getColumns :: (Text -> IO Statement)
-> EntityDef a
-> IO [Either Text (Either Column (DBName, [DBName]))]
getColumns getter def = do
let sqlv=concat ["SELECT "
,"column_name "
,",is_nullable "
,",udt_name "
,",column_default "
,",numeric_precision "
,",numeric_scale "
,"FROM information_schema.columns "
,"WHERE table_catalog=current_database() "
,"AND table_schema=current_schema() "
,"AND table_name=? "
,"AND column_name <> ?"]
stmt <- getter $ pack sqlv
let vals =
[ PersistText $ unDBName $ entityDB def
, PersistText $ unDBName $ entityID def
]
cs <- runResourceT $ stmtQuery stmt vals $$ helperClmns
let sqlc=concat ["SELECT "
,"c.constraint_name, "
,"c.column_name "
,"FROM information_schema.key_column_usage c, "
,"information_schema.table_constraints k "
,"WHERE c.table_catalog=current_database() "
,"AND c.table_catalog=k.table_catalog "
,"AND c.table_schema=current_schema() "
,"AND c.table_schema=k.table_schema "
,"AND c.table_name=? "
,"AND c.table_name=k.table_name "
,"AND c.column_name <> ? "
,"AND c.ordinal_position=1 "
,"AND c.constraint_name=k.constraint_name "
,"AND k.constraint_type not in ('PRIMARY KEY','FOREIGN KEY') "
,"ORDER BY c.constraint_name, c.column_name"]
stmt' <- getter $ pack sqlc
us <- runResourceT $ stmtQuery stmt' vals $$ helperU
liftIO $ putStrLn $ "\n\ngetColumns cs="++show cs++"\n\nus="++show us
return $ cs ++ us
where
getAll front = do
x <- CL.head
case x of
Nothing -> return $ front []
Just [PersistText con, PersistText col] -> getAll (front . (:) (con, col))
Just [PersistByteString con, PersistByteString col] -> getAll (front . (:) (T.decodeUtf8 con, T.decodeUtf8 col))
Just xx -> error $ "oops: unexpected datatype returned odbc postgres xx="++show xx
helperU = do
rows <- getAll id
return $ map (Right . Right . (DBName . fst . head &&& map (DBName . snd)))
$ groupBy ((==) `on` fst) rows
helperClmns = CL.mapM getIt =$ CL.consume
where
getIt = fmap (either Left (Right . Left)) .
liftIO .
getColumn getter (entityDB def)
safeToRemove :: EntityDef a -> DBName -> Bool
safeToRemove def (DBName colName)
= any (elem "SafeToRemove" . fieldAttrs)
$ filter ((== (DBName colName)) . fieldDB)
$ entityFields def
getAlters :: [EntityDef a]
-> EntityDef SqlType
-> ([Column], [(DBName, [DBName])])
-> ([Column], [(DBName, [DBName])])
-> ([AlterColumn'], [AlterTable])
getAlters allDefs def (c1, u1) (c2, u2) =
(getAltersC c1 c2, getAltersU u1 u2)
where
getAltersC [] old = map (\x -> (cName x, Drop $ safeToRemove def $ cName x)) old
getAltersC (new:news) old =
let (alters, old') = findAlters allDefs (entityDB def) new old
in alters ++ getAltersC news old'
getAltersU :: [(DBName, [DBName])]
-> [(DBName, [DBName])]
-> [AlterTable]
getAltersU [] old = map DropConstraint $ filter (not . isManual) $ map fst old
getAltersU ((name, cols):news) old =
case lookup name old of
Nothing -> AddUniqueConstraint name cols : getAltersU news old
Just ocols ->
let old' = filter (\(x, _) -> x /= name) old
in if sort cols == sort ocols
then getAltersU news old'
else DropConstraint name
: AddUniqueConstraint name cols
: getAltersU news old'
isManual (DBName x) = "__manual_" `T.isPrefixOf` x
getColumn :: (Text -> IO Statement)
-> DBName -> [PersistValue]
-> IO (Either Text Column)
getColumn getter tname [PersistByteString x, PersistByteString y, PersistByteString z, d, npre, nscl] = do
case d' of
Left s -> return $ Left s
Right d'' ->
case getType (T.decodeUtf8 z) of
Left s -> return $ Left s
Right t -> do
let cname = DBName $ T.decodeUtf8 x
ref <- getRef cname
return $ Right Column
{ cName = cname
, cNull = y == "YES"
, cSqlType = t
, cDefault = d''
, cDefaultConstraintName = Nothing
, cMaxLen = Nothing
, cReference = ref
}
where
getRef cname = do
let sql = pack $ concat
[ "SELECT "
,"tc.table_name, "
,"kcu.column_name, "
,"ccu.table_name AS foreign_table_name, "
,"ccu.column_name AS foreign_column_name, "
,"kcu.ordinal_position "
,"FROM "
,"information_schema.table_constraints AS tc "
,"JOIN information_schema.key_column_usage "
,"AS kcu ON tc.constraint_name = kcu.constraint_name "
,"JOIN information_schema.constraint_column_usage "
,"AS ccu ON ccu.constraint_name = tc.constraint_name "
,"WHERE constraint_type = 'FOREIGN KEY' "
,"and tc.table_name=? "
,"and tc.constraint_name=? "
,"and tc.table_catalog=current_database() "
,"AND tc.table_catalog=kcu.table_catalog "
,"AND tc.table_catalog=ccu.table_catalog "
,"AND tc.table_schema=current_schema() "
,"AND tc.table_schema=kcu.table_schema "
,"AND tc.table_schema=ccu.table_schema "
,"order by tc.table_name,tc.constraint_name, kcu.ordinal_position "
]
let ref = refName tname cname
stmt <- getter sql
runResourceT $ stmtQuery stmt
[ PersistText $ unDBName tname
, PersistText $ unDBName ref
] $$ do
m <- CL.head
return $ case m of
Just [PersistText _table, PersistText _col, PersistText reftable, PersistText _refcol, PersistInt64 _pos] -> Just (DBName reftable, ref)
Just [PersistByteString _table, PersistByteString _col, PersistByteString reftable, PersistByteString _refcol, PersistInt64 _pos] -> Just (DBName (T.decodeUtf8 reftable), ref)
Nothing -> Nothing
_ -> error $ "unexpected result found ["++ show m ++ "]"
d' = case d of
PersistNull -> Right Nothing
PersistText t -> Right $ Just t
PersistByteString bs -> Right $ Just $ T.decodeUtf8 bs
_ -> Left $ pack $ "Invalid default column: " ++ show d
getType "int4" = Right $ SqlInt32
getType "int8" = Right $ SqlInt64
getType "varchar" = Right $ SqlString
getType "date" = Right $ SqlDay
getType "bool" = Right $ SqlBool
getType "timestamp" = Right $ SqlDayTime
getType "timestamptz" = Right $ SqlDayTimeZoned
getType "float4" = Right $ SqlReal
getType "float8" = Right $ SqlReal
getType "bytea" = Right $ SqlBlob
getType "time" = Right $ SqlTime
getType "numeric" = getNumeric npre nscl
getType a = Right $ SqlOther a
getNumeric (PersistInt64 a) (PersistInt64 b) = Right $ SqlNumeric (fromIntegral a) (fromIntegral b)
getNumeric a b = Left $ pack $ "Can not get numeric field precision, got: " ++ show a ++ " and " ++ show b ++ " as precision and scale"
getColumn _ a2 x =
return $ Left $ pack $ "Invalid result from information_schema: " ++ show x ++ " a2[" ++ show a2 ++ "]"
findAlters :: [EntityDef a] -> DBName -> Column -> [Column] -> ([AlterColumn'], [Column])
findAlters defs tablename col@(Column name isNull sqltype def _defConstraintName _maxLen ref) cols =
tracex ("\n\n\nfindAlters tablename="++show tablename++ " name="++ show name++" col="++show col++"\ncols="++show cols++"\n\n\n") $ case filter ((name ==) . cName) cols of
[] -> ([(name, Add' col)], cols)
Column _ isNull' sqltype' def' defConstraintName' _maxLen' ref':_ ->
let refDrop Nothing = []
refDrop (Just (_, cname)) = tracex ("\n\n\n44444 findAlters dropping fkey defConstraintName'="++show defConstraintName' ++" name="++show name++" cname="++show cname++" tablename="++show tablename++"\n\n\n") $
[(name, DropReference cname)]
refAdd Nothing = []
refAdd (Just (tname, a)) = tracex ("\n\n\n33333 findAlters adding fkey defConstraintName'="++show defConstraintName' ++" name="++show name++" tname="++show tname++" a="++show a++" tablename="++show tablename++"\n\n\n") $
case find ((==tname) . entityDB) defs of
Just refdef -> [(tname, AddReference a [name] [entityID refdef])]
Nothing -> error $ "could not find the entityDef for reftable[" ++ show tname ++ "]"
modRef = tracex ("modType: sqltype[" ++ show sqltype ++ "] sqltype'[" ++ show sqltype' ++ "] name=" ++ show name) $
if fmap snd ref == fmap snd ref'
then []
else tracex ("\n\n\nmodRef findAlters drop/add cos ref doesnt match ref[" ++ show ref ++ "] ref'[" ++ show ref' ++ "] tablename="++show tablename++"\n\n\n") $
refDrop ref' ++ refAdd ref
modNull = case (isNull, isNull') of
(True, False) -> [(name, IsNull)]
(False, True) ->
let up = case def of
Nothing -> id
Just s -> (:) (name, Update' $ T.unpack s)
in up [(name, NotNull)]
_ -> []
modType = tracex ("modType: sqltype[" ++ show sqltype ++ "] sqltype'[" ++ show sqltype' ++ "] name=" ++ show name) $
if sqltype == sqltype' then [] else [(name, Type sqltype)]
modDef = tracex ("modDef col=" ++ show col ++ " def=" ++ show def ++ " def'=" ++ show def') $
if cmpdef def def'
then []
else case def of
Nothing -> [(name, NoDefault)]
Just s -> [(name, Default $ T.unpack s)]
in (modRef ++ modDef ++ modNull ++ modType,
filter (\c -> cName c /= name) cols)
cmpdef::Maybe Text -> Maybe Text -> Bool
cmpdef Nothing Nothing = True
cmpdef (Just def) (Just def') | def==def' = True
| otherwise =
let (a,_)=T.breakOnEnd ":" def'
in
case T.stripSuffix "::" a of
Just xs -> def==xs
Nothing -> False
cmpdef _ _ = False
getAddReference :: [EntityDef a] -> DBName -> DBName -> DBName -> Maybe (DBName, DBName) -> Maybe AlterDB
getAddReference allDefs table reftable cname ref =
case ref of
Nothing -> Nothing
Just (s, z) -> tracex ("\n\ngetaddreference table="++ show table++" reftable="++show reftable++" s="++show s++" z=" ++ show z++"\n\n") $
Just $ AlterColumn table (s, AddReference (refName table cname) [cname] [id_])
where
id_ = maybe (error $ "Could not find ID of entity " ++ show reftable)
id $ do
entDef <- find ((== reftable) . entityDB) allDefs
return (entityID entDef)
showColumn :: Column -> String
showColumn (Column n nu sqlType' def _defConstraintName _maxLen _ref) = concat
[ T.unpack $ escape n
, " "
, showSqlType sqlType' _maxLen
, " "
, if nu then "NULL" else "NOT NULL"
, case def of
Nothing -> ""
Just s -> " DEFAULT " ++ T.unpack s
]
showSqlType :: SqlType -> Maybe Integer -> String
showSqlType SqlString Nothing = "VARCHAR"
showSqlType SqlString (Just len) = "VARCHAR(" ++ show len ++ ")"
showSqlType SqlInt32 _ = "INT4"
showSqlType SqlInt64 _ = "INT8"
showSqlType SqlReal _ = "DOUBLE PRECISION"
showSqlType (SqlNumeric s prec) _ = "NUMERIC(" ++ show s ++ "," ++ show prec ++ ")"
showSqlType SqlDay _ = "DATE"
showSqlType SqlTime _ = "TIME"
showSqlType SqlDayTime _ = "TIMESTAMP"
showSqlType SqlDayTimeZoned _ = "TIMESTAMP WITH TIME ZONE"
showSqlType SqlBlob _ = "BYTEA"
showSqlType SqlBool _ = "BOOLEAN"
showSqlType (SqlOther t) _ = T.unpack t
showAlterDb :: AlterDB -> (Bool, Text)
showAlterDb (AddTable s) = (False, pack s)
showAlterDb (AlterColumn t (c, ac)) =
(isUnsafe ac, pack $ showAlter t (c, ac))
where
isUnsafe (Drop safeToRemove) = not safeToRemove
isUnsafe _ = False
showAlterDb (AlterTable t at) = (False, pack $ showAlterTable t at)
showAlterTable :: DBName -> AlterTable -> String
showAlterTable table (AddUniqueConstraint cname cols) = concat
[ "ALTER TABLE "
, T.unpack $ escape table
, " ADD CONSTRAINT "
, T.unpack $ escape cname
, " UNIQUE("
, intercalate "," $ map (T.unpack . escape) cols
, ")"
]
showAlterTable table (DropConstraint cname) = concat
[ "ALTER TABLE "
, T.unpack $ escape table
, " DROP CONSTRAINT "
, T.unpack $ escape cname
]
showAlter :: DBName -> AlterColumn' -> String
showAlter table (n, Type t) =
concat
[ "ALTER TABLE "
, T.unpack $ escape table
, " ALTER COLUMN "
, T.unpack $ escape n
, " TYPE "
, showSqlType t Nothing
]
showAlter table (n, IsNull) =
concat
[ "ALTER TABLE "
, T.unpack $ escape table
, " ALTER COLUMN "
, T.unpack $ escape n
, " DROP NOT NULL"
]
showAlter table (n, NotNull) =
concat
[ "ALTER TABLE "
, T.unpack $ escape table
, " ALTER COLUMN "
, T.unpack $ escape n
, " SET NOT NULL"
]
showAlter table (_, Add' col) =
concat
[ "ALTER TABLE "
, T.unpack $ escape table
, " ADD COLUMN "
, showColumn col
]
showAlter table (n, Drop _) =
concat
[ "ALTER TABLE "
, T.unpack $ escape table
, " DROP COLUMN "
, T.unpack $ escape n
]
showAlter table (n, Default s) =
concat
[ "ALTER TABLE "
, T.unpack $ escape table
, " ALTER COLUMN "
, T.unpack $ escape n
, " SET DEFAULT "
, s
]
showAlter table (n, NoDefault) = concat
[ "ALTER TABLE "
, T.unpack $ escape table
, " ALTER COLUMN "
, T.unpack $ escape n
, " DROP DEFAULT"
]
showAlter table (n, Update' s) = concat
[ "UPDATE "
, T.unpack $ escape table
, " SET "
, T.unpack $ escape n
, "="
, s
, " WHERE "
, T.unpack $ escape n
, " IS NULL"
]
showAlter table (reftable, AddReference fkeyname t2 id2) = concat
[ "ALTER TABLE "
, T.unpack $ escape table
, " ADD CONSTRAINT "
, T.unpack $ escape fkeyname
, " FOREIGN KEY("
, T.unpack $ T.intercalate "," $ map escape t2
, ") REFERENCES "
, T.unpack $ escape reftable
, "("
, T.unpack $ T.intercalate "," $ map escape id2
, ")"
]
showAlter table (_, DropReference cname) = concat
[ "ALTER TABLE "
, T.unpack (escape table)
, " DROP CONSTRAINT "
, T.unpack $ escape cname
]
escape :: DBName -> Text
escape (DBName s) =
T.pack $ '"' : go (T.unpack s) ++ "\""
where
go "" = ""
go ('"':xs) = "\"\"" ++ go xs
go (x:xs) = x : go xs
refName :: DBName -> DBName -> DBName
refName (DBName table) (DBName column) =
DBName $ T.concat [table, "_", column, "_fkey"]
udToPair :: UniqueDef -> (DBName, [DBName])
udToPair ud = (uniqueDBName ud, map snd $ uniqueFields ud)
insertSql' :: EntityDef SqlType -> [PersistValue] -> InsertSqlResult
insertSql' ent vals = tracex ("\n\n\nGBTEST " ++ show (entityFields ent) ++ "\n\n\n") $
case entityPrimary ent of
Just _pdef ->
ISRManyKeys sql vals
where sql = pack $ concat
[ "INSERT INTO "
, T.unpack $ escape $ entityDB ent
, "("
, intercalate "," $ map (T.unpack . escape . fieldDB) $ entityFields ent
, ") VALUES("
, intercalate "," (map (const "?") $ entityFields ent)
, ")"
]
Nothing ->
ISRSingle $ pack $ concat
[ "INSERT INTO "
, T.unpack $ escape $ entityDB ent
, "("
, intercalate "," $ map (T.unpack . escape . fieldDB) $ entityFields ent
, ") VALUES("
, intercalate "," (map (const "?") $ entityFields ent)
, ") RETURNING "
, T.unpack $ escape $ entityID ent
]