Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50133][PYTHON] Support df.argument() for conversion to table argument in Spark Classic #48914

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,11 @@ class Dataset[T] private[sql] (

// TODO(SPARK-50134): Support scalar Subquery API in Spark Connect
// scalastyle:off not.implemented.error.usage
/** @inheritdoc */
def argument(): Column = {
???
}

/** @inheritdoc */
def scalar(): Column = {
???
Expand Down
3 changes: 3 additions & 0 deletions python/pyspark/sql/classic/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1786,6 +1786,9 @@ def transpose(self, indexColumn: Optional["ColumnOrName"] = None) -> ParentDataF
else:
return DataFrame(self._jdf.transpose(), self.sparkSession)

def argument(self) -> Column:
return Column(self._jdf.argument())

def scalar(self) -> Column:
return Column(self._jdf.scalar())

Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1784,6 +1784,12 @@ def transpose(self, indexColumn: Optional["ColumnOrName"] = None) -> ParentDataF
self._session,
)

def argument(self) -> Column:
raise PySparkNotImplementedError(
errorClass="NOT_IMPLEMENTED",
messageParameters={"feature": "argument()"},
)

def scalar(self) -> Column:
# TODO(SPARK-50134): Implement this method
raise PySparkNotImplementedError(
Expand Down
33 changes: 33 additions & 0 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6476,6 +6476,39 @@ def transpose(self, indexColumn: Optional["ColumnOrName"] = None) -> "DataFrame"
"""
...

def argument(self) -> Column:
"""
Converts the DataFrame into a `Column` object for use with table-valued functions (TVFs)
or user-defined table functions (UDTFs).

.. versionadded:: 4.0.0

Returns
-------
:class:`Column`
A `Column` object representing the DataFrame.

Examples
--------
>>> from pyspark.sql import Row
>>> from pyspark.sql.functions import udtf
>>>
>>> @udtf(returnType="a: int")
... class TestUDTF:
... def eval(self, row: Row):
... if row[0] > 5:
... yield row[0],
>>> df = spark.range(8)
>>> TestUDTF(df.argument()).show() # doctest: +SKIP
+---+
| a|
+---+
| 6|
| 7|
+---+
"""
...

def scalar(self) -> Column:
"""
Return a `Column` object for a SCALAR Subquery containing exactly one row and one column.
Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/sql/tests/connect/test_parity_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def test_udtf_with_analyze_using_file(self):
def test_udtf_access_spark_session(self):
super().test_udtf_access_spark_session()

@unittest.skip("Spark Connect does not support df.argument()")
def test_df_argument(self):
super().test_df_argument()

def _add_pyfile(self, path):
self.spark.addArtifacts(path, pyfile=True)

Expand Down
13 changes: 13 additions & 0 deletions python/pyspark/sql/tests/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,19 @@ def eval(self, row: Row):
[Row(a=6), Row(a=7)],
)

def test_df_argument(self):
class TestUDTF:
def eval(self, row: Row):
if row["id"] > 5:
yield row["id"],

func = udtf(TestUDTF, returnType="a: int")
df = self.spark.range(8)
self.assertEqual(
func(df.argument()).collect(),
[Row(a=6), Row(a=7)],
)
ueshin marked this conversation as resolved.
Show resolved Hide resolved

def test_udtf_with_int_and_table_argument_query(self):
class TestUDTF:
def eval(self, i: int, row: Row):
Expand Down
9 changes: 9 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1699,6 +1699,15 @@ abstract class Dataset[T] extends Serializable {
*/
def transpose(): Dataset[Row]

/**
* Converts the DataFrame into a `Column` object for use with table-valued functions (TVFs) or
* user-defined table functions (UDTFs).
*
* @group typedrel
* @since 4.0.0
*/
def argument(): Column

/**
* Return a `Column` object for a SCALAR Subquery containing exactly one row and one column.
*
Expand Down
6 changes: 6 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,12 @@ class Dataset[T] private[sql](
)
}

/** @inheritdoc */
def argument(): Column = {
val tableExpr = FunctionTableSubqueryArgumentExpression(logicalPlan)
Column(tableExpr)
}

/** @inheritdoc */
def scalar(): Column = {
Column(ExpressionColumnNode(
Expand Down