From 8cd51870d1900c31bff21a0b7e196dbd21e07bfc Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Mon, 12 Aug 2024 07:40:05 -0400 Subject: [PATCH] DataFusion 41 fixes --- vegafusion-sql/src/compile/expr.rs | 49 +++++++------------ vegafusion-sql/src/compile/scalar.rs | 6 +-- .../src/connection/datafusion_conn.rs | 1 + 3 files changed, 21 insertions(+), 35 deletions(-) diff --git a/vegafusion-sql/src/compile/expr.rs b/vegafusion-sql/src/compile/expr.rs index c39e9a2e..d3856a07 100644 --- a/vegafusion-sql/src/compile/expr.rs +++ b/vegafusion-sql/src/compile/expr.rs @@ -3,11 +3,11 @@ use crate::compile::scalar::ToSqlScalar; use arrow::datatypes::DataType; use datafusion_common::{DFSchema, ScalarValue}; use sqlparser::ast::{ - BinaryOperator as SqlBinaryOperator, CastKind, Expr as SqlExpr, Function as SqlFunction, - FunctionArg as SqlFunctionArg, FunctionArgumentList, FunctionArguments, Ident, - ObjectName as SqlObjectName, UnaryOperator as SqlUnaryOperator, WindowFrame as SqlWindowFrame, - WindowFrameBound as SqlWindowBound, WindowFrameUnits as SqlWindowFrameUnits, - WindowSpec as SqlWindowSpec, WindowType, + BinaryOperator as SqlBinaryOperator, CastKind, DuplicateTreatment, Expr as SqlExpr, + Function as SqlFunction, FunctionArg as SqlFunctionArg, FunctionArgumentList, + FunctionArguments, Ident, ObjectName as SqlObjectName, UnaryOperator as SqlUnaryOperator, + WindowFrame as SqlWindowFrame, WindowFrameBound as SqlWindowBound, + WindowFrameUnits as SqlWindowFrameUnits, WindowSpec as SqlWindowSpec, WindowType, }; use datafusion_expr::expr::{BinaryExpr, Case, Cast, Sort}; @@ -367,11 +367,7 @@ impl ToSqlExpr for Expr { null_treatment: None, over: None, within_group: vec![], - parameters: FunctionArguments::List(FunctionArgumentList { - args: vec![], - duplicate_treatment: None, - clauses: vec![], - }), + parameters: FunctionArguments::None, })); } UnorderedRowNumberMode::OrderByConstant => { @@ -463,11 +459,7 @@ impl ToSqlExpr for Expr { null_treatment: None, over: Some(over), within_group: vec![], - parameters: FunctionArguments::List(FunctionArgumentList { - args: vec![], - duplicate_treatment: None, - clauses: vec![], - }), + parameters: FunctionArguments::None, }; Ok(SqlExpr::Function(sql_fun)) @@ -570,11 +562,7 @@ fn translate_scalar_function( null_treatment: None, over: None, within_group: vec![], - parameters: FunctionArguments::List(FunctionArgumentList { - args: vec![], - duplicate_treatment: None, - clauses: vec![], - }), + parameters: FunctionArguments::None, })) } else if let Some(transformer) = dialect.scalar_transformers.get(fun_name) { // Supported through AST transformation @@ -590,7 +578,7 @@ fn translate_scalar_function( fn translate_aggregate_function( fun_name: &str, args: &[Expr], - _distinct: bool, // Where should distinct go now? + distinct: bool, dialect: &Dialect, schema: &DFSchema, ) -> Result { @@ -600,24 +588,25 @@ fn translate_aggregate_function( quote_style: None, }; let args = translate_function_args(args, dialect, schema)?; - - Ok(SqlExpr::Function(SqlFunction { + let fn_expr = SqlExpr::Function(SqlFunction { name: SqlObjectName(vec![ident]), args: FunctionArguments::List(FunctionArgumentList { args, - duplicate_treatment: None, + duplicate_treatment: if distinct { + Some(DuplicateTreatment::Distinct) + } else { + None + }, clauses: vec![], }), filter: None, null_treatment: None, over: None, within_group: vec![], - parameters: FunctionArguments::List(FunctionArgumentList { - args: vec![], - duplicate_treatment: None, - clauses: vec![], - }), - })) + parameters: FunctionArguments::None, + }); + println!("{}", fn_expr.to_string()); + Ok(fn_expr) } else if let Some(transformer) = dialect.aggregate_transformers.get(fun_name) { // Supported through AST transformation transformer.transform(args, dialect, schema) diff --git a/vegafusion-sql/src/compile/scalar.rs b/vegafusion-sql/src/compile/scalar.rs index 3314674a..92887fc7 100644 --- a/vegafusion-sql/src/compile/scalar.rs +++ b/vegafusion-sql/src/compile/scalar.rs @@ -209,11 +209,7 @@ impl ToSqlScalar for ScalarValue { null_treatment: None, over: None, within_group: vec![], - parameters: FunctionArguments::List(FunctionArgumentList { - args: vec![], - duplicate_treatment: None, - clauses: vec![], - }), + parameters: FunctionArguments::None, })) } ScalarValue::Date32(v) => date32_to_date(v, dialect), diff --git a/vegafusion-sql/src/connection/datafusion_conn.rs b/vegafusion-sql/src/connection/datafusion_conn.rs index 1ff8400e..1b65de52 100644 --- a/vegafusion-sql/src/connection/datafusion_conn.rs +++ b/vegafusion-sql/src/connection/datafusion_conn.rs @@ -488,6 +488,7 @@ pub fn make_datafusion_context() -> SessionContext { let session_state = SessionStateBuilder::new() .with_config(config) .with_runtime_env(runtime) + .with_default_features() .build(); let ctx = SessionContext::new_with_state(session_state);