Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mixtral / Awq] Add mixtral fused modules for Awq #28240

Merged

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Dec 25, 2023

What does this PR do?

Adds Mixtral + AWQ fused modules for blazing fast text generation!

from transformers import MixtralForCausalLM, AwqConfig, AutoTokenizer

model_path = "casperhansen/mixtral-instruct-awq"

quantization_config = AwqConfig(
    do_fuse=True,
    fuse_max_seq_len=1024,
)

model = MixtralForCausalLM.from_pretrained(model_path, quantization_config=quantization_config, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_path)

tokenizer.pad_token = tokenizer.eos_token 

inputs = ["Here are the top 10 useful Hindi phrases for your upcoming trip to India:\n1. ", "Hello my name is"]

inputs = tokenizer(inputs, return_tensors="pt", padding=True).to(0)
outputs = model.generate(**inputs, max_new_tokens=100, do_sample=False)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

I introduced the same changes in modeling_utils as #28239 for a tiny issue with respect to modules_to_not_convert not being handled correctly for fused module.
Users needs autoawq>=0.1.8 to use this feature

cc @casper-hansen

@@ -328,6 +335,8 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na
previous_device,
modules_to_fuse["max_seq_len"],
use_alibi=modules_to_fuse["use_alibi"],
# The default value in autoawq is set to 10000.0
rope_theta=modules_to_fuse.get("rope_theta", 10000.0),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This specifically addresses: casper-hansen/AutoAWQ#251 (comment)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to have the option to configure. As a general note, matching the default of another library is brittle - it can be changed without us knowing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes correct, let's keep that in mind, cc @casper-hansen for visibility

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this!

Just some questions and comments about the model specific element to this PR

@@ -328,6 +335,8 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na
previous_device,
modules_to_fuse["max_seq_len"],
use_alibi=modules_to_fuse["use_alibi"],
# The default value in autoawq is set to 10000.0
rope_theta=modules_to_fuse.get("rope_theta", 10000.0),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to have the option to configure. As a general note, matching the default of another library is brittle - it can be changed without us knowing.

# In case a user passes a `AwqConfig` with `do_fuse=True` for models that have
# a `modules_to_not_convert` attribute we need to manually set that attribute into the
# passed `quantization_config`
elif (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not obvious how this change relates to mixtral here - either from the AWQ fuse mapping or the test. Is if it's addressing a general bug we should have a test to cover it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is something that has been addressed recently in #28239 - this covers a bug where awq + fused modules does not deal properly with fused modules + modules_to_not_convert. I think having the mixtral and llava test (the llava test is alread there) should perhaps be already sufficient as it cover most of the usecase of modules_to_not_convert + fused modules. What do you think?

Comment on lines +391 to +394
def test_generation_mixtral_fused(self):
"""
Text generation test for Mixtral + AWQ + fused
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a model specific test here? We don't want to have to add tests for every model we cover. It would be better to have tests which cover different functional properties e.g. A, B, C. Then if any model uses A & C we know it works

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test should be generalizable to all mixtral models as the many thing to make sure that it works is on the interaction between modules_to_not_convert and fused modules for mixtral !
I can also do a smaller test with a tiny model - in addition to this one, if we know that the tiny model is correctly loaded then other models should be correctly loaded as well - wdyt? I would say in general this test is also good to have as the underlying things that it tests are

1- correct conversion of mixtral to mixtral fused modules (with modules_to_not_convert being properly set)
2- Generation correctness for mixtral + fused modules
3- Batched generation correctness for mixtral fused modules

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes please - let's add a more general test for a tiny model to make sure the code works generally: we don't want to overfit to specifics of mixtral but also want to make sure mixtral works.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perfect, will do !

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@younesbelkada
Copy link
Contributor Author

Thanks for your review @amyeroberts ! I left few comments and open questions, let me know wdyt! 🙏

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanations and iterating on this!

I would like to see a general test for a tiny model to be added. Happy for you to merge once that's commited :)

Comment on lines +391 to +394
def test_generation_mixtral_fused(self):
"""
Text generation test for Mixtral + AWQ + fused
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes please - let's add a more general test for a tiny model to make sure the code works generally: we don't want to overfit to specifics of mixtral but also want to make sure mixtral works.

@younesbelkada
Copy link
Contributor Author

Thanks @amyeroberts for all your reviews! I just added the more general test with a tiny model ! I will merge the PR and address potential comments in a follow up PR ! 🙏

@younesbelkada younesbelkada merged commit 266c67b into huggingface:main Jan 12, 2024
@younesbelkada younesbelkada deleted the add-mixtral-fused-modules branch January 12, 2024 13:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants