Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cuda] [type] Refine SNode with quant 6/n: Support __ldg for loading QuantFixedType and QuantFloatType #5374

Merged
merged 7 commits into from
Jul 11, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[cuda] [type] Support __ldg for reading QuantFixedType
strongoier committed Jul 8, 2022
commit 401ef9ae525e1f87f0e46b1620d344caa11cf6e2
21 changes: 13 additions & 8 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
@@ -525,6 +525,12 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
{data_ptr, tlctx->get_constant(data_type_size(dtype))});
}

llvm::Value *load_quant_int_with_intrinsic(llvm::Value *ptr, Type *physical_type, QuantIntType *qit) {
auto [byte_ptr, bit_offset] = load_bit_ptr(ptr);
auto physical_value = create_intrinsic_load(physical_type, byte_ptr);
return extract_quant_int(physical_value, bit_offset, qit);
}

void visit(GlobalLoadStmt *stmt) override {
if (auto get_ch = stmt->src->cast<GetChStmt>(); get_ch) {
bool should_cache_as_read_only = false;
@@ -538,17 +544,16 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
ptr_type->is_bit_pointer()) {
// Bit pointer case.
auto val_type = ptr_type->get_pointee_type();
auto physical_type = get_ch->input_snode->physical_type;
if (auto qit = val_type->cast<QuantIntType>()) {
dtype = get_ch->input_snode->physical_type;
auto [data_ptr, bit_offset] = load_bit_ptr(llvm_val[stmt->src]);
data_ptr = builder->CreateBitCast(data_ptr, llvm_ptr_type(dtype));
auto data = create_intrinsic_load(dtype, data_ptr);
llvm_val[stmt] = extract_quant_int(data, bit_offset, qit);
llvm_val[stmt] = load_quant_int_with_intrinsic(llvm_val[stmt->src], physical_type, qit);
} else if (auto qfxt = val_type->cast<QuantFixedType>()) {
auto digits = load_quant_int_with_intrinsic(llvm_val[stmt->src], physical_type, qfxt->get_digits_type()->as<QuantIntType>());
llvm_val[stmt] = reconstruct_quant_fixed(digits, qfxt);
} else {
// TODO: support __ldg
TI_ASSERT(val_type->is<QuantFixedType>() ||
val_type->is<QuantFloatType>());
llvm_val[stmt] = load_quant_fixed_or_quant_float(stmt->src);
TI_ASSERT(val_type->is<QuantFloatType>());
llvm_val[stmt] = load_quant_float(llvm_val[stmt->src], get_ch->output_snode, val_type->as<QuantFloatType>());
}
} else {
// Byte pointer case.
11 changes: 7 additions & 4 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
@@ -1366,14 +1366,17 @@ void CodeGenLLVM::visit(GlobalLoadStmt *stmt) {
TI_ASSERT(width == 1);
auto ptr_type = stmt->src->ret_type->as<PointerType>();
if (ptr_type->is_bit_pointer()) {
TI_ASSERT(stmt->src->is<GetChStmt>());
auto val_type = ptr_type->get_pointee_type();
if (auto qit = val_type->cast<QuantIntType>()) {
llvm_val[stmt] = load_quant_int(llvm_val[stmt->src], qit);
} else if (auto qfxt = val_type->cast<QuantFixedType>()) {
auto digits = load_quant_int(llvm_val[stmt->src],
qfxt->get_digits_type()->as<QuantIntType>());
llvm_val[stmt] = reconstruct_quant_fixed(digits, qfxt);
} else {
TI_ASSERT(val_type->is<QuantFixedType>() ||
val_type->is<QuantFloatType>());
TI_ASSERT(stmt->src->is<GetChStmt>());
llvm_val[stmt] = load_quant_fixed_or_quant_float(stmt->src);
TI_ASSERT(val_type->is<QuantFloatType>());
llvm_val[stmt] = load_quant_float(llvm_val[stmt->src], stmt->src->as<GetChStmt>()->output_snode, val_type->as<QuantFloatType>());
}
} else {
llvm_val[stmt] = builder->CreateLoad(tlctx->get_data_type(stmt->ret_type),
4 changes: 2 additions & 2 deletions taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
@@ -278,6 +278,8 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
llvm::Value *reconstruct_quant_fixed(llvm::Value *digits,
QuantFixedType *qfxt);

llvm::Value *load_quant_float(llvm::Value *digits_bit_ptr, SNode *digits_snode, QuantFloatType *qflt);

llvm::Value *load_quant_float(llvm::Value *digits_bit_ptr,
llvm::Value *exponent_bit_ptr,
QuantFloatType *qflt,
@@ -288,8 +290,6 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
QuantFloatType *qflt,
bool shared_exponent);

llvm::Value *load_quant_fixed_or_quant_float(Stmt *ptr_stmt);

void visit(GlobalLoadStmt *stmt) override;

void visit(ElementShuffleStmt *stmt) override;
32 changes: 10 additions & 22 deletions taichi/codegen/llvm/codegen_llvm_quant.cpp
Original file line number Diff line number Diff line change
@@ -507,6 +507,16 @@ llvm::Value *CodeGenLLVM::reconstruct_quant_fixed(llvm::Value *digits,
return builder->CreateFMul(cast, s);
}

llvm::Value *CodeGenLLVM::load_quant_float(llvm::Value *digits_bit_ptr, SNode *digits_snode, QuantFloatType *qflt) {
auto exponent_snode = digits_snode->exp_snode;
// Compute the bit pointer of the exponent bits.
TI_ASSERT(digits_snode->parent == exponent_snode->parent);
auto exponent_bit_ptr = offset_bit_ptr(
digits_bit_ptr, exponent_snode->bit_offset - digits_snode->bit_offset);
return load_quant_float(digits_bit_ptr, exponent_bit_ptr, qflt,
digits_snode->owns_shared_exponent);
}

llvm::Value *CodeGenLLVM::load_quant_float(llvm::Value *digits_bit_ptr,
llvm::Value *exponent_bit_ptr,
QuantFloatType *qflt,
@@ -617,28 +627,6 @@ llvm::Value *CodeGenLLVM::reconstruct_quant_float(
}
}

llvm::Value *CodeGenLLVM::load_quant_fixed_or_quant_float(Stmt *ptr_stmt) {
auto ptr = ptr_stmt->as<GetChStmt>();
auto load_type = ptr->ret_type->as<PointerType>()->get_pointee_type();
if (auto qflt = load_type->cast<QuantFloatType>()) {
TI_ASSERT(ptr->width() == 1);
auto digits_bit_ptr = llvm_val[ptr];
auto digits_snode = ptr->output_snode;
auto exponent_snode = digits_snode->exp_snode;
// Compute the bit pointer of the exponent bits.
TI_ASSERT(digits_snode->parent == exponent_snode->parent);
auto exponent_bit_ptr = offset_bit_ptr(
digits_bit_ptr, exponent_snode->bit_offset - digits_snode->bit_offset);
return load_quant_float(digits_bit_ptr, exponent_bit_ptr, qflt,
digits_snode->owns_shared_exponent);
} else {
auto qfxt = load_type->as<QuantFixedType>();
auto digits = load_quant_int(llvm_val[ptr],
qfxt->get_digits_type()->as<QuantIntType>());
return reconstruct_quant_fixed(digits, qfxt);
}
}

TLANG_NAMESPACE_END

#endif // #ifdef TI_WITH_LLVM