diff --git a/linear_relational/lib/torch_utils.py b/linear_relational/lib/torch_utils.py index 7a9ca86..a1a4eeb 100644 --- a/linear_relational/lib/torch_utils.py +++ b/linear_relational/lib/torch_utils.py @@ -22,7 +22,7 @@ def get_device(model: nn.Module) -> torch.device: """ Returns the device on which the model is running. """ - if isinstance(model.device, torch.device): + if hasattr(model, "device") and isinstance(model.device, torch.device): return model.device return next(model.parameters()).device