//===----------------------------------------------------------------------===// // DuckDB // // duckdb/core_functions/aggregate/algebraic/corr.hpp // // //===----------------------------------------------------------------------===// #pragma once #include "duckdb/function/aggregate_function.hpp" #include "duckdb/core_functions/aggregate/algebraic/covar.hpp" #include "duckdb/core_functions/aggregate/algebraic/stddev.hpp" namespace duckdb { struct CorrState { CovarState cov_pop; StddevState dev_pop_x; StddevState dev_pop_y; }; // Returns the correlation coefficient for non-null pairs in a group. // CORR(y, x) = COVAR_POP(y, x) / (STDDEV_POP(x) * STDDEV_POP(y)) struct CorrOperation { template static void Initialize(STATE &state) { CovarOperation::Initialize(state.cov_pop); STDDevBaseOperation::Initialize(state.dev_pop_x); STDDevBaseOperation::Initialize(state.dev_pop_y); } template static void Operation(STATE &state, const A_TYPE &x_input, const B_TYPE &y_input, AggregateBinaryInput &idata) { CovarOperation::Operation(state.cov_pop, x_input, y_input, idata); STDDevBaseOperation::Execute(state.dev_pop_x, x_input); STDDevBaseOperation::Execute(state.dev_pop_y, y_input); } template static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { CovarOperation::Combine(source.cov_pop, target.cov_pop, aggr_input_data); STDDevBaseOperation::Combine(source.dev_pop_x, target.dev_pop_x, aggr_input_data); STDDevBaseOperation::Combine(source.dev_pop_y, target.dev_pop_y, aggr_input_data); } template static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { if (state.cov_pop.count == 0 || state.dev_pop_x.count == 0 || state.dev_pop_y.count == 0) { finalize_data.ReturnNull(); } else { auto cov = state.cov_pop.co_moment / state.cov_pop.count; auto std_x = state.dev_pop_x.count > 1 ? sqrt(state.dev_pop_x.dsquared / state.dev_pop_x.count) : 0; if (!Value::DoubleIsFinite(std_x)) { throw OutOfRangeException("STDDEV_POP for X is out of range!"); } auto std_y = state.dev_pop_y.count > 1 ? sqrt(state.dev_pop_y.dsquared / state.dev_pop_y.count) : 0; if (!Value::DoubleIsFinite(std_y)) { throw OutOfRangeException("STDDEV_POP for Y is out of range!"); } if (std_x * std_y == 0) { finalize_data.ReturnNull(); return; } target = cov / (std_x * std_y); } } static bool IgnoreNull() { return true; } }; } // namespace duckdb