Skip to content

Commit 69b0e8c

Browse files
committed
EmbedBackward with no loops -- use caffe_gpu_atomic_add instead
1 parent 86aad5c commit 69b0e8c

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

src/caffe/layers/embed_layer.cu

+15-10
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "caffe/common_layers.hpp"
66
#include "caffe/filler.hpp"
77
#include "caffe/layer.hpp"
8+
#include "caffe/util/gpu_util.cuh"
89
#include "caffe/util/math_functions.hpp"
910

1011
namespace caffe {
@@ -22,18 +23,21 @@ __global__ void EmbedForward(const int nthreads, const Dtype* bottom_data,
2223
}
2324
}
2425

26+
template <typename Dtype>
27+
__global__ void EmbedBackward(const int nthreads, const Dtype* bottom_data,
28+
const Dtype* top_diff, const int M, const int N, const int K,
29+
Dtype* weight_diff);
30+
2531
template <typename Dtype>
2632
__global__ void EmbedBackward(const int nthreads, const Dtype* bottom_data,
2733
const Dtype* top_diff, const int M, const int N, const int K,
2834
Dtype* weight_diff) {
29-
CUDA_KERNEL_LOOP(weight_index, nthreads) {
30-
const int index = weight_index / N;
31-
const int output_index = weight_index % N;
32-
for (int n = 0; n < M; ++n) {
33-
if (static_cast<int>(bottom_data[n]) == index) {
34-
weight_diff[weight_index] += top_diff[n * N + output_index];
35-
}
36-
}
35+
CUDA_KERNEL_LOOP(top_index, nthreads) {
36+
const int n = top_index / N;
37+
const int d = top_index % N;
38+
const int index = static_cast<int>(bottom_data[n]);
39+
const int weight_index = index * N + d;
40+
caffe_gpu_atomic_add(top_diff[top_index], weight_diff + weight_index);
3741
}
3842
}
3943

@@ -59,13 +63,14 @@ void EmbedLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
5963
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
6064
CHECK(!propagate_down[0]) << "Can't backpropagate to EmbedLayer input.";
6165
if (this->param_propagate_down_[0]) {
66+
const int top_count = top[0]->count();
6267
const int count = this->blobs_[0]->count();
6368
const Dtype* top_diff = top[0]->gpu_diff();
6469
const Dtype* bottom_data = bottom[0]->gpu_data();
6570
Dtype* weight_diff = this->blobs_[0]->mutable_gpu_diff();
6671
EmbedBackward<Dtype> // NOLINT_NEXT_LINE(whitespace/operators)
67-
<<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
68-
count, bottom_data, top_diff, M_, N_, K_, weight_diff);
72+
<<<CAFFE_GET_BLOCKS(top_count), CAFFE_CUDA_NUM_THREADS>>>(
73+
top_count, bottom_data, top_diff, M_, N_, K_, weight_diff);
6974
}
7075
if (bias_term_ && this->param_propagate_down_[1]) {
7176
const Dtype* top_diff = top[0]->gpu_diff();

0 commit comments

Comments
 (0)