Skip to content

Commit

Permalink
perf: multithread
Browse files Browse the repository at this point in the history
Signed-off-by: satoshi-ota <[email protected]>
  • Loading branch information
satoshi-ota committed Sep 12, 2024
1 parent 0fddae1 commit 67a222b
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@
grid_seach:
min: 0.1
max: 1.0
resolusion: 0.02
resolusion: 0.2
dt: 1.0
thread_num: 8
5 changes: 5 additions & 0 deletions planning/autoware_planning_data_analyzer/src/data_structs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ struct GridSearchParameters
double max{1.0};
double resolusion{0.01};
double dt{1.0};
size_t thread_num{4};
};

struct Parameters
Expand All @@ -96,6 +97,10 @@ struct Parameters

struct Result
{
Result(const double w0, const double w1, const double w2, const double w3)
: w0{w0}, w1{w1}, w2{w2}, w3{w3}
{
}
double loss{0.0};
double w0{0.0};
double w1{0.0};
Expand Down
89 changes: 61 additions & 28 deletions planning/autoware_planning_data_analyzer/src/node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#include "node.hpp"

#include "autoware/universe_utils/system/stop_watch.hpp"

#include <autoware/universe_utils/ros/marker_helper.hpp>

namespace autoware::behavior_analyzer
Expand Down Expand Up @@ -79,6 +81,7 @@ BehaviorAnalyzerNode::BehaviorAnalyzerNode(const rclcpp::NodeOptions & node_opti
parameters_->grid_seach.min = declare_parameter<double>("grid_seach.min");
parameters_->grid_seach.max = declare_parameter<double>("grid_seach.max");
parameters_->grid_seach.resolusion = declare_parameter<double>("grid_seach.resolusion");
parameters_->grid_seach.thread_num = declare_parameter<int>("grid_seach.thread_num");
parameters_->target_state.lat_positions =
declare_parameter<std::vector<double>>("target_state.lateral_positions");
parameters_->target_state.lat_velocities =
Expand Down Expand Up @@ -204,17 +207,23 @@ void BehaviorAnalyzerNode::weight(
const auto bag_data = std::make_shared<BagData>(
duration_cast<nanoseconds>(reader_.get_metadata().starting_time.time_since_epoch()).count());

const auto e = std::numeric_limits<double>::epsilon();

const size_t grid_size =
std::ceil((p->grid_seach.max - p->grid_seach.min + e) / p->grid_seach.resolusion);

std::vector<Result> results(std::pow(grid_size, 4));
std::vector<Result> weight_grid;

std::cout << "S:" << results.size() << std::endl;
double resolusion = p->grid_seach.resolusion;
double min = p->grid_seach.min;
double max = p->grid_seach.max;
for (double w0 = min; w0 < max + 0.1 * resolusion; w0 += resolusion) {
for (double w1 = min; w1 < max + 0.1 * resolusion; w1 += resolusion) {
for (double w2 = min; w2 < max + 0.1 * resolusion; w2 += resolusion) {
for (double w3 = min; w3 < max + 0.1 * resolusion; w3 += resolusion) {
weight_grid.emplace_back(w0, w1, w2, w3);
}
}
}
}

const auto show_best_result = [&results]() {
auto sort_by_loss = results;
const auto show_best_result = [&weight_grid]() {
auto sort_by_loss = weight_grid;
std::sort(sort_by_loss.begin(), sort_by_loss.end(), [](const auto & a, const auto & b) {
return a.loss < b.loss;
});
Expand All @@ -224,39 +233,63 @@ void BehaviorAnalyzerNode::weight(
std::cout << std::fixed;
std::cout << std::setprecision(4);
std::cout << " [w0]:" << best.w0 << " [w1]:" << best.w1 << " [w2]:" << best.w2
<< " [w3]:" << best.w3 << std::endl;
<< " [w3]:" << best.w3 << " [loss]:" << best.loss << std::endl;
};

autoware::universe_utils::StopWatch<std::chrono::milliseconds> stop_watch;

stop_watch.tic("total_time");
while (reader_.has_next() && rclcpp::ok()) {
update(bag_data, p->grid_seach.dt);

if (!bag_data->ready()) break;

const auto data_set = std::make_shared<DataSet>(bag_data, vehicle_info_, p);
size_t i = 0;
for (double w0 = p->grid_seach.min; w0 < p->grid_seach.max + 0.1 * p->grid_seach.resolusion;
w0 += p->grid_seach.resolusion) {
for (double w1 = p->grid_seach.min; w1 < p->grid_seach.max + 0.1 * p->grid_seach.resolusion;
w1 += p->grid_seach.resolusion) {
for (double w2 = p->grid_seach.min; w2 < p->grid_seach.max + 0.1 * p->grid_seach.resolusion;
w2 += p->grid_seach.resolusion) {
for (double w3 = p->grid_seach.min;
w3 < p->grid_seach.max + 0.1 * p->grid_seach.resolusion;
w3 += p->grid_seach.resolusion) {
results.at(i).loss += data_set->loss(w0, w1, w2, w3);
results.at(i).w0 = w0;
results.at(i).w1 = w1;
results.at(i).w2 = w2;
results.at(i).w3 = w3;
i++;
}

std::mutex grid_mutex;

const auto update = [&weight_grid, &grid_mutex](const auto & data_set, const auto idx) {
double w0 = 0.0;
double w1 = 0.0;
double w2 = 0.0;
double w3 = 0.0;

{
std::lock_guard<std::mutex> lock(grid_mutex);
if (idx + 1 > weight_grid.size()) return;
w0 = weight_grid.at(idx).w0;
w1 = weight_grid.at(idx).w1;
w2 = weight_grid.at(idx).w2;
w3 = weight_grid.at(idx).w3;
}

const auto loss = data_set->loss(w0, w1, w2, w3);

{
std::lock_guard<std::mutex> lock(grid_mutex);
if (idx < weight_grid.size()) {
weight_grid.at(idx).loss += loss;
}
}
};

size_t i = 0;
while (rclcpp::ok()) {
std::vector<std::thread> threads;
for (size_t thread_id = 0; thread_id < p->grid_seach.thread_num; thread_id++) {
threads.emplace_back(update, data_set, i + thread_id);
}

for (auto & t : threads) t.join();

if (i + 1 > weight_grid.size()) break;

i += p->grid_seach.thread_num;
}
std::cout << "A:" << i << std::endl;

show_best_result();
}
std::cout << "process time: " << stop_watch.toc("total_time") << "[ms]" << std::endl;

RCLCPP_INFO(get_logger(), "finish weight grid seach.");

Expand Down

0 comments on commit 67a222b

Please sign in to comment.