diff --git a/test/prototype/test_mixed_precision.py b/test/prototype/test_mixed_precision.py index b921860821..3af56c7bf3 100644 --- a/test/prototype/test_mixed_precision.py +++ b/test/prototype/test_mixed_precision.py @@ -13,7 +13,7 @@ class TestWeightOnlyQuantNaive(unittest.TestCase): def test_quantization_intNwo(self): # skip test int4wo for now since it is under development in torchao - for quantization_bit in [2, 3, 5, 6, 8]: + for quantization_bit in [2, 3, 5, 6]: for symmetric in [False, True]: with self.subTest( quantization_bit=quantization_bit, symmetric=symmetric diff --git a/torchao/prototype/quantization/mixed_precision/scripts/naive_intNwo.py b/torchao/prototype/quantization/mixed_precision/scripts/naive_intNwo.py index 44680b0fb1..76f8230f30 100644 --- a/torchao/prototype/quantization/mixed_precision/scripts/naive_intNwo.py +++ b/torchao/prototype/quantization/mixed_precision/scripts/naive_intNwo.py @@ -1,15 +1,20 @@ +from dataclasses import dataclass + import torch -from torchao.quantization import int4_weight_only, int8_weight_only -from torchao.quantization.quant_api import _get_linear_subclass_inserter +from torchao.core.config import AOBaseConfig from torchao.quantization.quant_primitives import ( MappingType, ) +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) -def intN_weight_only(group_size=32, n=8, symmetric=False): +@dataclass +class IntNWeightOnlyConfig(AOBaseConfig): """ - Apply int N-bit weight only quantization to a linear layer. + Configuration for applying int N-bit weight only quantization to a linear layer. Args: `group_size`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [512, 256, 128, 64, 32] `n`: number of bits to quantize to, choices are [8, 6, 5, 4, 3, 2] @@ -18,6 +23,25 @@ def intN_weight_only(group_size=32, n=8, symmetric=False): quantize_(model, intN_weight_only(n=your_bit_choice, group_size=group_size), optional_filter_func_for_desired_layers_to_quantize) """ + group_size: int = 32 + n: int = 8 + symmetric: bool = False + + +# for bc +intN_weight_only = IntNWeightOnlyConfig + + +@register_quantize_module_handler(IntNWeightOnlyConfig) +def _intN_weight_only_transform( + module: torch.nn.Module, + config: IntNWeightOnlyConfig, +) -> torch.nn.Module: + group_size = config.group_size + n = config.n + symmetric = config.symmetric + weight = module.weight + # for asymmetric quantization def apply_intN_weight_only_quant_asym(weight): # avoid circular dependency @@ -64,16 +88,19 @@ def apply_intN_weight_only_quant_sym(weight): zero_point_dtype=zero_point_dtype, ) - try: - assert n in [8, 6, 5, 4, 3, 2], "n must be one of [8, 6, 5, 4, 3, 2]" - if n == 8: - return int8_weight_only() - elif n == 4: - return int4_weight_only(group_size=group_size) + assert n in [8, 6, 5, 4, 3, 2], "n must be one of [8, 6, 5, 4, 3, 2]" + if n == 8: + raise AssertionError( + "Someone needs to refactor this code to handle int8_weight_only again" + ) + elif n == 4: + raise AssertionError( + "Someone needs to refactor this code to handle int4_weight_only again" + ) + else: + if symmetric: + new_weight = apply_intN_weight_only_quant_sym(weight) else: - if symmetric: - return _get_linear_subclass_inserter(apply_intN_weight_only_quant_sym) - else: - return _get_linear_subclass_inserter(apply_intN_weight_only_quant_asym) - except Exception: - raise + new_weight = apply_intN_weight_only_quant_asym(weight) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + return module