Skip to content

Commit c5e8f08

Browse files
authored
Merge pull request #2980 from stan-dev/multi-normal-derivatives-2
Add derivatives for mvn
2 parents f72a360 + 52b36e6 commit c5e8f08

File tree

3 files changed

+251
-43
lines changed

3 files changed

+251
-43
lines changed

stan/math/prim/prob/multi_normal_lpdf.hpp

+191-42
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,40 @@
55
#include <stan/math/prim/err.hpp>
66
#include <stan/math/prim/fun/as_column_vector_or_scalar.hpp>
77
#include <stan/math/prim/fun/constants.hpp>
8+
#include <stan/math/prim/fun/dot_product.hpp>
9+
#include <stan/math/prim/fun/eval.hpp>
10+
#include <stan/math/prim/fun/log.hpp>
811
#include <stan/math/prim/fun/log_determinant_ldlt.hpp>
912
#include <stan/math/prim/fun/max_size_mvt.hpp>
13+
#include <stan/math/prim/fun/mdivide_left_ldlt.hpp>
1014
#include <stan/math/prim/fun/size_mvt.hpp>
15+
#include <stan/math/prim/fun/sum.hpp>
1116
#include <stan/math/prim/fun/to_ref.hpp>
12-
#include <stan/math/prim/fun/trace_inv_quad_form_ldlt.hpp>
17+
#include <stan/math/prim/fun/transpose.hpp>
1318
#include <stan/math/prim/fun/vector_seq_view.hpp>
19+
#include <stan/math/prim/functor/partials_propagator.hpp>
1420

