Skip to content

Commit

Permalink
Merge pull request #145 from jeffbrennan/hotfix_transformations
Browse files Browse the repository at this point in the history
fix incorrect import order
  • Loading branch information
jeffbrennan authored Nov 5, 2023
2 parents ad271eb + a586404 commit 79262db
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions quinn/transformations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
import re
import pyspark.sql.functions as F # noqa: N812
from __future__ import annotations
from collections.abc import Callable
from pyspark.sql import DataFrame
from pyspark.sql.types import ArrayType, MapType, StructField, StructType
Expand Down Expand Up @@ -100,7 +100,7 @@ def sort_columns(
:return: A DataFrame with the columns sorted in the chosen order
:rtype: pyspark.sql.DataFrame
"""

def sort_nested_cols(schema, is_reversed, base_field="") -> list[str]:
# recursively check nested fields and sort them
# https://stackoverflow.com/questions/57821538/how-to-sort-columns-of-nested-structs-alphabetically-in-pyspark
Expand Down Expand Up @@ -282,6 +282,7 @@ def flatten_map(df: DataFrame, col_name: str, separator: str = ":") -> DataFrame
[F.col(f"`{col}`") for col in df.columns if col != col_name] + key_cols,
)


def flatten_dataframe(
df: DataFrame,
separator: str = ":",
Expand Down Expand Up @@ -331,6 +332,7 @@ def flatten_dataframe(
>>> flattened_df_with_hyphen = flatten_dataframe(df, replace_char="-")
>>> flattened_df_with_hyphen.show()
"""

def sanitize_column_name(name: str, rc: str = "_") -> str:
"""Sanitizes column names by replacing special characters with the specified character.
Expand All @@ -353,7 +355,9 @@ def explode_array(df: DataFrame, col_name: str) -> DataFrame:
:return: The DataFrame with the exploded ArrayType column.
:rtype: DataFrame
"""
return df.select("*", F.explode_outer(F.col(f"`{col_name}`")).alias(col_name)).drop(
return df.select(
"*", F.explode_outer(F.col(f"`{col_name}`")).alias(col_name)
).drop(
col_name,
)

Expand All @@ -380,4 +384,4 @@ def explode_array(df: DataFrame, col_name: str) -> DataFrame:
]
df = df.toDF(*sanitized_columns) # noqa: PD901

return df
return df

0 comments on commit 79262db

Please sign in to comment.