Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove v_vari #2422

Merged
merged 25 commits into from
Apr 1, 2021
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
8f4379d
uses make_callback_var for more scalar var operations
SteveBronder Mar 13, 2021
6cdf20b
get rid of more op_v_vari
SteveBronder Mar 13, 2021
c7d8564
get rid of v_vari
SteveBronder Mar 13, 2021
17c7af5
cleanup core
SteveBronder Mar 13, 2021
a2b56c8
use = instead of += for all NaN assignments in reverse pass
SteveBronder Mar 13, 2021
0e0bd5c
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Mar 13, 2021
f29f383
fix includes
SteveBronder Mar 13, 2021
4dfa517
precomp things that seem expensive
SteveBronder Mar 17, 2021
63038eb
Merge commit 'bd0404db9509819c9bcb7905d38759dc773e010d' into HEAD
yashikno Mar 17, 2021
bac549d
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Mar 17, 2021
3104870
remove precomp_v_vari test
SteveBronder Mar 17, 2021
40e459d
remove nan checks
SteveBronder Mar 18, 2021
6821b96
Merge commit '0b5fbd41430702f5508eca1aff52fa7f508870b2' into HEAD
yashikno Mar 18, 2021
95ff82b
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Mar 18, 2021
c4a788e
fix a sign flip
SteveBronder Mar 18, 2021
a6242e5
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Mar 18, 2021
7bed51e
have opencl line up with rev for no NA values and logit
SteveBronder Mar 19, 2021
474c638
Merge commit '80f8578e1e531d0467d9233ac158a4419fec1ee4' into HEAD
yashikno Mar 19, 2021
14b111a
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Mar 19, 2021
f36c574
remove na checks from opencl functions
SteveBronder Mar 20, 2021
c211c78
Couple simplifications
bbbales2 Mar 30, 2021
eea024f
Merge commit '40803f80dd9e3cdbcfc014b91a9a316d9618fd23' into HEAD
yashikno Mar 30, 2021
c489500
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Mar 30, 2021
210f802
Merge remote-tracking branch 'origin/develop' into cleanup/vari-inherits
SteveBronder Mar 31, 2021
16e2681
uncomment test
SteveBronder Mar 31, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions stan/math/opencl/rev/ceil.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@ namespace math {
template <typename T,
require_all_kernel_expressions_and_none_scalar_t<T>* = nullptr>
inline var_value<matrix_cl<double>> ceil(const var_value<T>& A) {
return make_callback_var(
ceil(A.val()), [A](vari_value<matrix_cl<double>>& res) mutable {
A.adj() = select(isnan(A.val()), NOT_A_NUMBER, A.adj());
});
return var_value<matrix_cl<double>>(ceil(A.val()));
}

} // namespace math
Expand Down
5 changes: 1 addition & 4 deletions stan/math/opencl/rev/floor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@ namespace math {
template <typename T,
require_all_kernel_expressions_and_none_scalar_t<T>* = nullptr>
inline var_value<matrix_cl<double>> floor(const var_value<T>& A) {
return make_callback_var(
floor(A.val()), [A](vari_value<matrix_cl<double>>& res) mutable {
A.adj() = select(isnan(A.val()), NOT_A_NUMBER, A.adj());
});
return var_value<matrix_cl<double>>(floor(A.val()));
}

} // namespace math
Expand Down
3 changes: 2 additions & 1 deletion stan/math/opencl/rev/logit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ template <typename T,
inline var_value<matrix_cl<double>> logit(const var_value<T>& A) {
return make_callback_var(
logit(A.val()), [A](vari_value<matrix_cl<double>>& res) mutable {
A.adj() += elt_divide(res.adj(), A.val() * LOG_TEN);
A.adj() += elt_multiply(res.adj(),
elt_divide(1.0, A.val() - square(A.val())));
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@t4c1 fyi I think this was wrong before? I wrote it to match up to the rev version

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems so. How did the tests pass?

Anyway you can simplify this into:
A.adj() += elt_divide(res.adj(), A.val() - square(A.val()));

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reverted this change and the tests failed for me locally. It looks like the right tests are running in CI. I don't know. This seems like a canary worth investigating, but if nobody can reproduce it then nobody can reproduce it :/.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh also I switched to the simplified A.adj() += elt_divide(res.adj(), A.val() - square(A.val()));

});
}

