From 44545bdd8a78cd95578097eeb88a603ae8650163 Mon Sep 17 00:00:00 2001 From: Dev-iL <6509619+Dev-iL@users.noreply.github.com> Date: Fri, 9 Aug 2024 10:53:30 +0300 Subject: [PATCH] Add temporary workaround for SPARK-48710 https://issues.apache.org/jira/browse/SPARK-48710 --- hamilton/registry.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/hamilton/registry.py b/hamilton/registry.py index f4d432192..14e6d562b 100644 --- a/hamilton/registry.py +++ b/hamilton/registry.py @@ -3,6 +3,9 @@ import importlib import logging from typing import Any, Dict, Optional, Type +from unittest.mock import patch + +from numpy import nan logger = logging.getLogger(__name__) @@ -78,7 +81,9 @@ def load_extension(plugin_module: str): :param plugin_module: the module name sans .py. e.g. pandas, polars, pyspark_pandas. """ - mod = importlib.import_module(f"hamilton.plugins.{plugin_module}_extensions") + with patch.dict("np.NaN", nan): + # Workaround for https://issues.apache.org/jira/browse/SPARK-48710 + mod = importlib.import_module(f"hamilton.plugins.{plugin_module}_extensions") # We have various plugin extensions. We default to assuming it's a dataframe extension with columns, # unless it explicitly says it's not. # We need to check the following if we are to enable `@extract_columns` for example.