Skip to content

Commit

Permalink
2021/9/24 1. 新增EMA;2. evaluate函数解码与测试一致;3.nezha-fgm1.0-lsr0.1_f82.19_…
Browse files Browse the repository at this point in the history
…p75.58_r90.06;4.预训练模型待测试。
  • Loading branch information
louishsu committed Sep 24, 2021
1 parent cb68095 commit d55c1ef
Show file tree
Hide file tree
Showing 5 changed files with 204 additions and 52 deletions.
8 changes: 7 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,13 @@ def main():
# n_splits = 5
# seed=42
# --------------------------
version = "nezha-rdrop0.1-fgm1.0-aug_ctx0.15"
# version = "nezha-rdrop0.1-fgm1.0-aug_ctx0.15"
# model_type = "nezha_span"
# dataset_name = "cail_ner"
# n_splits = 5
# seed=42
# --------------------------
version = "nezha-fgm1.0-lsr0.1"
model_type = "nezha_span"
dataset_name = "cail_ner"
n_splits = 5
Expand Down
8 changes: 7 additions & 1 deletion main_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,13 @@ def main():
# n_splits = 5
# seed=42
# --------------------------
version = "nezha-rdrop0.1-fgm1.0-aug_ctx0.15"
# version = "nezha-rdrop0.1-fgm1.0-aug_ctx0.15"
# model_type = "nezha_span"
# dataset_name = "cail_ner"
# n_splits = 5
# seed=42
# --------------------------
version = "nezha-fgm1.0-lsr0.1"
model_type = "nezha_span"
dataset_name = "cail_ner"
n_splits = 5
Expand Down
106 changes: 95 additions & 11 deletions run_span.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,55 @@ def forward(self, **kwargs):
return forward_rdrop(self, args.rdrop_alpha, **kwargs)
return forward(self, **kwargs)


class ExponentialMovingAverage(object):
'''
权重滑动平均,对最近的数据给予更高的权重
uasge:
# 初始化
ema = EMA(model, 0.999)
# 训练过程中,更新完参数后,同步update shadow weights
def train():
optimizer.step()
ema.update(model)
# eval前,apply shadow weights;
# eval之后(保存模型后),恢复原来模型的参数
def evaluate():
ema.apply_shadow(model)
# evaluate
ema.restore(modle)
'''
def __init__(self,model, decay, device):
self.decay = decay
self.device = device
self.shadow = {}
self.backup = {}
for name, param in model.named_parameters():
if param.requires_grad:
self.shadow[name] = param.data.clone().cpu()

def update(self,model):
for name, param in model.named_parameters():
if param.requires_grad:
assert name in self.shadow
new_average = (1.0 - self.decay) * param.data.cpu() + self.decay * self.shadow[name]
self.shadow[name] = new_average.clone()

def apply_shadow(self,model):
for name, param in model.named_parameters():
if param.requires_grad:
assert name in self.shadow
self.backup[name] = param.data.cpu()
param.data = self.shadow[name].to(self.device)

def restore(self,model):
for name, param in model.named_parameters():
if param.requires_grad:
assert name in self.backup
param.data = self.backup[name].to(self.device)
self.backup = {}


class NerArgumentParser(ArgumentParser):

def __init__(self, **kwargs):
Expand Down Expand Up @@ -399,6 +448,7 @@ def build_arguments(self):
self.add_argument("--augment_entity_replace_p", default=None, type=float)
self.add_argument("--rdrop_alpha", default=None, type=float)
self.add_argument("--vat_alpha", default=None, type=float)
self.add_argument("--do_ema", action="store_true")

# Other parameters
self.add_argument('--scheme', default='IOB2', type=str,
Expand Down Expand Up @@ -855,6 +905,8 @@ def train(args, model, processor, tokenizer):
find_unused_parameters=True)
if args.do_fgm:
fgm = FGM(model, emb_name=args.fgm_name, epsilon=args.fgm_epsilon)
if args.do_ema:
ema = ExponentialMovingAverage(model, decay=0.999, device=args.device)
# Train!
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset))
Expand Down Expand Up @@ -937,6 +989,8 @@ def train(args, model, processor, tokenizer):
optimizer.step()
scheduler.step() # Update learning rate schedule
model.zero_grad()
if args.do_ema:
ema.update(model)
global_step += 1
if args.local_rank in [-1, 0] and args.evaluate_during_training and \
args.logging_steps > 0 and global_step % args.logging_steps == 0:
Expand Down Expand Up @@ -967,6 +1021,23 @@ def train(args, model, processor, tokenizer):
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
logger.info("Saving optimizer and scheduler states to %s", output_dir)

if args.do_ema:
logger.info("{:*^50s}".format("EMA"))
ema.apply_shadow(model)
eval_results_ema = evaluate(args, model, processor, tokenizer)
logger.info(f"[{epoch_no}] loss={eval_results_ema.pop('loss')}")
for entity, metrics in eval_results_ema.items():
logger.info("{:*^50s}".format(entity))
logger.info("\t".join(f"{metric:s}={value:f}"
for metric, value in metrics.items()))
if eval_results_ema["avg"]["f"] > best_f1:
model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir)
ema.restore(model)

elif args.local_rank in [-1, 0] and \
args.save_steps > 0 and global_step % args.save_steps == 0:
# Save model checkpoint
Expand Down Expand Up @@ -1029,17 +1100,29 @@ def evaluate(args, model, processor, tokenizer, prefix=""):
nb_eval_steps += 1

# calculate metrics
preds = SpanV2.decode_batch(logits, batch["spans"], batch["span_mask"])
for pred_no, (pred, input_len, start, end) in enumerate(zip(
preds, batch["input_len"], batch["sent_start"], batch["sent_end"])):
pred = [(LABEL_MEANING_MAP[id2label[t]], b, e) for t, b, e in pred if id2label[t] != "O"]
pred = [(t, b - start, e - start) for t, b, e in pred]
sample = eval_dataset.examples[args.eval_batch_size * step + pred_no][1]
label_entities_map = {v: [] for v in LABEL_MEANING_MAP.values()}
for t, b, e in pred:
label_entities_map[t].append(f"{b};{e+1}")
entities = [{"label": k, "span": v} for k, v in label_entities_map.items()]
y_pred.append({"id": sample["id"].split("-")[-1], "entities": entities})
# preds = SpanV2.decode_batch(logits, batch["spans"], batch["span_mask"])
# for pred_no, (pred, input_len, start, end) in enumerate(zip(
# preds, batch["input_len"], batch["sent_start"], batch["sent_end"])):
# pred = [(LABEL_MEANING_MAP[id2label[t]], b, e) for t, b, e in pred if id2label[t] != "O"]
# pred = [(t, b - start, e - start) for t, b, e in pred]
# sample = eval_dataset.examples[args.eval_batch_size * step + pred_no][1]
# label_entities_map = {v: [] for v in LABEL_MEANING_MAP.values()}
# for t, b, e in pred:
# label_entities_map[t].append(f"{b};{e+1}")
# entities = [{"label": k, "span": v} for k, v in label_entities_map.items()]
# y_pred.append({"id": sample["id"].split("-")[-1], "entities": entities})

# evaluate函数解码替换为与test一致
for pred_no in range(logits.size(0)):
example_ = eval_dataset.examples[args.eval_batch_size * step + pred_no][1]
batch_ = {k: v[[pred_no]] for k,v in batch.items()}
batch_["logits"] = logits[[pred_no]]
y_pred_ = predict_decode_batch(example_, batch_, id2label, post_process=True)
for entity_no, entity in enumerate(y_pred_["entities"]):
y_pred_["entities"][entity_no] = {
"label": LABEL_MEANING_MAP[entity["label"]],
"span": entity["span"]}
y_pred.append(y_pred_)

labels = SpanV2.decode_batch(batch["label"], batch["spans"], batch["span_mask"], is_logits=False)
for label_no, (label, input_len, start, end) in enumerate(zip(
Expand Down Expand Up @@ -1389,6 +1472,7 @@ def load_dataset(args, processor, tokenizer, data_type='train'):
# 将early stop模型保存到输出目录下
if args.save_best_checkpoints:
best_checkpoints = os.path.join(args.output_dir, "checkpoint-999999")
logger.info("Loading model checkpoint from %s", best_checkpoints)
config = config_class.from_pretrained(best_checkpoints,
num_labels=num_labels, max_span_length=args.max_span_length,
cache_dir=args.cache_dir if args.cache_dir else None, )
Expand Down
7 changes: 2 additions & 5 deletions scripts/run_mlm_wwm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,10 @@ nohup python run_mlm_wwm.py \
--mlm_probability=0.15 \
--output_dir=output/nezha-legal-cn-base-wwm/ \
--overwrite_output_dir \
--do_train --do_eval \
--warmup_steps=1000 \
--do_train \
--warmup_steps=1500 \
--max_steps=30000 \
--evaluation_strategy=steps \
--eval_steps=1500 \
--per_device_train_batch_size=48 \
--per_device_eval_batch_size=48 \
--gradient_accumulation_steps=4 \
--label_smoothing_factor=0.0 \
--learning_rate=5e-5 \
Expand Down
127 changes: 93 additions & 34 deletions scripts/run_span.sh
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ done
# 组织机构
# {'p': 0.8795336787564767, 'r': 0.8424317617866005, 'f': 0.8605830164765527}

# TODO: context-aware,仅增强分类性能较差的“受害人、犯罪嫌疑人”
# context-aware,仅增强分类性能较差的“受害人、犯罪嫌疑人”
for k in 0 1 2 3 4
do
python run_span.py \
Expand Down Expand Up @@ -419,48 +419,107 @@ done
# 组织机构
# {'p': 0.8423586040914561, 'r': 0.8684863523573201, 'f': 0.855222968845449}

# LSR
for k in 0 1 2 3 4
do
python run_span.py \
--version=nezha-fgm1.0-lsr0.1-fold${k} \
--data_dir=./data/ner-ctx0-5fold-seed42/ \
--train_file=train.${k}.json \
--dev_file=dev.${k}.json \
--test_file=dev.${k}.json \
--model_type=nezha_span \
--model_name_or_path=/home/louishsu/NewDisk/Garage/weights/transformers/nezha-cn-base/ \
--do_train \
--overwrite_output_dir \
--evaluate_during_training \
--evaluate_each_epoch \
--save_best_checkpoints \
--max_span_length=40 \
--width_embedding_dim=128 \
--train_max_seq_length=512 \
--eval_max_seq_length=512 \
--do_lower_case \
--per_gpu_train_batch_size=8 \
--per_gpu_eval_batch_size=16 \
--gradient_accumulation_steps=2 \
--learning_rate=5e-5 \
--other_learning_rate=1e-3 \
--num_train_epochs=8.0 \
--warmup_proportion=0.1 \
--do_fgm --fgm_epsilon=1.0 \
--loss_type=lsr --label_smooth_eps=0.1 \
--seed=42
done
# main_local
# avg
# {'p': 0.8992657967816792, 'r': 0.8866509133190803, 'f': 0.8929138022210471}
# 犯罪嫌疑人
# {'p': 0.9609907120743034, 'r': 0.9605446387126721, 'f': 0.9607676236168073}
# 受害人
# {'p': 0.9242566510172144, 'r': 0.9501287001287001, 'f': 0.9370141202601936}
# 被盗货币
# {'p': 0.8159574468085107, 'r': 0.8382513661202186, 'f': 0.8269541778975741}
# 物品价值
# {'p': 0.9697696737044146, 'r': 0.9669856459330144, 'f': 0.9683756588404409}
# 盗窃获利
# {'p': 0.8627450980392157, 'r': 0.9147609147609148, 'f': 0.8879919273461151}
# 被盗物品
# {'p': 0.8217407137654771, 'r': 0.7806607853312576, 'f': 0.8006741772376476}
# 作案工具
# {'p': 0.801994301994302, 'r': 0.7659863945578231, 'f': 0.7835768963117605}
# 时间
# {'p': 0.9488699518340126, 'r': 0.9262206148282097, 'f': 0.9374084919472914}
# 地点
# {'p': 0.861102919492775, 'r': 0.8302530565823145, 'f': 0.8453966415749856}
# 组织机构
# {'p': 0.8513513513513513, 'r': 0.8598014888337469, 'f': 0.8555555555555555}

# TODO: 全部数据

# TODO: EMA
for k in 0 1 2 3 4
do
python run_span.py \
--version=nezha-fgm1.0-ema0.999-fold${k} \
--data_dir=./data/ner-ctx0-5fold-seed42/ \
--train_file=train.${k}.json \
--dev_file=dev.${k}.json \
--test_file=dev.${k}.json \
--model_type=nezha_span \
--model_name_or_path=/home/louishsu/NewDisk/Garage/weights/transformers/nezha-cn-base/ \
--do_train \
--overwrite_output_dir \
--evaluate_during_training \
--evaluate_each_epoch \
--save_best_checkpoints \
--max_span_length=40 \
--width_embedding_dim=128 \
--train_max_seq_length=512 \
--eval_max_seq_length=512 \
--do_lower_case \
--per_gpu_train_batch_size=8 \
--per_gpu_eval_batch_size=16 \
--gradient_accumulation_steps=2 \
--learning_rate=5e-5 \
--other_learning_rate=1e-3 \
--num_train_epochs=4.0 \
--warmup_proportion=0.1 \
--do_fgm --fgm_epsilon=1.0 \
--do_ema \
--seed=42
done

# TODO: albert
# TODO: distill
# TODO: further pretrain,加入别的赛道数据
# TODO: 伪标签,别的赛道数据作为无标签数据
# TODO: 往年数据
# TODO: TTA

# <<< 第二阶段 <<<

# ==================================================================================================================
# TODO: label smooth 0.1
# for k in 0 1 2 3 4
# do
# python run_span.py \
# --version=nezha-rdrop0.1-fgm1.0-lsr0.1-fold${k} \
# --data_dir=./data/ner-ctx0-5fold-seed42/ \
# --train_file=train.${k}.json \
# --dev_file=dev.${k}.json \
# --test_file=dev.${k}.json \
# --model_type=nezha_span \
# --model_name_or_path=/home/louishsu/NewDisk/Garage/weights/transformers/nezha-cn-base/ \
# --do_train \
# --overwrite_output_dir \
# --evaluate_during_training \
# --evaluate_each_epoch \
# --save_best_checkpoints \
# --max_span_length=40 \
# --width_embedding_dim=128 \
# --train_max_seq_length=512 \
# --eval_max_seq_length=512 \
# --do_lower_case \
# --per_gpu_train_batch_size=8 \
# --per_gpu_eval_batch_size=16 \
# --gradient_accumulation_steps=2 \
# --learning_rate=5e-5 \
# --other_learning_rate=1e-3 \
# --num_train_epochs=4.0 \
# --warmup_proportion=0.1 \
# --rdrop_alpha=0.1 \
# --do_fgm --fgm_epsilon=1.0 \
# --loss_type=lsr --label_smooth_eps=0.1 \
# --seed=42
# done

# for k in 0 1 2 3 4
# do
Expand Down

0 comments on commit d55c1ef

Please sign in to comment.