From 4f8860470d75cc915bc386a544cb9d75a10e4c32 Mon Sep 17 00:00:00 2001 From: Sajid Alam <90610031+SajidAlamQB@users.noreply.github.com> Date: Thu, 24 Aug 2023 15:20:15 +0100 Subject: [PATCH] feat(datasets): Add Support Python 3.11 for kedro-datasets (#297) * Add 3.11 tests to CI * update dependencies * 3.11 should only run on kedro-datasets * update setup matrix * refine setup matrix * use steps * pytables for unix * add pytables via conda * fix indent * add conda * activate windows * undo last commit * remove pytables conda install * Update pyproject.toml * pin pytables 3.8 for python 3.11 * lint * pin sqlalchemy for 3.11 * install gfortran * add pip_verbose and run onlu ubuntu tests * undo gfortran install * Update pyproject.toml * add conda install pytables * remove tables from pyproject.toml * Update pyproject.toml * pin tables 3.6 for unix * Update pyproject.toml * update pyspark * Update check-plugin.yml * change base spark dependency pin 3.4 * fix spark tests * Update test_spark_jdbc_dataset.py * coverage * fix coverage * Fix streaming dataset * fix streaming test * Update test_spark_hive_dataset.py * fix base pin * Remove delta-spark pin Signed-off-by: Nok * lower pyspark pin for 3.11 * pin delta-spark to 2.4 for python 3.11 * Update conftest.py * revert * Set delta version base on delta-spark version Signed-off-by: Nok * Update setup.py * open-bound * Update setup.py * scikit pin * Update setup.py * Update pyproject.toml * align with framework setup * Update setup.py * importlib_metadata backport * add deltalake * update holoviews pin * remove miniconda and conda pytables * add windows test back in running all parallel no spark * Update check-plugin.yml * add msbuild and run only windows * Update check-plugin.yml * lint * lint * update pandas pin * replace semver with packaging * lint * lint * add empty stacktrace for AnalysisException * revert * lint with 3.11 * update python version for linting * Remove setup matrix * Update check-plugin.yml * Update check-plugin.yml * overhaul python version and os pass in * Update check-plugin.yml * revert changes * Update check-plugin.yml * rtd with 3.8 * add snowflake-snowpark * remove repeated * release notes --------- Signed-off-by: Nok Co-authored-by: Nok Lam Chan --- .github/workflows/check-plugin.yml | 41 +++++++++--- kedro-datasets/RELEASE.md | 2 + .../spark/deltatable_dataset.py | 4 +- .../kedro_datasets/spark/spark_dataset.py | 7 +- .../spark/spark_jdbc_dataset.py | 2 +- .../spark/spark_streaming_dataset.py | 6 +- kedro-datasets/pyproject.toml | 2 +- kedro-datasets/setup.py | 46 +++++++------ kedro-datasets/tests/databricks/conftest.py | 6 +- .../tests/spark/test_deltatable_dataset.py | 18 ++++-- .../tests/spark/test_spark_dataset.py | 22 +++++-- .../tests/spark/test_spark_hive_dataset.py | 9 +-- .../tests/spark/test_spark_jdbc_dataset.py | 64 +++++++++---------- .../spark/test_spark_streaming_dataset.py | 21 ++++-- 14 files changed, 156 insertions(+), 94 deletions(-) diff --git a/.github/workflows/check-plugin.yml b/.github/workflows/check-plugin.yml index 00bc8a083..e0df1114c 100644 --- a/.github/workflows/check-plugin.yml +++ b/.github/workflows/check-plugin.yml @@ -7,14 +7,31 @@ on: type: string jobs: + + setup-matrix: + # kedro-datasets is the only plugin that supports python 3.11 + runs-on: ubuntu-latest + outputs: + python-versions: ${{ steps.set-matrix.outputs.matrix }} + steps: + - id: set-matrix + run: | + if [[ "${{ inputs.plugin }}" == "kedro-datasets" ]]; then + MATRIX='["3.7", "3.8", "3.9", "3.10", "3.11"]' + else + MATRIX='["3.7", "3.8", "3.9", "3.10"]' + fi + echo "matrix=${MATRIX}" >> $GITHUB_OUTPUT + unit-tests: + needs: setup-matrix defaults: run: shell: bash strategy: matrix: os: [ ubuntu-latest, windows-latest ] - python-version: [ "3.7", "3.8", "3.9", "3.10" ] + python-version: ${{fromJson(needs.setup-matrix.outputs.python-versions)}} runs-on: ${{ matrix.os }} steps: - name: Checkout code @@ -39,6 +56,9 @@ jobs: restore-keys: ${{inputs.plugin}} - name: Install Kedro run: pip install git+https://github.com/kedro-org/kedro@main + - name: Add MSBuild to PATH + if: matrix.os == 'windows-latest' + uses: microsoft/setup-msbuild@v1 - name: Install dependencies run: | cd ${{ inputs.plugin }} @@ -53,12 +73,8 @@ jobs: run: | cd ${{ inputs.plugin }} pytest tests - - name: Run unit tests for Windows / kedro-datasets / no spark sequential - if: matrix.os == 'windows-latest' && inputs.plugin == 'kedro-datasets' && matrix.python-version == '3.10' - run: | - make test-no-spark-sequential - name: Run unit tests for Windows / kedro-datasets / no spark parallel - if: matrix.os == 'windows-latest' && inputs.plugin == 'kedro-datasets' && matrix.python-version != '3.10' + if: matrix.os == 'windows-latest' && inputs.plugin == 'kedro-datasets' run: | make test-no-spark @@ -70,10 +86,19 @@ jobs: steps: - name: Checkout code uses: actions/checkout@v3 - - name: Set up Python 3.8 + # kedro-datasets is the only plugin that supports python 3.11 + - name: Determine Python version for linting + id: get-python-version + run: | + if [[ "${{ inputs.plugin }}" == "kedro-datasets" ]]; then + echo "version=3.11" >> $GITHUB_OUTPUT + else + echo "version=3.8" >> $GITHUB_OUTPUT + fi + - name: Set up Python uses: actions/setup-python@v3 with: - python-version: 3.8 + python-version: ${{ steps.get-python-version.outputs.version }} - name: Cache python packages uses: actions/cache@v3 with: diff --git a/kedro-datasets/RELEASE.md b/kedro-datasets/RELEASE.md index a5129d9e5..81f728dcc 100644 --- a/kedro-datasets/RELEASE.md +++ b/kedro-datasets/RELEASE.md @@ -1,5 +1,7 @@ # Upcoming Release + ## Major features and improvements +* Added support for Python 3.11. ## Bug fixes and other changes diff --git a/kedro-datasets/kedro_datasets/spark/deltatable_dataset.py b/kedro-datasets/kedro_datasets/spark/deltatable_dataset.py index a5cb02b36..07eeee64f 100644 --- a/kedro-datasets/kedro_datasets/spark/deltatable_dataset.py +++ b/kedro-datasets/kedro_datasets/spark/deltatable_dataset.py @@ -98,7 +98,9 @@ def _exists(self) -> bool: try: self._get_spark().read.load(path=load_path, format="delta") except AnalysisException as exception: - if "is not a Delta table" in exception.desc: + # `AnalysisException.desc` is deprecated with pyspark >= 3.4 + message = exception.desc if hasattr(exception, "desc") else str(exception) + if "Path does not exist:" in message or "is not a Delta table" in message: return False raise diff --git a/kedro-datasets/kedro_datasets/spark/spark_dataset.py b/kedro-datasets/kedro_datasets/spark/spark_dataset.py index 88a858330..449adefc9 100644 --- a/kedro-datasets/kedro_datasets/spark/spark_dataset.py +++ b/kedro-datasets/kedro_datasets/spark/spark_dataset.py @@ -406,10 +406,9 @@ def _exists(self) -> bool: try: self._get_spark().read.load(load_path, self._file_format) except AnalysisException as exception: - if ( - exception.desc.startswith("Path does not exist:") - or "is not a Delta table" in exception.desc - ): + # `AnalysisException.desc` is deprecated with pyspark >= 3.4 + message = exception.desc if hasattr(exception, "desc") else str(exception) + if "Path does not exist:" in message or "is not a Delta table" in message: return False raise return True diff --git a/kedro-datasets/kedro_datasets/spark/spark_jdbc_dataset.py b/kedro-datasets/kedro_datasets/spark/spark_jdbc_dataset.py index b68b081e0..193f01103 100644 --- a/kedro-datasets/kedro_datasets/spark/spark_jdbc_dataset.py +++ b/kedro-datasets/kedro_datasets/spark/spark_jdbc_dataset.py @@ -170,7 +170,7 @@ def _describe(self) -> Dict[str, Any]: } @staticmethod - def _get_spark(): + def _get_spark(): # pragma: no cover return SparkSession.builder.getOrCreate() def _load(self) -> DataFrame: diff --git a/kedro-datasets/kedro_datasets/spark/spark_streaming_dataset.py b/kedro-datasets/kedro_datasets/spark/spark_streaming_dataset.py index d2609748f..e61785bf7 100644 --- a/kedro-datasets/kedro_datasets/spark/spark_streaming_dataset.py +++ b/kedro-datasets/kedro_datasets/spark/spark_streaming_dataset.py @@ -149,9 +149,11 @@ def _exists(self) -> bool: load_path, self._file_format ) except AnalysisException as exception: + # `AnalysisException.desc` is deprecated with pyspark >= 3.4 + message = exception.desc if hasattr(exception, "desc") else str(exception) if ( - exception.desc.startswith("Path does not exist:") - or "is not a Streaming data" in exception.desc + "Path does not exist:" in message + or "is not a Streaming data" in message ): return False raise diff --git a/kedro-datasets/pyproject.toml b/kedro-datasets/pyproject.toml index 31c66c227..96828d508 100644 --- a/kedro-datasets/pyproject.toml +++ b/kedro-datasets/pyproject.toml @@ -8,7 +8,7 @@ authors = [ {name = "Kedro"} ] description = "Kedro-Datasets is where you can find all of Kedro's data connectors." -requires-python = ">=3.7, <3.11" +requires-python = ">=3.7" license = {text = "Apache Software License (Apache 2.0)"} dependencies = [ "kedro>=0.16", diff --git a/kedro-datasets/setup.py b/kedro-datasets/setup.py index 63cd364cc..926210984 100644 --- a/kedro-datasets/setup.py +++ b/kedro-datasets/setup.py @@ -4,7 +4,7 @@ # at least 1.3 to be able to use XMLDataSet and pandas integration with fsspec PANDAS = "pandas>=1.3, <3.0" -SPARK = "pyspark>=2.2, <4.0" +SPARK = "pyspark>=2.2, <3.4" HDFS = "hdfs>=2.5.8, <3.0" S3FS = "s3fs>=0.3.0, <0.5" POLARS = "polars~=0.17.0" @@ -139,24 +139,26 @@ def _collect_requirements(requires): "Jinja2<3.1.0", ] extras_require["test"] = [ - "adlfs>=2021.7.1, <=2022.2", + "adlfs>=2021.7.1, <=2022.2; python_version == '3.7'", + "adlfs~=2023.1; python_version >= '3.8'", "bandit>=1.6.2, <2.0", "behave==1.2.6", "biopython~=1.73", "blacken-docs==1.9.2", "black~=22.0", - "compress-pickle[lz4]~=1.2.0", + "compress-pickle[lz4]~=2.1.0", "coverage[toml]", - "dask[complete]", - "delta-spark~=1.2.1", - # 1.2.0 has a bug that breaks some of our tests: https://github.com/delta-io/delta/issues/1070 + "dask[complete]~=2021.10", # pinned by Snyk to avoid a vulnerability + "delta-spark>=1.2.1; python_version >= '3.11'", # 1.2.0 has a bug that breaks some of our tests: https://github.com/delta-io/delta/issues/1070 + "delta-spark~=1.2.1; python_version < '3.11'", "deltalake>=0.10.0", "dill~=0.3.1", "filelock>=3.4.0, <4.0", - "gcsfs>=2021.4, <=2022.1", + "gcsfs>=2021.4, <=2023.1; python_version == '3.7'", + "gcsfs>=2023.1, <2023.3; python_version >= '3.8'", "geopandas>=0.6.0, <1.0", "hdfs>=2.5.8, <3.0", - "holoviews~=1.13.0", + "holoviews>=1.13.0", "import-linter[toml]==1.2.6", "ipython>=7.31.1, <8.0", "Jinja2<3.1.0", @@ -165,25 +167,27 @@ def _collect_requirements(requires): "jupyter~=1.0", "lxml~=4.6", "matplotlib>=3.0.3, <3.4; python_version < '3.10'", # 3.4.0 breaks holoviews - "matplotlib>=3.5, <3.6; python_version == '3.10'", + "matplotlib>=3.5, <3.6; python_version >= '3.10'", "memory_profiler>=0.50.0, <1.0", "moto==1.3.7; python_version < '3.10'", - "moto==3.0.4; python_version == '3.10'", + "moto==4.1.12; python_version >= '3.10'", "networkx~=2.4", "opencv-python~=4.5.5.64", "openpyxl>=3.0.3, <4.0", - "pandas-gbq>=0.12.0, <0.18.0", - "pandas>=1.3, <2", # 1.3 for read_xml/to_xml, <2 for compatibility with Spark < 3.4 + "pandas-gbq>=0.12.0, <0.18.0; python_version < '3.11'", + "pandas-gbq>=0.18.0; python_version >= '3.11'", + "pandas~=1.3 # 1.3 for read_xml/to_xml", "Pillow~=9.0", "plotly>=4.8.0, <6.0", "polars~=0.15.13", "pre-commit>=2.9.2, <3.0", # The hook `mypy` requires pre-commit version 2.9.2. - "psutil==5.8.0", - "pyarrow~=8.0", + "pyarrow>=1.0; python_version < '3.11'", + "pyarrow>=7.0; python_version >= '3.11'", # Adding to avoid numpy build errors "pylint>=2.5.2, <3.0", "pyodbc~=4.0.35", "pyproj~=3.0", - "pyspark>=2.2, <4.0", + "pyspark>=2.2, <3.4; python_version < '3.11'", + "pyspark>=3.4; python_version >= '3.11'", "pytest-cov~=3.0", "pytest-mock>=1.7.1, <2.0", "pytest-xdist[psutil]~=2.2.1", @@ -192,12 +196,14 @@ def _collect_requirements(requires): "requests-mock~=1.6", "requests~=2.20", "s3fs>=0.3.0, <0.5", # Needs to be at least 0.3.0 to make use of `cachable` attribute on S3FileSystem. - "scikit-learn~=1.0.2", - "scipy~=1.7.3", "snowflake-snowpark-python~=1.0.0; python_version == '3.8'", - "SQLAlchemy>=1.4, <3.0", - # The `Inspector.has_table()` method replaces the `Engine.has_table()` method in version 1.4. - "tables~=3.7", + "scikit-learn>=1.0.2,<2", + "scipy>=1.7.3", + "packaging", + "SQLAlchemy~=1.2", + "tables~=3.6.0; platform_system == 'Windows' and python_version<'3.8'", + "tables~=3.8.0; platform_system == 'Windows' and python_version>='3.8'", # Import issues with python 3.8 with pytables pinning to 3.8.0 fixes this https://github.com/PyTables/PyTables/issues/933#issuecomment-1555917593 + "tables~=3.6; platform_system != 'Windows'", "tensorflow-macos~=2.0; platform_system == 'Darwin' and platform_machine == 'arm64'", "tensorflow~=2.0; platform_system != 'Darwin' or platform_machine != 'arm64'", "triad>=0.6.7, <1.0", diff --git a/kedro-datasets/tests/databricks/conftest.py b/kedro-datasets/tests/databricks/conftest.py index 26d63b056..958ee6a83 100644 --- a/kedro-datasets/tests/databricks/conftest.py +++ b/kedro-datasets/tests/databricks/conftest.py @@ -4,15 +4,19 @@ discover them automatically. More info here: https://docs.pytest.org/en/latest/fixture.html """ +# importlib_metadata needs backport for python 3.8 and older +import importlib_metadata as importlib_metadata # pylint: disable=useless-import-alias import pytest from pyspark.sql import SparkSession +DELTA_VERSION = importlib_metadata.version("delta-spark") + @pytest.fixture(scope="class", autouse=True) def spark_session(): spark = ( SparkSession.builder.appName("test") - .config("spark.jars.packages", "io.delta:delta-core_2.12:1.2.1") + .config("spark.jars.packages", f"io.delta:delta-core_2.12:{DELTA_VERSION}") .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") .config( "spark.sql.catalog.spark_catalog", diff --git a/kedro-datasets/tests/spark/test_deltatable_dataset.py b/kedro-datasets/tests/spark/test_deltatable_dataset.py index 5cbbe62b7..c39a8b1bf 100644 --- a/kedro-datasets/tests/spark/test_deltatable_dataset.py +++ b/kedro-datasets/tests/spark/test_deltatable_dataset.py @@ -4,12 +4,16 @@ from kedro.pipeline import node from kedro.pipeline.modular_pipeline import pipeline as modular_pipeline from kedro.runner import ParallelRunner +from packaging.version import Version +from pyspark import __version__ from pyspark.sql import SparkSession from pyspark.sql.types import IntegerType, StringType, StructField, StructType from pyspark.sql.utils import AnalysisException from kedro_datasets.spark import DeltaTableDataSet, SparkDataSet +SPARK_VERSION = Version(__version__) + @pytest.fixture def sample_spark_df(): @@ -65,10 +69,16 @@ def test_exists(self, tmp_path, sample_spark_df): def test_exists_raises_error(self, mocker): delta_ds = DeltaTableDataSet(filepath="") - mocker.patch.object( - delta_ds, "_get_spark", side_effect=AnalysisException("Other Exception", []) - ) - + if SPARK_VERSION >= Version("3.4.0"): + mocker.patch.object( + delta_ds, "_get_spark", side_effect=AnalysisException("Other Exception") + ) + else: + mocker.patch.object( + delta_ds, + "_get_spark", + side_effect=AnalysisException("Other Exception", []), + ) with pytest.raises(DataSetError, match="Other Exception"): delta_ds.exists() diff --git a/kedro-datasets/tests/spark/test_spark_dataset.py b/kedro-datasets/tests/spark/test_spark_dataset.py index 9452b007d..ab2ff7107 100644 --- a/kedro-datasets/tests/spark/test_spark_dataset.py +++ b/kedro-datasets/tests/spark/test_spark_dataset.py @@ -13,6 +13,8 @@ from kedro.pipeline.modular_pipeline import pipeline as modular_pipeline from kedro.runner import ParallelRunner, SequentialRunner from moto import mock_s3 +from packaging.version import Version as PackagingVersion +from pyspark import __version__ from pyspark.sql import SparkSession from pyspark.sql.functions import col from pyspark.sql.types import ( @@ -57,6 +59,8 @@ (HDFS_PREFIX + "/2019-02-01T00.00.00.000Z", [], ["other_file"]), ] +SPARK_VERSION = PackagingVersion(__version__) + @pytest.fixture def sample_pandas_df() -> pd.DataFrame: @@ -403,12 +407,18 @@ def test_exists_raises_error(self, mocker): # exists should raise all errors except for # AnalysisExceptions clearly indicating a missing file spark_data_set = SparkDataSet(filepath="") - mocker.patch.object( - spark_data_set, - "_get_spark", - side_effect=AnalysisException("Other Exception", []), - ) - + if SPARK_VERSION >= PackagingVersion("3.4.0"): + mocker.patch.object( + spark_data_set, + "_get_spark", + side_effect=AnalysisException("Other Exception"), + ) + else: + mocker.patch.object( + spark_data_set, + "_get_spark", + side_effect=AnalysisException("Other Exception", []), + ) with pytest.raises(DataSetError, match="Other Exception"): spark_data_set.exists() diff --git a/kedro-datasets/tests/spark/test_spark_hive_dataset.py b/kedro-datasets/tests/spark/test_spark_hive_dataset.py index e0b8fc333..038200358 100644 --- a/kedro-datasets/tests/spark/test_spark_hive_dataset.py +++ b/kedro-datasets/tests/spark/test_spark_hive_dataset.py @@ -293,12 +293,9 @@ def test_read_from_non_existent_table(self): ) with pytest.raises( DataSetError, - match=r"Failed while loading data from data set " - r"SparkHiveDataSet\(database=default_1, format=hive, " - r"table=table_doesnt_exist, table_pk=\[\], write_mode=append\)\.\n" - r"Table or view not found: default_1.table_doesnt_exist;\n" - r"'UnresolvedRelation \[default_1, " - r"table_doesnt_exist\], \[\], false\n", + match=r"Failed while loading data from data set SparkHiveDataSet" + r"|table_doesnt_exist" + r"|UnresolvedRelation", ): dataset.load() diff --git a/kedro-datasets/tests/spark/test_spark_jdbc_dataset.py b/kedro-datasets/tests/spark/test_spark_jdbc_dataset.py index 0f3d0e66b..46b86f42b 100644 --- a/kedro-datasets/tests/spark/test_spark_jdbc_dataset.py +++ b/kedro-datasets/tests/spark/test_spark_jdbc_dataset.py @@ -1,5 +1,3 @@ -from unittest import mock - import pytest from kedro.io import DataSetError @@ -53,57 +51,52 @@ def test_missing_table(): SparkJDBCDataSet(url="dummy_url", table=None) -def mock_save(arg_dict): - mock_data = mock.Mock() - data_set = SparkJDBCDataSet(**arg_dict) +def test_save(mocker, spark_jdbc_args): + mock_data = mocker.Mock() + data_set = SparkJDBCDataSet(**spark_jdbc_args) data_set.save(mock_data) - return mock_data - - -def test_save(spark_jdbc_args): - data = mock_save(spark_jdbc_args) - data.write.jdbc.assert_called_with("dummy_url", "dummy_table") + mock_data.write.jdbc.assert_called_with("dummy_url", "dummy_table") -def test_save_credentials(spark_jdbc_args_credentials): - data = mock_save(spark_jdbc_args_credentials) - data.write.jdbc.assert_called_with( +def test_save_credentials(mocker, spark_jdbc_args_credentials): + mock_data = mocker.Mock() + data_set = SparkJDBCDataSet(**spark_jdbc_args_credentials) + data_set.save(mock_data) + mock_data.write.jdbc.assert_called_with( "dummy_url", "dummy_table", properties={"user": "dummy_user", "password": "dummy_pw"}, ) -def test_save_args(spark_jdbc_args_save_load): - data = mock_save(spark_jdbc_args_save_load) - data.write.jdbc.assert_called_with( +def test_save_args(mocker, spark_jdbc_args_save_load): + mock_data = mocker.Mock() + data_set = SparkJDBCDataSet(**spark_jdbc_args_save_load) + data_set.save(mock_data) + mock_data.write.jdbc.assert_called_with( "dummy_url", "dummy_table", properties={"driver": "dummy_driver"} ) -def test_except_bad_credentials(spark_jdbc_args_credentials_with_none_password): +def test_except_bad_credentials(mocker, spark_jdbc_args_credentials_with_none_password): pattern = r"Credential property 'password' cannot be None(.+)" with pytest.raises(DataSetError, match=pattern): - mock_save(spark_jdbc_args_credentials_with_none_password) + mock_data = mocker.Mock() + data_set = SparkJDBCDataSet(**spark_jdbc_args_credentials_with_none_password) + data_set.save(mock_data) -@mock.patch("kedro_datasets.spark.spark_jdbc_dataset.SparkSession.builder.getOrCreate") -def mock_load(mock_get_or_create, arg_dict): - spark = mock_get_or_create.return_value - data_set = SparkJDBCDataSet(**arg_dict) +def test_load(mocker, spark_jdbc_args): + spark = mocker.patch.object(SparkJDBCDataSet, "_get_spark").return_value + data_set = SparkJDBCDataSet(**spark_jdbc_args) data_set.load() - return spark - - -def test_load(spark_jdbc_args): - # pylint: disable=no-value-for-parameter - spark = mock_load(arg_dict=spark_jdbc_args) spark.read.jdbc.assert_called_with("dummy_url", "dummy_table") -def test_load_credentials(spark_jdbc_args_credentials): - # pylint: disable=no-value-for-parameter - spark = mock_load(arg_dict=spark_jdbc_args_credentials) +def test_load_credentials(mocker, spark_jdbc_args_credentials): + spark = mocker.patch.object(SparkJDBCDataSet, "_get_spark").return_value + data_set = SparkJDBCDataSet(**spark_jdbc_args_credentials) + data_set.load() spark.read.jdbc.assert_called_with( "dummy_url", "dummy_table", @@ -111,9 +104,10 @@ def test_load_credentials(spark_jdbc_args_credentials): ) -def test_load_args(spark_jdbc_args_save_load): - # pylint: disable=no-value-for-parameter - spark = mock_load(arg_dict=spark_jdbc_args_save_load) +def test_load_args(mocker, spark_jdbc_args_save_load): + spark = mocker.patch.object(SparkJDBCDataSet, "_get_spark").return_value + data_set = SparkJDBCDataSet(**spark_jdbc_args_save_load) + data_set.load() spark.read.jdbc.assert_called_with( "dummy_url", "dummy_table", properties={"driver": "dummy_driver"} ) diff --git a/kedro-datasets/tests/spark/test_spark_streaming_dataset.py b/kedro-datasets/tests/spark/test_spark_streaming_dataset.py index c4fb6c005..d3e16f8a8 100644 --- a/kedro-datasets/tests/spark/test_spark_streaming_dataset.py +++ b/kedro-datasets/tests/spark/test_spark_streaming_dataset.py @@ -4,6 +4,8 @@ import pytest from kedro.io.core import DataSetError from moto import mock_s3 +from packaging.version import Version +from pyspark import __version__ from pyspark.sql import SparkSession from pyspark.sql.types import IntegerType, StringType, StructField, StructType from pyspark.sql.utils import AnalysisException @@ -15,6 +17,8 @@ BUCKET_NAME = "test_bucket" AWS_CREDENTIALS = {"key": "FAKE_ACCESS_KEY", "secret": "FAKE_SECRET_KEY"} +SPARK_VERSION = Version(__version__) + def sample_schema(schema_path): """read the schema file from json path""" @@ -168,11 +172,18 @@ def test_exists_raises_error(self, mocker): # exists should raise all errors except for # AnalysisExceptions clearly indicating a missing file spark_data_set = SparkStreamingDataSet(filepath="") - mocker.patch.object( - spark_data_set, - "_get_spark", - side_effect=AnalysisException("Other Exception", []), - ) + if SPARK_VERSION >= Version("3.4.0"): + mocker.patch.object( + spark_data_set, + "_get_spark", + side_effect=AnalysisException("Other Exception"), + ) + else: + mocker.patch.object( + spark_data_set, + "_get_spark", + side_effect=AnalysisException("Other Exception", []), + ) with pytest.raises(DataSetError, match="Other Exception"): spark_data_set.exists()