//===----------------------------------------------------------------------===// // DuckDB // // duckdb/common/vector_operations/aggregate_executor.hpp // // //===----------------------------------------------------------------------===// #pragma once #include "duckdb/common/exception.hpp" #include "duckdb/common/types/vector.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" #include "duckdb/function/aggregate_state.hpp" namespace duckdb { struct AggregateInputData; typedef std::pair FrameBounds; class AggregateExecutor { private: template static inline void NullaryFlatLoop(STATE_TYPE **__restrict states, AggregateInputData &aggr_input_data, idx_t count) { for (idx_t i = 0; i < count; i++) { OP::template Operation(*states[i], aggr_input_data, i); } } template static inline void NullaryScatterLoop(STATE_TYPE **__restrict states, AggregateInputData &aggr_input_data, const SelectionVector &ssel, idx_t count) { for (idx_t i = 0; i < count; i++) { auto sidx = ssel.get_index(i); OP::template Operation(*states[sidx], aggr_input_data, sidx); } } template static inline void UnaryFlatLoop(const INPUT_TYPE *__restrict idata, AggregateInputData &aggr_input_data, STATE_TYPE **__restrict states, ValidityMask &mask, idx_t count) { if (OP::IgnoreNull() && !mask.AllValid()) { AggregateUnaryInput input(aggr_input_data, mask); auto &base_idx = input.input_idx; base_idx = 0; auto entry_count = ValidityMask::EntryCount(count); for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx++) { auto validity_entry = mask.GetValidityEntry(entry_idx); idx_t next = MinValue(base_idx + ValidityMask::BITS_PER_VALUE, count); if (ValidityMask::AllValid(validity_entry)) { // all valid: perform operation for (; base_idx < next; base_idx++) { OP::template Operation(*states[base_idx], idata[base_idx], input); } } else if (ValidityMask::NoneValid(validity_entry)) { // nothing valid: skip all base_idx = next; continue; } else { // partially valid: need to check individual elements for validity idx_t start = base_idx; for (; base_idx < next; base_idx++) { if (ValidityMask::RowIsValid(validity_entry, base_idx - start)) { OP::template Operation(*states[base_idx], idata[base_idx], input); } } } } } else { AggregateUnaryInput input(aggr_input_data, mask); auto &i = input.input_idx; for (i = 0; i < count; i++) { OP::template Operation(*states[i], idata[i], input); } } } template static inline void UnaryScatterLoop(const INPUT_TYPE *__restrict idata, AggregateInputData &aggr_input_data, STATE_TYPE **__restrict states, const SelectionVector &isel, const SelectionVector &ssel, ValidityMask &mask, idx_t count) { if (OP::IgnoreNull() && !mask.AllValid()) { // potential NULL values and NULL values are ignored AggregateUnaryInput input(aggr_input_data, mask); for (idx_t i = 0; i < count; i++) { input.input_idx = isel.get_index(i); auto sidx = ssel.get_index(i); if (mask.RowIsValid(input.input_idx)) { OP::template Operation(*states[sidx], idata[input.input_idx], input); } } } else { // quick path: no NULL values or NULL values are not ignored AggregateUnaryInput input(aggr_input_data, mask); for (idx_t i = 0; i < count; i++) { input.input_idx = isel.get_index(i); auto sidx = ssel.get_index(i); OP::template Operation(*states[sidx], idata[input.input_idx], input); } } } template static inline void UnaryFlatUpdateLoop(const INPUT_TYPE *__restrict idata, AggregateInputData &aggr_input_data, STATE_TYPE *__restrict state, idx_t count, ValidityMask &mask) { AggregateUnaryInput input(aggr_input_data, mask); auto &base_idx = input.input_idx; base_idx = 0; auto entry_count = ValidityMask::EntryCount(count); for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx++) { auto validity_entry = mask.GetValidityEntry(entry_idx); idx_t next = MinValue(base_idx + ValidityMask::BITS_PER_VALUE, count); if (!OP::IgnoreNull() || ValidityMask::AllValid(validity_entry)) { // all valid: perform operation for (; base_idx < next; base_idx++) { OP::template Operation(*state, idata[base_idx], input); } } else if (ValidityMask::NoneValid(validity_entry)) { // nothing valid: skip all base_idx = next; continue; } else { // partially valid: need to check individual elements for validity idx_t start = base_idx; for (; base_idx < next; base_idx++) { if (ValidityMask::RowIsValid(validity_entry, base_idx - start)) { OP::template Operation(*state, idata[base_idx], input); } } } } } template static inline void UnaryUpdateLoop(const INPUT_TYPE *__restrict idata, AggregateInputData &aggr_input_data, STATE_TYPE *__restrict state, idx_t count, ValidityMask &mask, const SelectionVector &__restrict sel_vector) { AggregateUnaryInput input(aggr_input_data, mask); if (OP::IgnoreNull() && !mask.AllValid()) { // potential NULL values and NULL values are ignored for (idx_t i = 0; i < count; i++) { input.input_idx = sel_vector.get_index(i); if (mask.RowIsValid(input.input_idx)) { OP::template Operation(*state, idata[input.input_idx], input); } } } else { // quick path: no NULL values or NULL values are not ignored for (idx_t i = 0; i < count; i++) { input.input_idx = sel_vector.get_index(i); OP::template Operation(*state, idata[input.input_idx], input); } } } template static inline void BinaryScatterLoop(const A_TYPE *__restrict adata, AggregateInputData &aggr_input_data, const B_TYPE *__restrict bdata, STATE_TYPE **__restrict states, idx_t count, const SelectionVector &asel, const SelectionVector &bsel, const SelectionVector &ssel, ValidityMask &avalidity, ValidityMask &bvalidity) { AggregateBinaryInput input(aggr_input_data, avalidity, bvalidity); if (OP::IgnoreNull() && (!avalidity.AllValid() || !bvalidity.AllValid())) { // potential NULL values and NULL values are ignored for (idx_t i = 0; i < count; i++) { input.lidx = asel.get_index(i); input.ridx = bsel.get_index(i); auto sidx = ssel.get_index(i); if (avalidity.RowIsValid(input.lidx) && bvalidity.RowIsValid(input.ridx)) { OP::template Operation(*states[sidx], adata[input.lidx], bdata[input.ridx], input); } } } else { // quick path: no NULL values or NULL values are not ignored for (idx_t i = 0; i < count; i++) { input.lidx = asel.get_index(i); input.ridx = bsel.get_index(i); auto sidx = ssel.get_index(i); OP::template Operation(*states[sidx], adata[input.lidx], bdata[input.ridx], input); } } } template static inline void BinaryUpdateLoop(const A_TYPE *__restrict adata, AggregateInputData &aggr_input_data, const B_TYPE *__restrict bdata, STATE_TYPE *__restrict state, idx_t count, const SelectionVector &asel, const SelectionVector &bsel, ValidityMask &avalidity, ValidityMask &bvalidity) { AggregateBinaryInput input(aggr_input_data, avalidity, bvalidity); if (OP::IgnoreNull() && (!avalidity.AllValid() || !bvalidity.AllValid())) { // potential NULL values and NULL values are ignored for (idx_t i = 0; i < count; i++) { input.lidx = asel.get_index(i); input.ridx = bsel.get_index(i); if (avalidity.RowIsValid(input.lidx) && bvalidity.RowIsValid(input.ridx)) { OP::template Operation(*state, adata[input.lidx], bdata[input.ridx], input); } } } else { // quick path: no NULL values or NULL values are not ignored for (idx_t i = 0; i < count; i++) { input.lidx = asel.get_index(i); input.ridx = bsel.get_index(i); OP::template Operation(*state, adata[input.lidx], bdata[input.ridx], input); } } } public: template static void NullaryScatter(Vector &states, AggregateInputData &aggr_input_data, idx_t count) { if (states.GetVectorType() == VectorType::CONSTANT_VECTOR) { auto sdata = ConstantVector::GetData(states); OP::template ConstantOperation(**sdata, aggr_input_data, count); } else if (states.GetVectorType() == VectorType::FLAT_VECTOR) { auto sdata = FlatVector::GetData(states); NullaryFlatLoop(sdata, aggr_input_data, count); } else { UnifiedVectorFormat sdata; states.ToUnifiedFormat(count, sdata); NullaryScatterLoop((STATE_TYPE **)sdata.data, aggr_input_data, *sdata.sel, count); } } template static void NullaryUpdate(data_ptr_t state, AggregateInputData &aggr_input_data, idx_t count) { OP::template ConstantOperation(*reinterpret_cast(state), aggr_input_data, count); } template static void UnaryScatter(Vector &input, Vector &states, AggregateInputData &aggr_input_data, idx_t count) { if (input.GetVectorType() == VectorType::CONSTANT_VECTOR && states.GetVectorType() == VectorType::CONSTANT_VECTOR) { if (OP::IgnoreNull() && ConstantVector::IsNull(input)) { // constant NULL input in function that ignores NULL values return; } // regular constant: get first state auto idata = ConstantVector::GetData(input); auto sdata = ConstantVector::GetData(states); AggregateUnaryInput input_data(aggr_input_data, ConstantVector::Validity(input)); OP::template ConstantOperation(**sdata, *idata, input_data, count); } else if (input.GetVectorType() == VectorType::FLAT_VECTOR && states.GetVectorType() == VectorType::FLAT_VECTOR) { auto idata = FlatVector::GetData(input); auto sdata = FlatVector::GetData(states); UnaryFlatLoop(idata, aggr_input_data, sdata, FlatVector::Validity(input), count); } else { UnifiedVectorFormat idata, sdata; input.ToUnifiedFormat(count, idata); states.ToUnifiedFormat(count, sdata); UnaryScatterLoop(UnifiedVectorFormat::GetData(idata), aggr_input_data, (STATE_TYPE **)sdata.data, *idata.sel, *sdata.sel, idata.validity, count); } } template static void UnaryUpdate(Vector &input, AggregateInputData &aggr_input_data, data_ptr_t state, idx_t count) { switch (input.GetVectorType()) { case VectorType::CONSTANT_VECTOR: { if (OP::IgnoreNull() && ConstantVector::IsNull(input)) { return; } auto idata = ConstantVector::GetData(input); AggregateUnaryInput input_data(aggr_input_data, ConstantVector::Validity(input)); OP::template ConstantOperation(*reinterpret_cast(state), *idata, input_data, count); break; } case VectorType::FLAT_VECTOR: { auto idata = FlatVector::GetData(input); UnaryFlatUpdateLoop(idata, aggr_input_data, (STATE_TYPE *)state, count, FlatVector::Validity(input)); break; } default: { UnifiedVectorFormat idata; input.ToUnifiedFormat(count, idata); UnaryUpdateLoop(UnifiedVectorFormat::GetData(idata), aggr_input_data, (STATE_TYPE *)state, count, idata.validity, *idata.sel); break; } } } template static void BinaryScatter(AggregateInputData &aggr_input_data, Vector &a, Vector &b, Vector &states, idx_t count) { UnifiedVectorFormat adata, bdata, sdata; a.ToUnifiedFormat(count, adata); b.ToUnifiedFormat(count, bdata); states.ToUnifiedFormat(count, sdata); BinaryScatterLoop( UnifiedVectorFormat::GetData(adata), aggr_input_data, UnifiedVectorFormat::GetData(bdata), (STATE_TYPE **)sdata.data, count, *adata.sel, *bdata.sel, *sdata.sel, adata.validity, bdata.validity); } template static void BinaryUpdate(AggregateInputData &aggr_input_data, Vector &a, Vector &b, data_ptr_t state, idx_t count) { UnifiedVectorFormat adata, bdata; a.ToUnifiedFormat(count, adata); b.ToUnifiedFormat(count, bdata); BinaryUpdateLoop( UnifiedVectorFormat::GetData(adata), aggr_input_data, UnifiedVectorFormat::GetData(bdata), (STATE_TYPE *)state, count, *adata.sel, *bdata.sel, adata.validity, bdata.validity); } template static void Combine(Vector &source, Vector &target, AggregateInputData &aggr_input_data, idx_t count) { D_ASSERT(source.GetType().id() == LogicalTypeId::POINTER && target.GetType().id() == LogicalTypeId::POINTER); auto sdata = FlatVector::GetData(source); auto tdata = FlatVector::GetData(target); for (idx_t i = 0; i < count; i++) { OP::template Combine(*sdata[i], *tdata[i], aggr_input_data); } } template static void Finalize(Vector &states, AggregateInputData &aggr_input_data, Vector &result, idx_t count, idx_t offset) { if (states.GetVectorType() == VectorType::CONSTANT_VECTOR) { result.SetVectorType(VectorType::CONSTANT_VECTOR); auto sdata = ConstantVector::GetData(states); auto rdata = ConstantVector::GetData(result); AggregateFinalizeData finalize_data(result, aggr_input_data); OP::template Finalize(**sdata, *rdata, finalize_data); } else { D_ASSERT(states.GetVectorType() == VectorType::FLAT_VECTOR); result.SetVectorType(VectorType::FLAT_VECTOR); auto sdata = FlatVector::GetData(states); auto rdata = FlatVector::GetData(result); AggregateFinalizeData finalize_data(result, aggr_input_data); for (idx_t i = 0; i < count; i++) { finalize_data.result_idx = i + offset; OP::template Finalize(*sdata[i], rdata[finalize_data.result_idx], finalize_data); } } } template static void VoidFinalize(Vector &states, AggregateInputData &aggr_input_data, Vector &result, idx_t count, idx_t offset) { if (states.GetVectorType() == VectorType::CONSTANT_VECTOR) { result.SetVectorType(VectorType::CONSTANT_VECTOR); auto sdata = ConstantVector::GetData(states); AggregateFinalizeData finalize_data(result, aggr_input_data); OP::template Finalize(**sdata, finalize_data); } else { D_ASSERT(states.GetVectorType() == VectorType::FLAT_VECTOR); result.SetVectorType(VectorType::FLAT_VECTOR); auto sdata = FlatVector::GetData(states); AggregateFinalizeData finalize_data(result, aggr_input_data); for (idx_t i = 0; i < count; i++) { finalize_data.result_idx = i + offset; OP::template Finalize(*sdata[i], finalize_data); } } } template static void UnaryWindow(Vector &input, const ValidityMask &ifilter, AggregateInputData &aggr_input_data, data_ptr_t state, const FrameBounds &frame, const FrameBounds &prev, Vector &result, idx_t rid, idx_t bias) { auto idata = FlatVector::GetData(input) - bias; const auto &ivalid = FlatVector::Validity(input); OP::template Window( idata, ifilter, ivalid, aggr_input_data, *reinterpret_cast(state), frame, prev, result, rid, bias); } template static void Destroy(Vector &states, AggregateInputData &aggr_input_data, idx_t count) { auto sdata = FlatVector::GetData(states); for (idx_t i = 0; i < count; i++) { OP::template Destroy(*sdata[i], aggr_input_data); } } }; } // namespace duckdb