Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize LightGCN Model #531

Merged
merged 25 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 92 additions & 67 deletions cornac/models/lightgcn/lightgcn.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn


USER_KEY = "user"
ITEM_KEY = "item"


def construct_graph(data_set):
"""
Generates graph given a cornac data set
Expand All @@ -14,89 +19,109 @@ def construct_graph(data_set):
The data set as provided by cornac
"""
user_indices, item_indices, _ = data_set.uir_tuple
user_nodes, item_nodes = (
torch.from_numpy(user_indices),
torch.from_numpy(
item_indices + data_set.total_users
), # increment item node idx by num users
)

u = torch.cat([user_nodes, item_nodes], dim=0)
v = torch.cat([item_nodes, user_nodes], dim=0)
data_dict = {
(USER_KEY, "user_item", ITEM_KEY): (user_indices, item_indices),
(ITEM_KEY, "item_user", USER_KEY): (item_indices, user_indices),
}
num_dict = {USER_KEY: data_set.total_users, ITEM_KEY: data_set.total_items}

g = dgl.graph((u, v), num_nodes=(data_set.total_users + data_set.total_items))
return g
return dgl.heterograph(data_dict, num_nodes_dict=num_dict)


class GCNLayer(nn.Module):
def __init__(self):
def __init__(self, norm_dict):
super(GCNLayer, self).__init__()

def forward(self, graph, src_embedding, dst_embedding):
with graph.local_scope():
inner_product = torch.cat((src_embedding, dst_embedding), dim=0)

out_degs = graph.out_degrees().to(src_embedding.device).float().clamp(min=1)
norm_out_degs = torch.pow(out_degs, -0.5).view(-1, 1) # D^-1/2

inner_product = inner_product * norm_out_degs

graph.ndata["h"] = inner_product
graph.update_all(
message_func=fn.copy_u("h", "m"), reduce_func=fn.sum("m", "h")
)

res = graph.ndata["h"]

in_degs = graph.in_degrees().to(src_embedding.device).float().clamp(min=1)
norm_in_degs = torch.pow(in_degs, -0.5).view(-1, 1) # D^-1/2

res = res * norm_in_degs
return res
# norm
self.norm_dict = norm_dict

def forward(self, g, feat_dict):
funcs = {} # message and reduce functions dict
# for each type of edges, compute messages and reduce them all
for srctype, etype, dsttype in g.canonical_etypes:
src, dst = g.edges(etype=(srctype, etype, dsttype))
norm = self.norm_dict[(srctype, etype, dsttype)]
# TODO: CHECK HERE
messages = norm * feat_dict[srctype][src] # compute messages
g.edges[(srctype, etype, dsttype)].data[
etype
] = messages # store in edata
funcs[(srctype, etype, dsttype)] = (
fn.copy_e(etype, "m"),
fn.sum("m", "h"),
) # define message and reduce functions

g.multi_update_all(
funcs, "sum"
) # update all, reduce by first type-wisely then across different types
feature_dict = {}
for ntype in g.ntypes:
h = F.normalize(g.nodes[ntype].data["h"], dim=1, p=2) # l2 normalize
feature_dict[ntype] = h
return feature_dict


class Model(nn.Module):
def __init__(self, user_size, item_size, hidden_size, num_layers=3, device=None):
def __init__(self, g, in_size, num_layers, lambda_reg, device=None):
super(Model, self).__init__()
self.user_size = user_size
self.item_size = item_size
self.hidden_size = hidden_size
self.embedding_weights = self._init_weights()
self.layers = nn.ModuleList([GCNLayer() for _ in range(num_layers)])
self.norm_dict = dict()
self.lambda_reg = lambda_reg
self.device = device

def forward(self, graph):
user_embedding = self.embedding_weights["user_embedding"]
item_embedding = self.embedding_weights["item_embedding"]
for srctype, etype, dsttype in g.canonical_etypes:
src, dst = g.edges(etype=(srctype, etype, dsttype))
dst_degree = g.in_degrees(
dst, etype=(srctype, etype, dsttype)
).float() # obtain degrees
src_degree = g.out_degrees(src, etype=(srctype, etype, dsttype)).float()
norm = torch.pow(src_degree * dst_degree, -0.5).unsqueeze(1) # compute norm
self.norm_dict[(srctype, etype, dsttype)] = norm

for i, layer in enumerate(self.layers, start=1):
if i == 1:
embeddings = layer(graph, user_embedding, item_embedding)
else:
embeddings = layer(
graph, embeddings[: self.user_size], embeddings[self.user_size:]
)
self.layers = nn.ModuleList([GCNLayer(self.norm_dict) for _ in range(num_layers)])

