Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

migrates prototype/mixed_precision to configs #1854

Merged
merged 24 commits into from
Mar 8, 2025
2 changes: 1 addition & 1 deletion test/prototype/test_mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
class TestWeightOnlyQuantNaive(unittest.TestCase):
def test_quantization_intNwo(self):
# skip test int4wo for now since it is under development in torchao
for quantization_bit in [2, 3, 5, 6, 8]:
for quantization_bit in [2, 3, 5, 6]:
for symmetric in [False, True]:
with self.subTest(
quantization_bit=quantization_bit, symmetric=symmetric
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from dataclasses import dataclass

import torch

from torchao.quantization import int4_weight_only, int8_weight_only
from torchao.quantization.quant_api import _get_linear_subclass_inserter
from torchao.core.config import AOBaseConfig
from torchao.quantization.quant_primitives import (
MappingType,
)
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)


def intN_weight_only(group_size=32, n=8, symmetric=False):
@dataclass
class IntNWeightOnlyConfig(AOBaseConfig):
"""
Apply int N-bit weight only quantization to a linear layer.
Configuration for applying int N-bit weight only quantization to a linear layer.
Args:
`group_size`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [512, 256, 128, 64, 32]
`n`: number of bits to quantize to, choices are [8, 6, 5, 4, 3, 2]
Expand All @@ -18,6 +23,25 @@ def intN_weight_only(group_size=32, n=8, symmetric=False):
quantize_(model, intN_weight_only(n=your_bit_choice, group_size=group_size), optional_filter_func_for_desired_layers_to_quantize)
"""

group_size: int = 32
n: int = 8
symmetric: bool = False


# for bc
intN_weight_only = IntNWeightOnlyConfig
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo these prototype folders we should just break BC

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, but it's actually easier to keep BC here (to not change callsites) - if someone wants to change the callsites of all these prototype features, that would sgtm in a separate PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess, but also changing callsites within a repro is - I find - not that hard and ultimately helps create a more unified api as we move everything to configs + registrations

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok will do

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@drisspg , I take it back. See how the scope of #1851 expanded from removing the bc name. Some of these other prototypes have even more callsites throughout the codebase, some of those callsites aren't easily testable, etc - I just don't think it's worth the effort to tie removing old names to this PR stack since we have to keep BC anyways for some of the APIs.



@register_quantize_module_handler(IntNWeightOnlyConfig)
def _intN_weight_only_transform(
module: torch.nn.Module,
config: IntNWeightOnlyConfig,
) -> torch.nn.Module:
group_size = config.group_size
n = config.n
symmetric = config.symmetric
weight = module.weight

# for asymmetric quantization
def apply_intN_weight_only_quant_asym(weight):
# avoid circular dependency
Expand Down Expand Up @@ -64,16 +88,19 @@ def apply_intN_weight_only_quant_sym(weight):
zero_point_dtype=zero_point_dtype,
)

try:
assert n in [8, 6, 5, 4, 3, 2], "n must be one of [8, 6, 5, 4, 3, 2]"
if n == 8:
return int8_weight_only()
elif n == 4:
return int4_weight_only(group_size=group_size)
assert n in [8, 6, 5, 4, 3, 2], "n must be one of [8, 6, 5, 4, 3, 2]"
if n == 8:
raise AssertionError(
"Someone needs to refactor this code to handle int8_weight_only again"
)
elif n == 4:
raise AssertionError(
"Someone needs to refactor this code to handle int4_weight_only again"
)
else:
if symmetric:
new_weight = apply_intN_weight_only_quant_sym(weight)
else:
if symmetric:
return _get_linear_subclass_inserter(apply_intN_weight_only_quant_sym)
else:
return _get_linear_subclass_inserter(apply_intN_weight_only_quant_asym)
except Exception:
raise
new_weight = apply_intN_weight_only_quant_asym(weight)
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
return module
Loading