Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reorganize threads in generated OpenCL kernels #2392

Merged
merged 11 commits into from
Mar 2, 2021
68 changes: 58 additions & 10 deletions stan/math/opencl/kernel_generator/colwise_reduction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <stan/math/opencl/kernel_generator/operation_cl.hpp>
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
#include <stan/math/opencl/kernel_generator/rowwise_reduction.hpp>
#include <stan/math/opencl/kernel_generator/calc_if.hpp>
#include <map>
#include <string>
#include <type_traits>
Expand All @@ -21,6 +22,27 @@ namespace math {
* @{
*/

namespace internal {
class colwise_reduction_base {};

/**
* Determine number of work groups in rows direction that will be run fro
* colwise reduction of given size.
* @param n_rows number of rows of expression to resuce
* @param n_cols number of columns of expression to resuce
* @return number of work groups in rows direction
*/
inline int colwise_reduction_wgs_rows(int n_rows, int n_cols) {
int local = opencl_context.base_opts().at("LOCAL_SIZE_");
int preferred_work_groups
= opencl_context.device()[0].getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>() * 16;
// round up n_rows/local/n_cols
return (std::min(preferred_work_groups, (n_rows + local - 1) / local) + n_cols
- 1)
/ n_cols;
}
} // namespace internal

/**
* Represents a column wise reduction in kernel generator expressions. So as to
* be efficient column wise reductions are only done partially. That means
Expand All @@ -36,7 +58,8 @@ namespace math {
*/
template <typename Derived, typename T, typename Operation>
class colwise_reduction
: public operation_cl<Derived, typename std::remove_reference_t<T>::Scalar,
: public internal::colwise_reduction_base,
public operation_cl<Derived, typename std::remove_reference_t<T>::Scalar,
T> {
public:
using Scalar = typename std::remove_reference_t<T>::Scalar;
Expand Down Expand Up @@ -84,7 +107,7 @@ class colwise_reduction
parts.args += out_parts.args;
parts.reduction += "if (lid_i == 0) {\n"
+ result.var_name_
+ "_global[j * blocks_rows + wg_id_i] = "
+ "_global[j * n_groups_i + wg_id_i] = "
+ derived().var_name_ + "_local[0];\n"
"}\n";
return parts;
Expand All @@ -105,10 +128,11 @@ class colwise_reduction
const std::string& var_name_arg) const {
kernel_parts res;
res.declarations = "__local " + type_str<Scalar>() + " " + var_name_
+ "_local[LOCAL_SIZE_];\n";
res.initialization
= type_str<Scalar>() + " " + var_name_ + " = " + init_ + ";\n";
res.body = var_name_ + " = " + var_name_arg + ";\n";
+ "_local[LOCAL_SIZE_];\n" + type_str<Scalar>() + " "
+ var_name_ + ";\n";
res.initialization = var_name_ + " = " + init_ + ";\n";
res.body = var_name_ + " = " + Operation::generate(var_name_, var_name_arg)
+ ";\n";
res.reduction =
var_name_ + "_local[lid_i] = " + var_name_ + ";\n"
"barrier(CLK_LOCAL_MEM_FENCE);\n"
Expand All @@ -132,10 +156,15 @@ class colwise_reduction
* @return number of rows
*/
inline int rows() const {
int local_rows = opencl_context.base_opts().at("LOCAL_SIZE_");
int wgs_rows
= (this->template get_arg<0>().rows() + local_rows - 1) / local_rows;
return wgs_rows;
int arg_rows = this->template get_arg<0>().rows();
int arg_cols = this->template get_arg<0>().cols();
if (arg_cols == 0) {
return 1;
}
if (arg_cols == -1) {
return -1;
}
return internal::colwise_reduction_wgs_rows(arg_rows, arg_cols);
}

/**
Expand Down Expand Up @@ -287,6 +316,25 @@ inline auto colwise_min(T&& a) {
return colwise_min_<as_operation_cl_t<T>>(
as_operation_cl(std::forward<T>(a)));
}

namespace internal {
template <typename T>
struct is_colwise_reduction_impl
: public std::is_base_of<internal::colwise_reduction_base,
std::decay_t<T>> {};
template <typename T>
struct is_colwise_reduction_impl<calc_if_<true, T>>
: public std::is_base_of<internal::colwise_reduction_base,
std::decay_t<T>> {};
} // namespace internal

/**
* Check whether a kernel generator expression is a colwise reduction.
*/
template <typename T>
using is_colwise_reduction
= internal::is_colwise_reduction_impl<std::decay_t<T>>;

/** @}*/
} // namespace math
} // namespace stan
Expand Down
2 changes: 1 addition & 1 deletion stan/math/opencl/kernel_generator/constant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class constant_ : public operation_cl<constant_<T>, T> {

public:
static_assert(std::is_arithmetic<T>::value,
"class scalar_<T>: std::is_arithmetic<T> must be true!");
"class constant_<T>: std::is_arithmetic<T> must be true!");
using Scalar = T;
using base = operation_cl<constant_<T>, T>;
using base::var_name_;
Expand Down
37 changes: 21 additions & 16 deletions stan/math/opencl/kernel_generator/multi_result_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
#include <stan/math/opencl/kernel_generator/calc_if.hpp>
#include <stan/math/opencl/kernel_generator/check_cl.hpp>
#include <stan/math/opencl/kernel_generator/colwise_reduction.hpp>
#include <stan/math/opencl/kernel_generator/load.hpp>
#include <stan/math/opencl/opencl_context.hpp>
#include <algorithm>
Expand Down Expand Up @@ -71,23 +72,29 @@ struct multi_result_kernel_internal {
const auto& result = std::get<N>(assignment_pairs).first;
const char* function = "results.operator=";

int expressin_rows = expression.rows();
int expressin_cols = expression.cols();
int expression_rows = expression.rows();
int expression_cols = expression.cols();
if (is_colwise_reduction<T_current_expression>::value
&& expression_cols == -1) {
expression_cols = n_cols;
expression_rows = internal::colwise_reduction_wgs_rows(
expression.thread_rows(), expression_cols);
}
if (expression.thread_rows() != -1) {
check_size_match(function, "Rows of ", "expression",
expression.thread_rows(), "rows of ",
"first expression", n_rows);
} else {
expressin_rows = n_rows;
expression_rows = n_rows;
}
if (expression.thread_cols() != -1) {
check_size_match(function, "Columns of ", "expression",
expression.thread_cols(), "columns of ",
"first expression", n_cols);
} else {
expressin_cols = n_cols;
expression_cols = n_cols;
}
result.check_assign_dimensions(expressin_rows, expressin_cols);
result.check_assign_dimensions(expression_rows, expression_cols);
int bottom_written = 1 - expression.rows();
int top_written = expression.cols() - 1;
std::pair<int, int> extreme_diagonals = expression.extreme_diagonals();
Expand Down Expand Up @@ -373,21 +380,17 @@ class results_cl {
"kernel void calculate(" + parts.args +
"const int rows, const int cols){\n"
"const int gid_i = get_global_id(0);\n"
"const int gid_j = get_global_id(1);\n"
"const int lid_i = get_local_id(0);\n"
"const int lsize_i = get_local_size(0);\n"
"const int gsize_i = get_global_size(0);\n"
"const int gsize_j = get_global_size(1);\n"
"const int wg_id_i = get_group_id(0);\n"
"const int wg_id_j = get_group_id(1);\n"
"const int n_groups_i = get_num_groups(0);\n"
"const int blocks_rows = (rows + lsize_i - 1) / lsize_i;\n"
"const int blocks_cols = (cols + lsize_i - 1) / lsize_i;\n"
"const int i0 = lsize_i * wg_id_i;\n"
"const int i = i0 + lid_i;\n"
"const int j0 = lsize_i * wg_id_j;\n"
+ parts.declarations +
"for(int lid_j = 0; lid_j < min(cols - j0, lsize_i); lid_j++){\n"
"const int j = j0 + lid_j;\n"
"for(int j = gid_j; j < cols; j+=gsize_j){\n"
+ parts.initialization +
"if(i < rows){\n"
"for(int i = gid_i; i < rows; i+=gsize_i){\n"
+ parts.body
+ parts.body_suffix +
"}\n"
Expand Down Expand Up @@ -481,9 +484,11 @@ class results_cl {
if (require_specific_local_size) {
kernel.setArg(arg_num++, n_rows);
kernel.setArg(arg_num++, n_cols);

int local = opencl_context.base_opts().at("LOCAL_SIZE_");
int wgs_rows = (n_rows + local - 1) / local;
int wgs_cols = (n_cols + local - 1) / local;

int wgs_rows = internal::colwise_reduction_wgs_rows(n_rows, n_cols);
int wgs_cols = (n_cols + wgs_rows - 1) / wgs_rows;

opencl_context.queue().enqueueNDRangeKernel(
kernel, cl::NullRange, cl::NDRange(local * wgs_rows, wgs_cols),
Expand Down
26 changes: 14 additions & 12 deletions stan/math/opencl/prim/dirichlet_lpdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,28 +122,30 @@ inline return_type_t<T_prob_cl, T_prior_size_cl> dirichlet_lpdf(
calc_if<!is_constant<T_prob_cl>::value>(theta_deriv),
calc_if<!is_constant<T_prior_size_cl>::value>(alpha_deriv));

while (theta_csum_cl.rows() > 1) {
matrix_cl<double> theta_csum_cl2;
matrix_cl<double> theta_log_alpha_m_1_csum_cl2;
if (include_summand<propto, T_prior_size_cl>::value) {
matrix_cl<double> alpha_csum_cl2;
matrix_cl<double> lgamma_alpha_csum_cl2;
while (alpha_csum_cl.rows() > 1) {
results(alpha_csum_cl2, lgamma_alpha_csum_cl2) = expressions(
calc_if<include_summand<propto, T_prior_size_cl>::value>(
colwise_sum(alpha_csum_cl)),
calc_if<include_summand<propto, T_prior_size_cl>::value>(
colwise_sum(lgamma_alpha_csum_cl)));
alpha_csum_cl = std::move(alpha_csum_cl2);
lgamma_alpha_csum_cl = std::move(lgamma_alpha_csum_cl2);
}
}
matrix_cl<double> theta_csum_cl2;
matrix_cl<double> theta_log_alpha_m_1_csum_cl2;
while (theta_csum_cl.rows() > 1) {
results(theta_csum_cl2, theta_log_alpha_m_1_csum_cl2) = expressions(
colwise_sum(theta_csum_cl),
calc_if<include_summand<propto, T_prob_cl, T_prior_size_cl>::value>(
colwise_sum(theta_log_alpha_m_1_csum_cl)));
results(alpha_csum_cl2, lgamma_alpha_csum_cl2) = expressions(
calc_if<include_summand<propto, T_prior_size_cl>::value>(
colwise_sum(alpha_csum_cl)),
calc_if<include_summand<propto, T_prior_size_cl>::value>(
colwise_sum(lgamma_alpha_csum_cl)));
theta_csum_cl = std::move(theta_csum_cl2);
if (include_summand<propto, T_prob_cl, T_prior_size_cl>::value) {
theta_log_alpha_m_1_csum_cl = std::move(theta_log_alpha_m_1_csum_cl2);
}
if (include_summand<propto, T_prior_size_cl>::value) {
alpha_csum_cl = std::move(alpha_csum_cl2);
lgamma_alpha_csum_cl = std::move(lgamma_alpha_csum_cl2);
}
}
} else {
auto alpha_csum = colwise_sum(alpha_val);
Expand Down
55 changes: 42 additions & 13 deletions test/unit/math/opencl/kernel_generator/colwise_reduction_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using Eigen::Dynamic;
using Eigen::MatrixXd;
using Eigen::MatrixXi;
using Eigen::VectorXd;
using stan::math::matrix_cl;

TEST(KernelGenerator, colwise_sum_test) {
Expand Down Expand Up @@ -123,19 +124,24 @@ TEST(KernelGenerator, nested_rowwise_colwise_sum) {
}

TEST(KernelGenerator, colwise_sum_test_large) {
int N = 153;
int M = 201;
MatrixXd m = MatrixXd::Random(N, M);

matrix_cl<double> m_cl(m);
matrix_cl<double> res_cl = stan::math::colwise_sum(m_cl);
MatrixXd raw_res = stan::math::from_matrix_cl(res_cl);
EXPECT_GT(m.rows(), raw_res.rows());
MatrixXd res = raw_res.colwise().sum();
MatrixXd correct = m.colwise().sum();
EXPECT_EQ(correct.rows(), res.rows());
EXPECT_EQ(correct.cols(), res.cols());
EXPECT_MATRIX_NEAR(correct, res, 1e-9);
for (int M : {1, 2, 5, 9, 63, 64, 65, 4095, 4096, 4967, 4096 * 4}) {
for (int N : {1, 2, 5, 9, 63, 64, 65, 4095, 4096, 4967, 4096 * 4}) {
if (N * M > 1e6) {
continue;
}
MatrixXd m = MatrixXd::Random(N, M);

matrix_cl<double> m_cl(m);
matrix_cl<double> res_cl = stan::math::colwise_sum(m_cl);
MatrixXd raw_res = stan::math::from_matrix_cl(res_cl);
EXPECT_GE(m.rows(), raw_res.rows());
MatrixXd res = raw_res.colwise().sum();
MatrixXd correct = m.colwise().sum();
EXPECT_EQ(correct.rows(), res.rows());
EXPECT_EQ(correct.cols(), res.cols());
EXPECT_MATRIX_NEAR(correct, res, 1e-9);
}
}
}

TEST(KernelGenerator, colwise_sum_and_id_test) {
Expand All @@ -158,4 +164,27 @@ TEST(KernelGenerator, colwise_sum_and_id_test) {
EXPECT_MATRIX_NEAR(correct2, res2, 1e-9);
}

TEST(KernelGenerator, colwise_reduction_of_rowwise_broadcast_test) {
MatrixXd m(3, 2);
m << 1.1, 1.2, 1.3, 1.4, 1.5, 1.6;
VectorXd v(3);
v << 4, 5, 6;

matrix_cl<double> m_cl(m);
matrix_cl<double> v_cl(v);

matrix_cl<double> res1_cl;
matrix_cl<double> res2_cl;

stan::math::results(res1_cl, res2_cl) = stan::math::expressions(
m_cl, stan::math::colwise_sum(stan::math::rowwise_broadcast(v_cl)));
MatrixXd raw_res = stan::math::from_matrix_cl(res2_cl);
EXPECT_GE(v.rows(), raw_res.rows());
MatrixXd res = raw_res.colwise().sum();
MatrixXd correct = v.replicate(1, 2).colwise().sum();
EXPECT_EQ(correct.rows(), res.rows());
EXPECT_EQ(correct.cols(), res.cols());
EXPECT_MATRIX_NEAR(correct, res, 1e-9);
}

#endif
Original file line number Diff line number Diff line change
@@ -1,22 +1,19 @@
kernel void calculate(__global double* var2_global, int var2_rows, int var2_view, __global double* var3_global, int var3_rows, int var3_view, const int rows, const int cols){
const int gid_i = get_global_id(0);
const int gid_j = get_global_id(1);
const int lid_i = get_local_id(0);
const int lsize_i = get_local_size(0);
const int gsize_i = get_global_size(0);
const int gsize_j = get_global_size(1);
const int wg_id_i = get_group_id(0);
const int wg_id_j = get_group_id(1);
const int n_groups_i = get_num_groups(0);
const int blocks_rows = (rows + lsize_i - 1) / lsize_i;
const int blocks_cols = (cols + lsize_i - 1) / lsize_i;
const int i0 = lsize_i * wg_id_i;
const int i = i0 + lid_i;
const int j0 = lsize_i * wg_id_j;
__local double var1_local[LOCAL_SIZE_];
for(int lid_j = 0; lid_j < min(cols - j0, lsize_i); lid_j++){
const int j = j0 + lid_j;
double var1 = 0;
if(i < rows){
double var1;
for(int j = gid_j; j < cols; j+=gsize_j){
var1 = 0;
for(int i = gid_i; i < rows; i+=gsize_i){
double var2 = 0; if (!((!contains_nonzero(var2_view, LOWER) && j < i) || (!contains_nonzero(var2_view, UPPER) && j > i))) {var2 = var2_global[i + var2_rows * j];}
var1 = var2;
var1 = var1 + var2;
}
var1_local[lid_i] = var1;
barrier(CLK_LOCAL_MEM_FENCE);
Expand All @@ -29,7 +26,7 @@ for (int step = lsize_i / REDUCTION_STEP_SIZE; step > 0; step /= REDUCTION_STEP_
barrier(CLK_LOCAL_MEM_FENCE);
}
if (lid_i == 0) {
var3_global[j * blocks_rows + wg_id_i] = var1_local[0];
var3_global[j * n_groups_i + wg_id_i] = var1_local[0];
}
}
}
Loading