From 4241caef4129cc7983ea0f5c1c1f6dc864e82880 Mon Sep 17 00:00:00 2001 From: Aykut Bozkurt Date: Tue, 26 Nov 2024 23:07:41 +0300 Subject: [PATCH] Match fields by name via option We add an option for `COPY FROM` called `match_by_name` which matches Parquet file fields to PostgreSQL table columns `by their names` rather than `by their order` in the schema. By default, the option is `false`. The option is useful when field order differs between the Parquet file and the table, but their names match. **!!IMPORTANT!!**: This is a breaking change. Before the PR, we match always by name. This is a bit strict and not common way to match schemas. (e.g. COPY FROM csv at postgres or COPY FROM of duckdb match by field position by default) This is why we match by position by default and have a COPY FROM option `match_by_name` that can be set to true for the old behaviour. Closes #39. --- README.md | 1 + src/arrow_parquet/arrow_to_pg.rs | 2 +- src/arrow_parquet/parquet_reader.rs | 36 ++++++++++++++---- src/arrow_parquet/pg_to_arrow.rs | 2 +- src/arrow_parquet/schema_parser.rs | 59 ++++++++++++++++++++++++----- src/parquet_copy_hook/copy_from.rs | 8 ++-- src/parquet_copy_hook/copy_utils.rs | 23 ++++++++--- src/pgrx_tests/copy_from_coerce.rs | 40 +++++++++++++++++-- src/pgrx_tests/copy_pg_rules.rs | 17 ++++++++- src/pgrx_utils.rs | 10 +++-- 10 files changed, 163 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index cd4c8c9..6e93a20 100644 --- a/README.md +++ b/README.md @@ -193,6 +193,7 @@ Alternatively, you can use the following environment variables when starting pos `pg_parquet` supports the following options in the `COPY FROM` command: - `format parquet`: you need to specify this option to read or write Parquet files which does not end with `.parquet[.]` extension, +- `match_by_name `: matches Parquet file fields to PostgreSQL table columns by their name rather than by their position in the schema (default). By default, the option is `false`. The option is useful when field order differs between the Parquet file and the table, but their names match. ## Configuration There is currently only one GUC parameter to enable/disable the `pg_parquet`: diff --git a/src/arrow_parquet/arrow_to_pg.rs b/src/arrow_parquet/arrow_to_pg.rs index 079ee08..2aa8d0e 100644 --- a/src/arrow_parquet/arrow_to_pg.rs +++ b/src/arrow_parquet/arrow_to_pg.rs @@ -163,7 +163,7 @@ impl ArrowToPgAttributeContext { }; let attributes = - collect_attributes_for(CollectAttributesFor::Struct, attribute_tupledesc); + collect_attributes_for(CollectAttributesFor::Other, attribute_tupledesc); // we only cast the top-level attributes, which already covers the nested attributes let cast_to_types = None; diff --git a/src/arrow_parquet/parquet_reader.rs b/src/arrow_parquet/parquet_reader.rs index cbeff07..e3366ae 100644 --- a/src/arrow_parquet/parquet_reader.rs +++ b/src/arrow_parquet/parquet_reader.rs @@ -16,7 +16,11 @@ use url::Url; use crate::{ arrow_parquet::{ - arrow_to_pg::to_pg_datum, schema_parser::parquet_schema_string_from_attributes, + arrow_to_pg::to_pg_datum, + schema_parser::{ + error_if_copy_from_match_by_position_with_generated_columns, + parquet_schema_string_from_attributes, + }, }, pgrx_utils::{collect_attributes_for, CollectAttributesFor}, type_compat::{geometry::reset_postgis_context, map::reset_map_context}, @@ -38,15 +42,18 @@ pub(crate) struct ParquetReaderContext { parquet_reader: ParquetRecordBatchStream, attribute_contexts: Vec, binary_out_funcs: Vec>, + match_by_name: bool, } impl ParquetReaderContext { - pub(crate) fn new(uri: Url, tupledesc: &PgTupleDesc) -> Self { + pub(crate) fn new(uri: Url, match_by_name: bool, tupledesc: &PgTupleDesc) -> Self { // Postgis and Map contexts are used throughout reading the parquet file. // We need to reset them to avoid reading the stale data. (e.g. extension could be dropped) reset_postgis_context(); reset_map_context(); + error_if_copy_from_match_by_position_with_generated_columns(tupledesc, match_by_name); + let parquet_reader = parquet_reader_from_uri(&uri); let parquet_file_schema = parquet_reader.schema(); @@ -69,6 +76,7 @@ impl ParquetReaderContext { parquet_file_schema.clone(), tupledesc_schema.clone(), &attributes, + match_by_name, ); let attribute_contexts = collect_arrow_to_pg_attribute_contexts( @@ -85,6 +93,7 @@ impl ParquetReaderContext { attribute_contexts, parquet_reader, binary_out_funcs, + match_by_name, started: false, finished: false, } @@ -116,15 +125,23 @@ impl ParquetReaderContext { fn record_batch_to_tuple_datums( record_batch: RecordBatch, attribute_contexts: &[ArrowToPgAttributeContext], + match_by_name: bool, ) -> Vec> { let mut datums = vec![]; - for attribute_context in attribute_contexts { + for (attribute_idx, attribute_context) in attribute_contexts.iter().enumerate() { let name = attribute_context.name(); - let column_array = record_batch - .column_by_name(name) - .unwrap_or_else(|| panic!("column {} not found", name)); + let column_array = if match_by_name { + record_batch + .column_by_name(name) + .unwrap_or_else(|| panic!("column {} not found", name)) + } else { + record_batch + .columns() + .get(attribute_idx) + .unwrap_or_else(|| panic!("column {} not found", name)) + }; let datum = if attribute_context.needs_cast() { // should fail instead of returning None if the cast fails at runtime @@ -181,8 +198,11 @@ impl ParquetReaderContext { self.buffer.extend_from_slice(&attnum_len_bytes); // convert the columnar arrays in record batch to tuple datums - let tuple_datums = - Self::record_batch_to_tuple_datums(record_batch, &self.attribute_contexts); + let tuple_datums = Self::record_batch_to_tuple_datums( + record_batch, + &self.attribute_contexts, + self.match_by_name, + ); // write the tuple datums to the ParquetReader's internal buffer in PG copy format for (datum, out_func) in tuple_datums.into_iter().zip(self.binary_out_funcs.iter()) diff --git a/src/arrow_parquet/pg_to_arrow.rs b/src/arrow_parquet/pg_to_arrow.rs index 530c7f7..17d774b 100644 --- a/src/arrow_parquet/pg_to_arrow.rs +++ b/src/arrow_parquet/pg_to_arrow.rs @@ -148,7 +148,7 @@ impl PgToArrowAttributeContext { }; let attributes = - collect_attributes_for(CollectAttributesFor::Struct, &attribute_tupledesc); + collect_attributes_for(CollectAttributesFor::Other, &attribute_tupledesc); collect_pg_to_arrow_attribute_contexts(&attributes, &fields) }); diff --git a/src/arrow_parquet/schema_parser.rs b/src/arrow_parquet/schema_parser.rs index b76ee70..b4c5798 100644 --- a/src/arrow_parquet/schema_parser.rs +++ b/src/arrow_parquet/schema_parser.rs @@ -16,7 +16,7 @@ use pgrx::{check_for_interrupts, prelude::*, PgTupleDesc}; use crate::{ pgrx_utils::{ array_element_typoid, collect_attributes_for, domain_array_base_elem_typoid, is_array_type, - is_composite_type, tuple_desc, CollectAttributesFor, + is_composite_type, is_generated_attribute, tuple_desc, CollectAttributesFor, }, type_compat::{ geometry::is_postgis_geometry_type, @@ -95,7 +95,7 @@ fn parse_struct_schema(tupledesc: PgTupleDesc, elem_name: &str, field_id: &mut i let mut child_fields: Vec> = vec![]; - let attributes = collect_attributes_for(CollectAttributesFor::Struct, &tupledesc); + let attributes = collect_attributes_for(CollectAttributesFor::Other, &tupledesc); for attribute in attributes { if attribute.is_dropped() { @@ -342,6 +342,30 @@ fn adjust_map_entries_field(field: FieldRef) -> FieldRef { Arc::new(entries_field) } +pub(crate) fn error_if_copy_from_match_by_position_with_generated_columns( + tupledesc: &PgTupleDesc, + match_by_name: bool, +) { + // match_by_name can handle generated columns + if match_by_name { + return; + } + + let attributes = collect_attributes_for(CollectAttributesFor::Other, tupledesc); + + for attribute in attributes { + if is_generated_attribute(&attribute) { + ereport!( + PgLogLevel::ERROR, + PgSqlErrorCode::ERRCODE_FEATURE_NOT_SUPPORTED, + "COPY FROM parquet with generated columns is not supported", + "Try COPY FROM parquet WITH (match_by_name true). \" + It works only if the column names match with parquet file's.", + ); + } + } +} + // ensure_file_schema_match_tupledesc_schema throws an error if the file's schema does not match the table schema. // If the file's arrow schema is castable to the table's arrow schema, it returns a vector of Option // to cast to for each field. @@ -349,21 +373,38 @@ pub(crate) fn ensure_file_schema_match_tupledesc_schema( file_schema: Arc, tupledesc_schema: Arc, attributes: &[FormData_pg_attribute], + match_by_name: bool, ) -> Vec> { let mut cast_to_types = Vec::new(); + if !match_by_name && tupledesc_schema.fields().len() != file_schema.fields().len() { + panic!( + "column count mismatch between table and parquet file. \ + parquet file has {} columns, but table has {} columns", + file_schema.fields().len(), + tupledesc_schema.fields().len() + ); + } + for (tupledesc_schema_field, attribute) in tupledesc_schema.fields().iter().zip(attributes.iter()) { let field_name = tupledesc_schema_field.name(); - let file_schema_field = file_schema.column_with_name(field_name); + let file_schema_field = if match_by_name { + let file_schema_field = file_schema.column_with_name(field_name); - if file_schema_field.is_none() { - panic!("column \"{}\" is not found in parquet file", field_name); - } + if file_schema_field.is_none() { + panic!("column \"{}\" is not found in parquet file", field_name); + } + + let (_, file_schema_field) = file_schema_field.unwrap(); + + file_schema_field + } else { + file_schema.field(attribute.attnum as usize - 1) + }; - let (_, file_schema_field) = file_schema_field.unwrap(); let file_schema_field = Arc::new(file_schema_field.clone()); let from_type = file_schema_field.data_type(); @@ -378,7 +419,7 @@ pub(crate) fn ensure_file_schema_match_tupledesc_schema( if !is_coercible(from_type, to_type, attribute.atttypid, attribute.atttypmod) { panic!( "type mismatch for column \"{}\" between table and parquet file.\n\n\ - table has \"{}\"\n\nparquet file has \"{}\"", + table has \"{}\"\n\nparquet file has \"{}\"", field_name, to_type, from_type ); } @@ -413,7 +454,7 @@ fn is_coercible(from_type: &DataType, to_type: &DataType, to_typoid: Oid, to_typ let tupledesc = tuple_desc(to_typoid, to_typmod); - let attributes = collect_attributes_for(CollectAttributesFor::Struct, &tupledesc); + let attributes = collect_attributes_for(CollectAttributesFor::Other, &tupledesc); for (from_field, (to_field, to_attribute)) in from_fields .iter() diff --git a/src/parquet_copy_hook/copy_from.rs b/src/parquet_copy_hook/copy_from.rs index bf3a878..6a44446 100644 --- a/src/parquet_copy_hook/copy_from.rs +++ b/src/parquet_copy_hook/copy_from.rs @@ -20,8 +20,8 @@ use crate::{ }; use super::copy_utils::{ - copy_stmt_attribute_list, copy_stmt_create_namespace_item, copy_stmt_create_parse_state, - create_filtered_tupledesc_for_relation, + copy_from_stmt_match_by_name, copy_stmt_attribute_list, copy_stmt_create_namespace_item, + copy_stmt_create_parse_state, create_filtered_tupledesc_for_relation, }; // stack to store parquet reader contexts for COPY FROM. @@ -131,9 +131,11 @@ pub(crate) fn execute_copy_from( let tupledesc = create_filtered_tupledesc_for_relation(p_stmt, &relation); + let match_by_name = copy_from_stmt_match_by_name(p_stmt); + unsafe { // parquet reader context is used throughout the COPY FROM operation. - let parquet_reader_context = ParquetReaderContext::new(uri, &tupledesc); + let parquet_reader_context = ParquetReaderContext::new(uri, match_by_name, &tupledesc); push_parquet_reader_context(parquet_reader_context); // makes sure to set binary format diff --git a/src/parquet_copy_hook/copy_utils.rs b/src/parquet_copy_hook/copy_utils.rs index 068e95e..02c3c5d 100644 --- a/src/parquet_copy_hook/copy_utils.rs +++ b/src/parquet_copy_hook/copy_utils.rs @@ -3,11 +3,12 @@ use std::{ffi::CStr, str::FromStr}; use pgrx::{ is_a, pg_sys::{ - addRangeTableEntryForRelation, defGetInt32, defGetInt64, defGetString, get_namespace_name, - get_rel_namespace, makeDefElem, makeString, make_parsestate, quote_qualified_identifier, - AccessShareLock, AsPgCStr, CopyStmt, CreateTemplateTupleDesc, DefElem, List, NoLock, Node, - NodeTag::T_CopyStmt, Oid, ParseNamespaceItem, ParseState, PlannedStmt, QueryEnvironment, - RangeVar, RangeVarGetRelidExtended, RowExclusiveLock, TupleDescInitEntry, + addRangeTableEntryForRelation, defGetBoolean, defGetInt32, defGetInt64, defGetString, + get_namespace_name, get_rel_namespace, makeDefElem, makeString, make_parsestate, + quote_qualified_identifier, AccessShareLock, AsPgCStr, CopyStmt, CreateTemplateTupleDesc, + DefElem, List, NoLock, Node, NodeTag::T_CopyStmt, Oid, ParseNamespaceItem, ParseState, + PlannedStmt, QueryEnvironment, RangeVar, RangeVarGetRelidExtended, RowExclusiveLock, + TupleDescInitEntry, }, PgBox, PgList, PgRelation, PgTupleDesc, }; @@ -109,7 +110,7 @@ pub(crate) fn validate_copy_to_options(p_stmt: &PgBox, uri: &Url) { } pub(crate) fn validate_copy_from_options(p_stmt: &PgBox) { - validate_copy_option_names(p_stmt, &["format", "freeze"]); + validate_copy_option_names(p_stmt, &["format", "match_by_name", "freeze"]); let format_option = copy_stmt_get_option(p_stmt, "format"); @@ -253,6 +254,16 @@ pub(crate) fn copy_from_stmt_create_option_list(p_stmt: &PgBox) -> new_copy_options } +pub(crate) fn copy_from_stmt_match_by_name(p_stmt: &PgBox) -> bool { + let match_by_name_option = copy_stmt_get_option(p_stmt, "match_by_name"); + + if match_by_name_option.is_null() { + false + } else { + unsafe { defGetBoolean(match_by_name_option.as_ptr()) } + } +} + pub(crate) fn copy_stmt_get_option( p_stmt: &PgBox, option_name: &str, diff --git a/src/pgrx_tests/copy_from_coerce.rs b/src/pgrx_tests/copy_from_coerce.rs index 75c7af8..ed0ebb9 100644 --- a/src/pgrx_tests/copy_from_coerce.rs +++ b/src/pgrx_tests/copy_from_coerce.rs @@ -966,7 +966,7 @@ mod tests { } #[pg_test] - fn test_table_with_different_field_position() { + fn test_table_with_different_position_match_by_name() { let copy_to = format!( "COPY (SELECT 1 as x, 'hello' as y) TO '{}'", LOCAL_TEST_FILE_PATH @@ -976,13 +976,44 @@ mod tests { let create_table = "CREATE TABLE test_table (y text, x int)"; Spi::run(create_table).unwrap(); - let copy_from = format!("COPY test_table FROM '{}'", LOCAL_TEST_FILE_PATH); + let copy_from = format!( + "COPY test_table FROM '{}' WITH (match_by_name true)", + LOCAL_TEST_FILE_PATH + ); Spi::run(©_from).unwrap(); let result = Spi::get_two::<&str, i32>("SELECT y, x FROM test_table LIMIT 1").unwrap(); assert_eq!(result, (Some("hello"), Some(1))); } + #[pg_test] + fn test_table_with_different_name_match_by_position() { + let copy_to = "COPY (SELECT 1 as a, 'hello' as b) TO '/tmp/test.parquet'"; + Spi::run(copy_to).unwrap(); + + let create_table = "CREATE TABLE test_table (x bigint, y varchar)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let result = Spi::get_two::("SELECT x, y FROM test_table LIMIT 1").unwrap(); + assert_eq!(result, (Some(1), Some("hello"))); + } + + #[pg_test] + #[should_panic(expected = "column count mismatch between table and parquet file")] + fn test_table_with_different_name_match_by_position_fail() { + let copy_to = "COPY (SELECT 1 as a, 'hello' as b) TO '/tmp/test.parquet'"; + Spi::run(copy_to).unwrap(); + + let create_table = "CREATE TABLE test_table (x bigint, y varchar, z int)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + } + #[pg_test] #[should_panic(expected = "column \"name\" is not found in parquet file")] fn test_missing_column_in_parquet() { @@ -992,7 +1023,10 @@ mod tests { let copy_to_parquet = format!("copy (select 100 as id) to '{}';", LOCAL_TEST_FILE_PATH); Spi::run(©_to_parquet).unwrap(); - let copy_from = format!("COPY test_table FROM '{}'", LOCAL_TEST_FILE_PATH); + let copy_from = format!( + "COPY test_table FROM '{}' with (match_by_name true)", + LOCAL_TEST_FILE_PATH + ); Spi::run(©_from).unwrap(); } diff --git a/src/pgrx_tests/copy_pg_rules.rs b/src/pgrx_tests/copy_pg_rules.rs index b0347e2..35c44d6 100644 --- a/src/pgrx_tests/copy_pg_rules.rs +++ b/src/pgrx_tests/copy_pg_rules.rs @@ -101,6 +101,21 @@ mod tests { Spi::run(©_from).unwrap(); } + #[pg_test] + #[should_panic(expected = "COPY FROM parquet with generated columns is not supported")] + fn test_copy_from_by_position_with_generated_columns_not_supported() { + Spi::run("DROP TABLE IF EXISTS test_table;").unwrap(); + + Spi::run("CREATE TABLE test_table (a int, b int generated always as (10) stored, c text);") + .unwrap(); + + let copy_from_query = format!( + "COPY test_table FROM '{}' WITH (format parquet);", + LOCAL_TEST_FILE_PATH + ); + Spi::run(copy_from_query.as_str()).unwrap(); + } + #[pg_test] fn test_with_generated_and_dropped_columns() { Spi::run("DROP TABLE IF EXISTS test_table;").unwrap(); @@ -123,7 +138,7 @@ mod tests { Spi::run("TRUNCATE test_table;").unwrap(); let copy_from_query = format!( - "COPY test_table FROM '{}' WITH (format parquet);", + "COPY test_table FROM '{}' WITH (format parquet, match_by_name true);", LOCAL_TEST_FILE_PATH ); Spi::run(copy_from_query.as_str()).unwrap(); diff --git a/src/pgrx_utils.rs b/src/pgrx_utils.rs index cb8f9fa..0793def 100644 --- a/src/pgrx_utils.rs +++ b/src/pgrx_utils.rs @@ -12,7 +12,7 @@ use pgrx::{ pub(crate) enum CollectAttributesFor { CopyFrom, CopyTo, - Struct, + Other, } // collect_attributes_for collects not-dropped attributes from the tuple descriptor. @@ -23,7 +23,7 @@ pub(crate) fn collect_attributes_for( ) -> Vec { let include_generated_columns = match copy_operation { CollectAttributesFor::CopyFrom => false, - CollectAttributesFor::CopyTo | CollectAttributesFor::Struct => true, + CollectAttributesFor::CopyTo | CollectAttributesFor::Other => true, }; let mut attributes = vec![]; @@ -35,7 +35,7 @@ pub(crate) fn collect_attributes_for( continue; } - if !include_generated_columns && attribute.attgenerated != 0 { + if !include_generated_columns && is_generated_attribute(attribute) { continue; } @@ -55,6 +55,10 @@ pub(crate) fn collect_attributes_for( attributes } +pub(crate) fn is_generated_attribute(attribute: &FormData_pg_attribute) -> bool { + attribute.attgenerated != 0 +} + pub(crate) fn tuple_desc(typoid: Oid, typmod: i32) -> PgTupleDesc<'static> { let tupledesc = unsafe { lookup_rowtype_tupdesc(typoid, typmod) }; unsafe { PgTupleDesc::from_pg(tupledesc) }