diff --git a/py/rvspecfit/nn/train_interpolator.py b/py/rvspecfit/nn/train_interpolator.py index 8cf3814..78ed44b 100644 --- a/py/rvspecfit/nn/train_interpolator.py +++ b/py/rvspecfit/nn/train_interpolator.py @@ -202,7 +202,7 @@ def main(args): Tvecs0 = torch.FloatTensor(data=vecs) Tdat0 = torch.as_tensor(dats) - batch_on_dev = args.batch_on_dev + batch_on_dev = args.batch_on_device if not batch_on_dev: Tvecs0 = Tvecs0.to(train_dev)