{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
module Database.Relational.Monad.Trans.Aggregating
       ( 
         Aggregatings, aggregatings,
         AggregatingSetT, AggregatingSetListT, AggregatingPowerSetT, PartitioningSetT,
         
         extractAggregateTerms,
         
         AggregatingSet, AggregatingPowerSet,  AggregatingSetList, PartitioningSet,
         key, key', set,
         bkey, rollup, cube, groupingSets,
       ) where
import Control.Monad.Trans.Class (MonadTrans (lift))
import Control.Monad.Trans.Writer (WriterT, runWriterT, tell)
import Control.Applicative (Applicative, pure, (<$>))
import Control.Arrow (second)
import Data.DList (DList, toList)
import Data.Functor.Identity (Identity (runIdentity))
import Database.Relational.Internal.ContextType
  (Flat, Aggregated, Set, Power, SetList)
import Database.Relational.SqlSyntax
  (Record,
   AggregateColumnRef, AggregateElem, aggregateColumnRef, AggregateSet, aggregateGroupingSet,
   AggregateBitKey, aggregatePowerKey, aggregateRollup, aggregateCube, aggregateSets,
   AggregateKey, aggregateKeyRecord, aggregateKeyElement, unsafeAggregateKey)
import qualified Database.Relational.Record as Record
import Database.Relational.Monad.Class
  (MonadQualify (..), MonadRestrict(..), MonadQuery(..), MonadAggregate(..), MonadPartition(..))
newtype Aggregatings ac at m a =
  Aggregatings (WriterT (DList at) m a)
  deriving (MonadTrans, Monad, Functor, Applicative)
aggregatings :: Monad m => m a -> Aggregatings ac at m a
aggregatings =  lift
type AggregatingSetT      = Aggregatings Set       AggregateElem
type AggregatingSetListT  = Aggregatings SetList   AggregateSet
type AggregatingPowerSetT = Aggregatings Power     AggregateBitKey
type PartitioningSetT c   = Aggregatings c         AggregateColumnRef
instance MonadRestrict c m => MonadRestrict c (AggregatingSetT m) where
  restrict =  aggregatings . restrict
instance MonadQualify q m => MonadQualify q (AggregatingSetT m) where
  liftQualify = aggregatings . liftQualify
instance MonadQuery m => MonadQuery (AggregatingSetT m) where
  setDuplication     = aggregatings . setDuplication
  restrictJoin       = aggregatings . restrictJoin
  query'             = aggregatings . query'
  queryMaybe'        = aggregatings . queryMaybe'
unsafeAggregateWithTerm :: Monad m => at -> Aggregatings ac at m ()
unsafeAggregateWithTerm =  Aggregatings . tell . pure
aggregateKey :: Monad m => AggregateKey a -> Aggregatings ac AggregateElem m a
aggregateKey k = do
  unsafeAggregateWithTerm $ aggregateKeyElement k
  return $ aggregateKeyRecord k
instance MonadQuery m => MonadAggregate (AggregatingSetT m) where
  groupBy p = do
    mapM_ unsafeAggregateWithTerm [ aggregateColumnRef col | col <- Record.columns p]
    return $ Record.unsafeToAggregated p
  groupBy'  = aggregateKey
instance Monad m => MonadPartition c (PartitioningSetT c m) where
  partitionBy =  mapM_ unsafeAggregateWithTerm . Record.columns
extractAggregateTerms :: (Monad m, Functor m) => Aggregatings ac at m a -> m (a, [at])
extractAggregateTerms (Aggregatings ac) = second toList <$> runWriterT ac
extractTermList :: Aggregatings ac at Identity a -> (a, [at])
extractTermList =  runIdentity . extractAggregateTerms
type AggregatingSet      = AggregatingSetT      Identity
type AggregatingPowerSet = AggregatingPowerSetT Identity
type AggregatingSetList  = AggregatingSetListT  Identity
type PartitioningSet c   = PartitioningSetT c   Identity
key :: Record Flat r
    -> AggregatingSet (Record Aggregated (Maybe r))
key p = do
  mapM_ unsafeAggregateWithTerm [ aggregateColumnRef col | col <- Record.columns p]
  return . Record.just $ Record.unsafeToAggregated p
key' :: AggregateKey a
     -> AggregatingSet a
key' = aggregateKey
set :: AggregatingSet a
    -> AggregatingSetList a
set s = do
  let (p, c) = second aggregateGroupingSet . extractTermList $ s
  unsafeAggregateWithTerm c
  return p
bkey :: Record Flat r
     -> AggregatingPowerSet (Record Aggregated (Maybe r))
bkey p = do
  unsafeAggregateWithTerm . aggregatePowerKey $ Record.columns p
  return . Record.just $ Record.unsafeToAggregated p
finalizePower :: ([AggregateBitKey] -> AggregateElem)
              -> AggregatingPowerSet a -> AggregateKey a
finalizePower finalize pow = unsafeAggregateKey . second finalize . extractTermList $ pow
rollup :: AggregatingPowerSet a -> AggregateKey a
rollup =  finalizePower aggregateRollup
cube   :: AggregatingPowerSet a -> AggregateKey a
cube   =  finalizePower aggregateCube
groupingSets :: AggregatingSetList a -> AggregateKey a
groupingSets =  unsafeAggregateKey . second aggregateSets . extractTermList