Skip to content

Commit

Permalink
Match fields by name via option
Browse files Browse the repository at this point in the history
We add an option for `COPY FROM` called `match_by_name` which matches Parquet file fields to PostgreSQL table columns
`by their names` rather than `by their order` in the schema. By default, the option is `false`. The option is useful
when field order differs between the Parquet file and the table, but their names match.

**!!IMPORTANT!!**: This is a breaking change. Before the PR, we match always by name. This is a bit strict and not common
way to match schemas. (e.g. COPY FROM csv at postgres or COPY FROM of duckdb match by field position by default)
This is why we match by position by default and have a COPY FROM option `match_by_name` that can be set to true
for the old behaviour.

Closes #39.
  • Loading branch information
aykut-bozkurt committed Nov 27, 2024
1 parent fbaeadb commit 4241cae
Show file tree
Hide file tree
Showing 10 changed files with 163 additions and 35 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ Alternatively, you can use the following environment variables when starting pos

`pg_parquet` supports the following options in the `COPY FROM` command:
- `format parquet`: you need to specify this option to read or write Parquet files which does not end with `.parquet[.<compression>]` extension,
- `match_by_name <bool>`: matches Parquet file fields to PostgreSQL table columns by their name rather than by their position in the schema (default). By default, the option is `false`. The option is useful when field order differs between the Parquet file and the table, but their names match.

## Configuration
There is currently only one GUC parameter to enable/disable the `pg_parquet`:
Expand Down
2 changes: 1 addition & 1 deletion src/arrow_parquet/arrow_to_pg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ impl ArrowToPgAttributeContext {
};

let attributes =
collect_attributes_for(CollectAttributesFor::Struct, attribute_tupledesc);
collect_attributes_for(CollectAttributesFor::Other, attribute_tupledesc);

// we only cast the top-level attributes, which already covers the nested attributes
let cast_to_types = None;
Expand Down
36 changes: 28 additions & 8 deletions src/arrow_parquet/parquet_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ use url::Url;

