Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add_squence_extraction #67

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions examples/causal_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# -*-coding:utf-8-*-
import sys
import os
sys.path.append(os.path.abspath(os.path.join(__file__, "..", "..")))
import argparse
from openks.models import OpenKSModel


parser = argparse.ArgumentParser()
parser.add_argument("--num_epoch", type=int, default=50, help="Number of epoches for fine-tuning.")
parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate used to train with warmup.")
parser.add_argument("--train_data", type=str, default='../openks/data/data_for_causality_extraction/train-corpus.json', help="train data")
parser.add_argument("--test_data", type=str, default='../openks/data/data_for_causality_extraction/test-corpus.json', help="test data")

parser.add_argument("--predict_save_path", type=str, default='../openks/data/data_for_causality_extraction/predict.json', help="predict data save path")
parser.add_argument("--predict_data", type=str, default='../openks/data/data_for_causality_extraction/test-corpus.json', help="predict data")
parser.add_argument("--MLP_save_path", type=str, default='checkpoints/MLP', help="predict data")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay rate for L2 regularizer.")
parser.add_argument("--warmup_proportion", type=float, default=0.1,
help="Warmup proportion params for warmup strategy")
parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.")
parser.add_argument("--valid_step", type=int, default=100, help="validation step")
parser.add_argument("--skip_step", type=int, default=20, help="skip step")
parser.add_argument("--batch_size", type=int, default=16, help="Total examples' number in batch for training.")
parser.add_argument("--checkpoints", type=str, default='checkpoints/Erine', help="Directory to model checkpoint")
parser.add_argument("--init_ckpt", type=str, default='checkpoints',
help="already pretraining trigger detection model checkpoint")
parser.add_argument("--seed", type=int, default=1000, help="random seed for initialization")
parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu",
help="Select which device to train model, defaults to gpu.")

args = parser.parse_args()

platform = 'Paddle'
executor = 'Causality_Extraction'
model = 'Causality_Extraction'
print("根据配置,使用 {} 框架,{} 执行器训练 {} 模型。".format(platform, executor, model))
print("-----------------------------------------------")
# 模型训练
executor = OpenKSModel.get_module(platform, executor)
Event_Extraction = executor(args=args)
Event_Extraction.run()

print("-----------------------------------------------")

print("-----------------------------------------------")
51 changes: 51 additions & 0 deletions examples/sequence_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# -*-coding:utf-8-*-
import sys
import os
sys.path.append(os.path.abspath(os.path.join(__file__, "..", "..")))
import argparse
from openks.models import OpenKSModel



parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser()
parser.add_argument("--num_epoch", type=int, default=5, help="Number of epoches for fine-tuning.")
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.")
parser.add_argument("--train_data_path", type=str, default='data/SeRI_mod/train_pair.dat',
help="train_data_path")
parser.add_argument("--test_data_path", type=str, default='data/SeRI_mod/test_pair.dat', help="test_data_path")

parser.add_argument("--predict_data", type=str, default='', help="predict data")
parser.add_argument("--do_train", type=ast.literal_eval, default=True, help="do train")
parser.add_argument("--do_predict", type=ast.literal_eval, default=True, help="do predict")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay rate for L2 regularizer.")
parser.add_argument("--warmup_proportion", type=float, default=0.1,
help="Warmup proportion params for warmup strategy")
parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.")
parser.add_argument("--valid_step", type=int, default=100, help="validation step")
parser.add_argument("--skip_step", type=int, default=20, help="skip step")
parser.add_argument("--batch_size", type=int, default=16, help="Total examples' number in batch for training.")
parser.add_argument("--checkpoints", type=str, default='checkpoints/sub_Erine',
help="Directory to model checkpoint(save model)")
parser.add_argument("--init_ckpt", type=str, default='checkpoints/sub_Erine',
help="already pretraining model checkpoint()")
parser.add_argument("--predict_save_path", type=str, default='data/predict.json', help="predict data save path")
parser.add_argument("--seed", type=int, default=1000, help="random seed for initialization")
parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu",
help="Select which device to train model, defaults to gpu.")

args = parser.parse_args()

args = parser.parse_args()

platform = 'Paddle'
executor = 'Sequence_Extraction'
model = 'Sequence_Extraction'
print("根据配置,使用 {} 框架,{} 执行器训练 {} 模型。".format(platform, executor, model))
print("-----------------------------------------------")
# 模型训练
executor = OpenKSModel.get_module(platform, executor)
Event_Extraction = executor(args=args)
Event_Extraction.run()

print("-----------------------------------------------")
Loading