Skip to content

Commit

Permalink
added documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
dannymeijer committed Nov 8, 2024
1 parent b37a302 commit 6d6ccbd
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
6 changes: 3 additions & 3 deletions src/koheesio/integrations/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,12 +362,12 @@ class BoxReaderBase(Box, Reader, ABC):
default_factory=dict,
description="[Optional] Set of extra parameters that should be passed to the Spark reader.",
)

file_encoding: Optional[str] = Field(
default="utf-8",
description="[Optional] Set file encoding format. By default is utf-8."
default="utf-8", description="[Optional] Set file encoding format. By default is utf-8."
)


class BoxCsvFileReader(BoxReaderBase):
"""
Class facilitates reading one or multiple CSV files with the same structure directly from Box and
Expand Down
10 changes: 8 additions & 2 deletions src/koheesio/spark/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def get_spark_minor_version() -> float:


def check_if_pyspark_connect_is_supported() -> bool:
"""Check if the current version of PySpark supports the connect module"""
result = False
module_name: str = "pyspark"
if SPARK_MINOR_VERSION >= 3.5:
Expand All @@ -93,6 +94,7 @@ def check_if_pyspark_connect_is_supported() -> bool:


if check_if_pyspark_connect_is_supported():
"""Only import the connect module if the current version of PySpark supports it"""
from pyspark.errors.exceptions.captured import (
ParseException as CapturedParseException,
)
Expand Down Expand Up @@ -122,6 +124,7 @@ def check_if_pyspark_connect_is_supported() -> bool:
DataStreamWriter = Union[sql.streaming.readwriter.DataStreamWriter, DataStreamWriter]
StreamingQuery = StreamingQuery
else:
"""Import the regular PySpark modules if the current version of PySpark does not support the connect module"""
try:
from pyspark.errors.exceptions.captured import ParseException # type: ignore
except (ImportError, ModuleNotFoundError):
Expand Down Expand Up @@ -152,6 +155,7 @@ def check_if_pyspark_connect_is_supported() -> bool:


def get_active_session() -> SparkSession: # type: ignore
"""Get the active Spark session"""
if check_if_pyspark_connect_is_supported():
from pyspark.sql.connect.session import SparkSession as _ConnectSparkSession

Expand Down Expand Up @@ -321,7 +325,6 @@ def import_pandas_based_on_pyspark_version() -> ModuleType:
raise ImportError("Pandas module is not installed.") from e


# noinspection PyProtectedMember
def show_string(df: DataFrame, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> str: # type: ignore
"""Returns a string representation of the DataFrame
The default implementation of DataFrame.show() hardcodes a print statement, which is not always desirable.
Expand All @@ -348,12 +351,13 @@ def show_string(df: DataFrame, n: int = 20, truncate: Union[bool, int] = True, v
If set to True, display the DataFrame vertically, by default False
"""
if SPARK_MINOR_VERSION < 3.5:
# noinspection PyProtectedMember
return df._jdf.showString(n, truncate, vertical) # type: ignore
# as per spark 3.5, the _show_string method is now available making calls to _jdf.showString obsolete
# noinspection PyProtectedMember
return df._show_string(n, truncate, vertical)


# noinspection PyProtectedMember
def get_column_name(col: Column) -> str: # type: ignore
"""Get the column name from a Column object
Expand All @@ -373,8 +377,10 @@ def get_column_name(col: Column) -> str: # type: ignore
# we have to distinguish between the Column object from column from local session and remote
if hasattr(col, "_jc"):
# In case of a 'regular' Column object, we can directly access the name attribute through the _jc attribute
# noinspection PyProtectedMember
name = col._jc.toString() # type: ignore[operator]
elif any(cls.__module__ == "pyspark.sql.connect.column" for cls in inspect.getmro(col.__class__)):
# noinspection PyProtectedMember
name = col._expr.name()
else:
raise ValueError("Column object is not a valid Column object")
Expand Down

0 comments on commit 6d6ccbd

Please sign in to comment.