Skip to content

Commit a4abbe2

Browse files
committed
Merge remote-tracking branch 'origin/develop' into feature/lb_constrain-matvar
2 parents ae36dad + 0d75edc commit a4abbe2

7 files changed

+205
-60
lines changed

stan/math/opencl/prim/beta_binomial_lpmf.hpp

+7-3
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,11 @@ return_type_t<T_n_cl, T_size1_cl, T_size2_cl> beta_binomial_lpmf(
5858
return 0.0;
5959
}
6060

61-
const auto& alpha_val = value_of(alpha);
62-
const auto& beta_val = value_of(beta);
61+
const auto& alpha_col = as_column_vector_or_scalar(alpha);
62+
const auto& beta_col = as_column_vector_or_scalar(beta);
63+
64+
const auto& alpha_val = value_of(alpha_col);
65+
const auto& beta_val = value_of(beta_col);
6366

6467
auto check_N_nonnegative
6568
= check_cl(function, "Population size parameter", N, "nonnegative");
@@ -102,7 +105,8 @@ return_type_t<T_n_cl, T_size1_cl, T_size2_cl> beta_binomial_lpmf(
102105

103106
double logp = sum(from_matrix_cl(logp_cl));
104107

105-
operands_and_partials<T_size1_cl, T_size2_cl> ops_partials(alpha, beta);
108+
operands_and_partials<decltype(alpha_col), decltype(beta_col)> ops_partials(
109+
alpha_col, beta_col);
106110
if (!is_constant<T_size1_cl>::value) {
107111
ops_partials.edge1_.partials_ = std::move(alpha_deriv_cl);
108112
}

stan/math/opencl/prim/binomial_logit_lpmf.hpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ return_type_t<T_prob_cl> binomial_logit_lpmf(const T_n_cl& n, const T_N_cl N,
4747
return 0.0;
4848
}
4949

50-
const auto& alpha_val = value_of(alpha);
50+
const auto& alpha_col = as_column_vector_or_scalar(alpha);
51+
const auto& alpha_val = value_of(alpha_col);
5152

5253
auto check_n_bounded
5354
= check_cl(function, "Successes variable", n, "in the interval [0, N]");
@@ -81,7 +82,7 @@ return_type_t<T_prob_cl> binomial_logit_lpmf(const T_n_cl& n, const T_N_cl N,
8182
calc_if<!is_constant<T_prob_cl>::value>(alpha_deriv));
8283

8384
T_partials_return logp = sum(from_matrix_cl(logp_cl));
84-
operands_and_partials<T_prob_cl> ops_partials(alpha);
85+
operands_and_partials<decltype(alpha_col)> ops_partials(alpha_col);
8586
if (!is_constant<T_prob_cl>::value) {
8687
ops_partials.edge1_.partials_ = std::move(alpha_deriv_cl);
8788
}

stan/math/opencl/rev.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
#ifdef STAN_OPENCL
44

55
#include <stan/math/opencl/prim.hpp>
6-
#include <stan/math/opencl/rev/as_column_vector_or_scalar.hpp>
76
#include <stan/math/opencl/rev/append_col.hpp>
87
#include <stan/math/opencl/rev/append_row.hpp>
8+
#include <stan/math/opencl/rev/as_column_vector_or_scalar.hpp>
99
#include <stan/math/opencl/rev/acos.hpp>
1010
#include <stan/math/opencl/rev/acosh.hpp>
1111
#include <stan/math/opencl/rev/add.hpp>

test/unit/math/opencl/rev/beta_binomial_lpmf_test.cpp

+32
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,12 @@ TEST(ProbDistributionsBetaBinomial, opencl_matches_cpu_small) {
100100
N, alpha, beta);
101101
stan::math::test::compare_cpu_opencl_prim_rev(
102102
beta_binomial_lpmf_functor_propto, n, N, alpha, beta);
103+
stan::math::test::compare_cpu_opencl_prim_rev(beta_binomial_lpmf_functor, n,
104+
N, alpha.transpose().eval(),
105+
beta.transpose().eval());
106+
stan::math::test::compare_cpu_opencl_prim_rev(
107+
beta_binomial_lpmf_functor_propto, n, N, alpha.transpose().eval(),
108+
beta.transpose().eval());
103109
}
104110

