-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dev liuyibo #70
base: main
Are you sure you want to change the base?
Dev liuyibo #70
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
#!/bin/bash | ||
|
||
WORK_DIR=$(dirname $(dirname $(dirname $(readlink -f "$0")))) | ||
echo "working directory: $WORK_DIR" | ||
|
||
cd $WORK_DIR | ||
|
||
TASK_DIR=$WORK_DIR/tasks/text_generation_example | ||
|
||
if [ ! -d $TASK_DIR ]; then | ||
echo "task dir $TASK_DIR not exists, please train first." | ||
exit 1 | ||
fi | ||
|
||
export CUDA_VISIBLE_DEVICES=0 | ||
python gts_engine/gts_engine_inference.py \ | ||
--task_dir=$TASK_DIR \ | ||
--engine_type=qiankunding \ | ||
--task_type=generation \ | ||
--input_path=examples/text_generation/kpg_test.json \ | ||
--output_path=tnews_output.json |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
#!/bin/bash | ||
|
||
WORK_DIR=$(dirname $(dirname $(dirname $(readlink -f "$0")))) | ||
echo "working directory: $WORK_DIR" | ||
|
||
cd $WORK_DIR | ||
mkdir -p $WORK_DIR/tasks | ||
mkdir -p $WORK_DIR/pretrained | ||
|
||
PRETRAINED_DIR=$WORK_DIR/pretrained | ||
TASK_DIR=$WORK_DIR/tasks/text_generation_example | ||
mkdir -p $TASK_DIR | ||
|
||
export CUDA_VISIBLE_DEVICES=1 | ||
python gts_engine/gts_engine_train.py \ | ||
--engine_type=qiankunding \ | ||
--train_mode=standard \ | ||
--task_dir=$TASK_DIR \ | ||
--task_type=generation \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. task type改成keyphrase_generation |
||
--train_data=kpg_train.json \ | ||
--valid_data=kpg_val.json \ | ||
--test_data=kpg_test.json \ | ||
--data_dir=$WORK_DIR/examples/text_generation \ | ||
--save_path=$TASK_DIR/outputs \ | ||
--pretrained_model_dir=$PRETRAINED_DIR \ | ||
--train_batchsize=32 \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 测过这么大的bs占多少显存吗 |
||
--valid_batchsize=32 \ | ||
--max_len=256 \ | ||
--max_epochs=2 \ | ||
--min_epochs=2 \ | ||
--seed=123 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
#encoding=utf8 | ||
import os | ||
import sys | ||
import json | ||
import pickle | ||
import shutil | ||
import numpy as np | ||
from torch.utils.data import DataLoader | ||
from transformers import BertTokenizer, MegatronBertForMaskedLM | ||
|
||
from gts_common.registry import PIPELINE_REGISTRY | ||
from gts_common.pipeline_utils import download_model_from_huggingface, generate_common_trainer, load_args, save_args | ||
from qiankunding.utils.tokenization import get_t5_tokenizer | ||
from qiankunding.utils import knn_utils | ||
from qiankunding.dataloaders.text_generation.dataloader_kgt5 import TaskDataModelKGT5, TaskDatasetKGT5, kg_collate_fn | ||
from qiankunding.models.text_generation.t5_kg import T5KG | ||
from qiankunding.utils.evaluation import TextGenerateEvaluator | ||
from qiankunding.utils.utils import json2list, list2json | ||
from gts_common.logs_utils import Logger | ||
from transformers import T5ForConditionalGeneration, BertTokenizer, T5Tokenizer | ||
|
||
logger = Logger().get_log() | ||
|
||
|
||
def train_generation(args): | ||
model_name = "Randeng-T5-Keyphrase-Generation-Sci" | ||
# download pretrained model if not exists | ||
download_model_from_huggingface(args.pretrained_model_dir, model_name, model_class=T5ForConditionalGeneration, tokenizer_class=T5Tokenizer) | ||
# Set path to load pretrained model | ||
args.pretrained_model = os.path.join(args.pretrained_model_dir, model_name) | ||
# init tokenizer | ||
tokenizer = get_t5_tokenizer(args=args) | ||
tokenizer.save_pretrained(args.save_path) | ||
# init model | ||
data_model = TaskDataModelKGT5(args, tokenizer) | ||
#加载模型 | ||
model = T5KG(args, tokenizer) | ||
trainer, checkpoint = generate_common_trainer(args, args.save_path) | ||
# training | ||
trainer.fit(model, data_model) | ||
#验证集效果最好的模型文件地址 | ||
checkpoint_path = checkpoint.best_model_path | ||
|
||
if args.test_data: | ||
output_save_path = os.path.join(args.save_path, 'predictions/') | ||
if not os.path.exists(output_save_path): | ||
os.makedirs(output_save_path) | ||
|
||
# Evaluation | ||
logger.info("Load checkpoint from {}".format(checkpoint_path)) | ||
model = T5KG.load_from_checkpoint(checkpoint_path, tokenizer=tokenizer) | ||
model.cuda() | ||
model.eval() | ||
|
||
evaluator = TextGenerateEvaluator(args, model, data_model, output_save_path) | ||
test_f1 = evaluator.evaluation(mode='test', data_set="test") | ||
|
||
task_info = json.load(open(os.path.join(args.task_dir, "task_info.json"))) | ||
task_info["test_f1"] = test_f1 | ||
with open(os.path.join(args.task_dir, "task_info.json"), mode="w") as f: | ||
json.dump(task_info, f, indent=4) | ||
|
||
|
||
@PIPELINE_REGISTRY.register(suffix=__name__) | ||
def train_pipeline(args): | ||
# save args | ||
args = save_args(args) | ||
logger.info("******start standard train******") | ||
train_generation(args) | ||
|
||
|
||
@PIPELINE_REGISTRY.register(suffix=__name__) | ||
def prepare_inference(save_path): | ||
# load args | ||
args = load_args(save_path) | ||
|
||
# load tokenizer | ||
logger.info("Load tokenizer from {}".format(os.path.join(save_path, "vocab.txt"))) | ||
inference_tokenizer = T5Tokenizer.from_pretrained(save_path) | ||
|
||
# load model | ||
checkpoint_path = os.path.join(save_path, "best_model.ckpt") | ||
inference_model = T5KG.load_from_checkpoint(checkpoint_path, tokenizer=inference_tokenizer) | ||
inference_model.eval() | ||
inference_model = inference_model.cuda() | ||
|
||
inference_suite = { | ||
"tokenizer": inference_tokenizer, | ||
"model": inference_model, | ||
"args": args | ||
} | ||
return inference_suite | ||
|
||
@PIPELINE_REGISTRY.register(suffix=__name__) | ||
def inference(samples, inference_suite): | ||
# 加载数据 | ||
inner_samples = [] | ||
question = "请问下面的文字描述属于那个类别?" | ||
|
||
for idx, sample in enumerate(samples): | ||
inner_sample = { | ||
"id":idx, | ||
"content": sample["content"], | ||
"label":sample["label"], | ||
} | ||
inner_samples.append(inner_sample) | ||
|
||
dataset = TaskDatasetKGT5( | ||
data_path=None, | ||
args=inference_suite["args"], | ||
tokenizer=inference_suite["tokenizer"], | ||
load_from_list=True, | ||
samples=inner_samples | ||
) | ||
|
||
dataloader = DataLoader(dataset, shuffle=False, | ||
collate_fn=kg_collate_fn, \ | ||
batch_size=inference_suite["args"].valid_batchsize) | ||
|
||
pred_labels = [] | ||
|
||
for batch in dataloader: | ||
_, _, predicts, labels = inference_suite["model"].predict(batch) | ||
|
||
for predict in predicts: | ||
pred_labels.append(predict) | ||
|
||
result = {'predictions':pred_labels} | ||
return result |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import json | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 目录改成:gts_engine/qiankunding/dataloaders/keyphrase_generation |
||
import os | ||
import torch | ||
import torch.nn as nn | ||
import numpy as np | ||
from tqdm import tqdm | ||
import pytorch_lightning as pl | ||
from typing import Optional | ||
from torch.utils.data import DataLoader | ||
os.environ["TOKENIZERS_PARALLELISM"] = "false" | ||
|
||
from gts_common.logs_utils import Logger | ||
logger = Logger().get_log() | ||
|
||
class TaskDatasetKGT5(torch.utils.data.Dataset): | ||
def __init__(self, data_path=None, args=None, tokenizer=None, load_from_list=False, samples=None): | ||
super().__init__() | ||
|
||
self.tokenizer = tokenizer | ||
self.max_length = args.max_len | ||
self.args = args | ||
self.load_from_list = load_from_list | ||
self.samples = samples | ||
self.data = self.load_data(data_path, args, load_from_list, samples) | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def __getitem__(self, index): | ||
return self.encode(self.data[index]) | ||
|
||
def load_data(self, data_path, args=None, load_from_list=False, sentences=None): | ||
samples = [] | ||
|
||
if load_from_list: | ||
for line in tqdm(sentences): | ||
samples.append(line) | ||
else: | ||
with open(data_path, 'r', encoding='utf8') as f: | ||
lines = f.readlines() | ||
for line in tqdm(lines): | ||
samples.append(json.loads(line)) | ||
return samples | ||
|
||
def encode(self, item): | ||
|
||
text = f"""这段文本的关键词是?""" + f"""【{item["content"]}】""" | ||
label = item["label"] | ||
|
||
encode_dict = self.tokenizer(text, max_length=self.max_length, padding='longest',truncation=True) | ||
decode_dict = self.tokenizer(label, max_length=self.max_length // 2, padding='longest',truncation=True) | ||
|
||
encoded = { | ||
"id":item["id"], | ||
"sentence":text, | ||
"input_ids": torch.tensor(encode_dict['input_ids']).long(), | ||
"attention_mask": torch.tensor(encode_dict['attention_mask']).long(), | ||
"labels": torch.tensor(decode_dict['input_ids']).long(), | ||
} | ||
return encoded | ||
|
||
|
||
class TaskDataModelKGT5(pl.LightningDataModule): | ||
@staticmethod | ||
def add_data_specific_args(parent_args): | ||
parser = parent_args.add_argument_group('TASK NAME DataModel') | ||
parser.add_argument('--data_dir', default='./data', type=str) | ||
parser.add_argument('--train_data', default='train.json', type=str) | ||
parser.add_argument('--valid_data', default='dev.json', type=str) | ||
parser.add_argument('--test_data', default='test.json', type=str) | ||
parser.add_argument('--train_batchsize', default=16, type=int) | ||
parser.add_argument('--valid_batchsize', default=32, type=int) | ||
parser.add_argument('--max_len', default=128, type=int) | ||
|
||
return parent_args | ||
|
||
def __init__(self, args, tokenizer): | ||
super().__init__() | ||
self.train_batchsize = args.train_batchsize | ||
self.valid_batchsize = args.valid_batchsize | ||
self.num_workers = args.num_workers | ||
self.test_batchsize = args.test_batchsize | ||
self.tokenizer = tokenizer | ||
|
||
self.train_data = TaskDatasetKGT5(os.path.join( | ||
args.data_dir, args.train_data), args, tokenizer=tokenizer) | ||
self.valid_data = TaskDatasetKGT5(os.path.join( | ||
args.data_dir, args.valid_data), args, tokenizer=tokenizer) | ||
self.test_data = TaskDatasetKGT5(os.path.join( | ||
args.data_dir, args.test_data), args, tokenizer=tokenizer) | ||
|
||
|
||
def train_dataloader(self): | ||
return DataLoader(self.train_data, shuffle=True, collate_fn=kg_collate_fn, batch_size=self.train_batchsize, pin_memory=False, num_workers=self.num_workers) | ||
|
||
def val_dataloader(self): | ||
return DataLoader(self.valid_data, shuffle=False, collate_fn=kg_collate_fn, batch_size=self.valid_batchsize, pin_memory=False, num_workers=self.num_workers) | ||
|
||
def test_dataloader(self): | ||
return DataLoader(self.test_data, shuffle=False, collate_fn=kg_collate_fn, batch_size=self.test_batchsize, pin_memory=False, num_workers=self.num_workers) | ||
|
||
|
||
def kg_collate_fn(batch): | ||
''' | ||
Aggregate a batch data. | ||
batch = [ins1_dict, ins2_dict, ..., insN_dict] | ||
batch_data = {'sentence':[ins1_sentence, ins2_sentence...], 'input_ids':[ins1_input_ids, ins2_input_ids...], ...} | ||
''' | ||
batch_data = {} | ||
for key in batch[0]: | ||
batch_data[key] = [example[key] for example in batch] | ||
input_ids = batch_data['input_ids'] | ||
attention_mask = batch_data['attention_mask'] | ||
labels = batch_data["labels"] | ||
|
||
|
||
input_ids = nn.utils.rnn.pad_sequence(input_ids, | ||
batch_first=True, | ||
padding_value=0) | ||
|
||
new_attention_mask = nn.utils.rnn.pad_sequence(attention_mask, | ||
batch_first=True, | ||
padding_value=0) | ||
|
||
labels = nn.utils.rnn.pad_sequence(labels, | ||
batch_first=True, | ||
padding_value=-100) | ||
|
||
kpg_labels = labels.clone() | ||
kpg_labels[labels < 0] = 0 | ||
|
||
batch_data = { | ||
"id":batch_data["id"], | ||
"sentence":batch_data["sentence"], | ||
"input_ids": input_ids, | ||
"attention_mask": new_attention_mask, | ||
"labels": labels, | ||
"kpg_labels": kpg_labels, | ||
} | ||
|
||
return batch_data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
task type改成keyphrase_generation