# Copyright 2023-present the HuggingFace Inc. team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import importlib
import itertools
import os
import re
import tempfile
import unittest
from collections import Counter, defaultdict
from copy import deepcopy
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Union
import numpy as np
import pytest
import torch
from accelerate import infer_auto_device_map
from accelerate.test_utils.testing import run_command
from accelerate.utils import patch_environment
from accelerate.utils.imports import is_bf16_available
from accelerate.utils.memory import clear_device_cache
from accelerate.utils.versions import is_torch_version
from datasets import Audio, Dataset, DatasetDict, load_dataset
from packaging import version
from parameterized import parameterized
from torch.distributed import init_process_group
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.utils.data import DataLoader
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    DataCollatorForLanguageModeling,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    Trainer,
    TrainerCallback,
    TrainingArguments,
    WhisperFeatureExtractor,
    WhisperForConditionalGeneration,
    WhisperProcessor,
    WhisperTokenizer,
)
from transformers.pytorch_utils import Conv1D
from peft import (
    AdaLoraConfig,
    ArrowConfig,
    EvaConfig,
    LoftQConfig,
    LoraConfig,
    PeftModel,
    PrefixTuningConfig,
    PromptEncoderConfig,
    RandLoraConfig,
    RoadConfig,
    TaskType,
    VeraConfig,
    create_arrow_model,
    get_peft_model,
    get_peft_model_state_dict,
    initialize_lora_eva_weights,
    inject_adapter_in_model,
    prepare_model_for_kbit_training,
    replace_lora_weights_loftq,
    set_peft_model_state_dict,
)
from peft.import_utils import is_diffusers_available, is_xpu_available
from peft.tuners import boft
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import SAFETENSORS_WEIGHTS_NAME, infer_device
from peft.utils.hotswap import hotswap_adapter, prepare_model_for_compiled_hotswap
from peft.utils.loftq_utils import NFQuantizer
from peft.utils.other import fsdp_auto_wrap_policy
from tests.testing_utils import hub_online_once
from .testing_utils import (
    device_count,
    load_dataset_english_quotes,
    require_aqlm,
    require_auto_awq,
    require_auto_gptq,
    require_bitsandbytes,
    require_deterministic_for_xpu,
    require_eetq,
    require_hqq,
    require_non_cpu,
    require_non_xpu,
    require_optimum,
    require_torch_gpu,
    require_torch_multi_accelerator,
    require_torch_multi_gpu,
    require_torchao,
    torch_device,
)
# Some tests with multi GPU require specific device maps to ensure that the models are loaded in two devices
DEVICE_MAP_MAP: dict[str, dict[str, int]] = {
    "facebook/opt-6.7b": {
        "model.decoder.embed_tokens": 0,
        "model.decoder.embed_positions": 0,
        "model.decoder.final_layer_norm": 0,
        "model.decoder.layers.0": 0,
        "model.decoder.layers.1": 0,
        "model.decoder.layers.2": 0,
        "model.decoder.layers.3": 0,
        "model.decoder.layers.4": 0,
        "model.decoder.layers.5": 0,
        "model.decoder.layers.6": 0,
        "model.decoder.layers.7": 0,
        "model.decoder.layers.8": 0,
        "model.decoder.layers.9": 0,
        "model.decoder.layers.10": 0,
        "model.decoder.layers.11": 0,
        "model.decoder.layers.12": 0,
        "model.decoder.layers.13": 0,
        "model.decoder.layers.14": 0,
        "model.decoder.layers.15": 0,
        "model.decoder.layers.16": 1,
        "model.decoder.layers.17": 1,
        "model.decoder.layers.18": 1,
        "model.decoder.layers.19": 1,
        "model.decoder.layers.20": 1,
        "model.decoder.layers.21": 1,
        "model.decoder.layers.22": 1,
        "model.decoder.layers.23": 1,
        "model.decoder.layers.24": 1,
        "model.decoder.layers.25": 1,
        "model.decoder.layers.26": 1,
        "model.decoder.layers.27": 1,
        "model.decoder.layers.28": 1,
        "model.decoder.layers.29": 1,
        "model.decoder.layers.30": 1,
        "model.decoder.layers.31": 1,
        "lm_head": 0,  # tied with embed_tokens
    },
    "facebook/opt-125m": {
        "model.decoder.embed_tokens": 0,
        "model.decoder.embed_positions": 0,
        "model.decoder.final_layer_norm": 1,
        "model.decoder.layers.0": 0,
        "model.decoder.layers.1": 0,
        "model.decoder.layers.2": 0,
        "model.decoder.layers.3": 0,
        "model.decoder.layers.4": 0,
        "model.decoder.layers.5": 0,
        "model.decoder.layers.6": 1,
        "model.decoder.layers.7": 1,
        "model.decoder.layers.8": 1,
        "model.decoder.layers.9": 1,
        "model.decoder.layers.10": 1,
        "model.decoder.layers.11": 1,
        "lm_head": 0,
    },
    "marcsun13/opt-350m-gptq-4bit": {
        "model.decoder.embed_tokens": 0,
        "model.decoder.embed_positions": 0,
        "model.decoder.layers.0": 0,
        "model.decoder.layers.1": 0,
        "model.decoder.layers.2": 0,
        "model.decoder.layers.3": 0,
        "model.decoder.layers.4": 0,
        "model.decoder.layers.5": 0,
        "model.decoder.layers.6": 1,
        "model.decoder.layers.7": 1,
        "model.decoder.layers.8": 1,
        "model.decoder.layers.9": 1,
        "model.decoder.layers.10": 1,
        "model.decoder.layers.11": 1,
        "model.decoder.final_layer_norm": 1,
        "lm_head": 0,  # tied with embed_tokens
    },
    "google/flan-t5-base": {
        "shared": 0,
        "encoder": 0,
        "decoder": 1,
        "final_layer_norm": 1,
        "decoder.embed_tokens": 0,  # tied with encoder.embed_tokens
        "lm_head": 0,  # tied with encoder.embed_tokens
    },
}
# A full testing suite that tests all the necessary features on GPU. The tests should
# rely on the example scripts to test the features.
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    r"""
    Directly copied from:
    https://github.com/huggingface/peft/blob/main/examples/int8_training/peft_bnb_whisper_large_v2_training.ipynb
    """
    processor: Any
    def __call__(self, features: list[dict[str, Union[list[int], torch.Tensor]]]) -> dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]
        batch["labels"] = labels
        return batch
