Skip to content

Commit

Permalink
migrates prototype/mixed_precision to configs (#1854)
Browse files Browse the repository at this point in the history
* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]
  • Loading branch information
vkuzo authored Mar 8, 2025
1 parent 49694e3 commit bc4f51d
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 17 deletions.
2 changes: 1 addition & 1 deletion test/prototype/test_mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
class TestWeightOnlyQuantNaive(unittest.TestCase):
def test_quantization_intNwo(self):
# skip test int4wo for now since it is under development in torchao
for quantization_bit in [2, 3, 5, 6, 8]:
for quantization_bit in [2, 3, 5, 6]:
for symmetric in [False, True]:
with self.subTest(
quantization_bit=quantization_bit, symmetric=symmetric
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from dataclasses import dataclass

import torch

from torchao.quantization import int4_weight_only, int8_weight_only
from torchao.quantization.quant_api import _get_linear_subclass_inserter
from torchao.core.config import AOBaseConfig
from torchao.quantization.quant_primitives import (
MappingType,
)
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)


def intN_weight_only(group_size=32, n=8, symmetric=False):
@dataclass
class IntNWeightOnlyConfig(AOBaseConfig):
"""
Apply int N-bit weight only quantization to a linear layer.
Configuration for applying int N-bit weight only quantization to a linear layer.
Args:
`group_size`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [512, 256, 128, 64, 32]
`n`: number of bits to quantize to, choices are [8, 6, 5, 4, 3, 2]
Expand All @@ -18,6 +23,25 @@ def intN_weight_only(group_size=32, n=8, symmetric=False):
quantize_(model, intN_weight_only(n=your_bit_choice, group_size=group_size), optional_filter_func_for_desired_layers_to_quantize)
"""

group_size: int = 32
n: int = 8
symmetric: bool = False


# for bc
intN_weight_only = IntNWeightOnlyConfig


@register_quantize_module_handler(IntNWeightOnlyConfig)
def _intN_weight_only_transform(
module: torch.nn.Module,
config: IntNWeightOnlyConfig,
) -> torch.nn.Module:
group_size = config.group_size
n = config.n
symmetric = config.symmetric
weight = module.weight

# for asymmetric quantization
def apply_intN_weight_only_quant_asym(weight):
# avoid circular dependency
Expand Down Expand Up @@ -64,16 +88,19 @@ def apply_intN_weight_only_quant_sym(weight):
zero_point_dtype=zero_point_dtype,
)

try:
assert n in [8, 6, 5, 4, 3, 2], "n must be one of [8, 6, 5, 4, 3, 2]"
if n == 8:
return int8_weight_only()
elif n == 4:
return int4_weight_only(group_size=group_size)
assert n in [8, 6, 5, 4, 3, 2], "n must be one of [8, 6, 5, 4, 3, 2]"
if n == 8:
raise AssertionError(
"Someone needs to refactor this code to handle int8_weight_only again"
)
elif n == 4:
raise AssertionError(
"Someone needs to refactor this code to handle int4_weight_only again"
)
else:
if symmetric:
new_weight = apply_intN_weight_only_quant_sym(weight)
else:
if symmetric:
return _get_linear_subclass_inserter(apply_intN_weight_only_quant_sym)
else:
return _get_linear_subclass_inserter(apply_intN_weight_only_quant_asym)
except Exception:
raise
new_weight = apply_intN_weight_only_quant_asym(weight)
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
return module

0 comments on commit bc4f51d

Please sign in to comment.