diff --git a/examples/quickstart-pytorch/pytorchexample/task.py b/examples/quickstart-pytorch/pytorchexample/task.py index 8e0808871616..d115c9f1a469 100644 --- a/examples/quickstart-pytorch/pytorchexample/task.py +++ b/examples/quickstart-pytorch/pytorchexample/task.py @@ -100,6 +100,7 @@ def train(net, trainloader, valloader, epochs, learning_rate, device): def test(net, testloader, device): """Validate the model on the test set.""" + net.to(device) # move model to GPU if available criterion = torch.nn.CrossEntropyLoss() correct, loss = 0, 0.0 with torch.no_grad():