From f39c040ace0b34b0775827907aa01d6bb71cbb14 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 28 Dec 2023 11:38:16 -0700 Subject: [PATCH] Add serde support for CSV FileTypeWriterOptions (#8641) --- datafusion/proto/proto/datafusion.proto | 18 ++ datafusion/proto/src/generated/pbjson.rs | 213 ++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 29 ++- datafusion/proto/src/logical_plan/mod.rs | 74 ++++++ .../proto/src/physical_plan/from_proto.rs | 12 +- .../tests/cases/roundtrip_logical_plan.rs | 64 +++++- 6 files changed, 406 insertions(+), 4 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index d02fc8e91b41..59b82efcbb43 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1207,6 +1207,7 @@ message FileTypeWriterOptions { oneof FileType { JsonWriterOptions json_options = 1; ParquetWriterOptions parquet_options = 2; + CsvWriterOptions csv_options = 3; } } @@ -1218,6 +1219,23 @@ message ParquetWriterOptions { WriterProperties writer_properties = 1; } +message CsvWriterOptions { + // Optional column delimiter. Defaults to `b','` + string delimiter = 1; + // Whether to write column names as file headers. Defaults to `true` + bool has_header = 2; + // Optional date format for date arrays + string date_format = 3; + // Optional datetime format for datetime arrays + string datetime_format = 4; + // Optional timestamp format for timestamp arrays + string timestamp_format = 5; + // Optional time format for time arrays + string time_format = 6; + // Optional value to represent null + string null_value = 7; +} + message WriterProperties { uint64 data_page_size_limit = 1; uint64 dictionary_page_size_limit = 2; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index f860b1f1e6a0..956244ffdbc2 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -5151,6 +5151,205 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { deserializer.deserialize_struct("datafusion.CsvScanExecNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for CsvWriterOptions { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.delimiter.is_empty() { + len += 1; + } + if self.has_header { + len += 1; + } + if !self.date_format.is_empty() { + len += 1; + } + if !self.datetime_format.is_empty() { + len += 1; + } + if !self.timestamp_format.is_empty() { + len += 1; + } + if !self.time_format.is_empty() { + len += 1; + } + if !self.null_value.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CsvWriterOptions", len)?; + if !self.delimiter.is_empty() { + struct_ser.serialize_field("delimiter", &self.delimiter)?; + } + if self.has_header { + struct_ser.serialize_field("hasHeader", &self.has_header)?; + } + if !self.date_format.is_empty() { + struct_ser.serialize_field("dateFormat", &self.date_format)?; + } + if !self.datetime_format.is_empty() { + struct_ser.serialize_field("datetimeFormat", &self.datetime_format)?; + } + if !self.timestamp_format.is_empty() { + struct_ser.serialize_field("timestampFormat", &self.timestamp_format)?; + } + if !self.time_format.is_empty() { + struct_ser.serialize_field("timeFormat", &self.time_format)?; + } + if !self.null_value.is_empty() { + struct_ser.serialize_field("nullValue", &self.null_value)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for CsvWriterOptions { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "delimiter", + "has_header", + "hasHeader", + "date_format", + "dateFormat", + "datetime_format", + "datetimeFormat", + "timestamp_format", + "timestampFormat", + "time_format", + "timeFormat", + "null_value", + "nullValue", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Delimiter, + HasHeader, + DateFormat, + DatetimeFormat, + TimestampFormat, + TimeFormat, + NullValue, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "delimiter" => Ok(GeneratedField::Delimiter), + "hasHeader" | "has_header" => Ok(GeneratedField::HasHeader), + "dateFormat" | "date_format" => Ok(GeneratedField::DateFormat), + "datetimeFormat" | "datetime_format" => Ok(GeneratedField::DatetimeFormat), + "timestampFormat" | "timestamp_format" => Ok(GeneratedField::TimestampFormat), + "timeFormat" | "time_format" => Ok(GeneratedField::TimeFormat), + "nullValue" | "null_value" => Ok(GeneratedField::NullValue), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CsvWriterOptions; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.CsvWriterOptions") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut delimiter__ = None; + let mut has_header__ = None; + let mut date_format__ = None; + let mut datetime_format__ = None; + let mut timestamp_format__ = None; + let mut time_format__ = None; + let mut null_value__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Delimiter => { + if delimiter__.is_some() { + return Err(serde::de::Error::duplicate_field("delimiter")); + } + delimiter__ = Some(map_.next_value()?); + } + GeneratedField::HasHeader => { + if has_header__.is_some() { + return Err(serde::de::Error::duplicate_field("hasHeader")); + } + has_header__ = Some(map_.next_value()?); + } + GeneratedField::DateFormat => { + if date_format__.is_some() { + return Err(serde::de::Error::duplicate_field("dateFormat")); + } + date_format__ = Some(map_.next_value()?); + } + GeneratedField::DatetimeFormat => { + if datetime_format__.is_some() { + return Err(serde::de::Error::duplicate_field("datetimeFormat")); + } + datetime_format__ = Some(map_.next_value()?); + } + GeneratedField::TimestampFormat => { + if timestamp_format__.is_some() { + return Err(serde::de::Error::duplicate_field("timestampFormat")); + } + timestamp_format__ = Some(map_.next_value()?); + } + GeneratedField::TimeFormat => { + if time_format__.is_some() { + return Err(serde::de::Error::duplicate_field("timeFormat")); + } + time_format__ = Some(map_.next_value()?); + } + GeneratedField::NullValue => { + if null_value__.is_some() { + return Err(serde::de::Error::duplicate_field("nullValue")); + } + null_value__ = Some(map_.next_value()?); + } + } + } + Ok(CsvWriterOptions { + delimiter: delimiter__.unwrap_or_default(), + has_header: has_header__.unwrap_or_default(), + date_format: date_format__.unwrap_or_default(), + datetime_format: datetime_format__.unwrap_or_default(), + timestamp_format: timestamp_format__.unwrap_or_default(), + time_format: time_format__.unwrap_or_default(), + null_value: null_value__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.CsvWriterOptions", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for CubeNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -7893,6 +8092,9 @@ impl serde::Serialize for FileTypeWriterOptions { file_type_writer_options::FileType::ParquetOptions(v) => { struct_ser.serialize_field("parquetOptions", v)?; } + file_type_writer_options::FileType::CsvOptions(v) => { + struct_ser.serialize_field("csvOptions", v)?; + } } } struct_ser.end() @@ -7909,12 +8111,15 @@ impl<'de> serde::Deserialize<'de> for FileTypeWriterOptions { "jsonOptions", "parquet_options", "parquetOptions", + "csv_options", + "csvOptions", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { JsonOptions, ParquetOptions, + CsvOptions, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7938,6 +8143,7 @@ impl<'de> serde::Deserialize<'de> for FileTypeWriterOptions { match value { "jsonOptions" | "json_options" => Ok(GeneratedField::JsonOptions), "parquetOptions" | "parquet_options" => Ok(GeneratedField::ParquetOptions), + "csvOptions" | "csv_options" => Ok(GeneratedField::CsvOptions), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7972,6 +8178,13 @@ impl<'de> serde::Deserialize<'de> for FileTypeWriterOptions { return Err(serde::de::Error::duplicate_field("parquetOptions")); } file_type__ = map_.next_value::<::std::option::Option<_>>()?.map(file_type_writer_options::FileType::ParquetOptions) +; + } + GeneratedField::CsvOptions => { + if file_type__.is_some() { + return Err(serde::de::Error::duplicate_field("csvOptions")); + } + file_type__ = map_.next_value::<::std::option::Option<_>>()?.map(file_type_writer_options::FileType::CsvOptions) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 459d5a965cd3..32e892e663ef 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1642,7 +1642,7 @@ pub struct PartitionColumn { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FileTypeWriterOptions { - #[prost(oneof = "file_type_writer_options::FileType", tags = "1, 2")] + #[prost(oneof = "file_type_writer_options::FileType", tags = "1, 2, 3")] pub file_type: ::core::option::Option, } /// Nested message and enum types in `FileTypeWriterOptions`. @@ -1654,6 +1654,8 @@ pub mod file_type_writer_options { JsonOptions(super::JsonWriterOptions), #[prost(message, tag = "2")] ParquetOptions(super::ParquetWriterOptions), + #[prost(message, tag = "3")] + CsvOptions(super::CsvWriterOptions), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -1670,6 +1672,31 @@ pub struct ParquetWriterOptions { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct CsvWriterOptions { + /// Optional column delimiter. Defaults to `b','` + #[prost(string, tag = "1")] + pub delimiter: ::prost::alloc::string::String, + /// Whether to write column names as file headers. Defaults to `true` + #[prost(bool, tag = "2")] + pub has_header: bool, + /// Optional date format for date arrays + #[prost(string, tag = "3")] + pub date_format: ::prost::alloc::string::String, + /// Optional datetime format for datetime arrays + #[prost(string, tag = "4")] + pub datetime_format: ::prost::alloc::string::String, + /// Optional timestamp format for timestamp arrays + #[prost(string, tag = "5")] + pub timestamp_format: ::prost::alloc::string::String, + /// Optional time format for time arrays + #[prost(string, tag = "6")] + pub time_format: ::prost::alloc::string::String, + /// Optional value to represent null + #[prost(string, tag = "7")] + pub null_value: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct WriterProperties { #[prost(uint64, tag = "1")] pub data_page_size_limit: u64, diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index d137a41fa19b..e997bcde426e 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use arrow::csv::WriterBuilder; use std::collections::HashMap; use std::fmt::Debug; use std::str::FromStr; @@ -64,6 +65,7 @@ use datafusion_expr::{ }; use datafusion::parquet::file::properties::{WriterProperties, WriterVersion}; +use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; use datafusion_expr::dml::CopyOptions; use prost::bytes::BufMut; @@ -846,6 +848,20 @@ impl AsLogicalPlan for LogicalPlanNode { Some(copy_to_node::CopyOptions::WriterOptions(opt)) => { match &opt.file_type { Some(ft) => match ft { + file_type_writer_options::FileType::CsvOptions( + writer_options, + ) => { + let writer_builder = + csv_writer_options_from_proto(writer_options)?; + CopyOptions::WriterOptions(Box::new( + FileTypeWriterOptions::CSV( + CsvWriterOptions::new( + writer_builder, + CompressionTypeVariant::UNCOMPRESSED, + ), + ), + )) + } file_type_writer_options::FileType::ParquetOptions( writer_options, ) => { @@ -1630,6 +1646,40 @@ impl AsLogicalPlan for LogicalPlanNode { } CopyOptions::WriterOptions(opt) => { match opt.as_ref() { + FileTypeWriterOptions::CSV(csv_opts) => { + let csv_options = &csv_opts.writer_options; + let csv_writer_options = protobuf::CsvWriterOptions { + delimiter: (csv_options.delimiter() as char) + .to_string(), + has_header: csv_options.header(), + date_format: csv_options + .date_format() + .unwrap_or("") + .to_owned(), + datetime_format: csv_options + .datetime_format() + .unwrap_or("") + .to_owned(), + timestamp_format: csv_options + .timestamp_format() + .unwrap_or("") + .to_owned(), + time_format: csv_options + .time_format() + .unwrap_or("") + .to_owned(), + null_value: csv_options.null().to_owned(), + }; + let csv_options = + file_type_writer_options::FileType::CsvOptions( + csv_writer_options, + ); + Some(copy_to_node::CopyOptions::WriterOptions( + protobuf::FileTypeWriterOptions { + file_type: Some(csv_options), + }, + )) + } FileTypeWriterOptions::Parquet(parquet_opts) => { let parquet_writer_options = protobuf::ParquetWriterOptions { @@ -1674,6 +1724,30 @@ impl AsLogicalPlan for LogicalPlanNode { } } +pub(crate) fn csv_writer_options_from_proto( + writer_options: &protobuf::CsvWriterOptions, +) -> Result { + let mut builder = WriterBuilder::new(); + if !writer_options.delimiter.is_empty() { + if let Some(delimiter) = writer_options.delimiter.chars().next() { + if delimiter.is_ascii() { + builder = builder.with_delimiter(delimiter as u8); + } else { + return Err(proto_error("CSV Delimiter is not ASCII")); + } + } else { + return Err(proto_error("Error parsing CSV Delimiter")); + } + } + Ok(builder + .with_header(writer_options.has_header) + .with_date_format(writer_options.date_format.clone()) + .with_datetime_format(writer_options.datetime_format.clone()) + .with_timestamp_format(writer_options.timestamp_format.clone()) + .with_time_format(writer_options.time_format.clone()) + .with_null(writer_options.null_value.clone())) +} + pub(crate) fn writer_properties_to_proto( props: &WriterProperties, ) -> protobuf::WriterProperties { diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 824eb60a5715..6f1e811510c6 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -39,6 +39,7 @@ use datafusion::physical_plan::windows::create_window_expr; use datafusion::physical_plan::{ functions, ColumnStatistics, Partitioning, PhysicalExpr, Statistics, WindowExpr, }; +use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; @@ -53,7 +54,7 @@ use crate::logical_plan; use crate::protobuf; use crate::protobuf::physical_expr_node::ExprType; -use crate::logical_plan::writer_properties_from_proto; +use crate::logical_plan::{csv_writer_options_from_proto, writer_properties_from_proto}; use chrono::{TimeZone, Utc}; use object_store::path::Path; use object_store::ObjectMeta; @@ -766,11 +767,18 @@ impl TryFrom<&protobuf::FileTypeWriterOptions> for FileTypeWriterOptions { let file_type = value .file_type .as_ref() - .ok_or_else(|| proto_error("Missing required field in protobuf"))?; + .ok_or_else(|| proto_error("Missing required file_type field in protobuf"))?; match file_type { protobuf::file_type_writer_options::FileType::JsonOptions(opts) => Ok( Self::JSON(JsonWriterOptions::new(opts.compression().into())), ), + protobuf::file_type_writer_options::FileType::CsvOptions(opt) => { + let write_options = csv_writer_options_from_proto(opt)?; + Ok(Self::CSV(CsvWriterOptions::new( + write_options, + CompressionTypeVariant::UNCOMPRESSED, + ))) + } protobuf::file_type_writer_options::FileType::ParquetOptions(opt) => { let props = opt.writer_properties.clone().unwrap_or_default(); let writer_properties = writer_properties_from_proto(&props)?; diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 3eeae01a643e..2d7d85abda96 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -20,6 +20,7 @@ use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; use arrow::array::{ArrayRef, FixedSizeListArray}; +use arrow::csv::WriterBuilder; use arrow::datatypes::{ DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, @@ -35,8 +36,10 @@ use datafusion::parquet::file::properties::{WriterProperties, WriterVersion}; use datafusion::physical_plan::functions::make_scalar_function; use datafusion::prelude::{create_udf, CsvReadOptions, SessionConfig, SessionContext}; use datafusion::test_util::{TestTableFactory, TestTableProvider}; +use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; use datafusion_common::file_options::StatementOptions; +use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{internal_err, not_impl_err, plan_err, FileTypeWriterOptions}; use datafusion_common::{DFField, DFSchema, DFSchemaRef, DataFusionError, ScalarValue}; use datafusion_common::{FileType, Result}; @@ -386,10 +389,69 @@ async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { } _ => panic!(), } - Ok(()) } +#[tokio::test] +async fn roundtrip_logical_plan_copy_to_csv() -> Result<()> { + let ctx = SessionContext::new(); + + let input = create_csv_scan(&ctx).await?; + + let writer_properties = WriterBuilder::new() + .with_delimiter(b'*') + .with_date_format("dd/MM/yyyy".to_string()) + .with_datetime_format("dd/MM/yyyy HH:mm:ss".to_string()) + .with_timestamp_format("HH:mm:ss.SSSSSS".to_string()) + .with_time_format("HH:mm:ss".to_string()) + .with_null("NIL".to_string()); + + let plan = LogicalPlan::Copy(CopyTo { + input: Arc::new(input), + output_url: "test.csv".to_string(), + file_format: FileType::CSV, + single_file_output: true, + copy_options: CopyOptions::WriterOptions(Box::new(FileTypeWriterOptions::CSV( + CsvWriterOptions::new( + writer_properties, + CompressionTypeVariant::UNCOMPRESSED, + ), + ))), + }); + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + + match logical_round_trip { + LogicalPlan::Copy(copy_to) => { + assert_eq!("test.csv", copy_to.output_url); + assert_eq!(FileType::CSV, copy_to.file_format); + assert!(copy_to.single_file_output); + match ©_to.copy_options { + CopyOptions::WriterOptions(y) => match y.as_ref() { + FileTypeWriterOptions::CSV(p) => { + let props = &p.writer_options; + assert_eq!(b'*', props.delimiter()); + assert_eq!("dd/MM/yyyy", props.date_format().unwrap()); + assert_eq!( + "dd/MM/yyyy HH:mm:ss", + props.datetime_format().unwrap() + ); + assert_eq!("HH:mm:ss.SSSSSS", props.timestamp_format().unwrap()); + assert_eq!("HH:mm:ss", props.time_format().unwrap()); + assert_eq!("NIL", props.null()); + } + _ => panic!(), + }, + _ => panic!(), + } + } + _ => panic!(), + } + + Ok(()) +} async fn create_csv_scan(ctx: &SessionContext) -> Result { ctx.register_csv("t1", "tests/testdata/test.csv", CsvReadOptions::default()) .await?;