Skip to content

Commit 122d8b5

Browse files
authored
Merge pull request #2338 from bstatcomp/opencl_ordered_logistic
Add OpenCL ordered_logistic_lpmf
2 parents 24eeaad + af17849 commit 122d8b5

9 files changed

+567
-16
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
#ifndef STAN_MATH_OPENCL_KERNELS_ORDERED_LOGISTIC_LPMF_HPP
2+
#define STAN_MATH_OPENCL_KERNELS_ORDERED_LOGISTIC_LPMF_HPP
3+
#ifdef STAN_OPENCL
4+
5+
#include <stan/math/opencl/kernel_cl.hpp>
6+
#include <stan/math/opencl/kernels/device_functions/log1m_exp.hpp>
7+
#include <stan/math/opencl/kernels/device_functions/log1p_exp.hpp>
8+
9+
namespace stan {
10+
namespace math {
11+
namespace opencl_kernels {
12+
13+
// \cond
14+
static const char* ordered_logistic_kernel_code = STRINGIFY(
15+
// \endcond
16+
/** \ingroup opencl_kernels
17+
* GPU implementation of ordinal regression.
18+
*
19+
* Must be run with at least N_instances threads and local size equal to
20+
* LOCAL_SIZE_.
21+
* @param[out] logp_global partially summed log probability (1 value per
22+
* work group)
23+
* @param[out] lambda_derivative derivative wrt lambda
24+
* @param[out] cuts_derivative partially summed derivative wrt cuts (1
25+
* column per work group)
26+
* @param[in] y_global a scalar or vector of classes.
27+
* @param[in] lambda_global vector of continuous lambda variables
28+
* @param[in] cuts cutpoints vector
29+
* @param N_instances number of cases
30+
* @param N_classes number of classes
31+
* @param is_y_vector 0 or 1 - whether y is a vector (alternatively it is a
32+
* scalar we need to broadcast)
33+
* @param is_cuts_matrix 0 or 1 - whether cuts is a matrix (alternatively it
34+
* is a vector we need to broadcast)
35+
* @param need_lambda_derivative 0 or 1 - whether lambda_derivative needs to
36+
* be computed
37+
* @param need_cuts_derivative 0 or 1 - whether cuts_derivative needs to be
38+
* computed
39+
*/
40+
__kernel void ordered_logistic(
41+
__global double* logp_global, __global double* lambda_derivative,
42+
__global double* cuts_derivative, const __global int* y_global,
43+
const __global double* lambda_global, const __global double* cuts,
44+
const int N_instances, const int N_classes, const int is_y_vector,
45+
const int is_cuts_matrix, const int need_lambda_derivative,
46+
const int need_cuts_derivative) {
47+
const int gid = get_global_id(0);
48+
const int lid = get_local_id(0);
49+
const int lsize = get_local_size(0);
50+
const int wg_id = get_group_id(0);
51+
const int ngroups = get_num_groups(0);
52+
53+
__local double local_storage[LOCAL_SIZE_];
54+
55+
double logp = 0;
56+
double d1 = 0;
57+
double d2 = 0;
58+
int y;
59+
int cuts_start = (N_classes - 1) * gid * is_cuts_matrix;
60+
// Most calculations only happen for relevant data within next if.
61+
// Exceptions are reductions between threads that need barriers.
62+
if (gid < N_instances) {
63+
double lambda = lambda_global[gid];
64+
y = y_global[gid * is_y_vector];
65+
if (y < 1 || y > N_classes || !isfinite(lambda)) {
66+
logp = NAN;
67+
} else {
68+
const double cut_y1
69+
= y == N_classes ? INFINITY : cuts[cuts_start + y - 1];
70+
const double cut_y2 = y == 1 ? -INFINITY : cuts[cuts_start + y - 2];
71+
const double cut1 = lambda - cut_y1;
72+
const double cut2 = lambda - cut_y2;
73+
74+
if (y != N_classes) {
75+
logp -= log1p_exp(cut1);
76+
}
77+
if (y != 1) {
78+
logp -= log1p_exp(-cut2);
79+
}
80+
if (y != 1 && y != N_classes) {
81+
logp += log1m_exp(cut1 - cut2);
82+
}
83+
84+
if (need_lambda_derivative || need_cuts_derivative) {
85+
double exp_cuts_diff = exp(cut_y2 - cut_y1);
86+
if (cut2 > 0) {
87+
double exp_m_cut2 = exp(-cut2);
88+
d1 = exp_m_cut2 / (1 + exp_m_cut2);
89+
} else {
90+
d1 = 1 / (1 + exp(cut2));
91+
}
92+
d1 -= exp_cuts_diff / (exp_cuts_diff - 1);
93+
d2 = 1 / (1 - exp_cuts_diff);
94+
if (cut1 > 0) {
95+
double exp_m_cut1 = exp(-cut1);
96+
d2 -= exp_m_cut1 / (1 + exp_m_cut1);
97+
} else {
98+
d2 -= 1 / (1 + exp(cut1));
99+
}
100+
101+
if (need_lambda_derivative) {
102+
lambda_derivative[gid] = d1 - d2;
103+
}
104+
}
105+
}
106+
}
107+
if (need_cuts_derivative) {
108+
if (is_cuts_matrix) {
109+
if (gid < N_instances) {
110+
for (int i = 0; i < N_classes - 1; i++) {
111+
if (y - 1 == i) {
112+
cuts_derivative[cuts_start + i] = d2;
113+
} else if (y - 2 == i) {
114+
cuts_derivative[cuts_start + i] = -d1;
115+
} else {
116+
cuts_derivative[cuts_start + i] = 0.0;
117+
}
118+
}
119+
}
120+
} else {
121+
for (int i = 0; i < N_classes - 1; i++) {
122+
local_storage[lid] = 0;
123+
if (gid < N_instances) {
124+
if (y - 1 == i) {
125+
local_storage[lid] = d2;
126+
} else if (y - 2 == i) {
127+
local_storage[lid] = -d1;
128+
}
129+
}
130+
// Sum cuts_derivative, calculated by different threads.
131+
// Since we can't sum between different work groups, we emit one
132+
// number per work group. These must be summed on CPU for final
133+
// result.
134+
barrier(CLK_LOCAL_MEM_FENCE);
135+
for (int step = lsize / REDUCTION_STEP_SIZE; step > 0;
136+
step /= REDUCTION_STEP_SIZE) {
137+
if (lid < step) {
138+
for (int i = 1; i < REDUCTION_STEP_SIZE; i++) {
139+
local_storage[lid] += local_storage[lid + step * i];
140+
}
141+
}
142+
barrier(CLK_LOCAL_MEM_FENCE);
143+
}
144+
if (lid == 0) {
145+
cuts_derivative[(N_classes - 1) * wg_id + i] = local_storage[0];
146+
}
147+
barrier(CLK_LOCAL_MEM_FENCE);
148+
}
149+
}
150+
}
151+
local_storage[lid] = logp;
152+
barrier(CLK_LOCAL_MEM_FENCE);
153+
for (int step = lsize / REDUCTION_STEP_SIZE; step > 0;
154+
step /= REDUCTION_STEP_SIZE) {
155+
if (lid < step) {
156+
for (int i = 1; i < REDUCTION_STEP_SIZE; i++) {
157+
local_storage[lid] += local_storage[lid + step * i];
158+
}
159+
}
160+
barrier(CLK_LOCAL_MEM_FENCE);
161+
}
162+
if (lid == 0) {
163+
logp_global[wg_id] = local_storage[0];
164+
}
165+
}
166+
// \cond
167+
);
168+
// \endcond
169+
170+
/** \ingroup opencl_kernels
171+
* See the docs for \link kernels/ordered_logistic_lpmf.hpp
172+
* ordered_logistic() \endlink
173+
*/
174+
const kernel_cl<out_buffer, out_buffer, out_buffer, in_buffer, in_buffer,
175+
in_buffer, int, int, int, int, int, int>
176+
ordered_logistic("ordered_logistic",
177+
{log1p_exp_device_function, log1m_exp_device_function,
178+
ordered_logistic_kernel_code},
179+
{{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}});
180+
181+
} // namespace opencl_kernels
182+
} // namespace math
183+
} // namespace stan
184+
185+
#endif
186+
#endif

