Skip to content

Commit

Permalink
Optimize LightGCN Model (#531)
Browse files Browse the repository at this point in the history
* Generated model base from LightGCN

* wip

* wip example

* add self-connection

* refactor code

* added sanity check

* Changed train batch size in example to 1024

* Updated readme for example folder

* Update Readme

* update docs

* Update block comment

* WIP

* Updated validation metric

* Updated message handling

* Added legacy lightgcn for comparison purposes

* Changed to follow 'a_k = 1/(k+1)',  k instead of i

* Changed early stopping technique to follow NGCF

* remove test_batchsize, early stop verbose to false

* Changed parameters to align with paper and ngcf

* refractor codes

* update docstring

* change param name to 'batch_size'

* Fix paper reference

---------

Co-authored-by: tqtg <[email protected]>
Co-authored-by: Quoc-Tuan Truong <[email protected]>
  • Loading branch information
3 people authored Oct 17, 2023
1 parent c484988 commit a470d5c
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 148 deletions.
159 changes: 92 additions & 67 deletions cornac/models/lightgcn/lightgcn.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn


USER_KEY = "user"
ITEM_KEY = "item"


def construct_graph(data_set):
"""
Generates graph given a cornac data set
Expand All @@ -14,89 +19,109 @@ def construct_graph(data_set):
The data set as provided by cornac
"""
user_indices, item_indices, _ = data_set.uir_tuple
user_nodes, item_nodes = (
torch.from_numpy(user_indices),
torch.from_numpy(
item_indices + data_set.total_users
), # increment item node idx by num users
)

u = torch.cat([user_nodes, item_nodes], dim=0)
v = torch.cat([item_nodes, user_nodes], dim=0)
data_dict = {
(USER_KEY, "user_item", ITEM_KEY): (user_indices, item_indices),
(ITEM_KEY, "item_user", USER_KEY): (item_indices, user_indices),
}
num_dict = {USER_KEY: data_set.total_users, ITEM_KEY: data_set.total_items}

g = dgl.graph((u, v), num_nodes=(data_set.total_users + data_set.total_items))
return g
return dgl.heterograph(data_dict, num_nodes_dict=num_dict)


class GCNLayer(nn.Module):
def __init__(self):
def __init__(self, norm_dict):
super(GCNLayer, self).__init__()

def forward(self, graph, src_embedding, dst_embedding):
with graph.local_scope():
inner_product = torch.cat((src_embedding, dst_embedding), dim=0)

out_degs = graph.out_degrees().to(src_embedding.device).float().clamp(min=1)
norm_out_degs = torch.pow(out_degs, -0.5).view(-1, 1) # D^-1/2

inner_product = inner_product * norm_out_degs

graph.ndata["h"] = inner_product
graph.update_all(
message_func=fn.copy_u("h", "m"), reduce_func=fn.sum("m", "h")
)

res = graph.ndata["h"]

in_degs = graph.in_degrees().to(src_embedding.device).float().clamp(min=1)
norm_in_degs = torch.pow(in_degs, -0.5).view(-1, 1) # D^-1/2

res = res * norm_in_degs
return res
# norm
self.norm_dict = norm_dict

def forward(self, g, feat_dict):
funcs = {} # message and reduce functions dict
# for each type of edges, compute messages and reduce them all
for srctype, etype, dsttype in g.canonical_etypes:
src, dst = g.edges(etype=(srctype, etype, dsttype))
norm = self.norm_dict[(srctype, etype, dsttype)]
# TODO: CHECK HERE
messages = norm * feat_dict[srctype][src] # compute messages
g.edges[(srctype, etype, dsttype)].data[
etype
] = messages # store in edata
funcs[(srctype, etype, dsttype)] = (
fn.copy_e(etype, "m"),
fn.sum("m", "h"),
) # define message and reduce functions

g.multi_update_all(
funcs, "sum"
) # update all, reduce by first type-wisely then across different types
feature_dict = {}
for ntype in g.ntypes:
h = F.normalize(g.nodes[ntype].data["h"], dim=1, p=2) # l2 normalize
feature_dict[ntype] = h
return feature_dict


class Model(nn.Module):
def __init__(self, user_size, item_size, hidden_size, num_layers=3, device=None):
def __init__(self, g, in_size, num_layers, lambda_reg, device=None):
super(Model, self).__init__()
self.user_size = user_size
self.item_size = item_size
self.hidden_size = hidden_size
self.embedding_weights = self._init_weights()
self.layers = nn.ModuleList([GCNLayer() for _ in range(num_layers)])
self.norm_dict = dict()
self.lambda_reg = lambda_reg
self.device = device

def forward(self, graph):
user_embedding = self.embedding_weights["user_embedding"]
item_embedding = self.embedding_weights["item_embedding"]
for srctype, etype, dsttype in g.canonical_etypes:
src, dst = g.edges(etype=(srctype, etype, dsttype))
dst_degree = g.in_degrees(
dst, etype=(srctype, etype, dsttype)
).float() # obtain degrees
src_degree = g.out_degrees(src, etype=(srctype, etype, dsttype)).float()
norm = torch.pow(src_degree * dst_degree, -0.5).unsqueeze(1) # compute norm
self.norm_dict[(srctype, etype, dsttype)] = norm

for i, layer in enumerate(self.layers, start=1):
if i == 1:
embeddings = layer(graph, user_embedding, item_embedding)
else:
embeddings = layer(
graph, embeddings[: self.user_size], embeddings[self.user_size:]
)
self.layers = nn.ModuleList([GCNLayer(self.norm_dict) for _ in range(num_layers)])

user_embedding = user_embedding + embeddings[: self.user_size] * (
1 / (i + 1)
)
item_embedding = item_embedding + embeddings[self.user_size:] * (
1 / (i + 1)
)

return user_embedding, item_embedding
self.initializer = nn.init.xavier_uniform_

def _init_weights(self):
initializer = nn.init.xavier_uniform_

weights_dict = nn.ParameterDict(
# embeddings for different types of nodes
self.feature_dict = nn.ParameterDict(
{
"user_embedding": nn.Parameter(
initializer(torch.empty(self.user_size, self.hidden_size))
),
"item_embedding": nn.Parameter(
initializer(torch.empty(self.item_size, self.hidden_size))
),
ntype: nn.Parameter(
self.initializer(torch.empty(g.num_nodes(ntype), in_size))
)
for ntype in g.ntypes
}
)
return weights_dict

def forward(self, g, users=None, pos_items=None, neg_items=None):
h_dict = {ntype: self.feature_dict[ntype] for ntype in g.ntypes}
# obtain features of each layer and concatenate them all
user_embeds = h_dict[USER_KEY]
item_embeds = h_dict[ITEM_KEY]

for k, layer in enumerate(self.layers):
h_dict = layer(g, h_dict)
user_embeds = user_embeds + (h_dict[USER_KEY] * 1 / (k + 1))
item_embeds = item_embeds + (h_dict[ITEM_KEY] * 1 / (k + 1))

u_g_embeddings = user_embeds if users is None else user_embeds[users, :]
pos_i_g_embeddings = item_embeds if pos_items is None else item_embeds[pos_items, :]
neg_i_g_embeddings = item_embeds if neg_items is None else item_embeds[neg_items, :]

return u_g_embeddings, pos_i_g_embeddings, neg_i_g_embeddings

def loss_fn(self, users, pos_items, neg_items):
pos_scores = (users * pos_items).sum(1)
neg_scores = (users * neg_items).sum(1)

bpr_loss = F.softplus(neg_scores - pos_scores).mean()
reg_loss = (
(1 / 2)
* (
torch.norm(users) ** 2
+ torch.norm(pos_items) ** 2
+ torch.norm(neg_items) ** 2
)
/ len(users)
)

return bpr_loss + self.lambda_reg * reg_loss, bpr_loss, reg_loss
115 changes: 37 additions & 78 deletions cornac/models/lightgcn/recom_lightgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,18 @@ class LightGCN(Recommender):
name: string, default: 'LightGCN'
The name of the recommender model.
emb_size: int, default: 64
Size of the node embeddings.
num_epochs: int, default: 1000
Maximum number of iterations or the number of epochs
Maximum number of iterations or the number of epochs.
learning_rate: float, default: 0.001
The learning rate that determines the step size at each iteration
train_batch_size: int, default: 1024
batch_size: int, default: 1024
Mini-batch size used for train set
test_batch_size: int, default: 100
Mini-batch size used for test set
hidden_dim: int, default: 64
The embedding size of the model
num_layers: int, default: 3
Number of LightGCN Layers
Expand Down Expand Up @@ -80,11 +77,10 @@ class LightGCN(Recommender):
def __init__(
self,
name="LightGCN",
emb_size=64,
num_epochs=1000,
learning_rate=0.001,
train_batch_size=1024,
test_batch_size=100,
hidden_dim=64,
batch_size=1024,
num_layers=3,
early_stopping=None,
lambda_reg=1e-4,
Expand All @@ -93,13 +89,11 @@ def __init__(
seed=2020,
):
super().__init__(name=name, trainable=trainable, verbose=verbose)

self.emb_size = emb_size
self.num_epochs = num_epochs
self.learning_rate = learning_rate
self.hidden_dim = hidden_dim
self.batch_size = batch_size
self.num_layers = num_layers
self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size
self.early_stopping = early_stopping
self.lambda_reg = lambda_reg
self.seed = seed
Expand Down Expand Up @@ -135,19 +129,15 @@ def fit(self, train_set, val_set=None):
if torch.cuda.is_available():
torch.cuda.manual_seed_all(self.seed)

graph = construct_graph(train_set).to(self.device)
model = Model(
train_set.total_users,
train_set.total_items,
self.hidden_dim,
graph,
self.emb_size,
self.num_layers,
self.lambda_reg,
).to(self.device)

graph = construct_graph(train_set).to(self.device)

optimizer = torch.optim.Adam(
model.parameters(), lr=self.learning_rate, weight_decay=self.lambda_reg
)
loss_fn = torch.nn.BCELoss(reduction="sum")
optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate)

# model training
pbar = trange(
Expand All @@ -163,53 +153,43 @@ def fit(self, train_set, val_set=None):
accum_loss = 0.0
for batch_u, batch_i, batch_j in tqdm(
train_set.uij_iter(
batch_size=self.train_batch_size,
batch_size=self.batch_size,
shuffle=True,
),
desc="Epoch",
total=train_set.num_batches(self.train_batch_size),
total=train_set.num_batches(self.batch_size),
leave=False,
position=1,
disable=not self.verbose,
):
user_embeddings, item_embeddings = model(graph)

batch_u = torch.from_numpy(batch_u).long().to(self.device)
batch_i = torch.from_numpy(batch_i).long().to(self.device)
batch_j = torch.from_numpy(batch_j).long().to(self.device)

user_embed = user_embeddings[batch_u]
positive_item_embed = item_embeddings[batch_i]
negative_item_embed = item_embeddings[batch_j]

ui_scores = (user_embed * positive_item_embed).sum(dim=1)
uj_scores = (user_embed * negative_item_embed).sum(dim=1)
u_g_embeddings, pos_i_g_embeddings, neg_i_g_embeddings = model(
graph, batch_u, batch_i, batch_j
)

loss = loss_fn(
torch.sigmoid(ui_scores - uj_scores), torch.ones_like(ui_scores)
batch_loss, batch_bpr_loss, batch_reg_loss = model.loss_fn(
u_g_embeddings, pos_i_g_embeddings, neg_i_g_embeddings
)
accum_loss += loss.cpu().item()
accum_loss += batch_loss.cpu().item() * len(batch_u)

optimizer.zero_grad()
loss.backward()
batch_loss.backward()
optimizer.step()

accum_loss /= len(train_set.uir_tuple[0]) # normalize over all observations
pbar.set_postfix(loss=accum_loss)

# store user and item embedding matrices for prediction
model.eval()
self.U, self.V = model(graph)
u_embs, i_embs, _ = model(graph)
# we will use numpy for faster prediction in the score function, no need torch
self.U = u_embs.cpu().detach().numpy()
self.V = i_embs.cpu().detach().numpy()

if self.early_stopping is not None and self.early_stop(
**self.early_stopping
):
break

# we will use numpy for faster prediction in the score function, no need torch
self.U = self.U.cpu().detach().numpy()
self.V = self.V.cpu().detach().numpy()

def monitor_value(self):
"""Calculating monitored value used for early stopping on validation set (`val_set`).
This function will be called by `early_stop()` function.
Expand All @@ -223,38 +203,17 @@ def monitor_value(self):
if self.val_set is None:
return None

import torch
from ...metrics import Recall
from ...eval_methods import ranking_eval

loss_fn = torch.nn.BCELoss(reduction="sum")
accum_loss = 0.0
pbar = tqdm(
self.val_set.uij_iter(batch_size=self.test_batch_size),
desc="Validation",
total=self.val_set.num_batches(self.test_batch_size),
leave=False,
position=1,
disable=not self.verbose,
)
for batch_u, batch_i, batch_j in pbar:
batch_u = torch.from_numpy(batch_u).long().to(self.device)
batch_i = torch.from_numpy(batch_i).long().to(self.device)
batch_j = torch.from_numpy(batch_j).long().to(self.device)

user_embed = self.U[batch_u]
positive_item_embed = self.V[batch_i]
negative_item_embed = self.V[batch_j]

ui_scores = (user_embed * positive_item_embed).sum(dim=1)
uj_scores = (user_embed * negative_item_embed).sum(dim=1)

loss = loss_fn(
torch.sigmoid(ui_scores - uj_scores), torch.ones_like(ui_scores)
)
accum_loss += loss.cpu().item()
pbar.set_postfix(val_loss=accum_loss)

accum_loss /= len(self.val_set.uir_tuple[0])
return -accum_loss # higher is better -> smaller loss is better
recall_20 = ranking_eval(
model=self,
metrics=[Recall(k=20)],
train_set=self.train_set,
test_set=self.val_set
)[0][0]

return recall_20 # Section 4.1.2 in the paper, same strategy as NGCF.

def score(self, user_idx, item_idx=None):
"""Predict the scores/ratings of a user for an item.
Expand Down
Loading

0 comments on commit a470d5c

Please sign in to comment.