13
13
14
14
__all__ = ['lobpcg' ]
15
15
16
+ def _symeig_backward_complete_eigenspace (D_grad , U_grad , A , D , U ):
17
+ # compute F, such that F_ij = (d_j - d_i)^{-1} for i != j, F_ii = 0
18
+ F = D .unsqueeze (- 2 ) - D .unsqueeze (- 1 )
19
+ F .diagonal (dim1 = - 2 , dim2 = - 1 ).fill_ (float ('inf' ))
20
+ F .pow_ (- 1 )
21
+
22
+ # A.grad = U (D.grad + (U^T U.grad * F)) U^T
23
+ Ut = U .transpose (- 1 , - 2 ).contiguous ()
24
+ res = torch .matmul (
25
+ U ,
26
+ torch .matmul (
27
+ torch .diag_embed (D_grad ) + torch .matmul (Ut , U_grad ) * F ,
28
+ Ut
29
+ )
30
+ )
31
+
32
+ return res
33
+
34
+
35
+ def _polynomial_coefficients_given_roots (roots ):
36
+ """
37
+ Given the `roots` of a polynomial, find the polynomial's coefficients.
38
+
39
+ If roots = (r_1, ..., r_n), then the method returns
40
+ coefficients (a_0, a_1, ..., a_n (== 1)) so that
41
+ p(x) = (x - r_1) * ... * (x - r_n)
42
+ = x^n + a_{n-1} * x^{n-1} + ... a_1 * x_1 + a_0
43
+
44
+ Note: for better performance requires writing a low-level kernel
45
+ """
46
+ poly_order = roots .shape [- 1 ]
47
+ poly_coeffs_shape = list (roots .shape )
48
+ # we assume p(x) = x^n + a_{n-1} * x^{n-1} + ... + a_1 * x + a_0,
49
+ # so poly_coeffs = {a_0, ..., a_n, a_{n+1}(== 1)},
50
+ # but we insert one extra coefficient to enable better vectorization below
51
+ poly_coeffs_shape [- 1 ] += 2
52
+ poly_coeffs = roots .new_zeros (poly_coeffs_shape )
53
+ poly_coeffs [..., 0 ] = 1
54
+ poly_coeffs [..., - 1 ] = 1
55
+
56
+ # perform the Horner's rule
57
+ for i in range (1 , poly_order + 1 ):
58
+ # note that it is computationally hard to compute backward for this method,
59
+ # because then given the coefficients it would require finding the roots and/or
60
+ # calculating the sensitivity based on the Vieta's theorem.
61
+ # So the code below tries to circumvent the explicit root finding by series
62
+ # of operations on memory copies imitating the Horner's method.
63
+ # The memory copies are required to construct nodes in the computational graph
64
+ # by exploting the explicit (not in-place, separate node for each step)
65
+ # recursion of the Horner's method.
66
+ # Needs more memory, O(... * k^2), but with only O(... * k^2) complexity.
67
+ poly_coeffs_new = poly_coeffs .clone () if roots .requires_grad else poly_coeffs
68
+ out = poly_coeffs_new .narrow (- 1 , poly_order - i , i + 1 )
69
+ out -= roots .narrow (- 1 , i - 1 , 1 ) * poly_coeffs .narrow (- 1 , poly_order - i + 1 , i + 1 )
70
+ poly_coeffs = poly_coeffs_new
71
+
72
+ return poly_coeffs .narrow (- 1 , 1 , poly_order + 1 )
73
+
74
+
75
+ def _polynomial_value (poly , x , zero_power , transition ):
76
+ """
77
+ A generic method for computing poly(x) using the Horner's rule.
16
78
17
- def lobpcg (A , # type: Tensor
18
- k = None , # type: Optional[int]
19
- B = None , # type: Optional[Tensor]
20
- X = None , # type: Optional[Tensor]
21
- n = None , # type: Optional[int]
22
- iK = None , # type: Optional[Tensor]
23
- niter = None , # type: Optional[int]
24
- tol = None , # type: Optional[float]
25
- largest = None , # type: Optional[bool]
26
- method = None , # type: Optional[str]
27
- tracker = None , # type: Optional[None]
28
- ortho_iparams = None , # type: Optional[Dict[str, int]]
29
- ortho_fparams = None , # type: Optional[Dict[str, float]]
30
- ortho_bparams = None , # type: Optional[Dict[str, bool]]
31
- ):
32
- # type: (...) -> Tuple[Tensor, Tensor]
79
+ Arguments:
80
+ poly (Tensor): the (possibly batched) 1D Tensor representing
81
+ polynomial coefficients such that
82
+ poly[..., i] = (a_{i_0}, ..., a{i_n} (==1)), and
83
+ poly(x) = poly[..., 0] * zero_power + ... + poly[..., n] * x^n
84
+
85
+ x (Tensor): the value (possible batched) to evalate the polynomial `poly` at.
86
+
87
+ zero_power (Tensor): the represenation of `x^0`. It is application-specific.
88
+
89
+ transition (Callable): the function that accepts some intermediate result `int_val`,
90
+ the `x` and a specific polynomial coefficient
91
+ `poly[..., k]` for some iteration `k`.
92
+ It basically performs one iteration of the Horner's rule
93
+ defined as `x * int_val + poly[..., k] * zero_power`.
94
+ Note that `zero_power` is not a parameter,
95
+ because the step `+ poly[..., k] * zero_power` depends on `x`,
96
+ whether it is a vector, a matrix, or something else, so this
97
+ functionality is delegated to the user.
98
+ """
99
+
100
+ res = zero_power .clone ()
101
+ for k in range (poly .size (- 1 ) - 2 , - 1 , - 1 ):
102
+ res = transition (res , x , poly [..., k ])
103
+ return res
104
+
105
+ def _matrix_polynomial_value (poly , x , zero_power = None ):
106
+ """
107
+ Evaluates `poly(x)` for the (batched) matrix input `x`.
108
+ Check out `_polynomial_value` function for more details.
109
+ """
110
+
111
+ # matrix-aware Horner's rule iteration
112
+ def transition (curr_poly_val , x , poly_coeff ):
113
+ res = x .matmul (curr_poly_val )
114
+ res .diagonal (dim1 = - 2 , dim2 = - 1 ).add_ (poly_coeff .unsqueeze (- 1 ))
115
+ return res
116
+
117
+ if zero_power is None :
118
+ zero_power = torch .eye (x .size (- 1 ), x .size (- 1 ), dtype = x .dtype , device = x .device ) \
119
+ .view (* ([1 ] * len (list (x .shape [:- 2 ]))), x .size (- 1 ), x .size (- 1 ))
120
+
121
+ return _polynomial_value (poly , x , zero_power , transition )
122
+
123
+ def _vector_polynomial_value (poly , x , zero_power = None ):
124
+ """
125
+ Evaluates `poly(x)` for the (batched) vector input `x`.
126
+ Check out `_polynomial_value` function for more details.
127
+ """
128
+
129
+ # vector-aware Horner's rule iteration
130
+ def transition (curr_poly_val , x , poly_coeff ):
131
+ res = torch .addcmul (poly_coeff .unsqueeze (- 1 ), x , curr_poly_val )
132
+ return res
133
+
134
+ if zero_power is None :
135
+ zero_power = x .new_ones (1 ).expand (x .shape )
136
+
137
+ return _polynomial_value (poly , x , zero_power , transition )
138
+
139
+ def _symeig_backward_partial_eigenspace (D_grad , U_grad , A , D , U , largest ):
140
+ # compute a projection operator onto an orthogonal subspace spanned by the
141
+ # columns of U defined as (I - UU^T)
142
+ Ut = U .transpose (- 2 , - 1 ).contiguous ()
143
+ proj_U_ortho = - U .matmul (Ut )
144
+ proj_U_ortho .diagonal (dim1 = - 2 , dim2 = - 1 ).add_ (1 )
145
+
146
+ # compute U_ortho, a basis for the orthogonal complement to the span(U),
147
+ # by projecting a random [..., m, m - k] matrix onto the subspace spanned
148
+ # by the columns of U.
149
+ #
150
+ # fix generator for determinism
151
+ gen = torch .Generator (A .device )
152
+
153
+ # orthogonal complement to the span(U)
154
+ U_ortho = proj_U_ortho .matmul (
155
+ torch .randn (
156
+ (* A .shape [:- 1 ], A .size (- 1 ) - D .size (- 1 )),
157
+ dtype = A .dtype ,
158
+ device = A .device ,
159
+ generator = gen
160
+ )
161
+ )
162
+ U_ortho_t = U_ortho .transpose (- 2 , - 1 ).contiguous ()
163
+
164
+ # compute the coefficients of the characteristic polynomial of the tensor D.
165
+ # Note that D is diagonal, so the diagonal elements are exactly the roots
166
+ # of the characteristic polynomial.
167
+ chr_poly_D = _polynomial_coefficients_given_roots (D )
168
+
169
+ # the code belows finds the explicit solution to the Sylvester equation
170
+ # U_ortho^T A U_ortho dX - dX D = -U_ortho^T A U
171
+ # and incorporates it into the whole gradient stored in the `res` variable.
172
+ #
173
+ # Equivalent to the following naive implementation:
174
+ # res = A.new_zeros(A.shape)
175
+ # p_res = A.new_zeros(*A.shape[:-1], D.size(-1))
176
+ # for k in range(1, chr_poly_D.size(-1)):
177
+ # p_res.zero_()
178
+ # for i in range(0, k):
179
+ # p_res += (A.matrix_power(k - 1 - i) @ U_grad) * D.pow(i).unsqueeze(-2)
180
+ # res -= chr_poly_D[k] * (U_ortho @ poly_D_at_A.inverse() @ U_ortho_t @ p_res @ U.t())
181
+ #
182
+ # Note that dX is a differential, so the gradient contribution comes from the backward sensitivity
183
+ # Tr(f(U_grad, D_grad, A, U, D)^T dX) = Tr(g(U_grad, A, U, D)^T dA) for some functions f and g,
184
+ # and we need to compute g(U_grad, A, U, D)
185
+ #
186
+ # The naive implementation is based on the paper
187
+ # Hu, Qingxi, and Daizhan Cheng.
188
+ # "The polynomial solution to the Sylvester matrix equation."
189
+ # Applied mathematics letters 19.9 (2006): 859-864.
190
+ #
191
+ # We can modify the computation of `p_res` from above in a more efficient way
192
+ # p_res = U_grad * (chr_poly_D[1] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k)).unsqueeze(-2)
193
+ # + A U_grad * (chr_poly_D[2] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k - 1)).unsqueeze(-2)
194
+ # + ...
195
+ # + A.matrix_power(k - 1) U_grad * chr_poly_D[k]
196
+ # Note that this saves us from redundant matrix products with A (elimination of matrix_power)
197
+ U_grad_projected = U_grad
198
+ series_acc = U_grad_projected .new_zeros (U_grad_projected .shape )
199
+ for k in range (1 , chr_poly_D .size (- 1 )):
200
+ poly_D = _vector_polynomial_value (chr_poly_D [..., k :], D )
201
+ series_acc += U_grad_projected * poly_D .unsqueeze (- 2 )
202
+ U_grad_projected = A .matmul (U_grad_projected )
203
+
204
+ # compute chr_poly_D(A) which essentially is:
205
+ #
206
+ # chr_poly_D_at_A = A.new_zeros(A.shape)
207
+ # for k in range(chr_poly_D.size(-1)):
208
+ # chr_poly_D_at_A += chr_poly_D[k] * A.matrix_power(k)
209
+ #
210
+ # Note, however, for better performance we use the Horner's rule
211
+ chr_poly_D_at_A = _matrix_polynomial_value (chr_poly_D , A )
212
+
213
+ # compute the action of `chr_poly_D_at_A` restricted to U_ortho_t
214
+ chr_poly_D_at_A_to_U_ortho = torch .matmul (
215
+ U_ortho_t ,
216
+ torch .matmul (
217
+ chr_poly_D_at_A ,
218
+ U_ortho
219
+ )
220
+ )
221
+ # we need to invert 'chr_poly_D_at_A_to_U_ortho`, for that we compute its
222
+ # Cholesky decomposition and then use `torch.cholesky_solve` for better stability.
223
+ # Cholesky decomposition requires the input to be positive-definite.
224
+ # Note that `chr_poly_D_at_A_to_U_ortho` is positive-definite if
225
+ # 1. `largest` == False, or
226
+ # 2. `largest` == True and `k` is even
227
+ # under the assumption that `A` has distinct eigenvalues.
228
+ #
229
+ # check if `chr_poly_D_at_A_to_U_ortho` is positive-definite or negative-definite
230
+ chr_poly_D_at_A_to_U_ortho_sign = - 1 if (largest and (k % 2 == 1 )) else + 1
231
+ chr_poly_D_at_A_to_U_ortho_L = torch .cholesky (
232
+ chr_poly_D_at_A_to_U_ortho_sign * chr_poly_D_at_A_to_U_ortho
233
+ )
234
+
235
+ # compute the gradient part in span(U)
236
+ res = _symeig_backward_complete_eigenspace (
237
+ D_grad , U_grad , A , D , U
238
+ )
239
+
240
+ # incorporate the Sylvester equation solution into the full gradient
241
+ # it resides in span(U_ortho)
242
+ res -= U_ortho .matmul (
243
+ chr_poly_D_at_A_to_U_ortho_sign * torch .cholesky_solve (
244
+ U_ortho_t .matmul (series_acc ),
245
+ chr_poly_D_at_A_to_U_ortho_L
246
+ )
247
+ ).matmul (Ut )
248
+
249
+ return res
250
+
251
+ def _symeig_backward (D_grad , U_grad , A , D , U , largest ):
252
+ # if `U` is square, then the columns of `U` is a complete eigenspace
253
+ if U .size (- 1 ) == U .size (- 2 ):
254
+ return _symeig_backward_complete_eigenspace (
255
+ D_grad , U_grad , A , D , U
256
+ )
257
+ else :
258
+ return _symeig_backward_partial_eigenspace (
259
+ D_grad , U_grad , A , D , U , largest
260
+ )
261
+
262
+ class LOBPCGAutogradFunction (torch .autograd .Function ):
263
+
264
+ @staticmethod
265
+ def forward (ctx ,
266
+ A : Tensor ,
267
+ k : Optional [int ] = None ,
268
+ B : Optional [Tensor ] = None ,
269
+ X : Optional [Tensor ] = None ,
270
+ n : Optional [int ] = None ,
271
+ iK : Optional [Tensor ] = None ,
272
+ niter : Optional [int ] = None ,
273
+ tol : Optional [float ] = None ,
274
+ largest : Optional [bool ] = None ,
275
+ method : Optional [str ] = None ,
276
+ tracker : Optional [None ] = None ,
277
+ ortho_iparams : Optional [Dict [str , int ]] = None ,
278
+ ortho_fparams : Optional [Dict [str , float ]] = None ,
279
+ ortho_bparams : Optional [Dict [str , bool ]] = None
280
+ ) -> Tuple [Tensor , Tensor ]:
281
+
282
+ # makes sure that input is contiguous for efficiency.
283
+ # Note: autograd does not support dense gradients for sparse input yet.
284
+ A = A .contiguous () if (not A .is_sparse ) else A
285
+ if B is not None :
286
+ B = B .contiguous () if (not B .is_sparse ) else B
287
+
288
+ D , U = _lobpcg (
289
+ A , k , B , X ,
290
+ n , iK , niter , tol , largest , method , tracker ,
291
+ ortho_iparams , ortho_fparams , ortho_bparams
292
+ )
293
+
294
+ ctx .save_for_backward (A , B , D , U , largest )
295
+
296
+ return D , U
297
+
298
+ @staticmethod
299
+ def backward (ctx , D_grad , U_grad ):
300
+ A_grad = B_grad = None
301
+ grads = [None ] * 14
302
+
303
+ A , B , D , U , largest = ctx .saved_tensors
304
+
305
+ # lobpcg.backward has some limitations. Checks for unsupported input
306
+ if A .is_sparse or (B is not None and B .is_sparse and ctx .needs_input_grad [2 ]):
307
+ raise ValueError (
308
+ 'lobpcg.backward does not support sparse input yet.'
309
+ 'Note that lobpcg.forward does though.'
310
+ )
311
+ if A .dtype in (torch .complex64 , torch .complex128 ) or \
312
+ B is not None and B .dtype in (torch .complex64 , torch .complex128 ):
313
+ raise ValueError (
314
+ 'lobpcg.backward does not support complex input yet.'
315
+ 'Note that lobpcg.forward does though.'
316
+ )
317
+ if B is not None :
318
+ raise ValueError (
319
+ 'lobpcg.backward does not support backward with B != I yet.'
320
+ )
321
+
322
+ if largest is None :
323
+ largest = True
324
+
325
+ # symeig backward
326
+ if B is None :
327
+ A_grad = _symeig_backward (
328
+ D_grad , U_grad , A , D , U , largest
329
+ )
330
+
331
+ # A has index 0
332
+ grads [0 ] = A_grad
333
+ # B has index 2
334
+ grads [2 ] = B_grad
335
+ return tuple (grads )
336
+
337
+
338
+ def lobpcg (A : Tensor ,
339
+ k : Optional [int ] = None ,
340
+ B : Optional [Tensor ] = None ,
341
+ X : Optional [Tensor ] = None ,
342
+ n : Optional [int ] = None ,
343
+ iK : Optional [Tensor ] = None ,
344
+ niter : Optional [int ] = None ,
345
+ tol : Optional [float ] = None ,
346
+ largest : Optional [bool ] = None ,
347
+ method : Optional [str ] = None ,
348
+ tracker : Optional [None ] = None ,
349
+ ortho_iparams : Optional [Dict [str , int ]] = None ,
350
+ ortho_fparams : Optional [Dict [str , float ]] = None ,
351
+ ortho_bparams : Optional [Dict [str , bool ]] = None
352
+ ) -> Tuple [Tensor , Tensor ]:
33
353
34
354
"""Find the k largest (or smallest) eigenvalues and the corresponding
35
355
eigenvectors of a symmetric positive defined generalized
@@ -53,6 +373,17 @@ def lobpcg(A, # type: Tensor
53
373
not recommended but there exist cases where the usage of the
54
374
basic method may be preferred.
55
375
376
+ .. warning:: The backward method does not support sparse and complex inputs.
377
+ It works only when `B` is not provided (i.e. `B == None`).
378
+ We are actively working on extensions, and the details of
379
+ the algorithms are going to be published promptly.
380
+
381
+ .. warning:: While it is assumed that `A` is symmetric, `A.grad` is not.
382
+ To make sure that `A.grad` is symmetric, so that `A - t * A.grad` is symmetric
383
+ in first-order optimization routines, prior to running `lobpcg`
384
+ we do the following symmetrization map: `A -> (A + A.t()) / 2`.
385
+ The map is performed only when the `A` requires gradients.
386
+
56
387
Arguments:
57
388
58
389
A (Tensor): the input tensor of size :math:`(*, m, m)`
@@ -175,6 +506,51 @@ def lobpcg(A, # type: Tensor
175
506
ortho_fparams = ortho_fparams ,
176
507
ortho_bparams = ortho_bparams )
177
508
509
+ if not torch ._jit_internal .is_scripting ():
510
+ if A .requires_grad or (B is not None and B .requires_grad ):
511
+ # While it is expected that `A` is symmetric,
512
+ # the `A_grad` might be not. Therefore we perform the trick below,
513
+ # so that `A_grad` becomes symmetric.
514
+ # The symmetrization is important for first-order optimization methods,
515
+ # so that (A - alpha * A_grad) is still a symmetric matrix.
516
+ # Same holds for `B`.
517
+ A_sym = (A + A .transpose (- 2 , - 1 )) / 2
518
+ B_sym = (B + B .transpose (- 2 , - 1 )) / 2 if (B is not None ) else None
519
+
520
+ return LOBPCGAutogradFunction .apply (
521
+ A_sym , k , B_sym , X , n , iK , niter , tol , largest ,
522
+ method , tracker , ortho_iparams , ortho_fparams , ortho_bparams
523
+ )
524
+ else :
525
+ if A .requires_grad or (B is not None and B .requires_grad ):
526
+ raise RuntimeError (
527
+ 'Script and require grads is not supported atm.'
528
+ 'If you just want to do the forward, use .detach()'
529
+ 'on A and B before calling into lobpcg'
530
+ )
531
+
532
+ return _lobpcg (
533
+ A , k , B , X ,
534
+ n , iK , niter , tol , largest , method , tracker ,
535
+ ortho_iparams , ortho_fparams , ortho_bparams
536
+ )
537
+
538
+ def _lobpcg (A : Tensor ,
539
+ k : Optional [int ] = None ,
540
+ B : Optional [Tensor ] = None ,
541
+ X : Optional [Tensor ] = None ,
542
+ n : Optional [int ] = None ,
543
+ iK : Optional [Tensor ] = None ,
544
+ niter : Optional [int ] = None ,
545
+ tol : Optional [float ] = None ,
546
+ largest : Optional [bool ] = None ,
547
+ method : Optional [str ] = None ,
548
+ tracker : Optional [None ] = None ,
549
+ ortho_iparams : Optional [Dict [str , int ]] = None ,
550
+ ortho_fparams : Optional [Dict [str , float ]] = None ,
551
+ ortho_bparams : Optional [Dict [str , bool ]] = None
552
+ ) -> Tuple [Tensor , Tensor ]:
553
+
178
554
# A must be square:
179
555
assert A .shape [- 2 ] == A .shape [- 1 ], A .shape
180
556
if B is not None :
0 commit comments