Skip to content

Commit effd263

Browse files
authored
Merge pull request #2255 from stan-dev/feature/cold-path-range-size-errors
use immediately invoked lambdas in size and range error checks
2 parents d932ecf + f9f0dc0 commit effd263

28 files changed

+335
-282
lines changed

stan/math/prim/err/check_column_index.hpp

+8-8
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@ namespace math {
2727
template <typename T_y, typename = require_eigen_t<T_y>>
2828
inline void check_column_index(const char* function, const char* name,
2929
const T_y& y, size_t i) {
30-
if (i >= stan::error_index::value
31-
&& i < static_cast<size_t>(y.cols()) + stan::error_index::value) {
32-
return;
30+
if (!(i >= stan::error_index::value
31+
&& i < static_cast<size_t>(y.cols()) + stan::error_index::value)) {
32+
[&]() STAN_COLD_PATH {
33+
std::stringstream msg;
34+
msg << " for columns of " << name;
35+
std::string msg_str(msg.str());
36+
out_of_range(function, y.cols(), i, msg_str.c_str());
37+
}();
3338
}
34-
35-
std::stringstream msg;
36-
msg << " for columns of " << name;
37-
std::string msg_str(msg.str());
38-
out_of_range(function, y.cols(), i, msg_str.c_str());
3939
}
4040

4141
} // namespace math

stan/math/prim/err/check_consistent_size.hpp

+14-14
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,21 @@ namespace math {
2424
template <typename T>
2525
inline void check_consistent_size(const char* function, const char* name,
2626
const T& x, size_t expected_size) {
27-
if (!is_vector<T>::value
28-
|| (is_vector<T>::value && expected_size == stan::math::size(x))) {
29-
return;
30-
}
31-
32-
std::stringstream msg;
33-
msg << ", expecting dimension = " << expected_size
34-
<< "; a function was called with arguments of different "
35-
<< "scalar, array, vector, or matrix types, and they were not "
36-
<< "consistently sized; all arguments must be scalars or "
37-
<< "multidimensional values of the same shape.";
38-
std::string msg_str(msg.str());
27+
if (!(!is_vector<T>::value
28+
|| (is_vector<T>::value && expected_size == stan::math::size(x)))) {
29+
[&]() STAN_COLD_PATH {
30+
std::stringstream msg;
31+
msg << ", expecting dimension = " << expected_size
32+
<< "; a function was called with arguments of different "
33+
<< "scalar, array, vector, or matrix types, and they were not "
34+
<< "consistently sized; all arguments must be scalars or "
35+
<< "multidimensional values of the same shape.";
36+
std::string msg_str(msg.str());
3937

40-
invalid_argument(function, name, stan::math::size(x),
41-
"has dimension = ", msg_str.c_str());
38+
invalid_argument(function, name, stan::math::size(x),
39+
"has dimension = ", msg_str.c_str());
40+
}();
41+
}
4242
}
4343

4444
} // namespace math

stan/math/prim/err/check_consistent_sizes.hpp

+10-7
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,16 @@ inline void check_consistent_sizes(const char* function, const char* name1,
5353
} else if (stan::math::size(x1) == stan::math::size(x2)) {
5454
check_consistent_sizes(function, name1, x1, names_and_xs...);
5555
} else {
56-
size_t size_x1 = stan::math::size(x1);
57-
size_t size_x2 = stan::math::size(x2);
58-
std::stringstream msg;
59-
msg << ", but " << name2 << " has size " << size_x2
60-
<< "; and they must be the same size.";
61-
std::string msg_str(msg.str());
62-
invalid_argument(function, name1, size_x1, "has size = ", msg_str.c_str());
56+
[&]() STAN_COLD_PATH {
57+
size_t size_x1 = stan::math::size(x1);
58+
size_t size_x2 = stan::math::size(x2);
59+
std::stringstream msg;
60+
msg << ", but " << name2 << " has size " << size_x2
61+
<< "; and they must be the same size.";
62+
std::string msg_str(msg.str());
63+
invalid_argument(function, name1, size_x1,
64+
"has size = ", msg_str.c_str());
65+
}();
6366
}
6467
}
6568

stan/math/prim/err/check_consistent_sizes_mvt.hpp

+10-7
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,16 @@ inline void check_consistent_sizes_mvt(const char* function, const char* name1,
5454
} else if (stan::math::size(x1) == stan::math::size(x2)) {
5555
check_consistent_sizes_mvt(function, name1, x1, names_and_xs...);
5656
} else {
57-
size_t size_x1 = stan::math::size(x1);
58-
size_t size_x2 = stan::math::size(x2);
59-
std::stringstream msg;
60-
msg << ", but " << name2 << " has size " << size_x2
61-
<< "; and they must be the same size.";
62-
std::string msg_str(msg.str());
63-
invalid_argument(function, name1, size_x1, "has size = ", msg_str.c_str());
57+
[&]() STAN_COLD_PATH {
58+
size_t size_x1 = stan::math::size(x1);
59+
size_t size_x2 = stan::math::size(x2);
60+
std::stringstream msg;
61+
msg << ", but " << name2 << " has size " << size_x2
62+
<< "; and they must be the same size.";
63+
std::string msg_str(msg.str());
64+
invalid_argument(function, name1, size_x1,
65+
"has size = ", msg_str.c_str());
66+
}();
6467
}
6568
}
6669

stan/math/prim/err/check_greater.hpp

+15-11
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@ struct greater {
2020
scalar_seq_view<T_low> low_vec(low);
2121
for (size_t n = 0; n < stan::math::size(low); n++) {
2222
if (!(y > low_vec[n])) {
23-
std::stringstream msg;
24-
msg << ", but must be greater than ";
25-
msg << low_vec[n];
26-
std::string msg_str(msg.str());
27-
throw_domain_error(function, name, y, "is ", msg_str.c_str());
23+
[&]() STAN_COLD_PATH {
24+
std::stringstream msg;
25+
msg << ", but must be greater than ";
26+
msg << low_vec[n];
27+
std::string msg_str(msg.str());
28+
throw_domain_error(function, name, y, "is ", msg_str.c_str());
29+
}();
2830
}
2931
}
3032
}
@@ -38,12 +40,14 @@ struct greater<T_y, T_low, true> {
3840
const auto& y_ref = to_ref(y);
3941
for (size_t n = 0; n < stan::math::size(y_ref); n++) {
4042
if (!(stan::get(y_ref, n) > low_vec[n])) {
41-
std::stringstream msg;
42-
msg << ", but must be greater than ";
43-
msg << low_vec[n];
44-
std::string msg_str(msg.str());
45-
throw_domain_error_vec(function, name, y_ref, n, "is ",
46-
msg_str.c_str());
43+
[&]() STAN_COLD_PATH {
44+
std::stringstream msg;
45+
msg << ", but must be greater than ";
46+
msg << low_vec[n];
47+
std::string msg_str(msg.str());
48+
throw_domain_error_vec(function, name, y_ref, n, "is ",
49+
msg_str.c_str());
50+
}();
4751
}
4852
}
4953
}

stan/math/prim/err/check_greater_or_equal.hpp

+15-11
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@ struct greater_or_equal {
2020
scalar_seq_view<T_low> low_vec(low);
2121
for (size_t n = 0; n < stan::math::size(low); n++) {
2222
if (!(y >= low_vec[n])) {
23-
std::stringstream msg;
24-
msg << ", but must be greater than or equal to ";
25-
msg << low_vec[n];
26-
std::string msg_str(msg.str());
27-
throw_domain_error(function, name, y, "is ", msg_str.c_str());
23+
[&]() STAN_COLD_PATH {
24+
std::stringstream msg;
25+
msg << ", but must be greater than or equal to ";
26+
msg << low_vec[n];
27+
std::string msg_str(msg.str());
28+
throw_domain_error(function, name, y, "is ", msg_str.c_str());
29+
}();
2830
}
2931
}
3032
}
@@ -38,12 +40,14 @@ struct greater_or_equal<T_y, T_low, true> {
3840
const auto& y_ref = to_ref(y);
3941
for (size_t n = 0; n < stan::math::size(y_ref); n++) {
4042
if (!(stan::get(y_ref, n) >= low_vec[n])) {
41-
std::stringstream msg;
42-
msg << ", but must be greater than or equal to ";
43-
msg << low_vec[n];
44-
std::string msg_str(msg.str());
45-
throw_domain_error_vec(function, name, y_ref, n, "is ",
46-
msg_str.c_str());
43+
[&]() STAN_COLD_PATH {
44+
std::stringstream msg;
45+
msg << ", but must be greater than or equal to ";
46+
msg << low_vec[n];
47+
std::string msg_str(msg.str());
48+
throw_domain_error_vec(function, name, y_ref, n, "is ",
49+
msg_str.c_str());
50+
}();
4751
}
4852
}
4953
}

