Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
tqtg committed Nov 21, 2023
1 parent 091667b commit 9bfef69
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 36 deletions.
55 changes: 23 additions & 32 deletions cornac/models/mf/backend_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,48 +29,39 @@
class MF(nn.Module):
def __init__(
self,
n_users,
n_items,
n_factors=10,
global_mean=0,
dropout=0,
init_params={},
u_factors,
i_factors,
u_biases,
i_biases,
use_bias,
global_mean,
dropout,
):
super(MF, self).__init__()
self.n_users = n_users
self.n_items = n_items
self.u_factors = nn.Embedding(n_users, n_factors)
self.i_factors = nn.Embedding(n_items, n_factors)
self.u_biases = nn.Embedding(n_users, 1)
self.i_biases = nn.Embedding(n_items, 1)

self.use_bias = use_bias
self.global_mean = global_mean
self.dropout = nn.Dropout(p=dropout)

self._init_params(init_params)

def _init_params(self, init_params={}):
if not init_params:
nn.init.normal_(self.u_factors.weight, std=0.01)
nn.init.normal_(self.i_factors.weight, std=0.01)
self.u_biases.weight.data.fill_(0.0)
self.i_biases.weight.data.fill_(0.0)
return

if "U" in init_params:
self.u_factors.weight.data = torch.from_numpy(init_params["U"])
if "V" in init_params:
self.i_factors.weight.data = torch.from_numpy(init_params["V"])
if "Bu" in init_params:
self.u_biases.weight.data = torch.from_numpy(init_params["Bu"])
if "Bi" in init_params:
self.i_biases.weight.data = torch.from_numpy(init_params["Bi"])
self.u_factors = nn.Embedding(*u_factors.shape)
self.i_factors = nn.Embedding(*i_factors.shape)
self.u_biases = nn.Embedding(*u_biases.shape)
self.i_biases = nn.Embedding(*i_biases.shape)

# init params
self.u_factors.weight.data = torch.from_numpy(u_factors)
self.i_factors.weight.data = torch.from_numpy(i_factors)
self.u_biases.weight.data = torch.from_numpy(u_biases)
self.i_biases.weight.data = torch.from_numpy(i_biases)

def forward(self, uids, iids):
ues = self.u_factors(uids)
uis = self.i_factors(iids)

preds = self.u_biases(uids) + self.i_biases(iids) + self.global_mean
preds += (self.dropout(ues) * self.dropout(uis)).sum(dim=1, keepdim=True)
preds = (self.dropout(ues) * self.dropout(uis)).sum(dim=1, keepdim=True)
if self.use_bias:
preds += self.u_biases(uids) + self.i_biases(iids) + self.global_mean

return preds.squeeze()

def __call__(self, *args):
Expand Down
9 changes: 5 additions & 4 deletions cornac/models/mf/recom_mf.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,12 +226,13 @@ def _fit_pt(self, train_set, val_set):

if not hasattr(self, "model"):
self.model = MF(
self.num_users,
self.num_items,
self.k,
self.u_factors,
self.i_factors,
self.u_biases.reshape(-1, 1),
self.i_biases.reshape(-1, 1),
self.use_bias,
self.global_mean,
self.droppout,
self.init_params,
)

learn(
Expand Down

0 comments on commit 9bfef69

Please sign in to comment.