Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix arena_matrix_cl to matrix_cl conversions #2538

Merged
merged 4 commits into from
Jul 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions stan/math/opencl/kernel_generator/matrix_cl_conversion.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@ namespace math {
*/
template <typename T>
template <typename Expr,
require_all_kernel_expressions_and_none_scalar_t<Expr>*>
require_all_kernel_expressions_and_none_scalar_t<Expr>*,
require_not_matrix_cl_t<Expr>*>
matrix_cl<T>::matrix_cl(const Expr& expresion) : rows_(0), cols_(0) {
results(*this) = expressions(expresion);
}

template <typename T>
template <typename Expr,
require_all_kernel_expressions_and_none_scalar_t<Expr>*>
require_all_kernel_expressions_and_none_scalar_t<Expr>*,
require_not_matrix_cl_t<Expr>*>
matrix_cl<T>& matrix_cl<T>::operator=(const Expr& expresion) {
results(*this) = expressions(expresion);
return *this;
Expand Down
27 changes: 25 additions & 2 deletions stan/math/opencl/matrix_cl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ namespace math {
* @{
*/

// forward declare
template <typename T>
class arena_matrix_cl;

template <typename>
class matrix_cl;

Expand Down Expand Up @@ -212,6 +216,13 @@ class matrix_cl : public matrix_cl_base {
write_events_(std::move(A.write_events_)),
read_events_(std::move(A.read_events_)) {}

/**
* Constructor from `arena_matrix_cl`.
* @param A matrix_cl to move
*/
// defined in rev/arena_matrix_cl.hpp
matrix_cl(const arena_matrix_cl<T>& A); // NOLINT(runtime/explicit)

/**
* Constructor for the matrix_cl that creates a copy of a std::vector of Eigen
* matrices on the OpenCL device. Each matrix is flattened into one column
Expand Down Expand Up @@ -415,8 +426,10 @@ class matrix_cl : public matrix_cl_base {
* @tparam Expr type of the expression
* @param expression expression
*/
// defined in kernel_generator/matrix_cl_conversion.hpp
template <typename Expr,
require_all_kernel_expressions_and_none_scalar_t<Expr>* = nullptr>
require_all_kernel_expressions_and_none_scalar_t<Expr>* = nullptr,
require_not_matrix_cl_t<Expr>* = nullptr>
matrix_cl(const Expr& expression); // NOLINT(runtime/explicit)

/**
Expand Down Expand Up @@ -460,10 +473,20 @@ class matrix_cl : public matrix_cl_base {
* @tparam Expr type of the expression
* @param expression expression
*/
// defined in kernel_generator/matrix_cl_conversion.hpp
template <typename Expr,
require_all_kernel_expressions_and_none_scalar_t<Expr>* = nullptr>
require_all_kernel_expressions_and_none_scalar_t<Expr>* = nullptr,
require_not_matrix_cl_t<Expr>* = nullptr>
matrix_cl<T>& operator=(const Expr& expression);

/**
* Assignment of `arena_matrix_cl<T>`.
* @tparam Expr type of the expression
* @param expression expression
*/
// defined in rev/arena_matrix_cl.hpp
matrix_cl<T>& operator=(const arena_matrix_cl<T>& other);

/**
* Evaluates `this`. This is a no-op.
* @return `*this`
Expand Down
35 changes: 31 additions & 4 deletions stan/math/opencl/rev/arena_matrix_cl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ template <typename T>
class arena_matrix_cl : public matrix_cl_base {
private:
internal::arena_matrix_cl_impl<T>* impl_;
template <typename>
friend class matrix_cl;

public:
using Scalar = typename matrix_cl<T>::Scalar;
Expand Down Expand Up @@ -75,11 +77,14 @@ class arena_matrix_cl : public matrix_cl_base {

/**
* Implicit conversion operator to `matrix_cl`.
* @return `matrix_cl` equivalent to `*this`
* @return `matrix_cl` equivalent to `*this`. Returned matrix references the
* same underlying buffer.
*/
operator matrix_cl<T>() const& { return *impl_; } // NOLINT(runtime/explicit)
operator matrix_cl<T>() && { // NOLINT(runtime/explicit)
return std::move(*impl_);
operator const matrix_cl<T>&() const { // NOLINT(runtime/explicit)
return *static_cast<const matrix_cl<T>*>(impl_);
}
operator matrix_cl<T>&() { // NOLINT(runtime/explicit)
return *static_cast<matrix_cl<T>*>(impl_);
}

/**
Expand Down Expand Up @@ -139,6 +144,28 @@ class arena_matrix_cl : public matrix_cl_base {
#undef ARENA_MATRIX_CL_FUNCTION_WRAPPER
#undef ARENA_MATRIX_CL_CONST_FUNCTION_WRAPPER
};
template <typename T>
matrix_cl<T>::matrix_cl(const arena_matrix_cl<T>& A)
// works like a move constructor, except it does not modify `a`
: buffer_cl_(A.impl_->buffer_cl_),
rows_(A.impl_->rows_),
cols_(A.impl_->cols_),
view_(A.impl_->view_),
write_events_(A.impl_->write_events_),
read_events_(A.impl_->read_events_) {}

template <typename T>
matrix_cl<T>& matrix_cl<T>::operator=(const arena_matrix_cl<T>& a) {
// works like a move assignment operator, except it does not modify `a`
view_ = a.impl_->view();
rows_ = a.impl_->rows();
cols_ = a.impl_->cols();
this->wait_for_read_write_events();
buffer_cl_ = a.impl_->buffer_cl_;
write_events_ = a.impl_->write_events_;
read_events_ = a.impl_->read_events_;
return *this;
}

} // namespace math
} // namespace stan
Expand Down
38 changes: 38 additions & 0 deletions test/unit/math/opencl/rev/arena_matrix_cl_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#ifdef STAN_OPENCL
#include <stan/math/opencl/rev.hpp>
#include <test/unit/util.hpp>
#include <gtest/gtest.h>

TEST(AgradRev, arena_matrix_cl_shallow_copies) {
stan::math::arena_matrix_cl<double> a(3, 2);
stan::math::arena_matrix_cl<double> b(a);
stan::math::arena_matrix_cl<double> c;
c = a;
EXPECT_EQ(a.buffer()(), b.buffer()());
EXPECT_EQ(a.buffer()(), c.buffer()());
}

TEST(AgradRev, arena_matrix_cl_to_matrix_cl_conversion) {
stan::math::arena_matrix_cl<double> a(3, 2);
const stan::math::matrix_cl<double>& b(a);
EXPECT_EQ(a.buffer()(), b.buffer()());
}

TEST(AgradRev, arena_matrix_cl_to_matrix_cl_move_construction) {
stan::math::arena_matrix_cl<double> a(3, 2);
cl::Buffer a_buf = a.buffer();
stan::math::matrix_cl<double> b(std::move(a));
EXPECT_EQ(a_buf(), a.buffer()());
EXPECT_EQ(a_buf(), b.buffer()());
}

TEST(AgradRev, arena_matrix_cl_to_matrix_cl_move_assignment) {
stan::math::arena_matrix_cl<double> a(3, 2);
cl::Buffer a_buf = a.buffer();
stan::math::matrix_cl<double> b;
b = std::move(a);
EXPECT_EQ(a_buf(), a.buffer()());
EXPECT_EQ(a_buf(), b.buffer()());
}

#endif
3 changes: 3 additions & 0 deletions test/unit/math/opencl/rev/to_arena_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ TEST(AgradRev, to_arena_matrix_cl_test) {
auto b = stan::math::to_arena(a);
EXPECT_MATRIX_EQ(stan::math::from_matrix_cl(b),
stan::math::from_matrix_cl(a));
EXPECT_EQ(a.buffer()(), b.buffer()());
EXPECT_FALSE((std::is_same<decltype(a), decltype(b)>::value));

auto c = stan::math::to_arena(b);
Expand Down Expand Up @@ -54,6 +55,8 @@ TEST(AgradRev, to_arena_var_value_matrix_cl_test) {
stan::math::from_matrix_cl(a.val()));
EXPECT_MATRIX_EQ(stan::math::from_matrix_cl(b.adj()),
stan::math::from_matrix_cl(a.adj()));
EXPECT_EQ(a.val().buffer()(), b.val().buffer()());
EXPECT_EQ(a.adj().buffer()(), b.adj().buffer()());
EXPECT_TRUE((std::is_same<decltype(a), decltype(b)>::value));

auto c = stan::math::to_arena(b);
Expand Down