Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(parquet): handle nested data types correctly #20156

Merged
merged 6 commits into from
Jan 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 24 additions & 12 deletions e2e_test/s3/file_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
def gen_data(file_num, item_num_per_file):
assert item_num_per_file % 2 == 0, \
f'item_num_per_file should be even to ensure sum(mark) == 0: {item_num_per_file}'

struct_type = pa.struct([
('field1', pa.int32()),
('field2', pa.string())
])

return [
[{
'id': file_id * item_num_per_file + item_id,
Expand Down Expand Up @@ -44,6 +50,7 @@ def gen_data(file_num, item_num_per_file):
'test_timestamptz_ms': pa.scalar(datetime.now().timestamp() * 1000, type=pa.timestamp('ms', tz='+00:00')),
'test_timestamptz_us': pa.scalar(datetime.now().timestamp() * 1000000, type=pa.timestamp('us', tz='+00:00')),
'test_timestamptz_ns': pa.scalar(datetime.now().timestamp() * 1000000000, type=pa.timestamp('ns', tz='+00:00')),
'nested_struct': pa.scalar((item_id, f'struct_value_{item_id}'), type=struct_type),
} for item_id in range(item_num_per_file)]
for file_id in range(file_num)
]
Expand All @@ -65,7 +72,7 @@ def _table():
print("test table function file scan")
cur.execute(f'''
SELECT
id,
id,
name,
sex,
mark,
Expand All @@ -89,7 +96,8 @@ def _table():
test_timestamptz_s,
test_timestamptz_ms,
test_timestamptz_us,
test_timestamptz_ns
test_timestamptz_ns,
nested_struct
FROM file_scan(
'parquet',
's3',
Expand All @@ -104,7 +112,6 @@ def _table():
except ValueError as e:
print(f"cur.fetchone() got ValueError: {e}")


print("file scan test pass")
# Execute a SELECT statement
cur.execute(f'''CREATE TABLE {_table()}(
Expand Down Expand Up @@ -132,8 +139,8 @@ def _table():
test_timestamptz_s timestamptz,
test_timestamptz_ms timestamptz,
test_timestamptz_us timestamptz,
test_timestamptz_ns timestamptz

test_timestamptz_ns timestamptz,
nested_struct STRUCT<"field1" int, "field2" varchar>
) WITH (
connector = 's3',
match_pattern = '*.parquet',
Expand Down Expand Up @@ -213,7 +220,8 @@ def _table():
test_timestamptz_s,
test_timestamptz_ms,
test_timestamptz_us,
test_timestamptz_ns
test_timestamptz_ns,
nested_struct
from {_table()} WITH (
connector = 's3',
match_pattern = '*.parquet',
Expand All @@ -230,7 +238,7 @@ def _table():
print('Sink into s3 in parquet encode...')
# Execute a SELECT statement
cur.execute(f'''CREATE TABLE test_parquet_sink_table(
id bigint primary key,\
id bigint primary key,
name TEXT,
sex bigint,
mark bigint,
Expand All @@ -254,7 +262,8 @@ def _table():
test_timestamptz_s timestamptz,
test_timestamptz_ms timestamptz,
test_timestamptz_us timestamptz,
test_timestamptz_ns timestamptz
test_timestamptz_ns timestamptz,
nested_struct STRUCT<"field1" int, "field2" varchar>,
) WITH (
connector = 's3',
match_pattern = 'test_parquet_sink/*.parquet',
Expand All @@ -263,8 +272,8 @@ def _table():
s3.credentials.access = 'hummockadmin',
s3.credentials.secret = 'hummockadmin',
s3.endpoint_url = 'http://hummock001.127.0.0.1:9301',
refresh.interval.sec = 1,
) FORMAT PLAIN ENCODE PARQUET;''')

total_rows = file_num * item_num_per_file
MAX_RETRIES = 40
for retry_no in range(MAX_RETRIES):
Expand Down Expand Up @@ -305,7 +314,8 @@ def _table():
test_timestamptz_s,
test_timestamptz_ms,
test_timestamptz_us,
test_timestamptz_ns
test_timestamptz_ns,
nested_struct
from {_table()} WITH (
connector = 'snowflake',
match_pattern = '*.parquet',
Expand All @@ -316,7 +326,8 @@ def _table():
s3.endpoint_url = 'http://hummock001.127.0.0.1:9301',
s3.path = 'test_json_sink/',
type = 'append-only',
force_append_only='true'
force_append_only='true',
refresh.interval.sec = 1,
) FORMAT PLAIN ENCODE JSON(force_append_only='true');''')

print('Sink into s3 in json encode...')
Expand Down Expand Up @@ -346,7 +357,8 @@ def _table():
test_timestamptz_s timestamptz,
test_timestamptz_ms timestamptz,
test_timestamptz_us timestamptz,
test_timestamptz_ns timestamptz
test_timestamptz_ns timestamptz,
nested_struct STRUCT<"field1" int, "field2" varchar>
) WITH (
connector = 's3',
match_pattern = 'test_json_sink/*.json',
Expand Down
5 changes: 5 additions & 0 deletions src/connector/src/parser/parquet_parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use risingwave_common::util::tokio_util::compat::FuturesAsyncReadCompatExt;
use crate::parser::ConnectorResult;
use crate::source::filesystem::opendal_source::opendal_enumerator::OpendalEnumerator;
use crate::source::filesystem::opendal_source::{OpendalGcs, OpendalPosixFs, OpendalS3};
use crate::source::iceberg::is_parquet_schema_match_source_schema;
use crate::source::reader::desc::SourceDesc;
use crate::source::{ConnectorProperties, SourceColumnDesc};
/// `ParquetParser` is responsible for converting the incoming `record_batch_stream`
Expand Down Expand Up @@ -109,6 +110,10 @@ impl ParquetParser {

if let Some(parquet_column) =
record_batch.column_by_name(rw_column_name)
&& is_parquet_schema_match_source_schema(
parquet_column.data_type(),
rw_data_type,
)
{
let arrow_field = IcebergArrowConvert
.to_arrow_field(rw_column_name, rw_data_type)?;
Expand Down
108 changes: 56 additions & 52 deletions src/connector/src/source/iceberg/parquet_file_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use parquet::file::metadata::{FileMetaData, ParquetMetaData, ParquetMetaDataRead
use risingwave_common::array::arrow::arrow_schema_udf::{DataType as ArrowDateType, IntervalUnit};
use risingwave_common::array::arrow::IcebergArrowConvert;
use risingwave_common::array::StreamChunk;
use risingwave_common::catalog::ColumnId;
use risingwave_common::catalog::{ColumnDesc, ColumnId};
use risingwave_common::types::DataType as RwDataType;
use risingwave_common::util::tokio_util::compat::FuturesAsyncReadCompatExt;
use url::Url;
Expand Down Expand Up @@ -217,55 +217,64 @@ pub async fn list_data_directory(
}
}

/// Extracts valid column indices from a Parquet file schema based on the user's requested schema.
/// Extracts a suitable `ProjectionMask` from a Parquet file schema based on the user's requested schema.
///
/// This function is used for column pruning of Parquet files. It calculates the intersection
/// between the columns in the currently read Parquet file and the schema provided by the user.
/// This is useful for reading a `RecordBatch` with the appropriate `ProjectionMask`, ensuring that
/// only the necessary columns are read.
/// This function is utilized for column pruning of Parquet files. It checks the user's requested schema
/// against the schema of the currently read Parquet file. If the provided `columns` are `None`
/// or if the Parquet file contains nested data types, it returns `ProjectionMask::all()`. Otherwise,
/// it returns only the columns where both the data type and column name match the requested schema,
/// facilitating efficient reading of the `RecordBatch`.
///
/// # Parameters
/// - `columns`: A vector of `Column` representing the user's requested schema.
/// - `columns`: An optional vector of `Column` representing the user's requested schema.
/// - `metadata`: A reference to `FileMetaData` containing the schema and metadata of the Parquet file.
///
/// # Returns
/// - A `ConnectorResult<Vec<usize>>`, which contains the indices of the valid columns in the
/// Parquet file schema that match the requested schema. If an error occurs during processing,
/// it returns an appropriate error.
pub fn extract_valid_column_indices(
rw_columns: Vec<Column>,
/// - A `ConnectorResult<ProjectionMask>`, which represents the valid columns in the Parquet file schema
/// that correspond to the requested schema. If an error occurs during processing, it returns an
/// appropriate error.
pub fn get_project_mask(
columns: Option<Vec<Column>>,
metadata: &FileMetaData,
) -> ConnectorResult<Vec<usize>> {
let parquet_column_names = metadata
.schema_descr()
.columns()
.iter()
.map(|c| c.name())
.collect_vec();
) -> ConnectorResult<ProjectionMask> {
match columns {
Some(rw_columns) => {
let root_column_names = metadata
.schema_descr()
.root_schema()
.get_fields()
.iter()
.map(|field| field.name())
.collect_vec();

let converted_arrow_schema =
parquet_to_arrow_schema(metadata.schema_descr(), metadata.key_value_metadata())
.map_err(anyhow::Error::from)?;
let converted_arrow_schema =
parquet_to_arrow_schema(metadata.schema_descr(), metadata.key_value_metadata())
.map_err(anyhow::Error::from)?;
let valid_column_indices: Vec<usize> = rw_columns
.iter()
.filter_map(|column| {
root_column_names
.iter()
.position(|&name| name == column.name)
.and_then(|pos| {
let arrow_data_type: &risingwave_common::array::arrow::arrow_schema_udf::DataType = converted_arrow_schema.field_with_name(&column.name).ok()?.data_type();
let rw_data_type: &risingwave_common::types::DataType = &column.data_type;
if is_parquet_schema_match_source_schema(arrow_data_type, rw_data_type) {
Some(pos)
} else {
None
}
})
})
.collect();

let valid_column_indices: Vec<usize> = rw_columns
.iter()
.filter_map(|column| {
parquet_column_names
.iter()
.position(|&name| name == column.name)
.and_then(|pos| {
let arrow_data_type: &risingwave_common::array::arrow::arrow_schema_udf::DataType = converted_arrow_schema.field(pos).data_type();
let rw_data_type: &risingwave_common::types::DataType = &column.data_type;

if is_parquet_schema_match_source_schema(arrow_data_type, rw_data_type) {
Some(pos)
} else {
None
}
})
})
.collect();
Ok(valid_column_indices)
Ok(ProjectionMask::roots(
metadata.schema_descr(),
valid_column_indices,
))
}
None => Ok(ProjectionMask::all()),
}
}

/// Reads a specified Parquet file and converts its content into a stream of chunks.
Expand All @@ -289,13 +298,7 @@ pub async fn read_parquet_file(
let parquet_metadata = reader.get_metadata().await.map_err(anyhow::Error::from)?;

let file_metadata = parquet_metadata.file_metadata();
let projection_mask = match rw_columns {
Some(columns) => {
let column_indices = extract_valid_column_indices(columns, file_metadata)?;
ProjectionMask::leaves(file_metadata.schema_descr(), column_indices)
}
None => ProjectionMask::all(),
};
let projection_mask = get_project_mask(rw_columns, file_metadata)?;

// For the Parquet format, we directly convert from a record batch to a stream chunk.
// Therefore, the offset of the Parquet file represents the current position in terms of the number of rows read from the file.
Expand All @@ -318,11 +321,12 @@ pub async fn read_parquet_file(
.enumerate()
.map(|(index, field_ref)| {
let data_type = IcebergArrowConvert.type_from_field(field_ref).unwrap();
SourceColumnDesc::simple(
let column_desc = ColumnDesc::named(
field_ref.name().clone(),
data_type,
ColumnId::new(index as i32),
)
data_type,
);
SourceColumnDesc::from(&column_desc)
})
.collect(),
};
Expand Down Expand Up @@ -367,7 +371,7 @@ pub async fn get_parquet_fields(
/// - Arrow's `UInt32` matches with RisingWave's `Int64`.
/// - Arrow's `UInt64` matches with RisingWave's `Decimal`.
/// - Arrow's `Float16` matches with RisingWave's `Float32`.
fn is_parquet_schema_match_source_schema(
pub fn is_parquet_schema_match_source_schema(
arrow_data_type: &ArrowDateType,
rw_data_type: &RwDataType,
) -> bool {
Expand Down
Loading