Skip to content

Commit

Permalink
fixup! feat: use mapplotlibcpp
Browse files Browse the repository at this point in the history
  • Loading branch information
satoshi-ota committed Sep 26, 2024
1 parent 64bcb02 commit 1e527ff
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
efficiency: 1.0
safety: 1.0
achievability: 1.0
consistency: 1.0

grid_seach:
min: 0.1
Expand Down
5 changes: 3 additions & 2 deletions planning/autoware_planning_data_analyzer/src/data_structs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,14 +297,15 @@ double CommonData::get(const SCORE & score_type) const
}

double CommonData::total(
const double w0, const double w1, const double w2, const double w3, const double w4) const
const double w0, const double w1, const double w2, const double w3, const double w4,
const double w5) const
{
return w0 * scores.at(static_cast<size_t>(SCORE::LATERAL_COMFORTABILITY)) +
w1 * scores.at(static_cast<size_t>(SCORE::LONGITUDINAL_COMFORTABILITY)) +
w2 * scores.at(static_cast<size_t>(SCORE::EFFICIENCY)) +
w3 * scores.at(static_cast<size_t>(SCORE::SAFETY)) +
w4 * scores.at(static_cast<size_t>(SCORE::ACHIEVABILITY)) +
1.0 * scores.at(static_cast<size_t>(SCORE::CONSISTENCY));
w5 * scores.at(static_cast<size_t>(SCORE::CONSISTENCY));
}

ManualDrivingData::ManualDrivingData(
Expand Down
37 changes: 23 additions & 14 deletions planning/autoware_planning_data_analyzer/src/data_structs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,17 @@ struct Parameters
double w2{1.0};
double w3{1.0};
double w4{1.0};
double w5{1.0};
GridSearchParameters grid_seach{};
TargetStateParameters target_state{};
};

struct Result
{
Result(const double w0, const double w1, const double w2, const double w3, const double w4)
: w0{w0}, w1{w1}, w2{w2}, w3{w3}, w4{w4}
Result(
const double w0, const double w1, const double w2, const double w3, const double w4,
const double w5)
: w0{w0}, w1{w1}, w2{w2}, w3{w3}, w4{w4}, w5{w5}
{
}
double loss{0.0};
Expand All @@ -114,6 +117,7 @@ struct Result
double w2{0.0};
double w3{0.0};
double w4{0.0};
double w5{0.0};
};

struct BufferBase
Expand Down Expand Up @@ -247,7 +251,8 @@ struct CommonData
double consistency() const;

double total(
const double w0, const double w1, const double w2, const double w3, const double w4) const;
const double w0, const double w1, const double w2, const double w3, const double w4,
const double w5) const;

double get(const SCORE & score_type) const;

Expand Down Expand Up @@ -345,15 +350,16 @@ struct SamplingTrajectoryData
const std::shared_ptr<BagData> & bag_data, const vehicle_info_utils::VehicleInfo & vehicle_info,
const std::shared_ptr<Parameters> & parameters, const std::optional<TrajectoryPoints> & t_best);

auto best(const double w0, const double w1, const double w2, const double w3, const double w4)
const -> std::optional<TrajectoryData>
auto best(
const double w0, const double w1, const double w2, const double w3, const double w4,
const double w5) const -> std::optional<TrajectoryData>
{
auto sort_by_score = data;

std::sort(
sort_by_score.begin(), sort_by_score.end(),
[&w0, &w1, &w2, &w3, &w4](const auto & a, const auto & b) {
return a.total(w0, w1, w2, w3, w4) > b.total(w0, w1, w2, w3, w4);
[&w0, &w1, &w2, &w3, &w4, &w5](const auto & a, const auto & b) {
return a.total(w0, w1, w2, w3, w4, w5) > b.total(w0, w1, w2, w3, w4, w5);
});

const auto itr = std::remove_if(
Expand Down Expand Up @@ -418,8 +424,9 @@ struct DataSet
data.normalize(s3_min, s3_max, static_cast<size_t>(SCORE::SAFETY));
data.normalize(s4_min, s4_max, static_cast<size_t>(SCORE::ACHIEVABILITY), true);
data.normalize(s5_min, s5_max, static_cast<size_t>(SCORE::CONSISTENCY), true);
data.scores.at(static_cast<size_t>(SCORE::TOTAL)) =
data.total(parameters->w0, parameters->w1, parameters->w2, parameters->w3, parameters->w4);
data.scores.at(static_cast<size_t>(SCORE::TOTAL)) = data.total(
parameters->w0, parameters->w1, parameters->w2, parameters->w3, parameters->w4,
parameters->w5);
}

const auto [total_min, total_max] = range(static_cast<size_t>(SCORE::TOTAL));
Expand All @@ -429,10 +436,11 @@ struct DataSet
}
}

