Skip to content

Commit

Permalink
Use regex instead of hash-maps
Browse files Browse the repository at this point in the history
  • Loading branch information
vrongmeal committed Sep 11, 2023
1 parent ed47efe commit 4c11b3b
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 66 deletions.
139 changes: 74 additions & 65 deletions arrow-csv/src/reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@ use arrow_schema::*;
use chrono::{TimeZone, Utc};
use csv::StringRecord;
use lazy_static::lazy_static;
use regex::RegexSet;
use std::collections::HashSet;
use regex::{Regex, RegexSet};
use std::fmt::{self, Debug};
use std::fs::File;
use std::io::{BufRead, BufReader as StdBufReader, Read, Seek, SeekFrom};
Expand Down Expand Up @@ -214,7 +213,7 @@ pub struct Format {
escape: Option<u8>,
quote: Option<u8>,
terminator: Option<u8>,
nulls: HashSet<String>,
null_regex: Option<Regex>,
}

impl Format {
Expand Down Expand Up @@ -243,8 +242,8 @@ impl Format {
self
}

pub fn with_nulls(mut self, nulls: HashSet<String>) -> Self {
self.nulls = nulls;
pub fn with_null_regex(mut self, null_regex: Regex) -> Self {
self.null_regex = Some(null_regex);
self
}

Expand Down Expand Up @@ -326,6 +325,7 @@ impl Format {
if let Some(t) = self.terminator {
builder.terminator(csv::Terminator::Any(t));
}
// TODO: Null regex
builder.from_reader(reader)
}

Expand All @@ -343,6 +343,7 @@ impl Format {
if let Some(t) = self.terminator {
builder.terminator(csv_core::Terminator::Any(t));
}
// TODO: Null regex
builder.build()
}
}
Expand Down Expand Up @@ -564,8 +565,8 @@ pub struct Decoder {
/// A decoder for [`StringRecords`]
record_decoder: RecordDecoder,

/// Check for if the string is `NULL` value or not.
is_null: Box<dyn Fn(&str) -> bool>,
/// Check if the string matches this pattern for `NULL`.
null_regex: Option<Regex>,
}

impl Debug for Decoder {
Expand Down Expand Up @@ -626,7 +627,7 @@ impl Decoder {
Some(self.schema.metadata.clone()),
self.projection.as_ref(),
self.line_number,
&self.is_null,
self.null_regex.as_ref(),
)?;
self.line_number += rows.len();
Ok(Some(batch))
Expand All @@ -645,7 +646,7 @@ fn parse(
metadata: Option<std::collections::HashMap<String, String>>,
projection: Option<&Vec<usize>>,
line_number: usize,
is_null: &dyn Fn(&str) -> bool,
null_regex: Option<&Regex>,
) -> Result<RecordBatch, ArrowError> {
let projection: Vec<usize> = match projection {
Some(v) => v.clone(),
Expand All @@ -658,15 +659,17 @@ fn parse(
let i = *i;
let field = &fields[i];
match field.data_type() {
DataType::Boolean => build_boolean_array(line_number, rows, i, is_null),
DataType::Boolean => {
build_boolean_array(line_number, rows, i, null_regex)
}
DataType::Decimal128(precision, scale) => {
build_decimal_array::<Decimal128Type>(
line_number,
rows,
i,
*precision,
*scale,
is_null,
null_regex,
)
}
DataType::Decimal256(precision, scale) => {
Expand All @@ -676,78 +679,81 @@ fn parse(
i,
*precision,
*scale,
is_null,
null_regex,
)
}
DataType::Int8 => {
build_primitive_array::<Int8Type>(line_number, rows, i, is_null)
build_primitive_array::<Int8Type>(line_number, rows, i, null_regex)
}
DataType::Int16 => {
build_primitive_array::<Int16Type>(line_number, rows, i, is_null)
build_primitive_array::<Int16Type>(line_number, rows, i, null_regex)
}
DataType::Int32 => {
build_primitive_array::<Int32Type>(line_number, rows, i, is_null)
build_primitive_array::<Int32Type>(line_number, rows, i, null_regex)
}
DataType::Int64 => {
build_primitive_array::<Int64Type>(line_number, rows, i, is_null)
build_primitive_array::<Int64Type>(line_number, rows, i, null_regex)
}
DataType::UInt8 => {
build_primitive_array::<UInt8Type>(line_number, rows, i, is_null)
build_primitive_array::<UInt8Type>(line_number, rows, i, null_regex)
}
DataType::UInt16 => {
build_primitive_array::<UInt16Type>(line_number, rows, i, is_null)
build_primitive_array::<UInt16Type>(line_number, rows, i, null_regex)
}
DataType::UInt32 => {
build_primitive_array::<UInt32Type>(line_number, rows, i, is_null)
build_primitive_array::<UInt32Type>(line_number, rows, i, null_regex)
}
DataType::UInt64 => {
build_primitive_array::<UInt64Type>(line_number, rows, i, is_null)
build_primitive_array::<UInt64Type>(line_number, rows, i, null_regex)
}
DataType::Float32 => {
build_primitive_array::<Float32Type>(line_number, rows, i, is_null)
build_primitive_array::<Float32Type>(line_number, rows, i, null_regex)
}
DataType::Float64 => {
build_primitive_array::<Float64Type>(line_number, rows, i, is_null)
build_primitive_array::<Float64Type>(line_number, rows, i, null_regex)
}
DataType::Date32 => {
build_primitive_array::<Date32Type>(line_number, rows, i, is_null)
build_primitive_array::<Date32Type>(line_number, rows, i, null_regex)
}
DataType::Date64 => {
build_primitive_array::<Date64Type>(line_number, rows, i, is_null)
build_primitive_array::<Date64Type>(line_number, rows, i, null_regex)
}
DataType::Time32(TimeUnit::Second) => build_primitive_array::<
Time32SecondType,
>(
line_number, rows, i, is_null
line_number, rows, i, null_regex
),
DataType::Time32(TimeUnit::Millisecond) => {
build_primitive_array::<Time32MillisecondType>(
line_number,
rows,
i,
is_null,
null_regex,
)
}
DataType::Time64(TimeUnit::Microsecond) => {
build_primitive_array::<Time64MicrosecondType>(
line_number,
rows,
i,
is_null,
null_regex,
)
}
DataType::Time64(TimeUnit::Nanosecond) => {
build_primitive_array::<Time64NanosecondType>(
line_number,
rows,
i,
null_regex,
)
}
DataType::Time64(TimeUnit::Nanosecond) => build_primitive_array::<
Time64NanosecondType,
>(
line_number, rows, i, is_null
),
DataType::Timestamp(TimeUnit::Second, tz) => {
build_timestamp_array::<TimestampSecondType>(
line_number,
rows,
i,
tz.as_deref(),
is_null,
null_regex,
)
}
DataType::Timestamp(TimeUnit::Millisecond, tz) => {
Expand All @@ -756,7 +762,7 @@ fn parse(
rows,
i,
tz.as_deref(),
is_null,
null_regex,
)
}
DataType::Timestamp(TimeUnit::Microsecond, tz) => {
Expand All @@ -765,7 +771,7 @@ fn parse(
rows,
i,
tz.as_deref(),
is_null,
null_regex,
)
}
DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
Expand All @@ -774,7 +780,7 @@ fn parse(
rows,
i,
tz.as_deref(),
is_null,
null_regex,
)
}
DataType::Utf8 => Ok(Arc::new(
Expand Down Expand Up @@ -874,12 +880,12 @@ fn build_decimal_array<T: DecimalType>(
col_idx: usize,
precision: u8,
scale: i8,
is_null: &dyn Fn(&str) -> bool,
null_regex: Option<&Regex>,
) -> Result<ArrayRef, ArrowError> {
let mut decimal_builder = PrimitiveBuilder::<T>::with_capacity(rows.len());
for row in rows.iter() {
let s = row.get(col_idx);
if is_null(s) {
if s.is_empty() || null_regex.is_some_and(|r| r.is_match(s)) {
// append null
decimal_builder.append_null();
} else {
Expand Down Expand Up @@ -907,13 +913,13 @@ fn build_primitive_array<T: ArrowPrimitiveType + Parser>(
line_number: usize,
rows: &StringRecords<'_>,
col_idx: usize,
is_null: &dyn Fn(&str) -> bool,
null_regex: Option<&Regex>,
) -> Result<ArrayRef, ArrowError> {
rows.iter()
.enumerate()
.map(|(row_index, row)| {
let s = row.get(col_idx);
if is_null(s) {
if s.is_empty() || null_regex.is_some_and(|r| r.is_match(s)) {
return Ok(None);
}

Expand All @@ -937,17 +943,27 @@ fn build_timestamp_array<T: ArrowTimestampType>(
rows: &StringRecords<'_>,
col_idx: usize,
timezone: Option<&str>,
is_null: &dyn Fn(&str) -> bool,
null_regex: Option<&Regex>,
) -> Result<ArrayRef, ArrowError> {
Ok(Arc::new(match timezone {
Some(timezone) => {
let tz: Tz = timezone.parse()?;
build_timestamp_array_impl::<T, _>(line_number, rows, col_idx, &tz, is_null)?
.with_timezone(timezone)
}
None => {
build_timestamp_array_impl::<T, _>(line_number, rows, col_idx, &Utc, is_null)?
build_timestamp_array_impl::<T, _>(
line_number,
rows,
col_idx,
&tz,
null_regex,
)?
.with_timezone(timezone)
}
None => build_timestamp_array_impl::<T, _>(
line_number,
rows,
col_idx,
&Utc,
null_regex,
)?,
}))
}

Expand All @@ -956,13 +972,13 @@ fn build_timestamp_array_impl<T: ArrowTimestampType, Tz: TimeZone>(
rows: &StringRecords<'_>,
col_idx: usize,
timezone: &Tz,
is_null: &dyn Fn(&str) -> bool,
null_regex: Option<&Regex>,
) -> Result<PrimitiveArray<T>, ArrowError> {
rows.iter()
.enumerate()
.map(|(row_index, row)| {
let s = row.get(col_idx);
if is_null(s) {
if s.is_empty() || null_regex.is_some_and(|r| r.is_match(s)) {
return Ok(None);
}

Expand All @@ -989,13 +1005,13 @@ fn build_boolean_array(
line_number: usize,
rows: &StringRecords<'_>,
col_idx: usize,
is_null: &dyn Fn(&str) -> bool,
null_regex: Option<&Regex>,
) -> Result<ArrayRef, ArrowError> {
rows.iter()
.enumerate()
.map(|(row_index, row)| {
let s = row.get(col_idx);
if is_null(s) {
if s.is_empty() || null_regex.is_some_and(|r| r.is_match(s)) {
return Ok(None);
}
let parsed = parse_bool(s);
Expand Down Expand Up @@ -1029,8 +1045,8 @@ pub struct ReaderBuilder {
bounds: Bounds,
/// Optional projection for which columns to load (zero-based column indices)
projection: Option<Vec<usize>>,
/// Strings to consider as `NULL` when parsing.
nulls: HashSet<String>,
/// Pattern to consider as `NULL` when parsing.
null_regex: Option<Regex>,
}

impl ReaderBuilder {
Expand Down Expand Up @@ -1062,7 +1078,7 @@ impl ReaderBuilder {
batch_size: 1024,
bounds: None,
projection: None,
nulls: HashSet::new(),
null_regex: None,
}
}

Expand Down Expand Up @@ -1099,8 +1115,8 @@ impl ReaderBuilder {
self
}

pub fn with_nulls(mut self, nulls: HashSet<String>) -> Self {
self.nulls = nulls;
pub fn with_null_regex(mut self, null_regex: Regex) -> Self {
self.null_regex = Some(null_regex);
self
}

Expand Down Expand Up @@ -1154,13 +1170,6 @@ impl ReaderBuilder {
None => (header, usize::MAX),
};

let is_null: Box<dyn Fn(&str) -> bool> = if self.nulls.is_empty() {
Box::new(|s| s.is_empty())
} else {
let nulls = self.nulls;
Box::new(move |s| s.is_empty() || nulls.contains(s))
};

Decoder {
schema: self.schema,
to_skip: start,
Expand All @@ -1169,7 +1178,7 @@ impl ReaderBuilder {
end,
projection: self.projection,
batch_size: self.batch_size,
is_null,
null_regex: self.null_regex,
}
}
}
Expand Down Expand Up @@ -1507,11 +1516,11 @@ mod tests {

let file = File::open("test/data/custom_null_test.csv").unwrap();

let nulls: HashSet<String> = ["nil"].into_iter().map(|s| s.to_string()).collect();
let null_regex = Regex::new("^nil$").unwrap();

let mut csv = ReaderBuilder::new(schema)
.has_header(true)
.with_nulls(nulls)
.with_null_regex(null_regex)
.build(file)
.unwrap();

Expand Down
2 changes: 1 addition & 1 deletion arrow-csv/test/data/custom_null_test.csv
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ c_int,c_float,c_string,c_bool
nil,2.2,"2.22",TRUE
3,nil,"3.33",true
4,4.4,nil,False
5,6.6,"",nil
5,6.6,"",nil

0 comments on commit 4c11b3b

Please sign in to comment.