diff --git a/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc b/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc index 25725596375533..bb3babadf0560f 100644 --- a/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc @@ -49,6 +49,29 @@ void MinimumGradKernel(const Context& dev_ctx, dev_ctx, x, y, dout, dout, axis, dx, dy, MinGradDx(), MinGradDy()); } +template +void RemainderGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + DenseTensor* dx, + DenseTensor* dy) { + funcs::ElementwiseGradPreProcess(dout, dx); + int axis = -1; + phi::funcs:: + ElemwiseGradCompute, RemainderGradDy>( + dev_ctx, + x, + y, + dout, + dout, + axis, + dx, + dy, + RemainderGradDx(), + RemainderGradDy()); +} + template void CopySignGradKernel(const Context& dev_ctx, const DenseTensor& x, @@ -111,6 +134,16 @@ PD_REGISTER_KERNEL(minimum_grad, int64_t, phi::dtype::bfloat16) {} +PD_REGISTER_KERNEL(remainder_grad, + CPU, + ALL_LAYOUT, + phi::RemainderGradKernel, + float, + double, + int, + int64_t, + phi::dtype::bfloat16) {} + PD_REGISTER_KERNEL(heaviside_grad, CPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/elementwise_grad_kernel.h b/paddle/phi/kernels/elementwise_grad_kernel.h index 0ca934cc8f35b2..74a669387c41b1 100644 --- a/paddle/phi/kernels/elementwise_grad_kernel.h +++ b/paddle/phi/kernels/elementwise_grad_kernel.h @@ -51,6 +51,14 @@ void MinimumGradKernel(const Context& dev_ctx, DenseTensor* dx, DenseTensor* dy); +template +void RemainderGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + DenseTensor* dx, + DenseTensor* dy); + template void HeavisideGradKernel(const Context& dev_ctx, const DenseTensor& x, diff --git a/paddle/phi/kernels/funcs/elementwise_functor.h b/paddle/phi/kernels/funcs/elementwise_functor.h index 9756dcecf857ec..5d7a6b627cb026 100644 --- a/paddle/phi/kernels/funcs/elementwise_functor.h +++ b/paddle/phi/kernels/funcs/elementwise_functor.h @@ -591,6 +591,96 @@ struct RemainderFunctor { } }; +// RemainderGradXFunctor +template +struct RemainderGradXFunctor { + inline HOSTDEVICE T operator()(const T x, const T y, const T dout) const { + // dx = dout + return dout; + } +}; + +// RemainderGradYFunctor +template +struct RemainderGradYFunctor { + inline HOSTDEVICE T operator()(const T x, const T y, const T dout) const { + // dy = -dout * (floor_div(x, y)) + return -dout * static_cast((std::floor(x / y))); + } +}; +template +struct RemainderGradYFunctor< + T, + typename std::enable_if::value>::type> { + inline HOSTDEVICE T operator()(const T x, const T y, const T dout) const { + using MPType = typename phi::dtype::MPTypeTrait::Type; + // dy = -dout * (floor_div(x, y)) + auto x_ = static_cast(x); + auto y_ = static_cast(y); + return static_cast(-static_cast(dout) * (std::floor((x_ / y_)))); + } +}; +template +struct RemainderGradYFunctor< + T, + typename std::enable_if::value>::type> { + inline HOSTDEVICE T operator()(const T x, const T y, const T dout) const { + // dy = -dout * (floor_div(x, y)) + return -dout * (x / y); + } +}; + +// RemainderGradXYFunctor +template +struct RemainderGradXYFunctor { + inline HOSTDEVICE phi::Array operator()(const InT x, + const InT y, + const InT dout) { + phi::Array outs; + // dx = dout + outs[0] = static_cast(dout); + // dy = -dout * (floor_div(x, y)) + outs[1] = static_cast(dout * static_cast(std::floor(x / y))); + return outs; + } +}; +template +struct RemainderGradXYFunctor< + InT, + OutT, + typename std::enable_if::value>::type> { + inline HOSTDEVICE Array operator()(const InT x, + const InT y, + const InT dout) { + Array outs; + // dx = dout + outs[0] = static_cast(dout); + // dy = -dout * (x / y) + using MPType = typename phi::dtype::MPTypeTrait::Type; + auto x_ = static_cast(x); + auto y_ = static_cast(y); + outs[1] = + static_cast(static_cast(-dout) * std::floor(x_ / y_)); + return outs; + } +}; +template +struct RemainderGradXYFunctor< + InT, + OutT, + typename std::enable_if::value>::type> { + inline HOSTDEVICE Array operator()(const InT x, + const InT y, + const InT dout) { + Array outs; + // dx = dout + outs[0] = static_cast(dout); + // dy = -dout * (x / y) + outs[1] = static_cast(-dout * (x / y)); + return outs; + } +}; + template struct InverseRemainderFunctor { inline HOSTDEVICE T operator()(const T a, const T b) const { diff --git a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu index ddb3d3233a0298..f89d0ff01fdd5e 100644 --- a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu @@ -210,6 +210,36 @@ void MinimumGradKernel(const Context& dev_ctx, } } +template +void RemainderGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + DenseTensor* dx, + DenseTensor* dy) { + const auto place = dev_ctx.GetPlace(); + int axis = -1; + if (dx != nullptr && dy != nullptr) { + std::vector ins = {&x, &y, &dout}; + GetGradXAndYOut(dev_ctx, + place, + axis, + ins, + dout, + dx, + dy, + funcs::RemainderGradXYFunctor()); + } else if (dx != nullptr && dy == nullptr) { + std::vector ins = {&x, &y, &dout}; + GetGradXOrYOut( + dev_ctx, place, axis, ins, dout, dx, funcs::RemainderGradXFunctor()); + } else if (dy != nullptr && dx == nullptr) { + std::vector ins = {&x, &y, &dout}; + GetGradXOrYOut( + dev_ctx, place, axis, ins, dout, dy, funcs::RemainderGradYFunctor()); + } +} + template void CopySignGradKernel(const Context& dev_ctx, const DenseTensor& x, @@ -295,6 +325,17 @@ PD_REGISTER_KERNEL(minimum_grad, phi::dtype::float16, phi::dtype::bfloat16) {} +PD_REGISTER_KERNEL(remainder_grad, + GPU, + ALL_LAYOUT, + phi::RemainderGradKernel, + float, + double, + int, + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} + PD_REGISTER_KERNEL(heaviside_grad, GPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h index 16b927e83aabef..604cbbfec763f0 100644 --- a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h @@ -1400,6 +1400,48 @@ void ElementwisePowGradKernel(const Context& dev_ctx, dev_ctx, x, y, dout, dout, axis, dx, dy, PowGradDX(), PowGradDY()); } +/* +****************************** + Remainder Grad +****************************** +*/ +// RemainderGradDx +template +struct RemainderGradDx { + HOSTDEVICE T operator()(T x, T y, T out UNUSED, T dout) const { + // dx = dout + return dout; + } +}; + +// RemainderGradDy +template +struct RemainderGradDy { + HOSTDEVICE T operator()(T x, T y, T out UNUSED, T dout) const { + return -dout * (std::floor(static_cast(x / y))); + } +}; +template +struct RemainderGradDy< + T, + typename std::enable_if::value>::type> { + HOSTDEVICE T operator()(T x, T y, T out UNUSED, T dout) const { + using MPType = typename phi::dtype::MPTypeTrait::Type; + auto x_ = static_cast(x); + auto y_ = static_cast(y); + return static_cast(-static_cast(dout) * (std::floor((x_ / y_)))); + } +}; +template +struct RemainderGradDy< + T, + typename std::enable_if::value>::type> { + HOSTDEVICE T operator()(T x, T y, T out UNUSED, T dout) const { + // dy = -dout * (x / y) + return -dout * static_cast(std::floor(static_cast(x) / + static_cast(y))); + } +}; /* ****************************** Copysign Grad diff --git a/paddle/phi/ops/yaml/inconsistent/dygraph_backward.yaml b/paddle/phi/ops/yaml/inconsistent/dygraph_backward.yaml index 134fa8299d0813..7f8df1a2584a24 100755 --- a/paddle/phi/ops/yaml/inconsistent/dygraph_backward.yaml +++ b/paddle/phi/ops/yaml/inconsistent/dygraph_backward.yaml @@ -266,6 +266,16 @@ func : multiply_triple_grad optional : fwd_grad_grad_x, fwd_grad_grad_y, grad_x_grad, grad_y_grad, grad_grad_out_grad +- backward_op : remainder_grad + forward : remainder (Tensor x, Tensor y) -> Tensor(out) + args : (Tensor x, Tensor y, Tensor out_grad) + output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [x, y] + kernel : + func : remainder_grad + - backward_op : set_value_grad forward : set_value (Tensor x, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes, int64_t[] shape, Scalar[] values) -> Tensor(out) args : (Tensor out_grad, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes) diff --git a/paddle/phi/ops/yaml/inconsistent/dygraph_ops.yaml b/paddle/phi/ops/yaml/inconsistent/dygraph_ops.yaml index 8a183e27f18737..7f0432cd72c02a 100755 --- a/paddle/phi/ops/yaml/inconsistent/dygraph_ops.yaml +++ b/paddle/phi/ops/yaml/inconsistent/dygraph_ops.yaml @@ -295,10 +295,11 @@ output : Tensor (out) infer_meta : func : ElementwiseInferMeta + param: [x, y] kernel : func : remainder inplace : (x -> out) - traits : paddle::dialect::ForwardOnlyTrait + backward: remainder_grad - op : set_value args : (Tensor x, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes, int64_t[] shape, Scalar[] values) diff --git a/paddle/phi/ops/yaml/inconsistent/static_backward.yaml b/paddle/phi/ops/yaml/inconsistent/static_backward.yaml index d35e42599707bf..dbf80dbf912cbd 100644 --- a/paddle/phi/ops/yaml/inconsistent/static_backward.yaml +++ b/paddle/phi/ops/yaml/inconsistent/static_backward.yaml @@ -472,6 +472,16 @@ data_type : out_grad_in inplace : (out_grad_in -> out_grad_out) +- backward_op : remainder_grad + forward : remainder (Tensor x, Tensor y) -> Tensor(out) + args : (Tensor x, Tensor y, Tensor out_grad) + output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [x, y] + kernel : + func : remainder_grad + - backward_op : row_conv_grad forward: row_conv (Tensor x, Tensor filter) -> Tensor(out) args: (Tensor x, Tensor filter, Tensor out_grad) diff --git a/paddle/phi/ops/yaml/inconsistent/static_ops.yaml b/paddle/phi/ops/yaml/inconsistent/static_ops.yaml index 73292e952d526b..51539c0d781ba6 100644 --- a/paddle/phi/ops/yaml/inconsistent/static_ops.yaml +++ b/paddle/phi/ops/yaml/inconsistent/static_ops.yaml @@ -779,13 +779,14 @@ output : Tensor (out) infer_meta : func : ElementwiseInferMeta + param: [x, y] kernel : func : remainder data_transform : support_trans_dtype : x, y inplace : (x -> out) interfaces : paddle::dialect::InferSymbolicShapeInterface - traits : paddle::dialect::ForwardOnlyTrait + backward: remainder_grad - op : row_conv args : (Tensor x, Tensor filter) diff --git a/paddle/phi/ops/yaml/legacy/static_backward.yaml b/paddle/phi/ops/yaml/legacy/static_backward.yaml index 33f0b9c6efb881..f5d8ce2f0a8010 100755 --- a/paddle/phi/ops/yaml/legacy/static_backward.yaml +++ b/paddle/phi/ops/yaml/legacy/static_backward.yaml @@ -380,6 +380,16 @@ func : prod_grad composite: prod_grad(x, out, out_grad, axis, keepdim, reduce_all, x_grad) +- backward_op : remainder_grad + forward : remainder (Tensor x, Tensor y, int axis = -1) -> Tensor(out) + args : (Tensor x, Tensor y, Tensor out_grad) + output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [x, y] + kernel : + func : remainder_grad + - backward_op : rnn_grad forward : rnn (Tensor x, Tensor[] pre_state, Tensor[] weight_list, Tensor sequence_length, float dropout_prob=0.0, bool is_bidirec=false, int input_size=10, int hidden_size=100, int num_layers=1, str mode="RNN_TANH", int seed=0, bool is_test=false) -> Tensor(out), Tensor(dropout_state_out), Tensor[](state), Tensor(reserve) args : (Tensor x, Tensor[] pre_state, Tensor[] weight_list, Tensor sequence_length, Tensor out, Tensor dropout_state_out, Tensor reserve, Tensor out_grad, Tensor[] state_grad, float dropout_prob, bool is_bidirec, int input_size, int hidden_size, int num_layers, str mode, int seed, bool is_test) diff --git a/paddle/phi/ops/yaml/legacy/static_ops.yaml b/paddle/phi/ops/yaml/legacy/static_ops.yaml index 81dfe419e60f40..3428f68901194e 100755 --- a/paddle/phi/ops/yaml/legacy/static_ops.yaml +++ b/paddle/phi/ops/yaml/legacy/static_ops.yaml @@ -739,10 +739,11 @@ output : Tensor (out) infer_meta : func : ElementwiseRawInferMeta + param: [x, y] kernel : func : remainder inplace : (x -> out) - traits : paddle::dialect::ForwardOnlyTrait + backward: remainder_grad - op : rnn args: (Tensor x, Tensor[] pre_state, Tensor[] weight_list, Tensor sequence_length, float dropout_prob=0.0, bool is_bidirec=false, int input_size=10, int hidden_size=100, int num_layers=1, str mode="RNN_TANH", int seed=0, bool is_test=false) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 9689bfdacde926..8f3b1eb80991c5 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3513,7 +3513,6 @@ data_type : x traits : paddle::dialect::ForwardOnlyTrait interfaces : paddle::dialect::InferSymbolicShapeInterface - traits : paddle::dialect::ForwardOnlyTrait - op : multiplex args : (Tensor[] inputs, Tensor index) diff --git a/test/legacy_test/test_elementwise_mod_op.py b/test/legacy_test/test_elementwise_mod_op.py index fb115241594f85..8d21f3b45e3c81 100644 --- a/test/legacy_test/test_elementwise_mod_op.py +++ b/test/legacy_test/test_elementwise_mod_op.py @@ -17,9 +17,10 @@ import numpy as np from op_test import OpTest, convert_float_to_uint16, convert_uint16_to_float +from utils import dygraph_guard, static_guard import paddle -from paddle import base, static +from paddle import static from paddle.base import core @@ -197,6 +198,144 @@ def init_dtype(self): self.dtype = np.float64 +class TestElementwiseDygraph(unittest.TestCase): + def test_dygraph_same_shape(self): + with dygraph_guard(): + dtypes = ['int32', 'int64', 'float32', 'float64'] + places = [paddle.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + for dtype in dtypes: + for place in places: + shape = [1, 2, 3, 4, 5] + x_np = np.random.uniform(-1000, 1000, shape).astype(dtype) + y_np = np.random.uniform(-1000, 1000, shape).astype(dtype) + # make sure all element in y is non-zero + y_np[np.isclose(y_np, 0)] = -1 + z_np = np.remainder(x_np, y_np) + x = paddle.to_tensor(x_np, dtype=dtype, place=place) + x.stop_gradient = False + y = paddle.to_tensor(y_np, dtype=dtype, place=place) + y.stop_gradient = False + z = paddle.remainder(x, y) + self.assertEqual(z.dtype, x.dtype) + np.testing.assert_allclose(z_np, z.numpy()) + + def test_dygraph_broadcast_to_x(self): + with dygraph_guard(): + dtypes = ['int32', 'int64', 'float32', 'float64'] + places = [paddle.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + for dtype in dtypes: + for place in places: + x_shape = [2, 3, 4, 5] + y_shape = [1, 1, 5] + x_np = np.random.uniform(-1000, 1000, x_shape).astype(dtype) + y_np = np.random.uniform(-1000, 1000, y_shape).astype(dtype) + # make sure all element in y is non-zero + y_np[np.isclose(y_np, 0)] = -1 + z_np = np.remainder(x_np, y_np) + + x = paddle.to_tensor(x_np, dtype=dtype, place=place) + y = paddle.to_tensor(y_np, dtype=dtype, place=place) + z = paddle.remainder(x, y) + self.assertEqual(z.dtype, x.dtype) + np.testing.assert_allclose(z_np, z.numpy()) + + def test_dygraph_broadcast_to_y(self): + with dygraph_guard(): + dtypes = ['int32', 'int64', 'float32', 'float64'] + places = [paddle.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + for dtype in dtypes: + for place in places: + x_shape = [1, 1, 5] + y_shape = [2, 3, 4, 5] + x_np = np.random.uniform(-1000, 1000, x_shape).astype(dtype) + y_np = np.random.uniform(-1000, 1000, y_shape).astype(dtype) + # make sure all element in y is non-zero + y_np[np.isclose(y_np, 0)] = -1 + z_np = np.remainder(x_np, y_np) + + x = paddle.to_tensor(x_np, dtype=dtype, place=place) + y = paddle.to_tensor(y_np, dtype=dtype, place=place) + z = paddle.remainder(x, y) + self.assertEqual(z.dtype, x.dtype) + np.testing.assert_allclose(z_np, z.numpy()) + + def test_dygraph_broadcast_to_z(self): + with dygraph_guard(): + dtypes = ['int32', 'int64', 'float32', 'float64'] + places = [paddle.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + for dtype in dtypes: + for place in places: + x_shape = [1, 3, 1, 5] + y_shape = [2, 1, 4, 1] + x_np = np.random.uniform(-1000, 1000, x_shape).astype(dtype) + y_np = np.random.uniform(-1000, 1000, y_shape).astype(dtype) + # make sure all element in y is non-zero + y_np[np.isclose(y_np, 0)] = -1 + z_np = np.remainder(x_np, y_np) + + x = paddle.to_tensor(x_np, dtype=dtype, place=place) + y = paddle.to_tensor(y_np, dtype=dtype, place=place) + z = paddle.remainder(x, y) + self.assertEqual(z.dtype, x.dtype) + np.testing.assert_allclose(z_np, z.numpy()) + + def test_check_grad(self): + with dygraph_guard(): + dtypes = ['int32', 'int64', 'float32', 'float64'] + places = [paddle.CPUPlace()] # only test in cpu + if core.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + for dtype in dtypes: + for place in places: + x_shape = [2, 1, 4, 1] + y_shape = [1, 3, 1, 5] + # x_shape = y_shape + x_np = np.random.uniform(0, 1000, x_shape).astype(dtype) + # make sure all element in y is non-zero + x_np[x_np == 0] = -1 + y_np = np.random.uniform(0, 1000, y_shape).astype(dtype) + # make sure all element in y is non-zero + y_np[np.isclose(y_np, 0)] = -1 + z_np = np.remainder(x_np, y_np) + + x = paddle.to_tensor( + x_np, dtype=dtype, place=place, stop_gradient=False + ) + y = paddle.to_tensor( + y_np, dtype=dtype, place=place, stop_gradient=False + ) + z = paddle.remainder(x, y) + self.assertEqual(z.dtype, x.dtype) + np.testing.assert_allclose(z_np, z.numpy()) + + v_np = np.random.uniform(-1000, 1000, z_np.shape).astype( + dtype + ) + v = paddle.to_tensor(v_np, dtype=dtype, place=place) + dx = paddle.grad(z, x, v, retain_graph=True)[0] + + dx_np = v_np + for dim in range(len(x_shape)): + if dx_np.shape[dim] > x.shape[dim]: + dx_np = dx_np.sum(axis=dim, keepdims=True) + np.testing.assert_allclose(dx_np, dx.numpy(), 5e-5) + + dy = paddle.grad(z, y, v, retain_graph=True)[0] + dy_np = -v_np * np.floor_divide(x_np, y_np) + for dim in range(len(y_shape)): + if dy_np.shape[dim] > y.shape[dim]: + dy_np = dy_np.sum(axis=dim, keepdims=True) + np.testing.assert_allclose(dy_np, dy.numpy(), 5e-5) + + class TestRemainderOp(unittest.TestCase): def setUp(self): self.np_x1 = np.array([2, 3, 8, 7]).astype('int64') @@ -215,7 +354,7 @@ def _executed_api(self, x, y, name=None): return paddle.remainder(x, y, name) def test_dygraph(self): - with base.dygraph.guard(): + with dygraph_guard(): x = paddle.to_tensor(self.np_x1) y = paddle.to_tensor(self.np_y1) z = self._executed_api(x, y) @@ -233,37 +372,38 @@ def test_dygraph(self): np.testing.assert_allclose(self.z_expected3, z.numpy(), rtol=1e-05) def test_static(self): - mp, sp = static.Program(), static.Program() - with static.program_guard(mp, sp): - x1 = static.data("x1", shape=[4], dtype="int64") - y1 = static.data("y1", shape=[4], dtype="int64") - z1 = self._executed_api(x1, y1) - - x2 = static.data("x2", shape=[4], dtype="float64") - y2 = static.data("y2", shape=[4], dtype="float64") - z2 = self._executed_api(x2, y2) - - x3 = static.data("x3", shape=[4], dtype="int64") - y3 = static.data("y3", shape=[4], dtype="int64") - z3 = self._executed_api(x3, y3) - - exe = static.Executor() - exe.run(sp) - [z_np1, z_np2, z_np3] = exe.run( - mp, - feed={ - "x1": self.np_x1, - "y1": self.np_y1, - "x2": self.np_x2, - "y2": self.np_y2, - "x3": self.np_x3, - "y3": self.np_y3, - }, - fetch_list=[z1, z2, z3], - ) - np.testing.assert_allclose(self.z_expected1, z_np1, rtol=1e-05) - np.testing.assert_allclose(self.z_expected2, z_np2, rtol=1e-05) - np.testing.assert_allclose(self.z_expected3, z_np3, rtol=1e-05) + with static_guard(): + mp, sp = static.Program(), static.Program() + with static.program_guard(mp, sp): + x1 = static.data("x1", shape=[4], dtype="int64") + y1 = static.data("y1", shape=[4], dtype="int64") + z1 = self._executed_api(x1, y1) + + x2 = static.data("x2", shape=[4], dtype="float64") + y2 = static.data("y2", shape=[4], dtype="float64") + z2 = self._executed_api(x2, y2) + + x3 = static.data("x3", shape=[4], dtype="int64") + y3 = static.data("y3", shape=[4], dtype="int64") + z3 = self._executed_api(x3, y3) + + exe = static.Executor() + exe.run(sp) + [z_np1, z_np2, z_np3] = exe.run( + mp, + feed={ + "x1": self.np_x1, + "y1": self.np_y1, + "x2": self.np_x2, + "y2": self.np_y2, + "x3": self.np_x3, + "y3": self.np_y3, + }, + fetch_list=[z1, z2, z3], + ) + np.testing.assert_allclose(self.z_expected1, z_np1, rtol=1e-05) + np.testing.assert_allclose(self.z_expected2, z_np2, rtol=1e-05) + np.testing.assert_allclose(self.z_expected3, z_np3, rtol=1e-05) class TestRemainderInplaceOp(TestRemainderOp):