stan/math/opencl/prim.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@
156156
#include <stan/math/opencl/prim/normal_lpdf.hpp>
157157
#include <stan/math/opencl/prim/num_elements.hpp>
158158
#include <stan/math/opencl/prim/ordered_logistic_glm_lpmf.hpp>
159+
#include <stan/math/opencl/prim/ordered_logistic_lpmf.hpp>
159160
#include <stan/math/opencl/prim/pareto_lpdf.hpp>
160161
#include <stan/math/opencl/prim/pareto_type_2_lpdf.hpp>
161162
#include <stan/math/opencl/prim/poisson_log_glm_lpmf.hpp>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
#ifndef STAN_MATH_OPENCL_PRIM_ORDERED_LOGISTIC_LPMF_HPP
2+
#define STAN_MATH_OPENCL_PRIM_ORDERED_LOGISTIC_LPMF_HPP
3+
#ifdef STAN_OPENCL
4+
5+
#include <stan/math/opencl/kernel_generator.hpp>
6+
#include <stan/math/opencl/kernels/add.hpp>
7+
#include <stan/math/opencl/kernels/ordered_logistic_lpmf.hpp>
8+
#include <stan/math/prim/meta.hpp>
9+
#include <stan/math/prim/err.hpp>
10+
#include <stan/math/prim/fun/constants.hpp>
11+
#include <stan/math/prim/fun/elt_divide.hpp>
12+
#include <stan/math/prim/fun/elt_multiply.hpp>
13+
#include <stan/math/prim/functor/operands_and_partials.hpp>
14+
#include <stan/math/prim/err/constraint_tolerance.hpp>
15+
16+
namespace stan {
17+
namespace math {
18+
19+
/** \ingroup opencl
20+
* Returns the (natural) log probability of the specified array
21+
* of integers given the vector of continuous locations and
22+
* specified cutpoints in an ordered logistic model.
23+
*
24+
* <p>Typically the continuous lambda
25+
* will be the dot product of a vector of regression coefficients
26+
* and a vector of predictors for the outcome
27+
*
28+
\f[
29+
\frac{\partial }{\partial \lambda} =
30+
\begin{cases}\\
31+
-\mathrm{logit}^{-1}(\lambda - c_1) & \mbox{if } k = 1,\\
32+
-(((1-e^{c_{k-1}-c_{k-2}})^{-1} - \mathrm{logit}^{-1}(c_{k-2}-\lambda)) +
33+
((1-e^{c_{k-2}-c_{k-1}})^{-1} - \mathrm{logit}^{-1}(c_{k-1}-\lambda)))
34+
& \mathrm{if } 1 < k < K, \mathrm{and}\\
35+
\mathrm{logit}^{-1}(c_{K-2}-\lambda) & \mathrm{if } k = K.
36+
\end{cases}
37+
\f]
38+
39+
\f[
40+
\frac{\partial }{\partial \lambda} =
41+
\begin{cases}
42+
-\mathrm{logit}^{-1}(\lambda - c_1) & \text{if } k = 1,\\
43+
-(((1-e^{c_{k-1}-c_{k-2}})^{-1} - \mathrm{logit}^{-1}(c_{k-2}-\lambda)) +
44+
((1-e^{c_{k-2}-c_{k-1}})^{-1} - \mathrm{logit}^{-1}(c_{k-1}-\lambda)))
45+
& \text{if } 1 < k < K, \text{ and}\\
46+
\mathrm{logit}^{-1}(c_{K-2}-\lambda) & \text{if } k = K.
47+
\end{cases}
48+
\f]
49+
*
50+
* @tparam propto True if calculating up to a proportion.
51+
* @tparam T_y Y variable type (integer or array of integers).
52+
* @tparam T_loc lambda type.
53+
* @tparam T_cut Cut-point type.
54+
* @param y Array of integers
55+
* @param lambda Vector of continuous lambda variables.
56+
* @param cuts Positive increasing vector of cutpoints.
57+
* @return Log probability of outcome given lambda and
58+
* cutpoints.
59+
* @throw std::domain_error If the outcome is not between 1 and
60+
* the number of cutpoints plus 2; if the cutpoint vector is
61+
* empty; if the cutpoint vector contains a non-positive,
62+
* non-finite value; or if the cutpoint vector is not sorted in
63+
* ascending order.
64+
* @throw std::invalid_argument If y and lambda are different
65+
* lengths.
66+
*/
67+
template <bool propto, typename T_y_cl, typename T_loc_cl, typename T_cuts_cl,
68+
require_all_prim_or_rev_kernel_expression_t<T_y_cl, T_loc_cl,
69+
T_cuts_cl>* = nullptr>
70+
inline return_type_t<T_y_cl, T_loc_cl, T_cuts_cl> ordered_logistic_lpmf(
71+
const T_y_cl& y, const T_loc_cl& lambda, const T_cuts_cl& cuts) {
72+
constexpr bool is_y_vector = !is_stan_scalar<T_y_cl>::value;
73+
static const char* function = "ordered_logistic_lpmf(OpenCL)";
74+
75+
if (size(y) != 1) {
76+
check_size_match(function, "Size of ", "y", size(y), "Size of", "lambda",
77+
size(lambda));
78+
}
79+
80+
int N_instances = max_size(y, lambda);
81+
int N_classes = cuts.rows() + 1;
82+
int N_cut_sets = cuts.cols();
83+
84+
if (N_cut_sets > 1) {
85+
check_size_match(function, "Length of lambda variables ", N_instances,
86+
"Number of cutpoint vectors ", N_cut_sets);
87+
}
88+
if (N_instances == 0 || N_classes == 1) {
89+
return 0.0;
90+
}
91+
const auto& cuts_val = eval(value_of(cuts));
92+
if (N_classes >= 2) {
93+
auto cuts_head
94+
= block_zero_based(cuts_val, 0, 0, cuts.rows() - 1, N_cut_sets);
95+
auto cuts_tail
96+
= block_zero_based(cuts_val, 1, 0, cuts.rows() - 1, N_cut_sets);
97+
check_cl(function, "Cuts", cuts_head, "ordered and finite")
98+
= cuts_head < cuts_tail && isfinite(cuts_head) && isfinite(cuts_tail);
99+
} else if (N_classes == 1) {
100+
check_cl(function, "Cuts", cuts_val, "finite") = isfinite(cuts_val);
101+
}
102+
103+
if (!include_summand<propto, T_loc_cl, T_cuts_cl>::value) {
104+
return 0.0;
105+
}
106+
107+
const auto& y_val = eval(value_of(y));
108+
const auto& lambda_val = eval(value_of(lambda));
109+
110+
const auto& y_val_cl = to_matrix_cl(y_val);
111+
112+
const int local_size
113+
= opencl_kernels::ordered_logistic.get_option("LOCAL_SIZE_");
114+
const int wgs = (N_instances + local_size - 1) / local_size;
115+
116+
bool need_lambda_derivative = !is_constant_all<T_loc_cl>::value;
117+
bool need_cuts_derivative = !is_constant_all<T_cuts_cl>::value;
118+
bool need_broadcasting = N_cut_sets == 1 && N_instances != 1;
119+
matrix_cl<double> logp_cl(wgs, 1);
120+
matrix_cl<double> lambda_derivative_cl(N_instances,
121+
need_lambda_derivative ? 1 : 0);
122+
matrix_cl<double> cuts_derivative_cl(
123+
N_classes - 1,
124+
need_cuts_derivative ? (need_broadcasting ? wgs : N_cut_sets) : 0);
125+
126+
try {
127+
opencl_kernels::ordered_logistic(
128+
cl::NDRange(local_size * wgs), cl::NDRange(local_size), logp_cl,
129+
lambda_derivative_cl, cuts_derivative_cl, y_val_cl, lambda_val,
130+
cuts_val, N_instances, N_classes, is_y_vector, !need_broadcasting,
131+
need_lambda_derivative, need_cuts_derivative);
132+
} catch (const cl::Error& e) {
133+
check_opencl_error(function, e);
134+
}
135+
136+
double logp = sum(from_matrix_cl(logp_cl));
137+
138+
if (!std::isfinite(logp)) {
139+
results(check_cl(function, "Vector of dependent variables", y_val,
140+
"between 0 and number of classes"),
141+
check_cl(function, "lambda vector", lambda_val, "finite"))
142+
= expressions(y_val >= 1 && y_val <= static_cast<int>(N_classes),
143+
isfinite(lambda_val));
144+
}
145+
operands_and_partials<T_loc_cl, T_cuts_cl> ops_partials(lambda, cuts);
146+
147+
if (!is_constant_all<T_loc_cl>::value) {
148+
ops_partials.edge1_.partials_ = lambda_derivative_cl;
149+
}
150+
if (!is_constant_all<T_cuts_cl>::value) {
151+
if (need_broadcasting) {
152+
ops_partials.edge2_.partials_ = rowwise_sum(cuts_derivative_cl);
153+
} else {
154+
ops_partials.edge2_.partials_ = std::move(cuts_derivative_cl);
155+
}
156+
}
157+
return ops_partials.build(logp);
158+
}
159+
160+
} // namespace math
161+
} // namespace stan
162+
#endif
163+
#endif

stan/math/opencl/rev/operands_and_partials.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class ops_partials_edge<double, var_value<Op>, require_matrix_cl_t<Op>> {
2020
partials_t partials_; // For univariate use-cases
2121
broadcast_array<partials_t> partials_vec_; // For multivariate
2222
explicit ops_partials_edge(const var_value<Op>& ops)
23-
: partials_(constant(0, ops.vi_->rows(), ops.vi_->cols())),
23+
: partials_(constant(0.0, ops.vi_->rows(), ops.vi_->cols())),
2424
partials_vec_(partials_),
2525
operands_(ops) {}
2626

0 commit comments

Comments
 (0)