{-# language FlexibleContexts #-}
{-# language MonoLocalBinds #-}
{-# language ScopedTypeVariables #-}
{-# language TypeApplications #-}

module Rel8.Query.Aggregate
  ( aggregate
  , countRows
  , mode
  )
where

-- base
import Data.Functor.Contravariant ( (>$<) )
import Data.Int ( Int64 )
import Prelude

-- opaleye
import qualified Opaleye.Aggregate as Opaleye

-- rel8
import Rel8.Aggregate ( Aggregates )
import Rel8.Expr ( Expr )
import Rel8.Expr.Aggregate ( countStar )
import Rel8.Expr.Order ( desc )
import Rel8.Query ( Query )
import Rel8.Query.Limit ( limit )
import Rel8.Query.Maybe ( optional )
import Rel8.Query.Opaleye ( mapOpaleye )
import Rel8.Query.Order ( orderBy )
import Rel8.Table ( toColumns )
import Rel8.Table.Aggregate ( hgroupBy )
import Rel8.Table.Cols ( Cols( Cols ), fromCols )
import Rel8.Table.Eq ( EqTable, eqTable )
import Rel8.Table.Opaleye ( aggregator )
import Rel8.Table.Maybe ( maybeTable )


-- | Apply an aggregation to all rows returned by a 'Query'.
aggregate :: Aggregates aggregates exprs => Query aggregates -> Query exprs
aggregate :: Query aggregates -> Query exprs
aggregate = (Select aggregates -> Select exprs)
-> Query aggregates -> Query exprs
forall a b. (Select a -> Select b) -> Query a -> Query b
mapOpaleye (Aggregator aggregates exprs -> Select aggregates -> Select exprs
forall a b. Aggregator a b -> Select a -> Select b
Opaleye.aggregate Aggregator aggregates exprs
forall aggregates exprs.
Aggregates aggregates exprs =>
Aggregator aggregates exprs
aggregator)


-- | Count the number of rows returned by a query. Note that this is different
-- from @countStar@, as even if the given query yields no rows, @countRows@
-- will return @0@.
countRows :: Query a -> Query (Expr Int64)
countRows :: Query a -> Query (Expr Int64)
countRows = (MaybeTable Expr (Expr Int64) -> Expr Int64)
-> Query (MaybeTable Expr (Expr Int64)) -> Query (Expr Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Expr Int64
-> (Expr Int64 -> Expr Int64)
-> MaybeTable Expr (Expr Int64)
-> Expr Int64
forall b a. Table Expr b => b -> (a -> b) -> MaybeTable Expr a -> b
maybeTable Expr Int64
0 Expr Int64 -> Expr Int64
forall a. a -> a
id) (Query (MaybeTable Expr (Expr Int64)) -> Query (Expr Int64))
-> (Query a -> Query (MaybeTable Expr (Expr Int64)))
-> Query a
-> Query (Expr Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Query (Expr Int64) -> Query (MaybeTable Expr (Expr Int64))
forall a. Query a -> Query (MaybeTable Expr a)
optional (Query (Expr Int64) -> Query (MaybeTable Expr (Expr Int64)))
-> (Query a -> Query (Expr Int64))
-> Query a
-> Query (MaybeTable Expr (Expr Int64))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Query (Aggregate Int64) -> Query (Expr Int64)
forall aggregates exprs.
Aggregates aggregates exprs =>
Query aggregates -> Query exprs
aggregate (Query (Aggregate Int64) -> Query (Expr Int64))
-> (Query a -> Query (Aggregate Int64))
-> Query a
-> Query (Expr Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Aggregate Int64) -> Query a -> Query (Aggregate Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Aggregate Int64 -> a -> Aggregate Int64
forall a b. a -> b -> a
const Aggregate Int64
countStar)


-- | Return the most common row in a query.
mode :: forall a. EqTable a => Query a -> Query a
mode :: Query a -> Query a
mode Query a
rows = Word -> Query a -> Query a
forall a. Word -> Query a -> Query a
limit Word
1 (Query a -> Query a) -> Query a -> Query a
forall a b. (a -> b) -> a -> b
$ ((Expr Int64, Cols Expr (Columns a)) -> a)
-> Query (Expr Int64, Cols Expr (Columns a)) -> Query a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Cols Expr (Columns a) -> a
forall (context :: * -> *) a.
Table context a =>
Cols context (Columns a) -> a
fromCols (Cols Expr (Columns a) -> a)
-> ((Expr Int64, Cols Expr (Columns a)) -> Cols Expr (Columns a))
-> (Expr Int64, Cols Expr (Columns a))
-> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Expr Int64, Cols Expr (Columns a)) -> Cols Expr (Columns a)
forall a b. (a, b) -> b
snd) (Query (Expr Int64, Cols Expr (Columns a)) -> Query a)
-> Query (Expr Int64, Cols Expr (Columns a)) -> Query a
forall a b. (a -> b) -> a -> b
$ Order (Expr Int64, Cols Expr (Columns a))
-> Query (Expr Int64, Cols Expr (Columns a))
-> Query (Expr Int64, Cols Expr (Columns a))
forall a. Order a -> Query a -> Query a
orderBy ((Expr Int64, Cols Expr (Columns a)) -> Expr Int64
forall a b. (a, b) -> a
fst ((Expr Int64, Cols Expr (Columns a)) -> Expr Int64)
-> Order (Expr Int64) -> Order (Expr Int64, Cols Expr (Columns a))
forall (f :: * -> *) a b. Contravariant f => (a -> b) -> f b -> f a
>$< Order (Expr Int64)
forall a. DBOrd a => Order (Expr a)
desc) (Query (Expr Int64, Cols Expr (Columns a))
 -> Query (Expr Int64, Cols Expr (Columns a)))
-> Query (Expr Int64, Cols Expr (Columns a))
-> Query (Expr Int64, Cols Expr (Columns a))
forall a b. (a -> b) -> a -> b
$ do
  Query (Aggregate Int64, Cols Aggregate (Columns a))
-> Query (Expr Int64, Cols Expr (Columns a))
forall aggregates exprs.
Aggregates aggregates exprs =>
Query aggregates -> Query exprs
aggregate (Query (Aggregate Int64, Cols Aggregate (Columns a))
 -> Query (Expr Int64, Cols Expr (Columns a)))
-> Query (Aggregate Int64, Cols Aggregate (Columns a))
-> Query (Expr Int64, Cols Expr (Columns a))
forall a b. (a -> b) -> a -> b
$ do
    Columns a Expr
row <- a -> Columns a Expr
forall (context :: * -> *) a.
Table context a =>
a -> Columns a context
toColumns (a -> Columns a Expr) -> Query a -> Query (Columns a Expr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Query a
rows
    (Aggregate Int64, Cols Aggregate (Columns a))
-> Query (Aggregate Int64, Cols Aggregate (Columns a))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Aggregate Int64
countStar, Columns a Aggregate -> Cols Aggregate (Columns a)
forall (context :: * -> *) (columns :: HTable).
columns context -> Cols context columns
Cols (Columns a Aggregate -> Cols Aggregate (Columns a))
-> Columns a Aggregate -> Cols Aggregate (Columns a)
forall a b. (a -> b) -> a -> b
$ Columns a (Dict (Sql DBEq))
-> Columns a Expr -> Columns a Aggregate
forall (t :: HTable).
HTable t =>
t (Dict (Sql DBEq)) -> t Expr -> t Aggregate
hgroupBy (EqTable a => Columns a (Dict (Sql DBEq))
forall a. EqTable a => Columns a (Dict (Sql DBEq))
eqTable @a) Columns a Expr
row)