diff --git a/src/gluonts/core/component.py b/src/gluonts/core/component.py index 8fcdc07f8f..c5f18d011a 100644 --- a/src/gluonts/core/component.py +++ b/src/gluonts/core/component.py @@ -296,12 +296,16 @@ def validator(init): init_params = inspect.signature(init).parameters init_fields = { param.name: ( - param.annotation - if param.annotation != inspect.Parameter.empty - else Any, - param.default - if param.default != inspect.Parameter.empty - else ..., + ( + param.annotation + if param.annotation != inspect.Parameter.empty + else Any + ), + ( + param.default + if param.default != inspect.Parameter.empty + else ... + ), ) for param in init_params.values() if param.name != "self" diff --git a/src/gluonts/dataset/arrow/file.py b/src/gluonts/dataset/arrow/file.py index 40f1a06638..7bdb6cf898 100644 --- a/src/gluonts/dataset/arrow/file.py +++ b/src/gluonts/dataset/arrow/file.py @@ -51,16 +51,13 @@ def infer( return ArrowStreamFile(path) @abc.abstractmethod - def metadata(self) -> Dict[str, str]: - ... + def metadata(self) -> Dict[str, str]: ... @abc.abstractmethod - def __iter__(self): - ... + def __iter__(self): ... @abc.abstractmethod - def __len__(self): - ... + def __len__(self): ... @dataclass diff --git a/src/gluonts/dataset/repository/_lstnet.py b/src/gluonts/dataset/repository/_lstnet.py index ba46cf03ae..e933666c77 100644 --- a/src/gluonts/dataset/repository/_lstnet.py +++ b/src/gluonts/dataset/repository/_lstnet.py @@ -200,9 +200,9 @@ def generate_lstnet_dataset( meta = MetaData( **metadata( cardinality=ds_info.num_series, - freq=ds_info.freq - if ds_info.agg_freq is None - else ds_info.agg_freq, + freq=( + ds_info.freq if ds_info.agg_freq is None else ds_info.agg_freq + ), prediction_length=prediction_length or ds_info.prediction_length, ) ) diff --git a/src/gluonts/dataset/stat.py b/src/gluonts/dataset/stat.py index 0becdb4bc4..c3b9a39dfb 100644 --- a/src/gluonts/dataset/stat.py +++ b/src/gluonts/dataset/stat.py @@ -418,12 +418,12 @@ def calculate_dataset_statistics(ts_dataset: Any) -> DatasetStatistics: max_target_length=max_target_length, min_target=min_target, num_missing_values=num_missing_values, - feat_static_real=observed_feat_static_real - if observed_feat_static_real - else [], - feat_static_cat=observed_feat_static_cat - if observed_feat_static_cat - else [], + feat_static_real=( + observed_feat_static_real if observed_feat_static_real else [] + ), + feat_static_cat=( + observed_feat_static_cat if observed_feat_static_cat else [] + ), num_past_feat_dynamic_real=num_past_feat_dynamic_real, num_feat_dynamic_real=num_feat_dynamic_real, num_feat_dynamic_cat=num_feat_dynamic_cat, diff --git a/src/gluonts/evaluation/_base.py b/src/gluonts/evaluation/_base.py index f913c7d8fd..b623bf3d75 100644 --- a/src/gluonts/evaluation/_base.py +++ b/src/gluonts/evaluation/_base.py @@ -378,9 +378,9 @@ def get_base_metrics( return { "item_id": forecast.item_id, "forecast_start": forecast.start_date, - "MSE": mse(pred_target, mean_fcst) - if mean_fcst is not None - else None, + "MSE": ( + mse(pred_target, mean_fcst) if mean_fcst is not None else None + ), "abs_error": abs_error(pred_target, median_fcst), "abs_target_sum": abs_target_sum(pred_target), "abs_target_mean": abs_target_mean(pred_target), diff --git a/src/gluonts/ext/rotbaum/_model.py b/src/gluonts/ext/rotbaum/_model.py index 29bd1e3cf7..3c0270a49a 100644 --- a/src/gluonts/ext/rotbaum/_model.py +++ b/src/gluonts/ext/rotbaum/_model.py @@ -340,11 +340,11 @@ def _get_and_cache_quantile_computation( The quantile of the associated true value bin. """ if feature_vector_in_train not in self.quantile_dicts[quantile]: - self.quantile_dicts[quantile][ - feature_vector_in_train - ] = np.percentile( - self.id_to_bins[self.preds_to_id[feature_vector_in_train]], - quantile * 100, + self.quantile_dicts[quantile][feature_vector_in_train] = ( + np.percentile( + self.id_to_bins[self.preds_to_id[feature_vector_in_train]], + quantile * 100, + ) ) return self.quantile_dicts[quantile][feature_vector_in_train] diff --git a/src/gluonts/ext/rotbaum/_predictor.py b/src/gluonts/ext/rotbaum/_predictor.py index 63daea1a97..6631e8dde0 100644 --- a/src/gluonts/ext/rotbaum/_predictor.py +++ b/src/gluonts/ext/rotbaum/_predictor.py @@ -254,13 +254,17 @@ def train( target_data[:, train_QRX_only_using_timestep], ) self.model_list = [ - QRX( - xgboost_params=self.model_params, - min_bin_size=self.min_bin_size, - model=self.model_list[train_QRX_only_using_timestep].model, + ( + QRX( + xgboost_params=self.model_params, + min_bin_size=self.min_bin_size, + model=self.model_list[ + train_QRX_only_using_timestep + ].model, + ) + if i != train_QRX_only_using_timestep + else self.model_list[i] ) - if i != train_QRX_only_using_timestep - else self.model_list[i] for i in range(n_models) ] with concurrent.futures.ThreadPoolExecutor( diff --git a/src/gluonts/itertools.py b/src/gluonts/itertools.py index e90281a1d8..198559d6f6 100644 --- a/src/gluonts/itertools.py +++ b/src/gluonts/itertools.py @@ -42,11 +42,9 @@ @runtime_checkable class SizedIterable(Protocol): - def __len__(self): - ... + def __len__(self): ... - def __iter__(self): - ... + def __iter__(self): ... T = TypeVar("T") diff --git a/src/gluonts/model/forecast_generator.py b/src/gluonts/model/forecast_generator.py index ac33aae158..d66a1d361d 100644 --- a/src/gluonts/model/forecast_generator.py +++ b/src/gluonts/model/forecast_generator.py @@ -132,9 +132,11 @@ def __call__( yield QuantileForecast( output.T, start_date=batch[FieldName.FORECAST_START][i], - item_id=batch[FieldName.ITEM_ID][i] - if FieldName.ITEM_ID in batch - else None, + item_id=( + batch[FieldName.ITEM_ID][i] + if FieldName.ITEM_ID in batch + else None + ), info=batch["info"][i] if "info" in batch else None, forecast_keys=self.quantiles, ) @@ -181,9 +183,11 @@ def __call__( yield SampleForecast( output, start_date=batch[FieldName.FORECAST_START][i], - item_id=batch[FieldName.ITEM_ID][i] - if FieldName.ITEM_ID in batch - else None, + item_id=( + batch[FieldName.ITEM_ID][i] + if FieldName.ITEM_ID in batch + else None + ), info=batch["info"][i] if "info" in batch else None, ) assert i + 1 == len(batch[FieldName.FORECAST_START]) @@ -221,9 +225,11 @@ def __call__( yield make_distribution_forecast( distr, start_date=batch[FieldName.FORECAST_START][i], - item_id=batch[FieldName.ITEM_ID][i] - if FieldName.ITEM_ID in batch - else None, + item_id=( + batch[FieldName.ITEM_ID][i] + if FieldName.ITEM_ID in batch + else None + ), info=batch["info"][i] if "info" in batch else None, ) assert i + 1 == len(batch[FieldName.FORECAST_START]) diff --git a/src/gluonts/mx/batchify.py b/src/gluonts/mx/batchify.py index ea1de2adb0..b534b7f523 100644 --- a/src/gluonts/mx/batchify.py +++ b/src/gluonts/mx/batchify.py @@ -117,13 +117,16 @@ def as_in_context(batch: dict, ctx: mx.Context = None) -> DataBatch: Move data into new context, should only be in main process. """ batch = { - k: v.as_in_context(ctx) if isinstance(v, mx.nd.NDArray) - # Workaround due to MXNet not being able to handle NDArrays with 0 in - # shape properly: - else ( - stack(v, ctx=ctx, dtype=v.dtype, variable_length=False) - if isinstance(v[0], np.ndarray) and 0 in v[0].shape - else v + k: ( + v.as_in_context(ctx) + if isinstance(v, mx.nd.NDArray) + # Workaround due to MXNet not being able to handle NDArrays with 0 in + # shape properly: + else ( + stack(v, ctx=ctx, dtype=v.dtype, variable_length=False) + if isinstance(v[0], np.ndarray) and 0 in v[0].shape + else v + ) ) for k, v in batch.items() } diff --git a/src/gluonts/mx/block/dropout.py b/src/gluonts/mx/block/dropout.py index e1de5e22e0..2522820a2c 100644 --- a/src/gluonts/mx/block/dropout.py +++ b/src/gluonts/mx/block/dropout.py @@ -231,9 +231,11 @@ def mask(p, like): # mask as output, instead of simply copy output to the first element # in case that the base cell is ResidualCell new_states = [ - F.where(output_mask, next_states[0], states[0]) - if p_outputs != 0.0 - else next_states[0] + ( + F.where(output_mask, next_states[0], states[0]) + if p_outputs != 0.0 + else next_states[0] + ) ] new_states.extend( [ diff --git a/src/gluonts/mx/distribution/box_cox_transform.py b/src/gluonts/mx/distribution/box_cox_transform.py index 46d9afe87a..99125a032a 100644 --- a/src/gluonts/mx/distribution/box_cox_transform.py +++ b/src/gluonts/mx/distribution/box_cox_transform.py @@ -116,6 +116,7 @@ class BoxCoxTransform(Bijection): `tol_lambda_1` F """ + arg_names = ["box_cox.lambda_1", "box_cox.lambda_2"] @validated() diff --git a/src/gluonts/mx/distribution/inflated_beta.py b/src/gluonts/mx/distribution/inflated_beta.py index 345126a291..93c0781893 100644 --- a/src/gluonts/mx/distribution/inflated_beta.py +++ b/src/gluonts/mx/distribution/inflated_beta.py @@ -113,6 +113,7 @@ class ZeroInflatedBeta(ZeroAndOneInflatedBeta): `(*batch_shape, *event_shape)`. F """ + is_reparameterizable = False @validated() @@ -145,6 +146,7 @@ class OneInflatedBeta(ZeroAndOneInflatedBeta): `(*batch_shape, *event_shape)`. F """ + is_reparameterizable = False @validated() diff --git a/src/gluonts/mx/distribution/lds.py b/src/gluonts/mx/distribution/lds.py index 6167a00d1e..31d3ae225a 100644 --- a/src/gluonts/mx/distribution/lds.py +++ b/src/gluonts/mx/distribution/lds.py @@ -409,9 +409,11 @@ def sample( # (num_samples, batch_size, latent_dim, latent_dim) # innovation_coeff_t: (num_samples, batch_size, 1, latent_dim) emission_coeff_t, transition_coeff_t, innovation_coeff_t = ( - _broadcast_param(coeff, axes=[0], sizes=[num_samples]) - if num_samples is not None - else coeff + ( + _broadcast_param(coeff, axes=[0], sizes=[num_samples]) + if num_samples is not None + else coeff + ) for coeff in [ self.emission_coeff[t], self.transition_coeff[t], @@ -458,9 +460,11 @@ def sample( if scale is None else F.broadcast_mul( samples, - scale.expand_dims(axis=1).expand_dims(axis=0) - if num_samples is not None - else scale.expand_dims(axis=1), + ( + scale.expand_dims(axis=1).expand_dims(axis=0) + if num_samples is not None + else scale.expand_dims(axis=1) + ), ) ) diff --git a/src/gluonts/mx/distribution/lowrank_gp.py b/src/gluonts/mx/distribution/lowrank_gp.py index 443cc40acb..2e6de964a8 100644 --- a/src/gluonts/mx/distribution/lowrank_gp.py +++ b/src/gluonts/mx/distribution/lowrank_gp.py @@ -101,9 +101,7 @@ def hybrid_forward(self, F, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: D_vector = self.proj[1](x_plus_w).squeeze(axis=-1) d_bias = ( - 0.0 - if self.sigma_init == 0.0 - else inv_softplus(self.sigma_init**2) + 0.0 if self.sigma_init == 0.0 else inv_softplus(self.sigma_init**2) ) D_positive = ( diff --git a/src/gluonts/mx/distribution/lowrank_multivariate_gaussian.py b/src/gluonts/mx/distribution/lowrank_multivariate_gaussian.py index b8131178a5..f446799367 100644 --- a/src/gluonts/mx/distribution/lowrank_multivariate_gaussian.py +++ b/src/gluonts/mx/distribution/lowrank_multivariate_gaussian.py @@ -408,9 +408,7 @@ def domain_map(self, F, mu_vector, D_vector, W_vector=None): """ d_bias = ( - inv_softplus(self.sigma_init**2) - if self.sigma_init > 0.0 - else 0.0 + inv_softplus(self.sigma_init**2) if self.sigma_init > 0.0 else 0.0 ) # sigma_minimum helps avoiding cholesky problems, we could also jitter diff --git a/src/gluonts/mx/model/deepar/_network.py b/src/gluonts/mx/model/deepar/_network.py index f59bc517f9..2b2049afd2 100644 --- a/src/gluonts/mx/model/deepar/_network.py +++ b/src/gluonts/mx/model/deepar/_network.py @@ -572,9 +572,11 @@ def unroll_encoder_imputation( static_feat = F.concat( embedded_cat, feat_static_real, - F.log(scale) - if len(self.target_shape) == 0 - else F.log(scale.squeeze(axis=1)), + ( + F.log(scale) + if len(self.target_shape) == 0 + else F.log(scale.squeeze(axis=1)) + ), dim=1, ) @@ -603,9 +605,9 @@ def unroll_encoder_imputation( begin_state = self.rnn.begin_state( func=F.zeros, dtype=self.dtype, - batch_size=inputs.shape[0] - if isinstance(inputs, mx.nd.NDArray) - else 0, + batch_size=( + inputs.shape[0] if isinstance(inputs, mx.nd.NDArray) else 0 + ), ) unroll_results = self.imputation_rnn_unroll( @@ -726,9 +728,11 @@ def unroll_encoder_default( static_feat = F.concat( embedded_cat, feat_static_real, - F.log(scale) - if len(self.target_shape) == 0 - else F.log(scale.squeeze(axis=1)), + ( + F.log(scale) + if len(self.target_shape) == 0 + else F.log(scale.squeeze(axis=1)) + ), dim=1, ) @@ -757,9 +761,9 @@ def unroll_encoder_default( begin_state = self.rnn.begin_state( func=F.zeros, dtype=self.dtype, - batch_size=inputs.shape[0] - if isinstance(inputs, mx.nd.NDArray) - else 0, + batch_size=( + inputs.shape[0] if isinstance(inputs, mx.nd.NDArray) else 0 + ), ) state = begin_state # This is a dummy computation to avoid deferred initialization error diff --git a/src/gluonts/mx/model/deepvar/_estimator.py b/src/gluonts/mx/model/deepvar/_estimator.py index 0a1828369d..9881aa5df9 100644 --- a/src/gluonts/mx/model/deepvar/_estimator.py +++ b/src/gluonts/mx/model/deepvar/_estimator.py @@ -304,9 +304,9 @@ def __init__( self.scaling = scaling if self.use_marginal_transformation: - self.output_transform: Optional[ - Callable - ] = cdf_to_gaussian_forward_transform + self.output_transform: Optional[Callable] = ( + cdf_to_gaussian_forward_transform + ) else: self.output_transform = None diff --git a/src/gluonts/mx/model/estimator.py b/src/gluonts/mx/model/estimator.py index 1cf992a796..122e34fad3 100644 --- a/src/gluonts/mx/model/estimator.py +++ b/src/gluonts/mx/model/estimator.py @@ -177,9 +177,9 @@ def train_model( transformation = self.create_transformation() with env._let(max_idle_transforms=max(len(training_data), 100)): - transformed_training_data: Union[ - TransformedDataset, Cached - ] = transformation.apply(training_data) + transformed_training_data: Union[TransformedDataset, Cached] = ( + transformation.apply(training_data) + ) if cache_data: transformed_training_data = Cached(transformed_training_data) diff --git a/src/gluonts/mx/model/gpvar/_estimator.py b/src/gluonts/mx/model/gpvar/_estimator.py index 58edcf3982..29aea6ac31 100644 --- a/src/gluonts/mx/model/gpvar/_estimator.py +++ b/src/gluonts/mx/model/gpvar/_estimator.py @@ -63,7 +63,6 @@ class GPVAREstimator(GluonEstimator): - """ Constructs a GPVAR estimator. diff --git a/src/gluonts/mx/model/gpvar/_network.py b/src/gluonts/mx/model/gpvar/_network.py index 03691aeeec..1a8a60a8b8 100644 --- a/src/gluonts/mx/model/gpvar/_network.py +++ b/src/gluonts/mx/model/gpvar/_network.py @@ -116,9 +116,9 @@ def unroll( length=unroll_length, layout="NTC", merge_outputs=True, - begin_state=begin_state[i] - if begin_state is not None - else None, + begin_state=( + begin_state[i] if begin_state is not None else None + ), ) outputs.append(outputs_single_dim) states.append(state) diff --git a/src/gluonts/mx/model/n_beats/_ensemble.py b/src/gluonts/mx/model/n_beats/_ensemble.py index f1f8f8eb81..4351163251 100644 --- a/src/gluonts/mx/model/n_beats/_ensemble.py +++ b/src/gluonts/mx/model/n_beats/_ensemble.py @@ -201,9 +201,11 @@ def predict( yield SampleForecast( output, start_date=start_date, - item_id=item[FieldName.ITEM_ID] - if FieldName.ITEM_ID in item - else None, + item_id=( + item[FieldName.ITEM_ID] + if FieldName.ITEM_ID in item + else None + ), info=item["info"] if "info" in item else None, ) diff --git a/src/gluonts/mx/model/renewal/_network.py b/src/gluonts/mx/model/renewal/_network.py index aa83273506..d0d7bc0edd 100644 --- a/src/gluonts/mx/model/renewal/_network.py +++ b/src/gluonts/mx/model/renewal/_network.py @@ -89,9 +89,11 @@ def distribution( cond_interval, cond_size = F.split(cond_mean, num_outputs=2, axis=-1) alpha_biases = [ - F.broadcast_mul(F.ones_like(cond_interval), bias) - if bias is not None - else None + ( + F.broadcast_mul(F.ones_like(cond_interval), bias) + if bias is not None + else None + ) for bias in [interval_alpha_bias, size_alpha_bias] ] diff --git a/src/gluonts/mx/model/renewal/_transform.py b/src/gluonts/mx/model/renewal/_transform.py index 479c6a7f83..884ec31051 100644 --- a/src/gluonts/mx/model/renewal/_transform.py +++ b/src/gluonts/mx/model/renewal/_transform.py @@ -53,9 +53,11 @@ def transform(self, data: DataEntry) -> DataEntry: target = data[self.target_field] data[self.output_field] = np.array( [ - len(target) - if isinstance(target, list) - else target.shape[self.axis] + ( + len(target) + if isinstance(target, list) + else target.shape[self.axis] + ) ] ) return data diff --git a/src/gluonts/mx/model/seq2seq/_transform.py b/src/gluonts/mx/model/seq2seq/_transform.py index 4134e8b01b..7c443e82d7 100644 --- a/src/gluonts/mx/model/seq2seq/_transform.py +++ b/src/gluonts/mx/model/seq2seq/_transform.py @@ -180,21 +180,21 @@ def flatmap_transform( # (Fortran) ordering with strides = # (dtype, dtype*n_rows) stride = decoder_fields.strides - out[self._future(ts_field)][ - pad_length_dec: - ] = as_strided( - decoder_fields, - shape=( - self.num_forking - pad_length_dec, - self.dec_len, - ts_len, - ), - # strides for 2D array expanded to 3D array of - # shape (dim1, dim2, dim3) =(1, n_rows, n_cols). - # For transposed data, strides = (dtype, dtype * - # dim1, dtype*dim1*dim2) = (dtype, dtype, - # dtype*n_rows). - strides=stride[0:1] + stride, + out[self._future(ts_field)][pad_length_dec:] = ( + as_strided( + decoder_fields, + shape=( + self.num_forking - pad_length_dec, + self.dec_len, + ts_len, + ), + # strides for 2D array expanded to 3D array of + # shape (dim1, dim2, dim3) =(1, n_rows, n_cols). + # For transposed data, strides = (dtype, dtype * + # dim1, dtype*dim1*dim2) = (dtype, dtype, + # dtype*n_rows). + strides=stride[0:1] + stride, + ) ) # edge case for prediction_length = 1 diff --git a/src/gluonts/mx/model/tft/_estimator.py b/src/gluonts/mx/model/tft/_estimator.py index db19544603..a4d8335719 100644 --- a/src/gluonts/mx/model/tft/_estimator.py +++ b/src/gluonts/mx/model/tft/_estimator.py @@ -172,13 +172,13 @@ def __init__( self.past_dynamic_feature_dims = {} for name in self.past_dynamic_features: if name in self.dynamic_cardinalities: - self.past_dynamic_cardinalities[ - name - ] = self.dynamic_cardinalities.pop(name) + self.past_dynamic_cardinalities[name] = ( + self.dynamic_cardinalities.pop(name) + ) elif name in self.dynamic_feature_dims: - self.past_dynamic_feature_dims[ - name - ] = self.dynamic_feature_dims.pop(name) + self.past_dynamic_feature_dims[name] = ( + self.dynamic_feature_dims.pop(name) + ) else: raise ValueError( f"Feature name {name} is not provided in feature dicts" diff --git a/src/gluonts/mx/model/tpp/distribution/base.py b/src/gluonts/mx/model/tpp/distribution/base.py index a4d93c88aa..990d57d880 100644 --- a/src/gluonts/mx/model/tpp/distribution/base.py +++ b/src/gluonts/mx/model/tpp/distribution/base.py @@ -187,6 +187,7 @@ class TPPDistributionOutput(DistributionOutput): 1. Location param cannot be specified (all distributions must start at 0). 2. The return type is either TPPDistribution or TPPTransformedDistribution. """ + distr_cls: type def distribution( diff --git a/src/gluonts/mx/model/tpp/distribution/weibull.py b/src/gluonts/mx/model/tpp/distribution/weibull.py index 3055d47c68..09b73b51f6 100644 --- a/src/gluonts/mx/model/tpp/distribution/weibull.py +++ b/src/gluonts/mx/model/tpp/distribution/weibull.py @@ -34,6 +34,7 @@ class Weibull(TPPDistribution): parameter :math:`\lambda > 0` and the shape parameter :math:`k > 0`, and :math:`\lambda = b^{-1/k}`. """ + is_reparametrizable = True @validated() diff --git a/src/gluonts/mx/model/tpp/predictor.py b/src/gluonts/mx/model/tpp/predictor.py index cb706931e1..159427651e 100644 --- a/src/gluonts/mx/model/tpp/predictor.py +++ b/src/gluonts/mx/model/tpp/predictor.py @@ -86,9 +86,9 @@ def __call__( # type: ignore start_date=batch["forecast_start"][i], freq=freq, prediction_interval_length=prediction_net.prediction_interval_length, # noqa: E501 - item_id=batch["item_id"][i] - if "item_id" in batch - else None, + item_id=( + batch["item_id"][i] if "item_id" in batch else None + ), info=batch["info"][i] if "info" in batch else None, ) diff --git a/src/gluonts/mx/model/transformer/_network.py b/src/gluonts/mx/model/transformer/_network.py index 5e3082e968..245acd0c57 100644 --- a/src/gluonts/mx/model/transformer/_network.py +++ b/src/gluonts/mx/model/transformer/_network.py @@ -194,9 +194,11 @@ def create_network_input( # prediction too(batch_size, num_features + prod(target_shape)) static_feat = F.concat( embedded_cat, - F.log(scale) - if len(self.target_shape) == 0 - else F.log(scale.squeeze(axis=1)), + ( + F.log(scale) + if len(self.target_shape) == 0 + else F.log(scale.squeeze(axis=1)) + ), dim=1, ) diff --git a/src/gluonts/mx/model/wavenet/_estimator.py b/src/gluonts/mx/model/wavenet/_estimator.py index 28fca9d2cc..293676af6b 100644 --- a/src/gluonts/mx/model/wavenet/_estimator.py +++ b/src/gluonts/mx/model/wavenet/_estimator.py @@ -309,9 +309,11 @@ def _create_instance_splitter(self, mode: str): forecast_start_field=FieldName.FORECAST_START, instance_sampler=instance_sampler, past_length=self.context_length, - future_length=self.prediction_length - if mode == "test" - else self.train_window_length, + future_length=( + self.prediction_length + if mode == "test" + else self.train_window_length + ), output_NTC=False, time_series_fields=[ FieldName.FEAT_TIME, diff --git a/src/gluonts/nursery/SCott/dataset_tools/synthetic.py b/src/gluonts/nursery/SCott/dataset_tools/synthetic.py index a16034b83c..8e2dc383b6 100644 --- a/src/gluonts/nursery/SCott/dataset_tools/synthetic.py +++ b/src/gluonts/nursery/SCott/dataset_tools/synthetic.py @@ -57,12 +57,12 @@ def get_mixed_pattern(unit_length=16, num_duplicates=1000): for j in range(num_duplicates): context = torch.arange(context_length, dtype=torch.float) for i in range(1, pattern_number): - context[ - unit_length * (i - 1) : unit_length * i - ] = _get_mixed_pattern( - context[unit_length * (i - 1) : unit_length * i] - - unit_length * (i - 1), - pattern[(gid + i) % pattern_number], + context[unit_length * (i - 1) : unit_length * i] = ( + _get_mixed_pattern( + context[unit_length * (i - 1) : unit_length * i] + - unit_length * (i - 1), + pattern[(gid + i) % pattern_number], + ) ) ts_sample = torch.cat( [ diff --git a/src/gluonts/nursery/SCott/preprocess_data.py b/src/gluonts/nursery/SCott/preprocess_data.py index 4f247eb229..40b181820f 100644 --- a/src/gluonts/nursery/SCott/preprocess_data.py +++ b/src/gluonts/nursery/SCott/preprocess_data.py @@ -57,12 +57,12 @@ def get_mixed_pattern(unit_length=16, num_duplicates=1000): for j in range(num_duplicates): context = torch.arange(context_length, dtype=torch.float) for i in range(1, pattern_number): - context[ - unit_length * (i - 1) : unit_length * i - ] = _get_mixed_pattern( - context[unit_length * (i - 1) : unit_length * i] - - unit_length * (i - 1), - pattern[(gid + i) % pattern_number], + context[unit_length * (i - 1) : unit_length * i] = ( + _get_mixed_pattern( + context[unit_length * (i - 1) : unit_length * i] + - unit_length * (i - 1), + pattern[(gid + i) % pattern_number], + ) ) ts_sample = torch.cat( [ diff --git a/src/gluonts/nursery/daf/network/kernel.py b/src/gluonts/nursery/daf/network/kernel.py index 3a26e78595..9dccdd3e3b 100644 --- a/src/gluonts/nursery/daf/network/kernel.py +++ b/src/gluonts/nursery/daf/network/kernel.py @@ -105,17 +105,21 @@ def __init__( _query_weight = nn.Parameter(Tensor(d_hidden, d_hidden)) self._query_weights.append(_query_weight) self._key_weights.append( - _query_weight - if self.symmetric - else nn.Parameter(Tensor(d_hidden, d_hidden)), + ( + _query_weight + if self.symmetric + else nn.Parameter(Tensor(d_hidden, d_hidden)) + ), ) if self.bias: _query_bias = nn.Parameter(Tensor(d_hidden)) self._query_biases.append(_query_bias) self._key_biases.append( - _query_bias - if self.symmetric - else nn.Parameter(Tensor(d_hidden)), + ( + _query_bias + if self.symmetric + else nn.Parameter(Tensor(d_hidden)) + ), ) else: self._query_biases.append(None) diff --git a/src/gluonts/nursery/daf/tslib/dataset/timeseries.py b/src/gluonts/nursery/daf/tslib/dataset/timeseries.py index 0036ef6acf..4bd63964b8 100644 --- a/src/gluonts/nursery/daf/tslib/dataset/timeseries.py +++ b/src/gluonts/nursery/daf/tslib/dataset/timeseries.py @@ -335,12 +335,10 @@ def __len__(self): return len(self.target) @overload - def index_by_timestamp(self, index: pd.Timestamp) -> int: - ... + def index_by_timestamp(self, index: pd.Timestamp) -> int: ... @overload - def index_by_timestamp(self, index: List[pd.Timestamp]) -> List[int]: - ... + def index_by_timestamp(self, index: List[pd.Timestamp]) -> List[int]: ... def index_by_timestamp(self, index): return pd.Series( @@ -348,24 +346,19 @@ def index_by_timestamp(self, index): ).loc[index] @overload - def __getitem__(self, index: int) -> TimeSeriesInstant: - ... + def __getitem__(self, index: int) -> TimeSeriesInstant: ... @overload - def __getitem__(self, index: pd.Timestamp) -> TimeSeriesInstant: - ... + def __getitem__(self, index: pd.Timestamp) -> TimeSeriesInstant: ... @overload - def __getitem__(self, index: slice) -> TimeSeries: - ... + def __getitem__(self, index: slice) -> TimeSeries: ... @overload - def __getitem__(self, index: List[int]) -> TimeSeries: - ... + def __getitem__(self, index: List[int]) -> TimeSeries: ... @overload - def __getitem__(self, index: List[pd.Timestamp]) -> TimeSeries: - ... + def __getitem__(self, index: List[pd.Timestamp]) -> TimeSeries: ... def __getitem__(self, index): if isinstance(index, pd.Timestamp) or ( diff --git a/src/gluonts/nursery/daf/tslib/engine/evaluator.py b/src/gluonts/nursery/daf/tslib/engine/evaluator.py index 636f6e6b0d..f61c5a16fe 100644 --- a/src/gluonts/nursery/daf/tslib/engine/evaluator.py +++ b/src/gluonts/nursery/daf/tslib/engine/evaluator.py @@ -63,9 +63,11 @@ def load(self, tag: str): path = self.log_dir.joinpath(f"{tag}.pt.tar") state = pt.load( path, - map_location=f"cuda:{self.cuda_device}" - if self.cuda_device >= 0 - else "cpu", + map_location=( + f"cuda:{self.cuda_device}" + if self.cuda_device >= 0 + else "cpu" + ), ) print(f"Load checkpoint from {path}") except FileNotFoundError: diff --git a/src/gluonts/nursery/daf/tslib/engine/trainer.py b/src/gluonts/nursery/daf/tslib/engine/trainer.py index 07d7f9f9c4..63ad4484ec 100644 --- a/src/gluonts/nursery/daf/tslib/engine/trainer.py +++ b/src/gluonts/nursery/daf/tslib/engine/trainer.py @@ -143,9 +143,11 @@ def load( path = self.log_dir.joinpath(f"{tag}.pt.tar") state = pt.load( path, - map_location=f"cuda:{self.cuda_device}" - if self.cuda_device >= 0 - else "cpu", + map_location=( + f"cuda:{self.cuda_device}" + if self.cuda_device >= 0 + else "cpu" + ), ) print(f"Load checkpoint from {path}") except FileNotFoundError: diff --git a/src/gluonts/nursery/few_shot_prediction/src/meta/data/batch.py b/src/gluonts/nursery/few_shot_prediction/src/meta/data/batch.py index 9c9ff7f227..6512d7f5df 100644 --- a/src/gluonts/nursery/few_shot_prediction/src/meta/data/batch.py +++ b/src/gluonts/nursery/few_shot_prediction/src/meta/data/batch.py @@ -29,12 +29,14 @@ class SeriesBatch: and the splits sizes indicating the corresponding base dataset. """ - sequences: torch.Tensor # shape [batch, num_sequences, max_sequence_length] + sequences: ( + torch.Tensor + ) # shape [batch, num_sequences, max_sequence_length] lengths: torch.Tensor # shape [batch] split_sections: torch.Tensor # shape [batch] - scales: Optional[ - torch.Tensor - ] = None # shape[batch, 2] contains mean and std the ts has been scaled with + scales: Optional[torch.Tensor] = ( + None # shape[batch, 2] contains mean and std the ts has been scaled with + ) @classmethod def from_lists( @@ -61,9 +63,11 @@ def from_lists( pad_sequence(values, batch_first=True), lengths=torch.as_tensor([len(s) for s in series]), split_sections=torch.as_tensor(split_sections), - scales=torch.stack([s.scale for s in series]) - if series[0].scale is not None - else None, + scales=( + torch.stack([s.scale for s in series]) + if series[0].scale is not None + else None + ), ) def pin_memory(self): diff --git a/src/gluonts/nursery/few_shot_prediction/src/meta/data/sampling.py b/src/gluonts/nursery/few_shot_prediction/src/meta/data/sampling.py index a86ddd348b..9fc089fafe 100644 --- a/src/gluonts/nursery/few_shot_prediction/src/meta/data/sampling.py +++ b/src/gluonts/nursery/few_shot_prediction/src/meta/data/sampling.py @@ -170,12 +170,14 @@ def __iter__(self) -> Iterator[Triplet]: supps_size=self.support_set_size, length=self.support_length, dataset=self.dataset, - cheat_query=cheat_query[0] - if np.random.rand() < self.cheat - else None, - index_iterator=self.index_iterator - if self.catch22_nn is None - else iter(self.catch22_nn[query_idx]), + cheat_query=( + cheat_query[0] if np.random.rand() < self.cheat else None + ), + index_iterator=( + self.index_iterator + if self.catch22_nn is None + else iter(self.catch22_nn[query_idx]) + ), ) yield Triplet(support_set, query_past, query_future) @@ -311,9 +313,11 @@ def __getitem__(self, index: int) -> Triplet: # but this will be slower # seed=self.seed + index if self.seed else None, cheat_query=cheat_query if np.random.rand() < self.cheat else None, - index_iterator=self.index_iterator - if self.catch22_nn is None - else iter(self.catch22_nn[query_past[0].item_id]), + index_iterator=( + self.index_iterator + if self.catch22_nn is None + else iter(self.catch22_nn[query_past[0].item_id]) + ), ) return Triplet(support_set, query_past, query_future) diff --git a/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/artificial.py b/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/artificial.py index 19e1475c66..5a2dab200b 100644 --- a/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/artificial.py +++ b/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/artificial.py @@ -593,9 +593,9 @@ def generate_artificial_tuplets( np.arange(0, context_length - signal_length) ) si = np.random.choice(support_set_size) - support_set[si][ - marker_start : marker_start + signal_length - ] = query[-signal_length:] + support_set[si][marker_start : marker_start + signal_length] = ( + query[-signal_length:] + ) else: signal = np.concatenate( (np.ones((4,)), query[-prediction_length:]) @@ -647,9 +647,9 @@ def generate_artificial_tuplets( np.arange(0, context_length - signal_length) ) si = np.random.choice(support_set_size) - support_set[si][ - marker_start : marker_start + signal_length - ] = query[-signal_length:] + support_set[si][marker_start : marker_start + signal_length] = ( + query[-signal_length:] + ) # else: # signal = query[-prediction_length:] # signal_length = prediction_length diff --git a/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/m4.py b/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/m4.py index 240afc9b2f..a510676208 100644 --- a/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/m4.py +++ b/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/m4.py @@ -127,9 +127,11 @@ def generate_m4_dataset( start_dates = list(meta_df.StartingDate) start_dates = [ - sd - if pd.Timestamp(sd) <= pd.Timestamp("2022") - else str(pd.Timestamp("1900")) + ( + sd + if pd.Timestamp(sd) <= pd.Timestamp("2022") + else str(pd.Timestamp("1900")) + ) for sd in start_dates ] diff --git a/src/gluonts/nursery/few_shot_prediction/src/meta/models/module.py b/src/gluonts/nursery/few_shot_prediction/src/meta/models/module.py index 3a4e24080c..c5f7b3ad7f 100644 --- a/src/gluonts/nursery/few_shot_prediction/src/meta/models/module.py +++ b/src/gluonts/nursery/few_shot_prediction/src/meta/models/module.py @@ -109,7 +109,7 @@ def configure_optimizers(self) -> Dict: verbose=True, ), "monitor": self.lr_scheduler_monitor, - "frequency": 1 + "frequency": 1, # If "monitor" references validation metrics, then "frequency" should be set to a # multiple of "trainer.check_val_every_n_epoch". }, diff --git a/src/gluonts/nursery/gmm_tpp/gmm_base.py b/src/gluonts/nursery/gmm_tpp/gmm_base.py index a2ad97691e..902c05fa31 100644 --- a/src/gluonts/nursery/gmm_tpp/gmm_base.py +++ b/src/gluonts/nursery/gmm_tpp/gmm_base.py @@ -54,9 +54,11 @@ def __init__( "log_prior_", shape=(num_clusters,), lr_mult=lr_mult, - init=mx.init.Constant(np.log(1 / self.num_clusters)) - if log_prior_ is None - else mx.init.Constant(log_prior_), + init=( + mx.init.Constant(np.log(1 / self.num_clusters)) + if log_prior_ is None + else mx.init.Constant(log_prior_) + ), ) self.mu_ = self.params.get( diff --git a/src/gluonts/nursery/robust-mts-attack/attack_and_save.py b/src/gluonts/nursery/robust-mts-attack/attack_and_save.py index 0f1918bf6f..a3366e3bfc 100644 --- a/src/gluonts/nursery/robust-mts-attack/attack_and_save.py +++ b/src/gluonts/nursery/robust-mts-attack/attack_and_save.py @@ -133,9 +133,11 @@ def main(): ) best_perturbation = attack.attack_batch( batch, - true_future_target=future_target - if device == "cpu" - else torch.from_numpy(future_target).float().to(device), + true_future_target=( + future_target + if device == "cpu" + else torch.from_numpy(future_target).float().to(device) + ), ) batch_res = AttackResults( diff --git a/src/gluonts/nursery/robust-mts-attack/attack_sparse_layer.py b/src/gluonts/nursery/robust-mts-attack/attack_sparse_layer.py index a1309388cd..739888ee89 100644 --- a/src/gluonts/nursery/robust-mts-attack/attack_sparse_layer.py +++ b/src/gluonts/nursery/robust-mts-attack/attack_sparse_layer.py @@ -127,9 +127,11 @@ def main(): ) best_perturbation = attack.attack_batch( batch, - true_future_target=future_target - if device == "cpu" - else torch.from_numpy(future_target).float().to(device), + true_future_target=( + future_target + if device == "cpu" + else torch.from_numpy(future_target).float().to(device) + ), ) batch_res = AttackResults( diff --git a/src/gluonts/nursery/robust-mts-attack/multivariate/datasets/dataset.py b/src/gluonts/nursery/robust-mts-attack/multivariate/datasets/dataset.py index a56828f49e..768fe06070 100644 --- a/src/gluonts/nursery/robust-mts-attack/multivariate/datasets/dataset.py +++ b/src/gluonts/nursery/robust-mts-attack/multivariate/datasets/dataset.py @@ -121,9 +121,11 @@ def make_multivariate_dataset( align_data=False, num_test_dates=num_test_dates, max_target_dim=dim ) return MultivariateDatasetInfo( - dataset_name - if dataset_benchmark_name is None - else dataset_benchmark_name, + ( + dataset_name + if dataset_benchmark_name is None + else dataset_benchmark_name + ), grouper_train(train_ds), grouper_test(test_ds), prediction_length, diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/causal_deepar/causal_deepar_network.py b/src/gluonts/nursery/robust-mts-attack/pts/model/causal_deepar/causal_deepar_network.py index c162cb97f6..12fa525867 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/causal_deepar/causal_deepar_network.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/causal_deepar/causal_deepar_network.py @@ -216,12 +216,16 @@ def unroll_encoder( ( embedded_cat, feat_static_real, - scale.log() - if len(self.target_shape) == 0 - else scale.squeeze(1).log(), - control_scale.log() - if len(self.target_shape) == 0 - else control_scale.squeeze(1).log(), + ( + scale.log() + if len(self.target_shape) == 0 + else scale.squeeze(1).log() + ), + ( + control_scale.log() + if len(self.target_shape) == 0 + else control_scale.squeeze(1).log() + ), ), dim=1, ) diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/deepar/deepar_network.py b/src/gluonts/nursery/robust-mts-attack/pts/model/deepar/deepar_network.py index b756bc17e9..8d66c91ec0 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/deepar/deepar_network.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/deepar/deepar_network.py @@ -190,9 +190,11 @@ def unroll_encoder( ( embedded_cat, feat_static_real, - scale.log() - if len(self.target_shape) == 0 - else scale.squeeze(1).log(), + ( + scale.log() + if len(self.target_shape) == 0 + else scale.squeeze(1).log() + ), ), dim=1, ) diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_estimator.py b/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_estimator.py index fec7d5a2ee..4bea3ddf23 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_estimator.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_estimator.py @@ -139,9 +139,9 @@ def __init__( self.scaling = scaling if self.use_marginal_transformation: - self.output_transform: Optional[ - Callable - ] = cdf_to_gaussian_forward_transform + self.output_transform: Optional[Callable] = ( + cdf_to_gaussian_forward_transform + ) else: self.output_transform = None diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_network.py b/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_network.py index 5a4e0e0534..64353b4adf 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_network.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_network.py @@ -188,9 +188,11 @@ def unroll( ( embedded_cat, feat_static_real, - scale.log() - if len(self.target_shape) == 0 - else scale.squeeze(1).log(), + ( + scale.log() + if len(self.target_shape) == 0 + else scale.squeeze(1).log() + ), ), dim=1, ) diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/n_beats/n_beats_ensemble.py b/src/gluonts/nursery/robust-mts-attack/pts/model/n_beats/n_beats_ensemble.py index eb8c3ef7ae..ff9234273d 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/n_beats/n_beats_ensemble.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/n_beats/n_beats_ensemble.py @@ -102,9 +102,11 @@ def predict( output, start_date=start_date, freq=start_date.freqstr, - item_id=item[FieldName.ITEM_ID] - if FieldName.ITEM_ID in item - else None, + item_id=( + item[FieldName.ITEM_ID] + if FieldName.ITEM_ID in item + else None + ), info=item["info"] if "info" in item else None, ) diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_estimator.py b/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_estimator.py index af2805c0b3..fdb47f7d87 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_estimator.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_estimator.py @@ -106,13 +106,13 @@ def __init__( self.past_dynamic_feature_dims = {} for name in self.past_dynamic_features: if name in self.dynamic_cardinalities: - self.past_dynamic_cardinalities[ - name - ] = self.dynamic_cardinalities.pop(name) + self.past_dynamic_cardinalities[name] = ( + self.dynamic_cardinalities.pop(name) + ) elif name in self.dynamic_feature_dims: - self.past_dynamic_feature_dims[ - name - ] = self.dynamic_feature_dims.pop(name) + self.past_dynamic_feature_dims[name] = ( + self.dynamic_feature_dims.pop(name) + ) else: raise ValueError( f"Feature name {name} is not provided in feature dicts" diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/transformer/transformer_network.py b/src/gluonts/nursery/robust-mts-attack/pts/model/transformer/transformer_network.py index 37ae2b2bf8..3ac17fef6e 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/transformer/transformer_network.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/transformer/transformer_network.py @@ -206,9 +206,11 @@ def create_network_input( ( embedded_cat, feat_static_real, - torch.log(scale) - if len(self.target_shape) == 0 - else torch.log(scale.squeeze(1)), + ( + torch.log(scale) + if len(self.target_shape) == 0 + else torch.log(scale.squeeze(1)) + ), ), dim=1, ) diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_network.py b/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_network.py index 56b7b4d457..55d60d1537 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_network.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_network.py @@ -165,7 +165,11 @@ def create_network_input( future_time_feat: Optional[torch.Tensor], future_target_cdf: Optional[torch.Tensor], target_dimension_indicator: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,]: + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: """ Unrolls the RNN encoder over past and, if present, future data. Returns outputs and state of the encoder, plus the scale of diff --git a/src/gluonts/nursery/robust-mts-attack/pts/modules/gaussian_diffusion.py b/src/gluonts/nursery/robust-mts-attack/pts/modules/gaussian_diffusion.py index 2819be62a4..73195e7957 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/modules/gaussian_diffusion.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/modules/gaussian_diffusion.py @@ -80,9 +80,7 @@ def __init__( if beta_schedule == "linear": betas = np.linspace(1e-4, beta_end, diff_steps) elif beta_schedule == "quad": - betas = ( - np.linspace(1e-4**0.5, beta_end**0.5, diff_steps) ** 2 - ) + betas = np.linspace(1e-4**0.5, beta_end**0.5, diff_steps) ** 2 elif beta_schedule == "const": betas = beta_end * np.ones(diff_steps) elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 diff --git a/src/gluonts/nursery/robust-mts-attack/utils.py b/src/gluonts/nursery/robust-mts-attack/utils.py index 3bafea8d4d..75a36ec169 100644 --- a/src/gluonts/nursery/robust-mts-attack/utils.py +++ b/src/gluonts/nursery/robust-mts-attack/utils.py @@ -263,13 +263,13 @@ def calc_loss( if ( true_future_target[:, attack_idx][..., target_items] != 0 ).prod() == 0: - mape[attack_type][ - testset_idx : testset_idx + batch_size - ] = np.abs( - forecasts[attack_type][i][:, :, attack_idx][ - ..., target_items - ].mean(1) - - true_future_target[:, attack_idx][..., target_items] + mape[attack_type][testset_idx : testset_idx + batch_size] = ( + np.abs( + forecasts[attack_type][i][:, :, attack_idx][ + ..., target_items + ].mean(1) + - true_future_target[:, attack_idx][..., target_items] + ) ) mse[attack_type][testset_idx : testset_idx + batch_size] = ( forecasts[attack_type][i][:, :, attack_idx][ @@ -290,14 +290,14 @@ def calc_loss( j, testset_idx : testset_idx + batch_size ] = quantile_loss(true, pred, quantile) else: - mape[attack_type][ - testset_idx : testset_idx + batch_size - ] = np.abs( - forecasts[attack_type][i][:, :, attack_idx][ - ..., target_items - ].mean(1) - / true_future_target[:, attack_idx][..., target_items] - - 1 + mape[attack_type][testset_idx : testset_idx + batch_size] = ( + np.abs( + forecasts[attack_type][i][:, :, attack_idx][ + ..., target_items + ].mean(1) + / true_future_target[:, attack_idx][..., target_items] + - 1 + ) ) mse[attack_type][testset_idx : testset_idx + batch_size] = ( mape[attack_type][testset_idx : testset_idx + batch_size] diff --git a/src/gluonts/nursery/san/_network.py b/src/gluonts/nursery/san/_network.py index b3d08665f2..126dbdd0c2 100644 --- a/src/gluonts/nursery/san/_network.py +++ b/src/gluonts/nursery/san/_network.py @@ -213,9 +213,11 @@ def _assemble_covariates( covariates.append( feat_static_real.expand_dims(axis=1).repeat( axis=1, - repeats=self.context_length - if is_past - else self.prediction_length, + repeats=( + self.context_length + if is_past + else self.prediction_length + ), ) ) if len(covariates) > 0: @@ -231,9 +233,11 @@ def _assemble_covariates( categories.append( feat_static_cat.expand_dims(axis=1).repeat( axis=1, - repeats=self.context_length - if is_past - else self.prediction_length, + repeats=( + self.context_length + if is_past + else self.prediction_length + ), ) ) if len(categories) > 0: diff --git a/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_estimator.py b/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_estimator.py index 03bb055f00..8c8b297fa6 100644 --- a/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_estimator.py +++ b/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_estimator.py @@ -238,10 +238,11 @@ def __init__( # adapt window_length if RollingMeanValueImputation is used if isinstance(imputation_method, RollingMeanValueImputation): - base_estimator_hps_agg[ - "imputation_method" - ] = RollingMeanValueImputation( - window_size=imputation_method.window_size // agg_multiple + base_estimator_hps_agg["imputation_method"] = ( + RollingMeanValueImputation( + window_size=imputation_method.window_size + // agg_multiple + ) ) # Hack to enforce correct serialization of lags_seq and history length diff --git a/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/gluonts_fixes.py b/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/gluonts_fixes.py index a8ac13bc14..b66c8f6887 100644 --- a/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/gluonts_fixes.py +++ b/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/gluonts_fixes.py @@ -52,15 +52,17 @@ def batchify_with_dict( is_right_pad: bool = True, ) -> DataBatch: return { - key: stack( - data=[item[key] for item in data], - ctx=ctx, - dtype=dtype, - variable_length=variable_length, - is_right_pad=is_right_pad, + key: ( + stack( + data=[item[key] for item in data], + ctx=ctx, + dtype=dtype, + variable_length=variable_length, + is_right_pad=is_right_pad, + ) + if not isinstance(data[0][key], dict) + else batchify_with_dict(data=[item[key] for item in data]) ) - if not isinstance(data[0][key], dict) - else batchify_with_dict(data=[item[key] for item in data]) for key in data[0].keys() } diff --git a/src/gluonts/nursery/tsbench/src/cli/analysis/scripts/ensemble.py b/src/gluonts/nursery/tsbench/src/cli/analysis/scripts/ensemble.py index e1d03db304..8d1a17668c 100644 --- a/src/gluonts/nursery/tsbench/src/cli/analysis/scripts/ensemble.py +++ b/src/gluonts/nursery/tsbench/src/cli/analysis/scripts/ensemble.py @@ -66,9 +66,9 @@ def main( tracker, ensemble_size=size, ensemble_weighting=weighting, - config_class=MODEL_REGISTRY[model_class] - if model_class is not None - else None, + config_class=( + MODEL_REGISTRY[model_class] if model_class is not None else None + ), ) df, configs = evaluator.run() diff --git a/src/gluonts/nursery/tsbench/src/cli/evaluations/schedule.py b/src/gluonts/nursery/tsbench/src/cli/evaluations/schedule.py index a49864cc59..80cb563da4 100644 --- a/src/gluonts/nursery/tsbench/src/cli/evaluations/schedule.py +++ b/src/gluonts/nursery/tsbench/src/cli/evaluations/schedule.py @@ -177,12 +177,14 @@ def job_factory() -> str: tags=[ {"Key": "Experiment", "Value": experiment}, ], - instance_type="local" - if local - else ( - configuration["__instance_type__"] - if "__instance_type__" in configuration - else instance_type + instance_type=( + "local" + if local + else ( + configuration["__instance_type__"] + if "__instance_type__" in configuration + else instance_type + ) ), instance_count=1, volume_size=30, diff --git a/src/gluonts/nursery/tsbench/src/cli/utils/config.py b/src/gluonts/nursery/tsbench/src/cli/utils/config.py index 0ec10b3863..41425e22e7 100644 --- a/src/gluonts/nursery/tsbench/src/cli/utils/config.py +++ b/src/gluonts/nursery/tsbench/src/cli/utils/config.py @@ -73,14 +73,16 @@ def explode_key_values( into independent configurations. """ all_combinations = { - primary: itertools.product( - *[ - [(option["key"], value) for value in option["values"]] - for option in choices - ] + primary: ( + itertools.product( + *[ + [(option["key"], value) for value in option["values"]] + for option in choices + ] + ) + if choices + else [] ) - if choices - else [] for primary, choices in mapping.items() } diff --git a/src/gluonts/nursery/tsbench/src/tsbench/recommender/_base.py b/src/gluonts/nursery/tsbench/src/tsbench/recommender/_base.py index 7ff3c335c6..078692d442 100644 --- a/src/gluonts/nursery/tsbench/src/tsbench/recommender/_base.py +++ b/src/gluonts/nursery/tsbench/src/tsbench/recommender/_base.py @@ -116,9 +116,11 @@ def recommend( # Then, we perform a nondominated sort argsort = argsort_nondominated( df.to_numpy(), # type: ignore - dim=df.columns.tolist().index(self.focus) - if self.focus is not None - else None, + dim=( + df.columns.tolist().index(self.focus) + if self.focus is not None + else None + ), max_items=max_count, ) diff --git a/src/gluonts/nursery/tsbench/src/tsbench/recommender/_factory.py b/src/gluonts/nursery/tsbench/src/tsbench/recommender/_factory.py index 96cbb43b37..a84ba975ff 100644 --- a/src/gluonts/nursery/tsbench/src/tsbench/recommender/_factory.py +++ b/src/gluonts/nursery/tsbench/src/tsbench/recommender/_factory.py @@ -16,9 +16,9 @@ from ._base import Recommender RECOMMENDER_REGISTRY: Dict[str, Type[Recommender[ModelConfig]]] = {} -ENSEMBLE_RECOMMENDER_REGISTRY: Dict[ - str, Type[Recommender[EnsembleConfig]] -] = {} +ENSEMBLE_RECOMMENDER_REGISTRY: Dict[str, Type[Recommender[EnsembleConfig]]] = ( + {} +) R = TypeVar("R", bound=Type[Recommender[ModelConfig]]) E = TypeVar("E", bound=Type[Recommender[EnsembleConfig]]) diff --git a/src/gluonts/time_feature/lag.py b/src/gluonts/time_feature/lag.py index 9abadb3a7f..d93fee8547 100644 --- a/src/gluonts/time_feature/lag.py +++ b/src/gluonts/time_feature/lag.py @@ -146,7 +146,9 @@ def _make_lags_for_month(multiple, num_cycles=3): + _make_lags_for_hour(offset.n / (60 * 60)) ) else: - raise ValueError(f"invalid frequency | `freq_str={freq_str}` -> `offset_name={offset_name}`") + raise ValueError( + f"invalid frequency | `freq_str={freq_str}` -> `offset_name={offset_name}`" + ) # flatten lags list and filter lags = [ diff --git a/src/gluonts/torch/distributions/binned_uniforms.py b/src/gluonts/torch/distributions/binned_uniforms.py index d8672e8028..46a8a5e586 100644 --- a/src/gluonts/torch/distributions/binned_uniforms.py +++ b/src/gluonts/torch/distributions/binned_uniforms.py @@ -35,6 +35,7 @@ class BinnedUniforms(Distribution): These are softmaxed. The tensor is of shape (*batch_shape,) validate_args (bool) from the pytorch Distribution class """ + arg_constraints = {"logits": constraints.real} support = constraints.real has_rsample = False diff --git a/src/gluonts/torch/distributions/generalized_pareto.py b/src/gluonts/torch/distributions/generalized_pareto.py index a35a2ba1dc..9255474ad2 100644 --- a/src/gluonts/torch/distributions/generalized_pareto.py +++ b/src/gluonts/torch/distributions/generalized_pareto.py @@ -37,6 +37,7 @@ class GeneralizedPareto(Distribution): Tensor containing the beta scale parameters. The tensor is of shape (*batch_shape, 1) """ + arg_constraints = { "xi": constraints.positive, "beta": constraints.positive, diff --git a/src/gluonts/torch/distributions/spliced_binned_pareto.py b/src/gluonts/torch/distributions/spliced_binned_pareto.py index 93eef5fa3c..94e36a56cf 100644 --- a/src/gluonts/torch/distributions/spliced_binned_pareto.py +++ b/src/gluonts/torch/distributions/spliced_binned_pareto.py @@ -36,6 +36,7 @@ class SplicedBinnedPareto(BinnedUniforms): each tail. Default value is 0.05. NB: This symmetric percentile can still represent asymmetric upper and lower tails. """ + arg_constraints = { "logits": constraints.real, "lower_gp_xi": constraints.positive, diff --git a/src/gluonts/torch/model/deep_npts/_estimator.py b/src/gluonts/torch/model/deep_npts/_estimator.py index 8d396d7c96..f3cca87da1 100755 --- a/src/gluonts/torch/model/deep_npts/_estimator.py +++ b/src/gluonts/torch/model/deep_npts/_estimator.py @@ -343,9 +343,11 @@ def train_model( ) data_loader = self.training_data_loader( - transformed_dataset - if not cache_data - else Cached(transformed_dataset), + ( + transformed_dataset + if not cache_data + else Cached(transformed_dataset) + ), batch_size=self.batch_size, num_batches_per_epoch=self.num_batches_per_epoch, ) diff --git a/src/gluonts/torch/model/deepar/module.py b/src/gluonts/torch/model/deepar/module.py index 406468079c..ebe4ef2479 100644 --- a/src/gluonts/torch/model/deepar/module.py +++ b/src/gluonts/torch/model/deepar/module.py @@ -221,7 +221,11 @@ def prepare_rnn_input( past_observed_values: torch.Tensor, future_time_feat: torch.Tensor, future_target: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,]: + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: context = past_target[..., -self.context_length :] observed_context = past_observed_values[..., -self.context_length :] diff --git a/src/gluonts/torch/model/tft/module.py b/src/gluonts/torch/model/tft/module.py index f0dd7c00a1..f113160cb7 100644 --- a/src/gluonts/torch/model/tft/module.py +++ b/src/gluonts/torch/model/tft/module.py @@ -102,65 +102,66 @@ def __init__( self.target_proj = nn.Linear(in_features=1, out_features=self.d_var) # Past-only dynamic features if self.d_past_feat_dynamic_real: - self.past_feat_dynamic_proj: Optional[ - FeatureProjector - ] = FeatureProjector( - feature_dims=self.d_past_feat_dynamic_real, - embedding_dims=[self.d_var] - * len(self.d_past_feat_dynamic_real), + self.past_feat_dynamic_proj: Optional[FeatureProjector] = ( + FeatureProjector( + feature_dims=self.d_past_feat_dynamic_real, + embedding_dims=[self.d_var] + * len(self.d_past_feat_dynamic_real), + ) ) else: self.past_feat_dynamic_proj = None if self.c_past_feat_dynamic_cat: - self.past_feat_dynamic_embed: Optional[ - FeatureEmbedder - ] = FeatureEmbedder( - cardinalities=self.c_past_feat_dynamic_cat, - embedding_dims=[self.d_var] - * len(self.c_past_feat_dynamic_cat), + self.past_feat_dynamic_embed: Optional[FeatureEmbedder] = ( + FeatureEmbedder( + cardinalities=self.c_past_feat_dynamic_cat, + embedding_dims=[self.d_var] + * len(self.c_past_feat_dynamic_cat), + ) ) else: self.past_feat_dynamic_embed = None # Known dynamic features if self.d_feat_dynamic_real: - self.feat_dynamic_proj: Optional[ - FeatureProjector - ] = FeatureProjector( - feature_dims=self.d_feat_dynamic_real, - embedding_dims=[self.d_var] * len(self.d_feat_dynamic_real), + self.feat_dynamic_proj: Optional[FeatureProjector] = ( + FeatureProjector( + feature_dims=self.d_feat_dynamic_real, + embedding_dims=[self.d_var] + * len(self.d_feat_dynamic_real), + ) ) else: self.feat_dynamic_proj = None if self.c_feat_dynamic_cat: - self.feat_dynamic_embed: Optional[ - FeatureEmbedder - ] = FeatureEmbedder( - cardinalities=self.c_feat_dynamic_cat, - embedding_dims=[self.d_var] * len(self.c_feat_dynamic_cat), + self.feat_dynamic_embed: Optional[FeatureEmbedder] = ( + FeatureEmbedder( + cardinalities=self.c_feat_dynamic_cat, + embedding_dims=[self.d_var] * len(self.c_feat_dynamic_cat), + ) ) else: self.feat_dynamic_embed = None # Static features if self.d_feat_static_real: - self.feat_static_proj: Optional[ - FeatureProjector - ] = FeatureProjector( - feature_dims=self.d_feat_static_real, - embedding_dims=[self.d_var] * len(self.d_feat_static_real), + self.feat_static_proj: Optional[FeatureProjector] = ( + FeatureProjector( + feature_dims=self.d_feat_static_real, + embedding_dims=[self.d_var] * len(self.d_feat_static_real), + ) ) else: self.feat_static_proj = None if self.c_feat_static_cat: - self.feat_static_embed: Optional[ - FeatureEmbedder - ] = FeatureEmbedder( - cardinalities=self.c_feat_static_cat, - embedding_dims=[self.d_var] * len(self.c_feat_static_cat), + self.feat_static_embed: Optional[FeatureEmbedder] = ( + FeatureEmbedder( + cardinalities=self.c_feat_static_cat, + embedding_dims=[self.d_var] * len(self.c_feat_static_cat), + ) ) else: self.feat_static_embed = None diff --git a/src/gluonts/torch/model/wavenet/estimator.py b/src/gluonts/torch/model/wavenet/estimator.py index 2da9e87d91..234ecff237 100644 --- a/src/gluonts/torch/model/wavenet/estimator.py +++ b/src/gluonts/torch/model/wavenet/estimator.py @@ -267,12 +267,18 @@ def create_transformation(self) -> Transformation: return Chain( [ RemoveFields(field_names=remove_field_names), - SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0]) - if self.num_feat_static_cat == 0 - else Identity(), - SetField(output_field=FieldName.FEAT_STATIC_REAL, value=[0.0]) - if self.num_feat_static_real == 0 - else Identity(), + ( + SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0]) + if self.num_feat_static_cat == 0 + else Identity() + ), + ( + SetField( + output_field=FieldName.FEAT_STATIC_REAL, value=[0.0] + ) + if self.num_feat_static_real == 0 + else Identity() + ), AsNumpyArray( field=FieldName.FEAT_STATIC_CAT, expected_ndim=1, dtype=int ), diff --git a/src/gluonts/transform/feature.py b/src/gluonts/transform/feature.py index ff9f7bdd5e..f5f519d5a7 100644 --- a/src/gluonts/transform/feature.py +++ b/src/gluonts/transform/feature.py @@ -529,11 +529,11 @@ def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry: lags = np.vstack( [ agg_vals[ - -(l * self.ratio - self.half_window + len(t)) : -( - l * self.ratio - self.half_window + -(l * self.ratio - self.half_window + len(t)) : ( + -(l * self.ratio - self.half_window) + if -(l * self.ratio - self.half_window) != 0 + else None ) - if -(l * self.ratio - self.half_window) != 0 - else None ] for l in self.valid_lags ] diff --git a/src/gluonts/zebras/_period.py b/src/gluonts/zebras/_period.py index 66f32e311b..4cda3178a9 100644 --- a/src/gluonts/zebras/_period.py +++ b/src/gluonts/zebras/_period.py @@ -330,12 +330,10 @@ def __len__(self): return len(self.data) @overload - def __getitem__(self, idx: int) -> Period: - ... + def __getitem__(self, idx: int) -> Period: ... @overload - def __getitem__(self, idx: slice) -> Periods: - ... + def __getitem__(self, idx: slice) -> Periods: ... def __getitem__(self, idx): if _is_number(idx): diff --git a/test/mx/model/seq2seq/test_forking_sequence_splitter.py b/test/mx/model/seq2seq/test_forking_sequence_splitter.py index 6f9076d5d5..00d1df560f 100644 --- a/test/mx/model/seq2seq/test_forking_sequence_splitter.py +++ b/test/mx/model/seq2seq/test_forking_sequence_splitter.py @@ -125,9 +125,11 @@ def make_dataset(N, train_length): pred_length=10, ), ForkingSequenceSplitter( - instance_sampler=ValidationSplitSampler(min_future=dec_len) - if is_train - else TSplitSampler(), + instance_sampler=( + ValidationSplitSampler(min_future=dec_len) + if is_train + else TSplitSampler() + ), enc_len=enc_len, dec_len=dec_len, num_forking=num_forking,