From e5d28bd0743eb06352de0073fae046bbea7165c0 Mon Sep 17 00:00:00 2001 From: lthoang Date: Sun, 14 Jan 2024 22:50:46 +0800 Subject: [PATCH] Add DNNTSP Model --- README.md | 1 + cornac/models/__init__.py | 1 + cornac/models/dnntsp/__init__.py | 16 + cornac/models/dnntsp/dnntsp.py | 707 ++++++++++++++++++++++++++ cornac/models/dnntsp/recom_dnntsp.py | 130 +++++ cornac/models/dnntsp/requirements.txt | 2 + docs/source/api_ref/models.rst | 5 + examples/README.md | 2 + examples/dnntsp_tafeng.py | 55 ++ 9 files changed, 919 insertions(+) create mode 100644 cornac/models/dnntsp/__init__.py create mode 100644 cornac/models/dnntsp/dnntsp.py create mode 100644 cornac/models/dnntsp/recom_dnntsp.py create mode 100644 cornac/models/dnntsp/requirements.txt create mode 100644 examples/dnntsp_tafeng.py diff --git a/README.md b/README.md index 32c1d4581..b486fb62d 100644 --- a/README.md +++ b/README.md @@ -153,6 +153,7 @@ The recommender models supported by Cornac are listed below. Why don't you join | | [Hybrid neural recommendation with joint deep representation learning of ratings and reviews (HRDR)](cornac/models/hrdr), [paper](https://www.sciencedirect.com/science/article/abs/pii/S0925231219313207) | [requirements.txt](cornac/models/hrdr/requirements.txt) | [hrdr_example.py](examples/hrdr_example.py) | | [LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation](cornac/models/lightgcn), [paper](https://arxiv.org/pdf/2002.02126.pdf) | [requirements.txt](cornac/models/lightgcn/requirements.txt) | [lightgcn_example.py](examples/lightgcn_example.py) | | [New Variational Autoencoder for Top-N Recommendations with Implicit Feedback (RecVAE)](cornac/models/recvae), [paper](https://doi.org/10.1145/3336191.3371831) | [requirements.txt](cornac/models/recvae/requirements.txt) | [recvae_example.py](examples/recvae_example.py) +| | [Predicting Temporal Sets with Deep Neural Networks (DNNTSP)](cornac/models/dnntsp), [paper](https://arxiv.org/pdf/2006.11483.pdf) | [requirements.txt](cornac/models/dnntsp/requirements.txt) | [dnntsp_tafeng.py](examples/dnntsp_tafeng.py) | | [Temporal-Item-Frequency-based User-KNN (TIFUKNN)](cornac/models/tifuknn), [paper](https://arxiv.org/pdf/2006.00556.pdf) | N/A | [tifuknn_tafeng.py](examples/tifuknn_tafeng.py) | 2019 | [Embarrassingly Shallow Autoencoders for Sparse Data (EASEᴿ)](cornac/models/ease), [paper](https://arxiv.org/pdf/1905.03375.pdf) | N/A | [ease_movielens.py](examples/ease_movielens.py) | | [Neural Graph Collaborative Filtering (NGCF)](cornac/models/ngcf), [paper](https://arxiv.org/pdf/1905.08108.pdf) | [requirements.txt](cornac/models/ngcf/requirements.txt) | [ngcf_example.py](examples/ngcf_example.py) diff --git a/cornac/models/__init__.py b/cornac/models/__init__.py index cde6fcb37..a3e3db60c 100644 --- a/cornac/models/__init__.py +++ b/cornac/models/__init__.py @@ -37,6 +37,7 @@ from .ctr import CTR from .cvae import CVAE from .cvaecf import CVAECF +from .dnntsp import DNNTSP from .ease import EASE from .efm import EFM from .fm import FM diff --git a/cornac/models/dnntsp/__init__.py b/cornac/models/dnntsp/__init__.py new file mode 100644 index 000000000..4c3e9eea0 --- /dev/null +++ b/cornac/models/dnntsp/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2023 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from .recom_dnntsp import DNNTSP diff --git a/cornac/models/dnntsp/dnntsp.py b/cornac/models/dnntsp/dnntsp.py new file mode 100644 index 000000000..0e78d4ae1 --- /dev/null +++ b/cornac/models/dnntsp/dnntsp.py @@ -0,0 +1,707 @@ +import itertools +import random +from collections import defaultdict +from typing import List + +import dgl +import dgl.function as fn +import numpy as np +import torch +import torch.nn as nn +from tqdm.auto import trange + +from ..mf.backend_pt import OPTIMIZER_DICT + + +def get_optimizer(model, lr=0.001, weight_decay=0, momentum=0.9, optimizer="adam"): + if optimizer == "adam": + return torch.optim.Adam( + model.parameters(), + lr=lr, + weight_decay=weight_decay, + momentum=momentum, + ) + elif optimizer == "sgd": + return torch.optim.SGD( + model.parameters(), + lr=lr, + momentum=momentum, + ) + else: + raise NotImplementedError() + + +def scheduler_fn(optimizer): + return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) + + +class MaskedSelfAttention(nn.Module): + def __init__(self, input_dim, output_dim, n_heads=4, attention_aggregate="concat"): + super(MaskedSelfAttention, self).__init__() + # aggregate multi-heads by concatenation or mean + self.attention_aggregate = attention_aggregate + + # the dimension of each head is dq // n_heads + self.input_dim = input_dim + self.output_dim = output_dim + + self.n_heads = n_heads + + if attention_aggregate == "concat": + self.per_head_dim = self.dq = self.dk = self.dv = output_dim // n_heads + elif attention_aggregate == "mean": + self.per_head_dim = self.dq = self.dk = self.dv = output_dim + else: + raise ValueError(f"wrong value for aggregate {attention_aggregate}") + + self.Wq = nn.Linear(input_dim, n_heads * self.dq, bias=False) + self.Wk = nn.Linear(input_dim, n_heads * self.dk, bias=False) + self.Wv = nn.Linear(input_dim, n_heads * self.dv, bias=False) + + def forward(self, input_tensor): + """ + Args: + input_tensor: tensor, shape (nodes_num, T_max, features_num) + Returns: + output: tensor, shape (nodes_num, T_max, output_dim = features_num) + """ + seq_length = input_tensor.shape[1] + # tensor, shape (nodes_num, T_max, n_heads * dim_per_head) + Q = self.Wq(input_tensor) + K = self.Wk(input_tensor) + V = self.Wv(input_tensor) + # multi_head attention + # Q, tensor, shape (nodes_num, n_heads, T_max, dim_per_head) + Q = Q.reshape( + input_tensor.shape[0], input_tensor.shape[1], self.n_heads, self.dq + ).transpose(1, 2) + # K after transpose, tensor, shape (nodes_num, n_heads, dim_per_head, T_max) + K = K.reshape( + input_tensor.shape[0], input_tensor.shape[1], self.n_heads, self.dk + ).permute(0, 2, 3, 1) + # V, tensor, shape (nodes_num, n_heads, T_max, dim_per_head) + V = V.reshape( + input_tensor.shape[0], input_tensor.shape[1], self.n_heads, self.dv + ).transpose(1, 2) + + # scaled attention_score, tensor, shape (nodes_num, n_heads, T_max, T_max) + attention_score = Q.matmul(K) / np.sqrt(self.per_head_dim) + + # attention_mask, tensor, shape -> (T_max, T_max) -inf in the top and right + attention_mask = ( + torch.zeros(seq_length, seq_length) + .masked_fill(torch.tril(torch.ones(seq_length, seq_length)) == 0, -np.inf) + .to(input_tensor.device) + ) + # attention_mask will be broadcast to (nodes_num, n_heads, T_max, T_max) + attention_score = attention_score + attention_mask + # (nodes_num, n_heads, T_max, T_max) + attention_score = torch.softmax(attention_score, dim=-1) + + # multi_result, tensor, shape (nodes_num, n_heads, T_max, dim_per_head) + multi_head_result = attention_score.matmul(V) + if self.attention_aggregate == "concat": + # multi_result, tensor, shape (nodes_num, T_max, n_heads * dim_per_head = output_dim) + # concat multi-head attention results + output = multi_head_result.transpose(1, 2).reshape( + input_tensor.shape[0], seq_length, self.n_heads * self.per_head_dim + ) + elif self.attention_aggregate == "mean": + # multi_result, tensor, shape (nodes_num, T_max, dim_per_head = output_dim) + # mean multi-head attention results + output = multi_head_result.transpose(1, 2).mean(dim=2) + else: + raise ValueError(f"wrong value for aggregate {self.attention_aggregate}") + + return output + + +class GlobalGatedUpdate(nn.Module): + def __init__(self, n_items, embedding_matrix): + super(GlobalGatedUpdate, self).__init__() + self.n_items = n_items + self.embedding_matrix = embedding_matrix + + # alpha -> the weight for updating + self.alpha = nn.Parameter(torch.rand(n_items, 1), requires_grad=True) + + def forward(self, graph, nodes, nodes_output): + """ + :param graph: batched graphs, with the total number of nodes is `node_num`, + including `batch_size` disconnected subgraphs + :param nodes: tensor (n_1+n_2+..., ) + :param nodes_output: the output of self-attention model in time dimension, (n_1+n_2+..., F) + :return: + """ + nums_nodes, id = graph.batch_num_nodes(), 0 + items_embedding = self.embedding_matrix( + torch.tensor([i for i in range(self.n_items)]).to(nodes.device) + ) + batch_embedding = [] + for num_nodes in nums_nodes: + # tensor, shape, (user_nodes, item_embed_dim) + output_node_features = nodes_output[id : id + num_nodes, :] + # get each user's nodes + output_nodes = nodes[id : id + num_nodes] + # beta, tensor, (n_items, 1), indicator vector, appear item -> 1, not appear -> 0 + beta = torch.zeros(self.n_items, 1).to(nodes.device) + beta[output_nodes] = 1 + # update global embedding by gated mechanism + # broadcast (n_items, 1) * (n_items, item_embed_dim) -> (n_items, item_embed_dim) + embed = (1 - beta * self.alpha) * items_embedding.clone() + # appear items: (1 - self.alpha) * origin + self.alpha * update, not appear items: origin + embed[output_nodes, :] = ( + embed[output_nodes, :] + self.alpha[output_nodes] * output_node_features + ) + batch_embedding.append(embed) + id += num_nodes + # (B, n_items, item_embed_dim) + batch_embedding = torch.stack(batch_embedding) + return batch_embedding + + +class AggregateNodesTemporalFeature(nn.Module): + def __init__(self, item_embed_dim): + """ + :param item_embed_dim: the dimension of input features + """ + super(AggregateNodesTemporalFeature, self).__init__() + + self.Wq = nn.Linear(item_embed_dim, 1, bias=False) + + def forward(self, graph, lengths, nodes_output): + """ + :param graph: batched graphs, with the total number of nodes is `node_num`, + including `batch_size` disconnected subgraphs + :param lengths: tensor, (batch_size, ) + :param nodes_output: the output of self-attention model in time dimension, (n_1+n_2+..., T_max, F) + :return: aggregated_features, (n_1+n_2+..., F) + """ + nums_nodes, id = graph.batch_num_nodes(), 0 + aggregated_features = [] + for num_nodes, length in zip(nums_nodes, lengths): + # get each user's length, tensor, shape, (user_nodes, user_length, item_embed_dim) + output_node_features = nodes_output[id : id + num_nodes, :length, :] + # weights for each timestamp, tensor, shape, (user_nodes, 1, user_length) + # (user_nodes, user_length, 1) transpose to -> (user_nodes, 1, user_length) + weights = self.Wq(output_node_features).transpose(1, 2) + # (user_nodes, 1, user_length) matmul (user_nodes, user_length, item_embed_dim) + # -> (user_nodes, 1, item_embed_dim) squeeze to (user_nodes, item_embed_dim) + # aggregated_feature, tensor, shape, (user_nodes, item_embed_dim) + aggregated_feature = weights.matmul(output_node_features).squeeze(dim=1) + aggregated_features.append(aggregated_feature) + id += num_nodes + # (n_1+n_2+..., item_embed_dim) + aggregated_features = torch.cat(aggregated_features, dim=0) + return aggregated_features + + +class WeightedGraphConv(nn.Module): + """ + Apply graph convolution over an input signal. + """ + + def __init__(self, in_features: int, out_features: int): + super(WeightedGraphConv, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.linear = nn.Linear(in_features, out_features, bias=True) + + def forward(self, graph, node_features, edge_weights): + """Compute weighted graph convolution. + ----- + Input: + graph : DGLGraph, batched graph. + node_features : torch.Tensor, input features for nodes (n_1+n_2+..., in_features) or (n_1+n_2+..., T, in_features) + edge_weights : torch.Tensor, input weights for edges (T, n_1^2+n_2^2+..., n^2) + + Output: + shape: (N, T, out_features) + """ + graph = graph.local_var() + # multi W first to project the features, with bias + # (N, F) / (N, T, F) + graph.ndata["n"] = node_features + # edge_weights, shape (T, N^2) + # one way: use dgl.function is faster and less requirement of GPU memory + graph.edata["e"] = edge_weights.t().unsqueeze(dim=-1) # (E, T, 1) + graph.update_all(fn.u_mul_e("n", "e", "msg"), fn.sum("msg", "h")) + + # another way: use user defined function, needs more GPU memory + # graph.edata['e'] = edge_weights.t() + # graph.update_all(self.gcn_message, self.gcn_reduce) + + node_features = graph.ndata.pop("h") + output = self.linear(node_features) + return output + + @staticmethod + def gcn_message(edges): + if edges.src["n"].dim() == 2: + # (E, T, 1) (E, 1, F), matmul -> matmul (E, T, F) + return { + "msg": torch.matmul( + edges.data["e"].unsqueeze(dim=-1), edges.src["n"].unsqueeze(dim=1) + ) + } + + elif edges.src["n"].dim() == 3: + # (E, T, 1) (E, T, F), mul -> (E, T, F) + return {"msg": torch.mul(edges.data["e"].unsqueeze(dim=-1), edges.src["n"])} + + else: + raise ValueError( + f"wrong shape for edges.src['n'], the length of shape is {edges.src['n'].dim()}" + ) + + @staticmethod + def gcn_reduce(nodes): + # propagate, the first dimension is nodes num in a batch + # h, tensor, shape, (N, neighbors, T, F) -> (N, T, F) + return {"h": torch.sum(nodes.mailbox["msg"], 1)} + + +class WeightedGCN(nn.Module): + def __init__(self, in_features: int, hidden_sizes: List[int], out_features: int): + super(WeightedGCN, self).__init__() + gcns, relus, bns = nn.ModuleList(), nn.ModuleList(), nn.ModuleList() + # layers for hidden_size + input_size = in_features + for hidden_size in hidden_sizes: + gcns.append(WeightedGraphConv(input_size, hidden_size)) + relus.append(nn.ReLU()) + bns.append(nn.BatchNorm1d(hidden_size)) + input_size = hidden_size + # output layer + gcns.append(WeightedGraphConv(hidden_sizes[-1], out_features)) + relus.append(nn.ReLU()) + bns.append(nn.BatchNorm1d(out_features)) + self.gcns, self.relus, self.bns = gcns, relus, bns + + def forward( + self, + graph: dgl.DGLGraph, + node_features: torch.Tensor, + edges_weight: torch.Tensor, + ): + """ + :param graph: a graph + :param node_features: shape (n_1+n_2+..., n_features) + edges_weight: shape (T, n_1^2+n_2^2+...) + :return: + """ + h = node_features + for gcn, relu, bn in zip(self.gcns, self.relus, self.bns): + # (n_1+n_2+..., T, features) + h = gcn(graph, h, edges_weight) + h = bn(h.transpose(1, -1)).transpose(1, -1) + h = relu(h) + return h + + +class StackedWeightedGCNBlocks(nn.ModuleList): + def __init__(self, *args, **kwargs): + super(StackedWeightedGCNBlocks, self).__init__(*args, **kwargs) + + def forward(self, *input): + g, nodes_feature, edge_weights = input + h = nodes_feature + for module in self: + h = module(g, h, edge_weights) + return h + + +class TemporalSetPrediction(nn.Module): + def __init__(self, n_items, emb_dim): + """ + :param n_items: int + :param emb_dim: int + :param n_heads: int + """ + super(TemporalSetPrediction, self).__init__() + + self.embedding_matrix = nn.Embedding(n_items, emb_dim) + + self.emb_dim = emb_dim + self.n_items = n_items + self.stacked_gcn = StackedWeightedGCNBlocks( + [WeightedGCN(emb_dim, [emb_dim], emb_dim)] + ) + + self.masked_self_attention = MaskedSelfAttention( + input_dim=emb_dim, output_dim=emb_dim + ) + + self.aggregate_nodes_temporal_feature = AggregateNodesTemporalFeature( + item_embed_dim=emb_dim + ) + + self.global_gated_update = GlobalGatedUpdate( + n_items=n_items, embedding_matrix=self.embedding_matrix + ) + + self.fc_output = nn.Linear(emb_dim, 1) + + def forward( + self, + batch_graph, + batch_nodes_feature, + batch_edges_weight, + batch_lengths, + batch_nodes, + ): + """ + + :param graph: batched graphs, with the total number of nodes is `node_num`, + including `batch_size` disconnected subgraphs + :param nodes_feature: [n_1+n_2+..., F] + :param edges_weight: [T_max, n_1^2+n_2^2+...] + :param lengths: [batch_size, ] + :param nodes: [n_1+n_2+..., ] + :return: + """ + # perform weighted gcn on dynamic graphs (n_1+n_2+..., T_max, item_embed_dim) + batch_nodes_output = [ + self.stacked_gcn(graph, nodes_feature, edges_weight) + for graph, nodes_feature, edges_weight in zip( + batch_graph, batch_nodes_feature, batch_edges_weight + ) + ] + + # self-attention in time dimension, (n_1+n_2+..., T_max, item_embed_dim) + batch_nodes_output = [ + self.masked_self_attention(nodes_output) + for nodes_output in batch_nodes_output + ] + # aggregate node features in temporal dimension, (n_1+n_2+..., item_embed_dim) + batch_nodes_output = [ + self.aggregate_nodes_temporal_feature(graph, lengths, nodes_output) + for graph, lengths, nodes_output in zip( + batch_graph, batch_lengths, batch_nodes_output + ) + ] + + # (batch_size, n_items, item_embed_dim) + batch_nodes_output = [ + self.global_gated_update(graph, nodes, nodes_output) + for graph, nodes, nodes_output in zip( + batch_graph, batch_nodes, batch_nodes_output + ) + ] + + # (batch_size, n_items) + outputs = [ + self.fc_output(nodes_output).squeeze() + for nodes_output in batch_nodes_output + ] + + return outputs + + +def get_edges_weight(history_baskets): + edges_weight_dict = defaultdict(float) + for basket_items in history_baskets: + for i in range(len(basket_items)): + for j in range(i + 1, len(basket_items)): + edges_weight_dict[(basket_items[i], basket_items[j])] += 1 + edges_weight_dict[(basket_items[j], basket_items[i])] += 1 + return edges_weight_dict + + +def transform_data(bi_batch, item_embedding, total_items, device=torch.device("cpu"), is_test=False): + if is_test: + batch_history_items = [ + [np.unique(basket).tolist() for basket in basket_items] + for basket_items in bi_batch + ] + else: + batch_history_items = [ + [np.unique(basket).tolist() for basket in basket_items[:-1]] + for basket_items in bi_batch + ] + batch_lengths = [ + [len(basket) for basket in history_items] + for history_items in batch_history_items + ] + if is_test: + batch_targets = None + else: + batch_targets = np.zeros((len(bi_batch), total_items), dtype="uint8") + for inc, basket_items in enumerate(bi_batch): + batch_targets[inc, basket_items[-1]] = 1 + batch_targets = torch.tensor(batch_targets, dtype=torch.bool, device=device) + batch_nodes = [ + torch.tensor( + list(set(itertools.chain.from_iterable(history_items))), + dtype=torch.int32, + device=device, + ) + for history_items in batch_history_items + ] + batch_nodes_feature = [item_embedding(nodes) for nodes in batch_nodes] + + batch_project_nodes = [ + torch.tensor(list(range(nodes.shape[0]))) for nodes in batch_nodes + ] + batch_src = [ + ( + torch.stack([project_nodes for _ in range(project_nodes.shape[0])], dim=1) + .flatten() + .tolist() + ) + for project_nodes in batch_project_nodes + ] + batch_dst = [ + ( + torch.stack([project_nodes for _ in range(project_nodes.shape[0])], dim=0) + .flatten() + .tolist() + ) + for project_nodes in batch_project_nodes + ] + batch_g = [ + dgl.graph((src, dst), num_nodes=project_nodes.shape[0]).to(device) + for src, dst, project_nodes in zip(batch_src, batch_dst, batch_project_nodes) + ] + batch_edges_weight_dict = [ + get_edges_weight(history_items) for history_items in batch_history_items + ] + + for i, nodes in enumerate(batch_nodes): + for node in nodes.tolist(): + if batch_edges_weight_dict[i][(node, node)] == 0.0: + batch_edges_weight_dict[i][(node, node)] = 1.0 + max_weight = max(batch_edges_weight_dict[i].values()) + for k, v in batch_edges_weight_dict[i].items(): + batch_edges_weight_dict[i][k] = v / max_weight + + batch_edges_weight = [] + for edges_weight_dict, history_items, nodes in zip( + batch_edges_weight_dict, batch_history_items, batch_nodes + ): + edges_weight = [] + for basket in history_items: + edge_weight = [] + for node_1 in nodes.tolist(): + for node_2 in nodes.tolist(): + if (node_1 in basket and node_2 in basket) or (node_1 == node_2): + edge_weight.append(edges_weight_dict.get((node_1, node_2), 0.0)) + else: + edge_weight.append(0.0) + edges_weight.append(torch.Tensor(edge_weight)) + batch_edges_weight.append(torch.stack(edges_weight).to(device)) + return ( + batch_g, + batch_nodes_feature, + batch_edges_weight, + batch_lengths, + batch_nodes, + batch_targets, + ) + + +class BPRLoss(nn.Module): + def __init__(self): + super(BPRLoss, self).__init__() + + def forward(self, predict, truth): + """ + Args: + predict: (batch_size, items_total) / (batch_size, baskets_num, item_total) + truth: (batch_size, items_total) / (batch_size, baskets_num, item_total) + Returns: + output: tensor + """ + result = self.batch_bpr_loss(predict, truth) + + return result + + def batch_bpr_loss(self, predict, truth): + """ + Args: + predict: (batch_size, items_total) + truth: (batch_size, items_total) + Returns: + output: tensor + """ + items_total = truth.shape[1] + nll = 0 + for user, predictUser in zip(truth, predict): + pos_idx = user.clone().detach() + preUser = predictUser[pos_idx] + non_zero_list = list(itertools.chain.from_iterable(torch.nonzero(user))) + random_list = list(set(range(0, items_total)) - set(non_zero_list)) + random.shuffle(random_list) + neg_idx = torch.tensor(random_list[: len(preUser)]) + score = preUser - predictUser[neg_idx] + nll += -torch.mean(torch.nn.LogSigmoid()(score)) + return nll + + +class WeightMSELoss(nn.Module): + def __init__(self, weights=None): + """ + Args: + weights: tensor, (items_total, ) + """ + super(WeightMSELoss, self).__init__() + self.weights = weights + if weights is not None: + self.weights = torch.sqrt(weights) + self.mse_loss = nn.MSELoss(reduction="sum") + + def forward(self, predict, truth): + """ + Args: + predict: tenor, (batch_size, items_total) + truth: tensor, (batch_size, items_total) + Returns: + output: tensor + """ + # predict = torch.softmax(predict, dim=-1) + predict = torch.sigmoid(predict) + truth = truth.float() + # print(predict.device) + # print(truth.device) + if self.weights is not None: + self.weights = self.weights.to(truth.device) + predict = predict * self.weights + truth = truth * self.weights + + loss = self.mse_loss(predict, truth) + return loss + + +class W_multilabel(nn.Module): + def __init__(self, weight): + self.weights = weight + + def forward(self, predict, truth): + predict = torch.sigmoid(predict) + truth = truth.float() + assert len(self.weights) == 1 + self.weights = self.weights.to(truth.device) + batch_size = predict.size(0) + w_loss = 0 + for ind in range(batch_size): + c_loss = self.weights * (truth[ind] * torch.log(predict[ind])) + ( + (1 - truth[ind]) * torch.log(1 - predict[ind]) + ) + w_loss += torch.neg(torch.sum(c_loss)) + return w_loss + + +def loss_fn(loss_type=None, weights=None): + if loss_type == "bpr": + return BPRLoss() + elif loss_type == "mse": + return WeightMSELoss() + elif loss_type == "weight_mse": + assert weights != None, f"weight_mse loss required 'weights' but {weights}" + return WeightMSELoss(weights=weights) + elif loss_type == "multi_label_soft_margin": + return nn.MultiLabelSoftMarginLoss(reduction="mean") + else: + raise ValueError("Unknown loss function") + + +def learn( + model: TemporalSetPrediction, + train_set, + total_items, + val_set=None, + n_epochs=10, + batch_size=64, + lr=0.001, + weight_decay=0.0, + loss_type="bpr", + optimizer="adam", + device=torch.device("cpu"), + verbose=True, +): + model = model.to(device) + criteria = loss_fn(loss_type=loss_type) + optimizer = OPTIMIZER_DICT[optimizer]( + params=model.parameters(), + lr=lr, + weight_decay=weight_decay, + ) + + progress_bar = trange(1, n_epochs + 1, disable=not verbose) + last_val_loss = np.inf + last_loss = np.inf + for _ in progress_bar: + model.train() + for inc, (_, _, bi_batch) in enumerate( + train_set.ubi_iter(batch_size, shuffle=True) + ): + ( + g, + nodes_feature, + edges_weight, + lengths, + nodes, + targets, + ) = transform_data( + bi_batch, + item_embedding=model.embedding_matrix, + total_items=total_items, + device=device, + ) + preds = model(g, nodes_feature, edges_weight, lengths, nodes) + loss = criteria(preds, targets) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + last_loss = loss.data.item() + if inc % 10 == 0: + progress_bar.set_postfix(loss=last_loss, val_loss=last_val_loss) + + if val_set is not None: + model.eval() + for inc, (_, _, bi_batch) in enumerate( + val_set.ubi_iter(batch_size, shuffle=False) + ): + ( + g, + nodes_feature, + edges_weight, + lengths, + nodes, + targets, + ) = transform_data( + bi_batch, + item_embedding=model.embedding_matrix, + total_items=total_items, + device=device, + ) + preds = model(g, nodes_feature, edges_weight, lengths, nodes) + loss = criteria(preds, targets) + + last_val_loss = loss.data.item() + if inc % 10 == 0: + progress_bar.set_postfix(loss=last_loss, val_loss=last_val_loss) + + +def score(model: TemporalSetPrediction, history_baskets, total_items, device="cpu"): + model = model.to(device) + model.eval() + ( + g, + nodes_feature, + edges_weight, + lengths, + nodes, + _ + ) = transform_data( + [history_baskets], + item_embedding=model.embedding_matrix, + total_items=total_items, + device=device, + is_test=True, + ) + preds = model(g, nodes_feature, edges_weight, lengths, nodes) + return preds[0].cpu().detach().numpy() + diff --git a/cornac/models/dnntsp/recom_dnntsp.py b/cornac/models/dnntsp/recom_dnntsp.py new file mode 100644 index 000000000..9196d335b --- /dev/null +++ b/cornac/models/dnntsp/recom_dnntsp.py @@ -0,0 +1,130 @@ +# Copyright 2023 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from ..recommender import NextBasketRecommender + + +class DNNTSP(NextBasketRecommender): + """Deep Neural Network for Temporal Sets Prediction (DNNTSP). + + Parameters + ---------- + name: string, default: 'DNNTSP' + The name of the recommender model. + + emb_dim: int, optional, default: 32 + Number of hidden factors + + loss_type: string, optional, default: "bpr" + Loss type. Including + "bpr": BPRLoss + "mse": MSELoss + "weight_mse": WeightMSELoss + "multi_label_soft_margin": MultiLabelSoftMarginLoss + + optimizer: string, optional, default: "adam" + Optimizer + + lr: string, optional, default: 0.001 + Learning rate + + weight_decay: float, optional, default: 0 + Weight decay for adaptive optimizer + + n_epochs: int, optional, default: 100 + Number of epochs + + batch_size: int, optional, default: 64 + Batch size + + device: string, optional, default: "cpu" + Device for learning and evaluation. Using cpu as default. + Use "cuda:0" for using gpu. + + trainable: boolean, optional, default: True + When False, the model will not be re-trained, and input of pre-trained parameters are required. + + verbose: boolean, optional, default: True + When True, running logs are displayed. + + seed: int, optional, default: None + Random seed + + References + ---------- + Le Yu, Leilei Sun, Bowen Du, Chuanren Liu, Hui Xiong, and Weifeng Lv. 2020. + Predicting Temporal Sets with Deep Neural Networks. + In Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining (KDD '20). Association for Computing Machinery, New York, NY, USA, 1083–1091. https://doi.org/10.1145/3394486.3403152 + """ + + def __init__( + self, + name="DNNTSP", + emb_dim=32, + loss_type="bpr", + optimizer="adam", + lr=0.001, + weight_decay=0, + n_epochs=100, + batch_size=64, + device="cpu", + trainable=True, + verbose=False, + seed=None, + ): + super().__init__(name=name, trainable=trainable, verbose=verbose) + self.emb_dim = emb_dim + self.loss_type = loss_type + self.optimizer = optimizer + self.lr = lr + self.weight_decay = weight_decay + self.n_epochs = n_epochs + self.batch_size = batch_size + self.seed = seed + self.device = device + + def fit(self, train_set, val_set=None): + super().fit(train_set=train_set, val_set=val_set) + from .dnntsp import TemporalSetPrediction, learn + + self.model = TemporalSetPrediction( + n_items=self.total_items, + emb_dim=self.emb_dim, + ) + + learn( + model=self.model, + train_set=train_set, + total_items=self.total_items, + val_set=val_set, + n_epochs=self.n_epochs, + batch_size=self.batch_size, + lr=self.lr, + weight_decay=self.weight_decay, + loss_type=self.loss_type, + optimizer=self.optimizer, + device=self.device, + verbose=self.verbose, + ) + + return self + + def score(self, user_idx, history_baskets, **kwargs): + from .dnntsp import score + + item_scores = score( + self.model, history_baskets, self.total_items, device=self.device + ) + return item_scores diff --git a/cornac/models/dnntsp/requirements.txt b/cornac/models/dnntsp/requirements.txt new file mode 100644 index 000000000..32f294fbc --- /dev/null +++ b/cornac/models/dnntsp/requirements.txt @@ -0,0 +1,2 @@ +torch>=2.0.0 +dgl>=1.1.0 \ No newline at end of file diff --git a/docs/source/api_ref/models.rst b/docs/source/api_ref/models.rst index 5006dd84f..90341bdb2 100644 --- a/docs/source/api_ref/models.rst +++ b/docs/source/api_ref/models.rst @@ -49,6 +49,11 @@ New Variational Autoencoder for Top-N Recommendations with Implicit Feedback (Re .. automodule:: cornac.models.recvae.recom_recvae :members: +Predicting Temporal Sets with Deep Neural Networks (DNNTSP) +----------------------------------------------------------- +.. automodule:: cornac.models.dnntsp.recom_dnntsp + :members: + Temporal-Item-Frequency-based User-KNN (TIFUKNN) --------------------------------------------------- .. automodule:: cornac.models.tifuknn.recom_tifuknn diff --git a/examples/README.md b/examples/README.md index 4ec27a6a6..e2fc53a45 100644 --- a/examples/README.md +++ b/examples/README.md @@ -120,4 +120,6 @@ [gp_top_tafeng.py](gp_top_tafeng.py) - Next-basket recommendation model that merely uses item top frequency. +[dnntsp_tafeng.py](dnntsp_tafeng.py) - Predicting Temporal Sets with Deep Neural Networks (DNNTSP). + [tifuknn_tafeng.py](tifuknn_tafeng.py) - Example of Temporal-Item-Frequency-based User-KNN (TIFUKNN). diff --git a/examples/dnntsp_tafeng.py b/examples/dnntsp_tafeng.py new file mode 100644 index 000000000..2b41735ed --- /dev/null +++ b/examples/dnntsp_tafeng.py @@ -0,0 +1,55 @@ +# Copyright 2023 The Cornac Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Example of Predicting Temporal Sets with Deep Neural Networks (DNNTSP)""" + +import cornac +from cornac.eval_methods import NextBasketEvaluation +from cornac.metrics import NDCG, HitRatio, Recall +from cornac.models import DNNTSP + +data = cornac.datasets.tafeng.load_basket( + reader=cornac.data.Reader( + min_basket_size=3, max_basket_size=50, min_basket_sequence=2 + ) +) + +next_basket_eval = NextBasketEvaluation( + data=data, fmt="UBITJson", test_size=0.2, val_size=0.08, seed=123, verbose=True +) + +models = [ + DNNTSP( + emb_dim=32, + loss_type="bpr", + optimizer="adam", + lr=0.001, + weight_decay=0, + batch_size=64, + n_epochs=10, + device="cuda:0", + verbose=True, + ) +] + +metrics = [ + Recall(k=10), + Recall(k=50), + NDCG(k=10), + NDCG(k=50), + HitRatio(k=10), + HitRatio(k=50), +] + +cornac.Experiment(eval_method=next_basket_eval, models=models, metrics=metrics).run()