You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When using a PAST_FEAT_DYNAMIC_CAT in the torch implementation of TemporalFusionTransformerEstimator, it fails. Commenting the past_dynamic_cardinalities argument makes the example work
To Reproduce
(Please provide minimal example of code snippet that reproduces the error. For existing examples, please provide link.)
fromgluonts.dataset.commonimportListDatasetimportpandasaspdimportnumpyasnp# Example datadata= [
{
"start": pd.Timestamp("2021-01-01 00:00:00"),
"target": np.random.rand(100),
"feat_static_real": [1.0],
"feat_static_cat": [0],
"feat_dynamic_real": [np.random.rand(100)],
"feat_dynamic_cat": [np.random.randint(0, 10, size=100)],
"past_feat_dynamic_real": [np.random.rand(100)],
"past_feat_dynamic_cat": [np.random.randint(0, 10, size=100)],
}
]
dataset=ListDataset(data, freq="1H")
fromgluonts.torch.model.tftimportTemporalFusionTransformerEstimatorfromgluonts.torch.distributionsimportQuantileOutput# Define the estimatorestimator=TemporalFusionTransformerEstimator(
freq="1H",
prediction_length=24,
context_length=24,
quantiles=[0.1, 0.5, 0.9],
num_heads=4,
hidden_dim=32,
variable_dim=32,
static_dims=[1], # Size of feat_static_realdynamic_dims=[1], # Size of feat_dynamic_realpast_dynamic_dims=[1], # Size of past_feat_dynamic_realstatic_cardinalities=[1], # Cardinality of feat_static_catdynamic_cardinalities=[10], # Cardinality of feat_dynamic_catpast_dynamic_cardinalities=[10], # Cardinality of past_feat_dynamic_cattime_features=None,
lr=0.001,
weight_decay=1e-8,
dropout_rate=0.1,
patience=10,
batch_size=32,
num_batches_per_epoch=5,
)
predictor=estimator.train(dataset)
Error message or code output
(Paste the complete error message, including stack trace, or the undesired output that the above snippet produces.)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/coder/projects/efds-fcpf-forecasting-engine/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
| Name | Type | Params | In sizes | Out sizes
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
0 | model | TemporalFusionTransformerModel | 156 K | [[1, 24], [1, 24], [1, 1], [1, 1], [1, 48, 5], [1, 48, 1], [1, 24, 1], [1, 24, 1]] | [[[1, 24, 3]], [1, 1], [1, 1]]
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
156 K Trainable params
0 Non-trainable params
156 K Total params
0.624 Total estimated model params size (MB)
Output exceeds the [size limit](command:workbench.action.openSettings?[). Open the full output data [in a text editor](command:workbench.action.openLargeOutput?73bf1843-cf5a-4e9a-b065-41980b2d9f71)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[17], line 1
----> 1 predictor = estimator.train(dataset)
File ~/projects/efds-fcpf-forecasting-engine/.venv/lib/python3.11/site-packages/gluonts/torch/model/estimator.py:246, in PyTorchLightningEstimator.train(self, training_data, validation_data, shuffle_buffer_length, cache_data, ckpt_path, **kwargs)
237 def train(
238 self,
239 training_data: Dataset,
(...)
244 **kwargs,
245 ) -> PyTorchPredictor:
--> 246 return self.train_model(
247 training_data,
248 validation_data,
249 shuffle_buffer_length=shuffle_buffer_length,
250 cache_data=cache_data,
251 ckpt_path=ckpt_path,
252 ).predictor
File ~/projects/efds-fcpf-forecasting-engine/.venv/lib/python3.11/site-packages/gluonts/torch/model/estimator.py:209, in PyTorchLightningEstimator.train_model(self, training_data, validation_data, from_predictor, shuffle_buffer_length, cache_data, ckpt_path, **kwargs)
200 custom_callbacks = self.trainer_kwargs.pop("callbacks", [])
201 trainer = pl.Trainer(
202 **{
203 "accelerator": "auto",
(...)
206 }
207 )
--> 209 trainer.fit(
210 model=training_network,
211 train_dataloaders=training_data_loader,
212 val_dataloaders=validation_data_loader,
213 ckpt_path=ckpt_path,
214 )
216 if checkpoint.best_model_path != "":
217 logger.info(
218 f"Loading best model from {checkpoint.best_model_path}"
219 )
File ~/projects/efds-fcpf-forecasting-engine/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:544, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
542 self.state.status = TrainerStatus.RUNNING
543 self.training = True
--> 544 call._call_and_handle_interrupt(
545 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
546 )
File ~/projects/efds-fcpf-forecasting-engine/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:44, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
42 if trainer.strategy.launcher is not None:
43 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 44 return trainer_fn(*args, **kwargs)
46 except _TunerExitException:
47 _call_teardown_hook(trainer)
File ~/projects/efds-fcpf-forecasting-engine/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:580, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
573 assert self.state.fn is not None
574 ckpt_path = self._checkpoint_connector._select_ckpt_path(
575 self.state.fn,
576 ckpt_path,
577 model_provided=True,
578 model_connected=self.lightning_module is not None,
579 )
--> 580 self._run(model, ckpt_path=ckpt_path)
582 assert self.state.stopped
583 self.training = False
File ~/projects/efds-fcpf-forecasting-engine/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:989, in Trainer._run(self, model, ckpt_path)
984 self._signal_connector.register_signal_handlers()
986 # ----------------------------
987 # RUN THE TRAINER
988 # ----------------------------
--> 989 results = self._run_stage()
991 # ----------------------------
992 # POST-Training CLEAN UP
993 # ----------------------------
994 log.debug(f"{self.__class__.__name__}: trainer tearing down")
File ~/projects/efds-fcpf-forecasting-engine/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:1035, in Trainer._run_stage(self)
1033 self._run_sanity_check()
1034 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1035 self.fit_loop.run()
1036 return None
1037 raise RuntimeError(f"Unexpected state {self.state}")
File ~/projects/efds-fcpf-forecasting-engine/.venv/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:194, in _FitLoop.run(self)
193 def run(self) -> None:
--> 194 self.setup_data()
195 if self.skip:
196 return
File ~/projects/efds-fcpf-forecasting-engine/.venv/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:258, in _FitLoop.setup_data(self)
256 self._data_fetcher = _select_data_fetcher(trainer, RunningStage.TRAINING)
257 self._data_fetcher.setup(combined_loader)
--> 258 iter(self._data_fetcher) # creates the iterator inside the fetcher
259 max_batches = sized_len(combined_loader)
260 self.max_batches = max_batches if max_batches is not None else float("inf")
...
547 fill_value=self.dummy_value,
548 dtype=d[field].dtype,
549 )
TypeError: list indices must be integers or slices, not tuple
Environment
Operating system: Debian GNU/Linux 11 (bullseye)
Python version: 3.11
GluonTS version: 0.15.1
Torch version: 2.4.0
(Add as much information about your environment as possible, e.g. dependencies versions.)
pytorch_lightning 2.1.4
The text was updated successfully, but these errors were encountered:
Description
When using a PAST_FEAT_DYNAMIC_CAT in the torch implementation of TemporalFusionTransformerEstimator, it fails. Commenting the past_dynamic_cardinalities argument makes the example work
To Reproduce
(Please provide minimal example of code snippet that reproduces the error. For existing examples, please provide link.)
Error message or code output
(Paste the complete error message, including stack trace, or the undesired output that the above snippet produces.)
Environment
(Add as much information about your environment as possible, e.g. dependencies versions.)
pytorch_lightning 2.1.4
The text was updated successfully, but these errors were encountered: