Skip to content

Commit

Permalink
Added all items into test_dec_graph for all users
Browse files Browse the repository at this point in the history
  • Loading branch information
darrylong committed Oct 18, 2023
1 parent 4737350 commit 66e50aa
Showing 1 changed file with 44 additions and 2 deletions.
46 changes: 44 additions & 2 deletions cornac/models/gcmc/gcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,44 @@ def _generate_dec_graph(data_set):
)


def _generate_test_dec_graph(data_set):
"""
Generates decoding graph given a cornac data set
Parameters
----------
data_set : cornac.data.dataset.Dataset
The data set as provided by cornac
Returns
-------
graph : dgl.heterograph
Heterograph containing user-item edges and nodes
"""
uid_list = data_set.uir_tuple[0]
uid_list = np.unique(uid_list)

u_list = np.array([user_idx for _ in range(data_set.total_items) for user_idx in uid_list])
i_list = np.array([item_idx for item_idx in range(data_set.total_items) for _ in uid_list])

rating_pairs = (u_list, i_list)
ones = np.ones_like(rating_pairs[0])
user_item_ratings_coo = sp.coo_matrix(
(ones, rating_pairs),
shape=(data_set.total_users, data_set.total_items),
dtype=np.float32,
)

graph = dgl.bipartite_from_scipy(
user_item_ratings_coo, utype="_U", etype="_E", vtype="_V"
)

return dgl.heterograph(
{("user", "rate", "item"): graph.edges()},
num_nodes_dict={"user": data_set.total_users, "item": data_set.total_items},
)


class Model:
def __init__(
self,
Expand Down Expand Up @@ -458,7 +496,7 @@ def predict(self, test_set):
Dictionary containing '{user_idx}-{item_idx}' as key
and {score} as value.
"""
test_dec_graph = _generate_dec_graph(test_set)
test_dec_graph = _generate_test_dec_graph(test_set)
test_dec_graph = test_dec_graph.int().to(self.device)

self.net.eval()
Expand All @@ -479,7 +517,11 @@ def predict(self, test_set):

test_pred_ratings = test_pred_ratings.cpu().numpy()

(u_list, i_list, _) = test_set.uir_tuple
uid_list = test_set.uir_tuple[0]
uid_list = np.unique(uid_list)

u_list = np.array([user_idx for _ in range(test_set.total_items) for user_idx in uid_list])
i_list = np.array([item_idx for item_idx in range(test_set.total_items) for _ in uid_list])

u_list = u_list.tolist()
i_list = i_list.tolist()
Expand Down

0 comments on commit 66e50aa

Please sign in to comment.