diff --git a/lib/dl_api_lib/dl_api_lib_tests/db/data_api/result/complex_queries/test_ext_agg_basic.py b/lib/dl_api_lib/dl_api_lib_tests/db/data_api/result/complex_queries/test_ext_agg_basic.py index 9291f1332..4de7afb15 100644 --- a/lib/dl_api_lib/dl_api_lib_tests/db/data_api/result/complex_queries/test_ext_agg_basic.py +++ b/lib/dl_api_lib/dl_api_lib_tests/db/data_api/result/complex_queries/test_ext_agg_basic.py @@ -312,7 +312,7 @@ def test_workaround_for_inconsistent_agg(self, control_api, data_api, saved_data api_v1=control_api, dataset=saved_dataset, formulas={ - "My Field": "SUM(SUM([sales] INCLUDE [Order ID]) / SUM([sales] FIXED))", + "My Field": "SUM(SUM([sales] INCLUDE [city]) / SUM([sales] FIXED))", }, ) result_resp = data_api.get_result( @@ -829,7 +829,7 @@ def test_lod_fixed_markup(self, control_api, data_api, saved_dataset): ) assert result_resp.status_code == HTTPStatus.OK, result_resp.json - def test_bug_bi_3425_deeply_nested_bfb(self, control_api, data_api, db, saved_connection_id): + def test_deeply_nested_bfb(self, control_api, data_api, db, saved_connection_id): raw_data = [ {"id": 10, "city": "New York", "category": "Office Supplies", "sales": 1}, {"id": 11, "city": "New York", "category": "Office Supplies", "sales": 10}, @@ -977,7 +977,7 @@ def test_replace_original_dim_with_another(self, control_api, data_api, saved_da data_rows = get_data_rows(result_resp) assert len(data_rows) == 3 # There are 3 categories - def test_bi_4534_inconsistent_aggregation(self, control_api, data_api, saved_dataset): + def test_inconsistent_aggregation(self, control_api, data_api, saved_dataset): ds = add_formulas_to_dataset( api_v1=control_api, dataset=saved_dataset, @@ -997,7 +997,7 @@ def test_bi_4534_inconsistent_aggregation(self, control_api, data_api, saved_dat ) assert result_resp.status_code == HTTPStatus.OK - def test_bi_4652_measure_filter_with_total_in_select(self, control_api, data_api, saved_dataset): + def test_measure_filter_with_total_in_select(self, control_api, data_api, saved_dataset): ds = add_formulas_to_dataset( api_v1=control_api, dataset=saved_dataset, @@ -1023,7 +1023,6 @@ def test_bi_4652_measure_filter_with_total_in_select(self, control_api, data_api dim_values = [row[0] for row in data_rows] assert len(dim_values) == len(set(dim_values)), "Dimension values are not unique" - @pytest.mark.xfail(reason="https://github.com/datalens-tech/datalens-backend/issues/98") # FIXME def test_fixed_with_unknown_field(self, control_api, data_api, saved_dataset): ds = add_formulas_to_dataset( api_v1=control_api, @@ -1031,6 +1030,7 @@ def test_fixed_with_unknown_field(self, control_api, data_api, saved_dataset): formulas={ "sales sum fx unknown": "SUM([sales] FIXED [unknown])", }, + exp_status=HTTPStatus.BAD_REQUEST, ) result_resp = data_api.get_result( diff --git a/lib/dl_api_lib/dl_api_lib_tests/db/data_api/result/complex_queries/test_lookup_functions.py b/lib/dl_api_lib/dl_api_lib_tests/db/data_api/result/complex_queries/test_lookup_functions.py index 228c1d4dd..9356995aa 100644 --- a/lib/dl_api_lib/dl_api_lib_tests/db/data_api/result/complex_queries/test_lookup_functions.py +++ b/lib/dl_api_lib/dl_api_lib_tests/db/data_api/result/complex_queries/test_lookup_functions.py @@ -382,7 +382,7 @@ def test_ago_with_bfb(self, control_api, data_api, saved_connection_id, db): "sum": "SUM([int_value])", "date_duplicate": "[date_value]", "ago": f'AGO([sum], [date_value], "day", {day_offset})', - "ago_bfb": (f'AGO([sum], [date_value], "day", {day_offset} ' f"BEFORE FILTER BY [date_duplicate])"), + "ago_bfb": f'AGO([sum], [date_value], "day", {day_offset} ' f"BEFORE FILTER BY [date_duplicate])", "ago_bfb_nested": ( f'AGO(AGO([sum], [date_value], "day", 1), [date_value], "day", {day_offset - 1} ' "BEFORE FILTER BY [date_duplicate])" diff --git a/lib/dl_formula/dl_formula/core/nodes.py b/lib/dl_formula/dl_formula/core/nodes.py index 09b7ad30a..5ad0e413a 100644 --- a/lib/dl_formula/dl_formula/core/nodes.py +++ b/lib/dl_formula/dl_formula/core/nodes.py @@ -322,17 +322,17 @@ def replace_nodes( to_replace: dict[int, FormulaItem] = {} for idx, child in enumerate(self.__children): + modified_child = child.replace_nodes(match_func, replace_func, parent_stack_w_self) + if modified_child is not child or modified_child != child: + child = to_replace[idx] = modified_child + is_modified = True + if match_func(child, parent_stack_w_self): modified_child = replace_func(child, parent_stack_w_self) if modified_child is not child or modified_child != child: - child = to_replace[idx] = modified_child + to_replace[idx] = modified_child is_modified = True - modified_child = child.replace_nodes(match_func, replace_func, parent_stack_w_self) - if modified_child is not child or modified_child != child: - to_replace[idx] = modified_child - is_modified = True - if is_modified: children = cast( list[FormulaItem], diff --git a/lib/dl_formula/dl_formula/mutation/lod.py b/lib/dl_formula/dl_formula/mutation/lod.py index 3adcca989..284ee99c1 100644 --- a/lib/dl_formula/dl_formula/mutation/lod.py +++ b/lib/dl_formula/dl_formula/mutation/lod.py @@ -12,10 +12,7 @@ import dl_formula.core.fork_nodes as fork_nodes import dl_formula.core.nodes as nodes from dl_formula.inspect.expression import is_aggregate_expression -from dl_formula.inspect.node import ( - is_aggregate_function, - qfork_is_aggregation, -) +from dl_formula.inspect.node import is_aggregate_function from dl_formula.mutation.dim_resolution import DimensionResolvingMutationBase from dl_formula.mutation.mutation import FormulaMutation from dl_formula.shortcuts import n @@ -29,21 +26,19 @@ class ExtAggregationToQueryForkMutation(DimensionResolvingMutationBase): """ def match_node(self, node: nodes.FormulaItem, parent_stack: Tuple[nodes.FormulaItem, ...]) -> bool: - is_agg = is_aggregate_function(node) - direct_parent = parent_stack[-1] - already_patched = ( - isinstance(direct_parent, fork_nodes.QueryFork) - and direct_parent.result_expr is node - and qfork_is_aggregation(direct_parent) # is an aggregation fork (vs. AGO forks) - ) - return is_agg and not already_patched + return is_aggregate_function(node) def make_replacement( self, old: nodes.FormulaItem, parent_stack: Tuple[nodes.FormulaItem, ...] ) -> nodes.FormulaItem: assert isinstance(old, nodes.FuncCall) - dimensions, _, parent_dimension_set = self._generate_dimensions(node=old, parent_stack=parent_stack) + dimensions: list[nodes.FormulaItem] + if old.lod.list_node_type(aux_nodes.ErrorNode): + # there are errors in current LODs, propagate them + dimensions = list(old.lod.children) + else: + dimensions, _, _ = self._generate_dimensions(node=old, parent_stack=parent_stack) lod = nodes.FixedLodSpecifier.make(dim_list=dimensions) condition_list: List[fork_nodes.JoinConditionBase] = [] diff --git a/lib/dl_formula/dl_formula/mutation/window.py b/lib/dl_formula/dl_formula/mutation/window.py index 0cc1edf0a..994355ea8 100644 --- a/lib/dl_formula/dl_formula/mutation/window.py +++ b/lib/dl_formula/dl_formula/mutation/window.py @@ -16,7 +16,6 @@ is_bound_only_to, ) from dl_formula.inspect.function import uses_default_ordering -from dl_formula.inspect.node import qfork_is_window from dl_formula.mutation.dim_resolution import DimensionResolvingMutationBase from dl_formula.mutation.mutation import FormulaMutation @@ -135,21 +134,14 @@ class WindowFunctionToQueryForkMutation(DimensionResolvingMutationBase): """ def match_node(self, node: nodes.FormulaItem, parent_stack: Tuple[nodes.FormulaItem, ...]) -> bool: - is_winfunc = isinstance(node, nodes.WindowFuncCall) - direct_parent = parent_stack[-1] - already_patched = ( - isinstance(direct_parent, fork_nodes.QueryFork) - and direct_parent.result_expr is node - and qfork_is_window(direct_parent) - ) - return is_winfunc and not already_patched + return isinstance(node, nodes.WindowFuncCall) def make_replacement( self, old: nodes.FormulaItem, parent_stack: Tuple[nodes.FormulaItem, ...] ) -> nodes.FormulaItem: assert isinstance(old, nodes.WindowFuncCall) - dimensions, _, parent_dimension_set = self._generate_dimensions(node=old, parent_stack=parent_stack) + dimensions, _, _ = self._generate_dimensions(node=old, parent_stack=parent_stack) lod = nodes.FixedLodSpecifier.make(dim_list=dimensions) condition_list: List[fork_nodes.JoinConditionBase] = [] diff --git a/lib/dl_formula/dl_formula_tests/unit/mutation/test_optimization.py b/lib/dl_formula/dl_formula_tests/unit/mutation/test_optimization.py index 2bc940f7f..ab1a6d7a1 100644 --- a/lib/dl_formula/dl_formula_tests/unit/mutation/test_optimization.py +++ b/lib/dl_formula/dl_formula_tests/unit/mutation/test_optimization.py @@ -247,6 +247,50 @@ def test_optimize_if_mutation(): assert formula_obj == n.formula(n.field("then field 3")) +def test_optimize_if_mutation_with_const_comparison(): + # Check false comparison result in removal of a corresponding branch + formula_obj = n.formula( + n.func.IF( + n.binary("==", left=n.lit(2), right=n.lit(3)), + n.field("then field 1"), + n.binary("!=", left=n.lit(2), right=n.lit(2)), + n.field("then field 2"), + n.field("else field"), + ) + ) + formula_obj = apply_mutations( + formula_obj, + mutations=[ + OptimizeConstComparisonMutation(), + OptimizeConstFuncMutation(), + ], + ) + assert formula_obj == n.formula(n.field("else field")) + + # Check true comparison result in single field result + formula_obj = n.formula( + n.func.IF( + n.field("cond 1"), + n.field("then field 1"), + n.binary("!=", left=n.lit(2), right=n.lit(2)), + n.field("then field 2"), + n.binary("==", left=n.lit(2), right=n.lit(2)), + n.field("then field 3"), + n.field("cond 4"), + n.field("then field 4"), + n.field("else field"), + ) + ) + formula_obj = apply_mutations( + formula_obj, + mutations=[ + OptimizeConstComparisonMutation(), + OptimizeConstFuncMutation(), + ], + ) + assert formula_obj == n.formula(n.field("then field 3")) + + def test_optimize_case_mutation(): # Check removal of false conditions formula_obj = n.formula( diff --git a/lib/dl_query_processing/dl_query_processing/compilation/formula_compiler.py b/lib/dl_query_processing/dl_query_processing/compilation/formula_compiler.py index c148fc410..41270d554 100644 --- a/lib/dl_query_processing/dl_query_processing/compilation/formula_compiler.py +++ b/lib/dl_query_processing/dl_query_processing/compilation/formula_compiler.py @@ -753,9 +753,6 @@ def apply_pre_sub_mutations( Apply pre-substitution mutations required for window functions to be translated correctly. """ - # prepare mutations - mutations: List[FormulaMutation] = [] - # prepare default ordering (for patching RSUM, MSUM functions and the like) default_order_by = [] ob_expr_obj: formula_nodes.FormulaItem @@ -765,19 +762,18 @@ def apply_pre_sub_mutations( ob_expr_obj = formula_nodes.OrderDescending.make(expr=ob_expr_obj) default_order_by.append(ob_expr_obj) + mutations = [ + IgnoreParenthesisWrapperMutation(), + ConvertBlocksToFunctionsMutation(), + DefaultWindowOrderingMutation(default_order_by=default_order_by), + LookupDefaultBfbMutation(), + ] + formula_obj = apply_mutations(formula_obj, mutations=mutations) + # Only measures can contain BFB clauses title_id_map = {f.title: f.guid for f in self._fields} - mutations.extend( - [ - IgnoreParenthesisWrapperMutation(), - ConvertBlocksToFunctionsMutation(), - DefaultWindowOrderingMutation(default_order_by=default_order_by), - LookupDefaultBfbMutation(), - RemapBfbMutation(name_mapping=title_id_map), - ] - ) + formula_obj = apply_mutations(formula_obj, mutations=[RemapBfbMutation(name_mapping=title_id_map)]) - formula_obj = apply_mutations(formula_obj, mutations=mutations) return formula_obj def _apply_function_by_name(self, formula_obj: formula_nodes.Formula, func_name: str) -> formula_nodes.Formula: