You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
def pack(data_json_path, fileDirectory, template, do_label):
dataDirectory = os.path.dirname(fileDirectory)
with open(data_json_path, 'r') as f:
data = json.load(f)
train_datas = data["data"]
data_list = []
txt_files = [os.path.join(dataDirectory, f) for f in os.listdir(dataDirectory) if f.endswith(".txt")]
txt_filenames = [os.path.basename(f) for f in txt_files]
for i, train_data in enumerate(train_datas):
file_path = os.path.join(dataDirectory, txt_filenames[i])
Dic_entity_list = train_data["preLabel"]["results"][0]["preLabel"]["results"]
Temp = copy.deepcopy(template)
Temp["id"] = i
Temp["text"] = (open(file_path, "r").read())
if do_label:
for el in Dic_entity_list:
for j in range(len(Temp["choices"])):
if Temp["choices"][j]["entity_type"] == el["tagName"]:
dic_choice = {}
dic_choice["entity_type"] = el["tagName"]
dic_choice["entity_name"] = el["textSelectResults"]
dic_choice["entity_idx"] = [[el["startIndex"], el["endIndex"]]]
Temp["choices"][j]["entity_list"].append(dic_choice)
data_list.append(Temp)
return data_list
## 准备测试数据的true res
with open(test_data_path, 'r') as f:
test_datas = json.load(f)
test_true_entity = []
for data_dic in test_datas["data"]:
one_text_entity = []
entity_list = data_dic["preLabel"]["results"][0]["preLabel"]["results"]
for el in entity_list:
one_text_entity.append(el["textSelectResults"])
test_true_entity.append(one_text_entity)
return train_data_list, test_data_list, dev_data_list, test_true_entity
def read_json(json_path):
"""读取json配置文件"""
with open(json_path, encoding='utf_8_sig') as json_file:
config = json.load(json_file)
return config
from argparse import ArgumentParser
import enum
import sys
sys.path.insert(0, '/home/Fengshenbang-LM')
from fengshen import UbertPipelines
from fengshen.metric import metric
import finetuneutils as utils
import pandas as pd
import copy
import json
import time
from collections import Counter
import os
os.environ['CUDA_VISIBLE_DEVICES']='-1'
import logging
logger = logging.getLogger(name)
class EntityScore(object):
def init(self):
self.reset()
def metric_test_res(rawResult, test_data_label):
test_pre_metric = []
test_data_pred = []
m1 = EntityScore() # 评价指标
for i in range(len(test_data_label)):
res_dic = rawResult[i]
pred_res = [choice["entity_list"][-1]["entity_name"] for choice in res_dic["choices"] if len(choice["entity_list"]) > 0]
test_data_pred.append(pred_res)
true_res = test_data_label[i]
m1.update(true_res, pred_res)
m1_res = m1.result()[0]
test_pre_metric.append([m1_res["acc"], m1_res["recall"], m1_res["f1"]])
acc_list = [test_pre_metric[i][0] for i in range(len(test_pre_metric))]
recall_list = [test_pre_metric[i][1] for i in range(len(test_data_label))]
f1_list = [test_pre_metric[i][2] for i in range(len(test_data_label))]
Acc = sum(acc_list) / len(acc_list)
Recall = sum(recall_list) / len(recall_list)
f1 = sum(f1_list) / len(f1_list)
print(f"准确率为{Acc}, 召回率为{Recall}, F1值为{f1}")
metric_res = {"Acc": Acc, "Recall": Recall, "F1": f1}
TestText = [tdl["text"] for tdl in rawResult]
dic1 = {"text": TestText, "truelabel": test_data_label, "predictlabel": test_data_pred}
testResultFile = pd.DataFrame(dic1)
return metric_res, testResultFile
def init_model(config, state=False):
total_parser = ArgumentParser("TASK NAME")
total_parser = UbertPipelines.pipelines_args(total_parser)
args, _ = total_parser.parse_known_args()
args.pretrained_model_path = config['pretrained_model_path']
args.default_root_dir = config["default_root_dir"] # 默认主路径,用来放日志、tensorboard等
args.checkpoint_path = config["ckpt_save_path"]
if state:
args.load_checkpoints_path = config["ckpt_load_path"]
print(f"导入了{config['ckpt_load_path']}成功")
args.learning_rate = config["learning_rate"]
args.max_epochs = config["max_epochs"]
args.batch_size = config["batch_size"]
args.num_works = config["num_works"]
args.max_length = config["max_length"]
args.save_weights_only = False
args.save_top_k = 1
args.log_every_n_steps=5
model = UbertPipelines(args)
return model
def model_predict(model, train_data_list, test_data_list, dev_data_list, test_data_label, config):
"""读取模型并且进行小样本训练"""
print(f"--------------------------------------------------微调前测试集准确率-------------------------------------------------", end="\n\n\n")
test_data_list1 = copy.deepcopy(test_data_list)
test_data_list2 = copy.deepcopy(test_data_list)
rawResult1 = model.predict(test_data_list, cuda=False)
# print(f"rawResult1:\n{rawResult1}")
# print(f"test_data_label:\n{test_data_label}")
metric_res1, testResultFile1 = metric_test_res(rawResult1, test_data_label)
testResultFile1.to_csv("./output_319(未训练).csv", encoding='utf_8_sig')
def pack(data_json_path, fileDirectory, template, do_label):
dataDirectory = os.path.dirname(fileDirectory)
with open(data_json_path, 'r') as f:
data = json.load(f)
train_datas = data["data"]
data_list = []
txt_files = [os.path.join(dataDirectory, f) for f in os.listdir(dataDirectory) if f.endswith(".txt")]
txt_filenames = [os.path.basename(f) for f in txt_files]
for i, train_data in enumerate(train_datas):
file_path = os.path.join(dataDirectory, txt_filenames[i])
Dic_entity_list = train_data["preLabel"]["results"][0]["preLabel"]["results"]
Temp = copy.deepcopy(template)
Temp["id"] = i
Temp["text"] = (open(file_path, "r").read())
if do_label:
for el in Dic_entity_list:
for j in range(len(Temp["choices"])):
if Temp["choices"][j]["entity_type"] == el["tagName"]:
dic_choice = {}
dic_choice["entity_type"] = el["tagName"]
dic_choice["entity_name"] = el["textSelectResults"]
dic_choice["entity_idx"] = [[el["startIndex"], el["endIndex"]]]
Temp["choices"][j]["entity_list"].append(dic_choice)
data_list.append(Temp)
return data_list
def Data_Read_Process(config):
"""从指定的路径下读取json格式的训练数据, 并且进行数据包装"""
## 读取训练集、测试集数据
current_dir = os.path.dirname(os.path.abspath(file)) # 当前路径
train_data_path = os.path.join(current_dir, "data/finetune/train/annotation.json")
test_data_path = os.path.join(current_dir, "data/finetune/val/annotation.json")
template = config["UbertTemplate"]
train_data_list = pack(train_data_path, train_data_path, template, True)
test_data_list = pack(test_data_path, test_data_path, template, False)
print(f"测试集大小为{len(test_data_list)}")
dev_data_list = train_data_list[:20]
def read_json(json_path):
"""读取json配置文件"""
with open(json_path, encoding='utf_8_sig') as json_file:
config = json.load(json_file)
return config
def toDbLog(config, metric_res, cost_time):
current_dir = os.path.dirname(os.path.abspath(file)) # 当前路径
log_path = os.path.join(current_dir, "train_log.log")
# 更新一条数据在平台mysql
utils.postProcessing(
id=config["fineTuneId"],
model_path=config["ckpt_load_path"],
run_time=f"{cost_time}s",
mysql_info=mysql_info,
log_path=log_path,
metrics=metric_res
)
if name == "main":
。。。。。。。
这是我的训练脚本,在model_predict这个函数中,
1、在测试集上的预测,得到了准确率为0.38861435524828686, 召回率为0.23892033777267305, F1值为0.29555310050381006,
2、进行了model.fit
3、在测试集上进行测试,准确率为0.722521935459677, 召回率为0.5737778997677414, F1值为0.6349043602768729。
4、我重新导入了训练过程中生成的最新的last.ckpt文件,在测试集上进行测试,准确率和第3步相同。
但是我在验证脚本中,导入和第4步相同的.ckpt文件,并且使用相同的测试集,准确率为0.4253157506616308, 召回率为0.2049549567217007, F1值为0.2750863361696849。
我的验证脚本内容为:
from pymongo import MongoClient
import utils2
import requests
from urllib.parse import unquote
from argparse import ArgumentParser
import pandas as pd
from collections import Counter
import enum, sys, copy, json, time, os
sys.path.insert(0, '/home/Fengshenbang-LM')
os.environ['CUDA_VISIBLE_DEVICES']='-1'
from fengshen import UbertPipelines
from fengshen.metric import metric
inferencer = None
def inferenceFunc(model_info):
def InferAndExtract(source_path):
"""模型推理 + 结果信息提取"""
global inferencer # 模型实例对象(全局变量)
if isinstance(source_path, str) and os.path.isdir(source_path):
"""这是一个文件夹路径"""
source_list=get_nlp_files(cfg.fileDirectory)
elif isinstance(source_path, list):
"""这是一个列表, 里面是需要推理的txt文件路径"""
source_list = source_path
elif isinstance(source_path, str) and os.path.isfile(source_path) and source_path[-3:] == 'txt':
"""这是一个txt文件"""
source_list = [source_path]
这是我使用的训练参数:
为什么会出现这种.ckpt文件失效的问题?请教各位大佬有什么解决方法?
The text was updated successfully, but these errors were encountered: