//===----------------------------------------------------------------------===// // DuckDB // // duckdb/function/scalar_function.hpp // // //===----------------------------------------------------------------------===// #pragma once #include "duckdb/common/vector_operations/binary_executor.hpp" #include "duckdb/common/vector_operations/ternary_executor.hpp" #include "duckdb/common/vector_operations/unary_executor.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" #include "duckdb/execution/expression_executor_state.hpp" #include "duckdb/function/function.hpp" #include "duckdb/planner/plan_serialization.hpp" #include "duckdb/storage/statistics/base_statistics.hpp" #include "duckdb/common/optional_ptr.hpp" namespace duckdb { struct FunctionLocalState { DUCKDB_API virtual ~FunctionLocalState(); template TARGET &Cast() { D_ASSERT(dynamic_cast(this)); return reinterpret_cast(*this); } template const TARGET &Cast() const { D_ASSERT(dynamic_cast(this)); return reinterpret_cast(*this); } }; class Binder; class BoundFunctionExpression; class DependencyList; class ScalarFunctionCatalogEntry; struct FunctionStatisticsInput { FunctionStatisticsInput(BoundFunctionExpression &expr_p, optional_ptr bind_data_p, vector &child_stats_p, unique_ptr *expr_ptr_p) : expr(expr_p), bind_data(bind_data_p), child_stats(child_stats_p), expr_ptr(expr_ptr_p) { } BoundFunctionExpression &expr; optional_ptr bind_data; vector &child_stats; unique_ptr *expr_ptr; }; //! The type used for scalar functions typedef std::function scalar_function_t; //! Binds the scalar function and creates the function data typedef unique_ptr (*bind_scalar_function_t)(ClientContext &context, ScalarFunction &bound_function, vector> &arguments); typedef unique_ptr (*init_local_state_t)(ExpressionState &state, const BoundFunctionExpression &expr, FunctionData *bind_data); typedef unique_ptr (*function_statistics_t)(ClientContext &context, FunctionStatisticsInput &input); //! Adds the dependencies of this BoundFunctionExpression to the set of dependencies typedef void (*dependency_function_t)(BoundFunctionExpression &expr, DependencyList &dependencies); typedef void (*function_serialize_t)(FieldWriter &writer, const FunctionData *bind_data, const ScalarFunction &function); typedef unique_ptr (*function_deserialize_t)(PlanDeserializationState &state, FieldReader &reader, ScalarFunction &function); class ScalarFunction : public BaseScalarFunction { public: DUCKDB_API ScalarFunction(string name, vector arguments, LogicalType return_type, scalar_function_t function, bind_scalar_function_t bind = nullptr, dependency_function_t dependency = nullptr, function_statistics_t statistics = nullptr, init_local_state_t init_local_state = nullptr, LogicalType varargs = LogicalType(LogicalTypeId::INVALID), FunctionSideEffects side_effects = FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING); DUCKDB_API ScalarFunction(vector arguments, LogicalType return_type, scalar_function_t function, bind_scalar_function_t bind = nullptr, dependency_function_t dependency = nullptr, function_statistics_t statistics = nullptr, init_local_state_t init_local_state = nullptr, LogicalType varargs = LogicalType(LogicalTypeId::INVALID), FunctionSideEffects side_effects = FunctionSideEffects::NO_SIDE_EFFECTS, FunctionNullHandling null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING); //! The main scalar function to execute scalar_function_t function; //! The bind function (if any) bind_scalar_function_t bind; //! Init thread local state for the function (if any) init_local_state_t init_local_state; //! The dependency function (if any) dependency_function_t dependency; //! The statistics propagation function (if any) function_statistics_t statistics; function_serialize_t serialize; function_deserialize_t deserialize; DUCKDB_API bool operator==(const ScalarFunction &rhs) const; DUCKDB_API bool operator!=(const ScalarFunction &rhs) const; DUCKDB_API bool Equal(const ScalarFunction &rhs) const; private: bool CompareScalarFunctionT(const scalar_function_t &other) const; public: DUCKDB_API static void NopFunction(DataChunk &input, ExpressionState &state, Vector &result); template static void UnaryFunction(DataChunk &input, ExpressionState &state, Vector &result) { D_ASSERT(input.ColumnCount() >= 1); UnaryExecutor::Execute(input.data[0], result, input.size()); } template static void BinaryFunction(DataChunk &input, ExpressionState &state, Vector &result) { D_ASSERT(input.ColumnCount() == 2); BinaryExecutor::ExecuteStandard(input.data[0], input.data[1], result, input.size()); } template static void TernaryFunction(DataChunk &input, ExpressionState &state, Vector &result) { D_ASSERT(input.ColumnCount() == 3); TernaryExecutor::ExecuteStandard(input.data[0], input.data[1], input.data[2], result, input.size()); } public: template static scalar_function_t GetScalarUnaryFunction(LogicalType type) { scalar_function_t function; switch (type.id()) { case LogicalTypeId::TINYINT: function = &ScalarFunction::UnaryFunction; break; case LogicalTypeId::SMALLINT: function = &ScalarFunction::UnaryFunction; break; case LogicalTypeId::INTEGER: function = &ScalarFunction::UnaryFunction; break; case LogicalTypeId::BIGINT: function = &ScalarFunction::UnaryFunction; break; case LogicalTypeId::UTINYINT: function = &ScalarFunction::UnaryFunction; break; case LogicalTypeId::USMALLINT: function = &ScalarFunction::UnaryFunction; break; case LogicalTypeId::UINTEGER: function = &ScalarFunction::UnaryFunction; break; case LogicalTypeId::UBIGINT: function = &ScalarFunction::UnaryFunction; break; case LogicalTypeId::HUGEINT: function = &ScalarFunction::UnaryFunction; break; case LogicalTypeId::FLOAT: function = &ScalarFunction::UnaryFunction; break; case LogicalTypeId::DOUBLE: function = &ScalarFunction::UnaryFunction; break; default: throw InternalException("Unimplemented type for GetScalarUnaryFunction"); } return function; } template static scalar_function_t GetScalarUnaryFunctionFixedReturn(LogicalType type) { scalar_function_t function; switch (type.id()) { case LogicalTypeId::TINYINT: function = &ScalarFunction::UnaryFunction; break; case LogicalTypeId::SMALLINT: function = &ScalarFunction::UnaryFunction; break; case LogicalTypeId::INTEGER: function = &ScalarFunction::UnaryFunction; break; case LogicalTypeId::BIGINT: function = &ScalarFunction::UnaryFunction; break; case LogicalTypeId::UTINYINT: function = &ScalarFunction::UnaryFunction; break; case LogicalTypeId::USMALLINT: function = &ScalarFunction::UnaryFunction; break; case LogicalTypeId::UINTEGER: function = &ScalarFunction::UnaryFunction; break; case LogicalTypeId::UBIGINT: function = &ScalarFunction::UnaryFunction; break; case LogicalTypeId::HUGEINT: function = &ScalarFunction::UnaryFunction; break; case LogicalTypeId::FLOAT: function = &ScalarFunction::UnaryFunction; break; case LogicalTypeId::DOUBLE: function = &ScalarFunction::UnaryFunction; break; default: throw InternalException("Unimplemented type for GetScalarUnaryFunctionFixedReturn"); } return function; } }; } // namespace duckdb