Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[llvm] [refactor] Link modules instead of cloning modules #5962

Merged
merged 9 commits into from
Sep 6, 2022
Merged
Prev Previous commit
Next Next commit
remove struct_module
lin-hitonami committed Sep 5, 2022
commit 90a62778f5f2112a215f82646718fa0ba19164ef
4 changes: 2 additions & 2 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
@@ -725,8 +725,8 @@ FunctionType CUDAModuleToFunctionConverter::convert(
}

auto jit = tlctx_->jit.get();
auto cuda_module = jit->add_module(
std::move(mod), executor_->get_config()->gpu_max_reg);
auto cuda_module =
jit->add_module(std::move(mod), executor_->get_config()->gpu_max_reg);

return [cuda_module, kernel_name, args, offloaded_tasks = tasks,
executor = this->executor_](RuntimeContext &context) {
1 change: 0 additions & 1 deletion taichi/codegen/wasm/codegen_wasm.cpp
Original file line number Diff line number Diff line change
@@ -281,7 +281,6 @@ LLVMCompiledData KernelCodeGenWASM::compile_task(
return {name_list, std::move(gen->module), {}, {}};
}


std::vector<LLVMCompiledData> KernelCodeGenWASM::compile_kernel_to_module() {
auto *tlctx = get_llvm_program(prog)->get_llvm_context(kernel->arch);
if (!kernel->lowered()) {
6 changes: 3 additions & 3 deletions taichi/runtime/llvm/llvm_context.cpp
Original file line number Diff line number Diff line change
@@ -520,7 +520,7 @@ void TaichiLLVMContext::link_module_with_cuda_libdevice(
}

void TaichiLLVMContext::add_struct_module(std::unique_ptr<Module> module,
int tree_id) {
int tree_id) {
TI_AUTO_PROF;
TI_ASSERT(std::this_thread::get_id() == main_thread_id_);
auto this_thread_data = get_this_thread_data();
@@ -663,7 +663,8 @@ llvm::DataLayout TaichiLLVMContext::get_data_layout() {
return jit->get_data_layout();
}

JITModule *TaichiLLVMContext::create_jit_module(std::unique_ptr<llvm::Module> module) {
JITModule *TaichiLLVMContext::create_jit_module(
std::unique_ptr<llvm::Module> module) {
return jit->add_module(std::move(module));
}

@@ -884,7 +885,6 @@ TaichiLLVMContext::ThreadLocalData::ThreadLocalData(

TaichiLLVMContext::ThreadLocalData::~ThreadLocalData() {
runtime_module.reset();
struct_module.reset();
struct_modules.clear();
thread_safe_llvm_context.reset();
}
5 changes: 1 addition & 4 deletions taichi/runtime/llvm/llvm_context.h
Original file line number Diff line number Diff line change
@@ -34,8 +34,6 @@ class TaichiLLVMContext {
std::unordered_map<int, std::unique_ptr<llvm::Module>> struct_modules;
ThreadLocalData(std::unique_ptr<llvm::orc::ThreadSafeContext> ctx);
~ThreadLocalData();
std::unique_ptr<llvm::Module> struct_module{nullptr}; // TODO: To be
// deleted
};
CompileConfig *config_;

@@ -68,8 +66,7 @@ class TaichiLLVMContext {
*
* @param module Module containing the JIT compiled SNode structs.
*/
void add_struct_module(std::unique_ptr<llvm::Module> module,
int tree_id);
void add_struct_module(std::unique_ptr<llvm::Module> module, int tree_id);

/**
* Clones the LLVM module compiled from llvm/runtime.cpp
9 changes: 6 additions & 3 deletions taichi/runtime/program_impls/llvm/llvm_program.cpp
Original file line number Diff line number Diff line change
@@ -47,16 +47,19 @@ std::unique_ptr<StructCompiler> LlvmProgramImpl::compile_snode_tree_types_impl(
auto *const root = tree->root();
std::unique_ptr<StructCompiler> struct_compiler{nullptr};
if (arch_is_cpu(config->arch)) {
auto host_module = runtime_exec_->llvm_context_host_.get()->new_module("struct");
auto host_module =
runtime_exec_->llvm_context_host_.get()->new_module("struct");
struct_compiler = std::make_unique<StructCompilerLLVM>(
host_arch(), this, std::move(host_module), tree->id());
} else if (config->arch == Arch::dx12) {
auto device_module = runtime_exec_->llvm_context_device_.get()->new_module("struct");
auto device_module =
runtime_exec_->llvm_context_device_.get()->new_module("struct");
struct_compiler = std::make_unique<StructCompilerLLVM>(
Arch::dx12, this, std::move(device_module), tree->id());
} else {
TI_ASSERT(config->arch == Arch::cuda);
auto device_module = runtime_exec_->llvm_context_device_.get()->new_module("struct");
auto device_module =
runtime_exec_->llvm_context_device_.get()->new_module("struct");
struct_compiler = std::make_unique<StructCompilerLLVM>(
Arch::cuda, this, std::move(device_module), tree->id());
}