-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
311 additions
and
273 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,162 +1,169 @@ | ||
use arrow::datatypes::DataType; | ||
use std::collections::HashMap; | ||
|
||
use std::{collections::HashMap, sync::Arc}; | ||
use eyre::{Context, OptionExt, Report, Result}; | ||
|
||
use eyre::{ContextCompat, Result}; | ||
|
||
/// Creates a lookup table (`HashMap`) from the fields of a union. | ||
/// Converts an Arrow `UnionArray` into a `HashMap`. | ||
/// | ||
/// This function takes a reference to `arrow::datatypes::UnionFields` and | ||
/// creates a `HashMap` where the field names are the keys (as `String`) and | ||
/// the associated values are the field identifiers (`i8`). | ||
/// This function takes an Arrow `UnionArray` and converts it into a `HashMap` where the keys | ||
/// are the names of the fields and the values are the corresponding `ArrayRef` objects. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `fields` - A reference to the fields of the Arrow union data structure (`arrow::datatypes::UnionFields`). | ||
/// * `array` - An `arrow::array::UnionArray` containing the data to be converted. | ||
/// | ||
/// # Returns | ||
/// | ||
/// A `HashMap` with the field names as keys and their identifiers (`i8`) as values. | ||
/// | ||
/// # Example | ||
/// | ||
/// ``` | ||
/// use arrow::datatypes::{Field, DataType, UnionFields}; | ||
/// use std::collections::HashMap; | ||
/// | ||
/// use fastformat::arrow::union_lookup_table; | ||
/// | ||
/// let fields = UnionFields::new( | ||
/// vec![1, 2], | ||
/// vec![ | ||
/// Field::new("field1", DataType::Int32, false), | ||
/// Field::new("field2", DataType::Float64, false), | ||
/// ], | ||
/// ); | ||
/// A `Result` containing the constructed `HashMap<String, arrow::array::ArrayRef>` if successful, | ||
/// or an error otherwise. | ||
/// | ||
/// let lookup_table = union_lookup_table(&fields); | ||
/// # Errors | ||
/// | ||
/// assert_eq!(lookup_table.get("field1"), Some(&1)); | ||
/// assert_eq!(lookup_table.get("field2"), Some(&2)); | ||
/// ``` | ||
pub fn union_lookup_table(fields: &arrow::datatypes::UnionFields) -> HashMap<String, i8> { | ||
/// Returns an error if the union array field index is invalid or if there are issues | ||
/// in accessing the children of the union array. | ||
pub fn arrow_union_into_map( | ||
array: arrow::array::UnionArray, | ||
) -> Result<HashMap<String, arrow::array::ArrayRef>> { | ||
let mut result = HashMap::new(); | ||
|
||
for field in fields.iter() { | ||
let (a, b) = field; | ||
let (union_fields, _, _, children) = array.into_parts(); | ||
|
||
result.insert(b.name().to_string(), a); | ||
for (a, b) in union_fields.iter() { | ||
let child = children | ||
.get(a as usize) | ||
.ok_or_eyre(Report::msg( | ||
"Invalid union array field index. Must be >= 0 and correspond to children index in the array.", | ||
))? | ||
.clone(); | ||
|
||
result.insert(b.name().to_string(), child); | ||
} | ||
|
||
result | ||
Ok(result) | ||
} | ||
|
||
/// Retrieves a column from a `UnionArray` by its field name and downcasts it to the specified type. | ||
/// Extracts a primitive array from a `HashMap` and converts it to a `Vec`. | ||
/// | ||
/// This function takes a reference to an `arrow::array::UnionArray`, a field name, | ||
/// and a lookup table mapping field names to their identifiers. It retrieves the column | ||
/// corresponding to the field name from the union array and attempts to downcast it to | ||
/// the specified type `T`. | ||
/// This function takes a `HashMap` containing `ArrayRef` objects, extracts the array corresponding | ||
/// to the specified field, and converts it into a `Vec` of the primitive type `T`. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `array` - A reference to the `UnionArray` from which to retrieve the column. | ||
/// * `field` - The name of the field whose column is to be retrieved. | ||
/// * `lookup_table` - A reference to a `HashMap` that maps field names (`String`) to their identifiers (`i8`). | ||
/// * `field` - A string slice representing the key in the `HashMap`. | ||
/// * `map` - A mutable reference to the `HashMap<String, arrow::array::ArrayRef>`. | ||
/// | ||
/// # Returns | ||
/// | ||
/// A `Result` containing a reference to the column cast to type `T` if successful, or an error otherwise. | ||
/// A `Result` containing the constructed `Vec<T>` if successful, or an error otherwise. | ||
/// | ||
/// # Type Parameters | ||
/// | ||
/// * `T` - The native primitive type to be extracted. | ||
/// * `G` - The Arrow primitive type that corresponds to `T`. | ||
/// | ||
/// # Errors | ||
/// | ||
/// Returns an error if the field name is not found in the lookup table, or if the retrieved column | ||
/// cannot be downcast to the specified type `T`. | ||
/// Returns an error if the specified field is not present in the `HashMap`, or if the type | ||
/// conversion fails. | ||
/// | ||
/// # Example | ||
/// | ||
/// ``` | ||
/// use arrow::array::Array; | ||
/// | ||
/// use fastformat::image::Image; | ||
/// use fastformat::arrow::union_lookup_table; | ||
/// use fastformat::arrow::column_by_name; | ||
/// | ||
/// let data = vec![0; 27]; // 3x3 image with 3 bytes per pixel | ||
/// let image = Image::new_bgr8(data, 3, 3, None).unwrap(); | ||
/// let array = image.into_arrow().unwrap(); | ||
/// use std::collections::HashMap; | ||
/// use arrow::array::{Int32Array, ArrayRef}; | ||
/// use std::sync::Arc; | ||
/// | ||
/// let union_fields = match array.data_type() { | ||
/// arrow::datatypes::DataType::Union(fields, ..) => fields, | ||
/// _ => panic!("Unexpected data type for image array") | ||
/// }; | ||
/// use fastformat::arrow::get_primitive_array_from_map; | ||
/// | ||
/// let lookup_table = union_lookup_table(&union_fields); | ||
/// let mut map = HashMap::new(); | ||
/// map.insert("field".to_string(), Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef); | ||
/// | ||
/// let int_column = column_by_name::<arrow::array::Int32Array>(&array, "field1", &lookup_table); | ||
/// let result: Vec<i32> = get_primitive_array_from_map::<i32, arrow::datatypes::Int32Type>("field", &mut map).unwrap(); | ||
/// ``` | ||
pub fn column_by_name<'a, T: 'static>( | ||
array: &'a arrow::array::UnionArray, | ||
field: &'a str, | ||
lookup_table: &'a HashMap<String, i8>, | ||
) -> Result<&'a T> { | ||
let index = lookup_table | ||
.get(field) | ||
.cloned() | ||
.wrap_err(format!("Couldn't get field {} from look_up table", field))?; | ||
|
||
return array | ||
.child(index) | ||
.as_any() | ||
.downcast_ref::<T>() | ||
.wrap_err(format!("Couldn't downcast field {} to type T", field)); | ||
pub fn get_primitive_array_from_map< | ||
T: arrow::datatypes::ArrowNativeType, | ||
G: arrow::datatypes::ArrowPrimitiveType, | ||
>( | ||
field: &str, | ||
map: &mut HashMap<String, arrow::array::ArrayRef>, | ||
) -> Result<Vec<T>> { | ||
use arrow::array::Array; | ||
|
||
let array_data = map | ||
.remove(field) | ||
.ok_or_eyre(Report::msg("Invalid field for this map."))? | ||
.into_data(); | ||
|
||
let array = arrow::array::PrimitiveArray::<G>::from(array_data); | ||
let (_, buffer, _) = array.into_parts(); | ||
let buffer = buffer.into_inner(); | ||
|
||
match buffer.into_vec::<T>() { | ||
Ok(vec) => Ok(vec), | ||
Err(e) => Err(Report::msg(format!( | ||
"T is not a valid type for this buffer. Must have the same layout as the buffer (it usually occurs when type is incorrect or when an other reference exists). Error: {:?}", e | ||
))), | ||
} | ||
} | ||
|
||
/// Creates a tuple representing a union field with an index and an `Arc`-wrapped `Field`. | ||
/// Extracts a UTF-8 encoded string array from a `HashMap` and converts it to a `Vec<String>`. | ||
/// | ||
/// This function constructs a tuple where the first element is the given index and the second element | ||
/// is an `Arc`-wrapped `Field` constructed using the provided name, data type, and nullability. | ||
/// This function takes a `HashMap` containing `ArrayRef` objects, extracts the UTF-8 encoded | ||
/// string array corresponding to the specified field, and converts it into a `Vec<String>`. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `index` - An identifier (`i8`) for the union field. | ||
/// * `name` - A string slice representing the name of the field. | ||
/// * `data_type` - The data type of the field (`arrow::datatypes::DataType`). | ||
/// * `nullable` - A boolean indicating whether the field is nullable. | ||
/// * `field` - A string slice representing the key in the `HashMap`. | ||
/// * `map` - A mutable reference to the `HashMap<String, arrow::array::ArrayRef>`. | ||
/// | ||
/// # Returns | ||
/// | ||
/// A tuple where the first element is the given index (`i8`) and the second element is an `Arc`-wrapped | ||
/// `Field` constructed from the provided name, data type, and nullability. | ||
/// A `Result` containing the constructed `Vec<String>` if successful, or an error otherwise. | ||
/// | ||
/// # Errors | ||
/// | ||
/// Returns an error if the specified field is not present in the `HashMap`, or if the array | ||
/// is not UTF-8 encoded. | ||
/// | ||
/// # Example | ||
/// | ||
/// ``` | ||
/// use arrow::datatypes::{DataType, Field}; | ||
/// use std::collections::HashMap; | ||
/// use arrow::array::{StringArray, ArrayRef}; | ||
/// use std::sync::Arc; | ||
/// use fastformat::arrow::union_field; | ||
/// | ||
/// let index = 1; | ||
/// let name = "field1"; | ||
/// let data_type = DataType::Int32; | ||
/// let nullable = false; | ||
/// use fastformat::arrow::get_utf8_array_from_map; | ||
/// | ||
/// let union_field_tuple = union_field(index, name, data_type, nullable); | ||
/// let mut map = HashMap::new(); | ||
/// map.insert("field".to_string(), Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef); | ||
/// | ||
/// assert_eq!(union_field_tuple.0, 1); | ||
/// assert_eq!(union_field_tuple.1.name(), "field1"); | ||
/// assert_eq!(union_field_tuple.1.data_type(), &DataType::Int32); | ||
/// assert_eq!(union_field_tuple.1.is_nullable(), false); | ||
/// let result = get_utf8_array_from_map("field", &mut map).unwrap(); | ||
/// ``` | ||
/// | ||
pub fn union_field( | ||
index: i8, | ||
name: &str, | ||
data_type: DataType, | ||
nullable: bool, | ||
) -> (i8, Arc<arrow::datatypes::Field>) { | ||
( | ||
index, | ||
Arc::new(arrow::datatypes::Field::new(name, data_type, nullable)), | ||
) | ||
pub fn get_utf8_array_from_map( | ||
field: &str, | ||
map: &mut HashMap<String, arrow::array::ArrayRef>, | ||
) -> Result<Vec<String>> { | ||
use arrow::array::Array; | ||
|
||
let array_data = map | ||
.remove(field) | ||
.ok_or_eyre(Report::msg("Invalid field for this map."))? | ||
.into_data(); | ||
|
||
let array = arrow::array::StringArray::from(array_data); | ||
let (offsets, buffer, _) = array.into_parts(); | ||
|
||
let slice = buffer.as_slice(); | ||
let mut last_offset = 0; | ||
let mut iterator = offsets.iter(); | ||
iterator.next(); | ||
|
||
iterator | ||
.map(|&offset| { | ||
let offset = offset as usize; | ||
let slice = &slice[last_offset..offset]; | ||
last_offset = offset; | ||
|
||
String::from_utf8(slice.to_vec()).wrap_err(Report::msg("Array is not UTF-8 encoded.")) | ||
}) | ||
.collect::<Result<Vec<String>>>() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.