Skip to content

Commit aeff239

Browse files
committed
added mdivide_right_tri_low
1 parent 20df1f9 commit aeff239

File tree

6 files changed

+104
-46
lines changed

6 files changed

+104
-46
lines changed

Jenkinsfile

-4
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,6 @@ pipeline {
238238
runTests("test/unit/math/opencl")
239239
runTests("test/unit/multiple_translation_units_test.cpp")
240240
runTests("test/unit/math/prim/fun/gp_exp_quad_cov_test.cpp")
241-
runTests("test/unit/math/prim/fun/mdivide_right_tri_test.cpp")
242241
runTests("test/unit/math/prim/fun/multiply_test.cpp")
243242
runTests("test/unit/math/rev/fun/multiply_test.cpp")
244243
} else {
@@ -253,7 +252,6 @@ pipeline {
253252
runTestsWin("test/unit/math/opencl", false, false)
254253
runTestsWin("test/unit/multiple_translation_units_test.cpp", false, false)
255254
runTestsWin("test/unit/math/prim/fun/gp_exp_quad_cov_test.cpp", false, false)
256-
runTestsWin("test/unit/math/prim/fun/mdivide_right_tri_test.cpp", false, false)
257255
runTestsWin("test/unit/math/prim/fun/multiply_test.cpp", false, false)
258256
runTestsWin("test/unit/math/rev/fun/multiply_test.cpp", false, false)
259257
}
@@ -274,7 +272,6 @@ pipeline {
274272
runTests("test/unit/math/opencl")
275273
runTests("test/unit/multiple_translation_units_test.cpp")
276274
runTests("test/unit/math/prim/fun/gp_exp_quad_cov_test.cpp")
277-
runTests("test/unit/math/prim/fun/mdivide_right_tri_test.cpp")
278275
runTests("test/unit/math/prim/fun/multiply_test.cpp")
279276
runTests("test/unit/math/rev/fun/multiply_test.cpp")
280277
} else {
@@ -289,7 +286,6 @@ pipeline {
289286
runTestsWin("test/unit/math/opencl", false, false)
290287
runTestsWin("test/unit/multiple_translation_units_test.cpp", false, false)
291288
runTestsWin("test/unit/math/prim/fun/gp_exp_quad_cov_test.cpp", false, false)
292-
runTestsWin("test/unit/math/prim/fun/mdivide_right_tri_test.cpp", false, false)
293289
runTestsWin("test/unit/math/prim/fun/multiply_test.cpp", false, false)
294290
runTestsWin("test/unit/math/rev/fun/multiply_test.cpp", false, false)
295291
}

stan/math/opencl/prim/mdivide_right_tri_low.hpp

+7-4
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,15 @@ namespace math {
2121
* match the size of A.
2222
*/
2323
template <typename T1, typename T2,
24-
typename = require_all_floating_point_t<T1, T2>>
25-
inline matrix_cl<return_type_t<T1, T2>> mdivide_right_tri_low(
26-
const matrix_cl<T2>& b, const matrix_cl<T1>& A) {
24+
require_all_kernel_expressions_t<T1, T2>* = nullptr>
25+
inline matrix_cl<double> mdivide_right_tri_low(
26+
const T2& b, const T1& A) {
2727
check_square("mdivide_right_tri_low (OpenCL)", "A", A);
2828
check_multiplicable("mdivide_right_tri_low (OpenCL)", "b", b, "A", A);
29-
return b * tri_inverse<matrix_cl_view::Lower>(A);
29+
if (A.size() == 0 || b.size() == 0) {
30+
return matrix_cl<double>(b.rows(), A.cols());
31+
}
32+
return b * tri_inverse<matrix_cl_view::Lower>(eval(A));
3033
}
3134

3235
} // namespace math

stan/math/opencl/rev.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
#include <stan/math/opencl/rev/log_inv_logit_diff.hpp>
6262
#include <stan/math/opencl/rev/log_diff_exp.hpp>
6363
#include <stan/math/opencl/rev/mdivide_left_tri_low.hpp>
64+
#include <stan/math/opencl/rev/mdivide_right_tri_low.hpp>
6465
#include <stan/math/opencl/rev/multiply.hpp>
6566
#include <stan/math/opencl/rev/multiply_log.hpp>
6667
#include <stan/math/opencl/rev/operands_and_partials.hpp>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#ifndef STAN_MATH_OPENCL_REV_MDIVIDE_RIGHT_TRI_LOW_HPP
2+
#define STAN_MATH_OPENCL_REV_MDIVIDE_RIGHT_TRI_LOW_HPP
3+
#ifdef STAN_OPENCL
4+
5+
#include <stan/math/opencl/rev/arena_type.hpp>
6+
#include <stan/math/opencl/matrix_cl.hpp>
7+
#include <stan/math/opencl/prim/mdivide_left_tri_low.hpp>
8+
#include <stan/math/rev/core.hpp>
9+
#include <stan/math/rev/fun/adjoint_of.hpp>
10+
#include <stan/math/rev/fun/value_of.hpp>
11+
12+
namespace stan {
13+
namespace math {
14+
15+
/**
16+
* Returns the solution of the system Ax=b when A is lower triangular.
17+
*
18+
* @tparam T1 type of elements in A
19+
* @tparam T2 type of elements in b
20+
* @param A Triangular matrix.
21+
* @param b Right hand side matrix or vector.
22+
* @return x = A^-1 b, solution of the linear system.
23+
* @throws std::domain_error if A is not square or the rows of b don't
24+
* match the size of A.
25+
*/
26+
template <
27+
typename T1, typename T2,
28+
require_all_nonscalar_prim_or_rev_kernel_expression_t<T1, T2>* = nullptr,
29+
require_any_var_t<T1, T2>* = nullptr>
30+
inline var_value<matrix_cl<double>> mdivide_right_tri_low(T2&& b, T1&& A) {
31+
check_square("mdivide_right_tri_low", "A", A);
32+
check_multiplicable("mdivide_right_tri_low", "b", b, "A", A);
33+
if (A.size() == 0 || b.size() == 0) {
34+
return var_value<matrix_cl<double>>(matrix_cl<double>(b.rows(), A.cols()));
35+
}
36+
arena_t<T1> A_arena = std::forward<T1>(A);
37+
arena_t<T2> b_arena = std::forward<T2>(b);
38+
arena_matrix_cl<double> A_tri_inv
39+
= tri_inverse<matrix_cl_view::Lower>(value_of(A_arena));
40+
return make_callback_var(
41+
value_of(b_arena) * A_tri_inv,
42+
[A_arena, b_arena, A_tri_inv](const vari_value<matrix_cl<double>>& res) {
43+
matrix_cl<double> adjB = res.adj() * transpose(A_tri_inv);
44+
if (!is_constant<T1>::value) {
45+
matrix_cl<double> adjA = transpose(res.val()) * adjB;
46+
adjA.view(matrix_cl_view::Lower);
47+
adjoint_of(A_arena) -= adjA;
48+
}
49+
if (!is_constant<T2>::value) {
50+
adjoint_of(b_arena) += adjB;
51+
}
52+
});
53+
}
54+
55+
} // namespace math
56+
} // namespace stan
57+
58+
#endif
59+
#endif
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#ifdef STAN_OPENCL
2+
#include <stan/math/opencl/rev.hpp>
3+
#include <test/unit/math/opencl/util.hpp>
4+
#include <test/unit/util.hpp>
5+
#include <gtest/gtest.h>
6+
#include <algorithm>
7+
8+
auto mdivide_right_tri_low_functor = [](const auto& a, const auto& b) {
9+
return stan::math::mdivide_right_tri_low(a, b);
10+
};
11+
12+
TEST(OpenCL_mdivide_right_tri_low, mdivide_right_tri_low_small) {
13+
Eigen::MatrixXd in1(3, 3);
14+
in1 << 0.5, 3.4, 5.2, 7.5, 1, 2, 3, 4, 5;
15+
Eigen::RowVectorXd in2(3);
16+
in2 << 3.3, 0.9, 6.7;
17+
stan::math::test::compare_cpu_opencl_prim_rev(mdivide_right_tri_low_functor,
18+
in2, in1);
19+
}
20+
21+
TEST(OpenCL_mdivide_right_tri_low, zero) {
22+
Eigen::MatrixXd in1;
23+
Eigen::RowVectorXd in2;
24+
stan::math::test::compare_cpu_opencl_prim_rev(mdivide_right_tri_low_functor,
25+
in2, in1);
26+
}
27+
28+
TEST(OpenCL_mdivide_right_tri_low, prim_rev_values_large) {
29+
int N = 71;
30+
31+
Eigen::MatrixXd a = Eigen::MatrixXd::Random(N, N);
32+
Eigen::RowVectorXd b = Eigen::RowVectorXd::Random(N);
33+
stan::math::test::compare_cpu_opencl_prim_rev(mdivide_right_tri_low_functor,
34+
b, a);
35+
}
36+
37+
#endif

test/unit/math/prim/fun/mdivide_right_tri_test.cpp

-38
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,6 @@
22
#include <test/unit/util.hpp>
33
#include <gtest/gtest.h>
44

5-
#ifdef STAN_OPENCL
6-
#include <stan/math/opencl/prim.hpp>
7-
#include <boost/random/mersenne_twister.hpp>
8-
#endif
9-
105
TEST(MathMatrixPrim, mdivide_right_tri_val) {
116
using stan::math::mdivide_right_tri;
127
stan::math::matrix_d I = Eigen::MatrixXd::Identity(2, 2);
@@ -41,36 +36,3 @@ TEST(MathMatrixPrim, mdivide_right_tri_size_zero) {
4136
EXPECT_EQ(b0.rows(), I.rows());
4237
EXPECT_EQ(0, I.cols());
4338
}
44-
45-
#ifdef STAN_OPENCL
46-
47-
void mdivide_right_tri_cl_test(int size) {
48-
boost::random::mt19937 rng;
49-
stan::math::matrix_d m1(size, size);
50-
for (int i = 0; i < size; i++) {
51-
for (int j = 0; j < i; j++) {
52-
m1(i, j) = stan::math::uniform_rng(-5, 5, rng);
53-
}
54-
m1(i, i) = 20.0;
55-
for (int j = i + 1; j < size; j++) {
56-
m1(i, j) = 0.0;
57-
}
58-
}
59-
60-
stan::math::opencl_context.tuning_opts().tri_inverse_size_worth_transfer
61-
= size * 2;
62-
63-
stan::math::matrix_d m1_cpu
64-
= stan::math::mdivide_right_tri<Eigen::Lower>(m1, m1);
65-
66-
stan::math::opencl_context.tuning_opts().tri_inverse_size_worth_transfer = 0;
67-
68-
stan::math::matrix_d m1_cl
69-
= stan::math::mdivide_right_tri<Eigen::Lower>(m1, m1);
70-
71-
EXPECT_MATRIX_NEAR(m1_cpu, m1_cl, 1E-8);
72-
}
73-
TEST(MathMatrixCL, mdivide_right_tri_cl_small) { mdivide_right_tri_cl_test(3); }
74-
TEST(MathMatrixCL, mdivide_right_tri_cl_mid) { mdivide_right_tri_cl_test(100); }
75-
TEST(MathMatrixCL, mdivide_right_tri_cl_big) { mdivide_right_tri_cl_test(500); }
76-
#endif

0 commit comments

Comments
 (0)