Skip to content

Commit

Permalink
build: Update sqlparser to 0.39 (#12173)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie authored Nov 1, 2023
1 parent daac19f commit 8e3de87
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 20 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ serde = "1.0.188"
serde_json = "1"
simd-json = { version = "0.13", features = ["known-key"] }
smartstring = "1"
sqlparser = "0.38"
sqlparser = "0.39"
strum_macros = "0.25"
thiserror = "1"
tokio = "1.26"
Expand Down
4 changes: 1 addition & 3 deletions crates/polars-sql/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,11 @@ impl SQLContext {
concatenated.map(|lf| lf.unique(None, UniqueKeepStrategy::Any))
},
// UNION ALL BY NAME
// TODO: add recognition for SetQuantifier::DistinctByName
// when "https://github.com/sqlparser-rs/sqlparser-rs/pull/997" is available
#[cfg(feature = "diagonal_concat")]
SetQuantifier::AllByName => concat_lf_diagonal(vec![left, right], opts),
// UNION [DISTINCT] BY NAME
#[cfg(feature = "diagonal_concat")]
SetQuantifier::ByName => {
SetQuantifier::ByName | SetQuantifier::DistinctByName => {
let concatenated = concat_lf_diagonal(vec![left, right], opts);
concatenated.map(|lf| lf.unique(None, UniqueKeepStrategy::Any))
},
Expand Down
38 changes: 28 additions & 10 deletions crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ use polars_plan::prelude::{col, lit, when};
use rand::distributions::Alphanumeric;
use rand::{thread_rng, Rng};
use sqlparser::ast::{
ArrayAgg, BinaryOperator as SQLBinaryOperator, BinaryOperator, DataType as SQLDataType,
Expr as SqlExpr, Function as SQLFunction, Ident, JoinConstraint, OrderByExpr,
Query as Subquery, SelectItem, TrimWhereField, UnaryOperator, Value as SqlValue,
ArrayAgg, ArrayElemTypeDef, BinaryOperator as SQLBinaryOperator, BinaryOperator, CastFormat,
DataType as SQLDataType, Expr as SqlExpr, Function as SQLFunction, Ident, JoinConstraint,
OrderByExpr, Query as Subquery, SelectItem, TrimWhereField, UnaryOperator, Value as SqlValue,
};
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::{Parser, ParserOptions};
Expand All @@ -19,7 +19,8 @@ use crate::SQLContext;

pub(crate) fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult<DataType> {
Ok(match data_type {
SQLDataType::Array(Some(inner_type)) => {
SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_type))
| SQLDataType::Array(ArrayElemTypeDef::SquareBracket(inner_type)) => {
DataType::List(Box::new(map_sql_polars_datatype(inner_type)?))
},
SQLDataType::BigInt(_) => DataType::Int64,
Expand All @@ -32,7 +33,7 @@ pub(crate) fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult<D
| SQLDataType::Character(_)
| SQLDataType::CharacterVarying(_)
| SQLDataType::Clob(_)
| SQLDataType::String
| SQLDataType::String(_)
| SQLDataType::Text
| SQLDataType::Uuid
| SQLDataType::Varchar(_) => DataType::Utf8,
Expand Down Expand Up @@ -90,7 +91,11 @@ impl SqlExprVisitor<'_> {
high,
} => self.visit_between(expr, *negated, low, high),
SqlExpr::BinaryOp { left, op, right } => self.visit_binary_op(left, op, right),
SqlExpr::Cast { expr, data_type } => self.visit_cast(expr, data_type),
SqlExpr::Cast {
expr,
data_type,
format,
} => self.visit_cast(expr, data_type, format),
SqlExpr::Ceil { expr, .. } => Ok(self.visit_expr(expr)?.ceil()),
SqlExpr::CompoundIdentifier(idents) => self.visit_compound_identifier(idents),
SqlExpr::Floor { expr, .. } => Ok(self.visit_expr(expr)?.floor()),
Expand Down Expand Up @@ -124,7 +129,8 @@ impl SqlExprVisitor<'_> {
expr,
trim_where,
trim_what,
} => self.visit_trim(expr, trim_where, trim_what),
trim_characters,
} => self.visit_trim(expr, trim_where, trim_what, trim_characters),
SqlExpr::UnaryOp { op, expr } => self.visit_unary_op(op, expr),
SqlExpr::Value(value) => self.visit_literal(value),
e @ SqlExpr::Case { .. } => self.visit_when_then(e),
Expand Down Expand Up @@ -342,7 +348,15 @@ impl SqlExprVisitor<'_> {
/// Visit a SQL CAST
///
/// e.g. `CAST(column AS INT)` or `column::INT`
fn visit_cast(&mut self, expr: &SqlExpr, data_type: &SQLDataType) -> PolarsResult<Expr> {
fn visit_cast(
&mut self,
expr: &SqlExpr,
data_type: &SQLDataType,
format: &Option<CastFormat>,
) -> PolarsResult<Expr> {
if format.is_some() {
return Err(polars_err!(ComputeError: "unsupported use of FORMAT in CAST expression"));
}
let polars_type = map_sql_polars_datatype(data_type)?;
let expr = self.visit_expr(expr)?;

Expand Down Expand Up @@ -440,15 +454,19 @@ impl SqlExprVisitor<'_> {
expr: &SqlExpr,
trim_where: &Option<TrimWhereField>,
trim_what: &Option<Box<SqlExpr>>,
trim_characters: &Option<Vec<SqlExpr>>,
) -> PolarsResult<Expr> {
if trim_characters.is_some() {
// TODO: allow compact snowflake/bigquery syntax?
return Err(polars_err!(ComputeError: "unsupported TRIM syntax"));
};
let expr = self.visit_expr(expr)?;
let trim_what = trim_what.as_ref().map(|e| self.visit_expr(e)).transpose()?;
let trim_what = match trim_what {
Some(Expr::Literal(LiteralValue::Utf8(val))) => Some(val),
None => None,
_ => return self.err(&expr),
};

Ok(match (trim_where, trim_what) {
(None | Some(TrimWhereField::Both), None) => expr.str().strip_chars(lit(Null)),
(None | Some(TrimWhereField::Both), Some(val)) => expr.str().strip_chars(lit(val)),
Expand Down Expand Up @@ -676,7 +694,7 @@ pub(super) fn process_join_constraint(
) -> PolarsResult<(Vec<Expr>, Vec<Expr>)> {
if let JoinConstraint::On(SqlExpr::BinaryOp { left, op, right }) = constraint {
if op != &BinaryOperator::Eq {
polars_bail!(InvalidOperation:
polars_bail!(InvalidOperation:
"SQL interface (currently) only supports basic equi-join \
constraints; found '{:?}' op in\n{:?}", op, constraint)
}
Expand Down
4 changes: 2 additions & 2 deletions py-polars/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 14 additions & 4 deletions py-polars/tests/unit/sql/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ def test_sql_cast() -> None:
(5.0, 5.0, 5, 5, 5, 1, "5", "5.5", b"e", b"e", "true"),
]

with pytest.raises(pl.ComputeError, match="unsupported use of FORMAT in CAST"):
pl.SQLContext(df=df, eager_execution=True).execute(
"SELECT CAST(a AS STRING FORMAT 'HEX') FROM df"
)


def test_sql_any_all() -> None:
df = pl.DataFrame(
Expand Down Expand Up @@ -896,7 +901,8 @@ def test_sql_substr() -> None:


def test_sql_trim(foods_ipc_path: Path) -> None:
out = pl.SQLContext(foods1=pl.scan_ipc(foods_ipc_path)).execute(
lf = pl.scan_ipc(foods_ipc_path)
out = pl.SQLContext(foods1=lf).execute(
"""
SELECT DISTINCT TRIM(LEADING 'vmf' FROM category) as new_category
FROM foods1
Expand All @@ -907,6 +913,13 @@ def test_sql_trim(foods_ipc_path: Path) -> None:
assert out.to_dict(as_series=False) == {
"new_category": ["seafood", "ruit", "egetables", "eat"]
}
with pytest.raises(pl.ComputeError, match="unsupported TRIM"):
# currently unsupported (snowflake) trim syntax
pl.SQLContext(foods=lf).execute(
"""
SELECT DISTINCT TRIM('*^xxxx^*', '^*') as new_category FROM foods
""",
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -947,9 +960,6 @@ def test_sql_trim(foods_ipc_path: Path) -> None:
["c2", "c1"],
"DISTINCT BY NAME",
[(1, "zz"), (2, "yy"), (3, "xx")],
# TODO: Remove xfail marker when supported added in sqlparser-rs
# https://github.com/sqlparser-rs/sqlparser-rs/pull/997
marks=pytest.mark.xfail,
),
],
)
Expand Down

0 comments on commit 8e3de87

Please sign in to comment.