Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: nhuurre/math
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: feature/closures-v2
Choose a base ref
...
head repository: stan-dev/math
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: review2/closure
Choose a head ref
Able to merge. These branches can be automatically merged.
  • 1 commit
  • 3 files changed
  • 1 contributor

Commits on Aug 13, 2021

  1. Verified

    This commit was signed with the committer’s verified signature.
    SteveBronder Steve Bronder
    Copy the full SHA
    60c25c3 View commit details
Showing with 25 additions and 24 deletions.
  1. +1 −1 stan/math/prim/fun/value_of.hpp
  2. +23 −22 stan/math/prim/functor/closure_adapter.hpp
  3. +1 −1 stan/math/rev/core/deep_copy_vars.hpp
2 changes: 1 addition & 1 deletion stan/math/prim/fun/value_of.hpp
Original file line number Diff line number Diff line change
@@ -91,7 +91,7 @@ template <typename F, require_stan_closure_t<F>* = nullptr,
inline auto value_of(const F& f) {
return apply(
[&f](const auto&... s) {
return typename F::partials_closure_t_(f.f_, eval(value_of(s))...);
return typename F::partials_closure_t_(f.f_, value_of(s)...);
},
f.captures_);
}
45 changes: 23 additions & 22 deletions stan/math/prim/functor/closure_adapter.hpp
Original file line number Diff line number Diff line change
@@ -12,15 +12,17 @@ namespace internal {
/**
* A closure that wraps a C++ lambda and captures values.
*/
template <bool Ref, typename F, typename... Ts>
template <typename F, typename... Ts>
struct base_closure {
using return_scalar_t_ = return_type_t<Ts...>;
/*The base closure with `Ts` as the non-expression partials of `Ts`*/
using partials_closure_t_
= base_closure<false, F, decltype(eval(value_of(std::declval<Ts>())))...>;
using Base_ = base_closure<false, F, Ts...>;
= base_closure<F, decltype(value_of(std::declval<Ts>()))...>;
using Base_ = base_closure<F, Ts...>;
/* The closure with captures_ as the plain object instead of a reference */
using PlainBase_ = base_closure<F, std::decay_t<plain_type_t<decltype(eval(std::declval<Ts>()))>>...>;
std::decay_t<F> f_;
std::tuple<closure_return_type_t<Ts, Ref>...> captures_;
std::tuple<Ts...> captures_;
template <typename FF, require_same_t<FF, F>* = nullptr, typename... Args>
explicit base_closure(FF&& f, Args&&... args)
: f_(std::forward<FF>(f)), captures_(std::forward<Args>(args)...) {}
@@ -36,13 +38,13 @@ struct base_closure {
/**
* A closure that takes rng argument.
*/
template <bool Ref, typename F, typename... Ts>
template <typename F, typename... Ts>
struct closure_rng {
using return_scalar_t_ = double;
using partials_closure_t_ = closure_rng<false, F, Ts...>;
using Base_ = closure_rng<false, F, Ts...>;
using partials_closure_t_ = closure_rng<F, Ts...>;
using Base_ = closure_rng<F, Ts...>;
std::decay_t<F> f_;
std::tuple<closure_return_type_t<Ts, Ref>...> captures_;
std::tuple<Ts...> captures_;

template <typename FF, require_same_t<FF, F>* = nullptr, typename... Args>
explicit closure_rng(FF&& f, Args&&... args)
@@ -61,13 +63,13 @@ struct closure_rng {
/**
* A closure that can be called with `propto` template argument.
*/
template <bool Propto, bool Ref, typename F, typename... Ts>
template <bool Propto, typename F, typename... Ts>
struct closure_lpdf {
using return_scalar_t_ = return_type_t<Ts...>;
using partials_closure_t_ = closure_lpdf<Propto, false, F, Ts...>;
using Base_ = closure_lpdf<Propto, false, F, Ts...>;
using partials_closure_t_ = closure_lpdf<Propto, F, Ts...>;
using Base_ = closure_lpdf<Propto, F, Ts...>;
std::decay_t<F> f_;
std::tuple<closure_return_type_t<Ts, Ref>...> captures_;
std::tuple<Ts...> captures_;

template <typename FF, require_same_t<FF, F>* = nullptr, typename... Args>
explicit closure_lpdf(FF&& f, Args&&... args)
@@ -77,8 +79,7 @@ struct closure_lpdf {
auto with_propto() {
return apply(
[this](const auto&... args) {
return closure_lpdf < Propto && propto, true, F,
Ts... > (this->f_, args...);
return closure_lpdf < Propto && propto, F, Ts...> (this->f_, args...);
},
captures_);
}
@@ -96,13 +97,13 @@ struct closure_lpdf {
/**
* A closure that accesses logprob accumulator.
*/
template <bool Propto, bool Ref, typename F, typename... Ts>
template <bool Propto, typename F, typename... Ts>
struct closure_lp {
using return_scalar_t_ = return_type_t<Ts...>;
using partials_closure_t_ = closure_lp<Propto, true, F, Ts...>;
using Base_ = closure_lp<Propto, true, F, Ts...>;
using partials_closure_t_ = closure_lp<Propto, F, Ts...>;
using Base_ = closure_lp<Propto, F, Ts...>;
std::decay_t<F> f_;
std::tuple<closure_return_type_t<Ts, Ref>...> captures_;
std::tuple<Ts...> captures_;

template <typename FF, require_same_t<FF, F>* = nullptr, typename... Args>
explicit closure_lp(FF&& f, Args&&... args)
@@ -149,7 +150,7 @@ struct integrate_ode_closure_adapter {
*/
template <typename F, typename... Args>
auto from_lambda(F&& f, Args&&... args) {
return internal::base_closure<true, F, Args...>(std::forward<F>(f),
return internal::base_closure<F, Args...>(std::forward<F>(f),
std::forward<Args>(args)...);
}

@@ -158,7 +159,7 @@ auto from_lambda(F&& f, Args&&... args) {
*/
template <typename F, typename... Args>
auto rng_from_lambda(F&& f, Args&&... args) {
return internal::closure_rng<true, F, Args...>(std::forward<F>(f),
return internal::closure_rng<F, Args...>(std::forward<F>(f),
std::forward<Args>(args)...);
}

@@ -167,7 +168,7 @@ auto rng_from_lambda(F&& f, Args&&... args) {
*/
template <bool propto, typename F, typename... Args>
auto lpdf_from_lambda(F&& f, Args&&... args) {
return internal::closure_lpdf<propto, true, F, Args...>(
return internal::closure_lpdf<propto, F, Args...>(
std::forward<F>(f), std::forward<Args>(args)...);
}

@@ -176,7 +177,7 @@ auto lpdf_from_lambda(F&& f, Args&&... args) {
*/
template <bool Propto, typename F, typename... Args>
auto lp_from_lambda(F&& f, Args&&... args) {
return internal::closure_lp<Propto, true, F, Args...>(
return internal::closure_lp<Propto, F, Args...>(
std::forward<F>(f), std::forward<Args>(args)...);
}

2 changes: 1 addition & 1 deletion stan/math/rev/core/deep_copy_vars.hpp
Original file line number Diff line number Diff line change
@@ -95,7 +95,7 @@ template <typename F, require_stan_closure_t<F>* = nullptr>
inline auto deep_copy_vars(const F& f) {
return apply(
[&f](const auto&... s) {
return typename F::Base_(f.f_, eval(deep_copy_vars(s))...);
return typename F::PlainBase_(f.f_, eval(deep_copy_vars(s))...);
},
f.captures_);
}