1521
namespace stan {
1622
namespace math {
1723

18-
template <bool propto, typename T_y, typename T_loc, typename T_covar>
24+
template <bool propto, typename T_y, typename T_loc, typename T_covar,
25+
require_any_not_vector_vt<is_stan_scalar, T_y, T_loc>* = nullptr,
26+
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
27+
T_y, T_loc, T_covar>* = nullptr>
1928
return_type_t<T_y, T_loc, T_covar> multi_normal_lpdf(const T_y& y,
2029
const T_loc& mu,
2130
const T_covar& Sigma) {
2231
using T_covar_elem = typename scalar_type<T_covar>::type;
23-
using lp_type = return_type_t<T_y, T_loc, T_covar>;
24-
using Eigen::Dynamic;
25-
static constexpr const char* function = "multi_normal_lpdf";
32+
using T_return = return_type_t<T_y, T_loc, T_covar>;
33+
using T_partials_return = partials_return_t<T_y, T_loc, T_covar>;
34+
using matrix_partials_t
35+
= Eigen::Matrix<T_partials_return, Eigen::Dynamic, Eigen::Dynamic>;
36+
using vector_partials_t = Eigen::Matrix<T_partials_return, Eigen::Dynamic, 1>;
37+
using T_y_ref = ref_type_t<T_y>;
38+
using T_mu_ref = ref_type_t<T_loc>;
39+
using T_Sigma_ref = ref_type_t<T_covar>;
40+
41+
static const char* function = "multi_normal_lpdf";
2642
check_positive(function, "Covariance matrix rows", Sigma.rows());
2743

2844
check_consistent_sizes_mvt(function, "y", y, "mu", mu);
@@ -32,32 +48,36 @@ return_type_t<T_y, T_loc, T_covar> multi_normal_lpdf(const T_y& y,
3248
return 0.0;
3349
}
3450

35-
lp_type lp(0.0);
36-
vector_seq_view<T_y> y_vec(y);
37-
vector_seq_view<T_loc> mu_vec(mu);
38-
size_t size_vec = max_size_mvt(y, mu);
51+
T_y_ref y_ref = y;
52+
T_mu_ref mu_ref = mu;
53+
T_Sigma_ref Sigma_ref = Sigma;
54+
vector_seq_view<T_y_ref> y_vec(y_ref);
55+
vector_seq_view<T_mu_ref> mu_vec(mu_ref);
56+
const size_t size_vec = max_size_mvt(y, mu);
57+
const int K = Sigma.rows();
3958

4059
int size_y = y_vec[0].size();
4160
int size_mu = mu_vec[0].size();
42-
if (size_vec > 1) {
43-
for (size_t i = 1, size_mvt_y = size_mvt(y); i < size_mvt_y; i++) {
44-
check_size_match(function,
45-
"Size of one of the vectors of "
46-
"the random variable",
47-
y_vec[i].size(),
48-
"Size of the first vector of the "
49-
"random variable",
50-
size_y);
51-
}
52-
for (size_t i = 1, size_mvt_mu = size_mvt(mu); i < size_mvt_mu; i++) {
53-
check_size_match(function,
54-
"Size of one of the vectors of "
55-
"the location variable",
56-
mu_vec[i].size(),
57-
"Size of the first vector of the "
58-
"location variable",
59-
size_mu);
60-
}
61+
62+
// check size consistency of all random variables y
63+
for (size_t i = 1, size_mvt_y = size_mvt(y); i < size_mvt_y; i++) {
64+
check_size_match(function,
65+
"Size of one of the vectors of "
66+
"the random variable",
67+
y_vec[i].size(),
68+
"Size of the first vector of the "
69+
"random variable",
70+
size_y);
71+
}
72+
// check size consistency of all means mu
73+
for (size_t i = 1, size_mvt_mu = size_mvt(mu); i < size_mvt_mu; i++) {
74+
check_size_match(function,
75+
"Size of one of the vectors of "
76+
"the location variable",
77+
mu_vec[i].size(),
78+
"Size of the first vector of the "
79+
"location variable",
80+
size_mu);
6181
}
6282

6383
check_size_match(function, "Size of random variable", size_y,
@@ -71,35 +91,164 @@ return_type_t<T_y, T_loc, T_covar> multi_normal_lpdf(const T_y& y,
7191
check_finite(function, "Location parameter", mu_vec[i]);
7292
check_not_nan(function, "Random variable", y_vec[i]);
7393
}
74-
const auto& Sigma_ref = to_ref(Sigma);
7594
check_symmetric(function, "Covariance matrix", Sigma_ref);
7695

77-
auto ldlt_Sigma = make_ldlt_factor(Sigma_ref);
96+
auto ldlt_Sigma = make_ldlt_factor(value_of(Sigma_ref));
97+
7898
check_ldlt_factor(function, "LDLT_Factor of covariance parameter",
7999
ldlt_Sigma);
80100

81-
if (size_y == 0) {
82-
return lp;
101+
if (unlikely(size_y == 0)) {
102+
return T_return(0);
83103
}
84104

105+
auto ops_partials = make_partials_propagator(y_ref, mu_ref, Sigma_ref);
106+
107+
T_partials_return logp(0);
108+
85109
if (include_summand<propto>::value) {
86-
lp += NEG_LOG_SQRT_TWO_PI * size_y * size_vec;
110+
logp += NEG_LOG_SQRT_TWO_PI * size_y * size_vec;
111+
}
112+
113+
if (include_summand<propto, T_y, T_loc, T_covar_elem>::value) {
114+
vector_partials_t half(size_vec);
115+
vector_partials_t y_val_minus_mu_val(size_vec);
116+
117+
T_partials_return sum_lp_vec(0.0);
118+
for (size_t i = 0; i < size_vec; i++) {
119+
const auto& y_val = as_value_column_vector_or_scalar(y_vec[i]);
120+
const auto& mu_val = as_value_column_vector_or_scalar(mu_vec[i]);
121+
y_val_minus_mu_val = eval(y_val - mu_val);
122+
half = mdivide_left_ldlt(ldlt_Sigma, y_val_minus_mu_val);
123+
124+
sum_lp_vec += dot_product(y_val_minus_mu_val, half);
125+
126+
if (!is_constant_all<T_y>::value) {
127+
partials_vec<0>(ops_partials)[i] += -half;
128+
}
129+
if (!is_constant_all<T_loc>::value) {
130+
partials_vec<1>(ops_partials)[i] += half;
131+
}
132+
if (!is_constant<T_covar_elem>::value) {
133+
partials_vec<2>(ops_partials)[i] += 0.5 * half * half.transpose();
134+
}
135+
}
136+
137+
logp += -0.5 * sum_lp_vec;
138+
139+
// If the covariance is not autodiff, we can avoid computing a matrix
140+
// inverse
141+
if (is_constant<T_covar_elem>::value) {
142+
if (include_summand<propto>::value) {
143+
logp += -0.5 * log_determinant_ldlt(ldlt_Sigma) * size_vec;
144+
}
145+
} else {
146+
matrix_partials_t inv_Sigma
147+
= mdivide_left_ldlt(ldlt_Sigma, Eigen::MatrixXd::Identity(K, K));
148+
149+
logp += -0.5 * log_determinant_ldlt(ldlt_Sigma) * size_vec;
150+
151+
partials<2>(ops_partials) += -0.5 * size_vec * inv_Sigma;
152+
}
153+
}
154+
155+
return ops_partials.build(logp);
156+
}
157+
158+
template <bool propto, typename T_y, typename T_loc, typename T_covar,
159+
require_all_vector_vt<is_stan_scalar, T_y, T_loc>* = nullptr,
160+
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
161+
T_y, T_loc, T_covar>* = nullptr>
162+
return_type_t<T_y, T_loc, T_covar> multi_normal_lpdf(const T_y& y,
163+
const T_loc& mu,
164+
const T_covar& Sigma) {
165+
using T_covar_elem = typename scalar_type<T_covar>::type;
166+
using T_return = return_type_t<T_y, T_loc, T_covar>;
167+
using T_partials_return = partials_return_t<T_y, T_loc, T_covar>;
168+
using matrix_partials_t
169+
= Eigen::Matrix<T_partials_return, Eigen::Dynamic, Eigen::Dynamic>;
170+
using vector_partials_t = Eigen::Matrix<T_partials_return, Eigen::Dynamic, 1>;
171+
using T_y_ref = ref_type_t<T_y>;
172+
using T_mu_ref = ref_type_t<T_loc>;
173+
using T_Sigma_ref = ref_type_t<T_covar>;
174+
175+
static const char* function = "multi_normal_lpdf";
176+
check_positive(function, "Covariance matrix rows", Sigma.rows());
177+
178+
T_y_ref y_ref = y;
179+
T_mu_ref mu_ref = mu;
180+
T_Sigma_ref Sigma_ref = Sigma;
181+
182+
decltype(auto) y_val = as_value_column_vector_or_scalar(y_ref);
183+
decltype(auto) mu_val = as_value_column_vector_or_scalar(mu_ref);
184+
185+
const int size_y = y_ref.size();
186+
const int size_mu = mu_ref.size();
187+
const unsigned int K = Sigma.rows();
188+
189+
check_finite(function, "Location parameter", mu_val);
190+
check_not_nan(function, "Random variable", y_val);
191+
192+
check_size_match(function, "Size of random variable", size_y,
193+
"size of location parameter", size_mu);
194+
check_size_match(function, "Size of random variable", size_y,
195+
"rows of covariance parameter", Sigma.rows());
196+
check_size_match(function, "Size of random variable", size_y,
197+
"columns of covariance parameter", Sigma.cols());
198+
199+
check_symmetric(function, "Covariance matrix", Sigma_ref);
200+
201+
auto ldlt_Sigma = make_ldlt_factor(value_of(Sigma_ref));
202+
check_ldlt_factor(function, "LDLT_Factor of covariance parameter",
203+
ldlt_Sigma);
204+
205+
if (unlikely(size_y == 0)) {
206+
return T_return(0);
87207
}
88208

89-
if (include_summand<propto, T_covar_elem>::value) {
90-
lp -= 0.5 * log_determinant_ldlt(ldlt_Sigma) * size_vec;
209+
auto ops_partials = make_partials_propagator(y_ref, mu_ref, Sigma_ref);
210+
211+
T_partials_return logp(0);
212+
213+
if (include_summand<propto>::value) {
214+
logp += NEG_LOG_SQRT_TWO_PI * size_y;
91215
}
92216

93217
if (include_summand<propto, T_y, T_loc, T_covar_elem>::value) {
94-
lp_type sum_lp_vec(0.0);
95-
for (size_t i = 0; i < size_vec; i++) {
96-
const auto& y_col = as_column_vector_or_scalar(y_vec[i]);
97-
const auto& mu_col = as_column_vector_or_scalar(mu_vec[i]);
98-
sum_lp_vec += trace_inv_quad_form_ldlt(ldlt_Sigma, y_col - mu_col);
218+
vector_partials_t half(size_y);
219+
vector_partials_t y_val_minus_mu_val = eval(y_val - mu_val);
220+
221+
// If the covariance is not autodiff, we can avoid computing a matrix
222+
// inverse
223+
if (is_constant<T_covar_elem>::value) {
224+
half = mdivide_left_ldlt(ldlt_Sigma, y_val_minus_mu_val);
225+
226+
if (include_summand<propto>::value) {
227+
logp += -0.5 * log_determinant_ldlt(ldlt_Sigma);
228+
}
229+
} else {
230+
matrix_partials_t inv_Sigma
231+
= mdivide_left_ldlt(ldlt_Sigma, Eigen::MatrixXd::Identity(K, K));
232+
233+
half.noalias() = inv_Sigma * y_val_minus_mu_val;
234+
235+
logp += -0.5 * log_determinant_ldlt(ldlt_Sigma);
236+
237+
edge<2>(ops_partials).partials_
238+
+= 0.5 * (half * half.transpose() - inv_Sigma);
239+
}
240+
241+
logp += -0.5 * dot_product(y_val_minus_mu_val, half);
242+
243+
if (!is_constant_all<T_y>::value) {
244+
partials<0>(ops_partials) += -half;
245+
}
246+
if (!is_constant_all<T_loc>::value) {
247+
partials<1>(ops_partials) += half;
99248
}
100-
lp -= 0.5 * sum_lp_vec;
101249
}
102-
return lp;
250+
251+
return ops_partials.build(logp);
103252
}
104253

105254
template <typename T_y, typename T_loc, typename T_covar>

test/unit/math/rev/prob/multi_normal2_test.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ TEST_F(agrad_distributions_multi_normal_multi_row, ProptoSigma) {
120120
stan::math::recover_memory();
121121
}
122122

123-
TEST(ProbDistributionsMultiNormal, MultiNormalVar) {
123+
TEST(ProbDistributionsMultiNormal, MultiNormalVar2) {
124124
using Eigen::Dynamic;
125125
using Eigen::Matrix;
126126
using stan::math::var;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#include <stan/math/rev.hpp>
2+
#include <test/unit/math/rev/util.hpp>
3+
#include <gtest/gtest.h>
4+
5+
TEST(ProbDistributionsMultiNormal, MultiNormalVar) {
6+
using Eigen::Dynamic;
7+
using Eigen::Matrix;
8+
using stan::math::var;
9+
using std::vector;
10+
Matrix<var, Dynamic, 1> y(3, 1);
11+
y << 2.0, -2.0, 11.0;
12+
Matrix<var, Dynamic, 1> mu(3, 1);
13+
mu << 1.0, -1.0, 3.0;
14+
Matrix<var, Dynamic, Dynamic> Sigma(3, 3);
15+
Sigma << 9.0, -3.0, 0.0, -3.0, 4.0, 0.0, 0.0, 0.0, 5.0;
16+
EXPECT_FLOAT_EQ(-11.73908, stan::math::multi_normal_lpdf(y, mu, Sigma).val());
17+
}
18+
19+
TEST(ProbDistributionsMultiNormal, check_varis_on_stack) {
20+
using Eigen::Dynamic;
21+
using Eigen::Matrix;
22+
using stan::math::to_var;
23+
using std::vector;
24+
Matrix<double, Dynamic, 1> y(3, 1);
25+
y << 2.0, -2.0, 11.0;
26+
Matrix<double, Dynamic, 1> mu(3, 1);
27+
mu << 1.0, -1.0, 3.0;
28+
Matrix<double, Dynamic, Dynamic> Sigma(3, 3);
29+
Sigma << 9.0, -3.0, 0.0, -3.0, 4.0, 0.0, 0.0, 0.0, 5.0;
30+
test::check_varis_on_stack(stan::math::multi_normal_lpdf<true>(
31+
to_var(y), to_var(mu), to_var(Sigma)));
32+
test::check_varis_on_stack(
33+
stan::math::multi_normal_lpdf<true>(to_var(y), to_var(mu), Sigma));
34+
test::check_varis_on_stack(
35+
stan::math::multi_normal_lpdf<true>(to_var(y), mu, to_var(Sigma)));
36+
test::check_varis_on_stack(
37+
stan::math::multi_normal_lpdf<true>(to_var(y), mu, Sigma));
38+
test::check_varis_on_stack(
39+
stan::math::multi_normal_lpdf<true>(y, to_var(mu), to_var(Sigma)));
40+
test::check_varis_on_stack(
41+
stan::math::multi_normal_lpdf<true>(y, to_var(mu), Sigma));
42+
test::check_varis_on_stack(
43+
stan::math::multi_normal_lpdf<true>(y, mu, to_var(Sigma)));
44+
45+
test::check_varis_on_stack(stan::math::multi_normal_lpdf<false>(
46+
to_var(y), to_var(mu), to_var(Sigma)));
47+
test::check_varis_on_stack(
48+
stan::math::multi_normal_lpdf<false>(to_var(y), to_var(mu), Sigma));
49+
test::check_varis_on_stack(
50+
stan::math::multi_normal_lpdf<false>(to_var(y), mu, to_var(Sigma)));
51+
test::check_varis_on_stack(
52+
stan::math::multi_normal_lpdf<false>(to_var(y), mu, Sigma));
53+
test::check_varis_on_stack(
54+
stan::math::multi_normal_lpdf<false>(y, to_var(mu), to_var(Sigma)));
55+
test::check_varis_on_stack(
56+
stan::math::multi_normal_lpdf<false>(y, to_var(mu), Sigma));
57+
test::check_varis_on_stack(
58+
stan::math::multi_normal_lpdf<false>(y, mu, to_var(Sigma)));
59+
}

0 commit comments

Comments
 (0)