From 9bfef69e0141d1c5bf13b5699a4822c6d1773456 Mon Sep 17 00:00:00 2001 From: tqtg Date: Tue, 21 Nov 2023 23:58:05 +0000 Subject: [PATCH] refactor --- cornac/models/mf/backend_pt.py | 55 ++++++++++++++-------------------- cornac/models/mf/recom_mf.py | 9 +++--- 2 files changed, 28 insertions(+), 36 deletions(-) diff --git a/cornac/models/mf/backend_pt.py b/cornac/models/mf/backend_pt.py index 8ab966ae4..12e31a74f 100644 --- a/cornac/models/mf/backend_pt.py +++ b/cornac/models/mf/backend_pt.py @@ -29,48 +29,39 @@ class MF(nn.Module): def __init__( self, - n_users, - n_items, - n_factors=10, - global_mean=0, - dropout=0, - init_params={}, + u_factors, + i_factors, + u_biases, + i_biases, + use_bias, + global_mean, + dropout, ): super(MF, self).__init__() - self.n_users = n_users - self.n_items = n_items - self.u_factors = nn.Embedding(n_users, n_factors) - self.i_factors = nn.Embedding(n_items, n_factors) - self.u_biases = nn.Embedding(n_users, 1) - self.i_biases = nn.Embedding(n_items, 1) + + self.use_bias = use_bias self.global_mean = global_mean self.dropout = nn.Dropout(p=dropout) - self._init_params(init_params) - - def _init_params(self, init_params={}): - if not init_params: - nn.init.normal_(self.u_factors.weight, std=0.01) - nn.init.normal_(self.i_factors.weight, std=0.01) - self.u_biases.weight.data.fill_(0.0) - self.i_biases.weight.data.fill_(0.0) - return - - if "U" in init_params: - self.u_factors.weight.data = torch.from_numpy(init_params["U"]) - if "V" in init_params: - self.i_factors.weight.data = torch.from_numpy(init_params["V"]) - if "Bu" in init_params: - self.u_biases.weight.data = torch.from_numpy(init_params["Bu"]) - if "Bi" in init_params: - self.i_biases.weight.data = torch.from_numpy(init_params["Bi"]) + self.u_factors = nn.Embedding(*u_factors.shape) + self.i_factors = nn.Embedding(*i_factors.shape) + self.u_biases = nn.Embedding(*u_biases.shape) + self.i_biases = nn.Embedding(*i_biases.shape) + + # init params + self.u_factors.weight.data = torch.from_numpy(u_factors) + self.i_factors.weight.data = torch.from_numpy(i_factors) + self.u_biases.weight.data = torch.from_numpy(u_biases) + self.i_biases.weight.data = torch.from_numpy(i_biases) def forward(self, uids, iids): ues = self.u_factors(uids) uis = self.i_factors(iids) - preds = self.u_biases(uids) + self.i_biases(iids) + self.global_mean - preds += (self.dropout(ues) * self.dropout(uis)).sum(dim=1, keepdim=True) + preds = (self.dropout(ues) * self.dropout(uis)).sum(dim=1, keepdim=True) + if self.use_bias: + preds += self.u_biases(uids) + self.i_biases(iids) + self.global_mean + return preds.squeeze() def __call__(self, *args): diff --git a/cornac/models/mf/recom_mf.py b/cornac/models/mf/recom_mf.py index e4b3f873e..127c042b4 100644 --- a/cornac/models/mf/recom_mf.py +++ b/cornac/models/mf/recom_mf.py @@ -226,12 +226,13 @@ def _fit_pt(self, train_set, val_set): if not hasattr(self, "model"): self.model = MF( - self.num_users, - self.num_items, - self.k, + self.u_factors, + self.i_factors, + self.u_biases.reshape(-1, 1), + self.i_biases.reshape(-1, 1), + self.use_bias, self.global_mean, self.droppout, - self.init_params, ) learn(