Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds unary functions for var<Matrix> #2527

Merged
merged 25 commits into from
Jul 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
4e263ff
adds unary functions for var<Matrix>
SteveBronder Jul 1, 2021
711bf32
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Jul 1, 2021
87a66ec
missing append to make/local in jenkins for CPU opencl
SteveBronder Jul 1, 2021
e80abe9
fixup operator/ to work for var<Matrix> types
SteveBronder Jul 2, 2021
26b5a78
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Jul 2, 2021
4b434b0
fix headers
SteveBronder Jul 6, 2021
c534e81
Revert "fix headers"
SteveBronder Jul 6, 2021
77c0b2b
fix headers
SteveBronder Jul 6, 2021
329f76a
fix template for square
SteveBronder Jul 6, 2021
1b4b135
update cos
SteveBronder Jul 6, 2021
5b84c99
trying fvar and arithmetic specialization for divide()
SteveBronder Jul 6, 2021
a50a6c0
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Jul 7, 2021
2895732
moves is_fvar_or_arithmetic to it's own header
SteveBronder Jul 7, 2021
f0d6a83
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Jul 7, 2021
c8bfbf1
update to reduce number of overloads for var<matrix> functions
SteveBronder Jul 13, 2021
8de7f58
merge to develop
SteveBronder Jul 13, 2021
1a8a635
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Jul 13, 2021
631b505
update to make additions as double double
SteveBronder Jul 13, 2021
90373bf
fix array err when adding scalar and eigen type
SteveBronder Jul 13, 2021
0d50121
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Jul 13, 2021
a1f89e8
fix log
SteveBronder Jul 13, 2021
b58c301
update to cover templates for OpenCL types
SteveBronder Jul 14, 2021
35f8930
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Jul 14, 2021
4ebad3b
delete is_fvar_or_arithmetic
SteveBronder Jul 14, 2021
c9c6f0d
update docs for is_stan_scalar_or_eigen
SteveBronder Jul 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion stan/math/prim/core/complex_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ class complex_base {
template <typename U>
complex_type& operator/=(const std::complex<U>& other) {
using stan::math::square;
value_type sum_sq_im = square(other.real()) + square(other.imag());
value_type sum_sq_im
= (other.real() * other.real()) + (other.imag() * other.imag());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth opening a new issue about the square oddities so it gets looked at

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm going to open a bigger issue about our include orders.

value_type re_temp = (re_ * other.real() + im_ * other.imag()) / sum_sq_im;
im_ = (im_ * other.real() - re_ * other.imag()) / sum_sq_im;
re_ = re_temp;
Expand Down
12 changes: 12 additions & 0 deletions stan/math/prim/fun/as_array_or_scalar.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,18 @@ inline T as_array_or_scalar(T&& v) {
return std::forward<T>(v);
}

/**
* Returns a reference to rvalue specified input value.
*
* @tparam T Type of element.
* @param v Specified value.
* @return Same value.
*/
template <typename T, require_stan_scalar_t<T>* = nullptr>
inline T& as_array_or_scalar(T& v) {
return v;
}

/**
* Returns specified input value.
*
Expand Down
12 changes: 11 additions & 1 deletion stan/math/prim/fun/cos.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,16 @@
namespace stan {
namespace math {

/**
* Arithmetic version of `cos()`
* @tparam T An `Arithmetic` type.
* @param x Arithmetic scalar.
*/
template <typename T, require_arithmetic_t<T>* = nullptr>
inline auto cos(T x) {
return std::cos(x);
}

/**
* Structure to wrap `cos()` so it can be vectorized.
*
Expand All @@ -23,7 +33,6 @@ namespace math {
struct cos_fun {
template <typename T>
static inline T fun(const T& x) {
using std::cos;
return cos(x);
}
};
Expand All @@ -38,6 +47,7 @@ struct cos_fun {
*/
template <typename Container,
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
require_not_stan_scalar_t<Container>* = nullptr,
require_not_var_matrix_t<Container>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
Container>* = nullptr>
Expand Down
27 changes: 6 additions & 21 deletions stan/math/prim/fun/divide.hpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#ifndef STAN_MATH_PRIM_FUN_DIVIDE_HPP
#define STAN_MATH_PRIM_FUN_DIVIDE_HPP

#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/fun/as_array_or_scalar.hpp>
#include <cstddef>
#include <cstdlib>
#include <type_traits>
Expand Down Expand Up @@ -40,26 +41,10 @@ inline int divide(int x, int y) {
* @param[in] c specified scalar
* @return matrix divided by the scalar
*/
template <typename Mat, typename Scal, typename = require_eigen_t<Mat>,
require_stan_scalar_t<Scal>* = nullptr,
require_all_not_var_t<scalar_type_t<Mat>, Scal>* = nullptr>
inline auto divide(const Mat& m, Scal c) {
return m / c;
}

/**
* Return matrix divided by matrix.
*
* @tparam Mat1 type of the matrix or expression
* @tparam Mat2 type of the matrix or expression
* @param[in] m specified matrix or expression
* @param[in] c specified matrix or expression
* @return matrix divided elementwise by `c`
*/
template <typename Mat1, typename Mat2,
require_all_eigen_t<Mat1, Mat2>* = nullptr>
inline auto divide(const Mat1& m, const Mat2& c) {
return (m.array() / c.array()).matrix();
template <typename T1, typename T2, require_any_eigen_t<T1, T2>* = nullptr,
require_all_not_st_var<T1, T2>* = nullptr>
inline auto divide(const T1& m, const T2& c) {
return (as_array_or_scalar(m) / as_array_or_scalar(c)).matrix();
}

} // namespace math
Expand Down
1 change: 1 addition & 0 deletions stan/math/prim/fun/inv_cloglog.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ struct inv_cloglog_fun {
*/
template <typename Container,
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
require_not_var_matrix_t<Container>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
Container>* = nullptr>
inline auto inv_cloglog(const Container& x) {
Expand Down
12 changes: 9 additions & 3 deletions stan/math/prim/fun/inv_sqrt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
namespace stan {
namespace math {

template <typename T, require_stan_scalar_t<T>* = nullptr>
inline auto inv_sqrt(T x) {
using std::sqrt;
return inv(sqrt(x));
}
/**
* Structure to wrap `1 / sqrt(x)` so that it can be vectorized.
*
Expand All @@ -22,8 +27,7 @@ namespace math {
struct inv_sqrt_fun {
template <typename T>
static inline T fun(const T& x) {
using std::sqrt;
return inv(sqrt(x));
return inv_sqrt(x);
}
};

Expand All @@ -37,6 +41,8 @@ struct inv_sqrt_fun {
*/
template <typename Container,
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
require_not_var_matrix_t<Container>* = nullptr,
require_not_stan_scalar_t<Container>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
Container>* = nullptr>
inline auto inv_sqrt(const Container& x) {
Expand All @@ -51,7 +57,7 @@ inline auto inv_sqrt(const Container& x) {
* @param x Container
* @return inverse square root each variable in the container.
*/
template <typename Container,
template <typename Container, require_not_var_matrix_t<Container>* = nullptr,
require_container_st<std::is_arithmetic, Container>* = nullptr>
inline auto inv_sqrt(const Container& x) {
return apply_vector_unary<Container>::apply(
Expand Down
6 changes: 4 additions & 2 deletions stan/math/prim/fun/lambert_w.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ struct lambert_wm1_fun {
* @return value of the W0 branch of the Lambert W function for each value in x
* @throw std::domain_error if x is less than or equal to `-e^(-1)`
*/
template <typename T, require_not_stan_scalar_t<T>* = nullptr>
template <typename T, require_not_stan_scalar_t<T>* = nullptr,
require_not_var_matrix_t<T>* = nullptr>
inline auto lambert_w0(const T& x) {
return apply_scalar_unary<internal::lambert_w0_fun, T>::apply(x);
}
Expand All @@ -92,7 +93,8 @@ inline auto lambert_w0(const T& x) {
* @throw std::domain_error if x is less than or equal to `-e^(-1)` or greater
* than or equal to 0
*/
template <typename T, require_not_stan_scalar_t<T>* = nullptr>
template <typename T, require_not_stan_scalar_t<T>* = nullptr,
require_not_var_matrix_t<T>* = nullptr>
inline auto lambert_wm1(const T& x) {
return apply_scalar_unary<internal::lambert_wm1_fun, T>::apply(x);
}
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/fun/lgamma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ struct lgamma_fun {
* applied to each value in x.
* @throw std::domain_error if any value is a negative integer or 0.
*/
template <typename T,
template <typename T, require_not_var_matrix_t<T>* = nullptr,
require_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr>
inline auto lgamma(const T& x) {
return apply_scalar_unary<lgamma_fun, T>::apply(x);
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/fun/log10.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ struct log10_fun {
* @return Log base-10 applied to each value in x.
*/
template <
typename Container,
typename Container, require_not_var_matrix_t<Container>* = nullptr,
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
require_not_nonscalar_prim_or_rev_kernel_expression_t<Container>* = nullptr>
inline auto log10(const Container& x) {
Expand Down
5 changes: 3 additions & 2 deletions stan/math/prim/fun/log1m_exp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ struct log1m_exp_fun {
* @param x container
* @return Natural log of (1 - exp()) applied to each value in x.
*/
template <typename T, require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
T>* = nullptr>
template <
typename T, require_not_var_matrix_t<T>* = nullptr,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr>
inline auto log1m_exp(const T& x) {
return apply_scalar_unary<log1m_exp_fun, T>::apply(x);
}
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/fun/log1m_inv_logit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ struct log1m_inv_logit_fun {
* @param x container
* @return Elementwise log1m_inv_logit of members of container.
*/
template <typename T,
template <typename T, require_not_var_matrix_t<T>* = nullptr,
require_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr>
inline typename apply_scalar_unary<log1m_inv_logit_fun, T>::return_t
log1m_inv_logit(const T& x) {
Expand Down
2 changes: 1 addition & 1 deletion stan/math/prim/fun/log2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct log2_fun {
* @param x container
* @return elementwise log2 of container elements
*/
template <typename T,
template <typename T, require_not_var_matrix_t<T>* = nullptr,
require_not_nonscalar_prim_or_rev_kernel_expression_t<T>* = nullptr>
inline auto log2(const T& x) {
return apply_scalar_unary<log2_fun, T>::apply(x);
Expand Down
1 change: 1 addition & 0 deletions stan/math/prim/fun/logit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ struct logit_fun {
template <
typename Container,
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
require_not_var_matrix_t<Container>* = nullptr,
require_not_nonscalar_prim_or_rev_kernel_expression_t<Container>* = nullptr>
inline auto logit(const Container& x) {
return apply_scalar_unary<logit_fun, Container>::apply(x);
Expand Down
5 changes: 2 additions & 3 deletions stan/math/prim/fun/square.hpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
#ifndef STAN_MATH_PRIM_FUN_SQUARE_HPP
#define STAN_MATH_PRIM_FUN_SQUARE_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/fun/square.hpp>
#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/functor/apply_scalar_unary.hpp>
#include <stan/math/prim/functor/apply_vector_unary.hpp>
#include <cmath>
Expand Down Expand Up @@ -48,7 +47,7 @@ struct square_fun {
* @return Each value in x squared.
*/
template <
typename Container,
typename Container, require_not_stan_scalar_t<Container>* = nullptr,
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
require_not_var_matrix_t<Container>* = nullptr,
require_not_nonscalar_prim_or_rev_kernel_expression_t<Container>* = nullptr>
Expand Down
1 change: 1 addition & 0 deletions stan/math/prim/meta.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@
#include <stan/math/prim/meta/is_vector.hpp>
#include <stan/math/prim/meta/is_vector_like.hpp>
#include <stan/math/prim/meta/is_stan_scalar.hpp>
#include <stan/math/prim/meta/is_stan_scalar_or_eigen.hpp>
#include <stan/math/prim/meta/partials_return_type.hpp>
#include <stan/math/prim/meta/partials_type.hpp>
#include <stan/math/prim/meta/plain_type.hpp>
Expand Down
30 changes: 30 additions & 0 deletions stan/math/prim/meta/is_stan_scalar_or_eigen.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#ifndef STAN_MATH_PRIM_META_IS_STAN_SCALAR_OR_EIGEN_HPP
#define STAN_MATH_PRIM_META_IS_STAN_SCALAR_OR_EIGEN_HPP

#include <stan/math/prim/meta/bool_constant.hpp>
#include <stan/math/prim/meta/is_stan_scalar.hpp>
#include <stan/math/prim/meta/is_eigen.hpp>
#include <stan/math/prim/meta/scalar_type.hpp>
#include <stan/math/prim/meta/conjunction.hpp>
#include <stan/math/prim/meta/require_helpers.hpp>

#include <type_traits>

namespace stan {

/** \ingroup type_trait
* Extends std::true_type if all the provided types are either a Stan Scalar
* type or a type inheriting from `EigenBase`.
*/
template <typename T>
using is_stan_scalar_or_eigen
= bool_constant<is_stan_scalar<std::decay_t<T>>::value
|| is_eigen<std::decay_t<T>>::value>;

STAN_ADD_REQUIRE_UNARY(stan_scalar_or_eigen, is_stan_scalar_or_eigen,
require_stan_scalar_real);
STAN_ADD_REQUIRE_UNARY_INNER(stan_scalar_or_eigen, is_stan_scalar_or_eigen,
require_stan_scalar_real);

} // namespace stan
#endif
21 changes: 19 additions & 2 deletions stan/math/rev/core/operator_divide_equal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace math {
template <typename T>
inline var_value<T>& var_value<T, require_floating_point_t<T>>::operator/=(
const var_value<T>& b) {
vi_ = new internal::divide_vv_vari(vi_, b.vi_);
vi_ = divide(*this, b).vi_;
return *this;
}

Expand All @@ -20,7 +20,24 @@ inline var_value<T>& var_value<T, require_floating_point_t<T>>::operator/=(
if (b == 1.0) {
return *this;
}
vi_ = new internal::divide_vd_vari(vi_, b);
vi_ = divide(*this, b).vi_;
return *this;
}

template <typename T>
inline var_value<T>& var_value<T, internal::require_matrix_var_value<T>>::
operator/=(const var_value<T>& b) {
vi_ = divide(*this, b).vi_;
return *this;
}

template <typename T>
inline var_value<T>& var_value<T, internal::require_matrix_var_value<T>>::
operator/=(T b) {
if (b == 1.0) {
return *this;
}
vi_ = divide(*this, b).vi_;
return *this;
}

Expand Down
Loading