diff --git a/src/gluonts/core/component.py b/src/gluonts/core/component.py
index 8fcdc07f8f..c5f18d011a 100644
--- a/src/gluonts/core/component.py
+++ b/src/gluonts/core/component.py
@@ -296,12 +296,16 @@ def validator(init):
         init_params = inspect.signature(init).parameters
         init_fields = {
             param.name: (
-                param.annotation
-                if param.annotation != inspect.Parameter.empty
-                else Any,
-                param.default
-                if param.default != inspect.Parameter.empty
-                else ...,
+                (
+                    param.annotation
+                    if param.annotation != inspect.Parameter.empty
+                    else Any
+                ),
+                (
+                    param.default
+                    if param.default != inspect.Parameter.empty
+                    else ...
+                ),
             )
             for param in init_params.values()
             if param.name != "self"
diff --git a/src/gluonts/dataset/arrow/file.py b/src/gluonts/dataset/arrow/file.py
index 40f1a06638..7bdb6cf898 100644
--- a/src/gluonts/dataset/arrow/file.py
+++ b/src/gluonts/dataset/arrow/file.py
@@ -51,16 +51,13 @@ def infer(
             return ArrowStreamFile(path)
 
     @abc.abstractmethod
-    def metadata(self) -> Dict[str, str]:
-        ...
+    def metadata(self) -> Dict[str, str]: ...
 
     @abc.abstractmethod
-    def __iter__(self):
-        ...
+    def __iter__(self): ...
 
     @abc.abstractmethod
-    def __len__(self):
-        ...
+    def __len__(self): ...
 
 
 @dataclass
diff --git a/src/gluonts/dataset/repository/_lstnet.py b/src/gluonts/dataset/repository/_lstnet.py
index ba46cf03ae..e933666c77 100644
--- a/src/gluonts/dataset/repository/_lstnet.py
+++ b/src/gluonts/dataset/repository/_lstnet.py
@@ -200,9 +200,9 @@ def generate_lstnet_dataset(
     meta = MetaData(
         **metadata(
             cardinality=ds_info.num_series,
-            freq=ds_info.freq
-            if ds_info.agg_freq is None
-            else ds_info.agg_freq,
+            freq=(
+                ds_info.freq if ds_info.agg_freq is None else ds_info.agg_freq
+            ),
             prediction_length=prediction_length or ds_info.prediction_length,
         )
     )
diff --git a/src/gluonts/dataset/stat.py b/src/gluonts/dataset/stat.py
index 0becdb4bc4..c3b9a39dfb 100644
--- a/src/gluonts/dataset/stat.py
+++ b/src/gluonts/dataset/stat.py
@@ -418,12 +418,12 @@ def calculate_dataset_statistics(ts_dataset: Any) -> DatasetStatistics:
         max_target_length=max_target_length,
         min_target=min_target,
         num_missing_values=num_missing_values,
-        feat_static_real=observed_feat_static_real
-        if observed_feat_static_real
-        else [],
-        feat_static_cat=observed_feat_static_cat
-        if observed_feat_static_cat
-        else [],
+        feat_static_real=(
+            observed_feat_static_real if observed_feat_static_real else []
+        ),
+        feat_static_cat=(
+            observed_feat_static_cat if observed_feat_static_cat else []
+        ),
         num_past_feat_dynamic_real=num_past_feat_dynamic_real,
         num_feat_dynamic_real=num_feat_dynamic_real,
         num_feat_dynamic_cat=num_feat_dynamic_cat,
diff --git a/src/gluonts/evaluation/_base.py b/src/gluonts/evaluation/_base.py
index f913c7d8fd..b623bf3d75 100644
--- a/src/gluonts/evaluation/_base.py
+++ b/src/gluonts/evaluation/_base.py
@@ -378,9 +378,9 @@ def get_base_metrics(
         return {
             "item_id": forecast.item_id,
             "forecast_start": forecast.start_date,
-            "MSE": mse(pred_target, mean_fcst)
-            if mean_fcst is not None
-            else None,
+            "MSE": (
+                mse(pred_target, mean_fcst) if mean_fcst is not None else None
+            ),
             "abs_error": abs_error(pred_target, median_fcst),
             "abs_target_sum": abs_target_sum(pred_target),
             "abs_target_mean": abs_target_mean(pred_target),
diff --git a/src/gluonts/ext/rotbaum/_model.py b/src/gluonts/ext/rotbaum/_model.py
index 29bd1e3cf7..3c0270a49a 100644
--- a/src/gluonts/ext/rotbaum/_model.py
+++ b/src/gluonts/ext/rotbaum/_model.py
@@ -340,11 +340,11 @@ def _get_and_cache_quantile_computation(
             The quantile of the associated true value bin.
         """
         if feature_vector_in_train not in self.quantile_dicts[quantile]:
-            self.quantile_dicts[quantile][
-                feature_vector_in_train
-            ] = np.percentile(
-                self.id_to_bins[self.preds_to_id[feature_vector_in_train]],
-                quantile * 100,
+            self.quantile_dicts[quantile][feature_vector_in_train] = (
+                np.percentile(
+                    self.id_to_bins[self.preds_to_id[feature_vector_in_train]],
+                    quantile * 100,
+                )
             )
         return self.quantile_dicts[quantile][feature_vector_in_train]
 
diff --git a/src/gluonts/ext/rotbaum/_predictor.py b/src/gluonts/ext/rotbaum/_predictor.py
index 63daea1a97..6631e8dde0 100644
--- a/src/gluonts/ext/rotbaum/_predictor.py
+++ b/src/gluonts/ext/rotbaum/_predictor.py
@@ -254,13 +254,17 @@ def train(
                 target_data[:, train_QRX_only_using_timestep],
             )
             self.model_list = [
-                QRX(
-                    xgboost_params=self.model_params,
-                    min_bin_size=self.min_bin_size,
-                    model=self.model_list[train_QRX_only_using_timestep].model,
+                (
+                    QRX(
+                        xgboost_params=self.model_params,
+                        min_bin_size=self.min_bin_size,
+                        model=self.model_list[
+                            train_QRX_only_using_timestep
+                        ].model,
+                    )
+                    if i != train_QRX_only_using_timestep
+                    else self.model_list[i]
                 )
-                if i != train_QRX_only_using_timestep
-                else self.model_list[i]
                 for i in range(n_models)
             ]
             with concurrent.futures.ThreadPoolExecutor(
diff --git a/src/gluonts/itertools.py b/src/gluonts/itertools.py
index e90281a1d8..198559d6f6 100644
--- a/src/gluonts/itertools.py
+++ b/src/gluonts/itertools.py
@@ -42,11 +42,9 @@
 
 @runtime_checkable
 class SizedIterable(Protocol):
-    def __len__(self):
-        ...
+    def __len__(self): ...
 
-    def __iter__(self):
-        ...
+    def __iter__(self): ...
 
 
 T = TypeVar("T")
diff --git a/src/gluonts/model/forecast_generator.py b/src/gluonts/model/forecast_generator.py
index ac33aae158..d66a1d361d 100644
--- a/src/gluonts/model/forecast_generator.py
+++ b/src/gluonts/model/forecast_generator.py
@@ -132,9 +132,11 @@ def __call__(
                 yield QuantileForecast(
                     output.T,
                     start_date=batch[FieldName.FORECAST_START][i],
-                    item_id=batch[FieldName.ITEM_ID][i]
-                    if FieldName.ITEM_ID in batch
-                    else None,
+                    item_id=(
+                        batch[FieldName.ITEM_ID][i]
+                        if FieldName.ITEM_ID in batch
+                        else None
+                    ),
                     info=batch["info"][i] if "info" in batch else None,
                     forecast_keys=self.quantiles,
                 )
@@ -181,9 +183,11 @@ def __call__(
                 yield SampleForecast(
                     output,
                     start_date=batch[FieldName.FORECAST_START][i],
-                    item_id=batch[FieldName.ITEM_ID][i]
-                    if FieldName.ITEM_ID in batch
-                    else None,
+                    item_id=(
+                        batch[FieldName.ITEM_ID][i]
+                        if FieldName.ITEM_ID in batch
+                        else None
+                    ),
                     info=batch["info"][i] if "info" in batch else None,
                 )
             assert i + 1 == len(batch[FieldName.FORECAST_START])
@@ -221,9 +225,11 @@ def __call__(
                 yield make_distribution_forecast(
                     distr,
                     start_date=batch[FieldName.FORECAST_START][i],
-                    item_id=batch[FieldName.ITEM_ID][i]
-                    if FieldName.ITEM_ID in batch
-                    else None,
+                    item_id=(
+                        batch[FieldName.ITEM_ID][i]
+                        if FieldName.ITEM_ID in batch
+                        else None
+                    ),
                     info=batch["info"][i] if "info" in batch else None,
                 )
             assert i + 1 == len(batch[FieldName.FORECAST_START])
diff --git a/src/gluonts/mx/batchify.py b/src/gluonts/mx/batchify.py
index ea1de2adb0..b534b7f523 100644
--- a/src/gluonts/mx/batchify.py
+++ b/src/gluonts/mx/batchify.py
@@ -117,13 +117,16 @@ def as_in_context(batch: dict, ctx: mx.Context = None) -> DataBatch:
     Move data into new context, should only be in main process.
     """
     batch = {
-        k: v.as_in_context(ctx) if isinstance(v, mx.nd.NDArray)
-        # Workaround due to MXNet not being able to handle NDArrays with 0 in
-        # shape properly:
-        else (
-            stack(v, ctx=ctx, dtype=v.dtype, variable_length=False)
-            if isinstance(v[0], np.ndarray) and 0 in v[0].shape
-            else v
+        k: (
+            v.as_in_context(ctx)
+            if isinstance(v, mx.nd.NDArray)
+            # Workaround due to MXNet not being able to handle NDArrays with 0 in
+            # shape properly:
+            else (
+                stack(v, ctx=ctx, dtype=v.dtype, variable_length=False)
+                if isinstance(v[0], np.ndarray) and 0 in v[0].shape
+                else v
+            )
         )
         for k, v in batch.items()
     }
diff --git a/src/gluonts/mx/block/dropout.py b/src/gluonts/mx/block/dropout.py
index e1de5e22e0..2522820a2c 100644
--- a/src/gluonts/mx/block/dropout.py
+++ b/src/gluonts/mx/block/dropout.py
@@ -231,9 +231,11 @@ def mask(p, like):
         # mask as output, instead of simply copy output to the first element
         # in case that the base cell is ResidualCell
         new_states = [
-            F.where(output_mask, next_states[0], states[0])
-            if p_outputs != 0.0
-            else next_states[0]
+            (
+                F.where(output_mask, next_states[0], states[0])
+                if p_outputs != 0.0
+                else next_states[0]
+            )
         ]
         new_states.extend(
             [
diff --git a/src/gluonts/mx/distribution/box_cox_transform.py b/src/gluonts/mx/distribution/box_cox_transform.py
index 46d9afe87a..99125a032a 100644
--- a/src/gluonts/mx/distribution/box_cox_transform.py
+++ b/src/gluonts/mx/distribution/box_cox_transform.py
@@ -116,6 +116,7 @@ class BoxCoxTransform(Bijection):
         `tol_lambda_1`
     F
     """
+
     arg_names = ["box_cox.lambda_1", "box_cox.lambda_2"]
 
     @validated()
diff --git a/src/gluonts/mx/distribution/inflated_beta.py b/src/gluonts/mx/distribution/inflated_beta.py
index 345126a291..93c0781893 100644
--- a/src/gluonts/mx/distribution/inflated_beta.py
+++ b/src/gluonts/mx/distribution/inflated_beta.py
@@ -113,6 +113,7 @@ class ZeroInflatedBeta(ZeroAndOneInflatedBeta):
         `(*batch_shape, *event_shape)`.
     F
     """
+
     is_reparameterizable = False
 
     @validated()
@@ -145,6 +146,7 @@ class OneInflatedBeta(ZeroAndOneInflatedBeta):
         `(*batch_shape, *event_shape)`.
     F
     """
+
     is_reparameterizable = False
 
     @validated()
diff --git a/src/gluonts/mx/distribution/lds.py b/src/gluonts/mx/distribution/lds.py
index 6167a00d1e..31d3ae225a 100644
--- a/src/gluonts/mx/distribution/lds.py
+++ b/src/gluonts/mx/distribution/lds.py
@@ -409,9 +409,11 @@ def sample(
             #   (num_samples, batch_size, latent_dim, latent_dim)
             # innovation_coeff_t: (num_samples, batch_size, 1, latent_dim)
             emission_coeff_t, transition_coeff_t, innovation_coeff_t = (
-                _broadcast_param(coeff, axes=[0], sizes=[num_samples])
-                if num_samples is not None
-                else coeff
+                (
+                    _broadcast_param(coeff, axes=[0], sizes=[num_samples])
+                    if num_samples is not None
+                    else coeff
+                )
                 for coeff in [
                     self.emission_coeff[t],
                     self.transition_coeff[t],
@@ -458,9 +460,11 @@ def sample(
             if scale is None
             else F.broadcast_mul(
                 samples,
-                scale.expand_dims(axis=1).expand_dims(axis=0)
-                if num_samples is not None
-                else scale.expand_dims(axis=1),
+                (
+                    scale.expand_dims(axis=1).expand_dims(axis=0)
+                    if num_samples is not None
+                    else scale.expand_dims(axis=1)
+                ),
             )
         )
 
diff --git a/src/gluonts/mx/distribution/lowrank_gp.py b/src/gluonts/mx/distribution/lowrank_gp.py
index 443cc40acb..2e6de964a8 100644
--- a/src/gluonts/mx/distribution/lowrank_gp.py
+++ b/src/gluonts/mx/distribution/lowrank_gp.py
@@ -101,9 +101,7 @@ def hybrid_forward(self, F, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
         D_vector = self.proj[1](x_plus_w).squeeze(axis=-1)
 
         d_bias = (
-            0.0
-            if self.sigma_init == 0.0
-            else inv_softplus(self.sigma_init**2)
+            0.0 if self.sigma_init == 0.0 else inv_softplus(self.sigma_init**2)
         )
 
         D_positive = (
diff --git a/src/gluonts/mx/distribution/lowrank_multivariate_gaussian.py b/src/gluonts/mx/distribution/lowrank_multivariate_gaussian.py
index b8131178a5..f446799367 100644
--- a/src/gluonts/mx/distribution/lowrank_multivariate_gaussian.py
+++ b/src/gluonts/mx/distribution/lowrank_multivariate_gaussian.py
@@ -408,9 +408,7 @@ def domain_map(self, F, mu_vector, D_vector, W_vector=None):
         """
 
         d_bias = (
-            inv_softplus(self.sigma_init**2)
-            if self.sigma_init > 0.0
-            else 0.0
+            inv_softplus(self.sigma_init**2) if self.sigma_init > 0.0 else 0.0
         )
 
         # sigma_minimum helps avoiding cholesky problems, we could also jitter
diff --git a/src/gluonts/mx/model/deepar/_network.py b/src/gluonts/mx/model/deepar/_network.py
index f59bc517f9..2b2049afd2 100644
--- a/src/gluonts/mx/model/deepar/_network.py
+++ b/src/gluonts/mx/model/deepar/_network.py
@@ -572,9 +572,11 @@ def unroll_encoder_imputation(
         static_feat = F.concat(
             embedded_cat,
             feat_static_real,
-            F.log(scale)
-            if len(self.target_shape) == 0
-            else F.log(scale.squeeze(axis=1)),
+            (
+                F.log(scale)
+                if len(self.target_shape) == 0
+                else F.log(scale.squeeze(axis=1))
+            ),
             dim=1,
         )
 
@@ -603,9 +605,9 @@ def unroll_encoder_imputation(
         begin_state = self.rnn.begin_state(
             func=F.zeros,
             dtype=self.dtype,
-            batch_size=inputs.shape[0]
-            if isinstance(inputs, mx.nd.NDArray)
-            else 0,
+            batch_size=(
+                inputs.shape[0] if isinstance(inputs, mx.nd.NDArray) else 0
+            ),
         )
 
         unroll_results = self.imputation_rnn_unroll(
@@ -726,9 +728,11 @@ def unroll_encoder_default(
         static_feat = F.concat(
             embedded_cat,
             feat_static_real,
-            F.log(scale)
-            if len(self.target_shape) == 0
-            else F.log(scale.squeeze(axis=1)),
+            (
+                F.log(scale)
+                if len(self.target_shape) == 0
+                else F.log(scale.squeeze(axis=1))
+            ),
             dim=1,
         )
 
@@ -757,9 +761,9 @@ def unroll_encoder_default(
         begin_state = self.rnn.begin_state(
             func=F.zeros,
             dtype=self.dtype,
-            batch_size=inputs.shape[0]
-            if isinstance(inputs, mx.nd.NDArray)
-            else 0,
+            batch_size=(
+                inputs.shape[0] if isinstance(inputs, mx.nd.NDArray) else 0
+            ),
         )
         state = begin_state
         # This is a dummy computation to avoid deferred initialization error
diff --git a/src/gluonts/mx/model/deepvar/_estimator.py b/src/gluonts/mx/model/deepvar/_estimator.py
index 0a1828369d..9881aa5df9 100644
--- a/src/gluonts/mx/model/deepvar/_estimator.py
+++ b/src/gluonts/mx/model/deepvar/_estimator.py
@@ -304,9 +304,9 @@ def __init__(
         self.scaling = scaling
 
         if self.use_marginal_transformation:
-            self.output_transform: Optional[
-                Callable
-            ] = cdf_to_gaussian_forward_transform
+            self.output_transform: Optional[Callable] = (
+                cdf_to_gaussian_forward_transform
+            )
         else:
             self.output_transform = None
 
diff --git a/src/gluonts/mx/model/estimator.py b/src/gluonts/mx/model/estimator.py
index 1cf992a796..122e34fad3 100644
--- a/src/gluonts/mx/model/estimator.py
+++ b/src/gluonts/mx/model/estimator.py
@@ -177,9 +177,9 @@ def train_model(
         transformation = self.create_transformation()
 
         with env._let(max_idle_transforms=max(len(training_data), 100)):
-            transformed_training_data: Union[
-                TransformedDataset, Cached
-            ] = transformation.apply(training_data)
+            transformed_training_data: Union[TransformedDataset, Cached] = (
+                transformation.apply(training_data)
+            )
             if cache_data:
                 transformed_training_data = Cached(transformed_training_data)
 
diff --git a/src/gluonts/mx/model/gpvar/_estimator.py b/src/gluonts/mx/model/gpvar/_estimator.py
index 58edcf3982..29aea6ac31 100644
--- a/src/gluonts/mx/model/gpvar/_estimator.py
+++ b/src/gluonts/mx/model/gpvar/_estimator.py
@@ -63,7 +63,6 @@
 
 
 class GPVAREstimator(GluonEstimator):
-
     """
     Constructs a GPVAR estimator.
 
diff --git a/src/gluonts/mx/model/gpvar/_network.py b/src/gluonts/mx/model/gpvar/_network.py
index 03691aeeec..1a8a60a8b8 100644
--- a/src/gluonts/mx/model/gpvar/_network.py
+++ b/src/gluonts/mx/model/gpvar/_network.py
@@ -116,9 +116,9 @@ def unroll(
                 length=unroll_length,
                 layout="NTC",
                 merge_outputs=True,
-                begin_state=begin_state[i]
-                if begin_state is not None
-                else None,
+                begin_state=(
+                    begin_state[i] if begin_state is not None else None
+                ),
             )
             outputs.append(outputs_single_dim)
             states.append(state)
diff --git a/src/gluonts/mx/model/n_beats/_ensemble.py b/src/gluonts/mx/model/n_beats/_ensemble.py
index f1f8f8eb81..4351163251 100644
--- a/src/gluonts/mx/model/n_beats/_ensemble.py
+++ b/src/gluonts/mx/model/n_beats/_ensemble.py
@@ -201,9 +201,11 @@ def predict(
             yield SampleForecast(
                 output,
                 start_date=start_date,
-                item_id=item[FieldName.ITEM_ID]
-                if FieldName.ITEM_ID in item
-                else None,
+                item_id=(
+                    item[FieldName.ITEM_ID]
+                    if FieldName.ITEM_ID in item
+                    else None
+                ),
                 info=item["info"] if "info" in item else None,
             )
 
diff --git a/src/gluonts/mx/model/renewal/_network.py b/src/gluonts/mx/model/renewal/_network.py
index aa83273506..d0d7bc0edd 100644
--- a/src/gluonts/mx/model/renewal/_network.py
+++ b/src/gluonts/mx/model/renewal/_network.py
@@ -89,9 +89,11 @@ def distribution(
         cond_interval, cond_size = F.split(cond_mean, num_outputs=2, axis=-1)
 
         alpha_biases = [
-            F.broadcast_mul(F.ones_like(cond_interval), bias)
-            if bias is not None
-            else None
+            (
+                F.broadcast_mul(F.ones_like(cond_interval), bias)
+                if bias is not None
+                else None
+            )
             for bias in [interval_alpha_bias, size_alpha_bias]
         ]
 
diff --git a/src/gluonts/mx/model/renewal/_transform.py b/src/gluonts/mx/model/renewal/_transform.py
index 479c6a7f83..884ec31051 100644
--- a/src/gluonts/mx/model/renewal/_transform.py
+++ b/src/gluonts/mx/model/renewal/_transform.py
@@ -53,9 +53,11 @@ def transform(self, data: DataEntry) -> DataEntry:
         target = data[self.target_field]
         data[self.output_field] = np.array(
             [
-                len(target)
-                if isinstance(target, list)
-                else target.shape[self.axis]
+                (
+                    len(target)
+                    if isinstance(target, list)
+                    else target.shape[self.axis]
+                )
             ]
         )
         return data
diff --git a/src/gluonts/mx/model/seq2seq/_transform.py b/src/gluonts/mx/model/seq2seq/_transform.py
index 4134e8b01b..7c443e82d7 100644
--- a/src/gluonts/mx/model/seq2seq/_transform.py
+++ b/src/gluonts/mx/model/seq2seq/_transform.py
@@ -180,21 +180,21 @@ def flatmap_transform(
                         # (Fortran) ordering with strides =
                         # (dtype, dtype*n_rows)
                         stride = decoder_fields.strides
-                        out[self._future(ts_field)][
-                            pad_length_dec:
-                        ] = as_strided(
-                            decoder_fields,
-                            shape=(
-                                self.num_forking - pad_length_dec,
-                                self.dec_len,
-                                ts_len,
-                            ),
-                            # strides for 2D array expanded to 3D array of
-                            # shape (dim1, dim2, dim3) =(1, n_rows, n_cols).
-                            # For transposed data, strides = (dtype, dtype *
-                            # dim1, dtype*dim1*dim2) = (dtype, dtype,
-                            # dtype*n_rows).
-                            strides=stride[0:1] + stride,
+                        out[self._future(ts_field)][pad_length_dec:] = (
+                            as_strided(
+                                decoder_fields,
+                                shape=(
+                                    self.num_forking - pad_length_dec,
+                                    self.dec_len,
+                                    ts_len,
+                                ),
+                                # strides for 2D array expanded to 3D array of
+                                # shape (dim1, dim2, dim3) =(1, n_rows, n_cols).
+                                # For transposed data, strides = (dtype, dtype *
+                                # dim1, dtype*dim1*dim2) = (dtype, dtype,
+                                # dtype*n_rows).
+                                strides=stride[0:1] + stride,
+                            )
                         )
 
                     # edge case for prediction_length = 1
diff --git a/src/gluonts/mx/model/tft/_estimator.py b/src/gluonts/mx/model/tft/_estimator.py
index db19544603..a4d8335719 100644
--- a/src/gluonts/mx/model/tft/_estimator.py
+++ b/src/gluonts/mx/model/tft/_estimator.py
@@ -172,13 +172,13 @@ def __init__(
         self.past_dynamic_feature_dims = {}
         for name in self.past_dynamic_features:
             if name in self.dynamic_cardinalities:
-                self.past_dynamic_cardinalities[
-                    name
-                ] = self.dynamic_cardinalities.pop(name)
+                self.past_dynamic_cardinalities[name] = (
+                    self.dynamic_cardinalities.pop(name)
+                )
             elif name in self.dynamic_feature_dims:
-                self.past_dynamic_feature_dims[
-                    name
-                ] = self.dynamic_feature_dims.pop(name)
+                self.past_dynamic_feature_dims[name] = (
+                    self.dynamic_feature_dims.pop(name)
+                )
             else:
                 raise ValueError(
                     f"Feature name {name} is not provided in feature dicts"
diff --git a/src/gluonts/mx/model/tpp/distribution/base.py b/src/gluonts/mx/model/tpp/distribution/base.py
index a4d93c88aa..990d57d880 100644
--- a/src/gluonts/mx/model/tpp/distribution/base.py
+++ b/src/gluonts/mx/model/tpp/distribution/base.py
@@ -187,6 +187,7 @@ class TPPDistributionOutput(DistributionOutput):
     1. Location param cannot be specified (all distributions must start at 0).
     2. The return type is either TPPDistribution or TPPTransformedDistribution.
     """
+
     distr_cls: type
 
     def distribution(
diff --git a/src/gluonts/mx/model/tpp/distribution/weibull.py b/src/gluonts/mx/model/tpp/distribution/weibull.py
index 3055d47c68..09b73b51f6 100644
--- a/src/gluonts/mx/model/tpp/distribution/weibull.py
+++ b/src/gluonts/mx/model/tpp/distribution/weibull.py
@@ -34,6 +34,7 @@ class Weibull(TPPDistribution):
     parameter :math:`\lambda > 0` and the shape parameter :math:`k > 0`, and
     :math:`\lambda = b^{-1/k}`.
     """
+
     is_reparametrizable = True
 
     @validated()
diff --git a/src/gluonts/mx/model/tpp/predictor.py b/src/gluonts/mx/model/tpp/predictor.py
index cb706931e1..159427651e 100644
--- a/src/gluonts/mx/model/tpp/predictor.py
+++ b/src/gluonts/mx/model/tpp/predictor.py
@@ -86,9 +86,9 @@ def __call__(  # type: ignore
                     start_date=batch["forecast_start"][i],
                     freq=freq,
                     prediction_interval_length=prediction_net.prediction_interval_length,  # noqa: E501
-                    item_id=batch["item_id"][i]
-                    if "item_id" in batch
-                    else None,
+                    item_id=(
+                        batch["item_id"][i] if "item_id" in batch else None
+                    ),
                     info=batch["info"][i] if "info" in batch else None,
                 )
 
diff --git a/src/gluonts/mx/model/transformer/_network.py b/src/gluonts/mx/model/transformer/_network.py
index 5e3082e968..245acd0c57 100644
--- a/src/gluonts/mx/model/transformer/_network.py
+++ b/src/gluonts/mx/model/transformer/_network.py
@@ -194,9 +194,11 @@ def create_network_input(
         # prediction too(batch_size, num_features + prod(target_shape))
         static_feat = F.concat(
             embedded_cat,
-            F.log(scale)
-            if len(self.target_shape) == 0
-            else F.log(scale.squeeze(axis=1)),
+            (
+                F.log(scale)
+                if len(self.target_shape) == 0
+                else F.log(scale.squeeze(axis=1))
+            ),
             dim=1,
         )
 
diff --git a/src/gluonts/mx/model/wavenet/_estimator.py b/src/gluonts/mx/model/wavenet/_estimator.py
index 28fca9d2cc..293676af6b 100644
--- a/src/gluonts/mx/model/wavenet/_estimator.py
+++ b/src/gluonts/mx/model/wavenet/_estimator.py
@@ -309,9 +309,11 @@ def _create_instance_splitter(self, mode: str):
             forecast_start_field=FieldName.FORECAST_START,
             instance_sampler=instance_sampler,
             past_length=self.context_length,
-            future_length=self.prediction_length
-            if mode == "test"
-            else self.train_window_length,
+            future_length=(
+                self.prediction_length
+                if mode == "test"
+                else self.train_window_length
+            ),
             output_NTC=False,
             time_series_fields=[
                 FieldName.FEAT_TIME,
diff --git a/src/gluonts/nursery/SCott/dataset_tools/synthetic.py b/src/gluonts/nursery/SCott/dataset_tools/synthetic.py
index a16034b83c..8e2dc383b6 100644
--- a/src/gluonts/nursery/SCott/dataset_tools/synthetic.py
+++ b/src/gluonts/nursery/SCott/dataset_tools/synthetic.py
@@ -57,12 +57,12 @@ def get_mixed_pattern(unit_length=16, num_duplicates=1000):
             for j in range(num_duplicates):
                 context = torch.arange(context_length, dtype=torch.float)
                 for i in range(1, pattern_number):
-                    context[
-                        unit_length * (i - 1) : unit_length * i
-                    ] = _get_mixed_pattern(
-                        context[unit_length * (i - 1) : unit_length * i]
-                        - unit_length * (i - 1),
-                        pattern[(gid + i) % pattern_number],
+                    context[unit_length * (i - 1) : unit_length * i] = (
+                        _get_mixed_pattern(
+                            context[unit_length * (i - 1) : unit_length * i]
+                            - unit_length * (i - 1),
+                            pattern[(gid + i) % pattern_number],
+                        )
                     )
                 ts_sample = torch.cat(
                     [
diff --git a/src/gluonts/nursery/SCott/preprocess_data.py b/src/gluonts/nursery/SCott/preprocess_data.py
index 4f247eb229..40b181820f 100644
--- a/src/gluonts/nursery/SCott/preprocess_data.py
+++ b/src/gluonts/nursery/SCott/preprocess_data.py
@@ -57,12 +57,12 @@ def get_mixed_pattern(unit_length=16, num_duplicates=1000):
             for j in range(num_duplicates):
                 context = torch.arange(context_length, dtype=torch.float)
                 for i in range(1, pattern_number):
-                    context[
-                        unit_length * (i - 1) : unit_length * i
-                    ] = _get_mixed_pattern(
-                        context[unit_length * (i - 1) : unit_length * i]
-                        - unit_length * (i - 1),
-                        pattern[(gid + i) % pattern_number],
+                    context[unit_length * (i - 1) : unit_length * i] = (
+                        _get_mixed_pattern(
+                            context[unit_length * (i - 1) : unit_length * i]
+                            - unit_length * (i - 1),
+                            pattern[(gid + i) % pattern_number],
+                        )
                     )
                 ts_sample = torch.cat(
                     [
diff --git a/src/gluonts/nursery/daf/network/kernel.py b/src/gluonts/nursery/daf/network/kernel.py
index 3a26e78595..9dccdd3e3b 100644
--- a/src/gluonts/nursery/daf/network/kernel.py
+++ b/src/gluonts/nursery/daf/network/kernel.py
@@ -105,17 +105,21 @@ def __init__(
             _query_weight = nn.Parameter(Tensor(d_hidden, d_hidden))
             self._query_weights.append(_query_weight)
             self._key_weights.append(
-                _query_weight
-                if self.symmetric
-                else nn.Parameter(Tensor(d_hidden, d_hidden)),
+                (
+                    _query_weight
+                    if self.symmetric
+                    else nn.Parameter(Tensor(d_hidden, d_hidden))
+                ),
             )
             if self.bias:
                 _query_bias = nn.Parameter(Tensor(d_hidden))
                 self._query_biases.append(_query_bias)
                 self._key_biases.append(
-                    _query_bias
-                    if self.symmetric
-                    else nn.Parameter(Tensor(d_hidden)),
+                    (
+                        _query_bias
+                        if self.symmetric
+                        else nn.Parameter(Tensor(d_hidden))
+                    ),
                 )
             else:
                 self._query_biases.append(None)
diff --git a/src/gluonts/nursery/daf/tslib/dataset/timeseries.py b/src/gluonts/nursery/daf/tslib/dataset/timeseries.py
index 0036ef6acf..4bd63964b8 100644
--- a/src/gluonts/nursery/daf/tslib/dataset/timeseries.py
+++ b/src/gluonts/nursery/daf/tslib/dataset/timeseries.py
@@ -335,12 +335,10 @@ def __len__(self):
         return len(self.target)
 
     @overload
-    def index_by_timestamp(self, index: pd.Timestamp) -> int:
-        ...
+    def index_by_timestamp(self, index: pd.Timestamp) -> int: ...
 
     @overload
-    def index_by_timestamp(self, index: List[pd.Timestamp]) -> List[int]:
-        ...
+    def index_by_timestamp(self, index: List[pd.Timestamp]) -> List[int]: ...
 
     def index_by_timestamp(self, index):
         return pd.Series(
@@ -348,24 +346,19 @@ def index_by_timestamp(self, index):
         ).loc[index]
 
     @overload
-    def __getitem__(self, index: int) -> TimeSeriesInstant:
-        ...
+    def __getitem__(self, index: int) -> TimeSeriesInstant: ...
 
     @overload
-    def __getitem__(self, index: pd.Timestamp) -> TimeSeriesInstant:
-        ...
+    def __getitem__(self, index: pd.Timestamp) -> TimeSeriesInstant: ...
 
     @overload
-    def __getitem__(self, index: slice) -> TimeSeries:
-        ...
+    def __getitem__(self, index: slice) -> TimeSeries: ...
 
     @overload
-    def __getitem__(self, index: List[int]) -> TimeSeries:
-        ...
+    def __getitem__(self, index: List[int]) -> TimeSeries: ...
 
     @overload
-    def __getitem__(self, index: List[pd.Timestamp]) -> TimeSeries:
-        ...
+    def __getitem__(self, index: List[pd.Timestamp]) -> TimeSeries: ...
 
     def __getitem__(self, index):
         if isinstance(index, pd.Timestamp) or (
diff --git a/src/gluonts/nursery/daf/tslib/engine/evaluator.py b/src/gluonts/nursery/daf/tslib/engine/evaluator.py
index 636f6e6b0d..f61c5a16fe 100644
--- a/src/gluonts/nursery/daf/tslib/engine/evaluator.py
+++ b/src/gluonts/nursery/daf/tslib/engine/evaluator.py
@@ -63,9 +63,11 @@ def load(self, tag: str):
             path = self.log_dir.joinpath(f"{tag}.pt.tar")
             state = pt.load(
                 path,
-                map_location=f"cuda:{self.cuda_device}"
-                if self.cuda_device >= 0
-                else "cpu",
+                map_location=(
+                    f"cuda:{self.cuda_device}"
+                    if self.cuda_device >= 0
+                    else "cpu"
+                ),
             )
             print(f"Load checkpoint from {path}")
         except FileNotFoundError:
diff --git a/src/gluonts/nursery/daf/tslib/engine/trainer.py b/src/gluonts/nursery/daf/tslib/engine/trainer.py
index 07d7f9f9c4..63ad4484ec 100644
--- a/src/gluonts/nursery/daf/tslib/engine/trainer.py
+++ b/src/gluonts/nursery/daf/tslib/engine/trainer.py
@@ -143,9 +143,11 @@ def load(
             path = self.log_dir.joinpath(f"{tag}.pt.tar")
             state = pt.load(
                 path,
-                map_location=f"cuda:{self.cuda_device}"
-                if self.cuda_device >= 0
-                else "cpu",
+                map_location=(
+                    f"cuda:{self.cuda_device}"
+                    if self.cuda_device >= 0
+                    else "cpu"
+                ),
             )
             print(f"Load checkpoint from {path}")
         except FileNotFoundError:
diff --git a/src/gluonts/nursery/few_shot_prediction/src/meta/data/batch.py b/src/gluonts/nursery/few_shot_prediction/src/meta/data/batch.py
index 9c9ff7f227..6512d7f5df 100644
--- a/src/gluonts/nursery/few_shot_prediction/src/meta/data/batch.py
+++ b/src/gluonts/nursery/few_shot_prediction/src/meta/data/batch.py
@@ -29,12 +29,14 @@ class SeriesBatch:
     and the splits sizes indicating the corresponding base dataset.
     """
 
-    sequences: torch.Tensor  # shape [batch, num_sequences, max_sequence_length]
+    sequences: (
+        torch.Tensor
+    )  # shape [batch, num_sequences, max_sequence_length]
     lengths: torch.Tensor  # shape [batch]
     split_sections: torch.Tensor  # shape [batch]
-    scales: Optional[
-        torch.Tensor
-    ] = None  # shape[batch, 2] contains mean and std the ts has been scaled with
+    scales: Optional[torch.Tensor] = (
+        None  # shape[batch, 2] contains mean and std the ts has been scaled with
+    )
 
     @classmethod
     def from_lists(
@@ -61,9 +63,11 @@ def from_lists(
             pad_sequence(values, batch_first=True),
             lengths=torch.as_tensor([len(s) for s in series]),
             split_sections=torch.as_tensor(split_sections),
-            scales=torch.stack([s.scale for s in series])
-            if series[0].scale is not None
-            else None,
+            scales=(
+                torch.stack([s.scale for s in series])
+                if series[0].scale is not None
+                else None
+            ),
         )
 
     def pin_memory(self):
diff --git a/src/gluonts/nursery/few_shot_prediction/src/meta/data/sampling.py b/src/gluonts/nursery/few_shot_prediction/src/meta/data/sampling.py
index a86ddd348b..9fc089fafe 100644
--- a/src/gluonts/nursery/few_shot_prediction/src/meta/data/sampling.py
+++ b/src/gluonts/nursery/few_shot_prediction/src/meta/data/sampling.py
@@ -170,12 +170,14 @@ def __iter__(self) -> Iterator[Triplet]:
                 supps_size=self.support_set_size,
                 length=self.support_length,
                 dataset=self.dataset,
-                cheat_query=cheat_query[0]
-                if np.random.rand() < self.cheat
-                else None,
-                index_iterator=self.index_iterator
-                if self.catch22_nn is None
-                else iter(self.catch22_nn[query_idx]),
+                cheat_query=(
+                    cheat_query[0] if np.random.rand() < self.cheat else None
+                ),
+                index_iterator=(
+                    self.index_iterator
+                    if self.catch22_nn is None
+                    else iter(self.catch22_nn[query_idx])
+                ),
             )
             yield Triplet(support_set, query_past, query_future)
 
@@ -311,9 +313,11 @@ def __getitem__(self, index: int) -> Triplet:
             # but this will be slower
             # seed=self.seed + index if self.seed else None,
             cheat_query=cheat_query if np.random.rand() < self.cheat else None,
-            index_iterator=self.index_iterator
-            if self.catch22_nn is None
-            else iter(self.catch22_nn[query_past[0].item_id]),
+            index_iterator=(
+                self.index_iterator
+                if self.catch22_nn is None
+                else iter(self.catch22_nn[query_past[0].item_id])
+            ),
         )
         return Triplet(support_set, query_past, query_future)
 
diff --git a/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/artificial.py b/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/artificial.py
index 19e1475c66..5a2dab200b 100644
--- a/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/artificial.py
+++ b/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/artificial.py
@@ -593,9 +593,9 @@ def generate_artificial_tuplets(
                 np.arange(0, context_length - signal_length)
             )
             si = np.random.choice(support_set_size)
-            support_set[si][
-                marker_start : marker_start + signal_length
-            ] = query[-signal_length:]
+            support_set[si][marker_start : marker_start + signal_length] = (
+                query[-signal_length:]
+            )
         else:
             signal = np.concatenate(
                 (np.ones((4,)), query[-prediction_length:])
@@ -647,9 +647,9 @@ def generate_artificial_tuplets(
                 np.arange(0, context_length - signal_length)
             )
             si = np.random.choice(support_set_size)
-            support_set[si][
-                marker_start : marker_start + signal_length
-            ] = query[-signal_length:]
+            support_set[si][marker_start : marker_start + signal_length] = (
+                query[-signal_length:]
+            )
         # else:
         #     signal = query[-prediction_length:]
         #     signal_length = prediction_length
diff --git a/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/m4.py b/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/m4.py
index 240afc9b2f..a510676208 100644
--- a/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/m4.py
+++ b/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/m4.py
@@ -127,9 +127,11 @@ def generate_m4_dataset(
 
     start_dates = list(meta_df.StartingDate)
     start_dates = [
-        sd
-        if pd.Timestamp(sd) <= pd.Timestamp("2022")
-        else str(pd.Timestamp("1900"))
+        (
+            sd
+            if pd.Timestamp(sd) <= pd.Timestamp("2022")
+            else str(pd.Timestamp("1900"))
+        )
         for sd in start_dates
     ]
 
diff --git a/src/gluonts/nursery/few_shot_prediction/src/meta/models/module.py b/src/gluonts/nursery/few_shot_prediction/src/meta/models/module.py
index 3a4e24080c..c5f7b3ad7f 100644
--- a/src/gluonts/nursery/few_shot_prediction/src/meta/models/module.py
+++ b/src/gluonts/nursery/few_shot_prediction/src/meta/models/module.py
@@ -109,7 +109,7 @@ def configure_optimizers(self) -> Dict:
                     verbose=True,
                 ),
                 "monitor": self.lr_scheduler_monitor,
-                "frequency": 1
+                "frequency": 1,
                 # If "monitor" references validation metrics, then "frequency" should be set to a
                 # multiple of "trainer.check_val_every_n_epoch".
             },
diff --git a/src/gluonts/nursery/gmm_tpp/gmm_base.py b/src/gluonts/nursery/gmm_tpp/gmm_base.py
index a2ad97691e..902c05fa31 100644
--- a/src/gluonts/nursery/gmm_tpp/gmm_base.py
+++ b/src/gluonts/nursery/gmm_tpp/gmm_base.py
@@ -54,9 +54,11 @@ def __init__(
                 "log_prior_",
                 shape=(num_clusters,),
                 lr_mult=lr_mult,
-                init=mx.init.Constant(np.log(1 / self.num_clusters))
-                if log_prior_ is None
-                else mx.init.Constant(log_prior_),
+                init=(
+                    mx.init.Constant(np.log(1 / self.num_clusters))
+                    if log_prior_ is None
+                    else mx.init.Constant(log_prior_)
+                ),
             )
 
             self.mu_ = self.params.get(
diff --git a/src/gluonts/nursery/robust-mts-attack/attack_and_save.py b/src/gluonts/nursery/robust-mts-attack/attack_and_save.py
index 0f1918bf6f..a3366e3bfc 100644
--- a/src/gluonts/nursery/robust-mts-attack/attack_and_save.py
+++ b/src/gluonts/nursery/robust-mts-attack/attack_and_save.py
@@ -133,9 +133,11 @@ def main():
             )
             best_perturbation = attack.attack_batch(
                 batch,
-                true_future_target=future_target
-                if device == "cpu"
-                else torch.from_numpy(future_target).float().to(device),
+                true_future_target=(
+                    future_target
+                    if device == "cpu"
+                    else torch.from_numpy(future_target).float().to(device)
+                ),
             )
 
             batch_res = AttackResults(
diff --git a/src/gluonts/nursery/robust-mts-attack/attack_sparse_layer.py b/src/gluonts/nursery/robust-mts-attack/attack_sparse_layer.py
index a1309388cd..739888ee89 100644
--- a/src/gluonts/nursery/robust-mts-attack/attack_sparse_layer.py
+++ b/src/gluonts/nursery/robust-mts-attack/attack_sparse_layer.py
@@ -127,9 +127,11 @@ def main():
             )
             best_perturbation = attack.attack_batch(
                 batch,
-                true_future_target=future_target
-                if device == "cpu"
-                else torch.from_numpy(future_target).float().to(device),
+                true_future_target=(
+                    future_target
+                    if device == "cpu"
+                    else torch.from_numpy(future_target).float().to(device)
+                ),
             )
 
             batch_res = AttackResults(
diff --git a/src/gluonts/nursery/robust-mts-attack/multivariate/datasets/dataset.py b/src/gluonts/nursery/robust-mts-attack/multivariate/datasets/dataset.py
index a56828f49e..768fe06070 100644
--- a/src/gluonts/nursery/robust-mts-attack/multivariate/datasets/dataset.py
+++ b/src/gluonts/nursery/robust-mts-attack/multivariate/datasets/dataset.py
@@ -121,9 +121,11 @@ def make_multivariate_dataset(
         align_data=False, num_test_dates=num_test_dates, max_target_dim=dim
     )
     return MultivariateDatasetInfo(
-        dataset_name
-        if dataset_benchmark_name is None
-        else dataset_benchmark_name,
+        (
+            dataset_name
+            if dataset_benchmark_name is None
+            else dataset_benchmark_name
+        ),
         grouper_train(train_ds),
         grouper_test(test_ds),
         prediction_length,
diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/causal_deepar/causal_deepar_network.py b/src/gluonts/nursery/robust-mts-attack/pts/model/causal_deepar/causal_deepar_network.py
index c162cb97f6..12fa525867 100644
--- a/src/gluonts/nursery/robust-mts-attack/pts/model/causal_deepar/causal_deepar_network.py
+++ b/src/gluonts/nursery/robust-mts-attack/pts/model/causal_deepar/causal_deepar_network.py
@@ -216,12 +216,16 @@ def unroll_encoder(
             (
                 embedded_cat,
                 feat_static_real,
-                scale.log()
-                if len(self.target_shape) == 0
-                else scale.squeeze(1).log(),
-                control_scale.log()
-                if len(self.target_shape) == 0
-                else control_scale.squeeze(1).log(),
+                (
+                    scale.log()
+                    if len(self.target_shape) == 0
+                    else scale.squeeze(1).log()
+                ),
+                (
+                    control_scale.log()
+                    if len(self.target_shape) == 0
+                    else control_scale.squeeze(1).log()
+                ),
             ),
             dim=1,
         )
diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/deepar/deepar_network.py b/src/gluonts/nursery/robust-mts-attack/pts/model/deepar/deepar_network.py
index b756bc17e9..8d66c91ec0 100644
--- a/src/gluonts/nursery/robust-mts-attack/pts/model/deepar/deepar_network.py
+++ b/src/gluonts/nursery/robust-mts-attack/pts/model/deepar/deepar_network.py
@@ -190,9 +190,11 @@ def unroll_encoder(
             (
                 embedded_cat,
                 feat_static_real,
-                scale.log()
-                if len(self.target_shape) == 0
-                else scale.squeeze(1).log(),
+                (
+                    scale.log()
+                    if len(self.target_shape) == 0
+                    else scale.squeeze(1).log()
+                ),
             ),
             dim=1,
         )
diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_estimator.py b/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_estimator.py
index fec7d5a2ee..4bea3ddf23 100644
--- a/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_estimator.py
+++ b/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_estimator.py
@@ -139,9 +139,9 @@ def __init__(
         self.scaling = scaling
 
         if self.use_marginal_transformation:
-            self.output_transform: Optional[
-                Callable
-            ] = cdf_to_gaussian_forward_transform
+            self.output_transform: Optional[Callable] = (
+                cdf_to_gaussian_forward_transform
+            )
         else:
             self.output_transform = None
 
diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_network.py b/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_network.py
index 5a4e0e0534..64353b4adf 100644
--- a/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_network.py
+++ b/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_network.py
@@ -188,9 +188,11 @@ def unroll(
             (
                 embedded_cat,
                 feat_static_real,
-                scale.log()
-                if len(self.target_shape) == 0
-                else scale.squeeze(1).log(),
+                (
+                    scale.log()
+                    if len(self.target_shape) == 0
+                    else scale.squeeze(1).log()
+                ),
             ),
             dim=1,
         )
diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/n_beats/n_beats_ensemble.py b/src/gluonts/nursery/robust-mts-attack/pts/model/n_beats/n_beats_ensemble.py
index eb8c3ef7ae..ff9234273d 100644
--- a/src/gluonts/nursery/robust-mts-attack/pts/model/n_beats/n_beats_ensemble.py
+++ b/src/gluonts/nursery/robust-mts-attack/pts/model/n_beats/n_beats_ensemble.py
@@ -102,9 +102,11 @@ def predict(
                 output,
                 start_date=start_date,
                 freq=start_date.freqstr,
-                item_id=item[FieldName.ITEM_ID]
-                if FieldName.ITEM_ID in item
-                else None,
+                item_id=(
+                    item[FieldName.ITEM_ID]
+                    if FieldName.ITEM_ID in item
+                    else None
+                ),
                 info=item["info"] if "info" in item else None,
             )
 
diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_estimator.py b/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_estimator.py
index af2805c0b3..fdb47f7d87 100644
--- a/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_estimator.py
+++ b/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_estimator.py
@@ -106,13 +106,13 @@ def __init__(
         self.past_dynamic_feature_dims = {}
         for name in self.past_dynamic_features:
             if name in self.dynamic_cardinalities:
-                self.past_dynamic_cardinalities[
-                    name
-                ] = self.dynamic_cardinalities.pop(name)
+                self.past_dynamic_cardinalities[name] = (
+                    self.dynamic_cardinalities.pop(name)
+                )
             elif name in self.dynamic_feature_dims:
-                self.past_dynamic_feature_dims[
-                    name
-                ] = self.dynamic_feature_dims.pop(name)
+                self.past_dynamic_feature_dims[name] = (
+                    self.dynamic_feature_dims.pop(name)
+                )
             else:
                 raise ValueError(
                     f"Feature name {name} is not provided in feature dicts"
diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/transformer/transformer_network.py b/src/gluonts/nursery/robust-mts-attack/pts/model/transformer/transformer_network.py
index 37ae2b2bf8..3ac17fef6e 100644
--- a/src/gluonts/nursery/robust-mts-attack/pts/model/transformer/transformer_network.py
+++ b/src/gluonts/nursery/robust-mts-attack/pts/model/transformer/transformer_network.py
@@ -206,9 +206,11 @@ def create_network_input(
             (
                 embedded_cat,
                 feat_static_real,
-                torch.log(scale)
-                if len(self.target_shape) == 0
-                else torch.log(scale.squeeze(1)),
+                (
+                    torch.log(scale)
+                    if len(self.target_shape) == 0
+                    else torch.log(scale.squeeze(1))
+                ),
             ),
             dim=1,
         )
diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_network.py b/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_network.py
index 56b7b4d457..55d60d1537 100644
--- a/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_network.py
+++ b/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_network.py
@@ -165,7 +165,11 @@ def create_network_input(
         future_time_feat: Optional[torch.Tensor],
         future_target_cdf: Optional[torch.Tensor],
         target_dimension_indicator: torch.Tensor,
-    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,]:
+    ) -> Tuple[
+        torch.Tensor,
+        torch.Tensor,
+        torch.Tensor,
+    ]:
         """
         Unrolls the RNN encoder over past and, if present, future data.
         Returns outputs and state of the encoder, plus the scale of
diff --git a/src/gluonts/nursery/robust-mts-attack/pts/modules/gaussian_diffusion.py b/src/gluonts/nursery/robust-mts-attack/pts/modules/gaussian_diffusion.py
index 2819be62a4..73195e7957 100644
--- a/src/gluonts/nursery/robust-mts-attack/pts/modules/gaussian_diffusion.py
+++ b/src/gluonts/nursery/robust-mts-attack/pts/modules/gaussian_diffusion.py
@@ -80,9 +80,7 @@ def __init__(
             if beta_schedule == "linear":
                 betas = np.linspace(1e-4, beta_end, diff_steps)
             elif beta_schedule == "quad":
-                betas = (
-                    np.linspace(1e-4**0.5, beta_end**0.5, diff_steps) ** 2
-                )
+                betas = np.linspace(1e-4**0.5, beta_end**0.5, diff_steps) ** 2
             elif beta_schedule == "const":
                 betas = beta_end * np.ones(diff_steps)
             elif beta_schedule == "jsd":  # 1/T, 1/(T-1), 1/(T-2), ..., 1
diff --git a/src/gluonts/nursery/robust-mts-attack/utils.py b/src/gluonts/nursery/robust-mts-attack/utils.py
index 3bafea8d4d..75a36ec169 100644
--- a/src/gluonts/nursery/robust-mts-attack/utils.py
+++ b/src/gluonts/nursery/robust-mts-attack/utils.py
@@ -263,13 +263,13 @@ def calc_loss(
             if (
                 true_future_target[:, attack_idx][..., target_items] != 0
             ).prod() == 0:
-                mape[attack_type][
-                    testset_idx : testset_idx + batch_size
-                ] = np.abs(
-                    forecasts[attack_type][i][:, :, attack_idx][
-                        ..., target_items
-                    ].mean(1)
-                    - true_future_target[:, attack_idx][..., target_items]
+                mape[attack_type][testset_idx : testset_idx + batch_size] = (
+                    np.abs(
+                        forecasts[attack_type][i][:, :, attack_idx][
+                            ..., target_items
+                        ].mean(1)
+                        - true_future_target[:, attack_idx][..., target_items]
+                    )
                 )
                 mse[attack_type][testset_idx : testset_idx + batch_size] = (
                     forecasts[attack_type][i][:, :, attack_idx][
@@ -290,14 +290,14 @@ def calc_loss(
                         j, testset_idx : testset_idx + batch_size
                     ] = quantile_loss(true, pred, quantile)
             else:
-                mape[attack_type][
-                    testset_idx : testset_idx + batch_size
-                ] = np.abs(
-                    forecasts[attack_type][i][:, :, attack_idx][
-                        ..., target_items
-                    ].mean(1)
-                    / true_future_target[:, attack_idx][..., target_items]
-                    - 1
+                mape[attack_type][testset_idx : testset_idx + batch_size] = (
+                    np.abs(
+                        forecasts[attack_type][i][:, :, attack_idx][
+                            ..., target_items
+                        ].mean(1)
+                        / true_future_target[:, attack_idx][..., target_items]
+                        - 1
+                    )
                 )
                 mse[attack_type][testset_idx : testset_idx + batch_size] = (
                     mape[attack_type][testset_idx : testset_idx + batch_size]
diff --git a/src/gluonts/nursery/san/_network.py b/src/gluonts/nursery/san/_network.py
index b3d08665f2..126dbdd0c2 100644
--- a/src/gluonts/nursery/san/_network.py
+++ b/src/gluonts/nursery/san/_network.py
@@ -213,9 +213,11 @@ def _assemble_covariates(
                 covariates.append(
                     feat_static_real.expand_dims(axis=1).repeat(
                         axis=1,
-                        repeats=self.context_length
-                        if is_past
-                        else self.prediction_length,
+                        repeats=(
+                            self.context_length
+                            if is_past
+                            else self.prediction_length
+                        ),
                     )
                 )
             if len(covariates) > 0:
@@ -231,9 +233,11 @@ def _assemble_covariates(
                 categories.append(
                     feat_static_cat.expand_dims(axis=1).repeat(
                         axis=1,
-                        repeats=self.context_length
-                        if is_past
-                        else self.prediction_length,
+                        repeats=(
+                            self.context_length
+                            if is_past
+                            else self.prediction_length
+                        ),
                     )
                 )
             if len(categories) > 0:
diff --git a/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_estimator.py b/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_estimator.py
index 03bb055f00..8c8b297fa6 100644
--- a/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_estimator.py
+++ b/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_estimator.py
@@ -238,10 +238,11 @@ def __init__(
 
             # adapt window_length if RollingMeanValueImputation is used
             if isinstance(imputation_method, RollingMeanValueImputation):
-                base_estimator_hps_agg[
-                    "imputation_method"
-                ] = RollingMeanValueImputation(
-                    window_size=imputation_method.window_size // agg_multiple
+                base_estimator_hps_agg["imputation_method"] = (
+                    RollingMeanValueImputation(
+                        window_size=imputation_method.window_size
+                        // agg_multiple
+                    )
                 )
 
             # Hack to enforce correct serialization of lags_seq and history length
diff --git a/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/gluonts_fixes.py b/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/gluonts_fixes.py
index a8ac13bc14..b66c8f6887 100644
--- a/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/gluonts_fixes.py
+++ b/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/gluonts_fixes.py
@@ -52,15 +52,17 @@ def batchify_with_dict(
     is_right_pad: bool = True,
 ) -> DataBatch:
     return {
-        key: stack(
-            data=[item[key] for item in data],
-            ctx=ctx,
-            dtype=dtype,
-            variable_length=variable_length,
-            is_right_pad=is_right_pad,
+        key: (
+            stack(
+                data=[item[key] for item in data],
+                ctx=ctx,
+                dtype=dtype,
+                variable_length=variable_length,
+                is_right_pad=is_right_pad,
+            )
+            if not isinstance(data[0][key], dict)
+            else batchify_with_dict(data=[item[key] for item in data])
         )
-        if not isinstance(data[0][key], dict)
-        else batchify_with_dict(data=[item[key] for item in data])
         for key in data[0].keys()
     }
 
diff --git a/src/gluonts/nursery/tsbench/src/cli/analysis/scripts/ensemble.py b/src/gluonts/nursery/tsbench/src/cli/analysis/scripts/ensemble.py
index e1d03db304..8d1a17668c 100644
--- a/src/gluonts/nursery/tsbench/src/cli/analysis/scripts/ensemble.py
+++ b/src/gluonts/nursery/tsbench/src/cli/analysis/scripts/ensemble.py
@@ -66,9 +66,9 @@ def main(
         tracker,
         ensemble_size=size,
         ensemble_weighting=weighting,
-        config_class=MODEL_REGISTRY[model_class]
-        if model_class is not None
-        else None,
+        config_class=(
+            MODEL_REGISTRY[model_class] if model_class is not None else None
+        ),
     )
     df, configs = evaluator.run()
 
diff --git a/src/gluonts/nursery/tsbench/src/cli/evaluations/schedule.py b/src/gluonts/nursery/tsbench/src/cli/evaluations/schedule.py
index a49864cc59..80cb563da4 100644
--- a/src/gluonts/nursery/tsbench/src/cli/evaluations/schedule.py
+++ b/src/gluonts/nursery/tsbench/src/cli/evaluations/schedule.py
@@ -177,12 +177,14 @@ def job_factory() -> str:
             tags=[
                 {"Key": "Experiment", "Value": experiment},
             ],
-            instance_type="local"
-            if local
-            else (
-                configuration["__instance_type__"]
-                if "__instance_type__" in configuration
-                else instance_type
+            instance_type=(
+                "local"
+                if local
+                else (
+                    configuration["__instance_type__"]
+                    if "__instance_type__" in configuration
+                    else instance_type
+                )
             ),
             instance_count=1,
             volume_size=30,
diff --git a/src/gluonts/nursery/tsbench/src/cli/utils/config.py b/src/gluonts/nursery/tsbench/src/cli/utils/config.py
index 0ec10b3863..41425e22e7 100644
--- a/src/gluonts/nursery/tsbench/src/cli/utils/config.py
+++ b/src/gluonts/nursery/tsbench/src/cli/utils/config.py
@@ -73,14 +73,16 @@ def explode_key_values(
     into independent configurations.
     """
     all_combinations = {
-        primary: itertools.product(
-            *[
-                [(option["key"], value) for value in option["values"]]
-                for option in choices
-            ]
+        primary: (
+            itertools.product(
+                *[
+                    [(option["key"], value) for value in option["values"]]
+                    for option in choices
+                ]
+            )
+            if choices
+            else []
         )
-        if choices
-        else []
         for primary, choices in mapping.items()
     }
 
diff --git a/src/gluonts/nursery/tsbench/src/tsbench/recommender/_base.py b/src/gluonts/nursery/tsbench/src/tsbench/recommender/_base.py
index 7ff3c335c6..078692d442 100644
--- a/src/gluonts/nursery/tsbench/src/tsbench/recommender/_base.py
+++ b/src/gluonts/nursery/tsbench/src/tsbench/recommender/_base.py
@@ -116,9 +116,11 @@ def recommend(
         # Then, we perform a nondominated sort
         argsort = argsort_nondominated(
             df.to_numpy(),  # type: ignore
-            dim=df.columns.tolist().index(self.focus)
-            if self.focus is not None
-            else None,
+            dim=(
+                df.columns.tolist().index(self.focus)
+                if self.focus is not None
+                else None
+            ),
             max_items=max_count,
         )
 
diff --git a/src/gluonts/nursery/tsbench/src/tsbench/recommender/_factory.py b/src/gluonts/nursery/tsbench/src/tsbench/recommender/_factory.py
index 96cbb43b37..a84ba975ff 100644
--- a/src/gluonts/nursery/tsbench/src/tsbench/recommender/_factory.py
+++ b/src/gluonts/nursery/tsbench/src/tsbench/recommender/_factory.py
@@ -16,9 +16,9 @@
 from ._base import Recommender
 
 RECOMMENDER_REGISTRY: Dict[str, Type[Recommender[ModelConfig]]] = {}
-ENSEMBLE_RECOMMENDER_REGISTRY: Dict[
-    str, Type[Recommender[EnsembleConfig]]
-] = {}
+ENSEMBLE_RECOMMENDER_REGISTRY: Dict[str, Type[Recommender[EnsembleConfig]]] = (
+    {}
+)
 
 R = TypeVar("R", bound=Type[Recommender[ModelConfig]])
 E = TypeVar("E", bound=Type[Recommender[EnsembleConfig]])
diff --git a/src/gluonts/time_feature/lag.py b/src/gluonts/time_feature/lag.py
index 9abadb3a7f..d93fee8547 100644
--- a/src/gluonts/time_feature/lag.py
+++ b/src/gluonts/time_feature/lag.py
@@ -146,7 +146,9 @@ def _make_lags_for_month(multiple, num_cycles=3):
             + _make_lags_for_hour(offset.n / (60 * 60))
         )
     else:
-        raise ValueError(f"invalid frequency | `freq_str={freq_str}` -> `offset_name={offset_name}`")
+        raise ValueError(
+            f"invalid frequency | `freq_str={freq_str}` -> `offset_name={offset_name}`"
+        )
 
     # flatten lags list and filter
     lags = [
diff --git a/src/gluonts/torch/distributions/binned_uniforms.py b/src/gluonts/torch/distributions/binned_uniforms.py
index d8672e8028..46a8a5e586 100644
--- a/src/gluonts/torch/distributions/binned_uniforms.py
+++ b/src/gluonts/torch/distributions/binned_uniforms.py
@@ -35,6 +35,7 @@ class BinnedUniforms(Distribution):
             These are softmaxed. The tensor is of shape (*batch_shape,)
         validate_args (bool) from the pytorch Distribution class
     """
+
     arg_constraints = {"logits": constraints.real}
     support = constraints.real
     has_rsample = False
diff --git a/src/gluonts/torch/distributions/generalized_pareto.py b/src/gluonts/torch/distributions/generalized_pareto.py
index a35a2ba1dc..9255474ad2 100644
--- a/src/gluonts/torch/distributions/generalized_pareto.py
+++ b/src/gluonts/torch/distributions/generalized_pareto.py
@@ -37,6 +37,7 @@ class GeneralizedPareto(Distribution):
         Tensor containing the beta scale parameters. The tensor is of
         shape (*batch_shape, 1)
     """
+
     arg_constraints = {
         "xi": constraints.positive,
         "beta": constraints.positive,
diff --git a/src/gluonts/torch/distributions/spliced_binned_pareto.py b/src/gluonts/torch/distributions/spliced_binned_pareto.py
index 93eef5fa3c..94e36a56cf 100644
--- a/src/gluonts/torch/distributions/spliced_binned_pareto.py
+++ b/src/gluonts/torch/distributions/spliced_binned_pareto.py
@@ -36,6 +36,7 @@ class SplicedBinnedPareto(BinnedUniforms):
             each tail. Default value is 0.05. NB: This symmetric percentile
             can still represent asymmetric upper and lower tails.
     """
+
     arg_constraints = {
         "logits": constraints.real,
         "lower_gp_xi": constraints.positive,
diff --git a/src/gluonts/torch/model/deep_npts/_estimator.py b/src/gluonts/torch/model/deep_npts/_estimator.py
index 8d396d7c96..f3cca87da1 100755
--- a/src/gluonts/torch/model/deep_npts/_estimator.py
+++ b/src/gluonts/torch/model/deep_npts/_estimator.py
@@ -343,9 +343,11 @@ def train_model(
         )
 
         data_loader = self.training_data_loader(
-            transformed_dataset
-            if not cache_data
-            else Cached(transformed_dataset),
+            (
+                transformed_dataset
+                if not cache_data
+                else Cached(transformed_dataset)
+            ),
             batch_size=self.batch_size,
             num_batches_per_epoch=self.num_batches_per_epoch,
         )
diff --git a/src/gluonts/torch/model/deepar/module.py b/src/gluonts/torch/model/deepar/module.py
index 406468079c..ebe4ef2479 100644
--- a/src/gluonts/torch/model/deepar/module.py
+++ b/src/gluonts/torch/model/deepar/module.py
@@ -221,7 +221,11 @@ def prepare_rnn_input(
         past_observed_values: torch.Tensor,
         future_time_feat: torch.Tensor,
         future_target: Optional[torch.Tensor] = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,]:
+    ) -> Tuple[
+        torch.Tensor,
+        torch.Tensor,
+        torch.Tensor,
+    ]:
         context = past_target[..., -self.context_length :]
         observed_context = past_observed_values[..., -self.context_length :]
 
diff --git a/src/gluonts/torch/model/tft/module.py b/src/gluonts/torch/model/tft/module.py
index f0dd7c00a1..f113160cb7 100644
--- a/src/gluonts/torch/model/tft/module.py
+++ b/src/gluonts/torch/model/tft/module.py
@@ -102,65 +102,66 @@ def __init__(
         self.target_proj = nn.Linear(in_features=1, out_features=self.d_var)
         # Past-only dynamic features
         if self.d_past_feat_dynamic_real:
-            self.past_feat_dynamic_proj: Optional[
-                FeatureProjector
-            ] = FeatureProjector(
-                feature_dims=self.d_past_feat_dynamic_real,
-                embedding_dims=[self.d_var]
-                * len(self.d_past_feat_dynamic_real),
+            self.past_feat_dynamic_proj: Optional[FeatureProjector] = (
+                FeatureProjector(
+                    feature_dims=self.d_past_feat_dynamic_real,
+                    embedding_dims=[self.d_var]
+                    * len(self.d_past_feat_dynamic_real),
+                )
             )
         else:
             self.past_feat_dynamic_proj = None
 
         if self.c_past_feat_dynamic_cat:
-            self.past_feat_dynamic_embed: Optional[
-                FeatureEmbedder
-            ] = FeatureEmbedder(
-                cardinalities=self.c_past_feat_dynamic_cat,
-                embedding_dims=[self.d_var]
-                * len(self.c_past_feat_dynamic_cat),
+            self.past_feat_dynamic_embed: Optional[FeatureEmbedder] = (
+                FeatureEmbedder(
+                    cardinalities=self.c_past_feat_dynamic_cat,
+                    embedding_dims=[self.d_var]
+                    * len(self.c_past_feat_dynamic_cat),
+                )
             )
         else:
             self.past_feat_dynamic_embed = None
 
         # Known dynamic features
         if self.d_feat_dynamic_real:
-            self.feat_dynamic_proj: Optional[
-                FeatureProjector
-            ] = FeatureProjector(
-                feature_dims=self.d_feat_dynamic_real,
-                embedding_dims=[self.d_var] * len(self.d_feat_dynamic_real),
+            self.feat_dynamic_proj: Optional[FeatureProjector] = (
+                FeatureProjector(
+                    feature_dims=self.d_feat_dynamic_real,
+                    embedding_dims=[self.d_var]
+                    * len(self.d_feat_dynamic_real),
+                )
             )
         else:
             self.feat_dynamic_proj = None
 
         if self.c_feat_dynamic_cat:
-            self.feat_dynamic_embed: Optional[
-                FeatureEmbedder
-            ] = FeatureEmbedder(
-                cardinalities=self.c_feat_dynamic_cat,
-                embedding_dims=[self.d_var] * len(self.c_feat_dynamic_cat),
+            self.feat_dynamic_embed: Optional[FeatureEmbedder] = (
+                FeatureEmbedder(
+                    cardinalities=self.c_feat_dynamic_cat,
+                    embedding_dims=[self.d_var] * len(self.c_feat_dynamic_cat),
+                )
             )
         else:
             self.feat_dynamic_embed = None
 
         # Static features
         if self.d_feat_static_real:
-            self.feat_static_proj: Optional[
-                FeatureProjector
-            ] = FeatureProjector(
-                feature_dims=self.d_feat_static_real,
-                embedding_dims=[self.d_var] * len(self.d_feat_static_real),
+            self.feat_static_proj: Optional[FeatureProjector] = (
+                FeatureProjector(
+                    feature_dims=self.d_feat_static_real,
+                    embedding_dims=[self.d_var] * len(self.d_feat_static_real),
+                )
             )
         else:
             self.feat_static_proj = None
 
         if self.c_feat_static_cat:
-            self.feat_static_embed: Optional[
-                FeatureEmbedder
-            ] = FeatureEmbedder(
-                cardinalities=self.c_feat_static_cat,
-                embedding_dims=[self.d_var] * len(self.c_feat_static_cat),
+            self.feat_static_embed: Optional[FeatureEmbedder] = (
+                FeatureEmbedder(
+                    cardinalities=self.c_feat_static_cat,
+                    embedding_dims=[self.d_var] * len(self.c_feat_static_cat),
+                )
             )
         else:
             self.feat_static_embed = None
diff --git a/src/gluonts/torch/model/wavenet/estimator.py b/src/gluonts/torch/model/wavenet/estimator.py
index 2da9e87d91..234ecff237 100644
--- a/src/gluonts/torch/model/wavenet/estimator.py
+++ b/src/gluonts/torch/model/wavenet/estimator.py
@@ -267,12 +267,18 @@ def create_transformation(self) -> Transformation:
         return Chain(
             [
                 RemoveFields(field_names=remove_field_names),
-                SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0])
-                if self.num_feat_static_cat == 0
-                else Identity(),
-                SetField(output_field=FieldName.FEAT_STATIC_REAL, value=[0.0])
-                if self.num_feat_static_real == 0
-                else Identity(),
+                (
+                    SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0])
+                    if self.num_feat_static_cat == 0
+                    else Identity()
+                ),
+                (
+                    SetField(
+                        output_field=FieldName.FEAT_STATIC_REAL, value=[0.0]
+                    )
+                    if self.num_feat_static_real == 0
+                    else Identity()
+                ),
                 AsNumpyArray(
                     field=FieldName.FEAT_STATIC_CAT, expected_ndim=1, dtype=int
                 ),
diff --git a/src/gluonts/transform/feature.py b/src/gluonts/transform/feature.py
index ff9f7bdd5e..f5f519d5a7 100644
--- a/src/gluonts/transform/feature.py
+++ b/src/gluonts/transform/feature.py
@@ -529,11 +529,11 @@ def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry:
         lags = np.vstack(
             [
                 agg_vals[
-                    -(l * self.ratio - self.half_window + len(t)) : -(
-                        l * self.ratio - self.half_window
+                    -(l * self.ratio - self.half_window + len(t)) : (
+                        -(l * self.ratio - self.half_window)
+                        if -(l * self.ratio - self.half_window) != 0
+                        else None
                     )
-                    if -(l * self.ratio - self.half_window) != 0
-                    else None
                 ]
                 for l in self.valid_lags
             ]
diff --git a/src/gluonts/zebras/_period.py b/src/gluonts/zebras/_period.py
index 66f32e311b..4cda3178a9 100644
--- a/src/gluonts/zebras/_period.py
+++ b/src/gluonts/zebras/_period.py
@@ -330,12 +330,10 @@ def __len__(self):
         return len(self.data)
 
     @overload
-    def __getitem__(self, idx: int) -> Period:
-        ...
+    def __getitem__(self, idx: int) -> Period: ...
 
     @overload
-    def __getitem__(self, idx: slice) -> Periods:
-        ...
+    def __getitem__(self, idx: slice) -> Periods: ...
 
     def __getitem__(self, idx):
         if _is_number(idx):
diff --git a/test/mx/model/seq2seq/test_forking_sequence_splitter.py b/test/mx/model/seq2seq/test_forking_sequence_splitter.py
index 6f9076d5d5..00d1df560f 100644
--- a/test/mx/model/seq2seq/test_forking_sequence_splitter.py
+++ b/test/mx/model/seq2seq/test_forking_sequence_splitter.py
@@ -125,9 +125,11 @@ def make_dataset(N, train_length):
                 pred_length=10,
             ),
             ForkingSequenceSplitter(
-                instance_sampler=ValidationSplitSampler(min_future=dec_len)
-                if is_train
-                else TSplitSampler(),
+                instance_sampler=(
+                    ValidationSplitSampler(min_future=dec_len)
+                    if is_train
+                    else TSplitSampler()
+                ),
                 enc_len=enc_len,
                 dec_len=dec_len,
                 num_forking=num_forking,