Skip to content

Commit

Permalink
Merge pull request #265 from paulooctavio/feature/Update-validate-met…
Browse files Browse the repository at this point in the history
…hods-to-return-boolean-clean

Added new parameter `return_bool` to validate dataframe methods
  • Loading branch information
jeffbrennan authored Oct 3, 2024
2 parents e669317 + 187f98a commit a0849f3
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 28 deletions.
52 changes: 37 additions & 15 deletions quinn/dataframe_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import annotations

import copy
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union

if TYPE_CHECKING:
from pyspark.sql import DataFrame
Expand All @@ -33,29 +33,37 @@ class DataFrameProhibitedColumnError(ValueError):
"""Raise this when a DataFrame includes prohibited columns."""


def validate_presence_of_columns(df: DataFrame, required_col_names: list[str]) -> None:
def validate_presence_of_columns(df: DataFrame, required_col_names: list[str], return_bool: bool = False) -> Union[None, bool]:
"""Validate the presence of column names in a DataFrame.
:param df: A spark DataFrame.
:type df: DataFrame`
:type df: DataFrame
:param required_col_names: List of the required column names for the DataFrame.
:type required_col_names: :py:class:`list` of :py:class:`str`
:return: None.
:type required_col_names: list[str]
:param return_bool: If True, return a boolean instead of raising an exception.
:type return_bool: bool
:return: None if return_bool is False, otherwise a boolean indicating if validation passed.
:raises DataFrameMissingColumnError: if any of the requested column names are
not present in the DataFrame.
not present in the DataFrame and return_bool is False.
"""
all_col_names = df.columns
missing_col_names = [x for x in required_col_names if x not in all_col_names]
error_message = f"The {missing_col_names} columns are not included in the DataFrame with the following columns {all_col_names}"

Check failure on line 51 in quinn/dataframe_validator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W293)

quinn/dataframe_validator.py:51:1: W293 Blank line contains whitespace
if missing_col_names:
error_message = f"The {missing_col_names} columns are not included in the DataFrame with the following columns {all_col_names}"
if return_bool:
return False
raise DataFrameMissingColumnError(error_message)

Check failure on line 57 in quinn/dataframe_validator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W293)

quinn/dataframe_validator.py:57:1: W293 Blank line contains whitespace
return True if return_bool else None


def validate_schema(
df: DataFrame,
required_schema: StructType,
ignore_nullable: bool = False,
) -> None:
return_bool: bool = False,
) -> Union[None, bool]:
"""Function that validate if a given DataFrame has a given StructType as its schema.
:param df: DataFrame to validate
Expand All @@ -65,9 +73,11 @@ def validate_schema(
:param ignore_nullable: (Optional) A flag for if nullable fields should be
ignored during validation
:type ignore_nullable: bool, optional
:param return_bool: If True, return a boolean instead of raising an exception.
:type return_bool: bool
:return: None if return_bool is False, otherwise a boolean indicating if validation passed.
:raises DataFrameMissingStructFieldError: if any StructFields from the required
schema are not included in the DataFrame schema
schema are not included in the DataFrame schema and return_bool is False.
"""
_all_struct_fields = copy.deepcopy(df.schema)
_required_schema = copy.deepcopy(required_schema)
Expand All @@ -80,22 +90,34 @@ def validate_schema(
x.nullable = None

missing_struct_fields = [x for x in _required_schema if x not in _all_struct_fields]
error_message = f"The {missing_struct_fields} StructFields are not included in the DataFrame with the following StructFields {_all_struct_fields}"


Check failure on line 93 in quinn/dataframe_validator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W293)

quinn/dataframe_validator.py:93:1: W293 Blank line contains whitespace
if missing_struct_fields:
error_message = f"The {missing_struct_fields} StructFields are not included in the DataFrame with the following StructFields {_all_struct_fields}"

Check failure on line 95 in quinn/dataframe_validator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

quinn/dataframe_validator.py:95:151: E501 Line too long (154 > 150 characters)
if return_bool:
return False
raise DataFrameMissingStructFieldError(error_message)

Check failure on line 99 in quinn/dataframe_validator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W293)

quinn/dataframe_validator.py:99:1: W293 Blank line contains whitespace
return True if return_bool else None


def validate_absence_of_columns(df: DataFrame, prohibited_col_names: list[str]) -> None:
def validate_absence_of_columns(df: DataFrame, prohibited_col_names: list[str], return_bool: bool = False) -> Union[None, bool]:
"""Validate that none of the prohibited column names are present among specified DataFrame columns.
:param df: DataFrame containing columns to be checked.
:param prohibited_col_names: List of prohibited column names.
:param return_bool: If True, return a boolean instead of raising an exception.
:type return_bool: bool
:return: None if return_bool is False, otherwise a boolean indicating if validation passed.
:raises DataFrameProhibitedColumnError: If the prohibited column names are
present among the specified DataFrame columns.
present among the specified DataFrame columns and return_bool is False.
"""
all_col_names = df.columns
extra_col_names = [x for x in all_col_names if x in prohibited_col_names]
error_message = f"The {extra_col_names} columns are not allowed to be included in the DataFrame with the following columns {all_col_names}"

Check failure on line 116 in quinn/dataframe_validator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W293)

