diff --git a/README.md b/README.md index bd0b58c79f6..91dd030df4e 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,24 @@ +#Faster Center Loss Implementation +This branch is forked from [ydwen's caffe-face](https://github.com/ydwen/caffe-face) and modified by mfs6174 ( mfs6174@gmail.com ) + +Compared to the original implementation by the paper author, the backward time complexity of this implementation is optimized to O(MK) from O(MK+NM). + +In the original implementation, the time complexity of the backward process of the center loss layer is O(MK+NM). It will be very slow when training with a large number of classes since the running time of the backward pass is related to the class number (N). Unfortunately, it is a common case when training face recognition model (e.g. 750k unique persons). + +This implementation rewrites the backward code. The time complexity is optimized to O(MK) with additional O(N) space. Because M (batch size) << N and K (feature length) << N usually hold for face recognition problem, this modification will improve the training speed significantly. + +For a Googlenet v2 model trained with Everphoto's 750k unique person dataset, on a single Nvidia GTX Titan X, with 24 batch size and iter_size = 5, the average backward iteration time for different cases is: + +1. Softmax only: 230ms +2. Softmax + Center loss, original implementation: 3485ms, center loss layer: 3332ms +3. Softmax + Center loss, implementation in this PR: 235.6ms, center loss layer: 5.4ms + +There is more than 600x improvement. + +For the author's "minit_example", running on a single GTX Titan X, training time of the original implementation and the PR is 4min20s V.S. 3min50s. It is shown that even when training with small dataset with only 10 classes, there still is some improvement. + +The implementation also fix the code style to pass the Caffe's lint test (make lint) so that it may be ready to be merged into Caffe's master. + # Deep Face Recognition with Caffe Implementation This branch is developed for deep face recognition, the related paper is as follows. @@ -185,4 +206,4 @@ Please cite Caffe in your publications if it helps your research: Journal = {arXiv preprint arXiv:1408.5093}, Title = {Caffe: Convolutional Architecture for Fast Feature Embedding}, Year = {2014} - } \ No newline at end of file + } diff --git a/include/caffe/layers/center_loss_layer.hpp b/include/caffe/layers/center_loss_layer.hpp index cd6fd1cf994..8d28dbb21d4 100644 --- a/include/caffe/layers/center_loss_layer.hpp +++ b/include/caffe/layers/center_loss_layer.hpp @@ -38,11 +38,11 @@ class CenterLossLayer : public LossLayer { int M_; int K_; int N_; - Blob distance_; Blob variation_sum_; + Blob count_; }; } // namespace caffe -#endif // CAFFE_CENTER_LOSS_LAYER_HPP_ \ No newline at end of file +#endif // CAFFE_CENTER_LOSS_LAYER_HPP_ diff --git a/include/caffe/layers/normalize_layer.hpp b/include/caffe/layers/normalize_layer.hpp new file mode 100644 index 00000000000..26590457d94 --- /dev/null +++ b/include/caffe/layers/normalize_layer.hpp @@ -0,0 +1,44 @@ +#ifndef CAFFE_NEURON_LAYER_HPP_ +#define CAFFE_NEURON_LAYER_HPP_ + +#include + +#include "caffe/blob.hpp" +#include "caffe/layer.hpp" +#include "caffe/proto/caffe.pb.h" + +namespace caffe { + +/** + * @brief An interface for layers that take one blob as input (@f$ x @f$) + * and produce one equally-sized blob as output (@f$ y @f$), where + * each element of the output depends only on the corresponding input + * element. + */ +template +class NormalizeLayer : public Layer { + public: + explicit NormalizeLayer(const LayerParameter& param) + : Layer(param) {} + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline int ExactNumBottomBlobs() const { return 1; } + virtual inline int ExactNumTopBlobs() const { return 1; } + +protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + Blob sum_multiplier_, norm_, squared_; +}; + +} // namespace caffe + +#endif // CAFFE_NEURON_LAYER_HPP_ diff --git a/src/caffe/layers/center_loss_layer.cpp b/src/caffe/layers/center_loss_layer.cpp index 5e79c3af528..0b69d0d1c67 100644 --- a/src/caffe/layers/center_loss_layer.cpp +++ b/src/caffe/layers/center_loss_layer.cpp @@ -9,7 +9,7 @@ namespace caffe { template void CenterLossLayer::LayerSetUp(const vector*>& bottom, const vector*>& top) { - const int num_output = this->layer_param_.center_loss_param().num_output(); + const int num_output = this->layer_param_.center_loss_param().num_output(); N_ = num_output; const int axis = bottom[0]->CanonicalAxisIndex( this->layer_param_.center_loss_param().axis()); @@ -31,7 +31,6 @@ void CenterLossLayer::LayerSetUp(const vector*>& bottom, shared_ptr > center_filler(GetFiller( this->layer_param_.center_loss_param().center_filler())); center_filler->Fill(this->blobs_[0].get()); - } // parameter initialization this->param_propagate_down_.resize(this->blobs_.size(), true); } @@ -48,6 +47,9 @@ void CenterLossLayer::Reshape(const vector*>& bottom, LossLayer::Reshape(bottom, top); distance_.ReshapeLike(*bottom[0]); variation_sum_.ReshapeLike(*this->blobs_[0]); + vector count_shape(1); + count_shape[0] = N_; + count_.Reshape(count_shape); } template @@ -57,47 +59,55 @@ void CenterLossLayer::Forward_cpu(const vector*>& bottom, const Dtype* label = bottom[1]->cpu_data(); const Dtype* center = this->blobs_[0]->cpu_data(); Dtype* distance_data = distance_.mutable_cpu_data(); - // the i-th distance_data for (int i = 0; i < M_; i++) { const int label_value = static_cast(label[i]); // D(i,:) = X(i,:) - C(y(i),:) - caffe_sub(K_, bottom_data + i * K_, center + label_value * K_, distance_data + i * K_); + caffe_sub(K_, bottom_data + i * K_, + center + label_value * K_, distance_data + i * K_); } - Dtype dot = caffe_cpu_dot(M_ * K_, distance_.cpu_data(), distance_.cpu_data()); + Dtype dot = caffe_cpu_dot(M_ * K_, distance_.cpu_data(), + distance_.cpu_data()); Dtype loss = dot / M_ / Dtype(2); top[0]->mutable_cpu_data()[0] = loss; } template void CenterLossLayer::Backward_cpu(const vector*>& top, - const vector& propagate_down, - const vector*>& bottom) { + const vector& propagate_down, + const vector*>& bottom) { // Gradient with respect to centers if (this->param_propagate_down_[0]) { const Dtype* label = bottom[1]->cpu_data(); Dtype* center_diff = this->blobs_[0]->mutable_cpu_diff(); Dtype* variation_sum_data = variation_sum_.mutable_cpu_data(); + int* count_data = count_.mutable_cpu_data(); + const Dtype* distance_data = distance_.cpu_data(); // \sum_{y_i==j} caffe_set(N_ * K_, (Dtype)0., variation_sum_.mutable_cpu_data()); - for (int n = 0; n < N_; n++) { - int count = 0; - for (int m = 0; m < M_; m++) { - const int label_value = static_cast(label[m]); - if (label_value == n) { - count++; - caffe_sub(K_, variation_sum_data + n * K_, distance_data + m * K_, variation_sum_data + n * K_); - } - } - caffe_axpy(K_, (Dtype)1./(count + (Dtype)1.), variation_sum_data + n * K_, center_diff + n * K_); + caffe_set(N_, 0 , count_.mutable_cpu_data()); + + for (int m = 0; m < M_; m++) { + const int label_value = static_cast(label[m]); + caffe_sub(K_, variation_sum_data + label_value * K_, + distance_data + m * K_, variation_sum_data + label_value * K_); + count_data[label_value]++; + } + for (int m = 0; m < M_; m++) { + const int n = static_cast(label[m]); + caffe_cpu_axpby(K_, (Dtype)1./ (count_data[n] + (Dtype)1.), + variation_sum_data + n * K_, + (Dtype)0., center_diff + n * K_); } } - // Gradient with respect to bottom data + // Gradient with respect to bottom data if (propagate_down[0]) { - caffe_copy(M_ * K_, distance_.cpu_data(), bottom[0]->mutable_cpu_diff()); - caffe_scal(M_ * K_, top[0]->cpu_diff()[0] / M_, bottom[0]->mutable_cpu_diff()); + caffe_copy(M_ * K_, distance_.cpu_data(), + bottom[0]->mutable_cpu_diff()); + caffe_scal(M_ * K_, top[0]->cpu_diff()[0] / M_, + bottom[0]->mutable_cpu_diff()); } if (propagate_down[1]) { LOG(FATAL) << this->type() @@ -105,6 +115,7 @@ void CenterLossLayer::Backward_cpu(const vector*>& top, } } + #ifdef CPU_ONLY STUB_GPU(CenterLossLayer); #endif diff --git a/src/caffe/layers/center_loss_layer.cu b/src/caffe/layers/center_loss_layer.cu index f493557d5fd..24e8a63b8a3 100644 --- a/src/caffe/layers/center_loss_layer.cu +++ b/src/caffe/layers/center_loss_layer.cu @@ -7,8 +7,11 @@ namespace caffe { template -__global__ void Compute_distance_data_gpu(int nthreads, const int K, const Dtype* bottom, - const Dtype* label, const Dtype* center, Dtype* distance) { +__global__ void Compute_distance_data_gpu(int nthreads, const int K, + const Dtype* bottom, + const Dtype* label, + const Dtype* center, + Dtype* distance) { CUDA_KERNEL_LOOP(index, nthreads) { int m = index / K; int k = index % K; @@ -17,55 +20,68 @@ __global__ void Compute_distance_data_gpu(int nthreads, const int K, const Dtype distance[index] = bottom[index] - center[label_value * K + k]; } } - template -__global__ void Compute_center_diff_gpu(int nthreads, const int M, const int K, - const Dtype* label, const Dtype* distance, Dtype* variation_sum, - Dtype* center_diff) { +__global__ void Compute_variation_sum_gpu(int nthreads, const int K, + const Dtype* label, + const Dtype* distance, + Dtype* variation_sum, int * count) { CUDA_KERNEL_LOOP(index, nthreads) { - int count = 0; - for (int m = 0; m < M; m++) { - const int label_value = static_cast(label[m]); - if (label_value == index) { - count++; - for (int k = 0; k < K; k++) { - variation_sum[index * K + k] -= distance[m * K + k]; - } - } - } - for (int k = 0; k < K; k++) { - center_diff[index * K + k] = variation_sum[index * K + k] /(count + (Dtype)1.); - } + int m = index / K; + int k = index % K; + const int label_value = static_cast(label[m]); + variation_sum[label_value * K + k] -= distance[m * K + k]; + count[label_value] += ((k == 0)?1:0); } } - - + template +__global__ void Compute_center_diff_gpu(int nthreads, const int K, + const Dtype* label, + Dtype* variation_sum, + int * count, Dtype* center_diff) { + CUDA_KERNEL_LOOP(index, nthreads) { + int m = index / K; + int k = index % K; + const int n = static_cast(label[m]); + center_diff[n * K + k] = variation_sum[n * K + k] + / (count[n] + (Dtype)1.); + } +} template void CenterLossLayer::Forward_gpu(const vector*>& bottom, const vector*>& top) { int nthreads = M_ * K_; - Compute_distance_data_gpu<<>>(nthreads, K_, bottom[0]->gpu_data(), bottom[1]->gpu_data(), - this->blobs_[0]->gpu_data(), distance_.mutable_gpu_data()); + Compute_distance_data_gpu <<< CAFFE_GET_BLOCKS(nthreads), + CAFFE_CUDA_NUM_THREADS>>>(nthreads, K_, bottom[0]->gpu_data(), + bottom[1]->gpu_data(), + this->blobs_[0]->gpu_data(), + distance_.mutable_gpu_data()); Dtype dot; caffe_gpu_dot(M_ * K_, distance_.gpu_data(), distance_.gpu_data(), &dot); Dtype loss = dot / M_ / Dtype(2); top[0]->mutable_cpu_data()[0] = loss; } - template void CenterLossLayer::Backward_gpu(const vector*>& top, const vector& propagate_down, const vector*>& bottom) { - int nthreads = N_; - caffe_gpu_set(N_ * K_, (Dtype)0., variation_sum_.mutable_cpu_data()); - Compute_center_diff_gpu<<>>(nthreads, M_, K_, bottom[1]->gpu_data(), distance_.gpu_data(), - variation_sum_.mutable_cpu_data(), this->blobs_[0]->mutable_gpu_diff()); - + caffe_gpu_set(N_ * K_, (Dtype)0., variation_sum_.mutable_gpu_data()); + caffe_gpu_set(N_, 0 , count_.mutable_gpu_data()); + int nthreads = M_ * K_; + Compute_variation_sum_gpu <<< CAFFE_GET_BLOCKS(nthreads), + CAFFE_CUDA_NUM_THREADS>>>(nthreads, K_, bottom[1]->gpu_data(), + distance_.gpu_data(), + variation_sum_.mutable_gpu_data(), + count_.mutable_gpu_data()); + Compute_center_diff_gpu <<< CAFFE_GET_BLOCKS(nthreads), + CAFFE_CUDA_NUM_THREADS>>>(nthreads, K_, bottom[1]->gpu_data(), + variation_sum_.mutable_gpu_data(), + count_.mutable_gpu_data(), + this->blobs_[0]->mutable_gpu_diff()); if (propagate_down[0]) { - caffe_gpu_scale(M_ * K_, top[0]->cpu_diff()[0] / M_, - distance_.gpu_data(), bottom[0]->mutable_gpu_diff()); + caffe_gpu_scale(M_ * K_, + top[0]->cpu_diff()[0] / M_, + distance_.gpu_data(), + bottom[0]->mutable_gpu_diff()); } if (propagate_down[1]) { LOG(FATAL) << this->type() diff --git a/src/caffe/layers/normalize_layer.cpp b/src/caffe/layers/normalize_layer.cpp new file mode 100644 index 00000000000..6968151129e --- /dev/null +++ b/src/caffe/layers/normalize_layer.cpp @@ -0,0 +1,62 @@ +#include +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/layers/normalize_layer.hpp" + +namespace caffe { + +template +void NormalizeLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + top[0]->Reshape(bottom[0]->num(), bottom[0]->channels(), + bottom[0]->height(), bottom[0]->width()); + squared_.Reshape(bottom[0]->num(), bottom[0]->channels(), + bottom[0]->height(), bottom[0]->width()); +} + +template +void NormalizeLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* top_data = top[0]->mutable_cpu_data(); + Dtype* squared_data = squared_.mutable_cpu_data(); + int n = bottom[0]->num(); + int d = bottom[0]->count() / n; + caffe_sqr(n*d, bottom_data, squared_data); + for (int i=0; i(d, squared_data+i*d); + normsqr = (normsqr>Dtype(0))?normsqr:1e-6; + caffe_cpu_scale(d, pow(normsqr, -0.5), bottom_data+i*d, top_data+i*d); + } +} + +template +void NormalizeLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + const Dtype* top_diff = top[0]->cpu_diff(); + const Dtype* top_data = top[0]->cpu_data(); + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + int n = top[0]->num(); + int d = top[0]->count() / n; + for (int i=0; i +#include +#include +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/layers/normalize_layer.hpp" + +namespace caffe { + +template +void NormalizeLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + Dtype* squared_data = squared_.mutable_gpu_data(); + Dtype normsqr; + int n = bottom[0]->num(); + int d = bottom[0]->count() / n; + caffe_gpu_powx(n*d, bottom_data, Dtype(2), squared_data); + for (int i=0; i(d, squared_data+i*d, &normsqr); + normsqr = (normsqr>Dtype(0))?normsqr:1e-6; + caffe_gpu_scale(d, pow(normsqr, -0.5), bottom_data+i*d, top_data+i*d); + } +} + +template +void NormalizeLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + const Dtype* top_diff = top[0]->gpu_diff(); + const Dtype* top_data = top[0]->gpu_data(); + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + int n = top[0]->num(); + int d = top[0]->count() / n; + Dtype a; + for (int i=0; i