Skip to content

Commit

Permalink
Merge pull request #46110 from JuliaLang/vc/fp16_bb
Browse files Browse the repository at this point in the history
Backport "Emit aliases for FP16 conversion routines" (#45649) to 1.8
  • Loading branch information
vchuravy authored Aug 2, 2022
2 parents 20ef7d9 + 7ff795c commit 4983135
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 44 deletions.
6 changes: 3 additions & 3 deletions src/APInt-C.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ void LLVMByteSwap(unsigned numbits, integerPart *pa, integerPart *pr) {
void LLVMFPtoInt(unsigned numbits, void *pa, unsigned onumbits, integerPart *pr, bool isSigned, bool *isExact) {
double Val;
if (numbits == 16)
Val = __gnu_h2f_ieee(*(uint16_t*)pa);
Val = julia__gnu_h2f_ieee(*(uint16_t*)pa);
else if (numbits == 32)
Val = *(float*)pa;
else if (numbits == 64)
Expand Down Expand Up @@ -391,7 +391,7 @@ void LLVMSItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar
val = a.roundToDouble(true);
}
if (onumbits == 16)
*(uint16_t*)pr = __gnu_f2h_ieee(val);
*(uint16_t*)pr = julia__gnu_f2h_ieee(val);
else if (onumbits == 32)
*(float*)pr = val;
else if (onumbits == 64)
Expand All @@ -408,7 +408,7 @@ void LLVMUItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar
val = a.roundToDouble(false);
}
if (onumbits == 16)
*(uint16_t*)pr = __gnu_f2h_ieee(val);
*(uint16_t*)pr = julia__gnu_f2h_ieee(val);
else if (onumbits == 32)
*(float*)pr = val;
else if (onumbits == 64)
Expand Down
48 changes: 44 additions & 4 deletions src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include <llvm/Analysis/BasicAliasAnalysis.h>
#include <llvm/Analysis/TypeBasedAliasAnalysis.h>
#include <llvm/Analysis/ScopedNoAliasAA.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/PassManager.h>
#include <llvm/IR/Verifier.h>
#include <llvm/Transforms/IPO.h>
#include <llvm/Transforms/Scalar.h>
Expand All @@ -31,6 +33,9 @@
#include <llvm/Transforms/InstCombine/InstCombine.h>
#include <llvm/Transforms/Scalar/InstSimplifyPass.h>
#include <llvm/Transforms/Utils/SimplifyCFGOptions.h>
#include <llvm/Transforms/Utils/ModuleUtils.h>
#include <llvm/Passes/PassBuilder.h>
#include <llvm/Passes/PassPlugin.h>
#if defined(USE_POLLY)
#include <polly/RegisterPasses.h>
#include <polly/LinkAllPasses.h>
Expand Down Expand Up @@ -431,6 +436,23 @@ static void reportWriterError(const ErrorInfoBase &E)
jl_safe_printf("ERROR: failed to emit output file %s\n", err.c_str());
}

static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionType *FT)
{
Function *target = M.getFunction(alias);
if (!target) {
target = Function::Create(FT, Function::ExternalLinkage, alias, M);
}
Function *interposer = Function::Create(FT, Function::WeakAnyLinkage, name, M);
appendToCompilerUsed(M, {interposer});

llvm::IRBuilder<> builder(BasicBlock::Create(M.getContext(), "top", interposer));
SmallVector<Value *, 4> CallArgs;
for (auto &arg : interposer->args())
CallArgs.push_back(&arg);
auto val = builder.CreateCall(target, CallArgs);
builder.CreateRet(val);
}


// takes the running content that has collected in the shadow module and dump it to disk
// this builds the object file portion of the sysimage files for fast startup
Expand Down Expand Up @@ -475,7 +497,7 @@ void jl_dump_native_impl(void *native_code,
CodeGenOpt::Aggressive // -O3 TODO: respect command -O0 flag?
));

legacy::PassManager PM;
legacy::PassManager PM, postopt;
addTargetPasses(&PM, TM.get());

// set up optimization passes
Expand All @@ -500,12 +522,12 @@ void jl_dump_native_impl(void *native_code,
addMachinePasses(&PM, TM.get(), jl_options.opt_level);
}
if (bc_fname)
PM.add(createBitcodeWriterPass(bc_OS));
postopt.add(createBitcodeWriterPass(bc_OS));
if (obj_fname)
if (TM->addPassesToEmitFile(PM, obj_OS, nullptr, CGFT_ObjectFile, false))
if (TM->addPassesToEmitFile(postopt, obj_OS, nullptr, CGFT_ObjectFile, false))
jl_safe_printf("ERROR: target does not support generation of object files\n");
if (asm_fname)
if (TM->addPassesToEmitFile(PM, asm_OS, nullptr, CGFT_AssemblyFile, false))
if (TM->addPassesToEmitFile(postopt, asm_OS, nullptr, CGFT_AssemblyFile, false))
jl_safe_printf("ERROR: target does not support generation of object files\n");

// Reset the target triple to make sure it matches the new target machine
Expand Down Expand Up @@ -539,6 +561,24 @@ void jl_dump_native_impl(void *native_code,
// do the actual work
auto add_output = [&] (Module &M, StringRef unopt_bc_Name, StringRef bc_Name, StringRef obj_Name, StringRef asm_Name) {
PM.run(M);

// We would like to emit an alias or an weakref alias to redirect these symbols
// but LLVM doesn't let us emit a GlobalAlias to a declaration...
// So for now we inject a definition of these functions that calls our runtime
// functions. We do so after optimization to avoid cloning these functions.
injectCRTAlias(M, "__gnu_h2f_ieee", "julia__gnu_h2f_ieee",
FunctionType::get(Type::getFloatTy(Context), { Type::getHalfTy(Context) }, false));
injectCRTAlias(M, "__extendhfsf2", "julia__gnu_h2f_ieee",
FunctionType::get(Type::getFloatTy(Context), { Type::getHalfTy(Context) }, false));
injectCRTAlias(M, "__gnu_f2h_ieee", "julia__gnu_f2h_ieee",
FunctionType::get(Type::getHalfTy(Context), { Type::getFloatTy(Context) }, false));
injectCRTAlias(M, "__truncsfhf2", "julia__gnu_f2h_ieee",
FunctionType::get(Type::getHalfTy(Context), { Type::getFloatTy(Context) }, false));
injectCRTAlias(M, "__truncdfhf2", "julia__truncdfhf2",
FunctionType::get(Type::getHalfTy(Context), { Type::getDoubleTy(Context) }, false));

postopt.run(M);

if (unopt_bc_fname)
emit_result(unopt_bc_Archive, unopt_bc_Buffer, unopt_bc_Name, outputs);
if (bc_fname)
Expand Down
18 changes: 16 additions & 2 deletions src/jitlayers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -909,12 +909,26 @@ JuliaOJIT::JuliaOJIT(TargetMachine &TM, LLVMContext *LLVMCtx)
}

JD.addToLinkOrder(GlobalJD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly);

orc::SymbolAliasMap jl_crt = {
{ mangle("__gnu_h2f_ieee"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } },
{ mangle("__extendhfsf2"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } },
{ mangle("__gnu_f2h_ieee"), { mangle("julia__gnu_f2h_ieee"), JITSymbolFlags::Exported } },
{ mangle("__truncsfhf2"), { mangle("julia__gnu_f2h_ieee"), JITSymbolFlags::Exported } },
{ mangle("__truncdfhf2"), { mangle("julia__truncdfhf2"), JITSymbolFlags::Exported } }
};
cantFail(GlobalJD.define(orc::symbolAliases(jl_crt)));
}

void JuliaOJIT::addGlobalMapping(StringRef Name, uint64_t Addr)
orc::SymbolStringPtr JuliaOJIT::mangle(StringRef Name)
{
std::string MangleName = getMangledName(Name);
cantFail(JD.define(orc::absoluteSymbols({{ES.intern(MangleName), JITEvaluatedSymbol::fromPointer((void*)Addr)}})));
return ES.intern(MangleName);
}

void JuliaOJIT::addGlobalMapping(StringRef Name, uint64_t Addr)
{
cantFail(JD.define(orc::absoluteSymbols({{mangle(Name), JITEvaluatedSymbol::fromPointer((void*)Addr)}})));
}

