Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Jan 27, 2025
1 parent 4cbfb46 commit 60be19e
Show file tree
Hide file tree
Showing 12 changed files with 110 additions and 64 deletions.
4 changes: 2 additions & 2 deletions python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_mnmg.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,10 @@ def load_partitioned_data(
)

# Load features
feature_store["node", "x"] = torch.load(
feature_store["node", "x", None] = torch.load(
os.path.join(feature_path, f"rank={rank}_x.pt")
)
feature_store["node", "y"] = torch.load(
feature_store["node", "y", None] = torch.load(
os.path.join(feature_path, f"rank={rank}_y.pt")
)

Expand Down
4 changes: 2 additions & 2 deletions python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_sg.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ def load_data(
] = data.edge_index

feature_store = cugraph_pyg.data.TensorDictFeatureStore()
feature_store["node", "x"] = data.x
feature_store["node", "y"] = data.y
feature_store["node", "x", None] = data.x
feature_store["node", "y", None] = data.y

return (
(feature_store, graph_store),
Expand Down
4 changes: 2 additions & 2 deletions python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_snmg.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def run_train(
] = ixr

feature_store = TensorDictFeatureStore()
feature_store["node", "x"] = data.x
feature_store["node", "y"] = data.y
feature_store["node", "x", None] = data.x
feature_store["node", "y", None] = data.y

dist.barrier()

Expand Down
24 changes: 15 additions & 9 deletions python/cugraph-pyg/cugraph_pyg/examples/rgcn_link_class_mnmg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Copyright (c) 2024-2025, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -93,7 +93,11 @@ def train(epoch, model, optimizer, train_loader, edge_feature_store, num_steps=N
optimizer.zero_grad()

for i, batch in enumerate(train_loader):
r = edge_feature_store[("n", "e", "n"), "rel"][batch.e_id].flatten().cuda()
r = (
edge_feature_store[("n", "e", "n"), "rel", None][batch.e_id]
.flatten()
.cuda()
)
z = model.encode(batch.edge_index, r)

loss = model.recon_loss(z, batch.edge_index)
Expand Down Expand Up @@ -301,13 +305,18 @@ def load_partitioned_data(rank, edge_path, rel_path, pos_path, neg_path, meta_pa
feature_store = TensorDictFeatureStore()
edge_feature_store = WholeFeatureStore()

with open(meta_path, "r") as f:
meta = json.load(f)

print("num nodes:", meta["num_nodes"])

# Load edge index
graph_store[("n", "e", "n"), "coo"] = torch.load(
os.path.join(edge_path, f"rank={rank}.pt")
)
graph_store[
("n", "e", "n"), "coo", False, (meta["num_nodes"], meta["num_nodes"])
] = torch.load(os.path.join(edge_path, f"rank={rank}.pt"))

# Load edge rel type
edge_feature_store[("n", "e", "n"), "rel"] = torch.load(
edge_feature_store[("n", "e", "n"), "rel", None] = torch.load(
os.path.join(rel_path, f"rank={rank}.pt")
)

Expand All @@ -333,9 +342,6 @@ def load_partitioned_data(rank, edge_path, rel_path, pos_path, neg_path, meta_pa
splits[stage]["tail_neg"] = tail_neg
splits[stage]["relation"] = relation

with open(meta_path, "r") as f:
meta = json.load(f)

return (feature_store, graph_store), edge_feature_store, splits, meta


Expand Down
8 changes: 5 additions & 3 deletions python/cugraph-pyg/cugraph_pyg/examples/rgcn_link_class_sg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Copyright (c) 2024-2025, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -90,8 +90,10 @@ def load_data(
edge_feature_store = TensorDictFeatureStore()
meta = {}

graph_store[("n", "e", "n"), "coo"] = dataset.edge_index
edge_feature_store[("n", "e", "n"), "rel"] = dataset.edge_reltype.pin_memory()
graph_store[
("n", "e", "n"), "coo", False, (dataset.num_nodes, dataset.num_nodes)
] = dataset.edge_index
edge_feature_store[("n", "e", "n"), "rel", None] = dataset.edge_reltype.pin_memory()
meta["num_nodes"] = dataset.num_nodes
meta["num_rels"] = dataset.edge_reltype.max() + 1

Expand Down
12 changes: 7 additions & 5 deletions python/cugraph-pyg/cugraph_pyg/examples/rgcn_link_class_snmg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Copyright (c) 2024-2025, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -111,11 +111,13 @@ def load_data(
feature_store = TensorDictFeatureStore() # empty fs required by PyG
edge_feature_store = WholeFeatureStore()

graph_store[("n", "e", "n"), "coo"] = torch.tensor_split(
data.edge_index.cuda(), world_size, dim=1
)[rank]
print("num nodes:", data.num_nodes)

graph_store[
("n", "e", "n"), "coo", False, (data.num_nodes, data.num_nodes)
] = torch.tensor_split(data.edge_index.cuda(), world_size, dim=1)[rank]

edge_feature_store[("n", "e", "n"), "rel"] = torch.tensor_split(
edge_feature_store[("n", "e", "n"), "rel", None] = torch.tensor_split(
data.edge_reltype.cuda(),
world_size,
)[rank]
Expand Down
10 changes: 5 additions & 5 deletions python/cugraph-pyg/cugraph_pyg/tests/data/test_feature_store.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Copyright (c) 2024-2025, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -30,15 +30,15 @@ def test_tensordict_feature_store_basic_api():

other_features = torch.randint(1024, (10, 5))

feature_store["node", "feat0"] = node_features_0
feature_store["node", "feat1"] = node_features_1
feature_store["other", "feat"] = other_features
feature_store["node", "feat0", None] = node_features_0
feature_store["node", "feat1", None] = node_features_1
feature_store["other", "feat", None] = other_features

assert (feature_store["node"]["feat0"][:] == node_features_0).all()
assert (feature_store["node"]["feat1"][:] == node_features_1).all()
assert (feature_store["other"]["feat"][:] == other_features).all()

assert len(feature_store.get_all_tensor_attrs()) == 3

del feature_store["node", "feat0"]
del feature_store["node", "feat0", None]
assert len(feature_store.get_all_tensor_attrs()) == 2
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Copyright (c) 2024-2025, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -46,22 +46,24 @@ def run_test_wholegraph_feature_store_basic_api(rank, world_size, dtype):
features = features.reshape((features.numel() // 100, 100)).to(torch_dtype)

tensordict_store = TensorDictFeatureStore()
tensordict_store["node", "fea"] = features
tensordict_store["node", "fea", None] = features

whole_store = WholeFeatureStore()
whole_store["node", "fea"] = torch.tensor_split(features, world_size)[rank]
whole_store["node", "fea", None] = torch.tensor_split(features, world_size)[rank]

ix = torch.arange(features.shape[0])
assert (
whole_store["node", "fea"][ix].cpu() == tensordict_store["node", "fea"][ix]
whole_store["node", "fea", None][ix].cpu()
== tensordict_store["node", "fea", None][ix]
).all()

label = torch.arange(0, features.shape[0]).reshape((features.shape[0], 1))
tensordict_store["node", "label"] = label
whole_store["node", "label"] = torch.tensor_split(label, world_size)[rank]
tensordict_store["node", "label", None] = label
whole_store["node", "label", None] = torch.tensor_split(label, world_size)[rank]

assert (
whole_store["node", "fea"][ix].cpu() == tensordict_store["node", "fea"][ix]
whole_store["node", "fea", None][ix].cpu()
== tensordict_store["node", "fea", None][ix]
).all()

pylibwholegraph.torch.initialize.finalize()
Expand Down
8 changes: 6 additions & 2 deletions python/cugraph-pyg/cugraph_pyg/tests/data/test_graph_store.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Copyright (c) 2024-2025, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -30,8 +30,12 @@ def test_graph_store_basic_api():

ei = torch.stack([dst, src])

num_nodes = karate.number_of_nodes()

graph_store = GraphStore()
graph_store.put_edge_index(ei, ("person", "knows", "person"), "coo")
graph_store.put_edge_index(
ei, ("person", "knows", "person"), "coo", False, (num_nodes, num_nodes)
)

rei = graph_store.get_edge_index(("person", "knows", "person"), "coo")

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Copyright (c) 2024-2025, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -30,8 +30,12 @@ def test_graph_store_basic_api_mg():

ei = torch.stack([dst, src])

num_nodes = karate.number_of_nodes()

graph_store = GraphStore(is_multi_gpu=True)
graph_store.put_edge_index(ei, ("person", "knows", "person"), "coo")
graph_store.put_edge_index(
ei, ("person", "knows", "person"), "coo", False, (num_nodes, num_nodes)
)

rei = graph_store.get_edge_index(("person", "knows", "person"), "coo")

Expand Down
50 changes: 37 additions & 13 deletions python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,15 @@ def test_neighbor_loader():

ei = torch.stack([dst, src])

num_nodes = karate.number_of_nodes()

graph_store = GraphStore()
graph_store.put_edge_index(ei, ("person", "knows", "person"), "coo")
graph_store.put_edge_index(
ei, ("person", "knows", "person"), "coo", False, (num_nodes, num_nodes)
)

feature_store = TensorDictFeatureStore()
feature_store["person", "feat"] = torch.randint(128, (34, 16))
feature_store["person", "feat", None] = torch.randint(128, (34, 16))

loader = NeighborLoader(
(feature_store, graph_store),
Expand All @@ -51,7 +55,7 @@ def test_neighbor_loader():

for batch in loader:
assert isinstance(batch, torch_geometric.data.Data)
assert (feature_store["person", "feat"][batch.n_id] == batch.feat).all()
assert (feature_store["person", "feat", None][batch.n_id] == batch.feat).all()


@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
Expand All @@ -64,12 +68,16 @@ def test_neighbor_loader_biased():
]
)

num_nodes = 6

graph_store = GraphStore()
graph_store.put_edge_index(eix, ("person", "knows", "person"), "coo")
graph_store.put_edge_index(
eix, ("person", "knows", "person"), "coo", False, (num_nodes, num_nodes)
)

feature_store = TensorDictFeatureStore()
feature_store["person", "feat"] = torch.randint(128, (6, 12))
feature_store[("person", "knows", "person"), "bias"] = torch.tensor(
feature_store["person", "feat", None] = torch.randint(128, (6, 12))
feature_store[("person", "knows", "person"), "bias", None] = torch.tensor(
[0, 12, 14], dtype=torch.float32
)

Expand Down Expand Up @@ -104,7 +112,7 @@ def test_link_neighbor_loader_basic(
feature_store = TensorDictFeatureStore()

eix = torch.randperm(num_edges)[:select_edges]
graph_store[("n", "e", "n"), "coo"] = torch.stack(
graph_store[("n", "e", "n"), "coo", False, (num_nodes, num_nodes)] = torch.stack(
[
torch.randint(0, num_nodes, (num_edges,)),
torch.randint(0, num_nodes, (num_edges,)),
Expand Down Expand Up @@ -140,7 +148,7 @@ def test_link_neighbor_loader_negative_sampling_basic(batch_size):
feature_store = TensorDictFeatureStore()

eix = torch.randperm(num_edges)[:select_edges]
graph_store[("n", "e", "n"), "coo"] = torch.stack(
graph_store[("n", "e", "n"), "coo", False, (num_nodes, num_nodes)] = torch.stack(
[
torch.randint(0, num_nodes, (num_edges,)),
torch.randint(0, num_nodes, (num_edges,)),
Expand Down Expand Up @@ -174,7 +182,7 @@ def test_link_neighbor_loader_negative_sampling_uneven(batch_size):
feature_store = TensorDictFeatureStore()

eix = torch.randperm(num_edges)[:select_edges]
graph_store[("n", "e", "n"), "coo"] = torch.stack(
graph_store[("n", "e", "n"), "coo", False, (num_nodes, num_nodes)] = torch.stack(
[
torch.randint(0, num_nodes, (num_edges,)),
torch.randint(0, num_nodes, (num_edges,)),
Expand Down Expand Up @@ -205,11 +213,19 @@ def test_neighbor_loader_hetero_basic():
asrc = torch.tensor([0, 1, 2, 3, 3, 0]) # author
adst = torch.tensor([0, 1, 2, 3, 4, 5]) # paper

num_authors = 4
num_papers = 6

graph_store = GraphStore()
feature_store = TensorDictFeatureStore()

graph_store[("paper", "cites", "paper"), "coo"] = [src, dst]
graph_store[("author", "writes", "paper"), "coo"] = [asrc, adst]
graph_store[("paper", "cites", "paper"), "coo", False, (num_papers, num_papers)] = [
src,
dst,
]
graph_store[
("author", "writes", "paper"), "coo", False, (num_authors, num_papers)
] = [asrc, adst]

from cugraph_pyg.loader import NeighborLoader

Expand All @@ -235,11 +251,19 @@ def test_neighbor_loader_hetero_single_etype():
asrc = torch.tensor([0, 1, 2, 3, 3, 0]) # author
adst = torch.tensor([0, 1, 2, 3, 4, 5]) # paper

num_authors = 4
num_papers = 6

graph_store = GraphStore()
feature_store = TensorDictFeatureStore()

graph_store[("paper", "cites", "paper"), "coo"] = [src, dst]
graph_store[("author", "writes", "paper"), "coo"] = [asrc, adst]
graph_store[("paper", "cites", "paper"), "coo", False, (num_papers, num_papers)] = [
src,
dst,
]
graph_store[
("author", "writes", "paper"), "coo", False, (num_authors, num_papers)
] = [asrc, adst]

from cugraph_pyg.loader import NeighborLoader

Expand Down
Loading

0 comments on commit 60be19e

Please sign in to comment.