diff --git a/vegafusion-common/src/data/table.rs b/vegafusion-common/src/data/table.rs index 50568364..8a0ce9b4 100644 --- a/vegafusion-common/src/data/table.rs +++ b/vegafusion-common/src/data/table.rs @@ -34,9 +34,9 @@ use { use { arrow::pyarrow::{FromPyArrow, ToPyArrow}, pyo3::{ - prelude::PyModule, + prelude::*, types::{PyList, PyTuple}, - PyAny, PyErr, PyObject, Python, + Bound, PyAny, PyErr, PyObject, Python, }, }; @@ -268,18 +268,17 @@ impl VegaFusionTable { } #[cfg(feature = "pyarrow")] - pub fn from_pyarrow(py: Python, pyarrow_table: &PyAny) -> std::result::Result { + pub fn from_pyarrow(pyarrow_table: &Bound) -> std::result::Result { // Extract table.schema as a Rust Schema - let getattr_args = PyTuple::new(py, vec!["schema"]); - let schema_object = pyarrow_table.call_method1("__getattribute__", getattr_args)?; - let schema = Schema::from_pyarrow(schema_object)?; + let schema_object = pyarrow_table.getattr("schema")?; + let schema = Schema::from_pyarrow_bound(&schema_object)?; // Extract table.to_batches() as a Rust Vec let batches_object = pyarrow_table.call_method0("to_batches")?; let batches_list = batches_object.downcast::()?; let batches = batches_list .iter() - .map(|batch_any| Ok(RecordBatch::from_pyarrow(batch_any)?)) + .map(|batch_any| Ok(RecordBatch::from_pyarrow_bound(&batch_any)?)) .collect::>>()?; Ok(VegaFusionTable::try_new(Arc::new(schema), batches)?) @@ -288,14 +287,16 @@ impl VegaFusionTable { #[cfg(feature = "pyarrow")] pub fn to_pyarrow(&self, py: Python) -> std::result::Result { // Convert table's record batches into Python list of pyarrow batches - let pyarrow_module = PyModule::import(py, "pyarrow")?; + + use pyo3::types::PyAnyMethods; + let pyarrow_module = PyModule::import_bound(py, "pyarrow")?; let table_cls = pyarrow_module.getattr("Table")?; let batch_objects = self .batches .iter() .map(|batch| Ok(batch.to_pyarrow(py)?)) .collect::>>()?; - let batches_list = PyList::new(py, batch_objects); + let batches_list = PyList::new_bound(py, batch_objects); // Convert table's schema into pyarrow schema let schema = if let Some(batch) = self.batches.first() { @@ -308,7 +309,7 @@ impl VegaFusionTable { let schema_object = schema.to_pyarrow(py)?; // Build pyarrow table - let args = PyTuple::new(py, vec![batches_list.as_ref(), schema_object.as_ref(py)]); + let args = PyTuple::new_bound(py, vec![&batches_list, schema_object.bind(py)]); let pa_table = table_cls.call_method1("from_batches", args)?; Ok(PyObject::from(pa_table)) } diff --git a/vegafusion-python-embed/src/connection.rs b/vegafusion-python-embed/src/connection.rs index b14fd19c..cba6dfd5 100644 --- a/vegafusion-python-embed/src/connection.rs +++ b/vegafusion-python-embed/src/connection.rs @@ -49,13 +49,13 @@ fn get_dialect_and_fallback_connection( fn perform_fetch_query(query: &str, schema: &Schema, conn: &PyObject) -> Result { let table = Python::with_gil(|py| -> std::result::Result<_, PyErr> { - let query_object = PyString::new(py, query); + let query_object = PyString::new_bound(py, query); let query_object = query_object.as_ref(); let schema_object = schema.to_pyarrow(py)?; - let schema_object = schema_object.as_ref(py); - let args = PyTuple::new(py, vec![query_object, schema_object]); - let table_object = conn.call_method(py, "fetch_query", args, None)?; - VegaFusionTable::from_pyarrow(py, table_object.as_ref(py)) + let schema_object = schema_object.bind(py); + let args = PyTuple::new_bound(py, vec![query_object, schema_object]); + let table_object = conn.call_method_bound(py, "fetch_query", args, None)?; + VegaFusionTable::from_pyarrow(table_object.bind(py)) })?; Ok(table) } @@ -91,14 +91,14 @@ impl Connection for PySqlConnection { async fn tables(&self) -> Result> { let tables = Python::with_gil(|py| -> std::result::Result<_, PyErr> { let tables_object = self.conn.call_method0(py, "tables")?; - let tables_dict = tables_object.downcast::(py)?; + let tables_dict = tables_object.downcast_bound::(py)?; let mut tables: HashMap = HashMap::new(); for key in tables_dict.keys() { - let value = tables_dict.get_item(key)?.unwrap(); + let value = tables_dict.get_item(key.clone())?.unwrap(); let key_string = key.extract::()?; - let value_schema = Schema::from_pyarrow(value)?; + let value_schema = Schema::from_pyarrow_bound(&value)?; tables.insert(key_string, value_schema); } Ok(tables) @@ -130,12 +130,14 @@ impl Connection for PySqlConnection { // Register table with Python connection let table_name_object = table_name.clone().into_py(py); let is_temporary_object = true.into_py(py); - let args = PyTuple::new(py, vec![table_name_object, pa_table, is_temporary_object]); + let args = + PyTuple::new_bound(py, vec![table_name_object, pa_table, is_temporary_object]); match self.conn.call_method1(py, "register_arrow", args) { Ok(_) => {} Err(err) => { - let exception_name = err.get_type(py).name()?; + let err_bound = err.get_type_bound(py); + let exception_name = err_bound.name()?; // Check if we have a fallback connection and this is a RegistrationNotSupportedError if let Some(fallback_connection) = &self.fallback_conn { @@ -172,7 +174,7 @@ impl Connection for PySqlConnection { let inner_opts = opts.clone(); let fallback_connection = Python::with_gil(|py| -> std::result::Result<_, PyErr> { // Build Python CsvReadOptions - let vegafusion_module = PyModule::import(py, "vegafusion.connection")?; + let vegafusion_module = PyModule::import_bound(py, "vegafusion.connection")?; let csv_opts_class = vegafusion_module.getattr("CsvReadOptions")?; let pyschema = inner_opts @@ -188,28 +190,29 @@ impl Connection for PySqlConnection { ("file_extension", inner_opts.file_extension.into_py(py)), ("schema", pyschema), ] - .into_py_dict(py); - let args = PyTuple::empty(py); - let csv_opts = csv_opts_class.call(args, Some(kwargs))?; + .into_py_dict_bound(py); + let args = PyTuple::empty_bound(py); + let csv_opts = csv_opts_class.call(args, Some(&kwargs))?; // Register table with Python connection let table_name_object = table_name.clone().into_py(py); let path_name_object = url.to_string().into_py(py); let is_temporary_object = true.into_py(py); - let args = PyTuple::new( + let args = PyTuple::new_bound( py, vec![ - table_name_object.as_ref(py), - path_name_object.as_ref(py), - csv_opts, - is_temporary_object.as_ref(py), + table_name_object.bind(py), + path_name_object.bind(py), + &csv_opts, + is_temporary_object.bind(py), ], ); match self.conn.call_method1(py, "register_csv", args) { Ok(_) => {} Err(err) => { - let exception_name = err.get_type(py).name()?; + let err_bound = err.get_type_bound(py); + let exception_name = err_bound.name()?; // Check if we have a fallback connection and this is a RegistrationNotSupportedError if let Some(fallback_connection) = &self.fallback_conn { @@ -249,18 +252,19 @@ impl Connection for PySqlConnection { let path_name_object = path.to_string().into_py(py); let is_temporary_object = true.into_py(py); - let args = PyTuple::new( + let args = PyTuple::new_bound( py, vec![ - table_name_object.as_ref(py), - path_name_object.as_ref(py), - is_temporary_object.as_ref(py), + table_name_object.bind(py), + path_name_object.bind(py), + is_temporary_object.bind(py), ], ); match self.conn.call_method1(py, "register_arrow_file", args) { Ok(_) => {} Err(err) => { - let exception_name = err.get_type(py).name()?; + let err_bound = err.get_type_bound(py); + let exception_name = err_bound.name()?; // Check if we have a fallback connection and this is a RegistrationNotSupportedError if let Some(fallback_connection) = &self.fallback_conn { @@ -300,18 +304,19 @@ impl Connection for PySqlConnection { let path_name_object = path.to_string().into_py(py); let is_temporary_object = true.into_py(py); - let args = PyTuple::new( + let args = PyTuple::new_bound( py, vec![ - table_name_object.as_ref(py), - path_name_object.as_ref(py), - is_temporary_object.as_ref(py), + table_name_object.bind(py), + path_name_object.bind(py), + is_temporary_object.bind(py), ], ); match self.conn.call_method1(py, "register_parquet", args) { Ok(_) => {} Err(err) => { - let exception_name = err.get_type(py).name()?; + let err_bound = err.get_type_bound(py); + let exception_name = err_bound.name()?; // Check if we have a fallback connection and this is a RegistrationNotSupportedError if let Some(fallback_connection) = &self.fallback_conn { @@ -377,7 +382,7 @@ impl PySqlDataset { let table_name = table_name_obj.extract::(py)?; let table_schema_obj = dataset.call_method0(py, "table_schema")?; - let table_schema = Schema::from_pyarrow(table_schema_obj.as_ref(py))?; + let table_schema = Schema::from_pyarrow_bound(table_schema_obj.bind(py))?; Ok((table_name, table_schema)) })?; diff --git a/vegafusion-python-embed/src/dataframe.rs b/vegafusion-python-embed/src/dataframe.rs index b5d11f6d..1c5b7878 100644 --- a/vegafusion-python-embed/src/dataframe.rs +++ b/vegafusion-python-embed/src/dataframe.rs @@ -5,7 +5,7 @@ use datafusion_proto::logical_plan::to_proto::serialize_expr; use datafusion_proto::logical_plan::DefaultLogicalExtensionCodec; use prost::Message; use pyo3::prelude::PyModule; -use pyo3::types::{PyBytes, PyTuple}; +use pyo3::types::{PyAnyMethods, PyBytes, PyTuple, PyTypeMethods}; use pyo3::{pyclass, pymethods, IntoPy, PyErr, PyObject, Python}; use serde_json::Value; use std::any::Any; @@ -48,7 +48,7 @@ impl DataFrame for PyDataFrame { fn schema(&self) -> Schema { Python::with_gil(|py| -> std::result::Result<_, PyErr> { let schema_obj = self.dataframe.call_method0(py, "schema")?; - let schema = Schema::from_pyarrow(schema_obj.as_ref(py))?; + let schema = Schema::from_pyarrow_bound(schema_obj.bind(py))?; Ok(schema) }) .expect("Failed to return Schema of DataFrameDatasource") @@ -72,7 +72,7 @@ impl DataFrame for PyDataFrame { async fn collect(&self) -> Result { let table = Python::with_gil(|py| -> std::result::Result<_, PyErr> { let table_object = self.dataframe.call_method0(py, "collect")?; - VegaFusionTable::from_pyarrow(py, table_object.as_ref(py)) + VegaFusionTable::from_pyarrow(table_object.bind(py)) })?; Ok(table) } @@ -85,12 +85,13 @@ impl DataFrame for PyDataFrame { let py_limit = limit.into_py(py); // Build arguments for Python sort method - let args = PyTuple::new(py, vec![py_exprs, py_limit]); + let args = PyTuple::new_bound(py, vec![py_exprs, py_limit]); - let new_py_df = match self.dataframe.call_method(py, "sort", args, None) { + let new_py_df = match self.dataframe.call_method_bound(py, "sort", args, None) { Ok(new_py_df) => new_py_df, Err(err) => { - let exception_name = err.get_type(py).name()?; + let err_bound = err.get_type_bound(py); + let exception_name = err_bound.name()?; if exception_name == "DataFrameOperationNotSupportedError" { // Should fall back to fallback connection below return Ok(None); @@ -112,7 +113,7 @@ impl DataFrame for PyDataFrame { // Fallback let table = Python::with_gil(|py| -> std::result::Result<_, PyErr> { let table = self.dataframe.call_method0(py, "collect")?; - VegaFusionTable::from_pyarrow(py, table.as_ref(py)) + VegaFusionTable::from_pyarrow(table.bind(py)) })?; let new_df: Arc = self.fallback_conn.scan_arrow(table).await?; @@ -125,12 +126,13 @@ impl DataFrame for PyDataFrame { let new_df = Python::with_gil(|py| -> std::result::Result<_, PyErr> { let py_exprs = exprs_to_py(py, exprs.clone())?; // Build arguments for Python sort method - let args = PyTuple::new(py, vec![py_exprs]); + let args = PyTuple::new_bound(py, vec![py_exprs]); - let new_py_df = match self.dataframe.call_method(py, "select", args, None) { + let new_py_df = match self.dataframe.call_method_bound(py, "select", args, None) { Ok(new_py_df) => new_py_df, Err(err) => { - let exception_name = err.get_type(py).name()?; + let err_bound = err.get_type_bound(py); + let exception_name = err_bound.name()?; if exception_name == "DataFrameOperationNotSupportedError" { // Should fall back to fallback connection below return Ok(None); @@ -152,7 +154,7 @@ impl DataFrame for PyDataFrame { // Fallback let table = Python::with_gil(|py| -> std::result::Result<_, PyErr> { let table = self.dataframe.call_method0(py, "collect")?; - VegaFusionTable::from_pyarrow(py, table.as_ref(py)) + VegaFusionTable::from_pyarrow(table.bind(py)) })?; let new_df: Arc = self.fallback_conn.scan_arrow(table).await?; @@ -171,12 +173,16 @@ impl DataFrame for PyDataFrame { let py_aggr_exprs = exprs_to_py(py, aggr_exprs.clone())?; // Build arguments for Python sort method - let args = PyTuple::new(py, vec![py_group_exprs, py_aggr_exprs]); + let args = PyTuple::new_bound(py, vec![py_group_exprs, py_aggr_exprs]); - let new_py_df = match self.dataframe.call_method(py, "aggregate", args, None) { + let new_py_df = match self + .dataframe + .call_method_bound(py, "aggregate", args, None) + { Ok(new_py_df) => new_py_df, Err(err) => { - let exception_name = err.get_type(py).name()?; + let err_bound = err.get_type_bound(py); + let exception_name = err_bound.name()?; if exception_name == "DataFrameOperationNotSupportedError" { // Should fall back to fallback connection below return Ok(None); @@ -198,7 +204,7 @@ impl DataFrame for PyDataFrame { // Fallback let table = Python::with_gil(|py| -> std::result::Result<_, PyErr> { let table = self.dataframe.call_method0(py, "collect")?; - VegaFusionTable::from_pyarrow(py, table.as_ref(py)) + VegaFusionTable::from_pyarrow(table.bind(py)) })?; let new_df: Arc = self.fallback_conn.scan_arrow(table).await?; @@ -217,12 +223,16 @@ impl DataFrame for PyDataFrame { let py_aggr_exprs = exprs_to_py(py, aggr_exprs.clone())?; // Build arguments for Python sort method - let args = PyTuple::new(py, vec![py_group_exprs, py_aggr_exprs]); + let args = PyTuple::new_bound(py, vec![py_group_exprs, py_aggr_exprs]); - let new_py_df = match self.dataframe.call_method(py, "joinaggregate", args, None) { + let new_py_df = match self + .dataframe + .call_method_bound(py, "joinaggregate", args, None) + { Ok(new_py_df) => new_py_df, Err(err) => { - let exception_name = err.get_type(py).name()?; + let err_bound = err.get_type_bound(py); + let exception_name = err_bound.name()?; if exception_name == "DataFrameOperationNotSupportedError" { // Should fall back to fallback connection below return Ok(None); @@ -244,7 +254,7 @@ impl DataFrame for PyDataFrame { // Fallback let table = Python::with_gil(|py| -> std::result::Result<_, PyErr> { let table = self.dataframe.call_method0(py, "collect")?; - VegaFusionTable::from_pyarrow(py, table.as_ref(py)) + VegaFusionTable::from_pyarrow(table.bind(py)) })?; let new_df: Arc = self.fallback_conn.scan_arrow(table).await?; @@ -257,12 +267,13 @@ impl DataFrame for PyDataFrame { let new_df = Python::with_gil(|py| -> std::result::Result<_, PyErr> { let py_exprs = expr_to_py(py, &predicate)?; // Build arguments for Python sort method - let args = PyTuple::new(py, vec![py_exprs]); + let args = PyTuple::new_bound(py, vec![py_exprs]); - let new_py_df = match self.dataframe.call_method(py, "filter", args, None) { + let new_py_df = match self.dataframe.call_method_bound(py, "filter", args, None) { Ok(new_py_df) => new_py_df, Err(err) => { - let exception_name = err.get_type(py).name()?; + let err_bound = err.get_type_bound(py); + let exception_name = err_bound.name()?; if exception_name == "DataFrameOperationNotSupportedError" { // Should fall back to fallback connection below return Ok(None); @@ -284,7 +295,7 @@ impl DataFrame for PyDataFrame { // Fallback let table = Python::with_gil(|py| -> std::result::Result<_, PyErr> { let table = self.dataframe.call_method0(py, "collect")?; - VegaFusionTable::from_pyarrow(py, table.as_ref(py)) + VegaFusionTable::from_pyarrow(table.bind(py)) })?; let new_df: Arc = self.fallback_conn.scan_arrow(table).await?; @@ -299,12 +310,13 @@ impl DataFrame for PyDataFrame { let py_limit = limit.into_py(py); // Build arguments for Python sort method - let args = PyTuple::new(py, vec![py_limit]); + let args = PyTuple::new_bound(py, vec![py_limit]); - let new_py_df = match self.dataframe.call_method(py, "limit", args, None) { + let new_py_df = match self.dataframe.call_method_bound(py, "limit", args, None) { Ok(new_py_df) => new_py_df, Err(err) => { - let exception_name = err.get_type(py).name()?; + let err_bound = err.get_type_bound(py); + let exception_name = err_bound.name()?; if exception_name == "DataFrameOperationNotSupportedError" { // Should fall back to fallback connection below return Ok(None); @@ -326,7 +338,7 @@ impl DataFrame for PyDataFrame { // Fallback let table = Python::with_gil(|py| -> std::result::Result<_, PyErr> { let table = self.dataframe.call_method0(py, "collect")?; - VegaFusionTable::from_pyarrow(py, table.as_ref(py)) + VegaFusionTable::from_pyarrow(table.bind(py)) })?; let new_df: Arc = self.fallback_conn.scan_arrow(table).await?; @@ -350,15 +362,16 @@ impl DataFrame for PyDataFrame { let py_order_field = order_field.into_py(py); // Build arguments for Python sort method - let args = PyTuple::new( + let args = PyTuple::new_bound( py, vec![py_fields, py_value_col, py_key_col, py_order_field], ); - let new_py_df = match self.dataframe.call_method(py, "fold", args, None) { + let new_py_df = match self.dataframe.call_method_bound(py, "fold", args, None) { Ok(new_py_df) => new_py_df, Err(err) => { - let exception_name = err.get_type(py).name()?; + let err_bound = err.get_type_bound(py); + let exception_name = err_bound.name()?; if exception_name == "DataFrameOperationNotSupportedError" { // Should fall back to fallback connection below return Ok(None); @@ -380,7 +393,7 @@ impl DataFrame for PyDataFrame { // Fallback let table = Python::with_gil(|py| -> std::result::Result<_, PyErr> { let table = self.dataframe.call_method0(py, "collect")?; - VegaFusionTable::from_pyarrow(py, table.as_ref(py)) + VegaFusionTable::from_pyarrow(table.bind(py)) })?; let new_df: Arc = self.fallback_conn.scan_arrow(table).await?; @@ -408,7 +421,7 @@ impl DataFrame for PyDataFrame { let py_mode = mode.to_string().to_ascii_lowercase().into_py(py); // Build arguments for Python sort method - let args = PyTuple::new( + let args = PyTuple::new_bound( py, vec![ py_field, @@ -420,10 +433,11 @@ impl DataFrame for PyDataFrame { ], ); - let new_py_df = match self.dataframe.call_method(py, "stack", args, None) { + let new_py_df = match self.dataframe.call_method_bound(py, "stack", args, None) { Ok(new_py_df) => new_py_df, Err(err) => { - let exception_name = err.get_type(py).name()?; + let err_bound = err.get_type_bound(py); + let exception_name = err_bound.name()?; if exception_name == "DataFrameOperationNotSupportedError" { // Should fall back to fallback connection below return Ok(None); @@ -445,7 +459,7 @@ impl DataFrame for PyDataFrame { // Fallback let table = Python::with_gil(|py| -> std::result::Result<_, PyErr> { let table = self.dataframe.call_method0(py, "collect")?; - VegaFusionTable::from_pyarrow(py, table.as_ref(py)) + VegaFusionTable::from_pyarrow(table.bind(py)) })?; let new_df: Arc = self.fallback_conn.scan_arrow(table).await?; @@ -473,15 +487,16 @@ impl DataFrame for PyDataFrame { let py_order_field = order_field.into_py(py); // Build arguments for Python sort method - let args = PyTuple::new( + let args = PyTuple::new_bound( py, vec![py_field, py_value, py_key, py_groupby, py_order_field], ); - let new_py_df = match self.dataframe.call_method(py, "impute", args, None) { + let new_py_df = match self.dataframe.call_method_bound(py, "impute", args, None) { Ok(new_py_df) => new_py_df, Err(err) => { - let exception_name = err.get_type(py).name()?; + let err_bound = err.get_type_bound(py); + let exception_name = err_bound.name()?; if exception_name == "DataFrameOperationNotSupportedError" { // Should fall back to fallback connection below return Ok(None); @@ -503,7 +518,7 @@ impl DataFrame for PyDataFrame { // Fallback let table = Python::with_gil(|py| -> std::result::Result<_, PyErr> { let table = self.dataframe.call_method0(py, "collect")?; - VegaFusionTable::from_pyarrow(py, table.as_ref(py)) + VegaFusionTable::from_pyarrow(table.bind(py)) })?; let new_df: Arc = self.fallback_conn.scan_arrow(table).await?; @@ -524,7 +539,7 @@ fn exprs_to_py(py: Python, exprs: Vec) -> Result { fn expr_to_py(py: Python, expr: &Expr) -> Result { let extension_codec = DefaultLogicalExtensionCodec {}; - let proto_module = PyModule::import(py, "vegafusion.proto.datafusion_pb2")?; + let proto_module = PyModule::import_bound(py, "vegafusion.proto.datafusion_pb2")?; let logical_expr_class = proto_module.getattr("LogicalExprNode")?; let proto_sort_expr = serialize_expr(expr, &extension_codec)?; @@ -532,11 +547,11 @@ fn expr_to_py(py: Python, expr: &Expr) -> Result { let sort_expr_bytes: Vec = proto_sort_expr.encode_to_vec(); // py_logical_expr = LogicalExprNode() - let py_logical_expr = logical_expr_class.call(PyTuple::empty(py), None)?; + let py_logical_expr = logical_expr_class.call(PyTuple::empty_bound(py), None)?; // py_logical_expr.ParseFromString(sort_expr_bytes) - let py_bytes = PyBytes::new(py, sort_expr_bytes.as_slice()); - let args = PyTuple::new(py, vec![py_bytes]); + let py_bytes = PyBytes::new_bound(py, sort_expr_bytes.as_slice()); + let args = PyTuple::new_bound(py, vec![py_bytes]); py_logical_expr.call_method("ParseFromString", args, None)?; // From &PyAny to PyObject to maintain ownership diff --git a/vegafusion-python-embed/src/lib.rs b/vegafusion-python-embed/src/lib.rs index ea3b3ce8..8469f313 100644 --- a/vegafusion-python-embed/src/lib.rs +++ b/vegafusion-python-embed/src/lib.rs @@ -16,7 +16,7 @@ use vegafusion_runtime::task_graph::runtime::{ChartState as RsChartState, VegaFu use crate::connection::{PySqlConnection, PySqlDataset}; use crate::dataframe::PyDataFrame; use env_logger::{Builder, Target}; -use pythonize::{depythonize, pythonize}; +use pythonize::{depythonize_bound, pythonize}; use serde_json::json; use vegafusion_common::data::table::VegaFusionTable; use vegafusion_core::patch::patch_pre_transformed_spec; @@ -79,7 +79,7 @@ impl PyChartState { pub fn update(&self, py: Python, updates: Vec) -> PyResult> { let updates = updates .into_iter() - .map(|el| Ok(depythonize::(el.as_ref(py))?)) + .map(|el| Ok(depythonize_bound::(el.bind(py).clone())?)) .collect::>>()?; let result_updates = py.allow_threads(|| { @@ -142,27 +142,30 @@ struct PyVegaFusionRuntime { impl PyVegaFusionRuntime { fn process_inline_datasets( &self, - inline_datasets: Option<&PyDict>, + inline_datasets: Option<&Bound>, ) -> PyResult<(HashMap, bool)> { let mut any_main_thread = false; if let Some(inline_datasets) = inline_datasets { Python::with_gil(|py| -> PyResult<_> { - let vegafusion_dataset_module = PyModule::import(py, "vegafusion.dataset")?; + let vegafusion_dataset_module = PyModule::import_bound(py, "vegafusion.dataset")?; let sql_dataset_type = vegafusion_dataset_module.getattr("SqlDataset")?; let df_dataset_type = vegafusion_dataset_module.getattr("DataFrameDataset")?; - let vegafusion_datasource_module = PyModule::import(py, "vegafusion.datasource")?; + let vegafusion_datasource_module = + PyModule::import_bound(py, "vegafusion.datasource")?; let datasource_type = vegafusion_datasource_module.getattr("Datasource")?; let imported_datasets = inline_datasets .iter() .map(|(name, inline_dataset)| { - let dataset = if inline_dataset.is_instance(sql_dataset_type)? { + let inline_dataset = inline_dataset.to_object(py); + let inline_dataset = inline_dataset.bind(py); + let dataset = if inline_dataset.is_instance(&sql_dataset_type)? { let main_thread = inline_dataset .call_method0("main_thread")? .extract::()?; any_main_thread = any_main_thread || main_thread; - let sql_dataset = PySqlDataset::new(inline_dataset.into_py(py))?; + let sql_dataset = PySqlDataset::new(inline_dataset.to_object(py))?; let rt = if main_thread { &self.tokio_runtime_current_thread } else { @@ -172,15 +175,15 @@ impl PyVegaFusionRuntime { rt.block_on(sql_dataset.scan_table(&sql_dataset.table_name)) })?; VegaFusionDataset::DataFrame(df) - } else if inline_dataset.is_instance(df_dataset_type)? { + } else if inline_dataset.is_instance(&df_dataset_type)? { let main_thread = inline_dataset .call_method0("main_thread")? .extract::()?; any_main_thread = any_main_thread || main_thread; - let df = Arc::new(PyDataFrame::new(inline_dataset.into_py(py))?); + let df = Arc::new(PyDataFrame::new(inline_dataset.to_object(py))?); VegaFusionDataset::DataFrame(df) - } else if inline_dataset.is_instance(datasource_type)? { + } else if inline_dataset.is_instance(&datasource_type)? { let df = self.tokio_runtime_connection.block_on( self.runtime .conn @@ -192,7 +195,7 @@ impl PyVegaFusionRuntime { // We convert to ipc bytes for two reasons: // - It allows VegaFusionDataset to compute an accurate hash of the table // - It works around https://github.com/hex-inc/vegafusion/issues/268 - let table = VegaFusionTable::from_pyarrow(py, inline_dataset)?; + let table = VegaFusionTable::from_pyarrow(inline_dataset)?; VegaFusionDataset::from_table_ipc_bytes(&table.to_ipc_bytes()?)? }; @@ -260,7 +263,7 @@ impl PyVegaFusionRuntime { local_tz: String, default_input_tz: Option, row_limit: Option, - inline_datasets: Option<&PyDict>, + inline_datasets: Option<&Bound>, ) -> PyResult { let spec = parse_json_spec(spec)?; let tz_config = TzConfig { @@ -291,13 +294,17 @@ impl PyVegaFusionRuntime { }) } - pub fn process_request_bytes(&self, py: Python, request_bytes: &PyBytes) -> PyResult { + pub fn process_request_bytes( + &self, + py: Python, + request_bytes: &Bound, + ) -> PyResult { let request_bytes = request_bytes.as_bytes(); let response_bytes = py.allow_threads(|| { self.tokio_runtime_connection .block_on(self.runtime.query_request_bytes(request_bytes)) })?; - Ok(PyBytes::new(py, &response_bytes).into()) + Ok(PyBytes::new_bound(py, &response_bytes).into()) } #[allow(clippy::too_many_arguments)] @@ -348,7 +355,7 @@ impl PyVegaFusionRuntime { default_input_tz: Option, row_limit: Option, preserve_interactivity: Option, - inline_datasets: Option<&PyDict>, + inline_datasets: Option<&Bound>, keep_signals: Option)>>, keep_datasets: Option)>>, ) -> PyResult<(PyObject, PyObject)> { @@ -407,7 +414,7 @@ impl PyVegaFusionRuntime { local_tz: String, default_input_tz: Option, row_limit: Option, - inline_datasets: Option<&PyDict>, + inline_datasets: Option<&Bound>, ) -> PyResult<(PyObject, PyObject)> { let (inline_datasets, any_main_thread_sources) = self.process_inline_datasets(inline_datasets)?; @@ -457,7 +464,7 @@ impl PyVegaFusionRuntime { .collect(); Python::with_gil(|py| -> PyResult<(PyObject, PyObject)> { - let py_response_list = PyList::empty(py); + let py_response_list = PyList::empty_bound(py); for value in values { let pytable: PyObject = if let TaskValue::Table(table) = value { table.to_pyarrow(py)? @@ -484,7 +491,7 @@ impl PyVegaFusionRuntime { preserve_interactivity: Option, extract_threshold: Option, extracted_format: Option, - inline_datasets: Option<&PyDict>, + inline_datasets: Option<&Bound>, keep_signals: Option)>>, keep_datasets: Option)>>, ) -> PyResult<(PyObject, Vec, PyObject)> { @@ -545,7 +552,7 @@ impl PyVegaFusionRuntime { let table = match extracted_format.as_str() { "pyarrow" => table.to_pyarrow(py)?, "arrow-ipc" => { - PyBytes::new(py, table.to_ipc_bytes()?.as_slice()).to_object(py) + PyBytes::new_bound(py, table.to_ipc_bytes()?.as_slice()).to_object(py) } "arrow-ipc-base64" => table.to_ipc_base64()?.into_py(py), _ => { @@ -556,7 +563,8 @@ impl PyVegaFusionRuntime { } }; - let dataset: PyObject = PyTuple::new(py, &[name, scope, table]).into_py(py); + let dataset: PyObject = + PyTuple::new_bound(py, &[name, scope, table]).into_py(py); Ok(dataset) }) .collect::>>()?; @@ -614,7 +622,7 @@ impl PyVegaFusionRuntime { /// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to /// import the module. #[pymodule] -fn vegafusion_embed(_py: Python, m: &PyModule) -> PyResult<()> { +fn vegafusion_embed(_py: Python, m: &Bound) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -631,8 +639,8 @@ fn parse_json_spec(chart_spec: PyObject) -> PyResult { "Failed to parse chart_spec string as Vega: {err}" ))), } - } else if let Ok(chart_spec) = chart_spec.downcast::(py) { - match depythonize(chart_spec) { + } else if let Ok(chart_spec) = chart_spec.downcast_bound::(py) { + match depythonize_bound::(chart_spec.clone()) { Ok(chart_spec) => Ok(chart_spec), Err(err) => Err(PyValueError::new_err(format!( "Failed to parse chart_spec dict as Vega: {err}" diff --git a/vegafusion-sql/src/connection/datafusion_py_datasource.rs b/vegafusion-sql/src/connection/datafusion_py_datasource.rs index e3a38909..9b28a3ae 100644 --- a/vegafusion-sql/src/connection/datafusion_py_datasource.rs +++ b/vegafusion-sql/src/connection/datafusion_py_datasource.rs @@ -28,7 +28,7 @@ impl PyDatasource { pub fn try_new(py_datasource: PyObject) -> Result { Python::with_gil(|py| -> Result<_, PyErr> { let table_schema_obj = py_datasource.call_method0(py, "schema")?; - let schema = Arc::new(Schema::from_pyarrow(table_schema_obj.as_ref(py))?); + let schema = Arc::new(Schema::from_pyarrow_bound(table_schema_obj.bind(py))?); Ok(Self { py_datasource, schema, @@ -141,9 +141,9 @@ impl ExecutionPlan for PyDatasourceExec { .iter() .map(|field| field.name().clone()) .collect::>(); - let args = PyTuple::new(py, vec![column_names.into_py(py)]); + let args = PyTuple::new_bound(py, vec![column_names.into_py(py)]); let pa_table = self.db.py_datasource.call_method1(py, "fetch", args)?; - let table = VegaFusionTable::from_pyarrow(py, pa_table.as_ref(py))?; + let table = VegaFusionTable::from_pyarrow(pa_table.bind(py))?; Ok(table) }) .map_err(|err| DataFusionError::Execution(err.to_string()))?; diff --git a/vegafusion-sql/tests/test_aggregate.rs b/vegafusion-sql/tests/test_aggregate.rs index e61df58d..18c42326 100644 --- a/vegafusion-sql/tests/test_aggregate.rs +++ b/vegafusion-sql/tests/test_aggregate.rs @@ -77,6 +77,7 @@ mod test_median_agg { #[apply(dialect_names)] async fn test(dialect_name: &str) { + use datafusion::functions_aggregate::median::median_udaf; use sqlparser::ast::NullTreatment; println!("{dialect_name}"); @@ -98,7 +99,7 @@ mod test_median_agg { vec![ count(flat_col("a")).alias("count_a"), Expr::AggregateFunction(expr::AggregateFunction { - func_def: AggregateFunctionDefinition::BuiltIn(AggregateFunction::Median), + func_def: AggregateFunctionDefinition::UDF(median_udaf()), args: vec![flat_col("a")], distinct: false, filter: None, @@ -126,6 +127,7 @@ mod test_median_agg { #[cfg(test)] mod test_variance_aggs { use crate::*; + use datafusion::functions_aggregate::variance::var_samp_udaf; use datafusion_expr::expr::AggregateFunctionDefinition; use vegafusion_common::column::flat_col; @@ -176,7 +178,7 @@ mod test_variance_aggs { .div(lit(100)) .alias("stddev_pop_a"), round(vec![Expr::AggregateFunction(expr::AggregateFunction { - func_def: AggregateFunctionDefinition::BuiltIn(AggregateFunction::Variance), + func_def: AggregateFunctionDefinition::UDF(var_samp_udaf()), args: vec![flat_col("a")], distinct: false, filter: None, diff --git a/vegafusion-sql/tests/test_select.rs b/vegafusion-sql/tests/test_select.rs index 1d5311e7..de44cd15 100644 --- a/vegafusion-sql/tests/test_select.rs +++ b/vegafusion-sql/tests/test_select.rs @@ -565,7 +565,7 @@ mod test_is_finite { use arrow::array::{Float64Array, Int32Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; - use datafusion_expr::{expr, Expr, ScalarFunctionDefinition}; + use datafusion_expr::{expr, Expr}; use std::sync::Arc; use vegafusion_common::column::flat_col; @@ -602,12 +602,12 @@ mod test_is_finite { flat_col("a"), flat_col("b"), Expr::ScalarFunction(expr::ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(Arc::new(ISFINITE_UDF.clone())), + func: Arc::new(ISFINITE_UDF.clone()), args: vec![flat_col("a")], }) .alias("f1"), Expr::ScalarFunction(expr::ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(Arc::new(ISFINITE_UDF.clone())), + func: Arc::new(ISFINITE_UDF.clone()), args: vec![flat_col("b")], }) .alias("f2"), @@ -638,7 +638,7 @@ mod test_is_finite { #[cfg(test)] mod test_str_to_utc_timestamp { use crate::*; - use datafusion_expr::{expr, lit, Expr, ScalarFunctionDefinition}; + use datafusion_expr::{expr, lit, Expr}; use std::sync::Arc; use vegafusion_common::column::flat_col; use vegafusion_datafusion_udfs::udfs::datetime::str_to_utc_timestamp::STR_TO_UTC_TIMESTAMP_UDF; @@ -663,9 +663,7 @@ mod test_str_to_utc_timestamp { flat_col("a"), flat_col("b"), Expr::ScalarFunction(expr::ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(Arc::new( - STR_TO_UTC_TIMESTAMP_UDF.clone(), - )), + func: Arc::new(STR_TO_UTC_TIMESTAMP_UDF.clone()), args: vec![flat_col("b"), lit("America/New_York")], }) .alias("b_utc"), @@ -702,7 +700,7 @@ mod test_str_to_utc_timestamp { #[cfg(test)] mod test_date_part_tz { use crate::*; - use datafusion_expr::{expr, lit, Expr, ScalarFunctionDefinition}; + use datafusion_expr::{expr, lit, Expr}; use std::sync::Arc; use vegafusion_common::column::flat_col; use vegafusion_datafusion_udfs::udfs::datetime::date_part_tz::DATE_PART_TZ_UDF; @@ -728,9 +726,7 @@ mod test_date_part_tz { flat_col("a"), flat_col("b"), Expr::ScalarFunction(expr::ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(Arc::new( - STR_TO_UTC_TIMESTAMP_UDF.clone(), - )), + func: Arc::new(STR_TO_UTC_TIMESTAMP_UDF.clone()), args: vec![flat_col("b"), lit("America/New_York")], }) .alias("b_utc"), @@ -743,17 +739,17 @@ mod test_date_part_tz { flat_col("b"), flat_col("b_utc"), Expr::ScalarFunction(expr::ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(Arc::new(DATE_PART_TZ_UDF.clone())), + func: Arc::new(DATE_PART_TZ_UDF.clone()), args: vec![lit("hour"), flat_col("b_utc"), lit("UTC")], }) .alias("hours_utc"), Expr::ScalarFunction(expr::ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(Arc::new(DATE_PART_TZ_UDF.clone())), + func: Arc::new(DATE_PART_TZ_UDF.clone()), args: vec![lit("hour"), flat_col("b_utc"), lit("America/Los_Angeles")], }) .alias("hours_la"), Expr::ScalarFunction(expr::ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(Arc::new(DATE_PART_TZ_UDF.clone())), + func: Arc::new(DATE_PART_TZ_UDF.clone()), args: vec![lit("hour"), flat_col("b_utc"), lit("America/New_York")], }) .alias("hours_nyc"), @@ -793,7 +789,7 @@ mod test_date_part_tz { #[cfg(test)] mod test_date_trunc_tz { use crate::*; - use datafusion_expr::{expr, lit, Expr, ScalarFunctionDefinition}; + use datafusion_expr::{expr, lit, Expr}; use std::sync::Arc; use vegafusion_common::column::flat_col; use vegafusion_datafusion_udfs::udfs::datetime::date_trunc_tz::DATE_TRUNC_TZ_UDF; @@ -819,9 +815,7 @@ mod test_date_trunc_tz { flat_col("a"), flat_col("b"), Expr::ScalarFunction(expr::ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(Arc::new( - STR_TO_UTC_TIMESTAMP_UDF.clone(), - )), + func: Arc::new(STR_TO_UTC_TIMESTAMP_UDF.clone()), args: vec![flat_col("b"), lit("America/New_York")], }) .alias("b_utc"), @@ -834,17 +828,17 @@ mod test_date_trunc_tz { flat_col("b"), flat_col("b_utc"), Expr::ScalarFunction(expr::ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(Arc::new(DATE_TRUNC_TZ_UDF.clone())), + func: Arc::new(DATE_TRUNC_TZ_UDF.clone()), args: vec![lit("day"), flat_col("b_utc"), lit("UTC")], }) .alias("day_utc"), Expr::ScalarFunction(expr::ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(Arc::new(DATE_TRUNC_TZ_UDF.clone())), + func: Arc::new(DATE_TRUNC_TZ_UDF.clone()), args: vec![lit("day"), flat_col("b_utc"), lit("America/Los_Angeles")], }) .alias("day_la"), Expr::ScalarFunction(expr::ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(Arc::new(DATE_TRUNC_TZ_UDF.clone())), + func: Arc::new(DATE_TRUNC_TZ_UDF.clone()), args: vec![lit("day"), flat_col("b_utc"), lit("America/New_York")], }) .alias("day_nyc"), @@ -884,7 +878,7 @@ mod test_date_trunc_tz { #[cfg(test)] mod test_make_timestamp_tz { use crate::*; - use datafusion_expr::{expr, lit, Expr, ScalarFunctionDefinition}; + use datafusion_expr::{expr, lit, Expr}; use std::sync::Arc; use vegafusion_common::column::flat_col; use vegafusion_datafusion_udfs::udfs::datetime::make_utc_timestamp::MAKE_UTC_TIMESTAMP; @@ -907,7 +901,7 @@ mod test_make_timestamp_tz { .select(vec![ flat_col("a"), Expr::ScalarFunction(expr::ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(Arc::new(MAKE_UTC_TIMESTAMP.clone())), + func: Arc::new(MAKE_UTC_TIMESTAMP.clone()), args: vec![ flat_col("Y"), flat_col("M"), @@ -921,7 +915,7 @@ mod test_make_timestamp_tz { }) .alias("ts_utc"), Expr::ScalarFunction(expr::ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(Arc::new(MAKE_UTC_TIMESTAMP.clone())), + func: Arc::new(MAKE_UTC_TIMESTAMP.clone()), args: vec![ flat_col("Y"), flat_col("M"), @@ -935,7 +929,7 @@ mod test_make_timestamp_tz { }) .alias("ts_nyc"), Expr::ScalarFunction(expr::ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(Arc::new(MAKE_UTC_TIMESTAMP.clone())), + func: Arc::new(MAKE_UTC_TIMESTAMP.clone()), args: vec![ flat_col("Y"), flat_col("M"), @@ -981,7 +975,7 @@ mod test_make_timestamp_tz { #[cfg(test)] mod test_epoch_to_utc_timestamp { use crate::*; - use datafusion_expr::{expr, Expr, ScalarFunctionDefinition}; + use datafusion_expr::{expr, Expr}; use std::sync::Arc; use vegafusion_common::column::flat_col; use vegafusion_datafusion_udfs::udfs::datetime::epoch_to_utc_timestamp::EPOCH_MS_TO_UTC_TIMESTAMP_UDF; @@ -1004,9 +998,7 @@ mod test_epoch_to_utc_timestamp { flat_col("a"), flat_col("t"), Expr::ScalarFunction(expr::ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(Arc::new( - EPOCH_MS_TO_UTC_TIMESTAMP_UDF.clone(), - )), + func: Arc::new(EPOCH_MS_TO_UTC_TIMESTAMP_UDF.clone()), args: vec![flat_col("t")], }) .alias("t_utc"), @@ -1043,7 +1035,7 @@ mod test_epoch_to_utc_timestamp { #[cfg(test)] mod test_utc_timestamp_to_epoch_ms { use crate::*; - use datafusion_expr::{expr, Expr, ScalarFunctionDefinition}; + use datafusion_expr::{expr, Expr}; use std::sync::Arc; use vegafusion_common::column::flat_col; use vegafusion_datafusion_udfs::udfs::datetime::epoch_to_utc_timestamp::EPOCH_MS_TO_UTC_TIMESTAMP_UDF; @@ -1068,9 +1060,7 @@ mod test_utc_timestamp_to_epoch_ms { flat_col("a"), flat_col("t"), Expr::ScalarFunction(expr::ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(Arc::new( - EPOCH_MS_TO_UTC_TIMESTAMP_UDF.clone(), - )), + func: Arc::new(EPOCH_MS_TO_UTC_TIMESTAMP_UDF.clone()), args: vec![flat_col("t")], }) .alias("t_utc"), @@ -1083,9 +1073,7 @@ mod test_utc_timestamp_to_epoch_ms { flat_col("t"), flat_col("t_utc"), Expr::ScalarFunction(expr::ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(Arc::new( - UTC_TIMESTAMP_TO_EPOCH_MS.clone(), - )), + func: Arc::new(UTC_TIMESTAMP_TO_EPOCH_MS.clone()), args: vec![flat_col("t_utc")], }) .alias("epoch_millis"), @@ -1125,7 +1113,7 @@ mod test_utc_timestamp_to_epoch_ms { #[cfg(test)] mod test_date_add_tz { use crate::*; - use datafusion_expr::{expr, lit, Expr, ScalarFunctionDefinition}; + use datafusion_expr::{expr, lit, Expr}; use std::sync::Arc; use vegafusion_common::column::flat_col; use vegafusion_datafusion_udfs::udfs::datetime::str_to_utc_timestamp::STR_TO_UTC_TIMESTAMP_UDF; @@ -1150,9 +1138,7 @@ mod test_date_add_tz { flat_col("a"), flat_col("b"), Expr::ScalarFunction(expr::ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(Arc::new( - STR_TO_UTC_TIMESTAMP_UDF.clone(), - )), + func: Arc::new(STR_TO_UTC_TIMESTAMP_UDF.clone()), args: vec![flat_col("b"), lit("UTC")], }) .alias("b_utc"), @@ -1165,12 +1151,12 @@ mod test_date_add_tz { flat_col("b"), flat_col("b_utc"), Expr::ScalarFunction(expr::ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(Arc::new(DATE_ADD_TZ_UDF.clone())), + func: Arc::new(DATE_ADD_TZ_UDF.clone()), args: vec![lit("month"), lit(1), flat_col("b_utc"), lit("UTC")], }) .alias("month_utc"), Expr::ScalarFunction(expr::ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(Arc::new(DATE_ADD_TZ_UDF.clone())), + func: Arc::new(DATE_ADD_TZ_UDF.clone()), args: vec![ lit("month"), lit(1), @@ -1215,7 +1201,7 @@ mod test_date_add_tz { #[cfg(test)] mod test_utc_timestamp_to_str { use crate::*; - use datafusion_expr::{expr, lit, Expr, ScalarFunctionDefinition}; + use datafusion_expr::{expr, lit, Expr}; use std::sync::Arc; use vegafusion_common::column::flat_col; use vegafusion_datafusion_udfs::udfs::datetime::str_to_utc_timestamp::STR_TO_UTC_TIMESTAMP_UDF; @@ -1241,9 +1227,7 @@ mod test_utc_timestamp_to_str { flat_col("a"), flat_col("b"), Expr::ScalarFunction(expr::ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(Arc::new( - STR_TO_UTC_TIMESTAMP_UDF.clone(), - )), + func: Arc::new(STR_TO_UTC_TIMESTAMP_UDF.clone()), args: vec![flat_col("b"), lit("UTC")], }) .alias("b_utc"), @@ -1256,16 +1240,12 @@ mod test_utc_timestamp_to_str { flat_col("b"), flat_col("b_utc"), Expr::ScalarFunction(expr::ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(Arc::new( - UTC_TIMESTAMP_TO_STR_UDF.clone(), - )), + func: Arc::new(UTC_TIMESTAMP_TO_STR_UDF.clone()), args: vec![flat_col("b_utc"), lit("UTC")], }) .alias("str_utc"), Expr::ScalarFunction(expr::ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(Arc::new( - UTC_TIMESTAMP_TO_STR_UDF.clone(), - )), + func: Arc::new(UTC_TIMESTAMP_TO_STR_UDF.clone()), args: vec![flat_col("b_utc"), lit("America/New_York")], }) .alias("str_nyc"), @@ -1308,7 +1288,7 @@ mod test_date_to_utc_timestamp { use arrow::array::{ArrayRef, Date32Array, Int32Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; - use datafusion_expr::{expr, lit, Expr, ScalarFunctionDefinition}; + use datafusion_expr::{expr, lit, Expr}; use std::sync::Arc; use vegafusion_common::column::flat_col; use vegafusion_datafusion_udfs::udfs::datetime::date_to_utc_timestamp::DATE_TO_UTC_TIMESTAMP_UDF; @@ -1341,9 +1321,7 @@ mod test_date_to_utc_timestamp { flat_col("a"), flat_col("b"), Expr::ScalarFunction(expr::ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(Arc::new( - DATE_TO_UTC_TIMESTAMP_UDF.clone(), - )), + func: Arc::new(DATE_TO_UTC_TIMESTAMP_UDF.clone()), args: vec![flat_col("b"), lit("America/New_York")], }) .alias("b_utc"), @@ -1383,7 +1361,7 @@ mod test_timestamp_to_utc_timestamp { use arrow::array::{ArrayRef, Int32Array, TimestampMillisecondArray}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; use arrow::record_batch::RecordBatch; - use datafusion_expr::{expr, lit, Expr, ScalarFunctionDefinition}; + use datafusion_expr::{expr, lit, Expr}; use std::sync::Arc; use vegafusion_common::column::flat_col; use vegafusion_datafusion_udfs::udfs::datetime::to_utc_timestamp::TO_UTC_TIMESTAMP_UDF; @@ -1415,7 +1393,7 @@ mod test_timestamp_to_utc_timestamp { flat_col("a"), flat_col("b"), Expr::ScalarFunction(expr::ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(Arc::new(TO_UTC_TIMESTAMP_UDF.clone())), + func: Arc::new(TO_UTC_TIMESTAMP_UDF.clone()), args: vec![flat_col("b"), lit("America/New_York")], }) .alias("b_utc"),