Skip to content

Commit

Permalink
Added extract property to query objects
Browse files Browse the repository at this point in the history
  • Loading branch information
altvod committed Oct 27, 2023
1 parent 26128aa commit 50b5a54
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 3 deletions.
4 changes: 4 additions & 0 deletions lib/dl_formula/dl_formula/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
)
from typing import (
FrozenSet,
Hashable,
Optional,
)

Expand Down Expand Up @@ -101,6 +102,9 @@ class DataTypeParams:
timezone: Optional[str] = attr.ib(default=None)
# Other possible cases: decimal precision, datetime sub-second precision, nullable, enum values.

def as_primitive(self) -> tuple[Optional[Hashable]]:
return (self.timezone,)


_AUTOCAST_FROM_TYPES = OrderedDict(
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,12 @@
OrderDirection,
)
import dl_formula.core.nodes as formula_nodes
from dl_query_processing.compilation.query_meta import QueryMetaInfo
from dl_query_processing.compilation.query_meta import QueryMetaInfo, QueryElementExtract
from dl_query_processing.enums import (
ExecutionLevel,
QueryPart,
)


_COMPILED_FLA_TV = TypeVar("_COMPILED_FLA_TV", bound="CompiledFormulaInfo")


Expand All @@ -45,6 +44,17 @@ class CompiledFormulaInfo:
alias: Optional[str] = attr.ib()
avatar_ids: Set[str] = attr.ib(factory=set)
original_field_id: Optional[str] = attr.ib(default=None)

@property
def extract(self) -> QueryElementExtract:
return QueryElementExtract(
values=(
self.formula_obj.extract,
self.alias,
frozenset(self.avatar_ids),
self.original_field_id,
),
)

@property
def not_none_alias(self) -> str:
Expand Down Expand Up @@ -96,6 +106,17 @@ class CompiledOrderByFormulaInfo(CompiledFormulaInfo): # noqa

direction: OrderDirection = attr.ib(kw_only=True)

@property
def extract(self) -> QueryElementExtract:
return QueryElementExtract(
values=(
self.formula_obj.extract,
self.alias,
frozenset(self.avatar_ids),
self.original_field_id,
self.direction,
),
)

@attr.s(slots=True, frozen=True)
class CompiledJoinOnFormulaInfo(CompiledFormulaInfo): # noqa
Expand All @@ -109,6 +130,20 @@ class CompiledJoinOnFormulaInfo(CompiledFormulaInfo): # noqa
right_id: str = attr.ib(kw_only=True)
join_type: JoinType = attr.ib(kw_only=True)

@property
def extract(self) -> QueryElementExtract:
return QueryElementExtract(
values=(
self.formula_obj.extract,
self.alias,
frozenset(self.avatar_ids),
self.original_field_id,
self.left_id,
self.right_id,
self.join_type.name,
),
)


_FROM_OBJ_TV = TypeVar("_FROM_OBJ_TV", bound="FromObject")

Expand All @@ -121,13 +156,32 @@ class FromColumn:
def clone(self, **updates: Any) -> FromColumn:
return attr.evolve(self, **updates)

@property
def extract(self) -> QueryElementExtract:
return QueryElementExtract(
values=(
self.id,
self.name,
),
)


@attr.s(frozen=True)
class FromObject:
id: str = attr.ib(kw_only=True)
alias: str = attr.ib(kw_only=True)
columns: tuple[FromColumn, ...] = attr.ib(kw_only=True)

@property
def extract(self) -> QueryElementExtract:
return QueryElementExtract(
values=(
self.id,
self.alias,
tuple(col.extract for col in self.columns),
),
)

def clone(self: _FROM_OBJ_TV, **updates: Any) -> _FROM_OBJ_TV:
return attr.evolve(self, **updates)

Expand All @@ -137,6 +191,15 @@ class JoinedFromObject:
root_from_id: Optional[str] = attr.ib(kw_only=True, default=None)
froms: Sequence[FromObject] = attr.ib(kw_only=True, default=())

@property
def extract(self) -> QueryElementExtract:
return QueryElementExtract(
values=(
self.root_from_id,
tuple(from_obj.extract for from_obj in self.froms),
),
)

def iter_ids(self) -> Iterable[str]:
return (from_obj.id for from_obj in self.froms)

Expand Down Expand Up @@ -175,6 +238,24 @@ class CompiledQuery:
offset: Optional[int] = attr.ib(kw_only=True, default=None)
meta: QueryMetaInfo = attr.ib(kw_only=True, factory=QueryMetaInfo)

