Skip to content

Commit

Permalink
Feature custom check (#267)
Browse files Browse the repository at this point in the history
* added is custom to pyspark

* Added is_custom validation for pyspark
  • Loading branch information
canimus authored Jun 29, 2024
1 parent 55d74f6 commit d9ff574
Show file tree
Hide file tree
Showing 48 changed files with 273 additions and 81 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ Check | Description | DataType
`is_on_schedule` | For date fields confirms time windows i.e. `9:00 - 17:00` | _timestamp_
`is_daily` | Can verify daily continuity on date fields by default. `[2,3,4,5,6]` which represents `Mon-Fri` in PySpark. However new schedules can be used for custom date continuity | _date_
`has_workflow` | Adjacency matrix validation on `3-column` graph, based on `group`, `event`, `order` columns. | _agnostic_
`is_custom` | User-defined custom `function` applied to dataframe for row-based validation. | _agnostic_
`satisfies` | An open `SQL expression` builder to construct custom checks | _agnostic_
`validate` | The ultimate transformation of a check with a `dataframe` input for validation | _agnostic_

Expand Down
20 changes: 19 additions & 1 deletion cuallee/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from types import ModuleType
from typing import Any, Dict, List, Literal, Optional, Protocol, Tuple, Union
from typing import Any, Dict, List, Literal, Optional, Protocol, Tuple, Union, Callable
from toolz import compose, valfilter # type: ignore
from toolz.curried import map as map_curried

Expand Down Expand Up @@ -55,6 +55,8 @@
except (ModuleNotFoundError, ImportError):
logger.debug("KO: BigQuery")

class CustomComputeException(Exception):
pass

class CheckLevel(enum.Enum):
"""Level of verifications in cuallee"""
Expand Down Expand Up @@ -1165,6 +1167,22 @@ def has_workflow(
)
return self

def is_custom(
self, column: Union[str, List[str]], fn: Callable = None, pct: float = 1.0
):
"""
Uses a user-defined function that receives the to-be-validated dataframe
and uses the last column of the transformed dataframe to summarize the check
Args:
column (str): Column(s) required for custom function
fn (Callable): A function that receives a dataframe as input and returns a dataframe with at least 1 column as result
pct (float): The threshold percentage required to pass
"""

(Rule("is_custom", column, fn, CheckDataType.AGNOSTIC, pct) >> self._rule)
return self

def validate(self, dataframe: Any):
"""
Compute all rules in this check for specific data frame
Expand Down
39 changes: 36 additions & 3 deletions cuallee/pyspark_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import pyspark.sql.types as T
from pyspark.sql import Window as W
from pyspark.sql import Column, DataFrame, Row
from toolz import first, valfilter # type: ignore
from toolz import first, valfilter, last # type: ignore

import cuallee.utils as cuallee_utils
from cuallee import Check, ComputeEngine, Rule
from cuallee import Check, ComputeEngine, Rule, CustomComputeException

import os

Expand Down Expand Up @@ -587,6 +587,32 @@ def _execute(dataframe: DataFrame, key: str):

return self.compute_instruction

def is_custom(self, rule: Rule):
"""Validates dataframe by applying a custom function to the dataframe and resolving boolean values in the last column"""

predicate = None

def _execute(dataframe: DataFrame, key: str):
try:
assert isinstance(rule.value, Callable), "Please provide a Callable/Function for validation"
computed_frame = rule.value(dataframe)
assert isinstance(computed_frame, DataFrame), "Custom function does not return a PySpark DataFrame"
assert len(computed_frame.columns) >= 1, "Custom function should retun at least one column"
computed_column = last(computed_frame.columns)
return computed_frame.select(
F.sum(F.col(f"`{computed_column}`").cast("integer")).alias(key)
)

except Exception as err:
raise CustomComputeException(str(err))



self.compute_instruction = ComputeInstruction(
predicate, _execute, ComputeMethod.TRANSFORM
)

return self.compute_instruction

def _field_type_filter(
dataframe: DataFrame,
Expand Down Expand Up @@ -769,6 +795,13 @@ def summary(check: Check, dataframe: DataFrame) -> DataFrame:
# TODO: Check should have options for compute engine
spark = SparkSession.builder.getOrCreate()

def _value(x):
""" Removes verbosity for Callable values"""
if isinstance(x, Callable):
return "f(x)"
else:
return str(x)

# Compute the expression
computed_expressions = compute(check._rule)
if (int(spark.version.replace(".", "")[:3]) < 330) or (
Expand Down Expand Up @@ -807,7 +840,7 @@ def summary(check: Check, dataframe: DataFrame) -> DataFrame:
check.level.name,
str(rule.column),
str(rule.method),
str(rule.value),
_value(rule.value),
int(check.rows),
int(rule.violations),
float(rule.pass_rate),
Expand Down
7 changes: 4 additions & 3 deletions cuallee/report/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import List, Tuple
from fpdf import FPDF
#from datetime import datetime, timezone

# from datetime import datetime, timezone


def pdf(data: List[Tuple[str]], name: str = "cuallee.pdf"):
#today = datetime.now(timezone.utc).strftime("%Y%m%d%H%M%S")
#style = FontFace(fill_color="#AAAAAA")
# today = datetime.now(timezone.utc).strftime("%Y%m%d%H%M%S")
# style = FontFace(fill_color="#AAAAAA")
pdf = FPDF(orientation="landscape", format="A4")
pdf.add_page()
pdf.set_font("Helvetica", size=6)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "cuallee"
version = "0.11.0"
version = "0.11.1"
authors = [
{ name="Herminio Vazquez", email="[email protected]"},
{ name="Virginie Grosboillot", email="[email protected]" }
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[metadata]
name = cuallee
version = 0.11.0
version = 0.11.1
[options]
packages = find:
2 changes: 1 addition & 1 deletion test/performance/cuallee/test_performance_cuallee.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


start = datetime.now()
check.validate(df).show(n=int(len(df.columns)*2), truncate=False)
check.validate(df).show(n=int(len(df.columns) * 2), truncate=False)
end = datetime.now()
elapsed = end - start
print("START:", start)
Expand Down
8 changes: 6 additions & 2 deletions test/performance/greatexpectations/test_performance_gx.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@

start = datetime.now()

check_unique = [check.expect_column_values_to_be_unique(name).success for name in df.columns]
check_complete = [check.expect_column_values_to_not_be_null(name).success for name in df.columns]
check_unique = [
check.expect_column_values_to_be_unique(name).success for name in df.columns
]
check_complete = [
check.expect_column_values_to_not_be_null(name).success for name in df.columns
]

end = datetime.now()
print(check_unique + check_complete)
Expand Down
9 changes: 6 additions & 3 deletions test/performance/soda/test_performance_soda.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,11 @@
api_key_id: $soda_key
api_key_secret: $soda_secret
"""
scan.add_configuration_yaml_str(Template(config).substitute(soda_key=os.environ.get("SODA_KEY"), soda_secret=os.environ.get("SODA_SECRET")))

scan.add_configuration_yaml_str(
Template(config).substitute(
soda_key=os.environ.get("SODA_KEY"), soda_secret=os.environ.get("SODA_SECRET")
)
)


start = datetime.now()
Expand All @@ -114,4 +117,4 @@
print("END:", end)
print("ELAPSED:", elapsed)
print("FRAMEWORK: soda")
spark.stop()
spark.stop()
4 changes: 1 addition & 3 deletions test/unit/bigquery/test_is_complete.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ def test_positive():
rs = check.validate(df)
assert rs.status.str.match("PASS")[1]
assert rs.violations[1] == 0



def test_negative():
Expand All @@ -23,7 +22,6 @@ def test_negative():
assert rs.status.str.match("FAIL")[1]
assert rs.violations[1] >= 1589
assert rs.pass_threshold[1] == 1.0



# def test_parameters():
Expand All @@ -37,5 +35,5 @@ def test_coverage():
rs = check.validate(df)
assert rs.status.str.match("PASS")[1]
assert rs.violations[1] >= 1589
#assert rs.pass_threshold[1] == 0.7
# assert rs.pass_threshold[1] == 0.7
# assert rs.pass_rate[1] == 0.9999117752439066 # 207158222/207176656
12 changes: 6 additions & 6 deletions test/unit/bigquery/test_is_daily.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_positive():
check = Check(CheckLevel.WARNING, "pytest")
check.is_daily("trip_start_timestamp")
rs = check.validate(df)
#assert rs.violations[1] > 1
# assert rs.violations[1] > 1


def test_negative():
Expand All @@ -20,7 +20,7 @@ def test_negative():
check.is_daily("trip_end_timestamp")
rs = check.validate(df)
assert rs.status.str.match("FAIL")[1]
#assert rs.violations[1] >= 1
# assert rs.violations[1] >= 1
# assert rs.pass_rate[1] <= 208914146 / 208943621


Expand All @@ -34,15 +34,15 @@ def test_parameters(rule_value):
check = Check(CheckLevel.WARNING, "pytest")
check.is_daily("trip_start_timestamp", rule_value)
rs = check.validate(df)
#assert rs.status.str.match("FAIL")[1]
#assert rs.violations[1] > 0
# assert rs.status.str.match("FAIL")[1]
# assert rs.violations[1] > 0


def test_coverage():
df = bigquery.dataset.Table("bigquery-public-data.chicago_taxi_trips.taxi_trips")
check = Check(CheckLevel.WARNING, "pytest")
check.is_daily("trip_end_timestamp", pct=0.7)
rs = check.validate(df)
#assert rs.status.str.match("PASS")[1]
#assert rs.pass_threshold[1] == 0.7
# assert rs.status.str.match("PASS")[1]
# assert rs.pass_threshold[1] == 0.7
# assert rs.pass_rate[1] <= 208914146 / 208943621
12 changes: 6 additions & 6 deletions test/unit/bigquery/test_is_unique.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@ def test_positive():
rs = check.validate(df)
# assert rs.status.str.match("PASS")[1]
# assert rs.violations[1] == 0
#assert rs.pass_rate[1] == 1.0
# assert rs.pass_rate[1] == 1.0


def test_negative():
df = bigquery.dataset.Table("bigquery-public-data.chicago_taxi_trips.taxi_trips")
check = Check(CheckLevel.WARNING, "pytest")
check.is_unique("taxi_id")
rs = check.validate(df)
#assert rs.status.str.match("FAIL")[1]
#assert rs.violations[1] >= 102580503
#assert rs.pass_threshold[1] == 1.0
# assert rs.status.str.match("FAIL")[1]
# assert rs.violations[1] >= 102580503
# assert rs.pass_threshold[1] == 1.0
# assert rs.pass_rate[1] == 9738 / 208943621


Expand All @@ -35,7 +35,7 @@ def test_coverage():
check = Check(CheckLevel.WARNING, "pytest")
check.is_unique("taxi_id", 0.000007)
rs = check.validate(df)
#assert rs.status.str.match("PASS")[1]
#assert rs.violations[1] >= 102580503
# assert rs.status.str.match("PASS")[1]
# assert rs.violations[1] >= 102580503
# assert rs.pass_threshold[1] == 0.000007
# assert rs.pass_rate[1] == 9738 / 208943621
12 changes: 6 additions & 6 deletions test/unit/bigquery/test_not_contained_in.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ def test_positive():
check = Check(CheckLevel.WARNING, "pytest")
check.not_contained_in("payment_type", ["Dinero"])
rs = check.validate(df)
#assert rs.status.str.match("PASS")[1]
#assert rs.violations[1] == 0
#assert rs.pass_rate[1] == 1.0
# assert rs.status.str.match("PASS")[1]
# assert rs.violations[1] == 0
# assert rs.pass_rate[1] == 1.0


def test_negative():
Expand Down Expand Up @@ -114,15 +114,15 @@ def test_parameters(column_name, rule_value):
check = Check(CheckLevel.WARNING, "pytest")
check.not_contained_in(column_name, rule_value)
rs = check.validate(df)
#assert rs.status.str.match("FAIL")[1]
#assert rs.pass_rate[1] <= 1.0
# assert rs.status.str.match("FAIL")[1]
# assert rs.pass_rate[1] <= 1.0


def test_coverage():
df = bigquery.dataset.Table("bigquery-public-data.chicago_taxi_trips.taxi_trips")
check = Check(CheckLevel.WARNING, "pytest")
check.not_contained_in("payment_type", ("Dinero", "Metalico"), 0.7)
rs = check.validate(df)
#assert rs.status.str.match("PASS")[1]
# assert rs.status.str.match("PASS")[1]
# assert rs.violations[1] == 0
# assert rs.pass_threshold[1] == 0.7
5 changes: 4 additions & 1 deletion test/unit/daft/test_are_complete.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,8 @@ def test_coverage(check: Check):

col_pass_rate = daft.col("pass_rate")
assert (
result.agg(col_pass_rate.max()).select(col_pass_rate == 0.75).to_pandas().pass_rate.all()
result.agg(col_pass_rate.max())
.select(col_pass_rate == 0.75)
.to_pandas()
.pass_rate.all()
)
5 changes: 4 additions & 1 deletion test/unit/daft/test_are_unique.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,8 @@ def test_coverage(check: Check):

col_pass_rate = daft.col("pass_rate")
assert (
result.agg(col_pass_rate.max()).select(col_pass_rate == 0.75).to_pandas().pass_rate.all()
result.agg(col_pass_rate.max())
.select(col_pass_rate == 0.75)
.to_pandas()
.pass_rate.all()
)
5 changes: 4 additions & 1 deletion test/unit/daft/test_has_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,8 @@ def test_coverage(check: Check):
assert result.select(daft.col("status").str.match("PASS")).to_pandas().status.all()
col_pass_rate = daft.col("pass_rate")
assert (
result.agg(col_pass_rate.max()).select(col_pass_rate == 0.75).to_pandas().pass_rate.all()
result.agg(col_pass_rate.max())
.select(col_pass_rate == 0.75)
.to_pandas()
.pass_rate.all()
)
5 changes: 4 additions & 1 deletion test/unit/daft/test_has_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,8 @@ def test_coverage(check: Check):

col_pass_rate = daft.col("pass_rate")
assert (
result.agg(col_pass_rate.max()).select(col_pass_rate == 4/6).to_pandas().pass_rate.all()
result.agg(col_pass_rate.max())
.select(col_pass_rate == 4 / 6)
.to_pandas()
.pass_rate.all()
)
5 changes: 4 additions & 1 deletion test/unit/daft/test_is_between.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,8 @@ def test_coverage(check: Check):

col_pass_rate = daft.col("pass_rate")
assert (
result.agg(col_pass_rate.max()).select(col_pass_rate == 0.55).to_pandas().pass_rate.all()
result.agg(col_pass_rate.max())
.select(col_pass_rate == 0.55)
.to_pandas()
.pass_rate.all()
)
5 changes: 4 additions & 1 deletion test/unit/daft/test_is_complete.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,8 @@ def test_coverage(check: Check):
assert result.select(daft.col("status").str.match("PASS")).to_pandas().status.all()
col_pass_rate = daft.col("pass_rate")
assert (
result.agg(col_pass_rate.max()).select(col_pass_rate == 0.50).to_pandas().pass_rate.all()
result.agg(col_pass_rate.max())
.select(col_pass_rate == 0.50)
.to_pandas()
.pass_rate.all()
)
5 changes: 4 additions & 1 deletion test/unit/daft/test_is_contained_in.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,8 @@ def test_coverage(check: Check):
assert result.select(daft.col("status").str.match("PASS")).to_pandas().status.all()
col_pass_rate = daft.col("pass_rate")
assert (
result.agg(col_pass_rate.max()).select(col_pass_rate == 0.50).to_pandas().pass_rate.all()
result.agg(col_pass_rate.max())
.select(col_pass_rate == 0.50)
.to_pandas()
.pass_rate.all()
)
Loading

0 comments on commit d9ff574

Please sign in to comment.