Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

config migration: smoothquant #1851

Merged
merged 4 commits into from
Mar 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions test/prototype/test_smoothquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import torch

from torchao.prototype.smoothquant import (
SmoothQuantConfig,
SmoothQuantObservedLinear,
insert_smooth_quant_observer_,
load_smooth_quant_recipe,
save_smooth_quant_recipe,
smooth_quant,
)
from torchao.quantization import quantize_
from torchao.quantization.utils import (
Expand Down Expand Up @@ -85,7 +85,7 @@ def forward(self, x):
m(data)
# quantize
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
quantize_(m, smooth_quant(), is_observed_linear)
quantize_(m, SmoothQuantConfig(), is_observed_linear)
with torch.inference_mode():
if TORCH_VERSION_AT_LEAST_2_5:
m = torch.compile(m, fullgraph=True)
Expand Down Expand Up @@ -173,7 +173,7 @@ def test_save_load_recipe(alpha, quant_mode, device, idtype):

# quantize
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
quantize_(m, smooth_quant(), is_observed_linear)
quantize_(m, SmoothQuantConfig(), is_observed_linear)
if TORCH_VERSION_AT_LEAST_2_5:
# earlier versions are not compatible
m = torch.compile(m, fullgraph=True)
Expand Down
6 changes: 3 additions & 3 deletions torchao/prototype/smoothquant/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ python example.py -m MODLE_ID --device=<cuda or cpu> --quant-mode=<dynamic or st
## Usage of API
The following APIs are provided:
- insert_smooth_quant_observer_
- smooth_quant
- SmoothQuantConfig
- save_smooth_quant_recipe (advanced)
- load_smooth_quant_recipe (advanced)

Expand All @@ -37,11 +37,11 @@ insert_smooth_quant_observer_(model, alpha=0.5, quant_mode="dynamic")
```
After insertion, run the model for calibration on a certain dataset or (advanced) load a recipe.

`smooth_quant` applies SmoothQuant to each linear layer of the model. Use it by calling `torchao.quantization.quantize_`. For example:
`SmoothQuantConfig` configures appliying SmoothQuant to each linear layer of the model. Use it by calling `torchao.quantization.quantize_`. For example:
```python
from torchao.prototype.smoothquant import SmoothQuantObservedLinear
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
torchao.quantization.quantize_(model, smooth_quant(), is_observed_linear)
torchao.quantization.quantize_(model, SmoothQuantConfig(), is_observed_linear)
```
`is_observed_linear` is a filter so that we only quantize observed linear layers.

Expand Down
4 changes: 2 additions & 2 deletions torchao/prototype/smoothquant/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from .api import (
SmoothQuantConfig,
insert_smooth_quant_observer_,
load_smooth_quant_recipe,
save_smooth_quant_recipe,
smooth_quant,
)
from .core import SmoothQuantObservedLinear

__all__ = [
"insert_smooth_quant_observer_",
"load_smooth_quant_recipe",
"save_smooth_quant_recipe",
"smooth_quant",
"SmoothQuantConfig",
"SmoothQuantObservedLinear",
]
144 changes: 77 additions & 67 deletions torchao/prototype/smoothquant/api.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,30 @@
import types
from dataclasses import dataclass
from typing import Dict, Optional

import torch

from torchao.core.config import AOBaseConfig
from torchao.dtypes import to_affine_quantized_intx, to_affine_quantized_intx_static
from torchao.prototype.smoothquant.core import (
SmoothQuantObservedLinear,
SmoothQuantObserver,
)
from torchao.quantization import quantize_
from torchao.quantization.linear_activation_quantized_tensor import (
to_linear_activation_quantized,
)
from torchao.quantization.linear_activation_scale import (
to_weight_tensor_with_linear_activation_scale_metadata,
)
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
from torchao.quantization.quant_api import (
_linear_extra_repr,
_replace_with_custom_fn_if_matches_filter,
)
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)
from torchao.quantization.utils import _get_per_token_block_size
from torchao.quantization.weight_tensor_linear_activation_quantization import (
to_weight_tensor_with_linear_activation_quantization_metadata,
Expand Down Expand Up @@ -53,32 +63,6 @@ def replace_with_observer(layer):
_replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear)


def _observed_linear_subclass_inserter(constructor):
"""
Replaces unquantized observed linear instances with quantized linear instances.

Args:
constructor: the function which applies quantization to the observed linear layer
"""

def insert_subclass(observed_linear):
# creates the new linear layer using constructor
linear = torch.nn.Linear(
observed_linear.in_features,
observed_linear.out_features,
observed_linear.bias is not None,
device=observed_linear.weight.device,
dtype=observed_linear.weight.dtype,
)
linear.weight = torch.nn.Parameter(
constructor(observed_linear), requires_grad=False
)
linear.bias = observed_linear.bias
return linear

return insert_subclass


def save_smooth_quant_recipe(
model: torch.nn.Module, save_path: str
) -> Dict[str, torch.Tensor]:
Expand Down Expand Up @@ -121,7 +105,14 @@ def recurse(module: torch.nn.Module, name: str = ""):
# act_scales is None for dynamic quantization
if any(x is None for x in (smoothing_factor, wei_scales)):
return module
return smooth_quant(smoothing_factor, act_scales, wei_scales)(module)
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
wrapper = torch.nn.Sequential(module)
quantize_(
wrapper,
SmoothQuantConfig(smoothing_factor, act_scales, wei_scales),
is_observed_linear,
)
return wrapper[0]

mod_new = module

Expand Down Expand Up @@ -158,54 +149,73 @@ def static_quantize(self, input, scale, zero_point):
)


def smooth_quant(
smoothing_factor: Optional[torch.Tensor] = None,
act_scales: Optional[torch.Tensor] = None,
wei_scales: Optional[torch.Tensor] = None,
):
@dataclass
class SmoothQuantConfig(AOBaseConfig):
"""
Quantizes linear layers when passed into quantize_()
Configuration for quantizing linear layers when passed into quantize_()

Args:
smoothing_factor: The smoothing factor for the layer. Acquired from the layer's observer if None.
act_scales: The activation scales for the layer. Acquired from the layer's observer if None.
wei_scales: The weight scales for the layer. Acquired from the layer's observer if None.
"""

def quantize_weight(observed_linear):
target_dtype = torch.int8
# act_scales is None for dynamic quantization thus not checked
if any(x is None for x in (smoothing_factor, wei_scales)):
factor, x_scale, w_scales = observed_linear.obs.calculate_qparams()
weight = observed_linear.obs.weight * factor
else:
factor, x_scale, w_scales = smoothing_factor, act_scales, wei_scales
weight = observed_linear.weight * factor
weight = weight.to(observed_linear.weight.dtype)
block_size = (1, weight.size(1))
wei_zero_points = torch.zeros_like(w_scales, dtype=torch.int64)
qw = to_affine_quantized_intx_static(
weight,
w_scales,
wei_zero_points,
block_size,
target_dtype,
)
smoothing_factor: Optional[torch.Tensor] = None
act_scales: Optional[torch.Tensor] = None
wei_scales: Optional[torch.Tensor] = None

if x_scale is None:
# dynamic quant
qw = to_linear_activation_quantized(
qw, _ActQuantizer(target_dtype).dynamic_quantize
)
else:
# static quant
x_zero_point = torch.zeros_like(x_scale, dtype=torch.int64)
qw = to_weight_tensor_with_linear_activation_quantization_metadata(
qw, _ActQuantizer(target_dtype).static_quantize, x_scale, x_zero_point
)

return to_weight_tensor_with_linear_activation_scale_metadata(
qw, factor.to(qw.dtype)
@register_quantize_module_handler(SmoothQuantConfig)
def _smooth_quant_transform(
module: torch.nn.Module,
config: SmoothQuantConfig,
):
smoothing_factor = config.smoothing_factor
act_scales = config.act_scales
wei_scales = config.wei_scales
observed_linear = module

linear = torch.nn.Linear(
observed_linear.in_features,
observed_linear.out_features,
observed_linear.bias is not None,
device=observed_linear.weight.device,
dtype=observed_linear.weight.dtype,
)
linear.bias = observed_linear.bias

target_dtype = torch.int8
# act_scales is None for dynamic quantization thus not checked
if any(x is None for x in (smoothing_factor, wei_scales)):
factor, x_scale, w_scales = observed_linear.obs.calculate_qparams()
weight = observed_linear.obs.weight * factor
else:
factor, x_scale, w_scales = smoothing_factor, act_scales, wei_scales
weight = observed_linear.weight * factor
weight = weight.to(observed_linear.weight.dtype)
block_size = (1, weight.size(1))
wei_zero_points = torch.zeros_like(w_scales, dtype=torch.int64)
qw = to_affine_quantized_intx_static(
weight,
w_scales,
wei_zero_points,
block_size,
target_dtype,
)

if x_scale is None:
# dynamic quant
qw = to_linear_activation_quantized(
qw, _ActQuantizer(target_dtype).dynamic_quantize
)
else:
# static quant
x_zero_point = torch.zeros_like(x_scale, dtype=torch.int64)
qw = to_weight_tensor_with_linear_activation_quantization_metadata(
qw, _ActQuantizer(target_dtype).static_quantize, x_scale, x_zero_point
)

return _observed_linear_subclass_inserter(quantize_weight)
qw = to_weight_tensor_with_linear_activation_scale_metadata(qw, factor.to(qw.dtype))
linear.weight = torch.nn.Parameter(qw, requires_grad=False)
linear.extra_repr = types.MethodType(_linear_extra_repr, module)
return linear
4 changes: 2 additions & 2 deletions torchao/prototype/smoothquant/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from transformers import AutoModelForCausalLM, AutoTokenizer

from torchao.prototype.smoothquant import (
SmoothQuantConfig,
SmoothQuantObservedLinear,
insert_smooth_quant_observer_,
smooth_quant,
)
from torchao.quantization import quantize_

Expand Down Expand Up @@ -145,7 +145,7 @@ def wikitext2_ppl(
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
print(f"running SmoothQuant with {quant_mode} quantization")
t0 = time.time()
quantize_(model, smooth_quant(), is_observed_linear)
quantize_(model, SmoothQuantConfig(), is_observed_linear)
print(f"time for quantization: {time.time() - t0:.02f} seconds")
if model_save_path is not None:
print(f"Saving quantized model to {model_save_path}")
Expand Down
Loading