Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds unary var matrix functions #2362

Merged
merged 9 commits into from
Feb 24, 2021
6 changes: 4 additions & 2 deletions stan/math/prim/fun/Phi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@ struct Phi_fun {
* @param x container
* @return Unit normal CDF of each value in x.
*/
template <typename T, require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
T>* = nullptr>
template <
typename T,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr,
require_not_var_matrix_t<T>* = nullptr>
inline auto Phi(const T& x) {
return apply_scalar_unary<Phi_fun, T>::apply(x);
}
Expand Down
6 changes: 4 additions & 2 deletions stan/math/prim/fun/Phi_approx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@ struct Phi_approx_fun {
* @param x container
* @return elementwise Phi_approx of container elements
*/
template <typename T, require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
T>* = nullptr>
template <
typename T,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr,
require_not_var_matrix_t<T>* = nullptr>
inline auto Phi_approx(const T& x) {
return apply_scalar_unary<Phi_approx_fun, T>::apply(x);
}
Expand Down
3 changes: 2 additions & 1 deletion stan/math/prim/fun/digamma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ struct digamma_fun {
* @throw std::domain_error if any value is a negative integer or 0
*/
template <typename T,
require_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr>
require_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr,
require_not_var_matrix_t<T>* = nullptr>
inline auto digamma(const T& x) {
return apply_scalar_unary<digamma_fun, T>::apply(x);
}
Expand Down
3 changes: 1 addition & 2 deletions stan/math/prim/fun/distance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ inline return_type_t<T1, T2> distance(const T1& x1, const T2& x2) {
* @throw std::domain_error If the vectors are not the same
* size.
*/
template <typename T1, typename T2,
require_all_eigen_vector_t<T1, T2>* = nullptr>
template <typename T1, typename T2, require_all_vector_t<T1, T2>* = nullptr>
inline return_type_t<T1, T2> distance(const T1& x1, const T2& x2) {
using std::sqrt;
check_matching_sizes("distance", "x1", x1, "x2", x2);
Expand Down
6 changes: 4 additions & 2 deletions stan/math/prim/fun/inv_Phi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,10 @@ struct inv_Phi_fun {
* @return Inverse unit normal CDF of each value in x.
* @throw std::domain_error if any value is not between 0 and 1.
*/
template <typename T, require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
T>* = nullptr>
template <
typename T,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr,
require_not_var_matrix_t<T>* = nullptr>
inline auto inv_Phi(const T& x) {
return apply_scalar_unary<inv_Phi_fun, T>::apply(x);
}
Expand Down
6 changes: 4 additions & 2 deletions stan/math/prim/fun/sqrt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ struct sqrt_fun {
template <typename Container,
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
Container>* = nullptr>
Container>* = nullptr,
require_not_var_matrix_t<Container>* = nullptr>
inline auto sqrt(const Container& x) {
return apply_scalar_unary<sqrt_fun, Container>::apply(x);
}
Expand All @@ -50,7 +51,8 @@ inline auto sqrt(const Container& x) {
* @return Square root of each value in x.
*/
template <typename Container,
require_container_st<std::is_arithmetic, Container>* = nullptr>
require_container_st<std::is_arithmetic, Container>* = nullptr,
require_not_var_matrix_t<Container>* = nullptr>
inline auto sqrt(const Container& x) {
return apply_vector_unary<Container>::apply(
x, [](const auto& v) { return v.array().sqrt(); });
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/fun/tail.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace math {
* @return The last n elements of v.
* @throw std::out_of_range if n is out of range.
*/
template <typename T, require_eigen_vector_t<T>* = nullptr>
template <typename T, require_vector_t<T>* = nullptr>
inline auto tail(const T& v, size_t n) {
if (n != 0) {
check_vector_index("tail", "n", v, n);
Expand Down
6 changes: 4 additions & 2 deletions stan/math/prim/fun/tgamma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ struct tgamma_fun {
* @return Gamma function applied to each value in x.
* @throw std::domain_error if any value is 0 or a negative integer
*/
template <typename T, require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
T>* = nullptr>
template <
typename T,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr,
require_not_var_matrix_t<T>* = nullptr>
inline auto tgamma(const T& x) {
return apply_scalar_unary<tgamma_fun, T>::apply(x);
}
Expand Down
26 changes: 13 additions & 13 deletions stan/math/rev/fun/Phi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,6 @@
namespace stan {
namespace math {

namespace internal {
class Phi_vari : public op_v_vari {
public:
explicit Phi_vari(vari* avi) : op_v_vari(Phi(avi->val_), avi) {}
void chain() {
static const double NEG_HALF = -0.5;
avi_->adj_ += adj_ * INV_SQRT_TWO_PI
* std::exp(NEG_HALF * avi_->val_ * avi_->val_);
}
};
} // namespace internal

/**
* The unit normal cumulative density function for variables (stan).
*
Expand Down Expand Up @@ -62,7 +50,19 @@ class Phi_vari : public op_v_vari {
* @param a Variable argument.
* @return The unit normal cdf evaluated at the specified argument.
*/
inline var Phi(const var& a) { return var(new internal::Phi_vari(a.vi_)); }
inline var Phi(const var& a) {
return make_callback_var(Phi(a.val()), [a](auto& vi) mutable {
a.adj() += vi.adj() * INV_SQRT_TWO_PI * std::exp(-0.5 * a.val() * a.val());
});
}

template <typename T, require_var_matrix_t<T>* = nullptr>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably need doxygen here?

inline auto Phi(const T& a) {
return make_callback_var(Phi(a.val()), [a](auto& vi) mutable {
a.adj().array() += vi.adj().array() * INV_SQRT_TWO_PI
* (-0.5 * a.val().array().square()).exp();
});
}

} // namespace math
} // namespace stan
Expand Down
28 changes: 23 additions & 5 deletions stan/math/rev/fun/Phi_approx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,30 @@ namespace math {
* @return The corresponding unit normal cdf approximation.
*/
inline var Phi_approx(const var& a) {
double av = a.vi_->val_;
double av_squared = av * av;
double av_cubed = av * av_squared;
double f = inv_logit(0.07056 * av_cubed + 1.5976 * av);
double av_squared = a.val() * a.val();
double f = inv_logit(0.07056 * a.val() * av_squared + 1.5976 * a.val());
double da = f * (1 - f) * (3.0 * 0.07056 * av_squared + 1.5976);
return var(new precomp_v_vari(f, a.vi_, da));
return make_callback_var(
f, [a, da](auto& vi) mutable { a.adj() += vi.adj() * da; });
}

template <typename T, require_var_matrix_t<T>* = nullptr>
inline auto Phi_approx(const T& a) {
arena_t<value_type_t<T>> f(a.rows(), a.cols());
arena_t<value_type_t<T>> da(a.rows(), a.cols());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to compute this in the forward pass? I guess we would compute av_squared twice if we compute this in the backward pass. Is that the reason?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that's the general jist. we can compute adjoint while calculating the forward pass so it's just nice to alloc everything at once, throw it in one big loop, and then have an easy peasy reverse pass.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool.

for (Eigen::Index j = 0; j < a.cols(); ++j) {
for (Eigen::Index i = 0; i < a.rows(); ++i) {
const auto a_val = a.val().coeff(i, j);
const auto av_squared = a_val * a_val;
f.coeffRef(i, j) = inv_logit(0.07056 * a_val * av_squared
+ 1.5976 * a.val().coeff(i, j));
da.coeffRef(i, j) = f.coeff(i, j) * (1 - f.coeff(i, j))
* (3.0 * 0.07056 * av_squared + 1.5976);
}
}
return make_callback_var(f, [a, da](auto& vi) mutable {
a.adj().array() += vi.adj().array() * da.array();
});
}

} // namespace math
Expand Down
19 changes: 18 additions & 1 deletion stan/math/rev/fun/digamma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,24 @@ class digamma_vari : public op_v_vari {
} // namespace internal

inline var digamma(const var& a) {
return var(new internal::digamma_vari(a.vi_));
return make_callback_var(digamma(a.val()), [a](auto& vi) {
a.adj() += vi.adj() * trigamma(a.val());
});
}

template <typename T, require_var_matrix_t<T>* = nullptr>
inline auto digamma(const T& a) {
return make_callback_var(
a.val()
.array()
.unaryExpr([](auto& x) { return digamma(x); })
.matrix()
.eval(),
[a](auto& vi) mutable {
a.adj().array()
+= vi.adj().array()
* a.val().array().unaryExpr([](auto& x) { return trigamma(x); });
});
}

} // namespace math
Expand Down
12 changes: 11 additions & 1 deletion stan/math/rev/fun/inv_Phi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,17 @@ class inv_Phi_vari : public op_v_vari {
* @return The unit normal inverse cdf evaluated at p
*/
inline var inv_Phi(const var& p) {
return var(new internal::inv_Phi_vari(p.vi_));
return make_callback_var(inv_Phi(p.val()), [p](auto& vi) mutable {
p.adj() += vi.adj() * SQRT_TWO_PI / std::exp(-0.5 * vi.val() * vi.val());
});
}

template <typename T, require_var_matrix_t<T>* = nullptr>
inline auto inv_Phi(const T& p) {
return make_callback_var(inv_Phi(p.val()), [p](auto& vi) mutable {
p.adj().array() += vi.adj().array() * SQRT_TWO_PI
/ (-0.5 * vi.val().array().square()).exp();
});
}

} // namespace math
Expand Down
18 changes: 7 additions & 11 deletions stan/math/rev/fun/pow.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,23 +140,19 @@ template <typename T, typename = require_arithmetic_t<T>>
inline var pow(const var& base, T exponent) {
if (exponent == 0.5) {
return sqrt(base);
}
if (exponent == 1.0) {
} else if (exponent == 1.0) {
return base;
}
if (exponent == 2.0) {
} else if (exponent == 2.0) {
return square(base);
}
if (exponent == -2.0) {
} else if (exponent == -2.0) {
return inv_square(base);
}
if (exponent == -1.0) {
} else if (exponent == -1.0) {
return inv(base);
}
if (exponent == -0.5) {
} else if (exponent == -0.5) {
return inv_sqrt(base);
} else {
return {new internal::pow_vd_vari(base.vi_, exponent)};
}
return {new internal::pow_vd_vari(base.vi_, exponent)};
}

/**
Expand Down
14 changes: 13 additions & 1 deletion stan/math/rev/fun/sqrt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,19 @@ class sqrt_vari : public op_v_vari {
* @param a Variable whose square root is taken.
* @return Square root of variable.
*/
inline var sqrt(const var& a) { return var(new internal::sqrt_vari(a.vi_)); }
inline var sqrt(const var& a) {
return make_callback_var(std::sqrt(a.val()), [a](auto& vi) mutable {
a.adj() += vi.adj() / (2.0 * vi.val());
});
}

template <typename T, require_var_matrix_t<T>* = nullptr>
inline auto sqrt(const T& a) {
return make_callback_var(
a.val().array().sqrt().matrix(), [a](auto& vi) mutable {
a.adj().array() += vi.adj().array() / (2.0 * vi.val_op().array());
});
}

/**
* Return the square root of the complex argument.
Expand Down
12 changes: 11 additions & 1 deletion stan/math/rev/fun/tgamma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,17 @@ class tgamma_vari : public op_v_vari {
* @return The Gamma function applied to the specified argument.
*/
inline var tgamma(const var& a) {
return var(new internal::tgamma_vari(a.vi_));
return make_callback_var(tgamma(a.val()), [a](auto& vi) mutable {
a.adj() += vi.adj() * vi.val() * digamma(a.val());
});
}

template <typename T, require_var_matrix_t<T>* = nullptr>
inline auto tgamma(const T& a) {
return make_callback_var(tgamma(a.val()), [a](auto& vi) mutable {
a.adj().array()
+= vi.adj().array() * vi.val().array() * digamma(a.val()).array();
});
}

} // namespace math
Expand Down
18 changes: 18 additions & 0 deletions test/unit/math/mix/fun/Phi_approx_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,21 @@ TEST(mathMixMatFun, PhiApprox) {
stan::test::expect_common_unary_vectorized(f);
stan::test::expect_unary_vectorized(f, -3.0, 1, 1.3, 3);
}

TEST(mathMixMatFun, PhiApprox_varmat) {
using stan::math::vec_concat;
using stan::test::expect_ad_vector_matvar;
using stan::test::internal::common_nonzero_args;
auto f = [](const auto& x1) {
using stan::math::Phi_approx;
return Phi_approx(x1);
};
std::vector<double> com_args = common_nonzero_args();
std::vector<double> args{-3.0, 1, 1.3, 3};
auto all_args = vec_concat(com_args, args);
Eigen::VectorXd A(all_args.size());
for (int i = 0; i < all_args.size(); ++i) {
A(i) = all_args[i];
}
expect_ad_vector_matvar(f, A);
}
21 changes: 21 additions & 0 deletions test/unit/math/mix/fun/Phi_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,24 @@ TEST(mathMixMatFun, Phi) {
for (double x = -37.5; x <= 10; x += 0.5)
stan::test::expect_unary_vectorized(x);
}

TEST(mathMixMatFun, Phi_varmat) {
using stan::math::vec_concat;
using stan::test::expect_ad_vector_matvar;
using stan::test::internal::common_nonzero_args;
auto f = [](const auto& x1) {
using stan::math::Phi;
return Phi(x1);
};
std::vector<double> com_args = common_nonzero_args();
std::vector<double> args{-27.5, 27.5};
for (double x = -37.5; x <= 10; x += 0.5) {
args.push_back(x);
}
auto all_args = vec_concat(com_args, args);
Eigen::VectorXd A(all_args.size());
for (int i = 0; i < all_args.size(); ++i) {
A(i) = all_args[i];
}
expect_ad_vector_matvar(f, A);
}
18 changes: 18 additions & 0 deletions test/unit/math/mix/fun/digamma_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,21 @@ TEST(mathMixMatFun, digamma) {
stan::test::expect_common_nonzero_unary_vectorized(f);
stan::test::expect_unary_vectorized(f, -25, -10.2, -1.2, -1, 2.3, 5.7);
}

TEST(mathMixMatFun, digamma_varmat) {
using stan::math::vec_concat;
using stan::test::expect_ad_vector_matvar;
using stan::test::internal::common_nonzero_args;
auto f = [](const auto& x1) {
using stan::math::digamma;
return digamma(x1);
};
std::vector<double> com_args = common_nonzero_args();
std::vector<double> args{-25, -10.2, -1.2, -1, 2.3, 5.7};
auto all_args = vec_concat(com_args, args);
Eigen::VectorXd A(all_args.size());
for (int i = 0; i < all_args.size(); ++i) {
A(i) = all_args[i];
}
expect_ad_vector_matvar(f, A);
}
6 changes: 6 additions & 0 deletions test/unit/math/mix/fun/distance_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,37 @@ TEST(MathMixMatFun, distance) {
Eigen::VectorXd x0(0);
Eigen::VectorXd y0(0);
stan::test::expect_ad(f, x0, y0);
stan::test::expect_ad_matvar(f, x0, y0);

// 1 x 1
Eigen::VectorXd x1(1);
x1 << 1;
Eigen::VectorXd y1(1);
y1 << -2.3;
stan::test::expect_ad(f, x1, y1);
stan::test::expect_ad_matvar(f, x1, y1);

// 2 x 2
Eigen::VectorXd x2(2);
x2 << 2, -3;
Eigen::VectorXd y2(2);
y2 << -2.3, 1.1;
stan::test::expect_ad(f, x2, y2);
stan::test::expect_ad_matvar(f, x2, y2);

// 3 x 3
Eigen::VectorXd x(3);
x << 1, 3, -5;
Eigen::VectorXd y(3);
y << 4, -2, -1;
stan::test::expect_ad(f, x, y);
stan::test::expect_ad_matvar(f, x, y);

// exception cases
Eigen::VectorXd z(2);
z << 1, 2;
stan::test::expect_ad(f, x, z);
stan::test::expect_ad(f, z, x);
stan::test::expect_ad_matvar(f, x, z);
stan::test::expect_ad_matvar(f, z, x);
}
Loading