From f73ac449b115ded4ee1ab590e6fcdf711c450360 Mon Sep 17 00:00:00 2001 From: Nick Proskurin Date: Fri, 25 Oct 2024 17:43:55 +0200 Subject: [PATCH] fix(formula): BI-5455 fix LODs with constant functions in query (#668) --- .../test_ext_agg_corner_cases.py | 23 ++++++++++++ .../db/data_api/result/test_markup.py | 20 ++--------- .../definitions/functions_markup.py | 2 +- .../dl_formula/definitions/registry.py | 2 +- .../dl_formula/inspect/expression.py | 8 +++++ .../unit/inspect/test_expression.py | 35 ++++++++++++++----- .../unit/scripts/test_formula_cli.py | 8 ++--- 7 files changed, 66 insertions(+), 32 deletions(-) diff --git a/lib/dl_api_lib/dl_api_lib_tests/db/data_api/result/complex_queries/test_ext_agg_corner_cases.py b/lib/dl_api_lib/dl_api_lib_tests/db/data_api/result/complex_queries/test_ext_agg_corner_cases.py index ae965934a..92c2f008d 100644 --- a/lib/dl_api_lib/dl_api_lib_tests/db/data_api/result/complex_queries/test_ext_agg_corner_cases.py +++ b/lib/dl_api_lib/dl_api_lib_tests/db/data_api/result/complex_queries/test_ext_agg_corner_cases.py @@ -2,6 +2,8 @@ from http import HTTPStatus +import pytest + from dl_api_client.dsmaker.shortcuts.dataset import add_formulas_to_dataset from dl_api_client.dsmaker.shortcuts.result_data import get_data_rows from dl_api_lib_tests.db.base import DefaultApiTestBase @@ -123,3 +125,24 @@ def test_lod_include_measure(self, control_api, data_api, saved_dataset): fail_ok=True, ) assert result_resp.status_code == HTTPStatus.BAD_REQUEST + + @pytest.mark.parametrize("function", ["COUNT()", "NOW()"]) + def test_lod_with_avatarless_function(self, control_api, data_api, saved_dataset, function): + ds = add_formulas_to_dataset( + api_v1=control_api, + dataset=saved_dataset, + formulas={ + "Agg": "SUM(SUM([sales] FIXED [region]))", + "Agg with avatarless function": f"CONCAT({function}, ': ', [Agg])", + }, + ) + + result_resp = data_api.get_result( + dataset=ds, + fields=[ + ds.find_field(title="Agg with avatarless function"), + ], + ) + assert result_resp.status_code == HTTPStatus.OK, result_resp.json + rows = get_data_rows(result_resp) + assert len(rows) == 1 diff --git a/lib/dl_api_lib/dl_api_lib_tests/db/data_api/result/test_markup.py b/lib/dl_api_lib/dl_api_lib_tests/db/data_api/result/test_markup.py index 008e83a1a..f62b7686e 100644 --- a/lib/dl_api_lib/dl_api_lib_tests/db/data_api/result/test_markup.py +++ b/lib/dl_api_lib/dl_api_lib_tests/db/data_api/result/test_markup.py @@ -1,27 +1,11 @@ -from __future__ import annotations - import uuid -import pytest - from dl_api_client.dsmaker.shortcuts.result_data import get_data_rows -from dl_api_lib_testing.data_api_base import DataApiTestParams from dl_api_lib_tests.db.base import DefaultApiTestBase -class TestUMarkup(DefaultApiTestBase): - @pytest.fixture(scope="function") - def data_api_test_params(self, sample_table) -> DataApiTestParams: - # This default is defined for the sample table - return DataApiTestParams( - two_dims=("category", "city"), - summable_field="sales", - range_field="sales", - distinct_field="city", - date_field="order_date", - ) - - def test_markup(self, saved_dataset, data_api, data_api_test_params): +class TestMarkup(DefaultApiTestBase): + def test_markup(self, saved_dataset, data_api): ds = saved_dataset field_a, field_b, field_c, field_d, field_e, field_nulled = (str(uuid.uuid4()) for _ in range(6)) diff --git a/lib/dl_formula/dl_formula/definitions/functions_markup.py b/lib/dl_formula/dl_formula/definitions/functions_markup.py index 16fee714b..180e4b666 100644 --- a/lib/dl_formula/dl_formula/definitions/functions_markup.py +++ b/lib/dl_formula/dl_formula/definitions/functions_markup.py @@ -173,7 +173,7 @@ class MarkupTypeStrategy(Fixed): def __init__(self) -> None: super().__init__(DataType.MARKUP) - def get_from_args(self, arg_types): # type: ignore # 2024-01-24 # TODO: Function is missing a type annotation [no-untyped-def] + def get_from_args(self, arg_types: list[DataType]) -> DataType: if all(arg.casts_to(DataType.CONST_MARKUP) or arg.casts_to(DataType.CONST_STRING) for arg in arg_types): return DataType.CONST_MARKUP return DataType.MARKUP diff --git a/lib/dl_formula/dl_formula/definitions/registry.py b/lib/dl_formula/dl_formula/definitions/registry.py index fa9573b98..a152f48a8 100644 --- a/lib/dl_formula/dl_formula/definitions/registry.py +++ b/lib/dl_formula/dl_formula/definitions/registry.py @@ -80,7 +80,7 @@ def get_definition( if dialect is not None and for_any_dialect or dialect is None and not for_any_dialect: raise ValueError( - "Either dialect should be provided or for_any_dialect be set to True." " Cannot provide both." + "Either dialect should be provided or for_any_dialect be set to True. Cannot provide both." ) arg_types = list(arg_types) diff --git a/lib/dl_formula/dl_formula/inspect/expression.py b/lib/dl_formula/dl_formula/inspect/expression.py index f30469c58..9433890f4 100644 --- a/lib/dl_formula/dl_formula/inspect/expression.py +++ b/lib/dl_formula/dl_formula/inspect/expression.py @@ -59,6 +59,14 @@ def is_constant_expression( result = True elif isinstance(node, nodes.ParenthesizedExpr): result = is_constant_expression(node.expr, env=env) + elif isinstance(node, nodes.OperationCall): + # check if all arguments of a function or an operator are constant; + # avoid aggregates and WFs, for example count() or count(1) + result = ( + not dl_formula.inspect.node.is_aggregate_function(node) + and not dl_formula.inspect.node.is_window_function(node) + and all(is_constant_expression(arg, env) for arg in node.args) + ) else: result = False diff --git a/lib/dl_formula/dl_formula_tests/unit/inspect/test_expression.py b/lib/dl_formula/dl_formula_tests/unit/inspect/test_expression.py index 7649eba8b..c3116d240 100644 --- a/lib/dl_formula/dl_formula_tests/unit/inspect/test_expression.py +++ b/lib/dl_formula/dl_formula_tests/unit/inspect/test_expression.py @@ -21,6 +21,8 @@ used_fields, used_func_calls, ) +from dl_formula.inspect.registry.item import NameOpItem +from dl_formula.inspect.registry.registry import LOWLEVEL_OP_REGISTRY from dl_formula.shortcuts import n @@ -96,16 +98,33 @@ def test_used_fields(): def test_is_constant_expression(): - env = InspectionEnvironment() - assert is_constant_expression(nodes.LiteralString.make("qwe"), env=env) - assert is_constant_expression(nodes.Null(), env=env) - assert not is_constant_expression(nodes.FuncCall.make(name="func", args=[nodes.LiteralString.make("qwe")]), env=env) - assert is_constant_expression(n.p(nodes.LiteralString.make("qwe")), env=env) - assert not is_constant_expression(nodes.FuncCall.make(name="func", args=[nodes.LiteralString.make("qwe")]), env=env) - assert not is_constant_expression( - n.p(nodes.FuncCall.make(name="func", args=[nodes.LiteralString.make("qwe")])), env=env + LOWLEVEL_OP_REGISTRY._name_op_registry["agg_func"] = NameOpItem( + name="agg_func", + can_be_window=False, + can_be_aggregate=True, + can_be_nonwindow=True, ) + env = InspectionEnvironment() + literal_string = nodes.LiteralString.make("qwe") + const_func_call = nodes.FuncCall.make(name="func", args=[literal_string]) + non_const_func_call = nodes.FuncCall.make(name="func", args=[literal_string, nodes.Field.make("field")]) + agg_func_call = nodes.FuncCall.make(name="agg_func", args=[literal_string]) + window_func_call = nodes.WindowFuncCall.make(name="func", args=[literal_string]) + + try: + assert is_constant_expression(literal_string, env=env) + assert is_constant_expression(nodes.Null(), env=env) + assert is_constant_expression(const_func_call, env=env) + assert is_constant_expression(n.p(literal_string), env=env) + assert is_constant_expression(n.p(const_func_call), env=env) + + assert not is_constant_expression(non_const_func_call, env=env) + assert not is_constant_expression(window_func_call, env=env) + assert not is_constant_expression(agg_func_call, env=env) + finally: + del LOWLEVEL_OP_REGISTRY._name_op_registry["agg_func"] + def test_is_is_bound_only_to(): dim_1 = nodes.Field.make("Dim Field") diff --git a/lib/dl_formula/dl_formula_tests/unit/scripts/test_formula_cli.py b/lib/dl_formula/dl_formula_tests/unit/scripts/test_formula_cli.py index 751dc1f9a..e9ea01952 100644 --- a/lib/dl_formula/dl_formula_tests/unit/scripts/test_formula_cli.py +++ b/lib/dl_formula/dl_formula_tests/unit/scripts/test_formula_cli.py @@ -131,14 +131,14 @@ def test_slice(tool): "--diff", "--levels", "aggregate,window,toplevel", - "456 + SUM(FUNC(123)) + RMAX([coeff] * AVG(456 + [qwe] + [rty] + [uio]) - 1)", + "456 + SUM(FUNC([field])) + RMAX([coeff] * AVG(456 + [qwe] + [rty] + [uio]) - 1)", ] ) assert ( stdout.strip() == ( - "toplevel ▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫ + RMAX(▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫)\n" - "window 456 + SUM(▫▫▫▫▫▫▫▫▫) ▫▫▫▫▫▫▫ * AVG(▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫) - 1 \n" - "aggregate FUNC(123) [coeff] 456 + [qwe] + [rty] + [uio] \n" + "toplevel ▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫ + RMAX(▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫)\n" + "window 456 + SUM(▫▫▫▫▫▫▫▫▫▫▫▫▫) ▫▫▫▫▫▫▫ * AVG(▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫▫) - 1 \n" + "aggregate FUNC([field]) [coeff] 456 + [qwe] + [rty] + [uio] \n" ).strip() )