diff --git a/src/lighteval/data.py b/src/lighteval/data.py index 0305b073..74dedf22 100644 --- a/src/lighteval/data.py +++ b/src/lighteval/data.py @@ -252,11 +252,11 @@ def init_split_limits(self, num_dataset_splits): ) if len(self.sorted_data) > 0: - all_sorting_criterion = [self._sorting_criteria(self.sorted_data[0])[:2]] + all_sorting_criterion = [self._sorting_criteria(self.sorted_data[0])[:-1]] splits_indices = [[0, None]] for ix, req in enumerate(self.sorted_data): current_sorting_criteria = self._sorting_criteria(req) - current_key = current_sorting_criteria[:2] + current_key = current_sorting_criteria[:-1] if current_key not in all_sorting_criterion: all_sorting_criterion.append(current_key) splits_indices[-1][1] = ix @@ -269,7 +269,7 @@ def init_split_limits(self, num_dataset_splits): splits_indices = [tuple(e) for e in splits_indices] return num_dataset_splits, splits_indices - def _sorting_criteria(self, request: GreedyUntilRequest) -> tuple[bool, bool, list, int]: + def _sorting_criteria(self, request: GreedyUntilRequest) -> tuple[bool, bool, list, int, int]: """ Collate function for generating batches. @@ -284,7 +284,13 @@ def _sorting_criteria(self, request: GreedyUntilRequest) -> tuple[bool, bool, li # The generative task has no limit except the model context if gen_length is None: gen_length = 0 - return request.do_sample, request.use_logits, request.stop_sequence, -(len(toks) + gen_length) + return ( + request.do_sample, + request.use_logits, + tuple(request.stop_sequence), + gen_length, + -(len(toks) + gen_length), + ) class GenerativeTaskDatasetNanotron(GenerativeTaskDataset):