From 950f0f97d8ab40175b8d2bae554e9f90b009df2d Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Tue, 31 Oct 2023 00:34:39 +0400 Subject: [PATCH] build: update `sqlparser-rs` to `0.39` --- Cargo.toml | 2 +- crates/polars-sql/src/context.rs | 4 +-- crates/polars-sql/src/sql_expr.rs | 38 ++++++++++++++++++++-------- py-polars/Cargo.lock | 4 +-- py-polars/tests/unit/sql/test_sql.py | 18 ++++++++++--- 5 files changed, 46 insertions(+), 20 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a32e5f0a7c9c..279fd36459bd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,7 +47,7 @@ serde = "1.0.188" serde_json = "1" simd-json = { version = "0.12", features = ["known-key"] } smartstring = "1" -sqlparser = "0.38" +sqlparser = "0.39" strum_macros = "0.25" thiserror = "1" tokio = "1.26" diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index cbb856c20c2d..9774cbbd932c 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -223,13 +223,11 @@ impl SQLContext { concatenated.map(|lf| lf.unique(None, UniqueKeepStrategy::Any)) }, // UNION ALL BY NAME - // TODO: add recognition for SetQuantifier::DistinctByName - // when "https://github.com/sqlparser-rs/sqlparser-rs/pull/997" is available #[cfg(feature = "diagonal_concat")] SetQuantifier::AllByName => concat_lf_diagonal(vec![left, right], opts), // UNION [DISTINCT] BY NAME #[cfg(feature = "diagonal_concat")] - SetQuantifier::ByName => { + SetQuantifier::ByName | SetQuantifier::DistinctByName => { let concatenated = concat_lf_diagonal(vec![left, right], opts); concatenated.map(|lf| lf.unique(None, UniqueKeepStrategy::Any)) }, diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index a9013bb95dff..79486051f03c 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -7,9 +7,9 @@ use polars_plan::prelude::{col, lit, when}; use rand::distributions::Alphanumeric; use rand::{thread_rng, Rng}; use sqlparser::ast::{ - ArrayAgg, BinaryOperator as SQLBinaryOperator, BinaryOperator, DataType as SQLDataType, - Expr as SqlExpr, Function as SQLFunction, Ident, JoinConstraint, OrderByExpr, - Query as Subquery, SelectItem, TrimWhereField, UnaryOperator, Value as SqlValue, + ArrayAgg, ArrayElemTypeDef, BinaryOperator as SQLBinaryOperator, BinaryOperator, CastFormat, + DataType as SQLDataType, Expr as SqlExpr, Function as SQLFunction, Ident, JoinConstraint, + OrderByExpr, Query as Subquery, SelectItem, TrimWhereField, UnaryOperator, Value as SqlValue, }; use sqlparser::dialect::GenericDialect; use sqlparser::parser::{Parser, ParserOptions}; @@ -19,7 +19,8 @@ use crate::SQLContext; pub(crate) fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult { Ok(match data_type { - SQLDataType::Array(Some(inner_type)) => { + SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_type)) + | SQLDataType::Array(ArrayElemTypeDef::SquareBracket(inner_type)) => { DataType::List(Box::new(map_sql_polars_datatype(inner_type)?)) }, SQLDataType::BigInt(_) => DataType::Int64, @@ -32,7 +33,7 @@ pub(crate) fn map_sql_polars_datatype(data_type: &SQLDataType) -> PolarsResult DataType::Utf8, @@ -90,7 +91,11 @@ impl SqlExprVisitor<'_> { high, } => self.visit_between(expr, *negated, low, high), SqlExpr::BinaryOp { left, op, right } => self.visit_binary_op(left, op, right), - SqlExpr::Cast { expr, data_type } => self.visit_cast(expr, data_type), + SqlExpr::Cast { + expr, + data_type, + format, + } => self.visit_cast(expr, data_type, format), SqlExpr::Ceil { expr, .. } => Ok(self.visit_expr(expr)?.ceil()), SqlExpr::CompoundIdentifier(idents) => self.visit_compound_identifier(idents), SqlExpr::Floor { expr, .. } => Ok(self.visit_expr(expr)?.floor()), @@ -124,7 +129,8 @@ impl SqlExprVisitor<'_> { expr, trim_where, trim_what, - } => self.visit_trim(expr, trim_where, trim_what), + trim_characters, + } => self.visit_trim(expr, trim_where, trim_what, trim_characters), SqlExpr::UnaryOp { op, expr } => self.visit_unary_op(op, expr), SqlExpr::Value(value) => self.visit_literal(value), e @ SqlExpr::Case { .. } => self.visit_when_then(e), @@ -342,7 +348,15 @@ impl SqlExprVisitor<'_> { /// Visit a SQL CAST /// /// e.g. `CAST(column AS INT)` or `column::INT` - fn visit_cast(&mut self, expr: &SqlExpr, data_type: &SQLDataType) -> PolarsResult { + fn visit_cast( + &mut self, + expr: &SqlExpr, + data_type: &SQLDataType, + format: &Option, + ) -> PolarsResult { + if format.is_some() { + return Err(polars_err!(ComputeError: "unsupported use of FORMAT in CAST expression")); + } let polars_type = map_sql_polars_datatype(data_type)?; let expr = self.visit_expr(expr)?; @@ -440,7 +454,12 @@ impl SqlExprVisitor<'_> { expr: &SqlExpr, trim_where: &Option, trim_what: &Option>, + trim_characters: &Option>, ) -> PolarsResult { + if trim_characters.is_some() { + // TODO: allow compact snowflake/bigquery syntax? + return Err(polars_err!(ComputeError: "unsupported TRIM syntax")); + }; let expr = self.visit_expr(expr)?; let trim_what = trim_what.as_ref().map(|e| self.visit_expr(e)).transpose()?; let trim_what = match trim_what { @@ -448,7 +467,6 @@ impl SqlExprVisitor<'_> { None => None, _ => return self.err(&expr), }; - Ok(match (trim_where, trim_what) { (None | Some(TrimWhereField::Both), None) => expr.str().strip_chars(lit(Null)), (None | Some(TrimWhereField::Both), Some(val)) => expr.str().strip_chars(lit(val)), @@ -676,7 +694,7 @@ pub(super) fn process_join_constraint( ) -> PolarsResult<(Vec, Vec)> { if let JoinConstraint::On(SqlExpr::BinaryOp { left, op, right }) = constraint { if op != &BinaryOperator::Eq { - polars_bail!(InvalidOperation: + polars_bail!(InvalidOperation: "SQL interface (currently) only supports basic equi-join \ constraints; found '{:?}' op in\n{:?}", op, constraint) } diff --git a/py-polars/Cargo.lock b/py-polars/Cargo.lock index b9ba1326cf62..fde9cea763d1 100644 --- a/py-polars/Cargo.lock +++ b/py-polars/Cargo.lock @@ -2513,9 +2513,9 @@ checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" [[package]] name = "sqlparser" -version = "0.38.0" +version = "0.39.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0272b7bb0a225320170c99901b4b5fb3a4384e255a7f2cc228f61e2ba3893e75" +checksum = "743b4dc2cbde11890ccb254a8fc9d537fa41b36da00de2a1c5e9848c9bc42bd7" dependencies = [ "log", ] diff --git a/py-polars/tests/unit/sql/test_sql.py b/py-polars/tests/unit/sql/test_sql.py index 545a6fab8ea7..dd59125c2bfa 100644 --- a/py-polars/tests/unit/sql/test_sql.py +++ b/py-polars/tests/unit/sql/test_sql.py @@ -70,6 +70,11 @@ def test_sql_cast() -> None: (5.0, 5.0, 5, 5, 5, 1, "5", "5.5", b"e", b"e", "true"), ] + with pytest.raises(pl.ComputeError, match="unsupported use of FORMAT in CAST"): + pl.SQLContext(df=df, eager_execution=True).execute( + "SELECT CAST(a AS STRING FORMAT 'HEX') FROM df" + ) + def test_sql_any_all() -> None: df = pl.DataFrame( @@ -896,7 +901,8 @@ def test_sql_substr() -> None: def test_sql_trim(foods_ipc_path: Path) -> None: - out = pl.SQLContext(foods1=pl.scan_ipc(foods_ipc_path)).execute( + lf = pl.scan_ipc(foods_ipc_path) + out = pl.SQLContext(foods1=lf).execute( """ SELECT DISTINCT TRIM(LEADING 'vmf' FROM category) as new_category FROM foods1 @@ -907,6 +913,13 @@ def test_sql_trim(foods_ipc_path: Path) -> None: assert out.to_dict(as_series=False) == { "new_category": ["seafood", "ruit", "egetables", "eat"] } + with pytest.raises(pl.ComputeError, match="unsupported TRIM"): + # currently unsupported (snowflake) trim syntax + pl.SQLContext(foods=lf).execute( + """ + SELECT DISTINCT TRIM('*^xxxx^*', '^*') as new_category FROM foods + """, + ) @pytest.mark.parametrize( @@ -947,9 +960,6 @@ def test_sql_trim(foods_ipc_path: Path) -> None: ["c2", "c1"], "DISTINCT BY NAME", [(1, "zz"), (2, "yy"), (3, "xx")], - # TODO: Remove xfail marker when supported added in sqlparser-rs - # https://github.com/sqlparser-rs/sqlparser-rs/pull/997 - marks=pytest.mark.xfail, ), ], )