Skip to content

Commit

Permalink
2021/9/20 1. 添加prepare_corpus.py、run_mlm_wwwm.py及run_mlm_wwm.sh,预训练相关…
Browse files Browse the repository at this point in the history
…参数待确定; 2. 修改nezha.modeling_nezha.NeZhaForMaskedLM实现,增加return_dict参数。
  • Loading branch information
louishsu committed Sep 20, 2021
1 parent 185a302 commit e62585d
Show file tree
Hide file tree
Showing 4 changed files with 281 additions and 34 deletions.
33 changes: 27 additions & 6 deletions nezha/modeling_nezha.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from .configuration_nezha import NeZhaConfig
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from transformers.modeling_utils import PreTrainedModel, prune_linear_layer
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions
from transformers.modeling_outputs import (
BaseModelOutputWithPoolingAndCrossAttentions,
BaseModelOutputWithPastAndCrossAttentions,
MaskedLMOutput,
)
from transformers.models.bert.modeling_bert import (
BertOutput,
BertPooler,
Expand Down Expand Up @@ -710,7 +714,10 @@ def forward(
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
labels=None,
return_dict=None,
):
r"""
masked_lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Expand Down Expand Up @@ -758,33 +765,47 @@ def forward(
loss, prediction_scores = outputs[:2]
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
head_mask=head_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here

# Although this may seem awkward, BertForMaskedLM supports two scenarios:
# 1. If a tensor that contains the indices of masked labels is provided,
# the cross-entropy is the MLM cross-entropy that measures the likelihood
# of predictions for masked words.
# 2. If `lm_labels` is provided we are in a causal scenario where we
# try to predict the next token for each input in the decoder.
masked_lm_labels = None
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
outputs = (masked_lm_loss,) + outputs
return outputs # (ltr_lm_loss), (masked_lm_loss), prediction_scores, (hidden_states), (attentions)


if not return_dict:
output = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
effective_batch_size = input_shape[0]
Expand Down
198 changes: 176 additions & 22 deletions prepare_corpus.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,197 @@
# -*- coding: utf-8 -*-
'''
Description:
Version:
Author: louishsu
Github: https://github.com/isLouisHsu
E-mail: [email protected]
Date: 2021-09-19 14:53:15
LastEditTime: 2021-09-20 16:21:17
LastEditors: louishsu
FilePath: \CAIL2021-information-extraction\prepare_corpus.py
'''
import os
import re
import sys
import json
import random
from tqdm import tqdm
from collections import Counter
from argparse import ArgumentParser

def get_xxcq_corpus():
""" 信息抽取 """
...
def _process(document):
document = re.sub(r"\s+", "", document)
# document = document.translate({ord(f): ord(t) for f, t in zip(
# u',.!?[]()<>"\'', u',。!?【】()《》“‘')})
return document

def _split_doc(document):
sentences = re.split(r"[。;;]", document)
return sentences

def get_sfzy_corpus():
""" 司法摘要 """
...
def load_cail2018_corpus(filepaths):
corpus = []
for filepath in filepaths:
with open(filepath, "r", encoding="utf-8") as f:
while True:
line = f.readline()
if line == "": break
line = json.loads(line.strip())
document = line["fact"].strip()
document = _process(document)
sentences = _split_doc(document)
sentences = [sentence + "。" for sentence in sentences if len(sentence) > 0]
corpus.extend(sentences)
print(f"{sys._getframe().f_code.co_name} #{len(corpus)}")
return corpus

def get_sfks_corpus():
""" 司法考试 """
...
def load_cail2020_ydlj_corpus(filepaths):
corpus = []
for filepath in filepaths:
with open(filepath, "r", encoding="utf-8") as f:
lines = json.load(f)
for line in lines:
for sentence in line["context"][0][1]:
corpus.append(_process(sentence))
print(f"{sys._getframe().f_code.co_name} #{len(corpus)}")
return corpus

def get_aqbq_corpus():
def load_cail2021_aqbq_corpus(filepaths):
""" 案情标签 """
...
corpus = []
for filepath in filepaths:
with open(filepath, "r", encoding="utf-8") as f:
lines = json.load(f)
for line in lines:
for sentence in line["content"]:
corpus.append(_process(sentence))
print(f"{sys._getframe().f_code.co_name} #{len(corpus)}")
return corpus

def get_aljs_corpus():
def load_cail2021_aljs_candidate_corpus(dirname):
""" 案类检索 """
...
corpus = []
subdirs = os.listdir(dirname)
for subdir in tqdm(subdirs, desc="Loading...", total=len(subdirs)):
if subdir.startswith("."): continue
subdir = os.path.join(dirname, subdir)
for filename in os.listdir(subdir):
filename = os.path.join(subdir, filename)
with open(filename, "r", encoding="utf-8") as f:
line = json.load(f)
for key in ["ajjbqk", "cpfxgc", "pjjg", "qw"]:
document = line.get(key, None)
if document is None: continue
document = _process(document)
sentences = _split_doc(document)
sentences = [sentence + "。" for sentence in sentences if len(sentence) > 0]
corpus.extend(sentences)
print(f"{sys._getframe().f_code.co_name} #{len(corpus)}")
return corpus

def get_bllj_corpus():
""" 辩论理解 """
...

def get_ydlj_corpus():
def load_cail2021_ydlj_corpus(filepaths):
""" 阅读理解 """
...
corpus = []
for filepath in filepaths:
with open(filepath, "r", encoding="utf-8") as f:
lines = json.load(f)["data"]
for line in lines:
document = line["paragraphs"][0]["context"]
sentences = _split_doc(document)
sentences = [sentence + "。" for sentence in sentences if len(sentence) > 0]
corpus.extend(sentences)
print(f"{sys._getframe().f_code.co_name} #{len(corpus)}")
return corpus

def load_cail2021_xxcq_corpus(filepaths):
""" 信息抽取 """
corpus = []
for filepath in filepaths:
with open(filepath, "r", encoding="utf-8") as f:
while True:
line = f.readline()
if line == "": break
line = json.loads(line.strip())
document = line["context"].strip()
document = _process(document)
sentences = _split_doc(document)
sentences = [sentence + "。" for sentence in sentences if len(sentence) > 0]
corpus.extend(sentences)
print(f"{sys._getframe().f_code.co_name} #{len(corpus)}")
return corpus

def main(args):
args.output_dir = os.path.join(args.output_dir, f"mlm-seed{args.seed}")
args.output_dir = os.path.join(args.output_dir, f"mlm-minlen{args.min_length}-maxlen{args.max_length}-seed{args.seed}")
os.makedirs(args.output_dir, exist_ok=True)

...
corpus = []
corpus.extend(
load_cail2018_corpus([
"../cail_raw_data/2018/CAIL2018_ALL_DATA/final_all_data/exercise_contest/data_train.json",
"../cail_raw_data/2018/CAIL2018_ALL_DATA/final_all_data/exercise_contest/data_valid.json",
"../cail_raw_data/2018/CAIL2018_ALL_DATA/final_all_data/exercise_contest/data_test.json",
"../cail_raw_data/2018/CAIL2018_ALL_DATA/final_all_data/first_stage/train.json",
"../cail_raw_data/2018/CAIL2018_ALL_DATA/final_all_data/first_stage/test.json",
"../cail_raw_data/2018/CAIL2018_ALL_DATA/final_all_data/restData/rest_data.json",
"../cail_raw_data/2018/CAIL2018_ALL_DATA/final_all_data/final_test.json",
]))

corpus.extend(
load_cail2020_ydlj_corpus([
"../cail_raw_data/2020/ydlj_small_data/train.json",
"../cail_raw_data/2020/ydlj_big_data/train.json",
]))

corpus.extend(
load_cail2021_aqbq_corpus([
"../cail_raw_data/2021/案情标签_第一阶段/aqbq/train.json",
]))

corpus.extend(
load_cail2021_aljs_candidate_corpus(
"../cail_raw_data/2021/类案检索_第一阶段/small/candidates/"
))

corpus.extend(
load_cail2021_ydlj_corpus([
"../cail_raw_data/2021/阅读理解_第一阶段/ydlj_cjrc3.0_small_train.json"
]))

corpus.extend(
load_cail2021_xxcq_corpus([
"../cail_raw_data/2021/信息抽取_第二阶段/xxcq_mid.json",
]))

# 保留句子长度超过`min_length`的
corpus = list(filter(lambda x: len(x) > args.min_length and len(x) < args.max_length, corpus))

# 统计
lengths = list(map(len, corpus))
length_counter = Counter(lengths)
num_corpus = len(corpus)
print(f"{sys._getframe().f_code.co_name} #{len(corpus)}")
# corpus = sorted(corpus, key=lambda x: -len(x)) # for debug

# 保存
random.shuffle(corpus)
corpus = list(map(lambda x: x + "\n", corpus))
with open(os.path.join(args.output_dir, "corpus.txt"), "w", encoding="utf-8") as f:
f.writelines(corpus)
num_corpus_train = int(num_corpus * args.train_ratio)
corpus_train = corpus[: num_corpus_train]
corpus_valid = corpus[num_corpus_train: ]
with open(os.path.join(args.output_dir, "corpus.train.txt"), "w", encoding="utf-8") as f:
f.writelines(corpus_train)
with open(os.path.join(args.output_dir, "corpus.valid.txt"), "w", encoding="utf-8") as f:
f.writelines(corpus_valid)

if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("--data_dir", type=str, default="data/")
# parser.add_argument("--data_dir", type=str, default="data/")
parser.add_argument("--output_dir", type=str, default="data/")
parser.add_argument("--min_length", type=int, default=20)
parser.add_argument("--max_length", type=int, default=256)
parser.add_argument("--train_ratio", type=float, default=0.8)
parser.add_argument("--seed", default=42, type=int, help="Seed.")
args = parser.parse_args()

Expand Down
20 changes: 14 additions & 6 deletions run_mlm_wwm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,17 @@
)
from transformers.trainer_utils import get_last_checkpoint, is_main_process

from transformers import BertConfig, BertTokenizer
from nezha.modeling_nezha import NeZhaForMaskedLM

logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)

MODEL_CLASSES = {
"default": (AutoConfig, AutoModelForMaskedLM, AutoTokenizer),
"nezha": (BertConfig, NeZhaForMaskedLM, BertTokenizer),
}

@dataclass
class ModelArguments:
Expand Down Expand Up @@ -263,15 +269,17 @@ def main():
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
model_args.model_type = "default" if model_args.model_type is None else model_args.model_type
config_class, model_class, tokenizer_class = MODEL_CLASSES[model_args.model_type]
config_kwargs = {
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
}
if model_args.config_name:
config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
config = config_class.from_pretrained(model_args.config_name, **config_kwargs)
elif model_args.model_name_or_path:
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
config = config_class.from_pretrained(model_args.model_name_or_path, **config_kwargs)
else:
config = CONFIG_MAPPING[model_args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.")
Expand All @@ -283,17 +291,17 @@ def main():
"use_auth_token": True if model_args.use_auth_token else None,
}
if model_args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
tokenizer = tokenizer_class.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
elif model_args.model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
tokenizer = tokenizer_class.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
else:
raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
)

if model_args.model_name_or_path:
model = AutoModelForMaskedLM.from_pretrained(
model = model_class.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
Expand All @@ -303,7 +311,7 @@ def main():
)
else:
logger.info("Training new model from scratch")
model = AutoModelForMaskedLM.from_config(config)
model = model_class.from_config(config)

model.resize_token_embeddings(len(tokenizer))

Expand Down
Loading

0 comments on commit e62585d

Please sign in to comment.