Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use immediately invoked lambdas in size and range error checks #2255

Merged
merged 10 commits into from
Dec 21, 2020
16 changes: 8 additions & 8 deletions stan/math/prim/err/check_column_index.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ namespace math {
template <typename T_y, typename = require_eigen_t<T_y>>
inline void check_column_index(const char* function, const char* name,
const T_y& y, size_t i) {
if (i >= stan::error_index::value
&& i < static_cast<size_t>(y.cols()) + stan::error_index::value) {
return;
if (!(i >= stan::error_index::value
&& i < static_cast<size_t>(y.cols()) + stan::error_index::value)) {
[&]() STAN_COLD_PATH {
std::stringstream msg;
msg << " for columns of " << name;
std::string msg_str(msg.str());
out_of_range(function, y.cols(), i, msg_str.c_str());
}();
}

std::stringstream msg;
msg << " for columns of " << name;
std::string msg_str(msg.str());
out_of_range(function, y.cols(), i, msg_str.c_str());
}

} // namespace math
Expand Down
28 changes: 14 additions & 14 deletions stan/math/prim/err/check_consistent_size.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,21 @@ namespace math {
template <typename T>
inline void check_consistent_size(const char* function, const char* name,
const T& x, size_t expected_size) {
if (!is_vector<T>::value
|| (is_vector<T>::value && expected_size == stan::math::size(x))) {
return;
}

std::stringstream msg;
msg << ", expecting dimension = " << expected_size
<< "; a function was called with arguments of different "
<< "scalar, array, vector, or matrix types, and they were not "
<< "consistently sized; all arguments must be scalars or "
<< "multidimensional values of the same shape.";
std::string msg_str(msg.str());
if (!(!is_vector<T>::value
|| (is_vector<T>::value && expected_size == stan::math::size(x)))) {
[&]() STAN_COLD_PATH {
std::stringstream msg;
msg << ", expecting dimension = " << expected_size
<< "; a function was called with arguments of different "
<< "scalar, array, vector, or matrix types, and they were not "
<< "consistently sized; all arguments must be scalars or "
<< "multidimensional values of the same shape.";
std::string msg_str(msg.str());

invalid_argument(function, name, stan::math::size(x),
"has dimension = ", msg_str.c_str());
invalid_argument(function, name, stan::math::size(x),
"has dimension = ", msg_str.c_str());
}();
}
}

} // namespace math
Expand Down
17 changes: 10 additions & 7 deletions stan/math/prim/err/check_consistent_sizes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,16 @@ inline void check_consistent_sizes(const char* function, const char* name1,
} else if (stan::math::size(x1) == stan::math::size(x2)) {
check_consistent_sizes(function, name1, x1, names_and_xs...);
} else {
size_t size_x1 = stan::math::size(x1);
size_t size_x2 = stan::math::size(x2);
std::stringstream msg;
msg << ", but " << name2 << " has size " << size_x2
<< "; and they must be the same size.";
std::string msg_str(msg.str());
invalid_argument(function, name1, size_x1, "has size = ", msg_str.c_str());
[&]() STAN_COLD_PATH {
size_t size_x1 = stan::math::size(x1);
size_t size_x2 = stan::math::size(x2);
std::stringstream msg;
msg << ", but " << name2 << " has size " << size_x2
<< "; and they must be the same size.";
std::string msg_str(msg.str());
invalid_argument(function, name1, size_x1,
"has size = ", msg_str.c_str());
}();
}
}

Expand Down
17 changes: 10 additions & 7 deletions stan/math/prim/err/check_consistent_sizes_mvt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,16 @@ inline void check_consistent_sizes_mvt(const char* function, const char* name1,
} else if (stan::math::size(x1) == stan::math::size(x2)) {
check_consistent_sizes_mvt(function, name1, x1, names_and_xs...);
} else {
size_t size_x1 = stan::math::size(x1);
size_t size_x2 = stan::math::size(x2);
std::stringstream msg;
msg << ", but " << name2 << " has size " << size_x2
<< "; and they must be the same size.";
std::string msg_str(msg.str());
invalid_argument(function, name1, size_x1, "has size = ", msg_str.c_str());
[&]() STAN_COLD_PATH {
size_t size_x1 = stan::math::size(x1);
size_t size_x2 = stan::math::size(x2);
std::stringstream msg;
msg << ", but " << name2 << " has size " << size_x2
<< "; and they must be the same size.";
std::string msg_str(msg.str());
invalid_argument(function, name1, size_x1,
"has size = ", msg_str.c_str());
}();
}
}

