Skip to content

Commit

Permalink
Merge branch 'apache:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
mustafasrepo authored Feb 28, 2024
2 parents 38db3d8 + 32d906f commit e21ac2b
Show file tree
Hide file tree
Showing 22 changed files with 358 additions and 152 deletions.
193 changes: 92 additions & 101 deletions datafusion-cli/Cargo.lock

Large diffs are not rendered by default.

33 changes: 33 additions & 0 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,21 @@ impl SessionContext {
self.state.write().register_udwf(Arc::new(f)).ok();
}

/// Deregisters a UDF within this context.
pub fn deregister_udf(&self, name: &str) {
self.state.write().deregister_udf(name).ok();
}

/// Deregisters a UDAF within this context.
pub fn deregister_udaf(&self, name: &str) {
self.state.write().deregister_udaf(name).ok();
}

/// Deregisters a UDWF within this context.
pub fn deregister_udwf(&self, name: &str) {
self.state.write().deregister_udwf(name).ok();
}

/// Creates a [`DataFrame`] for reading a data source.
///
/// For more control such as reading multiple files, you can use
Expand Down Expand Up @@ -2026,6 +2041,24 @@ impl FunctionRegistry for SessionState {
fn register_udwf(&mut self, udwf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
Ok(self.window_functions.insert(udwf.name().into(), udwf))
}

fn deregister_udf(&mut self, name: &str) -> Result<Option<Arc<ScalarUDF>>> {
let udf = self.scalar_functions.remove(name);
if let Some(udf) = &udf {
for alias in udf.aliases() {
self.scalar_functions.remove(alias);
}
}
Ok(udf)
}

fn deregister_udaf(&mut self, name: &str) -> Result<Option<Arc<AggregateUDF>>> {
Ok(self.aggregate_functions.remove(name))
}

fn deregister_udwf(&mut self, name: &str) -> Result<Option<Arc<WindowUDF>>> {
Ok(self.window_functions.remove(name))
}
}

impl OptimizerConfig for SessionState {
Expand Down
23 changes: 23 additions & 0 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,29 @@ async fn simple_udaf() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn deregister_udaf() -> Result<()> {
let ctx = SessionContext::new();
let my_avg = create_udaf(
"my_avg",
vec![DataType::Float64],
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);

ctx.register_udaf(my_avg.clone());

assert!(ctx.state().aggregate_functions().contains_key("my_avg"));

ctx.deregister_udaf("my_avg");

assert!(!ctx.state().aggregate_functions().contains_key("my_avg"));

Ok(())
}

#[tokio::test]
async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> {
let ctx = SessionContext::new();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,22 @@ async fn test_user_defined_functions_zero_argument() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn deregister_udf() -> Result<()> {
let random_normal_udf = ScalarUDF::from(RandomUDF::new());
let ctx = SessionContext::new();

ctx.register_udf(random_normal_udf.clone());

assert!(ctx.udfs().contains("random_udf"));

ctx.deregister_udf("random_udf");

assert!(!ctx.udfs().contains("random_udf"));

Ok(())
}

#[derive(Debug)]
struct TakeUDF {
signature: Signature,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,21 @@ async fn test_udwf() {
assert_eq!(test_state.evaluate_all_called(), 2);
}

#[tokio::test]
async fn test_deregister_udwf() -> Result<()> {
let test_state = Arc::new(TestState::new());
let mut ctx = SessionContext::new();
OddCounter::register(&mut ctx, Arc::clone(&test_state));

assert!(ctx.state().window_functions().contains_key("odd_counter"));

ctx.deregister_udwf("odd_counter");

assert!(!ctx.state().window_functions().contains_key("odd_counter"));

Ok(())
}

/// Basic user defined window function with bounded window
#[tokio::test]
async fn test_udwf_bounded_window_ignores_frame() {
Expand Down
27 changes: 27 additions & 0 deletions datafusion/execution/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,33 @@ pub trait FunctionRegistry {
fn register_udwf(&mut self, _udaf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
not_impl_err!("Registering WindowUDF")
}

/// Deregisters a [`ScalarUDF`], returning the implementation that was
/// deregistered.
///
/// Returns an error (the default) if the function can not be deregistered,
/// for example if the registry is read only.
fn deregister_udf(&mut self, _name: &str) -> Result<Option<Arc<ScalarUDF>>> {
not_impl_err!("Deregistering ScalarUDF")
}

/// Deregisters a [`AggregateUDF`], returning the implementation that was
/// deregistered.
///
/// Returns an error (the default) if the function can not be deregistered,
/// for example if the registry is read only.
fn deregister_udaf(&mut self, _name: &str) -> Result<Option<Arc<AggregateUDF>>> {
not_impl_err!("Deregistering AggregateUDF")
}

/// Deregisters a [`WindowUDF`], returning the implementation that was
/// deregistered.
///
/// Returns an error (the default) if the function can not be deregistered,
/// for example if the registry is read only.
fn deregister_udwf(&mut self, _name: &str) -> Result<Option<Arc<WindowUDF>>> {
not_impl_err!("Deregistering WindowUDF")
}
}

/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode].
Expand Down
10 changes: 2 additions & 8 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ use strum_macros::EnumIter;
#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter, Copy)]
pub enum BuiltinScalarFunction {
// math functions
/// acos
Acos,
/// asin
Asin,
/// atan
Expand Down Expand Up @@ -362,7 +360,6 @@ impl BuiltinScalarFunction {
pub fn volatility(&self) -> Volatility {
match self {
// Immutable scalar builtins
BuiltinScalarFunction::Acos => Volatility::Immutable,
BuiltinScalarFunction::Asin => Volatility::Immutable,
BuiltinScalarFunction::Atan => Volatility::Immutable,
BuiltinScalarFunction::Atan2 => Volatility::Immutable,
Expand Down Expand Up @@ -873,8 +870,7 @@ impl BuiltinScalarFunction {
utf8_to_int_type(&input_expr_types[0], "levenshtein")
}

BuiltinScalarFunction::Acos
| BuiltinScalarFunction::Asin
BuiltinScalarFunction::Asin
| BuiltinScalarFunction::Atan
| BuiltinScalarFunction::Acosh
| BuiltinScalarFunction::Asinh
Expand Down Expand Up @@ -1346,8 +1342,7 @@ impl BuiltinScalarFunction {
vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])],
self.volatility(),
),
BuiltinScalarFunction::Acos
| BuiltinScalarFunction::Asin
BuiltinScalarFunction::Asin
| BuiltinScalarFunction::Atan
| BuiltinScalarFunction::Acosh
| BuiltinScalarFunction::Asinh
Expand Down Expand Up @@ -1438,7 +1433,6 @@ impl BuiltinScalarFunction {
/// Returns all names that can be used to call this function
pub fn aliases(&self) -> &'static [&'static str] {
match self {
BuiltinScalarFunction::Acos => &["acos"],
BuiltinScalarFunction::Acosh => &["acosh"],
BuiltinScalarFunction::Asin => &["asin"],
BuiltinScalarFunction::Asinh => &["asinh"],
Expand Down
2 changes: 0 additions & 2 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,6 @@ scalar_expr!(Sinh, sinh, num, "hyperbolic sine");
scalar_expr!(Cosh, cosh, num, "hyperbolic cosine");
scalar_expr!(Tanh, tanh, num, "hyperbolic tangent");
scalar_expr!(Asin, asin, num, "inverse sine");
scalar_expr!(Acos, acos, num, "inverse cosine");
scalar_expr!(Atan, atan, num, "inverse tangent");
scalar_expr!(Asinh, asinh, num, "inverse hyperbolic sine");
scalar_expr!(Acosh, acosh, num, "inverse hyperbolic cosine");
Expand Down Expand Up @@ -1339,7 +1338,6 @@ mod test {
test_unary_scalar_expr!(Cosh, cosh);
test_unary_scalar_expr!(Tanh, tanh);
test_unary_scalar_expr!(Asin, asin);
test_unary_scalar_expr!(Acos, acos);
test_unary_scalar_expr!(Atan, atan);
test_unary_scalar_expr!(Asinh, asinh);
test_unary_scalar_expr!(Acosh, acosh);
Expand Down
6 changes: 1 addition & 5 deletions datafusion/functions/src/core/nvl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,7 @@ impl ScalarUDFImpl for NVLFunc {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
// NVL has two args and they might get coerced, get a preview of this
let coerced_types = datafusion_expr::type_coercion::functions::data_types(arg_types, &self.signature);
coerced_types.map(|typs| typs[0].clone())
.map_err(|e| e.context("Failed to coerce arguments for NVL")
)
Ok(arg_types[0].clone())
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
Expand Down
110 changes: 110 additions & 0 deletions datafusion/functions/src/math/acos.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Math function: `acos()`.
use arrow::array::{ArrayRef, Float32Array, Float64Array};
use arrow::datatypes::DataType;
use datafusion_common::{exec_err, plan_datafusion_err, DataFusionError, Result};
use datafusion_expr::ColumnarValue;
use datafusion_expr::{
utils::generate_signature_error_msg, ScalarUDFImpl, Signature, Volatility,
};
use std::any::Any;
use std::sync::Arc;

#[derive(Debug)]
pub struct AcosFunc {
signature: Signature,
}

impl AcosFunc {
pub fn new() -> Self {
use DataType::*;
Self {
signature: Signature::uniform(
1,
vec![Float64, Float32],
Volatility::Immutable,
),
}
}
}

impl ScalarUDFImpl for AcosFunc {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"acos"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if arg_types.len() != 1 {
return Err(plan_datafusion_err!(
"{}",
generate_signature_error_msg(
self.name(),
self.signature().clone(),
arg_types,
)
));
}

let arg_type = &arg_types[0];

match arg_type {
DataType::Float64 => Ok(DataType::Float64),
DataType::Float32 => Ok(DataType::Float32),

// For other types (possible values null/int), use Float 64
_ => Ok(DataType::Float64),
}
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(args)?;

let arr: ArrayRef = match args[0].data_type() {
DataType::Float64 => Arc::new(make_function_scalar_inputs_return_type!(
&args[0],
self.name(),
Float64Array,
Float64Array,
{ f64::acos }
)),
DataType::Float32 => Arc::new(make_function_scalar_inputs_return_type!(
&args[0],
self.name(),
Float32Array,
Float32Array,
{ f32::acos }
)),
other => {
return exec_err!(
"Unsupported data type {other:?} for function {}",
self.name()
)
}
};
Ok(ColumnarValue::Array(arr))
}
}
19 changes: 15 additions & 4 deletions datafusion/functions/src/math/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,26 @@

//! "math" DataFusion functions
mod nans;
mod abs;
mod acos;
mod nans;

// create UDFs
make_udf_function!(nans::IsNanFunc, ISNAN, isnan);
make_udf_function!(abs::AbsFunc, ABS, abs);
make_udf_function!(acos::AcosFunc, ACOS, acos);

// Export the functions out of this package, both as expr_fn as well as a list of functions
export_functions!(
(isnan, num, "returns true if a given number is +NaN or -NaN otherwise returns false"),
(abs, num, "returns the absolute value of a given number")
);
(
isnan,
num,
"returns true if a given number is +NaN or -NaN otherwise returns false"
),
(abs, num, "returns the absolute value of a given number"),
(
acos,
num,
"returns the arc cosine or inverse cosine of a number"
)
);
2 changes: 1 addition & 1 deletion datafusion/functions/src/math/nans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

//! Encoding expressions
//! Math function: `isnan()`.
use arrow::datatypes::DataType;
use datafusion_common::{exec_err, DataFusionError, Result};
Expand Down
4 changes: 2 additions & 2 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -889,14 +889,14 @@ mod test {
// test that automatic argument type coercion for scalar functions work
let empty = empty();
let lit_expr = lit(10i64);
let fun: BuiltinScalarFunction = BuiltinScalarFunction::Acos;
let fun: BuiltinScalarFunction = BuiltinScalarFunction::Floor;
let scalar_function_expr =
Expr::ScalarFunction(ScalarFunction::new(fun, vec![lit_expr]));
let plan = LogicalPlan::Projection(Projection::try_new(
vec![scalar_function_expr],
empty,
)?);
let expected = "Projection: acos(CAST(Int64(10) AS Float64))\n EmptyRelation";
let expected = "Projection: floor(CAST(Int64(10) AS Float64))\n EmptyRelation";
assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)
}

Expand Down
Loading

0 comments on commit e21ac2b

Please sign in to comment.