Skip to content

Commit 02dc560

Browse files
Merge pull request #2554 from nickdidio/complex-numbers
Complex numbers get_real(), get_imag(), to_complex
2 parents 62d9a8d + 899bc77 commit 02dc560

12 files changed

+125
-21
lines changed

stan/math/prim/fun.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,9 @@
111111
#include <stan/math/prim/fun/get.hpp>
112112
#include <stan/math/prim/fun/get_base1.hpp>
113113
#include <stan/math/prim/fun/get_base1_lhs.hpp>
114+
#include <stan/math/prim/fun/get_imag.hpp>
114115
#include <stan/math/prim/fun/get_lp.hpp>
116+
#include <stan/math/prim/fun/get_real.hpp>
115117
#include <stan/math/prim/fun/gp_dot_prod_cov.hpp>
116118
#include <stan/math/prim/fun/gp_exponential_cov.hpp>
117119
#include <stan/math/prim/fun/gp_matern32_cov.hpp>
@@ -325,6 +327,7 @@
325327
#include <stan/math/prim/fun/tgamma.hpp>
326328
#include <stan/math/prim/fun/to_array_1d.hpp>
327329
#include <stan/math/prim/fun/to_array_2d.hpp>
330+
#include <stan/math/prim/fun/to_complex.hpp>
328331
#include <stan/math/prim/fun/to_matrix.hpp>
329332
#include <stan/math/prim/fun/to_ref.hpp>
330333
#include <stan/math/prim/fun/to_row_vector.hpp>

stan/math/prim/fun/abs.hpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
#ifndef STAN_MATH_PRIM_FUN_ABS_HPP
22
#define STAN_MATH_PRIM_FUN_ABS_HPP
33

4+
#include <stan/math/prim/core.hpp>
45
#include <stan/math/prim/meta.hpp>
56
#include <stan/math/prim/fun/hypot.hpp>
67
#include <stan/math/prim/functor/apply_scalar_unary.hpp>
78
#include <stan/math/prim/functor/apply_vector_unary.hpp>
89
#include <cmath>
10+
#include <complex>
911

