Skip to content

Commit

Permalink
consider using num_user_batches() instead of overwriting existing num…
Browse files Browse the repository at this point in the history
…_batches()
  • Loading branch information
tqtg committed Dec 7, 2023
1 parent 51bfb0c commit d173455
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions cornac/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,14 @@ def num_batches(self, batch_size):
"""Estimate number of batches per epoch"""
return estimate_batches(len(self.uir_tuple[0]), batch_size)

def num_user_batches(self, batch_size):
"""Estimate number of batches per epoch iterating over users"""
return estimate_batches(self.num_users, batch_size)

def num_item_batches(self, batch_size):
"""Estimate number of batches per epoch iterating over items"""
return estimate_batches(self.num_items, batch_size)

def idx_iter(self, idx_range, batch_size=1, shuffle=False):
"""Create an iterator over batch of indices
Expand Down Expand Up @@ -914,10 +922,6 @@ def from_ubitjson(cls, data, seed=None):
"""
return cls.build(data, fmt="UBITJson", seed=seed)

def num_batches(self, batch_size):
"""Estimate number of batches per epoch"""
return estimate_batches(len(self.user_data), batch_size)

def ub_iter(self, batch_size=1, shuffle=False):
"""Create an iterator over data yielding batch of users and batch of baskets
Expand Down

0 comments on commit d173455

Please sign in to comment.