@require_non_cpu
@require_bitsandbytes
class PeftBnbGPUExampleTests(unittest.TestCase):
    r"""
    A single GPU int8 + fp4 test suite, this will test if training fits correctly on a single GPU device (1x NVIDIA T4
    16GB) using bitsandbytes.
    The tests are the following:
    - Seq2Seq model training based on:
      https://github.com/huggingface/peft/blob/main/examples/int8_training/Finetune_flan_t5_large_bnb_peft.ipynb
    - Causal LM model training based on:
      https://github.com/huggingface/peft/blob/main/examples/int8_training/Finetune_opt_bnb_peft.ipynb
    - Audio model training based on:
      https://github.com/huggingface/peft/blob/main/examples/int8_training/peft_bnb_whisper_large_v2_training.ipynb
    """
    def setUp(self):
        self.seq2seq_model_id = "google/flan-t5-base"
        self.causal_lm_model_id = "facebook/opt-6.7b"
        self.tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
        self.audio_model_id = "openai/whisper-large"
    def tearDown(self):
        r"""
        Efficient mechanism to free GPU memory after each test. Based on
        https://github.com/huggingface/transformers/issues/21094
        """
        clear_device_cache(garbage_collection=True)
    def _check_inference_finite(self, model, batch):
        # try inference without Trainer class
        training = model.training
        model.eval()
        output = model(**batch.to(model.device))
        assert torch.isfinite(output.logits).all()
        model.train(training)
    @pytest.mark.single_gpu_tests
    def test_causal_lm_training(self):
        r"""
        Test the CausalLM training on a single GPU device. This test is a converted version of
        https://github.com/huggingface/peft/blob/main/examples/int8_training/Finetune_opt_bnb_peft.ipynb where we train
        `opt-6.7b` on `english_quotes` dataset in few steps. The test would simply fail if the adapters are not set
        correctly.
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                quantization_config=BitsAndBytesConfig(load_in_8bit=True),
                device_map="auto",
            )
            tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
            model = prepare_model_for_kbit_training(model)
            config = LoraConfig(
                r=16,
                lora_alpha=32,
                target_modules=["q_proj", "v_proj"],
                lora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(model, config)
            data = load_dataset_english_quotes()
            data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    fp16=True,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.single_gpu_tests
    def test_causal_lm_training_4bit(self):
        r"""
        Test the CausalLM training on a single GPU device. This test is a converted version of
        https://github.com/huggingface/peft/blob/main/examples/int8_training/Finetune_opt_bnb_peft.ipynb where we train
        `opt-6.7b` on `english_quotes` dataset in few steps using 4bit base model. The test would simply fail if the
        adapters are not set correctly.
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                quantization_config=BitsAndBytesConfig(load_in_4bit=True),
                device_map="auto",
            )
            tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
            model = prepare_model_for_kbit_training(model)
            config = LoraConfig(
                r=16,
                lora_alpha=32,
                target_modules=["q_proj", "v_proj"],
                lora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(model, config)
            data = load_dataset_english_quotes()
            data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    fp16=True,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.multi_gpu_tests
    def test_causal_lm_training_multi_gpu_4bit(self):
        r"""
        Test the CausalLM training on a multi-GPU device with 4bit base model. The test would simply fail if the
        adapters are not set correctly.
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                device_map=DEVICE_MAP_MAP[self.causal_lm_model_id],
                quantization_config=BitsAndBytesConfig(load_in_4bit=True),
            )
            assert set(model.hf_device_map.values()) == set(range(device_count))
            assert {p.device.index for p in model.parameters()} == set(range(device_count))
            model = prepare_model_for_kbit_training(model)
            setattr(model, "model_parallel", True)
            setattr(model, "is_parallelizable", True)
            config = LoraConfig(
                r=16,
                lora_alpha=32,
                target_modules=["q_proj", "v_proj"],
                lora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(model, config)
            data = load_dataset_english_quotes()
            data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    fp16=True,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.single_gpu_tests
    @require_non_cpu
    def test_4bit_adalora_causalLM(self):
        r"""
        Tests the 4bit training with adalora
        """
        model_id = "facebook/opt-350m"
        # for >3 GPUs, might need: device_map={"": "cuda:0"}
        model = AutoModelForCausalLM.from_pretrained(
            model_id, quantization_config=BitsAndBytesConfig(load_in_4bit=True)
        )
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        model.gradient_checkpointing_enable()
        model = prepare_model_for_kbit_training(model)
        peft_config = AdaLoraConfig(
            init_r=6,
            target_r=4,
            tinit=2,
            tfinal=2,
            total_step=6,
            deltaT=5,
            beta1=0.3,
            beta2=0.3,
            orth_reg_weight=0.2,
            lora_alpha=32,
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM",
        )
        model = get_peft_model(model, peft_config)
        data = load_dataset_english_quotes()
        data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
        batch = tokenizer(data["train"][:3]["quote"], return_tensors="pt", padding=True)
        self._check_inference_finite(model, batch)
        class OptimizerStepCallback(TrainerCallback):
            def on_optimizer_step(self, args, state, control, **kwargs):
                model.update_and_allocate(state.global_step)
        step_callback = OptimizerStepCallback()
        with tempfile.TemporaryDirectory() as tmp_dir:
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=6,
                    learning_rate=2e-4,
                    fp16=True,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.add_callback(step_callback)
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.single_gpu_tests
    @require_non_cpu
    def test_8bit_adalora_causalLM(self):
        r"""
        Tests the 8bit training with adalora
        """
        model_id = "facebook/opt-350m"
        model = AutoModelForCausalLM.from_pretrained(
            model_id, quantization_config=BitsAndBytesConfig(load_in_8bit=True)
        )
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        model.gradient_checkpointing_enable()
        model = prepare_model_for_kbit_training(model)
        peft_config = AdaLoraConfig(
            init_r=6,
            target_r=4,
            tinit=2,
            tfinal=2,
            total_step=6,
            deltaT=5,
            beta1=0.3,
            beta2=0.3,
            orth_reg_weight=0.2,
            lora_alpha=32,
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM",
        )
        model = get_peft_model(model, peft_config)
        data = load_dataset_english_quotes()
        data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
        batch = tokenizer(data["train"][:3]["quote"], return_tensors="pt", padding=True)
        self._check_inference_finite(model, batch)
        class OptimizerStepCallback(TrainerCallback):
            def on_optimizer_step(self, args, state, control, **kwargs):
                model.update_and_allocate(state.global_step)
        step_callback = OptimizerStepCallback()
        with tempfile.TemporaryDirectory() as tmp_dir:
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=6,
                    learning_rate=2e-4,
                    fp16=True,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.add_callback(step_callback)
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.multi_gpu_tests
    @require_torch_multi_accelerator
    def test_causal_lm_training_multi_gpu(self):
        r"""
        Test the CausalLM training on a multi-GPU device. This test is a converted version of
        https://github.com/huggingface/peft/blob/main/examples/int8_training/Finetune_opt_bnb_peft.ipynb where we train
        `opt-6.7b` on `english_quotes` dataset in few steps. The test would simply fail if the adapters are not set
        correctly.
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                quantization_config=BitsAndBytesConfig(load_in_8bit=True),
                device_map="auto",
            )
            print(f"device map: {model.hf_device_map}")
            assert set(model.hf_device_map.values()) == set(range(device_count))
            tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
            model = prepare_model_for_kbit_training(model)
            setattr(model, "model_parallel", True)
            setattr(model, "is_parallelizable", True)
            config = LoraConfig(
                r=16,
                lora_alpha=32,
                target_modules=["q_proj", "v_proj"],
                lora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(model, config)
            data = load_dataset_english_quotes()
            data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    fp16=True,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.single_gpu_tests
    def test_seq2seq_lm_training_single_gpu(self):
        r"""
        Test the Seq2SeqLM training on a single GPU device. This test is a converted version of
        https://github.com/huggingface/peft/blob/main/examples/int8_training/Finetune_opt_bnb_peft.ipynb where we train
        `flan-large` on `english_quotes` dataset in few steps. The test would simply fail if the adapters are not set
        correctly.
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForSeq2SeqLM.from_pretrained(
                self.seq2seq_model_id,
                quantization_config=BitsAndBytesConfig(load_in_8bit=True),
                device_map={"": 0},
            )
            assert set(model.hf_device_map.values()) == {0}
            tokenizer = AutoTokenizer.from_pretrained(self.seq2seq_model_id)
            model = prepare_model_for_kbit_training(model)
            config = LoraConfig(
                r=16,
                lora_alpha=32,
                target_modules=["q", "v"],
                lora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(model, config)
            data = load_dataset_english_quotes()
            data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    fp16=True,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.multi_gpu_tests
    @require_torch_multi_accelerator
    def test_seq2seq_lm_training_multi_gpu(self):
        r"""
        Test the Seq2SeqLM training on a multi-GPU device. This test is a converted version of
        https://github.com/huggingface/peft/blob/main/examples/int8_training/Finetune_opt_bnb_peft.ipynb where we train
        `flan-large` on `english_quotes` dataset in few steps. The test would simply fail if the adapters are not set
        correctly.
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForSeq2SeqLM.from_pretrained(
                self.seq2seq_model_id,
                quantization_config=BitsAndBytesConfig(load_in_8bit=True),
                device_map=DEVICE_MAP_MAP[self.seq2seq_model_id],
            )
            assert set(model.hf_device_map.values()) == set(range(device_count))
            assert {p.device.index for p in model.parameters()} == set(range(device_count))
            tokenizer = AutoTokenizer.from_pretrained(self.seq2seq_model_id)
            model = prepare_model_for_kbit_training(model)
            config = LoraConfig(
                r=16,
                lora_alpha=32,
                target_modules=["q", "v"],
                lora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(model, config)
            data = load_dataset_english_quotes()
            data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    fp16=True,
                    logging_steps=1,
                    output_dir="outputs",
                ),
                data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    # TODO skipping to see if this leads to single GPU tests passing
    @pytest.mark.skip
    @pytest.mark.single_gpu_tests
    def test_audio_model_training(self):
        r"""
        Test the audio model training on a single GPU device. This test is a converted version of
        https://github.com/huggingface/peft/blob/main/examples/int8_training/peft_bnb_whisper_large_v2_training.ipynb
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
            dataset_name = "ybelkada/common_voice_mr_11_0_copy"
            task = "transcribe"
            language = "Marathi"
            common_voice = DatasetDict()
            common_voice["train"] = load_dataset(dataset_name, split="train+validation")
            common_voice = common_voice.remove_columns(
                ["accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"]
            )
            feature_extractor = WhisperFeatureExtractor.from_pretrained(self.audio_model_id)
            tokenizer = WhisperTokenizer.from_pretrained(self.audio_model_id, language=language, task=task)
            processor = WhisperProcessor.from_pretrained(self.audio_model_id, language=language, task=task)
            common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))
            def prepare_dataset(batch):
                # load and resample audio data from 48 to 16kHz
                audio = batch["audio"]
                # compute log-Mel input features from input audio array
                batch["input_features"] = feature_extractor(
                    audio["array"], sampling_rate=audio["sampling_rate"]
                ).input_features[0]
                # encode target text to label ids
                batch["labels"] = tokenizer(batch["sentence"]).input_ids
                return batch
            common_voice = common_voice.map(
                prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=2
            )
            data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
            model = WhisperForConditionalGeneration.from_pretrained(
                self.audio_model_id, quantization_config=BitsAndBytesConfig(load_in_8bit=True), device_map="auto"
            )
            model.config.forced_decoder_ids = None
            model.config.suppress_tokens = []
            model = prepare_model_for_kbit_training(model)
            # as Whisper model uses Conv layer in encoder, checkpointing disables grad computation
            # to avoid this, make the inputs trainable
            def make_inputs_require_grad(module, input, output):
                output.requires_grad_(True)
            model.model.encoder.conv1.register_forward_hook(make_inputs_require_grad)
            config = LoraConfig(
                r=32, lora_alpha=64, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none"
            )
            model = get_peft_model(model, config)
            model.print_trainable_parameters()
            training_args = Seq2SeqTrainingArguments(
                output_dir=tmp_dir,  # change to a repo name of your choice
                per_device_train_batch_size=8,
                gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
                learning_rate=1e-3,
                warmup_steps=2,
                max_steps=3,
                fp16=True,
                per_device_eval_batch_size=8,
                generation_max_length=128,
                logging_steps=25,
                remove_unused_columns=False,  # required as the PeftModel forward doesn't have the signature of the wrapped model's forward
                label_names=["labels"],  # same reason as above
            )
            trainer = Seq2SeqTrainer(
                args=training_args,
                model=model,
                train_dataset=common_voice["train"],
                data_collator=data_collator,
                tokenizer=processor.feature_extractor,
            )
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.single_gpu_tests
    def test_4bit_non_default_adapter_name(self):
        # See PR 1294
        config = LoraConfig(
            r=16,
            target_modules=["q_proj", "v_proj"],
            bias="none",
            task_type="CAUSAL_LM",
        )
        # default adapter name
        model = AutoModelForCausalLM.from_pretrained(
            "facebook/opt-125m",
            device_map="auto",
            quantization_config=BitsAndBytesConfig(load_in_4bit=True),
        )
        model = prepare_model_for_kbit_training(model)
        model = get_peft_model(model, config)
        n_trainable_default, n_total_default = model.get_nb_trainable_parameters()
        # other adapter name
        model = AutoModelForCausalLM.from_pretrained(
            "facebook/opt-125m",
            device_map="auto",
            quantization_config=BitsAndBytesConfig(load_in_4bit=True),
        )
        model = prepare_model_for_kbit_training(model)
        model = get_peft_model(model, config, adapter_name="other")
        n_trainable_other, n_total_other = model.get_nb_trainable_parameters()
        assert n_trainable_other > 0
        # sanity check
        assert n_trainable_default == n_trainable_other
        assert n_total_default == n_total_other
    @pytest.mark.single_gpu_tests
    def test_8bit_non_default_adapter_name(self):
        # See PR 1294
        config = LoraConfig(
            r=16,
            target_modules=["q_proj", "v_proj"],
            bias="none",
            task_type="CAUSAL_LM",
        )
        # default adapter name
        model = AutoModelForCausalLM.from_pretrained(
            "facebook/opt-125m",
            device_map="auto",
            quantization_config=BitsAndBytesConfig(load_in_8bit=True),
        )
        model = prepare_model_for_kbit_training(model)
        model = get_peft_model(model, config)
        n_trainable_default, n_total_default = model.get_nb_trainable_parameters()
        # other adapter name
        model = AutoModelForCausalLM.from_pretrained(
            "facebook/opt-125m",
            device_map="auto",
            quantization_config=BitsAndBytesConfig(load_in_8bit=True),
        )
        model = prepare_model_for_kbit_training(model)
        model = get_peft_model(model, config, adapter_name="other")
        n_trainable_other, n_total_other = model.get_nb_trainable_parameters()
        assert n_trainable_other > 0
        # sanity check
        assert n_trainable_default == n_trainable_other
        assert n_total_default == n_total_other
    @pytest.mark.single_gpu_tests
    def test_causal_lm_training_4bit_dora(self):
        r"""
        Same as test_causal_lm_training_4bit but with DoRA
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                quantization_config=BitsAndBytesConfig(load_in_4bit=True),
                device_map="auto",
            )
            tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
            model = prepare_model_for_kbit_training(model)
            config = LoraConfig(
                r=16,
                lora_alpha=32,
                target_modules=["q_proj", "v_proj"],
                lora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
                use_dora=True,
            )
            model = get_peft_model(model, config)
            data = load_dataset_english_quotes()
            data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    fp16=True,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.multi_gpu_tests
    def test_causal_lm_training_multi_gpu_4bit_dora(self):
        r"""
        Same as test_causal_lm_training_multi_gpu_4bit but with DoRA
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                device_map=DEVICE_MAP_MAP[self.causal_lm_model_id],
                quantization_config=BitsAndBytesConfig(load_in_4bit=True),
            )
            assert set(model.hf_device_map.values()) == set(range(device_count))
            assert {p.device.index for p in model.parameters()} == set(range(device_count))
            model = prepare_model_for_kbit_training(model)
            setattr(model, "model_parallel", True)
            setattr(model, "is_parallelizable", True)
            config = LoraConfig(
                r=16,
                lora_alpha=32,
                target_modules=["q_proj", "v_proj"],
                lora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
                use_dora=True,
            )
            model = get_peft_model(model, config)
            data = load_dataset_english_quotes()
            data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    fp16=True,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.single_gpu_tests
    def test_causal_lm_training_8bit_dora(self):
        r"""
        Same as test_causal_lm_training_4bit_dora but with 8bit
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                quantization_config=BitsAndBytesConfig(load_in_8bit=True),
                device_map="auto",
            )
            tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
            model = prepare_model_for_kbit_training(model)
            config = LoraConfig(
                r=16,
                lora_alpha=32,
                target_modules=["q_proj", "v_proj"],
                lora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
                use_dora=True,
            )
            model = get_peft_model(model, config)
            data = load_dataset_english_quotes()
            data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    fp16=True,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.multi_gpu_tests
    def test_causal_lm_training_multi_gpu_8bit_dora(self):
        r"""
        Same as test_causal_lm_training_multi_gpu_4bit_dora but with 8bit
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                device_map=DEVICE_MAP_MAP[self.causal_lm_model_id],
                quantization_config=BitsAndBytesConfig(load_in_8bit=True),
            )
            assert set(model.hf_device_map.values()) == set(range(device_count))
            assert {p.device.index for p in model.parameters()} == set(range(device_count))
            model = prepare_model_for_kbit_training(model)
            setattr(model, "model_parallel", True)
            setattr(model, "is_parallelizable", True)
            config = LoraConfig(
                r=16,
                lora_alpha=32,
                target_modules=["q_proj", "v_proj"],
                lora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
                use_dora=True,
            )
            model = get_peft_model(model, config)
            data = load_dataset_english_quotes()
            data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    fp16=True,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.single_gpu_tests
    def test_causal_lm_training_gpt2_dora(self):
        r"""
        Same as test_causal_lm_training_4bit but with DoRA
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForCausalLM.from_pretrained("gpt2", device_map="auto")
            tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
            model = prepare_model_for_kbit_training(model)
            config = LoraConfig(
                r=16,
                lora_alpha=32,
                lora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
                use_dora=True,
            )
            model = get_peft_model(model, config)
            data = load_dataset_english_quotes()
            data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    fp16=True,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @parameterized.expand(["4bit", "8bit"])
    def test_initialize_dora_with_bnb_on_cpu(self, kbit):
        # 1674
        # The issue is that to initialize DoRA, we need to dequantize the weights. That only works on GPU for bnb.
        # Therefore, initializing DoRA with bnb on CPU used to fail.
        model_id = "facebook/opt-125m"
        if kbit == "4bit":
            bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")
        elif kbit == "8bit":
            bnb_config = BitsAndBytesConfig(load_in_8bit=True)
        else:
            raise ValueError("Only 4bit and 8bit bnb allowed")
        model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)
        model = model.cpu()  # ensure that we're on CPU
        # sanity check that all weights are on CPU
        weights_not_cpu = [name for name, p in model.named_parameters() if p.device != torch.device("cpu")]
        assert not weights_not_cpu
        lora_config = LoraConfig(use_dora=True)
        # should not raise
        peft_model = get_peft_model(model, lora_config)
        # check that the weights are still on CPU
        weights_not_cpu = [name for name, p in peft_model.named_parameters() if p.device != torch.device("cpu")]
        assert not weights_not_cpu
    @pytest.mark.single_gpu_tests
    def test_causal_lm_training_vera(self):
        r"""
        Same as test_causal_lm_training but with VeRA
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                quantization_config=BitsAndBytesConfig(load_in_8bit=True),
                device_map="auto",
            )
            tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
            model = prepare_model_for_kbit_training(model)
            config = VeraConfig(
                r=16,
                target_modules=["q_proj", "v_proj"],
                vera_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(model, config)
            data = load_dataset_english_quotes()
            data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    fp16=True,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.single_gpu_tests
    def test_causal_lm_training_4bit_vera(self):
        r"""
        Same as test_causal_lm_training_4bit but with VeRA
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                quantization_config=BitsAndBytesConfig(load_in_4bit=True),
                device_map="auto",
            )
            tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
            model = prepare_model_for_kbit_training(model)
            config = VeraConfig(
                r=16,
                target_modules=["q_proj", "v_proj"],
                vera_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(model, config)
            data = load_dataset_english_quotes()
            data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    fp16=True,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.multi_gpu_tests
    def test_causal_lm_training_multi_gpu_vera(self):
        r"""
        Same as test_causal_lm_training_multi_gpu but with VeRA
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                device_map=DEVICE_MAP_MAP[self.causal_lm_model_id],
                quantization_config=BitsAndBytesConfig(load_in_8bit=True),
            )
            assert set(model.hf_device_map.values()) == set(range(device_count))
            assert {p.device.index for p in model.parameters()} == set(range(device_count))
            model = prepare_model_for_kbit_training(model)
            setattr(model, "model_parallel", True)
            setattr(model, "is_parallelizable", True)
            config = VeraConfig(
                r=16,
                target_modules=["q_proj", "v_proj"],
                vera_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(model, config)
            data = load_dataset_english_quotes()
            data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    fp16=True,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.multi_gpu_tests
    def test_causal_lm_training_multi_gpu_4bit_vera(self):
        r"""
        Same as test_causal_lm_training_multi_gpu_4bit but with VeRA
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                device_map=DEVICE_MAP_MAP[self.causal_lm_model_id],
                quantization_config=BitsAndBytesConfig(load_in_4bit=True),
            )
            assert set(model.hf_device_map.values()) == set(range(device_count))
            assert {p.device.index for p in model.parameters()} == set(range(device_count))
            model = prepare_model_for_kbit_training(model)
            setattr(model, "model_parallel", True)
            setattr(model, "is_parallelizable", True)
            config = VeraConfig(
                r=16,
                target_modules=["q_proj", "v_proj"],
                vera_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(model, config)
            data = load_dataset_english_quotes()
            data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    fp16=True,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.single_gpu_tests
    def test_causal_lm_training_8bit_randlora(self):
        r"""
        Same as test_causal_lm_training but with RandLora
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                quantization_config=BitsAndBytesConfig(load_in_8bit=True),
                device_map="auto",
            )
            tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
            model = prepare_model_for_kbit_training(model)
            config = RandLoraConfig(
                r=16,
                target_modules=["q_proj", "v_proj"],
                randlora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(model, config)
            data = load_dataset("ybelkada/english_quotes_copy")
            data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    fp16=True,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.single_gpu_tests
    def test_causal_lm_training_4bit_randlora(self):
        r"""
        Same as test_causal_lm_training_4bit but with RandLora
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                quantization_config=BitsAndBytesConfig(load_in_4bit=True),
                device_map="auto",
            )
            tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
            model = prepare_model_for_kbit_training(model)
            config = RandLoraConfig(
                r=16,
                target_modules=["q_proj", "v_proj"],
                randlora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(model, config)
            data = load_dataset("ybelkada/english_quotes_copy")
            data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    fp16=True,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.multi_gpu_tests
    def test_causal_lm_training_multi_gpu_8bit_randlora(self):
        r"""
        Same as test_causal_lm_training_multi_gpu but with RandLoRA
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                device_map=DEVICE_MAP_MAP[self.causal_lm_model_id],
                quantization_config=BitsAndBytesConfig(load_in_8bit=True),
            )
            assert set(model.hf_device_map.values()) == set(range(device_count))
            assert {p.device.index for p in model.parameters()} == set(range(device_count))
            model = prepare_model_for_kbit_training(model)
            setattr(model, "model_parallel", True)
            setattr(model, "is_parallelizable", True)
            config = RandLoraConfig(
                r=16,
                target_modules=["q_proj", "v_proj"],
                randlora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(model, config)
            data = load_dataset("Abirate/english_quotes")
            data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    fp16=True,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.multi_gpu_tests
    def test_causal_lm_training_multi_gpu_4bit_randlora(self):
        r"""
        Same as test_causal_lm_training_multi_gpu_4bit but with RandLora
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                device_map=DEVICE_MAP_MAP[self.causal_lm_model_id],
                quantization_config=BitsAndBytesConfig(load_in_4bit=True),
            )
            assert set(model.hf_device_map.values()) == set(range(device_count))
            assert {p.device.index for p in model.parameters()} == set(range(device_count))
            model = prepare_model_for_kbit_training(model)
            setattr(model, "model_parallel", True)
            setattr(model, "is_parallelizable", True)
            config = RandLoraConfig(
                r=16,
                target_modules=["q_proj", "v_proj"],
                randlora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(model, config)
            data = load_dataset("Abirate/english_quotes")
            data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    fp16=True,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.single_gpu_tests
    def test_causal_lm_training_8bit_road(self):
        r"""
        Same as test_causal_lm_training but with RoAd
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                quantization_config=BitsAndBytesConfig(load_in_8bit=True),
                device_map="auto",
            )
            tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
            model = prepare_model_for_kbit_training(model)
            config = RoadConfig(
                variant="road_1",
                target_modules=["q_proj", "v_proj"],
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(model, config)
            data = load_dataset("ybelkada/english_quotes_copy")
            data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=1e-3,
                    fp16=True,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.single_gpu_tests
    def test_causal_lm_training_4bit_road(self):
        r"""
        Same as test_causal_lm_training_4bit but with RoAd
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                quantization_config=BitsAndBytesConfig(load_in_4bit=True),
                device_map="auto",
            )
            tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
            model = prepare_model_for_kbit_training(model)
            config = RoadConfig(
                variant="road_1",
                target_modules=["q_proj", "v_proj"],
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(model, config)
            data = load_dataset("ybelkada/english_quotes_copy")
            data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=1e-3,
                    fp16=True,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.multi_gpu_tests
    def test_causal_lm_training_multi_gpu_8bit_road(self):
        r"""
        Same as test_causal_lm_training_multi_gpu but with RoAd
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                device_map=DEVICE_MAP_MAP[self.causal_lm_model_id],
                quantization_config=BitsAndBytesConfig(load_in_8bit=True),
            )
            assert set(model.hf_device_map.values()) == set(range(device_count))
            assert {p.device.index for p in model.parameters()} == set(range(device_count))
            model = prepare_model_for_kbit_training(model)
            setattr(model, "model_parallel", True)
            setattr(model, "is_parallelizable", True)
            config = RoadConfig(
                variant="road_1",
                target_modules=["q_proj", "v_proj"],
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(model, config)
            data = load_dataset("Abirate/english_quotes")
            data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=1e-3,
                    fp16=True,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.multi_gpu_tests
    def test_causal_lm_training_multi_gpu_4bit_road(self):
        r"""
        Same as test_causal_lm_training_multi_gpu_4bit but with RoAd
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                device_map=DEVICE_MAP_MAP[self.causal_lm_model_id],
                quantization_config=BitsAndBytesConfig(load_in_4bit=True),
            )
            assert set(model.hf_device_map.values()) == set(range(device_count))
            assert {p.device.index for p in model.parameters()} == set(range(device_count))
            model = prepare_model_for_kbit_training(model)
            setattr(model, "model_parallel", True)
            setattr(model, "is_parallelizable", True)
            config = RoadConfig(
                variant="road_1",
                target_modules=["q_proj", "v_proj"],
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(model, config)
            data = load_dataset("Abirate/english_quotes")
            data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=1e-3,
                    fp16=True,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.single_gpu_tests
    def test_causal_lm_training_lora_resize_embeddings_trainable_tokens(self):
        r"""
        Test LoRA with trainable tokens on a resized embedding matrix
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_quant_storage=torch.float16,
                bnb_4bit_use_double_quant=True,
            )
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                quantization_config=bnb_config,
                device_map="auto",
            )
            # add 2 new tokens
            tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
            new_tokens = ["", ""]
            tokenizer.add_special_tokens({"additional_special_tokens": new_tokens})
            trainable_token_indices = [tokenizer.vocab[token] for token in new_tokens]
            cur_emb_size = model.model.decoder.embed_tokens.weight.shape[0]
            model.resize_token_embeddings(max(tokenizer.vocab_size, cur_emb_size))
            model = prepare_model_for_kbit_training(model)
            config = LoraConfig(
                r=16,
                lora_alpha=32,
                target_modules=["q_proj", "v_proj"],
                lora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
                trainable_token_indices={"embed_tokens": trainable_token_indices},
            )
            model = get_peft_model(model, config)
            data = load_dataset_english_quotes()
            def tokenize(samples):
                # add new tokens to samples
                samples = [f"{row}" for row in samples["quote"]]
                return tokenizer(samples)
            data = data.map(tokenize, batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    # higher learning rate, as embeddings are a bit slow to update
                    learning_rate=1e-3,
                    fp16=True,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            # ensure that the new trainable tokens have been updated
            embedding = model.base_model.model.model.decoder.embed_tokens
            tol = 1e-4
            assert not torch.allclose(
                embedding.token_adapter.trainable_tokens_delta["default"],
                embedding.original_module.weight[trainable_token_indices],
                atol=tol,
                rtol=tol,
            )
            # check size of the checkpoint, should be small since the embedding matrix does not need to be stored
            stat = os.stat(os.path.join(tmp_dir, SAFETENSORS_WEIGHTS_NAME))
            embed_params = model.base_model.model.model.decoder.embed_tokens.original_module.weight.numel()
            # fp32 -> 4x
            emb_file_size = 4 * embed_params
            assert stat.st_size < emb_file_size
            # sanity check: assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
@require_torch_gpu
@require_auto_gptq
@require_optimum
class PeftGPTQGPUTests(unittest.TestCase):
    r"""
    GPTQ + peft tests
    """
    def setUp(self):
        from transformers import GPTQConfig
        self.causal_lm_model_id = "marcsun13/opt-350m-gptq-4bit"
        # TODO : check if it works for Exllamav2 kernels
        self.quantization_config = GPTQConfig(bits=4, use_exllama=False)
        self.tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
    def tearDown(self):
        r"""
        Efficient mechanism to free GPU memory after each test. Based on
        https://github.com/huggingface/transformers/issues/21094
        """
        clear_device_cache(garbage_collection=True)
    def _check_inference_finite(self, model, batch):
        # try inference without Trainer class
        training = model.training
        model.eval()
        output = model(**batch.to(model.device))
        assert torch.isfinite(output.logits).all()
        model.train(training)
    @pytest.mark.single_gpu_tests
    def test_causal_lm_training(self):
        r"""
        Test the CausalLM training on a single GPU device. The test would simply fail if the adapters are not set
        correctly.
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                torch_dtype=torch.float16,
                device_map="auto",
                quantization_config=self.quantization_config,
            )
            model = prepare_model_for_kbit_training(model)
            config = LoraConfig(
                r=16,
                lora_alpha=32,
                target_modules=["q_proj", "v_proj"],
                lora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(model, config)
            data = load_dataset_english_quotes()
            data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    fp16=True,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.single_gpu_tests
    def test_adalora_causalLM(self):
        r"""
        Tests the gptq training with adalora
        """
        model = AutoModelForCausalLM.from_pretrained(
            self.causal_lm_model_id,
            torch_dtype=torch.float16,
            device_map="auto",
            quantization_config=self.quantization_config,
        )
        tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
        model = prepare_model_for_kbit_training(model)
        peft_config = AdaLoraConfig(
            init_r=6,
            target_r=4,
            tinit=2,
            tfinal=2,
            total_step=6,
            deltaT=5,
            beta1=0.3,
            beta2=0.3,
            orth_reg_weight=0.2,
            lora_alpha=32,
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM",
        )
        model = get_peft_model(model, peft_config)
        data = load_dataset_english_quotes()
        data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
        batch = tokenizer(data["train"][:3]["quote"], return_tensors="pt", padding=True)
        self._check_inference_finite(model, batch)
        class OptimizerStepCallback(TrainerCallback):
            def on_optimizer_step(self, args, state, control, **kwargs):
                model.update_and_allocate(state.global_step)
        step_callback = OptimizerStepCallback()
        with tempfile.TemporaryDirectory() as tmp_dir:
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=6,
                    learning_rate=2e-4,
                    fp16=True,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            )
            trainer.add_callback(step_callback)
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.single_gpu_tests
    def test_causal_lm_training_gptq_qalora(self):
        """
        Test QALoRA with GPTQ quantization. The test would simply fail if the adapters are not set correctly.
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                torch_dtype=torch.float16,
                device_map="auto",
                quantization_config=self.quantization_config,
            )
            model = prepare_model_for_kbit_training(model)
            config = LoraConfig(
                r=16,
                lora_alpha=32,
                target_modules=["q_proj", "v_proj"],
                lora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
                use_qalora=True,
                qalora_group_size=32,
            )
            model = get_peft_model(model, config)
            data = load_dataset_english_quotes()
            data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    fp16=True,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.multi_gpu_tests
    @require_torch_multi_gpu
    def test_causal_lm_training_multi_gpu(self):
        r"""
        Test the CausalLM training on a multi-GPU device. The test would simply fail if the adapters are not set
        correctly.
        """
        device_map = {
            "model.decoder.embed_tokens": 0,
            "lm_head": 0,
            "model.decoder.embed_positions": 0,
            "model.decoder.project_out": 0,
            "model.decoder.project_in": 0,
            "model.decoder.layers.0": 0,
            "model.decoder.layers.1": 0,
            "model.decoder.layers.2": 0,
            "model.decoder.layers.3": 0,
            "model.decoder.layers.4": 0,
            "model.decoder.layers.5": 0,
            "model.decoder.layers.6": 1,
            "model.decoder.layers.7": 1,
            "model.decoder.layers.8": 1,
            "model.decoder.layers.9": 1,
            "model.decoder.layers.10": 1,
            "model.decoder.layers.11": 1,
            "model.decoder.final_layer_norm": 1,
        }
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                torch_dtype=torch.float16,
                device_map=device_map,
                quantization_config=self.quantization_config,
            )
            assert set(model.hf_device_map.values()) == set(range(device_count))
            assert {p.device.index for p in model.parameters()} == set(range(device_count))
            model = prepare_model_for_kbit_training(model)
            setattr(model, "model_parallel", True)
            setattr(model, "is_parallelizable", True)
            config = LoraConfig(
                r=16,
                lora_alpha=32,
                target_modules=["q_proj", "v_proj"],
                lora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(model, config)
            data = load_dataset_english_quotes()
            data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    fp16=True,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.single_gpu_tests
    def test_non_default_adapter_name(self):
        # See issue 1346
        config = LoraConfig(
            r=16,
            target_modules=["q_proj", "v_proj"],
            task_type="CAUSAL_LM",
        )
        # default adapter name
        model = AutoModelForCausalLM.from_pretrained(
            self.causal_lm_model_id,
            torch_dtype=torch.float16,
            device_map="auto",
            quantization_config=self.quantization_config,
        )
        model = prepare_model_for_kbit_training(model)
        model = get_peft_model(model, config)
        n_trainable_default, n_total_default = model.get_nb_trainable_parameters()
        # other adapter name
        model = AutoModelForCausalLM.from_pretrained(
            self.causal_lm_model_id,
            torch_dtype=torch.float16,
            device_map="auto",
            quantization_config=self.quantization_config,
        )
        model = prepare_model_for_kbit_training(model)
        model = get_peft_model(model, config, adapter_name="other")
        n_trainable_other, n_total_other = model.get_nb_trainable_parameters()
        assert n_trainable_other > 0
        # sanity check
        assert n_trainable_default == n_trainable_other
        assert n_total_default == n_total_other
@require_non_cpu
class OffloadSaveTests(unittest.TestCase):
    def setUp(self):
        self.causal_lm_model_id = "gpt2"
    def tearDown(self):
        r"""
        Efficient mechanism to free GPU memory after each test. Based on
        https://github.com/huggingface/transformers/issues/21094
        """
        clear_device_cache(garbage_collection=True)
    def test_offload_load(self):
        r"""
        Test the loading of a LoRA model with CPU- and disk-offloaded modules
        """
        torch.manual_seed(0)
        model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id)
        tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
        memory_limits = {"cpu": "0.4GIB"}  # no "disk" for PeftModel.from_pretrained() compatibility
        # offload around half of all transformer modules to the disk
        device_map = infer_auto_device_map(model, max_memory=memory_limits)
        assert "cpu" in device_map.values()
        assert "disk" in device_map.values()
        config = LoraConfig(task_type="CAUSAL_LM", init_lora_weights=False, target_modules=["c_attn"])
        model = get_peft_model(model, config)
        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir)
            model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id, device_map="cpu")
            lora_model = PeftModel.from_pretrained(model, tmp_dir).eval()
            input_tokens = tokenizer.encode("Four score and seven years ago", return_tensors="pt")
            output = lora_model(input_tokens)[0]
            # load the model with device_map
            offloaded_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id, device_map=device_map)
            assert len({p.device for p in offloaded_model.parameters()}) == 2  # 'cpu' and 'meta'
            offloaded_lora_model = PeftModel.from_pretrained(offloaded_model, tmp_dir, max_memory=memory_limits).eval()
            offloaded_output = offloaded_lora_model(input_tokens)[0]
        assert torch.allclose(output, offloaded_output, atol=1e-5)
    @pytest.mark.single_gpu_tests
    def test_offload_merge(self):
        r"""
        Test merging, unmerging, and unloading of a model with CPU- and disk- offloaded modules.
        """
        torch.manual_seed(0)
        model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id)
        tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
        memory_limits = {0: "0.2GIB", "cpu": "0.2GIB"}  # no "disk" for PeftModel.from_pretrained() compatibility
        # offloads around half of all transformer modules
        device_map = infer_auto_device_map(model, max_memory=memory_limits)
        assert 0 in device_map.values()
        assert "cpu" in device_map.values()
        assert "disk" in device_map.values()
        config = LoraConfig(task_type="CAUSAL_LM", init_lora_weights=False, target_modules=["c_attn"])
        model = get_peft_model(model, config)
        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir)
            # load the model with device_map
            model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id, device_map=device_map).eval()
            assert len({p.device for p in model.parameters()}) == 2
            model = PeftModel.from_pretrained(model, tmp_dir, max_memory=memory_limits)
        input_tokens = tokenizer.encode("Four score and seven years ago", return_tensors="pt")
        model.eval()
        # test peft model adapter merge
        pre_merge_olayer = model(input_tokens)[0]
        model.merge_adapter()
        post_merge_olayer = model(input_tokens)[0]
        assert torch.allclose(post_merge_olayer, pre_merge_olayer)
        # test peft model adapter unmerge
        model.unmerge_adapter()
        post_unmerge_olayer = model(input_tokens)[0]
        assert torch.allclose(post_unmerge_olayer, pre_merge_olayer)
        # test LoRA merge and unload
        model = model.merge_and_unload()
        post_unload_merge_olayer = model(input_tokens)[0]
        assert torch.allclose(post_unload_merge_olayer, pre_merge_olayer)
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="test requires a GPU or XPU")
@pytest.mark.single_gpu_tests
class TestPiSSA:
    r"""
    Tests for PiSSA to ensure that it reduces the quantization error compared to normal LoRA quantization.
    """
    # The error factor indicates by how much the quantization error should be decreased when using PiSSA compared to
    # quantization without PiSSA. Thus 1.03 means that the error should be decreased by 3% at least. This is a very
    # conservative value to prevent flakiness, in practice most gains are > 1.5
    error_factor = 1.03
    def quantize_model(self, model, num_bits=4, device="cuda"):
        # Quantize the `weight.data` of the linear layer in the model to `num_bits` and store it with full precision.
        quantizer = NFQuantizer(num_bits=num_bits, device=device, method="normal", block_size=64)
        for name, module in model.named_modules():
            if isinstance(module, (torch.nn.Linear, Conv1D)) and "lm_head" not in name:
                quantized_weight, max_abs, shape = quantizer.quantize_block(module.weight.data.to(device))
                module.weight.data = quantizer.dequantize_block(quantized_weight, max_abs, shape)
        return model
    def nuclear_norm(self, base_model, quantized_model):
        # Calculate the nuclear norm (sum of singular values) of the error matrices between the `quantized_model` and the `base_model`.
        error_list = []
        for name, module in base_model.named_modules():
            if isinstance(module, (torch.nn.Linear, Conv1D)) and "lm_head" not in name:
                quant_module = quantized_model.get_submodule(name)
                error_list.append(torch.linalg.svdvals(module.weight.data - quant_module.weight.data).sum())
        return torch.Tensor(error_list).sum()
    def get_errors(
        self,
        tmp_path,
        bits=4,
        device="cuda",
        model_id="hf-internal-testing/tiny-random-BloomForCausalLM",
    ):
        # Comparing the quantized LoRA model to the base model, vs the PiSSA quantized model to the base model.
        # We expect the PiSSA quantized model to have less error than the normal LoRA quantized model.
        cls = AutoModelForSeq2SeqLM if "t5" in str(model_id) else AutoModelForCausalLM
        base_model = cls.from_pretrained(model_id).eval().to(device)
        task_type = TaskType.SEQ_2_SEQ_LM if base_model.config.is_encoder_decoder else TaskType.CAUSAL_LM
        # logits from the normal quantized LoRA model
        target_modules = "all-linear" if task_type != TaskType.SEQ_2_SEQ_LM else ["o", "k", "wi", "q", "v"]
        lora_config = LoraConfig(task_type=task_type, target_modules=target_modules)
        qlora_model = self.quantize_model(cls.from_pretrained(model_id).eval().to(device), bits, device)
        qlora_model = get_peft_model(
            qlora_model,
            lora_config,
        )
        qlora_model = qlora_model.merge_and_unload()
        qlora_error = self.nuclear_norm(base_model, qlora_model)
        del qlora_model
        clear_device_cache(garbage_collection=True)
        # logits from quantized LoRA model using PiSSA
        lora_config = LoraConfig(
            task_type=task_type,
            init_lora_weights="pissa",
            target_modules=target_modules,
        )
        pissa_model = cls.from_pretrained(model_id).eval().to(device)
        pissa_model = get_peft_model(pissa_model, lora_config)
        # save LoRA weights, they should be initialized such that they minimize the quantization error
        pissa_model.base_model.peft_config["default"].init_lora_weights = True
        pissa_model.save_pretrained(tmp_path / "pissa_model")
        pissa_model = pissa_model.unload()
        pissa_model.save_pretrained(tmp_path / "residual_model")
        del pissa_model
        clear_device_cache(garbage_collection=True)
        # now load quantized model and apply PiSSA-initialized weights on top
        qpissa_model = self.quantize_model(
            cls.from_pretrained(tmp_path / "residual_model").eval().to(device), bits, device
        )
        qpissa_model = PeftModel.from_pretrained(qpissa_model, tmp_path / "pissa_model")
        qpissa_model = qpissa_model.merge_and_unload()
        qpissa_error = self.nuclear_norm(base_model, qpissa_model)
        del qpissa_model
        clear_device_cache(garbage_collection=True)
        assert qlora_error > 0.0
        assert qpissa_error > 0.0
        # next, check that PiSSA quantization errors are smaller than LoRA errors by a certain margin
        assert qpissa_error < (qlora_error / self.error_factor)
    @pytest.mark.parametrize("device", [torch_device, "cpu"])
    def test_bloomz_pissa_4bit(self, device, tmp_path):
        # In this test, we compare the logits of the base model, the quantized LoRA model, and the quantized model
        # using PiSSA. When quantizing, we expect a certain level of error. However, we expect the PiSSA quantized
        # model to have less error than the normal LoRA quantized model. Note that when using normal LoRA, the
        # quantization error is simply the error from quantization without LoRA, as LoRA is a no-op before training.
        # We still apply LoRA for the test for consistency.
        self.get_errors(bits=4, device=device, tmp_path=tmp_path)
    @pytest.mark.parametrize("device", [torch_device, "cpu"])
    def test_bloomz_pissa_8bit(self, device, tmp_path):
        # Same test as test_bloomz_pissa_4bit but with 8 bits.
        self.get_errors(bits=8, device=device, tmp_path=tmp_path)
    @pytest.mark.parametrize("device", [torch_device, "cpu"])
    def test_t5_pissa_4bit(self, device, tmp_path):
        self.get_errors(bits=4, device=device, model_id="t5-small", tmp_path=tmp_path)
    @pytest.mark.parametrize("device", [torch_device, "cpu"])
    def test_t5_pissa_8bit(self, device, tmp_path):
        self.get_errors(bits=8, device=device, model_id="t5-small", tmp_path=tmp_path)
    @pytest.mark.parametrize("device", [torch_device, "cpu"])
    def test_gpt2_pissa_4bit(self, device, tmp_path):
        # see 2104
        self.get_errors(bits=4, device=device, model_id="gpt2", tmp_path=tmp_path)
    @pytest.mark.parametrize("device", [torch_device, "cpu"])
    def test_gpt2_pissa_8bit(self, device, tmp_path):
        # see 2104
        self.get_errors(bits=8, device=device, model_id="gpt2", tmp_path=tmp_path)
    @require_bitsandbytes
    def test_lora_pissa_conversion_same_output_after_loading_with_quantization(self, tmp_path):
        # A copy of the test `test_lora_pissa_conversion_same_output_after_loading` in peft/tests/test_initialization.py,
        # that would fail if bitsandbytes quantization is used because Quant(W_res) + AB !=Quant(W) + \Delta(AB).
        import bitsandbytes as bnb
        torch.manual_seed(0)
        data = torch.rand(10, 1000).to(torch_device)
        class MyModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                # choose a large weight so that averages are close to expected values
                self.linear = torch.nn.Linear(1000, 1000)
                self.embed = torch.nn.Embedding(1000, 1000)
                self.conv2d = torch.nn.Conv2d(100, 100, 3)
            def forward(self, x):
                x_int = (100 * x).int()
                x_4d = x.flatten().reshape(1, 100, 10, 10)
                return self.linear(x), self.embed(x_int), self.conv2d(x_4d)
        model = MyModule().to(torch_device)
        output_base = model(data)[0]
        config = LoraConfig(init_lora_weights="pissa", target_modules=["linear"], r=8)
        peft_model = get_peft_model(deepcopy(model), config)
        # save the initial model
        peft_model.peft_config["default"].init_lora_weights = True
        peft_model.save_pretrained(tmp_path / "init-model")
        peft_model = peft_model.unload()
        torch.save(peft_model.state_dict(), tmp_path / "residual-model")
        del peft_model
        # create 4bit base model
        base_model = deepcopy(model)
        base_model.load_state_dict(torch.load(tmp_path / "residual-model"))
        # sanity check: the base model weights were indeed changed
        tol = 1e-06
        assert not torch.allclose(model.linear.weight, base_model.linear.weight, atol=tol, rtol=tol)
        # quantize the linear layer
        linear4bit = bnb.nn.Linear4bit(base_model.linear.in_features, base_model.linear.out_features)
        linear4bit.load_state_dict(base_model.linear.state_dict())
        linear4bit.to(0)
        base_model.linear = linear4bit
        peft_model = PeftModel.from_pretrained(deepcopy(base_model), tmp_path / "init-model")
        output_quantized_pissa = peft_model(data)[0]
        # sanity check
        tol = 1e-06
        assert not torch.allclose(output_base, output_quantized_pissa, atol=tol, rtol=tol)
        # modify the weights, or else the adapter performs an identity transformation
        peft_model.base_model.linear.lora_B["default"].weight.data *= 2.0
        output_finetuned_pissa = peft_model(data)[0]
        # sanity check
        tol = 1e-06
        assert not torch.allclose(output_quantized_pissa, output_finetuned_pissa, atol=tol, rtol=tol)
        # save the model normally
        peft_model.save_pretrained(tmp_path / "pissa-model")
        model_loaded = PeftModel.from_pretrained(deepcopy(base_model), tmp_path / "pissa-model")
        output_loaded = model_loaded(data)[0]
        assert torch.allclose(output_finetuned_pissa, output_loaded, atol=tol, rtol=tol)
        # sanity check: ranks should still be 8 as initially
        assert model_loaded.peft_config["default"].r == 8
        assert model_loaded.base_model.model.linear.lora_A["default"].weight.shape[0] == 8
        # save the model with conversion
        peft_model.save_pretrained(
            tmp_path / "pissa-model-converted", path_initial_model_for_weight_conversion=tmp_path / "init-model"
        )
        model_converted = PeftModel.from_pretrained(deepcopy(model), tmp_path / "pissa-model-converted")
        output_converted = model_converted(data)[0]
        # rank should be double of what it was initially
        assert model_converted.peft_config["default"].r == 16
        assert model_converted.base_model.model.linear.lora_A["default"].weight.shape[0] == 16
        # base model weights should be the same as the initial model
        assert torch.allclose(
            model.linear.weight, model_converted.base_model.model.linear.base_layer.weight, atol=tol, rtol=tol
        )
        # This check is expected to fail when using bnb
        assert not torch.allclose(output_finetuned_pissa, output_converted, atol=tol, rtol=tol)
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="test requires a GPU or XPU")
@pytest.mark.single_gpu_tests
class TestOLoRA:
    r"""
    Tests for OLoRA to ensure that it reduces the quantization error compared to normal LoRA quantization.
    """
    # The error factor indicates by how much the quantization error should be decreased when using OLoRA compared to
    # quantization without OLoRA. Thus 1.03 means that the error should be decreased by 3% at least. This is a very
    # conservative value to prevent flakiness, in practice most gains are > 1.5
    error_factor = 1.2
    def quantize_model(self, model, num_bits=4, device="cuda"):
        # Quantize the `weight.data` of the linear layer in the model to `num_bits` and store it with full precision.
        quantizer = NFQuantizer(num_bits=num_bits, device=device, method="normal", block_size=64)
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Linear) and "lm_head" not in name:
                quantized_weight, max_abs, shape = quantizer.quantize_block(module.weight.data.to(device))
                module.weight.data = quantizer.dequantize_block(quantized_weight, max_abs, shape)
        return model
    def nuclear_norm(self, base_model, quantized_model):
        # Calculate the nuclear norm (sum of singular values) of the error matrices between the `quantized_model` and the `base_model`.
        error_list = []
        for name, module in base_model.named_modules():
            if isinstance(module, torch.nn.Linear) and "lm_head" not in name:
                quant_module = quantized_model.get_submodule(name)
                error_list.append(torch.linalg.svdvals(module.weight.data - quant_module.weight.data).sum())
        return torch.Tensor(error_list).sum()
    def get_errors(
        self,
        tmp_path,
        bits=4,
        device="cuda",
        model_id="hf-internal-testing/tiny-random-BloomForCausalLM",
    ):
        # Comparing the quantized LoRA model to the base model, vs the OLoRA quantized model to the base model.
        # We expect the OLoRA quantized model to have less error than the normal LoRA quantized model.
        cls = AutoModelForSeq2SeqLM if "t5" in str(model_id) else AutoModelForCausalLM
        base_model = cls.from_pretrained(model_id).eval().to(device)
        task_type = TaskType.SEQ_2_SEQ_LM if base_model.config.is_encoder_decoder else TaskType.CAUSAL_LM
        # logits from the normal quantized LoRA model
        target_modules = "all-linear" if task_type != TaskType.SEQ_2_SEQ_LM else ["o", "k", "wi", "q", "v"]
        lora_config = LoraConfig(task_type=task_type, target_modules=target_modules)
        qlora_model = self.quantize_model(cls.from_pretrained(model_id).eval().to(device), bits, device)
        qlora_model = get_peft_model(
            qlora_model,
            lora_config,
        )
        qlora_model = qlora_model.merge_and_unload()
        qlora_error = self.nuclear_norm(base_model, qlora_model)
        del qlora_model
        clear_device_cache(garbage_collection=True)
        # logits from quantized LoRA model using OLoRA
        lora_config = LoraConfig(
            task_type=task_type,
            init_lora_weights="olora",
            target_modules=target_modules,
        )
        olora_model = cls.from_pretrained(model_id).eval().to(device)
        olora_model = get_peft_model(olora_model, lora_config)
        # save LoRA weights, they should be initialized such that they minimize the quantization error
        olora_model.base_model.peft_config["default"].init_lora_weights = True
        olora_model.save_pretrained(tmp_path / "olora_model")
        olora_model = olora_model.unload()
        olora_model.save_pretrained(tmp_path / "residual_model")
        del olora_model
        clear_device_cache(garbage_collection=True)
        # now load quantized model and apply OLoRA-initialized weights on top
        qolora_model = self.quantize_model(
            cls.from_pretrained(tmp_path / "residual_model").eval().to(device), bits, device
        )
        qolora_model = PeftModel.from_pretrained(qolora_model, tmp_path / "olora_model")
        qolora_model = qolora_model.merge_and_unload()
        qolora_error = self.nuclear_norm(base_model, qolora_model)
        del qolora_model
        clear_device_cache(garbage_collection=True)
        assert qlora_error > 0.0
        assert qolora_error > 0.0
        # next, check that OLoRA quantization errors are smaller than LoRA errors by a certain margin
        assert qolora_error < (qlora_error / self.error_factor)
    @pytest.mark.parametrize("device", [torch_device, "cpu"])
    def test_bloomz_olora_4bit(self, device, tmp_path):
        # In this test, we compare the logits of the base model, the quantized LoRA model, and the quantized model
        # using OLoRA. When quantizing, we expect a certain level of error. However, we expect the OLoRA quantized
        # model to have less error than the normal LoRA quantized model. Note that when using normal LoRA, the
        # quantization error is simply the error from quantization without LoRA, as LoRA is a no-op before training.
        # We still apply LoRA for the test for consistency.
        self.get_errors(bits=4, device=device, tmp_path=tmp_path)
    @pytest.mark.parametrize("device", [torch_device, "cpu"])
    def test_bloomz_olora_8bit(self, device, tmp_path):
        # Same test as test_bloomz_olora_4bit but with 8 bits.
        self.get_errors(bits=8, device=device, tmp_path=tmp_path)
    @pytest.mark.parametrize("bits", [4, 8])
    def test_olora_with_quantized_model(self, bits):
        import bitsandbytes as bnb
        # issue 1999
        model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
        if bits == 4:
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_quant_storage=torch.float16,
                bnb_4bit_use_double_quant=True,
            )
        elif bits == 8:
            bnb_config = BitsAndBytesConfig(load_in_8bit=True)
        else:
            raise ValueError("bits must be 4 or 8")
        model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)
        model = prepare_model_for_kbit_training(model)
        config = LoraConfig(init_lora_weights="olora")
        model = get_peft_model(model, config)
        # check that the correct type is used for the weights
        base_layer = model.base_model.model.model.decoder.layers[0].self_attn.v_proj.base_layer.weight
        if bits == 4:
            assert isinstance(base_layer, bnb.nn.modules.Params4bit)
        else:
            assert isinstance(base_layer, bnb.nn.modules.Int8Params)
        inputs = torch.arange(10).unsqueeze(0).to(model.device)
        logits = model(inputs).logits  # does not raise
        assert torch.isfinite(logits).all()
@pytest.mark.skipif(
    not (torch.cuda.is_available() or is_xpu_available()), reason="test requires a hardware accelerator"
)
@pytest.mark.single_gpu_tests
@require_bitsandbytes
class TestLoftQ:
    r"""
    Tests for LoftQ to ensure that it reduces the quantization error compared to normal LoRA quantization.
    """
    def get_error_factor(self, device):
        # The error factor indicates by how much the quantization error should be decreased when using LoftQ compared to
        # quantization without LoftQ. Thus 1.03 means that the error should be decreased by 3% at least. This is a very
        # conservative value to prevent flakiness, in practice most gains are > 1.5
        error_factor = 1.005 if device in ("xpu", "cpu") else 1.03
        return error_factor
    def get_input(self, model_id, device):
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        inputs = tokenizer("All I want is", padding=True, return_tensors="pt")
        inputs = inputs.to(device)
        return inputs
    def get_base_model(self, model_id, device, **kwargs):
        cls = AutoModelForSeq2SeqLM if "t5" in str(model_id) else AutoModelForCausalLM
        model = cls.from_pretrained(model_id, device_map=device, **kwargs).eval()
        return model
    def get_logits(self, model, inputs):
        if model.config.is_encoder_decoder:
            input_ids = inputs["input_ids"]
            return model(input_ids=input_ids, decoder_input_ids=input_ids).logits
        return model(**inputs).logits
    def get_errors(
        self,
        tmp_path,
        bits=4,
        loftq_iter=1,
        device="cuda",
        model_id="hf-internal-testing/tiny-random-BloomForCausalLM",
        use_dora=False,
    ):
        # Helper function that returns the quantization errors (MAE and MSE) when comparing the quantized LoRA model
        # to the base model, vs the LoftQ quantized model to the base model. We expect the LoftQ quantized model to
        # have less error than the normal LoRA quantized model. Since we compare logits, the observed error is
        # already somewhat dampened because of the softmax.
        torch.manual_seed(0)
        model = self.get_base_model(model_id, device)
        task_type = TaskType.SEQ_2_SEQ_LM if model.config.is_encoder_decoder else TaskType.CAUSAL_LM
        inputs = self.get_input(model_id, device)
        # the base logits are the reference, we try to match those as closely as possible
        logits_base = self.get_logits(model, inputs)
        # clean up
        del model
        clear_device_cache(garbage_collection=True)
        # logits from the normal quantized LoRA model
        target_modules = "all-linear" if task_type != TaskType.SEQ_2_SEQ_LM else ["o", "k", "wi", "q", "v"]
        lora_config = LoraConfig(task_type=task_type, use_dora=use_dora, target_modules=target_modules)
        kwargs = {}
        if bits == 4:
            kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")
        elif bits == 8:
            kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
        else:
            raise ValueError("bits must be 4 or 8")
        quantized_model = get_peft_model(
            self.get_base_model(model_id, device, **kwargs),
            lora_config,
        )
        torch.manual_seed(0)
        logits_quantized = self.get_logits(quantized_model, inputs)
        del quantized_model
        clear_device_cache(garbage_collection=True)
        # logits from quantized LoRA model using LoftQ
        loftq_config = LoftQConfig(loftq_bits=bits, loftq_iter=loftq_iter)
        lora_config = LoraConfig(
            task_type=task_type,
            init_lora_weights="loftq",
            loftq_config=loftq_config,
            use_dora=use_dora,
            target_modules=target_modules,
        )
        model = self.get_base_model(model_id, device)
        if device != "cpu":
            model = model.to(device)
        loftq_model = get_peft_model(model, lora_config)
        if device != "cpu":
            loftq_model = loftq_model.to(device)
        # save LoRA weights, they should be initialized such that they minimize the quantization error
        loftq_model.base_model.peft_config["default"].init_lora_weights = True
        loftq_model.save_pretrained(tmp_path / "loftq_model")
        loftq_model = loftq_model.unload()
        loftq_model.save_pretrained(tmp_path / "base_model")
        del loftq_model
        clear_device_cache(garbage_collection=True)
        # now load quantized model and apply LoftQ-initialized weights on top
        base_model = self.get_base_model(tmp_path / "base_model", device=device, **kwargs, torch_dtype=torch.float32)
        loftq_model = PeftModel.from_pretrained(base_model, tmp_path / "loftq_model", is_trainable=True)
        # TODO sanity check: model is quantized
        torch.manual_seed(0)
        logits_loftq = self.get_logits(loftq_model, inputs)
        del loftq_model
        clear_device_cache(garbage_collection=True)
        mae_quantized = torch.abs(logits_base - logits_quantized).mean()
        mse_quantized = torch.pow(logits_base - logits_quantized, 2).mean()
        mae_loftq = torch.abs(logits_base - logits_loftq).mean()
        mse_loftq = torch.pow(logits_base - logits_loftq, 2).mean()
        return mae_quantized, mse_quantized, mae_loftq, mse_loftq
    @pytest.mark.parametrize("device", [torch_device, "cpu"])
    def test_bloomz_loftq_4bit(self, device, tmp_path):
        # In this test, we compare the logits of the base model, the quantized LoRA model, and the quantized model
        # using LoftQ. When quantizing, we expect a certain level of error. However, we expect the LoftQ quantized
        # model to have less error than the normal LoRA quantized model. Note that when using normal LoRA, the
        # quantization error is simply the error from quantization without LoRA, as LoRA is a no-op before training.
        # We still apply LoRA for the test for consistency.
        mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(bits=4, device=device, tmp_path=tmp_path)
        # first, sanity check that all errors are > 0.0
        assert mae_quantized > 0.0
        assert mse_quantized > 0.0
        assert mae_loftq > 0.0
        assert mse_loftq > 0.0
        # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
        assert mse_loftq < (mse_quantized / self.error_factor)
        assert mae_loftq < (mae_quantized / self.error_factor)
    @pytest.mark.parametrize("device", [torch_device, "cpu"])
    def test_bloomz_loftq_4bit_iter_5(self, device, tmp_path):
        # Same test as the previous one but with 5 iterations. We should expect the error to be even smaller with more
        # iterations, but in practice the difference is not that large, at least not for this small base model.
        mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(
            bits=4, loftq_iter=5, device=device, tmp_path=tmp_path
        )
        # first, sanity check that all errors are > 0.0
        assert mae_quantized > 0.0
        assert mse_quantized > 0.0
        assert mae_loftq > 0.0
        assert mse_loftq > 0.0
        # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
        error_factor = self.get_error_factor(device)
        assert mse_loftq < (mse_quantized / error_factor)
        assert mae_loftq < (mae_quantized / error_factor)
    @pytest.mark.parametrize("device", [torch_device, "cpu"])
    def test_bloomz_loftq_8bit(self, device, tmp_path):
        # Same test as test_bloomz_loftq_4bit but with 8 bits.
        mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(bits=8, device=device, tmp_path=tmp_path)
        # first, sanity check that all errors are > 0.0
        assert mae_quantized > 0.0
        assert mse_quantized > 0.0
        assert mae_loftq > 0.0
        assert mse_loftq > 0.0
        # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
        error_factor = self.get_error_factor(device)
        assert mse_loftq < (mse_quantized / error_factor)
        assert mae_loftq < (mae_quantized / error_factor)
    @pytest.mark.parametrize("device", [torch_device, "cpu"])
    def test_bloomz_loftq_8bit_iter_5(self, device, tmp_path):
        # Same test as test_bloomz_loftq_4bit_iter_5 but with 8 bits.
        mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(
            bits=8, loftq_iter=5, device=device, tmp_path=tmp_path
        )
        # first, sanity check that all errors are > 0.0
        assert mae_quantized > 0.0
        assert mse_quantized > 0.0
        assert mae_loftq > 0.0
        assert mse_loftq > 0.0
        # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
        error_factor = self.get_error_factor(device)
        assert mse_loftq < (mse_quantized / error_factor)
        assert mae_loftq < (mae_quantized / error_factor)
    @pytest.mark.parametrize("device", [torch_device, "cpu"])
    def test_t5_loftq_4bit(self, device, tmp_path):
        mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(
            bits=4, device=device, model_id="t5-small", tmp_path=tmp_path
        )
        # first, sanity check that all errors are > 0.0
        assert mae_quantized > 0.0
        assert mse_quantized > 0.0
        assert mae_loftq > 0.0
        assert mse_loftq > 0.0
        # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
        error_factor = self.get_error_factor(device)
        assert mse_loftq < (mse_quantized / error_factor)
        assert mae_loftq < (mae_quantized / error_factor)
    @pytest.mark.parametrize("device", [torch_device, "cpu"])
    def test_t5_loftq_8bit(self, device, tmp_path):
        mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(
            bits=8, device=device, model_id="t5-small", tmp_path=tmp_path
        )
        # first, sanity check that all errors are > 0.0
        assert mae_quantized > 0.0
        assert mse_quantized > 0.0
        assert mae_loftq > 0.0
        assert mse_loftq > 0.0
        # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
        error_factor = self.get_error_factor(device)
        assert mse_loftq < (mse_quantized / error_factor)
        assert mae_loftq < (mae_quantized / error_factor)
    @pytest.mark.xfail  # failing for now, but having DoRA pass is only a nice-to-have, not a must, so we're good
    @pytest.mark.parametrize("device", [torch_device, "cpu"])
    def test_bloomz_loftq_4bit_dora(self, device, tmp_path):
        # same as test_bloomz_loftq_4bit but with DoRA
        mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(
            bits=4, device=device, use_dora=True, tmp_path=tmp_path
        )
        # first, sanity check that all errors are > 0.0
        assert mae_quantized > 0.0
        assert mse_quantized > 0.0
        assert mae_loftq > 0.0
        assert mse_loftq > 0.0
        # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
        factor = 3
        assert mae_loftq < (mae_quantized / factor)
        assert mse_loftq < (mse_quantized / factor)
    @pytest.mark.parametrize("device", [torch_device, "cpu"])
    def test_bloomz_loftq_8bit_dora(self, device, tmp_path):
        # same as test_bloomz_loftq_8bit but with DoRA
        mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(
            bits=8, device=device, use_dora=True, tmp_path=tmp_path
        )
        # first, sanity check that all errors are > 0.0
        assert mae_quantized > 0.0
        assert mse_quantized > 0.0
        assert mae_loftq > 0.0
        assert mse_loftq > 0.0
        # next, check that LoftQ quantization errors are smaller than LoRA errors by a certain margin
        error_factor = self.get_error_factor(device)
        assert mae_loftq < (mae_quantized / error_factor)
        assert mse_loftq < (mse_quantized / error_factor)
    def test_replace_lora_weights_with_loftq_using_callable(self):
        """
        Test replacing LoRa weights with LoFTQ using a callable.
        Using the replace_lora_weights_loftq function, we replace the LoRa weights of a bnb-quantized model with LoRA
        weights initialized by LoftQ on the fly. We use a callable to decide whether to replace the weights or not.
        This callable checks, for each weight, if replacing it would actually result in logits that are closer to the
        original logits of the non-quantized model.
        """
        torch.manual_seed(0)
        model_id = "bigscience/bloomz-560m"
        device = torch_device
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        inputs = tokenizer("The dog was", padding=True, return_tensors="pt").to(device)
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
            logits_base = model(**inputs).logits
            model.save_pretrained(tmp_dir)
            # load in 4bit
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
            )
            model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)
            model = get_peft_model(model, LoraConfig(task_type="CAUSAL_LM", target_modules="all-linear"))
            logits_lora = model(**inputs).logits
            current_mse = float("inf")
            logs = []
            def my_callback(model, module_name):
                """Callable to replace weights with LoFTQ if the mse is lower than the current best one."""
                nonlocal current_mse
                logits = model(**inputs).logits
                mse = ((logits_base - logits) ** 2).mean()
                if mse < current_mse:
                    current_mse = mse
                    logs.append(True)
                    return True
                logs.append(False)
                return False
            replace_lora_weights_loftq(model, model_path=tmp_dir, callback=my_callback)
            logits_loftq = model(**inputs).logits
            mae_lora = (logits_base - logits_lora).abs().mean()
            mae_loftq = (logits_base - logits_loftq).abs().mean()
            mse_lora = ((logits_base - logits_lora) ** 2).mean()
            mse_loftq = ((logits_base - logits_loftq) ** 2).mean()
            # check that the error was reduced by a certain margin
            assert mae_loftq * 1.5 < mae_lora
            assert mse_loftq * 2.5 < mse_lora
            # check that the callback has returned some True and some False values
            assert any(logs)
            assert not all(logs)
        del model
        clear_device_cache(garbage_collection=True)
    def test_replace_lora_weights_with_local_model(self):
        # see issue 2020
        torch.manual_seed(0)
        model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
        device = torch_device
        with tempfile.TemporaryDirectory() as tmp_dir:
            # save base model locally
            model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
            model.save_pretrained(tmp_dir)
            del model
            # load in 4bit
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
            )
            # load the base model from local directory
            model = AutoModelForCausalLM.from_pretrained(tmp_dir, quantization_config=bnb_config)
            model = get_peft_model(model, LoraConfig())
            # passing the local path directly works
            replace_lora_weights_loftq(model, model_path=tmp_dir)
            del model
            # load the base model from local directory
            model = AutoModelForCausalLM.from_pretrained(tmp_dir, quantization_config=bnb_config)
            model = get_peft_model(model, LoraConfig())
            # when not passing, ensure that users are made aware of the `model_path` argument
            with pytest.raises(ValueError, match="model_path"):
                replace_lora_weights_loftq(model)
        del model
        clear_device_cache(garbage_collection=True)
    def test_config_no_loftq_init(self):
        with pytest.warns(
            UserWarning,
            match="`loftq_config` specified but will be ignored when `init_lora_weights` is not 'loftq'.",
        ):
            LoraConfig(loftq_config=LoftQConfig())
    def test_config_no_loftq_config(self):
        with pytest.raises(ValueError, match="`loftq_config` must be specified when `init_lora_weights` is 'loftq'."):
            LoraConfig(init_lora_weights="loftq")
@require_bitsandbytes
@require_non_cpu
class MultiprocessTester(unittest.TestCase):
    def test_notebook_launcher(self):
        script_path = os.path.join("scripts", "launch_notebook_mp.py")
        cmd = ["python", script_path]
        with patch_environment(omp_num_threads=1):
            run_command(cmd, env=os.environ.copy())
@require_non_cpu
class MixedPrecisionTests(unittest.TestCase):
    def setUp(self):
        self.causal_lm_model_id = "facebook/opt-125m"
        self.tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
        self.config = LoraConfig(
            r=16,
            lora_alpha=32,
            task_type="CAUSAL_LM",
        )
        data = load_dataset_english_quotes()
        self.data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
    def tearDown(self):
        r"""
        Efficient mechanism to free GPU memory after each test. Based on
        https://github.com/huggingface/transformers/issues/21094
        """
        clear_device_cache(garbage_collection=True)
        gc.collect()
    @pytest.mark.single_gpu_tests
    def test_model_using_float16_with_amp_raises(self):
        # This test shows the issue with using a model in fp16 and then trying to use it with mixed precision training,
        # which should not use fp16.
        model = AutoModelForCausalLM.from_pretrained(
            self.causal_lm_model_id,
            torch_dtype=torch.float16,
        )
        model = get_peft_model(model, self.config, autocast_adapter_dtype=False)
        with tempfile.TemporaryDirectory() as tmp_dir:
            trainer = Trainer(
                model=model,
                train_dataset=self.data["train"],
                args=TrainingArguments(
                    fp16=True,  # <= this is required for the error to be raised
                    output_dir=tmp_dir,
                    max_steps=3,
                ),
                data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            )
            with pytest.raises(ValueError, match="Attempting to unscale FP16 gradients."):
                trainer.train()
    @pytest.mark.single_gpu_tests
    def test_model_using_float16_autocast_dtype(self):
        # Here we use autocast_adapter_dtype=True (the default) to automatically promote the adapter weights to float32.
        # No exception should be raised.
        model = AutoModelForCausalLM.from_pretrained(
            self.causal_lm_model_id,
            torch_dtype=torch.float16,
        )
        model = get_peft_model(model, self.config, autocast_adapter_dtype=True)
        with tempfile.TemporaryDirectory() as tmp_dir:
            trainer = Trainer(
                model=model,
                train_dataset=self.data["train"],
                args=TrainingArguments(
                    fp16=True,  # <= this is required for the error to be raised
                    output_dir=tmp_dir,
                    max_steps=3,
                ),
                data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            )
            trainer.train()  # does not raise
    @pytest.mark.single_gpu_tests
    def test_model_using_float16_explicit_cast(self):
        # Same test as above but containing the fix to make it work
        model = AutoModelForCausalLM.from_pretrained(
            self.causal_lm_model_id,
            torch_dtype=torch.float16,
        )
        model = get_peft_model(model, self.config, autocast_adapter_dtype=False)
        # here we manually promote the adapter weights to float32
        for param in model.parameters():
            if param.requires_grad:
                param.data = param.data.float()
        dtype_counts_before = Counter(p.dtype for p in model.parameters())
        model = AutoModelForCausalLM.from_pretrained(
            self.causal_lm_model_id,
            torch_dtype=torch.float16,
        )
        model = get_peft_model(model, self.config, autocast_adapter_dtype=True)
        dtype_counts_after = Counter(p.dtype for p in model.parameters())
        assert dtype_counts_before == dtype_counts_after
        with tempfile.TemporaryDirectory() as tmp_dir:
            trainer = Trainer(
                model=model,
                train_dataset=self.data["train"],
                args=TrainingArguments(
                    fp16=True,  # <= this is required for the error to be raised
                    max_steps=3,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            )
            trainer.train()  # does not raise
    @pytest.mark.single_gpu_tests
    def test_load_model_using_float16_with_amp_raises(self):
        # Same as previous tests, but loading the adapter with PeftModel.from_pretrained instead
        model = AutoModelForCausalLM.from_pretrained(
            self.causal_lm_model_id,
            torch_dtype=torch.float16,
        )
        model = get_peft_model(model, self.config, autocast_adapter_dtype=False)
        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir)
            model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id, torch_dtype=torch.float16)
            model = PeftModel.from_pretrained(model, tmp_dir, autocast_adapter_dtype=False, is_trainable=True)
            trainer = Trainer(
                model=model,
                train_dataset=self.data["train"],
                args=TrainingArguments(
                    fp16=True,  # <= this is required for the error to be raised
                    output_dir=tmp_dir,
                    max_steps=3,
                ),
                data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            )
            with pytest.raises(ValueError, match="Attempting to unscale FP16 gradients."):
                trainer.train()
    @pytest.mark.single_gpu_tests
    def test_load_model_using_float16_autocast_dtype(self):
        # Same as previous tests, but loading the adapter with PeftModel.from_pretrained instead
        model = AutoModelForCausalLM.from_pretrained(
            self.causal_lm_model_id,
            torch_dtype=torch.float16,
        )
        # Below, we purposefully set autocast_adapter_dtype=False so that the saved adapter uses float16. We still want
        # the loaded adapter to use float32 when we load it with autocast_adapter_dtype=True.
        model = get_peft_model(model, self.config, autocast_adapter_dtype=False)
        # sanity check: this should have float16 adapter weights:
        assert (
            model.base_model.model.model.decoder.layers[0].self_attn.v_proj.lora_A["default"].weight.dtype
            == torch.float16
        )
        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir)
            model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id, torch_dtype=torch.float16)
            model = PeftModel.from_pretrained(model, tmp_dir, autocast_adapter_dtype=True, is_trainable=True)
            # sanity check: this should NOT have float16 adapter weights:
            assert (
                model.base_model.model.model.decoder.layers[0].self_attn.v_proj.lora_A["default"].weight.dtype
                == torch.float32
            )
            trainer = Trainer(
                model=model,
                train_dataset=self.data["train"],
                args=TrainingArguments(
                    fp16=True,  # <= this is required for the error to be raised
                    output_dir=tmp_dir,
                    max_steps=3,
                ),
                data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            )
            trainer.train()  # does not raise
    @pytest.mark.single_gpu_tests
    def test_load_adapter_using_float16_autocast_dtype(self):
        # Here we test the load_adapter method with autocast_adapter_dtype. We show that autocasting is prevented when
        # calling load_model(..., autocast_adapter_dtype=False) and that it is enabled when calling
        # load_model(..., autocast_adapter_dtype=True) (the default).
        model = AutoModelForCausalLM.from_pretrained(
            self.causal_lm_model_id,
            torch_dtype=torch.float16,
        )
        # Below, we purposefully set autocast_adapter_dtype=False so that the saved adapter uses float16. We still want
        # the loaded adapter to use float32 when we load it with autocast_adapter_dtype=True.
        model = get_peft_model(model, self.config, autocast_adapter_dtype=False)
        # sanity check: this should have float16 adapter weights:
        assert (
            model.base_model.model.model.decoder.layers[0].self_attn.v_proj.lora_A["default"].weight.dtype
            == torch.float16
        )
        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir)
            model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id, torch_dtype=torch.float16)
            # the default adapter is now in float16
            model = get_peft_model(model, self.config, autocast_adapter_dtype=False)
            # sanity check: this should NOT have float16 adapter weights:
            assert (
                model.base_model.model.model.decoder.layers[0].self_attn.v_proj.lora_A["default"].weight.dtype
                == torch.float16
            )
            # now load the first adapter in float16 using the adapter name "loaded16"
            model.load_adapter(tmp_dir, "loaded16", autocast_adapter_dtype=False)
            assert (
                model.base_model.model.model.decoder.layers[0].self_attn.v_proj.lora_A["loaded16"].weight.dtype
                == torch.float16
            )
            # now load the first adapter in float32 using the adapter name "loaded32"
            model.load_adapter(tmp_dir, "loaded32", autocast_adapter_dtype=True)
            assert (
                model.base_model.model.model.decoder.layers[0].self_attn.v_proj.lora_A["loaded32"].weight.dtype
                == torch.float32
            )
            # training with the default adapter, which is in float16, should raise
            model.set_adapter("default")
            trainer = Trainer(
                model=model,
                train_dataset=self.data["train"],
                args=TrainingArguments(
                    fp16=True,  # <= this is required for the error to be raised
                    output_dir=tmp_dir,
                    max_steps=3,
                ),
                data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            )
            with pytest.raises(ValueError, match="Attempting to unscale FP16 gradients."):
                trainer.train()
            # training the model with the adapter "loaded16", which is in float16, should also raise
            model.set_adapter("loaded16")
            trainer = Trainer(
                model=model,
                train_dataset=self.data["train"],
                args=TrainingArguments(
                    fp16=True,  # <= this is required for the error to be raised
                    output_dir=tmp_dir,
                    max_steps=3,
                ),
                data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            )
            with pytest.raises(ValueError, match="Attempting to unscale FP16 gradients."):
                trainer.train()
            # training the model with the adapter "loaded32", which is in float32, should not raise
            model.set_adapter("loaded32")
            trainer = Trainer(
                model=model,
                train_dataset=self.data["train"],
                args=TrainingArguments(
                    fp16=True,  # <= this is required for the error to be raised
                    output_dir=tmp_dir,
                    max_steps=3,
                ),
                data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            )
            trainer.train()  # does not raise
@require_non_xpu
@require_torch_gpu
@require_aqlm
@unittest.skipUnless(
    version.parse(importlib.metadata.version("transformers")) >= version.parse("4.38.0"),
    "test requires `transformers>=4.38.0`",
)
class PeftAqlmGPUTests(unittest.TestCase):
    r"""
    AQLM + peft tests
    """
    def setUp(self):
        self.causal_lm_model_id = "BlackSamorez/TinyLlama-1_1B-Chat-v1_0-AQLM-2Bit-1x16-hf"
        self.tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
    def tearDown(self):
        r"""
        Efficient mechanism to free GPU memory after each test. Based on
        https://github.com/huggingface/transformers/issues/21094
        """
        clear_device_cache(garbage_collection=True)
    def _check_inference_finite(self, model, batch):
        # try inference without Trainer class
        training = model.training
        model.eval()
        output = model(**batch.to(model.device))
        assert torch.isfinite(output.logits).all()
        model.train(training)
    @pytest.mark.single_gpu_tests
    def test_causal_lm_training_aqlm(self):
        r"""
        Test the CausalLM training on a single GPU device. The test would simply fail if the adapters are not set
        correctly.
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                device_map="cuda",
                torch_dtype="auto",
            )
            model = prepare_model_for_kbit_training(model)
            config = LoraConfig(
                r=16,
                lora_alpha=32,
                target_modules=["q_proj", "v_proj"],
                lora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(model, config)
            data = load_dataset_english_quotes()
            data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    logging_steps=1,
                    output_dir=tmp_dir,
                    fp16=True,
                ),
                data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
