Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-organize and rename aggregates physical plan #2388

Merged
merged 3 commits into from
May 2, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ballista/rust/core/proto/ballista.proto
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ message PhysicalPlanNode {
ProjectionExecNode projection = 4;
GlobalLimitExecNode global_limit = 6;
LocalLimitExecNode local_limit = 7;
HashAggregateExecNode hash_aggregate = 8;
AggregateExecNode aggregate = 8;
HashJoinExecNode hash_join = 9;
ShuffleReaderExecNode shuffle_reader = 10;
SortExecNode sort = 11;
Expand Down Expand Up @@ -519,7 +519,7 @@ message WindowAggExecNode {
datafusion.Schema input_schema = 4;
}

message HashAggregateExecNode {
message AggregateExecNode {
repeated PhysicalExprNode group_expr = 1;
repeated PhysicalExprNode aggr_expr = 2;
AggregateMode mode = 3;
Expand Down
34 changes: 18 additions & 16 deletions ballista/rust/core/src/serde/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ use datafusion::datasource::listing::PartitionedFile;
use datafusion::execution::runtime_env::RuntimeEnv;
use datafusion::logical_plan::window_frames::WindowFrame;
use datafusion::logical_plan::FunctionRegistry;
use datafusion::physical_plan::aggregates::create_aggregate_expr;
use datafusion::physical_plan::aggregates::AggregateExec;
use datafusion::physical_plan::aggregates::{create_aggregate_expr, AggregateMode};
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion::physical_plan::cross_join::CrossJoinExec;
Expand All @@ -39,7 +40,6 @@ use datafusion::physical_plan::file_format::{
AvroExec, CsvExec, FileScanConfig, ParquetExec,
};
use datafusion::physical_plan::filter::FilterExec;
use datafusion::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec};
use datafusion::physical_plan::hash_join::{HashJoinExec, PartitionMode};
use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
use datafusion::physical_plan::projection::ProjectionExec;
Expand Down Expand Up @@ -306,19 +306,21 @@ impl AsExecutionPlan for PhysicalPlanNode {
Arc::new((&input_schema).try_into()?),
)?))
}
PhysicalPlanType::HashAggregate(hash_agg) => {
PhysicalPlanType::Aggregate(hash_agg) => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

let input: Arc<dyn ExecutionPlan> = into_physical_plan!(
hash_agg.input,
registry,
runtime,
extension_codec
)?;
let mode = protobuf::AggregateMode::from_i32(hash_agg.mode).ok_or_else(|| {
proto_error(format!(
"Received a HashAggregateNode message with unknown AggregateMode {}",
let mode = protobuf::AggregateMode::from_i32(hash_agg.mode).ok_or_else(
|| {
proto_error(format!(
"Received a AggregateNode message with unknown AggregateMode {}",
hash_agg.mode
))
})?;
},
)?;
let agg_mode: AggregateMode = match mode {
protobuf::AggregateMode::Partial => AggregateMode::Partial,
protobuf::AggregateMode::Final => AggregateMode::Final,
Expand All @@ -341,7 +343,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
.as_ref()
.ok_or_else(|| {
BallistaError::General(
"input_schema in HashAggregateNode is missing.".to_owned(),
"input_schema in AggregateNode is missing.".to_owned(),
)
})?
.clone();
Expand Down Expand Up @@ -384,14 +386,14 @@ impl AsExecutionPlan for PhysicalPlanNode {
)?)
}
_ => Err(BallistaError::General(
"Invalid aggregate expression for HashAggregateExec"
"Invalid aggregate expression for AggregateExec"
.to_string(),
)),
}
})
.collect::<Result<Vec<_>, _>>()?;

