From 2e8f377187b30d85667bef375dc1c03ba288f677 Mon Sep 17 00:00:00 2001 From: Dmitrii Ovsyannikov Date: Fri, 20 Dec 2024 13:45:47 +0100 Subject: [PATCH] feat: BI-4316 add ago filtration (#674) * feat: add ago dimension filtration * fix: dl_formula tests * fix: * fix: review - namings * fix: PR --- .../complex_queries/test_lookup_functions.py | 38 +++++++++++++++ lib/dl_formula/dl_formula/core/fork_nodes.py | 47 +++++++++++++++++-- lib/dl_formula/dl_formula/mutation/lookup.py | 10 ++++ lib/dl_formula/dl_formula/shortcuts.py | 2 + .../unit/mutation/test_subquery_mutations.py | 7 +++ lib/dl_formula/pyproject.toml | 1 + .../multi_query/splitters/mask_based.py | 4 ++ .../multi_query/splitters/query_fork.py | 38 ++++++++++++++- .../multi_query/splitters/win_func.py | 1 + .../unit/multi_query/test_splitter.py | 1 + lib/dl_query_processing/pyproject.toml | 1 + tools/taskfiles/taskfile_dev.yml | 4 +- 12 files changed, 148 insertions(+), 6 deletions(-) diff --git a/lib/dl_api_lib/dl_api_lib_tests/db/data_api/result/complex_queries/test_lookup_functions.py b/lib/dl_api_lib/dl_api_lib_tests/db/data_api/result/complex_queries/test_lookup_functions.py index 31acb6c67..2b842848c 100644 --- a/lib/dl_api_lib/dl_api_lib_tests/db/data_api/result/complex_queries/test_lookup_functions.py +++ b/lib/dl_api_lib/dl_api_lib_tests/db/data_api/result/complex_queries/test_lookup_functions.py @@ -1,5 +1,6 @@ import datetime from http import HTTPStatus +import re from typing import Optional import pytest @@ -55,6 +56,43 @@ def test_ago(self, control_api, data_api, saved_dataset): data_rows = get_data_rows(result_resp) check_ago_data(data_rows=data_rows, date_idx=1, value_idx=2, ago_idx=3, day_offset=1) + def test_ago_filtered(self, control_api, data_api, saved_dataset): + ds = add_formulas_to_dataset( + api_v1=control_api, + dataset=saved_dataset, + formulas={ + "Sales Sum": "SUM([sales])", + "Sales Sum Yesterday": 'AGO([Sales Sum], [order_date], "day", 5)', + }, + ) + result_resp = data_api.get_result( + dataset=ds, + fields=[ + ds.find_field(title="category"), + ds.find_field(title="order_date"), + ds.find_field(title="Sales Sum"), + ds.find_field(title="Sales Sum Yesterday"), + ], + order_by=[ + ds.find_field(title="category"), + ds.find_field(title="order_date"), + ], + filters=[ + ds.find_field(title="category").filter(op=WhereClauseOperation.EQ, values=["Office Supplies"]), + ds.find_field(title="order_date").filter(op=WhereClauseOperation.GTE, values=["2014-01-06"]), + ], + fail_ok=True, + ) + assert result_resp.status_code == HTTPStatus.OK, result_resp.json + + query: str = result_resp.json["blocks"][0]["query"] + expected_query_pattern = r"JOIN.*\(.*order_date.*5.*>=.*2014-01-06.*\).*ON" + assert re.search( + expected_query_pattern, + query, + flags=re.DOTALL, + ), "Expected to find pattern 'JOIN (... order_date >= 2014-01-06 ...) ON' in query" + def test_ago_variants(self, control_api, data_api, saved_dataset): ds = add_formulas_to_dataset( api_v1=control_api, diff --git a/lib/dl_formula/dl_formula/core/fork_nodes.py b/lib/dl_formula/dl_formula/core/fork_nodes.py index f09774dc6..a7c678957 100644 --- a/lib/dl_formula/dl_formula/core/fork_nodes.py +++ b/lib/dl_formula/dl_formula/core/fork_nodes.py @@ -120,6 +120,43 @@ def is_self_eq_join(self) -> bool: return all(isinstance(child, SelfEqualityJoinCondition) for child in self.children) +class BfbFilterMutationSpec(nodes.FormulaItem): + __slots__ = () + + show_names = nodes.FormulaItem.show_names + ("original", "replacement") + + original: nodes.Child[nodes.FormulaItem] = nodes.Child(0) + replacement: nodes.Child[nodes.FormulaItem] = nodes.Child(1) + + @classmethod + def make( + cls, + original: nodes.FormulaItem, + replacement: nodes.FormulaItem, + *, + meta: Optional[nodes.NodeMeta] = None, + ) -> BfbFilterMutationSpec: + children = (original, replacement) + return cls(*children, meta=meta) + + +class BfbFilterMutationCollectionSpec(nodes.FormulaItem): + __slots__ = () + + show_names = nodes.FormulaItem.show_names + ("mutations",) + + mutations: nodes.MultiChild[BfbFilterMutationSpec] = nodes.MultiChild(slice(0, None)) + + @classmethod + def make( + cls, + *mutations: BfbFilterMutationSpec, + meta: Optional[nodes.NodeMeta] = None, + ) -> BfbFilterMutationCollectionSpec: + children = mutations + return cls(*children, meta=meta) + + class QueryFork(nodes.FormulaItem): """ Represents a point where the query should be forked in two: @@ -129,12 +166,13 @@ class QueryFork(nodes.FormulaItem): __slots__ = () - show_names = nodes.FormulaItem.show_names + ("join_type", "joining", "result_expr", "lod") + show_names = nodes.FormulaItem.show_names + ("join_type", "joining", "result_expr", "lod", "bfb_filter_mutations") joining: nodes.Child[QueryForkJoiningBase] = nodes.Child(0) result_expr: nodes.Child[nodes.FormulaItem] = nodes.Child(1) before_filter_by: nodes.Child[nodes.BeforeFilterBy] = nodes.Child(2) lod: nodes.Child[nodes.LodSpecifier] = nodes.Child(3) + bfb_filter_mutations: nodes.Child[BfbFilterMutationCollectionSpec] = nodes.Child(4) @classmethod def make( @@ -144,20 +182,23 @@ def make( result_expr: nodes.FormulaItem, before_filter_by: Optional[nodes.BeforeFilterBy] = None, lod: Optional[nodes.LodSpecifier] = None, + bfb_filter_mutations: Optional[BfbFilterMutationCollectionSpec] = None, meta: Optional[nodes.NodeMeta] = None, ) -> QueryFork: if before_filter_by is None: before_filter_by = nodes.BeforeFilterBy.make() if lod is None: lod = nodes.InheritedLodSpecifier() + if bfb_filter_mutations is None: + bfb_filter_mutations = BfbFilterMutationCollectionSpec.make() - children = (joining, result_expr, before_filter_by, lod) + children = (joining, result_expr, before_filter_by, lod, bfb_filter_mutations) internal_value = (join_type,) return cls(*children, internal_value=internal_value, meta=meta) @classmethod def validate_children(cls, children: Sequence[nodes.FormulaItem]) -> None: - assert len(children) == 4 + assert len(children) == 5 @classmethod def validate_internal_value(cls, internal_value: tuple[Optional[Hashable], ...]) -> None: diff --git a/lib/dl_formula/dl_formula/mutation/lookup.py b/lib/dl_formula/dl_formula/mutation/lookup.py index 78cca32e9..b03d3650b 100644 --- a/lib/dl_formula/dl_formula/mutation/lookup.py +++ b/lib/dl_formula/dl_formula/mutation/lookup.py @@ -253,6 +253,15 @@ def make_replacement( joining = fork_nodes.QueryForkJoiningWithList.make(condition_list=condition_list) lod = nodes.InheritedLodSpecifier() + bfb_filter_mutations = fork_nodes.BfbFilterMutationCollectionSpec.make( + fork_nodes.BfbFilterMutationSpec.make( + original=lookup_dimension, + replacement=nodes.FuncCall.make( + name="dateadd", + args=[lookup_dimension, *old.args[2:]], + ), + ) + ) new_node = fork_nodes.QueryFork.make( join_type=fork_nodes.JoinType.left, @@ -260,6 +269,7 @@ def make_replacement( joining=joining, before_filter_by=old.before_filter_by, lod=lod, + bfb_filter_mutations=bfb_filter_mutations, meta=old.meta, ) return new_node diff --git a/lib/dl_formula/dl_formula/shortcuts.py b/lib/dl_formula/dl_formula/shortcuts.py index 5fedf10de..4f25c86b5 100644 --- a/lib/dl_formula/dl_formula/shortcuts.py +++ b/lib/dl_formula/dl_formula/shortcuts.py @@ -334,6 +334,7 @@ def fork( before_filter_by: Optional[List[Union[nodes.Field, builtins.str]]] = None, lod: Optional[nodes.FixedLodSpecifier] = None, join_type: fork_nodes.JoinType = fork_nodes.JoinType.left, + bfb_filter_mutations: Optional[fork_nodes.BfbFilterMutationCollectionSpec] = None, ) -> fork_nodes.QueryFork: before_filter_by_node = self._normalize_raw_bfb(before_filter_by=before_filter_by) return fork_nodes.QueryFork.make( @@ -342,6 +343,7 @@ def fork( result_expr=result_expr, before_filter_by=before_filter_by_node, lod=lod, + bfb_filter_mutations=bfb_filter_mutations, ) def bin_condition(self, expr: nodes.FormulaItem, fork_expr: nodes.FormulaItem) -> fork_nodes.BinaryJoinCondition: diff --git a/lib/dl_formula/dl_formula_tests/unit/mutation/test_subquery_mutations.py b/lib/dl_formula/dl_formula_tests/unit/mutation/test_subquery_mutations.py index ad4f18e6c..44fb60efe 100644 --- a/lib/dl_formula/dl_formula_tests/unit/mutation/test_subquery_mutations.py +++ b/lib/dl_formula/dl_formula_tests/unit/mutation/test_subquery_mutations.py @@ -1,5 +1,6 @@ from __future__ import annotations +import dl_formula.core.fork_nodes as fork_nodes from dl_formula.mutation.lookup import ( LookupDefaultBfbMutation, LookupFunctionToQueryForkMutation, @@ -38,6 +39,12 @@ def test_ago_to_query_fork_mutation(): ], ), lod=n.inherited(), + bfb_filter_mutations=fork_nodes.BfbFilterMutationCollectionSpec.make( + fork_nodes.BfbFilterMutationSpec.make( + original=n.field("date"), + replacement=n.func.DATEADD(n.field("date")), + ), + ), ) ), ) diff --git a/lib/dl_formula/pyproject.toml b/lib/dl_formula/pyproject.toml index 77cec5126..fe2a1cfac 100644 --- a/lib/dl_formula/pyproject.toml +++ b/lib/dl_formula/pyproject.toml @@ -70,6 +70,7 @@ labels = ["no_compose"] bi-formula-cli = "dl_formula.scripts.formula_cli:main" [tool.mypy] +exclude = ["dl_formula_tests/"] warn_unused_configs = true disallow_untyped_defs = true check_untyped_defs = true diff --git a/lib/dl_query_processing/dl_query_processing/multi_query/splitters/mask_based.py b/lib/dl_query_processing/dl_query_processing/multi_query/splitters/mask_based.py index 98cfee0e7..b0e2cf856 100644 --- a/lib/dl_query_processing/dl_query_processing/multi_query/splitters/mask_based.py +++ b/lib/dl_query_processing/dl_query_processing/multi_query/splitters/mask_based.py @@ -121,6 +121,7 @@ class QuerySplitMask: formula_split_masks: tuple[AliasedFormulaSplitMask, ...] = attr.ib(kw_only=True) add_formulas: tuple[AddFormulaInfo, ...] = attr.ib(kw_only=True) filter_indices: frozenset[int] = attr.ib(kw_only=True) + add_filters: tuple[CompiledFormulaInfo, ...] = attr.ib(kw_only=True) join_type: Optional[JoinType] = attr.ib(kw_only=True) joining_node: Optional[formula_fork_nodes.QueryForkJoiningBase] = attr.ib(kw_only=True) is_base: bool = attr.ib(kw_only=True, default=False) @@ -254,6 +255,8 @@ def _generate_subquery_for_mask( filters: list[CompiledFormulaInfo] = [ formula for filter_idx, formula in enumerate(query.filters) if filter_idx in split_mask.filter_indices ] + for filter in split_mask.add_filters: + filters.append(filter) join_on = query.join_on joined_from = query.joined_from @@ -850,6 +853,7 @@ def _patch_query_masks_with_base( formula_split_masks=tuple(base_formula_split_masks), add_formulas=tuple(new_add_formulas), filter_indices=frozenset(new_filter_indices), + add_filters=tuple(), join_type=None, joining_node=None, is_base=True, diff --git a/lib/dl_query_processing/dl_query_processing/multi_query/splitters/query_fork.py b/lib/dl_query_processing/dl_query_processing/multi_query/splitters/query_fork.py index 86b1c6cf8..04cd22050 100644 --- a/lib/dl_query_processing/dl_query_processing/multi_query/splitters/query_fork.py +++ b/lib/dl_query_processing/dl_query_processing/multi_query/splitters/query_fork.py @@ -12,6 +12,10 @@ from dl_formula.inspect.env import InspectionEnvironment import dl_formula.inspect.expression as inspect_expression import dl_formula.inspect.node as inspect_node +from dl_formula.mutation.mutation import ( + FormulaMutation, + apply_mutations, +) from dl_query_processing.compilation.primitives import CompiledQuery from dl_query_processing.enums import QueryPart from dl_query_processing.multi_query.splitters.mask_based import ( @@ -41,12 +45,27 @@ class SubqueryForkSignature: join_type: JoinType +@attr.s(frozen=True) +class SimpleReplacementFormulaMutation(FormulaMutation): + original: formula_nodes.FormulaItem = attr.ib() + replacement: formula_nodes.FormulaItem = attr.ib() + + def match_node(self, node: formula_nodes.FormulaItem, parent_stack: tuple[formula_nodes.FormulaItem, ...]) -> bool: + return node == self.original + + def make_replacement( + self, old: formula_nodes.FormulaItem, parent_stack: tuple[formula_nodes.FormulaItem, ...] + ) -> formula_nodes.FormulaItem: + return self.replacement + + @attr.s(frozen=True) class QueryForkInfo: subquery_type: SubqueryType = attr.ib(kw_only=True) joining_node: formula_fork_nodes.QueryForkJoiningBase = attr.ib(kw_only=True) bfb_field_ids: frozenset[str] = attr.ib(kw_only=True) add_formulas: tuple[AddFormulaInfo, ...] = attr.ib(kw_only=True) + bfb_filter_mutations: tuple[SimpleReplacementFormulaMutation, ...] = attr.ib(kw_only=True) join_type: JoinType = attr.ib(kw_only=True) aliases_by_extract: dict[NodeExtract, str] = attr.ib(kw_only=True, factory=dict) formula_split_masks: list[AliasedFormulaSplitMask] = attr.ib(kw_only=True, factory=list) @@ -338,12 +357,21 @@ def _normalize_bfb(qfork_node: formula_fork_nodes.QueryFork) -> frozenset[str]: ) add_formulas = tuple(dim_add_formulas + non_dim_add_formulas) + bfb_filter_mutations = tuple( + SimpleReplacementFormulaMutation( + original=mutation.original, + replacement=mutation.replacement, + ) + for mutation in qfork_node.bfb_filter_mutations.mutations + ) + qfork_info = QueryForkInfo( subquery_type=subquery_type, add_formulas=add_formulas, joining_node=joining, join_type=join_type, bfb_field_ids=normalized_bfb, + bfb_filter_mutations=bfb_filter_mutations, ) qforks_by_signature[qfork_signature] = qfork_info result.append(qfork_info) @@ -426,11 +454,18 @@ def get_split_masks( for extract, alias in sorted(extract_to_alias_map.items(), key=lambda t: t[0].complexity, reverse=True): joining_node = self._normalize_joining_node(joining_node=joining_node, extract=extract, alias=alias) + add_filters = [] + # Collect indices of filters that should be applied to the sub-query filter_indices: set[int] = set() for filter_idx, filter_formula in enumerate(query.filters): if filter_formula.original_field_id in qfork_info.bfb_field_ids: - # Filter field is in BFB, so exclude it. + # Filter field is in BFB, so exclude it unless it is mutated by one of the BFB mutations + new_formula_obj = apply_mutations(filter_formula.formula_obj, qfork_info.bfb_filter_mutations) + + if new_formula_obj is not filter_formula.formula_obj: + new_filter = filter_formula.clone(formula_obj=new_formula_obj) + add_filters.append(new_filter) continue if filter_idx in split_filter_indices: # This filter can only be applied to a higher-level query @@ -446,6 +481,7 @@ def get_split_masks( subquery_id=query_id_gen.get_id(), formula_split_masks=tuple(qfork_info.formula_split_masks), filter_indices=frozenset(filter_indices), + add_filters=tuple(add_filters), add_formulas=qfork_info.add_formulas, join_type=qfork_info.join_type, joining_node=joining_node, diff --git a/lib/dl_query_processing/dl_query_processing/multi_query/splitters/win_func.py b/lib/dl_query_processing/dl_query_processing/multi_query/splitters/win_func.py index 6266277f3..63a5e5e5f 100644 --- a/lib/dl_query_processing/dl_query_processing/multi_query/splitters/win_func.py +++ b/lib/dl_query_processing/dl_query_processing/multi_query/splitters/win_func.py @@ -108,6 +108,7 @@ def get_split_masks( formula_split_masks=tuple(formula_split_masks), filter_indices=frozenset(), add_formulas=(), + add_filters=(), join_type=JoinType.inner, joining_node=formula_fork_nodes.QueryForkJoiningWithList.make(condition_list=[]), ) diff --git a/lib/dl_query_processing/dl_query_processing_tests/unit/multi_query/test_splitter.py b/lib/dl_query_processing/dl_query_processing_tests/unit/multi_query/test_splitter.py index c89b74b61..816788f1d 100644 --- a/lib/dl_query_processing/dl_query_processing_tests/unit/multi_query/test_splitter.py +++ b/lib/dl_query_processing/dl_query_processing_tests/unit/multi_query/test_splitter.py @@ -86,6 +86,7 @@ def get_split_masks( is_group_by=True, ), ), + add_filters=(), ) ] diff --git a/lib/dl_query_processing/pyproject.toml b/lib/dl_query_processing/pyproject.toml index 88e0681fb..d4b3106f4 100644 --- a/lib/dl_query_processing/pyproject.toml +++ b/lib/dl_query_processing/pyproject.toml @@ -46,6 +46,7 @@ skip_compose = "true" labels = ["no_compose"] [tool.mypy] +exclude = ["dl_query_processing_tests"] warn_unused_configs = true disallow_untyped_defs = true check_untyped_defs = true diff --git a/tools/taskfiles/taskfile_dev.yml b/tools/taskfiles/taskfile_dev.yml index 49b0f046d..c8ec2c9b8 100644 --- a/tools/taskfiles/taskfile_dev.yml +++ b/tools/taskfiles/taskfile_dev.yml @@ -267,9 +267,9 @@ tasks: env: PORTS: sh: DOCKER_HOST=ssh://{{.VM_HOST}} docker ps --format "{{`{{.Ports}}`}}" - | grep -o ':[0-9]*->' + | grep -o '0.0.0.0:[0-9]*->' | sed 's/->//g' - | sed 's/://g' + | sed 's/0.0.0.0://g' | uniq ssh-forward-stop: