🤗 PEFT
State-of-the-art Parameter-Efficient Fine-Tuning (PEFT) methods
Fine-tuning large pretrained models is often prohibitively costly due to their scale. Parameter-Efficient Fine-Tuning (PEFT) methods enable efficient adaptation of large pretrained models to various downstream applications by only fine-tuning a small number of (extra) model parameters instead of all the model's parameters. This significantly decreases the computational and storage costs. Recent state-of-the-art PEFT techniques achieve performance comparable to fully fine-tuned models. PEFT is integrated with Transformers for easy model training and inference, Diffusers for conveniently managing different adapters, and Accelerate for distributed training and inference for really big models. > [!TIP] > Visit the [PEFT](https://huggingface.co/PEFT) organization to read about the PEFT methods implemented in the library and to see notebooks demonstrating how to apply these methods to a variety of downstream tasks. Click the "Watch repos" button on the organization page to be notified of newly implemented methods and notebooks! Check the PEFT Adapters API Reference section for a list of supported PEFT methods, and read the [Adapters](https://huggingface.co/docs/peft/en/conceptual_guides/adapter), [Soft prompts](https://huggingface.co/docs/peft/en/conceptual_guides/prompting), and [IA3](https://huggingface.co/docs/peft/en/conceptual_guides/ia3) conceptual guides to learn more about how these methods work. ## Quickstart Install PEFT from pip: ```bash pip install peft ``` Prepare a model for training with a PEFT method such as LoRA by wrapping the base model and PEFT configuration with `get_peft_model`. For the bigscience/mt0-large model, you're only training 0.19% of the parameters! ```python from transformers import AutoModelForCausalLM from peft import LoraConfig, TaskType, get_peft_model device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda" model_id = "Qwen/Qwen2.5-3B-Instruct" model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device) peft_config = LoraConfig( r=16, lora_alpha=32, task_type=TaskType.CAUSAL_LM, # target_modules=["q_proj", "v_proj", ...] # optionally indicate target modules ) model = get_peft_model(model, peft_config) model.print_trainable_parameters() # prints: trainable params: 3,686,400 || all params: 3,089,625,088 || trainable%: 0.1193 # now perform training on your dataset, e.g. using transformers Trainer, then save the model model.save_pretrained("qwen2.5-3b-lora") ``` To load a PEFT model for inference: ```python from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda" model_id = "Qwen/Qwen2.5-3B-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device) model = PeftModel.from_pretrained(model, "qwen2.5-3b-lora") inputs = tokenizer("Preheat the oven to 350 degrees and place the cookie dough", return_tensors="pt") outputs = model.generate(**inputs.to(device), max_new_tokens=50) print(tokenizer.decode(outputs[0], skip_special_tokens=True)) # prints something like: Preheat the oven to 350 degrees and place the cookie dough in a baking dish [...] ``` ## Why you should use PEFT There are many benefits of using PEFT but the main one is the huge savings in compute and storage, making PEFT applicable to many different use cases. ### High performance on consumer hardware Consider the memory requirements for training the following models on the [ought/raft/twitter_complaints](https://huggingface.co/datasets/ought/raft/viewer/twitter_complaints) dataset with an A100 80GB GPU with more than 64GB of CPU RAM. | Model | Full Finetuning | PEFT-LoRA PyTorch | PEFT-LoRA DeepSpeed with CPU Offloading | | --------- | ---- | ---- | ---- | | bigscience/T0_3B (3B params) | 47.14GB GPU / 2.96GB CPU | 14.4GB GPU / 2.96GB CPU | 9.8GB GPU / 17.8GB CPU | | bigscience/mt0-xxl (12B params) | OOM GPU | 56GB GPU / 3GB CPU | 22GB GPU / 52GB CPU | | bigscience/bloomz-7b1 (7B params) | OOM GPU | 32GB GPU / 3.8GB CPU | 18.1GB GPU / 35GB CPU | With LoRA you can fully finetune a 12B parameter model that would've otherwise run out of memory on the 80GB GPU, and comfortably fit and train a 3B parameter model. When you look at the 3B parameter model's performance, it is comparable to a fully finetuned model at a fraction of the GPU memory. | Submission Name | Accuracy | | --------- | ---- | | Human baseline (crowdsourced) | 0.897 | | Flan-T5 | 0.892 | | lora-t0-3b | 0.863 | > [!TIP] > The bigscience/T0_3B model performance isn't optimized in the table above. You can squeeze even more performance out of it by playing around with the input instruction templates, LoRA hyperparameters, and other training related hyperparameters. The final checkpoint size of this model is just 19MB compared to 11GB of the full bigscience/T0_3B model. Learn more about the advantages of finetuning with PEFT in this [blog post](https://www.philschmid.de/fine-tune-flan-t5-peft). ### Quantization Quantization is another method for reducing the memory requirements of a model by representing the data in a lower precision. It can be combined with PEFT methods to make it even easier to train and load LLMs for inference. * Learn how to finetune [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) with QLoRA and the [TRL](https://huggingface.co/docs/trl/index) library on a 16GB GPU in the [Finetune LLMs on your own consumer hardware using tools from PyTorch and Hugging Face ecosystem](https://pytorch.org/blog/finetune-llms/) blog post. * Learn how to finetune a [openai/whisper-large-v2](https://huggingface.co/openai/whisper-large-v2) model for multilingual automatic speech recognition with LoRA and 8-bit quantization in this [notebook](https://colab.research.google.com/drive/1DOkD_5OUjFa0r5Ik3SgywJLJtEo2qLxO?usp=sharing) (see this [notebook](https://colab.research.google.com/drive/1vhF8yueFqha3Y3CpTHN6q9EVcII9EYzs?usp=sharing) instead for an example of streaming a dataset). ### Save compute and storage PEFT can help you save storage by avoiding full finetuning of models on each of downstream task or dataset. In many cases, you're only finetuning a very small fraction of a model's parameters and each checkpoint is only a few MBs in size (instead of GBs). These smaller PEFT adapters demonstrate performance comparable to a fully finetuned model. If you have many datasets, you can save a lot of storage with a PEFT model and not have to worry about catastrophic forgetting or overfitting the backbone or base model. ## PEFT integrations PEFT is widely supported across the Hugging Face ecosystem because of the massive efficiency it brings to training and inference. ### Diffusers The iterative diffusion process consumes a lot of memory which can make it difficult to train. PEFT can help reduce the memory requirements and reduce the storage size of the final model checkpoint. For example, consider the memory required for training a Stable Diffusion model with LoRA on an A100 80GB GPU with more than 64GB of CPU RAM. The final model checkpoint size is only 8.8MB! | Model | Full Finetuning | PEFT-LoRA | PEFT-LoRA with Gradient Checkpointing | | --------- | ---- | ---- | ---- | | CompVis/stable-diffusion-v1-4 | 27.5GB GPU / 3.97GB CPU | 15.5GB GPU / 3.84GB CPU | 8.12GB GPU / 3.77GB CPU | > [!TIP] > Take a look at the [examples/lora_dreambooth/train_dreambooth.py](examples/lora_dreambooth/train_dreambooth.py) training script to try training your own Stable Diffusion model with LoRA, and play around with the [smangrul/peft-lora-sd-dreambooth](https://huggingface.co/spaces/smangrul/peft-lora-sd-dreambooth) Space which is running on a T4 instance. Learn more about the PEFT integration in Diffusers in this [tutorial](https://huggingface.co/docs/peft/main/en/tutorial/peft_integrations#diffusers). ### Transformers PEFT is directly integrated with [Transformers](https://huggingface.co/docs/transformers/main/en/peft). After loading a model, call `add_adapter` to add a new PEFT adapter to the model: ```python from peft import LoraConfig model = ... # transformers model peft_config = LoraConfig(...) model.add_adapter(lora_config, adapter_name="lora_1") ``` To load a trained PEFT adapter, call `load_adapter`: ```python model = ... # transformers model model.load_adapter(