Skip to content

Commit

Permalink
Fixed normalize function for RecordBatch. Adjusted test case to mat…
Browse files Browse the repository at this point in the history
…ch the example from PyArrow.
  • Loading branch information
nglime committed Nov 24, 2024
1 parent 55eb953 commit 30d6294
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 82 deletions.
108 changes: 53 additions & 55 deletions arrow-array/src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
//! A two-dimensional batch of column-oriented data with a defined
//! [schema](arrow_schema::Schema).

use crate::cast::AsArray;
use crate::{new_empty_array, Array, ArrayRef, StructArray};
use arrow_schema::{
ArrowError, DataType, Field, FieldRef, Fields, Schema, SchemaBuilder, SchemaRef,
};
use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, SchemaBuilder, SchemaRef};
use std::collections::VecDeque;
use std::ops::{Deref, Index};
use std::ops::Index;
use std::sync::Arc;

/// Trait for types that can read `RecordBatch`'s.
Expand Down Expand Up @@ -406,7 +405,8 @@ impl RecordBatch {
)
}

/// Normalize a semi-structured RecordBatch into a flat table
/// Normalize a semi-structured [`RecordBatch`] into a flat table.
///
/// If max_level is 0, normalizes all levels.
pub fn normalize(&self, separator: &str, mut max_level: usize) -> Result<Self, ArrowError> {
if max_level == 0 {
Expand All @@ -418,54 +418,47 @@ impl RecordBatch {
self.schema.normalize(separator, max_level)?,
)));
}
let mut queue: VecDeque<(usize, &Arc<dyn Array>, &FieldRef)> = VecDeque::new();
let mut queue: VecDeque<(usize, (ArrayRef, FieldRef))> = VecDeque::new();

// push fields
for (c, f) in self.columns.iter().zip(self.schema.fields()) {
queue.push_front((0, c, f));
queue.push_back((0, ((*c).clone(), (*f).clone())));
}

