Skip to content

Commit 3473789

Browse files
committed
Merge branch 'develop' into generalize_weibull
# Conflicts: # test/expressions/stan_math_sigs_exceptions.expected
2 parents 62408d2 + 304b817 commit 3473789

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+494
-425
lines changed

stan/math/fwd/fun/unit_vector_constrain.hpp

+22-30
Original file line numberDiff line numberDiff line change
@@ -15,44 +15,36 @@
1515
namespace stan {
1616
namespace math {
1717

18-
template <typename T, int R, int C>
19-
inline Eigen::Matrix<fvar<T>, R, C> unit_vector_constrain(
20-
const Eigen::Matrix<fvar<T>, R, C>& y) {
21-
using Eigen::Matrix;
22-
using std::sqrt;
23-
24-
Matrix<T, R, C> y_t(y.size());
25-
for (int k = 0; k < y.size(); ++k) {
26-
y_t.coeffRef(k) = y.coeff(k).val_;
27-
}
28-
29-
Matrix<T, R, C> unit_vector_y_t = unit_vector_constrain(y_t);
30-
Matrix<fvar<T>, R, C> unit_vector_y(y.size());
31-
for (int k = 0; k < y.size(); ++k) {
32-
unit_vector_y.coeffRef(k).val_ = unit_vector_y_t.coeff(k);
33-
}
34-
35-
T squared_norm = dot_self(y_t);
36-
T norm = sqrt(squared_norm);
37-
T inv_norm = inv(norm);
38-
Matrix<T, Eigen::Dynamic, Eigen::Dynamic> J
39-
= divide(tcrossprod(y_t), -norm * squared_norm);
40-
41-
for (int m = 0; m < y.size(); ++m) {
18+
template <typename EigMat, require_eigen_vt<is_fvar, EigMat>* = nullptr>
19+
inline auto unit_vector_constrain(const EigMat& y) {
20+
using eig_partial = partials_type_t<value_type_t<EigMat>>;
21+
promote_scalar_t<eig_partial, EigMat> y_val(value_of(y));
22+
plain_type_t<EigMat> unit_vector_y(y_val.size());
23+
unit_vector_y.val() = unit_vector_constrain(y_val);
24+
25+
eig_partial squared_norm = dot_self(y_val);
26+
eig_partial norm = sqrt(squared_norm);
27+
eig_partial inv_norm = inv(norm);
28+
Eigen::Matrix<eig_partial, Eigen::Dynamic, Eigen::Dynamic> J
29+
= divide(tcrossprod(y_val), -norm * squared_norm);
30+
31+
for (Eigen::Index m = 0; m < y_val.size(); ++m) {
4232
J.coeffRef(m, m) += inv_norm;
43-
for (int k = 0; k < y.size(); ++k) {
33+
for (Eigen::Index k = 0; k < y_val.size(); ++k) {
4434
unit_vector_y.coeffRef(k).d_ = J.coeff(k, m);
4535
}
4636
}
4737
return unit_vector_y;
4838
}
4939

50-
template <typename T, int R, int C>
51-
inline Eigen::Matrix<fvar<T>, R, C> unit_vector_constrain(
52-
const Eigen::Matrix<fvar<T>, R, C>& y, fvar<T>& lp) {
53-
fvar<T> squared_norm = dot_self(y);
40+
template <typename EigMat, typename T,
41+
require_eigen_vt<is_fvar, EigMat>* = nullptr,
42+
require_stan_scalar_t<T>* = nullptr>
43+
inline auto unit_vector_constrain(const EigMat& y, T& lp) {
44+
const auto& y_ref = to_ref(y);
45+
const value_type_t<EigMat> squared_norm = dot_self(y_ref);
5446
lp -= 0.5 * squared_norm;
55-
return unit_vector_constrain(y);
47+
return unit_vector_constrain(y_ref);
5648
}
5749

5850
} // namespace math

stan/math/prim/err/check_cholesky_factor_corr.hpp

+13-12
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
#ifndef STAN_MATH_PRIM_ERR_CHECK_CHOLESKY_FACTOR_CORR_HPP
22
#define STAN_MATH_PRIM_ERR_CHECK_CHOLESKY_FACTOR_CORR_HPP
33

4-
#include <stan/math/prim/meta.hpp>
54
#include <stan/math/prim/fun/Eigen.hpp>
5+
#include <stan/math/prim/meta.hpp>
6+
#include <stan/math/prim/fun/to_ref.hpp>
67
#include <stan/math/prim/err/check_positive.hpp>
78
#include <stan/math/prim/err/check_lower_triangular.hpp>
89
#include <stan/math/prim/err/check_square.hpp>
@@ -19,24 +20,24 @@ namespace math {
1920
* be square, but require at least as many rows M as columns N
2021
* (i.e., M &gt;= N).
2122
* Tolerance is specified by <code>math::CONSTRAINT_TOLERANCE</code>.
22-
* @tparam T_y Type of elements of Cholesky factor
23+
* @tparam EigMat Type inheriting from `MatrixBase` with dynamic rows and
24+
* columns.
2325
* @param function Function name (for error messages)
2426
* @param name Variable name (for error messages)
2527
* @param y Matrix to test
2628
* @throw <code>std::domain_error</code> if y is not a valid Cholesky
2729
* factor, if number of rows is less than the number of columns,
2830
* if there are 0 columns, or if any element in matrix is NaN
2931
*/
30-
template <typename T_y>
31-
void check_cholesky_factor_corr(
32-
const char* function, const char* name,
33-
const Eigen::Matrix<T_y, Eigen::Dynamic, Eigen::Dynamic>& y) {
34-
check_square(function, name, y);
35-
check_lower_triangular(function, name, y);
36-
check_positive(function, name, y.diagonal());
37-
for (int i = 0; i < y.rows(); ++i) {
38-
Eigen::Matrix<T_y, Eigen::Dynamic, 1> y_i = y.row(i).transpose();
39-
check_unit_vector(function, name, y_i);
32+
template <typename EigMat, require_eigen_matrix_dynamic_t<EigMat>* = nullptr>
33+
void check_cholesky_factor_corr(const char* function, const char* name,
34+
const EigMat& y) {
35+
const auto& y_ref = to_ref(y);
36+
check_square(function, name, y_ref);
37+
check_lower_triangular(function, name, y_ref);
38+
check_positive(function, name, y_ref.diagonal());
39+
for (Eigen::Index i = 0; i < y_ref.rows(); ++i) {
40+
check_unit_vector(function, name, y_ref.row(i));
4041
}
4142
}
4243

stan/math/prim/err/check_corr_matrix.hpp

+14-15
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
#ifndef STAN_MATH_PRIM_ERR_CHECK_CORR_MATRIX_HPP
22
#define STAN_MATH_PRIM_ERR_CHECK_CORR_MATRIX_HPP
33

4+
#include <stan/math/prim/fun/Eigen.hpp>
45
#include <stan/math/prim/meta.hpp>
56
#include <stan/math/prim/err/throw_domain_error.hpp>
67
#include <stan/math/prim/err/check_pos_definite.hpp>
78
#include <stan/math/prim/err/check_square.hpp>
8-
#include <stan/math/prim/fun/Eigen.hpp>
9+
#include <stan/math/prim/fun/to_ref.hpp>
910
#include <sstream>
1011
#include <string>
1112
#include <cmath>
@@ -20,7 +21,8 @@ namespace math {
2021
* (inclusive).
2122
* This function throws exceptions if the variable is not a valid
2223
* correlation matrix.
23-
* @tparam T_y Type of scalar
24+
* @tparam EigMat Type inheriting from `MatrixBase` with dynamic rows and
25+
* columns.
2426
* @param function Name of the function this was called from
2527
* @param name Name of the variable
2628
* @param y Matrix to test
@@ -29,31 +31,28 @@ namespace math {
2931
* diagonals not near 1, not positive definite, or any of the
3032
* elements nan
3133
*/
32-
template <typename T_y>
33-
inline void check_corr_matrix(
34-
const char* function, const char* name,
35-
const Eigen::Matrix<T_y, Eigen::Dynamic, Eigen::Dynamic>& y) {
36-
using size_type
37-
= index_type_t<Eigen::Matrix<T_y, Eigen::Dynamic, Eigen::Dynamic>>;
38-
39-
check_square(function, name, y);
34+
template <typename EigMat, require_eigen_matrix_dynamic_t<EigMat>* = nullptr>
35+
inline void check_corr_matrix(const char* function, const char* name,
36+
const EigMat& y) {
37+
const auto& y_ref = to_ref(y);
38+
check_square(function, name, y_ref);
4039
using std::fabs;
41-
if (y.size() == 0) {
40+
if (y_ref.size() == 0) {
4241
return;
4342
}
4443

45-
for (size_type k = 0; k < y.rows(); ++k) {
46-
if (!(fabs(y(k, k) - 1.0) <= CONSTRAINT_TOLERANCE)) {
44+
for (Eigen::Index k = 0; k < y.rows(); ++k) {
45+
if (!(fabs(y_ref(k, k) - 1.0) <= CONSTRAINT_TOLERANCE)) {
4746
std::ostringstream msg;
4847
msg << "is not a valid correlation matrix. " << name << "("
4948
<< stan::error_index::value + k << "," << stan::error_index::value + k
5049
<< ") is ";
5150
std::string msg_str(msg.str());
52-
throw_domain_error(function, name, y(k, k), msg_str.c_str(),
51+
throw_domain_error(function, name, y_ref(k, k), msg_str.c_str(),
5352
", but should be near 1.0");
5453
}
5554
}
56-
check_pos_definite(function, "y", y);
55+
check_pos_definite(function, "y", y_ref);
5756
}
5857

5958
} // namespace math

stan/math/prim/err/check_cov_matrix.hpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#ifndef STAN_MATH_PRIM_ERR_CHECK_COV_MATRIX_HPP
22
#define STAN_MATH_PRIM_ERR_CHECK_COV_MATRIX_HPP
33

4-
#include <stan/math/prim/meta.hpp>
54
#include <stan/math/prim/fun/Eigen.hpp>
5+
#include <stan/math/prim/meta.hpp>
66
#include <stan/math/prim/err/check_pos_definite.hpp>
77

88
namespace stan {
@@ -11,7 +11,8 @@ namespace math {
1111
* Check if the specified matrix is a valid covariance matrix.
1212
* A valid covariance matrix is a square, symmetric matrix that is
1313
* positive definite.
14-
* @tparam T Type of scalar.
14+
* @tparam EigMat Type inheriting from `MatrixBase` with dynamic rows and
15+
* columns.
1516
* @param function Function name (for error messages)
1617
* @param name Variable name (for error messages)
1718
* @param y Matrix to test
@@ -21,10 +22,9 @@ namespace math {
2122
* if the matrix is not positive definite,
2223
* or if any element of the matrix is nan
2324
*/
24-
template <typename T_y>
25-
inline void check_cov_matrix(
26-
const char* function, const char* name,
27-
const Eigen::Matrix<T_y, Eigen::Dynamic, Eigen::Dynamic>& y) {
25+
template <typename EigMat, require_eigen_matrix_dynamic_t<EigMat>* = nullptr>
26+
inline void check_cov_matrix(const char* function, const char* name,
27+
const EigMat& y) {
2828
check_pos_definite(function, name, y);
2929
}
3030

stan/math/prim/err/check_greater.hpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include <stan/math/prim/err/throw_domain_error_vec.hpp>
77
#include <stan/math/prim/fun/get.hpp>
88
#include <stan/math/prim/fun/size.hpp>
9-
#include <functional>
9+
#include <stan/math/prim/fun/to_ref.hpp>
1010
#include <string>
1111

1212
namespace stan {
@@ -35,13 +35,15 @@ struct greater<T_y, T_low, true> {
3535
static void check(const char* function, const char* name, const T_y& y,
3636
const T_low& low) {
3737
scalar_seq_view<T_low> low_vec(low);
38-
for (size_t n = 0; n < stan::math::size(y); n++) {
39-
if (!(stan::get(y, n) > low_vec[n])) {
38+
const auto& y_ref = to_ref(y);
39+
for (size_t n = 0; n < stan::math::size(y_ref); n++) {
40+
if (!(stan::get(y_ref, n) > low_vec[n])) {
4041
std::stringstream msg;
4142
msg << ", but must be greater than ";
4243
msg << low_vec[n];
4344
std::string msg_str(msg.str());
44-
throw_domain_error_vec(function, name, y, n, "is ", msg_str.c_str());
45+
throw_domain_error_vec(function, name, y_ref, n, "is ",
46+
msg_str.c_str());
4547
}
4648
}
4749
}

stan/math/prim/err/check_greater_or_equal.hpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <stan/math/prim/err/throw_domain_error.hpp>
77
#include <stan/math/prim/fun/get.hpp>
88
#include <stan/math/prim/fun/size.hpp>
9+
#include <stan/math/prim/fun/to_ref.hpp>
910
#include <string>
1011

1112
namespace stan {
@@ -34,13 +35,15 @@ struct greater_or_equal<T_y, T_low, true> {
3435
static void check(const char* function, const char* name, const T_y& y,
3536
const T_low& low) {
3637
scalar_seq_view<T_low> low_vec(low);
37-
for (size_t n = 0; n < stan::math::size(y); n++) {
38-
if (!(stan::get(y, n) >= low_vec[n])) {
38+
const auto& y_ref = to_ref(y);
39+
for (size_t n = 0; n < stan::math::size(y_ref); n++) {
40+
if (!(stan::get(y_ref, n) >= low_vec[n])) {
3941
std::stringstream msg;
4042
msg << ", but must be greater than or equal to ";
4143
msg << low_vec[n];
4244
std::string msg_str(msg.str());
43-
throw_domain_error_vec(function, name, y, n, "is ", msg_str.c_str());
45+
throw_domain_error_vec(function, name, y_ref, n, "is ",
46+
msg_str.c_str());
4447
}
4548
}
4649
}

stan/math/prim/err/check_less.hpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include <stan/math/prim/err/throw_domain_error_vec.hpp>
77
#include <stan/math/prim/fun/get.hpp>
88
#include <stan/math/prim/fun/size.hpp>
9-
#include <functional>
9+
#include <stan/math/prim/fun/to_ref.hpp>
1010
#include <string>
1111

1212
namespace stan {
@@ -35,13 +35,15 @@ struct less<T_y, T_high, true> {
3535
static void check(const char* function, const char* name, const T_y& y,
3636
const T_high& high) {
3737
scalar_seq_view<T_high> high_vec(high);
38-
for (size_t n = 0; n < stan::math::size(y); n++) {
39-
if (!(stan::get(y, n) < high_vec[n])) {
38+
const auto& y_ref = to_ref(y);
39+
for (size_t n = 0; n < stan::math::size(y_ref); n++) {
40+
if (!(stan::get(y_ref, n) < high_vec[n])) {
4041
std::stringstream msg;
4142
msg << ", but must be less than ";
4243
msg << high_vec[n];
4344
std::string msg_str(msg.str());
44-
throw_domain_error_vec(function, name, y, n, "is ", msg_str.c_str());
45+
throw_domain_error_vec(function, name, y_ref, n, "is ",
46+
msg_str.c_str());
4547
}
4648
}
4749
}

stan/math/prim/err/check_less_or_equal.hpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <stan/math/prim/err/throw_domain_error_vec.hpp>
77
#include <stan/math/prim/fun/get.hpp>
88
#include <stan/math/prim/fun/size.hpp>
9+
#include <stan/math/prim/fun/to_ref.hpp>
910
#include <string>
1011

1112
namespace stan {
@@ -34,13 +35,15 @@ struct less_or_equal<T_y, T_high, true> {
3435
static void check(const char* function, const char* name, const T_y& y,
3536
const T_high& high) {
3637
scalar_seq_view<T_high> high_vec(high);
37-
for (size_t n = 0; n < stan::math::size(y); n++) {
38-
if (!(stan::get(y, n) <= high_vec[n])) {
38+
const auto& y_ref = to_ref(y);
39+
for (size_t n = 0; n < stan::math::size(y_ref); n++) {
40+
if (!(stan::get(y_ref, n) <= high_vec[n])) {
3941
std::stringstream msg;
4042
msg << ", but must be less than or equal to ";
4143
msg << high_vec[n];
4244
std::string msg_str(msg.str());
43-
throw_domain_error_vec(function, name, y, n, "is ", msg_str.c_str());
45+
throw_domain_error_vec(function, name, y_ref, n, "is ",
46+
msg_str.c_str());
4447
}
4548
}
4649
}

stan/math/prim/err/check_lower_triangular.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ namespace math {
2626
template <typename T_y, require_eigen_t<T_y>* = nullptr>
2727
inline void check_lower_triangular(const char* function, const char* name,
2828
const T_y& y) {
29-
const Eigen::Ref<const plain_type_t<T_y>>& y_ref = y;
29+
const auto& y_ref = to_ref(y);
3030
for (int n = 1; n < y.cols(); ++n) {
3131
for (int m = 0; m < n && m < y.rows(); ++m) {
3232
if (y_ref(m, n) != 0) {

stan/math/prim/err/check_ordered.hpp

+8-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
#ifndef STAN_MATH_PRIM_ERR_CHECK_ORDERED_HPP
22
#define STAN_MATH_PRIM_ERR_CHECK_ORDERED_HPP
33

4+
#include <stan/math/prim/fun/Eigen.hpp>
45
#include <stan/math/prim/meta.hpp>
56
#include <stan/math/prim/err/throw_domain_error.hpp>
6-
#include <stan/math/prim/fun/Eigen.hpp>
7+
#include <stan/math/prim/fun/to_ref.hpp>
78
#include <sstream>
89
#include <string>
910
#include <vector>
@@ -23,16 +24,18 @@ namespace math {
2324
*/
2425
template <typename T_y, require_eigen_vector_t<T_y>* = nullptr>
2526
void check_ordered(const char* function, const char* name, const T_y& y) {
26-
for (Eigen::Index n = 1; n < y.size(); n++) {
27-
if (!(y[n] > y[n - 1])) {
27+
const auto& y_ref = to_ref(y);
28+
for (Eigen::Index n = 1; n < y_ref.size(); n++) {
29+
if (!(y_ref[n] > y_ref[n - 1])) {
2830
std::ostringstream msg1;
2931
msg1 << "is not a valid ordered vector."
3032
<< " The element at " << stan::error_index::value + n << " is ";
3133
std::string msg1_str(msg1.str());
3234
std::ostringstream msg2;
33-
msg2 << ", but should be greater than the previous element, " << y[n - 1];
35+
msg2 << ", but should be greater than the previous element, "
36+
<< y_ref[n - 1];
3437
std::string msg2_str(msg2.str());
35-
throw_domain_error(function, name, y[n], msg1_str.c_str(),
38+
throw_domain_error(function, name, y_ref[n], msg1_str.c_str(),
3639
msg2_str.c_str());
3740
}
3841
}

0 commit comments

Comments
 (0)