diff --git a/stan/math/prim/core/complex_base.hpp b/stan/math/prim/core/complex_base.hpp index 222a93009ef..c9993037145 100644 --- a/stan/math/prim/core/complex_base.hpp +++ b/stan/math/prim/core/complex_base.hpp @@ -204,7 +204,8 @@ class complex_base { template complex_type& operator/=(const std::complex& other) { using stan::math::square; - value_type sum_sq_im = square(other.real()) + square(other.imag()); + value_type sum_sq_im + = (other.real() * other.real()) + (other.imag() * other.imag()); value_type re_temp = (re_ * other.real() + im_ * other.imag()) / sum_sq_im; im_ = (im_ * other.real() - re_ * other.imag()) / sum_sq_im; re_ = re_temp; diff --git a/stan/math/prim/fun/as_array_or_scalar.hpp b/stan/math/prim/fun/as_array_or_scalar.hpp index 6e91323d447..46bafdd7ee9 100644 --- a/stan/math/prim/fun/as_array_or_scalar.hpp +++ b/stan/math/prim/fun/as_array_or_scalar.hpp @@ -20,6 +20,18 @@ inline T as_array_or_scalar(T&& v) { return std::forward(v); } +/** + * Returns a reference to rvalue specified input value. + * + * @tparam T Type of element. + * @param v Specified value. + * @return Same value. + */ +template * = nullptr> +inline T& as_array_or_scalar(T& v) { + return v; +} + /** * Returns specified input value. * diff --git a/stan/math/prim/fun/cos.hpp b/stan/math/prim/fun/cos.hpp index 82735e60f40..c0796a11dc0 100644 --- a/stan/math/prim/fun/cos.hpp +++ b/stan/math/prim/fun/cos.hpp @@ -13,6 +13,16 @@ namespace stan { namespace math { +/** + * Arithmetic version of `cos()` + * @tparam T An `Arithmetic` type. + * @param x Arithmetic scalar. + */ +template * = nullptr> +inline auto cos(T x) { + return std::cos(x); +} + /** * Structure to wrap `cos()` so it can be vectorized. * @@ -23,7 +33,6 @@ namespace math { struct cos_fun { template static inline T fun(const T& x) { - using std::cos; return cos(x); } }; @@ -38,6 +47,7 @@ struct cos_fun { */ template * = nullptr, + require_not_stan_scalar_t* = nullptr, require_not_var_matrix_t* = nullptr, require_all_not_nonscalar_prim_or_rev_kernel_expression_t< Container>* = nullptr> diff --git a/stan/math/prim/fun/divide.hpp b/stan/math/prim/fun/divide.hpp index 183a6e294a1..c15ed453953 100644 --- a/stan/math/prim/fun/divide.hpp +++ b/stan/math/prim/fun/divide.hpp @@ -1,9 +1,10 @@ #ifndef STAN_MATH_PRIM_FUN_DIVIDE_HPP #define STAN_MATH_PRIM_FUN_DIVIDE_HPP +#include #include #include -#include +#include #include #include #include @@ -40,26 +41,10 @@ inline int divide(int x, int y) { * @param[in] c specified scalar * @return matrix divided by the scalar */ -template , - require_stan_scalar_t* = nullptr, - require_all_not_var_t, Scal>* = nullptr> -inline auto divide(const Mat& m, Scal c) { - return m / c; -} - -/** - * Return matrix divided by matrix. - * - * @tparam Mat1 type of the matrix or expression - * @tparam Mat2 type of the matrix or expression - * @param[in] m specified matrix or expression - * @param[in] c specified matrix or expression - * @return matrix divided elementwise by `c` - */ -template * = nullptr> -inline auto divide(const Mat1& m, const Mat2& c) { - return (m.array() / c.array()).matrix(); +template * = nullptr, + require_all_not_st_var* = nullptr> +inline auto divide(const T1& m, const T2& c) { + return (as_array_or_scalar(m) / as_array_or_scalar(c)).matrix(); } } // namespace math diff --git a/stan/math/prim/fun/inv_cloglog.hpp b/stan/math/prim/fun/inv_cloglog.hpp index 6559ebba0db..daaec48c278 100644 --- a/stan/math/prim/fun/inv_cloglog.hpp +++ b/stan/math/prim/fun/inv_cloglog.hpp @@ -73,6 +73,7 @@ struct inv_cloglog_fun { */ template * = nullptr, + require_not_var_matrix_t* = nullptr, require_all_not_nonscalar_prim_or_rev_kernel_expression_t< Container>* = nullptr> inline auto inv_cloglog(const Container& x) { diff --git a/stan/math/prim/fun/inv_sqrt.hpp b/stan/math/prim/fun/inv_sqrt.hpp index d97b755a56f..d2ec2b01885 100644 --- a/stan/math/prim/fun/inv_sqrt.hpp +++ b/stan/math/prim/fun/inv_sqrt.hpp @@ -12,6 +12,11 @@ namespace stan { namespace math { +template * = nullptr> +inline auto inv_sqrt(T x) { + using std::sqrt; + return inv(sqrt(x)); +} /** * Structure to wrap `1 / sqrt(x)` so that it can be vectorized. * @@ -22,8 +27,7 @@ namespace math { struct inv_sqrt_fun { template static inline T fun(const T& x) { - using std::sqrt; - return inv(sqrt(x)); + return inv_sqrt(x); } }; @@ -37,6 +41,8 @@ struct inv_sqrt_fun { */ template * = nullptr, + require_not_var_matrix_t* = nullptr, + require_not_stan_scalar_t* = nullptr, require_all_not_nonscalar_prim_or_rev_kernel_expression_t< Container>* = nullptr> inline auto inv_sqrt(const Container& x) { @@ -51,7 +57,7 @@ inline auto inv_sqrt(const Container& x) { * @param x Container * @return inverse square root each variable in the container. */ -template * = nullptr, require_container_st* = nullptr> inline auto inv_sqrt(const Container& x) { return apply_vector_unary::apply( diff --git a/stan/math/prim/fun/lambert_w.hpp b/stan/math/prim/fun/lambert_w.hpp index e4c8b312409..c7b907ad204 100644 --- a/stan/math/prim/fun/lambert_w.hpp +++ b/stan/math/prim/fun/lambert_w.hpp @@ -78,7 +78,8 @@ struct lambert_wm1_fun { * @return value of the W0 branch of the Lambert W function for each value in x * @throw std::domain_error if x is less than or equal to `-e^(-1)` */ -template * = nullptr> +template * = nullptr, + require_not_var_matrix_t* = nullptr> inline auto lambert_w0(const T& x) { return apply_scalar_unary::apply(x); } @@ -92,7 +93,8 @@ inline auto lambert_w0(const T& x) { * @throw std::domain_error if x is less than or equal to `-e^(-1)` or greater * than or equal to 0 */ -template * = nullptr> +template * = nullptr, + require_not_var_matrix_t* = nullptr> inline auto lambert_wm1(const T& x) { return apply_scalar_unary::apply(x); } diff --git a/stan/math/prim/fun/lgamma.hpp b/stan/math/prim/fun/lgamma.hpp index 8c3bedea622..93896185525 100644 --- a/stan/math/prim/fun/lgamma.hpp +++ b/stan/math/prim/fun/lgamma.hpp @@ -114,7 +114,7 @@ struct lgamma_fun { * applied to each value in x. * @throw std::domain_error if any value is a negative integer or 0. */ -template * = nullptr, require_not_nonscalar_prim_or_rev_kernel_expression_t* = nullptr> inline auto lgamma(const T& x) { return apply_scalar_unary::apply(x); diff --git a/stan/math/prim/fun/log10.hpp b/stan/math/prim/fun/log10.hpp index 96ff37ee4aa..a5702086bfd 100644 --- a/stan/math/prim/fun/log10.hpp +++ b/stan/math/prim/fun/log10.hpp @@ -35,7 +35,7 @@ struct log10_fun { * @return Log base-10 applied to each value in x. */ template < - typename Container, + typename Container, require_not_var_matrix_t* = nullptr, require_not_container_st* = nullptr, require_not_nonscalar_prim_or_rev_kernel_expression_t* = nullptr> inline auto log10(const Container& x) { diff --git a/stan/math/prim/fun/log1m_exp.hpp b/stan/math/prim/fun/log1m_exp.hpp index e5023795d16..5243088649d 100644 --- a/stan/math/prim/fun/log1m_exp.hpp +++ b/stan/math/prim/fun/log1m_exp.hpp @@ -77,8 +77,9 @@ struct log1m_exp_fun { * @param x container * @return Natural log of (1 - exp()) applied to each value in x. */ -template * = nullptr> +template < + typename T, require_not_var_matrix_t* = nullptr, + require_all_not_nonscalar_prim_or_rev_kernel_expression_t* = nullptr> inline auto log1m_exp(const T& x) { return apply_scalar_unary::apply(x); } diff --git a/stan/math/prim/fun/log1m_inv_logit.hpp b/stan/math/prim/fun/log1m_inv_logit.hpp index bb7164c9bd9..4755da4f21e 100644 --- a/stan/math/prim/fun/log1m_inv_logit.hpp +++ b/stan/math/prim/fun/log1m_inv_logit.hpp @@ -80,7 +80,7 @@ struct log1m_inv_logit_fun { * @param x container * @return Elementwise log1m_inv_logit of members of container. */ -template * = nullptr, require_not_nonscalar_prim_or_rev_kernel_expression_t* = nullptr> inline typename apply_scalar_unary::return_t log1m_inv_logit(const T& x) { diff --git a/stan/math/prim/fun/log2.hpp b/stan/math/prim/fun/log2.hpp index 0bc2b3081b2..ef0dd4dd844 100644 --- a/stan/math/prim/fun/log2.hpp +++ b/stan/math/prim/fun/log2.hpp @@ -44,7 +44,7 @@ struct log2_fun { * @param x container * @return elementwise log2 of container elements */ -template * = nullptr, require_not_nonscalar_prim_or_rev_kernel_expression_t* = nullptr> inline auto log2(const T& x) { return apply_scalar_unary::apply(x); diff --git a/stan/math/prim/fun/logit.hpp b/stan/math/prim/fun/logit.hpp index f208d76fdae..affdae68b73 100644 --- a/stan/math/prim/fun/logit.hpp +++ b/stan/math/prim/fun/logit.hpp @@ -86,6 +86,7 @@ struct logit_fun { template < typename Container, require_not_container_st* = nullptr, + require_not_var_matrix_t* = nullptr, require_not_nonscalar_prim_or_rev_kernel_expression_t* = nullptr> inline auto logit(const Container& x) { return apply_scalar_unary::apply(x); diff --git a/stan/math/prim/fun/square.hpp b/stan/math/prim/fun/square.hpp index 8d165560afc..16d2d787fd7 100644 --- a/stan/math/prim/fun/square.hpp +++ b/stan/math/prim/fun/square.hpp @@ -1,9 +1,8 @@ #ifndef STAN_MATH_PRIM_FUN_SQUARE_HPP #define STAN_MATH_PRIM_FUN_SQUARE_HPP -#include #include -#include +#include #include #include #include @@ -48,7 +47,7 @@ struct square_fun { * @return Each value in x squared. */ template < - typename Container, + typename Container, require_not_stan_scalar_t* = nullptr, require_not_container_st* = nullptr, require_not_var_matrix_t* = nullptr, require_not_nonscalar_prim_or_rev_kernel_expression_t* = nullptr> diff --git a/stan/math/prim/meta.hpp b/stan/math/prim/meta.hpp index adc3b92f12b..2fa62a1e55e 100644 --- a/stan/math/prim/meta.hpp +++ b/stan/math/prim/meta.hpp @@ -215,6 +215,7 @@ #include #include #include +#include #include #include #include diff --git a/stan/math/prim/meta/is_stan_scalar_or_eigen.hpp b/stan/math/prim/meta/is_stan_scalar_or_eigen.hpp new file mode 100644 index 00000000000..00693a84172 --- /dev/null +++ b/stan/math/prim/meta/is_stan_scalar_or_eigen.hpp @@ -0,0 +1,30 @@ +#ifndef STAN_MATH_PRIM_META_IS_STAN_SCALAR_OR_EIGEN_HPP +#define STAN_MATH_PRIM_META_IS_STAN_SCALAR_OR_EIGEN_HPP + +#include +#include +#include +#include +#include +#include + +#include + +namespace stan { + +/** \ingroup type_trait + * Extends std::true_type if all the provided types are either a Stan Scalar + * type or a type inheriting from `EigenBase`. + */ +template +using is_stan_scalar_or_eigen + = bool_constant>::value + || is_eigen>::value>; + +STAN_ADD_REQUIRE_UNARY(stan_scalar_or_eigen, is_stan_scalar_or_eigen, + require_stan_scalar_real); +STAN_ADD_REQUIRE_UNARY_INNER(stan_scalar_or_eigen, is_stan_scalar_or_eigen, + require_stan_scalar_real); + +} // namespace stan +#endif diff --git a/stan/math/rev/core/operator_divide_equal.hpp b/stan/math/rev/core/operator_divide_equal.hpp index 23687c7011e..a8176552b15 100644 --- a/stan/math/rev/core/operator_divide_equal.hpp +++ b/stan/math/rev/core/operator_divide_equal.hpp @@ -10,7 +10,7 @@ namespace math { template inline var_value& var_value>::operator/=( const var_value& b) { - vi_ = new internal::divide_vv_vari(vi_, b.vi_); + vi_ = divide(*this, b).vi_; return *this; } @@ -20,7 +20,24 @@ inline var_value& var_value>::operator/=( if (b == 1.0) { return *this; } - vi_ = new internal::divide_vd_vari(vi_, b); + vi_ = divide(*this, b).vi_; + return *this; +} + +template +inline var_value& var_value>:: +operator/=(const var_value& b) { + vi_ = divide(*this, b).vi_; + return *this; +} + +template +inline var_value& var_value>:: +operator/=(T b) { + if (b == 1.0) { + return *this; + } + vi_ = divide(*this, b).vi_; return *this; } diff --git a/stan/math/rev/core/operator_division.hpp b/stan/math/rev/core/operator_division.hpp index c61013f6328..bf0307c6135 100644 --- a/stan/math/rev/core/operator_division.hpp +++ b/stan/math/rev/core/operator_division.hpp @@ -1,54 +1,25 @@ #ifndef STAN_MATH_REV_CORE_OPERATOR_DIVISION_HPP #define STAN_MATH_REV_CORE_OPERATOR_DIVISION_HPP +#include #include #include #include #include +#include #include #include -#include -#include -#include #include #include #include -#include +#include +#include #include #include namespace stan { namespace math { -namespace internal { -// (dividend/divisor)' = dividend' * (1 / divisor) - divisor' * (dividend / -// [divisor * divisor]) -class divide_vv_vari final : public op_vv_vari { - public: - divide_vv_vari(vari* dividend_vi, vari* divisor_vi) - : op_vv_vari(dividend_vi->val_ / divisor_vi->val_, dividend_vi, - divisor_vi) {} - void chain() { - avi_->adj_ += adj_ / bvi_->val_; - bvi_->adj_ -= adj_ * avi_->val_ / (bvi_->val_ * bvi_->val_); - } -}; - -class divide_vd_vari final : public op_vd_vari { - public: - divide_vd_vari(vari* dividend_vi, double divisor) - : op_vd_vari(dividend_vi->val_ / divisor, dividend_vi, divisor) {} - void chain() { avi_->adj_ += adj_ / bd_; } -}; - -class divide_dv_vari final : public op_dv_vari { - public: - divide_dv_vari(double dividend, vari* divisor_vi) - : op_dv_vari(dividend / divisor_vi->val_, dividend, divisor_vi) {} - void chain() { bvi_->adj_ -= adj_ * ad_ / (bvi_->val_ * bvi_->val_); } -}; -} // namespace internal - /** * Division operator for two variables (C++). * @@ -88,7 +59,12 @@ class divide_dv_vari final : public op_dv_vari { * second. */ inline var operator/(const var& dividend, const var& divisor) { - return {new internal::divide_vv_vari(dividend.vi_, divisor.vi_)}; + return make_callback_var( + dividend.val() / divisor.val(), [dividend, divisor](auto&& vi) { + dividend.adj() += vi.adj() / divisor.val(); + divisor.adj() + -= vi.adj() * dividend.val() / (divisor.val() * divisor.val()); + }); } /** @@ -109,7 +85,9 @@ inline var operator/(const var& dividend, Arith divisor) { if (divisor == 1.0) { return dividend; } - return {new internal::divide_vd_vari(dividend.vi_, divisor)}; + return make_callback_var( + dividend.val() / divisor, + [dividend, divisor](auto&& vi) { dividend.adj() += vi.adj() / divisor; }); } /** @@ -126,7 +104,156 @@ inline var operator/(const var& dividend, Arith divisor) { */ template * = nullptr> inline var operator/(Arith dividend, const var& divisor) { - return {new internal::divide_dv_vari(dividend, divisor.vi_)}; + return make_callback_var( + dividend / divisor.val(), [dividend, divisor](auto&& vi) { + divisor.adj() -= vi.adj() * dividend / (divisor.val() * divisor.val()); + }); +} + +/** + * Return matrix divided by scalar. + * + * @tparam Mat A type inheriting from `EigenBase` with an `Arithmetic` scalar + * type. + * @param[in] m specified matrix or expression + * @param[in] c specified scalar + * @return matrix divided by the scalar + */ +template * = nullptr, + require_stan_scalar_t* = nullptr, + require_all_st_var_or_arithmetic* = nullptr, + require_any_st_var* = nullptr> +inline auto divide(const Mat& m, Scalar c) { + if (!is_constant::value && !is_constant::value) { + arena_t> arena_m = m; + var arena_c = c; + auto inv_c = (1.0 / arena_c.val()); + arena_t> res = inv_c * arena_m.val(); + reverse_pass_callback([arena_c, inv_c, arena_m, res]() mutable { + auto inv_times_adj = (inv_c * res.adj().array()).eval(); + arena_c.adj() -= (inv_times_adj * res.val().array()).sum(); + arena_m.adj().array() += inv_times_adj; + }); + return promote_scalar_t(res); + } else if (!is_constant::value) { + arena_t> arena_m = m; + auto inv_c = (1.0 / value_of(c)); + arena_t> res = inv_c * arena_m.val(); + reverse_pass_callback([inv_c, arena_m, res]() mutable { + arena_m.adj().array() += inv_c * res.adj_op().array(); + }); + return promote_scalar_t(res); + } else { + var arena_c = c; + auto inv_c = (1.0 / arena_c.val()); + arena_t> res = inv_c * value_of(m).array(); + reverse_pass_callback([arena_c, inv_c, res]() mutable { + arena_c.adj() -= inv_c * (res.adj().array() * res.val().array()).sum(); + }); + return promote_scalar_t(res); + } +} + +/** + * Return scalar divided by matrix. + * + * @tparam Mat Either a type inheriting from `EigenBase` with a scalar type of + * `var` or a `var_value` with type `T` inheriting from `EigenBase`. + * @param[in] m specified matrix or expression + * @param[in] c specified scalar + * @return matrix divided by the scalar + */ +template * = nullptr, + require_stan_scalar_t* = nullptr, + require_all_st_var_or_arithmetic* = nullptr, + require_any_st_var* = nullptr> +inline auto divide(Scalar c, const Mat& m) { + if (!is_constant::value && !is_constant::value) { + arena_t> arena_m = m; + auto inv_m = to_arena(arena_m.val().array().inverse()); + var arena_c = c; + arena_t> res = arena_c.val() * inv_m; + reverse_pass_callback([arena_c, inv_m, arena_m, res]() mutable { + auto inv_times_res = (inv_m * res.adj().array()).eval(); + arena_m.adj().array() -= inv_times_res * res.val().array(); + arena_c.adj() += (inv_times_res).sum(); + }); + return promote_scalar_t(res); + } else if (!is_constant::value) { + arena_t> arena_m = m; + auto inv_m = to_arena(arena_m.val().array().inverse()); + arena_t> res = value_of(c) * inv_m; + reverse_pass_callback([inv_m, arena_m, res]() mutable { + arena_m.adj().array() -= inv_m * res.adj().array() * res.val().array(); + }); + return promote_scalar_t(res); + } else { + auto inv_m = to_arena(value_of(m).array().inverse()); + var arena_c = c; + arena_t> res = arena_c.val() * inv_m; + reverse_pass_callback([arena_c, inv_m, res]() mutable { + arena_c.adj() += (inv_m * res.adj().array()).sum(); + }); + return promote_scalar_t(res); + } +} + +/// + +/** + * Return a matrix divided by a matrix elementwise. + * @tparam Mat1 Either a type inheriting from `EigenBase` or a `var_value` + * with type `T` inheriting from `EigenBase`. + * @tparam Mat2 Either a type inheriting from `EigenBase` or a `var_value` + * with type `T` inheriting from `EigenBase`. + * @param[in] m1 specified matrix or expression + * @param[in] m2 specified matrix or expression + */ +template * = nullptr, + require_any_matrix_st* = nullptr> +inline auto divide(const Mat1& m1, const Mat2& m2) { + if (!is_constant::value && !is_constant::value) { + arena_t> arena_m1 = m1; + arena_t> arena_m2 = m2; + auto inv_m2 = to_arena(arena_m2.val().array().inverse()); + using val_ret = decltype((inv_m2 * arena_m1.val().array()).matrix().eval()); + using ret_type = return_var_matrix_t; + arena_t res = (inv_m2.array() * arena_m1.val().array()).matrix(); + reverse_pass_callback([inv_m2, arena_m1, arena_m2, res]() mutable { + auto inv_times_res = (inv_m2 * res.adj().array()).eval(); + arena_m1.adj().array() += inv_times_res; + arena_m2.adj().array() -= inv_times_res * res.val().array(); + }); + return ret_type(res); + } else if (!is_constant::value) { + arena_t> arena_m1 = value_of(m1); + arena_t> arena_m2 = m2; + auto inv_m2 = to_arena(arena_m2.val().array().inverse()); + using val_ret = decltype((inv_m2 * arena_m1.array()).matrix().eval()); + using ret_type = return_var_matrix_t; + arena_t res = (inv_m2.array() * arena_m1.array()).matrix(); + reverse_pass_callback([inv_m2, arena_m1, arena_m2, res]() mutable { + arena_m2.adj().array() -= inv_m2 * res.adj().array() * res.val().array(); + }); + return ret_type(res); + } else { + arena_t> arena_m1 = m1; + arena_t> arena_m2 = value_of(m2); + auto inv_m2 = to_arena(arena_m2.array().inverse()); + using val_ret = decltype((inv_m2 * arena_m1.val().array()).matrix().eval()); + using ret_type = return_var_matrix_t; + arena_t res = (inv_m2.array() * arena_m1.val().array()).matrix(); + reverse_pass_callback([inv_m2, arena_m1, arena_m2, res]() mutable { + arena_m1.adj().array() += inv_m2 * res.adj().array(); + }); + return ret_type(res); + } +} + +template * = nullptr> +inline auto operator/(const T1& dividend, const T2& divisor) { + return divide(dividend, divisor); } inline std::complex operator/(const std::complex& x1, diff --git a/stan/math/rev/core/var.hpp b/stan/math/rev/core/var.hpp index 0c7ade74951..b51c06545d2 100644 --- a/stan/math/rev/core/var.hpp +++ b/stan/math/rev/core/var.hpp @@ -295,6 +295,13 @@ class var_value> { } }; +namespace internal { +template +using require_matrix_var_value = require_t::value || is_kernel_expression_and_not_scalar::value) + && std::is_floating_point>::value>>; +} + /** * Independent (input) and dependent (output) variables for gradients. * @@ -310,10 +317,7 @@ class var_value> { * @tparam T An Floating point type. */ template -class var_value< - T, require_t::value || is_kernel_expression_and_not_scalar::value) - && std::is_floating_point>::value>>> { +class var_value> { public: using value_type = T; // type in vari_value. using vari_type = std::conditional_t::value, diff --git a/stan/math/rev/fun/cos.hpp b/stan/math/rev/fun/cos.hpp index ed53bfb2f18..4b48a5e5fe4 100644 --- a/stan/math/rev/fun/cos.hpp +++ b/stan/math/rev/fun/cos.hpp @@ -43,7 +43,7 @@ namespace math { * @param a Variable for radians of angle. * @return Cosine of variable. */ -inline var cos(const var& a) { +inline var cos(var a) { return make_callback_var(std::cos(a.val()), [a](const auto& vi) mutable { a.adj() -= vi.adj() * std::sin(a.val()); }); diff --git a/stan/math/rev/fun/divide.hpp b/stan/math/rev/fun/divide.hpp index 06d1452f5a1..df6ea6640c7 100644 --- a/stan/math/rev/fun/divide.hpp +++ b/stan/math/rev/fun/divide.hpp @@ -1,208 +1,8 @@ #ifndef STAN_MATH_REV_FUN_DIVIDE_HPP #define STAN_MATH_REV_FUN_DIVIDE_HPP +#include #include #include -#include -#include -#include -#include - -namespace stan { -namespace math { -namespace internal { - -template -class matrix_scalar_divide_dv_vari : public vari { - public: - int rows_; - int cols_; - vari* adjCRef_; - vari** adjResultRef_; - double invc_; - - explicit matrix_scalar_divide_dv_vari(const Eigen::Matrix& m, - const var& c) - : vari(0), - rows_(m.rows()), - cols_(m.cols()), - adjCRef_(c.vi_), - adjResultRef_(ChainableStack::instance_->memalloc_.alloc_array( - m.rows() * m.cols())), - invc_(1.0 / c.val()) { - Eigen::Map(adjResultRef_, rows_, cols_) - = (invc_ * m).unaryExpr([](double x) { return new vari(x, false); }); - } - - virtual void chain() { - Eigen::Map adjResult(adjResultRef_, rows_, cols_); - adjCRef_->adj_ - -= invc_ * (adjResult.adj().array() * adjResult.val().array()).sum(); - } -}; - -template -class matrix_scalar_divide_vd_vari : public vari { - public: - int rows_; - int cols_; - vari** adjMRef_; - vari** adjResultRef_; - double invc_; - - explicit matrix_scalar_divide_vd_vari(const Eigen::Matrix& m, - const double& c) - : vari(0), - rows_(m.rows()), - cols_(m.cols()), - adjMRef_(ChainableStack::instance_->memalloc_.alloc_array( - m.rows() * m.cols())), - adjResultRef_(ChainableStack::instance_->memalloc_.alloc_array( - m.rows() * m.cols())), - invc_(1.0 / c) { - Eigen::Map(adjMRef_, rows_, cols_) = m.vi(); - Eigen::Map(adjResultRef_, rows_, cols_) - = (invc_ * m.val()).unaryExpr([](double x) { - return new vari(x, false); - }); - } - - virtual void chain() { - Eigen::Map adjM(adjMRef_, rows_, cols_); - Eigen::Map adjResult(adjResultRef_, rows_, cols_); - adjM.adj() += invc_ * adjResult.adj(); - } -}; - -template -class matrix_scalar_divide_vv_vari : public vari { - public: - int rows_; - int cols_; - vari** adjMRef_; - vari* adjC_; - vari** adjResultRef_; - double invc_; - - explicit matrix_scalar_divide_vv_vari(const Eigen::Matrix& m, - const var& c) - : vari(0), - rows_(m.rows()), - cols_(m.cols()), - adjMRef_(ChainableStack::instance_->memalloc_.alloc_array( - m.rows() * m.cols())), - adjC_(c.vi_), - adjResultRef_(ChainableStack::instance_->memalloc_.alloc_array( - m.rows() * m.cols())), - invc_(1.0 / c.val()) { - Eigen::Map(adjMRef_, rows_, cols_) = m.vi(); - Eigen::Map(adjResultRef_, rows_, cols_) - = (invc_ * m.val()).unaryExpr([](double x) { - return new vari(x, false); - }); - } - - virtual void chain() { - Eigen::Map adjM(adjMRef_, rows_, cols_); - Eigen::Map adjResult(adjResultRef_, rows_, cols_); - adjC_->adj_ - -= invc_ * (adjResult.adj().array() * adjResult.val().array()).sum(); - adjM.adj() += invc_ * adjResult.adj(); - } -}; - -} // namespace internal - -/** - * Return matrix divided by scalar. - * - * @tparam Mat type of the matrix or expression - * @param[in] m specified matrix or expression - * @param[in] c specified scalar - * @return matrix divided by the scalar - */ -template > -inline auto divide(const Mat& m, const var& c) { - auto* baseVari - = new internal::matrix_scalar_divide_dv_vari(m, - c); - Eigen::Matrix result( - m.rows(), m.cols()); - result.vi() - = Eigen::Map(baseVari->adjResultRef_, m.rows(), m.cols()); - return result; -} - -/** - * Return matrix divided by scalar. - * - * @tparam Mat type of the matrix or expression - * @param[in] m specified matrix or expression - * @param[in] c specified scalar - * @return matrix divided by the scalar - */ -template > -inline auto divide(const Mat& m, const double& c) { - auto* baseVari - = new internal::matrix_scalar_divide_vd_vari(m, - c); - Eigen::Matrix result( - m.rows(), m.cols()); - result.vi() - = Eigen::Map(baseVari->adjResultRef_, m.rows(), m.cols()); - return result; -} - -/** - * Return matrix divided by scalar. - * - * @tparam Mat type of the matrix or expression - * @param[in] m specified matrix or expression - * @param[in] c specified scalar - * @return matrix divided by the scalar - */ -template , - typename = void> -inline auto divide(const Mat& m, const var& c) { - auto* baseVari - = new internal::matrix_scalar_divide_vv_vari(m, - c); - Eigen::Matrix result( - m.rows(), m.cols()); - result.vi() - = Eigen::Map(baseVari->adjResultRef_, m.rows(), m.cols()); - return result; -} - -/** - * Return matrix divided by scalar. - * - * @tparam Mat type of the matrix - * @tparam Scal type of the scalar - * @param[in] m input matrix - * @param[in] c input scalar - * @return matrix divided by the scalar - */ -template * = nullptr, - require_stan_scalar_t* = nullptr> -inline auto divide(const Mat& m, const Scal& c) { - double invc = 1.0 / value_of(c); - - plain_type_t res = invc * m.val(); - - reverse_pass_callback([m, c, res, invc]() mutable { - m.adj() += invc * res.adj(); - if (!is_constant::value) - forward_as(c).adj() - -= invc * (res.adj().array() * res.val().array()).sum(); - }); - - return res; -} -} // namespace math -} // namespace stan #endif diff --git a/stan/math/rev/fun/exp.hpp b/stan/math/rev/fun/exp.hpp index 3cb96eb80d9..5f04d8917d2 100644 --- a/stan/math/rev/fun/exp.hpp +++ b/stan/math/rev/fun/exp.hpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include diff --git a/stan/math/rev/fun/inv.hpp b/stan/math/rev/fun/inv.hpp index f3c365b0404..8e920e59925 100644 --- a/stan/math/rev/fun/inv.hpp +++ b/stan/math/rev/fun/inv.hpp @@ -9,7 +9,7 @@ namespace stan { namespace math { /** - * + * @tparam T Arithmetic or a type inheriting from `EigenBase`. \f[ \mbox{inv}(x) = \begin{cases} @@ -27,9 +27,11 @@ namespace math { \f] * */ -inline var inv(const var& a) { - return make_callback_var(inv(a.val()), [a](auto& vi) mutable { - a.adj() -= vi.adj() / (a.val() * a.val()); +template * = nullptr> +inline auto inv(const var_value& a) { + auto denom = to_arena(as_array_or_scalar(square(a.val()))); + return make_callback_var(inv(a.val()), [a, denom](auto& vi) mutable { + as_array_or_scalar(a.adj()) -= as_array_or_scalar(vi.adj()) / denom; }); } diff --git a/stan/math/rev/fun/inv_cloglog.hpp b/stan/math/rev/fun/inv_cloglog.hpp index b312034e3d1..78d58737932 100644 --- a/stan/math/rev/fun/inv_cloglog.hpp +++ b/stan/math/rev/fun/inv_cloglog.hpp @@ -19,15 +19,18 @@ namespace math { * * \f$\frac{d}{dx} \mbox{cloglog}^{-1}(x) = \exp (x - \exp (x))\f$. * + * @tparam T Arithmetic or a type inheriting from `EigenBase`. * @param a Variable argument. * @return The inverse complementary log-log of the specified * argument. */ -inline var inv_cloglog(const var& a) { - auto precomp_exp = std::exp(a.val() - std::exp(a.val())); +template * = nullptr> +inline auto inv_cloglog(const var_value& a) { + auto precomp_exp = to_arena(as_array_or_scalar(exp(a.val() - exp(a.val())))); return make_callback_var(inv_cloglog(a.val()), [a, precomp_exp](auto& vi) mutable { - a.adj() += vi.adj() * precomp_exp; + as_array_or_scalar(a.adj()) + += as_array_or_scalar(vi.adj()) * precomp_exp; }); } diff --git a/stan/math/rev/fun/inv_logit.hpp b/stan/math/rev/fun/inv_logit.hpp index d64d3b50c1f..086b5f45bc9 100644 --- a/stan/math/rev/fun/inv_logit.hpp +++ b/stan/math/rev/fun/inv_logit.hpp @@ -18,27 +18,16 @@ namespace math { * \f$\frac{d}{dx} \mbox{logit}^{-1}(x) = \mbox{logit}^{-1}(x) (1 - * \mbox{logit}^{-1}(x))\f$. * + * @tparam T Arithmetic or a type inheriting from `EigenBase`. * @param a Argument variable. * @return Inverse logit of argument. */ -inline var inv_logit(const var& a) { +template * = nullptr> +inline auto inv_logit(const var_value& a) { return make_callback_var(inv_logit(a.val()), [a](auto& vi) mutable { - a.adj() += vi.adj() * vi.val() * (1.0 - vi.val()); - }); -} - -/** - * Return the inverse logit of the elements of x - * - * @tparam T type of x - * @param x argument - * @return elementwise inverse logit of x - */ -template * = nullptr> -inline auto inv_logit(const T& x) { - return make_callback_var(inv_logit(x.val()), [x](const auto& vi) mutable { - x.adj() += (vi.adj().array() * vi.val().array() * (1.0 - vi.val().array())) - .matrix(); + as_array_or_scalar(a).adj() += as_array_or_scalar(vi.adj()) + * as_array_or_scalar(vi.val()) + * (1.0 - as_array_or_scalar(vi.val())); }); } diff --git a/stan/math/rev/fun/inv_sqrt.hpp b/stan/math/rev/fun/inv_sqrt.hpp index 5f7b35c6416..3135779ed2a 100644 --- a/stan/math/rev/fun/inv_sqrt.hpp +++ b/stan/math/rev/fun/inv_sqrt.hpp @@ -4,12 +4,14 @@ #include #include #include +#include #include namespace stan { namespace math { /** + * @tparam T Arithmetic or a type inheriting from `EigenBase`. * \f[ \mbox{inv\_sqrt}(x) = @@ -28,10 +30,12 @@ namespace math { \f] * */ -inline var inv_sqrt(const var& a) { - auto denom = a.val() * std::sqrt(a.val()); +template * = nullptr> +inline auto inv_sqrt(const var_value& a) { + auto denom = to_arena(as_array_or_scalar(a.val()) + * as_array_or_scalar(sqrt(a.val()))); return make_callback_var(inv_sqrt(a.val()), [a, denom](auto& vi) mutable { - a.adj() -= 0.5 * vi.adj() / denom; + as_array_or_scalar(a.adj()) -= 0.5 * as_array_or_scalar(vi.adj()) / denom; }); } diff --git a/stan/math/rev/fun/lambert_w.hpp b/stan/math/rev/fun/lambert_w.hpp index 36fa8715ffd..6399a384e9a 100644 --- a/stan/math/rev/fun/lambert_w.hpp +++ b/stan/math/rev/fun/lambert_w.hpp @@ -14,12 +14,16 @@ namespace math { /** * Return the Lambert W function on W0 branch applied to the specified variable. * + * @tparam T Arithmetic or a type inheriting from `EigenBase`. * @param a Variable argument. * @return the Lambert W function (W0 branch) applied to the specified argument. */ -inline var lambert_w0(const var& a) { +template * = nullptr> +inline auto lambert_w0(const var_value& a) { return make_callback_var(lambert_w0(a.val()), [a](auto& vi) mutable { - a.adj() += (vi.adj() / (a.val() + exp(vi.val()))); + as_array_or_scalar(a.adj()) + += (as_array_or_scalar(vi.adj()) + / as_array_or_scalar(a.val() + exp(vi.val()))); }); } @@ -27,13 +31,17 @@ inline var lambert_w0(const var& a) { * Return the Lambert W function on W-1 branch applied to the specified * variable. * + * @tparam T Arithmetic or a type inheriting from `EigenBase`. * @param a Variable argument. * @return the Lambert W function (W-1 branch) applied to the specified * argument. */ -inline var lambert_wm1(const var& a) { +template * = nullptr> +inline auto lambert_wm1(const var_value& a) { return make_callback_var(lambert_wm1(a.val()), [a](auto& vi) mutable { - a.adj() += (vi.adj() / (a.val() + exp(vi.val()))); + as_array_or_scalar(a.adj()) + += (as_array_or_scalar(vi.adj()) + / as_array_or_scalar(a.val() + exp(vi.val()))); }); } diff --git a/stan/math/rev/fun/lgamma.hpp b/stan/math/rev/fun/lgamma.hpp index 1f50706cba6..7afe24991a5 100644 --- a/stan/math/rev/fun/lgamma.hpp +++ b/stan/math/rev/fun/lgamma.hpp @@ -16,12 +16,15 @@ namespace math { * * \f$\frac{d}{dx} \Gamma(x) = \psi^{(0)}(x)\f$. * + * @tparam T Arithmetic or a type inheriting from `EigenBase`. * @param a The variable. * @return Log gamma of the variable. */ -inline var lgamma(const var& a) { +template * = nullptr> +inline auto lgamma(const var_value& a) { return make_callback_var(lgamma(a.val()), [a](auto& vi) mutable { - a.adj() += vi.adj() * digamma(a.val()); + as_array_or_scalar(a.adj()) + += as_array_or_scalar(vi.adj()) * as_array_or_scalar(digamma(a.val())); }); } diff --git a/stan/math/rev/fun/log.hpp b/stan/math/rev/fun/log.hpp index c063d55dfe7..8c7d0fc284f 100644 --- a/stan/math/rev/fun/log.hpp +++ b/stan/math/rev/fun/log.hpp @@ -41,12 +41,15 @@ namespace math { \end{cases} \f] * + * @tparam T Arithmetic or a type inheriting from `EigenBase`. * @param a Variable whose log is taken. * @return Natural log of variable. */ -inline var log(const var& a) { - return make_callback_var(std::log(a.val()), [a](auto& vi) mutable { - a.adj() += vi.adj() / a.val(); +template * = nullptr> +inline auto log(const var_value& a) { + return make_callback_var(log(a.val()), [a](auto& vi) mutable { + as_array_or_scalar(a.adj()) + += as_array_or_scalar(vi.adj()) / as_array_or_scalar(a.val()); }); } @@ -60,21 +63,6 @@ inline std::complex log(const std::complex& z) { return internal::complex_log(z); } -/** - * Return the natural log of the elements of x - * - * @tparam T type of x - * @param x argument - * @return elementwise natural log of x - */ -template * = nullptr> -inline auto log(const T& x) { - return make_callback_var( - x.val().array().log().matrix(), [x](const auto& vi) mutable { - x.adj() += (vi.adj().array() / x.val().array()).matrix(); - }); -} - } // namespace math } // namespace stan #endif diff --git a/stan/math/rev/fun/log10.hpp b/stan/math/rev/fun/log10.hpp index 64cf9b745cd..d037bae30b7 100644 --- a/stan/math/rev/fun/log10.hpp +++ b/stan/math/rev/fun/log10.hpp @@ -41,12 +41,15 @@ namespace math { \end{cases} \f] * + * @tparam T Arithmetic or a type inheriting from `EigenBase`. * @param a Variable whose log is taken. * @return Base 10 log of variable. */ -inline var log10(const var& a) { - return make_callback_var(std::log10(a.val()), [a](auto& vi) mutable { - a.adj() += vi.adj() / (LOG_TEN * a.val()); +template * = nullptr> +inline auto log10(const var_value& a) { + return make_callback_var(log10(a.val()), [a](auto& vi) mutable { + as_array_or_scalar(a.adj()) += as_array_or_scalar(vi.adj()) + / (LOG_TEN * as_array_or_scalar(a.val())); }); } diff --git a/stan/math/rev/fun/log1m.hpp b/stan/math/rev/fun/log1m.hpp index 762db31a520..751ff99e28b 100644 --- a/stan/math/rev/fun/log1m.hpp +++ b/stan/math/rev/fun/log1m.hpp @@ -15,30 +15,18 @@ namespace math { * * \f$\frac{d}{dx} \log (1 - x) = -\frac{1}{1 - x}\f$. * + * @tparam T Arithmetic or a type inheriting from `EigenBase`. * @param a The variable. * @return The variable representing log of 1 minus the variable. */ -inline var log1m(const var& a) { +template * = nullptr> +inline auto log1m(const var_value& a) { return make_callback_var(log1m(a.val()), [a](auto& vi) mutable { - a.adj() += vi.adj() / (a.val() - 1); + as_array_or_scalar(a.adj()) + += as_array_or_scalar(vi.adj()) / (as_array_or_scalar(a.val()) - 1.0); }); } -/** - * Return the elementwise log of 1 - x - * - * @tparam T type of x - * @param x argument - * @return elementwise log of 1 - x - */ -template * = nullptr> -inline auto log1m(const T& x) { - return make_callback_var( - stan::math::log1m(x.val()), [x](const auto& vi) mutable { - x.adj() += (vi.adj().array() / (x.val().array() - 1.0)).matrix(); - }); -} - } // namespace math } // namespace stan #endif diff --git a/stan/math/rev/fun/log1m_exp.hpp b/stan/math/rev/fun/log1m_exp.hpp index 3087d1cbfa5..80c45440fb7 100644 --- a/stan/math/rev/fun/log1m_exp.hpp +++ b/stan/math/rev/fun/log1m_exp.hpp @@ -16,13 +16,16 @@ namespace math { *

The derivative of log(1 - exp(x)) with respect * to x is -1 / expm1(-x). * + * @tparam T Arithmetic or a type inheriting from `EigenBase`. * @param[in] x Argument. * @return Natural logarithm of one minus the exponential of the * argument. */ -inline var log1m_exp(const var& x) { +template * = nullptr> +inline auto log1m_exp(const var_value& x) { return make_callback_var(log1m_exp(x.val()), [x](auto& vi) mutable { - x.adj() -= vi.adj() / expm1(-x.val()); + as_array_or_scalar(x.adj()) + -= as_array_or_scalar(vi.adj()) / as_array_or_scalar(expm1(-x.val())); }); } diff --git a/stan/math/rev/fun/log1m_inv_logit.hpp b/stan/math/rev/fun/log1m_inv_logit.hpp index 6db335ff16b..7592707cf98 100644 --- a/stan/math/rev/fun/log1m_inv_logit.hpp +++ b/stan/math/rev/fun/log1m_inv_logit.hpp @@ -13,15 +13,18 @@ namespace math { * Return the natural logarithm of one minus the inverse logit of * the specified argument. * + * @tparam T Arithmetic or a type inheriting from `EigenBase`. * @param u argument * @return log of one minus the inverse logit of the argument */ -inline var log1m_inv_logit(const var& u) { - auto precomp_inv_logit = -inv_logit(u.val()); - return make_callback_var(log1m_inv_logit(u.val()), - [u, precomp_inv_logit](auto& vi) mutable { - u.adj() += vi.adj() * precomp_inv_logit; - }); +template * = nullptr> +inline auto log1m_inv_logit(const var_value& u) { + auto precomp_inv_logit = to_arena(as_array_or_scalar(-inv_logit(u.val()))); + return make_callback_var( + log1m_inv_logit(u.val()), [u, precomp_inv_logit](auto& vi) mutable { + as_array_or_scalar(u.adj()) + += as_array_or_scalar(vi.adj()) * precomp_inv_logit; + }); } } // namespace math diff --git a/stan/math/rev/fun/log1p.hpp b/stan/math/rev/fun/log1p.hpp index edb4796b4ad..dc83088d5c2 100644 --- a/stan/math/rev/fun/log1p.hpp +++ b/stan/math/rev/fun/log1p.hpp @@ -15,31 +15,18 @@ namespace math { * * \f$\frac{d}{dx} \log (1 + x) = \frac{1}{1 + x}\f$. * + * @tparam T Arithmetic or a type inheriting from `EigenBase`. * @param a The variable. * @return The log of 1 plus the variable. */ -inline var log1p(const var& a) { +template * = nullptr> +inline auto log1p(const var_value& a) { return make_callback_var(log1p(a.val()), [a](auto& vi) mutable { - a.adj() += vi.adj() / (1 + a.val()); + as_array_or_scalar(a.adj()) + += as_array_or_scalar(vi.adj()) / (1.0 + as_array_or_scalar(a.val())); }); } -/** - * Return the elementwise log of (1 + x) - * - * @tparam T type of input - * @param x input - * @return Elementwise log(1 + x) - */ -template * = nullptr> -inline auto log1p(const T& x) { - check_greater_or_equal("log1p", "x", x.val(), -1.0); - return make_callback_var( - x.val().array().log1p().matrix(), [x](const auto& vi) { - x.adj().array() += vi.adj().array() / (1 + x.val().array()); - }); -} - } // namespace math } // namespace stan #endif diff --git a/stan/math/rev/fun/log1p_exp.hpp b/stan/math/rev/fun/log1p_exp.hpp index d0d7071f078..4a73c220ee9 100644 --- a/stan/math/rev/fun/log1p_exp.hpp +++ b/stan/math/rev/fun/log1p_exp.hpp @@ -12,27 +12,17 @@ namespace math { /** * Return the log of 1 plus the exponential of the specified * variable. + * @tparam T Arithmetic or a type inheriting from `EigenBase`. + * @param a The variable. */ -inline var log1p_exp(const var& a) { - auto precomp_inv_logit = inv_logit(a.val()); - return make_callback_var(log1p_exp(a.val()), - [a, precomp_inv_logit](auto& vi) mutable { - a.adj() += vi.adj() * precomp_inv_logit; - }); -} - -/** - * Return the elementwise log(1 + exp(x)) - * - * @tparam T type of input - * @param x input - * @return Elementwise log(1 + exp(x)) - */ -template * = nullptr> -inline auto log1p_exp(const T& x) { - return make_callback_var(log1p_exp(x.val()), [x](const auto& vi) { - x.adj().array() += vi.adj().array() * inv_logit(x.val().array()); - }); +template * = nullptr> +inline auto log1p_exp(const var_value& a) { + auto precomp_inv_logit = to_arena(as_array_or_scalar(inv_logit(a.val()))); + return make_callback_var( + log1p_exp(a.val()), [a, precomp_inv_logit](auto& vi) mutable { + as_array_or_scalar(a.adj()) + += as_array_or_scalar(vi.adj()) * precomp_inv_logit; + }); } } // namespace math diff --git a/stan/math/rev/fun/log2.hpp b/stan/math/rev/fun/log2.hpp index 636de21317a..05ee0f3db1d 100644 --- a/stan/math/rev/fun/log2.hpp +++ b/stan/math/rev/fun/log2.hpp @@ -36,12 +36,15 @@ namespace math { \end{cases} \f] * - * @param a Specified variable. + * @tparam T Arithmetic or a type inheriting from `EigenBase`. + * @param a The variable. * @return Base 2 logarithm of the variable. */ -inline var log2(const var& a) { +template * = nullptr> +inline auto log2(const var_value& a) { return make_callback_var(log2(a.val()), [a](auto& vi) mutable { - a.adj() += vi.adj() / (LOG_TWO * a.val()); + as_array_or_scalar(a.adj()) += as_array_or_scalar(vi.adj()) + / (LOG_TWO * as_array_or_scalar(a.val())); }); } diff --git a/stan/math/rev/fun/logit.hpp b/stan/math/rev/fun/logit.hpp index 00c858cc39d..10aaf8b0545 100644 --- a/stan/math/rev/fun/logit.hpp +++ b/stan/math/rev/fun/logit.hpp @@ -4,6 +4,7 @@ #include #include #include +#include namespace stan { namespace math { @@ -11,13 +12,15 @@ namespace math { /** * Return the log odds of the specified argument. * - * @param u argument + * @tparam T Arithmetic or a type inheriting from `EigenBase`. + * @param u The variable. * @return log odds of argument */ -inline var logit(const var& u) { - auto denom = (1.0 / (u.val() - u.val() * u.val())); +template * = nullptr> +inline auto logit(const var_value& u) { + auto denom = to_arena(1.0 / as_array_or_scalar(u.val() - square(u.val()))); return make_callback_var(logit(u.val()), [u, denom](auto& vi) mutable { - u.adj() += vi.adj() * denom; + as_array_or_scalar(u.adj()) += as_array_or_scalar(vi.adj()) * denom; }); } diff --git a/stan/math/rev/fun/norm.hpp b/stan/math/rev/fun/norm.hpp index 41a6b6af10c..846f2a1651a 100644 --- a/stan/math/rev/fun/norm.hpp +++ b/stan/math/rev/fun/norm.hpp @@ -3,6 +3,7 @@ #include #include +#include #include namespace stan { diff --git a/stan/math/rev/fun/to_arena.hpp b/stan/math/rev/fun/to_arena.hpp index dadf68f2cf7..9a17e400a40 100644 --- a/stan/math/rev/fun/to_arena.hpp +++ b/stan/math/rev/fun/to_arena.hpp @@ -1,9 +1,9 @@ #ifndef STAN_MATH_REV_FUN_TO_ARENA_HPP #define STAN_MATH_REV_FUN_TO_ARENA_HPP +#include #include #include -#include #include #include diff --git a/test/unit/math/mix/core/operator_division_test.cpp b/test/unit/math/mix/core/operator_division_test.cpp index 06bf2e9b291..e64c0f7621d 100644 --- a/test/unit/math/mix/core/operator_division_test.cpp +++ b/test/unit/math/mix/core/operator_division_test.cpp @@ -33,3 +33,35 @@ TEST(mathMixCore, operatorDivision) { } } } + +namespace stan { +namespace test { +struct operator_divide_tester { + template * = nullptr> + auto operator()(const T1& x, const T2& y) const { + return x / y; + } + template * = nullptr, + require_all_not_var_matrix_t* = nullptr> + auto operator()(const T1& x, const T2& y) const { + return (stan::math::as_array_or_scalar(x) + / stan::math::as_array_or_scalar(y)) + .matrix() + .eval(); + } +}; +} // namespace test +} // namespace stan + +TEST(mathMixCore, operatorDivisionVarMat) { + auto f = [](const auto& x1, const auto& x2) { return x1 / x2; }; + + Eigen::MatrixXd mat1(2, 2); + mat1 << -2, -1, 0.5, 2.8; + Eigen::MatrixXd mat2 = mat1.reverse(); + stan::test::expect_ad_matvar(stan::test::operator_divide_tester{}, mat1, + mat2); + stan::test::expect_ad_matvar(stan::test::operator_divide_tester{}, mat1, 2.0); + stan::test::expect_ad_matvar(stan::test::operator_divide_tester{}, 2.0, mat2); +} diff --git a/test/unit/math/mix/fun/divide2_test.cpp b/test/unit/math/mix/fun/divide2_test.cpp new file mode 100644 index 00000000000..8afaef56ec4 --- /dev/null +++ b/test/unit/math/mix/fun/divide2_test.cpp @@ -0,0 +1,66 @@ +#include +#include + +TEST(MathMixMatFun, divide_vec_scal) { + auto f + = [](const auto& x, const auto& y) { return stan::math::divide(x, y); }; + + double x1 = 10; + double x2 = -2; + + Eigen::VectorXd v(1); + v << 10; + stan::test::expect_ad(f, x2, v); + stan::test::expect_ad_matvar(f, x2, v); + + Eigen::RowVectorXd rv = v; + stan::test::expect_ad(f, x2, rv); + stan::test::expect_ad_matvar(f, x2, rv); + + Eigen::MatrixXd m(1, 1); + m << 10; + stan::test::expect_ad(f, x2, m); + stan::test::expect_ad_matvar(f, x2, m); + + Eigen::MatrixXd p(3, 2); + p << 1, 2, 3, 4, 5, 6; + stan::test::expect_ad(f, x2, p); + stan::test::expect_ad_matvar(f, x2, m); + Eigen::VectorXd w(3); + w << 100, 2, -3; + stan::test::expect_ad(f, x2, w); + stan::test::expect_ad_matvar(f, x2, m); + + Eigen::VectorXd u(4); + u << 100, 0.5, -3, 4; + stan::test::expect_ad(f, x2, u); + stan::test::expect_ad_matvar(f, x2, m); + + Eigen::VectorXd v0(0); + Eigen::RowVectorXd rv0(0); + Eigen::MatrixXd m00(0, 0); + stan::test::expect_ad(f, x1, v0); + stan::test::expect_ad(f, x1, rv0); + stan::test::expect_ad(f, x1, m00); + + stan::test::expect_ad_matvar(f, x1, v0); + stan::test::expect_ad_matvar(f, x1, rv0); + stan::test::expect_ad_matvar(f, x1, m00); + + Eigen::RowVectorXd rv4(4); + rv4 << -5, 10, 7, 8.2; + stan::test::expect_ad(f, x2, rv4); + stan::test::expect_ad_matvar(f, x2, rv4); + + double inf = std::numeric_limits::infinity(); + double nan = std::numeric_limits::quiet_NaN(); + for (double value : {inf, -inf, nan}) { + stan::test::expect_ad(f, value, u); + stan::test::expect_ad(f, value, rv4); + stan::test::expect_ad(f, value, p); + + stan::test::expect_ad_matvar(f, value, u); + stan::test::expect_ad_matvar(f, value, rv4); + stan::test::expect_ad_matvar(f, value, p); + } +} diff --git a/test/unit/math/mix/fun/divide_test.cpp b/test/unit/math/mix/fun/divide_test.cpp index aa5aac5831a..509168ae332 100644 --- a/test/unit/math/mix/fun/divide_test.cpp +++ b/test/unit/math/mix/fun/divide_test.cpp @@ -34,10 +34,15 @@ TEST(MathMixMatFun, divide) { stan::test::expect_ad_matvar(f, m, x2); Eigen::VectorXd u(4); - u << 100, 0, -3, 4; + u << 100, 0.5, -3, 4; stan::test::expect_ad(f, u, x2); stan::test::expect_ad_matvar(f, m, x2); + Eigen::VectorXd u1(4); + u1 << 100, 0.5, -3, 4; + Eigen::VectorXd u2 = u1.reverse(); + stan::test::expect_ad(f, u1, u2); + stan::test::expect_ad_matvar(f, u1, u2); Eigen::VectorXd v0(0); Eigen::RowVectorXd rv0(0); Eigen::MatrixXd m00(0, 0); diff --git a/test/unit/math/mix/fun/inv_cloglog_test.cpp b/test/unit/math/mix/fun/inv_cloglog_test.cpp index 1359fd877f5..19dc4eabc96 100644 --- a/test/unit/math/mix/fun/inv_cloglog_test.cpp +++ b/test/unit/math/mix/fun/inv_cloglog_test.cpp @@ -5,3 +5,21 @@ TEST(mathMixMatFun, invCLogLog) { stan::test::expect_common_unary_vectorized(f); stan::test::expect_unary_vectorized(f, -2.6, -2, -1.2, -0.2, 0.5, 1.3); } + +TEST(mathMixMatFun, invcloglog_varmat) { + using stan::math::vec_concat; + using stan::test::expect_ad_vector_matvar; + using stan::test::internal::common_args; + auto f = [](const auto& x1) { + using stan::math::inv_cloglog; + return inv_cloglog(x1); + }; + std::vector com_args = common_args(); + std::vector args{-2.6, -2, -1.2, -0.2, 0.5, 1.3}; + auto all_args = vec_concat(com_args, args); + Eigen::VectorXd A(all_args.size()); + for (int i = 0; i < all_args.size(); ++i) { + A(i) = all_args[i]; + } + expect_ad_vector_matvar(f, A); +} diff --git a/test/unit/math/mix/fun/inv_sqrt_test.cpp b/test/unit/math/mix/fun/inv_sqrt_test.cpp index 2c1863be486..d7a90961f80 100644 --- a/test/unit/math/mix/fun/inv_sqrt_test.cpp +++ b/test/unit/math/mix/fun/inv_sqrt_test.cpp @@ -6,3 +6,21 @@ TEST(mathMixMatFun, invSqrt) { stan::test::expect_unary_vectorized(f, -2.6, -2, -1.2, -0.2, 0.5, 1, 1.3, 1.5, 3, 10.2); } + +TEST(mathMixMatFun, invsqrt_varmat) { + using stan::math::vec_concat; + using stan::test::expect_ad_vector_matvar; + using stan::test::internal::common_args; + auto f = [](const auto& x1) { + using stan::math::inv_sqrt; + return inv_sqrt(x1); + }; + std::vector com_args = common_args(); + std::vector args{-2.6, -2, -1.2, -0.2, 0.5, 1, 1.3, 1.5, 3, 10.2}; + auto all_args = vec_concat(com_args, args); + Eigen::VectorXd A(all_args.size()); + for (int i = 0; i < all_args.size(); ++i) { + A(i) = all_args[i]; + } + expect_ad_vector_matvar(f, A); +} diff --git a/test/unit/math/mix/fun/inv_square_test.cpp b/test/unit/math/mix/fun/inv_square_test.cpp index 0a1bb040772..c86ad25dfd7 100644 --- a/test/unit/math/mix/fun/inv_square_test.cpp +++ b/test/unit/math/mix/fun/inv_square_test.cpp @@ -5,3 +5,21 @@ TEST(mathMixMatFun, invSquare) { stan::test::expect_common_unary_vectorized(f); stan::test::expect_unary_vectorized(f, -2.6, -2, -0.2, 1, 1.3, 3); } + +TEST(mathMixMatFun, invsquare_varmat) { + using stan::math::vec_concat; + using stan::test::expect_ad_vector_matvar; + using stan::test::internal::common_args; + auto f = [](const auto& x1) { + using stan::math::inv_square; + return inv_square(x1); + }; + std::vector com_args = common_args(); + std::vector args{-2.6, -2, -0.2, 1, 1.3, 3}; + auto all_args = vec_concat(com_args, args); + Eigen::VectorXd A(all_args.size()); + for (int i = 0; i < all_args.size(); ++i) { + A(i) = all_args[i]; + } + expect_ad_vector_matvar(f, A); +} diff --git a/test/unit/math/mix/fun/inv_test.cpp b/test/unit/math/mix/fun/inv_test.cpp index 3a610033a01..65ba8443b8d 100644 --- a/test/unit/math/mix/fun/inv_test.cpp +++ b/test/unit/math/mix/fun/inv_test.cpp @@ -5,3 +5,21 @@ TEST(mathMixMatFun, inv) { stan::test::expect_common_unary_vectorized(f); stan::test::expect_unary_vectorized(f, -2.6, -2, -0.2, 1.3, 3); } + +TEST(mathMixMatFun, inv_varmat) { + using stan::math::vec_concat; + using stan::test::expect_ad_vector_matvar; + using stan::test::internal::common_args; + auto f = [](const auto& x1) { + using stan::math::inv; + return inv(x1); + }; + std::vector com_args = common_args(); + std::vector args{-2.6, -2, -0.2, 1.3, 3}; + auto all_args = vec_concat(com_args, args); + Eigen::VectorXd A(all_args.size()); + for (int i = 0; i < all_args.size(); ++i) { + A(i) = all_args[i]; + } + expect_ad_vector_matvar(f, A); +} diff --git a/test/unit/math/mix/fun/lambertw_test.cpp b/test/unit/math/mix/fun/lambertw_test.cpp index 461d105a59b..6822b45c911 100644 --- a/test/unit/math/mix/fun/lambertw_test.cpp +++ b/test/unit/math/mix/fun/lambertw_test.cpp @@ -20,6 +20,24 @@ TEST(mathMixMatFun, lambert_w0) { stan::math::recover_memory(); } +TEST(mathMixMatFun, lambert_w0_varmat) { + using stan::math::vec_concat; + using stan::test::expect_ad_vector_matvar; + using stan::test::internal::common_args; + auto f = [](const auto& x1) { + using stan::math::lambert_w0; + return lambert_w0(x1); + }; + std::vector com_args = common_args(); + std::vector args{-0.3, -0.1, 0.0, 1, 10, 20, -0.38}; + auto all_args = vec_concat(com_args, args); + Eigen::VectorXd A(all_args.size()); + for (int i = 0; i < all_args.size(); ++i) { + A(i) = all_args[i]; + } + expect_ad_vector_matvar(f, A); +} + TEST(mathMixMatFun, lambert_wm1) { auto f = [](const auto& x1) { using stan::math::lambert_wm1; @@ -36,3 +54,21 @@ TEST(mathMixMatFun, lambert_wm1) { stan::math::recover_memory(); } + +TEST(mathMixMatFun, lambert_wm1_varmat) { + using stan::math::vec_concat; + using stan::test::expect_ad_vector_matvar; + using stan::test::internal::common_args; + auto f = [](const auto& x1) { + using stan::math::lambert_wm1; + return lambert_wm1(x1); + }; + std::vector com_args = common_args(); + std::vector args{-0.35, -0.3, -0.1, -0.01, -0.38, 0.001}; + auto all_args = vec_concat(com_args, args); + Eigen::VectorXd A(all_args.size()); + for (int i = 0; i < all_args.size(); ++i) { + A(i) = all_args[i]; + } + expect_ad_vector_matvar(f, A); +} diff --git a/test/unit/math/mix/fun/lgamma_test.cpp b/test/unit/math/mix/fun/lgamma_test.cpp index faa33e15d3f..f5f42691a6e 100644 --- a/test/unit/math/mix/fun/lgamma_test.cpp +++ b/test/unit/math/mix/fun/lgamma_test.cpp @@ -6,3 +6,21 @@ TEST(mathMixMatFun, lgamma) { stan::test::expect_unary_vectorized(f, -2.6, -2, -0.2, 1, 1.3, 3, 3.7, 10, 22); } + +TEST(mathMixMatFun, lgamma_varmat) { + using stan::math::vec_concat; + using stan::test::expect_ad_vector_matvar; + using stan::test::internal::common_args; + auto f = [](const auto& x1) { + using stan::math::lgamma; + return lgamma(x1); + }; + std::vector com_args = common_args(); + std::vector args{-2.6, -2, -0.2, 1, 1.3, 3, 3.7, 10, 22}; + auto all_args = vec_concat(com_args, args); + Eigen::VectorXd A(all_args.size()); + for (int i = 0; i < all_args.size(); ++i) { + A(i) = all_args[i]; + } + expect_ad_vector_matvar(f, A); +} diff --git a/test/unit/math/mix/fun/log10_test.cpp b/test/unit/math/mix/fun/log10_test.cpp index b572db0dcd5..6b33ea92f4d 100644 --- a/test/unit/math/mix/fun/log10_test.cpp +++ b/test/unit/math/mix/fun/log10_test.cpp @@ -25,3 +25,21 @@ TEST(mathMixMatFun, log10) { stan::test::expect_ad(f, std::complex{2.1, -0.0}); // (negative real and zero imaginary illegal) } + +TEST(mathMixMatFun, log10_varmat) { + using stan::math::vec_concat; + using stan::test::expect_ad_vector_matvar; + using stan::test::internal::common_args; + auto f = [](const auto& x1) { + using stan::math::log10; + return log10(x1); + }; + std::vector com_args = common_args(); + std::vector args{-0.2, 1e-3, 1, 1.3, 3, 3.7, 10, 10.2, 1e6}; + auto all_args = vec_concat(com_args, args); + Eigen::VectorXd A(all_args.size()); + for (int i = 0; i < all_args.size(); ++i) { + A(i) = all_args[i]; + } + expect_ad_vector_matvar(f, A); +} diff --git a/test/unit/math/mix/fun/log1m_exp_test.cpp b/test/unit/math/mix/fun/log1m_exp_test.cpp index 61833abfbee..19dcfba626f 100644 --- a/test/unit/math/mix/fun/log1m_exp_test.cpp +++ b/test/unit/math/mix/fun/log1m_exp_test.cpp @@ -6,3 +6,21 @@ TEST(mathMixMatFun, log1m_exp) { stan::test::expect_unary_vectorized(f, -14, -12.6, -2, -1, -0.2, -0.5, 1.3, 3); } + +TEST(mathMixMatFun, log1m_exp_varmat) { + using stan::math::vec_concat; + using stan::test::expect_ad_vector_matvar; + using stan::test::internal::common_args; + auto f = [](const auto& x1) { + using stan::math::log1m_exp; + return log1m_exp(x1); + }; + std::vector com_args = common_args(); + std::vector args{-14, -12.6, -2, -1, -0.2, -0.5, 1.3, 3}; + auto all_args = vec_concat(com_args, args); + Eigen::VectorXd A(all_args.size()); + for (int i = 0; i < all_args.size(); ++i) { + A(i) = all_args[i]; + } + expect_ad_vector_matvar(f, A); +} diff --git a/test/unit/math/mix/fun/log1m_inv_logit_test.cpp b/test/unit/math/mix/fun/log1m_inv_logit_test.cpp index 7955fb3c109..18be8401c75 100644 --- a/test/unit/math/mix/fun/log1m_inv_logit_test.cpp +++ b/test/unit/math/mix/fun/log1m_inv_logit_test.cpp @@ -6,3 +6,21 @@ TEST(mathMixMatFun, log1mInvLogit) { stan::test::expect_unary_vectorized(f, -2.6, -2, -1, -0.5, -0.2, 0.5, 1, 1.3, 3, 5); } + +TEST(mathMixMatFun, log1minvlogit_varmat) { + using stan::math::vec_concat; + using stan::test::expect_ad_vector_matvar; + using stan::test::internal::common_args; + auto f = [](const auto& x1) { + using stan::math::log1m_inv_logit; + return log1m_inv_logit(x1); + }; + std::vector com_args = common_args(); + std::vector args{-2.6, -2, -1, -0.5, -0.2, 0.5, 1, 1.3, 3, 5}; + auto all_args = vec_concat(com_args, args); + Eigen::VectorXd A(all_args.size()); + for (int i = 0; i < all_args.size(); ++i) { + A(i) = all_args[i]; + } + expect_ad_vector_matvar(f, A); +} diff --git a/test/unit/math/mix/fun/log2_test.cpp b/test/unit/math/mix/fun/log2_test.cpp index cc04b93ab7e..e57f2f3c603 100644 --- a/test/unit/math/mix/fun/log2_test.cpp +++ b/test/unit/math/mix/fun/log2_test.cpp @@ -6,3 +6,21 @@ TEST(mathMixMatFun, log2) { stan::test::expect_unary_vectorized(f, -0.2, 1e-3, 1, 1.3, 3, 3.7, 10, 10.2, 1e6); } + +TEST(mathMixMatFun, log2_varmat) { + using stan::math::vec_concat; + using stan::test::expect_ad_vector_matvar; + using stan::test::internal::common_args; + auto f = [](const auto& x1) { + using stan::math::log2; + return log2(x1); + }; + std::vector com_args = common_args(); + std::vector args{-0.2, 1e-3, 1, 1.3, 3, 3.7, 10, 10.2, 1e6}; + auto all_args = vec_concat(com_args, args); + Eigen::VectorXd A(all_args.size()); + for (int i = 0; i < all_args.size(); ++i) { + A(i) = all_args[i]; + } + expect_ad_vector_matvar(f, A); +} diff --git a/test/unit/math/mix/fun/logit_test.cpp b/test/unit/math/mix/fun/logit_test.cpp index f4bb592335f..a43bb4978ca 100644 --- a/test/unit/math/mix/fun/logit_test.cpp +++ b/test/unit/math/mix/fun/logit_test.cpp @@ -5,3 +5,21 @@ TEST(mathMixMatFun, logit) { stan::test::expect_common_unary_vectorized(f); stan::test::expect_unary_vectorized(f, -1.2, -0.5, 0.01, 0.5, 0.99, 1.5); } + +TEST(mathMixMatFun, logit_varmat) { + using stan::math::vec_concat; + using stan::test::expect_ad_vector_matvar; + using stan::test::internal::common_args; + auto f = [](const auto& x1) { + using stan::math::logit; + return logit(x1); + }; + std::vector com_args = common_args(); + std::vector args{-1.2, -0.5, 0.01, 0.5, 0.99, 1.5}; + auto all_args = vec_concat(com_args, args); + Eigen::VectorXd A(all_args.size()); + for (int i = 0; i < all_args.size(); ++i) { + A(i) = all_args[i]; + } + expect_ad_vector_matvar(f, A); +}