Skip to content

Commit

Permalink
Fix splitting for generative tasks (#400)
Browse files Browse the repository at this point in the history
Co-authored-by: Clémentine Fourrier <[email protected]>
  • Loading branch information
NathanHB and clefourrier authored Nov 25, 2024
1 parent 2c9bf97 commit ea46419
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions src/lighteval/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,11 @@ def init_split_limits(self, num_dataset_splits):
)

if len(self.sorted_data) > 0:
all_sorting_criterion = [self._sorting_criteria(self.sorted_data[0])[:2]]
all_sorting_criterion = [self._sorting_criteria(self.sorted_data[0])[:-1]]
splits_indices = [[0, None]]
for ix, req in enumerate(self.sorted_data):
current_sorting_criteria = self._sorting_criteria(req)
current_key = current_sorting_criteria[:2]
current_key = current_sorting_criteria[:-1]
if current_key not in all_sorting_criterion:
all_sorting_criterion.append(current_key)
splits_indices[-1][1] = ix
Expand All @@ -269,7 +269,7 @@ def init_split_limits(self, num_dataset_splits):
splits_indices = [tuple(e) for e in splits_indices]
return num_dataset_splits, splits_indices

def _sorting_criteria(self, request: GreedyUntilRequest) -> tuple[bool, bool, list, int]:
def _sorting_criteria(self, request: GreedyUntilRequest) -> tuple[bool, bool, list, int, int]:
"""
Collate function for generating batches.
Expand All @@ -284,7 +284,13 @@ def _sorting_criteria(self, request: GreedyUntilRequest) -> tuple[bool, bool, li
# The generative task has no limit except the model context
if gen_length is None:
gen_length = 0
return request.do_sample, request.use_logits, request.stop_sequence, -(len(toks) + gen_length)
return (
request.do_sample,
request.use_logits,
tuple(request.stop_sequence),
gen_length,
-(len(toks) + gen_length),
)


class GenerativeTaskDatasetNanotron(GenerativeTaskDataset):
Expand Down

0 comments on commit ea46419

Please sign in to comment.