stan/math/prim/err/check_less.hpp

+15-11
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@ struct less {
2020
scalar_seq_view<T_high> high_vec(high);
2121
for (size_t n = 0; n < stan::math::size(high); n++) {
2222
if (!(y < high_vec[n])) {
23-
std::stringstream msg;
24-
msg << ", but must be less than ";
25-
msg << high_vec[n];
26-
std::string msg_str(msg.str());
27-
throw_domain_error(function, name, y, "is ", msg_str.c_str());
23+
[&]() STAN_COLD_PATH {
24+
std::stringstream msg;
25+
msg << ", but must be less than ";
26+
msg << high_vec[n];
27+
std::string msg_str(msg.str());
28+
throw_domain_error(function, name, y, "is ", msg_str.c_str());
29+
}();
2830
}
2931
}
3032
}
@@ -38,12 +40,14 @@ struct less<T_y, T_high, true> {
3840
const auto& y_ref = to_ref(y);
3941
for (size_t n = 0; n < stan::math::size(y_ref); n++) {
4042
if (!(stan::get(y_ref, n) < high_vec[n])) {
41-
std::stringstream msg;
42-
msg << ", but must be less than ";
43-
msg << high_vec[n];
44-
std::string msg_str(msg.str());
45-
throw_domain_error_vec(function, name, y_ref, n, "is ",
46-
msg_str.c_str());
43+
[&]() STAN_COLD_PATH {
44+
std::stringstream msg;
45+
msg << ", but must be less than ";
46+
msg << high_vec[n];
47+
std::string msg_str(msg.str());
48+
throw_domain_error_vec(function, name, y_ref, n, "is ",
49+
msg_str.c_str());
50+
}();
4751
}
4852
}
4953
}

stan/math/prim/err/check_less_or_equal.hpp

+15-11
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@ struct less_or_equal {
2020
scalar_seq_view<T_high> high_vec(high);
2121
for (size_t n = 0; n < stan::math::size(high); n++) {
2222
if (!(y <= high_vec[n])) {
23-
std::stringstream msg;
24-
msg << ", but must be less than or equal to ";
25-
msg << high_vec[n];
26-
std::string msg_str(msg.str());
27-
throw_domain_error(function, name, y, "is ", msg_str.c_str());
23+
[&]() STAN_COLD_PATH {
24+
std::stringstream msg;
25+
msg << ", but must be less than or equal to ";
26+
msg << high_vec[n];
27+
std::string msg_str(msg.str());
28+
throw_domain_error(function, name, y, "is ", msg_str.c_str());
29+
}();
2830
}
2931
}
3032
}
@@ -38,12 +40,14 @@ struct less_or_equal<T_y, T_high, true> {
3840
const auto& y_ref = to_ref(y);
3941
for (size_t n = 0; n < stan::math::size(y_ref); n++) {
4042
if (!(stan::get(y_ref, n) <= high_vec[n])) {
41-
std::stringstream msg;
42-
msg << ", but must be less than or equal to ";
43-
msg << high_vec[n];
44-
std::string msg_str(msg.str());
45-
throw_domain_error_vec(function, name, y_ref, n, "is ",
46-
msg_str.c_str());
43+
[&]() STAN_COLD_PATH {
44+
std::stringstream msg;
45+
msg << ", but must be less than or equal to ";
46+
msg << high_vec[n];
47+
std::string msg_str(msg.str());
48+
throw_domain_error_vec(function, name, y_ref, n, "is ",
49+
msg_str.c_str());
50+
}();
4751
}
4852
}
4953
}

