From f5026e955ce4cebc8cfed570fefbac316a6da6d8 Mon Sep 17 00:00:00 2001 From: louishsu Date: Fri, 1 Oct 2021 23:34:17 +0800 Subject: [PATCH] 2021/10/1 pseudo label --- run_span.py | 33 +++++++++++++++++++-------------- scripts/run_span.sh | 13 ++++++------- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/run_span.py b/run_span.py index c7b725d..229f84e 100644 --- a/run_span.py +++ b/run_span.py @@ -217,15 +217,20 @@ def forward(self, if args.do_pseudo: proba = logits.softmax(dim=-1) proba, index = proba.max(dim=-1) - condition = ( - label > args.pseudo_proba_ub - ) | ( - label < args.pseudo_proba_lb - ) & ( - label == PSEUDO_TAG - ) - label = torch.where(condition, index, label) - loss_mask = loss_mask & (label.view(-1) != PSEUDO_TAG) + + is_pseudo = label == PSEUDO_TAG + label = torch.where(is_pseudo, index, label) # 用预测标签替换无标签 + pseudo_valid_mask = is_pseudo & ( + proba > args.pseudo_proba_thresh + ) # 有效伪标签:是伪标签、且大于阈值 + # pseudo_valid_mask = is_pseudo & ( + # proba > args.pseudo_proba_thresh + # ) & ( + # index != 0 + # ) # 有效伪标签:是伪标签、且大于阈值、是实体 + loss_mask = (mask == 1) & (~is_pseudo) # 重新初始化loss_mask:真实标签 + loss_mask = loss_mask | pseudo_valid_mask # 合并`真实标签`和`有效伪标签` + loss_mask = loss_mask.view(-1) loss = self.loss_fct(logits.view(-1, num_labels), label.view(-1)) loss = loss[loss_mask].mean() @@ -474,8 +479,7 @@ def build_arguments(self): self.add_argument("--pseudo_data_dir", default=None, type=str) self.add_argument("--pseudo_data_file", default=None, type=str) self.add_argument("--pseudo_num_sample", default=None, type=int) - self.add_argument("--pseudo_proba_ub", default=0.99, type=float) - self.add_argument("--pseudo_proba_lb", default=0.01, type=float) + self.add_argument("--pseudo_proba_thresh", default=0.99, type=float) # Other parameters self.add_argument('--scheme', default='IOB2', type=str, @@ -606,6 +610,10 @@ def _create_examples(self, data_dir, data_file, mode): logger.info(f"Creating examples from {data_path}...") with open(data_path, encoding="utf-8") as f: lines = [json.loads(line) for line in f.readlines()] + # 无标签数据数量限制 + if mode == "pseudo" and args.pseudo_num_sample is not None: + random.shuffle(lines) + lines = lines[:args.pseudo_num_sample] logger.info(f"Totally {len(lines)} examples.") for sentence_counter, line in enumerate(lines): sentence = ( @@ -1417,9 +1425,6 @@ def load_dataset(args, processor, tokenizer, data_type='train'): elif data_type == 'pseudo': examples = processor.get_train_examples(args.data_dir, args.train_file) examples_pseudo = processor.get_pseudo_examples(args.pseudo_data_dir, args.pseudo_data_file) - if args.pseudo_num_sample is not None: - random.shuffle(examples_pseudo) - examples_pseudo = examples_pseudo[:args.pseudo_num_sample] examples.extend(examples_pseudo) if args.local_rank == 0 and not evaluate: torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache diff --git a/scripts/run_span.sh b/scripts/run_span.sh index 7720270..fbc1a12 100644 --- a/scripts/run_span.sh +++ b/scripts/run_span.sh @@ -582,7 +582,7 @@ python prepare_data.py \ for k in 0 1 2 3 4 do python run_span.py \ - --version=nezha-legal-fgm1.0-lsr0.1-pseudo_u0.99_l0.01-fold${k} \ + --version=nezha-legal-fgm1.0-lsr0.1-pseudo_t0.9-fold${k} \ --data_dir=./data/ner-ctx0-5fold-seed42/ \ --train_file=train.${k}.json \ --dev_file=dev.${k}.json \ @@ -602,18 +602,17 @@ python run_span.py \ --per_gpu_train_batch_size=8 \ --per_gpu_eval_batch_size=16 \ --gradient_accumulation_steps=2 \ - --learning_rate=1e-6 \ - --other_learning_rate=1e-6 \ - --num_train_epochs=2.0 \ + --learning_rate=1e-5 \ + --other_learning_rate=1e-5 \ + --num_train_epochs=1.0 \ --warmup_proportion=0.1 \ --do_fgm --fgm_epsilon=1.0 \ --loss_type=lsr --label_smooth_eps=0.1 \ --do_pseudo \ --pseudo_data_dir=../cail_processed_data/ner-ctx0-1fold-seed42/ \ --pseudo_data_file=train.json \ - --pseudo_num_sample=2000 \ - --pseudo_proba_ub=0.99 \ - --pseudo_proba_lb=0.01 \ + --pseudo_num_sample=1500 \ + --pseudo_proba_thresh=0.9 \ --seed=42 done