From 08e4e6ad020136cac18710d0a18f31d44e025457 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sava=20Vrane=C5=A1evi=C4=87?= <20240220+svranesevic@users.noreply.github.com> Date: Sat, 22 Jun 2024 15:25:39 +0200 Subject: [PATCH] Fix `FormatOptions::CSV` propagation (#10912) * Fix sink output schema being passed in to `FileSinkExec` where input schema was expected * Propagate CSV options (quote, double quote, and escape) through protos * Add test for double quotes * Test quote escape when double quotes are disabled * regen --------- Co-authored-by: svranesevic Co-authored-by: Andrew Lamb --- datafusion/common/src/config.rs | 9 +++ .../common/src/file_options/csv_writer.rs | 6 ++ datafusion/core/tests/data/double_quote.csv | 5 ++ .../proto/datafusion_common.proto | 7 ++ datafusion/proto-common/src/from_proto/mod.rs | 26 ++++++- .../proto-common/src/generated/pbjson.rs | 73 +++++++++++++++++++ .../proto-common/src/generated/prost.rs | 12 +++ datafusion/proto-common/src/to_proto/mod.rs | 4 + .../src/generated/datafusion_proto_common.rs | 12 +++ datafusion/proto/src/physical_plan/mod.rs | 12 +-- .../sqllogictest/test_files/csv_files.slt | 67 +++++++++++++++++ 11 files changed, 226 insertions(+), 7 deletions(-) create mode 100644 datafusion/core/tests/data/double_quote.csv diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index c59cdba7c829..47da14574c5d 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -1566,6 +1566,7 @@ config_namespace! { pub delimiter: u8, default = b',' pub quote: u8, default = b'"' pub escape: Option, default = None + pub double_quote: Option, default = None pub compression: CompressionTypeVariant, default = CompressionTypeVariant::UNCOMPRESSED pub schema_infer_max_rec: usize, default = 100 pub date_format: Option, default = None @@ -1631,6 +1632,13 @@ impl CsvOptions { self } + /// Set true to indicate that the CSV quotes should be doubled. + /// - default to true + pub fn with_double_quote(mut self, double_quote: bool) -> Self { + self.double_quote = Some(double_quote); + self + } + /// Set a `CompressionTypeVariant` of CSV /// - defaults to `CompressionTypeVariant::UNCOMPRESSED` pub fn with_file_compression_type( @@ -1675,6 +1683,7 @@ pub enum FormatOptions { AVRO, ARROW, } + impl Display for FormatOptions { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let out = match self { diff --git a/datafusion/common/src/file_options/csv_writer.rs b/datafusion/common/src/file_options/csv_writer.rs index 4f948a29adc4..5792cfdba9e0 100644 --- a/datafusion/common/src/file_options/csv_writer.rs +++ b/datafusion/common/src/file_options/csv_writer.rs @@ -69,6 +69,12 @@ impl TryFrom<&CsvOptions> for CsvWriterOptions { if let Some(v) = &value.null_value { builder = builder.with_null(v.into()) } + if let Some(v) = &value.escape { + builder = builder.with_escape(*v) + } + if let Some(v) = &value.double_quote { + builder = builder.with_double_quote(*v) + } Ok(CsvWriterOptions { writer_options: builder, compression: value.compression, diff --git a/datafusion/core/tests/data/double_quote.csv b/datafusion/core/tests/data/double_quote.csv new file mode 100644 index 000000000000..95a6f0c4077a --- /dev/null +++ b/datafusion/core/tests/data/double_quote.csv @@ -0,0 +1,5 @@ +c1,c2 +id0,"""value0""" +id1,"""value1""" +id2,"""value2""" +id3,"""value3""" diff --git a/datafusion/proto-common/proto/datafusion_common.proto b/datafusion/proto-common/proto/datafusion_common.proto index e523ef1a5e93..225bb9ddf661 100644 --- a/datafusion/proto-common/proto/datafusion_common.proto +++ b/datafusion/proto-common/proto/datafusion_common.proto @@ -385,6 +385,12 @@ message CsvWriterOptions { string time_format = 7; // Optional value to represent null string null_value = 8; + // Optional quote. Defaults to `b'"'` + string quote = 9; + // Optional escape. Defaults to `'\\'` + string escape = 10; + // Optional flag whether to double quotes, instead of escaping. Defaults to `true` + bool double_quote = 11; } // Options controlling CSV format @@ -402,6 +408,7 @@ message CsvOptions { string time_format = 11; // Optional time format string null_value = 12; // Optional representation of null value bytes comment = 13; // Optional comment character as a byte + bytes double_quote = 14; // Indicates if quotes are doubled } // Options controlling CSV format diff --git a/datafusion/proto-common/src/from_proto/mod.rs b/datafusion/proto-common/src/from_proto/mod.rs index be87123fb13f..de9fede9ee86 100644 --- a/datafusion/proto-common/src/from_proto/mod.rs +++ b/datafusion/proto-common/src/from_proto/mod.rs @@ -857,6 +857,7 @@ impl TryFrom<&protobuf::CsvOptions> for CsvOptions { delimiter: proto_opts.delimiter[0], quote: proto_opts.quote[0], escape: proto_opts.escape.first().copied(), + double_quote: proto_opts.has_header.first().map(|h| *h != 0), compression: proto_opts.compression().into(), schema_infer_max_rec: proto_opts.schema_infer_max_rec as usize, date_format: (!proto_opts.date_format.is_empty()) @@ -1091,11 +1092,34 @@ pub(crate) fn csv_writer_options_from_proto( return Err(proto_error("Error parsing CSV Delimiter")); } } + if !writer_options.quote.is_empty() { + if let Some(quote) = writer_options.quote.chars().next() { + if quote.is_ascii() { + builder = builder.with_quote(quote as u8); + } else { + return Err(proto_error("CSV Quote is not ASCII")); + } + } else { + return Err(proto_error("Error parsing CSV Quote")); + } + } + if !writer_options.escape.is_empty() { + if let Some(escape) = writer_options.escape.chars().next() { + if escape.is_ascii() { + builder = builder.with_escape(escape as u8); + } else { + return Err(proto_error("CSV Escape is not ASCII")); + } + } else { + return Err(proto_error("Error parsing CSV Escape")); + } + } Ok(builder .with_header(writer_options.has_header) .with_date_format(writer_options.date_format.clone()) .with_datetime_format(writer_options.datetime_format.clone()) .with_timestamp_format(writer_options.timestamp_format.clone()) .with_time_format(writer_options.time_format.clone()) - .with_null(writer_options.null_value.clone())) + .with_null(writer_options.null_value.clone()) + .with_double_quote(writer_options.double_quote)) } diff --git a/datafusion/proto-common/src/generated/pbjson.rs b/datafusion/proto-common/src/generated/pbjson.rs index ead29d9b92e0..3cf34aeb6d01 100644 --- a/datafusion/proto-common/src/generated/pbjson.rs +++ b/datafusion/proto-common/src/generated/pbjson.rs @@ -1881,6 +1881,9 @@ impl serde::Serialize for CsvOptions { if !self.comment.is_empty() { len += 1; } + if !self.double_quote.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion_common.CsvOptions", len)?; if !self.has_header.is_empty() { #[allow(clippy::needless_borrow)] @@ -1929,6 +1932,10 @@ impl serde::Serialize for CsvOptions { #[allow(clippy::needless_borrow)] struct_ser.serialize_field("comment", pbjson::private::base64::encode(&self.comment).as_str())?; } + if !self.double_quote.is_empty() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("doubleQuote", pbjson::private::base64::encode(&self.double_quote).as_str())?; + } struct_ser.end() } } @@ -1960,6 +1967,8 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { "null_value", "nullValue", "comment", + "double_quote", + "doubleQuote", ]; #[allow(clippy::enum_variant_names)] @@ -1977,6 +1986,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { TimeFormat, NullValue, Comment, + DoubleQuote, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -2011,6 +2021,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { "timeFormat" | "time_format" => Ok(GeneratedField::TimeFormat), "nullValue" | "null_value" => Ok(GeneratedField::NullValue), "comment" => Ok(GeneratedField::Comment), + "doubleQuote" | "double_quote" => Ok(GeneratedField::DoubleQuote), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -2043,6 +2054,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { let mut time_format__ = None; let mut null_value__ = None; let mut comment__ = None; + let mut double_quote__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::HasHeader => { @@ -2135,6 +2147,14 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) ; } + GeneratedField::DoubleQuote => { + if double_quote__.is_some() { + return Err(serde::de::Error::duplicate_field("doubleQuote")); + } + double_quote__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } } } Ok(CsvOptions { @@ -2151,6 +2171,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { time_format: time_format__.unwrap_or_default(), null_value: null_value__.unwrap_or_default(), comment: comment__.unwrap_or_default(), + double_quote: double_quote__.unwrap_or_default(), }) } } @@ -2189,6 +2210,15 @@ impl serde::Serialize for CsvWriterOptions { if !self.null_value.is_empty() { len += 1; } + if !self.quote.is_empty() { + len += 1; + } + if !self.escape.is_empty() { + len += 1; + } + if self.double_quote { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion_common.CsvWriterOptions", len)?; if self.compression != 0 { let v = CompressionTypeVariant::try_from(self.compression) @@ -2216,6 +2246,15 @@ impl serde::Serialize for CsvWriterOptions { if !self.null_value.is_empty() { struct_ser.serialize_field("nullValue", &self.null_value)?; } + if !self.quote.is_empty() { + struct_ser.serialize_field("quote", &self.quote)?; + } + if !self.escape.is_empty() { + struct_ser.serialize_field("escape", &self.escape)?; + } + if self.double_quote { + struct_ser.serialize_field("doubleQuote", &self.double_quote)?; + } struct_ser.end() } } @@ -2240,6 +2279,10 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { "timeFormat", "null_value", "nullValue", + "quote", + "escape", + "double_quote", + "doubleQuote", ]; #[allow(clippy::enum_variant_names)] @@ -2252,6 +2295,9 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { TimestampFormat, TimeFormat, NullValue, + Quote, + Escape, + DoubleQuote, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -2281,6 +2327,9 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { "timestampFormat" | "timestamp_format" => Ok(GeneratedField::TimestampFormat), "timeFormat" | "time_format" => Ok(GeneratedField::TimeFormat), "nullValue" | "null_value" => Ok(GeneratedField::NullValue), + "quote" => Ok(GeneratedField::Quote), + "escape" => Ok(GeneratedField::Escape), + "doubleQuote" | "double_quote" => Ok(GeneratedField::DoubleQuote), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -2308,6 +2357,9 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { let mut timestamp_format__ = None; let mut time_format__ = None; let mut null_value__ = None; + let mut quote__ = None; + let mut escape__ = None; + let mut double_quote__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Compression => { @@ -2358,6 +2410,24 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { } null_value__ = Some(map_.next_value()?); } + GeneratedField::Quote => { + if quote__.is_some() { + return Err(serde::de::Error::duplicate_field("quote")); + } + quote__ = Some(map_.next_value()?); + } + GeneratedField::Escape => { + if escape__.is_some() { + return Err(serde::de::Error::duplicate_field("escape")); + } + escape__ = Some(map_.next_value()?); + } + GeneratedField::DoubleQuote => { + if double_quote__.is_some() { + return Err(serde::de::Error::duplicate_field("doubleQuote")); + } + double_quote__ = Some(map_.next_value()?); + } } } Ok(CsvWriterOptions { @@ -2369,6 +2439,9 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { timestamp_format: timestamp_format__.unwrap_or_default(), time_format: time_format__.unwrap_or_default(), null_value: null_value__.unwrap_or_default(), + quote: quote__.unwrap_or_default(), + escape: escape__.unwrap_or_default(), + double_quote: double_quote__.unwrap_or_default(), }) } } diff --git a/datafusion/proto-common/src/generated/prost.rs b/datafusion/proto-common/src/generated/prost.rs index b306f3212a2f..57893321e665 100644 --- a/datafusion/proto-common/src/generated/prost.rs +++ b/datafusion/proto-common/src/generated/prost.rs @@ -575,6 +575,15 @@ pub struct CsvWriterOptions { /// Optional value to represent null #[prost(string, tag = "8")] pub null_value: ::prost::alloc::string::String, + /// Optional quote. Defaults to `b'"'` + #[prost(string, tag = "9")] + pub quote: ::prost::alloc::string::String, + /// Optional escape. Defaults to `'\\'` + #[prost(string, tag = "10")] + pub escape: ::prost::alloc::string::String, + /// Optional flag whether to double quote instead of escaping. Defaults to `true` + #[prost(bool, tag = "11")] + pub double_quote: bool, } /// Options controlling CSV format #[allow(clippy::derive_partial_eq_without_eq)] @@ -619,6 +628,9 @@ pub struct CsvOptions { /// Optional comment character as a byte #[prost(bytes = "vec", tag = "13")] pub comment: ::prost::alloc::vec::Vec, + /// Indicates if quotes are doubled + #[prost(bytes = "vec", tag = "14")] + pub double_quote: ::prost::alloc::vec::Vec, } /// Options controlling CSV format #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index a3dc826a79ca..877043f66809 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -896,6 +896,7 @@ impl TryFrom<&CsvOptions> for protobuf::CsvOptions { delimiter: vec![opts.delimiter], quote: vec![opts.quote], escape: opts.escape.map_or_else(Vec::new, |e| vec![e]), + double_quote: opts.double_quote.map_or_else(Vec::new, |h| vec![h as u8]), compression: compression.into(), schema_infer_max_rec: opts.schema_infer_max_rec as u64, date_format: opts.date_format.clone().unwrap_or_default(), @@ -1022,5 +1023,8 @@ pub(crate) fn csv_writer_options_to_proto( timestamp_format: csv_options.timestamp_format().unwrap_or("").to_owned(), time_format: csv_options.time_format().unwrap_or("").to_owned(), null_value: csv_options.null().to_owned(), + quote: (csv_options.quote() as char).to_string(), + escape: (csv_options.escape() as char).to_string(), + double_quote: csv_options.double_quote(), } } diff --git a/datafusion/proto/src/generated/datafusion_proto_common.rs b/datafusion/proto/src/generated/datafusion_proto_common.rs index b306f3212a2f..875fe8992e90 100644 --- a/datafusion/proto/src/generated/datafusion_proto_common.rs +++ b/datafusion/proto/src/generated/datafusion_proto_common.rs @@ -575,6 +575,15 @@ pub struct CsvWriterOptions { /// Optional value to represent null #[prost(string, tag = "8")] pub null_value: ::prost::alloc::string::String, + /// Optional quote. Defaults to `b'"'` + #[prost(string, tag = "9")] + pub quote: ::prost::alloc::string::String, + /// Optional escape. Defaults to `'\\'` + #[prost(string, tag = "10")] + pub escape: ::prost::alloc::string::String, + /// Optional flag whether to double quotes, instead of escaping. Defaults to `true` + #[prost(bool, tag = "11")] + pub double_quote: bool, } /// Options controlling CSV format #[allow(clippy::derive_partial_eq_without_eq)] @@ -619,6 +628,9 @@ pub struct CsvOptions { /// Optional comment character as a byte #[prost(bytes = "vec", tag = "13")] pub comment: ::prost::alloc::vec::Vec, + /// Indicates if quotes are doubled + #[prost(bytes = "vec", tag = "14")] + pub double_quote: ::prost::alloc::vec::Vec, } /// Options controlling CSV format #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 8a488d30cf24..56e702704798 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -1010,7 +1010,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { .as_ref() .ok_or_else(|| proto_error("Missing required field in protobuf"))? .try_into()?; - let sink_schema = convert_required!(sink.sink_schema)?; + let sink_schema = input.schema(); let sort_order = sink .sort_order .as_ref() @@ -1027,7 +1027,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { Ok(Arc::new(DataSinkExec::new( input, Arc::new(data_sink), - Arc::new(sink_schema), + sink_schema, sort_order, ))) } @@ -1040,7 +1040,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { .as_ref() .ok_or_else(|| proto_error("Missing required field in protobuf"))? .try_into()?; - let sink_schema = convert_required!(sink.sink_schema)?; + let sink_schema = input.schema(); let sort_order = sink .sort_order .as_ref() @@ -1057,7 +1057,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { Ok(Arc::new(DataSinkExec::new( input, Arc::new(data_sink), - Arc::new(sink_schema), + sink_schema, sort_order, ))) } @@ -1070,7 +1070,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { .as_ref() .ok_or_else(|| proto_error("Missing required field in protobuf"))? .try_into()?; - let sink_schema = convert_required!(sink.sink_schema)?; + let sink_schema = input.schema(); let sort_order = sink .sort_order .as_ref() @@ -1087,7 +1087,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { Ok(Arc::new(DataSinkExec::new( input, Arc::new(data_sink), - Arc::new(sink_schema), + sink_schema, sort_order, ))) } diff --git a/datafusion/sqllogictest/test_files/csv_files.slt b/datafusion/sqllogictest/test_files/csv_files.slt index 8902b3eebf24..a8a689cbb8b5 100644 --- a/datafusion/sqllogictest/test_files/csv_files.slt +++ b/datafusion/sqllogictest/test_files/csv_files.slt @@ -226,3 +226,70 @@ SELECT * from stored_table_with_comments; ---- column1 column2 2 3 + +# read csv with double quote +statement ok +CREATE EXTERNAL TABLE csv_with_double_quote ( +c1 VARCHAR, +c2 VARCHAR +) STORED AS CSV +OPTIONS ('format.delimiter' ',', + 'format.has_header' 'true', + 'format.double_quote' 'true') +LOCATION '../core/tests/data/double_quote.csv'; + +query TT +select * from csv_with_double_quote +---- +id0 "value0" +id1 "value1" +id2 "value2" +id3 "value3" + +# ensure that double quote option is used when writing to csv +query TT +COPY csv_with_double_quote TO 'test_files/scratch/csv_files/table_with_double_quotes.csv' +STORED AS csv +OPTIONS ('format.double_quote' 'true'); +---- +4 + +statement ok +CREATE EXTERNAL TABLE stored_table_with_double_quotes ( +col1 TEXT, +col2 TEXT +) STORED AS CSV +LOCATION 'test_files/scratch/csv_files/table_with_double_quotes.csv' +OPTIONS ('format.double_quote' 'true'); + +query TT +select * from stored_table_with_double_quotes; +---- +id0 "value0" +id1 "value1" +id2 "value2" +id3 "value3" + +# ensure when double quote option is disabled that quotes are escaped instead +query TT +COPY csv_with_double_quote TO 'test_files/scratch/csv_files/table_with_escaped_quotes.csv' +STORED AS csv +OPTIONS ('format.double_quote' 'false', 'format.escape' '#'); +---- +4 + +statement ok +CREATE EXTERNAL TABLE stored_table_with_escaped_quotes ( +col1 TEXT, +col2 TEXT +) STORED AS CSV +LOCATION 'test_files/scratch/csv_files/table_with_escaped_quotes.csv' +OPTIONS ('format.double_quote' 'false', 'format.escape' '#'); + +query TT +select * from stored_table_with_escaped_quotes; +---- +id0 "value0" +id1 "value1" +id2 "value2" +id3 "value3"