Skip to content

Commit 29dc3c5

Browse files
aocsafacebook-github-bot
authored andcommittedSep 24, 2020
Sparse softmax support (CUDA) (pytorch#42307)
Summary: This PR implements softmax support for sparse tensors. Resolves pytorchgh-23651 for CUDA. - [x] sparse softmax - [x] CUDA C++ implementation - [x] unittests - [x] update softmax documentation - [x] autograd support - [x] sparse log_softmax - [x] CUDA C++ implementation - [x] unittests - [x] update log_softmax documentation - [x] autograd support Here are some benchmark (script is [here](https://gist.github.com/aocsa/fbc1827b3e49901512a33ba96092cbc1)) results for `torch.sparse.softmax and torch.softmax`, using CPU and GPU, values are float64 scalars, timing repeat is 1000: | size | density | sparse CUDA | sparse CPU | |--------------|---------|-------------|------------| | (32, 10000) | 0.01 | 380.2 | 687.5 | | (32, 10000) | 0.05 | 404.3 | 2357.9 | | (32, 10000) | 0.1 | 405.9 | 3677.2 | | (512, 10000) | 0.01 | 438.0 | 5443.4 | | (512, 10000) | 0.05 | 888.1 | 24485.0 | | (512, 10000) | 0.1 | 1921.3 | 45340.5 | | size | density | dense CUDA | dense CPU | |--------------|---------|-------------|------------| | (32, 10000) | 0.01 | 23.6 | 1943.2 | | (32, 10000) | 0.05 | 23.6 | 1954.0 | | (32, 10000) | 0.1 | 23.5 | 1950.0 | | (512, 10000) | 0.01 | 639.3 | 39797.9 | | (512, 10000) | 0.05 | 640.3 | 39374.4 | | (512, 10000) | 0.1 | 639.6 | 39192.3 | Times are in microseconds (us). Quick note: I updated the performance test again. Pull Request resolved: pytorch#42307 Reviewed By: ngimel Differential Revision: D23774427 Pulled By: mruberry fbshipit-source-id: bfabf726075b39dde544c10249f27ae1871f82c7
1 parent b3d7c2f commit 29dc3c5

File tree

6 files changed

+777
-64
lines changed

6 files changed

+777
-64
lines changed
 

‎aten/src/ATen/native/native_functions.yaml

+4
Original file line numberDiff line numberDiff line change
@@ -3676,11 +3676,13 @@
36763676
use_c10_dispatcher: full
36773677
dispatch:
36783678
SparseCPU: softmax_sparse_cpu
3679+
SparseCUDA: softmax_sparse_cuda
36793680

36803681
- func: _sparse_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor
36813682
use_c10_dispatcher: full
36823683
dispatch:
36833684
SparseCPU: softmax_backward_sparse_cpu
3685+
SparseCUDA: softmax_backward_sparse_cuda
36843686

36853687
- func: _sparse_log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
36863688
use_c10_dispatcher: full
@@ -3693,11 +3695,13 @@
36933695
use_c10_dispatcher: full
36943696
dispatch:
36953697
SparseCPU: log_softmax_sparse_cpu
3698+
SparseCUDA: log_softmax_sparse_cuda
36963699

36973700
- func: _sparse_log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor
36983701
use_c10_dispatcher: full
36993702
dispatch:
37003703
SparseCPU: log_softmax_backward_sparse_cpu
3704+
SparseCUDA: log_softmax_backward_sparse_cuda
37013705

37023706
- func: norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor
37033707
use_c10_dispatcher: full
+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#include <ATen/native/sparse/ParamUtils.h>
2+
#include <ATen/TensorUtils.h>
3+
#include <ATen/ATen.h>
4+
#include <tuple>
5+
6+
namespace at {
7+
namespace native {
8+
9+
std::pair<Tensor, Tensor> softmax_sparse_input_preprocessing(
10+
const Tensor& input_,
11+
const int64_t dim_,
12+
const bool half_to_float,
13+
CheckedFrom function_name) {
14+
TORCH_INTERNAL_ASSERT(input_.is_sparse());
15+
TORCH_CHECK(
16+
!half_to_float,
17+
std::string(function_name) +
18+
": with half to float conversion is not supported on " +
19+
input_.device().str());
20+
auto input = input_.coalesce();
21+
Tensor output = at::native::empty_like(input);
22+
TORCH_CHECK(
23+
dim_ >= 0 && dim_ < input.dim(),
24+
": dim must be non-negative and less than input dimensions");
25+
return std::make_pair(input, output);
26+
}
27+
28+
std::tuple<Tensor, Tensor, Tensor> softmax_backward_sparse_input_preprocessing(
29+
const Tensor& grad_,
30+
const Tensor& output_,
31+
int64_t dim_,
32+
const Tensor& input_,
33+
CheckedFrom function_name) {
34+
TensorArg grad_arg{grad_, "grad", 1}, output_arg{output_, "output", 2};
35+
checkSameSize(function_name, grad_arg, output_arg);
36+
37+
int64_t dim = maybe_wrap_dim(dim_, grad_.dim());
38+
39+
auto grad = grad_.coalesce();
40+
auto output = output_.coalesce();
41+
42+
Tensor grad_input = at::native::empty_like(output);
43+
TORCH_CHECK(
44+
dim >= 0 && dim < grad.dim(),
45+
": dim must be non-negative and less than input dimensions");
46+
TORCH_CHECK(
47+
grad.sparse_dim() == output.sparse_dim(),
48+
": grad and output sparse dimensions must be equal");
49+
return std::make_tuple(grad_input, grad, output);
50+
}
51+
52+
} // namespace native
53+
} // namespace at
+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
#include <ATen/TensorUtils.h>
5+
#include <tuple>
6+
7+
namespace at {
8+
namespace native {
9+
10+
TORCH_API std::pair<Tensor, Tensor> softmax_sparse_input_preprocessing(
11+
const Tensor& input_,
12+
const int64_t dim_,
13+
const bool half_to_float,
14+
CheckedFrom function_name);
15+
16+
TORCH_API std::tuple<Tensor, Tensor, Tensor> softmax_backward_sparse_input_preprocessing(
17+
const Tensor& grad_,
18+
const Tensor& output_,
19+
int64_t dim_,
20+
const Tensor& input_,
21+
CheckedFrom function_name);
22+
23+
} // namespace native
24+
} // namespace at

‎aten/src/ATen/native/sparse/SoftMax.cpp

+54-63
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <ATen/SparseTensorUtils.h>
55
#include <ATen/Parallel.h>
66
#include <ATen/NamedTensorUtils.h>
7+
#include <ATen/native/sparse/ParamUtils.h>
78
#include <map>
89

910
namespace at {
@@ -291,10 +292,10 @@ void cpu_sparse_coo_softmax(Tensor output, const Tensor& input, const int64_t di
291292
if (dim >= sparse_dim) {
292293
if (LogSoftMax) {
293294
auto new_values = log_softmax_cpu(values, dim - sparse_dim + 1, false);
294-
out_values.copy_(new_values);
295+
out_values.set_(new_values);
295296
} else {
296297
auto new_values = softmax_cpu(values, dim - sparse_dim + 1, false);
297-
out_values.copy_(new_values);
298+
out_values.set_(new_values);
298299
}
299300
return;
300301
}
@@ -411,17 +412,27 @@ void cpu_sparse_coo_softmax_backward(Tensor& grad_input, const Tensor& grad, con
411412
auto grad_offsets = get_offsets(grad_indices, sizes, -1);
412413

413414
if (dim >= sparse_dim) {
414-
for(int64_t i=0; i<out_nnz; i++) {
415-
Tensor unused;
416-
auto low = std::lower_bound(grad_offsets.begin(), grad_offsets.end(), out_offsets[i]);
417-
auto j = low - grad_offsets.begin();
418-
if (j < grad_nnz && out_offsets[i] == grad_offsets[j]) {
419-
if (LogSoftMax) {
420-
auto r = log_softmax_backward_cpu(grad_values[j], out_values[i], dim - sparse_dim, unused);
421-
values[i].copy_(r);
422-
} else {
423-
auto r = softmax_backward_cpu(grad_values[j], out_values[i], dim - sparse_dim, unused);
424-
values[i].copy_(r);
415+
Tensor unused;
416+
if (out_offsets == grad_offsets) {
417+
if (LogSoftMax) {
418+
auto r = log_softmax_backward_cpu(grad_values, out_values, dim - sparse_dim + 1, unused);
419+
values.set_(r);
420+
} else {
421+
auto r = softmax_backward_cpu(grad_values, out_values, dim - sparse_dim + 1, unused);
422+
values.set_(r);
423+
}
424+
} else {
425+
for(int64_t i=0; i<out_nnz; i++) {
426+
auto low = std::lower_bound(grad_offsets.begin(), grad_offsets.end(), out_offsets[i]);
427+
auto j = low - grad_offsets.begin();
428+
if (j < grad_nnz && out_offsets[i] == grad_offsets[j]) {
429+
if (LogSoftMax) {
430+
auto r = log_softmax_backward_cpu(grad_values[j], out_values[i], dim - sparse_dim, unused);
431+
values[i].copy_(r);
432+
} else {
433+
auto r = softmax_backward_cpu(grad_values[j], out_values[i], dim - sparse_dim, unused);
434+
values[i].copy_(r);
435+
}
425436
}
426437
}
427438
}
@@ -503,36 +514,36 @@ void cpu_sparse_coo_softmax_backward(Tensor& grad_input, const Tensor& grad, con
503514
});
504515
}
505516

506-
} // namespace
517+
} // anonymous namespace
507518

508-
Tensor softmax_sparse_cpu(const Tensor& input_, const int64_t dim_, const bool half_to_float) {
509-
TORCH_INTERNAL_ASSERT(input_.is_sparse());
510-
TORCH_CHECK(!half_to_float, "softmax with half to float conversion is not supported on CPU");
511-
auto input = input_.coalesce();
512-
Tensor output = at::native::empty_like(input);
519+
Tensor softmax_sparse_cpu(
520+
const Tensor& input_,
521+
const int64_t dim,
522+
const bool half_to_float) {
523+
Tensor input, output;
524+
std::tie(input, output) = softmax_sparse_input_preprocessing(
525+
input_, dim, half_to_float, "softmax");
513526
if (input.numel() == 0) {
514527
return output;
515528
}
516-
TORCH_CHECK(dim_ >= 0 && dim_ < input.dim(),
517-
"dim must be non-negative and less than input dimensions");
518529
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "softmax", [&] {
519-
cpu_sparse_coo_softmax<scalar_t, false>(output, input, dim_);
530+
cpu_sparse_coo_softmax<scalar_t, false>(output, input, dim);
520531
});
521532
return output;
522533
}
523534

524-
Tensor log_softmax_sparse_cpu(const Tensor& input_, const int64_t dim_, const bool half_to_float) {
525-
TORCH_INTERNAL_ASSERT(input_.is_sparse());
526-
TORCH_CHECK(!half_to_float, "log_softmax with half to float conversion is not supported on CPU");
527-
auto input = input_.coalesce();
528-
Tensor output = at::native::empty_like(input);
535+
Tensor log_softmax_sparse_cpu(
536+
const Tensor& input_,
537+
const int64_t dim,
538+
const bool half_to_float) {
539+
Tensor input, output;
540+
std::tie(input, output) = softmax_sparse_input_preprocessing(
541+
input_, dim, half_to_float, "log_softmax");
529542
if (input.numel() == 0) {
530543
return output;
531544
}
532-
TORCH_CHECK(dim_ >= 0 && dim_ < input.dim(),
533-
"dim must be non-negative and less than input dimensions");
534545
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "log_softmax", [&] {
535-
cpu_sparse_coo_softmax<scalar_t, true>(output, input, dim_);
546+
cpu_sparse_coo_softmax<scalar_t, true>(output, input, dim);
536547
});
537548
return output;
538549
}
@@ -542,26 +553,16 @@ Tensor softmax_backward_sparse_cpu(
542553
const Tensor& output_,
543554
int64_t dim_,
544555
const Tensor& input_) {
545-
TensorArg grad_arg{grad_, "grad", 1}, output_arg{output_, "output", 2};
546-
checkSameSize("softmax_backward", grad_arg, output_arg);
547-
548-
int64_t dim = maybe_wrap_dim(dim_, grad_.dim());
549-
550-
auto grad = grad_.coalesce();
551-
auto output = output_.coalesce();
552-
553-
Tensor grad_input = at::native::empty_like(output);
556+
Tensor grad_input, grad, output;
557+
std::tie(grad_input, grad, output) =
558+
softmax_backward_sparse_input_preprocessing(
559+
grad_, output_, dim_, input_, "softmax_backward");
554560
if (output.numel() == 0) {
555561
return grad_input;
556562
}
557-
TORCH_CHECK(
558-
dim >= 0 && dim < grad.dim(),
559-
"dim must be non-negative and less than input dimensions");
560-
TORCH_CHECK(
561-
grad.sparse_dim() == output.sparse_dim(),
562-
"grad and output sparse dimensions must be equal");
563563
AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "softmax_backward", [&] {
564-
cpu_sparse_coo_softmax_backward<scalar_t, false>(grad_input, grad, output, dim);
564+
cpu_sparse_coo_softmax_backward<scalar_t, false>(
565+
grad_input, grad, output, dim_);
565566
});
566567
return grad_input;
567568
}
@@ -571,26 +572,16 @@ Tensor log_softmax_backward_sparse_cpu(
571572
const Tensor& output_,
572573
int64_t dim_,
573574
const Tensor& input_) {
574-
TensorArg grad_arg{grad_, "grad", 1}, output_arg{output_, "output", 2};
575-
checkSameSize("log_softmax_backward", grad_arg, output_arg);
576-
577-
int64_t dim = maybe_wrap_dim(dim_, grad_.dim());
578-
579-
auto grad = grad_.coalesce();
580-
auto output = output_.coalesce();
581-
582-
Tensor grad_input = at::native::empty_like(output);
575+
Tensor grad_input, grad, output;
576+
std::tie(grad_input, grad, output) =
577+
softmax_backward_sparse_input_preprocessing(
578+
grad_, output_, dim_, input_, "log_softmax_backward");
583579
if (output.numel() == 0) {
584580
return grad_input;
585581
}
586-
TORCH_CHECK(
587-
dim >= 0 && dim < grad.dim(),
588-
"dim must be non-negative and less than input dimensions");
589-
TORCH_CHECK(
590-
grad.sparse_dim() == output.sparse_dim(),
591-
"grad and output sparse dimensions must be equal");
592-
AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "softmax_backward", [&] {
593-
cpu_sparse_coo_softmax_backward<scalar_t, true>(grad_input, grad, output, dim);
582+
AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "log_softmax_backward", [&] {
583+
cpu_sparse_coo_softmax_backward<scalar_t, true>(
584+
grad_input, grad, output, dim_);
594585
});
595586
return grad_input;
596587
}

0 commit comments

Comments
 (0)
Please sign in to comment.