Skip to content

Commit

Permalink
Initial bump to datafusion-python 34.0.0
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbluca committed Jan 11, 2024
1 parent 3ffba21 commit d8e31e2
Show file tree
Hide file tree
Showing 14 changed files with 456 additions and 472 deletions.
777 changes: 379 additions & 398 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ include = ["/src", "/dask_sql", "/LICENSE.txt", "pyproject.toml", "Cargo.toml",

[dependencies]
async-trait = "0.1.77"
datafusion-python = { git = "https://github.com/apache/arrow-datafusion-python.git", ref = "da6c183" }
datafusion-python = "34.0.0"
env_logger = "0.10"
log = "^0.4"
pyo3 = { version = "0.19.2", features = ["extension-module", "abi3", "abi3-py39"] }
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3", "abi3-py39"] }
pyo3-log = "0.9.0"

[build-dependencies]
Expand Down
12 changes: 12 additions & 0 deletions src/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ impl Dialect for DaskDialect {
Ok(Some(Expr::Function(Function {
name: ObjectName(vec![Ident::new("timestampceil")]),
args,
filter: None,
null_treatment: None,
over: None,
distinct: false,
special: false,
Expand Down Expand Up @@ -106,6 +108,8 @@ impl Dialect for DaskDialect {
Ok(Some(Expr::Function(Function {
name: ObjectName(vec![Ident::new("timestampfloor")]),
args,
filter: None,
null_treatment: None,
over: None,
distinct: false,
special: false,
Expand Down Expand Up @@ -135,6 +139,8 @@ impl Dialect for DaskDialect {
Ok(Some(Expr::Function(Function {
name: ObjectName(vec![Ident::new("timestampadd")]),
args,
filter: None,
null_treatment: None,
over: None,
distinct: false,
special: false,
Expand Down Expand Up @@ -163,6 +169,8 @@ impl Dialect for DaskDialect {
Ok(Some(Expr::Function(Function {
name: ObjectName(vec![Ident::new("timestampdiff")]),
args,
filter: None,
null_treatment: None,
over: None,
distinct: false,
special: false,
Expand Down Expand Up @@ -193,6 +201,8 @@ impl Dialect for DaskDialect {
Ok(Some(Expr::Function(Function {
name: ObjectName(vec![Ident::new("dsql_totimestamp")]),
args,
filter: None,
null_treatment: None,
over: None,
distinct: false,
special: false,
Expand Down Expand Up @@ -223,6 +233,8 @@ impl Dialect for DaskDialect {
Ok(Some(Expr::Function(Function {
name: ObjectName(vec![Ident::new("extract_date")]),
args,
filter: None,
null_treatment: None,
over: None,
distinct: false,
special: false,
Expand Down
48 changes: 17 additions & 31 deletions src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@ use datafusion_python::{
datafusion_expr::{
expr::{
AggregateFunction,
AggregateUDF,
Alias,
BinaryExpr,
Cast,
Exists,
InList,
InSubquery,
ScalarFunction,
ScalarUDF,
Sort,
TryCast,
WindowFunction,
Expand All @@ -29,6 +27,7 @@ use datafusion_python::{
Like,
LogicalPlan,
Operator,
ScalarFunctionDefinition,
},
datafusion_sql::TableReference,
};
Expand Down Expand Up @@ -105,10 +104,9 @@ impl PyExpr {
fn _rex_type(&self, expr: &Expr) -> RexType {
match expr {
Expr::Alias(..) => RexType::Alias,
Expr::Column(..)
| Expr::QualifiedWildcard { .. }
| Expr::GetIndexedField { .. }
| Expr::Wildcard => RexType::Reference,
Expr::Column(..) | Expr::GetIndexedField { .. } | Expr::Wildcard { .. } => {
RexType::Reference
}
Expr::ScalarVariable(..) | Expr::Literal(..) => RexType::Literal,
Expr::BinaryExpr { .. }
| Expr::Not(..)
Expand All @@ -125,9 +123,7 @@ impl PyExpr {
| Expr::ScalarFunction { .. }
| Expr::AggregateFunction { .. }
| Expr::WindowFunction { .. }
| Expr::AggregateUDF { .. }
| Expr::InList { .. }
| Expr::ScalarUDF { .. }
| Expr::Exists { .. }
| Expr::InSubquery { .. }
| Expr::GroupingSet(..)
Expand Down Expand Up @@ -280,11 +276,8 @@ impl PyExpr {
| Expr::AggregateFunction { .. }
| Expr::InList { .. }
| Expr::InSubquery { .. }
| Expr::ScalarUDF { .. }
| Expr::AggregateUDF { .. }
| Expr::Exists { .. }
| Expr::ScalarSubquery(..)
| Expr::QualifiedWildcard { .. }
| Expr::Not(..)
| Expr::OuterReferenceColumn(_, _)
| Expr::GroupingSet(..) => self.expr.variant_name(),
Expand All @@ -305,7 +298,7 @@ impl PyExpr {
| Expr::TryCast { .. }
| Expr::WindowFunction { .. }
| Expr::Placeholder { .. }
| Expr::Wildcard => {
| Expr::Wildcard { .. } => {
return Err(py_type_err(format!(
"Encountered unsupported expression type: {}",
&self.expr.variant_name()
Expand Down Expand Up @@ -362,9 +355,7 @@ impl PyExpr {

// Expr variants containing a collection of Expr(s) for operands
Expr::AggregateFunction(AggregateFunction { args, .. })
| Expr::AggregateUDF(AggregateUDF { args, .. })
| Expr::ScalarFunction(ScalarFunction { args, .. })
| Expr::ScalarUDF(ScalarUDF { args, .. })
| Expr::WindowFunction(WindowFunction { args, .. }) => Ok(args
.iter()
.map(|arg| PyExpr::from(arg.clone(), self.input_plan.clone()))
Expand Down Expand Up @@ -437,15 +428,14 @@ impl PyExpr {
PyExpr::from(*low.clone(), self.input_plan.clone()),
PyExpr::from(*high.clone(), self.input_plan.clone()),
]),
Expr::Wildcard => Ok(vec![PyExpr::from(
Expr::Wildcard { .. } => Ok(vec![PyExpr::from(
self.expr.clone(),
self.input_plan.clone(),
)]),

// Currently un-support/implemented Expr types for Rex Call operations
Expr::GroupingSet(..)
| Expr::OuterReferenceColumn(_, _)
| Expr::QualifiedWildcard { .. }
| Expr::ScalarSubquery(..)
| Expr::Placeholder { .. }
| Expr::Exists { .. } => Err(py_runtime_err(format!(
Expand All @@ -463,8 +453,7 @@ impl PyExpr {
op,
right: _,
}) => format!("{op}"),
Expr::ScalarFunction(ScalarFunction { fun, args: _ }) => format!("{fun}"),
Expr::ScalarUDF(ScalarUDF { fun, .. }) => fun.name.clone(),
Expr::ScalarFunction(ScalarFunction { func_def, args: _ }) => format!("{func_def:?}"),
Expr::Cast { .. } => "cast".to_string(),
Expr::Between { .. } => "between".to_string(),
Expr::Case { .. } => "case".to_string(),
Expand Down Expand Up @@ -586,18 +575,19 @@ impl PyExpr {
ScalarValue::List(..) => "List",
ScalarValue::Struct(..) => "Struct",
ScalarValue::FixedSizeBinary(_, _) => "FixedSizeBinary",
ScalarValue::Fixedsizelist(..) => "Fixedsizelist",
ScalarValue::FixedSizeList(..) => "FixedSizeList",
ScalarValue::DurationSecond(..) => "DurationSecond",
ScalarValue::DurationMillisecond(..) => "DurationMillisecond",
ScalarValue::DurationMicrosecond(..) => "DurationMicrosecond",
ScalarValue::DurationNanosecond(..) => "DurationNanosecond",
ScalarValue::LargeList(..) => "LargeList",
},
Expr::ScalarFunction(ScalarFunction { fun, args: _ }) => match fun {
BuiltinScalarFunction::Abs => "Abs",
BuiltinScalarFunction::DatePart => "DatePart",
Expr::ScalarFunction(ScalarFunction { func_def, args: _ }) => match func_def {
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Abs) => "Abs",
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::DatePart) => "DatePart",
_ => {
return Err(py_type_err(format!(
"Catch all triggered for ScalarFunction in get_type; {fun:?}"
"Catch all triggered for ScalarFunction in get_type; {func_def:?}"
)))
}
},
Expand Down Expand Up @@ -673,8 +663,7 @@ impl PyExpr {
// TODO refactor to avoid duplication
match &self.expr {
Expr::Alias(Alias { expr, .. }) => match expr.as_ref() {
Expr::AggregateFunction(AggregateFunction { filter, .. })
| Expr::AggregateUDF(AggregateUDF { filter, .. }) => match filter {
Expr::AggregateFunction(AggregateFunction { filter, .. }) => match filter {
Some(filter) => {
Ok(Some(PyExpr::from(*filter.clone(), self.input_plan.clone())))
}
Expand All @@ -684,8 +673,7 @@ impl PyExpr {
"getFilterExpr() - Non-aggregate expression encountered",
)),
},
Expr::AggregateFunction(AggregateFunction { filter, .. })
| Expr::AggregateUDF(AggregateUDF { filter, .. }) => match filter {
Expr::AggregateFunction(AggregateFunction { filter, .. }) => match filter {
Some(filter) => Ok(Some(PyExpr::from(*filter.clone(), self.input_plan.clone()))),
None => Ok(None),
},
Expand Down Expand Up @@ -843,10 +831,8 @@ impl PyExpr {
// TODO refactor to avoid duplication
match &self.expr {
Expr::AggregateFunction(funct) => Ok(funct.distinct),
Expr::AggregateUDF { .. } => Ok(false),
Expr::Alias(Alias { expr, .. }) => match expr.as_ref() {
Expr::AggregateFunction(funct) => Ok(funct.distinct),
Expr::AggregateUDF { .. } => Ok(false),
_ => Err(py_type_err(
"isDistinctAgg() - Non-aggregate expression encountered",
)),
Expand Down Expand Up @@ -916,7 +902,7 @@ fn unexpected_literal_value(value: &ScalarValue) -> PyErr {
fn get_expr_name(expr: &Expr) -> Result<String> {
match expr {
Expr::Alias(Alias { expr, .. }) => get_expr_name(expr),
Expr::Wildcard => {
Expr::Wildcard { .. } => {
// 'Wildcard' means any and all columns. We get the first valid column name here
Ok("*".to_owned())
}
Expand All @@ -932,7 +918,7 @@ pub fn expr_to_field(expr: &Expr, input_plan: &LogicalPlan) -> Result<DFField> {
// appear in projections) so we just delegate to the contained expression instead
expr_to_field(expr, input_plan)
}
Expr::Wildcard => {
Expr::Wildcard { .. } => {
// Any column will do. We use the first column to keep things consistent
Ok(input_plan.schema().field(0).clone())
}
Expand Down
2 changes: 1 addition & 1 deletion src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1101,7 +1101,7 @@ impl<'a> DaskParser<'a> {
} else {
self.parser.prev_token();
Ok(DaskStatement::Statement(Box::from(
self.parser.parse_create_view(or_replace)?,
self.parser.parse_create_view(or_replace, false)?,
)))
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ pub struct DaskSQLContext {
}

impl ContextProvider for DaskSQLContext {
fn get_table_provider(
fn get_table_source(
&self,
name: TableReference,
) -> Result<Arc<dyn TableSource>, DataFusionError> {
Expand Down
6 changes: 3 additions & 3 deletions src/sql/logical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub mod window;

use datafusion_python::{
datafusion_common::{DFSchemaRef, DataFusionError},
datafusion_expr::{DdlStatement, LogicalPlan},
datafusion_expr::{logical_plan::Distinct, DdlStatement, LogicalPlan},
};
use pyo3::prelude::*;

Expand Down Expand Up @@ -411,8 +411,8 @@ impl PyLogicalPlan {
lhs_fields.append(&mut rhs_fields);
Ok(RelDataType::new(false, lhs_fields))
}
LogicalPlan::Distinct(distinct) => {
let schema = distinct.input.schema();
LogicalPlan::Distinct(Distinct::All(input)) => {
let schema = input.schema();
let rel_fields: Vec<RelDataTypeField> = schema
.fields()
.iter()
Expand Down
18 changes: 7 additions & 11 deletions src/sql/logical/aggregate.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use datafusion_python::datafusion_expr::{
expr::{AggregateFunction, AggregateUDF, Alias},
expr::{AggregateFunction, Alias},
logical_plan::{Aggregate, Distinct},
Expr,
LogicalPlan,
Expand All @@ -24,8 +24,8 @@ impl PyAggregate {
#[pyo3(name = "getDistinctColumns")]
pub fn distinct_columns(&self) -> PyResult<Vec<String>> {
match &self.distinct {
Some(e) => Ok(e.input.schema().field_names()),
None => Err(py_type_err(
Some(Distinct::All(input)) => Ok(input.schema().field_names()),
_ => Err(py_type_err(
"distinct_columns invoked for non distinct instance",
)),
}
Expand Down Expand Up @@ -74,8 +74,9 @@ impl PyAggregate {
fn _aggregation_arguments(&self, expr: &Expr) -> PyResult<Vec<PyExpr>> {
match expr {
Expr::Alias(Alias { expr, .. }) => self._aggregation_arguments(expr.as_ref()),
Expr::AggregateFunction(AggregateFunction { fun: _, args, .. })
| Expr::AggregateUDF(AggregateUDF { fun: _, args, .. }) => match &self.aggregate {
Expr::AggregateFunction(AggregateFunction {
func_def: _, args, ..
}) => match &self.aggregate {
Some(e) => py_expr_list(&e.input, args),
None => Ok(vec![]),
},
Expand All @@ -89,8 +90,7 @@ impl PyAggregate {
fn _agg_func_name(expr: &Expr) -> PyResult<String> {
match expr {
Expr::Alias(Alias { expr, .. }) => _agg_func_name(expr.as_ref()),
Expr::AggregateFunction(AggregateFunction { fun, .. }) => Ok(fun.to_string()),
Expr::AggregateUDF(AggregateUDF { fun, .. }) => Ok(fun.name.clone()),
Expr::AggregateFunction(AggregateFunction { func_def, .. }) => Ok(format!("{func_def:?}")),
_ => Err(py_type_err(
"Encountered a non Aggregate type in agg_func_name",
)),
Expand All @@ -101,10 +101,6 @@ fn _distinct_agg_expr(expr: &Expr) -> PyResult<bool> {
match expr {
Expr::Alias(Alias { expr, .. }) => _distinct_agg_expr(expr.as_ref()),
Expr::AggregateFunction(AggregateFunction { distinct, .. }) => Ok(*distinct),
Expr::AggregateUDF { .. } => {
// DataFusion does not support DISTINCT in UDAFs
Ok(false)
}
_ => Err(py_type_err(
"Encountered a non Aggregate type in distinct_agg_expr",
)),
Expand Down
2 changes: 1 addition & 1 deletion src/sql/logical/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ impl PyProjection {
PyExpr::from(expression, Some(vec![self.projection.input.clone()]));
for expr in self.projected_expressions(&py_expr) {
match expr.expr {
Expr::Alias(Alias { expr, name }) => named.push((
Expr::Alias(Alias { expr, name, .. }) => named.push((
name.to_string(),
PyExpr::from(*expr, Some(vec![self.projection.input.clone()])),
)),
Expand Down
8 changes: 3 additions & 5 deletions src/sql/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@ use datafusion_python::{
eliminate_cross_join::EliminateCrossJoin,
eliminate_limit::EliminateLimit,
eliminate_outer_join::EliminateOuterJoin,
eliminate_project::EliminateProjection,
filter_null_join_keys::FilterNullJoinKeys,
optimizer::{Optimizer, OptimizerRule},
push_down_filter::PushDownFilter,
push_down_limit::PushDownLimit,
push_down_projection::PushDownProjection,
rewrite_disjunctive_predicate::RewriteDisjunctivePredicate,
scalar_subquery_to_join::ScalarSubqueryToJoin,
simplify_expressions::SimplifyExpressions,
Expand Down Expand Up @@ -81,8 +79,8 @@ impl DaskSqlOptimizer {
Arc::new(SimplifyExpressions::new()),
Arc::new(UnwrapCastInComparison::new()),
// Arc::new(CommonSubexprEliminate::new()),
Arc::new(PushDownProjection::new()),
Arc::new(EliminateProjection::new()),
// Arc::new(PushDownProjection::new()),
// Arc::new(EliminateProjection::new()),
// PushDownProjection can pushdown Projections through Limits, do PushDownLimit again.
Arc::new(PushDownLimit::new()),
];
Expand Down Expand Up @@ -191,7 +189,7 @@ mod tests {
&self.options
}

fn get_table_provider(
fn get_table_source(
&self,
name: TableReference,
) -> datafusion_python::datafusion_common::Result<Arc<dyn TableSource>> {
Expand Down
Loading

0 comments on commit d8e31e2

Please sign in to comment.