Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consistently use Python 3.10 type annotations #263

Merged
merged 12 commits into from
Feb 5, 2025
48 changes: 25 additions & 23 deletions src/datajudge/constraints/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import abc
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Any, Callable, Collection, List, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Collection, List, Optional, TypeVar

import sqlalchemy as sa

Expand All @@ -16,7 +18,7 @@
ToleranceGetter = Callable[[sa.engine.Engine], float]


def uncommon_substrings(string1: str, string2: str) -> Tuple[str, str]:
def uncommon_substrings(string1: str, string2: str) -> tuple[str, str]:
qualifiers1 = string1.split(".")
qualifiers2 = string2.split(".")
if qualifiers1[0] != qualifiers2[0]:
Expand All @@ -29,29 +31,29 @@ def uncommon_substrings(string1: str, string2: str) -> Tuple[str, str]:
@dataclass(frozen=True)
class TestResult:
outcome: bool
_failure_message: Optional[str] = field(default=None, repr=False)
_constraint_description: Optional[str] = field(default=None, repr=False)
_factual_queries: Optional[str] = field(default=None, repr=False)
_target_queries: Optional[str] = field(default=None, repr=False)
_failure_message: str | None = field(default=None, repr=False)
_constraint_description: str | None = field(default=None, repr=False)
_factual_queries: str | None = field(default=None, repr=False)
_target_queries: str | None = field(default=None, repr=False)

def formatted_failure_message(self, formatter: Formatter) -> Optional[str]:
def formatted_failure_message(self, formatter: Formatter) -> str | None:
return (
formatter.fmt_str(self._failure_message) if self._failure_message else None
)

def formatted_constraint_description(self, formatter: Formatter) -> Optional[str]:
def formatted_constraint_description(self, formatter: Formatter) -> str | None:
return (
formatter.fmt_str(self._constraint_description)
if self._constraint_description
else None
)

@property
def failure_message(self) -> Optional[str]:
def failure_message(self) -> str | None:
return self.formatted_failure_message(DEFAULT_FORMATTER)

@property
def constraint_description(self) -> Optional[str]:
def constraint_description(self) -> str | None:
return self.formatted_constraint_description(DEFAULT_FORMATTER)

