diff --git a/src/datafusion/file_opener.rs b/src/datafusion/file_opener.rs index 8faec30..a8fb99a 100644 --- a/src/datafusion/file_opener.rs +++ b/src/datafusion/file_opener.rs @@ -52,17 +52,17 @@ impl FileOpener for ZarrFileOpener { let zarr_path = ZarrPath::new(config.object_store, file_meta.object_meta.location); let rng = file_meta.range.map(|r| (r.start as usize, r.end as usize)); let projection = ZarrProjection::from(config.projection.as_ref()); - let mut batch_reader_builder = ZarrRecordBatchStreamBuilder::new(zarr_path.clone()).with_projection(projection); if let Some(filters) = filters_to_pushdown { - let schema = zarr_path + let file_schema = zarr_path .get_zarr_metadata() .await .map_err(|e| DataFusionError::External(Box::new(e)))? .arrow_schema() .map_err(|e| DataFusionError::External(Box::new(e)))?; - let filters = build_row_filter(&filters, &schema)?; + let filters = build_row_filter(&filters, &file_schema)?; + if let Some(filters) = filters { batch_reader_builder = batch_reader_builder.with_filter(filters); } @@ -71,9 +71,7 @@ impl FileOpener for ZarrFileOpener { .build_partial_reader(rng) .await .map_err(|e| DataFusionError::External(Box::new(e)))?; - let stream = batch_reader.map_err(|e| ArrowError::from_external_error(Box::new(e))); - Ok(stream.boxed()) })) } diff --git a/src/datafusion/helpers.rs b/src/datafusion/helpers.rs index e30f991..0bebc30 100644 --- a/src/datafusion/helpers.rs +++ b/src/datafusion/helpers.rs @@ -16,18 +16,30 @@ // under the License. use crate::reader::{ZarrArrowPredicate, ZarrChunkFilter, ZarrProjection}; -use arrow::array::BooleanArray; +use arrow::array::{ArrayRef, BooleanArray, StringBuilder}; +use arrow::compute::{and, cast, prep_null_mask_filter}; +use arrow::datatypes::{DataType, Field}; use arrow::error::ArrowError; use arrow::record_batch::RecordBatch; -use arrow_schema::Schema; +use arrow_array::cast::AsArray; +use arrow_array::Array; +use arrow_schema::{Fields, Schema}; +use datafusion::datasource::listing::{ListingTableUrl, PartitionedFile}; use datafusion_common::cast::as_boolean_array; +use datafusion_common::scalar::ScalarValue; use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter, VisitRecursion}; use datafusion_common::Result as DataFusionResult; -use datafusion_common::{internal_err, DataFusionError}; +use datafusion_common::{internal_err, DFField, DFSchema, DataFusionError}; use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility}; +use datafusion_physical_expr::create_physical_expr; +use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::utils::reassign_predicate_columns; use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; +use futures::stream::FuturesUnordered; +use futures::stream::{self, BoxStream, StreamExt}; +use object_store::path::DELIMITER; +use object_store::{path::Path, ObjectStore}; use std::collections::BTreeSet; use std::sync::Arc; @@ -117,6 +129,7 @@ struct ZarrFilterCandidateBuilder<'a> { expr: Arc, file_schema: &'a Schema, required_column_indices: BTreeSet, + projected_columns: bool, } impl<'a> ZarrFilterCandidateBuilder<'a> { @@ -125,12 +138,19 @@ impl<'a> ZarrFilterCandidateBuilder<'a> { expr, file_schema, required_column_indices: BTreeSet::default(), + projected_columns: false, } } pub fn build(mut self) -> DataFusionResult> { let expr = self.expr.clone().rewrite(&mut self)?; + // if we are dealing with a projected column, which here means it's + // a partitioned column, we don't produce a filter for it. + if self.projected_columns { + return Ok(None); + } + Ok(Some(ZarrFilterCandidate { expr, projection: self.required_column_indices.into_iter().collect(), @@ -145,6 +165,14 @@ impl<'a> TreeNodeRewriter for ZarrFilterCandidateBuilder<'a> { if let Some(column) = node.as_any().downcast_ref::() { if let Ok(idx) = self.file_schema.index_of(column.name()) { self.required_column_indices.insert(idx); + } else { + // set the flag is we detect that the column is not in the file schema. for the + // zarr implementation, this would mean that the column is actually a partitioned + // column, and we shouldn't be pushing down a filter for it. + // TODO handle cases where a filter contains a column that doesn't exist (not even + // as a partition). + self.projected_columns = true; + return Ok(RewriteRecursion::Stop); } } @@ -243,3 +271,337 @@ pub(crate) fn build_row_filter( Ok(Some(chunk_filter)) } } + +// Below is all the logic related to hive still partitioning, mostly copied and +// slightly modified from datafusion. +const CONCURRENCY_LIMIT: usize = 100; +const MAX_PARTITION_DEPTH: usize = 64; + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +struct Partition { + /// The path to the partition, including the table prefix + path: Path, + /// How many path segments below the table prefix `path` contains + /// or equivalently the number of partition values in `path` + depth: usize, +} + +impl Partition { + /// List the direct children of this partition updating `self.files` with + /// any child files, and returning a list of child "directories" + async fn list(self, store: &dyn ObjectStore) -> DataFusionResult<(Self, Vec)> { + let prefix = Some(&self.path).filter(|p| !p.as_ref().is_empty()); + let result = store.list_with_delimiter(prefix).await?; + Ok((self, result.common_prefixes)) + } +} + +async fn list_partitions( + store: &dyn ObjectStore, + table_path: &ListingTableUrl, + max_depth: usize, +) -> DataFusionResult> { + let partition = Partition { + path: table_path.prefix().clone(), + depth: 0, + }; + + let mut final_partitions = Vec::with_capacity(MAX_PARTITION_DEPTH); + let mut pending = vec![]; + let mut futures = FuturesUnordered::new(); + futures.push(partition.list(store)); + + while let Some((partition, paths)) = futures.next().await.transpose()? { + // If pending contains a future it implies prior to this iteration + // `futures.len == CONCURRENCY_LIMIT`. We can therefore add a single + // future from `pending` to the working set + if let Some(next) = pending.pop() { + futures.push(next) + } + + let depth = partition.depth; + if depth == max_depth { + final_partitions.push(partition); + } + for path in paths { + let child = Partition { + path, + depth: depth + 1, + }; + // if we have reached the max depth, we don't need to list all the + // directories under the last partition, those will all be zarr arrays, + // and the last partition itself will be a store to be read atomically. + if depth < max_depth { + if futures.len() < CONCURRENCY_LIMIT { + futures.push(child.list(store)); + } else { + pending.push(child.list(store)) + } + } + } + } + Ok(final_partitions) +} + +fn parse_partitions_for_path( + table_path: &ListingTableUrl, + file_path: &Path, + table_partition_cols: Vec<&str>, +) -> Option> { + let mut stripped = file_path + .as_ref() + .strip_prefix(table_path.prefix().as_ref())?; + if !stripped.is_empty() && !table_path.prefix().as_ref().is_empty() { + stripped = stripped.strip_prefix(DELIMITER)?; + } + let subpath = stripped.split_terminator(DELIMITER).map(|s| s.to_string()); + + let mut part_values = vec![]; + for (part, pn) in subpath.zip(table_partition_cols) { + match part.split_once('=') { + Some((name, val)) if name == pn => part_values.push(val.to_string()), + _ => { + return None; + } + } + } + Some(part_values) +} + +async fn prune_partitions( + table_path: &ListingTableUrl, + partitions: Vec, + filters: &[Expr], + partition_cols: &[(String, DataType)], +) -> DataFusionResult> { + if filters.is_empty() { + return Ok(partitions); + } + + let mut builders: Vec<_> = (0..partition_cols.len()) + .map(|_| StringBuilder::with_capacity(partitions.len(), partitions.len() * 10)) + .collect(); + + for partition in &partitions { + let cols = partition_cols.iter().map(|x| x.0.as_str()).collect(); + let parsed = + parse_partitions_for_path(table_path, &partition.path, cols).unwrap_or_default(); + + let mut builders = builders.iter_mut(); + for (p, b) in parsed.iter().zip(&mut builders) { + b.append_value(p); + } + builders.for_each(|b| b.append_null()); + } + + let arrays = partition_cols + .iter() + .zip(builders) + .map(|((_, d), mut builder)| { + let array = builder.finish(); + cast(&array, d) + }) + .collect::>()?; + + let fields: Fields = partition_cols + .iter() + .map(|(n, d)| Field::new(n, d.clone(), true)) + .collect(); + let schema = Arc::new(Schema::new(fields)); + + let df_schema = DFSchema::new_with_metadata( + partition_cols + .iter() + .map(|(n, d)| DFField::new_unqualified(n, d.clone(), true)) + .collect(), + Default::default(), + )?; + + let batch = RecordBatch::try_new(schema.clone(), arrays)?; + + // TODO: Plumb this down + let props = ExecutionProps::new(); + + // Applies `filter` to `batch` returning `None` on error + let do_filter = |filter| -> Option { + let expr = create_physical_expr(filter, &df_schema, &props).ok()?; + expr.evaluate(&batch) + .ok()? + .into_array(partitions.len()) + .ok() + }; + + // Compute the conjunction of the filters, ignoring errors + let mask = filters + .iter() + .fold(None, |acc, filter| match (acc, do_filter(filter)) { + (Some(a), Some(b)) => Some(and(&a, b.as_boolean()).unwrap_or(a)), + (None, Some(r)) => Some(r.as_boolean().clone()), + (r, None) => r, + }); + + let mask = match mask { + Some(mask) => mask, + None => return Ok(partitions), + }; + + // Don't retain partitions that evaluated to null + let prepared = match mask.null_count() { + 0 => mask, + _ => prep_null_mask_filter(&mask), + }; + + let filtered = partitions + .into_iter() + .zip(prepared.values()) + .filter_map(|(p, f)| f.then_some(p)) + .collect(); + + Ok(filtered) +} + +pub async fn pruned_partition_list<'a>( + store: &'a dyn ObjectStore, + table_path: &'a ListingTableUrl, + filters: &'a [Expr], + partition_cols: &'a [(String, DataType)], +) -> DataFusionResult>> { + // if no partition col => simply return the table path + if partition_cols.is_empty() { + let pf = PartitionedFile::new(table_path.prefix().clone(), 0); + return Ok(Box::pin(stream::iter(vec![Ok(pf)]))); + } + + let partitions = list_partitions(store, table_path, partition_cols.len()).await?; + let pruned = prune_partitions(table_path, partitions, filters, partition_cols).await?; + + let stream = futures::stream::iter(pruned) + .map(move |partition: Partition| async move { + let cols = partition_cols.iter().map(|x| x.0.as_str()).collect(); + let parsed = parse_partitions_for_path(table_path, &partition.path, cols); + + let partition_values = parsed + .into_iter() + .flatten() + .zip(partition_cols) + .map(|(parsed, (_, datatype))| { + ScalarValue::try_from_string(parsed.to_string(), datatype) + }) + .collect::>>()?; + + let mut pf = PartitionedFile::new(partition.path, 0); + pf.partition_values.clone_from(&partition_values); + + Ok(pf) + }) + .buffer_unordered(CONCURRENCY_LIMIT) + .boxed(); + + Ok(stream) +} + +// copied from datafusion +pub fn split_files( + mut partitioned_files: Vec, + n: usize, +) -> Vec> { + if partitioned_files.is_empty() { + return vec![]; + } + + // ObjectStore::list does not guarantee any consistent order and for some + // implementations such as LocalFileSystem, it may be inconsistent. Thus + // Sort files by path to ensure consistent plans when run more than once. + partitioned_files.sort_by(|a, b| a.path().cmp(b.path())); + + // effectively this is div with rounding up instead of truncating + let chunk_size = (partitioned_files.len() + n - 1) / n; + partitioned_files + .chunks(chunk_size) + .map(|c| c.to_vec()) + .collect() +} + +#[cfg(test)] +mod helpers_tests { + use super::*; + use crate::tests::get_test_v2_data_path; + use datafusion_expr::{and, col, lit}; + use itertools::Itertools; + use object_store::local::LocalFileSystem; + + #[tokio::test] + async fn test_listing_and_pruning_partitions() { + let table_path = get_test_v2_data_path("lat_lon_w_groups_example.zarr".to_string()) + .to_str() + .unwrap() + .to_string(); + + let store = LocalFileSystem::new(); + let url = ListingTableUrl::parse(&table_path).unwrap(); + let partitions = list_partitions(&store, &url, 2).await.unwrap(); + + let expr1 = col("var").eq(lit(1_i32)); + let expr2 = col("other_var").eq(lit::("b".to_string())); + let partition_cols = [ + ("var".to_string(), DataType::Int32), + ("other_var".to_string(), DataType::Utf8), + ]; + + let part_1a = Partition { + path: Path::parse(&table_path) + .unwrap() + .child("var=1") + .child("other_var=a"), + depth: 2, + }; + let part_1b = Partition { + path: Path::parse(&table_path) + .unwrap() + .child("var=1") + .child("other_var=b"), + depth: 2, + }; + let part_2b = Partition { + path: Path::parse(table_path) + .unwrap() + .child("var=2") + .child("other_var=b"), + depth: 2, + }; + + let filters = [expr1.clone()]; + let pruned = prune_partitions(&url, partitions.clone(), &filters, &partition_cols) + .await + .unwrap(); + assert_eq!( + pruned.into_iter().sorted().collect::>(), + vec![part_1a.clone(), part_1b.clone()] + .into_iter() + .sorted() + .collect::>(), + ); + + let filters = [expr2.clone()]; + let pruned = prune_partitions(&url, partitions.clone(), &filters, &partition_cols) + .await + .unwrap(); + assert_eq!( + pruned.into_iter().sorted().collect::>(), + vec![part_1b.clone(), part_2b.clone()] + .into_iter() + .sorted() + .collect::>(), + ); + + let expr = and(expr1, expr2); + let filters = [expr]; + let pruned = prune_partitions(&url, partitions.clone(), &filters, &partition_cols) + .await + .unwrap(); + assert_eq!( + pruned.into_iter().sorted().collect::>(), + vec![part_1b].into_iter().sorted().collect::>(), + ); + } +} diff --git a/src/datafusion/scanner.rs b/src/datafusion/scanner.rs index a4356f3..cb3bfa4 100644 --- a/src/datafusion/scanner.rs +++ b/src/datafusion/scanner.rs @@ -102,8 +102,18 @@ impl ExecutionPlan for ZarrScan { .runtime_env() .object_store(&self.base_config.object_store_url)?; - let config = - ZarrConfig::new(object_store).with_projection(self.base_config.projection.clone()); + // This is just replicating the `file_column_projection_indices` method on + // `FileScanConfig`, which is only pub within the datafusion crate. We need + // to remove column indices that correspond to partitions, since we can't + // pass those to the zarr reader. + let projection = self.base_config.projection.as_ref().map(|p| { + p.iter() + .filter(|col_idx| **col_idx < self.base_config.file_schema.fields().len()) + .copied() + .collect() + }); + + let config = ZarrConfig::new(object_store).with_projection(projection); let opener = ZarrFileOpener::new(config, self.filters.clone()); let stream = FileStream::new(&self.base_config, partition, opener, &self.metrics)?; diff --git a/src/datafusion/table_factory.rs b/src/datafusion/table_factory.rs index 5259668..55f4db3 100644 --- a/src/datafusion/table_factory.rs +++ b/src/datafusion/table_factory.rs @@ -17,8 +17,11 @@ use std::sync::Arc; +use arrow::datatypes::{DataType, Schema}; +use arrow_schema::Field; use async_trait::async_trait; use datafusion::{ + common::arrow_datafusion_err, datasource::{listing::ListingTableUrl, provider::TableProviderFactory, TableProvider}, error::DataFusionError, execution::context::SessionState, @@ -42,16 +45,73 @@ impl TableProviderFactory for ZarrListingTableFactory { )); } - let table_path = ListingTableUrl::parse(&cmd.location)?; - - let options = ListingZarrTableOptions {}; - let schema = options - .infer_schema(state, &table_path) - .await - .map_err(|e| DataFusionError::Execution(format!("infer error: {:?}", e)))?; + // mostly copied over from datafusion. + let (provided_schema, table_partition_cols) = if cmd.schema.fields().is_empty() { + ( + None, + cmd.table_partition_cols + .iter() + .map(|x| { + ( + x.clone(), + DataType::Dictionary( + Box::new(DataType::UInt16), + Box::new(DataType::Utf8), + ), + ) + }) + .collect::>(), + ) + } else { + // this bit here is to ensure that the fields in the schema are alphabetically + // ordered. because a zarr store doesn't provide any ordering, we need some + // convention, and the schema needs to follow that convention here. + let mut schema: Schema = cmd.schema.as_ref().into(); + let mut fields: Vec = Vec::new(); + for f in schema.fields() { + let test = Field::new(f.name(), f.data_type().clone(), false); + fields.push(test); + } + fields.sort_by(|f1, f2| f1.name().cmp(f2.name())); + schema = Schema::new(fields); + + let table_partition_cols = cmd + .table_partition_cols + .iter() + .map(|col| { + schema + .field_with_name(col) + .map_err(|e| arrow_datafusion_err!(e)) + }) + .collect::>>()? + .into_iter() + .map(|f| (f.name().to_owned(), f.data_type().to_owned())) + .collect(); + // exclude partition columns to support creating partitioned external table + // with a specified column definition like `create external table a(c0 int, c1 int) + // stored as csv partitioned by (c1)...` + let mut project_idx = Vec::new(); + for i in 0..schema.fields().len() { + if !cmd.table_partition_cols.contains(schema.field(i).name()) { + project_idx.push(i); + } + } + let schema = schema.project(&project_idx)?; + (Some(schema), table_partition_cols) + }; - let table_provider = - ZarrTableProvider::new(ListingZarrTableConfig::new(table_path), schema); + let table_path = ListingTableUrl::parse(&cmd.location)?; + let options = ListingZarrTableOptions::new().with_partition_cols(table_partition_cols); + let schema = match provided_schema { + None => options + .infer_schema(state, &table_path) + .await + .map_err(|e| DataFusionError::Execution(format!("infer error: {:?}", e)))?, + Some(s) => s, + }; + + let config = ListingZarrTableConfig::new(table_path, schema, options); + let table_provider = ZarrTableProvider::try_new(config)?; Ok(Arc::new(table_provider)) } @@ -61,8 +121,8 @@ impl TableProviderFactory for ZarrListingTableFactory { mod tests { use crate::tests::get_test_v2_data_path; use arrow::record_batch::RecordBatch; - use arrow_array::cast::AsArray; use arrow_array::types::*; + use arrow_array::{cast::AsArray, StringArray}; use arrow_buffer::ScalarBuffer; use datafusion::execution::{ @@ -84,6 +144,14 @@ mod tests { .clone() } + fn extract_str_col(col_name: &str, rec_batch: &RecordBatch) -> StringArray { + rec_batch + .column_by_name(col_name) + .unwrap() + .as_string() + .to_owned() + } + #[tokio::test] async fn test_create() -> Result<(), Box> { let mut state = SessionState::new_with_config_rt( @@ -239,4 +307,122 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_partitions() -> Result<(), Box> { + let mut state = SessionState::new_with_config_rt( + SessionConfig::default(), + Arc::new(RuntimeEnv::default()), + ); + + state + .table_factories_mut() + .insert("ZARR".into(), Arc::new(super::ZarrListingTableFactory {})); + + let test_data = get_test_v2_data_path("lat_lon_w_groups_example.zarr".to_string()); + + let sql = format!( + "CREATE EXTERNAL TABLE zarr_table ( + lat double, + lon double, + float_data double, + var int, + other_var string + ) + STORED AS ZARR LOCATION '{}' + PARTITIONED BY (var, other_var)", + test_data.display(), + ); + + let session = SessionContext::new_with_state(state); + session.sql(&sql).await?; + + // select a particular partition for each partitioned variable + let sql = "SELECT lat, lon, var, other_var FROM zarr_table + WHERE var=1 + AND other_var='b'"; + let df = session.sql(sql).await?; + + let batches = df.collect().await?; + for batch in batches { + let lat_values = extract_col::("lat", &batch); + let lon_values = extract_col::("lon", &batch); + let var_values = extract_col::("var", &batch); + let other_var_values = extract_str_col("other_var", &batch); + assert!(lat_values + .iter() + .zip(lon_values.iter()) + .all(|(lat, lon)| *lat >= 38.0 + && *lat <= 39.0 + && *lon >= -108.9 + && *lon <= -107.9)); + assert!(var_values.iter().all(|var| var == &1)); + assert!(other_var_values + .iter() + .all(|other_var| other_var.unwrap() == "b")); + } + + // select a different partition for each partitioned variable + let sql = "SELECT lat, lon, var, other_var FROM zarr_table + WHERE var=2 + AND other_var='a'"; + let df = session.sql(sql).await?; + + let batches = df.collect().await?; + for batch in batches { + let lat_values = extract_col::("lat", &batch); + let lon_values = extract_col::("lon", &batch); + let var_values = extract_col::("var", &batch); + let other_var_values = extract_str_col("other_var", &batch); + assert!(lat_values + .iter() + .zip(lon_values.iter()) + .all(|(lat, lon)| *lat >= 39.0 + && *lat <= 40.0 + && *lon >= -110.0 + && *lon <= -108.9)); + assert!(var_values.iter().all(|var| var == &2)); + assert!(other_var_values + .iter() + .all(|other_var| other_var.unwrap() == "a")); + } + + // select the same partition but without selection the partitioned variables + let sql = "SELECT lat, lon FROM zarr_table + WHERE var=2 + AND other_var='a'"; + let df = session.sql(sql).await?; + + let batches = df.collect().await?; + for batch in batches { + let lat_values = extract_col::("lat", &batch); + let lon_values = extract_col::("lon", &batch); + assert!(lat_values + .iter() + .zip(lon_values.iter()) + .all(|(lat, lon)| *lat >= 39.0 + && *lat <= 40.0 + && *lon >= -110.0 + && *lon <= -108.9)); + } + + // select a partition for only one of the partitioned variables + let sql = "SELECT lat, lon, var, other_var FROM zarr_table + WHERE var=1"; + let df = session.sql(sql).await?; + + let batches = df.collect().await?; + for batch in batches { + let lat_values = extract_col::("lat", &batch); + let var_values = extract_col::("var", &batch); + let other_var_values = extract_str_col("other_var", &batch); + assert!(lat_values.iter().all(|lat| *lat >= 38.0 && *lat <= 39.0)); + assert!(var_values.iter().all(|var| var == &1)); + assert!(other_var_values + .iter() + .all(|other_var| other_var.unwrap() == "a" || other_var.unwrap() == "b")); + } + + Ok(()) + } } diff --git a/src/datafusion/table_provider.rs b/src/datafusion/table_provider.rs index ecbc84f..edc4b0f 100644 --- a/src/datafusion/table_provider.rs +++ b/src/datafusion/table_provider.rs @@ -17,10 +17,11 @@ use std::sync::Arc; -use arrow_schema::{Schema, SchemaRef}; +use arrow::datatypes::DataType; +use arrow_schema::{Field, Schema, SchemaBuilder, SchemaRef}; use async_trait::async_trait; use datafusion::{ - common::{Statistics, ToDFSchema}, + common::{Result as DataFusionResult, Statistics, ToDFSchema}, datasource::{ listing::{ListingTableUrl, PartitionedFile}, physical_plan::FileScanConfig, @@ -31,55 +32,158 @@ use datafusion::{ physical_plan::ExecutionPlan, }; use datafusion_physical_expr::create_physical_expr; +use futures::StreamExt; use crate::{ async_reader::{ZarrPath, ZarrReadAsync}, - reader::ZarrResult, + reader::{ZarrError, ZarrResult}, }; -use super::helpers::expr_applicable_for_cols; +use super::helpers::{expr_applicable_for_cols, pruned_partition_list, split_files}; use super::scanner::ZarrScan; -pub struct ListingZarrTableOptions {} +pub struct ListingZarrTableOptions { + pub table_partition_cols: Vec<(String, DataType)>, + pub target_partitions: usize, +} + +impl Default for ListingZarrTableOptions { + fn default() -> Self { + Self::new() + } +} impl ListingZarrTableOptions { + pub fn new() -> Self { + Self { + table_partition_cols: vec![], + target_partitions: 1, + } + } + + pub fn with_partition_cols(mut self, table_partition_cols: Vec<(String, DataType)>) -> Self { + self.table_partition_cols = table_partition_cols; + self + } + + pub fn with_target_partitions(mut self, target_partitions: usize) -> Self { + self.target_partitions = target_partitions; + self + } + pub async fn infer_schema( &self, state: &SessionState, table_path: &ListingTableUrl, ) -> ZarrResult { let store = state.runtime_env().object_store(table_path)?; + let prefix = table_path.prefix(); - let zarr_path = ZarrPath::new(store, table_path.prefix().clone()); - let schema = zarr_path.get_zarr_metadata().await?.arrow_schema()?; + let n_partitions = self.table_partition_cols.len(); + let mut files = table_path.list_all_files(state, &store, "zgroup").await?; + let mut schema_to_return: Option = None; + while let Some(file) = files.next().await { + let mut p = prefix.clone(); + let file = file?.location; + for (cnt, part) in file.prefix_match(prefix).unwrap().enumerate() { + if cnt == n_partitions { + if let Some(ext) = file.extension() { + if ext == "zgroup" { + let schema = ZarrPath::new(store.clone(), p.clone()) + .get_zarr_metadata() + .await? + .arrow_schema()?; + if let Some(sch) = &schema_to_return { + if sch != &schema { + return Err(ZarrError::InvalidMetadata( + "mismatch between different partition schemas".into(), + )); + } + } else { + schema_to_return = Some(schema); + } + } + } + } + p = p.child(part); + } + } - Ok(schema) + if let Some(schema_to_return) = schema_to_return { + return Ok(schema_to_return); + } + Err(ZarrError::InvalidMetadata( + "could not infer schema for zarr table path".into(), + )) } } pub struct ListingZarrTableConfig { - /// The inner listing table configuration table_path: ListingTableUrl, + pub file_schema: Schema, + pub options: ListingZarrTableOptions, } impl ListingZarrTableConfig { /// Create a new ListingZarrTableConfig - pub fn new(table_path: ListingTableUrl) -> Self { - Self { table_path } + pub fn new( + table_path: ListingTableUrl, + file_schema: Schema, + options: ListingZarrTableOptions, + ) -> Self { + Self { + table_path, + file_schema, + options, + } } } pub struct ZarrTableProvider { + // the distinction between the file schema and the table schema is + // that the latter could include partitioned columns. + file_schema: Schema, table_schema: Schema, - config: ListingZarrTableConfig, + table_path: ListingTableUrl, + options: ListingZarrTableOptions, } impl ZarrTableProvider { - pub fn new(config: ListingZarrTableConfig, table_schema: Schema) -> Self { - Self { + pub fn try_new(config: ListingZarrTableConfig) -> DataFusionResult { + let mut builder = SchemaBuilder::from(config.file_schema.clone()); + for (part_col_name, part_col_type) in &config.options.table_partition_cols { + builder.push(Field::new(part_col_name, part_col_type.clone(), false)); + } + let table_schema = builder.finish(); + + Ok(Self { + file_schema: config.file_schema, table_schema, - config, + table_path: config.table_path, + options: config.options, + }) + } + + async fn list_stores_for_scan<'a>( + &'a self, + ctx: &'a SessionState, + filters: &'a [Expr], + ) -> datafusion::error::Result>> { + let store = ctx.runtime_env().object_store(&self.table_path)?; + let mut partition_stream = pruned_partition_list( + store.as_ref(), + &self.table_path, + filters, + &self.options.table_partition_cols, + ) + .await?; + + let mut partition_list = vec![]; + while let Some(partition) = partition_stream.next().await { + partition_list.push(partition?); } + + Ok(split_files(partition_list, self.options.target_partitions)) } } @@ -101,22 +205,21 @@ impl TableProvider for ZarrTableProvider { &self, filters: &[&Expr], ) -> datafusion::error::Result> { - // TODO handle predicates on partition columns as Exact. Ok(filters .iter() .map(|filter| { if expr_applicable_for_cols( &self - .table_schema - .fields + .options + .table_partition_cols .iter() - .map(|field| field.name().to_string()) + .map(|x| x.0.clone()) .collect::>(), filter, ) { - TableProviderFilterPushDown::Inexact + TableProviderFilterPushDown::Exact } else { - TableProviderFilterPushDown::Unsupported + TableProviderFilterPushDown::Inexact } }) .collect()) @@ -129,10 +232,9 @@ impl TableProvider for ZarrTableProvider { filters: &[Expr], limit: Option, ) -> datafusion::error::Result> { - let object_store_url = self.config.table_path.object_store(); + let object_store_url = self.table_path.object_store(); - let pf = PartitionedFile::new(self.config.table_path.prefix().clone(), 0); - let file_groups = vec![vec![pf]]; + let file_groups = self.list_stores_for_scan(state, filters).await?; let filters = if let Some(expr) = conjunction(filters.to_vec()) { let table_df_schema = self.table_schema.clone().to_dfschema()?; @@ -142,19 +244,25 @@ impl TableProvider for ZarrTableProvider { None }; + let table_partition_cols = self + .options + .table_partition_cols + .iter() + .map(|col| Ok(self.table_schema.field_with_name(&col.0)?.clone())) + .collect::>>()?; + let file_scan_config = FileScanConfig { object_store_url, - file_schema: Arc::new(self.table_schema.clone()), // TODO differentiate between file and table schema + file_schema: Arc::new(self.file_schema.clone()), file_groups, statistics: Statistics::new_unknown(&self.table_schema), projection: projection.cloned(), limit, - table_partition_cols: vec![], + table_partition_cols, output_ordering: vec![], }; let scanner = ZarrScan::new(file_scan_config, filters); - Ok(Arc::new(scanner)) } } diff --git a/test-data b/test-data index 333fb9a..2af70b4 160000 --- a/test-data +++ b/test-data @@ -1 +1 @@ -Subproject commit 333fb9ac03de3aa471933f03afe29b56e864c0ac +Subproject commit 2af70b4c0f59921734b109520acc3ea83e681476