@@ -998,8 +998,7 @@ def test_binomial_log_prob(self):
998
998
def ref_log_prob (idx , x , log_prob ):
999
999
p = probs .view (- 1 )[idx ].item ()
1000
1000
expected = scipy .stats .binom (total_count , p ).logpmf (x )
1001
- self .assertAlmostEqual (log_prob , expected , places = 3 )
1002
-
1001
+ self .assertEqual (log_prob , expected , atol = 1e-3 , rtol = 0 )
1003
1002
self ._check_log_prob (Binomial (total_count , probs ), ref_log_prob )
1004
1003
logits = probs_to_logits (probs , is_binary = True )
1005
1004
self ._check_log_prob (Binomial (total_count , logits = logits ), ref_log_prob )
@@ -1023,7 +1022,7 @@ def test_binomial_log_prob_vectorized_count(self):
1023
1022
(torch .tensor ([1 , 2 , 10 ]), torch .tensor ([0. , 1. , 9. ]))]:
1024
1023
log_prob = Binomial (total_count , probs ).log_prob (sample )
1025
1024
expected = scipy .stats .binom (total_count .cpu ().numpy (), probs .cpu ().numpy ()).logpmf (sample )
1026
- self .assertAlmostEqual (log_prob , expected , places = 4 )
1025
+ self .assertEqual (log_prob , expected , atol = 1e-4 , rtol = 0 )
1027
1026
1028
1027
def test_binomial_enumerate_support (self ):
1029
1028
examples = [
@@ -1037,11 +1036,11 @@ def test_binomial_extreme_vals(self):
1037
1036
total_count = 100
1038
1037
bin0 = Binomial (total_count , 0 )
1039
1038
self .assertEqual (bin0 .sample (), 0 )
1040
- self .assertAlmostEqual (bin0 .log_prob (torch .tensor ([0. ]))[0 ], 0 , places = 3 )
1039
+ self .assertEqual (bin0 .log_prob (torch .tensor ([0. ]))[0 ], 0 , atol = 1e-3 , rtol = 0 )
1041
1040
self .assertEqual (float (bin0 .log_prob (torch .tensor ([1. ])).exp ()), 0 )
1042
1041
bin1 = Binomial (total_count , 1 )
1043
1042
self .assertEqual (bin1 .sample (), total_count )
1044
- self .assertAlmostEqual (bin1 .log_prob (torch .tensor ([float (total_count )]))[0 ], 0 , places = 3 )
1043
+ self .assertEqual (bin1 .log_prob (torch .tensor ([float (total_count )]))[0 ], 0 , atol = 1e-3 , rtol = 0 )
1045
1044
self .assertEqual (float (bin1 .log_prob (torch .tensor ([float (total_count - 1 )])).exp ()), 0 )
1046
1045
zero_counts = torch .zeros (torch .Size ((2 , 2 )))
1047
1046
bin2 = Binomial (zero_counts , 1 )
@@ -1076,7 +1075,7 @@ def test_negative_binomial_log_prob(self):
1076
1075
def ref_log_prob (idx , x , log_prob ):
1077
1076
p = probs .view (- 1 )[idx ].item ()
1078
1077
expected = scipy .stats .nbinom (total_count , 1 - p ).logpmf (x )
1079
- self .assertAlmostEqual (log_prob , expected , places = 3 )
1078
+ self .assertEqual (log_prob , expected , atol = 1e-3 , rtol = 0 )
1080
1079
1081
1080
self ._check_log_prob (NegativeBinomial (total_count , probs ), ref_log_prob )
1082
1081
logits = probs_to_logits (probs , is_binary = True )
@@ -1089,7 +1088,7 @@ def test_negative_binomial_log_prob_vectorized_count(self):
1089
1088
(torch .tensor ([1 , 2 , 10 ]), torch .tensor ([0. , 1. , 9. ]))]:
1090
1089
log_prob = NegativeBinomial (total_count , probs ).log_prob (sample )
1091
1090
expected = scipy .stats .nbinom (total_count .cpu ().numpy (), 1 - probs .cpu ().numpy ()).logpmf (sample )
1092
- self .assertAlmostEqual (log_prob , expected , places = 4 )
1091
+ self .assertEqual (log_prob , expected , atol = 1e-4 , rtol = 0 )
1093
1092
1094
1093
def test_multinomial_1d (self ):
1095
1094
total_count = 10
@@ -1238,7 +1237,7 @@ def test_poisson_log_prob(self):
1238
1237
def ref_log_prob (idx , x , log_prob ):
1239
1238
l = rate .view (- 1 )[idx ].detach ()
1240
1239
expected = scipy .stats .poisson .logpmf (x , l )
1241
- self .assertAlmostEqual (log_prob , expected , places = 3 )
1240
+ self .assertEqual (log_prob , expected , atol = 1e-3 , rtol = 0 )
1242
1241
1243
1242
set_rng_seed (0 )
1244
1243
self ._check_log_prob (Poisson (rate ), ref_log_prob )
@@ -1500,7 +1499,7 @@ def test_halfnormal_logprob(self):
1500
1499
def ref_log_prob (idx , x , log_prob ):
1501
1500
s = std .view (- 1 )[idx ].detach ()
1502
1501
expected = scipy .stats .halfnorm (scale = s ).logpdf (x )
1503
- self .assertAlmostEqual (log_prob , expected , places = 3 )
1502
+ self .assertEqual (log_prob , expected , atol = 1e-3 , rtol = 0 )
1504
1503
1505
1504
self ._check_log_prob (HalfNormal (std ), ref_log_prob )
1506
1505
@@ -1550,7 +1549,7 @@ def ref_log_prob(idx, x, log_prob):
1550
1549
m = mean .view (- 1 )[idx ].detach ()
1551
1550
s = std .view (- 1 )[idx ].detach ()
1552
1551
expected = scipy .stats .lognorm (s = s , scale = math .exp (m )).logpdf (x )
1553
- self .assertAlmostEqual (log_prob , expected , places = 3 )
1552
+ self .assertEqual (log_prob , expected , atol = 1e-3 , rtol = 0 )
1554
1553
1555
1554
self ._check_log_prob (LogNormal (mean , std ), ref_log_prob )
1556
1555
@@ -1681,7 +1680,7 @@ def ref_log_prob(idx, x, log_prob):
1681
1680
mix = scipy .stats .multinomial (1 , p )
1682
1681
comp = scipy .stats .norm (m , s )
1683
1682
expected = scipy .special .logsumexp (comp .logpdf (x ) + np .log (mix .p ))
1684
- self .assertAlmostEqual (log_prob , expected , places = 3 )
1683
+ self .assertEqual (log_prob , expected , atol = 1e-3 , rtol = 0 )
1685
1684
1686
1685
self ._check_log_prob (
1687
1686
MixtureSameFamily (Categorical (probs = probs ),
@@ -1754,7 +1753,7 @@ def ref_log_prob(idx, x, log_prob):
1754
1753
s = scale .view (- 1 )[idx ]
1755
1754
expected = (math .exp (- (x - m ) ** 2 / (2 * s ** 2 )) /
1756
1755
math .sqrt (2 * math .pi * s ** 2 ))
1757
- self .assertAlmostEqual (log_prob , math .log (expected ), places = 3 )
1756
+ self .assertEqual (log_prob , math .log (expected ), atol = 1e-3 , rtol = 0 )
1758
1757
1759
1758
self ._check_log_prob (Normal (loc , scale ), ref_log_prob )
1760
1759
@@ -1828,7 +1827,7 @@ def test_lowrank_multivariate_normal_log_prob(self):
1828
1827
x = dist1 .sample ((10 ,))
1829
1828
expected = ref_dist .logpdf (x .numpy ())
1830
1829
1831
- self .assertAlmostEqual (0.0 , np .mean ((dist1 .log_prob (x ).detach ().numpy () - expected )** 2 ), places = 3 )
1830
+ self .assertEqual (0.0 , np .mean ((dist1 .log_prob (x ).detach ().numpy () - expected )** 2 ), atol = 1e-3 , rtol = 0 )
1832
1831
1833
1832
# Double-check that batched versions behave the same as unbatched
1834
1833
mean = torch .randn (5 , 3 , requires_grad = True )
@@ -1844,7 +1843,7 @@ def test_lowrank_multivariate_normal_log_prob(self):
1844
1843
unbatched_prob = torch .stack ([dist_unbatched [i ].log_prob (x [:, i ]) for i in range (5 )]).t ()
1845
1844
1846
1845
self .assertEqual (batched_prob .shape , unbatched_prob .shape )
1847
- self .assertAlmostEqual (0.0 , (batched_prob - unbatched_prob ).abs ().max (), places = 3 )
1846
+ self .assertEqual (0.0 , (batched_prob - unbatched_prob ).abs ().max (), atol = 1e-3 , rtol = 0 )
1848
1847
1849
1848
@unittest .skipIf (not TEST_NUMPY , "NumPy not found" )
1850
1849
def test_lowrank_multivariate_normal_sample (self ):
@@ -1965,9 +1964,9 @@ def test_multivariate_normal_log_prob(self):
1965
1964
x = dist1 .sample ((10 ,))
1966
1965
expected = ref_dist .logpdf (x .numpy ())
1967
1966
1968
- self .assertAlmostEqual (0.0 , np .mean ((dist1 .log_prob (x ).detach ().numpy () - expected )** 2 ), places = 3 )
1969
- self .assertAlmostEqual (0.0 , np .mean ((dist2 .log_prob (x ).detach ().numpy () - expected )** 2 ), places = 3 )
1970
- self .assertAlmostEqual (0.0 , np .mean ((dist3 .log_prob (x ).detach ().numpy () - expected )** 2 ), places = 3 )
1967
+ self .assertEqual (0.0 , np .mean ((dist1 .log_prob (x ).detach ().numpy () - expected )** 2 ), atol = 1e-3 , rtol = 0 )
1968
+ self .assertEqual (0.0 , np .mean ((dist2 .log_prob (x ).detach ().numpy () - expected )** 2 ), atol = 1e-3 , rtol = 0 )
1969
+ self .assertEqual (0.0 , np .mean ((dist3 .log_prob (x ).detach ().numpy () - expected )** 2 ), atol = 1e-3 , rtol = 0 )
1971
1970
1972
1971
# Double-check that batched versions behave the same as unbatched
1973
1972
mean = torch .randn (5 , 3 , requires_grad = True )
@@ -1982,7 +1981,7 @@ def test_multivariate_normal_log_prob(self):
1982
1981
unbatched_prob = torch .stack ([dist_unbatched [i ].log_prob (x [:, i ]) for i in range (5 )]).t ()
1983
1982
1984
1983
self .assertEqual (batched_prob .shape , unbatched_prob .shape )
1985
- self .assertAlmostEqual (0.0 , (batched_prob - unbatched_prob ).abs ().max (), places = 3 )
1984
+ self .assertEqual (0.0 , (batched_prob - unbatched_prob ).abs ().max (), atol = 1e-3 , rtol = 0 )
1986
1985
1987
1986
@unittest .skipIf (not TEST_NUMPY , "NumPy not found" )
1988
1987
def test_multivariate_normal_sample (self ):
@@ -2048,7 +2047,7 @@ def test_exponential(self):
2048
2047
def ref_log_prob (idx , x , log_prob ):
2049
2048
m = rate .view (- 1 )[idx ]
2050
2049
expected = math .log (m ) - m * x
2051
- self .assertAlmostEqual (log_prob , expected , places = 3 )
2050
+ self .assertEqual (log_prob , expected , atol = 1e-3 , rtol = 0 )
2052
2051
2053
2052
self ._check_log_prob (Exponential (rate ), ref_log_prob )
2054
2053
@@ -2099,7 +2098,7 @@ def ref_log_prob(idx, x, log_prob):
2099
2098
m = loc .view (- 1 )[idx ]
2100
2099
s = scale .view (- 1 )[idx ]
2101
2100
expected = (- math .log (2 * s ) - abs (x - m ) / s )
2102
- self .assertAlmostEqual (log_prob , expected , places = 3 )
2101
+ self .assertEqual (log_prob , expected , atol = 1e-3 , rtol = 0 )
2103
2102
2104
2103
self ._check_log_prob (Laplace (loc , scale ), ref_log_prob )
2105
2104
@@ -2128,7 +2127,7 @@ def ref_log_prob(idx, x, log_prob):
2128
2127
a = alpha .view (- 1 )[idx ].detach ()
2129
2128
b = beta .view (- 1 )[idx ].detach ()
2130
2129
expected = scipy .stats .gamma .logpdf (x , a , scale = 1 / b )
2131
- self .assertAlmostEqual (log_prob , expected , places = 3 )
2130
+ self .assertEqual (log_prob , expected , atol = 1e-3 , rtol = 0 )
2132
2131
2133
2132
self ._check_log_prob (Gamma (alpha , beta ), ref_log_prob )
2134
2133
@@ -2150,7 +2149,7 @@ def ref_log_prob(idx, x, log_prob):
2150
2149
a = alpha .view (- 1 )[idx ].detach ().cpu ()
2151
2150
b = beta .view (- 1 )[idx ].detach ().cpu ()
2152
2151
expected = scipy .stats .gamma .logpdf (x .cpu (), a , scale = 1 / b )
2153
- self .assertAlmostEqual (log_prob , expected , places = 3 )
2152
+ self .assertEqual (log_prob , expected , atol = 1e-3 , rtol = 0 )
2154
2153
2155
2154
self ._check_log_prob (Gamma (alpha , beta ), ref_log_prob )
2156
2155
@@ -2192,7 +2191,7 @@ def ref_log_prob(idx, x, log_prob):
2192
2191
s = scale .view (- 1 )[idx ].detach ()
2193
2192
a = alpha .view (- 1 )[idx ].detach ()
2194
2193
expected = scipy .stats .pareto .logpdf (x , a , scale = s )
2195
- self .assertAlmostEqual (log_prob , expected , places = 3 )
2194
+ self .assertEqual (log_prob , expected , atol = 1e-3 , rtol = 0 )
2196
2195
2197
2196
self ._check_log_prob (Pareto (scale , alpha ), ref_log_prob )
2198
2197
@@ -2221,7 +2220,7 @@ def ref_log_prob(idx, x, log_prob):
2221
2220
l = loc .view (- 1 )[idx ].detach ()
2222
2221
s = scale .view (- 1 )[idx ].detach ()
2223
2222
expected = scipy .stats .gumbel_r .logpdf (x , loc = l , scale = s )
2224
- self .assertAlmostEqual (log_prob , expected , places = 3 )
2223
+ self .assertEqual (log_prob , expected , atol = 1e-3 , rtol = 0 )
2225
2224
2226
2225
self ._check_log_prob (Gumbel (loc , scale ), ref_log_prob )
2227
2226
@@ -2252,7 +2251,7 @@ def ref_log_prob(idx, x, log_prob):
2252
2251
f1 = df1 .view (- 1 )[idx ].detach ()
2253
2252
f2 = df2 .view (- 1 )[idx ].detach ()
2254
2253
expected = scipy .stats .f .logpdf (x , f1 , f2 )
2255
- self .assertAlmostEqual (log_prob , expected , places = 3 )
2254
+ self .assertEqual (log_prob , expected , atol = 1e-3 , rtol = 0 )
2256
2255
2257
2256
self ._check_log_prob (FisherSnedecor (df1 , df2 ), ref_log_prob )
2258
2257
@@ -2279,7 +2278,7 @@ def test_chi2_shape(self):
2279
2278
def ref_log_prob (idx , x , log_prob ):
2280
2279
d = df .view (- 1 )[idx ].detach ()
2281
2280
expected = scipy .stats .chi2 .logpdf (x , d )
2282
- self .assertAlmostEqual (log_prob , expected , places = 3 )
2281
+ self .assertEqual (log_prob , expected , atol = 1e-3 , rtol = 0 )
2283
2282
2284
2283
self ._check_log_prob (Chi2 (df ), ref_log_prob )
2285
2284
@@ -2309,7 +2308,7 @@ def test_studentT(self):
2309
2308
def ref_log_prob (idx , x , log_prob ):
2310
2309
d = df .view (- 1 )[idx ].detach ()
2311
2310
expected = scipy .stats .t .logpdf (x , d )
2312
- self .assertAlmostEqual (log_prob , expected , places = 3 )
2311
+ self .assertEqual (log_prob , expected , atol = 1e-3 , rtol = 0 )
2313
2312
2314
2313
self ._check_log_prob (StudentT (df ), ref_log_prob )
2315
2314
@@ -2331,7 +2330,7 @@ def test_studentT_log_prob(self):
2331
2330
actual_log_prob = dist .log_prob (x )
2332
2331
for i in range (num_samples ):
2333
2332
expected_log_prob = scipy .stats .t .logpdf (x [i ], df = df , loc = loc , scale = scale )
2334
- self .assertAlmostEqual (float (actual_log_prob [i ]), float (expected_log_prob ), places = 3 )
2333
+ self .assertEqual (float (actual_log_prob [i ]), float (expected_log_prob ), atol = 1e-3 , rtol = 0 )
2335
2334
2336
2335
def test_dirichlet_shape (self ):
2337
2336
alpha = torch .randn (2 , 3 ).exp ().requires_grad_ ()
@@ -2350,7 +2349,7 @@ def test_dirichlet_log_prob(self):
2350
2349
actual_log_prob = dist .log_prob (x )
2351
2350
for i in range (num_samples ):
2352
2351
expected_log_prob = scipy .stats .dirichlet .logpdf (x [i ].numpy (), alpha .numpy ())
2353
- self .assertAlmostEqual (actual_log_prob [i ], expected_log_prob , places = 3 )
2352
+ self .assertEqual (actual_log_prob [i ], expected_log_prob , atol = 1e-3 , rtol = 0 )
2354
2353
2355
2354
@unittest .skipIf (not TEST_NUMPY , "NumPy not found" )
2356
2355
def test_dirichlet_sample (self ):
@@ -2382,7 +2381,7 @@ def test_beta_log_prob(self):
2382
2381
x = dist .sample ()
2383
2382
actual_log_prob = dist .log_prob (x ).sum ()
2384
2383
expected_log_prob = scipy .stats .beta .logpdf (x , con1 , con0 )
2385
- self .assertAlmostEqual (float (actual_log_prob ), float (expected_log_prob ), places = 3 )
2384
+ self .assertEqual (float (actual_log_prob ), float (expected_log_prob ), atol = 1e-3 , rtol = 0 )
2386
2385
2387
2386
@unittest .skipIf (not TEST_NUMPY , "NumPy not found" )
2388
2387
def test_beta_sample (self ):
0 commit comments