Skip to content

Commit

Permalink
方案整理
Browse files Browse the repository at this point in the history
  • Loading branch information
louishsu committed Nov 21, 2021
1 parent c8775a8 commit 4b6f08e
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 24 deletions.
168 changes: 167 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# CAIL2021——司法文本信息抽取

[报名链接](http://cail.cipsc.org.cn/)
[报名链接](http://cail.cipsc.org.cn/),方案详细介绍见[中国法律智能技术评测(CAIL2021):信息抽取(Rank2) - LOUIS' BLOG](https://louishsu.xyz/2021/10/22/%E4%B8%AD%E5%9B%BD%E6%B3%95%E5%BE%8B%E6%99%BA%E8%83%BD%E6%8A%80%E6%9C%AF%E8%AF%84%E6%B5%8B(CAIL2021)%EF%BC%9A%E4%BF%A1%E6%81%AF%E6%8A%BD%E5%8F%96(Rank2).html)

## 数据说明

Expand Down Expand Up @@ -30,3 +30,169 @@
|NT|时间|
|NS|地点|
|NO|组织机构|

## 程序运行

#### 准备数据
``` sh
$ tree ../cail_raw_data/ -d
../cail_raw_data/
├── 2018
│ └── CAIL2018_ALL_DATA
│ └── final_all_data
│ ├── exercise_contest
│ ├── first_stage
│ └── restData
├── 2020
│ ├── ydlj_big_data
│ └── ydlj_small_data
└── 2021
├── 信息抽取_第一阶段
├── 信息抽取_第二阶段
├── 案情标签_第一阶段
│ ├── aqbq
│ └── __MACOSX
│ └── aqbq
├── 类案检索_第一阶段
│ ├── __MACOSX
│ │ └── small
│ │ └── candidates
│ └── small
│ └── candidates
│ ├── 1325
│ ├── 1355
│ ├── 1405
│ ├── 1430
│ ├── 1972
│ ├── 1978
│ ├── 2132
│ ├── 2143
│ ├── 221
│ ├── 2331
│ ├── 2361
│ ├── 2373
│ ├── 259
│ ├── 3228
│ ├── 3342
│ ├── 3746
│ ├── 3765
│ ├── 4738
│ ├── 4794
│ └── 4829
└── 阅读理解_第一阶段

43 directories
```

#### 领域预训练
``` sh
# 生成预训练语料
python prepare_corpus.py \
--output_dir=../cail_processed_data/ \
--min_length=30 \
--max_length=256 \
--seed=42

# 分词用于wwm
python run_chinese_ref.py \
--file_name=../cail_processed_data/mlm-minlen30-maxlen256-seed42/corpus.txt \
--ltp=/home/louishsu/NewDisk/Garage/weights/ltp/base1.tgz \
--bert=/home/louishsu/NewDisk/Garage/weights/transformers/nezha-cn-base/ \
--save_path=../cail_processed_data/mlm-minlen30-maxlen256-seed42/ref.txt

# 预训练
export WANDB_DISABLED=true
nohup python run_mlm_wwm.py \
--model_name_or_path=/home/louishsu/NewDisk/Garage/weights/transformers/nezha-cn-base/ \
--model_type=nezha \
--train_file=../cail_processed_data/mlm-minlen30-maxlen256-seed42/corpus.txt \
--train_ref_file=../cail_processed_data/mlm-minlen30-maxlen256-seed42/ref.txt \
--cache_dir=cache/ \
--overwrite_cache \
--max_seq_length=256 \
--preprocessing_num_workers=8 \
--mlm_probability=0.15 \
--output_dir=output/nezha-legal-cn-base-wwm/ \
--overwrite_output_dir \
--do_train \
--warmup_steps=1500 \
--max_steps=30000 \
--per_device_train_batch_size=48 \
--gradient_accumulation_steps=4 \
--label_smoothing_factor=0.0 \
--learning_rate=5e-5 \
--weight_decay=0.01 \
--logging_dir=output/nezha-legal-cn-base-wwm/log/ \
--logging_strategy=steps \
--logging_steps=1500 \
--save_strategy=steps \
--save_steps=1500 \
--save_total_limit=10 \
--dataloader_num_workers=4 \
--seed=42 \
--fp16 \
>> output/nezha-legal-cn-base-wwm.out &
```

#### 信息抽取微调

``` sh
# 生成数据
python prepare_data.py \
--data_files ./data/信息抽取_第二阶段/xxcq_mid.json \
--context_window 0 \
--n_splits 5 \
--output_dir data/ \
--seed 42

# 多折模型微调
for k in 0 1 2 3 4
do
python run_span.py \
--version=nezha-legal-fgm1.0-lsr0.1-fold${k} \
--data_dir=./data/ner-ctx0-5fold-seed42/ \
--train_file=train.${k}.json \
--dev_file=dev.${k}.json \
--test_file=dev.${k}.json \
--model_type=nezha_span \
--model_name_or_path=/home/louishsu/NewDisk/Code/CAIL2021/nezha-legal-cn-base-wwm/ \
--do_train \
--overwrite_output_dir \
--evaluate_during_training \
--evaluate_each_epoch \
--save_best_checkpoints \
--max_span_length=40 \
--width_embedding_dim=128 \
--train_max_seq_length=512 \
--eval_max_seq_length=512 \
--do_lower_case \
--per_gpu_train_batch_size=8 \
--per_gpu_eval_batch_size=16 \
--gradient_accumulation_steps=2 \
--learning_rate=5e-5 \
--other_learning_rate=1e-3 \
--num_train_epochs=8.0 \
--warmup_proportion=0.1 \
--do_fgm --fgm_epsilon=1.0 \
--loss_type=lsr --label_smooth_eps=0.1 \
--seed=42
done

# 本地线下预测得到OOF,并计算得分
python main_local.py
```

注:可通过以下命令校验划分得到的数据集是否一致
``` sh
$ cd data/md5sum -c checksum
$ md5sum -c checksum
```

<!--
``` sh
$ for f in `ls`
> do
> md5sum $f >> checksum
> done
```
-->
24 changes: 12 additions & 12 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def main():
local_debug = run_args.local_debug

# WARNING:以下配置需要在提交前指定!!!
# span_proba_thresh = 0.0
span_proba_thresh = 0.3
span_proba_thresh = 0.0
# span_proba_thresh = 0.3
# version = "baseline"
# model_type = "bert_span"
# dataset_name = "cail_ner"
Expand Down Expand Up @@ -56,11 +56,11 @@ def main():
# n_splits = 5
# seed=42
# --------------------------
# version = "nezha-legal-fgm1.0-lsr0.1"
# model_type = "nezha_span"
# dataset_name = "cail_ner"
# n_splits = 5
# seed=42
version = "nezha-legal-fgm1.0-lsr0.1"
model_type = "nezha_span"
dataset_name = "cail_ner"
n_splits = 5
seed=42
# --------------------------
# version = "nezha-legal-fgm1.0-lsr0.1-ema3"
# model_type = "nezha_span"
Expand All @@ -80,11 +80,11 @@ def main():
# n_splits = 5
# seed=42
# # --------------------------
version = "nezha-legal-fgm1.0-lsr0.1-v2"
model_type = "nezha_span"
dataset_name = "cail_ner"
n_splits = 5
seed=42
# version = "nezha-legal-fgm1.0-lsr0.1-v2"
# model_type = "nezha_span"
# dataset_name = "cail_ner"
# n_splits = 5
# seed=42
# seed=32
# seed=12345
# --------------------------
Expand Down
22 changes: 11 additions & 11 deletions main_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def main():
local_debug = run_args.local_debug

# WARNING:以下配置需要在提交前指定!!!
# span_proba_thresh = 0.0
span_proba_thresh = 0.3
span_proba_thresh = 0.0
# span_proba_thresh = 0.3
# version = "baseline"
# model_type = "bert_span"
# dataset_name = "cail_ner"
Expand Down Expand Up @@ -62,29 +62,29 @@ def main():
# n_splits = 5
# seed=42
# --------------------------
# version = "nezha-legal-fgm1.0-lsr0.1"
version = "nezha-legal-fgm1.0-lsr0.1"
model_type = "nezha_span"
dataset_name = "cail_ner"
n_splits = 5
seed=42
# --------------------------
# version = "nezha-legal-fgm1.0-lsr0.1-ema3"
# model_type = "nezha_span"
# dataset_name = "cail_ner"
# n_splits = 5
# seed=42
# --------------------------
# version = "nezha-legal-fgm1.0-lsr0.1-ema3"
# version = "nezha-legal-fgm2.0-lsr0.1"
# model_type = "nezha_span"
# dataset_name = "cail_ner"
# n_splits = 5
# seed=42
# --------------------------
# version = "nezha-legal-fgm2.0-lsr0.1"
# version = "nezha-legal-fgm1.0-lsr0.1-v2"
# model_type = "nezha_span"
# dataset_name = "cail_ner"
# n_splits = 5
# seed=42
# --------------------------
version = "nezha-legal-fgm1.0-lsr0.1-v2"
model_type = "nezha_span"
dataset_name = "cail_ner"
n_splits = 5
seed=42
# seed=32
# seed=12345
# --------------------------
Expand Down
3 changes: 3 additions & 0 deletions run_span.py
Original file line number Diff line number Diff line change
Expand Up @@ -1523,6 +1523,7 @@ def load_dataset(args, processor, tokenizer, data_type='train'):
if args.do_train:
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
num_labels=num_labels, max_span_length=args.max_span_length,
width_embedding_dim=args.width_embedding_dim,
cache_dir=args.cache_dir if args.cache_dir else None, )
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
do_lower_case=args.do_lower_case,
Expand All @@ -1542,6 +1543,7 @@ def load_dataset(args, processor, tokenizer, data_type='train'):
logger.info("Loading model checkpoint from %s", best_checkpoints)
config = config_class.from_pretrained(best_checkpoints,
num_labels=num_labels, max_span_length=args.max_span_length,
width_embedding_dim=args.width_embedding_dim,
cache_dir=args.cache_dir if args.cache_dir else None, )
tokenizer = tokenizer_class.from_pretrained(best_checkpoints,
do_lower_case=args.do_lower_case,
Expand All @@ -1566,6 +1568,7 @@ def load_dataset(args, processor, tokenizer, data_type='train'):
if args.do_eval and args.local_rank in [-1, 0]:
config = config_class.from_pretrained(args.output_dir,
num_labels=num_labels, max_span_length=args.max_span_length,
width_embedding_dim=args.width_embedding_dim,
cache_dir=args.cache_dir if args.cache_dir else None, )
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
checkpoints = [args.output_dir]
Expand Down

0 comments on commit 4b6f08e

Please sign in to comment.