@require_non_xpu
@require_torch_gpu
@require_hqq
@unittest.skipUnless(
    version.parse(importlib.metadata.version("transformers")) >= version.parse("4.36.1"),
    "test requires `transformers>=4.36.1`",
)
class PeftHqqGPUTests(unittest.TestCase):
    r"""
    HQQ + peft tests
    """
    def setUp(self):
        self.causal_lm_model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
        self.tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
    def tearDown(self):
        r"""
        Efficient mechanism to free GPU memory after each test. Based on
        https://github.com/huggingface/transformers/issues/21094
        """
        clear_device_cache(garbage_collection=True)
    @pytest.mark.single_gpu_tests
    @parameterized.expand([False, True])
    def test_causal_lm_training_hqq(self, use_dora):
        r"""
        Test the CausalLM training on a single GPU device. The test would simply fail if the adapters are not set
        correctly.
        """
        from transformers import HqqConfig
        with tempfile.TemporaryDirectory() as tmp_dir:
            device = "cuda"
            compute_dtype = torch.float16
            quant_config = HqqConfig(nbits=4, group_size=64)
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                device_map=device,
                torch_dtype=compute_dtype,
                quantization_config=quant_config,
            )
            model = prepare_model_for_kbit_training(model)
            config = LoraConfig(
                r=16,
                lora_alpha=32,
                target_modules=["q_proj", "v_proj"],
                lora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
                use_dora=use_dora,
            )
            model = get_peft_model(model, config)
            data = load_dataset_english_quotes()
            data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    logging_steps=1,
                    output_dir=tmp_dir,
                    fp16=True,
                ),
                data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.single_gpu_tests
    def test_hqq_lora_model_outputs(self):
        # check that the outputs generated by HQQ with LoRA are similar to those without HQQ
        from transformers import HqqConfig
        device = "cuda"
        compute_dtype = torch.float16
        min_correlation = 0.96
        # first load the model without HQQ
        model = AutoModelForCausalLM.from_pretrained(
            self.causal_lm_model_id,
            device_map=device,
            torch_dtype=compute_dtype,
        )
        config = LoraConfig(
            target_modules=["q_proj", "v_proj"],
            task_type="CAUSAL_LM",
            init_lora_weights=False,
        )
        torch.manual_seed(0)
        model = get_peft_model(model, config).eval()
        inputs = self.tokenizer("The meaning of unit tests is", return_tensors="pt").to(model.device)
        with torch.inference_mode():
            output_normal = model(**inputs).logits
        assert torch.isfinite(output_normal).all()
        del model
        clear_device_cache(garbage_collection=True)
        # now load with HQQ
        quant_config = HqqConfig(nbits=4, group_size=64)
        model = AutoModelForCausalLM.from_pretrained(
            self.causal_lm_model_id,
            device_map=device,
            torch_dtype=compute_dtype,
            quantization_config=quant_config,
        )
        torch.manual_seed(0)
        model = get_peft_model(model, config).eval()
        with torch.inference_mode():
            output_hqq = model(**inputs).logits
        # check that outputs of HQQ are highly correlated; there are outliers, so don't check for equality
        cc_matrix = torch.corrcoef(torch.stack((output_normal.float().flatten(), output_hqq.float().flatten())))
        assert cc_matrix.min() > min_correlation
        # check that outputs are the same after merging
        cc_matrix = torch.corrcoef(torch.stack((output_normal.float().flatten(), output_hqq.float().flatten())))
        assert cc_matrix.min() > min_correlation
        # check outputs are the same after unmerging
        model.unmerge_adapter()
        with torch.inference_mode():
            output_unmerged = model(**inputs).logits
        cc_matrix = torch.corrcoef(torch.stack((output_normal.float().flatten(), output_unmerged.float().flatten())))
        assert cc_matrix.min() > min_correlation
        # check that the results are the same after saving and loading
        with tempfile.TemporaryDirectory() as tmp_dir:
            model.save_pretrained(tmp_dir)
            del model
            clear_device_cache(garbage_collection=True)
            quant_config = HqqConfig(nbits=4, group_size=64)
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                device_map=device,
                torch_dtype=compute_dtype,
                quantization_config=quant_config,
            )
            model = PeftModel.from_pretrained(model, tmp_dir)
            with torch.inference_mode():
                output_loaded = model(**inputs).logits
            # for loading, we expect high precision, so check for equality and not just correlation
            atol, rtol = 1e-6, 1e-6
            assert torch.allclose(output_hqq, output_loaded, atol=atol, rtol=rtol)
        # check that outputs are the same after merge_and_unload
        model = model.merge_and_unload()
        with torch.inference_mode():
            output_merged_unloaded = model(**inputs).logits
        cc_matrix = torch.corrcoef(
            torch.stack((output_normal.float().flatten(), output_merged_unloaded.float().flatten()))
        )
        assert cc_matrix.min() > min_correlation
