Skip to content

Commit

Permalink
Match fields by position via option
Browse files Browse the repository at this point in the history
We add an option for `COPY FROM` called `match_by_position` which matches Parquet file fields to PostgreSQL table columns
`by their position` in the schema rather than `by their names`. By default, the option is `false`. The option is useful
when field names differ between the Parquet file and the table, but their order aligns.

Closes #39.
  • Loading branch information
aykut-bozkurt committed Nov 27, 2024
1 parent fbaeadb commit ed4907a
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 23 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ Alternatively, you can use the following environment variables when starting pos

`pg_parquet` supports the following options in the `COPY FROM` command:
- `format parquet`: you need to specify this option to read or write Parquet files which does not end with `.parquet[.<compression>]` extension,
- `match_by_position <bool>`: matches Parquet file fields to PostgreSQL table columns by their position in the schema rather than by their names. By default, the option is `false`. The option is useful when field names differ between the Parquet file and the table, but their order aligns.

## Configuration
There is currently only one GUC parameter to enable/disable the `pg_parquet`:
Expand Down
28 changes: 21 additions & 7 deletions src/arrow_parquet/parquet_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@ pub(crate) struct ParquetReaderContext {
parquet_reader: ParquetRecordBatchStream<ParquetObjectReader>,
attribute_contexts: Vec<ArrowToPgAttributeContext>,
binary_out_funcs: Vec<PgBox<FmgrInfo>>,
match_by_position: bool,
}

