Skip to content

Commit 3f8c41a

Browse files
authoredMar 4, 2021
Merge pull request #2404 from bstatcomp/opencl_ref_type
added OpenCL ref_type
2 parents fbe53a9 + 1e1bd58 commit 3f8c41a

File tree

5 files changed

+298
-150
lines changed

5 files changed

+298
-150
lines changed
 

‎stan/math/opencl/prim.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
#include <stan/math/opencl/pinned_matrix.hpp>
9494
#include <stan/math/opencl/plain_type.hpp>
9595
#include <stan/math/opencl/ref_type_for_opencl.hpp>
96+
#include <stan/math/opencl/ref_type.hpp>
9697
#include <stan/math/opencl/to_ref_for_opencl.hpp>
9798
#include <stan/math/opencl/value_type.hpp>
9899
#include <stan/math/opencl/zeros_strict_tri.hpp>

‎stan/math/opencl/ref_type.hpp

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#ifndef STAN_MATH_OPENCL_REF_TYPE_HPP
2+
#define STAN_MATH_OPENCL_REF_TYPE_HPP
3+
#ifdef STAN_OPENCL
4+
5+
#include <stan/math/opencl/kernel_generator.hpp>
6+
#include <stan/math/opencl/plain_type.hpp>
7+
#include <stan/math/prim/meta/ref_type.hpp>
8+
#include <type_traits>
9+
10+
namespace stan {
11+
12+
namespace internal {
13+
14+
template <typename T>
15+
struct is_trivial_kg_expression : public std::false_type {};
16+
17+
template <>
18+
struct is_trivial_kg_expression<int> : public std::true_type {};
19+
template <>
20+
struct is_trivial_kg_expression<double> : public std::true_type {};
21+
template <typename T>
22+
struct is_trivial_kg_expression<math::matrix_cl<T>> : public std::true_type {};
23+
template <typename T>
24+
struct is_trivial_kg_expression<math::load_<T>> : public std::true_type {};
25+
template <typename T>
26+
struct is_trivial_kg_expression<math::scalar_<T>> : public std::true_type {};
27+
template <typename T>
28+
struct is_trivial_kg_expression<math::constant_<T>> : public std::true_type {};
29+
template <>
30+
struct is_trivial_kg_expression<math::row_index> : public std::true_type {};
31+
template <>
32+
struct is_trivial_kg_expression<math::col_index> : public std::true_type {};
33+
template <typename T>
34+
struct is_trivial_kg_expression<math::calc_if_<false, T>>
35+
: public std::true_type {};
36+
37+
template <typename T>
38+
struct is_trivial_kg_expression<math::as_column_vector_or_scalar_<T>>
39+
: public is_trivial_kg_expression<std::decay_t<T>> {};
40+
template <typename T>
41+
struct is_trivial_kg_expression<math::block_<T>>
42+
: public is_trivial_kg_expression<std::decay_t<T>> {};
43+
template <typename T, bool Colwise, bool Rowwise>
44+
struct is_trivial_kg_expression<math::broadcast_<T, Colwise, Rowwise>>
45+
: public is_trivial_kg_expression<std::decay_t<T>> {};
46+
template <typename T>
47+
struct is_trivial_kg_expression<math::calc_if_<true, T>>
48+
: public is_trivial_kg_expression<std::decay_t<T>> {};
49+
template <typename T>
50+
struct is_trivial_kg_expression<math::holder_cl_<T>>
51+
: public is_trivial_kg_expression<std::decay_t<T>> {};
52+
template <typename T, bool Colwise, bool Rowwise>
53+
struct is_trivial_kg_expression<math::optional_broadcast_<T, Colwise, Rowwise>>
54+
: public is_trivial_kg_expression<std::decay_t<T>> {};
55+
56+
} // namespace internal
57+
58+
template <bool Condition, typename T>
59+
struct ref_type_if<Condition, T, require_all_kernel_expressions_t<T>> {
60+
using T_plain = plain_type_t<T>;
61+
using T_optionally_ref
62+
= std::conditional_t<std::is_rvalue_reference<T>::value,
63+
std::remove_reference_t<T>, const T&>;
64+
using type = std::conditional_t<
65+
internal::is_trivial_kg_expression<std::decay_t<T>>::value || !Condition,
66+
T_optionally_ref, T_plain>;
67+
};
68+
69+
} // namespace stan
70+
#endif
71+
#endif

