Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

var<mat> implementation of mdivide_left_spd and mdivide_left_tri #2213

Merged
merged 7 commits into from
Dec 17, 2020
113 changes: 112 additions & 1 deletion stan/math/rev/fun/mdivide_left_spd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class mdivide_left_spd_alloc : public chainable_alloc {
public:
virtual ~mdivide_left_spd_alloc() {}

Eigen::LLT<Eigen::Matrix<double, R1, C1> > llt_;
Eigen::LLT<Eigen::Matrix<double, R1, C1>> llt_;
Eigen::Matrix<double, R2, C2> C_;
};

Expand Down Expand Up @@ -238,6 +238,117 @@ mdivide_left_spd(const EigMat1 &A, const EigMat2 &b) {
return res;
}

/**
* Returns the solution of the system Ax=B where A is symmetric positive
* definite.
*
* This overload handles arguments where one of T1 or T2 are
* `var_value<T>` where `T` is an Eigen type. The other type can
* also be a `var_value` or it can be a matrix type that inherits
* from EigenBase
*
* @tparam T1 type of the first matrix
* @tparam T2 type of the right-hand side matrix or vector
*
* @param A Matrix.
* @param B Right hand side matrix or vector.
* @return x = A^-1 B, solution of the linear system.
* @throws std::domain_error if A is not square or B does not have
* as many rows as A has columns.
*/
template <typename T1, typename T2, require_all_matrix_t<T1, T2> * = nullptr,
require_any_var_matrix_t<T1, T2> * = nullptr>
inline auto mdivide_left_spd(const T1 &A, const T2 &B) {
using ret_val_type = plain_type_t<decltype(value_of(A) * value_of(B))>;
using ret_type = var_value<ret_val_type>;

if (A.size() == 0) {
return ret_type(ret_val_type(0, B.cols()));
}

check_multiplicable("mdivide_left_spd", "A", A, "B", B);

if (!is_constant<T1>::value && !is_constant<T2>::value) {
arena_t<promote_scalar_t<var, T1>> arena_A = A;
arena_t<promote_scalar_t<var, T2>> arena_B = B;

check_symmetric("mdivide_left_spd", "A", arena_A.val());
check_not_nan("mdivide_left_spd", "A", arena_A.val());

auto A_llt = arena_A.val().llt();

check_pos_definite("mdivide_left_spd", "A", A_llt);

arena_t<Eigen::MatrixXd> arena_A_llt = A_llt.matrixL();
arena_t<ret_type> res = A_llt.solve(arena_B.val());

reverse_pass_callback([arena_A, arena_B, arena_A_llt, res]() mutable {
promote_scalar_t<double, T2> adjB = res.adj();

arena_A_llt.template triangularView<Eigen::Lower>().solveInPlace(adjB);
arena_A_llt.template triangularView<Eigen::Lower>()
.transpose()
.solveInPlace(adjB);

arena_A.adj() -= adjB * res.val_op().transpose();
arena_B.adj() += adjB;
});

return ret_type(res);
} else if (!is_constant<T1>::value) {
arena_t<promote_scalar_t<var, T1>> arena_A = A;

check_symmetric("mdivide_left_spd", "A", arena_A.val());
check_not_nan("mdivide_left_spd", "A", arena_A.val());

auto A_llt = arena_A.val().llt();

check_pos_definite("mdivide_left_spd", "A", A_llt);

arena_t<Eigen::MatrixXd> arena_A_llt = A_llt.matrixL();
arena_t<ret_type> res = A_llt.solve(value_of(B));

reverse_pass_callback([arena_A, arena_A_llt, res]() mutable {
promote_scalar_t<double, T2> adjB = res.adj();

arena_A_llt.template triangularView<Eigen::Lower>().solveInPlace(adjB);
arena_A_llt.template triangularView<Eigen::Lower>()
.transpose()
.solveInPlace(adjB);

arena_A.adj() -= adjB * res.val().transpose().eval();
});

return ret_type(res);
} else {
const auto &A_ref = to_ref(value_of(A));
arena_t<promote_scalar_t<var, T2>> arena_B = B;

check_symmetric("mdivide_left_spd", "A", A_ref);
check_not_nan("mdivide_left_spd", "A", A_ref);

auto A_llt = A_ref.llt();

check_pos_definite("mdivide_left_spd", "A", A_llt);

arena_t<Eigen::MatrixXd> arena_A_llt = A_llt.matrixL();
arena_t<ret_type> res = A_llt.solve(arena_B.val());

reverse_pass_callback([arena_B, arena_A_llt, res]() mutable {
promote_scalar_t<double, T2> adjB = res.adj();

arena_A_llt.template triangularView<Eigen::Lower>().solveInPlace(adjB);
arena_A_llt.template triangularView<Eigen::Lower>()
.transpose()
.solveInPlace(adjB);

arena_B.adj() += adjB;
});

return ret_type(res);
}
}

} // namespace math
} // namespace stan
#endif
92 changes: 90 additions & 2 deletions stan/math/rev/fun/mdivide_left_tri.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,11 @@ class mdivide_left_tri_vd_vari : public vari {
} else {
#endif
adjA.noalias()
= -Map<Matrix<double, R1, C1> >(A_, M_, M_)
= -Map<Matrix<double, R1, C1>>(A_, M_, M_)
.template triangularView<TriView>()
.transpose()
.solve(adjC
* Map<Matrix<double, R1, C2> >(C_, M_, N_).transpose());
* Map<Matrix<double, R1, C2>>(C_, M_, N_).transpose());
#ifdef STAN_OPENCL
}
#endif
Expand Down Expand Up @@ -389,6 +389,94 @@ mdivide_left_tri(const T1 &A, const T2 &b) {
return res;
}

