{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Database.Beam.AutoMigrate.Postgres
( getSchema,
)
where
import Control.Monad.State
import Data.Bits (shiftR, (.&.))
import Data.ByteString (ByteString)
import Data.Foldable (asum, foldlM)
import Data.Map (Map)
import qualified Data.Map.Strict as M
import Data.Maybe (fromMaybe)
import Data.Set (Set)
import qualified Data.Set as S
import Data.String
import Data.Text (Text)
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import qualified Data.Vector as V
import Database.Beam.AutoMigrate.Types
import Database.Beam.Backend.SQL hiding (tableName)
import qualified Database.Beam.Backend.SQL.AST as AST
import qualified Database.PostgreSQL.Simple as Pg
import Database.PostgreSQL.Simple.FromField (FromField (..), fromField, returnError)
import Database.PostgreSQL.Simple.FromRow (FromRow (..), field)
import qualified Database.PostgreSQL.Simple.TypeInfo.Static as Pg
import qualified Database.PostgreSQL.Simple.Types as Pg
data SqlRawOtherConstraintType
= SQL_raw_pk
| SQL_raw_unique
deriving (Show, Eq)
data SqlOtherConstraint = SqlOtherConstraint
{ sqlCon_name :: Text,
sqlCon_constraint_type :: SqlRawOtherConstraintType,
sqlCon_table :: TableName,
sqlCon_fk_colums :: V.Vector ColumnName
}
deriving (Show, Eq)
instance Pg.FromRow SqlOtherConstraint where
fromRow =
SqlOtherConstraint <$> field
<*> field
<*> fmap TableName field
<*> fmap (V.map ColumnName) field
data SqlForeignConstraint = SqlForeignConstraint
{ sqlFk_foreign_table :: TableName,
sqlFk_primary_table :: TableName,
sqlFk_fk_columns :: V.Vector ColumnName,
sqlFk_pk_columns :: V.Vector ColumnName,
sqlFk_name :: Text
}
deriving (Show, Eq)
instance Pg.FromRow SqlForeignConstraint where
fromRow =
SqlForeignConstraint <$> fmap TableName field
<*> fmap TableName field
<*> fmap (V.map ColumnName) field
<*> fmap (V.map ColumnName) field
<*> field
instance FromField TableName where
fromField f dat = TableName <$> fromField f dat
instance FromField ColumnName where
fromField f dat = ColumnName <$> fromField f dat
instance FromField SqlRawOtherConstraintType where
fromField f dat = do
t :: String <- fromField f dat
case t of
"p" -> pure SQL_raw_pk
"u" -> pure SQL_raw_unique
_ -> returnError Pg.ConversionFailed f t
userTablesQ :: Pg.Query
userTablesQ =
fromString $
unlines
[ "SELECT cl.oid, relname FROM pg_catalog.pg_class \"cl\" join pg_catalog.pg_namespace \"ns\" ",
"on (ns.oid = relnamespace) where nspname = any (current_schemas(false)) and relkind='r' ",
"and relname NOT LIKE 'beam_%'"
]
defaultsQ :: Pg.Query
defaultsQ =
fromString $
unlines
[ "SELECT col.table_name::text, col.column_name::text, col.column_default::text, col.data_type::text ",
"FROM information_schema.columns col ",
"WHERE col.column_default IS NOT NULL ",
"AND col.table_schema NOT IN('information_schema', 'pg_catalog') ",
"ORDER BY col.table_name"
]
tableColumnsQ :: Pg.Query
tableColumnsQ =
fromString $
unlines
[ "SELECT attname, atttypid, atttypmod, attnotnull, pg_catalog.format_type(atttypid, atttypmod) ",
"FROM pg_catalog.pg_attribute att ",
"WHERE att.attrelid=? AND att.attnum>0 AND att.attisdropped='f' "
]
enumerationsQ :: Pg.Query
enumerationsQ =
fromString $
unlines
[ "SELECT t.typname, t.oid, array_agg(e.enumlabel ORDER BY e.enumsortorder)",
"FROM pg_enum e JOIN pg_type t ON t.oid = e.enumtypid",
"GROUP BY t.typname, t.oid"
]
sequencesQ :: Pg.Query
sequencesQ = fromString "SELECT c.relname FROM pg_class c WHERE c.relkind = 'S'"
foreignKeysQ :: Pg.Query
foreignKeysQ =
fromString $
unlines
[ "SELECT kcu.table_name::text as foreign_table,",
" rel_kcu.table_name::text as primary_table,",
" array_agg(kcu.column_name::text)::text[] as fk_columns,",
" array_agg(rel_kcu.column_name::text)::text[] as pk_columns,",
" kcu.constraint_name as cname",
"FROM information_schema.table_constraints tco",
"JOIN information_schema.key_column_usage kcu",
" on tco.constraint_schema = kcu.constraint_schema",
" and tco.constraint_name = kcu.constraint_name",
"JOIN information_schema.referential_constraints rco",
" on tco.constraint_schema = rco.constraint_schema",
" and tco.constraint_name = rco.constraint_name",
"JOIN information_schema.key_column_usage rel_kcu",
" on rco.unique_constraint_schema = rel_kcu.constraint_schema",
" and rco.unique_constraint_name = rel_kcu.constraint_name",
" and kcu.ordinal_position = rel_kcu.ordinal_position",
"GROUP BY foreign_table, primary_table, cname"
]
otherConstraintsQ :: Pg.Query
otherConstraintsQ =
fromString $
unlines
[ "SELECT c.conname AS constraint_name,",
" c.contype AS constraint_type,",
" tbl.relname AS \"table\",",
" ARRAY_AGG(col.attname ORDER BY u.attposition) AS columns",
"FROM pg_constraint c",
" JOIN LATERAL UNNEST(c.conkey) WITH ORDINALITY AS u(attnum, attposition) ON TRUE",
" JOIN pg_class tbl ON tbl.oid = c.conrelid",
" JOIN pg_namespace sch ON sch.oid = tbl.relnamespace",
" JOIN pg_attribute col ON (col.attrelid = tbl.oid AND col.attnum = u.attnum)",
"WHERE c.contype = 'u' OR c.contype = 'p'",
"GROUP BY constraint_name, constraint_type, \"table\"",
"ORDER BY c.contype"
]
referenceActionsQ :: Pg.Query
referenceActionsQ =
fromString $
unlines
[ "SELECT c.conname, c. confdeltype, c.confupdtype FROM ",
"(SELECT r.conrelid, r.confrelid, unnest(r.conkey) AS conkey, unnest(r.confkey) AS confkey, r.conname, r.confupdtype, r.confdeltype ",
"FROM pg_catalog.pg_constraint r WHERE r.contype = 'f') AS c ",
"INNER JOIN pg_attribute a_parent ON a_parent.attnum = c.confkey AND a_parent.attrelid = c.confrelid ",
"INNER JOIN pg_class cl_parent ON cl_parent.oid = c.confrelid ",
"INNER JOIN pg_namespace sch_parent ON sch_parent.oid = cl_parent.relnamespace ",
"INNER JOIN pg_attribute a_child ON a_child.attnum = c.conkey AND a_child.attrelid = c.conrelid ",
"INNER JOIN pg_class cl_child ON cl_child.oid = c.conrelid ",
"INNER JOIN pg_namespace sch_child ON sch_child.oid = cl_child.relnamespace ",
"WHERE sch_child.nspname = current_schema() ORDER BY c.conname "
]
getSchema :: Pg.Connection -> IO Schema
getSchema conn = do
allTableConstraints <- getAllConstraints conn
allDefaults <- getAllDefaults conn
enumerationData <- Pg.fold_ conn enumerationsQ mempty getEnumeration
sequences <- Pg.fold_ conn sequencesQ mempty getSequence
tables <-
Pg.fold_ conn userTablesQ mempty (getTable allDefaults enumerationData allTableConstraints)
pure $ Schema tables (M.fromList $ M.elems enumerationData) sequences
where
getEnumeration ::
Map Pg.Oid (EnumerationName, Enumeration) ->
(Text, Pg.Oid, V.Vector Text) ->
IO (Map Pg.Oid (EnumerationName, Enumeration))
getEnumeration allEnums (enumName, oid, V.toList -> vals) =
pure $ M.insert oid (EnumerationName enumName, Enumeration vals) allEnums
getSequence ::
Sequences ->
Pg.Only Text ->
IO Sequences
getSequence allSeqs (Pg.Only seqName) =
case T.splitOn "___" seqName of
[tName, cName, "seq"] ->
pure $ M.insert (SequenceName seqName) (Sequence (TableName tName) (ColumnName cName)) allSeqs
_ -> pure allSeqs
getTable ::
AllDefaults ->
Map Pg.Oid (EnumerationName, Enumeration) ->
AllTableConstraints ->
Tables ->
(Pg.Oid, Text) ->
IO Tables
getTable allDefaults enumData allTableConstraints allTables (oid, TableName -> tName) = do
pgColumns <- Pg.query conn tableColumnsQ (Pg.Only oid)
newTable <-
Table (fromMaybe noTableConstraints (M.lookup tName allTableConstraints))
<$> foldlM (getColumns tName enumData allDefaults) mempty pgColumns
pure $ M.insert tName newTable allTables
getColumns ::
TableName ->
Map Pg.Oid (EnumerationName, Enumeration) ->
AllDefaults ->
Columns ->
(ByteString, Pg.Oid, Int, Bool, ByteString) ->
IO Columns
getColumns tName enumData defaultData c (attname, atttypid, atttypmod, attnotnull, format_type) = do
let mbPrecision =
if
| atttypmod == -1 -> Nothing
| Pg.typoid Pg.bit == atttypid -> Just atttypmod
| Pg.typoid Pg.varbit == atttypid -> Just atttypmod
| otherwise -> Just (atttypmod - 4)
let columnName = ColumnName (TE.decodeUtf8 attname)
let mbDefault = do
x <- M.lookup tName defaultData
M.lookup columnName x
case asum
[ pgSerialTyColumnType atttypid mbDefault,
pgTypeToColumnType atttypid mbPrecision,
pgEnumTypeToColumnType enumData atttypid
] of
Just cType -> do
let nullConstraint = if attnotnull then S.fromList [NotNull] else mempty
let inferredConstraints = nullConstraint <> fromMaybe mempty (S.singleton <$> mbDefault)
let newColumn = Column cType inferredConstraints
pure $ M.insert columnName newColumn c
Nothing ->
fail $
"Couldn't convert pgType "
<> show format_type
<> " of field "
<> show attname
<> " into a valid ColumnType."
pgEnumTypeToColumnType ::
Map Pg.Oid (EnumerationName, Enumeration) ->
Pg.Oid ->
Maybe ColumnType
pgEnumTypeToColumnType enumData oid =
(\(n, _) -> PgSpecificType (PgEnumeration n)) <$> M.lookup oid enumData
pgSerialTyColumnType ::
Pg.Oid ->
Maybe ColumnConstraint ->
Maybe ColumnType
pgSerialTyColumnType oid (Just (Default d)) = do
guard $ (Pg.typoid Pg.int4 == oid && "nextval" `T.isInfixOf` d && "seq" `T.isInfixOf` d)
pure $ SqlStdType intType
pgSerialTyColumnType _ _ = Nothing
pgTypeToColumnType :: Pg.Oid -> Maybe Int -> Maybe ColumnType
pgTypeToColumnType oid width
| Pg.typoid Pg.int2 == oid =
Just (SqlStdType smallIntType)
| Pg.typoid Pg.int4 == oid =
Just (SqlStdType intType)
| Pg.typoid Pg.int8 == oid =
Just (SqlStdType bigIntType)
| Pg.typoid Pg.bpchar == oid =
Just (SqlStdType $ charType (fromIntegral <$> width) Nothing)
| Pg.typoid Pg.varchar == oid =
Just (SqlStdType $ varCharType (fromIntegral <$> width) Nothing)
| Pg.typoid Pg.bit == oid =
Just (SqlStdType $ bitType (fromIntegral <$> width))
| Pg.typoid Pg.varbit == oid =
Just (SqlStdType $ varBitType (fromIntegral <$> width))
| Pg.typoid Pg.numeric == oid =
let decimals = fromMaybe 0 width .&. 0xFFFF
prec = (fromMaybe 0 width `shiftR` 16) .&. 0xFFFF
in case (prec, decimals) of
(0, 0) -> Just (SqlStdType $ numericType Nothing)
(p, 0) -> Just (SqlStdType $ numericType $ Just (fromIntegral p, Nothing))
_ -> Just (SqlStdType $ numericType (Just (fromIntegral prec, Just (fromIntegral decimals))))
| Pg.typoid Pg.float4 == oid =
Just (SqlStdType realType)
| Pg.typoid Pg.float8 == oid =
Just (SqlStdType doubleType)
| Pg.typoid Pg.date == oid =
Just (SqlStdType dateType)
| Pg.typoid Pg.text == oid =
Just (SqlStdType characterLargeObjectType)
| Pg.typoid Pg.bytea == oid =
Just (SqlStdType AST.DataTypeBinaryLargeObject)
| Pg.typoid Pg.bool == oid =
Just (SqlStdType booleanType)
| Pg.typoid Pg.time == oid =
Just (SqlStdType $ timeType Nothing False)
| Pg.typoid Pg.timestamp == oid =
Just (SqlStdType $timestampType Nothing False)
| Pg.typoid Pg.timestamptz == oid =
Just (SqlStdType $ timestampType Nothing True)
| Pg.typoid Pg.json == oid =
Just (PgSpecificType PgJson)
| Pg.typoid Pg.jsonb == oid =
Just (PgSpecificType PgJsonB)
| Pg.typoid Pg.int4range == oid =
Just (PgSpecificType PgRangeInt4)
| Pg.typoid Pg.int8range == oid =
Just (PgSpecificType PgRangeInt8)
| Pg.typoid Pg.numrange == oid =
Just (PgSpecificType PgRangeNum)
| Pg.typoid Pg.tsrange == oid =
Just (PgSpecificType PgRangeTs)
| Pg.typoid Pg.tstzrange == oid =
Just (PgSpecificType PgRangeTsTz)
| Pg.typoid Pg.daterange == oid =
Just (PgSpecificType PgRangeDate)
| Pg.typoid Pg.uuid == oid =
Just (PgSpecificType PgUuid)
| otherwise =
Nothing
type AllTableConstraints = Map TableName (Set TableConstraint)
type AllDefaults = Map TableName Defaults
type Defaults = Map ColumnName ColumnConstraint
getAllDefaults :: Pg.Connection -> IO AllDefaults
getAllDefaults conn = Pg.fold_ conn defaultsQ mempty (\acc -> pure . addDefault acc)
where
addDefault :: AllDefaults -> (TableName, ColumnName, Text, Text) -> AllDefaults
addDefault m (tName, colName, defValue, dataType) =
let cleanedDefault = case T.breakOn "::" defValue of
(uncasted, defMb)
| T.null defMb ->
"'" <> T.dropAround ((==) '\'') uncasted <> "'::" <> dataType
_ -> defValue
entry = M.singleton colName (Default cleanedDefault)
in M.alter
( \case
Nothing -> Just entry
Just ss -> Just $ ss <> entry
)
tName
m
getAllConstraints :: Pg.Connection -> IO AllTableConstraints
getAllConstraints conn = do
allActions <- mkActions <$> Pg.query_ conn referenceActionsQ
allForeignKeys <- Pg.fold_ conn foreignKeysQ mempty (\acc -> pure . addFkConstraint allActions acc)
Pg.fold_ conn otherConstraintsQ allForeignKeys (\acc -> pure . addOtherConstraint acc)
where
addFkConstraint ::
ReferenceActions ->
AllTableConstraints ->
SqlForeignConstraint ->
AllTableConstraints
addFkConstraint actions st SqlForeignConstraint {..} = flip execState st $ do
let currentTable = sqlFk_foreign_table
let columnSet = S.fromList $ zip (V.toList sqlFk_fk_columns) (V.toList sqlFk_pk_columns)
let (onDelete, onUpdate) =
case M.lookup sqlFk_name (getActions actions) of
Nothing -> (NoAction, NoAction)
Just a -> (actionOnDelete a, actionOnUpdate a)
addTableConstraint currentTable (ForeignKey sqlFk_name sqlFk_primary_table columnSet onDelete onUpdate)
addOtherConstraint ::
AllTableConstraints ->
SqlOtherConstraint ->
AllTableConstraints
addOtherConstraint st SqlOtherConstraint {..} = flip execState st $ do
let currentTable = sqlCon_table
let columnSet = S.fromList . V.toList $ sqlCon_fk_colums
case sqlCon_constraint_type of
SQL_raw_unique -> addTableConstraint currentTable (Unique sqlCon_name columnSet)
SQL_raw_pk -> addTableConstraint currentTable (PrimaryKey sqlCon_name columnSet)
newtype ReferenceActions = ReferenceActions {getActions :: Map Text Actions}
newtype RefEntry = RefEntry {unRefEntry :: (Text, ReferenceAction, ReferenceAction)}
mkActions :: [RefEntry] -> ReferenceActions
mkActions = ReferenceActions . M.fromList . map ((\(a, b, c) -> (a, Actions b c)) . unRefEntry)
instance Pg.FromRow RefEntry where
fromRow =
fmap
RefEntry
( (,,) <$> field
<*> fmap mkAction field
<*> fmap mkAction field
)
data Actions = Actions
{ actionOnDelete :: ReferenceAction,
actionOnUpdate :: ReferenceAction
}
mkAction :: Text -> ReferenceAction
mkAction c = case c of
"a" -> NoAction
"r" -> Restrict
"c" -> Cascade
"n" -> SetNull
"d" -> SetDefault
_ -> error . T.unpack $ "unknown reference action type: " <> c
addTableConstraint ::
TableName ->
TableConstraint ->
State AllTableConstraints ()
addTableConstraint tName cns =
modify'
( M.alter
( \case
Nothing -> Just $ S.singleton cns
Just ss -> Just $ S.insert cns ss
)
tName
)