105111
TEST(ProbDistributionsBetaBinomial, opencl_broadcast_n) {
@@ -116,6 +122,12 @@ TEST(ProbDistributionsBetaBinomial, opencl_broadcast_n) {
116122
beta_binomial_lpmf_functor, n, N, alpha, beta);
117123
stan::math::test::test_opencl_broadcasting_prim_rev<0>(
118124
beta_binomial_lpmf_functor_propto, n, N, alpha, beta);
125+
stan::math::test::test_opencl_broadcasting_prim_rev<0>(
126+
beta_binomial_lpmf_functor, n, N, alpha.transpose().eval(),
127+
beta.transpose().eval());
128+
stan::math::test::test_opencl_broadcasting_prim_rev<0>(
129+
beta_binomial_lpmf_functor_propto, n, N, alpha.transpose().eval(),
130+
beta.transpose().eval());
119131
}
120132

121133
TEST(ProbDistributionsBetaBinomial, opencl_broadcast_N) {
@@ -132,6 +144,12 @@ TEST(ProbDistributionsBetaBinomial, opencl_broadcast_N) {
132144
beta_binomial_lpmf_functor, n, N, alpha, beta);
133145
stan::math::test::test_opencl_broadcasting_prim_rev<1>(
134146
beta_binomial_lpmf_functor_propto, n, N, alpha, beta);
147+
stan::math::test::test_opencl_broadcasting_prim_rev<1>(
148+
beta_binomial_lpmf_functor, n, N, alpha.transpose().eval(),
149+
beta.transpose().eval());
150+
stan::math::test::test_opencl_broadcasting_prim_rev<1>(
151+
beta_binomial_lpmf_functor_propto, n, N, alpha.transpose().eval(),
152+
beta.transpose().eval());
135153
}
136154

137155
TEST(ProbDistributionsBetaBinomial, opencl_broadcast_alpha) {
@@ -147,6 +165,10 @@ TEST(ProbDistributionsBetaBinomial, opencl_broadcast_alpha) {
147165
beta_binomial_lpmf_functor, n, N, alpha, beta);
148166
stan::math::test::test_opencl_broadcasting_prim_rev<2>(
149167
beta_binomial_lpmf_functor_propto, n, N, alpha, beta);
168+
stan::math::test::test_opencl_broadcasting_prim_rev<2>(
169+
beta_binomial_lpmf_functor, n, N, alpha, beta.transpose().eval());
170+
stan::math::test::test_opencl_broadcasting_prim_rev<2>(
171+
beta_binomial_lpmf_functor_propto, n, N, alpha, beta.transpose().eval());
150172
}
151173

152174
TEST(ProbDistributionsBetaBinomial, opencl_broadcast_beta) {
@@ -161,6 +183,10 @@ TEST(ProbDistributionsBetaBinomial, opencl_broadcast_beta) {
161183
beta_binomial_lpmf_functor, n, N, alpha, beta);
162184
stan::math::test::test_opencl_broadcasting_prim_rev<3>(
163185
beta_binomial_lpmf_functor_propto, n, N, alpha, beta);
186+
stan::math::test::test_opencl_broadcasting_prim_rev<3>(
187+
beta_binomial_lpmf_functor, n, N, alpha.transpose().eval(), beta);
188+
stan::math::test::test_opencl_broadcasting_prim_rev<3>(
189+
beta_binomial_lpmf_functor_propto, n, N, alpha.transpose().eval(), beta);
164190
}
165191

166192
TEST(ProbDistributionsBetaBinomial, opencl_matches_cpu_big) {
@@ -182,6 +208,12 @@ TEST(ProbDistributionsBetaBinomial, opencl_matches_cpu_big) {
182208
N, alpha, beta);
183209
stan::math::test::compare_cpu_opencl_prim_rev(
184210
beta_binomial_lpmf_functor_propto, n, N, alpha, beta);
211+
stan::math::test::compare_cpu_opencl_prim_rev(beta_binomial_lpmf_functor, n,
212+
N, alpha.transpose().eval(),
213+
beta.transpose().eval());
214+
stan::math::test::compare_cpu_opencl_prim_rev(
215+
beta_binomial_lpmf_functor_propto, n, N, alpha.transpose().eval(),
216+
beta.transpose().eval());
185217
}
186218

