-
-
Notifications
You must be signed in to change notification settings - Fork 190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add OpenCL/prim add_diag, diag_matrix, subtract and minor OpenCL fixes #2250
Merged
Merged
Changes from 9 commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
303e974
fix inv_cloglog header guard
rok-cesnovar 5dabb84
Merge branch 'develop' into opencl_prim_misc
rok-cesnovar 33db9ce
fix add and subtract
rok-cesnovar f4cc731
add prim test util
rok-cesnovar 34f1150
merge prim and rev cholesky_decompose tests
rok-cesnovar 7d5ad8d
use util in subtract test
rok-cesnovar 44f06bf
add add_diag
rok-cesnovar ae1cda7
expand add_diag test
rok-cesnovar c4bdbf1
add diag_matrix and remove zeros()
rok-cesnovar eda2ca2
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot 0fed529
fix headers
rok-cesnovar a99db84
Apply suggestions from code review
rok-cesnovar cf3dd28
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot a2d4ec3
apply other review suggestions
rok-cesnovar 666da6f
copy in add_diag
rok-cesnovar fc18f42
Update stan/math/opencl/prim/add_diag.hpp
rok-cesnovar ed8c822
remove remaining zeros from tests
rok-cesnovar f61b823
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot 892b39d
cleanup zeros
rok-cesnovar e87f2b3
bugfix zeros return in multiply
rok-cesnovar File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -185,7 +185,10 @@ class binary_operation : public operation_cl<Derived, T_res, T_a, T_b> { | |
} | ||
|
||
ADD_BINARY_OPERATION(addition_, operator+, common_scalar_t<T_a COMMA T_b>, "+"); | ||
ADD_BINARY_OPERATION(subtraction_, operator-, common_scalar_t<T_a COMMA T_b>, | ||
ADD_BINARY_OPERATION(addition_operator_, add, common_scalar_t<T_a COMMA T_b>, "+"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You have the class names for addition aroud - the other one is operator. Also it might be better to add a wrapper function for add to delegate to operator instead of constructing another class |
||
ADD_BINARY_OPERATION(subtraction_operator_, operator-, common_scalar_t<T_a COMMA T_b>, | ||
"-"); | ||
ADD_BINARY_OPERATION(subtraction_, subtract, common_scalar_t<T_a COMMA T_b>, | ||
"-"); | ||
ADD_BINARY_OPERATION_WITH_CUSTOM_CODE( | ||
elt_multiply_, elt_multiply, common_scalar_t<T_a COMMA T_b>, "*", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
#ifndef STAN_MATH_OPENCL_PRIM_ADD_DIAG_HPP | ||
#define STAN_MATH_OPENCL_PRIM_ADD_DIAG_HPP | ||
#ifdef STAN_OPENCL | ||
#include <stan/math/opencl/prim/size.hpp> | ||
#include <stan/math/opencl/kernel_generator.hpp> | ||
#include <stan/math/prim/meta.hpp> | ||
#include <stan/math/prim/err.hpp> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
/** | ||
* Returns a Matrix with values added along the main diagonal | ||
* | ||
* @tparam T_m type of input kernel generator expression for the input matrix | ||
* @tparam T_a type of input kernel generator expression to add along the diagonal | ||
* | ||
* @param mat input kernel generator expression | ||
* @param to_add scalar value or input kernel generator expression to add along the | ||
* diagonal | ||
* @return a kernel generator expressio with to_add added along main diagonal | ||
*/ | ||
template <typename T_m, typename T_a, | ||
require_all_kernel_expressions_and_none_scalar_t<T_m>* = nullptr, | ||
require_all_kernel_expressions_t<T_a>* = nullptr> | ||
inline auto add_diag(T_m&& mat, T_a&& to_add) { // NOLINT | ||
t4c1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if (is_vector<T_a>::value) { | ||
const size_t length_diag = std::min(mat.rows(), mat.cols()); | ||
// int a = stan::math::size(mat); | ||
check_consistent_sizes("add_diag (OpenCL)", "number of elements of to_add", to_add, "diagonal", | ||
length_diag); | ||
} | ||
diagonal(mat) = diagonal(mat) + to_add; | ||
return mat; | ||
} | ||
} // namespace math | ||
} // namespace stan | ||
|
||
#endif | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
#ifndef STAN_MATH_OPENCL_PRIM_DIAG_MATRIX_HPP | ||
#define STAN_MATH_OPENCL_PRIM_DIAG_MATRIX_HPP | ||
#ifdef STAN_OPENCL | ||
#include <stan/math/opencl/kernel_generator.hpp> | ||
#include <stan/math/prim/meta.hpp> | ||
#include <stan/math/prim/err.hpp> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
/** | ||
* Return a square diagonal matrix with the specified vector of | ||
* coefficients as the diagonal values. | ||
* | ||
* @tparam T_x type of input kernel generator expression for the | ||
* diagonal | ||
* | ||
* @param x input kernel generator expression for the diagonal | ||
* | ||
* @return a kernel generator expression | ||
*/ | ||
template <typename T_x, | ||
require_all_kernel_expressions_and_none_scalar_t<T_x>* = nullptr> | ||
inline auto diag_matrix(T_x&& x) { // NOLINT | ||
return add_diag(constant(0,x.size(), x.size()).eval(), x); | ||
} | ||
} // namespace math | ||
} // namespace stan | ||
|
||
#endif | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
#ifdef STAN_OPENCL | ||
#include <stan/math/prim.hpp> | ||
#include <test/unit/math/opencl/util.hpp> | ||
#include <test/unit/util.hpp> | ||
#include <gtest/gtest.h> | ||
#include <algorithm> | ||
|
||
auto add_diag_functor | ||
= [](const auto& a, const auto& b) { return stan::math::add_diag(a, b); }; | ||
|
||
TEST(OpenCLPrim, add_diag_small) { | ||
stan::math::matrix_d d1(3,3); | ||
d1 << 1, 2,3,4,5,6,7,8,9; | ||
stan::math::vector_d d2(3); | ||
d2 << -1,15,500; | ||
stan::math::test::compare_cpu_opencl_prim(add_diag_functor, d1, 2); | ||
stan::math::test::compare_cpu_opencl_prim(add_diag_functor, d1, d2); | ||
|
||
stan::math::matrix_d d3(3,5); | ||
d3 << 1, 2,3,4,5,6,7,8,9,10,11,12,13,14,15; | ||
stan::math::test::compare_cpu_opencl_prim(add_diag_functor, d3, 2); | ||
stan::math::test::compare_cpu_opencl_prim(add_diag_functor, d3, d2); | ||
} | ||
|
||
TEST(OpenCLPrim, add_diag_zero) { | ||
stan::math::matrix_d d1(0,0); | ||
stan::math::vector_d d2(0); | ||
stan::math::test::compare_cpu_opencl_prim(add_diag_functor, d1, d2); | ||
} | ||
|
||
TEST(OpenCLPrim, add_diag_exception) { | ||
stan::math::matrix_d d1(3,3); | ||
stan::math::matrix_d d2(2,3); | ||
stan::math::vector_d d3(5); | ||
stan::math::vector_d d4(3); | ||
EXPECT_THROW(stan::math::add_diag(stan::math::to_matrix_cl(d1), stan::math::to_matrix_cl(d3)), std::invalid_argument); | ||
EXPECT_THROW(stan::math::add_diag(stan::math::to_matrix_cl(d2), stan::math::to_matrix_cl(d4)), std::invalid_argument); | ||
} | ||
|
||
#endif |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[optional] This looks good, but we could completely replace this function with kernel generator implementation.