Skip to content

Commit e08f302

Browse files
committed
Use context guard to replace make_current
1 parent 585538e commit e08f302

13 files changed

+101
-55
lines changed

taichi/codegen/cuda/codegen_cuda.cpp

+27-20
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,6 @@ FunctionType CUDAModuleToFunctionConverter::convert(
786786

787787
return [cuda_modules, kernel_name, args, offloaded_tasks,
788788
executor = this->executor_](RuntimeContext &context) {
789-
CUDAContext::get_instance().make_current();
790789
std::vector<void *> arg_buffers(args.size(), nullptr);
791790
std::vector<void *> device_buffers(args.size(), nullptr);
792791

@@ -804,24 +803,28 @@ FunctionType CUDAModuleToFunctionConverter::convert(
804803
// in shapes, e.g., shape=(0) or shape=(100, 0, 200). This makes
805804
// `arr_sz` zero.
806805
unsigned int attr_val = 0;
807-
uint32_t ret_code = CUDADriver::get_instance().mem_get_attribute.call(
808-
&attr_val, CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
809-
(void *)arg_buffers[i]);
810-
811-
if (ret_code != CUDA_SUCCESS || attr_val != CU_MEMORYTYPE_DEVICE) {
812-
// Copy to device buffer if arg is on host
813-
// - ret_code != CUDA_SUCCESS:
814-
// arg_buffers[i] is not on device
815-
// - attr_val != CU_MEMORYTYPE_DEVICE:
816-
// Cuda driver is aware of arg_buffers[i] but it might be on
817-
// host.
818-
// See CUDA driver API `cuPointerGetAttribute` for more details.
819-
transferred = true;
820-
CUDADriver::get_instance().malloc(&device_buffers[i], arr_sz);
821-
CUDADriver::get_instance().memcpy_host_to_device(
822-
(void *)device_buffers[i], arg_buffers[i], arr_sz);
823-
} else {
824-
device_buffers[i] = arg_buffers[i];
806+
{
807+
auto context_guard = CUDAContext::get_instance().get_guard();
808+
uint32_t ret_code =
809+
CUDADriver::get_instance().mem_get_attribute.call(
810+
&attr_val, CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
811+
(void *)arg_buffers[i]);
812+
813+
if (ret_code != CUDA_SUCCESS || attr_val != CU_MEMORYTYPE_DEVICE) {
814+
// Copy to device buffer if arg is on host
815+
// - ret_code != CUDA_SUCCESS:
816+
// arg_buffers[i] is not on device
817+
// - attr_val != CU_MEMORYTYPE_DEVICE:
818+
// Cuda driver is aware of arg_buffers[i] but it might be on
819+
// host.
820+
// See CUDA driver API `cuPointerGetAttribute` for more details.
821+
transferred = true;
822+
CUDADriver::get_instance().malloc(&device_buffers[i], arr_sz);
823+
CUDADriver::get_instance().memcpy_host_to_device(
824+
(void *)device_buffers[i], arg_buffers[i], arr_sz);
825+
} else {
826+
device_buffers[i] = arg_buffers[i];
827+
}
825828
}
826829
// device_buffers[i] saves a raw ptr on CUDA device.
827830
context.set_arg(i, (uint64)device_buffers[i]);
@@ -845,7 +848,10 @@ FunctionType CUDAModuleToFunctionConverter::convert(
845848
}
846849
}
847850
if (transferred) {
848-
CUDADriver::get_instance().stream_synchronize(nullptr);
851+
{
852+
auto context_guard = CUDAContext::get_instance().get_guard();
853+
CUDADriver::get_instance().stream_synchronize(nullptr);
854+
}
849855
}
850856

851857
for (int i = 0; i < offloaded_tasks.size(); i++) {
@@ -859,6 +865,7 @@ FunctionType CUDAModuleToFunctionConverter::convert(
859865

860866
// copy data back to host
861867
if (transferred) {
868+
auto context_guard = CUDAContext::get_instance().get_guard();
862869
CUDADriver::get_instance().stream_synchronize(nullptr);
863870
for (int i = 0; i < (int)args.size(); i++) {
864871
if (device_buffers[i] != arg_buffers[i]) {

taichi/program/sparse_matrix.cpp

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
#include "taichi/program/sparse_matrix.h"
2+
#if defined(TI_WITH_CUDA)
3+
#include "taichi/rhi/cuda/cuda_context.h"
4+
#endif
25

36
#include <sstream>
47
#include <string>
@@ -204,6 +207,7 @@ void CuSparseMatrix::build_csr_from_coo(void *coo_row_ptr,
204207
int nnz) {
205208
#if defined(TI_WITH_CUDA)
206209
void *csr_row_offset_ptr = NULL;
210+
auto context_guard = CUDAContext::get_instance().get_guard();
207211
CUDADriver::get_instance().malloc(&csr_row_offset_ptr,
208212
sizeof(int) * (rows_ + 1));
209213
cusparseHandle_t cusparse_handle;
@@ -269,8 +273,10 @@ void CuSparseMatrix::spmv(Program *prog, const Ndarray &x, Ndarray &y) {
269273
&beta, vecY, CUDA_R_32F, CUSPARSE_SPMV_CSR_ALG1, &bufferSize);
270274

271275
void *dBuffer = NULL;
272-
if (bufferSize > 0)
276+
if (bufferSize > 0) {
277+
auto context_guard = CUDAContext::get_instance().get_guard();
273278
CUDADriver::get_instance().malloc(&dBuffer, bufferSize);
279+
}
274280
CUSPARSEDriver::get_instance().cpSpMV(
275281
cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matrix_, vecX,
276282
&beta, vecY, CUDA_R_32F, CUSPARSE_SPMV_CSR_ALG1, dBuffer);

taichi/python/export_lang.cpp

-8
Original file line numberDiff line numberDiff line change
@@ -1100,14 +1100,6 @@ void export_lang(py::module &m) {
11001100
}
11011101
});
11021102

1103-
#if defined(TI_WITH_CUDA)
1104-
m.def("pop_cuda_context",
1105-
[]() { CUDADriver::get_instance().context_pop_current(NULL); });
1106-
1107-
m.def("push_cuda_context",
1108-
[]() { CUDAContext::get_instance().make_current(); });
1109-
#endif
1110-
11111103
// Type system
11121104

11131105
py::class_<Type>(m, "Type").def("to_string", &Type::to_string);

taichi/rhi/cuda/cuda_context.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -55,20 +55,26 @@ CUDAContext::CUDAContext()
5555

5656
std::size_t CUDAContext::get_total_memory() {
5757
std::size_t ret, _;
58+
this->make_current();
5859
driver_.mem_get_info(&_, &ret);
60+
driver_.context_pop_current(nullptr);
5961
return ret;
6062
}
6163

6264
std::size_t CUDAContext::get_free_memory() {
6365
std::size_t ret, _;
66+
this->make_current();
6467
driver_.mem_get_info(&ret, &_);
68+
driver_.context_pop_current(nullptr);
6569
return ret;
6670
}
6771

6872
std::string CUDAContext::get_device_name() {
6973
constexpr uint32_t kMaxNameStringLength = 128;
7074
char name[kMaxNameStringLength];
75+
this->make_current();
7176
driver_.device_get_name(name, kMaxNameStringLength /*=128*/, device_);
77+
driver_.context_pop_current(nullptr);
7278
std::string str(name);
7379
return str;
7480
}

taichi/rhi/cuda/cuda_context.h

+17-5
Original file line numberDiff line numberDiff line change
@@ -78,21 +78,33 @@ class CUDAContext {
7878
void *new_ctx_;
7979

8080
public:
81-
ContextGuard(CUDAContext *new_ctx) : old_ctx_(nullptr), new_ctx_(new_ctx) {
81+
ContextGuard(CUDAContext *new_ctx)
82+
: old_ctx_(nullptr), new_ctx_(new_ctx->context_) {
8283
CUDADriver::get_instance().context_get_current(&old_ctx_);
83-
if (old_ctx_ != new_ctx)
84+
if (old_ctx_ != new_ctx_) {
8485
new_ctx->make_current();
86+
}
8587
}
8688

8789
~ContextGuard() {
88-
if (old_ctx_ != new_ctx_) {
90+
// Always pop out the current context in order to inter-operator with
91+
// 3rd-party libs However, this would cause problems when there are nested
92+
// guards Use the following logic if encountered context-related any
93+
// errors if (old_ctx_ != new_ctx_) {
94+
// CUDADriver::get_instance().context_set_current(old_ctx_);
95+
// }
96+
void *pop_ctx = nullptr;
97+
CUDADriver::get_instance().context_pop_current(&pop_ctx);
98+
TI_ASSERT(pop_ctx == new_ctx_);
99+
100+
if (old_ctx_ != nullptr) {
89101
CUDADriver::get_instance().context_set_current(old_ctx_);
90102
}
91103
}
92104
};
93105

94-
ContextGuard get_guard() {
95-
return ContextGuard(this);
106+
std::unique_ptr<ContextGuard> get_guard() {
107+
return std::move(std::make_unique<ContextGuard>(this));
96108
}
97109

98110
std::unique_lock<std::mutex> get_lock_guard() {

taichi/rhi/cuda/cuda_device.cpp

+17-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ CudaDevice::AllocInfo CudaDevice::get_alloc_info(
1414
DeviceAllocation CudaDevice::allocate_memory(const AllocParams &params) {
1515
AllocInfo info;
1616

17+
auto context_guard = CUDAContext::get_instance().get_guard();
18+
1719
if (params.host_read || params.host_write) {
1820
CUDADriver::get_instance().malloc_managed(&info.ptr, params.size,
1921
CU_MEM_ATTACH_GLOBAL);
@@ -45,7 +47,10 @@ DeviceAllocation CudaDevice::allocate_memory_runtime(
4547
caching_allocator_ = std::make_unique<CudaCachingAllocator>(this);
4648
}
4749
info.ptr = caching_allocator_->allocate(params);
48-
CUDADriver::get_instance().memset((void *)info.ptr, 0, info.size);
50+
{
51+
auto context_guard = CUDAContext::get_instance().get_guard();
52+
CUDADriver::get_instance().memset((void *)info.ptr, 0, info.size);
53+
}
4954
} else {
5055
info.ptr = allocate_llvm_runtime_memory_jit(params);
5156
}
@@ -63,6 +68,9 @@ DeviceAllocation CudaDevice::allocate_memory_runtime(
6368

6469
void CudaDevice::dealloc_memory(DeviceAllocation handle) {
6570
validate_device_alloc(handle);
71+
72+
auto context_guard = CUDAContext::get_instance().get_guard();
73+
6674
AllocInfo &info = allocations_[handle.alloc_id];
6775
if (info.ptr == nullptr) {
6876
TI_ERROR("the DeviceAllocation is already deallocated");
@@ -80,6 +88,8 @@ void CudaDevice::dealloc_memory(DeviceAllocation handle) {
8088
}
8189

8290
void *CudaDevice::map(DeviceAllocation alloc) {
91+
auto context_guard = CUDAContext::get_instance().get_guard();
92+
8393
AllocInfo &info = allocations_[alloc.alloc_id];
8494
size_t size = info.size;
8595
info.mapped = new char[size];
@@ -89,6 +99,8 @@ void *CudaDevice::map(DeviceAllocation alloc) {
8999
}
90100

91101
void CudaDevice::unmap(DeviceAllocation alloc) {
102+
auto context_guard = CUDAContext::get_instance().get_guard();
103+
92104
AllocInfo &info = allocations_[alloc.alloc_id];
93105
CUDADriver::get_instance().memcpy_host_to_device(info.ptr, info.mapped,
94106
info.size);
@@ -97,6 +109,8 @@ void CudaDevice::unmap(DeviceAllocation alloc) {
97109
}
98110

99111
void CudaDevice::memcpy_internal(DevicePtr dst, DevicePtr src, uint64_t size) {
112+
auto context_guard = CUDAContext::get_instance().get_guard();
113+
100114
void *dst_ptr =
101115
static_cast<char *>(allocations_[dst.alloc_id].ptr) + dst.offset;
102116
void *src_ptr =
@@ -119,6 +133,8 @@ DeviceAllocation CudaDevice::import_memory(void *ptr, size_t size) {
119133
}
120134

121135
uint64 CudaDevice::fetch_result_uint64(int i, uint64 *result_buffer) {
136+
auto context_guard = CUDAContext::get_instance().get_guard();
137+
122138
CUDADriver::get_instance().stream_synchronize(nullptr);
123139
uint64 ret;
124140
CUDADriver::get_instance().memcpy_device_to_host(&ret, result_buffer + i,

taichi/rhi/interop/vulkan_cuda_interop.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ CUexternalMemory import_vk_memory_object_from_handle(HANDLE handle,
8585
if (is_dedicated) {
8686
desc.flags |= CUDA_EXTERNAL_MEMORY_DEDICATED;
8787
}
88-
88+
auto context_guard = CUDAContext::get_instance().get_guard();
8989
CUDADriver::get_instance().import_external_memory(&ext_mem, &desc);
9090
return ext_mem;
9191
}
@@ -104,6 +104,7 @@ CUexternalMemory import_vk_memory_object_from_handle(int fd,
104104
if (is_dedicated) {
105105
desc.flags |= CUDA_EXTERNAL_MEMORY_DEDICATED;
106106
}
107+
auto context_guard = CUDAContext::get_instance().get_guard();
107108
CUDADriver::get_instance().import_external_memory(&ext_mem, &desc);
108109
return ext_mem;
109110
}
@@ -120,6 +121,7 @@ void *map_buffer_onto_external_memory(CUexternalMemory ext_mem,
120121
desc.offset = offset;
121122
desc.size = size;
122123

124+
auto context_guard = CUDAContext::get_instance().get_guard();
123125
CUDADriver::get_instance().external_memory_get_mapped_buffer(
124126
(CUdeviceptr *)&ptr, ext_mem, &desc);
125127
return ptr;
@@ -137,6 +139,7 @@ void *get_cuda_memory_pointer(VkDeviceMemory mem,
137139
}
138140

139141
void cuda_memcpy(void *dst, void *src, size_t size) {
142+
auto context_guard = CUDAContext::get_instance().get_guard();
140143
CUDADriver::get_instance().memcpy_device_to_device(dst, src, size);
141144
}
142145

taichi/runtime/cuda/jit_cuda.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@ JITModule *JITSessionCUDA ::add_module(std::unique_ptr<llvm::Module> M,
1313
"module NVPTX");
1414
writer.write(ptx);
1515
}
16-
// TODO: figure out why using the guard leads to wrong tests results
17-
// auto context_guard = CUDAContext::get_instance().get_guard();
18-
CUDAContext::get_instance().make_current();
16+
auto context_guard = CUDAContext::get_instance().get_guard();
1917
// Create module for object
2018
void *cuda_module;
2119
TI_TRACE("PTX size: {:.2f}KB", ptx.size() / 1024.0);

taichi/runtime/cuda/jit_cuda.h

+1-3
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,7 @@ class JITModuleCUDA : public JITModule {
4747
}
4848

4949
void *lookup_function(const std::string &name) override {
50-
// TODO: figure out why using the guard leads to wrong tests results
51-
// auto context_guard = CUDAContext::get_instance().get_guard();
52-
CUDAContext::get_instance().make_current();
50+
auto context_guard = CUDAContext::get_instance().get_guard();
5351
void *func = nullptr;
5452
auto t = Time::get_time();
5553
auto err = CUDADriver::get_instance().module_get_function.call_with_warning(

taichi/runtime/llvm/llvm_runtime_executor.cpp

+15-13
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ void LlvmRuntimeExecutor::print_list_manager_info(void *list_manager,
176176
void LlvmRuntimeExecutor::synchronize() {
177177
if (config_->arch == Arch::cuda) {
178178
#if defined(TI_WITH_CUDA)
179-
CUDAContext::get_instance().make_current();
179+
auto context_guard = CUDAContext::get_instance().get_guard();
180180
CUDADriver::get_instance().stream_synchronize(nullptr);
181181
#else
182182
TI_ERROR("No CUDA support");
@@ -191,7 +191,7 @@ uint64 LlvmRuntimeExecutor::fetch_result_uint64(int i, uint64 *result_buffer) {
191191
uint64 ret;
192192
if (config_->arch == Arch::cuda) {
193193
#if defined(TI_WITH_CUDA)
194-
CUDAContext::get_instance().make_current();
194+
auto context_guard = CUDAContext::get_instance().get_guard();
195195
CUDADriver::get_instance().memcpy_device_to_host(&ret, result_buffer + i,
196196
sizeof(uint64));
197197
#else
@@ -373,6 +373,7 @@ void LlvmRuntimeExecutor::initialize_llvm_runtime_snodes(
373373
result_buffer);
374374
if (config_->arch == Arch::cuda) {
375375
#if defined(TI_WITH_CUDA)
376+
auto context_guard = CUDAContext::get_instance().get_guard();
376377
CUDADriver::get_instance().memset(root_buffer, 0, rounded_size);
377378
#else
378379
TI_NOT_IMPLEMENTED
@@ -476,6 +477,7 @@ void LlvmRuntimeExecutor::fill_ndarray(const DeviceAllocation &alloc,
476477
auto ptr = get_ndarray_alloc_info_ptr(alloc);
477478
if (config_->arch == Arch::cuda) {
478479
#if defined(TI_WITH_CUDA)
480+
auto cuda_context = CUDAContext::get_instance().get_guard();
479481
CUDADriver::get_instance().memsetd32((void *)ptr, data, size);
480482
#else
481483
TI_NOT_IMPLEMENTED
@@ -515,9 +517,12 @@ void LlvmRuntimeExecutor::materialize_runtime(MemoryPool *memory_pool,
515517
TaichiLLVMContext *tlctx = nullptr;
516518
if (config_->arch == Arch::cuda) {
517519
#if defined(TI_WITH_CUDA)
518-
CUDADriver::get_instance().malloc(
519-
(void **)result_buffer_ptr,
520-
sizeof(uint64) * taichi_result_buffer_entries);
520+
{
521+
auto context_guard = CUDAContext::get_instance().get_guard();
522+
CUDADriver::get_instance().malloc(
523+
(void **)result_buffer_ptr,
524+
sizeof(uint64) * taichi_result_buffer_entries);
525+
}
521526
const auto total_mem = runtime_mem_info_->get_total_memory();
522527
if (config_->device_memory_fraction == 0) {
523528
TI_ASSERT(config_->device_memory_GB > 0);
@@ -537,9 +542,11 @@ void LlvmRuntimeExecutor::materialize_runtime(MemoryPool *memory_pool,
537542
cuda::CudaDevice::AllocInfo preallocated_device_buffer_alloc_info =
538543
cuda_device()->get_alloc_info(preallocated_device_buffer_alloc_);
539544
preallocated_device_buffer_ = preallocated_device_buffer_alloc_info.ptr;
540-
541-
CUDADriver::get_instance().memset(preallocated_device_buffer_, 0,
542-
prealloc_size);
545+
{
546+
auto context_guard = CUDAContext::get_instance().get_guard();
547+
CUDADriver::get_instance().memset(preallocated_device_buffer_, 0,
548+
prealloc_size);
549+
}
543550
tlctx = llvm_context_device_.get();
544551
#else
545552
TI_NOT_IMPLEMENTED
@@ -612,11 +619,6 @@ void LlvmRuntimeExecutor::materialize_runtime(MemoryPool *memory_pool,
612619
"LLVMRuntime_set_profiler_stop", llvm_runtime_,
613620
(void *)&KernelProfilerBase::profiler_stop);
614621
}
615-
#if defined(TI_WITH_CUDA)
616-
if (config_->arch == Arch::cuda) {
617-
CUDADriver::get_instance().context_pop_current(nullptr);
618-
}
619-
#endif
620622
}
621623

622624
void LlvmRuntimeExecutor::destroy_snode_tree(SNodeTree *snode_tree) {

taichi/runtime/program_impls/llvm/llvm_program.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#if defined(TI_WITH_CUDA)
1414
#include "taichi/runtime/cuda/aot_module_builder_impl.h"
1515
#include "taichi/codegen/cuda/codegen_cuda.h"
16+
#include "taichi/rhi/cuda/cuda_context.h"
1617
#endif
1718

1819
#if defined(TI_WITH_DX12)
@@ -69,6 +70,7 @@ std::unique_ptr<StructCompiler> LlvmProgramImpl::compile_snode_tree_types_impl(
6970
Arch::dx12, this, std::move(device_module), tree->id());
7071
} else {
7172
TI_ASSERT(config->arch == Arch::cuda);
73+
auto context_guard = CUDAContext::get_instance().get_guard();
7274
auto device_module = clone_struct_compiler_initial_context(
7375
has_multiple_snode_trees, runtime_exec_->llvm_context_device_.get());
7476
struct_compiler = std::make_unique<StructCompilerLLVM>(

0 commit comments

Comments
 (0)