187219
#endif

test/unit/math/opencl/rev/binomial_logit_lpmf_test.cpp

+16
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ TEST(ProbDistributionsBinomialLogit, opencl_matches_cpu_small) {
7171
m, alpha);
7272
stan::math::test::compare_cpu_opencl_prim_rev(
7373
binomial_logit_lpmf_functor_propto, n, m, alpha);
74+
stan::math::test::compare_cpu_opencl_prim_rev(binomial_logit_lpmf_functor, n,
75+
m, alpha.transpose().eval());
76+
stan::math::test::compare_cpu_opencl_prim_rev(
77+
binomial_logit_lpmf_functor_propto, n, m, alpha.transpose().eval());
7478
}
7579

7680
TEST(ProbDistributionsBinomialLogit, opencl_broadcast_n) {
@@ -85,6 +89,10 @@ TEST(ProbDistributionsBinomialLogit, opencl_broadcast_n) {
8589
binomial_logit_lpmf_functor, n, m, alpha);
8690
stan::math::test::test_opencl_broadcasting_prim_rev<0>(
8791
binomial_logit_lpmf_functor_propto, n, m, alpha);
92+
stan::math::test::test_opencl_broadcasting_prim_rev<0>(
93+
binomial_logit_lpmf_functor, n, m, alpha.transpose().eval());
94+
stan::math::test::test_opencl_broadcasting_prim_rev<0>(
95+
binomial_logit_lpmf_functor_propto, n, m, alpha.transpose().eval());
8896
}
8997

9098
TEST(ProbDistributionsBinomialLogit, opencl_broadcast_N) {
@@ -99,6 +107,10 @@ TEST(ProbDistributionsBinomialLogit, opencl_broadcast_N) {
99107
binomial_logit_lpmf_functor, n, m, alpha);
100108
stan::math::test::test_opencl_broadcasting_prim_rev<1>(
101109
binomial_logit_lpmf_functor_propto, n, m, alpha);
110+
stan::math::test::test_opencl_broadcasting_prim_rev<1>(
111+
binomial_logit_lpmf_functor, n, m, alpha.transpose().eval());
112+
stan::math::test::test_opencl_broadcasting_prim_rev<1>(
113+
binomial_logit_lpmf_functor_propto, n, m, alpha.transpose().eval());
102114
}
103115

104116
TEST(ProbDistributionsBinomialLogit, opencl_broadcast_alpha) {
@@ -131,6 +143,10 @@ TEST(ProbDistributionsBinomialLogit, opencl_matches_cpu_big) {
131143
m, alpha);
132144
stan::math::test::compare_cpu_opencl_prim_rev(
133145
binomial_logit_lpmf_functor_propto, n, m, alpha);
146+
stan::math::test::compare_cpu_opencl_prim_rev(binomial_logit_lpmf_functor, n,
147+
m, alpha.transpose().eval());
148+
stan::math::test::compare_cpu_opencl_prim_rev(
149+
binomial_logit_lpmf_functor_propto, n, m, alpha.transpose().eval());
134150
}
135151

136152
TEST(ProbDistributionsBinomialLogit, opencl_n_N_scalar) {

test/unit/math/opencl/rev/dirichlet_lpdf_test.cpp

+74-28
Original file line numberDiff line numberDiff line change
@@ -80,71 +80,117 @@ TEST(ProbDistributionsDirichlet, opencl_matches_cpu_small) {
8080

8181
Eigen::VectorXd theta1(N);
8282
theta1 << 0.5, 0.4, 0.1;
83-
Eigen::VectorXd theta2(N);
83+
Eigen::RowVectorXd theta2(N);
8484
theta2 << 0.6, 0.2, 0.2;
85-
std::vector<Eigen::VectorXd> theta{theta1, theta2};
85+
std::vector<Eigen::VectorXd> theta3{theta1, theta2};
86+
std::vector<Eigen::RowVectorXd> theta4{theta2, theta1};
8687
Eigen::VectorXd alpha1(N);
8788
alpha1 << 0.5, 0.1, 12.3;
88-
Eigen::VectorXd alpha2(N);
89+
Eigen::RowVectorXd alpha2(N);
8990
alpha2 << 2.1, 3.4, 2.3;
90-
std::vector<Eigen::VectorXd> alpha{alpha1, alpha2};
91+
std::vector<Eigen::VectorXd> alpha3{alpha1, alpha2};
92+
std::vector<Eigen::RowVectorXd> alpha4{alpha2, alpha1};
9193

92-
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor, theta,
93-
alpha);
94+
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor, theta1,
95+
alpha1);
9496
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor_propto,
95-
theta, alpha);
96-
97+
theta1, alpha1);
98+
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor, theta1,
99+
alpha3);
100+
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor_propto,
101+
theta1, alpha3);
97102
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor, theta1,
98-
alpha);
103+
alpha4);
99104
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor_propto,
100-
theta1, alpha);
105+
theta1, alpha4);
101106

