Skip to content

Commit cd1b0cb

Browse files
committed
Use varis to get nth gradient
1 parent 36847d9 commit cd1b0cb

File tree

1 file changed

+11
-24
lines changed

1 file changed

+11
-24
lines changed

stan/math/rev/functor/integrate_1d.hpp

+11-24
Original file line numberDiff line numberDiff line change
@@ -102,43 +102,30 @@ inline return_type_t<T_a, T_b, Args...> integrate_1d_impl(
102102
// The arguments copy is used multiple times in the following nests, so
103103
// do it once in a separate nest for efficiency
104104
auto args_tuple_local_copy = std::make_tuple(deep_copy_vars(args)...);
105+
106+
// Save the varis so it's easy to efficiently access the nth adjoint
107+
std::vector<vari*> local_varis(num_vars_args);
108+
apply([&](const auto&... args) {
109+
save_varis(local_varis.data(), args...);
110+
}, args_tuple_local_copy);
111+
105112
for (size_t n = 0; n < num_vars_args; ++n) {
106113
// This computes the integral of the gradient of f with respect to the
107114
// nth parameter in args using a nested nested reverse mode autodiff
108115
*partials_ptr = integrate(
109116
[&](const auto &x, const auto &xc) {
110117
argument_nest.set_zero_all_adjoints();
118+
111119
nested_rev_autodiff gradient_nest;
112120
var fx = apply(
113121
[&f, &x, &xc, msgs](auto &&... local_args) {
114122
return f(x, xc, msgs, local_args...);
115123
},
116124
args_tuple_local_copy);
117125
fx.grad();
118-
size_t adjoint_count = 0;
119-
double gradient = 0;
120-
bool not_found = true;
121-
// for_each is guaranteed to go off from left to right.
122-
// So for var argument we count the number of previous vars
123-
// until we go past n, then index into that argument to get
124-
// the correct adjoint.
125-
stan::math::for_each(
126-
[&](auto &arg) {
127-
using arg_t = decltype(arg);
128-
using scalar_arg_t = scalar_type_t<arg_t>;
129-
if (is_var<scalar_arg_t>::value) {
130-
size_t var_count = count_vars(arg);
131-
if (((adjoint_count + var_count) < n) && not_found) {
132-
adjoint_count += var_count;
133-
} else if (not_found) {
134-
not_found = false;
135-
gradient
136-
= forward_as<var>(stan::get(arg, n - adjoint_count))
137-
.adj();
138-
}
139-
}
140-
},
141-
args_tuple_local_copy);
126+
127+
double gradient = local_varis[n]->adj();
128+
142129
// Gradients that evaluate to NaN are set to zero if the function
143130
// itself evaluates to zero. If the function is not zero and the
144131
// gradient evaluates to NaN, a std::domain_error is thrown

0 commit comments

Comments
 (0)