Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

check generalized inverse for full rank symmetric mat #2577

Merged
merged 27 commits into from
Sep 18, 2021
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
e1a375b
check symmetric mat for full rank
spinkney Sep 8, 2021
88e3ae2
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Sep 8, 2021
294fbb4
remove fullpivLU
spinkney Sep 8, 2021
9ef5618
Merge branch 'generalized_inverse_fix' of https://github.com/spinkney…
spinkney Sep 8, 2021
d459414
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Sep 8, 2021
f176c35
Update generalized_inverse.hpp
spinkney Sep 8, 2021
cf36885
fix for expression tests
spinkney Sep 9, 2021
8a1786d
Revert "fix for expression tests"
spinkney Sep 9, 2021
cadfa12
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Sep 9, 2021
e20d39e
fix expression tests
spinkney Sep 9, 2021
3df6a8d
Merge branch 'generalized_inverse_fix' of https://github.com/spinkney…
spinkney Sep 9, 2021
0ebe908
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Sep 9, 2021
456bd74
rank revealing decomp for symmetric matrix
spinkney Sep 9, 2021
c78fc21
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Sep 9, 2021
27c2647
small optimization in derivative
spinkney Sep 9, 2021
9bae32e
Revert "small optimization in derivative"
spinkney Sep 9, 2021
52c1be3
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Sep 9, 2021
9c357d7
Merge branch 'generalized_inverse_fix' of https://github.com/spinkney…
spinkney Sep 9, 2021
3f1ea94
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Sep 9, 2021
8d60e00
Update generalized_inverse.hpp
spinkney Sep 9, 2021
1670e27
add braces for if
spinkney Sep 9, 2021
70a21d0
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Sep 9, 2021
9a3e82a
update for review comments
spinkney Sep 9, 2021
2c686b4
update test names
spinkney Sep 9, 2021
1de51d0
revert add_diag
spinkney Sep 17, 2021
7722ddf
Merge commit '8920da4297b72d50f3fa456ed8f2cd2bd714265c' into HEAD
yashikno Sep 17, 2021
3644737
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Sep 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions stan/math/prim/fun/generalized_inverse.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,22 @@ template <typename EigMat, require_eigen_t<EigMat>* = nullptr,
inline Eigen::Matrix<value_type_t<EigMat>, EigMat::ColsAtCompileTime,
EigMat::RowsAtCompileTime>
generalized_inverse(const EigMat& G) {
if (G.size() == 0)
const auto& G_ref = to_ref(G);
if (G_ref.size() == 0)
return {};

if (G.rows() == G.cols())
return inverse(G);

const auto& G_ref = to_ref(G);
if (G_ref.rows() == G_ref.cols()) {
Eigen::CompleteOrthogonalDecomposition<
Eigen::Matrix<value_type_t<EigMat>, EigMat::RowsAtCompileTime,
EigMat::ColsAtCompileTime>>
complete_ortho_decomp_G = G_ref.completeOrthogonalDecomposition();
if (!(complete_ortho_decomp_G.rank() < G_ref.rows()))
return inverse(G_ref);
Comment on lines +38 to +44
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So my only Q is how much this costs / accuracy of computing the completeOrthogonalDecomposition() and then using it for the check and then sometimes using it for the pseudoinverse. If someone passes an NxM matrix here i sort of feels like they either

  1. Would know ahead of time that they will or will not have a square matrix
  2. Know their matrix is low rank and they can't use inverse()

Are those not good assumptions? If they are then should we always just do the pseudoinverse()? Is the accuracy of pseudoinverse pretty low compared to inverse()? I almost feel like even if so someone using this would expect a pseudoinverse instead of the inverse

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are great questions!

On

If someone passes an NxM matrix here i sort of feels like they either

  1. Would know ahead of time that they will or will not have a square matrix
  2. Know their matrix is low rank and they can't use inverse()

I'm interested in the case where you'd specifically pass a square matrix but not know ahead of time if the matrix is low rank or not. It's probably low-rank but it may not be (say in a factor model).

This got me to revisit the paper that the code is based on. The author does something that I missed because he notates L as a cholesky factor when it is actually a modified cholesky factor that works even for low rank matrices. This is how he is able to not do a rank find (which one usually does either as I've done here or using SVD). I've coded up the prim version at https://github.com/spinkney/math/tree/generalized_inverse_low_rank_chol. The point is that some rank issue must be dealt with or there's numerical instability issues.

The code now follows the paper explicitly. The major change involves adding a cholesky_low_rank_decomposition lambda. I've coded it using Eigen block notation. I believe there's further optimizations that could be done, if someone with a background in high performance computing helps. It could be added to Stan as a separate function, which is cool in it's own right.

The prim tests pass. I haven't benchmarked though to see how it compares to the completeOrthogonalDecomposition() version here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've put together an example how low rank chol can transform a low rank matrix of "factors" into a valid correlation matrix. I'm not sure if this works all the time but it's promising. low_rank_chol_eg.R.zip

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nice! So with the low rank chol version should we close this PR and open one up based on the low rank chol? Or should we merge this then make a separate PR for that?

Copy link
Collaborator Author

@spinkney spinkney Sep 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking that the low rank chol should be the final version. Though I think it may be worth just getting this update to fix the symmetric low-rank case (for this release cycle) and then in the next release having 2 PRs: 1) add low rank chol as a function and 2) update gen inverse using it in a new PR.

else
return complete_ortho_decomp_G.pseudoInverse();
}

