-
-
Notifications
You must be signed in to change notification settings - Fork 190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add kernel generator cast operation #2472
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
aa77ead
add cast operation and use it in distributions
t4c1 a3cdcd8
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot 1722825
fixed cpplint and test names
t4c1 3962c69
Merge branch 'kg_cast' of https://github.com/bstatcomp/math into kg_cast
t4c1 fcaff77
bugfix inv_gamma_lpdf
t4c1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_CAST_HPP | ||
#define STAN_MATH_OPENCL_KERNEL_GENERATOR_CAST_HPP | ||
#ifdef STAN_OPENCL | ||
|
||
#include <stan/math/prim/meta.hpp> | ||
#include <stan/math/opencl/matrix_cl_view.hpp> | ||
#include <stan/math/opencl/kernel_generator/common_return_scalar.hpp> | ||
#include <stan/math/opencl/kernel_generator/type_str.hpp> | ||
#include <stan/math/opencl/kernel_generator/name_generator.hpp> | ||
#include <stan/math/opencl/kernel_generator/operation_cl.hpp> | ||
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp> | ||
#include <array> | ||
#include <string> | ||
#include <type_traits> | ||
#include <set> | ||
#include <utility> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
/** \addtogroup opencl_kernel_generator | ||
* @{ | ||
*/ | ||
|
||
/** | ||
* Represents a typecast os scalar in kernel generator expressions. | ||
* @tparam Derived derived type | ||
* @tparam T type of argument | ||
* @tparam Scal type of the scalar of result | ||
*/ | ||
template <typename Scal, typename T> | ||
class cast_ : public operation_cl<cast_<Scal, T>, Scal, T> { | ||
public: | ||
using Scalar = Scal; | ||
using base = operation_cl<cast_<Scal, T>, Scalar, T>; | ||
using base::var_name_; | ||
|
||
/** | ||
* Constructor | ||
* @param args argument expression(s) | ||
*/ | ||
explicit cast_(T&& arg) : base(std::forward<T>(arg)) {} | ||
|
||
/** | ||
* Generates kernel code for this expression. | ||
* @param row_index_name row index variable name | ||
* @param col_index_name column index variable name | ||
* @param view_handled whether whether caller already handled matrix view | ||
* @param var_names_arg variable names of the nested expressions | ||
* @return part of kernel with code for this expression | ||
*/ | ||
inline kernel_parts generate(const std::string& row_index_name, | ||
const std::string& col_index_name, | ||
const bool view_handled, | ||
const std::string& var_name_arg) const { | ||
kernel_parts res{}; | ||
|
||
res.body = type_str<Scalar>() + " " + var_name_ + " = (" | ||
+ type_str<Scalar>() + ")" + var_name_arg + ";\n"; | ||
return res; | ||
} | ||
|
||
inline auto deep_copy() const { | ||
auto&& arg_copy = this->template get_arg<0>().deep_copy(); | ||
return cast_<Scalar, std::remove_reference_t<decltype(arg_copy)>>{ | ||
std::move(arg_copy)}; | ||
} | ||
}; | ||
|
||
/** | ||
* Typecast a kernel generator expression scalar. | ||
* | ||
* @tparam T type of argument | ||
* @param a input argument | ||
* @return Typecast of given expression | ||
*/ | ||
template <typename Scalar, typename T, | ||
require_all_kernel_expressions_and_none_scalar_t<T>* = nullptr> | ||
inline auto cast(T&& a) { | ||
auto&& a_operation = as_operation_cl(std::forward<T>(a)).deep_copy(); | ||
return cast_<Scalar, std::remove_reference_t<decltype(a_operation)>>( | ||
std::move(a_operation)); | ||
} | ||
|
||
/** | ||
* Typecast a scalar. | ||
* | ||
* @tparam T type of argument | ||
* @param a input argument | ||
* @return Typecast of given expression | ||
*/ | ||
template <typename Scalar, typename T, require_stan_scalar_t<T>* = nullptr> | ||
inline Scalar cast(T a) { | ||
return a; | ||
} | ||
|
||
/** @}*/ | ||
} // namespace math | ||
} // namespace stan | ||
#endif | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this line here generates
Scalar var_name = (Scalar)input_var_name;
? If so then idt I'm understanding how this works forchar
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I don't understand what is the confusion with char. It works exactly the same as for any other supported scalar type.
We do have reference kernels in tests, although I am not using them for operations that generate really trivial code, such as typecasting in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I'm looking at
And here this is turned into
Apologies if this is silly but I'm not following how that is used in an if etc. Like is it that if
input_var_name
is 0 then thechar
is 0 and the if is false?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, exactly.
input_var_name
is in this case the result of the comparisonn < 0
, which is either 0 or 1. 0 or 1 converted into char is still 0 or 1. That gets written into global memory and than copied to host. Is there still anything you are not understanding?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aight then I think that makes sense