Skip to content

Commit

Permalink
feat: support InList clause in streaming SQL (#1694)
Browse files Browse the repository at this point in the history
* feat: support InList clause in streaming SQL

Signed-off-by: hi-rustin <[email protected]>

* feat: support not in list

Signed-off-by: hi-rustin <[email protected]>

* fix: use users

Signed-off-by: hi-rustin <[email protected]>

* Remove broken tests

Signed-off-by: hi-rustin <[email protected]>

* Fix fmt

Signed-off-by: hi-rustin <[email protected]>

* Better test code

Signed-off-by: hi-rustin <[email protected]>

---------

Signed-off-by: hi-rustin <[email protected]>
Co-authored-by: Dario Pizzamiglio <[email protected]>
  • Loading branch information
Rustin170506 and mediuminvader authored Jul 7, 2023
1 parent 5831ba4 commit bbaa21f
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 0 deletions.
28 changes: 28 additions & 0 deletions dozer-sql/src/pipeline/expression/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ impl ExpressionBuilder {
escape_char,
schema,
),
SqlExpr::InList {
expr,
list,
negated,
} => self.parse_sql_in_list_operator(parse_aggregations, expr, list, *negated, schema),

SqlExpr::Cast { expr, data_type } => {
self.parse_sql_cast_operator(parse_aggregations, expr, data_type, schema)
}
Expand Down Expand Up @@ -764,6 +770,28 @@ impl ExpressionBuilder {
return_type,
})
}

fn parse_sql_in_list_operator(
&mut self,
parse_aggregations: bool,
expr: &Expr,
list: &[Expr],
negated: bool,
schema: &Schema,
) -> Result<Expression, PipelineError> {
let expr = self.parse_sql_expression(parse_aggregations, expr, schema)?;
let list = list
.iter()
.map(|expr| self.parse_sql_expression(parse_aggregations, expr, schema))
.collect::<Result<Vec<_>, PipelineError>>()?;
let in_list_expression = Expression::InList {
expr: Box::new(expr),
list,
negated,
};

Ok(in_list_expression)
}
}

#[derive(Debug, Clone, Hash, PartialEq, Eq)]
Expand Down
37 changes: 37 additions & 0 deletions dozer-sql/src/pipeline/expression/execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use uuid::Uuid;

use super::aggregate::AggregateFunctionType;
use super::cast::CastOperatorType;
use super::in_list::evaluate_in_list;
use super::scalar::string::{evaluate_like, get_like_operator_type};

