|
15 | 15 | namespace stan {
|
16 | 16 | namespace math {
|
17 | 17 |
|
18 |
| -template <typename T, int R, int C> |
19 |
| -inline Eigen::Matrix<fvar<T>, R, C> unit_vector_constrain( |
20 |
| - const Eigen::Matrix<fvar<T>, R, C>& y) { |
21 |
| - using Eigen::Matrix; |
22 |
| - using std::sqrt; |
23 |
| - |
24 |
| - Matrix<T, R, C> y_t(y.size()); |
25 |
| - for (int k = 0; k < y.size(); ++k) { |
26 |
| - y_t.coeffRef(k) = y.coeff(k).val_; |
27 |
| - } |
28 |
| - |
29 |
| - Matrix<T, R, C> unit_vector_y_t = unit_vector_constrain(y_t); |
30 |
| - Matrix<fvar<T>, R, C> unit_vector_y(y.size()); |
31 |
| - for (int k = 0; k < y.size(); ++k) { |
32 |
| - unit_vector_y.coeffRef(k).val_ = unit_vector_y_t.coeff(k); |
33 |
| - } |
34 |
| - |
35 |
| - T squared_norm = dot_self(y_t); |
36 |
| - T norm = sqrt(squared_norm); |
37 |
| - T inv_norm = inv(norm); |
38 |
| - Matrix<T, Eigen::Dynamic, Eigen::Dynamic> J |
39 |
| - = divide(tcrossprod(y_t), -norm * squared_norm); |
40 |
| - |
41 |
| - for (int m = 0; m < y.size(); ++m) { |
| 18 | +template <typename EigMat, require_eigen_vt<is_fvar, EigMat>* = nullptr> |
| 19 | +inline auto unit_vector_constrain(const EigMat& y) { |
| 20 | + using eig_partial = partials_type_t<value_type_t<EigMat>>; |
| 21 | + promote_scalar_t<eig_partial, EigMat> y_val(value_of(y)); |
| 22 | + plain_type_t<EigMat> unit_vector_y(y_val.size()); |
| 23 | + unit_vector_y.val() = unit_vector_constrain(y_val); |
| 24 | + |
| 25 | + eig_partial squared_norm = dot_self(y_val); |
| 26 | + eig_partial norm = sqrt(squared_norm); |
| 27 | + eig_partial inv_norm = inv(norm); |
| 28 | + Eigen::Matrix<eig_partial, Eigen::Dynamic, Eigen::Dynamic> J |
| 29 | + = divide(tcrossprod(y_val), -norm * squared_norm); |
| 30 | + |
| 31 | + for (Eigen::Index m = 0; m < y_val.size(); ++m) { |
42 | 32 | J.coeffRef(m, m) += inv_norm;
|
43 |
| - for (int k = 0; k < y.size(); ++k) { |
| 33 | + for (Eigen::Index k = 0; k < y_val.size(); ++k) { |
44 | 34 | unit_vector_y.coeffRef(k).d_ = J.coeff(k, m);
|
45 | 35 | }
|
46 | 36 | }
|
47 | 37 | return unit_vector_y;
|
48 | 38 | }
|
49 | 39 |
|
50 |
| -template <typename T, int R, int C> |
51 |
| -inline Eigen::Matrix<fvar<T>, R, C> unit_vector_constrain( |
52 |
| - const Eigen::Matrix<fvar<T>, R, C>& y, fvar<T>& lp) { |
53 |
| - fvar<T> squared_norm = dot_self(y); |
| 40 | +template <typename EigMat, typename T, |
| 41 | + require_eigen_vt<is_fvar, EigMat>* = nullptr, |
| 42 | + require_stan_scalar_t<T>* = nullptr> |
| 43 | +inline auto unit_vector_constrain(const EigMat& y, T& lp) { |
| 44 | + const auto& y_ref = to_ref(y); |
| 45 | + const value_type_t<EigMat> squared_norm = dot_self(y_ref); |
54 | 46 | lp -= 0.5 * squared_norm;
|
55 |
| - return unit_vector_constrain(y); |
| 47 | + return unit_vector_constrain(y_ref); |
56 | 48 | }
|
57 | 49 |
|
58 | 50 | } // namespace math
|
|
0 commit comments