void JuliaOJIT::addModule(std::unique_ptr<Module> M)
Expand Down
1 change: 1 addition & 0 deletions src/jitlayers.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ class JuliaOJIT {
void RegisterJITEventListener(JITEventListener *L);
#endif

orc::SymbolStringPtr mangle(StringRef Name);
void addGlobalMapping(StringRef Name, uint64_t Addr);
void addModule(std::unique_ptr<Module> M);

Expand Down
6 changes: 0 additions & 6 deletions src/julia.expmap
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,6 @@
environ;
__progname;

/* compiler run-time intrinsics */
__gnu_h2f_ieee;
__extendhfsf2;
__gnu_f2h_ieee;
__truncdfhf2;

local:
*;
};
14 changes: 12 additions & 2 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1520,8 +1520,18 @@ jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT;
#define JL_GC_ASSERT_LIVE(x) (void)(x)
#endif

float __gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT;
uint16_t __gnu_f2h_ieee(float param) JL_NOTSAFEPOINT;
JL_DLLEXPORT float julia__gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint16_t julia__gnu_f2h_ieee(float param) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint16_t julia__truncdfhf2(double param) JL_NOTSAFEPOINT;
//JL_DLLEXPORT double julia__extendhfdf2(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT int32_t julia__fixhfsi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT int64_t julia__fixhfdi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint32_t julia__fixunshfsi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint64_t julia__fixunshfdi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint16_t julia__floatsihf(int32_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint16_t julia__floatdihf(int64_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint16_t julia__floatunsihf(uint32_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint16_t julia__floatundihf(uint64_t n) JL_NOTSAFEPOINT;

#ifdef __cplusplus
}
Expand Down
64 changes: 37 additions & 27 deletions src/runtime_intrinsics.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
const unsigned int host_char_bit = 8;

// float16 intrinsics
// TODO: use LLVM's compiler-rt on all platforms (Xcode already links compiler-rt)

#if !defined(_OS_DARWIN_)

static inline float half_to_float(uint16_t ival) JL_NOTSAFEPOINT
{
Expand Down Expand Up @@ -188,22 +185,17 @@ static inline uint16_t float_to_half(float param) JL_NOTSAFEPOINT
return h;
}

JL_DLLEXPORT float __gnu_h2f_ieee(uint16_t param)
JL_DLLEXPORT float julia__gnu_h2f_ieee(uint16_t param)
{
return half_to_float(param);
}

JL_DLLEXPORT float __extendhfsf2(uint16_t param)
{
return half_to_float(param);
}

JL_DLLEXPORT uint16_t __gnu_f2h_ieee(float param)
JL_DLLEXPORT uint16_t julia__gnu_f2h_ieee(float param)
{
return float_to_half(param);
}

JL_DLLEXPORT uint16_t __truncdfhf2(double param)
JL_DLLEXPORT uint16_t julia__truncdfhf2(double param)
{
float res = (float)param;
uint32_t resi;
Expand All @@ -225,7 +217,25 @@ JL_DLLEXPORT uint16_t __truncdfhf2(double param)
return float_to_half(res);
}

#endif
//JL_DLLEXPORT double julia__extendhfdf2(uint16_t n) { return (double)julia__gnu_h2f_ieee(n); }
//JL_DLLEXPORT int32_t julia__fixhfsi(uint16_t n) { return (int32_t)julia__gnu_h2f_ieee(n); }
//JL_DLLEXPORT int64_t julia__fixhfdi(uint16_t n) { return (int64_t)julia__gnu_h2f_ieee(n); }
//JL_DLLEXPORT uint32_t julia__fixunshfsi(uint16_t n) { return (uint32_t)julia__gnu_h2f_ieee(n); }
//JL_DLLEXPORT uint64_t julia__fixunshfdi(uint16_t n) { return (uint64_t)julia__gnu_h2f_ieee(n); }
//JL_DLLEXPORT uint16_t julia__floatsihf(int32_t n) { return julia__gnu_f2h_ieee((float)n); }
//JL_DLLEXPORT uint16_t julia__floatdihf(int64_t n) { return julia__gnu_f2h_ieee((float)n); }
//JL_DLLEXPORT uint16_t julia__floatunsihf(uint32_t n) { return julia__gnu_f2h_ieee((float)n); }
//JL_DLLEXPORT uint16_t julia__floatundihf(uint64_t n) { return julia__gnu_f2h_ieee((float)n); }
//HANDLE_LIBCALL(F16, F128, __extendhftf2)
//HANDLE_LIBCALL(F16, F80, __extendhfxf2)
//HANDLE_LIBCALL(F80, F16, __truncxfhf2)
//HANDLE_LIBCALL(F128, F16, __trunctfhf2)
//HANDLE_LIBCALL(PPCF128, F16, __trunctfhf2)
//HANDLE_LIBCALL(F16, I128, __fixhfti)
//HANDLE_LIBCALL(F16, I128, __fixunshfti)
//HANDLE_LIBCALL(I128, F16, __floattihf)
//HANDLE_LIBCALL(I128, F16, __floatuntihf)