@require_non_cpu
@require_auto_awq
class PeftAwqGPUTests(unittest.TestCase):
    r"""
    Awq + peft tests
    Note that AWQ is no longer being maintained:
    https://github.com/casper-hansen/AutoAWQ/blob/88e4c76b20755db275574e6a03c83c84ba3bece5/README.md
    It is therefore expected that more tests will start failing in the future. If this happens, remove AWQ support from
    PEFT.
    """
    def setUp(self):
        self.causal_lm_model_id = "peft-internal-testing/opt-125m-awq"
        self.tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
    def tearDown(self):
        r"""
        Efficient mechanism to free accelerator memory after each test. Based on
        https://github.com/huggingface/transformers/issues/21094
        """
        clear_device_cache(garbage_collection=True)
    def _check_inference_finite(self, model, batch):
        # try inference without Trainer class
        training = model.training
        model.eval()
        output = model(**batch.to(model.device))
        assert torch.isfinite(output.logits).all()
        model.train(training)
    @pytest.mark.single_gpu_tests
    def test_causal_lm_training_awq(self):
        r"""
        Test the CausalLM training on a single accelerator. The test would simply fail if the adapters are not set
        correctly.
        """
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                device_map="auto",
            )
            model = prepare_model_for_kbit_training(model)
            config = LoraConfig(
                r=16,
                lora_alpha=32,
                target_modules=["q_proj", "v_proj"],
                lora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(model, config)
            data = load_dataset_english_quotes()
            data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
            # TODO: deal correctly with this case in transformers
            model._is_quantized_training_enabled = True
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    logging_steps=1,
                    output_dir=tmp_dir,
                    fp16=True,
                ),
                data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.multi_gpu_tests
    # TODO remove marker if/once issue is resolved, most likely requiring a fix in AutoAWQ:
    # https://github.com/casper-hansen/AutoAWQ/issues/754
    @pytest.mark.xfail(
        condition=is_torch_version(">=", "2.7.0"),
        reason="Multi-GPU test currently not working with AutoAWQ and PyTorch 2.7+",
        strict=True,
    )
    @require_torch_multi_accelerator
    def test_causal_lm_training_multi_accelerator(self):
        r"""
        Test the CausalLM training on a multi-accelerator device. The test would simply fail if the adapters are not
        set correctly.
        """
        device_map = {
            "model.decoder.embed_tokens": 0,
            "lm_head": 0,
            "model.decoder.embed_positions": 0,
            "model.decoder.project_out": 0,
            "model.decoder.project_in": 0,
            "model.decoder.layers.0": 0,
            "model.decoder.layers.1": 0,
            "model.decoder.layers.2": 0,
            "model.decoder.layers.3": 0,
            "model.decoder.layers.4": 0,
            "model.decoder.layers.5": 0,
            "model.decoder.layers.6": 1,
            "model.decoder.layers.7": 1,
            "model.decoder.layers.8": 1,
            "model.decoder.layers.9": 1,
            "model.decoder.layers.10": 1,
            "model.decoder.layers.11": 1,
            "model.decoder.final_layer_norm": 1,
        }
        with tempfile.TemporaryDirectory() as tmp_dir:
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                device_map=device_map,
            )
            assert set(model.hf_device_map.values()) == set(range(device_count))
            assert {p.device.index for p in model.parameters()} == set(range(device_count))
            model = prepare_model_for_kbit_training(model)
            setattr(model, "model_parallel", True)
            setattr(model, "is_parallelizable", True)
            config = LoraConfig(
                r=16,
                lora_alpha=32,
                target_modules=["q_proj", "v_proj"],
                lora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(model, config)
            data = load_dataset_english_quotes()
            data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
@require_non_xpu
@require_torch_gpu
@require_eetq
class PeftEetqGPUTests(unittest.TestCase):
    r"""
    EETQ + peft tests
    """
    def setUp(self):
        self.causal_lm_model_id = "facebook/opt-125m"
        self.tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
    def tearDown(self):
        r"""
        Efficient mechanism to free GPU memory after each test. Based on
        https://github.com/huggingface/transformers/issues/21094
        """
        clear_device_cache(garbage_collection=True)
    def _check_inference_finite(self, model, batch):
        # try inference without Trainer class
        training = model.training
        model.eval()
        output = model(**batch.to(model.device))
        assert torch.isfinite(output.logits).all()
        model.train(training)
    @pytest.mark.single_gpu_tests
    def test_causal_lm_training_eetq(self):
        r"""
        Test the CausalLM training on a single GPU device. The test would simply fail if the adapters are not set
        correctly.
        """
        from transformers import EetqConfig
        with tempfile.TemporaryDirectory() as tmp_dir:
            quantization_config = EetqConfig("int8")
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id, device_map="auto", quantization_config=quantization_config
            )
            model = prepare_model_for_kbit_training(model)
            config = LoraConfig(
                r=16,
                lora_alpha=32,
                target_modules=["q_proj", "v_proj"],
                lora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(model, config)
            data = load_dataset_english_quotes()
            data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.multi_gpu_tests
    @require_torch_multi_gpu
    def test_causal_lm_training_multi_gpu_eetq(self):
        r"""
        Test the CausalLM training on a multi-GPU device. The test would simply fail if the adapters are not set
        correctly.
        """
        from transformers import EetqConfig
        with tempfile.TemporaryDirectory() as tmp_dir:
            quantization_config = EetqConfig("int8")
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                device_map=DEVICE_MAP_MAP[self.causal_lm_model_id],
                quantization_config=quantization_config,
            )
            assert set(model.hf_device_map.values()) == set(range(device_count))
            assert {p.device.index for p in model.parameters()} == set(range(device_count))
            model = prepare_model_for_kbit_training(model)
            setattr(model, "model_parallel", True)
            setattr(model, "is_parallelizable", True)
            config = LoraConfig(
                r=16,
                lora_alpha=32,
                target_modules=["q_proj", "v_proj"],
                lora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(model, config)
            data = load_dataset_english_quotes()
            data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            )
            model.config.use_cache = False
            trainer.train()
            model.cpu().save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
