-
Notifications
You must be signed in to change notification settings - Fork 228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Migrate to config for Int8DynamicActivationIntxWeightConfig #1836
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1836
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 11 PendingAs of commit 3f51d82 with merge base ada4c02 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@drisspg @jerryzh168 are we ok adding tensor_impl_ctr_kwargs to from_hp_to_intx.
It can be used to propagate a bias when constructing the weight tensor subclass via from_plain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not super familiar with this code, but as long as this doesn't change the BC surface sgtm
if this is controversial, can we separate this from the config migration? I'd love to see that piece land asap.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new tensor_impl_ctr_kwargs has a default value of None, so it shouldn't change how any existing call sites work. The CI also passes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sgtm, let's land if no other concerns?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I can land after CI passes.
I don't have concerns on the changes in torchao/experimental/*. I mostly wanted feedback from someone in torchao on this change from torchao/dtypes/affine_quantized_tensor.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean,
tensor_impl_ctr_kwargs: Optional[dict] = None,
is pretty hard to follow. IMO it would be better to refactor the code to just pass bias directly instead of adding a layer of indirection.
However, this is nitty, not a part of BC, and I want to see the config part land, so how about we chat about ^ in parallel and if it needs fixing someone can do that in a future PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason for using tensor_impl_ctr_kwargs instead of "bias" is I thought it would be more extensible in future.
Currently if you use to_affine_quantized_intx to do quantization, there is no way to forward other args to your tensor subclass's "from_layout(data, scale, zero_point, _layout)" method. Here we want to forward bias, but in future someone might want to forward something else.
What is currently done is the code in torchao/experimental has its own copy of to_affine_quantized_intx that supports bias. The downside of this is I fear the two might drift apart going forward.
With all that said, I can refactor this PR to only contain the config change and put up the tensor_impl_ctr_kwargs change in another PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't get it, why can't we just add bias
as an argument and pass it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move the discussion to a future PR. I split out the change.
e332b54
to
4b3a742
Compare
f138c3d
to
fc46e34
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly nits
Not terribly familiar with this code, but passes the gut test
if tensor_impl_ctr_kwargs is None: | ||
tensor_impl_ctr_kwargs = {} | ||
tensor_impl = tensor_impl_ctr( | ||
data, scale, zero_point, _layout, **tensor_impl_ctr_kwargs | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't know which style AO uses, no strong pref
if tensor_impl_ctr_kwargs is None: | |
tensor_impl_ctr_kwargs = {} | |
tensor_impl = tensor_impl_ctr( | |
data, scale, zero_point, _layout, **tensor_impl_ctr_kwargs | |
) | |
tensor_impl = tensor_impl_ctr( | |
data, scale, zero_point, _layout, **(tensor_impl_ctr_kwargs or {}) | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to hear from @drisspg or someone from torchao on this change.
Not so much on the style preference, but more so on whether they're OK adding tensor_impl_ctr_kwargs to the to_affine_quantized_intx signature.
quantized_model_reference = copy.deepcopy(model) | ||
quantize_( | ||
quantized_model_reference, | ||
int8_dynamic_activation_intx_weight( | ||
weight_dtype=weight_dtype, | ||
granularity=granularity, | ||
has_weight_zeros=has_weight_zeros, | ||
layout=reference_layout, | ||
), | ||
) | ||
|
||
with torch.no_grad(): | ||
result = quantized_model(activations) | ||
expected_result = quantized_model_reference(activations) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: We can factor out the creation of expected_results since it's just PlainLayout in both cases (different models)
and layout.target == Target.ATEN | ||
) | ||
weight_dtype: torch.dtype = torch.int4 | ||
granularity: Union[PerRow, PerGroup] = PerRow() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not
granularity: Union[PerRow, PerGroup] = PerGroup(128),
like int8_dynamic_activation_intx_weight
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PerRow is safer default because it doesn't depend on input data size. I expect users should always specify this parameter
torchao/experimental/quant_api.py
Outdated
) | ||
|
||
@register_quantize_module_handler(Int8DynamicActivationIntxWeightConfig) | ||
def _int8_dynamic_activation_intx_weigh_transform( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _int8_dynamic_activation_intx_weigh_transform( | |
def _int8_dynamic_activation_intx_weight_transform( |
torchao/experimental/quant_api.py
Outdated
tensor_impl_ctr_kwargs = None | ||
if isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout): | ||
# We need to create a new layout object for each module because when | ||
# granulairty is PerRow, the layout objects cannot share the group_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# granulairty is PerRow, the layout objects cannot share the group_size | |
# granularity is PerRow, the layout objects cannot share the group_size |
if weight_tensor.tensor_impl.get_layout().has_bias: | ||
assert ( | ||
bias is None | ||
), "bias should be None because it is already packed with the weights (has_bias=True)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: if: assert
; also fine with leaving it as-is for legibility
if weight_tensor.tensor_impl.get_layout().has_bias: | |
assert ( | |
bias is None | |
), "bias should be None because it is already packed with the weights (has_bias=True)" | |
assert ( | |
not weight_tensor.tensor_impl.get_layout().has_bias or bias is None | |
), "bias should be None because it is already packed with the weights (has_bias=True)" |
if torch.backends.kleidiai.is_available(): | ||
if isinstance(granularity, PerGroup): | ||
scale_dtype = ( | ||
torch.bfloat16 | ||
) # KleidiAI kernel requires bfloat16 scale_dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like we always use float32 in to_affine_quantized_intx
. Is this intentional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
KleidiAI tests pass with this. This was only used for python-based quantization that computes qvals, scales, zeros, not by what was passed to the kernel itself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aten KleidiAI groupwise kernel requires scale_dtype as torch.bfloat16 otherwise it would fallback to ref implementation. Also the input to aten kernel needs to be float32
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will update this then.
It is only for groupwise that they'll fallback to the reference kernel? For channelwise, FP32 is fine or should it still be bfloat16?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bfloat16 only for groupwise. float32 for channelwise
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
didn't read the code in detail, but it would be great to migrate this to config soon so we can disable the old path
please feel free to wait for a proper review if needed
This PR: