From f4702c9a119118a417c0645485315e56565f5723 Mon Sep 17 00:00:00 2001 From: ddelange <14880945+ddelange@users.noreply.github.com> Date: Thu, 1 Dec 2022 06:44:32 +0100 Subject: [PATCH] Fix RuntimeError in adapted pyg example https://buildkite.com/ray-project/oss-ci-build-pr/builds/6341#0184b34e-75c5-41b6-9487-64ac1a84d7bc/2036-2444 Signed-off-by: ddelange <14880945+ddelange@users.noreply.github.com> --- .../examples/pytorch_geometric/distributed_sage_example.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/ray/train/examples/pytorch_geometric/distributed_sage_example.py b/python/ray/train/examples/pytorch_geometric/distributed_sage_example.py index 4b880607de1c..fb8efa1894ae 100644 --- a/python/ray/train/examples/pytorch_geometric/distributed_sage_example.py +++ b/python/ray/train/examples/pytorch_geometric/distributed_sage_example.py @@ -1,5 +1,4 @@ -# Adapted from https://github.com/pyg-team/pytorch_geometric/blob/master/examples -# /multi_gpu/distributed_sampling.py. +# Adapted from https://github.com/pyg-team/pytorch_geometric/blob/2.1.0/examples/multi_gpu/distributed_sampling.py import os import argparse @@ -46,7 +45,7 @@ def test(self, x_all, subgraph_loader): for batch_size, n_id, adj in subgraph_loader: edge_index, _, size = adj - x = x_all[n_id].to(train.torch.get_device()) + x = x_all[n_id] x_target = x[: size[1]] x = self.convs[i]((x, x_target), edge_index) if i != self.num_layers - 1: