diff --git a/src/gluonts/core/serde/_base.py b/src/gluonts/core/serde/_base.py index 6f51e02a70..5432bd8105 100644 --- a/src/gluonts/core/serde/_base.py +++ b/src/gluonts/core/serde/_base.py @@ -309,7 +309,7 @@ def decode(r: Any) -> Any: """ # structural recursion over the possible shapes of r - if type(r) == dict and "__kind__" in r: + if isinstance(r, dict) and "__kind__" in r: kind = r["__kind__"] cls = cast(Any, locate(r["class"])) @@ -331,10 +331,10 @@ def decode(r: Any) -> Any: raise ValueError(f"Unknown kind {kind}.") - if type(r) == dict: + if isinstance(r, dict): return valmap(decode, r) - if type(r) == list: + if isinstance(r, list): return list(map(decode, r)) return r diff --git a/src/gluonts/torch/model/wavenet/estimator.py b/src/gluonts/torch/model/wavenet/estimator.py index 504bc0fc4d..87a12a80f5 100644 --- a/src/gluonts/torch/model/wavenet/estimator.py +++ b/src/gluonts/torch/model/wavenet/estimator.py @@ -18,7 +18,7 @@ import numpy as np from gluonts.core.component import validated -from gluonts.dataset.common import DataEntry, Dataset +from gluonts.dataset.common import Dataset from gluonts.dataset.field_names import FieldName from gluonts.dataset.loader import as_stacked_batches from gluonts.itertools import Cyclic @@ -41,7 +41,7 @@ InstanceSplitter, SetField, RemoveFields, - SimpleTransformation, + QuantizeMeanScaled, VstackFeatures, Identity, TestSplitSampler, @@ -67,70 +67,6 @@ ] -class QuantizeScaled(SimpleTransformation): - """Rescale and quantize the target variable. - Requires `past_target_field`, and `future_target_field` to be present. - - The mean absolute value of the past_target is used to rescale - past_target and future_target. Then the bin_edges are used to quantize - the rescaled target. - - The calculated scale is stored in the `scale_field`. - - Parameters - ---------- - bin_edges - The bin edges for quantization. - past_target_field, optional - The field name that contains `past_target`, - by default "past_target" - past_observed_values_field, optional - The field name that contains `past_observed_values`, - by default "past_observed_values" - future_target_field, optional - The field name that contains `future_target`, - by default "future_target" - scale_field, optional - The field name where scale will be stored, - by default "scale" - """ - - @validated() - def __init__( - self, - bin_edges: List[float], - past_target_field: str = "past_target", - past_observed_values_field: str = "past_observed_values", - future_target_field: str = "future_target", - scale_field: str = "scale", - ): - self.bin_edges = np.array(bin_edges) - self.future_target_field = future_target_field - self.past_target_field = past_target_field - self.past_observed_values_field = past_observed_values_field - self.scale_field = scale_field - - def transform(self, data: DataEntry) -> DataEntry: - target = data[self.past_target_field] - weights = data.get( - self.past_observed_values_field, np.ones_like(target) - ) - m = np.sum(np.abs(target) * weights) / np.sum(weights) - scale = m if m > 0 else 1.0 - data[self.future_target_field] = np.digitize( - data[self.future_target_field] / scale, - bins=self.bin_edges, - right=False, - ) - data[self.past_target_field] = np.digitize( - data[self.past_target_field] / scale, - bins=self.bin_edges, - right=False, - ) - data[self.scale_field] = np.array([scale], dtype=np.float32) - return data - - class WaveNetEstimator(PyTorchLightningEstimator): @validated() def __init__( @@ -392,7 +328,7 @@ def _create_instance_splitter(self, mode: str): FieldName.FEAT_TIME, FieldName.OBSERVED_VALUES, ], - ) + QuantizeScaled(bin_edges=self.bin_edges) + ) + QuantizeMeanScaled(bin_edges=self.bin_edges) def create_training_data_loader( self, diff --git a/src/gluonts/torch/model/wavenet/module.py b/src/gluonts/torch/model/wavenet/module.py index 788b5c453d..a910df3f91 100644 --- a/src/gluonts/torch/model/wavenet/module.py +++ b/src/gluonts/torch/model/wavenet/module.py @@ -18,27 +18,7 @@ from gluonts.core.component import validated from gluonts.torch.modules.feature import FeatureEmbedder - - -class LookupValues(nn.Module): - """Lookup bin values from bin indices. - - Parameters - ---------- - bin_values - Tensor of bin values with shape (num_bins, ). - """ - - @validated() - def __init__(self, bin_values: torch.Tensor): - super().__init__() - self.register_buffer("bin_values", bin_values) - - def forward(self, indices: torch.Tensor) -> torch.Tensor: - indices = torch.clamp(indices, 0, self.bin_values.shape[0] - 1) - return torch.index_select( - self.bin_values, 0, indices.reshape(-1) - ).view_as(indices) +from gluonts.torch.modules.lookup_table import LookupValues class CausalDilatedResidualLayer(nn.Module): diff --git a/src/gluonts/torch/modules/lookup_table.py b/src/gluonts/torch/modules/lookup_table.py new file mode 100644 index 0000000000..b412f0ba96 --- /dev/null +++ b/src/gluonts/torch/modules/lookup_table.py @@ -0,0 +1,38 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import torch +import torch.nn as nn + +from gluonts.core.component import validated + + +class LookupValues(nn.Module): + """A lookup table mapping bin indices to values. + + Parameters + ---------- + bin_values + Tensor of bin values with shape (num_bins, ). + """ + + @validated() + def __init__(self, bin_values: torch.Tensor): + super().__init__() + self.register_buffer("bin_values", bin_values) + + def forward(self, indices: torch.Tensor) -> torch.Tensor: + indices = torch.clamp(indices, 0, self.bin_values.shape[0] - 1) + return torch.index_select( + self.bin_values, 0, indices.reshape(-1) + ).view_as(indices) diff --git a/src/gluonts/transform/__init__.py b/src/gluonts/transform/__init__.py index 6f9b39bdb8..dddfcbc5bb 100644 --- a/src/gluonts/transform/__init__.py +++ b/src/gluonts/transform/__init__.py @@ -41,6 +41,7 @@ "InstanceSplitter", "ListFeatures", "MapTransformation", + "QuantizeMeanScaled", "RemoveFields", "RenameFields", "SampleTargetDim", @@ -89,6 +90,7 @@ VstackFeatures, cdf_to_gaussian_forward_transform, Valmap, + QuantizeMeanScaled, ) from .feature import ( AddAgeFeature, diff --git a/src/gluonts/transform/convert.py b/src/gluonts/transform/convert.py index ed4855a2da..156ddcf638 100644 --- a/src/gluonts/transform/convert.py +++ b/src/gluonts/transform/convert.py @@ -904,3 +904,67 @@ def flatmap_transform( if len(times) > 0 or not self.drop_empty: data[self.target_field] = [times, sizes] yield data + + +class QuantizeMeanScaled(SimpleTransformation): + """Rescale and quantize the target variable. + Requires `past_target_field`, and `future_target_field` to be present. + + The mean absolute value of the past_target is used to rescale + past_target and future_target. Then the bin_edges are used to quantize + the rescaled target. + + The calculated scale is stored in the `scale_field`. + + Parameters + ---------- + bin_edges + The bin edges for quantization. + past_target_field, optional + The field name that contains `past_target`, + by default "past_target" + past_observed_values_field, optional + The field name that contains `past_observed_values`, + by default "past_observed_values" + future_target_field, optional + The field name that contains `future_target`, + by default "future_target" + scale_field, optional + The field name where scale will be stored, + by default "scale" + """ + + @validated() + def __init__( + self, + bin_edges: List[float], + past_target_field: str = "past_target", + past_observed_values_field: str = "past_observed_values", + future_target_field: str = "future_target", + scale_field: str = "scale", + ): + self.bin_edges = np.array(bin_edges) + self.future_target_field = future_target_field + self.past_target_field = past_target_field + self.past_observed_values_field = past_observed_values_field + self.scale_field = scale_field + + def transform(self, data: DataEntry) -> DataEntry: + target = data[self.past_target_field] + weights = data.get( + self.past_observed_values_field, np.ones_like(target) + ) + m = np.sum(np.abs(target) * weights) / np.sum(weights) + scale = m if m > 0 else 1.0 + data[self.future_target_field] = np.digitize( + data[self.future_target_field] / scale, + bins=self.bin_edges, + right=False, + ) + data[self.past_target_field] = np.digitize( + data[self.past_target_field] / scale, + bins=self.bin_edges, + right=False, + ) + data[self.scale_field] = np.array([scale], dtype=np.float32) + return data