@@ -80,71 +80,117 @@ TEST(ProbDistributionsDirichlet, opencl_matches_cpu_small) {
80
80
81
81
Eigen::VectorXd theta1 (N);
82
82
theta1 << 0.5 , 0.4 , 0.1 ;
83
- Eigen::VectorXd theta2 (N);
83
+ Eigen::RowVectorXd theta2 (N);
84
84
theta2 << 0.6 , 0.2 , 0.2 ;
85
- std::vector<Eigen::VectorXd> theta{theta1, theta2};
85
+ std::vector<Eigen::VectorXd> theta3{theta1, theta2};
86
+ std::vector<Eigen::RowVectorXd> theta4{theta2, theta1};
86
87
Eigen::VectorXd alpha1 (N);
87
88
alpha1 << 0.5 , 0.1 , 12.3 ;
88
- Eigen::VectorXd alpha2 (N);
89
+ Eigen::RowVectorXd alpha2 (N);
89
90
alpha2 << 2.1 , 3.4 , 2.3 ;
90
- std::vector<Eigen::VectorXd> alpha{alpha1, alpha2};
91
+ std::vector<Eigen::VectorXd> alpha3{alpha1, alpha2};
92
+ std::vector<Eigen::RowVectorXd> alpha4{alpha2, alpha1};
91
93
92
- stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor, theta ,
93
- alpha );
94
+ stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor, theta1 ,
95
+ alpha1 );
94
96
stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor_propto,
95
- theta, alpha);
96
-
97
+ theta1, alpha1);
98
+ stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor, theta1,
99
+ alpha3);
100
+ stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor_propto,
101
+ theta1, alpha3);
97
102
stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor, theta1,
98
- alpha );
103
+ alpha4 );
99
104
stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor_propto,
100
- theta1, alpha );
105
+ theta1, alpha4 );
101
106
102
- stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor, theta ,
107
+ stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor, theta3 ,
103
108
alpha1);
104
109
stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor_propto,
105
- theta, alpha1);
106
-
110
+ theta3, alpha1);
107
111
stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor, theta1,
112
+ alpha3);
113
+ stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor_propto,
114
+ theta3, alpha3);
115
+ stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor, theta3,
116
+ alpha4);
117
+ stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor_propto,
118
+ theta3, alpha4);
119
+
120
+ stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor, theta4,
108
121
alpha1);
109
122
stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor_propto,
110
- theta1, alpha1);
123
+ theta4, alpha1);
124
+ stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor, theta4,
125
+ alpha3);
126
+ stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor_propto,
127
+ theta4, alpha3);
128
+ stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor, theta4,
129
+ alpha4);
130
+ stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor_propto,
131
+ theta4, alpha4);
111
132
}
112
133
113
134
TEST (ProbDistributionsDirichlet, opencl_matches_cpu_big) {
114
135
int N = 153 ;
115
136
int M = 11 ;
116
137
Eigen::VectorXd theta1;
117
138
Eigen::VectorXd alpha1;
118
- std::vector<Eigen::VectorXd> theta;
119
- std::vector<Eigen::VectorXd> alpha;
139
+ std::vector<Eigen::VectorXd> theta3;
140
+ std::vector<Eigen::VectorXd> alpha3;
141
+ std::vector<Eigen::RowVectorXd> theta4;
142
+ std::vector<Eigen::RowVectorXd> alpha4;
120
143
121
144
for (int i = 0 ; i < M; i++) {
122
145
theta1 = Eigen::Array<double , Eigen::Dynamic, 1 >::Random (N, 1 ).abs ();
123
146
theta1 /= theta1.sum ();
124
147
alpha1 = Eigen::Array<double , Eigen::Dynamic, 1 >::Random (N, 1 ).abs ();
125
- theta.push_back (theta1);
126
- alpha.push_back (alpha1);
148
+ theta3.push_back (theta1);
149
+ alpha3.push_back (alpha1);
150
+ theta4.push_back (theta1);
151
+ alpha4.push_back (alpha1);
127
152
}
153
+ Eigen::RowVectorXd theta2 = theta1;
154
+ Eigen::RowVectorXd alpha2 = alpha1;
128
155
129
- stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor, theta ,
130
- alpha );
156
+ stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor, theta1 ,
157
+ alpha1 );
131
158
stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor_propto,
132
- theta, alpha);
133
-
159
+ theta1, alpha1);
160
+ stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor, theta1,
161
+ alpha3);
162
+ stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor_propto,
163
+ theta1, alpha3);
134
164
stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor, theta1,
135
- alpha );
165
+ alpha4 );
136
166
stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor_propto,
137
- theta1, alpha );
167
+ theta1, alpha4 );
138
168
139
- stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor, theta ,
169
+ stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor, theta3 ,
140
170
alpha1);
141
171
stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor_propto,
142
- theta, alpha1);
143
-
172
+ theta3, alpha1);
144
173
stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor, theta1,
174
+ alpha3);
175
+ stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor_propto,
176
+ theta3, alpha3);
177
+ stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor, theta3,
178
+ alpha4);
179
+ stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor_propto,
180
+ theta3, alpha4);
181
+
182
+ stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor, theta4,
145
183
alpha1);
146
184
stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor_propto,
147
- theta1, alpha1);
185
+ theta4, alpha1);
186
+ stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor, theta4,
187
+ alpha3);
188
+ stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor_propto,
189
+ theta4, alpha3);
190
+ stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor, theta4,
191
+ alpha4);
192
+ stan::math::test::compare_cpu_opencl_prim_rev (dirichlet_lpdf_functor_propto,
193
+ theta4, alpha4);
148
194
}
149
195
150
196
#endif
0 commit comments