user_embedding = user_embedding + embeddings[: self.user_size] * (
1 / (i + 1)
)
item_embedding = item_embedding + embeddings[self.user_size:] * (
1 / (i + 1)
)

return user_embedding, item_embedding
self.initializer = nn.init.xavier_uniform_

def _init_weights(self):
initializer = nn.init.xavier_uniform_

weights_dict = nn.ParameterDict(
# embeddings for different types of nodes
self.feature_dict = nn.ParameterDict(
{
"user_embedding": nn.Parameter(
initializer(torch.empty(self.user_size, self.hidden_size))
),
"item_embedding": nn.Parameter(
initializer(torch.empty(self.item_size, self.hidden_size))
),
ntype: nn.Parameter(
self.initializer(torch.empty(g.num_nodes(ntype), in_size))
)
for ntype in g.ntypes
}
)
return weights_dict

def forward(self, g, users=None, pos_items=None, neg_items=None):
h_dict = {ntype: self.feature_dict[ntype] for ntype in g.ntypes}
# obtain features of each layer and concatenate them all
user_embeds = h_dict[USER_KEY]
item_embeds = h_dict[ITEM_KEY]

for k, layer in enumerate(self.layers):
h_dict = layer(g, h_dict)
user_embeds = user_embeds + (h_dict[USER_KEY] * 1 / (k + 1))
item_embeds = item_embeds + (h_dict[ITEM_KEY] * 1 / (k + 1))

u_g_embeddings = user_embeds if users is None else user_embeds[users, :]
pos_i_g_embeddings = item_embeds if pos_items is None else item_embeds[pos_items, :]
neg_i_g_embeddings = item_embeds if neg_items is None else item_embeds[neg_items, :]

return u_g_embeddings, pos_i_g_embeddings, neg_i_g_embeddings

def loss_fn(self, users, pos_items, neg_items):
pos_scores = (users * pos_items).sum(1)
neg_scores = (users * neg_items).sum(1)

bpr_loss = F.softplus(neg_scores - pos_scores).mean()
reg_loss = (
(1 / 2)
* (
torch.norm(users) ** 2
+ torch.norm(pos_items) ** 2
+ torch.norm(neg_items) ** 2
)
/ len(users)
)

return bpr_loss + self.lambda_reg * reg_loss, bpr_loss, reg_loss
115 changes: 37 additions & 78 deletions cornac/models/lightgcn/recom_lightgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,18 @@ class LightGCN(Recommender):
name: string, default: 'LightGCN'
The name of the recommender model.

emb_size: int, default: 64
Size of the node embeddings.

num_epochs: int, default: 1000
Maximum number of iterations or the number of epochs
Maximum number of iterations or the number of epochs.

learning_rate: float, default: 0.001
The learning rate that determines the step size at each iteration

train_batch_size: int, default: 1024
batch_size: int, default: 1024
Mini-batch size used for train set

test_batch_size: int, default: 100
Mini-batch size used for test set

hidden_dim: int, default: 64
The embedding size of the model

num_layers: int, default: 3
Number of LightGCN Layers

Expand Down Expand Up @@ -80,11 +77,10 @@ class LightGCN(Recommender):
def __init__(
self,
name="LightGCN",
emb_size=64,
num_epochs=1000,
learning_rate=0.001,
train_batch_size=1024,
test_batch_size=100,
hidden_dim=64,
batch_size=1024,
num_layers=3,
early_stopping=None,
lambda_reg=1e-4,
Expand All @@ -93,13 +89,11 @@ def __init__(
seed=2020,
):
super().__init__(name=name, trainable=trainable, verbose=verbose)

self.emb_size = emb_size
self.num_epochs = num_epochs
self.learning_rate = learning_rate
self.hidden_dim = hidden_dim
self.batch_size = batch_size
self.num_layers = num_layers
self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size
self.early_stopping = early_stopping
self.lambda_reg = lambda_reg
self.seed = seed
Expand Down Expand Up @@ -135,19 +129,15 @@ def fit(self, train_set, val_set=None):
if torch.cuda.is_available():
torch.cuda.manual_seed_all(self.seed)

graph = construct_graph(train_set).to(self.device)
model = Model(
train_set.total_users,
train_set.total_items,
self.hidden_dim,
graph,
self.emb_size,
self.num_layers,
self.lambda_reg,
).to(self.device)

graph = construct_graph(train_set).to(self.device)

optimizer = torch.optim.Adam(
model.parameters(), lr=self.learning_rate, weight_decay=self.lambda_reg
)
loss_fn = torch.nn.BCELoss(reduction="sum")
optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate)

# model training
pbar = trange(
Expand All @@ -163,53 +153,43 @@ def fit(self, train_set, val_set=None):
accum_loss = 0.0
for batch_u, batch_i, batch_j in tqdm(
train_set.uij_iter(
batch_size=self.train_batch_size,
batch_size=self.batch_size,
shuffle=True,
),
desc="Epoch",
total=train_set.num_batches(self.train_batch_size),
total=train_set.num_batches(self.batch_size),
leave=False,
position=1,
disable=not self.verbose,
):
user_embeddings, item_embeddings = model(graph)

batch_u = torch.from_numpy(batch_u).long().to(self.device)
batch_i = torch.from_numpy(batch_i).long().to(self.device)
batch_j = torch.from_numpy(batch_j).long().to(self.device)

user_embed = user_embeddings[batch_u]
positive_item_embed = item_embeddings[batch_i]
negative_item_embed = item_embeddings[batch_j]

ui_scores = (user_embed * positive_item_embed).sum(dim=1)
uj_scores = (user_embed * negative_item_embed).sum(dim=1)
u_g_embeddings, pos_i_g_embeddings, neg_i_g_embeddings = model(
graph, batch_u, batch_i, batch_j
)

loss = loss_fn(
torch.sigmoid(ui_scores - uj_scores), torch.ones_like(ui_scores)
batch_loss, batch_bpr_loss, batch_reg_loss = model.loss_fn(
u_g_embeddings, pos_i_g_embeddings, neg_i_g_embeddings
)
accum_loss += loss.cpu().item()
accum_loss += batch_loss.cpu().item() * len(batch_u)

optimizer.zero_grad()
loss.backward()
batch_loss.backward()
optimizer.step()

accum_loss /= len(train_set.uir_tuple[0]) # normalize over all observations
pbar.set_postfix(loss=accum_loss)

# store user and item embedding matrices for prediction
model.eval()
self.U, self.V = model(graph)
u_embs, i_embs, _ = model(graph)
# we will use numpy for faster prediction in the score function, no need torch
self.U = u_embs.cpu().detach().numpy()
self.V = i_embs.cpu().detach().numpy()

if self.early_stopping is not None and self.early_stop(
**self.early_stopping
):
break

# we will use numpy for faster prediction in the score function, no need torch
self.U = self.U.cpu().detach().numpy()
self.V = self.V.cpu().detach().numpy()

def monitor_value(self):
"""Calculating monitored value used for early stopping on validation set (`val_set`).
This function will be called by `early_stop()` function.
Expand All @@ -223,38 +203,17 @@ def monitor_value(self):
if self.val_set is None:
return None

import torch
from ...metrics import Recall
from ...eval_methods import ranking_eval

loss_fn = torch.nn.BCELoss(reduction="sum")
accum_loss = 0.0
pbar = tqdm(
self.val_set.uij_iter(batch_size=self.test_batch_size),
desc="Validation",
total=self.val_set.num_batches(self.test_batch_size),
leave=False,
position=1,
disable=not self.verbose,
)
for batch_u, batch_i, batch_j in pbar:
batch_u = torch.from_numpy(batch_u).long().to(self.device)
batch_i = torch.from_numpy(batch_i).long().to(self.device)
batch_j = torch.from_numpy(batch_j).long().to(self.device)

user_embed = self.U[batch_u]
positive_item_embed = self.V[batch_i]
negative_item_embed = self.V[batch_j]

ui_scores = (user_embed * positive_item_embed).sum(dim=1)
uj_scores = (user_embed * negative_item_embed).sum(dim=1)

loss = loss_fn(
torch.sigmoid(ui_scores - uj_scores), torch.ones_like(ui_scores)
)
accum_loss += loss.cpu().item()
pbar.set_postfix(val_loss=accum_loss)

accum_loss /= len(self.val_set.uir_tuple[0])
return -accum_loss # higher is better -> smaller loss is better
recall_20 = ranking_eval(
model=self,
metrics=[Recall(k=20)],
train_set=self.train_set,
test_set=self.val_set
)[0][0]

return recall_20 # Section 4.1.2 in the paper, same strategy as NGCF.

def score(self, user_idx, item_idx=None):
"""Predict the scores/ratings of a user for an item.
Expand Down
Loading
Loading