-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dev liuyibo #70
base: main
Are you sure you want to change the base?
Dev liuyibo #70
Conversation
python gts_engine/gts_engine_inference.py \ | ||
--task_dir=$TASK_DIR \ | ||
--engine_type=qiankunding \ | ||
--task_type=generation \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
task type改成keyphrase_generation
--engine_type=qiankunding \ | ||
--train_mode=standard \ | ||
--task_dir=$TASK_DIR \ | ||
--task_type=generation \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
task type改成keyphrase_generation
--data_dir=$WORK_DIR/examples/text_generation \ | ||
--save_path=$TASK_DIR/outputs \ | ||
--pretrained_model_dir=$PRETRAINED_DIR \ | ||
--train_batchsize=32 \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
测过这么大的bs占多少显存吗
@@ -0,0 +1,141 @@ | |||
import json |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目录改成:gts_engine/qiankunding/dataloaders/keyphrase_generation
@@ -0,0 +1,181 @@ | |||
from genericpath import exists |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改成:gts_engine/qiankunding/models/keyphrase_generation
logger = Logger().get_log() | ||
|
||
|
||
class T5KG(BaseModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
叫KeyphraseGenerationT5吧,本身单词也不长
inputs = self.train_inputs(batch) | ||
outputs = self.model.generate( | ||
input_ids = inputs['input_ids'], | ||
max_length = 32, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个32是拍的?
|
||
outputs = self.model.generate( | ||
input_ids = inputs['input_ids'], | ||
max_length=32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
32这个,设置成类成员变量吧
|
||
results.append(pred) | ||
TP, total_pred, total_true = 0, 0, 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的计算和前面模型的validation_step差不多,为啥不单独抽象出一个公共的函数
special_tokens += [f"[choice{i+1}]" for i in range(200)] | ||
# special_tokens += [f"{i+1}" for i in range(200)] | ||
|
||
print("pretrained_model_path", pretrained_model_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
没用的输出要删掉
1.增加Keyphrase Generation项目
2.在examples/text_generation/ 新增 Keyphrase Generation的启动项 以及少量数据
3.新增gts_engine/pipelines/qiankunding_generation.py 主代码
3.新增gts_engine/qiankunding/dataloaders/text_generation/dataloader_kgt5.py的数据处理代码
4.新增gts_engine/qiankunding/models/text_generation/t5_kg.py 模型代码
5.在gts_engine/qiankunding/utils/evaluation.py 中新增 TextGenerateEvaluator 类 用于Generate的评估
6.在gts_engine/qiankunding/utils/tokenization.py中新增 T5的tokenization
7.运行examples/text_generation/run_train_qiankunding.sh 成功如下
8.运行examples/text_generation/run_inference_qiankunding.sh 成功如下