Skip to content

Commit

Permalink
Test wider range of float values
Browse files Browse the repository at this point in the history
This requires a few fixes:
- Cleaning up tolerances of tests
- Fix infinity -> int conversion crashes
- The RVV vrnd ukernels work by converting to int and back. This loses the "special" float values (inf, NaN). This change tries to at least preserve inf, matching NaN is a little trickier and also not quite as bad to lose (but we should find a way to do so).

PiperOrigin-RevId: 734782315
  • Loading branch information
dsharletg authored and xnnpack-bot committed Mar 8, 2025
1 parent 42229c3 commit 42d855c
Show file tree
Hide file tree
Showing 23 changed files with 252 additions and 63 deletions.
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ xnnpack_cxx_library(
":common",
":datatype",
":math",
":reference_ukernels",
":xnnpack_h",
],
)
Expand Down
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,7 @@ IF(XNNPACK_BUILD_LIBRARY)
TARGET_LINK_LIBRARIES(operators PRIVATE xnnpack-base allocator indirection logging microkernel-utils normalization operator-utils packing reference-ukernels datatype)
TARGET_LINK_LIBRARIES(operator-run PRIVATE xnnpack-base logging)
TARGET_LINK_LIBRARIES(operator-utils PRIVATE xnnpack-base logging)
TARGET_LINK_LIBRARIES(reference-ukernels PRIVATE xnnpack-base)
TARGET_LINK_LIBRARIES(reference-ukernels PRIVATE xnnpack-base datatype)
TARGET_LINK_LIBRARIES(subgraph PRIVATE xnnpack-base allocator logging memory mutex operators operator-run datatype)
TARGET_LINK_LIBRARIES(XNNPACK PRIVATE xnnpack-base allocator cache hardware-config indirection memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing microkernels-prod subgraph datatype reference-ukernels)
TARGET_LINK_LIBRARIES(XNNPACK PUBLIC pthreadpool logging)
Expand Down Expand Up @@ -1754,6 +1754,7 @@ IF(XNNPACK_BUILD_TESTS)
GTest::gmock
GTest::gtest
GTest::gtest_main
datatype
hardware-config
logging
microkernels-all
Expand Down
8 changes: 8 additions & 0 deletions src/f32-vrnd/gen/f32-vrndd-rvv-u1v.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// LICENSE file in the root directory of this source tree.

#include <assert.h>
#include <math.h>

#include <riscv_vector.h>

