Skip to content

Commit

Permalink
Let models manage total_users and total_items instead of train_set, t…
Browse files Browse the repository at this point in the history
…rain_set shouldn't worry about users and items in validation and test
  • Loading branch information
tqtg committed Oct 26, 2023
1 parent a429d61 commit 8a4f2f8
Show file tree
Hide file tree
Showing 15 changed files with 262 additions and 311 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ __pycache__/

# C extensions
*.so
cornac/models/*/*.cpp
cornac/models/*/cython/*.cpp
cornac/utils/*.cpp

# Distribution / packaging
bin/
Expand Down
24 changes: 0 additions & 24 deletions cornac/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,6 @@ def __init__(
self.min_rating = np.min(r_values)
self.global_mean = np.mean(r_values)

self.__total_users = None
self.__total_items = None
self.__user_ids = None
self.__item_ids = None

Expand All @@ -111,28 +109,6 @@ def __init__(
self.__csc_matrix = None
self.__dok_matrix = None

@property
def total_users(self):
"""Total number of users including test and validation users if exists"""
return self.__total_users if self.__total_users is not None else self.num_users

@total_users.setter
def total_users(self, input_value):
"""Set total number of users for the dataset"""
assert input_value >= self.num_users
self.__total_users = input_value

@property
def total_items(self):
"""Total number of items including test and validation items if exists"""
return self.__total_items if self.__total_items is not None else self.num_items

@total_items.setter
def total_items(self, input_value):
"""Set total number of items for the dataset"""
assert input_value >= self.num_items
self.__total_items = input_value

@property
def user_ids(self):
"""Return the list of raw user ids"""
Expand Down
3 changes: 0 additions & 3 deletions cornac/eval_methods/base_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,9 +540,6 @@ def _build_datasets(self, train_data, test_data, val_data=None):
print("Total users = {}".format(self.total_users))
print("Total items = {}".format(self.total_items))

self.train_set.total_users = self.total_users
self.train_set.total_items = self.total_items

def _build_modalities(self):
for user_modality in [
self.user_feature,
Expand Down
4 changes: 2 additions & 2 deletions cornac/models/causalrec/recom_causalrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ def fit(self, train_set, val_set=None):
train_features = train_set.item_image.features[: self.total_items]
train_features = train_features.astype(np.float32)
self._init(
n_users=train_set.total_users,
n_items=train_set.total_items,
n_users=self.total_users,
n_items=self.total_items,
features=train_features,
)

Expand Down
2 changes: 1 addition & 1 deletion cornac/models/fm/recom_fm.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ class FM(Recommender):
(uid, iid, val) = train_set.uir_tuple
cdef Data *train = _prepare_data(
uid,
iid + train_set.total_users,
iid + self.total_users,
val.astype(np.float32),
num_feature,
self.method in ["als", "mcmc"],
Expand Down
Loading

0 comments on commit 8a4f2f8

Please sign in to comment.