-
-
Notifications
You must be signed in to change notification settings - Fork 190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve Numerical Stability of Bernoulli CDF functions #2784
Changes from 17 commits
2b8ae14
c2022be
a9d0800
bcfaec1
143f707
62d8640
0edeb5f
fd6d39c
f03c35f
2b3cf01
a371c2b
2bade77
409c971
6220f14
2c11c42
136fc9c
ccde5d5
a4b384c
da1f7a8
abfa364
6103b16
53ce103
9872fcf
bd11e54
8f5cdb8
cabafd6
47c4f9a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
#ifndef STAN_MATH_PRIM_FUN_SELECT_HPP | ||
#define STAN_MATH_PRIM_FUN_SELECT_HPP | ||
|
||
#include <stan/math/prim/meta.hpp> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
/** | ||
* Return the second argument if the first argument is true | ||
* and otherwise return the third argument. | ||
* | ||
* <code>select(c, y1, y0) = c ? y1 : y0</code>. | ||
* | ||
* @tparam T_true type of the true argument | ||
* @tparam T_false type of the false argument | ||
* @param c Boolean condition value. | ||
* @param y_true Value to return if condition is true. | ||
* @param y_false Value to return if condition is false. | ||
*/ | ||
template <typename T_true, typename T_false, | ||
require_all_stan_scalar_t<T_true, T_false>* = nullptr> | ||
inline auto select(const bool c, const T_true y_true, const T_false y_false) { | ||
return c ? y_true : y_false; | ||
} | ||
|
||
/** | ||
* Return the second argument if the first argument is true | ||
* and otherwise return the third argument. Overload for use with two Eigen | ||
* objects. | ||
* | ||
* @tparam T_true type of the true argument | ||
* @tparam T_false type of the false argument | ||
* @param c Boolean condition value. | ||
* @param y_true Value to return if condition is true. | ||
* @param y_false Value to return if condition is false. | ||
*/ | ||
template <typename T_true, typename T_false, | ||
require_all_eigen_t<T_true, T_false>* = nullptr> | ||
inline auto select(const bool c, const T_true y_true, const T_false y_false) { | ||
return y_true | ||
.binaryExpr(y_false, [&](auto&& x, auto&& y) { return c ? x : y; }) | ||
.eval(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If c is constant here should we just be returning |
||
} | ||
|
||
/** | ||
* Return the second Eigen argument if the first argument is true | ||
* and otherwise return the second Eigen argument. Overload for use with one | ||
* scalar and one Eigen object. If chosen, the scalar is returned as an Eigen | ||
* object of the same size and type as the provided argument. | ||
* | ||
* @tparam T_true type of the true argument | ||
* @tparam T_false type of the false argument | ||
* @param c Boolean condition value. | ||
* @param y_true Value to return if condition is true. | ||
* @param y_false Value to return if condition is false. | ||
*/ | ||
template <typename T_true, typename T_false, | ||
typename ReturnT = promote_scalar_t<return_type_t<T_true, T_false>, | ||
plain_type_t<T_true>>, | ||
require_eigen_t<T_true>* = nullptr, | ||
require_stan_scalar_t<T_false>* = nullptr> | ||
inline ReturnT select(const bool c, const T_true& y_true, | ||
const T_false& y_false) { | ||
if (c) { | ||
return y_true; | ||
} | ||
|
||
return y_true.unaryExpr([&](auto&& y) { return y_false; }); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd use
with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is true for all of them. |
||
|
||
/** | ||
* Return the second Eigen argument if the first argument is true | ||
* and otherwise return the second Eigen argument. Overload for use with one | ||
* scalar and one Eigen object. If chosen, the scalar is returned as an Eigen | ||
* object of the same size and type as the provided argument. | ||
* | ||
* @tparam T_true type of the true argument | ||
* @tparam T_false type of the false argument | ||
* @param c Boolean condition value. | ||
* @param y_true Value to return if condition is true. | ||
* @param y_false Value to return if condition is false. | ||
*/ | ||
template <typename T_true, typename T_false, | ||
typename ReturnT = promote_scalar_t<return_type_t<T_true, T_false>, | ||
plain_type_t<T_false>>, | ||
require_stan_scalar_t<T_true>* = nullptr, | ||
require_eigen_t<T_false>* = nullptr> | ||
inline ReturnT select(const bool c, const T_true y_true, | ||
const T_false y_false) { | ||
if (c) { | ||
return y_false.unaryExpr([&](auto&& y) { return y_true; }); | ||
} | ||
|
||
return y_false; | ||
} | ||
|
||
/** | ||
* Return the second argument if the first argument is true | ||
* and otherwise return the third argument. Overload for use with an Eigen | ||
* object of booleans, and two scalars. The chosen scalar is returned as an | ||
* Eigen object of the same dimension as the input Eigen argument | ||
* | ||
* @tparam T_bool type of Eigen boolean object | ||
* @tparam T_true type of the true argument | ||
* @tparam T_false type of the false argument | ||
* @param c Eigen object of boolean condition values. | ||
* @param y_true Value to return if condition is true. | ||
* @param y_false Value to return if condition is false. | ||
*/ | ||
template <typename T_bool, typename T_true, typename T_false, | ||
require_eigen_array_t<T_bool>* = nullptr, | ||
require_all_stan_scalar_t<T_true, T_false>* = nullptr> | ||
inline auto select(const T_bool c, const T_true y_true, const T_false y_false) { | ||
return c.unaryExpr([&](bool cond) { return cond ? y_true : y_false; }).eval(); | ||
} | ||
|
||
/** | ||
* Return the second argument if the first argument is true | ||
* and otherwise return the third argument. Overload for use with an Eigen | ||
* object of booleans, and at least one Eigen object as input. | ||
* | ||
* @tparam T_bool type of Eigen boolean object | ||
* @tparam T_true type of the true argument | ||
* @tparam T_false type of the false argument | ||
* @param c Eigen object of boolean condition values. | ||
* @param y_true Value to return if condition is true. | ||
* @param y_false Value to return if condition is false. | ||
*/ | ||
template <typename T_bool, typename T_true, typename T_false, | ||
require_eigen_array_t<T_bool>* = nullptr, | ||
require_any_eigen_array_t<T_true, T_false>* = nullptr> | ||
inline auto select(const T_bool c, const T_true y_true, const T_false y_false) { | ||
return c.select(y_true, y_false).eval(); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this work if |
||
|
||
} // namespace math | ||
} // namespace stan | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -8,6 +8,7 @@ | |||||||
#include <stan/math/prim/fun/log.hpp> | ||||||||
#include <stan/math/prim/fun/max_size.hpp> | ||||||||
#include <stan/math/prim/fun/scalar_seq_view.hpp> | ||||||||
#include <stan/math/prim/fun/select.hpp> | ||||||||
#include <stan/math/prim/fun/size.hpp> | ||||||||
#include <stan/math/prim/fun/size_zero.hpp> | ||||||||
#include <stan/math/prim/fun/value_of.hpp> | ||||||||
|
@@ -33,50 +34,38 @@ template <typename T_n, typename T_prob, | |||||||
require_all_not_nonscalar_prim_or_rev_kernel_expression_t< | ||||||||
T_n, T_prob>* = nullptr> | ||||||||
return_type_t<T_prob> bernoulli_lccdf(const T_n& n, const T_prob& theta) { | ||||||||
using T_partials_return = partials_return_t<T_n, T_prob>; | ||||||||
using T_theta_ref = ref_type_t<T_prob>; | ||||||||
using std::log; | ||||||||
static const char* function = "bernoulli_lccdf"; | ||||||||
check_consistent_sizes(function, "Random variable", n, | ||||||||
"Probability parameter", theta); | ||||||||
T_theta_ref theta_ref = theta; | ||||||||
const auto& n_arr = as_array_or_scalar(n); | ||||||||
check_bounded(function, "Probability parameter", value_of(theta_ref), 0.0, | ||||||||
1.0); | ||||||||
|
||||||||
if (size_zero(n, theta)) { | ||||||||
return 0.0; | ||||||||
} | ||||||||
|
||||||||
T_partials_return P(0.0); | ||||||||
operands_and_partials<T_theta_ref> ops_partials(theta_ref); | ||||||||
|
||||||||
scalar_seq_view<T_n> n_vec(n); | ||||||||
scalar_seq_view<T_theta_ref> theta_vec(theta_ref); | ||||||||
size_t max_size_seq_view = max_size(n, theta); | ||||||||
|
||||||||
// Explicit return for extreme values | ||||||||
// The gradients are technically ill-defined, but treated as zero | ||||||||
for (size_t i = 0; i < stan::math::size(n); i++) { | ||||||||
const double n_dbl = n_vec.val(i); | ||||||||
if (n_dbl < 0) { | ||||||||
return ops_partials.build(0.0); | ||||||||
} | ||||||||
if (n_dbl >= 1) { | ||||||||
return ops_partials.build(NEGATIVE_INFTY); | ||||||||
} | ||||||||
if (sum(n_arr < 0)) { | ||||||||
return ops_partials.build(0.0); | ||||||||
} | ||||||||
if (sum(n_arr >= 1)) { | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
return ops_partials.build(NEGATIVE_INFTY); | ||||||||
} | ||||||||
|
||||||||
for (size_t i = 0; i < max_size_seq_view; i++) { | ||||||||
const T_partials_return Pi = theta_vec.val(i); | ||||||||
|
||||||||
P += log(Pi); | ||||||||
const auto& theta_arr = as_value_column_array_or_scalar(theta_ref); | ||||||||
|
||||||||
if (!is_constant_all<T_prob>::value) { | ||||||||
ops_partials.edge1_.partials_[i] += inv(Pi); | ||||||||
} | ||||||||
if (!is_constant_all<T_prob>::value) { | ||||||||
ops_partials.edge1_.partials_ = select(true, inv(theta_arr), n_arr); | ||||||||
} | ||||||||
|
||||||||
return ops_partials.build(P); | ||||||||
return ops_partials.build(sum(select(true, log(theta_arr), n_arr))); | ||||||||
} | ||||||||
|
||||||||
} // namespace math | ||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
#include <stan/math/prim/fun/log.hpp> | ||
#include <stan/math/prim/fun/max_size.hpp> | ||
#include <stan/math/prim/fun/scalar_seq_view.hpp> | ||
#include <stan/math/prim/fun/select.hpp> | ||
#include <stan/math/prim/fun/size.hpp> | ||
#include <stan/math/prim/fun/size_zero.hpp> | ||
#include <stan/math/prim/fun/value_of.hpp> | ||
|
@@ -33,52 +34,36 @@ template <typename T_n, typename T_prob, | |
require_all_not_nonscalar_prim_or_rev_kernel_expression_t< | ||
T_n, T_prob>* = nullptr> | ||
return_type_t<T_prob> bernoulli_lcdf(const T_n& n, const T_prob& theta) { | ||
using T_partials_return = partials_return_t<T_n, T_prob>; | ||
using T_theta_ref = ref_type_t<T_prob>; | ||
using std::log; | ||
static const char* function = "bernoulli_lcdf"; | ||
check_consistent_sizes(function, "Random variable", n, | ||
"Probability parameter", theta); | ||
T_theta_ref theta_ref = theta; | ||
const auto& n_arr = as_array_or_scalar(n); | ||
check_bounded(function, "Probability parameter", value_of(theta_ref), 0.0, | ||
1.0); | ||
|
||
if (size_zero(n, theta)) { | ||
return 0.0; | ||
} | ||
|
||
T_partials_return P(0.0); | ||
operands_and_partials<T_theta_ref> ops_partials(theta_ref); | ||
|
||
scalar_seq_view<T_n> n_vec(n); | ||
scalar_seq_view<T_theta_ref> theta_vec(theta_ref); | ||
size_t max_size_seq_view = max_size(n, theta); | ||
|
||
// Explicit return for extreme values | ||
// The gradients are technically ill-defined, but treated as zero | ||
for (size_t i = 0; i < stan::math::size(n); i++) { | ||
if (n_vec.val(i) < 0) { | ||
return ops_partials.build(NEGATIVE_INFTY); | ||
} | ||
if (sum(n_arr < 0)) { | ||
return ops_partials.build(NEGATIVE_INFTY); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could just write an |
||
} | ||
|
||
for (size_t i = 0; i < max_size_seq_view; i++) { | ||
// Explicit results for extreme values | ||
// The gradients are technically ill-defined, but treated as zero | ||
if (n_vec.val(i) >= 1) { | ||
continue; | ||
} | ||
|
||
const T_partials_return Pi = 1 - theta_vec.val(i); | ||
|
||
P += log(Pi); | ||
const auto& theta_arr = as_value_column_array_or_scalar(theta_ref); | ||
const auto& log1m_theta = select(theta_arr == 1, 0.0, log1m(theta_arr)); | ||
|
||
if (!is_constant_all<T_prob>::value) { | ||
ops_partials.edge1_.partials_[i] -= inv(Pi); | ||
} | ||
if (!is_constant_all<T_prob>::value) { | ||
ops_partials.edge1_.partials_ = select(n_arr == 0, -exp(-log1m_theta), 0.0); | ||
} | ||
|
||
return ops_partials.build(P); | ||
return ops_partials.build(sum(select(n_arr == 0, log1m_theta, 0.0))); | ||
} | ||
|
||
} // namespace math | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if @t4c1 still checks github, but I'm not sure if we need
common_type
here or if auto is fine? I wouldn't mind just usingreturn_type_t<>
, though that will only work with arithmetic types sincereturn_type_t
has a minimum ofdouble
as the returned type. We could just write another another overload to handle the double integral case thoughThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still get notifications if pinged.
auto
will here be same asT_true
(that is how ternary operator works), so some common type is a better idea. Not sure ifretrun_type
will do promotion to var even if neither T_true nor T_false are var, but we do not want that here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've done some tests and it doesn't look like an issue when mixing types: https://godbolt.org/z/dvcxvvxhs
But let me know if I've missed something basic!