From ef6f07d7167afb497d21686151416d31c4834f19 Mon Sep 17 00:00:00 2001 From: Aykut Bozkurt Date: Mon, 11 Nov 2024 14:50:37 +0300 Subject: [PATCH] Coerce types on read `COPY FROM parquet` is too strict when matching Postgres tupledesc schema to the schema from parquet file. e.g. `INT32` type in the parquet schema cannot be read into a Postgres column with `int64` type. We can avoid this situation by adding a `is_coercible(from_type, to_type)` check while matching the expected schema from the parquet file. With that we can coerce as shown below from parquet source type to Postgres destination types: - INT16 => {int32, int64} - INT32 => {int64} - UINT16 => {int16, int32, int64} - UINT32 => {int32, int64} - UINT64 => {int64} - FLOAT32 => {double} As we use arrow as intermediate format, it might be the case that `LargeUtf8` or `LargeBinary` types are used by the external writer instead of `Utf8` and `Binary`. That is why we also need to support below coercions for arrow source types: - `Utf8 | LargeUtf8` => {text} - `Binary | LargeBinary` => {bytea} Closes #67. --- src/arrow_parquet/arrow_to_pg.rs | 475 ++++++++++++------ src/arrow_parquet/arrow_to_pg/bytea.rs | 27 +- src/arrow_parquet/arrow_to_pg/char.rs | 28 +- .../arrow_to_pg/fallback_to_text.rs | 28 +- src/arrow_parquet/arrow_to_pg/float4.rs | 21 + src/arrow_parquet/arrow_to_pg/geometry.rs | 27 +- src/arrow_parquet/arrow_to_pg/int2.rs | 107 +++- src/arrow_parquet/arrow_to_pg/int4.rs | 65 ++- src/arrow_parquet/arrow_to_pg/int8.rs | 23 +- src/arrow_parquet/arrow_to_pg/text.rs | 24 +- src/arrow_parquet/schema_parser.rs | 50 +- src/lib.rs | 220 ++++++++ 12 files changed, 922 insertions(+), 173 deletions(-) diff --git a/src/arrow_parquet/arrow_to_pg.rs b/src/arrow_parquet/arrow_to_pg.rs index ec7c9ce..43592df 100644 --- a/src/arrow_parquet/arrow_to_pg.rs +++ b/src/arrow_parquet/arrow_to_pg.rs @@ -1,14 +1,17 @@ +use std::ops::Deref; + use arrow::array::{ Array, ArrayData, BinaryArray, BooleanArray, Date32Array, Decimal128Array, Float32Array, - Float64Array, Int16Array, Int32Array, Int64Array, ListArray, MapArray, StringArray, - StructArray, Time64MicrosecondArray, TimestampMicrosecondArray, UInt32Array, + Float64Array, Int16Array, Int32Array, Int64Array, LargeBinaryArray, LargeStringArray, + ListArray, MapArray, StringArray, StructArray, Time64MicrosecondArray, + TimestampMicrosecondArray, UInt16Array, UInt32Array, UInt64Array, }; -use arrow_schema::Fields; +use arrow_schema::{DataType, FieldRef, Fields, TimeUnit}; use pgrx::{ datum::{Date, Time, TimeWithTimeZone, Timestamp, TimestampWithTimeZone}, pg_sys::{ - Datum, Oid, BOOLOID, BYTEAOID, CHAROID, DATEOID, FLOAT4OID, FLOAT8OID, INT2OID, INT4OID, - INT8OID, NUMERICOID, OIDOID, TEXTOID, TIMEOID, TIMESTAMPOID, TIMESTAMPTZOID, TIMETZOID, + Datum, Oid, CHAROID, FLOAT4OID, FLOAT8OID, INT2OID, INT4OID, INT8OID, NUMERICOID, OIDOID, + TEXTOID, TIMEOID, }, prelude::PgHeapTuple, AllocatedByRust, AnyNumeric, IntoDatum, PgTupleDesc, @@ -23,9 +26,7 @@ use crate::{ fallback_to_text::{reset_fallback_to_text_context, FallbackToText}, geometry::{is_postgis_geometry_type, Geometry}, map::{is_map_type, Map}, - pg_arrow_type_conversions::{ - extract_precision_and_scale_from_numeric_typmod, should_write_numeric_as_text, - }, + pg_arrow_type_conversions::extract_precision_and_scale_from_numeric_typmod, }, }; @@ -57,12 +58,10 @@ pub(crate) trait ArrowArrayToPgType: From { #[derive(Clone)] pub(crate) struct ArrowToPgAttributeContext { name: String, + field: FieldRef, typoid: Oid, typmod: i32, - is_array: bool, - is_composite: bool, is_geometry: bool, - is_map: bool, attribute_contexts: Option>, attribute_tupledesc: Option>, precision: Option, @@ -157,12 +156,10 @@ impl ArrowToPgAttributeContext { Self { name: name.to_string(), + field, typoid: attribute_typoid, typmod, - is_array, - is_composite, is_geometry, - is_map, attribute_contexts, attribute_tupledesc, scale, @@ -206,7 +203,7 @@ pub(crate) fn to_pg_datum( attribute_array: ArrayData, attribute_context: &ArrowToPgAttributeContext, ) -> Option { - if attribute_context.is_array { + if matches!(attribute_array.data_type(), DataType::List(_)) { to_pg_array_datum(attribute_array, attribute_context) } else { to_pg_nonarray_datum(attribute_array, attribute_context) @@ -227,43 +224,71 @@ fn to_pg_nonarray_datum( primitive_array: ArrayData, attribute_context: &ArrowToPgAttributeContext, ) -> Option { - match attribute_context.typoid { - FLOAT4OID => { - to_pg_datum!(Float32Array, f32, primitive_array, attribute_context) + match attribute_context.field.data_type() { + DataType::Float32 => { + if attribute_context.typoid == FLOAT4OID { + to_pg_datum!(Float32Array, f32, primitive_array, attribute_context) + } else { + debug_assert!(attribute_context.typoid == FLOAT8OID); + to_pg_datum!(Float32Array, f64, primitive_array, attribute_context) + } } - FLOAT8OID => { + DataType::Float64 => { to_pg_datum!(Float64Array, f64, primitive_array, attribute_context) } - INT2OID => { - to_pg_datum!(Int16Array, i16, primitive_array, attribute_context) - } - INT4OID => { - to_pg_datum!(Int32Array, i32, primitive_array, attribute_context) + DataType::Int16 => { + if attribute_context.typoid == INT2OID { + to_pg_datum!(Int16Array, i16, primitive_array, attribute_context) + } else if attribute_context.typoid == INT4OID { + to_pg_datum!(Int16Array, i32, primitive_array, attribute_context) + } else { + debug_assert!(attribute_context.typoid == INT8OID); + to_pg_datum!(Int16Array, i64, primitive_array, attribute_context) + } } - INT8OID => { - to_pg_datum!(Int64Array, i64, primitive_array, attribute_context) + DataType::UInt16 => { + if attribute_context.typoid == INT2OID { + to_pg_datum!(UInt16Array, i16, primitive_array, attribute_context) + } else if attribute_context.typoid == INT4OID { + to_pg_datum!(UInt16Array, i32, primitive_array, attribute_context) + } else { + debug_assert!(attribute_context.typoid == INT8OID); + to_pg_datum!(UInt16Array, i64, primitive_array, attribute_context) + } } - BOOLOID => { - to_pg_datum!(BooleanArray, bool, primitive_array, attribute_context) + DataType::Int32 => { + if attribute_context.typoid == INT4OID { + to_pg_datum!(Int32Array, i32, primitive_array, attribute_context) + } else { + debug_assert!(attribute_context.typoid == INT8OID); + to_pg_datum!(Int32Array, i64, primitive_array, attribute_context) + } } - CHAROID => { - to_pg_datum!(StringArray, i8, primitive_array, attribute_context) + DataType::UInt32 => { + if attribute_context.typoid == OIDOID { + to_pg_datum!(UInt32Array, Oid, primitive_array, attribute_context) + } else if attribute_context.typoid == INT4OID { + to_pg_datum!(UInt32Array, i32, primitive_array, attribute_context) + } else { + debug_assert!(attribute_context.typoid == INT8OID); + to_pg_datum!(UInt32Array, i64, primitive_array, attribute_context) + } } - TEXTOID => { - to_pg_datum!(StringArray, String, primitive_array, attribute_context) + DataType::Int64 => { + to_pg_datum!(Int64Array, i64, primitive_array, attribute_context) } - BYTEAOID => { - to_pg_datum!(BinaryArray, Vec, primitive_array, attribute_context) + DataType::UInt64 => { + to_pg_datum!(UInt64Array, i64, primitive_array, attribute_context) } - OIDOID => { - to_pg_datum!(UInt32Array, Oid, primitive_array, attribute_context) + DataType::Boolean => { + to_pg_datum!(BooleanArray, bool, primitive_array, attribute_context) } - NUMERICOID => { - let precision = attribute_context - .precision - .expect("missing precision in context"); - - if should_write_numeric_as_text(precision) { + DataType::Utf8 => { + if attribute_context.typoid == CHAROID { + to_pg_datum!(StringArray, i8, primitive_array, attribute_context) + } else if attribute_context.typoid == TEXTOID { + to_pg_datum!(StringArray, String, primitive_array, attribute_context) + } else { reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); to_pg_datum!( @@ -272,72 +297,110 @@ fn to_pg_nonarray_datum( primitive_array, attribute_context ) + } + } + DataType::LargeUtf8 => { + if attribute_context.typoid == CHAROID { + to_pg_datum!(LargeStringArray, i8, primitive_array, attribute_context) + } else if attribute_context.typoid == TEXTOID { + to_pg_datum!(LargeStringArray, String, primitive_array, attribute_context) } else { + reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); + to_pg_datum!( - Decimal128Array, - AnyNumeric, + LargeStringArray, + FallbackToText, primitive_array, attribute_context ) } } - DATEOID => { - to_pg_datum!(Date32Array, Date, primitive_array, attribute_context) + DataType::Binary => { + if attribute_context.is_geometry { + to_pg_datum!(BinaryArray, Geometry, primitive_array, attribute_context) + } else { + to_pg_datum!(BinaryArray, Vec, primitive_array, attribute_context) + } + } + DataType::LargeBinary => { + if attribute_context.is_geometry { + to_pg_datum!( + LargeBinaryArray, + Geometry, + primitive_array, + attribute_context + ) + } else { + to_pg_datum!( + LargeBinaryArray, + Vec, + primitive_array, + attribute_context + ) + } } - TIMEOID => { + DataType::Decimal128(_, _) => { to_pg_datum!( - Time64MicrosecondArray, - Time, + Decimal128Array, + AnyNumeric, primitive_array, attribute_context ) } - TIMETZOID => { + DataType::Date32 => { + to_pg_datum!(Date32Array, Date, primitive_array, attribute_context) + } + DataType::Time64(TimeUnit::Microsecond) => { + if attribute_context.typoid == TIMEOID { + to_pg_datum!( + Time64MicrosecondArray, + Time, + primitive_array, + attribute_context + ) + } else { + to_pg_datum!( + Time64MicrosecondArray, + TimeWithTimeZone, + primitive_array, + attribute_context + ) + } + } + DataType::Timestamp(TimeUnit::Microsecond, None) => { to_pg_datum!( - Time64MicrosecondArray, - TimeWithTimeZone, + TimestampMicrosecondArray, + Timestamp, primitive_array, attribute_context ) } - TIMESTAMPOID => { + DataType::Timestamp(TimeUnit::Microsecond, Some(timezone_str)) + if timezone_str.deref() == "+00:00" => + { to_pg_datum!( TimestampMicrosecondArray, - Timestamp, + TimestampWithTimeZone, primitive_array, attribute_context ) } - TIMESTAMPTZOID => { + DataType::Struct(_) => { to_pg_datum!( - TimestampMicrosecondArray, - TimestampWithTimeZone, + StructArray, + PgHeapTuple, primitive_array, attribute_context ) } + DataType::Map(_, _) => { + to_pg_datum!(MapArray, Map, primitive_array, attribute_context) + } _ => { - if attribute_context.is_composite { - to_pg_datum!( - StructArray, - PgHeapTuple, - primitive_array, - attribute_context - ) - } else if attribute_context.is_map { - to_pg_datum!(MapArray, Map, primitive_array, attribute_context) - } else if attribute_context.is_geometry { - to_pg_datum!(BinaryArray, Geometry, primitive_array, attribute_context) - } else { - reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); - - to_pg_datum!( - StringArray, - FallbackToText, - primitive_array, - attribute_context - ) - } + panic!( + "unsupported data type: {:?}", + attribute_context.field.data_type() + ); } } } @@ -354,16 +417,31 @@ fn to_pg_array_datum( let list_array = list_array.value(0).to_data(); - match attribute_context.typoid { - FLOAT4OID => { - to_pg_datum!( - Float32Array, - Vec>, - list_array, - attribute_context - ) + let element_field = match attribute_context.field.data_type() { + DataType::List(field) => field, + _ => unreachable!(), + }; + + match element_field.data_type() { + DataType::Float32 => { + if attribute_context.typoid == FLOAT4OID { + to_pg_datum!( + Float32Array, + Vec>, + list_array, + attribute_context + ) + } else { + debug_assert!(attribute_context.typoid == FLOAT8OID); + to_pg_datum!( + Float32Array, + Vec>, + list_array, + attribute_context + ) + } } - FLOAT8OID => { + DataType::Float64 => { to_pg_datum!( Float64Array, Vec>, @@ -371,51 +449,69 @@ fn to_pg_array_datum( attribute_context ) } - INT2OID => { - to_pg_datum!(Int16Array, Vec>, list_array, attribute_context) + DataType::Int16 => { + if attribute_context.typoid == INT2OID { + to_pg_datum!(Int16Array, Vec>, list_array, attribute_context) + } else if attribute_context.typoid == INT4OID { + to_pg_datum!(Int16Array, Vec>, list_array, attribute_context) + } else { + debug_assert!(attribute_context.typoid == INT8OID); + to_pg_datum!(Int16Array, Vec>, list_array, attribute_context) + } } - INT4OID => { - to_pg_datum!(Int32Array, Vec>, list_array, attribute_context) + DataType::UInt16 => { + if attribute_context.typoid == INT2OID { + to_pg_datum!(UInt16Array, Vec>, list_array, attribute_context) + } else if attribute_context.typoid == INT4OID { + to_pg_datum!(UInt16Array, Vec>, list_array, attribute_context) + } else { + debug_assert!(attribute_context.typoid == INT8OID); + to_pg_datum!(UInt16Array, Vec>, list_array, attribute_context) + } } - INT8OID => { - to_pg_datum!(Int64Array, Vec>, list_array, attribute_context) + DataType::Int32 => { + if attribute_context.typoid == INT4OID { + to_pg_datum!(Int32Array, Vec>, list_array, attribute_context) + } else { + debug_assert!(attribute_context.typoid == INT8OID); + to_pg_datum!(Int32Array, Vec>, list_array, attribute_context) + } } - BOOLOID => { - to_pg_datum!( - BooleanArray, - Vec>, - list_array, - attribute_context - ) + DataType::UInt32 => { + if attribute_context.typoid == OIDOID { + to_pg_datum!(UInt32Array, Vec>, list_array, attribute_context) + } else if attribute_context.typoid == INT4OID { + to_pg_datum!(UInt32Array, Vec>, list_array, attribute_context) + } else { + debug_assert!(attribute_context.typoid == INT8OID); + to_pg_datum!(UInt32Array, Vec>, list_array, attribute_context) + } } - CHAROID => { - to_pg_datum!(StringArray, Vec>, list_array, attribute_context) + DataType::Int64 => { + to_pg_datum!(Int64Array, Vec>, list_array, attribute_context) } - TEXTOID => { - to_pg_datum!( - StringArray, - Vec>, - list_array, - attribute_context - ) + DataType::UInt64 => { + to_pg_datum!(UInt64Array, Vec>, list_array, attribute_context) } - BYTEAOID => { + DataType::Boolean => { to_pg_datum!( - BinaryArray, - Vec>>, + BooleanArray, + Vec>, list_array, attribute_context ) } - OIDOID => { - to_pg_datum!(UInt32Array, Vec>, list_array, attribute_context) - } - NUMERICOID => { - let precision = attribute_context - .precision - .expect("missing precision in context"); - - if should_write_numeric_as_text(precision) { + DataType::Utf8 => { + if attribute_context.typoid == CHAROID { + to_pg_datum!(StringArray, Vec>, list_array, attribute_context) + } else if attribute_context.typoid == TEXTOID { + to_pg_datum!( + StringArray, + Vec>, + list_array, + attribute_context + ) + } else { reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); to_pg_datum!( @@ -424,82 +520,135 @@ fn to_pg_array_datum( list_array, attribute_context ) + } + } + DataType::LargeUtf8 => { + if attribute_context.typoid == CHAROID { + to_pg_datum!( + LargeStringArray, + Vec>, + list_array, + attribute_context + ) + } else if attribute_context.typoid == TEXTOID { + to_pg_datum!( + LargeStringArray, + Vec>, + list_array, + attribute_context + ) + } else { + reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); + + to_pg_datum!( + LargeStringArray, + Vec>, + list_array, + attribute_context + ) + } + } + DataType::Binary => { + if attribute_context.is_geometry { + to_pg_datum!( + BinaryArray, + Vec>, + list_array, + attribute_context + ) } else { to_pg_datum!( - Decimal128Array, - Vec>, + BinaryArray, + Vec>>, list_array, attribute_context ) } } - DATEOID => { + DataType::LargeBinary => { + if attribute_context.is_geometry { + to_pg_datum!( + LargeBinaryArray, + Vec>, + list_array, + attribute_context + ) + } else { + to_pg_datum!( + LargeBinaryArray, + Vec>>, + list_array, + attribute_context + ) + } + } + DataType::Decimal128(_, _) => { to_pg_datum!( - Date32Array, - Vec>, + Decimal128Array, + Vec>, list_array, attribute_context ) } - TIMEOID => { + DataType::Date32 => { to_pg_datum!( - Time64MicrosecondArray, - Vec>, + Date32Array, + Vec>, list_array, attribute_context ) } - TIMETZOID => { + DataType::Time64(TimeUnit::Microsecond) => { + if attribute_context.typoid == TIMEOID { + to_pg_datum!( + Time64MicrosecondArray, + Vec>, + list_array, + attribute_context + ) + } else { + to_pg_datum!( + Time64MicrosecondArray, + Vec>, + list_array, + attribute_context + ) + } + } + DataType::Timestamp(TimeUnit::Microsecond, None) => { to_pg_datum!( - Time64MicrosecondArray, - Vec>, + TimestampMicrosecondArray, + Vec>, list_array, attribute_context ) } - TIMESTAMPOID => { + DataType::Timestamp(TimeUnit::Microsecond, Some(timezone_str)) + if timezone_str.deref() == "+00:00" => + { to_pg_datum!( TimestampMicrosecondArray, - Vec>, + Vec>, list_array, attribute_context ) } - TIMESTAMPTZOID => { + DataType::Struct(_) => { to_pg_datum!( - TimestampMicrosecondArray, - Vec>, + StructArray, + Vec>>, list_array, attribute_context ) } + DataType::Map(_, _) => { + to_pg_datum!(MapArray, Vec>, list_array, attribute_context) + } _ => { - if attribute_context.is_composite { - to_pg_datum!( - StructArray, - Vec>>, - list_array, - attribute_context - ) - } else if attribute_context.is_map { - to_pg_datum!(MapArray, Vec>, list_array, attribute_context) - } else if attribute_context.is_geometry { - to_pg_datum!( - BinaryArray, - Vec>, - list_array, - attribute_context - ) - } else { - reset_fallback_to_text_context(attribute_context.typoid, attribute_context.typmod); - - to_pg_datum!( - StringArray, - Vec>, - list_array, - attribute_context - ) - } + panic!( + "unsupported data type: {:?}", + attribute_context.field.data_type() + ); } } } diff --git a/src/arrow_parquet/arrow_to_pg/bytea.rs b/src/arrow_parquet/arrow_to_pg/bytea.rs index fc67c2c..d17262b 100644 --- a/src/arrow_parquet/arrow_to_pg/bytea.rs +++ b/src/arrow_parquet/arrow_to_pg/bytea.rs @@ -1,4 +1,4 @@ -use arrow::array::{Array, BinaryArray}; +use arrow::array::{Array, BinaryArray, LargeBinaryArray}; use super::{ArrowArrayToPgType, ArrowToPgAttributeContext}; @@ -13,6 +13,16 @@ impl ArrowArrayToPgType> for BinaryArray { } } +impl ArrowArrayToPgType> for LargeBinaryArray { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option> { + if self.is_null(0) { + None + } else { + Some(self.value(0).to_vec()) + } + } +} + // Bytea[] impl ArrowArrayToPgType>>> for BinaryArray { fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>>> { @@ -28,3 +38,18 @@ impl ArrowArrayToPgType>>> for BinaryArray { Some(vals) } } + +impl ArrowArrayToPgType>>> for LargeBinaryArray { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>>> { + let mut vals = vec![]; + for val in self.iter() { + if let Some(val) = val { + vals.push(Some(val.to_vec())); + } else { + vals.push(None); + } + } + + Some(vals) + } +} diff --git a/src/arrow_parquet/arrow_to_pg/char.rs b/src/arrow_parquet/arrow_to_pg/char.rs index 2a23187..7b55d23 100644 --- a/src/arrow_parquet/arrow_to_pg/char.rs +++ b/src/arrow_parquet/arrow_to_pg/char.rs @@ -1,4 +1,4 @@ -use arrow::array::{Array, StringArray}; +use arrow::array::{Array, LargeStringArray, StringArray}; use super::{ArrowArrayToPgType, ArrowToPgAttributeContext}; @@ -15,6 +15,18 @@ impl ArrowArrayToPgType for StringArray { } } +impl ArrowArrayToPgType for LargeStringArray { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + if self.is_null(0) { + None + } else { + let val = self.value(0); + let val: i8 = val.chars().next().expect("unexpected ascii char") as i8; + Some(val) + } + } +} + // Char[] impl ArrowArrayToPgType>> for StringArray { fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { @@ -29,3 +41,17 @@ impl ArrowArrayToPgType>> for StringArray { Some(vals) } } + +impl ArrowArrayToPgType>> for LargeStringArray { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { + let mut vals = vec![]; + for val in self.iter() { + let val = val.map(|val| { + let val: i8 = val.chars().next().expect("unexpected ascii char") as i8; + val + }); + vals.push(val); + } + Some(vals) + } +} diff --git a/src/arrow_parquet/arrow_to_pg/fallback_to_text.rs b/src/arrow_parquet/arrow_to_pg/fallback_to_text.rs index 5144787..a07bd08 100644 --- a/src/arrow_parquet/arrow_to_pg/fallback_to_text.rs +++ b/src/arrow_parquet/arrow_to_pg/fallback_to_text.rs @@ -1,4 +1,4 @@ -use arrow::array::{Array, StringArray}; +use arrow::array::{Array, LargeStringArray, StringArray}; use crate::type_compat::fallback_to_text::FallbackToText; @@ -17,6 +17,18 @@ impl ArrowArrayToPgType for StringArray { } } +impl ArrowArrayToPgType for LargeStringArray { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + if self.is_null(0) { + None + } else { + let text_repr = self.value(0).to_string(); + let val = FallbackToText(text_repr); + Some(val) + } + } +} + // Text[] representation of any type impl ArrowArrayToPgType>> for StringArray { fn to_pg_type( @@ -31,3 +43,17 @@ impl ArrowArrayToPgType>> for StringArray { Some(vals) } } + +impl ArrowArrayToPgType>> for LargeStringArray { + fn to_pg_type( + self, + _context: &ArrowToPgAttributeContext, + ) -> Option>> { + let mut vals = vec![]; + for val in self.iter() { + let val = val.map(|val| FallbackToText(val.to_string())); + vals.push(val); + } + Some(vals) + } +} diff --git a/src/arrow_parquet/arrow_to_pg/float4.rs b/src/arrow_parquet/arrow_to_pg/float4.rs index 48f36e2..19ffb6a 100644 --- a/src/arrow_parquet/arrow_to_pg/float4.rs +++ b/src/arrow_parquet/arrow_to_pg/float4.rs @@ -14,6 +14,17 @@ impl ArrowArrayToPgType for Float32Array { } } +impl ArrowArrayToPgType for Float32Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + if self.is_null(0) { + None + } else { + let val = self.value(0) as _; + Some(val) + } + } +} + // Float4[] impl ArrowArrayToPgType>> for Float32Array { fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { @@ -24,3 +35,13 @@ impl ArrowArrayToPgType>> for Float32Array { Some(vals) } } + +impl ArrowArrayToPgType>> for Float32Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { + let mut vals = vec![]; + for val in self.iter() { + vals.push(val.map(|val| val as _)); + } + Some(vals) + } +} diff --git a/src/arrow_parquet/arrow_to_pg/geometry.rs b/src/arrow_parquet/arrow_to_pg/geometry.rs index eea86af..6b8e3c8 100644 --- a/src/arrow_parquet/arrow_to_pg/geometry.rs +++ b/src/arrow_parquet/arrow_to_pg/geometry.rs @@ -1,4 +1,4 @@ -use arrow::array::{Array, BinaryArray}; +use arrow::array::{Array, BinaryArray, LargeBinaryArray}; use crate::type_compat::geometry::Geometry; @@ -15,6 +15,16 @@ impl ArrowArrayToPgType for BinaryArray { } } +impl ArrowArrayToPgType for LargeBinaryArray { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + if self.is_null(0) { + None + } else { + Some(self.value(0).to_vec().into()) + } + } +} + // Geometry[] impl ArrowArrayToPgType>> for BinaryArray { fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { @@ -30,3 +40,18 @@ impl ArrowArrayToPgType>> for BinaryArray { Some(vals) } } + +impl ArrowArrayToPgType>> for LargeBinaryArray { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { + let mut vals = vec![]; + for val in self.iter() { + if let Some(val) = val { + vals.push(Some(val.to_vec().into())); + } else { + vals.push(None); + } + } + + Some(vals) + } +} diff --git a/src/arrow_parquet/arrow_to_pg/int2.rs b/src/arrow_parquet/arrow_to_pg/int2.rs index 6f814db..d1c4e73 100644 --- a/src/arrow_parquet/arrow_to_pg/int2.rs +++ b/src/arrow_parquet/arrow_to_pg/int2.rs @@ -1,4 +1,4 @@ -use arrow::array::{Array, Int16Array}; +use arrow::array::{Array, Int16Array, UInt16Array}; use super::{ArrowArrayToPgType, ArrowToPgAttributeContext}; @@ -14,6 +14,61 @@ impl ArrowArrayToPgType for Int16Array { } } +impl ArrowArrayToPgType for Int16Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + if self.is_null(0) { + None + } else { + let val = self.value(0) as _; + Some(val) + } + } +} + +impl ArrowArrayToPgType for Int16Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + if self.is_null(0) { + None + } else { + let val = self.value(0) as _; + Some(val) + } + } +} + +impl ArrowArrayToPgType for UInt16Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + if self.is_null(0) { + None + } else { + let val = self.value(0) as _; + Some(val) + } + } +} + +impl ArrowArrayToPgType for UInt16Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + if self.is_null(0) { + None + } else { + let val = self.value(0) as _; + Some(val) + } + } +} + +impl ArrowArrayToPgType for UInt16Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + if self.is_null(0) { + None + } else { + let val = self.value(0) as _; + Some(val) + } + } +} + // Int2[] impl ArrowArrayToPgType>> for Int16Array { fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { @@ -24,3 +79,53 @@ impl ArrowArrayToPgType>> for Int16Array { Some(vals) } } + +impl ArrowArrayToPgType>> for Int16Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { + let mut vals = vec![]; + for val in self.iter() { + vals.push(val.map(|val| val as _)); + } + Some(vals) + } +} + +impl ArrowArrayToPgType>> for Int16Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { + let mut vals = vec![]; + for val in self.iter() { + vals.push(val.map(|val| val as _)); + } + Some(vals) + } +} + +impl ArrowArrayToPgType>> for UInt16Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { + let mut vals = vec![]; + for val in self.iter() { + vals.push(val.map(|val| val as _)); + } + Some(vals) + } +} + +impl ArrowArrayToPgType>> for UInt16Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { + let mut vals = vec![]; + for val in self.iter() { + vals.push(val.map(|val| val as _)); + } + Some(vals) + } +} + +impl ArrowArrayToPgType>> for UInt16Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { + let mut vals = vec![]; + for val in self.iter() { + vals.push(val.map(|val| val as _)); + } + Some(vals) + } +} diff --git a/src/arrow_parquet/arrow_to_pg/int4.rs b/src/arrow_parquet/arrow_to_pg/int4.rs index 87a06e4..ecae4d4 100644 --- a/src/arrow_parquet/arrow_to_pg/int4.rs +++ b/src/arrow_parquet/arrow_to_pg/int4.rs @@ -1,4 +1,4 @@ -use arrow::array::{Array, Int32Array}; +use arrow::array::{Array, Int32Array, UInt32Array}; use super::{ArrowArrayToPgType, ArrowToPgAttributeContext}; @@ -14,6 +14,39 @@ impl ArrowArrayToPgType for Int32Array { } } +impl ArrowArrayToPgType for Int32Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + if self.is_null(0) { + None + } else { + let val = self.value(0) as _; + Some(val) + } + } +} + +impl ArrowArrayToPgType for UInt32Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + if self.is_null(0) { + None + } else { + let val = self.value(0) as _; + Some(val) + } + } +} + +impl ArrowArrayToPgType for UInt32Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + if self.is_null(0) { + None + } else { + let val = self.value(0) as _; + Some(val) + } + } +} + // Int4[] impl ArrowArrayToPgType>> for Int32Array { fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { @@ -24,3 +57,33 @@ impl ArrowArrayToPgType>> for Int32Array { Some(vals) } } + +impl ArrowArrayToPgType>> for Int32Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { + let mut vals = vec![]; + for val in self.iter() { + vals.push(val.map(|val| val as _)); + } + Some(vals) + } +} + +impl ArrowArrayToPgType>> for UInt32Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { + let mut vals = vec![]; + for val in self.iter() { + vals.push(val.map(|val| val as _)); + } + Some(vals) + } +} + +impl ArrowArrayToPgType>> for UInt32Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { + let mut vals = vec![]; + for val in self.iter() { + vals.push(val.map(|val| val as _)); + } + Some(vals) + } +} diff --git a/src/arrow_parquet/arrow_to_pg/int8.rs b/src/arrow_parquet/arrow_to_pg/int8.rs index 151b99e..978f70b 100644 --- a/src/arrow_parquet/arrow_to_pg/int8.rs +++ b/src/arrow_parquet/arrow_to_pg/int8.rs @@ -1,4 +1,4 @@ -use arrow::array::{Array, Int64Array}; +use arrow::array::{Array, Int64Array, UInt64Array}; use super::{ArrowArrayToPgType, ArrowToPgAttributeContext}; @@ -14,6 +14,17 @@ impl ArrowArrayToPgType for Int64Array { } } +impl ArrowArrayToPgType for UInt64Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + if self.is_null(0) { + None + } else { + let val = self.value(0) as _; + Some(val) + } + } +} + // Int8[] impl ArrowArrayToPgType>> for Int64Array { fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { @@ -24,3 +35,13 @@ impl ArrowArrayToPgType>> for Int64Array { Some(vals) } } + +impl ArrowArrayToPgType>> for UInt64Array { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { + let mut vals = vec![]; + for val in self.iter() { + vals.push(val.map(|v| v as _)); + } + Some(vals) + } +} diff --git a/src/arrow_parquet/arrow_to_pg/text.rs b/src/arrow_parquet/arrow_to_pg/text.rs index ba784e0..b4190a1 100644 --- a/src/arrow_parquet/arrow_to_pg/text.rs +++ b/src/arrow_parquet/arrow_to_pg/text.rs @@ -1,4 +1,4 @@ -use arrow::array::{Array, StringArray}; +use arrow::array::{Array, LargeStringArray, StringArray}; use super::{ArrowArrayToPgType, ArrowToPgAttributeContext}; @@ -14,6 +14,17 @@ impl ArrowArrayToPgType for StringArray { } } +impl ArrowArrayToPgType for LargeStringArray { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option { + if self.is_null(0) { + None + } else { + let val = self.value(0); + Some(val.to_string()) + } + } +} + // Text[] impl ArrowArrayToPgType>> for StringArray { fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { @@ -25,3 +36,14 @@ impl ArrowArrayToPgType>> for StringArray { Some(vals) } } + +impl ArrowArrayToPgType>> for LargeStringArray { + fn to_pg_type(self, _context: &ArrowToPgAttributeContext) -> Option>> { + let mut vals = vec![]; + for val in self.iter() { + let val = val.map(|val| val.to_string()); + vals.push(val); + } + Some(vals) + } +} diff --git a/src/arrow_parquet/schema_parser.rs b/src/arrow_parquet/schema_parser.rs index 8dd79cf..9d1fc83 100644 --- a/src/arrow_parquet/schema_parser.rs +++ b/src/arrow_parquet/schema_parser.rs @@ -1,7 +1,7 @@ use std::{collections::HashMap, ops::Deref, sync::Arc}; use arrow::datatypes::{Field, Fields, Schema}; -use arrow_schema::FieldRef; +use arrow_schema::{DataType, FieldRef}; use parquet::arrow::{arrow_to_parquet_schema, PARQUET_FIELD_ID_META_KEY}; use pg_sys::{ Oid, BOOLOID, BYTEAOID, CHAROID, DATEOID, FLOAT4OID, FLOAT8OID, INT2OID, INT4OID, INT8OID, @@ -315,6 +315,50 @@ fn adjust_map_entries_field(field: FieldRef) -> FieldRef { Arc::new(entries_field) } +fn is_coercible_types(from_type: &DataType, to_type: &DataType) -> bool { + if let (DataType::List(from_elem_field), DataType::List(to_elem_field)) = (from_type, to_type) { + return is_coercible_types(from_elem_field.data_type(), to_elem_field.data_type()); + } else if let (DataType::Struct(from_fields), DataType::Struct(to_fields)) = + (from_type, to_type) + { + if from_fields.len() != to_fields.len() { + return false; + } + + for (from_field, to_field) in from_fields.iter().zip(to_fields.iter()) { + if from_field.name() != to_field.name() { + return false; + } + + if !is_coercible_types(from_field.data_type(), to_field.data_type()) { + return false; + } + } + + return true; + } else if let (DataType::Map(from_entries_field, _), DataType::Map(to_entries_field, _)) = + (from_type, to_type) + { + return is_coercible_types(from_entries_field.data_type(), to_entries_field.data_type()); + } + + matches!( + (from_type, to_type), + (DataType::Float32, DataType::Float64) + | (DataType::Int16, DataType::Int32) + | (DataType::Int16, DataType::Int64) + | (DataType::Int32, DataType::Int64) + | (DataType::UInt16, DataType::Int16) + | (DataType::UInt16, DataType::Int32) + | (DataType::UInt16, DataType::Int64) + | (DataType::UInt32, DataType::Int32) + | (DataType::UInt32, DataType::Int64) + | (DataType::UInt64, DataType::Int64) + | (DataType::LargeUtf8, DataType::Utf8) + | (DataType::LargeBinary, DataType::Binary) + ) +} + pub(crate) fn ensure_arrow_schema_match_tupledesc( file_schema: Arc, tupledesc: &PgTupleDesc, @@ -330,7 +374,9 @@ pub(crate) fn ensure_arrow_schema_match_tupledesc( if let Some(file_schema_field) = file_schema_field { let file_schema_field_type = file_schema_field.1.data_type(); - if file_schema_field_type != table_schema_field_type { + if file_schema_field_type != table_schema_field_type + && !is_coercible_types(file_schema_field_type, table_schema_field_type) + { panic!( "type mismatch for column \"{}\" between table and parquet file. table expected \"{}\" but file had \"{}\"", table_schema_field_name, diff --git a/src/lib.rs b/src/lib.rs index 57584bb..ab6cf99 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -37,8 +37,11 @@ pub extern "C" fn _PG_init() { #[cfg(any(test, feature = "pg_test"))] #[pg_schema] mod tests { + use std::fs::File; use std::io::Write; use std::marker::PhantomData; + use std::sync::Arc; + use std::vec; use std::{collections::HashMap, fmt::Debug}; use crate::arrow_parquet::compression::PgParquetCompression; @@ -48,6 +51,12 @@ mod tests { use crate::type_compat::pg_arrow_type_conversions::{ DEFAULT_UNBOUNDED_NUMERIC_PRECISION, DEFAULT_UNBOUNDED_NUMERIC_SCALE, }; + use arrow::array::{ + Float32Array, Int16Array, Int32Array, LargeBinaryArray, LargeStringArray, RecordBatch, + UInt16Array, UInt32Array, UInt64Array, + }; + use arrow_schema::{DataType, Field, Schema, SchemaRef}; + use parquet::arrow::ArrowWriter; use pgrx::pg_sys::Oid; use pgrx::{ composite_type, @@ -1391,6 +1400,217 @@ mod tests { Spi::run("DROP TYPE dog;").unwrap(); } + fn write_record_batch_to_parquet(schema: SchemaRef, record_batch: RecordBatch) { + let file = File::create("/tmp/test.parquet").unwrap(); + let mut writer = ArrowWriter::try_new(file, schema, None).unwrap(); + + writer.write(&record_batch).unwrap(); + writer.close().unwrap(); + } + + #[pg_test] + fn test_coerce_types() { + // INT16 => {int, bigint} + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Int16, true), + Field::new("y", DataType::Int16, true), + ])); + + let x = Arc::new(Int16Array::from(vec![1])); + let y = Arc::new(Int16Array::from(vec![2])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x, y]).unwrap(); + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x int, y bigint)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_two::("SELECT x, y FROM test_table LIMIT 1").unwrap(); + assert_eq!(value, (Some(1), Some(2))); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // INT32 => {bigint} + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, true)])); + + let x = Arc::new(Int32Array::from(vec![1])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x bigint)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, 1); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // FLOAT32 => {double} + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Float32, true)])); + + let x = Arc::new(Float32Array::from(vec![1.123])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x double precision)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value as f32, 1.123); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // UINT16 => {smallint, int, bigint} + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::UInt16, true), + Field::new("y", DataType::UInt16, true), + Field::new("z", DataType::UInt16, true), + ])); + + let x = Arc::new(UInt16Array::from(vec![1])); + let y = Arc::new(UInt16Array::from(vec![2])); + let z = Arc::new(UInt16Array::from(vec![3])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x, y, z]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x smallint, y int, z bigint)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = + Spi::get_three::("SELECT x, y, z FROM test_table LIMIT 1").unwrap(); + assert_eq!(value, (Some(1), Some(2), Some(3))); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // UINT32 => {int, bigint} + let schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::UInt32, true), + Field::new("y", DataType::UInt32, true), + ])); + + let x = Arc::new(UInt32Array::from(vec![1])); + let y = Arc::new(UInt32Array::from(vec![2])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x, y]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x int, y bigint)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_two::("SELECT x, y FROM test_table LIMIT 1").unwrap(); + assert_eq!(value, (Some(1), Some(2))); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // UINT64 => {bigint} + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::UInt64, true)])); + + let x = Arc::new(UInt64Array::from(vec![1])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x bigint)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, 1); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // LargeUtf8 => {text} + let schema = Arc::new(Schema::new(vec![Field::new( + "x", + DataType::LargeUtf8, + true, + )])); + + let x = Arc::new(LargeStringArray::from(vec!["test"])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x text)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, "test"); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + + // LargeBinary => {bytea} + let schema = Arc::new(Schema::new(vec![Field::new( + "x", + DataType::LargeBinary, + true, + )])); + + let x = Arc::new(LargeBinaryArray::from(vec!["abc".as_bytes()])); + + let batch = RecordBatch::try_new(schema.clone(), vec![x]).unwrap(); + + write_record_batch_to_parquet(schema, batch); + + let create_table = "CREATE TABLE test_table (x bytea)"; + Spi::run(create_table).unwrap(); + + let copy_from = "COPY test_table FROM '/tmp/test.parquet'"; + Spi::run(copy_from).unwrap(); + + let value = Spi::get_one::>("SELECT x FROM test_table LIMIT 1") + .unwrap() + .unwrap(); + assert_eq!(value, "abc".as_bytes()); + + let drop_table = "DROP TABLE test_table"; + Spi::run(drop_table).unwrap(); + } + #[pg_test] fn test_copy_with_empty_options() { let test_table = TestTable::::new("int4".into())