diff --git a/python/ray/train/examples/pytorch_geometric/distributed_sage_example.py b/python/ray/train/examples/pytorch_geometric/distributed_sage_example.py index 4b880607de1c..fb8efa1894ae 100644 --- a/python/ray/train/examples/pytorch_geometric/distributed_sage_example.py +++ b/python/ray/train/examples/pytorch_geometric/distributed_sage_example.py @@ -1,5 +1,4 @@ -# Adapted from https://github.com/pyg-team/pytorch_geometric/blob/master/examples -# /multi_gpu/distributed_sampling.py. +# Adapted from https://github.com/pyg-team/pytorch_geometric/blob/2.1.0/examples/multi_gpu/distributed_sampling.py import os import argparse @@ -46,7 +45,7 @@ def test(self, x_all, subgraph_loader): for batch_size, n_id, adj in subgraph_loader: edge_index, _, size = adj - x = x_all[n_id].to(train.torch.get_device()) + x = x_all[n_id] x_target = x[: size[1]] x = self.convs[i]((x, x_target), edge_index) if i != self.num_layers - 1: