From c328067b64760154822458edb3a9742277084824 Mon Sep 17 00:00:00 2001 From: louishsu Date: Thu, 16 Sep 2021 23:21:57 +0800 Subject: [PATCH] 2021/9/16 nezha_rdrop0.1-fgm1.0_f81.46_p74.55_r89.77 --- evaluate.py | 4 ++-- main.py | 15 +++++++++++-- main_local.py | 17 +++++++++++--- scripts/run_span.sh | 55 +++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 84 insertions(+), 7 deletions(-) diff --git a/evaluate.py b/evaluate.py index 07573a6..9e12a40 100644 --- a/evaluate.py +++ b/evaluate.py @@ -148,8 +148,8 @@ def analyze_error(ground_truth_path, output_path): plt.savefig(os.path.join("tmp", "cm.jpg")) if __name__ == '__main__': - # ground_truth_path, output_path = sys.argv[1], sys.argv[2] - ground_truth_path, output_path = "data/ner-ctx0-5fold-seed42/dev.gt.all.json", "output.json" + ground_truth_path, output_path = sys.argv[1], sys.argv[2] + # ground_truth_path, output_path = "data/ner-ctx0-5fold-seed42/dev.gt.all.json", "output.json" for label, score in get_scores(ground_truth_path, output_path).items(): print(LABEL_MEANING_MAP.get(label, label)) print(score) diff --git a/main.py b/main.py index 7d537b8..be32241 100644 --- a/main.py +++ b/main.py @@ -19,8 +19,19 @@ def main(): # WARNING:以下配置需要在提交前指定!!! # version = "baseline" - version = "rdrop0.1-fgm1.0" - model_type = "bert_span" + # model_type = "bert_span" + # dataset_name = "cail_ner" + # n_splits = 5 + # seed=42 + # -------------------------- + # version = "rdrop0.1-fgm1.0" + # model_type = "bert_span" + # dataset_name = "cail_ner" + # n_splits = 5 + # seed=42 + # -------------------------- + version = "nezha_rdrop0.1-fgm1.0" + model_type = "nezha_span" dataset_name = "cail_ner" n_splits = 5 seed=42 diff --git a/main_local.py b/main_local.py index a917415..2b4dd1a 100644 --- a/main_local.py +++ b/main_local.py @@ -13,14 +13,25 @@ def main(): parser = ArgumentParser() parser.add_argument("--local_debug", action="store_true", default=True) - parser.add_argument("--test_file", type=str, default=None) + # parser.add_argument("--test_file", type=str, default=None) run_args = parser.parse_args() local_debug = run_args.local_debug # WARNING:以下配置需要在提交前指定!!! # version = "baseline" - version = "rdrop0.1-fgm1.0" - model_type = "bert_span" + # model_type = "bert_span" + # dataset_name = "cail_ner" + # n_splits = 5 + # seed=42 + # -------------------------- + # version = "rdrop0.1-fgm1.0" + # model_type = "bert_span" + # dataset_name = "cail_ner" + # n_splits = 5 + # seed=42 + # -------------------------- + version = "nezha_rdrop0.1-fgm1.0" + model_type = "nezha_span" dataset_name = "cail_ner" n_splits = 5 seed=42 diff --git a/scripts/run_span.sh b/scripts/run_span.sh index c195a5b..513ccf9 100644 --- a/scripts/run_span.sh +++ b/scripts/run_span.sh @@ -195,6 +195,61 @@ done # 组织机构 # {'p': 0.8940397350993378, 'r': 0.9507042253521126, 'f': 0.9215017064846417} +for k in 0 1 2 3 4 +do +python run_span.py \ + --version=nezha_rdrop0.1-fgm1.0-fold${k} \ + --data_dir=./data/ner-ctx0-5fold-seed42/ \ + --train_file=train.${k}.json \ + --dev_file=dev.${k}.json \ + --test_file=dev.${k}.json \ + --model_type=nezha_span \ + --model_name_or_path=/home/louishsu/NewDisk/Garage/weights/transformers/nezha-cn-base/ \ + --do_train \ + --overwrite_output_dir \ + --evaluate_during_training \ + --evaluate_each_epoch \ + --save_best_checkpoints \ + --max_span_length=40 \ + --width_embedding_dim=128 \ + --train_max_seq_length=512 \ + --eval_max_seq_length=512 \ + --do_lower_case \ + --per_gpu_train_batch_size=8 \ + --per_gpu_eval_batch_size=16 \ + --gradient_accumulation_steps=2 \ + --learning_rate=5e-5 \ + --other_learning_rate=1e-3 \ + --num_train_epochs=4.0 \ + --warmup_proportion=0.1 \ + --rdrop_alpha=0.1 \ + --do_fgm --fgm_epsilon=1.0 \ + --seed=42 +done +# main_local +# avg +# {'p': 0.9433435148625493, 'r': 0.949889351487191, 'f': 0.9466051170874838} +# 犯罪嫌疑人 +# {'p': 0.9784416384354789, 'r': 0.983134767136005, 'f': 0.9807825885621673} +# 受害人 +# {'p': 0.946600434647625, 'r': 0.981016731016731, 'f': 0.9635013430241745} +# 被盗货币 +# {'p': 0.8947368421052632, 'r': 0.9475409836065574, 'f': 0.9203821656050956} +# 物品价值 +# {'p': 0.9842931937172775, 'r': 0.9894736842105263, 'f': 0.9868766404199475} +# 盗窃获利 +# {'p': 0.9301397205588823, 'r': 0.9688149688149689, 'f': 0.9490835030549899} +# 被盗物品 +# {'p': 0.9147816938453446, 'r': 0.9024390243902439, 'f': 0.9085684430512017} +# 作案工具 +# {'p': 0.8863636363636364, 'r': 0.9551020408163265, 'f': 0.9194499017681729} +# 时间 +# {'p': 0.955661414437523, 'r': 0.9432188065099457, 'f': 0.9493993447397161} +# 地点 +# {'p': 0.9213002566295979, 'r': 0.9186806937731021, 'f': 0.9199886104783599} +# 组织机构 +# {'p': 0.9203860072376358, 'r': 0.9466501240694789, 'f': 0.9333333333333333} + # for k in 0 1 2 3 4 # do # python run_span.py \