//===----------------------------------------------------------------------===// // DuckDB // // duckdb/common/vector_operations/ternary_executor.hpp // // //===----------------------------------------------------------------------===// #pragma once #include "duckdb/common/exception.hpp" #include "duckdb/common/types/vector.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" #include namespace duckdb { template struct TernaryStandardOperatorWrapper { template static inline RESULT_TYPE Operation(FUN fun, A_TYPE a, B_TYPE b, C_TYPE c, ValidityMask &mask, idx_t idx) { return OP::template Operation(a, b, c); } }; struct TernaryLambdaWrapper { template static inline RESULT_TYPE Operation(FUN fun, A_TYPE a, B_TYPE b, C_TYPE c, ValidityMask &mask, idx_t idx) { return fun(a, b, c); } }; struct TernaryLambdaWrapperWithNulls { template static inline RESULT_TYPE Operation(FUN fun, A_TYPE a, B_TYPE b, C_TYPE c, ValidityMask &mask, idx_t idx) { return fun(a, b, c, mask, idx); } }; struct TernaryExecutor { private: template static inline void ExecuteLoop(const A_TYPE *__restrict adata, const B_TYPE *__restrict bdata, const C_TYPE *__restrict cdata, RESULT_TYPE *__restrict result_data, idx_t count, const SelectionVector &asel, const SelectionVector &bsel, const SelectionVector &csel, ValidityMask &avalidity, ValidityMask &bvalidity, ValidityMask &cvalidity, ValidityMask &result_validity, FUN fun) { if (!avalidity.AllValid() || !bvalidity.AllValid() || !cvalidity.AllValid()) { for (idx_t i = 0; i < count; i++) { auto aidx = asel.get_index(i); auto bidx = bsel.get_index(i); auto cidx = csel.get_index(i); if (avalidity.RowIsValid(aidx) && bvalidity.RowIsValid(bidx) && cvalidity.RowIsValid(cidx)) { result_data[i] = OPWRAPPER::template Operation( fun, adata[aidx], bdata[bidx], cdata[cidx], result_validity, i); } else { result_validity.SetInvalid(i); } } } else { for (idx_t i = 0; i < count; i++) { auto aidx = asel.get_index(i); auto bidx = bsel.get_index(i); auto cidx = csel.get_index(i); result_data[i] = OPWRAPPER::template Operation( fun, adata[aidx], bdata[bidx], cdata[cidx], result_validity, i); } } } public: template static void ExecuteGeneric(Vector &a, Vector &b, Vector &c, Vector &result, idx_t count, FUN fun) { if (a.GetVectorType() == VectorType::CONSTANT_VECTOR && b.GetVectorType() == VectorType::CONSTANT_VECTOR && c.GetVectorType() == VectorType::CONSTANT_VECTOR) { result.SetVectorType(VectorType::CONSTANT_VECTOR); if (ConstantVector::IsNull(a) || ConstantVector::IsNull(b) || ConstantVector::IsNull(c)) { ConstantVector::SetNull(result, true); } else { auto adata = ConstantVector::GetData(a); auto bdata = ConstantVector::GetData(b); auto cdata = ConstantVector::GetData(c); auto result_data = ConstantVector::GetData(result); auto &result_validity = ConstantVector::Validity(result); result_data[0] = OPWRAPPER::template Operation( fun, adata[0], bdata[0], cdata[0], result_validity, 0); } } else { result.SetVectorType(VectorType::FLAT_VECTOR); UnifiedVectorFormat adata, bdata, cdata; a.ToUnifiedFormat(count, adata); b.ToUnifiedFormat(count, bdata); c.ToUnifiedFormat(count, cdata); ExecuteLoop( UnifiedVectorFormat::GetData(adata), UnifiedVectorFormat::GetData(bdata), UnifiedVectorFormat::GetData(cdata), FlatVector::GetData(result), count, *adata.sel, *bdata.sel, *cdata.sel, adata.validity, bdata.validity, cdata.validity, FlatVector::Validity(result), fun); } } template > static void Execute(Vector &a, Vector &b, Vector &c, Vector &result, idx_t count, FUN fun) { ExecuteGeneric(a, b, c, result, count, fun); } template static void ExecuteStandard(Vector &a, Vector &b, Vector &c, Vector &result, idx_t count) { ExecuteGeneric, bool>(a, b, c, result, count, false); } template > static void ExecuteWithNulls(Vector &a, Vector &b, Vector &c, Vector &result, idx_t count, FUN fun) { ExecuteGeneric(a, b, c, result, count, fun); } private: template static inline idx_t SelectLoop(const A_TYPE *__restrict adata, const B_TYPE *__restrict bdata, const C_TYPE *__restrict cdata, const SelectionVector *result_sel, idx_t count, const SelectionVector &asel, const SelectionVector &bsel, const SelectionVector &csel, ValidityMask &avalidity, ValidityMask &bvalidity, ValidityMask &cvalidity, SelectionVector *true_sel, SelectionVector *false_sel) { idx_t true_count = 0, false_count = 0; for (idx_t i = 0; i < count; i++) { auto result_idx = result_sel->get_index(i); auto aidx = asel.get_index(i); auto bidx = bsel.get_index(i); auto cidx = csel.get_index(i); bool comparison_result = (NO_NULL || (avalidity.RowIsValid(aidx) && bvalidity.RowIsValid(bidx) && cvalidity.RowIsValid(cidx))) && OP::Operation(adata[aidx], bdata[bidx], cdata[cidx]); if (HAS_TRUE_SEL) { true_sel->set_index(true_count, result_idx); true_count += comparison_result; } if (HAS_FALSE_SEL) { false_sel->set_index(false_count, result_idx); false_count += !comparison_result; } } if (HAS_TRUE_SEL) { return true_count; } else { return count - false_count; } } template static inline idx_t SelectLoopSelSwitch(UnifiedVectorFormat &adata, UnifiedVectorFormat &bdata, UnifiedVectorFormat &cdata, const SelectionVector *sel, idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) { if (true_sel && false_sel) { return SelectLoop( UnifiedVectorFormat::GetData(adata), UnifiedVectorFormat::GetData(bdata), UnifiedVectorFormat::GetData(cdata), sel, count, *adata.sel, *bdata.sel, *cdata.sel, adata.validity, bdata.validity, cdata.validity, true_sel, false_sel); } else if (true_sel) { return SelectLoop( UnifiedVectorFormat::GetData(adata), UnifiedVectorFormat::GetData(bdata), UnifiedVectorFormat::GetData(cdata), sel, count, *adata.sel, *bdata.sel, *cdata.sel, adata.validity, bdata.validity, cdata.validity, true_sel, false_sel); } else { D_ASSERT(false_sel); return SelectLoop( UnifiedVectorFormat::GetData(adata), UnifiedVectorFormat::GetData(bdata), UnifiedVectorFormat::GetData(cdata), sel, count, *adata.sel, *bdata.sel, *cdata.sel, adata.validity, bdata.validity, cdata.validity, true_sel, false_sel); } } template static inline idx_t SelectLoopSwitch(UnifiedVectorFormat &adata, UnifiedVectorFormat &bdata, UnifiedVectorFormat &cdata, const SelectionVector *sel, idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) { if (!adata.validity.AllValid() || !bdata.validity.AllValid() || !cdata.validity.AllValid()) { return SelectLoopSelSwitch(adata, bdata, cdata, sel, count, true_sel, false_sel); } else { return SelectLoopSelSwitch(adata, bdata, cdata, sel, count, true_sel, false_sel); } } public: template static idx_t Select(Vector &a, Vector &b, Vector &c, const SelectionVector *sel, idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) { if (!sel) { sel = FlatVector::IncrementalSelectionVector(); } UnifiedVectorFormat adata, bdata, cdata; a.ToUnifiedFormat(count, adata); b.ToUnifiedFormat(count, bdata); c.ToUnifiedFormat(count, cdata); return SelectLoopSwitch(adata, bdata, cdata, sel, count, true_sel, false_sel); } }; } // namespace duckdb