From 3cdef76054ccaea19626373e9711da239b42ff91 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Fri, 31 May 2024 20:21:46 +0200 Subject: [PATCH] Backports v0.15.1 (reprise) (#3191) *Description of changes:* backporting fixes - #3188 - #3189 By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. **Please tag this pr with at least one of these labels to make our release process faster:** BREAKING, new feature, bug fix, other change, dev setup --------- Co-authored-by: Oleksandr Shchur --- .../howto_pytorch_lightning.md.template | 20 +++++++++---------- src/gluonts/model/forecast_generator.py | 15 ++++++++------ .../distributions/distribution_output.py | 8 -------- src/gluonts/torch/distributions/output.py | 7 +++++++ .../torch/distributions/quantile_output.py | 4 ++++ src/gluonts/torch/model/tide/estimator.py | 12 +++-------- src/gluonts/torch/model/tide/module.py | 4 ++-- test/torch/model/test_estimators.py | 8 ++++++++ 8 files changed, 43 insertions(+), 35 deletions(-) diff --git a/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template b/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template index 34e6937b19..9290260a13 100644 --- a/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template +++ b/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template @@ -110,10 +110,10 @@ class FeedForwardNetwork(nn.Module): torch.nn.init.zeros_(lin.bias) return lin - def forward(self, context): - scale = self.scaling(context) - scaled_context = context / scale - nn_out = self.nn(scaled_context) + def forward(self, past_target): + scale = self.scaling(past_target) + scaled_past_target = past_target / scale + nn_out = self.nn(scaled_past_target) nn_out_reshaped = nn_out.reshape(-1, self.prediction_length, self.hidden_dimensions[-1]) distr_args = self.args_proj(nn_out_reshaped) return distr_args, torch.zeros_like(scale), scale @@ -143,15 +143,15 @@ class LightningFeedForwardNetwork(FeedForwardNetwork, pl.LightningModule): super().__init__(*args, **kwargs) def training_step(self, batch, batch_idx): - context = batch["past_target"] - target = batch["future_target"] + past_target = batch["past_target"] + future_target = batch["future_target"] - assert context.shape[-1] == self.context_length - assert target.shape[-1] == self.prediction_length + assert past_target.shape[-1] == self.context_length + assert future_target.shape[-1] == self.prediction_length - distr_args, loc, scale = self(context) + distr_args, loc, scale = self(past_target) distr = self.distr_output.distribution(distr_args, loc, scale) - loss = -distr.log_prob(target) + loss = -distr.log_prob(future_target) return loss.mean() diff --git a/src/gluonts/model/forecast_generator.py b/src/gluonts/model/forecast_generator.py index 33b0320808..0148a8e1e6 100644 --- a/src/gluonts/model/forecast_generator.py +++ b/src/gluonts/model/forecast_generator.py @@ -83,12 +83,15 @@ def make_distribution_forecast(distr, *args, **kwargs) -> Forecast: def make_predictions(prediction_net, inputs: dict): - # MXNet predictors only support positional arguments - class_name = prediction_net.__class__.__module__ - if class_name.startswith("gluonts.mx") or class_name.startswith("mxnet"): - return prediction_net(*inputs.values()) - else: - return prediction_net(**inputs) + try: + # Feed inputs as positional arguments for MXNet block predictors + import mxnet as mx + + if isinstance(prediction_net, mx.gluon.Block): + return prediction_net(*inputs.values()) + except ImportError: + pass + return prediction_net(**inputs) class ForecastGenerator: diff --git a/src/gluonts/torch/distributions/distribution_output.py b/src/gluonts/torch/distributions/distribution_output.py index af786ca4ef..e583d941b5 100644 --- a/src/gluonts/torch/distributions/distribution_output.py +++ b/src/gluonts/torch/distributions/distribution_output.py @@ -90,14 +90,6 @@ def loss( nll = nll * (variance.detach() ** self.beta) return nll - @property - def event_shape(self) -> Tuple: - r""" - Shape of each individual event contemplated by the distributions that - this object constructs. - """ - raise NotImplementedError() - @property def event_dim(self) -> int: r""" diff --git a/src/gluonts/torch/distributions/output.py b/src/gluonts/torch/distributions/output.py index 49385f0e55..83d22246bb 100644 --- a/src/gluonts/torch/distributions/output.py +++ b/src/gluonts/torch/distributions/output.py @@ -105,6 +105,13 @@ def loss( """ raise NotImplementedError() + @property + def event_shape(self) -> Tuple: + r""" + Shape of each individual event compatible with the output object. + """ + raise NotImplementedError() + @property def forecast_generator(self) -> ForecastGenerator: raise NotImplementedError() diff --git a/src/gluonts/torch/distributions/quantile_output.py b/src/gluonts/torch/distributions/quantile_output.py index 4bce0fad53..ce104c5703 100644 --- a/src/gluonts/torch/distributions/quantile_output.py +++ b/src/gluonts/torch/distributions/quantile_output.py @@ -37,6 +37,10 @@ def __init__(self, quantiles: List[float]) -> None: def forecast_generator(self) -> ForecastGenerator: return QuantileForecastGenerator(quantiles=self.quantiles) + @property + def event_shape(self) -> Tuple: + return () + def domain_map(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]: return args diff --git a/src/gluonts/torch/model/tide/estimator.py b/src/gluonts/torch/model/tide/estimator.py index ab6c004c02..35f8cf5362 100644 --- a/src/gluonts/torch/model/tide/estimator.py +++ b/src/gluonts/torch/model/tide/estimator.py @@ -21,7 +21,6 @@ from gluonts.dataset.field_names import FieldName from gluonts.dataset.loader import as_stacked_batches from gluonts.itertools import Cyclic -from gluonts.model.forecast_generator import DistributionForecastGenerator from gluonts.time_feature import ( minute_of_hour, hour_of_day, @@ -49,10 +48,7 @@ from gluonts.torch.model.estimator import PyTorchLightningEstimator from gluonts.torch.model.predictor import PyTorchPredictor -from gluonts.torch.distributions import ( - DistributionOutput, - StudentTOutput, -) +from gluonts.torch.distributions import Output, StudentTOutput from .lightning_module import TiDELightningModule @@ -174,7 +170,7 @@ def __init__( weight_decay: float = 1e-8, patience: int = 10, scaling: Optional[str] = "mean", - distr_output: DistributionOutput = StudentTOutput(), + distr_output: Output = StudentTOutput(), batch_size: int = 32, num_batches_per_epoch: int = 50, trainer_kwargs: Optional[Dict[str, Any]] = None, @@ -403,9 +399,7 @@ def create_predictor( input_transform=transformation + prediction_splitter, input_names=PREDICTION_INPUT_NAMES, prediction_net=module, - forecast_generator=DistributionForecastGenerator( - self.distr_output - ), + forecast_generator=self.distr_output.forecast_generator, batch_size=self.batch_size, prediction_length=self.prediction_length, device="auto", diff --git a/src/gluonts/torch/model/tide/module.py b/src/gluonts/torch/model/tide/module.py index 875a05f8b2..e0eb06cb85 100644 --- a/src/gluonts/torch/model/tide/module.py +++ b/src/gluonts/torch/model/tide/module.py @@ -19,7 +19,7 @@ from gluonts.core.component import validated from gluonts.torch.modules.feature import FeatureEmbedder from gluonts.model import Input, InputSpec -from gluonts.torch.distributions import DistributionOutput +from gluonts.torch.distributions import Output from gluonts.torch.scaler import StdScaler, MeanScaler, NOPScaler from gluonts.torch.model.simple_feedforward import make_linear_layer from gluonts.torch.util import weighted_average @@ -242,7 +242,7 @@ def __init__( num_layers_encoder: int, num_layers_decoder: int, layer_norm: bool, - distr_output: DistributionOutput, + distr_output: Output, scaling: str, ) -> None: super().__init__() diff --git a/test/torch/model/test_estimators.py b/test/torch/model/test_estimators.py index 3c2ef3ea51..1833e3e966 100644 --- a/test/torch/model/test_estimators.py +++ b/test/torch/model/test_estimators.py @@ -148,6 +148,14 @@ num_batches_per_epoch=3, trainer_kwargs=dict(max_epochs=2), ), + lambda dataset: TiDEEstimator( + freq=dataset.metadata.freq, + prediction_length=dataset.metadata.prediction_length, + distr_output=QuantileOutput(quantiles=[0.1, 0.6, 0.85]), + batch_size=4, + num_batches_per_epoch=3, + trainer_kwargs=dict(max_epochs=2), + ), lambda dataset: WaveNetEstimator( freq=dataset.metadata.freq, prediction_length=dataset.metadata.prediction_length,