Skip to content

Commit 3a9b438

Browse files
author
Ailing Zhang
committed
[refactor] Replace make_current with cuda context guard in jit cuda
session Main goal is to get rid of the to-do in comment as it's fixed taichi-dev#5891
1 parent 3458609 commit 3a9b438

File tree

4 files changed

+5
-6
lines changed

4 files changed

+5
-6
lines changed

taichi/rhi/cuda/cuda_device.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ DeviceAllocation CudaDevice::allocate_memory_runtime(
4343
caching_allocator_ = std::make_unique<CudaCachingAllocator>(this);
4444
}
4545
info.ptr = caching_allocator_->allocate(params);
46+
auto context_guard = CUDAContext::get_instance().get_guard();
4647
CUDADriver::get_instance().memset((void *)info.ptr, 0, info.size);
4748
} else {
4849
info.ptr = allocate_llvm_runtime_memory_jit(params);
@@ -117,6 +118,7 @@ DeviceAllocation CudaDevice::import_memory(void *ptr, size_t size) {
117118
}
118119

119120
uint64 CudaDevice::fetch_result_uint64(int i, uint64 *result_buffer) {
121+
auto context_guard = CUDAContext::get_instance().get_guard();
120122
CUDADriver::get_instance().stream_synchronize(nullptr);
121123
uint64 ret;
122124
CUDADriver::get_instance().memcpy_device_to_host(&ret, result_buffer + i,

taichi/runtime/cuda/jit_cuda.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@ JITModule *JITSessionCUDA ::add_module(std::unique_ptr<llvm::Module> M,
1313
"module NVPTX");
1414
writer.write(ptx);
1515
}
16-
// TODO: figure out why using the guard leads to wrong tests results
17-
// auto context_guard = CUDAContext::get_instance().get_guard();
18-
CUDAContext::get_instance().make_current();
16+
auto context_guard = CUDAContext::get_instance().get_guard();
1917
// Create module for object
2018
void *cuda_module;
2119
TI_TRACE("PTX size: {:.2f}KB", ptx.size() / 1024.0);

taichi/runtime/cuda/jit_cuda.h

+1-3
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,7 @@ class JITModuleCUDA : public JITModule {
4747
}
4848

4949
void *lookup_function(const std::string &name) override {
50-
// TODO: figure out why using the guard leads to wrong tests results
51-
// auto context_guard = CUDAContext::get_instance().get_guard();
52-
CUDAContext::get_instance().make_current();
50+
auto context_guard = CUDAContext::get_instance().get_guard();
5351
void *func = nullptr;
5452
auto t = Time::get_time();
5553
auto err = CUDADriver::get_instance().module_get_function.call_with_warning(

taichi/runtime/llvm/llvm_runtime_executor.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,7 @@ void LlvmRuntimeExecutor::fill_ndarray(const DeviceAllocation &alloc,
476476
auto ptr = get_ndarray_alloc_info_ptr(alloc);
477477
if (config_->arch == Arch::cuda) {
478478
#if defined(TI_WITH_CUDA)
479+
auto context_guard = CUDAContext::get_instance().get_guard();
479480
CUDADriver::get_instance().memsetd32((void *)ptr, data, size);
480481
#else
481482
TI_NOT_IMPLEMENTED

0 commit comments

Comments
 (0)