@require_non_cpu
@require_torchao
class PeftTorchaoGPUTests(unittest.TestCase):
    r"""
    torchao + peft tests
    """
    supported_quant_types = [
        "int8_weight_only",
        "int8_dynamic_activation_int8_weight",
        # int4_weight_only raises an error:
        # RuntimeError: derivative for aten::_weight_int4pack_mm is not implemented
        # "int4_weight_only",
    ]
    def setUp(self):
        self.causal_lm_model_id = "facebook/opt-125m"
        self.tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
        # torchao breaks with fp16 and if a previous test uses fp16, transformers will set this env var, which affects
        # subsequent tests, therefore the env var needs to be cleared explicitly
        #
        # TODO: remove this once https://github.com/huggingface/transformers/pull/39483 is merged
        os.environ.pop("ACCELERATE_MIXED_PRECISION", None)
    def tearDown(self):
        r"""
        Efficient mechanism to free GPU memory after each test. Based on
        https://github.com/huggingface/transformers/issues/21094
        """
        clear_device_cache(garbage_collection=True)
    @parameterized.expand(supported_quant_types)
    @pytest.mark.single_gpu_tests
    def test_causal_lm_training_single_gpu_torchao(self, quant_type):
        from transformers import TorchAoConfig
        device = 0
        with tempfile.TemporaryDirectory() as tmp_dir:
            quantization_config = TorchAoConfig(quant_type=quant_type)
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id, device_map=device, quantization_config=quantization_config
            )
            model = prepare_model_for_kbit_training(model)
            config = LoraConfig(
                r=16,
                lora_alpha=32,
                target_modules=["q_proj", "v_proj"],
                lora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(model, config)
            data = load_dataset_english_quotes()
            data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            )
            trainer.model.config.use_cache = False
            trainer.train()
            model.save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.single_gpu_tests
    def test_causal_lm_training_single_gpu_torchao_dora_int8_weight_only(self):
        from transformers import TorchAoConfig
        device = 0
        with tempfile.TemporaryDirectory() as tmp_dir:
            quantization_config = TorchAoConfig(quant_type="int8_weight_only")
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id, device_map=device, quantization_config=quantization_config
            )
            model = prepare_model_for_kbit_training(model)
            config = LoraConfig(
                r=16,
                lora_alpha=32,
                target_modules=["q_proj", "v_proj"],
                lora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
                use_dora=True,
            )
            model = get_peft_model(model, config)
            data = load_dataset_english_quotes()
            data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            )
            trainer.model.config.use_cache = False
            trainer.train()
            model.save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.single_gpu_tests
    def test_causal_lm_training_single_gpu_torchao_dora_int8_dynamic_activation_int8_weight_raises(self):
        from transformers import TorchAoConfig
        device = 0
        quantization_config = TorchAoConfig(quant_type="int8_dynamic_activation_int8_weight")
        model = AutoModelForCausalLM.from_pretrained(
            self.causal_lm_model_id, device_map=device, quantization_config=quantization_config
        )
        model = prepare_model_for_kbit_training(model)
        config = LoraConfig(
            r=16,
            lora_alpha=32,
            target_modules=["q_proj", "v_proj"],
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM",
            use_dora=True,
        )
        with pytest.raises(NotImplementedError):
            get_peft_model(model, config)
    @pytest.mark.single_gpu_tests
    def test_causal_lm_training_single_gpu_torchao_int4_raises(self):
        # int4_weight_only raises an error:
        # RuntimeError: derivative for aten::_weight_int4pack_mm is not implemented
        # TODO: Once proper torchao support for int4 is added, remove this test and add int4 to supported_quant_types
        from transformers import TorchAoConfig
        device = 0
        quantization_config = TorchAoConfig(quant_type="int4_weight_only")
        model = AutoModelForCausalLM.from_pretrained(
            self.causal_lm_model_id, device_map=device, quantization_config=quantization_config
        )
        model = prepare_model_for_kbit_training(model)
        config = LoraConfig(
            r=16,
            lora_alpha=32,
            target_modules=["q_proj", "v_proj"],
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM",
        )
        msg = re.escape("TorchaoLoraLinear only supports int8 weights for now")
        with pytest.raises(ValueError, match=msg):
            get_peft_model(model, config)
    @parameterized.expand(supported_quant_types)
    @pytest.mark.multi_gpu_tests
    @require_torch_multi_accelerator
    def test_causal_lm_training_multi_accelerator_torchao(self, quant_type):
        from transformers import TorchAoConfig
        device_map = {
            "model.decoder.embed_tokens": 0,
            "lm_head": 0,
            "model.decoder.embed_positions": 0,
            "model.decoder.project_out": 0,
            "model.decoder.project_in": 0,
            "model.decoder.layers.0": 0,
            "model.decoder.layers.1": 0,
            "model.decoder.layers.2": 0,
            "model.decoder.layers.3": 0,
            "model.decoder.layers.4": 0,
            "model.decoder.layers.5": 0,
            "model.decoder.layers.6": 1,
            "model.decoder.layers.7": 1,
            "model.decoder.layers.8": 1,
            "model.decoder.layers.9": 1,
            "model.decoder.layers.10": 1,
            "model.decoder.layers.11": 1,
            "model.decoder.final_layer_norm": 1,
        }
        with tempfile.TemporaryDirectory() as tmp_dir:
            quantization_config = TorchAoConfig(quant_type=quant_type)
            model = AutoModelForCausalLM.from_pretrained(
                self.causal_lm_model_id,
                device_map=device_map,
                quantization_config=quantization_config,
                torch_dtype=torch.bfloat16,
            )
            assert set(model.hf_device_map.values()) == set(range(device_count))
            assert {p.device.index for p in model.parameters()} == set(range(device_count))
            model = prepare_model_for_kbit_training(model)
            model.model_parallel = True
            model.is_parallelizable = True
            config = LoraConfig(
                r=16,
                lora_alpha=32,
                target_modules=["q_proj", "v_proj"],
                lora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM",
            )
            model = get_peft_model(model, config)
            data = load_dataset_english_quotes()
            data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
            trainer = Trainer(
                model=model,
                train_dataset=data["train"],
                args=TrainingArguments(
                    per_device_train_batch_size=4,
                    gradient_accumulation_steps=4,
                    warmup_steps=2,
                    max_steps=3,
                    learning_rate=2e-4,
                    logging_steps=1,
                    output_dir=tmp_dir,
                ),
                data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
            )
            trainer.model.config.use_cache = False
            trainer.train()
            model.save_pretrained(tmp_dir)
            assert "adapter_config.json" in os.listdir(tmp_dir)
            assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
            # assert loss is not None
            assert trainer.state.log_history[-1]["train_loss"] is not None
    @pytest.mark.multi_gpu_tests
    @require_torch_multi_accelerator
    def test_causal_lm_training_multi_accelerator_torchao_int4_raises(self):
        # int4_weight_only raises an error:
        # RuntimeError: derivative for aten::_weight_int4pack_mm is not implemented
        # TODO: Once proper torchao support for int4 is added, remove this test and add int4 to supported_quant_types
        from transformers import TorchAoConfig
        device_map = {
            "model.decoder.embed_tokens": 0,
            "lm_head": 0,
            "model.decoder.embed_positions": 0,
            "model.decoder.project_out": 0,
            "model.decoder.project_in": 0,
            "model.decoder.layers.0": 0,
            "model.decoder.layers.1": 0,
            "model.decoder.layers.2": 0,
            "model.decoder.layers.3": 0,
            "model.decoder.layers.4": 0,
            "model.decoder.layers.5": 0,
            "model.decoder.layers.6": 1,
            "model.decoder.layers.7": 1,
            "model.decoder.layers.8": 1,
            "model.decoder.layers.9": 1,
            "model.decoder.layers.10": 1,
            "model.decoder.layers.11": 1,
            "model.decoder.final_layer_norm": 1,
        }
        quantization_config = TorchAoConfig(quant_type="int4_weight_only")
        model = AutoModelForCausalLM.from_pretrained(
            self.causal_lm_model_id,
            device_map=device_map,
            quantization_config=quantization_config,
            torch_dtype=torch.bfloat16,
        )
        assert set(model.hf_device_map.values()) == set(range(device_count))
        assert {p.device.index for p in model.parameters()} == set(range(device_count))
        model = prepare_model_for_kbit_training(model)
        model.model_parallel = True
        model.is_parallelizable = True
        config = LoraConfig(
            r=16,
            lora_alpha=32,
            target_modules=["q_proj", "v_proj"],
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM",
        )
        msg = re.escape("TorchaoLoraLinear only supports int8 weights for now")
        with pytest.raises(ValueError, match=msg):
            get_peft_model(model, config)
    @pytest.mark.single_gpu_tests
    def test_torchao_merge_layers_int8_weight_only(self):
        from torchao.dtypes import AffineQuantizedTensor
        from transformers import TorchAoConfig
        quant_type = "int8_weight_only"
        torch.manual_seed(0)
        device = 0
        dummy_input = torch.arange(10).view(-1, 1).to(device)
        quantization_config = TorchAoConfig(quant_type=quant_type)
        model = AutoModelForCausalLM.from_pretrained(
            self.causal_lm_model_id, device_map=device, quantization_config=quantization_config
        ).eval()
        logits_base = model(dummy_input)[0]
        config = LoraConfig(
            r=16,
            lora_alpha=32,
            target_modules=["q_proj", "v_proj"],
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM",
            init_lora_weights=False,
        )
        model = get_peft_model(model, config)
        model.eval()
        logits = model(dummy_input)[0]
        # sanity check: outputs changed
        # precision is quite low, so we need to use high atol and rtol
        atol, rtol = 1e-1, 1e-1
        assert not torch.allclose(logits, logits_base, atol=atol, rtol=rtol)
        model.merge_adapter()
        logits_merged = model(dummy_input)[0]
        for name, module in model.named_modules():
            if "base_layer" in name:
                assert isinstance(module.weight, AffineQuantizedTensor)
        model.unmerge_adapter()
        logits_unmerged = model(dummy_input)[0]
        for name, module in model.named_modules():
            if "base_layer" in name:
                assert isinstance(module.weight, AffineQuantizedTensor)
        model = model.merge_and_unload()
        logits_merged_unloaded = model(dummy_input)[0]
        assert torch.allclose(logits, logits_merged, atol=atol, rtol=rtol)
        assert torch.allclose(logits, logits_unmerged, atol=atol, rtol=rtol)
        assert torch.allclose(logits, logits_merged_unloaded, atol=atol, rtol=rtol)
    @pytest.mark.single_gpu_tests
    def test_torchao_merge_layers_int8_dynamic_activation_int8_weight_raises(self):
        # int8_dynamic_activation_int8_weight does not support dequantize, thus merging does not work
        from transformers import TorchAoConfig
        quant_type = "int8_dynamic_activation_int8_weight"
        torch.manual_seed(0)
        device = 0
        quantization_config = TorchAoConfig(quant_type=quant_type)
        model = AutoModelForCausalLM.from_pretrained(
            self.causal_lm_model_id, device_map=device, quantization_config=quantization_config
        ).eval()
        config = LoraConfig(
            r=16,
            lora_alpha=32,
            target_modules=["q_proj", "v_proj"],
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM",
            init_lora_weights=False,
        )
        model = get_peft_model(model, config)
        msg = re.escape(
            "Weights of type LinearActivationQuantizedTensor do not support dequantization (yet), which is needed to "
            "support merging."
        )
        with pytest.raises(NotImplementedError, match=msg):
            model.merge_adapter()
