Skip to content

Commit

Permalink
baseline 5fold
Browse files Browse the repository at this point in the history
  • Loading branch information
louishsu committed Sep 5, 2021
1 parent 204ff72 commit 453bf34
Show file tree
Hide file tree
Showing 12 changed files with 5,491 additions and 181 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
*.pyc
output.json
args/pred*
output/*
cache/*
tmp/*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
{
"version": "baseline",
"version": "baseline-fold0",
"device": "cuda:0",
"n_gpu": 1,
"task_name": "ner",
"dataset_name": "cail_ner",
"data_dir": "./data/ner-ctx0-train0.8-seed42/",
"train_file": "train.json",
"dev_file": "dev.json",
"test_file": "dev.json",
"data_dir": "./data/ner-ctx0-5fold-seed42/",
"train_file": "train.0.json",
"dev_file": "dev.0.json",
"test_file": "dev.0.json",
"model_type": "bert_span",
"model_name_or_path": "/home/louishsu/NewDisk/Garage/weights/transformers/chinese-roberta-wwm",
"output_dir": "output/",
Expand Down
Binary file added data/信息抽取_第二阶段.zip
Binary file not shown.
5,247 changes: 5,247 additions & 0 deletions data/信息抽取_第二阶段/xxcq_mid.json

Large diffs are not rendered by default.

38 changes: 12 additions & 26 deletions eda.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
{
"cell_type": "code",
"execution_count": 1,
"id": "42482b6a",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -15,7 +14,6 @@
{
"cell_type": "code",
"execution_count": 2,
"id": "e534daf4",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -24,7 +22,6 @@
},
{
"cell_type": "markdown",
"id": "e6469600",
"metadata": {},
"source": [
"## 文本长度统计"
Expand All @@ -33,7 +30,6 @@
{
"cell_type": "code",
"execution_count": 3,
"id": "f7ac750c",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -43,7 +39,6 @@
{
"cell_type": "code",
"execution_count": 4,
"id": "71da647b",
"metadata": {},
"outputs": [
{
Expand All @@ -60,7 +55,6 @@
},
{
"cell_type": "markdown",
"id": "de0910b0",
"metadata": {},
"source": [
"## 标签统计"
Expand All @@ -69,7 +63,6 @@
{
"cell_type": "code",
"execution_count": 5,
"id": "60120294",
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -99,7 +92,6 @@
{
"cell_type": "code",
"execution_count": 6,
"id": "87c78261",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -110,7 +102,6 @@
{
"cell_type": "code",
"execution_count": 7,
"id": "b3b997bd",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -120,7 +111,6 @@
{
"cell_type": "code",
"execution_count": 8,
"id": "d210524a",
"metadata": {},
"outputs": [
{
Expand All @@ -141,7 +131,6 @@
{
"cell_type": "code",
"execution_count": 9,
"id": "4f9680c8",
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -171,7 +160,6 @@
{
"cell_type": "code",
"execution_count": 10,
"id": "324d63d6",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -181,7 +169,6 @@
{
"cell_type": "code",
"execution_count": 11,
"id": "0aec0c72",
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -244,7 +231,6 @@
{
"cell_type": "code",
"execution_count": 12,
"id": "c144f03e",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -256,40 +242,40 @@
{
"cell_type": "code",
"execution_count": 13,
"id": "bfe15b7e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"NHCS 犯罪嫌疑人\n",
"Counter({3: 2711, 2: 162, 4: 28, 5: 19, 9: 9, 7: 4, 1: 1, 6: 1})\n",
"[(1, 1), (2, 162), (3, 2711), (4, 28), (5, 19), (6, 1), (7, 4), (9, 9)]\n",
"NHVI 受害人\n",
"Counter({3: 1274, 2: 20, 4: 4, 8: 1})\n",
"[(2, 20), (3, 1274), (4, 4), (8, 1)]\n",
"NASI 被盗物品\n",
"Counter({4: 342, 5: 275, 2: 267, 7: 203, 3: 199, 6: 199, 8: 190, 10: 146, 9: 139, 11: 112, 12: 101, 13: 71, 14: 60, 17: 33, 15: 32, 18: 26, 16: 25, 1: 21, 21: 17, 20: 16, 19: 13, 23: 10, 26: 9, 22: 8, 27: 6, 28: 5, 25: 5, 24: 5, 30: 4, 32: 2, 35: 2, 40: 2, 29: 2, 31: 2, 33: 1, 39: 1, 38: 1, 44: 1, 45: 1, 57: 1})\n",
"[(1, 21), (2, 267), (3, 199), (4, 342), (5, 275), (6, 199), (7, 203), (8, 190), (9, 139), (10, 146), (11, 112), (12, 101), (13, 71), (14, 60), (15, 32), (16, 25), (17, 33), (18, 26), (19, 13), (20, 16), (21, 17), (22, 8), (23, 10), (24, 5), (25, 5), (26, 9), (27, 6), (28, 5), (29, 2), (30, 4), (31, 2), (32, 2), (33, 1), (35, 2), (38, 1), (39, 1), (40, 2), (44, 1), (45, 1), (57, 1)]\n",
"NT 时间\n",
"Counter({12: 215, 14: 155, 13: 150, 11: 147, 10: 136, 15: 105, 9: 79, 16: 62, 17: 37, 7: 32, 8: 30, 18: 14, 5: 12, 6: 10, 22: 7, 4: 7, 3: 6, 2: 6, 19: 5, 20: 4, 23: 3, 24: 1, 26: 1, 21: 1, 32: 1})\n",
"[(2, 6), (3, 6), (4, 7), (5, 12), (6, 10), (7, 32), (8, 30), (9, 79), (10, 136), (11, 147), (12, 215), (13, 150), (14, 155), (15, 105), (16, 62), (17, 37), (18, 14), (19, 5), (20, 4), (21, 1), (22, 7), (23, 3), (24, 1), (26, 1), (32, 1)]\n",
"NS 地点\n",
"Counter({5: 99, 14: 91, 17: 87, 15: 87, 6: 83, 19: 78, 18: 78, 16: 74, 4: 71, 12: 70, 13: 69, 11: 69, 20: 68, 9: 63, 7: 62, 23: 57, 8: 53, 22: 52, 21: 50, 10: 39, 24: 30, 3: 28, 25: 23, 2: 21, 26: 18, 28: 16, 27: 15, 31: 8, 29: 8, 34: 4, 32: 3, 33: 3, 30: 1, 37: 1, 36: 1})\n",
"[(2, 21), (3, 28), (4, 71), (5, 99), (6, 83), (7, 62), (8, 53), (9, 63), (10, 39), (11, 69), (12, 70), (13, 69), (14, 91), (15, 87), (16, 74), (17, 87), (18, 78), (19, 78), (20, 68), (21, 50), (22, 52), (23, 57), (24, 30), (25, 23), (26, 18), (27, 15), (28, 16), (29, 8), (30, 1), (31, 8), (32, 3), (33, 3), (34, 4), (36, 1), (37, 1)]\n",
"NCGV 物品价值\n",
"Counter({8: 255, 5: 194, 7: 187, 4: 122, 9: 67, 6: 64, 11: 20, 10: 19, 3: 15, 12: 2, 13: 1, 20: 1})\n",
"[(3, 15), (4, 122), (5, 194), (6, 64), (7, 187), (8, 255), (9, 67), (10, 19), (11, 20), (12, 2), (13, 1), (20, 1)]\n",
"NO 组织机构\n",
"Counter({9: 94, 4: 81, 6: 45, 12: 22, 10: 21, 8: 18, 11: 17, 5: 15, 13: 6, 16: 6, 7: 6, 15: 5, 14: 3, 17: 3, 18: 2, 20: 1, 3: 1})\n",
"[(3, 1), (4, 81), (5, 15), (6, 45), (7, 6), (8, 18), (9, 94), (10, 21), (11, 17), (12, 22), (13, 6), (14, 3), (15, 5), (16, 6), (17, 3), (18, 2), (20, 1)]\n",
"NATS 作案工具\n",
"Counter({3: 83, 2: 80, 5: 32, 4: 28, 15: 9, 13: 9, 8: 7, 6: 7, 12: 6, 7: 6, 16: 6, 11: 5, 20: 5, 19: 3, 21: 2, 9: 2, 10: 2, 26: 1, 24: 1})\n",
"[(2, 80), (3, 83), (4, 28), (5, 32), (6, 7), (7, 6), (8, 7), (9, 2), (10, 2), (11, 5), (12, 6), (13, 9), (15, 9), (16, 6), (19, 3), (20, 5), (21, 2), (24, 1), (26, 1)]\n",
"NCSP 盗窃获利\n",
"Counter({4: 58, 5: 46, 7: 25, 8: 23, 3: 13, 6: 12, 9: 4, 11: 3, 15: 1, 10: 1})\n",
"[(3, 13), (4, 58), (5, 46), (6, 12), (7, 25), (8, 23), (9, 4), (10, 1), (11, 3), (15, 1)]\n",
"NCSM 被盗货币\n",
"Counter({8: 74, 7: 70, 6: 55, 10: 51, 9: 47, 5: 33, 11: 24, 4: 23, 12: 12, 2: 9, 3: 7, 13: 5, 35: 2, 14: 2, 15: 2, 17: 2})\n"
"[(2, 9), (3, 7), (4, 23), (5, 33), (6, 55), (7, 70), (8, 74), (9, 47), (10, 51), (11, 24), (12, 12), (13, 5), (14, 2), (15, 2), (17, 2), (35, 2)]\n"
]
}
],
"source": [
"for label, entities in label_entities_map.items():\n",
" print(label, utils.LABEL_MEANING_MAP[label])\n",
" print(Counter([len(entity) for entity in entities]))\n",
" counter = Counter([len(entity) for entity in entities])\n",
" print(sorted(counter.items(), key=lambda x: x[0]))\n",
" entities = sorted(list(set(entities)), key=len)\n",
" with open(f\"tmp/{label}_{utils.LABEL_MEANING_MAP[label]}.txt\", \"w\") as f:\n",
" f.writelines([entity + \"\\n\" for entity in entities])"
Expand Down
143 changes: 88 additions & 55 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,57 +2,73 @@
import sys
import json
from collections import Counter
from matplotlib import pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from utils import LABEL_MEANING_MAP

def get_score(ground_truth_path, output_path):
try:
ground_truth = {}
prediction = {}
with open(ground_truth_path, 'r', encoding="utf-8") as f:
for line in f:
data = json.loads(line)
id = data['id']
data.pop('id')
ground_truth[id] = data
with open(output_path, 'r', encoding="utf-8") as f:
for line in f:
data = json.loads(line)
id = data['id']
data.pop('id')
prediction[id] = data

ground_truth_num = 0
prediction_num = 0
tp = 0
plt.rcParams['font.sans-serif'] = ['KaiTi']
plt.rcParams['axes.unicode_minus'] = False

def score(ground_truth, prediction, labels=None):
ground_truth_num = 0
prediction_num = 0
tp = 0

for id, ground_truth_data in ground_truth.items():
try:
pred_data = prediction[id]
except KeyError:
for id, ground_truth_data in ground_truth.items():
try:
pred_data = prediction[id]
except KeyError:
continue
ground_truth_entities_dict = {}
for entitie in ground_truth_data['entities']:
if labels is not None and entitie["label"] not in labels:
continue
ground_truth_num += len(entitie['span'])
ground_truth_entities_dict[entitie['label']] = entitie['span']
pred_entities_dict = {}
for entitie in pred_data['entities']:
if labels is not None and entitie["label"] not in labels:
continue
ground_truth_entities_dict = {}
for entitie in ground_truth_data['entities']:
ground_truth_num += len(entitie['span'])
ground_truth_entities_dict[entitie['label']] = entitie['span']
pred_entities_dict = {}
for entitie in pred_data['entities']:
prediction_num += len(entitie['span'])
pred_entities_dict[entitie['label']] = entitie['span']
for label in ground_truth_entities_dict.keys():
tp += len(set(ground_truth_entities_dict[label]).intersection(set(pred_entities_dict[label])))
prediction_num += len(entitie['span'])
pred_entities_dict[entitie['label']] = entitie['span']
for label in ground_truth_entities_dict.keys():
tp += len(set(ground_truth_entities_dict[label]).intersection(set(pred_entities_dict[label])))

try:
p = tp / prediction_num
r = tp / ground_truth_num
f = 2 * p * r / ( p + r )

s1 = round(p * 100, 2)
s2 = round(r * 100, 2)
s3 = round(f * 100, 2)
return {"p": s1, "r": s2, "f": s3}
except Exception as e:
return {"p": -1, "r": -1, "f": -1}
score = {"p": p, "r": r, "f": f}
except ZeroDivisionError as e:
score = {"p": -1, "r": -1, "f": -1}
return score


def get_scores(ground_truth_path, output_path):
ground_truth = {}
prediction = {}
with open(ground_truth_path, 'r', encoding="utf-8") as f:
for line in f:
data = json.loads(line)
id = data['id']
data.pop('id')
ground_truth[id] = data
with open(output_path, 'r', encoding="utf-8") as f:
for line in f:
data = json.loads(line)
id = data['id']
data.pop('id')
prediction[id] = data
scores = dict()
scores["总计"] = score(ground_truth, prediction)
for label in LABEL_MEANING_MAP.keys():
scores[label] = score(ground_truth, prediction, [label])
return scores

def analyze_error(ground_truth_path, output_path):
get_position = lambda x: "-".join(x.split("-")[:2])
get_content = lambda x: "-".join(x.split("-")[2:])
get_label = lambda x: x.split("-")[-1]
ground_truth = {}
prediction = {}
with open(ground_truth_path, 'r', encoding="utf-8") as f:
Expand All @@ -77,6 +93,7 @@ def analyze_error(ground_truth_path, output_path):
for span in entity["span"]:
b, e = [int(i) for i in span.split(";")]
ground_truth_entities.append(f"{id}-{span}-{text[b: e]}-{label}")
ground_truth_entities = sorted(ground_truth_entities, key=get_position)
prediction_entities = []
for id, prediction_data in prediction.items():
text = ground_truth_id_text_map[id]
Expand All @@ -85,12 +102,10 @@ def analyze_error(ground_truth_path, output_path):
for span in entity["span"]:
b, e = [int(i) for i in span.split(";")]
prediction_entities.append(f"{id}-{span}-{text[b: e]}-{label}")
prediction_entities = sorted(prediction_entities, key=get_position)

# in_gt_not_in_pred = sorted(set(ground_truth_entities) - set(prediction_entities))
# in_pred_not_in_gt = sorted(set(prediction_entities) - set(ground_truth_entities))
get_position = lambda x: "-".join(x.split("-")[:2])
get_content = lambda x: "-".join(x.split("-")[2:])
get_label = lambda x: x.split("-")[-1]
ground_truth_positions_content_map = {get_position(entity): get_content(entity)
for entity in ground_truth_entities}
prediction_positions_content_map = {get_position(entity): get_content(entity)
Expand All @@ -101,33 +116,51 @@ def analyze_error(ground_truth_path, output_path):
for entity in prediction_entities}
## 第一类错误:定位错误,未识别到标注实体
location_missed = set(ground_truth_positions_content_map.keys()) - set(prediction_positions_content_map.keys())
location_missed = sorted(location_missed)
location_missed = [k + "-" + ground_truth_positions_content_map[k] for k in location_missed]
location_missed = sorted(location_missed)
## 第二类错误:定位错误,识别到未标注实体
location_found = set(prediction_positions_content_map.keys()) - set(ground_truth_positions_content_map.keys())
location_found = sorted(location_found)
location_found = [k + "-" + prediction_positions_content_map[k] for k in location_found]
location_found = sorted(location_found)
## 第三类错误:定位准确但分类错误
itersection_positions = set(ground_truth_positions_content_map.keys()) \
.intersection(prediction_positions_content_map.keys())
label_error = [k + "-" + ground_truth_positions_content_map[k] + \
ground_truth_positions_label_map[k] + "->" + prediction_positions_label_map[k]
label_error = [k + "-" + ground_truth_positions_content_map[k].split("-")[0] + "-" + \
ground_truth_positions_label_map[k] + "-" + prediction_positions_label_map[k]
for k in itersection_positions if ground_truth_positions_label_map[k] != prediction_positions_label_map[k]]
# 混淆矩阵
y_true = [ground_truth_positions_label_map[k] for k in itersection_positions]
y_pred = [prediction_positions_label_map[k] for k in itersection_positions]
labels = list(LABEL_MEANING_MAP.values())
cm = confusion_matrix(y_true, y_pred, labels=labels)
disp = ConfusionMatrixDisplay(cm, display_labels=labels).plot()
label_error_type_count_map = Counter(["-".join(e.split("-")[-2:]) for e in label_error])
# 保存
for list_name in ["location_missed", "location_found", "label_error"]:
for list_name in ["ground_truth_entities", "prediction_entities",
"location_missed", "location_found", "label_error"]:
with open(os.path.join("tmp", list_name + ".txt"), "w") as f:
for line in locals()[list_name]:
f.write(line + "\n")
print(label_error_type_count_map)
print(labels)
plt.savefig(os.path.join("tmp", "cm.jpg"))

if __name__ == '__main__':
# print(get_score(sys.argv[1], sys.argv[2]))
# print(get_score(
ground_truth_path, output_path = sys.argv[1], sys.argv[2]
for label, score in get_scores(ground_truth_path, output_path).items():
print(LABEL_MEANING_MAP.get(label, label))
print(score)
# analyze_error(ground_truth_path, output_path)

# for label, score in get_scores(
# "./data/ner-ctx0-5fold-seed42/dev.gt.0.json",
# "output/ner-cail_ner-bert_span-baseline-42/test_prediction.json"
# ).items():
# print(LABEL_MEANING_MAP.get(label, label))
# print(score)
# analyze_error(
# "./data/ner-ctx0-5fold-seed42/dev.gt.0.json",
# "output/ner-cail_ner-bert_span-baseline-fold0-42/test_prediction.json"
# ))
analyze_error(
"./data/ner-ctx0-5fold-seed42/dev.gt.0.json",
"output/ner-cail_ner-bert_span-baseline-fold0-42/test_prediction.json"
)
# "output/ner-cail_ner-bert_span-baseline-42/test_prediction.json"
# )
Loading

0 comments on commit 453bf34

Please sign in to comment.