if (G.rows() < G.cols()) {
if (G_ref.rows() < G_ref.cols()) {
return (G_ref * G_ref.transpose()).ldlt().solve(G_ref).transpose();
} else {
return (G_ref.transpose() * G_ref).ldlt().solve(G_ref.transpose());
Expand Down
27 changes: 21 additions & 6 deletions stan/math/rev/fun/generalized_inverse.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef STAN_MATH_REV_FUN_GENERALIZED_INVERSE_HPP
#define STAN_MATH_REV_FUN_GENERALIZED_INVERSE_HPP

#include <stan/math/prim/fun/add_diag.hpp>
#include <stan/math/rev/core.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
Expand All @@ -21,17 +22,19 @@ namespace internal {
template <typename T1, typename T2>
inline auto generalized_inverse_lambda(T1& G_arena, T2& inv_G) {
return [G_arena, inv_G]() mutable {
Eigen::VectorXd ones1
= Eigen::VectorXd::Constant(std::min(G_arena.rows(), inv_G.cols()), 1);
Eigen::VectorXd ones2
= Eigen::VectorXd::Constant(std::min(inv_G.rows(), G_arena.cols()), 1);
G_arena.adj()
+= -(inv_G.val_op().transpose() * inv_G.adj_op()
* inv_G.val_op().transpose())
+ (-G_arena.val_op() * inv_G.val_op()
+ Eigen::MatrixXd::Identity(G_arena.rows(), inv_G.cols()))
+ add_diag(-G_arena.val_op() * inv_G.val_op(), ones1)
* inv_G.adj_op().transpose() * inv_G.val_op()
* inv_G.val_op().transpose()
+ inv_G.val_op().transpose() * inv_G.val_op()
* inv_G.adj_op().transpose()
* (-inv_G.val_op() * G_arena.val_op()
+ Eigen::MatrixXd::Identity(inv_G.rows(), G_arena.cols()));
* add_diag(-inv_G.val_op() * G_arena.val_op(), ones2);
};
}
} // namespace internal
Expand Down Expand Up @@ -66,8 +69,20 @@ inline auto generalized_inverse(const VarMat& G) {
if (G.size() == 0)
return ret_type(G);

if (G.rows() == G.cols())
return ret_type(inverse(G));
if (G.rows() == G.cols()) {
arena_t<VarMat> G_arena(G);
Eigen::CompleteOrthogonalDecomposition<Eigen::MatrixXd>
complete_ortho_decomp_G
= G_arena.val().completeOrthogonalDecomposition();
if (!(complete_ortho_decomp_G.rank() < G.rows())) {
return ret_type(inverse(G));
} else {
arena_t<ret_type> inv_G(complete_ortho_decomp_G.pseudoInverse());
reverse_pass_callback(
internal::generalized_inverse_lambda(G_arena, inv_G));
return ret_type(inv_G);
}
}

if (G.rows() < G.cols()) {
arena_t<VarMat> G_arena(G);
Expand Down
52 changes: 51 additions & 1 deletion test/unit/math/mix/fun/generalized_inverse_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#include <vector>
#include <gtest/gtest.h>

TEST(mathMixMatFun, ad_tests) {
TEST(mathMixMatFun, ad_tests1) {
using stan::test::expect_ad;
using stan::test::expect_ad_matvar;

Expand All @@ -11,25 +11,61 @@ TEST(mathMixMatFun, ad_tests) {
Eigen::MatrixXd t(0, 0);
expect_ad(f, t);
expect_ad_matvar(f, t);
}

TEST(mathMixMatFun, ad_tests2) {
using stan::test::expect_ad;
using stan::test::expect_ad_matvar;

auto f = [](const auto& G) { return stan::math::generalized_inverse(G); };

Eigen::MatrixXd u(1, 1);
u << 2;
expect_ad(f, u);
expect_ad_matvar(f, u);
}

TEST(mathMixMatFun, ad_tests3) {
using stan::test::expect_ad;
using stan::test::expect_ad_matvar;

auto f = [](const auto& G) { return stan::math::generalized_inverse(G); };

Eigen::MatrixXd v(2, 3);
v << 1, 3, 5, 2, 4, 6;
expect_ad(f, v);
expect_ad_matvar(f, v);
}

TEST(mathMixMatFun, ad_tests4) {
using stan::test::expect_ad;
using stan::test::expect_ad_matvar;

auto f = [](const auto& G) { return stan::math::generalized_inverse(G); };

Eigen::MatrixXd v(2, 3);
v << 1.9, 1.3, 2.5, 0.4, 1.7, 0.1;
expect_ad(f, v);
expect_ad_matvar(f, v);
}

TEST(mathMixMatFun, ad_tests5) {
using stan::test::expect_ad;
using stan::test::expect_ad_matvar;

auto f = [](const auto& G) { return stan::math::generalized_inverse(G); };

Eigen::MatrixXd s(2, 4);
s << 3.4, 2, 5, 1.2, 2, 1, 3.2, 3.1;
expect_ad(f, s);
expect_ad_matvar(f, s);
}

TEST(mathMixMatFun, ad_tests6) {
using stan::test::expect_ad;
using stan::test::expect_ad_matvar;

auto f = [](const auto& G) { return stan::math::generalized_inverse(G); };

// issues around zero require looser tolerances for hessians
stan::test::ad_tolerances tols;
Expand All @@ -40,10 +76,24 @@ TEST(mathMixMatFun, ad_tests) {
w << 2, 3, 5, 7, 11, 13, 17, 19, 23, 25, 27, 29;
expect_ad(tols, f, w);
expect_ad_matvar(f, w);
}

TEST(mathMixMatFun, ad_tests8) {
using stan::test::expect_ad;
using stan::test::expect_ad_matvar;

auto f = [](const auto& G) { return stan::math::generalized_inverse(G); };

Eigen::MatrixXd z(2, 2);
z << 1, 2, 5, std::numeric_limits<double>::quiet_NaN();
EXPECT_NO_THROW(stan::math::generalized_inverse(z));
}

TEST(mathMixMatFun, ad_tests9) {
using stan::test::expect_ad;
using stan::test::expect_ad_matvar;

auto f = [](const auto& G) { return stan::math::generalized_inverse(G); };

// autodiff throws, so following fails (throw behavior must match to pass)

Expand Down
14 changes: 14 additions & 0 deletions test/unit/math/prim/fun/generalized_inverse_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,17 @@ TEST(MathMatrixPrim, Equal2) {
stan::math::matrix_d m3 = m1 * generalized_inverse(m1);
EXPECT_MATRIX_NEAR(m2, m3, 1e-9);
}

TEST(MathMatrixPrim, Equal3) {
using stan::math::generalized_inverse;

stan::math::matrix_d m1(3, 3);
m1 << 4, 4, -2, 4, 4, -2, -2, -2, 10;

stan::math::matrix_d m2(3, 3);
m2 << 0.069444444, 0.069444444, 0.0277777777, 0.069444444, 0.069444444,
0.0277777777, 0.027777777, 0.027777777, 0.1111111111;

stan::math::matrix_d m3 = generalized_inverse(m1);
EXPECT_MATRIX_NEAR(m2, m3, 1e-9);
}