PRECISIONS = [(torch.float32), (torch.float16), (torch.bfloat16)]
LORA_PARAMS = {
    "r": 8,
    "lora_alpha": 16,
    "lora_dropout": 0.05,
}
class SimpleModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding_layer = torch.nn.Embedding(1000, 768)
        self.layer_norm = torch.nn.LayerNorm(768)
        self.linear_transform = torch.nn.Linear(768, 256)
    def forward(self, input_ids):
        embedded_output = self.embedding_layer(input_ids)
        norm_output = self.layer_norm(embedded_output)
        linear_output = self.linear_transform(norm_output)
        return linear_output
class SimpleConv2DModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding_layer = torch.nn.Embedding(1000, 768)
        self.layer_norm = torch.nn.LayerNorm(768)
        self.conv2d_transform = torch.nn.Conv2d(1, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    def forward(self, input_ids):
        # Additional layers for your custom model
        embedded_output = self.embedding_layer(input_ids)
        norm_output = self.layer_norm(embedded_output)
        # Reshape for Conv2d input (add batch size dimension)
        norm_output = norm_output.unsqueeze(1)
        conv_output = self.conv2d_transform(norm_output)
        # Remove batch size dimension
        conv_output = conv_output.squeeze(1)
        return conv_output
@require_non_cpu
class TestAutoCast(unittest.TestCase):
    device = infer_device()
    # This test makes sure, that Lora dtypes are consistent with the types
    # infered by torch.autocast under tested PRECISIONS
    @parameterized.expand(PRECISIONS)
    def test_simple_model(self, *args, **kwargs):
        self._test_model(SimpleModel(), *args, **kwargs)
    @parameterized.expand(PRECISIONS)
    def test_simple_lora_linear_model(self, *args, **kwargs):
        simple_model = SimpleModel()
        config = LoraConfig(
            **LORA_PARAMS,
            target_modules=["linear_transform"],
        )
        lora_model = get_peft_model(simple_model, config)
        self._test_model(lora_model, *args, **kwargs)
    @parameterized.expand(PRECISIONS)
    def test_simple_lora_embedding_model(self, *args, **kwargs):
        simple_model = SimpleModel()
        config = LoraConfig(
            **LORA_PARAMS,
            target_modules=["embedding_layer"],
        )
        lora_model = get_peft_model(simple_model, config)
        self._test_model(lora_model, *args, **kwargs)
    @parameterized.expand(PRECISIONS)
    def test_simple_conv2d_model(self, *args, **kwargs):
        self._test_model(SimpleConv2DModel(), *args, **kwargs)
    @parameterized.expand(PRECISIONS)
    def test_simple_lora_conv2d_model(self, *args, **kwargs):
        simple_model = SimpleConv2DModel()
        config = LoraConfig(
            **LORA_PARAMS,
            target_modules=["conv2d_transform"],
        )
        lora_model = get_peft_model(simple_model, config)
        self._test_model(lora_model, *args, **kwargs)
    def _test_model(self, model, precision):
        # Move model to GPU
        model = model.to(self.device)
        # Prepare dummy inputs
        input_ids = torch.randint(0, 1000, (2, 10)).to(self.device)
        if precision == torch.bfloat16:
            if not is_bf16_available():
                self.skipTest("Bfloat16 not supported on this device")
        # Forward pass with test precision
        with torch.autocast(enabled=True, dtype=precision, device_type=self.device):
            outputs = model(input_ids)
            assert outputs.dtype == precision
class TestFSDPWrap:
    """
    Test that we can successfully initialize an FSDP instance of the module.
    This is a very simple test, as it does not perform actual FSDP training. Here we just ensure that the FSDP instance
    can be created. This can fail for several reasons, e.g. int dtype from BNB or inconsistent requires_grad settings
    due to the auto wrap policy.
    """
    @pytest.mark.single_gpu_tests
    @require_bitsandbytes
    def test_bnb_4bit_wrap_fsdp(self):
        quant_config = BitsAndBytesConfig(
            load_in_4bit=True,
            # float32 must be used, or else FSDP will complain about mixed int and float dtypes
            bnb_4bit_compute_dtype=torch.float32,
            bnb_4bit_quant_storage=torch.float32,
            bnb_4bit_use_double_quant=True,
        )
        model = AutoModelForCausalLM.from_pretrained(
            "facebook/opt-125m",
            quantization_config=quant_config,
            torch_dtype=torch.float32,
        )
        # model = prepare_model_for_kbit_training(model)
        config = LoraConfig(
            target_modules=["q_proj", "v_proj"],
            task_type="CAUSAL_LM",
            use_dora=True,
        )
        model = get_peft_model(model, config)
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = "29501"
        init_process_group(world_size=1, rank=0)
        # check that this does not raise:
        FSDP(model, auto_wrap_policy=fsdp_auto_wrap_policy(model), use_orig_params=False, sync_module_states=True)
    def test_fsdp_auto_wrap_policy_does_not_raise_on_custom_model(self):
        # See #2167
        # Avoid raising on custom models since Trainer uses fsdp_auto_wrap_policy automatically for PEFT + FSDP
        fsdp_auto_wrap_policy(SimpleModel())  # does not raise
class TestBOFT:
    """
    Test that we can correctly use half-precision models with BOFT.
    """
    device = infer_device()
    @require_non_cpu
    @pytest.mark.single_gpu_tests
    def test_boft_half_linear(self):
        # Check that we can use BoFT with model loaded in half precision
        layer = torch.nn.Linear(160, 160).to(self.device)
        layer = boft.layer.Linear(layer, "layer", boft_n_butterfly_factor=2).to(dtype=torch.bfloat16)
        x = torch.randn(160, 160, device=self.device, dtype=torch.bfloat16)
        layer(x)  # does not raise
    @require_non_cpu
    @pytest.mark.single_gpu_tests
    def test_boft_half_conv(self):
        conv = torch.nn.Conv2d(1, 1, 4).to(self.device)
        conv = boft.layer.Conv2d(conv, "conv", boft_n_butterfly_factor=2).to(dtype=torch.bfloat16)
        x = torch.randn(1, 160, 160, device=self.device, dtype=torch.bfloat16)
        conv(x)  # does not raise
class TestPTuningReproducibility:
    device = infer_device()
    @require_non_cpu
    @require_deterministic_for_xpu
    def test_p_tuning_exactly_reproducible_after_loading(self, tmp_path):
        # See: https://github.com/huggingface/peft/issues/2043#issuecomment-2321522577
        # Ensure that after loading a p-tuning checkpoint, results are exactly reproducible (before the patch, they were
        # only _almost_ identical).
        # The model must be sufficiently large for the effect to be measurable, which is why this test requires is not
        # run on CPU.
        model_id = "facebook/opt-125m"
        inputs = torch.arange(10).view(-1, 1).to(self.device)
        torch.manual_seed(0)
        model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device)
        peft_config = PromptEncoderConfig(task_type="CAUSAL_LM", num_virtual_tokens=20, encoder_hidden_size=128)
        model = get_peft_model(model, peft_config).eval()
        with torch.inference_mode():
            output_peft = model(inputs).logits
            gen_peft = model.generate(inputs, min_new_tokens=10, max_new_tokens=10)
        model.save_pretrained(tmp_path)
        del model
        clear_device_cache(garbage_collection=True)
        model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device)
        model = PeftModel.from_pretrained(model, tmp_path)
        with torch.inference_mode():
            output_loaded = model(inputs).logits
            gen_loaded = model.generate(inputs, min_new_tokens=10, max_new_tokens=10)
        torch.testing.assert_close(output_loaded, output_peft)
        torch.testing.assert_close(gen_loaded, gen_peft)
@pytest.mark.single_gpu_tests
class TestLowCpuMemUsageDifferentDevices:
    """Test for the low CPU memory usage option for loading PEFT models.
    There are already tests for low_cpu_mem_usage=True in test_initialization.py but here we want to run tests that
    require a GPU.
    """
    model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
    device = infer_device()
    @require_non_cpu
    @pytest.mark.parametrize("device_model, device_sd", [("cpu", infer_device()), (infer_device(), "cpu")])
    def test_low_cpu_mem_usage_model_model_on_gpu_state_dict_on_cpu_works(self, device_model, device_sd):
        # specifically test diverging devices for the model and state_dict
        inputs = {"input_ids": torch.randint(0, 100, (1, 10)), "attention_mask": torch.ones(1, 10)}
        inputs = {k: v.to(device_model) for k, v in inputs.items()}
        model = AutoModelForCausalLM.from_pretrained(self.model_id).to(device_model)
        lora_config = LoraConfig(init_lora_weights=False, target_modules="all-linear")
        model = get_peft_model(model, lora_config)
        model.eval()
        logits_not_low_cpu_mem = model(**inputs).logits
        state_dict = get_peft_model_state_dict(model)
        peft_model_state_dict = {}
        # remap the state dict so that it can be correctly loaded, and move weights to the other device
        prefix = "base_model.model."
        for k, v in state_dict.items():
            k = k[len(prefix) :]
            peft_model_state_dict[k] = v.to(device_sd)
        del model
        model = AutoModelForCausalLM.from_pretrained(self.model_id).to(device_model)
        model.eval()
        inject_adapter_in_model(lora_config, model, low_cpu_mem_usage=True)
        load_result = set_peft_model_state_dict(model, peft_model_state_dict, low_cpu_mem_usage=True)
        # sanity check: all lora keys are matched
        assert not any("lora" in k for k in load_result.missing_keys)
        assert not any("lora" in k for k in load_result.unexpected_keys)
        logits_low_cpu_mem = model(**inputs).logits
        assert torch.allclose(logits_low_cpu_mem, logits_not_low_cpu_mem)
        assert {p.device.type for p in model.parameters()} == {device_model}
    @require_bitsandbytes
    @pytest.mark.parametrize("quantization_method", ["bnb-4bit", "bnb-8bit"])
    def test_low_cpu_mem_usage_with_quantization(self, quantization_method):
        # Ensure that low_cpu_mem_usage works with quantization
        # See also https://github.com/huggingface/diffusers/issues/10550
        if quantization_method == "bnb-4bit":
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.float32,
                bnb_4bit_quant_storage=torch.float32,
                bnb_4bit_use_double_quant=True,
            )
        elif quantization_method == "bnb-8bit":
            quantization_config = BitsAndBytesConfig(load_in_8bit=True)
        else:
            raise ValueError(f"Unknown quantization method {quantization_method}")
        model = AutoModelForCausalLM.from_pretrained(self.model_id, quantization_config=quantization_config)
        if model.device.type != self.device:
            # calling model.to("cuda") with 8 bit bnb raises an error, thus guard against it
            model = model.to(self.device)
        lora_config = LoraConfig(init_lora_weights=False, target_modules="all-linear")
        # We use get_peft_model with low_cpu_mem_usage=True here. This is not typically done in practice (the option is
        # mostly interesting for loading trained adapters), but it does the job for testing purposes.
        model = get_peft_model(model, lora_config, low_cpu_mem_usage=True)  # this should not raise
        assert {p.device.type for p in model.parameters()} == {self.device, "meta"}
class TestEvaInitializationGPU:
    """GPU tests for the Eva initialization method."""
    # Constants for test configuration
    COSINE_SIMILARITY_THRESHOLD = 0.75
    NUM_SEEDS = 3
    BATCH_SIZE = 4
    MAX_LENGTH = 256
    LORA_DIM = 8
    LORA_ALPHA = 1
    DEVICE = infer_device()
    @pytest.fixture
    def tokenizer(self):
        tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
        tokenizer.pad_token = tokenizer.eos_token
        return tokenizer
    @pytest.fixture
    def dataset(self, tokenizer):
        dataset = load_dataset_english_quotes()["train"]
        # concatenate examples
        examples = []
        example = ""
        for data in dataset:
            if len(example) >= self.MAX_LENGTH:
                examples.append(example)
                example = ""
            example = example + " " + data["quote"]
        dataset = Dataset.from_dict({"text": examples})
        # tokenize
        dataset = dataset.map(
            lambda x: tokenizer(x["text"], padding="max_length", truncation=True, max_length=self.MAX_LENGTH),
            batched=True,
            remove_columns=dataset.column_names,
        )
        dataset.set_format(type="torch")
        return dataset
    @pytest.fixture
    def model(self):
        model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
        model.transformer.h = model.transformer.h[:2]  # truncate to 2 layers
        return model.to(self.DEVICE)
    @pytest.fixture
    def model_bnb(self):
        bnb_config = BitsAndBytesConfig(load_in_4bit=True)
        model = AutoModelForCausalLM.from_pretrained(
            "openai-community/gpt2",
            quantization_config=bnb_config,
            attn_implementation="eager",  # gpt2 doesnt support flash attention
        )
        model.transformer.h = model.transformer.h[:2]  # truncate to 2 layers
        model = prepare_model_for_kbit_training(model)
        return model
    @pytest.fixture
    def model_fixture(self, request):
        return request.getfixturevalue(request.param)
    @pytest.fixture
    def peft_config(self):
        return LoraConfig(
            r=self.LORA_DIM,
            lora_alpha=self.LORA_ALPHA,
            target_modules=["c_attn"],
            init_lora_weights="eva",
            eva_config=EvaConfig(rho=2),
        )
    def is_bnb_model(self, model):
        return hasattr(model.config, "quantization_config")
    @staticmethod
    def collate_fn(examples):
        return {k: torch.stack([v[k] for v in examples], dim=0) for k in examples[0].keys()}
    @require_non_cpu
    @require_bitsandbytes
    @pytest.mark.single_gpu_tests
    @pytest.mark.parametrize("model_fixture", ["model", "model_bnb"], indirect=True)
    def test_eva_initialization_consistency(self, model_fixture, dataset, peft_config):
        """Test that the state dict returned by get_eva_state_dict loaded correctly and is consistent across different seeds based
        on the cosine similarity of the svd components."""
        state_dicts = []
        for seed in range(self.NUM_SEEDS):
            shuffled_dataset = dataset.shuffle(seed=seed)
            dataloader = DataLoader(
                shuffled_dataset,
                batch_size=self.BATCH_SIZE,
                collate_fn=lambda examples: {
                    k: torch.stack([v[k] for v in examples], dim=0) for k in examples[0].keys()
                },
                shuffle=False,
            )
            peft_model = get_peft_model(deepcopy(model_fixture), peft_config)
            initialize_lora_eva_weights(peft_model, dataloader)
            state_dicts.append(
                {k: v.cpu() for k, v in peft_model.state_dict().items() if "lora_A.default.weight" in k}
            )
        cos_sims = defaultdict(list)
        for i, j in itertools.combinations(range(self.NUM_SEEDS), 2):
            for k, v1 in state_dicts[i].items():
                v2 = state_dicts[j][k]
                min_size = min(v1.size(0), v2.size(0))
                cos_sims[k].extend(torch.cosine_similarity(v1[:min_size], v2[:min_size], dim=1).abs().tolist())
        mean_cosine_similarities = {k: torch.tensor(v).mean() for k, v in cos_sims.items()}
        for layer_name, mean_cosine_similarity in mean_cosine_similarities.items():
            assert mean_cosine_similarity > self.COSINE_SIMILARITY_THRESHOLD, (
                f"Mean absolute cosine similarity {mean_cosine_similarity:.4f} "
                f"is not greater than {self.COSINE_SIMILARITY_THRESHOLD}"
            )
class TestALoRAInferenceGPU:
    """GPU inference for Activated LoRA."""
    # Constants for test configuration
    NUM_SEEDS = 3
    LORA_DIM = 8
    LORA_ALPHA = 1
    DEVICE = infer_device()
    @pytest.fixture
    def tokenizer(self):
        tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
        tokenizer.pad_token = tokenizer.eos_token
        return tokenizer
    @pytest.fixture
    def model(self):
        model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
        model.model.decoder.layers = model.model.decoder.layers[:2]  # truncate to 2 layers
        return model.to(self.DEVICE)
    @pytest.fixture
    def model_bnb(self):
        bnb_config = BitsAndBytesConfig(load_in_4bit=True)
        model = AutoModelForCausalLM.from_pretrained(
            "facebook/opt-125m",
            quantization_config=bnb_config,
        )
        model.model.decoder.layers = model.model.decoder.layers[:2]  # truncate to 2 layers
        model = prepare_model_for_kbit_training(model)
        return model
    @pytest.fixture
    def peft_config(self):
        return LoraConfig(
            r=self.LORA_DIM,
            task_type="CAUSAL_LM",
            lora_alpha=self.LORA_ALPHA,
            target_modules=["q_proj"],
            alora_invocation_tokens=[2],  # id for 
            init_lora_weights=False,
        )
    @require_non_cpu
    @require_bitsandbytes
    @pytest.mark.single_gpu_tests
    def test_alora_forward_consistency(self, model, model_bnb, peft_config):
        """Test that the forwards of the model with adapter are similar across quantizations."""
        for seed in range(self.NUM_SEEDS):
            torch.manual_seed(seed)
            # random.seed(seed)
            np.random.seed(seed)
            peft_model = get_peft_model(deepcopy(model), peft_config)
            torch.manual_seed(seed)
            # random.seed(seed)
            np.random.seed(seed)
            peft_model_bnb = get_peft_model(deepcopy(model_bnb), peft_config)
            peft_model.eval()
            peft_model_bnb.eval()
            input_ids = torch.tensor([[0, 1, 2, 3]]).to(self.DEVICE)
            with torch.no_grad():
                peft_out = peft_model(input_ids=input_ids, return_dict=True, output_hidden_states=True)
                peft_out_bnb = peft_model_bnb(input_ids=input_ids, return_dict=True, output_hidden_states=True)
            h_fp = peft_out.hidden_states[-1]
            h_4b = peft_out_bnb.hidden_states[-1]
            a = h_fp.detach().to(torch.float32).cpu()
            b = h_4b.detach().to(torch.float32).cpu()
            import torch.nn.functional as F
            cos = F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item()
            assert cos > 0.9