// run time version of bitcast intrinsic
JL_DLLEXPORT jl_value_t *jl_bitcast(jl_value_t *ty, jl_value_t *v)
Expand Down Expand Up @@ -551,9 +561,9 @@ static inline unsigned select_by_size(unsigned sz) JL_NOTSAFEPOINT
}

#define fp_select(a, func) \
sizeof(a) == sizeof(float) ? func##f((float)a) : func(a)
sizeof(a) <= sizeof(float) ? func##f((float)a) : func(a)
#define fp_select2(a, b, func) \
sizeof(a) == sizeof(float) ? func##f(a, b) : func(a, b)
sizeof(a) <= sizeof(float) ? func##f(a, b) : func(a, b)

// fast-function generators //

Expand Down Expand Up @@ -597,11 +607,11 @@ static inline void name(unsigned osize, void *pa, void *pr) JL_NOTSAFEPOINT \
static inline void name(unsigned osize, void *pa, void *pr) JL_NOTSAFEPOINT \
{ \
uint16_t a = *(uint16_t*)pa; \
float A = __gnu_h2f_ieee(a); \
float A = julia__gnu_h2f_ieee(a); \
if (osize == 16) { \
float R; \
OP(&R, A); \
*(uint16_t*)pr = __gnu_f2h_ieee(R); \
*(uint16_t*)pr = julia__gnu_f2h_ieee(R); \
} else { \
OP((uint16_t*)pr, A); \
} \
Expand All @@ -625,11 +635,11 @@ static void jl_##name##16(unsigned runtime_nbits, void *pa, void *pb, void *pr)
{ \
uint16_t a = *(uint16_t*)pa; \
uint16_t b = *(uint16_t*)pb; \
float A = __gnu_h2f_ieee(a); \
float B = __gnu_h2f_ieee(b); \
float A = julia__gnu_h2f_ieee(a); \
float B = julia__gnu_h2f_ieee(b); \
runtime_nbits = 16; \
float R = OP(A, B); \
*(uint16_t*)pr = __gnu_f2h_ieee(R); \
*(uint16_t*)pr = julia__gnu_f2h_ieee(R); \
}

// float or integer inputs, bool output
Expand All @@ -650,8 +660,8 @@ static int jl_##name##16(unsigned runtime_nbits, void *pa, void *pb) JL_NOTSAFEP
{ \
uint16_t a = *(uint16_t*)pa; \
uint16_t b = *(uint16_t*)pb; \
float A = __gnu_h2f_ieee(a); \
float B = __gnu_h2f_ieee(b); \
float A = julia__gnu_h2f_ieee(a); \
float B = julia__gnu_h2f_ieee(b); \
runtime_nbits = 16; \
return OP(A, B); \
}
Expand Down Expand Up @@ -691,12 +701,12 @@ static void jl_##name##16(unsigned runtime_nbits, void *pa, void *pb, void *pc,
uint16_t a = *(uint16_t*)pa; \
uint16_t b = *(uint16_t*)pb; \
uint16_t c = *(uint16_t*)pc; \
float A = __gnu_h2f_ieee(a); \
float B = __gnu_h2f_ieee(b); \
float C = __gnu_h2f_ieee(c); \
float A = julia__gnu_h2f_ieee(a); \
float B = julia__gnu_h2f_ieee(b); \
float C = julia__gnu_h2f_ieee(c); \
runtime_nbits = 16; \
float R = OP(A, B, C); \
*(uint16_t*)pr = __gnu_f2h_ieee(R); \
*(uint16_t*)pr = julia__gnu_f2h_ieee(R); \
}


