Skip to content

Commit

Permalink
tests: add tests for writing hive-partitioned parquet (apache#9316)
Browse files Browse the repository at this point in the history
* tests: adds tests associated with apache#9237

* style: clippy
  • Loading branch information
tshauck authored Feb 26, 2024
1 parent b8c6e0b commit a26f583
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 76 deletions.
74 changes: 0 additions & 74 deletions datafusion/core/src/datasource/physical_plan/parquet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2066,80 +2066,6 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn write_parquet_results() -> Result<()> {
// create partitioned input file and context
let tmp_dir = TempDir::new()?;
// let mut ctx = create_ctx(&tmp_dir, 4).await?;
let ctx = SessionContext::new_with_config(
SessionConfig::new().with_target_partitions(8),
);
let schema = populate_csv_partitions(&tmp_dir, 4, ".csv")?;
// register csv file with the execution context
ctx.register_csv(
"test",
tmp_dir.path().to_str().unwrap(),
CsvReadOptions::new().schema(&schema),
)
.await?;

// register a local file system object store for /tmp directory
let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?);
let local_url = Url::parse("file://local").unwrap();
ctx.runtime_env().register_object_store(&local_url, local);

// execute a simple query and write the results to parquet
let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out/";
let out_dir_url = "file://local/out/";
let df = ctx.sql("SELECT c1, c2 FROM test").await?;
df.write_parquet(out_dir_url, DataFrameWriteOptions::new(), None)
.await?;
// write_parquet(&mut ctx, "SELECT c1, c2 FROM test", &out_dir, None).await?;

// create a new context and verify that the results were saved to a partitioned parquet file
let ctx = SessionContext::new();

// get write_id
let mut paths = fs::read_dir(&out_dir).unwrap();
let path = paths.next();
let name = path
.unwrap()?
.path()
.file_name()
.expect("Should be a file name")
.to_str()
.expect("Should be a str")
.to_owned();
let (parsed_id, _) = name.split_once('_').expect("File should contain _ !");
let write_id = parsed_id.to_owned();

// register each partition as well as the top level dir
ctx.register_parquet(
"part0",
&format!("{out_dir}/{write_id}_0.parquet"),
ParquetReadOptions::default(),
)
.await?;

ctx.register_parquet("allparts", &out_dir, ParquetReadOptions::default())
.await?;

let part0 = ctx.sql("SELECT c1, c2 FROM part0").await?.collect().await?;
let allparts = ctx
.sql("SELECT c1, c2 FROM allparts")
.await?
.collect()
.await?;

let allparts_count: usize = allparts.iter().map(|batch| batch.num_rows()).sum();

assert_eq!(part0[0].schema(), allparts[0].schema());

assert_eq!(allparts_count, 40);

Ok(())
}

fn logical2physical(expr: &Expr, schema: &Schema) -> Arc<dyn PhysicalExpr> {
let df_schema = schema.clone().to_dfschema().unwrap();
let execution_props = ExecutionProps::new();
Expand Down
160 changes: 158 additions & 2 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,19 @@ use arrow::{
};
use arrow_array::Float32Array;
use arrow_schema::ArrowError;
use object_store::local::LocalFileSystem;
use std::fs;
use std::sync::Arc;
use tempfile::TempDir;
use url::Url;

use datafusion::dataframe::DataFrame;
use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
use datafusion::datasource::MemTable;
use datafusion::error::Result;
use datafusion::execution::context::{SessionContext, SessionState};
use datafusion::prelude::JoinType;
use datafusion::prelude::{CsvReadOptions, ParquetReadOptions};
use datafusion::test_util::parquet_test_data;
use datafusion::test_util::{parquet_test_data, populate_csv_partitions};
use datafusion::{assert_batches_eq, assert_batches_sorted_eq};
use datafusion_common::{assert_contains, DataFusionError, ScalarValue, UnnestOptions};
use datafusion_execution::config::SessionConfig;
Expand Down Expand Up @@ -1896,3 +1900,155 @@ async fn test_dataframe_placeholder_missing_param_values() -> Result<()> {

Ok(())
}

#[tokio::test]
async fn write_partitioned_parquet_results() -> Result<()> {
// create partitioned input file and context
let tmp_dir = TempDir::new()?;

let ctx = SessionContext::new();

// Create an in memory table with schema C1 and C2, both strings
let schema = Arc::new(Schema::new(vec![
Field::new("c1", DataType::Utf8, false),
Field::new("c2", DataType::Utf8, false),
]));

let record_batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(vec!["abc", "def"])),
Arc::new(StringArray::from(vec!["123", "456"])),
],
)?;

