diff --git a/docs/getting_started/models.md b/docs/getting_started/models.md
index fb7db0adc9..476391e8e9 100644
--- a/docs/getting_started/models.md
+++ b/docs/getting_started/models.md
@@ -2,6 +2,9 @@
Model + Paper | Local/global | Data layout | Architecture/method | Implementation
-------------------------------------------------------------|--------------|--------------------------|---------------------|----------------
+PatchTST
[Nie et al., 2023][Nie2023] | Global | Univariate | MLP, multi-head attention | [Pytorch][PatchTST_torch]
+LagTST
| Global | Univariate | MLP, multi-head attention | [Pytorch][LagTST_torch]
+DLinear
[Zeng et al., 2023][Zeng2023] | Global | Univariate | MLP | [Pytorch][DLinear_torch]
DeepAR
[Salinas et al. 2020][Salinas2020] | Global | Univariate | RNN | [MXNet][DeepAR_mx], [PyTorch][DeepAR_torch]
DeepState
[Rangapuram et al. 2018][Rangapuram2018] | Global | Univariate | RNN, state-space model | [MXNet][DeepState]
DeepFactor
[Wang et al. 2019][Wang2019] | Global | Univariate | RNN, state-space model, Gaussian process | [MXNet][DeepFactor]
@@ -30,6 +33,8 @@ NPTS | Local | Un
+[Nie2023]: https://arxiv.org/abs/2211.14730
+[Zeng2023]: https://arxiv.org/abs/2205.13504
[Rangapuram2021]: https://proceedings.mlr.press/v139/rangapuram21a.html
[Salinas2020]: https://doi.org/10.1016/j.ijforecast.2019.07.001
[Rangapuram2018]: https://papers.nips.cc/paper/2018/hash/5cf68969fb67aa6082363a6d4e6468e2-Abstract.html
@@ -52,6 +57,9 @@ NPTS | Local | Un
+[PatchTST_torch]: https://github.com/awslabs/gluonts/blob/dev/src/gluonts/torch/model/patch_tst/estimator.py
+[LagTST_torch]: https://github.com/awslabs/gluonts/blob/dev/src/gluonts/torch/model/lag_tst/estimator.py
+[DLinear_torch]: https://github.com/awslabs/gluonts/blob/dev/src/gluonts/torch/model/d_linear/estimator.py
[DeepAR_mx]: https://github.com/awslabs/gluonts/blob/dev/src/gluonts/mx/model/deepar/_estimator.py
[DeepAR_torch]: https://github.com/awslabs/gluonts/blob/dev/src/gluonts/torch/model/deepar/estimator.py
[DeepState]: https://github.com/awslabs/gluonts/blob/dev/src/gluonts/mx/model/deepstate/_estimator.py
diff --git a/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template b/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template
index 34e6937b19..9290260a13 100644
--- a/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template
+++ b/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template
@@ -110,10 +110,10 @@ class FeedForwardNetwork(nn.Module):
torch.nn.init.zeros_(lin.bias)
return lin
- def forward(self, context):
- scale = self.scaling(context)
- scaled_context = context / scale
- nn_out = self.nn(scaled_context)
+ def forward(self, past_target):
+ scale = self.scaling(past_target)
+ scaled_past_target = past_target / scale
+ nn_out = self.nn(scaled_past_target)
nn_out_reshaped = nn_out.reshape(-1, self.prediction_length, self.hidden_dimensions[-1])
distr_args = self.args_proj(nn_out_reshaped)
return distr_args, torch.zeros_like(scale), scale
@@ -143,15 +143,15 @@ class LightningFeedForwardNetwork(FeedForwardNetwork, pl.LightningModule):
super().__init__(*args, **kwargs)
def training_step(self, batch, batch_idx):
- context = batch["past_target"]
- target = batch["future_target"]
+ past_target = batch["past_target"]
+ future_target = batch["future_target"]
- assert context.shape[-1] == self.context_length
- assert target.shape[-1] == self.prediction_length
+ assert past_target.shape[-1] == self.context_length
+ assert future_target.shape[-1] == self.prediction_length
- distr_args, loc, scale = self(context)
+ distr_args, loc, scale = self(past_target)
distr = self.distr_output.distribution(distr_args, loc, scale)
- loss = -distr.log_prob(target)
+ loss = -distr.log_prob(future_target)
return loss.mean()
diff --git a/requirements/requirements-extras-anomaly-evaluation.txt b/requirements/requirements-extras-anomaly-evaluation.txt
index 5e5bf9e2bb..1f664c4994 100644
--- a/requirements/requirements-extras-anomaly-evaluation.txt
+++ b/requirements/requirements-extras-anomaly-evaluation.txt
@@ -1,2 +1,2 @@
numba~=0.51,<0.54
-scikit-learn~=0.22
\ No newline at end of file
+scikit-learn~=1.0
\ No newline at end of file
diff --git a/requirements/requirements-extras-sagemaker-sdk.txt b/requirements/requirements-extras-sagemaker-sdk.txt
index 87c31e0b5d..8c7c2cb795 100644
--- a/requirements/requirements-extras-sagemaker-sdk.txt
+++ b/requirements/requirements-extras-sagemaker-sdk.txt
@@ -1,4 +1,4 @@
-sagemaker~=2.0
+sagemaker~=2.0,>2.214.3
s3fs~=0.6; python_version >= "3.7.0"
s3fs~=0.5; python_version < "3.7.0"
fsspec~=0.8,<0.9; python_version < "3.7.0"
diff --git a/requirements/requirements-pytorch.txt b/requirements/requirements-pytorch.txt
index 683bb101d7..e3e9432c0d 100644
--- a/requirements/requirements-pytorch.txt
+++ b/requirements/requirements-pytorch.txt
@@ -1,6 +1,6 @@
torch>=1.9,<3
-lightning>=2.0,<2.2
+lightning>=2.2.2,<2.4
# Capping `lightning` does not cap `pytorch_lightning`, so we cap manually
-pytorch_lightning>=2.0,<2.2
+pytorch_lightning>=2.2.2,<2.4
scipy~=1.10; python_version > "3.7.0"
scipy~=1.7.3; python_version <= "3.7.0"
diff --git a/requirements/requirements-rotbaum.txt b/requirements/requirements-rotbaum.txt
index 4c597d89ad..7a47b62f8b 100644
--- a/requirements/requirements-rotbaum.txt
+++ b/requirements/requirements-rotbaum.txt
@@ -1,2 +1,2 @@
xgboost>=0.90,<2
-scikit-learn>=0.22,<2
+scikit-learn~=1.0
diff --git a/src/gluonts/core/serde/_base.py b/src/gluonts/core/serde/_base.py
index 9efee3aab8..97de04d65d 100644
--- a/src/gluonts/core/serde/_base.py
+++ b/src/gluonts/core/serde/_base.py
@@ -287,6 +287,15 @@ def encode_partial(v: partial) -> Any:
}
+decode_disallow = [
+ eval,
+ exec,
+ compile,
+ open,
+ input,
+]
+
+
def decode(r: Any) -> Any:
"""
Decodes a value from an intermediate representation `r`.
@@ -312,7 +321,10 @@ def decode(r: Any) -> Any:
kind = r["__kind__"]
cls = cast(Any, locate(r["class"]))
- assert cls is not None, f"Can not locate {r['class']}."
+ if cls is None:
+ raise ValueError(f"Cannot locate {r['class']}.")
+ if cls in decode_disallow:
+ raise ValueError(f"{r['class']} cannot be run.")
if kind == Kind.Type:
return cls
diff --git a/src/gluonts/ext/rotbaum/_predictor.py b/src/gluonts/ext/rotbaum/_predictor.py
index bab40f11fd..b97e715ae7 100644
--- a/src/gluonts/ext/rotbaum/_predictor.py
+++ b/src/gluonts/ext/rotbaum/_predictor.py
@@ -13,7 +13,6 @@
import concurrent.futures
import logging
-import pickle
from itertools import chain
from typing import Iterator, List, Optional, Any, Dict
from toolz import first
@@ -24,6 +23,7 @@
from itertools import compress
from gluonts.core.component import validated
+from gluonts.core.serde import dump_json, load_json
from gluonts.dataset.common import Dataset
from gluonts.dataset.util import forecast_start
from gluonts.model.forecast import Forecast
@@ -355,8 +355,8 @@ class name, version information and constructor arguments.
generated when pickling the TreePredictor.
"""
super().serialize(path)
- with (path / "predictor.pkl").open("wb") as f:
- pickle.dump(self.model_list, f)
+ with (path / "model_list.json").open("w") as fp:
+ print(dump_json(self.model_list), file=fp)
@classmethod
def deserialize(cls, path: Path, **kwargs: Any) -> "TreePredictor":
@@ -369,8 +369,8 @@ def deserialize(cls, path: Path, **kwargs: Any) -> "TreePredictor":
predictor = super().deserialize(path)
assert isinstance(predictor, cls)
- with (path / "predictor.pkl").open("rb") as f:
- predictor.model_list = pickle.load(f)
+ with (path / "model_list.json").open("r") as fp:
+ predictor.model_list = load_json(fp.read())
return predictor
def explain(
diff --git a/src/gluonts/model/forecast_generator.py b/src/gluonts/model/forecast_generator.py
index 42caf1fa9e..0148a8e1e6 100644
--- a/src/gluonts/model/forecast_generator.py
+++ b/src/gluonts/model/forecast_generator.py
@@ -82,6 +82,18 @@ def make_distribution_forecast(distr, *args, **kwargs) -> Forecast:
raise NotImplementedError
+def make_predictions(prediction_net, inputs: dict):
+ try:
+ # Feed inputs as positional arguments for MXNet block predictors
+ import mxnet as mx
+
+ if isinstance(prediction_net, mx.gluon.Block):
+ return prediction_net(*inputs.values())
+ except ImportError:
+ pass
+ return prediction_net(**inputs)
+
+
class ForecastGenerator:
"""
Classes used to bring the output of a network into a class.
@@ -115,7 +127,7 @@ def __call__(
) -> Iterator[Forecast]:
for batch in inference_data_loader:
inputs = select(input_names, batch, ignore_missing=True)
- (outputs,), loc, scale = prediction_net(*inputs.values())
+ (outputs,), loc, scale = make_predictions(prediction_net, inputs)
outputs = to_numpy(outputs)
if scale is not None:
outputs = outputs * to_numpy(scale[..., None])
@@ -159,14 +171,16 @@ def __call__(
) -> Iterator[Forecast]:
for batch in inference_data_loader:
inputs = select(input_names, batch, ignore_missing=True)
- outputs = to_numpy(prediction_net(*inputs.values()))
+ outputs = to_numpy(make_predictions(prediction_net, inputs))
if output_transform is not None:
outputs = output_transform(batch, outputs)
if num_samples:
num_collected_samples = outputs[0].shape[0]
collected_samples = [outputs]
while num_collected_samples < num_samples:
- outputs = to_numpy(prediction_net(*inputs.values()))
+ outputs = to_numpy(
+ make_predictions(prediction_net, inputs)
+ )
if output_transform is not None:
outputs = output_transform(batch, outputs)
collected_samples.append(outputs)
@@ -209,7 +223,7 @@ def __call__(
) -> Iterator[Forecast]:
for batch in inference_data_loader:
inputs = select(input_names, batch, ignore_missing=True)
- outputs = prediction_net(*inputs.values())
+ outputs = make_predictions(prediction_net, inputs)
if output_transform:
log_once(OUTPUT_TRANSFORM_NOT_SUPPORTED_MSG)
diff --git a/src/gluonts/nursery/few_shot_prediction/pyproject.toml b/src/gluonts/nursery/few_shot_prediction/pyproject.toml
index 0c0100af9e..2b61158b4d 100644
--- a/src/gluonts/nursery/few_shot_prediction/pyproject.toml
+++ b/src/gluonts/nursery/few_shot_prediction/pyproject.toml
@@ -16,7 +16,7 @@ gluonts = {git = "https://github.com/awslabs/gluon-ts.git"}
pandas = "^1.3.1"
python = "^3.8,<3.10"
pytorch-lightning = "^1.4.4"
-sagemaker = "^2.40.0,<2.41.0"
+sagemaker = "^2.218.0"
scikit-learn = "^1.4.0"
torch = "^1.9.0"
sagemaker-training = "^3.9.2"
@@ -27,7 +27,7 @@ catch22 = "^0.2.0"
seaborn = "^0.11.2"
[tool.poetry.dev-dependencies]
-black = "^21.7b0"
+black = "^24.3.0"
isort = "^5.9.3"
jupyter = "^1.0.0"
pylint = "^2.10.2"
diff --git a/src/gluonts/nursery/robust-mts-attack/pts/modules/distribution_output.py b/src/gluonts/nursery/robust-mts-attack/pts/modules/distribution_output.py
index 9b6a4cb157..dd528262af 100644
--- a/src/gluonts/nursery/robust-mts-attack/pts/modules/distribution_output.py
+++ b/src/gluonts/nursery/robust-mts-attack/pts/modules/distribution_output.py
@@ -44,11 +44,10 @@
TransformedImplicitQuantile,
)
from gluonts.core.component import validated
-from gluonts.torch.modules.distribution_output import (
- DistributionOutput,
- LambdaLayer,
- PtArgProj,
-)
+from gluonts.torch.distributions.distribution_output import DistributionOutput
+from gluonts.torch.modules.lambda_layer import LambdaLayer
+from gluonts.torch.distributions.output import PtArgProj
+
from pts.modules.iqn_modules import ImplicitQuantileModule
diff --git a/src/gluonts/nursery/tsbench/pyproject.toml b/src/gluonts/nursery/tsbench/pyproject.toml
index 6232fd0e03..767cfe0f2f 100644
--- a/src/gluonts/nursery/tsbench/pyproject.toml
+++ b/src/gluonts/nursery/tsbench/pyproject.toml
@@ -25,7 +25,7 @@ plotly = "^5.3.1"
pyarrow = "^14.0.1"
pydantic = "^1.8.2"
pygmo = "^2.16.1"
-pymongo = "^3.12.0"
+pymongo = "^4.6.3"
pystan = "^2.0.0"
python = ">=3.8,<3.9"
pytorch-lightning = "^1.5.0"
@@ -43,7 +43,7 @@ ujson = "^5.1.0"
xgboost = "^1.4.1"
[tool.poetry.dev-dependencies]
-black = "^21.5b1"
+black = "^24.3.0"
isort = "^5.8.0"
jupyter = "^1.0.0"
mypy = "^0.812"
diff --git a/src/gluonts/time_feature/_base.py b/src/gluonts/time_feature/_base.py
index 0d88971002..3aa53a55ef 100644
--- a/src/gluonts/time_feature/_base.py
+++ b/src/gluonts/time_feature/_base.py
@@ -11,6 +11,7 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
+from packaging.version import Version
from typing import Any, Callable, Dict, List
import numpy as np
@@ -196,7 +197,10 @@ def norm_freq_str(freq_str: str) -> str:
# Note: Secondly ("S") frequency exists, where we don't want to remove the
# "S"!
if len(base_freq) >= 2 and base_freq.endswith("S"):
- return base_freq[:-1]
+ base_freq = base_freq[:-1]
+ # In pandas >= 2.2, period end frequencies have been renamed, e.g. "M" -> "ME"
+ if Version(pd.__version__) >= Version("2.2.0"):
+ base_freq += "E"
return base_freq
@@ -252,17 +256,13 @@ def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:
Unsupported frequency {freq_str}
The following frequencies are supported:
-
- Y - yearly
- alias: A
- Q - quarterly
- M - monthly
- W - weekly
- D - daily
- B - business days
- H - hourly
- T - minutely
- alias: min
- S - secondly
+
"""
+
+ for offset_cls in features_by_offsets:
+ offset = offset_cls()
+ supported_freq_msg += (
+ f"\t{offset.freqstr.split('-')[0]} - {offset_cls.__name__}"
+ )
+
raise RuntimeError(supported_freq_msg)
diff --git a/src/gluonts/time_feature/seasonality.py b/src/gluonts/time_feature/seasonality.py
index 62026fc691..9cb2581a24 100644
--- a/src/gluonts/time_feature/seasonality.py
+++ b/src/gluonts/time_feature/seasonality.py
@@ -33,6 +33,7 @@
"ME": 12,
"B": 5,
"Q": 4,
+ "QE": 4,
}
diff --git a/src/gluonts/torch/distributions/distribution_output.py b/src/gluonts/torch/distributions/distribution_output.py
index af786ca4ef..e583d941b5 100644
--- a/src/gluonts/torch/distributions/distribution_output.py
+++ b/src/gluonts/torch/distributions/distribution_output.py
@@ -90,14 +90,6 @@ def loss(
nll = nll * (variance.detach() ** self.beta)
return nll
- @property
- def event_shape(self) -> Tuple:
- r"""
- Shape of each individual event contemplated by the distributions that
- this object constructs.
- """
- raise NotImplementedError()
-
@property
def event_dim(self) -> int:
r"""
diff --git a/src/gluonts/torch/distributions/output.py b/src/gluonts/torch/distributions/output.py
index 49385f0e55..83d22246bb 100644
--- a/src/gluonts/torch/distributions/output.py
+++ b/src/gluonts/torch/distributions/output.py
@@ -105,6 +105,13 @@ def loss(
"""
raise NotImplementedError()
+ @property
+ def event_shape(self) -> Tuple:
+ r"""
+ Shape of each individual event compatible with the output object.
+ """
+ raise NotImplementedError()
+
@property
def forecast_generator(self) -> ForecastGenerator:
raise NotImplementedError()
diff --git a/src/gluonts/torch/distributions/quantile_output.py b/src/gluonts/torch/distributions/quantile_output.py
index 4bce0fad53..ce104c5703 100644
--- a/src/gluonts/torch/distributions/quantile_output.py
+++ b/src/gluonts/torch/distributions/quantile_output.py
@@ -37,6 +37,10 @@ def __init__(self, quantiles: List[float]) -> None:
def forecast_generator(self) -> ForecastGenerator:
return QuantileForecastGenerator(quantiles=self.quantiles)
+ @property
+ def event_shape(self) -> Tuple:
+ return ()
+
def domain_map(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]:
return args
diff --git a/src/gluonts/torch/model/patch_tst/estimator.py b/src/gluonts/torch/model/patch_tst/estimator.py
index 94ae5b5444..4eea0be9d7 100644
--- a/src/gluonts/torch/model/patch_tst/estimator.py
+++ b/src/gluonts/torch/model/patch_tst/estimator.py
@@ -30,6 +30,7 @@
TestSplitSampler,
ExpectedNumInstanceSampler,
SelectFields,
+ RenameFields,
)
from gluonts.torch.model.estimator import PyTorchLightningEstimator
from gluonts.torch.model.predictor import PyTorchPredictor
@@ -74,6 +75,8 @@ class PatchTSTEstimator(PyTorchLightningEstimator):
Number of attention heads in the Transformer encoder which must divide d_model.
dim_feedforward
Size of hidden layers in the Transformer encoder.
+ num_feat_dynamic_real
+ Number of dynamic real features in the data (default: 0).
dropout
Dropout probability in the Transformer encoder.
activation
@@ -115,6 +118,7 @@ def __init__(
d_model: int = 32,
nhead: int = 4,
dim_feedforward: int = 128,
+ num_feat_dynamic_real: int = 0,
dropout: float = 0.1,
activation: str = "relu",
norm_first: bool = False,
@@ -151,6 +155,7 @@ def __init__(
self.d_model = d_model
self.nhead = nhead
self.dim_feedforward = dim_feedforward
+ self.num_feat_dynamic_real = num_feat_dynamic_real
self.dropout = dropout
self.activation = activation
self.norm_first = norm_first
@@ -166,17 +171,26 @@ def __init__(
)
def create_transformation(self) -> Transformation:
- return SelectFields(
- [
- FieldName.ITEM_ID,
- FieldName.INFO,
- FieldName.START,
- FieldName.TARGET,
- ],
- allow_missing=True,
- ) + AddObservedValuesIndicator(
- target_field=FieldName.TARGET,
- output_field=FieldName.OBSERVED_VALUES,
+ return (
+ SelectFields(
+ [
+ FieldName.ITEM_ID,
+ FieldName.INFO,
+ FieldName.START,
+ FieldName.TARGET,
+ ]
+ + (
+ [FieldName.FEAT_DYNAMIC_REAL]
+ if self.num_feat_dynamic_real > 0
+ else []
+ ),
+ allow_missing=True,
+ )
+ + RenameFields({FieldName.FEAT_DYNAMIC_REAL: FieldName.FEAT_TIME})
+ + AddObservedValuesIndicator(
+ target_field=FieldName.TARGET,
+ output_field=FieldName.OBSERVED_VALUES,
+ )
)
def create_lightning_module(self) -> pl.LightningModule:
@@ -192,6 +206,7 @@ def create_lightning_module(self) -> pl.LightningModule:
"d_model": self.d_model,
"nhead": self.nhead,
"dim_feedforward": self.dim_feedforward,
+ "num_feat_dynamic_real": self.num_feat_dynamic_real,
"dropout": self.dropout,
"activation": self.activation,
"norm_first": self.norm_first,
@@ -220,7 +235,10 @@ def _create_instance_splitter(
instance_sampler=instance_sampler,
past_length=self.context_length,
future_length=self.prediction_length,
- time_series_fields=[FieldName.OBSERVED_VALUES],
+ time_series_fields=[FieldName.OBSERVED_VALUES]
+ + (
+ [FieldName.FEAT_TIME] if self.num_feat_dynamic_real > 0 else []
+ ),
dummy_value=self.distr_output.value_in_support,
)
@@ -239,7 +257,15 @@ def create_training_data_loader(
instances,
batch_size=self.batch_size,
shuffle_buffer_length=shuffle_buffer_length,
- field_names=TRAINING_INPUT_NAMES,
+ field_names=TRAINING_INPUT_NAMES
+ + (
+ [
+ f"past_{FieldName.FEAT_TIME}",
+ f"future_{FieldName.FEAT_TIME}",
+ ]
+ if self.num_feat_dynamic_real > 0
+ else []
+ ),
output_type=torch.tensor,
num_batches_per_epoch=self.num_batches_per_epoch,
)
@@ -253,7 +279,15 @@ def create_validation_data_loader(
return as_stacked_batches(
instances,
batch_size=self.batch_size,
- field_names=TRAINING_INPUT_NAMES,
+ field_names=TRAINING_INPUT_NAMES
+ + (
+ [
+ f"past_{FieldName.FEAT_TIME}",
+ f"future_{FieldName.FEAT_TIME}",
+ ]
+ if self.num_feat_dynamic_real > 0
+ else []
+ ),
output_type=torch.tensor,
)
@@ -264,7 +298,15 @@ def create_predictor(
return PyTorchPredictor(
input_transform=transformation + prediction_splitter,
- input_names=PREDICTION_INPUT_NAMES,
+ input_names=PREDICTION_INPUT_NAMES
+ + (
+ [
+ f"past_{FieldName.FEAT_TIME}",
+ f"future_{FieldName.FEAT_TIME}",
+ ]
+ if self.num_feat_dynamic_real > 0
+ else []
+ ),
prediction_net=module,
forecast_generator=self.distr_output.forecast_generator,
batch_size=self.batch_size,
diff --git a/src/gluonts/torch/model/patch_tst/module.py b/src/gluonts/torch/model/patch_tst/module.py
index 3a59f80299..5f5c3e9fbc 100644
--- a/src/gluonts/torch/model/patch_tst/module.py
+++ b/src/gluonts/torch/model/patch_tst/module.py
@@ -11,7 +11,7 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
-from typing import Tuple
+from typing import Optional, Tuple
import numpy as np
import torch
@@ -21,7 +21,7 @@
from gluonts.model import Input, InputSpec
from gluonts.torch.distributions import StudentTOutput
from gluonts.torch.scaler import StdScaler, MeanScaler, NOPScaler
-from gluonts.torch.util import unsqueeze_expand, weighted_average
+from gluonts.torch.util import take_last, unsqueeze_expand, weighted_average
from gluonts.torch.model.simple_feedforward import make_linear_layer
@@ -85,6 +85,8 @@ class PatchTSTModel(nn.Module):
Number of time points to predict.
context_length
Number of time steps prior to prediction time that the model.
+ num_feat_dynamic_real
+ Number of dynamic real features in the data (default: 0).
distr_output
Distribution to use to evaluate observations and sample predictions.
Default: ``StudentTOutput()``.
@@ -101,6 +103,7 @@ def __init__(
d_model: int,
nhead: int,
dim_feedforward: int,
+ num_feat_dynamic_real: int,
dropout: float,
activation: str,
norm_first: bool,
@@ -120,6 +123,7 @@ def __init__(
self.d_model = d_model
self.padding_patch = padding_patch
self.distr_output = distr_output
+ self.num_feat_dynamic_real = num_feat_dynamic_real
if scaling == "mean":
self.scaler = MeanScaler(keepdim=True)
@@ -133,8 +137,11 @@ def __init__(
self.padding_patch_layer = nn.ReplicationPad1d((0, self.stride))
self.patch_num += 1
- # project from patch_len + 2 features (loc and scale) to d_model
- self.patch_proj = make_linear_layer(patch_len + 2, d_model)
+ # project from `patch_len` + 2 features (`loc` and `scale`) +
+ # `num_feat_dynamic_real` x `patch_len` to d_model
+ self.patch_proj = make_linear_layer(
+ patch_len + 2 + self.num_feat_dynamic_real * patch_len, d_model
+ )
self.positional_encoding = SinusoidalPositionalEmbedding(
self.patch_num, d_model
@@ -163,6 +170,28 @@ def __init__(
self.args_proj = self.distr_output.get_args_proj(d_model)
def describe_inputs(self, batch_size=1) -> InputSpec:
+ if self.num_feat_dynamic_real > 0:
+ input_spec_feat = {
+ "past_time_feat": Input(
+ shape=(
+ batch_size,
+ self.context_length,
+ self.num_feat_dynamic_real,
+ ),
+ dtype=torch.float,
+ ),
+ "future_time_feat": Input(
+ shape=(
+ batch_size,
+ self.prediction_length,
+ self.num_feat_dynamic_real,
+ ),
+ dtype=torch.float,
+ ),
+ }
+ else:
+ input_spec_feat = {}
+
return InputSpec(
{
"past_target": Input(
@@ -171,6 +200,7 @@ def describe_inputs(self, batch_size=1) -> InputSpec:
"past_observed_values": Input(
shape=(batch_size, self.context_length), dtype=torch.float
),
+ **input_spec_feat,
},
torch.zeros,
)
@@ -179,6 +209,8 @@ def forward(
self,
past_target: torch.Tensor,
past_observed_values: torch.Tensor,
+ past_time_feat: Optional[torch.Tensor] = None,
+ future_time_feat: Optional[torch.Tensor] = None,
) -> Tuple[Tuple[torch.Tensor, ...], torch.Tensor, torch.Tensor]:
# scale the input
past_target_scaled, loc, scale = self.scaler(
@@ -192,6 +224,25 @@ def forward(
dimension=1, size=self.patch_len, step=self.stride
)
+ # do patching for time features as well
+ if self.num_feat_dynamic_real > 0:
+ # shift time features by `prediction_length` so that they are
+ # aligned with the target input.
+ time_feat = take_last(
+ torch.cat((past_time_feat, future_time_feat), dim=1),
+ dim=1,
+ num=self.context_length,
+ )
+
+ # (bs x T x d) --> (bs x d x T) because the 1D padding is done on
+ # the last dimension.
+ time_feat = self.padding_patch_layer(
+ time_feat.transpose(-2, -1)
+ ).transpose(-2, -1)
+ time_feat_patches = time_feat.unfold(
+ dimension=1, size=self.patch_len, step=self.stride
+ ).flatten(-2, -1)
+
# add loc and scale to past_target_patches as additional features
log_abs_loc = loc.abs().log1p()
log_scale = scale.log()
@@ -202,6 +253,9 @@ def forward(
)
inputs = torch.cat((past_target_patches, expanded_static_feat), dim=-1)
+ if self.num_feat_dynamic_real > 0:
+ inputs = torch.cat((inputs, time_feat_patches), dim=-1)
+
# project patches
enc_in = self.patch_proj(inputs)
embed_pos = self.positional_encoding(enc_in.size())
@@ -224,9 +278,14 @@ def loss(
past_observed_values: torch.Tensor,
future_target: torch.Tensor,
future_observed_values: torch.Tensor,
+ past_time_feat: Optional[torch.Tensor] = None,
+ future_time_feat: Optional[torch.Tensor] = None,
) -> torch.Tensor:
distr_args, loc, scale = self(
- past_target=past_target, past_observed_values=past_observed_values
+ past_target=past_target,
+ past_observed_values=past_observed_values,
+ past_time_feat=past_time_feat,
+ future_time_feat=future_time_feat,
)
loss = self.distr_output.loss(
target=future_target, distr_args=distr_args, loc=loc, scale=scale
diff --git a/src/gluonts/torch/model/tide/estimator.py b/src/gluonts/torch/model/tide/estimator.py
index ab6c004c02..35f8cf5362 100644
--- a/src/gluonts/torch/model/tide/estimator.py
+++ b/src/gluonts/torch/model/tide/estimator.py
@@ -21,7 +21,6 @@
from gluonts.dataset.field_names import FieldName
from gluonts.dataset.loader import as_stacked_batches
from gluonts.itertools import Cyclic
-from gluonts.model.forecast_generator import DistributionForecastGenerator
from gluonts.time_feature import (
minute_of_hour,
hour_of_day,
@@ -49,10 +48,7 @@
from gluonts.torch.model.estimator import PyTorchLightningEstimator
from gluonts.torch.model.predictor import PyTorchPredictor
-from gluonts.torch.distributions import (
- DistributionOutput,
- StudentTOutput,
-)
+from gluonts.torch.distributions import Output, StudentTOutput
from .lightning_module import TiDELightningModule
@@ -174,7 +170,7 @@ def __init__(
weight_decay: float = 1e-8,
patience: int = 10,
scaling: Optional[str] = "mean",
- distr_output: DistributionOutput = StudentTOutput(),
+ distr_output: Output = StudentTOutput(),
batch_size: int = 32,
num_batches_per_epoch: int = 50,
trainer_kwargs: Optional[Dict[str, Any]] = None,
@@ -403,9 +399,7 @@ def create_predictor(
input_transform=transformation + prediction_splitter,
input_names=PREDICTION_INPUT_NAMES,
prediction_net=module,
- forecast_generator=DistributionForecastGenerator(
- self.distr_output
- ),
+ forecast_generator=self.distr_output.forecast_generator,
batch_size=self.batch_size,
prediction_length=self.prediction_length,
device="auto",
diff --git a/src/gluonts/torch/model/tide/module.py b/src/gluonts/torch/model/tide/module.py
index 875a05f8b2..e0eb06cb85 100644
--- a/src/gluonts/torch/model/tide/module.py
+++ b/src/gluonts/torch/model/tide/module.py
@@ -19,7 +19,7 @@
from gluonts.core.component import validated
from gluonts.torch.modules.feature import FeatureEmbedder
from gluonts.model import Input, InputSpec
-from gluonts.torch.distributions import DistributionOutput
+from gluonts.torch.distributions import Output
from gluonts.torch.scaler import StdScaler, MeanScaler, NOPScaler
from gluonts.torch.model.simple_feedforward import make_linear_layer
from gluonts.torch.util import weighted_average
@@ -242,7 +242,7 @@ def __init__(
num_layers_encoder: int,
num_layers_decoder: int,
layer_norm: bool,
- distr_output: DistributionOutput,
+ distr_output: Output,
scaling: str,
) -> None:
super().__init__()
diff --git a/test/core/test_serde.py b/test/core/test_serde.py
index 34ecdcb9ff..b651874d1c 100644
--- a/test/core/test_serde.py
+++ b/test/core/test_serde.py
@@ -142,3 +142,26 @@ def test_serde_method():
def test_np_str_dtype():
a = np.array(["foo"])
serde.decode(serde.encode(a.dtype)) == a.dtype
+
+
+@pytest.mark.parametrize(
+ "obj",
+ [
+ {"__kind__": 42, "class": cls_str}
+ for cls_str in [
+ "builtins.eval",
+ "builtins.exec",
+ "builtins.compile",
+ "builtins.open",
+ "builtins.input",
+ "eval",
+ "exec",
+ "compile",
+ "open",
+ "input",
+ ]
+ ],
+)
+def test_decode_disallow(obj):
+ with pytest.raises(ValueError):
+ serde.decode(obj)
diff --git a/test/ext/r_forecast/test_r_univariate_predictor.py b/test/ext/r_forecast/test_r_univariate_predictor.py
index eb5b59115c..4eefa1bbb0 100644
--- a/test/ext/r_forecast/test_r_univariate_predictor.py
+++ b/test/ext/r_forecast/test_r_univariate_predictor.py
@@ -48,6 +48,10 @@ def test_forecasts(method_name):
"MLP currently does not work because "
"the `neuralnet` package is not yet updated with a known bug fix in ` bips-hb/neuralnet`"
)
+ if method_name == "fourier.arima.xreg":
+ pytest.xfail(
+ "Method `fourier.arima.xreg` does not work because of a known issue."
+ )
dataset = datasets.get_dataset("constant")
diff --git a/test/time_feature/__init__.py b/test/time_feature/__init__.py
new file mode 100644
index 0000000000..f342912f9b
--- /dev/null
+++ b/test/time_feature/__init__.py
@@ -0,0 +1,12 @@
+# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License").
+# You may not use this file except in compliance with the License.
+# A copy of the License is located at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# or in the "license" file accompanying this file. This file is distributed
+# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+# express or implied. See the License for the specific language governing
+# permissions and limitations under the License.
diff --git a/test/time_feature/common.py b/test/time_feature/common.py
new file mode 100644
index 0000000000..89e19a23f8
--- /dev/null
+++ b/test/time_feature/common.py
@@ -0,0 +1,28 @@
+# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License").
+# You may not use this file except in compliance with the License.
+# A copy of the License is located at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# or in the "license" file accompanying this file. This file is distributed
+# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+# express or implied. See the License for the specific language governing
+# permissions and limitations under the License.
+
+import pandas as pd
+from packaging.version import Version
+
+if Version(pd.__version__) <= Version("2.2.0"):
+ S = "S"
+ H = "H"
+ M = "M"
+ Q = "Q"
+ Y = "A"
+else:
+ S = "s"
+ H = "h"
+ M = "ME"
+ Q = "QE"
+ Y = "YE"
diff --git a/test/time_feature/test_agg_lags.py b/test/time_feature/test_agg_lags.py
index dd3b2f2d9b..6e299e498d 100644
--- a/test/time_feature/test_agg_lags.py
+++ b/test/time_feature/test_agg_lags.py
@@ -16,7 +16,6 @@
import pytest
from gluonts.dataset.common import ListDataset
-
from gluonts.dataset.field_names import FieldName
from gluonts.transform import AddAggregateLags
diff --git a/test/time_feature/test_base.py b/test/time_feature/test_base.py
index 8e249eba86..e448156b2b 100644
--- a/test/time_feature/test_base.py
+++ b/test/time_feature/test_base.py
@@ -11,21 +11,23 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
+import pytest
from pandas.tseries.frequencies import to_offset
from gluonts.time_feature import norm_freq_str
+from .common import M, Q, S, Y
-def test_norm_freq_str():
- assert norm_freq_str(to_offset("Y").name) in ["A", "YE"]
- assert norm_freq_str(to_offset("YS").name) in ["A", "Y"]
- assert norm_freq_str(to_offset("A").name) in ["A", "YE"]
- assert norm_freq_str(to_offset("AS").name) in ["A", "Y"]
- assert norm_freq_str(to_offset("Q").name) in ["Q", "QE"]
- assert norm_freq_str(to_offset("QS").name) == "Q"
-
- assert norm_freq_str(to_offset("M").name) in ["M", "ME"]
- assert norm_freq_str(to_offset("MS").name) in ["M", "ME"]
-
- assert norm_freq_str(to_offset("S").name) in ["S", "s"]
+@pytest.mark.parametrize(
+ " aliases, normalized_freq_str",
+ [
+ (["Y", "YS", "A", "AS"], Y),
+ (["Q", "QS"], Q),
+ (["M", "MS"], M),
+ (["S"], S),
+ ],
+)
+def test_norm_freq_str(aliases, normalized_freq_str):
+ for alias in aliases:
+ assert norm_freq_str(to_offset(alias).name) == normalized_freq_str
diff --git a/test/time_feature/test_features.py b/test/time_feature/test_features.py
index 96c590fbf2..1c59db9909 100644
--- a/test/time_feature/test_features.py
+++ b/test/time_feature/test_features.py
@@ -16,7 +16,6 @@
import pytest
from gluonts import zebras as zb
-
from gluonts.time_feature import (
Constant,
TimeFeature,
diff --git a/test/time_feature/test_lag.py b/test/time_feature/test_lag.py
index 951a5f9cb4..2ce9651e0c 100644
--- a/test/time_feature/test_lag.py
+++ b/test/time_feature/test_lag.py
@@ -15,12 +15,16 @@
Test the lags computed for different frequencies.
"""
+import pytest
+
import gluonts.time_feature.lag as date_feature_set
+from .common import H, M, Q, Y
+
# These are the expected lags for common frequencies and corner cases.
# By default all frequencies have the following lags: [1, 2, 3, 4, 5, 6, 7].
# Remaining lags correspond to the same `season` (+/- `delta`) in previous `k` cycles.
-expected_lags = {
+EXPECTED_LAGS = {
# (apart from the default lags) centered around each of the last 3 hours (delta = 2)
"4S": [
1,
@@ -179,7 +183,7 @@
]
+ [329, 330, 331, 494, 495, 496, 659, 660, 661, 707, 708, 709],
# centered around each of the last 3 hours (delta = 2) + last 7 days (delta = 1) + last 6 weeks (delta = 1)
- "H": [1, 2, 3, 4, 5, 6, 7]
+ H: [1, 2, 3, 4, 5, 6, 7]
+ [
23,
24,
@@ -206,7 +210,7 @@
+ [335, 336, 337, 503, 504, 505, 671, 672, 673, 719, 720, 721],
# centered around each of the last 7 days (delta = 1) + last 4 weeks (delta = 1) + last 1 month (delta = 1) +
# last 8th and 12th weeks (delta = 0)
- "6H": [
+ ("6" + H): [
1,
2,
3,
@@ -237,21 +241,21 @@
+ [224, 336],
# centered around each of the last 7 days (delta = 1) + last 4 weeks (delta = 1) + last 1 month (delta = 1) +
# last 8th and 12th weeks (delta = 0) + last year (delta = 1)
- "12H": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
+ ("12" + H): [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
+ [27, 28, 29, 41, 42, 43, 55, 56, 57]
+ [59, 60, 61]
+ [112, 168]
+ [727, 728, 729],
# centered around each of the last 7 days (delta = 1) + last 4 weeks (delta = 1) + last 1 month (delta = 1) +
# last 8th and 12th weeks (delta = 0) + last 3 years (delta = 1)
- "23H": [1, 2, 3, 4, 5, 6, 7, 8]
+ ("23" + H): [1, 2, 3, 4, 5, 6, 7, 8]
+ [13, 14, 15, 20, 21, 22, 28, 29]
+ [30, 31, 32]
+ [58, 87]
+ [378, 379, 380, 758, 759, 760, 1138, 1139, 1140],
# centered around each of the last 7 days (delta = 1) + last 4 weeks (delta = 1) + last 1 month (delta = 1) +
# last 8th and 12th weeks (delta = 0) + last 3 years (delta = 1)
- "25H": [1, 2, 3, 4, 5, 6, 7]
+ ("25" + H): [1, 2, 3, 4, 5, 6, 7]
+ [12, 13, 14, 19, 20, 21, 25, 26, 27]
+ [28, 29]
+ [53, 80]
@@ -285,64 +289,31 @@
# centered around each of the last 3 years (delta = 1)
"5W": [1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 19, 20, 21, 30, 31, 32],
# centered around each of the last 3 years (delta = 1)
- "M": [1, 2, 3, 4, 5, 6, 7, 11, 12, 13, 23, 24, 25, 35, 36, 37],
+ M: [1, 2, 3, 4, 5, 6, 7, 11, 12, 13, 23, 24, 25, 35, 36, 37],
# default
- "6M": [1, 2, 3, 4, 5, 6, 7],
+ "6" + M: [1, 2, 3, 4, 5, 6, 7],
# default
- "12M": [1, 2, 3, 4, 5, 6, 7],
+ "12" + M: [1, 2, 3, 4, 5, 6, 7],
+ Q: [1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13],
+ "QS": [1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13],
+ Y: [1, 2, 3, 4, 5, 6, 7],
+ "YS": [1, 2, 3, 4, 5, 6, 7],
}
# For the default multiple (1)
-for freq in ["min", "H", "D", "W", "M"]:
- expected_lags["1" + freq] = expected_lags[freq]
+for freq in ["min", H, "D", "W", M]:
+ EXPECTED_LAGS["1" + freq] = EXPECTED_LAGS[freq]
# For frequencies that do not have unique form
-expected_lags["60min"] = expected_lags["1H"]
-expected_lags["24H"] = expected_lags["1D"]
-expected_lags["7D"] = expected_lags["1W"]
-
-
-def test_lags():
- freq_strs = [
- "4S",
- "min",
- "1min",
- "15min",
- "30min",
- "59min",
- "60min",
- "61min",
- "H",
- "1H",
- "6H",
- "12H",
- "23H",
- "24H",
- "25H",
- "D",
- "1D",
- "2D",
- "6D",
- "7D",
- "8D",
- "W",
- "1W",
- "3W",
- "4W",
- "5W",
- "M",
- "6M",
- "12M",
- ]
+EXPECTED_LAGS["60min"] = EXPECTED_LAGS["1" + H]
+EXPECTED_LAGS["24" + H] = EXPECTED_LAGS["1D"]
+EXPECTED_LAGS["7D"] = EXPECTED_LAGS["1W"]
- for freq_str in freq_strs:
- lags = date_feature_set.get_lags_for_frequency(freq_str)
- assert (
- lags == expected_lags[freq_str]
- ), "lags do not match for the frequency '{}':\nexpected: {},\nprovided: {}".format(
- freq_str, expected_lags[freq_str], lags
- )
+@pytest.mark.parametrize("freq_str, expected_lags", EXPECTED_LAGS.items())
+def test_lags(freq_str, expected_lags):
+ lags = date_feature_set.get_lags_for_frequency(freq_str)
+ assert lags == expected_lags
if __name__ == "__main__":
diff --git a/test/time_feature/test_seasonality.py b/test/time_feature/test_seasonality.py
index 0323e52ebe..3817416c2c 100644
--- a/test/time_feature/test_seasonality.py
+++ b/test/time_feature/test_seasonality.py
@@ -15,25 +15,42 @@
from gluonts.time_feature import get_seasonality
+from .common import H, M, Q, Y
-@pytest.mark.parametrize(
- "freq, expected_seasonality",
- [
- ("30min", 48),
- ("1H", 24),
- ("H", 24),
- ("2H", 12),
- ("3H", 8),
- ("4H", 6),
- ("15H", 1),
- ("5B", 1),
- ("1B", 5),
- ("2W", 1),
- ("3M", 4),
- ("1D", 1),
- ("7D", 1),
- ("8D", 1),
- ],
-)
+TEST_CASES = [
+ ("30min", 48),
+ ("5B", 1),
+ ("1B", 5),
+ ("2W", 1),
+ ("1D", 1),
+ ("7D", 1),
+ ("8D", 1),
+ # Monthly
+ ("MS", 12),
+ ("3MS", 4),
+ (M, 12),
+ ("3" + M, 4),
+ # Quarterly
+ ("QS", 4),
+ ("2QS", 2),
+ (Q, 4),
+ ("2" + Q, 2),
+ ("3" + Q, 1),
+ # Hourly
+ ("1" + H, 24),
+ (H, 24),
+ ("2" + H, 12),
+ ("3" + H, 8),
+ ("4" + H, 6),
+ ("15" + H, 1),
+ # Yearly
+ (Y, 1),
+ ("2" + Y, 1),
+ ("YS", 1),
+ ("2YS", 1),
+]
+
+
+@pytest.mark.parametrize("freq, expected_seasonality", TEST_CASES)
def test_get_seasonality(freq, expected_seasonality):
assert get_seasonality(freq) == expected_seasonality
diff --git a/test/torch/model/test_estimators.py b/test/torch/model/test_estimators.py
index 3c2ef3ea51..faea91b313 100644
--- a/test/torch/model/test_estimators.py
+++ b/test/torch/model/test_estimators.py
@@ -148,6 +148,14 @@
num_batches_per_epoch=3,
trainer_kwargs=dict(max_epochs=2),
),
+ lambda dataset: TiDEEstimator(
+ freq=dataset.metadata.freq,
+ prediction_length=dataset.metadata.prediction_length,
+ distr_output=QuantileOutput(quantiles=[0.1, 0.6, 0.85]),
+ batch_size=4,
+ num_batches_per_epoch=3,
+ trainer_kwargs=dict(max_epochs=2),
+ ),
lambda dataset: WaveNetEstimator(
freq=dataset.metadata.freq,
prediction_length=dataset.metadata.prediction_length,
@@ -296,6 +304,25 @@ def test_estimator_constant_dataset(
num_batches_per_epoch=3,
epochs=2,
),
+ lambda freq, prediction_length: PatchTSTEstimator(
+ prediction_length=prediction_length,
+ context_length=2 * prediction_length,
+ num_feat_dynamic_real=3,
+ patch_len=16,
+ batch_size=4,
+ num_batches_per_epoch=3,
+ trainer_kwargs=dict(max_epochs=2),
+ ),
+ lambda freq, prediction_length: PatchTSTEstimator(
+ prediction_length=prediction_length,
+ context_length=2 * prediction_length,
+ num_feat_dynamic_real=3,
+ distr_output=QuantileOutput(quantiles=[0.1, 0.6, 0.85]),
+ patch_len=16,
+ batch_size=4,
+ num_batches_per_epoch=3,
+ trainer_kwargs=dict(max_epochs=2),
+ ),
lambda freq, prediction_length: WaveNetEstimator(
freq=freq,
prediction_length=prediction_length,