Skip to content

Commit 0956b10

Browse files
authored
Merge pull request #2397 from stan-dev/feature/variadic-integrate-1d
Update integrate_1d to use variadic autodiff stuff internally in preparation for closures
2 parents 40803f8 + 49ba955 commit 0956b10

File tree

6 files changed

+1638
-140
lines changed

6 files changed

+1638
-140
lines changed

stan/math/prim/functor.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <stan/math/prim/functor/finite_diff_gradient_auto.hpp>
1111
#include <stan/math/prim/functor/for_each.hpp>
1212
#include <stan/math/prim/functor/integrate_1d.hpp>
13+
#include <stan/math/prim/functor/integrate_1d_adapter.hpp>
1314
#include <stan/math/prim/functor/integrate_ode_rk45.hpp>
1415
#include <stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp>
1516
#include <stan/math/prim/functor/ode_ckrk.hpp>

stan/math/prim/functor/integrate_1d.hpp

+109-62
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
#ifndef STAN_MATH_PRIM_FUNCTOR_integrate_1d_HPP
2-
#define STAN_MATH_PRIM_FUNCTOR_integrate_1d_HPP
1+
#ifndef STAN_MATH_PRIM_FUNCTOR_INTEGRATE_1D_HPP
2+
#define STAN_MATH_PRIM_FUNCTOR_INTEGRATE_1D_HPP
33

