Skip to content

Commit

Permalink
Merge remote-tracking branch 'apache/main' into lowercase-count
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Jun 21, 2024
2 parents 287a00d + 098ba30 commit 74c0411
Show file tree
Hide file tree
Showing 40 changed files with 1,234 additions and 728 deletions.
48 changes: 26 additions & 22 deletions datafusion-examples/examples/rewrite_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::{plan_err, Result, ScalarValue};
use datafusion_common::{plan_err, DataFusionError, Result, ScalarValue};
use datafusion_expr::{
AggregateUDF, Between, Expr, Filter, LogicalPlan, ScalarUDF, TableSource, WindowUDF,
};
use datafusion_optimizer::analyzer::{Analyzer, AnalyzerRule};
use datafusion_optimizer::optimizer::Optimizer;
use datafusion_optimizer::{utils, OptimizerConfig, OptimizerContext, OptimizerRule};
use datafusion_optimizer::optimizer::{ApplyOrder, Optimizer};
use datafusion_optimizer::{OptimizerConfig, OptimizerContext, OptimizerRule};
use datafusion_sql::planner::{ContextProvider, SqlToRel};
use datafusion_sql::sqlparser::dialect::PostgreSqlDialect;
use datafusion_sql::sqlparser::parser::Parser;
Expand Down Expand Up @@ -133,30 +133,34 @@ impl OptimizerRule for MyOptimizerRule {

fn try_optimize(
&self,
plan: &LogicalPlan,
config: &dyn OptimizerConfig,
_plan: &LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
// recurse down and optimize children first
let optimized_plan = utils::optimize_children(self, plan, config)?;
match optimized_plan {
Some(LogicalPlan::Filter(filter)) => {
unreachable!()
}

fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::BottomUp)
}

fn supports_rewrite(&self) -> bool {
true
}

fn rewrite(
&self,
plan: LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>, DataFusionError> {
match plan {
LogicalPlan::Filter(filter) => {
let predicate = my_rewrite(filter.predicate.clone())?;
Ok(Some(LogicalPlan::Filter(Filter::try_new(
Ok(Transformed::yes(LogicalPlan::Filter(Filter::try_new(
predicate,
filter.input,
filter.input.clone(),
)?)))
}
Some(optimized_plan) => Ok(Some(optimized_plan)),
None => match plan {
LogicalPlan::Filter(filter) => {
let predicate = my_rewrite(filter.predicate.clone())?;
Ok(Some(LogicalPlan::Filter(Filter::try_new(
predicate,
filter.input.clone(),
)?)))
}
_ => Ok(None),
},
_ => Ok(Transformed::no(plan)),
}
}
}
Expand Down
83 changes: 81 additions & 2 deletions datafusion/common/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,11 @@ impl TreeNodeRecursion {

/// Result of tree walk / transformation APIs
///
/// API users control the transformation by returning:
/// `Transformed` is a wrapper around the tree node data (e.g. `Expr` or
/// `LogicalPlan`). It is used to indicate whether the node was transformed
/// and how the recursion should proceed.
///
/// [`TreeNode`] API users control the transformation by returning:
/// - The resulting (possibly transformed) node,
/// - `transformed`: flag indicating whether any change was made to the node
/// - `tnr`: [`TreeNodeRecursion`] specifying how to proceed with the recursion.
Expand All @@ -592,7 +596,66 @@ impl TreeNodeRecursion {
/// - `transformed`: flag indicating whether any change was made to the node
/// - `tnr`: [`TreeNodeRecursion`] specifying how the recursion ended.
///
/// Example APIs:
/// See also
/// * [`Transformed::update_data`] to modify the node without changing the `transformed` flag
/// * [`Transformed::map_data`] for fallable operation that return the same type
/// * [`Transformed::transform_data`] to chain fallable transformations
/// * [`TransformedResult`] for working with `Result<Transformed<U>>`
///
/// # Examples
///
/// Use [`Transformed::yes`] and [`Transformed::no`] to signal that a node was
/// rewritten and the recursion should continue:
///
/// ```
/// # use datafusion_common::tree_node::Transformed;
/// # // note use i64 instead of Expr as Expr is not in datafusion-common
/// # fn orig_expr() -> i64 { 1 }
/// # fn make_new_expr(i: i64) -> i64 { 2 }
/// let expr = orig_expr();
///
/// // Create a new `Transformed` object signaling the node was not rewritten
/// let ret = Transformed::no(expr.clone());
/// assert!(!ret.transformed);
///
/// // Create a new `Transformed` object signaling the node was rewritten
/// let ret = Transformed::yes(expr);
/// assert!(ret.transformed)
/// ```
///
/// Access the node within the `Transformed` object:
/// ```
/// # use datafusion_common::tree_node::Transformed;
/// # // note use i64 instead of Expr as Expr is not in datafusion-common
/// # fn orig_expr() -> i64 { 1 }
/// # fn make_new_expr(i: i64) -> i64 { 2 }
/// let expr = orig_expr();
///
/// // `Transformed` object signaling the node was not rewritten
/// let ret = Transformed::no(expr.clone());
/// // Access the inner object using .data
/// assert_eq!(expr, ret.data);
/// ```
///
/// Transform the node within the `Transformed` object.
///
/// ```
/// # use datafusion_common::tree_node::Transformed;
/// # // note use i64 instead of Expr as Expr is not in datafusion-common
/// # fn orig_expr() -> i64 { 1 }
/// # fn make_new_expr(i: i64) -> i64 { 2 }
/// let expr = orig_expr();
/// let ret = Transformed::no(expr.clone())
/// .transform_data(|expr| {
/// // closure returns a result and potentially transforms the node
/// // in this example, it does transform the node
/// let new_expr = make_new_expr(expr);
/// Ok(Transformed::yes(new_expr))
/// }).unwrap();
/// // transformed flag is the union of the original ans closure's transformed flag
/// assert!(ret.transformed);
/// ```
/// # Example APIs that use `TreeNode`
/// - [`TreeNode`],
/// - [`TreeNode::rewrite`],
/// - [`TreeNode::transform_down`],
Expand Down Expand Up @@ -833,6 +896,22 @@ macro_rules! map_until_stop_and_collect {
}

/// Transformation helper to access [`Transformed`] fields in a [`Result`] easily.
///
/// # Example
/// Access the internal data of a `Result<Transformed<T>>`
/// as a `Result<T>` using the `data` method:
/// ```
/// # use datafusion_common::Result;
/// # use datafusion_common::tree_node::{Transformed, TransformedResult};
/// # // note use i64 instead of Expr as Expr is not in datafusion-common
/// # fn update_expr() -> i64 { 1 }
/// # fn main() -> Result<()> {
/// let transformed: Result<Transformed<_>> = Ok(Transformed::yes(update_expr()));
/// // access the internal data of the transformed result, or return the error
/// let transformed_expr = transformed.data()?;
/// # Ok(())
/// # }
/// ```
pub trait TransformedResult<T> {
fn data(self) -> Result<T>;

Expand Down
169 changes: 169 additions & 0 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,42 @@ impl DataFrame {
})
}

/// Returns a new DataFrame containing all columns except the specified columns.
///
/// ```
/// # use datafusion::prelude::*;
/// # use datafusion::error::Result;
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// let ctx = SessionContext::new();
/// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?;
/// let df = df.drop_columns(&["a"])?;
/// # Ok(())
/// # }
/// ```
pub fn drop_columns(self, columns: &[&str]) -> Result<DataFrame> {
let fields_to_drop = columns
.iter()
.map(|name| {
self.plan
.schema()
.qualified_field_with_unqualified_name(name)
})
.filter(|r| r.is_ok())
.collect::<Result<Vec<_>>>()?;
let expr: Vec<Expr> = self
.plan
.schema()
.fields()
.into_iter()
.enumerate()
.map(|(idx, _)| self.plan.schema().qualified_field(idx))
.filter(|(qualifier, f)| !fields_to_drop.contains(&(*qualifier, f)))
.map(|(qualifier, field)| Expr::Column(Column::from((qualifier, field))))
.collect();
self.select(expr)
}

/// Expand each list element of a column to multiple rows.
#[deprecated(since = "37.0.0", note = "use unnest_columns instead")]
pub fn unnest_column(self, column: &str) -> Result<DataFrame> {
Expand Down Expand Up @@ -1799,6 +1835,139 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn drop_columns() -> Result<()> {
// build plan using Table API
let t = test_table().await?;
let t2 = t.drop_columns(&["c2", "c11"])?;
let plan = t2.plan.clone();

// build query using SQL
let sql_plan = create_plan(
"SELECT c1,c3,c4,c5,c6,c7,c8,c9,c10,c12,c13 FROM aggregate_test_100",
)
.await?;

// the two plans should be identical
assert_same_plan(&plan, &sql_plan);

Ok(())
}

#[tokio::test]
async fn drop_columns_with_duplicates() -> Result<()> {
// build plan using Table API
let t = test_table().await?;
let t2 = t.drop_columns(&["c2", "c11", "c2", "c2"])?;
let plan = t2.plan.clone();

// build query using SQL
let sql_plan = create_plan(
"SELECT c1,c3,c4,c5,c6,c7,c8,c9,c10,c12,c13 FROM aggregate_test_100",
)
.await?;

// the two plans should be identical
assert_same_plan(&plan, &sql_plan);

Ok(())
}

#[tokio::test]
async fn drop_columns_with_nonexistent_columns() -> Result<()> {
// build plan using Table API
let t = test_table().await?;
let t2 = t.drop_columns(&["canada", "c2", "rocks"])?;
let plan = t2.plan.clone();

// build query using SQL
let sql_plan = create_plan(
"SELECT c1,c3,c4,c5,c6,c7,c8,c9,c10,c11,c12,c13 FROM aggregate_test_100",
)
.await?;

// the two plans should be identical
assert_same_plan(&plan, &sql_plan);

Ok(())
}

#[tokio::test]
async fn drop_columns_with_empty_array() -> Result<()> {
// build plan using Table API
let t = test_table().await?;
let t2 = t.drop_columns(&[])?;
let plan = t2.plan.clone();

// build query using SQL
let sql_plan = create_plan(
"SELECT c1,c2,c3,c4,c5,c6,c7,c8,c9,c10,c11,c12,c13 FROM aggregate_test_100",
)
.await?;

// the two plans should be identical
assert_same_plan(&plan, &sql_plan);

Ok(())
}

#[tokio::test]
async fn drop_with_quotes() -> Result<()> {
// define data with a column name that has a "." in it:
let array1: Int32Array = [1, 10].into_iter().collect();
let array2: Int32Array = [2, 11].into_iter().collect();
let batch = RecordBatch::try_from_iter(vec![
("f\"c1", Arc::new(array1) as _),
("f\"c2", Arc::new(array2) as _),
])?;

let ctx = SessionContext::new();
ctx.register_batch("t", batch)?;

let df = ctx.table("t").await?.drop_columns(&["f\"c1"])?;

let df_results = df.collect().await?;

assert_batches_sorted_eq!(
[
"+------+",
"| f\"c2 |",
"+------+",
"| 2 |",
"| 11 |",
"+------+"
],
&df_results
);

Ok(())
}

#[tokio::test]
async fn drop_with_periods() -> Result<()> {
// define data with a column name that has a "." in it:
let array1: Int32Array = [1, 10].into_iter().collect();
let array2: Int32Array = [2, 11].into_iter().collect();
let batch = RecordBatch::try_from_iter(vec![
("f.c1", Arc::new(array1) as _),
("f.c2", Arc::new(array2) as _),
])?;

let ctx = SessionContext::new();
ctx.register_batch("t", batch)?;

let df = ctx.table("t").await?.drop_columns(&["f.c1"])?;

let df_results = df.collect().await?;

assert_batches_sorted_eq!(
["+------+", "| f.c2 |", "+------+", "| 2 |", "| 11 |", "+------+"],
&df_results
);

Ok(())
}

#[tokio::test]
async fn aggregate() -> Result<()> {
// build plan using DataFrame API
Expand Down
Loading

0 comments on commit 74c0411

Please sign in to comment.