impl ParquetReaderContext {
pub(crate) fn new(uri: Url, tupledesc: &PgTupleDesc) -> Self {
pub(crate) fn new(uri: Url, match_by_position: bool, tupledesc: &PgTupleDesc) -> Self {
// Postgis and Map contexts are used throughout reading the parquet file.
// We need to reset them to avoid reading the stale data. (e.g. extension could be dropped)
reset_postgis_context();
Expand Down Expand Up @@ -69,6 +70,7 @@ impl ParquetReaderContext {
parquet_file_schema.clone(),
tupledesc_schema.clone(),
&attributes,
match_by_position,
);

let attribute_contexts = collect_arrow_to_pg_attribute_contexts(
Expand All @@ -85,6 +87,7 @@ impl ParquetReaderContext {
attribute_contexts,
parquet_reader,
binary_out_funcs,
match_by_position,
started: false,
finished: false,
}
Expand Down Expand Up @@ -116,15 +119,23 @@ impl ParquetReaderContext {
fn record_batch_to_tuple_datums(
record_batch: RecordBatch,
attribute_contexts: &[ArrowToPgAttributeContext],
match_by_position: bool,
) -> Vec<Option<Datum>> {
let mut datums = vec![];

for attribute_context in attribute_contexts {
for (attribute_idx, attribute_context) in attribute_contexts.iter().enumerate() {
let name = attribute_context.name();

let column_array = record_batch
.column_by_name(name)
.unwrap_or_else(|| panic!("column {} not found", name));
let column_array = if match_by_position {
record_batch
.columns()
.get(attribute_idx)
.unwrap_or_else(|| panic!("column {} not found", name))
} else {
record_batch
.column_by_name(name)
.unwrap_or_else(|| panic!("column {} not found", name))
};

let datum = if attribute_context.needs_cast() {
// should fail instead of returning None if the cast fails at runtime
Expand Down Expand Up @@ -181,8 +192,11 @@ impl ParquetReaderContext {
self.buffer.extend_from_slice(&attnum_len_bytes);

// convert the columnar arrays in record batch to tuple datums
let tuple_datums =
Self::record_batch_to_tuple_datums(record_batch, &self.attribute_contexts);
let tuple_datums = Self::record_batch_to_tuple_datums(
record_batch,
&self.attribute_contexts,
self.match_by_position,
);

// write the tuple datums to the ParquetReader's internal buffer in PG copy format
for (datum, out_func) in tuple_datums.into_iter().zip(self.binary_out_funcs.iter())
Expand Down
29 changes: 23 additions & 6 deletions src/arrow_parquet/schema_parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,21 +349,38 @@ pub(crate) fn ensure_file_schema_match_tupledesc_schema(
file_schema: Arc<Schema>,
tupledesc_schema: Arc<Schema>,
attributes: &[FormData_pg_attribute],
match_by_position: bool,
) -> Vec<Option<DataType>> {
let mut cast_to_types = Vec::new();

if match_by_position && tupledesc_schema.fields().len() != file_schema.fields().len() {
panic!(
"column count mismatch between table and parquet file. \
parquet file has {} columns, but table has {} columns",
file_schema.fields().len(),
tupledesc_schema.fields().len()
);
}

for (tupledesc_schema_field, attribute) in
tupledesc_schema.fields().iter().zip(attributes.iter())
{
let field_name = tupledesc_schema_field.name();

let file_schema_field = file_schema.column_with_name(field_name);
let file_schema_field = if match_by_position {
file_schema.field(attribute.attnum as usize - 1)
} else {
let file_schema_field = file_schema.column_with_name(field_name);

if file_schema_field.is_none() {
panic!("column \"{}\" is not found in parquet file", field_name);
}
if file_schema_field.is_none() {
panic!("column \"{}\" is not found in parquet file", field_name);
}

let (_, file_schema_field) = file_schema_field.unwrap();

file_schema_field
};

let (_, file_schema_field) = file_schema_field.unwrap();
let file_schema_field = Arc::new(file_schema_field.clone());

let from_type = file_schema_field.data_type();
Expand All @@ -378,7 +395,7 @@ pub(crate) fn ensure_file_schema_match_tupledesc_schema(
if !is_coercible(from_type, to_type, attribute.atttypid, attribute.atttypmod) {
panic!(
"type mismatch for column \"{}\" between table and parquet file.\n\n\
table has \"{}\"\n\nparquet file has \"{}\"",
table has \"{}\"\n\nparquet file has \"{}\"",
field_name, to_type, from_type
);
}
Expand Down
8 changes: 5 additions & 3 deletions src/parquet_copy_hook/copy_from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use crate::{
};

use super::copy_utils::{
copy_stmt_attribute_list, copy_stmt_create_namespace_item, copy_stmt_create_parse_state,
create_filtered_tupledesc_for_relation,
copy_from_stmt_match_by_position, copy_stmt_attribute_list, copy_stmt_create_namespace_item,
copy_stmt_create_parse_state, create_filtered_tupledesc_for_relation,
};

// stack to store parquet reader contexts for COPY FROM.
Expand Down Expand Up @@ -131,9 +131,11 @@ pub(crate) fn execute_copy_from(

let tupledesc = create_filtered_tupledesc_for_relation(p_stmt, &relation);

let match_by_position = copy_from_stmt_match_by_position(p_stmt);

unsafe {
// parquet reader context is used throughout the COPY FROM operation.
let parquet_reader_context = ParquetReaderContext::new(uri, &tupledesc);
let parquet_reader_context = ParquetReaderContext::new(uri, match_by_position, &tupledesc);
push_parquet_reader_context(parquet_reader_context);

// makes sure to set binary format
Expand Down
23 changes: 17 additions & 6 deletions src/parquet_copy_hook/copy_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ use std::{ffi::CStr, str::FromStr};
use pgrx::{
is_a,
pg_sys::{
addRangeTableEntryForRelation, defGetInt32, defGetInt64, defGetString, get_namespace_name,
get_rel_namespace, makeDefElem, makeString, make_parsestate, quote_qualified_identifier,
AccessShareLock, AsPgCStr, CopyStmt, CreateTemplateTupleDesc, DefElem, List, NoLock, Node,
NodeTag::T_CopyStmt, Oid, ParseNamespaceItem, ParseState, PlannedStmt, QueryEnvironment,
RangeVar, RangeVarGetRelidExtended, RowExclusiveLock, TupleDescInitEntry,
addRangeTableEntryForRelation, defGetBoolean, defGetInt32, defGetInt64, defGetString,
get_namespace_name, get_rel_namespace, makeDefElem, makeString, make_parsestate,
quote_qualified_identifier, AccessShareLock, AsPgCStr, CopyStmt, CreateTemplateTupleDesc,
DefElem, List, NoLock, Node, NodeTag::T_CopyStmt, Oid, ParseNamespaceItem, ParseState,
PlannedStmt, QueryEnvironment, RangeVar, RangeVarGetRelidExtended, RowExclusiveLock,
TupleDescInitEntry,
},
PgBox, PgList, PgRelation, PgTupleDesc,
};
Expand Down Expand Up @@ -109,7 +110,7 @@ pub(crate) fn validate_copy_to_options(p_stmt: &PgBox<PlannedStmt>, uri: &Url) {
}

pub(crate) fn validate_copy_from_options(p_stmt: &PgBox<PlannedStmt>) {
validate_copy_option_names(p_stmt, &["format", "freeze"]);
validate_copy_option_names(p_stmt, &["format", "match_by_position", "freeze"]);

let format_option = copy_stmt_get_option(p_stmt, "format");

Expand Down Expand Up @@ -253,6 +254,16 @@ pub(crate) fn copy_from_stmt_create_option_list(p_stmt: &PgBox<PlannedStmt>) ->
new_copy_options
}

pub(crate) fn copy_from_stmt_match_by_position(p_stmt: &PgBox<PlannedStmt>) -> bool {
let match_by_position_option = copy_stmt_get_option(p_stmt, "match_by_position");

if match_by_position_option.is_null() {
false
} else {
unsafe { defGetBoolean(match_by_position_option.as_ptr()) }
}
}

pub(crate) fn copy_stmt_get_option(
p_stmt: &PgBox<PlannedStmt>,
option_name: &str,
Expand Down
30 changes: 29 additions & 1 deletion src/pgrx_tests/copy_from_coerce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -966,7 +966,7 @@ mod tests {
}

#[pg_test]
fn test_table_with_different_field_position() {
fn test_table_with_different_position_match_by_name() {
let copy_to = format!(
"COPY (SELECT 1 as x, 'hello' as y) TO '{}'",
LOCAL_TEST_FILE_PATH
Expand All @@ -983,6 +983,34 @@ mod tests {
assert_eq!(result, (Some("hello"), Some(1)));
}

#[pg_test]
fn test_table_with_different_name_match_by_position() {
let copy_to = "COPY (SELECT 1 as a, 'hello' as b) TO '/tmp/test.parquet'";
Spi::run(copy_to).unwrap();

let create_table = "CREATE TABLE test_table (x bigint, y varchar)";
Spi::run(create_table).unwrap();

let copy_from = "COPY test_table FROM '/tmp/test.parquet' WITH (match_by_position true)";
Spi::run(copy_from).unwrap();

let result = Spi::get_two::<i64, &str>("SELECT x, y FROM test_table LIMIT 1").unwrap();
assert_eq!(result, (Some(1), Some("hello")));
}

#[pg_test]
#[should_panic(expected = "column count mismatch between table and parquet file")]
fn test_table_with_different_name_match_by_position_fail() {
let copy_to = "COPY (SELECT 1 as a, 'hello' as b) TO '/tmp/test.parquet'";
Spi::run(copy_to).unwrap();

let create_table = "CREATE TABLE test_table (x bigint, y varchar, z int)";
Spi::run(create_table).unwrap();

let copy_from = "COPY test_table FROM '/tmp/test.parquet' WITH (match_by_position true)";
Spi::run(copy_from).unwrap();
}

#[pg_test]
#[should_panic(expected = "column \"name\" is not found in parquet file")]
fn test_missing_column_in_parquet() {
Expand Down

0 comments on commit ed4907a

Please sign in to comment.