Skip to content

Commit

Permalink
2021/9/11 baseline.5-fold.postprocess._f80.78_p81.7_r79.88
Browse files Browse the repository at this point in the history
  • Loading branch information
louishsu committed Sep 11, 2021
1 parent 88d3aef commit 320c2cb
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 12 deletions.
9 changes: 8 additions & 1 deletion TODO → TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
6. FGM
7. 分析三类错误:真实未预测出的、预测出的边界错误、预测出的类别错误
8. K折:是,已添加
9. nezha/司法bert
9. hfl/chinese-legal-electra-base-discriminator
``` py
from transformers import AutoConfig, AutoModel, AutoTokenizer
model_name_or_path = "hfl/chinese-legal-electra-base-discriminator"
config = AutoConfig.from_pretrained(model_name_or_path)
model = AutoModel.from_pretrained(model_name_or_path, config=config)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
```
10. 改题实体普遍较长,试一下bert_pointer:分类准确,不是模型问题
11. 分析数据,后处理需要保留的实体
2 changes: 1 addition & 1 deletion evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def analyze_error(ground_truth_path, output_path):
for label, score in get_scores(ground_truth_path, output_path).items():
print(LABEL_MEANING_MAP.get(label, label))
print(score)
# analyze_error(ground_truth_path, output_path)
analyze_error(ground_truth_path, output_path)

# for label, score in get_scores(
# "./data/ner-ctx0-5fold-seed42/dev.gt.0.json",
Expand Down
5 changes: 4 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
outfile = "/output/output.json"

def main():
local_debug = True
local_debug = False
version = "baseline"
model_type = "bert_span"
dataset_name = "cail_ner"
Expand Down Expand Up @@ -63,6 +63,9 @@ def main():
with open(output_predict_file, "w") as writer:
for record in results:
writer.write(json.dumps(record) + '\n')

if local_debug:
os.system("python evaluate.py data/ner-ctx0-5fold-seed42/dev.gt.0.json output.json")

if __name__ == '__main__':
main()
102 changes: 94 additions & 8 deletions run_span.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import glob
import random
import logging
import itertools
from pathlib import Path
from argparse import ArgumentParser, Namespace
from tqdm import tqdm
Expand Down Expand Up @@ -45,7 +46,7 @@
get_entities
)
from evaluate import score
from utils import LABEL_MEANING_MAP, get_ner_tags
from utils import LABEL_MEANING_MAP, MEANING_LABEL_MAP, get_ner_tags

class BertConfigSpanV2(BertConfig):

Expand Down Expand Up @@ -896,21 +897,106 @@ def evaluate(args, model, processor, tokenizer, prefix=""):
results['loss'] = eval_loss / nb_eval_steps
return results

def predict_decode_batch(example, batch, id2label):
def predict_decode_batch(example, batch, id2label, post_process=True):
if example["id"].split("-")[-1] == "0004d0b59e19461ff126e3a08a814c33":
print()
is_intersect = lambda a, b: min(a[1], b[1]) - max(a[0], b[0]) > 0
is_a_included_by_b = lambda a, b: min(a[1], b[1]) - max(a[0], b[0]) == a[1] - a[0]
is_contain_special_char = lambda x: any([c in text[x[0]: x[1]] for c in [",", ",", "、"]])
is_length_le_n = lambda x, n: x[1] - x[0] < n
entities2spans = lambda entities: [(int(e.split(";")[0]), int(e.split(";")[1])) for e in entities]
spans2entities = lambda spans: [f"{b};{e}" for b, e in spans]
def merge_spans(spans, keep_type="short"):
spans = sorted(spans, key=lambda x: (x[0], x[1] - x[0])) # (起始位置, 区间长度)
spans_new = []
for span in spans:
if not spans_new:
spans_new.append(span)
else:
spans_last = spans_new[-1]
if not is_intersect(spans_last, span):
spans_new.append(span)
else:
if keep_type == "long":
if is_a_included_by_b(spans_last, span):
spans_new.pop(-1)
spans_new.append(span)
elif is_a_included_by_b(span, spans_last):
pass
else:
spans_new.append(span)
elif keep_type == "short":
if is_a_included_by_b(spans_last, span):
pass
elif is_a_included_by_b(span, spans_last):
spans_new.pop(-1)
spans_new.append(span)
else:
spans_new.append(span)
# if len(spans_new) < len(spans): print(spans, "->", spans_new)
return spans_new

text = "".join(example["tokens"])
logits = batch["logits"]
preds = SpanV2.decode_batch(logits, batch["spans"], batch["span_mask"])
pred, input_len = preds[0], batch["input_len"][0]
start, end = batch["sent_start"].item(), batch["sent_end"].item()
pred = [(id2label[t], b, e) for t, b, e in pred if id2label[t] != "O"]
pred = [(t, b - start, e - start) for t, b, e in pred]
label_entities_map = {label: [] for label in LABEL_MEANING_MAP.keys()}
for t, b, e in pred:
label_entities_map[t].append(f"{b};{e+1}")
for t, b, e in pred: label_entities_map[t].append(f"{b};{e+1}")
if post_process:
# 若存在时间、地点实体重叠,则保留较长的
for meaning in ["时间", "地点"]:
label = MEANING_LABEL_MAP[meaning]
entities = label_entities_map[label] # 左闭右开
if entities:
spans = entities2spans(entities)
spans = list(filter(lambda x: not is_contain_special_char(x), spans))
spans = merge_spans(spans, keep_type="long")
entities = spans2entities(spans)
label_entities_map[label] = entities

# 1. 若存在被盗物品实体重叠,保留最短的;2. 被盗物品要和人名联系
meaning = "被盗物品"
label = MEANING_LABEL_MAP[meaning]
entities = label_entities_map[label] # 左闭右开
if entities:
spans = entities2spans(entities)
spans = list(filter(lambda x: not is_contain_special_char(x), spans))
# >>> 姓名处理 >>>
entities_name = label_entities_map[MEANING_LABEL_MAP["受害人"]]
spans_name = entities2spans(entities_name)
# 加入`受害人+被盗物品`的组合
spans.extend([(a[0], b[1]) for a, b in itertools.product(
spans_name, spans) if a[1] - b[0] in [-1, 0]])
# `受害人+被盗物品`、`被盗物品`,优先保留`受害人+被盗物品`
is_todel = [False] * len(spans)
for i, a in enumerate(spans_name):
for j, b in enumerate(spans):
u = (a[0], b[1])
if u in spans and u != b:
is_todel[j] = True
spans = [span for flag, span in zip(is_todel, spans) if not flag]
# <<< 姓名处理 <<<
spans = merge_spans(spans, keep_type="short")
entities = spans2entities(spans)
label_entities_map[label] = entities

# 受害人和犯罪嫌疑人设置最长实体限制(10)
for meaning in ["受害人", "犯罪嫌疑人"]:
label = MEANING_LABEL_MAP[meaning]
entities = label_entities_map[label]
if entities:
spans = entities2spans(entities)
spans = list(filter(lambda x: (not is_contain_special_char(x)) and is_length_le_n(x, 10), spans))
entities = spans2entities(spans)
label_entities_map[label] = entities

entities = [{"label": label, "span": label_entities_map[label]} \
for label in LABEL_MEANING_MAP.keys()]
# 预测结果文件为一个json格式的文件,包含两个字段,分别为``id``和``entities``
id_ = example["id"].split("-")[1]
return {"id": id_, "entities": entities}
return {"id": example["id"].split("-")[1], "entities": entities}

def predict(args, model, processor, tokenizer, prefix=""):
pred_output_dir = args.output_dir
Expand Down Expand Up @@ -948,7 +1034,7 @@ def predict(args, model, processor, tokenizer, prefix=""):
batch.pop("token_type_ids")
# 解码输出
example = test_dataset.examples[step][1]
results.append(predict_decode_batch(example, batch, id2label))
results.append(predict_decode_batch(example, batch, id2label, post_process=True))
# for k-fold
batch_all.append({k: v.detach().cpu() for k, v in batch.items()})
logger.info("\n")
Expand Down Expand Up @@ -996,7 +1082,7 @@ def load_dataset(args, processor, tokenizer, data_type='train'):
args = parser.parse_args_from_json(json_file=os.path.abspath(sys.argv[1]))
else:
args = parser.build_arguments().parse_args()
# args = parser.parse_args_from_json(json_file="args/bert_span-baseline-fold0.json")
# args = parser.parse_args_from_json(json_file="args/pred.0.json")

# Set seed before initializing model.
seed_everything(args.seed)
Expand Down
4 changes: 3 additions & 1 deletion scripts/clear_before_submit.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
rm -rf args/pred*
rm -rf tmp/*
rm -rf output/*/test*
rm -rf output/*/checkpoint-999999/
rm -rf output/*/checkpoint-999999/

# zip -r xxx.zip ./
1 change: 1 addition & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def seed_everything(seed=None, reproducibility=True):
"NS": "地点",
"NO": "组织机构",
}
MEANING_LABEL_MAP = {v: k for k, v in LABEL_MEANING_MAP.items()}

def load_raw(filepath):
raw = []
Expand Down

0 comments on commit 320c2cb

Please sign in to comment.