{-| Module: Squeal.PostgreSQL.Expression.Aggregate Description: aggregate functions and arguments Copyright: (c) Eitan Chatav, 2019 Maintainer: eitan@morphism.tech Stability: experimental aggregate functions and arguments -} {-# LANGUAGE DataKinds , DeriveGeneric , FlexibleContexts , FlexibleInstances , FunctionalDependencies , LambdaCase , MultiParamTypeClasses , OverloadedStrings , PatternSynonyms , PolyKinds , ScopedTypeVariables , StandaloneDeriving , TypeApplications , TypeFamilies , TypeOperators , UndecidableInstances #-} module Squeal.PostgreSQL.Expression.Aggregate ( -- * Aggregate Aggregate (..) -- * Aggregate Arguments , AggregateArg (..) , pattern All , pattern Alls , allNotNull , pattern Distinct , pattern Distincts , distinctNotNull , FilterWhere (..) -- * Aggregate Types , PGSum , PGAvg ) where import Data.ByteString (ByteString) import GHC.TypeLits import qualified Generics.SOP as SOP import Squeal.PostgreSQL.Type.Alias import Squeal.PostgreSQL.Expression import Squeal.PostgreSQL.Expression.Logic import Squeal.PostgreSQL.Expression.Null import Squeal.PostgreSQL.Expression.Sort import Squeal.PostgreSQL.Type.List import Squeal.PostgreSQL.Render import Squeal.PostgreSQL.Type.Schema -- $setup -- >>> import Squeal.PostgreSQL {- | `Aggregate` functions compute a single result from a set of input values. `Aggregate` functions can be used as `Grouped` `Expression`s as well as `Squeal.PostgreSQL.Expression.Window.WindowFunction`s. -} class Aggregate arg expr | expr -> arg where -- | A special aggregation that does not require an input -- -- >>> :{ -- let -- expression :: Expression ('Grouped bys) '[] with db params from ('NotNull 'PGint8) -- expression = countStar -- in printSQL expression -- :} -- count(*) countStar :: expr lat with db params from ('NotNull 'PGint8) -- | >>> :{ -- let -- expression :: Expression ('Grouped bys) '[] with db params '[tab ::: '["col" ::: null ty]] ('NotNull 'PGint8) -- expression = count (All #col) -- in printSQL expression -- :} -- count(ALL "col") count :: arg '[ty] lat with db params from -- ^ argument -> expr lat with db params from ('NotNull 'PGint8) -- | >>> :{ -- let -- expression :: Expression ('Grouped bys) '[] with db params '[tab ::: '["col" ::: 'Null 'PGnumeric]] ('Null 'PGnumeric) -- expression = sum_ (Distinct #col & filterWhere (#col .< 100)) -- in printSQL expression -- :} -- sum(DISTINCT "col") FILTER (WHERE ("col" < (100.0 :: numeric))) sum_ :: arg '[null ty] lat with db params from -- ^ argument -> expr lat with db params from ('Null (PGSum ty)) -- | input values, including nulls, concatenated into an array -- -- >>> :{ -- let -- expression :: Expression ('Grouped bys) '[] with db params '[tab ::: '["col" ::: 'Null 'PGnumeric]] ('Null ('PGvararray ('Null 'PGnumeric))) -- expression = arrayAgg (All #col & orderBy [AscNullsFirst #col] & filterWhere (#col .< 100)) -- in printSQL expression -- :} -- array_agg(ALL "col" ORDER BY "col" ASC NULLS FIRST) FILTER (WHERE ("col" < (100.0 :: numeric))) arrayAgg :: arg '[ty] lat with db params from -- ^ argument -> expr lat with db params from ('Null ('PGvararray ty)) -- | aggregates values as a JSON array jsonAgg :: arg '[ty] lat with db params from -- ^ argument -> expr lat with db params from ('Null 'PGjson) -- | aggregates values as a JSON array jsonbAgg :: arg '[ty] lat with db params from -- ^ argument -> expr lat with db params from ('Null 'PGjsonb) {- | the bitwise AND of all non-null input values, or null if none >>> :{ let expression :: Expression ('Grouped bys) '[] with db params '[tab ::: '["col" ::: null 'PGint4]] ('Null 'PGint4) expression = bitAnd (Distinct #col) in printSQL expression :} bit_and(DISTINCT "col") -} bitAnd :: int `In` PGIntegral => arg '[null int] lat with db params from -- ^ argument -> expr lat with db params from ('Null int) {- | the bitwise OR of all non-null input values, or null if none >>> :{ let expression :: Expression ('Grouped bys) '[] with db params '[tab ::: '["col" ::: null 'PGint4]] ('Null 'PGint4) expression = bitOr (All #col) in printSQL expression :} bit_or(ALL "col") -} bitOr :: int `In` PGIntegral => arg '[null int] lat with db params from -- ^ argument -> expr lat with db params from ('Null int) {- | true if all input values are true, otherwise false >>> :{ let winFun :: WindowFunction 'Ungrouped '[] with db params '[tab ::: '["col" ::: null 'PGbool]] ('Null 'PGbool) winFun = boolAnd (Window #col) in printSQL winFun :} bool_and("col") -} boolAnd :: arg '[null 'PGbool] lat with db params from -- ^ argument -> expr lat with db params from ('Null 'PGbool) {- | true if at least one input value is true, otherwise false >>> :{ let expression :: Expression ('Grouped bys) '[] with db params '[tab ::: '["col" ::: null 'PGbool]] ('Null 'PGbool) expression = boolOr (All #col) in printSQL expression :} bool_or(ALL "col") -} boolOr :: arg '[null 'PGbool] lat with db params from -- ^ argument -> expr lat with db params from ('Null 'PGbool) {- | equivalent to `boolAnd` >>> :{ let expression :: Expression ('Grouped bys) '[] with db params '[tab ::: '["col" ::: null 'PGbool]] ('Null 'PGbool) expression = every (Distinct #col) in printSQL expression :} every(DISTINCT "col") -} every :: arg '[null 'PGbool] lat with db params from -- ^ argument -> expr lat with db params from ('Null 'PGbool) {- |maximum value of expression across all input values-} max_ :: arg '[null ty] lat with db params from -- ^ argument -> expr lat with db params from ('Null ty) -- | minimum value of expression across all input values min_ :: arg '[null ty] lat with db params from -- ^ argument -> expr lat with db params from ('Null ty) -- | the average (arithmetic mean) of all input values avg :: arg '[null ty] lat with db params from -- ^ argument -> expr lat with db params from ('Null (PGAvg ty)) {- | correlation coefficient >>> :{ let expression :: Expression ('Grouped g) '[] c s p '[t ::: '["x" ::: 'NotNull 'PGfloat8, "y" ::: 'NotNull 'PGfloat8]] ('Null 'PGfloat8) expression = corr (Alls (#y *: #x)) in printSQL expression :} corr(ALL "y", "x") -} corr :: arg '[null 'PGfloat8, null 'PGfloat8] lat with db params from -- ^ arguments -> expr lat with db params from ('Null 'PGfloat8) {- | population covariance >>> :{ let expression :: Expression ('Grouped g) '[] c s p '[t ::: '["x" ::: 'NotNull 'PGfloat8, "y" ::: 'NotNull 'PGfloat8]] ('Null 'PGfloat8) expression = covarPop (Alls (#y *: #x)) in printSQL expression :} covar_pop(ALL "y", "x") -} covarPop :: arg '[null 'PGfloat8, null 'PGfloat8] lat with db params from -- ^ arguments -> expr lat with db params from ('Null 'PGfloat8) {- | sample covariance >>> :{ let winFun :: WindowFunction 'Ungrouped '[] c s p '[t ::: '["x" ::: 'NotNull 'PGfloat8, "y" ::: 'NotNull 'PGfloat8]] ('Null 'PGfloat8) winFun = covarSamp (Windows (#y *: #x)) in printSQL winFun :} covar_samp("y", "x") -} covarSamp :: arg '[null 'PGfloat8, null 'PGfloat8] lat with db params from -- ^ arguments -> expr lat with db params from ('Null 'PGfloat8) {- | average of the independent variable (sum(X)/N) >>> :{ let expression :: Expression ('Grouped g) '[] c s p '[t ::: '["x" ::: 'NotNull 'PGfloat8, "y" ::: 'NotNull 'PGfloat8]] ('Null 'PGfloat8) expression = regrAvgX (Alls (#y *: #x)) in printSQL expression :} regr_avgx(ALL "y", "x") -} regrAvgX :: arg '[null 'PGfloat8, null 'PGfloat8] lat with db params from -- ^ arguments -> expr lat with db params from ('Null 'PGfloat8) {- | average of the dependent variable (sum(Y)/N) >>> :{ let winFun :: WindowFunction 'Ungrouped '[] c s p '[t ::: '["x" ::: 'NotNull 'PGfloat8, "y" ::: 'NotNull 'PGfloat8]] ('Null 'PGfloat8) winFun = regrAvgY (Windows (#y *: #x)) in printSQL winFun :} regr_avgy("y", "x") -} regrAvgY :: arg '[null 'PGfloat8, null 'PGfloat8] lat with db params from -- ^ arguments -> expr lat with db params from ('Null 'PGfloat8) {- | number of input rows in which both expressions are nonnull >>> :{ let winFun :: WindowFunction 'Ungrouped '[] c s p '[t ::: '["x" ::: 'NotNull 'PGfloat8, "y" ::: 'NotNull 'PGfloat8]] ('Null 'PGint8) winFun = regrCount (Windows (#y *: #x)) in printSQL winFun :} regr_count("y", "x") -} regrCount :: arg '[null 'PGfloat8, null 'PGfloat8] lat with db params from -- ^ arguments -> expr lat with db params from ('Null 'PGint8) {- | y-intercept of the least-squares-fit linear equation determined by the (X, Y) pairs >>> :{ let expression :: Expression ('Grouped g) '[] c s p '[t ::: '["x" ::: 'NotNull 'PGfloat8, "y" ::: 'NotNull 'PGfloat8]] ('Null 'PGfloat8) expression = regrIntercept (Alls (#y *: #x)) in printSQL expression :} regr_intercept(ALL "y", "x") -} regrIntercept :: arg '[null 'PGfloat8, null 'PGfloat8] lat with db params from -- ^ arguments -> expr lat with db params from ('Null 'PGfloat8) -- | @regr_r2(Y, X)@, square of the correlation coefficient regrR2 :: arg '[null 'PGfloat8, null 'PGfloat8] lat with db params from -- ^ arguments -> expr lat with db params from ('Null 'PGfloat8) -- | @regr_slope(Y, X)@, slope of the least-squares-fit linear equation -- determined by the (X, Y) pairs regrSlope :: arg '[null 'PGfloat8, null 'PGfloat8] lat with db params from -- ^ arguments -> expr lat with db params from ('Null 'PGfloat8) -- | @regr_sxx(Y, X)@, sum(X^2) - sum(X)^2/N -- (“sum of squares” of the independent variable) regrSxx :: arg '[null 'PGfloat8, null 'PGfloat8] lat with db params from -- ^ arguments -> expr lat with db params from ('Null 'PGfloat8) -- | @regr_sxy(Y, X)@, sum(X*Y) - sum(X) * sum(Y)/N -- (“sum of products” of independent times dependent variable) regrSxy :: arg '[null 'PGfloat8, null 'PGfloat8] lat with db params from -- ^ arguments -> expr lat with db params from ('Null 'PGfloat8) -- | @regr_syy(Y, X)@, sum(Y^2) - sum(Y)^2/N -- (“sum of squares” of the dependent variable) regrSyy :: arg '[null 'PGfloat8, null 'PGfloat8] lat with db params from -- ^ arguments -> expr lat with db params from ('Null 'PGfloat8) -- | historical alias for `stddevSamp` stddev :: arg '[null ty] lat with db params from -- ^ argument -> expr lat with db params from ('Null (PGAvg ty)) -- | population standard deviation of the input values stddevPop :: arg '[null ty] lat with db params from -- ^ argument -> expr lat with db params from ('Null (PGAvg ty)) -- | sample standard deviation of the input values stddevSamp :: arg '[null ty] lat with db params from -- ^ argument -> expr lat with db params from ('Null (PGAvg ty)) -- | historical alias for `varSamp` variance :: arg '[null ty] lat with db params from -- ^ argument -> expr lat with db params from ('Null (PGAvg ty)) -- | population variance of the input values -- (square of the population standard deviation) varPop :: arg '[null ty] lat with db params from -- ^ argument -> expr lat with db params from ('Null (PGAvg ty)) -- | sample variance of the input values -- (square of the sample standard deviation) varSamp :: arg '[null ty] lat with db params from -- ^ argument -> expr lat with db params from ('Null (PGAvg ty)) {- | `AggregateArg`s are used for the input of `Aggregate` `Expression`s. -} data AggregateArg (xs :: [NullType]) (lat :: FromType) (with :: FromType) (db :: SchemasType) (params :: [NullType]) (from :: FromType) = AggregateAll { aggregateArgs :: NP (Expression 'Ungrouped lat with db params from) xs , aggregateOrder :: [SortExpression 'Ungrouped lat with db params from] -- ^ `orderBy` , aggregateFilter :: [Condition 'Ungrouped lat with db params from] -- ^ `filterWhere` } | AggregateDistinct { aggregateArgs :: NP (Expression 'Ungrouped lat with db params from) xs , aggregateOrder :: [SortExpression 'Ungrouped lat with db params from] -- ^ `orderBy` , aggregateFilter :: [Condition 'Ungrouped lat with db params from] -- ^ `filterWhere` } instance (HasUnique tab (Join from lat) row, Has col row ty) => IsLabel col (AggregateArg '[ty] lat with db params from) where fromLabel = All (fromLabel @col) instance (Has tab (Join from lat) row, Has col row ty) => IsQualified tab col (AggregateArg '[ty] lat with db params from) where tab ! col = All (tab ! col) instance SOP.SListI xs => RenderSQL (AggregateArg xs lat with db params from) where renderSQL = \case AggregateAll args sorts filters -> parenthesized ("ALL" <+> renderCommaSeparated renderSQL args<> renderSQL sorts) <> renderFilters filters AggregateDistinct args sorts filters -> parenthesized ("DISTINCT" <+> renderCommaSeparated renderSQL args <> renderSQL sorts) <> renderFilters filters where renderFilter wh = "FILTER" <+> parenthesized ("WHERE" <+> wh) renderFilters = \case [] -> "" wh:whs -> " " <> renderFilter (renderSQL (foldr (.&&) wh whs)) instance OrderBy (AggregateArg xs) 'Ungrouped where orderBy sorts1 = \case AggregateAll xs sorts0 whs -> AggregateAll xs (sorts0 ++ sorts1) whs AggregateDistinct xs sorts0 whs -> AggregateDistinct xs (sorts0 ++ sorts1) whs -- | `All` invokes the aggregate on a single -- argument once for each input row. pattern All :: Expression 'Ungrouped lat with db params from x -- ^ argument -> AggregateArg '[x] lat with db params from pattern All x = Alls (x :* Nil) -- | `All` invokes the aggregate on multiple -- arguments once for each input row. pattern Alls :: NP (Expression 'Ungrouped lat with db params from) xs -- ^ arguments -> AggregateArg xs lat with db params from pattern Alls xs = AggregateAll xs [] [] -- | `allNotNull` invokes the aggregate on a single -- argument once for each input row where the argument -- is not null allNotNull :: Expression 'Ungrouped lat with db params from ('Null x) -- ^ argument -> AggregateArg '[ 'NotNull x] lat with db params from allNotNull x = All (unsafeNotNull x) & filterWhere (not_ (isNull x)) {- | `Distinct` invokes the aggregate once for each distinct value of the expression found in the input. -} pattern Distinct :: Expression 'Ungrouped lat with db params from x -- ^ argument -> AggregateArg '[x] lat with db params from pattern Distinct x = Distincts (x :* Nil) {- | `Distincts` invokes the aggregate once for each distinct set of values, for multiple expressions, found in the input. -} pattern Distincts :: NP (Expression 'Ungrouped lat with db params from) xs -- ^ arguments -> AggregateArg xs lat with db params from pattern Distincts xs = AggregateDistinct xs [] [] {- | `distinctNotNull` invokes the aggregate once for each distinct, not null value of the expression found in the input. -} distinctNotNull :: Expression 'Ungrouped lat with db params from ('Null x) -- ^ argument -> AggregateArg '[ 'NotNull x] lat with db params from distinctNotNull x = Distinct (unsafeNotNull x) & filterWhere (not_ (isNull x)) -- | Permits filtering -- `Squeal.PostgreSQL.Expression.Window.WindowArg`s and `AggregateArg`s class FilterWhere arg grp | arg -> grp where {- | If `filterWhere` is specified, then only the input rows for which the `Condition` evaluates to true are fed to the aggregate function; other rows are discarded. -} filterWhere :: Condition grp lat with db params from -- ^ include rows which evaluate to true -> arg xs lat with db params from -> arg xs lat with db params from instance FilterWhere AggregateArg 'Ungrouped where filterWhere wh = \case AggregateAll xs sorts whs -> AggregateAll xs sorts (wh : whs) AggregateDistinct xs sorts whs -> AggregateDistinct xs sorts (wh : whs) instance Aggregate AggregateArg (Expression ('Grouped bys)) where countStar = UnsafeExpression "count(*)" count = unsafeAggregate "count" sum_ = unsafeAggregate "sum" arrayAgg = unsafeAggregate "array_agg" jsonAgg = unsafeAggregate "json_agg" jsonbAgg = unsafeAggregate "jsonb_agg" bitAnd = unsafeAggregate "bit_and" bitOr = unsafeAggregate "bit_or" boolAnd = unsafeAggregate "bool_and" boolOr = unsafeAggregate "bool_or" every = unsafeAggregate "every" max_ = unsafeAggregate "max" min_ = unsafeAggregate "min" avg = unsafeAggregate "avg" corr = unsafeAggregate "corr" covarPop = unsafeAggregate "covar_pop" covarSamp = unsafeAggregate "covar_samp" regrAvgX = unsafeAggregate "regr_avgx" regrAvgY = unsafeAggregate "regr_avgy" regrCount = unsafeAggregate "regr_count" regrIntercept = unsafeAggregate "regr_intercept" regrR2 = unsafeAggregate "regr_r2" regrSlope = unsafeAggregate "regr_slope" regrSxx = unsafeAggregate "regr_sxx" regrSxy = unsafeAggregate "regr_sxy" regrSyy = unsafeAggregate "regr_syy" stddev = unsafeAggregate "stddev" stddevPop = unsafeAggregate "stddev_pop" stddevSamp = unsafeAggregate "stddev_samp" variance = unsafeAggregate "variance" varPop = unsafeAggregate "var_pop" varSamp = unsafeAggregate "var_samp" -- provides a nicer type error when we forget to group by -- note that we need to make our 'a' polymorphic so that we can still match when it's ambiguous instance ( TypeError ('Text "Cannot use aggregate functions to construct an Ungrouped Expression. Add a 'groupBy' to your TableExpression. If you want to aggregate across the entire result set, use 'groupBy Nil'.") , a ~ AggregateArg ) => Aggregate a (Expression 'Ungrouped) where countStar = impossibleAggregateError count = impossibleAggregateError sum_ = impossibleAggregateError arrayAgg = impossibleAggregateError jsonAgg = impossibleAggregateError jsonbAgg = impossibleAggregateError bitAnd = impossibleAggregateError bitOr = impossibleAggregateError boolAnd = impossibleAggregateError boolOr = impossibleAggregateError every = impossibleAggregateError max_ = impossibleAggregateError min_ = impossibleAggregateError avg = impossibleAggregateError corr = impossibleAggregateError covarPop = impossibleAggregateError covarSamp = impossibleAggregateError regrAvgX = impossibleAggregateError regrAvgY = impossibleAggregateError regrCount = impossibleAggregateError regrIntercept = impossibleAggregateError regrR2 = impossibleAggregateError regrSlope = impossibleAggregateError regrSxx = impossibleAggregateError regrSxy = impossibleAggregateError regrSyy = impossibleAggregateError stddev = impossibleAggregateError stddevPop = impossibleAggregateError stddevSamp = impossibleAggregateError variance = impossibleAggregateError varPop = impossibleAggregateError varSamp = impossibleAggregateError -- | helper function for our errors above impossibleAggregateError :: a impossibleAggregateError = error "impossible; called aggregate function for Ungrouped even though the Aggregate instance has a type error constraint." -- | escape hatch to define aggregate functions unsafeAggregate :: SOP.SListI xs => ByteString -- ^ function -> AggregateArg xs lat with db params from -> Expression ('Grouped bys) lat with db params from y unsafeAggregate fun xs = UnsafeExpression $ fun <> renderSQL xs -- | A type family that calculates `PGSum` `PGType` of -- a given argument `PGType`. type family PGSum ty where PGSum 'PGint2 = 'PGint8 PGSum 'PGint4 = 'PGint8 PGSum 'PGint8 = 'PGnumeric PGSum 'PGfloat4 = 'PGfloat4 PGSum 'PGfloat8 = 'PGfloat8 PGSum 'PGnumeric = 'PGnumeric PGSum 'PGinterval = 'PGinterval PGSum 'PGmoney = 'PGmoney PGSum pg = TypeError ( 'Text "Squeal type error: Cannot sum with argument type " ':<>: 'ShowType pg ) -- | A type family that calculates `PGAvg` type of a `PGType`. type family PGAvg ty where PGAvg 'PGint2 = 'PGnumeric PGAvg 'PGint4 = 'PGnumeric PGAvg 'PGint8 = 'PGnumeric PGAvg 'PGnumeric = 'PGnumeric PGAvg 'PGfloat4 = 'PGfloat8 PGAvg 'PGfloat8 = 'PGfloat8 PGAvg 'PGinterval = 'PGinterval PGAvg pg = TypeError ('Text "Squeal type error: No average for " ':<>: 'ShowType pg)