Skip to content

Commit e4950a0

Browse files
nikitavedfacebook-github-bot
authored andcommittedSep 28, 2020
Backward support for generalized eigenvalue solver with LOBPCG in forward [only k-rank SYMEIG case] (pytorch#43002)
Summary: As per title. Fixes [#{38948}](pytorch#38948). Therein you can find some blueprints for the algorithm being used in this PR. Pull Request resolved: pytorch#43002 Reviewed By: zou3519 Differential Revision: D23931326 Pulled By: albanD fbshipit-source-id: e6994af70d94145f974ef87aa5cea166d6deff1e
1 parent 6417a70 commit e4950a0

File tree

2 files changed

+453
-16
lines changed

2 files changed

+453
-16
lines changed
 

‎test/test_autograd.py

+61
Original file line numberDiff line numberDiff line change
@@ -2592,6 +2592,67 @@ def run_test(upper, dims):
25922592
for upper, dims in product([True, False], [(3, 3), (5, 3, 3), (4, 3, 2, 2)]):
25932593
run_test(upper, dims)
25942594

2595+
@slowTest
2596+
@skipIfNoLapack
2597+
def test_lobpcg(self):
2598+
2599+
def func(k, A, largest=True, B=None):
2600+
X_shape = list(A.shape)
2601+
X_shape[-1] = k
2602+
X = torch.eye(A.size(-2), k, dtype=A.dtype, device=A.device)
2603+
if A.dim() > 2:
2604+
X = X.expand(X_shape)
2605+
2606+
D, U = torch.lobpcg(A=A, k=k, B=B, X=X)
2607+
2608+
# LOBPCG uses a random initial eigenspace approximation
2609+
# if parameter `X` is not provided.
2610+
# This may cause a non-deterministic behavior
2611+
# when it comes to the sign of an eigenvector
2612+
# (note if v is an eigenvector, so is -v),
2613+
# hence we eliminate this non-determinism
2614+
# by making sure that each column of U
2615+
# gets multiplied by the sign of its max (in absolute value) element.
2616+
# Also, gradcheck changes the content of the input by +/- eps (default to 1e-06)
2617+
# to compute the numerical gradient which can also cause the signs to flip.
2618+
_, idx = U.abs().max(-2, keepdim=True)
2619+
sign = U.gather(-2, idx).sign()
2620+
U = U * sign
2621+
return D, U
2622+
2623+
def run_symeig_test(k, sizes, largest=True):
2624+
A = torch.rand(*sizes).double()
2625+
A = A.matmul(A.transpose(-1, -2)) / 10
2626+
A.requires_grad_(True)
2627+
2628+
gradcheck(lambda A: func(k, A, largest), A)
2629+
2630+
# Custom gradient vectors for better stability due to some
2631+
# non-determinism in the lobpcg's forward.
2632+
# Note it is not required if symeig is in forward instead (tested).
2633+
D_grad = torch.rand(*A.shape[:-2], k) / 100
2634+
U_grad = torch.rand(*A.shape[:-1], k) / 100
2635+
gradgradcheck(lambda A: func(k, A, largest), A, [D_grad, U_grad], atol=1e-4)
2636+
2637+
# check whether A.grad is symmetric
2638+
A = A.detach().requires_grad_(True)
2639+
D, U = func(k, A, largest)
2640+
(D.sum() + U.sum()).backward()
2641+
self.assertEqual(A.grad, A.grad.transpose(-1, -2))
2642+
2643+
# the tests below take about 1-2 minutes to finish,
2644+
# but we want to be extra sure that the backward is correct.
2645+
for largest in [True, False]:
2646+
run_symeig_test(1, (6, 6), largest=largest)
2647+
run_symeig_test(1, (2, 6, 6), largest=largest)
2648+
run_symeig_test(1, (2, 2, 6, 6), largest=largest)
2649+
run_symeig_test(2, (6, 6), largest=largest)
2650+
run_symeig_test(2, (2, 6, 6), largest=largest)
2651+
run_symeig_test(2, (2, 2, 6, 6), largest=largest)
2652+
run_symeig_test(3, (9, 9), largest=largest)
2653+
run_symeig_test(3, (2, 9, 9), largest=largest)
2654+
run_symeig_test(3, (2, 2, 9, 9), largest=largest)
2655+
25952656
@skipIfNoLapack
25962657
def test_cholesky_inverse(self):
25972658
def _test_with_size(upper, dims):

‎torch/_lobpcg.py

+392-16
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,343 @@
1313

1414
__all__ = ['lobpcg']
1515

16+
def _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U):
17+
# compute F, such that F_ij = (d_j - d_i)^{-1} for i != j, F_ii = 0
18+
F = D.unsqueeze(-2) - D.unsqueeze(-1)
19+
F.diagonal(dim1=-2, dim2=-1).fill_(float('inf'))
20+
F.pow_(-1)
21+
22+
# A.grad = U (D.grad + (U^T U.grad * F)) U^T
23+
Ut = U.transpose(-1, -2).contiguous()
24+
res = torch.matmul(
25+
U,
26+
torch.matmul(
27+
torch.diag_embed(D_grad) + torch.matmul(Ut, U_grad) * F,
28+
Ut
29+
)
30+
)
31+
32+
return res
33+
34+
35+
def _polynomial_coefficients_given_roots(roots):
36+
"""
37+
Given the `roots` of a polynomial, find the polynomial's coefficients.
38+
39+
If roots = (r_1, ..., r_n), then the method returns
40+
coefficients (a_0, a_1, ..., a_n (== 1)) so that
41+
p(x) = (x - r_1) * ... * (x - r_n)
42+
= x^n + a_{n-1} * x^{n-1} + ... a_1 * x_1 + a_0
43+
44+
Note: for better performance requires writing a low-level kernel
45+
"""
46+
poly_order = roots.shape[-1]
47+
poly_coeffs_shape = list(roots.shape)
48+
# we assume p(x) = x^n + a_{n-1} * x^{n-1} + ... + a_1 * x + a_0,
49+
# so poly_coeffs = {a_0, ..., a_n, a_{n+1}(== 1)},
50+
# but we insert one extra coefficient to enable better vectorization below
51+
poly_coeffs_shape[-1] += 2
52+
poly_coeffs = roots.new_zeros(poly_coeffs_shape)
53+
poly_coeffs[..., 0] = 1
54+
poly_coeffs[..., -1] = 1
55+
56+
# perform the Horner's rule
57+
for i in range(1, poly_order + 1):
58+
# note that it is computationally hard to compute backward for this method,
59+
# because then given the coefficients it would require finding the roots and/or
60+
# calculating the sensitivity based on the Vieta's theorem.
61+
# So the code below tries to circumvent the explicit root finding by series
62+
# of operations on memory copies imitating the Horner's method.
63+
# The memory copies are required to construct nodes in the computational graph
64+
# by exploting the explicit (not in-place, separate node for each step)
65+
# recursion of the Horner's method.
66+
# Needs more memory, O(... * k^2), but with only O(... * k^2) complexity.
67+
poly_coeffs_new = poly_coeffs.clone() if roots.requires_grad else poly_coeffs
68+
out = poly_coeffs_new.narrow(-1, poly_order - i, i + 1)
69+
out -= roots.narrow(-1, i - 1, 1) * poly_coeffs.narrow(-1, poly_order - i + 1, i + 1)
70+
poly_coeffs = poly_coeffs_new
71+
72+
return poly_coeffs.narrow(-1, 1, poly_order + 1)
73+
74+
75+
def _polynomial_value(poly, x, zero_power, transition):
76+
"""
77+
A generic method for computing poly(x) using the Horner's rule.
1678
17-
def lobpcg(A, # type: Tensor
18-
k=None, # type: Optional[int]
19-
B=None, # type: Optional[Tensor]
20-
X=None, # type: Optional[Tensor]
21-
n=None, # type: Optional[int]
22-
iK=None, # type: Optional[Tensor]
23-
niter=None, # type: Optional[int]
24-
tol=None, # type: Optional[float]
25-
largest=None, # type: Optional[bool]
26-
method=None, # type: Optional[str]
27-
tracker=None, # type: Optional[None]
28-
ortho_iparams=None, # type: Optional[Dict[str, int]]
29-
ortho_fparams=None, # type: Optional[Dict[str, float]]
30-
ortho_bparams=None, # type: Optional[Dict[str, bool]]
31-
):
32-
# type: (...) -> Tuple[Tensor, Tensor]
79+
Arguments:
80+
poly (Tensor): the (possibly batched) 1D Tensor representing
81+
polynomial coefficients such that
82+
poly[..., i] = (a_{i_0}, ..., a{i_n} (==1)), and
83+
poly(x) = poly[..., 0] * zero_power + ... + poly[..., n] * x^n
84+
85+
x (Tensor): the value (possible batched) to evalate the polynomial `poly` at.
86+
87+
zero_power (Tensor): the represenation of `x^0`. It is application-specific.
88+
89+
transition (Callable): the function that accepts some intermediate result `int_val`,
90+
the `x` and a specific polynomial coefficient
91+
`poly[..., k]` for some iteration `k`.
92+
It basically performs one iteration of the Horner's rule
93+
defined as `x * int_val + poly[..., k] * zero_power`.
94+
Note that `zero_power` is not a parameter,
95+
because the step `+ poly[..., k] * zero_power` depends on `x`,
96+
whether it is a vector, a matrix, or something else, so this
97+
functionality is delegated to the user.
98+
"""
99+
100+
res = zero_power.clone()
101+
for k in range(poly.size(-1) - 2, -1, -1):
102+
res = transition(res, x, poly[..., k])
103+
return res
104+
105+
def _matrix_polynomial_value(poly, x, zero_power=None):
106+
"""
107+
Evaluates `poly(x)` for the (batched) matrix input `x`.
108+
Check out `_polynomial_value` function for more details.
109+
"""
110+
111+
# matrix-aware Horner's rule iteration
112+
def transition(curr_poly_val, x, poly_coeff):
113+
res = x.matmul(curr_poly_val)
114+
res.diagonal(dim1=-2, dim2=-1).add_(poly_coeff.unsqueeze(-1))
115+
return res
116+
117+
if zero_power is None:
118+
zero_power = torch.eye(x.size(-1), x.size(-1), dtype=x.dtype, device=x.device) \
119+
.view(*([1] * len(list(x.shape[:-2]))), x.size(-1), x.size(-1))
120+
121+
return _polynomial_value(poly, x, zero_power, transition)
122+
123+
def _vector_polynomial_value(poly, x, zero_power=None):
124+
"""
125+
Evaluates `poly(x)` for the (batched) vector input `x`.
126+
Check out `_polynomial_value` function for more details.
127+
"""
128+
129+
# vector-aware Horner's rule iteration
130+
def transition(curr_poly_val, x, poly_coeff):
131+
res = torch.addcmul(poly_coeff.unsqueeze(-1), x, curr_poly_val)
132+
return res
133+
134+
if zero_power is None:
135+
zero_power = x.new_ones(1).expand(x.shape)
136+
137+
return _polynomial_value(poly, x, zero_power, transition)
138+
139+
def _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest):
140+
# compute a projection operator onto an orthogonal subspace spanned by the
141+
# columns of U defined as (I - UU^T)
142+
Ut = U.transpose(-2, -1).contiguous()
143+
proj_U_ortho = -U.matmul(Ut)
144+
proj_U_ortho.diagonal(dim1=-2, dim2=-1).add_(1)
145+
146+
# compute U_ortho, a basis for the orthogonal complement to the span(U),
147+
# by projecting a random [..., m, m - k] matrix onto the subspace spanned
148+
# by the columns of U.
149+
#
150+
# fix generator for determinism
151+
gen = torch.Generator(A.device)
152+
153+
# orthogonal complement to the span(U)
154+
U_ortho = proj_U_ortho.matmul(
155+
torch.randn(
156+
(*A.shape[:-1], A.size(-1) - D.size(-1)),
157+
dtype=A.dtype,
158+
device=A.device,
159+
generator=gen
160+
)
161+
)
162+
U_ortho_t = U_ortho.transpose(-2, -1).contiguous()
163+
164+
# compute the coefficients of the characteristic polynomial of the tensor D.
165+
# Note that D is diagonal, so the diagonal elements are exactly the roots
166+
# of the characteristic polynomial.
167+
chr_poly_D = _polynomial_coefficients_given_roots(D)
168+
169+
# the code belows finds the explicit solution to the Sylvester equation
170+
# U_ortho^T A U_ortho dX - dX D = -U_ortho^T A U
171+
# and incorporates it into the whole gradient stored in the `res` variable.
172+
#
173+
# Equivalent to the following naive implementation:
174+
# res = A.new_zeros(A.shape)
175+
# p_res = A.new_zeros(*A.shape[:-1], D.size(-1))
176+
# for k in range(1, chr_poly_D.size(-1)):
177+
# p_res.zero_()
178+
# for i in range(0, k):
179+
# p_res += (A.matrix_power(k - 1 - i) @ U_grad) * D.pow(i).unsqueeze(-2)
180+
# res -= chr_poly_D[k] * (U_ortho @ poly_D_at_A.inverse() @ U_ortho_t @ p_res @ U.t())
181+
#
182+
# Note that dX is a differential, so the gradient contribution comes from the backward sensitivity
183+
# Tr(f(U_grad, D_grad, A, U, D)^T dX) = Tr(g(U_grad, A, U, D)^T dA) for some functions f and g,
184+
# and we need to compute g(U_grad, A, U, D)
185+
#
186+
# The naive implementation is based on the paper
187+
# Hu, Qingxi, and Daizhan Cheng.
188+
# "The polynomial solution to the Sylvester matrix equation."
189+
# Applied mathematics letters 19.9 (2006): 859-864.
190+
#
191+
# We can modify the computation of `p_res` from above in a more efficient way
192+
# p_res = U_grad * (chr_poly_D[1] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k)).unsqueeze(-2)
193+
# + A U_grad * (chr_poly_D[2] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k - 1)).unsqueeze(-2)
194+
# + ...
195+
# + A.matrix_power(k - 1) U_grad * chr_poly_D[k]
196+
# Note that this saves us from redundant matrix products with A (elimination of matrix_power)
197+
U_grad_projected = U_grad
198+
series_acc = U_grad_projected.new_zeros(U_grad_projected.shape)
199+
for k in range(1, chr_poly_D.size(-1)):
200+
poly_D = _vector_polynomial_value(chr_poly_D[..., k:], D)
201+
series_acc += U_grad_projected * poly_D.unsqueeze(-2)
202+
U_grad_projected = A.matmul(U_grad_projected)
203+
204+
# compute chr_poly_D(A) which essentially is:
205+
#
206+
# chr_poly_D_at_A = A.new_zeros(A.shape)
207+
# for k in range(chr_poly_D.size(-1)):
208+
# chr_poly_D_at_A += chr_poly_D[k] * A.matrix_power(k)
209+
#
210+
# Note, however, for better performance we use the Horner's rule
211+
chr_poly_D_at_A = _matrix_polynomial_value(chr_poly_D, A)
212+
213+
# compute the action of `chr_poly_D_at_A` restricted to U_ortho_t
214+
chr_poly_D_at_A_to_U_ortho = torch.matmul(
215+
U_ortho_t,
216+
torch.matmul(
217+
chr_poly_D_at_A,
218+
U_ortho
219+
)
220+
)
221+
# we need to invert 'chr_poly_D_at_A_to_U_ortho`, for that we compute its
222+
# Cholesky decomposition and then use `torch.cholesky_solve` for better stability.
223+
# Cholesky decomposition requires the input to be positive-definite.
224+
# Note that `chr_poly_D_at_A_to_U_ortho` is positive-definite if
225+
# 1. `largest` == False, or
226+
# 2. `largest` == True and `k` is even
227+
# under the assumption that `A` has distinct eigenvalues.
228+
#
229+
# check if `chr_poly_D_at_A_to_U_ortho` is positive-definite or negative-definite
230+
chr_poly_D_at_A_to_U_ortho_sign = -1 if (largest and (k % 2 == 1)) else +1
231+
chr_poly_D_at_A_to_U_ortho_L = torch.cholesky(
232+
chr_poly_D_at_A_to_U_ortho_sign * chr_poly_D_at_A_to_U_ortho
233+
)
234+
235+
# compute the gradient part in span(U)
236+
res = _symeig_backward_complete_eigenspace(
237+
D_grad, U_grad, A, D, U
238+
)
239+
240+
# incorporate the Sylvester equation solution into the full gradient
241+
# it resides in span(U_ortho)
242+
res -= U_ortho.matmul(
243+
chr_poly_D_at_A_to_U_ortho_sign * torch.cholesky_solve(
244+
U_ortho_t.matmul(series_acc),
245+
chr_poly_D_at_A_to_U_ortho_L
246+
)
247+
).matmul(Ut)
248+
249+
return res
250+
251+
def _symeig_backward(D_grad, U_grad, A, D, U, largest):
252+
# if `U` is square, then the columns of `U` is a complete eigenspace
253+
if U.size(-1) == U.size(-2):
254+
return _symeig_backward_complete_eigenspace(
255+
D_grad, U_grad, A, D, U
256+
)
257+
else:
258+
return _symeig_backward_partial_eigenspace(
259+
D_grad, U_grad, A, D, U, largest
260+
)
261+
262+
class LOBPCGAutogradFunction(torch.autograd.Function):
263+
264+
@staticmethod
265+
def forward(ctx,
266+
A: Tensor,
267+
k: Optional[int] = None,
268+
B: Optional[Tensor] = None,
269+
X: Optional[Tensor] = None,
270+
n: Optional[int] = None,
271+
iK: Optional[Tensor] = None,
272+
niter: Optional[int] = None,
273+
tol: Optional[float] = None,
274+
largest: Optional[bool] = None,
275+
method: Optional[str] = None,
276+
tracker: Optional[None] = None,
277+
ortho_iparams: Optional[Dict[str, int]] = None,
278+
ortho_fparams: Optional[Dict[str, float]] = None,
279+
ortho_bparams: Optional[Dict[str, bool]] = None
280+
) -> Tuple[Tensor, Tensor]:
281+
282+
# makes sure that input is contiguous for efficiency.
283+
# Note: autograd does not support dense gradients for sparse input yet.
284+
A = A.contiguous() if (not A.is_sparse) else A
285+
if B is not None:
286+
B = B.contiguous() if (not B.is_sparse) else B
287+
288+
D, U = _lobpcg(
289+
A, k, B, X,
290+
n, iK, niter, tol, largest, method, tracker,
291+
ortho_iparams, ortho_fparams, ortho_bparams
292+
)
293+
294+
ctx.save_for_backward(A, B, D, U, largest)
295+
296+
return D, U
297+
298+
@staticmethod
299+
def backward(ctx, D_grad, U_grad):
300+
A_grad = B_grad = None
301+
grads = [None] * 14
302+
303+
A, B, D, U, largest = ctx.saved_tensors
304+
305+
# lobpcg.backward has some limitations. Checks for unsupported input
306+
if A.is_sparse or (B is not None and B.is_sparse and ctx.needs_input_grad[2]):
307+
raise ValueError(
308+
'lobpcg.backward does not support sparse input yet.'
309+
'Note that lobpcg.forward does though.'
310+
)
311+
if A.dtype in (torch.complex64, torch.complex128) or \
312+
B is not None and B.dtype in (torch.complex64, torch.complex128):
313+
raise ValueError(
314+
'lobpcg.backward does not support complex input yet.'
315+
'Note that lobpcg.forward does though.'
316+
)
317+
if B is not None:
318+
raise ValueError(
319+
'lobpcg.backward does not support backward with B != I yet.'
320+
)
321+
322+
if largest is None:
323+
largest = True
324+
325+
# symeig backward
326+
if B is None:
327+
A_grad = _symeig_backward(
328+
D_grad, U_grad, A, D, U, largest
329+
)
330+
331+
# A has index 0
332+
grads[0] = A_grad
333+
# B has index 2
334+
grads[2] = B_grad
335+
return tuple(grads)
336+
337+
338+
def lobpcg(A: Tensor,
339+
k: Optional[int] = None,
340+
B: Optional[Tensor] = None,
341+
X: Optional[Tensor] = None,
342+
n: Optional[int] = None,
343+
iK: Optional[Tensor] = None,
344+
niter: Optional[int] = None,
345+
tol: Optional[float] = None,
346+
largest: Optional[bool] = None,
347+
method: Optional[str] = None,
348+
tracker: Optional[None] = None,
349+
ortho_iparams: Optional[Dict[str, int]] = None,
350+
ortho_fparams: Optional[Dict[str, float]] = None,
351+
ortho_bparams: Optional[Dict[str, bool]] = None
352+
) -> Tuple[Tensor, Tensor]:
33353

