From 27edbcdbd610923544c60d42be200e8f78cf56a2 Mon Sep 17 00:00:00 2001 From: Nate Parsons <4307001+thehomebrewnerd@users.noreply.github.com> Date: Mon, 13 May 2024 10:47:42 -0500 Subject: [PATCH] Use filter arg to safe extract archives (#1862) * use filter to safe extract archives * update release notes * update actions * add tests * fix action * final test --- .github/workflows/pull_request_check.yaml | 2 +- .github/workflows/release_notes_updated.yaml | 4 +- docs/source/release_notes.rst | 5 +- woodwork/deserializers/deserializer_base.py | 8 ++- .../deserializers/parquet_deserializer.py | 8 ++- woodwork/deserializers/utils.py | 8 ++- woodwork/tests/accessor/test_serialization.py | 70 ++++++++++++++++++- 7 files changed, 97 insertions(+), 8 deletions(-) diff --git a/.github/workflows/pull_request_check.yaml b/.github/workflows/pull_request_check.yaml index 074222531..e749af5e0 100644 --- a/.github/workflows/pull_request_check.yaml +++ b/.github/workflows/pull_request_check.yaml @@ -7,7 +7,7 @@ jobs: name: pull request check runs-on: ubuntu-latest steps: - - uses: nearform/github-action-check-linked-issues@v1 + - uses: nearform-actions/github-action-check-linked-issues@v1 id: check-linked-issues with: exclude-branches: "release_v**, backport_v**, main, latest-dep-update-**, min-dep-update-**, dependabot/**" diff --git a/.github/workflows/release_notes_updated.yaml b/.github/workflows/release_notes_updated.yaml index 11dac4b93..3ea6d9f68 100644 --- a/.github/workflows/release_notes_updated.yaml +++ b/.github/workflows/release_notes_updated.yaml @@ -12,6 +12,8 @@ jobs: - name: Check for development branch id: branch shell: python + env: + REF: ${{ github.event.pull_request.head.ref }} run: | from re import compile main = '^main$' @@ -21,7 +23,7 @@ jobs: min_dep_update = '^min-dep-update-[a-f0-9]{7}$' regex = main, release, backport, dep_update, min_dep_update patterns = list(map(compile, regex)) - ref = "${{ github.event.pull_request.head.ref }}" + ref = "$REF" is_dev = not any(pattern.match(ref) for pattern in patterns) print('::set-output name=is_dev::' + str(is_dev)) - if: ${{ steps.branch.outputs.is_dev == 'True' }} diff --git a/docs/source/release_notes.rst b/docs/source/release_notes.rst index a5ca1a639..cafc5f5fe 100644 --- a/docs/source/release_notes.rst +++ b/docs/source/release_notes.rst @@ -6,10 +6,11 @@ Release Notes Future Release ============== * Enhancements + * Add support for Python 3.12 :pr:`1855` * Fixes * Changes - * Add support for Python 3.12 :pr:`1855` - * Drop support for using Woodwork with Dask or Pyspark dataframes (:pr:`1857`) + * Drop support for using Woodwork with Dask or Pyspark dataframes :pr:`1857` + * Use ``filter`` arg in call to ``tarfile.extractall`` to safely deserialize DataFrames :pr:`1862` * Documentation Changes * Testing Changes diff --git a/woodwork/deserializers/deserializer_base.py b/woodwork/deserializers/deserializer_base.py index 418e426f3..ee8b10c6b 100644 --- a/woodwork/deserializers/deserializer_base.py +++ b/woodwork/deserializers/deserializer_base.py @@ -2,6 +2,7 @@ import tarfile import tempfile import warnings +from inspect import getfullargspec from itertools import zip_longest from pathlib import Path @@ -125,7 +126,12 @@ def read_from_s3(self, profile_name): use_smartopen(tar_filepath, self.path, transport_params) with tarfile.open(str(tar_filepath)) as tar: - tar.extractall(path=tmpdir) + if "filter" in getfullargspec(tar.extractall).kwonlyargs: + tar.extractall(path=tmpdir, filter="data") + else: + raise RuntimeError( + "Please upgrade your Python version to the latest patch release to allow for safe extraction of the EntitySet archive.", + ) self.read_path = os.path.join( tmpdir, self.typing_info["loading_info"]["location"], diff --git a/woodwork/deserializers/parquet_deserializer.py b/woodwork/deserializers/parquet_deserializer.py index 27a03ce29..662aa6aee 100644 --- a/woodwork/deserializers/parquet_deserializer.py +++ b/woodwork/deserializers/parquet_deserializer.py @@ -2,6 +2,7 @@ import os import tarfile import tempfile +from inspect import getfullargspec from pathlib import Path import pandas as pd @@ -61,7 +62,12 @@ def read_from_s3(self, profile_name): use_smartopen(tar_filepath, self.path, transport_params) with tarfile.open(str(tar_filepath)) as tar: - tar.extractall(path=tmpdir) + if "filter" in getfullargspec(tar.extractall).kwonlyargs: + tar.extractall(path=tmpdir, filter="data") + else: + raise RuntimeError( + "Please upgrade your Python version to the latest patch release to allow for safe extraction of the EntitySet archive.", + ) self.read_path = os.path.join(tmpdir, self.data_subdirectory, self.filename) diff --git a/woodwork/deserializers/utils.py b/woodwork/deserializers/utils.py index 2d25c775c..1c7ca561e 100644 --- a/woodwork/deserializers/utils.py +++ b/woodwork/deserializers/utils.py @@ -2,6 +2,7 @@ import os import tarfile import tempfile +from inspect import getfullargspec from pathlib import Path from woodwork.deserializers import ( @@ -99,7 +100,12 @@ def read_table_typing_information(path, typing_info_filename, profile_name): use_smartopen(file_path, path, transport_params) with tarfile.open(str(file_path)) as tar: - tar.extractall(path=tmpdir) + if "filter" in getfullargspec(tar.extractall).kwonlyargs: + tar.extractall(path=tmpdir, filter="data") + else: + raise RuntimeError( + "Please upgrade your Python version to the latest patch release to allow for safe extraction of the EntitySet archive.", + ) file = os.path.join(tmpdir, typing_info_filename) with open(file, "r") as file: diff --git a/woodwork/tests/accessor/test_serialization.py b/woodwork/tests/accessor/test_serialization.py index c169965e6..23412dca1 100644 --- a/woodwork/tests/accessor/test_serialization.py +++ b/woodwork/tests/accessor/test_serialization.py @@ -2,7 +2,7 @@ import os import shutil import warnings -from unittest.mock import patch +from unittest.mock import MagicMock, patch import boto3 import pandas as pd @@ -662,6 +662,35 @@ def test_to_csv_S3(sample_df, s3_client, s3_bucket, profile_name): assert sample_df.ww.schema == deserialized_df.ww.schema +@patch("woodwork.deserializers.utils.getfullargspec") +def test_to_csv_S3_errors_if_python_version_unsafe( + mock_inspect, + sample_df, + s3_client, + s3_bucket, +): + mock_response = MagicMock() + mock_response.kwonlyargs = [] + mock_inspect.return_value = mock_response + sample_df.ww.init( + name="test_data", + index="id", + semantic_tags={"id": "tag1"}, + logical_types={"age": Ordinal(order=[25, 33, 57])}, + ) + sample_df.ww.to_disk( + TEST_S3_URL, + format="csv", + encoding="utf-8", + engine="python", + profile_name=None, + ) + make_public(s3_client, s3_bucket) + + with pytest.raises(RuntimeError, match="Please upgrade your Python version"): + read_woodwork_table(TEST_S3_URL, profile_name=None) + + @pytest.mark.parametrize("profile_name", [None, False]) def test_serialize_s3_pickle(sample_df, s3_client, s3_bucket, profile_name): sample_df.ww.init() @@ -673,6 +702,23 @@ def test_serialize_s3_pickle(sample_df, s3_client, s3_bucket, profile_name): assert sample_df.ww.schema == deserialized_df.ww.schema +@patch("woodwork.deserializers.deserializer_base.getfullargspec") +def test_serialize_s3_pickle_errors_if_python_version_unsafe( + mock_inspect, + sample_df, + s3_client, + s3_bucket, +): + mock_response = MagicMock() + mock_response.kwonlyargs = [] + mock_inspect.return_value = mock_response + sample_df.ww.init() + sample_df.ww.to_disk(TEST_S3_URL, format="pickle", profile_name=None) + make_public(s3_client, s3_bucket) + with pytest.raises(RuntimeError, match="Please upgrade your Python version"): + read_woodwork_table(TEST_S3_URL, profile_name=None) + + @pytest.mark.parametrize("profile_name", [None, False]) def test_serialize_s3_parquet(sample_df, s3_client, s3_bucket, profile_name): sample_df.ww.init() @@ -688,6 +734,28 @@ def test_serialize_s3_parquet(sample_df, s3_client, s3_bucket, profile_name): assert sample_df.ww.schema == deserialized_df.ww.schema +@patch("woodwork.deserializers.parquet_deserializer.getfullargspec") +def test_serialize_s3_parquet_errors_if_python_version_unsafe( + mock_inspect, + sample_df, + s3_client, + s3_bucket, +): + mock_response = MagicMock() + mock_response.kwonlyargs = [] + mock_inspect.return_value = mock_response + sample_df.ww.init() + sample_df.ww.to_disk(TEST_S3_URL, format="parquet", profile_name=None) + make_public(s3_client, s3_bucket) + + with pytest.raises(RuntimeError, match="Please upgrade your Python version"): + read_woodwork_table( + TEST_S3_URL, + filename="data.parquet", + profile_name=None, + ) + + def create_test_credentials(test_path): with open(test_path, "w+") as f: f.write("[test]\n")