Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

migration of quantize_ workflow configuration from callables to configs #1690

Open
vkuzo opened this issue Feb 10, 2025 · 0 comments
Open

migration of quantize_ workflow configuration from callables to configs #1690

vkuzo opened this issue Feb 10, 2025 · 0 comments
Assignees

Comments

@vkuzo
Copy link
Contributor

vkuzo commented Feb 10, 2025

summary

This issue tracks the migration of quantize_ per-workflow configuration from Callables to configs..

We are migrating the way quantize_ workflows are configured from callables (tensor subclass inserters) to direct configuration (config objects). Motivation: align with the rest of the ecosystem, enable inspection of configs after instantiation, remove a common source of confusion.

What is changing:

Specifically, here is how the signature of quantize_'s second argument will change:

#
# torchao v0.8.0 and before
#
def quantize(
    model: torch.nn.Module,
    apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module],
    ...,
): ...

#
# torchao v0.9.0
#
def quantize(
    model: torch.nn.Module,
    config: Union[AOBaseConfig, Callable[[torch.nn.Module], torch.nn.Module]],
    ...,
): ...

#
# torchao v0.10.0 or later (exact version TBD)
#
def quantize(
    model: torch.nn.Module,
    config: AOBaseConfig,
    ...,
): ...
  1. the name of the second argument to quantize_ changed from apply_tensor_subclass to config. Since the vast majority of callsites today are passing in configuration with a positional argument, this change should not affect most people.
  2. the type of the second argument to quantize_ will change from Callable[[torch.nn.Module], torch.nn.Module] to config: AOBaseConfig, following a deprecation process detailed below.
  3. for individual workflows, the user facing API name changed from snake case (int8_weight_only) to camel case (Int8WeightOnlyConfig). All argument names for each config are kept as-is. We will keep the old snake case names (int8_weight_only) around and alias them to the new names (int8_weight_only = Int8WeightOnlyConfig), to avoid breaking callsites. We plan to keep the old names forever. Here are all the workflow config name changes:
old name (will keep working) new name (recommended)
int4_weight_only Int4WeightOnlyConfig
float8_dynamic_activation_float8_weight Float8DynamicQuantizationFloat8WeightConfig
float8_static_activation_float8_weight Float8StaticActivationFloat8WeightConfig
float8_weight_only Float8WeightOnlyConfig
fpx_weight_only FPXWeightOnlyConfig
gemlite_uintx_weight_only GemliteUIntXWeightOnlyConfig
int4_dynamic_activation_int4_weight Int4DynamicActivationInt4WeightConfig
int8_dynamic_activation_int4_weight Int8DynamicActivationInt4WeightConfig
int8_dynamic_activation_int8_semi_sparse_weight n/a (deprecated)
int8_dynamic_activation_int8_weight Int8DynamicActivationInt8WeightConfig
int8_weight_only Int8WeightOnlyConfig
uintx_weight_only UIntXWeightOnlyConfig

Configuration for prototype workflows using quantize_ will be migrated at a later time. sparsify_ will be migrated in a similar fashion at a later time.

How these changes can affect you:

  1. If you are a user of existing quantize_ API workflows and are passing in config by a positional argument (quantize_(model, int8_weight_only(group_size=128))), you are not affected. This syntax will keep working going forward. You have the option to migrate your callsite to the new config name (quantize_(model, Int8WeightOnlyConfig(group_size=128)) at your own pace.
  2. If you are a user of existing quantize_ API workflows and are passing in config by a keyword argument (quantize_(model, tensor_subclass_inserter=int8_weight_only(group_size=128))), your callsite will break. You will need to change your callsite to quantize_(model, config=int8_weight_only(group_size=128)). We don't expect many people to be in this bucket.
  3. If you are a developer writing new workflows for the quantize_ API, you will need to use the new configuration system. Please see migration of quantize_ workflow configuration from callables to configs #1690 for details.
  4. If you are a user of sparsify_, you are not affected for now and a similar change will happen in a future version of torchao.

This migration will be a two step process:

  • in torchao v0.9.0, we will enable the new syntax while starting the deprecation process for the old syntax.
  • in torchao v.0.10.0 or later, we will remove the old syntax

We will keep the old callable syntax supported by quantize_ for one release cycle, and delete it afterwards. We will keep the old names as aliases for new names going forward (example: int4_weight_only as an alias of Int4WeightOnlyConfig) to keep existing callsites working without changes.

impact on API users

If you are just using the torchao quantize_ API as specified in the README, this is not BC-breaking. For example, the following syntax will keep working.

quantize_(model, int8_weight_only())

Note that the type of the object created by int8_weight_only() will change from a Callable to a config. You have the option to migrate to the explicit config creation, as follows:

quantize_(model, Int8WeightOnlyConfig())

user facing API changes

signature of quantize_

#
# before
#
def quantize(
    model: torch.nn.Module,
    apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module],
    ...,
): ...

