Skip to content

Commit

Permalink
Optimize performance of math::trunc
Browse files Browse the repository at this point in the history
Signed-off-by: Tai Le Manh <[email protected]>
  • Loading branch information
tlm365 committed Oct 13, 2024
1 parent 1582e8d commit 7e6f820
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 30 deletions.
5 changes: 5 additions & 0 deletions datafusion/functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,8 @@ required-features = ["unicode_expressions"]
harness = false
name = "strpos"
required-features = ["unicode_expressions"]

[[bench]]
harness = false
name = "trunc"
required-features = ["math_expressions"]
47 changes: 47 additions & 0 deletions datafusion/functions/benches/trunc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

extern crate criterion;

use arrow::{
datatypes::{Float32Type, Float64Type},
util::bench_util::create_primitive_array,
};
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use datafusion_expr::ColumnarValue;
use datafusion_functions::math::trunc;

use std::sync::Arc;

fn criterion_benchmark(c: &mut Criterion) {
let trunc = trunc();
for size in [1024, 4096, 8192] {
let f32_array = Arc::new(create_primitive_array::<Float32Type>(size, 0.2));
let f32_args = vec![ColumnarValue::Array(f32_array)];
c.bench_function(&format!("trunc f32 array: {}", size), |b| {
b.iter(|| black_box(trunc.invoke(&f32_args).unwrap()))
});
let f64_array = Arc::new(create_primitive_array::<Float64Type>(size, 0.2));
let f64_args = vec![ColumnarValue::Array(f64_array)];
c.bench_function(&format!("trunc f64 array: {}", size), |b| {
b.iter(|| black_box(trunc.invoke(&f64_args).unwrap()))
});
}
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
82 changes: 52 additions & 30 deletions datafusion/functions/src/math/trunc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ use std::sync::Arc;

use crate::utils::make_scalar_function;

use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
use arrow::datatypes::DataType;
use arrow::array::{ArrayRef, AsArray, PrimitiveArray};
use arrow::datatypes::DataType::{Float32, Float64};
use arrow::datatypes::{DataType, Float32Type, Float64Type, Int64Type};
use datafusion_common::ScalarValue::Int64;
use datafusion_common::{exec_err, DataFusionError, Result};
use datafusion_common::{exec_err, Result};
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
use datafusion_expr::TypeSignature::Exact;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
Expand Down Expand Up @@ -111,44 +111,66 @@ fn trunc(args: &[ArrayRef]) -> Result<ArrayRef> {
);
}

//if only one arg then invoke toolchain trunc(num) and precision = 0 by default
//or then invoke the compute_truncate method to process precision
// If only one arg then invoke toolchain trunc(num) and precision = 0 by default
// or then invoke the compute_truncate method to process precision
let num = &args[0];
let precision = if args.len() == 1 {
ColumnarValue::Scalar(Int64(Some(0)))
} else {
ColumnarValue::Array(Arc::clone(&args[1]))
};

match args[0].data_type() {
match num.data_type() {
Float64 => match precision {
ColumnarValue::Scalar(Int64(Some(0))) => Ok(Arc::new(
make_function_scalar_inputs!(num, "num", Float64Array, { f64::trunc }),
) as ArrayRef),
ColumnarValue::Array(precision) => Ok(Arc::new(make_function_inputs2!(
num,
precision,
"x",
"y",
Float64Array,
Int64Array,
{ compute_truncate64 }
)) as ArrayRef),
ColumnarValue::Scalar(Int64(Some(0))) => {
Ok(Arc::new(
args[0]
.as_primitive::<Float64Type>()
.unary::<_, Float64Type>(|x: f64| {
if x == 0_f64 {
0_f64
} else {
x.trunc()
}
}),
) as ArrayRef)
}
ColumnarValue::Array(precision) => {
let num_array = num.as_primitive::<Float64Type>();
let precision_array = precision.as_primitive::<Int64Type>();
let result: PrimitiveArray<Float64Type> =
arrow::compute::binary(num_array, precision_array, |x, y| {
compute_truncate64(x, y)
})?;

Ok(Arc::new(result) as ArrayRef)
}
_ => exec_err!("trunc function requires a scalar or array for precision"),
},
Float32 => match precision {
ColumnarValue::Scalar(Int64(Some(0))) => Ok(Arc::new(
make_function_scalar_inputs!(num, "num", Float32Array, { f32::trunc }),
) as ArrayRef),
ColumnarValue::Array(precision) => Ok(Arc::new(make_function_inputs2!(
num,
precision,
"x",
"y",
Float32Array,
Int64Array,
{ compute_truncate32 }
)) as ArrayRef),
ColumnarValue::Scalar(Int64(Some(0))) => {
Ok(Arc::new(
args[0]
.as_primitive::<Float32Type>()
.unary::<_, Float32Type>(|x: f32| {
if x == 0_f32 {
0_f32
} else {
x.trunc()
}
}),
) as ArrayRef)
}
ColumnarValue::Array(precision) => {
let num_array = num.as_primitive::<Float32Type>();
let precision_array = precision.as_primitive::<Int64Type>();
let result: PrimitiveArray<Float32Type> =
arrow::compute::binary(num_array, precision_array, |x, y| {
compute_truncate32(x, y)
})?;

Ok(Arc::new(result) as ArrayRef)
}
_ => exec_err!("trunc function requires a scalar or array for precision"),
},
other => exec_err!("Unsupported data type {other:?} for function trunc"),
Expand Down

0 comments on commit 7e6f820

Please sign in to comment.