From 7d564b7d1e535d1c6c1a828ce35411dfda1037ec Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Wed, 4 Sep 2024 09:26:20 +0900 Subject: [PATCH] [SPARK-49477][PYTHON] Improve pandas udf invalid return type error message ### What changes were proposed in this pull request? This PR improves the error message when the specified return type of a pandas udf mismatch the actual return type. ### Why are the changes needed? To improve the error message. Before this PR: `pyspark.errors.exceptions.base.PySparkValueError: A field of type StructType expects a pandas.DataFrame, but got: ` After this PR: `pyspark.errors.exceptions.base.PySparkValueError: Invalid return type. Please make sure that the UDF returns a pandas.DataFrame when the specified return type is StructType.` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? New unit test ### Was this patch authored or co-authored using generative AI tooling? No Closes #47942 from allisonwang-db/spark-49477-pandas-udf-err-msg. Authored-by: allisonwang-db Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/pandas/serializers.py | 4 ++-- python/pyspark/sql/tests/pandas/test_pandas_udf.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 6203d4d19d866..076226865f3a7 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -510,8 +510,8 @@ def _create_batch(self, series): # If it returns a pd.Series, it should throw an error. if not isinstance(s, pd.DataFrame): raise PySparkValueError( - "A field of type StructType expects a pandas.DataFrame, " - "but got: %s" % str(type(s)) + "Invalid return type. Please make sure that the UDF returns a " + "pandas.DataFrame when the specified return type is StructType." ) arrs.append(self._create_struct_array(s, t)) else: diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf.py b/python/pyspark/sql/tests/pandas/test_pandas_udf.py index 6720dfc37d0cc..228fc30b497cc 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf.py @@ -339,6 +339,19 @@ def noop(s: pd.Series) -> pd.Series: self.assertEqual(df.schema[0].dataType.simpleString(), "interval day to second") self.assertEqual(df.first()[0], datetime.timedelta(microseconds=123)) + def test_pandas_udf_return_type_error(self): + import pandas as pd + + @pandas_udf("s string") + def upper(s: pd.Series) -> pd.Series: + return s.str.upper() + + df = self.spark.createDataFrame([("a",)], schema="s string") + + self.assertRaisesRegex( + PythonException, "Invalid return type", df.select(upper("s")).collect + ) + class PandasUDFTests(PandasUDFTestsMixin, ReusedSQLTestCase): pass