44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/err.hpp>
66
#include <stan/math/prim/fun/constants.hpp>
7+
#include <stan/math/prim/functor/integrate_1d_adapter.hpp>
78
#include <boost/math/quadrature/exp_sinh.hpp>
89
#include <boost/math/quadrature/sinh_sinh.hpp>
910
#include <boost/math/quadrature/tanh_sinh.hpp>
@@ -50,17 +51,53 @@ namespace math {
5051
template <typename F>
5152
inline double integrate(const F& f, double a, double b,
5253
double relative_tolerance) {
54+
static constexpr const char* function = "integrate";
5355
double error1 = 0.0;
5456
double error2 = 0.0;
5557
double L1 = 0.0;
5658
double L2 = 0.0;
57-
bool used_two_integrals = false;
5859
size_t levels;
59-
double Q = 0.0;
60-
auto f_wrap = [&](double x) { return f(x, NOT_A_NUMBER); };
60+
61+
auto one_integral_convergence_check = [](auto& error1, auto& rel_tol,
62+
auto& L1) {
63+
if (error1 > rel_tol * L1) {
64+
[error1]() STAN_COLD_PATH {
65+
throw_domain_error(
66+
function, "error estimate of integral", error1, "",
67+
" exceeds the given relative tolerance times norm of integral");
68+
}();
69+
}
70+
};
71+
72+
auto two_integral_convergence_check
73+
= [](auto& error1, auto& error2, auto& rel_tol, auto& L1, auto& L2) {
74+
if (error1 > rel_tol * L1) {
75+
[error1]() STAN_COLD_PATH {
76+
throw_domain_error(
77+
function, "error estimate of integral below zero", error1, "",
78+
" exceeds the given relative tolerance times norm of "
79+
"integral below zero");
80+
}();
81+
}
82+
if (error2 > rel_tol * L2) {
83+
[error2]() STAN_COLD_PATH {
84+
throw_domain_error(
85+
function, "error estimate of integral above zero", error2, "",
86+
" exceeds the given relative tolerance times norm of "
87+
"integral above zero");
88+
}();
89+
}
90+
};
91+
92+
// if a or b is infinite, set xc argument to NaN (see docs above for user
93+
// function for xc info)
94+
auto f_wrap = [&f](double x) { return f(x, NOT_A_NUMBER); };
6195
if (std::isinf(a) && std::isinf(b)) {
6296
boost::math::quadrature::sinh_sinh<double> integrator;
63-
Q = integrator.integrate(f_wrap, relative_tolerance, &error1, &L1, &levels);
97+
double Q = integrator.integrate(f_wrap, relative_tolerance, &error1, &L1,
98+
&levels);
99+
one_integral_convergence_check(error1, relative_tolerance, L1);
100+
return Q;
64101
} else if (std::isinf(a)) {
65102
boost::math::quadrature::exp_sinh<double> integrator;
66103
/**
@@ -69,66 +106,88 @@ inline double integrate(const F& f, double a, double b,
69106
* https://www.boost.org/doc/libs/1_66_0/libs/math/doc/html/math_toolkit/double_exponential/de_caveats.html)
70107
*/
71108
if (b <= 0.0) {
72-
Q = integrator.integrate(f_wrap, a, b, relative_tolerance, &error1, &L1,
73-
&levels);
109+
double Q = integrator.integrate(f_wrap, a, b, relative_tolerance, &error1,
110+
&L1, &levels);
111+
one_integral_convergence_check(error1, relative_tolerance, L1);
112+
return Q;
74113
} else {
75114
boost::math::quadrature::tanh_sinh<double> integrator_right;
76-
Q = integrator.integrate(f_wrap, a, 0.0, relative_tolerance, &error1, &L1,
77-
&levels)
78-
+ integrator_right.integrate(f_wrap, 0.0, b, relative_tolerance,
79-
&error2, &L2, &levels);
80-
used_two_integrals = true;
115+
double Q
116+
= integrator.integrate(f_wrap, a, 0.0, relative_tolerance, &error1,
117+
&L1, &levels)
118+
+ integrator_right.integrate(f_wrap, 0.0, b, relative_tolerance,
119+
&error2, &L2, &levels);
120+
two_integral_convergence_check(error1, error2, relative_tolerance, L1,
121+
L2);
122+
return Q;
81123
}
82124
} else if (std::isinf(b)) {
83125
boost::math::quadrature::exp_sinh<double> integrator;
84126
if (a >= 0.0) {
85-
Q = integrator.integrate(f_wrap, a, b, relative_tolerance, &error1, &L1,
86-
&levels);
127+
double Q = integrator.integrate(f_wrap, a, b, relative_tolerance, &error1,
128+
&L1, &levels);
129+
one_integral_convergence_check(error1, relative_tolerance, L1);
130+
return Q;
87131
} else {
88132
boost::math::quadrature::tanh_sinh<double> integrator_left;
89-
Q = integrator_left.integrate(f_wrap, a, 0, relative_tolerance, &error1,
90-
&L1, &levels)
91-
+ integrator.integrate(f_wrap, relative_tolerance, &error2, &L2,
92-
&levels);
93-
used_two_integrals = true;
133+
double Q = integrator_left.integrate(f_wrap, a, 0, relative_tolerance,
134+
&error1, &L1, &levels)
135+
+ integrator.integrate(f_wrap, relative_tolerance, &error2,
136+
&L2, &levels);
137+
two_integral_convergence_check(error1, error2, relative_tolerance, L1,
138+
L2);
139+
return Q;
94140
}
95141
} else {
96-
auto f_wrap = [&](double x, double xc) { return f(x, xc); };
142+
auto f_wrap = [&f](double x, double xc) { return f(x, xc); };
97143
boost::math::quadrature::tanh_sinh<double> integrator;
98144
if (a < 0.0 && b > 0.0) {
99-
Q = integrator.integrate(f_wrap, a, 0.0, relative_tolerance, &error1, &L1,
100-
&levels)
101-
+ integrator.integrate(f_wrap, 0.0, b, relative_tolerance, &error2,
102-
&L2, &levels);
103-
used_two_integrals = true;
145+
double Q = integrator.integrate(f_wrap, a, 0.0, relative_tolerance,
146+
&error1, &L1, &levels)
147+
+ integrator.integrate(f_wrap, 0.0, b, relative_tolerance,
148+
&error2, &L2, &levels);
149+
two_integral_convergence_check(error1, error2, relative_tolerance, L1,
150+
L2);
151+
return Q;
104152
} else {
105-
Q = integrator.integrate(f_wrap, a, b, relative_tolerance, &error1, &L1,
106-
&levels);
153+
double Q = integrator.integrate(f_wrap, a, b, relative_tolerance, &error1,
154+
&L1, &levels);
155+
one_integral_convergence_check(error1, relative_tolerance, L1);
156+
return Q;
107157
}
108158
}
159+
}
109160

110-
static const char* function = "integrate";
111-
if (used_two_integrals) {
112-
if (error1 > relative_tolerance * L1) {
113-
throw_domain_error(function, "error estimate of integral below zero",
114-
error1, "",
115-
" exceeds the given relative tolerance times norm of "
116-
"integral below zero");
117-
}
118-
if (error2 > relative_tolerance * L2) {
119-
throw_domain_error(function, "error estimate of integral above zero",
120-
error2, "",
121-
" exceeds the given relative tolerance times norm of "
122-
"integral above zero");
161+
/**
162+
* Compute the integral of the single variable function f from a to b to within
163+
* a specified relative tolerance. a and b can be finite or infinite.
164+
*
165+
* @tparam T Type of f
166+
* @param f the function to be integrated
167+
* @param a lower limit of integration
168+
* @param b upper limit of integration
169+
* @param relative_tolerance tolerance passed to Boost quadrature
170+
* @param[in, out] msgs the print stream for warning messages
171+
* @param args additional arguments passed to f
172+
* @return numeric integral of function f
173+
*/
174+
template <typename F, typename... Args,
175+
require_all_not_st_var<Args...>* = nullptr>
176+
inline double integrate_1d_impl(const F& f, double a, double b,
177+
double relative_tolerance, std::ostream* msgs,
178+
const Args&... args) {
179+
static constexpr const char* function = "integrate_1d";
180+
check_less_or_equal(function, "lower limit", a, b);
181+
if (unlikely(a == b)) {
182+
if (std::isinf(a)) {
183+
throw_domain_error(function, "Integration endpoints are both", a, "", "");
123184
}
185+
return 0.0;
124186
} else {
125-
if (error1 > relative_tolerance * L1) {
126-
throw_domain_error(
127-
function, "error estimate of integral", error1, "",
128-
" exceeds the given relative tolerance times norm of integral");
129-
}
187+
return integrate(
188+
[&](const auto& x, const auto& xc) { return f(x, xc, msgs, args...); },
189+
a, b, relative_tolerance);
130190
}
131-
return Q;
132191
}
133192

134193
/**
@@ -178,26 +237,14 @@ inline double integrate(const F& f, double a, double b,
178237
* @return numeric integral of function f
179238
*/
180239
template <typename F>
181-
inline double integrate_1d(const F& f, const double a, const double b,
240+
inline double integrate_1d(const F& f, double a, double b,
182241
const std::vector<double>& theta,
183242
const std::vector<double>& x_r,
184243
const std::vector<int>& x_i, std::ostream* msgs,
185244
const double relative_tolerance
186245
= std::sqrt(EPSILON)) {
187-
static const char* function = "integrate_1d";
188-
check_less_or_equal(function, "lower limit", a, b);
189-
190-
if (a == b) {
191-
if (std::isinf(a)) {
192-
throw_domain_error(function, "Integration endpoints are both", a, "", "");
193-
}
194-
return 0.0;
195-
} else {
196-
return integrate(
197-
std::bind<double>(f, std::placeholders::_1, std::placeholders::_2,
198-
theta, x_r, x_i, msgs),
199-
a, b, relative_tolerance);
200-
}
246+
return integrate_1d_impl(integrate_1d_adapter<F>(f), a, b, relative_tolerance,
247+
msgs, theta, x_r, x_i);
201248
}
202249

203250
} // namespace math
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#ifndef STAN_MATH_PRIM_FUNCTOR_INTEGRATE_1D_ADAPTER_HPP
2+
#define STAN_MATH_PRIM_FUNCTOR_INTEGRATE_1D_ADAPTER_HPP
3+
4+
#include <ostream>
5+
#include <vector>
6+
7+
/**
8+
* Adapt the non-variadic integrate_1d arguments to the variadic
9+
* integrate_1d_impl interface
10+
*
11+
* @tparam F type of function to adapt
12+
*/
13+
template <typename F>
14+
struct integrate_1d_adapter {
15+
const F& f_;
16+
17+
explicit integrate_1d_adapter(const F& f) : f_(f) {}
18+
19+
template <typename T_a, typename T_b, typename T_theta>
20+
auto operator()(const T_a& x, const T_b& xc, std::ostream* msgs,
21+
const std::vector<T_theta>& theta,
22+
const std::vector<double>& x_r,
23+
const std::vector<int>& x_i) const {
24+
return f_(x, xc, theta, x_r, x_i, msgs);
25+
}
26+
};
27+
28+
#endif

0 commit comments

Comments
 (0)