Skip to content

Commit

Permalink
Incorporate review suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
ozankabak committed Dec 28, 2023
1 parent 15f9fb6 commit 225253b
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 32 deletions.
16 changes: 5 additions & 11 deletions datafusion/core/src/datasource/file_format/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -425,21 +425,15 @@ impl CsvSerializationSchema {

#[async_trait]
impl SerializationSchema for CsvSerializationSchema {
async fn serialize(&self, batch: RecordBatch) -> Result<Bytes> {
async fn serialize(&self, batch: RecordBatch, initial: bool) -> Result<Bytes> {
let mut buffer = Vec::with_capacity(4096);
let builder = self.builder.clone();
let mut writer = builder.with_header(self.header).build(&mut buffer);
let header = self.header && initial;
let mut writer = builder.with_header(header).build(&mut buffer);
writer.write(&batch)?;
drop(writer);
Ok(Bytes::from(buffer))
}

fn duplicate_headerless(&self) -> Arc<dyn SerializationSchema> {
let new_self = CsvSerializationSchema::new()
.with_builder(self.builder.clone())
.with_header(false);
Arc::new(new_self) as _
}
}

/// Implements [`DataSink`] for writing to a CSV file.
Expand Down Expand Up @@ -828,7 +822,7 @@ mod tests {
.await?;
let batch = concat_batches(&batches[0].schema(), &batches)?;
let serializer = CsvSerializationSchema::new();
let bytes = serializer.serialize(batch).await?;
let bytes = serializer.serialize(batch, true).await?;
assert_eq!(
"c2,c3\n2,1\n5,-40\n1,29\n1,-85\n5,-82\n4,-111\n3,104\n3,13\n1,38\n4,-38\n",
String::from_utf8(bytes.into()).unwrap()
Expand All @@ -852,7 +846,7 @@ mod tests {
.await?;
let batch = concat_batches(&batches[0].schema(), &batches)?;
let serializer = CsvSerializationSchema::new().with_header(false);
let bytes = serializer.serialize(batch).await?;
let bytes = serializer.serialize(batch, true).await?;
assert_eq!(
"2,1\n5,-40\n1,29\n1,-85\n5,-82\n4,-111\n3,104\n3,13\n1,38\n4,-38\n",
String::from_utf8(bytes.into()).unwrap()
Expand Down
6 changes: 1 addition & 5 deletions datafusion/core/src/datasource/file_format/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,16 +206,12 @@ impl JsonSerializationSchema {

#[async_trait]
impl SerializationSchema for JsonSerializationSchema {
async fn serialize(&self, batch: RecordBatch) -> Result<Bytes> {
async fn serialize(&self, batch: RecordBatch, _initial: bool) -> Result<Bytes> {
let mut buffer = Vec::with_capacity(4096);
let mut writer = json::LineDelimitedWriter::new(&mut buffer);
writer.write(&batch)?;
Ok(Bytes::from(buffer))
}

fn duplicate_headerless(&self) -> Arc<dyn SerializationSchema> {
Arc::new(JsonSerializationSchema::new()) as _
}
}

/// Implements [`DataSink`] for writing to a Json file.
Expand Down
9 changes: 3 additions & 6 deletions datafusion/core/src/datasource/file_format/write/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,9 @@ impl<W: AsyncWrite + Unpin + Send> AsyncWrite for AbortableWrite<W> {
#[async_trait]
pub trait SerializationSchema: Sync + Send {
/// Asynchronously serializes a `RecordBatch` and returns the serialized bytes.
async fn serialize(&self, batch: RecordBatch) -> Result<Bytes>;

/// Duplicates itself (sans header configuration) to support serializing
/// multiple batches in parallel on multiple cores. Unless we are serializing
/// a CSV file, this method is no-op.
fn duplicate_headerless(&self) -> Arc<dyn SerializationSchema>;
/// Parameter `initial` signals whether the given batch is the first batch.
/// This distinction is important for certain serializers (like CSV).
async fn serialize(&self, batch: RecordBatch, initial: bool) -> Result<Bytes>;
}

/// Returns an [`AbortableWrite`] which writes to the given object store location
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,24 +47,23 @@ type SerializerType = Arc<dyn SerializationSchema>;
/// so that the caller may handle aborting failed writes.
pub(crate) async fn serialize_rb_stream_to_object_store(
mut data_rx: Receiver<RecordBatch>,
mut serializer: Arc<dyn SerializationSchema>,
serializer: Arc<dyn SerializationSchema>,
mut writer: AbortableWrite<Box<dyn AsyncWrite + Send + Unpin>>,
) -> std::result::Result<(WriterType, u64), (WriterType, DataFusionError)> {
let (tx, mut rx) =
mpsc::channel::<JoinHandle<Result<(usize, Bytes), DataFusionError>>>(100);
// Initially, has_header can be true for CSV use cases. Then, we must turn
// it off to maintain the integrity of the writing process.
let serialize_task = tokio::spawn(async move {
// Some serializers (like CSV) handle the first batch differently than
// subsequent batches, so we track that here.
let mut initial = true;
while let Some(batch) = data_rx.recv().await {
let serializer_clone = serializer.clone();
let handle = tokio::spawn(async move {
let num_rows = batch.num_rows();
let bytes = serializer_clone.serialize(batch).await?;
let bytes = serializer_clone.serialize(batch, initial).await?;
Ok((num_rows, bytes))
});
if initial {
serializer = serializer.duplicate_headerless();
initial = false;
}
tx.send(handle).await.map_err(|_| {
Expand Down
6 changes: 1 addition & 5 deletions datafusion/core/src/datasource/physical_plan/file_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -991,12 +991,8 @@ mod tests {

#[async_trait]
impl SerializationSchema for TestSerializer {
async fn serialize(&self, _batch: RecordBatch) -> Result<Bytes> {
async fn serialize(&self, _batch: RecordBatch, _initial: bool) -> Result<Bytes> {
Ok(self.bytes.clone())
}

fn duplicate_headerless(&self) -> Arc<dyn SerializationSchema> {
unimplemented!()
}
}
}

0 comments on commit 225253b

Please sign in to comment.