Skip to content

Commit

Permalink
Add better tests for split columns (#138)
Browse files Browse the repository at this point in the history
  • Loading branch information
MrPowers authored Oct 12, 2023
1 parent 92932e7 commit fe33bbe
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 26 deletions.
9 changes: 5 additions & 4 deletions quinn/split_columns.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Optional

from pyspark.sql.functions import length, split, trim, udf, when
from pyspark.sql.types import IntegerType
Expand All @@ -14,8 +15,8 @@ def split_col( # noqa: PLR0913
col_name: str,
delimiter: str,
new_col_names: list[str],
mode: str = "strict",
default: str = "default",
mode: str = "permissive",
default: Optional[str] = None,

Check failure on line 19 in quinn/split_columns.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP007)

quinn/split_columns.py:19:14: UP007 Use `X | Y` for type annotations
) -> DataFrame:
"""Splits the given column based on the delimiter and creates new columns with the split values.
Expand All @@ -27,7 +28,7 @@ def split_col( # noqa: PLR0913
:type delimiter: str
:param new_col_names: A list of two strings for the new column names
:type new_col_names: (List[str])
:param mode: The split mode. Can be "strict" or "permissive". Default is "strict"
:param mode: The split mode. Can be "strict" or "permissive". Default is "permissive"
:type mode: str
:param default: If the mode is "permissive" then default value will be assigned to column
:type mode: str
Expand Down Expand Up @@ -81,7 +82,7 @@ def _num_delimiter(col_value1: str) -> int:
# If col_value is None, return 0
return 0

num_udf = udf(lambda y: _num_delimiter(y), IntegerType())
num_udf = udf(lambda y: None if y is None else _num_delimiter(y), IntegerType())

# Get the column expression for the column to be split
col_expr = df[col_name]
Expand Down
61 changes: 39 additions & 22 deletions tests/test_split_columns.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,54 @@
import quinn
from tests.conftest import auto_inject_fixtures
import chispa
import pytest


@auto_inject_fixtures("spark")
def test_split_columns(spark):
# Create Spark DataFrame
data = [("chrisXXmoe", 2025, "bio"),
("davidXXbb", 2026, "physics"),
("sophiaXXraul", 2022, "bio"),
("fredXXli", 2025, "physics"),
("someXXperson", 2023, "math"),
("liXXyao", 2025, "physics")]

(None, 2025, "physics")]
df = spark.createDataFrame(data, ["student_name", "graduation_year", "major"])
# Define the delimiter
delimiter = "XX"

# New column names
new_col_names = ["student_first_name", "student_last_name"]

col_name = "student_name"
mode = "strict"
# Call split_col() function to split "student_name" column
new_df = quinn.split_col(df, col_name, delimiter, new_col_names, mode)

new_df = quinn.split_col(
df,
col_name="student_name",
delimiter="XX",
new_col_names=["student_first_name", "student_last_name"],
mode="permissive")
data = [(2025, "bio", "chris", "moe"),
(2026, "physics", "david", "bb"),
(2022, "bio", "sophia", "raul"),
(2025, "physics", "fred", "li"),
(2023, "math", "some", "person"),
(2025, "physics", "li", "yao")]

(2025, "physics", None, None)]
expected = spark.createDataFrame(data, ["graduation_year", "major", "student_first_name", "student_last_name"])
chispa.assert_df_equality(new_df, expected)

def test_split_columns_advanced(spark):
data = [("chrisXXsomethingXXmoe", 2025, "bio"),
("davidXXbb", 2026, "physics"),
(None, 2025, "physics")]
df = spark.createDataFrame(data, ["student_name", "graduation_year", "major"])
new_df = quinn.split_col(
df,
col_name="student_name",
delimiter="XX",
new_col_names=["student_first_name", "student_middle_name", "student_last_name"],
mode="permissive")
data = [(2025, "bio", "chris", "something", "moe"),
(2026, "physics", "david", "bb", None),
(2025, "physics", None, None, None)]
expected = spark.createDataFrame(data, ["graduation_year", "major", "student_first_name", "student_middle_name", "student_last_name"])
chispa.assert_df_equality(new_df, expected)

def test_split_columns_strict(spark):
data = [("chrisXXsomethingXXmoe", 2025, "bio"),
("davidXXbb", 2026, "physics"),
(None, 2025, "physics")]
df = spark.createDataFrame(data, ["student_name", "graduation_year", "major"])
df2 = quinn.split_col(
df,
col_name="student_name",
delimiter="XX",
new_col_names=["student_first_name", "student_middle_name", "student_last_name"],
mode="strict", default="hi")
with pytest.raises(IndexError):
df2.show()

0 comments on commit fe33bbe

Please sign in to comment.