diff --git a/README.md b/README.md index cd4c8c9..1583f2c 100644 --- a/README.md +++ b/README.md @@ -193,6 +193,7 @@ Alternatively, you can use the following environment variables when starting pos `pg_parquet` supports the following options in the `COPY FROM` command: - `format parquet`: you need to specify this option to read or write Parquet files which does not end with `.parquet[.]` extension, +- `match_by_position `: matches Parquet file fields to PostgreSQL table columns by their position in the schema rather than by their names. By default, the option is `false`. The option is useful when field names differ between the Parquet file and the table, but their order aligns. ## Configuration There is currently only one GUC parameter to enable/disable the `pg_parquet`: diff --git a/src/arrow_parquet/parquet_reader.rs b/src/arrow_parquet/parquet_reader.rs index cbeff07..4377860 100644 --- a/src/arrow_parquet/parquet_reader.rs +++ b/src/arrow_parquet/parquet_reader.rs @@ -38,10 +38,11 @@ pub(crate) struct ParquetReaderContext { parquet_reader: ParquetRecordBatchStream, attribute_contexts: Vec, binary_out_funcs: Vec>, + match_by_position: bool, } impl ParquetReaderContext { - pub(crate) fn new(uri: Url, tupledesc: &PgTupleDesc) -> Self { + pub(crate) fn new(uri: Url, match_by_position: bool, tupledesc: &PgTupleDesc) -> Self { // Postgis and Map contexts are used throughout reading the parquet file. // We need to reset them to avoid reading the stale data. (e.g. extension could be dropped) reset_postgis_context(); @@ -69,6 +70,7 @@ impl ParquetReaderContext { parquet_file_schema.clone(), tupledesc_schema.clone(), &attributes, + match_by_position, ); let attribute_contexts = collect_arrow_to_pg_attribute_contexts( @@ -85,6 +87,7 @@ impl ParquetReaderContext { attribute_contexts, parquet_reader, binary_out_funcs, + match_by_position, started: false, finished: false, } @@ -116,15 +119,23 @@ impl ParquetReaderContext { fn record_batch_to_tuple_datums( record_batch: RecordBatch, attribute_contexts: &[ArrowToPgAttributeContext], + match_by_position: bool, ) -> Vec> { let mut datums = vec![]; - for attribute_context in attribute_contexts { + for (attribute_idx, attribute_context) in attribute_contexts.iter().enumerate() { let name = attribute_context.name(); - let column_array = record_batch - .column_by_name(name) - .unwrap_or_else(|| panic!("column {} not found", name)); + let column_array = if match_by_position { + record_batch + .columns() + .get(attribute_idx) + .unwrap_or_else(|| panic!("column {} not found", name)) + } else { + record_batch + .column_by_name(name) + .unwrap_or_else(|| panic!("column {} not found", name)) + }; let datum = if attribute_context.needs_cast() { // should fail instead of returning None if the cast fails at runtime @@ -181,8 +192,11 @@ impl ParquetReaderContext { self.buffer.extend_from_slice(&attnum_len_bytes); // convert the columnar arrays in record batch to tuple datums - let tuple_datums = - Self::record_batch_to_tuple_datums(record_batch, &self.attribute_contexts); + let tuple_datums = Self::record_batch_to_tuple_datums( + record_batch, + &self.attribute_contexts, + self.match_by_position, + ); // write the tuple datums to the ParquetReader's internal buffer in PG copy format for (datum, out_func) in tuple_datums.into_iter().zip(self.binary_out_funcs.iter()) diff --git a/src/arrow_parquet/schema_parser.rs b/src/arrow_parquet/schema_parser.rs index b76ee70..9030424 100644 --- a/src/arrow_parquet/schema_parser.rs +++ b/src/arrow_parquet/schema_parser.rs @@ -349,21 +349,38 @@ pub(crate) fn ensure_file_schema_match_tupledesc_schema( file_schema: Arc, tupledesc_schema: Arc, attributes: &[FormData_pg_attribute], + match_by_position: bool, ) -> Vec> { let mut cast_to_types = Vec::new(); + if match_by_position && tupledesc_schema.fields().len() != file_schema.fields().len() { + panic!( + "column count mismatch between table and parquet file. \ + parquet file has {} columns, but table has {} columns", + file_schema.fields().len(), + tupledesc_schema.fields().len() + ); + } + for (tupledesc_schema_field, attribute) in tupledesc_schema.fields().iter().zip(attributes.iter()) { let field_name = tupledesc_schema_field.name(); - let file_schema_field = file_schema.column_with_name(field_name); + let file_schema_field = if match_by_position { + file_schema.field(attribute.attnum as usize - 1) + } else { + let file_schema_field = file_schema.column_with_name(field_name); - if file_schema_field.is_none() { - panic!("column \"{}\" is not found in parquet file", field_name); - } + if file_schema_field.is_none() { + panic!("column \"{}\" is not found in parquet file", field_name); + } + + let (_, file_schema_field) = file_schema_field.unwrap(); + + file_schema_field + }; - let (_, file_schema_field) = file_schema_field.unwrap(); let file_schema_field = Arc::new(file_schema_field.clone()); let from_type = file_schema_field.data_type(); @@ -378,7 +395,7 @@ pub(crate) fn ensure_file_schema_match_tupledesc_schema( if !is_coercible(from_type, to_type, attribute.atttypid, attribute.atttypmod) { panic!( "type mismatch for column \"{}\" between table and parquet file.\n\n\ - table has \"{}\"\n\nparquet file has \"{}\"", + table has \"{}\"\n\nparquet file has \"{}\"", field_name, to_type, from_type ); } diff --git a/src/parquet_copy_hook/copy_from.rs b/src/parquet_copy_hook/copy_from.rs index bf3a878..4987112 100644 --- a/src/parquet_copy_hook/copy_from.rs +++ b/src/parquet_copy_hook/copy_from.rs @@ -20,8 +20,8 @@ use crate::{ }; use super::copy_utils::{ - copy_stmt_attribute_list, copy_stmt_create_namespace_item, copy_stmt_create_parse_state, - create_filtered_tupledesc_for_relation, + copy_from_stmt_match_by_position, copy_stmt_attribute_list, copy_stmt_create_namespace_item, + copy_stmt_create_parse_state, create_filtered_tupledesc_for_relation, }; // stack to store parquet reader contexts for COPY FROM. @@ -131,9 +131,11 @@ pub(crate) fn execute_copy_from( let tupledesc = create_filtered_tupledesc_for_relation(p_stmt, &relation); + let match_by_position = copy_from_stmt_match_by_position(p_stmt); + unsafe { // parquet reader context is used throughout the COPY FROM operation. - let parquet_reader_context = ParquetReaderContext::new(uri, &tupledesc); + let parquet_reader_context = ParquetReaderContext::new(uri, match_by_position, &tupledesc); push_parquet_reader_context(parquet_reader_context); // makes sure to set binary format diff --git a/src/parquet_copy_hook/copy_utils.rs b/src/parquet_copy_hook/copy_utils.rs index 068e95e..869fb1c 100644 --- a/src/parquet_copy_hook/copy_utils.rs +++ b/src/parquet_copy_hook/copy_utils.rs @@ -3,11 +3,12 @@ use std::{ffi::CStr, str::FromStr}; use pgrx::{ is_a, pg_sys::{ - addRangeTableEntryForRelation, defGetInt32, defGetInt64, defGetString, get_namespace_name, - get_rel_namespace, makeDefElem, makeString, make_parsestate, quote_qualified_identifier, - AccessShareLock, AsPgCStr, CopyStmt, CreateTemplateTupleDesc, DefElem, List, NoLock, Node, - NodeTag::T_CopyStmt, Oid, ParseNamespaceItem, ParseState, PlannedStmt, QueryEnvironment, - RangeVar, RangeVarGetRelidExtended, RowExclusiveLock, TupleDescInitEntry, + addRangeTableEntryForRelation, defGetBoolean, defGetInt32, defGetInt64, defGetString, + get_namespace_name, get_rel_namespace, makeDefElem, makeString, make_parsestate, + quote_qualified_identifier, AccessShareLock, AsPgCStr, CopyStmt, CreateTemplateTupleDesc, + DefElem, List, NoLock, Node, NodeTag::T_CopyStmt, Oid, ParseNamespaceItem, ParseState, + PlannedStmt, QueryEnvironment, RangeVar, RangeVarGetRelidExtended, RowExclusiveLock, + TupleDescInitEntry, }, PgBox, PgList, PgRelation, PgTupleDesc, }; @@ -109,7 +110,7 @@ pub(crate) fn validate_copy_to_options(p_stmt: &PgBox, uri: &Url) { } pub(crate) fn validate_copy_from_options(p_stmt: &PgBox) { - validate_copy_option_names(p_stmt, &["format", "freeze"]); + validate_copy_option_names(p_stmt, &["format", "match_by_position", "freeze"]); let format_option = copy_stmt_get_option(p_stmt, "format"); @@ -253,6 +254,16 @@ pub(crate) fn copy_from_stmt_create_option_list(p_stmt: &PgBox) -> new_copy_options } +pub(crate) fn copy_from_stmt_match_by_position(p_stmt: &PgBox) -> bool { + let match_by_position_option = copy_stmt_get_option(p_stmt, "match_by_position"); + + if match_by_position_option.is_null() { + false + } else { + unsafe { defGetBoolean(match_by_position_option.as_ptr()) } + } +} + pub(crate) fn copy_stmt_get_option( p_stmt: &PgBox, option_name: &str, diff --git a/src/pgrx_tests/copy_from_coerce.rs b/src/pgrx_tests/copy_from_coerce.rs index 75c7af8..faef64c 100644 --- a/src/pgrx_tests/copy_from_coerce.rs +++ b/src/pgrx_tests/copy_from_coerce.rs @@ -966,7 +966,7 @@ mod tests { } #[pg_test] - fn test_table_with_different_field_position() { + fn test_table_with_different_position_match_by_name() { let copy_to = format!( "COPY (SELECT 1 as x, 'hello' as y) TO '{}'", LOCAL_TEST_FILE_PATH @@ -983,6 +983,34 @@ mod tests { assert_eq!(result, (Some("hello"), Some(1))); } + #[pg_test] + fn test_table_with_different_name_match_by_position() { + let copy_to = "COPY (SELECT 1 as a, 'hello' as b) TO '/tmp/test.parquet'"; + Spi::run(copy_to).unwrap(); + + let create_table = "CREATE TABLE test_table (x bigint, y varchar)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet' WITH (match_by_position true)"; + Spi::run(copy_from).unwrap(); + + let result = Spi::get_two::("SELECT x, y FROM test_table LIMIT 1").unwrap(); + assert_eq!(result, (Some(1), Some("hello"))); + } + + #[pg_test] + #[should_panic(expected = "column count mismatch between table and parquet file")] + fn test_table_with_different_name_match_by_position_fail() { + let copy_to = "COPY (SELECT 1 as a, 'hello' as b) TO '/tmp/test.parquet'"; + Spi::run(copy_to).unwrap(); + + let create_table = "CREATE TABLE test_table (x bigint, y varchar, z int)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet' WITH (match_by_position true)"; + Spi::run(copy_from).unwrap(); + } + #[pg_test] #[should_panic(expected = "column \"name\" is not found in parquet file")] fn test_missing_column_in_parquet() {