Expand All @@ -31,8 +32,15 @@ void xnn_f32_vrndd_ukernel__rvv_u1v(
do {
const size_t n = __riscv_vsetvl_e32m1(batch);
vfloat32m1_t x_f32v = __riscv_vle32_v_f32m1(input, n); input += n;
// We need to remember which values are infinity, so we can preserve them
// after rounding.
// TODO: We should also preserve NaN.
vbool32_t inf_bv = __riscv_vmfeq_vf_f32m1_b32(x_f32v, INFINITY, n);
vbool32_t ninf_bv = __riscv_vmfeq_vf_f32m1_b32(x_f32v, -INFINITY, n);
vbool32_t mask_bv = __riscv_vmor_mm_b32(inf_bv, ninf_bv, n);
vint32m1_t x_rnd_i32v = __riscv_vfcvt_x_f_v_i32m1_rm(x_f32v, __RISCV_FRM_RDN, n);
vfloat32m1_t out_f32v = __riscv_vfcvt_f_x_v_f32m1(x_rnd_i32v, n);
out_f32v = __riscv_vmerge_vvm_f32m1(out_f32v, x_f32v, mask_bv, n);
__riscv_vse32_v_f32m1(output, out_f32v, n); output += n;
batch -= n;
} while (batch != 0);
Expand Down
8 changes: 8 additions & 0 deletions src/f32-vrnd/gen/f32-vrndd-rvv-u2v.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// LICENSE file in the root directory of this source tree.

#include <assert.h>
#include <math.h>

#include <riscv_vector.h>

Expand All @@ -31,8 +32,15 @@ void xnn_f32_vrndd_ukernel__rvv_u2v(
do {
const size_t n = __riscv_vsetvl_e32m2(batch);
vfloat32m2_t x_f32v = __riscv_vle32_v_f32m2(input, n); input += n;
// We need to remember which values are infinity, so we can preserve them
// after rounding.
// TODO: We should also preserve NaN.
vbool16_t inf_bv = __riscv_vmfeq_vf_f32m2_b16(x_f32v, INFINITY, n);
vbool16_t ninf_bv = __riscv_vmfeq_vf_f32m2_b16(x_f32v, -INFINITY, n);
vbool16_t mask_bv = __riscv_vmor_mm_b16(inf_bv, ninf_bv, n);
vint32m2_t x_rnd_i32v = __riscv_vfcvt_x_f_v_i32m2_rm(x_f32v, __RISCV_FRM_RDN, n);
vfloat32m2_t out_f32v = __riscv_vfcvt_f_x_v_f32m2(x_rnd_i32v, n);
out_f32v = __riscv_vmerge_vvm_f32m2(out_f32v, x_f32v, mask_bv, n);
__riscv_vse32_v_f32m2(output, out_f32v, n); output += n;
batch -= n;
} while (batch != 0);
Expand Down
8 changes: 8 additions & 0 deletions src/f32-vrnd/gen/f32-vrndd-rvv-u4v.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// LICENSE file in the root directory of this source tree.

#include <assert.h>
#include <math.h>

#include <riscv_vector.h>

Expand All @@ -31,8 +32,15 @@ void xnn_f32_vrndd_ukernel__rvv_u4v(
do {
const size_t n = __riscv_vsetvl_e32m4(batch);
vfloat32m4_t x_f32v = __riscv_vle32_v_f32m4(input, n); input += n;
// We need to remember which values are infinity, so we can preserve them
// after rounding.
// TODO: We should also preserve NaN.
vbool8_t inf_bv = __riscv_vmfeq_vf_f32m4_b8(x_f32v, INFINITY, n);
vbool8_t ninf_bv = __riscv_vmfeq_vf_f32m4_b8(x_f32v, -INFINITY, n);
vbool8_t mask_bv = __riscv_vmor_mm_b8(inf_bv, ninf_bv, n);
vint32m4_t x_rnd_i32v = __riscv_vfcvt_x_f_v_i32m4_rm(x_f32v, __RISCV_FRM_RDN, n);
vfloat32m4_t out_f32v = __riscv_vfcvt_f_x_v_f32m4(x_rnd_i32v, n);
out_f32v = __riscv_vmerge_vvm_f32m4(out_f32v, x_f32v, mask_bv, n);
__riscv_vse32_v_f32m4(output, out_f32v, n); output += n;
batch -= n;
} while (batch != 0);
Expand Down
8 changes: 8 additions & 0 deletions src/f32-vrnd/gen/f32-vrndd-rvv-u8v.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// LICENSE file in the root directory of this source tree.

#include <assert.h>
#include <math.h>

#include <riscv_vector.h>

Expand All @@ -31,8 +32,15 @@ void xnn_f32_vrndd_ukernel__rvv_u8v(
do {
const size_t n = __riscv_vsetvl_e32m8(batch);
vfloat32m8_t x_f32v = __riscv_vle32_v_f32m8(input, n); input += n;
// We need to remember which values are infinity, so we can preserve them
// after rounding.
// TODO: We should also preserve NaN.
vbool4_t inf_bv = __riscv_vmfeq_vf_f32m8_b4(x_f32v, INFINITY, n);
vbool4_t ninf_bv = __riscv_vmfeq_vf_f32m8_b4(x_f32v, -INFINITY, n);
vbool4_t mask_bv = __riscv_vmor_mm_b4(inf_bv, ninf_bv, n);
vint32m8_t x_rnd_i32v = __riscv_vfcvt_x_f_v_i32m8_rm(x_f32v, __RISCV_FRM_RDN, n);
vfloat32m8_t out_f32v = __riscv_vfcvt_f_x_v_f32m8(x_rnd_i32v, n);
out_f32v = __riscv_vmerge_vvm_f32m8(out_f32v, x_f32v, mask_bv, n);
__riscv_vse32_v_f32m8(output, out_f32v, n); output += n;
batch -= n;
} while (batch != 0);
Expand Down
8 changes: 8 additions & 0 deletions src/f32-vrnd/gen/f32-vrndne-rvv-u1v.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// LICENSE file in the root directory of this source tree.

#include <assert.h>
#include <math.h>

#include <riscv_vector.h>

Expand All @@ -31,8 +32,15 @@ void xnn_f32_vrndne_ukernel__rvv_u1v(
do {
const size_t n = __riscv_vsetvl_e32m1(batch);
vfloat32m1_t x_f32v = __riscv_vle32_v_f32m1(input, n); input += n;
// We need to remember which values are infinity, so we can preserve them
// after rounding.
// TODO: We should also preserve NaN.
vbool32_t inf_bv = __riscv_vmfeq_vf_f32m1_b32(x_f32v, INFINITY, n);
vbool32_t ninf_bv = __riscv_vmfeq_vf_f32m1_b32(x_f32v, -INFINITY, n);
vbool32_t mask_bv = __riscv_vmor_mm_b32(inf_bv, ninf_bv, n);
vint32m1_t x_rnd_i32v = __riscv_vfcvt_x_f_v_i32m1_rm(x_f32v, __RISCV_FRM_RNE, n);
vfloat32m1_t out_f32v = __riscv_vfcvt_f_x_v_f32m1(x_rnd_i32v, n);
out_f32v = __riscv_vmerge_vvm_f32m1(out_f32v, x_f32v, mask_bv, n);
__riscv_vse32_v_f32m1(output, out_f32v, n); output += n;
batch -= n;
} while (batch != 0);
Expand Down
8 changes: 8 additions & 0 deletions src/f32-vrnd/gen/f32-vrndne-rvv-u2v.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// LICENSE file in the root directory of this source tree.

#include <assert.h>
#include <math.h>

#include <riscv_vector.h>

Expand All @@ -31,8 +32,15 @@ void xnn_f32_vrndne_ukernel__rvv_u2v(
do {
const size_t n = __riscv_vsetvl_e32m2(batch);
vfloat32m2_t x_f32v = __riscv_vle32_v_f32m2(input, n); input += n;
// We need to remember which values are infinity, so we can preserve them
// after rounding.
// TODO: We should also preserve NaN.
vbool16_t inf_bv = __riscv_vmfeq_vf_f32m2_b16(x_f32v, INFINITY, n);
vbool16_t ninf_bv = __riscv_vmfeq_vf_f32m2_b16(x_f32v, -INFINITY, n);
vbool16_t mask_bv = __riscv_vmor_mm_b16(inf_bv, ninf_bv, n);
vint32m2_t x_rnd_i32v = __riscv_vfcvt_x_f_v_i32m2_rm(x_f32v, __RISCV_FRM_RNE, n);
vfloat32m2_t out_f32v = __riscv_vfcvt_f_x_v_f32m2(x_rnd_i32v, n);
out_f32v = __riscv_vmerge_vvm_f32m2(out_f32v, x_f32v, mask_bv, n);
__riscv_vse32_v_f32m2(output, out_f32v, n); output += n;
batch -= n;
} while (batch != 0);
Expand Down
8 changes: 8 additions & 0 deletions src/f32-vrnd/gen/f32-vrndne-rvv-u4v.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// LICENSE file in the root directory of this source tree.

#include <assert.h>
#include <math.h>

#include <riscv_vector.h>

Expand All @@ -31,8 +32,15 @@ void xnn_f32_vrndne_ukernel__rvv_u4v(
do {
const size_t n = __riscv_vsetvl_e32m4(batch);
vfloat32m4_t x_f32v = __riscv_vle32_v_f32m4(input, n); input += n;
// We need to remember which values are infinity, so we can preserve them
// after rounding.
// TODO: We should also preserve NaN.
vbool8_t inf_bv = __riscv_vmfeq_vf_f32m4_b8(x_f32v, INFINITY, n);
vbool8_t ninf_bv = __riscv_vmfeq_vf_f32m4_b8(x_f32v, -INFINITY, n);
vbool8_t mask_bv = __riscv_vmor_mm_b8(inf_bv, ninf_bv, n);
vint32m4_t x_rnd_i32v = __riscv_vfcvt_x_f_v_i32m4_rm(x_f32v, __RISCV_FRM_RNE, n);
vfloat32m4_t out_f32v = __riscv_vfcvt_f_x_v_f32m4(x_rnd_i32v, n);
out_f32v = __riscv_vmerge_vvm_f32m4(out_f32v, x_f32v, mask_bv, n);
__riscv_vse32_v_f32m4(output, out_f32v, n); output += n;
batch -= n;
} while (batch != 0);
Expand Down
8 changes: 8 additions & 0 deletions src/f32-vrnd/gen/f32-vrndne-rvv-u8v.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// LICENSE file in the root directory of this source tree.

#include <assert.h>
#include <math.h>

#include <riscv_vector.h>

Expand All @@ -31,8 +32,15 @@ void xnn_f32_vrndne_ukernel__rvv_u8v(
do {
const size_t n = __riscv_vsetvl_e32m8(batch);
vfloat32m8_t x_f32v = __riscv_vle32_v_f32m8(input, n); input += n;
// We need to remember which values are infinity, so we can preserve them
// after rounding.
// TODO: We should also preserve NaN.
vbool4_t inf_bv = __riscv_vmfeq_vf_f32m8_b4(x_f32v, INFINITY, n);
vbool4_t ninf_bv = __riscv_vmfeq_vf_f32m8_b4(x_f32v, -INFINITY, n);
vbool4_t mask_bv = __riscv_vmor_mm_b4(inf_bv, ninf_bv, n);
vint32m8_t x_rnd_i32v = __riscv_vfcvt_x_f_v_i32m8_rm(x_f32v, __RISCV_FRM_RNE, n);
vfloat32m8_t out_f32v = __riscv_vfcvt_f_x_v_f32m8(x_rnd_i32v, n);
out_f32v = __riscv_vmerge_vvm_f32m8(out_f32v, x_f32v, mask_bv, n);
__riscv_vse32_v_f32m8(output, out_f32v, n); output += n;
batch -= n;
} while (batch != 0);
Expand Down
8 changes: 8 additions & 0 deletions src/f32-vrnd/gen/f32-vrndu-rvv-u1v.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// LICENSE file in the root directory of this source tree.

#include <assert.h>
#include <math.h>

#include <riscv_vector.h>

Expand All @@ -31,8 +32,15 @@ void xnn_f32_vrndu_ukernel__rvv_u1v(
do {
const size_t n = __riscv_vsetvl_e32m1(batch);
vfloat32m1_t x_f32v = __riscv_vle32_v_f32m1(input, n); input += n;
// We need to remember which values are infinity, so we can preserve them
// after rounding.
// TODO: We should also preserve NaN.
vbool32_t inf_bv = __riscv_vmfeq_vf_f32m1_b32(x_f32v, INFINITY, n);
vbool32_t ninf_bv = __riscv_vmfeq_vf_f32m1_b32(x_f32v, -INFINITY, n);
vbool32_t mask_bv = __riscv_vmor_mm_b32(inf_bv, ninf_bv, n);
vint32m1_t x_rnd_i32v = __riscv_vfcvt_x_f_v_i32m1_rm(x_f32v, __RISCV_FRM_RUP, n);
vfloat32m1_t out_f32v = __riscv_vfcvt_f_x_v_f32m1(x_rnd_i32v, n);
out_f32v = __riscv_vmerge_vvm_f32m1(out_f32v, x_f32v, mask_bv, n);
__riscv_vse32_v_f32m1(output, out_f32v, n); output += n;
batch -= n;
} while (batch != 0);
Expand Down
8 changes: 8 additions & 0 deletions src/f32-vrnd/gen/f32-vrndu-rvv-u2v.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// LICENSE file in the root directory of this source tree.

#include <assert.h>
#include <math.h>

#include <riscv_vector.h>

Expand All @@ -31,8 +32,15 @@ void xnn_f32_vrndu_ukernel__rvv_u2v(
do {
const size_t n = __riscv_vsetvl_e32m2(batch);
vfloat32m2_t x_f32v = __riscv_vle32_v_f32m2(input, n); input += n;
// We need to remember which values are infinity, so we can preserve them
// after rounding.
// TODO: We should also preserve NaN.
vbool16_t inf_bv = __riscv_vmfeq_vf_f32m2_b16(x_f32v, INFINITY, n);
vbool16_t ninf_bv = __riscv_vmfeq_vf_f32m2_b16(x_f32v, -INFINITY, n);
vbool16_t mask_bv = __riscv_vmor_mm_b16(inf_bv, ninf_bv, n);
vint32m2_t x_rnd_i32v = __riscv_vfcvt_x_f_v_i32m2_rm(x_f32v, __RISCV_FRM_RUP, n);
vfloat32m2_t out_f32v = __riscv_vfcvt_f_x_v_f32m2(x_rnd_i32v, n);
out_f32v = __riscv_vmerge_vvm_f32m2(out_f32v, x_f32v, mask_bv, n);
__riscv_vse32_v_f32m2(output, out_f32v, n); output += n;
batch -= n;
} while (batch != 0);
Expand Down
8 changes: 8 additions & 0 deletions src/f32-vrnd/gen/f32-vrndu-rvv-u4v.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// LICENSE file in the root directory of this source tree.

#include <assert.h>
#include <math.h>

#include <riscv_vector.h>

Expand All @@ -31,8 +32,15 @@ void xnn_f32_vrndu_ukernel__rvv_u4v(
do {
const size_t n = __riscv_vsetvl_e32m4(batch);
vfloat32m4_t x_f32v = __riscv_vle32_v_f32m4(input, n); input += n;
// We need to remember which values are infinity, so we can preserve them
// after rounding.
// TODO: We should also preserve NaN.
vbool8_t inf_bv = __riscv_vmfeq_vf_f32m4_b8(x_f32v, INFINITY, n);
vbool8_t ninf_bv = __riscv_vmfeq_vf_f32m4_b8(x_f32v, -INFINITY, n);
vbool8_t mask_bv = __riscv_vmor_mm_b8(inf_bv, ninf_bv, n);
vint32m4_t x_rnd_i32v = __riscv_vfcvt_x_f_v_i32m4_rm(x_f32v, __RISCV_FRM_RUP, n);
vfloat32m4_t out_f32v = __riscv_vfcvt_f_x_v_f32m4(x_rnd_i32v, n);
out_f32v = __riscv_vmerge_vvm_f32m4(out_f32v, x_f32v, mask_bv, n);
__riscv_vse32_v_f32m4(output, out_f32v, n); output += n;
batch -= n;
} while (batch != 0);
Expand Down
8 changes: 8 additions & 0 deletions src/f32-vrnd/gen/f32-vrndu-rvv-u8v.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// LICENSE file in the root directory of this source tree.

#include <assert.h>
#include <math.h>

#include <riscv_vector.h>

Expand All @@ -31,8 +32,15 @@ void xnn_f32_vrndu_ukernel__rvv_u8v(
do {
const size_t n = __riscv_vsetvl_e32m8(batch);
vfloat32m8_t x_f32v = __riscv_vle32_v_f32m8(input, n); input += n;
// We need to remember which values are infinity, so we can preserve them
// after rounding.
// TODO: We should also preserve NaN.
vbool4_t inf_bv = __riscv_vmfeq_vf_f32m8_b4(x_f32v, INFINITY, n);
vbool4_t ninf_bv = __riscv_vmfeq_vf_f32m8_b4(x_f32v, -INFINITY, n);
vbool4_t mask_bv = __riscv_vmor_mm_b4(inf_bv, ninf_bv, n);
vint32m8_t x_rnd_i32v = __riscv_vfcvt_x_f_v_i32m8_rm(x_f32v, __RISCV_FRM_RUP, n);
vfloat32m8_t out_f32v = __riscv_vfcvt_f_x_v_f32m8(x_rnd_i32v, n);
out_f32v = __riscv_vmerge_vvm_f32m8(out_f32v, x_f32v, mask_bv, n);
__riscv_vse32_v_f32m8(output, out_f32v, n); output += n;
batch -= n;
} while (batch != 0);
Expand Down
8 changes: 8 additions & 0 deletions src/f32-vrnd/gen/f32-vrndz-rvv-u1v.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// LICENSE file in the root directory of this source tree.

#include <assert.h>
#include <math.h>

#include <riscv_vector.h>

Expand All @@ -31,8 +32,15 @@ void xnn_f32_vrndz_ukernel__rvv_u1v(
do {
const size_t n = __riscv_vsetvl_e32m1(batch);
vfloat32m1_t x_f32v = __riscv_vle32_v_f32m1(input, n); input += n;
// We need to remember which values are infinity, so we can preserve them
// after rounding.
// TODO: We should also preserve NaN.
vbool32_t inf_bv = __riscv_vmfeq_vf_f32m1_b32(x_f32v, INFINITY, n);
vbool32_t ninf_bv = __riscv_vmfeq_vf_f32m1_b32(x_f32v, -INFINITY, n);
vbool32_t mask_bv = __riscv_vmor_mm_b32(inf_bv, ninf_bv, n);
vint32m1_t x_rnd_i32v = __riscv_vfcvt_x_f_v_i32m1_rm(x_f32v, __RISCV_FRM_RTZ, n);
vfloat32m1_t out_f32v = __riscv_vfcvt_f_x_v_f32m1(x_rnd_i32v, n);
out_f32v = __riscv_vmerge_vvm_f32m1(out_f32v, x_f32v, mask_bv, n);
__riscv_vse32_v_f32m1(output, out_f32v, n); output += n;
batch -= n;
} while (batch != 0);
Expand Down
8 changes: 8 additions & 0 deletions src/f32-vrnd/gen/f32-vrndz-rvv-u2v.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// LICENSE file in the root directory of this source tree.

#include <assert.h>
#include <math.h>

#include <riscv_vector.h>

Expand All @@ -31,8 +32,15 @@ void xnn_f32_vrndz_ukernel__rvv_u2v(
do {
const size_t n = __riscv_vsetvl_e32m2(batch);
vfloat32m2_t x_f32v = __riscv_vle32_v_f32m2(input, n); input += n;
// We need to remember which values are infinity, so we can preserve them
// after rounding.
// TODO: We should also preserve NaN.
vbool16_t inf_bv = __riscv_vmfeq_vf_f32m2_b16(x_f32v, INFINITY, n);
vbool16_t ninf_bv = __riscv_vmfeq_vf_f32m2_b16(x_f32v, -INFINITY, n);
vbool16_t mask_bv = __riscv_vmor_mm_b16(inf_bv, ninf_bv, n);
vint32m2_t x_rnd_i32v = __riscv_vfcvt_x_f_v_i32m2_rm(x_f32v, __RISCV_FRM_RTZ, n);
vfloat32m2_t out_f32v = __riscv_vfcvt_f_x_v_f32m2(x_rnd_i32v, n);
out_f32v = __riscv_vmerge_vvm_f32m2(out_f32v, x_f32v, mask_bv, n);
__riscv_vse32_v_f32m2(output, out_f32v, n); output += n;
batch -= n;
} while (batch != 0);
Expand Down
8 changes: 8 additions & 0 deletions src/f32-vrnd/gen/f32-vrndz-rvv-u4v.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// LICENSE file in the root directory of this source tree.

#include <assert.h>
#include <math.h>

#include <riscv_vector.h>

Expand All @@ -31,8 +32,15 @@ void xnn_f32_vrndz_ukernel__rvv_u4v(
do {
const size_t n = __riscv_vsetvl_e32m4(batch);
vfloat32m4_t x_f32v = __riscv_vle32_v_f32m4(input, n); input += n;
// We need to remember which values are infinity, so we can preserve them
// after rounding.
// TODO: We should also preserve NaN.
vbool8_t inf_bv = __riscv_vmfeq_vf_f32m4_b8(x_f32v, INFINITY, n);
vbool8_t ninf_bv = __riscv_vmfeq_vf_f32m4_b8(x_f32v, -INFINITY, n);
vbool8_t mask_bv = __riscv_vmor_mm_b8(inf_bv, ninf_bv, n);
vint32m4_t x_rnd_i32v = __riscv_vfcvt_x_f_v_i32m4_rm(x_f32v, __RISCV_FRM_RTZ, n);
vfloat32m4_t out_f32v = __riscv_vfcvt_f_x_v_f32m4(x_rnd_i32v, n);
out_f32v = __riscv_vmerge_vvm_f32m4(out_f32v, x_f32v, mask_bv, n);
__riscv_vse32_v_f32m4(output, out_f32v, n); output += n;
batch -= n;
} while (batch != 0);
Expand Down
Loading

0 comments on commit 42d855c

Please sign in to comment.