Skip to content

Commit

Permalink
[SPARK-49477][PYTHON] Improve pandas udf invalid return type error me…
Browse files Browse the repository at this point in the history
…ssage

### What changes were proposed in this pull request?

This PR improves the error message when the specified return type of a pandas udf mismatch the actual return type.

### Why are the changes needed?

To improve the error message.

Before this PR:
`pyspark.errors.exceptions.base.PySparkValueError: A field of type StructType expects a pandas.DataFrame, but got: <class 'pandas.core.series.Series'>`

After this PR:
`pyspark.errors.exceptions.base.PySparkValueError: Invalid return type. Please make sure that the UDF returns a pandas.DataFrame when the specified return type is StructType.`

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

New unit test

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#47942 from allisonwang-db/spark-49477-pandas-udf-err-msg.

Authored-by: allisonwang-db <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
allisonwang-db authored and HyukjinKwon committed Sep 4, 2024
1 parent 6165b12 commit 7d564b7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,8 +510,8 @@ def _create_batch(self, series):
# If it returns a pd.Series, it should throw an error.
if not isinstance(s, pd.DataFrame):
raise PySparkValueError(
"A field of type StructType expects a pandas.DataFrame, "
"but got: %s" % str(type(s))
"Invalid return type. Please make sure that the UDF returns a "
"pandas.DataFrame when the specified return type is StructType."
)
arrs.append(self._create_struct_array(s, t))
else:
Expand Down
13 changes: 13 additions & 0 deletions python/pyspark/sql/tests/pandas/test_pandas_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,19 @@ def noop(s: pd.Series) -> pd.Series:
self.assertEqual(df.schema[0].dataType.simpleString(), "interval day to second")
self.assertEqual(df.first()[0], datetime.timedelta(microseconds=123))

def test_pandas_udf_return_type_error(self):
import pandas as pd

@pandas_udf("s string")
def upper(s: pd.Series) -> pd.Series:
return s.str.upper()

df = self.spark.createDataFrame([("a",)], schema="s string")

self.assertRaisesRegex(
PythonException, "Invalid return type", df.select(upper("s")).collect
)


class PandasUDFTests(PandasUDFTestsMixin, ReusedSQLTestCase):
pass
Expand Down

0 comments on commit 7d564b7

Please sign in to comment.