Skip to content

Commit

Permalink
2021/9/16 nezha_rdrop0.1-fgm1.0_f81.46_p74.55_r89.77
Browse files Browse the repository at this point in the history
  • Loading branch information
louishsu committed Sep 16, 2021
1 parent f442600 commit c328067
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 7 deletions.
4 changes: 2 additions & 2 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def analyze_error(ground_truth_path, output_path):
plt.savefig(os.path.join("tmp", "cm.jpg"))

if __name__ == '__main__':
# ground_truth_path, output_path = sys.argv[1], sys.argv[2]
ground_truth_path, output_path = "data/ner-ctx0-5fold-seed42/dev.gt.all.json", "output.json"
ground_truth_path, output_path = sys.argv[1], sys.argv[2]
# ground_truth_path, output_path = "data/ner-ctx0-5fold-seed42/dev.gt.all.json", "output.json"
for label, score in get_scores(ground_truth_path, output_path).items():
print(LABEL_MEANING_MAP.get(label, label))
print(score)
Expand Down
15 changes: 13 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,19 @@ def main():

# WARNING:以下配置需要在提交前指定!!!
# version = "baseline"
version = "rdrop0.1-fgm1.0"
model_type = "bert_span"
# model_type = "bert_span"
# dataset_name = "cail_ner"
# n_splits = 5
# seed=42
# --------------------------
# version = "rdrop0.1-fgm1.0"
# model_type = "bert_span"
# dataset_name = "cail_ner"
# n_splits = 5
# seed=42
# --------------------------
version = "nezha_rdrop0.1-fgm1.0"
model_type = "nezha_span"
dataset_name = "cail_ner"
n_splits = 5
seed=42
Expand Down
17 changes: 14 additions & 3 deletions main_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,25 @@
def main():
parser = ArgumentParser()
parser.add_argument("--local_debug", action="store_true", default=True)
parser.add_argument("--test_file", type=str, default=None)
# parser.add_argument("--test_file", type=str, default=None)
run_args = parser.parse_args()
local_debug = run_args.local_debug

# WARNING:以下配置需要在提交前指定!!!
# version = "baseline"
version = "rdrop0.1-fgm1.0"
model_type = "bert_span"
# model_type = "bert_span"
# dataset_name = "cail_ner"
# n_splits = 5
# seed=42
# --------------------------
# version = "rdrop0.1-fgm1.0"
# model_type = "bert_span"
# dataset_name = "cail_ner"
# n_splits = 5
# seed=42
# --------------------------
version = "nezha_rdrop0.1-fgm1.0"
model_type = "nezha_span"
dataset_name = "cail_ner"
n_splits = 5
seed=42
Expand Down
55 changes: 55 additions & 0 deletions scripts/run_span.sh
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,61 @@ done
# 组织机构
# {'p': 0.8940397350993378, 'r': 0.9507042253521126, 'f': 0.9215017064846417}

for k in 0 1 2 3 4
do
python run_span.py \
--version=nezha_rdrop0.1-fgm1.0-fold${k} \
--data_dir=./data/ner-ctx0-5fold-seed42/ \
--train_file=train.${k}.json \
--dev_file=dev.${k}.json \
--test_file=dev.${k}.json \
--model_type=nezha_span \
--model_name_or_path=/home/louishsu/NewDisk/Garage/weights/transformers/nezha-cn-base/ \
--do_train \
--overwrite_output_dir \
--evaluate_during_training \
--evaluate_each_epoch \
--save_best_checkpoints \
--max_span_length=40 \
--width_embedding_dim=128 \
--train_max_seq_length=512 \
--eval_max_seq_length=512 \
--do_lower_case \
--per_gpu_train_batch_size=8 \
--per_gpu_eval_batch_size=16 \
--gradient_accumulation_steps=2 \
--learning_rate=5e-5 \
--other_learning_rate=1e-3 \
--num_train_epochs=4.0 \
--warmup_proportion=0.1 \
--rdrop_alpha=0.1 \
--do_fgm --fgm_epsilon=1.0 \
--seed=42
done
# main_local
# avg
# {'p': 0.9433435148625493, 'r': 0.949889351487191, 'f': 0.9466051170874838}
# 犯罪嫌疑人
# {'p': 0.9784416384354789, 'r': 0.983134767136005, 'f': 0.9807825885621673}
# 受害人
# {'p': 0.946600434647625, 'r': 0.981016731016731, 'f': 0.9635013430241745}
# 被盗货币
# {'p': 0.8947368421052632, 'r': 0.9475409836065574, 'f': 0.9203821656050956}
# 物品价值
# {'p': 0.9842931937172775, 'r': 0.9894736842105263, 'f': 0.9868766404199475}
# 盗窃获利
# {'p': 0.9301397205588823, 'r': 0.9688149688149689, 'f': 0.9490835030549899}
# 被盗物品
# {'p': 0.9147816938453446, 'r': 0.9024390243902439, 'f': 0.9085684430512017}
# 作案工具
# {'p': 0.8863636363636364, 'r': 0.9551020408163265, 'f': 0.9194499017681729}
# 时间
# {'p': 0.955661414437523, 'r': 0.9432188065099457, 'f': 0.9493993447397161}
# 地点
# {'p': 0.9213002566295979, 'r': 0.9186806937731021, 'f': 0.9199886104783599}
# 组织机构
# {'p': 0.9203860072376358, 'r': 0.9466501240694789, 'f': 0.9333333333333333}

# for k in 0 1 2 3 4
# do
# python run_span.py \
Expand Down

0 comments on commit c328067

Please sign in to comment.