/**
* Returns the solution of the system Ax=B when A is triangular.
*
* This overload handles arguments where one of T1 or T2 are
* `var_value<T>` where `T` is an Eigen type. The other type can
* also be a `var_value` or it can be a matrix type that inherits
* from EigenBase
*
* @tparam TriView Specifies whether A is upper (Eigen::Upper)
* or lower triangular (Eigen::Lower).
* @tparam T1 type of the triangular matrix
* @tparam T2 type of the right-hand side matrix or vector
*
* @param A Triangular matrix.
* @param B Right hand side matrix or vector.
* @return x = A^-1 B, solution of the linear system.
* @throws std::domain_error if A is not square or B does not have
* as many rows as A has columns.
*/
template <Eigen::UpLoType TriView, typename T1, typename T2,
require_all_matrix_t<T1, T2> * = nullptr,
require_any_var_matrix_t<T1, T2> * = nullptr>
inline auto mdivide_left_tri(const T1 &A, const T2 &B) {
using ret_val_type = plain_type_t<decltype(value_of(A) * value_of(B))>;
using ret_type = var_value<ret_val_type>;

if (A.size() == 0) {
return ret_type(ret_val_type(0, B.cols()));
}

check_square("mdivide_left_tri", "A", A);
check_multiplicable("mdivide_left_tri", "A", A, "B", B);

if (!is_constant<T1>::value && !is_constant<T2>::value) {
arena_t<promote_scalar_t<var, T1>> arena_A = A;
arena_t<promote_scalar_t<var, T2>> arena_B = B;
auto arena_A_val = to_arena(arena_A.val());

arena_t<ret_type> res
= arena_A_val.template triangularView<TriView>().solve(arena_B.val());

reverse_pass_callback([arena_A, arena_B, arena_A_val, res]() mutable {
promote_scalar_t<double, T2> adjB
= arena_A_val.template triangularView<TriView>().transpose().solve(
res.adj());

arena_B.adj() += adjB;
arena_A.adj() -= (adjB * res.val().transpose().eval())
.template triangularView<TriView>();
});

return ret_type(res);
} else if (!is_constant<T1>::value) {
arena_t<promote_scalar_t<var, T1>> arena_A = A;
auto arena_A_val = to_arena(arena_A.val());

arena_t<ret_type> res
= arena_A_val.template triangularView<TriView>().solve(value_of(B));

reverse_pass_callback([arena_A, arena_A_val, res]() mutable {
promote_scalar_t<double, T2> adjB
= arena_A_val.template triangularView<TriView>().transpose().solve(
res.adj());

arena_A.adj() -= (adjB * res.val().transpose().eval())
.template triangularView<TriView>();
});

return ret_type(res);
} else {
arena_t<promote_scalar_t<double, T1>> arena_A = value_of(A);
arena_t<promote_scalar_t<var, T2>> arena_B = B;

arena_t<ret_type> res
= arena_A.template triangularView<TriView>().solve(arena_B.val());

reverse_pass_callback([arena_A, arena_B, res]() mutable {
promote_scalar_t<double, T2> adjB
= arena_A.template triangularView<TriView>().transpose().solve(
res.adj());

arena_B.adj() += adjB;
});

return ret_type(res);
}
}

} // namespace math
} // namespace stan
#endif
22 changes: 21 additions & 1 deletion test/unit/math/mix/fun/mdivide_left_spd_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ TEST(MathMixMatFun, mdivideLeftSpd) {
auto f = [](const auto& x, const auto& y) {
if (x.rows() != x.cols())
return stan::math::mdivide_left_spd(x, y);
auto x_sym = ((x + x.transpose()) * 0.5).eval(); // sym for finite diffs
auto x_sym = stan::math::eval(
stan::math::multiply(x + x.transpose(), 0.5)); // sym for finite diffs
return stan::math::mdivide_left_spd(x_sym, y);
};

Expand All @@ -15,18 +16,24 @@ TEST(MathMixMatFun, mdivideLeftSpd) {
stan::test::expect_ad(f, m00, m00);
stan::test::expect_ad(f, m00, m02);
stan::test::expect_ad(f, m00, v0);
stan::test::expect_ad_matvar(f, m00, m00);
stan::test::expect_ad_matvar(f, m00, m02);
stan::test::expect_ad_matvar(f, m00, v0);

Eigen::MatrixXd aa(1, 1);
aa << 1;
Eigen::MatrixXd bb(1, 1);
bb << 2;
stan::test::expect_ad(f, aa, bb);
stan::test::expect_ad_matvar(f, aa, bb);
Eigen::MatrixXd b0(1, 0);
stan::test::expect_ad(f, aa, b0);
stan::test::expect_ad_matvar(f, aa, b0);

Eigen::VectorXd cc(1);
cc << 3;
stan::test::expect_ad(f, aa, cc);
stan::test::expect_ad_matvar(f, aa, cc);

Eigen::MatrixXd a(2, 2);
a << 2, 3, 3, 7;
Expand All @@ -45,9 +52,15 @@ TEST(MathMixMatFun, mdivideLeftSpd) {
stan::test::expect_ad(f, b, b);
stan::test::expect_ad(f, a, c);
stan::test::expect_ad(f, b, c);
stan::test::expect_ad_matvar(f, a, a);
stan::test::expect_ad_matvar(f, b, b);
stan::test::expect_ad_matvar(f, a, c);
stan::test::expect_ad_matvar(f, b, c);
// matrix, vector : ditto
stan::test::expect_ad(f, a, d);
stan::test::expect_ad(f, b, d);
stan::test::expect_ad_matvar(f, a, d);
stan::test::expect_ad_matvar(f, b, d);

Eigen::MatrixXd m33 = Eigen::MatrixXd::Zero(3, 3);
Eigen::MatrixXd m44 = Eigen::MatrixXd::Zero(4, 4);
Expand All @@ -58,17 +71,24 @@ TEST(MathMixMatFun, mdivideLeftSpd) {
// exceptions: not symmetric
stan::test::expect_ad(f, c, a);
stan::test::expect_ad(f, c, d);
stan::test::expect_ad_matvar(f, c, a);
stan::test::expect_ad_matvar(f, c, d);

// exceptions: not pos def
stan::test::expect_ad(f, m33, m33);
stan::test::expect_ad(f, m33, v3);
stan::test::expect_ad_matvar(f, m33, m33);
stan::test::expect_ad_matvar(f, m33, v3);

// exceptions: wrong sizes
stan::test::expect_ad(f, m33, m44);
stan::test::expect_ad(f, m33, v4);
stan::test::expect_ad_matvar(f, m33, m44);
stan::test::expect_ad_matvar(f, m33, v4);

// exceptions: wrong types
stan::test::expect_ad(f, m33, rv3);
stan::test::expect_ad_matvar(f, m33, rv3);

stan::math::recover_memory();
}
23 changes: 23 additions & 0 deletions test/unit/math/mix/fun/mdivide_left_tri_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ TEST(MathMixMatFun, mdivideLeftTri) {
Eigen::VectorXd v0(0);
stan::test::expect_ad(f, m00, v0);
stan::test::expect_ad(f, m00, m00);
stan::test::expect_ad_matvar(f, m00, v0);
stan::test::expect_ad_matvar(f, m00, m00);

// signature 1 of 2: matrix-matrix
Eigen::MatrixXd aa(1, 1);
Expand All @@ -21,12 +23,16 @@ TEST(MathMixMatFun, mdivideLeftTri) {
bb << 2;
stan::test::expect_ad(f, aa, bb);
stan::test::expect_ad(f_up, aa, bb);
stan::test::expect_ad_matvar(f, aa, bb);
stan::test::expect_ad_matvar(f_up, aa, bb);

// signature 2 of 2: matrix-vector
Eigen::VectorXd cc(1);
cc << 3;
stan::test::expect_ad(f, aa, cc);
stan::test::expect_ad(f_up, aa, cc);
stan::test::expect_ad_matvar(f, aa, cc);
stan::test::expect_ad_matvar(f_up, aa, cc);

Eigen::MatrixXd a(2, 2);
a << 2, 0, 5, 7;
Expand All @@ -41,6 +47,12 @@ TEST(MathMixMatFun, mdivideLeftTri) {
stan::test::expect_ad(f_up, a_tr, a);
stan::test::expect_ad(f_up, a_tr, b);
stan::test::expect_ad(f_up, a_tr, c);
stan::test::expect_ad_matvar(f, a, a);
stan::test::expect_ad_matvar(f, a, b);
stan::test::expect_ad_matvar(f, a, c);
stan::test::expect_ad_matvar(f_up, a_tr, a);
stan::test::expect_ad_matvar(f_up, a_tr, b);
stan::test::expect_ad_matvar(f_up, a_tr, c);

Eigen::MatrixXd y(3, 3);
y << 1, 0, 0, 2, 3, 0, 4, 5, 6;
Expand All @@ -53,12 +65,17 @@ TEST(MathMixMatFun, mdivideLeftTri) {
stan::test::expect_ad(f, u, y);
stan::test::expect_ad(f_up, y_tr, z);
stan::test::expect_ad(f_up, y_tr, y);
stan::test::expect_ad_matvar(f, y, z);
stan::test::expect_ad_matvar(f, u, y);
stan::test::expect_ad_matvar(f_up, y_tr, z);
stan::test::expect_ad_matvar(f_up, y_tr, y);

Eigen::MatrixXd uu(2, 2);
uu << 3, 0, 1, 4;
Eigen::MatrixXd vv(2, 2);
vv << 2, 3, 5, 7;
stan::test::expect_ad(f, uu, vv);
stan::test::expect_ad_matvar(f, uu, vv);

// exception cases
Eigen::MatrixXd d(3, 2);
Expand All @@ -68,6 +85,9 @@ TEST(MathMixMatFun, mdivideLeftTri) {
stan::test::expect_ad(f, d, b);
stan::test::expect_ad(f, d, c);
stan::test::expect_ad(f, a, e);
stan::test::expect_ad_matvar(f, d, b);
stan::test::expect_ad_matvar(f, d, c);
stan::test::expect_ad_matvar(f, a, e);

Eigen::MatrixXd m33 = Eigen::MatrixXd::Zero(3, 3);
Eigen::MatrixXd m44 = Eigen::MatrixXd::Zero(4, 4);
Expand All @@ -77,7 +97,10 @@ TEST(MathMixMatFun, mdivideLeftTri) {
// exceptions: wrong sizes
stan::test::expect_ad(f, m33, m44);
stan::test::expect_ad(f, m33, v4);
stan::test::expect_ad_matvar(f, m33, m44);
stan::test::expect_ad_matvar(f, m33, v4);

// exceptions: wrong types
stan::test::expect_ad(f, m33, rv3);
stan::test::expect_ad_matvar(f, m33, rv3);
}