Skip to content

Commit

Permalink
Backports v0.15.1 (reprise) (#3191)
Browse files Browse the repository at this point in the history
*Description of changes:* backporting fixes
- #3188 
- #3189


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.


**Please tag this pr with at least one of these labels to make our
release process faster:** BREAKING, new feature, bug fix, other change,
dev setup

---------

Co-authored-by: Oleksandr Shchur <[email protected]>
  • Loading branch information
lostella and shchur authored May 31, 2024
1 parent 0cb0808 commit 3cdef76
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 35 deletions.
20 changes: 10 additions & 10 deletions docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ class FeedForwardNetwork(nn.Module):
torch.nn.init.zeros_(lin.bias)
return lin

def forward(self, context):
scale = self.scaling(context)
scaled_context = context / scale
nn_out = self.nn(scaled_context)
def forward(self, past_target):
scale = self.scaling(past_target)
scaled_past_target = past_target / scale
nn_out = self.nn(scaled_past_target)
nn_out_reshaped = nn_out.reshape(-1, self.prediction_length, self.hidden_dimensions[-1])
distr_args = self.args_proj(nn_out_reshaped)
return distr_args, torch.zeros_like(scale), scale
Expand Down Expand Up @@ -143,15 +143,15 @@ class LightningFeedForwardNetwork(FeedForwardNetwork, pl.LightningModule):
super().__init__(*args, **kwargs)

def training_step(self, batch, batch_idx):
context = batch["past_target"]
target = batch["future_target"]
past_target = batch["past_target"]
future_target = batch["future_target"]

assert context.shape[-1] == self.context_length
assert target.shape[-1] == self.prediction_length
assert past_target.shape[-1] == self.context_length
assert future_target.shape[-1] == self.prediction_length

distr_args, loc, scale = self(context)
distr_args, loc, scale = self(past_target)
distr = self.distr_output.distribution(distr_args, loc, scale)
loss = -distr.log_prob(target)
loss = -distr.log_prob(future_target)

return loss.mean()

Expand Down
15 changes: 9 additions & 6 deletions src/gluonts/model/forecast_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,15 @@ def make_distribution_forecast(distr, *args, **kwargs) -> Forecast:


def make_predictions(prediction_net, inputs: dict):
# MXNet predictors only support positional arguments
class_name = prediction_net.__class__.__module__
if class_name.startswith("gluonts.mx") or class_name.startswith("mxnet"):
return prediction_net(*inputs.values())
else:
return prediction_net(**inputs)
try:
# Feed inputs as positional arguments for MXNet block predictors
import mxnet as mx

if isinstance(prediction_net, mx.gluon.Block):
return prediction_net(*inputs.values())
except ImportError:
pass
return prediction_net(**inputs)


class ForecastGenerator:
Expand Down
8 changes: 0 additions & 8 deletions src/gluonts/torch/distributions/distribution_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,6 @@ def loss(
nll = nll * (variance.detach() ** self.beta)
return nll

@property
def event_shape(self) -> Tuple:
r"""
Shape of each individual event contemplated by the distributions that
this object constructs.
"""
raise NotImplementedError()

@property
def event_dim(self) -> int:
r"""
Expand Down
7 changes: 7 additions & 0 deletions src/gluonts/torch/distributions/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ def loss(
"""
raise NotImplementedError()

@property
def event_shape(self) -> Tuple:
r"""
Shape of each individual event compatible with the output object.
"""
raise NotImplementedError()

@property
def forecast_generator(self) -> ForecastGenerator:
raise NotImplementedError()
Expand Down
4 changes: 4 additions & 0 deletions src/gluonts/torch/distributions/quantile_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ def __init__(self, quantiles: List[float]) -> None:
def forecast_generator(self) -> ForecastGenerator:
return QuantileForecastGenerator(quantiles=self.quantiles)

@property
def event_shape(self) -> Tuple:
return ()

def domain_map(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]:
return args

Expand Down
12 changes: 3 additions & 9 deletions src/gluonts/torch/model/tide/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from gluonts.dataset.field_names import FieldName
from gluonts.dataset.loader import as_stacked_batches
from gluonts.itertools import Cyclic
from gluonts.model.forecast_generator import DistributionForecastGenerator
from gluonts.time_feature import (
minute_of_hour,
hour_of_day,
Expand Down Expand Up @@ -49,10 +48,7 @@

from gluonts.torch.model.estimator import PyTorchLightningEstimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.distributions import (
DistributionOutput,
StudentTOutput,
)
from gluonts.torch.distributions import Output, StudentTOutput

from .lightning_module import TiDELightningModule

Expand Down Expand Up @@ -174,7 +170,7 @@ def __init__(
weight_decay: float = 1e-8,
patience: int = 10,
scaling: Optional[str] = "mean",
distr_output: DistributionOutput = StudentTOutput(),
distr_output: Output = StudentTOutput(),
batch_size: int = 32,
num_batches_per_epoch: int = 50,
trainer_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -403,9 +399,7 @@ def create_predictor(
input_transform=transformation + prediction_splitter,
input_names=PREDICTION_INPUT_NAMES,
prediction_net=module,
forecast_generator=DistributionForecastGenerator(
self.distr_output
),
forecast_generator=self.distr_output.forecast_generator,
batch_size=self.batch_size,
prediction_length=self.prediction_length,
device="auto",
Expand Down
4 changes: 2 additions & 2 deletions src/gluonts/torch/model/tide/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from gluonts.core.component import validated
from gluonts.torch.modules.feature import FeatureEmbedder
from gluonts.model import Input, InputSpec
from gluonts.torch.distributions import DistributionOutput
from gluonts.torch.distributions import Output
from gluonts.torch.scaler import StdScaler, MeanScaler, NOPScaler
from gluonts.torch.model.simple_feedforward import make_linear_layer
from gluonts.torch.util import weighted_average
Expand Down Expand Up @@ -242,7 +242,7 @@ def __init__(
num_layers_encoder: int,
num_layers_decoder: int,
layer_norm: bool,
distr_output: DistributionOutput,
distr_output: Output,
scaling: str,
) -> None:
super().__init__()
Expand Down
8 changes: 8 additions & 0 deletions test/torch/model/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,14 @@
num_batches_per_epoch=3,
trainer_kwargs=dict(max_epochs=2),
),
lambda dataset: TiDEEstimator(
freq=dataset.metadata.freq,
prediction_length=dataset.metadata.prediction_length,
distr_output=QuantileOutput(quantiles=[0.1, 0.6, 0.85]),
batch_size=4,
num_batches_per_epoch=3,
trainer_kwargs=dict(max_epochs=2),
),
lambda dataset: WaveNetEstimator(
freq=dataset.metadata.freq,
prediction_length=dataset.metadata.prediction_length,
Expand Down

0 comments on commit 3cdef76

Please sign in to comment.