Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nonnegative predictions for deepar mxnet #2957

Merged
merged 14 commits into from
Aug 11, 2023
7 changes: 7 additions & 0 deletions src/gluonts/mx/model/deepar/_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ class DeepAREstimator(GluonEstimator):
num_imputation_samples
How many samples to use to impute values when
impute_missing_values=True
nonnegative_pred_samples
Should final prediction samples be non-negative? If yes, an activation
function is applied to ensure non-negative. Observe that this is applied
only to the final samples and this is not applied during training.
"""

@validated()
Expand Down Expand Up @@ -189,6 +193,7 @@ def __init__(
minimum_scale: float = 1e-10,
impute_missing_values: bool = False,
num_imputation_samples: int = 1,
nonnegative_pred_samples: bool = False,
) -> None:
super().__init__(trainer=trainer, batch_size=batch_size, dtype=dtype)

Expand Down Expand Up @@ -285,6 +290,7 @@ def __init__(
self.default_scale = default_scale
self.minimum_scale = minimum_scale
self.impute_missing_values = impute_missing_values
self.nonnegative_pred_samples = nonnegative_pred_samples

@classmethod
def derive_auto_fields(cls, train_iter):
Expand Down Expand Up @@ -470,6 +476,7 @@ def create_predictor(
default_scale=self.default_scale,
minimum_scale=self.minimum_scale,
impute_missing_values=self.impute_missing_values,
nonnegative_pred_samples=self.nonnegative_pred_samples,
)

copy_parameters(trained_network, prediction_network)
Expand Down
21 changes: 21 additions & 0 deletions src/gluonts/mx/model/deepar/_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
minimum_scale: float = 1e-10,
impute_missing_values: bool = False,
default_scale: Optional[float] = None,
nonnegative_pred_samples: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -72,6 +73,7 @@ def __init__(
self.num_cat = len(cardinality)
self.scaling = scaling
self.dtype = dtype
self.nonnegative_pred_samples = nonnegative_pred_samples

assert len(cardinality) == len(embedding_dimension), (
"embedding_dimension should be a list with the same size as"
Expand Down Expand Up @@ -789,6 +791,24 @@ def unroll_encoder_default(
# static_feat: (batch_size, num_features + prod(target_shape))
return outputs, state, scale, static_feat, sequence

def post_process_samples(self, F, samples: Tensor) -> Tensor:
"""
Method to enforce domain-specific constraints on the generated samples.
For example, we can enforce forecasts to be nonnegative.
Parameters
----------
samples
Tensor of shape (batch_size * num_samples, 1)
Returns
-------
Tensor of samples with the same shape.
"""

if self.nonnegative_pred_samples:
return F.relu(samples)

return samples


class DeepARTrainingNetwork(DeepARNetwork):
@validated()
Expand Down Expand Up @@ -1098,6 +1118,7 @@ def sampling_decoder(

# (batch_size * num_samples, 1, *target_shape)
new_samples = distr.sample(dtype=self.dtype)
new_samples = self.post_process_samples(F, new_samples)

# (batch_size * num_samples, seq_len, *target_shape)
repeated_past_target = F.concat(
Expand Down
56 changes: 56 additions & 0 deletions test/mx/model/deepar/test_nonnegative_pred_samples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import numpy as np
import pytest

from gluonts.mx import DeepAREstimator
from gluonts.mx.distribution import StudentTOutput
from gluonts.mx.trainer import Trainer
from gluonts.testutil.dummy_datasets import make_dummy_datasets_with_features


@pytest.mark.parametrize("datasets", [make_dummy_datasets_with_features()])
@pytest.mark.parametrize("distr_output", [StudentTOutput()])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize("impute_missing_values", [False, True])
@pytest.mark.parametrize("symbol_block_predictor", [False, True])
def test_deepar_nonnegative_pred_samples(
distr_output,
datasets,
dtype,
impute_missing_values,
symbol_block_predictor,
):
estimator = DeepAREstimator(
distr_output=distr_output,
dtype=dtype,
impute_missing_values=impute_missing_values,
nonnegative_pred_samples=True,
freq="D",
prediction_length=3,
trainer=Trainer(epochs=1, num_batches_per_epoch=1),
)

dataset_train, dataset_test = datasets
predictor = estimator.train(dataset_train)

if symbol_block_predictor:
predictor = predictor.as_symbol_block_predictor(dataset=dataset_test)

forecasts = list(predictor.predict(dataset_test))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To ensure the added feature also works in symbolic mode, you can call as_symbol_block_predictor on the predictor object, see

def as_symbol_block_predictor(

assert all([forecast.samples.dtype == dtype for forecast in forecasts])
assert len(forecasts) == len(dataset_test)

for forecast in forecasts:
assert (forecast.samples >= 0).all()
Loading