Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OpenCL/prim add_diag, diag_matrix, subtract and minor OpenCL fixes #2250

Merged
merged 20 commits into from
Dec 11, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions stan/math/opencl/err/check_diagonal_zeros.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <stan/math/opencl/copy.hpp>
#include <stan/math/opencl/matrix_cl.hpp>
#include <stan/math/opencl/kernels/check_diagonal_zeros.hpp>
#include <stan/math/opencl/kernel_generator/constant.hpp>
#include <vector>

namespace stan {
Expand Down
6 changes: 3 additions & 3 deletions stan/math/opencl/kernel_generator/binary_operation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,9 @@ class binary_operation : public operation_cl<Derived, T_res, T_a, T_b> {
as_operation_cl(std::forward<T_b>(b))}; \
}

ADD_BINARY_OPERATION(addition_, operator+, common_scalar_t<T_a COMMA T_b>, "+");
ADD_BINARY_OPERATION(addition_operator_, add, common_scalar_t<T_a COMMA T_b>,
"+");
ADD_BINARY_OPERATION(addition_operator_, operator+,
common_scalar_t<T_a COMMA T_b>, "+");
ADD_BINARY_OPERATION(addition_, add, common_scalar_t<T_a COMMA T_b>, "+");
ADD_BINARY_OPERATION(subtraction_operator_, operator-,
common_scalar_t<T_a COMMA T_b>, "-");
ADD_BINARY_OPERATION(subtraction_, subtract, common_scalar_t<T_a COMMA T_b>,
Expand Down
4 changes: 2 additions & 2 deletions stan/math/opencl/kernel_generator/check_cl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <stan/math/opencl/value_type.hpp>
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
#include <stan/math/opencl/kernel_generator/operation_cl_lhs.hpp>
#include <stan/math/opencl/kernel_generator/constant.hpp>
#include <stan/math/opencl/kernel_generator/scalar.hpp>
#include <map>

Expand Down Expand Up @@ -52,13 +53,12 @@ class check_cl_ : public operation_cl_lhs<check_cl_<T>, bool> {
*/
check_cl_(const char* function, const char* err_variable, T&& y,
const char* must_be)
: buffer_(3, 1),
: buffer_(constant(0, 3, 1)),
value_(1, 1),
arg_(std::forward<T>(y)),
function_(function),
err_variable_(err_variable),
must_be_(must_be) {
buffer_ = constant(0, buffer_.rows(), buffer_.cols());
buffer_.view(matrix_cl_view::Entire);
}

Expand Down
1 change: 0 additions & 1 deletion stan/math/opencl/multiply.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include <stan/math/opencl/kernels/add.hpp>
#include <stan/math/opencl/scalar_type.hpp>
#include <stan/math/opencl/sub_block.hpp>
#include <stan/math/opencl/zeros_strict_tri.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/meta.hpp>
#include <algorithm>
Expand Down
5 changes: 3 additions & 2 deletions stan/math/opencl/prim/add_diag.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ inline auto add_diag(T_m&& mat, T_a&& to_add) { // NOLINT
check_consistent_sizes("add_diag (OpenCL)", "number of elements of to_add",
to_add, "diagonal", length_diag);
}
diagonal(mat) = diagonal(mat) + to_add;
return mat;
matrix_cl<typename std::decay_t<T_m>::Scalar> mat_eval = mat.eval();
diagonal(mat_eval) = diagonal(mat_eval) + to_add;
return mat_eval;
}
} // namespace math
} // namespace stan
Expand Down
2 changes: 1 addition & 1 deletion stan/math/opencl/prim/diag_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace math {
template <typename T_x,
require_all_kernel_expressions_and_none_scalar_t<T_x>* = nullptr>
inline auto diag_matrix(T_x&& x) { // NOLINT
return add_diag(constant(0, x.size(), x.size()).eval(), x);
return add_diag(constant(value_type_t<T_x>(0), x.size(), x.size()).eval(), x);
}
} // namespace math
} // namespace stan
Expand Down
5 changes: 3 additions & 2 deletions stan/math/opencl/tri_inverse.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,10 @@ inline matrix_cl<T> tri_inverse(const matrix_cl<T>& A) {
* thread_block_size_1D;

matrix_cl<T> temp(A_rows_padded, A_rows_padded);
matrix_cl<T> inv_padded = constant(0, A_rows_padded, A_rows_padded);
matrix_cl<T> inv_padded = constant(0.0, A_rows_padded, A_rows_padded);
matrix_cl<T> inv_mat(A);
matrix_cl<T> zero_mat = constant(0, A_rows_padded - A.rows(), A_rows_padded);
matrix_cl<T> zero_mat
= constant(0.0, A_rows_padded - A.rows(), A_rows_padded);
if (tri_view == matrix_cl_view::Upper) {
inv_mat = transpose(inv_mat).eval();
}
Expand Down
10 changes: 0 additions & 10 deletions test/unit/math/opencl/prim/add_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,62 +7,52 @@

auto add_functor
= [](const auto& a, const auto& b) { return stan::math::add(a, b); };
auto plusop_functor = [](const auto& a, const auto& b) { return a + b; };

TEST(OpenCLPrim, add_v_small_zero) {
stan::math::vector_d d1(3), d2(3);
d1 << 1, 2, 3;
d2 << 3, 2, 1;
stan::math::test::compare_cpu_opencl_prim(add_functor, d1, d2);
stan::math::test::compare_cpu_opencl_prim(plusop_functor, d1, d2);

stan::math::vector_d d0(0);
stan::math::test::compare_cpu_opencl_prim(add_functor, d1, d2);
stan::math::test::compare_cpu_opencl_prim(plusop_functor, d1, d2);
}

TEST(OpenCLPrim, add_rv_small_zero) {
stan::math::row_vector_d d1(3), d2(3);
d1 << 1, 2, 3;
d2 << 3, 2, 1;
stan::math::test::compare_cpu_opencl_prim(add_functor, d1, d2);
stan::math::test::compare_cpu_opencl_prim(plusop_functor, d1, d2);

stan::math::vector_d d0(0);
stan::math::test::compare_cpu_opencl_prim(add_functor, d1, d2);
stan::math::test::compare_cpu_opencl_prim(plusop_functor, d1, d2);
}

TEST(OpenCLPrim, add_m_small_zero) {
stan::math::matrix_d d1(3, 3), d2(3, 3);
d1 << 1, 2, 3, 4, 5, 6, 7, 8, 9;
d2 << 10, 100, 1000, 0, -10, -12, 2, 4, 8;
stan::math::test::compare_cpu_opencl_prim(add_functor, d1, d2);
stan::math::test::compare_cpu_opencl_prim(plusop_functor, d1, d2);

stan::math::matrix_d d0(0, 0);
stan::math::test::compare_cpu_opencl_prim(add_functor, d1, d2);
stan::math::test::compare_cpu_opencl_prim(plusop_functor, d1, d2);
}

TEST(OpenCLPrim, add_exceptions) {
stan::math::vector_d vd1(2), vd2(3);
stan::math::matrix_cl<double> vd11(vd1);
stan::math::matrix_cl<double> vd22(vd2);
EXPECT_THROW(stan::math::add(vd11, vd22), std::invalid_argument);
EXPECT_THROW((vd11 + vd22), std::invalid_argument);

stan::math::row_vector_d rvd1(2), rvd2(3);
stan::math::matrix_cl<double> rvd11(rvd1);
stan::math::matrix_cl<double> rvd22(rvd2);
EXPECT_THROW(stan::math::add(rvd11, rvd22), std::invalid_argument);
EXPECT_THROW((rvd11 + rvd22), std::invalid_argument);

stan::math::matrix_d md1(2, 2), md2(3, 3);
stan::math::matrix_cl<double> md11(md1);
stan::math::matrix_cl<double> md22(md2);
EXPECT_THROW(stan::math::add(md11, md22), std::invalid_argument);
EXPECT_THROW((md11 + md22), std::invalid_argument);
}

TEST(OpenCLPrim, add_tri_value_check) {
Expand Down
96 changes: 20 additions & 76 deletions test/unit/math/opencl/prim/subtract_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,62 +7,52 @@

auto subtract_functor
= [](const auto& a, const auto& b) { return stan::math::subtract(a, b); };
auto minusop_functor = [](const auto& a, const auto& b) { return a - b; };

TEST(OpenCLPrim, subtract_v_small_zero) {
stan::math::vector_d d1(3), d2(3);
d1 << 1, 2, 3;
d2 << 3, 2, 1;
stan::math::test::compare_cpu_opencl_prim(subtract_functor, d1, d2);
stan::math::test::compare_cpu_opencl_prim(minusop_functor, d1, d2);

stan::math::vector_d d0(0);
stan::math::test::compare_cpu_opencl_prim(subtract_functor, d1, d2);
stan::math::test::compare_cpu_opencl_prim(minusop_functor, d1, d2);
}

TEST(OpenCLPrim, subtract_rv_small_zero) {
stan::math::row_vector_d d1(3), d2(3);
d1 << 1, 2, 3;
d2 << 3, 2, 1;
stan::math::test::compare_cpu_opencl_prim(subtract_functor, d1, d2);
stan::math::test::compare_cpu_opencl_prim(minusop_functor, d1, d2);

stan::math::vector_d d0(0);
stan::math::test::compare_cpu_opencl_prim(subtract_functor, d1, d2);
stan::math::test::compare_cpu_opencl_prim(minusop_functor, d1, d2);
}

TEST(OpenCLPrim, subtract_m_small_zero) {
stan::math::matrix_d d1(3, 3), d2(3, 3);
d1 << 1, 2, 3, 4, 5, 6, 7, 8, 9;
d2 << 10, 100, 1000, 0, -10, -12, 2, 4, 8;
stan::math::test::compare_cpu_opencl_prim(subtract_functor, d1, d2);
stan::math::test::compare_cpu_opencl_prim(minusop_functor, d1, d2);

stan::math::matrix_d d0(0, 0);
stan::math::test::compare_cpu_opencl_prim(subtract_functor, d1, d2);
stan::math::test::compare_cpu_opencl_prim(minusop_functor, d1, d2);
}

TEST(OpenCLPrim, subtract_exceptions) {
stan::math::vector_d vd1(2), vd2(3);
stan::math::matrix_cl<double> vd11(vd1);
stan::math::matrix_cl<double> vd22(vd2);
EXPECT_THROW(stan::math::subtract(vd11, vd22), std::invalid_argument);
EXPECT_THROW((vd11 + vd22), std::invalid_argument);

stan::math::row_vector_d rvd1(2), rvd2(3);
stan::math::matrix_cl<double> rvd11(rvd1);
stan::math::matrix_cl<double> rvd22(rvd2);
EXPECT_THROW(stan::math::subtract(rvd11, rvd22), std::invalid_argument);
EXPECT_THROW((rvd11 + rvd22), std::invalid_argument);

stan::math::matrix_d md1(2, 2), md2(3, 3);
stan::math::matrix_cl<double> md11(md1);
stan::math::matrix_cl<double> md22(md2);
EXPECT_THROW(stan::math::subtract(md11, md22), std::invalid_argument);
EXPECT_THROW((md11 + md22), std::invalid_argument);
}

TEST(MathMatrixCL, subtract_tri_value_check) {
Expand All @@ -79,109 +69,63 @@ TEST(MathMatrixCL, subtract_tri_value_check) {
c_cl = a_cl - b_cl;
EXPECT_EQ(c_cl.view(), stan::math::matrix_cl_view::Lower);
c = stan::math::from_matrix_cl(c_cl);
EXPECT_EQ(-2, c(0, 0));
EXPECT_EQ(1, c(1, 0));
EXPECT_EQ(2, c(1, 1));
EXPECT_EQ(4, c(2, 0));
EXPECT_EQ(5, c(2, 1));
EXPECT_EQ(6, c(2, 2));
EXPECT_MATRIX_EQ((Eigen::MatrixXd(a.triangularView<Eigen::Lower>())
- Eigen::MatrixXd(b.triangularView<Eigen::Lower>())),
c);

c_cl = stan::math::subtract(a_cl, b_cl);
EXPECT_EQ(c_cl.view(), stan::math::matrix_cl_view::Lower);
c = stan::math::from_matrix_cl(c_cl);
EXPECT_EQ(-2, c(0, 0));
EXPECT_EQ(1, c(1, 0));
EXPECT_EQ(2, c(1, 1));
EXPECT_EQ(4, c(2, 0));
EXPECT_EQ(5, c(2, 1));
EXPECT_EQ(6, c(2, 2));
EXPECT_MATRIX_EQ((Eigen::MatrixXd(a.triangularView<Eigen::Lower>())
- Eigen::MatrixXd(b.triangularView<Eigen::Lower>())),
c);

a_cl.view(stan::math::matrix_cl_view::Lower);
b_cl.view(stan::math::matrix_cl_view::Upper);
c_cl = a_cl - b_cl;
EXPECT_EQ(c_cl.view(), stan::math::matrix_cl_view::Entire);
c = stan::math::from_matrix_cl(c_cl);
EXPECT_EQ(-2, c(0, 0));
EXPECT_EQ(-3, c(0, 1));
EXPECT_EQ(-3, c(0, 2));
EXPECT_EQ(4, c(1, 0));
EXPECT_EQ(2, c(1, 1));
EXPECT_EQ(-3, c(1, 2));
EXPECT_EQ(7, c(2, 0));
EXPECT_EQ(8, c(2, 1));
EXPECT_EQ(6, c(2, 2));
EXPECT_MATRIX_EQ((Eigen::MatrixXd(a.triangularView<Eigen::Lower>())
- Eigen::MatrixXd(b.triangularView<Eigen::Upper>())),
c);

c_cl = stan::math::subtract(a_cl, b_cl);
EXPECT_EQ(c_cl.view(), stan::math::matrix_cl_view::Entire);
c = stan::math::from_matrix_cl(c_cl);
EXPECT_EQ(-2, c(0, 0));
EXPECT_EQ(-3, c(0, 1));
EXPECT_EQ(-3, c(0, 2));
EXPECT_EQ(4, c(1, 0));
EXPECT_EQ(2, c(1, 1));
EXPECT_EQ(-3, c(1, 2));
EXPECT_EQ(7, c(2, 0));
EXPECT_EQ(8, c(2, 1));
EXPECT_EQ(6, c(2, 2));
EXPECT_MATRIX_EQ((Eigen::MatrixXd(a.triangularView<Eigen::Lower>())
- Eigen::MatrixXd(b.triangularView<Eigen::Upper>())),
c);

a_cl.view(stan::math::matrix_cl_view::Upper);
b_cl.view(stan::math::matrix_cl_view::Lower);
c_cl = a_cl - b_cl;
EXPECT_EQ(c_cl.view(), stan::math::matrix_cl_view::Entire);
c = stan::math::from_matrix_cl(c_cl);
EXPECT_EQ(-2, c(0, 0));
EXPECT_EQ(2, c(0, 1));
EXPECT_EQ(3, c(0, 2));
EXPECT_EQ(-3, c(1, 0));
EXPECT_EQ(2, c(1, 1));
EXPECT_EQ(6, c(1, 2));
EXPECT_EQ(-3, c(2, 0));
EXPECT_EQ(-3, c(2, 1));
EXPECT_EQ(6, c(2, 2));
EXPECT_MATRIX_EQ((Eigen::MatrixXd(a.triangularView<Eigen::Upper>())
- Eigen::MatrixXd(b.triangularView<Eigen::Lower>())),
c);

a_cl.view(stan::math::matrix_cl_view::Upper);
b_cl.view(stan::math::matrix_cl_view::Lower);
c_cl = stan::math::subtract(a_cl, b_cl);
EXPECT_EQ(c_cl.view(), stan::math::matrix_cl_view::Entire);
c = stan::math::from_matrix_cl(c_cl);
EXPECT_EQ(-2, c(0, 0));
EXPECT_EQ(2, c(0, 1));
EXPECT_EQ(3, c(0, 2));
EXPECT_EQ(-3, c(1, 0));
EXPECT_EQ(2, c(1, 1));
EXPECT_EQ(6, c(1, 2));
EXPECT_EQ(-3, c(2, 0));
EXPECT_EQ(-3, c(2, 1));
EXPECT_EQ(6, c(2, 2));
EXPECT_MATRIX_EQ((Eigen::MatrixXd(a.triangularView<Eigen::Upper>())
- Eigen::MatrixXd(b.triangularView<Eigen::Lower>())),
c);

a_cl.view(stan::math::matrix_cl_view::Entire);
b_cl.view(stan::math::matrix_cl_view::Lower);
c_cl = a_cl - b_cl;
EXPECT_EQ(c_cl.view(), stan::math::matrix_cl_view::Entire);
c = stan::math::from_matrix_cl(c_cl);
EXPECT_EQ(-2, c(0, 0));
EXPECT_EQ(2, c(0, 1));
EXPECT_EQ(3, c(0, 2));
EXPECT_EQ(1, c(1, 0));
EXPECT_EQ(2, c(1, 1));
EXPECT_EQ(6, c(1, 2));
EXPECT_EQ(4, c(2, 0));
EXPECT_EQ(5, c(2, 1));
EXPECT_EQ(6, c(2, 2));
EXPECT_MATRIX_EQ((a - Eigen::MatrixXd(b.triangularView<Eigen::Lower>())), c);

a_cl.view(stan::math::matrix_cl_view::Entire);
b_cl.view(stan::math::matrix_cl_view::Lower);
c_cl = stan::math::subtract(a_cl, b_cl);
EXPECT_EQ(c_cl.view(), stan::math::matrix_cl_view::Entire);
c = stan::math::from_matrix_cl(c_cl);
EXPECT_EQ(-2, c(0, 0));
EXPECT_EQ(2, c(0, 1));
EXPECT_EQ(3, c(0, 2));
EXPECT_EQ(1, c(1, 0));
EXPECT_EQ(2, c(1, 1));
EXPECT_EQ(6, c(1, 2));
EXPECT_EQ(4, c(2, 0));
EXPECT_EQ(5, c(2, 1));
EXPECT_EQ(6, c(2, 2));
EXPECT_MATRIX_EQ((a - Eigen::MatrixXd(b.triangularView<Eigen::Lower>())), c);
}
#endif
20 changes: 0 additions & 20 deletions test/unit/math/opencl/util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,26 +142,6 @@ void prim_rev_argument_combinations(const Functor& f, const Arg0& arg0,
args...);
}

template <typename Functor>
void prim_argument_combinations(Functor f) {
f(std::make_tuple(), std::make_tuple());
}

template <typename Functor, typename Arg0, typename... Args>
void prim_argument_combinations(const Functor& f, const Arg0& arg0,
const Args&... args) {
prim_argument_combinations(
[&f, &arg0](const auto& args1, const auto& args2) {
constexpr size_t Size
= std::tuple_size<std::decay_t<decltype(args1)>>::value;
return index_apply<Size>([&](auto... Is) {
return f(std::make_tuple(arg0, std::get<Is>(args1)...),
std::make_tuple(arg0, std::get<Is>(args2)...));
});
},
args...);
}

template <typename Functor, std::size_t... Is, typename... Args>
void compare_cpu_opencl_prim_rev_impl(const Functor& functor,
std::index_sequence<Is...>,
Expand Down