-
-
Notifications
You must be signed in to change notification settings - Fork 190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adds unary functions for var<Matrix> #2527
Conversation
I can review this one, just lmk when it's ready |
Jenkins Console Log Machine informationProductName: Mac OS X ProductVersion: 10.11.6 BuildVersion: 15G22010CPU: G++: Clang: |
…4.1 (tags/RELEASE_600/final)
@SteveBronder JFYI; After e80abe9: stan/math/prim/fun/square.hpp:56:58: error: incomplete type ‘stan::math::apply_scalar_unary<stan::math::square_fun, stan::math::var_value<double>, void>’ used in nested name specifier
56 | return apply_scalar_unary<square_fun, Container>::apply(x);
detected during:
instantiation of "auto stan::math::square(const Container &) [with Container=stan::math::var, <unnamed>=(void *)nullptr, <unnamed>=(void *)nullptr, <unnamed>=(void *)nullptr]" at line 207 of "stan/math/prim/core/complex_base.hpp"
instantiation of "stan::math::complex_base<ValueType>::complex_type &stan::math::complex_base<ValueType>::operator/=(const std::complex<U> &) [with ValueType=stan::math::var, U=stan::math::var]" at line 24 of "stan/math/prim/core/operator_division.hpp"
instantiation of "stan::complex_return_t<U, V> stan::math::internal::complex_divide(const U &, const V &) [with U=std::complex<stan::math::var>, V=std::complex<stan::math::var>]" at line 122 of "stan/math/rev/core/operator_division.hpp" |
Oh my! Yes much appreciated! Where did you get that error? |
Testing with the experimental version of RStan. |
This reverts commit 4b434b0.
stan/math/prim/core/complex_base.hpp:207:34: error: no matching function for call to ‘square(stan::math::complex_base<stan::math::var_value<double> >::value_type)’
207 | value_type sum_sq_im = square(other.real()) + square(other.imag());
stan/math/rev/core/operator_division.hpp:122:34: required from here
/usr/include/c++/11.1.0/type_traits:2514:11: error: no type named ‘type’ in ‘struct std::enable_if<false, void>’
2514 | using enable_if_t = typename enable_if<_Cond, _Tp>::type; Log:In file included from /usr/lib/R/library/StanHeaders/include/stan/math/rev/core/std_complex.hpp:5,
from /usr/lib/R/library/StanHeaders/include/stan/math/rev/core/operator_division.hpp:11,
from /usr/lib/R/library/StanHeaders/include/stan/math/rev/core/operator_divide_equal.hpp:5,
from /usr/lib/R/library/StanHeaders/include/stan/math/rev/core.hpp:29,
from /usr/lib/R/library/StanHeaders/include/src/stan/model/model_base.hpp:5,
from Module.cpp:2:
/usr/lib/R/library/StanHeaders/include/stan/math/prim/core/complex_base.hpp: In instantiation of ‘stan::math::complex_base<ValueType>::complex_type& stan::math::complex_base<ValueType>::operator/=(const std::
complex<_Up>&) [with U = stan::math::var_value<double>; ValueType = stan::math::var_value<double>; stan::math::complex_base<ValueType>::complex_type = std::complex<stan::math::var_value<double> >]’:
/usr/lib/R/library/StanHeaders/include/stan/math/prim/core/operator_division.hpp:24:5: required from ‘stan::complex_return_t<U, V> stan::math::internal::complex_divide(const U&, const V&) [with U = std::com
plex<stan::math::var_value<double> >; V = std::complex<stan::math::var_value<double> >; stan::complex_return_t<U, V> = std::complex<stan::math::var_value<double> >]’
/usr/lib/R/library/StanHeaders/include/stan/math/rev/core/operator_division.hpp:122:34: required from here
/usr/lib/R/library/StanHeaders/include/stan/math/prim/core/complex_base.hpp:207:34: error: no matching function for call to ‘square(stan::math::complex_base<stan::math::var_value<double> >::value_type)’
207 | value_type sum_sq_im = square(other.real()) + square(other.imag());
| ~~~~~~^~~~~~~~~~~~~~
In file included from /usr/lib/R/library/StanHeaders/include/stan/math/prim/core/complex_base.hpp:4,
from /usr/lib/R/library/StanHeaders/include/stan/math/rev/core/std_complex.hpp:5,
from /usr/lib/R/library/StanHeaders/include/stan/math/rev/core/operator_division.hpp:11,
from /usr/lib/R/library/StanHeaders/include/stan/math/rev/core/operator_divide_equal.hpp:5,
from /usr/lib/R/library/StanHeaders/include/stan/math/rev/core.hpp:29,
from /usr/lib/R/library/StanHeaders/include/src/stan/model/model_base.hpp:5,
from Module.cpp:2:
/usr/lib/R/library/StanHeaders/include/stan/math/prim/fun/square.hpp:68:13: note: candidate: ‘template<class Container, stan::require_container_st<std::is_arithmetic, Container>* <anonymous> > auto stan::math
::square(const Container&)’
68 | inline auto square(const Container& x) {
| ^~~~~~
/usr/lib/R/library/StanHeaders/include/stan/math/prim/fun/square.hpp:68:13: note: template argument deduction/substitution failed:
In file included from /usr/include/c++/11.1.0/unordered_map:38,
from /usr/lib/R/library/Rcpp/include/Rcpp/platform/compiler.h:153,
from /usr/lib/R/library/Rcpp/include/Rcpp/r/headers.h:66,
from /usr/lib/R/library/Rcpp/include/RcppCommon.h:30,
from /usr/lib/R/library/Rcpp/include/Rcpp.h:27,
from Module.cpp:1:
/usr/include/c++/11.1.0/type_traits: In substitution of ‘template<bool _Cond, class _Tp> using enable_if_t = typename std::enable_if::type [with bool _Cond = false; _Tp = void]’:
/usr/lib/R/library/StanHeaders/include/stan/math/prim/meta/require_helpers.hpp:19:7: required by substitution of ‘template<class Check> using require_t = std::enable_if_t<Check::value> [with Check = std::in
tegral_constant<bool, false>]’
/usr/lib/R/library/StanHeaders/include/stan/math/prim/meta/is_container.hpp:26:1: required by substitution of ‘template<template<class ...> class TypeCheck, class ... Check> using require_container_st = sta
n::require_t<std::integral_constant<bool, stan::math::conjunction<std::integral_constant<bool, stan::math::disjunction<stan::is_eigen<typename std::decay<Check>::type>, stan::is_std_vector<typename std::decay
<Check>::type, void> >::value>..., TypeCheck<typename stan::scalar_type<Check, void>::type>...>::value> > [with TypeCheck = std::is_arithmetic; Check = {stan::math::var_value<double, void>}]’
/usr/lib/R/library/StanHeaders/include/stan/math/prim/fun/square.hpp:67:66: required from ‘stan::math::complex_base<ValueType>::complex_type& stan::math::complex_base<ValueType>::operator/=(const std::compl
ex<_Up>&) [with U = stan::math::var_value<double>; ValueType = stan::math::var_value<double>; stan::math::complex_base<ValueType>::complex_type = std::complex<stan::math::var_value<double> >]’
/usr/lib/R/library/StanHeaders/include/stan/math/prim/core/operator_division.hpp:24:5: required from ‘stan::complex_return_t<U, V> stan::math::internal::complex_divide(const U&, const V&) [with U = std::com
plex<stan::math::var_value<double> >; V = std::complex<stan::math::var_value<double> >; stan::complex_return_t<U, V> = std::complex<stan::math::var_value<double> >]’
/usr/lib/R/library/StanHeaders/include/stan/math/rev/core/operator_division.hpp:122:34: required from here
/usr/include/c++/11.1.0/type_traits:2514:11: error: no type named ‘type’ in ‘struct std::enable_if<false, void>’
2514 | using enable_if_t = typename enable_if<_Cond, _Tp>::type;
| ^~~~~~~~~~~
In file included from /usr/lib/R/library/StanHeaders/include/stan/math/prim/core/complex_base.hpp:4,
from /usr/lib/R/library/StanHeaders/include/stan/math/rev/core/std_complex.hpp:5,
from /usr/lib/R/library/StanHeaders/include/stan/math/rev/core/operator_division.hpp:11,
from /usr/lib/R/library/StanHeaders/include/stan/math/rev/core/operator_divide_equal.hpp:5,
from /usr/lib/R/library/StanHeaders/include/stan/math/rev/core.hpp:29,
from /usr/lib/R/library/StanHeaders/include/src/stan/model/model_base.hpp:5,
from Module.cpp:2:
/usr/lib/R/library/StanHeaders/include/stan/math/prim/core/complex_base.hpp: In instantiation of ‘stan::math::complex_base<ValueType>::complex_type& stan::math::complex_base<ValueType>::operator/=(const std::
complex<_Up>&) [with U = stan::math::var_value<double>; ValueType = stan::math::var_value<double>; stan::math::complex_base<ValueType>::complex_type = std::complex<stan::math::var_value<double> >]’:
/usr/lib/R/library/StanHeaders/include/stan/math/prim/core/operator_division.hpp:24:5: required from ‘stan::complex_return_t<U, V> stan::math::internal::complex_divide(const U&, const V&) [with U = std::com
plex<stan::math::var_value<double> >; V = std::complex<stan::math::var_value<double> >; stan::complex_return_t<U, V> = std::complex<stan::math::var_value<double> >]’
/usr/lib/R/library/StanHeaders/include/stan/math/rev/core/operator_division.hpp:122:34: required from here
/usr/lib/R/library/StanHeaders/include/stan/math/prim/fun/square.hpp:54:13: note: candidate: ‘template<class Container, stan::require_not_stan_scalar_t<Container>* <anonymous>, stan::require_not_var_matrix_t<
Container>* <anonymous>, stan::require_not_nonscalar_prim_or_rev_kernel_expression_t<Container>* <anonymous> > auto stan::math::square(const Container&)’
54 | inline auto square(const Container& x) {
| ^~~~~~
/usr/lib/R/library/StanHeaders/include/stan/math/prim/fun/square.hpp:54:13: note: template argument deduction/substitution failed:
/usr/lib/R/library/StanHeaders/include/stan/math/prim/fun/square.hpp:26:15: note: candidate: ‘double stan::math::square(double)’
26 | inline double square(double x) { return std::pow(x, 2); }
| ^~~~~~
/usr/lib/R/library/StanHeaders/include/stan/math/prim/fun/square.hpp:26:29: note: no known conversion for argument 1 from ‘stan::math::complex_base<stan::math::var_value<double> >::value_type’ {aka ‘stan::m
ath::var_value<double>’} to ‘double’
26 | inline double square(double x) { return std::pow(x, 2); }
| ~~~~~~~^
In file included from /usr/lib/R/library/StanHeaders/include/stan/math/rev/core/std_complex.hpp:5,
from /usr/lib/R/library/StanHeaders/include/stan/math/rev/core/operator_division.hpp:11,
from /usr/lib/R/library/StanHeaders/include/stan/math/rev/core/operator_divide_equal.hpp:5,
from /usr/lib/R/library/StanHeaders/include/stan/math/rev/core.hpp:29,
from /usr/lib/R/library/StanHeaders/include/src/stan/model/model_base.hpp:5,
from Module.cpp:2:
/usr/lib/R/library/StanHeaders/include/stan/math/prim/core/complex_base.hpp:207:57: error: no matching function for call to ‘square(stan::math::complex_base<stan::math::var_value<double> >::value_type)’
207 | value_type sum_sq_im = square(other.real()) + square(other.imag());
| ~~~~~~^~~~~~~~~~~~~~
In file included from /usr/lib/R/library/StanHeaders/include/stan/math/prim/core/complex_base.hpp:4,
from /usr/lib/R/library/StanHeaders/include/stan/math/rev/core/std_complex.hpp:5,
from /usr/lib/R/library/StanHeaders/include/stan/math/rev/core/operator_division.hpp:11,
from /usr/lib/R/library/StanHeaders/include/stan/math/rev/core/operator_divide_equal.hpp:5,
from /usr/lib/R/library/StanHeaders/include/stan/math/rev/core.hpp:29,
from /usr/lib/R/library/StanHeaders/include/src/stan/model/model_base.hpp:5,
from Module.cpp:2:
/usr/lib/R/library/StanHeaders/include/stan/math/prim/fun/square.hpp:68:13: note: candidate: ‘template<class Container, stan::require_container_st<std::is_arithmetic, Container>* <anonymous> > auto stan::math
::square(const Container&)’
68 | inline auto square(const Container& x) {
| ^~~~~~
/usr/lib/R/library/StanHeaders/include/stan/math/prim/fun/square.hpp:68:13: note: template argument deduction/substitution failed:
/usr/lib/R/library/StanHeaders/include/stan/math/prim/fun/square.hpp:54:13: note: candidate: ‘template<class Container, stan::require_not_stan_scalar_t<Container>* <anonymous>, stan::require_not_var_matrix_t<
Container>* <anonymous>, stan::require_not_nonscalar_prim_or_rev_kernel_expression_t<Container>* <anonymous> > auto stan::math::square(const Container&)’
54 | inline auto square(const Container& x) {
| ^~~~~~
/usr/lib/R/library/StanHeaders/include/stan/math/prim/fun/square.hpp:54:13: note: template argument deduction/substitution failed:
/usr/lib/R/library/StanHeaders/include/stan/math/prim/fun/square.hpp:26:15: note: candidate: ‘double stan::math::square(double)’
26 | inline double square(double x) { return std::pow(x, 2); }
| ^~~~~~
/usr/lib/R/library/StanHeaders/include/stan/math/prim/fun/square.hpp:26:29: note: no known conversion for argument 1 from ‘stan::math::complex_base<stan::math::var_value<double> >::value_type’ {aka ‘stan::m
ath::var_value<double>’} to ‘double’
26 | inline double square(double x) { return std::pow(x, 2); }
| ~~~~~~~^
make: *** [/usr/lib64/R/etc/Makeconf:177: Module.o] Error 1 |
@hsbadr ty! Yes I'm seeing that locally as well. The commit I just made should handle that. The problem was that the complex number stuff was using |
Yes, it compiles successfully now. Thanks! |
This PR is also stalled because of a Jenkins hiccup as it looks. @hsbadr sorry for off-topic... but is there any hope for rstan to hit CRAN any time soon? I am asking as our next major production system is gearing up for an update and I would like to get rstan in a more recent version. If you have any suggestions let me know... should I ping you on discourse (what was your username there again?). |
…4.1 (tags/RELEASE_600/final)
Not until @bgoodri gives it sometime or agrees on creating a new package.
You may add our r-packages to your repos: install.packages("StanHeaders", repos = c("https://mc-stan.org/r-packages/", getOption("repos")))
install.packages("rstan", repos = c("https://mc-stan.org/r-packages/", getOption("repos"))) or install the
|
…4.1 (tags/RELEASE_600/final)
Aight @andrjohns I think this is ready! Couple notes to point out
Besides that I think it's a pretty normal PR! |
Jenkins Console Log Machine informationProductName: Mac OS X ProductVersion: 10.11.6 BuildVersion: 15G22010CPU: G++: Clang: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good! My main feedback is that I think most of the additional rev
overloads wouldn't be necessary, as long as the existing specifications were modified to use Math functions that are scalar/Matrix agnostic
stan/math/prim/fun/divide.hpp
Outdated
*/ | ||
template <typename Scalar, typename Mat, require_eigen_t<Mat>* = nullptr, | ||
require_t<bool_constant<std::is_arithmetic<Scalar>::value | ||
|| is_fvar<Scalar>::value>>* = nullptr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if you could replace all of these specialisations with:
template <typename Ta, typename Tb>
inline auto divide(const Ta& a, const Tb& b) {
return (as_array_or_scalar(a) / as_array_or_scalar(b)).matrix();
}
With the appropriate requires
to disambiguate from the varmat
overloads?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh absolutely! Much cleaner!
stan/math/rev/fun/inv.hpp
Outdated
template <typename T, require_eigen_t<T>* = nullptr> | ||
inline auto inv(const var_value<T>& a) { | ||
auto denom = to_arena(a.val().array().square()); | ||
return make_callback_var(inv(a.val()), [a, denom](auto& vi) mutable { | ||
a.adj().array() -= vi.adj().array() / denom; | ||
}); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you use square()
, you can combine both overloads:
template <typename T>
inline auto inv(const var_value<T>& a) {
auto denom = to_arena(square(a.val()));
return make_callback_var(inv(a.val()), [a, denom](auto& vi) mutable {
a.adj() -= vi.adj() / denom;
});
}
stan/math/rev/fun/inv_cloglog.hpp
Outdated
@@ -31,6 +31,15 @@ inline var inv_cloglog(const var& a) { | |||
}); | |||
} | |||
|
|||
template <typename T, require_eigen_t<T>* = nullptr> | |||
inline auto inv_cloglog(const var_value<T>& a) { | |||
auto precomp_exp = (a.val().array() - a.val().array().exp()).exp(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here. if you use the Stan Math exp()
then you can collapse the overloads
stan/math/rev/fun/inv_sqrt.hpp
Outdated
@@ -35,6 +36,14 @@ inline var inv_sqrt(const var& a) { | |||
}); | |||
} | |||
|
|||
template <typename T, require_eigen_t<T>* = nullptr> | |||
inline auto inv_sqrt(const var_value<T>& a) { | |||
auto denom = to_arena(a.val().array() * a.val().array().sqrt()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll stop adding comments for this, but same here
Ty on the catch for reducing the number of overloads, I had to make a little change to |
…4.1 (tags/RELEASE_600/final)
…4.1 (tags/RELEASE_600/final)
…4.1 (tags/RELEASE_600/final)
@andrjohns I think this is ready for review! |
Jenkins Console Log Machine informationProductName: Mac OS X ProductVersion: 10.11.6 BuildVersion: 15G22010CPU: G++: Clang: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some super minor comments, almost done!
@@ -204,7 +204,8 @@ class complex_base { | |||
template <typename U> | |||
complex_type& operator/=(const std::complex<U>& other) { | |||
using stan::math::square; | |||
value_type sum_sq_im = square(other.real()) + square(other.imag()); | |||
value_type sum_sq_im | |||
= (other.real() * other.real()) + (other.imag() * other.imag()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be worth opening a new issue about the square
oddities so it gets looked at
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I'm going to open a bigger issue about our include orders.
* Extends std::true_type if all the provided types are either fvar or | ||
* an arithmetic type, extends std::false_type otherwise. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to update the doc on this
@@ -295,6 +295,13 @@ class var_value<T, require_floating_point_t<T>> { | |||
} | |||
}; | |||
|
|||
namespace internal { | |||
template <typename T> | |||
using require_matrix_var_value = require_t<bool_constant< |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not a huge fan of defining additional require_
s within individual headers, but if you don't see this being used anywhere else then it's fine to leave
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this one is very particular. I'd like to leave it here till unless we need it in other places in the future. This was just in the vari<>
template at first and hasn't been used anywhere else. I just pulled it out because it's a bit cleaner looking.
Jenkins Console Log Machine informationProductName: Mac OS X ProductVersion: 10.11.6 BuildVersion: 15G22010CPU: G++: Clang: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Summary
Adds several unary functions for var matrix as well as overloads for division.
Tests
Tests added for each new function
Side Effects
Nopes
Release notes
Adds several unary functions for var as well as division.
Checklist
Math issue How to add static matrix? #1805
Copyright holder: Steve Bronder
The copyright holder is typically you or your assignee, such as a university or company. By submitting this pull request, the copyright holder is agreeing to the license the submitted work under the following licenses:
- Code: BSD 3-clause (https://opensource.org/licenses/BSD-3-Clause)
- Documentation: CC-BY 4.0 (https://creativecommons.org/licenses/by/4.0/)
the basic tests are passing
./runTests.py test/unit
)make test-headers
)make test-math-dependencies
)make doxygen
)make cpplint
)the code is written in idiomatic C++ and changes are documented in the doxygen
the new changes are tested