Skip to content

Commit

Permalink
fix: BI-5359 apply mutations to node's children before the node itself (
Browse files Browse the repository at this point in the history
  • Loading branch information
MCPN authored Oct 3, 2024
1 parent 6e2a07d commit d90c569
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def test_workaround_for_inconsistent_agg(self, control_api, data_api, saved_data
api_v1=control_api,
dataset=saved_dataset,
formulas={
"My Field": "SUM(SUM([sales] INCLUDE [Order ID]) / SUM([sales] FIXED))",
"My Field": "SUM(SUM([sales] INCLUDE [city]) / SUM([sales] FIXED))",
},
)
result_resp = data_api.get_result(
Expand Down Expand Up @@ -829,7 +829,7 @@ def test_lod_fixed_markup(self, control_api, data_api, saved_dataset):
)
assert result_resp.status_code == HTTPStatus.OK, result_resp.json

def test_bug_bi_3425_deeply_nested_bfb(self, control_api, data_api, db, saved_connection_id):
def test_deeply_nested_bfb(self, control_api, data_api, db, saved_connection_id):
raw_data = [
{"id": 10, "city": "New York", "category": "Office Supplies", "sales": 1},
{"id": 11, "city": "New York", "category": "Office Supplies", "sales": 10},
Expand Down Expand Up @@ -977,7 +977,7 @@ def test_replace_original_dim_with_another(self, control_api, data_api, saved_da
data_rows = get_data_rows(result_resp)
assert len(data_rows) == 3 # There are 3 categories

def test_bi_4534_inconsistent_aggregation(self, control_api, data_api, saved_dataset):
def test_inconsistent_aggregation(self, control_api, data_api, saved_dataset):
ds = add_formulas_to_dataset(
api_v1=control_api,
dataset=saved_dataset,
Expand All @@ -997,7 +997,7 @@ def test_bi_4534_inconsistent_aggregation(self, control_api, data_api, saved_dat
)
assert result_resp.status_code == HTTPStatus.OK

def test_bi_4652_measure_filter_with_total_in_select(self, control_api, data_api, saved_dataset):
def test_measure_filter_with_total_in_select(self, control_api, data_api, saved_dataset):
ds = add_formulas_to_dataset(
api_v1=control_api,
dataset=saved_dataset,
Expand All @@ -1023,14 +1023,14 @@ def test_bi_4652_measure_filter_with_total_in_select(self, control_api, data_api
dim_values = [row[0] for row in data_rows]
assert len(dim_values) == len(set(dim_values)), "Dimension values are not unique"

@pytest.mark.xfail(reason="https://github.com/datalens-tech/datalens-backend/issues/98") # FIXME
def test_fixed_with_unknown_field(self, control_api, data_api, saved_dataset):
ds = add_formulas_to_dataset(
api_v1=control_api,
dataset=saved_dataset,
formulas={
"sales sum fx unknown": "SUM([sales] FIXED [unknown])",
},
exp_status=HTTPStatus.BAD_REQUEST,
)

result_resp = data_api.get_result(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def test_ago_with_bfb(self, control_api, data_api, saved_connection_id, db):
"sum": "SUM([int_value])",
"date_duplicate": "[date_value]",
"ago": f'AGO([sum], [date_value], "day", {day_offset})',
"ago_bfb": (f'AGO([sum], [date_value], "day", {day_offset} ' f"BEFORE FILTER BY [date_duplicate])"),
"ago_bfb": f'AGO([sum], [date_value], "day", {day_offset} ' f"BEFORE FILTER BY [date_duplicate])",
"ago_bfb_nested": (
f'AGO(AGO([sum], [date_value], "day", 1), [date_value], "day", {day_offset - 1} '
"BEFORE FILTER BY [date_duplicate])"
Expand Down
12 changes: 6 additions & 6 deletions lib/dl_formula/dl_formula/core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,17 +322,17 @@ def replace_nodes(
to_replace: dict[int, FormulaItem] = {}

for idx, child in enumerate(self.__children):
modified_child = child.replace_nodes(match_func, replace_func, parent_stack_w_self)
if modified_child is not child or modified_child != child:
child = to_replace[idx] = modified_child
is_modified = True

if match_func(child, parent_stack_w_self):
modified_child = replace_func(child, parent_stack_w_self)
if modified_child is not child or modified_child != child:
child = to_replace[idx] = modified_child
to_replace[idx] = modified_child
is_modified = True

modified_child = child.replace_nodes(match_func, replace_func, parent_stack_w_self)
if modified_child is not child or modified_child != child:
to_replace[idx] = modified_child
is_modified = True

if is_modified:
children = cast(
list[FormulaItem],
Expand Down
21 changes: 8 additions & 13 deletions lib/dl_formula/dl_formula/mutation/lod.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@
import dl_formula.core.fork_nodes as fork_nodes
import dl_formula.core.nodes as nodes
from dl_formula.inspect.expression import is_aggregate_expression
from dl_formula.inspect.node import (
is_aggregate_function,
qfork_is_aggregation,
)
from dl_formula.inspect.node import is_aggregate_function
from dl_formula.mutation.dim_resolution import DimensionResolvingMutationBase
from dl_formula.mutation.mutation import FormulaMutation
from dl_formula.shortcuts import n
Expand All @@ -29,21 +26,19 @@ class ExtAggregationToQueryForkMutation(DimensionResolvingMutationBase):
"""

def match_node(self, node: nodes.FormulaItem, parent_stack: Tuple[nodes.FormulaItem, ...]) -> bool:
is_agg = is_aggregate_function(node)
direct_parent = parent_stack[-1]
already_patched = (
isinstance(direct_parent, fork_nodes.QueryFork)
and direct_parent.result_expr is node
and qfork_is_aggregation(direct_parent) # is an aggregation fork (vs. AGO forks)
)
return is_agg and not already_patched
return is_aggregate_function(node)

def make_replacement(
self, old: nodes.FormulaItem, parent_stack: Tuple[nodes.FormulaItem, ...]
) -> nodes.FormulaItem:
assert isinstance(old, nodes.FuncCall)

dimensions, _, parent_dimension_set = self._generate_dimensions(node=old, parent_stack=parent_stack)
dimensions: list[nodes.FormulaItem]
if old.lod.list_node_type(aux_nodes.ErrorNode):
# there are errors in current LODs, propagate them
dimensions = list(old.lod.children)
else:
dimensions, _, _ = self._generate_dimensions(node=old, parent_stack=parent_stack)
lod = nodes.FixedLodSpecifier.make(dim_list=dimensions)

condition_list: List[fork_nodes.JoinConditionBase] = []
Expand Down
12 changes: 2 additions & 10 deletions lib/dl_formula/dl_formula/mutation/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
is_bound_only_to,
)
from dl_formula.inspect.function import uses_default_ordering
from dl_formula.inspect.node import qfork_is_window
from dl_formula.mutation.dim_resolution import DimensionResolvingMutationBase
from dl_formula.mutation.mutation import FormulaMutation

Expand Down Expand Up @@ -135,21 +134,14 @@ class WindowFunctionToQueryForkMutation(DimensionResolvingMutationBase):
"""

def match_node(self, node: nodes.FormulaItem, parent_stack: Tuple[nodes.FormulaItem, ...]) -> bool:
is_winfunc = isinstance(node, nodes.WindowFuncCall)
direct_parent = parent_stack[-1]
already_patched = (
isinstance(direct_parent, fork_nodes.QueryFork)
and direct_parent.result_expr is node
and qfork_is_window(direct_parent)
)
return is_winfunc and not already_patched
return isinstance(node, nodes.WindowFuncCall)

def make_replacement(
self, old: nodes.FormulaItem, parent_stack: Tuple[nodes.FormulaItem, ...]
) -> nodes.FormulaItem:
assert isinstance(old, nodes.WindowFuncCall)

dimensions, _, parent_dimension_set = self._generate_dimensions(node=old, parent_stack=parent_stack)
dimensions, _, _ = self._generate_dimensions(node=old, parent_stack=parent_stack)
lod = nodes.FixedLodSpecifier.make(dim_list=dimensions)

condition_list: List[fork_nodes.JoinConditionBase] = []
Expand Down
44 changes: 44 additions & 0 deletions lib/dl_formula/dl_formula_tests/unit/mutation/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,50 @@ def test_optimize_if_mutation():
assert formula_obj == n.formula(n.field("then field 3"))


def test_optimize_if_mutation_with_const_comparison():
# Check false comparison result in removal of a corresponding branch
formula_obj = n.formula(
n.func.IF(
n.binary("==", left=n.lit(2), right=n.lit(3)),
n.field("then field 1"),
n.binary("!=", left=n.lit(2), right=n.lit(2)),
n.field("then field 2"),
n.field("else field"),
)
)
formula_obj = apply_mutations(
formula_obj,
mutations=[
OptimizeConstComparisonMutation(),
OptimizeConstFuncMutation(),
],
)
assert formula_obj == n.formula(n.field("else field"))

# Check true comparison result in single field result
formula_obj = n.formula(
n.func.IF(
n.field("cond 1"),
n.field("then field 1"),
n.binary("!=", left=n.lit(2), right=n.lit(2)),
n.field("then field 2"),
n.binary("==", left=n.lit(2), right=n.lit(2)),
n.field("then field 3"),
n.field("cond 4"),
n.field("then field 4"),
n.field("else field"),
)
)
formula_obj = apply_mutations(
formula_obj,
mutations=[
OptimizeConstComparisonMutation(),
OptimizeConstFuncMutation(),
],
)
assert formula_obj == n.formula(n.field("then field 3"))


def test_optimize_case_mutation():
# Check removal of false conditions
formula_obj = n.formula(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -753,9 +753,6 @@ def apply_pre_sub_mutations(
Apply pre-substitution mutations required for window functions to be translated correctly.
"""

# prepare mutations
mutations: List[FormulaMutation] = []

# prepare default ordering (for patching RSUM, MSUM functions and the like)
default_order_by = []
ob_expr_obj: formula_nodes.FormulaItem
Expand All @@ -765,19 +762,18 @@ def apply_pre_sub_mutations(
ob_expr_obj = formula_nodes.OrderDescending.make(expr=ob_expr_obj)
default_order_by.append(ob_expr_obj)

mutations = [
IgnoreParenthesisWrapperMutation(),
ConvertBlocksToFunctionsMutation(),
DefaultWindowOrderingMutation(default_order_by=default_order_by),
LookupDefaultBfbMutation(),
]
formula_obj = apply_mutations(formula_obj, mutations=mutations)

# Only measures can contain BFB clauses
title_id_map = {f.title: f.guid for f in self._fields}
mutations.extend(
[
IgnoreParenthesisWrapperMutation(),
ConvertBlocksToFunctionsMutation(),
DefaultWindowOrderingMutation(default_order_by=default_order_by),
LookupDefaultBfbMutation(),
RemapBfbMutation(name_mapping=title_id_map),
]
)
formula_obj = apply_mutations(formula_obj, mutations=[RemapBfbMutation(name_mapping=title_id_map)])

formula_obj = apply_mutations(formula_obj, mutations=mutations)
return formula_obj

def _apply_function_by_name(self, formula_obj: formula_nodes.Formula, func_name: str) -> formula_nodes.Formula:
Expand Down

0 comments on commit d90c569

Please sign in to comment.