@@ -99,6 +99,7 @@ def __init__(
99
99
per_channel = False ,
100
100
reduce_range = False ,
101
101
use_external_data_format = False ,
102
+ calibration_providers = None ,
102
103
extra_options = None ,
103
104
):
104
105
"""
@@ -112,6 +113,8 @@ def __init__(
112
113
quant_format: QuantFormat{QOperator, QDQ}.
113
114
QOperator format quantizes the model with quantized operators directly.
114
115
QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
116
+ calibration_providers: Execution providers to run the session during calibration. Default is None which uses
117
+ [ "CPUExecutionProvider" ].
115
118
extra_options:
116
119
key value pair dictionary for various options in different case. Current used:
117
120
extra.Sigmoid.nnapi = True/False (Default is False)
@@ -219,6 +222,7 @@ def __init__(
219
222
self .calibration_data_reader = calibration_data_reader
220
223
self .calibrate_method = calibrate_method
221
224
self .quant_format = quant_format
225
+ self .calibration_providers = calibration_providers
222
226
self .extra_options = extra_options or {}
223
227
224
228
@@ -473,6 +477,7 @@ def quantize_static(
473
477
nodes_to_exclude = None ,
474
478
use_external_data_format = False ,
475
479
calibrate_method = CalibrationMethod .MinMax ,
480
+ calibration_providers = None ,
476
481
extra_options = None ,
477
482
):
478
483
"""
@@ -520,6 +525,8 @@ def quantize_static(
520
525
List of nodes names to exclude. The nodes in this list will be excluded from quantization
521
526
when it is not None.
522
527
use_external_data_format: option used for large size (>2GB) model. Set to False by default.
528
+ calibration_providers: Execution providers to run the session during calibration. Default is None which uses
529
+ [ "CPUExecutionProvider" ]
523
530
extra_options:
524
531
key value pair dictionary for various options in different case. Current used:
525
532
extra.Sigmoid.nnapi = True/False (Default is False)
@@ -697,6 +704,7 @@ def inc_dataloader():
697
704
augmented_model_path = Path (quant_tmp_dir ).joinpath ("augmented_model.onnx" ).as_posix (),
698
705
calibrate_method = calibrate_method ,
699
706
use_external_data_format = use_external_data_format ,
707
+ providers = calibration_providers ,
700
708
extra_options = calib_extra_options ,
701
709
)
702
710
@@ -890,6 +898,7 @@ def quantize(
890
898
per_channel = quant_config .per_channel ,
891
899
reduce_range = quant_config .reduce_range ,
892
900
use_external_data_format = quant_config .use_external_data_format ,
901
+ calibration_providers = quant_config .calibration_providers ,
893
902
extra_options = quant_config .extra_options ,
894
903
)
895
904
0 commit comments