Skip to content

Commit

Permalink
[feat] Add Ability to Generate Function Visibility to arrow-udf (#52)
Browse files Browse the repository at this point in the history
This PR will add an additional meta parameter `visibility` to
`arrow-udf`. I might want this to be added while working on
apache/datafusion#11413. Sometimes it is
better to reference the symbol directly instead of using the function
registry.

---------

Co-authored-by: Runji Wang <[email protected]>
  • Loading branch information
xinlifoobar and wangrunji0408 authored Jul 22, 2024
1 parent e379764 commit 92fdc16
Show file tree
Hide file tree
Showing 7 changed files with 239 additions and 8 deletions.
15 changes: 13 additions & 2 deletions arrow-udf-macros/src/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,17 @@ impl FunctionAttr {
user_fn: &UserFunctionAttr,
eval_fn_name: &Ident,
) -> Result<TokenStream2> {
let fn_with_visibility = if let Some(visiblity) = &self.visibility {
// handle the scope of the visibility by parsing the visibility string
match syn::parse_str::<syn::Visibility>(visiblity)? {
syn::Visibility::Public(token) => quote! { #token fn },
syn::Visibility::Restricted(vis_restricted) => quote! { #vis_restricted fn },
syn::Visibility::Inherited => quote! { fn },
}
} else {
quote! { fn }
};

let variadic = matches!(self.args.last(), Some(t) if t == "...");
let num_args = self.args.len() - if variadic { 1 } else { 0 };
let user_fn_name = format_ident!("{}", user_fn.name);
Expand Down Expand Up @@ -420,7 +431,7 @@ impl FunctionAttr {

Ok(if self.is_table_function {
quote! {
fn #eval_fn_name<'a>(input: &'a ::arrow_udf::codegen::arrow_array::RecordBatch)
#fn_with_visibility #eval_fn_name<'a>(input: &'a ::arrow_udf::codegen::arrow_array::RecordBatch)
-> ::arrow_udf::Result<Box<dyn Iterator<Item = ::arrow_udf::codegen::arrow_array::RecordBatch> + 'a>>
{
const BATCH_SIZE: usize = 1024;
Expand All @@ -432,7 +443,7 @@ impl FunctionAttr {
}
} else {
quote! {
fn #eval_fn_name(input: &::arrow_udf::codegen::arrow_array::RecordBatch)
#fn_with_visibility #eval_fn_name(input: &::arrow_udf::codegen::arrow_array::RecordBatch)
-> ::arrow_udf::Result<::arrow_udf::codegen::arrow_array::RecordBatch>
{
#downcast_arrays
Expand Down
2 changes: 2 additions & 0 deletions arrow-udf-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,8 @@ struct FunctionAttr {
/// Generated batch function name.
/// If not specified, the macro will not generate batch function.
output: Option<String>,
/// Customized function visibility.
visibility: Option<String>,
}

/// Attributes from function signature `fn(..)`
Expand Down
2 changes: 2 additions & 0 deletions arrow-udf-macros/src/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ impl Parse for FunctionAttr {
parsed.volatile = true;
} else if meta.path().is_ident("append_only") {
parsed.append_only = true;
} else if meta.path().is_ident("visibility") {
parsed.visibility = Some(get_value()?);
} else {
return Err(Error::new(
meta.span(),
Expand Down
15 changes: 15 additions & 0 deletions arrow-udf/tests/cases/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright 2024 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

pub mod visibility_tests;
142 changes: 142 additions & 0 deletions arrow-udf/tests/cases/visibility_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// Copyright 2024 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::Arc;

use crate::common::check;
use arrow_array::{Float32Array, Float64Array, Int32Array, RecordBatch, StringArray, UInt32Array};
use arrow_schema::{DataType, Field, Schema};
use arrow_udf::function;
use expect_test::expect;

// test visibility
#[function("maybe_visible(int) -> int", output = "maybe_visible_udf")]
#[function(
"maybe_visible(uint32) -> uint32",
output = "maybe_visible_pub_udf",
visibility = "pub"
)]
#[function(
"maybe_visible(float32) -> float32",
output = "maybe_visible_pub_crate_udf",
visibility = "pub(crate)"
)]
#[function(
"maybe_visible(float64) -> float64",
output = "maybe_visible_pub_self_udf",
visibility = "pub(self)"
)]
#[function(
"maybe_visible(string) -> string",
output = "maybe_visible_pub_super_udf",
visibility = "pub(super)"
)]
fn maybe_visible<T>(x: T) -> T {
x
}

#[test]
fn test_default() {
let schema = Schema::new(vec![Field::new("int", DataType::Int32, true)]);
let arg0 = Int32Array::from(vec![Some(1), None]);
let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();

let output = maybe_visible_udf(&input).unwrap();
check(
&[output],
expect![[r#"
+---------------+
| maybe_visible |
+---------------+
| 1 |
| |
+---------------+"#]],
);
}

#[test]
fn test_pub() {
let schema = Schema::new(vec![Field::new("uint32", DataType::UInt32, true)]);
let arg0 = UInt32Array::from(vec![Some(1), None]);
let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();

let output = maybe_visible_pub_udf(&input).unwrap();
check(
&[output],
expect![[r#"
+---------------+
| maybe_visible |
+---------------+
| 1 |
| |
+---------------+"#]],
);
}

#[test]
fn test_pub_crate() {
let schema = Schema::new(vec![Field::new("float32", DataType::Float32, true)]);
let arg0 = Float32Array::from(vec![Some(1.0), None]);
let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();

let output = maybe_visible_pub_crate_udf(&input).unwrap();
check(
&[output],
expect![[r#"
+---------------+
| maybe_visible |
+---------------+
| 1.0 |
| |
+---------------+"#]],
);
}

#[test]
fn test_pub_self() {
let schema = Schema::new(vec![Field::new("float64", DataType::Float64, true)]);
let arg0 = Float64Array::from(vec![Some(1.0), None]);
let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();

let output = maybe_visible_pub_self_udf(&input).unwrap();
check(
&[output],
expect![[r#"
+---------------+
| maybe_visible |
+---------------+
| 1.0 |
| |
+---------------+"#]],
);
}

#[test]
fn test_pub_super() {
let schema = Schema::new(vec![Field::new("string", DataType::Utf8, true)]);
let arg0 = StringArray::from(vec![Some("1.0"), None]);
let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();

let output = maybe_visible_pub_super_udf(&input).unwrap();
check(
&[output],
expect![[r#"
+---------------+
| maybe_visible |
+---------------+
| 1.0 |
| |
+---------------+"#]],
);
}
23 changes: 23 additions & 0 deletions arrow-udf/tests/common/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright 2024 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use arrow_array::RecordBatch;
use arrow_cast::pretty::pretty_format_batches;
use expect_test::Expect;

/// Compare the actual output with the expected output.
#[track_caller]
pub fn check(actual: &[RecordBatch], expect: Expect) {
expect.assert_eq(&pretty_format_batches(actual).unwrap().to_string());
}
48 changes: 42 additions & 6 deletions arrow-udf/tests/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,15 @@ use arrow_array::cast::AsArray;
use arrow_array::temporal_conversions::time_to_time64us;
use arrow_array::types::{Date32Type, Int32Type};
use arrow_array::*;
use arrow_cast::pretty::pretty_format_batches;
use arrow_schema::{DataType, Field, Schema, TimeUnit};
use arrow_udf::function;
use arrow_udf::types::*;
use expect_test::{expect, Expect};
use cases::visibility_tests::{maybe_visible_pub_crate_udf, maybe_visible_pub_udf};
use common::check;
use expect_test::expect;

mod cases;
mod common;

// test no return value
#[function("null()")]
Expand Down Expand Up @@ -670,10 +674,42 @@ fn test_json_array_elements() {
);
}

/// Compare the actual output with the expected output.
#[track_caller]
fn check(actual: &[RecordBatch], expect: Expect) {
expect.assert_eq(&pretty_format_batches(actual).unwrap().to_string());
#[test]
fn test_pub() {
let schema = Schema::new(vec![Field::new("uint32", DataType::UInt32, true)]);
let arg0 = UInt32Array::from(vec![Some(1), None]);
let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();

let output = maybe_visible_pub_udf(&input).unwrap();
check(
&[output],
expect![[r#"
+---------------+
| maybe_visible |
+---------------+
| 1 |
| |
+---------------+"#]],
);
}

#[test]
fn test_pub_crate() {
let schema = Schema::new(vec![Field::new("float32", DataType::Float32, true)]);
let arg0 = Float32Array::from(vec![Some(1.0), None]);
let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();

let output = maybe_visible_pub_crate_udf(&input).unwrap();
check(
&[output],
expect![[r#"
+---------------+
| maybe_visible |
+---------------+
| 1.0 |
| |
+---------------+"#]],
);
}

/// Returns a field with JSON type.
Expand Down

0 comments on commit 92fdc16

Please sign in to comment.