Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright 2024 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| from typing import Callable, Dict, List, Optional, Union | |
| import torch | |
| from huggingface_hub.utils import validate_hf_hub_args | |
| from diffusers.utils import ( | |
| USE_PEFT_BACKEND, | |
| is_peft_available, | |
| is_peft_version, | |
| is_torch_version, | |
| is_transformers_available, | |
| is_transformers_version, | |
| logging, | |
| ) | |
| from diffusers.loaders.lora_base import ( # noqa | |
| LoraBaseMixin, | |
| _fetch_state_dict, | |
| ) | |
| from diffusers.loaders.lora_conversion_utils import ( | |
| _convert_non_diffusers_lumina2_lora_to_diffusers, | |
| ) | |
| _LOW_CPU_MEM_USAGE_DEFAULT_LORA = False | |
| if is_torch_version(">=", "1.9.0"): | |
| if ( | |
| is_peft_available() | |
| and is_peft_version(">=", "0.13.1") | |
| and is_transformers_available() | |
| and is_transformers_version(">", "4.45.2") | |
| ): | |
| _LOW_CPU_MEM_USAGE_DEFAULT_LORA = True | |
| logger = logging.get_logger(__name__) | |
| TRANSFORMER_NAME = "transformer" | |
| class OmniGen2LoraLoaderMixin(LoraBaseMixin): | |
| r""" | |
| Load LoRA layers into [`OmniGen2Transformer2DModel`]. Specific to [`OmniGen2Pipeline`]. | |
| """ | |
| _lora_loadable_modules = ["transformer"] | |
| transformer_name = TRANSFORMER_NAME | |
| def lora_state_dict( | |
| cls, | |
| pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], | |
| **kwargs, | |
| ): | |
| r""" | |
| Return state dict for lora weights and the network alphas. | |
| <Tip warning={true}> | |
| We support loading A1111 formatted LoRA checkpoints in a limited capacity. | |
| This function is experimental and might change in the future. | |
| </Tip> | |
| Parameters: | |
| pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): | |
| Can be either: | |
| - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on | |
| the Hub. | |
| - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved | |
| with [`ModelMixin.save_pretrained`]. | |
| - A [torch state | |
| dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). | |
| cache_dir (`Union[str, os.PathLike]`, *optional*): | |
| Path to a directory where a downloaded pretrained model configuration is cached if the standard cache | |
| is not used. | |
| force_download (`bool`, *optional*, defaults to `False`): | |
| Whether or not to force the (re-)download of the model weights and configuration files, overriding the | |
| cached versions if they exist. | |
| proxies (`Dict[str, str]`, *optional*): | |
| A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', | |
| 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. | |
| local_files_only (`bool`, *optional*, defaults to `False`): | |
| Whether to only load local model weights and configuration files or not. If set to `True`, the model | |
| won't be downloaded from the Hub. | |
| token (`str` or *bool*, *optional*): | |
| The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from | |
| `diffusers-cli login` (stored in `~/.huggingface`) is used. | |
| revision (`str`, *optional*, defaults to `"main"`): | |
| The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier | |
| allowed by Git. | |
| subfolder (`str`, *optional*, defaults to `""`): | |
| The subfolder location of a model file within a larger model repository on the Hub or locally. | |
| """ | |
| # Load the main state dict first which has the LoRA layers for either of | |
| # transformer and text encoder or both. | |
| cache_dir = kwargs.pop("cache_dir", None) | |
| force_download = kwargs.pop("force_download", False) | |
| proxies = kwargs.pop("proxies", None) | |
| local_files_only = kwargs.pop("local_files_only", None) | |
| token = kwargs.pop("token", None) | |
| revision = kwargs.pop("revision", None) | |
| subfolder = kwargs.pop("subfolder", None) | |
| weight_name = kwargs.pop("weight_name", None) | |
| use_safetensors = kwargs.pop("use_safetensors", None) | |
| allow_pickle = False | |
| if use_safetensors is None: | |
| use_safetensors = True | |
| allow_pickle = True | |
| user_agent = { | |
| "file_type": "attn_procs_weights", | |
| "framework": "pytorch", | |
| } | |
| state_dict = _fetch_state_dict( | |
| pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, | |
| weight_name=weight_name, | |
| use_safetensors=use_safetensors, | |
| local_files_only=local_files_only, | |
| cache_dir=cache_dir, | |
| force_download=force_download, | |
| proxies=proxies, | |
| token=token, | |
| revision=revision, | |
| subfolder=subfolder, | |
| user_agent=user_agent, | |
| allow_pickle=allow_pickle, | |
| ) | |
| is_dora_scale_present = any("dora_scale" in k for k in state_dict) | |
| if is_dora_scale_present: | |
| warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." | |
| logger.warning(warn_msg) | |
| state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} | |
| # conversion. | |
| non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict) | |
| if non_diffusers: | |
| state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict) | |
| return state_dict | |
| # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights | |
| def load_lora_weights( | |
| self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs | |
| ): | |
| """ | |
| Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and | |
| `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See | |
| [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. | |
| See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state | |
| dict is loaded into `self.transformer`. | |
| Parameters: | |
| pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): | |
| See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. | |
| adapter_name (`str`, *optional*): | |
| Adapter name to be used for referencing the loaded adapter model. If not specified, it will use | |
| `default_{i}` where i is the total number of adapters being loaded. | |
| low_cpu_mem_usage (`bool`, *optional*): | |
| Speed up model loading by only loading the pretrained LoRA weights and not initializing the random | |
| weights. | |
| kwargs (`dict`, *optional*): | |
| See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. | |
| """ | |
| if not USE_PEFT_BACKEND: | |
| raise ValueError("PEFT backend is required for this method.") | |
| low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) | |
| if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): | |
| raise ValueError( | |
| "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." | |
| ) | |
| # if a dict is passed, copy it instead of modifying it inplace | |
| if isinstance(pretrained_model_name_or_path_or_dict, dict): | |
| pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() | |
| # First, ensure that the checkpoint is a compatible one and can be successfully loaded. | |
| state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) | |
| is_correct_format = all("lora" in key for key in state_dict.keys()) | |
| if not is_correct_format: | |
| raise ValueError("Invalid LoRA checkpoint.") | |
| self.load_lora_into_transformer( | |
| state_dict, | |
| transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, | |
| adapter_name=adapter_name, | |
| _pipeline=self, | |
| low_cpu_mem_usage=low_cpu_mem_usage, | |
| ) | |
| # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel | |
| def load_lora_into_transformer( | |
| cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False | |
| ): | |
| """ | |
| This will load the LoRA layers specified in `state_dict` into `transformer`. | |
| Parameters: | |
| state_dict (`dict`): | |
| A standard state dict containing the lora layer parameters. The keys can either be indexed directly | |
| into the unet or prefixed with an additional `unet` which can be used to distinguish between text | |
| encoder lora layers. | |
| transformer (`Lumina2Transformer2DModel`): | |
| The Transformer model to load the LoRA layers into. | |
| adapter_name (`str`, *optional*): | |
| Adapter name to be used for referencing the loaded adapter model. If not specified, it will use | |
| `default_{i}` where i is the total number of adapters being loaded. | |
| low_cpu_mem_usage (`bool`, *optional*): | |
| Speed up model loading by only loading the pretrained LoRA weights and not initializing the random | |
| weights. | |
| hotswap : (`bool`, *optional*) | |
| Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter | |
| in-place. This means that, instead of loading an additional adapter, this will take the existing | |
| adapter weights and replace them with the weights of the new adapter. This can be faster and more | |
| memory efficient. However, the main advantage of hotswapping is that when the model is compiled with | |
| torch.compile, loading the new adapter does not require recompilation of the model. When using | |
| hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. | |
| If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need | |
| to call an additional method before loading the adapter: | |
| ```py | |
| pipeline = ... # load diffusers pipeline | |
| max_rank = ... # the highest rank among all LoRAs that you want to load | |
| # call *before* compiling and loading the LoRA adapter | |
| pipeline.enable_lora_hotswap(target_rank=max_rank) | |
| pipeline.load_lora_weights(file_name) | |
| # optionally compile the model now | |
| ``` | |
| Note that hotswapping adapters of the text encoder is not yet supported. There are some further | |
| limitations to this technique, which are documented here: | |
| https://huggingface.co/docs/peft/main/en/package_reference/hotswap | |
| """ | |
| if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): | |
| raise ValueError( | |
| "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." | |
| ) | |
| # Load the layers corresponding to transformer. | |
| logger.info(f"Loading {cls.transformer_name}.") | |
| transformer.load_lora_adapter( | |
| state_dict, | |
| network_alphas=None, | |
| adapter_name=adapter_name, | |
| _pipeline=_pipeline, | |
| low_cpu_mem_usage=low_cpu_mem_usage, | |
| hotswap=hotswap, | |
| ) | |
| # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights | |
| def save_lora_weights( | |
| cls, | |
| save_directory: Union[str, os.PathLike], | |
| transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, | |
| is_main_process: bool = True, | |
| weight_name: str = None, | |
| save_function: Callable = None, | |
| safe_serialization: bool = True, | |
| ): | |
| r""" | |
| Save the LoRA parameters corresponding to the UNet and text encoder. | |
| Arguments: | |
| save_directory (`str` or `os.PathLike`): | |
| Directory to save LoRA parameters to. Will be created if it doesn't exist. | |
| transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): | |
| State dict of the LoRA layers corresponding to the `transformer`. | |
| is_main_process (`bool`, *optional*, defaults to `True`): | |
| Whether the process calling this is the main process or not. Useful during distributed training and you | |
| need to call this function on all processes. In this case, set `is_main_process=True` only on the main | |
| process to avoid race conditions. | |
| save_function (`Callable`): | |
| The function to use to save the state dictionary. Useful during distributed training when you need to | |
| replace `torch.save` with another method. Can be configured with the environment variable | |
| `DIFFUSERS_SAVE_MODE`. | |
| safe_serialization (`bool`, *optional*, defaults to `True`): | |
| Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. | |
| """ | |
| state_dict = {} | |
| if not transformer_lora_layers: | |
| raise ValueError("You must pass `transformer_lora_layers`.") | |
| if transformer_lora_layers: | |
| state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) | |
| # Save the model | |
| cls.write_lora_layers( | |
| state_dict=state_dict, | |
| save_directory=save_directory, | |
| is_main_process=is_main_process, | |
| weight_name=weight_name, | |
| save_function=save_function, | |
| safe_serialization=safe_serialization, | |
| ) | |
| # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora | |
| def fuse_lora( | |
| self, | |
| components: List[str] = ["transformer"], | |
| lora_scale: float = 1.0, | |
| safe_fusing: bool = False, | |
| adapter_names: Optional[List[str]] = None, | |
| **kwargs, | |
| ): | |
| r""" | |
| Fuses the LoRA parameters into the original parameters of the corresponding blocks. | |
| <Tip warning={true}> | |
| This is an experimental API. | |
| </Tip> | |
| Args: | |
| components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. | |
| lora_scale (`float`, defaults to 1.0): | |
| Controls how much to influence the outputs with the LoRA parameters. | |
| safe_fusing (`bool`, defaults to `False`): | |
| Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. | |
| adapter_names (`List[str]`, *optional*): | |
| Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. | |
| Example: | |
| ```py | |
| from diffusers import DiffusionPipeline | |
| import torch | |
| pipeline = DiffusionPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 | |
| ).to("cuda") | |
| pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") | |
| pipeline.fuse_lora(lora_scale=0.7) | |
| ``` | |
| """ | |
| super().fuse_lora( | |
| components=components, | |
| lora_scale=lora_scale, | |
| safe_fusing=safe_fusing, | |
| adapter_names=adapter_names, | |
| **kwargs, | |
| ) | |
| # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora | |
| def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): | |
| r""" | |
| Reverses the effect of | |
| [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). | |
| <Tip warning={true}> | |
| This is an experimental API. | |
| </Tip> | |
| Args: | |
| components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. | |
| unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. | |
| """ | |
| super().unfuse_lora(components=components, **kwargs) |