Expand Down Expand Up @@ -1318,7 +1328,7 @@ static inline int fpiseq##nbits(c_type a, c_type b) JL_NOTSAFEPOINT { \
fpiseq_n(float, 32)
fpiseq_n(double, 64)
#define fpiseq(a,b) \
sizeof(a) == sizeof(float) ? fpiseq32(a, b) : fpiseq64(a, b)
sizeof(a) <= sizeof(float) ? fpiseq32(a, b) : fpiseq64(a, b)

bool_fintrinsic(eq,eq_float)
bool_fintrinsic(ne,ne_float)
Expand Down Expand Up @@ -1367,7 +1377,7 @@ cvt_iintrinsic(LLVMFPtoUI, fptoui)
if (!(osize < 8 * sizeof(a))) \
jl_error("fptrunc: output bitsize must be < input bitsize"); \
else if (osize == 16) \
*(uint16_t*)pr = __gnu_f2h_ieee(a); \
*(uint16_t*)pr = julia__gnu_f2h_ieee(a); \
else if (osize == 32) \
*(float*)pr = a; \
else if (osize == 64) \
Expand Down
24 changes: 24 additions & 0 deletions test/intrinsics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,30 @@ end
@test_intrinsic Core.Intrinsics.fptoui UInt Float16(3.3) UInt(3)
end

if Sys.ARCH == :aarch64
# On AArch64 we are following the `_Float16` ABI. Buthe these functions expect `Int16`.
# TODO: SHould we have `Chalf == Int16` and `Cfloat16 == Float16`?
extendhfsf2(x::Float16) = ccall("extern __extendhfsf2", llvmcall, Float32, (Int16,), reinterpret(Int16, x))
gnu_h2f_ieee(x::Float16) = ccall("extern __gnu_h2f_ieee", llvmcall, Float32, (Int16,), reinterpret(Int16, x))
truncsfhf2(x::Float32) = reinterpret(Float16, ccall("extern __truncsfhf2", llvmcall, Int16, (Float32,), x))
gnu_f2h_ieee(x::Float32) = reinterpret(Float16, ccall("extern __gnu_f2h_ieee", llvmcall, Int16, (Float32,), x))
truncdfhf2(x::Float64) = reinterpret(Float16, ccall("extern __truncdfhf2", llvmcall, Int16, (Float64,), x))
else
extendhfsf2(x::Float16) = ccall("extern __extendhfsf2", llvmcall, Float32, (Float16,), x)
gnu_h2f_ieee(x::Float16) = ccall("extern __gnu_h2f_ieee", llvmcall, Float32, (Float16,), x)
truncsfhf2(x::Float32) = ccall("extern __truncsfhf2", llvmcall, Float16, (Float32,), x)
gnu_f2h_ieee(x::Float32) = ccall("extern __gnu_f2h_ieee", llvmcall, Float16, (Float32,), x)
truncdfhf2(x::Float64) = ccall("extern __truncdfhf2", llvmcall, Float16, (Float64,), x)
end

@testset "Float16 intrinsics (crt)" begin
@test extendhfsf2(Float16(3.3)) == 3.3007812f0
@test gnu_h2f_ieee(Float16(3.3)) == 3.3007812f0
@test truncsfhf2(3.3f0) == Float16(3.3)
@test gnu_f2h_ieee(3.3f0) == Float16(3.3)
@test truncdfhf2(3.3) == Float16(3.3)
end

using Base.Experimental: @force_compile
@test_throws ConcurrencyViolationError("invalid atomic ordering") (@force_compile; Core.Intrinsics.atomic_fence(:u)) === nothing
@test_throws ConcurrencyViolationError("invalid atomic ordering") (@force_compile; Core.Intrinsics.atomic_fence(Symbol("u", "x"))) === nothing
Expand Down

0 comments on commit 4983135

Please sign in to comment.