quinn/dataframe_validator.py:116:1: W293 Blank line contains whitespace
if extra_col_names:
error_message = f"The {extra_col_names} columns are not allowed to be included in the DataFrame with the following columns {all_col_names}"
if return_bool:
return False
raise DataFrameProhibitedColumnError(error_message)

Check failure on line 122 in quinn/dataframe_validator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W293)

quinn/dataframe_validator.py:122:1: W293 Blank line contains whitespace
return True if return_bool else None

Check failure on line 123 in quinn/dataframe_validator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W292)

quinn/dataframe_validator.py:123:41: W292 No newline at end of file
74 changes: 61 additions & 13 deletions tests/test_dataframe_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,36 @@


def describe_validate_presence_of_columns():
def it_raises_if_a_required_column_is_missing():
def it_raises_if_a_required_column_is_missing_and_return_bool_is_false():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
with pytest.raises(quinn.DataFrameMissingColumnError) as excinfo:
quinn.validate_presence_of_columns(source_df, ["name", "age", "fun"])
quinn.validate_presence_of_columns(source_df, ["name", "age", "fun"], False)
assert (
excinfo.value.args[0]
== "The ['fun'] columns are not included in the DataFrame with the following columns ['name', 'age']"
)

def it_does_nothing_if_all_required_columns_are_present():
def it_does_nothing_if_all_required_columns_are_present_and_return_bool_is_false():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
quinn.validate_presence_of_columns(source_df, ["name"])
quinn.validate_presence_of_columns(source_df, ["name"], False)

def it_returns_false_if_a_required_column_is_missing_and_return_bool_is_true():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
result = quinn.validate_presence_of_columns(source_df, ["name", "age", "fun"], True)
assert result is False

def it_returns_true_if_all_required_columns_are_present_and_return_bool_is_true():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
result = quinn.validate_presence_of_columns(source_df, ["name"], True)
assert result is True


def describe_validate_schema():
def it_raises_when_struct_field_is_missing1():
def it_raises_when_struct_field_is_missing_and_return_bool_is_false():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
required_schema = StructType(
Expand All @@ -34,7 +46,7 @@ def it_raises_when_struct_field_is_missing1():
]
)
with pytest.raises(quinn.DataFrameMissingStructFieldError) as excinfo:
quinn.validate_schema(source_df, required_schema)
quinn.validate_schema(source_df, required_schema, return_bool = False)

current_spark_version = semver.Version.parse(spark.version)
spark_330 = semver.Version.parse("3.3.0")
Expand All @@ -44,7 +56,7 @@ def it_raises_when_struct_field_is_missing1():
expected_error_message = "The [StructField(city,StringType,true)] StructFields are not included in the DataFrame with the following StructFields StructType(List(StructField(name,StringType,true),StructField(age,LongType,true)))" # noqa
assert excinfo.value.args[0] == expected_error_message

def it_does_nothing_when_the_schema_matches():
def it_does_nothing_when_the_schema_matches_and_return_bool_is_false():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
required_schema = StructType(
Expand All @@ -53,7 +65,31 @@ def it_does_nothing_when_the_schema_matches():
StructField("age", LongType(), True),
]
)
quinn.validate_schema(source_df, required_schema)
quinn.validate_schema(source_df, required_schema, return_bool = False)

def it_returns_false_when_struct_field_is_missing_and_return_bool_is_true():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
required_schema = StructType(
[
StructField("name", StringType(), True),
StructField("city", StringType(), True),
]
)
result = quinn.validate_schema(source_df, required_schema, return_bool = True)
assert result is False

def it_returns_true_when_the_schema_matches_and_return_bool_is_true():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
required_schema = StructType(
[
StructField("name", StringType(), True),
StructField("age", LongType(), True),
]
)
result = quinn.validate_schema(source_df, required_schema, return_bool = True)
assert result is True

def nullable_column_mismatches_are_ignored():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
Expand All @@ -64,21 +100,33 @@ def nullable_column_mismatches_are_ignored():
StructField("age", LongType(), False),
]
)
quinn.validate_schema(source_df, required_schema, ignore_nullable=True)
quinn.validate_schema(source_df, required_schema, ignore_nullable=True, return_bool = False)


def describe_validate_absence_of_columns():
def it_raises_when_a_unallowed_column_is_present():
def it_raises_when_a_unallowed_column_is_present_and_return_bool_is_false():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
with pytest.raises(quinn.DataFrameProhibitedColumnError) as excinfo:
quinn.validate_absence_of_columns(source_df, ["age", "cool"])
quinn.validate_absence_of_columns(source_df, ["age", "cool"], False)
assert (
excinfo.value.args[0]
== "The ['age'] columns are not allowed to be included in the DataFrame with the following columns ['name', 'age']" # noqa
)

def it_does_nothing_when_no_unallowed_columns_are_present():
def it_does_nothing_when_no_unallowed_columns_are_present_and_return_bool_is_false():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
quinn.validate_absence_of_columns(source_df, ["favorite_color"], False)

def it_returns_false_when_a_unallowed_column_is_present_and_return_bool_is_true():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
result = quinn.validate_absence_of_columns(source_df, ["age", "cool"], True)
assert result is False

def it_returns_true_when_no_unallowed_columns_are_present_and_return_bool_is_true():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
quinn.validate_absence_of_columns(source_df, ["favorite_color"])
result = quinn.validate_absence_of_columns(source_df, ["favorite_color"], True)
assert result is True

0 comments on commit a0849f3

Please sign in to comment.