Skip to content

Commit f8765cf

Browse files
authored
Merge pull request #2577 from spinkney/generalized_inverse_fix
check generalized inverse for full rank symmetric mat
2 parents 8920da4 + 3644737 commit f8765cf

File tree

4 files changed

+93
-11
lines changed

4 files changed

+93
-11
lines changed

stan/math/prim/fun/generalized_inverse.hpp

+13-6
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,22 @@ template <typename EigMat, require_eigen_t<EigMat>* = nullptr,
3131
inline Eigen::Matrix<value_type_t<EigMat>, EigMat::ColsAtCompileTime,
3232
EigMat::RowsAtCompileTime>
3333
generalized_inverse(const EigMat& G) {
34-
if (G.size() == 0)
34+
const auto& G_ref = to_ref(G);
35+
if (G_ref.size() == 0)
3536
return {};
3637

37-
if (G.rows() == G.cols())
38-
return inverse(G);
39-
40-
const auto& G_ref = to_ref(G);
38+
if (G_ref.rows() == G_ref.cols()) {
39+
Eigen::CompleteOrthogonalDecomposition<
40+
Eigen::Matrix<value_type_t<EigMat>, EigMat::RowsAtCompileTime,
41+
EigMat::ColsAtCompileTime>>
42+
complete_ortho_decomp_G = G_ref.completeOrthogonalDecomposition();
43+
if (!(complete_ortho_decomp_G.rank() < G_ref.rows()))
44+
return inverse(G_ref);
45+
else
46+
return complete_ortho_decomp_G.pseudoInverse();
47+
}
4148

42-
if (G.rows() < G.cols()) {
49+
if (G_ref.rows() < G_ref.cols()) {
4350
return (G_ref * G_ref.transpose()).ldlt().solve(G_ref).transpose();
4451
} else {
4552
return (G_ref.transpose() * G_ref).ldlt().solve(G_ref.transpose());

stan/math/rev/fun/generalized_inverse.hpp

+15-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef STAN_MATH_REV_FUN_GENERALIZED_INVERSE_HPP
22
#define STAN_MATH_REV_FUN_GENERALIZED_INVERSE_HPP
33

4+
#include <stan/math/prim/fun/add_diag.hpp>
45
#include <stan/math/rev/core.hpp>
56
#include <stan/math/prim/err.hpp>
67
#include <stan/math/prim/fun/Eigen.hpp>
@@ -66,10 +67,20 @@ inline auto generalized_inverse(const VarMat& G) {
6667
if (G.size() == 0)
6768
return ret_type(G);
6869

69-
if (G.rows() == G.cols())
70-
return ret_type(inverse(G));
71-
72-
if (G.rows() < G.cols()) {
70+
if (G.rows() == G.cols()) {
71+
arena_t<VarMat> G_arena(G);
72+
Eigen::CompleteOrthogonalDecomposition<Eigen::MatrixXd>
73+
complete_ortho_decomp_G
74+
= G_arena.val().completeOrthogonalDecomposition();
75+
if (!(complete_ortho_decomp_G.rank() < G.rows())) {
76+
return ret_type(inverse(G_arena));
77+
} else {
78+
arena_t<ret_type> inv_G(complete_ortho_decomp_G.pseudoInverse());
79+
reverse_pass_callback(
80+
internal::generalized_inverse_lambda(G_arena, inv_G));
81+
return ret_type(inv_G);
82+
}
83+
} else if (G.rows() < G.cols()) {
7384
arena_t<VarMat> G_arena(G);
7485
arena_t<ret_type> inv_G((G_arena.val_op() * G_arena.val_op().transpose())
7586
.ldlt()

test/unit/math/mix/fun/generalized_inverse_test.cpp

+51-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#include <vector>
33
#include <gtest/gtest.h>
44

5-
TEST(mathMixMatFun, ad_tests) {
5+
TEST(mathMixMatFun, gen_inv_test_1) {
66
using stan::test::expect_ad;
77
using stan::test::expect_ad_matvar;
88

@@ -11,25 +11,61 @@ TEST(mathMixMatFun, ad_tests) {
1111
Eigen::MatrixXd t(0, 0);
1212
expect_ad(f, t);
1313
expect_ad_matvar(f, t);
14+
}
15+
16+
TEST(mathMixMatFun, gen_inv_test_2) {
17+
using stan::test::expect_ad;
18+
using stan::test::expect_ad_matvar;
19+
20+
auto f = [](const auto& G) { return stan::math::generalized_inverse(G); };
1421

1522
Eigen::MatrixXd u(1, 1);
1623
u << 2;
1724
expect_ad(f, u);
1825
expect_ad_matvar(f, u);
26+
}
27+
28+
TEST(mathMixMatFun, gen_inv_test_3) {
29+
using stan::test::expect_ad;
30+
using stan::test::expect_ad_matvar;
31+
32+
auto f = [](const auto& G) { return stan::math::generalized_inverse(G); };
1933

2034
Eigen::MatrixXd v(2, 3);
2135
v << 1, 3, 5, 2, 4, 6;
2236
expect_ad(f, v);
2337
expect_ad_matvar(f, v);
38+
}
39+
40+
TEST(mathMixMatFun, gen_inv_test_4) {
41+
using stan::test::expect_ad;
42+
using stan::test::expect_ad_matvar;
2443

44+
auto f = [](const auto& G) { return stan::math::generalized_inverse(G); };
45+
46+
Eigen::MatrixXd v(2, 3);
2547
v << 1.9, 1.3, 2.5, 0.4, 1.7, 0.1;
2648
expect_ad(f, v);
2749
expect_ad_matvar(f, v);
50+
}
51+
52+
TEST(mathMixMatFun, gen_inv_test_5) {
53+
using stan::test::expect_ad;
54+
using stan::test::expect_ad_matvar;
55+
56+
auto f = [](const auto& G) { return stan::math::generalized_inverse(G); };
2857

2958
Eigen::MatrixXd s(2, 4);
3059
s << 3.4, 2, 5, 1.2, 2, 1, 3.2, 3.1;
3160
expect_ad(f, s);
3261
expect_ad_matvar(f, s);
62+
}
63+
64+
TEST(mathMixMatFun, gen_inv_test_6) {
65+
using stan::test::expect_ad;
66+
using stan::test::expect_ad_matvar;
67+
68+
auto f = [](const auto& G) { return stan::math::generalized_inverse(G); };
3369

3470
// issues around zero require looser tolerances for hessians
3571
stan::test::ad_tolerances tols;
@@ -40,10 +76,24 @@ TEST(mathMixMatFun, ad_tests) {
4076
w << 2, 3, 5, 7, 11, 13, 17, 19, 23, 25, 27, 29;
4177
expect_ad(tols, f, w);
4278
expect_ad_matvar(f, w);
79+
}
80+
81+
TEST(mathMixMatFun, gen_inv_test_8) {
82+
using stan::test::expect_ad;
83+
using stan::test::expect_ad_matvar;
84+
85+
auto f = [](const auto& G) { return stan::math::generalized_inverse(G); };
4386

4487
Eigen::MatrixXd z(2, 2);
4588
z << 1, 2, 5, std::numeric_limits<double>::quiet_NaN();
4689
EXPECT_NO_THROW(stan::math::generalized_inverse(z));
90+
}
91+
92+
TEST(mathMixMatFun, gen_inv_test_9) {
93+
using stan::test::expect_ad;
94+
using stan::test::expect_ad_matvar;
95+
96+
auto f = [](const auto& G) { return stan::math::generalized_inverse(G); };
4797

4898
// autodiff throws, so following fails (throw behavior must match to pass)
4999

test/unit/math/prim/fun/generalized_inverse_test.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,17 @@ TEST(MathMatrixPrim, Equal2) {
4545
stan::math::matrix_d m3 = m1 * generalized_inverse(m1);
4646
EXPECT_MATRIX_NEAR(m2, m3, 1e-9);
4747
}
48+
49+
TEST(MathMatrixPrim, Equal3) {
50+
using stan::math::generalized_inverse;
51+
52+
stan::math::matrix_d m1(3, 3);
53+
m1 << 4, 4, -2, 4, 4, -2, -2, -2, 10;
54+
55+
stan::math::matrix_d m2(3, 3);
56+
m2 << 0.069444444, 0.069444444, 0.0277777777, 0.069444444, 0.069444444,
57+
0.0277777777, 0.027777777, 0.027777777, 0.1111111111;
58+
59+
stan::math::matrix_d m3 = generalized_inverse(m1);
60+
EXPECT_MATRIX_NEAR(m2, m3, 1e-9);
61+
}

0 commit comments

Comments
 (0)