auto loss(const double w0, const double w1, const double w2, const double w3, const double w4)
const -> double
auto loss(
const double w0, const double w1, const double w2, const double w3, const double w4,
const double w5) const -> double
{
const auto best = sampling.best(w0, w1, w2, w3, w4);
const auto best = sampling.best(w0, w1, w2, w3, w4, w5);
if (!best.has_value()) {
return 0.0;
}
Expand All @@ -455,8 +463,9 @@ struct DataSet

void show()
{
const auto best =
sampling.best(parameters->w0, parameters->w1, parameters->w2, parameters->w3, parameters->w4);
const auto best = sampling.best(
parameters->w0, parameters->w1, parameters->w2, parameters->w3, parameters->w4,
parameters->w5);

if (!best.has_value()) {
return;
Expand Down
22 changes: 18 additions & 4 deletions planning/autoware_planning_data_analyzer/src/node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ BehaviorAnalyzerNode::BehaviorAnalyzerNode(const rclcpp::NodeOptions & node_opti
parameters_->w2 = declare_parameter<double>("weight.efficiency");
parameters_->w3 = declare_parameter<double>("weight.safety");
parameters_->w4 = declare_parameter<double>("weight.achievability");
parameters_->w5 = declare_parameter<double>("weight.consistency");
parameters_->grid_seach.dt = declare_parameter<double>("grid_seach.dt");
parameters_->grid_seach.min = declare_parameter<double>("grid_seach.min");
parameters_->grid_seach.max = declare_parameter<double>("grid_seach.max");
Expand Down Expand Up @@ -290,7 +291,9 @@ void BehaviorAnalyzerNode::weight(
for (double w2 = min; w2 < max + 0.1 * resolusion; w2 += resolusion) {
for (double w3 = min; w3 < max + 0.1 * resolusion; w3 += resolusion) {
for (double w4 = min; w4 < max + 0.1 * resolusion; w4 += resolusion) {
weight_grid.emplace_back(w0, w1, w2, w3, w4);
for (double w5 = min; w5 < max + 0.1 * resolusion; w5 += resolusion) {
weight_grid.emplace_back(w0, w1, w2, w3, w4, w5);
}
}
}
}
Expand All @@ -313,6 +316,7 @@ void BehaviorAnalyzerNode::weight(
ss << " [w2]:" << best.w2;
ss << " [w3]:" << best.w3;
ss << " [w4]:" << best.w4;
ss << " [w5]:" << best.w5;
ss << " [loss]:" << best.loss << std::endl;
// clang-format on
RCLCPP_INFO_STREAM(get_logger(), ss.str());
Expand All @@ -335,6 +339,7 @@ void BehaviorAnalyzerNode::weight(
double w2 = 0.0;
double w3 = 0.0;
double w4 = 0.0;
double w5 = 0.0;

{
std::lock_guard<std::mutex> lock(grid_mutex);
Expand All @@ -344,9 +349,10 @@ void BehaviorAnalyzerNode::weight(
w2 = weight_grid.at(idx).w2;
w3 = weight_grid.at(idx).w3;
w4 = weight_grid.at(idx).w4;
w5 = weight_grid.at(idx).w5;
}

const auto loss = data_set->loss(w0, w1, w2, w3, w4);
const auto loss = data_set->loss(w0, w1, w2, w3, w4, w5);

{
std::lock_guard<std::mutex> lock(grid_mutex);
Expand All @@ -370,6 +376,13 @@ void BehaviorAnalyzerNode::weight(
i += p->grid_seach.thread_num;
}

const auto t_best = data_set->sampling.best(p->w0, p->w1, p->w2, p->w3, p->w4, p->w5);
if (t_best.has_value()) {
best = t_best.value().points;
} else {
best = std::nullopt;
}

std::cout << "IDX:" << i << " GRID:" << weight_grid.size() << std::endl;

show_best_result();
Expand Down Expand Up @@ -517,7 +530,8 @@ void BehaviorAnalyzerNode::visualize(const std::shared_ptr<DataSet> & data_set)
}

const auto best = data_set->sampling.best(
parameters_->w0, parameters_->w1, parameters_->w2, parameters_->w3, parameters_->w4);
parameters_->w0, parameters_->w1, parameters_->w2, parameters_->w3, parameters_->w4,
parameters_->w5);
if (best.has_value()) {
Marker marker = createDefaultMarker(
"map", rclcpp::Clock{RCL_ROS_TIME}.now(), "best_score", 0L, Marker::LINE_STRIP,
Expand Down Expand Up @@ -626,7 +640,7 @@ void BehaviorAnalyzerNode::plot(const std::shared_ptr<DataSet> & data_set) const
subplot(SCORE::ACHIEVABILITY, 3, 3, 5);
subplot(SCORE::CONSISTENCY, 3, 3, 6);
subplot(SCORE::TOTAL, 3, 3, 7);
plot_best(data_set->sampling.best(p->w0, p->w1, p->w2, p->w3, p->w4), 3, 3, 8);
plot_best(data_set->sampling.best(p->w0, p->w1, p->w2, p->w3, p->w4, p->w5), 3, 3, 8);
matplotlibcpp::pause(1e-9);
}
} // namespace autoware::behavior_analyzer
Expand Down

0 comments on commit 1e527ff

Please sign in to comment.