Skip to content

Commit b2b8af9

Browse files
Mike Ruberryfacebook-github-bot
Mike Ruberry
authored andcommittedJul 16, 2020
Removes assertAlmostEqual (pytorch#41514)
Summary: This test function is confusing since our `assertEqual` behavior allows for tolerance to be specified, and this is a redundant mechanism. Pull Request resolved: pytorch#41514 Reviewed By: ngimel Differential Revision: D22569348 Pulled By: mruberry fbshipit-source-id: 2b2ff8aaa9625a51207941dfee8e07786181fe9f
1 parent 58244a9 commit b2b8af9

8 files changed

+112
-120
lines changed
 

‎test/quantization/test_workflow_module.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def test_per_tensor_observers(self, qdtype, qscheme, reduce_range):
247247
ref_scale = 0.0313725
248248
ref_zero_point = -128 if qdtype is torch.qint8 else 0
249249
self.assertEqual(qparams[1].item(), ref_zero_point)
250-
self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5)
250+
self.assertEqual(qparams[0].item(), ref_scale, atol=1e-5, rtol=0)
251251
state_dict = myobs.state_dict()
252252
b = io.BytesIO()
253253
torch.save(state_dict, b)
@@ -474,7 +474,7 @@ def test_histogram_observer(self, qdtype, qscheme, reduce_range):
474474
ref_zero_point = -128 if qdtype is torch.qint8 else 0
475475

476476
self.assertEqual(qparams[1].item(), ref_zero_point)
477-
self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5)
477+
self.assertEqual(qparams[0].item(), ref_scale, atol=1e-5, rtol=0)
478478
# Test for serializability
479479
state_dict = myobs.state_dict()
480480
b = io.BytesIO()

‎test/test_bundled_inputs.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def forward(self, arg):
7070
# This tensor is random, but with 100,000 trials,
7171
# mean and std had ranges of (-0.0154, 0.0144) and (0.9907, 1.0105).
7272
self.assertEqual(inflated[5][0].shape, (1 << 16,))
73-
self.assertAlmostEqual(inflated[5][0].mean().item(), 0, delta=0.025)
74-
self.assertAlmostEqual(inflated[5][0].std().item(), 1, delta=0.02)
73+
self.assertEqual(inflated[5][0].mean().item(), 0, atol=0.025, rtol=0)
74+
self.assertEqual(inflated[5][0].std().item(), 1, atol=0.02, rtol=0)
7575

7676

7777
def test_large_tensor_with_inflation(self):

‎test/test_distributions.py

+29-30
Original file line numberDiff line numberDiff line change
@@ -998,8 +998,7 @@ def test_binomial_log_prob(self):
998998
def ref_log_prob(idx, x, log_prob):
999999
p = probs.view(-1)[idx].item()
10001000
expected = scipy.stats.binom(total_count, p).logpmf(x)
1001-
self.assertAlmostEqual(log_prob, expected, places=3)
1002-
1001+
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
10031002
self._check_log_prob(Binomial(total_count, probs), ref_log_prob)
10041003
logits = probs_to_logits(probs, is_binary=True)
10051004
self._check_log_prob(Binomial(total_count, logits=logits), ref_log_prob)
@@ -1023,7 +1022,7 @@ def test_binomial_log_prob_vectorized_count(self):
10231022
(torch.tensor([1, 2, 10]), torch.tensor([0., 1., 9.]))]:
10241023
log_prob = Binomial(total_count, probs).log_prob(sample)
10251024
expected = scipy.stats.binom(total_count.cpu().numpy(), probs.cpu().numpy()).logpmf(sample)
1026-
self.assertAlmostEqual(log_prob, expected, places=4)
1025+
self.assertEqual(log_prob, expected, atol=1e-4, rtol=0)
10271026

