Skip to content

Commit d8c3845

Browse files
Ilia Cherniavskiifacebook-github-bot
Ilia Cherniavskii
authored andcommittedJun 23, 2020
Destroy CUDA events after profiling (pytorch#39962)
Summary: Pull Request resolved: pytorch#39962 Adding a simple wrapper with ref count for cuda event and destroying cuda event after the last copy is destroyed Test Plan: CI cuda profiler tests Differential Revision: D22027092 Pulled By: ilia-cher fbshipit-source-id: e0810388aa60b2291eb010896e13af1fad92e472
1 parent a54bb4e commit d8c3845

File tree

6 files changed

+95
-37
lines changed

6 files changed

+95
-37
lines changed
 
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1+
from functools import partial
12
import itertools
23
import statistics
34
import timeit
45
import torch
56

6-
profiling_enabled = None
7-
profiling_tensor_size = None
87
TENSOR_SIZES = [1, 32, 128, 256, 512]
98
INTERNAL_ITER = 256
109
PARALLEL_TASKS_NUM = 4
@@ -16,13 +15,12 @@ def loop_workload(x):
1615
return x
1716

1817
traced_loop_workload = None
19-
def run_profiler_benchmark_loop():
20-
x = torch.rand(profiling_tensor_size, profiling_tensor_size)
18+
def run_profiler_benchmark_loop(input_x, use_cuda, profiling_enabled):
2119
if profiling_enabled:
22-
with torch.autograd.profiler.profile() as prof:
23-
traced_loop_workload(x)
20+
with torch.autograd.profiler.profile(use_cuda=use_cuda) as prof:
21+
traced_loop_workload(input_x)
2422
else:
25-
traced_loop_workload(x)
23+
traced_loop_workload(input_x)
2624

2725
def parallel_task(x):
2826
for i in range(int(INTERNAL_ITER / PARALLEL_TASKS_NUM)):
@@ -38,40 +36,49 @@ def parallel_workload(x):
3836
return x
3937

4038
traced_parallel_workload = None
41-
def run_profiler_benchmark_parallel():
42-
x = torch.rand(profiling_tensor_size, profiling_tensor_size)
39+
def run_profiler_benchmark_parallel(input_x, use_cuda, profiling_enabled):
4340
if profiling_enabled:
44-
with torch.autograd.profiler.profile() as prof:
45-
traced_parallel_workload(x)
41+
with torch.autograd.profiler.profile(use_cuda=use_cuda) as prof:
42+
traced_parallel_workload(input_x)
4643
else:
47-
traced_parallel_workload(x)
44+
traced_parallel_workload(input_x)
4845

4946
if __name__ == '__main__':
5047
for workload_name in ["loop", "parallel"]:
5148
print("Payload: {}; {} iterations, N = {}\n".format(
5249
workload_name, INTERNAL_ITER, N))
53-
for params in itertools.product(TENSOR_SIZES, [False, True]):
54-
profiling_tensor_size = params[0]
55-
profiling_enabled = params[1]
50+
for params in itertools.product([False, True], TENSOR_SIZES, [False, True]):
51+
use_cuda = params[0]
52+
profiling_tensor_size = params[1]
53+
profiling_enabled = params[2]
5654

57-
print("Profiling {}, tensor size {}x{}".format(
58-
"enabled " if profiling_enabled else "disabled",
59-
profiling_tensor_size, profiling_tensor_size))
55+
if (use_cuda and not torch.cuda.is_available()):
56+
continue
6057

61-
x = torch.rand(profiling_tensor_size, profiling_tensor_size)
58+
print("Profiling {}, tensor size {}x{}, use cuda: {}".format(
59+
"enabled" if profiling_enabled else "disabled",
60+
profiling_tensor_size, profiling_tensor_size, use_cuda))
61+
62+
input_x = torch.rand(profiling_tensor_size, profiling_tensor_size)
63+
if use_cuda:
64+
input_x = input_x.cuda()
6265
workload = None
6366
if workload_name == "loop":
64-
workload = run_profiler_benchmark_loop
65-
traced_loop_workload = torch.jit.trace(loop_workload, x)
67+
workload = partial(
68+
run_profiler_benchmark_loop, input_x, use_cuda, profiling_enabled)
69+
traced_loop_workload = torch.jit.trace(loop_workload, input_x)
6670
elif workload_name == "parallel":
67-
workload = run_profiler_benchmark_parallel
71+
workload = partial(
72+
run_profiler_benchmark_parallel, input_x, use_cuda, profiling_enabled)
6873
traced_parallel_workload = torch.jit.trace(
69-
parallel_workload, x)
74+
parallel_workload, input_x)
7075

7176
runtimes = timeit.repeat(workload, repeat=N, number=1)
7277
avg_time = statistics.mean(runtimes) * 1000.0
7378
stddev_time = statistics.stdev(runtimes) * 1000.0
7479
print("\tavg. time: {:.3f} ms, stddev: {:.3f} ms".format(
7580
avg_time, stddev_time))
76-
print("\ttime per iteration: {:.3f} ms\n".format(
77-
avg_time / INTERNAL_ITER))
81+
if workload_name == "loop":
82+
print("\ttime per iteration: {:.3f} ms".format(
83+
avg_time / INTERNAL_ITER))
84+
print()

‎test/run_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
'test_jit_fuser_te',
7171
'test_tensorexpr',
7272
'test_openmp',
73+
'test_profiler',
7374
'distributed/nn/jit/test_instantiator',
7475
'distributed/nn/api/test_remote_module_spawn',
7576
'distributed/rpc/faulty_agent/test_dist_autograd_spawn',

‎test/test_profiler.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import collections
2+
import gc
3+
import unittest
4+
5+
import torch
6+
from torch.testing._internal.common_utils import (
7+
TestCase, run_tests, TEST_WITH_ASAN)
8+
from torch.autograd.profiler import profile
9+
10+
try:
11+
import psutil
12+
HAS_PSUTIL = True
13+
except ImportError:
14+
HAS_PSUTIL = False
15+
16+
17+
@unittest.skipIf(not HAS_PSUTIL, "Requires psutil to run")
18+
@unittest.skipIf(TEST_WITH_ASAN, "Cannot test with ASAN")
19+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
20+
class TestProfiler_cuda(TestCase):
21+
def test_mem_leak(self):
22+
"""Checks that there's no memory leak when using profiler with CUDA
23+
"""
24+
t = torch.rand(1, 1).cuda()
25+
p = psutil.Process()
26+
last_rss = collections.deque(maxlen=5)
27+
for outer_idx in range(10):
28+
with profile(use_cuda=True):
29+
for _ in range(1024):
30+
t = torch.mm(t, t)
31+
32+
gc.collect()
33+
torch.cuda.empty_cache()
34+
last_rss.append(p.memory_info().rss)
35+
36+
max_diff = -1
37+
for idx in range(1, len(last_rss)):
38+
max_diff = max(max_diff, last_rss[idx] - last_rss[idx - 1])
39+
40+
# with CUDA events leaking the increase in memory was ~7 MB,
41+
# using much smaller threshold but not zero to reduce flakiness
42+
self.assertTrue(max_diff < 100 * 1024)
43+
44+
if __name__ == '__main__':
45+
run_tests()

‎torch/csrc/autograd/profiler.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ struct ProfilerThreadLocalState
264264
thread_id,
265265
config_.state == ProfilerState::CUDA);
266266
evt.updateMemoryStats(alloc_size, device);
267-
getEventList(thread_id).record(evt);
267+
getEventList(thread_id).record(std::move(evt));
268268
}
269269
}
270270

@@ -554,7 +554,7 @@ at::IValue Event::toIValue() const {
554554
return at::IValue(eventIValueList);
555555
}
556556