‎stan/math/prim/meta/ref_type.hpp

+8-6
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ namespace stan {
2323
*/
2424
template <bool Condition, typename T, typename = void>
2525
struct ref_type_if {
26+
using type = std::conditional_t<std::is_rvalue_reference<T>::value,
27+
std::remove_reference_t<T>, const T&>;
28+
};
29+
30+
template <bool Condition, typename T>
31+
struct ref_type_if<
32+
Condition, T,
33+
require_all_t<is_eigen<T>, bool_constant<!is_arena_matrix<T>::value>>> {
2634
using T_plain = plain_type_t<T>;
2735
using T_optionally_ref
2836
= std::conditional_t<std::is_rvalue_reference<T>::value,
@@ -36,12 +44,6 @@ struct ref_type_if {
3644
T_optionally_ref, T_plain>;
3745
};
3846

39-
template <bool Condition, typename T>
40-
struct ref_type_if<Condition, T, require_not_eigen_t<T>> {
41-
using type = std::conditional_t<std::is_rvalue_reference<T>::value,
42-
std::remove_reference_t<T>, const T&>;
43-
};
44-
4547
template <bool Condition, typename T>
4648
struct ref_type_if<Condition, T, require_arena_matrix_t<T>> {
4749
using type =
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
#ifdef STAN_OPENCL
2+
#include <stan/math/opencl/prim.hpp>
3+
#include <stan/math/prim.hpp>
4+
5+
#include <test/unit/util.hpp>
6+
#include <gtest/gtest.h>
7+
#include <stan/math/opencl/pinned_matrix.hpp>
8+
9+
TEST(MathMetaPrim, ref_type_for_opencl_for_opencl_non_eigen) {
10+
using stan::math::ref_type_for_opencl_t;
11+
std::vector<int> a{1, 2, 3};
12+
ref_type_for_opencl_t<std::vector<int>> a_ref1 = a;
13+
ref_type_for_opencl_t<std::vector<int>&> a_ref2 = a;
14+
ref_type_for_opencl_t<std::vector<int>&&> a_ref3 = std::vector<int>{1, 2, 3};
15+
16+
double b = 3;
17+
ref_type_for_opencl_t<double> b_ref1 = b;
18+
ref_type_for_opencl_t<double&> b_ref2 = b;
19+
ref_type_for_opencl_t<double&&> b_ref3 = 3;
20+
21+
const std::vector<double> c{0.5, 4, 0.7};
22+
ref_type_for_opencl_t<const std::vector<double>> c_ref1 = c;
23+
ref_type_for_opencl_t<const std::vector<double>&> c_ref2 = c;
24+
25+
EXPECT_STD_VECTOR_FLOAT_EQ(a_ref1, a);
26+
EXPECT_STD_VECTOR_FLOAT_EQ(a_ref2, a);
27+
EXPECT_STD_VECTOR_FLOAT_EQ(a_ref3, a);
28+
EXPECT_EQ(b_ref1, b);
29+
EXPECT_EQ(b_ref2, b);
30+
EXPECT_EQ(b_ref3, b);
31+
EXPECT_STD_VECTOR_FLOAT_EQ(c_ref1, c);
32+
EXPECT_STD_VECTOR_FLOAT_EQ(c_ref2, c);
33+
EXPECT_TRUE(std::is_lvalue_reference<ref_type_for_opencl_t<double>>::value);
34+
EXPECT_TRUE(std::is_lvalue_reference<ref_type_for_opencl_t<double&>>::value);
35+
EXPECT_FALSE(std::is_reference<ref_type_for_opencl_t<double&&>>::value);
36+
EXPECT_TRUE(std::is_lvalue_reference<
37+
ref_type_for_opencl_t<const std::vector<double>>>::value);
38+
EXPECT_TRUE(std::is_lvalue_reference<
39+
ref_type_for_opencl_t<const std::vector<double>&>>::value);
40+
EXPECT_FALSE(std::is_reference<
41+
ref_type_for_opencl_t<const std::vector<double>&&>>::value);
42+
}
43+
44+
TEST(MathMetaPrim, ref_type_for_opencl_eigen_contiguous) {
45+
using stan::math::ref_type_for_opencl_t;
46+
Eigen::MatrixXd a(3, 3);
47+
a << 1, 2, 3, 4, 5, 6, 7, 8, 9;
48+
Eigen::MatrixXd a2 = a;
49+
ref_type_for_opencl_t<Eigen::MatrixXd> a_ref1 = a;
50+
ref_type_for_opencl_t<Eigen::MatrixXd&> a_ref2 = a;
51+
ref_type_for_opencl_t<Eigen::MatrixXd&&> a_ref3 = std::move(a2);
52+
53+
auto b = a.leftCols(2);
54+
ref_type_for_opencl_t<decltype(b)> b_ref1 = b;
55+
ref_type_for_opencl_t<decltype(b)&> b_ref2 = b;
56+
ref_type_for_opencl_t<decltype(b)&&> b_ref3 = a.leftCols(2);
57+
58+
using ContiguousMap = Eigen::Map<Eigen::MatrixXd, 0, Eigen::Stride<0, 0>>;
59+
ContiguousMap c(a.data(), 3, 3);
60+
ContiguousMap c2(a.data(), 3, 3);
61+
ref_type_for_opencl_t<ContiguousMap> c_ref1 = c;
62+
ref_type_for_opencl_t<ContiguousMap&> c_ref2 = c;
63+
ref_type_for_opencl_t<ContiguousMap&&> c_ref3 = std::move(c2);
64+
65+
EXPECT_MATRIX_EQ(a_ref1, a);
66+
EXPECT_MATRIX_EQ(a_ref2, a);
67+
EXPECT_MATRIX_EQ(a_ref3, a);
68+
69+
EXPECT_MATRIX_EQ(b_ref1, b);
70+
EXPECT_MATRIX_EQ(b_ref2, b);
71+
EXPECT_MATRIX_EQ(b_ref3, b);
72+
73+
EXPECT_MATRIX_EQ(c_ref1, c);
74+
EXPECT_MATRIX_EQ(c_ref2, c);
75+
EXPECT_MATRIX_EQ(c_ref3, c);
76+
EXPECT_TRUE(
77+
(std::is_same<decltype(a), ref_type_for_opencl_t<decltype(a)&&>>::value));
78+
EXPECT_TRUE(
79+
(std::is_same<decltype(b), ref_type_for_opencl_t<decltype(b)&&>>::value));
80+
EXPECT_TRUE(
81+
(std::is_same<decltype(c), ref_type_for_opencl_t<decltype(c)&&>>::value));
82+
EXPECT_TRUE(
83+
std::is_lvalue_reference<ref_type_for_opencl_t<Eigen::MatrixXd>>::value);
84+
EXPECT_TRUE(
85+
std::is_lvalue_reference<ref_type_for_opencl_t<Eigen::MatrixXd&>>::value);
86+
EXPECT_FALSE(
87+
std::is_reference<ref_type_for_opencl_t<Eigen::MatrixXd&&>>::value);
88+
}
89+
90+
TEST(MathMetaPrim, ref_type_for_opencl_eigen_non_contiguous) {
91+
using stan::math::ref_type_for_opencl_t;
92+
Eigen::MatrixXd m(3, 3);
93+
m << 1, 2, 3, 4, 5, 6, 7, 8, 9;
94+
using RowMajorMatrixXd
95+
= Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
96+
RowMajorMatrixXd a = m;
97+
RowMajorMatrixXd a2 = m;
98+
ref_type_for_opencl_t<RowMajorMatrixXd> a_ref1 = a;
99+
ref_type_for_opencl_t<RowMajorMatrixXd&> a_ref2 = a;
100+
ref_type_for_opencl_t<RowMajorMatrixXd&&> a_ref3 = std::move(a2);
101+
102+
auto b = m.block(1, 0, 2, 2);
103+
ref_type_for_opencl_t<decltype(b)> b_ref1 = b;
104+
ref_type_for_opencl_t<decltype(b)&> b_ref2 = b;
105+
ref_type_for_opencl_t<decltype(b)&&> b_ref3 = a.block(1, 0, 2, 2);
106+
107+
Eigen::Ref<Eigen::MatrixXd> c = m;
108+
Eigen::Ref<Eigen::MatrixXd> c2 = m;
109+
ref_type_for_opencl_t<Eigen::Ref<Eigen::MatrixXd>> c_ref1 = c;
110+
ref_type_for_opencl_t<Eigen::Ref<Eigen::MatrixXd>&> c_ref2 = c;
111+
ref_type_for_opencl_t<Eigen::Ref<Eigen::MatrixXd>&&> c_ref3 = std::move(c2);
112+
113+
EXPECT_MATRIX_EQ(a_ref1, a);
114+
EXPECT_MATRIX_EQ(a_ref2, a);
115+
EXPECT_MATRIX_EQ(a_ref3, a);
116+
117+
EXPECT_MATRIX_EQ(b_ref1, b);
118+
EXPECT_MATRIX_EQ(b_ref2, b);
119+
EXPECT_MATRIX_EQ(b_ref3, b);
120+
121+
EXPECT_MATRIX_EQ(c_ref1, c);
122+
EXPECT_MATRIX_EQ(c_ref2, c);
123+
EXPECT_MATRIX_EQ(c_ref3, c);
124+
EXPECT_TRUE((std::is_same<stan::math::pinned_matrix<Eigen::MatrixXd>,
125+
ref_type_for_opencl_t<decltype(a)&&>>::value));
126+
EXPECT_TRUE((std::is_same<stan::math::pinned_matrix<Eigen::MatrixXd>,
127+
ref_type_for_opencl_t<decltype(b)&&>>::value));
128+
EXPECT_TRUE((std::is_same<stan::math::pinned_matrix<Eigen::MatrixXd>,
129+
ref_type_for_opencl_t<decltype(c)&&>>::value));
130+
}
131+
132+
TEST(MathMetaPrim, ref_type_for_opencl_eigen_expression) {
133+
using stan::plain_type_t;
134+
using stan::math::ref_type_for_opencl_t;
135+
Eigen::MatrixXd m(3, 3);
136+
m << 1, 2, 3, 4, 5, 6, 7, 8, 9;
137+
auto a = m * 3;
138+
ref_type_for_opencl_t<decltype(a)> a_ref1 = a;
139+
ref_type_for_opencl_t<decltype(a)&> a_ref2 = a;
140+
ref_type_for_opencl_t<decltype(a)&&> a_ref3 = m * 3;
141+
142+
Eigen::MatrixXd a_eval = a;
143+
EXPECT_MATRIX_EQ(a_ref1, a_eval);
144+
EXPECT_MATRIX_EQ(a_ref2, a_eval);
145+
EXPECT_MATRIX_EQ(a_ref3, a_eval);
146+
147+
EXPECT_TRUE(
148+
(std::is_same<stan::math::pinned_matrix<plain_type_t<decltype(a)>>,
149+
ref_type_for_opencl_t<decltype(a)>>::value));
150+
EXPECT_TRUE(
151+
(std::is_same<stan::math::pinned_matrix<plain_type_t<decltype(a)>>,
152+
ref_type_for_opencl_t<decltype(a)&>>::value));
153+
EXPECT_TRUE(
154+
(std::is_same<stan::math::pinned_matrix<plain_type_t<decltype(a)>>,
155+
ref_type_for_opencl_t<decltype(a)&&>>::value));
156+
}
157+
158+
#endif
+60-144
Original file line numberDiff line numberDiff line change
@@ -1,158 +1,74 @@
11
#ifdef STAN_OPENCL
22
#include <stan/math/opencl/prim.hpp>
33
#include <stan/math/prim.hpp>
4-
54
#include <test/unit/util.hpp>
65
#include <gtest/gtest.h>
7-
#include <stan/math/opencl/pinned_matrix.hpp>
8-
9-
TEST(MathMetaPrim, ref_type_for_opencl_for_opencl_non_eigen) {
10-
using stan::math::ref_type_for_opencl_t;
11-
std::vector<int> a{1, 2, 3};
12-
ref_type_for_opencl_t<std::vector<int>> a_ref1 = a;
13-
ref_type_for_opencl_t<std::vector<int>&> a_ref2 = a;
14-
ref_type_for_opencl_t<std::vector<int>&&> a_ref3 = std::vector<int>{1, 2, 3};
15-
16-
double b = 3;
17-
ref_type_for_opencl_t<double> b_ref1 = b;
18-
ref_type_for_opencl_t<double&> b_ref2 = b;
19-
ref_type_for_opencl_t<double&&> b_ref3 = 3;
20-
21-
const std::vector<double> c{0.5, 4, 0.7};
22-
ref_type_for_opencl_t<const std::vector<double>> c_ref1 = c;
23-
ref_type_for_opencl_t<const std::vector<double>&> c_ref2 = c;
24-
25-
EXPECT_STD_VECTOR_FLOAT_EQ(a_ref1, a);
26-
EXPECT_STD_VECTOR_FLOAT_EQ(a_ref2, a);
27-
EXPECT_STD_VECTOR_FLOAT_EQ(a_ref3, a);
28-
EXPECT_EQ(b_ref1, b);
29-
EXPECT_EQ(b_ref2, b);
30-
EXPECT_EQ(b_ref3, b);
31-
EXPECT_STD_VECTOR_FLOAT_EQ(c_ref1, c);
32-
EXPECT_STD_VECTOR_FLOAT_EQ(c_ref2, c);
33-
EXPECT_TRUE(std::is_lvalue_reference<ref_type_for_opencl_t<double>>::value);
34-
EXPECT_TRUE(std::is_lvalue_reference<ref_type_for_opencl_t<double&>>::value);
35-
EXPECT_FALSE(std::is_reference<ref_type_for_opencl_t<double&&>>::value);
36-
EXPECT_TRUE(std::is_lvalue_reference<
37-
ref_type_for_opencl_t<const std::vector<double>>>::value);
38-
EXPECT_TRUE(std::is_lvalue_reference<
39-
ref_type_for_opencl_t<const std::vector<double>&>>::value);
40-
EXPECT_FALSE(std::is_reference<
41-
ref_type_for_opencl_t<const std::vector<double>&&>>::value);
42-
}
43-
44-
TEST(MathMetaPrim, ref_type_for_opencl_eigen_contiguous) {
45-
using stan::math::ref_type_for_opencl_t;
46-
Eigen::MatrixXd a(3, 3);
47-
a << 1, 2, 3, 4, 5, 6, 7, 8, 9;
48-
Eigen::MatrixXd a2 = a;
49-
ref_type_for_opencl_t<Eigen::MatrixXd> a_ref1 = a;
50-
ref_type_for_opencl_t<Eigen::MatrixXd&> a_ref2 = a;
51-
ref_type_for_opencl_t<Eigen::MatrixXd&&> a_ref3 = std::move(a2);
52-
53-
auto b = a.leftCols(2);
54-
ref_type_for_opencl_t<decltype(b)> b_ref1 = b;
55-
ref_type_for_opencl_t<decltype(b)&> b_ref2 = b;
56-
ref_type_for_opencl_t<decltype(b)&&> b_ref3 = a.leftCols(2);
57-
58-
using ContiguousMap = Eigen::Map<Eigen::MatrixXd, 0, Eigen::Stride<0, 0>>;
59-
ContiguousMap c(a.data(), 3, 3);
60-
ContiguousMap c2(a.data(), 3, 3);
61-
ref_type_for_opencl_t<ContiguousMap> c_ref1 = c;
62-
ref_type_for_opencl_t<ContiguousMap&> c_ref2 = c;
63-
ref_type_for_opencl_t<ContiguousMap&&> c_ref3 = std::move(c2);
646

65-
EXPECT_MATRIX_EQ(a_ref1, a);
66-
EXPECT_MATRIX_EQ(a_ref2, a);
67-
EXPECT_MATRIX_EQ(a_ref3, a);
68-
69-
EXPECT_MATRIX_EQ(b_ref1, b);
70-
EXPECT_MATRIX_EQ(b_ref2, b);
71-
EXPECT_MATRIX_EQ(b_ref3, b);
72-
73-
EXPECT_MATRIX_EQ(c_ref1, c);
74-
EXPECT_MATRIX_EQ(c_ref2, c);
75-
EXPECT_MATRIX_EQ(c_ref3, c);
76-
EXPECT_TRUE(
77-
(std::is_same<decltype(a), ref_type_for_opencl_t<decltype(a)&&>>::value));
78-
EXPECT_TRUE(
79-
(std::is_same<decltype(b), ref_type_for_opencl_t<decltype(b)&&>>::value));
80-
EXPECT_TRUE(
81-
(std::is_same<decltype(c), ref_type_for_opencl_t<decltype(c)&&>>::value));
82-
EXPECT_TRUE(
83-
std::is_lvalue_reference<ref_type_for_opencl_t<Eigen::MatrixXd>>::value);
84-
EXPECT_TRUE(
85-
std::is_lvalue_reference<ref_type_for_opencl_t<Eigen::MatrixXd&>>::value);
86-
EXPECT_FALSE(
87-
std::is_reference<ref_type_for_opencl_t<Eigen::MatrixXd&&>>::value);
7+
TEST(MathMetaPrim, ref_type_matrix_cl) {
8+
using stan::ref_type_t;
9+
using stan::math::matrix_cl;
10+
Eigen::MatrixXd a_eig(3, 3);
11+
a_eig << 1, 2, 3, 4, 5, 6, 7, 8, 9;
12+
matrix_cl<double> a(a_eig);
13+
matrix_cl<double> a2 = a;
14+
ref_type_t<matrix_cl<double>> a_ref1 = a;
15+
ref_type_t<matrix_cl<double>&> a_ref2 = a;
16+
ref_type_t<matrix_cl<double>&&> a_ref3 = std::move(a2);
17+
18+
EXPECT_MATRIX_EQ(from_matrix_cl(a_ref1), from_matrix_cl(a));
19+
EXPECT_MATRIX_EQ(from_matrix_cl(a_ref2), from_matrix_cl(a));
20+
EXPECT_MATRIX_EQ(from_matrix_cl(a_ref3), from_matrix_cl(a));
21+
22+
EXPECT_TRUE((std::is_same<decltype(a), ref_type_t<decltype(a)&&>>::value));
23+
EXPECT_TRUE(std::is_lvalue_reference<ref_type_t<matrix_cl<double>>>::value);
24+
EXPECT_TRUE(std::is_lvalue_reference<ref_type_t<matrix_cl<double>&>>::value);
25+
EXPECT_FALSE(std::is_reference<ref_type_t<matrix_cl<double>&&>>::value);
8826
}
8927

90-
TEST(MathMetaPrim, ref_type_for_opencl_eigen_non_contiguous) {
91-
using stan::math::ref_type_for_opencl_t;
92-
Eigen::MatrixXd m(3, 3);
93-
m << 1, 2, 3, 4, 5, 6, 7, 8, 9;
94-
using RowMajorMatrixXd
95-
= Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
96-
RowMajorMatrixXd a = m;
97-
RowMajorMatrixXd a2 = m;
98-
ref_type_for_opencl_t<RowMajorMatrixXd> a_ref1 = a;
99-
ref_type_for_opencl_t<RowMajorMatrixXd&> a_ref2 = a;
100-
ref_type_for_opencl_t<RowMajorMatrixXd&&> a_ref3 = std::move(a2);
101-
102-
auto b = m.block(1, 0, 2, 2);
103-
ref_type_for_opencl_t<decltype(b)> b_ref1 = b;
104-
ref_type_for_opencl_t<decltype(b)&> b_ref2 = b;
105-
ref_type_for_opencl_t<decltype(b)&&> b_ref3 = a.block(1, 0, 2, 2);
106-
107-
Eigen::Ref<Eigen::MatrixXd> c = m;
108-
Eigen::Ref<Eigen::MatrixXd> c2 = m;
109-
ref_type_for_opencl_t<Eigen::Ref<Eigen::MatrixXd>> c_ref1 = c;
110-
ref_type_for_opencl_t<Eigen::Ref<Eigen::MatrixXd>&> c_ref2 = c;
111-
ref_type_for_opencl_t<Eigen::Ref<Eigen::MatrixXd>&&> c_ref3 = std::move(c2);
112-
113-
EXPECT_MATRIX_EQ(a_ref1, a);
114-
EXPECT_MATRIX_EQ(a_ref2, a);
115-
EXPECT_MATRIX_EQ(a_ref3, a);
116-
117-
EXPECT_MATRIX_EQ(b_ref1, b);
118-
EXPECT_MATRIX_EQ(b_ref2, b);
119-
EXPECT_MATRIX_EQ(b_ref3, b);
120-
121-
EXPECT_MATRIX_EQ(c_ref1, c);
122-
EXPECT_MATRIX_EQ(c_ref2, c);
123-
EXPECT_MATRIX_EQ(c_ref3, c);
124-
EXPECT_TRUE((std::is_same<stan::math::pinned_matrix<Eigen::MatrixXd>,
125-
ref_type_for_opencl_t<decltype(a)&&>>::value));
126-
EXPECT_TRUE((std::is_same<stan::math::pinned_matrix<Eigen::MatrixXd>,
127-
ref_type_for_opencl_t<decltype(b)&&>>::value));
128-
EXPECT_TRUE((std::is_same<stan::math::pinned_matrix<Eigen::MatrixXd>,
129-
ref_type_for_opencl_t<decltype(c)&&>>::value));
28+
TEST(MathMetaPrim, ref_type_kg_light_expression) {
29+
using stan::ref_type_t;
30+
using stan::math::matrix_cl;
31+
Eigen::MatrixXd m_eig(3, 3);
32+
m_eig << 1, 2, 3, 4, 5, 6, 7, 8, 9;
33+
matrix_cl<double> m(m_eig);
34+
auto a = stan::math::block_zero_based(m, 0, 1, 2, 2);
35+
ref_type_t<decltype(a)> a_ref1 = a;
36+
ref_type_t<decltype(a)&> a_ref2 = a;
37+
ref_type_t<decltype(a)&&> a_ref3
38+
= stan::math::block_zero_based(m, 0, 1, 2, 2);
39+
40+
matrix_cl<double> a_eval = a;
41+
EXPECT_MATRIX_EQ(from_matrix_cl(a_ref1), from_matrix_cl(a_eval));
42+
EXPECT_MATRIX_EQ(from_matrix_cl(a_ref2), from_matrix_cl(a_eval));
43+
EXPECT_MATRIX_EQ(from_matrix_cl(a_ref3), from_matrix_cl(a_eval));
44+
45+
EXPECT_FALSE((stan::is_matrix_cl<ref_type_t<decltype(a)>>::value));
46+
EXPECT_FALSE((stan::is_matrix_cl<ref_type_t<decltype(a)&>>::value));
47+
EXPECT_FALSE((stan::is_matrix_cl<ref_type_t<decltype(a)&&>>::value));
13048
}
13149

132-
TEST(MathMetaPrim, ref_type_for_opencl_eigen_expression) {
133-
using stan::plain_type_t;
134-
using stan::math::ref_type_for_opencl_t;
135-
Eigen::MatrixXd m(3, 3);
136-
m << 1, 2, 3, 4, 5, 6, 7, 8, 9;
50+
TEST(MathMetaPrim, ref_type_kg_heavy_expression) {
51+
using stan::ref_type_t;
52+
using stan::math::matrix_cl;
53+
Eigen::MatrixXd m_eig(3, 3);
54+
m_eig << 1, 2, 3, 4, 5, 6, 7, 8, 9;
55+
matrix_cl<double> m(m_eig);
13756
auto a = m * 3;
138-
ref_type_for_opencl_t<decltype(a)> a_ref1 = a;
139-
ref_type_for_opencl_t<decltype(a)&> a_ref2 = a;
140-
ref_type_for_opencl_t<decltype(a)&&> a_ref3 = m * 3;
141-
142-
Eigen::MatrixXd a_eval = a;
143-
EXPECT_MATRIX_EQ(a_ref1, a_eval);
144-
EXPECT_MATRIX_EQ(a_ref2, a_eval);
145-
EXPECT_MATRIX_EQ(a_ref3, a_eval);
146-
147-
EXPECT_TRUE(
148-
(std::is_same<stan::math::pinned_matrix<plain_type_t<decltype(a)>>,
149-
ref_type_for_opencl_t<decltype(a)>>::value));
150-
EXPECT_TRUE(
151-
(std::is_same<stan::math::pinned_matrix<plain_type_t<decltype(a)>>,
152-
ref_type_for_opencl_t<decltype(a)&>>::value));
153-
EXPECT_TRUE(
154-
(std::is_same<stan::math::pinned_matrix<plain_type_t<decltype(a)>>,
155-
ref_type_for_opencl_t<decltype(a)&&>>::value));
57+
ref_type_t<decltype(a)> a_ref1 = a;
58+
ref_type_t<decltype(a)&> a_ref2 = a;
59+
ref_type_t<decltype(a)&&> a_ref3 = m * 3;
60+
61+
matrix_cl<double> a_eval = a;
62+
EXPECT_MATRIX_EQ(from_matrix_cl(a_ref1), from_matrix_cl(a_eval));
63+
EXPECT_MATRIX_EQ(from_matrix_cl(a_ref2), from_matrix_cl(a_eval));
64+
EXPECT_MATRIX_EQ(from_matrix_cl(a_ref3), from_matrix_cl(a_eval));
65+
66+
EXPECT_TRUE((std::is_same<stan::math::matrix_cl<double>,
67+
ref_type_t<decltype(a)>>::value));
68+
EXPECT_TRUE((std::is_same<stan::math::matrix_cl<double>,
69+
ref_type_t<decltype(a)&>>::value));
70+
EXPECT_TRUE((std::is_same<stan::math::matrix_cl<double>,
71+
ref_type_t<decltype(a)&&>>::value));
15672
}
15773

15874
#endif

0 commit comments

Comments
 (0)
Please sign in to comment.