Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 1d01fcd

Browse files
Zafarfacebook-github-bot
Zafar
authored andcommittedSep 9, 2020
[quant] fill_ path for quantized tensors (pytorch#43303)
Summary: Pull Request resolved: pytorch#43303 Test Plan: Imported from OSS Reviewed By: raghuramank100 Differential Revision: D23231947 Pulled By: z-a-f fbshipit-source-id: fd5110ff15a073f326ef590436f8c6e5a2608324
1 parent 4aacfab commit 1d01fcd

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed
 

‎aten/src/ATen/native/Fill.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ namespace {
1919
} // namspace
2020

2121
Tensor& fill_out(Tensor& self, Scalar value) {
22+
if (self.is_quantized()) {
23+
at::Tensor out = at::ones(self.sizes()).to(kFloat) * value;
24+
out = out.to(self.device());
25+
// Trust the `copy_` to handle the quantization and the boundary chacks.
26+
self.copy_(out);
27+
return self;
28+
}
2229
// When filling a number to 1-element CPU tensor, we want to skip
2330
// everything but manipulate data ptr directly.
2431
// Ideally this fast pass should be implemented in TensorIterator,

‎test/quantization/test_quantized_tensor.py

+29
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,35 @@ def test_qtensor_clone(self):
484484
# Check to make sure the scale and zero_point has been copied.
485485
self.assertEqual(q, q2)
486486

487+
def test_qtensor_fill(self):
488+
numel = 10
489+
scale = 0.5
490+
zero_point = 10
491+
492+
ones = torch.ones(numel).to(torch.float)
493+
494+
types = [torch.qint8, torch.quint8, torch.qint32]
495+
fills = [-1, 1, 2**32] # positive, negative, overflow
496+
497+
# `fill_` uses `copy_(float)`, which doesn't support CUDA
498+
device = 'cpu'
499+
ones = ones.to(device)
500+
for qtype, fill_with in itertools.product(types, fills):
501+
q_filled = torch._empty_affine_quantized(
502+
[numel], scale=scale, zero_point=zero_point, device=device,
503+
dtype=qtype)
504+
q_filled.fill_(fill_with)
505+
int_repr = torch.quantize_per_tensor(ones * fill_with, scale,
506+
zero_point, qtype)
507+
fill_with = int_repr.dequantize()
508+
int_repr = int_repr.int_repr()
509+
510+
self.assertEqual(q_filled.int_repr(), int_repr)
511+
self.assertEqual(q_filled.dequantize(), fill_with)
512+
# Make sure the scale and zero_point don't change
513+
self.assertEqual(q_filled.q_scale(), scale)
514+
self.assertEqual(q_filled.q_zero_point(), zero_point)
515+
487516
def test_qtensor_view(self):
488517
scale, zero_point, dtype = 1.0, 2, torch.uint8
489518
for device in get_supported_device_types():

0 commit comments

Comments
 (0)
Please sign in to comment.