-
-
Notifications
You must be signed in to change notification settings - Fork 190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve Numerical Stability of Bernoulli CDF functions #2784
Improve Numerical Stability of Bernoulli CDF functions #2784
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks---looks great. I think there are a couple name changes that would make this much easier to follow.
@andrjohns: please let me know when this is ready to review and merge. Thanks! |
Thanks @bob-carpenter! This is ready for another look. I've updated the vectorisation by adding a |
@SteveBronder when you have a minute (no rush at all), can you have a look at this PR? It involves re-implementing the the OpenCL code's |
@bob-carpenter are you able to re-review this? @SteveBronder are you able to double check the opencl select code? |
dismiss the closed/reopened, I hit the trackpad on my laptop by accident |
don't think I understand the current C++ well enough
Thanks for the heads up. I just dismissed my review so that someone else could review it. I still don't feel I understand our new C++ conventions well enough to review PRs. |
Sorry didn't have time today but Tuesday I can look at this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few Qs around the new version of select. I also think we should just write an any()
function that for bool just returns the input and for Eigen types holding bools calls the .any()
method. Would make things simpler to read
stan/math/prim/fun/select.hpp
Outdated
inline auto select(const bool c, const T_true y_true, const T_false y_false) { | ||
return c ? y_true : y_false; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if @t4c1 still checks github, but I'm not sure if we need common_type
here or if auto is fine? I wouldn't mind just using return_type_t<>
, though that will only work with arithmetic types since return_type_t
has a minimum of double
as the returned type. We could just write another another overload to handle the double integral case though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still get notifications if pinged. auto
will here be same as T_true
(that is how ternary operator works), so some common type is a better idea. Not sure if retrun_type
will do promotion to var even if neither T_true nor T_false are var, but we do not want that here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto
will here be same asT_true
I've done some tests and it doesn't look like an issue when mixing types: https://godbolt.org/z/dvcxvvxhs
But let me know if I've missed something basic!
stan/math/prim/fun/select.hpp
Outdated
return y_true | ||
.binaryExpr(y_false, [&](auto&& x, auto&& y) { return c ? x : y; }) | ||
.eval(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If c is constant here should we just be returning y_true
or y_false
? We just need to use promotion rules on the output types scalar value with promote_scalar_t<return_type_t<T_true, T_false>>
stan/math/prim/fun/select.hpp
Outdated
if (c) { | ||
return y_true; | ||
} | ||
|
||
return y_true.unaryExpr([&](auto&& y) { return y_false; }); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd use
if () {
} else {
}
with promote_type_t
again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is true for all of them.
stan/math/prim/fun/select.hpp
Outdated
inline auto select(const T_bool c, const T_true y_true, const T_false y_false) { | ||
return c.select(y_true, y_false).eval(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this work if y_true
has a double scalar type and y_false
has an integer scalar type?
} | ||
if (sum(n_arr >= 1)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
} | |
if (sum(n_arr >= 1)) { | |
} else if (sum(n_arr >= 1)) { |
if (sum(n_arr < 0)) { | ||
return ops_partials.build(NEGATIVE_INFTY); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could just write an any()
function that takes in a scalar or vector and returns true or false. Think that would just be easier to read imo
@SteveBronder I'll add an |
I'll update this PR once the helper functions from #2852 have been added and merged |
@SteveBronder would you mind having another look at this when you get a minute? No rush |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
I'm guessing this is due to
@andrjohns are you able to take a look at this soon? Otherwise I think we may need to revert this to unblock our CI |
Ah damn, yeah I'll have a look now |
Summary
This PR updates the Bernoulli CDF functions (
_cdf
,_lcdf
, and_lccdf
) to operate on the log scale as much as possible, to avoid issues with underflow and resolution around 1Tests
Additional
mix/prob
tests have been added to ensure that the gradients aren't impacted (prim
behaviour covered by the distribution tests)Side Effects
N/A
Release notes
Improved numerical stability of Bernoulli CDF functions
Checklist
Math issue Improve Bernoulli (LC)CDF Numerical Stability #2783
Copyright holder: Andrew Johnson
The copyright holder is typically you or your assignee, such as a university or company. By submitting this pull request, the copyright holder is agreeing to the license the submitted work under the following licenses:
- Code: BSD 3-clause (https://opensource.org/licenses/BSD-3-Clause)
- Documentation: CC-BY 4.0 (https://creativecommons.org/licenses/by/4.0/)
the basic tests are passing
./runTests.py test/unit
)make test-headers
)make test-math-dependencies
)make doxygen
)make cpplint
)the code is written in idiomatic C++ and changes are documented in the doxygen
the new changes are tested