diff --git a/README.md b/README.md index cd4c8c9..6e93a20 100644 --- a/README.md +++ b/README.md @@ -193,6 +193,7 @@ Alternatively, you can use the following environment variables when starting pos `pg_parquet` supports the following options in the `COPY FROM` command: - `format parquet`: you need to specify this option to read or write Parquet files which does not end with `.parquet[.]` extension, +- `match_by_name `: matches Parquet file fields to PostgreSQL table columns by their name rather than by their position in the schema (default). By default, the option is `false`. The option is useful when field order differs between the Parquet file and the table, but their names match. ## Configuration There is currently only one GUC parameter to enable/disable the `pg_parquet`: diff --git a/src/arrow_parquet/arrow_to_pg.rs b/src/arrow_parquet/arrow_to_pg.rs index 079ee08..2aa8d0e 100644 --- a/src/arrow_parquet/arrow_to_pg.rs +++ b/src/arrow_parquet/arrow_to_pg.rs @@ -163,7 +163,7 @@ impl ArrowToPgAttributeContext { }; let attributes = - collect_attributes_for(CollectAttributesFor::Struct, attribute_tupledesc); + collect_attributes_for(CollectAttributesFor::Other, attribute_tupledesc); // we only cast the top-level attributes, which already covers the nested attributes let cast_to_types = None; diff --git a/src/arrow_parquet/parquet_reader.rs b/src/arrow_parquet/parquet_reader.rs index cbeff07..e3366ae 100644 --- a/src/arrow_parquet/parquet_reader.rs +++ b/src/arrow_parquet/parquet_reader.rs @@ -16,7 +16,11 @@ use url::Url; use crate::{ arrow_parquet::{ - arrow_to_pg::to_pg_datum, schema_parser::parquet_schema_string_from_attributes, + arrow_to_pg::to_pg_datum, + schema_parser::{ + error_if_copy_from_match_by_position_with_generated_columns, + parquet_schema_string_from_attributes, + }, }, pgrx_utils::{collect_attributes_for, CollectAttributesFor}, type_compat::{geometry::reset_postgis_context, map::reset_map_context}, @@ -38,15 +42,18 @@ pub(crate) struct ParquetReaderContext { parquet_reader: ParquetRecordBatchStream, attribute_contexts: Vec, binary_out_funcs: Vec>, + match_by_name: bool, } impl ParquetReaderContext { - pub(crate) fn new(uri: Url, tupledesc: &PgTupleDesc) -> Self { + pub(crate) fn new(uri: Url, match_by_name: bool, tupledesc: &PgTupleDesc) -> Self { // Postgis and Map contexts are used throughout reading the parquet file. // We need to reset them to avoid reading the stale data. (e.g. extension could be dropped) reset_postgis_context(); reset_map_context(); + error_if_copy_from_match_by_position_with_generated_columns(tupledesc, match_by_name); + let parquet_reader = parquet_reader_from_uri(&uri); let parquet_file_schema = parquet_reader.schema(); @@ -69,6 +76,7 @@ impl ParquetReaderContext { parquet_file_schema.clone(), tupledesc_schema.clone(), &attributes, + match_by_name, ); let attribute_contexts = collect_arrow_to_pg_attribute_contexts( @@ -85,6 +93,7 @@ impl ParquetReaderContext { attribute_contexts, parquet_reader, binary_out_funcs, + match_by_name, started: false, finished: false, } @@ -116,15 +125,23 @@ impl ParquetReaderContext { fn record_batch_to_tuple_datums( record_batch: RecordBatch, attribute_contexts: &[ArrowToPgAttributeContext], + match_by_name: bool, ) -> Vec> { let mut datums = vec![]; - for attribute_context in attribute_contexts { + for (attribute_idx, attribute_context) in attribute_contexts.iter().enumerate() { let name = attribute_context.name(); - let column_array = record_batch - .column_by_name(name) - .unwrap_or_else(|| panic!("column {} not found", name)); + let column_array = if match_by_name { + record_batch + .column_by_name(name) + .unwrap_or_else(|| panic!("column {} not found", name)) + } else { + record_batch + .columns() + .get(attribute_idx) + .unwrap_or_else(|| panic!("column {} not found", name)) + }; let datum = if attribute_context.needs_cast() { // should fail instead of returning None if the cast fails at runtime @@ -181,8 +198,11 @@ impl ParquetReaderContext { self.buffer.extend_from_slice(&attnum_len_bytes); // convert the columnar arrays in record batch to tuple datums - let tuple_datums = - Self::record_batch_to_tuple_datums(record_batch, &self.attribute_contexts); + let tuple_datums = Self::record_batch_to_tuple_datums( + record_batch, + &self.attribute_contexts, + self.match_by_name, + ); // write the tuple datums to the ParquetReader's internal buffer in PG copy format for (datum, out_func) in tuple_datums.into_iter().zip(self.binary_out_funcs.iter()) diff --git a/src/arrow_parquet/pg_to_arrow.rs b/src/arrow_parquet/pg_to_arrow.rs index 530c7f7..17d774b 100644 --- a/src/arrow_parquet/pg_to_arrow.rs +++ b/src/arrow_parquet/pg_to_arrow.rs @@ -148,7 +148,7 @@ impl PgToArrowAttributeContext { }; let attributes = - collect_attributes_for(CollectAttributesFor::Struct, &attribute_tupledesc); + collect_attributes_for(CollectAttributesFor::Other, &attribute_tupledesc); collect_pg_to_arrow_attribute_contexts(&attributes, &fields) }); diff --git a/src/arrow_parquet/schema_parser.rs b/src/arrow_parquet/schema_parser.rs index b76ee70..b4c5798 100644 --- a/src/arrow_parquet/schema_parser.rs +++ b/src/arrow_parquet/schema_parser.rs @@ -16,7 +16,7 @@ use pgrx::{check_for_interrupts, prelude::*, PgTupleDesc}; use crate::{ pgrx_utils::{ array_element_typoid, collect_attributes_for, domain_array_base_elem_typoid, is_array_type, - is_composite_type, tuple_desc, CollectAttributesFor, + is_composite_type, is_generated_attribute, tuple_desc, CollectAttributesFor, }, type_compat::{ geometry::is_postgis_geometry_type, @@ -95,7 +95,7 @@ fn parse_struct_schema(tupledesc: PgTupleDesc, elem_name: &str, field_id: &mut i let mut child_fields: Vec> = vec![]; - let attributes = collect_attributes_for(CollectAttributesFor::Struct, &tupledesc); + let attributes = collect_attributes_for(CollectAttributesFor::Other, &tupledesc); for attribute in attributes { if attribute.is_dropped() { @@ -342,6 +342,30 @@ fn adjust_map_entries_field(field: FieldRef) -> FieldRef { Arc::new(entries_field) } +pub(crate) fn error_if_copy_from_match_by_position_with_generated_columns( + tupledesc: &PgTupleDesc, + match_by_name: bool, +) { + // match_by_name can handle generated columns + if match_by_name { + return; + } + + let attributes = collect_attributes_for(CollectAttributesFor::Other, tupledesc); + + for attribute in attributes { + if is_generated_attribute(&attribute) { + ereport!( + PgLogLevel::ERROR, + PgSqlErrorCode::ERRCODE_FEATURE_NOT_SUPPORTED, + "COPY FROM parquet with generated columns is not supported", + "Try COPY FROM parquet WITH (match_by_name true). \" + It works only if the column names match with parquet file's.", + ); + } + } +} + // ensure_file_schema_match_tupledesc_schema throws an error if the file's schema does not match the table schema. // If the file's arrow schema is castable to the table's arrow schema, it returns a vector of Option // to cast to for each field. @@ -349,21 +373,38 @@ pub(crate) fn ensure_file_schema_match_tupledesc_schema( file_schema: Arc, tupledesc_schema: Arc, attributes: &[FormData_pg_attribute], + match_by_name: bool, ) -> Vec> { let mut cast_to_types = Vec::new(); + if !match_by_name && tupledesc_schema.fields().len() != file_schema.fields().len() { + panic!( + "column count mismatch between table and parquet file. \ + parquet file has {} columns, but table has {} columns", + file_schema.fields().len(), + tupledesc_schema.fields().len() + ); + } + for (tupledesc_schema_field, attribute) in tupledesc_schema.fields().iter().zip(attributes.iter()) { let field_name = tupledesc_schema_field.name(); - let file_schema_field = file_schema.column_with_name(field_name); + let file_schema_field = if match_by_name { + let file_schema_field = file_schema.column_with_name(field_name); - if file_schema_field.is_none() { - panic!("column \"{}\" is not found in parquet file", field_name); - } + if file_schema_field.is_none() { + panic!("column \"{}\" is not found in parquet file", field_name); + } + + let (_, file_schema_field) = file_schema_field.unwrap(); + + file_schema_field + } else { + file_schema.field(attribute.attnum as usize - 1) + }; - let (_, file_schema_field) = file_schema_field.unwrap(); let file_schema_field = Arc::new(file_schema_field.clone()); let from_type = file_schema_field.data_type(); @@ -378,7 +419,7 @@ pub(crate) fn ensure_file_schema_match_tupledesc_schema( if !is_coercible(from_type, to_type, attribute.atttypid, attribute.atttypmod) { panic!( "type mismatch for column \"{}\" between table and parquet file.\n\n\ - table has \"{}\"\n\nparquet file has \"{}\"", + table has \"{}\"\n\nparquet file has \"{}\"", field_name, to_type, from_type ); } @@ -413,7 +454,7 @@ fn is_coercible(from_type: &DataType, to_type: &DataType, to_typoid: Oid, to_typ let tupledesc = tuple_desc(to_typoid, to_typmod); - let attributes = collect_attributes_for(CollectAttributesFor::Struct, &tupledesc); + let attributes = collect_attributes_for(CollectAttributesFor::Other, &tupledesc); for (from_field, (to_field, to_attribute)) in from_fields .iter() diff --git a/src/parquet_copy_hook/copy_from.rs b/src/parquet_copy_hook/copy_from.rs index bf3a878..6a44446 100644 --- a/src/parquet_copy_hook/copy_from.rs +++ b/src/parquet_copy_hook/copy_from.rs @@ -20,8 +20,8 @@ use crate::{ }; use super::copy_utils::{ - copy_stmt_attribute_list, copy_stmt_create_namespace_item, copy_stmt_create_parse_state, - create_filtered_tupledesc_for_relation, + copy_from_stmt_match_by_name, copy_stmt_attribute_list, copy_stmt_create_namespace_item, + copy_stmt_create_parse_state, create_filtered_tupledesc_for_relation, }; // stack to store parquet reader contexts for COPY FROM. @@ -131,9 +131,11 @@ pub(crate) fn execute_copy_from( let tupledesc = create_filtered_tupledesc_for_relation(p_stmt, &relation); + let match_by_name = copy_from_stmt_match_by_name(p_stmt); + unsafe { // parquet reader context is used throughout the COPY FROM operation. - let parquet_reader_context = ParquetReaderContext::new(uri, &tupledesc); + let parquet_reader_context = ParquetReaderContext::new(uri, match_by_name, &tupledesc); push_parquet_reader_context(parquet_reader_context); // makes sure to set binary format diff --git a/src/parquet_copy_hook/copy_utils.rs b/src/parquet_copy_hook/copy_utils.rs index 068e95e..02c3c5d 100644 --- a/src/parquet_copy_hook/copy_utils.rs +++ b/src/parquet_copy_hook/copy_utils.rs @@ -3,11 +3,12 @@ use std::{ffi::CStr, str::FromStr}; use pgrx::{ is_a, pg_sys::{ - addRangeTableEntryForRelation, defGetInt32, defGetInt64, defGetString, get_namespace_name, - get_rel_namespace, makeDefElem, makeString, make_parsestate, quote_qualified_identifier, - AccessShareLock, AsPgCStr, CopyStmt, CreateTemplateTupleDesc, DefElem, List, NoLock, Node, - NodeTag::T_CopyStmt, Oid, ParseNamespaceItem, ParseState, PlannedStmt, QueryEnvironment, - RangeVar, RangeVarGetRelidExtended, RowExclusiveLock, TupleDescInitEntry, + addRangeTableEntryForRelation, defGetBoolean, defGetInt32, defGetInt64, defGetString, + get_namespace_name, get_rel_namespace, makeDefElem, makeString, make_parsestate, + quote_qualified_identifier, AccessShareLock, AsPgCStr, CopyStmt, CreateTemplateTupleDesc, + DefElem, List, NoLock, Node, NodeTag::T_CopyStmt, Oid, ParseNamespaceItem, ParseState, + PlannedStmt, QueryEnvironment, RangeVar, RangeVarGetRelidExtended, RowExclusiveLock, + TupleDescInitEntry, }, PgBox, PgList, PgRelation, PgTupleDesc, }; @@ -109,7 +110,7 @@ pub(crate) fn validate_copy_to_options(p_stmt: &PgBox, uri: &Url) { } pub(crate) fn validate_copy_from_options(p_stmt: &PgBox) { - validate_copy_option_names(p_stmt, &["format", "freeze"]); + validate_copy_option_names(p_stmt, &["format", "match_by_name", "freeze"]); let format_option = copy_stmt_get_option(p_stmt, "format"); @@ -253,6 +254,16 @@ pub(crate) fn copy_from_stmt_create_option_list(p_stmt: &PgBox) -> new_copy_options } +pub(crate) fn copy_from_stmt_match_by_name(p_stmt: &PgBox) -> bool { + let match_by_name_option = copy_stmt_get_option(p_stmt, "match_by_name"); + + if match_by_name_option.is_null() { + false + } else { + unsafe { defGetBoolean(match_by_name_option.as_ptr()) } + } +} + pub(crate) fn copy_stmt_get_option( p_stmt: &PgBox, option_name: &str, diff --git a/src/pgrx_tests/copy_from_coerce.rs b/src/pgrx_tests/copy_from_coerce.rs index 75c7af8..ed0ebb9 100644 --- a/src/pgrx_tests/copy_from_coerce.rs +++ b/src/pgrx_tests/copy_from_coerce.rs @@ -966,7 +966,7 @@ mod tests { } #[pg_test] - fn test_table_with_different_field_position() { + fn test_table_with_different_position_match_by_name() { let copy_to = format!( "COPY (SELECT 1 as x, 'hello' as y) TO '{}'", LOCAL_TEST_FILE_PATH @@ -976,13 +976,44 @@ mod tests { let create_table = "CREATE TABLE test_table (y text, x int)"; Spi::run(create_table).unwrap(); - let copy_from = format!("COPY test_table FROM '{}'", LOCAL_TEST_FILE_PATH); + let copy_from = format!( + "COPY test_table FROM '{}' WITH (match_by_name true)", + LOCAL_TEST_FILE_PATH + ); Spi::run(©_from).unwrap(); let result = Spi::get_two::<&str, i32>("SELECT y, x FROM test_table LIMIT 1").unwrap(); assert_eq!(result, (Some("hello"), Some(1))); } + #[pg_test] + fn test_table_with_different_name_match_by_position() { + let copy_to = "COPY (SELECT 1 as a, 'hello' as b) TO '/tmp/test.parquet'"; + Spi::run(copy_to).unwrap(); + + let create_table = "CREATE TABLE test_table (x bigint, y varchar)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let result = Spi::get_two::("SELECT x, y FROM test_table LIMIT 1").unwrap(); + assert_eq!(result, (Some(1), Some("hello"))); + } + + #[pg_test] + #[should_panic(expected = "column count mismatch between table and parquet file")] + fn test_table_with_different_name_match_by_position_fail() { + let copy_to = "COPY (SELECT 1 as a, 'hello' as b) TO '/tmp/test.parquet'"; + Spi::run(copy_to).unwrap(); + + let create_table = "CREATE TABLE test_table (x bigint, y varchar, z int)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + } + #[pg_test] #[should_panic(expected = "column \"name\" is not found in parquet file")] fn test_missing_column_in_parquet() { @@ -992,7 +1023,10 @@ mod tests { let copy_to_parquet = format!("copy (select 100 as id) to '{}';", LOCAL_TEST_FILE_PATH); Spi::run(©_to_parquet).unwrap(); - let copy_from = format!("COPY test_table FROM '{}'", LOCAL_TEST_FILE_PATH); + let copy_from = format!( + "COPY test_table FROM '{}' with (match_by_name true)", + LOCAL_TEST_FILE_PATH + ); Spi::run(©_from).unwrap(); } diff --git a/src/pgrx_tests/copy_pg_rules.rs b/src/pgrx_tests/copy_pg_rules.rs index b0347e2..35c44d6 100644 --- a/src/pgrx_tests/copy_pg_rules.rs +++ b/src/pgrx_tests/copy_pg_rules.rs @@ -101,6 +101,21 @@ mod tests { Spi::run(©_from).unwrap(); } + #[pg_test] + #[should_panic(expected = "COPY FROM parquet with generated columns is not supported")] + fn test_copy_from_by_position_with_generated_columns_not_supported() { + Spi::run("DROP TABLE IF EXISTS test_table;").unwrap(); + + Spi::run("CREATE TABLE test_table (a int, b int generated always as (10) stored, c text);") + .unwrap(); + + let copy_from_query = format!( + "COPY test_table FROM '{}' WITH (format parquet);", + LOCAL_TEST_FILE_PATH + ); + Spi::run(copy_from_query.as_str()).unwrap(); + } + #[pg_test] fn test_with_generated_and_dropped_columns() { Spi::run("DROP TABLE IF EXISTS test_table;").unwrap(); @@ -123,7 +138,7 @@ mod tests { Spi::run("TRUNCATE test_table;").unwrap(); let copy_from_query = format!( - "COPY test_table FROM '{}' WITH (format parquet);", + "COPY test_table FROM '{}' WITH (format parquet, match_by_name true);", LOCAL_TEST_FILE_PATH ); Spi::run(copy_from_query.as_str()).unwrap(); diff --git a/src/pgrx_utils.rs b/src/pgrx_utils.rs index cb8f9fa..0793def 100644 --- a/src/pgrx_utils.rs +++ b/src/pgrx_utils.rs @@ -12,7 +12,7 @@ use pgrx::{ pub(crate) enum CollectAttributesFor { CopyFrom, CopyTo, - Struct, + Other, } // collect_attributes_for collects not-dropped attributes from the tuple descriptor. @@ -23,7 +23,7 @@ pub(crate) fn collect_attributes_for( ) -> Vec { let include_generated_columns = match copy_operation { CollectAttributesFor::CopyFrom => false, - CollectAttributesFor::CopyTo | CollectAttributesFor::Struct => true, + CollectAttributesFor::CopyTo | CollectAttributesFor::Other => true, }; let mut attributes = vec![]; @@ -35,7 +35,7 @@ pub(crate) fn collect_attributes_for( continue; } - if !include_generated_columns && attribute.attgenerated != 0 { + if !include_generated_columns && is_generated_attribute(attribute) { continue; } @@ -55,6 +55,10 @@ pub(crate) fn collect_attributes_for( attributes } +pub(crate) fn is_generated_attribute(attribute: &FormData_pg_attribute) -> bool { + attribute.attgenerated != 0 +} + pub(crate) fn tuple_desc(typoid: Oid, typmod: i32) -> PgTupleDesc<'static> { let tupledesc = unsafe { lookup_rowtype_tupdesc(typoid, typmod) }; unsafe { PgTupleDesc::from_pg(tupledesc) }