Ok(Arc::new(HashAggregateExec::try_new(
Ok(Arc::new(AggregateExec::try_new(
agg_mode,
group,
physical_aggr_expr,
Expand Down Expand Up @@ -730,7 +732,7 @@ impl AsExecutionPlan for PhysicalPlanNode {
},
))),
})
} else if let Some(exec) = plan.downcast_ref::<HashAggregateExec>() {
} else if let Some(exec) = plan.downcast_ref::<AggregateExec>() {
let groups = exec
.group_expr()
.iter()
Expand Down Expand Up @@ -768,8 +770,8 @@ impl AsExecutionPlan for PhysicalPlanNode {
extension_codec,
)?;
Ok(protobuf::PhysicalPlanNode {
physical_plan_type: Some(PhysicalPlanType::HashAggregate(Box::new(
protobuf::HashAggregateExecNode {
physical_plan_type: Some(PhysicalPlanType::Aggregate(Box::new(
protobuf::AggregateExecNode {
group_expr: groups,
group_expr_name: group_names,
aggr_expr: agg,
Expand Down Expand Up @@ -1080,12 +1082,12 @@ mod roundtrip_tests {
datasource::listing::PartitionedFile,
logical_plan::{JoinType, Operator},
physical_plan::{
aggregates::{AggregateExec, AggregateMode},
empty::EmptyExec,
expressions::{binary, col, lit, InListExpr, NotExpr},
expressions::{Avg, Column, PhysicalSortExpr},
file_format::{FileScanConfig, ParquetExec},
filter::FilterExec,
hash_aggregate::{AggregateMode, HashAggregateExec},
hash_join::{HashJoinExec, PartitionMode},
limit::{GlobalLimitExec, LocalLimitExec},
sorts::sort::SortExec,
Expand Down Expand Up @@ -1212,7 +1214,7 @@ mod roundtrip_tests {
}

#[test]
fn rountrip_hash_aggregate() -> Result<()> {
fn rountrip_aggregate() -> Result<()> {
let field_a = Field::new("a", DataType::Int64, false);
let field_b = Field::new("b", DataType::Int64, false);
let schema = Arc::new(Schema::new(vec![field_a, field_b]));
Expand All @@ -1226,7 +1228,7 @@ mod roundtrip_tests {
DataType::Float64,
))];

roundtrip_test(Arc::new(HashAggregateExec::try_new(
roundtrip_test(Arc::new(AggregateExec::try_new(
AggregateMode::Final,
groups.clone(),
aggregates.clone(),
Expand Down
6 changes: 3 additions & 3 deletions ballista/rust/core/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ use datafusion::physical_plan::common::batch_byte_size;
use datafusion::physical_plan::empty::EmptyExec;

use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion::physical_plan::aggregates::AggregateExec;
use datafusion::physical_plan::file_format::{CsvExec, ParquetExec};
use datafusion::physical_plan::filter::FilterExec;
use datafusion::physical_plan::hash_aggregate::HashAggregateExec;
use datafusion::physical_plan::hash_join::HashJoinExec;
use datafusion::physical_plan::projection::ProjectionExec;
use datafusion::physical_plan::sorts::sort::SortExec;
Expand Down Expand Up @@ -151,8 +151,8 @@ fn build_exec_plan_diagram(
id: &mut AtomicUsize,
draw_entity: bool,
) -> Result<usize> {
let operator_str = if plan.as_any().downcast_ref::<HashAggregateExec>().is_some() {
"HashAggregateExec"
let operator_str = if plan.as_any().downcast_ref::<AggregateExec>().is_some() {
"AggregateExec"
} else if plan.as_any().downcast_ref::<SortExec>().is_some() {
"SortExec"
} else if plan.as_any().downcast_ref::<ProjectionExec>().is_some() {
Expand Down
24 changes: 12 additions & 12 deletions ballista/rust/scheduler/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,8 @@ mod test {
use ballista_core::error::BallistaError;
use ballista_core::execution_plans::UnresolvedShuffleExec;
use ballista_core::serde::{protobuf, AsExecutionPlan, BallistaCodec};
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode};
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
use datafusion::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec};
use datafusion::physical_plan::hash_join::HashJoinExec;
use datafusion::physical_plan::sorts::sort::SortExec;
use datafusion::physical_plan::{
Expand All @@ -298,7 +298,7 @@ mod test {
}

#[tokio::test]
async fn distributed_hash_aggregate_plan() -> Result<(), BallistaError> {
async fn distributed_aggregate_plan() -> Result<(), BallistaError> {
let ctx = datafusion_test_context("testdata").await?;

// simplified form of TPC-H query 1
Expand Down Expand Up @@ -327,12 +327,12 @@ mod test {
/* Expected result:

ShuffleWriterExec: Some(Hash([Column { name: "l_returnflag", index: 0 }], 2))
HashAggregateExec: mode=Partial, gby=[l_returnflag@1 as l_returnflag], aggr=[SUM(l_extendedprice Multiply Int64(1))]
AggregateExec: mode=Partial, gby=[l_returnflag@1 as l_returnflag], aggr=[SUM(l_extendedprice Multiply Int64(1))]
CsvExec: source=Path(testdata/lineitem: [testdata/lineitem/partition0.tbl,testdata/lineitem/partition1.tbl]), has_header=false

ShuffleWriterExec: None
ProjectionExec: expr=[l_returnflag@0 as l_returnflag, SUM(lineitem.l_extendedprice Multiply Int64(1))@1 as sum_disc_price]
HashAggregateExec: mode=FinalPartitioned, gby=[l_returnflag@0 as l_returnflag], aggr=[SUM(l_extendedprice Multiply Int64(1))]
AggregateExec: mode=FinalPartitioned, gby=[l_returnflag@0 as l_returnflag], aggr=[SUM(l_extendedprice Multiply Int64(1))]
CoalesceBatchesExec: target_batch_size=4096
UnresolvedShuffleExec

Expand All @@ -346,14 +346,14 @@ mod test {

// verify stage 0
let stage0 = stages[0].children()[0].clone();
let partial_hash = downcast_exec!(stage0, HashAggregateExec);
let partial_hash = downcast_exec!(stage0, AggregateExec);
assert!(*partial_hash.mode() == AggregateMode::Partial);

// verify stage 1
let stage1 = stages[1].children()[0].clone();
let projection = downcast_exec!(stage1, ProjectionExec);
let final_hash = projection.children()[0].clone();
let final_hash = downcast_exec!(final_hash, HashAggregateExec);
let final_hash = downcast_exec!(final_hash, AggregateExec);
assert!(*final_hash.mode() == AggregateMode::FinalPartitioned);
let coalesce = final_hash.children()[0].clone();
let coalesce = downcast_exec!(coalesce, CoalesceBatchesExec);
Expand Down Expand Up @@ -449,7 +449,7 @@ order by
CsvExec: source=Path(testdata/orders: [testdata/orders/orders.tbl]), has_header=false

ShuffleWriterExec: Some(Hash([Column { name: "l_shipmode", index: 0 }], 2))
HashAggregateExec: mode=Partial, gby=[l_shipmode@4 as l_shipmode], aggr=[SUM(CASE WHEN #orders.o_orderpriority Eq Utf8("1-URGENT") Or #orders.o_orderpriority Eq Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END), SUM(CASE WHEN #orders.o_orderpriority NotEq Utf8("1-URGENT") And #orders.o_orderpriority NotEq Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END)]
AggregateExec: mode=Partial, gby=[l_shipmode@4 as l_shipmode], aggr=[SUM(CASE WHEN #orders.o_orderpriority Eq Utf8("1-URGENT") Or #orders.o_orderpriority Eq Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END), SUM(CASE WHEN #orders.o_orderpriority NotEq Utf8("1-URGENT") And #orders.o_orderpriority NotEq Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END)]
CoalesceBatchesExec: target_batch_size=4096
HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: "l_orderkey", index: 0 }, Column { name: "o_orderkey", index: 0 })]
CoalesceBatchesExec: target_batch_size=4096
Expand All @@ -459,7 +459,7 @@ order by

ShuffleWriterExec: None
ProjectionExec: expr=[l_shipmode@0 as l_shipmode, SUM(CASE WHEN #orders.o_orderpriority Eq Utf8("1-URGENT") Or #orders.o_orderpriority Eq Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END)@1 as high_line_count, SUM(CASE WHEN #orders.o_orderpriority NotEq Utf8("1-URGENT") And #orders.o_orderpriority NotEq Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END)@2 as low_line_count]
HashAggregateExec: mode=FinalPartitioned, gby=[l_shipmode@0 as l_shipmode], aggr=[SUM(CASE WHEN #orders.o_orderpriority Eq Utf8("1-URGENT") Or #orders.o_orderpriority Eq Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END), SUM(CASE WHEN #orders.o_orderpriority NotEq Utf8("1-URGENT") And #orders.o_orderpriority NotEq Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END)]
AggregateExec: mode=FinalPartitioned, gby=[l_shipmode@0 as l_shipmode], aggr=[SUM(CASE WHEN #orders.o_orderpriority Eq Utf8("1-URGENT") Or #orders.o_orderpriority Eq Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END), SUM(CASE WHEN #orders.o_orderpriority NotEq Utf8("1-URGENT") And #orders.o_orderpriority NotEq Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END)]
CoalesceBatchesExec: target_batch_size=4096
UnresolvedShuffleExec

Expand Down Expand Up @@ -514,7 +514,7 @@ order by
.partition_count()
);

let hash_agg = downcast_exec!(input, HashAggregateExec);
let hash_agg = downcast_exec!(input, AggregateExec);

let coalesce_batches = hash_agg.children()[0].clone();
let coalesce_batches = downcast_exec!(coalesce_batches, CoalesceBatchesExec);
Expand Down Expand Up @@ -560,7 +560,7 @@ order by
}

#[tokio::test]
async fn roundtrip_serde_hash_aggregate() -> Result<(), BallistaError> {
async fn roundtrip_serde_aggregate() -> Result<(), BallistaError> {
let ctx = datafusion_test_context("testdata").await?;

// simplified form of TPC-H query 1
Expand All @@ -586,8 +586,8 @@ order by
let partial_hash = stages[0].children()[0].clone();
let partial_hash_serde = roundtrip_operator(partial_hash.clone())?;

let partial_hash = downcast_exec!(partial_hash, HashAggregateExec);
let partial_hash_serde = downcast_exec!(partial_hash_serde, HashAggregateExec);
let partial_hash = downcast_exec!(partial_hash, AggregateExec);
let partial_hash_serde = downcast_exec!(partial_hash_serde, AggregateExec);

assert_eq!(
format!("{:?}", partial_hash),
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@
//!
//! * Projection: [`ProjectionExec`](physical_plan::projection::ProjectionExec)
//! * Filter: [`FilterExec`](physical_plan::filter::FilterExec)
//! * Hash and Grouped aggregations: [`HashAggregateExec`](physical_plan::hash_aggregate::HashAggregateExec)
//! * Grouped and non-grouped aggregations: [`AggregateExec`](physical_plan::aggregates::AggregateExec)
//! * Sort: [`SortExec`](physical_plan::sorts::sort::SortExec)
//! * Coalesce partitions: [`CoalescePartitionsExec`](physical_plan::coalesce_partitions::CoalescePartitionsExec)
//! * Limit: [`LocalLimitExec`](physical_plan::limit::LocalLimitExec) and [`GlobalLimitExec`](physical_plan::limit::GlobalLimitExec)
Expand Down
Loading