Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

多线程 inference.cpp 因为全局变量 no_grad 出现异常 #4

Open
guanquanchen opened this issue Sep 11, 2023 · 0 comments
Open

Comments

@guanquanchen
Copy link

#ifndef CNN_ARCHITECTURES_H
#define CNN_ARCHITECTURES_H

// C++
#include
#include
// self
#include "pipeline.h"

namespace architectures {
using namespace pipeline;

// 随机初始化用的, C++ 这个生成的数字过大, softmax 之前的都好几百, 直接爆了, 坑爹
extern data_type random_times;
//data_type random_times =10.f;
// 全局变量, 是否要 backward, 访问速度上要慢一些
extern bool no_grad;

我试图使用 多线程包装函数,并使用了 2.2万张图片测试多线程,耗时 420s,多线程不可以完全返回结果。
我查找到 原因是 因为 no_grad 这个变量是一个 可读写的全局变量, 问题应该是 这个全局变量 在多线程中被读写导致的错误。

下面是我添加的 多线程

// #include "pool_number.cpp" 我把线程文件放到这里
#ifndef THREAD_POOL_H
#define THREAD_POOL_H

#include
#include
#include
#include
#include
#include <condition_variable>
#include
#include
#include

class ThreadPool {
public:
ThreadPool(size_t);
template<class F, class... Args>
auto enqueue(F&& f, Args&&... args)
->std::future<typename std::result_of<F(Args...)>::type>;
~ThreadPool();
private:
// need to keep track of threads so we can join them
std::vector< std::thread > workers;
// the task queue
std::queue< std::function<void()> > tasks;

// synchronization
std::mutex queue_mutex;
std::condition_variable condition;
bool stop;

};

// the constructor just launches some amount of workers
inline ThreadPool::ThreadPool(size_t threads)
: stop(false)
{
for (size_t i = 0; i < threads; ++i)
workers.emplace_back(
[this]
{
for (;;)
{
std::function<void()> task;

                {
                    std::unique_lock<std::mutex> lock(this->queue_mutex);
                    this->condition.wait(lock,
                        [this] { return this->stop || !this->tasks.empty(); });
                    if (this->stop && this->tasks.empty())
                        return;
                    task = std::move(this->tasks.front());
                    this->tasks.pop();
                }

                task();
            }
        }
        );

}

// add new work item to the pool
template<class F, class... Args>
auto ThreadPool::enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>
{
using return_type = typename std::result_of<F(Args...)>::type;

auto task = std::make_shared< std::packaged_task<return_type()> >(
    std::bind(std::forward<F>(f), std::forward<Args>(args)...)
    );

std::future<return_type> res = task->get_future();
{
    std::unique_lock<std::mutex> lock(queue_mutex);

    // don't allow enqueueing after stopping the pool
    if (stop)
        throw std::runtime_error("enqueue on stopped ThreadPool");

    tasks.emplace([task]() { (*task)(); });
}
condition.notify_one();
return res;

}

// the destructor joins all threads
inline ThreadPool::~ThreadPool()
{
{
std::unique_lockstd::mutex lock(queue_mutex);
stop = true;
}
condition.notify_all();
for (std::thread& worker : workers)
worker.join();
}

#endif

#include
#include <windows.h>
#include
#include <basci/basci.h>
// #include
// #include "pool_number.cpp"
#include
#include
#include
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>

#include "func.h"
#include "architectures.h"

int cpu_number() {
unsigned int numCores = std::thread::hardware_concurrency();
return numCores;
}

namespace {
void cv_show(const cv::Mat& one_image, const char* info="") {
cv::imshow(info, one_image);
cv::waitKey(0);
cv::destroyAllWindows();
}
}

#include
#include

std::vectorstd::string get_png_list(int main_i,vector images_list){
// 输出不要放在缓冲区, 到时间了及时输出
std::setbuf(stdout, 0);

string main_string = "for_i  "+ to_string(main_i);
cout<<main_string<<endl;
using namespace architectures;
//std::cout << "inference\n";

// 指定一些参数
const std::vector<std::string> categories({"dog", "panda", "bird"});

// 定义网络结构
const int num_classes = categories.size(); // 分类的数目
AlexNet network(num_classes);
// 直接加载
network.load_weights("../checkpoints/AlexNet_aug_1e-3/iter_395000_train_0.918_valid_0.913.model");

auto start = std::chrono::high_resolution_clock::now();

// 准备一块图像内容存放的空间
const std::tuple<int, int, int> image_size({3, 224, 224});
tensor buffer_data(new Tensor3D(image_size, "inference_buffer"));
std::vector<tensor> image_buffer({buffer_data});

// 去掉梯度计算
WithoutGrad guard;

// std::cout<<"\n\n\n main_list = [\n";
// 逐一读取图像, 做变换
Json::Value root;
vector json_list;
json_list.push_back(to_string(main_i));
for(const auto& image_path : images_list) {
// try{
// 读取图像
cv::Mat origin = cv::imread(image_path);
if(origin.empty() || !std::filesystem::exists(image_path)) {
std::cout << "Failed to read image file " << image_path << "\n";
continue;
}
// // 图像 resize 到规定的大小, 224 X 224
// cv::resize(origin, origin, {std::get<1>(image_size), std::get<2>(image_size)});
// // 转化为 tensor 数据
// image_buffer[0]->read_from_opencv_mat(origin.data);
// // 经过卷积神经网络得到输出
// const auto output = network.forward(image_buffer);
// // softmax 得到输出
// const auto prob = softmax(output);
// // 找到最大概率的输出
// const int max_index = prob[0]->argmax();

        // // 创建内层Json对象
        // Json::Value inner;
        // inner["path"] = image_path;
        // inner["catgories"] = categories[max_index];
        // inner["data"] = prob[0]->data[max_index];
        // root[image_path]= inner;

        // json_list.push_back(json_str(inner));

        json_list.push_back(image_path);

        //std::cout <<"['" <<image_path << "','" << categories[max_index] << "','" << prob[0]->data[max_index] << "'],\n";
        //cv_show(origin);
// } catch (const std::exception& e)
//     {
//         std::cout << "Exception caught: " << e.what() << std::endl;
//     }
}



//std::cout<<"]\n"<<std::endl;

// cout<<root<<endl;
//write_json_file("./json/main "+to_string(main_i)+".json",root);
// 获取结束时间点
cout<<"ok "+to_string(main_i)<<endl;
auto end = std::chrono::high_resolution_clock::now();
// 计算代码执行时间(以毫秒为单位)
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
// 输出执行时间
std::cout << "usetime:" << duration << " ms" << std::endl;
cout << endl;
// return root;
return json_list;

}

int main(){
u_init();
std::string read_path = main_path+ utf8togbk("/png");
std::cout<<"exe_path "<<read_path<<std::endl;
ulist main_list =get_all_file(read_path);
//main_list.print();
cout<<"list len "<<main_list.len()<<endl;

ustring main_str ;
vector<vector<string >> png_list;
vector<string> one_list;

int main_len = 0;
for (node *p = main_list.head->next; p != main_list.head; p = p->next) {
    main_str.str = p->str;
    main_str.replace("\\","/");
    if (main_len>=1000){
        png_list.push_back(one_list);
        one_list.clear();
        main_len=0;
    }else{
        one_list.push_back(std::string(main_str.str));
    }
    main_len=main_len+1;
}

if (one_list.size()>=1){
    png_list.push_back(one_list);
}

cout<<"main_size   "<<png_list.size()<<endl;

auto start = std::chrono::high_resolution_clock::now();
cout << "cpu_number " << cpu_number() << endl;
int pool_size = cpu_number();
ThreadPool threadpool(pool_size);

int fori=0;
vector<future<vector<string>>> resVec;
for(auto one:png_list){
    cout<<"list size "<<one.size()<<endl;
    resVec.emplace_back(
            threadpool.enqueue(
                // [fori,one]{return com_list::get_png_list(fori,one);}
                [fori,one]{return get_png_list(fori,one);}
            )

        );
    fori +=1;
}


vector<vector<string>> main_list1;
/*打印每个任务的返回值*/
for (auto & result : resVec) {
    main_list1.push_back(result.get());
}

for(auto dd:main_list1){
    cout<<"ok_size "<<dd[0]<<"  "<<dd.size()-1<<endl;
}

// 获取结束时间点
auto end = std::chrono::high_resolution_clock::now();
// 计算代码执行时间(以毫秒为单位)
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
// 输出执行时间
std::cout << "总消耗时间    "<< duration << " ms" << std::endl;
cout << endl;

system("pause");

}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant