chore(callback): Remove old peft saving code (#510)
Browse files
src/axolotl/utils/callbacks.py
CHANGED
|
@@ -43,29 +43,6 @@ LOG = logging.getLogger("axolotl.callbacks")
|
|
| 43 |
IGNORE_INDEX = -100
|
| 44 |
|
| 45 |
|
| 46 |
-
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
| 47 |
-
"""Callback to save the PEFT adapter"""
|
| 48 |
-
|
| 49 |
-
def on_save(
|
| 50 |
-
self,
|
| 51 |
-
args: TrainingArguments,
|
| 52 |
-
state: TrainerState,
|
| 53 |
-
control: TrainerControl,
|
| 54 |
-
**kwargs,
|
| 55 |
-
):
|
| 56 |
-
checkpoint_folder = os.path.join(
|
| 57 |
-
args.output_dir,
|
| 58 |
-
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
| 59 |
-
)
|
| 60 |
-
|
| 61 |
-
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
| 62 |
-
kwargs["model"].save_pretrained(
|
| 63 |
-
peft_model_path, save_safetensors=args.save_safetensors
|
| 64 |
-
)
|
| 65 |
-
|
| 66 |
-
return control
|
| 67 |
-
|
| 68 |
-
|
| 69 |
class EvalFirstStepCallback(
|
| 70 |
TrainerCallback
|
| 71 |
): # pylint: disable=too-few-public-methods disable=unused-argument
|
|
|
|
| 43 |
IGNORE_INDEX = -100
|
| 44 |
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
class EvalFirstStepCallback(
|
| 47 |
TrainerCallback
|
| 48 |
): # pylint: disable=too-few-public-methods disable=unused-argument
|
src/axolotl/utils/trainer.py
CHANGED
|
@@ -31,7 +31,6 @@ from axolotl.utils.callbacks import (
|
|
| 31 |
EvalFirstStepCallback,
|
| 32 |
GPUStatsCallback,
|
| 33 |
SaveBetterTransformerModelCallback,
|
| 34 |
-
SavePeftModelCallback,
|
| 35 |
bench_eval_callback_factory,
|
| 36 |
log_prediction_callback_factory,
|
| 37 |
)
|
|
@@ -711,12 +710,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
| 711 |
if cfg.relora_steps:
|
| 712 |
callbacks.append(ReLoRACallback(cfg))
|
| 713 |
|
| 714 |
-
if cfg.local_rank == 0 and cfg.adapter in [
|
| 715 |
-
"lora",
|
| 716 |
-
"qlora",
|
| 717 |
-
]: # only save in rank 0
|
| 718 |
-
callbacks.append(SavePeftModelCallback)
|
| 719 |
-
|
| 720 |
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
|
| 721 |
callbacks.append(SaveBetterTransformerModelCallback)
|
| 722 |
|
|
|
|
| 31 |
EvalFirstStepCallback,
|
| 32 |
GPUStatsCallback,
|
| 33 |
SaveBetterTransformerModelCallback,
|
|
|
|
| 34 |
bench_eval_callback_factory,
|
| 35 |
log_prediction_callback_factory,
|
| 36 |
)
|
|
|
|
| 710 |
if cfg.relora_steps:
|
| 711 |
callbacks.append(ReLoRACallback(cfg))
|
| 712 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 713 |
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
|
| 714 |
callbacks.append(SaveBetterTransformerModelCallback)
|
| 715 |
|