{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ViewPatterns #-}
module Database.Beam.AutoMigrate
(
defaultAnnotatedDbSettings,
fromAnnotatedDbSettings,
deAnnotateDatabase,
Migration,
migrate,
runMigrationUnsafe,
runMigrationWithEditUpdate,
tryRunMigrationsWithEditUpdate,
createMigration,
splitEditsOnSafety,
fastApproximateRowCountFor,
prettyEditActionDescription,
prettyEditSQL,
printMigration,
printMigrationIO,
unsafeRunMigration,
module Exports,
FromAnnotated,
ToAnnotated,
sqlSingleQuoted,
sqlEscaped,
editToSqlCommand,
)
where
import Control.Exception
import Control.Monad.Except
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Identity (runIdentity)
import Control.Monad.State.Strict
import Data.Bifunctor (first)
import Data.Function ((&))
import Data.Int (Int64)
import Data.List (foldl')
import qualified Data.Map.Strict as M
import Data.Maybe (fromMaybe)
import Data.Proxy
import qualified Data.Set as S
import Data.String.Conv (toS)
import Data.Text (Text)
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import qualified Data.Text.Lazy as LT
import Database.Beam (MonadBeam)
import Database.Beam.AutoMigrate.Annotated as Exports
import Database.Beam.AutoMigrate.Compat as Exports
import Database.Beam.AutoMigrate.Diff as Exports
import Database.Beam.AutoMigrate.Generic as Exports
import Database.Beam.AutoMigrate.Postgres (getSchema)
import Database.Beam.AutoMigrate.Types as Exports
import Database.Beam.AutoMigrate.Util hiding (tableName)
import Database.Beam.AutoMigrate.Validity as Exports
import Database.Beam.Backend.SQL hiding (tableName)
import qualified Database.Beam.Backend.SQL.AST as AST
import qualified Database.Beam.Postgres as Pg
import qualified Database.Beam.Postgres.Syntax as Pg
import Database.Beam.Schema (Database, DatabaseSettings)
import Database.Beam.Schema.Tables (DatabaseEntity (..))
import qualified Database.PostgreSQL.Simple as Pg
import GHC.Generics hiding (prec)
import Lens.Micro (over, (^.), _1, _2)
import qualified Text.Pretty.Simple as PS
type ToAnnotated (be :: *) (db :: DatabaseKind) e1 e2 =
( Generic (db (e1 be db)),
Generic (db (e2 be db)),
Database be db,
GZipDatabase
be
(e1 be db)
(e2 be db)
(e2 be db)
(Rep (db (e1 be db)))
(Rep (db (e2 be db)))
(Rep (db (e2 be db)))
)
type FromAnnotated (be :: *) (db :: DatabaseKind) e1 e2 =
( Generic (db (e1 be db)),
Generic (db (e2 be db)),
Database be db,
GZipDatabase
be
(e2 be db)
(e2 be db)
(e1 be db)
(Rep (db (e2 be db)))
(Rep (db (e2 be db)))
(Rep (db (e1 be db)))
)
defaultAnnotatedDbSettings ::
forall be db.
ToAnnotated be db DatabaseEntity AnnotatedDatabaseEntity =>
DatabaseSettings be db ->
AnnotatedDatabaseSettings be db
defaultAnnotatedDbSettings db =
runIdentity $
zipTables (Proxy @be) annotate db (undefined :: AnnotatedDatabaseSettings be db)
where
annotate ::
( Monad m,
IsAnnotatedDatabaseEntity be ty,
AnnotatedDatabaseEntityRegularRequirements be ty
) =>
DatabaseEntity be db ty ->
AnnotatedDatabaseEntity be db ty ->
m (AnnotatedDatabaseEntity be db ty)
annotate (DatabaseEntity edesc) _ =
pure $ AnnotatedDatabaseEntity (dbAnnotatedEntityAuto edesc) (DatabaseEntity edesc)
deAnnotateDatabase ::
forall be db.
FromAnnotated be db DatabaseEntity AnnotatedDatabaseEntity =>
AnnotatedDatabaseSettings be db ->
DatabaseSettings be db
deAnnotateDatabase db =
runIdentity $ zipTables (Proxy @be) (\ann _ -> pure $ ann ^. deannotate) db db
fromAnnotatedDbSettings ::
( FromAnnotated be db DatabaseEntity AnnotatedDatabaseEntity,
GSchema be db anns (Rep (AnnotatedDatabaseSettings be db))
) =>
AnnotatedDatabaseSettings be db ->
Proxy (anns :: [Annotation]) ->
Schema
fromAnnotatedDbSettings db p = gSchema db p (from db)
editsToPgSyntax :: [WithPriority Edit] -> [Pg.PgSyntax]
editsToPgSyntax = map (toSqlSyntax . fst . unPriority)
type Migration m = ExceptT MigrationError (StateT [WithPriority Edit] m) ()
data MigrationError
= DiffFailed DiffError
| HaskellSchemaValidationFailed [ValidationFailed]
| DatabaseSchemaValidationFailed [ValidationFailed]
| UnsafeEditsDetected [EditAction]
deriving (Show)
instance Exception MigrationError
splitEditsOnSafety :: [WithPriority Edit] -> ([WithPriority Edit], [WithPriority Edit])
splitEditsOnSafety =
foldl'
( \acc p ->
if editSafetyIs Unsafe (fst $ unPriority p)
then over _1 (p :) acc
else over _2 (p :) acc
)
(mempty, mempty)
migrate :: MonadIO m => Pg.Connection -> Schema -> Migration m
migrate conn hsSchema = do
dbSchema <- lift . liftIO $ getSchema conn
liftEither $ first HaskellSchemaValidationFailed (validateSchema hsSchema)
liftEither $ first DatabaseSchemaValidationFailed (validateSchema dbSchema)
let schemaDiff = diff hsSchema dbSchema
case schemaDiff of
Left e -> throwError (DiffFailed e)
Right edits -> lift (put edits)
unsafeRunMigration :: (MonadBeam Pg.Postgres m, MonadIO m) => Migration m -> m ()
unsafeRunMigration m = do
migs <- evalMigration m
case migs of
Left e -> liftIO $ throwIO e
Right (sortEdits -> edits) ->
runNoReturn $ Pg.PgCommandSyntax Pg.PgCommandTypeDdl (mconcat . editsToPgSyntax $ edits)
runMigrationUnsafe :: MonadBeam Pg.Postgres Pg.Pg => Pg.Connection -> Migration Pg.Pg -> IO ()
runMigrationUnsafe conn mig = Pg.withTransaction conn $ Pg.runBeamPostgres conn (unsafeRunMigration mig)
runMigrationWithEditUpdate ::
MonadBeam Pg.Postgres Pg.Pg =>
([WithPriority Edit] -> [WithPriority Edit]) ->
Pg.Connection ->
Schema ->
IO ()
runMigrationWithEditUpdate editUpdate conn hsSchema = do
edits <- either throwIO pure =<< evalMigration (migrate conn hsSchema)
let newEdits = sortEdits $ editUpdate $ sortEdits edits
when (any (editSafetyIs Unsafe . fst . unPriority) newEdits) $
throwIO $ UnsafeEditsDetected $ fmap (\(WithPriority (e, _)) -> _editAction e) newEdits
Pg.withTransaction conn $
Pg.runBeamPostgres conn $
forM_ newEdits $ \(WithPriority (edit, _)) -> do
case _editCondition edit of
Right Unsafe -> liftIO $ throwIO $ UnsafeEditsDetected [_editAction edit]
Right safeMaybeSlow -> safeOrSlow safeMaybeSlow edit
Left ec -> do
printmsg $ "edit has condition: " <> toS (prettyEditConditionQuery ec)
checkedSafety <- _editCondition_check ec
case checkedSafety of
Unsafe -> do
printmsg "edit unsafe by condition"
liftIO $ throwIO $ UnsafeEditsDetected [_editAction edit]
safeMaybeSlow -> do
printmsg "edit condition satisfied"
safeOrSlow safeMaybeSlow edit
where
safeOrSlow safety edit = do
when (safety == PotentiallySlow) $ do
printmsg "Running potentially slow edit"
printmsg $ T.unpack $ prettyEditActionDescription $ _editAction edit
runNoReturn $ editToSqlCommand edit
printmsg :: MonadIO m => String -> m ()
printmsg = liftIO . putStrLn . mappend "[beam-migrate] "
fastApproximateRowCountFor :: TableName -> Pg.Pg (Maybe Int64)
fastApproximateRowCountFor tblName = runReturningOne $ selectCmd $ Pg.PgSelectSyntax $ qry
where
qry =
Pg.emit $
toS $
"SELECT reltuples AS approximate_row_count FROM pg_class WHERE relname = "
<> sqlEscaped (tableName tblName)
<> ";"
data AlterTableAction
= SetConstraint
| DropConstraint
deriving (Show, Eq)
toSqlSyntax :: Edit -> Pg.PgSyntax
toSqlSyntax e =
safetyPrefix $
_editAction e & \case
TableAdded tblName tbl ->
ddlSyntax
( "CREATE TABLE " <> sqlEscaped (tableName tblName)
<> " ("
<> T.intercalate ", " (map renderTableColumn (M.toList (tableColumns tbl)))
<> ")"
)
TableRemoved tblName ->
ddlSyntax ("DROP TABLE " <> sqlEscaped (tableName tblName))
TableConstraintAdded tblName cstr ->
updateSyntax (alterTable tblName <> renderAddConstraint cstr)
TableConstraintRemoved tblName cstr ->
updateSyntax (alterTable tblName <> renderDropConstraint cstr)
SequenceAdded sName (Sequence _tName _cName) -> createSequenceSyntax sName
SequenceRemoved sName -> dropSequenceSyntax sName
EnumTypeAdded tyName vals -> createTypeSyntax tyName vals
EnumTypeRemoved (EnumerationName tyName) -> ddlSyntax ("DROP TYPE " <> tyName)
EnumTypeValueAdded (EnumerationName tyName) newVal order insPoint ->
ddlSyntax
( "ALTER TYPE " <> tyName
<> " ADD VALUE "
<> sqlSingleQuoted newVal
<> " "
<> renderInsertionOrder order
<> " "
<> sqlSingleQuoted insPoint
)
ColumnAdded tblName colName col ->
updateSyntax
( alterTable tblName
<> "ADD COLUMN "
<> sqlEscaped (columnName colName)
<> " "
<> renderDataType (columnType col)
<> " "
<> T.intercalate " " (map (renderColumnConstraint SetConstraint) (S.toList $ columnConstraints col))
)
ColumnRemoved tblName colName ->
updateSyntax (alterTable tblName <> "DROP COLUMN " <> sqlEscaped (columnName colName))
ColumnTypeChanged tblName colName _old new ->
updateSyntax
( alterTable tblName <> "ALTER COLUMN "
<> sqlEscaped (columnName colName)
<> " TYPE "
<> renderDataType new
)
ColumnConstraintAdded tblName colName cstr ->
updateSyntax
( alterTable tblName <> "ALTER COLUMN "
<> sqlEscaped (columnName colName)
<> " SET "
<> renderColumnConstraint SetConstraint cstr
)
ColumnConstraintRemoved tblName colName cstr ->
updateSyntax
( alterTable tblName <> "ALTER COLUMN "
<> sqlEscaped (columnName colName)
<> " DROP "
<> renderColumnConstraint DropConstraint cstr
)
where
safetyPrefix query =
if editSafetyIs Safe e
then Pg.emit " " <> query
else Pg.emit "<UNSAFE>" <> query
ddlSyntax query = Pg.emit . TE.encodeUtf8 $ query <> ";\n"
updateSyntax query = Pg.emit . TE.encodeUtf8 $ query <> ";\n"
alterTable :: TableName -> Text
alterTable (TableName tName) = "ALTER TABLE " <> sqlEscaped tName <> " "
renderTableColumn :: (ColumnName, Column) -> Text
renderTableColumn (colName, col) =
sqlEscaped (columnName colName) <> " "
<> renderDataType (columnType col)
<> " "
<> T.intercalate " " (map (renderColumnConstraint SetConstraint) (S.toList $ columnConstraints col))
renderInsertionOrder :: InsertionOrder -> Text
renderInsertionOrder Before = "BEFORE"
renderInsertionOrder After = "AFTER"
renderCreateTableConstraint :: TableConstraint -> Text
renderCreateTableConstraint = \case
Unique fname cols ->
conKeyword <> sqlEscaped fname
<> " UNIQUE ("
<> T.intercalate ", " (map (sqlEscaped . columnName) (S.toList cols))
<> ")"
PrimaryKey fname cols ->
conKeyword <> sqlEscaped fname
<> " PRIMARY KEY ("
<> T.intercalate ", " (map (sqlEscaped . columnName) (S.toList cols))
<> ")"
ForeignKey fname (tableName -> tName) (S.toList -> colPair) onDelete onUpdate ->
let (fkCols, referenced) =
( map (sqlEscaped . columnName . fst) colPair,
map (sqlEscaped . columnName . snd) colPair
)
in conKeyword <> sqlEscaped fname
<> " FOREIGN KEY ("
<> T.intercalate ", " fkCols
<> ") REFERENCES "
<> sqlEscaped tName
<> "("
<> T.intercalate ", " referenced
<> ")"
<> renderAction "ON DELETE" onDelete
<> renderAction "ON UPDATE" onUpdate
where
conKeyword = "CONSTRAINT "
renderAddConstraint :: TableConstraint -> Text
renderAddConstraint = mappend "ADD " . renderCreateTableConstraint
renderDropConstraint :: TableConstraint -> Text
renderDropConstraint tc = case tc of
Unique cName _ -> dropC cName
PrimaryKey cName _ -> dropC cName
ForeignKey cName _ _ _ _ -> dropC cName
where
dropC = mappend "DROP CONSTRAINT " . sqlEscaped
renderAction actionPrefix = \case
NoAction -> mempty
Cascade -> " " <> actionPrefix <> " " <> "CASCADE "
Restrict -> " " <> actionPrefix <> " " <> "RESTRICT "
SetNull -> " " <> actionPrefix <> " " <> "SET NULL "
SetDefault -> " " <> actionPrefix <> " " <> "SET DEFAULT "
renderColumnConstraint :: AlterTableAction -> ColumnConstraint -> Text
renderColumnConstraint act = \case
NotNull -> "NOT NULL"
Default defValue | act == SetConstraint -> "DEFAULT " <> defValue
Default _ -> "DEFAULT"
createTypeSyntax :: EnumerationName -> Enumeration -> Pg.PgSyntax
createTypeSyntax (EnumerationName ty) (Enumeration vals) =
Pg.emit $
toS $
"CREATE TYPE " <> ty <> " AS ENUM (" <> T.intercalate "," (map sqlSingleQuoted vals) <> ");\n"
createSequenceSyntax :: SequenceName -> Pg.PgSyntax
createSequenceSyntax (SequenceName s) = Pg.emit $ toS $ "CREATE SEQUENCE " <> sqlEscaped s <> ";\n"
dropSequenceSyntax :: SequenceName -> Pg.PgSyntax
dropSequenceSyntax (SequenceName s) = Pg.emit $ toS $ "DROP SEQUENCE " <> sqlEscaped s <> ";\n"
renderStdType :: AST.DataType -> Text
renderStdType = \case
(AST.DataTypeChar False prec charSet) ->
"CHAR" <> sqlOptPrec (Just $ fromMaybe 1 prec) <> sqlOptCharSet charSet
(AST.DataTypeChar True prec charSet) ->
"VARCHAR" <> sqlOptPrec prec <> sqlOptCharSet charSet
(AST.DataTypeNationalChar varying prec) ->
let ty = if varying then "NATIONAL CHARACTER VARYING" else "NATIONAL CHAR"
in ty <> sqlOptPrec prec
(AST.DataTypeBit varying prec) ->
let ty = if varying then "BIT VARYING" else "BIT"
in ty <> sqlOptPrec prec
(AST.DataTypeNumeric prec) -> "NUMERIC" <> sqlOptNumericPrec prec
(AST.DataTypeDecimal prec) -> "NUMERIC" <> sqlOptNumericPrec prec
AST.DataTypeInteger -> "INT"
AST.DataTypeSmallInt -> "SMALLINT"
AST.DataTypeBigInt -> "BIGINT"
(AST.DataTypeFloat prec) -> "FLOAT" <> sqlOptPrec prec
AST.DataTypeReal -> "REAL"
AST.DataTypeDoublePrecision -> "DOUBLE PRECISION"
AST.DataTypeDate -> "DATE"
(AST.DataTypeTime prec withTz) -> wTz withTz "TIME" prec <> sqlOptPrec prec
(AST.DataTypeTimeStamp prec withTz) -> wTz withTz "TIMESTAMP" prec <> sqlOptPrec prec
(AST.DataTypeInterval _i) ->
error $
"Impossible: DataTypeInterval doesn't map to any SQLXX beam typeclass, so we don't know"
<> " how to render it."
(AST.DataTypeIntervalFromTo _from _to) ->
error $
"Impossible: DataTypeIntervalFromTo doesn't map to any SQLXX beam typeclass, so we don't know"
<> " how to render it."
AST.DataTypeBoolean -> "BOOL"
AST.DataTypeBinaryLargeObject -> "BYTEA"
AST.DataTypeCharacterLargeObject -> "TEXT"
(AST.DataTypeArray dt sz) ->
renderStdType dt <> "[" <> T.pack (show sz) <> "]"
(AST.DataTypeRow _rows) ->
error "DataTypeRow not supported both for beam-postgres and this library."
(AST.DataTypeDomain nm) -> "\"" <> nm <> "\""
where
wTz withTz tt prec =
tt <> sqlOptPrec prec <> (if withTz then " WITH" else " WITHOUT") <> " TIME ZONE"
renderDataType :: ColumnType -> Text
renderDataType = \case
SqlStdType stdType -> renderStdType stdType
DbEnumeration (EnumerationName _) _ ->
renderDataType (SqlStdType (AST.DataTypeChar True Nothing Nothing))
PgSpecificType PgJson -> toS $ displaySyntax Pg.pgJsonType
PgSpecificType PgJsonB -> toS $ displaySyntax Pg.pgJsonbType
PgSpecificType PgRangeInt4 -> toS $ Pg.rangeName @Pg.PgInt4Range
PgSpecificType PgRangeInt8 -> toS $ Pg.rangeName @Pg.PgInt8Range
PgSpecificType PgRangeNum -> toS $ Pg.rangeName @Pg.PgNumRange
PgSpecificType PgRangeTs -> toS $ Pg.rangeName @Pg.PgTsRange
PgSpecificType PgRangeTsTz -> toS $ Pg.rangeName @Pg.PgTsTzRange
PgSpecificType PgRangeDate -> toS $ Pg.rangeName @Pg.PgDateRange
PgSpecificType PgUuid -> toS $ displaySyntax Pg.pgUuidType
PgSpecificType (PgEnumeration (EnumerationName ty)) -> ty
evalMigration :: Monad m => Migration m -> m (Either MigrationError [WithPriority Edit])
evalMigration m = do
(a, s) <- runStateT (runExceptT m) mempty
case a of
Left e -> pure (Left e)
Right () -> pure (Right s)
createMigration :: Monad m => Diff -> Migration m
createMigration (Left e) = throwError (DiffFailed e)
createMigration (Right edits) = ExceptT $ do
put edits
pure (Right ())
printMigration :: MonadIO m => Migration m -> m ()
printMigration m = do
(a, sortedEdits) <- fmap sortEdits <$> runStateT (runExceptT m) mempty
case a of
Left e -> liftIO $ throwIO e
Right () -> liftIO $ putStrLn (unlines . map displaySyntax $ editsToPgSyntax sortedEdits)
printMigrationIO :: Migration Pg.Pg -> IO ()
printMigrationIO mig = Pg.runBeamPostgres (undefined :: Pg.Connection) $ printMigration mig
editToSqlCommand :: Edit -> Pg.PgCommandSyntax
editToSqlCommand = Pg.PgCommandSyntax Pg.PgCommandTypeDdl . toSqlSyntax
prettyEditSQL :: Edit -> Text
prettyEditSQL = T.pack . displaySyntax . Pg.fromPgCommand . editToSqlCommand
prettyEditActionDescription :: EditAction -> Text
prettyEditActionDescription =
T.unwords . \case
TableAdded tblName table ->
["create table:", qt tblName, "\n", pshow' table]
TableRemoved tblName ->
["remove table:", qt tblName]
TableConstraintAdded tblName tableConstraint ->
["add table constraint to:", qt tblName, "\n", pshow' tableConstraint]
TableConstraintRemoved tblName tableConstraint ->
["remove table constraint from:", qt tblName, "\n", pshow' tableConstraint]
ColumnAdded tblName colName column ->
["add column:", qc colName, ", from:", qt tblName, "\n", pshow' column]
ColumnRemoved tblName colName ->
["remove column:", qc colName, ", from:", qt tblName]
ColumnTypeChanged tblName colName oldColumnType newColumnType ->
[ "change type of column:",
qc colName,
"in table:",
qt tblName,
"\nfrom:",
renderDataType oldColumnType,
"\nto:",
renderDataType newColumnType
]
ColumnConstraintAdded tblName colName columnConstraint ->
[ "add column constraint to:",
qc colName,
"in table:",
qt tblName,
"\n",
pshow' columnConstraint
]
ColumnConstraintRemoved tblName colName columnConstraint ->
[ "remove column constraint from:",
qc colName,
"in table:",
qt tblName,
"\n",
pshow' columnConstraint
]
EnumTypeAdded eName enumeration ->
["add enum type:", enumName eName, pshow' enumeration]
EnumTypeRemoved eName ->
["remove enum type:", enumName eName]
EnumTypeValueAdded eName newValue insertionOrder insertedAt ->
[ "add enum value to enum:",
enumName eName,
", value:",
newValue,
", with order:",
pshow' insertionOrder,
", at pos",
insertedAt
]
SequenceAdded sequenceName sequence0 ->
["add sequence:", qs sequenceName, pshow' sequence0]
SequenceRemoved sequenceName ->
["remove sequence:", qs sequenceName]
where
q t = "'" <> t <> "'"
qt = q . tableName
qc = q . columnName
qs = q . seqName
pshow' :: Show a => a -> Text
pshow' = LT.toStrict . PS.pShow
tryRunMigrationsWithEditUpdate
:: ( Generic (db (DatabaseEntity be db))
, (Generic (db (AnnotatedDatabaseEntity be db)))
, Database be db
, (GZipDatabase be
(AnnotatedDatabaseEntity be db)
(AnnotatedDatabaseEntity be db)
(DatabaseEntity be db)
(Rep (db (AnnotatedDatabaseEntity be db)))
(Rep (db (AnnotatedDatabaseEntity be db)))
(Rep (db (DatabaseEntity be db)))
)
, (GSchema be db '[] (Rep (db (AnnotatedDatabaseEntity be db))))
)
=> AnnotatedDatabaseSettings be db
-> Pg.Connection
-> IO ()
tryRunMigrationsWithEditUpdate annotatedDb conn = do
let expectedHaskellSchema = fromAnnotatedDbSettings annotatedDb (Proxy @'[])
actualDatabaseSchema <- getSchema conn
case diff expectedHaskellSchema actualDatabaseSchema of
Left err -> do
putStrLn "Error detecting database migration requirements: "
print err
Right [] ->
putStrLn "No database migration required, continuing startup."
Right edits -> do
putStrLn "Database migration required, attempting..."
putStrLn $ T.unpack $ T.unlines $ fmap (prettyEditSQL . fst . unPriority) edits
try (runMigrationWithEditUpdate Prelude.id conn expectedHaskellSchema) >>= \case
Left (e :: SomeException) ->
error $ "Database migration error: " <> displayException e
Right _ ->
pure ()