@property
Expand Down Expand Up @@ -121,12 +123,12 @@ def __init__(
self,
ref: DataReference,
*,
ref2: Optional[DataReference] = None,
ref_value: Optional[Any] = None,
name: Optional[str] = None,
output_processors: Optional[
Union[OutputProcessor, List[OutputProcessor]]
] = output_processor_limit,
ref2: DataReference | None = None,
ref_value: Any = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not critical, but the change from optional any to any feels wrong. But probably is an issue that Any is too vague to start with.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point. Do you suggestion yet?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no suggestion. Well, one, but it's quite vague. To create a type that reflects what a ref_value can be. But it feels too much work for little gain.

name: str | None = None,
output_processors: OutputProcessor
| list[OutputProcessor]
| None = output_processor_limit,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why a None did sneak in here?

Copy link
Collaborator Author

@kklein kklein Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before this change, output_processors annotated with Optional[Union[OutputProcessor, List[OutputProcessor]].

If we convert the Optional, this gives us:
Union[OutputProcessor, List[OutputProcessor] | None.

If we now also covert the Union and (List to list), this gives us OutputProcessor | list[OutputProcessor] | None, which, afaict corresponds to the suggested change.

Does that make sense?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems I missed a step in the chain of transformations.

cache_size=None,
):
self._check_if_valid_between_or_within(ref2, ref_value)
Expand All @@ -136,8 +138,8 @@ def __init__(
self.name = name
self.factual_selections: OptionalSelections = None
self.target_selections: OptionalSelections = None
self.factual_queries: Optional[List[str]] = None
self.target_queries: Optional[List[str]] = None
self.factual_queries: list[str] | None = None
self.target_queries: list[str] | None = None

if (output_processors is not None) and (
not isinstance(output_processors, list)
Expand All @@ -156,7 +158,9 @@ def _setup_caching(self):
self.get_target_value = lru_cache(self.cache_size)(self.get_target_value) # type: ignore[method-assign]

def _check_if_valid_between_or_within(
self, ref2: Optional[DataReference], ref_value: Optional[Any]
self,
ref2: DataReference | None,
ref_value: Any,
):
"""Check whether exactly one of ref2 and ref_value arguments have been used."""
class_name = self.__class__.__name__
Expand Down Expand Up @@ -228,13 +232,11 @@ def condition_string(self) -> str:

def retrieve(
self, engine: sa.engine.Engine, ref: DataReference
) -> Tuple[Any, OptionalSelections]:
) -> tuple[Any, OptionalSelections]:
"""Retrieve the value of interest for a DataReference from database."""
pass

def compare(
self, value_factual: Any, value_target: Any
) -> Tuple[bool, Optional[str]]:
def compare(self, value_factual: Any, value_target: Any) -> tuple[bool, str | None]:
pass

def test(self, engine: sa.engine.Engine) -> TestResult:
Expand Down
31 changes: 16 additions & 15 deletions src/datajudge/constraints/column.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import abc
from typing import List, Optional, Tuple, Union

import sqlalchemy as sa

Expand All @@ -11,7 +12,7 @@
class Column(Constraint, abc.ABC):
def retrieve(
self, engine: sa.engine.Engine, ref: DataReference
) -> Tuple[List[str], OptionalSelections]:
) -> tuple[list[str], OptionalSelections]:
# TODO: This does not 'belong' here. Rather, `retrieve` should be free of
# side effects. This should be removed as soon as snowflake column capitalization
# is fixed by snowflake-sqlalchemy.
Expand All @@ -24,15 +25,15 @@ class ColumnExistence(Column):
def __init__(
self,
ref: DataReference,
columns: List[str],
name: Optional[str] = None,
columns: list[str],
name: str | None = None,
cache_size=None,
):
super().__init__(ref, ref_value=columns, name=name, cache_size=cache_size)

def compare(
self, column_names_factual: List[str], column_names_target: List[str]
) -> Tuple[bool, str]:
self, column_names_factual: list[str], column_names_target: list[str]
) -> tuple[bool, str]:
excluded_columns = list(
filter(lambda c: c not in column_names_factual, column_names_target)
)
Expand All @@ -45,8 +46,8 @@ def compare(

class ColumnSubset(Column):
def compare(
self, column_names_factual: List[str], column_names_target: List[str]
) -> Tuple[bool, str]:
self, column_names_factual: list[str], column_names_target: list[str]
) -> tuple[bool, str]:
missing_columns = list(
filter(lambda c: c not in column_names_target, column_names_factual)
)
Expand All @@ -59,8 +60,8 @@ def compare(

class ColumnSuperset(Column):
def compare(
self, column_names_factual: List[str], column_names_target: List[str]
) -> Tuple[bool, str]:
self, column_names_factual: list[str], column_names_target: list[str]
) -> tuple[bool, str]:
missing_columns = list(
filter(lambda c: c not in column_names_factual, column_names_target)
)
Expand All @@ -87,9 +88,9 @@ def __init__(
self,
ref: DataReference,
*,
ref2: Optional[DataReference] = None,
column_type: Optional[Union[str, sa.types.TypeEngine]] = None,
name: Optional[str] = None,
ref2: DataReference | None = None,
column_type: str | sa.types.TypeEngine | None = None,
name: str | None = None,
cache_size=None,
):
super().__init__(
Expand All @@ -103,11 +104,11 @@ def __init__(

def retrieve(
self, engine: sa.engine.Engine, ref: DataReference
) -> Tuple[sa.types.TypeEngine, OptionalSelections]:
) -> tuple[sa.types.TypeEngine, OptionalSelections]:
result, selections = db_access.get_column_type(engine, ref)
return result, selections

def compare(self, column_type_factual, column_type_target) -> Tuple[bool, str]:
def compare(self, column_type_factual, column_type_target) -> tuple[bool, str]:
assertion_message = (
f"{self.ref} is {column_type_factual} instead of {column_type_target}."
)
Expand Down
12 changes: 7 additions & 5 deletions src/datajudge/constraints/groupby.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any, Optional, Tuple
from __future__ import annotations

from typing import Any

import sqlalchemy as sa

Expand All @@ -13,11 +15,11 @@ def __init__(
ref: DataReference,
aggregation_column: str,
start_value: int = 0,
name: Optional[str] = None,
name: str | None = None,
cache_size=None,
*,
tolerance: float = 0,
ref2: Optional[DataReference] = None,
ref2: DataReference | None = None,
):
super().__init__(ref, ref2=ref2, ref_value=object(), name=name)
self.aggregation_column = aggregation_column
Expand All @@ -27,14 +29,14 @@ def __init__(

def retrieve(
self, engine: sa.engine.Engine, ref: DataReference
) -> Tuple[Any, OptionalSelections]:
) -> tuple[Any, OptionalSelections]:
result, selections = db_access.get_column_array_agg(
engine, ref, self.aggregation_column
)
result = {fact[:-1]: fact[-1] for fact in result}
return result, selections

def compare(self, factual: Any, target: Any) -> Tuple[bool, Optional[str]]:
def compare(self, factual: Any, target: Any) -> tuple[bool, str | None]:
def missing_from_range(values, start=0):
return set(range(start, max(values) + start)) - set(values)

Expand Down
32 changes: 17 additions & 15 deletions src/datajudge/constraints/interval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import abc
from typing import Any, List, Optional, Tuple
from typing import Any

import sqlalchemy as sa

Expand All @@ -14,11 +16,11 @@ class IntervalConstraint(Constraint):
def __init__(
self,
ref: DataReference,
key_columns: Optional[List[str]],
start_columns: List[str],
end_columns: List[str],
key_columns: list[str] | None,
start_columns: list[str],
end_columns: list[str],
max_relative_n_violations: float,
name: Optional[str] = None,
name: str | None = None,
cache_size=None,
):
super().__init__(ref, ref_value=object(), name=name)
Expand All @@ -44,7 +46,7 @@ def _validate_dimensions(self):

def retrieve(
self, engine: sa.engine.Engine, ref: DataReference
) -> Tuple[Tuple[int, int], OptionalSelections]:
) -> tuple[tuple[int, int], OptionalSelections]:
keys_ref = DataReference(
data_source=self.ref.data_source,
columns=self.key_columns,
Expand All @@ -69,12 +71,12 @@ class NoOverlapConstraint(IntervalConstraint):
def __init__(
self,
ref: DataReference,
key_columns: Optional[List[str]],
start_columns: List[str],
end_columns: List[str],
key_columns: list[str] | None,
start_columns: list[str],
end_columns: list[str],
max_relative_n_violations: float,
end_included: bool,
name: Optional[str] = None,
name: str | None = None,
cache_size=None,
):
self.end_included = end_included
Expand Down Expand Up @@ -110,12 +112,12 @@ class NoGapConstraint(IntervalConstraint):
def __init__(
self,
ref: DataReference,
key_columns: Optional[List[str]],
start_columns: List[str],
end_columns: List[str],
key_columns: list[str] | None,
start_columns: list[str],
end_columns: list[str],
max_relative_n_violations: float,
legitimate_gap_size: float,
name: Optional[str] = None,
name: str | None = None,
cache_size=None,
):
self.legitimate_gap_size = legitimate_gap_size
Expand All @@ -134,5 +136,5 @@ def select(self, engine: sa.engine.Engine, ref: DataReference):
pass

@abc.abstractmethod
def compare(self, factual: Tuple[int, int], target: Any) -> Tuple[bool, str]:
def compare(self, factual: tuple[int, int], target: Any) -> tuple[bool, str]:
pass
25 changes: 13 additions & 12 deletions src/datajudge/constraints/miscs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from typing import List, Optional, Set, Tuple

import sqlalchemy as sa

Expand All @@ -12,24 +13,24 @@ class PrimaryKeyDefinition(Constraint):
def __init__(
self,
ref,
primary_keys: List[str],
name: Optional[str] = None,
primary_keys: list[str],
name: str | None = None,
cache_size=None,
):
super().__init__(ref, ref_value=set(primary_keys), name=name)

def retrieve(
self, engine: sa.engine.Engine, ref: DataReference
) -> Tuple[Set[str], OptionalSelections]:
) -> tuple[set[str], OptionalSelections]:
if db_access.is_impala(engine):
raise NotImplementedError("Primary key retrieval does not work for Impala.")
values, selections = db_access.get_primary_keys(engine, self.ref)
return set(values), selections

# Note: Exact equality!
def compare(
self, primary_keys_factual: Set[str], primary_keys_target: Set[str]
) -> Tuple[bool, Optional[str]]:
self, primary_keys_factual: set[str], primary_keys_target: set[str]
) -> tuple[bool, str | None]:
assertion_message = ""
result = True
# If both are true, just report one.
Expand Down Expand Up @@ -61,7 +62,7 @@ def __init__(
max_duplicate_fraction: float = 0,
max_absolute_n_duplicates: int = 0,
infer_pk_columns: bool = False,
name: Optional[str] = None,
name: str | None = None,
cache_size=None,
):
if max_duplicate_fraction != 0 and max_absolute_n_duplicates != 0:
Expand Down Expand Up @@ -125,7 +126,7 @@ def test(self, engine: sa.engine.Engine) -> TestResult:


class FunctionalDependency(Constraint):
def __init__(self, ref: DataReference, key_columns: List[str], **kwargs):
def __init__(self, ref: DataReference, key_columns: list[str], **kwargs):
super().__init__(ref, ref_value=object(), **kwargs)
self.key_columns = key_columns

Expand Down Expand Up @@ -155,10 +156,10 @@ def __init__(
self,
ref,
*,
ref2: Optional[DataReference] = None,
max_null_fraction: Optional[float] = None,
ref2: DataReference | None = None,
max_null_fraction: float | None = None,
max_relative_deviation: float = 0,
name: Optional[str] = None,
name: str | None = None,
cache_size=None,
):
super().__init__(
Expand All @@ -184,7 +185,7 @@ def retrieve(self, engine: sa.engine.Engine, ref: DataReference):

def compare(
self, missing_fraction_factual: float, missing_fracion_target: float
) -> Tuple[bool, Optional[str]]:
) -> tuple[bool, str | None]:
threshold = missing_fracion_target * (1 + self.max_relative_deviation)
result = missing_fraction_factual <= threshold
assertion_text = (
Expand Down
Loading
Loading