From 2075cd125dc0c132be5cb9dbf65748abf52243f1 Mon Sep 17 00:00:00 2001 From: Vaibhav Rabber Date: Wed, 13 Sep 2023 16:27:39 +0530 Subject: [PATCH] csv: Add option to specify custom null values (#4795) * csv: Add option to specify custom null regex Can specify custom strings as `NULL` values for CSVs as a regular expression. This allows reading a CSV files which have placeholders for NULL values instead of empty strings. Fixes #4794 Signed-off-by: Vaibhav * Apply suggestions from code review --------- Signed-off-by: Vaibhav Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> --- arrow-csv/src/reader/mod.rs | 203 +++++++++++++++++++---- arrow-csv/test/data/custom_null_test.csv | 6 + 2 files changed, 180 insertions(+), 29 deletions(-) create mode 100644 arrow-csv/test/data/custom_null_test.csv diff --git a/arrow-csv/src/reader/mod.rs b/arrow-csv/src/reader/mod.rs index 328c2cd41f3b..695e3d47965d 100644 --- a/arrow-csv/src/reader/mod.rs +++ b/arrow-csv/src/reader/mod.rs @@ -133,8 +133,8 @@ use arrow_schema::*; use chrono::{TimeZone, Utc}; use csv::StringRecord; use lazy_static::lazy_static; -use regex::RegexSet; -use std::fmt; +use regex::{Regex, RegexSet}; +use std::fmt::{self, Debug}; use std::fs::File; use std::io::{BufRead, BufReader as StdBufReader, Read, Seek, SeekFrom}; use std::sync::Arc; @@ -157,6 +157,22 @@ lazy_static! { ]).unwrap(); } +/// A wrapper over `Option` to check if the value is `NULL`. +#[derive(Debug, Clone, Default)] +struct NullRegex(Option); + +impl NullRegex { + /// Returns true if the value should be considered as `NULL` according to + /// the provided regular expression. + #[inline] + fn is_null(&self, s: &str) -> bool { + match &self.0 { + Some(r) => r.is_match(s), + None => s.is_empty(), + } + } +} + #[derive(Default, Copy, Clone)] struct InferredDataType { /// Packed booleans indicating type @@ -213,6 +229,7 @@ pub struct Format { escape: Option, quote: Option, terminator: Option, + null_regex: NullRegex, } impl Format { @@ -241,6 +258,12 @@ impl Format { self } + /// Provide a regex to match null values, defaults to `^$` + pub fn with_null_regex(mut self, null_regex: Regex) -> Self { + self.null_regex = NullRegex(Some(null_regex)); + self + } + /// Infer schema of CSV records from the provided `reader` /// /// If `max_records` is `None`, all records will be read, otherwise up to `max_records` @@ -287,7 +310,7 @@ impl Format { column_types.iter_mut().enumerate().take(header_length) { if let Some(string) = record.get(i) { - if !string.is_empty() { + if !self.null_regex.is_null(string) { column_type.update(string) } } @@ -557,6 +580,9 @@ pub struct Decoder { /// A decoder for [`StringRecords`] record_decoder: RecordDecoder, + + /// Check if the string matches this pattern for `NULL`. + null_regex: NullRegex, } impl Decoder { @@ -603,6 +629,7 @@ impl Decoder { Some(self.schema.metadata.clone()), self.projection.as_ref(), self.line_number, + &self.null_regex, )?; self.line_number += rows.len(); Ok(Some(batch)) @@ -621,6 +648,7 @@ fn parse( metadata: Option>, projection: Option<&Vec>, line_number: usize, + null_regex: &NullRegex, ) -> Result { let projection: Vec = match projection { Some(v) => v.clone(), @@ -633,7 +661,9 @@ fn parse( let i = *i; let field = &fields[i]; match field.data_type() { - DataType::Boolean => build_boolean_array(line_number, rows, i), + DataType::Boolean => { + build_boolean_array(line_number, rows, i, null_regex) + } DataType::Decimal128(precision, scale) => { build_decimal_array::( line_number, @@ -641,6 +671,7 @@ fn parse( i, *precision, *scale, + null_regex, ) } DataType::Decimal256(precision, scale) => { @@ -650,53 +681,73 @@ fn parse( i, *precision, *scale, + null_regex, ) } - DataType::Int8 => build_primitive_array::(line_number, rows, i), + DataType::Int8 => { + build_primitive_array::(line_number, rows, i, null_regex) + } DataType::Int16 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, null_regex) } DataType::Int32 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, null_regex) } DataType::Int64 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, null_regex) } DataType::UInt8 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, null_regex) } DataType::UInt16 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, null_regex) } DataType::UInt32 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, null_regex) } DataType::UInt64 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, null_regex) } DataType::Float32 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, null_regex) } DataType::Float64 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, null_regex) } DataType::Date32 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, null_regex) } DataType::Date64 => { - build_primitive_array::(line_number, rows, i) - } - DataType::Time32(TimeUnit::Second) => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, null_regex) } + DataType::Time32(TimeUnit::Second) => build_primitive_array::< + Time32SecondType, + >( + line_number, rows, i, null_regex + ), DataType::Time32(TimeUnit::Millisecond) => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::( + line_number, + rows, + i, + null_regex, + ) } DataType::Time64(TimeUnit::Microsecond) => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::( + line_number, + rows, + i, + null_regex, + ) } DataType::Time64(TimeUnit::Nanosecond) => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::( + line_number, + rows, + i, + null_regex, + ) } DataType::Timestamp(TimeUnit::Second, tz) => { build_timestamp_array::( @@ -704,6 +755,7 @@ fn parse( rows, i, tz.as_deref(), + null_regex, ) } DataType::Timestamp(TimeUnit::Millisecond, tz) => { @@ -712,6 +764,7 @@ fn parse( rows, i, tz.as_deref(), + null_regex, ) } DataType::Timestamp(TimeUnit::Microsecond, tz) => { @@ -720,6 +773,7 @@ fn parse( rows, i, tz.as_deref(), + null_regex, ) } DataType::Timestamp(TimeUnit::Nanosecond, tz) => { @@ -728,6 +782,7 @@ fn parse( rows, i, tz.as_deref(), + null_regex, ) } DataType::Utf8 => Ok(Arc::new( @@ -827,11 +882,12 @@ fn build_decimal_array( col_idx: usize, precision: u8, scale: i8, + null_regex: &NullRegex, ) -> Result { let mut decimal_builder = PrimitiveBuilder::::with_capacity(rows.len()); for row in rows.iter() { let s = row.get(col_idx); - if s.is_empty() { + if null_regex.is_null(s) { // append null decimal_builder.append_null(); } else { @@ -859,12 +915,13 @@ fn build_primitive_array( line_number: usize, rows: &StringRecords<'_>, col_idx: usize, + null_regex: &NullRegex, ) -> Result { rows.iter() .enumerate() .map(|(row_index, row)| { let s = row.get(col_idx); - if s.is_empty() { + if null_regex.is_null(s) { return Ok(None); } @@ -888,14 +945,27 @@ fn build_timestamp_array( rows: &StringRecords<'_>, col_idx: usize, timezone: Option<&str>, + null_regex: &NullRegex, ) -> Result { Ok(Arc::new(match timezone { Some(timezone) => { let tz: Tz = timezone.parse()?; - build_timestamp_array_impl::(line_number, rows, col_idx, &tz)? - .with_timezone(timezone) + build_timestamp_array_impl::( + line_number, + rows, + col_idx, + &tz, + null_regex, + )? + .with_timezone(timezone) } - None => build_timestamp_array_impl::(line_number, rows, col_idx, &Utc)?, + None => build_timestamp_array_impl::( + line_number, + rows, + col_idx, + &Utc, + null_regex, + )?, })) } @@ -904,12 +974,13 @@ fn build_timestamp_array_impl( rows: &StringRecords<'_>, col_idx: usize, timezone: &Tz, + null_regex: &NullRegex, ) -> Result, ArrowError> { rows.iter() .enumerate() .map(|(row_index, row)| { let s = row.get(col_idx); - if s.is_empty() { + if null_regex.is_null(s) { return Ok(None); } @@ -936,12 +1007,13 @@ fn build_boolean_array( line_number: usize, rows: &StringRecords<'_>, col_idx: usize, + null_regex: &NullRegex, ) -> Result { rows.iter() .enumerate() .map(|(row_index, row)| { let s = row.get(col_idx); - if s.is_empty() { + if null_regex.is_null(s) { return Ok(None); } let parsed = parse_bool(s); @@ -1042,6 +1114,12 @@ impl ReaderBuilder { self } + /// Provide a regex to match null values, defaults to `^$` + pub fn with_null_regex(mut self, null_regex: Regex) -> Self { + self.format.null_regex = NullRegex(Some(null_regex)); + self + } + /// Set the batch size (number of records to load at one time) pub fn with_batch_size(mut self, batch_size: usize) -> Self { self.batch_size = batch_size; @@ -1100,6 +1178,7 @@ impl ReaderBuilder { end, projection: self.projection, batch_size: self.batch_size, + null_regex: self.format.null_regex, } } } @@ -1426,6 +1505,36 @@ mod tests { assert!(!batch.column(1).is_null(4)); } + #[test] + fn test_custom_nulls() { + let schema = Arc::new(Schema::new(vec![ + Field::new("c_int", DataType::UInt64, true), + Field::new("c_float", DataType::Float32, true), + Field::new("c_string", DataType::Utf8, true), + Field::new("c_bool", DataType::Boolean, true), + ])); + + let file = File::open("test/data/custom_null_test.csv").unwrap(); + + let null_regex = Regex::new("^nil$").unwrap(); + + let mut csv = ReaderBuilder::new(schema) + .has_header(true) + .with_null_regex(null_regex) + .build(file) + .unwrap(); + + let batch = csv.next().unwrap().unwrap(); + + // "nil"s should be NULL + assert!(batch.column(0).is_null(1)); + assert!(batch.column(1).is_null(2)); + assert!(batch.column(3).is_null(4)); + // String won't be empty + assert!(!batch.column(2).is_null(3)); + assert!(!batch.column(2).is_null(4)); + } + #[test] fn test_nulls_with_inference() { let mut file = File::open("test/data/various_types.csv").unwrap(); @@ -1485,6 +1594,42 @@ mod tests { assert!(!batch.column(1).is_null(4)); } + #[test] + fn test_custom_nulls_with_inference() { + let mut file = File::open("test/data/custom_null_test.csv").unwrap(); + + let null_regex = Regex::new("^nil$").unwrap(); + + let format = Format::default() + .with_header(true) + .with_null_regex(null_regex); + + let (schema, _) = format.infer_schema(&mut file, None).unwrap(); + file.rewind().unwrap(); + + let expected_schema = Schema::new(vec![ + Field::new("c_int", DataType::Int64, true), + Field::new("c_float", DataType::Float64, true), + Field::new("c_string", DataType::Utf8, true), + Field::new("c_bool", DataType::Boolean, true), + ]); + + assert_eq!(schema, expected_schema); + + let builder = ReaderBuilder::new(Arc::new(schema)) + .with_format(format) + .with_batch_size(512) + .with_projection(vec![0, 1, 2, 3]); + + let mut csv = builder.build(file).unwrap(); + let batch = csv.next().unwrap().unwrap(); + + assert_eq!(5, batch.num_rows()); + assert_eq!(4, batch.num_columns()); + + assert_eq!(batch.schema().as_ref(), &expected_schema); + } + #[test] fn test_parse_invalid_csv() { let file = File::open("test/data/various_types_invalid.csv").unwrap(); diff --git a/arrow-csv/test/data/custom_null_test.csv b/arrow-csv/test/data/custom_null_test.csv new file mode 100644 index 000000000000..39f9fc4b3eff --- /dev/null +++ b/arrow-csv/test/data/custom_null_test.csv @@ -0,0 +1,6 @@ +c_int,c_float,c_string,c_bool +1,1.1,"1.11",True +nil,2.2,"2.22",TRUE +3,nil,"3.33",true +4,4.4,nil,False +5,6.6,"",nil