Skip to content

Commit

Permalink
added broadcastable 1D arrays (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
maximedion2 authored Nov 10, 2024
1 parent 36d5d52 commit 17d4af2
Show file tree
Hide file tree
Showing 9 changed files with 633 additions and 367 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ authors = ["Maxime Dion <[email protected]>"]
license = "Apache-2.0"
keywords = ["arrow"]
edition = "2021"
rust-version = "1.62"
rust-version = "1.64"

[dependencies]
async-trait = { version = "0.1.53" }
Expand Down
233 changes: 109 additions & 124 deletions src/async_reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ use async_trait::async_trait;
use futures::stream::{BoxStream, Stream};
use futures::{ready, FutureExt};
use futures_util::future::BoxFuture;
use itertools::Itertools;
use std::collections::HashMap;
use std::pin::Pin;
use std::task::{Context, Poll};

Expand All @@ -118,8 +118,7 @@ pub struct ZarrStoreAsync<T: for<'a> ZarrReadAsync<'a>> {
projection: ZarrProjection,
curr_chunk: usize,
io_uring_worker_pool: WorkerPool,
n_pre_read_chunks_left: Option<usize>,
pre_read_chunks: Vec<Option<ZarrResult<ZarrInMemoryChunk>>>,
broadcastable_array_axes: HashMap<String, Option<usize>>,
}

impl<T: for<'a> ZarrReadAsync<'a>> ZarrStoreAsync<T> {
Expand All @@ -129,28 +128,24 @@ impl<T: for<'a> ZarrReadAsync<'a>> ZarrStoreAsync<T> {
projection: ZarrProjection,
) -> ZarrResult<Self> {
let meta = zarr_reader.get_zarr_metadata().await?;
let mut bdc_axes: HashMap<String, Option<usize>> = HashMap::new();
for col in meta.get_columns() {
let mut axis = None;
if let Some(params) = meta.get_array_meta(col)?.get_ond_d_array_params() {
axis = Some(params.1);
}
bdc_axes.insert(col.to_string(), axis);
}
Ok(Self {
meta,
chunk_positions,
zarr_reader,
projection,
curr_chunk: 0,
io_uring_worker_pool: WorkerPool::new(_IO_URING_SIZE, _IO_URING_N_WORKERS)?,
n_pre_read_chunks_left: None,
pre_read_chunks: Vec::new(),
broadcastable_array_axes: bdc_axes,
})
}

// Not used for now, but could enable it later. This would basically
// be to read multiple chunks at once, cache them, and then return
// one at a time as an iterator.
fn _with_pre_reads(mut self, n_pre_reads: usize) -> Self {
self.n_pre_read_chunks_left = Some(0);
for _ in 0..n_pre_reads {
self.pre_read_chunks.push(None);
}
self
}
}

/// A trait exposing a method to asynchronously get zarr chunk data, but also to
Expand All @@ -176,62 +171,18 @@ where
let cols = self.projection.apply_selection(self.meta.get_columns());
let cols = unwrap_or_return!(cols);

if let Some(n) = self.n_pre_read_chunks_left {
if n == 0 {
let final_idx = std::cmp::min(
self.curr_chunk + self.pre_read_chunks.len(),
self.chunk_positions.len(),
);
let positions = self.chunk_positions[self.curr_chunk..final_idx].to_vec();
let real_dims = positions
.iter()
.map(|pos| self.meta.get_real_dims(pos))
.collect_vec();
let chunks = self
.zarr_reader
.get_zarr_chunks(
positions,
&cols,
real_dims,
self.meta.get_chunk_patterns(),
&mut self.io_uring_worker_pool,
)
.await;
if let Ok(chunks) = chunks {
self.pre_read_chunks = chunks.into_iter().map(|c| Some(Ok(c))).collect();
self.n_pre_read_chunks_left = Some(final_idx - self.curr_chunk);
} else {
return Some(Err(ZarrError::Read(
"could not read batch of chunks".to_string(),
)));
}
}
}

let chnk = match self.n_pre_read_chunks_left {
Some(n) => {
if n == 0 {
panic!("unexpected condition in async reader with pre reads");
}
let pre_read_idx = self.pre_read_chunks.len() - n;
self.n_pre_read_chunks_left = Some(n - 1);
self.pre_read_chunks[pre_read_idx]
.take()
.expect("unexpected condition in async reader with pre reads")
}
None => {
let pos = &self.chunk_positions[self.curr_chunk];
self.zarr_reader
.get_zarr_chunk(
pos,
&cols,
self.meta.get_real_dims(pos),
self.meta.get_chunk_patterns(),
&mut self.io_uring_worker_pool,
)
.await
}
};
let pos = &self.chunk_positions[self.curr_chunk];
let chnk = self
.zarr_reader
.get_zarr_chunk(
pos,
&cols,
self.meta.get_real_dims(pos),
self.meta.get_chunk_patterns(),
&mut self.io_uring_worker_pool,
&self.broadcastable_array_axes,
)
.await;

self.curr_chunk += 1;
Some(chnk)
Expand Down Expand Up @@ -767,19 +718,17 @@ mod zarr_async_reader_tests {
use itertools::enumerate;
use object_store::{local::LocalFileSystem, path::Path};
use std::sync::Arc;
use std::{collections::HashMap, fmt::Debug, path::PathBuf};
use std::{collections::HashMap, fmt::Debug};

use super::*;
use crate::async_reader::zarr_read_async::ZarrPath;
use crate::reader::{ZarrArrowPredicate, ZarrArrowPredicateFn};
use crate::tests::{get_test_v2_data_path, get_test_v3_data_path};

fn get_v2_test_data_path(zarr_store: String) -> ZarrPath {
let p = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("test-data/data/zarr/v2_data")
.join(zarr_store);
fn get_v2_test_zarr_path(zarr_store: String) -> ZarrPath {
ZarrPath::new(
Arc::new(LocalFileSystem::new()),
Path::from_absolute_path(p).unwrap(),
Path::from_absolute_path(get_test_v2_data_path(zarr_store)).unwrap(),
)
}

Expand Down Expand Up @@ -826,39 +775,8 @@ mod zarr_async_reader_tests {
assert!(matched);
}

#[tokio::test]
async fn projection_tests() {
let zp = get_v2_test_data_path("compression_example.zarr".to_string());
let proj = ZarrProjection::keep(vec!["bool_data".to_string(), "int_data".to_string()]);
let stream_builder = ZarrRecordBatchStreamBuilder::new(zp).with_projection(proj);

let stream = stream_builder.build().await.unwrap();
let records: Vec<_> = stream.try_collect().await.unwrap();

let target_types = HashMap::from([
("bool_data".to_string(), DataType::Boolean),
("int_data".to_string(), DataType::Int64),
]);

// center chunk
let rec = &records[4];
validate_names_and_types(&target_types, rec);
validate_bool_column(
"bool_data",
rec,
&[false, true, false, false, true, false, false, true, false],
);
validate_primitive_column::<Int64Type, i64>(
"int_data",
rec,
&[-4, -3, -2, 4, 5, 6, 12, 13, 14],
);
}

#[tokio::test]
async fn filters_tests() {
// set the filters to select part of the raster, based on lat and
// lon coordinates.
// create a test filter
fn create_filter() -> ZarrChunkFilter {
let mut filters: Vec<Box<dyn ZarrArrowPredicate>> = Vec::new();
let f = ZarrArrowPredicateFn::new(
ZarrProjection::keep(vec!["lat".to_string()]),
Expand Down Expand Up @@ -891,9 +809,42 @@ mod zarr_async_reader_tests {
);
filters.push(Box::new(f));

let zp = get_v2_test_data_path("lat_lon_example.zarr".to_string());
let stream_builder =
ZarrRecordBatchStreamBuilder::new(zp).with_filter(ZarrChunkFilter::new(filters));
ZarrChunkFilter::new(filters)
}

#[tokio::test]
async fn projection_tests() {
let zp = get_v2_test_zarr_path("compression_example.zarr".to_string());
let proj = ZarrProjection::keep(vec!["bool_data".to_string(), "int_data".to_string()]);
let stream_builder = ZarrRecordBatchStreamBuilder::new(zp).with_projection(proj);

let stream = stream_builder.build().await.unwrap();
let records: Vec<_> = stream.try_collect().await.unwrap();

let target_types = HashMap::from([
("bool_data".to_string(), DataType::Boolean),
("int_data".to_string(), DataType::Int64),
]);

// center chunk
let rec = &records[4];
validate_names_and_types(&target_types, rec);
validate_bool_column(
"bool_data",
rec,
&[false, true, false, false, true, false, false, true, false],
);
validate_primitive_column::<Int64Type, i64>(
"int_data",
rec,
&[-4, -3, -2, 4, 5, 6, 12, 13, 14],
);
}

#[tokio::test]
async fn filters_tests() {
let zp = get_v2_test_zarr_path("lat_lon_example.zarr".to_string());
let stream_builder = ZarrRecordBatchStreamBuilder::new(zp).with_filter(create_filter());
let stream = stream_builder.build().await.unwrap();
let records: Vec<_> = stream.try_collect().await.unwrap();

Expand Down Expand Up @@ -938,7 +889,7 @@ mod zarr_async_reader_tests {

#[tokio::test]
async fn multiple_readers_tests() {
let zp = get_v2_test_data_path("compression_example.zarr".to_string());
let zp = get_v2_test_zarr_path("compression_example.zarr".to_string());
let stream1 = ZarrRecordBatchStreamBuilder::new(zp.clone())
.build_partial_reader(Some((0, 5)))
.await
Expand Down Expand Up @@ -1008,7 +959,7 @@ mod zarr_async_reader_tests {

#[tokio::test]
async fn empty_query_tests() {
let zp = get_v2_test_data_path("lat_lon_example.zarr".to_string());
let zp = get_v2_test_zarr_path("lat_lon_example.zarr".to_string());
let mut builder = ZarrRecordBatchStreamBuilder::new(zp);

// set a filter that will filter out all the data, there should be nothing left after
Expand All @@ -1033,19 +984,53 @@ mod zarr_async_reader_tests {
assert_eq!(records.len(), 0);
}

fn get_v3_test_data_path(zarr_store: String) -> ZarrPath {
let p = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("test-data/data/zarr/v3_data")
.join(zarr_store);
#[tokio::test]
async fn array_broadcast_tests() {
// reference that doesn't broadcast a 1D array
let zp = get_v2_test_zarr_path("lat_lon_example.zarr".to_string());
let mut builder = ZarrRecordBatchStreamBuilder::new(zp);

builder = builder.with_filter(create_filter());
let stream = builder.build().await.unwrap();
let records: Vec<_> = stream.try_collect().await.unwrap();

// v2 format with array broadcast
let zp = get_v2_test_zarr_path("lat_lon_example_broadcastable.zarr".to_string());
let mut builder = ZarrRecordBatchStreamBuilder::new(zp);

builder = builder.with_filter(create_filter());
let stream = builder.build().await.unwrap();
let records_from_one_d_repr: Vec<_> = stream.try_collect().await.unwrap();

assert_eq!(records_from_one_d_repr.len(), records.len());
for (rec, rec_from_one_d_repr) in records.iter().zip(records_from_one_d_repr.iter()) {
assert_eq!(rec, rec_from_one_d_repr);
}

// v3 format with array broadcast
let zp = get_v3_test_zarr_path("with_broadcastable_array.zarr".to_string());
let mut builder = ZarrRecordBatchStreamBuilder::new(zp);

builder = builder.with_filter(create_filter());
let stream = builder.build().await.unwrap();
let records_from_one_d_repr: Vec<_> = stream.try_collect().await.unwrap();

assert_eq!(records_from_one_d_repr.len(), records.len());
for (rec, rec_from_one_d_repr) in records.iter().zip(records_from_one_d_repr.iter()) {
assert_eq!(rec, rec_from_one_d_repr);
}
}

fn get_v3_test_zarr_path(zarr_store: String) -> ZarrPath {
ZarrPath::new(
Arc::new(LocalFileSystem::new()),
Path::from_absolute_path(p).unwrap(),
Path::from_absolute_path(get_test_v3_data_path(zarr_store)).unwrap(),
)
}

#[tokio::test]
async fn with_sharding_tests() {
let zp = get_v3_test_data_path("with_sharding.zarr".to_string());
let zp = get_v3_test_zarr_path("with_sharding.zarr".to_string());
let stream_builder = ZarrRecordBatchStreamBuilder::new(zp);

let stream = stream_builder.build().await.unwrap();
Expand Down Expand Up @@ -1077,7 +1062,7 @@ mod zarr_async_reader_tests {

#[tokio::test]
async fn three_dims_with_sharding_with_edge_tests() {
let zp = get_v3_test_data_path("with_sharding_with_edge_3d.zarr".to_string());
let zp = get_v3_test_zarr_path("with_sharding_with_edge_3d.zarr".to_string());
let stream_builder = ZarrRecordBatchStreamBuilder::new(zp);

let stream = stream_builder.build().await.unwrap();
Expand All @@ -1101,7 +1086,7 @@ mod zarr_async_reader_tests {

#[tokio::test]
async fn no_sharding_tests() {
let zp = get_v3_test_data_path("no_sharding.zarr".to_string());
let zp = get_v3_test_zarr_path("no_sharding.zarr".to_string());
let stream_builder = ZarrRecordBatchStreamBuilder::new(zp);

let stream = stream_builder.build().await.unwrap();
Expand Down
Loading

0 comments on commit 17d4af2

Please sign in to comment.