-
-
Notifications
You must be signed in to change notification settings - Fork 190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adds unary var matrix functions #2362
Changes from 5 commits
8262c78
400a7b1
48671e5
a6f95a2
e525689
e14ad85
bfe60c0
d51a225
f34fc2d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,12 +44,30 @@ namespace math { | |
* @return The corresponding unit normal cdf approximation. | ||
*/ | ||
inline var Phi_approx(const var& a) { | ||
double av = a.vi_->val_; | ||
double av_squared = av * av; | ||
double av_cubed = av * av_squared; | ||
double f = inv_logit(0.07056 * av_cubed + 1.5976 * av); | ||
double av_squared = a.val() * a.val(); | ||
double f = inv_logit(0.07056 * a.val() * av_squared + 1.5976 * a.val()); | ||
double da = f * (1 - f) * (3.0 * 0.07056 * av_squared + 1.5976); | ||
return var(new precomp_v_vari(f, a.vi_, da)); | ||
return make_callback_var( | ||
f, [a, da](auto& vi) mutable { a.adj() += vi.adj() * da; }); | ||
} | ||
|
||
template <typename T, require_var_matrix_t<T>* = nullptr> | ||
inline auto Phi_approx(const T& a) { | ||
arena_t<value_type_t<T>> f(a.rows(), a.cols()); | ||
arena_t<value_type_t<T>> da(a.rows(), a.cols()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it make sense to compute this in the forward pass? I guess we would compute av_squared twice if we compute this in the backward pass. Is that the reason? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah that's the general jist. we can compute adjoint while calculating the forward pass so it's just nice to alloc everything at once, throw it in one big loop, and then have an easy peasy reverse pass. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cool. |
||
for (Eigen::Index j = 0; j < a.cols(); ++j) { | ||
for (Eigen::Index i = 0; i < a.rows(); ++i) { | ||
const auto a_val = a.val().coeff(i, j); | ||
const auto av_squared = a_val * a_val; | ||
f.coeffRef(i, j) = inv_logit(0.07056 * a_val * av_squared | ||
+ 1.5976 * a.val().coeff(i, j)); | ||
da.coeffRef(i, j) = f.coeff(i, j) * (1 - f.coeff(i, j)) | ||
* (3.0 * 0.07056 * av_squared + 1.5976); | ||
} | ||
} | ||
return make_callback_var(f, [a, da](auto& vi) mutable { | ||
a.adj().array() += vi.adj().array() * da.array(); | ||
}); | ||
} | ||
|
||
} // namespace math | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably need doxygen here?