diff --git a/stan/math/opencl/kernel_generator/colwise_reduction.hpp b/stan/math/opencl/kernel_generator/colwise_reduction.hpp index 4e3972517f4..3d24eff50ce 100644 --- a/stan/math/opencl/kernel_generator/colwise_reduction.hpp +++ b/stan/math/opencl/kernel_generator/colwise_reduction.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -21,6 +22,27 @@ namespace math { * @{ */ +namespace internal { +class colwise_reduction_base {}; + +/** + * Determine number of work groups in rows direction that will be run fro + * colwise reduction of given size. + * @param n_rows number of rows of expression to resuce + * @param n_cols number of columns of expression to resuce + * @return number of work groups in rows direction + */ +inline int colwise_reduction_wgs_rows(int n_rows, int n_cols) { + int local = opencl_context.base_opts().at("LOCAL_SIZE_"); + int preferred_work_groups + = opencl_context.device()[0].getInfo() * 16; + // round up n_rows/local/n_cols + return (std::min(preferred_work_groups, (n_rows + local - 1) / local) + n_cols + - 1) + / n_cols; +} +} // namespace internal + /** * Represents a column wise reduction in kernel generator expressions. So as to * be efficient column wise reductions are only done partially. That means @@ -36,7 +58,8 @@ namespace math { */ template class colwise_reduction - : public operation_cl::Scalar, + : public internal::colwise_reduction_base, + public operation_cl::Scalar, T> { public: using Scalar = typename std::remove_reference_t::Scalar; @@ -84,7 +107,7 @@ class colwise_reduction parts.args += out_parts.args; parts.reduction += "if (lid_i == 0) {\n" + result.var_name_ - + "_global[j * blocks_rows + wg_id_i] = " + + "_global[j * n_groups_i + wg_id_i] = " + derived().var_name_ + "_local[0];\n" "}\n"; return parts; @@ -105,10 +128,11 @@ class colwise_reduction const std::string& var_name_arg) const { kernel_parts res; res.declarations = "__local " + type_str() + " " + var_name_ - + "_local[LOCAL_SIZE_];\n"; - res.initialization - = type_str() + " " + var_name_ + " = " + init_ + ";\n"; - res.body = var_name_ + " = " + var_name_arg + ";\n"; + + "_local[LOCAL_SIZE_];\n" + type_str() + " " + + var_name_ + ";\n"; + res.initialization = var_name_ + " = " + init_ + ";\n"; + res.body = var_name_ + " = " + Operation::generate(var_name_, var_name_arg) + + ";\n"; res.reduction = var_name_ + "_local[lid_i] = " + var_name_ + ";\n" "barrier(CLK_LOCAL_MEM_FENCE);\n" @@ -132,10 +156,15 @@ class colwise_reduction * @return number of rows */ inline int rows() const { - int local_rows = opencl_context.base_opts().at("LOCAL_SIZE_"); - int wgs_rows - = (this->template get_arg<0>().rows() + local_rows - 1) / local_rows; - return wgs_rows; + int arg_rows = this->template get_arg<0>().rows(); + int arg_cols = this->template get_arg<0>().cols(); + if (arg_cols == 0) { + return 1; + } + if (arg_cols == -1) { + return -1; + } + return internal::colwise_reduction_wgs_rows(arg_rows, arg_cols); } /** @@ -287,6 +316,25 @@ inline auto colwise_min(T&& a) { return colwise_min_>( as_operation_cl(std::forward(a))); } + +namespace internal { +template +struct is_colwise_reduction_impl + : public std::is_base_of> {}; +template +struct is_colwise_reduction_impl> + : public std::is_base_of> {}; +} // namespace internal + +/** + * Check whether a kernel generator expression is a colwise reduction. + */ +template +using is_colwise_reduction + = internal::is_colwise_reduction_impl>; + /** @}*/ } // namespace math } // namespace stan diff --git a/stan/math/opencl/kernel_generator/constant.hpp b/stan/math/opencl/kernel_generator/constant.hpp index 7211b9c8c9f..313309e8746 100644 --- a/stan/math/opencl/kernel_generator/constant.hpp +++ b/stan/math/opencl/kernel_generator/constant.hpp @@ -33,7 +33,7 @@ class constant_ : public operation_cl, T> { public: static_assert(std::is_arithmetic::value, - "class scalar_: std::is_arithmetic must be true!"); + "class constant_: std::is_arithmetic must be true!"); using Scalar = T; using base = operation_cl, T>; using base::var_name_; diff --git a/stan/math/opencl/kernel_generator/multi_result_kernel.hpp b/stan/math/opencl/kernel_generator/multi_result_kernel.hpp index 98ea0a1149e..f895a058314 100644 --- a/stan/math/opencl/kernel_generator/multi_result_kernel.hpp +++ b/stan/math/opencl/kernel_generator/multi_result_kernel.hpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -71,23 +72,29 @@ struct multi_result_kernel_internal { const auto& result = std::get(assignment_pairs).first; const char* function = "results.operator="; - int expressin_rows = expression.rows(); - int expressin_cols = expression.cols(); + int expression_rows = expression.rows(); + int expression_cols = expression.cols(); + if (is_colwise_reduction::value + && expression_cols == -1) { + expression_cols = n_cols; + expression_rows = internal::colwise_reduction_wgs_rows( + expression.thread_rows(), expression_cols); + } if (expression.thread_rows() != -1) { check_size_match(function, "Rows of ", "expression", expression.thread_rows(), "rows of ", "first expression", n_rows); } else { - expressin_rows = n_rows; + expression_rows = n_rows; } if (expression.thread_cols() != -1) { check_size_match(function, "Columns of ", "expression", expression.thread_cols(), "columns of ", "first expression", n_cols); } else { - expressin_cols = n_cols; + expression_cols = n_cols; } - result.check_assign_dimensions(expressin_rows, expressin_cols); + result.check_assign_dimensions(expression_rows, expression_cols); int bottom_written = 1 - expression.rows(); int top_written = expression.cols() - 1; std::pair extreme_diagonals = expression.extreme_diagonals(); @@ -373,21 +380,17 @@ class results_cl { "kernel void calculate(" + parts.args + "const int rows, const int cols){\n" "const int gid_i = get_global_id(0);\n" + "const int gid_j = get_global_id(1);\n" "const int lid_i = get_local_id(0);\n" "const int lsize_i = get_local_size(0);\n" + "const int gsize_i = get_global_size(0);\n" + "const int gsize_j = get_global_size(1);\n" "const int wg_id_i = get_group_id(0);\n" - "const int wg_id_j = get_group_id(1);\n" "const int n_groups_i = get_num_groups(0);\n" - "const int blocks_rows = (rows + lsize_i - 1) / lsize_i;\n" - "const int blocks_cols = (cols + lsize_i - 1) / lsize_i;\n" - "const int i0 = lsize_i * wg_id_i;\n" - "const int i = i0 + lid_i;\n" - "const int j0 = lsize_i * wg_id_j;\n" + parts.declarations + - "for(int lid_j = 0; lid_j < min(cols - j0, lsize_i); lid_j++){\n" - "const int j = j0 + lid_j;\n" + "for(int j = gid_j; j < cols; j+=gsize_j){\n" + parts.initialization + - "if(i < rows){\n" + "for(int i = gid_i; i < rows; i+=gsize_i){\n" + parts.body + parts.body_suffix + "}\n" @@ -481,9 +484,11 @@ class results_cl { if (require_specific_local_size) { kernel.setArg(arg_num++, n_rows); kernel.setArg(arg_num++, n_cols); + int local = opencl_context.base_opts().at("LOCAL_SIZE_"); - int wgs_rows = (n_rows + local - 1) / local; - int wgs_cols = (n_cols + local - 1) / local; + + int wgs_rows = internal::colwise_reduction_wgs_rows(n_rows, n_cols); + int wgs_cols = (n_cols + wgs_rows - 1) / wgs_rows; opencl_context.queue().enqueueNDRangeKernel( kernel, cl::NullRange, cl::NDRange(local * wgs_rows, wgs_cols), diff --git a/stan/math/opencl/prim/dirichlet_lpdf.hpp b/stan/math/opencl/prim/dirichlet_lpdf.hpp index d28e7fdc943..eaa05aaa061 100644 --- a/stan/math/opencl/prim/dirichlet_lpdf.hpp +++ b/stan/math/opencl/prim/dirichlet_lpdf.hpp @@ -122,28 +122,30 @@ inline return_type_t dirichlet_lpdf( calc_if::value>(theta_deriv), calc_if::value>(alpha_deriv)); - while (theta_csum_cl.rows() > 1) { - matrix_cl theta_csum_cl2; - matrix_cl theta_log_alpha_m_1_csum_cl2; + if (include_summand::value) { matrix_cl alpha_csum_cl2; matrix_cl lgamma_alpha_csum_cl2; + while (alpha_csum_cl.rows() > 1) { + results(alpha_csum_cl2, lgamma_alpha_csum_cl2) = expressions( + calc_if::value>( + colwise_sum(alpha_csum_cl)), + calc_if::value>( + colwise_sum(lgamma_alpha_csum_cl))); + alpha_csum_cl = std::move(alpha_csum_cl2); + lgamma_alpha_csum_cl = std::move(lgamma_alpha_csum_cl2); + } + } + matrix_cl theta_csum_cl2; + matrix_cl theta_log_alpha_m_1_csum_cl2; + while (theta_csum_cl.rows() > 1) { results(theta_csum_cl2, theta_log_alpha_m_1_csum_cl2) = expressions( colwise_sum(theta_csum_cl), calc_if::value>( colwise_sum(theta_log_alpha_m_1_csum_cl))); - results(alpha_csum_cl2, lgamma_alpha_csum_cl2) = expressions( - calc_if::value>( - colwise_sum(alpha_csum_cl)), - calc_if::value>( - colwise_sum(lgamma_alpha_csum_cl))); theta_csum_cl = std::move(theta_csum_cl2); if (include_summand::value) { theta_log_alpha_m_1_csum_cl = std::move(theta_log_alpha_m_1_csum_cl2); } - if (include_summand::value) { - alpha_csum_cl = std::move(alpha_csum_cl2); - lgamma_alpha_csum_cl = std::move(lgamma_alpha_csum_cl2); - } } } else { auto alpha_csum = colwise_sum(alpha_val); diff --git a/test/unit/math/opencl/kernel_generator/colwise_reduction_test.cpp b/test/unit/math/opencl/kernel_generator/colwise_reduction_test.cpp index b5754a1b5d0..b9ff37785bb 100644 --- a/test/unit/math/opencl/kernel_generator/colwise_reduction_test.cpp +++ b/test/unit/math/opencl/kernel_generator/colwise_reduction_test.cpp @@ -13,6 +13,7 @@ using Eigen::Dynamic; using Eigen::MatrixXd; using Eigen::MatrixXi; +using Eigen::VectorXd; using stan::math::matrix_cl; TEST(KernelGenerator, colwise_sum_test) { @@ -123,19 +124,24 @@ TEST(KernelGenerator, nested_rowwise_colwise_sum) { } TEST(KernelGenerator, colwise_sum_test_large) { - int N = 153; - int M = 201; - MatrixXd m = MatrixXd::Random(N, M); - - matrix_cl m_cl(m); - matrix_cl res_cl = stan::math::colwise_sum(m_cl); - MatrixXd raw_res = stan::math::from_matrix_cl(res_cl); - EXPECT_GT(m.rows(), raw_res.rows()); - MatrixXd res = raw_res.colwise().sum(); - MatrixXd correct = m.colwise().sum(); - EXPECT_EQ(correct.rows(), res.rows()); - EXPECT_EQ(correct.cols(), res.cols()); - EXPECT_MATRIX_NEAR(correct, res, 1e-9); + for (int M : {1, 2, 5, 9, 63, 64, 65, 4095, 4096, 4967, 4096 * 4}) { + for (int N : {1, 2, 5, 9, 63, 64, 65, 4095, 4096, 4967, 4096 * 4}) { + if (N * M > 1e6) { + continue; + } + MatrixXd m = MatrixXd::Random(N, M); + + matrix_cl m_cl(m); + matrix_cl res_cl = stan::math::colwise_sum(m_cl); + MatrixXd raw_res = stan::math::from_matrix_cl(res_cl); + EXPECT_GE(m.rows(), raw_res.rows()); + MatrixXd res = raw_res.colwise().sum(); + MatrixXd correct = m.colwise().sum(); + EXPECT_EQ(correct.rows(), res.rows()); + EXPECT_EQ(correct.cols(), res.cols()); + EXPECT_MATRIX_NEAR(correct, res, 1e-9); + } + } } TEST(KernelGenerator, colwise_sum_and_id_test) { @@ -158,4 +164,27 @@ TEST(KernelGenerator, colwise_sum_and_id_test) { EXPECT_MATRIX_NEAR(correct2, res2, 1e-9); } +TEST(KernelGenerator, colwise_reduction_of_rowwise_broadcast_test) { + MatrixXd m(3, 2); + m << 1.1, 1.2, 1.3, 1.4, 1.5, 1.6; + VectorXd v(3); + v << 4, 5, 6; + + matrix_cl m_cl(m); + matrix_cl v_cl(v); + + matrix_cl res1_cl; + matrix_cl res2_cl; + + stan::math::results(res1_cl, res2_cl) = stan::math::expressions( + m_cl, stan::math::colwise_sum(stan::math::rowwise_broadcast(v_cl))); + MatrixXd raw_res = stan::math::from_matrix_cl(res2_cl); + EXPECT_GE(v.rows(), raw_res.rows()); + MatrixXd res = raw_res.colwise().sum(); + MatrixXd correct = v.replicate(1, 2).colwise().sum(); + EXPECT_EQ(correct.rows(), res.rows()); + EXPECT_EQ(correct.cols(), res.cols()); + EXPECT_MATRIX_NEAR(correct, res, 1e-9); +} + #endif diff --git a/test/unit/math/opencl/kernel_generator/reference_kernels/colwise_sum.cl b/test/unit/math/opencl/kernel_generator/reference_kernels/colwise_sum.cl index 855dd535ae3..64e78e8c727 100644 --- a/test/unit/math/opencl/kernel_generator/reference_kernels/colwise_sum.cl +++ b/test/unit/math/opencl/kernel_generator/reference_kernels/colwise_sum.cl @@ -1,22 +1,19 @@ kernel void calculate(__global double* var2_global, int var2_rows, int var2_view, __global double* var3_global, int var3_rows, int var3_view, const int rows, const int cols){ const int gid_i = get_global_id(0); +const int gid_j = get_global_id(1); const int lid_i = get_local_id(0); const int lsize_i = get_local_size(0); +const int gsize_i = get_global_size(0); +const int gsize_j = get_global_size(1); const int wg_id_i = get_group_id(0); -const int wg_id_j = get_group_id(1); const int n_groups_i = get_num_groups(0); -const int blocks_rows = (rows + lsize_i - 1) / lsize_i; -const int blocks_cols = (cols + lsize_i - 1) / lsize_i; -const int i0 = lsize_i * wg_id_i; -const int i = i0 + lid_i; -const int j0 = lsize_i * wg_id_j; __local double var1_local[LOCAL_SIZE_]; -for(int lid_j = 0; lid_j < min(cols - j0, lsize_i); lid_j++){ -const int j = j0 + lid_j; -double var1 = 0; -if(i < rows){ +double var1; +for(int j = gid_j; j < cols; j+=gsize_j){ +var1 = 0; +for(int i = gid_i; i < rows; i+=gsize_i){ double var2 = 0; if (!((!contains_nonzero(var2_view, LOWER) && j < i) || (!contains_nonzero(var2_view, UPPER) && j > i))) {var2 = var2_global[i + var2_rows * j];} -var1 = var2; +var1 = var1 + var2; } var1_local[lid_i] = var1; barrier(CLK_LOCAL_MEM_FENCE); @@ -29,7 +26,7 @@ for (int step = lsize_i / REDUCTION_STEP_SIZE; step > 0; step /= REDUCTION_STEP_ barrier(CLK_LOCAL_MEM_FENCE); } if (lid_i == 0) { -var3_global[j * blocks_rows + wg_id_i] = var1_local[0]; +var3_global[j * n_groups_i + wg_id_i] = var1_local[0]; } } } diff --git a/test/unit/math/opencl/util.hpp b/test/unit/math/opencl/util.hpp index 4bf58f32817..f922bea8908 100644 --- a/test/unit/math/opencl/util.hpp +++ b/test/unit/math/opencl/util.hpp @@ -161,23 +161,29 @@ void compare_cpu_opencl_prim_rev_impl(const Functor& functor, const Args&... args) { prim_rev_argument_combinations( [&functor](const auto& args_for_cpu, const auto& args_for_opencl) { - auto res_cpu = eval(functor(std::get(args_for_cpu)...)); - auto res_opencl - = eval(functor(opencl_argument(std::get(args_for_opencl))...)); std::string signature = type_name().data(); - expect_eq(res_opencl, res_cpu, - ("CPU and OpenCL return values do not match for signature " - + signature + "!") - .c_str()); - var(recursive_sum(res_cpu) + recursive_sum(res_opencl)).grad(); + try { + auto res_cpu = eval(functor(std::get(args_for_cpu)...)); + auto res_opencl = eval( + functor(opencl_argument(std::get(args_for_opencl))...)); + expect_eq(res_opencl, res_cpu, + ("CPU and OpenCL return values do not match for signature " + + signature + "!") + .c_str()); + var(recursive_sum(res_cpu) + recursive_sum(res_opencl)).grad(); - static_cast(std::initializer_list{ - (expect_adj_near( - std::get(args_for_opencl), std::get(args_for_cpu), - ("CPU and OpenCL adjoints do not match for argument " - + std::to_string(Is) + " for signature " + signature + "!") - .c_str()), - 0)...}); + static_cast(std::initializer_list{ + (expect_adj_near( + std::get(args_for_opencl), std::get(args_for_cpu), + ("CPU and OpenCL adjoints do not match for argument " + + std::to_string(Is) + " for signature " + signature + "!") + .c_str()), + 0)...}); + } catch (...) { + std::cerr << "exception thrown in signature " << signature << ":" + << std::endl; + throw; + } set_zero_all_adjoints(); },