From 66e50aaa611f6d070bda056a954b90038ba5b5ff Mon Sep 17 00:00:00 2001 From: Darryl Ong Date: Wed, 18 Oct 2023 16:59:38 +0800 Subject: [PATCH] Added all items into test_dec_graph for all users --- cornac/models/gcmc/gcmc.py | 46 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/cornac/models/gcmc/gcmc.py b/cornac/models/gcmc/gcmc.py index 6b1229218..7bcaef0df 100644 --- a/cornac/models/gcmc/gcmc.py +++ b/cornac/models/gcmc/gcmc.py @@ -133,6 +133,44 @@ def _generate_dec_graph(data_set): ) +def _generate_test_dec_graph(data_set): + """ + Generates decoding graph given a cornac data set + + Parameters + ---------- + data_set : cornac.data.dataset.Dataset + The data set as provided by cornac + + Returns + ------- + graph : dgl.heterograph + Heterograph containing user-item edges and nodes + """ + uid_list = data_set.uir_tuple[0] + uid_list = np.unique(uid_list) + + u_list = np.array([user_idx for _ in range(data_set.total_items) for user_idx in uid_list]) + i_list = np.array([item_idx for item_idx in range(data_set.total_items) for _ in uid_list]) + + rating_pairs = (u_list, i_list) + ones = np.ones_like(rating_pairs[0]) + user_item_ratings_coo = sp.coo_matrix( + (ones, rating_pairs), + shape=(data_set.total_users, data_set.total_items), + dtype=np.float32, + ) + + graph = dgl.bipartite_from_scipy( + user_item_ratings_coo, utype="_U", etype="_E", vtype="_V" + ) + + return dgl.heterograph( + {("user", "rate", "item"): graph.edges()}, + num_nodes_dict={"user": data_set.total_users, "item": data_set.total_items}, + ) + + class Model: def __init__( self, @@ -458,7 +496,7 @@ def predict(self, test_set): Dictionary containing '{user_idx}-{item_idx}' as key and {score} as value. """ - test_dec_graph = _generate_dec_graph(test_set) + test_dec_graph = _generate_test_dec_graph(test_set) test_dec_graph = test_dec_graph.int().to(self.device) self.net.eval() @@ -479,7 +517,11 @@ def predict(self, test_set): test_pred_ratings = test_pred_ratings.cpu().numpy() - (u_list, i_list, _) = test_set.uir_tuple + uid_list = test_set.uir_tuple[0] + uid_list = np.unique(uid_list) + + u_list = np.array([user_idx for _ in range(test_set.total_items) for user_idx in uid_list]) + i_list = np.array([item_idx for item_idx in range(test_set.total_items) for _ in uid_list]) u_list = u_list.tolist() i_list = i_list.tolist()