Skip to content

Commit

Permalink
Fix pyo3 deprecations
Browse files Browse the repository at this point in the history
  • Loading branch information
jonmmease committed Aug 11, 2024
1 parent 5f46cd1 commit 6839a34
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 170 deletions.
21 changes: 11 additions & 10 deletions vegafusion-common/src/data/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ use {
use {
arrow::pyarrow::{FromPyArrow, ToPyArrow},
pyo3::{
prelude::PyModule,
prelude::*,
types::{PyList, PyTuple},
PyAny, PyErr, PyObject, Python,
Bound, PyAny, PyErr, PyObject, Python,
},
};

Expand Down Expand Up @@ -268,18 +268,17 @@ impl VegaFusionTable {
}

#[cfg(feature = "pyarrow")]
pub fn from_pyarrow(py: Python, pyarrow_table: &PyAny) -> std::result::Result<Self, PyErr> {
pub fn from_pyarrow(pyarrow_table: &Bound<PyAny>) -> std::result::Result<Self, PyErr> {
// Extract table.schema as a Rust Schema
let getattr_args = PyTuple::new(py, vec!["schema"]);
let schema_object = pyarrow_table.call_method1("__getattribute__", getattr_args)?;
let schema = Schema::from_pyarrow(schema_object)?;
let schema_object = pyarrow_table.getattr("schema")?;
let schema = Schema::from_pyarrow_bound(&schema_object)?;

// Extract table.to_batches() as a Rust Vec<RecordBatch>
let batches_object = pyarrow_table.call_method0("to_batches")?;
let batches_list = batches_object.downcast::<PyList>()?;
let batches = batches_list
.iter()
.map(|batch_any| Ok(RecordBatch::from_pyarrow(batch_any)?))
.map(|batch_any| Ok(RecordBatch::from_pyarrow_bound(&batch_any)?))
.collect::<Result<Vec<RecordBatch>>>()?;

Ok(VegaFusionTable::try_new(Arc::new(schema), batches)?)
Expand All @@ -288,14 +287,16 @@ impl VegaFusionTable {
#[cfg(feature = "pyarrow")]
pub fn to_pyarrow(&self, py: Python) -> std::result::Result<PyObject, PyErr> {
// Convert table's record batches into Python list of pyarrow batches
let pyarrow_module = PyModule::import(py, "pyarrow")?;

use pyo3::types::PyAnyMethods;
let pyarrow_module = PyModule::import_bound(py, "pyarrow")?;
let table_cls = pyarrow_module.getattr("Table")?;
let batch_objects = self
.batches
.iter()
.map(|batch| Ok(batch.to_pyarrow(py)?))
.collect::<Result<Vec<_>>>()?;
let batches_list = PyList::new(py, batch_objects);
let batches_list = PyList::new_bound(py, batch_objects);

// Convert table's schema into pyarrow schema
let schema = if let Some(batch) = self.batches.first() {
Expand All @@ -308,7 +309,7 @@ impl VegaFusionTable {
let schema_object = schema.to_pyarrow(py)?;

// Build pyarrow table
let args = PyTuple::new(py, vec![batches_list.as_ref(), schema_object.as_ref(py)]);
let args = PyTuple::new_bound(py, vec![&batches_list, schema_object.bind(py)]);
let pa_table = table_cls.call_method1("from_batches", args)?;
Ok(PyObject::from(pa_table))
}
Expand Down
67 changes: 36 additions & 31 deletions vegafusion-python-embed/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ fn get_dialect_and_fallback_connection(

fn perform_fetch_query(query: &str, schema: &Schema, conn: &PyObject) -> Result<VegaFusionTable> {
let table = Python::with_gil(|py| -> std::result::Result<_, PyErr> {
let query_object = PyString::new(py, query);
let query_object = PyString::new_bound(py, query);
let query_object = query_object.as_ref();
let schema_object = schema.to_pyarrow(py)?;
let schema_object = schema_object.as_ref(py);
let args = PyTuple::new(py, vec![query_object, schema_object]);
let table_object = conn.call_method(py, "fetch_query", args, None)?;
VegaFusionTable::from_pyarrow(py, table_object.as_ref(py))
let schema_object = schema_object.bind(py);
let args = PyTuple::new_bound(py, vec![query_object, schema_object]);
let table_object = conn.call_method_bound(py, "fetch_query", args, None)?;
VegaFusionTable::from_pyarrow(table_object.bind(py))
})?;
Ok(table)
}
Expand Down Expand Up @@ -91,14 +91,14 @@ impl Connection for PySqlConnection {
async fn tables(&self) -> Result<HashMap<String, Schema>> {
let tables = Python::with_gil(|py| -> std::result::Result<_, PyErr> {
let tables_object = self.conn.call_method0(py, "tables")?;
let tables_dict = tables_object.downcast::<PyDict>(py)?;
let tables_dict = tables_object.downcast_bound::<PyDict>(py)?;

let mut tables: HashMap<String, Schema> = HashMap::new();

for key in tables_dict.keys() {
let value = tables_dict.get_item(key)?.unwrap();
let value = tables_dict.get_item(key.clone())?.unwrap();
let key_string = key.extract::<String>()?;
let value_schema = Schema::from_pyarrow(value)?;
let value_schema = Schema::from_pyarrow_bound(&value)?;
tables.insert(key_string, value_schema);
}
Ok(tables)
Expand Down Expand Up @@ -130,12 +130,14 @@ impl Connection for PySqlConnection {
// Register table with Python connection
let table_name_object = table_name.clone().into_py(py);
let is_temporary_object = true.into_py(py);
let args = PyTuple::new(py, vec![table_name_object, pa_table, is_temporary_object]);
let args =
PyTuple::new_bound(py, vec![table_name_object, pa_table, is_temporary_object]);

match self.conn.call_method1(py, "register_arrow", args) {
Ok(_) => {}
Err(err) => {
let exception_name = err.get_type(py).name()?;
let err_bound = err.get_type_bound(py);
let exception_name = err_bound.name()?;

// Check if we have a fallback connection and this is a RegistrationNotSupportedError
if let Some(fallback_connection) = &self.fallback_conn {
Expand Down Expand Up @@ -172,7 +174,7 @@ impl Connection for PySqlConnection {
let inner_opts = opts.clone();
let fallback_connection = Python::with_gil(|py| -> std::result::Result<_, PyErr> {
// Build Python CsvReadOptions
let vegafusion_module = PyModule::import(py, "vegafusion.connection")?;
let vegafusion_module = PyModule::import_bound(py, "vegafusion.connection")?;
let csv_opts_class = vegafusion_module.getattr("CsvReadOptions")?;

let pyschema = inner_opts
Expand All @@ -188,28 +190,29 @@ impl Connection for PySqlConnection {
("file_extension", inner_opts.file_extension.into_py(py)),
("schema", pyschema),
]
.into_py_dict(py);
let args = PyTuple::empty(py);
let csv_opts = csv_opts_class.call(args, Some(kwargs))?;
.into_py_dict_bound(py);
let args = PyTuple::empty_bound(py);
let csv_opts = csv_opts_class.call(args, Some(&kwargs))?;

// Register table with Python connection
let table_name_object = table_name.clone().into_py(py);
let path_name_object = url.to_string().into_py(py);
let is_temporary_object = true.into_py(py);
let args = PyTuple::new(
let args = PyTuple::new_bound(
py,
vec![
table_name_object.as_ref(py),
path_name_object.as_ref(py),
csv_opts,
is_temporary_object.as_ref(py),
table_name_object.bind(py),
path_name_object.bind(py),
&csv_opts,
is_temporary_object.bind(py),
],
);

match self.conn.call_method1(py, "register_csv", args) {
Ok(_) => {}
Err(err) => {
let exception_name = err.get_type(py).name()?;
let err_bound = err.get_type_bound(py);
let exception_name = err_bound.name()?;

// Check if we have a fallback connection and this is a RegistrationNotSupportedError
if let Some(fallback_connection) = &self.fallback_conn {
Expand Down Expand Up @@ -249,18 +252,19 @@ impl Connection for PySqlConnection {
let path_name_object = path.to_string().into_py(py);
let is_temporary_object = true.into_py(py);

let args = PyTuple::new(
let args = PyTuple::new_bound(
py,
vec![
table_name_object.as_ref(py),
path_name_object.as_ref(py),
is_temporary_object.as_ref(py),
table_name_object.bind(py),
path_name_object.bind(py),
is_temporary_object.bind(py),
],
);
match self.conn.call_method1(py, "register_arrow_file", args) {
Ok(_) => {}
Err(err) => {
let exception_name = err.get_type(py).name()?;
let err_bound = err.get_type_bound(py);
let exception_name = err_bound.name()?;

// Check if we have a fallback connection and this is a RegistrationNotSupportedError
if let Some(fallback_connection) = &self.fallback_conn {
Expand Down Expand Up @@ -300,18 +304,19 @@ impl Connection for PySqlConnection {
let path_name_object = path.to_string().into_py(py);
let is_temporary_object = true.into_py(py);

let args = PyTuple::new(
let args = PyTuple::new_bound(
py,
vec![
table_name_object.as_ref(py),
path_name_object.as_ref(py),
is_temporary_object.as_ref(py),
table_name_object.bind(py),
path_name_object.bind(py),
is_temporary_object.bind(py),
],
);
match self.conn.call_method1(py, "register_parquet", args) {
Ok(_) => {}
Err(err) => {
let exception_name = err.get_type(py).name()?;
let err_bound = err.get_type_bound(py);
let exception_name = err_bound.name()?;

// Check if we have a fallback connection and this is a RegistrationNotSupportedError
if let Some(fallback_connection) = &self.fallback_conn {
Expand Down Expand Up @@ -377,7 +382,7 @@ impl PySqlDataset {
let table_name = table_name_obj.extract::<String>(py)?;

let table_schema_obj = dataset.call_method0(py, "table_schema")?;
let table_schema = Schema::from_pyarrow(table_schema_obj.as_ref(py))?;
let table_schema = Schema::from_pyarrow_bound(table_schema_obj.bind(py))?;
Ok((table_name, table_schema))
})?;

Expand Down
Loading

0 comments on commit 6839a34

Please sign in to comment.