diff --git a/training/run_distillation.py b/training/run_distillation.py index 3e4ac3e..f2a49fd 100644 --- a/training/run_distillation.py +++ b/training/run_distillation.py @@ -1189,13 +1189,13 @@ def prepare_train_dataset(batch): # check that the length of the prompt does not exceed more than half the max label length (224) if len(prev_ids) > prompt_cutoff_length: prev_ids = prev_ids[-prompt_cutoff_length + 1 :] - prev_ids = [decoder_prev_token_id] + prev_ids # and that the total length of the labels does not exceed the max label length (448) - if len(prev_ids + token_ids) > max_label_length: - trim_length = len(prev_ids + token_ids) - max_label_length + 1 + if len(prev_ids + token_ids) + 1 > max_label_length: + trim_length = len(token_ids) - max_label_length + 1 prev_ids = prev_ids[trim_length:] - prev_ids = [decoder_prev_token_id] + prev_ids + + prev_ids = [decoder_prev_token_id] + prev_ids token_ids = prev_ids + token_ids