diff --git a/test/prototype/test_smoothquant.py b/test/prototype/test_smoothquant.py index d90990143c..aed1f6fcd8 100644 --- a/test/prototype/test_smoothquant.py +++ b/test/prototype/test_smoothquant.py @@ -5,11 +5,11 @@ import torch from torchao.prototype.smoothquant import ( + SmoothQuantConfig, SmoothQuantObservedLinear, insert_smooth_quant_observer_, load_smooth_quant_recipe, save_smooth_quant_recipe, - smooth_quant, ) from torchao.quantization import quantize_ from torchao.quantization.utils import ( @@ -85,7 +85,7 @@ def forward(self, x): m(data) # quantize is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) - quantize_(m, smooth_quant(), is_observed_linear) + quantize_(m, SmoothQuantConfig(), is_observed_linear) with torch.inference_mode(): if TORCH_VERSION_AT_LEAST_2_5: m = torch.compile(m, fullgraph=True) @@ -173,7 +173,7 @@ def test_save_load_recipe(alpha, quant_mode, device, idtype): # quantize is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) - quantize_(m, smooth_quant(), is_observed_linear) + quantize_(m, SmoothQuantConfig(), is_observed_linear) if TORCH_VERSION_AT_LEAST_2_5: # earlier versions are not compatible m = torch.compile(m, fullgraph=True) diff --git a/torchao/prototype/smoothquant/README.md b/torchao/prototype/smoothquant/README.md index fa64fc4460..c268a83504 100644 --- a/torchao/prototype/smoothquant/README.md +++ b/torchao/prototype/smoothquant/README.md @@ -27,7 +27,7 @@ python example.py -m MODLE_ID --device= --quant-mode= Dict[str, torch.Tensor]: @@ -121,7 +105,14 @@ def recurse(module: torch.nn.Module, name: str = ""): # act_scales is None for dynamic quantization if any(x is None for x in (smoothing_factor, wei_scales)): return module - return smooth_quant(smoothing_factor, act_scales, wei_scales)(module) + is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) + wrapper = torch.nn.Sequential(module) + quantize_( + wrapper, + SmoothQuantConfig(smoothing_factor, act_scales, wei_scales), + is_observed_linear, + ) + return wrapper[0] mod_new = module @@ -158,13 +149,10 @@ def static_quantize(self, input, scale, zero_point): ) -def smooth_quant( - smoothing_factor: Optional[torch.Tensor] = None, - act_scales: Optional[torch.Tensor] = None, - wei_scales: Optional[torch.Tensor] = None, -): +@dataclass +class SmoothQuantConfig(AOBaseConfig): """ - Quantizes linear layers when passed into quantize_() + Configuration for quantizing linear layers when passed into quantize_() Args: smoothing_factor: The smoothing factor for the layer. Acquired from the layer's observer if None. @@ -172,40 +160,62 @@ def smooth_quant( wei_scales: The weight scales for the layer. Acquired from the layer's observer if None. """ - def quantize_weight(observed_linear): - target_dtype = torch.int8 - # act_scales is None for dynamic quantization thus not checked - if any(x is None for x in (smoothing_factor, wei_scales)): - factor, x_scale, w_scales = observed_linear.obs.calculate_qparams() - weight = observed_linear.obs.weight * factor - else: - factor, x_scale, w_scales = smoothing_factor, act_scales, wei_scales - weight = observed_linear.weight * factor - weight = weight.to(observed_linear.weight.dtype) - block_size = (1, weight.size(1)) - wei_zero_points = torch.zeros_like(w_scales, dtype=torch.int64) - qw = to_affine_quantized_intx_static( - weight, - w_scales, - wei_zero_points, - block_size, - target_dtype, - ) + smoothing_factor: Optional[torch.Tensor] = None + act_scales: Optional[torch.Tensor] = None + wei_scales: Optional[torch.Tensor] = None - if x_scale is None: - # dynamic quant - qw = to_linear_activation_quantized( - qw, _ActQuantizer(target_dtype).dynamic_quantize - ) - else: - # static quant - x_zero_point = torch.zeros_like(x_scale, dtype=torch.int64) - qw = to_weight_tensor_with_linear_activation_quantization_metadata( - qw, _ActQuantizer(target_dtype).static_quantize, x_scale, x_zero_point - ) - return to_weight_tensor_with_linear_activation_scale_metadata( - qw, factor.to(qw.dtype) +@register_quantize_module_handler(SmoothQuantConfig) +def _smooth_quant_transform( + module: torch.nn.Module, + config: SmoothQuantConfig, +): + smoothing_factor = config.smoothing_factor + act_scales = config.act_scales + wei_scales = config.wei_scales + observed_linear = module + + linear = torch.nn.Linear( + observed_linear.in_features, + observed_linear.out_features, + observed_linear.bias is not None, + device=observed_linear.weight.device, + dtype=observed_linear.weight.dtype, + ) + linear.bias = observed_linear.bias + + target_dtype = torch.int8 + # act_scales is None for dynamic quantization thus not checked + if any(x is None for x in (smoothing_factor, wei_scales)): + factor, x_scale, w_scales = observed_linear.obs.calculate_qparams() + weight = observed_linear.obs.weight * factor + else: + factor, x_scale, w_scales = smoothing_factor, act_scales, wei_scales + weight = observed_linear.weight * factor + weight = weight.to(observed_linear.weight.dtype) + block_size = (1, weight.size(1)) + wei_zero_points = torch.zeros_like(w_scales, dtype=torch.int64) + qw = to_affine_quantized_intx_static( + weight, + w_scales, + wei_zero_points, + block_size, + target_dtype, + ) + + if x_scale is None: + # dynamic quant + qw = to_linear_activation_quantized( + qw, _ActQuantizer(target_dtype).dynamic_quantize + ) + else: + # static quant + x_zero_point = torch.zeros_like(x_scale, dtype=torch.int64) + qw = to_weight_tensor_with_linear_activation_quantization_metadata( + qw, _ActQuantizer(target_dtype).static_quantize, x_scale, x_zero_point ) - return _observed_linear_subclass_inserter(quantize_weight) + qw = to_weight_tensor_with_linear_activation_scale_metadata(qw, factor.to(qw.dtype)) + linear.weight = torch.nn.Parameter(qw, requires_grad=False) + linear.extra_repr = types.MethodType(_linear_extra_repr, module) + return linear diff --git a/torchao/prototype/smoothquant/example.py b/torchao/prototype/smoothquant/example.py index 0075502595..788201d3fe 100644 --- a/torchao/prototype/smoothquant/example.py +++ b/torchao/prototype/smoothquant/example.py @@ -9,9 +9,9 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from torchao.prototype.smoothquant import ( + SmoothQuantConfig, SmoothQuantObservedLinear, insert_smooth_quant_observer_, - smooth_quant, ) from torchao.quantization import quantize_ @@ -145,7 +145,7 @@ def wikitext2_ppl( is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear) print(f"running SmoothQuant with {quant_mode} quantization") t0 = time.time() - quantize_(model, smooth_quant(), is_observed_linear) + quantize_(model, SmoothQuantConfig(), is_observed_linear) print(f"time for quantization: {time.time() - t0:.02f} seconds") if model_save_path is not None: print(f"Saving quantized model to {model_save_path}")