Expand Down
5 changes: 1 addition & 4 deletions stan/math/opencl/rev/round.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@ namespace math {
template <typename T,
require_all_kernel_expressions_and_none_scalar_t<T>* = nullptr>
inline var_value<matrix_cl<double>> round(const var_value<T>& A) {
return make_callback_var(
round(A.val()), [A](vari_value<matrix_cl<double>>& res) mutable {
A.adj() = select(isnan(A.val()), NOT_A_NUMBER, A.adj());
});
return var_value<matrix_cl<double>>(round(A.val()));
}

} // namespace math
Expand Down
5 changes: 1 addition & 4 deletions stan/math/opencl/rev/trunc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@ namespace math {
template <typename T,
require_all_kernel_expressions_and_none_scalar_t<T>* = nullptr>
inline var_value<matrix_cl<double>> trunc(const var_value<T>& A) {
return make_callback_var(
trunc(A.val()), [A](vari_value<matrix_cl<double>>& res) mutable {
A.adj() = select(isnan(A.val()), NOT_A_NUMBER, A.adj());
});
return var_value<matrix_cl<double>>(trunc(A.val()));
}

} // namespace math
Expand Down
2 changes: 0 additions & 2 deletions stan/math/rev/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
#include <stan/math/rev/core/operator_unary_negative.hpp>
#include <stan/math/rev/core/operator_unary_not.hpp>
#include <stan/math/rev/core/operator_unary_plus.hpp>
#include <stan/math/rev/core/precomp_v_vari.hpp>
#include <stan/math/rev/core/precomp_vv_vari.hpp>
#include <stan/math/rev/core/precomp_vvv_vari.hpp>
#include <stan/math/rev/core/precomputed_gradients.hpp>
Expand All @@ -65,7 +64,6 @@
#include <stan/math/rev/core/std_numeric_limits.hpp>
#include <stan/math/rev/core/stored_gradient_vari.hpp>
#include <stan/math/rev/core/typedefs.hpp>
#include <stan/math/rev/core/v_vari.hpp>
#include <stan/math/rev/core/var.hpp>
#include <stan/math/rev/core/vari.hpp>
#include <stan/math/rev/core/vd_vari.hpp>
Expand Down
20 changes: 5 additions & 15 deletions stan/math/rev/core/operator_addition.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,8 @@ namespace math {
inline var operator+(const var& a, const var& b) {
return make_callback_vari(a.vi_->val_ + b.vi_->val_,
[avi = a.vi_, bvi = b.vi_](const auto& vi) mutable {
if (unlikely(std::isnan(vi.val_))) {
avi->adj_ = NOT_A_NUMBER;
bvi->adj_ = NOT_A_NUMBER;
} else {
avi->adj_ += vi.adj_;
bvi->adj_ += vi.adj_;
}
avi->adj_ += vi.adj_;
bvi->adj_ += vi.adj_;
});
}

Expand All @@ -80,14 +75,9 @@ inline var operator+(const var& a, Arith b) {
if (unlikely(b == 0.0)) {
return a;
}
return make_callback_vari(a.vi_->val_ + b,
[avi = a.vi_, b](const auto& vi) mutable {
if (unlikely(std::isnan(vi.val_))) {
avi->adj_ = NOT_A_NUMBER;
} else {
avi->adj_ += vi.adj_;
}
});
return make_callback_vari(
a.vi_->val_ + b,
[avi = a.vi_, b](const auto& vi) mutable { avi->adj_ += vi.adj_; });
}

/**
Expand Down
17 changes: 3 additions & 14 deletions stan/math/rev/core/operator_division.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,16 @@ class divide_vv_vari final : public op_vv_vari {
: op_vv_vari(dividend_vi->val_ / divisor_vi->val_, dividend_vi,
divisor_vi) {}
void chain() {
if (unlikely(is_any_nan(avi_->val_, bvi_->val_))) {
avi_->adj_ = NOT_A_NUMBER;
bvi_->adj_ = NOT_A_NUMBER;
} else {
avi_->adj_ += adj_ / bvi_->val_;
bvi_->adj_ -= adj_ * avi_->val_ / (bvi_->val_ * bvi_->val_);
}
avi_->adj_ += adj_ / bvi_->val_;
bvi_->adj_ -= adj_ * avi_->val_ / (bvi_->val_ * bvi_->val_);
}
};

class divide_vd_vari final : public op_vd_vari {
public:
divide_vd_vari(vari* dividend_vi, double divisor)
: op_vd_vari(dividend_vi->val_ / divisor, dividend_vi, divisor) {}
void chain() {
if (unlikely(is_any_nan(avi_->val_, bd_))) {
avi_->adj_ = NOT_A_NUMBER;
} else {
avi_->adj_ += adj_ / bd_;
}
}
void chain() { avi_->adj_ += adj_ / bd_; }
};

class divide_dv_vari final : public op_dv_vari {
Expand Down
17 changes: 3 additions & 14 deletions stan/math/rev/core/operator_multiplication.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,15 @@ class multiply_vv_vari final : public op_vv_vari {
multiply_vv_vari(vari* avi, vari* bvi)
: op_vv_vari(avi->val_ * bvi->val_, avi, bvi) {}
void chain() {
if (unlikely(is_any_nan(avi_->val_, bvi_->val_))) {
avi_->adj_ = NOT_A_NUMBER;
bvi_->adj_ = NOT_A_NUMBER;
} else {
avi_->adj_ += bvi_->val_ * adj_;
bvi_->adj_ += avi_->val_ * adj_;
}
avi_->adj_ += bvi_->val_ * adj_;
bvi_->adj_ += avi_->val_ * adj_;
}
};

class multiply_vd_vari final : public op_vd_vari {
public:
multiply_vd_vari(vari* avi, double b) : op_vd_vari(avi->val_ * b, avi, b) {}
void chain() {
if (unlikely(is_any_nan(avi_->val_, bd_))) {
avi_->adj_ = NOT_A_NUMBER;
} else {
avi_->adj_ += adj_ * bd_;
}
}
void chain() { avi_->adj_ += adj_ * bd_; }
};
} // namespace internal

Expand Down
31 changes: 8 additions & 23 deletions stan/math/rev/core/operator_subtraction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,8 @@ namespace math {
inline var operator-(const var& a, const var& b) {
return make_callback_vari(a.vi_->val_ - b.vi_->val_,
[avi = a.vi_, bvi = b.vi_](const auto& vi) mutable {
if (unlikely(is_nan(vi.val_))) {
avi->adj_ = NOT_A_NUMBER;
bvi->adj_ = NOT_A_NUMBER;
} else {
avi->adj_ += vi.adj_;
bvi->adj_ -= vi.adj_;
}
avi->adj_ += vi.adj_;
bvi->adj_ -= vi.adj_;
});
}

Expand All @@ -84,14 +79,9 @@ inline var operator-(const var& a, Arith b) {
if (unlikely(b == 0.0)) {
return a;
}
return make_callback_vari(a.vi_->val_ - b,
[avi = a.vi_, b](const auto& vi) mutable {
if (unlikely(is_nan(vi.val_))) {
avi->adj_ = NOT_A_NUMBER;
} else {
avi->adj_ += vi.adj_;
}
});
return make_callback_vari(
a.vi_->val_ - b,
[avi = a.vi_, b](const auto& vi) mutable { avi->adj_ += vi.adj_; });
}

/**
Expand All @@ -109,14 +99,9 @@ inline var operator-(const var& a, Arith b) {
*/
template <typename Arith, require_arithmetic_t<Arith>* = nullptr>
inline var operator-(Arith a, const var& b) {
return make_callback_vari(a - b.vi_->val_,
[bvi = b.vi_, a](const auto& vi) mutable {
if (unlikely(is_nan(vi.val_))) {
bvi->adj_ = NOT_A_NUMBER;
} else {
bvi->adj_ -= vi.adj_;
}
});
return make_callback_vari(
a - b.vi_->val_,
[bvi = b.vi_, a](const auto& vi) mutable { bvi->adj_ -= vi.adj_; });
}

/**
Expand Down
20 changes: 3 additions & 17 deletions stan/math/rev/core/operator_unary_decrement.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,13 @@

#include <stan/math/prim/meta.hpp>
#include <stan/math/rev/core/var.hpp>
#include <stan/math/rev/core/v_vari.hpp>
#include <stan/math/rev/core/callback_vari.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/is_nan.hpp>

namespace stan {
namespace math {

namespace internal {
class decrement_vari final : public op_v_vari {
public:
explicit decrement_vari(vari* avi) : op_v_vari(avi->val_ - 1.0, avi) {}
void chain() {
if (unlikely(is_nan(avi_->val_))) {
avi_->adj_ = NOT_A_NUMBER;
} else {
avi_->adj_ += adj_;
}
}
};
} // namespace internal

/**
* Prefix decrement operator for variables (C++).
*
Expand All @@ -38,7 +24,7 @@ class decrement_vari final : public op_v_vari {
* @return Reference the result of decrementing this input variable.
*/
inline var& operator--(var& a) {
a.vi_ = new internal::decrement_vari(a.vi_);
a = make_callback_var(a.val() - 1.0, [a](auto& vi) { a.adj() += vi.adj(); });
return a;
}

Expand All @@ -55,7 +41,7 @@ inline var& operator--(var& a) {
*/
inline var operator--(var& a, int /*dummy*/) {
var temp(a);
a.vi_ = new internal::decrement_vari(a.vi_);
a = make_callback_var(a.val() - 1.0, [a](auto& vi) { a.adj() += vi.adj(); });
return temp;
}

Expand Down
20 changes: 3 additions & 17 deletions stan/math/rev/core/operator_unary_increment.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,13 @@

#include <stan/math/prim/meta.hpp>
#include <stan/math/rev/core/var.hpp>
#include <stan/math/rev/core/v_vari.hpp>
#include <stan/math/rev/core/callback_vari.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/is_nan.hpp>

namespace stan {
namespace math {

namespace internal {
class increment_vari final : public op_v_vari {
public:
explicit increment_vari(vari* avi) : op_v_vari(avi->val_ + 1.0, avi) {}
void chain() {
if (unlikely(is_nan(avi_->val_))) {
avi_->adj_ = NOT_A_NUMBER;
} else {
avi_->adj_ += adj_;
}
}
};
} // namespace internal

/**
* Prefix increment operator for variables (C++). Following C++,
* (++a) is defined to behave exactly as (a = a + 1.0) does,
Expand All @@ -34,7 +20,7 @@ class increment_vari final : public op_v_vari {
* @return Reference the result of incrementing this input variable.
*/
inline var& operator++(var& a) {
a.vi_ = new internal::increment_vari(a.vi_);
a = make_callback_var(a.val() + 1.0, [a](auto& vi) { a.adj() += vi.adj(); });
return a;
}

Expand All @@ -51,7 +37,7 @@ inline var& operator++(var& a) {
*/
inline var operator++(var& a, int /*dummy*/) {
var temp(a);
a.vi_ = new internal::increment_vari(a.vi_);
a = make_callback_var(a.val() + 1.0, [a](auto& vi) { a.adj() += vi.adj(); });
return temp;
}

Expand Down
32 changes: 4 additions & 28 deletions stan/math/rev/core/operator_unary_negative.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,13 @@

#include <stan/math/prim/meta.hpp>
#include <stan/math/rev/core/var.hpp>
#include <stan/math/rev/core/v_vari.hpp>
#include <stan/math/rev/core/callback_vari.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/is_nan.hpp>

namespace stan {
namespace math {

namespace internal {
class neg_vari final : public op_v_vari {
public:
explicit neg_vari(vari* avi) : op_v_vari(-(avi->val_), avi) {}
void chain() {
if (unlikely(is_nan(avi_->val_))) {
avi_->adj_ = NOT_A_NUMBER;
} else {
avi_->adj_ -= adj_;
}
}
};
} // namespace internal

/**
* Unary negation operator for variables (C++).
*
Expand All @@ -50,13 +35,8 @@ class neg_vari final : public op_v_vari {
* @return Negation of variable.
*/
inline var operator-(const var& a) {
return make_callback_var(-a.val(), [a](const auto vi) {
if (unlikely(is_nan(a.val()))) {
a.adj() = NOT_A_NUMBER;
} else {
a.adj() -= vi.adj();
}
});
return make_callback_var(
-a.val(), [a](const auto& vi) mutable { a.adj() -= vi.adj(); });
}

/**
Expand All @@ -68,14 +48,10 @@ inline var operator-(const var& a) {
*/
template <typename T, require_var_matrix_t<T>* = nullptr>
inline auto operator-(const T& a) {
return make_callback_var(-a.val(), [a](const auto vi) {
return make_callback_var(-a.val(), [a](const auto& vi) mutable {
for (Eigen::Index j = 0; j < a.cols(); ++j) {
for (Eigen::Index i = 0; i < a.rows(); ++i) {
if (unlikely(is_nan(a.val().coeffRef(i, j)))) {
a.adj().coeffRef(i, j) = NOT_A_NUMBER;
} else {
a.adj().coeffRef(i, j) -= vi.adj().coeff(i, j);
}
a.adj().coeffRef(i, j) -= vi.adj().coeff(i, j);
}
}
});
Expand Down
Loading