From 225253b5eaeb1fe9a78f2eed851cd4d5b9d74632 Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Thu, 28 Dec 2023 23:49:18 +0300 Subject: [PATCH] Incorporate review suggestions --- .../core/src/datasource/file_format/csv.rs | 16 +++++----------- .../core/src/datasource/file_format/json.rs | 6 +----- .../core/src/datasource/file_format/write/mod.rs | 9 +++------ .../file_format/write/orchestration.rs | 9 ++++----- .../src/datasource/physical_plan/file_stream.rs | 6 +----- 5 files changed, 14 insertions(+), 32 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 180f772920c9..963a30c1e56b 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -425,21 +425,15 @@ impl CsvSerializationSchema { #[async_trait] impl SerializationSchema for CsvSerializationSchema { - async fn serialize(&self, batch: RecordBatch) -> Result { + async fn serialize(&self, batch: RecordBatch, initial: bool) -> Result { let mut buffer = Vec::with_capacity(4096); let builder = self.builder.clone(); - let mut writer = builder.with_header(self.header).build(&mut buffer); + let header = self.header && initial; + let mut writer = builder.with_header(header).build(&mut buffer); writer.write(&batch)?; drop(writer); Ok(Bytes::from(buffer)) } - - fn duplicate_headerless(&self) -> Arc { - let new_self = CsvSerializationSchema::new() - .with_builder(self.builder.clone()) - .with_header(false); - Arc::new(new_self) as _ - } } /// Implements [`DataSink`] for writing to a CSV file. @@ -828,7 +822,7 @@ mod tests { .await?; let batch = concat_batches(&batches[0].schema(), &batches)?; let serializer = CsvSerializationSchema::new(); - let bytes = serializer.serialize(batch).await?; + let bytes = serializer.serialize(batch, true).await?; assert_eq!( "c2,c3\n2,1\n5,-40\n1,29\n1,-85\n5,-82\n4,-111\n3,104\n3,13\n1,38\n4,-38\n", String::from_utf8(bytes.into()).unwrap() @@ -852,7 +846,7 @@ mod tests { .await?; let batch = concat_batches(&batches[0].schema(), &batches)?; let serializer = CsvSerializationSchema::new().with_header(false); - let bytes = serializer.serialize(batch).await?; + let bytes = serializer.serialize(batch, true).await?; assert_eq!( "2,1\n5,-40\n1,29\n1,-85\n5,-82\n4,-111\n3,104\n3,13\n1,38\n4,-38\n", String::from_utf8(bytes.into()).unwrap() diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index 5f2f4e6b7afb..78211ff6be42 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -206,16 +206,12 @@ impl JsonSerializationSchema { #[async_trait] impl SerializationSchema for JsonSerializationSchema { - async fn serialize(&self, batch: RecordBatch) -> Result { + async fn serialize(&self, batch: RecordBatch, _initial: bool) -> Result { let mut buffer = Vec::with_capacity(4096); let mut writer = json::LineDelimitedWriter::new(&mut buffer); writer.write(&batch)?; Ok(Bytes::from(buffer)) } - - fn duplicate_headerless(&self) -> Arc { - Arc::new(JsonSerializationSchema::new()) as _ - } } /// Implements [`DataSink`] for writing to a Json file. diff --git a/datafusion/core/src/datasource/file_format/write/mod.rs b/datafusion/core/src/datasource/file_format/write/mod.rs index 11592cb5dfdb..fc88725b39c4 100644 --- a/datafusion/core/src/datasource/file_format/write/mod.rs +++ b/datafusion/core/src/datasource/file_format/write/mod.rs @@ -147,12 +147,9 @@ impl AsyncWrite for AbortableWrite { #[async_trait] pub trait SerializationSchema: Sync + Send { /// Asynchronously serializes a `RecordBatch` and returns the serialized bytes. - async fn serialize(&self, batch: RecordBatch) -> Result; - - /// Duplicates itself (sans header configuration) to support serializing - /// multiple batches in parallel on multiple cores. Unless we are serializing - /// a CSV file, this method is no-op. - fn duplicate_headerless(&self) -> Arc; + /// Parameter `initial` signals whether the given batch is the first batch. + /// This distinction is important for certain serializers (like CSV). + async fn serialize(&self, batch: RecordBatch, initial: bool) -> Result; } /// Returns an [`AbortableWrite`] which writes to the given object store location diff --git a/datafusion/core/src/datasource/file_format/write/orchestration.rs b/datafusion/core/src/datasource/file_format/write/orchestration.rs index aa8087e2b028..d32a2b0137af 100644 --- a/datafusion/core/src/datasource/file_format/write/orchestration.rs +++ b/datafusion/core/src/datasource/file_format/write/orchestration.rs @@ -47,24 +47,23 @@ type SerializerType = Arc; /// so that the caller may handle aborting failed writes. pub(crate) async fn serialize_rb_stream_to_object_store( mut data_rx: Receiver, - mut serializer: Arc, + serializer: Arc, mut writer: AbortableWrite>, ) -> std::result::Result<(WriterType, u64), (WriterType, DataFusionError)> { let (tx, mut rx) = mpsc::channel::>>(100); - // Initially, has_header can be true for CSV use cases. Then, we must turn - // it off to maintain the integrity of the writing process. let serialize_task = tokio::spawn(async move { + // Some serializers (like CSV) handle the first batch differently than + // subsequent batches, so we track that here. let mut initial = true; while let Some(batch) = data_rx.recv().await { let serializer_clone = serializer.clone(); let handle = tokio::spawn(async move { let num_rows = batch.num_rows(); - let bytes = serializer_clone.serialize(batch).await?; + let bytes = serializer_clone.serialize(batch, initial).await?; Ok((num_rows, bytes)) }); if initial { - serializer = serializer.duplicate_headerless(); initial = false; } tx.send(handle).await.map_err(|_| { diff --git a/datafusion/core/src/datasource/physical_plan/file_stream.rs b/datafusion/core/src/datasource/physical_plan/file_stream.rs index 5205f613f5b2..c266a795e335 100644 --- a/datafusion/core/src/datasource/physical_plan/file_stream.rs +++ b/datafusion/core/src/datasource/physical_plan/file_stream.rs @@ -991,12 +991,8 @@ mod tests { #[async_trait] impl SerializationSchema for TestSerializer { - async fn serialize(&self, _batch: RecordBatch) -> Result { + async fn serialize(&self, _batch: RecordBatch, _initial: bool) -> Result { Ok(self.bytes.clone()) } - - fn duplicate_headerless(&self) -> Arc { - unimplemented!() - } } }