22
22
from torch .testing ._internal .common_quantization import skipIfNoFBGEMM
23
23
from torch .testing ._internal .common_quantized import _quantize , _dequantize , _calculate_dynamic_qparams , \
24
24
override_quantized_engine , supported_qengines , override_qengines
25
+ from torch .quantization import PerChannelMinMaxObserver
25
26
26
27
np_dtype = {
27
28
torch .quint8 : np .uint8 ,
@@ -2716,7 +2717,7 @@ def test_qlinear_unpack(self, W, use_channelwise):
2716
2717
2717
2718
2718
2719
@unittest .skipIf (sys .platform == "darwin" , "Known test failure on Mac." )
2719
- class TestQuantizedEmbeddingBag (TestCase ):
2720
+ class TestQuantizedEmbeddingOps (TestCase ):
2720
2721
def _test_embedding_bag_unpack_fn (self , pack_fn , unpack_fn , num_embeddings , embedding_dim , bit_rate ):
2721
2722
weights = torch .from_numpy ((np .random .random_sample ((
2722
2723
num_embeddings , embedding_dim )) + 1 ).astype (np .float32 ))
@@ -2727,7 +2728,6 @@ def _test_embedding_bag_unpack_fn(self, pack_fn, unpack_fn, num_embeddings, embe
2727
2728
if bit_rate == 8 :
2728
2729
# Check numerics of prepack function that accepts qtensor as input.
2729
2730
# We use min-max observer to mimic the quantization performed in the original function.
2730
- from torch .quantization import PerChannelMinMaxObserver
2731
2731
obs = PerChannelMinMaxObserver (dtype = torch .quint8 , qscheme = torch .per_channel_affine_float_qparams , ch_axis = 0 )
2732
2732
obs (weights )
2733
2733
# Get the scale and zero point for the weight tensor
@@ -2884,7 +2884,6 @@ def get_reference_result(
2884
2884
2885
2885
if bit_rate == 8 :
2886
2886
# Test operator that accepts TorchBind packed weights.
2887
- from torch .quantization import PerChannelMinMaxObserver
2888
2887
obs = PerChannelMinMaxObserver (dtype = torch .quint8 , qscheme = torch .per_channel_affine_float_qparams , ch_axis = 0 )
2889
2888
obs (weights )
2890
2889
# Get the scale and zero point for the weight tensor
@@ -2931,6 +2930,37 @@ def test_embedding_bag_4bit_rowwise_offsets(self, num_embeddings,
2931
2930
include_last_offset , atol = 0.1 ,
2932
2931
rtol = 1e-2 )
2933
2932
2933
+ """ Tests the correctness of the quantized embedding lookup operator """
2934
+ @given (num_embeddings = st .integers (10 , 100 ),
2935
+ embedding_dim = st .integers (5 , 50 ).filter (lambda x : x % 4 == 0 ))
2936
+ def test_embedding_byte (self , num_embeddings , embedding_dim ):
2937
+ quant_op = torch .ops .quantized .embedding_byte
2938
+ prepack_op = torch .ops .quantized .embedding_bag_prepack
2939
+
2940
+ weights = torch .from_numpy ((np .random .random_sample ((
2941
+ num_embeddings , embedding_dim )) + 1 ).astype (np .float32 ))
2942
+
2943
+ obs = PerChannelMinMaxObserver (dtype = torch .quint8 , qscheme = torch .per_channel_affine_float_qparams , ch_axis = 0 )
2944
+ obs (weights )
2945
+ # Get the scale and zero point for the weight tensor
2946
+ qparams = obs .calculate_qparams ()
2947
+
2948
+ # Quantize the weights to 8bits
2949
+ qweight = torch .quantize_per_channel (weights , qparams [0 ], qparams [1 ], axis = 0 , dtype = torch .quint8 )
2950
+ max_segments = 5
2951
+ max_segment_length = 20
2952
+ num_lengths = np .random .randint (1 , max_segments + 1 )
2953
+ lengths = np .random .randint (1 , max_segment_length + 1 ,
2954
+ size = num_lengths ).astype (np .int32 )
2955
+ num_indices = np .sum (lengths )
2956
+ indices = torch .from_numpy (np .random .randint (
2957
+ low = 0 , high = num_embeddings , size = num_indices , dtype = np .int64 ))
2958
+
2959
+ packed_weight = prepack_op (qweight )
2960
+ qresult = quant_op (packed_weight , indices , sparse = False )
2961
+
2962
+ ref = torch .embedding (weights , indices , padding_idx = - 1 , scale_grad_by_freq = False , sparse = False )
2963
+ torch .testing .assert_allclose (ref , qresult , atol = 0.005 , rtol = 1e-3 )
2934
2964
2935
2965
class TestQuantizedConv (unittest .TestCase ):
2936
2966
def _test_qconv_unpack_impl (self , qconv_prepack_fn , qconv_unpack_fn , inputs ,
0 commit comments