Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

migrate prototype/quantized_training to configs #1855

Open
wants to merge 35 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion test/prototype/test_quantized_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ def test_int8_mixed_precision_training(self, compile, config, module_swap):

linear = nn.Linear(embed_dim, embed_dim, device=device)
linear_int8mp = copy.deepcopy(linear)
apply_func = int8_mixed_precision_training(config, module_swap=module_swap)
config.module_swap = module_swap
apply_func = int8_mixed_precision_training(config)
quantize_(linear_int8mp, apply_func, set_inductor_config=False)

if compile:
Expand Down
25 changes: 20 additions & 5 deletions torchao/prototype/quantized_training/bitnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
from torch.distributed._tensor import DTensor
from torch.utils._triton import has_triton

from torchao.quantization.quant_api import _get_linear_subclass_inserter
from torchao.core.config import AOBaseConfig
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)
from torchao.utils import TorchAOBaseTensor

from .int8 import quantize_int8_rowwise
Expand Down Expand Up @@ -232,10 +235,22 @@ def backward(ctx, grad_output):
return grad_input, grad_weight, grad_bias


def bitnet_training():
return _get_linear_subclass_inserter(
BitNetTrainingLinearWeight, allow_requires_grad=True
)
class BitNetTrainingConfig(AOBaseConfig):
pass


# for bc
bitnet_training = BitNetTrainingConfig
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto on bc break here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd actually prefer to keep it, see previous PR for some additional context



@register_quantize_module_handler(BitNetTrainingConfig)
def _bitnet_training_transform(
module: torch.nn.Module,
config: BitNetTrainingConfig,
) -> torch.nn.Module:
new_weight = BitNetTrainingLinearWeight(module.weight)
module.weight = torch.nn.Parameter(new_weight, requires_grad=True)
return module


def _pack_i2_in_i8(x: Tensor):
Expand Down
25 changes: 20 additions & 5 deletions torchao/prototype/quantized_training/int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
from torch import Tensor
from torch.utils._python_dispatch import return_and_correct_aliasing

from torchao.quantization.quant_api import _get_linear_subclass_inserter
from torchao.core.config import AOBaseConfig
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)
from torchao.utils import TorchAOBaseTensor

aten = torch.ops.aten
Expand Down Expand Up @@ -293,7 +296,19 @@ def _(func, types, args, kwargs):
return return_and_correct_aliasing(func, args, kwargs, out)


def int8_weight_only_quantized_training():
return _get_linear_subclass_inserter(
Int8QuantizedTrainingLinearWeight.from_float, allow_requires_grad=True
)
class Int8WeightOnlyQuantizedTrainingConfig(AOBaseConfig):
pass


# for bc
int8_weight_only_quantized_training = Int8WeightOnlyQuantizedTrainingConfig


@register_quantize_module_handler(Int8WeightOnlyQuantizedTrainingConfig)
def _int8_weight_only_quantized_training_transform(
module: torch.nn.Module,
config: Int8WeightOnlyQuantizedTrainingConfig,
) -> torch.nn.Module:
new_weight = Int8QuantizedTrainingLinearWeight.from_float(module.weight)
module.weight = torch.nn.Parameter(new_weight, requires_grad=True)
return module
44 changes: 26 additions & 18 deletions torchao/prototype/quantized_training/int8_mixed_precision.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from typing import Any, NamedTuple, Optional, Tuple, Union
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union

import torch
import torch.utils._pytree as pytree
from torch import Tensor, nn
from torch.utils._triton import has_triton

from torchao.quantization.quant_api import _get_linear_subclass_inserter
from torchao.core.config import AOBaseConfig
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)
from torchao.utils import TorchAOBaseTensor

from .int8 import quantize_int8_rowwise
Expand All @@ -23,10 +27,16 @@ def scaled_int8_mm(
return torch._int_mm(A, B) * col_scale.view(-1) * row_scale.view(-1, 1)


class Int8MixedPrecisionTrainingConfig(NamedTuple):
@dataclass
class Int8MixedPrecisionTrainingConfig(AOBaseConfig):
output: bool = True
grad_input: bool = True
grad_weight: bool = True
module_swap: bool = False


# for bc
int8_mixed_precision_training = Int8MixedPrecisionTrainingConfig


_DEFAULT_CONFIG = Int8MixedPrecisionTrainingConfig()
Expand Down Expand Up @@ -265,25 +275,23 @@ def backward(ctx, grad_output):
return grad_input, grad_weight, grad_bias, None


def int8_mixed_precision_training(
config: Int8MixedPrecisionTrainingConfig = _DEFAULT_CONFIG,
*,
module_swap: bool = False,
@register_quantize_module_handler(Int8MixedPrecisionTrainingConfig)
def _int8_mixed_precision_training_transform(
module: torch.nn.Module,
config: Int8MixedPrecisionTrainingConfig,
):
module_swap = config.module_swap

# TODO: skip small layers that don't have perf gain.
if module_swap:
# module swap implementation
def convert_linear(linear: nn.Linear):
linear.__class__ = Int8MixedPrecisionTrainingLinear
linear.config = config
return linear

return convert_linear
module.__class__ = Int8MixedPrecisionTrainingLinear
module.config = config
return module

else:
# tensor subclass implementation
return _get_linear_subclass_inserter(
Int8MixedPrecisionTrainingLinearWeight,
config=config,
allow_requires_grad=True,
)

new_weight = Int8MixedPrecisionTrainingLinearWeight(module.weight, config)
module.weight = torch.nn.Parameter(new_weight, requires_grad=True)
return module
Loading