Skip to content

Commit

Permalink
2021/9/12 bert-rdrop0.1-fgm1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
louishsu committed Sep 12, 2021
1 parent ac7a1b5 commit a2f923b
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 67 deletions.
3 changes: 2 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

def main():
local_debug = False
version = "baseline"
# version = "baseline"
version = "rdrop0.1-fgm1.0"
model_type = "bert_span"
dataset_name = "cail_ner"
n_splits = 5
Expand Down
160 changes: 109 additions & 51 deletions run_span.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from transformers.modeling_outputs import TokenClassifierOutput
from nezha.modeling_nezha import NeZhaModel, NeZhaPreTrainedModel
from nezha.modeling_nezha import relative_position_encoding

# trainer & training arguments
from transformers import AdamW, get_linear_schedule_with_warmup
Expand Down Expand Up @@ -236,8 +237,8 @@ def __init__(self, config):
self.init_weights()

def forward(self, **kwargs):
# if args.rdrop_alpha is not None:
# return forward_rdrop(self, args.rdrop_alpha, **kwargs)
if args.rdrop_alpha is not None:
return forward_rdrop(self, args.rdrop_alpha, **kwargs)
return forward(self, **kwargs)

class NeZhaSpanV2ForNer(NeZhaPreTrainedModel):
Expand All @@ -252,8 +253,8 @@ def __init__(self, config):
self.init_weights()

def forward(self, **kwargs):
# if args.rdrop_alpha is not None:
# return forward_rdrop(self, args.rdrop_alpha, **kwargs)
if args.rdrop_alpha is not None:
return forward_rdrop(self, args.rdrop_alpha, **kwargs)
return forward(self, **kwargs)

class NerArgumentParser(ArgumentParser):
Expand Down Expand Up @@ -297,8 +298,9 @@ def build_arguments(self):
self.add_argument("--width_embedding_dim", default=128, type=int)
self.add_argument("--optimizer", default="adamw", type=str)
# self.add_argument("--context_window", default=0, type=int)
# self.add_argument("--augment_context_aware_p", default=None, type=float)
# self.add_argument("--rdrop_alpha", default=None, type=float)
self.add_argument("--augment_context_aware_p", default=None, type=float)
self.add_argument("--augment_entity_replace_p", default=None, type=float)
self.add_argument("--rdrop_alpha", default=None, type=float)