34354
"""Find the k largest (or smallest) eigenvalues and the corresponding
35355
eigenvectors of a symmetric positive defined generalized
@@ -53,6 +373,17 @@ def lobpcg(A, # type: Tensor
53373
not recommended but there exist cases where the usage of the
54374
basic method may be preferred.
55375
376+
.. warning:: The backward method does not support sparse and complex inputs.
377+
It works only when `B` is not provided (i.e. `B == None`).
378+
We are actively working on extensions, and the details of
379+
the algorithms are going to be published promptly.
380+
381+
.. warning:: While it is assumed that `A` is symmetric, `A.grad` is not.
382+
To make sure that `A.grad` is symmetric, so that `A - t * A.grad` is symmetric
383+
in first-order optimization routines, prior to running `lobpcg`
384+
we do the following symmetrization map: `A -> (A + A.t()) / 2`.
385+
The map is performed only when the `A` requires gradients.
386+
56387
Arguments:
57388
58389
A (Tensor): the input tensor of size :math:`(*, m, m)`
@@ -175,6 +506,51 @@ def lobpcg(A, # type: Tensor
175506
ortho_fparams=ortho_fparams,
176507
ortho_bparams=ortho_bparams)
177508

509+
if not torch._jit_internal.is_scripting():
510+
if A.requires_grad or (B is not None and B.requires_grad):
511+
# While it is expected that `A` is symmetric,
512+
# the `A_grad` might be not. Therefore we perform the trick below,
513+
# so that `A_grad` becomes symmetric.
514+
# The symmetrization is important for first-order optimization methods,
515+
# so that (A - alpha * A_grad) is still a symmetric matrix.
516+
# Same holds for `B`.
517+
A_sym = (A + A.transpose(-2, -1)) / 2
518+
B_sym = (B + B.transpose(-2, -1)) / 2 if (B is not None) else None
519+
520+
return LOBPCGAutogradFunction.apply(
521+
A_sym, k, B_sym, X, n, iK, niter, tol, largest,
522+
method, tracker, ortho_iparams, ortho_fparams, ortho_bparams
523+
)
524+
else:
525+
if A.requires_grad or (B is not None and B.requires_grad):
526+
raise RuntimeError(
527+
'Script and require grads is not supported atm.'
528+
'If you just want to do the forward, use .detach()'
529+
'on A and B before calling into lobpcg'
530+
)
531+
532+
return _lobpcg(
533+
A, k, B, X,
534+
n, iK, niter, tol, largest, method, tracker,
535+
ortho_iparams, ortho_fparams, ortho_bparams
536+
)
537+
538+
def _lobpcg(A: Tensor,
539+
k: Optional[int] = None,
540+
B: Optional[Tensor] = None,
541+
X: Optional[Tensor] = None,
542+
n: Optional[int] = None,
543+
iK: Optional[Tensor] = None,
544+
niter: Optional[int] = None,
545+
tol: Optional[float] = None,
546+
largest: Optional[bool] = None,
547+
method: Optional[str] = None,
548+
tracker: Optional[None] = None,
549+
ortho_iparams: Optional[Dict[str, int]] = None,
550+
ortho_fparams: Optional[Dict[str, float]] = None,
551+
ortho_bparams: Optional[Dict[str, bool]] = None
552+
) -> Tuple[Tensor, Tensor]:
553+
178554
# A must be square:
179555
assert A.shape[-2] == A.shape[-1], A.shape
180556
if B is not None:

0 commit comments

Comments
 (0)
Please sign in to comment.