-
Notifications
You must be signed in to change notification settings - Fork 23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RuntimeError: cumsum_cuda_kernel does not have a deterministic implementation, but you set 'torch.use_deterministic_algorithms(True)' #21
Comments
It seems to be a problem of torchmetrics version. Could you try downgrading your torchmetrics? |
Hi! I have the same issue. Did you fix it? |
Could you try changing MGCA/mgca/models/mgca/mgca_module.py Line 476 in b9ec84f
|
| Name | Type | Params
0 | train_auc | AUROC | 0
1 | val_auc | AUROC | 0
2 | test_auc | AUROC | 0
3 | backbone | ImageEncoder | 89.5 M
4 | linear_layer | SSLEvaluator | 769
769 Trainable params
89.5 M Non-trainable params
89.5 M Total params
357.928 Total estimated model params size (MB)
Validation sanity check: 0%| | 0/2 [00:00<?, ?it/s]Traceback (most recent call last):
File "mgca_finetuner.py", line 128, in
cli_main()
File "mgca_finetuner.py", line 119, in cli_main
trainer.fit(tuner, datamodule)
File "/opt/conda/envs/mgca/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 740, in fit
self._call_and_handle_interrupt(
File "/opt/conda/envs/mgca/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 685, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/opt/conda/envs/mgca/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 777, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/opt/conda/envs/mgca/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1199, in _run
self._dispatch()
File "/opt/conda/envs/mgca/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1279, in _dispatch
self.training_type_plugin.start_training(self)
File "/opt/conda/envs/mgca/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
self._results = trainer.run_stage()
File "/opt/conda/envs/mgca/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1289, in run_stage
return self._run_train()
File "/opt/conda/envs/mgca/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1311, in _run_train
self._run_sanity_check(self.lightning_module)
File "/opt/conda/envs/mgca/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1375, in _run_sanity_check
self._evaluation_loop.run()
File "/opt/conda/envs/mgca/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/opt/conda/envs/mgca/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 110, in advance
dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
File "/opt/conda/envs/mgca/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, **kwargs)
File "/opt/conda/envs/mgca/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 122, in advance
output = self._evaluation_step(batch, batch_idx, dataloader_idx)
File "/opt/conda/envs/mgca/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 217, in _evaluation_step
output = self.trainer.accelerator.validation_step(step_kwargs)
File "/opt/conda/envs/mgca/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 239, in validation_step
return self.training_type_plugin.validation_step(*step_kwargs.values())
File "/opt/conda/envs/mgca/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 219, in validation_step
return self.model.validation_step(*args, **kwargs)
File "/code/mgca/models/ssl_finetuner.py", line 75, in validation_step
self.val_auc(torch.sigmoid(logits).float(), y.long())
File "/opt/conda/envs/mgca/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/envs/mgca/lib/python3.8/site-packages/torchmetrics/metric.py", line 264, in forward
self._forward_cache = self.compute()
File "/opt/conda/envs/mgca/lib/python3.8/site-packages/torchmetrics/metric.py", line 440, in wrapped_func
value = compute(*args, **kwargs)
File "/opt/conda/envs/mgca/lib/python3.8/site-packages/torchmetrics/classification/auroc.py", line 175, in compute
return _auroc_compute(
File "/opt/conda/envs/mgca/lib/python3.8/site-packages/torchmetrics/functional/classification/auroc.py", line 150, in _auroc_compute
fpr, tpr, _ = roc(preds, target, num_classes, pos_label, sample_weights)
File "/opt/conda/envs/mgca/lib/python3.8/site-packages/torchmetrics/functional/classification/roc.py", line 282, in roc
return _roc_compute(preds, target, num_classes, pos_label, sample_weights)
File "/opt/conda/envs/mgca/lib/python3.8/site-packages/torchmetrics/functional/classification/roc.py", line 193, in _roc_compute
return _roc_compute_single_class(preds, target, pos_label, sample_weights)
File "/opt/conda/envs/mgca/lib/python3.8/site-packages/torchmetrics/functional/classification/roc.py", line 67, in _roc_compute_single_class
fps, tps, thresholds = _binary_clf_curve(
File "/opt/conda/envs/mgca/lib/python3.8/site-packages/torchmetrics/functional/classification/precision_recall_curve.py", line 52, in _binary_clf_curve
tps = torch.cumsum(target * weight, dim=0)[threshold_idxs]
RuntimeError: cumsum_cuda_kernel does not have a deterministic implementation, but you set 'torch.use_deterministic_algorithms(True)'. You can turn off determinism just for this operation, or you can use the 'warn_only=True' option, if that's acceptable for your application. You can also file an issue at https://github.com/pytorch/pytorch/issues to help us prioritize adding deterministic support for this operation.
The text was updated successfully, but these errors were encountered: