diff --git a/genesis/grad/creation_ops.py b/genesis/grad/creation_ops.py index ca01e32a..863edce0 100644 --- a/genesis/grad/creation_ops.py +++ b/genesis/grad/creation_ops.py @@ -89,7 +89,7 @@ def from_torch(torch_tensor, dtype=None, requires_grad=False, detach=True, scene ) requires_grad = True - gs_tensor = Tensor(torch_tensor.to(gs.device).to(dtype), scene=scene).clone() + gs_tensor = Tensor(torch_tensor.to(device=gs.device, dtype=dtype), scene=scene).clone() if detach: gs_tensor = gs_tensor.detach(sceneless=False)