use crate::{
arrow_parquet::{
arrow_to_pg::to_pg_datum, schema_parser::parquet_schema_string_from_attributes,
arrow_to_pg::to_pg_datum,
schema_parser::{
error_if_copy_from_match_by_position_with_generated_columns,
parquet_schema_string_from_attributes,
},
},
pgrx_utils::{collect_attributes_for, CollectAttributesFor},
type_compat::{geometry::reset_postgis_context, map::reset_map_context},
Expand All @@ -38,15 +42,18 @@ pub(crate) struct ParquetReaderContext {
parquet_reader: ParquetRecordBatchStream<ParquetObjectReader>,
attribute_contexts: Vec<ArrowToPgAttributeContext>,
binary_out_funcs: Vec<PgBox<FmgrInfo>>,
match_by_name: bool,
}

impl ParquetReaderContext {
pub(crate) fn new(uri: Url, tupledesc: &PgTupleDesc) -> Self {
pub(crate) fn new(uri: Url, match_by_name: bool, tupledesc: &PgTupleDesc) -> Self {
// Postgis and Map contexts are used throughout reading the parquet file.
// We need to reset them to avoid reading the stale data. (e.g. extension could be dropped)
reset_postgis_context();
reset_map_context();

error_if_copy_from_match_by_position_with_generated_columns(tupledesc, match_by_name);

let parquet_reader = parquet_reader_from_uri(&uri);

let parquet_file_schema = parquet_reader.schema();
Expand All @@ -69,6 +76,7 @@ impl ParquetReaderContext {
parquet_file_schema.clone(),
tupledesc_schema.clone(),
&attributes,
match_by_name,
);

let attribute_contexts = collect_arrow_to_pg_attribute_contexts(
Expand All @@ -85,6 +93,7 @@ impl ParquetReaderContext {
attribute_contexts,
parquet_reader,
binary_out_funcs,
match_by_name,
started: false,
finished: false,
}
Expand Down Expand Up @@ -116,15 +125,23 @@ impl ParquetReaderContext {
fn record_batch_to_tuple_datums(
record_batch: RecordBatch,
attribute_contexts: &[ArrowToPgAttributeContext],
match_by_name: bool,
) -> Vec<Option<Datum>> {
let mut datums = vec![];

for attribute_context in attribute_contexts {
for (attribute_idx, attribute_context) in attribute_contexts.iter().enumerate() {
let name = attribute_context.name();

let column_array = record_batch
.column_by_name(name)
.unwrap_or_else(|| panic!("column {} not found", name));
let column_array = if match_by_name {
record_batch
.column_by_name(name)
.unwrap_or_else(|| panic!("column {} not found", name))
} else {
record_batch
.columns()
.get(attribute_idx)
.unwrap_or_else(|| panic!("column {} not found", name))
};

let datum = if attribute_context.needs_cast() {
// should fail instead of returning None if the cast fails at runtime
Expand Down Expand Up @@ -181,8 +198,11 @@ impl ParquetReaderContext {
self.buffer.extend_from_slice(&attnum_len_bytes);

// convert the columnar arrays in record batch to tuple datums
let tuple_datums =
Self::record_batch_to_tuple_datums(record_batch, &self.attribute_contexts);
let tuple_datums = Self::record_batch_to_tuple_datums(
record_batch,
&self.attribute_contexts,
self.match_by_name,
);

// write the tuple datums to the ParquetReader's internal buffer in PG copy format
for (datum, out_func) in tuple_datums.into_iter().zip(self.binary_out_funcs.iter())
Expand Down
2 changes: 1 addition & 1 deletion src/arrow_parquet/pg_to_arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ impl PgToArrowAttributeContext {
};

let attributes =
collect_attributes_for(CollectAttributesFor::Struct, &attribute_tupledesc);
collect_attributes_for(CollectAttributesFor::Other, &attribute_tupledesc);

collect_pg_to_arrow_attribute_contexts(&attributes, &fields)
});
Expand Down
59 changes: 50 additions & 9 deletions src/arrow_parquet/schema_parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use pgrx::{check_for_interrupts, prelude::*, PgTupleDesc};
use crate::{
pgrx_utils::{
array_element_typoid, collect_attributes_for, domain_array_base_elem_typoid, is_array_type,
is_composite_type, tuple_desc, CollectAttributesFor,
is_composite_type, is_generated_attribute, tuple_desc, CollectAttributesFor,
},
type_compat::{
geometry::is_postgis_geometry_type,
Expand Down Expand Up @@ -95,7 +95,7 @@ fn parse_struct_schema(tupledesc: PgTupleDesc, elem_name: &str, field_id: &mut i

let mut child_fields: Vec<Arc<Field>> = vec![];

let attributes = collect_attributes_for(CollectAttributesFor::Struct, &tupledesc);
let attributes = collect_attributes_for(CollectAttributesFor::Other, &tupledesc);

for attribute in attributes {
if attribute.is_dropped() {
Expand Down Expand Up @@ -342,28 +342,69 @@ fn adjust_map_entries_field(field: FieldRef) -> FieldRef {
Arc::new(entries_field)
}

pub(crate) fn error_if_copy_from_match_by_position_with_generated_columns(
tupledesc: &PgTupleDesc,
match_by_name: bool,
) {
// match_by_name can handle generated columns
if match_by_name {
return;
}

let attributes = collect_attributes_for(CollectAttributesFor::Other, tupledesc);

for attribute in attributes {
if is_generated_attribute(&attribute) {
ereport!(
PgLogLevel::ERROR,
PgSqlErrorCode::ERRCODE_FEATURE_NOT_SUPPORTED,
"COPY FROM parquet with generated columns is not supported",
"Try COPY FROM parquet WITH (match_by_name true). \"
It works only if the column names match with parquet file's.",
);
}
}
}

// ensure_file_schema_match_tupledesc_schema throws an error if the file's schema does not match the table schema.
// If the file's arrow schema is castable to the table's arrow schema, it returns a vector of Option<DataType>
// to cast to for each field.
pub(crate) fn ensure_file_schema_match_tupledesc_schema(
file_schema: Arc<Schema>,
tupledesc_schema: Arc<Schema>,
attributes: &[FormData_pg_attribute],
match_by_name: bool,
) -> Vec<Option<DataType>> {
let mut cast_to_types = Vec::new();

if !match_by_name && tupledesc_schema.fields().len() != file_schema.fields().len() {
panic!(
"column count mismatch between table and parquet file. \
parquet file has {} columns, but table has {} columns",
file_schema.fields().len(),
tupledesc_schema.fields().len()
);
}

for (tupledesc_schema_field, attribute) in
tupledesc_schema.fields().iter().zip(attributes.iter())
{
let field_name = tupledesc_schema_field.name();

let file_schema_field = file_schema.column_with_name(field_name);
let file_schema_field = if match_by_name {
let file_schema_field = file_schema.column_with_name(field_name);

if file_schema_field.is_none() {
panic!("column \"{}\" is not found in parquet file", field_name);
}
if file_schema_field.is_none() {
panic!("column \"{}\" is not found in parquet file", field_name);
}

let (_, file_schema_field) = file_schema_field.unwrap();

file_schema_field
} else {
file_schema.field(attribute.attnum as usize - 1)
};

let (_, file_schema_field) = file_schema_field.unwrap();
let file_schema_field = Arc::new(file_schema_field.clone());

let from_type = file_schema_field.data_type();
Expand All @@ -378,7 +419,7 @@ pub(crate) fn ensure_file_schema_match_tupledesc_schema(
if !is_coercible(from_type, to_type, attribute.atttypid, attribute.atttypmod) {
panic!(
"type mismatch for column \"{}\" between table and parquet file.\n\n\
table has \"{}\"\n\nparquet file has \"{}\"",
table has \"{}\"\n\nparquet file has \"{}\"",
field_name, to_type, from_type
);
}
Expand Down Expand Up @@ -413,7 +454,7 @@ fn is_coercible(from_type: &DataType, to_type: &DataType, to_typoid: Oid, to_typ

let tupledesc = tuple_desc(to_typoid, to_typmod);

let attributes = collect_attributes_for(CollectAttributesFor::Struct, &tupledesc);
let attributes = collect_attributes_for(CollectAttributesFor::Other, &tupledesc);

for (from_field, (to_field, to_attribute)) in from_fields
.iter()
Expand Down
8 changes: 5 additions & 3 deletions src/parquet_copy_hook/copy_from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use crate::{
};

use super::copy_utils::{
copy_stmt_attribute_list, copy_stmt_create_namespace_item, copy_stmt_create_parse_state,
create_filtered_tupledesc_for_relation,
copy_from_stmt_match_by_name, copy_stmt_attribute_list, copy_stmt_create_namespace_item,
copy_stmt_create_parse_state, create_filtered_tupledesc_for_relation,
};

// stack to store parquet reader contexts for COPY FROM.
Expand Down Expand Up @@ -131,9 +131,11 @@ pub(crate) fn execute_copy_from(

let tupledesc = create_filtered_tupledesc_for_relation(p_stmt, &relation);

let match_by_name = copy_from_stmt_match_by_name(p_stmt);

unsafe {
// parquet reader context is used throughout the COPY FROM operation.
let parquet_reader_context = ParquetReaderContext::new(uri, &tupledesc);
let parquet_reader_context = ParquetReaderContext::new(uri, match_by_name, &tupledesc);
push_parquet_reader_context(parquet_reader_context);

// makes sure to set binary format
Expand Down
23 changes: 17 additions & 6 deletions src/parquet_copy_hook/copy_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ use std::{ffi::CStr, str::FromStr};
use pgrx::{
is_a,
pg_sys::{
addRangeTableEntryForRelation, defGetInt32, defGetInt64, defGetString, get_namespace_name,
get_rel_namespace, makeDefElem, makeString, make_parsestate, quote_qualified_identifier,
AccessShareLock, AsPgCStr, CopyStmt, CreateTemplateTupleDesc, DefElem, List, NoLock, Node,
NodeTag::T_CopyStmt, Oid, ParseNamespaceItem, ParseState, PlannedStmt, QueryEnvironment,
RangeVar, RangeVarGetRelidExtended, RowExclusiveLock, TupleDescInitEntry,
addRangeTableEntryForRelation, defGetBoolean, defGetInt32, defGetInt64, defGetString,
get_namespace_name, get_rel_namespace, makeDefElem, makeString, make_parsestate,
quote_qualified_identifier, AccessShareLock, AsPgCStr, CopyStmt, CreateTemplateTupleDesc,
DefElem, List, NoLock, Node, NodeTag::T_CopyStmt, Oid, ParseNamespaceItem, ParseState,
PlannedStmt, QueryEnvironment, RangeVar, RangeVarGetRelidExtended, RowExclusiveLock,
TupleDescInitEntry,
},
PgBox, PgList, PgRelation, PgTupleDesc,
};
Expand Down Expand Up @@ -109,7 +110,7 @@ pub(crate) fn validate_copy_to_options(p_stmt: &PgBox<PlannedStmt>, uri: &Url) {
}

pub(crate) fn validate_copy_from_options(p_stmt: &PgBox<PlannedStmt>) {
validate_copy_option_names(p_stmt, &["format", "freeze"]);
validate_copy_option_names(p_stmt, &["format", "match_by_name", "freeze"]);

let format_option = copy_stmt_get_option(p_stmt, "format");

Expand Down Expand Up @@ -253,6 +254,16 @@ pub(crate) fn copy_from_stmt_create_option_list(p_stmt: &PgBox<PlannedStmt>) ->
new_copy_options
}

pub(crate) fn copy_from_stmt_match_by_name(p_stmt: &PgBox<PlannedStmt>) -> bool {
let match_by_name_option = copy_stmt_get_option(p_stmt, "match_by_name");

if match_by_name_option.is_null() {
false
} else {
unsafe { defGetBoolean(match_by_name_option.as_ptr()) }
}
}

pub(crate) fn copy_stmt_get_option(
p_stmt: &PgBox<PlannedStmt>,
option_name: &str,
Expand Down
40 changes: 37 additions & 3 deletions src/pgrx_tests/copy_from_coerce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -966,7 +966,7 @@ mod tests {
}

#[pg_test]
fn test_table_with_different_field_position() {
fn test_table_with_different_position_match_by_name() {
let copy_to = format!(
"COPY (SELECT 1 as x, 'hello' as y) TO '{}'",
LOCAL_TEST_FILE_PATH
Expand All @@ -976,13 +976,44 @@ mod tests {
let create_table = "CREATE TABLE test_table (y text, x int)";
Spi::run(create_table).unwrap();

let copy_from = format!("COPY test_table FROM '{}'", LOCAL_TEST_FILE_PATH);
let copy_from = format!(
"COPY test_table FROM '{}' WITH (match_by_name true)",
LOCAL_TEST_FILE_PATH
);
Spi::run(&copy_from).unwrap();

let result = Spi::get_two::<&str, i32>("SELECT y, x FROM test_table LIMIT 1").unwrap();
assert_eq!(result, (Some("hello"), Some(1)));
}

#[pg_test]
fn test_table_with_different_name_match_by_position() {
let copy_to = "COPY (SELECT 1 as a, 'hello' as b) TO '/tmp/test.parquet'";
Spi::run(copy_to).unwrap();

let create_table = "CREATE TABLE test_table (x bigint, y varchar)";
Spi::run(create_table).unwrap();

let copy_from = "COPY test_table FROM '/tmp/test.parquet'";
Spi::run(copy_from).unwrap();

let result = Spi::get_two::<i64, &str>("SELECT x, y FROM test_table LIMIT 1").unwrap();
assert_eq!(result, (Some(1), Some("hello")));
}

#[pg_test]
#[should_panic(expected = "column count mismatch between table and parquet file")]
fn test_table_with_different_name_match_by_position_fail() {
let copy_to = "COPY (SELECT 1 as a, 'hello' as b) TO '/tmp/test.parquet'";
Spi::run(copy_to).unwrap();

let create_table = "CREATE TABLE test_table (x bigint, y varchar, z int)";
Spi::run(create_table).unwrap();

let copy_from = "COPY test_table FROM '/tmp/test.parquet'";
Spi::run(copy_from).unwrap();
}

#[pg_test]
#[should_panic(expected = "column \"name\" is not found in parquet file")]
fn test_missing_column_in_parquet() {
Expand All @@ -992,7 +1023,10 @@ mod tests {
let copy_to_parquet = format!("copy (select 100 as id) to '{}';", LOCAL_TEST_FILE_PATH);
Spi::run(&copy_to_parquet).unwrap();

let copy_from = format!("COPY test_table FROM '{}'", LOCAL_TEST_FILE_PATH);
let copy_from = format!(
"COPY test_table FROM '{}' with (match_by_name true)",
LOCAL_TEST_FILE_PATH
);
Spi::run(&copy_from).unwrap();
}

Expand Down
17 changes: 16 additions & 1 deletion src/pgrx_tests/copy_pg_rules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,21 @@ mod tests {
Spi::run(&copy_from).unwrap();
}

#[pg_test]
#[should_panic(expected = "COPY FROM parquet with generated columns is not supported")]
fn test_copy_from_by_position_with_generated_columns_not_supported() {
Spi::run("DROP TABLE IF EXISTS test_table;").unwrap();

Spi::run("CREATE TABLE test_table (a int, b int generated always as (10) stored, c text);")
.unwrap();

let copy_from_query = format!(
"COPY test_table FROM '{}' WITH (format parquet);",
LOCAL_TEST_FILE_PATH
);
Spi::run(copy_from_query.as_str()).unwrap();
}

#[pg_test]
fn test_with_generated_and_dropped_columns() {
Spi::run("DROP TABLE IF EXISTS test_table;").unwrap();
Expand All @@ -123,7 +138,7 @@ mod tests {
Spi::run("TRUNCATE test_table;").unwrap();

let copy_from_query = format!(
"COPY test_table FROM '{}' WITH (format parquet);",
"COPY test_table FROM '{}' WITH (format parquet, match_by_name true);",
LOCAL_TEST_FILE_PATH
);
Spi::run(copy_from_query.as_str()).unwrap();
Expand Down
Loading

0 comments on commit 4241cae

Please sign in to comment.