Expand Down
26 changes: 15 additions & 11 deletions stan/math/prim/err/check_greater.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ struct greater {
scalar_seq_view<T_low> low_vec(low);
for (size_t n = 0; n < stan::math::size(low); n++) {
if (!(y > low_vec[n])) {
std::stringstream msg;
msg << ", but must be greater than ";
msg << low_vec[n];
std::string msg_str(msg.str());
throw_domain_error(function, name, y, "is ", msg_str.c_str());
[&]() STAN_COLD_PATH {
std::stringstream msg;
msg << ", but must be greater than ";
msg << low_vec[n];
std::string msg_str(msg.str());
throw_domain_error(function, name, y, "is ", msg_str.c_str());
}();
}
}
}
Expand All @@ -38,12 +40,14 @@ struct greater<T_y, T_low, true> {
const auto& y_ref = to_ref(y);
for (size_t n = 0; n < stan::math::size(y_ref); n++) {
if (!(stan::get(y_ref, n) > low_vec[n])) {
std::stringstream msg;
msg << ", but must be greater than ";
msg << low_vec[n];
std::string msg_str(msg.str());
throw_domain_error_vec(function, name, y_ref, n, "is ",
msg_str.c_str());
[&]() STAN_COLD_PATH {
std::stringstream msg;
msg << ", but must be greater than ";
msg << low_vec[n];
std::string msg_str(msg.str());
throw_domain_error_vec(function, name, y_ref, n, "is ",
msg_str.c_str());
}();
}
}
}
Expand Down
26 changes: 15 additions & 11 deletions stan/math/prim/err/check_greater_or_equal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ struct greater_or_equal {
scalar_seq_view<T_low> low_vec(low);
for (size_t n = 0; n < stan::math::size(low); n++) {
if (!(y >= low_vec[n])) {
std::stringstream msg;
msg << ", but must be greater than or equal to ";
msg << low_vec[n];
std::string msg_str(msg.str());
throw_domain_error(function, name, y, "is ", msg_str.c_str());
[&]() STAN_COLD_PATH {
std::stringstream msg;
msg << ", but must be greater than or equal to ";
msg << low_vec[n];
std::string msg_str(msg.str());
throw_domain_error(function, name, y, "is ", msg_str.c_str());
}();
}
}
}
Expand All @@ -38,12 +40,14 @@ struct greater_or_equal<T_y, T_low, true> {
const auto& y_ref = to_ref(y);
for (size_t n = 0; n < stan::math::size(y_ref); n++) {
if (!(stan::get(y_ref, n) >= low_vec[n])) {
std::stringstream msg;
msg << ", but must be greater than or equal to ";
msg << low_vec[n];
std::string msg_str(msg.str());
throw_domain_error_vec(function, name, y_ref, n, "is ",
msg_str.c_str());
[&]() STAN_COLD_PATH {
std::stringstream msg;
msg << ", but must be greater than or equal to ";
msg << low_vec[n];
std::string msg_str(msg.str());
throw_domain_error_vec(function, name, y_ref, n, "is ",
msg_str.c_str());
}();
}
}
}
Expand Down
26 changes: 15 additions & 11 deletions stan/math/prim/err/check_less.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ struct less {
scalar_seq_view<T_high> high_vec(high);
for (size_t n = 0; n < stan::math::size(high); n++) {
if (!(y < high_vec[n])) {
std::stringstream msg;
msg << ", but must be less than ";
msg << high_vec[n];
std::string msg_str(msg.str());
throw_domain_error(function, name, y, "is ", msg_str.c_str());
[&]() STAN_COLD_PATH {
std::stringstream msg;
msg << ", but must be less than ";
msg << high_vec[n];
std::string msg_str(msg.str());
throw_domain_error(function, name, y, "is ", msg_str.c_str());
}();
}
}
}
Expand All @@ -38,12 +40,14 @@ struct less<T_y, T_high, true> {
const auto& y_ref = to_ref(y);
for (size_t n = 0; n < stan::math::size(y_ref); n++) {
if (!(stan::get(y_ref, n) < high_vec[n])) {
std::stringstream msg;
msg << ", but must be less than ";
msg << high_vec[n];
std::string msg_str(msg.str());
throw_domain_error_vec(function, name, y_ref, n, "is ",
msg_str.c_str());
[&]() STAN_COLD_PATH {
std::stringstream msg;
msg << ", but must be less than ";
msg << high_vec[n];
std::string msg_str(msg.str());
throw_domain_error_vec(function, name, y_ref, n, "is ",
msg_str.c_str());
}();
}
}
}
Expand Down
26 changes: 15 additions & 11 deletions stan/math/prim/err/check_less_or_equal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ struct less_or_equal {
scalar_seq_view<T_high> high_vec(high);
for (size_t n = 0; n < stan::math::size(high); n++) {
if (!(y <= high_vec[n])) {
std::stringstream msg;
msg << ", but must be less than or equal to ";
msg << high_vec[n];
std::string msg_str(msg.str());
throw_domain_error(function, name, y, "is ", msg_str.c_str());
[&]() STAN_COLD_PATH {
std::stringstream msg;
msg << ", but must be less than or equal to ";
msg << high_vec[n];
std::string msg_str(msg.str());
throw_domain_error(function, name, y, "is ", msg_str.c_str());
}();
}
}
}
Expand All @@ -38,12 +40,14 @@ struct less_or_equal<T_y, T_high, true> {
const auto& y_ref = to_ref(y);
for (size_t n = 0; n < stan::math::size(y_ref); n++) {
if (!(stan::get(y_ref, n) <= high_vec[n])) {
std::stringstream msg;
msg << ", but must be less than or equal to ";
msg << high_vec[n];
std::string msg_str(msg.str());
throw_domain_error_vec(function, name, y_ref, n, "is ",
msg_str.c_str());
[&]() STAN_COLD_PATH {
std::stringstream msg;
msg << ", but must be less than or equal to ";
msg << high_vec[n];
std::string msg_str(msg.str());
throw_domain_error_vec(function, name, y_ref, n, "is ",
msg_str.c_str());
}();
}
}
}
Expand Down
47 changes: 25 additions & 22 deletions stan/math/prim/err/check_matching_dims.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,7 @@ inline void check_matching_dims(const char* function, const char* name1,
std::vector<int> y1_d = dims(y1);
std::vector<int> y2_d = dims(y2);
bool error = false;
if (y1_d.size() != y2_d.size()) {
error = true;
} else {
for (int i = 0; i < y1_d.size(); i++) {
if (y1_d[i] != y2_d[i]) {
error = true;
break;
}
}
}
if (error) {
auto error_throw = [&]() STAN_COLD_PATH {
std::ostringstream y1s;
if (y1_d.size() > 0) {
y1s << y1_d[0];
Expand All @@ -58,6 +48,15 @@ inline void check_matching_dims(const char* function, const char* name1,
msg << ") must match in size";
std::string msg_str(msg.str());
invalid_argument(function, name1, y1s.str(), "(", msg_str.c_str());
};
if (y1_d.size() != y2_d.size()) {
error_throw();
} else {
for (int i = 0; i < y1_d.size(); i++) {
if (y1_d[i] != y2_d[i]) {
error_throw();
}
}
}
}

Expand All @@ -77,12 +76,14 @@ template <typename T1, typename T2, require_all_matrix_t<T1, T2>* = nullptr>
inline void check_matching_dims(const char* function, const char* name1,
const T1& y1, const char* name2, const T2& y2) {
if (y1.rows() != y2.rows() || y1.cols() != y2.cols()) {
std::ostringstream y1_err;
std::ostringstream msg_str;
y1_err << "(" << y1.rows() << ", " << y1.cols() << ")";
msg_str << y2.rows() << ", " << y2.cols() << ") must match in size";
invalid_argument(function, name1, y1_err.str(), "(",
std::string(msg_str.str()).c_str());
[&]() STAN_COLD_PATH {
std::ostringstream y1_err;
std::ostringstream msg_str;
y1_err << "(" << y1.rows() << ", " << y1.cols() << ")";
msg_str << y2.rows() << ", " << y2.cols() << ") must match in size";
invalid_argument(function, name1, y1_err.str(), "(",
std::string(msg_str.str()).c_str());
}();
}
}

Expand Down Expand Up @@ -135,11 +136,13 @@ inline void check_matching_dims(const char* function, const char* name1,
!= static_cast<int>(Mat2::RowsAtCompileTime)
|| static_cast<int>(Mat1::ColsAtCompileTime)
!= static_cast<int>(Mat2::ColsAtCompileTime))) {
std::ostringstream msg;
msg << "Static rows and cols of " << name1 << " and " << name2
<< " must match in size.";
std::string msg_str(msg.str());
invalid_argument(function, msg_str.c_str(), "", "");
[&]() STAN_COLD_PATH {
std::ostringstream msg;
msg << "Static rows and cols of " << name1 << " and " << name2
<< " must match in size.";
std::string msg_str(msg.str());
invalid_argument(function, msg_str.c_str(), "", "");
}();
}
check_matching_dims(function, name1, y1, name2, y2);
}
Expand Down
9 changes: 5 additions & 4 deletions stan/math/prim/err/check_nonzero_size.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ namespace math {
template <typename T_y>
inline void check_nonzero_size(const char* function, const char* name,
const T_y& y) {
if (y.size() > 0) {
return;
if (y.size() == 0) {
[&]() STAN_COLD_PATH {
invalid_argument(function, name, 0, "has size ",
", but must have a non-zero size");
}();
}
invalid_argument(function, name, 0, "has size ",
", but must have a non-zero size");
}

} // namespace math
Expand Down
Loading