102-
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor, theta,
107+
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor, theta3,
103108
alpha1);
104109
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor_propto,
105-
theta, alpha1);
106-
110+
theta3, alpha1);
107111
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor, theta1,
112+
alpha3);
113+
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor_propto,
114+
theta3, alpha3);
115+
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor, theta3,
116+
alpha4);
117+
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor_propto,
118+
theta3, alpha4);
119+
120+
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor, theta4,
108121
alpha1);
109122
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor_propto,
110-
theta1, alpha1);
123+
theta4, alpha1);
124+
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor, theta4,
125+
alpha3);
126+
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor_propto,
127+
theta4, alpha3);
128+
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor, theta4,
129+
alpha4);
130+
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor_propto,
131+
theta4, alpha4);
111132
}
112133

113134
TEST(ProbDistributionsDirichlet, opencl_matches_cpu_big) {
114135
int N = 153;
115136
int M = 11;
116137
Eigen::VectorXd theta1;
117138
Eigen::VectorXd alpha1;
118-
std::vector<Eigen::VectorXd> theta;
119-
std::vector<Eigen::VectorXd> alpha;
139+
std::vector<Eigen::VectorXd> theta3;
140+
std::vector<Eigen::VectorXd> alpha3;
141+
std::vector<Eigen::RowVectorXd> theta4;
142+
std::vector<Eigen::RowVectorXd> alpha4;
120143

121144
for (int i = 0; i < M; i++) {
122145
theta1 = Eigen::Array<double, Eigen::Dynamic, 1>::Random(N, 1).abs();
123146
theta1 /= theta1.sum();
124147
alpha1 = Eigen::Array<double, Eigen::Dynamic, 1>::Random(N, 1).abs();
125-
theta.push_back(theta1);
126-
alpha.push_back(alpha1);
148+
theta3.push_back(theta1);
149+
alpha3.push_back(alpha1);
150+
theta4.push_back(theta1);
151+
alpha4.push_back(alpha1);
127152
}
153+
Eigen::RowVectorXd theta2 = theta1;
154+
Eigen::RowVectorXd alpha2 = alpha1;
128155

129-
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor, theta,
130-
alpha);
156+
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor, theta1,
157+
alpha1);
131158
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor_propto,
132-
theta, alpha);
133-
159+
theta1, alpha1);
160+
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor, theta1,
161+
alpha3);
162+
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor_propto,
163+
theta1, alpha3);
134164
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor, theta1,
135-
alpha);
165+
alpha4);
136166
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor_propto,
137-
theta1, alpha);
167+
theta1, alpha4);
138168

139-
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor, theta,
169+
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor, theta3,
140170
alpha1);
141171
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor_propto,
142-
theta, alpha1);
143-
172+
theta3, alpha1);
144173
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor, theta1,
174+
alpha3);
175+
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor_propto,
176+
theta3, alpha3);
177+
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor, theta3,
178+
alpha4);
179+
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor_propto,
180+
theta3, alpha4);
181+
182+
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor, theta4,
145183
alpha1);
146184
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor_propto,
147-
theta1, alpha1);
185+
theta4, alpha1);
186+
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor, theta4,
187+
alpha3);
188+
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor_propto,
189+
theta4, alpha3);
190+
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor, theta4,
191+
alpha4);
192+
stan::math::test::compare_cpu_opencl_prim_rev(dirichlet_lpdf_functor_propto,
193+
theta4, alpha4);
148194
}
149195

150196
#endif

0 commit comments

Comments
 (0)