Skip to content

Commit

Permalink
Nonnegative predictions for deepar-pytorch (#2959)
Browse files Browse the repository at this point in the history
* DeepAR-torch to return nonnegative pred samples

* add tests

* clean up

---------

Co-authored-by: Pedro Eduardo Mercado Lopez <[email protected]>
  • Loading branch information
melopeo and Pedro Eduardo Mercado Lopez authored Aug 14, 2023
1 parent 589281e commit 33cb259
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/gluonts/torch/model/deepar/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ class DeepAREstimator(PyTorchLightningEstimator):
Controls the sampling of windows during training.
validation_sampler
Controls the sampling of windows during validation.
nonnegative_pred_samples
Should final prediction samples be non-negative? If yes, an activation
function is applied to ensure non-negative. Observe that this is applied
only to the final samples and this is not applied during training.
"""

@validated()
Expand Down Expand Up @@ -176,6 +180,7 @@ def __init__(
trainer_kwargs: Optional[Dict[str, Any]] = None,
train_sampler: Optional[InstanceSampler] = None,
validation_sampler: Optional[InstanceSampler] = None,
nonnegative_pred_samples: bool = False,
) -> None:
default_trainer_kwargs = {
"max_epochs": 100,
Expand Down Expand Up @@ -230,6 +235,7 @@ def __init__(
self.validation_sampler = validation_sampler or ValidationSplitSampler(
min_future=prediction_length
)
self.nonnegative_pred_samples = nonnegative_pred_samples

@classmethod
def derive_auto_fields(cls, train_iter):
Expand Down Expand Up @@ -397,6 +403,7 @@ def create_lightning_module(self) -> DeepARLightningModule:
"scaling": self.scaling,
"default_scale": self.default_scale,
"num_parallel_samples": self.num_parallel_samples,
"nonnegative_pred_samples": self.nonnegative_pred_samples,
},
)

Expand Down
28 changes: 28 additions & 0 deletions src/gluonts/torch/model/deepar/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ class DeepARModel(nn.Module):
num_parallel_samples
Number of samples to produce when unrolling the RNN in the prediction
time range.
nonnegative_pred_samples
Should final prediction samples be non-negative? If yes, an activation
function is applied to ensure non-negative. Observe that this is applied
only to the final samples and this is not applied during training.
"""

@validated()
Expand All @@ -107,6 +111,7 @@ def __init__(
scaling: bool = True,
default_scale: Optional[float] = None,
num_parallel_samples: int = 100,
nonnegative_pred_samples: bool = False,
) -> None:
super().__init__()

Expand Down Expand Up @@ -154,6 +159,7 @@ def __init__(
dropout=dropout_rate,
batch_first=True,
)
self.nonnegative_pred_samples = nonnegative_pred_samples

def describe_inputs(self, batch_size=1) -> InputSpec:
return InputSpec(
Expand Down Expand Up @@ -350,6 +356,24 @@ def output_distribution(
sliced_params = [p[:, -trailing_n:] for p in params]
return self.distr_output.distribution(sliced_params, scale=scale)

def post_process_samples(self, samples: torch.Tensor) -> torch.Tensor:
"""
Method to enforce domain-specific constraints on the generated samples.
For example, we can enforce forecasts to be nonnegative.
Parameters
----------
samples
Tensor of samples
Returns
-------
Tensor of processed samples with the same shape.
"""

if self.nonnegative_pred_samples:
return torch.relu(samples)

return samples

def forward(
self,
feat_static_cat: torch.Tensor,
Expand Down Expand Up @@ -451,6 +475,10 @@ def forward(

future_samples_concat = torch.cat(future_samples, dim=1)

future_samples_concat = self.post_process_samples(
future_samples_concat
)

return future_samples_concat.reshape(
(-1, num_parallel_samples, self.prediction_length)
)
Expand Down
42 changes: 42 additions & 0 deletions test/torch/model/test_deepar_nonnegative_pred_samples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import pytest

from gluonts.torch import DeepAREstimator
from gluonts.torch.distributions import StudentTOutput, NormalOutput
from gluonts.testutil.dummy_datasets import make_dummy_datasets_with_features


@pytest.mark.parametrize("datasets", [make_dummy_datasets_with_features()])
@pytest.mark.parametrize("distr_output", [StudentTOutput(), NormalOutput()])
def test_deepar_nonnegative_pred_samples(
distr_output,
datasets,
):
estimator = DeepAREstimator(
distr_output=distr_output,
nonnegative_pred_samples=True,
freq="D",
prediction_length=3,
trainer_kwargs={"max_epochs": 1},
)

dataset_train, dataset_test = datasets
predictor = estimator.train(dataset_train)
forecasts = list(predictor.predict(dataset_test))

assert len(forecasts) == len(dataset_test)

for forecast in forecasts:
assert (forecast.samples >= 0).all()

0 comments on commit 33cb259

Please sign in to comment.