Skip to content

Commit 241afc9

Browse files
RockingJavaBeanfacebook-github-bot
authored andcommittedSep 25, 2020
Migrate addr from the TH to Aten (CPU) (pytorch#44364)
Summary: Related pytorch#24507 Fixes pytorch#24666 This PR is to modernize the CPU implementation of the vector `outer product`. The existing TH implementation for `torch.attr` is migrated to `aten`, as the `torch.ger` manipulates the `addr` functions to calculate outer product, Pull Request resolved: pytorch#44364 Reviewed By: ezyang Differential Revision: D23866733 Pulled By: mruberry fbshipit-source-id: 5159ea22f0e3c991123fe7c19cc9beb6ad00301e
1 parent 99e0a87 commit 241afc9

16 files changed

+194
-672
lines changed
 

‎aten/src/ATen/LegacyTHFunctionsCPU.cpp

-255
Large diffs are not rendered by default.

‎aten/src/ATen/LegacyTHFunctionsCPU.h

-3
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,6 @@ Tensor & _th_renorm_(Tensor & self, Scalar p, int64_t dim, Scalar maxnorm);
3939
Tensor & _th_histc_out(Tensor & result, const Tensor & self, int64_t bins, Scalar min, Scalar max);
4040
Tensor _th_histc(const Tensor & self, int64_t bins, Scalar min, Scalar max);
4141
Tensor _th_trace(const Tensor & self);
42-
Tensor & _th_addr_out(Tensor & result, const Tensor & self, const Tensor & vec1, const Tensor & vec2, Scalar beta, Scalar alpha);
43-
Tensor _th_addr(const Tensor & self, const Tensor & vec1, const Tensor & vec2, Scalar beta, Scalar alpha);
44-
Tensor & _th_addr_(Tensor & self, const Tensor & vec1, const Tensor & vec2, Scalar beta, Scalar alpha);
4542
std::tuple<Tensor &,Tensor &> _th_gels_out(Tensor & res1, Tensor & res2, const Tensor & self, const Tensor & A);
4643
std::tuple<Tensor,Tensor> _th_gels(const Tensor & self, const Tensor & A);
4744
std::tuple<Tensor &,Tensor &> _th_eig_out(Tensor & res1, Tensor & res2, const Tensor & self, bool eigenvectors);

‎aten/src/ATen/cuda/CUDABlas.cpp

-40
Original file line numberDiff line numberDiff line change
@@ -498,46 +498,6 @@ void gemv<at::BFloat16>(CUDABLAS_GEMV_ARGTYPES(at::BFloat16)) {
498498
}
499499
#endif
500500

501-
namespace {
502-
template<typename scalar_t>
503-
cublasStatus_t cublasGer(const cublasHandle_t &handle, int64_t m, int64_t n, scalar_t *alpha, scalar_t *x, int64_t incx, scalar_t *y, int64_t incy, scalar_t *a, int64_t lda) {
504-
TORCH_CHECK(false, "cublas ger is defined only for float and double");
505-
return {};
506-
}
507-
template<>
508-
cublasStatus_t cublasGer<float>(const cublasHandle_t &handle, int64_t m, int64_t n, float *alpha, float *x, int64_t incx, float *y, int64_t incy, float *a, int64_t lda) {
509-
return cublasSger(handle, m, n, alpha, x, incx, y, incy, a, lda);
510-
}
511-
template<>
512-
cublasStatus_t cublasGer<double>(const cublasHandle_t &handle, int64_t m, int64_t n, double *alpha, double *x, int64_t incx, double *y, int64_t incy, double *a, int64_t lda) {
513-
return cublasDger(handle, m, n, alpha, x, incx, y, incy, a, lda);
514-
}
515-
} // anonymous namespace
516-
517-
template<typename scalar_t>
518-
void ger(int64_t m, int64_t n, scalar_t alpha, scalar_t *x, int64_t incx, scalar_t *y, int64_t incy, scalar_t *a, int64_t lda)
519-
{
520-
_cublasAdjustLdLevel2(m, n, &lda);
521-
TORCH_CHECK((m <= INT_MAX) &&
522-
(n <= INT_MAX) &&
523-
(lda <= INT_MAX) &&
524-
(incx <= INT_MAX) &&
525-
(incy <= INT_MAX),
526-
"cublasSger/cublasDger only supports m, n, lda, incx, incy with "
527-
"the bound [val] <= %d", INT_MAX);
528-
int i_m = (int)m;
529-
int i_n = (int)n;
530-
int i_lda = (int)lda;
531-
int i_incx = (int)incx;
532-
int i_incy = (int)incy;
533-
534-
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
535-
TORCH_CUDABLAS_CHECK(cublasGer<scalar_t>(
536-
handle, i_m, i_n, &alpha, x, i_incx, y, i_incy, a, i_lda));
537-
}
538-
template void ger<float>(int64_t m, int64_t n, float alpha, float *x, int64_t incx, float *y, int64_t incy, float *a, int64_t lda);
539-
template void ger<double>(int64_t m, int64_t n, double alpha, double *x, int64_t incx, double *y, int64_t incy, double *a, int64_t lda);
540-
541501
/* LEVEL 1 BLAS FUNCTIONS */
542502

543503
template <>

‎aten/src/ATen/native/LinearAlgebra.cpp

+37-26
Original file line numberDiff line numberDiff line change
@@ -143,50 +143,61 @@ static void check_1d(const Tensor& t, const char* arg, const char* fn) {
143143
}
144144

145145
Tensor addr(const Tensor& self, const Tensor& vec1, const Tensor& vec2, Scalar beta, Scalar alpha) {
146-
check_1d(vec1, "vec1", "addr");
147-
check_1d(vec2, "vec2", "addr");
148-
Tensor b_self;
149-
std::tie(b_self) = expand_size(self, {vec1.size(0), vec2.size(0)}, "addr");
150-
return at::_addr(b_self, vec1, vec2, beta, alpha);
146+
TORCH_WARN(
147+
"torch.addr is deprecated and may be removed in a future PyTorch release. "
148+
"This function can be implemented using torch.outer as "
149+
"alpha * torch.outer(vec1, vec2) + beta * input when beta is not zero, "
150+
"alpha * torch.outer(vec1, vec2) when beta is zero.");
151+
152+
Tensor outer_result = at::outer(vec1, vec2) * alpha;
153+
if (beta.to<double>() == 0.0) {
154+
return outer_result;
155+
}
156+
return outer_result + (self * beta);
151157
}
152158

153159
Tensor& addr_(Tensor& self, const Tensor& vec1, const Tensor& vec2, Scalar beta, Scalar alpha) {
154-
check_1d(vec1, "vec1", "addr");
155-
check_1d(vec2, "vec2", "addr");
156-
return at::_addr_(self, vec1, vec2, beta, alpha);
160+
return at::addr_out(self, self, vec1, vec2, beta, alpha);
157161
}
158162

159163
Tensor& addr_out(Tensor &result, const Tensor& self, const Tensor& vec1, const Tensor& vec2, Scalar beta, Scalar alpha) {
160-
check_1d(vec1, "vec1", "addr");
161-
check_1d(vec2, "vec2", "addr");
162-
Tensor b_self;
163-
std::tie(b_self) = expand_size(self, {vec1.size(0), vec2.size(0)}, "addr_out");
164-
return at::_addr_out(result, b_self, vec1, vec2, beta, alpha);
164+
auto addr_result = at::addr(self, vec1, vec2, beta, alpha);
165+
// Validates safe casting
166+
const auto result_dtype = addr_result.scalar_type();
167+
TORCH_CHECK(canCast(result_dtype, result.scalar_type()),
168+
"result type ", result_dtype,
169+
" can't be cast to the desired output type ", result.scalar_type());
170+
171+
at::native::resize_output(result, addr_result.sizes().vec());
172+
result.copy_(addr_result);
173+
return result;
165174
}
166175

176+
// torch.ger, alias for torch.outer
167177
Tensor& ger_out(Tensor &result, const Tensor& self, const Tensor& vec2) {
168-
check_1d(self, "self", "ger");
169-
check_1d(vec2, "vec2", "ger");
170-
if (result.dim() != 2 || result.size(0) != self.size(0) || result.size(1) != vec2.size(0)) {
171-
result.resize_({ self.size(0), vec2.size(0) });
172-
}
173-
// resize_ does the "broadcasting", don't need to broadcast again.
174-
return at::_addr_out(result, result, self, vec2, Scalar(0), Scalar(1));
178+
TORCH_WARN("torch.ger is deprecated and will be removed in a future PyTorch release. "
179+
"Use torch.outer instead.");
180+
return at::outer_out(result, self, vec2);
175181
}
176182

177183
Tensor ger(const Tensor& self, const Tensor& vec2) {
178-
Tensor result = at::empty({0}, self.options());
179-
at::ger_out(result, self, vec2);
180-
return result;
184+
return self.outer(vec2);
181185
}
182186

183-
// torch.outer, alias for torch.ger
184187
Tensor& outer_out(Tensor &result, const Tensor& self, const Tensor& vec2) {
185-
return at::ger_out(result, self, vec2);
188+
check_1d(self, "self", "outer");
189+
check_1d(vec2, "vec2", "outer");
190+
191+
// torch.outer is implemented as a composite op using reshape and mul
192+
at::mul_out(result, self.reshape({self.size(0), 1}), vec2);
193+
return result;
186194
}
187195

188196
Tensor outer(const Tensor& self, const Tensor& vec2) {
189-
return self.ger(vec2);
197+
check_1d(self, "self", "outer");
198+
check_1d(vec2, "vec2", "outer");
199+
200+
return self.reshape({self.size(0), 1}) * vec2;
190201
}
191202

192203
static void addmm_impl_cpu_(

‎aten/src/ATen/native/cuda/LinearAlgebra.cu

-114
Original file line numberDiff line numberDiff line change
@@ -178,120 +178,6 @@ Tensor& addmm__cuda(Tensor& self, const Tensor& mat1, const Tensor& mat2,
178178
return self;
179179
}
180180

181-
template<typename scalar_t>
182-
void addr_impl_ger_cuda(Tensor &out, const Tensor &self,
183-
const Tensor& vec1, const Tensor& vec2,
184-
scalar_t alpha, scalar_t beta) {
185-
static_assert(std::is_same<scalar_t, float>::value ||
186-
std::is_same<scalar_t, double>::value,
187-
"addr_impl_ger_cuda: only float and double are supported");
188-
if (&out != &self) {
189-
at::native::resize_as_(out, self);
190-
at::native::copy_(out, self);
191-
}
192-
if (beta == 0.0) {
193-
at::native::zero_(out);
194-
}
195-
if (beta != 1.0) {
196-
at::native::mul_(out, beta);
197-
}
198-
if (out.stride(0) == 1) {
199-
at::cuda::blas::ger<scalar_t>(
200-
vec1.size(0), vec2.size(0), alpha,
201-
vec1.data_ptr<scalar_t>(), vec1.stride(0),
202-
vec2.data_ptr<scalar_t>(), vec2.stride(0),
203-
out.data_ptr<scalar_t>(), out.stride(1)
204-
);
205-
} else if (out.stride(1) == 1) {
206-
at::cuda::blas::ger<scalar_t>(
207-
vec2.size(0), vec1.size(0), alpha,
208-
vec2.data_ptr<scalar_t>(), vec2.stride(0),
209-
vec1.data_ptr<scalar_t>(), vec1.stride(0),
210-
out.data_ptr<scalar_t>(), out.stride(0)
211-
);
212-
} else {
213-
Tensor cr = out.clone();
214-
at::cuda::blas::ger<scalar_t>(
215-
vec2.size(0), vec1.size(0), alpha,
216-
vec2.data_ptr<scalar_t>(), vec2.stride(0),
217-
vec1.data_ptr<scalar_t>(), vec1.stride(0),
218-
out.data_ptr<scalar_t>(), out.stride(0)
219-
);
220-
out.set_(cr);
221-
}
222-
}
223-
224-
template<typename scalar_t>
225-
void addr_impl_cuda(Tensor &out, const Tensor &self,
226-
const Tensor& vec1, const Tensor& vec2,
227-
scalar_t alpha, scalar_t beta) {
228-
// currently no Hger/SgerEx in Cublas.
229-
Tensor vec2T = vec2.reshape({1, vec2.size(0)});
230-
Tensor vec1M = vec1.reshape({vec1.size(0), 1});
231-
addmm_out_cuda(out, self, vec1M, vec2T, beta, alpha);
232-
}
233-
template<>
234-
void addr_impl_cuda<float>(Tensor &out, const Tensor &self,
235-
const Tensor& vec1, const Tensor& vec2,
236-
float alpha, float beta) {
237-
addr_impl_ger_cuda<float>(out, self, vec1, vec2, alpha, beta);
238-
}
239-
template<>
240-
void addr_impl_cuda<double>(Tensor &out, const Tensor &self,
241-
const Tensor& vec1, const Tensor& vec2,
242-
double alpha, double beta) {
243-
addr_impl_ger_cuda<double>(out, self, vec1, vec2, alpha, beta);
244-
}
245-
246-
Tensor& addr_out_cuda(Tensor &out, const Tensor& self,
247-
const Tensor& vec1, const Tensor& vec2,
248-
Scalar beta, Scalar alpha) {
249-
TORCH_CHECK(vec1.dim() == 1 && vec2.dim() == 1,
250-
"vec1 and vec2 should be 1-dimensional vectors. Got dimensions ",
251-
vec1.dim(), " and ", vec2.dim());
252-
253-
Tensor self_;
254-
if (&out != &self) {
255-
std::tie(self_) = expand_size(self, {vec1.size(0), vec2.size(0)}, "addr");
256-
} else {
257-
self_ = self;
258-
}
259-
260-
TORCH_CHECK(out.device() == self_.device() &&
261-
out.device() == vec1.device() &&
262-
out.device() == vec2.device(),
263-
"Expected all tensors to be on the same device. Found: ",
264-
out.device(), ", ", self_.device(), ", ",
265-
vec1.device(), " and ", vec2.device());
266-
TORCH_CHECK(self_.dim() == 2,
267-
"2D tensor expected, got ", self_.dim(), "D tensor for input");
268-
TORCH_CHECK(self_.size(0) == vec1.size(0) && self_.size(1) == vec2.size(0),
269-
"size mismatch",
270-
", input: ", self_.sizes(),
271-
", v1: ", vec1.sizes(),
272-
", v2: ", vec2.sizes());
273-
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, self_.scalar_type(), "addr_out_cuda", [&] {
274-
addr_impl_cuda<scalar_t>(out, self_, vec1, vec2,
275-
alpha.to<scalar_t>(), beta.to<scalar_t>());
276-
});
277-
return out;
278-
}
279-
280-
Tensor& addr__cuda(Tensor& self,
281-
const Tensor& vec1, const Tensor& vec2,
282-
Scalar beta, Scalar alpha) {
283-
addr_out_cuda(self, self, vec1, vec2, beta, alpha);
284-
return self;
285-
}
286-
287-
Tensor addr_cuda(const Tensor& self,
288-
const Tensor& vec1, const Tensor& vec2,
289-
Scalar beta, Scalar alpha) {
290-
Tensor out = at::empty({0}, self.options());
291-
addr_out_cuda(out, self, vec1, vec2, beta, alpha);
292-
return out;
293-
}
294-
295181
Tensor& addbmm_out_cuda(Tensor& out, const Tensor& self,
296182
const Tensor& batch1, const Tensor& batch2,
297183
Scalar beta, Scalar alpha) {

‎aten/src/ATen/native/native_functions.yaml

-17
Original file line numberDiff line numberDiff line change
@@ -6238,23 +6238,6 @@
62386238
use_c10_dispatcher: full
62396239
variants: method, function
62406240

6241-
- func: _addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
6242-
use_c10_dispatcher: full
6243-
dispatch:
6244-
CPU: legacy::cpu::_th_addr
6245-
CUDA: addr_cuda
6246-
6247-
- func: _addr_(Tensor(a!) self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
6248-
use_c10_dispatcher: full
6249-
dispatch:
6250-
CPU: legacy::cpu::_th_addr_
6251-
CUDA: addr__cuda
6252-
6253-
- func: _addr.out(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
6254-
dispatch:
6255-
CPU: legacy::cpu::_th_addr_out
6256-
CUDA: addr_out_cuda
6257-
62586241
- func: _index_copy_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!)
62596242
use_c10_dispatcher: full
62606243
dispatch:

‎aten/src/TH/generic/THBlas.cpp

-49
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ TH_EXTERNC void dcopy_(int *n, double *x, int *incx, double *y, int *incy);
1414
TH_EXTERNC void scopy_(int *n, float *x, int *incx, float *y, int *incy);
1515
TH_EXTERNC void daxpy_(int *n, double *a, double *x, int *incx, double *y, int *incy);
1616
TH_EXTERNC void saxpy_(int *n, float *a, float *x, int *incx, float *y, int *incy);
17-
TH_EXTERNC void dger_(int *m, int *n, double *alpha, double *x, int *incx, double *y, int *incy, double *a, int *lda);
18-
TH_EXTERNC void sger_(int *m, int *n, float *alpha, float *x, int *incx, float *y, int *incy, float *a, int *lda);
1917

2018
void THBlas_(swap)(int64_t n, scalar_t *x, int64_t incx, scalar_t *y, int64_t incy)
2119
{
@@ -111,51 +109,4 @@ void THBlas_(axpy)(int64_t n, scalar_t a, scalar_t *x, int64_t incx, scalar_t *y
111109
}
112110
}
113111

114-
void THBlas_(ger)(
115-
int64_t m,
116-
int64_t n,
117-
scalar_t alpha,
118-
scalar_t *x,
119-
int64_t incx,
120-
scalar_t *y,
121-
int64_t incy,
122-
scalar_t *a,
123-
int64_t lda)
124-
{
125-
if(n == 1)
126-
lda = m;
127-
128-
#if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT))
129-
if( (m <= INT_MAX) && (n <= INT_MAX) && (lda <= INT_MAX) &&
130-
(incx > 0) && (incx <= INT_MAX) &&
131-
(incy > 0) && (incy <= INT_MAX) )
132-
{
133-
THArgCheck(lda >= THMax(1, m), 9,
134-
"lda should be at least max(1, m=%d), but have %d", m, lda);
135-
int i_m = (int)m;
136-
int i_n = (int)n;
137-
int i_lda = (int)lda;
138-
int i_incx = (int)incx;
139-
int i_incy = (int)incy;
140-
141-
#if defined(TH_REAL_IS_DOUBLE)
142-
dger_(&i_m, &i_n, &alpha, x, &i_incx, y, &i_incy, a, &i_lda);
143-
#else
144-
sger_(&i_m, &i_n, &alpha, x, &i_incx, y, &i_incy, a, &i_lda);
145-
#endif
146-
return;
147-
}
148-
#endif
149-
{
150-
int64_t i, j;
151-
for(j = 0; j < n; j++)
152-
{
153-
scalar_t *column_ = a+j*lda;
154-
scalar_t z = alpha*y[j*incy];
155-
for(i = 0; i < m; i++)
156-
column_[i] += z*x[i*incx] ;
157-
}
158-
}
159-
}
160-
161112
#endif

‎aten/src/TH/generic/THBlas.h

-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,4 @@ TH_API void THBlas_(swap)(int64_t n, scalar_t *x, int64_t incx, scalar_t *y, int
77
TH_API void THBlas_(copy)(int64_t n, scalar_t *x, int64_t incx, scalar_t *y, int64_t incy);
88
TH_API void THBlas_(axpy)(int64_t n, scalar_t a, scalar_t *x, int64_t incx, scalar_t *y, int64_t incy);
99

10-
/* Level 2 */
11-
TH_API void THBlas_(ger)(int64_t m, int64_t n, scalar_t alpha, scalar_t *x, int64_t incx, scalar_t *y, int64_t incy, scalar_t *a, int64_t lda);
12-
1310
#endif

0 commit comments

Comments
 (0)
Please sign in to comment.