@property
def extract(self) -> QueryElementExtract:
return QueryElementExtract(
values=(
self.id,
self.level_type.name,
tuple(formula.extract for formula in self.select),
tuple(formula.extract for formula in self.group_by),
tuple(formula.extract for formula in self.filters),
tuple(formula.extract for formula in self.order_by),
tuple(formula.extract for formula in self.join_on),
self.joined_from.extract,
self.limit,
self.offset,
self.meta.extract,
),
)

def get_complexity(self) -> int:
return sum(formula.complexity for formula in self.all_formulas)

Expand Down Expand Up @@ -262,6 +343,14 @@ class SubqueryFromObject(FromObject):

@attr.s(frozen=True)
class CompiledMultiQueryBase(abc.ABC):
@property
def extract(self) -> QueryElementExtract:
return QueryElementExtract(
values=(
tuple(query.extract for query in sorted(self.iter_queries(), key=lambda query: query.id)),
),
)

@abc.abstractmethod
def iter_queries(self) -> Iterable[CompiledQuery]:
raise NotImplementedError
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
Any,
Optional,
TypeVar,
Hashable,
NamedTuple,
)

import attr
Expand All @@ -15,6 +17,10 @@
)


class QueryElementExtract(NamedTuple):
values: tuple[Optional[Hashable], ...]


_QUERY_META_TV = TypeVar("_QUERY_META_TV", bound="QueryMetaInfo")


Expand All @@ -28,6 +34,20 @@ class QueryMetaInfo:
subquery_limit: int = attr.ib(kw_only=True, default=None)
empty_query_mode: EmptyQueryMode = attr.ib(kw_only=True, default=EmptyQueryMode.error)

@property
def extract(self) -> QueryElementExtract:
return QueryElementExtract(
values=(
self.query_type.name,
tuple(self.phantom_select_ids),
tuple(self.field_order) if self.field_order is not None else None,
self.row_count_hard_limit,
self.from_subquery,
self.subquery_limit,
self.empty_query_mode.name,
),
)

@property
def result_field_ids(self) -> list[str]:
assert self.field_order is not None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def translate_flat_query(
offset=compiled_flat_query.offset,
column_list=column_list,
meta=translated_meta,
extract=compiled_flat_query.extract,
)

def get_collected_stats(self) -> TranslationStats:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
DataTypeParams,
)
from dl_query_processing.compilation.primitives import FromObject
from dl_query_processing.compilation.query_meta import QueryMetaInfo
from dl_query_processing.compilation.query_meta import QueryMetaInfo, QueryElementExtract
from dl_query_processing.enums import ExecutionLevel


Expand All @@ -42,6 +42,15 @@ class DetailedType(NamedTuple):
# TODO: native_type: Optional[GenericNativeType] = None
formula_data_type: Optional[DataType] = None
formula_data_type_params: Optional[DataTypeParams] = None

@property
def extract(self) -> QueryElementExtract:
return QueryElementExtract(values=(
self.field_id,
self.data_type.name,
self.formula_data_type.name if self.formula_data_type is not None else None,
self.formula_data_type_params.as_primitive() if self.formula_data_type is not None else None,
))


_META_TV = TypeVar("_META_TV", bound="TranslatedQueryMetaInfo")
Expand All @@ -51,6 +60,16 @@ class DetailedType(NamedTuple):
class TranslatedQueryMetaInfo(QueryMetaInfo):
detailed_types: Optional[list[Optional[DetailedType]]] = attr.ib(kw_only=True, factory=list) # type: ignore

@property
def extract(self) -> QueryElementExtract:
return QueryElementExtract(values=(
*super().extract,
(
dt.extract if dt is not None else None
for dt in self.detailed_types
) if self.detailed_types is not None else None,
))

@classmethod
def from_comp_meta(
cls: Type[_META_TV],
Expand Down Expand Up @@ -107,6 +126,8 @@ class TranslatedFlatQuery:
column_list: list[SchemaColumn] = attr.ib(kw_only=True)
meta: TranslatedQueryMetaInfo = attr.ib(kw_only=True, factory=TranslatedQueryMetaInfo)

extract: QueryElementExtract = attr.ib(kw_only=True)

def is_empty(self) -> bool:
return not self.select

Expand All @@ -116,6 +137,14 @@ def is_empty(self) -> bool:

@attr.s(frozen=True)
class TranslatedMultiQueryBase(abc.ABC):
@property
def extract(self) -> QueryElementExtract:
return QueryElementExtract(
values=(
tuple(query.extract for query in sorted(self.iter_queries(), key=lambda query: query.id)),
),
)

@abc.abstractmethod
def iter_queries(self) -> Iterable[TranslatedFlatQuery]:
raise NotImplementedError
Expand Down

0 comments on commit 50b5a54

Please sign in to comment.