Skip to content

Commit 6013a29

Browse files
supriyarfacebook-github-bot
authored andcommittedSep 9, 2020
[quant] Support quantization of embedding lookup operators (pytorch#44207)
Summary: Pull Request resolved: pytorch#44207 Use existing embedding_bag operator but set offsets to [0, 1, .. len(indices)] Test Plan: python test/test_quantization.py TestEmbeddingOps.test_embedding_byte Imported from OSS Reviewed By: vkuzo Differential Revision: D23547385 fbshipit-source-id: ccce348bc192c6a4a65a8eca4c8b90f99f40f1b1
1 parent f27be2f commit 6013a29

File tree

4 files changed

+66
-8
lines changed

4 files changed

+66
-8
lines changed
 

‎aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp

+31-4
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@ at::Tensor PackedEmbeddingBagWeight::embeddingbag_byte(
1717
bool sparse,
1818
const c10::optional<at::Tensor>& per_sample_weights_,
1919
bool include_last_offset) {
20-
21-
TORCH_CHECK(offsets_in.has_value(), "embedding_bag_byte_rowwise_offsets expects offsets to be set");
20+
TORCH_CHECK(
21+
offsets_in.has_value(),
22+
"embedding_bag_byte_rowwise_offsets expects offsets to be set");
2223
auto offsets = offsets_in.value();
2324
auto offsets_data = offsets.data_ptr<int64_t>();
2425
const auto indices_data = indices.data_ptr<int64_t>();
@@ -123,7 +124,9 @@ Tensor embedding_bag_byte_rowwise_offsets(
123124
bool include_last_offset) {
124125
TORCH_CHECK(weight.scalar_type() == at::kByte);
125126
TORCH_CHECK(weight.ndimension() == 2);
126-
TORCH_CHECK(offsets_in.has_value(), "embedding_bag_byte_rowwise_offsets expects offsets to be set");
127+
TORCH_CHECK(
128+
offsets_in.has_value(),
129+
"embedding_bag_byte_rowwise_offsets expects offsets to be set");
127130

128131
auto offsets = offsets_in.value();
129132
auto offsets_data = offsets.data_ptr<int64_t>();
@@ -221,7 +224,9 @@ Tensor embedding_bag_4bit_rowwise_offsets(
221224
const c10::optional<Tensor>& per_sample_weights_,
222225
const c10::optional<Tensor>& compressed_indices_mapping,
223226
bool include_last_offset) {
224-
TORCH_CHECK(offsets_in.has_value(), "embedding_bag_4bit_rowwise_offsets expects offsets to be set");
227+
TORCH_CHECK(
228+
offsets_in.has_value(),
229+
"embedding_bag_4bit_rowwise_offsets expects offsets to be set");
225230

226231
TORCH_CHECK(weight.ndimension() == 2);
227232
TORCH_CHECK(indices.ndimension() == 1);
@@ -423,9 +428,31 @@ class QEmbeddingBag final {
423428
}
424429
};
425430

431+
template <int bit_rate>
432+
class QEmbedding final {
433+
public:
434+
static at::Tensor run(
435+
const c10::intrusive_ptr<EmbeddingPackedParamsBase>& packed_weight,
436+
const Tensor& indices,
437+
bool sparse) {
438+
const auto offsets_size = indices.numel();
439+
at::Tensor offsets = at::arange(0, offsets_size, at::kLong);
440+
at::Tensor output;
441+
if (bit_rate == 8) {
442+
return packed_weight->embeddingbag_byte(
443+
indices, offsets, sparse, c10::nullopt, false);
444+
} else {
445+
TORCH_INTERNAL_ASSERT(
446+
"Currently only support 8-bit embedding quantization");
447+
}
448+
return output;
449+
}
450+
};
451+
426452
TORCH_LIBRARY_IMPL(quantized, CPU, m) {
427453
// Function that works on TorchBind packed weights.
428454
m.impl("embedding_bag_byte", TORCH_FN(QEmbeddingBag<8>::run));
455+
m.impl("embedding_byte", TORCH_FN(QEmbedding<8>::run));
429456

430457
// Functions that work on at::Tensor packed weight.
431458
m.impl(

‎aten/src/ATen/native/quantized/library.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ TORCH_LIBRARY(quantized, m) {
110110
m.def("embedding_bag_byte_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> Tensor");
111111
m.def("embedding_bag_4bit_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor");
112112
m.def("embedding_bag_byte(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor");
113+
m.def("embedding_byte(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, bool sparse=False) -> Tensor");
113114
m.def("celu(Tensor self, float output_scale, int output_zero_point, Scalar alpha=1) -> Tensor");
114115
m.def("hardswish(Tensor input, float output_scale, int output_zero_point) -> Tensor");
115116
m.def("group_norm(Tensor input, int num_groups, Tensor? weight, Tensor? bias, float eps, float output_scale, int output_zero_point) -> Tensor");

‎test/quantization/test_quantized_op.py

+33-3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
2323
from torch.testing._internal.common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams, \
2424
override_quantized_engine, supported_qengines, override_qengines
25+
from torch.quantization import PerChannelMinMaxObserver
2526

2627
np_dtype = {
2728
torch.quint8 : np.uint8,
@@ -2716,7 +2717,7 @@ def test_qlinear_unpack(self, W, use_channelwise):
27162717

27172718

27182719
@unittest.skipIf(sys.platform == "darwin", "Known test failure on Mac.")
2719-
class TestQuantizedEmbeddingBag(TestCase):
2720+
class TestQuantizedEmbeddingOps(TestCase):
27202721
def _test_embedding_bag_unpack_fn(self, pack_fn, unpack_fn, num_embeddings, embedding_dim, bit_rate):
27212722
weights = torch.from_numpy((np.random.random_sample((
27222723
num_embeddings, embedding_dim)) + 1).astype(np.float32))
@@ -2727,7 +2728,6 @@ def _test_embedding_bag_unpack_fn(self, pack_fn, unpack_fn, num_embeddings, embe
27272728
if bit_rate == 8:
27282729
# Check numerics of prepack function that accepts qtensor as input.
27292730
# We use min-max observer to mimic the quantization performed in the original function.
2730-
from torch.quantization import PerChannelMinMaxObserver
27312731
obs = PerChannelMinMaxObserver(dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0)
27322732
obs(weights)
27332733
# Get the scale and zero point for the weight tensor
@@ -2884,7 +2884,6 @@ def get_reference_result(
28842884

28852885
if bit_rate == 8:
28862886
# Test operator that accepts TorchBind packed weights.
2887-
from torch.quantization import PerChannelMinMaxObserver
28882887
obs = PerChannelMinMaxObserver(dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0)
28892888
obs(weights)
28902889
# Get the scale and zero point for the weight tensor
@@ -2931,6 +2930,37 @@ def test_embedding_bag_4bit_rowwise_offsets(self, num_embeddings,
29312930
include_last_offset, atol=0.1,
29322931
rtol=1e-2)
29332932

2933+
""" Tests the correctness of the quantized embedding lookup operator """
2934+
@given(num_embeddings=st.integers(10, 100),
2935+
embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0))
2936+
def test_embedding_byte(self, num_embeddings, embedding_dim):
2937+
quant_op = torch.ops.quantized.embedding_byte
2938+
prepack_op = torch.ops.quantized.embedding_bag_prepack
2939+
2940+
weights = torch.from_numpy((np.random.random_sample((
2941+
num_embeddings, embedding_dim)) + 1).astype(np.float32))
2942+
2943+
obs = PerChannelMinMaxObserver(dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0)
2944+
obs(weights)
2945+
# Get the scale and zero point for the weight tensor
2946+
qparams = obs.calculate_qparams()
2947+
2948+
# Quantize the weights to 8bits
2949+
qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8)
2950+
max_segments = 5
2951+
max_segment_length = 20
2952+
num_lengths = np.random.randint(1, max_segments + 1)
2953+
lengths = np.random.randint(1, max_segment_length + 1,
2954+
size=num_lengths).astype(np.int32)
2955+
num_indices = np.sum(lengths)
2956+
indices = torch.from_numpy(np.random.randint(
2957+
low=0, high=num_embeddings, size=num_indices, dtype=np.int64))
2958+
2959+
packed_weight = prepack_op(qweight)
2960+
qresult = quant_op(packed_weight, indices, sparse=False)
2961+
2962+
ref = torch.embedding(weights, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False)
2963+
torch.testing.assert_allclose(ref, qresult, atol=0.005, rtol=1e-3)
29342964

29352965
class TestQuantizedConv(unittest.TestCase):
29362966
def _test_qconv_unpack_impl(self, qconv_prepack_fn, qconv_unpack_fn, inputs,

‎test/test_quantization.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from quantization.test_quantized_op import TestDynamicQuantizedLinear # noqa: F401
1414
from quantization.test_quantized_op import TestComparatorOps # noqa: F401
1515
from quantization.test_quantized_op import TestPadding # noqa: F401
16-
from quantization.test_quantized_op import TestQuantizedEmbeddingBag # noqa: F401
16+
from quantization.test_quantized_op import TestQuantizedEmbeddingOps # noqa: F401
1717

1818
# Quantized Functional
1919
from quantization.test_quantized_functional import TestQuantizedFunctional # noqa: F401

0 commit comments

Comments
 (0)
Please sign in to comment.