From e749023b14d36cdff234398de57d9b6c519e776a Mon Sep 17 00:00:00 2001 From: hongjin-su <114016954+hongjin-su@users.noreply.github.com> Date: Sat, 30 Dec 2023 15:12:30 +0800 Subject: [PATCH] Update train.py --- train.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/train.py b/train.py index 30014bb..05f9eb3 100644 --- a/train.py +++ b/train.py @@ -92,9 +92,9 @@ def _get_train_sampler(self) : ) def compute_loss(self, model, inputs, return_outputs=False): - for task_id in inputs['task_name']: - assert task_id==inputs['task_name'][0],f"Examples in the same batch should come from the same task, " \ - f"but task {task_id} and task {inputs['task_name'][0]} are found" + for task_id in inputs['task_id']: + assert task_id==inputs['task_id'][0],f"Examples in the same batch should come from the same task, " \ + f"but task {task_id} and task {inputs['task_id'][0]} are found" cur_results = {} for k in ['query', 'pos', 'neg']: cur_inputs = { @@ -447,12 +447,12 @@ def main(): def get_examples_raw(old_examples_raw, total_n, real_batch_size): examples_raw = [] for idx in range(0, total_n, real_batch_size): - local_task_name = old_examples_raw[idx]['task_name'] + local_task_name = old_examples_raw[idx]['task_id'] cur_batch = [] include_batch = True for idx1 in range(idx, min(idx + real_batch_size, total_n)): - if not old_examples_raw[idx1]['task_name'] == local_task_name: - print(f'one batch in task {old_examples_raw[idx1]["task_name"]} is skipped') + if not old_examples_raw[idx1]['task_id'] == local_task_name: + print(f'one batch in task {old_examples_raw[idx1]["task_id"]} is skipped') include_batch = False break else: @@ -478,7 +478,7 @@ def get_examples_raw(old_examples_raw, total_n, real_batch_size): train_examples_raw = train_examples_raw[:int(data_args.debug_mode)] def get_dataset(examples_raw): - examples = {'query':[],'pos':[],'neg':[],'task_name':[]} + examples = {'query':[],'pos':[],'neg':[],'task_id':[]} task_name_map = {} total_num = len(examples_raw) task_count = 0 @@ -492,10 +492,10 @@ def get_dataset(examples_raw): cur_e[k][0] = '' assert cur_e[k][0].startswith('Represent ') or cur_e[k][0]=='' examples[k].append('!@#$%^&**!@#$%^&**'.join(cur_e[k])) - if not cur_e['task_name'] in task_name_map: - task_name_map[cur_e['task_name']] = task_count + if not cur_e['task_id'] in task_name_map: + task_name_map[cur_e['task_id']] = task_count task_count += 1 - examples['task_name'].append(task_name_map[cur_e['task_name']]) + examples['task_id'].append(task_name_map[cur_e['task_id']]) return examples train_raw_datasets = DatasetDict({'train':Dataset.from_dict(get_dataset(train_examples_raw))}) @@ -530,7 +530,7 @@ def preprocess_function(examples): all_tokenized[k] = all_tokenized[k].tolist() for k in keys: all_tokenized[f'{key}_{k}'] = tokenized[k].tolist() - all_tokenized['task_name'] = examples['task_name'] + all_tokenized['task_id'] = examples['task_id'] return all_tokenized train_dataset = train_raw_datasets["train"]