Skip to content

Commit d1fb58b

Browse files
authored
Quantization tool: Allow user to override calibrator's session EP (#23559)
### Description The quantization calibrators have `execution_providers` attributes but there is no way for a user to provide their own providers when using the `quantize` or `quantize_static` functions. This PR adds a `calibration_providers` parameter to allow users to specify the execution providers to use during calibration. It is helpful when quantizing large models which are slow to calibrate on the CPU. - Chose `calibration_providers` as the name since there is the docstrings refer to another `execution_provider` https://github.com/microsoft/onnxruntime/blob/169917b1e7f69daa687a5448526c189d1f7a4e2b/onnxruntime/python/tools/quantization/quantize.py#L204 https://github.com/microsoft/onnxruntime/blob/169917b1e7f69daa687a5448526c189d1f7a4e2b/onnxruntime/python/tools/quantization/quantize.py#L415 which are not present anywhere in the code. - Can change the name to something else if needed like calibrator_providers, and/or make it into a string instead of a providers list.
1 parent 649ced4 commit d1fb58b

File tree

3 files changed

+17
-1
lines changed

3 files changed

+17
-1
lines changed

onnxruntime/python/tools/quantization/calibrate.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ def add_reduce_min_max(tensor_name, reduce_op_name):
380380
else:
381381
raise ValueError(
382382
f"Unable to guess tensor type for tensor {tensor_name!r}, "
383-
f"running shape inference before quantization may resolve this issue."
383+
"running shape inference before quantization may resolve this issue."
384384
)
385385

386386
# Include axes in reduce_op when per_channel, always keeping axis=1
@@ -1177,6 +1177,7 @@ def create_calibrator(
11771177
augmented_model_path="augmented_model.onnx",
11781178
calibrate_method=CalibrationMethod.MinMax,
11791179
use_external_data_format=False,
1180+
providers=None,
11801181
extra_options={}, # noqa: B006
11811182
):
11821183
calibrator = None
@@ -1243,6 +1244,8 @@ def create_calibrator(
12431244

12441245
if calibrator:
12451246
calibrator.augment_graph()
1247+
if providers:
1248+
calibrator.execution_providers = providers
12461249
calibrator.create_inference_session()
12471250
return calibrator
12481251

onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py

+4
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def get_qnn_qdq_config(
5353
weight_symmetric: bool | None = None,
5454
keep_removable_activations: bool = False,
5555
stride: int | None = None,
56+
calibration_providers: list[str] | None = None,
5657
) -> StaticQuantConfig:
5758
"""
5859
Returns a static quantization configuration suitable for running QDQ models on QNN EP.
@@ -117,6 +118,8 @@ def get_qnn_qdq_config(
117118
are automatically removed if activations are asymmetrically quantized. Keeping these activations
118119
is necessary if optimizations or EP transformations will later remove
119120
QuantizeLinear/DequantizeLinear operators from the model.
121+
calibration_providers: Execution providers to run the session during calibration. Default is None which uses
122+
[ "CPUExecutionProvider" ].
120123
121124
Returns:
122125
A StaticQuantConfig object
@@ -192,6 +195,7 @@ def get_qnn_qdq_config(
192195
op_types_to_quantize=list(op_types.difference(OP_TYPES_TO_EXCLUDE)),
193196
per_channel=per_channel,
194197
use_external_data_format=(model_has_external_data or model.ByteSize() >= MODEL_SIZE_THRESHOLD),
198+
calibration_providers=calibration_providers,
195199
extra_options=extra_options,
196200
)
197201

onnxruntime/python/tools/quantization/quantize.py

+9
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def __init__(
9999
per_channel=False,
100100
reduce_range=False,
101101
use_external_data_format=False,
102+
calibration_providers=None,
102103
extra_options=None,
103104
):
104105
"""
@@ -112,6 +113,8 @@ def __init__(
112113
quant_format: QuantFormat{QOperator, QDQ}.
113114
QOperator format quantizes the model with quantized operators directly.
114115
QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
116+
calibration_providers: Execution providers to run the session during calibration. Default is None which uses
117+
[ "CPUExecutionProvider" ].
115118
extra_options:
116119
key value pair dictionary for various options in different case. Current used:
117120
extra.Sigmoid.nnapi = True/False (Default is False)
@@ -219,6 +222,7 @@ def __init__(
219222
self.calibration_data_reader = calibration_data_reader
220223
self.calibrate_method = calibrate_method
221224
self.quant_format = quant_format
225+
self.calibration_providers = calibration_providers
222226
self.extra_options = extra_options or {}
223227

224228

@@ -473,6 +477,7 @@ def quantize_static(
473477
nodes_to_exclude=None,
474478
use_external_data_format=False,
475479
calibrate_method=CalibrationMethod.MinMax,
480+
calibration_providers=None,
476481
extra_options=None,
477482
):
478483
"""
@@ -520,6 +525,8 @@ def quantize_static(
520525
List of nodes names to exclude. The nodes in this list will be excluded from quantization
521526
when it is not None.
522527
use_external_data_format: option used for large size (>2GB) model. Set to False by default.
528+
calibration_providers: Execution providers to run the session during calibration. Default is None which uses
529+
[ "CPUExecutionProvider" ]
523530
extra_options:
524531
key value pair dictionary for various options in different case. Current used:
525532
extra.Sigmoid.nnapi = True/False (Default is False)
@@ -697,6 +704,7 @@ def inc_dataloader():
697704
augmented_model_path=Path(quant_tmp_dir).joinpath("augmented_model.onnx").as_posix(),
698705
calibrate_method=calibrate_method,
699706
use_external_data_format=use_external_data_format,
707+
providers=calibration_providers,
700708
extra_options=calib_extra_options,
701709
)
702710

@@ -890,6 +898,7 @@ def quantize(
890898
per_channel=quant_config.per_channel,
891899
reduce_range=quant_config.reduce_range,
892900
use_external_data_format=quant_config.use_external_data_format,
901+
calibration_providers=quant_config.calibration_providers,
893902
extra_options=quant_config.extra_options,
894903
)
895904

0 commit comments

Comments
 (0)