Skip to content

Commit cb26661

Browse files
Mike Ruberryfacebook-github-bot
Mike Ruberry
authored andcommittedJun 24, 2020
Throws runtime error when torch.full would infer a float dtype from a bool or integral fill value (pytorch#40364)
Summary: BC-breaking NOTE: In PyTorch 1.6 bool and integral fill values given to torch.full must set the dtype our out keyword arguments. In prior versions of PyTorch these fill values would return float tensors by default, but in PyTorch 1.7 they will return a bool or long tensor, respectively. The documentation for torch.full has been updated to reflect this. PR NOTE: This PR causes torch.full to throw a runtime error when it would have inferred a float dtype by being given a boolean or integer value. A versioned symbol for torch.full is added to preserve the behavior of already serialized Torchscript programs. Existing tests for this behavior being deprecated have been updated to reflect it now being unsupported, and a couple new tests have been added to validate the versioned symbol behavior. The documentation of torch.full has also been updated to reflect this change. Pull Request resolved: pytorch#40364 Differential Revision: D22176640 Pulled By: mruberry fbshipit-source-id: b20158ebbcb4f6bf269d05a688bcf4f6c853a965
1 parent a2d4d9e commit cb26661

19 files changed

+152
-51
lines changed
 

‎aten/src/ATen/native/TensorFactories.cpp

+6-7
Original file line numberDiff line numberDiff line change
@@ -355,13 +355,12 @@ TensorOptions infer_full_options(
355355

356356
if (!options.has_dtype()) {
357357
if (fill_value.isIntegral(true)) {
358-
TORCH_WARN_ONCE(
359-
"Deprecation warning: In a future PyTorch release torch.full ",
360-
"will no longer return tensors of floating dtype by default. ",
361-
"Instead, a bool fill_value will return a tensor of torch.bool dtype, ",
362-
"and an integral fill_value will return a tensor of torch.long dtype. ",
363-
"Set the optional `dtype` or `out` arguments to suppress this warning."
364-
);
358+
TORCH_CHECK(false,
359+
"Providing a bool or integral fill value without setting the optional ",
360+
"`dtype` or `out` arguments is currently unsupported. In PyTorch 1.7, ",
361+
"when `dtype` and `out` are not set a bool fill value will ",
362+
"return a tensor of torch.bool dtype, and an integral fill value ",
363+
"will return a tensor of torch.long dtype.");
365364
} else if (fill_value.isComplex()) {
366365
auto scalar_type = (get_default_dtype() == ScalarType::Double) ?
367366
ScalarType::ComplexDouble :

‎caffe2/serialize/inline_container.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,9 @@ constexpr uint64_t kMaxSupportedFileFormatVersion = 0x5L;
9999
// 3. Added type tags to pickle serialization of container types
100100
// 4. Stopped integer division using torch.div
101101
// (a versioned symbol preserves the historic behavior of versions 1--3)
102-
// 5. (Read-only) Stops torch.full inferring a floating point dtype
103-
// when given integer fill values.
104-
constexpr uint64_t kProducedFileFormatVersion = 0x4L;
102+
// 5. Stops torch.full inferring a floating point dtype
103+
// when given bool or integer fill values.
104+
constexpr uint64_t kProducedFileFormatVersion = 0x5L;
105105

106106
// Writer-specific constants
107107
constexpr uint64_t kFieldAlignment = 64;

‎test/cpp/api/autograd.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ TEST(CustomAutogradTest, FunctionReturnsInput) {
193193

194194
Variable x(torch::ones(1, torch::requires_grad()));
195195
MyFunction::apply(x).backward(torch::ones(1) , true, true);
196-
ASSERT_VARIABLE_EQ(x.grad(), torch::full(1,2));
196+
ASSERT_VARIABLE_EQ(x.grad(), torch::full(1, 2.));
197197
}
198198

199199
TEST(CustomAutogradTest, NoGradCustomFunction) {

‎test/distributed/test_c10d.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1528,9 +1528,9 @@ def test_round_robin(self):
15281528

15291529
# Run a few collectives so that we have called each process group
15301530
for _ in range(num_process_groups + 1):
1531-
tensor = torch.full([100, 100], self.rank)
1531+
tensor = torch.full([100, 100], float(self.rank))
15321532
pg.broadcast(tensor, root=0).wait()
1533-
self.assertEqual(torch.full([100, 100], 0), tensor)
1533+
self.assertEqual(torch.full([100, 100], 0.), tensor)
15341534

15351535
def test_round_robin_create_destroy(self):
15361536
store = c10d.FileStore(self.file_name, self.world_size)
@@ -1551,7 +1551,7 @@ def create(num, prefix):
15511551
for _ in range(3):
15521552
tensor = torch.ones([10, 10])
15531553
pg.allreduce(tensor).wait()
1554-
self.assertEqual(torch.full([10, 10], self.world_size), tensor)
1554+
self.assertEqual(torch.full([10, 10], float(self.world_size)), tensor)
15551555
del pg
15561556

15571557

Binary file not shown.
Binary file not shown.

‎test/jit/test_save_load.py

+72
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,78 @@ def _helper(m, fn):
480480

481481
_helper(v3_module, current_module)
482482

483+
# NOTE: the JIT was incapable of handling boolean fill values when
484+
# PyTorch produced file format versions 0-4
485+
def test_versioned_full_integer_value(self):
486+
class MyModule(torch.nn.Module):
487+
def __init__(self):
488+
super(MyModule, self).__init__()
489+
490+
def forward(self, int_fill: int):
491+
size = torch.Size(2, 2)
492+
a = torch.full(size, int_fill)
493+
b = torch.full(size, 1)
494+
return (a, b)
495+
496+
try:
497+
v4_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_full_integer_value_v4.pt")
498+
except Exception as e:
499+
self.skipTest("Failed to load fixture!")
500+
501+
self._verify_count("aten::full", v4_module, 2)
502+
503+
current_module = self._save_load_module(MyModule)
504+
self._verify_count("aten::full", current_module, 2)
505+
506+
# Verifies historic integer type inference is float
507+
# NOTE: only verifies floating point, not exact dtype, due to
508+
# https://github.com/pytorch/pytorch/issues/40470
509+
results = v4_module(2)
510+
for result in results:
511+
self.assertTrue(result.is_floating_point())
512+
513+
# Verifies values are correct
514+
a, b = results
515+
self.assertTrue((a == 2.).all())
516+
self.assertTrue((b == 1.).all())
517+
518+
with self.assertRaisesRegex(RuntimeError, ".+is currently unsupported.+"):
519+
current_module(2)
520+
521+
# Tests that torch.full behavior which is the same from prior versions
522+
# to version 5 is preserved.
523+
# NOTE: while torch.full in eager PyTorch accepts a requires_grad argument,
524+
# it does not in Torchscript (see https://github.com/pytorch/pytorch/issues/40363)
525+
def test_versioned_full_preserved(self):
526+
class MyModule(torch.nn.Module):
527+
def __init__(self):
528+
super(MyModule, self).__init__()
529+
530+
def forward(self, float_fill: float):
531+
size = (2, 2)
532+
a = torch.full(size, 1.)
533+
b = torch.full(size, float_fill)
534+
c = torch.full(size, float_fill, dtype=torch.long)
535+
536+
out = torch.empty(size, dtype=torch.long)
537+
d = torch.full(size, float_fill, out=out)
538+
539+
e = torch.full(size, float_fill, dtype=torch.float16, pin_memory=None,
540+
layout=torch.strided, device='cpu')
541+
return (a, b, c, d, e)
542+
543+
try:
544+
v4_module = torch.jit.load(pytorch_test_dir + "/jit/fixtures/test_versioned_full_preserved_v4.pt")
545+
except Exception as e:
546+
self.skipTest("Failed to load fixture!")
547+
548+
self._verify_count("aten::full", v4_module, 5)
549+
550+
current_module = self._save_load_module(MyModule)
551+
self._verify_count("aten::full", current_module, 5)
552+
553+
self.assertEqual(v4_module(2.), current_module(2.))
554+
483555
def test_versioned_symbols_reserialization(self):
484556
"""
485557
Tests that loading and saving serialized Torchscript with a versioned

‎test/jit/test_tracer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def test_trace_arange_with_grad(self):
390390
# Test that a trace of torch.full(x.shape) doesn't store the shape as a constant
391391
def test_trace_full_dynamic_shape(self):
392392
def full_with_shape_like(x):
393-
return torch.full(x.shape, 2)
393+
return torch.full(x.shape, 2.)
394394

395395
x = torch.randn(3, 4)
396396
ge = torch.jit.trace(full_with_shape_like, example_inputs=x)

‎test/onnx/test_onnx_opset.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -252,10 +252,9 @@ def forward(self, x):
252252

253253
ops = [{"op_name" : "Constant"},
254254
{"op_name" : "ConstantOfShape"},
255-
{"op_name" : "Cast"},
256255
{"op_name" : "Add"}]
257256
ops = {9 : ops, 10 : ops}
258-
x = torch.tensor(12)
257+
x = torch.tensor(12.)
259258
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
260259

261260
def test_interpolate(self):

‎test/onnx/test_operators.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def test_hardtanh(self):
356356

357357
def test_full(self):
358358
x = torch.randn(3, 4, requires_grad=True)
359-
self.assertONNX(lambda x: torch.full(x.shape, 2), x)
359+
self.assertONNX(lambda x: torch.full(x.shape, 2.), x)
360360

361361
def test_full_like(self):
362362
x = torch.randn(3, 4, requires_grad=True)

‎test/onnx/test_pytorch_onnx_onnxruntime.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1895,7 +1895,7 @@ def get_LstmNet_model_and_inputs(num_layers, bidirectional):
18951895

18961896
num_layers = [1, 1, 2, 3]
18971897
bidirectional = [True, False, True, False]
1898-
models_and_inputs = [get_LstmNet_model_and_inputs(n, b) for n, b in zip(num_layers, bidirectional)]
1898+
models_and_inputs = [get_LstmNet_model_and_inputs(n, b) for n, b in zip(num_layers, bidirectional)]
18991899
for model, input in models_and_inputs:
19001900
self.run_test(model, input)
19011901

@@ -1960,7 +1960,7 @@ def get_GruNet_model_and_inputs(input_size, hidden_size, num_layers, batch_size,
19601960
num_layers = [2, 3]
19611961
batch_size = [3, 4]
19621962
seq_len = [5, 7]
1963-
bidirectional = [True, False]
1963+
bidirectional = [True, False]
19641964
models_and_inputs = [get_GruNet_model_and_inputs(i, h, n, b, s, bi)
19651965
for i, h, n, b, s, bi in zip(input_size, hidden_size, num_layers, batch_size, seq_len, bidirectional)]
19661966
for model, input in models_and_inputs:
@@ -2684,7 +2684,7 @@ class FullModel(torch.nn.Module):
26842684
# add is used for exporting full
26852685
def forward(self, x):
26862686
return torch.full((3, 4), x)
2687-
x = torch.tensor(12)
2687+
x = torch.tensor(12.)
26882688
self.run_test(FullModel(), x)
26892689

26902690
def test_l1_norm(self):

‎test/test_autograd.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,12 @@ def backward(ctx, grad):
158158
for shape in [(1,), ()]:
159159
v = torch.ones(shape, requires_grad=True)
160160
MyFunction.apply(v).backward()
161-
self.assertEqual(v.grad, torch.full(shape, 2))
161+
self.assertEqual(v.grad, torch.full(shape, 2.))
162162

163163
with torch.no_grad():
164164
v.grad.zero_()
165165
MyFunction.apply(v.clone()).backward()
166-
self.assertEqual(v.grad, torch.full(shape, 2))
166+
self.assertEqual(v.grad, torch.full(shape, 2.))
167167

168168
def test_legacy_function_deprecation_exception(self):
169169
# Trigger exception

‎test/test_multiprocessing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def test_cuda_memory_allocation(self):
406406
for _ in range(5):
407407
t.append(q.get())
408408
# TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
409-
self.assertEqualIgnoreType(t[0], torch.full([5], 0))
409+
self.assertEqualIgnoreType(t[0], torch.full([5], 0.))
410410
del t
411411
e.set()
412412
p.join(1)

‎test/test_namedtensor.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,8 @@ def _test(factory, device):
459459
# Test torch.full
460460
for device in torch.testing.get_all_device_types():
461461
names = ('N', 'T', 'D')
462-
result = torch.full([1, 2, 3], 2, names=names, device=device)
463-
expected = torch.full([1, 2, 3], 2, device=device).rename_(*names)
462+
result = torch.full([1, 2, 3], 2., names=names, device=device)
463+
expected = torch.full([1, 2, 3], 2., device=device).rename_(*names)
464464
self.assertTensorDataAndNamesEqual(result, expected)
465465

466466
def test_tensor_from_lists(self):
@@ -1388,7 +1388,7 @@ def test_as_strided_cuda(self):
13881388

13891389
def test_no_jit_tracer_support(self):
13901390
def foo(x):
1391-
return torch.full(x.shape, 2, names=('N',))
1391+
return torch.full(x.shape, 2., names=('N',))
13921392

13931393
with self.assertRaisesRegex(RuntimeError, 'not supported with the tracer'):
13941394
x = torch.randn(3)

‎test/test_torch.py

+25-20
Original file line numberDiff line numberDiff line change
@@ -13089,9 +13089,9 @@ def test_reduction_empty(self, device):
1308913089
('mode', torch.mode, None),
1309013090
('median', torch.median, None),
1309113091

13092-
('prod', torch.prod, 1),
13093-
('sum', torch.sum, 0),
13094-
('norm', torch.norm, 0),
13092+
('prod', torch.prod, 1.),
13093+
('sum', torch.sum, 0.),
13094+
('norm', torch.norm, 0.),
1309513095
('mean', torch.mean, nan),
1309613096
('var', torch.var, nan),
1309713097
('std', torch.std, nan),
@@ -17303,25 +17303,26 @@ def test_min_mixed_devices(self, device):
1730317303
self.assertRaises(RuntimeError,
1730417304
lambda: torch.min(a, 0, out=(values, indices)))
1730517305

17306-
def test_full_deprecation_warning(self, device):
17306+
# NOTE: inferring the dtype from bool or integer fill values is
17307+
# disabled because the behavior is changing from PyTorch 1.5,
17308+
# where the default scalar type would be inferred, to PyTorch 1.7,
17309+
# where bool or long, respectively, will be inferred.
17310+
def test_full_unsupported_integer_inference(self, device):
1730717311
size = (2, 2)
1730817312
# Tests bool and integer fill_values deprecated without specific dtype set
17309-
with self.maybeWarnsRegex(UserWarning, 'Deprecation warning: .+'):
17313+
with self.assertRaisesRegex(RuntimeError, '.+is currently unsupported.+'):
1731017314
self.assertEqual(torch.full(size, True).dtype, torch.float)
17311-
with self.maybeWarnsRegex(UserWarning, 'Deprecation warning: .+'):
17315+
with self.assertRaisesRegex(RuntimeError, '.+is currently unsupported.+'):
1731217316
self.assertEqual(torch.full(size, 1).dtype, torch.float)
1731317317

1731417318
# Explicitly setting the dtype doesn't warn
17315-
with self.maybeWarnsRegex(UserWarning, ''):
17316-
self.assertEqual(torch.full(size, 1, dtype=torch.long).dtype, torch.long)
17317-
with self.maybeWarnsRegex(UserWarning, ''):
17318-
self.assertEqual(torch.full(size, True, dtype=torch.bool).dtype,
17319-
torch.bool)
17319+
self.assertEqual(torch.full(size, 1, dtype=torch.long).dtype, torch.long)
17320+
self.assertEqual(torch.full(size, True, dtype=torch.bool).dtype, torch.bool)
1732017321

1732117322
# Performs same tests with named tensor
17322-
with self.maybeWarnsRegex(UserWarning, 'Deprecation warning: .+|Named tensors .+'):
17323+
with self.assertRaisesRegex(RuntimeError, '.+is currently unsupported.+'):
1732317324
self.assertEqual(torch.full(size, True, names=('a', 'b')).dtype, torch.float)
17324-
with self.maybeWarnsRegex(UserWarning, 'Deprecation warning: .+|Named tensors .+'):
17325+
with self.assertRaisesRegex(RuntimeError, '.+is currently unsupported.+'):
1732517326
self.assertEqual(torch.full(size, 1, names=('a', 'b')).dtype, torch.float)
1732617327

1732717328
with self.maybeWarnsRegex(UserWarning, 'Named tensors .+'):
@@ -17339,15 +17340,17 @@ def test_full_inference(self, device, dtype):
1733917340
prev_default = torch.get_default_dtype()
1734017341
torch.set_default_dtype(dtype)
1734117342

17342-
# Tests bool fill value inference
17343+
# Tests bool fill value inference (currently unsupported)
1734317344
# Note: in the future this will return a tensor of torch.bool dtype
17344-
t = torch.full(size, True)
17345-
self.assertEqual(t.dtype, dtype)
17345+
with self.assertRaisesRegex(RuntimeError, '.+is currently unsupported.+'):
17346+
t = torch.full(size, True)
17347+
self.assertEqual(t.dtype, dtype)
1734617348

17347-
# Tests integer fill value inference
17349+
# Tests integer fill value inference (currently unsupported)
1734817350
# Note: in the future this will return a tensor of torch.long dtype
17349-
t = torch.full(size, 1)
17350-
self.assertEqual(t.dtype, dtype)
17351+
with self.assertRaisesRegex(RuntimeError, '.+is currently unsupported.+'):
17352+
t = torch.full(size, 1)
17353+
self.assertEqual(t.dtype, dtype)
1735117354

1735217355
# Tests float fill value inference
1735317356
t = torch.full(size, 1.)
@@ -17372,14 +17375,16 @@ def test_full_like_inference(self, device):
1737217375
torch.complex64)
1737317376

1737417377
def test_full_out(self, device):
17375-
o = torch.empty((5,), device=device, dtype=torch.long)
17378+
size = (5,)
17379+
o = torch.empty(size, device=device, dtype=torch.long)
1737617380

1737717381
# verifies dtype/out conflict throws a RuntimeError
1737817382
with self.assertRaises(RuntimeError):
1737917383
torch.full(o.shape, 1., dtype=torch.float, out=o)
1738017384

1738117385
# verifies out dtype overrides inference
1738217386
self.assertEqual(torch.full(o.shape, 1., out=o).dtype, o.dtype)
17387+
self.assertEqual(torch.full(size, 1, out=o).dtype, o.dtype)
1738317388

1738417389
def _float_to_int_conversion_helper(self, vals, device, dtype):
1738517390
assert TEST_NUMPY

‎torch/_torch_docs.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -6958,9 +6958,9 @@ def merge_dicts(*dicts):
69586958
Returns a tensor of size :attr:`size` filled with :attr:`fill_value`.
69596959
69606960
.. warning::
6961-
In PyTorch 1.5 a bool or integral :attr:`fill_value` will produce a warning if
6962-
:attr:`dtype` or :attr:`out` are not set.
6963-
In a future PyTorch release, when :attr:`dtype` and :attr:`out` are not set
6961+
Providing a bool or integral :attr:`fill_value` without setting
6962+
the optional :attr:`dtype` or :attr:`out` arguments is currently unsupported.
6963+
In PyTorch 1.7, when :attr:`dtype` and :attr:`out` are not set
69646964
a bool :attr:`fill_value` will return a tensor of torch.bool dtype,
69656965
and an integral :attr:`fill_value` will return a tensor of torch.long dtype.
69666966

‎torch/csrc/jit/frontend/builtin_functions.cpp

+24
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,28 @@ def div__0_3(self: Tensor, other: number) -> Tensor:
127127
return self.floor_divide_(other)
128128
)SCRIPT";
129129

130+
// NOTE: torch.full would historically infer a float dtype for bool and
131+
// integral fill values.
132+
// NOTE: Torchscript does not currently support complex values
133+
// NOTE: Torchscript does not currently support named tensors, although
134+
// torch.full does have a named tensor variant
135+
auto full = R"SCRIPT(
136+
def full_0_4(size:List[int], fill_value:number, *, dtype:Optional[int]=None,
137+
layout:Optional[int]=None, device:Optional[Device]=None,
138+
pin_memory:Optional[bool]=None) -> Tensor:
139+
if dtype is None:
140+
fill_value = float(fill_value)
141+
142+
return torch.full(size, fill_value, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory)
143+
)SCRIPT";
144+
145+
// NOTE: the out variant of full works the same, but must be overridden
146+
// since the other variant of full is overridden
147+
auto full_out = R"SCRIPT(
148+
def full_0_4(size:List[int], fill_value:number, *, out:Tensor) -> Tensor:
149+
return torch.full(size, fill_value, out=out)
150+
)SCRIPT";
151+
130152
struct BuiltinFunctionRegistry {
131153
const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name) {
132154
const static std::vector<Function*> empty;
@@ -204,6 +226,8 @@ struct BuiltinFunctionRegistry {
204226
loadSource(div__tensor, "upgraders");
205227
loadSource(div_tensor_out, "upgraders");
206228
loadSource(div__scalar, "upgraders");
229+
loadSource(full, "upgraders");
230+
loadSource(full_out, "upgraders");
207231

208232
// These are under `prim` instead of `aten` since they exist to bind certain
209233
// tensor property getters to correpsonding methods

‎torch/csrc/jit/frontend/versioned_symbols.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ static std::unordered_map<Symbol, SymbolRange> symbol_range_map({
7272
{0, 3, Symbol::fromQualString("upgraders::div_0_3")}},
7373
{Symbol::fromQualString("aten::div_"),
7474
{0, 3, Symbol::fromQualString("upgraders::div__0_3")}},
75+
{Symbol::fromQualString("aten::full"),
76+
{0, 4, Symbol::fromQualString("upgraders::full_0_4")}},
7577
});
7678

7779
Symbol get_symbol_for_version(const Symbol name, const uint64_t version) {

‎torch/testing/_internal/common_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1609,7 +1609,7 @@ def get_int64_dtype(dtype):
16091609

16101610
default_dtype = torch.get_default_dtype()
16111611
check_value(torch.empty(shape), default_dtype, torch.strided, -1, None, False)
1612-
check_value(torch.full(shape, -5), default_dtype, torch.strided, -1, None, False)
1612+
check_value(torch.full(shape, -5.), default_dtype, torch.strided, -1, None, False)
16131613
for dtype in dtypes:
16141614
for rg in {dtype.is_floating_point, False}:
16151615
int64_dtype = get_int64_dtype(dtype)

0 commit comments

Comments
 (0)
Please sign in to comment.