module Spark.Core.Internal.AggregationFunctions(
collect,
collect',
count,
count',
countCol,
countCol',
sumCol,
sumCol',
AggTry,
UniversalAggregator(..),
applyUAOUnsafe,
applyUntypedUniAgg3
) where
import Data.Aeson(Value(Null))
import qualified Data.Text as T
import qualified Data.Vector as V
import Spark.Core.Internal.DatasetStructures
import Spark.Core.Internal.ColumnStructures
import Spark.Core.Internal.ColumnFunctions(colType, untypedCol)
import Spark.Core.Internal.DatasetFunctions
import Spark.Core.Internal.RowGenerics(ToSQL)
import Spark.Core.Internal.LocalDataFunctions()
import Spark.Core.Internal.FunctionsInternals
import Spark.Core.Internal.OpStructures
import Spark.Core.Internal.TypesStructures
import Spark.Core.Internal.Utilities
import Spark.Core.Internal.TypesFunctions(arrayType')
import Spark.Core.StructuresInternal(emptyFieldPath)
import Spark.Core.Types
import Spark.Core.Try
sumCol :: forall ref a. (Num a, SQLTypeable a, ToSQL a) =>
Column ref a -> LocalData a
sumCol = applyUAOUnsafe _sumAgg'
sumCol' :: DynColumn -> LocalFrame
sumCol' = applyUntypedUniAgg3 _sumAgg'
count :: forall a. Dataset a -> LocalData Int
count = countCol . asCol
count' :: DataFrame -> LocalFrame
count' = countCol' . asCol'
countCol :: Column ref a -> LocalData Int
countCol = applyUAOUnsafe _countAgg'
countCol' :: DynColumn -> LocalFrame
countCol' = applyUntypedUniAgg3 _countAgg'
collect :: forall ref a. (SQLTypeable a) => Column ref a -> LocalData [a]
collect = applyUAOUnsafe _collectAgg'
collect' :: DynColumn -> LocalFrame
collect' = applyUntypedUniAgg3 _collectAgg'
type AggTry a = Either T.Text a
data UniversalAggregator a buff = UniversalAggregator {
uaMergeType :: SQLType buff,
uaInitialOuter :: Dataset a -> LocalData buff,
uaMergeBuffer :: LocalData buff -> LocalData buff -> LocalData buff
}
_sumAgg' :: DataType -> AggTry UniversalAggregatorOp
_sumAgg' dt = pure UniversalAggregatorOp {
uaoMergeType = dt,
uaoInitialOuter = InnerAggOp $ AggFunction "SUM" (V.singleton emptyFieldPath),
uaoMergeBuffer = ColumnSemiGroupLaw "SUM_SL"
}
_countAgg' :: DataType -> AggTry UniversalAggregatorOp
_countAgg' _ = pure UniversalAggregatorOp {
uaoMergeType = StrictType IntType,
uaoInitialOuter = InnerAggOp $ AggFunction "COUNT" (V.singleton emptyFieldPath),
uaoMergeBuffer = ColumnSemiGroupLaw "SUM"
}
_collectAgg' :: DataType -> AggTry UniversalAggregatorOp
_collectAgg' dt =
let ldt = arrayType' dt
soMerge = StandardOperator {
soName = "org.spark.Collect",
soOutputType = ldt,
soExtra = Null
}
soMono = StandardOperator {
soName = "org.spark.CatSorted",
soOutputType = ldt,
soExtra = Null
}
in pure UniversalAggregatorOp {
uaoMergeType = ldt,
uaoInitialOuter = OpaqueAggTransform soMerge,
uaoMergeBuffer = OpaqueSemiGroupLaw soMono
}
applyUntypedUniAgg3 :: (DataType -> AggTry UniversalAggregatorOp) -> DynColumn -> LocalFrame
applyUntypedUniAgg3 f dc = do
c <- dc
let uaot = f . unSQLType . colType $ c
uao <- tryEither uaot
let no = NodeAggregatorReduction uao
let ds = pack1 c
return $ emptyLocalData no (SQLType (uaoMergeType uao)) `parents` [untyped ds]
applyUAOUnsafe :: forall a b ref. (SQLTypeable b, HasCallStack) => (DataType -> AggTry UniversalAggregatorOp) -> Column ref a -> LocalData b
applyUAOUnsafe f c =
let lf = applyUntypedUniAgg3 f (untypedCol c)
in forceRight (asObservable lf)