9
9
import itertools
10
10
import os
11
11
import uuid
12
+ from collections .abc import Sequence
12
13
from enum import Enum
13
14
from pathlib import Path
14
- from typing import Dict , Optional , Sequence , Tuple , Union
15
15
16
16
import numpy as np
17
17
import onnx
@@ -39,7 +39,7 @@ def rel_entr(pk: np.ndarray, qk: np.ndarray) -> np.ndarray:
39
39
def entropy (
40
40
pk : np .ndarray ,
41
41
qk : np .ndarray ,
42
- base : Optional [ float ] = None ,
42
+ base : float | None = None ,
43
43
axis : int = 0 ,
44
44
) -> np .ndarray :
45
45
"""
@@ -100,7 +100,7 @@ def to_dict(self):
100
100
101
101
102
102
class TensorsData :
103
- def __init__ (self , calibration_method , data : Dict [str , Union [ TensorData , Tuple ] ]):
103
+ def __init__ (self , calibration_method , data : dict [str , TensorData | tuple ]):
104
104
self .calibration_method = calibration_method
105
105
self .data = {}
106
106
for k , v in data .items ():
@@ -187,8 +187,8 @@ def set_range(self, start_index: int, end_index: int):
187
187
class CalibraterBase :
188
188
def __init__ (
189
189
self ,
190
- model_path : Union [ str , Path ] ,
191
- op_types_to_calibrate : Optional [ Sequence [str ]] = None ,
190
+ model_path : str | Path ,
191
+ op_types_to_calibrate : Sequence [str ] | None = None ,
192
192
augmented_model_path = "augmented_model.onnx" ,
193
193
symmetric = False ,
194
194
use_external_data_format = False ,
@@ -297,8 +297,8 @@ def compute_data(self) -> TensorsData:
297
297
class MinMaxCalibrater (CalibraterBase ):
298
298
def __init__ (
299
299
self ,
300
- model_path : Union [ str , Path ] ,
301
- op_types_to_calibrate : Optional [ Sequence [str ]] = None ,
300
+ model_path : str | Path ,
301
+ op_types_to_calibrate : Sequence [str ] | None = None ,
302
302
augmented_model_path = "augmented_model.onnx" ,
303
303
symmetric = False ,
304
304
use_external_data_format = False ,
@@ -476,7 +476,8 @@ def compute_data(self) -> TensorsData:
476
476
477
477
output_names = [self .infer_session .get_outputs ()[i ].name for i in range (len (self .intermediate_outputs [0 ]))]
478
478
output_dicts_list = [
479
- dict (zip (output_names , intermediate_output )) for intermediate_output in self .intermediate_outputs
479
+ dict (zip (output_names , intermediate_output , strict = False ))
480
+ for intermediate_output in self .intermediate_outputs
480
481
]
481
482
482
483
merged_output_dict = {}
@@ -507,7 +508,9 @@ def compute_data(self) -> TensorsData:
507
508
else :
508
509
pairs .append (tuple ([min_value_array , max_value_array ]))
509
510
510
- new_calibrate_tensors_range = TensorsData (CalibrationMethod .MinMax , dict (zip (calibrate_tensor_names , pairs )))
511
+ new_calibrate_tensors_range = TensorsData (
512
+ CalibrationMethod .MinMax , dict (zip (calibrate_tensor_names , pairs , strict = False ))
513
+ )
511
514
if self .calibrate_tensors_range :
512
515
self .calibrate_tensors_range = self .merge_range (self .calibrate_tensors_range , new_calibrate_tensors_range )
513
516
else :
@@ -519,8 +522,8 @@ def compute_data(self) -> TensorsData:
519
522
class HistogramCalibrater (CalibraterBase ):
520
523
def __init__ (
521
524
self ,
522
- model_path : Union [ str , Path ] ,
523
- op_types_to_calibrate : Optional [ Sequence [str ]] = None ,
525
+ model_path : str | Path ,
526
+ op_types_to_calibrate : Sequence [str ] | None = None ,
524
527
augmented_model_path = "augmented_model.onnx" ,
525
528
use_external_data_format = False ,
526
529
method = "percentile" ,
@@ -608,7 +611,8 @@ def collect_data(self, data_reader: CalibrationDataReader):
608
611
raise ValueError ("No data is collected." )
609
612
610
613
output_dicts_list = [
611
- dict (zip (output_names , intermediate_output )) for intermediate_output in self .intermediate_outputs
614
+ dict (zip (output_names , intermediate_output , strict = False ))
615
+ for intermediate_output in self .intermediate_outputs
612
616
]
613
617
614
618
merged_dict = {}
@@ -653,8 +657,8 @@ def compute_data(self) -> TensorsData:
653
657
class EntropyCalibrater (HistogramCalibrater ):
654
658
def __init__ (
655
659
self ,
656
- model_path : Union [ str , Path ] ,
657
- op_types_to_calibrate : Optional [ Sequence [str ]] = None ,
660
+ model_path : str | Path ,
661
+ op_types_to_calibrate : Sequence [str ] | None = None ,
658
662
augmented_model_path = "augmented_model.onnx" ,
659
663
use_external_data_format = False ,
660
664
method = "entropy" ,
@@ -687,8 +691,8 @@ def __init__(
687
691
class PercentileCalibrater (HistogramCalibrater ):
688
692
def __init__ (
689
693
self ,
690
- model_path : Union [ str , Path ] ,
691
- op_types_to_calibrate : Optional [ Sequence [str ]] = None ,
694
+ model_path : str | Path ,
695
+ op_types_to_calibrate : Sequence [str ] | None = None ,
692
696
augmented_model_path = "augmented_model.onnx" ,
693
697
use_external_data_format = False ,
694
698
method = "percentile" ,
@@ -721,8 +725,8 @@ def __init__(
721
725
class DistributionCalibrater (HistogramCalibrater ):
722
726
def __init__ (
723
727
self ,
724
- model_path : Union [ str , Path ] ,
725
- op_types_to_calibrate : Optional [ Sequence [str ]] = None ,
728
+ model_path : str | Path ,
729
+ op_types_to_calibrate : Sequence [str ] | None = None ,
726
730
augmented_model_path = "augmented_model.onnx" ,
727
731
use_external_data_format = False ,
728
732
method = "distribution" ,
@@ -1168,8 +1172,8 @@ def get_entropy_threshold(self, histogram, num_quantized_bins):
1168
1172
1169
1173
1170
1174
def create_calibrator (
1171
- model : Union [ str , Path ] ,
1172
- op_types_to_calibrate : Optional [ Sequence [str ]] = None ,
1175
+ model : str | Path ,
1176
+ op_types_to_calibrate : Sequence [str ] | None = None ,
1173
1177
augmented_model_path = "augmented_model.onnx" ,
1174
1178
calibrate_method = CalibrationMethod .MinMax ,
1175
1179
use_external_data_format = False ,
0 commit comments