diff --git a/.gitignore b/.gitignore index bb30b8566..5ff0a0ffa 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ tests/vocab.pkl .idea/ +.vscode/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/cornac/models/__init__.py b/cornac/models/__init__.py index 7d4ef4b56..91610481b 100644 --- a/cornac/models/__init__.py +++ b/cornac/models/__init__.py @@ -74,4 +74,3 @@ "FM model is only supported on Linux.\n" + "Windows executable can be found at http://www.libfm.org." ) - diff --git a/cornac/models/ncf/backend_pt.py b/cornac/models/ncf/backend_pt.py new file mode 100644 index 000000000..a1da0f991 --- /dev/null +++ b/cornac/models/ncf/backend_pt.py @@ -0,0 +1,176 @@ +import torch +import torch.nn as nn + + +optimizer_dict = { + "sgd": torch.optim.SGD, + "adam": torch.optim.Adam, + "rmsprop": torch.optim.RMSprop, + "adagrad": torch.optim.Adagrad, +} + +activation_functions = { + "sigmoid": nn.Sigmoid(), + "tanh": nn.Tanh(), + "elu": nn.ELU(), + "selu": nn.SELU(), + "relu": nn.ReLU(), + "relu6": nn.ReLU6(), + "leakyrelu": nn.LeakyReLU(), +} + + +class GMF(nn.Module): + def __init__( + self, + num_users: int, + num_items: int, + num_factors: int = 8, + ): + super(GMF, self).__init__() + + self.num_users = num_users + self.num_items = num_items + self.user_embedding = nn.Embedding(num_users, num_factors) + self.item_embedding = nn.Embedding(num_items, num_factors) + + self.logit = nn.Linear(num_factors, 1) + self.Sigmoid = nn.Sigmoid() + + self._init_weight() + + def _init_weight(self): + nn.init.normal_(self.user_embedding.weight, std=1e-2) + nn.init.normal_(self.item_embedding.weight, std=1e-2) + nn.init.normal_(self.logit.weight, std=1e-2) + + def from_pretrained(self, pretrained_gmf): + self.user_embedding.weight.data.copy_(pretrained_gmf.user_embedding.weight) + self.item_embedding.weight.data.copy_(pretrained_gmf.item_embedding.weight) + self.logit.weight.data.copy_(pretrained_gmf.logit.weight) + self.logit.bias.data.copy_(pretrained_gmf.logit.bias) + + def h(self, users, items): + return self.user_embedding(users) * self.item_embedding(items) + + def forward(self, users, items): + h = self.h(users, items) + output = self.Sigmoid(self.logit(h)).view(-1) + return output + + +class MLP(nn.Module): + def __init__( + self, + num_users: int, + num_items: int, + layers=(64, 32, 16, 8), + act_fn="relu", + ): + super(MLP, self).__init__() + + self.num_users = num_users + self.num_items = num_items + self.user_embedding = nn.Embedding(num_users, layers[0] // 2) + self.item_embedding = nn.Embedding(num_items, layers[0] // 2) + + mlp_layers = [] + for idx, factor in enumerate(layers[:-1]): + mlp_layers.append(nn.Linear(factor, layers[idx + 1])) + mlp_layers.append(activation_functions[act_fn.lower()]) + + # unpacking layers in to torch.nn.Sequential + self.mlp_model = nn.Sequential(*mlp_layers) + + self.logit = nn.Linear(layers[-1], 1) + self.Sigmoid = nn.Sigmoid() + + self._init_weight() + + def _init_weight(self): + nn.init.normal_(self.user_embedding.weight, std=1e-2) + nn.init.normal_(self.item_embedding.weight, std=1e-2) + for layer in self.mlp_model: + if isinstance(layer, nn.Linear): + nn.init.xavier_uniform_(layer.weight) + nn.init.normal_(self.logit.weight, std=1e-2) + + def from_pretrained(self, pretrained_mlp): + self.user_embedding.weight.data.copy_(pretrained_mlp.user_embedding.weight) + self.item_embedding.weight.data.copy_(pretrained_mlp.item_embedding.weight) + for layer, pretrained_layer in zip(self.mlp_model, pretrained_mlp.mlp_model): + if isinstance(layer, nn.Linear) and isinstance(pretrained_layer, nn.Linear): + layer.weight.data.copy_(pretrained_layer.weight) + layer.bias.data.copy_(pretrained_layer.bias) + self.logit.weight.data.copy_(pretrained_mlp.logit.weight) + self.logit.bias.data.copy_(pretrained_mlp.logit.bias) + + def h(self, users, items): + embed_user = self.user_embedding(users) + embed_item = self.item_embedding(items) + embed_input = torch.cat((embed_user, embed_item), dim=-1) + return self.mlp_model(embed_input) + + def forward(self, users, items): + h = self.h(users, items) + output = self.Sigmoid(self.logit(h)).view(-1) + return output + + def __call__(self, *args): + return self.forward(*args) + + +class NeuMF(nn.Module): + def __init__( + self, + num_users: int, + num_items: int, + num_factors: int = 8, + layers=(64, 32, 16, 8), + act_fn="relu", + ): + super(NeuMF, self).__init__() + + # layer for MLP + if layers is None: + layers = [64, 32, 16, 8] + if num_factors is None: + num_factors = layers[-1] + + assert layers[-1] == num_factors + + self.logit = nn.Linear(num_factors + layers[-1], 1) + self.Sigmoid = nn.Sigmoid() + + self.gmf = GMF(num_users, num_items, num_factors) + self.mlp = MLP( + num_users=num_users, num_items=num_items, layers=layers, act_fn=act_fn + ) + + nn.init.normal_(self.logit.weight, std=1e-2) + + def from_pretrained(self, pretrained_gmf, pretrained_mlp, alpha): + self.gmf.from_pretrained(pretrained_gmf) + self.mlp.from_pretrained(pretrained_mlp) + logit_weight = torch.cat( + [ + alpha * self.gmf.logit.weight, + (1.0 - alpha) * self.mlp.logit.weight, + ], + dim=1, + ) + logit_bias = alpha * self.gmf.logit.bias + (1.0 - alpha) * self.mlp.logit.bias + self.logit.weight.data.copy_(logit_weight) + self.logit.bias.data.copy_(logit_bias) + + def forward(self, users, items, gmf_users=None): + # gmf_users is there to take advantage of broadcasting + h_gmf = ( + self.gmf.h(users, items) + if gmf_users is None + else self.gmf.h(gmf_users, items) + ) + h_mlp = self.mlp.h(users, items) + h = torch.cat([h_gmf, h_mlp], dim=-1) + output = self.Sigmoid(self.logit(h)).view(-1) + return output diff --git a/cornac/models/ncf/ops.py b/cornac/models/ncf/backend_tf.py similarity index 100% rename from cornac/models/ncf/ops.py rename to cornac/models/ncf/backend_tf.py diff --git a/cornac/models/ncf/recom_gmf.py b/cornac/models/ncf/recom_gmf.py index 657d601d1..f55ec7eef 100644 --- a/cornac/models/ncf/recom_gmf.py +++ b/cornac/models/ncf/recom_gmf.py @@ -27,9 +27,9 @@ class GMF(NCFBase): ---------- num_factors: int, optional, default: 8 Embedding size of MF model. - - regs: float, optional, default: 0. - Regularization for user and item embeddings. + + reg: float, optional, default: 0. + Regularization (weight_decay). num_epochs: int, optional, default: 20 Number of epochs. @@ -45,7 +45,10 @@ class GMF(NCFBase): learner: str, optional, default: 'adam' Specify an optimizer: adagrad, adam, rmsprop, sgd - + + backend: str, optional, default: 'tensorflow' + Backend used for model training: tensorflow, pytorch + early_stopping: {min_delta: float, patience: int}, optional, default: None If `None`, no early stopping. Meaning of the arguments: @@ -77,12 +80,13 @@ def __init__( self, name="GMF", num_factors=8, - regs=(0.0, 0.0), + reg=0.0, num_epochs=20, batch_size=256, num_neg=4, lr=0.001, learner="adam", + backend="tensorflow", early_stopping=None, trainable=True, verbose=True, @@ -97,17 +101,21 @@ def __init__( num_neg=num_neg, lr=lr, learner=learner, + backend=backend, early_stopping=early_stopping, seed=seed, ) self.num_factors = num_factors - self.regs = regs + self.reg = reg - def _build_graph(self): + ######################## + ## TensorFlow backend ## + ######################## + def _build_graph_tf(self): import tensorflow.compat.v1 as tf - from .ops import gmf, loss_fn, train_fn + from .backend_tf import gmf, loss_fn, train_fn - super()._build_graph() + self.graph = tf.Graph() with self.graph.as_default(): tf.set_random_seed(self.seed) @@ -123,8 +131,8 @@ def _build_graph(self): num_users=self.num_users, num_items=self.num_items, emb_size=self.num_factors, - reg_user=self.regs[0], - reg_item=self.regs[1], + reg_user=self.reg, + reg_item=self.reg, seed=self.seed, ) @@ -144,50 +152,32 @@ def _build_graph(self): self.initializer = tf.global_variables_initializer() self.saver = tf.train.Saver() - self._sess_init() - - def score(self, user_idx, item_idx=None): - """Predict the scores/ratings of a user for an item. - - Parameters - ---------- - user_idx: int, required - The index of the user for whom to perform score prediction. - - item_idx: int, optional, default: None - The index of the item for which to perform score prediction. - If None, scores for all known items will be returned. - - Returns - ------- - res : A scalar or a Numpy array - Relative scores that the user gives to the item or to all known items - """ - if item_idx is None: - if self.train_set.is_unk_user(user_idx): - raise ScoreException( - "Can't make score prediction for (user_id=%d)" % user_idx - ) - - known_item_scores = self.sess.run( - self.prediction, - feed_dict={ - self.user_id: [user_idx], - self.item_id: np.arange(self.train_set.num_items), - }, - ) - return known_item_scores.ravel() - else: - if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( - item_idx - ): - raise ScoreException( - "Can't make score prediction for (user_id=%d, item_id=%d)" - % (user_idx, item_idx) - ) - - user_pred = self.sess.run( - self.prediction, - feed_dict={self.user_id: [user_idx], self.item_id: [item_idx]}, - ) - return user_pred.ravel() + self._sess_init_tf() + + def _score_tf(self, user_idx, item_idx): + feed_dict = { + self.user_id: [user_idx], + self.item_id: np.arange(self.num_items) if item_idx is None else [item_idx], + } + return self.sess.run(self.prediction, feed_dict=feed_dict) + + ##################### + ## PyTorch backend ## + ##################### + def _build_model_pt(self): + from .backend_pt import GMF + + return GMF(self.num_users, self.num_items, self.num_factors) + + def _score_pt(self, user_idx, item_idx): + import torch + + with torch.no_grad(): + users = torch.tensor(user_idx).unsqueeze(0).to(self.device) + items = ( + torch.from_numpy(np.arange(self.num_items)) + if item_idx is None + else torch.tensor(item_idx).unsqueeze(0) + ).to(self.device) + output = self.model(users, items) + return output.squeeze().cpu().numpy() diff --git a/cornac/models/ncf/recom_mlp.py b/cornac/models/ncf/recom_mlp.py index 4d5304490..6901b91c4 100644 --- a/cornac/models/ncf/recom_mlp.py +++ b/cornac/models/ncf/recom_mlp.py @@ -16,7 +16,6 @@ import numpy as np from .recom_ncf_base import NCFBase -from ...exception import ScoreException class MLP(NCFBase): @@ -32,9 +31,8 @@ class MLP(NCFBase): Name of the activation function used for the MLP layers. Supported functions: ['sigmoid', 'tanh', 'elu', 'relu', 'selu, 'relu6', 'leaky_relu'] - reg_layers: list, optional, default: [0., 0., 0., 0.] - Regularization for each MLP layer, - reg_layers[0] is the regularization for embeddings. + reg: float, optional, default: 0. + Regularization (weight_decay). num_epochs: int, optional, default: 20 Number of epochs. @@ -50,7 +48,10 @@ class MLP(NCFBase): learner: str, optional, default: 'adam' Specify an optimizer: adagrad, adam, rmsprop, sgd - + + backend: str, optional, default: 'tensorflow' + Backend used for model training: tensorflow, pytorch + early_stopping: {min_delta: float, patience: int}, optional, default: None If `None`, no early stopping. Meaning of the arguments: @@ -83,12 +84,13 @@ def __init__( name="MLP", layers=(64, 32, 16, 8), act_fn="relu", - reg_layers=(0.0, 0.0, 0.0, 0.0), + reg=0.0, num_epochs=20, batch_size=256, num_neg=4, lr=0.001, learner="adam", + backend="tensorflow", early_stopping=None, trainable=True, verbose=True, @@ -103,18 +105,22 @@ def __init__( num_neg=num_neg, lr=lr, learner=learner, + backend=backend, early_stopping=early_stopping, seed=seed, ) self.layers = layers self.act_fn = act_fn - self.reg_layers = reg_layers + self.reg = reg - def _build_graph(self): + ######################## + ## TensorFlow backend ## + ######################## + def _build_graph_tf(self): import tensorflow.compat.v1 as tf - from .ops import mlp, loss_fn, train_fn + from .backend_tf import mlp, loss_fn, train_fn - super()._build_graph() + self.graph = tf.Graph() with self.graph.as_default(): tf.set_random_seed(self.seed) @@ -130,7 +136,7 @@ def _build_graph(self): num_users=self.num_users, num_items=self.num_items, layers=self.layers, - reg_layers=self.reg_layers, + reg_layers=[self.reg] * len(self.layers), act_fn=self.act_fn, seed=self.seed, ) @@ -150,50 +156,43 @@ def _build_graph(self): self.initializer = tf.global_variables_initializer() self.saver = tf.train.Saver() - self._sess_init() - - def score(self, user_idx, item_idx=None): - """Predict the scores/ratings of a user for an item. - - Parameters - ---------- - user_idx: int, required - The index of the user for whom to perform score prediction. - - item_idx: int, optional, default: None - The index of the item for which to perform score prediction. - If None, scores for all known items will be returned. + self._sess_init_tf() - Returns - ------- - res : A scalar or a Numpy array - Relative scores that the user gives to the item or to all known items - """ + def _score_tf(self, user_idx, item_idx): if item_idx is None: - if self.train_set.is_unk_user(user_idx): - raise ScoreException( - "Can't make score prediction for (user_id=%d)" % user_idx - ) - - known_item_scores = self.sess.run( - self.prediction, - feed_dict={ - self.user_id: np.ones(self.train_set.num_items) * user_idx, - self.item_id: np.arange(self.train_set.num_items), - }, - ) - return known_item_scores.ravel() + feed_dict = { + self.user_id: np.ones(self.num_items) * user_idx, + self.item_id: np.arange(self.num_items), + } else: - if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( - item_idx - ): - raise ScoreException( - "Can't make score prediction for (user_id=%d, item_id=%d)" - % (user_idx, item_idx) - ) - - user_pred = self.sess.run( - self.prediction, - feed_dict={self.user_id: [user_idx], self.item_id: [item_idx]}, - ) - return user_pred.ravel() + feed_dict = { + self.user_id: [user_idx], + self.item_id: [item_idx], + } + return self.sess.run(self.prediction, feed_dict=feed_dict) + + ##################### + ## PyTorch backend ## + ##################### + def _build_model_pt(self): + from .backend_pt import MLP + + return MLP( + num_users=self.num_users, + num_items=self.num_items, + layers=self.layers, + act_fn=self.act_fn, + ) + + def _score_pt(self, user_idx, item_idx): + import torch + + with torch.no_grad(): + if item_idx is None: + users = torch.from_numpy(np.ones(self.num_items, dtype=int) * user_idx) + items = (torch.from_numpy(np.arange(self.num_items))).to(self.device) + else: + users = torch.tensor(user_idx).unsqueeze(0) + items = torch.tensor(item_idx).unsqueeze(0) + output = self.model(users.to(self.device), items.to(self.device)) + return output.squeeze().cpu().numpy() diff --git a/cornac/models/ncf/recom_ncf_base.py b/cornac/models/ncf/recom_ncf_base.py index 8510c69d2..2541a4272 100644 --- a/cornac/models/ncf/recom_ncf_base.py +++ b/cornac/models/ncf/recom_ncf_base.py @@ -14,10 +14,12 @@ # ============================================================================ +import numpy as np from tqdm.auto import trange from ..recommender import Recommender from ...utils import get_rng +from ...exception import ScoreException class NCFBase(Recommender): @@ -39,6 +41,9 @@ class NCFBase(Recommender): learner: str, optional, default: 'adam' Specify an optimizer: adagrad, adam, rmsprop, sgd + + backend: str, optional, default: 'tensorflow' + Backend used for model training: tensorflow, pytorch early_stopping: {min_delta: float, patience: int}, optional, default: None If `None`, no early stopping. Meaning of the arguments: @@ -71,6 +76,7 @@ def __init__( num_neg=4, lr=0.001, learner="adam", + backend="tensorflow", early_stopping=None, trainable=True, verbose=True, @@ -82,6 +88,7 @@ def __init__( self.num_neg = num_neg self.lr = lr self.learner = learner + self.backend = backend self.early_stopping = early_stopping self.seed = seed self.rng = get_rng(seed) @@ -119,20 +126,25 @@ def fit(self, train_set, val_set=None): Recommender.fit(self, train_set, val_set) if self.trainable: - if not hasattr(self, "graph"): - self.num_users = self.train_set.num_users - self.num_items = self.train_set.num_items - self._build_graph() - self._fit_tf() + self.num_users = self.train_set.num_users + self.num_items = self.train_set.num_items - return self + if self.backend == "tensorflow": + self._fit_tf() + elif self.backend == "pytorch": + self._fit_pt() + else: + raise ValueError(f"{self.backend} is not supported") - def _build_graph(self): - import tensorflow.compat.v1 as tf + return self - self.graph = tf.Graph() + ######################## + ## TensorFlow backend ## + ######################## + def _build_graph_tf(self): + raise NotImplementedError() - def _sess_init(self): + def _sess_init_tf(self): import tensorflow.compat.v1 as tf config = tf.ConfigProto() @@ -140,18 +152,17 @@ def _sess_init(self): self.sess = tf.Session(graph=self.graph, config=config) self.sess.run(self.initializer) - def _step_update(self, batch_users, batch_items, batch_ratings): - _, _loss = self.sess.run( - [self.train_op, self.loss], - feed_dict={ - self.user_id: batch_users, - self.item_id: batch_items, - self.labels: batch_ratings.reshape(-1, 1), - }, - ) - return _loss + def _get_feed_dict(self, batch_users, batch_items, batch_ratings): + return { + self.user_id: batch_users, + self.item_id: batch_items, + self.labels: batch_ratings.reshape(-1, 1), + } def _fit_tf(self): + if not hasattr(self, "graph"): + self._build_graph_tf() + loop = trange(self.num_epochs, disable=not self.verbose) for _ in loop: count = 0 @@ -161,7 +172,12 @@ def _fit_tf(self): self.batch_size, shuffle=True, binary=True, num_zeros=self.num_neg ) ): - _loss = self._step_update(batch_users, batch_items, batch_ratings) + _, _loss = self.sess.run( + [self.train_op, self.loss], + feed_dict=self._get_feed_dict( + batch_users, batch_items, batch_ratings + ), + ) count += len(batch_ratings) sum_loss += _loss * len(batch_ratings) if i % 10 == 0: @@ -173,6 +189,67 @@ def _fit_tf(self): break loop.close() + def _score_tf(self, user_idx, item_idx): + raise NotImplementedError() + + ##################### + ## PyTorch backend ## + ##################### + def _build_model_pt(self): + raise NotImplementedError() + + def _fit_pt(self): + import torch + import torch.nn as nn + from .backend_pt import optimizer_dict + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = device + if self.seed is not None: + torch.manual_seed(self.seed) + np.random.seed(self.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(self.seed) + + self.model = self._build_model_pt().to(self.device) + + optimizer = optimizer_dict[self.learner]( + self.model.parameters(), + lr=self.lr, + weight_decay=self.reg, + ) + criteria = nn.BCELoss() + + loop = trange(self.num_epochs, disable=not self.verbose) + for _ in loop: + count = 0 + sum_loss = 0 + for batch_id, (batch_users, batch_items, batch_ratings) in enumerate( + self.train_set.uir_iter( + self.batch_size, shuffle=True, binary=True, num_zeros=self.num_neg + ) + ): + batch_users = torch.from_numpy(batch_users).to(self.device) + batch_items = torch.from_numpy(batch_items).to(self.device) + batch_ratings = torch.tensor(batch_ratings, dtype=torch.float).to( + self.device + ) + + optimizer.zero_grad() + outputs = self.model(batch_users, batch_items) + loss = criteria(outputs, batch_ratings) + loss.backward() + optimizer.step() + + count += len(batch_users) + sum_loss += loss.data.item() + + if batch_id % 10 == 0: + loop.set_postfix(loss=(sum_loss / count)) + + def _score_pt(self, user_idx, item_idx): + raise NotImplementedError() + def save(self, save_dir=None): """Save a recommender model to the filesystem. @@ -186,8 +263,12 @@ def save(self, save_dir=None): return model_file = Recommender.save(self, save_dir) - # save TF weights - self.saver.save(self.sess, model_file.replace(".pkl", ".cpt")) + + if self.backend == "tensorflow": + self.saver.save(self.sess, model_file.replace(".pkl", ".cpt")) + elif self.backend == "pytorch": + # TODO: implement model saving for PyTorch + raise NotImplementedError() return model_file @@ -213,8 +294,12 @@ def load(model_path, trainable=False): if hasattr(model, "pretrained"): # NeuMF model.pretrained = False - model._build_graph() - model.saver.restore(model.sess, model.load_from.replace(".pkl", ".cpt")) + if model.backend == "tensorflow": + model._build_graph() + model.saver.restore(model.sess, model.load_from.replace(".pkl", ".cpt")) + elif model.backend == "pytorch": + # TODO: implement model loading for PyTorch + raise NotImplementedError() return model @@ -242,3 +327,38 @@ def monitor_value(self): )[0][0] return ndcg_100 + + def score(self, user_idx, item_idx=None): + """Predict the scores/ratings of a user for an item. + + Parameters + ---------- + user_idx: int, required + The index of the user for whom to perform score prediction. + + item_idx: int, optional, default: None + The index of the item for which to perform score prediction. + If None, scores for all known items will be returned. + + Returns + ------- + res : A scalar or a Numpy array + Relative scores that the user gives to the item or to all known items + """ + if self.train_set.is_unk_user(user_idx): + raise ScoreException( + "Can't make score prediction for (user_id=%d)" % user_idx + ) + + if item_idx is not None and self.train_set.is_unk_item(item_idx): + raise ScoreException( + "Can't make score prediction for (user_id=%d, item_id=%d)" + % (user_idx, item_idx) + ) + + if self.backend == "tensorflow": + pred_scores = self._score_tf(user_idx, item_idx) + elif self.backend == "pytorch": + pred_scores = self._score_pt(user_idx, item_idx) + + return pred_scores.ravel() diff --git a/cornac/models/ncf/recom_neumf.py b/cornac/models/ncf/recom_neumf.py index 595bf7830..760048d0f 100644 --- a/cornac/models/ncf/recom_neumf.py +++ b/cornac/models/ncf/recom_neumf.py @@ -35,8 +35,8 @@ class NeuMF(NCFBase): Name of the activation function used for the MLP layers. Supported functions: ['sigmoid', 'tanh', 'elu', 'relu', 'selu, 'relu6', 'leaky_relu'] - reg_mf: float, optional, default: 0. - Regularization for MF embeddings. + reg: float, optional, default: 0. + Regularization (weight_decay). reg_layers: list, optional, default: [0., 0., 0., 0.] Regularization for each MLP layer, @@ -57,6 +57,9 @@ class NeuMF(NCFBase): learner: str, optional, default: 'adam' Specify an optimizer: adagrad, adam, rmsprop, sgd + backend: str, optional, default: 'tensorflow' + Backend used for model training: tensorflow, pytorch + early_stopping: {min_delta: float, patience: int}, optional, default: None If `None`, no early stopping. Meaning of the arguments: @@ -90,13 +93,13 @@ def __init__( num_factors=8, layers=(64, 32, 16, 8), act_fn="relu", - reg_mf=0.0, - reg_layers=(0.0, 0.0, 0.0, 0.0), + reg=0.0, num_epochs=20, batch_size=256, num_neg=4, lr=0.001, learner="adam", + backend="tensorflow", early_stopping=None, trainable=True, verbose=True, @@ -111,51 +114,54 @@ def __init__( num_neg=num_neg, lr=lr, learner=learner, + backend=backend, early_stopping=early_stopping, seed=seed, ) self.num_factors = num_factors self.layers = layers self.act_fn = act_fn - self.reg_mf = reg_mf - self.reg_layers = reg_layers + self.reg = reg self.pretrained = False self.ignored_attrs.extend( [ "gmf_user_id", "mlp_user_id", - "gmf_model", - "mlp_model", + "pretrained_gmf", + "pretrained_mlp", "alpha", ] ) - def pretrain(self, gmf_model, mlp_model, alpha=0.5): + def from_pretrained(self, pretrained_gmf, pretrained_mlp, alpha=0.5): """Provide pre-trained GMF and MLP models. Section 3.4.1 of the paper. Parameters ---------- - gmf_model: object of type GMF, required + pretrained_gmf: object of type GMF, required Reference to trained/fitted GMF model. - gmf_model: object of type GMF, required - Reference to trained/fitted GMF model. + pretrained_mlp: object of type MLP, required + Reference to trained/fitted MLP model. alpha: float, optional, default: 0.5 Hyper-parameter determining the trade-off between the two pre-trained models. Details are described in the section 3.4.1 of the paper. """ self.pretrained = True - self.gmf_model = gmf_model - self.mlp_model = mlp_model + self.pretrained_gmf = pretrained_gmf + self.pretrained_mlp = pretrained_mlp self.alpha = alpha return self - def _build_graph(self): + ######################## + ## TensorFlow backend ## + ######################## + def _build_graph_tf(self): import tensorflow.compat.v1 as tf - from .ops import gmf, mlp, loss_fn, train_fn + from .backend_tf import gmf, mlp, loss_fn, train_fn - super()._build_graph() + self.graph = tf.Graph() with self.graph.as_default(): tf.set_random_seed(self.seed) @@ -176,8 +182,8 @@ def _build_graph(self): num_users=self.num_users, num_items=self.num_items, emb_size=self.num_factors, - reg_user=self.reg_mf, - reg_item=self.reg_mf, + reg_user=self.reg, + reg_item=self.reg, seed=self.seed, ) mlp_feat = mlp( @@ -186,7 +192,7 @@ def _build_graph(self): num_users=self.num_users, num_items=self.num_items, layers=self.layers, - reg_layers=self.reg_layers, + reg_layers=[self.reg] * len(self.layers), act_fn=self.act_fn, seed=self.seed, ) @@ -208,20 +214,20 @@ def _build_graph(self): self.initializer = tf.global_variables_initializer() self.saver = tf.train.Saver() - self._sess_init() + self._sess_init_tf() if self.pretrained: - gmf_kernel = self.gmf_model.sess.run( - self.gmf_model.sess.graph.get_tensor_by_name("logits/kernel:0") + gmf_kernel = self.pretrained_gmf.sess.run( + self.pretrained_gmf.sess.graph.get_tensor_by_name("logits/kernel:0") ) - gmf_bias = self.gmf_model.sess.run( - self.gmf_model.sess.graph.get_tensor_by_name("logits/bias:0") + gmf_bias = self.pretrained_gmf.sess.run( + self.pretrained_gmf.sess.graph.get_tensor_by_name("logits/bias:0") ) - mlp_kernel = self.mlp_model.sess.run( - self.mlp_model.sess.graph.get_tensor_by_name("logits/kernel:0") + mlp_kernel = self.pretrained_mlp.sess.run( + self.pretrained_mlp.sess.graph.get_tensor_by_name("logits/kernel:0") ) - mlp_bias = self.mlp_model.sess.run( - self.mlp_model.sess.graph.get_tensor_by_name("logits/bias:0") + mlp_bias = self.pretrained_mlp.sess.run( + self.pretrained_mlp.sess.graph.get_tensor_by_name("logits/bias:0") ) logits_kernel = np.concatenate( [self.alpha * gmf_kernel, (1 - self.alpha) * mlp_kernel] @@ -230,12 +236,12 @@ def _build_graph(self): for v in self.graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES): if v.name.startswith("GMF"): - sess = self.gmf_model.sess + sess = self.pretrained_gmf.sess self.sess.run( tf.assign(v, sess.run(sess.graph.get_tensor_by_name(v.name))) ) elif v.name.startswith("MLP"): - sess = self.mlp_model.sess + sess = self.pretrained_mlp.sess self.sess.run( tf.assign(v, sess.run(sess.graph.get_tensor_by_name(v.name))) ) @@ -244,65 +250,59 @@ def _build_graph(self): elif v.name.startswith("logits/bias"): self.sess.run(tf.assign(v, logits_bias)) - def _step_update(self, batch_users, batch_items, batch_ratings): - _, _loss = self.sess.run( - [self.train_op, self.loss], - feed_dict={ - self.gmf_user_id: batch_users, - self.mlp_user_id: batch_users, - self.item_id: batch_items, - self.labels: batch_ratings.reshape(-1, 1), - }, - ) - return _loss - - def score(self, user_idx, item_idx=None): - """Predict the scores/ratings of a user for an item. - - Parameters - ---------- - user_idx: int, required - The index of the user for whom to perform score prediction. + def _get_feed_dict(self, batch_users, batch_items, batch_ratings): + return { + self.gmf_user_id: batch_users, + self.mlp_user_id: batch_users, + self.item_id: batch_items, + self.labels: batch_ratings.reshape(-1, 1), + } - item_idx: int, optional, default: None - The index of the item for which to perform score prediction. - If None, scores for all known items will be returned. - - Returns - ------- - res : A scalar or a Numpy array - Relative scores that the user gives to the item or to all known items - """ + def _score_tf(self, user_idx, item_idx): if item_idx is None: - if self.train_set.is_unk_user(user_idx): - raise ScoreException( - "Can't make score prediction for (user_id=%d)" % user_idx - ) - - known_item_scores = self.sess.run( - self.prediction, - feed_dict={ - self.gmf_user_id: [user_idx], - self.mlp_user_id: np.ones(self.train_set.num_items) * user_idx, - self.item_id: np.arange(self.train_set.num_items), - }, - ) - return known_item_scores.ravel() + feed_dict = { + self.gmf_user_id: [user_idx], + self.mlp_user_id: np.ones(self.num_items) * user_idx, + self.item_id: np.arange(self.num_items), + } else: - if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item( - item_idx - ): - raise ScoreException( - "Can't make score prediction for (user_id=%d, item_id=%d)" - % (user_idx, item_idx) - ) - - user_pred = self.sess.run( - self.prediction, - feed_dict={ - self.gmf_user_id: [user_idx], - self.mlp_user_id: [user_idx], - self.item_id: [item_idx], - }, + feed_dict = { + self.gmf_user_id: [user_idx], + self.mlp_user_id: [user_idx], + self.item_id: [item_idx], + } + return self.sess.run(self.prediction, feed_dict=feed_dict) + + ##################### + ## PyTorch backend ## + ##################### + def _build_model_pt(self): + from .backend_pt import NeuMF + + model = NeuMF( + num_users=self.num_users, + num_items=self.num_items, + layers=self.layers, + act_fn=self.act_fn, + ) + if self.pretrained: + model.from_pretrained( + self.pretrained_gmf.model, self.pretrained_mlp.model, self.alpha + ) + return model + + def _score_pt(self, user_idx, item_idx): + import torch + + with torch.no_grad(): + if item_idx is None: + users = torch.from_numpy(np.ones(self.num_items, dtype=int) * user_idx) + items = (torch.from_numpy(np.arange(self.num_items))).to(self.device) + else: + users = torch.tensor(user_idx).unsqueeze(0) + items = torch.tensor(item_idx).unsqueeze(0) + gmf_users = torch.tensor(user_idx).unsqueeze(0).to(self.device) + output = self.model( + users.to(self.device), items.to(self.device), gmf_users.to(self.device) ) - return user_pred.ravel() + return output.squeeze().cpu().numpy() diff --git a/cornac/models/ncf/requirements.txt b/cornac/models/ncf/requirements.txt index a60e13761..c108aa3bf 100644 --- a/cornac/models/ncf/requirements.txt +++ b/cornac/models/ncf/requirements.txt @@ -1 +1,2 @@ -tensorflow==2.12.0 \ No newline at end of file +tensorflow==2.12.0 +torch>=0.4.1 \ No newline at end of file diff --git a/examples/ncf_example.py b/examples/ncf_example.py index 9902877b7..85586a964 100644 --- a/examples/ncf_example.py +++ b/examples/ncf_example.py @@ -33,11 +33,14 @@ verbose=True, ) +backend = "tensorflow" # or 'pytorch' + # Instantiate the recommender models to be compared gmf = cornac.models.GMF( num_factors=8, num_epochs=10, learner="adam", + backend=backend, batch_size=256, lr=0.001, num_neg=50, @@ -47,6 +50,7 @@ layers=[64, 32, 16, 8], act_fn="tanh", learner="adam", + backend=backend, num_epochs=10, batch_size=256, lr=0.001, @@ -58,6 +62,7 @@ layers=[64, 32, 16, 8], act_fn="tanh", learner="adam", + backend=backend, num_epochs=10, batch_size=256, lr=0.001, @@ -66,7 +71,8 @@ ) neumf2 = cornac.models.NeuMF( name="NeuMF_pretrained", - learner="adam", + learner="sgd", + backend=backend, num_epochs=10, batch_size=256, lr=0.001, @@ -75,7 +81,7 @@ num_factors=gmf.num_factors, layers=mlp.layers, act_fn=mlp.act_fn, -).pretrain(gmf, mlp) +).from_pretrained(gmf, mlp, alpha=0.5) # Instantiate evaluation metrics ndcg_50 = cornac.metrics.NDCG(k=50) @@ -84,6 +90,11 @@ # Put everything together into an experiment and run it cornac.Experiment( eval_method=ratio_split, - models=[gmf, mlp, neumf1, neumf2], + models=[ + gmf, + mlp, + neumf1, + neumf2, + ], metrics=[ndcg_50, rec_50], ).run()