1012
namespace stan {
1113
namespace math {
@@ -62,7 +64,7 @@ template <typename Container,
6264
require_container_st<std::is_arithmetic, Container>* = nullptr>
6365
inline auto abs(const Container& x) {
6466
return apply_vector_unary<Container>::apply(
65-
x, [](const auto& v) { return v.array().abs(); });
67+
x, [&](const auto& v) { return v.array().abs(); });
6668
}
6769

6870
namespace internal {

stan/math/prim/fun/cos.hpp

+2-12
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,6 @@
1313
namespace stan {
1414
namespace math {
1515

16-
/**
17-
* Arithmetic version of `cos()`
18-
* @tparam T An `Arithmetic` type.
19-
* @param x Arithmetic scalar.
20-
*/
21-
template <typename T, require_arithmetic_t<T>* = nullptr>
22-
inline auto cos(T x) {
23-
return std::cos(x);
24-
}
25-
2616
/**
2717
* Structure to wrap `cos()` so it can be vectorized.
2818
*
@@ -33,6 +23,7 @@ inline auto cos(T x) {
3323
struct cos_fun {
3424
template <typename T>
3525
static inline T fun(const T& x) {
26+
using std::cos;
3627
return cos(x);
3728
}
3829
};
@@ -47,7 +38,6 @@ struct cos_fun {
4738
*/
4839
template <typename Container,
4940
require_not_container_st<std::is_arithmetic, Container>* = nullptr,
50-
require_not_stan_scalar_t<Container>* = nullptr,
5141
require_not_var_matrix_t<Container>* = nullptr,
5242
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
5343
Container>* = nullptr>
@@ -67,7 +57,7 @@ template <typename Container,
6757
require_container_st<std::is_arithmetic, Container>* = nullptr>
6858
inline auto cos(const Container& x) {
6959
return apply_vector_unary<Container>::apply(
70-
x, [](const auto& v) { return v.array().cos(); });
60+
x, [&](const auto& v) { return v.array().cos(); });
7161
}
7262

7363
namespace internal {

stan/math/prim/fun/get_imag.hpp

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#ifndef STAN_MATH_PRIM_FUN_GET_IMAG_HPP
2+
#define STAN_MATH_PRIM_FUN_GET_IMAG_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <complex>
6+
7+
namespace stan {
8+
namespace math {
9+
10+
/**
11+
* Return the imaginary component of the complex argument.
12+
*
13+
* @tparam T value type of complex argument
14+
* @param[in] z complex value whose imaginary component is extracted
15+
* @return imaginary component of argument
16+
*/
17+
template <typename T>
18+
T get_imag(const std::complex<T>& z) {
19+
return z.imag();
20+
}
21+
22+
} // namespace math
23+
} // namespace stan
24+
25+
#endif

stan/math/prim/fun/get_real.hpp

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#ifndef STAN_MATH_PRIM_FUN_GET_REAL_HPP
2+
#define STAN_MATH_PRIM_FUN_GET_REAL_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <complex>
6+
7+
namespace stan {
8+
namespace math {
9+
10+
/**
11+
* Return the real component of the complex argument.
12+
*
13+
* @tparam T value type of complex argument
14+
* @param[in] z complex value whose real component is extracted
15+
* @return real component of argument
16+
*/
17+
template <typename T>
18+
T get_real(const std::complex<T>& z) {
19+
return z.real();
20+
}
21+
22+
} // namespace math
23+
} // namespace stan
24+
25+
#endif

stan/math/prim/fun/imag.hpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@ namespace stan {
88
namespace math {
99

1010
/**
11-
* Return the imaginary part of the complex argument.
11+
* Return the imaginary component of the complex argument.
1212
*
13-
* @tparam T value type of argument
14-
* @param[in] z argument
15-
* @return imaginary part of argument
13+
* @tparam T value type of complex argument
14+
* @param[in] z complex value whose imaginary component is extracted
15+
* @return imaginary component of argument
1616
*/
1717
template <typename T, require_autodiff_t<T>>
1818
T imag(const std::complex<T>& z) {

stan/math/prim/fun/real.hpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@ namespace stan {
88
namespace math {
99

1010
/**
11-
* Return the real part of the complex argument.
11+
* Return the real component of the complex argument.
1212
*
13-
* @tparam T value type of argument
14-
* @param[in] z argument
15-
* @return real part of argument
13+
* @tparam T value type of complex argument
14+
* @param[in] z complex value whose real component is extracted
15+
* @return real component of argument
1616
*/
1717
template <typename T, require_autodiff_t<T>>
1818
T real(const std::complex<T>& z) {

stan/math/prim/fun/to_complex.hpp

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#ifndef STAN_MATH_PRIM_FUN_TO_COMPLEX_HPP
2+
#define STAN_MATH_PRIM_FUN_TO_COMPLEX_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/meta/return_type.hpp>
6+
#include <complex>
7+
8+
namespace stan {
9+
namespace math {
10+
11+
/**
12+
* Return a complex value from a real component and an imaginary component.
13+
* Default values for both components is 0.
14+
*
15+
* @tparam T type of real component
16+
* @tparam S type of imaginary component
17+
* @param[in] re real component (default = 0)
18+
* @param[in] im imaginary component (default = 0)
19+
* @return complex value with specified real and imaginary components
20+
*/
21+
template <typename T = double, typename S = double>
22+
inline std::complex<stan::real_return_t<T, S>> to_complex(const T& re = 0,
23+
const S& im = 0) {
24+
return std::complex<stan::real_return_t<T, S>>(re, im);
25+
}
26+
27+
} // namespace math
28+
} // namespace stan
29+
30+
#endif
+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#include <test/unit/math/test_ad.hpp>
2+
#include <complex>
3+
4+
TEST(mathMixMatFun, get_imag) {
5+
auto f = [](const auto& z) { return stan::math::get_imag(z); };
6+
stan::test::expect_complex_common(f);
7+
}
+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#include <test/unit/math/test_ad.hpp>
2+
#include <complex>
3+
4+
TEST(mathMixMatFun, get_real) {
5+
auto f = [](const auto& z) { return stan::math::get_real(z); };
6+
stan::test::expect_complex_common(f);
7+
}
+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#include <test/unit/math/test_ad.hpp>
2+
#include <complex>
3+
4+
TEST(mathMixMatFun, to_complex) {
5+
auto f0 = []() { return stan::math::to_complex(); };
6+
auto f1 = [](const auto& x) { return stan::math::to_complex(x); };
7+
auto f2 = [](const auto& x, const auto& y) {
8+
return stan::math::to_complex(x, y);
9+
};
10+
11+
EXPECT_EQ(std::complex<double>(), stan::math::to_complex());
12+
stan::test::expect_common_unary(f1);
13+
stan::test::expect_common_binary(f2);
14+
}

test/unit/math/test_ad_test.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ int baz_fvar = 0;
233233
int baz_complex = 0;
234234
int baz_complex_var = 0;
235235
int baz_complex_fvar = 0;
236+
236237
double baz(int x) {
237238
++baz_int;
238239
return x / 2.0;

0 commit comments

Comments
 (0)