diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index b8fbf8f54362..cc0abac6e4ab 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -357,6 +357,8 @@ jobs: config: - backend: "bitsandbytes" test_location: "bnb" + - backend: "gguf" + test_location: "gguf" runs-on: group: aws-g6e-xlarge-plus container: diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 4edeb9fcb389..ab733054fbd3 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -157,6 +157,8 @@ title: Getting Started - local: quantization/bitsandbytes title: bitsandbytes + - local: quantization/gguf + title: gguf - local: quantization/torchao title: torchao title: Quantization Methods diff --git a/docs/source/en/api/quantization.md b/docs/source/en/api/quantization.md index 18aadf3111bd..168a9a03473f 100644 --- a/docs/source/en/api/quantization.md +++ b/docs/source/en/api/quantization.md @@ -28,6 +28,9 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui [[autodoc]] BitsAndBytesConfig +## GGUFQuantizationConfig + +[[autodoc]] GGUFQuantizationConfig ## TorchAoConfig [[autodoc]] TorchAoConfig diff --git a/docs/source/en/quantization/gguf.md b/docs/source/en/quantization/gguf.md new file mode 100644 index 000000000000..dbcd1b1486b2 --- /dev/null +++ b/docs/source/en/quantization/gguf.md @@ -0,0 +1,70 @@ + + +# GGUF + +The GGUF file format is typically used to store models for inference with [GGML](https://github.com/ggerganov/ggml) and supports a variety of block wise quantization options. Diffusers supports loading checkpoints prequantized and saved in the GGUF format via `from_single_file` loading with Model classes. Loading GGUF checkpoints via Pipelines is currently not supported. + +The following example will load the [FLUX.1 DEV](https://huggingface.co/black-forest-labs/FLUX.1-dev) transformer model using the GGUF Q2_K quantization variant. + +Before starting please install gguf in your environment + +```shell +pip install -U gguf +``` + +Since GGUF is a single file format, use [`~FromSingleFileMixin.from_single_file`] to load the model and pass in the [`GGUFQuantizationConfig`]. + +When using GGUF checkpoints, the quantized weights remain in a low memory `dtype`(typically `torch.unint8`) and are dynamically dequantized and cast to the configured `compute_dtype` during each module's forward pass through the model. The `GGUFQuantizationConfig` allows you to set the `compute_dtype`. + +The functions used for dynamic dequantizatation are based on the great work done by [city96](https://github.com/city96/ComfyUI-GGUF), who created the Pytorch ports of the original (`numpy`)[https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/quants.py] implementation by [compilade](https://github.com/compilade). + +```python +import torch + +from diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig + +ckpt_path = ( + "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf" +) +transformer = FluxTransformer2DModel.from_single_file( + ckpt_path, + quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16), + torch_dtype=torch.bfloat16, +) +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + transformer=transformer, + generator=torch.manual_seed(0), + torch_dtype=torch.bfloat16, +) +pipe.enable_model_cpu_offload() +prompt = "A cat holding a sign that says hello world" +image = pipe(prompt).images[0] +image.save("flux-gguf.png") +``` + +## Supported Quantization Types + +- BF16 +- Q4_0 +- Q4_1 +- Q5_0 +- Q5_1 +- Q8_0 +- Q2_K +- Q3_K +- Q4_K +- Q5_K +- Q6_K + diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index 151b22a607a4..6c2df7514d5e 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -17,7 +17,7 @@ Quantization techniques focus on representing data with less information while a -Interested in adding a new quantization method to Transformers? Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) to learn more about adding a new quantization method. +Interested in adding a new quantization method to Diffusers? Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) to learn more about adding a new quantization method. @@ -32,4 +32,9 @@ If you are new to the quantization field, we recommend you to check out these be ## When to use what? -Diffusers supports [bitsandbytes](https://huggingface.co/docs/bitsandbytes/main/en/index) and [torchao](https://github.com/pytorch/ao). Refer to this [table](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) to help you determine which quantization backend to use. \ No newline at end of file +Diffusers currently supports the following quantization methods. +- [BitsandBytes]() +- [TorchAO]() +- [GGUF]() + +[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques. diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index fc7ada80a63b..e2351a0c53b8 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -31,7 +31,7 @@ "loaders": ["FromOriginalModelMixin"], "models": [], "pipelines": [], - "quantizers.quantization_config": ["BitsAndBytesConfig", "TorchAoConfig"], + "quantizers.quantization_config": ["BitsAndBytesConfig", "GGUFQuantizationConfig", "TorchAoConfig"], "schedulers": [], "utils": [ "OptionalDependencyNotAvailable", @@ -569,7 +569,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .configuration_utils import ConfigMixin - from .quantizers.quantization_config import BitsAndBytesConfig, TorchAoConfig + from .quantizers.quantization_config import BitsAndBytesConfig, GGUFQuantizationConfig, TorchAoConfig try: if not is_onnx_available(): diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 78ce47273d8f..9641435fa5a6 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -17,8 +17,10 @@ from contextlib import nullcontext from typing import Optional +import torch from huggingface_hub.utils import validate_hf_hub_args +from ..quantizers import DiffusersAutoQuantizer from ..utils import deprecate, is_accelerate_available, logging from .single_file_utils import ( SingleFileComponentError, @@ -214,6 +216,8 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = subfolder = kwargs.pop("subfolder", None) revision = kwargs.pop("revision", None) torch_dtype = kwargs.pop("torch_dtype", None) + quantization_config = kwargs.pop("quantization_config", None) + device = kwargs.pop("device", None) if isinstance(pretrained_model_link_or_path_or_dict, dict): checkpoint = pretrained_model_link_or_path_or_dict @@ -227,6 +231,12 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = local_files_only=local_files_only, revision=revision, ) + if quantization_config is not None: + hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config) + hf_quantizer.validate_environment() + + else: + hf_quantizer = None mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name] @@ -309,8 +319,36 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = with ctx(): model = cls.from_config(diffusers_model_config) + # Check if `_keep_in_fp32_modules` is not None + use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( + (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules") + ) + if use_keep_in_fp32_modules: + keep_in_fp32_modules = cls._keep_in_fp32_modules + if not isinstance(keep_in_fp32_modules, list): + keep_in_fp32_modules = [keep_in_fp32_modules] + + else: + keep_in_fp32_modules = [] + + if hf_quantizer is not None: + hf_quantizer.preprocess_model( + model=model, + device_map=None, + state_dict=diffusers_format_checkpoint, + keep_in_fp32_modules=keep_in_fp32_modules, + ) + if is_accelerate_available(): - unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) + param_device = torch.device(device) if device else torch.device("cpu") + unexpected_keys = load_model_dict_into_meta( + model, + diffusers_format_checkpoint, + dtype=torch_dtype, + device=param_device, + hf_quantizer=hf_quantizer, + keep_in_fp32_modules=keep_in_fp32_modules, + ) else: _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False) @@ -324,7 +362,11 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" ) - if torch_dtype is not None: + if hf_quantizer is not None: + hf_quantizer.postprocess_model(model) + model.hf_quantizer = hf_quantizer + + if torch_dtype is not None and hf_quantizer is None: model.to(torch_dtype) model.eval() diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 21ff2841700d..4e288737fe88 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -81,8 +81,14 @@ "open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight", "stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight", "stable_cascade_stage_c": "clip_txt_mapper.weight", - "sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias", - "sd35_large": "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight", + "sd3": [ + "joint_blocks.0.context_block.adaLN_modulation.1.bias", + "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias", + ], + "sd35_large": [ + "joint_blocks.37.x_block.mlp.fc1.weight", + "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight", + ], "animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe", "animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias", "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight", @@ -542,13 +548,20 @@ def infer_diffusers_model_type(checkpoint): ): model_type = "stable_cascade_stage_b" - elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["sd3"]].shape[-1] == 9216: - if checkpoint["model.diffusion_model.pos_embed"].shape[1] == 36864: + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd3"]) and any( + checkpoint[key].shape[-1] == 9216 if key in checkpoint else False for key in CHECKPOINT_KEY_NAMES["sd3"] + ): + if "model.diffusion_model.pos_embed" in checkpoint: + key = "model.diffusion_model.pos_embed" + else: + key = "pos_embed" + + if checkpoint[key].shape[1] == 36864: model_type = "sd3" - elif checkpoint["model.diffusion_model.pos_embed"].shape[1] == 147456: + elif checkpoint[key].shape[1] == 147456: model_type = "sd35_medium" - elif CHECKPOINT_KEY_NAMES["sd35_large"] in checkpoint: + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd35_large"]): model_type = "sd35_large" elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint: diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 546c0eb4d840..af1a1a5250ff 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -17,6 +17,7 @@ import importlib import inspect import os +from array import array from collections import OrderedDict from pathlib import Path from typing import List, Optional, Union @@ -26,6 +27,7 @@ from huggingface_hub.utils import EntryNotFoundError from ..utils import ( + GGUF_FILE_EXTENSION, SAFE_WEIGHTS_INDEX_NAME, SAFETENSORS_FILE_EXTENSION, WEIGHTS_INDEX_NAME, @@ -33,6 +35,8 @@ _get_model_file, deprecate, is_accelerate_available, + is_gguf_available, + is_torch_available, is_torch_version, logging, ) @@ -139,6 +143,8 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[ file_extension = os.path.basename(checkpoint_file).split(".")[-1] if file_extension == SAFETENSORS_FILE_EXTENSION: return safetensors.torch.load_file(checkpoint_file, device="cpu") + elif file_extension == GGUF_FILE_EXTENSION: + return load_gguf_checkpoint(checkpoint_file) else: weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {} return torch.load( @@ -211,13 +217,14 @@ def load_model_dict_into_meta( set_module_kwargs["dtype"] = dtype # bnb params are flattened. + # gguf quants have a different shape based on the type of quantization applied if empty_state_dict[param_name].shape != param.shape: if ( is_quantized and hf_quantizer.pre_quantized and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) ): - hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape) + hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param) else: model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" raise ValueError( @@ -396,3 +403,78 @@ def _fetch_index_file_legacy( index_file = None return index_file + + +def _gguf_parse_value(_value, data_type): + if not isinstance(data_type, list): + data_type = [data_type] + if len(data_type) == 1: + data_type = data_type[0] + array_data_type = None + else: + if data_type[0] != 9: + raise ValueError("Received multiple types, therefore expected the first type to indicate an array.") + data_type, array_data_type = data_type + + if data_type in [0, 1, 2, 3, 4, 5, 10, 11]: + _value = int(_value[0]) + elif data_type in [6, 12]: + _value = float(_value[0]) + elif data_type in [7]: + _value = bool(_value[0]) + elif data_type in [8]: + _value = array("B", list(_value)).tobytes().decode() + elif data_type in [9]: + _value = _gguf_parse_value(_value, array_data_type) + return _value + + +def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): + """ + Load a GGUF file and return a dictionary of parsed parameters containing tensors, the parsed tokenizer and config + attributes. + + Args: + gguf_checkpoint_path (`str`): + The path the to GGUF file to load + return_tensors (`bool`, defaults to `True`): + Whether to read the tensors from the file and return them. Not doing so is faster and only loads the + metadata in memory. + """ + + if is_gguf_available() and is_torch_available(): + import gguf + from gguf import GGUFReader + + from ..quantizers.gguf.utils import SUPPORTED_GGUF_QUANT_TYPES, GGUFParameter + else: + logger.error( + "Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see " + "https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions." + ) + raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.") + + reader = GGUFReader(gguf_checkpoint_path) + + parsed_parameters = {} + for tensor in reader.tensors: + name = tensor.name + quant_type = tensor.tensor_type + + # if the tensor is a torch supported dtype do not use GGUFParameter + is_gguf_quant = quant_type not in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16] + if is_gguf_quant and quant_type not in SUPPORTED_GGUF_QUANT_TYPES: + _supported_quants_str = "\n".join([str(type) for type in SUPPORTED_GGUF_QUANT_TYPES]) + raise ValueError( + ( + f"{name} has a quantization type: {str(quant_type)} which is unsupported." + "\n\nCurrently the following quantization types are supported: \n\n" + f"{_supported_quants_str}" + "\n\nTo request support for this quantization type please open an issue here: https://github.com/huggingface/diffusers" + ) + ) + + weights = torch.from_numpy(tensor.data.copy()) + parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights + + return parsed_parameters diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index ce5289e3dbfd..0f9c9203c926 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1038,14 +1038,14 @@ def to(self, *args, **kwargs): dtype_present_in_args = True break - # Checks if the model has been loaded in 4-bit or 8-bit with BNB - if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: + if getattr(self, "is_quantized", False): if dtype_present_in_args: raise ValueError( - "You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the" - " desired `dtype` by passing the correct `torch_dtype` argument." + "Casting a quantized model to a new `dtype` is unsupported. To set the dtype of unquantized layers, please " + "use the `torch_dtype` argument when loading the model using `from_pretrained` or `from_single_file`" ) + if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: if getattr(self, "is_loaded_in_8bit", False): raise ValueError( "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the" diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 18527e3c46c0..8dbe49b75076 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -524,7 +524,6 @@ def custom_forward(*inputs): ) else: hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) for index_block, block in enumerate(self.single_transformer_blocks): diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py index 098308ae0bdc..41173ecb8f5e 100644 --- a/src/diffusers/quantizers/auto.py +++ b/src/diffusers/quantizers/auto.py @@ -15,23 +15,33 @@ Adapted from https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/quantizers/auto.py """ + import warnings from typing import Dict, Optional, Union from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer -from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod, TorchAoConfig +from .gguf import GGUFQuantizer +from .quantization_config import ( + BitsAndBytesConfig, + GGUFQuantizationConfig, + QuantizationConfigMixin, + QuantizationMethod, + TorchAoConfig, +) from .torchao import TorchAoHfQuantizer AUTO_QUANTIZER_MAPPING = { "bitsandbytes_4bit": BnB4BitDiffusersQuantizer, "bitsandbytes_8bit": BnB8BitDiffusersQuantizer, + "gguf": GGUFQuantizer, "torchao": TorchAoHfQuantizer, } AUTO_QUANTIZATION_CONFIG_MAPPING = { "bitsandbytes_4bit": BitsAndBytesConfig, "bitsandbytes_8bit": BitsAndBytesConfig, + "gguf": GGUFQuantizationConfig, "torchao": TorchAoConfig, } diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py index d5ac1611a571..f7780b66b12b 100644 --- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -204,7 +204,10 @@ def create_quantized_param( module._parameters[tensor_name] = new_value - def check_quantized_param_shape(self, param_name, current_param_shape, loaded_param_shape): + def check_quantized_param_shape(self, param_name, current_param, loaded_param): + current_param_shape = current_param.shape + loaded_param_shape = loaded_param.shape + n = current_param_shape.numel() inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1) if loaded_param_shape != inferred_shape: diff --git a/src/diffusers/quantizers/gguf/__init__.py b/src/diffusers/quantizers/gguf/__init__.py new file mode 100644 index 000000000000..b3d9082ac803 --- /dev/null +++ b/src/diffusers/quantizers/gguf/__init__.py @@ -0,0 +1 @@ +from .gguf_quantizer import GGUFQuantizer diff --git a/src/diffusers/quantizers/gguf/gguf_quantizer.py b/src/diffusers/quantizers/gguf/gguf_quantizer.py new file mode 100644 index 000000000000..0c760e277ce4 --- /dev/null +++ b/src/diffusers/quantizers/gguf/gguf_quantizer.py @@ -0,0 +1,159 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from ..base import DiffusersQuantizer + + +if TYPE_CHECKING: + from ...models.modeling_utils import ModelMixin + + +from ...utils import ( + get_module_from_name, + is_accelerate_available, + is_accelerate_version, + is_gguf_available, + is_gguf_version, + is_torch_available, + logging, +) + + +if is_torch_available() and is_gguf_available(): + import torch + + from .utils import ( + GGML_QUANT_SIZES, + GGUFParameter, + _dequantize_gguf_and_restore_linear, + _quant_shape_from_byte_shape, + _replace_with_gguf_linear, + ) + + +logger = logging.get_logger(__name__) + + +class GGUFQuantizer(DiffusersQuantizer): + use_keep_in_fp32_modules = True + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + + self.compute_dtype = quantization_config.compute_dtype + self.pre_quantized = quantization_config.pre_quantized + self.modules_to_not_convert = quantization_config.modules_to_not_convert + + if not isinstance(self.modules_to_not_convert, list): + self.modules_to_not_convert = [self.modules_to_not_convert] + + def validate_environment(self, *args, **kwargs): + if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"): + raise ImportError( + "Loading GGUF Parameters requires `accelerate` installed in your enviroment: `pip install 'accelerate>=0.26.0'`" + ) + if not is_gguf_available() or is_gguf_version("<", "0.10.0"): + raise ImportError( + "To load GGUF format files you must have `gguf` installed in your environment: `pip install gguf>=0.10.0`" + ) + + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.adjust_max_memory + def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: + # need more space for buffers that are created during quantization + max_memory = {key: val * 0.90 for key, val in max_memory.items()} + return max_memory + + def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": + if target_dtype != torch.uint8: + logger.info(f"target_dtype {target_dtype} is replaced by `torch.uint8` for GGUF quantization") + return torch.uint8 + + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": + if torch_dtype is None: + torch_dtype = self.compute_dtype + return torch_dtype + + def check_quantized_param_shape(self, param_name, current_param, loaded_param): + loaded_param_shape = loaded_param.shape + current_param_shape = current_param.shape + quant_type = loaded_param.quant_type + + block_size, type_size = GGML_QUANT_SIZES[quant_type] + + inferred_shape = _quant_shape_from_byte_shape(loaded_param_shape, type_size, block_size) + if inferred_shape != current_param_shape: + raise ValueError( + f"{param_name} has an expected quantized shape of: {inferred_shape}, but receieved shape: {loaded_param_shape}" + ) + + return True + + def check_if_quantized_param( + self, + model: "ModelMixin", + param_value: Union["GGUFParameter", "torch.Tensor"], + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ) -> bool: + if isinstance(param_value, GGUFParameter): + return True + + return False + + def create_quantized_param( + self, + model: "ModelMixin", + param_value: Union["GGUFParameter", "torch.Tensor"], + param_name: str, + target_device: "torch.device", + state_dict: Optional[Dict[str, Any]] = None, + unexpected_keys: Optional[List[str]] = None, + ): + module, tensor_name = get_module_from_name(model, param_name) + if tensor_name not in module._parameters and tensor_name not in module._buffers: + raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") + + if tensor_name in module._parameters: + module._parameters[tensor_name] = param_value.to(target_device) + if tensor_name in module._buffers: + module._buffers[tensor_name] = param_value.to(target_device) + + def _process_model_before_weight_loading( + self, + model: "ModelMixin", + device_map, + keep_in_fp32_modules: List[str] = [], + **kwargs, + ): + state_dict = kwargs.get("state_dict", None) + + self.modules_to_not_convert.extend(keep_in_fp32_modules) + self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None] + + _replace_with_gguf_linear( + model, self.compute_dtype, state_dict, modules_to_not_convert=self.modules_to_not_convert + ) + + def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs): + return model + + @property + def is_serializable(self): + return False + + @property + def is_trainable(self) -> bool: + return False + + def _dequantize(self, model): + is_model_on_cpu = model.device.type == "cpu" + if is_model_on_cpu: + logger.info( + "Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device." + ) + model.to(torch.cuda.current_device()) + + model = _dequantize_gguf_and_restore_linear(model, self.modules_to_not_convert) + if is_model_on_cpu: + model.to("cpu") + return model diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py new file mode 100644 index 000000000000..35e5743fbcf0 --- /dev/null +++ b/src/diffusers/quantizers/gguf/utils.py @@ -0,0 +1,456 @@ +# Copyright 2024 The HuggingFace Team and City96. All rights reserved. +# # +# # Licensed under the Apache License, Version 2.0 (the "License"); +# # you may not use this file except in compliance with the License. +# # You may obtain a copy of the License at +# # +# # http://www.apache.org/licenses/LICENSE-2.0 +# # +# # Unless required by applicable law or agreed to in writing, software +# # distributed under the License is distributed on an "AS IS" BASIS, +# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# # See the License for the specific language governing permissions and +# # limitations under the License. + + +import inspect +from contextlib import nullcontext + +import gguf +import torch +import torch.nn as nn + +from ...utils import is_accelerate_available + + +if is_accelerate_available(): + import accelerate + from accelerate import init_empty_weights + from accelerate.hooks import add_hook_to_module, remove_hook_from_module + + +# Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook +def _create_accelerate_new_hook(old_hook): + r""" + Creates a new hook based on the old hook. Use it only if you know what you are doing ! This method is a copy of: + https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245 with + some changes + """ + old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__) + old_hook_attr = old_hook.__dict__ + filtered_old_hook_attr = {} + old_hook_init_signature = inspect.signature(old_hook_cls.__init__) + for k in old_hook_attr.keys(): + if k in old_hook_init_signature.parameters: + filtered_old_hook_attr[k] = old_hook_attr[k] + new_hook = old_hook_cls(**filtered_old_hook_attr) + return new_hook + + +def _replace_with_gguf_linear(model, compute_dtype, state_dict, prefix="", modules_to_not_convert=[]): + def _should_convert_to_gguf(state_dict, prefix): + weight_key = prefix + "weight" + return weight_key in state_dict and isinstance(state_dict[weight_key], GGUFParameter) + + has_children = list(model.children()) + if not has_children: + return + + for name, module in model.named_children(): + module_prefix = prefix + name + "." + _replace_with_gguf_linear(module, compute_dtype, state_dict, module_prefix, modules_to_not_convert) + + if ( + isinstance(module, nn.Linear) + and _should_convert_to_gguf(state_dict, module_prefix) + and name not in modules_to_not_convert + ): + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + model._modules[name] = GGUFLinear( + module.in_features, + module.out_features, + module.bias is not None, + compute_dtype=compute_dtype, + ) + model._modules[name].source_cls = type(module) + # Force requires_grad to False to avoid unexpected errors + model._modules[name].requires_grad_(False) + + return model + + +def _dequantize_gguf_and_restore_linear(model, modules_to_not_convert=[]): + for name, module in model.named_children(): + if isinstance(module, GGUFLinear) and name not in modules_to_not_convert: + device = module.weight.device + bias = getattr(module, "bias", None) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + new_module = nn.Linear( + module.in_features, + module.out_features, + module.bias is not None, + device=device, + ) + new_module.weight = nn.Parameter(dequantize_gguf_tensor(module.weight)) + if bias is not None: + new_module.bias = bias + + # Create a new hook and attach it in case we use accelerate + if hasattr(module, "_hf_hook"): + old_hook = module._hf_hook + new_hook = _create_accelerate_new_hook(old_hook) + + remove_hook_from_module(module) + add_hook_to_module(new_module, new_hook) + + new_module.to(device) + model._modules[name] = new_module + + has_children = list(module.children()) + if has_children: + _dequantize_gguf_and_restore_linear(module, modules_to_not_convert) + + return model + + +# dequantize operations based on torch ports of GGUF dequantize_functions +# from City96 +# more info: https://github.com/city96/ComfyUI-GGUF/blob/main/dequant.py + + +QK_K = 256 +K_SCALE_SIZE = 12 + + +def to_uint32(x): + x = x.view(torch.uint8).to(torch.int32) + return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1) + + +def split_block_dims(blocks, *args): + n_max = blocks.shape[1] + dims = list(args) + [n_max - sum(args)] + return torch.split(blocks, dims, dim=1) + + +def get_scale_min(scales): + n_blocks = scales.shape[0] + scales = scales.view(torch.uint8) + scales = scales.reshape((n_blocks, 3, 4)) + + d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2) + + sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1) + min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1) + + return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8))) + + +def dequantize_blocks_Q8_0(blocks, block_size, type_size, dtype=None): + d, x = split_block_dims(blocks, 2) + d = d.view(torch.float16).to(dtype) + x = x.view(torch.int8) + return d * x + + +def dequantize_blocks_Q5_1(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, m, qh, qs = split_block_dims(blocks, 2, 2, 4) + d = d.view(torch.float16).to(dtype) + m = m.view(torch.float16).to(dtype) + qh = to_uint32(qh) + + qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) + ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor( + [0, 4], device=d.device, dtype=torch.uint8 + ).reshape(1, 1, 2, 1) + qh = (qh & 1).to(torch.uint8) + ql = (ql & 0x0F).reshape((n_blocks, -1)) + + qs = ql | (qh << 4) + return (d * qs) + m + + +def dequantize_blocks_Q5_0(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, qh, qs = split_block_dims(blocks, 2, 4) + d = d.view(torch.float16).to(dtype) + qh = to_uint32(qh) + + qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) + ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor( + [0, 4], device=d.device, dtype=torch.uint8 + ).reshape(1, 1, 2, 1) + + qh = (qh & 1).to(torch.uint8) + ql = (ql & 0x0F).reshape(n_blocks, -1) + + qs = (ql | (qh << 4)).to(torch.int8) - 16 + return d * qs + + +def dequantize_blocks_Q4_1(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, m, qs = split_block_dims(blocks, 2, 2) + d = d.view(torch.float16).to(dtype) + m = m.view(torch.float16).to(dtype) + + qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor( + [0, 4], device=d.device, dtype=torch.uint8 + ).reshape(1, 1, 2, 1) + qs = (qs & 0x0F).reshape(n_blocks, -1) + + return (d * qs) + m + + +def dequantize_blocks_Q4_0(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, qs = split_block_dims(blocks, 2) + d = d.view(torch.float16).to(dtype) + + qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor( + [0, 4], device=d.device, dtype=torch.uint8 + ).reshape((1, 1, 2, 1)) + qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8 + return d * qs + + +def dequantize_blocks_Q6_K(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + ( + ql, + qh, + scales, + d, + ) = split_block_dims(blocks, QK_K // 2, QK_K // 4, QK_K // 16) + + scales = scales.view(torch.int8).to(dtype) + d = d.view(torch.float16).to(dtype) + d = (d * scales).reshape((n_blocks, QK_K // 16, 1)) + + ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape( + (1, 1, 2, 1) + ) + ql = (ql & 0x0F).reshape((n_blocks, -1, 32)) + qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape( + (1, 1, 4, 1) + ) + qh = (qh & 0x03).reshape((n_blocks, -1, 32)) + q = (ql | (qh << 4)).to(torch.int8) - 32 + q = q.reshape((n_blocks, QK_K // 16, -1)) + + return (d * q).reshape((n_blocks, QK_K)) + + +def dequantize_blocks_Q5_K(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, dmin, scales, qh, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE, QK_K // 8) + + d = d.view(torch.float16).to(dtype) + dmin = dmin.view(torch.float16).to(dtype) + + sc, m = get_scale_min(scales) + + d = (d * sc).reshape((n_blocks, -1, 1)) + dm = (dmin * m).reshape((n_blocks, -1, 1)) + + ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape( + (1, 1, 2, 1) + ) + qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.arange(0, 8, device=d.device, dtype=torch.uint8).reshape( + (1, 1, 8, 1) + ) + ql = (ql & 0x0F).reshape((n_blocks, -1, 32)) + qh = (qh & 0x01).reshape((n_blocks, -1, 32)) + q = ql | (qh << 4) + + return (d * q - dm).reshape((n_blocks, QK_K)) + + +def dequantize_blocks_Q4_K(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE) + d = d.view(torch.float16).to(dtype) + dmin = dmin.view(torch.float16).to(dtype) + + sc, m = get_scale_min(scales) + + d = (d * sc).reshape((n_blocks, -1, 1)) + dm = (dmin * m).reshape((n_blocks, -1, 1)) + + qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape( + (1, 1, 2, 1) + ) + qs = (qs & 0x0F).reshape((n_blocks, -1, 32)) + + return (d * qs - dm).reshape((n_blocks, QK_K)) + + +def dequantize_blocks_Q3_K(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + hmask, qs, scales, d = split_block_dims(blocks, QK_K // 8, QK_K // 4, 12) + d = d.view(torch.float16).to(dtype) + + lscales, hscales = scales[:, :8], scales[:, 8:] + lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape( + (1, 2, 1) + ) + lscales = lscales.reshape((n_blocks, 16)) + hscales = hscales.reshape((n_blocks, 1, 4)) >> torch.tensor( + [0, 2, 4, 6], device=d.device, dtype=torch.uint8 + ).reshape((1, 4, 1)) + hscales = hscales.reshape((n_blocks, 16)) + scales = (lscales & 0x0F) | ((hscales & 0x03) << 4) + scales = scales.to(torch.int8) - 32 + + dl = (d * scales).reshape((n_blocks, 16, 1)) + + ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape( + (1, 1, 4, 1) + ) + qh = hmask.reshape(n_blocks, -1, 1, 32) >> torch.arange(0, 8, device=d.device, dtype=torch.uint8).reshape( + (1, 1, 8, 1) + ) + ql = ql.reshape((n_blocks, 16, QK_K // 16)) & 3 + qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & 1) ^ 1 + q = ql.to(torch.int8) - (qh << 2).to(torch.int8) + + return (dl * q).reshape((n_blocks, QK_K)) + + +def dequantize_blocks_Q2_K(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + scales, qs, d, dmin = split_block_dims(blocks, QK_K // 16, QK_K // 4, 2) + d = d.view(torch.float16).to(dtype) + dmin = dmin.view(torch.float16).to(dtype) + + # (n_blocks, 16, 1) + dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1)) + ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1)) + + shift = torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1)) + + qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & 3 + qs = qs.reshape((n_blocks, QK_K // 16, 16)) + qs = dl * qs - ml + + return qs.reshape((n_blocks, -1)) + + +def dequantize_blocks_BF16(blocks, block_size, type_size, dtype=None): + return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32) + + +GGML_QUANT_SIZES = gguf.GGML_QUANT_SIZES +dequantize_functions = { + gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16, + gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0, + gguf.GGMLQuantizationType.Q5_1: dequantize_blocks_Q5_1, + gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0, + gguf.GGMLQuantizationType.Q4_1: dequantize_blocks_Q4_1, + gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0, + gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K, + gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K, + gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K, + gguf.GGMLQuantizationType.Q3_K: dequantize_blocks_Q3_K, + gguf.GGMLQuantizationType.Q2_K: dequantize_blocks_Q2_K, +} +SUPPORTED_GGUF_QUANT_TYPES = list(dequantize_functions.keys()) + + +def _quant_shape_from_byte_shape(shape, type_size, block_size): + return (*shape[:-1], shape[-1] // type_size * block_size) + + +def dequantize_gguf_tensor(tensor): + if not hasattr(tensor, "quant_type"): + return tensor + + quant_type = tensor.quant_type + dequant_fn = dequantize_functions[quant_type] + + block_size, type_size = GGML_QUANT_SIZES[quant_type] + + tensor = tensor.view(torch.uint8) + shape = _quant_shape_from_byte_shape(tensor.shape, type_size, block_size) + + n_blocks = tensor.numel() // type_size + blocks = tensor.reshape((n_blocks, type_size)) + + dequant = dequant_fn(blocks, block_size, type_size) + dequant = dequant.reshape(shape) + + return dequant.as_tensor() + + +class GGUFParameter(torch.nn.Parameter): + def __new__(cls, data, requires_grad=False, quant_type=None): + data = data if data is not None else torch.empty(0) + self = torch.Tensor._make_subclass(cls, data, requires_grad) + self.quant_type = quant_type + + return self + + def as_tensor(self): + return torch.Tensor._make_subclass(torch.Tensor, self, self.requires_grad) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + result = super().__torch_function__(func, types, args, kwargs) + + # When converting from original format checkpoints we often use splits, cats etc on tensors + # this method ensures that the returned tensor type from those operations remains GGUFParameter + # so that we preserve quant_type information + quant_type = None + for arg in args: + if isinstance(arg, list) and (arg[0], GGUFParameter): + quant_type = arg[0].quant_type + break + if isinstance(arg, GGUFParameter): + quant_type = arg.quant_type + break + if isinstance(result, torch.Tensor): + return cls(result, quant_type=quant_type) + # Handle tuples and lists + elif isinstance(result, (tuple, list)): + # Preserve the original type (tuple or list) + wrapped = [cls(x, quant_type=quant_type) if isinstance(x, torch.Tensor) else x for x in result] + return type(result)(wrapped) + else: + return result + + +class GGUFLinear(nn.Linear): + def __init__( + self, + in_features, + out_features, + bias=False, + compute_dtype=None, + device=None, + ) -> None: + super().__init__(in_features, out_features, bias, device) + self.compute_dtype = compute_dtype + + def forward(self, inputs): + weight = dequantize_gguf_tensor(self.weight) + weight = weight.to(self.compute_dtype) + bias = self.bias.to(self.compute_dtype) + + output = torch.nn.functional.linear(inputs, weight, bias) + return output diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 4aeb75ab704c..3078be310719 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -43,6 +43,7 @@ class QuantizationMethod(str, Enum): BITS_AND_BYTES = "bitsandbytes" + GGUF = "gguf" TORCHAO = "torchao" @@ -394,6 +395,29 @@ def to_diff_dict(self) -> Dict[str, Any]: return serializable_config_dict +@dataclass +class GGUFQuantizationConfig(QuantizationConfigMixin): + """This is a config class for GGUF Quantization techniques. + + Args: + compute_dtype: (`torch.dtype`, defaults to `torch.float32`): + This sets the computational type which might be different than the input type. For example, inputs might be + fp32, but computation can be set to bf16 for speedups. + + """ + + def __init__(self, compute_dtype: Optional["torch.dtype"] = None): + self.quant_method = QuantizationMethod.GGUF + self.compute_dtype = compute_dtype + self.pre_quantized = True + + # TODO: (Dhruv) Add this as an init argument when we can support loading unquantized checkpoints. + self.modules_to_not_convert = None + + if self.compute_dtype is None: + self.compute_dtype = torch.float32 + + @dataclass class TorchAoConfig(QuantizationConfigMixin): """This is a config class for torchao quantization/sparsity techniques. diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 9860ac849834..f8de48ecfc78 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -23,6 +23,7 @@ DEPRECATED_REVISION_ARGS, DIFFUSERS_DYNAMIC_MODULE_NAME, FLAX_WEIGHTS_NAME, + GGUF_FILE_EXTENSION, HF_MODULES_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, MIN_PEFT_VERSION, @@ -66,6 +67,8 @@ is_bs4_available, is_flax_available, is_ftfy_available, + is_gguf_available, + is_gguf_version, is_google_colab, is_inflect_available, is_invisible_watermark_available, diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 553ac5d1bb27..93b0cd847d91 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -34,6 +34,7 @@ SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors" SAFE_WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.safetensors.index.json" SAFETENSORS_FILE_EXTENSION = "safetensors" +GGUF_FILE_EXTENSION = "gguf" ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb" HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co") DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index f325f36bddd3..3014efebc82e 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -339,6 +339,14 @@ def is_timm_available(): except importlib_metadata.PackageNotFoundError: _imageio_available = False +_is_gguf_available = importlib.util.find_spec("gguf") is not None +if _is_gguf_available: + try: + _gguf_version = importlib_metadata.version("gguf") + logger.debug(f"Successfully import gguf version {_gguf_version}") + except importlib_metadata.PackageNotFoundError: + _is_gguf_available = False + _is_torchao_available = importlib.util.find_spec("torchao") is not None if _is_torchao_available: @@ -469,6 +477,10 @@ def is_imageio_available(): return _imageio_available +def is_gguf_available(): + return _is_gguf_available + + def is_torchao_available(): return _is_torchao_available @@ -607,8 +619,13 @@ def is_torchao_available(): """ # docstyle-ignore +GGUF_IMPORT_ERROR = """ +{0} requires the gguf library but it was not found in your environment. You can install it with pip: `pip install gguf` +""" + TORCHAO_IMPORT_ERROR = """ -{0} requires the torchao library but it was not found in your environment. You can install it with pip: `pip install torchao` +{0} requires the torchao library but it was not found in your environment. You can install it with pip: `pip install +torchao` """ BACKENDS_MAPPING = OrderedDict( @@ -636,6 +653,7 @@ def is_torchao_available(): ("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)), ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)), ("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)), + ("gguf", (is_gguf_available, GGUF_IMPORT_ERROR)), ("torchao", (is_torchao_available, TORCHAO_IMPORT_ERROR)), ] ) @@ -793,6 +811,21 @@ def is_bitsandbytes_version(operation: str, version: str): return compare_versions(parse(_bitsandbytes_version), operation, version) +def is_gguf_version(operation: str, version: str): + """ + Compares the current Accelerate version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _is_gguf_available: + return False + return compare_versions(parse(_gguf_version), operation, version) + + def is_k_diffusion_version(operation: str, version: str): """ Compares the current k-diffusion version to a given reference with an operation. diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index b4d3415de50e..3448b4d28d1f 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -32,6 +32,7 @@ is_bitsandbytes_available, is_compel_available, is_flax_available, + is_gguf_available, is_note_seq_available, is_onnx_available, is_opencv_available, @@ -477,6 +478,18 @@ def decorator(test_case): return decorator +def require_gguf_version_greater_or_equal(gguf_version): + def decorator(test_case): + correct_gguf_version = is_gguf_available() and version.parse( + version.parse(importlib.metadata.version("gguf")).base_version + ) >= version.parse(gguf_version) + return unittest.skipUnless( + correct_gguf_version, f"Test requires gguf with the version greater than {gguf_version}." + )(test_case) + + return decorator + + def require_torchao_version_greater(torchao_version): def decorator(test_case): correct_torchao_version = is_torchao_available() and version.parse( diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py new file mode 100644 index 000000000000..8ac4c9915c27 --- /dev/null +++ b/tests/quantization/gguf/test_gguf.py @@ -0,0 +1,379 @@ +import gc +import unittest + +import numpy as np +import torch +import torch.nn as nn + +from diffusers import ( + FluxPipeline, + FluxTransformer2DModel, + GGUFQuantizationConfig, + SD3Transformer2DModel, + StableDiffusion3Pipeline, +) +from diffusers.utils.testing_utils import ( + is_gguf_available, + nightly, + numpy_cosine_similarity_distance, + require_accelerate, + require_big_gpu_with_torch_cuda, + require_gguf_version_greater_or_equal, + torch_device, +) + + +if is_gguf_available(): + from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter + + +@nightly +@require_big_gpu_with_torch_cuda +@require_accelerate +@require_gguf_version_greater_or_equal("0.10.0") +class GGUFSingleFileTesterMixin: + ckpt_path = None + model_cls = None + torch_dtype = torch.bfloat16 + expected_memory_use_in_gb = 5 + + def test_gguf_parameters(self): + quant_storage_type = torch.uint8 + quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype) + model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config) + + for param_name, param in model.named_parameters(): + if isinstance(param, GGUFParameter): + assert hasattr(param, "quant_type") + assert param.dtype == quant_storage_type + + def test_gguf_linear_layers(self): + quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype) + model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config) + + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear) and hasattr(module.weight, "quant_type"): + assert module.weight.dtype == torch.uint8 + assert module.bias.dtype == torch.float32 + + def test_gguf_memory_usage(self): + quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype) + + model = self.model_cls.from_single_file( + self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype + ) + model.to("cuda") + assert (model.get_memory_footprint() / 1024**3) < self.expected_memory_use_in_gb + inputs = self.get_dummy_inputs() + + torch.cuda.reset_peak_memory_stats() + torch.cuda.empty_cache() + with torch.no_grad(): + model(**inputs) + max_memory = torch.cuda.max_memory_allocated() + assert (max_memory / 1024**3) < self.expected_memory_use_in_gb + + def test_keep_modules_in_fp32(self): + r""" + A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32. + Also ensures if inference works. + """ + _keep_in_fp32_modules = self.model_cls._keep_in_fp32_modules + self.model_cls._keep_in_fp32_modules = ["proj_out"] + + quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype) + model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config) + + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + if name in model._keep_in_fp32_modules: + assert module.weight.dtype == torch.float32 + self.model_cls._keep_in_fp32_modules = _keep_in_fp32_modules + + def test_dtype_assignment(self): + quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype) + model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config) + + with self.assertRaises(ValueError): + # Tries with a `dtype` + model.to(torch.float16) + + with self.assertRaises(ValueError): + # Tries with a `device` and `dtype` + model.to(device="cuda:0", dtype=torch.float16) + + with self.assertRaises(ValueError): + # Tries with a cast + model.float() + + with self.assertRaises(ValueError): + # Tries with a cast + model.half() + + # This should work + model.to("cuda") + + def test_dequantize_model(self): + quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype) + model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config) + model.dequantize() + + def _check_for_gguf_linear(model): + has_children = list(model.children()) + if not has_children: + return + + for name, module in model.named_children(): + if isinstance(module, nn.Linear): + assert not isinstance(module, GGUFLinear), f"{name} is still GGUFLinear" + assert not isinstance(module.weight, GGUFParameter), f"{name} weight is still GGUFParameter" + + for name, module in model.named_children(): + _check_for_gguf_linear(module) + + +class FluxGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): + ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf" + torch_dtype = torch.bfloat16 + model_cls = FluxTransformer2DModel + expected_memory_use_in_gb = 5 + + def setUp(self): + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + def get_dummy_inputs(self): + return { + "hidden_states": torch.randn((1, 4096, 64), generator=torch.Generator("cpu").manual_seed(0)).to( + torch_device, self.torch_dtype + ), + "encoder_hidden_states": torch.randn( + (1, 512, 4096), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "pooled_projections": torch.randn( + (1, 768), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype), + "img_ids": torch.randn((4096, 3), generator=torch.Generator("cpu").manual_seed(0)).to( + torch_device, self.torch_dtype + ), + "txt_ids": torch.randn((512, 3), generator=torch.Generator("cpu").manual_seed(0)).to( + torch_device, self.torch_dtype + ), + "guidance": torch.tensor([3.5]).to(torch_device, self.torch_dtype), + } + + def test_pipeline_inference(self): + quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype) + transformer = self.model_cls.from_single_file( + self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype + ) + pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=self.torch_dtype + ) + pipe.enable_model_cpu_offload() + + prompt = "a cat holding a sign that says hello" + output = pipe( + prompt=prompt, num_inference_steps=2, generator=torch.Generator("cpu").manual_seed(0), output_type="np" + ).images[0] + output_slice = output[:3, :3, :].flatten() + expected_slice = np.array( + [ + 0.47265625, + 0.43359375, + 0.359375, + 0.47070312, + 0.421875, + 0.34375, + 0.46875, + 0.421875, + 0.34765625, + 0.46484375, + 0.421875, + 0.34179688, + 0.47070312, + 0.42578125, + 0.34570312, + 0.46875, + 0.42578125, + 0.3515625, + 0.45507812, + 0.4140625, + 0.33984375, + 0.4609375, + 0.41796875, + 0.34375, + 0.45898438, + 0.41796875, + 0.34375, + ] + ) + max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice) + assert max_diff < 1e-4 + + +class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): + ckpt_path = "https://huggingface.co/city96/stable-diffusion-3.5-large-gguf/blob/main/sd3.5_large-Q4_0.gguf" + torch_dtype = torch.bfloat16 + model_cls = SD3Transformer2DModel + expected_memory_use_in_gb = 5 + + def setUp(self): + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + def get_dummy_inputs(self): + return { + "hidden_states": torch.randn((1, 16, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to( + torch_device, self.torch_dtype + ), + "encoder_hidden_states": torch.randn( + (1, 512, 4096), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "pooled_projections": torch.randn( + (1, 2048), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype), + } + + def test_pipeline_inference(self): + quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype) + transformer = self.model_cls.from_single_file( + self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype + ) + pipe = StableDiffusion3Pipeline.from_pretrained( + "stabilityai/stable-diffusion-3.5-large", transformer=transformer, torch_dtype=self.torch_dtype + ) + pipe.enable_model_cpu_offload() + + prompt = "a cat holding a sign that says hello" + output = pipe( + prompt=prompt, num_inference_steps=2, generator=torch.Generator("cpu").manual_seed(0), output_type="np" + ).images[0] + output_slice = output[:3, :3, :].flatten() + expected_slice = np.array( + [ + 0.17578125, + 0.27539062, + 0.27734375, + 0.11914062, + 0.26953125, + 0.25390625, + 0.109375, + 0.25390625, + 0.25, + 0.15039062, + 0.26171875, + 0.28515625, + 0.13671875, + 0.27734375, + 0.28515625, + 0.12109375, + 0.26757812, + 0.265625, + 0.16210938, + 0.29882812, + 0.28515625, + 0.15625, + 0.30664062, + 0.27734375, + 0.14648438, + 0.29296875, + 0.26953125, + ] + ) + max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice) + assert max_diff < 1e-4 + + +class SD35MediumGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): + ckpt_path = "https://huggingface.co/city96/stable-diffusion-3.5-medium-gguf/blob/main/sd3.5_medium-Q3_K_M.gguf" + torch_dtype = torch.bfloat16 + model_cls = SD3Transformer2DModel + expected_memory_use_in_gb = 2 + + def setUp(self): + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + def get_dummy_inputs(self): + return { + "hidden_states": torch.randn((1, 16, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to( + torch_device, self.torch_dtype + ), + "encoder_hidden_states": torch.randn( + (1, 512, 4096), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "pooled_projections": torch.randn( + (1, 2048), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype), + } + + def test_pipeline_inference(self): + quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype) + transformer = self.model_cls.from_single_file( + self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype + ) + pipe = StableDiffusion3Pipeline.from_pretrained( + "stabilityai/stable-diffusion-3.5-medium", transformer=transformer, torch_dtype=self.torch_dtype + ) + pipe.enable_model_cpu_offload() + + prompt = "a cat holding a sign that says hello" + output = pipe( + prompt=prompt, num_inference_steps=2, generator=torch.Generator("cpu").manual_seed(0), output_type="np" + ).images[0] + output_slice = output[:3, :3, :].flatten() + expected_slice = np.array( + [ + 0.625, + 0.6171875, + 0.609375, + 0.65625, + 0.65234375, + 0.640625, + 0.6484375, + 0.640625, + 0.625, + 0.6484375, + 0.63671875, + 0.6484375, + 0.66796875, + 0.65625, + 0.65234375, + 0.6640625, + 0.6484375, + 0.6328125, + 0.6640625, + 0.6484375, + 0.640625, + 0.67578125, + 0.66015625, + 0.62109375, + 0.671875, + 0.65625, + 0.62109375, + ] + ) + max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice) + assert max_diff < 1e-4