let mem_table = Arc::new(MemTable::try_new(schema, vec![vec![record_batch]])?);

// Register the table in the context
ctx.register_table("test", mem_table)?;

let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?);
let local_url = Url::parse("file://local").unwrap();
ctx.runtime_env().register_object_store(&local_url, local);

// execute a simple query and write the results to parquet
let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out/";
let out_dir_url = format!("file://{out_dir}");

// Write the results to parquet with partitioning
let df = ctx.sql("SELECT c1, c2 FROM test").await?;
let df_write_options =
DataFrameWriteOptions::new().with_partition_by(vec![String::from("c2")]);

df.write_parquet(&out_dir_url, df_write_options, None)
.await?;

// Explicitly read the parquet file at c2=123 to verify the physical files are partitioned
let partitioned_file = format!("{out_dir}/c2=123", out_dir = out_dir);
let filted_df = ctx
.read_parquet(&partitioned_file, ParquetReadOptions::default())
.await?;

// Check that the c2 column is gone and that c1 is abc.
let results = filted_df.collect().await?;
let expected = ["+-----+", "| c1 |", "+-----+", "| abc |", "+-----+"];

assert_batches_eq!(expected, &results);

// Read the entire set of parquet files
let df = ctx
.read_parquet(
&out_dir_url,
ParquetReadOptions::default()
.table_partition_cols(vec![(String::from("c2"), DataType::Utf8)]),
)
.await?;

// Check that the df has the entire set of data
let results = df.collect().await?;
let expected = [
"+-----+-----+",
"| c1 | c2 |",
"+-----+-----+",
"| abc | 123 |",
"| def | 456 |",
"+-----+-----+",
];

assert_batches_eq!(expected, &results);

Ok(())
}

#[tokio::test]
async fn write_parquet_results() -> Result<()> {
// create partitioned input file and context
let tmp_dir = TempDir::new()?;
// let mut ctx = create_ctx(&tmp_dir, 4).await?;
let ctx =
SessionContext::new_with_config(SessionConfig::new().with_target_partitions(8));
let schema = populate_csv_partitions(&tmp_dir, 4, ".csv")?;
// register csv file with the execution context
ctx.register_csv(
"test",
tmp_dir.path().to_str().unwrap(),
CsvReadOptions::new().schema(&schema),
)
.await?;

// register a local file system object store for /tmp directory
let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?);
let local_url = Url::parse("file://local").unwrap();
ctx.runtime_env().register_object_store(&local_url, local);

// execute a simple query and write the results to parquet
let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out/";
let out_dir_url = "file://local/out/";
let df = ctx.sql("SELECT c1, c2 FROM test").await?;
df.write_parquet(out_dir_url, DataFrameWriteOptions::new(), None)
.await?;
// write_parquet(&mut ctx, "SELECT c1, c2 FROM test", &out_dir, None).await?;

// create a new context and verify that the results were saved to a partitioned parquet file
let ctx = SessionContext::new();

// get write_id
let mut paths = fs::read_dir(&out_dir).unwrap();
let path = paths.next();
let name = path
.unwrap()?
.path()
.file_name()
.expect("Should be a file name")
.to_str()
.expect("Should be a str")
.to_owned();
let (parsed_id, _) = name.split_once('_').expect("File should contain _ !");
let write_id = parsed_id.to_owned();

// register each partition as well as the top level dir
ctx.register_parquet(
"part0",
&format!("{out_dir}/{write_id}_0.parquet"),
ParquetReadOptions::default(),
)
.await?;

ctx.register_parquet("allparts", &out_dir, ParquetReadOptions::default())
.await?;

let part0 = ctx.sql("SELECT c1, c2 FROM part0").await?.collect().await?;
let allparts = ctx
.sql("SELECT c1, c2 FROM allparts")
.await?
.collect()
.await?;

let allparts_count: usize = allparts.iter().map(|batch| batch.num_rows()).sum();

assert_eq!(part0[0].schema(), allparts[0].schema());

assert_eq!(allparts_count, 40);

Ok(())
}

0 comments on commit a26f583

Please sign in to comment.