{-# language DataKinds #-}
{-# language FlexibleContexts #-}
{-# language FlexibleInstances #-}
{-# language MultiParamTypeClasses #-}
{-# language RankNTypes #-}
{-# language StandaloneKindSignatures #-}
{-# language TypeFamilies #-}
{-# language UndecidableInstances #-}

module Rel8.Aggregate
  ( Aggregate(..), zipOutputs
  , unsafeMakeAggregate
  , Aggregates
  )
where

-- base
import Control.Applicative ( liftA2 )
import Data.Functor.Identity ( Identity( Identity ) )
import Data.Kind ( Constraint, Type )
import Prelude

-- profunctors
import Data.Profunctor ( dimap )

-- opaleye
import qualified Opaleye.Aggregate as Opaleye
import qualified Opaleye.Internal.HaskellDB.PrimQuery as Opaleye
import qualified Opaleye.Internal.Column as Opaleye

-- rel8
import Rel8.Expr ( Expr )
import Rel8.Schema.HTable.Identity ( HIdentity(..) )
import qualified Rel8.Schema.Kind as K
import Rel8.Schema.Null ( Sql )
import Rel8.Table
  ( Table, Columns, Context, fromColumns, toColumns
  , FromExprs, fromResult, toResult
  , Transpose
  )
import Rel8.Table.Transpose ( Transposes )
import Rel8.Type ( DBType )


-- | 'Aggregate' is a special context used by 'Rel8.aggregate'.
type Aggregate :: K.Context
newtype Aggregate a = Aggregate (Opaleye.Aggregator () (Expr a))


instance Sql DBType a => Table Aggregate (Aggregate a) where
  type Columns (Aggregate a) = HIdentity a
  type Context (Aggregate a) = Aggregate
  type FromExprs (Aggregate a) = a
  type Transpose to (Aggregate a) = to a

  toColumns :: Aggregate a -> Columns (Aggregate a) Aggregate
toColumns = forall a (context :: Context). context a -> HIdentity a context
HIdentity
  fromColumns :: Columns (Aggregate a) Aggregate -> Aggregate a
fromColumns (HIdentity Aggregate a
a) = Aggregate a
a
  toResult :: FromExprs (Aggregate a) -> Columns (Aggregate a) Result
toResult = forall a (context :: Context). context a -> HIdentity a context
HIdentity forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Identity a
Identity
  fromResult :: Columns (Aggregate a) Result -> FromExprs (Aggregate a)
fromResult (HIdentity (Identity a
a)) = a
a


-- | @Aggregates a b@ means that the columns in @a@ are all 'Aggregate's
-- for the 'Expr' columns in @b@.
type Aggregates :: Type -> Type -> Constraint
class Transposes Aggregate Expr aggregates exprs => Aggregates aggregates exprs
instance Transposes Aggregate Expr aggregates exprs => Aggregates aggregates exprs


zipOutputs :: ()
  => (Expr a -> Expr b -> Expr c) -> Aggregate a -> Aggregate b -> Aggregate c
zipOutputs :: forall a b c.
(Expr a -> Expr b -> Expr c)
-> Aggregate a -> Aggregate b -> Aggregate c
zipOutputs Expr a -> Expr b -> Expr c
f (Aggregate Aggregator () (Expr a)
a) (Aggregate Aggregator () (Expr b)
b) = forall a. Aggregator () (Expr a) -> Aggregate a
Aggregate (forall (f :: Context) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 Expr a -> Expr b -> Expr c
f Aggregator () (Expr a)
a Aggregator () (Expr b)
b)


unsafeMakeAggregate :: forall (input :: Type) (output :: Type) n n' a a'. ()
  => (Expr input -> Opaleye.PrimExpr)
  -> (Opaleye.PrimExpr -> Expr output)
  -> Opaleye.Aggregator (Opaleye.Field_ n a) (Opaleye.Field_ n' a')
  -> Expr input
  -> Aggregate output
unsafeMakeAggregate :: forall input output (n :: Nullability) (n' :: Nullability) a a'.
(Expr input -> PrimExpr)
-> (PrimExpr -> Expr output)
-> Aggregator (Field_ n a) (Field_ n' a')
-> Expr input
-> Aggregate output
unsafeMakeAggregate Expr input -> PrimExpr
input PrimExpr -> Expr output
output Aggregator (Field_ n a) (Field_ n' a')
aggregator Expr input
expr =
  forall a. Aggregator () (Expr a) -> Aggregate a
Aggregate forall a b. (a -> b) -> a -> b
$ forall (p :: * -> Context) a b c d.
Profunctor p =>
(a -> b) -> (c -> d) -> p b c -> p a d
dimap () -> Field_ n a
in_ Field_ n' a' -> Expr output
out Aggregator (Field_ n a) (Field_ n' a')
aggregator
  where out :: Field_ n' a' -> Expr output
out = PrimExpr -> Expr output
output forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nullability) a. Field_ n a -> PrimExpr
Opaleye.unColumn
        in_ :: () -> Field_ n a
in_ = forall (n :: Nullability) sqlType. PrimExpr -> Field_ n sqlType
Opaleye.Column forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr input -> PrimExpr
input forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> b -> a
const Expr input
expr