Skip to content

Commit

Permalink
Update model
Browse files Browse the repository at this point in the history
  • Loading branch information
tqtg committed Oct 26, 2023
1 parent 1638d05 commit d0d7c98
Showing 1 changed file with 10 additions and 17 deletions.
27 changes: 10 additions & 17 deletions cornac/models/wmf/recom_wmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,10 @@ def __init__(

def _init(self):
rng = get_rng(self.seed)
n_users, n_items = self.train_set.num_users, self.train_set.num_items

if self.U is None:
self.U = xavier_uniform((n_users, self.k), rng)
self.U = xavier_uniform((self.num_users, self.k), rng)
if self.V is None:
self.V = xavier_uniform((n_items, self.k), rng)
self.V = xavier_uniform((self.num_items, self.k), rng)

def fit(self, train_set, val_set=None):
"""Fit the model to observations.
Expand All @@ -146,28 +144,27 @@ def fit(self, train_set, val_set=None):
self._init()

if self.trainable:
self._fit_cf()
self._fit_cf(train_set)

return self

def _fit_cf(self,):
def _fit_cf(self, train_set):
import tensorflow.compat.v1 as tf
from .wmf import Model

np.random.seed(self.seed)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

R = self.train_set.csc_matrix # csc for efficient slicing over items
n_users, n_items, = self.train_set.num_users, self.train_set.num_items
R = train_set.csc_matrix # csc for efficient slicing over items

# Build model
graph = tf.Graph()
with graph.as_default():
tf.set_random_seed(self.seed)
model = Model(
n_users=n_users,
n_items=n_items,
n_users=self.num_users,
n_items=self.num_items,
k=self.k,
lambda_u=self.lambda_u,
lambda_v=self.lambda_v,
Expand All @@ -184,11 +181,10 @@ def _fit_cf(self,):

loop = trange(self.max_iter, disable=not self.verbose)
for _ in loop:

sum_loss = 0
count = 0
for i, batch_ids in enumerate(
self.train_set.item_iter(self.batch_size, shuffle=True)
train_set.item_iter(self.batch_size, shuffle=True)
):
batch_R = R[:, batch_ids]
batch_C = np.ones(batch_R.shape) * self.b
Expand Down Expand Up @@ -232,17 +228,14 @@ def score(self, user_idx, item_idx=None):
Relative scores that the user gives to the item or to all known items
"""
if item_idx is None:
if self.train_set.is_unk_user(user_idx):
if not self.knows_user(user_idx):
raise ScoreException(
"Can't make score prediction for (user_id=%d)" % user_idx
)

known_item_scores = self.V.dot(self.U[user_idx, :])
return known_item_scores
else:
if self.train_set.is_unk_user(user_idx) or self.train_set.is_unk_item(
item_idx
):
if not (self.knows_user(user_idx) and self.knows_item(item_idx)):
raise ScoreException(
"Can't make score prediction for (user_id=%d, item_id=%d)"
% (user_idx, item_idx)
Expand Down

0 comments on commit d0d7c98

Please sign in to comment.