10281027
def test_binomial_enumerate_support(self):
10291028
examples = [
@@ -1037,11 +1036,11 @@ def test_binomial_extreme_vals(self):
10371036
total_count = 100
10381037
bin0 = Binomial(total_count, 0)
10391038
self.assertEqual(bin0.sample(), 0)
1040-
self.assertAlmostEqual(bin0.log_prob(torch.tensor([0.]))[0], 0, places=3)
1039+
self.assertEqual(bin0.log_prob(torch.tensor([0.]))[0], 0, atol=1e-3, rtol=0)
10411040
self.assertEqual(float(bin0.log_prob(torch.tensor([1.])).exp()), 0)
10421041
bin1 = Binomial(total_count, 1)
10431042
self.assertEqual(bin1.sample(), total_count)
1044-
self.assertAlmostEqual(bin1.log_prob(torch.tensor([float(total_count)]))[0], 0, places=3)
1043+
self.assertEqual(bin1.log_prob(torch.tensor([float(total_count)]))[0], 0, atol=1e-3, rtol=0)
10451044
self.assertEqual(float(bin1.log_prob(torch.tensor([float(total_count - 1)])).exp()), 0)
10461045
zero_counts = torch.zeros(torch.Size((2, 2)))
10471046
bin2 = Binomial(zero_counts, 1)
@@ -1076,7 +1075,7 @@ def test_negative_binomial_log_prob(self):
10761075
def ref_log_prob(idx, x, log_prob):
10771076
p = probs.view(-1)[idx].item()
10781077
expected = scipy.stats.nbinom(total_count, 1 - p).logpmf(x)
1079-
self.assertAlmostEqual(log_prob, expected, places=3)
1078+
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
10801079

10811080
self._check_log_prob(NegativeBinomial(total_count, probs), ref_log_prob)
10821081
logits = probs_to_logits(probs, is_binary=True)
@@ -1089,7 +1088,7 @@ def test_negative_binomial_log_prob_vectorized_count(self):
10891088
(torch.tensor([1, 2, 10]), torch.tensor([0., 1., 9.]))]:
10901089
log_prob = NegativeBinomial(total_count, probs).log_prob(sample)
10911090
expected = scipy.stats.nbinom(total_count.cpu().numpy(), 1 - probs.cpu().numpy()).logpmf(sample)
1092-
self.assertAlmostEqual(log_prob, expected, places=4)
1091+
self.assertEqual(log_prob, expected, atol=1e-4, rtol=0)
10931092

10941093
def test_multinomial_1d(self):
10951094
total_count = 10
@@ -1238,7 +1237,7 @@ def test_poisson_log_prob(self):
12381237
def ref_log_prob(idx, x, log_prob):
12391238
l = rate.view(-1)[idx].detach()
12401239
expected = scipy.stats.poisson.logpmf(x, l)
1241-
self.assertAlmostEqual(log_prob, expected, places=3)
1240+
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
12421241

12431242
set_rng_seed(0)
12441243
self._check_log_prob(Poisson(rate), ref_log_prob)
@@ -1500,7 +1499,7 @@ def test_halfnormal_logprob(self):
15001499
def ref_log_prob(idx, x, log_prob):
15011500
s = std.view(-1)[idx].detach()
15021501
expected = scipy.stats.halfnorm(scale=s).logpdf(x)
1503-
self.assertAlmostEqual(log_prob, expected, places=3)
1502+
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
15041503

15051504
self._check_log_prob(HalfNormal(std), ref_log_prob)
15061505

@@ -1550,7 +1549,7 @@ def ref_log_prob(idx, x, log_prob):
15501549
m = mean.view(-1)[idx].detach()
15511550
s = std.view(-1)[idx].detach()
15521551
expected = scipy.stats.lognorm(s=s, scale=math.exp(m)).logpdf(x)
1553-
self.assertAlmostEqual(log_prob, expected, places=3)
1552+
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
15541553

15551554
self._check_log_prob(LogNormal(mean, std), ref_log_prob)
15561555

@@ -1681,7 +1680,7 @@ def ref_log_prob(idx, x, log_prob):
16811680
mix = scipy.stats.multinomial(1, p)
16821681
comp = scipy.stats.norm(m, s)
16831682
expected = scipy.special.logsumexp(comp.logpdf(x) + np.log(mix.p))
1684-
self.assertAlmostEqual(log_prob, expected, places=3)
1683+
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
16851684