while !queue.is_empty() {
match queue.pop_front() {
Some((depth, c, f)) => {

if depth < max_level {
match (c.data_type(), f.data_type()) {
//DataType::List(f) => field,
//DataType::ListView(_) => field,
//DataType::FixedSizeList(_, _) => field,
//DataType::LargeList(_) => field,
//DataType::LargeListView(_) => field,
(DataType::Struct(cf), DataType::Struct(ff)) => {
let field_name = f.name().as_str();
let new_key = format!("{key_string}{separator}{field_name}");
ff.iter().rev().zip(cf.iter().rev()).map(|(field, ())| {
let updated_field = Field::new(
format!("{key_string}{separator}{}", field.name()),
field.data_type().clone(),
field.is_nullable(),
);
queue.push_front((
depth + 1,
c, // TODO: need to modify c -- if it's a StructArray, it needs to have the fields modified.
&Arc::new(updated_field),
))
});
}
//DataType::Union(_, _) => field,
//DataType::Dictionary(_, _) => field,
//DataType::Map(_, _) => field,
//DataType::RunEndEncoded(_, _) => field, // not sure how to support this field
_ => queue.push_front((depth, c, f)),
let mut columns: Vec<ArrayRef> = Vec::new();
let mut fields: Vec<FieldRef> = Vec::new();

while let Some((depth, (c, f))) = queue.pop_front() {
if depth < max_level {
match f.data_type() {
DataType::Struct(ff) => {
// Need to zip these in reverse to maintain original order
for (cff, fff) in c
.as_struct()
.columns()
.iter()
.rev()
.zip(ff.into_iter().rev())
{
let new_key = format!("{}{separator}{}", f.name(), fff.name());
let updated_field = Field::new(
new_key.as_str(),
fff.data_type().clone(),
fff.is_nullable(),
);
queue.push_front((depth + 1, (cff.clone(), Arc::new(updated_field))))
}
} else {
queue.push_front((depth, c, f));
}
_ => {
columns.push(c);
fields.push(f);
}
}
None => break,
};
} else {
columns.push(c);
fields.push(f);
}
}
todo!()
RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
}

/// Returns the number of columns in the record batch.
Expand Down Expand Up @@ -1282,9 +1275,9 @@ mod tests {
let year_field = Arc::new(Field::new("year", DataType::Int64, true));

let a = Arc::new(StructArray::from(vec![
(animals_field.clone(), Arc::new(animals) as ArrayRef),
(n_legs_field.clone(), Arc::new(n_legs) as ArrayRef),
(year_field.clone(), Arc::new(year) as ArrayRef),
(animals_field.clone(), Arc::new(animals.clone()) as ArrayRef),
(n_legs_field.clone(), Arc::new(n_legs.clone()) as ArrayRef),
(year_field.clone(), Arc::new(year.clone()) as ArrayRef),
]));

let month = Arc::new(Int64Array::from(vec![Some(4), Some(6)]));
Expand All @@ -1297,16 +1290,21 @@ mod tests {
),
Field::new("month", DataType::Int64, true),
]);
let normalized = schema.clone().normalize(".", 0).unwrap();
println!("{:?}", normalized);

let record_batch =
RecordBatch::try_new(Arc::new(schema), vec![a, month]).expect("valid conversion");
let record_batch = RecordBatch::try_new(Arc::new(schema), vec![a, month.clone()])
.expect("valid conversion");

let normalized = record_batch.normalize(".", 0).expect("valid normalization");

println!("Fields: {:?}", record_batch.schema().fields());
println!("Metadata{:?}", record_batch.columns());
let expected = RecordBatch::try_from_iter_with_nullable(vec![
("a.animals", animals.clone(), true),
("a.n_legs", n_legs.clone(), true),
("a.year", year.clone(), true),
("month", month.clone(), true),
])
.expect("valid conversion");

//println!("{:?}", record_batch);
assert_eq!(expected, normalized);
}

#[test]
Expand Down
42 changes: 15 additions & 27 deletions arrow-schema/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -419,14 +419,9 @@ impl Schema {
if max_level == 0 {
max_level = usize::MAX;
}
let mut new_fields: Vec<Field> = vec![];
let mut new_fields: Vec<FieldRef> = vec![];
for field in self.fields() {
match field.data_type() {
//DataType::List(f) => field,
//DataType::ListView(_) => field,
//DataType::FixedSizeList(_, _) => field,
//DataType::LargeList(_) => field,
//DataType::LargeListView(_) => field,
DataType::Struct(nested_fields) => {
let field_name = field.name().as_str();
new_fields = [
Expand All @@ -440,15 +435,11 @@ impl Schema {
]
.concat();
}
//DataType::Union(_, _) => field,
//DataType::Dictionary(_, _) => field,
//DataType::Map(_, _) => field,
//DataType::RunEndEncoded(_, _) => field, // not sure how to support this field
_ => new_fields.push(Field::new(
_ => new_fields.push(Arc::new(Field::new(
field.name(),
field.data_type().clone(),
field.is_nullable(),
)),
))),
};
}
Ok(Self::new_with_metadata(new_fields, self.metadata.clone()))
Expand All @@ -459,16 +450,11 @@ impl Schema {
key_string: &str,
separator: &str,
max_level: usize,
) -> Vec<Field> {
) -> Vec<FieldRef> {
let mut new_fields: Vec<FieldRef> = vec![];
if max_level > 0 {
let mut new_fields: Vec<Field> = vec![];
for field in fields {
match field.data_type() {
//DataType::List(f) => ,
//DataType::ListView(_) => ,
//DataType::FixedSizeList(_, _) => ,
//DataType::LargeList(_) => ,
//DataType::LargeListView(_) => ,
DataType::Struct(nested_fields) => {
let field_name = field.name().as_str();
let new_key = format!("{key_string}{separator}{field_name}");
Expand All @@ -483,21 +469,23 @@ impl Schema {
]
.concat();
}
//DataType::Union(_, _) => field,
//DataType::Dictionary(_, _) => field,
//DataType::Map(_, _) => field,
//DataType::RunEndEncoded(_, _) => field, // not sure how to support this field
_ => new_fields.push(Field::new(
_ => new_fields.push(Arc::new(Field::new(
format!("{key_string}{separator}{}", field.name()),
field.data_type().clone(),
field.is_nullable(),
)),
))),
};
}
new_fields
} else {
todo!()
for field in fields {
new_fields.push(Arc::new(Field::new(
format!("{key_string}{separator}{}", field.name()),
field.data_type().clone(),
field.is_nullable(),
)));
}
}
new_fields
}

/// Look up a column by name and return a immutable reference to the column along with
Expand Down

0 comments on commit 30d6294

Please sign in to comment.