Skip to content

Commit ffe0c1a

Browse files
zasdfgbnmfacebook-github-bot
authored andcommittedDec 15, 2019
Make test_torch.py pass cuda-memcheck (pytorch#29243)
Summary: Make the following changes: - When there are more than 10k errors, cuda-memcheck only shows 10k errors, in this case we shouldn't raise an Exception - Add UNDER_CUDA_MEMCHECK environment to allow disabling `pin_memory` tests when running cuda-memcheck. - Add a `--ci` command option, when turned on, then this script would run output to stdout instead of writing a file, and exit with an error if cuda-memcheck fails - Add a `--nohang` command option. When turned on, then hang would be treated as pass instead of error - Do simple filtering on the test to run: if `'cpu'` in the test name but not `'cuda'` is not in the test name - Add `--split` and `--rank` to allowing splitting the work (NVIDIA CI has a limitation of 3 hours, we have to split the work to satisfy this limitation) - The error summary could be `ERROR SUMMARY: 1 error`, or `ERROR SUMMARY: 2 errors`, the tail could be `error` or `errors`, it is not of the same length. The script is fixed to handle this case. - Ignore errors from `cufft` Pull Request resolved: pytorch#29243 Differential Revision: D18941701 Pulled By: mruberry fbshipit-source-id: 2048428f32b66ef50c67444c03ce4dd9491179d2
1 parent 701e05d commit ffe0c1a

File tree

4 files changed

+60
-11
lines changed

4 files changed

+60
-11
lines changed
 

‎test/common_device_type.py

+3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import threading
33
from functools import wraps
44
import unittest
5+
import os
56
import torch
67
from common_utils import TestCase, TEST_WITH_ROCM, TEST_MKL, \
78
skipCUDANonDefaultStreamIf
@@ -247,6 +248,8 @@ def setUpClass(cls):
247248
if torch.cuda.is_available():
248249
device_type_test_bases.append(CUDATestBase)
249250

251+
PYTORCH_CUDA_MEMCHECK = os.getenv('PYTORCH_CUDA_MEMCHECK', '0') == '1'
252+
250253

251254
# Adds 'instantiated' device-specific test cases to the given scope.
252255
# The tests in these test cases are derived from the generic tests in

‎test/scripts/cuda_memcheck_common.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,20 @@ class ParseError(Exception):
99
class Report:
1010
"""A report is a container of errors, and a summary on how many errors are found"""
1111

12-
HEAD = 'ERROR SUMMARY: '
13-
TAIL = ' errors'
14-
1512
def __init__(self, text, errors):
13+
# text is something like
14+
# ERROR SUMMARY: 1 error
15+
# or
16+
# ERROR SUMMARY: 2 errors
1617
self.text = text
17-
self.num_errors = int(text[len(self.HEAD):len(text) - len(self.TAIL)])
18+
self.num_errors = int(text.strip().split()[2])
1819
self.errors = errors
1920
if len(errors) != self.num_errors:
20-
raise ParseError("Number of errors does not match")
21+
if len(errors) == 10000 and self.num_errors > 10000:
22+
# When there are more than 10k errors, cuda-memcheck only display 10k
23+
self.num_errors = 10000
24+
else:
25+
raise ParseError("Number of errors does not match")
2126

2227

2328
class Error:

‎test/scripts/run_cuda_memcheck.py

+44-5
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
import argparse
1919
import subprocess
2020
import tqdm
21-
import re
21+
import os
22+
import sys
2223
import cuda_memcheck_common as cmc
2324

2425
ALL_TESTS = []
@@ -35,6 +36,13 @@
3536
help='Number of processes running tests, default to number of cores in the system')
3637
parser.add_argument('--gpus', default='all',
3738
help='GPU assignments for each process, it could be "all", or : separated list like "1,2:3,4:5,6"')
39+
parser.add_argument('--ci', action='store_true',
40+
help='Whether this script is executed in CI. When executed inside a CI, this script fails when '
41+
'an error is detected. Also, it will not show tqdm progress bar, but directly print the error'
42+
'to stdout instead.')
43+
parser.add_argument('--nohang', action='store_true', help='Treat timeout as success')
44+
parser.add_argument('--split', type=int, default=1, help='Split the job into pieces')
45+
parser.add_argument('--rank', type=int, default=0, help='Which piece this process should pick')
3846
args = parser.parse_args()
3947

4048
# Filters that ignores cublas/cudnn errors
@@ -48,10 +56,13 @@ def is_ignored_only(output):
4856
return False
4957
count_ignored_errors = 0
5058
for e in report.errors:
51-
if 'libcublas' in ''.join(e.stack) or 'libcudnn' in ''.join(e.stack):
59+
if 'libcublas' in ''.join(e.stack) or 'libcudnn' in ''.join(e.stack) or 'libcufft' in ''.join(e.stack):
5260
count_ignored_errors += 1
5361
return count_ignored_errors == report.num_errors
5462

63+
# Set environment PYTORCH_CUDA_MEMCHECK=1 to allow skipping some tests
64+
os.environ['PYTORCH_CUDA_MEMCHECK'] = '1'
65+
5566
# Discover tests:
5667
# To get a list of tests, run:
5768
# pytest --setup-only test/test_torch.py
@@ -66,6 +77,21 @@ def is_ignored_only(output):
6677
line = line.replace('::', '.')
6778
ALL_TESTS.append(line)
6879

80+
# Do a simple filtering:
81+
# if 'cpu' or 'CPU' is in the name and 'cuda' or 'CUDA' is not in the name, then skip it
82+
def is_cpu_only(name):
83+
name = name.lower()
84+
return ('cpu' in name) and not ('cuda' in name)
85+
86+
ALL_TESTS = [x for x in ALL_TESTS if not is_cpu_only(x)]
87+
88+
# Split all tests into chunks, and only on the selected chunk
89+
ALL_TESTS.sort()
90+
chunk_size = (len(ALL_TESTS) + args.split - 1) // args.split
91+
start = chunk_size * args.rank
92+
end = chunk_size * (args.rank + 1)
93+
ALL_TESTS = ALL_TESTS[start:end]
94+
6995
# Run tests:
7096
# Since running cuda-memcheck on PyTorch unit tests is very slow, these tests must be run in parallel.
7197
# This is done by using the coroutine feature in new Python versions. A number of coroutines are created;
@@ -74,8 +100,17 @@ def is_ignored_only(output):
74100
# These subprocesses are balanced across different GPUs on the system by assigning one devices per process,
75101
# or as specified by the user
76102
progress = 0
77-
logfile = open('result.log', 'w')
78-
progressbar = tqdm.tqdm(total=len(ALL_TESTS))
103+
if not args.ci:
104+
logfile = open('result.log', 'w')
105+
progressbar = tqdm.tqdm(total=len(ALL_TESTS))
106+
else:
107+
logfile = sys.stdout
108+
109+
# create a fake progress bar that does not display anything
110+
class ProgressbarStub:
111+
def update(*args):
112+
return
113+
progressbar = ProgressbarStub()
79114

80115
async def run1(coroutine_id):
81116
global progress
@@ -97,6 +132,8 @@ async def run1(coroutine_id):
97132
except asyncio.TimeoutError:
98133
print('Timeout:', test, file=logfile)
99134
proc.kill()
135+
if args.ci and not args.nohang:
136+
sys.exit("Hang detected on cuda-memcheck")
100137
else:
101138
if proc.returncode == 0:
102139
print('Success:', test, file=logfile)
@@ -108,13 +145,15 @@ async def run1(coroutine_id):
108145
print('Fail:', test, file=logfile)
109146
print(stdout, file=logfile)
110147
print(stderr, file=logfile)
148+
if args.ci:
149+
sys.exit("Failure detected on cuda-memcheck")
111150
else:
112151
print('Ignored:', test, file=logfile)
113152
del proc
114153
progressbar.update(1)
115154

116155
async def main():
117-
tasks = [asyncio.create_task(run1(i)) for i in range(args.nproc)]
156+
tasks = [asyncio.ensure_future(run1(i)) for i in range(args.nproc)]
118157
for t in tasks:
119158
await t
120159

‎test/test_torch.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
IS_SANDCASTLE, load_tests, brute_pdist, brute_cdist, slowTest, \
3535
skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf
3636
from multiprocessing.reduction import ForkingPickler
37-
from common_device_type import instantiate_device_type_tests, \
37+
from common_device_type import instantiate_device_type_tests, PYTORCH_CUDA_MEMCHECK, \
3838
skipCPUIfNoLapack, skipCUDAIfNoMagma, skipCUDAIfRocm, onlyCUDA, onlyCPU, \
3939
dtypes, dtypesIfCUDA, deviceCountAtLeast, skipCUDAIf, precisionOverride
4040
import torch.backends.quantized
@@ -4929,6 +4929,7 @@ def test_empty_like(self):
49294929
self.assertEqual(torch.empty_like(a).shape, a.shape)
49304930
self.assertEqual(torch.empty_like(a).type(), a.type())
49314931

4932+
@unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property")
49324933
def test_pin_memory(self):
49334934
x = torch.randn(3, 5)
49344935
self.assertFalse(x.is_pinned())
@@ -12673,6 +12674,7 @@ def test_dlpack_conversion(self, device):
1267312674
self.assertEqual(z, x)
1267412675

1267512676
@onlyCUDA
12677+
@unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property")
1267612678
def test_pin_memory_from_constructor(self, device):
1267712679
def _get_like(t, **kwargs):
1267812680
return [

0 commit comments

Comments
 (0)
Please sign in to comment.