557-
double Event::cuda_elapsed_us(const Event & e) const {
557+
double Event::cuda_elapsed_us(const Event& e) const {
558558
TORCH_CHECK(e.has_cuda() && has_cuda(), "Events were not recorded for CUDA");
559559
TORCH_CHECK(
560560
e.device() == device(),
@@ -565,7 +565,7 @@ double Event::cuda_elapsed_us(const Event & e) const {
565565
TORCH_INTERNAL_ASSERT(cuda_us_ >= 0 && e.cuda_us_ >= 0);
566566
return static_cast<double>(e.cuda_us_ - cuda_us_);
567567
}
568-
return cuda_stubs->elapsed(cuda_event, e.cuda_event);
568+
return cuda_stubs->elapsed(&cuda_event, &e.cuda_event);
569569
}
570570

571571
CUDAStubs::~CUDAStubs() = default;

‎torch/csrc/autograd/profiler.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020

2121
#include <ATen/record_function.h>
2222

23-
typedef struct CUevent_st* CUDAEventStub;
23+
struct CUevent_st;
24+
typedef std::shared_ptr<CUevent_st> CUDAEventStub;
2425

2526
namespace torch { namespace autograd {
2627

@@ -32,7 +33,7 @@ struct TORCH_API CUDAStubs {
3233
virtual void record(int* device, CUDAEventStub* event, int64_t* cpu_ns) {
3334
fail();
3435
}
35-
virtual float elapsed(CUDAEventStub event, CUDAEventStub event2) {
36+
virtual float elapsed(const CUDAEventStub* event, const CUDAEventStub* event2) {
3637
fail();
3738
return 0.f;
3839
}
@@ -291,7 +292,7 @@ struct TORCH_API Event final {
291292
int64_t cpu_memory_usage_ = 0;
292293
int64_t cuda_memory_usage_ = 0;
293294
int device_ = -1;
294-
struct CUevent_st* cuda_event = nullptr;
295+
CUDAEventStub cuda_event = nullptr;
295296
int node_id_ = 0;
296297
bool is_remote_ = false;
297298
int64_t cuda_us_ = -1;

‎torch/csrc/autograd/profiler_cuda.cpp

+10-6
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,20 @@ static inline void cudaCheck(cudaError_t result, const char * file, int line) {
3434
struct CUDAMethods : public CUDAStubs {
3535
void record(int* device, CUDAEventStub* event, int64_t* cpu_ns) override {
3636
TORCH_CUDA_CHECK(cudaGetDevice(device));
37-
TORCH_CUDA_CHECK(cudaEventCreate(event));
37+
CUevent_st* cuda_event_ptr;
38+
TORCH_CUDA_CHECK(cudaEventCreate(&cuda_event_ptr));
39+
*event = std::shared_ptr<CUevent_st>(cuda_event_ptr, [](CUevent_st* ptr) {
40+
TORCH_CUDA_CHECK(cudaEventDestroy(ptr));
41+
});
3842
auto stream = at::cuda::getCurrentCUDAStream();
3943
*cpu_ns = getTime();
40-
TORCH_CUDA_CHECK(cudaEventRecord(*event, stream));
44+
TORCH_CUDA_CHECK(cudaEventRecord(cuda_event_ptr, stream));
4145
}
42-
float elapsed(CUDAEventStub event, CUDAEventStub event2) override {
43-
TORCH_CUDA_CHECK(cudaEventSynchronize(event));
44-
TORCH_CUDA_CHECK(cudaEventSynchronize(event2));
46+
float elapsed(const CUDAEventStub* event, const CUDAEventStub* event2) override {
47+
TORCH_CUDA_CHECK(cudaEventSynchronize(event->get()));
48+
TORCH_CUDA_CHECK(cudaEventSynchronize(event2->get()));
4549
float ms;
46-
TORCH_CUDA_CHECK(cudaEventElapsedTime(&ms, event, event2));
50+
TORCH_CUDA_CHECK(cudaEventElapsedTime(&ms, event->get(), event2->get()));
4751
return ms*1000.0;
4852
}
4953
void nvtxMarkA(const char* name) override {

0 commit comments

Comments
 (0)
Please sign in to comment.