{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE UndecidableInstances #-} module Database.Beam.AutoMigrate.Util where import Control.Applicative.Lift import Control.Monad.Except import Data.Char import Data.Functor.Constant import Data.Set (Set) import qualified Data.Set as Set import Data.String (fromString) import Data.Text (Text) import qualified Data.Text as T import Database.Beam.AutoMigrate.Types (ColumnName(..), TableName(..)) import Database.Beam.Schema (Beamable, PrimaryKey, TableEntity, TableSettings) import qualified Database.Beam.Schema as Beam import Database.Beam.Schema.Tables import Lens.Micro ((^.)) -- -- Retrieving all the column names for a beam entity. -- class HasColumnNames entity tbl where colNames :: tbl (Beam.TableField tbl) -> (tbl (Beam.TableField tbl) -> entity) -> [ColumnName] instance Beam.Beamable (PrimaryKey tbl) => HasColumnNames (PrimaryKey tbl (Beam.TableField c)) tbl where colNames field fn = map ColumnName (allBeamValues (\(Columnar' x) -> x ^. fieldName) (fn field)) instance Beam.Beamable (PrimaryKey tbl) => HasColumnNames (PrimaryKey tbl (Beam.TableField c)) tbl' where colNames field fn = map ColumnName (allBeamValues (\(Columnar' x) -> x ^. fieldName) (fn field)) instance HasColumnNames (Beam.TableField tbl ty) tbl where colNames field fn = [ColumnName (fn field ^. Beam.fieldName)] -- -- General utility functions -- -- | Extracts the 'TableSettings' out of the input 'DatabaseEntity'. tableSettings :: Beam.DatabaseEntity be db (TableEntity tbl) -> TableSettings tbl tableSettings entity = dbTableSettings $ entity ^. dbEntityDescriptor tableName :: Beam.Beamable tbl => Beam.DatabaseEntity be db (TableEntity tbl) -> TableName tableName entity = TableName $ (entity ^. dbEntityDescriptor . dbEntityName) -- | Extracts the primary key of a table as a list of 'ColumnName'. pkFieldNames :: (Beamable (PrimaryKey tbl), Beam.Table tbl) => Beam.DatabaseEntity be db (TableEntity tbl) -> [ColumnName] pkFieldNames entity = map ColumnName (allBeamValues (\(Columnar' x) -> x ^. fieldName) (primaryKey . tableSettings $ entity)) -- | Similar to 'pkFieldNames', but it works on any entity that derives 'Beamable'. fieldAsColumnNames :: Beamable tbl => tbl (Beam.TableField c) -> [ColumnName] fieldAsColumnNames field = map ColumnName (allBeamValues (\(Columnar' x) -> x ^. fieldName) field) -- | Returns /all/ the 'ColumnName's for a given 'DatabaseEntity'. allColumnNames :: Beamable tbl => Beam.DatabaseEntity be db (TableEntity tbl) -> [ColumnName] allColumnNames entity = let settings = dbTableSettings $ entity ^. dbEntityDescriptor in map ColumnName (allBeamValues (\(Columnar' x) -> x ^. fieldName) settings) -- -- Reporting multiple errors at once -- -- See https://teh.id.au/posts/2017/03/13/accumulating-errors/index.html hoistErrors :: Either e a -> Errors e a hoistErrors e = case e of Left es -> Other (Constant es) Right a -> Pure a -- | Like 'sequence', but accumulating all errors in case of at least one 'Left'. sequenceEither :: (Monoid e, Traversable f) => f (Either e a) -> Either e (f a) sequenceEither = runErrors . traverse hoistErrors -- | Evaluate each action in sequence, accumulating all errors in case of a failure. -- Note that this means each action will be run independently, regardless of failure. sequenceExceptT :: (Monad m, Monoid w, Traversable t) => t (ExceptT w m a) -> ExceptT w m (t a) sequenceExceptT es = do es' <- lift (traverse runExceptT es) ExceptT (return (sequenceEither es')) -- NOTE(adn) Unfortunately these combinators are not re-exported by beam. sqlOptPrec :: Maybe Word -> Text sqlOptPrec Nothing = mempty sqlOptPrec (Just x) = "(" <> fromString (show x) <> ")" sqlOptCharSet :: Maybe Text -> Text sqlOptCharSet Nothing = mempty sqlOptCharSet (Just cs) = " CHARACTER SET " <> cs -- | Escape a sql identifier according to the rules defined in the postgres manual sqlEscaped :: Text -> Text sqlEscaped t = if sqlValidUnescaped t then t else -- Double-quotes inside identifier names must be escaped by with an additional double-quote "\"" <> (T.intercalate "\"\"" $ T.splitOn "\"" t) <> "\"" -- | Check whether an identifier is valid without escaping (True) or must be escaped (False) -- according to the postgres sqlValidUnescaped :: Text -> Bool sqlValidUnescaped t = case T.uncons t of Nothing -> True Just (c, rest) -> validUnescapedHead c && validUnescapedTail rest && not (sqlIsReservedKeyword t) where validUnescapedHead c = c `elem` ("1234567890_"::String) || isAlpha c validUnescapedTail = all (\r -> (isAlpha r && isLower r) || r `elem` ("1234567890$_"::String)) . T.unpack sqlIsReservedKeyword :: Text -> Bool sqlIsReservedKeyword t = T.toCaseFold t `Set.member` postgresKeywordsReserved -- | Reserved keywords according to -- https://www.postgresql.org/docs/current/sql-keywords-appendix.html postgresKeywordsReserved :: Set Text postgresKeywordsReserved = Set.fromList $ map T.toCaseFold [ "ALL" , "ANALYSE" , "ANALYZE" , "AND" , "ANY" , "ARRAY" , "AS" , "ASC" , "ASYMMETRIC" , "BOTH" , "CASE" , "CAST" , "CHECK" , "COLLATE" , "COLUMN" , "CONSTRAINT" , "CREATE" , "CURRENT_CATALOG" , "CURRENT_DATE" , "CURRENT_ROLE" , "CURRENT_TIME" , "CURRENT_TIMESTAMP" , "CURRENT_USER" , "DEFAULT" , "DEFERRABLE" , "DESC" , "DISTINCT" , "DO" , "ELSE" , "END" , "EXCEPT" , "FALSE" , "FETCH" , "FOR" , "FOREIGN" , "FROM" , "GRANT" , "GROUP" , "HAVING" , "IN" , "INITIALLY" , "INTERSECT" , "INTO" , "LATERAL" , "LEADING" , "LIMIT" , "LOCALTIME" , "LOCALTIMESTAMP" , "NOT" , "NULL" , "OFFSET" , "ON" , "ONLY" , "OR" , "ORDER" , "PLACING" , "PRIMARY" , "REFERENCES" , "RETURNING" , "SELECT" , "SESSION_USER" , "SOME" , "SYMMETRIC" , "TABLE" , "THEN" , "TO" , "TRAILING" , "TRUE" , "UNION" , "UNIQUE" , "USER" , "USING" , "VARIADIC" , "WHEN" , "WHERE" , "WINDOW" , "WITH" ] sqlSingleQuoted :: Text -> Text sqlSingleQuoted t = "'" <> t <> "'" sqlOptNumericPrec :: Maybe (Word, Maybe Word) -> Text sqlOptNumericPrec Nothing = mempty sqlOptNumericPrec (Just (prec, Nothing)) = sqlOptPrec (Just prec) sqlOptNumericPrec (Just (prec, Just dec)) = "(" <> fromString (show prec) <> ", " <> fromString (show dec) <> ")"