Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OpenCL/prim add_diag, diag_matrix, subtract and minor OpenCL fixes #2250

Merged
merged 20 commits into from
Dec 11, 2020
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions stan/math/opencl/err/check_diagonal_zeros.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ inline void check_diagonal_zeros(const char* function, const char* name,
cl::Context ctx = opencl_context.context();
try {
int zero_on_diagonal_flag = 0;
matrix_cl<int> zeros_flag(1, 1);
zeros_flag = to_matrix_cl(zero_on_diagonal_flag);
matrix_cl<int> zeros_flag = constant(0, 1, 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[optional] This looks good, but we could completely replace this function with kernel generator implementation.

opencl_kernels::check_diagonal_zeros(cl::NDRange(y.rows(), y.cols()), y,
zeros_flag, y.rows(), y.cols());
zero_on_diagonal_flag = from_matrix_cl_error_code(zeros_flag);
Expand Down
5 changes: 4 additions & 1 deletion stan/math/opencl/kernel_generator/binary_operation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,10 @@ class binary_operation : public operation_cl<Derived, T_res, T_a, T_b> {
}

ADD_BINARY_OPERATION(addition_, operator+, common_scalar_t<T_a COMMA T_b>, "+");
ADD_BINARY_OPERATION(subtraction_, operator-, common_scalar_t<T_a COMMA T_b>,
ADD_BINARY_OPERATION(addition_operator_, add, common_scalar_t<T_a COMMA T_b>, "+");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have the class names for addition aroud - the other one is operator. Also it might be better to add a wrapper function for add to delegate to operator instead of constructing another class

ADD_BINARY_OPERATION(subtraction_operator_, operator-, common_scalar_t<T_a COMMA T_b>,
"-");
ADD_BINARY_OPERATION(subtraction_, subtract, common_scalar_t<T_a COMMA T_b>,
"-");
ADD_BINARY_OPERATION_WITH_CUSTOM_CODE(
elt_multiply_, elt_multiply, common_scalar_t<T_a COMMA T_b>, "*",
Expand Down
2 changes: 1 addition & 1 deletion stan/math/opencl/kernel_generator/check_cl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class check_cl_ : public operation_cl_lhs<check_cl_<T>, bool> {
function_(function),
err_variable_(err_variable),
must_be_(must_be) {
buffer_.zeros();
buffer_ = constant(0, buffer_.rows(), buffer_.cols());
buffer_.view(matrix_cl_view::Entire);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef STAN_MATH_OPENCL_KERNELS_FILL_HPP
#define STAN_MATH_OPENCL_KERNELS_FILL_HPP
#ifndef STAN_MATH_OPENCL_KERNELS_FILL_STRICT_TRI_HPP
#define STAN_MATH_OPENCL_KERNELS_FILL_STRICT_TRI_HPP
#ifdef STAN_OPENCL

#include <stan/math/opencl/kernel_cl.hpp>
Expand All @@ -10,45 +10,6 @@
namespace stan {
namespace math {
namespace opencl_kernels {
// \cond
static const std::string fill_kernel_code = STRINGIFY(
// \endcond
/** \ingroup opencl_kernels
* Stores constant in the matrix on the OpenCL device.
* Supports writing constants to the lower and upper triangular or
* the whole matrix.
*
* @param[out] A matrix
* @param val value to replicate in the matrix
* @param rows Number of rows for matrix A
* @param cols Number of columns for matrix A
* @param view_A triangular part of matrix A to use
*
* @note Code is a <code>const char*</code> held in
* <code>fill_kernel_code.</code>
* This kernel uses the helper macros available in helpers.cl.
*/
__kernel void fill(__global double* A, double val, unsigned int rows,
unsigned int cols, unsigned int view_A) {
const int i = get_global_id(0);
const int j = get_global_id(1);
if (i < rows && j < cols) {
if ((contains_nonzero(view_A, LOWER) && j <= i)
|| (contains_nonzero(view_A, UPPER) && j >= i) || i == j) {
A(i, j) = val;
}
}
}
// \cond
);
// \endcond

/** \ingroup opencl_kernels
* See the docs for \link kernels/fill.hpp fill() \endlink
*/
const kernel_cl<out_buffer, double, int, int, matrix_cl_view> fill(
"fill", {indexing_helpers, view_kernel_helpers, fill_kernel_code});

// \cond
static const std::string fill_strict_tri_kernel_code = STRINGIFY(
// \endcond
Expand Down Expand Up @@ -85,7 +46,7 @@ static const std::string fill_strict_tri_kernel_code = STRINGIFY(
// \endcond

/** \ingroup opencl_kernels
* See the docs for \link kernels/fill.hpp fill_strict_tri_kernel_code()
* See the docs for \link kernels/fill_strict_tri.hpp fill_strict_tri_kernel_code()
* \endlink
*/
const kernel_cl<out_buffer, double, int, int, matrix_cl_view> fill_strict_tri(
Expand Down
2 changes: 0 additions & 2 deletions stan/math/opencl/matrix_cl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ class matrix_cl<T, require_arithmetic_t<T>> : public matrix_cl_base {
using type = T; // Underlying type of the matrix
// Forward declare the methods that work in place on the matrix
template <matrix_cl_view matrix_view = matrix_cl_view::Entire>
inline void zeros();
template <matrix_cl_view matrix_view = matrix_cl_view::Entire>
inline void zeros_strict_tri();
template <TriangularMapCL triangular_map = TriangularMapCL::LowerToUpper>
inline void triangular_transpose();
Expand Down
3 changes: 1 addition & 2 deletions stan/math/opencl/multiply.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include <stan/math/opencl/kernels/add.hpp>
#include <stan/math/opencl/scalar_type.hpp>
#include <stan/math/opencl/sub_block.hpp>
#include <stan/math/opencl/zeros.hpp>
#include <stan/math/opencl/zeros_strict_tri.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/meta.hpp>
#include <algorithm>
Expand Down Expand Up @@ -43,7 +43,6 @@ inline matrix_cl<return_type_t<T1, T2>> multiply(const T1& A, const T2& B) {
matrix_cl<return_type_t<T1, T2>> temp(A.rows(), B.cols(),
either(A.view(), B.view()));
if (A.size() == 0 || B.size() == 0) {
temp.zeros();
return temp;
}
if (A.rows() == 1) {
Expand Down
5 changes: 3 additions & 2 deletions stan/math/opencl/prim.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@
#include <stan/math/opencl/sub_block.hpp>
#include <stan/math/opencl/triangular_transpose.hpp>
#include <stan/math/opencl/value_type.hpp>
#include <stan/math/opencl/zeros.hpp>
#include <stan/math/opencl/zeros_strict_tri.hpp>

#include <stan/math/opencl/prim/add.hpp>
#include <stan/math/opencl/prim/add_diag.hpp>
#include <stan/math/opencl/prim/bernoulli_lpmf.hpp>
#include <stan/math/opencl/prim/bernoulli_logit_lpmf.hpp>
#include <stan/math/opencl/prim/bernoulli_logit_glm_lpmf.hpp>
Expand All @@ -113,6 +113,7 @@
#include <stan/math/opencl/prim/col.hpp>
#include <stan/math/opencl/prim/cols.hpp>
#include <stan/math/opencl/prim/crossprod.hpp>
#include <stan/math/opencl/prim/diag_matrix.hpp>
#include <stan/math/opencl/prim/dims.hpp>
#include <stan/math/opencl/prim/divide.hpp>
#include <stan/math/opencl/prim/divide_columns.hpp>
Expand Down
31 changes: 0 additions & 31 deletions stan/math/opencl/prim/add.hpp

This file was deleted.

40 changes: 40 additions & 0 deletions stan/math/opencl/prim/add_diag.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#ifndef STAN_MATH_OPENCL_PRIM_ADD_DIAG_HPP
#define STAN_MATH_OPENCL_PRIM_ADD_DIAG_HPP
#ifdef STAN_OPENCL
#include <stan/math/opencl/prim/size.hpp>
#include <stan/math/opencl/kernel_generator.hpp>
#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>

namespace stan {
namespace math {

/**
* Returns a Matrix with values added along the main diagonal
*
* @tparam T_m type of input kernel generator expression for the input matrix
* @tparam T_a type of input kernel generator expression to add along the diagonal
*
* @param mat input kernel generator expression
* @param to_add scalar value or input kernel generator expression to add along the
* diagonal
* @return a kernel generator expressio with to_add added along main diagonal
*/
template <typename T_m, typename T_a,
require_all_kernel_expressions_and_none_scalar_t<T_m>* = nullptr,
require_all_kernel_expressions_t<T_a>* = nullptr>
inline auto add_diag(T_m&& mat, T_a&& to_add) { // NOLINT
if (is_vector<T_a>::value) {
const size_t length_diag = std::min(mat.rows(), mat.cols());
// int a = stan::math::size(mat);
check_consistent_sizes("add_diag (OpenCL)", "number of elements of to_add", to_add, "diagonal",
length_diag);
}
diagonal(mat) = diagonal(mat) + to_add;
return mat;
}
} // namespace math
} // namespace stan

#endif
#endif
31 changes: 31 additions & 0 deletions stan/math/opencl/prim/diag_matrix.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#ifndef STAN_MATH_OPENCL_PRIM_DIAG_MATRIX_HPP
#define STAN_MATH_OPENCL_PRIM_DIAG_MATRIX_HPP
#ifdef STAN_OPENCL
#include <stan/math/opencl/kernel_generator.hpp>
#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>

namespace stan {
namespace math {

/**
* Return a square diagonal matrix with the specified vector of
* coefficients as the diagonal values.
*
* @tparam T_x type of input kernel generator expression for the
* diagonal
*
* @param x input kernel generator expression for the diagonal
*
* @return a kernel generator expression
*/
template <typename T_x,
require_all_kernel_expressions_and_none_scalar_t<T_x>* = nullptr>
inline auto diag_matrix(T_x&& x) { // NOLINT
return add_diag(constant(0,x.size(), x.size()).eval(), x);
}
} // namespace math
} // namespace stan

#endif
#endif
4 changes: 2 additions & 2 deletions stan/math/opencl/prim/inv_cloglog.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef STAN_MATH_OPENCL_PRIM_FUN_INV_SQUARE_HPP
#define STAN_MATH_OPENCL_PRIM_FUN_INV_SQUARE_HPP
#ifndef STAN_MATH_OPENCL_PRIM_FUN_INV_CLOGLOG_HPP
#define STAN_MATH_OPENCL_PRIM_FUN_INV_CLOGLOG_HPP
#ifdef STAN_OPENCL

#include <stan/math/opencl/matrix_cl.hpp>
Expand Down
4 changes: 0 additions & 4 deletions stan/math/opencl/rev/arena_matrix_cl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,6 @@ class arena_matrix_cl {
// Wrapers to functions with explicit template parameters are implemented
// without macros.
template <matrix_cl_view matrix_view = matrix_cl_view::Entire>
inline void zeros() {
impl_->template zeros<matrix_view>();
}
template <matrix_cl_view matrix_view = matrix_cl_view::Entire>
inline void zeros_strict_tri() {
impl_->template zeros_strict_tri<matrix_view>();
}
Expand Down
8 changes: 3 additions & 5 deletions stan/math/opencl/tri_inverse.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include <stan/math/opencl/err.hpp>
#include <stan/math/opencl/identity.hpp>
#include <stan/math/opencl/sub_block.hpp>
#include <stan/math/opencl/zeros.hpp>
#include <stan/math/opencl/zeros_strict_tri.hpp>
#include <stan/math/opencl/kernel_generator.hpp>
#include <stan/math/prim/meta.hpp>
#include <cmath>
Expand Down Expand Up @@ -79,11 +79,9 @@ inline matrix_cl<T> tri_inverse(const matrix_cl<T>& A) {
* thread_block_size_1D;

matrix_cl<T> temp(A_rows_padded, A_rows_padded);
matrix_cl<T> inv_padded(A_rows_padded, A_rows_padded);
matrix_cl<T> inv_padded = constant(0, A_rows_padded, A_rows_padded);
matrix_cl<T> inv_mat(A);
matrix_cl<T> zero_mat(A_rows_padded - A.rows(), A_rows_padded);
zero_mat.template zeros<stan::math::matrix_cl_view::Entire>();
inv_padded.template zeros<stan::math::matrix_cl_view::Entire>();
matrix_cl<T> zero_mat = constant(0, A_rows_padded - A.rows(), A_rows_padded);
if (tri_view == matrix_cl_view::Upper) {
inv_mat = transpose(inv_mat).eval();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,14 @@
#include <stan/math/opencl/matrix_cl_view.hpp>
#include <stan/math/opencl/matrix_cl.hpp>
#include <stan/math/opencl/err.hpp>
#include <stan/math/opencl/kernels/fill.hpp>
#include <stan/math/opencl/kernels/fill_strict_tri.hpp>
#include <stan/math/prim/meta.hpp>

#include <CL/cl2.hpp>

namespace stan {
namespace math {

/** \ingroup matrix_cl_group
* Stores zeros in the matrix on the OpenCL device.
* Supports writing zeroes to the lower and upper triangular or
* the whole matrix.
*
* @tparam view Specifies if zeros are assigned to
* the entire matrix, lower triangular or upper triangular. The
* value must be of type matrix_cl_view
*/
template <typename T>
template <matrix_cl_view matrix_view>
inline void matrix_cl<T, require_arithmetic_t<T>>::zeros() try {
if (size() == 0) {
return;
}
this->view_ = both(this->view_, invert(matrix_view));
cl::CommandQueue cmdQueue = opencl_context.queue();
opencl_kernels::fill(cl::NDRange(this->rows(), this->cols()), *this, 0.0,
this->rows(), this->cols(), matrix_view);
} catch (const cl::Error& e) {
check_opencl_error("zeros", e);
}

/** \ingroup matrix_cl_group
* Stores zeros in the strict's triangular part (excluding the diagonal)
* of a matrix on the OpenCL device.
Expand Down
8 changes: 4 additions & 4 deletions stan/math/prim.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

#include <stan/math/version.hpp>

#ifdef STAN_OPENCL
#include <stan/math/opencl/prim.hpp>
#endif

#include <stan/math/prim/core.hpp>
#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>
Expand All @@ -11,8 +15,4 @@
#include <stan/math/prim/functor.hpp>
#include <stan/math/prim/prob.hpp>

#ifdef STAN_OPENCL
#include <stan/math/opencl/prim.hpp>
#endif

#endif
40 changes: 40 additions & 0 deletions test/unit/math/opencl/prim/add_diag_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#ifdef STAN_OPENCL
#include <stan/math/prim.hpp>
#include <test/unit/math/opencl/util.hpp>
#include <test/unit/util.hpp>
#include <gtest/gtest.h>
#include <algorithm>

auto add_diag_functor
= [](const auto& a, const auto& b) { return stan::math::add_diag(a, b); };

TEST(OpenCLPrim, add_diag_small) {
stan::math::matrix_d d1(3,3);
d1 << 1, 2,3,4,5,6,7,8,9;
stan::math::vector_d d2(3);
d2 << -1,15,500;
stan::math::test::compare_cpu_opencl_prim(add_diag_functor, d1, 2);
stan::math::test::compare_cpu_opencl_prim(add_diag_functor, d1, d2);

stan::math::matrix_d d3(3,5);
d3 << 1, 2,3,4,5,6,7,8,9,10,11,12,13,14,15;
stan::math::test::compare_cpu_opencl_prim(add_diag_functor, d3, 2);
stan::math::test::compare_cpu_opencl_prim(add_diag_functor, d3, d2);
}

TEST(OpenCLPrim, add_diag_zero) {
stan::math::matrix_d d1(0,0);
stan::math::vector_d d2(0);
stan::math::test::compare_cpu_opencl_prim(add_diag_functor, d1, d2);
}

TEST(OpenCLPrim, add_diag_exception) {
stan::math::matrix_d d1(3,3);
stan::math::matrix_d d2(2,3);
stan::math::vector_d d3(5);
stan::math::vector_d d4(3);
EXPECT_THROW(stan::math::add_diag(stan::math::to_matrix_cl(d1), stan::math::to_matrix_cl(d3)), std::invalid_argument);
EXPECT_THROW(stan::math::add_diag(stan::math::to_matrix_cl(d2), stan::math::to_matrix_cl(d4)), std::invalid_argument);
}

#endif
Loading