From d55c1ef52410a568af75924e40fe0ddf59e96332 Mon Sep 17 00:00:00 2001 From: louishsu Date: Fri, 24 Sep 2021 23:01:16 +0800 Subject: [PATCH] =?UTF-8?q?2021/9/24=201.=20=E6=96=B0=E5=A2=9EEMA=EF=BC=9B?= =?UTF-8?q?2.=20evaluate=E5=87=BD=E6=95=B0=E8=A7=A3=E7=A0=81=E4=B8=8E?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E4=B8=80=E8=87=B4=EF=BC=9B3.nezha-fgm1.0-lsr?= =?UTF-8?q?0.1=5Ff82.19=5Fp75.58=5Fr90.06=EF=BC=9B4.=E9=A2=84=E8=AE=AD?= =?UTF-8?q?=E7=BB=83=E6=A8=A1=E5=9E=8B=E5=BE=85=E6=B5=8B=E8=AF=95=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 8 ++- main_local.py | 8 ++- run_span.py | 106 ++++++++++++++++++++++++++++++---- scripts/run_mlm_wwm.sh | 7 +-- scripts/run_span.sh | 127 ++++++++++++++++++++++++++++++----------- 5 files changed, 204 insertions(+), 52 deletions(-) diff --git a/main.py b/main.py index d0cecf8..440071d 100644 --- a/main.py +++ b/main.py @@ -42,7 +42,13 @@ def main(): # n_splits = 5 # seed=42 # -------------------------- - version = "nezha-rdrop0.1-fgm1.0-aug_ctx0.15" + # version = "nezha-rdrop0.1-fgm1.0-aug_ctx0.15" + # model_type = "nezha_span" + # dataset_name = "cail_ner" + # n_splits = 5 + # seed=42 + # -------------------------- + version = "nezha-fgm1.0-lsr0.1" model_type = "nezha_span" dataset_name = "cail_ner" n_splits = 5 diff --git a/main_local.py b/main_local.py index 2d96f55..01bd195 100644 --- a/main_local.py +++ b/main_local.py @@ -48,7 +48,13 @@ def main(): # n_splits = 5 # seed=42 # -------------------------- - version = "nezha-rdrop0.1-fgm1.0-aug_ctx0.15" + # version = "nezha-rdrop0.1-fgm1.0-aug_ctx0.15" + # model_type = "nezha_span" + # dataset_name = "cail_ner" + # n_splits = 5 + # seed=42 + # -------------------------- + version = "nezha-fgm1.0-lsr0.1" model_type = "nezha_span" dataset_name = "cail_ner" n_splits = 5 diff --git a/run_span.py b/run_span.py index ac5f61c..6c17ba6 100644 --- a/run_span.py +++ b/run_span.py @@ -354,6 +354,55 @@ def forward(self, **kwargs): return forward_rdrop(self, args.rdrop_alpha, **kwargs) return forward(self, **kwargs) + +class ExponentialMovingAverage(object): + ''' + 权重滑动平均,对最近的数据给予更高的权重 + uasge: + # 初始化 + ema = EMA(model, 0.999) + # 训练过程中,更新完参数后,同步update shadow weights + def train(): + optimizer.step() + ema.update(model) + # eval前,apply shadow weights; + # eval之后(保存模型后),恢复原来模型的参数 + def evaluate(): + ema.apply_shadow(model) + # evaluate + ema.restore(modle) + ''' + def __init__(self,model, decay, device): + self.decay = decay + self.device = device + self.shadow = {} + self.backup = {} + for name, param in model.named_parameters(): + if param.requires_grad: + self.shadow[name] = param.data.clone().cpu() + + def update(self,model): + for name, param in model.named_parameters(): + if param.requires_grad: + assert name in self.shadow + new_average = (1.0 - self.decay) * param.data.cpu() + self.decay * self.shadow[name] + self.shadow[name] = new_average.clone() + + def apply_shadow(self,model): + for name, param in model.named_parameters(): + if param.requires_grad: + assert name in self.shadow + self.backup[name] = param.data.cpu() + param.data = self.shadow[name].to(self.device) + + def restore(self,model): + for name, param in model.named_parameters(): + if param.requires_grad: + assert name in self.backup + param.data = self.backup[name].to(self.device) + self.backup = {} + + class NerArgumentParser(ArgumentParser): def __init__(self, **kwargs): @@ -399,6 +448,7 @@ def build_arguments(self): self.add_argument("--augment_entity_replace_p", default=None, type=float) self.add_argument("--rdrop_alpha", default=None, type=float) self.add_argument("--vat_alpha", default=None, type=float) + self.add_argument("--do_ema", action="store_true") # Other parameters self.add_argument('--scheme', default='IOB2', type=str, @@ -855,6 +905,8 @@ def train(args, model, processor, tokenizer): find_unused_parameters=True) if args.do_fgm: fgm = FGM(model, emb_name=args.fgm_name, epsilon=args.fgm_epsilon) + if args.do_ema: + ema = ExponentialMovingAverage(model, decay=0.999, device=args.device) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) @@ -937,6 +989,8 @@ def train(args, model, processor, tokenizer): optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() + if args.do_ema: + ema.update(model) global_step += 1 if args.local_rank in [-1, 0] and args.evaluate_during_training and \ args.logging_steps > 0 and global_step % args.logging_steps == 0: @@ -967,6 +1021,23 @@ def train(args, model, processor, tokenizer): torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) logger.info("Saving optimizer and scheduler states to %s", output_dir) + + if args.do_ema: + logger.info("{:*^50s}".format("EMA")) + ema.apply_shadow(model) + eval_results_ema = evaluate(args, model, processor, tokenizer) + logger.info(f"[{epoch_no}] loss={eval_results_ema.pop('loss')}") + for entity, metrics in eval_results_ema.items(): + logger.info("{:*^50s}".format(entity)) + logger.info("\t".join(f"{metric:s}={value:f}" + for metric, value in metrics.items())) + if eval_results_ema["avg"]["f"] > best_f1: + model_to_save = ( + model.module if hasattr(model, "module") else model + ) # Take care of distributed/parallel training + model_to_save.save_pretrained(output_dir) + ema.restore(model) + elif args.local_rank in [-1, 0] and \ args.save_steps > 0 and global_step % args.save_steps == 0: # Save model checkpoint @@ -1029,17 +1100,29 @@ def evaluate(args, model, processor, tokenizer, prefix=""): nb_eval_steps += 1 # calculate metrics - preds = SpanV2.decode_batch(logits, batch["spans"], batch["span_mask"]) - for pred_no, (pred, input_len, start, end) in enumerate(zip( - preds, batch["input_len"], batch["sent_start"], batch["sent_end"])): - pred = [(LABEL_MEANING_MAP[id2label[t]], b, e) for t, b, e in pred if id2label[t] != "O"] - pred = [(t, b - start, e - start) for t, b, e in pred] - sample = eval_dataset.examples[args.eval_batch_size * step + pred_no][1] - label_entities_map = {v: [] for v in LABEL_MEANING_MAP.values()} - for t, b, e in pred: - label_entities_map[t].append(f"{b};{e+1}") - entities = [{"label": k, "span": v} for k, v in label_entities_map.items()] - y_pred.append({"id": sample["id"].split("-")[-1], "entities": entities}) + # preds = SpanV2.decode_batch(logits, batch["spans"], batch["span_mask"]) + # for pred_no, (pred, input_len, start, end) in enumerate(zip( + # preds, batch["input_len"], batch["sent_start"], batch["sent_end"])): + # pred = [(LABEL_MEANING_MAP[id2label[t]], b, e) for t, b, e in pred if id2label[t] != "O"] + # pred = [(t, b - start, e - start) for t, b, e in pred] + # sample = eval_dataset.examples[args.eval_batch_size * step + pred_no][1] + # label_entities_map = {v: [] for v in LABEL_MEANING_MAP.values()} + # for t, b, e in pred: + # label_entities_map[t].append(f"{b};{e+1}") + # entities = [{"label": k, "span": v} for k, v in label_entities_map.items()] + # y_pred.append({"id": sample["id"].split("-")[-1], "entities": entities}) + + # evaluate函数解码替换为与test一致 + for pred_no in range(logits.size(0)): + example_ = eval_dataset.examples[args.eval_batch_size * step + pred_no][1] + batch_ = {k: v[[pred_no]] for k,v in batch.items()} + batch_["logits"] = logits[[pred_no]] + y_pred_ = predict_decode_batch(example_, batch_, id2label, post_process=True) + for entity_no, entity in enumerate(y_pred_["entities"]): + y_pred_["entities"][entity_no] = { + "label": LABEL_MEANING_MAP[entity["label"]], + "span": entity["span"]} + y_pred.append(y_pred_) labels = SpanV2.decode_batch(batch["label"], batch["spans"], batch["span_mask"], is_logits=False) for label_no, (label, input_len, start, end) in enumerate(zip( @@ -1389,6 +1472,7 @@ def load_dataset(args, processor, tokenizer, data_type='train'): # 将early stop模型保存到输出目录下 if args.save_best_checkpoints: best_checkpoints = os.path.join(args.output_dir, "checkpoint-999999") + logger.info("Loading model checkpoint from %s", best_checkpoints) config = config_class.from_pretrained(best_checkpoints, num_labels=num_labels, max_span_length=args.max_span_length, cache_dir=args.cache_dir if args.cache_dir else None, ) diff --git a/scripts/run_mlm_wwm.sh b/scripts/run_mlm_wwm.sh index 548133d..ed4d088 100644 --- a/scripts/run_mlm_wwm.sh +++ b/scripts/run_mlm_wwm.sh @@ -91,13 +91,10 @@ nohup python run_mlm_wwm.py \ --mlm_probability=0.15 \ --output_dir=output/nezha-legal-cn-base-wwm/ \ --overwrite_output_dir \ - --do_train --do_eval \ - --warmup_steps=1000 \ + --do_train \ + --warmup_steps=1500 \ --max_steps=30000 \ - --evaluation_strategy=steps \ - --eval_steps=1500 \ --per_device_train_batch_size=48 \ - --per_device_eval_batch_size=48 \ --gradient_accumulation_steps=4 \ --label_smoothing_factor=0.0 \ --learning_rate=5e-5 \ diff --git a/scripts/run_span.sh b/scripts/run_span.sh index 3226fd0..99689a1 100644 --- a/scripts/run_span.sh +++ b/scripts/run_span.sh @@ -363,7 +363,7 @@ done # 组织机构 # {'p': 0.8795336787564767, 'r': 0.8424317617866005, 'f': 0.8605830164765527} -# TODO: context-aware,仅增强分类性能较差的“受害人、犯罪嫌疑人” +# context-aware,仅增强分类性能较差的“受害人、犯罪嫌疑人” for k in 0 1 2 3 4 do python run_span.py \ @@ -419,48 +419,107 @@ done # 组织机构 # {'p': 0.8423586040914561, 'r': 0.8684863523573201, 'f': 0.855222968845449} +# LSR +for k in 0 1 2 3 4 +do +python run_span.py \ + --version=nezha-fgm1.0-lsr0.1-fold${k} \ + --data_dir=./data/ner-ctx0-5fold-seed42/ \ + --train_file=train.${k}.json \ + --dev_file=dev.${k}.json \ + --test_file=dev.${k}.json \ + --model_type=nezha_span \ + --model_name_or_path=/home/louishsu/NewDisk/Garage/weights/transformers/nezha-cn-base/ \ + --do_train \ + --overwrite_output_dir \ + --evaluate_during_training \ + --evaluate_each_epoch \ + --save_best_checkpoints \ + --max_span_length=40 \ + --width_embedding_dim=128 \ + --train_max_seq_length=512 \ + --eval_max_seq_length=512 \ + --do_lower_case \ + --per_gpu_train_batch_size=8 \ + --per_gpu_eval_batch_size=16 \ + --gradient_accumulation_steps=2 \ + --learning_rate=5e-5 \ + --other_learning_rate=1e-3 \ + --num_train_epochs=8.0 \ + --warmup_proportion=0.1 \ + --do_fgm --fgm_epsilon=1.0 \ + --loss_type=lsr --label_smooth_eps=0.1 \ + --seed=42 +done +# main_local +# avg +# {'p': 0.8992657967816792, 'r': 0.8866509133190803, 'f': 0.8929138022210471} +# 犯罪嫌疑人 +# {'p': 0.9609907120743034, 'r': 0.9605446387126721, 'f': 0.9607676236168073} +# 受害人 +# {'p': 0.9242566510172144, 'r': 0.9501287001287001, 'f': 0.9370141202601936} +# 被盗货币 +# {'p': 0.8159574468085107, 'r': 0.8382513661202186, 'f': 0.8269541778975741} +# 物品价值 +# {'p': 0.9697696737044146, 'r': 0.9669856459330144, 'f': 0.9683756588404409} +# 盗窃获利 +# {'p': 0.8627450980392157, 'r': 0.9147609147609148, 'f': 0.8879919273461151} +# 被盗物品 +# {'p': 0.8217407137654771, 'r': 0.7806607853312576, 'f': 0.8006741772376476} +# 作案工具 +# {'p': 0.801994301994302, 'r': 0.7659863945578231, 'f': 0.7835768963117605} +# 时间 +# {'p': 0.9488699518340126, 'r': 0.9262206148282097, 'f': 0.9374084919472914} +# 地点 +# {'p': 0.861102919492775, 'r': 0.8302530565823145, 'f': 0.8453966415749856} +# 组织机构 +# {'p': 0.8513513513513513, 'r': 0.8598014888337469, 'f': 0.8555555555555555} + +# TODO: 全部数据 + +# TODO: EMA +for k in 0 1 2 3 4 +do +python run_span.py \ + --version=nezha-fgm1.0-ema0.999-fold${k} \ + --data_dir=./data/ner-ctx0-5fold-seed42/ \ + --train_file=train.${k}.json \ + --dev_file=dev.${k}.json \ + --test_file=dev.${k}.json \ + --model_type=nezha_span \ + --model_name_or_path=/home/louishsu/NewDisk/Garage/weights/transformers/nezha-cn-base/ \ + --do_train \ + --overwrite_output_dir \ + --evaluate_during_training \ + --evaluate_each_epoch \ + --save_best_checkpoints \ + --max_span_length=40 \ + --width_embedding_dim=128 \ + --train_max_seq_length=512 \ + --eval_max_seq_length=512 \ + --do_lower_case \ + --per_gpu_train_batch_size=8 \ + --per_gpu_eval_batch_size=16 \ + --gradient_accumulation_steps=2 \ + --learning_rate=5e-5 \ + --other_learning_rate=1e-3 \ + --num_train_epochs=4.0 \ + --warmup_proportion=0.1 \ + --do_fgm --fgm_epsilon=1.0 \ + --do_ema \ + --seed=42 +done + # TODO: albert # TODO: distill # TODO: further pretrain,加入别的赛道数据 # TODO: 伪标签,别的赛道数据作为无标签数据 # TODO: 往年数据 +# TODO: TTA # <<< 第二阶段 <<< # ================================================================================================================== -# TODO: label smooth 0.1 -# for k in 0 1 2 3 4 -# do -# python run_span.py \ -# --version=nezha-rdrop0.1-fgm1.0-lsr0.1-fold${k} \ -# --data_dir=./data/ner-ctx0-5fold-seed42/ \ -# --train_file=train.${k}.json \ -# --dev_file=dev.${k}.json \ -# --test_file=dev.${k}.json \ -# --model_type=nezha_span \ -# --model_name_or_path=/home/louishsu/NewDisk/Garage/weights/transformers/nezha-cn-base/ \ -# --do_train \ -# --overwrite_output_dir \ -# --evaluate_during_training \ -# --evaluate_each_epoch \ -# --save_best_checkpoints \ -# --max_span_length=40 \ -# --width_embedding_dim=128 \ -# --train_max_seq_length=512 \ -# --eval_max_seq_length=512 \ -# --do_lower_case \ -# --per_gpu_train_batch_size=8 \ -# --per_gpu_eval_batch_size=16 \ -# --gradient_accumulation_steps=2 \ -# --learning_rate=5e-5 \ -# --other_learning_rate=1e-3 \ -# --num_train_epochs=4.0 \ -# --warmup_proportion=0.1 \ -# --rdrop_alpha=0.1 \ -# --do_fgm --fgm_epsilon=1.0 \ -# --loss_type=lsr --label_smooth_eps=0.1 \ -# --seed=42 -# done # for k in 0 1 2 3 4 # do