-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
2021/9/20 1. 添加prepare_corpus.py、run_mlm_wwwm.py及run_mlm_wwm.sh,预训练相关…
…参数待确定; 2. 修改nezha.modeling_nezha.NeZhaForMaskedLM实现,增加return_dict参数。
- Loading branch information
louishsu
committed
Sep 20, 2021
1 parent
185a302
commit e62585d
Showing
4 changed files
with
281 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,43 +1,197 @@ | ||
# -*- coding: utf-8 -*- | ||
''' | ||
Description: | ||
Version: | ||
Author: louishsu | ||
Github: https://github.com/isLouisHsu | ||
E-mail: [email protected] | ||
Date: 2021-09-19 14:53:15 | ||
LastEditTime: 2021-09-20 16:21:17 | ||
LastEditors: louishsu | ||
FilePath: \CAIL2021-information-extraction\prepare_corpus.py | ||
''' | ||
import os | ||
import re | ||
import sys | ||
import json | ||
import random | ||
from tqdm import tqdm | ||
from collections import Counter | ||
from argparse import ArgumentParser | ||
|
||
def get_xxcq_corpus(): | ||
""" 信息抽取 """ | ||
... | ||
def _process(document): | ||
document = re.sub(r"\s+", "", document) | ||
# document = document.translate({ord(f): ord(t) for f, t in zip( | ||
# u',.!?[]()<>"\'', u',。!?【】()《》“‘')}) | ||
return document | ||
|
||
def _split_doc(document): | ||
sentences = re.split(r"[。;;]", document) | ||
return sentences | ||
|
||
def get_sfzy_corpus(): | ||
""" 司法摘要 """ | ||
... | ||
def load_cail2018_corpus(filepaths): | ||
corpus = [] | ||
for filepath in filepaths: | ||
with open(filepath, "r", encoding="utf-8") as f: | ||
while True: | ||
line = f.readline() | ||
if line == "": break | ||
line = json.loads(line.strip()) | ||
document = line["fact"].strip() | ||
document = _process(document) | ||
sentences = _split_doc(document) | ||
sentences = [sentence + "。" for sentence in sentences if len(sentence) > 0] | ||
corpus.extend(sentences) | ||
print(f"{sys._getframe().f_code.co_name} #{len(corpus)}") | ||
return corpus | ||
|
||
def get_sfks_corpus(): | ||
""" 司法考试 """ | ||
... | ||
def load_cail2020_ydlj_corpus(filepaths): | ||
corpus = [] | ||
for filepath in filepaths: | ||
with open(filepath, "r", encoding="utf-8") as f: | ||
lines = json.load(f) | ||
for line in lines: | ||
for sentence in line["context"][0][1]: | ||
corpus.append(_process(sentence)) | ||
print(f"{sys._getframe().f_code.co_name} #{len(corpus)}") | ||
return corpus | ||
|
||
def get_aqbq_corpus(): | ||
def load_cail2021_aqbq_corpus(filepaths): | ||
""" 案情标签 """ | ||
... | ||
corpus = [] | ||
for filepath in filepaths: | ||
with open(filepath, "r", encoding="utf-8") as f: | ||
lines = json.load(f) | ||
for line in lines: | ||
for sentence in line["content"]: | ||
corpus.append(_process(sentence)) | ||
print(f"{sys._getframe().f_code.co_name} #{len(corpus)}") | ||
return corpus | ||
|
||
def get_aljs_corpus(): | ||
def load_cail2021_aljs_candidate_corpus(dirname): | ||
""" 案类检索 """ | ||
... | ||
corpus = [] | ||
subdirs = os.listdir(dirname) | ||
for subdir in tqdm(subdirs, desc="Loading...", total=len(subdirs)): | ||
if subdir.startswith("."): continue | ||
subdir = os.path.join(dirname, subdir) | ||
for filename in os.listdir(subdir): | ||
filename = os.path.join(subdir, filename) | ||
with open(filename, "r", encoding="utf-8") as f: | ||
line = json.load(f) | ||
for key in ["ajjbqk", "cpfxgc", "pjjg", "qw"]: | ||
document = line.get(key, None) | ||
if document is None: continue | ||
document = _process(document) | ||
sentences = _split_doc(document) | ||
sentences = [sentence + "。" for sentence in sentences if len(sentence) > 0] | ||
corpus.extend(sentences) | ||
print(f"{sys._getframe().f_code.co_name} #{len(corpus)}") | ||
return corpus | ||
|
||
def get_bllj_corpus(): | ||
""" 辩论理解 """ | ||
... | ||
|
||
def get_ydlj_corpus(): | ||
def load_cail2021_ydlj_corpus(filepaths): | ||
""" 阅读理解 """ | ||
... | ||
corpus = [] | ||
for filepath in filepaths: | ||
with open(filepath, "r", encoding="utf-8") as f: | ||
lines = json.load(f)["data"] | ||
for line in lines: | ||
document = line["paragraphs"][0]["context"] | ||
sentences = _split_doc(document) | ||
sentences = [sentence + "。" for sentence in sentences if len(sentence) > 0] | ||
corpus.extend(sentences) | ||
print(f"{sys._getframe().f_code.co_name} #{len(corpus)}") | ||
return corpus | ||
|
||
def load_cail2021_xxcq_corpus(filepaths): | ||
""" 信息抽取 """ | ||
corpus = [] | ||
for filepath in filepaths: | ||
with open(filepath, "r", encoding="utf-8") as f: | ||
while True: | ||
line = f.readline() | ||
if line == "": break | ||
line = json.loads(line.strip()) | ||
document = line["context"].strip() | ||
document = _process(document) | ||
sentences = _split_doc(document) | ||
sentences = [sentence + "。" for sentence in sentences if len(sentence) > 0] | ||
corpus.extend(sentences) | ||
print(f"{sys._getframe().f_code.co_name} #{len(corpus)}") | ||
return corpus | ||
|
||
def main(args): | ||
args.output_dir = os.path.join(args.output_dir, f"mlm-seed{args.seed}") | ||
args.output_dir = os.path.join(args.output_dir, f"mlm-minlen{args.min_length}-maxlen{args.max_length}-seed{args.seed}") | ||
os.makedirs(args.output_dir, exist_ok=True) | ||
|
||
... | ||
corpus = [] | ||
corpus.extend( | ||
load_cail2018_corpus([ | ||
"../cail_raw_data/2018/CAIL2018_ALL_DATA/final_all_data/exercise_contest/data_train.json", | ||
"../cail_raw_data/2018/CAIL2018_ALL_DATA/final_all_data/exercise_contest/data_valid.json", | ||
"../cail_raw_data/2018/CAIL2018_ALL_DATA/final_all_data/exercise_contest/data_test.json", | ||
"../cail_raw_data/2018/CAIL2018_ALL_DATA/final_all_data/first_stage/train.json", | ||
"../cail_raw_data/2018/CAIL2018_ALL_DATA/final_all_data/first_stage/test.json", | ||
"../cail_raw_data/2018/CAIL2018_ALL_DATA/final_all_data/restData/rest_data.json", | ||
"../cail_raw_data/2018/CAIL2018_ALL_DATA/final_all_data/final_test.json", | ||
])) | ||
|
||
corpus.extend( | ||
load_cail2020_ydlj_corpus([ | ||
"../cail_raw_data/2020/ydlj_small_data/train.json", | ||
"../cail_raw_data/2020/ydlj_big_data/train.json", | ||
])) | ||
|
||
corpus.extend( | ||
load_cail2021_aqbq_corpus([ | ||
"../cail_raw_data/2021/案情标签_第一阶段/aqbq/train.json", | ||
])) | ||
|
||
corpus.extend( | ||
load_cail2021_aljs_candidate_corpus( | ||
"../cail_raw_data/2021/类案检索_第一阶段/small/candidates/" | ||
)) | ||
|
||
corpus.extend( | ||
load_cail2021_ydlj_corpus([ | ||
"../cail_raw_data/2021/阅读理解_第一阶段/ydlj_cjrc3.0_small_train.json" | ||
])) | ||
|
||
corpus.extend( | ||
load_cail2021_xxcq_corpus([ | ||
"../cail_raw_data/2021/信息抽取_第二阶段/xxcq_mid.json", | ||
])) | ||
|
||
# 保留句子长度超过`min_length`的 | ||
corpus = list(filter(lambda x: len(x) > args.min_length and len(x) < args.max_length, corpus)) | ||
|
||
# 统计 | ||
lengths = list(map(len, corpus)) | ||
length_counter = Counter(lengths) | ||
num_corpus = len(corpus) | ||
print(f"{sys._getframe().f_code.co_name} #{len(corpus)}") | ||
# corpus = sorted(corpus, key=lambda x: -len(x)) # for debug | ||
|
||
# 保存 | ||
random.shuffle(corpus) | ||
corpus = list(map(lambda x: x + "\n", corpus)) | ||
with open(os.path.join(args.output_dir, "corpus.txt"), "w", encoding="utf-8") as f: | ||
f.writelines(corpus) | ||
num_corpus_train = int(num_corpus * args.train_ratio) | ||
corpus_train = corpus[: num_corpus_train] | ||
corpus_valid = corpus[num_corpus_train: ] | ||
with open(os.path.join(args.output_dir, "corpus.train.txt"), "w", encoding="utf-8") as f: | ||
f.writelines(corpus_train) | ||
with open(os.path.join(args.output_dir, "corpus.valid.txt"), "w", encoding="utf-8") as f: | ||
f.writelines(corpus_valid) | ||
|
||
if __name__ == '__main__': | ||
parser = ArgumentParser() | ||
parser.add_argument("--data_dir", type=str, default="data/") | ||
# parser.add_argument("--data_dir", type=str, default="data/") | ||
parser.add_argument("--output_dir", type=str, default="data/") | ||
parser.add_argument("--min_length", type=int, default=20) | ||
parser.add_argument("--max_length", type=int, default=256) | ||
parser.add_argument("--train_ratio", type=float, default=0.8) | ||
parser.add_argument("--seed", default=42, type=int, help="Seed.") | ||
args = parser.parse_args() | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.