Skip to content

Commit

Permalink
2021/9/25 nezha-legal-fgm1.0-lsr0.1_f82.71_p76.04_r90.66
Browse files Browse the repository at this point in the history
  • Loading branch information
louishsu committed Sep 25, 2021
1 parent d55c1ef commit 37dbe30
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 7 deletions.
15 changes: 12 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,25 @@ def main():
# n_splits = 5
# seed=42
# --------------------------
version = "nezha-fgm1.0-lsr0.1"
# version = "nezha-fgm1.0-lsr0.1"
# model_type = "nezha_span"
# dataset_name = "cail_ner"
# n_splits = 5
# seed=42
# --------------------------
version = "nezha-legal-fgm1.0-lsr0.1"
model_type = "nezha_span"
dataset_name = "cail_ner"
n_splits = 5
seed=42

test_examples = None
test_batches = None
for k in range(n_splits):
model_path = f"./output/ner-{dataset_name}-{model_type}-{version}-fold{k}-{seed}/"
# for k in range(n_splits):
# model_path = f"./output/ner-{dataset_name}-{model_type}-{version}-fold{k}-{seed}/"
model_paths = [f"./output/ner-{dataset_name}-{model_type}-{version}-fold{k}-{seed}/" for k in range(n_splits)]
for k in range(len(model_paths)):
model_path = model_paths[k]
# 生成测试运行参数
json_file = os.path.join(model_path, "training_args.json")
parser = NerArgumentParser()
Expand Down
15 changes: 12 additions & 3 deletions main_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,25 @@ def main():
# n_splits = 5
# seed=42
# --------------------------
version = "nezha-fgm1.0-lsr0.1"
# version = "nezha-fgm1.0-lsr0.1"
# model_type = "nezha_span"
# dataset_name = "cail_ner"
# n_splits = 5
# seed=42
# --------------------------
version = "nezha-legal-fgm1.0-lsr0.1"
model_type = "nezha_span"
dataset_name = "cail_ner"
n_splits = 5
seed=42

test_examples = []
test_batches = []
for k in range(n_splits):
model_path = f"./output/ner-{dataset_name}-{model_type}-{version}-fold{k}-{seed}/"
# for k in range(n_splits):
# model_path = f"./output/ner-{dataset_name}-{model_type}-{version}-fold{k}-{seed}/"
model_paths = [f"./output/ner-{dataset_name}-{model_type}-{version}-fold{k}-{seed}/" for k in range(n_splits)]
for k in range(len(model_paths)):
model_path = model_paths[k]
# 生成测试运行参数
json_file = os.path.join(model_path, "training_args.json")
parser = NerArgumentParser()
Expand Down
2 changes: 1 addition & 1 deletion scripts/clear_before_submit.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
rm -rf args/pred*
rm -rf tmp/*
# rm -rf tmp/*
rm -rf output/*/test*
rm -rf output/*/checkpoint-999999/

Expand Down
56 changes: 56 additions & 0 deletions scripts/run_span.sh
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,62 @@ done
# 组织机构
# {'p': 0.8513513513513513, 'r': 0.8598014888337469, 'f': 0.8555555555555555}

# Further-pretrain LSR
for k in 0 1 2 3 4
do
python run_span.py \
--version=nezha-legal-fgm1.0-lsr0.1-fold${k} \
--data_dir=./data/ner-ctx0-5fold-seed42/ \
--train_file=train.${k}.json \
--dev_file=dev.${k}.json \
--test_file=dev.${k}.json \
--model_type=nezha_span \
--model_name_or_path=/home/louishsu/NewDisk/Code/CAIL2021/nezha-legal-cn-base-wwm/ \
--do_train \
--overwrite_output_dir \
--evaluate_during_training \
--evaluate_each_epoch \
--save_best_checkpoints \
--max_span_length=40 \
--width_embedding_dim=128 \
--train_max_seq_length=512 \
--eval_max_seq_length=512 \
--do_lower_case \
--per_gpu_train_batch_size=8 \
--per_gpu_eval_batch_size=16 \
--gradient_accumulation_steps=2 \
--learning_rate=5e-5 \
--other_learning_rate=1e-3 \
--num_train_epochs=8.0 \
--warmup_proportion=0.1 \
--do_fgm --fgm_epsilon=1.0 \
--loss_type=lsr --label_smooth_eps=0.1 \
--seed=42
done
# main_local
# avg
# {'p': 0.9032722314800787, 'r': 0.8945650950827051, 'f': 0.898897578441534}
# 犯罪嫌疑人
# {'p': 0.9654852190063458, 'r': 0.965186445922946, 'f': 0.9653358093469513}
# 受害人
# {'p': 0.9271829682196853, 'r': 0.9668597168597168, 'f': 0.9466057646873522}
# 被盗货币
# {'p': 0.8280590717299579, 'r': 0.8579234972677595, 'f': 0.8427267847557702}
# 物品价值
# {'p': 0.9745069745069745, 'r': 0.969377990430622, 'f': 0.9719357159990405}
# 盗窃获利
# {'p': 0.8812877263581489, 'r': 0.9106029106029107, 'f': 0.8957055214723928}
# 被盗物品
# {'p': 0.8194369732831271, 'r': 0.7905206711641585, 'f': 0.8047191406937841}
# 作案工具
# {'p': 0.7972972972972973, 'r': 0.8027210884353742, 'f': 0.7999999999999999}
# 时间
# {'p': 0.950575994054255, 'r': 0.9251356238698011, 'f': 0.937683284457478}
# 地点
# {'p': 0.8714835652946402, 'r': 0.836792721069093, 'f': 0.8537859007832898}
# 组织机构
# {'p': 0.8789407313997478, 'r': 0.8647642679900744, 'f': 0.8717948717948718}

# TODO: 全部数据

# TODO: EMA
Expand Down

0 comments on commit 37dbe30

Please sign in to comment.