Skip to content

Commit d47c331

Browse files
authored
Merge pull request #2185 from bstatcomp/opencl_gamma_lpdf
Opencl gamma_lpdf and gumbel_lpdf
2 parents aafad99 + 7aca964 commit d47c331

File tree

10 files changed

+591
-6
lines changed

10 files changed

+591
-6
lines changed

stan/math/opencl/kernel_generator/check_cl.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class check_cl_ : public operation_cl_lhs<check_cl_<T>, bool> {
5858
err_variable_(err_variable),
5959
must_be_(must_be) {
6060
buffer_.zeros();
61+
buffer_.view(matrix_cl_view::Entire);
6162
}
6263

6364
// this operation can not be used on the right hand side of assignment
@@ -86,7 +87,7 @@ class check_cl_ : public operation_cl_lhs<check_cl_<T>, bool> {
8687
+ type_str<value_type_t<T>>() + "* " + var_name_ + "_value, ";
8788
res.body += "bool " + var_name_;
8889
res.body_suffix += "if(!" + var_name_ +
89-
" && atomic_xchg(" + var_name_ + "_buffer, 1) == 1){\n"
90+
" && atomic_xchg(" + var_name_ + "_buffer, 1) == 0){\n"
9091
+ var_name_ + "_buffer[1] = " + row_index_name + ";\n"
9192
+ var_name_ + "_buffer[2] = " + col_index_name + ";\n"
9293
+ var_name_ + "_value[0] = " + arg_.var_name_ + ";\n"

stan/math/opencl/opencl.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,9 @@
119119
#include <stan/math/opencl/prim/exp_mod_normal_lpdf.hpp>
120120
#include <stan/math/opencl/prim/exponential_lpdf.hpp>
121121
#include <stan/math/opencl/prim/frechet_lpdf.hpp>
122+
#include <stan/math/opencl/prim/gamma_lpdf.hpp>
122123
#include <stan/math/opencl/prim/gp_exp_quad_cov.hpp>
124+
#include <stan/math/opencl/prim/gumbel_lpdf.hpp>
123125
#include <stan/math/opencl/prim/inv.hpp>
124126
#include <stan/math/opencl/prim/inv_chi_square_lpdf.hpp>
125127
#include <stan/math/opencl/prim/inv_cloglog.hpp>

stan/math/opencl/prim/gamma_lpdf.hpp

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
#ifndef STAN_MATH_OPENCL_PRIM_GAMMA_LPDF_HPP
2+
#define STAN_MATH_OPENCL_PRIM_GAMMA_LPDF_HPP
3+
#ifdef STAN_OPENCL
4+
5+
#include <stan/math/opencl/kernel_generator.hpp>
6+
#include <stan/math/opencl/prim/size.hpp>
7+
#include <stan/math/prim/meta.hpp>
8+
#include <stan/math/prim/err.hpp>
9+
#include <stan/math/prim/fun/constants.hpp>
10+
#include <stan/math/prim/fun/digamma.hpp>
11+
#include <stan/math/prim/fun/lgamma.hpp>
12+
#include <stan/math/prim/fun/max_size.hpp>
13+
#include <stan/math/prim/functor/operands_and_partials.hpp>
14+
15+
namespace stan {
16+
namespace math {
17+
18+
/** \ingroup opencl
19+
* The log of a gamma density for y with the specified
20+
* shape and inverse scale parameters.
21+
* Shape and inverse scale parameters must be greater than 0.
22+
* y must be greater than or equal to 0.
23+
*
24+
\f{eqnarray*}{
25+
y &\sim& \mbox{\sf{Gamma}}(\alpha, \beta) \\
26+
\log (p (y \, |\, \alpha, \beta) ) &=& \log \left(
27+
\frac{\beta^\alpha}{\Gamma(\alpha)} y^{\alpha - 1} \exp^{- \beta y} \right) \\
28+
&=& \alpha \log(\beta) - \log(\Gamma(\alpha)) + (\alpha - 1) \log(y) - \beta
29+
y\\ & & \mathrm{where} \; y > 0 \f}
30+
*
31+
* @tparam T_y_cl type of scalar
32+
* @tparam T_shape_cl type of shape
33+
* @tparam T_inv_scale_cl type of inverse scale
34+
* @param y A scalar variable.
35+
* @param alpha Shape parameter.
36+
* @param beta Inverse scale parameter.
37+
* @throw std::domain_error if alpha is not greater than 0.
38+
* @throw std::domain_error if beta is not greater than 0.
39+
* @throw std::domain_error if y is not greater than or equal to 0.
40+
*/
41+
template <bool propto, typename T_y_cl, typename T_shape_cl,
42+
typename T_inv_scale_cl,
43+
require_all_prim_or_rev_kernel_expression_t<
44+
T_y_cl, T_shape_cl, T_inv_scale_cl>* = nullptr,
45+
require_any_not_stan_scalar_t<T_y_cl, T_shape_cl,
46+
T_inv_scale_cl>* = nullptr>
47+
return_type_t<T_y_cl, T_shape_cl, T_inv_scale_cl> gamma_lpdf(
48+
const T_y_cl& y, const T_shape_cl& alpha, const T_inv_scale_cl& beta) {
49+
using std::isfinite;
50+
using std::isnan;
51+
static const char* function = "gamma_lpdf(OpenCL)";
52+
using T_partials_return
53+
= partials_return_t<T_y_cl, T_shape_cl, T_inv_scale_cl>;
54+
55+
check_consistent_sizes(function, "Random variable", y, "Shape parameter",
56+
alpha, "Inverse scale parameter", beta);
57+
const size_t N = max_size(y, alpha, beta);
58+
if (N == 0) {
59+
return 0.0;
60+
}
61+
if (!include_summand<propto, T_y_cl, T_shape_cl, T_inv_scale_cl>::value) {
62+
return 0.0;
63+
}
64+
65+
const auto& y_val = value_of(y);
66+
const auto& alpha_val = value_of(alpha);
67+
const auto& beta_val = value_of(beta);
68+
69+
auto check_y_not_nan
70+
= check_cl(function, "Random variable", y_val, "not NaN");
71+
auto y_not_nan_expr = !isnan(y_val);
72+
auto check_alpha_pos_finite
73+
= check_cl(function, "Shape parameter", alpha_val, "positive finite");
74+
auto alpha_pos_finite_expr = alpha_val > 0 && isfinite(alpha_val);
75+
auto check_beta_pos_finite = check_cl(function, "Inverse scale parameter",
76+
beta_val, "positive finite");
77+
auto beta_pos_finite_expr = beta_val > 0 && isfinite(beta_val);
78+
79+
auto any_y_negative_expr = colwise_max(constant(0, N, 1) + (y_val < 0));
80+
auto log_y_expr = log(y_val);
81+
auto log_beta_expr = log(beta_val);
82+
auto logp1_expr = static_select<include_summand<propto, T_shape_cl>::value>(
83+
-lgamma(alpha_val), constant(0.0, N, 1));
84+
auto logp2_expr = static_select<
85+
include_summand<propto, T_shape_cl, T_inv_scale_cl>::value>(
86+
logp1_expr + elt_multiply(alpha_val, log_beta_expr), logp1_expr);
87+
auto logp3_expr
88+
= static_select<include_summand<propto, T_y_cl, T_shape_cl>::value>(
89+
logp2_expr + elt_multiply(alpha_val - 1.0, log_y_expr), logp2_expr);
90+
auto logp_expr = colwise_sum(
91+
static_select<include_summand<propto, T_y_cl, T_inv_scale_cl>::value>(
92+
logp3_expr - elt_multiply(beta_val, y_val), logp3_expr));
93+
94+
auto y_deriv_expr = elt_divide(alpha_val - 1, y_val) - beta_val;
95+
auto alpha_deriv_expr = log_beta_expr + log_y_expr - digamma(alpha_val);
96+
auto beta_deriv_expr = elt_divide(alpha_val, beta_val) - y_val;
97+
98+
matrix_cl<int> any_y_negative_cl;
99+
matrix_cl<double> logp_cl;
100+
matrix_cl<double> y_deriv_cl;
101+
matrix_cl<double> alpha_deriv_cl;
102+
matrix_cl<double> beta_deriv_cl;
103+
104+
results(check_y_not_nan, check_alpha_pos_finite, check_beta_pos_finite,
105+
any_y_negative_cl, logp_cl, y_deriv_cl, alpha_deriv_cl, beta_deriv_cl)
106+
= expressions(
107+
y_not_nan_expr, alpha_pos_finite_expr, beta_pos_finite_expr,
108+
any_y_negative_expr, logp_expr,
109+
calc_if<!is_constant<T_y_cl>::value>(y_deriv_expr),
110+
calc_if<!is_constant<T_shape_cl>::value>(alpha_deriv_expr),
111+
calc_if<!is_constant<T_inv_scale_cl>::value>(beta_deriv_expr));
112+
113+
if (from_matrix_cl(any_y_negative_cl).any()) {
114+
return LOG_ZERO;
115+
}
116+
117+
T_partials_return logp = sum(from_matrix_cl(logp_cl));
118+
119+
operands_and_partials<T_y_cl, T_shape_cl, T_inv_scale_cl> ops_partials(
120+
y, alpha, beta);
121+
if (!is_constant<T_y_cl>::value) {
122+
ops_partials.edge1_.partials_ = std::move(y_deriv_cl);
123+
}
124+
if (!is_constant<T_shape_cl>::value) {
125+
ops_partials.edge2_.partials_ = std::move(alpha_deriv_cl);
126+
}
127+
if (!is_constant<T_inv_scale_cl>::value) {
128+
ops_partials.edge3_.partials_ = std::move(beta_deriv_cl);
129+
}
130+
131+
return ops_partials.build(logp);
132+
}
133+
134+
} // namespace math
135+
} // namespace stan
136+
#endif
137+
#endif

stan/math/opencl/prim/gumbel_lpdf.hpp

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
#ifndef STAN_MATH_OPENCL_PRIM_GUMBEL_LPDF_HPP
2+
#define STAN_MATH_OPENCL_PRIM_GUMBEL_LPDF_HPP
3+
#ifdef STAN_OPENCL
4+
5+
#include <stan/math/opencl/kernel_generator.hpp>
6+
#include <stan/math/opencl/prim/size.hpp>
7+
#include <stan/math/prim/meta.hpp>
8+
#include <stan/math/prim/err.hpp>
9+
#include <stan/math/prim/fun/elt_divide.hpp>
10+
#include <stan/math/prim/fun/max_size.hpp>
11+
#include <stan/math/prim/functor/operands_and_partials.hpp>
12+
13+
namespace stan {
14+
namespace math {
15+
16+
/** \ingroup opencl
17+
* Returns the Gumbel log probability density for the given
18+
* location and scale. Given containers of matching sizes, returns the
19+
* log sum of densities.
20+
*
21+
* @tparam T_y_cl type of real parameter
22+
* @tparam T_loc_cl type of location parameter
23+
* @tparam T_scale_cl type of scale parameter
24+
* @param y real parameter
25+
* @param mu location parameter
26+
* @param beta scale parameter
27+
* @return log probability density or log sum of probability densities
28+
* @throw std::domain_error if y is nan, mu is infinite, or beta is nonpositive
29+
* @throw std::invalid_argument if container sizes mismatch
30+
*/
31+
template <
32+
bool propto, typename T_y_cl, typename T_loc_cl, typename T_scale_cl,
33+
require_all_prim_or_rev_kernel_expression_t<T_y_cl, T_loc_cl,
34+
T_scale_cl>* = nullptr,
35+
require_any_not_stan_scalar_t<T_y_cl, T_loc_cl, T_scale_cl>* = nullptr>
36+
return_type_t<T_y_cl, T_loc_cl, T_scale_cl> gumbel_lpdf(
37+
const T_y_cl& y, const T_loc_cl& mu, const T_scale_cl& beta) {
38+
using std::isfinite;
39+
using std::isnan;
40+
static const char* function = "gumbel_lpdf(OpenCL)";
41+
using T_partials_return = partials_return_t<T_y_cl, T_loc_cl, T_scale_cl>;
42+
43+
check_consistent_sizes(function, "Random variable", y, "Location parameter",
44+
mu, "Scale parameter", beta);
45+
const size_t N = max_size(y, mu, beta);
46+
if (N == 0) {
47+
return 0.0;
48+
}
49+
if (!include_summand<propto, T_y_cl, T_loc_cl, T_scale_cl>::value) {
50+
return 0.0;
51+
}
52+
53+
const auto& y_val = value_of(y);
54+
const auto& mu_val = value_of(mu);
55+
const auto& beta_val = value_of(beta);
56+
57+
auto check_y_not_nan
58+
= check_cl(function, "Random variable", y_val, "not NaN");
59+
auto y_not_nan_expr = !isnan(y_val);
60+
auto check_mu_finite
61+
= check_cl(function, "Location parameter", mu_val, "finite");
62+
auto mu_finite_expr = isfinite(mu_val);
63+
auto check_beta_positive
64+
= check_cl(function, "Scale parameter", beta_val, "positive ");
65+
auto beta_positive_expr = beta_val > 0;
66+
67+
auto inv_beta_expr = elt_divide(1.0, beta_val);
68+
auto y_minus_mu_over_beta_expr = elt_multiply(y_val - mu_val, inv_beta_expr);
69+
auto exp_y_m_mu_over_beta_expr = exp(-y_minus_mu_over_beta_expr);
70+
71+
auto logp1_expr = -y_minus_mu_over_beta_expr - exp_y_m_mu_over_beta_expr;
72+
auto logp_expr
73+
= colwise_sum(static_select<include_summand<propto, T_scale_cl>::value>(
74+
logp1_expr - log(beta_val), logp1_expr));
75+
76+
auto scaled_diff_expr
77+
= elt_multiply(inv_beta_expr, exp_y_m_mu_over_beta_expr) - inv_beta_expr;
78+
auto beta_deriv_expr
79+
= elt_multiply(-y_minus_mu_over_beta_expr, scaled_diff_expr)
80+
- inv_beta_expr;
81+
82+
matrix_cl<double> logp_cl;
83+
matrix_cl<double> y_deriv_cl;
84+
matrix_cl<double> mu_deriv_cl;
85+
matrix_cl<double> beta_deriv_cl;
86+
87+
results(check_y_not_nan, check_mu_finite, check_beta_positive, logp_cl,
88+
y_deriv_cl, mu_deriv_cl, beta_deriv_cl)
89+
= expressions(y_not_nan_expr, mu_finite_expr, beta_positive_expr,
90+
logp_expr,
91+
calc_if<!is_constant<T_y_cl>::value>(scaled_diff_expr),
92+
calc_if<!is_constant<T_loc_cl>::value>(-scaled_diff_expr),
93+
calc_if<!is_constant<T_scale_cl>::value>(beta_deriv_expr));
94+
95+
T_partials_return logp = sum(from_matrix_cl(logp_cl));
96+
97+
operands_and_partials<T_y_cl, T_loc_cl, T_scale_cl> ops_partials(y, mu, beta);
98+
if (!is_constant<T_y_cl>::value) {
99+
ops_partials.edge1_.partials_ = std::move(y_deriv_cl);
100+
}
101+
if (!is_constant<T_loc_cl>::value) {
102+
ops_partials.edge2_.partials_ = std::move(mu_deriv_cl);
103+
}
104+
if (!is_constant<T_scale_cl>::value) {
105+
ops_partials.edge3_.partials_ = std::move(beta_deriv_cl);
106+
}
107+
108+
return ops_partials.build(logp);
109+
}
110+
111+
} // namespace math
112+
} // namespace stan
113+
#endif
114+
#endif

stan/math/opencl/prim/size.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace math {
99

1010
/**
1111
* Returns the size (number of the elements) of a `matrix_cl`.
12-
* @param m `matric_cl` to determine size of
12+
* @param m `matrix_cl` to determine size of
1313
* @return number of elements in a
1414
*/
1515
template <typename T,

stan/math/prim/prob/gamma_lpdf.hpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ namespace math {
4242
* @throw std::domain_error if beta is not greater than 0.
4343
* @throw std::domain_error if y is not greater than or equal to 0.
4444
*/
45-
template <bool propto, typename T_y, typename T_shape, typename T_inv_scale>
45+
template <bool propto, typename T_y, typename T_shape, typename T_inv_scale,
46+
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
47+
T_y, T_shape, T_inv_scale>* = nullptr>
4648
return_type_t<T_y, T_shape, T_inv_scale> gamma_lpdf(const T_y& y,
4749
const T_shape& alpha,
4850
const T_inv_scale& beta) {

stan/math/prim/prob/gumbel_lpdf.hpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ namespace math {
3131
* @throw std::domain_error if y is nan, mu is infinite, or beta is nonpositive
3232
* @throw std::invalid_argument if container sizes mismatch
3333
*/
34-
template <bool propto, typename T_y, typename T_loc, typename T_scale>
34+
template <bool propto, typename T_y, typename T_loc, typename T_scale,
35+
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
36+
T_y, T_loc, T_scale>* = nullptr>
3537
return_type_t<T_y, T_loc, T_scale> gumbel_lpdf(const T_y& y, const T_loc& mu,
3638
const T_scale& beta) {
3739
using T_partials_return = partials_return_t<T_y, T_loc, T_scale>;
@@ -59,7 +61,6 @@ return_type_t<T_y, T_loc, T_scale> gumbel_lpdf(const T_y& y, const T_loc& mu,
5961

6062
check_not_nan(function, "Random variable", y_val);
6163
check_finite(function, "Location parameter", mu_val);
62-
check_not_nan(function, "Scale parameter", beta_val);
6364
check_positive(function, "Scale parameter", beta_val);
6465

6566
if (size_zero(y, mu, beta)) {

test/unit/math/opencl/kernel_generator/reference_kernels/check_cl_positive.cl

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ double var2 = 0; if (!((!contains_nonzero(var2_view, LOWER) && j < i) || (!conta
55
bool var1 = var2 > var3;
66
double var5 = 0; if (!((!contains_nonzero(var5_view, LOWER) && j < i) || (!contains_nonzero(var5_view, UPPER) && j > i))) {var5 = var5_global[i + var5_rows * j];}
77
bool var4 = var1;
8-
if(!var4 && atomic_xchg(var4_buffer, 1) == 1){
8+
if(!var4 && atomic_xchg(var4_buffer, 1) == 0){
99
var4_buffer[1] = i;
1010
var4_buffer[2] = j;
1111
var4_value[0] = var5;

0 commit comments

Comments
 (0)