#[derive(Clone, Debug, PartialEq)]
Expand Down Expand Up @@ -73,6 +74,11 @@ pub enum Expression {
pattern: Box<Expression>,
escape: Option<char>,
},
InList {
expr: Box<Expression>,
list: Vec<Expression>,
negated: bool,
},
Now {
fun: DateTimeFunctionType,
},
Expand Down Expand Up @@ -219,6 +225,22 @@ impl Expression {
pattern,
escape: _,
} => arg.to_string(schema) + " LIKE " + pattern.to_string(schema).as_str(),
Expression::InList {
expr,
list,
negated,
} => {
expr.to_string(schema)
+ if *negated { " NOT" } else { "" }
+ " IN ("
+ list
.iter()
.map(|e| e.to_string(schema))
.collect::<Vec<String>>()
.join(",")
.as_str()
+ ")"
}
Expression::GeoFunction { fun, args } => {
fun.to_string()
+ "("
Expand Down Expand Up @@ -318,6 +340,11 @@ impl ExpressionExecutor for Expression {
pattern,
escape,
} => evaluate_like(schema, arg, pattern, *escape, record),
Expression::InList {
expr,
list,
negated,
} => evaluate_in_list(schema, expr, list, *negated, record),
Expression::Cast { arg, typ } => typ.evaluate(schema, arg, record),
Expression::GeoFunction { fun, args } => fun.evaluate(schema, args, record),
Expression::ConditionalExpression { fun, args } => fun.evaluate(schema, args, record),
Expand Down Expand Up @@ -384,6 +411,16 @@ impl ExpressionExecutor for Expression {
pattern,
escape: _,
} => get_like_operator_type(arg, pattern, schema),
Expression::InList {
expr: _,
list: _,
negated: _,
} => Ok(ExpressionType::new(
FieldType::Boolean,
false,
SourceDefinition::Dynamic,
false,
)),
Expression::Cast { arg, typ } => typ.get_return_type(schema, arg),
Expression::GeoFunction { fun, args } => get_geo_function_type(fun, args, schema),
Expression::DateTimeFunction { fun, arg } => {
Expand Down
27 changes: 27 additions & 0 deletions dozer-sql/src/pipeline/expression/in_list.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use dozer_types::types::{Field, Record, Schema};

use crate::pipeline::errors::PipelineError;
use crate::pipeline::expression::execution::{Expression, ExpressionExecutor};

pub(crate) fn evaluate_in_list(
schema: &Schema,
expr: &Expression,
list: &[Expression],
negated: bool,
record: &Record,
) -> Result<Field, PipelineError> {
let field = expr.evaluate(record, schema)?;
let mut result = false;
for item in list {
let item = item.evaluate(record, schema)?;
if field == item {
result = true;
break;
}
}
// Negate the result if the IN list was negated.
if negated {
result = !result;
}
Ok(Field::Boolean(result))
}
1 change: 1 addition & 0 deletions dozer-sql/src/pipeline/expression/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub mod conditional;
mod datetime;
pub mod execution;
pub mod geo;
pub mod in_list;
mod json_functions;
pub mod logical;
pub mod mathematical;
Expand Down
86 changes: 86 additions & 0 deletions dozer-sql/src/pipeline/expression/tests/in_list.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
use crate::pipeline::expression::tests::test_common::run_fct;
use dozer_types::types::{Field, FieldDefinition, FieldType, Schema, SourceDefinition};

#[test]
fn test_in_list() {
let f = run_fct(
"SELECT 42 IN (1, 2, 3, 4, 5, 6, 7, 8, 9, 10)",
Schema::empty(),
vec![],
);
assert_eq!(f, Field::Boolean(false));

let f = run_fct(
"SELECT 42 IN (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 42)",
Schema::empty(),
vec![],
);
assert_eq!(f, Field::Boolean(true));

let schema = Schema::empty()
.field(
FieldDefinition::new(
String::from("age"),
FieldType::Int,
false,
SourceDefinition::Dynamic,
),
false,
)
.clone();
let f = run_fct(
"SELECT age IN (1, 2, 3, 4, 5, 6, 7, 8, 9, 10) FROM users",
schema.clone(),
vec![Field::Int(42)],
);
assert_eq!(f, Field::Boolean(false));

let f = run_fct(
"SELECT age IN (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 42) FROM users",
schema.clone(),
vec![Field::Int(42)],
);
assert_eq!(f, Field::Boolean(true));
}

#[test]
fn test_not_in_list() {
let f = run_fct(
"SELECT 42 NOT IN (1, 2, 3, 4, 5, 6, 7, 8, 9, 10)",
Schema::empty(),
vec![],
);
assert_eq!(f, Field::Boolean(true));

let f = run_fct(
"SELECT 42 NOT IN (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 42)",
Schema::empty(),
vec![],
);
assert_eq!(f, Field::Boolean(false));

let schema = Schema::empty()
.field(
FieldDefinition::new(
String::from("age"),
FieldType::Int,
false,
SourceDefinition::Dynamic,
),
false,
)
.clone();
let f = run_fct(
"SELECT age NOT IN (1, 2, 3, 4, 5, 6, 7, 8, 9, 10) FROM users",
schema.clone(),
vec![Field::Int(42)],
);
assert_eq!(f, Field::Boolean(true));

let f = run_fct(
"SELECT age NOT IN (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 42) FROM users",
schema.clone(),
vec![Field::Int(42)],
);
assert_eq!(f, Field::Boolean(false));
}
1 change: 1 addition & 0 deletions dozer-sql/src/pipeline/expression/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ mod conditional;
mod datetime;
#[cfg(test)]
mod distance;
mod in_list;
#[cfg(test)]
mod json_functions;
#[cfg(test)]
Expand Down

0 comments on commit bbaa21f

Please sign in to comment.