Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pr-curve support #8

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ C++ API to log data in tensorboard format. Only support `scalar`, `histogram`, `
![text](./assets/text.jpg)
![embedding](./assets/embedding.png)
![multiple-image](./assets/multi-image.png)
![pr-curve](./assets/pr_curve.png)

# Acknowledgement

Expand Down
Binary file added assets/pr_curve.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
27 changes: 22 additions & 5 deletions include/tensorboard_logger.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "crc.h"
#include "event.pb.h"
#include "plugin_pr_curve.pb.h"

using tensorflow::Event;
using tensorflow::Summary;
Expand Down Expand Up @@ -44,9 +45,10 @@ class TensorBoardLogger {
template <typename T>
int add_histogram(const std::string &tag, int step, const T *value,
size_t num) {
if (bucket_limits_ == nullptr) {
generate_default_buckets();
}

double max_range = static_cast<double>(*(std::max(value,value+num-1)));
double min_range = static_cast<double>(*(std::min(value,value+num-1)));
generate_default_buckets({max_range, min_range}, num, false, true);

std::vector<int> counts(bucket_limits_->size(), 0);
double min = std::numeric_limits<double>::max();
Expand Down Expand Up @@ -140,9 +142,24 @@ class TensorBoardLogger {
const std::vector<std::string> &metadata = std::vector<std::string>(),
const std::string &metadata_filename = "",
int step = 1 /* no effect */);

int prcurve(const std::string tag,
const std::vector<double>labels,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use reference

const std::vector<double>predictions,
const int num_thresholds = 127,
std::vector<double>weights = {},
const std::string &display_name = "",
const std::string &description = "");
private:
int generate_default_buckets();
std::vector<std::vector<double>> compute_curve(
const std::vector<double>labels,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

references

const std::vector<double>predictions,
int num_thresholds = 127,
std::vector<double>weights = {});
int generate_default_buckets(
std::vector<double> range = {(-1*1e-12), 1e20},
size_t num_of_bins = 10,
bool ignore_outside_range = false,
bool regenerate = false);
int add_event(int64_t step, Summary *summary);
int write(Event &event);

Expand Down
10 changes: 10 additions & 0 deletions proto/plugin_pr_curve.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
syntax = "proto3";

package tensorflow;

message PrCurvePluginData {
// Version `0` is the only supported version.
int32 version = 1;

uint32 num_thresholds = 2;
}
152 changes: 136 additions & 16 deletions src/tensorboard_logger.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "event.pb.h"
#include "projector_config.pb.h"
#include "plugin_pr_curve.pb.h"

using namespace std;
using google::protobuf::TextFormat;
Expand All @@ -26,27 +27,31 @@ using tensorflow::ProjectorConfig;
using tensorflow::Summary;
using tensorflow::SummaryMetadata;
using tensorflow::TensorProto;
using tensorflow::PrCurvePluginData;
using tensorflow::TensorShapeProto;

// https://github.com/dmlc/tensorboard/blob/master/python/tensorboard/summary.py#L115
int TensorBoardLogger::generate_default_buckets() {
if (bucket_limits_ == nullptr) {
int TensorBoardLogger::generate_default_buckets(std::vector<double> range,
size_t num_of_bins,
bool ignore_outside_range,
bool regenerate ) {
if (bucket_limits_ == nullptr || regenerate == true) {
bucket_limits_ = new vector<double>;
vector<double> pos_buckets, neg_buckets;
double v = 1e-12;
while (v < 1e20) {
pos_buckets.push_back(v);
neg_buckets.push_back(-v);
v *= 1.1;
double v = range[0];
double width = (range[1] - range[0]) / num_of_bins ;
if (width == 0)
width = 1;
if(!ignore_outside_range)
bucket_limits_->push_back(numeric_limits<double>::lowest());
while (v <= range[1]) {
bucket_limits_->push_back(v);
v = v + width;
}
if(!ignore_outside_range)
{
bucket_limits_->push_back(numeric_limits<double>::max());
}
pos_buckets.push_back(numeric_limits<double>::max());
neg_buckets.push_back(numeric_limits<double>::lowest());

bucket_limits_->insert(bucket_limits_->end(), neg_buckets.rbegin(),
neg_buckets.rend());
bucket_limits_->insert(bucket_limits_->end(), pos_buckets.begin(),
pos_buckets.end());
}

return 0;
}

Expand Down Expand Up @@ -243,6 +248,121 @@ int TensorBoardLogger::add_embedding(
tensor_shape, step);
}

std::vector<std::vector<double>> TensorBoardLogger::compute_curve(
const std::vector<double>labels,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reference

const std::vector<double>predictions,
int num_thresholds,
std::vector<double>weights)
{
// misbheaves when thresholds is greater than 127
num_thresholds = min(num_thresholds,127);
double min_count = 1e-7;
std::vector<std::vector<double>> data;
while (weights.size()<labels.size())
{
weights.push_back(1.0);
}
generate_default_buckets({0, (double)num_thresholds - 1}, num_thresholds, true, true);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this num_thresholds be fixed to 127, so buckets only need to be generated once? I don't see the necessity to change it. (I will look deeper later)

vector<double> tp(bucket_limits_->size(), 0), fp(bucket_limits_->size(), 0);

for (size_t i = 0; i < labels.size(); ++i)
{
float v = labels[i];
int item = predictions[i] * (num_thresholds -1);
auto lb =
lower_bound(bucket_limits_->begin(), bucket_limits_->end(), item);
{
tp[lb - bucket_limits_->begin()] = tp[lb - bucket_limits_->begin()] + (v*weights[i]);
fp[lb - bucket_limits_->begin()] = fp[lb - bucket_limits_->begin()] + ((1-v)*weights[i]);
}
}

// Reverse cummulative sum
for(int i = tp.size() - 1; i >= 0 ;i--)
{
tp[i] = tp[i] + tp[i+1];
fp[i] = fp[i] + fp[i+1];
}
reverse(tp.begin(), tp.end());
reverse(fp.begin(), fp.end());
for(int i = tp.size() - 1; i >= 0 ;i--)
{

tp[i] = tp[i] + tp[i+1];
fp[i] = fp[i] + fp[i+1];
}
std::vector<double> tn(tp.size()), fn(tp.size()), precision(tp.size()), recall(tp.size());
for(size_t i = 0; i < tp.size() ;i++)
{
tn[i] = tp[0] - tp[i];
fn[i] = fp[0] - fp[i];
precision[i] = tp[i] / max(min_count,tp[i]+fp[i]);
recall[i] = tp[i] / max(min_count,tp[i]+fn[i]);
}
data.push_back(tp);
data.push_back(fp);
data.push_back(tn);
data.push_back(fn);
data.push_back(precision);
data.push_back(recall);
return data;
}
int TensorBoardLogger::prcurve(
const std::string tag,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

references

const std::vector<double>labels,
const std::vector<double>predictions,
const int num_thresholds,
std::vector<double>weights,
const std::string &display_name,
const std::string &description)
{
// Pr plugin
PrCurvePluginData *pr_curve_plugin = new PrCurvePluginData();
pr_curve_plugin->set_version(0);
pr_curve_plugin->set_num_thresholds(num_thresholds);
std::string pr_curve_content;
pr_curve_plugin->SerializeToString(&pr_curve_content);

// PluginMeta data
auto *plugin_data = new SummaryMetadata::PluginData();
plugin_data->set_plugin_name("pr_curves");
plugin_data->set_content(pr_curve_content);

// Summary Meta data
auto *meta = new SummaryMetadata();
meta->set_display_name(display_name == "" ? tag : display_name);
meta->set_summary_description(description);
meta->set_allocated_plugin_data(plugin_data);

std::vector<std::vector<double>> data =
compute_curve(labels, predictions, num_thresholds, weights);

// Prepare Tensor
auto *tensorshape = new TensorShapeProto();
auto rowdim = tensorshape->add_dim();
rowdim->set_size(data.size());
auto coldim = tensorshape->add_dim();
coldim->set_size(data[0].size());
auto *tensor = new TensorProto();
tensor->set_dtype(tensorflow::DataType::DT_DOUBLE);
tensor->set_allocated_tensor_shape(tensorshape);
for(int i=0;i<data.size();i++)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Range-based for loop is better suited here.

{
for(int j=0;j<data[0].size();j++)
{
tensor->add_double_val(data[i][j]);
}
}

auto *summary = new Summary();
auto *v = summary->add_value();
v->set_tag(tag);
v->set_allocated_tensor(tensor);
v->set_allocated_metadata(meta);

return add_event(0, summary);
}

int TensorBoardLogger::add_embedding(const std::string &tensor_name,
const float *tensor,
const std::vector<uint32_t> &tensor_shape,
Expand Down
13 changes: 12 additions & 1 deletion tests/test_tensorboard_logger.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,18 @@ int test_log(const char* log_file) {
tensor_shape.push_back(tensor[0].size());
logger.add_embedding("binary tensor 1d", tensor_1d, tensor_shape,
"tensor_1d.bin", meta, "binary_tensor_1d.tsv");
delete[] tensor_1d;
delete[] tensor_1d;

// test pr curver
vector<double> labels, predictions;
for(int i=0;i<100;i++)
{
double item = (double) rand()/RAND_MAX;
double sem_item = (double) rand()/RAND_MAX;
labels.push_back(round(sem_item));
predictions.push_back(item);
}
logger.prcurve("pr_curve",labels,predictions,127);

return 0;
}
Expand Down