Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

build: Update sqlparser to 0.39 #12173

Merged
merged 1 commit into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ serde = "1.0.188"
serde_json = "1"
simd-json = { version = "0.12", features = ["known-key"] }
smartstring = "1"
sqlparser = "0.38"
sqlparser = "0.39"
strum_macros = "0.25"
thiserror = "1"
tokio = "1.26"
Expand Down
4 changes: 1 addition & 3 deletions crates/polars-sql/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,11 @@ impl SQLContext {
concatenated.map(|lf| lf.unique(None, UniqueKeepStrategy::Any))
},
// UNION ALL BY NAME
// TODO: add recognition for SetQuantifier::DistinctByName
// when "https://github.com/sqlparser-rs/sqlparser-rs/pull/997" is available
#[cfg(feature = "diagonal_concat")]
SetQuantifier::AllByName => concat_lf_diagonal(vec![left, right], opts),
// UNION [DISTINCT] BY NAME
#[cfg(feature = "diagonal_concat")]
SetQuantifier::ByName => {
SetQuantifier::ByName | SetQuantifier::DistinctByName => {
let concatenated = concat_lf_diagonal(vec![left, right], opts);
concatenated.map(|lf| lf.unique(None, UniqueKeepStrategy::Any))
},
Expand Down
38 changes: 28 additions & 10 deletions crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ use polars_plan::prelude::{col, lit, when};
use rand::distributions::Alphanumeric;
use rand::{thread_rng, Rng};
use sqlparser::ast::{
ArrayAgg, BinaryOperator as SQLBinaryOperator, BinaryOperator, DataType as SQLDataType,
Expr as SqlExpr, Function as SQLFunction, Ident, JoinConstraint, OrderByExpr,
Query as Subquery, SelectItem, TrimWhereField, UnaryOperator, Value as SqlValue,
ArrayAgg, ArrayElemTypeDef, BinaryOperator as SQLBinaryOperator, BinaryOperator, CastFormat,
DataType as SQLDataType, Expr as SqlExpr, Function as SQLFunction, Ident, JoinConstraint,
OrderByExpr, Query as Subquery, SelectItem, TrimWhereField, UnaryOperator, Value as SqlValue,
};
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::{Parser, ParserOptions};
Expand All @@ -19,7 +19,8 @@ use crate::SQLContext;

pub(crate) fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult<DataType> {
Ok(match data_type {
SQLDataType::Array(Some(inner_type)) => {
SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_type))
| SQLDataType::Array(ArrayElemTypeDef::SquareBracket(inner_type)) => {
DataType::List(Box::new(map_sql_polars_datatype(inner_type)?))
},
SQLDataType::BigInt(_) => DataType::Int64,
Expand All @@ -32,7 +33,7 @@ pub(crate) fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult<D
| SQLDataType::Character(_)
| SQLDataType::CharacterVarying(_)
| SQLDataType::Clob(_)
| SQLDataType::String
| SQLDataType::String(_)
| SQLDataType::Text
| SQLDataType::Uuid
| SQLDataType::Varchar(_) => DataType::Utf8,
Expand Down Expand Up @@ -90,7 +91,11 @@ impl SqlExprVisitor<'_> {
high,
} => self.visit_between(expr, *negated, low, high),
SqlExpr::BinaryOp { left, op, right } => self.visit_binary_op(left, op, right),
SqlExpr::Cast { expr, data_type } => self.visit_cast(expr, data_type),
SqlExpr::Cast {
expr,
data_type,
format,
} => self.visit_cast(expr, data_type, format),
SqlExpr::Ceil { expr, .. } => Ok(self.visit_expr(expr)?.ceil()),
SqlExpr::CompoundIdentifier(idents) => self.visit_compound_identifier(idents),
SqlExpr::Floor { expr, .. } => Ok(self.visit_expr(expr)?.floor()),
Expand Down Expand Up @@ -124,7 +129,8 @@ impl SqlExprVisitor<'_> {
expr,
trim_where,
trim_what,
} => self.visit_trim(expr, trim_where, trim_what),
trim_characters,
} => self.visit_trim(expr, trim_where, trim_what, trim_characters),
SqlExpr::UnaryOp { op, expr } => self.visit_unary_op(op, expr),
SqlExpr::Value(value) => self.visit_literal(value),
e @ SqlExpr::Case { .. } => self.visit_when_then(e),
Expand Down Expand Up @@ -342,7 +348,15 @@ impl SqlExprVisitor<'_> {
/// Visit a SQL CAST
///
/// e.g. `CAST(column AS INT)` or `column::INT`
fn visit_cast(&mut self, expr: &SqlExpr, data_type: &SQLDataType) -> PolarsResult<Expr> {
fn visit_cast(
&mut self,
expr: &SqlExpr,
data_type: &SQLDataType,
format: &Option<CastFormat>,
) -> PolarsResult<Expr> {
if format.is_some() {
return Err(polars_err!(ComputeError: "unsupported use of FORMAT in CAST expression"));
}
let polars_type = map_sql_polars_datatype(data_type)?;
let expr = self.visit_expr(expr)?;

Expand Down Expand Up @@ -440,15 +454,19 @@ impl SqlExprVisitor<'_> {
expr: &SqlExpr,
trim_where: &Option<TrimWhereField>,
trim_what: &Option<Box<SqlExpr>>,
trim_characters: &Option<Vec<SqlExpr>>,
) -> PolarsResult<Expr> {
if trim_characters.is_some() {
// TODO: allow compact snowflake/bigquery syntax?
return Err(polars_err!(ComputeError: "unsupported TRIM syntax"));
};
let expr = self.visit_expr(expr)?;
let trim_what = trim_what.as_ref().map(|e| self.visit_expr(e)).transpose()?;
let trim_what = match trim_what {
Some(Expr::Literal(LiteralValue::Utf8(val))) => Some(val),
None => None,
_ => return self.err(&expr),
};

Ok(match (trim_where, trim_what) {
(None | Some(TrimWhereField::Both), None) => expr.str().strip_chars(lit(Null)),
(None | Some(TrimWhereField::Both), Some(val)) => expr.str().strip_chars(lit(val)),
Expand Down Expand Up @@ -676,7 +694,7 @@ pub(super) fn process_join_constraint(
) -> PolarsResult<(Vec<Expr>, Vec<Expr>)> {
if let JoinConstraint::On(SqlExpr::BinaryOp { left, op, right }) = constraint {
if op != &BinaryOperator::Eq {
polars_bail!(InvalidOperation:
polars_bail!(InvalidOperation:
"SQL interface (currently) only supports basic equi-join \
constraints; found '{:?}' op in\n{:?}", op, constraint)
}
Expand Down
4 changes: 2 additions & 2 deletions py-polars/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 14 additions & 4 deletions py-polars/tests/unit/sql/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ def test_sql_cast() -> None:
(5.0, 5.0, 5, 5, 5, 1, "5", "5.5", b"e", b"e", "true"),
]

with pytest.raises(pl.ComputeError, match="unsupported use of FORMAT in CAST"):
pl.SQLContext(df=df, eager_execution=True).execute(
"SELECT CAST(a AS STRING FORMAT 'HEX') FROM df"
)


def test_sql_any_all() -> None:
df = pl.DataFrame(
Expand Down Expand Up @@ -896,7 +901,8 @@ def test_sql_substr() -> None:


def test_sql_trim(foods_ipc_path: Path) -> None:
out = pl.SQLContext(foods1=pl.scan_ipc(foods_ipc_path)).execute(
lf = pl.scan_ipc(foods_ipc_path)
out = pl.SQLContext(foods1=lf).execute(
"""
SELECT DISTINCT TRIM(LEADING 'vmf' FROM category) as new_category
FROM foods1
Expand All @@ -907,6 +913,13 @@ def test_sql_trim(foods_ipc_path: Path) -> None:
assert out.to_dict(as_series=False) == {
"new_category": ["seafood", "ruit", "egetables", "eat"]
}
with pytest.raises(pl.ComputeError, match="unsupported TRIM"):
# currently unsupported (snowflake) trim syntax
pl.SQLContext(foods=lf).execute(
"""
SELECT DISTINCT TRIM('*^xxxx^*', '^*') as new_category FROM foods
""",
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -947,9 +960,6 @@ def test_sql_trim(foods_ipc_path: Path) -> None:
["c2", "c1"],
"DISTINCT BY NAME",
[(1, "zz"), (2, "yy"), (3, "xx")],
# TODO: Remove xfail marker when supported added in sqlparser-rs
# https://github.com/sqlparser-rs/sqlparser-rs/pull/997
marks=pytest.mark.xfail,
),
],
)
Expand Down