Skip to content

Commit

Permalink
Arrow safe extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
Hennzau committed Aug 8, 2024
1 parent e4a4f5d commit f6bd50b
Show file tree
Hide file tree
Showing 8 changed files with 311 additions and 273 deletions.
6 changes: 3 additions & 3 deletions examples/dummy-opencv-capture/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use fastformat::image::Image;
fn camera_read() -> ndarray::Array<u8, ndarray::Ix3> {
// Dummy camera read

let flat_image = vec![0; 27];
let flat_image = (110..137).collect::<Vec<u8>>();
println!(
"Generate a camera image at address: {:?}",
flat_image.as_ptr()
Expand All @@ -16,10 +16,10 @@ fn camera_read() -> ndarray::Array<u8, ndarray::Ix3> {
return image.bgr8_into_ndarray().unwrap();
}

fn image_show(_frame: ndarray::ArrayView<u8, ndarray::Ix3>) {
fn image_show(frame: ndarray::ArrayView<u8, ndarray::Ix3>) {
// Dummy image show

println!("Showing an image.");
println!("{:?}", frame);
}

fn send_output(arrow_array: arrow::array::UnionArray) {
Expand Down
221 changes: 114 additions & 107 deletions src/arrow.rs
Original file line number Diff line number Diff line change
@@ -1,162 +1,169 @@
use arrow::datatypes::DataType;
use std::collections::HashMap;

use std::{collections::HashMap, sync::Arc};
use eyre::{Context, OptionExt, Report, Result};

use eyre::{ContextCompat, Result};

/// Creates a lookup table (`HashMap`) from the fields of a union.
/// Converts an Arrow `UnionArray` into a `HashMap`.
///
/// This function takes a reference to `arrow::datatypes::UnionFields` and
/// creates a `HashMap` where the field names are the keys (as `String`) and
/// the associated values are the field identifiers (`i8`).
/// This function takes an Arrow `UnionArray` and converts it into a `HashMap` where the keys
/// are the names of the fields and the values are the corresponding `ArrayRef` objects.
///
/// # Arguments
///
/// * `fields` - A reference to the fields of the Arrow union data structure (`arrow::datatypes::UnionFields`).
/// * `array` - An `arrow::array::UnionArray` containing the data to be converted.
///
/// # Returns
///
/// A `HashMap` with the field names as keys and their identifiers (`i8`) as values.
///
/// # Example
///
/// ```
/// use arrow::datatypes::{Field, DataType, UnionFields};
/// use std::collections::HashMap;
///
/// use fastformat::arrow::union_lookup_table;
///
/// let fields = UnionFields::new(
/// vec![1, 2],
/// vec![
/// Field::new("field1", DataType::Int32, false),
/// Field::new("field2", DataType::Float64, false),
/// ],
/// );
/// A `Result` containing the constructed `HashMap<String, arrow::array::ArrayRef>` if successful,
/// or an error otherwise.
///
/// let lookup_table = union_lookup_table(&fields);
/// # Errors
///
/// assert_eq!(lookup_table.get("field1"), Some(&1));
/// assert_eq!(lookup_table.get("field2"), Some(&2));
/// ```
pub fn union_lookup_table(fields: &arrow::datatypes::UnionFields) -> HashMap<String, i8> {
/// Returns an error if the union array field index is invalid or if there are issues
/// in accessing the children of the union array.
pub fn arrow_union_into_map(
array: arrow::array::UnionArray,
) -> Result<HashMap<String, arrow::array::ArrayRef>> {
let mut result = HashMap::new();

for field in fields.iter() {
let (a, b) = field;
let (union_fields, _, _, children) = array.into_parts();

result.insert(b.name().to_string(), a);
for (a, b) in union_fields.iter() {
let child = children
.get(a as usize)
.ok_or_eyre(Report::msg(
"Invalid union array field index. Must be >= 0 and correspond to children index in the array.",
))?
.clone();

result.insert(b.name().to_string(), child);
}

result
Ok(result)
}

/// Retrieves a column from a `UnionArray` by its field name and downcasts it to the specified type.
/// Extracts a primitive array from a `HashMap` and converts it to a `Vec`.
///
/// This function takes a reference to an `arrow::array::UnionArray`, a field name,
/// and a lookup table mapping field names to their identifiers. It retrieves the column
/// corresponding to the field name from the union array and attempts to downcast it to
/// the specified type `T`.
/// This function takes a `HashMap` containing `ArrayRef` objects, extracts the array corresponding
/// to the specified field, and converts it into a `Vec` of the primitive type `T`.
///
/// # Arguments
///
/// * `array` - A reference to the `UnionArray` from which to retrieve the column.
/// * `field` - The name of the field whose column is to be retrieved.
/// * `lookup_table` - A reference to a `HashMap` that maps field names (`String`) to their identifiers (`i8`).
/// * `field` - A string slice representing the key in the `HashMap`.
/// * `map` - A mutable reference to the `HashMap<String, arrow::array::ArrayRef>`.
///
/// # Returns
///
/// A `Result` containing a reference to the column cast to type `T` if successful, or an error otherwise.
/// A `Result` containing the constructed `Vec<T>` if successful, or an error otherwise.
///
/// # Type Parameters
///
/// * `T` - The native primitive type to be extracted.
/// * `G` - The Arrow primitive type that corresponds to `T`.
///
/// # Errors
///
/// Returns an error if the field name is not found in the lookup table, or if the retrieved column
/// cannot be downcast to the specified type `T`.
/// Returns an error if the specified field is not present in the `HashMap`, or if the type
/// conversion fails.
///
/// # Example
///
/// ```
/// use arrow::array::Array;
///
/// use fastformat::image::Image;
/// use fastformat::arrow::union_lookup_table;
/// use fastformat::arrow::column_by_name;
///
/// let data = vec![0; 27]; // 3x3 image with 3 bytes per pixel
/// let image = Image::new_bgr8(data, 3, 3, None).unwrap();
/// let array = image.into_arrow().unwrap();
/// use std::collections::HashMap;
/// use arrow::array::{Int32Array, ArrayRef};
/// use std::sync::Arc;
///
/// let union_fields = match array.data_type() {
/// arrow::datatypes::DataType::Union(fields, ..) => fields,
/// _ => panic!("Unexpected data type for image array")
/// };
/// use fastformat::arrow::get_primitive_array_from_map;
///
/// let lookup_table = union_lookup_table(&union_fields);
/// let mut map = HashMap::new();
/// map.insert("field".to_string(), Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef);
///
/// let int_column = column_by_name::<arrow::array::Int32Array>(&array, "field1", &lookup_table);
/// let result: Vec<i32> = get_primitive_array_from_map::<i32, arrow::datatypes::Int32Type>("field", &mut map).unwrap();
/// ```
pub fn column_by_name<'a, T: 'static>(
array: &'a arrow::array::UnionArray,
field: &'a str,
lookup_table: &'a HashMap<String, i8>,
) -> Result<&'a T> {
let index = lookup_table
.get(field)
.cloned()
.wrap_err(format!("Couldn't get field {} from look_up table", field))?;

return array
.child(index)
.as_any()
.downcast_ref::<T>()
.wrap_err(format!("Couldn't downcast field {} to type T", field));
pub fn get_primitive_array_from_map<
T: arrow::datatypes::ArrowNativeType,
G: arrow::datatypes::ArrowPrimitiveType,
>(
field: &str,
map: &mut HashMap<String, arrow::array::ArrayRef>,
) -> Result<Vec<T>> {
use arrow::array::Array;

let array_data = map
.remove(field)
.ok_or_eyre(Report::msg("Invalid field for this map."))?
.into_data();

let array = arrow::array::PrimitiveArray::<G>::from(array_data);
let (_, buffer, _) = array.into_parts();
let buffer = buffer.into_inner();

match buffer.into_vec::<T>() {
Ok(vec) => Ok(vec),
Err(e) => Err(Report::msg(format!(
"T is not a valid type for this buffer. Must have the same layout as the buffer (it usually occurs when type is incorrect or when an other reference exists). Error: {:?}", e
))),
}
}

/// Creates a tuple representing a union field with an index and an `Arc`-wrapped `Field`.
/// Extracts a UTF-8 encoded string array from a `HashMap` and converts it to a `Vec<String>`.
///
/// This function constructs a tuple where the first element is the given index and the second element
/// is an `Arc`-wrapped `Field` constructed using the provided name, data type, and nullability.
/// This function takes a `HashMap` containing `ArrayRef` objects, extracts the UTF-8 encoded
/// string array corresponding to the specified field, and converts it into a `Vec<String>`.
///
/// # Arguments
///
/// * `index` - An identifier (`i8`) for the union field.
/// * `name` - A string slice representing the name of the field.
/// * `data_type` - The data type of the field (`arrow::datatypes::DataType`).
/// * `nullable` - A boolean indicating whether the field is nullable.
/// * `field` - A string slice representing the key in the `HashMap`.
/// * `map` - A mutable reference to the `HashMap<String, arrow::array::ArrayRef>`.
///
/// # Returns
///
/// A tuple where the first element is the given index (`i8`) and the second element is an `Arc`-wrapped
/// `Field` constructed from the provided name, data type, and nullability.
/// A `Result` containing the constructed `Vec<String>` if successful, or an error otherwise.
///
/// # Errors
///
/// Returns an error if the specified field is not present in the `HashMap`, or if the array
/// is not UTF-8 encoded.
///
/// # Example
///
/// ```
/// use arrow::datatypes::{DataType, Field};
/// use std::collections::HashMap;
/// use arrow::array::{StringArray, ArrayRef};
/// use std::sync::Arc;
/// use fastformat::arrow::union_field;
///
/// let index = 1;
/// let name = "field1";
/// let data_type = DataType::Int32;
/// let nullable = false;
/// use fastformat::arrow::get_utf8_array_from_map;
///
/// let union_field_tuple = union_field(index, name, data_type, nullable);
/// let mut map = HashMap::new();
/// map.insert("field".to_string(), Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef);
///
/// assert_eq!(union_field_tuple.0, 1);
/// assert_eq!(union_field_tuple.1.name(), "field1");
/// assert_eq!(union_field_tuple.1.data_type(), &DataType::Int32);
/// assert_eq!(union_field_tuple.1.is_nullable(), false);
/// let result = get_utf8_array_from_map("field", &mut map).unwrap();
/// ```
///
pub fn union_field(
index: i8,
name: &str,
data_type: DataType,
nullable: bool,
) -> (i8, Arc<arrow::datatypes::Field>) {
(
index,
Arc::new(arrow::datatypes::Field::new(name, data_type, nullable)),
)
pub fn get_utf8_array_from_map(
field: &str,
map: &mut HashMap<String, arrow::array::ArrayRef>,
) -> Result<Vec<String>> {
use arrow::array::Array;

let array_data = map
.remove(field)
.ok_or_eyre(Report::msg("Invalid field for this map."))?
.into_data();

let array = arrow::array::StringArray::from(array_data);
let (offsets, buffer, _) = array.into_parts();

let slice = buffer.as_slice();
let mut last_offset = 0;
let mut iterator = offsets.iter();
iterator.next();

iterator
.map(|&offset| {
let offset = offset as usize;
let slice = &slice[last_offset..offset];
last_offset = offset;

String::from_utf8(slice.to_vec()).wrap_err(Report::msg("Array is not UTF-8 encoded."))
})
.collect::<Result<Vec<String>>>()
}
18 changes: 18 additions & 0 deletions src/bbox.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,24 @@ pub struct BBox {
pub encoding: Encoding,
}

type BboxNdArrayResult = (
ndarray::Array<f32, ndarray::Ix1>,
ndarray::Array<f32, ndarray::Ix1>,
ndarray::Array<String, ndarray::Ix1>,
);

type BboxNdArrayViewResult<'a> = (
ndarray::ArrayView<'a, f32, ndarray::Ix1>,
ndarray::ArrayView<'a, f32, ndarray::Ix1>,
ndarray::ArrayView<'a, String, ndarray::Ix1>,
);

type BboxNdArrayViewMutResult<'a> = (
ndarray::ArrayViewMut<'a, f32, ndarray::Ix1>,
ndarray::ArrayViewMut<'a, f32, ndarray::Ix1>,
ndarray::ArrayViewMut<'a, String, ndarray::Ix1>,
);

impl BBox {
pub fn into_xyxy(self) -> Result<Self> {
match self.encoding {
Expand Down
Loading

0 comments on commit f6bd50b

Please sign in to comment.