Skip to content

Commit

Permalink
Add type hints to db_access (#262)
Browse files Browse the repository at this point in the history
  • Loading branch information
kklein authored Feb 4, 2025
1 parent b9acf3c commit 1429f9d
Show file tree
Hide file tree
Showing 4 changed files with 245 additions and 127 deletions.
40 changes: 21 additions & 19 deletions src/datajudge/constraints/date.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import datetime as dt
from typing import Any, Optional, Tuple, Union
from typing import Any, Union

import sqlalchemy as sa

Expand Down Expand Up @@ -38,15 +40,15 @@ def __init__(
ref: DataReference,
use_lower_bound_reference: bool,
column_type: str,
name: Optional[str] = None,
name: str | None = None,
cache_size=None,
*,
ref2: Optional[DataReference] = None,
min_value: Optional[str] = None,
ref2: DataReference | None = None,
min_value: str | None = None,
):
self.format = get_format_from_column_type(column_type)
self.use_lower_bound_reference = use_lower_bound_reference
min_date: Optional[dt.date] = None
min_date: dt.date | None = None
if min_value is not None:
min_date = dt.datetime.strptime(min_value, INPUT_DATE_FORMAT).date()
super().__init__(
Expand All @@ -59,11 +61,11 @@ def __init__(

def retrieve(
self, engine: sa.engine.Engine, ref: DataReference
) -> Tuple[dt.date, OptionalSelections]:
) -> tuple[dt.date, OptionalSelections]:
result, selections = db_access.get_min(engine, ref)
return convert_to_date(result, self.format), selections

def compare(self, min_factual: dt.date, min_target: dt.date) -> Tuple[bool, str]:
def compare(self, min_factual: dt.date, min_target: dt.date) -> tuple[bool, str]:
if min_target is None:
return TestResult(True, "")
if min_factual is None:
Expand Down Expand Up @@ -91,15 +93,15 @@ def __init__(
ref: DataReference,
use_upper_bound_reference: bool,
column_type: str,
name: Optional[str] = None,
name: str | None = None,
cache_size=None,
*,
ref2: Optional[DataReference] = None,
max_value: Optional[str] = None,
ref2: DataReference | None = None,
max_value: str | None = None,
):
self.format = get_format_from_column_type(column_type)
self.use_upper_bound_reference = use_upper_bound_reference
max_date: Optional[dt.date] = None
max_date: dt.date | None = None
if max_value is not None:
max_date = dt.datetime.strptime(max_value, INPUT_DATE_FORMAT).date()
super().__init__(
Expand All @@ -112,11 +114,11 @@ def __init__(

def retrieve(
self, engine: sa.engine.Engine, ref: DataReference
) -> Tuple[dt.date, OptionalSelections]:
) -> tuple[dt.date, OptionalSelections]:
value, selections = db_access.get_max(engine, ref)
return convert_to_date(value, self.format), selections

def compare(self, max_factual: dt.date, max_target: dt.date) -> Tuple[bool, str]:
def compare(self, max_factual: dt.date, max_target: dt.date) -> tuple[bool, str]:
if max_factual is None:
return True, None
if max_target is None:
Expand Down Expand Up @@ -146,7 +148,7 @@ def __init__(
min_fraction: float,
lower_bound: str,
upper_bound: str,
name: Optional[str] = None,
name: str | None = None,
cache_size=None,
):
super().__init__(ref, ref_value=min_fraction, name=name, cache_size=cache_size)
Expand All @@ -155,14 +157,14 @@ def __init__(

def retrieve(
self, engine: sa.engine.Engine, ref: DataReference
) -> Tuple[float, OptionalSelections]:
) -> tuple[float | None, OptionalSelections]:
return db_access.get_fraction_between(
engine, ref, self.lower_bound, self.upper_bound
)

def compare(
self, fraction_factual: float, fraction_target: float
) -> Tuple[bool, str]:
) -> tuple[bool, str]:
assertion_text = (
f"{self.ref} has {fraction_factual} < "
f"{fraction_target} of values between {self.lower_bound} and "
Expand All @@ -175,7 +177,7 @@ def compare(
class DateNoOverlap(NoOverlapConstraint):
_DIMENSIONS = 1

def compare(self, factual: Tuple[int, int], target: Any) -> Tuple[bool, str]:
def compare(self, factual: tuple[int, int], target: Any) -> tuple[bool, str]:
n_violation_keys, n_distinct_key_values = factual
if n_distinct_key_values == 0:
return TestResult.success()
Expand All @@ -193,7 +195,7 @@ def compare(self, factual: Tuple[int, int], target: Any) -> Tuple[bool, str]:
class DateNoOverlap2d(NoOverlapConstraint):
_DIMENSIONS = 2

def compare(self, factual: Tuple[int, int], target: Any) -> Tuple[bool, str]:
def compare(self, factual: tuple[int, int], target: Any) -> tuple[bool, str]:
n_violation_keys, n_distinct_key_values = factual
if n_distinct_key_values == 0:
return TestResult.success()
Expand Down Expand Up @@ -225,7 +227,7 @@ def select(self, engine: sa.engine.Engine, ref: DataReference):
# executing it, one would want to list this selection here as well.
return sample_selection, n_violations_selection

def compare(self, factual: Tuple[int, int], target: Any) -> Tuple[bool, str]:
def compare(self, factual: tuple[int, int], target: Any) -> tuple[bool, str]:
n_violation_keys, n_distinct_key_values = factual
if n_distinct_key_values == 0:
return TestResult.success()
Expand Down
60 changes: 29 additions & 31 deletions src/datajudge/constraints/numeric.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any, Optional, Tuple
from __future__ import annotations

from typing import Any

import sqlalchemy as sa

Expand All @@ -12,11 +14,11 @@ class NumericMin(Constraint):
def __init__(
self,
ref: DataReference,
name: Optional[str] = None,
name: str | None = None,
cache_size=None,
*,
ref2: Optional[DataReference] = None,
min_value: Optional[float] = None,
ref2: DataReference | None = None,
min_value: float | None = None,
):
super().__init__(
ref,
Expand All @@ -28,12 +30,10 @@ def __init__(

def retrieve(
self, engine: sa.engine.Engine, ref: DataReference
) -> Tuple[float, OptionalSelections]:
) -> tuple[float, OptionalSelections]:
return db_access.get_min(engine, ref)

def compare(
self, min_factual: float, min_target: float
) -> Tuple[bool, Optional[str]]:
def compare(self, min_factual: float, min_target: float) -> tuple[bool, str | None]:
if min_target is None:
return True, None
if min_factual is None:
Expand All @@ -52,11 +52,11 @@ class NumericMax(Constraint):
def __init__(
self,
ref: DataReference,
name: Optional[str] = None,
name: str | None = None,
cache_size=None,
*,
ref2: Optional[DataReference] = None,
max_value: Optional[float] = None,
ref2: DataReference | None = None,
max_value: float | None = None,
):
super().__init__(
ref,
Expand All @@ -68,12 +68,10 @@ def __init__(

def retrieve(
self, engine: sa.engine.Engine, ref: DataReference
) -> Tuple[float, OptionalSelections]:
) -> tuple[float, OptionalSelections]:
return db_access.get_max(engine, ref)

def compare(
self, max_factual: float, max_target: float
) -> Tuple[bool, Optional[str]]:
def compare(self, max_factual: float, max_target: float) -> tuple[bool, str | None]:
if max_factual is None:
return True, None
if max_target is None:
Expand All @@ -95,7 +93,7 @@ def __init__(
min_fraction: float,
lower_bound: float,
upper_bound: float,
name: Optional[str] = None,
name: str | None = None,
cache_size=None,
):
super().__init__(ref, ref_value=min_fraction, name=name, cache_size=cache_size)
Expand All @@ -104,7 +102,7 @@ def __init__(

def retrieve(
self, engine: sa.engine.Engine, ref: DataReference
) -> Tuple[float, OptionalSelections]:
) -> tuple[float | None, OptionalSelections]:
return db_access.get_fraction_between(
engine,
ref,
Expand All @@ -114,7 +112,7 @@ def retrieve(

def compare(
self, fraction_factual: float, fraction_target: float
) -> Tuple[bool, Optional[str]]:
) -> tuple[bool, str | None]:
if fraction_factual is None:
return True, "Empty selection."
assertion_text = (
Expand All @@ -132,11 +130,11 @@ def __init__(
self,
ref: DataReference,
max_absolute_deviation: float,
name: Optional[str] = None,
name: str | None = None,
cache_size=None,
*,
ref2: Optional[DataReference] = None,
mean_value: Optional[float] = None,
ref2: DataReference | None = None,
mean_value: float | None = None,
):
super().__init__(
ref,
Expand All @@ -149,7 +147,7 @@ def __init__(

def retrieve(
self, engine: sa.engine.Engine, ref: DataReference
) -> Tuple[float, OptionalSelections]:
) -> tuple[float, OptionalSelections]:
result, selections = db_access.get_mean(engine, ref)
return result, selections

Expand Down Expand Up @@ -178,13 +176,13 @@ def __init__(
self,
ref: DataReference,
percentage: float,
max_absolute_deviation: Optional[float] = None,
max_relative_deviation: Optional[float] = None,
name: Optional[str] = None,
max_absolute_deviation: float | None = None,
max_relative_deviation: float | None = None,
name: str | None = None,
cache_size=None,
*,
ref2: Optional[DataReference] = None,
expected_percentile: Optional[float] = None,
ref2: DataReference | None = None,
expected_percentile: float | None = None,
):
super().__init__(
ref,
Expand Down Expand Up @@ -216,13 +214,13 @@ def __init__(

def retrieve(
self, engine: sa.engine.Engine, ref: DataReference
) -> Tuple[float, OptionalSelections]:
) -> tuple[float, OptionalSelections]:
result, selections = db_access.get_percentile(engine, ref, self.percentage)
return result, selections

def compare(
self, percentile_factual: float, percentile_target: float
) -> Tuple[bool, Optional[str]]:
) -> tuple[bool, str | None]:
abs_diff = abs(percentile_factual - percentile_target)
if (
self.max_absolute_deviation is not None
Expand Down Expand Up @@ -269,7 +267,7 @@ def select(self, engine: sa.engine.Engine, ref: DataReference):
# executing it, one would want to list this selection here as well.
return sample_selection, n_violations_selection

def compare(self, factual: Tuple[int, int], target: Any) -> Tuple[bool, str]:
def compare(self, factual: tuple[int, int], target: Any) -> tuple[bool, str]:
n_violation_keys, n_distinct_key_values = factual
if n_distinct_key_values == 0:
return TestResult.success()
Expand All @@ -287,7 +285,7 @@ def compare(self, factual: Tuple[int, int], target: Any) -> Tuple[bool, str]:
class NumericNoOverlap(NoOverlapConstraint):
_DIMENSIONS = 1

def compare(self, factual: Tuple[int, int], target: Any) -> Tuple[bool, str]:
def compare(self, factual: tuple[int, int], target: Any) -> tuple[bool, str]:
n_violation_keys, n_distinct_key_values = factual
if n_distinct_key_values == 0:
return TestResult.success()
Expand Down
14 changes: 14 additions & 0 deletions src/datajudge/constraints/row.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def __init__(
self.max_missing_fraction_getter = max_missing_fraction_getter

def test(self, engine: sa.engine.Engine) -> TestResult:
if self.ref is None or self.ref2 is None:
raise ValueError()
if db_access.is_impala(engine):
raise NotImplementedError("Currently not implemented for impala.")
self.max_missing_fraction = self.max_missing_fraction_getter(engine)
Expand All @@ -36,6 +38,8 @@ def test(self, engine: sa.engine.Engine) -> TestResult:

class RowEquality(Row):
def get_factual_value(self, engine: sa.engine.Engine) -> Tuple[int, int]:
if self.ref is None or self.ref2 is None:
raise ValueError()
n_rows_missing_left, selections_left = db_access.get_row_difference_count(
engine, self.ref, self.ref2
)
Expand All @@ -46,6 +50,8 @@ def get_factual_value(self, engine: sa.engine.Engine) -> Tuple[int, int]:
return n_rows_missing_left, n_rows_missing_right

def get_target_value(self, engine: sa.engine.Engine) -> int:
if self.ref is None or self.ref2 is None:
raise ValueError()
n_rows_total, selections = db_access.get_unique_count_union(
engine, self.ref, self.ref2
)
Expand Down Expand Up @@ -80,6 +86,8 @@ def compare(
class RowSubset(Row):
@lru_cache(maxsize=None)
def get_factual_value(self, engine: sa.engine.Engine) -> int:
if self.ref is None or self.ref2 is None:
raise ValueError()
n_rows_missing, selections = db_access.get_row_difference_count(
engine,
self.ref,
Expand Down Expand Up @@ -118,13 +126,17 @@ def compare(

class RowSuperset(Row):
def get_factual_value(self, engine: sa.engine.Engine) -> int:
if self.ref is None or self.ref2 is None:
raise ValueError()
n_rows_missing, selections = db_access.get_row_difference_count(
engine, self.ref2, self.ref
)
self.factual_selections = selections
return n_rows_missing

def get_target_value(self, engine: sa.engine.Engine) -> int:
if self.ref is None or self.ref2 is None:
raise ValueError()
n_rows_total, selections = db_access.get_unique_count(engine, self.ref2)
self.target_selections = selections
return n_rows_total
Expand Down Expand Up @@ -180,6 +192,8 @@ def __init__(
)

def test(self, engine: sa.engine.Engine) -> TestResult:
if self.ref is None or self.ref2 is None:
raise ValueError()
missing_fraction, n_rows_match, selections = db_access.get_row_mismatch(
engine, self.ref, self.ref2, self.match_and_compare
)
Expand Down
Loading

0 comments on commit 1429f9d

Please sign in to comment.