Skip to content

Commit

Permalink
Backports v0.9.10 (#2653)
Browse files Browse the repository at this point in the history
* Add `gluonts.util.safe_extract` (#2606)

Co-authored-by: Jasper <[email protected]>
Co-authored-by: Lorenzo Stella <[email protected]>

* Fix call to `extractall` (#2648)

* remove py36

* fix requirements

* fix flake8

* Cap numpy compatibility in `mxnet` extra requirements (#2506)

* xfail multivariate grouper test

Co-authored-by: Lorenzo Stella <[email protected]>
Co-authored-by: Jasper <[email protected]>

* roll back undesired change

* fixup

* remove np.long

* fix

* fix flake8

---------

Co-authored-by: Jasper <[email protected]>
Co-authored-by: Lorenzo Stella <[email protected]>
Co-authored-by: Jasper <[email protected]>
  • Loading branch information
4 people authored Feb 15, 2023
1 parent 5424442 commit 3099156
Show file tree
Hide file tree
Showing 17 changed files with 114 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests-torch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
max-parallel: 4
fail-fast: false
matrix:
python-version: [3.6, 3.7, 3.8]
python-version: [3.7, 3.8]
platform: [ubuntu-latest]

runs-on: ${{ matrix.platform }}
Expand Down
10 changes: 2 additions & 8 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
max-parallel: 4
fail-fast: false
matrix:
python-version: [3.6, 3.7, 3.8]
python-version: [3.7, 3.8]
platform: [ubuntu-latest]

runs-on: ${{ matrix.platform }}
Expand All @@ -19,16 +19,10 @@ jobs:
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Install MXNet (Linux)
if: ${{ runner.os == 'Linux' }}
run: pip install mxnet~=1.8.0
- name: Install MXNet (Windows)
if: ${{ runner.os == 'Windows' }}
run: pip install mxnet~=1.7.0
- name: Install other dependencies
run: |
python -m pip install -U pip
pip install ".[shell]"
pip install ".[mxnet,shell]"
pip install -r requirements/requirements-test.txt
pip install -r requirements/requirements-extras-m-competitions.txt
- name: Test with pytest
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements-extras-prophet.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
fbprophet>=0.4.*
fbprophet>=0.4.0
2 changes: 1 addition & 1 deletion requirements/requirements-extras-r.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
rpy2>=2.9.*,<3.*
rpy2>=2.9.0,<3.0
3 changes: 3 additions & 0 deletions requirements/requirements-mxnet.txt
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
# upper bound added since numpy==1.24 broke importing mxnet,
# see https://github.com/awslabs/gluonts/pull/2506
numpy<1.24
mxnet~=1.7
2 changes: 1 addition & 1 deletion requirements/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
flaky~=3.6
pytest-cov==2.6.*
pytest-cov~=2.6.0
pytest-timeout~=1.3
pytest-xdist~=1.27
pytest~=5.0
Expand Down
3 changes: 2 additions & 1 deletion src/gluonts/dataset/repository/_gp_copula_2019.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from gluonts.dataset.common import FileDataset
from gluonts.dataset.field_names import FieldName
from gluonts.dataset.repository._util import metadata, save_to_file, to_dict
from gluonts.util import safe_extractall


class GPCopulaDataset(NamedTuple):
Expand Down Expand Up @@ -122,7 +123,7 @@ def download_dataset(dataset_path: Path, ds_info: GPCopulaDataset):
request.urlretrieve(ds_info.url, dataset_path / f"{ds_info.name}.tar.gz")

with tarfile.open(dataset_path / f"{ds_info.name}.tar.gz") as tar:
tar.extractall(path=dataset_path)
safe_extractall(tar, path=dataset_path)


def save_metadata(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def create_transformation(self, is_full_batch=False) -> Transformation:
AsNumpyArray(
field=FieldName.FEAT_STATIC_CAT,
expected_ndim=1,
dtype=np.long,
dtype=int,
),
AsNumpyArray(
field=FieldName.FEAT_STATIC_REAL,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def use_marginal_transformation(
AsNumpyArray(
field=FieldName.FEAT_STATIC_CAT,
expected_ndim=1,
dtype=np.long,
dtype=int,
),
AsNumpyArray(
field=FieldName.FEAT_STATIC_REAL, expected_ndim=1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def create_transformation(self) -> Transformation:
AsNumpyArray(
field=FieldName.FEAT_STATIC_CAT,
expected_ndim=1,
dtype=np.long,
dtype=int,
),
AsNumpyArray(
field=FieldName.FEAT_STATIC_REAL,
Expand Down
3 changes: 2 additions & 1 deletion src/gluonts/nursery/sagemaker_sdk/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from gluonts.dataset.repository import datasets
from gluonts.model.estimator import Estimator
from gluonts.model.predictor import Predictor
from gluonts.util import safe_extractall

from .defaults import (
ENTRY_POINTS_FOLDER,
Expand Down Expand Up @@ -502,7 +503,7 @@ def _retrieve_model(self, locations):
with self._s3fs.open(locations.model_archive, "rb") as stream:
with tarfile.open(mode="r:gz", fileobj=stream) as archive:
with TemporaryDirectory() as temp_dir:
archive.extractall(temp_dir)
safe_extractall(archive, temp_dir)
predictor = Predictor.deserialize(Path(temp_dir))

return predictor
Expand Down
3 changes: 2 additions & 1 deletion src/gluonts/nursery/tsbench/src/cli/evaluations/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Any, cast, Dict, List, Optional
import botocore
import click
from gluonts.util import safe_extract
from tqdm.auto import tqdm
from tqdm.contrib.concurrent import process_map
from tsbench.analysis.utils import run_parallel
Expand Down Expand Up @@ -97,7 +98,7 @@ def _download_public_evaluations(
file = Path(tmp) / "metrics.tar.gz"
client.download_file(public_bucket, "metrics.tar.gz", str(file))
with tarfile.open(file, mode="r:gz") as tar:
tar.extractall(evaluations_path)
safe_extractall(tar, evaluations_path)

# Then, optionally download the forecasts
if include_forecasts:
Expand Down
12 changes: 8 additions & 4 deletions src/gluonts/shell/sagemaker/dyn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from pathlib import Path
from typing import Optional

from gluonts.util import safe_extractall


class Installer:
def __init__(self, packages):
Expand Down Expand Up @@ -63,10 +65,12 @@ def pip_install(self, path: Path):
def install(self, path):
if path.is_file():
if tarfile.is_tarfile(path):
self.handle_archive(tarfile.open, path)
self.handle_archive(tarfile.open, safe_extractall, path)

elif zipfile.is_zipfile(path):
self.handle_archive(zipfile.ZipFile, path)
self.handle_archive(
zipfile.ZipFile, zipfile.ZipFile.extractall, path
)

elif path.suffix == ".py":
self.copy_install(path)
Expand All @@ -80,14 +84,14 @@ def install(self, path):
for subpath in path.iterdir():
self.install(subpath)

def handle_archive(self, open_fn, path):
def handle_archive(self, open_fn, extractall_fn, path):
with open_fn(path) as archive:
tempdir = tempfile.mkdtemp()
self.cleanups.append(
partial(shutil.rmtree, tempdir, ignore_errors=True)
)

archive.extractall(tempdir)
extractall_fn(archive, tempdir)
self.install(Path(tempdir))


Expand Down
4 changes: 1 addition & 3 deletions src/gluonts/torch/model/deepar/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@

from typing import List, Optional, Iterable, Dict, Any

import numpy as np

import torch
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -166,7 +164,7 @@ def create_transformation(self) -> Transformation:
AsNumpyArray(
field=FieldName.FEAT_STATIC_CAT,
expected_ndim=1,
dtype=np.long,
dtype=int,
),
AsNumpyArray(
field=FieldName.FEAT_STATIC_REAL,
Expand Down
48 changes: 48 additions & 0 deletions src/gluonts/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import tarfile
from pathlib import Path


def will_extractall_into(tar: tarfile.TarFile, path: Path) -> None:
"""
Check that the content of ``tar`` will be extracted within ``path``
upon calling ``extractall``.
Raise a ``PermissionError`` if not.
"""
path = Path(path).resolve()

for member in tar.getmembers():
member_path = (path / member.name).resolve()

try:
member_path.relative_to(path)
except ValueError:
raise PermissionError(f"'{member.name}' extracts out of target.")


def safe_extractall(
tar: tarfile.TarFile,
path: Path = Path("."),
members=None,
*,
numeric_owner=False,
):
"""
Safe wrapper around ``TarFile.extractall`` that checks all destination
files to be strictly within the given ``path``.
"""
will_extractall_into(tar, path)
tar.extractall(path, members, numeric_owner=numeric_owner)
3 changes: 3 additions & 0 deletions test/dataset/test_multivariate_grouper.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ def test_multivariate_grouper_train(
MAX_TARGET_DIM = [2, 1]


@pytest.mark.xfail(
reason="This test is known to fail with numpy>=1.24, and a fix is pending"
)
@pytest.mark.parametrize(
"univariate_ts, multivariate_ts, test_fill_rule, max_target_dim",
zip(
Expand Down
40 changes: 36 additions & 4 deletions test/test_forecaster_entrypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,41 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import pkg_resources
import tempfile
import tarfile
from pathlib import Path
from typing import Optional

import pytest

# def test_forecaster_entrypoints():
# for entry_point in pkg_resources.iter_entry_points("gluonts_forecasters"):
# entry_point.load()
from gluonts.util import will_extractall_into


@pytest.mark.parametrize(
"arcname, expect_failure",
[
(None, False),
("./file.txt", False),
("/a/../file.txt", False),
("/a/../../file.txt", True),
("../file.txt", True),
],
)
def test_will_extractall_into(arcname: Optional[str], expect_failure: bool):
with tempfile.TemporaryDirectory() as tempdir:
file_path = Path(tempdir) / "a" / "file.txt"
file_path.parent.mkdir(parents=True)
file_path.touch()

with tarfile.open(Path(tempdir) / "archive.tar.gz", "w:gz") as tar:
tar.add(file_path, arcname=arcname)

if expect_failure:
with pytest.raises(PermissionError):
with tarfile.open(
Path(tempdir) / "archive.tar.gz", "r:gz"
) as tar:
will_extractall_into(tar, Path(tempdir) / "b")
else:
with tarfile.open(Path(tempdir) / "archive.tar.gz", "r:gz") as tar:
will_extractall_into(tar, Path(tempdir) / "b")

0 comments on commit 3099156

Please sign in to comment.