#
# after - intermediate state, support both old and new for one release
#
def quantize(
    model: torch.nn.Module,
    config: Union[AOBaseConfig, Callable[[torch.nn.Module], torch.nn.Module]],
    ...,
): ...

#
# after - long term state
#
def quantize(
    model: torch.nn.Module,
    config: AOBaseConfig,
    ...,
): ...

usage example

An example for int4_weight_only

#
# before
#
quantize_(m, int4_weight_only(group_size=32))

#
# after, with new user facing names
#
quantize_(m, Int4WeightOnlyConfig(group_size=32))

#
# AND, after, with BC names
#
quantize_(m, int4_weight_only(group_size=32))

developer facing changes

See the PR details for examples, but they can be summarized as:

#
# old
#

# quantize_ calls the instance of calling this function on each module of the model
def int4_weight_only(group_size: int, ...) -> Callable:

    def new_callable(weight: torch.Tensor):
        # configuration is captured here via local variables
        ...
        
    # return type is a Callable
    return _get_linear_subclass_inserter(new_callable)

#
# new
#

# config base class
class AOBaseConfig(abc.ABC):
    pass

# user facing configuration of a workflow
@dataclass
class Int4WeightOnlyConfig(AOBaseConfig):
    group_size: int = 128
    ...

# not user facing transform of a module according to a worfklow's configuration
@register_quantize_module_handler(Int4WeightOnlyConfig)
def _int4_weight_only_transform(
    module: torch.nn.Module, 
    config: Int4WeightOnlyConfig,
) -> torch.nn.Module:
    # map to AQT, not user facing
    ...

migration status

quantize_ non-prototype workflow configuration

quantize_ prototype workflow configuration

Grep for callsites:

grep -r "quantize_(" torchao/prototype

experimental

sparsify_

tutorials (replace with new registration API)

replace docblocks and public facing descriptions with new names

verify partner integrations still work

confirmed two out of three here: vkuzo/pytorch_scripts#28

delete old path (one version after migration)

@vkuzo vkuzo self-assigned this Feb 10, 2025
vkuzo added a commit to vkuzo/pytorch_scripts that referenced this issue Feb 12, 2025
Summary:

Testing for pytorch/ao#1690

Convenient to have this here to test on torchao main vs torchao
experiment

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
vkuzo added a commit to vkuzo/pytorch_scripts that referenced this issue Feb 13, 2025
Summary:

Testing for pytorch/ao#1690

Convenient to have this here to test on torchao main vs torchao
experiment

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@vkuzo vkuzo changed the title placeholder for migrating workflow configuration to AOBaseConfig migration of quantize_ workflow configuration from callables to configs Feb 13, 2025
vkuzo added a commit that referenced this issue Feb 26, 2025
Summary:

Thanks to investigation from @eellison, moving the reshape
to the end of the cast helps inductor fuse the cast into a single
kernel.  This doesn't yet work with fp4, but let's unblock fp8 and deal
with fp4 later.

Fixes #1690

Note: in the repro with swizzling from
#1773, we go from 3 to 2 kernels.
Further investigation is needed whether we can fuse the swizzling.

Test Plan:

```
pytest test/prototype/mx_formats/test_mx_tensor.py -x -s -k test_to_mx_inductor_single_kernel
```

Reviewers:

Subscribers:

Tasks:

Tags:
vkuzo added a commit that referenced this issue Feb 26, 2025
Summary:

Thanks to investigation from @eellison, moving the reshape
to the end of the cast helps inductor fuse the cast into a single
kernel.  This doesn't yet work with fp4, but let's unblock fp8 and deal
with fp4 later.

Fixes #1690

Note: in the repro with swizzling from
#1773, we go from 3 to 2 kernels.
Further investigation is needed whether we can fuse the swizzling.

Test Plan:

```
pytest test/prototype/mx_formats/test_mx_tensor.py -x -s -k test_to_mx_inductor_single_kernel
```

Reviewers:

Subscribers:

Tasks:

Tags:
@vkuzo vkuzo closed this as completed in 8d110bf Feb 26, 2025
@vkuzo vkuzo reopened this Feb 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants