Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve facet lifting to handle formula and filter transforms #446

Merged
merged 2 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 140 additions & 102 deletions vegafusion-core/src/planning/lift_facet_aggregations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,18 @@ impl ExtractFacetAggregationsVisitor {
}
Ok(())
}

fn num_children_of_dataset(&self, name: &str, scope: &[u32]) -> usize {
let facet_dataset_var: ScopedVariable = (Variable::new_data(&name), Vec::from(scope));
let Some(facet_dataset_idx) = self.node_indexes.get(&facet_dataset_var) else {
return 0;
};
let edges = self
.graph
.edges_directed(*facet_dataset_idx, Direction::Outgoing);
let edges_vec = edges.into_iter().collect::<Vec<_>>();
edges_vec.len()
}
}

impl MutChartVisitor for ExtractFacetAggregationsVisitor {
Expand All @@ -76,123 +88,152 @@ impl MutChartVisitor for ExtractFacetAggregationsVisitor {
};

// Check for child datasets
let facet_dataset_var: ScopedVariable = (Variable::new_data(&facet.name), Vec::from(scope));
let Some(facet_dataset_idx) = self.node_indexes.get(&facet_dataset_var) else {
return Ok(());
};
let edges = self
.graph
.edges_directed(*facet_dataset_idx, Direction::Outgoing);
let edges_vec = edges.into_iter().collect::<Vec<_>>();
if edges_vec.len() != 1 {
// We don't have exactly one child dataset so we cannot lift
let num_facet_children = self.num_children_of_dataset(&facet.name, scope);
if num_facet_children != 1 {
// We don't have exactly one child dataset, so we cannot lift
return Ok(());
}

// Collect datasets that are immediate children of the facet dataset
let mut child_datasets = mark
.data
.iter_mut()
.filter(|d| d.source.as_ref() == Some(&facet.name))
.collect::<Vec<_>>();

if child_datasets.len() != 1 {
// Child dataset isn't located in this facet's dataset.
// I don't think this shouldn't happen, but bail out in case
let mark_datasets = mark.data.clone();
let mut child_dataset = if let Some(idx) = mark_datasets
.iter()
.position(|d| d.source.as_ref() == Some(&facet.name.to_string()))
{
&mut mark.data[idx]
} else {
return Ok(());
}
};

let child_dataset = &mut child_datasets[0];
let Some(TransformSpec::Aggregate(mut agg)) = child_dataset.transform.get(0).cloned()
else {
// dataset does not have a aggregate transform as the first transform, nothing to lift
return Ok(());
let mut lifted_transforms: Vec<TransformSpec> = Vec::new();

let agg = loop {
match child_dataset.transform.get(0).cloned() {
None => {
// End of transforms for this dataset, advance to child dataset if possible
if self.num_children_of_dataset(&child_dataset.name, scope) != 1 {
break None;
}

if let Some(idx) = mark_datasets
.iter()
.position(|d| d.source.as_ref() == Some(&child_dataset.name))
{
child_dataset = &mut mark.data[idx]
} else {
break None;
};
}
Some(TransformSpec::Aggregate(agg)) => {
// Reached an aggregation, bail out
child_dataset.transform.remove(0);
break Some(agg);
}
Some(TransformSpec::Formula(tx)) => {
lifted_transforms.push(TransformSpec::Formula(tx));
child_dataset.transform.remove(0);
}
Some(TransformSpec::Filter(tx)) => {
lifted_transforms.push(TransformSpec::Filter(tx));
child_dataset.transform.remove(0);
}
_ => {
// Reached unsupported transform type without an aggregation
break None;
}
}
};

// Add facet groupby fields as aggregate transform groupby fields
let facet_groupby_fields: Vec<Field> = facet
.groupby
.clone()
.unwrap_or_default()
.to_vec()
.into_iter()
.map(Field::String)
.collect();
if let Some(mut agg) = agg {
// Add facet groupby fields as aggregate transform groupby fields
let facet_groupby_fields: Vec<Field> = facet
.groupby
.clone()
.unwrap_or_default()
.to_vec()
.into_iter()
.map(Field::String)
.collect();

agg.groupby.extend(facet_groupby_fields.clone());
agg.groupby.extend(facet_groupby_fields.clone());

let mut lifted_transforms: Vec<TransformSpec> = Vec::new();
// When the facet defines an aggregation, we need to perform it with a joinaggregate
// prior to the lifted aggregation.
//
// Leave `cross` field as-is
if let Some(facet_aggregate) = &mut facet.aggregate {
if facet_aggregate.fields.is_some()
&& facet_aggregate.ops.is_some()
&& facet_aggregate.as_.is_some()
{
// Add joinaggregate transform that performs the facet's aggregation using the same
// grouping columns as the facet
lifted_transforms.push(TransformSpec::JoinAggregate(
JoinAggregateTransformSpec {
groupby: Some(facet_groupby_fields),
fields: facet_aggregate.fields.clone().unwrap(),
ops: facet_aggregate.ops.clone().unwrap(),
as_: facet_aggregate.as_.clone(),
extra: Default::default(),
},
));

// When the facet defines an aggregation, we need to perform it with a joinaggregate
// prior to the lifted aggregation.
//
// Leave `cross` field as-is
if let Some(facet_aggregate) = &mut facet.aggregate {
if facet_aggregate.fields.is_some()
&& facet_aggregate.ops.is_some()
&& facet_aggregate.as_.is_some()
{
// Add joinaggregate transform that performs the facet's aggregation using the same
// grouping columns as the facet
lifted_transforms.push(TransformSpec::JoinAggregate(JoinAggregateTransformSpec {
groupby: Some(facet_groupby_fields),
fields: facet_aggregate.fields.clone().unwrap(),
ops: facet_aggregate.ops.clone().unwrap(),
as_: facet_aggregate.as_.clone(),
extra: Default::default(),
}));
// Add aggregations to the lifted aggregate transform that pass through the
// fields that the joinaggregate above calculates
let mut new_fields = agg.fields.clone().unwrap_or_default();
let mut new_ops = agg.ops.clone().unwrap_or_default();
let mut new_as = agg.as_.clone().unwrap_or_default();

// Add aggregations to the lifted aggregate transform that pass through the
// fields that the joinaggregate above calculates
let mut new_fields = agg.fields.clone().unwrap_or_default();
let mut new_ops = agg.ops.clone().unwrap_or_default();
let mut new_as = agg.as_.clone().unwrap_or_default();
new_fields.extend(
facet_aggregate
.as_
.clone()
.unwrap()
.into_iter()
.map(|s| s.map(Field::String)),
);
// Use min aggregate to pass through single unique value
new_ops.extend(facet_aggregate.ops.iter().map(|_| AggregateOpSpec::Min));
new_as.extend(facet_aggregate.as_.clone().unwrap());

new_fields.extend(
facet_aggregate
.as_
.clone()
.unwrap()
.into_iter()
.map(|s| s.map(Field::String)),
);
// Use min aggregate to pass through single unique value
new_ops.extend(facet_aggregate.ops.iter().map(|_| AggregateOpSpec::Min));
new_as.extend(facet_aggregate.as_.clone().unwrap());
agg.fields = Some(new_fields);
agg.ops = Some(new_ops);
agg.as_ = Some(new_as);

agg.fields = Some(new_fields);
agg.ops = Some(new_ops);
agg.as_ = Some(new_as);
// Update facet aggregate to pass through the fields compute in joinaggregate
facet_aggregate.fields = Some(
facet_aggregate
.as_
.clone()
.unwrap()
.into_iter()
.map(|s| s.map(Field::String))
.collect(),
);
facet_aggregate.ops = Some(
facet_aggregate
.ops
.iter()
.map(|_| AggregateOpSpec::Min)
.collect(),
);
} else if facet_aggregate.fields.is_some()
|| facet_aggregate.ops.is_some()
|| facet_aggregate.as_.is_some()
{
// Not all of fields, ops, and as are defined so skip lifting
return Ok(());
}
}

// Update facet aggregate to pass through the fields compute in joinaggregate
facet_aggregate.fields = Some(
facet_aggregate
.as_
.clone()
.unwrap()
.into_iter()
.map(|s| s.map(Field::String))
.collect(),
);
facet_aggregate.ops = Some(
facet_aggregate
.ops
.iter()
.map(|_| AggregateOpSpec::Min)
.collect(),
);
} else if facet_aggregate.fields.is_some()
|| facet_aggregate.ops.is_some()
|| facet_aggregate.as_.is_some()
{
// Not all of fields, ops, and as are defined so skip lifting
// Add lifted aggregate transform, potentially after the joinaggregate transform
lifted_transforms.push(TransformSpec::Aggregate(agg));
} else {
if lifted_transforms.is_empty() {
// No supported transforms found
return Ok(());
}
}

// Add lifted aggregate transform, potentially after the joinaggregate transform
lifted_transforms.push(TransformSpec::Aggregate(agg));

// Create facet dataset name and increment counter to keep names unique even if the same
// source dataset is used in multiple facets
let facet_dataset_name = format!("{}_facet_{}{}", facet.data, facet.name, self.counter);
Expand Down Expand Up @@ -223,9 +264,6 @@ impl MutChartVisitor for ExtractFacetAggregationsVisitor {
.or_default()
.push(new_dataset);

// Remove leading aggregate transform from child dataset
child_dataset.transform.remove(0);

// Rename source dataset in facet
facet.data = facet_dataset_name;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
{
"server_to_client": [
{
"name": "data_2",
"namespace": "data",
"scope": []
},
{
"name": "data_2_x_domain_day_0",
"namespace": "data",
"scope": []
},
{
"name": "data_3",
"namespace": "data",
"scope": []
},
{
"name": "data_3_x_domain_day_1",
"namespace": "data",
"scope": []
},
{
"name": "row_domain",
"namespace": "data",
"scope": []
},
{
"name": "tips_facet_facet0",
"namespace": "data",
"scope": []
}
],
"client_to_server": []
}
Loading
Loading