diff --git a/datafusion-cli/src/print_format.rs b/datafusion-cli/src/print_format.rs index ea418562495d..0a8c7b4b3e3a 100644 --- a/datafusion-cli/src/print_format.rs +++ b/datafusion-cli/src/print_format.rs @@ -161,23 +161,29 @@ impl PrintFormat { maxrows: MaxRows, with_header: bool, ) -> Result<()> { - if batches.is_empty() || batches[0].num_rows() == 0 { + // filter out any empty batches + let batches: Vec<_> = batches + .iter() + .filter(|b| b.num_rows() > 0) + .cloned() + .collect(); + if batches.is_empty() { return Ok(()); } match self { Self::Csv | Self::Automatic => { - print_batches_with_sep(writer, batches, b',', with_header) + print_batches_with_sep(writer, &batches, b',', with_header) } - Self::Tsv => print_batches_with_sep(writer, batches, b'\t', with_header), + Self::Tsv => print_batches_with_sep(writer, &batches, b'\t', with_header), Self::Table => { if maxrows == MaxRows::Limited(0) { return Ok(()); } - format_batches_with_maxrows(writer, batches, maxrows) + format_batches_with_maxrows(writer, &batches, maxrows) } - Self::Json => batches_to_json!(ArrayWriter, writer, batches), - Self::NdJson => batches_to_json!(LineDelimitedWriter, writer, batches), + Self::Json => batches_to_json!(ArrayWriter, writer, &batches), + Self::NdJson => batches_to_json!(LineDelimitedWriter, writer, &batches), } } } @@ -189,7 +195,7 @@ mod tests { use super::*; - use arrow::array::Int32Array; + use arrow::array::{ArrayRef, Int32Array}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion::error::Result; @@ -351,4 +357,111 @@ mod tests { Ok(()) } + + #[test] + fn test_print_batches_empty_batches() -> Result<()> { + let batch = one_column_batch(); + let empty_batch = RecordBatch::new_empty(batch.schema()); + + #[rustfmt::skip] + let expected =&[ + "+---+", + "| a |", + "+---+", + "| 1 |", + "| 2 |", + "| 3 |", + "+---+\n", + ]; + + PrintBatchesTest::new() + .with_format(PrintFormat::Table) + .with_batches(vec![empty_batch.clone(), batch, empty_batch]) + .with_expected(expected) + .run(); + Ok(()) + } + + #[test] + fn test_print_batches_empty_batches_no_header() -> Result<()> { + let empty_batch = RecordBatch::new_empty(one_column_batch().schema()); + + // empty batches should not print a header + let expected = &[""]; + + PrintBatchesTest::new() + .with_format(PrintFormat::Table) + .with_batches(vec![empty_batch]) + .with_header(true) + .with_expected(expected) + .run(); + Ok(()) + } + + struct PrintBatchesTest { + format: PrintFormat, + batches: Vec, + maxrows: MaxRows, + with_header: bool, + expected: Vec<&'static str>, + } + + impl PrintBatchesTest { + fn new() -> Self { + Self { + format: PrintFormat::Table, + batches: vec![], + maxrows: MaxRows::Unlimited, + with_header: false, + expected: vec![], + } + } + + /// set the format + fn with_format(mut self, format: PrintFormat) -> Self { + self.format = format; + self + } + + /// set the batches to convert + fn with_batches(mut self, batches: Vec) -> Self { + self.batches = batches; + self + } + + /// set whether to include a header + fn with_header(mut self, with_header: bool) -> Self { + self.with_header = with_header; + self + } + + /// set expected output + fn with_expected(mut self, expected: &[&'static str]) -> Self { + self.expected = expected.to_vec(); + self + } + + /// run the test + fn run(self) { + let mut buffer: Vec = vec![]; + self.format + .print_batches(&mut buffer, &self.batches, self.maxrows, self.with_header) + .unwrap(); + let actual = String::from_utf8(buffer).unwrap(); + let expected = self.expected.join("\n"); + assert_eq!( + actual, expected, + "actual:\n\n{actual}expected:\n\n{expected}" + ); + } + } + + /// return a batch with one column and three rows + fn one_column_batch() -> RecordBatch { + RecordBatch::try_from_iter(vec![( + "a", + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + )]) + .unwrap() + } }