Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Digamma op #10066

Merged
merged 32 commits into from
Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
cb8ffea
digamma op dev
youxiudeshouyeren Apr 1, 2023
fec26e3
unittest
youxiudeshouyeren Apr 1, 2023
44e1340
Merge branch 'master' into digamma_op_dev
youxiudeshouyeren Apr 1, 2023
8aac0c7
refine
youxiudeshouyeren Apr 2, 2023
b1fe15b
tensor.digamma api
youxiudeshouyeren Apr 2, 2023
0e92082
flow.digamma api
youxiudeshouyeren Apr 2, 2023
a08eada
fix test
youxiudeshouyeren Apr 2, 2023
b60259f
unittest
youxiudeshouyeren Apr 2, 2023
e7a3e19
Merge branch 'digamma_op_dev' of github.com:youxiudeshouyeren/oneflow…
youxiudeshouyeren Apr 2, 2023
2e0248e
fmt
youxiudeshouyeren Apr 2, 2023
98ebc8b
auto fmt
youxiudeshouyeren Apr 2, 2023
d8e5b0f
add api psi
youxiudeshouyeren Apr 2, 2023
98cf0a7
docstr
youxiudeshouyeren Apr 2, 2023
b0702e2
fmt
youxiudeshouyeren Apr 2, 2023
c1759ec
fix docstr
youxiudeshouyeren Apr 2, 2023
4b41d24
refine
youxiudeshouyeren Apr 3, 2023
17fe31a
fmt
youxiudeshouyeren Apr 3, 2023
8b80a76
fmt
youxiudeshouyeren Apr 3, 2023
a86d041
Merge branch 'master' into digamma_op_dev
youxiudeshouyeren Apr 3, 2023
fe44442
Merge branch 'master' into digamma_op_dev
youxiudeshouyeren Apr 6, 2023
e126953
add references
youxiudeshouyeren Apr 7, 2023
e1105d9
Merge branch 'digamma_op_dev' of github.com:youxiudeshouyeren/oneflow…
youxiudeshouyeren Apr 7, 2023
809892d
fmt
youxiudeshouyeren Apr 7, 2023
7130878
fix build
youxiudeshouyeren Apr 7, 2023
d1f63ea
fmt
youxiudeshouyeren Apr 7, 2023
6af454e
Merge branch 'master' into digamma_op_dev
youxiudeshouyeren Apr 7, 2023
75ff719
Merge branch 'youxiudeshouyeren-digamma_op_dev'
youxiudeshouyeren Apr 8, 2023
6c0055f
Merge branch 'master' into digamma_op_dev
youxiudeshouyeren Apr 8, 2023
190aabb
fix
youxiudeshouyeren Apr 8, 2023
68681e2
Merge branch 'master' into digamma_op_dev
youxiudeshouyeren Apr 11, 2023
8edc5e8
Merge branch 'master' into digamma_op_dev
mergify[bot] Apr 11, 2023
3e32f4c
Merge branch 'master' into digamma_op_dev
mergify[bot] Apr 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/special.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The oneflow.special module, modeled after SciPy's special module.
:toctree: generated
:nosignatures:

digamma
erf
erfc
erfinv
Expand Down
1 change: 1 addition & 0 deletions docs/source/tensor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ Tensor class reference
Tensor.div_
Tensor.double
Tensor.dtype
Tensor.digamma
Tensor.element_size
Tensor.eq
Tensor.equal
Expand Down
2 changes: 2 additions & 0 deletions oneflow/api/python/framework/tensor_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ PyNumberMethods PyTensorObject_as_number = {
}

