-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[Single File] Add GGUF support #9964
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -204,7 +204,10 @@ def create_quantized_param( | |||
|
|||
module._parameters[tensor_name] = new_value | |||
|
|||
def check_quantized_param_shape(self, param_name, current_param_shape, loaded_param_shape): | |||
def check_quantized_param_shape(self, param_name, current_param, loaded_param): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GGUF needs to access the tensor quant type to run a shape check. So this needs to change from passing in shapes to passing in params directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not add this method to the gguf_quantizer.py
file instead of modifying this? This would be a breaking change no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see you're already adding this to the GGUF quantizer class. So, maybe okay to not modify this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
definitely make sense here to make sure this method has same signature across all quantizers, it will be confusing otherwise
in terms of breaking change, I think it is ok, but we can deprecate it if we want to be extra cautious
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think no deprecation is fine since this method is called from load_model_dict_into_meta()
. But let's make sure to run the tests to ensure nothing's breaking.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh thanks!
I left some comments but overall looks good to me! (I didn't look into all these methods on utils.py, but I can go over them in details if you need me to 😛)
from ..quantizers.gguf.utils import GGUFParameter | ||
else: | ||
logger.error( | ||
"Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to check gguf version as well? (in addition to is_gguf_available)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. Let's always suggest installing the latest stable build of gguf
like we do for bitsandbytes
.
if not is_bitsandbytes_available() or is_bitsandbytes_version("<", "0.43.3"): |
weights = torch.from_numpy(tensor.data.copy()) | ||
parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights | ||
|
||
if len(reader_keys) > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trying to understand this check here,
I think maybe when we iterate through the tensors we also remove the names from the reader_keys
as we go, the check here would make sense - but I didn't see any code to remove anything; so maybe we forgot to remove them? or it's not the case? did I miss something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copied this from transformers. But these aren't tensor keys. They're metadata keys
['GGUF.version', 'GGUF.tensor_count', 'GGUF.kv_count', 'general.architecture', 'general.quantization_version', 'general.file_type']
This can probably just be removed since the info isn't too relevant.
super().__init__(quantization_config, **kwargs) | ||
|
||
self.compute_dtype = quantization_config.compute_dtype | ||
self.pre_quantized = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so gguf will always be pre_quantized? it does not make sense to support converting it (like we do for bnb)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would rather take this from the config and then default it to True
. If it's otherwise, we error out unless we support saving a model quantized with GGUF.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm seems like it's always going to be prequantized for this PR.
raise ImportError( | ||
"Loading GGUF Parameters requires `accelerate` installed in your enviroment: `pip install 'accelerate>=0.26.0'`" | ||
) | ||
if not is_gguf_available(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i saw a bit earlier you said there is a minimum version for gguf, do we need to check for that here too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that just checking it here should be enough + check version
@@ -204,7 +204,10 @@ def create_quantized_param( | |||
|
|||
module._parameters[tensor_name] = new_value | |||
|
|||
def check_quantized_param_shape(self, param_name, current_param_shape, loaded_param_shape): | |||
def check_quantized_param_shape(self, param_name, current_param, loaded_param): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
definitely make sense here to make sure this method has same signature across all quantizers, it will be confusing otherwise
in terms of breaking change, I think it is ok, but we can deprecate it if we want to be extra cautious
import gguf | ||
import torch | ||
|
||
from .utils import GGUFParameter, _quant_shape_from_byte_shape, _replace_with_gguf_linear |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there any particular reason to put these under conditional import as well? =
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When we import quantizers.quantization_config
in the diffusers init, it also imports the DiffusersAutoQuantizers object which is in the quantizers module init file. DiffusersAutoQuantizer imports BnB and GGUF quantizers so our dependency test complains when running without torch or gguf installed.
@SunMarc can you do a review too? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking very nice! I wonder if it would be possible to support saving models with quantization from GGUF. That would be amazing!
if torch_dtype is not None: | ||
model.to(torch_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't cast the model when hf_quantizer
is not None to preserve the data-types set during preprocessing and postprocessing:
# When using `use_keep_in_fp32_modules` if we do a global `to()` here, then we will |
is_torch_version, | ||
logging, | ||
) | ||
from ..utils.import_utils import is_gguf_available |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might make sense to add the method to the __init__.py
of utils
?
@@ -176,11 +182,9 @@ def load_model_dict_into_meta( | |||
hf_quantizer=None, | |||
keep_in_fp32_modules=None, | |||
) -> List[str]: | |||
if hf_quantizer is None: | |||
device = device or torch.device("cpu") | |||
device = device or torch.device("cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might have some consequences.
If device
is passed as 0 (which is perfectly valid as a device id) then the device
would be selected as "CPU", which is not what we want here no? For bnb, we pass the param_device
to be:
param_device = torch.cuda.current_device() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@a-r-r-o-w has an open PR for this #10069
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this nice implementation. I think transformers integration could benefit from what you did here to allow users to run the compressed model cc @Isotr0py. Left a few comments.
raise ImportError( | ||
"Loading GGUF Parameters requires `accelerate` installed in your enviroment: `pip install 'accelerate>=0.26.0'`" | ||
) | ||
if not is_gguf_available(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that just checking it here should be enough + check version
if tensor_name not in module._parameters: | ||
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should check for buffers also in module._buffers
as it is not included in module._parameters
# dequantize operations based on torch ports of GGUF dequantize_functions | ||
# from City96 | ||
# more info: https://github.com/city96/ComfyUI-GGUF/blob/main/dequant.py | ||
|
||
|
||
QK_K = 256 | ||
K_SCALE_SIZE = 12 | ||
|
||
|
||
def to_uint32(x): | ||
x = x.view(torch.uint8).to(torch.int32) | ||
return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1) | ||
|
||
|
||
def split_block_dims(blocks, *args): | ||
n_max = blocks.shape[1] | ||
dims = list(args) + [n_max - sum(args)] | ||
return torch.split(blocks, dims, dim=1) | ||
|
||
|
||
def get_scale_min(scales): | ||
n_blocks = scales.shape[0] | ||
scales = scales.view(torch.uint8) | ||
scales = scales.reshape((n_blocks, 3, 4)) | ||
|
||
d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2) | ||
|
||
sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1) | ||
min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1) | ||
|
||
return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8))) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gguf library have a dequantize function, is this something we can use here ?
def dequantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems that these dequantize
functions are used for runtime inference, if we use dequantize
from gguf-py
, it would cause significant performance degradation on GPU because it's a numpy implementation.
BTW, if we want a runtime optimized dequantize function, we might need a torch compatible GGUF kernel, just like what we do in vLLM...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh makes sense. Thanks for the details !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great to see you again on diffusers
, @Isotr0py!
BTW, if we want a runtime optimized dequantize function, we might need a torch compatible GGUF kernel, just like what we do in vLLM...
That is a great point. Perhaps we could ship the first iteration without and make use of those kernels (iff possible) in a follow-up PR? WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, we can port the GGUF kernel in a follow-up PR in the future, because I'm going to update the GGUF kernel implementation in vLLM since it's pretty out of date. (Although I haven't had bandwidth to start it. 😅)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But regardless, thanks for bringing it up. Good to know this is possible.
class GGUFParameter(torch.Tensor): | ||
def __new__(cls, data, requires_grad=False, quant_type=None): | ||
data = data if data is not None else torch.empty(0) | ||
self = torch.Tensor._make_subclass(cls, data, requires_grad) | ||
self.quant_type = quant_type | ||
|
||
return self |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice ! Great use of tensor subclasses
if quantization_config is not None: | ||
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config) | ||
hf_quantizer.validate_environment() | ||
|
||
else: | ||
hf_quantizer = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For GGUF files, I'm thinking if it would be nice to allow the user to load the model without having necessarily to specify quantization_config=GGUFQuantizationConfig(compute_dtype=xxx)
. If we detect that this is a gguf, we can set by default quantization_config = GGUFQuantizationConfig(compute_dtype=torch.float32)
.
I'm suggesting this because usually, when you pass a quantization_config
, it means either that the model is not quantized (bnb) or that the model is quantized (there is a quantization_config in the config.json) but we want to change a few arguments.
Also, what happens when the user pass a gguf without specifying the quantization_config
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this is a good point! I think for most users, the entrypoint for GGUF files is going to be through from_single_file()
and I agree with the logic you mentioned.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that this is a nice convenience. GGUF does have all the information we need to auto fetch the config (honestly it's possible to skip the config all together), but it would mean that loading semantics would be different for GGUF vs other quant types. e.g.
GGUF
model = FluxTransformer2DModel.from_single_file("<>.gguf")
BnB and TorchAO (assuming these can be supported):
model = FluxTransformer2DModel.from_single_file("<path>", quantization_config=BnBConfig)
model = FluxTransformer2DModel.from_single_file("<path>", quantization_config=TorchAOConfig)
GGUF can also be used through from_pretrained
(assuming quants of diffusers format checkpoints show up as some point) and we would have to pass a quant config in that case. I understand it's not ideal, but I feel it's better to preserve consistency across the different quant loading methods.
@SunMarc if the config isn't passed you get shape mismatch errors when you hit load_model_dict_into_meta
since the quant shapes are different from the expected shapes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm suggesting this because usually, when you pass a quantization_config, it means either that the model is not quantized (bnb) or that the model is quantized (there is a quantization_config in the config.json) but we want to change a few arguments.
yeah I thought about that too, but I think the API for from_single_file
and from_pretrained
might just have to be different. It is a bit confusing but I'm not sure if there is a way to make it consistent between from_single_file
and from_pertrained
, if we also want to make sure the same API is consistent across different quant types
GGUF is a special case here because it has built-in config. Normally, for single-file it is just a checkpoint without config, so you will always have to pass a config (at least I think so, is it? @DN6 ). So for loading a regular quantized model (e.g. BNB) we can load it with from_pretrained
without passing a config, but for from_single_file
, we will have to manually pass a config
so agree with @DN6 here I think it more important to make the same API (from_pretrained API
or from_single_file
) consistent for different quant types; if we have to choose one
but if there a way to make it consistent between from_pretrained and from_single_file and across all quant types it will be great!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, want to know this: do we plan to support quantizing a model infrom_single_file
? @DN6
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GGUF is a special case here because it has built-in config. Normally, for single-file it is just a checkpoint without config, so you will always have to pass a config (at least I think so, is it? @DN6 ). So for loading a regular quantized model (e.g. BNB) we can load it with from_pretrained without passing a config, but for from_single_file, we will have to manually pass a config
Would it make sense to at least make the user aware when the passed config and the determined config mismatch and if that could lead to unintentional consequences?
also, want to know this: do we plan to support quantizing a model infrom_single_file? @DN6
Supporting quantizing in the GGUF format (regardless of from_pretrained()
or from_single_file()
) would be reallllly nice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yiyixuxu Yeah we can definitely support quantizing a model via single file. For GGUF I can look into in a follow up because we would have to port the quantize functions to torch (the gguf library uses numpy). We could use the gguf library interally to quantize but it's quite slow since we would have to move tensors off GPU, convert to numpy and then quantize.
I think with torch AO I'm pretty sure it would work just out of the box.
You would have to save it with save_pretrained
though since we don't support serializing single file checkpoints.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, what I am hearing is saving a GGUF quantized model would be added in a follow-up PR? That is also okay but it could be quite an enabling factor for the community.
For GGUF I can look into in a follow up because we would have to port the quantize functions to torch (the gguf library uses numpy). We could use the gguf library interally to quantize but it's quite slow since we would have to move tensors off GPU, convert to numpy and then quantize.
I think the porting option is more preferrable.
I think with torch AO I'm pretty sure it would work just out of the box.
You mean serializing with torchao
but with quantization configs similar to the ones provided in GGUF?
assert module.weight.dtype == torch.float32 | ||
self.model_cls._keep_in_fp32_modules = _keep_in_fp32_modules | ||
|
||
def test_dtype_assignment(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def test_dtype_assignment(self): | |
def test_device_dtype_assignment(self): |
def test_gguf_linear_layers(self): | ||
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype) | ||
model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config) | ||
|
||
for name, module in model.named_modules(): | ||
if isinstance(module, torch.nn.Linear) and hasattr(module.weight, "quant_type"): | ||
assert module.weight.dtype == torch.uint8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this test also check if the bias and 1D params are in the respective dtypes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1D Params don't need to be tested. They're normal FP32 parameters and are loaded as is. They won't be GGUFLinear layers or GGUFParameters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes, Dhruv! I think only documentation now, no?
docs/source/en/quantization/gguf.md
Outdated
|
||
Since GGUF is a single file format, we will be using `from_single_file` to load the model and pass in the `GGUFQuantizationConfig` when loading the model. | ||
|
||
When using GGUF checkpoints, the quantized weights remain in a low memory `dtype`, typically `torch.unint8` and are dynamically dequantized and cast to the configured `compute_dtype` when running a forward pass through each module in the model. The `GGUFQuantizationConfig` allows you to set the `compute_dtype` for the forward pass of each module. The functions used for dynamic dequantizatation are based on the great work done by [city96](https://github.com/city96/ComfyUI-GGUF) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A lot of the pytorch dequantization code is based on the numpy code from llama.cpp written by @compilade - I believe he should be credited here as well :)
@stevhliu could you review the doc related changes please? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Co-authored-by: Steven Liu <[email protected]>
@Isotr0py, now that this PR is merged, would you like to help us integrate the support for launching appropriate GGUF kernels as done in vLLM? |
@sayakpaul Sure, I have separated a minimal version of GGUF kernel from vllm here: https://github.com/Isotr0py/ggml-libtorch/tree/main/ggml-cuda So it should be easier to integrate the GGUF kernel now. :) |
Wonderful. I think we could start minimal, gauge the performance gains and then iteratively ship. Happy to work with you on this! |
* update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * Update src/diffusers/quantizers/gguf/utils.py Co-authored-by: Sayak Paul <[email protected]> * update * update * update * update * update * update * update * update * update * update * Update docs/source/en/quantization/gguf.md Co-authored-by: Steven Liu <[email protected]> * update * update * update * update --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Steven Liu <[email protected]>
What does this PR do?
Adds support for loading GGUF checkpoints via
from_single_file
.Notes:
from_pretrained
. GGUF files have enough metadata that we can automatically infer everything we need from the file itself. We don't really need a quantization config, but it becomes necessary as we expand to support to other quant loading methods (BnB, TorchAO etc)TODOS:
Fixes # (issue)
#9487
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.