16861685
self._check_log_prob(
16871686
MixtureSameFamily(Categorical(probs=probs),
@@ -1754,7 +1753,7 @@ def ref_log_prob(idx, x, log_prob):
17541753
s = scale.view(-1)[idx]
17551754
expected = (math.exp(-(x - m) ** 2 / (2 * s ** 2)) /
17561755
math.sqrt(2 * math.pi * s ** 2))
1757-
self.assertAlmostEqual(log_prob, math.log(expected), places=3)
1756+
self.assertEqual(log_prob, math.log(expected), atol=1e-3, rtol=0)
17581757

17591758
self._check_log_prob(Normal(loc, scale), ref_log_prob)
17601759

@@ -1828,7 +1827,7 @@ def test_lowrank_multivariate_normal_log_prob(self):
18281827
x = dist1.sample((10,))
18291828
expected = ref_dist.logpdf(x.numpy())
18301829

1831-
self.assertAlmostEqual(0.0, np.mean((dist1.log_prob(x).detach().numpy() - expected)**2), places=3)
1830+
self.assertEqual(0.0, np.mean((dist1.log_prob(x).detach().numpy() - expected)**2), atol=1e-3, rtol=0)
18321831

18331832
# Double-check that batched versions behave the same as unbatched
18341833
mean = torch.randn(5, 3, requires_grad=True)
@@ -1844,7 +1843,7 @@ def test_lowrank_multivariate_normal_log_prob(self):
18441843
unbatched_prob = torch.stack([dist_unbatched[i].log_prob(x[:, i]) for i in range(5)]).t()
18451844

18461845
self.assertEqual(batched_prob.shape, unbatched_prob.shape)
1847-
self.assertAlmostEqual(0.0, (batched_prob - unbatched_prob).abs().max(), places=3)
1846+
self.assertEqual(0.0, (batched_prob - unbatched_prob).abs().max(), atol=1e-3, rtol=0)
18481847

18491848
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
18501849
def test_lowrank_multivariate_normal_sample(self):
@@ -1965,9 +1964,9 @@ def test_multivariate_normal_log_prob(self):
19651964
x = dist1.sample((10,))
19661965
expected = ref_dist.logpdf(x.numpy())
19671966

1968-
self.assertAlmostEqual(0.0, np.mean((dist1.log_prob(x).detach().numpy() - expected)**2), places=3)
1969-
self.assertAlmostEqual(0.0, np.mean((dist2.log_prob(x).detach().numpy() - expected)**2), places=3)
1970-
self.assertAlmostEqual(0.0, np.mean((dist3.log_prob(x).detach().numpy() - expected)**2), places=3)
1967+
self.assertEqual(0.0, np.mean((dist1.log_prob(x).detach().numpy() - expected)**2), atol=1e-3, rtol=0)
1968+
self.assertEqual(0.0, np.mean((dist2.log_prob(x).detach().numpy() - expected)**2), atol=1e-3, rtol=0)
1969+
self.assertEqual(0.0, np.mean((dist3.log_prob(x).detach().numpy() - expected)**2), atol=1e-3, rtol=0)
19711970

19721971
# Double-check that batched versions behave the same as unbatched
19731972
mean = torch.randn(5, 3, requires_grad=True)
@@ -1982,7 +1981,7 @@ def test_multivariate_normal_log_prob(self):
19821981
unbatched_prob = torch.stack([dist_unbatched[i].log_prob(x[:, i]) for i in range(5)]).t()
19831982

19841983
self.assertEqual(batched_prob.shape, unbatched_prob.shape)
1985-
self.assertAlmostEqual(0.0, (batched_prob - unbatched_prob).abs().max(), places=3)
1984+
self.assertEqual(0.0, (batched_prob - unbatched_prob).abs().max(), atol=1e-3, rtol=0)
19861985

19871986
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
19881987
def test_multivariate_normal_sample(self):
@@ -2048,7 +2047,7 @@ def test_exponential(self):
20482047
def ref_log_prob(idx, x, log_prob):
20492048
m = rate.view(-1)[idx]
20502049
expected = math.log(m) - m * x
2051-
self.assertAlmostEqual(log_prob, expected, places=3)
2050+
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
20522051

20532052
self._check_log_prob(Exponential(rate), ref_log_prob)
20542053

@@ -2099,7 +2098,7 @@ def ref_log_prob(idx, x, log_prob):
20992098
m = loc.view(-1)[idx]
21002099
s = scale.view(-1)[idx]
21012100
expected = (-math.log(2 * s) - abs(x - m) / s)
2102-
self.assertAlmostEqual(log_prob, expected, places=3)
2101+
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
21032102

21042103
self._check_log_prob(Laplace(loc, scale), ref_log_prob)
21052104

@@ -2128,7 +2127,7 @@ def ref_log_prob(idx, x, log_prob):
21282127
a = alpha.view(-1)[idx].detach()
21292128
b = beta.view(-1)[idx].detach()
21302129
expected = scipy.stats.gamma.logpdf(x, a, scale=1 / b)
2131-
self.assertAlmostEqual(log_prob, expected, places=3)
2130+
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
21322131

21332132
self._check_log_prob(Gamma(alpha, beta), ref_log_prob)
21342133

@@ -2150,7 +2149,7 @@ def ref_log_prob(idx, x, log_prob):
21502149
a = alpha.view(-1)[idx].detach().cpu()
21512150
b = beta.view(-1)[idx].detach().cpu()
21522151
expected = scipy.stats.gamma.logpdf(x.cpu(), a, scale=1 / b)
2153-
self.assertAlmostEqual(log_prob, expected, places=3)
2152+
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
21542153

21552154
self._check_log_prob(Gamma(alpha, beta), ref_log_prob)
21562155

@@ -2192,7 +2191,7 @@ def ref_log_prob(idx, x, log_prob):
21922191
s = scale.view(-1)[idx].detach()
21932192
a = alpha.view(-1)[idx].detach()
21942193
expected = scipy.stats.pareto.logpdf(x, a, scale=s)
2195-
self.assertAlmostEqual(log_prob, expected, places=3)
2194+
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
21962195

21972196
self._check_log_prob(Pareto(scale, alpha), ref_log_prob)
21982197

@@ -2221,7 +2220,7 @@ def ref_log_prob(idx, x, log_prob):
22212220
l = loc.view(-1)[idx].detach()
22222221
s = scale.view(-1)[idx].detach()
22232222
expected = scipy.stats.gumbel_r.logpdf(x, loc=l, scale=s)
2224-
self.assertAlmostEqual(log_prob, expected, places=3)
2223+
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
22252224

22262225
self._check_log_prob(Gumbel(loc, scale), ref_log_prob)
22272226

@@ -2252,7 +2251,7 @@ def ref_log_prob(idx, x, log_prob):
22522251
f1 = df1.view(-1)[idx].detach()
22532252
f2 = df2.view(-1)[idx].detach()
22542253
expected = scipy.stats.f.logpdf(x, f1, f2)
2255-
self.assertAlmostEqual(log_prob, expected, places=3)
2254+
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
22562255

22572256
self._check_log_prob(FisherSnedecor(df1, df2), ref_log_prob)
22582257

@@ -2279,7 +2278,7 @@ def test_chi2_shape(self):
22792278
def ref_log_prob(idx, x, log_prob):
22802279
d = df.view(-1)[idx].detach()
22812280
expected = scipy.stats.chi2.logpdf(x, d)
2282-
self.assertAlmostEqual(log_prob, expected, places=3)
2281+
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
22832282

22842283
self._check_log_prob(Chi2(df), ref_log_prob)
22852284

@@ -2309,7 +2308,7 @@ def test_studentT(self):
23092308
def ref_log_prob(idx, x, log_prob):
23102309
d = df.view(-1)[idx].detach()
23112310
expected = scipy.stats.t.logpdf(x, d)
2312-
self.assertAlmostEqual(log_prob, expected, places=3)
2311+
self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
23132312

23142313
self._check_log_prob(StudentT(df), ref_log_prob)
23152314

@@ -2331,7 +2330,7 @@ def test_studentT_log_prob(self):
23312330
actual_log_prob = dist.log_prob(x)
23322331
for i in range(num_samples):
23332332
expected_log_prob = scipy.stats.t.logpdf(x[i], df=df, loc=loc, scale=scale)
2334-
self.assertAlmostEqual(float(actual_log_prob[i]), float(expected_log_prob), places=3)
2333+
self.assertEqual(float(actual_log_prob[i]), float(expected_log_prob), atol=1e-3, rtol=0)
23352334

23362335
def test_dirichlet_shape(self):
23372336
alpha = torch.randn(2, 3).exp().requires_grad_()
@@ -2350,7 +2349,7 @@ def test_dirichlet_log_prob(self):
23502349
actual_log_prob = dist.log_prob(x)
23512350
for i in range(num_samples):
23522351
expected_log_prob = scipy.stats.dirichlet.logpdf(x[i].numpy(), alpha.numpy())
2353-
self.assertAlmostEqual(actual_log_prob[i], expected_log_prob, places=3)
2352+
self.assertEqual(actual_log_prob[i], expected_log_prob, atol=1e-3, rtol=0)
23542353

23552354
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
23562355
def test_dirichlet_sample(self):
@@ -2382,7 +2381,7 @@ def test_beta_log_prob(self):
23822381
x = dist.sample()
23832382
actual_log_prob = dist.log_prob(x).sum()
23842383
expected_log_prob = scipy.stats.beta.logpdf(x, con1, con0)
2385-
self.assertAlmostEqual(float(actual_log_prob), float(expected_log_prob), places=3)
2384+
self.assertEqual(float(actual_log_prob), float(expected_log_prob), atol=1e-3, rtol=0)
23862385

23872386
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
23882387
def test_beta_sample(self):

‎test/test_jit_py3.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def func(x):
3434
with self.capture_stdout() as captured_script:
3535
out_script = func(x)
3636

37-
self.assertAlmostEqual(out, out_script)
37+
self.assertEqual(out, out_script)
3838
self.assertEqual(captured, captured_script)
3939

4040
@unittest.skipIf(sys.version_info[:2] < (3, 7), "`dataclasses` module not present on < 3.7")

0 commit comments

Comments
 (0)
Please sign in to comment.