UNARY_METHOD(PyTensorObject_abs, functional::Abs);
UNARY_METHOD(PyTensorObject_digamma, functional::Digamma);
UNARY_METHOD(PyTensorObject_exp, functional::Exp);
UNARY_METHOD(PyTensorObject_exp2, functional::Exp2);
UNARY_METHOD(PyTensorObject_floor, functional::Floor);
Expand Down Expand Up @@ -1102,6 +1103,7 @@ PyMethodDef PyTensorObject_extra_methods[] = {

// macro UNARY_METHOD
{"abs", PyTensorObject_abs, METH_NOARGS, NULL},
{"digamma", PyTensorObject_digamma, METH_NOARGS, NULL},
{"exp", PyTensorObject_exp, METH_NOARGS, NULL},
{"exp2", PyTensorObject_exp2, METH_NOARGS, NULL},
{"floor", PyTensorObject_floor, METH_NOARGS, NULL},
Expand Down
6 changes: 6 additions & 0 deletions oneflow/core/common/math_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ limitations under the License.

namespace oneflow {

/*
* math constants
*/
template<typename T>
constexpr T pi = static_cast<T>(3.141592653589793238462643383279502);

int64_t Gcd(int64_t m, int64_t n);

int64_t Lcm(int64_t m, int64_t n);
Expand Down
19 changes: 10 additions & 9 deletions oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,16 @@ inline bool IsDimsEquals(size_t num_src0_dims, const int64_t* src0_dims, size_t
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAtanhBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kCosBackwardWithDyX)

#define BINARY_MATH_BACKWARD_OP_SEQ_1 \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kCoshBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kErfBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kErfcBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kExpBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kExp2BackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kExpm1BackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLgammaBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogBackwardWithDyX) \
#define BINARY_MATH_BACKWARD_OP_SEQ_1 \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kCoshBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kErfBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kErfcBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kExpBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kExp2BackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kExpm1BackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLgammaBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kDigammaBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLog2BackwardWithDyX)

#define BINARY_MATH_BACKWARD_OP_SEQ_2 \
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/ep/common/primitive/elementwise_unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ namespace primitive {
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kCeil) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kCos) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kCosh) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kDigamma) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kErf) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kErfc) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kExp) \
Expand Down
10 changes: 10 additions & 0 deletions oneflow/core/ep/cpu/primitive/binary_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,16 @@ struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kErfcBackwardWithDyX, Src, Dst>
}
};

template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kDigammaBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
// TODO:shijiaxing: This function is named trigamma, it will be implemented soon.
UNIMPLEMENTED();
return 0;
}
};

#define SPECIALIZATION_CPU_BINARY_FUNCTOR(op, type) \
template<> \
struct BinaryFunctor<DeviceType::kCPU, op, type, type> { \
Expand Down
135 changes: 135 additions & 0 deletions oneflow/core/ep/cpu/primitive/unary_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.
*/
#include "oneflow/core/ep/common/primitive/unary_functor.h"
#include "oneflow/core/ep/cpu/primitive/type_seq.h"
#include "oneflow/core/common/math_util.h"

namespace oneflow {
namespace ep {
Expand Down Expand Up @@ -120,6 +121,139 @@ struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kRsqrt, Dst, Src> {
}
};

template<>
struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kDigamma, float, float> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC float operator()(float src) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个实现是参考的什么地方,可以在注释中放个链接吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个实现是参考的什么地方,可以在注释中放个链接吗?

已添加

// references
// https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/Math.h#L434-L487
const auto& calc_digamma = [](float x) {
std::function<float(float)> compute;
compute = [&](float x) {
static float PSI_10 = 2.25175258906672110764f;
if (x == 0) {
// As per C++ standard for gamma related functions and SciPy,
// If the argument is ±0, ±∞ is returned
return std::copysign(INFINITY, -x);
}

bool x_is_integer = x == truncf(x);
if (x < 0) {
if (x_is_integer) {
// As per C++ standard for gamma related functions and SciPy,
// If the argument is a negative integer, NaN is returned
return std::numeric_limits<float>::quiet_NaN();
}
// Extracts the fractional part of x as r, since tan(pi * r) is more numerically
// accurate than tan(pi * x). While these operations are mathematically equivalent
// since both x and r are in radians and tan() has a periodicity of pi, in practice
// the computation of pi * x is a source of error (when |x| > 1).
double q, r;
r = std::modf(x, &q);
float pi_over_tan_pi_x = (float)(pi<double> / tan(pi<double> * r));
return compute(1 - x) - pi_over_tan_pi_x;
}

// Push x to be >= 10
float result = 0;
while (x < 10) {
result -= 1 / x;
x += 1;
}
if (x == 10) { return result + PSI_10; }

// Compute asymptotic digamma
static const float A[] = {
8.33333333333333333333E-2f, -2.10927960927960927961E-2f, 7.57575757575757575758E-3f,
-4.16666666666666666667E-3f, 3.96825396825396825397E-3f, -8.33333333333333333333E-3f,
8.33333333333333333333E-2f,
};

float y = 0;
if (x < 1.0e17f) {
float z = 1 / (x * x);
float polevl_result = 0;
for (int i = 0; i <= 6; i++) { polevl_result = polevl_result * z + A[i]; }
y = z * polevl_result;
}
return result + logf(x) - (0.5f / x) - y;
};

return compute(x);
};

return calc_digamma(src);
}
};