# Other parameters
self.add_argument('--scheme', default='IOB2', type=str,
Expand Down Expand Up @@ -446,6 +448,7 @@ def __getitem__(self, index):
example = self.examples[index]
# preprocessing
for proc in self.process_pipline:
if proc is None: continue
example = proc(example)
# convert to features
return example
Expand All @@ -470,39 +473,92 @@ def collate_fn(batch):
collated[k] = t
return collated

# class AugmentContextAware:

# def __init__(self, p, min_mlm_span_length, max_mlm_span_length):
# self.p = p
# self.min_mlm_span_length = min_mlm_span_length
# self.max_mlm_span_length = max_mlm_span_length

# def __call__(self, example):
# id_ = example[1]["id"]
# tokens = example[1]["tokens"]
# entities = example[1]["entities"]
# sent_start = example[1]["sent_start"]
# sent_end = example[1]["sent_end"]

# entities = sorted(entities, key=lambda x: x[1], reverse=True) # sort by entity start
# ner_tags = get_ner_tags(entities, len(tokens))

# for label, start, end, span_text in entities:
# if random.random() < self.p:
# entity_text_mlm = ["[MASK]"] * random.randint(
# self.min_mlm_span_length, self.max_mlm_span_length)
# tokens = tokens[: start] + entity_text_mlm + tokens[end + 1: ]
# ner_span_tag = [f"B-{label}"] + [f"I-{label}"] * (len(entity_text_mlm) - 1)
# ner_tags = ner_tags[: start] + ner_span_tag + ner_tags[end + 1: ]

# entities_new = [[t, s, e, "".join(tokens[s: e])] for t, s, e in get_entities(ner_tags)]
# return [example[0], {
# "id": id_,
# "tokens": tokens,
# "entities": entities_new,
# "sent_start": sent_start,
# "sent_end": sent_start + len(tokens)
# }]
class AugmentContextAware:

def __init__(self, p):
self.p = p

self.augment_entity_meanings = [
"物品价值", "被盗货币", "盗窃获利",
"受害人", "犯罪嫌疑人"
]

def __call__(self, example):
id_ = example[1]["id"]
tokens = example[1]["tokens"]
entities = example[1]["entities"]
sent_start = example[1]["sent_start"]
sent_end = example[1]["sent_end"]

random.shuffle(entities)
for entity_type, entity_start, entity_end, entity_text in entities:
if LABEL_MEANING_MAP[entity_type] in self.augment_entity_meanings:
if random.random() > self.p: continue
if any([tk == "[MASK]" for tk in tokens[entity_start: entity_end + 1]]):
continue
for i in range(entity_start, entity_end + 1):
tokens[i] = "[MASK]"
example[1]["tokens"] = tokens
return example

class AugmentEntityReplace:

def __init__(self, p, examples):
self.p = p

self.wordType_entityTypes_map = {
"姓名": ["受害人", "犯罪嫌疑人", ],
"价值": ["物品价值", "被盗货币", "盗窃获利", ],
}
self.entityType_wordType_map = dict()
for word_type, entity_types in self.wordType_entityTypes_map.items():
for entity_type in entity_types:
self.entityType_wordType_map[entity_type] = word_type

self.wordType_words_map = {
"姓名": set(),
"价值": set(),
}
for example in examples:
for entity_type, entity_start, entity_end, entity_text in example[1]["entities"]:
meaning = LABEL_MEANING_MAP[entity_type]
if meaning not in self.entityType_wordType_map:
continue
self.wordType_words_map[self.entityType_wordType_map[meaning]] \
.add(entity_text)
self.wordType_words_map = {k: list(v) for k, v in self.wordType_words_map.items()}

def __call__(self, example):
id_ = example[1]["id"]
tokens = example[1]["tokens"]
entities = example[1]["entities"]
sent_start = example[1]["sent_start"]
sent_end = example[1]["sent_end"]

text = "".join(tokens)
entities = sorted(entities, key=lambda x: x[0])
for i, (entity_type, entity_start, entity_end, entity_text) in enumerate(entities):
if random.random() > self.p: continue
meaning = LABEL_MEANING_MAP[entity_type]
if meaning not in self.entityType_wordType_map:
continue
entity_text_new = random.choice(self.wordType_words_map[self.entityType_wordType_map[meaning]])
len_diff = len(entity_text_new) - len(entity_text)
text = text[: entity_start] + entity_text_new + text[entity_end + 1:]
entity_start, entity_end = entity_start, entity_start + len(entity_text_new) - 1
entities[i] = [entity_type, entity_start, entity_end, text[entity_start: entity_end + 1]]
# 调整其他实体位置
adjust_pos = lambda x: x if x <= entity_start else x + len_diff
for j, (l, s, e, t) in enumerate(entities):
s, e = adjust_pos(s), adjust_pos(e)
t = text[s: e + 1]
entities[j] = [l, s, e, t]

example[1]["tokens"] = list(text)
example[1]["entities"] = entities
example[1]["sent_start"] = sent_start
example[1]["sent_end"] = sent_start + len(text)
return example

# TODO:
class ReDataMasking:
Expand Down Expand Up @@ -802,7 +858,7 @@ def train(args, model, processor, tokenizer):
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
logger.info("Saving optimizer and scheduler states to %s", output_dir)
if args.local_rank in [-1, 0] and \
elif args.local_rank in [-1, 0] and \
args.save_steps > 0 and global_step % args.save_steps == 0:
# Save model checkpoint
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(
Expand Down Expand Up @@ -970,14 +1026,14 @@ def merge_spans(spans, keep_type="short"):
# 加入`受害人+被盗物品`的组合
spans.extend([(a[0], b[1]) for a, b in itertools.product(
spans_name, spans) if a[1] - b[0] in [-1, 0]])
# `受害人+被盗物品`、`被盗物品`,优先保留`受害人+被盗物品`
is_todel = [False] * len(spans)
for i, a in enumerate(spans_name):
for j, b in enumerate(spans):
u = (a[0], b[1])
if u in spans and u != b:
is_todel[j] = True
spans = [span for flag, span in zip(is_todel, spans) if not flag]
# # `受害人+被盗物品`、`被盗物品`,优先保留`受害人+被盗物品`
# is_todel = [False] * len(spans)
# for i, a in enumerate(spans_name):
# for j, b in enumerate(spans):
# u = (a[0], b[1])
# if u in spans and u != b:
# is_todel[j] = True
# spans = [span for flag, span in zip(is_todel, spans) if not flag]
# <<< 姓名处理 <<<
spans = merge_spans(spans, keep_type="short")
entities = spans2entities(spans)
Expand Down Expand Up @@ -1067,8 +1123,10 @@ def load_dataset(args, processor, tokenizer, data_type='train'):
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
max_seq_length = args.train_max_seq_length if data_type == 'train' else args.eval_max_seq_length
return NerDataset(examples, process_pipline=[
# AugmentContextAware(args.augment_context_aware_p, 2, 30
# ) if (data_type == 'train' and args.augment_context_aware_p is not None) else None,
AugmentEntityReplace(args.augment_entity_replace_p, examples,
) if (data_type == 'train' and args.augment_entity_replace_p is not None) else None,
AugmentContextAware(args.augment_context_aware_p,
) if (data_type == 'train' and args.augment_context_aware_p is not None) else None,
Example2Feature(tokenizer, processor.label2id, max_seq_length, config.max_span_length),
])

Expand All @@ -1082,7 +1140,7 @@ def load_dataset(args, processor, tokenizer, data_type='train'):
args = parser.parse_args_from_json(json_file=os.path.abspath(sys.argv[1]))
else:
args = parser.build_arguments().parse_args()
# args = parser.parse_args_from_json(json_file="args/pred.0.json")
# args = parser.parse_args_from_json(json_file="output/ner-cail_ner-bert_span-aug_ctx1.0-fold0-42/training_args.json")

# Set seed before initializing model.
seed_everything(args.seed)
Expand Down
106 changes: 91 additions & 15 deletions scripts/run_span.sh
Original file line number Diff line number Diff line change
Expand Up @@ -118,40 +118,116 @@ python evaluate.py \
./data/ner-ctx0-5fold-seed42/dev.gt.${k}.json \
output/ner-cail_ner-bert_span-baseline-fold${k}-42/test_prediction.json
done
# avg
# {'p': 0.9142199194012666, 'r': 0.9188042430086789, 'f': 0.9165063485956138}
# 犯罪嫌疑人
# {'p': 0.9647829647829648, 'r': 0.9546191247974068, 'f': 0.959674134419552}
# 受害人
# {'p': 0.8846153846153846, 'r': 0.968013468013468, 'f': 0.9244372990353698}
# 被盗货币
# {'p': 0.8152173913043478, 'r': 0.8670520231213873, 'f': 0.8403361344537815}
# 物品价值
# {'p': 0.9780487804878049, 'r': 0.9828431372549019, 'f': 0.9804400977995109}
# 盗窃获利
# {'p': 0.8780487804878049, 'r': 0.9152542372881356, 'f': 0.896265560165975}
# 被盗物品
# {'p': 0.903485254691689, 'r': 0.8829694323144105, 'f': 0.8931095406360424}
# 作案工具
# {'p': 0.7407407407407407, 'r': 0.7874015748031497, 'f': 0.7633587786259541}
# 时间
# {'p': 0.9361702127659575, 'r': 0.9219047619047619, 'f': 0.928982725527831}
# 地点
# {'p': 0.8971830985915493, 'r': 0.885952712100139, 'f': 0.8915325402379286}
# 组织机构
# {'p': 0.8450704225352113, 'r': 0.8450704225352113, 'f': 0.8450704225352113}

# INFO:root:Counter({'犯罪嫌疑人': 1312, '被盗物品': 1202, '地点': 740, '时间': 573, '受害人': 570, '物品价值': 454, '作案工具': 171, '被盗货币': 170, '组织机构': 158, '盗窃获利': 79})
for k in 0 1 2 3 4
do
python run_span.py \
--version=legal_electra_base-fold${k} \
--version=rdrop0.1-fgm1.0-fold${k} \
--data_dir=./data/ner-ctx0-5fold-seed42/ \
--train_file=train.${k}.json \
--dev_file=dev.${k}.json \
--test_file=dev.${k}.json \
--model_type=bert_span \
--model_name_or_path=/home/louishsu/NewDisk/Garage/weights/transformers/hfl_chinese-legal-electra-base-discriminator/ \
--model_name_or_path=/home/louishsu/NewDisk/Garage/weights/transformers/chinese-roberta-wwm/ \
--do_train \
--do_eval \
--do_predict \
--overwrite_output_dir \
--evaluate_during_training \
--evaluate_each_epoch \
--save_best_checkpoints \
--max_span_length=50 \
--max_span_length=40 \
--width_embedding_dim=128 \
--train_max_seq_length=512 \
--eval_max_seq_length=512 \
--do_lower_case \
--per_gpu_train_batch_size=12 \
--per_gpu_eval_batch_size=24 \
--per_gpu_train_batch_size=8 \
--per_gpu_eval_batch_size=16 \
--gradient_accumulation_steps=2 \
--learning_rate=2e-5 \
--other_learning_rate=1e-4 \
--num_train_epochs=8.0 \
--learning_rate=5e-5 \
--other_learning_rate=1e-3 \
--num_train_epochs=4.0 \
--warmup_proportion=0.1 \
--rdrop_alpha=0.1 \
--do_fgm --fgm_epsilon=1.0 \
--seed=42
python evaluate.py \
./data/ner-ctx0-5fold-seed42/dev.gt.${k}.json \
output/ner-cail_ner-bert_span-baseline-fold${k}-42/test_prediction.json
done
# <<< 第二阶段 <<<
# avg
# {'p': 0.9452107279693487, 'r': 0.9515911282545805, 'f': 0.9483901970206632}
# 犯罪嫌疑人
# {'p': 0.9821573398215734, 'r': 0.9813614262560778, 'f': 0.9817592217267938}
# 受害人
# {'p': 0.93026941362916, 'r': 0.9882154882154882, 'f': 0.9583673469387756}
# 被盗货币
# {'p': 0.8461538461538461, 'r': 0.953757225433526, 'f': 0.8967391304347825}
# 物品价值
# {'p': 0.9901960784313726, 'r': 0.9901960784313726, 'f': 0.9901960784313726}
# 盗窃获利
# {'p': 0.9652173913043478, 'r': 0.940677966101695, 'f': 0.9527896995708155}
# 被盗物品
# {'p': 0.9335106382978723, 'r': 0.9196506550218341, 'f': 0.9265288165420149}
# 作案工具
# {'p': 0.8472222222222222, 'r': 0.9606299212598425, 'f': 0.9003690036900368}
# 时间
# {'p': 0.9455252918287937, 'r': 0.9257142857142857, 'f': 0.9355149181905678}
# 地点
# {'p': 0.9415121255349501, 'r': 0.9179415855354659, 'f': 0.9295774647887323}
# 组织机构
# {'p': 0.8940397350993378, 'r': 0.9507042253521126, 'f': 0.9215017064846417}

# for k in 0 1 2 3 4
# do
# python run_span.py \
# --version=legal_electra_base-fold${k} \
# --data_dir=./data/ner-ctx0-5fold-seed42/ \
# --train_file=train.${k}.json \
# --dev_file=dev.${k}.json \
# --test_file=dev.${k}.json \
# --model_type=bert_span \
# --model_name_or_path=/home/louishsu/NewDisk/Garage/weights/transformers/hfl_chinese-legal-electra-base-discriminator/ \
# --do_train \
# --do_eval \
# --do_predict \
# --overwrite_output_dir \
# --evaluate_during_training \
# --evaluate_each_epoch \
# --save_best_checkpoints \
# --max_span_length=50 \
# --width_embedding_dim=128 \
# --train_max_seq_length=512 \
# --eval_max_seq_length=512 \
# --do_lower_case \
# --per_gpu_train_batch_size=12 \
# --per_gpu_eval_batch_size=24 \
# --gradient_accumulation_steps=2 \
# --learning_rate=2e-5 \
# --other_learning_rate=1e-4 \
# --num_train_epochs=8.0 \
# --warmup_proportion=0.1 \
# --seed=42
# python evaluate.py \
# ./data/ner-ctx0-5fold-seed42/dev.gt.${k}.json \
# output/ner-cail_ner-bert_span-legal_electra_base-fold${k}-42/test_prediction.json
# done
# <<< 第二阶段 <<<

0 comments on commit a2f923b

Please sign in to comment.