|
1 | 1 | import collections
|
2 | 2 | import unittest
|
3 | 3 |
|
4 |
| -import numpy as np |
5 | 4 | import torch
|
6 | 5 | from torch.testing._internal.common_utils import (
|
7 | 6 | TestCase, run_tests, TEST_WITH_ASAN)
|
@@ -43,15 +42,15 @@ def func(self, runs):
|
43 | 42 | return last_rss
|
44 | 43 |
|
45 | 44 | def func_rss(self, runs):
|
46 |
| - last_rss = self.func(runs) |
47 |
| - # Do a least-mean-squares fit of last_rss to a line |
48 |
| - poly = np.polynomial.Polynomial.fit( |
49 |
| - range(len(last_rss)), np.array(last_rss), 1) |
50 |
| - coefs = poly.convert().coef |
51 |
| - # The coefs are (b, m) for the line y = m * x + b that fits the data. |
52 |
| - # If m == 0 it will not be present. Assert it is missing or < 1000. |
53 |
| - self.assertTrue(len(coefs) < 2 or coefs[1] < 1000, |
54 |
| - msg='memory did not stabilize, {}'.format(str(list(last_rss)))) |
| 45 | + last_rss = list(self.func(runs)) |
| 46 | + # Check that the sequence is not strictly increasing |
| 47 | + is_increasing = True |
| 48 | + for idx in range(len(last_rss)): |
| 49 | + if idx == 0: |
| 50 | + continue |
| 51 | + is_increasing = is_increasing and (last_rss[idx] > last_rss[idx - 1]) |
| 52 | + self.assertTrue(not is_increasing, |
| 53 | + msg='memory usage is increasing, {}'.format(str(last_rss))) |
55 | 54 |
|
56 | 55 | def test_one_thread(self):
|
57 | 56 | """Make sure there is no memory leak with one thread: issue gh-32284
|
|
0 commit comments