template<>
struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kDigamma, double, double> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC double operator()(double src) const {
// references
// https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/Math.h#L376-L428
const auto& calc_digamma = [](double x) {
std::function<double(double)> compute;
compute = [&](double x) {
static double PSI_10 = 2.25175258906672110764;
if (x == 0) {
// As per C++ standard for gamma related functions and SciPy,
// If the argument is ±0, ±∞ is returned
return std::copysign(INFINITY, -x);
}

bool x_is_integer = x == trunc(x);
if (x < 0) {
if (x_is_integer) {
// As per C++ standard for gamma related functions and SciPy,
// If the argument is a negative integer, NaN is returned
return std::numeric_limits<double>::quiet_NaN();
}
// Extracts the fractional part of x as r, since tan(pi * r) is more numerically
// accurate than tan(pi * x). While these operations are mathematically equivalent
// since both x and r are in radians and tan() has a periodicity of pi, in practice
// the computation of pi * x is a source of error (when |x| > 1).
double q, r;
r = std::modf(x, &q);
return compute(1 - x) - pi<double> / tan(pi<double> * r);
}

// Push x to be >= 10
double result = 0;
while (x < 10) {
result -= 1 / x;
x += 1;
}
if (x == 10) { return result + PSI_10; }

// Compute asymptotic digamma
static const double A[] = {
8.33333333333333333333E-2, -2.10927960927960927961E-2, 7.57575757575757575758E-3,
-4.16666666666666666667E-3, 3.96825396825396825397E-3, -8.33333333333333333333E-3,
8.33333333333333333333E-2,
};

double y = 0;
if (x < 1.0e17) {
double z = 1.0 / (x * x);
// y = z * polevl(z, A, 6);

double polevl_result = 0;
for (int i = 0; i <= 6; i++) { polevl_result = polevl_result * z + A[i]; }
y = z * polevl_result;
}
return result + log(x) - (0.5 / x) - y;
};

return compute(x);
};

return calc_digamma(src);
}
};

template<>
struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kAbs, bfloat16, bfloat16> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
Expand Down Expand Up @@ -187,6 +321,7 @@ SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kReciprocalNoNan);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kNotEqualZero);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kFastGelu);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kQuickGelu);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kDigamma);

template<>
struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kIsInf, bool, bfloat16> {
Expand Down
10 changes: 10 additions & 0 deletions oneflow/core/ep/cuda/primitive/binary_functor.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,16 @@ struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kIsClose, Src, Dst> {
float atol, rtol;
};

template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kDigammaBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
// TODO:shijiaxing: This function is named trigamma, it will be implemented soon.
assert(false);
return static_cast<Dst>(0.0);
}
};