stan/math/prim/err/check_matching_dims.hpp

+25-22
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,7 @@ inline void check_matching_dims(const char* function, const char* name1,
2929
std::vector<int> y1_d = dims(y1);
3030
std::vector<int> y2_d = dims(y2);
3131
bool error = false;
32-
if (y1_d.size() != y2_d.size()) {
33-
error = true;
34-
} else {
35-
for (int i = 0; i < y1_d.size(); i++) {
36-
if (y1_d[i] != y2_d[i]) {
37-
error = true;
38-
break;
39-
}
40-
}
41-
}
42-
if (error) {
32+
auto error_throw = [&]() STAN_COLD_PATH {
4333
std::ostringstream y1s;
4434
if (y1_d.size() > 0) {
4535
y1s << y1_d[0];
@@ -58,6 +48,15 @@ inline void check_matching_dims(const char* function, const char* name1,
5848
msg << ") must match in size";
5949
std::string msg_str(msg.str());
6050
invalid_argument(function, name1, y1s.str(), "(", msg_str.c_str());
51+
};
52+
if (y1_d.size() != y2_d.size()) {
53+
error_throw();
54+
} else {
55+
for (int i = 0; i < y1_d.size(); i++) {
56+
if (y1_d[i] != y2_d[i]) {
57+
error_throw();
58+
}
59+
}
6160
}
6261
}
6362

@@ -77,12 +76,14 @@ template <typename T1, typename T2, require_all_matrix_t<T1, T2>* = nullptr>
7776
inline void check_matching_dims(const char* function, const char* name1,
7877
const T1& y1, const char* name2, const T2& y2) {
7978
if (y1.rows() != y2.rows() || y1.cols() != y2.cols()) {
80-
std::ostringstream y1_err;
81-
std::ostringstream msg_str;
82-
y1_err << "(" << y1.rows() << ", " << y1.cols() << ")";
83-
msg_str << y2.rows() << ", " << y2.cols() << ") must match in size";
84-
invalid_argument(function, name1, y1_err.str(), "(",
85-
std::string(msg_str.str()).c_str());
79+
[&]() STAN_COLD_PATH {
80+
std::ostringstream y1_err;
81+
std::ostringstream msg_str;
82+
y1_err << "(" << y1.rows() << ", " << y1.cols() << ")";
83+
msg_str << y2.rows() << ", " << y2.cols() << ") must match in size";
84+
invalid_argument(function, name1, y1_err.str(), "(",
85+
std::string(msg_str.str()).c_str());
86+
}();
8687
}
8788
}
8889

@@ -135,11 +136,13 @@ inline void check_matching_dims(const char* function, const char* name1,
135136
!= static_cast<int>(Mat2::RowsAtCompileTime)
136137
|| static_cast<int>(Mat1::ColsAtCompileTime)
137138
!= static_cast<int>(Mat2::ColsAtCompileTime))) {
138-
std::ostringstream msg;
139-
msg << "Static rows and cols of " << name1 << " and " << name2
140-
<< " must match in size.";
141-
std::string msg_str(msg.str());
142-
invalid_argument(function, msg_str.c_str(), "", "");
139+
[&]() STAN_COLD_PATH {
140+
std::ostringstream msg;
141+
msg << "Static rows and cols of " << name1 << " and " << name2
142+
<< " must match in size.";
143+
std::string msg_str(msg.str());
144+
invalid_argument(function, msg_str.c_str(), "", "");
145+
}();
143146
}
144147
check_matching_dims(function, name1, y1, name2, y2);
145148
}

stan/math/prim/err/check_nonzero_size.hpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@ namespace math {
2121
template <typename T_y>
2222
inline void check_nonzero_size(const char* function, const char* name,
2323
const T_y& y) {
24-
if (y.size() > 0) {
25-
return;
24+
if (y.size() == 0) {
25+
[&]() STAN_COLD_PATH {
26+
invalid_argument(function, name, 0, "has size ",
27+
", but must have a non-zero size");
28+
}();
2629
}
27-
invalid_argument(function, name, 0, "has size ",
28-
", but must have a non-zero size");
2930
}
3031

3132
} // namespace math

0 commit comments

Comments
 (0)