Hi everyone,
I’ve been trying to switch from LoRA to QLoRA on an Nvidia T4, but I’m running into an issue where the evaluation loss stays completely flat, while the training loss fluctuates around its initial value.
My LoRA setup works fine, but adding bnb_config
, model.gradient_checkpointing_enable()
, and model = prepare_model_for_kbit_training(model)
causes the issue described above.
Since the non-quantized version runs without problems, I don’t think the issue is related to the LoRA config, dataset, or formatting functions. The number of trainable parameters is non-zero for both the LoRA and QLoRA setups.
Below is the code I’m using for QLoRA. Any help would be appreciated!
ds_train_with_assistant_content = ds_train.map(construct_message_with_assistant_content)
ds_valid_with_assistant_content = ds_valid.map(construct_message_with_assistant_content)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
checkpoint = "Qwen/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(
checkpoint,
device_map="auto",
quantization_config=bnb_config
)
model.config.use_cache = False
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
model.enable_input_require_grads()
timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
RUN_NAME = f'qlora-final-model-all-linear-r64-{timestamp}'
wandb.init(
project=os.environ["WANDB_PROJECT"],
name=RUN_NAME,
# id=run_id, # resume previous run if available
resume="allow", # allows resuming crashed run
)
RESUME_TRAINING = False
OUTPUT_DIR = "./qlora-final_model_all_linear_r64-output"
PER_DEVICE_BATCH_SIZE = 2 # higher values --> OOM
optimizer = 'paged_adamw_8bit'
effective_batch_size = 16
learning_rate = 1e-5
weight_decay = 0.0
betas = (0.9, 0.9999)
warmup_ratio = 0.2
epochs = 1
gradient_accumulation_steps = int(effective_batch_size / PER_DEVICE_BATCH_SIZE)
lora_r = 16*4
lora_alpha = 64*4
lora_dropout = 0.01
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
gradient_accumulation_steps=gradient_accumulation_steps,
learning_rate=learning_rate,
optim=optimizer,
num_train_epochs=epochs,
weight_decay=weight_decay,
lr_scheduler_type="cosine",
warmup_ratio=warmup_ratio,
save_strategy="steps",
save_steps=gradient_accumulation_steps*5,
save_total_limit=2,
eval_strategy="steps",
eval_steps=gradient_accumulation_steps*5,
logging_strategy="steps",
logging_steps=gradient_accumulation_steps*5,
report_to=['wandb'],
run_name=RUN_NAME,
bf16=True,
# fp16=True,
# fp16_full_eval=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
max_grad_norm=1,
load_best_model_at_end=True,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False}
)
peft_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules='all-linear'
)
# model.requires_grad_(False) # freeze base weights (precautionary)
model_peft = get_peft_model(model, peft_config) # inject a LoRA adapter
print_trainable_parameters(model_peft)
trainer = SFTTrainer(
model=model_peft,
train_dataset=ds_train_with_assistant_content,
eval_dataset=ds_valid_with_assistant_content,
formatting_func=formatting_func,
args=training_args,
callbacks=[EarlyStoppingCallback(early_stopping_patience=25)]
)
# Training setup summary
dataset_size = len(ds_train_with_assistant_content)
steps_per_epoch = dataset_size // (PER_DEVICE_BATCH_SIZE * gradient_accumulation_steps)
total_steps = steps_per_epoch * epochs
warmup_steps = int(total_steps * warmup_ratio)
print("===== Training Setup Summary =====")
print(f"Num epochs: {epochs}")
print(f"Effective batch size: {effective_batch_size}")
print(f"Per-device batch size: {PER_DEVICE_BATCH_SIZE}")
print(f"Gradient accumulation: {gradient_accumulation_steps}")
print(f"Dataset size: {dataset_size}")
print(f"Steps per epoch: {steps_per_epoch}")
print(f"Total training steps: {total_steps}")
print(f"Warmup steps: {warmup_steps}")
print(f"Logging steps: {training_args.logging_steps}")
print("===================================")
print(f"Start time: {datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}")
# Training
last_checkpoint = None
if RESUME_TRAINING and os.path.isdir(OUTPUT_DIR):
last_checkpoint = get_last_checkpoint(OUTPUT_DIR)
if last_checkpoint is not None:
print(f"Resuming training from checkpoint: {last_checkpoint}")
trainer.train(resume_from_checkpoint=last_checkpoint)
else:
print("Starting fresh training run")
trainer.train()
print(f"End time: {datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}")
# WandB logging of eval metrics
for log in trainer.state.log_history:
if 'eval_loss' in log:
wandb.log({
"eval_loss": log['eval_loss'],
"eval_perplexity": math.exp(log['eval_loss']),
"step": log['step'],
"learning_rate": learning_rate,
"weight_decay": weight_decay,
"betas": betas,
"warmup_ratio": warmup_ratio,
"effective_batch_size": effective_batch_size,
"optimizer": optimizer
})
wandb.finish() # finish the run