Skip to content

Commit

Permalink
Add VegaFusionTable method to convert to types compatible with Vega.js
Browse files Browse the repository at this point in the history
  • Loading branch information
jonmmease committed Oct 6, 2023
1 parent d94fee4 commit 9cd2fde
Showing 1 changed file with 71 additions and 3 deletions.
74 changes: 71 additions & 3 deletions vegafusion-common/src/data/table.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use datafusion_common::ScalarValue;

use arrow::{
array::{ArrayRef, StructArray, UInt32Array},
compute::concat_batches,
datatypes::{DataType, Field, Schema, SchemaRef},
array::{Array, ArrayRef, StructArray, UInt32Array},
compute::{cast, concat_batches},
datatypes::{DataType, Field, FieldRef, Schema, SchemaRef},
ipc::{reader::StreamReader, writer::StreamWriter},
record_batch::RecordBatch,
};
Expand Down Expand Up @@ -330,6 +330,47 @@ impl VegaFusionTable {
Ok(PyObject::from(pa_table))
}

/// Convert columns to types that are compatible with the Vega JavaScript library
pub fn to_js_types(&self) -> Result<Self> {
let new_fields = self
.schema
.fields
.iter()
.map(|f| {
if f.data_type().is_numeric() {
FieldRef::new(Field::new(
f.name().clone(),
DataType::Float64,
f.is_nullable(),
))
} else {
f.clone()
}
})
.collect::<Vec<_>>();
let new_schema = Schema::new(new_fields);

let mut new_record_batches: Vec<RecordBatch> = Vec::with_capacity(self.batches.len());
for batch in self.batches.iter() {
let mut new_columns = Vec::with_capacity(batch.num_columns());
for c in 0..batch.num_columns() {
let column = batch.column(c);
let dtype = column.data_type();
new_columns.push(if dtype.is_numeric() && dtype != &DataType::Float64 {
cast(column, &DataType::Float64)?
} else {
column.clone()
});
}
new_record_batches.push(RecordBatch::try_new(
Arc::new(new_schema.clone()),
new_columns,
)?);
}

Self::try_new(Arc::new(new_schema.clone()), new_record_batches)
}

// Serialize to bytes using Arrow IPC format
pub fn to_ipc_bytes(&self) -> Result<Vec<u8>> {
let buffer: Vec<u8> = Vec::new();
Expand Down Expand Up @@ -400,6 +441,7 @@ impl Hash for VegaFusionTable {
#[cfg(test)]
mod tests {
use crate::data::table::VegaFusionTable;
use arrow::datatypes::DataType;
use serde_json::json;

#[test]
Expand Down Expand Up @@ -439,4 +481,30 @@ mod tests {
assert_eq!(result_table2.batches.len(), 1);
assert_eq!(result_table2.to_json().unwrap(), expected_json);
}

#[test]
fn test_to_js_types() {
// Test that Integer columns are cast to double
let table1 = VegaFusionTable::from_json(&json!([
{"a": 1, "b": "A"},
{"a": 2, "b": "BB"},
{"a": 10, "b": "CCC"},
{"a": 20, "b": "DDDD"},
]))
.unwrap();
let f1 = &table1.schema.fields[0];
let f2 = &table1.schema.fields[1];
assert_eq!(f1.name(), "a");
assert_eq!(f1.data_type(), &DataType::Int64);
assert_eq!(f2.name(), "b");
assert_eq!(f2.data_type(), &DataType::Utf8);

let table2 = table1.to_js_types().unwrap();
let f1 = &table2.schema.fields[0];
let f2 = &table2.schema.fields[1];
assert_eq!(f1.name(), "a");
assert_eq!(f1.data_type(), &DataType::Float64);
assert_eq!(f2.name(), "b");
assert_eq!(f2.data_type(), &DataType::Utf8);
}
}

0 comments on commit 9cd2fde

Please sign in to comment.