#define SPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR(op, type) \
template<typename Dst> \
struct BinaryFunctor<DeviceType::kCUDA, op, type, Dst> { \
Expand Down
62 changes: 62 additions & 0 deletions oneflow/core/ep/cuda/primitive/unary_functor.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include "oneflow/core/cuda/elementwise.cuh"
#include "oneflow/core/ep/cuda/cuda_stream.h"
#include <cuda.h>
#include "oneflow/core/common/math_util.h"

namespace oneflow {
namespace ep {
Expand Down Expand Up @@ -223,6 +224,65 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kTrunc, double, double> {
OF_DEVICE_FUNC double operator()(double src) const { return trunc(src); }
};

template<typename Dst, typename Src>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kDigamma, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC Dst operator()(Src in) const {
// references
// https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/cuda/Math.cuh#L3029-L3090
static const double PI_f64 = 3.14159265358979323846;
const Src PSI_10 = 2.25175258906672110764;
const Src A[] = {
8.33333333333333333333E-2, -2.10927960927960927961E-2, 7.57575757575757575758E-3,
-4.16666666666666666667E-3, 3.96825396825396825397E-3, -8.33333333333333333333E-3,
8.33333333333333333333E-2,
};

Src x = static_cast<Src>(in);
if (x == static_cast<Src>(0)) {
// As per C++ standard for gamma related functions and SciPy,
// If the argument is ±0, ±∞ is returned
return std::copysign(static_cast<Src>(INFINITY), -x);
}

bool x_is_integer = x == trunc(x);
Src result = static_cast<Src>(0);
if (x < 0) {
if (x_is_integer) {
// As per C++ standard for gamma related functions and SciPy,
// If the argument is a negative integer, NaN is returned
return static_cast<Src>(NAN);
}
// Extracts the fractional part of x as r, since tan(pi * r) is more numerically
// accurate than tan(pi * x). While these operations are mathematically equivalent
// since both x and r are in radians and tan() has a periodicity of pi, in practice
// the computation of pi * x is a source of error (when |x| > 1).
double q, r;
r = modf(static_cast<double>(x), &q);
result = static_cast<Src>(-PI_f64 / tan(PI_f64 * r));
x = static_cast<Src>(1) - x;
}

while (x < 10) {
result -= static_cast<Src>(1) / x;
x += 1;
}
if (x == static_cast<Src>(10)) { return static_cast<Src>(result + PSI_10); }

Src y = 0;
if (x < 1.0e17) {
Src z = static_cast<Src>(1) / (x * x);

Src polevl_result = 0;
for (int i = 0; i <= 6; i++) { polevl_result = polevl_result * z + A[i]; }
y = z * polevl_result;
}

return static_cast<Src>(log(x) - (static_cast<Src>(0.5) / x) - y + result);
}
};

template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kAbs, half, half> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
Expand Down Expand Up @@ -351,6 +411,7 @@ SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kAtanh);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kCeil);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kCos);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kCosh);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kDigamma);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kErf);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kErfc);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kExp);
Expand Down Expand Up @@ -443,6 +504,7 @@ SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kNotEqualZero);
SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kNanAssign);
SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kFastGelu);
SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kQuickGelu);
SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kDigamma);

template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kIsInf, bool, nv_bfloat16> {
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/ep/include/primitive/binary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ enum class BinaryOp {
kExp2BackwardWithDyX,
kExpm1BackwardWithDyX,
kLgammaBackwardWithDyX,
kDigammaBackwardWithDyX,
kLogBackwardWithDyX,
kLog2BackwardWithDyX,
kLog10BackwardWithDyX,
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/ep/include/primitive/unary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ enum class UnaryOp {
kCeil,
kCos,
kCosh,
kDigamma,
kErf,
kErfc,
kExp,
Expand Down
8 changes: 8 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3338,3 +3338,11 @@
- name: "frac_"
signature: "Tensor (Tensor x) => FracInplace"
bind_python: True

- name: "digamma"
signature: "Tensor (Tensor x) => Digamma"
bind_python: True

- name: "digamma_grad"
signature: "Tensor (Tensor x, Tensor dy) => DigammaGrad"
bind_python: False
Loading