Skip to content

Commit

Permalink
perf: flow if then using interleave
Browse files Browse the repository at this point in the history
  • Loading branch information
discord9 committed Aug 26, 2024
1 parent 9ef585c commit 101f901
Showing 1 changed file with 54 additions and 92 deletions.
146 changes: 54 additions & 92 deletions src/flow/src/expr/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,8 @@ impl ScalarExpr {
}
}

// TODO(discord9): optimize using `arrow::compute::concat` instead
/// NOTE: this if then eval impl assume all given expr are pure, and will not change the state of the world
/// since it will evaluate both then and else branch and filter the result
fn eval_if_then(
batch: &Batch,
cond: &ScalarExpr,
Expand All @@ -308,108 +309,69 @@ impl ScalarExpr {
})?
.as_boolean_array();

let mut input_batch: Vec<(Option<bool>, Batch)> = Vec::with_capacity(bool_conds.len());

let mut prev_cond_and_idx: Option<(Option<bool>, usize)> = None;
// first put different conds' vector into different batches
for (idx, cond) in bool_conds.iter().enumerate() {
// if belong to same slice and not last one continue
if let Some((prev_cond, prev_start_idx)) = prev_cond_and_idx {
if prev_cond == cond {
continue;
} else {
// put a slice to corresponding batch
let slice_offset = prev_start_idx;
let slice_length = idx - prev_start_idx;
let to_be_append = batch.slice(slice_offset, slice_length)?;
input_batch.push((prev_cond, to_be_append));
prev_cond_and_idx = Some((cond, idx));
}
} else {
prev_cond_and_idx = Some((cond, idx));
}
}

// deal with empty and last slice case
if let Some((prev_cond, prev_start_idx)) = prev_cond_and_idx {
let slice_length = bool_conds.len() - prev_start_idx;
let to_be_append = batch.slice(prev_start_idx, slice_length)?;
input_batch.push((prev_cond, to_be_append));
}

let mut output_arrays = Vec::with_capacity(input_batch.len());
let mut first_type = None;
for (cond, input_batch) in input_batch {
let len = input_batch.row_count();
let out = match cond {
Some(true) => {
let ret = then.eval_batch(&input_batch)?;
if first_type.is_none() {
first_type = Some(ret.data_type());
}
(Some(ret.to_arrow_array()), len)
}
Some(false) => {
let ret = els.eval_batch(&input_batch)?;
if first_type.is_none() {
first_type = Some(ret.data_type());
}
(Some(ret.to_arrow_array()), len)
}
None => (None, len),
};
output_arrays.push(out);
}
fn new_nulls(dt: &arrow_schema::DataType, len: usize) -> ArrayRef {
let data = ArrayData::new_null(dt, len);
make_array(data)
}

// get persumed type from first output array
let persumed_type = first_type
.unwrap_or(ConcreteDataType::null_datatype())
.as_arrow_type();

// and create null array of same type for concat
let cast_output_array = output_arrays
let indices = bool_conds
.into_iter()
.map(|(x, len)| x.unwrap_or_else(|| new_nulls(&persumed_type, len)))
.enumerate()
.map(|(idx, b)| {
(
match b {
Some(true) => 0, // then branch vector
Some(false) => 1, // else branch vector
None => 2, // null vector
},
idx,
)
})
.collect_vec();

let out_type = cast_output_array
.iter()
.map(|x| x.data_type())
.collect::<Vec<_>>();
let then_input_vec = then.eval_batch(batch)?;
let else_input_vec = els.eval_batch(batch)?;

let is_same = out_type.windows(2).all(|w| w[0] == w[1]);
ensure!(
then_input_vec.data_type() == else_input_vec.data_type(),
TypeMismatchSnafu {
expected: then_input_vec.data_type(),
actual: else_input_vec.data_type(),
}
);

if !is_same {
ensure!(
then_input_vec.len() == else_input_vec.len() && then_input_vec.len() == batch.row_count(),
InvalidArgumentSnafu {
reason: format!(
"if then else return different type, found {:?}",
BTreeSet::from_iter(out_type.iter())
),
"then and else branch must have the same length(found {} and {}) which equals input batch's row count(which is {})",
then_input_vec.len(),
else_input_vec.len(),
batch.row_count()
)
}
.fail()?;
);

fn new_nulls(dt: &arrow_schema::DataType, len: usize) -> ArrayRef {
let data = ArrayData::new_null(dt, len);
make_array(data)
}

let out_ref = cast_output_array
.iter()
.map(|v| v.as_ref())
.collect::<Vec<_>>();
let concated = if out_ref.is_empty() {
new_nulls(&persumed_type, 0)
} else {
arrow::compute::concat(&out_ref).context(ArrowSnafu {
context: "Failed to concat output arrays",
})?
};
let null_input_vec = new_nulls(
&then_input_vec.data_type().as_arrow_type(),
batch.row_count(),
);

let out_vec = Helper::try_into_vector(concated).context(DataTypeSnafu {
let interleave_values = vec![
then_input_vec.to_arrow_array(),
else_input_vec.to_arrow_array(),
null_input_vec,
];
let int_ref: Vec<_> = interleave_values.iter().map(|x| x.as_ref()).collect();

let interleave_res_arr =
arrow::compute::interleave(&int_ref, &indices).context(ArrowSnafu {
context: "Failed to interleave output arrays",
})?;
let res_vec = Helper::try_into_vector(interleave_res_arr).context(DataTypeSnafu {
msg: "Failed to convert arrow array to vector",
})?;

Ok(out_vec)
Ok(res_vec)
}

/// Eval this expression with the given values.
Expand Down Expand Up @@ -731,7 +693,7 @@ impl ScalarExpr {

#[cfg(test)]
mod test {
use datatypes::vectors::{Int32Vector, NullVector, Vector};
use datatypes::vectors::{Int32Vector, Vector};
use pretty_assertions::assert_eq;

use super::*;
Expand Down Expand Up @@ -886,7 +848,7 @@ mod test {
let vectors = vec![Int32Vector::from(raw).slice(0, raw_len)];

let batch = Batch::try_new(vectors, raw_len).unwrap();
let expected = NullVector::new(raw_len).slice(0, raw_len);
let expected = Int32Vector::from(vec![]).slice(0, raw_len);
assert_eq!(expr.eval_batch(&batch).unwrap(), expected);
}
}
Expand Down

0 comments on commit 101f901

Please sign in to comment.