@@ -102,43 +102,30 @@ inline return_type_t<T_a, T_b, Args...> integrate_1d_impl(
102
102
// The arguments copy is used multiple times in the following nests, so
103
103
// do it once in a separate nest for efficiency
104
104
auto args_tuple_local_copy = std::make_tuple (deep_copy_vars (args)...);
105
+
106
+ // Save the varis so it's easy to efficiently access the nth adjoint
107
+ std::vector<vari*> local_varis (num_vars_args);
108
+ apply ([&](const auto &... args) {
109
+ save_varis (local_varis.data (), args...);
110
+ }, args_tuple_local_copy);
111
+
105
112
for (size_t n = 0 ; n < num_vars_args; ++n) {
106
113
// This computes the integral of the gradient of f with respect to the
107
114
// nth parameter in args using a nested nested reverse mode autodiff
108
115
*partials_ptr = integrate (
109
116
[&](const auto &x, const auto &xc) {
110
117
argument_nest.set_zero_all_adjoints ();
118
+
111
119
nested_rev_autodiff gradient_nest;
112
120
var fx = apply (
113
121
[&f, &x, &xc, msgs](auto &&... local_args) {
114
122
return f (x, xc, msgs, local_args...);
115
123
},
116
124
args_tuple_local_copy);
117
125
fx.grad ();
118
- size_t adjoint_count = 0 ;
119
- double gradient = 0 ;
120
- bool not_found = true ;
121
- // for_each is guaranteed to go off from left to right.
122
- // So for var argument we count the number of previous vars
123
- // until we go past n, then index into that argument to get
124
- // the correct adjoint.
125
- stan::math::for_each (
126
- [&](auto &arg) {
127
- using arg_t = decltype (arg);
128
- using scalar_arg_t = scalar_type_t <arg_t >;
129
- if (is_var<scalar_arg_t >::value) {
130
- size_t var_count = count_vars (arg);
131
- if (((adjoint_count + var_count) < n) && not_found) {
132
- adjoint_count += var_count;
133
- } else if (not_found) {
134
- not_found = false ;
135
- gradient
136
- = forward_as<var>(stan::get (arg, n - adjoint_count))
137
- .adj ();
138
- }
139
- }
140
- },
141
- args_tuple_local_copy);
126
+
127
+ double gradient = local_varis[n]->adj ();
128
+
142
129
// Gradients that evaluate to NaN are set to zero if the function
143
130
// itself evaluates to zero. If the function is not zero and the
144
131
// gradient evaluates to NaN, a std::domain_error is thrown
0 commit comments