@pytest.mark.multi_gpu_tests
class TestPrefixTuning:
    device = infer_device()
    @require_torch_multi_accelerator
    def test_prefix_tuning_multiple_devices_decoder_model(self):
        # See issue 2134
        model_id = "hf-internal-testing/tiny-random-MistralForCausalLM"
        tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
        inputs = tokenizer(["A list of colors: red, blue"], return_tensors="pt").to(self.device)
        device_map = {
            "model.embed_tokens": 0,
            "model.layers.0": 0,
            "model.layers.1": 1,
            "model.norm": 1,
            "model.rotary_emb": 1,
            "lm_head": 1,
        }
        model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device_map)
        # sanity check, as the test passes trivially for a single device
        assert len({p.device for p in model.parameters()}) > 1
        # sanity check: this should work without peft
        model.generate(**inputs)  # does not raise
        peft_config = PrefixTuningConfig(num_virtual_tokens=10, task_type="CAUSAL_LM")
        model = get_peft_model(model, peft_config)
        model.generate(**inputs)  # does not raise
    @require_torch_multi_accelerator
    def test_prefix_tuning_multiple_devices_encoder_decoder_model(self):
        # See issue 2134
        model_id = "hf-internal-testing/tiny-random-T5Model"
        tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
        inputs = tokenizer(["A list of colors: red, blue"], return_tensors="pt").to(self.device)
        device_map = {
            "shared": 0,
            "encoder.embed_tokens": 0,
            "encoder.block.0": 0,
            "encoder.block.1": 0,
            "encoder.block.2": 1,
            "encoder.block.3": 1,
            "encoder.block.4": 1,
            "encoder.final_layer_norm": 1,
            "decoder.embed_tokens": 0,
            "decoder.block.0": 0,
            "decoder.block.1": 0,
            "decoder.block.2": 1,
            "decoder.block.3": 1,
            "decoder.block.4": 1,
            "decoder.final_layer_norm": 1,
            "lm_head": 0,
        }
        model = AutoModelForSeq2SeqLM.from_pretrained(model_id, device_map=device_map)
        # sanity check, as the test passes trivially for a single device
        assert len({p.device for p in model.parameters()}) > 1
        # sanity check: this should work without peft
        model.generate(**inputs)  # does not raise
        peft_config = PrefixTuningConfig(num_virtual_tokens=10, task_type="SEQ_2_SEQ_LM")
        model = get_peft_model(model, peft_config)
        model.generate(**inputs)  # does not raise
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="test requires a GPU or XPU")
@pytest.mark.single_gpu_tests
class TestHotSwapping:
    """
    Test hotswapping on compiled models.
    This test suite is only run on GPU as it is quite slow.
    """
    torch_device = infer_device()
    @pytest.fixture(scope="class", autouse=True)
    def reset_float32_matmul_precision(self):
        # Earlier tests may run torchao, which, at the time this was added, sets the float32 matmul precision to 'high'.
        # This in turn results in some models producing different outputs when compiled (but only for some seeds).
        # Therefore, we need to ensure that the precision is reset to "highest", which is the default.
        # TODO: if torchao removes the side effect, this fixture can be deleted.
        # https://github.com/pytorch/ao/blob/ffb4350640e76c7e7f449dd1e36d33f19fe384c8/torchao/quantization/utils.py#L589
        torch.set_float32_matmul_precision("highest")
    @pytest.fixture(autouse=True)
    def reset_dynamo_cache(self):
        # It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
        # there will be recompilation errors, as torch caches the model when run in the same process.
        yield
        torch._dynamo.reset()
    #######
    # LLM #
    #######
    def check_hotswap(self, do_hotswap, ranks, alpha_scalings):
        """
        Test hotswapping with a compiled model.
        Passing do_hotswap=False should trigger recompilation. Use the raise_error_on_recompile context manager to
        raise an error when recompilation occurs.
        """
        torch.manual_seed(0)
        inputs = torch.arange(10).view(-1, 1).to(self.torch_device)
        model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
        model = AutoModelForCausalLM.from_pretrained(model_id).to(self.torch_device)
        rank0, rank1 = ranks
        alpha0, alpha1 = alpha_scalings
        # note that the 2nd adapter targeting a subset of the 1st adapter is okay, but not the other way round
        config0 = LoraConfig(init_lora_weights=False, r=rank0, lora_alpha=alpha0, target_modules=["q_proj", "v_proj"])
        config1 = LoraConfig(init_lora_weights=False, r=rank1, lora_alpha=alpha1, target_modules=["q_proj"])
        model = get_peft_model(model, config0, adapter_name="adapter0").eval()
        with torch.inference_mode():
            output0 = model(inputs).logits
        model.add_adapter("adapter1", config1)
        model.set_adapter("adapter1")
        with torch.inference_mode():
            output1 = model(inputs).logits
        # sanity check:
        tol = 1e-4
        assert not torch.allclose(output0, output1, atol=tol, rtol=tol)
        with tempfile.TemporaryDirectory() as tmp_dirname:
            model.save_pretrained(tmp_dirname)
            del model
            model = AutoModelForCausalLM.from_pretrained(model_id).to(self.torch_device)
            model = PeftModel.from_pretrained(model, os.path.join(tmp_dirname, "adapter0")).eval()
            if do_hotswap:
                prepare_model_for_compiled_hotswap(model, config=model.peft_config, target_rank=max(ranks))
            model = torch.compile(model, mode="reduce-overhead")
            output_after0 = model(inputs).logits
            assert torch.allclose(output0, output_after0, atol=tol, rtol=tol)
            # swap and check that we get the output from adapter1
            if do_hotswap:
                hotswap_adapter(model, os.path.join(tmp_dirname, "adapter1"), adapter_name="default")
            else:
                model.load_adapter(os.path.join(tmp_dirname, "adapter1"), adapter_name="other")
                model.set_adapter("other")
            # we need to call forward to potentially trigger recompilation
            output_after1 = model(inputs).logits
            assert torch.allclose(output1, output_after1, atol=tol, rtol=tol)
            # we need to call forward third time since cudagraphs are not recorded in first call.
            if do_hotswap:
                hotswap_adapter(model, os.path.join(tmp_dirname, "adapter0"), adapter_name="default")
                output_after2 = model(inputs).logits
                assert torch.allclose(output0, output_after2, atol=tol, rtol=tol)
    # it is important to check hotswapping small to large ranks and large to small ranks
    @pytest.mark.parametrize("ranks", [(11, 11), (7, 13), (13, 7)])
    def test_hotswapping_compiled_model_does_not_trigger_recompilation(self, ranks):
        # here we set three configs to ensure no recompilation or cudagraph re-record occurs:
        # 1. error_on_recompile: raise an error on recompilation
        # 2. inline_inbuilt_nn_modules: needed to raise an error on static input address changes instead of re-recording
        # 3. triton.cudagraph_support_input_mutation: same as above
        dynamo_config_ctx = torch._dynamo.config.patch(error_on_recompile=True, inline_inbuilt_nn_modules=False)
        inductor_config_ctx = torch._inductor.config.patch("triton.cudagraph_support_input_mutation", False)
        with dynamo_config_ctx, inductor_config_ctx:
            self.check_hotswap(do_hotswap=True, ranks=ranks, alpha_scalings=ranks)
    def test_no_hotswapping_compiled_model_triggers_recompilation(self):
        # contingency test to ensure that hotswapping is actually needed to prevent recompilation
        ranks = 7, 13
        with torch._dynamo.config.patch(error_on_recompile=True):
            with pytest.raises(torch._dynamo.exc.RecompileError):  # raise an error on recompilation
                self.check_hotswap(do_hotswap=False, ranks=ranks, alpha_scalings=ranks)
    ###################
    # DIFFUSION MODEL #
    ###################
    def get_small_unet(self):
        # from diffusers UNet2DConditionModelTests
        from diffusers import UNet2DConditionModel
        torch.manual_seed(0)
        init_dict = {
            "block_out_channels": (4, 8),
            "norm_num_groups": 4,
            "down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
            "up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
            "cross_attention_dim": 8,
            "attention_head_dim": 2,
            "out_channels": 4,
            "in_channels": 4,
            "layers_per_block": 1,
            "sample_size": 16,
        }
        model = UNet2DConditionModel(**init_dict)
        return model.to(self.torch_device)
    def get_unet_lora_config(self, lora_rank, lora_alpha, target_modules):
        # from diffusers test_models_unet_2d_condition.py
        # note that this only targets linear layers by default
        unet_lora_config = LoraConfig(
            r=lora_rank,
            lora_alpha=lora_alpha,
            target_modules=target_modules,
            init_lora_weights=False,
            use_dora=False,
        )
        return unet_lora_config
    def get_dummy_input(self):
        pipeline_inputs = {
            "prompt": "A painting of a squirrel eating a burger",
            "num_inference_steps": 5,
            "guidance_scale": 6.0,
            "output_type": "np",
            "return_dict": False,
        }
        return pipeline_inputs
    def set_lora_device(self, model, adapter_names, device):
        # copied from diffusers LoraBaseMixin.set_lora_device
        for module in model.modules():
            if isinstance(module, BaseTunerLayer):
                for adapter_name in adapter_names:
                    module.lora_A[adapter_name].to(device)
                    module.lora_B[adapter_name].to(device)
                    # this is a param, not a module, so device placement is not in-place -> re-assign
                    if hasattr(module, "lora_magnitude_vector") and module.lora_magnitude_vector is not None:
                        if adapter_name in module.lora_magnitude_vector:
                            module.lora_magnitude_vector[adapter_name] = module.lora_magnitude_vector[adapter_name].to(
                                device
                            )
    def check_hotswap_diffusion(self, ranks, alpha_scalings, target_modules):
        """
        Check that hotswapping works on a pipeline.
        This is essentially the same test as:
        https://github.com/huggingface/diffusers/blob/d7dd924ece56cddf261cd8b9dd901cbfa594c62c/tests/pipelines/test_pipelines.py#L2264
        Steps:
        - create 2 LoRA adapters and save them
        - load the first adapter
        - hotswap the second adapter
        - check that the outputs are correct
        - optionally compile the model
        Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
        fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
        fine.
        """
        from diffusers import StableDiffusionPipeline
        # create 2 adapters with different ranks and alphas
        dummy_input = self.get_dummy_input()
        pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
        rank0, rank1 = ranks
        alpha0, alpha1 = alpha_scalings
        max_rank = max([rank0, rank1])
        lora_config0 = self.get_unet_lora_config(rank0, alpha0, target_modules)
        lora_config1 = self.get_unet_lora_config(rank1, alpha1, target_modules)
        torch.manual_seed(0)
        pipeline.unet.add_adapter(lora_config0, adapter_name="adapter0")
        output0_before = pipeline(**dummy_input, generator=torch.manual_seed(0))[0]
        torch.manual_seed(1)
        pipeline.unet.add_adapter(lora_config1, adapter_name="adapter1")
        pipeline.unet.set_adapter("adapter1")
        output1_before = pipeline(**dummy_input, generator=torch.manual_seed(0))[0]
        # sanity check
        tol = 1e-3
        assert not np.allclose(output0_before, output1_before, atol=tol, rtol=tol)
        assert not (output0_before == 0).all()
        assert not (output1_before == 0).all()
        with tempfile.TemporaryDirectory() as tmp_dirname:
            # save the adapter checkpoints
            sd0 = get_peft_model_state_dict(pipeline.unet, adapter_name="adapter0")
            StableDiffusionPipeline.save_lora_weights(
                save_directory=os.path.join(tmp_dirname, "adapter0"), safe_serialization=True, unet_lora_layers=sd0
            )
            sd1 = get_peft_model_state_dict(pipeline.unet, adapter_name="adapter1")
            StableDiffusionPipeline.save_lora_weights(
                save_directory=os.path.join(tmp_dirname, "adapter1"), safe_serialization=True, unet_lora_layers=sd1
            )
            del pipeline
            # load the first adapter
            pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
            # no need to prepare if the model is not compiled or if the ranks are identical
            pipeline.enable_lora_hotswap(target_rank=max_rank)
            file_name0 = os.path.join(tmp_dirname, "adapter0", "pytorch_lora_weights.safetensors")
            file_name1 = os.path.join(tmp_dirname, "adapter1", "pytorch_lora_weights.safetensors")
            pipeline.load_lora_weights(file_name0)
            pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead")
            output0_after = pipeline(**dummy_input, generator=torch.manual_seed(0))[0]
            # sanity check: still same result
            assert np.allclose(output0_before, output0_after, atol=tol, rtol=tol)
            # hotswap the 2nd adapter
            pipeline.load_lora_weights(file_name1, hotswap=True, adapter_name="default_0")
            output1_after = pipeline(**dummy_input, generator=torch.manual_seed(0))[0]
            # sanity check: since it's the same LoRA, the results should be identical
            assert np.allclose(output1_before, output1_after, atol=tol, rtol=tol)
            # we need to call forward third time since cudagraphs are not recorded in first call.
            pipeline.load_lora_weights(file_name0, hotswap=True, adapter_name="default_0")
            output2_after = pipeline(**dummy_input, generator=torch.manual_seed(0))[0]
            assert np.allclose(output0_before, output2_after, atol=tol, rtol=tol)
    @pytest.mark.skipif(not is_diffusers_available(), reason="Test requires diffusers to be installed")
    # it is important to check hotswapping small to large ranks and large to small ranks
    @pytest.mark.parametrize("ranks", [(11, 11), (7, 13), (13, 7)])
    @pytest.mark.parametrize(
        "target_modules",
        [
            ["to_q", "to_k", "to_v", "to_out.0"],  # Linear layers
            ["conv", "conv1", "conv2"],  # Conv2d layers
            ["to_q", "conv"],  # mix of Linear and Conv2d
        ],
    )
    def test_hotswapping_compiled_diffusers_model_does_not_trigger_recompilation(self, ranks, target_modules):
        # here we set three configs to ensure no recompilation or cudagraph re-record occurs:
        # 1. error_on_recompile: raise an error on recompilation
        # 2. inline_inbuilt_nn_modules: needed to raise an error on static input address changes instead of re-recording
        # 3. triton.cudagraph_support_input_mutation: same as above
        dynamo_config_ctx = torch._dynamo.config.patch(error_on_recompile=True, inline_inbuilt_nn_modules=False)
        inductor_config_ctx = torch._inductor.config.patch("triton.cudagraph_support_input_mutation", False)
        with dynamo_config_ctx, inductor_config_ctx:
            self.check_hotswap_diffusion(ranks=ranks, alpha_scalings=ranks, target_modules=target_modules)
# Test: 4-bit load + Arrow + generate
class TestArrowQuantized:
    @pytest.fixture(scope="class")
    def workdir(self, tmp_path_factory):
        """Create and return a temp directory path for this class (no chdir)."""
        wd = tmp_path_factory.mktemp("arrow_workdir")
        return Path(wd)
    def _create_and_save_adapter_opt(self, out_dir: Path, rank: int = 4):
        """
        Build a randomly initialized LoRA adapter for OPT-125M and save into `out_dir`. We construct a model from
        CONFIG (no pretrained weights) to avoid slow downloads here.
        """
        model_id = "facebook/opt-125m"
        # Target all linear layers so the adapter matches whatever we later quantize/load.
        lora_cfg = LoraConfig(
            r=rank,
            target_modules="all-linear",
            task_type="CAUSAL_LM",
            init_lora_weights=False,
        )
        # Load the adapter on the model and save it
        with hub_online_once(model_id):
            model = AutoModelForCausalLM.from_pretrained(model_id)
        peft_model = get_peft_model(model, lora_cfg)
        peft_model.save_pretrained(out_dir)
    @pytest.fixture(scope="class")
    def ts_adapters_opt(self, workdir: Path):
        """
        Build 3 locally-saved task-specific adapters for OPT-125M and return their absolute paths.
        """
        paths = []
        for i in range(3):
            sub = workdir / f"ts_expert_{i}"
            self._create_and_save_adapter_opt(sub)
            paths.append(str(sub))
        return paths
    @require_bitsandbytes
    @pytest.mark.single_gpu_tests
    def test_arrow_4bit_opt125m_load_and_generate_with_local_adapters(self, ts_adapters_opt):
        # Skip if CUDA or bitsandbytes isn’t available
        if not torch.cuda.is_available():
            pytest.skip("CUDA required for 4-bit bitsandbytes test.")
        model_id = "facebook/opt-125m"
        # Quantization config (nf4, bf16 compute)
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=False,
        )
        with hub_online_once(model_id):
            # Load quantized base model
            base_model = AutoModelForCausalLM.from_pretrained(
                model_id,
                torch_dtype=torch.bfloat16,
                device_map="auto",
                quantization_config=bnb_config,
            )
            with hub_online_once(model_id + "tokenizer"):
                tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
        # Build Arrow model from the locally created adapters
        arrow_cfg = ArrowConfig(top_k=2, router_temperature=1.0, rng_seed=42)
        model = create_arrow_model(
            base_model=base_model,
            task_specific_adapter_paths=ts_adapters_opt,  # local dirs (each has adapter_config.json)
            arrow_config=arrow_cfg,
        ).eval()
        # Quick generate smoke test
        inputs = tok("Hello world", return_tensors="pt")
        inputs = {k: v.to(model.device) for k, v in inputs.items()}
        with torch.no_grad():
            out = model.generate(**inputs, max_new_tokens=8)
        assert out is not None
        assert out.shape[0] == 1  # batch size 1