Delete test_run_uploads/ with huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- test_run_uploads/UnslothAlignPropTrainer.py +0 -646
- test_run_uploads/UnslothBCOTrainer.py +0 -1834
- test_run_uploads/UnslothCPOTrainer.py +0 -1566
- test_run_uploads/UnslothDDPOTrainer.py +0 -881
- test_run_uploads/UnslothDPOTrainer.py +0 -0
- test_run_uploads/UnslothGKDTrainer.py +0 -885
- test_run_uploads/UnslothGRPOTrainer.py +0 -0
- test_run_uploads/UnslothKTOTrainer.py +0 -1849
- test_run_uploads/UnslothNashMDTrainer.py +0 -969
- test_run_uploads/UnslothORPOTrainer.py +0 -1552
- test_run_uploads/UnslothOnlineDPOTrainer.py +0 -1293
- test_run_uploads/UnslothPPOTrainer.py +0 -1273
- test_run_uploads/UnslothPRMTrainer.py +0 -809
- test_run_uploads/UnslothRLOOTrainer.py +0 -1143
- test_run_uploads/UnslothRewardTrainer.py +0 -828
- test_run_uploads/UnslothSFTTrainer.py +0 -1102
- test_run_uploads/UnslothXPOTrainer.py +0 -1024
- test_run_uploads/__pycache__/UnslothAlignPropTrainer.cpython-311.pyc +0 -0
- test_run_uploads/__pycache__/UnslothBCOTrainer.cpython-311.pyc +0 -0
- test_run_uploads/__pycache__/UnslothCPOTrainer.cpython-311.pyc +0 -0
- test_run_uploads/__pycache__/UnslothDDPOTrainer.cpython-311.pyc +0 -0
- test_run_uploads/__pycache__/UnslothDPOTrainer.cpython-311.pyc +0 -3
- test_run_uploads/__pycache__/UnslothGKDTrainer.cpython-311.pyc +0 -0
- test_run_uploads/__pycache__/UnslothGRPOTrainer.cpython-311.pyc +0 -0
- test_run_uploads/__pycache__/UnslothKTOTrainer.cpython-311.pyc +0 -0
- test_run_uploads/__pycache__/UnslothNashMDTrainer.cpython-311.pyc +0 -0
- test_run_uploads/__pycache__/UnslothORPOTrainer.cpython-311.pyc +0 -0
- test_run_uploads/__pycache__/UnslothOnlineDPOTrainer.cpython-311.pyc +0 -0
- test_run_uploads/__pycache__/UnslothPPOTrainer.cpython-311.pyc +0 -0
- test_run_uploads/__pycache__/UnslothPRMTrainer.cpython-311.pyc +0 -0
- test_run_uploads/__pycache__/UnslothRLOOTrainer.cpython-311.pyc +0 -0
- test_run_uploads/__pycache__/UnslothRewardTrainer.cpython-311.pyc +0 -0
- test_run_uploads/__pycache__/UnslothSFTTrainer.cpython-311.pyc +0 -0
- test_run_uploads/__pycache__/UnslothXPOTrainer.cpython-311.pyc +0 -0
- test_run_uploads/checkpoint-50/README.md +0 -210
- test_run_uploads/checkpoint-50/adapter_config.json +0 -41
- test_run_uploads/checkpoint-50/adapter_model.safetensors +0 -3
- test_run_uploads/checkpoint-50/chat_template.jinja +0 -1
- test_run_uploads/checkpoint-50/optimizer.pt +0 -3
- test_run_uploads/checkpoint-50/rng_state.pth +0 -3
- test_run_uploads/checkpoint-50/scaler.pt +0 -3
- test_run_uploads/checkpoint-50/scheduler.pt +0 -3
- test_run_uploads/checkpoint-50/special_tokens_map.json +0 -24
- test_run_uploads/checkpoint-50/tokenizer.json +0 -3
- test_run_uploads/checkpoint-50/tokenizer_config.json +0 -0
- test_run_uploads/checkpoint-50/trainer_state.json +0 -77
- test_run_uploads/checkpoint-50/training_args.bin +0 -3
- test_run_uploads/checkpoint-90/README.md +0 -210
- test_run_uploads/checkpoint-90/adapter_config.json +0 -41
- test_run_uploads/checkpoint-90/adapter_model.safetensors +0 -3
test_run_uploads/UnslothAlignPropTrainer.py
DELETED
|
@@ -1,646 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
2025.7.11
|
| 3 |
-
2025.7.11
|
| 4 |
-
4.54.1
|
| 5 |
-
0.16.1
|
| 6 |
-
__UNSLOTH_VERSIONING__
|
| 7 |
-
"""
|
| 8 |
-
from torch import Tensor
|
| 9 |
-
import torch
|
| 10 |
-
import torch.nn as nn
|
| 11 |
-
from torch.nn import functional as F
|
| 12 |
-
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
| 13 |
-
from trl.trainer.alignprop_trainer import (Accelerator, AlignPropConfig, AlignPropTrainer, Any, Callable, DDPOStableDiffusionPipeline, Optional, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, wandb, warn)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
import os
|
| 17 |
-
from typing import *
|
| 18 |
-
from dataclasses import dataclass, field
|
| 19 |
-
from packaging.version import Version
|
| 20 |
-
import torch
|
| 21 |
-
import numpy as np
|
| 22 |
-
from contextlib import nullcontext
|
| 23 |
-
from torch.nn import functional as F
|
| 24 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
| 25 |
-
|
| 26 |
-
torch_compile_options = {
|
| 27 |
-
"epilogue_fusion" : True,
|
| 28 |
-
"max_autotune" : False,
|
| 29 |
-
"shape_padding" : True,
|
| 30 |
-
"trace.enabled" : False,
|
| 31 |
-
"triton.cudagraphs" : False,
|
| 32 |
-
}
|
| 33 |
-
|
| 34 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 35 |
-
def chunked_selective_log_softmax(logits, index):
|
| 36 |
-
# Split into 4 chunks only
|
| 37 |
-
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
|
| 38 |
-
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
|
| 39 |
-
all_per_token_logps = []
|
| 40 |
-
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
|
| 41 |
-
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
|
| 42 |
-
chunk_logits = chunk_logits.to(torch.float32)
|
| 43 |
-
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 44 |
-
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
|
| 45 |
-
per_token_logps = selected_logits - logsumexp_values
|
| 46 |
-
all_per_token_logps.append(per_token_logps)
|
| 47 |
-
pass
|
| 48 |
-
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 49 |
-
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
|
| 50 |
-
return all_per_token_logps
|
| 51 |
-
@dataclass
|
| 52 |
-
class UnslothAlignPropConfig(AlignPropConfig):
|
| 53 |
-
"""
|
| 54 |
-
|
| 55 |
-
Configuration class for the [`AlignPropTrainer`].
|
| 56 |
-
|
| 57 |
-
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 58 |
-
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 59 |
-
command line.
|
| 60 |
-
|
| 61 |
-
Parameters:
|
| 62 |
-
exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
|
| 63 |
-
Name of this experiment (defaults to the file name without the extension).
|
| 64 |
-
run_name (`str`, *optional*, defaults to `""`):
|
| 65 |
-
Name of this run.
|
| 66 |
-
seed (`int`, *optional*, defaults to `0`):
|
| 67 |
-
Random seed for reproducibility.
|
| 68 |
-
log_with (`str` or `None`, *optional*, defaults to `None`):
|
| 69 |
-
Log with either `"wandb"` or `"tensorboard"`. Check
|
| 70 |
-
[tracking](https://huggingface.co/docs/accelerate/usage_guides/tracking) for more details.
|
| 71 |
-
log_image_freq (`int`, *optional*, defaults to `1`):
|
| 72 |
-
Frequency for logging images.
|
| 73 |
-
tracker_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
|
| 74 |
-
Keyword arguments for the tracker (e.g., `wandb_project`).
|
| 75 |
-
accelerator_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
|
| 76 |
-
Keyword arguments for the accelerator.
|
| 77 |
-
project_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
|
| 78 |
-
Keyword arguments for the accelerator project config (e.g., `logging_dir`).
|
| 79 |
-
tracker_project_name (`str`, *optional*, defaults to `"trl"`):
|
| 80 |
-
Name of project to use for tracking.
|
| 81 |
-
logdir (`str`, *optional*, defaults to `"logs"`):
|
| 82 |
-
Top-level logging directory for checkpoint saving.
|
| 83 |
-
num_epochs (`int`, *optional*, defaults to `100`):
|
| 84 |
-
Number of epochs to train.
|
| 85 |
-
save_freq (`int`, *optional*, defaults to `1`):
|
| 86 |
-
Number of epochs between saving model checkpoints.
|
| 87 |
-
num_checkpoint_limit (`int`, *optional*, defaults to `5`):
|
| 88 |
-
Number of checkpoints to keep before overwriting old ones.
|
| 89 |
-
mixed_precision (`str`, *optional*, defaults to `"fp16"`):
|
| 90 |
-
Mixed precision training.
|
| 91 |
-
allow_tf32 (`bool`, *optional*, defaults to `True`):
|
| 92 |
-
Allow `tf32` on Ampere GPUs.
|
| 93 |
-
resume_from (`str`, *optional*, defaults to `""`):
|
| 94 |
-
Path to resume training from a checkpoint.
|
| 95 |
-
sample_num_steps (`int`, *optional*, defaults to `50`):
|
| 96 |
-
Number of sampler inference steps.
|
| 97 |
-
sample_eta (`float`, *optional*, defaults to `1.0`):
|
| 98 |
-
Eta parameter for the DDIM sampler.
|
| 99 |
-
sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
|
| 100 |
-
Classifier-free guidance weight.
|
| 101 |
-
train_batch_size (`int`, *optional*, defaults to `1`):
|
| 102 |
-
Batch size for training.
|
| 103 |
-
train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
|
| 104 |
-
Whether to use the 8bit Adam optimizer from `bitsandbytes`.
|
| 105 |
-
train_learning_rate (`float`, *optional*, defaults to `1e-3`):
|
| 106 |
-
Learning rate.
|
| 107 |
-
train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
|
| 108 |
-
Beta1 for Adam optimizer.
|
| 109 |
-
train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
|
| 110 |
-
Beta2 for Adam optimizer.
|
| 111 |
-
train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
|
| 112 |
-
Weight decay for Adam optimizer.
|
| 113 |
-
train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
|
| 114 |
-
Epsilon value for Adam optimizer.
|
| 115 |
-
train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
|
| 116 |
-
Number of gradient accumulation steps.
|
| 117 |
-
train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
|
| 118 |
-
Maximum gradient norm for gradient clipping.
|
| 119 |
-
negative_prompts (`str` or `None`, *optional*, defaults to `None`):
|
| 120 |
-
Comma-separated list of prompts to use as negative examples.
|
| 121 |
-
truncated_backprop_rand (`bool`, *optional*, defaults to `True`):
|
| 122 |
-
If `True`, randomized truncation to different diffusion timesteps is used.
|
| 123 |
-
truncated_backprop_timestep (`int`, *optional*, defaults to `49`):
|
| 124 |
-
Absolute timestep to which the gradients are backpropagated. Used only if `truncated_backprop_rand=False`.
|
| 125 |
-
truncated_rand_backprop_minmax (`tuple[int, int]`, *optional*, defaults to `(0, 50)`):
|
| 126 |
-
Range of diffusion timesteps for randomized truncated backpropagation.
|
| 127 |
-
push_to_hub (`bool`, *optional*, defaults to `False`):
|
| 128 |
-
Whether to push the final model to the Hub.
|
| 129 |
-
|
| 130 |
-
"""
|
| 131 |
-
vllm_sampling_params: Optional[Any] = field(
|
| 132 |
-
default = None,
|
| 133 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
| 134 |
-
)
|
| 135 |
-
unsloth_num_chunks : Optional[int] = field(
|
| 136 |
-
default = -1,
|
| 137 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 138 |
-
)
|
| 139 |
-
def __init__(
|
| 140 |
-
self,
|
| 141 |
-
exp_name = 'colab_kernel_launcher',
|
| 142 |
-
run_name = '',
|
| 143 |
-
seed = 3407,
|
| 144 |
-
log_with = None,
|
| 145 |
-
log_image_freq = 1,
|
| 146 |
-
tracker_project_name = 'trl',
|
| 147 |
-
logdir = 'logs',
|
| 148 |
-
num_epochs = 100,
|
| 149 |
-
save_freq = 1,
|
| 150 |
-
num_checkpoint_limit = 5,
|
| 151 |
-
mixed_precision = 'fp16',
|
| 152 |
-
allow_tf32 = True,
|
| 153 |
-
resume_from = '',
|
| 154 |
-
sample_num_steps = 50,
|
| 155 |
-
sample_eta = 1.0,
|
| 156 |
-
sample_guidance_scale = 5.0,
|
| 157 |
-
train_batch_size = 1,
|
| 158 |
-
train_use_8bit_adam = False,
|
| 159 |
-
train_learning_rate = 5e-05,
|
| 160 |
-
train_adam_beta1 = 0.9,
|
| 161 |
-
train_adam_beta2 = 0.999,
|
| 162 |
-
train_adam_weight_decay = 0.01,
|
| 163 |
-
train_adam_epsilon = 1e-08,
|
| 164 |
-
train_gradient_accumulation_steps = 2,
|
| 165 |
-
train_max_grad_norm = 1.0,
|
| 166 |
-
negative_prompts = None,
|
| 167 |
-
truncated_backprop_rand = True,
|
| 168 |
-
truncated_backprop_timestep = 49,
|
| 169 |
-
push_to_hub = False,
|
| 170 |
-
vllm_sampling_params = None,
|
| 171 |
-
unsloth_num_chunks = -1,
|
| 172 |
-
**kwargs,
|
| 173 |
-
):
|
| 174 |
-
|
| 175 |
-
super().__init__(
|
| 176 |
-
exp_name = exp_name,
|
| 177 |
-
run_name = run_name,
|
| 178 |
-
seed = seed,
|
| 179 |
-
log_with = log_with,
|
| 180 |
-
log_image_freq = log_image_freq,
|
| 181 |
-
tracker_project_name = tracker_project_name,
|
| 182 |
-
logdir = logdir,
|
| 183 |
-
num_epochs = num_epochs,
|
| 184 |
-
save_freq = save_freq,
|
| 185 |
-
num_checkpoint_limit = num_checkpoint_limit,
|
| 186 |
-
mixed_precision = mixed_precision,
|
| 187 |
-
allow_tf32 = allow_tf32,
|
| 188 |
-
resume_from = resume_from,
|
| 189 |
-
sample_num_steps = sample_num_steps,
|
| 190 |
-
sample_eta = sample_eta,
|
| 191 |
-
sample_guidance_scale = sample_guidance_scale,
|
| 192 |
-
train_batch_size = train_batch_size,
|
| 193 |
-
train_use_8bit_adam = train_use_8bit_adam,
|
| 194 |
-
train_learning_rate = train_learning_rate,
|
| 195 |
-
train_adam_beta1 = train_adam_beta1,
|
| 196 |
-
train_adam_beta2 = train_adam_beta2,
|
| 197 |
-
train_adam_weight_decay = train_adam_weight_decay,
|
| 198 |
-
train_adam_epsilon = train_adam_epsilon,
|
| 199 |
-
train_gradient_accumulation_steps = train_gradient_accumulation_steps,
|
| 200 |
-
train_max_grad_norm = train_max_grad_norm,
|
| 201 |
-
negative_prompts = negative_prompts,
|
| 202 |
-
truncated_backprop_rand = truncated_backprop_rand,
|
| 203 |
-
truncated_backprop_timestep = truncated_backprop_timestep,
|
| 204 |
-
push_to_hub = push_to_hub,**kwargs)
|
| 205 |
-
self.vllm_sampling_params = vllm_sampling_params
|
| 206 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
| 207 |
-
pass
|
| 208 |
-
|
| 209 |
-
class _UnslothAlignPropTrainer(PyTorchModelHubMixin):
|
| 210 |
-
""""""
|
| 211 |
-
|
| 212 |
-
_tag_names = ["trl", "alignprop"]
|
| 213 |
-
|
| 214 |
-
def __init__(
|
| 215 |
-
self,
|
| 216 |
-
config: AlignPropConfig,
|
| 217 |
-
reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
|
| 218 |
-
prompt_function: Callable[[], tuple[str, Any]],
|
| 219 |
-
sd_pipeline: DDPOStableDiffusionPipeline,
|
| 220 |
-
image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
|
| 221 |
-
):
|
| 222 |
-
if image_samples_hook is None:
|
| 223 |
-
warn("No image_samples_hook provided; no images will be logged")
|
| 224 |
-
|
| 225 |
-
self.prompt_fn = prompt_function
|
| 226 |
-
self.reward_fn = reward_function
|
| 227 |
-
self.config = config
|
| 228 |
-
self.image_samples_callback = image_samples_hook
|
| 229 |
-
|
| 230 |
-
accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
|
| 231 |
-
|
| 232 |
-
if self.config.resume_from:
|
| 233 |
-
self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
|
| 234 |
-
if "checkpoint_" not in os.path.basename(self.config.resume_from):
|
| 235 |
-
# get the most recent checkpoint in this directory
|
| 236 |
-
checkpoints = list(
|
| 237 |
-
filter(
|
| 238 |
-
lambda x: "checkpoint_" in x,
|
| 239 |
-
os.listdir(self.config.resume_from),
|
| 240 |
-
)
|
| 241 |
-
)
|
| 242 |
-
if len(checkpoints) == 0:
|
| 243 |
-
raise ValueError(f"No checkpoints found in {self.config.resume_from}")
|
| 244 |
-
checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
|
| 245 |
-
self.config.resume_from = os.path.join(
|
| 246 |
-
self.config.resume_from,
|
| 247 |
-
f"checkpoint_{checkpoint_numbers[-1]}",
|
| 248 |
-
)
|
| 249 |
-
|
| 250 |
-
accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
|
| 251 |
-
|
| 252 |
-
self.accelerator = Accelerator(
|
| 253 |
-
log_with=self.config.log_with,
|
| 254 |
-
mixed_precision=self.config.mixed_precision,
|
| 255 |
-
project_config=accelerator_project_config,
|
| 256 |
-
# we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
|
| 257 |
-
# number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
|
| 258 |
-
# the total number of optimizer steps to accumulate across.
|
| 259 |
-
gradient_accumulation_steps=self.config.train_gradient_accumulation_steps,
|
| 260 |
-
**self.config.accelerator_kwargs,
|
| 261 |
-
)
|
| 262 |
-
|
| 263 |
-
is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
|
| 264 |
-
|
| 265 |
-
if self.accelerator.is_main_process:
|
| 266 |
-
self.accelerator.init_trackers(
|
| 267 |
-
self.config.tracker_project_name,
|
| 268 |
-
config=dict(alignprop_trainer_config=config.to_dict())
|
| 269 |
-
if not is_using_tensorboard
|
| 270 |
-
else config.to_dict(),
|
| 271 |
-
init_kwargs=self.config.tracker_kwargs,
|
| 272 |
-
)
|
| 273 |
-
|
| 274 |
-
logger.info(f"\n{config}")
|
| 275 |
-
|
| 276 |
-
set_seed(self.config.seed, device_specific=True)
|
| 277 |
-
|
| 278 |
-
self.sd_pipeline = sd_pipeline
|
| 279 |
-
|
| 280 |
-
self.sd_pipeline.set_progress_bar_config(
|
| 281 |
-
position=1,
|
| 282 |
-
disable=not self.accelerator.is_local_main_process,
|
| 283 |
-
leave=False,
|
| 284 |
-
desc="Timestep",
|
| 285 |
-
dynamic_ncols=True,
|
| 286 |
-
)
|
| 287 |
-
|
| 288 |
-
# For mixed precision training we cast all non-trainable weights [vae, non-lora text_encoder and non-lora unet] to half-precision
|
| 289 |
-
# as these weights are only used for inference, keeping weights in full precision is not required.
|
| 290 |
-
if self.accelerator.mixed_precision == "fp16":
|
| 291 |
-
inference_dtype = torch.float16
|
| 292 |
-
elif self.accelerator.mixed_precision == "bf16":
|
| 293 |
-
inference_dtype = torch.bfloat16
|
| 294 |
-
else:
|
| 295 |
-
inference_dtype = torch.float32
|
| 296 |
-
|
| 297 |
-
self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
|
| 298 |
-
self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
|
| 299 |
-
self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
|
| 300 |
-
|
| 301 |
-
trainable_layers = self.sd_pipeline.get_trainable_layers()
|
| 302 |
-
|
| 303 |
-
self.accelerator.register_save_state_pre_hook(self._save_model_hook)
|
| 304 |
-
self.accelerator.register_load_state_pre_hook(self._load_model_hook)
|
| 305 |
-
|
| 306 |
-
# Enable TF32 for faster training on Ampere GPUs,
|
| 307 |
-
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
| 308 |
-
if self.config.allow_tf32:
|
| 309 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
| 310 |
-
|
| 311 |
-
self.optimizer = self._setup_optimizer(
|
| 312 |
-
trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
|
| 313 |
-
)
|
| 314 |
-
|
| 315 |
-
self.neg_prompt_embed = self.sd_pipeline.text_encoder(
|
| 316 |
-
self.sd_pipeline.tokenizer(
|
| 317 |
-
[""] if self.config.negative_prompts is None else self.config.negative_prompts,
|
| 318 |
-
return_tensors="pt",
|
| 319 |
-
padding="max_length",
|
| 320 |
-
truncation=True,
|
| 321 |
-
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
| 322 |
-
).input_ids.to(self.accelerator.device)
|
| 323 |
-
)[0]
|
| 324 |
-
|
| 325 |
-
# NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
|
| 326 |
-
# more memory
|
| 327 |
-
self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
|
| 328 |
-
|
| 329 |
-
if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
|
| 330 |
-
unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
| 331 |
-
self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
| 332 |
-
else:
|
| 333 |
-
self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
| 334 |
-
|
| 335 |
-
if config.resume_from:
|
| 336 |
-
logger.info(f"Resuming from {config.resume_from}")
|
| 337 |
-
self.accelerator.load_state(config.resume_from)
|
| 338 |
-
self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
|
| 339 |
-
else:
|
| 340 |
-
self.first_epoch = 0
|
| 341 |
-
|
| 342 |
-
def compute_rewards(self, prompt_image_pairs):
|
| 343 |
-
reward, reward_metadata = self.reward_fn(
|
| 344 |
-
prompt_image_pairs["images"], prompt_image_pairs["prompts"], prompt_image_pairs["prompt_metadata"]
|
| 345 |
-
)
|
| 346 |
-
return reward
|
| 347 |
-
|
| 348 |
-
def step(self, epoch: int, global_step: int):
|
| 349 |
-
"""
|
| 350 |
-
Perform a single step of training.
|
| 351 |
-
|
| 352 |
-
Args:
|
| 353 |
-
epoch (int): The current epoch.
|
| 354 |
-
global_step (int): The current global step.
|
| 355 |
-
|
| 356 |
-
Side Effects:
|
| 357 |
-
- Model weights are updated
|
| 358 |
-
- Logs the statistics to the accelerator trackers.
|
| 359 |
-
- If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
|
| 360 |
-
|
| 361 |
-
Returns:
|
| 362 |
-
global_step (int): The updated global step.
|
| 363 |
-
"""
|
| 364 |
-
info = defaultdict(list)
|
| 365 |
-
|
| 366 |
-
self.sd_pipeline.unet.train()
|
| 367 |
-
|
| 368 |
-
for _ in range(self.config.train_gradient_accumulation_steps):
|
| 369 |
-
with self.accelerator.accumulate(self.sd_pipeline.unet), self.autocast(), torch.enable_grad():
|
| 370 |
-
prompt_image_pairs = self._generate_samples(
|
| 371 |
-
batch_size=self.config.train_batch_size,
|
| 372 |
-
)
|
| 373 |
-
|
| 374 |
-
rewards = self.compute_rewards(prompt_image_pairs)
|
| 375 |
-
|
| 376 |
-
prompt_image_pairs["rewards"] = rewards
|
| 377 |
-
|
| 378 |
-
rewards_vis = self.accelerator.gather(rewards).detach().cpu().numpy()
|
| 379 |
-
|
| 380 |
-
loss = self.calculate_loss(rewards)
|
| 381 |
-
|
| 382 |
-
self.accelerator.backward(loss)
|
| 383 |
-
|
| 384 |
-
if self.accelerator.sync_gradients:
|
| 385 |
-
self.accelerator.clip_grad_norm_(
|
| 386 |
-
self.trainable_layers.parameters()
|
| 387 |
-
if not isinstance(self.trainable_layers, list)
|
| 388 |
-
else self.trainable_layers,
|
| 389 |
-
self.config.train_max_grad_norm,
|
| 390 |
-
)
|
| 391 |
-
|
| 392 |
-
self.optimizer.step()
|
| 393 |
-
self.optimizer.zero_grad()
|
| 394 |
-
|
| 395 |
-
info["reward_mean"].append(rewards_vis.mean())
|
| 396 |
-
info["reward_std"].append(rewards_vis.std())
|
| 397 |
-
info["loss"].append(loss.item())
|
| 398 |
-
|
| 399 |
-
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 400 |
-
if self.accelerator.sync_gradients:
|
| 401 |
-
# log training-related stuff
|
| 402 |
-
info = {k: torch.mean(torch.tensor(v)) for k, v in info.items()}
|
| 403 |
-
info = self.accelerator.reduce(info, reduction="mean")
|
| 404 |
-
info.update({"epoch": epoch})
|
| 405 |
-
self.accelerator.log(info, step=global_step)
|
| 406 |
-
global_step += 1
|
| 407 |
-
info = defaultdict(list)
|
| 408 |
-
else:
|
| 409 |
-
raise ValueError(
|
| 410 |
-
"Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
|
| 411 |
-
)
|
| 412 |
-
# Logs generated images
|
| 413 |
-
if self.image_samples_callback is not None and global_step % self.config.log_image_freq == 0:
|
| 414 |
-
self.image_samples_callback(prompt_image_pairs, global_step, self.accelerator.trackers[0])
|
| 415 |
-
|
| 416 |
-
if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
|
| 417 |
-
self.accelerator.save_state()
|
| 418 |
-
|
| 419 |
-
return global_step
|
| 420 |
-
|
| 421 |
-
def calculate_loss(self, rewards):
|
| 422 |
-
"""
|
| 423 |
-
Calculate the loss for a batch of an unpacked sample
|
| 424 |
-
|
| 425 |
-
Args:
|
| 426 |
-
rewards (torch.Tensor):
|
| 427 |
-
Differentiable reward scalars for each generated image, shape: [batch_size]
|
| 428 |
-
|
| 429 |
-
Returns:
|
| 430 |
-
loss (torch.Tensor)
|
| 431 |
-
(all of these are of shape (1,))
|
| 432 |
-
"""
|
| 433 |
-
# Loss is specific to Aesthetic Reward function used in AlignProp (https://huggingface.co/papers/2310.03739)
|
| 434 |
-
loss = 10.0 - (rewards).mean()
|
| 435 |
-
return loss
|
| 436 |
-
|
| 437 |
-
def loss(
|
| 438 |
-
self,
|
| 439 |
-
advantages: torch.Tensor,
|
| 440 |
-
clip_range: float,
|
| 441 |
-
ratio: torch.Tensor,
|
| 442 |
-
):
|
| 443 |
-
unclipped_loss = -advantages * ratio
|
| 444 |
-
clipped_loss = -advantages * torch.clamp(
|
| 445 |
-
ratio,
|
| 446 |
-
1.0 - clip_range,
|
| 447 |
-
1.0 + clip_range,
|
| 448 |
-
)
|
| 449 |
-
return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
|
| 450 |
-
|
| 451 |
-
def _setup_optimizer(self, trainable_layers_parameters):
|
| 452 |
-
if self.config.train_use_8bit_adam:
|
| 453 |
-
import bitsandbytes
|
| 454 |
-
|
| 455 |
-
optimizer_cls = bitsandbytes.optim.AdamW8bit
|
| 456 |
-
else:
|
| 457 |
-
optimizer_cls = torch.optim.AdamW
|
| 458 |
-
|
| 459 |
-
return optimizer_cls(
|
| 460 |
-
trainable_layers_parameters,
|
| 461 |
-
lr=self.config.train_learning_rate,
|
| 462 |
-
betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
|
| 463 |
-
weight_decay=self.config.train_adam_weight_decay,
|
| 464 |
-
eps=self.config.train_adam_epsilon,
|
| 465 |
-
)
|
| 466 |
-
|
| 467 |
-
def _save_model_hook(self, models, weights, output_dir):
|
| 468 |
-
self.sd_pipeline.save_checkpoint(models, weights, output_dir)
|
| 469 |
-
weights.pop() # ensures that accelerate doesn't try to handle saving of the model
|
| 470 |
-
|
| 471 |
-
def _load_model_hook(self, models, input_dir):
|
| 472 |
-
self.sd_pipeline.load_checkpoint(models, input_dir)
|
| 473 |
-
models.pop() # ensures that accelerate doesn't try to handle loading of the model
|
| 474 |
-
|
| 475 |
-
def _generate_samples(self, batch_size, with_grad=True, prompts=None):
|
| 476 |
-
"""
|
| 477 |
-
Generate samples from the model
|
| 478 |
-
|
| 479 |
-
Args:
|
| 480 |
-
batch_size (int): Batch size to use for sampling
|
| 481 |
-
with_grad (bool): Whether the generated RGBs should have gradients attached to it.
|
| 482 |
-
|
| 483 |
-
Returns:
|
| 484 |
-
prompt_image_pairs (dict[Any])
|
| 485 |
-
"""
|
| 486 |
-
prompt_image_pairs = {}
|
| 487 |
-
|
| 488 |
-
sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
|
| 489 |
-
|
| 490 |
-
if prompts is None:
|
| 491 |
-
prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
|
| 492 |
-
else:
|
| 493 |
-
prompt_metadata = [{} for _ in range(batch_size)]
|
| 494 |
-
|
| 495 |
-
prompt_ids = self.sd_pipeline.tokenizer(
|
| 496 |
-
prompts,
|
| 497 |
-
return_tensors="pt",
|
| 498 |
-
padding="max_length",
|
| 499 |
-
truncation=True,
|
| 500 |
-
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
| 501 |
-
).input_ids.to(self.accelerator.device)
|
| 502 |
-
|
| 503 |
-
prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
|
| 504 |
-
|
| 505 |
-
if with_grad:
|
| 506 |
-
sd_output = self.sd_pipeline.rgb_with_grad(
|
| 507 |
-
prompt_embeds=prompt_embeds,
|
| 508 |
-
negative_prompt_embeds=sample_neg_prompt_embeds,
|
| 509 |
-
num_inference_steps=self.config.sample_num_steps,
|
| 510 |
-
guidance_scale=self.config.sample_guidance_scale,
|
| 511 |
-
eta=self.config.sample_eta,
|
| 512 |
-
truncated_backprop_rand=self.config.truncated_backprop_rand,
|
| 513 |
-
truncated_backprop_timestep=self.config.truncated_backprop_timestep,
|
| 514 |
-
truncated_rand_backprop_minmax=self.config.truncated_rand_backprop_minmax,
|
| 515 |
-
output_type="pt",
|
| 516 |
-
)
|
| 517 |
-
else:
|
| 518 |
-
sd_output = self.sd_pipeline(
|
| 519 |
-
prompt_embeds=prompt_embeds,
|
| 520 |
-
negative_prompt_embeds=sample_neg_prompt_embeds,
|
| 521 |
-
num_inference_steps=self.config.sample_num_steps,
|
| 522 |
-
guidance_scale=self.config.sample_guidance_scale,
|
| 523 |
-
eta=self.config.sample_eta,
|
| 524 |
-
output_type="pt",
|
| 525 |
-
)
|
| 526 |
-
|
| 527 |
-
images = sd_output.images
|
| 528 |
-
|
| 529 |
-
prompt_image_pairs["images"] = images
|
| 530 |
-
prompt_image_pairs["prompts"] = prompts
|
| 531 |
-
prompt_image_pairs["prompt_metadata"] = prompt_metadata
|
| 532 |
-
|
| 533 |
-
return prompt_image_pairs
|
| 534 |
-
|
| 535 |
-
def train(self, epochs: Optional[int] = None):
|
| 536 |
-
"""
|
| 537 |
-
Train the model for a given number of epochs
|
| 538 |
-
"""
|
| 539 |
-
global_step = 0
|
| 540 |
-
if epochs is None:
|
| 541 |
-
epochs = self.config.num_epochs
|
| 542 |
-
for epoch in range(self.first_epoch, epochs):
|
| 543 |
-
global_step = self.step(epoch, global_step)
|
| 544 |
-
|
| 545 |
-
def _save_pretrained(self, save_directory):
|
| 546 |
-
self.sd_pipeline.save_pretrained(save_directory)
|
| 547 |
-
self.create_model_card()
|
| 548 |
-
|
| 549 |
-
def create_model_card(
|
| 550 |
-
self,
|
| 551 |
-
model_name: Optional[str] = None,
|
| 552 |
-
dataset_name: Optional[str] = None,
|
| 553 |
-
tags: Union[str, list[str], None] = None,
|
| 554 |
-
):
|
| 555 |
-
"""
|
| 556 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
| 557 |
-
|
| 558 |
-
Args:
|
| 559 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 560 |
-
Name of the model.
|
| 561 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 562 |
-
Name of the dataset used for training.
|
| 563 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 564 |
-
Tags to be associated with the model card.
|
| 565 |
-
"""
|
| 566 |
-
if not self.is_world_process_zero():
|
| 567 |
-
return
|
| 568 |
-
|
| 569 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 570 |
-
base_model = self.model.config._name_or_path
|
| 571 |
-
else:
|
| 572 |
-
base_model = None
|
| 573 |
-
|
| 574 |
-
tags = tags or []
|
| 575 |
-
if isinstance(tags, str):
|
| 576 |
-
tags = [tags]
|
| 577 |
-
|
| 578 |
-
if hasattr(self.model.config, "unsloth_version"):
|
| 579 |
-
tags.append("unsloth")
|
| 580 |
-
|
| 581 |
-
citation = textwrap.dedent("""\
|
| 582 |
-
@article{prabhudesai2024aligning,
|
| 583 |
-
title = {{Aligning Text-to-Image Diffusion Models with Reward Backpropagation}},
|
| 584 |
-
author = {Mihir Prabhudesai and Anirudh Goyal and Deepak Pathak and Katerina Fragkiadaki},
|
| 585 |
-
year = 2024,
|
| 586 |
-
eprint = {arXiv:2310.03739}
|
| 587 |
-
}""")
|
| 588 |
-
|
| 589 |
-
model_card = generate_model_card(
|
| 590 |
-
base_model=base_model,
|
| 591 |
-
model_name=model_name,
|
| 592 |
-
hub_model_id=self.hub_model_id,
|
| 593 |
-
dataset_name=dataset_name,
|
| 594 |
-
tags=tags,
|
| 595 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 596 |
-
comet_url=get_comet_experiment_url(),
|
| 597 |
-
trainer_name="AlignProp",
|
| 598 |
-
trainer_citation=citation,
|
| 599 |
-
paper_title="Aligning Text-to-Image Diffusion Models with Reward Backpropagation",
|
| 600 |
-
paper_id="2310.03739",
|
| 601 |
-
)
|
| 602 |
-
|
| 603 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 604 |
-
class UnslothAlignPropTrainer(_UnslothAlignPropTrainer):
|
| 605 |
-
"""
|
| 606 |
-
|
| 607 |
-
The AlignPropTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
|
| 608 |
-
Note, this trainer is heavily inspired by the work here: https://github.com/mihirp1998/AlignProp/
|
| 609 |
-
As of now only Stable Diffusion based pipelines are supported
|
| 610 |
-
|
| 611 |
-
Attributes:
|
| 612 |
-
config (`AlignPropConfig`):
|
| 613 |
-
Configuration object for AlignPropTrainer. Check the documentation of `PPOConfig` for more details.
|
| 614 |
-
reward_function (`Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]`):
|
| 615 |
-
Reward function to be used
|
| 616 |
-
prompt_function (`Callable[[], tuple[str, Any]]`):
|
| 617 |
-
Function to generate prompts to guide model
|
| 618 |
-
sd_pipeline (`DDPOStableDiffusionPipeline`):
|
| 619 |
-
Stable Diffusion pipeline to be used for training.
|
| 620 |
-
image_samples_hook (`Optional[Callable[[Any, Any, Any], Any]]`):
|
| 621 |
-
Hook to be called to log images
|
| 622 |
-
|
| 623 |
-
"""
|
| 624 |
-
def __init__(
|
| 625 |
-
self,
|
| 626 |
-
config,
|
| 627 |
-
reward_function,
|
| 628 |
-
prompt_function,
|
| 629 |
-
sd_pipeline,
|
| 630 |
-
image_samples_hook = None,
|
| 631 |
-
**kwargs
|
| 632 |
-
):
|
| 633 |
-
if args is None: args = UnslothAlignPropConfig()
|
| 634 |
-
other_metrics = []
|
| 635 |
-
|
| 636 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 637 |
-
PatchRLStatistics('alignprop_trainer', other_metrics)
|
| 638 |
-
|
| 639 |
-
super().__init__(
|
| 640 |
-
config = config,
|
| 641 |
-
reward_function = reward_function,
|
| 642 |
-
prompt_function = prompt_function,
|
| 643 |
-
sd_pipeline = sd_pipeline,
|
| 644 |
-
image_samples_hook = image_samples_hook,**kwargs)
|
| 645 |
-
|
| 646 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_run_uploads/UnslothBCOTrainer.py
DELETED
|
@@ -1,1834 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
2025.7.11
|
| 3 |
-
2025.7.11
|
| 4 |
-
4.54.1
|
| 5 |
-
0.16.1
|
| 6 |
-
__UNSLOTH_VERSIONING__
|
| 7 |
-
"""
|
| 8 |
-
from torch import Tensor
|
| 9 |
-
import torch
|
| 10 |
-
import torch.nn as nn
|
| 11 |
-
from torch.nn import functional as F
|
| 12 |
-
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
| 13 |
-
from trl.trainer.bco_trainer import (Any, AutoModelForCausalLM, BCOConfig, BCOTrainer, BaseImageProcessor, CLF_NAME, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, LogisticRegression, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, RUNNING_NAME, RunningMoments, SequentialSampler, Trainer, TrainerCallback, TrainingArguments, Union, _process_tokens, _tokenize, amp, contextmanager, create_reference_model, deepcopy, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, has_length, inspect, is_comet_available, is_peft_available, is_sklearn_available, is_wandb_available, itemgetter, log_table_to_comet_experiment, maybe_apply_chat_template, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, tqdm, transformers, version, wandb, warnings, F, Optional, PeftModel, PreTrainedModel, Trainer, is_peft_available, os, torch)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
import os
|
| 17 |
-
from typing import *
|
| 18 |
-
from dataclasses import dataclass, field
|
| 19 |
-
from packaging.version import Version
|
| 20 |
-
import torch
|
| 21 |
-
import numpy as np
|
| 22 |
-
from contextlib import nullcontext
|
| 23 |
-
from torch.nn import functional as F
|
| 24 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
| 25 |
-
|
| 26 |
-
torch_compile_options = {
|
| 27 |
-
"epilogue_fusion" : True,
|
| 28 |
-
"max_autotune" : False,
|
| 29 |
-
"shape_padding" : True,
|
| 30 |
-
"trace.enabled" : False,
|
| 31 |
-
"triton.cudagraphs" : False,
|
| 32 |
-
}
|
| 33 |
-
|
| 34 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 35 |
-
def chunked_selective_log_softmax(logits, index):
|
| 36 |
-
# Split into 4 chunks only
|
| 37 |
-
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
|
| 38 |
-
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
|
| 39 |
-
all_per_token_logps = []
|
| 40 |
-
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
|
| 41 |
-
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
|
| 42 |
-
chunk_logits = chunk_logits.to(torch.float32)
|
| 43 |
-
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 44 |
-
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
|
| 45 |
-
per_token_logps = selected_logits - logsumexp_values
|
| 46 |
-
all_per_token_logps.append(per_token_logps)
|
| 47 |
-
pass
|
| 48 |
-
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 49 |
-
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
|
| 50 |
-
return all_per_token_logps
|
| 51 |
-
@dataclass
|
| 52 |
-
class UnslothBCOConfig(BCOConfig):
|
| 53 |
-
"""
|
| 54 |
-
|
| 55 |
-
Configuration class for the [`BCOTrainer`].
|
| 56 |
-
|
| 57 |
-
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 58 |
-
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 59 |
-
command line.
|
| 60 |
-
|
| 61 |
-
Parameters:
|
| 62 |
-
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
| 63 |
-
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
|
| 64 |
-
to use the default data collator.
|
| 65 |
-
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
| 66 |
-
Maximum length of the prompt. This argument is required if you want to use the default data collator.
|
| 67 |
-
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
|
| 68 |
-
Maximum length of the completion. This argument is required if you want to use the default data collator
|
| 69 |
-
and your model is an encoder-decoder.
|
| 70 |
-
beta (`float`, *optional*, defaults to `0.1`):
|
| 71 |
-
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
|
| 72 |
-
reference model.
|
| 73 |
-
label_pad_token_id (`int`, *optional*, defaults to `-100`):
|
| 74 |
-
Label pad token id. This argument is required if you want to use the default data collator.
|
| 75 |
-
padding_value (`int` or `None`, *optional*, defaults to `None`):
|
| 76 |
-
Padding value to use. If `None`, the padding value of the tokenizer is used.
|
| 77 |
-
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
|
| 78 |
-
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
|
| 79 |
-
This argument is required if you want to use the default data collator.
|
| 80 |
-
disable_dropout (`bool`, *optional*, defaults to `True`):
|
| 81 |
-
Whether to disable dropout in the model and reference model.
|
| 82 |
-
generate_during_eval (`bool`, *optional*, defaults to `False`):
|
| 83 |
-
If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during
|
| 84 |
-
evaluation.
|
| 85 |
-
is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
|
| 86 |
-
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
|
| 87 |
-
you need to specify if the model returned by the callable is an encoder-decoder model.
|
| 88 |
-
precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
|
| 89 |
-
Whether to precompute reference model log probabilities for training and evaluation datasets. This is
|
| 90 |
-
useful when training without the reference model to reduce the total GPU memory needed.
|
| 91 |
-
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
| 92 |
-
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
|
| 93 |
-
string.
|
| 94 |
-
ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
| 95 |
-
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model
|
| 96 |
-
from a string.
|
| 97 |
-
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
| 98 |
-
Number of processes to use for processing the dataset.
|
| 99 |
-
prompt_sample_size (`int`, *optional*, defaults to `1024`):
|
| 100 |
-
Number of prompts that are fed to density ratio classifier.
|
| 101 |
-
min_density_ratio (`float`, *optional*, defaults to `0.5`):
|
| 102 |
-
Minimum value of the density ratio. The estimated density ratio is clamped to this value.
|
| 103 |
-
max_density_ratio (`float`, *optional*, defaults to `10.0`):
|
| 104 |
-
Maximum value of the density ratio. The estimated density ratio is clamped to this value.
|
| 105 |
-
|
| 106 |
-
"""
|
| 107 |
-
vllm_sampling_params: Optional[Any] = field(
|
| 108 |
-
default = None,
|
| 109 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
| 110 |
-
)
|
| 111 |
-
unsloth_num_chunks : Optional[int] = field(
|
| 112 |
-
default = -1,
|
| 113 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 114 |
-
)
|
| 115 |
-
def __init__(
|
| 116 |
-
self,
|
| 117 |
-
output_dir = None,
|
| 118 |
-
overwrite_output_dir = None,
|
| 119 |
-
do_train = False,
|
| 120 |
-
do_eval = False,
|
| 121 |
-
do_predict = False,
|
| 122 |
-
eval_strategy = 'no',
|
| 123 |
-
prediction_loss_only = False,
|
| 124 |
-
per_device_train_batch_size = 4,
|
| 125 |
-
per_device_eval_batch_size = 4,
|
| 126 |
-
per_gpu_train_batch_size = None,
|
| 127 |
-
per_gpu_eval_batch_size = None,
|
| 128 |
-
gradient_accumulation_steps = 2,
|
| 129 |
-
eval_accumulation_steps = 2,
|
| 130 |
-
eval_delay = 0,
|
| 131 |
-
torch_empty_cache_steps = 250,
|
| 132 |
-
learning_rate = 5e-05,
|
| 133 |
-
weight_decay = 0.01,
|
| 134 |
-
adam_beta1 = 0.9,
|
| 135 |
-
adam_beta2 = 0.999,
|
| 136 |
-
adam_epsilon = 1e-08,
|
| 137 |
-
max_grad_norm = 1.0,
|
| 138 |
-
num_train_epochs = 3.0,
|
| 139 |
-
max_steps = -1,
|
| 140 |
-
lr_scheduler_type = 'linear',
|
| 141 |
-
warmup_ratio = 0.1,
|
| 142 |
-
warmup_steps = 0,
|
| 143 |
-
log_level = 'passive',
|
| 144 |
-
log_level_replica = 'warning',
|
| 145 |
-
log_on_each_node = True,
|
| 146 |
-
logging_dir = None,
|
| 147 |
-
logging_strategy = 'steps',
|
| 148 |
-
logging_first_step = False,
|
| 149 |
-
logging_steps = 1,
|
| 150 |
-
logging_nan_inf_filter = False,
|
| 151 |
-
save_strategy = 'steps',
|
| 152 |
-
save_steps = 500,
|
| 153 |
-
save_total_limit = None,
|
| 154 |
-
save_safetensors = True,
|
| 155 |
-
save_on_each_node = False,
|
| 156 |
-
save_only_model = False,
|
| 157 |
-
restore_callback_states_from_checkpoint = False,
|
| 158 |
-
no_cuda = False,
|
| 159 |
-
use_cpu = False,
|
| 160 |
-
use_mps_device = False,
|
| 161 |
-
seed = 3407,
|
| 162 |
-
data_seed = 3407,
|
| 163 |
-
jit_mode_eval = False,
|
| 164 |
-
use_ipex = False,
|
| 165 |
-
bf16 = False,
|
| 166 |
-
fp16 = False,
|
| 167 |
-
fp16_opt_level = 'O1',
|
| 168 |
-
half_precision_backend = 'auto',
|
| 169 |
-
bf16_full_eval = False,
|
| 170 |
-
fp16_full_eval = False,
|
| 171 |
-
tf32 = None,
|
| 172 |
-
local_rank = -1,
|
| 173 |
-
ddp_backend = None,
|
| 174 |
-
tpu_num_cores = None,
|
| 175 |
-
tpu_metrics_debug = False,
|
| 176 |
-
debug = '',
|
| 177 |
-
dataloader_drop_last = False,
|
| 178 |
-
eval_steps = None,
|
| 179 |
-
dataloader_num_workers = 0,
|
| 180 |
-
dataloader_prefetch_factor = None,
|
| 181 |
-
past_index = -1,
|
| 182 |
-
run_name = None,
|
| 183 |
-
disable_tqdm = None,
|
| 184 |
-
remove_unused_columns = True,
|
| 185 |
-
label_names = None,
|
| 186 |
-
load_best_model_at_end = False,
|
| 187 |
-
metric_for_best_model = None,
|
| 188 |
-
greater_is_better = None,
|
| 189 |
-
ignore_data_skip = False,
|
| 190 |
-
fsdp = '',
|
| 191 |
-
fsdp_min_num_params = 0,
|
| 192 |
-
fsdp_config = None,
|
| 193 |
-
fsdp_transformer_layer_cls_to_wrap = None,
|
| 194 |
-
accelerator_config = None,
|
| 195 |
-
deepspeed = None,
|
| 196 |
-
label_smoothing_factor = 0.0,
|
| 197 |
-
optim = 'adamw_8bit',
|
| 198 |
-
optim_args = None,
|
| 199 |
-
adafactor = False,
|
| 200 |
-
group_by_length = False,
|
| 201 |
-
length_column_name = 'length',
|
| 202 |
-
report_to = None,
|
| 203 |
-
ddp_find_unused_parameters = None,
|
| 204 |
-
ddp_bucket_cap_mb = None,
|
| 205 |
-
ddp_broadcast_buffers = None,
|
| 206 |
-
dataloader_pin_memory = True,
|
| 207 |
-
dataloader_persistent_workers = False,
|
| 208 |
-
skip_memory_metrics = True,
|
| 209 |
-
use_legacy_prediction_loop = False,
|
| 210 |
-
push_to_hub = False,
|
| 211 |
-
resume_from_checkpoint = None,
|
| 212 |
-
hub_model_id = None,
|
| 213 |
-
hub_strategy = 'every_save',
|
| 214 |
-
hub_token = None,
|
| 215 |
-
hub_private_repo = None,
|
| 216 |
-
hub_always_push = False,
|
| 217 |
-
hub_revision = None,
|
| 218 |
-
gradient_checkpointing = False,
|
| 219 |
-
gradient_checkpointing_kwargs = None,
|
| 220 |
-
include_inputs_for_metrics = False,
|
| 221 |
-
eval_do_concat_batches = True,
|
| 222 |
-
fp16_backend = 'auto',
|
| 223 |
-
push_to_hub_model_id = None,
|
| 224 |
-
push_to_hub_organization = None,
|
| 225 |
-
push_to_hub_token = None,
|
| 226 |
-
mp_parameters = '',
|
| 227 |
-
auto_find_batch_size = True,
|
| 228 |
-
full_determinism = False,
|
| 229 |
-
torchdynamo = None,
|
| 230 |
-
ray_scope = 'last',
|
| 231 |
-
ddp_timeout = 1800,
|
| 232 |
-
torch_compile = False,
|
| 233 |
-
torch_compile_backend = None,
|
| 234 |
-
torch_compile_mode = None,
|
| 235 |
-
include_tokens_per_second = False,
|
| 236 |
-
include_num_input_tokens_seen = False,
|
| 237 |
-
neftune_noise_alpha = None,
|
| 238 |
-
optim_target_modules = None,
|
| 239 |
-
batch_eval_metrics = False,
|
| 240 |
-
eval_on_start = False,
|
| 241 |
-
use_liger_kernel = False,
|
| 242 |
-
liger_kernel_config = None,
|
| 243 |
-
eval_use_gather_object = False,
|
| 244 |
-
average_tokens_across_devices = True,
|
| 245 |
-
max_length = 1024,
|
| 246 |
-
max_prompt_length = 512,
|
| 247 |
-
max_completion_length = None,
|
| 248 |
-
beta = 0.1,
|
| 249 |
-
label_pad_token_id = -100,
|
| 250 |
-
padding_value = None,
|
| 251 |
-
truncation_mode = 'keep_end',
|
| 252 |
-
disable_dropout = True,
|
| 253 |
-
generate_during_eval = False,
|
| 254 |
-
is_encoder_decoder = None,
|
| 255 |
-
precompute_ref_log_probs = False,
|
| 256 |
-
model_init_kwargs = None,
|
| 257 |
-
ref_model_init_kwargs = None,
|
| 258 |
-
dataset_num_proc = None,
|
| 259 |
-
prompt_sample_size = 1024,
|
| 260 |
-
min_density_ratio = 0.5,
|
| 261 |
-
max_density_ratio = 10.0,
|
| 262 |
-
vllm_sampling_params = None,
|
| 263 |
-
unsloth_num_chunks = -1,
|
| 264 |
-
**kwargs,
|
| 265 |
-
):
|
| 266 |
-
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 267 |
-
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 268 |
-
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 269 |
-
output_dir = 'unsloth_training_checkpoints'
|
| 270 |
-
save_strategy = 'no'
|
| 271 |
-
if dataset_num_proc is None:
|
| 272 |
-
from multiprocessing import cpu_count
|
| 273 |
-
dataset_num_proc = min(cpu_count()*2, 2)
|
| 274 |
-
|
| 275 |
-
super().__init__(
|
| 276 |
-
output_dir = output_dir,
|
| 277 |
-
overwrite_output_dir = overwrite_output_dir,
|
| 278 |
-
do_train = do_train,
|
| 279 |
-
do_eval = do_eval,
|
| 280 |
-
do_predict = do_predict,
|
| 281 |
-
eval_strategy = eval_strategy,
|
| 282 |
-
prediction_loss_only = prediction_loss_only,
|
| 283 |
-
per_device_train_batch_size = per_device_train_batch_size,
|
| 284 |
-
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 285 |
-
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
| 286 |
-
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
| 287 |
-
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 288 |
-
eval_accumulation_steps = eval_accumulation_steps,
|
| 289 |
-
eval_delay = eval_delay,
|
| 290 |
-
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 291 |
-
learning_rate = learning_rate,
|
| 292 |
-
weight_decay = weight_decay,
|
| 293 |
-
adam_beta1 = adam_beta1,
|
| 294 |
-
adam_beta2 = adam_beta2,
|
| 295 |
-
adam_epsilon = adam_epsilon,
|
| 296 |
-
max_grad_norm = max_grad_norm,
|
| 297 |
-
num_train_epochs = num_train_epochs,
|
| 298 |
-
max_steps = max_steps,
|
| 299 |
-
lr_scheduler_type = lr_scheduler_type,
|
| 300 |
-
warmup_ratio = warmup_ratio,
|
| 301 |
-
warmup_steps = warmup_steps,
|
| 302 |
-
log_level = log_level,
|
| 303 |
-
log_level_replica = log_level_replica,
|
| 304 |
-
log_on_each_node = log_on_each_node,
|
| 305 |
-
logging_dir = logging_dir,
|
| 306 |
-
logging_strategy = logging_strategy,
|
| 307 |
-
logging_first_step = logging_first_step,
|
| 308 |
-
logging_steps = logging_steps,
|
| 309 |
-
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 310 |
-
save_strategy = save_strategy,
|
| 311 |
-
save_steps = save_steps,
|
| 312 |
-
save_total_limit = save_total_limit,
|
| 313 |
-
save_safetensors = save_safetensors,
|
| 314 |
-
save_on_each_node = save_on_each_node,
|
| 315 |
-
save_only_model = save_only_model,
|
| 316 |
-
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 317 |
-
no_cuda = no_cuda,
|
| 318 |
-
use_cpu = use_cpu,
|
| 319 |
-
use_mps_device = use_mps_device,
|
| 320 |
-
seed = seed,
|
| 321 |
-
data_seed = data_seed,
|
| 322 |
-
jit_mode_eval = jit_mode_eval,
|
| 323 |
-
use_ipex = use_ipex,
|
| 324 |
-
bf16 = bf16,
|
| 325 |
-
fp16 = fp16,
|
| 326 |
-
fp16_opt_level = fp16_opt_level,
|
| 327 |
-
half_precision_backend = half_precision_backend,
|
| 328 |
-
bf16_full_eval = bf16_full_eval,
|
| 329 |
-
fp16_full_eval = fp16_full_eval,
|
| 330 |
-
tf32 = tf32,
|
| 331 |
-
local_rank = local_rank,
|
| 332 |
-
ddp_backend = ddp_backend,
|
| 333 |
-
tpu_num_cores = tpu_num_cores,
|
| 334 |
-
tpu_metrics_debug = tpu_metrics_debug,
|
| 335 |
-
debug = debug,
|
| 336 |
-
dataloader_drop_last = dataloader_drop_last,
|
| 337 |
-
eval_steps = eval_steps,
|
| 338 |
-
dataloader_num_workers = dataloader_num_workers,
|
| 339 |
-
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 340 |
-
past_index = past_index,
|
| 341 |
-
run_name = run_name,
|
| 342 |
-
disable_tqdm = disable_tqdm,
|
| 343 |
-
remove_unused_columns = remove_unused_columns,
|
| 344 |
-
label_names = label_names,
|
| 345 |
-
load_best_model_at_end = load_best_model_at_end,
|
| 346 |
-
metric_for_best_model = metric_for_best_model,
|
| 347 |
-
greater_is_better = greater_is_better,
|
| 348 |
-
ignore_data_skip = ignore_data_skip,
|
| 349 |
-
fsdp = fsdp,
|
| 350 |
-
fsdp_min_num_params = fsdp_min_num_params,
|
| 351 |
-
fsdp_config = fsdp_config,
|
| 352 |
-
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
| 353 |
-
accelerator_config = accelerator_config,
|
| 354 |
-
deepspeed = deepspeed,
|
| 355 |
-
label_smoothing_factor = label_smoothing_factor,
|
| 356 |
-
optim = optim,
|
| 357 |
-
optim_args = optim_args,
|
| 358 |
-
adafactor = adafactor,
|
| 359 |
-
group_by_length = group_by_length,
|
| 360 |
-
length_column_name = length_column_name,
|
| 361 |
-
report_to = report_to,
|
| 362 |
-
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 363 |
-
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 364 |
-
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 365 |
-
dataloader_pin_memory = dataloader_pin_memory,
|
| 366 |
-
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 367 |
-
skip_memory_metrics = skip_memory_metrics,
|
| 368 |
-
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
| 369 |
-
push_to_hub = push_to_hub,
|
| 370 |
-
resume_from_checkpoint = resume_from_checkpoint,
|
| 371 |
-
hub_model_id = hub_model_id,
|
| 372 |
-
hub_strategy = hub_strategy,
|
| 373 |
-
hub_token = hub_token,
|
| 374 |
-
hub_private_repo = hub_private_repo,
|
| 375 |
-
hub_always_push = hub_always_push,
|
| 376 |
-
hub_revision = hub_revision,
|
| 377 |
-
gradient_checkpointing = gradient_checkpointing,
|
| 378 |
-
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 379 |
-
include_inputs_for_metrics = include_inputs_for_metrics,
|
| 380 |
-
eval_do_concat_batches = eval_do_concat_batches,
|
| 381 |
-
fp16_backend = fp16_backend,
|
| 382 |
-
push_to_hub_model_id = push_to_hub_model_id,
|
| 383 |
-
push_to_hub_organization = push_to_hub_organization,
|
| 384 |
-
push_to_hub_token = push_to_hub_token,
|
| 385 |
-
mp_parameters = mp_parameters,
|
| 386 |
-
auto_find_batch_size = auto_find_batch_size,
|
| 387 |
-
full_determinism = full_determinism,
|
| 388 |
-
torchdynamo = torchdynamo,
|
| 389 |
-
ray_scope = ray_scope,
|
| 390 |
-
ddp_timeout = ddp_timeout,
|
| 391 |
-
torch_compile = torch_compile,
|
| 392 |
-
torch_compile_backend = torch_compile_backend,
|
| 393 |
-
torch_compile_mode = torch_compile_mode,
|
| 394 |
-
include_tokens_per_second = include_tokens_per_second,
|
| 395 |
-
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 396 |
-
neftune_noise_alpha = neftune_noise_alpha,
|
| 397 |
-
optim_target_modules = optim_target_modules,
|
| 398 |
-
batch_eval_metrics = batch_eval_metrics,
|
| 399 |
-
eval_on_start = eval_on_start,
|
| 400 |
-
use_liger_kernel = use_liger_kernel,
|
| 401 |
-
liger_kernel_config = liger_kernel_config,
|
| 402 |
-
eval_use_gather_object = eval_use_gather_object,
|
| 403 |
-
average_tokens_across_devices = average_tokens_across_devices,
|
| 404 |
-
max_length = max_length,
|
| 405 |
-
max_prompt_length = max_prompt_length,
|
| 406 |
-
max_completion_length = max_completion_length,
|
| 407 |
-
beta = beta,
|
| 408 |
-
label_pad_token_id = label_pad_token_id,
|
| 409 |
-
padding_value = padding_value,
|
| 410 |
-
truncation_mode = truncation_mode,
|
| 411 |
-
disable_dropout = disable_dropout,
|
| 412 |
-
generate_during_eval = generate_during_eval,
|
| 413 |
-
is_encoder_decoder = is_encoder_decoder,
|
| 414 |
-
precompute_ref_log_probs = precompute_ref_log_probs,
|
| 415 |
-
model_init_kwargs = model_init_kwargs,
|
| 416 |
-
ref_model_init_kwargs = ref_model_init_kwargs,
|
| 417 |
-
dataset_num_proc = dataset_num_proc,
|
| 418 |
-
prompt_sample_size = prompt_sample_size,
|
| 419 |
-
min_density_ratio = min_density_ratio,
|
| 420 |
-
max_density_ratio = max_density_ratio,**kwargs)
|
| 421 |
-
self.vllm_sampling_params = vllm_sampling_params
|
| 422 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
| 423 |
-
pass
|
| 424 |
-
|
| 425 |
-
class _UnslothBCOTrainer(Trainer):
|
| 426 |
-
r""""""
|
| 427 |
-
|
| 428 |
-
_tag_names = ["trl", "bco"]
|
| 429 |
-
|
| 430 |
-
def __init__(
|
| 431 |
-
self,
|
| 432 |
-
model: Union[PreTrainedModel, nn.Module, str] = None,
|
| 433 |
-
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
| 434 |
-
args: BCOConfig = None,
|
| 435 |
-
train_dataset: Optional[Dataset] = None,
|
| 436 |
-
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 437 |
-
processing_class: Optional[
|
| 438 |
-
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 439 |
-
] = None,
|
| 440 |
-
data_collator: Optional[DataCollator] = None,
|
| 441 |
-
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
| 442 |
-
callbacks: Optional[list[TrainerCallback]] = None,
|
| 443 |
-
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 444 |
-
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 445 |
-
peft_config: Optional[dict] = None,
|
| 446 |
-
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
| 447 |
-
model_adapter_name: Optional[str] = None,
|
| 448 |
-
ref_adapter_name: Optional[str] = None,
|
| 449 |
-
embedding_func: Optional[Callable] = None,
|
| 450 |
-
embedding_tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
| 451 |
-
):
|
| 452 |
-
if not is_sklearn_available():
|
| 453 |
-
raise ImportError(
|
| 454 |
-
"BCOTrainer requires the scikit-learn library. Please install it with `pip install scikit-learn`."
|
| 455 |
-
)
|
| 456 |
-
|
| 457 |
-
if type(args) is TrainingArguments:
|
| 458 |
-
raise ValueError("Please use `BCOConfig` instead `TrainingArguments`.")
|
| 459 |
-
|
| 460 |
-
if not isinstance(model, str) and ref_model is model:
|
| 461 |
-
raise ValueError(
|
| 462 |
-
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
|
| 463 |
-
"same as `model`, you must mass a copy of it, or `None` if you use peft."
|
| 464 |
-
)
|
| 465 |
-
|
| 466 |
-
if args.model_init_kwargs is None:
|
| 467 |
-
model_init_kwargs = {}
|
| 468 |
-
elif not isinstance(model, str):
|
| 469 |
-
raise ValueError("You passed model_kwargs to the BCOTrainer. But your model is already instantiated.")
|
| 470 |
-
else:
|
| 471 |
-
model_init_kwargs = args.model_init_kwargs
|
| 472 |
-
torch_dtype = model_init_kwargs.get("torch_dtype")
|
| 473 |
-
if torch_dtype is not None:
|
| 474 |
-
# Convert to `torch.dtype` if an str is passed
|
| 475 |
-
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
| 476 |
-
torch_dtype = getattr(torch, torch_dtype)
|
| 477 |
-
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
| 478 |
-
raise ValueError(
|
| 479 |
-
f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
| 480 |
-
)
|
| 481 |
-
model_init_kwargs["torch_dtype"] = torch_dtype
|
| 482 |
-
|
| 483 |
-
if args.ref_model_init_kwargs is None:
|
| 484 |
-
ref_model_init_kwargs = {}
|
| 485 |
-
elif not isinstance(ref_model, str):
|
| 486 |
-
raise ValueError(
|
| 487 |
-
"You passed ref_model_kwargs to the BCOTrainer. But your ref_model is already instantiated."
|
| 488 |
-
)
|
| 489 |
-
else:
|
| 490 |
-
ref_model_init_kwargs = args.ref_model_init_kwargs
|
| 491 |
-
torch_dtype = ref_model_init_kwargs.get("torch_dtype")
|
| 492 |
-
if torch_dtype is not None:
|
| 493 |
-
# Convert to `torch.dtype` if an str is passed
|
| 494 |
-
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
| 495 |
-
torch_dtype = getattr(torch, torch_dtype)
|
| 496 |
-
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
| 497 |
-
raise ValueError(
|
| 498 |
-
f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
| 499 |
-
)
|
| 500 |
-
ref_model_init_kwargs["torch_dtype"] = torch_dtype
|
| 501 |
-
|
| 502 |
-
if isinstance(model, str):
|
| 503 |
-
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
| 504 |
-
|
| 505 |
-
if isinstance(ref_model, str):
|
| 506 |
-
ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
|
| 507 |
-
|
| 508 |
-
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
| 509 |
-
# has been called in order to properly call autocast if needed.
|
| 510 |
-
self._peft_has_been_casted_to_bf16 = False
|
| 511 |
-
|
| 512 |
-
if not is_peft_available() and peft_config is not None:
|
| 513 |
-
raise ValueError(
|
| 514 |
-
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
|
| 515 |
-
)
|
| 516 |
-
elif is_peft_available() and peft_config is not None:
|
| 517 |
-
# if model is a peft model and we have a peft_config, we merge and unload it first
|
| 518 |
-
if isinstance(model, PeftModel):
|
| 519 |
-
model = model.merge_and_unload()
|
| 520 |
-
|
| 521 |
-
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
| 522 |
-
_support_gc_kwargs = hasattr(
|
| 523 |
-
args, "gradient_checkpointing_kwargs"
|
| 524 |
-
) and "gradient_checkpointing_kwargs" in list(
|
| 525 |
-
inspect.signature(prepare_model_for_kbit_training).parameters
|
| 526 |
-
)
|
| 527 |
-
|
| 528 |
-
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
| 529 |
-
|
| 530 |
-
if _support_gc_kwargs:
|
| 531 |
-
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
| 532 |
-
|
| 533 |
-
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
| 534 |
-
elif getattr(args, "gradient_checkpointing", False):
|
| 535 |
-
# For backward compatibility with older versions of transformers
|
| 536 |
-
if hasattr(model, "enable_input_require_grads"):
|
| 537 |
-
model.enable_input_require_grads()
|
| 538 |
-
else:
|
| 539 |
-
|
| 540 |
-
def make_inputs_require_grad(module, input, output):
|
| 541 |
-
output.requires_grad_(True)
|
| 542 |
-
|
| 543 |
-
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 544 |
-
|
| 545 |
-
# get peft model with the given config
|
| 546 |
-
model = model
|
| 547 |
-
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
| 548 |
-
peft_module_casting_to_bf16(model)
|
| 549 |
-
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
| 550 |
-
self._peft_has_been_casted_to_bf16 = True
|
| 551 |
-
|
| 552 |
-
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
| 553 |
-
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
| 554 |
-
# fail or completely fail.
|
| 555 |
-
elif getattr(args, "gradient_checkpointing", False):
|
| 556 |
-
# For backward compatibility with older versions of transformers
|
| 557 |
-
if hasattr(model, "enable_input_require_grads"):
|
| 558 |
-
model.enable_input_require_grads()
|
| 559 |
-
else:
|
| 560 |
-
|
| 561 |
-
def make_inputs_require_grad(module, input, output):
|
| 562 |
-
output.requires_grad_(True)
|
| 563 |
-
|
| 564 |
-
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 565 |
-
|
| 566 |
-
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
|
| 567 |
-
raise ValueError(
|
| 568 |
-
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
|
| 569 |
-
" Please install `wandb` or `comet-ml` to resolve."
|
| 570 |
-
)
|
| 571 |
-
|
| 572 |
-
if model is not None:
|
| 573 |
-
self.is_encoder_decoder = model.config.is_encoder_decoder
|
| 574 |
-
elif args.is_encoder_decoder is None:
|
| 575 |
-
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
| 576 |
-
else:
|
| 577 |
-
self.is_encoder_decoder = args.is_encoder_decoder
|
| 578 |
-
|
| 579 |
-
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
|
| 580 |
-
self.model_adapter_name = model_adapter_name
|
| 581 |
-
self.ref_adapter_name = ref_adapter_name
|
| 582 |
-
|
| 583 |
-
if ref_model:
|
| 584 |
-
self.ref_model = ref_model
|
| 585 |
-
elif self.is_peft_model or args.precompute_ref_log_probs:
|
| 586 |
-
# The `model` with adapters turned off will be used as the reference model
|
| 587 |
-
self.ref_model = None
|
| 588 |
-
else:
|
| 589 |
-
self.ref_model = create_reference_model(model)
|
| 590 |
-
|
| 591 |
-
if processing_class is None:
|
| 592 |
-
raise ValueError(
|
| 593 |
-
"max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
|
| 594 |
-
)
|
| 595 |
-
if args.max_length is None:
|
| 596 |
-
warnings.warn(
|
| 597 |
-
"When using DPODataCollatorWithPadding, you should set `max_length` in the `BCOConfig`. "
|
| 598 |
-
"It will be set to `512` by default, but you should do it yourself in the future.",
|
| 599 |
-
UserWarning,
|
| 600 |
-
)
|
| 601 |
-
max_length = 512
|
| 602 |
-
if args.max_length is not None:
|
| 603 |
-
max_length = args.max_length
|
| 604 |
-
|
| 605 |
-
if args.max_prompt_length is None:
|
| 606 |
-
warnings.warn(
|
| 607 |
-
"When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the `BCOConfig`. "
|
| 608 |
-
"It will be set to `128` by default, but you should do it yourself in the future.",
|
| 609 |
-
UserWarning,
|
| 610 |
-
)
|
| 611 |
-
max_prompt_length = 128
|
| 612 |
-
if args.max_prompt_length is not None:
|
| 613 |
-
max_prompt_length = args.max_prompt_length
|
| 614 |
-
|
| 615 |
-
max_completion_length = None
|
| 616 |
-
if args.max_completion_length is None and self.is_encoder_decoder:
|
| 617 |
-
warnings.warn(
|
| 618 |
-
"When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the BCOTrainer's init"
|
| 619 |
-
" it will be set to `128` by default, but you should do it yourself in the future.",
|
| 620 |
-
UserWarning,
|
| 621 |
-
)
|
| 622 |
-
max_completion_length = 128
|
| 623 |
-
if args.max_completion_length is not None and self.is_encoder_decoder:
|
| 624 |
-
max_completion_length = args.max_completion_length
|
| 625 |
-
|
| 626 |
-
if data_collator is None:
|
| 627 |
-
data_collator = DPODataCollatorWithPadding(
|
| 628 |
-
pad_token_id=processing_class.pad_token_id,
|
| 629 |
-
label_pad_token_id=args.label_pad_token_id,
|
| 630 |
-
is_encoder_decoder=self.is_encoder_decoder,
|
| 631 |
-
)
|
| 632 |
-
|
| 633 |
-
if args.remove_unused_columns:
|
| 634 |
-
args.remove_unused_columns = False
|
| 635 |
-
# warn users
|
| 636 |
-
warnings.warn(
|
| 637 |
-
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your BCOConfig"
|
| 638 |
-
" we have set it for you, but you should do it yourself in the future.",
|
| 639 |
-
UserWarning,
|
| 640 |
-
)
|
| 641 |
-
|
| 642 |
-
self.use_dpo_data_collator = True
|
| 643 |
-
else:
|
| 644 |
-
self.use_dpo_data_collator = False
|
| 645 |
-
|
| 646 |
-
# Disable dropout in the model and reference model
|
| 647 |
-
if args.disable_dropout:
|
| 648 |
-
disable_dropout_in_model(model)
|
| 649 |
-
if self.ref_model is not None:
|
| 650 |
-
disable_dropout_in_model(self.ref_model)
|
| 651 |
-
|
| 652 |
-
self.max_length = max_length
|
| 653 |
-
self.generate_during_eval = args.generate_during_eval
|
| 654 |
-
self.label_pad_token_id = args.label_pad_token_id
|
| 655 |
-
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
|
| 656 |
-
self.max_prompt_length = max_prompt_length
|
| 657 |
-
self.truncation_mode = args.truncation_mode
|
| 658 |
-
self.max_completion_length = max_completion_length
|
| 659 |
-
self.precompute_ref_log_probs = args.precompute_ref_log_probs
|
| 660 |
-
|
| 661 |
-
# Since ref_logs are precomputed on the first call to get_train/eval_dataloader
|
| 662 |
-
# keep track of first called to avoid computation of future calls
|
| 663 |
-
self._precomputed_train_ref_log_probs = False
|
| 664 |
-
self._precomputed_eval_ref_log_probs = False
|
| 665 |
-
|
| 666 |
-
# metric
|
| 667 |
-
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
| 668 |
-
|
| 669 |
-
# BCO parameter
|
| 670 |
-
self.beta = args.beta
|
| 671 |
-
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
| 672 |
-
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
|
| 673 |
-
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
|
| 674 |
-
warnings.warn(
|
| 675 |
-
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
|
| 676 |
-
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
|
| 677 |
-
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
|
| 678 |
-
"loss.",
|
| 679 |
-
UserWarning,
|
| 680 |
-
)
|
| 681 |
-
|
| 682 |
-
# Underlying Distribution Matching argument
|
| 683 |
-
self.embedding_func = embedding_func
|
| 684 |
-
self.embedding_tokenizer = embedding_tokenizer
|
| 685 |
-
|
| 686 |
-
# The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
|
| 687 |
-
# input tensor associated with the key "input_ids". However, in BCO, the sampled data does not include the
|
| 688 |
-
# "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
|
| 689 |
-
# the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
|
| 690 |
-
# operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
|
| 691 |
-
# "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
|
| 692 |
-
# issued.
|
| 693 |
-
model.warnings_issued["estimate_tokens"] = True
|
| 694 |
-
|
| 695 |
-
with PartialState().main_process_first():
|
| 696 |
-
# Apply the chat template if needed
|
| 697 |
-
train_dataset = train_dataset.map(
|
| 698 |
-
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
|
| 699 |
-
)
|
| 700 |
-
if eval_dataset is not None:
|
| 701 |
-
eval_dataset = eval_dataset.map(
|
| 702 |
-
maybe_apply_chat_template,
|
| 703 |
-
fn_kwargs={"tokenizer": processing_class},
|
| 704 |
-
num_proc=args.dataset_num_proc,
|
| 705 |
-
)
|
| 706 |
-
# Shuffle the datasets
|
| 707 |
-
train_dataset = train_dataset.shuffle(seed=args.data_seed)
|
| 708 |
-
if eval_dataset is not None:
|
| 709 |
-
eval_dataset = eval_dataset.shuffle(seed=args.data_seed)
|
| 710 |
-
# Tokenize and prepare the training datasets
|
| 711 |
-
train_dataset = train_dataset.map(
|
| 712 |
-
_tokenize,
|
| 713 |
-
batched=True,
|
| 714 |
-
fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer},
|
| 715 |
-
num_proc=args.dataset_num_proc,
|
| 716 |
-
desc="Tokenizing train dataset",
|
| 717 |
-
)
|
| 718 |
-
|
| 719 |
-
# Prepare the datasets
|
| 720 |
-
fn_kwargs = {
|
| 721 |
-
"prefix": "",
|
| 722 |
-
"is_encoder_decoder": self.is_encoder_decoder,
|
| 723 |
-
"tokenizer": processing_class,
|
| 724 |
-
"max_length": self.max_length,
|
| 725 |
-
"truncation_mode": self.truncation_mode,
|
| 726 |
-
"label_pad_token_id": self.label_pad_token_id,
|
| 727 |
-
"max_prompt_length": self.max_prompt_length,
|
| 728 |
-
"max_completion_length": self.max_completion_length,
|
| 729 |
-
}
|
| 730 |
-
train_dataset = train_dataset.map(
|
| 731 |
-
_process_tokens,
|
| 732 |
-
fn_kwargs=fn_kwargs,
|
| 733 |
-
num_proc=args.dataset_num_proc,
|
| 734 |
-
desc="Processing tokenized train dataset",
|
| 735 |
-
)
|
| 736 |
-
|
| 737 |
-
if eval_dataset is not None:
|
| 738 |
-
# Tokenize
|
| 739 |
-
eval_dataset = eval_dataset.map(
|
| 740 |
-
_tokenize,
|
| 741 |
-
fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer},
|
| 742 |
-
batched=True,
|
| 743 |
-
num_proc=args.dataset_num_proc,
|
| 744 |
-
desc="Tokenizing eval dataset",
|
| 745 |
-
)
|
| 746 |
-
|
| 747 |
-
# Process
|
| 748 |
-
fn_kwargs = {
|
| 749 |
-
"prefix": "",
|
| 750 |
-
"is_encoder_decoder": self.is_encoder_decoder,
|
| 751 |
-
"tokenizer": processing_class,
|
| 752 |
-
"max_length": self.max_length,
|
| 753 |
-
"truncation_mode": self.truncation_mode,
|
| 754 |
-
"label_pad_token_id": self.label_pad_token_id,
|
| 755 |
-
"max_prompt_length": self.max_prompt_length,
|
| 756 |
-
"max_completion_length": self.max_completion_length,
|
| 757 |
-
}
|
| 758 |
-
eval_dataset = eval_dataset.map(
|
| 759 |
-
_process_tokens,
|
| 760 |
-
fn_kwargs=fn_kwargs,
|
| 761 |
-
num_proc=args.dataset_num_proc,
|
| 762 |
-
desc="Processing tokenized eval dataset",
|
| 763 |
-
)
|
| 764 |
-
|
| 765 |
-
desirable = train_dataset.filter(
|
| 766 |
-
lambda x: x["label"], num_proc=args.dataset_num_proc, desc="Filtering desirable examples"
|
| 767 |
-
)
|
| 768 |
-
undesirable = train_dataset.filter(
|
| 769 |
-
lambda x: not x["label"], num_proc=args.dataset_num_proc, desc="Filtering undesirable examples"
|
| 770 |
-
)
|
| 771 |
-
|
| 772 |
-
desirable = desirable.shuffle(seed=args.data_seed)
|
| 773 |
-
undesirable = undesirable.shuffle(seed=args.data_seed)
|
| 774 |
-
|
| 775 |
-
super().__init__(
|
| 776 |
-
model=model,
|
| 777 |
-
args=args,
|
| 778 |
-
data_collator=data_collator,
|
| 779 |
-
train_dataset=train_dataset,
|
| 780 |
-
eval_dataset=eval_dataset,
|
| 781 |
-
processing_class=processing_class,
|
| 782 |
-
model_init=model_init,
|
| 783 |
-
compute_metrics=compute_metrics,
|
| 784 |
-
callbacks=callbacks,
|
| 785 |
-
optimizers=optimizers,
|
| 786 |
-
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 787 |
-
)
|
| 788 |
-
|
| 789 |
-
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
| 790 |
-
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
| 791 |
-
# self.model_accepts_loss_kwargs to False to enable scaling.
|
| 792 |
-
self.model_accepts_loss_kwargs = False
|
| 793 |
-
|
| 794 |
-
# Add tags for models that have been loaded with the correct transformers version
|
| 795 |
-
if hasattr(self.model, "add_model_tags"):
|
| 796 |
-
self.model.add_model_tags(self._tag_names)
|
| 797 |
-
|
| 798 |
-
if not hasattr(self, "accelerator"):
|
| 799 |
-
raise AttributeError(
|
| 800 |
-
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
| 801 |
-
)
|
| 802 |
-
|
| 803 |
-
# Deepspeed Zero-3 does not support precompute_ref_log_probs
|
| 804 |
-
if self.is_deepspeed_enabled:
|
| 805 |
-
if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
|
| 806 |
-
raise ValueError(
|
| 807 |
-
"You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
|
| 808 |
-
)
|
| 809 |
-
|
| 810 |
-
if self.ref_model is None:
|
| 811 |
-
if not (self.is_peft_model or self.precompute_ref_log_probs):
|
| 812 |
-
raise ValueError(
|
| 813 |
-
"No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
|
| 814 |
-
)
|
| 815 |
-
else:
|
| 816 |
-
if self.is_deepspeed_enabled:
|
| 817 |
-
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
| 818 |
-
else:
|
| 819 |
-
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
| 820 |
-
|
| 821 |
-
self.running = RunningMoments(accelerator=self.accelerator)
|
| 822 |
-
|
| 823 |
-
if self.embedding_func is None:
|
| 824 |
-
return
|
| 825 |
-
|
| 826 |
-
chosen_embeddings = self._get_sample_prompt_embeddings(desirable, sample_size=self.args.prompt_sample_size)
|
| 827 |
-
rejected_embeddings = self._get_sample_prompt_embeddings(undesirable, sample_size=self.args.prompt_sample_size)
|
| 828 |
-
|
| 829 |
-
embeddings = torch.cat((chosen_embeddings, rejected_embeddings), dim=0)
|
| 830 |
-
labels = torch.cat(
|
| 831 |
-
(torch.ones_like(chosen_embeddings[:, 0]), torch.zeros_like(rejected_embeddings[:, 0])), dim=0
|
| 832 |
-
)
|
| 833 |
-
|
| 834 |
-
self.clf = LogisticRegression(class_weight="balanced").fit(
|
| 835 |
-
embeddings.cpu().float().numpy(), labels.cpu().numpy()
|
| 836 |
-
)
|
| 837 |
-
|
| 838 |
-
@property
|
| 839 |
-
def match_underlying_distribution(self):
|
| 840 |
-
return self.embedding_func is not None and self.embedding_tokenizer is not None
|
| 841 |
-
|
| 842 |
-
def _get_chosen_prob(self, prompt_embeddings: torch.FloatTensor) -> torch.FloatTensor:
|
| 843 |
-
"""
|
| 844 |
-
Calculates the probability if the given prompt embedding is from desirable dataset.
|
| 845 |
-
This function calculates the probability in the process and ensemble across processes.
|
| 846 |
-
"""
|
| 847 |
-
dtype = prompt_embeddings.dtype
|
| 848 |
-
device = prompt_embeddings.device
|
| 849 |
-
rank = self.accelerator.process_index
|
| 850 |
-
|
| 851 |
-
padded_prompt_embeddings = self.accelerator.pad_across_processes(
|
| 852 |
-
prompt_embeddings, pad_index=self.embedding_tokenizer.pad_token_id
|
| 853 |
-
)
|
| 854 |
-
sample_size = padded_prompt_embeddings.shape[0]
|
| 855 |
-
nonzero = padded_prompt_embeddings.mean(dim=1) != self.embedding_tokenizer.pad_token_id
|
| 856 |
-
prompt_embeddings = self.accelerator.gather(padded_prompt_embeddings)
|
| 857 |
-
|
| 858 |
-
# cannot predict for all empty values
|
| 859 |
-
if prompt_embeddings.shape[0] == 0:
|
| 860 |
-
return torch.tensor([], device=device, dtype=dtype)
|
| 861 |
-
|
| 862 |
-
prob = self.clf.predict_proba(prompt_embeddings.cpu().float().numpy())[:, 1]
|
| 863 |
-
prob = torch.as_tensor(prob, dtype=dtype, device=device)
|
| 864 |
-
prob = self.accelerator.reduce(prob, reduction="mean")
|
| 865 |
-
|
| 866 |
-
prob = prob[sample_size * rank : sample_size * (rank + 1)]
|
| 867 |
-
prob = prob[nonzero]
|
| 868 |
-
|
| 869 |
-
return prob
|
| 870 |
-
|
| 871 |
-
def _vectorize_prompt(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor) -> torch.FloatTensor:
|
| 872 |
-
"""
|
| 873 |
-
Replaces processing_class.pad_token_id to embedding_tokenizer.pad_token_id
|
| 874 |
-
and applies self.embedding_func
|
| 875 |
-
"""
|
| 876 |
-
input_ids = torch.where(
|
| 877 |
-
input_ids == self.processing_class.pad_token_id,
|
| 878 |
-
self.embedding_tokenizer.pad_token_id,
|
| 879 |
-
input_ids,
|
| 880 |
-
)
|
| 881 |
-
|
| 882 |
-
with torch.no_grad():
|
| 883 |
-
embeddings = self.embedding_func(
|
| 884 |
-
input_ids=input_ids,
|
| 885 |
-
attention_mask=attention_mask,
|
| 886 |
-
)
|
| 887 |
-
|
| 888 |
-
return embeddings
|
| 889 |
-
|
| 890 |
-
def _get_prompt_embeddings(
|
| 891 |
-
self, batch: dict[str, Union[list, torch.LongTensor]]
|
| 892 |
-
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
| 893 |
-
"""Extract embeddings from frozen embedding model"""
|
| 894 |
-
|
| 895 |
-
if not self.match_underlying_distribution:
|
| 896 |
-
return None, None
|
| 897 |
-
|
| 898 |
-
embeddings = self._vectorize_prompt(
|
| 899 |
-
input_ids=batch["embedding_input_ids"],
|
| 900 |
-
attention_mask=batch["embedding_attention_mask"],
|
| 901 |
-
)
|
| 902 |
-
|
| 903 |
-
chosen_idx = [i for i in range(len(batch["label"])) if batch["label"][i] is True]
|
| 904 |
-
rejected_idx = [i for i in range(len(batch["label"])) if batch["label"][i] is False]
|
| 905 |
-
|
| 906 |
-
chosen_embeddings = embeddings[chosen_idx, ...]
|
| 907 |
-
rejected_embeddings = embeddings[rejected_idx, ...]
|
| 908 |
-
|
| 909 |
-
return (chosen_embeddings, rejected_embeddings)
|
| 910 |
-
|
| 911 |
-
def _get_sample_prompt_embeddings(self, dataset: Dataset, sample_size: int = 512) -> torch.FloatTensor:
|
| 912 |
-
"""
|
| 913 |
-
Sample instances from dataset and get prompt embeddings.
|
| 914 |
-
Used for density ratio classifier training.
|
| 915 |
-
"""
|
| 916 |
-
n_samples = min(len(dataset), sample_size)
|
| 917 |
-
rand_indices = np.random.choice(len(dataset), size=(n_samples,))
|
| 918 |
-
|
| 919 |
-
embedding_dataset = dataset.select(rand_indices)
|
| 920 |
-
|
| 921 |
-
dataloader_params = {
|
| 922 |
-
"batch_size": self.args.per_device_train_batch_size,
|
| 923 |
-
"collate_fn": self.data_collator,
|
| 924 |
-
"num_workers": self.args.dataloader_num_workers,
|
| 925 |
-
"pin_memory": self.args.dataloader_pin_memory,
|
| 926 |
-
"shuffle": False,
|
| 927 |
-
}
|
| 928 |
-
|
| 929 |
-
# prepare dataloader
|
| 930 |
-
data_loader = self.accelerator.prepare(DataLoader(embedding_dataset, **dataloader_params))
|
| 931 |
-
|
| 932 |
-
with torch.no_grad():
|
| 933 |
-
all_embeddings = torch.empty(0)
|
| 934 |
-
for padded_batch in tqdm(iterable=data_loader, desc="Building sample prompt embeddings"):
|
| 935 |
-
embeddings = self._vectorize_prompt(
|
| 936 |
-
input_ids=padded_batch["embedding_input_ids"],
|
| 937 |
-
attention_mask=padded_batch["embedding_attention_mask"],
|
| 938 |
-
)
|
| 939 |
-
embeddings = self.accelerator.gather_for_metrics(embeddings)
|
| 940 |
-
all_embeddings = torch.cat((all_embeddings, embeddings.cpu()))
|
| 941 |
-
|
| 942 |
-
return all_embeddings
|
| 943 |
-
|
| 944 |
-
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
|
| 945 |
-
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
| 946 |
-
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
| 947 |
-
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
| 948 |
-
|
| 949 |
-
if model is not None:
|
| 950 |
-
if hasattr(model, "config"):
|
| 951 |
-
hidden_size = (
|
| 952 |
-
max(model.config.hidden_sizes)
|
| 953 |
-
if getattr(model.config, "hidden_sizes", None)
|
| 954 |
-
else getattr(model.config, "hidden_size", None)
|
| 955 |
-
)
|
| 956 |
-
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
| 957 |
-
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
| 958 |
-
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
| 959 |
-
config_kwargs.update(
|
| 960 |
-
{
|
| 961 |
-
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
| 962 |
-
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
| 963 |
-
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
| 964 |
-
}
|
| 965 |
-
)
|
| 966 |
-
|
| 967 |
-
# If ZeRO-3 is used, we shard both the active and reference model.
|
| 968 |
-
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
| 969 |
-
if config_kwargs["zero_optimization"]["stage"] != 3:
|
| 970 |
-
config_kwargs["zero_optimization"]["stage"] = 0
|
| 971 |
-
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
| 972 |
-
model.eval()
|
| 973 |
-
return model
|
| 974 |
-
|
| 975 |
-
def _save_optimizer_and_scheduler(self, output_dir):
|
| 976 |
-
super()._save_optimizer_and_scheduler(output_dir)
|
| 977 |
-
|
| 978 |
-
# When saving optimizer and scheduler to checkpoint, save also the running delta object.
|
| 979 |
-
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
| 980 |
-
|
| 981 |
-
self.running.save_to_json(os.path.join(output_dir, RUNNING_NAME))
|
| 982 |
-
|
| 983 |
-
if self.match_underlying_distribution:
|
| 984 |
-
torch.save(self.clf.get_params(), os.path.join(output_dir, CLF_NAME))
|
| 985 |
-
|
| 986 |
-
def _load_optimizer_and_scheduler(self, checkpoint):
|
| 987 |
-
super()._load_optimizer_and_scheduler(checkpoint)
|
| 988 |
-
|
| 989 |
-
if checkpoint is None:
|
| 990 |
-
return
|
| 991 |
-
# when loading optimizer and scheduler from checkpoint, also load the running delta object.
|
| 992 |
-
running_file = os.path.join(checkpoint, RUNNING_NAME)
|
| 993 |
-
if os.path.isfile(running_file):
|
| 994 |
-
self.running = RunningMoments.load_from_json(self.accelerator, running_file)
|
| 995 |
-
|
| 996 |
-
if self.match_underlying_distribution:
|
| 997 |
-
clf_file = os.path.join(checkpoint, CLF_NAME)
|
| 998 |
-
if os.path.isfile(running_file):
|
| 999 |
-
self.clf.set_params(**torch.load(clf_file, weights_only=True, map_location="cpu"))
|
| 1000 |
-
|
| 1001 |
-
@contextmanager
|
| 1002 |
-
def null_ref_context(self):
|
| 1003 |
-
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
|
| 1004 |
-
with (
|
| 1005 |
-
self.accelerator.unwrap_model(self.model).disable_adapter()
|
| 1006 |
-
if self.is_peft_model and not self.ref_adapter_name
|
| 1007 |
-
else nullcontext()
|
| 1008 |
-
):
|
| 1009 |
-
if self.ref_adapter_name:
|
| 1010 |
-
self.model.set_adapter(self.ref_adapter_name)
|
| 1011 |
-
yield
|
| 1012 |
-
if self.ref_adapter_name:
|
| 1013 |
-
self.model.set_adapter(self.model_adapter_name or "default")
|
| 1014 |
-
|
| 1015 |
-
def get_train_dataloader(self) -> DataLoader:
|
| 1016 |
-
"""
|
| 1017 |
-
Returns the training [`~torch.utils.data.DataLoader`].
|
| 1018 |
-
|
| 1019 |
-
Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
|
| 1020 |
-
"""
|
| 1021 |
-
|
| 1022 |
-
if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
|
| 1023 |
-
dataloader_params = {
|
| 1024 |
-
"batch_size": self.args.per_device_train_batch_size,
|
| 1025 |
-
"collate_fn": self.data_collator,
|
| 1026 |
-
"num_workers": self.args.dataloader_num_workers,
|
| 1027 |
-
"pin_memory": self.args.dataloader_pin_memory,
|
| 1028 |
-
"shuffle": False,
|
| 1029 |
-
}
|
| 1030 |
-
|
| 1031 |
-
# prepare dataloader
|
| 1032 |
-
data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
|
| 1033 |
-
reference_completion_logps = []
|
| 1034 |
-
|
| 1035 |
-
for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
|
| 1036 |
-
reference_completion_logp = self.compute_reference_log_probs(padded_batch)
|
| 1037 |
-
|
| 1038 |
-
reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
|
| 1039 |
-
reference_completion_logps.append(reference_completion_logp.cpu())
|
| 1040 |
-
|
| 1041 |
-
self.train_dataset = self.train_dataset.add_column(
|
| 1042 |
-
name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
|
| 1043 |
-
)
|
| 1044 |
-
|
| 1045 |
-
self._precomputed_train_ref_log_probs = True
|
| 1046 |
-
|
| 1047 |
-
return super().get_train_dataloader()
|
| 1048 |
-
|
| 1049 |
-
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
| 1050 |
-
"""
|
| 1051 |
-
Returns the evaluation [`~torch.utils.data.DataLoader`].
|
| 1052 |
-
|
| 1053 |
-
Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
|
| 1054 |
-
|
| 1055 |
-
Args:
|
| 1056 |
-
eval_dataset (`torch.utils.data.Dataset`, *optional*):
|
| 1057 |
-
If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
|
| 1058 |
-
by the `model.forward()` method are automatically removed. It must implement `__len__`.
|
| 1059 |
-
"""
|
| 1060 |
-
if eval_dataset is None and self.eval_dataset is None:
|
| 1061 |
-
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
| 1062 |
-
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
| 1063 |
-
|
| 1064 |
-
if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
|
| 1065 |
-
dataloader_params = {
|
| 1066 |
-
"batch_size": self.args.per_device_eval_batch_size,
|
| 1067 |
-
"collate_fn": self.data_collator,
|
| 1068 |
-
"num_workers": self.args.dataloader_num_workers,
|
| 1069 |
-
"pin_memory": self.args.dataloader_pin_memory,
|
| 1070 |
-
"shuffle": False,
|
| 1071 |
-
}
|
| 1072 |
-
|
| 1073 |
-
# prepare dataloader
|
| 1074 |
-
data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
|
| 1075 |
-
|
| 1076 |
-
reference_completion_logps = []
|
| 1077 |
-
|
| 1078 |
-
for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
|
| 1079 |
-
reference_completion_logp = self.compute_reference_log_probs(padded_batch)
|
| 1080 |
-
|
| 1081 |
-
reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
|
| 1082 |
-
reference_completion_logps.append(reference_completion_logp.cpu())
|
| 1083 |
-
|
| 1084 |
-
eval_dataset = eval_dataset.add_column(
|
| 1085 |
-
name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
|
| 1086 |
-
)
|
| 1087 |
-
|
| 1088 |
-
# Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
|
| 1089 |
-
if self.eval_dataset is not None:
|
| 1090 |
-
self.eval_dataset = eval_dataset
|
| 1091 |
-
self._precomputed_eval_ref_log_probs = True
|
| 1092 |
-
|
| 1093 |
-
return super().get_eval_dataloader(eval_dataset=eval_dataset)
|
| 1094 |
-
|
| 1095 |
-
def compute_reference_log_probs(self, padded_batch: dict) -> dict:
|
| 1096 |
-
"""Computes log probabilities of the reference model for a single padded batch of a BCO specific dataset."""
|
| 1097 |
-
with torch.no_grad():
|
| 1098 |
-
if self.ref_model is None:
|
| 1099 |
-
with self.null_ref_context():
|
| 1100 |
-
if self.is_encoder_decoder:
|
| 1101 |
-
completion_logits = self.model(
|
| 1102 |
-
padded_batch["prompt_input_ids"],
|
| 1103 |
-
attention_mask=padded_batch["prompt_attention_mask"],
|
| 1104 |
-
decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
|
| 1105 |
-
labels=padded_batch["completion_labels"],
|
| 1106 |
-
).logits
|
| 1107 |
-
|
| 1108 |
-
else:
|
| 1109 |
-
completion_logits = self.model(
|
| 1110 |
-
padded_batch["completion_input_ids"],
|
| 1111 |
-
attention_mask=padded_batch["completion_attention_mask"],
|
| 1112 |
-
).logits
|
| 1113 |
-
|
| 1114 |
-
else:
|
| 1115 |
-
if self.is_encoder_decoder:
|
| 1116 |
-
completion_logits = self.ref_model(
|
| 1117 |
-
padded_batch["prompt_input_ids"],
|
| 1118 |
-
attention_mask=padded_batch["prompt_attention_mask"],
|
| 1119 |
-
decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
|
| 1120 |
-
labels=padded_batch["completion_labels"],
|
| 1121 |
-
).logits
|
| 1122 |
-
|
| 1123 |
-
else:
|
| 1124 |
-
completion_logits = self.ref_model(
|
| 1125 |
-
padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
|
| 1126 |
-
).logits
|
| 1127 |
-
|
| 1128 |
-
completion_logps = self.get_batch_logps(
|
| 1129 |
-
completion_logits,
|
| 1130 |
-
padded_batch["completion_labels"],
|
| 1131 |
-
average_log_prob=False,
|
| 1132 |
-
is_encoder_decoder=self.is_encoder_decoder,
|
| 1133 |
-
label_pad_token_id=self.label_pad_token_id,
|
| 1134 |
-
)
|
| 1135 |
-
|
| 1136 |
-
return completion_logps
|
| 1137 |
-
|
| 1138 |
-
@staticmethod
|
| 1139 |
-
def get_batch_logps(
|
| 1140 |
-
logits: torch.FloatTensor,
|
| 1141 |
-
labels: torch.LongTensor,
|
| 1142 |
-
average_log_prob: bool = False,
|
| 1143 |
-
label_pad_token_id: int = -100,
|
| 1144 |
-
is_encoder_decoder: bool = False,
|
| 1145 |
-
) -> torch.FloatTensor:
|
| 1146 |
-
"""Compute the log probabilities of the given labels under the given logits.
|
| 1147 |
-
|
| 1148 |
-
Args:
|
| 1149 |
-
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
|
| 1150 |
-
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
|
| 1151 |
-
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
| 1152 |
-
|
| 1153 |
-
Returns:
|
| 1154 |
-
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
| 1155 |
-
"""
|
| 1156 |
-
if logits.shape[:-1] != labels.shape:
|
| 1157 |
-
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
| 1158 |
-
|
| 1159 |
-
if not is_encoder_decoder:
|
| 1160 |
-
labels = labels[:, 1:].clone()
|
| 1161 |
-
logits = logits[:, :-1, :]
|
| 1162 |
-
else:
|
| 1163 |
-
# Fixes end-dec RuntimeError
|
| 1164 |
-
labels = labels.clone()
|
| 1165 |
-
|
| 1166 |
-
loss_mask = labels != label_pad_token_id
|
| 1167 |
-
|
| 1168 |
-
# dummy token; we'll ignore the losses on these tokens later
|
| 1169 |
-
labels[labels == label_pad_token_id] = 0
|
| 1170 |
-
|
| 1171 |
-
per_token_logps = selective_log_softmax(logits, labels)
|
| 1172 |
-
|
| 1173 |
-
if average_log_prob:
|
| 1174 |
-
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
| 1175 |
-
else:
|
| 1176 |
-
return (per_token_logps * loss_mask).sum(-1)
|
| 1177 |
-
|
| 1178 |
-
def forward(
|
| 1179 |
-
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
| 1180 |
-
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
| 1181 |
-
model_kwargs = (
|
| 1182 |
-
{
|
| 1183 |
-
"labels": batch["completion_labels"],
|
| 1184 |
-
"decoder_input_ids": batch.get("completion_decoder_input_ids"),
|
| 1185 |
-
}
|
| 1186 |
-
if self.is_encoder_decoder
|
| 1187 |
-
else {}
|
| 1188 |
-
)
|
| 1189 |
-
if self.aux_loss_enabled:
|
| 1190 |
-
model_kwargs["output_router_logits"] = True
|
| 1191 |
-
|
| 1192 |
-
outputs = model(
|
| 1193 |
-
batch["completion_input_ids"],
|
| 1194 |
-
attention_mask=batch["completion_attention_mask"],
|
| 1195 |
-
**model_kwargs,
|
| 1196 |
-
)
|
| 1197 |
-
completion_logits = outputs.logits
|
| 1198 |
-
|
| 1199 |
-
completion_logps = self.get_batch_logps(
|
| 1200 |
-
completion_logits,
|
| 1201 |
-
batch["completion_labels"],
|
| 1202 |
-
average_log_prob=False,
|
| 1203 |
-
is_encoder_decoder=self.is_encoder_decoder,
|
| 1204 |
-
label_pad_token_id=self.label_pad_token_id,
|
| 1205 |
-
)
|
| 1206 |
-
|
| 1207 |
-
if completion_logps.shape[0] != len(batch["label"]):
|
| 1208 |
-
raise ValueError(
|
| 1209 |
-
"There is a mismatch between the number of examples in this batch and the number of "
|
| 1210 |
-
"examples for which an output sequence was predicted."
|
| 1211 |
-
)
|
| 1212 |
-
|
| 1213 |
-
chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True]
|
| 1214 |
-
rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False]
|
| 1215 |
-
|
| 1216 |
-
chosen_logps = completion_logps[chosen_idx, ...]
|
| 1217 |
-
rejected_logps = completion_logps[rejected_idx, ...]
|
| 1218 |
-
|
| 1219 |
-
chosen_logits = completion_logits[chosen_idx, ...]
|
| 1220 |
-
rejected_logits = completion_logits[rejected_idx, ...]
|
| 1221 |
-
|
| 1222 |
-
if self.aux_loss_enabled:
|
| 1223 |
-
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, outputs.aux_loss)
|
| 1224 |
-
else:
|
| 1225 |
-
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
|
| 1226 |
-
|
| 1227 |
-
def _get_udm_weight(self, rejected_embeddings: torch.FloatTensor) -> torch.FloatTensor:
|
| 1228 |
-
prob_desirable = self._get_chosen_prob(rejected_embeddings)
|
| 1229 |
-
min_ratio = self.args.min_density_ratio
|
| 1230 |
-
max_ratio = self.args.max_density_ratio
|
| 1231 |
-
|
| 1232 |
-
weight = (prob_desirable / (1 - prob_desirable + 1e-8)).clamp(min=min_ratio, max=max_ratio)
|
| 1233 |
-
|
| 1234 |
-
return weight
|
| 1235 |
-
|
| 1236 |
-
def bco_loss(
|
| 1237 |
-
self,
|
| 1238 |
-
policy_chosen_logps: torch.FloatTensor,
|
| 1239 |
-
policy_rejected_logps: torch.FloatTensor,
|
| 1240 |
-
reference_chosen_logps: torch.FloatTensor,
|
| 1241 |
-
reference_rejected_logps: torch.FloatTensor,
|
| 1242 |
-
chosen_embeddings: Optional[torch.FloatTensor],
|
| 1243 |
-
rejected_embeddings: Optional[torch.FloatTensor],
|
| 1244 |
-
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
| 1245 |
-
"""Compute the BCO loss for a batch of policy and reference model log probabilities.
|
| 1246 |
-
|
| 1247 |
-
Args:
|
| 1248 |
-
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
|
| 1249 |
-
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
|
| 1250 |
-
reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
|
| 1251 |
-
reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,)
|
| 1252 |
-
chosen_embeddings: embeddings of desirable prompts
|
| 1253 |
-
rejected_embeddings: embeddings of undesirable prompts
|
| 1254 |
-
|
| 1255 |
-
Returns:
|
| 1256 |
-
A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, delta).
|
| 1257 |
-
The losses tensor contains the BCO loss for each example in the batch.
|
| 1258 |
-
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
| 1259 |
-
The delta value contains the moving average of all implicit rewards.
|
| 1260 |
-
"""
|
| 1261 |
-
|
| 1262 |
-
if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
|
| 1263 |
-
chosen_logratios = policy_chosen_logps - reference_chosen_logps
|
| 1264 |
-
chosen_rewards = self.beta * chosen_logratios
|
| 1265 |
-
else:
|
| 1266 |
-
# lists can't be empty -- if they are, then accelerate.gather will hang
|
| 1267 |
-
chosen_losses = torch.Tensor([]).to(self.accelerator.device)
|
| 1268 |
-
chosen_rewards = torch.Tensor([]).to(self.accelerator.device)
|
| 1269 |
-
|
| 1270 |
-
if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
|
| 1271 |
-
rejected_logratios = policy_rejected_logps - reference_rejected_logps
|
| 1272 |
-
rejected_rewards = self.beta * rejected_logratios
|
| 1273 |
-
else:
|
| 1274 |
-
# lists can't be empty -- if they are, then accelerate.gather will hang
|
| 1275 |
-
rejected_losses = torch.Tensor([]).to(self.accelerator.device)
|
| 1276 |
-
rejected_rewards = torch.Tensor([]).to(self.accelerator.device)
|
| 1277 |
-
|
| 1278 |
-
rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
|
| 1279 |
-
self.running.update(rewards)
|
| 1280 |
-
delta = self.running.mean
|
| 1281 |
-
|
| 1282 |
-
if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
|
| 1283 |
-
chosen_losses = -F.logsigmoid(chosen_rewards - delta)
|
| 1284 |
-
|
| 1285 |
-
if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
|
| 1286 |
-
rejected_losses = -F.logsigmoid(-(rejected_rewards - delta))
|
| 1287 |
-
|
| 1288 |
-
if self.match_underlying_distribution:
|
| 1289 |
-
chosen_weight = torch.ones_like(chosen_losses)
|
| 1290 |
-
rejected_weight = self._get_udm_weight(rejected_embeddings)
|
| 1291 |
-
|
| 1292 |
-
losses = torch.cat((chosen_weight * chosen_losses, rejected_weight * rejected_losses), dim=0)
|
| 1293 |
-
else:
|
| 1294 |
-
losses = torch.cat((chosen_losses, rejected_losses), dim=0)
|
| 1295 |
-
|
| 1296 |
-
return losses, chosen_rewards, rejected_rewards, torch.as_tensor(delta)
|
| 1297 |
-
|
| 1298 |
-
def get_batch_loss_metrics(
|
| 1299 |
-
self,
|
| 1300 |
-
model,
|
| 1301 |
-
batch: dict[str, Union[list, torch.LongTensor]],
|
| 1302 |
-
):
|
| 1303 |
-
"""Compute the BCO loss and other metrics for the given batch of inputs for train or test."""
|
| 1304 |
-
metrics = {}
|
| 1305 |
-
batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
| 1306 |
-
|
| 1307 |
-
forward_output = self.forward(model, batch)
|
| 1308 |
-
(
|
| 1309 |
-
policy_chosen_logps,
|
| 1310 |
-
policy_rejected_logps,
|
| 1311 |
-
policy_chosen_logits,
|
| 1312 |
-
policy_rejected_logits,
|
| 1313 |
-
) = forward_output[:4]
|
| 1314 |
-
if self.aux_loss_enabled:
|
| 1315 |
-
aux_loss = forward_output[4]
|
| 1316 |
-
|
| 1317 |
-
# if reference_logps in batch use them, otherwise use the reference model
|
| 1318 |
-
if "reference_logps" in batch:
|
| 1319 |
-
chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
|
| 1320 |
-
rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]
|
| 1321 |
-
|
| 1322 |
-
reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
|
| 1323 |
-
reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
|
| 1324 |
-
else:
|
| 1325 |
-
with torch.no_grad():
|
| 1326 |
-
if self.ref_model is None:
|
| 1327 |
-
with self.null_ref_context():
|
| 1328 |
-
(
|
| 1329 |
-
reference_chosen_logps,
|
| 1330 |
-
reference_rejected_logps,
|
| 1331 |
-
_,
|
| 1332 |
-
_,
|
| 1333 |
-
) = self.forward(self.model, batch)[:4]
|
| 1334 |
-
else:
|
| 1335 |
-
(
|
| 1336 |
-
reference_chosen_logps,
|
| 1337 |
-
reference_rejected_logps,
|
| 1338 |
-
_,
|
| 1339 |
-
_,
|
| 1340 |
-
) = self.forward(self.ref_model, batch)[:4]
|
| 1341 |
-
|
| 1342 |
-
chosen_embeddings, rejected_embeddings = self._get_prompt_embeddings(batch)
|
| 1343 |
-
|
| 1344 |
-
losses, chosen_rewards, rejected_rewards, delta = self.bco_loss(
|
| 1345 |
-
policy_chosen_logps,
|
| 1346 |
-
policy_rejected_logps,
|
| 1347 |
-
reference_chosen_logps,
|
| 1348 |
-
reference_rejected_logps,
|
| 1349 |
-
chosen_embeddings,
|
| 1350 |
-
rejected_embeddings,
|
| 1351 |
-
)
|
| 1352 |
-
metrics["delta"] = self.accelerator.gather_for_metrics(delta).mean().item()
|
| 1353 |
-
|
| 1354 |
-
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
|
| 1355 |
-
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
|
| 1356 |
-
|
| 1357 |
-
all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
|
| 1358 |
-
all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()
|
| 1359 |
-
|
| 1360 |
-
if all_num_chosen > 0:
|
| 1361 |
-
metrics["rewards/chosen_sum"] = (
|
| 1362 |
-
self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
|
| 1363 |
-
)
|
| 1364 |
-
metrics["logps/chosen_sum"] = (
|
| 1365 |
-
self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
|
| 1366 |
-
)
|
| 1367 |
-
metrics["logits/chosen_sum"] = (
|
| 1368 |
-
self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
|
| 1369 |
-
)
|
| 1370 |
-
metrics["count/chosen"] = all_num_chosen
|
| 1371 |
-
|
| 1372 |
-
if all_num_rejected > 0:
|
| 1373 |
-
metrics["rewards/rejected_sum"] = (
|
| 1374 |
-
self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
|
| 1375 |
-
)
|
| 1376 |
-
metrics["logps/rejected_sum"] = (
|
| 1377 |
-
self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
|
| 1378 |
-
)
|
| 1379 |
-
metrics["logits/rejected_sum"] = (
|
| 1380 |
-
self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
|
| 1381 |
-
)
|
| 1382 |
-
metrics["count/rejected"] = all_num_rejected
|
| 1383 |
-
|
| 1384 |
-
loss = losses.nanmean()
|
| 1385 |
-
if self.aux_loss_enabled:
|
| 1386 |
-
loss += self.aux_loss_coef * aux_loss
|
| 1387 |
-
|
| 1388 |
-
return loss, metrics
|
| 1389 |
-
|
| 1390 |
-
def compute_loss(
|
| 1391 |
-
self,
|
| 1392 |
-
model: Union[PreTrainedModel, nn.Module],
|
| 1393 |
-
inputs: dict[str, Union[torch.Tensor, Any]],
|
| 1394 |
-
return_outputs=False,
|
| 1395 |
-
num_items_in_batch=None,
|
| 1396 |
-
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
| 1397 |
-
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1398 |
-
|
| 1399 |
-
with compute_loss_context_manager:
|
| 1400 |
-
loss, metrics = self.get_batch_loss_metrics(model, inputs)
|
| 1401 |
-
|
| 1402 |
-
# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
|
| 1403 |
-
loss = loss.to(self.args.device)
|
| 1404 |
-
# force log the metrics
|
| 1405 |
-
if self.accelerator.is_main_process:
|
| 1406 |
-
self.store_metrics(metrics, train_eval="train")
|
| 1407 |
-
|
| 1408 |
-
if return_outputs:
|
| 1409 |
-
return (loss, metrics)
|
| 1410 |
-
return loss
|
| 1411 |
-
|
| 1412 |
-
def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
|
| 1413 |
-
for key, value in metrics.items():
|
| 1414 |
-
self._stored_metrics[train_eval][key].append(value)
|
| 1415 |
-
|
| 1416 |
-
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
| 1417 |
-
if self.train_dataset is None or not has_length(self.train_dataset):
|
| 1418 |
-
return None
|
| 1419 |
-
return SequentialSampler(self.train_dataset)
|
| 1420 |
-
|
| 1421 |
-
def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
|
| 1422 |
-
"""Generate samples from the model and reference model for the given batch of inputs."""
|
| 1423 |
-
|
| 1424 |
-
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
| 1425 |
-
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
|
| 1426 |
-
generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1427 |
-
with generate_context_manager:
|
| 1428 |
-
policy_output = model.generate(
|
| 1429 |
-
input_ids=batch["prompt_input_ids"],
|
| 1430 |
-
attention_mask=batch["prompt_attention_mask"],
|
| 1431 |
-
max_length=self.max_length,
|
| 1432 |
-
do_sample=True,
|
| 1433 |
-
pad_token_id=self.processing_class.pad_token_id,
|
| 1434 |
-
)
|
| 1435 |
-
|
| 1436 |
-
# if reference_output in batch use that otherwise use the reference model
|
| 1437 |
-
if "reference_output" in batch:
|
| 1438 |
-
reference_output = batch["reference_output"]
|
| 1439 |
-
else:
|
| 1440 |
-
if self.ref_model is None:
|
| 1441 |
-
with self.null_ref_context():
|
| 1442 |
-
reference_output = self.model.generate(
|
| 1443 |
-
input_ids=batch["prompt_input_ids"],
|
| 1444 |
-
attention_mask=batch["prompt_attention_mask"],
|
| 1445 |
-
max_length=self.max_length,
|
| 1446 |
-
do_sample=True,
|
| 1447 |
-
pad_token_id=self.processing_class.pad_token_id,
|
| 1448 |
-
)
|
| 1449 |
-
else:
|
| 1450 |
-
reference_output = self.ref_model.generate(
|
| 1451 |
-
input_ids=batch["prompt_input_ids"],
|
| 1452 |
-
attention_mask=batch["prompt_attention_mask"],
|
| 1453 |
-
max_length=self.max_length,
|
| 1454 |
-
do_sample=True,
|
| 1455 |
-
pad_token_id=self.processing_class.pad_token_id,
|
| 1456 |
-
)
|
| 1457 |
-
|
| 1458 |
-
policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
|
| 1459 |
-
policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
|
| 1460 |
-
|
| 1461 |
-
reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
|
| 1462 |
-
reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)
|
| 1463 |
-
|
| 1464 |
-
return policy_output_decoded, reference_output_decoded
|
| 1465 |
-
|
| 1466 |
-
def prediction_step(
|
| 1467 |
-
self,
|
| 1468 |
-
model: Union[PreTrainedModel, nn.Module],
|
| 1469 |
-
inputs: dict[str, Union[torch.Tensor, Any]],
|
| 1470 |
-
prediction_loss_only: bool,
|
| 1471 |
-
ignore_keys: Optional[list[str]] = None,
|
| 1472 |
-
):
|
| 1473 |
-
if ignore_keys is None:
|
| 1474 |
-
if hasattr(model, "config"):
|
| 1475 |
-
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
| 1476 |
-
else:
|
| 1477 |
-
ignore_keys = []
|
| 1478 |
-
|
| 1479 |
-
prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1480 |
-
with torch.no_grad(), prediction_context_manager:
|
| 1481 |
-
loss, metrics = self.get_batch_loss_metrics(model, inputs)
|
| 1482 |
-
|
| 1483 |
-
# force log the metrics
|
| 1484 |
-
if self.accelerator.is_main_process:
|
| 1485 |
-
self.store_metrics(metrics, train_eval="eval")
|
| 1486 |
-
|
| 1487 |
-
if prediction_loss_only:
|
| 1488 |
-
return (loss.detach(), None, None)
|
| 1489 |
-
|
| 1490 |
-
# logits for the chosen and rejected samples from model
|
| 1491 |
-
logits_dict = {}
|
| 1492 |
-
if "logits/chosen_sum" in metrics:
|
| 1493 |
-
logits_dict["eval_logits/chosen"] = metrics["logits/chosen_sum"]
|
| 1494 |
-
if "logits/rejected_sum" in metrics:
|
| 1495 |
-
logits_dict["eval_logits/rejected"] = metrics["logits/rejected_sum"]
|
| 1496 |
-
logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
|
| 1497 |
-
logits = torch.tensor(logits, device=self.accelerator.device)
|
| 1498 |
-
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
| 1499 |
-
|
| 1500 |
-
return (loss.detach(), logits, labels)
|
| 1501 |
-
|
| 1502 |
-
def evaluation_loop(
|
| 1503 |
-
self,
|
| 1504 |
-
dataloader: DataLoader,
|
| 1505 |
-
description: str,
|
| 1506 |
-
prediction_loss_only: Optional[bool] = None,
|
| 1507 |
-
ignore_keys: Optional[list[str]] = None,
|
| 1508 |
-
metric_key_prefix: str = "eval",
|
| 1509 |
-
) -> EvalLoopOutput:
|
| 1510 |
-
"""
|
| 1511 |
-
Overriding built-in evaluation loop to store metrics for each batch.
|
| 1512 |
-
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
| 1513 |
-
|
| 1514 |
-
Works both with or without labels.
|
| 1515 |
-
"""
|
| 1516 |
-
|
| 1517 |
-
# Sample and save to game log if requested (for one batch to save time)
|
| 1518 |
-
if self.generate_during_eval:
|
| 1519 |
-
# Generate random indices within the range of the total number of samples
|
| 1520 |
-
num_samples = len(dataloader.dataset)
|
| 1521 |
-
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
|
| 1522 |
-
|
| 1523 |
-
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
| 1524 |
-
random_batch_dataset = dataloader.dataset.select(random_indices)
|
| 1525 |
-
random_batch = self.data_collator(random_batch_dataset)
|
| 1526 |
-
random_batch = self._prepare_inputs(random_batch)
|
| 1527 |
-
|
| 1528 |
-
target_indicies = [i for i in range(len(random_batch["label"])) if random_batch["label"][i] is False]
|
| 1529 |
-
target_batch = {
|
| 1530 |
-
"prompt_input_ids": random_batch["prompt_input_ids"][target_indicies],
|
| 1531 |
-
"prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies],
|
| 1532 |
-
"prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
|
| 1533 |
-
}
|
| 1534 |
-
policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)
|
| 1535 |
-
|
| 1536 |
-
table = pd.DataFrame(
|
| 1537 |
-
columns=["Prompt", "Policy", "Ref Model"],
|
| 1538 |
-
data=[
|
| 1539 |
-
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
|
| 1540 |
-
for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
|
| 1541 |
-
],
|
| 1542 |
-
)
|
| 1543 |
-
if "wandb" in self.args.report_to:
|
| 1544 |
-
wandb.log({"game_log": wandb.Table(data=table)})
|
| 1545 |
-
|
| 1546 |
-
if "comet_ml" in self.args.report_to:
|
| 1547 |
-
log_table_to_comet_experiment(
|
| 1548 |
-
name="game_log.csv",
|
| 1549 |
-
table=table,
|
| 1550 |
-
)
|
| 1551 |
-
|
| 1552 |
-
# Base evaluation
|
| 1553 |
-
initial_output = super().evaluation_loop(
|
| 1554 |
-
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
|
| 1555 |
-
)
|
| 1556 |
-
|
| 1557 |
-
return initial_output
|
| 1558 |
-
|
| 1559 |
-
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
| 1560 |
-
"""
|
| 1561 |
-
Log `logs` on the various objects watching training, including stored metrics.
|
| 1562 |
-
|
| 1563 |
-
Args:
|
| 1564 |
-
logs (`dict[str, float]`):
|
| 1565 |
-
The values to log.
|
| 1566 |
-
start_time (`float` or `None`, *optional*, defaults to `None`):
|
| 1567 |
-
Start time of the training.
|
| 1568 |
-
"""
|
| 1569 |
-
# logs either has 'loss' or 'eval_loss'
|
| 1570 |
-
train_eval = "train" if "loss" in logs else "eval"
|
| 1571 |
-
# train metrics should have no prefix, eval should have 'eval_'
|
| 1572 |
-
prefix = "eval_" if train_eval == "eval" else ""
|
| 1573 |
-
# accumulate average metrics from sums and lengths
|
| 1574 |
-
for split in ["chosen", "rejected"]:
|
| 1575 |
-
if f"count/{split}" in self._stored_metrics[train_eval]:
|
| 1576 |
-
count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
|
| 1577 |
-
for metric in ["rewards", "logps", "logits"]:
|
| 1578 |
-
logs[f"{prefix}{metric}/{split}"] = (
|
| 1579 |
-
torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item()
|
| 1580 |
-
/ count_sum
|
| 1581 |
-
)
|
| 1582 |
-
# delete obsolete metric
|
| 1583 |
-
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
|
| 1584 |
-
del self._stored_metrics[train_eval][f"count/{split}"]
|
| 1585 |
-
# calculate reward margin
|
| 1586 |
-
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
|
| 1587 |
-
logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
|
| 1588 |
-
# Add averaged stored metrics to logs
|
| 1589 |
-
for key, metrics in self._stored_metrics[train_eval].items():
|
| 1590 |
-
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
|
| 1591 |
-
del self._stored_metrics[train_eval]
|
| 1592 |
-
|
| 1593 |
-
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
| 1594 |
-
return super().log(logs, start_time)
|
| 1595 |
-
else: # transformers<=4.46
|
| 1596 |
-
return super().log(logs)
|
| 1597 |
-
|
| 1598 |
-
def create_model_card(
|
| 1599 |
-
self,
|
| 1600 |
-
model_name: Optional[str] = None,
|
| 1601 |
-
dataset_name: Optional[str] = None,
|
| 1602 |
-
tags: Union[str, list[str], None] = None,
|
| 1603 |
-
):
|
| 1604 |
-
"""
|
| 1605 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
| 1606 |
-
|
| 1607 |
-
Args:
|
| 1608 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1609 |
-
Name of the model.
|
| 1610 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1611 |
-
Name of the dataset used for training.
|
| 1612 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 1613 |
-
Tags to be associated with the model card.
|
| 1614 |
-
"""
|
| 1615 |
-
if not self.is_world_process_zero():
|
| 1616 |
-
return
|
| 1617 |
-
|
| 1618 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 1619 |
-
base_model = self.model.config._name_or_path
|
| 1620 |
-
else:
|
| 1621 |
-
base_model = None
|
| 1622 |
-
|
| 1623 |
-
tags = tags or []
|
| 1624 |
-
if isinstance(tags, str):
|
| 1625 |
-
tags = [tags]
|
| 1626 |
-
|
| 1627 |
-
if hasattr(self.model.config, "unsloth_version"):
|
| 1628 |
-
tags.append("unsloth")
|
| 1629 |
-
|
| 1630 |
-
citation = textwrap.dedent("""\
|
| 1631 |
-
@article{jung2024binary,
|
| 1632 |
-
title = {{Binary Classifier Optimization for Large Language Model Alignment}},
|
| 1633 |
-
author = {Seungjae Jung and Gunsoo Han and Daniel Wontae Nam and Kyoung{-}Woon On},
|
| 1634 |
-
year = 2024,
|
| 1635 |
-
eprint = {arXiv:2404.04656}
|
| 1636 |
-
}""")
|
| 1637 |
-
|
| 1638 |
-
model_card = generate_model_card(
|
| 1639 |
-
base_model=base_model,
|
| 1640 |
-
model_name=model_name,
|
| 1641 |
-
hub_model_id=self.hub_model_id,
|
| 1642 |
-
dataset_name=dataset_name,
|
| 1643 |
-
tags=tags,
|
| 1644 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 1645 |
-
comet_url=get_comet_experiment_url(),
|
| 1646 |
-
trainer_name="BCO",
|
| 1647 |
-
trainer_citation=citation,
|
| 1648 |
-
paper_title="Binary Classifier Optimization for Large Language Model Alignment",
|
| 1649 |
-
paper_id="2404.04656",
|
| 1650 |
-
)
|
| 1651 |
-
|
| 1652 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 1653 |
-
class UnslothBCOTrainer(_UnslothBCOTrainer):
|
| 1654 |
-
"""
|
| 1655 |
-
|
| 1656 |
-
Initialize BCOTrainer from [BCO](https://huggingface.co/papers/2404.04656) paper.
|
| 1657 |
-
|
| 1658 |
-
Args:
|
| 1659 |
-
model (`transformers.PreTrainedModel`):
|
| 1660 |
-
The model to train, preferably an `AutoModelForSequenceClassification`.
|
| 1661 |
-
ref_model (`PreTrainedModelWrapper`):
|
| 1662 |
-
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
|
| 1663 |
-
reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
|
| 1664 |
-
args (`BCOConfig`):
|
| 1665 |
-
The arguments to use for training.
|
| 1666 |
-
train_dataset (`datasets.Dataset`):
|
| 1667 |
-
The dataset to use for training.
|
| 1668 |
-
eval_dataset (`datasets.Dataset`):
|
| 1669 |
-
The dataset to use for evaluation.
|
| 1670 |
-
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
| 1671 |
-
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
| 1672 |
-
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
| 1673 |
-
reuse the fine-tuned model.
|
| 1674 |
-
data_collator (`transformers.DataCollator`, *optional*, defaults to `None`):
|
| 1675 |
-
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
| 1676 |
-
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
| 1677 |
-
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
| 1678 |
-
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
| 1679 |
-
callbacks (`list[transformers.TrainerCallback]`):
|
| 1680 |
-
The callbacks to use for training.
|
| 1681 |
-
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
| 1682 |
-
The optimizer and scheduler to use for training.
|
| 1683 |
-
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
| 1684 |
-
The function to use to preprocess the logits before computing the metrics.
|
| 1685 |
-
peft_config (`dict`, defaults to `None`):
|
| 1686 |
-
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
| 1687 |
-
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
| 1688 |
-
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
| 1689 |
-
a dictionary string to metric values.
|
| 1690 |
-
model_adapter_name (`str`, defaults to `None`):
|
| 1691 |
-
Name of the train target PEFT adapter, when using LoRA with multiple adapters.
|
| 1692 |
-
ref_adapter_name (`str`, defaults to `None`):
|
| 1693 |
-
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
|
| 1694 |
-
|
| 1695 |
-
"""
|
| 1696 |
-
def __init__(
|
| 1697 |
-
self,
|
| 1698 |
-
model = None,
|
| 1699 |
-
ref_model = None,
|
| 1700 |
-
args = None,
|
| 1701 |
-
train_dataset = None,
|
| 1702 |
-
eval_dataset = None,
|
| 1703 |
-
processing_class = None,
|
| 1704 |
-
data_collator = None,
|
| 1705 |
-
model_init = None,
|
| 1706 |
-
callbacks = None,
|
| 1707 |
-
preprocess_logits_for_metrics = None,
|
| 1708 |
-
peft_config = None,
|
| 1709 |
-
compute_metrics = None,
|
| 1710 |
-
model_adapter_name = None,
|
| 1711 |
-
ref_adapter_name = None,
|
| 1712 |
-
embedding_func = None,
|
| 1713 |
-
embedding_tokenizer = None,
|
| 1714 |
-
**kwargs
|
| 1715 |
-
):
|
| 1716 |
-
if args is None: args = UnslothBCOConfig()
|
| 1717 |
-
use_bf16 = getattr(args, 'bf16', False)
|
| 1718 |
-
if type(use_bf16) is not bool: use_bf16 = False
|
| 1719 |
-
use_fp16 = getattr(args, 'fp16', False)
|
| 1720 |
-
if type(use_fp16) is not bool: use_fp16 = False
|
| 1721 |
-
force_float32 = False
|
| 1722 |
-
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
| 1723 |
-
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 1724 |
-
force_float32 = True
|
| 1725 |
-
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 1726 |
-
dtype = getattr(model.config, 'torch_dtype', None)
|
| 1727 |
-
if dtype is None: dtype = model.get_input_embeddings().dtype
|
| 1728 |
-
from unsloth_zoo.utils import _get_dtype
|
| 1729 |
-
dtype = _get_dtype(dtype)
|
| 1730 |
-
float16 = dtype == torch.float16
|
| 1731 |
-
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 1732 |
-
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 1733 |
-
if force_float32:
|
| 1734 |
-
args.fp16 = False
|
| 1735 |
-
args.bf16 = False
|
| 1736 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 1737 |
-
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 1738 |
-
args.fp16 = float16
|
| 1739 |
-
args.bf16 = not float16
|
| 1740 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 1741 |
-
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 1742 |
-
args.eval_strategy = 'steps'
|
| 1743 |
-
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 1744 |
-
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 1745 |
-
if ga_steps is not None and ga_steps > 1:
|
| 1746 |
-
from transformers import __version__ as transformers_version
|
| 1747 |
-
if Version(transformers_version) <= Version('4.45.2'):
|
| 1748 |
-
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 1749 |
-
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 1750 |
-
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 1751 |
-
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 1752 |
-
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 1753 |
-
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 1754 |
-
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 1755 |
-
if type(fp16_full_eval) is not bool: fp16_full_eval = False
|
| 1756 |
-
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 1757 |
-
if type(bf16_full_eval) is not bool: bf16_full_eval = False
|
| 1758 |
-
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 1759 |
-
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 1760 |
-
if force_float32:
|
| 1761 |
-
args.bf16_full_eval = False
|
| 1762 |
-
args.fp16_full_eval = False
|
| 1763 |
-
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 1764 |
-
args.bf16_full_eval = True
|
| 1765 |
-
args.fp16_full_eval = False
|
| 1766 |
-
elif not bf16_full_eval and not fp16_full_eval:
|
| 1767 |
-
args.bf16_full_eval = args.bf16
|
| 1768 |
-
args.fp16_full_eval = args.fp16
|
| 1769 |
-
_output_logits = False
|
| 1770 |
-
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 1771 |
-
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 1772 |
-
if _output_logits:
|
| 1773 |
-
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 1774 |
-
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 1775 |
-
pass
|
| 1776 |
-
else:
|
| 1777 |
-
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 1778 |
-
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 1779 |
-
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 1780 |
-
max_seq_length = model.max_seq_length
|
| 1781 |
-
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 1782 |
-
if model is not None and hasattr(model, 'for_training'):
|
| 1783 |
-
model.for_training()
|
| 1784 |
-
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 1785 |
-
if 'processing_class' in locals():
|
| 1786 |
-
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 1787 |
-
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 1788 |
-
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 1789 |
-
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 1790 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1791 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 1792 |
-
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
|
| 1793 |
-
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 1794 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
| 1795 |
-
else:
|
| 1796 |
-
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 1797 |
-
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 1798 |
-
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 1799 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1800 |
-
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 1801 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 1802 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
| 1803 |
-
else:
|
| 1804 |
-
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
|
| 1805 |
-
other_metrics = []
|
| 1806 |
-
|
| 1807 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 1808 |
-
PatchRLStatistics('bco_trainer', other_metrics)
|
| 1809 |
-
|
| 1810 |
-
super().__init__(
|
| 1811 |
-
model = model,
|
| 1812 |
-
ref_model = ref_model,
|
| 1813 |
-
args = args,
|
| 1814 |
-
train_dataset = train_dataset,
|
| 1815 |
-
eval_dataset = eval_dataset,
|
| 1816 |
-
processing_class = processing_class,
|
| 1817 |
-
data_collator = data_collator,
|
| 1818 |
-
model_init = model_init,
|
| 1819 |
-
callbacks = callbacks,
|
| 1820 |
-
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
| 1821 |
-
peft_config = peft_config,
|
| 1822 |
-
compute_metrics = compute_metrics,
|
| 1823 |
-
model_adapter_name = model_adapter_name,
|
| 1824 |
-
ref_adapter_name = ref_adapter_name,
|
| 1825 |
-
embedding_func = embedding_func,
|
| 1826 |
-
embedding_tokenizer = embedding_tokenizer,**kwargs)
|
| 1827 |
-
if hasattr(self, 'neftune_hook_handle'):
|
| 1828 |
-
self.neftune_hook_handle.remove()
|
| 1829 |
-
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 1830 |
-
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 1831 |
-
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 1832 |
-
pass
|
| 1833 |
-
|
| 1834 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_run_uploads/UnslothCPOTrainer.py
DELETED
|
@@ -1,1566 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
2025.7.11
|
| 3 |
-
2025.7.11
|
| 4 |
-
4.54.1
|
| 5 |
-
0.16.1
|
| 6 |
-
__UNSLOTH_VERSIONING__
|
| 7 |
-
"""
|
| 8 |
-
from torch import Tensor
|
| 9 |
-
import torch
|
| 10 |
-
import torch.nn as nn
|
| 11 |
-
from torch.nn import functional as F
|
| 12 |
-
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
| 13 |
-
from trl.trainer.cpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, CPOConfig, CPOTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, amp, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, transformers, version, wandb, warnings)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
import os
|
| 17 |
-
from typing import *
|
| 18 |
-
from dataclasses import dataclass, field
|
| 19 |
-
from packaging.version import Version
|
| 20 |
-
import torch
|
| 21 |
-
import numpy as np
|
| 22 |
-
from contextlib import nullcontext
|
| 23 |
-
from torch.nn import functional as F
|
| 24 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
| 25 |
-
|
| 26 |
-
torch_compile_options = {
|
| 27 |
-
"epilogue_fusion" : True,
|
| 28 |
-
"max_autotune" : False,
|
| 29 |
-
"shape_padding" : True,
|
| 30 |
-
"trace.enabled" : False,
|
| 31 |
-
"triton.cudagraphs" : False,
|
| 32 |
-
}
|
| 33 |
-
|
| 34 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 35 |
-
def chunked_selective_log_softmax(logits, index):
|
| 36 |
-
# Split into 4 chunks only
|
| 37 |
-
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
|
| 38 |
-
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
|
| 39 |
-
all_per_token_logps = []
|
| 40 |
-
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
|
| 41 |
-
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
|
| 42 |
-
chunk_logits = chunk_logits.to(torch.float32)
|
| 43 |
-
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 44 |
-
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
|
| 45 |
-
per_token_logps = selected_logits - logsumexp_values
|
| 46 |
-
all_per_token_logps.append(per_token_logps)
|
| 47 |
-
pass
|
| 48 |
-
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 49 |
-
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
|
| 50 |
-
return all_per_token_logps
|
| 51 |
-
@dataclass
|
| 52 |
-
class UnslothCPOConfig(CPOConfig):
|
| 53 |
-
"""
|
| 54 |
-
|
| 55 |
-
Configuration class for the [`CPOTrainer`].
|
| 56 |
-
|
| 57 |
-
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 58 |
-
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 59 |
-
command line.
|
| 60 |
-
|
| 61 |
-
Parameters:
|
| 62 |
-
learning_rate (`float`, *optional*, defaults to `1e-6`):
|
| 63 |
-
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
| 64 |
-
[`~transformers.TrainingArguments`].
|
| 65 |
-
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
| 66 |
-
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
|
| 67 |
-
to use the default data collator.
|
| 68 |
-
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
| 69 |
-
Maximum length of the prompt. This argument is required if you want to use the default data collator.
|
| 70 |
-
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
|
| 71 |
-
Maximum length of the completion. This argument is required if you want to use the default data collator
|
| 72 |
-
and your model is an encoder-decoder.
|
| 73 |
-
beta (`float`, *optional*, defaults to `0.1`):
|
| 74 |
-
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
|
| 75 |
-
reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
|
| 76 |
-
the [paper](https://huggingface.co/papers/2310.12036).
|
| 77 |
-
label_smoothing (`float`, *optional*, defaults to `0.0`):
|
| 78 |
-
Label smoothing factor. This argument is required if you want to use the default data collator.
|
| 79 |
-
loss_type (`str`, *optional*, defaults to `"sigmoid"`):
|
| 80 |
-
Type of loss to use. Possible values are:
|
| 81 |
-
|
| 82 |
-
- `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
|
| 83 |
-
- `"hinge"`: hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper.
|
| 84 |
-
- `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
|
| 85 |
-
- `"simpo"`: SimPO loss from the [SimPO](https://huggingface.co/papers/2405.14734) paper.
|
| 86 |
-
|
| 87 |
-
disable_dropout (`bool`, *optional*, defaults to `True`):
|
| 88 |
-
Whether to disable dropout in the model.
|
| 89 |
-
cpo_alpha (`float`, *optional*, defaults to `1.0`):
|
| 90 |
-
Weight of the BC regularizer in CPO training.
|
| 91 |
-
simpo_gamma (`float`, *optional*, defaults to `0.5`):
|
| 92 |
-
Target reward margin for the SimPO loss, used only when the `loss_type="simpo"`.
|
| 93 |
-
label_pad_token_id (`int`, *optional*, defaults to `-100`):
|
| 94 |
-
Label pad token id. This argument is required if you want to use the default data collator.
|
| 95 |
-
padding_value (`int` or `None`, *optional*, defaults to `None`):
|
| 96 |
-
Padding value to use. If `None`, the padding value of the tokenizer is used.
|
| 97 |
-
truncation_mode (`str`,*optional*, defaults to `"keep_end"`):
|
| 98 |
-
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
|
| 99 |
-
This argument is required if you want to use the default data collator.
|
| 100 |
-
generate_during_eval (`bool`, *optional*, defaults to `False`):
|
| 101 |
-
If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
|
| 102 |
-
is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
|
| 103 |
-
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
|
| 104 |
-
you need to specify if the model returned by the callable is an encoder-decoder model.
|
| 105 |
-
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
| 106 |
-
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
|
| 107 |
-
string.
|
| 108 |
-
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
| 109 |
-
Number of processes to use for processing the dataset.
|
| 110 |
-
|
| 111 |
-
"""
|
| 112 |
-
vllm_sampling_params: Optional[Any] = field(
|
| 113 |
-
default = None,
|
| 114 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
| 115 |
-
)
|
| 116 |
-
unsloth_num_chunks : Optional[int] = field(
|
| 117 |
-
default = -1,
|
| 118 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 119 |
-
)
|
| 120 |
-
def __init__(
|
| 121 |
-
self,
|
| 122 |
-
output_dir = None,
|
| 123 |
-
overwrite_output_dir = None,
|
| 124 |
-
do_train = False,
|
| 125 |
-
do_eval = False,
|
| 126 |
-
do_predict = False,
|
| 127 |
-
eval_strategy = 'no',
|
| 128 |
-
prediction_loss_only = False,
|
| 129 |
-
per_device_train_batch_size = 4,
|
| 130 |
-
per_device_eval_batch_size = 4,
|
| 131 |
-
per_gpu_train_batch_size = None,
|
| 132 |
-
per_gpu_eval_batch_size = None,
|
| 133 |
-
gradient_accumulation_steps = 2,
|
| 134 |
-
eval_accumulation_steps = 2,
|
| 135 |
-
eval_delay = 0,
|
| 136 |
-
torch_empty_cache_steps = 250,
|
| 137 |
-
learning_rate = 5e-05,
|
| 138 |
-
weight_decay = 0.01,
|
| 139 |
-
adam_beta1 = 0.9,
|
| 140 |
-
adam_beta2 = 0.999,
|
| 141 |
-
adam_epsilon = 1e-08,
|
| 142 |
-
max_grad_norm = 1.0,
|
| 143 |
-
num_train_epochs = 3.0,
|
| 144 |
-
max_steps = -1,
|
| 145 |
-
lr_scheduler_type = 'linear',
|
| 146 |
-
warmup_ratio = 0.1,
|
| 147 |
-
warmup_steps = 0,
|
| 148 |
-
log_level = 'passive',
|
| 149 |
-
log_level_replica = 'warning',
|
| 150 |
-
log_on_each_node = True,
|
| 151 |
-
logging_dir = None,
|
| 152 |
-
logging_strategy = 'steps',
|
| 153 |
-
logging_first_step = False,
|
| 154 |
-
logging_steps = 1,
|
| 155 |
-
logging_nan_inf_filter = False,
|
| 156 |
-
save_strategy = 'steps',
|
| 157 |
-
save_steps = 500,
|
| 158 |
-
save_total_limit = None,
|
| 159 |
-
save_safetensors = True,
|
| 160 |
-
save_on_each_node = False,
|
| 161 |
-
save_only_model = False,
|
| 162 |
-
restore_callback_states_from_checkpoint = False,
|
| 163 |
-
no_cuda = False,
|
| 164 |
-
use_cpu = False,
|
| 165 |
-
use_mps_device = False,
|
| 166 |
-
seed = 3407,
|
| 167 |
-
data_seed = 3407,
|
| 168 |
-
jit_mode_eval = False,
|
| 169 |
-
use_ipex = False,
|
| 170 |
-
bf16 = False,
|
| 171 |
-
fp16 = False,
|
| 172 |
-
fp16_opt_level = 'O1',
|
| 173 |
-
half_precision_backend = 'auto',
|
| 174 |
-
bf16_full_eval = False,
|
| 175 |
-
fp16_full_eval = False,
|
| 176 |
-
tf32 = None,
|
| 177 |
-
local_rank = -1,
|
| 178 |
-
ddp_backend = None,
|
| 179 |
-
tpu_num_cores = None,
|
| 180 |
-
tpu_metrics_debug = False,
|
| 181 |
-
debug = '',
|
| 182 |
-
dataloader_drop_last = False,
|
| 183 |
-
eval_steps = None,
|
| 184 |
-
dataloader_num_workers = 0,
|
| 185 |
-
dataloader_prefetch_factor = None,
|
| 186 |
-
past_index = -1,
|
| 187 |
-
run_name = None,
|
| 188 |
-
disable_tqdm = None,
|
| 189 |
-
remove_unused_columns = True,
|
| 190 |
-
label_names = None,
|
| 191 |
-
load_best_model_at_end = False,
|
| 192 |
-
metric_for_best_model = None,
|
| 193 |
-
greater_is_better = None,
|
| 194 |
-
ignore_data_skip = False,
|
| 195 |
-
fsdp = '',
|
| 196 |
-
fsdp_min_num_params = 0,
|
| 197 |
-
fsdp_config = None,
|
| 198 |
-
fsdp_transformer_layer_cls_to_wrap = None,
|
| 199 |
-
accelerator_config = None,
|
| 200 |
-
deepspeed = None,
|
| 201 |
-
label_smoothing_factor = 0.0,
|
| 202 |
-
optim = 'adamw_8bit',
|
| 203 |
-
optim_args = None,
|
| 204 |
-
adafactor = False,
|
| 205 |
-
group_by_length = False,
|
| 206 |
-
length_column_name = 'length',
|
| 207 |
-
report_to = None,
|
| 208 |
-
ddp_find_unused_parameters = None,
|
| 209 |
-
ddp_bucket_cap_mb = None,
|
| 210 |
-
ddp_broadcast_buffers = None,
|
| 211 |
-
dataloader_pin_memory = True,
|
| 212 |
-
dataloader_persistent_workers = False,
|
| 213 |
-
skip_memory_metrics = True,
|
| 214 |
-
use_legacy_prediction_loop = False,
|
| 215 |
-
push_to_hub = False,
|
| 216 |
-
resume_from_checkpoint = None,
|
| 217 |
-
hub_model_id = None,
|
| 218 |
-
hub_strategy = 'every_save',
|
| 219 |
-
hub_token = None,
|
| 220 |
-
hub_private_repo = None,
|
| 221 |
-
hub_always_push = False,
|
| 222 |
-
hub_revision = None,
|
| 223 |
-
gradient_checkpointing = False,
|
| 224 |
-
gradient_checkpointing_kwargs = None,
|
| 225 |
-
include_inputs_for_metrics = False,
|
| 226 |
-
eval_do_concat_batches = True,
|
| 227 |
-
fp16_backend = 'auto',
|
| 228 |
-
push_to_hub_model_id = None,
|
| 229 |
-
push_to_hub_organization = None,
|
| 230 |
-
push_to_hub_token = None,
|
| 231 |
-
mp_parameters = '',
|
| 232 |
-
auto_find_batch_size = True,
|
| 233 |
-
full_determinism = False,
|
| 234 |
-
torchdynamo = None,
|
| 235 |
-
ray_scope = 'last',
|
| 236 |
-
ddp_timeout = 1800,
|
| 237 |
-
torch_compile = False,
|
| 238 |
-
torch_compile_backend = None,
|
| 239 |
-
torch_compile_mode = None,
|
| 240 |
-
include_tokens_per_second = False,
|
| 241 |
-
include_num_input_tokens_seen = False,
|
| 242 |
-
neftune_noise_alpha = None,
|
| 243 |
-
optim_target_modules = None,
|
| 244 |
-
batch_eval_metrics = False,
|
| 245 |
-
eval_on_start = False,
|
| 246 |
-
use_liger_kernel = False,
|
| 247 |
-
liger_kernel_config = None,
|
| 248 |
-
eval_use_gather_object = False,
|
| 249 |
-
average_tokens_across_devices = True,
|
| 250 |
-
max_length = 1024,
|
| 251 |
-
max_prompt_length = 512,
|
| 252 |
-
max_completion_length = None,
|
| 253 |
-
beta = 0.1,
|
| 254 |
-
label_smoothing = 0.0,
|
| 255 |
-
loss_type = 'sigmoid',
|
| 256 |
-
disable_dropout = True,
|
| 257 |
-
cpo_alpha = 1.0,
|
| 258 |
-
simpo_gamma = 0.5,
|
| 259 |
-
label_pad_token_id = -100,
|
| 260 |
-
padding_value = None,
|
| 261 |
-
truncation_mode = 'keep_end',
|
| 262 |
-
generate_during_eval = False,
|
| 263 |
-
is_encoder_decoder = None,
|
| 264 |
-
model_init_kwargs = None,
|
| 265 |
-
dataset_num_proc = None,
|
| 266 |
-
vllm_sampling_params = None,
|
| 267 |
-
unsloth_num_chunks = -1,
|
| 268 |
-
**kwargs,
|
| 269 |
-
):
|
| 270 |
-
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 271 |
-
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 272 |
-
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 273 |
-
output_dir = 'unsloth_training_checkpoints'
|
| 274 |
-
save_strategy = 'no'
|
| 275 |
-
if dataset_num_proc is None:
|
| 276 |
-
from multiprocessing import cpu_count
|
| 277 |
-
dataset_num_proc = min(cpu_count()*2, 2)
|
| 278 |
-
|
| 279 |
-
super().__init__(
|
| 280 |
-
output_dir = output_dir,
|
| 281 |
-
overwrite_output_dir = overwrite_output_dir,
|
| 282 |
-
do_train = do_train,
|
| 283 |
-
do_eval = do_eval,
|
| 284 |
-
do_predict = do_predict,
|
| 285 |
-
eval_strategy = eval_strategy,
|
| 286 |
-
prediction_loss_only = prediction_loss_only,
|
| 287 |
-
per_device_train_batch_size = per_device_train_batch_size,
|
| 288 |
-
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 289 |
-
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
| 290 |
-
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
| 291 |
-
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 292 |
-
eval_accumulation_steps = eval_accumulation_steps,
|
| 293 |
-
eval_delay = eval_delay,
|
| 294 |
-
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 295 |
-
learning_rate = learning_rate,
|
| 296 |
-
weight_decay = weight_decay,
|
| 297 |
-
adam_beta1 = adam_beta1,
|
| 298 |
-
adam_beta2 = adam_beta2,
|
| 299 |
-
adam_epsilon = adam_epsilon,
|
| 300 |
-
max_grad_norm = max_grad_norm,
|
| 301 |
-
num_train_epochs = num_train_epochs,
|
| 302 |
-
max_steps = max_steps,
|
| 303 |
-
lr_scheduler_type = lr_scheduler_type,
|
| 304 |
-
warmup_ratio = warmup_ratio,
|
| 305 |
-
warmup_steps = warmup_steps,
|
| 306 |
-
log_level = log_level,
|
| 307 |
-
log_level_replica = log_level_replica,
|
| 308 |
-
log_on_each_node = log_on_each_node,
|
| 309 |
-
logging_dir = logging_dir,
|
| 310 |
-
logging_strategy = logging_strategy,
|
| 311 |
-
logging_first_step = logging_first_step,
|
| 312 |
-
logging_steps = logging_steps,
|
| 313 |
-
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 314 |
-
save_strategy = save_strategy,
|
| 315 |
-
save_steps = save_steps,
|
| 316 |
-
save_total_limit = save_total_limit,
|
| 317 |
-
save_safetensors = save_safetensors,
|
| 318 |
-
save_on_each_node = save_on_each_node,
|
| 319 |
-
save_only_model = save_only_model,
|
| 320 |
-
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 321 |
-
no_cuda = no_cuda,
|
| 322 |
-
use_cpu = use_cpu,
|
| 323 |
-
use_mps_device = use_mps_device,
|
| 324 |
-
seed = seed,
|
| 325 |
-
data_seed = data_seed,
|
| 326 |
-
jit_mode_eval = jit_mode_eval,
|
| 327 |
-
use_ipex = use_ipex,
|
| 328 |
-
bf16 = bf16,
|
| 329 |
-
fp16 = fp16,
|
| 330 |
-
fp16_opt_level = fp16_opt_level,
|
| 331 |
-
half_precision_backend = half_precision_backend,
|
| 332 |
-
bf16_full_eval = bf16_full_eval,
|
| 333 |
-
fp16_full_eval = fp16_full_eval,
|
| 334 |
-
tf32 = tf32,
|
| 335 |
-
local_rank = local_rank,
|
| 336 |
-
ddp_backend = ddp_backend,
|
| 337 |
-
tpu_num_cores = tpu_num_cores,
|
| 338 |
-
tpu_metrics_debug = tpu_metrics_debug,
|
| 339 |
-
debug = debug,
|
| 340 |
-
dataloader_drop_last = dataloader_drop_last,
|
| 341 |
-
eval_steps = eval_steps,
|
| 342 |
-
dataloader_num_workers = dataloader_num_workers,
|
| 343 |
-
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 344 |
-
past_index = past_index,
|
| 345 |
-
run_name = run_name,
|
| 346 |
-
disable_tqdm = disable_tqdm,
|
| 347 |
-
remove_unused_columns = remove_unused_columns,
|
| 348 |
-
label_names = label_names,
|
| 349 |
-
load_best_model_at_end = load_best_model_at_end,
|
| 350 |
-
metric_for_best_model = metric_for_best_model,
|
| 351 |
-
greater_is_better = greater_is_better,
|
| 352 |
-
ignore_data_skip = ignore_data_skip,
|
| 353 |
-
fsdp = fsdp,
|
| 354 |
-
fsdp_min_num_params = fsdp_min_num_params,
|
| 355 |
-
fsdp_config = fsdp_config,
|
| 356 |
-
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
| 357 |
-
accelerator_config = accelerator_config,
|
| 358 |
-
deepspeed = deepspeed,
|
| 359 |
-
label_smoothing_factor = label_smoothing_factor,
|
| 360 |
-
optim = optim,
|
| 361 |
-
optim_args = optim_args,
|
| 362 |
-
adafactor = adafactor,
|
| 363 |
-
group_by_length = group_by_length,
|
| 364 |
-
length_column_name = length_column_name,
|
| 365 |
-
report_to = report_to,
|
| 366 |
-
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 367 |
-
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 368 |
-
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 369 |
-
dataloader_pin_memory = dataloader_pin_memory,
|
| 370 |
-
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 371 |
-
skip_memory_metrics = skip_memory_metrics,
|
| 372 |
-
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
| 373 |
-
push_to_hub = push_to_hub,
|
| 374 |
-
resume_from_checkpoint = resume_from_checkpoint,
|
| 375 |
-
hub_model_id = hub_model_id,
|
| 376 |
-
hub_strategy = hub_strategy,
|
| 377 |
-
hub_token = hub_token,
|
| 378 |
-
hub_private_repo = hub_private_repo,
|
| 379 |
-
hub_always_push = hub_always_push,
|
| 380 |
-
hub_revision = hub_revision,
|
| 381 |
-
gradient_checkpointing = gradient_checkpointing,
|
| 382 |
-
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 383 |
-
include_inputs_for_metrics = include_inputs_for_metrics,
|
| 384 |
-
eval_do_concat_batches = eval_do_concat_batches,
|
| 385 |
-
fp16_backend = fp16_backend,
|
| 386 |
-
push_to_hub_model_id = push_to_hub_model_id,
|
| 387 |
-
push_to_hub_organization = push_to_hub_organization,
|
| 388 |
-
push_to_hub_token = push_to_hub_token,
|
| 389 |
-
mp_parameters = mp_parameters,
|
| 390 |
-
auto_find_batch_size = auto_find_batch_size,
|
| 391 |
-
full_determinism = full_determinism,
|
| 392 |
-
torchdynamo = torchdynamo,
|
| 393 |
-
ray_scope = ray_scope,
|
| 394 |
-
ddp_timeout = ddp_timeout,
|
| 395 |
-
torch_compile = torch_compile,
|
| 396 |
-
torch_compile_backend = torch_compile_backend,
|
| 397 |
-
torch_compile_mode = torch_compile_mode,
|
| 398 |
-
include_tokens_per_second = include_tokens_per_second,
|
| 399 |
-
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 400 |
-
neftune_noise_alpha = neftune_noise_alpha,
|
| 401 |
-
optim_target_modules = optim_target_modules,
|
| 402 |
-
batch_eval_metrics = batch_eval_metrics,
|
| 403 |
-
eval_on_start = eval_on_start,
|
| 404 |
-
use_liger_kernel = use_liger_kernel,
|
| 405 |
-
liger_kernel_config = liger_kernel_config,
|
| 406 |
-
eval_use_gather_object = eval_use_gather_object,
|
| 407 |
-
average_tokens_across_devices = average_tokens_across_devices,
|
| 408 |
-
max_length = max_length,
|
| 409 |
-
max_prompt_length = max_prompt_length,
|
| 410 |
-
max_completion_length = max_completion_length,
|
| 411 |
-
beta = beta,
|
| 412 |
-
label_smoothing = label_smoothing,
|
| 413 |
-
loss_type = loss_type,
|
| 414 |
-
disable_dropout = disable_dropout,
|
| 415 |
-
cpo_alpha = cpo_alpha,
|
| 416 |
-
simpo_gamma = simpo_gamma,
|
| 417 |
-
label_pad_token_id = label_pad_token_id,
|
| 418 |
-
padding_value = padding_value,
|
| 419 |
-
truncation_mode = truncation_mode,
|
| 420 |
-
generate_during_eval = generate_during_eval,
|
| 421 |
-
is_encoder_decoder = is_encoder_decoder,
|
| 422 |
-
model_init_kwargs = model_init_kwargs,
|
| 423 |
-
dataset_num_proc = dataset_num_proc,**kwargs)
|
| 424 |
-
self.vllm_sampling_params = vllm_sampling_params
|
| 425 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
| 426 |
-
pass
|
| 427 |
-
|
| 428 |
-
class _UnslothCPOTrainer(Trainer):
|
| 429 |
-
r""""""
|
| 430 |
-
|
| 431 |
-
_tag_names = ["trl", "cpo"]
|
| 432 |
-
|
| 433 |
-
def __init__(
|
| 434 |
-
self,
|
| 435 |
-
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
| 436 |
-
args: Optional[CPOConfig] = None,
|
| 437 |
-
data_collator: Optional[DataCollator] = None,
|
| 438 |
-
train_dataset: Optional[Dataset] = None,
|
| 439 |
-
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 440 |
-
processing_class: Optional[
|
| 441 |
-
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 442 |
-
] = None,
|
| 443 |
-
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
| 444 |
-
callbacks: Optional[list[TrainerCallback]] = None,
|
| 445 |
-
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 446 |
-
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 447 |
-
peft_config: Optional[dict] = None,
|
| 448 |
-
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
| 449 |
-
):
|
| 450 |
-
if args.model_init_kwargs is None:
|
| 451 |
-
model_init_kwargs = {}
|
| 452 |
-
elif not isinstance(model, str):
|
| 453 |
-
raise ValueError("You passed model_kwargs to the CPOTrainer. But your model is already instantiated.")
|
| 454 |
-
else:
|
| 455 |
-
model_init_kwargs = args.model_init_kwargs
|
| 456 |
-
torch_dtype = model_init_kwargs.get("torch_dtype")
|
| 457 |
-
if torch_dtype is not None:
|
| 458 |
-
# Convert to `torch.dtype` if an str is passed
|
| 459 |
-
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
| 460 |
-
torch_dtype = getattr(torch, torch_dtype)
|
| 461 |
-
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
| 462 |
-
raise ValueError(
|
| 463 |
-
f"Invalid `torch_dtype` passed to the CPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
| 464 |
-
)
|
| 465 |
-
model_init_kwargs["torch_dtype"] = torch_dtype
|
| 466 |
-
|
| 467 |
-
if isinstance(model, str):
|
| 468 |
-
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
| 469 |
-
|
| 470 |
-
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
| 471 |
-
# has been called in order to properly call autocast if needed.
|
| 472 |
-
self._peft_has_been_casted_to_bf16 = False
|
| 473 |
-
|
| 474 |
-
if not is_peft_available() and peft_config is not None:
|
| 475 |
-
raise ValueError(
|
| 476 |
-
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
| 477 |
-
)
|
| 478 |
-
elif is_peft_available() and peft_config is not None:
|
| 479 |
-
# if model is a peft model and we have a peft_config, we merge and unload it first
|
| 480 |
-
if isinstance(model, PeftModel):
|
| 481 |
-
model = model.merge_and_unload()
|
| 482 |
-
|
| 483 |
-
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
| 484 |
-
_support_gc_kwargs = hasattr(
|
| 485 |
-
args, "gradient_checkpointing_kwargs"
|
| 486 |
-
) and "gradient_checkpointing_kwargs" in list(
|
| 487 |
-
inspect.signature(prepare_model_for_kbit_training).parameters
|
| 488 |
-
)
|
| 489 |
-
|
| 490 |
-
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
| 491 |
-
|
| 492 |
-
if _support_gc_kwargs:
|
| 493 |
-
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
| 494 |
-
|
| 495 |
-
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
| 496 |
-
elif getattr(args, "gradient_checkpointing", False):
|
| 497 |
-
# For backward compatibility with older versions of transformers
|
| 498 |
-
if hasattr(model, "enable_input_require_grads"):
|
| 499 |
-
model.enable_input_require_grads()
|
| 500 |
-
else:
|
| 501 |
-
|
| 502 |
-
def make_inputs_require_grad(module, input, output):
|
| 503 |
-
output.requires_grad_(True)
|
| 504 |
-
|
| 505 |
-
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 506 |
-
|
| 507 |
-
# get peft model with the given config
|
| 508 |
-
model = model
|
| 509 |
-
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
| 510 |
-
peft_module_casting_to_bf16(model)
|
| 511 |
-
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
| 512 |
-
self._peft_has_been_casted_to_bf16 = True
|
| 513 |
-
|
| 514 |
-
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
| 515 |
-
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
| 516 |
-
# fail or completely fail.
|
| 517 |
-
elif getattr(args, "gradient_checkpointing", False):
|
| 518 |
-
# For backward compatibility with older versions of transformers
|
| 519 |
-
if hasattr(model, "enable_input_require_grads"):
|
| 520 |
-
model.enable_input_require_grads()
|
| 521 |
-
else:
|
| 522 |
-
|
| 523 |
-
def make_inputs_require_grad(module, input, output):
|
| 524 |
-
output.requires_grad_(True)
|
| 525 |
-
|
| 526 |
-
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 527 |
-
|
| 528 |
-
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
|
| 529 |
-
raise ValueError(
|
| 530 |
-
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
|
| 531 |
-
" Please install `wandb` or `comet-ml` to resolve."
|
| 532 |
-
)
|
| 533 |
-
|
| 534 |
-
if model is not None:
|
| 535 |
-
self.is_encoder_decoder = model.config.is_encoder_decoder
|
| 536 |
-
elif args.is_encoder_decoder is None:
|
| 537 |
-
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
| 538 |
-
else:
|
| 539 |
-
self.is_encoder_decoder = args.is_encoder_decoder
|
| 540 |
-
|
| 541 |
-
if self.is_encoder_decoder:
|
| 542 |
-
self.decoder_start_token_id = model.config.decoder_start_token_id
|
| 543 |
-
self.pad_token_id = model.config.pad_token_id
|
| 544 |
-
|
| 545 |
-
if processing_class is None:
|
| 546 |
-
raise ValueError("processing_class must be specified to tokenize a CPO dataset.")
|
| 547 |
-
if args.max_length is None:
|
| 548 |
-
warnings.warn(
|
| 549 |
-
"`max_length` is not set in the CPOConfig's init"
|
| 550 |
-
" it will default to `512` by default, but you should do it yourself in the future.",
|
| 551 |
-
UserWarning,
|
| 552 |
-
)
|
| 553 |
-
max_length = 512
|
| 554 |
-
else:
|
| 555 |
-
max_length = args.max_length
|
| 556 |
-
if args.max_prompt_length is None:
|
| 557 |
-
warnings.warn(
|
| 558 |
-
"`max_prompt_length` is not set in the CPOConfig's init"
|
| 559 |
-
" it will default to `128` by default, but you should do it yourself in the future.",
|
| 560 |
-
UserWarning,
|
| 561 |
-
)
|
| 562 |
-
max_prompt_length = 128
|
| 563 |
-
else:
|
| 564 |
-
max_prompt_length = args.max_prompt_length
|
| 565 |
-
|
| 566 |
-
if args.max_completion_length is None and self.is_encoder_decoder:
|
| 567 |
-
warnings.warn(
|
| 568 |
-
"When using an encoder decoder architecture, you should set `max_completion_length` in the CPOConfig's init"
|
| 569 |
-
" it will default to `128` by default, but you should do it yourself in the future.",
|
| 570 |
-
UserWarning,
|
| 571 |
-
)
|
| 572 |
-
max_completion_length = 128
|
| 573 |
-
else:
|
| 574 |
-
max_completion_length = args.max_completion_length
|
| 575 |
-
|
| 576 |
-
if data_collator is None:
|
| 577 |
-
data_collator = DPODataCollatorWithPadding(
|
| 578 |
-
pad_token_id=processing_class.pad_token_id,
|
| 579 |
-
label_pad_token_id=args.label_pad_token_id,
|
| 580 |
-
is_encoder_decoder=self.is_encoder_decoder,
|
| 581 |
-
)
|
| 582 |
-
|
| 583 |
-
if args.remove_unused_columns:
|
| 584 |
-
args.remove_unused_columns = False
|
| 585 |
-
# warn users
|
| 586 |
-
warnings.warn(
|
| 587 |
-
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
|
| 588 |
-
" we have set it for you, but you should do it yourself in the future.",
|
| 589 |
-
UserWarning,
|
| 590 |
-
)
|
| 591 |
-
|
| 592 |
-
self.use_dpo_data_collator = True
|
| 593 |
-
else:
|
| 594 |
-
self.use_dpo_data_collator = False
|
| 595 |
-
|
| 596 |
-
# Disable dropout in the model
|
| 597 |
-
if args.disable_dropout:
|
| 598 |
-
disable_dropout_in_model(model)
|
| 599 |
-
|
| 600 |
-
self.max_length = max_length
|
| 601 |
-
self.generate_during_eval = args.generate_during_eval
|
| 602 |
-
self.label_pad_token_id = args.label_pad_token_id
|
| 603 |
-
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
|
| 604 |
-
self.max_prompt_length = max_prompt_length
|
| 605 |
-
self.truncation_mode = args.truncation_mode
|
| 606 |
-
self.max_completion_length = max_completion_length
|
| 607 |
-
self.processing_class = processing_class
|
| 608 |
-
|
| 609 |
-
if args.loss_type in ["hinge", "ipo"] and args.label_smoothing > 0:
|
| 610 |
-
warnings.warn(
|
| 611 |
-
f"You are using the {args.loss_type} loss type that does not support label smoothing. The "
|
| 612 |
-
"`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this warning.",
|
| 613 |
-
UserWarning,
|
| 614 |
-
)
|
| 615 |
-
if args.loss_type == "kto_pair":
|
| 616 |
-
raise ValueError("Support for kto_pair has been removed in CPOTrainer. Please use KTOTrainer.")
|
| 617 |
-
|
| 618 |
-
self.beta = args.beta
|
| 619 |
-
self.label_smoothing = args.label_smoothing
|
| 620 |
-
self.loss_type = args.loss_type
|
| 621 |
-
self.cpo_alpha = args.cpo_alpha
|
| 622 |
-
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
| 623 |
-
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
|
| 624 |
-
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
|
| 625 |
-
warnings.warn(
|
| 626 |
-
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
|
| 627 |
-
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
|
| 628 |
-
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
|
| 629 |
-
"loss.",
|
| 630 |
-
UserWarning,
|
| 631 |
-
)
|
| 632 |
-
|
| 633 |
-
if args.loss_type == "simpo":
|
| 634 |
-
self.simpo_gamma = args.simpo_gamma
|
| 635 |
-
|
| 636 |
-
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
| 637 |
-
|
| 638 |
-
# The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
|
| 639 |
-
# input tensor associated with the key "input_ids". However, in CPO, the sampled data does not include the
|
| 640 |
-
# "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
|
| 641 |
-
# "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
|
| 642 |
-
# of the input, floating-point operations will not be computed." To suppress this warning, we set the
|
| 643 |
-
# "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
|
| 644 |
-
# that the warning has already been issued.
|
| 645 |
-
model.warnings_issued["estimate_tokens"] = True
|
| 646 |
-
|
| 647 |
-
# Compute that only on the main process for faster data processing.
|
| 648 |
-
# see: https://github.com/huggingface/trl/pull/1255
|
| 649 |
-
with PartialState().main_process_first():
|
| 650 |
-
# Extract the prompt if needed, and apply the chat template if needed
|
| 651 |
-
train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
|
| 652 |
-
train_dataset = train_dataset.map(
|
| 653 |
-
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
|
| 654 |
-
)
|
| 655 |
-
if eval_dataset is not None:
|
| 656 |
-
eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
|
| 657 |
-
eval_dataset = eval_dataset.map(
|
| 658 |
-
maybe_apply_chat_template,
|
| 659 |
-
fn_kwargs={"tokenizer": processing_class},
|
| 660 |
-
num_proc=args.dataset_num_proc,
|
| 661 |
-
)
|
| 662 |
-
|
| 663 |
-
# tokenize the dataset
|
| 664 |
-
train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
| 665 |
-
if eval_dataset is not None:
|
| 666 |
-
eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
| 667 |
-
|
| 668 |
-
super().__init__(
|
| 669 |
-
model=model,
|
| 670 |
-
args=args,
|
| 671 |
-
data_collator=data_collator,
|
| 672 |
-
train_dataset=train_dataset,
|
| 673 |
-
eval_dataset=eval_dataset,
|
| 674 |
-
processing_class=processing_class,
|
| 675 |
-
model_init=model_init,
|
| 676 |
-
compute_metrics=compute_metrics,
|
| 677 |
-
callbacks=callbacks,
|
| 678 |
-
optimizers=optimizers,
|
| 679 |
-
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 680 |
-
)
|
| 681 |
-
|
| 682 |
-
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
| 683 |
-
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
| 684 |
-
# self.model_accepts_loss_kwargs to False to enable scaling.
|
| 685 |
-
self.model_accepts_loss_kwargs = False
|
| 686 |
-
|
| 687 |
-
# Add tags for models that have been loaded with the correct transformers version
|
| 688 |
-
if hasattr(self.model, "add_model_tags"):
|
| 689 |
-
self.model.add_model_tags(self._tag_names)
|
| 690 |
-
|
| 691 |
-
if not hasattr(self, "accelerator"):
|
| 692 |
-
raise AttributeError(
|
| 693 |
-
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
| 694 |
-
)
|
| 695 |
-
|
| 696 |
-
def build_tokenized_answer(self, prompt, answer):
|
| 697 |
-
"""
|
| 698 |
-
Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
|
| 699 |
-
It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
|
| 700 |
-
Reference:
|
| 701 |
-
https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
| 702 |
-
"""
|
| 703 |
-
|
| 704 |
-
full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
|
| 705 |
-
prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
|
| 706 |
-
|
| 707 |
-
answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
|
| 708 |
-
answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
|
| 709 |
-
|
| 710 |
-
# Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
|
| 711 |
-
full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
|
| 712 |
-
|
| 713 |
-
# Prepare input tokens for token by token comparison
|
| 714 |
-
full_input_ids = np.array(full_tokenized["input_ids"])
|
| 715 |
-
|
| 716 |
-
if len(full_input_ids) != len(full_concat_input_ids):
|
| 717 |
-
raise ValueError("Prompt input ids and answer input ids should have the same length.")
|
| 718 |
-
|
| 719 |
-
# On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
|
| 720 |
-
# can be merged together when tokenizing prompt+answer. This could result
|
| 721 |
-
# on the last token from the prompt being different when tokenized on its own
|
| 722 |
-
# vs when done as prompt+answer.
|
| 723 |
-
response_token_ids_start_idx = len(prompt_input_ids)
|
| 724 |
-
|
| 725 |
-
# If tokenized prompt is different than both prompt+answer, then it means the
|
| 726 |
-
# last token has changed due to merging.
|
| 727 |
-
if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
|
| 728 |
-
response_token_ids_start_idx -= 1
|
| 729 |
-
|
| 730 |
-
prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
|
| 731 |
-
prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
|
| 732 |
-
|
| 733 |
-
if len(prompt_input_ids) != len(prompt_attention_mask):
|
| 734 |
-
raise ValueError("Prompt input ids and attention mask should have the same length.")
|
| 735 |
-
|
| 736 |
-
answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
|
| 737 |
-
answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
|
| 738 |
-
|
| 739 |
-
return dict(
|
| 740 |
-
prompt_input_ids=prompt_input_ids,
|
| 741 |
-
prompt_attention_mask=prompt_attention_mask,
|
| 742 |
-
input_ids=answer_input_ids,
|
| 743 |
-
attention_mask=answer_attention_mask,
|
| 744 |
-
)
|
| 745 |
-
|
| 746 |
-
def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
|
| 747 |
-
"""Tokenize a single row from a CPO specific dataset.
|
| 748 |
-
|
| 749 |
-
At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
|
| 750 |
-
in case the prompt + chosen or prompt + rejected responses is/are too long. First
|
| 751 |
-
we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
|
| 752 |
-
|
| 753 |
-
We also create the labels for the chosen/rejected responses, which are of length equal to
|
| 754 |
-
the sum of the length of the prompt and the chosen/rejected response, with
|
| 755 |
-
label_pad_token_id for the prompt tokens.
|
| 756 |
-
"""
|
| 757 |
-
batch = {}
|
| 758 |
-
prompt = feature["prompt"]
|
| 759 |
-
chosen = feature["chosen"]
|
| 760 |
-
rejected = feature["rejected"]
|
| 761 |
-
|
| 762 |
-
if not self.is_encoder_decoder:
|
| 763 |
-
# Check issues below for more details
|
| 764 |
-
# 1. https://github.com/huggingface/trl/issues/907
|
| 765 |
-
# 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
| 766 |
-
# 3. https://github.com/LianjiaTech/BELLE/issues/337
|
| 767 |
-
|
| 768 |
-
if not isinstance(prompt, str):
|
| 769 |
-
raise ValueError(f"prompt should be an str but got {type(prompt)}")
|
| 770 |
-
prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
|
| 771 |
-
prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
|
| 772 |
-
|
| 773 |
-
if not isinstance(chosen, str):
|
| 774 |
-
raise ValueError(f"chosen should be an str but got {type(chosen)}")
|
| 775 |
-
chosen_tokens = self.build_tokenized_answer(prompt, chosen)
|
| 776 |
-
|
| 777 |
-
if not isinstance(rejected, str):
|
| 778 |
-
raise ValueError(f"rejected should be an str but got {type(rejected)}")
|
| 779 |
-
rejected_tokens = self.build_tokenized_answer(prompt, rejected)
|
| 780 |
-
|
| 781 |
-
# Last prompt token might get merged by tokenizer and
|
| 782 |
-
# it should not be included for generation if that happens
|
| 783 |
-
prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
|
| 784 |
-
|
| 785 |
-
chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
|
| 786 |
-
rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
|
| 787 |
-
prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
|
| 788 |
-
|
| 789 |
-
for k, v in prompt_tokens.items():
|
| 790 |
-
prompt_tokens[k] = v[:prompt_len_input_ids]
|
| 791 |
-
|
| 792 |
-
# Make sure prompts only have one different token at most an
|
| 793 |
-
# and length only differs by 1 at most
|
| 794 |
-
num_diff_tokens = sum(
|
| 795 |
-
[a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
|
| 796 |
-
)
|
| 797 |
-
num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
|
| 798 |
-
if num_diff_tokens > 1 or num_diff_len > 1:
|
| 799 |
-
raise ValueError(
|
| 800 |
-
"Chosen and rejected prompt_input_ids might only differ on the "
|
| 801 |
-
"last token due to tokenizer merge ops."
|
| 802 |
-
)
|
| 803 |
-
|
| 804 |
-
# add BOS token to head of prompt. Avoid adding if it's already there
|
| 805 |
-
prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
|
| 806 |
-
self.processing_class.bos_token_id,
|
| 807 |
-
prompt_len_input_ids,
|
| 808 |
-
prompt_tokens,
|
| 809 |
-
chosen_prompt_len_input_ids,
|
| 810 |
-
chosen_tokens,
|
| 811 |
-
rejected_prompt_len_input_ids,
|
| 812 |
-
rejected_tokens,
|
| 813 |
-
)
|
| 814 |
-
|
| 815 |
-
# add EOS token to end of answer. Avoid adding if it's already there
|
| 816 |
-
chosen_tokens, rejected_tokens = add_eos_token_if_needed(
|
| 817 |
-
self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
|
| 818 |
-
)
|
| 819 |
-
|
| 820 |
-
longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
|
| 821 |
-
|
| 822 |
-
# if combined sequence is too long, truncate the prompt
|
| 823 |
-
for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
|
| 824 |
-
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
| 825 |
-
if self.truncation_mode == "keep_start":
|
| 826 |
-
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
| 827 |
-
answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
|
| 828 |
-
elif self.truncation_mode == "keep_end":
|
| 829 |
-
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
| 830 |
-
answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
|
| 831 |
-
else:
|
| 832 |
-
raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
|
| 833 |
-
|
| 834 |
-
# if that's still too long, truncate the response
|
| 835 |
-
for answer_tokens in [chosen_tokens, rejected_tokens]:
|
| 836 |
-
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
| 837 |
-
for k in ["input_ids", "attention_mask"]:
|
| 838 |
-
answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
|
| 839 |
-
|
| 840 |
-
# Create labels
|
| 841 |
-
chosen_sequence_tokens = {
|
| 842 |
-
k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
|
| 843 |
-
}
|
| 844 |
-
rejected_sequence_tokens = {
|
| 845 |
-
k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
|
| 846 |
-
}
|
| 847 |
-
chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
|
| 848 |
-
chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
|
| 849 |
-
self.label_pad_token_id
|
| 850 |
-
] * len(chosen_tokens["prompt_input_ids"])
|
| 851 |
-
rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
|
| 852 |
-
rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
|
| 853 |
-
self.label_pad_token_id
|
| 854 |
-
] * len(rejected_tokens["prompt_input_ids"])
|
| 855 |
-
|
| 856 |
-
for k, toks in {
|
| 857 |
-
"chosen_": chosen_sequence_tokens,
|
| 858 |
-
"rejected_": rejected_sequence_tokens,
|
| 859 |
-
"": prompt_tokens,
|
| 860 |
-
}.items():
|
| 861 |
-
for type_key, tokens in toks.items():
|
| 862 |
-
if type_key == "token_type_ids":
|
| 863 |
-
continue
|
| 864 |
-
batch[f"{k}{type_key}"] = tokens
|
| 865 |
-
|
| 866 |
-
else:
|
| 867 |
-
chosen_tokens = self.processing_class(
|
| 868 |
-
chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
|
| 869 |
-
)
|
| 870 |
-
rejected_tokens = self.processing_class(
|
| 871 |
-
rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
|
| 872 |
-
)
|
| 873 |
-
prompt_tokens = self.processing_class(
|
| 874 |
-
prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
|
| 875 |
-
)
|
| 876 |
-
|
| 877 |
-
batch["chosen_labels"] = chosen_tokens["input_ids"]
|
| 878 |
-
batch["rejected_labels"] = rejected_tokens["input_ids"]
|
| 879 |
-
batch["prompt_input_ids"] = prompt_tokens["input_ids"]
|
| 880 |
-
batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
|
| 881 |
-
|
| 882 |
-
if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
| 883 |
-
batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
| 884 |
-
labels=torch.tensor(batch["rejected_labels"])
|
| 885 |
-
)
|
| 886 |
-
batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
| 887 |
-
labels=torch.tensor(batch["chosen_labels"])
|
| 888 |
-
)
|
| 889 |
-
|
| 890 |
-
return batch
|
| 891 |
-
|
| 892 |
-
@staticmethod
|
| 893 |
-
def concatenated_inputs(
|
| 894 |
-
batch: dict[str, Union[list, torch.LongTensor]],
|
| 895 |
-
is_encoder_decoder: bool = False,
|
| 896 |
-
label_pad_token_id: int = -100,
|
| 897 |
-
padding_value: int = 0,
|
| 898 |
-
device: Optional[torch.device] = None,
|
| 899 |
-
) -> dict[str, torch.LongTensor]:
|
| 900 |
-
"""Concatenate the chosen and rejected inputs into a single tensor.
|
| 901 |
-
|
| 902 |
-
Args:
|
| 903 |
-
batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
|
| 904 |
-
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
| 905 |
-
label_pad_token_id: The label pad token id.
|
| 906 |
-
padding_value: The padding value to use for the concatenated inputs_ids.
|
| 907 |
-
device: The device for the concatenated inputs.
|
| 908 |
-
|
| 909 |
-
Returns:
|
| 910 |
-
A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
|
| 911 |
-
"""
|
| 912 |
-
concatenated_batch = {}
|
| 913 |
-
|
| 914 |
-
if is_encoder_decoder:
|
| 915 |
-
max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
|
| 916 |
-
else:
|
| 917 |
-
max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
|
| 918 |
-
|
| 919 |
-
for k in batch:
|
| 920 |
-
if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
|
| 921 |
-
if "labels" in k or is_encoder_decoder:
|
| 922 |
-
pad_value = label_pad_token_id
|
| 923 |
-
elif k.endswith("_input_ids"):
|
| 924 |
-
pad_value = padding_value
|
| 925 |
-
elif k.endswith("_attention_mask"):
|
| 926 |
-
pad_value = 0
|
| 927 |
-
concatenated_key = k.replace("chosen", "concatenated")
|
| 928 |
-
concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
|
| 929 |
-
for k in batch:
|
| 930 |
-
if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
|
| 931 |
-
if "labels" in k or is_encoder_decoder:
|
| 932 |
-
pad_value = label_pad_token_id
|
| 933 |
-
elif k.endswith("_input_ids"):
|
| 934 |
-
pad_value = padding_value
|
| 935 |
-
elif k.endswith("_attention_mask"):
|
| 936 |
-
pad_value = 0
|
| 937 |
-
concatenated_key = k.replace("rejected", "concatenated")
|
| 938 |
-
concatenated_batch[concatenated_key] = torch.cat(
|
| 939 |
-
(
|
| 940 |
-
concatenated_batch[concatenated_key],
|
| 941 |
-
pad_to_length(batch[k], max_length, pad_value=pad_value),
|
| 942 |
-
),
|
| 943 |
-
dim=0,
|
| 944 |
-
).to(device=device)
|
| 945 |
-
|
| 946 |
-
if is_encoder_decoder:
|
| 947 |
-
concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
|
| 948 |
-
concatenated_batch["concatenated_attention_mask"] = (
|
| 949 |
-
batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
|
| 950 |
-
)
|
| 951 |
-
|
| 952 |
-
return concatenated_batch
|
| 953 |
-
|
| 954 |
-
def cpo_loss(
|
| 955 |
-
self,
|
| 956 |
-
policy_chosen_logps: torch.FloatTensor,
|
| 957 |
-
policy_rejected_logps: torch.FloatTensor,
|
| 958 |
-
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
| 959 |
-
"""Compute the CPO loss for a batch of policy and reference model log probabilities.
|
| 960 |
-
|
| 961 |
-
Args:
|
| 962 |
-
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
|
| 963 |
-
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
| 964 |
-
|
| 965 |
-
Returns:
|
| 966 |
-
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
|
| 967 |
-
The losses tensor contains the CPO loss for each example in the batch.
|
| 968 |
-
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
| 969 |
-
"""
|
| 970 |
-
logits = (policy_chosen_logps - policy_rejected_logps).to(self.accelerator.device)
|
| 971 |
-
|
| 972 |
-
# The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5.
|
| 973 |
-
# We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
|
| 974 |
-
# calculates a conservative CPO loss.
|
| 975 |
-
|
| 976 |
-
if self.loss_type == "simpo":
|
| 977 |
-
gamma_logratios = self.simpo_gamma / self.beta
|
| 978 |
-
logits = logits - gamma_logratios
|
| 979 |
-
# This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
|
| 980 |
-
losses = (
|
| 981 |
-
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
| 982 |
-
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
| 983 |
-
)
|
| 984 |
-
elif self.loss_type == "sigmoid":
|
| 985 |
-
# This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
|
| 986 |
-
losses = (
|
| 987 |
-
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
| 988 |
-
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
| 989 |
-
)
|
| 990 |
-
elif self.loss_type == "hinge":
|
| 991 |
-
losses = torch.relu(1 - self.beta * logits)
|
| 992 |
-
elif self.loss_type == "ipo":
|
| 993 |
-
# eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
|
| 994 |
-
losses = (logits - 1 / (2 * self.beta)) ** 2
|
| 995 |
-
else:
|
| 996 |
-
raise ValueError(
|
| 997 |
-
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'simpo']"
|
| 998 |
-
)
|
| 999 |
-
|
| 1000 |
-
chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
|
| 1001 |
-
rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
|
| 1002 |
-
|
| 1003 |
-
return losses, chosen_rewards, rejected_rewards
|
| 1004 |
-
|
| 1005 |
-
@staticmethod
|
| 1006 |
-
def get_batch_logps(
|
| 1007 |
-
logits: torch.FloatTensor,
|
| 1008 |
-
labels: torch.LongTensor,
|
| 1009 |
-
average_log_prob: bool = False,
|
| 1010 |
-
label_pad_token_id: int = -100,
|
| 1011 |
-
is_encoder_decoder: bool = False,
|
| 1012 |
-
) -> torch.FloatTensor:
|
| 1013 |
-
"""Compute the log probabilities of the given labels under the given logits.
|
| 1014 |
-
|
| 1015 |
-
Args:
|
| 1016 |
-
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
|
| 1017 |
-
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
|
| 1018 |
-
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
| 1019 |
-
label_pad_token_id: The label pad token id.
|
| 1020 |
-
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
| 1021 |
-
|
| 1022 |
-
Returns:
|
| 1023 |
-
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
| 1024 |
-
"""
|
| 1025 |
-
if logits.shape[:-1] != labels.shape:
|
| 1026 |
-
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
| 1027 |
-
|
| 1028 |
-
if not is_encoder_decoder:
|
| 1029 |
-
labels = labels[:, 1:].clone()
|
| 1030 |
-
logits = logits[:, :-1, :]
|
| 1031 |
-
loss_mask = labels != label_pad_token_id
|
| 1032 |
-
|
| 1033 |
-
# dummy token; we'll ignore the losses on these tokens later
|
| 1034 |
-
labels[labels == label_pad_token_id] = 0
|
| 1035 |
-
|
| 1036 |
-
per_token_logps = selective_log_softmax(logits, labels)
|
| 1037 |
-
|
| 1038 |
-
if average_log_prob:
|
| 1039 |
-
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
| 1040 |
-
else:
|
| 1041 |
-
return (per_token_logps * loss_mask).sum(-1)
|
| 1042 |
-
|
| 1043 |
-
def concatenated_forward(
|
| 1044 |
-
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
| 1045 |
-
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
| 1046 |
-
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
| 1047 |
-
|
| 1048 |
-
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
| 1049 |
-
"""
|
| 1050 |
-
concatenated_batch = self.concatenated_inputs(
|
| 1051 |
-
batch,
|
| 1052 |
-
is_encoder_decoder=self.is_encoder_decoder,
|
| 1053 |
-
label_pad_token_id=self.label_pad_token_id,
|
| 1054 |
-
padding_value=self.padding_value,
|
| 1055 |
-
device=self.accelerator.device,
|
| 1056 |
-
)
|
| 1057 |
-
len_chosen = batch["chosen_labels"].shape[0]
|
| 1058 |
-
|
| 1059 |
-
model_kwargs = (
|
| 1060 |
-
{
|
| 1061 |
-
"decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
|
| 1062 |
-
}
|
| 1063 |
-
if self.is_encoder_decoder
|
| 1064 |
-
else {}
|
| 1065 |
-
)
|
| 1066 |
-
|
| 1067 |
-
if self.aux_loss_enabled:
|
| 1068 |
-
model_kwargs["output_router_logits"] = True
|
| 1069 |
-
|
| 1070 |
-
outputs = model(
|
| 1071 |
-
concatenated_batch["concatenated_input_ids"],
|
| 1072 |
-
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
| 1073 |
-
use_cache=False,
|
| 1074 |
-
**model_kwargs,
|
| 1075 |
-
)
|
| 1076 |
-
all_logits = outputs.logits
|
| 1077 |
-
|
| 1078 |
-
def cross_entropy_loss(logits, labels):
|
| 1079 |
-
if not self.is_encoder_decoder:
|
| 1080 |
-
# Shift so that tokens < n predict n
|
| 1081 |
-
logits = logits[..., :-1, :].contiguous()
|
| 1082 |
-
labels = labels[..., 1:].contiguous()
|
| 1083 |
-
# Flatten the tokens
|
| 1084 |
-
loss_fct = nn.CrossEntropyLoss()
|
| 1085 |
-
logits = logits.view(-1, logits.shape[-1])
|
| 1086 |
-
labels = labels.view(-1)
|
| 1087 |
-
# Enable model parallelism
|
| 1088 |
-
labels = labels.to(logits.device)
|
| 1089 |
-
loss = loss_fct(logits, labels)
|
| 1090 |
-
return loss
|
| 1091 |
-
|
| 1092 |
-
labels = concatenated_batch["concatenated_labels"].clone()
|
| 1093 |
-
|
| 1094 |
-
if self.cpo_alpha == 0:
|
| 1095 |
-
nll_loss = torch.tensor(0.0).to(self.accelerator.device)
|
| 1096 |
-
else:
|
| 1097 |
-
nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
|
| 1098 |
-
|
| 1099 |
-
all_logps = self.get_batch_logps(
|
| 1100 |
-
all_logits,
|
| 1101 |
-
concatenated_batch["concatenated_labels"],
|
| 1102 |
-
average_log_prob=self.loss_type in ["ipo", "simpo"],
|
| 1103 |
-
is_encoder_decoder=self.is_encoder_decoder,
|
| 1104 |
-
label_pad_token_id=self.label_pad_token_id,
|
| 1105 |
-
)
|
| 1106 |
-
|
| 1107 |
-
chosen_logps = all_logps[:len_chosen]
|
| 1108 |
-
rejected_logps = all_logps[len_chosen:]
|
| 1109 |
-
|
| 1110 |
-
chosen_logits = all_logits[:len_chosen]
|
| 1111 |
-
rejected_logits = all_logits[len_chosen:]
|
| 1112 |
-
|
| 1113 |
-
if self.aux_loss_enabled:
|
| 1114 |
-
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss)
|
| 1115 |
-
|
| 1116 |
-
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)
|
| 1117 |
-
|
| 1118 |
-
def get_batch_loss_metrics(
|
| 1119 |
-
self,
|
| 1120 |
-
model,
|
| 1121 |
-
batch: dict[str, Union[list, torch.LongTensor]],
|
| 1122 |
-
train_eval: Literal["train", "eval"] = "train",
|
| 1123 |
-
):
|
| 1124 |
-
"""Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
|
| 1125 |
-
metrics = {}
|
| 1126 |
-
|
| 1127 |
-
forward_output = self.concatenated_forward(model, batch)
|
| 1128 |
-
(
|
| 1129 |
-
policy_chosen_logps,
|
| 1130 |
-
policy_rejected_logps,
|
| 1131 |
-
policy_chosen_logits,
|
| 1132 |
-
policy_rejected_logits,
|
| 1133 |
-
policy_nll_loss,
|
| 1134 |
-
) = forward_output[:5]
|
| 1135 |
-
if self.aux_loss_enabled:
|
| 1136 |
-
aux_loss = forward_output[5]
|
| 1137 |
-
|
| 1138 |
-
losses, chosen_rewards, rejected_rewards = self.cpo_loss(
|
| 1139 |
-
policy_chosen_logps,
|
| 1140 |
-
policy_rejected_logps,
|
| 1141 |
-
)
|
| 1142 |
-
|
| 1143 |
-
loss = losses.mean() + self.cpo_alpha * policy_nll_loss
|
| 1144 |
-
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
| 1145 |
-
|
| 1146 |
-
prefix = "eval_" if train_eval == "eval" else ""
|
| 1147 |
-
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
|
| 1148 |
-
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
|
| 1149 |
-
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
|
| 1150 |
-
metrics[f"{prefix}rewards/margins"] = (
|
| 1151 |
-
self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
|
| 1152 |
-
)
|
| 1153 |
-
metrics[f"{prefix}logps/rejected"] = (
|
| 1154 |
-
self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean().item()
|
| 1155 |
-
)
|
| 1156 |
-
metrics[f"{prefix}logps/chosen"] = (
|
| 1157 |
-
self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean().item()
|
| 1158 |
-
)
|
| 1159 |
-
metrics[f"{prefix}logits/rejected"] = (
|
| 1160 |
-
self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean().item()
|
| 1161 |
-
)
|
| 1162 |
-
metrics[f"{prefix}logits/chosen"] = (
|
| 1163 |
-
self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean().item()
|
| 1164 |
-
)
|
| 1165 |
-
metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
|
| 1166 |
-
|
| 1167 |
-
if self.aux_loss_enabled:
|
| 1168 |
-
loss += self.aux_loss_coef * aux_loss
|
| 1169 |
-
|
| 1170 |
-
return loss, metrics
|
| 1171 |
-
|
| 1172 |
-
def compute_loss(
|
| 1173 |
-
self,
|
| 1174 |
-
model: Union[PreTrainedModel, nn.Module],
|
| 1175 |
-
inputs: dict[str, Union[torch.Tensor, Any]],
|
| 1176 |
-
return_outputs=False,
|
| 1177 |
-
num_items_in_batch=None,
|
| 1178 |
-
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
| 1179 |
-
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1180 |
-
|
| 1181 |
-
with compute_loss_context_manager:
|
| 1182 |
-
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
|
| 1183 |
-
|
| 1184 |
-
# force log the metrics
|
| 1185 |
-
self.store_metrics(metrics, train_eval="train")
|
| 1186 |
-
|
| 1187 |
-
if return_outputs:
|
| 1188 |
-
return (loss, metrics)
|
| 1189 |
-
return loss
|
| 1190 |
-
|
| 1191 |
-
def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
|
| 1192 |
-
"""Generate samples from the model and reference model for the given batch of inputs."""
|
| 1193 |
-
|
| 1194 |
-
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
| 1195 |
-
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
|
| 1196 |
-
generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1197 |
-
|
| 1198 |
-
with generate_context_manager:
|
| 1199 |
-
policy_output = model.generate(
|
| 1200 |
-
input_ids=batch["prompt_input_ids"],
|
| 1201 |
-
attention_mask=batch["prompt_attention_mask"],
|
| 1202 |
-
max_length=self.max_length,
|
| 1203 |
-
do_sample=True,
|
| 1204 |
-
pad_token_id=self.processing_class.pad_token_id,
|
| 1205 |
-
)
|
| 1206 |
-
|
| 1207 |
-
policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
|
| 1208 |
-
policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
|
| 1209 |
-
|
| 1210 |
-
return policy_output_decoded
|
| 1211 |
-
|
| 1212 |
-
def prediction_step(
|
| 1213 |
-
self,
|
| 1214 |
-
model: Union[PreTrainedModel, nn.Module],
|
| 1215 |
-
inputs: dict[str, Union[torch.Tensor, Any]],
|
| 1216 |
-
prediction_loss_only: bool,
|
| 1217 |
-
ignore_keys: Optional[list[str]] = None,
|
| 1218 |
-
):
|
| 1219 |
-
if ignore_keys is None:
|
| 1220 |
-
if hasattr(model, "config"):
|
| 1221 |
-
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
| 1222 |
-
else:
|
| 1223 |
-
ignore_keys = []
|
| 1224 |
-
|
| 1225 |
-
prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1226 |
-
|
| 1227 |
-
with torch.no_grad(), prediction_context_manager:
|
| 1228 |
-
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
|
| 1229 |
-
|
| 1230 |
-
# force log the metrics
|
| 1231 |
-
self.store_metrics(metrics, train_eval="eval")
|
| 1232 |
-
|
| 1233 |
-
if prediction_loss_only:
|
| 1234 |
-
return (loss.detach(), None, None)
|
| 1235 |
-
|
| 1236 |
-
# logits for the chosen and rejected samples from model
|
| 1237 |
-
logits_dict = {
|
| 1238 |
-
"eval_logits/chosen": metrics["eval_logits/chosen"],
|
| 1239 |
-
"eval_logits/rejected": metrics["eval_logits/rejected"],
|
| 1240 |
-
}
|
| 1241 |
-
logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
|
| 1242 |
-
logits = torch.tensor(logits, device=self.accelerator.device)
|
| 1243 |
-
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
| 1244 |
-
|
| 1245 |
-
return (loss.detach(), logits, labels)
|
| 1246 |
-
|
| 1247 |
-
def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
|
| 1248 |
-
for key, value in metrics.items():
|
| 1249 |
-
self._stored_metrics[train_eval][key].append(value)
|
| 1250 |
-
|
| 1251 |
-
def evaluation_loop(
|
| 1252 |
-
self,
|
| 1253 |
-
dataloader: DataLoader,
|
| 1254 |
-
description: str,
|
| 1255 |
-
prediction_loss_only: Optional[bool] = None,
|
| 1256 |
-
ignore_keys: Optional[list[str]] = None,
|
| 1257 |
-
metric_key_prefix: str = "eval",
|
| 1258 |
-
) -> EvalLoopOutput:
|
| 1259 |
-
"""
|
| 1260 |
-
Overriding built-in evaluation loop to store metrics for each batch.
|
| 1261 |
-
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
| 1262 |
-
|
| 1263 |
-
Works both with or without labels.
|
| 1264 |
-
"""
|
| 1265 |
-
|
| 1266 |
-
# Sample and save to game log if requested (for one batch to save time)
|
| 1267 |
-
if self.generate_during_eval:
|
| 1268 |
-
# Generate random indices within the range of the total number of samples
|
| 1269 |
-
num_samples = len(dataloader.dataset)
|
| 1270 |
-
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
|
| 1271 |
-
|
| 1272 |
-
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
| 1273 |
-
random_batch_dataset = dataloader.dataset.select(random_indices)
|
| 1274 |
-
random_batch = self.data_collator(random_batch_dataset)
|
| 1275 |
-
random_batch = self._prepare_inputs(random_batch)
|
| 1276 |
-
|
| 1277 |
-
policy_output_decoded = self.generate_from_model(self.model, random_batch)
|
| 1278 |
-
|
| 1279 |
-
table = pd.DataFrame(
|
| 1280 |
-
columns=["Prompt", "Policy"],
|
| 1281 |
-
data=[
|
| 1282 |
-
[prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
|
| 1283 |
-
],
|
| 1284 |
-
)
|
| 1285 |
-
if "wandb" in self.args.report_to:
|
| 1286 |
-
wandb.log({"game_log": wandb.Table(data=table)})
|
| 1287 |
-
|
| 1288 |
-
if "comet_ml" in self.args.report_to:
|
| 1289 |
-
log_table_to_comet_experiment(
|
| 1290 |
-
name="game_log.csv",
|
| 1291 |
-
table=table,
|
| 1292 |
-
)
|
| 1293 |
-
|
| 1294 |
-
# Base evaluation
|
| 1295 |
-
initial_output = super().evaluation_loop(
|
| 1296 |
-
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
|
| 1297 |
-
)
|
| 1298 |
-
|
| 1299 |
-
return initial_output
|
| 1300 |
-
|
| 1301 |
-
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
| 1302 |
-
"""
|
| 1303 |
-
Log `logs` on the various objects watching training, including stored metrics.
|
| 1304 |
-
|
| 1305 |
-
Args:
|
| 1306 |
-
logs (`dict[str, float]`):
|
| 1307 |
-
The values to log.
|
| 1308 |
-
start_time (`float` or `None`, *optional*, defaults to `None`):
|
| 1309 |
-
Start time of the training.
|
| 1310 |
-
"""
|
| 1311 |
-
# logs either has 'loss' or 'eval_loss'
|
| 1312 |
-
train_eval = "train" if "loss" in logs else "eval"
|
| 1313 |
-
# Add averaged stored metrics to logs
|
| 1314 |
-
for key, metrics in self._stored_metrics[train_eval].items():
|
| 1315 |
-
logs[key] = torch.tensor(metrics).mean().item()
|
| 1316 |
-
del self._stored_metrics[train_eval]
|
| 1317 |
-
|
| 1318 |
-
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
| 1319 |
-
return super().log(logs, start_time)
|
| 1320 |
-
else: # transformers<=4.46
|
| 1321 |
-
return super().log(logs)
|
| 1322 |
-
|
| 1323 |
-
def _shift_right(self, input_ids):
|
| 1324 |
-
if self.decoder_start_token_id is None:
|
| 1325 |
-
raise ValueError(
|
| 1326 |
-
"model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
|
| 1327 |
-
)
|
| 1328 |
-
|
| 1329 |
-
# shift inputs to the right
|
| 1330 |
-
if is_torch_fx_proxy(input_ids):
|
| 1331 |
-
# Item assignment is not supported natively for proxies.
|
| 1332 |
-
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
|
| 1333 |
-
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
|
| 1334 |
-
else:
|
| 1335 |
-
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
| 1336 |
-
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
| 1337 |
-
shifted_input_ids[..., 0] = self.decoder_start_token_id
|
| 1338 |
-
|
| 1339 |
-
if self.pad_token_id is None:
|
| 1340 |
-
raise ValueError("model.config.pad_token_id has to be defined.")
|
| 1341 |
-
# replace possible -100 values in labels by `pad_token_id`
|
| 1342 |
-
shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
|
| 1343 |
-
|
| 1344 |
-
return shifted_input_ids
|
| 1345 |
-
|
| 1346 |
-
def create_model_card(
|
| 1347 |
-
self,
|
| 1348 |
-
model_name: Optional[str] = None,
|
| 1349 |
-
dataset_name: Optional[str] = None,
|
| 1350 |
-
tags: Union[str, list[str], None] = None,
|
| 1351 |
-
):
|
| 1352 |
-
"""
|
| 1353 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
| 1354 |
-
|
| 1355 |
-
Args:
|
| 1356 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1357 |
-
Name of the model.
|
| 1358 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1359 |
-
Name of the dataset used for training.
|
| 1360 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 1361 |
-
Tags to be associated with the model card.
|
| 1362 |
-
"""
|
| 1363 |
-
if not self.is_world_process_zero():
|
| 1364 |
-
return
|
| 1365 |
-
|
| 1366 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 1367 |
-
base_model = self.model.config._name_or_path
|
| 1368 |
-
else:
|
| 1369 |
-
base_model = None
|
| 1370 |
-
|
| 1371 |
-
tags = tags or []
|
| 1372 |
-
if isinstance(tags, str):
|
| 1373 |
-
tags = [tags]
|
| 1374 |
-
|
| 1375 |
-
if hasattr(self.model.config, "unsloth_version"):
|
| 1376 |
-
tags.append("unsloth")
|
| 1377 |
-
|
| 1378 |
-
citation = textwrap.dedent("""\
|
| 1379 |
-
@inproceedings{xu2024contrastive,
|
| 1380 |
-
title = {{Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation}},
|
| 1381 |
-
author = {Haoran Xu and Amr Sharaf and Yunmo Chen and Weiting Tan and Lingfeng Shen and Benjamin Van Durme and Kenton Murray and Young Jin Kim},
|
| 1382 |
-
year = 2024,
|
| 1383 |
-
booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
|
| 1384 |
-
publisher = {OpenReview.net},
|
| 1385 |
-
url = {https://openreview.net/forum?id=51iwkioZpn}
|
| 1386 |
-
}""")
|
| 1387 |
-
|
| 1388 |
-
model_card = generate_model_card(
|
| 1389 |
-
base_model=base_model,
|
| 1390 |
-
model_name=model_name,
|
| 1391 |
-
hub_model_id=self.hub_model_id,
|
| 1392 |
-
dataset_name=dataset_name,
|
| 1393 |
-
tags=tags,
|
| 1394 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 1395 |
-
comet_url=get_comet_experiment_url(),
|
| 1396 |
-
trainer_name="CPO",
|
| 1397 |
-
trainer_citation=citation,
|
| 1398 |
-
paper_title="Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation",
|
| 1399 |
-
paper_id="2401.08417",
|
| 1400 |
-
)
|
| 1401 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 1402 |
-
class UnslothCPOTrainer(_UnslothCPOTrainer):
|
| 1403 |
-
"""
|
| 1404 |
-
|
| 1405 |
-
Initialize CPOTrainer.
|
| 1406 |
-
|
| 1407 |
-
Args:
|
| 1408 |
-
model (`transformers.PreTrainedModel`):
|
| 1409 |
-
The model to train, preferably an `AutoModelForSequenceClassification`.
|
| 1410 |
-
args (`CPOConfig`):
|
| 1411 |
-
The CPO config arguments to use for training.
|
| 1412 |
-
data_collator (`transformers.DataCollator`):
|
| 1413 |
-
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
| 1414 |
-
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
| 1415 |
-
train_dataset (`datasets.Dataset`):
|
| 1416 |
-
The dataset to use for training.
|
| 1417 |
-
eval_dataset (`datasets.Dataset`):
|
| 1418 |
-
The dataset to use for evaluation.
|
| 1419 |
-
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
| 1420 |
-
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
| 1421 |
-
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
| 1422 |
-
reuse the fine-tuned model.
|
| 1423 |
-
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
| 1424 |
-
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
| 1425 |
-
callbacks (`list[transformers.TrainerCallback]`):
|
| 1426 |
-
The callbacks to use for training.
|
| 1427 |
-
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
| 1428 |
-
The optimizer and scheduler to use for training.
|
| 1429 |
-
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
| 1430 |
-
The function to use to preprocess the logits before computing the metrics.
|
| 1431 |
-
peft_config (`dict`, defaults to `None`):
|
| 1432 |
-
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
| 1433 |
-
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
| 1434 |
-
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
| 1435 |
-
a dictionary string to metric values.
|
| 1436 |
-
|
| 1437 |
-
"""
|
| 1438 |
-
def __init__(
|
| 1439 |
-
self,
|
| 1440 |
-
model = None,
|
| 1441 |
-
args = None,
|
| 1442 |
-
data_collator = None,
|
| 1443 |
-
train_dataset = None,
|
| 1444 |
-
eval_dataset = None,
|
| 1445 |
-
processing_class = None,
|
| 1446 |
-
model_init = None,
|
| 1447 |
-
callbacks = None,
|
| 1448 |
-
preprocess_logits_for_metrics = None,
|
| 1449 |
-
peft_config = None,
|
| 1450 |
-
compute_metrics = None,
|
| 1451 |
-
**kwargs
|
| 1452 |
-
):
|
| 1453 |
-
if args is None: args = UnslothCPOConfig()
|
| 1454 |
-
use_bf16 = getattr(args, 'bf16', False)
|
| 1455 |
-
if type(use_bf16) is not bool: use_bf16 = False
|
| 1456 |
-
use_fp16 = getattr(args, 'fp16', False)
|
| 1457 |
-
if type(use_fp16) is not bool: use_fp16 = False
|
| 1458 |
-
force_float32 = False
|
| 1459 |
-
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
| 1460 |
-
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 1461 |
-
force_float32 = True
|
| 1462 |
-
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 1463 |
-
dtype = getattr(model.config, 'torch_dtype', None)
|
| 1464 |
-
if dtype is None: dtype = model.get_input_embeddings().dtype
|
| 1465 |
-
from unsloth_zoo.utils import _get_dtype
|
| 1466 |
-
dtype = _get_dtype(dtype)
|
| 1467 |
-
float16 = dtype == torch.float16
|
| 1468 |
-
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 1469 |
-
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 1470 |
-
if force_float32:
|
| 1471 |
-
args.fp16 = False
|
| 1472 |
-
args.bf16 = False
|
| 1473 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 1474 |
-
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 1475 |
-
args.fp16 = float16
|
| 1476 |
-
args.bf16 = not float16
|
| 1477 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 1478 |
-
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 1479 |
-
args.eval_strategy = 'steps'
|
| 1480 |
-
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 1481 |
-
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 1482 |
-
if ga_steps is not None and ga_steps > 1:
|
| 1483 |
-
from transformers import __version__ as transformers_version
|
| 1484 |
-
if Version(transformers_version) <= Version('4.45.2'):
|
| 1485 |
-
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 1486 |
-
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 1487 |
-
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 1488 |
-
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 1489 |
-
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 1490 |
-
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 1491 |
-
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 1492 |
-
if type(fp16_full_eval) is not bool: fp16_full_eval = False
|
| 1493 |
-
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 1494 |
-
if type(bf16_full_eval) is not bool: bf16_full_eval = False
|
| 1495 |
-
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 1496 |
-
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 1497 |
-
if force_float32:
|
| 1498 |
-
args.bf16_full_eval = False
|
| 1499 |
-
args.fp16_full_eval = False
|
| 1500 |
-
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 1501 |
-
args.bf16_full_eval = True
|
| 1502 |
-
args.fp16_full_eval = False
|
| 1503 |
-
elif not bf16_full_eval and not fp16_full_eval:
|
| 1504 |
-
args.bf16_full_eval = args.bf16
|
| 1505 |
-
args.fp16_full_eval = args.fp16
|
| 1506 |
-
_output_logits = False
|
| 1507 |
-
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 1508 |
-
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 1509 |
-
if _output_logits:
|
| 1510 |
-
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 1511 |
-
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 1512 |
-
pass
|
| 1513 |
-
else:
|
| 1514 |
-
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 1515 |
-
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 1516 |
-
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 1517 |
-
max_seq_length = model.max_seq_length
|
| 1518 |
-
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 1519 |
-
if model is not None and hasattr(model, 'for_training'):
|
| 1520 |
-
model.for_training()
|
| 1521 |
-
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 1522 |
-
if 'processing_class' in locals():
|
| 1523 |
-
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 1524 |
-
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 1525 |
-
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 1526 |
-
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 1527 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1528 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 1529 |
-
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
|
| 1530 |
-
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 1531 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
| 1532 |
-
else:
|
| 1533 |
-
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 1534 |
-
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 1535 |
-
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 1536 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1537 |
-
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 1538 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 1539 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
| 1540 |
-
else:
|
| 1541 |
-
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
|
| 1542 |
-
other_metrics = []
|
| 1543 |
-
|
| 1544 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 1545 |
-
PatchRLStatistics('cpo_trainer', other_metrics)
|
| 1546 |
-
|
| 1547 |
-
super().__init__(
|
| 1548 |
-
model = model,
|
| 1549 |
-
args = args,
|
| 1550 |
-
data_collator = data_collator,
|
| 1551 |
-
train_dataset = train_dataset,
|
| 1552 |
-
eval_dataset = eval_dataset,
|
| 1553 |
-
processing_class = processing_class,
|
| 1554 |
-
model_init = model_init,
|
| 1555 |
-
callbacks = callbacks,
|
| 1556 |
-
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
| 1557 |
-
peft_config = peft_config,
|
| 1558 |
-
compute_metrics = compute_metrics,**kwargs)
|
| 1559 |
-
if hasattr(self, 'neftune_hook_handle'):
|
| 1560 |
-
self.neftune_hook_handle.remove()
|
| 1561 |
-
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 1562 |
-
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 1563 |
-
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 1564 |
-
pass
|
| 1565 |
-
|
| 1566 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_run_uploads/UnslothDDPOTrainer.py
DELETED
|
@@ -1,881 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
2025.7.11
|
| 3 |
-
2025.7.11
|
| 4 |
-
4.54.1
|
| 5 |
-
0.16.1
|
| 6 |
-
__UNSLOTH_VERSIONING__
|
| 7 |
-
"""
|
| 8 |
-
from torch import Tensor
|
| 9 |
-
import torch
|
| 10 |
-
import torch.nn as nn
|
| 11 |
-
from torch.nn import functional as F
|
| 12 |
-
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
| 13 |
-
from trl.trainer.ddpo_trainer import (Accelerator, Any, Callable, DDPOConfig, DDPOStableDiffusionPipeline, DDPOTrainer, Optional, PerPromptStatTracker, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, futures, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, wandb, warn)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
import os
|
| 17 |
-
from typing import *
|
| 18 |
-
from dataclasses import dataclass, field
|
| 19 |
-
from packaging.version import Version
|
| 20 |
-
import torch
|
| 21 |
-
import numpy as np
|
| 22 |
-
from contextlib import nullcontext
|
| 23 |
-
from torch.nn import functional as F
|
| 24 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
| 25 |
-
|
| 26 |
-
torch_compile_options = {
|
| 27 |
-
"epilogue_fusion" : True,
|
| 28 |
-
"max_autotune" : False,
|
| 29 |
-
"shape_padding" : True,
|
| 30 |
-
"trace.enabled" : False,
|
| 31 |
-
"triton.cudagraphs" : False,
|
| 32 |
-
}
|
| 33 |
-
|
| 34 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 35 |
-
def chunked_selective_log_softmax(logits, index):
|
| 36 |
-
# Split into 4 chunks only
|
| 37 |
-
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
|
| 38 |
-
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
|
| 39 |
-
all_per_token_logps = []
|
| 40 |
-
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
|
| 41 |
-
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
|
| 42 |
-
chunk_logits = chunk_logits.to(torch.float32)
|
| 43 |
-
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 44 |
-
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
|
| 45 |
-
per_token_logps = selected_logits - logsumexp_values
|
| 46 |
-
all_per_token_logps.append(per_token_logps)
|
| 47 |
-
pass
|
| 48 |
-
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 49 |
-
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
|
| 50 |
-
return all_per_token_logps
|
| 51 |
-
@dataclass
|
| 52 |
-
class UnslothDDPOConfig(DDPOConfig):
|
| 53 |
-
"""
|
| 54 |
-
|
| 55 |
-
Configuration class for the [`DDPOTrainer`].
|
| 56 |
-
|
| 57 |
-
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 58 |
-
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 59 |
-
command line.
|
| 60 |
-
|
| 61 |
-
Parameters:
|
| 62 |
-
exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
|
| 63 |
-
Name of this experiment (by default is the file name without the extension name).
|
| 64 |
-
run_name (`str`, *optional*, defaults to `""`):
|
| 65 |
-
Name of this run.
|
| 66 |
-
seed (`int`, *optional*, defaults to `0`):
|
| 67 |
-
Random seed.
|
| 68 |
-
log_with (`Literal["wandb", "tensorboard"]]` or `None`, *optional*, defaults to `None`):
|
| 69 |
-
Log with either 'wandb' or 'tensorboard', check
|
| 70 |
-
https://huggingface.co/docs/accelerate/usage_guides/tracking for more details.
|
| 71 |
-
tracker_kwargs (`Dict`, *optional*, defaults to `{}`):
|
| 72 |
-
Keyword arguments for the tracker (e.g. wandb_project).
|
| 73 |
-
accelerator_kwargs (`Dict`, *optional*, defaults to `{}`):
|
| 74 |
-
Keyword arguments for the accelerator.
|
| 75 |
-
project_kwargs (`Dict`, *optional*, defaults to `{}`):
|
| 76 |
-
Keyword arguments for the accelerator project config (e.g. `logging_dir`).
|
| 77 |
-
tracker_project_name (`str`, *optional*, defaults to `"trl"`):
|
| 78 |
-
Name of project to use for tracking.
|
| 79 |
-
logdir (`str`, *optional*, defaults to `"logs"`):
|
| 80 |
-
Top-level logging directory for checkpoint saving.
|
| 81 |
-
num_epochs (`int`, *optional*, defaults to `100`):
|
| 82 |
-
Number of epochs to train.
|
| 83 |
-
save_freq (`int`, *optional*, defaults to `1`):
|
| 84 |
-
Number of epochs between saving model checkpoints.
|
| 85 |
-
num_checkpoint_limit (`int`, *optional*, defaults to `5`):
|
| 86 |
-
Number of checkpoints to keep before overwriting old ones.
|
| 87 |
-
mixed_precision (`str`, *optional*, defaults to `"fp16"`):
|
| 88 |
-
Mixed precision training.
|
| 89 |
-
allow_tf32 (`bool`, *optional*, defaults to `True`):
|
| 90 |
-
Allow `tf32` on Ampere GPUs.
|
| 91 |
-
resume_from (`str`, *optional*, defaults to `""`):
|
| 92 |
-
Resume training from a checkpoint.
|
| 93 |
-
sample_num_steps (`int`, *optional*, defaults to `50`):
|
| 94 |
-
Number of sampler inference steps.
|
| 95 |
-
sample_eta (`float`, *optional*, defaults to `1.0`):
|
| 96 |
-
Eta parameter for the DDIM sampler.
|
| 97 |
-
sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
|
| 98 |
-
Classifier-free guidance weight.
|
| 99 |
-
sample_batch_size (`int`, *optional*, defaults to `1`):
|
| 100 |
-
Batch size (per GPU) to use for sampling.
|
| 101 |
-
sample_num_batches_per_epoch (`int`, *optional*, defaults to `2`):
|
| 102 |
-
Number of batches to sample per epoch.
|
| 103 |
-
train_batch_size (`int`, *optional*, defaults to `1`):
|
| 104 |
-
Batch size (per GPU) to use for training.
|
| 105 |
-
train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
|
| 106 |
-
Use 8bit Adam optimizer from bitsandbytes.
|
| 107 |
-
train_learning_rate (`float`, *optional*, defaults to `3e-4`):
|
| 108 |
-
Learning rate.
|
| 109 |
-
train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
|
| 110 |
-
Adam beta1.
|
| 111 |
-
train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
|
| 112 |
-
Adam beta2.
|
| 113 |
-
train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
|
| 114 |
-
Adam weight decay.
|
| 115 |
-
train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
|
| 116 |
-
Adam epsilon.
|
| 117 |
-
train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
|
| 118 |
-
Number of gradient accumulation steps.
|
| 119 |
-
train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
|
| 120 |
-
Maximum gradient norm for gradient clipping.
|
| 121 |
-
train_num_inner_epochs (`int`, *optional*, defaults to `1`):
|
| 122 |
-
Number of inner epochs per outer epoch.
|
| 123 |
-
train_cfg (`bool`, *optional*, defaults to `True`):
|
| 124 |
-
Whether to use classifier-free guidance during training.
|
| 125 |
-
train_adv_clip_max (`float`, *optional*, defaults to `5.0`):
|
| 126 |
-
Clip advantages to the range.
|
| 127 |
-
train_clip_range (`float`, *optional*, defaults to `1e-4`):
|
| 128 |
-
PPO clip range.
|
| 129 |
-
train_timestep_fraction (`float`, *optional*, defaults to `1.0`):
|
| 130 |
-
Fraction of timesteps to train on.
|
| 131 |
-
per_prompt_stat_tracking (`bool`, *optional*, defaults to `False`):
|
| 132 |
-
Whether to track statistics for each prompt separately.
|
| 133 |
-
per_prompt_stat_tracking_buffer_size (`int`, *optional*, defaults to `16`):
|
| 134 |
-
Number of reward values to store in the buffer for each prompt.
|
| 135 |
-
per_prompt_stat_tracking_min_count (`int`, *optional*, defaults to `16`):
|
| 136 |
-
Minimum number of reward values to store in the buffer.
|
| 137 |
-
async_reward_computation (`bool`, *optional*, defaults to `False`):
|
| 138 |
-
Whether to compute rewards asynchronously.
|
| 139 |
-
max_workers (`int`, *optional*, defaults to `2`):
|
| 140 |
-
Maximum number of workers to use for async reward computation.
|
| 141 |
-
negative_prompts (`str`, *optional*, defaults to `""`):
|
| 142 |
-
Comma-separated list of prompts to use as negative examples.
|
| 143 |
-
push_to_hub (`bool`, *optional*, defaults to `False`):
|
| 144 |
-
Whether to push the final model checkpoint to the Hub.
|
| 145 |
-
|
| 146 |
-
"""
|
| 147 |
-
vllm_sampling_params: Optional[Any] = field(
|
| 148 |
-
default = None,
|
| 149 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
| 150 |
-
)
|
| 151 |
-
unsloth_num_chunks : Optional[int] = field(
|
| 152 |
-
default = -1,
|
| 153 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 154 |
-
)
|
| 155 |
-
def __init__(
|
| 156 |
-
self,
|
| 157 |
-
exp_name = 'colab_kernel_launcher',
|
| 158 |
-
run_name = '',
|
| 159 |
-
seed = 3407,
|
| 160 |
-
log_with = None,
|
| 161 |
-
tracker_project_name = 'trl',
|
| 162 |
-
logdir = 'logs',
|
| 163 |
-
num_epochs = 100,
|
| 164 |
-
save_freq = 1,
|
| 165 |
-
num_checkpoint_limit = 5,
|
| 166 |
-
mixed_precision = 'fp16',
|
| 167 |
-
allow_tf32 = True,
|
| 168 |
-
resume_from = '',
|
| 169 |
-
sample_num_steps = 50,
|
| 170 |
-
sample_eta = 1.0,
|
| 171 |
-
sample_guidance_scale = 5.0,
|
| 172 |
-
sample_batch_size = 1,
|
| 173 |
-
sample_num_batches_per_epoch = 2,
|
| 174 |
-
train_batch_size = 1,
|
| 175 |
-
train_use_8bit_adam = False,
|
| 176 |
-
train_learning_rate = 5e-05,
|
| 177 |
-
train_adam_beta1 = 0.9,
|
| 178 |
-
train_adam_beta2 = 0.999,
|
| 179 |
-
train_adam_weight_decay = 0.01,
|
| 180 |
-
train_adam_epsilon = 1e-08,
|
| 181 |
-
train_gradient_accumulation_steps = 2,
|
| 182 |
-
train_max_grad_norm = 1.0,
|
| 183 |
-
train_num_inner_epochs = 1,
|
| 184 |
-
train_cfg = True,
|
| 185 |
-
train_adv_clip_max = 5.0,
|
| 186 |
-
train_clip_range = 0.0001,
|
| 187 |
-
train_timestep_fraction = 1.0,
|
| 188 |
-
per_prompt_stat_tracking = False,
|
| 189 |
-
per_prompt_stat_tracking_buffer_size = 16,
|
| 190 |
-
per_prompt_stat_tracking_min_count = 16,
|
| 191 |
-
async_reward_computation = False,
|
| 192 |
-
max_workers = 2,
|
| 193 |
-
negative_prompts = '',
|
| 194 |
-
push_to_hub = False,
|
| 195 |
-
vllm_sampling_params = None,
|
| 196 |
-
unsloth_num_chunks = -1,
|
| 197 |
-
**kwargs,
|
| 198 |
-
):
|
| 199 |
-
|
| 200 |
-
super().__init__(
|
| 201 |
-
exp_name = exp_name,
|
| 202 |
-
run_name = run_name,
|
| 203 |
-
seed = seed,
|
| 204 |
-
log_with = log_with,
|
| 205 |
-
tracker_project_name = tracker_project_name,
|
| 206 |
-
logdir = logdir,
|
| 207 |
-
num_epochs = num_epochs,
|
| 208 |
-
save_freq = save_freq,
|
| 209 |
-
num_checkpoint_limit = num_checkpoint_limit,
|
| 210 |
-
mixed_precision = mixed_precision,
|
| 211 |
-
allow_tf32 = allow_tf32,
|
| 212 |
-
resume_from = resume_from,
|
| 213 |
-
sample_num_steps = sample_num_steps,
|
| 214 |
-
sample_eta = sample_eta,
|
| 215 |
-
sample_guidance_scale = sample_guidance_scale,
|
| 216 |
-
sample_batch_size = sample_batch_size,
|
| 217 |
-
sample_num_batches_per_epoch = sample_num_batches_per_epoch,
|
| 218 |
-
train_batch_size = train_batch_size,
|
| 219 |
-
train_use_8bit_adam = train_use_8bit_adam,
|
| 220 |
-
train_learning_rate = train_learning_rate,
|
| 221 |
-
train_adam_beta1 = train_adam_beta1,
|
| 222 |
-
train_adam_beta2 = train_adam_beta2,
|
| 223 |
-
train_adam_weight_decay = train_adam_weight_decay,
|
| 224 |
-
train_adam_epsilon = train_adam_epsilon,
|
| 225 |
-
train_gradient_accumulation_steps = train_gradient_accumulation_steps,
|
| 226 |
-
train_max_grad_norm = train_max_grad_norm,
|
| 227 |
-
train_num_inner_epochs = train_num_inner_epochs,
|
| 228 |
-
train_cfg = train_cfg,
|
| 229 |
-
train_adv_clip_max = train_adv_clip_max,
|
| 230 |
-
train_clip_range = train_clip_range,
|
| 231 |
-
train_timestep_fraction = train_timestep_fraction,
|
| 232 |
-
per_prompt_stat_tracking = per_prompt_stat_tracking,
|
| 233 |
-
per_prompt_stat_tracking_buffer_size = per_prompt_stat_tracking_buffer_size,
|
| 234 |
-
per_prompt_stat_tracking_min_count = per_prompt_stat_tracking_min_count,
|
| 235 |
-
async_reward_computation = async_reward_computation,
|
| 236 |
-
max_workers = max_workers,
|
| 237 |
-
negative_prompts = negative_prompts,
|
| 238 |
-
push_to_hub = push_to_hub,**kwargs)
|
| 239 |
-
self.vllm_sampling_params = vllm_sampling_params
|
| 240 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
| 241 |
-
pass
|
| 242 |
-
|
| 243 |
-
class _UnslothDDPOTrainer(PyTorchModelHubMixin):
|
| 244 |
-
""""""
|
| 245 |
-
|
| 246 |
-
_tag_names = ["trl", "ddpo"]
|
| 247 |
-
|
| 248 |
-
def __init__(
|
| 249 |
-
self,
|
| 250 |
-
config: DDPOConfig,
|
| 251 |
-
reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
|
| 252 |
-
prompt_function: Callable[[], tuple[str, Any]],
|
| 253 |
-
sd_pipeline: DDPOStableDiffusionPipeline,
|
| 254 |
-
image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
|
| 255 |
-
):
|
| 256 |
-
if image_samples_hook is None:
|
| 257 |
-
warn("No image_samples_hook provided; no images will be logged")
|
| 258 |
-
|
| 259 |
-
self.prompt_fn = prompt_function
|
| 260 |
-
self.reward_fn = reward_function
|
| 261 |
-
self.config = config
|
| 262 |
-
self.image_samples_callback = image_samples_hook
|
| 263 |
-
|
| 264 |
-
accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
|
| 265 |
-
|
| 266 |
-
if self.config.resume_from:
|
| 267 |
-
self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
|
| 268 |
-
if "checkpoint_" not in os.path.basename(self.config.resume_from):
|
| 269 |
-
# get the most recent checkpoint in this directory
|
| 270 |
-
checkpoints = list(
|
| 271 |
-
filter(
|
| 272 |
-
lambda x: "checkpoint_" in x,
|
| 273 |
-
os.listdir(self.config.resume_from),
|
| 274 |
-
)
|
| 275 |
-
)
|
| 276 |
-
if len(checkpoints) == 0:
|
| 277 |
-
raise ValueError(f"No checkpoints found in {self.config.resume_from}")
|
| 278 |
-
checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
|
| 279 |
-
self.config.resume_from = os.path.join(
|
| 280 |
-
self.config.resume_from,
|
| 281 |
-
f"checkpoint_{checkpoint_numbers[-1]}",
|
| 282 |
-
)
|
| 283 |
-
|
| 284 |
-
accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
|
| 285 |
-
|
| 286 |
-
# number of timesteps within each trajectory to train on
|
| 287 |
-
self.num_train_timesteps = int(self.config.sample_num_steps * self.config.train_timestep_fraction)
|
| 288 |
-
|
| 289 |
-
self.accelerator = Accelerator(
|
| 290 |
-
log_with=self.config.log_with,
|
| 291 |
-
mixed_precision=self.config.mixed_precision,
|
| 292 |
-
project_config=accelerator_project_config,
|
| 293 |
-
# we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
|
| 294 |
-
# number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
|
| 295 |
-
# the total number of optimizer steps to accumulate across.
|
| 296 |
-
gradient_accumulation_steps=self.config.train_gradient_accumulation_steps * self.num_train_timesteps,
|
| 297 |
-
**self.config.accelerator_kwargs,
|
| 298 |
-
)
|
| 299 |
-
|
| 300 |
-
is_okay, message = self._config_check()
|
| 301 |
-
if not is_okay:
|
| 302 |
-
raise ValueError(message)
|
| 303 |
-
|
| 304 |
-
is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
|
| 305 |
-
|
| 306 |
-
if self.accelerator.is_main_process:
|
| 307 |
-
self.accelerator.init_trackers(
|
| 308 |
-
self.config.tracker_project_name,
|
| 309 |
-
config=dict(ddpo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(),
|
| 310 |
-
init_kwargs=self.config.tracker_kwargs,
|
| 311 |
-
)
|
| 312 |
-
|
| 313 |
-
logger.info(f"\n{config}")
|
| 314 |
-
|
| 315 |
-
set_seed(self.config.seed, device_specific=True)
|
| 316 |
-
|
| 317 |
-
self.sd_pipeline = sd_pipeline
|
| 318 |
-
|
| 319 |
-
self.sd_pipeline.set_progress_bar_config(
|
| 320 |
-
position=1,
|
| 321 |
-
disable=not self.accelerator.is_local_main_process,
|
| 322 |
-
leave=False,
|
| 323 |
-
desc="Timestep",
|
| 324 |
-
dynamic_ncols=True,
|
| 325 |
-
)
|
| 326 |
-
|
| 327 |
-
# For mixed precision training we cast all non-trainable weights [vae, non-lora text_encoder and non-lora unet] to half-precision
|
| 328 |
-
# as these weights are only used for inference, keeping weights in full precision is not required.
|
| 329 |
-
if self.accelerator.mixed_precision == "fp16":
|
| 330 |
-
inference_dtype = torch.float16
|
| 331 |
-
elif self.accelerator.mixed_precision == "bf16":
|
| 332 |
-
inference_dtype = torch.bfloat16
|
| 333 |
-
else:
|
| 334 |
-
inference_dtype = torch.float32
|
| 335 |
-
|
| 336 |
-
self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
|
| 337 |
-
self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
|
| 338 |
-
self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
|
| 339 |
-
|
| 340 |
-
trainable_layers = self.sd_pipeline.get_trainable_layers()
|
| 341 |
-
|
| 342 |
-
self.accelerator.register_save_state_pre_hook(self._save_model_hook)
|
| 343 |
-
self.accelerator.register_load_state_pre_hook(self._load_model_hook)
|
| 344 |
-
|
| 345 |
-
# Enable TF32 for faster training on Ampere GPUs,
|
| 346 |
-
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
| 347 |
-
if self.config.allow_tf32:
|
| 348 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
| 349 |
-
|
| 350 |
-
self.optimizer = self._setup_optimizer(
|
| 351 |
-
trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
|
| 352 |
-
)
|
| 353 |
-
|
| 354 |
-
self.neg_prompt_embed = self.sd_pipeline.text_encoder(
|
| 355 |
-
self.sd_pipeline.tokenizer(
|
| 356 |
-
[""] if self.config.negative_prompts is None else self.config.negative_prompts,
|
| 357 |
-
return_tensors="pt",
|
| 358 |
-
padding="max_length",
|
| 359 |
-
truncation=True,
|
| 360 |
-
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
| 361 |
-
).input_ids.to(self.accelerator.device)
|
| 362 |
-
)[0]
|
| 363 |
-
|
| 364 |
-
if config.per_prompt_stat_tracking:
|
| 365 |
-
self.stat_tracker = PerPromptStatTracker(
|
| 366 |
-
config.per_prompt_stat_tracking_buffer_size,
|
| 367 |
-
config.per_prompt_stat_tracking_min_count,
|
| 368 |
-
)
|
| 369 |
-
|
| 370 |
-
# NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
|
| 371 |
-
# more memory
|
| 372 |
-
self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
|
| 373 |
-
|
| 374 |
-
if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
|
| 375 |
-
unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
| 376 |
-
self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
| 377 |
-
else:
|
| 378 |
-
self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
| 379 |
-
|
| 380 |
-
if self.config.async_reward_computation:
|
| 381 |
-
self.executor = futures.ThreadPoolExecutor(max_workers=config.max_workers)
|
| 382 |
-
|
| 383 |
-
if config.resume_from:
|
| 384 |
-
logger.info(f"Resuming from {config.resume_from}")
|
| 385 |
-
self.accelerator.load_state(config.resume_from)
|
| 386 |
-
self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
|
| 387 |
-
else:
|
| 388 |
-
self.first_epoch = 0
|
| 389 |
-
|
| 390 |
-
def compute_rewards(self, prompt_image_pairs, is_async=False):
|
| 391 |
-
if not is_async:
|
| 392 |
-
rewards = []
|
| 393 |
-
for images, prompts, prompt_metadata in prompt_image_pairs:
|
| 394 |
-
reward, reward_metadata = self.reward_fn(images, prompts, prompt_metadata)
|
| 395 |
-
rewards.append(
|
| 396 |
-
(
|
| 397 |
-
torch.as_tensor(reward, device=self.accelerator.device),
|
| 398 |
-
reward_metadata,
|
| 399 |
-
)
|
| 400 |
-
)
|
| 401 |
-
else:
|
| 402 |
-
rewards = self.executor.map(lambda x: self.reward_fn(*x), prompt_image_pairs)
|
| 403 |
-
rewards = [
|
| 404 |
-
(torch.as_tensor(reward.result(), device=self.accelerator.device), reward_metadata.result())
|
| 405 |
-
for reward, reward_metadata in rewards
|
| 406 |
-
]
|
| 407 |
-
|
| 408 |
-
return zip(*rewards)
|
| 409 |
-
|
| 410 |
-
def step(self, epoch: int, global_step: int):
|
| 411 |
-
"""
|
| 412 |
-
Perform a single step of training.
|
| 413 |
-
|
| 414 |
-
Args:
|
| 415 |
-
epoch (int): The current epoch.
|
| 416 |
-
global_step (int): The current global step.
|
| 417 |
-
|
| 418 |
-
Side Effects:
|
| 419 |
-
- Model weights are updated
|
| 420 |
-
- Logs the statistics to the accelerator trackers.
|
| 421 |
-
- If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
|
| 422 |
-
|
| 423 |
-
Returns:
|
| 424 |
-
global_step (int): The updated global step.
|
| 425 |
-
|
| 426 |
-
"""
|
| 427 |
-
samples, prompt_image_data = self._generate_samples(
|
| 428 |
-
iterations=self.config.sample_num_batches_per_epoch,
|
| 429 |
-
batch_size=self.config.sample_batch_size,
|
| 430 |
-
)
|
| 431 |
-
|
| 432 |
-
# collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
|
| 433 |
-
samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
|
| 434 |
-
rewards, rewards_metadata = self.compute_rewards(
|
| 435 |
-
prompt_image_data, is_async=self.config.async_reward_computation
|
| 436 |
-
)
|
| 437 |
-
|
| 438 |
-
for i, image_data in enumerate(prompt_image_data):
|
| 439 |
-
image_data.extend([rewards[i], rewards_metadata[i]])
|
| 440 |
-
|
| 441 |
-
if self.image_samples_callback is not None:
|
| 442 |
-
self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0])
|
| 443 |
-
|
| 444 |
-
rewards = torch.cat(rewards)
|
| 445 |
-
rewards = self.accelerator.gather(rewards).cpu().numpy()
|
| 446 |
-
|
| 447 |
-
self.accelerator.log(
|
| 448 |
-
{
|
| 449 |
-
"reward": rewards,
|
| 450 |
-
"epoch": epoch,
|
| 451 |
-
"reward_mean": rewards.mean(),
|
| 452 |
-
"reward_std": rewards.std(),
|
| 453 |
-
},
|
| 454 |
-
step=global_step,
|
| 455 |
-
)
|
| 456 |
-
|
| 457 |
-
if self.config.per_prompt_stat_tracking:
|
| 458 |
-
# gather the prompts across processes
|
| 459 |
-
prompt_ids = self.accelerator.gather(samples["prompt_ids"]).cpu().numpy()
|
| 460 |
-
prompts = self.sd_pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True)
|
| 461 |
-
advantages = self.stat_tracker.update(prompts, rewards)
|
| 462 |
-
else:
|
| 463 |
-
advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
|
| 464 |
-
|
| 465 |
-
# ungather advantages; keep the entries corresponding to the samples on this process
|
| 466 |
-
samples["advantages"] = (
|
| 467 |
-
torch.as_tensor(advantages)
|
| 468 |
-
.reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index]
|
| 469 |
-
.to(self.accelerator.device)
|
| 470 |
-
)
|
| 471 |
-
|
| 472 |
-
del samples["prompt_ids"]
|
| 473 |
-
|
| 474 |
-
total_batch_size, num_timesteps = samples["timesteps"].shape
|
| 475 |
-
|
| 476 |
-
for inner_epoch in range(self.config.train_num_inner_epochs):
|
| 477 |
-
# shuffle samples along batch dimension
|
| 478 |
-
perm = torch.randperm(total_batch_size, device=self.accelerator.device)
|
| 479 |
-
samples = {k: v[perm] for k, v in samples.items()}
|
| 480 |
-
|
| 481 |
-
# shuffle along time dimension independently for each sample
|
| 482 |
-
# still trying to understand the code below
|
| 483 |
-
perms = torch.stack(
|
| 484 |
-
[torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)]
|
| 485 |
-
)
|
| 486 |
-
|
| 487 |
-
for key in ["timesteps", "latents", "next_latents", "log_probs"]:
|
| 488 |
-
samples[key] = samples[key][
|
| 489 |
-
torch.arange(total_batch_size, device=self.accelerator.device)[:, None],
|
| 490 |
-
perms,
|
| 491 |
-
]
|
| 492 |
-
|
| 493 |
-
original_keys = samples.keys()
|
| 494 |
-
original_values = samples.values()
|
| 495 |
-
# rebatch them as user defined train_batch_size is different from sample_batch_size
|
| 496 |
-
reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values]
|
| 497 |
-
|
| 498 |
-
# Transpose the list of original values
|
| 499 |
-
transposed_values = zip(*reshaped_values)
|
| 500 |
-
# Create new dictionaries for each row of transposed values
|
| 501 |
-
samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values]
|
| 502 |
-
|
| 503 |
-
self.sd_pipeline.unet.train()
|
| 504 |
-
global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched)
|
| 505 |
-
# ensure optimization step at the end of the inner epoch
|
| 506 |
-
if not self.accelerator.sync_gradients:
|
| 507 |
-
raise ValueError(
|
| 508 |
-
"Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
|
| 509 |
-
)
|
| 510 |
-
|
| 511 |
-
if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
|
| 512 |
-
self.accelerator.save_state()
|
| 513 |
-
|
| 514 |
-
return global_step
|
| 515 |
-
|
| 516 |
-
def calculate_loss(self, latents, timesteps, next_latents, log_probs, advantages, embeds):
|
| 517 |
-
"""
|
| 518 |
-
Calculate the loss for a batch of an unpacked sample
|
| 519 |
-
|
| 520 |
-
Args:
|
| 521 |
-
latents (torch.Tensor):
|
| 522 |
-
The latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
|
| 523 |
-
timesteps (torch.Tensor):
|
| 524 |
-
The timesteps sampled from the diffusion model, shape: [batch_size]
|
| 525 |
-
next_latents (torch.Tensor):
|
| 526 |
-
The next latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
|
| 527 |
-
log_probs (torch.Tensor):
|
| 528 |
-
The log probabilities of the latents, shape: [batch_size]
|
| 529 |
-
advantages (torch.Tensor):
|
| 530 |
-
The advantages of the latents, shape: [batch_size]
|
| 531 |
-
embeds (torch.Tensor):
|
| 532 |
-
The embeddings of the prompts, shape: [2*batch_size or batch_size, ...]
|
| 533 |
-
Note: the "or" is because if train_cfg is True, the expectation is that negative prompts are concatenated to the embeds
|
| 534 |
-
|
| 535 |
-
Returns:
|
| 536 |
-
loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor)
|
| 537 |
-
(all of these are of shape (1,))
|
| 538 |
-
"""
|
| 539 |
-
with self.autocast():
|
| 540 |
-
if self.config.train_cfg:
|
| 541 |
-
noise_pred = self.sd_pipeline.unet(
|
| 542 |
-
torch.cat([latents] * 2),
|
| 543 |
-
torch.cat([timesteps] * 2),
|
| 544 |
-
embeds,
|
| 545 |
-
).sample
|
| 546 |
-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 547 |
-
noise_pred = noise_pred_uncond + self.config.sample_guidance_scale * (
|
| 548 |
-
noise_pred_text - noise_pred_uncond
|
| 549 |
-
)
|
| 550 |
-
else:
|
| 551 |
-
noise_pred = self.sd_pipeline.unet(
|
| 552 |
-
latents,
|
| 553 |
-
timesteps,
|
| 554 |
-
embeds,
|
| 555 |
-
).sample
|
| 556 |
-
# compute the log prob of next_latents given latents under the current model
|
| 557 |
-
|
| 558 |
-
scheduler_step_output = self.sd_pipeline.scheduler_step(
|
| 559 |
-
noise_pred,
|
| 560 |
-
timesteps,
|
| 561 |
-
latents,
|
| 562 |
-
eta=self.config.sample_eta,
|
| 563 |
-
prev_sample=next_latents,
|
| 564 |
-
)
|
| 565 |
-
|
| 566 |
-
log_prob = scheduler_step_output.log_probs
|
| 567 |
-
|
| 568 |
-
advantages = torch.clamp(
|
| 569 |
-
advantages,
|
| 570 |
-
-self.config.train_adv_clip_max,
|
| 571 |
-
self.config.train_adv_clip_max,
|
| 572 |
-
)
|
| 573 |
-
|
| 574 |
-
ratio = torch.exp(log_prob - log_probs)
|
| 575 |
-
|
| 576 |
-
loss = self.loss(advantages, self.config.train_clip_range, ratio)
|
| 577 |
-
|
| 578 |
-
approx_kl = 0.5 * torch.mean((log_prob - log_probs) ** 2)
|
| 579 |
-
|
| 580 |
-
clipfrac = torch.mean((torch.abs(ratio - 1.0) > self.config.train_clip_range).float())
|
| 581 |
-
|
| 582 |
-
return loss, approx_kl, clipfrac
|
| 583 |
-
|
| 584 |
-
def loss(
|
| 585 |
-
self,
|
| 586 |
-
advantages: torch.Tensor,
|
| 587 |
-
clip_range: float,
|
| 588 |
-
ratio: torch.Tensor,
|
| 589 |
-
):
|
| 590 |
-
unclipped_loss = -advantages * ratio
|
| 591 |
-
clipped_loss = -advantages * torch.clamp(
|
| 592 |
-
ratio,
|
| 593 |
-
1.0 - clip_range,
|
| 594 |
-
1.0 + clip_range,
|
| 595 |
-
)
|
| 596 |
-
return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
|
| 597 |
-
|
| 598 |
-
def _setup_optimizer(self, trainable_layers_parameters):
|
| 599 |
-
if self.config.train_use_8bit_adam:
|
| 600 |
-
import bitsandbytes
|
| 601 |
-
|
| 602 |
-
optimizer_cls = bitsandbytes.optim.AdamW8bit
|
| 603 |
-
else:
|
| 604 |
-
optimizer_cls = torch.optim.AdamW
|
| 605 |
-
|
| 606 |
-
return optimizer_cls(
|
| 607 |
-
trainable_layers_parameters,
|
| 608 |
-
lr=self.config.train_learning_rate,
|
| 609 |
-
betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
|
| 610 |
-
weight_decay=self.config.train_adam_weight_decay,
|
| 611 |
-
eps=self.config.train_adam_epsilon,
|
| 612 |
-
)
|
| 613 |
-
|
| 614 |
-
def _save_model_hook(self, models, weights, output_dir):
|
| 615 |
-
self.sd_pipeline.save_checkpoint(models, weights, output_dir)
|
| 616 |
-
weights.pop() # ensures that accelerate doesn't try to handle saving of the model
|
| 617 |
-
|
| 618 |
-
def _load_model_hook(self, models, input_dir):
|
| 619 |
-
self.sd_pipeline.load_checkpoint(models, input_dir)
|
| 620 |
-
models.pop() # ensures that accelerate doesn't try to handle loading of the model
|
| 621 |
-
|
| 622 |
-
def _generate_samples(self, iterations, batch_size):
|
| 623 |
-
"""
|
| 624 |
-
Generate samples from the model
|
| 625 |
-
|
| 626 |
-
Args:
|
| 627 |
-
iterations (int): Number of iterations to generate samples for
|
| 628 |
-
batch_size (int): Batch size to use for sampling
|
| 629 |
-
|
| 630 |
-
Returns:
|
| 631 |
-
samples (list[dict[str, torch.Tensor]]), prompt_image_pairs (list[list[Any]])
|
| 632 |
-
"""
|
| 633 |
-
samples = []
|
| 634 |
-
prompt_image_pairs = []
|
| 635 |
-
self.sd_pipeline.unet.eval()
|
| 636 |
-
|
| 637 |
-
sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
|
| 638 |
-
|
| 639 |
-
for _ in range(iterations):
|
| 640 |
-
prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
|
| 641 |
-
|
| 642 |
-
prompt_ids = self.sd_pipeline.tokenizer(
|
| 643 |
-
prompts,
|
| 644 |
-
return_tensors="pt",
|
| 645 |
-
padding="max_length",
|
| 646 |
-
truncation=True,
|
| 647 |
-
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
| 648 |
-
).input_ids.to(self.accelerator.device)
|
| 649 |
-
prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
|
| 650 |
-
|
| 651 |
-
with self.autocast():
|
| 652 |
-
sd_output = self.sd_pipeline(
|
| 653 |
-
prompt_embeds=prompt_embeds,
|
| 654 |
-
negative_prompt_embeds=sample_neg_prompt_embeds,
|
| 655 |
-
num_inference_steps=self.config.sample_num_steps,
|
| 656 |
-
guidance_scale=self.config.sample_guidance_scale,
|
| 657 |
-
eta=self.config.sample_eta,
|
| 658 |
-
output_type="pt",
|
| 659 |
-
)
|
| 660 |
-
|
| 661 |
-
images = sd_output.images
|
| 662 |
-
latents = sd_output.latents
|
| 663 |
-
log_probs = sd_output.log_probs
|
| 664 |
-
|
| 665 |
-
latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, ...)
|
| 666 |
-
log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1)
|
| 667 |
-
timesteps = self.sd_pipeline.scheduler.timesteps.repeat(batch_size, 1) # (batch_size, num_steps)
|
| 668 |
-
|
| 669 |
-
samples.append(
|
| 670 |
-
{
|
| 671 |
-
"prompt_ids": prompt_ids,
|
| 672 |
-
"prompt_embeds": prompt_embeds,
|
| 673 |
-
"timesteps": timesteps,
|
| 674 |
-
"latents": latents[:, :-1], # each entry is the latent before timestep t
|
| 675 |
-
"next_latents": latents[:, 1:], # each entry is the latent after timestep t
|
| 676 |
-
"log_probs": log_probs,
|
| 677 |
-
"negative_prompt_embeds": sample_neg_prompt_embeds,
|
| 678 |
-
}
|
| 679 |
-
)
|
| 680 |
-
prompt_image_pairs.append([images, prompts, prompt_metadata])
|
| 681 |
-
|
| 682 |
-
return samples, prompt_image_pairs
|
| 683 |
-
|
| 684 |
-
def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples):
|
| 685 |
-
"""
|
| 686 |
-
Train on a batch of samples. Main training segment
|
| 687 |
-
|
| 688 |
-
Args:
|
| 689 |
-
inner_epoch (int): The current inner epoch
|
| 690 |
-
epoch (int): The current epoch
|
| 691 |
-
global_step (int): The current global step
|
| 692 |
-
batched_samples (list[dict[str, torch.Tensor]]): The batched samples to train on
|
| 693 |
-
|
| 694 |
-
Side Effects:
|
| 695 |
-
- Model weights are updated
|
| 696 |
-
- Logs the statistics to the accelerator trackers.
|
| 697 |
-
|
| 698 |
-
Returns:
|
| 699 |
-
global_step (int): The updated global step
|
| 700 |
-
"""
|
| 701 |
-
info = defaultdict(list)
|
| 702 |
-
for _i, sample in enumerate(batched_samples):
|
| 703 |
-
if self.config.train_cfg:
|
| 704 |
-
# concat negative prompts to sample prompts to avoid two forward passes
|
| 705 |
-
embeds = torch.cat([sample["negative_prompt_embeds"], sample["prompt_embeds"]])
|
| 706 |
-
else:
|
| 707 |
-
embeds = sample["prompt_embeds"]
|
| 708 |
-
|
| 709 |
-
for j in range(self.num_train_timesteps):
|
| 710 |
-
with self.accelerator.accumulate(self.sd_pipeline.unet):
|
| 711 |
-
loss, approx_kl, clipfrac = self.calculate_loss(
|
| 712 |
-
sample["latents"][:, j],
|
| 713 |
-
sample["timesteps"][:, j],
|
| 714 |
-
sample["next_latents"][:, j],
|
| 715 |
-
sample["log_probs"][:, j],
|
| 716 |
-
sample["advantages"],
|
| 717 |
-
embeds,
|
| 718 |
-
)
|
| 719 |
-
info["approx_kl"].append(approx_kl)
|
| 720 |
-
info["clipfrac"].append(clipfrac)
|
| 721 |
-
info["loss"].append(loss)
|
| 722 |
-
|
| 723 |
-
self.accelerator.backward(loss)
|
| 724 |
-
if self.accelerator.sync_gradients:
|
| 725 |
-
self.accelerator.clip_grad_norm_(
|
| 726 |
-
self.trainable_layers.parameters()
|
| 727 |
-
if not isinstance(self.trainable_layers, list)
|
| 728 |
-
else self.trainable_layers,
|
| 729 |
-
self.config.train_max_grad_norm,
|
| 730 |
-
)
|
| 731 |
-
self.optimizer.step()
|
| 732 |
-
self.optimizer.zero_grad()
|
| 733 |
-
|
| 734 |
-
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 735 |
-
if self.accelerator.sync_gradients:
|
| 736 |
-
# log training-related stuff
|
| 737 |
-
info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
|
| 738 |
-
info = self.accelerator.reduce(info, reduction="mean")
|
| 739 |
-
info.update({"epoch": epoch, "inner_epoch": inner_epoch})
|
| 740 |
-
self.accelerator.log(info, step=global_step)
|
| 741 |
-
global_step += 1
|
| 742 |
-
info = defaultdict(list)
|
| 743 |
-
return global_step
|
| 744 |
-
|
| 745 |
-
def _config_check(self) -> tuple[bool, str]:
|
| 746 |
-
samples_per_epoch = (
|
| 747 |
-
self.config.sample_batch_size * self.accelerator.num_processes * self.config.sample_num_batches_per_epoch
|
| 748 |
-
)
|
| 749 |
-
total_train_batch_size = (
|
| 750 |
-
self.config.train_batch_size
|
| 751 |
-
* self.accelerator.num_processes
|
| 752 |
-
* self.config.train_gradient_accumulation_steps
|
| 753 |
-
)
|
| 754 |
-
|
| 755 |
-
if not self.config.sample_batch_size >= self.config.train_batch_size:
|
| 756 |
-
return (
|
| 757 |
-
False,
|
| 758 |
-
f"Sample batch size ({self.config.sample_batch_size}) must be greater than or equal to the train batch size ({self.config.train_batch_size})",
|
| 759 |
-
)
|
| 760 |
-
if not self.config.sample_batch_size % self.config.train_batch_size == 0:
|
| 761 |
-
return (
|
| 762 |
-
False,
|
| 763 |
-
f"Sample batch size ({self.config.sample_batch_size}) must be divisible by the train batch size ({self.config.train_batch_size})",
|
| 764 |
-
)
|
| 765 |
-
if not samples_per_epoch % total_train_batch_size == 0:
|
| 766 |
-
return (
|
| 767 |
-
False,
|
| 768 |
-
f"Number of samples per epoch ({samples_per_epoch}) must be divisible by the total train batch size ({total_train_batch_size})",
|
| 769 |
-
)
|
| 770 |
-
return True, ""
|
| 771 |
-
|
| 772 |
-
def train(self, epochs: Optional[int] = None):
|
| 773 |
-
"""
|
| 774 |
-
Train the model for a given number of epochs
|
| 775 |
-
"""
|
| 776 |
-
global_step = 0
|
| 777 |
-
if epochs is None:
|
| 778 |
-
epochs = self.config.num_epochs
|
| 779 |
-
for epoch in range(self.first_epoch, epochs):
|
| 780 |
-
global_step = self.step(epoch, global_step)
|
| 781 |
-
|
| 782 |
-
def _save_pretrained(self, save_directory):
|
| 783 |
-
self.sd_pipeline.save_pretrained(save_directory)
|
| 784 |
-
self.create_model_card()
|
| 785 |
-
|
| 786 |
-
def create_model_card(
|
| 787 |
-
self,
|
| 788 |
-
model_name: Optional[str] = None,
|
| 789 |
-
dataset_name: Optional[str] = None,
|
| 790 |
-
tags: Union[str, list[str], None] = None,
|
| 791 |
-
):
|
| 792 |
-
"""
|
| 793 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
| 794 |
-
|
| 795 |
-
Args:
|
| 796 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 797 |
-
Name of the model.
|
| 798 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 799 |
-
Name of the dataset used for training.
|
| 800 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 801 |
-
Tags to be associated with the model card.
|
| 802 |
-
"""
|
| 803 |
-
if not self.is_world_process_zero():
|
| 804 |
-
return
|
| 805 |
-
|
| 806 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 807 |
-
base_model = self.model.config._name_or_path
|
| 808 |
-
else:
|
| 809 |
-
base_model = None
|
| 810 |
-
|
| 811 |
-
tags = tags or []
|
| 812 |
-
if isinstance(tags, str):
|
| 813 |
-
tags = [tags]
|
| 814 |
-
|
| 815 |
-
if hasattr(self.model.config, "unsloth_version"):
|
| 816 |
-
tags.append("unsloth")
|
| 817 |
-
|
| 818 |
-
citation = textwrap.dedent("""\
|
| 819 |
-
@inproceedings{black2024training,
|
| 820 |
-
title = {{Training Diffusion Models with Reinforcement Learning}},
|
| 821 |
-
author = {Kevin Black and Michael Janner and Yilun Du and Ilya Kostrikov and Sergey Levine},
|
| 822 |
-
year = 2024,
|
| 823 |
-
booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
|
| 824 |
-
publisher = {OpenReview.net},
|
| 825 |
-
url = {https://openreview.net/forum?id=YCWjhGrJFD},
|
| 826 |
-
}""")
|
| 827 |
-
|
| 828 |
-
model_card = generate_model_card(
|
| 829 |
-
base_model=base_model,
|
| 830 |
-
model_name=model_name,
|
| 831 |
-
hub_model_id=self.hub_model_id,
|
| 832 |
-
dataset_name=dataset_name,
|
| 833 |
-
tags=tags,
|
| 834 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 835 |
-
comet_url=get_comet_experiment_url(),
|
| 836 |
-
trainer_name="DDPO",
|
| 837 |
-
trainer_citation=citation,
|
| 838 |
-
paper_title="Training Diffusion Models with Reinforcement Learning",
|
| 839 |
-
paper_id="2305.13301",
|
| 840 |
-
)
|
| 841 |
-
|
| 842 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 843 |
-
class UnslothDDPOTrainer(_UnslothDDPOTrainer):
|
| 844 |
-
"""
|
| 845 |
-
|
| 846 |
-
The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
|
| 847 |
-
Note, this trainer is heavily inspired by the work here: https://github.com/kvablack/ddpo-pytorch
|
| 848 |
-
As of now only Stable Diffusion based pipelines are supported
|
| 849 |
-
|
| 850 |
-
Attributes:
|
| 851 |
-
**config** (`DDPOConfig`) -- Configuration object for DDPOTrainer. Check the documentation of `PPOConfig` for more
|
| 852 |
-
details.
|
| 853 |
-
**reward_function** (Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]) -- Reward function to be used
|
| 854 |
-
**prompt_function** (Callable[[], tuple[str, Any]]) -- Function to generate prompts to guide model
|
| 855 |
-
**sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training.
|
| 856 |
-
**image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images
|
| 857 |
-
|
| 858 |
-
"""
|
| 859 |
-
def __init__(
|
| 860 |
-
self,
|
| 861 |
-
config,
|
| 862 |
-
reward_function,
|
| 863 |
-
prompt_function,
|
| 864 |
-
sd_pipeline,
|
| 865 |
-
image_samples_hook = None,
|
| 866 |
-
**kwargs
|
| 867 |
-
):
|
| 868 |
-
if args is None: args = UnslothDDPOConfig()
|
| 869 |
-
other_metrics = []
|
| 870 |
-
|
| 871 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 872 |
-
PatchRLStatistics('ddpo_trainer', other_metrics)
|
| 873 |
-
|
| 874 |
-
super().__init__(
|
| 875 |
-
config = config,
|
| 876 |
-
reward_function = reward_function,
|
| 877 |
-
prompt_function = prompt_function,
|
| 878 |
-
sd_pipeline = sd_pipeline,
|
| 879 |
-
image_samples_hook = image_samples_hook,**kwargs)
|
| 880 |
-
|
| 881 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_run_uploads/UnslothDPOTrainer.py
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
test_run_uploads/UnslothGKDTrainer.py
DELETED
|
@@ -1,885 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
2025.7.11
|
| 3 |
-
2025.7.11
|
| 4 |
-
4.54.1
|
| 5 |
-
0.16.1
|
| 6 |
-
__UNSLOTH_VERSIONING__
|
| 7 |
-
"""
|
| 8 |
-
from torch import Tensor
|
| 9 |
-
import torch
|
| 10 |
-
import torch.nn as nn
|
| 11 |
-
from torch.nn import functional as F
|
| 12 |
-
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
| 13 |
-
from trl.trainer.gkd_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DataCollator, DataCollatorForChatML, Dataset, EvalPrediction, F, FeatureExtractionMixin, GKDConfig, GKDTrainer, GenerationConfig, Optional, PeftConfig, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, SFTTrainer, TrainerCallback, Union, deepcopy, disable_dropout_in_model, empty_cache, generate_model_card, get_comet_experiment_url, is_wandb_available, nn, os, random, textwrap, torch, unwrap_model_for_generation, wandb)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
import os
|
| 17 |
-
from typing import *
|
| 18 |
-
from dataclasses import dataclass, field
|
| 19 |
-
from packaging.version import Version
|
| 20 |
-
import torch
|
| 21 |
-
import numpy as np
|
| 22 |
-
from contextlib import nullcontext
|
| 23 |
-
from torch.nn import functional as F
|
| 24 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
| 25 |
-
|
| 26 |
-
torch_compile_options = {
|
| 27 |
-
"epilogue_fusion" : True,
|
| 28 |
-
"max_autotune" : False,
|
| 29 |
-
"shape_padding" : True,
|
| 30 |
-
"trace.enabled" : False,
|
| 31 |
-
"triton.cudagraphs" : False,
|
| 32 |
-
}
|
| 33 |
-
|
| 34 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 35 |
-
def chunked_selective_log_softmax(logits, index):
|
| 36 |
-
# Split into 4 chunks only
|
| 37 |
-
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
|
| 38 |
-
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
|
| 39 |
-
all_per_token_logps = []
|
| 40 |
-
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
|
| 41 |
-
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
|
| 42 |
-
chunk_logits = chunk_logits.to(torch.float32)
|
| 43 |
-
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 44 |
-
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
|
| 45 |
-
per_token_logps = selected_logits - logsumexp_values
|
| 46 |
-
all_per_token_logps.append(per_token_logps)
|
| 47 |
-
pass
|
| 48 |
-
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 49 |
-
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
|
| 50 |
-
return all_per_token_logps
|
| 51 |
-
@dataclass
|
| 52 |
-
class UnslothGKDConfig(GKDConfig):
|
| 53 |
-
"""
|
| 54 |
-
|
| 55 |
-
Configuration class for [`GKDTrainer`].
|
| 56 |
-
|
| 57 |
-
Args:
|
| 58 |
-
temperature (`float`, *optional*, defaults to `0.9`):
|
| 59 |
-
Temperature for sampling. The higher the temperature, the more random the completions.
|
| 60 |
-
lmbda (`float`, *optional*, defaults to `0.5`):
|
| 61 |
-
Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy
|
| 62 |
-
student-generated outputs).
|
| 63 |
-
beta (`float`, *optional*, defaults to `0.5`):
|
| 64 |
-
Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When
|
| 65 |
-
beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.
|
| 66 |
-
max_new_tokens (`int`, *optional*, defaults to `128`):
|
| 67 |
-
Maximum number of tokens to generate per completion.
|
| 68 |
-
teacher_model_name_or_path (`str` or `None`, *optional*, defaults to `None`):
|
| 69 |
-
Model name or path of the teacher model. If `None`, the teacher model will be the same as the model
|
| 70 |
-
being trained.
|
| 71 |
-
teacher_model_init_kwargs (`dict[str, Any]]` or `None`, *optional*, defaults to `None`):
|
| 72 |
-
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
|
| 73 |
-
from a string.
|
| 74 |
-
disable_dropout (`bool`, *optional*, defaults to `True`):
|
| 75 |
-
Whether to disable dropout in the model.
|
| 76 |
-
seq_kd (`bool`, *optional*, defaults to `False`):
|
| 77 |
-
Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT
|
| 78 |
-
on teacher-generated output).
|
| 79 |
-
|
| 80 |
-
"""
|
| 81 |
-
vllm_sampling_params: Optional[Any] = field(
|
| 82 |
-
default = None,
|
| 83 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
| 84 |
-
)
|
| 85 |
-
unsloth_num_chunks : Optional[int] = field(
|
| 86 |
-
default = -1,
|
| 87 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 88 |
-
)
|
| 89 |
-
def __init__(
|
| 90 |
-
self,
|
| 91 |
-
output_dir = None,
|
| 92 |
-
overwrite_output_dir = None,
|
| 93 |
-
do_train = False,
|
| 94 |
-
do_eval = False,
|
| 95 |
-
do_predict = False,
|
| 96 |
-
eval_strategy = 'no',
|
| 97 |
-
prediction_loss_only = False,
|
| 98 |
-
per_device_train_batch_size = 4,
|
| 99 |
-
per_device_eval_batch_size = 4,
|
| 100 |
-
per_gpu_train_batch_size = None,
|
| 101 |
-
per_gpu_eval_batch_size = None,
|
| 102 |
-
gradient_accumulation_steps = 2,
|
| 103 |
-
eval_accumulation_steps = 2,
|
| 104 |
-
eval_delay = 0,
|
| 105 |
-
torch_empty_cache_steps = 250,
|
| 106 |
-
learning_rate = 5e-05,
|
| 107 |
-
weight_decay = 0.01,
|
| 108 |
-
adam_beta1 = 0.9,
|
| 109 |
-
adam_beta2 = 0.999,
|
| 110 |
-
adam_epsilon = 1e-08,
|
| 111 |
-
max_grad_norm = 1.0,
|
| 112 |
-
num_train_epochs = 3.0,
|
| 113 |
-
max_steps = -1,
|
| 114 |
-
lr_scheduler_type = 'linear',
|
| 115 |
-
warmup_ratio = 0.1,
|
| 116 |
-
warmup_steps = 0,
|
| 117 |
-
log_level = 'passive',
|
| 118 |
-
log_level_replica = 'warning',
|
| 119 |
-
log_on_each_node = True,
|
| 120 |
-
logging_dir = None,
|
| 121 |
-
logging_strategy = 'steps',
|
| 122 |
-
logging_first_step = False,
|
| 123 |
-
logging_steps = 1,
|
| 124 |
-
logging_nan_inf_filter = False,
|
| 125 |
-
save_strategy = 'steps',
|
| 126 |
-
save_steps = 500,
|
| 127 |
-
save_total_limit = None,
|
| 128 |
-
save_safetensors = True,
|
| 129 |
-
save_on_each_node = False,
|
| 130 |
-
save_only_model = False,
|
| 131 |
-
restore_callback_states_from_checkpoint = False,
|
| 132 |
-
no_cuda = False,
|
| 133 |
-
use_cpu = False,
|
| 134 |
-
use_mps_device = False,
|
| 135 |
-
seed = 3407,
|
| 136 |
-
data_seed = 3407,
|
| 137 |
-
jit_mode_eval = False,
|
| 138 |
-
use_ipex = False,
|
| 139 |
-
bf16 = False,
|
| 140 |
-
fp16 = False,
|
| 141 |
-
fp16_opt_level = 'O1',
|
| 142 |
-
half_precision_backend = 'auto',
|
| 143 |
-
bf16_full_eval = False,
|
| 144 |
-
fp16_full_eval = False,
|
| 145 |
-
tf32 = None,
|
| 146 |
-
local_rank = -1,
|
| 147 |
-
ddp_backend = None,
|
| 148 |
-
tpu_num_cores = None,
|
| 149 |
-
tpu_metrics_debug = False,
|
| 150 |
-
debug = '',
|
| 151 |
-
dataloader_drop_last = False,
|
| 152 |
-
eval_steps = None,
|
| 153 |
-
dataloader_num_workers = 0,
|
| 154 |
-
dataloader_prefetch_factor = None,
|
| 155 |
-
past_index = -1,
|
| 156 |
-
run_name = None,
|
| 157 |
-
disable_tqdm = None,
|
| 158 |
-
remove_unused_columns = True,
|
| 159 |
-
label_names = None,
|
| 160 |
-
load_best_model_at_end = False,
|
| 161 |
-
metric_for_best_model = None,
|
| 162 |
-
greater_is_better = None,
|
| 163 |
-
ignore_data_skip = False,
|
| 164 |
-
fsdp = '',
|
| 165 |
-
fsdp_min_num_params = 0,
|
| 166 |
-
fsdp_config = None,
|
| 167 |
-
fsdp_transformer_layer_cls_to_wrap = None,
|
| 168 |
-
accelerator_config = None,
|
| 169 |
-
deepspeed = None,
|
| 170 |
-
label_smoothing_factor = 0.0,
|
| 171 |
-
optim = 'adamw_8bit',
|
| 172 |
-
optim_args = None,
|
| 173 |
-
adafactor = False,
|
| 174 |
-
group_by_length = False,
|
| 175 |
-
length_column_name = 'length',
|
| 176 |
-
report_to = None,
|
| 177 |
-
ddp_find_unused_parameters = None,
|
| 178 |
-
ddp_bucket_cap_mb = None,
|
| 179 |
-
ddp_broadcast_buffers = None,
|
| 180 |
-
dataloader_pin_memory = True,
|
| 181 |
-
dataloader_persistent_workers = False,
|
| 182 |
-
skip_memory_metrics = True,
|
| 183 |
-
use_legacy_prediction_loop = False,
|
| 184 |
-
push_to_hub = False,
|
| 185 |
-
resume_from_checkpoint = None,
|
| 186 |
-
hub_model_id = None,
|
| 187 |
-
hub_strategy = 'every_save',
|
| 188 |
-
hub_token = None,
|
| 189 |
-
hub_private_repo = None,
|
| 190 |
-
hub_always_push = False,
|
| 191 |
-
hub_revision = None,
|
| 192 |
-
gradient_checkpointing = False,
|
| 193 |
-
gradient_checkpointing_kwargs = None,
|
| 194 |
-
include_inputs_for_metrics = False,
|
| 195 |
-
eval_do_concat_batches = True,
|
| 196 |
-
fp16_backend = 'auto',
|
| 197 |
-
push_to_hub_model_id = None,
|
| 198 |
-
push_to_hub_organization = None,
|
| 199 |
-
push_to_hub_token = None,
|
| 200 |
-
mp_parameters = '',
|
| 201 |
-
auto_find_batch_size = True,
|
| 202 |
-
full_determinism = False,
|
| 203 |
-
torchdynamo = None,
|
| 204 |
-
ray_scope = 'last',
|
| 205 |
-
ddp_timeout = 1800,
|
| 206 |
-
torch_compile = False,
|
| 207 |
-
torch_compile_backend = None,
|
| 208 |
-
torch_compile_mode = None,
|
| 209 |
-
include_tokens_per_second = False,
|
| 210 |
-
include_num_input_tokens_seen = False,
|
| 211 |
-
neftune_noise_alpha = None,
|
| 212 |
-
optim_target_modules = None,
|
| 213 |
-
batch_eval_metrics = False,
|
| 214 |
-
eval_on_start = False,
|
| 215 |
-
use_liger_kernel = False,
|
| 216 |
-
liger_kernel_config = None,
|
| 217 |
-
eval_use_gather_object = False,
|
| 218 |
-
average_tokens_across_devices = True,
|
| 219 |
-
model_init_kwargs = None,
|
| 220 |
-
dataset_text_field = 'text',
|
| 221 |
-
dataset_kwargs = None,
|
| 222 |
-
dataset_num_proc = None,
|
| 223 |
-
pad_token = None,
|
| 224 |
-
max_length = 1024,
|
| 225 |
-
packing = False,
|
| 226 |
-
padding_free = False,
|
| 227 |
-
eval_packing = None,
|
| 228 |
-
dataset_batch_size = None,
|
| 229 |
-
num_of_sequences = None,
|
| 230 |
-
chars_per_token = None,
|
| 231 |
-
max_seq_length = None,
|
| 232 |
-
use_liger = None,
|
| 233 |
-
temperature = 0.9,
|
| 234 |
-
lmbda = 0.5,
|
| 235 |
-
beta = 0.5,
|
| 236 |
-
max_new_tokens = 128,
|
| 237 |
-
teacher_model_name_or_path = None,
|
| 238 |
-
teacher_model_init_kwargs = None,
|
| 239 |
-
disable_dropout = True,
|
| 240 |
-
seq_kd = False,
|
| 241 |
-
vllm_sampling_params = None,
|
| 242 |
-
unsloth_num_chunks = -1,
|
| 243 |
-
**kwargs,
|
| 244 |
-
):
|
| 245 |
-
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 246 |
-
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 247 |
-
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 248 |
-
output_dir = 'unsloth_training_checkpoints'
|
| 249 |
-
save_strategy = 'no'
|
| 250 |
-
if dataset_num_proc is None:
|
| 251 |
-
from multiprocessing import cpu_count
|
| 252 |
-
dataset_num_proc = min(cpu_count()*2, 2)
|
| 253 |
-
if temperature <= 0:
|
| 254 |
-
raise MathError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
|
| 255 |
-
elif temperature >= 10:
|
| 256 |
-
raise MathError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
super().__init__(
|
| 260 |
-
output_dir = output_dir,
|
| 261 |
-
overwrite_output_dir = overwrite_output_dir,
|
| 262 |
-
do_train = do_train,
|
| 263 |
-
do_eval = do_eval,
|
| 264 |
-
do_predict = do_predict,
|
| 265 |
-
eval_strategy = eval_strategy,
|
| 266 |
-
prediction_loss_only = prediction_loss_only,
|
| 267 |
-
per_device_train_batch_size = per_device_train_batch_size,
|
| 268 |
-
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 269 |
-
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
| 270 |
-
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
| 271 |
-
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 272 |
-
eval_accumulation_steps = eval_accumulation_steps,
|
| 273 |
-
eval_delay = eval_delay,
|
| 274 |
-
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 275 |
-
learning_rate = learning_rate,
|
| 276 |
-
weight_decay = weight_decay,
|
| 277 |
-
adam_beta1 = adam_beta1,
|
| 278 |
-
adam_beta2 = adam_beta2,
|
| 279 |
-
adam_epsilon = adam_epsilon,
|
| 280 |
-
max_grad_norm = max_grad_norm,
|
| 281 |
-
num_train_epochs = num_train_epochs,
|
| 282 |
-
max_steps = max_steps,
|
| 283 |
-
lr_scheduler_type = lr_scheduler_type,
|
| 284 |
-
warmup_ratio = warmup_ratio,
|
| 285 |
-
warmup_steps = warmup_steps,
|
| 286 |
-
log_level = log_level,
|
| 287 |
-
log_level_replica = log_level_replica,
|
| 288 |
-
log_on_each_node = log_on_each_node,
|
| 289 |
-
logging_dir = logging_dir,
|
| 290 |
-
logging_strategy = logging_strategy,
|
| 291 |
-
logging_first_step = logging_first_step,
|
| 292 |
-
logging_steps = logging_steps,
|
| 293 |
-
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 294 |
-
save_strategy = save_strategy,
|
| 295 |
-
save_steps = save_steps,
|
| 296 |
-
save_total_limit = save_total_limit,
|
| 297 |
-
save_safetensors = save_safetensors,
|
| 298 |
-
save_on_each_node = save_on_each_node,
|
| 299 |
-
save_only_model = save_only_model,
|
| 300 |
-
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 301 |
-
no_cuda = no_cuda,
|
| 302 |
-
use_cpu = use_cpu,
|
| 303 |
-
use_mps_device = use_mps_device,
|
| 304 |
-
seed = seed,
|
| 305 |
-
data_seed = data_seed,
|
| 306 |
-
jit_mode_eval = jit_mode_eval,
|
| 307 |
-
use_ipex = use_ipex,
|
| 308 |
-
bf16 = bf16,
|
| 309 |
-
fp16 = fp16,
|
| 310 |
-
fp16_opt_level = fp16_opt_level,
|
| 311 |
-
half_precision_backend = half_precision_backend,
|
| 312 |
-
bf16_full_eval = bf16_full_eval,
|
| 313 |
-
fp16_full_eval = fp16_full_eval,
|
| 314 |
-
tf32 = tf32,
|
| 315 |
-
local_rank = local_rank,
|
| 316 |
-
ddp_backend = ddp_backend,
|
| 317 |
-
tpu_num_cores = tpu_num_cores,
|
| 318 |
-
tpu_metrics_debug = tpu_metrics_debug,
|
| 319 |
-
debug = debug,
|
| 320 |
-
dataloader_drop_last = dataloader_drop_last,
|
| 321 |
-
eval_steps = eval_steps,
|
| 322 |
-
dataloader_num_workers = dataloader_num_workers,
|
| 323 |
-
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 324 |
-
past_index = past_index,
|
| 325 |
-
run_name = run_name,
|
| 326 |
-
disable_tqdm = disable_tqdm,
|
| 327 |
-
remove_unused_columns = remove_unused_columns,
|
| 328 |
-
label_names = label_names,
|
| 329 |
-
load_best_model_at_end = load_best_model_at_end,
|
| 330 |
-
metric_for_best_model = metric_for_best_model,
|
| 331 |
-
greater_is_better = greater_is_better,
|
| 332 |
-
ignore_data_skip = ignore_data_skip,
|
| 333 |
-
fsdp = fsdp,
|
| 334 |
-
fsdp_min_num_params = fsdp_min_num_params,
|
| 335 |
-
fsdp_config = fsdp_config,
|
| 336 |
-
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
| 337 |
-
accelerator_config = accelerator_config,
|
| 338 |
-
deepspeed = deepspeed,
|
| 339 |
-
label_smoothing_factor = label_smoothing_factor,
|
| 340 |
-
optim = optim,
|
| 341 |
-
optim_args = optim_args,
|
| 342 |
-
adafactor = adafactor,
|
| 343 |
-
group_by_length = group_by_length,
|
| 344 |
-
length_column_name = length_column_name,
|
| 345 |
-
report_to = report_to,
|
| 346 |
-
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 347 |
-
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 348 |
-
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 349 |
-
dataloader_pin_memory = dataloader_pin_memory,
|
| 350 |
-
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 351 |
-
skip_memory_metrics = skip_memory_metrics,
|
| 352 |
-
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
| 353 |
-
push_to_hub = push_to_hub,
|
| 354 |
-
resume_from_checkpoint = resume_from_checkpoint,
|
| 355 |
-
hub_model_id = hub_model_id,
|
| 356 |
-
hub_strategy = hub_strategy,
|
| 357 |
-
hub_token = hub_token,
|
| 358 |
-
hub_private_repo = hub_private_repo,
|
| 359 |
-
hub_always_push = hub_always_push,
|
| 360 |
-
hub_revision = hub_revision,
|
| 361 |
-
gradient_checkpointing = gradient_checkpointing,
|
| 362 |
-
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 363 |
-
include_inputs_for_metrics = include_inputs_for_metrics,
|
| 364 |
-
eval_do_concat_batches = eval_do_concat_batches,
|
| 365 |
-
fp16_backend = fp16_backend,
|
| 366 |
-
push_to_hub_model_id = push_to_hub_model_id,
|
| 367 |
-
push_to_hub_organization = push_to_hub_organization,
|
| 368 |
-
push_to_hub_token = push_to_hub_token,
|
| 369 |
-
mp_parameters = mp_parameters,
|
| 370 |
-
auto_find_batch_size = auto_find_batch_size,
|
| 371 |
-
full_determinism = full_determinism,
|
| 372 |
-
torchdynamo = torchdynamo,
|
| 373 |
-
ray_scope = ray_scope,
|
| 374 |
-
ddp_timeout = ddp_timeout,
|
| 375 |
-
torch_compile = torch_compile,
|
| 376 |
-
torch_compile_backend = torch_compile_backend,
|
| 377 |
-
torch_compile_mode = torch_compile_mode,
|
| 378 |
-
include_tokens_per_second = include_tokens_per_second,
|
| 379 |
-
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 380 |
-
neftune_noise_alpha = neftune_noise_alpha,
|
| 381 |
-
optim_target_modules = optim_target_modules,
|
| 382 |
-
batch_eval_metrics = batch_eval_metrics,
|
| 383 |
-
eval_on_start = eval_on_start,
|
| 384 |
-
use_liger_kernel = use_liger_kernel,
|
| 385 |
-
liger_kernel_config = liger_kernel_config,
|
| 386 |
-
eval_use_gather_object = eval_use_gather_object,
|
| 387 |
-
average_tokens_across_devices = average_tokens_across_devices,
|
| 388 |
-
model_init_kwargs = model_init_kwargs,
|
| 389 |
-
dataset_text_field = dataset_text_field,
|
| 390 |
-
dataset_kwargs = dataset_kwargs,
|
| 391 |
-
dataset_num_proc = dataset_num_proc,
|
| 392 |
-
pad_token = pad_token,
|
| 393 |
-
max_length = max_length,
|
| 394 |
-
packing = packing,
|
| 395 |
-
padding_free = padding_free,
|
| 396 |
-
eval_packing = eval_packing,
|
| 397 |
-
dataset_batch_size = dataset_batch_size,
|
| 398 |
-
num_of_sequences = num_of_sequences,
|
| 399 |
-
chars_per_token = chars_per_token,
|
| 400 |
-
max_seq_length = max_seq_length,
|
| 401 |
-
use_liger = use_liger,
|
| 402 |
-
temperature = temperature,
|
| 403 |
-
lmbda = lmbda,
|
| 404 |
-
beta = beta,
|
| 405 |
-
max_new_tokens = max_new_tokens,
|
| 406 |
-
teacher_model_name_or_path = teacher_model_name_or_path,
|
| 407 |
-
teacher_model_init_kwargs = teacher_model_init_kwargs,
|
| 408 |
-
disable_dropout = disable_dropout,
|
| 409 |
-
seq_kd = seq_kd,**kwargs)
|
| 410 |
-
self.vllm_sampling_params = vllm_sampling_params
|
| 411 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
| 412 |
-
pass
|
| 413 |
-
|
| 414 |
-
class _UnslothGKDTrainer(SFTTrainer):
|
| 415 |
-
_tag_names = ["trl", "gkd"]
|
| 416 |
-
|
| 417 |
-
def __init__(
|
| 418 |
-
self,
|
| 419 |
-
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
| 420 |
-
teacher_model: Union[PreTrainedModel, nn.Module, str] = None,
|
| 421 |
-
args: Optional[GKDConfig] = None,
|
| 422 |
-
data_collator: Optional[DataCollator] = None, # type: ignore
|
| 423 |
-
train_dataset: Optional[Dataset] = None,
|
| 424 |
-
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 425 |
-
processing_class: Optional[
|
| 426 |
-
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 427 |
-
] = None,
|
| 428 |
-
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
| 429 |
-
callbacks: Optional[list[TrainerCallback]] = None,
|
| 430 |
-
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 431 |
-
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 432 |
-
peft_config: Optional["PeftConfig"] = None,
|
| 433 |
-
formatting_func: Optional[Callable] = None,
|
| 434 |
-
):
|
| 435 |
-
# add remove_unused_columns=False to the dataclass args
|
| 436 |
-
args.remove_unused_columns = False
|
| 437 |
-
data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length)
|
| 438 |
-
|
| 439 |
-
super().__init__(
|
| 440 |
-
model,
|
| 441 |
-
args=args,
|
| 442 |
-
data_collator=data_collator,
|
| 443 |
-
train_dataset=train_dataset,
|
| 444 |
-
eval_dataset=eval_dataset,
|
| 445 |
-
processing_class=processing_class,
|
| 446 |
-
compute_metrics=compute_metrics,
|
| 447 |
-
callbacks=callbacks,
|
| 448 |
-
optimizers=optimizers,
|
| 449 |
-
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 450 |
-
peft_config=peft_config,
|
| 451 |
-
formatting_func=formatting_func,
|
| 452 |
-
)
|
| 453 |
-
|
| 454 |
-
if args.teacher_model_init_kwargs is None:
|
| 455 |
-
teacher_model_init_kwargs = {}
|
| 456 |
-
elif not isinstance(teacher_model, str):
|
| 457 |
-
raise ValueError(
|
| 458 |
-
"You passed teacher_model_init_kwargs to the GKDConfig, but your teacher_model is already instantiated."
|
| 459 |
-
)
|
| 460 |
-
else:
|
| 461 |
-
teacher_model_init_kwargs = args.teacher_model_init_kwargs
|
| 462 |
-
teacher_model_init_kwargs["torch_dtype"] = (
|
| 463 |
-
teacher_model_init_kwargs["torch_dtype"]
|
| 464 |
-
if teacher_model_init_kwargs["torch_dtype"] in ["auto", None]
|
| 465 |
-
else getattr(torch, teacher_model_init_kwargs["torch_dtype"])
|
| 466 |
-
)
|
| 467 |
-
|
| 468 |
-
if isinstance(teacher_model, str):
|
| 469 |
-
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
|
| 470 |
-
|
| 471 |
-
# Disable dropout in the model
|
| 472 |
-
if args.disable_dropout:
|
| 473 |
-
disable_dropout_in_model(self.model)
|
| 474 |
-
|
| 475 |
-
if self.is_deepspeed_enabled:
|
| 476 |
-
self.teacher_model = self._prepare_deepspeed(teacher_model)
|
| 477 |
-
else:
|
| 478 |
-
self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True)
|
| 479 |
-
|
| 480 |
-
self.lmbda = args.lmbda
|
| 481 |
-
self.beta = args.beta
|
| 482 |
-
self.temperature = args.temperature
|
| 483 |
-
self.seq_kd = args.seq_kd
|
| 484 |
-
|
| 485 |
-
self.generation_config = GenerationConfig(
|
| 486 |
-
max_new_tokens=args.max_new_tokens,
|
| 487 |
-
temperature=args.temperature,
|
| 488 |
-
do_sample=True,
|
| 489 |
-
top_k=0,
|
| 490 |
-
use_cache=False if args.gradient_checkpointing else True,
|
| 491 |
-
pad_token_id=self.processing_class.pad_token_id,
|
| 492 |
-
)
|
| 493 |
-
# Set custom EOS tokens if they are specified by the model's generation
|
| 494 |
-
# config. This is important for models with the Llama 3 chat template,
|
| 495 |
-
# which use special tokens <|eot_id|> and <|eom_id|> to mark the end of
|
| 496 |
-
# turns or messages.
|
| 497 |
-
if (
|
| 498 |
-
hasattr(self.model.generation_config, "eos_token_id")
|
| 499 |
-
and self.model.generation_config.eos_token_id is not None
|
| 500 |
-
):
|
| 501 |
-
self.generation_config.eos_token_id = self.model.generation_config.eos_token_id
|
| 502 |
-
|
| 503 |
-
def _prepare_dataset(self, dataset, *args):
|
| 504 |
-
# SFTTrainer._prepare_dataset() applies the chat template and rename the messages column to text. However, we
|
| 505 |
-
# need to keep the messages column as it is. We use the following workaround to keep the messages column.
|
| 506 |
-
dataset = dataset.add_column("_messages", dataset["messages"])
|
| 507 |
-
dataset = super()._prepare_dataset(dataset, *args)
|
| 508 |
-
dataset = dataset.rename_column("_messages", "messages")
|
| 509 |
-
return dataset
|
| 510 |
-
|
| 511 |
-
@staticmethod
|
| 512 |
-
def generalized_jsd_loss(
|
| 513 |
-
student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean"
|
| 514 |
-
):
|
| 515 |
-
"""
|
| 516 |
-
Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1)
|
| 517 |
-
of https://huggingface.co/papers/2306.13649 for the definition.
|
| 518 |
-
|
| 519 |
-
Args:
|
| 520 |
-
student_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
|
| 521 |
-
teacher_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
|
| 522 |
-
labels: Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing loss
|
| 523 |
-
beta: Interpolation coefficient between 0 and 1 (default: 0.5)
|
| 524 |
-
temperature: Softmax temperature (default: 1.0)
|
| 525 |
-
reduction: Specifies the reduction to apply to the output (default: 'batchmean')
|
| 526 |
-
|
| 527 |
-
Returns:
|
| 528 |
-
loss: Scalar tensor with the generalized JSD loss
|
| 529 |
-
"""
|
| 530 |
-
|
| 531 |
-
# Apply temperature scaling
|
| 532 |
-
student_logits = student_logits / temperature
|
| 533 |
-
teacher_logits = teacher_logits / temperature
|
| 534 |
-
|
| 535 |
-
# Compute log probabilities for student and probabilities for teacher
|
| 536 |
-
student_log_probs = F.log_softmax(student_logits, dim=-1)
|
| 537 |
-
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
|
| 538 |
-
|
| 539 |
-
if beta == 0:
|
| 540 |
-
jsd = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True)
|
| 541 |
-
elif beta == 1:
|
| 542 |
-
jsd = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True)
|
| 543 |
-
else:
|
| 544 |
-
# Compute the log of the mixture distribution
|
| 545 |
-
# log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture
|
| 546 |
-
beta = torch.tensor(beta, dtype=student_log_probs.dtype)
|
| 547 |
-
mixture_log_probs = torch.logsumexp(
|
| 548 |
-
torch.stack([student_log_probs + torch.log(1 - beta), teacher_log_probs + torch.log(beta)]),
|
| 549 |
-
dim=0,
|
| 550 |
-
)
|
| 551 |
-
|
| 552 |
-
# Compute KL divergences using F.kl_div
|
| 553 |
-
# PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper.
|
| 554 |
-
kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
|
| 555 |
-
kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)
|
| 556 |
-
|
| 557 |
-
# Compute the Generalized Jensen-Shannon Divergence
|
| 558 |
-
jsd = beta * kl_teacher + (1 - beta) * kl_student
|
| 559 |
-
|
| 560 |
-
# Masking
|
| 561 |
-
if labels is not None:
|
| 562 |
-
mask = labels != -100
|
| 563 |
-
jsd = jsd[mask]
|
| 564 |
-
|
| 565 |
-
# Apply reduction
|
| 566 |
-
if reduction == "batchmean":
|
| 567 |
-
return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / (jsd.size(0) * jsd.size(1))
|
| 568 |
-
elif reduction == "sum":
|
| 569 |
-
return jsd.sum()
|
| 570 |
-
elif reduction == "mean":
|
| 571 |
-
return jsd.mean()
|
| 572 |
-
else:
|
| 573 |
-
return jsd
|
| 574 |
-
|
| 575 |
-
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
| 576 |
-
# compute student output
|
| 577 |
-
outputs_student = model(
|
| 578 |
-
input_ids=inputs["input_ids"],
|
| 579 |
-
attention_mask=inputs["attention_mask"],
|
| 580 |
-
)
|
| 581 |
-
|
| 582 |
-
# compute teacher output in eval mode
|
| 583 |
-
self.teacher_model.eval()
|
| 584 |
-
with torch.no_grad():
|
| 585 |
-
outputs_teacher = self.teacher_model(
|
| 586 |
-
input_ids=inputs["input_ids"],
|
| 587 |
-
attention_mask=inputs["attention_mask"],
|
| 588 |
-
)
|
| 589 |
-
|
| 590 |
-
# slice the logits for the generated tokens using the inputs["prompts"] lengths
|
| 591 |
-
prompt_lengths = inputs["prompts"].shape[1]
|
| 592 |
-
shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :]
|
| 593 |
-
shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :]
|
| 594 |
-
shifted_labels = inputs["labels"][:, prompt_lengths:]
|
| 595 |
-
|
| 596 |
-
# compute loss
|
| 597 |
-
loss = self.generalized_jsd_loss(
|
| 598 |
-
student_logits=shifted_student_logits,
|
| 599 |
-
teacher_logits=shifted_teacher_logits,
|
| 600 |
-
labels=shifted_labels,
|
| 601 |
-
beta=self.beta,
|
| 602 |
-
)
|
| 603 |
-
|
| 604 |
-
# empty cache
|
| 605 |
-
empty_cache()
|
| 606 |
-
|
| 607 |
-
# Return loss
|
| 608 |
-
return (loss, outputs_student) if return_outputs else loss
|
| 609 |
-
|
| 610 |
-
@staticmethod
|
| 611 |
-
def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None):
|
| 612 |
-
# Generate output with respect to the prompt only
|
| 613 |
-
generated_outputs = model.generate(
|
| 614 |
-
input_ids=inputs["prompts"],
|
| 615 |
-
attention_mask=inputs.get("prompt_attention_mask", None),
|
| 616 |
-
generation_config=generation_config,
|
| 617 |
-
return_dict_in_generate=True,
|
| 618 |
-
)
|
| 619 |
-
|
| 620 |
-
# Get the generated token IDs
|
| 621 |
-
generated_tokens = generated_outputs.sequences
|
| 622 |
-
# Calculate new attention mask
|
| 623 |
-
new_attention_mask = torch.ones_like(generated_tokens)
|
| 624 |
-
new_labels = generated_tokens.clone()
|
| 625 |
-
|
| 626 |
-
# If there's pad_token_id, set attention mask to 0 for padding tokens
|
| 627 |
-
if pad_token_id is not None:
|
| 628 |
-
new_labels[new_labels == pad_token_id] = -100
|
| 629 |
-
new_attention_mask[generated_tokens == pad_token_id] = 0
|
| 630 |
-
|
| 631 |
-
return generated_tokens, new_attention_mask, new_labels
|
| 632 |
-
|
| 633 |
-
def training_step(
|
| 634 |
-
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
|
| 635 |
-
) -> torch.Tensor:
|
| 636 |
-
"""
|
| 637 |
-
Perform a training step for the Generalized Knowledge Distillation (GKD) model.
|
| 638 |
-
|
| 639 |
-
This method implements the on-policy learning approach described in the GKD paper.
|
| 640 |
-
With probability `self.lmbda`, it generates new responses using the student model,
|
| 641 |
-
which are then used for training instead of the original inputs.
|
| 642 |
-
"""
|
| 643 |
-
if self.seq_kd:
|
| 644 |
-
with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model:
|
| 645 |
-
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
|
| 646 |
-
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
|
| 647 |
-
)
|
| 648 |
-
inputs["input_ids"] = new_input_ids
|
| 649 |
-
inputs["attention_mask"] = new_attention_mask
|
| 650 |
-
inputs["labels"] = new_labels
|
| 651 |
-
if random.random() <= self.lmbda:
|
| 652 |
-
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
| 653 |
-
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
|
| 654 |
-
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
|
| 655 |
-
)
|
| 656 |
-
inputs["input_ids"] = new_input_ids
|
| 657 |
-
inputs["attention_mask"] = new_attention_mask
|
| 658 |
-
inputs["labels"] = new_labels
|
| 659 |
-
|
| 660 |
-
loss = super().training_step(model, inputs, num_items_in_batch)
|
| 661 |
-
return loss
|
| 662 |
-
|
| 663 |
-
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
|
| 664 |
-
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
| 665 |
-
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
| 666 |
-
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
| 667 |
-
|
| 668 |
-
if model is not None:
|
| 669 |
-
if hasattr(model, "config"):
|
| 670 |
-
hidden_size = (
|
| 671 |
-
max(model.config.hidden_sizes)
|
| 672 |
-
if getattr(model.config, "hidden_sizes", None)
|
| 673 |
-
else getattr(model.config, "hidden_size", None)
|
| 674 |
-
)
|
| 675 |
-
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
| 676 |
-
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
| 677 |
-
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
| 678 |
-
config_kwargs.update(
|
| 679 |
-
{
|
| 680 |
-
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
| 681 |
-
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
| 682 |
-
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
| 683 |
-
}
|
| 684 |
-
)
|
| 685 |
-
|
| 686 |
-
# If ZeRO-3 is used, we shard both the active and reference model.
|
| 687 |
-
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
| 688 |
-
if config_kwargs["zero_optimization"]["stage"] != 3:
|
| 689 |
-
config_kwargs["zero_optimization"]["stage"] = 0
|
| 690 |
-
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
| 691 |
-
model.eval()
|
| 692 |
-
return model
|
| 693 |
-
|
| 694 |
-
def create_model_card(
|
| 695 |
-
self,
|
| 696 |
-
model_name: Optional[str] = None,
|
| 697 |
-
dataset_name: Optional[str] = None,
|
| 698 |
-
tags: Union[str, list[str], None] = None,
|
| 699 |
-
):
|
| 700 |
-
"""
|
| 701 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
| 702 |
-
|
| 703 |
-
Args:
|
| 704 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 705 |
-
Name of the model.
|
| 706 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 707 |
-
Name of the dataset used for training.
|
| 708 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 709 |
-
Tags to be associated with the model card.
|
| 710 |
-
"""
|
| 711 |
-
if not self.is_world_process_zero():
|
| 712 |
-
return
|
| 713 |
-
|
| 714 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 715 |
-
base_model = self.model.config._name_or_path
|
| 716 |
-
else:
|
| 717 |
-
base_model = None
|
| 718 |
-
|
| 719 |
-
tags = tags or []
|
| 720 |
-
if isinstance(tags, str):
|
| 721 |
-
tags = [tags]
|
| 722 |
-
|
| 723 |
-
if hasattr(self.model.config, "unsloth_version"):
|
| 724 |
-
tags.append("unsloth")
|
| 725 |
-
|
| 726 |
-
citation = textwrap.dedent("""\
|
| 727 |
-
@inproceedings{agarwal2024on-policy,
|
| 728 |
-
title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}},
|
| 729 |
-
author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem},
|
| 730 |
-
year = 2024,
|
| 731 |
-
booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
|
| 732 |
-
publisher = {OpenReview.net},
|
| 733 |
-
url = {https://openreview.net/forum?id=3zKtaqxLhW},
|
| 734 |
-
}""")
|
| 735 |
-
|
| 736 |
-
model_card = generate_model_card(
|
| 737 |
-
base_model=base_model,
|
| 738 |
-
model_name=model_name,
|
| 739 |
-
hub_model_id=self.hub_model_id,
|
| 740 |
-
dataset_name=dataset_name,
|
| 741 |
-
tags=tags,
|
| 742 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 743 |
-
comet_url=get_comet_experiment_url(),
|
| 744 |
-
trainer_name="GKD",
|
| 745 |
-
trainer_citation=citation,
|
| 746 |
-
paper_title="On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes",
|
| 747 |
-
paper_id="2306.13649",
|
| 748 |
-
)
|
| 749 |
-
|
| 750 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 751 |
-
class UnslothGKDTrainer(_UnslothGKDTrainer):
|
| 752 |
-
"""
|
| 753 |
-
|
| 754 |
-
"""
|
| 755 |
-
def __init__(
|
| 756 |
-
self,
|
| 757 |
-
model = None,
|
| 758 |
-
teacher_model = None,
|
| 759 |
-
args = None,
|
| 760 |
-
data_collator = None,
|
| 761 |
-
train_dataset = None,
|
| 762 |
-
eval_dataset = None,
|
| 763 |
-
processing_class = None,
|
| 764 |
-
compute_metrics = None,
|
| 765 |
-
callbacks = None,
|
| 766 |
-
preprocess_logits_for_metrics = None,
|
| 767 |
-
peft_config = None,
|
| 768 |
-
formatting_func = None,
|
| 769 |
-
**kwargs
|
| 770 |
-
):
|
| 771 |
-
if args is None: args = UnslothGKDConfig()
|
| 772 |
-
use_bf16 = getattr(args, 'bf16', False)
|
| 773 |
-
if type(use_bf16) is not bool: use_bf16 = False
|
| 774 |
-
use_fp16 = getattr(args, 'fp16', False)
|
| 775 |
-
if type(use_fp16) is not bool: use_fp16 = False
|
| 776 |
-
force_float32 = False
|
| 777 |
-
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
| 778 |
-
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 779 |
-
force_float32 = True
|
| 780 |
-
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 781 |
-
dtype = getattr(model.config, 'torch_dtype', None)
|
| 782 |
-
if dtype is None: dtype = model.get_input_embeddings().dtype
|
| 783 |
-
from unsloth_zoo.utils import _get_dtype
|
| 784 |
-
dtype = _get_dtype(dtype)
|
| 785 |
-
float16 = dtype == torch.float16
|
| 786 |
-
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 787 |
-
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 788 |
-
if force_float32:
|
| 789 |
-
args.fp16 = False
|
| 790 |
-
args.bf16 = False
|
| 791 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 792 |
-
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 793 |
-
args.fp16 = float16
|
| 794 |
-
args.bf16 = not float16
|
| 795 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 796 |
-
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 797 |
-
args.eval_strategy = 'steps'
|
| 798 |
-
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 799 |
-
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 800 |
-
if ga_steps is not None and ga_steps > 1:
|
| 801 |
-
from transformers import __version__ as transformers_version
|
| 802 |
-
if Version(transformers_version) <= Version('4.45.2'):
|
| 803 |
-
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 804 |
-
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 805 |
-
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 806 |
-
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 807 |
-
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 808 |
-
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 809 |
-
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 810 |
-
if type(fp16_full_eval) is not bool: fp16_full_eval = False
|
| 811 |
-
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 812 |
-
if type(bf16_full_eval) is not bool: bf16_full_eval = False
|
| 813 |
-
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 814 |
-
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 815 |
-
if force_float32:
|
| 816 |
-
args.bf16_full_eval = False
|
| 817 |
-
args.fp16_full_eval = False
|
| 818 |
-
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 819 |
-
args.bf16_full_eval = True
|
| 820 |
-
args.fp16_full_eval = False
|
| 821 |
-
elif not bf16_full_eval and not fp16_full_eval:
|
| 822 |
-
args.bf16_full_eval = args.bf16
|
| 823 |
-
args.fp16_full_eval = args.fp16
|
| 824 |
-
_output_logits = False
|
| 825 |
-
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 826 |
-
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 827 |
-
if _output_logits:
|
| 828 |
-
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 829 |
-
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 830 |
-
pass
|
| 831 |
-
else:
|
| 832 |
-
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 833 |
-
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 834 |
-
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 835 |
-
max_seq_length = model.max_seq_length
|
| 836 |
-
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 837 |
-
if model is not None and hasattr(model, 'for_training'):
|
| 838 |
-
model.for_training()
|
| 839 |
-
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 840 |
-
if 'processing_class' in locals():
|
| 841 |
-
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 842 |
-
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 843 |
-
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 844 |
-
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 845 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 846 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 847 |
-
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
|
| 848 |
-
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 849 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
| 850 |
-
else:
|
| 851 |
-
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 852 |
-
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 853 |
-
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 854 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 855 |
-
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 856 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 857 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
| 858 |
-
else:
|
| 859 |
-
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
|
| 860 |
-
other_metrics = []
|
| 861 |
-
|
| 862 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 863 |
-
PatchRLStatistics('gkd_trainer', other_metrics)
|
| 864 |
-
|
| 865 |
-
super().__init__(
|
| 866 |
-
model = model,
|
| 867 |
-
teacher_model = teacher_model,
|
| 868 |
-
args = args,
|
| 869 |
-
data_collator = data_collator,
|
| 870 |
-
train_dataset = train_dataset,
|
| 871 |
-
eval_dataset = eval_dataset,
|
| 872 |
-
processing_class = processing_class,
|
| 873 |
-
compute_metrics = compute_metrics,
|
| 874 |
-
callbacks = callbacks,
|
| 875 |
-
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
| 876 |
-
peft_config = peft_config,
|
| 877 |
-
formatting_func = formatting_func,**kwargs)
|
| 878 |
-
if hasattr(self, 'neftune_hook_handle'):
|
| 879 |
-
self.neftune_hook_handle.remove()
|
| 880 |
-
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 881 |
-
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 882 |
-
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 883 |
-
pass
|
| 884 |
-
|
| 885 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_run_uploads/UnslothGRPOTrainer.py
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
test_run_uploads/UnslothKTOTrainer.py
DELETED
|
@@ -1,1849 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
2025.7.11
|
| 3 |
-
2025.7.11
|
| 4 |
-
4.54.1
|
| 5 |
-
0.16.1
|
| 6 |
-
__UNSLOTH_VERSIONING__
|
| 7 |
-
"""
|
| 8 |
-
from torch import Tensor
|
| 9 |
-
import torch
|
| 10 |
-
import torch.nn as nn
|
| 11 |
-
from torch.nn import functional as F
|
| 12 |
-
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
| 13 |
-
from trl.trainer.kto_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, KTOConfig, KTOTrainer, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, SequentialSampler, Trainer, TrainerCallback, TrainingArguments, Union, _get_kl_dataset, _process_tokens, _tokenize, amp, concatenate_datasets, contextmanager, create_reference_model, deepcopy, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, has_length, inspect, is_comet_available, is_peft_available, is_wandb_available, itemgetter, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, tqdm, transformers, version, wandb, warnings)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
import os
|
| 17 |
-
from typing import *
|
| 18 |
-
from dataclasses import dataclass, field
|
| 19 |
-
from packaging.version import Version
|
| 20 |
-
import torch
|
| 21 |
-
import numpy as np
|
| 22 |
-
from contextlib import nullcontext
|
| 23 |
-
from torch.nn import functional as F
|
| 24 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
| 25 |
-
|
| 26 |
-
torch_compile_options = {
|
| 27 |
-
"epilogue_fusion" : True,
|
| 28 |
-
"max_autotune" : False,
|
| 29 |
-
"shape_padding" : True,
|
| 30 |
-
"trace.enabled" : False,
|
| 31 |
-
"triton.cudagraphs" : False,
|
| 32 |
-
}
|
| 33 |
-
|
| 34 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 35 |
-
def chunked_selective_log_softmax(logits, index):
|
| 36 |
-
# Split into 4 chunks only
|
| 37 |
-
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
|
| 38 |
-
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
|
| 39 |
-
all_per_token_logps = []
|
| 40 |
-
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
|
| 41 |
-
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
|
| 42 |
-
chunk_logits = chunk_logits.to(torch.float32)
|
| 43 |
-
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 44 |
-
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
|
| 45 |
-
per_token_logps = selected_logits - logsumexp_values
|
| 46 |
-
all_per_token_logps.append(per_token_logps)
|
| 47 |
-
pass
|
| 48 |
-
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 49 |
-
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
|
| 50 |
-
return all_per_token_logps
|
| 51 |
-
@dataclass
|
| 52 |
-
class UnslothKTOConfig(KTOConfig):
|
| 53 |
-
"""
|
| 54 |
-
|
| 55 |
-
Configuration class for the [`KTOTrainer`].
|
| 56 |
-
|
| 57 |
-
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 58 |
-
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 59 |
-
command line.
|
| 60 |
-
|
| 61 |
-
Parameters:
|
| 62 |
-
learning_rate (`float`, *optional*, defaults to `1e-6`):
|
| 63 |
-
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
| 64 |
-
[`~transformers.TrainingArguments`].
|
| 65 |
-
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
| 66 |
-
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
|
| 67 |
-
to use the default data collator.
|
| 68 |
-
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
| 69 |
-
Maximum length of the prompt. This argument is required if you want to use the default data collator.
|
| 70 |
-
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
|
| 71 |
-
Maximum length of the completion. This argument is required if you want to use the default data collator
|
| 72 |
-
and your model is an encoder-decoder.
|
| 73 |
-
beta (`float`, *optional*, defaults to `0.1`):
|
| 74 |
-
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
|
| 75 |
-
reference model.
|
| 76 |
-
loss_type (`str`, *optional*, defaults to `"kto"`):
|
| 77 |
-
Type of loss to use. Possible values are:
|
| 78 |
-
|
| 79 |
-
- `"kto"`: KTO loss from the [KTO](https://huggingface.co/papers/2402.01306) paper.
|
| 80 |
-
- `"apo_zero_unpaired"`: Unpaired variant of APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
|
| 81 |
-
|
| 82 |
-
desirable_weight (`float`, *optional*, defaults to `1.0`):
|
| 83 |
-
Desirable losses are weighed by this factor to counter unequal number of desirable and undesirable paris.
|
| 84 |
-
undesirable_weight (`float`, *optional*, defaults to `1.0`):
|
| 85 |
-
Undesirable losses are weighed by this factor to counter unequal number of desirable and undesirable pairs.
|
| 86 |
-
label_pad_token_id (`int`, *optional*, defaults to `-100`):
|
| 87 |
-
Label pad token id. This argument is required if you want to use the default data collator.
|
| 88 |
-
padding_value (`int` or `None`, *optional*, defaults to `None`):
|
| 89 |
-
Padding value to use. If `None`, the padding value of the tokenizer is used.
|
| 90 |
-
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
|
| 91 |
-
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
|
| 92 |
-
This argument is required if you want to use the default data collator.
|
| 93 |
-
generate_during_eval (`bool`, *optional*, defaults to `False`):
|
| 94 |
-
If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during
|
| 95 |
-
evaluation.
|
| 96 |
-
is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
|
| 97 |
-
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
|
| 98 |
-
you need to specify if the model returned by the callable is an encoder-decoder model.
|
| 99 |
-
precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
|
| 100 |
-
Whether to precompute reference model log probabilities for training and evaluation datasets. This is
|
| 101 |
-
useful when training without the reference model to reduce the total GPU memory needed.
|
| 102 |
-
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
| 103 |
-
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
|
| 104 |
-
string.
|
| 105 |
-
ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
| 106 |
-
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model
|
| 107 |
-
from a string.
|
| 108 |
-
dataset_num_proc: (`int` or `None`, *optional*, defaults to `None`):
|
| 109 |
-
Number of processes to use for processing the dataset.
|
| 110 |
-
disable_dropout (`bool`, *optional*, defaults to `True`):
|
| 111 |
-
Whether to disable dropout in the model and reference model.
|
| 112 |
-
|
| 113 |
-
"""
|
| 114 |
-
vllm_sampling_params: Optional[Any] = field(
|
| 115 |
-
default = None,
|
| 116 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
| 117 |
-
)
|
| 118 |
-
unsloth_num_chunks : Optional[int] = field(
|
| 119 |
-
default = -1,
|
| 120 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 121 |
-
)
|
| 122 |
-
def __init__(
|
| 123 |
-
self,
|
| 124 |
-
output_dir = None,
|
| 125 |
-
overwrite_output_dir = None,
|
| 126 |
-
do_train = False,
|
| 127 |
-
do_eval = False,
|
| 128 |
-
do_predict = False,
|
| 129 |
-
eval_strategy = 'no',
|
| 130 |
-
prediction_loss_only = False,
|
| 131 |
-
per_device_train_batch_size = 4,
|
| 132 |
-
per_device_eval_batch_size = 4,
|
| 133 |
-
per_gpu_train_batch_size = None,
|
| 134 |
-
per_gpu_eval_batch_size = None,
|
| 135 |
-
gradient_accumulation_steps = 2,
|
| 136 |
-
eval_accumulation_steps = 2,
|
| 137 |
-
eval_delay = 0,
|
| 138 |
-
torch_empty_cache_steps = 250,
|
| 139 |
-
learning_rate = 5e-05,
|
| 140 |
-
weight_decay = 0.01,
|
| 141 |
-
adam_beta1 = 0.9,
|
| 142 |
-
adam_beta2 = 0.999,
|
| 143 |
-
adam_epsilon = 1e-08,
|
| 144 |
-
max_grad_norm = 1.0,
|
| 145 |
-
num_train_epochs = 3.0,
|
| 146 |
-
max_steps = -1,
|
| 147 |
-
lr_scheduler_type = 'linear',
|
| 148 |
-
warmup_ratio = 0.1,
|
| 149 |
-
warmup_steps = 0,
|
| 150 |
-
log_level = 'passive',
|
| 151 |
-
log_level_replica = 'warning',
|
| 152 |
-
log_on_each_node = True,
|
| 153 |
-
logging_dir = None,
|
| 154 |
-
logging_strategy = 'steps',
|
| 155 |
-
logging_first_step = False,
|
| 156 |
-
logging_steps = 1,
|
| 157 |
-
logging_nan_inf_filter = False,
|
| 158 |
-
save_strategy = 'steps',
|
| 159 |
-
save_steps = 500,
|
| 160 |
-
save_total_limit = None,
|
| 161 |
-
save_safetensors = True,
|
| 162 |
-
save_on_each_node = False,
|
| 163 |
-
save_only_model = False,
|
| 164 |
-
restore_callback_states_from_checkpoint = False,
|
| 165 |
-
no_cuda = False,
|
| 166 |
-
use_cpu = False,
|
| 167 |
-
use_mps_device = False,
|
| 168 |
-
seed = 3407,
|
| 169 |
-
data_seed = 3407,
|
| 170 |
-
jit_mode_eval = False,
|
| 171 |
-
use_ipex = False,
|
| 172 |
-
bf16 = False,
|
| 173 |
-
fp16 = False,
|
| 174 |
-
fp16_opt_level = 'O1',
|
| 175 |
-
half_precision_backend = 'auto',
|
| 176 |
-
bf16_full_eval = False,
|
| 177 |
-
fp16_full_eval = False,
|
| 178 |
-
tf32 = None,
|
| 179 |
-
local_rank = -1,
|
| 180 |
-
ddp_backend = None,
|
| 181 |
-
tpu_num_cores = None,
|
| 182 |
-
tpu_metrics_debug = False,
|
| 183 |
-
debug = '',
|
| 184 |
-
dataloader_drop_last = False,
|
| 185 |
-
eval_steps = None,
|
| 186 |
-
dataloader_num_workers = 0,
|
| 187 |
-
dataloader_prefetch_factor = None,
|
| 188 |
-
past_index = -1,
|
| 189 |
-
run_name = None,
|
| 190 |
-
disable_tqdm = None,
|
| 191 |
-
remove_unused_columns = True,
|
| 192 |
-
label_names = None,
|
| 193 |
-
load_best_model_at_end = False,
|
| 194 |
-
metric_for_best_model = None,
|
| 195 |
-
greater_is_better = None,
|
| 196 |
-
ignore_data_skip = False,
|
| 197 |
-
fsdp = '',
|
| 198 |
-
fsdp_min_num_params = 0,
|
| 199 |
-
fsdp_config = None,
|
| 200 |
-
fsdp_transformer_layer_cls_to_wrap = None,
|
| 201 |
-
accelerator_config = None,
|
| 202 |
-
deepspeed = None,
|
| 203 |
-
label_smoothing_factor = 0.0,
|
| 204 |
-
optim = 'adamw_8bit',
|
| 205 |
-
optim_args = None,
|
| 206 |
-
adafactor = False,
|
| 207 |
-
group_by_length = False,
|
| 208 |
-
length_column_name = 'length',
|
| 209 |
-
report_to = None,
|
| 210 |
-
ddp_find_unused_parameters = None,
|
| 211 |
-
ddp_bucket_cap_mb = None,
|
| 212 |
-
ddp_broadcast_buffers = None,
|
| 213 |
-
dataloader_pin_memory = True,
|
| 214 |
-
dataloader_persistent_workers = False,
|
| 215 |
-
skip_memory_metrics = True,
|
| 216 |
-
use_legacy_prediction_loop = False,
|
| 217 |
-
push_to_hub = False,
|
| 218 |
-
resume_from_checkpoint = None,
|
| 219 |
-
hub_model_id = None,
|
| 220 |
-
hub_strategy = 'every_save',
|
| 221 |
-
hub_token = None,
|
| 222 |
-
hub_private_repo = None,
|
| 223 |
-
hub_always_push = False,
|
| 224 |
-
hub_revision = None,
|
| 225 |
-
gradient_checkpointing = False,
|
| 226 |
-
gradient_checkpointing_kwargs = None,
|
| 227 |
-
include_inputs_for_metrics = False,
|
| 228 |
-
eval_do_concat_batches = True,
|
| 229 |
-
fp16_backend = 'auto',
|
| 230 |
-
push_to_hub_model_id = None,
|
| 231 |
-
push_to_hub_organization = None,
|
| 232 |
-
push_to_hub_token = None,
|
| 233 |
-
mp_parameters = '',
|
| 234 |
-
auto_find_batch_size = True,
|
| 235 |
-
full_determinism = False,
|
| 236 |
-
torchdynamo = None,
|
| 237 |
-
ray_scope = 'last',
|
| 238 |
-
ddp_timeout = 1800,
|
| 239 |
-
torch_compile = False,
|
| 240 |
-
torch_compile_backend = None,
|
| 241 |
-
torch_compile_mode = None,
|
| 242 |
-
include_tokens_per_second = False,
|
| 243 |
-
include_num_input_tokens_seen = False,
|
| 244 |
-
neftune_noise_alpha = None,
|
| 245 |
-
optim_target_modules = None,
|
| 246 |
-
batch_eval_metrics = False,
|
| 247 |
-
eval_on_start = False,
|
| 248 |
-
use_liger_kernel = False,
|
| 249 |
-
liger_kernel_config = None,
|
| 250 |
-
eval_use_gather_object = False,
|
| 251 |
-
average_tokens_across_devices = True,
|
| 252 |
-
max_length = 1024,
|
| 253 |
-
max_prompt_length = 512,
|
| 254 |
-
max_completion_length = None,
|
| 255 |
-
beta = 0.1,
|
| 256 |
-
loss_type = 'kto',
|
| 257 |
-
desirable_weight = 1.0,
|
| 258 |
-
undesirable_weight = 1.0,
|
| 259 |
-
label_pad_token_id = -100,
|
| 260 |
-
padding_value = None,
|
| 261 |
-
truncation_mode = 'keep_end',
|
| 262 |
-
generate_during_eval = False,
|
| 263 |
-
is_encoder_decoder = None,
|
| 264 |
-
disable_dropout = True,
|
| 265 |
-
precompute_ref_log_probs = False,
|
| 266 |
-
model_init_kwargs = None,
|
| 267 |
-
ref_model_init_kwargs = None,
|
| 268 |
-
dataset_num_proc = None,
|
| 269 |
-
vllm_sampling_params = None,
|
| 270 |
-
unsloth_num_chunks = -1,
|
| 271 |
-
**kwargs,
|
| 272 |
-
):
|
| 273 |
-
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 274 |
-
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 275 |
-
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 276 |
-
output_dir = 'unsloth_training_checkpoints'
|
| 277 |
-
save_strategy = 'no'
|
| 278 |
-
if dataset_num_proc is None:
|
| 279 |
-
from multiprocessing import cpu_count
|
| 280 |
-
dataset_num_proc = min(cpu_count()*2, 2)
|
| 281 |
-
|
| 282 |
-
super().__init__(
|
| 283 |
-
output_dir = output_dir,
|
| 284 |
-
overwrite_output_dir = overwrite_output_dir,
|
| 285 |
-
do_train = do_train,
|
| 286 |
-
do_eval = do_eval,
|
| 287 |
-
do_predict = do_predict,
|
| 288 |
-
eval_strategy = eval_strategy,
|
| 289 |
-
prediction_loss_only = prediction_loss_only,
|
| 290 |
-
per_device_train_batch_size = per_device_train_batch_size,
|
| 291 |
-
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 292 |
-
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
| 293 |
-
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
| 294 |
-
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 295 |
-
eval_accumulation_steps = eval_accumulation_steps,
|
| 296 |
-
eval_delay = eval_delay,
|
| 297 |
-
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 298 |
-
learning_rate = learning_rate,
|
| 299 |
-
weight_decay = weight_decay,
|
| 300 |
-
adam_beta1 = adam_beta1,
|
| 301 |
-
adam_beta2 = adam_beta2,
|
| 302 |
-
adam_epsilon = adam_epsilon,
|
| 303 |
-
max_grad_norm = max_grad_norm,
|
| 304 |
-
num_train_epochs = num_train_epochs,
|
| 305 |
-
max_steps = max_steps,
|
| 306 |
-
lr_scheduler_type = lr_scheduler_type,
|
| 307 |
-
warmup_ratio = warmup_ratio,
|
| 308 |
-
warmup_steps = warmup_steps,
|
| 309 |
-
log_level = log_level,
|
| 310 |
-
log_level_replica = log_level_replica,
|
| 311 |
-
log_on_each_node = log_on_each_node,
|
| 312 |
-
logging_dir = logging_dir,
|
| 313 |
-
logging_strategy = logging_strategy,
|
| 314 |
-
logging_first_step = logging_first_step,
|
| 315 |
-
logging_steps = logging_steps,
|
| 316 |
-
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 317 |
-
save_strategy = save_strategy,
|
| 318 |
-
save_steps = save_steps,
|
| 319 |
-
save_total_limit = save_total_limit,
|
| 320 |
-
save_safetensors = save_safetensors,
|
| 321 |
-
save_on_each_node = save_on_each_node,
|
| 322 |
-
save_only_model = save_only_model,
|
| 323 |
-
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 324 |
-
no_cuda = no_cuda,
|
| 325 |
-
use_cpu = use_cpu,
|
| 326 |
-
use_mps_device = use_mps_device,
|
| 327 |
-
seed = seed,
|
| 328 |
-
data_seed = data_seed,
|
| 329 |
-
jit_mode_eval = jit_mode_eval,
|
| 330 |
-
use_ipex = use_ipex,
|
| 331 |
-
bf16 = bf16,
|
| 332 |
-
fp16 = fp16,
|
| 333 |
-
fp16_opt_level = fp16_opt_level,
|
| 334 |
-
half_precision_backend = half_precision_backend,
|
| 335 |
-
bf16_full_eval = bf16_full_eval,
|
| 336 |
-
fp16_full_eval = fp16_full_eval,
|
| 337 |
-
tf32 = tf32,
|
| 338 |
-
local_rank = local_rank,
|
| 339 |
-
ddp_backend = ddp_backend,
|
| 340 |
-
tpu_num_cores = tpu_num_cores,
|
| 341 |
-
tpu_metrics_debug = tpu_metrics_debug,
|
| 342 |
-
debug = debug,
|
| 343 |
-
dataloader_drop_last = dataloader_drop_last,
|
| 344 |
-
eval_steps = eval_steps,
|
| 345 |
-
dataloader_num_workers = dataloader_num_workers,
|
| 346 |
-
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 347 |
-
past_index = past_index,
|
| 348 |
-
run_name = run_name,
|
| 349 |
-
disable_tqdm = disable_tqdm,
|
| 350 |
-
remove_unused_columns = remove_unused_columns,
|
| 351 |
-
label_names = label_names,
|
| 352 |
-
load_best_model_at_end = load_best_model_at_end,
|
| 353 |
-
metric_for_best_model = metric_for_best_model,
|
| 354 |
-
greater_is_better = greater_is_better,
|
| 355 |
-
ignore_data_skip = ignore_data_skip,
|
| 356 |
-
fsdp = fsdp,
|
| 357 |
-
fsdp_min_num_params = fsdp_min_num_params,
|
| 358 |
-
fsdp_config = fsdp_config,
|
| 359 |
-
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
| 360 |
-
accelerator_config = accelerator_config,
|
| 361 |
-
deepspeed = deepspeed,
|
| 362 |
-
label_smoothing_factor = label_smoothing_factor,
|
| 363 |
-
optim = optim,
|
| 364 |
-
optim_args = optim_args,
|
| 365 |
-
adafactor = adafactor,
|
| 366 |
-
group_by_length = group_by_length,
|
| 367 |
-
length_column_name = length_column_name,
|
| 368 |
-
report_to = report_to,
|
| 369 |
-
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 370 |
-
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 371 |
-
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 372 |
-
dataloader_pin_memory = dataloader_pin_memory,
|
| 373 |
-
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 374 |
-
skip_memory_metrics = skip_memory_metrics,
|
| 375 |
-
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
| 376 |
-
push_to_hub = push_to_hub,
|
| 377 |
-
resume_from_checkpoint = resume_from_checkpoint,
|
| 378 |
-
hub_model_id = hub_model_id,
|
| 379 |
-
hub_strategy = hub_strategy,
|
| 380 |
-
hub_token = hub_token,
|
| 381 |
-
hub_private_repo = hub_private_repo,
|
| 382 |
-
hub_always_push = hub_always_push,
|
| 383 |
-
hub_revision = hub_revision,
|
| 384 |
-
gradient_checkpointing = gradient_checkpointing,
|
| 385 |
-
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 386 |
-
include_inputs_for_metrics = include_inputs_for_metrics,
|
| 387 |
-
eval_do_concat_batches = eval_do_concat_batches,
|
| 388 |
-
fp16_backend = fp16_backend,
|
| 389 |
-
push_to_hub_model_id = push_to_hub_model_id,
|
| 390 |
-
push_to_hub_organization = push_to_hub_organization,
|
| 391 |
-
push_to_hub_token = push_to_hub_token,
|
| 392 |
-
mp_parameters = mp_parameters,
|
| 393 |
-
auto_find_batch_size = auto_find_batch_size,
|
| 394 |
-
full_determinism = full_determinism,
|
| 395 |
-
torchdynamo = torchdynamo,
|
| 396 |
-
ray_scope = ray_scope,
|
| 397 |
-
ddp_timeout = ddp_timeout,
|
| 398 |
-
torch_compile = torch_compile,
|
| 399 |
-
torch_compile_backend = torch_compile_backend,
|
| 400 |
-
torch_compile_mode = torch_compile_mode,
|
| 401 |
-
include_tokens_per_second = include_tokens_per_second,
|
| 402 |
-
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 403 |
-
neftune_noise_alpha = neftune_noise_alpha,
|
| 404 |
-
optim_target_modules = optim_target_modules,
|
| 405 |
-
batch_eval_metrics = batch_eval_metrics,
|
| 406 |
-
eval_on_start = eval_on_start,
|
| 407 |
-
use_liger_kernel = use_liger_kernel,
|
| 408 |
-
liger_kernel_config = liger_kernel_config,
|
| 409 |
-
eval_use_gather_object = eval_use_gather_object,
|
| 410 |
-
average_tokens_across_devices = average_tokens_across_devices,
|
| 411 |
-
max_length = max_length,
|
| 412 |
-
max_prompt_length = max_prompt_length,
|
| 413 |
-
max_completion_length = max_completion_length,
|
| 414 |
-
beta = beta,
|
| 415 |
-
loss_type = loss_type,
|
| 416 |
-
desirable_weight = desirable_weight,
|
| 417 |
-
undesirable_weight = undesirable_weight,
|
| 418 |
-
label_pad_token_id = label_pad_token_id,
|
| 419 |
-
padding_value = padding_value,
|
| 420 |
-
truncation_mode = truncation_mode,
|
| 421 |
-
generate_during_eval = generate_during_eval,
|
| 422 |
-
is_encoder_decoder = is_encoder_decoder,
|
| 423 |
-
disable_dropout = disable_dropout,
|
| 424 |
-
precompute_ref_log_probs = precompute_ref_log_probs,
|
| 425 |
-
model_init_kwargs = model_init_kwargs,
|
| 426 |
-
ref_model_init_kwargs = ref_model_init_kwargs,
|
| 427 |
-
dataset_num_proc = dataset_num_proc,**kwargs)
|
| 428 |
-
self.vllm_sampling_params = vllm_sampling_params
|
| 429 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
| 430 |
-
pass
|
| 431 |
-
|
| 432 |
-
class _UnslothKTOTrainer(Trainer):
|
| 433 |
-
r""""""
|
| 434 |
-
|
| 435 |
-
_tag_names = ["trl", "kto"]
|
| 436 |
-
|
| 437 |
-
def __init__(
|
| 438 |
-
self,
|
| 439 |
-
model: Union[PreTrainedModel, nn.Module, str] = None,
|
| 440 |
-
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
| 441 |
-
args: KTOConfig = None,
|
| 442 |
-
train_dataset: Optional[Dataset] = None,
|
| 443 |
-
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 444 |
-
processing_class: Optional[
|
| 445 |
-
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 446 |
-
] = None,
|
| 447 |
-
data_collator: Optional[DataCollator] = None,
|
| 448 |
-
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
| 449 |
-
callbacks: Optional[list[TrainerCallback]] = None,
|
| 450 |
-
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 451 |
-
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 452 |
-
peft_config: Optional[dict] = None,
|
| 453 |
-
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
| 454 |
-
model_adapter_name: Optional[str] = None,
|
| 455 |
-
ref_adapter_name: Optional[str] = None,
|
| 456 |
-
):
|
| 457 |
-
if type(args) is TrainingArguments:
|
| 458 |
-
raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
|
| 459 |
-
|
| 460 |
-
if not isinstance(model, str) and ref_model is model:
|
| 461 |
-
raise ValueError(
|
| 462 |
-
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
|
| 463 |
-
"same as `model`, you must mass a copy of it, or `None` if you use peft."
|
| 464 |
-
)
|
| 465 |
-
|
| 466 |
-
if args.model_init_kwargs is None:
|
| 467 |
-
model_init_kwargs = {}
|
| 468 |
-
elif not isinstance(model, str):
|
| 469 |
-
raise ValueError("You passed model_kwargs to the KTOTrainer. But your model is already instantiated.")
|
| 470 |
-
else:
|
| 471 |
-
model_init_kwargs = args.model_init_kwargs
|
| 472 |
-
torch_dtype = model_init_kwargs.get("torch_dtype")
|
| 473 |
-
if torch_dtype is not None:
|
| 474 |
-
# Convert to `torch.dtype` if an str is passed
|
| 475 |
-
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
| 476 |
-
torch_dtype = getattr(torch, torch_dtype)
|
| 477 |
-
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
| 478 |
-
raise ValueError(
|
| 479 |
-
f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
| 480 |
-
)
|
| 481 |
-
model_init_kwargs["torch_dtype"] = torch_dtype
|
| 482 |
-
|
| 483 |
-
if args.ref_model_init_kwargs is None:
|
| 484 |
-
ref_model_init_kwargs = {}
|
| 485 |
-
elif not isinstance(ref_model, str):
|
| 486 |
-
raise ValueError(
|
| 487 |
-
"You passed ref_model_kwargs to the KTOTrainer. But your ref_model is already instantiated."
|
| 488 |
-
)
|
| 489 |
-
else:
|
| 490 |
-
ref_model_init_kwargs = args.ref_model_init_kwargs
|
| 491 |
-
torch_dtype = ref_model_init_kwargs.get("torch_dtype")
|
| 492 |
-
if torch_dtype is not None:
|
| 493 |
-
# Convert to `torch.dtype` if an str is passed
|
| 494 |
-
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
| 495 |
-
torch_dtype = getattr(torch, torch_dtype)
|
| 496 |
-
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
| 497 |
-
raise ValueError(
|
| 498 |
-
f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
| 499 |
-
)
|
| 500 |
-
ref_model_init_kwargs["torch_dtype"] = torch_dtype
|
| 501 |
-
|
| 502 |
-
if isinstance(model, str):
|
| 503 |
-
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
| 504 |
-
|
| 505 |
-
if isinstance(ref_model, str):
|
| 506 |
-
ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
|
| 507 |
-
|
| 508 |
-
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
| 509 |
-
# has been called in order to properly call autocast if needed.
|
| 510 |
-
self._peft_has_been_casted_to_bf16 = False
|
| 511 |
-
|
| 512 |
-
if not is_peft_available() and peft_config is not None:
|
| 513 |
-
raise ValueError(
|
| 514 |
-
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
|
| 515 |
-
)
|
| 516 |
-
elif is_peft_available() and peft_config is not None:
|
| 517 |
-
# if model is a peft model and we have a peft_config, we merge and unload it first
|
| 518 |
-
if isinstance(model, PeftModel):
|
| 519 |
-
model = model.merge_and_unload()
|
| 520 |
-
|
| 521 |
-
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
| 522 |
-
_support_gc_kwargs = hasattr(
|
| 523 |
-
args, "gradient_checkpointing_kwargs"
|
| 524 |
-
) and "gradient_checkpointing_kwargs" in list(
|
| 525 |
-
inspect.signature(prepare_model_for_kbit_training).parameters
|
| 526 |
-
)
|
| 527 |
-
|
| 528 |
-
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
| 529 |
-
|
| 530 |
-
if _support_gc_kwargs:
|
| 531 |
-
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
| 532 |
-
|
| 533 |
-
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
| 534 |
-
elif getattr(args, "gradient_checkpointing", False):
|
| 535 |
-
# For backward compatibility with older versions of transformers
|
| 536 |
-
if hasattr(model, "enable_input_require_grads"):
|
| 537 |
-
model.enable_input_require_grads()
|
| 538 |
-
else:
|
| 539 |
-
|
| 540 |
-
def make_inputs_require_grad(module, input, output):
|
| 541 |
-
output.requires_grad_(True)
|
| 542 |
-
|
| 543 |
-
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 544 |
-
|
| 545 |
-
# get peft model with the given config
|
| 546 |
-
model = model
|
| 547 |
-
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
| 548 |
-
peft_module_casting_to_bf16(model)
|
| 549 |
-
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
| 550 |
-
self._peft_has_been_casted_to_bf16 = True
|
| 551 |
-
|
| 552 |
-
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
| 553 |
-
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
| 554 |
-
# fail or completely fail.
|
| 555 |
-
elif getattr(args, "gradient_checkpointing", False):
|
| 556 |
-
# For backward compatibility with older versions of transformers
|
| 557 |
-
if hasattr(model, "enable_input_require_grads"):
|
| 558 |
-
model.enable_input_require_grads()
|
| 559 |
-
else:
|
| 560 |
-
|
| 561 |
-
def make_inputs_require_grad(module, input, output):
|
| 562 |
-
output.requires_grad_(True)
|
| 563 |
-
|
| 564 |
-
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 565 |
-
|
| 566 |
-
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
|
| 567 |
-
raise ValueError(
|
| 568 |
-
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
|
| 569 |
-
" Please install `wandb` or `comet-ml` to resolve."
|
| 570 |
-
)
|
| 571 |
-
|
| 572 |
-
if model is not None:
|
| 573 |
-
self.is_encoder_decoder = model.config.is_encoder_decoder
|
| 574 |
-
elif args.is_encoder_decoder is None:
|
| 575 |
-
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
| 576 |
-
else:
|
| 577 |
-
self.is_encoder_decoder = args.is_encoder_decoder
|
| 578 |
-
|
| 579 |
-
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
|
| 580 |
-
self.model_adapter_name = model_adapter_name
|
| 581 |
-
self.ref_adapter_name = ref_adapter_name
|
| 582 |
-
|
| 583 |
-
if ref_model:
|
| 584 |
-
self.ref_model = ref_model
|
| 585 |
-
elif self.is_peft_model or args.precompute_ref_log_probs:
|
| 586 |
-
# The `model` with adapters turned off will be used as the reference model
|
| 587 |
-
self.ref_model = None
|
| 588 |
-
else:
|
| 589 |
-
self.ref_model = create_reference_model(model)
|
| 590 |
-
|
| 591 |
-
if processing_class is None:
|
| 592 |
-
raise ValueError(
|
| 593 |
-
"max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
|
| 594 |
-
)
|
| 595 |
-
if args.max_length is None:
|
| 596 |
-
warnings.warn(
|
| 597 |
-
"When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init"
|
| 598 |
-
" it will be set to `512` by default, but you should do it yourself in the future.",
|
| 599 |
-
UserWarning,
|
| 600 |
-
)
|
| 601 |
-
max_length = 512
|
| 602 |
-
if args.max_length is not None:
|
| 603 |
-
max_length = args.max_length
|
| 604 |
-
|
| 605 |
-
if args.max_prompt_length is None:
|
| 606 |
-
warnings.warn(
|
| 607 |
-
"When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init"
|
| 608 |
-
" it will be set to `128` by default, but you should do it yourself in the future.",
|
| 609 |
-
UserWarning,
|
| 610 |
-
)
|
| 611 |
-
max_prompt_length = 128
|
| 612 |
-
if args.max_prompt_length is not None:
|
| 613 |
-
max_prompt_length = args.max_prompt_length
|
| 614 |
-
|
| 615 |
-
max_completion_length = None
|
| 616 |
-
if args.max_completion_length is None and self.is_encoder_decoder:
|
| 617 |
-
warnings.warn(
|
| 618 |
-
"When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the KTOTrainer's init"
|
| 619 |
-
" it will be set to `128` by default, but you should do it yourself in the future.",
|
| 620 |
-
UserWarning,
|
| 621 |
-
)
|
| 622 |
-
max_completion_length = 128
|
| 623 |
-
if args.max_completion_length is not None and self.is_encoder_decoder:
|
| 624 |
-
max_completion_length = args.max_completion_length
|
| 625 |
-
|
| 626 |
-
if data_collator is None:
|
| 627 |
-
data_collator = DPODataCollatorWithPadding(
|
| 628 |
-
pad_token_id=processing_class.pad_token_id,
|
| 629 |
-
label_pad_token_id=args.label_pad_token_id,
|
| 630 |
-
is_encoder_decoder=self.is_encoder_decoder,
|
| 631 |
-
)
|
| 632 |
-
|
| 633 |
-
if args.remove_unused_columns:
|
| 634 |
-
args.remove_unused_columns = False
|
| 635 |
-
# warn users
|
| 636 |
-
warnings.warn(
|
| 637 |
-
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig"
|
| 638 |
-
" we have set it for you, but you should do it yourself in the future.",
|
| 639 |
-
UserWarning,
|
| 640 |
-
)
|
| 641 |
-
|
| 642 |
-
self.use_dpo_data_collator = True
|
| 643 |
-
else:
|
| 644 |
-
self.use_dpo_data_collator = False
|
| 645 |
-
|
| 646 |
-
# Disable dropout in the model and reference model
|
| 647 |
-
if args.disable_dropout:
|
| 648 |
-
disable_dropout_in_model(model)
|
| 649 |
-
if self.ref_model is not None:
|
| 650 |
-
disable_dropout_in_model(self.ref_model)
|
| 651 |
-
|
| 652 |
-
self.loss_type = args.loss_type
|
| 653 |
-
self.max_length = max_length
|
| 654 |
-
self.generate_during_eval = args.generate_during_eval
|
| 655 |
-
self.label_pad_token_id = args.label_pad_token_id
|
| 656 |
-
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
|
| 657 |
-
self.max_prompt_length = max_prompt_length
|
| 658 |
-
self.truncation_mode = args.truncation_mode
|
| 659 |
-
self.max_completion_length = max_completion_length
|
| 660 |
-
self.processing_class = processing_class
|
| 661 |
-
self.precompute_ref_log_probs = args.precompute_ref_log_probs
|
| 662 |
-
|
| 663 |
-
# Not all losses require a KL calculation
|
| 664 |
-
self.calculate_KL = True
|
| 665 |
-
if self.loss_type in ["apo_zero_unpaired"]:
|
| 666 |
-
self.calculate_KL = False
|
| 667 |
-
|
| 668 |
-
# Since ref_logs are precomputed on the first call to get_train/eval_dataloader
|
| 669 |
-
# keep track of first called to avoid computation of future calls
|
| 670 |
-
self._precomputed_train_ref_log_probs = False
|
| 671 |
-
self._precomputed_eval_ref_log_probs = False
|
| 672 |
-
|
| 673 |
-
# metric
|
| 674 |
-
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
| 675 |
-
|
| 676 |
-
# KTO parameter
|
| 677 |
-
self.beta = args.beta
|
| 678 |
-
self.desirable_weight = args.desirable_weight
|
| 679 |
-
self.undesirable_weight = args.undesirable_weight
|
| 680 |
-
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
| 681 |
-
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
|
| 682 |
-
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
|
| 683 |
-
warnings.warn(
|
| 684 |
-
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
|
| 685 |
-
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
|
| 686 |
-
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
|
| 687 |
-
"loss.",
|
| 688 |
-
UserWarning,
|
| 689 |
-
)
|
| 690 |
-
|
| 691 |
-
# The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
|
| 692 |
-
# input tensor associated with the key "input_ids". However, in KTO, the sampled data does not include the
|
| 693 |
-
# "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
|
| 694 |
-
# the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
|
| 695 |
-
# operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
|
| 696 |
-
# "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
|
| 697 |
-
# issued.
|
| 698 |
-
model.warnings_issued["estimate_tokens"] = True
|
| 699 |
-
|
| 700 |
-
# Compute that only on the main process for faster data processing.
|
| 701 |
-
# see: https://github.com/huggingface/trl/pull/1255
|
| 702 |
-
with PartialState().main_process_first():
|
| 703 |
-
# Extract the prompt if needed
|
| 704 |
-
train_dataset = train_dataset.map(
|
| 705 |
-
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset"
|
| 706 |
-
)
|
| 707 |
-
# Unpair the dataset if needed
|
| 708 |
-
train_dataset = maybe_unpair_preference_dataset(
|
| 709 |
-
train_dataset, args.dataset_num_proc, desc="Unpairing train dataset"
|
| 710 |
-
)
|
| 711 |
-
# Apply the chat template if needed
|
| 712 |
-
train_dataset = train_dataset.map(
|
| 713 |
-
maybe_apply_chat_template,
|
| 714 |
-
fn_kwargs={"tokenizer": processing_class},
|
| 715 |
-
num_proc=args.dataset_num_proc,
|
| 716 |
-
desc="Applying chat template to train dataset",
|
| 717 |
-
)
|
| 718 |
-
if eval_dataset is not None:
|
| 719 |
-
eval_dataset = eval_dataset.map(
|
| 720 |
-
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset"
|
| 721 |
-
)
|
| 722 |
-
eval_dataset = maybe_unpair_preference_dataset(
|
| 723 |
-
eval_dataset, args.dataset_num_proc, desc="Unpairing eval dataset"
|
| 724 |
-
)
|
| 725 |
-
eval_dataset = eval_dataset.map(
|
| 726 |
-
maybe_apply_chat_template,
|
| 727 |
-
fn_kwargs={"tokenizer": processing_class},
|
| 728 |
-
num_proc=args.dataset_num_proc,
|
| 729 |
-
desc="Applying chat template to eval dataset",
|
| 730 |
-
)
|
| 731 |
-
|
| 732 |
-
# Tokenize and prepare the training datasets
|
| 733 |
-
train_dataset = train_dataset.map(
|
| 734 |
-
_tokenize,
|
| 735 |
-
batched=True,
|
| 736 |
-
fn_kwargs={"tokenizer": self.processing_class},
|
| 737 |
-
num_proc=args.dataset_num_proc,
|
| 738 |
-
desc="Tokenizing train dataset",
|
| 739 |
-
)
|
| 740 |
-
|
| 741 |
-
fn_kwargs = {
|
| 742 |
-
"prefix": "",
|
| 743 |
-
"is_encoder_decoder": self.is_encoder_decoder,
|
| 744 |
-
"tokenizer": self.processing_class,
|
| 745 |
-
"max_length": self.max_length,
|
| 746 |
-
"truncation_mode": self.truncation_mode,
|
| 747 |
-
"label_pad_token_id": self.label_pad_token_id,
|
| 748 |
-
"max_prompt_length": self.max_prompt_length,
|
| 749 |
-
"max_completion_length": self.max_completion_length,
|
| 750 |
-
}
|
| 751 |
-
|
| 752 |
-
train_dataset = train_dataset.map(
|
| 753 |
-
_process_tokens,
|
| 754 |
-
fn_kwargs=fn_kwargs,
|
| 755 |
-
num_proc=args.dataset_num_proc,
|
| 756 |
-
desc="Processing tokenized train dataset",
|
| 757 |
-
)
|
| 758 |
-
|
| 759 |
-
# Tokenize and prepare the eval datasets
|
| 760 |
-
if eval_dataset is not None:
|
| 761 |
-
eval_dataset = eval_dataset.map(
|
| 762 |
-
_tokenize,
|
| 763 |
-
fn_kwargs={"tokenizer": self.processing_class},
|
| 764 |
-
batched=True,
|
| 765 |
-
num_proc=args.dataset_num_proc,
|
| 766 |
-
desc="Tokenizing eval dataset",
|
| 767 |
-
)
|
| 768 |
-
|
| 769 |
-
eval_dataset = eval_dataset.map(
|
| 770 |
-
_process_tokens,
|
| 771 |
-
fn_kwargs=fn_kwargs,
|
| 772 |
-
num_proc=args.dataset_num_proc,
|
| 773 |
-
desc="Processing tokenized eval dataset",
|
| 774 |
-
)
|
| 775 |
-
|
| 776 |
-
# Get KL datasets if needed
|
| 777 |
-
if self.calculate_KL:
|
| 778 |
-
if args.per_device_train_batch_size <= 1:
|
| 779 |
-
raise ValueError(
|
| 780 |
-
"Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward."
|
| 781 |
-
)
|
| 782 |
-
|
| 783 |
-
# create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size
|
| 784 |
-
# i.e., [x_1, y_1], ..., [x_n, y_n] --> [x_1, y_n], ..., [x_n, y_1] = [x'_1, y'_1], ..., [x'_n, y'_n]
|
| 785 |
-
train_kl_dataset = train_dataset.map(
|
| 786 |
-
_get_kl_dataset,
|
| 787 |
-
batched=True,
|
| 788 |
-
batch_size=args.per_device_train_batch_size,
|
| 789 |
-
num_proc=args.dataset_num_proc,
|
| 790 |
-
desc="Extracting KL train dataset",
|
| 791 |
-
)
|
| 792 |
-
|
| 793 |
-
fn_kwargs["prefix"] = "KL_"
|
| 794 |
-
train_kl_dataset = train_kl_dataset.map(
|
| 795 |
-
_process_tokens,
|
| 796 |
-
fn_kwargs=fn_kwargs,
|
| 797 |
-
num_proc=args.dataset_num_proc,
|
| 798 |
-
remove_columns=[c for c in train_kl_dataset.column_names if c in train_dataset.column_names],
|
| 799 |
-
desc="Processing tokenized train KL dataset",
|
| 800 |
-
)
|
| 801 |
-
|
| 802 |
-
# merge the datasets
|
| 803 |
-
train_dataset = concatenate_datasets([train_dataset, train_kl_dataset], axis=1)
|
| 804 |
-
|
| 805 |
-
if eval_dataset is not None:
|
| 806 |
-
# Get KL dataset
|
| 807 |
-
eval_kl_dataset = eval_dataset.map(
|
| 808 |
-
_get_kl_dataset,
|
| 809 |
-
batched=True,
|
| 810 |
-
batch_size=args.per_device_train_batch_size,
|
| 811 |
-
num_proc=args.dataset_num_proc,
|
| 812 |
-
desc="Extracting eval KL dataset",
|
| 813 |
-
)
|
| 814 |
-
|
| 815 |
-
eval_kl_dataset = eval_kl_dataset.map(
|
| 816 |
-
_process_tokens,
|
| 817 |
-
fn_kwargs=fn_kwargs,
|
| 818 |
-
num_proc=args.dataset_num_proc,
|
| 819 |
-
remove_columns=[c for c in eval_kl_dataset.column_names if c in eval_dataset.column_names],
|
| 820 |
-
desc="Processing tokenized eval KL dataset",
|
| 821 |
-
)
|
| 822 |
-
|
| 823 |
-
# merge the datasets
|
| 824 |
-
eval_dataset = concatenate_datasets([eval_dataset, eval_kl_dataset], axis=1)
|
| 825 |
-
|
| 826 |
-
# calculate dataset desirability balance
|
| 827 |
-
num_desirable = max(sum(train_dataset["label"]), 1)
|
| 828 |
-
num_undesirable = max(len(train_dataset["label"]) - num_desirable, 1) # "label" is binary
|
| 829 |
-
|
| 830 |
-
if num_desirable != num_undesirable:
|
| 831 |
-
# The lower and upper bounds come from Eq. [8] of https://huggingface.co/papers/2402.01306
|
| 832 |
-
des_weight_lower_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1, 2)
|
| 833 |
-
des_weight_upper_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1.33, 2)
|
| 834 |
-
und_weight_lower_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1.33, 2)
|
| 835 |
-
und_weight_upper_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1, 2)
|
| 836 |
-
|
| 837 |
-
des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound
|
| 838 |
-
und_weight_in_range = und_weight_lower_bound <= self.undesirable_weight <= und_weight_upper_bound
|
| 839 |
-
|
| 840 |
-
if not (des_weight_in_range or und_weight_in_range):
|
| 841 |
-
warnings.warn(
|
| 842 |
-
"You have different amounts of desirable/positive and undesirable/negative examples but the "
|
| 843 |
-
"weights on the desirable and undesirable losses don't seem to be in an ideal range. Based "
|
| 844 |
-
f"on your data, we recommend EITHER "
|
| 845 |
-
f"desirable_weight in [{des_weight_lower_bound}, {des_weight_upper_bound}] or "
|
| 846 |
-
f"undesirable_weight in [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH). "
|
| 847 |
-
"See the documentation on how to optimally set these weights.",
|
| 848 |
-
UserWarning,
|
| 849 |
-
)
|
| 850 |
-
|
| 851 |
-
super().__init__(
|
| 852 |
-
model=model,
|
| 853 |
-
args=args,
|
| 854 |
-
data_collator=data_collator,
|
| 855 |
-
train_dataset=train_dataset,
|
| 856 |
-
eval_dataset=eval_dataset,
|
| 857 |
-
processing_class=processing_class,
|
| 858 |
-
model_init=model_init,
|
| 859 |
-
compute_metrics=compute_metrics,
|
| 860 |
-
callbacks=callbacks,
|
| 861 |
-
optimizers=optimizers,
|
| 862 |
-
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 863 |
-
)
|
| 864 |
-
|
| 865 |
-
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
| 866 |
-
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
| 867 |
-
# self.model_accepts_loss_kwargs to False to enable scaling.
|
| 868 |
-
self.model_accepts_loss_kwargs = False
|
| 869 |
-
|
| 870 |
-
# Add tags for models that have been loaded with the correct transformers version
|
| 871 |
-
if hasattr(self.model, "add_model_tags"):
|
| 872 |
-
self.model.add_model_tags(self._tag_names)
|
| 873 |
-
|
| 874 |
-
if not hasattr(self, "accelerator"):
|
| 875 |
-
raise AttributeError(
|
| 876 |
-
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
| 877 |
-
)
|
| 878 |
-
|
| 879 |
-
# Deepspeed Zero-3 does not support precompute_ref_log_probs
|
| 880 |
-
if self.is_deepspeed_enabled:
|
| 881 |
-
if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
|
| 882 |
-
raise ValueError(
|
| 883 |
-
"You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
|
| 884 |
-
)
|
| 885 |
-
|
| 886 |
-
if self.ref_model is None:
|
| 887 |
-
if not (self.is_peft_model or self.precompute_ref_log_probs):
|
| 888 |
-
raise ValueError(
|
| 889 |
-
"No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
|
| 890 |
-
)
|
| 891 |
-
else:
|
| 892 |
-
if self.is_deepspeed_enabled:
|
| 893 |
-
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
| 894 |
-
else:
|
| 895 |
-
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
| 896 |
-
|
| 897 |
-
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
|
| 898 |
-
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
| 899 |
-
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
| 900 |
-
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
| 901 |
-
|
| 902 |
-
if model is not None:
|
| 903 |
-
if hasattr(model, "config"):
|
| 904 |
-
hidden_size = (
|
| 905 |
-
max(model.config.hidden_sizes)
|
| 906 |
-
if getattr(model.config, "hidden_sizes", None)
|
| 907 |
-
else getattr(model.config, "hidden_size", None)
|
| 908 |
-
)
|
| 909 |
-
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
| 910 |
-
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
| 911 |
-
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
| 912 |
-
config_kwargs.update(
|
| 913 |
-
{
|
| 914 |
-
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
| 915 |
-
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
| 916 |
-
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
| 917 |
-
}
|
| 918 |
-
)
|
| 919 |
-
|
| 920 |
-
# If ZeRO-3 is used, we shard both the active and reference model.
|
| 921 |
-
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
| 922 |
-
if config_kwargs["zero_optimization"]["stage"] != 3:
|
| 923 |
-
config_kwargs["zero_optimization"]["stage"] = 0
|
| 924 |
-
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
| 925 |
-
model.eval()
|
| 926 |
-
return model
|
| 927 |
-
|
| 928 |
-
@contextmanager
|
| 929 |
-
def null_ref_context(self):
|
| 930 |
-
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
|
| 931 |
-
with (
|
| 932 |
-
self.accelerator.unwrap_model(self.model).disable_adapter()
|
| 933 |
-
if self.is_peft_model and not self.ref_adapter_name
|
| 934 |
-
else nullcontext()
|
| 935 |
-
):
|
| 936 |
-
if self.ref_adapter_name:
|
| 937 |
-
self.model.set_adapter(self.ref_adapter_name)
|
| 938 |
-
yield
|
| 939 |
-
if self.ref_adapter_name:
|
| 940 |
-
self.model.set_adapter(self.model_adapter_name or "default")
|
| 941 |
-
|
| 942 |
-
def get_train_dataloader(self) -> DataLoader:
|
| 943 |
-
"""
|
| 944 |
-
Returns the training [`~torch.utils.data.DataLoader`].
|
| 945 |
-
|
| 946 |
-
Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
|
| 947 |
-
"""
|
| 948 |
-
|
| 949 |
-
if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
|
| 950 |
-
dataloader_params = {
|
| 951 |
-
"batch_size": self.args.per_device_train_batch_size,
|
| 952 |
-
"collate_fn": self.data_collator,
|
| 953 |
-
"num_workers": self.args.dataloader_num_workers,
|
| 954 |
-
"pin_memory": self.args.dataloader_pin_memory,
|
| 955 |
-
"shuffle": False,
|
| 956 |
-
}
|
| 957 |
-
|
| 958 |
-
# prepare dataloader
|
| 959 |
-
data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
|
| 960 |
-
reference_completion_logps = []
|
| 961 |
-
reference_KL_logps = []
|
| 962 |
-
|
| 963 |
-
for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
|
| 964 |
-
reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
|
| 965 |
-
|
| 966 |
-
reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
|
| 967 |
-
reference_completion_logps.append(reference_completion_logp.cpu())
|
| 968 |
-
|
| 969 |
-
if self.calculate_KL:
|
| 970 |
-
reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
|
| 971 |
-
reference_KL_logps.append(reference_KL_logp.cpu())
|
| 972 |
-
|
| 973 |
-
self.train_dataset = self.train_dataset.add_column(
|
| 974 |
-
name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
|
| 975 |
-
)
|
| 976 |
-
|
| 977 |
-
if self.calculate_KL:
|
| 978 |
-
self.train_dataset = self.train_dataset.add_column(
|
| 979 |
-
name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
|
| 980 |
-
)
|
| 981 |
-
|
| 982 |
-
self._precomputed_train_ref_log_probs = True
|
| 983 |
-
|
| 984 |
-
return super().get_train_dataloader()
|
| 985 |
-
|
| 986 |
-
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
| 987 |
-
"""
|
| 988 |
-
Returns the evaluation [`~torch.utils.data.DataLoader`].
|
| 989 |
-
|
| 990 |
-
Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
|
| 991 |
-
|
| 992 |
-
Args:
|
| 993 |
-
eval_dataset (`torch.utils.data.Dataset`, *optional*):
|
| 994 |
-
If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
|
| 995 |
-
by the `model.forward()` method are automatically removed. It must implement `__len__`.
|
| 996 |
-
"""
|
| 997 |
-
if eval_dataset is None and self.eval_dataset is None:
|
| 998 |
-
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
| 999 |
-
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
| 1000 |
-
|
| 1001 |
-
if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
|
| 1002 |
-
dataloader_params = {
|
| 1003 |
-
"batch_size": self.args.per_device_eval_batch_size,
|
| 1004 |
-
"collate_fn": self.data_collator,
|
| 1005 |
-
"num_workers": self.args.dataloader_num_workers,
|
| 1006 |
-
"pin_memory": self.args.dataloader_pin_memory,
|
| 1007 |
-
"shuffle": False,
|
| 1008 |
-
}
|
| 1009 |
-
|
| 1010 |
-
# prepare dataloader
|
| 1011 |
-
data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
|
| 1012 |
-
|
| 1013 |
-
reference_completion_logps = []
|
| 1014 |
-
reference_KL_logps = []
|
| 1015 |
-
|
| 1016 |
-
for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
|
| 1017 |
-
reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
|
| 1018 |
-
|
| 1019 |
-
reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
|
| 1020 |
-
reference_completion_logps.append(reference_completion_logp.cpu())
|
| 1021 |
-
|
| 1022 |
-
if self.calculate_KL:
|
| 1023 |
-
reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
|
| 1024 |
-
reference_KL_logps.append(reference_KL_logp.cpu())
|
| 1025 |
-
|
| 1026 |
-
eval_dataset = eval_dataset.add_column(
|
| 1027 |
-
name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
|
| 1028 |
-
)
|
| 1029 |
-
if self.calculate_KL:
|
| 1030 |
-
eval_dataset = eval_dataset.add_column(
|
| 1031 |
-
name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
|
| 1032 |
-
)
|
| 1033 |
-
|
| 1034 |
-
# Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
|
| 1035 |
-
if self.eval_dataset is not None:
|
| 1036 |
-
self.eval_dataset = eval_dataset
|
| 1037 |
-
self._precomputed_eval_ref_log_probs = True
|
| 1038 |
-
|
| 1039 |
-
return super().get_eval_dataloader(eval_dataset=eval_dataset)
|
| 1040 |
-
|
| 1041 |
-
def compute_reference_log_probs(self, padded_batch: dict) -> dict:
|
| 1042 |
-
"""Computes log probabilities of the reference model for a single padded batch of a KTO specific dataset."""
|
| 1043 |
-
with torch.no_grad():
|
| 1044 |
-
if self.ref_model is None:
|
| 1045 |
-
with self.null_ref_context():
|
| 1046 |
-
if self.is_encoder_decoder:
|
| 1047 |
-
completion_logits = self.model(
|
| 1048 |
-
padded_batch["prompt_input_ids"],
|
| 1049 |
-
attention_mask=padded_batch["prompt_attention_mask"],
|
| 1050 |
-
decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
|
| 1051 |
-
labels=padded_batch["completion_labels"],
|
| 1052 |
-
).logits
|
| 1053 |
-
|
| 1054 |
-
if self.calculate_KL:
|
| 1055 |
-
KL_logits = self.model(
|
| 1056 |
-
padded_batch["KL_prompt_input_ids"],
|
| 1057 |
-
attention_mask=padded_batch["KL_prompt_attention_mask"],
|
| 1058 |
-
decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
|
| 1059 |
-
labels=padded_batch["KL_completion_labels"],
|
| 1060 |
-
).logits
|
| 1061 |
-
else:
|
| 1062 |
-
completion_logits = self.model(
|
| 1063 |
-
padded_batch["completion_input_ids"],
|
| 1064 |
-
attention_mask=padded_batch["completion_attention_mask"],
|
| 1065 |
-
).logits
|
| 1066 |
-
|
| 1067 |
-
if self.calculate_KL:
|
| 1068 |
-
KL_logits = self.model(
|
| 1069 |
-
padded_batch["KL_completion_input_ids"],
|
| 1070 |
-
attention_mask=padded_batch["KL_completion_attention_mask"],
|
| 1071 |
-
).logits
|
| 1072 |
-
else:
|
| 1073 |
-
if self.is_encoder_decoder:
|
| 1074 |
-
completion_logits = self.ref_model(
|
| 1075 |
-
padded_batch["prompt_input_ids"],
|
| 1076 |
-
attention_mask=padded_batch["prompt_attention_mask"],
|
| 1077 |
-
decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
|
| 1078 |
-
labels=padded_batch["completion_labels"],
|
| 1079 |
-
).logits
|
| 1080 |
-
|
| 1081 |
-
if self.calculate_KL:
|
| 1082 |
-
KL_logits = self.ref_model(
|
| 1083 |
-
padded_batch["KL_prompt_input_ids"],
|
| 1084 |
-
attention_mask=padded_batch["KL_prompt_attention_mask"],
|
| 1085 |
-
decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
|
| 1086 |
-
labels=padded_batch["KL_completion_labels"],
|
| 1087 |
-
).logits
|
| 1088 |
-
else:
|
| 1089 |
-
completion_logits = self.ref_model(
|
| 1090 |
-
padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
|
| 1091 |
-
).logits
|
| 1092 |
-
|
| 1093 |
-
if self.calculate_KL:
|
| 1094 |
-
KL_logits = self.ref_model(
|
| 1095 |
-
padded_batch["KL_completion_input_ids"],
|
| 1096 |
-
attention_mask=padded_batch["KL_completion_attention_mask"],
|
| 1097 |
-
).logits
|
| 1098 |
-
|
| 1099 |
-
completion_logps = self.get_batch_logps(
|
| 1100 |
-
completion_logits,
|
| 1101 |
-
padded_batch["completion_labels"],
|
| 1102 |
-
average_log_prob=False,
|
| 1103 |
-
is_encoder_decoder=self.is_encoder_decoder,
|
| 1104 |
-
label_pad_token_id=self.label_pad_token_id,
|
| 1105 |
-
)
|
| 1106 |
-
|
| 1107 |
-
if self.calculate_KL:
|
| 1108 |
-
KL_logps = self.get_batch_logps(
|
| 1109 |
-
KL_logits,
|
| 1110 |
-
padded_batch["KL_completion_labels"],
|
| 1111 |
-
average_log_prob=False,
|
| 1112 |
-
is_encoder_decoder=self.is_encoder_decoder,
|
| 1113 |
-
label_pad_token_id=self.label_pad_token_id,
|
| 1114 |
-
)
|
| 1115 |
-
else:
|
| 1116 |
-
KL_logps = None
|
| 1117 |
-
|
| 1118 |
-
return completion_logps, KL_logps
|
| 1119 |
-
|
| 1120 |
-
@staticmethod
|
| 1121 |
-
def get_batch_logps(
|
| 1122 |
-
logits: torch.FloatTensor,
|
| 1123 |
-
labels: torch.LongTensor,
|
| 1124 |
-
average_log_prob: bool = False,
|
| 1125 |
-
label_pad_token_id: int = -100,
|
| 1126 |
-
is_encoder_decoder: bool = False,
|
| 1127 |
-
) -> torch.FloatTensor:
|
| 1128 |
-
"""Compute the log probabilities of the given labels under the given logits.
|
| 1129 |
-
|
| 1130 |
-
Args:
|
| 1131 |
-
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
|
| 1132 |
-
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
|
| 1133 |
-
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
| 1134 |
-
|
| 1135 |
-
Returns:
|
| 1136 |
-
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
| 1137 |
-
"""
|
| 1138 |
-
if logits.shape[:-1] != labels.shape:
|
| 1139 |
-
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
| 1140 |
-
|
| 1141 |
-
if not is_encoder_decoder:
|
| 1142 |
-
labels = labels[:, 1:].clone()
|
| 1143 |
-
logits = logits[:, :-1, :]
|
| 1144 |
-
else:
|
| 1145 |
-
# Fixes end-dec RuntimeError
|
| 1146 |
-
labels = labels.clone()
|
| 1147 |
-
|
| 1148 |
-
loss_mask = labels != label_pad_token_id
|
| 1149 |
-
|
| 1150 |
-
# dummy token; we'll ignore the losses on these tokens later
|
| 1151 |
-
labels[labels == label_pad_token_id] = 0
|
| 1152 |
-
|
| 1153 |
-
per_token_logps = selective_log_softmax(logits, labels)
|
| 1154 |
-
|
| 1155 |
-
if average_log_prob:
|
| 1156 |
-
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
| 1157 |
-
else:
|
| 1158 |
-
return (per_token_logps * loss_mask).sum(-1)
|
| 1159 |
-
|
| 1160 |
-
def forward(
|
| 1161 |
-
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
| 1162 |
-
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
| 1163 |
-
if self.calculate_KL:
|
| 1164 |
-
KL_logps = None
|
| 1165 |
-
KL_model_kwargs = (
|
| 1166 |
-
{
|
| 1167 |
-
"input_ids": batch["KL_prompt_input_ids"],
|
| 1168 |
-
"attention_mask": batch["KL_prompt_attention_mask"],
|
| 1169 |
-
"labels": batch["KL_completion_labels"],
|
| 1170 |
-
"decoder_input_ids": batch.get("KL_completion_decoder_input_ids"),
|
| 1171 |
-
}
|
| 1172 |
-
if self.is_encoder_decoder
|
| 1173 |
-
else {
|
| 1174 |
-
"input_ids": batch["KL_completion_input_ids"],
|
| 1175 |
-
"attention_mask": batch["KL_completion_attention_mask"],
|
| 1176 |
-
}
|
| 1177 |
-
)
|
| 1178 |
-
with torch.no_grad():
|
| 1179 |
-
KL_logits = model(
|
| 1180 |
-
**KL_model_kwargs,
|
| 1181 |
-
).logits
|
| 1182 |
-
|
| 1183 |
-
KL_logps = self.get_batch_logps(
|
| 1184 |
-
KL_logits,
|
| 1185 |
-
batch["KL_completion_labels"],
|
| 1186 |
-
average_log_prob=False,
|
| 1187 |
-
is_encoder_decoder=self.is_encoder_decoder,
|
| 1188 |
-
label_pad_token_id=self.label_pad_token_id,
|
| 1189 |
-
)
|
| 1190 |
-
else:
|
| 1191 |
-
KL_logps = None
|
| 1192 |
-
|
| 1193 |
-
model_kwargs = (
|
| 1194 |
-
{
|
| 1195 |
-
"labels": batch["completion_labels"],
|
| 1196 |
-
"decoder_input_ids": batch.get("completion_decoder_input_ids"),
|
| 1197 |
-
}
|
| 1198 |
-
if self.is_encoder_decoder
|
| 1199 |
-
else {}
|
| 1200 |
-
)
|
| 1201 |
-
if self.aux_loss_enabled:
|
| 1202 |
-
model_kwargs["output_router_logits"] = True
|
| 1203 |
-
|
| 1204 |
-
outputs = model(
|
| 1205 |
-
batch["completion_input_ids"],
|
| 1206 |
-
attention_mask=batch["completion_attention_mask"],
|
| 1207 |
-
**model_kwargs,
|
| 1208 |
-
)
|
| 1209 |
-
completion_logits = outputs.logits
|
| 1210 |
-
|
| 1211 |
-
completion_logps = self.get_batch_logps(
|
| 1212 |
-
completion_logits,
|
| 1213 |
-
batch["completion_labels"],
|
| 1214 |
-
average_log_prob=False,
|
| 1215 |
-
is_encoder_decoder=self.is_encoder_decoder,
|
| 1216 |
-
label_pad_token_id=self.label_pad_token_id,
|
| 1217 |
-
)
|
| 1218 |
-
|
| 1219 |
-
if completion_logps.shape[0] != len(batch["label"]):
|
| 1220 |
-
raise ValueError(
|
| 1221 |
-
"There is a mismatch between the number of examples in this batch and the number of "
|
| 1222 |
-
"examples for which an output sequence was predicted."
|
| 1223 |
-
)
|
| 1224 |
-
|
| 1225 |
-
chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True]
|
| 1226 |
-
rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False]
|
| 1227 |
-
|
| 1228 |
-
chosen_logps = completion_logps[chosen_idx, ...]
|
| 1229 |
-
rejected_logps = completion_logps[rejected_idx, ...]
|
| 1230 |
-
|
| 1231 |
-
chosen_logits = completion_logits[chosen_idx, ...]
|
| 1232 |
-
rejected_logits = completion_logits[rejected_idx, ...]
|
| 1233 |
-
|
| 1234 |
-
if self.aux_loss_enabled:
|
| 1235 |
-
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps, outputs.aux_loss)
|
| 1236 |
-
else:
|
| 1237 |
-
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps)
|
| 1238 |
-
|
| 1239 |
-
def kto_loss(
|
| 1240 |
-
self,
|
| 1241 |
-
policy_chosen_logps: torch.FloatTensor,
|
| 1242 |
-
policy_rejected_logps: torch.FloatTensor,
|
| 1243 |
-
policy_KL_logps: torch.FloatTensor,
|
| 1244 |
-
reference_chosen_logps: torch.FloatTensor,
|
| 1245 |
-
reference_rejected_logps: torch.FloatTensor,
|
| 1246 |
-
reference_KL_logps: torch.FloatTensor,
|
| 1247 |
-
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
| 1248 |
-
"""Compute the KTO loss for a batch of policy and reference model log probabilities.
|
| 1249 |
-
|
| 1250 |
-
Args:
|
| 1251 |
-
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
|
| 1252 |
-
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
|
| 1253 |
-
policy_KL_logps: Log probabilities of the policy model for the KL responses. Shape: (batch_size,)
|
| 1254 |
-
reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
|
| 1255 |
-
reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,)
|
| 1256 |
-
reference_KL_logps: Log probabilities of the reference model for the KL responses. Shape: (batch_size,)
|
| 1257 |
-
|
| 1258 |
-
Returns:
|
| 1259 |
-
A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, KL).
|
| 1260 |
-
The losses tensor contains the KTO loss for each example in the batch.
|
| 1261 |
-
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
| 1262 |
-
The KL tensor contains the detached KL divergence estimate between the policy and reference models.
|
| 1263 |
-
"""
|
| 1264 |
-
if self.calculate_KL:
|
| 1265 |
-
kl = (policy_KL_logps - reference_KL_logps).mean().detach()
|
| 1266 |
-
kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0)
|
| 1267 |
-
else:
|
| 1268 |
-
kl = torch.zeros(1).to(policy_chosen_logps.device)
|
| 1269 |
-
|
| 1270 |
-
# Chosen losses
|
| 1271 |
-
if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
|
| 1272 |
-
chosen_logratios = policy_chosen_logps - reference_chosen_logps
|
| 1273 |
-
|
| 1274 |
-
if self.loss_type == "kto":
|
| 1275 |
-
# Eqn (7) of the KTO paper (https://huggingface.co/papers/2402.01306)
|
| 1276 |
-
chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl))
|
| 1277 |
-
elif self.loss_type == "apo_zero_unpaired":
|
| 1278 |
-
# Unpaired variant of Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
|
| 1279 |
-
# Use this loss when you believe the chosen outputs are better than your model's default output
|
| 1280 |
-
chosen_losses = 1 - F.sigmoid(self.beta * chosen_logratios)
|
| 1281 |
-
|
| 1282 |
-
chosen_rewards = self.beta * chosen_logratios.detach()
|
| 1283 |
-
|
| 1284 |
-
else:
|
| 1285 |
-
# lists can't be empty -- if they are, then accelerate.gather will hang
|
| 1286 |
-
chosen_losses = torch.Tensor([]).to(self.accelerator.device)
|
| 1287 |
-
chosen_rewards = torch.Tensor([]).to(self.accelerator.device)
|
| 1288 |
-
|
| 1289 |
-
# Rejected losses
|
| 1290 |
-
if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
|
| 1291 |
-
rejected_logratios = policy_rejected_logps - reference_rejected_logps
|
| 1292 |
-
|
| 1293 |
-
if self.loss_type == "kto":
|
| 1294 |
-
rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios))
|
| 1295 |
-
elif self.loss_type == "apo_zero_unpaired":
|
| 1296 |
-
rejected_losses = F.sigmoid(self.beta * rejected_logratios)
|
| 1297 |
-
|
| 1298 |
-
rejected_rewards = self.beta * rejected_logratios.detach()
|
| 1299 |
-
else:
|
| 1300 |
-
# lists can't be empty -- if they are, then accelerate.gather will hang
|
| 1301 |
-
rejected_losses = torch.Tensor([]).to(self.accelerator.device)
|
| 1302 |
-
rejected_rewards = torch.Tensor([]).to(self.accelerator.device)
|
| 1303 |
-
|
| 1304 |
-
losses = torch.cat(
|
| 1305 |
-
(self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses),
|
| 1306 |
-
0,
|
| 1307 |
-
)
|
| 1308 |
-
|
| 1309 |
-
return losses, chosen_rewards, rejected_rewards, kl
|
| 1310 |
-
|
| 1311 |
-
def get_batch_loss_metrics(
|
| 1312 |
-
self,
|
| 1313 |
-
model,
|
| 1314 |
-
batch: dict[str, Union[list, torch.LongTensor]],
|
| 1315 |
-
):
|
| 1316 |
-
"""Compute the KTO loss and other metrics for the given batch of inputs for train or test."""
|
| 1317 |
-
metrics = {}
|
| 1318 |
-
batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
| 1319 |
-
|
| 1320 |
-
forward_output = self.forward(model, batch)
|
| 1321 |
-
(
|
| 1322 |
-
policy_chosen_logps,
|
| 1323 |
-
policy_rejected_logps,
|
| 1324 |
-
policy_chosen_logits,
|
| 1325 |
-
policy_rejected_logits,
|
| 1326 |
-
policy_KL_logps,
|
| 1327 |
-
) = forward_output[:5]
|
| 1328 |
-
if self.aux_loss_enabled:
|
| 1329 |
-
aux_loss = forward_output[5]
|
| 1330 |
-
|
| 1331 |
-
# if reference_logps in batch use them, otherwise use the reference model
|
| 1332 |
-
if "reference_logps" in batch:
|
| 1333 |
-
chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
|
| 1334 |
-
rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]
|
| 1335 |
-
|
| 1336 |
-
reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
|
| 1337 |
-
reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
|
| 1338 |
-
if self.calculate_KL:
|
| 1339 |
-
reference_KL_logps = batch["reference_KL_logps"]
|
| 1340 |
-
else:
|
| 1341 |
-
reference_KL_logps = None
|
| 1342 |
-
else:
|
| 1343 |
-
with torch.no_grad():
|
| 1344 |
-
if self.ref_model is None:
|
| 1345 |
-
with self.null_ref_context():
|
| 1346 |
-
(
|
| 1347 |
-
reference_chosen_logps,
|
| 1348 |
-
reference_rejected_logps,
|
| 1349 |
-
_,
|
| 1350 |
-
_,
|
| 1351 |
-
reference_KL_logps,
|
| 1352 |
-
) = self.forward(self.model, batch)[:5]
|
| 1353 |
-
else:
|
| 1354 |
-
(
|
| 1355 |
-
reference_chosen_logps,
|
| 1356 |
-
reference_rejected_logps,
|
| 1357 |
-
_,
|
| 1358 |
-
_,
|
| 1359 |
-
reference_KL_logps,
|
| 1360 |
-
) = self.forward(self.ref_model, batch)[:5]
|
| 1361 |
-
|
| 1362 |
-
losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
|
| 1363 |
-
policy_chosen_logps,
|
| 1364 |
-
policy_rejected_logps,
|
| 1365 |
-
policy_KL_logps,
|
| 1366 |
-
reference_chosen_logps,
|
| 1367 |
-
reference_rejected_logps,
|
| 1368 |
-
reference_KL_logps,
|
| 1369 |
-
)
|
| 1370 |
-
metrics["kl"] = kl.item()
|
| 1371 |
-
|
| 1372 |
-
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
|
| 1373 |
-
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
|
| 1374 |
-
|
| 1375 |
-
all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
|
| 1376 |
-
all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()
|
| 1377 |
-
|
| 1378 |
-
if all_num_chosen > 0:
|
| 1379 |
-
metrics["rewards/chosen_sum"] = (
|
| 1380 |
-
self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
|
| 1381 |
-
)
|
| 1382 |
-
metrics["logps/chosen_sum"] = (
|
| 1383 |
-
self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
|
| 1384 |
-
)
|
| 1385 |
-
metrics["logits/chosen_sum"] = (
|
| 1386 |
-
self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
|
| 1387 |
-
)
|
| 1388 |
-
metrics["count/chosen"] = all_num_chosen
|
| 1389 |
-
|
| 1390 |
-
if all_num_rejected > 0:
|
| 1391 |
-
metrics["rewards/rejected_sum"] = (
|
| 1392 |
-
self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
|
| 1393 |
-
)
|
| 1394 |
-
metrics["logps/rejected_sum"] = (
|
| 1395 |
-
self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
|
| 1396 |
-
)
|
| 1397 |
-
metrics["logits/rejected_sum"] = (
|
| 1398 |
-
self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
|
| 1399 |
-
)
|
| 1400 |
-
metrics["count/rejected"] = all_num_rejected
|
| 1401 |
-
|
| 1402 |
-
loss = losses.nanmean()
|
| 1403 |
-
if self.aux_loss_enabled:
|
| 1404 |
-
loss += self.aux_loss_coef * aux_loss
|
| 1405 |
-
|
| 1406 |
-
return loss, metrics
|
| 1407 |
-
|
| 1408 |
-
def compute_loss(
|
| 1409 |
-
self,
|
| 1410 |
-
model: Union[PreTrainedModel, nn.Module],
|
| 1411 |
-
inputs: dict[str, Union[torch.Tensor, Any]],
|
| 1412 |
-
return_outputs=False,
|
| 1413 |
-
num_items_in_batch=None,
|
| 1414 |
-
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
| 1415 |
-
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1416 |
-
|
| 1417 |
-
with compute_loss_context_manager:
|
| 1418 |
-
loss, metrics = self.get_batch_loss_metrics(model, inputs)
|
| 1419 |
-
|
| 1420 |
-
# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
|
| 1421 |
-
loss = loss.to(self.args.device)
|
| 1422 |
-
# force log the metrics
|
| 1423 |
-
if self.accelerator.is_main_process:
|
| 1424 |
-
self.store_metrics(metrics, train_eval="train")
|
| 1425 |
-
|
| 1426 |
-
if return_outputs:
|
| 1427 |
-
return (loss, metrics)
|
| 1428 |
-
return loss
|
| 1429 |
-
|
| 1430 |
-
def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
|
| 1431 |
-
for key, value in metrics.items():
|
| 1432 |
-
self._stored_metrics[train_eval][key].append(value)
|
| 1433 |
-
|
| 1434 |
-
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
| 1435 |
-
if self.train_dataset is None or not has_length(self.train_dataset):
|
| 1436 |
-
return None
|
| 1437 |
-
return SequentialSampler(self.train_dataset)
|
| 1438 |
-
|
| 1439 |
-
def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
|
| 1440 |
-
"""Generate samples from the model and reference model for the given batch of inputs."""
|
| 1441 |
-
|
| 1442 |
-
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
| 1443 |
-
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
|
| 1444 |
-
generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1445 |
-
|
| 1446 |
-
with generate_context_manager:
|
| 1447 |
-
policy_output = model.generate(
|
| 1448 |
-
input_ids=batch["prompt_input_ids"],
|
| 1449 |
-
attention_mask=batch["prompt_attention_mask"],
|
| 1450 |
-
max_length=self.max_length,
|
| 1451 |
-
do_sample=True,
|
| 1452 |
-
pad_token_id=self.processing_class.pad_token_id,
|
| 1453 |
-
)
|
| 1454 |
-
|
| 1455 |
-
# if reference_output in batch use that otherwise use the reference model
|
| 1456 |
-
if "reference_output" in batch:
|
| 1457 |
-
reference_output = batch["reference_output"]
|
| 1458 |
-
else:
|
| 1459 |
-
if self.ref_model is None:
|
| 1460 |
-
with self.null_ref_context():
|
| 1461 |
-
reference_output = self.model.generate(
|
| 1462 |
-
input_ids=batch["prompt_input_ids"],
|
| 1463 |
-
attention_mask=batch["prompt_attention_mask"],
|
| 1464 |
-
max_length=self.max_length,
|
| 1465 |
-
do_sample=True,
|
| 1466 |
-
pad_token_id=self.processing_class.pad_token_id,
|
| 1467 |
-
)
|
| 1468 |
-
else:
|
| 1469 |
-
reference_output = self.ref_model.generate(
|
| 1470 |
-
input_ids=batch["prompt_input_ids"],
|
| 1471 |
-
attention_mask=batch["prompt_attention_mask"],
|
| 1472 |
-
max_length=self.max_length,
|
| 1473 |
-
do_sample=True,
|
| 1474 |
-
pad_token_id=self.processing_class.pad_token_id,
|
| 1475 |
-
)
|
| 1476 |
-
|
| 1477 |
-
policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
|
| 1478 |
-
policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
|
| 1479 |
-
|
| 1480 |
-
reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
|
| 1481 |
-
reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)
|
| 1482 |
-
|
| 1483 |
-
return policy_output_decoded, reference_output_decoded
|
| 1484 |
-
|
| 1485 |
-
def prediction_step(
|
| 1486 |
-
self,
|
| 1487 |
-
model: Union[PreTrainedModel, nn.Module],
|
| 1488 |
-
inputs: dict[str, Union[torch.Tensor, Any]],
|
| 1489 |
-
prediction_loss_only: bool,
|
| 1490 |
-
ignore_keys: Optional[list[str]] = None,
|
| 1491 |
-
):
|
| 1492 |
-
if ignore_keys is None:
|
| 1493 |
-
if hasattr(model, "config"):
|
| 1494 |
-
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
| 1495 |
-
else:
|
| 1496 |
-
ignore_keys = []
|
| 1497 |
-
|
| 1498 |
-
prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1499 |
-
with torch.no_grad(), prediction_context_manager:
|
| 1500 |
-
loss, metrics = self.get_batch_loss_metrics(model, inputs)
|
| 1501 |
-
|
| 1502 |
-
# force log the metrics
|
| 1503 |
-
if self.accelerator.is_main_process:
|
| 1504 |
-
self.store_metrics(metrics, train_eval="eval")
|
| 1505 |
-
|
| 1506 |
-
if prediction_loss_only:
|
| 1507 |
-
return (loss.detach(), None, None)
|
| 1508 |
-
|
| 1509 |
-
# logits for the chosen and rejected samples from model
|
| 1510 |
-
logits_dict = {}
|
| 1511 |
-
if "logits/chosen_sum" in metrics:
|
| 1512 |
-
logits_dict["eval_logits/chosen"] = metrics["logits/chosen_sum"]
|
| 1513 |
-
if "logits/rejected_sum" in metrics:
|
| 1514 |
-
logits_dict["eval_logits/rejected"] = metrics["logits/rejected_sum"]
|
| 1515 |
-
logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
|
| 1516 |
-
logits = torch.tensor(logits, device=self.accelerator.device)
|
| 1517 |
-
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
| 1518 |
-
|
| 1519 |
-
return (loss.detach(), logits, labels)
|
| 1520 |
-
|
| 1521 |
-
def evaluation_loop(
|
| 1522 |
-
self,
|
| 1523 |
-
dataloader: DataLoader,
|
| 1524 |
-
description: str,
|
| 1525 |
-
prediction_loss_only: Optional[bool] = None,
|
| 1526 |
-
ignore_keys: Optional[list[str]] = None,
|
| 1527 |
-
metric_key_prefix: str = "eval",
|
| 1528 |
-
) -> EvalLoopOutput:
|
| 1529 |
-
"""
|
| 1530 |
-
Overriding built-in evaluation loop to store metrics for each batch.
|
| 1531 |
-
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
| 1532 |
-
|
| 1533 |
-
Works both with or without labels.
|
| 1534 |
-
"""
|
| 1535 |
-
|
| 1536 |
-
# Sample and save to game log if requested (for one batch to save time)
|
| 1537 |
-
if self.generate_during_eval:
|
| 1538 |
-
# Generate random indices within the range of the total number of samples
|
| 1539 |
-
num_samples = len(dataloader.dataset)
|
| 1540 |
-
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
|
| 1541 |
-
|
| 1542 |
-
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
| 1543 |
-
random_batch_dataset = dataloader.dataset.select(random_indices)
|
| 1544 |
-
random_batch = self.data_collator(random_batch_dataset)
|
| 1545 |
-
random_batch = self._prepare_inputs(random_batch)
|
| 1546 |
-
|
| 1547 |
-
target_indicies = [i for i in range(len(random_batch["label"])) if random_batch["label"][i] is False]
|
| 1548 |
-
target_batch = {
|
| 1549 |
-
"prompt_input_ids": random_batch["prompt_input_ids"][target_indicies],
|
| 1550 |
-
"prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies],
|
| 1551 |
-
"prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
|
| 1552 |
-
}
|
| 1553 |
-
policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)
|
| 1554 |
-
|
| 1555 |
-
table = pd.DataFrame(
|
| 1556 |
-
columns=["Prompt", "Policy", "Ref Model"],
|
| 1557 |
-
data=[
|
| 1558 |
-
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
|
| 1559 |
-
for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
|
| 1560 |
-
],
|
| 1561 |
-
)
|
| 1562 |
-
if "wandb" in self.args.report_to:
|
| 1563 |
-
wandb.log({"game_log": wandb.Table(data=table)})
|
| 1564 |
-
|
| 1565 |
-
if "comet_ml" in self.args.report_to:
|
| 1566 |
-
log_table_to_comet_experiment(
|
| 1567 |
-
name="game_log.csv",
|
| 1568 |
-
table=table,
|
| 1569 |
-
)
|
| 1570 |
-
|
| 1571 |
-
# Base evaluation
|
| 1572 |
-
initial_output = super().evaluation_loop(
|
| 1573 |
-
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
|
| 1574 |
-
)
|
| 1575 |
-
|
| 1576 |
-
return initial_output
|
| 1577 |
-
|
| 1578 |
-
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
| 1579 |
-
"""
|
| 1580 |
-
Log `logs` on the various objects watching training, including stored metrics.
|
| 1581 |
-
|
| 1582 |
-
Args:
|
| 1583 |
-
logs (`dict[str, float]`):
|
| 1584 |
-
The values to log.
|
| 1585 |
-
start_time (`float` or `None`, *optional*, defaults to `None`):
|
| 1586 |
-
Start time of the training.
|
| 1587 |
-
"""
|
| 1588 |
-
# logs either has 'loss' or 'eval_loss'
|
| 1589 |
-
train_eval = "train" if "loss" in logs else "eval"
|
| 1590 |
-
# train metrics should have no prefix, eval should have 'eval_'
|
| 1591 |
-
prefix = "eval_" if train_eval == "eval" else ""
|
| 1592 |
-
# accumulate average metrics from sums and lengths
|
| 1593 |
-
for split in ["chosen", "rejected"]:
|
| 1594 |
-
if f"count/{split}" in self._stored_metrics[train_eval]:
|
| 1595 |
-
count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
|
| 1596 |
-
for metric in ["rewards", "logps", "logits"]:
|
| 1597 |
-
logs[f"{prefix}{metric}/{split}"] = (
|
| 1598 |
-
torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item()
|
| 1599 |
-
/ count_sum
|
| 1600 |
-
)
|
| 1601 |
-
# delete obsolete metric
|
| 1602 |
-
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
|
| 1603 |
-
del self._stored_metrics[train_eval][f"count/{split}"]
|
| 1604 |
-
# calculate reward margin
|
| 1605 |
-
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
|
| 1606 |
-
logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
|
| 1607 |
-
# Add averaged stored metrics to logs
|
| 1608 |
-
for key, metrics in self._stored_metrics[train_eval].items():
|
| 1609 |
-
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
|
| 1610 |
-
del self._stored_metrics[train_eval]
|
| 1611 |
-
|
| 1612 |
-
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
| 1613 |
-
return super().log(logs, start_time)
|
| 1614 |
-
else: # transformers<=4.46
|
| 1615 |
-
return super().log(logs)
|
| 1616 |
-
|
| 1617 |
-
def create_model_card(
|
| 1618 |
-
self,
|
| 1619 |
-
model_name: Optional[str] = None,
|
| 1620 |
-
dataset_name: Optional[str] = None,
|
| 1621 |
-
tags: Union[str, list[str], None] = None,
|
| 1622 |
-
):
|
| 1623 |
-
"""
|
| 1624 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
| 1625 |
-
|
| 1626 |
-
Args:
|
| 1627 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1628 |
-
Name of the model.
|
| 1629 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1630 |
-
Name of the dataset used for training.
|
| 1631 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 1632 |
-
Tags to be associated with the model card.
|
| 1633 |
-
"""
|
| 1634 |
-
if not self.is_world_process_zero():
|
| 1635 |
-
return
|
| 1636 |
-
|
| 1637 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 1638 |
-
base_model = self.model.config._name_or_path
|
| 1639 |
-
else:
|
| 1640 |
-
base_model = None
|
| 1641 |
-
|
| 1642 |
-
tags = tags or []
|
| 1643 |
-
if isinstance(tags, str):
|
| 1644 |
-
tags = [tags]
|
| 1645 |
-
|
| 1646 |
-
if hasattr(self.model.config, "unsloth_version"):
|
| 1647 |
-
tags.append("unsloth")
|
| 1648 |
-
|
| 1649 |
-
citation = textwrap.dedent("""\
|
| 1650 |
-
@article{ethayarajh2024kto,
|
| 1651 |
-
title = {{KTO: Model Alignment as Prospect Theoretic Optimization}},
|
| 1652 |
-
author = {Kawin Ethayarajh and Winnie Xu and Niklas Muennighoff and Dan Jurafsky and Douwe Kiela},
|
| 1653 |
-
year = 2024,
|
| 1654 |
-
eprint = {arXiv:2402.01306},
|
| 1655 |
-
}""")
|
| 1656 |
-
|
| 1657 |
-
model_card = generate_model_card(
|
| 1658 |
-
base_model=base_model,
|
| 1659 |
-
model_name=model_name,
|
| 1660 |
-
hub_model_id=self.hub_model_id,
|
| 1661 |
-
dataset_name=dataset_name,
|
| 1662 |
-
tags=tags,
|
| 1663 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 1664 |
-
comet_url=get_comet_experiment_url(),
|
| 1665 |
-
trainer_name="KTO",
|
| 1666 |
-
trainer_citation=citation,
|
| 1667 |
-
paper_title="KTO: Model Alignment as Prospect Theoretic Optimization",
|
| 1668 |
-
paper_id="2402.01306",
|
| 1669 |
-
)
|
| 1670 |
-
|
| 1671 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 1672 |
-
class UnslothKTOTrainer(_UnslothKTOTrainer):
|
| 1673 |
-
"""
|
| 1674 |
-
|
| 1675 |
-
Initialize KTOTrainer.
|
| 1676 |
-
|
| 1677 |
-
Args:
|
| 1678 |
-
model (`transformers.PreTrainedModel`):
|
| 1679 |
-
The model to train, preferably an `AutoModelForSequenceClassification`.
|
| 1680 |
-
ref_model (`PreTrainedModelWrapper`):
|
| 1681 |
-
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
|
| 1682 |
-
reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
|
| 1683 |
-
args (`KTOConfig`):
|
| 1684 |
-
The arguments to use for training.
|
| 1685 |
-
train_dataset (`datasets.Dataset`):
|
| 1686 |
-
The dataset to use for training.
|
| 1687 |
-
eval_dataset (`datasets.Dataset`):
|
| 1688 |
-
The dataset to use for evaluation.
|
| 1689 |
-
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
| 1690 |
-
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
| 1691 |
-
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
| 1692 |
-
reuse the fine-tuned model.
|
| 1693 |
-
data_collator (`transformers.DataCollator`, *optional*, defaults to `None`):
|
| 1694 |
-
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
| 1695 |
-
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
| 1696 |
-
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
| 1697 |
-
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
| 1698 |
-
callbacks (`list[transformers.TrainerCallback]`):
|
| 1699 |
-
The callbacks to use for training.
|
| 1700 |
-
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
| 1701 |
-
The optimizer and scheduler to use for training.
|
| 1702 |
-
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
| 1703 |
-
The function to use to preprocess the logits before computing the metrics.
|
| 1704 |
-
peft_config (`dict`, defaults to `None`):
|
| 1705 |
-
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
| 1706 |
-
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
| 1707 |
-
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
| 1708 |
-
a dictionary string to metric values.
|
| 1709 |
-
model_adapter_name (`str`, defaults to `None`):
|
| 1710 |
-
Name of the train target PEFT adapter, when using LoRA with multiple adapters.
|
| 1711 |
-
ref_adapter_name (`str`, defaults to `None`):
|
| 1712 |
-
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
|
| 1713 |
-
|
| 1714 |
-
"""
|
| 1715 |
-
def __init__(
|
| 1716 |
-
self,
|
| 1717 |
-
model = None,
|
| 1718 |
-
ref_model = None,
|
| 1719 |
-
args = None,
|
| 1720 |
-
train_dataset = None,
|
| 1721 |
-
eval_dataset = None,
|
| 1722 |
-
processing_class = None,
|
| 1723 |
-
data_collator = None,
|
| 1724 |
-
model_init = None,
|
| 1725 |
-
callbacks = None,
|
| 1726 |
-
preprocess_logits_for_metrics = None,
|
| 1727 |
-
peft_config = None,
|
| 1728 |
-
compute_metrics = None,
|
| 1729 |
-
model_adapter_name = None,
|
| 1730 |
-
ref_adapter_name = None,
|
| 1731 |
-
**kwargs
|
| 1732 |
-
):
|
| 1733 |
-
if args is None: args = UnslothKTOConfig()
|
| 1734 |
-
use_bf16 = getattr(args, 'bf16', False)
|
| 1735 |
-
if type(use_bf16) is not bool: use_bf16 = False
|
| 1736 |
-
use_fp16 = getattr(args, 'fp16', False)
|
| 1737 |
-
if type(use_fp16) is not bool: use_fp16 = False
|
| 1738 |
-
force_float32 = False
|
| 1739 |
-
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
| 1740 |
-
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 1741 |
-
force_float32 = True
|
| 1742 |
-
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 1743 |
-
dtype = getattr(model.config, 'torch_dtype', None)
|
| 1744 |
-
if dtype is None: dtype = model.get_input_embeddings().dtype
|
| 1745 |
-
from unsloth_zoo.utils import _get_dtype
|
| 1746 |
-
dtype = _get_dtype(dtype)
|
| 1747 |
-
float16 = dtype == torch.float16
|
| 1748 |
-
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 1749 |
-
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 1750 |
-
if force_float32:
|
| 1751 |
-
args.fp16 = False
|
| 1752 |
-
args.bf16 = False
|
| 1753 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 1754 |
-
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 1755 |
-
args.fp16 = float16
|
| 1756 |
-
args.bf16 = not float16
|
| 1757 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 1758 |
-
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 1759 |
-
args.eval_strategy = 'steps'
|
| 1760 |
-
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 1761 |
-
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 1762 |
-
if ga_steps is not None and ga_steps > 1:
|
| 1763 |
-
from transformers import __version__ as transformers_version
|
| 1764 |
-
if Version(transformers_version) <= Version('4.45.2'):
|
| 1765 |
-
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 1766 |
-
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 1767 |
-
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 1768 |
-
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 1769 |
-
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 1770 |
-
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 1771 |
-
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 1772 |
-
if type(fp16_full_eval) is not bool: fp16_full_eval = False
|
| 1773 |
-
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 1774 |
-
if type(bf16_full_eval) is not bool: bf16_full_eval = False
|
| 1775 |
-
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 1776 |
-
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 1777 |
-
if force_float32:
|
| 1778 |
-
args.bf16_full_eval = False
|
| 1779 |
-
args.fp16_full_eval = False
|
| 1780 |
-
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 1781 |
-
args.bf16_full_eval = True
|
| 1782 |
-
args.fp16_full_eval = False
|
| 1783 |
-
elif not bf16_full_eval and not fp16_full_eval:
|
| 1784 |
-
args.bf16_full_eval = args.bf16
|
| 1785 |
-
args.fp16_full_eval = args.fp16
|
| 1786 |
-
_output_logits = False
|
| 1787 |
-
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 1788 |
-
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 1789 |
-
if _output_logits:
|
| 1790 |
-
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 1791 |
-
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 1792 |
-
pass
|
| 1793 |
-
else:
|
| 1794 |
-
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 1795 |
-
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 1796 |
-
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 1797 |
-
max_seq_length = model.max_seq_length
|
| 1798 |
-
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 1799 |
-
if model is not None and hasattr(model, 'for_training'):
|
| 1800 |
-
model.for_training()
|
| 1801 |
-
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 1802 |
-
if 'processing_class' in locals():
|
| 1803 |
-
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 1804 |
-
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 1805 |
-
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 1806 |
-
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 1807 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1808 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 1809 |
-
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
|
| 1810 |
-
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 1811 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
| 1812 |
-
else:
|
| 1813 |
-
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 1814 |
-
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 1815 |
-
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 1816 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1817 |
-
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 1818 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 1819 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
| 1820 |
-
else:
|
| 1821 |
-
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
|
| 1822 |
-
other_metrics = []
|
| 1823 |
-
|
| 1824 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 1825 |
-
PatchRLStatistics('kto_trainer', other_metrics)
|
| 1826 |
-
|
| 1827 |
-
super().__init__(
|
| 1828 |
-
model = model,
|
| 1829 |
-
ref_model = ref_model,
|
| 1830 |
-
args = args,
|
| 1831 |
-
train_dataset = train_dataset,
|
| 1832 |
-
eval_dataset = eval_dataset,
|
| 1833 |
-
processing_class = processing_class,
|
| 1834 |
-
data_collator = data_collator,
|
| 1835 |
-
model_init = model_init,
|
| 1836 |
-
callbacks = callbacks,
|
| 1837 |
-
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
| 1838 |
-
peft_config = peft_config,
|
| 1839 |
-
compute_metrics = compute_metrics,
|
| 1840 |
-
model_adapter_name = model_adapter_name,
|
| 1841 |
-
ref_adapter_name = ref_adapter_name,**kwargs)
|
| 1842 |
-
if hasattr(self, 'neftune_hook_handle'):
|
| 1843 |
-
self.neftune_hook_handle.remove()
|
| 1844 |
-
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 1845 |
-
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 1846 |
-
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 1847 |
-
pass
|
| 1848 |
-
|
| 1849 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_run_uploads/UnslothNashMDTrainer.py
DELETED
|
@@ -1,969 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
2025.7.11
|
| 3 |
-
2025.7.11
|
| 4 |
-
4.54.1
|
| 5 |
-
0.16.1
|
| 6 |
-
__UNSLOTH_VERSIONING__
|
| 7 |
-
"""
|
| 8 |
-
from torch import Tensor
|
| 9 |
-
import torch
|
| 10 |
-
import torch.nn as nn
|
| 11 |
-
from torch.nn import functional as F
|
| 12 |
-
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
| 13 |
-
from trl.trainer.nash_md_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, GeometricMixtureWrapper, IterableDataset, NashMDConfig, NashMDTrainer, OnlineDPOTrainer, OptimizerNames, Optional, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_wandb_available, jinja2, maybe_apply_chat_template, nn, os, selective_log_softmax, textwrap, torch, truncate_right, unwrap_model_for_generation, wandb)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
import os
|
| 17 |
-
from typing import *
|
| 18 |
-
from dataclasses import dataclass, field
|
| 19 |
-
from packaging.version import Version
|
| 20 |
-
import torch
|
| 21 |
-
import numpy as np
|
| 22 |
-
from contextlib import nullcontext
|
| 23 |
-
from torch.nn import functional as F
|
| 24 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
| 25 |
-
|
| 26 |
-
torch_compile_options = {
|
| 27 |
-
"epilogue_fusion" : True,
|
| 28 |
-
"max_autotune" : False,
|
| 29 |
-
"shape_padding" : True,
|
| 30 |
-
"trace.enabled" : False,
|
| 31 |
-
"triton.cudagraphs" : False,
|
| 32 |
-
}
|
| 33 |
-
|
| 34 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 35 |
-
def chunked_selective_log_softmax(logits, index):
|
| 36 |
-
# Split into 4 chunks only
|
| 37 |
-
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
|
| 38 |
-
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
|
| 39 |
-
all_per_token_logps = []
|
| 40 |
-
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
|
| 41 |
-
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
|
| 42 |
-
chunk_logits = chunk_logits.to(torch.float32)
|
| 43 |
-
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 44 |
-
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
|
| 45 |
-
per_token_logps = selected_logits - logsumexp_values
|
| 46 |
-
all_per_token_logps.append(per_token_logps)
|
| 47 |
-
pass
|
| 48 |
-
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 49 |
-
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
|
| 50 |
-
return all_per_token_logps
|
| 51 |
-
@dataclass
|
| 52 |
-
class UnslothNashMDConfig(NashMDConfig):
|
| 53 |
-
"""
|
| 54 |
-
|
| 55 |
-
Configuration class for the [`NashMDTrainer`].
|
| 56 |
-
|
| 57 |
-
Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
|
| 58 |
-
|
| 59 |
-
Parameters:
|
| 60 |
-
mixture_coef (`float` or `list[float]`, *optional*, defaults to `0.5`):
|
| 61 |
-
Logit mixture coefficient for the model and reference model. If a list of floats is provided then the
|
| 62 |
-
mixture coefficient is selected for each new epoch and the last coefficient is used for the rest of the
|
| 63 |
-
epochs.
|
| 64 |
-
|
| 65 |
-
"""
|
| 66 |
-
vllm_sampling_params: Optional[Any] = field(
|
| 67 |
-
default = None,
|
| 68 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
| 69 |
-
)
|
| 70 |
-
unsloth_num_chunks : Optional[int] = field(
|
| 71 |
-
default = -1,
|
| 72 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 73 |
-
)
|
| 74 |
-
def __init__(
|
| 75 |
-
self,
|
| 76 |
-
output_dir = None,
|
| 77 |
-
overwrite_output_dir = None,
|
| 78 |
-
do_train = False,
|
| 79 |
-
do_eval = False,
|
| 80 |
-
do_predict = False,
|
| 81 |
-
eval_strategy = 'no',
|
| 82 |
-
prediction_loss_only = False,
|
| 83 |
-
per_device_train_batch_size = 4,
|
| 84 |
-
per_device_eval_batch_size = 4,
|
| 85 |
-
per_gpu_train_batch_size = None,
|
| 86 |
-
per_gpu_eval_batch_size = None,
|
| 87 |
-
gradient_accumulation_steps = 2,
|
| 88 |
-
eval_accumulation_steps = 2,
|
| 89 |
-
eval_delay = 0,
|
| 90 |
-
torch_empty_cache_steps = 250,
|
| 91 |
-
learning_rate = 5e-05,
|
| 92 |
-
weight_decay = 0.01,
|
| 93 |
-
adam_beta1 = 0.9,
|
| 94 |
-
adam_beta2 = 0.999,
|
| 95 |
-
adam_epsilon = 1e-08,
|
| 96 |
-
max_grad_norm = 1.0,
|
| 97 |
-
num_train_epochs = 3.0,
|
| 98 |
-
max_steps = -1,
|
| 99 |
-
lr_scheduler_type = 'linear',
|
| 100 |
-
warmup_ratio = 0.1,
|
| 101 |
-
warmup_steps = 0,
|
| 102 |
-
log_level = 'passive',
|
| 103 |
-
log_level_replica = 'warning',
|
| 104 |
-
log_on_each_node = True,
|
| 105 |
-
logging_dir = None,
|
| 106 |
-
logging_strategy = 'steps',
|
| 107 |
-
logging_first_step = False,
|
| 108 |
-
logging_steps = 1,
|
| 109 |
-
logging_nan_inf_filter = False,
|
| 110 |
-
save_strategy = 'steps',
|
| 111 |
-
save_steps = 500,
|
| 112 |
-
save_total_limit = None,
|
| 113 |
-
save_safetensors = True,
|
| 114 |
-
save_on_each_node = False,
|
| 115 |
-
save_only_model = False,
|
| 116 |
-
restore_callback_states_from_checkpoint = False,
|
| 117 |
-
no_cuda = False,
|
| 118 |
-
use_cpu = False,
|
| 119 |
-
use_mps_device = False,
|
| 120 |
-
seed = 3407,
|
| 121 |
-
data_seed = 3407,
|
| 122 |
-
jit_mode_eval = False,
|
| 123 |
-
use_ipex = False,
|
| 124 |
-
bf16 = False,
|
| 125 |
-
fp16 = False,
|
| 126 |
-
fp16_opt_level = 'O1',
|
| 127 |
-
half_precision_backend = 'auto',
|
| 128 |
-
bf16_full_eval = False,
|
| 129 |
-
fp16_full_eval = False,
|
| 130 |
-
tf32 = None,
|
| 131 |
-
local_rank = -1,
|
| 132 |
-
ddp_backend = None,
|
| 133 |
-
tpu_num_cores = None,
|
| 134 |
-
tpu_metrics_debug = False,
|
| 135 |
-
debug = '',
|
| 136 |
-
dataloader_drop_last = False,
|
| 137 |
-
eval_steps = None,
|
| 138 |
-
dataloader_num_workers = 0,
|
| 139 |
-
dataloader_prefetch_factor = None,
|
| 140 |
-
past_index = -1,
|
| 141 |
-
run_name = None,
|
| 142 |
-
disable_tqdm = None,
|
| 143 |
-
remove_unused_columns = True,
|
| 144 |
-
label_names = None,
|
| 145 |
-
load_best_model_at_end = False,
|
| 146 |
-
metric_for_best_model = None,
|
| 147 |
-
greater_is_better = None,
|
| 148 |
-
ignore_data_skip = False,
|
| 149 |
-
fsdp = '',
|
| 150 |
-
fsdp_min_num_params = 0,
|
| 151 |
-
fsdp_config = None,
|
| 152 |
-
fsdp_transformer_layer_cls_to_wrap = None,
|
| 153 |
-
accelerator_config = None,
|
| 154 |
-
deepspeed = None,
|
| 155 |
-
label_smoothing_factor = 0.0,
|
| 156 |
-
optim = 'adamw_8bit',
|
| 157 |
-
optim_args = None,
|
| 158 |
-
adafactor = False,
|
| 159 |
-
group_by_length = False,
|
| 160 |
-
length_column_name = 'length',
|
| 161 |
-
report_to = None,
|
| 162 |
-
ddp_find_unused_parameters = None,
|
| 163 |
-
ddp_bucket_cap_mb = None,
|
| 164 |
-
ddp_broadcast_buffers = None,
|
| 165 |
-
dataloader_pin_memory = True,
|
| 166 |
-
dataloader_persistent_workers = False,
|
| 167 |
-
skip_memory_metrics = True,
|
| 168 |
-
use_legacy_prediction_loop = False,
|
| 169 |
-
push_to_hub = False,
|
| 170 |
-
resume_from_checkpoint = None,
|
| 171 |
-
hub_model_id = None,
|
| 172 |
-
hub_strategy = 'every_save',
|
| 173 |
-
hub_token = None,
|
| 174 |
-
hub_private_repo = None,
|
| 175 |
-
hub_always_push = False,
|
| 176 |
-
hub_revision = None,
|
| 177 |
-
gradient_checkpointing = False,
|
| 178 |
-
gradient_checkpointing_kwargs = None,
|
| 179 |
-
include_inputs_for_metrics = False,
|
| 180 |
-
eval_do_concat_batches = True,
|
| 181 |
-
fp16_backend = 'auto',
|
| 182 |
-
push_to_hub_model_id = None,
|
| 183 |
-
push_to_hub_organization = None,
|
| 184 |
-
push_to_hub_token = None,
|
| 185 |
-
mp_parameters = '',
|
| 186 |
-
auto_find_batch_size = True,
|
| 187 |
-
full_determinism = False,
|
| 188 |
-
torchdynamo = None,
|
| 189 |
-
ray_scope = 'last',
|
| 190 |
-
ddp_timeout = 1800,
|
| 191 |
-
torch_compile = False,
|
| 192 |
-
torch_compile_backend = None,
|
| 193 |
-
torch_compile_mode = None,
|
| 194 |
-
include_tokens_per_second = False,
|
| 195 |
-
include_num_input_tokens_seen = False,
|
| 196 |
-
neftune_noise_alpha = None,
|
| 197 |
-
optim_target_modules = None,
|
| 198 |
-
batch_eval_metrics = False,
|
| 199 |
-
eval_on_start = False,
|
| 200 |
-
use_liger_kernel = False,
|
| 201 |
-
liger_kernel_config = None,
|
| 202 |
-
eval_use_gather_object = False,
|
| 203 |
-
average_tokens_across_devices = True,
|
| 204 |
-
reward_model_path = None,
|
| 205 |
-
judge = None,
|
| 206 |
-
max_new_tokens = 64,
|
| 207 |
-
max_length = 512,
|
| 208 |
-
temperature = 0.9,
|
| 209 |
-
missing_eos_penalty = None,
|
| 210 |
-
loss_type = 'sigmoid',
|
| 211 |
-
dataset_num_proc = None,
|
| 212 |
-
disable_dropout = True,
|
| 213 |
-
use_vllm = False,
|
| 214 |
-
ds3_gather_for_generation = True,
|
| 215 |
-
vllm_sampling_params = None,
|
| 216 |
-
unsloth_num_chunks = -1,
|
| 217 |
-
**kwargs,
|
| 218 |
-
):
|
| 219 |
-
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 220 |
-
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 221 |
-
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 222 |
-
output_dir = 'unsloth_training_checkpoints'
|
| 223 |
-
save_strategy = 'no'
|
| 224 |
-
if dataset_num_proc is None:
|
| 225 |
-
from multiprocessing import cpu_count
|
| 226 |
-
dataset_num_proc = min(cpu_count()*2, 2)
|
| 227 |
-
if temperature <= 0:
|
| 228 |
-
raise MathError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
|
| 229 |
-
elif temperature >= 10:
|
| 230 |
-
raise MathError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
super().__init__(
|
| 234 |
-
output_dir = output_dir,
|
| 235 |
-
overwrite_output_dir = overwrite_output_dir,
|
| 236 |
-
do_train = do_train,
|
| 237 |
-
do_eval = do_eval,
|
| 238 |
-
do_predict = do_predict,
|
| 239 |
-
eval_strategy = eval_strategy,
|
| 240 |
-
prediction_loss_only = prediction_loss_only,
|
| 241 |
-
per_device_train_batch_size = per_device_train_batch_size,
|
| 242 |
-
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 243 |
-
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
| 244 |
-
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
| 245 |
-
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 246 |
-
eval_accumulation_steps = eval_accumulation_steps,
|
| 247 |
-
eval_delay = eval_delay,
|
| 248 |
-
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 249 |
-
learning_rate = learning_rate,
|
| 250 |
-
weight_decay = weight_decay,
|
| 251 |
-
adam_beta1 = adam_beta1,
|
| 252 |
-
adam_beta2 = adam_beta2,
|
| 253 |
-
adam_epsilon = adam_epsilon,
|
| 254 |
-
max_grad_norm = max_grad_norm,
|
| 255 |
-
num_train_epochs = num_train_epochs,
|
| 256 |
-
max_steps = max_steps,
|
| 257 |
-
lr_scheduler_type = lr_scheduler_type,
|
| 258 |
-
warmup_ratio = warmup_ratio,
|
| 259 |
-
warmup_steps = warmup_steps,
|
| 260 |
-
log_level = log_level,
|
| 261 |
-
log_level_replica = log_level_replica,
|
| 262 |
-
log_on_each_node = log_on_each_node,
|
| 263 |
-
logging_dir = logging_dir,
|
| 264 |
-
logging_strategy = logging_strategy,
|
| 265 |
-
logging_first_step = logging_first_step,
|
| 266 |
-
logging_steps = logging_steps,
|
| 267 |
-
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 268 |
-
save_strategy = save_strategy,
|
| 269 |
-
save_steps = save_steps,
|
| 270 |
-
save_total_limit = save_total_limit,
|
| 271 |
-
save_safetensors = save_safetensors,
|
| 272 |
-
save_on_each_node = save_on_each_node,
|
| 273 |
-
save_only_model = save_only_model,
|
| 274 |
-
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 275 |
-
no_cuda = no_cuda,
|
| 276 |
-
use_cpu = use_cpu,
|
| 277 |
-
use_mps_device = use_mps_device,
|
| 278 |
-
seed = seed,
|
| 279 |
-
data_seed = data_seed,
|
| 280 |
-
jit_mode_eval = jit_mode_eval,
|
| 281 |
-
use_ipex = use_ipex,
|
| 282 |
-
bf16 = bf16,
|
| 283 |
-
fp16 = fp16,
|
| 284 |
-
fp16_opt_level = fp16_opt_level,
|
| 285 |
-
half_precision_backend = half_precision_backend,
|
| 286 |
-
bf16_full_eval = bf16_full_eval,
|
| 287 |
-
fp16_full_eval = fp16_full_eval,
|
| 288 |
-
tf32 = tf32,
|
| 289 |
-
local_rank = local_rank,
|
| 290 |
-
ddp_backend = ddp_backend,
|
| 291 |
-
tpu_num_cores = tpu_num_cores,
|
| 292 |
-
tpu_metrics_debug = tpu_metrics_debug,
|
| 293 |
-
debug = debug,
|
| 294 |
-
dataloader_drop_last = dataloader_drop_last,
|
| 295 |
-
eval_steps = eval_steps,
|
| 296 |
-
dataloader_num_workers = dataloader_num_workers,
|
| 297 |
-
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 298 |
-
past_index = past_index,
|
| 299 |
-
run_name = run_name,
|
| 300 |
-
disable_tqdm = disable_tqdm,
|
| 301 |
-
remove_unused_columns = remove_unused_columns,
|
| 302 |
-
label_names = label_names,
|
| 303 |
-
load_best_model_at_end = load_best_model_at_end,
|
| 304 |
-
metric_for_best_model = metric_for_best_model,
|
| 305 |
-
greater_is_better = greater_is_better,
|
| 306 |
-
ignore_data_skip = ignore_data_skip,
|
| 307 |
-
fsdp = fsdp,
|
| 308 |
-
fsdp_min_num_params = fsdp_min_num_params,
|
| 309 |
-
fsdp_config = fsdp_config,
|
| 310 |
-
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
| 311 |
-
accelerator_config = accelerator_config,
|
| 312 |
-
deepspeed = deepspeed,
|
| 313 |
-
label_smoothing_factor = label_smoothing_factor,
|
| 314 |
-
optim = optim,
|
| 315 |
-
optim_args = optim_args,
|
| 316 |
-
adafactor = adafactor,
|
| 317 |
-
group_by_length = group_by_length,
|
| 318 |
-
length_column_name = length_column_name,
|
| 319 |
-
report_to = report_to,
|
| 320 |
-
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 321 |
-
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 322 |
-
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 323 |
-
dataloader_pin_memory = dataloader_pin_memory,
|
| 324 |
-
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 325 |
-
skip_memory_metrics = skip_memory_metrics,
|
| 326 |
-
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
| 327 |
-
push_to_hub = push_to_hub,
|
| 328 |
-
resume_from_checkpoint = resume_from_checkpoint,
|
| 329 |
-
hub_model_id = hub_model_id,
|
| 330 |
-
hub_strategy = hub_strategy,
|
| 331 |
-
hub_token = hub_token,
|
| 332 |
-
hub_private_repo = hub_private_repo,
|
| 333 |
-
hub_always_push = hub_always_push,
|
| 334 |
-
hub_revision = hub_revision,
|
| 335 |
-
gradient_checkpointing = gradient_checkpointing,
|
| 336 |
-
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 337 |
-
include_inputs_for_metrics = include_inputs_for_metrics,
|
| 338 |
-
eval_do_concat_batches = eval_do_concat_batches,
|
| 339 |
-
fp16_backend = fp16_backend,
|
| 340 |
-
push_to_hub_model_id = push_to_hub_model_id,
|
| 341 |
-
push_to_hub_organization = push_to_hub_organization,
|
| 342 |
-
push_to_hub_token = push_to_hub_token,
|
| 343 |
-
mp_parameters = mp_parameters,
|
| 344 |
-
auto_find_batch_size = auto_find_batch_size,
|
| 345 |
-
full_determinism = full_determinism,
|
| 346 |
-
torchdynamo = torchdynamo,
|
| 347 |
-
ray_scope = ray_scope,
|
| 348 |
-
ddp_timeout = ddp_timeout,
|
| 349 |
-
torch_compile = torch_compile,
|
| 350 |
-
torch_compile_backend = torch_compile_backend,
|
| 351 |
-
torch_compile_mode = torch_compile_mode,
|
| 352 |
-
include_tokens_per_second = include_tokens_per_second,
|
| 353 |
-
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 354 |
-
neftune_noise_alpha = neftune_noise_alpha,
|
| 355 |
-
optim_target_modules = optim_target_modules,
|
| 356 |
-
batch_eval_metrics = batch_eval_metrics,
|
| 357 |
-
eval_on_start = eval_on_start,
|
| 358 |
-
use_liger_kernel = use_liger_kernel,
|
| 359 |
-
liger_kernel_config = liger_kernel_config,
|
| 360 |
-
eval_use_gather_object = eval_use_gather_object,
|
| 361 |
-
average_tokens_across_devices = average_tokens_across_devices,
|
| 362 |
-
reward_model_path = reward_model_path,
|
| 363 |
-
judge = judge,
|
| 364 |
-
max_new_tokens = max_new_tokens,
|
| 365 |
-
max_length = max_length,
|
| 366 |
-
temperature = temperature,
|
| 367 |
-
missing_eos_penalty = missing_eos_penalty,
|
| 368 |
-
loss_type = loss_type,
|
| 369 |
-
dataset_num_proc = dataset_num_proc,
|
| 370 |
-
disable_dropout = disable_dropout,
|
| 371 |
-
use_vllm = use_vllm,
|
| 372 |
-
ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
|
| 373 |
-
self.vllm_sampling_params = vllm_sampling_params
|
| 374 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
| 375 |
-
pass
|
| 376 |
-
|
| 377 |
-
class _UnslothNashMDTrainer(OnlineDPOTrainer):
|
| 378 |
-
r""""""
|
| 379 |
-
|
| 380 |
-
_tag_names = ["trl", "nash-md"]
|
| 381 |
-
|
| 382 |
-
def __init__(
|
| 383 |
-
self,
|
| 384 |
-
model: Union[PreTrainedModel, nn.Module] = None,
|
| 385 |
-
ref_model: Union[PreTrainedModel, nn.Module] = None,
|
| 386 |
-
reward_model: Union[PreTrainedModel, nn.Module, None] = None,
|
| 387 |
-
judge: Optional[BasePairwiseJudge] = None,
|
| 388 |
-
args: Optional[NashMDConfig] = None,
|
| 389 |
-
data_collator: Optional[Callable] = None,
|
| 390 |
-
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
| 391 |
-
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 392 |
-
processing_class: Optional[
|
| 393 |
-
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 394 |
-
] = None,
|
| 395 |
-
peft_config: Optional[dict] = None,
|
| 396 |
-
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
| 397 |
-
callbacks: Optional[list[TrainerCallback]] = None,
|
| 398 |
-
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 399 |
-
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 400 |
-
) -> None:
|
| 401 |
-
super().__init__(
|
| 402 |
-
model=model,
|
| 403 |
-
ref_model=ref_model,
|
| 404 |
-
reward_model=reward_model,
|
| 405 |
-
judge=judge,
|
| 406 |
-
args=args,
|
| 407 |
-
data_collator=data_collator,
|
| 408 |
-
train_dataset=train_dataset,
|
| 409 |
-
eval_dataset=eval_dataset,
|
| 410 |
-
processing_class=processing_class,
|
| 411 |
-
reward_processing_class=processing_class, # for now, NashMDTrainer can't use any reward model
|
| 412 |
-
peft_config=peft_config,
|
| 413 |
-
compute_metrics=compute_metrics,
|
| 414 |
-
callbacks=callbacks,
|
| 415 |
-
optimizers=optimizers,
|
| 416 |
-
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 417 |
-
)
|
| 418 |
-
|
| 419 |
-
self._mixture_coef = self.args.mixture_coef
|
| 420 |
-
|
| 421 |
-
# Overwrite the stats dictionary to include NashMD specific statistics
|
| 422 |
-
self.stats = {
|
| 423 |
-
# Remove "non_score_reward", "rlhf_reward", "scores_margin"
|
| 424 |
-
# Add "mixture_coef"
|
| 425 |
-
"loss/kl": [],
|
| 426 |
-
"objective/entropy": [],
|
| 427 |
-
"loss/score": [],
|
| 428 |
-
"rewards/probabilities": [],
|
| 429 |
-
"rewards/accuracies": [],
|
| 430 |
-
"rewards/margins": [],
|
| 431 |
-
"logps/chosen": [],
|
| 432 |
-
"logps/rejected": [],
|
| 433 |
-
"val/model_contain_eos_token": [],
|
| 434 |
-
"val/ref_contain_eos_token": [],
|
| 435 |
-
"beta": [],
|
| 436 |
-
"mixture_coef": [],
|
| 437 |
-
}
|
| 438 |
-
if self.reward_model is not None:
|
| 439 |
-
self.stats["rewards/chosen"] = []
|
| 440 |
-
self.stats["rewards/rejected"] = []
|
| 441 |
-
|
| 442 |
-
@property
|
| 443 |
-
def mixture_coef(self):
|
| 444 |
-
if isinstance(self._mixture_coef, list):
|
| 445 |
-
epoch = self.state.epoch
|
| 446 |
-
return self._mixture_coef[epoch] if epoch < len(self._mixture_coef) else self._mixture_coef[-1]
|
| 447 |
-
else:
|
| 448 |
-
return self._mixture_coef
|
| 449 |
-
|
| 450 |
-
def _generate_completions(self, model, prompts):
|
| 451 |
-
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
| 452 |
-
model_output = unwrapped_model.generate(
|
| 453 |
-
input_ids=prompts["input_ids"],
|
| 454 |
-
attention_mask=prompts["attention_mask"],
|
| 455 |
-
generation_config=self.generation_config,
|
| 456 |
-
)
|
| 457 |
-
|
| 458 |
-
ref_model = model if self.ref_model is None else self.ref_model
|
| 459 |
-
with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model:
|
| 460 |
-
mixture_model = GeometricMixtureWrapper(
|
| 461 |
-
model=unwrapped_model,
|
| 462 |
-
ref_model=unwrapped_ref_model,
|
| 463 |
-
generation_config=self.generation_config,
|
| 464 |
-
mixture_coef=self.mixture_coef,
|
| 465 |
-
device=self.accelerator.device,
|
| 466 |
-
)
|
| 467 |
-
|
| 468 |
-
mixture_output = mixture_model.generate(
|
| 469 |
-
input_ids=prompts["input_ids"],
|
| 470 |
-
attention_mask=prompts["attention_mask"],
|
| 471 |
-
generation_config=self.generation_config,
|
| 472 |
-
)
|
| 473 |
-
|
| 474 |
-
return model_output, mixture_output
|
| 475 |
-
|
| 476 |
-
def _process_completions(self, model_output, mixture_output, prompts):
|
| 477 |
-
context_length = prompts["input_ids"].shape[1]
|
| 478 |
-
|
| 479 |
-
# Process model completions
|
| 480 |
-
model_completion_ids = model_output[:, context_length:]
|
| 481 |
-
model_completion_ids, model_completion_mask = truncate_right(
|
| 482 |
-
model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
|
| 483 |
-
)
|
| 484 |
-
model_data = {
|
| 485 |
-
"input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1),
|
| 486 |
-
"attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1),
|
| 487 |
-
"raw": prompts["raw"],
|
| 488 |
-
}
|
| 489 |
-
|
| 490 |
-
# Process reference model completions
|
| 491 |
-
mixture_completion_ids = mixture_output[:, context_length:]
|
| 492 |
-
mixture_completion_ids, mixture_completion_mask = truncate_right(
|
| 493 |
-
mixture_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
|
| 494 |
-
)
|
| 495 |
-
mixture_data = {
|
| 496 |
-
"input_ids": torch.cat((prompts["input_ids"], mixture_completion_ids), dim=1),
|
| 497 |
-
"attention_mask": torch.cat((prompts["attention_mask"], mixture_completion_mask), dim=1),
|
| 498 |
-
"raw": prompts["raw"],
|
| 499 |
-
}
|
| 500 |
-
|
| 501 |
-
return model_data, mixture_data
|
| 502 |
-
|
| 503 |
-
def _compute_rewards(self, model_data, mixture_data, context_length):
|
| 504 |
-
with torch.no_grad():
|
| 505 |
-
_, model_scores, _ = get_reward(
|
| 506 |
-
self.reward_model, model_data["input_ids"], self.processing_class.pad_token_id, context_length
|
| 507 |
-
)
|
| 508 |
-
_, mixture_scores, _ = get_reward(
|
| 509 |
-
self.reward_model, mixture_data["input_ids"], self.processing_class.pad_token_id, context_length
|
| 510 |
-
)
|
| 511 |
-
|
| 512 |
-
# Apply EOS penalty if needed
|
| 513 |
-
if self.args.missing_eos_penalty is not None:
|
| 514 |
-
model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
|
| 515 |
-
mixture_contain_eos = torch.any(mixture_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
|
| 516 |
-
model_scores[~model_contain_eos] -= self.args.missing_eos_penalty
|
| 517 |
-
mixture_scores[~mixture_contain_eos] -= self.args.missing_eos_penalty
|
| 518 |
-
|
| 519 |
-
return model_scores, mixture_scores
|
| 520 |
-
|
| 521 |
-
def _compute_judge(self, model_data, mixture_data, context_length):
|
| 522 |
-
prompts = model_data["raw"]
|
| 523 |
-
model_data_completions = self.processing_class.batch_decode(
|
| 524 |
-
model_data["input_ids"][:, context_length:], skip_special_tokens=True
|
| 525 |
-
)
|
| 526 |
-
model_data_completions = [completion.strip() for completion in model_data_completions]
|
| 527 |
-
|
| 528 |
-
mixture_data_completions = self.processing_class.batch_decode(
|
| 529 |
-
mixture_data["input_ids"][:, context_length:], skip_special_tokens=True
|
| 530 |
-
)
|
| 531 |
-
mixture_data_completions = [completion.strip() for completion in mixture_data_completions]
|
| 532 |
-
if is_conversational({"prompt": prompts[0]}):
|
| 533 |
-
model_data_completions = [
|
| 534 |
-
[{"role": "assistant", "content": completion}] for completion in model_data_completions
|
| 535 |
-
]
|
| 536 |
-
environment = jinja2.Environment()
|
| 537 |
-
template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
|
| 538 |
-
prompts = [template.render(messages=message) for message in prompts]
|
| 539 |
-
model_data_completions = [template.render(messages=completion) for completion in model_data_completions]
|
| 540 |
-
|
| 541 |
-
mixture_data_completions = [
|
| 542 |
-
[{"role": "assistant", "content": completion}] for completion in mixture_data_completions
|
| 543 |
-
]
|
| 544 |
-
mixture_data_completions = [
|
| 545 |
-
template.render(messages=completion) for completion in mixture_data_completions
|
| 546 |
-
]
|
| 547 |
-
|
| 548 |
-
probability = self.judge.judge(
|
| 549 |
-
prompts,
|
| 550 |
-
list(zip(model_data_completions, mixture_data_completions)),
|
| 551 |
-
return_scores=True,
|
| 552 |
-
)
|
| 553 |
-
return torch.tensor(probability, device=model_data["input_ids"].device)
|
| 554 |
-
|
| 555 |
-
def _compute_logprobs(self, model, model_data, context_length):
|
| 556 |
-
def compute_logprobs_for_data(m, data):
|
| 557 |
-
output = m(data["input_ids"], attention_mask=data["attention_mask"])
|
| 558 |
-
logits = output.logits[:, context_length - 1 : -1]
|
| 559 |
-
token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
|
| 560 |
-
return token_logprobs
|
| 561 |
-
|
| 562 |
-
# Compute logprobs for model completions under the model
|
| 563 |
-
model_logprobs_model_data = compute_logprobs_for_data(model, model_data)
|
| 564 |
-
|
| 565 |
-
# Compute logprobs of model completions under the reference model
|
| 566 |
-
with torch.no_grad():
|
| 567 |
-
if self.ref_model is None:
|
| 568 |
-
with model.disable_adapter():
|
| 569 |
-
ref_logprobs_model_data = compute_logprobs_for_data(model, model_data)
|
| 570 |
-
else:
|
| 571 |
-
ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data)
|
| 572 |
-
|
| 573 |
-
# Mask padding tokens
|
| 574 |
-
model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
|
| 575 |
-
model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
|
| 576 |
-
ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
|
| 577 |
-
|
| 578 |
-
return (model_logprobs_model_data, ref_logprobs_model_data)
|
| 579 |
-
|
| 580 |
-
def _compute_losses(
|
| 581 |
-
self,
|
| 582 |
-
model_logprobs_model_data,
|
| 583 |
-
ref_logprobs_model_data,
|
| 584 |
-
probability,
|
| 585 |
-
):
|
| 586 |
-
# reinforce score where 0.5 is a control variate
|
| 587 |
-
score = (probability - 0.5) * model_logprobs_model_data.sum(1)
|
| 588 |
-
|
| 589 |
-
# kl divergence via reinforce
|
| 590 |
-
with torch.no_grad():
|
| 591 |
-
log_ratio = model_logprobs_model_data - ref_logprobs_model_data
|
| 592 |
-
kl_div_log = log_ratio.sum(1)
|
| 593 |
-
kl_div_loss = (log_ratio * model_logprobs_model_data).sum(1)
|
| 594 |
-
|
| 595 |
-
# final loss
|
| 596 |
-
loss = self.beta * kl_div_loss - score
|
| 597 |
-
|
| 598 |
-
return loss.mean(), score, kl_div_log
|
| 599 |
-
|
| 600 |
-
def _log_statistics(
|
| 601 |
-
self,
|
| 602 |
-
model_data,
|
| 603 |
-
mixture_data,
|
| 604 |
-
model_logprobs_model_data,
|
| 605 |
-
ref_logprobs_model_data,
|
| 606 |
-
probability,
|
| 607 |
-
score,
|
| 608 |
-
kl_div,
|
| 609 |
-
context_length,
|
| 610 |
-
model_scores=None,
|
| 611 |
-
mixture_scores=None,
|
| 612 |
-
):
|
| 613 |
-
# Helper function to gather and compute mean
|
| 614 |
-
def gather_mean(tensor):
|
| 615 |
-
return self.accelerator.gather_for_metrics(tensor).mean().item()
|
| 616 |
-
|
| 617 |
-
# Log score
|
| 618 |
-
self.stats["loss/score"].append(gather_mean(score))
|
| 619 |
-
# Log KL divergence
|
| 620 |
-
self.stats["loss/kl"].append(gather_mean(kl_div))
|
| 621 |
-
|
| 622 |
-
# Log logprobs
|
| 623 |
-
model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
|
| 624 |
-
ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
|
| 625 |
-
|
| 626 |
-
self.stats["logps/chosen"].append(gather_mean(model_logprobs_model_data_sum))
|
| 627 |
-
self.stats["logps/rejected"].append(gather_mean(ref_logprobs_model_data_sum))
|
| 628 |
-
|
| 629 |
-
# Log rewards
|
| 630 |
-
if self.reward_model is not None:
|
| 631 |
-
self.stats["rewards/chosen"].append(gather_mean(model_scores))
|
| 632 |
-
self.stats["rewards/rejected"].append(gather_mean(mixture_scores))
|
| 633 |
-
|
| 634 |
-
# Log probabilities
|
| 635 |
-
self.stats["rewards/probabilities"].append(gather_mean(probability))
|
| 636 |
-
|
| 637 |
-
# Calculate entropy for model data
|
| 638 |
-
entropy_model_data = -model_logprobs_model_data.sum(1)
|
| 639 |
-
self.stats["objective/entropy"].append(gather_mean(entropy_model_data))
|
| 640 |
-
|
| 641 |
-
# Calculate margins
|
| 642 |
-
margin = model_logprobs_model_data_sum - ref_logprobs_model_data_sum
|
| 643 |
-
self.stats["rewards/margins"].append(gather_mean(margin))
|
| 644 |
-
|
| 645 |
-
# Calculate accuracy
|
| 646 |
-
accuracy = (margin > 0).float()
|
| 647 |
-
self.stats["rewards/accuracies"].append(gather_mean(accuracy))
|
| 648 |
-
|
| 649 |
-
# Log EOS token statistics
|
| 650 |
-
model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
|
| 651 |
-
mixture_eos = (mixture_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
|
| 652 |
-
self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float()))
|
| 653 |
-
self.stats["val/ref_contain_eos_token"].append(gather_mean(mixture_eos.float()))
|
| 654 |
-
|
| 655 |
-
# Log beta and mixture coef
|
| 656 |
-
self.stats["beta"].append(self.beta)
|
| 657 |
-
self.stats["mixture_coef"].append(self.mixture_coef)
|
| 658 |
-
|
| 659 |
-
def training_step(
|
| 660 |
-
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
|
| 661 |
-
) -> torch.Tensor:
|
| 662 |
-
model.train()
|
| 663 |
-
|
| 664 |
-
# Apply chat template and tokenize the input
|
| 665 |
-
batch_size = len(next(iter(inputs.values())))
|
| 666 |
-
prompts = inputs["prompt"]
|
| 667 |
-
inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
|
| 668 |
-
inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
|
| 669 |
-
inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs]
|
| 670 |
-
inputs = self.data_collator(inputs)
|
| 671 |
-
|
| 672 |
-
# need the prompt_ only
|
| 673 |
-
inputs = self._prepare_inputs(inputs)
|
| 674 |
-
context_length = inputs["prompt_input_ids"].shape[1]
|
| 675 |
-
prompts = {
|
| 676 |
-
"input_ids": inputs["prompt_input_ids"],
|
| 677 |
-
"attention_mask": inputs["prompt_attention_mask"],
|
| 678 |
-
"raw": prompts,
|
| 679 |
-
}
|
| 680 |
-
del inputs
|
| 681 |
-
|
| 682 |
-
# Sample completions from both the model and the reference model
|
| 683 |
-
model_output, mixture_output = self._generate_completions(model, prompts)
|
| 684 |
-
|
| 685 |
-
# Process model completions
|
| 686 |
-
model_data, mixture_data = self._process_completions(model_output, mixture_output, prompts)
|
| 687 |
-
|
| 688 |
-
# Compute rewards
|
| 689 |
-
if self.reward_model is not None:
|
| 690 |
-
model_scores, mixture_scores = self._compute_rewards(model_data, mixture_data, context_length)
|
| 691 |
-
# probability of the model data vs the mixture data
|
| 692 |
-
probability = F.sigmoid(model_scores - mixture_scores)
|
| 693 |
-
else:
|
| 694 |
-
model_scores, mixture_scores = None, None
|
| 695 |
-
probability = self._compute_judge(model_data, mixture_data, context_length)
|
| 696 |
-
|
| 697 |
-
# Compute logprobs
|
| 698 |
-
model_logprobs_model_data, ref_logprobs_model_data = self._compute_logprobs(model, model_data, context_length)
|
| 699 |
-
|
| 700 |
-
# Compute loss
|
| 701 |
-
loss, score, kl_div = self._compute_losses(model_logprobs_model_data, ref_logprobs_model_data, probability)
|
| 702 |
-
|
| 703 |
-
# Log everything
|
| 704 |
-
self._log_statistics(
|
| 705 |
-
model_data,
|
| 706 |
-
mixture_data,
|
| 707 |
-
model_logprobs_model_data.detach(),
|
| 708 |
-
ref_logprobs_model_data,
|
| 709 |
-
probability,
|
| 710 |
-
score.detach(),
|
| 711 |
-
kl_div.detach(),
|
| 712 |
-
context_length,
|
| 713 |
-
model_scores,
|
| 714 |
-
mixture_scores,
|
| 715 |
-
)
|
| 716 |
-
|
| 717 |
-
if (
|
| 718 |
-
self.args.torch_empty_cache_steps is not None
|
| 719 |
-
and self.state.global_step % self.args.torch_empty_cache_steps == 0
|
| 720 |
-
):
|
| 721 |
-
empty_cache()
|
| 722 |
-
|
| 723 |
-
kwargs = {}
|
| 724 |
-
# For LOMO optimizers you need to explicitly use the learning rate
|
| 725 |
-
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
| 726 |
-
kwargs["learning_rate"] = self._get_learning_rate()
|
| 727 |
-
|
| 728 |
-
if self.args.n_gpu > 1:
|
| 729 |
-
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
| 730 |
-
|
| 731 |
-
if self.use_apex:
|
| 732 |
-
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
| 733 |
-
scaled_loss.backward()
|
| 734 |
-
else:
|
| 735 |
-
self.accelerator.backward(loss, **kwargs)
|
| 736 |
-
|
| 737 |
-
return loss.detach() / self.args.gradient_accumulation_steps
|
| 738 |
-
|
| 739 |
-
def create_model_card(
|
| 740 |
-
self,
|
| 741 |
-
model_name: Optional[str] = None,
|
| 742 |
-
dataset_name: Optional[str] = None,
|
| 743 |
-
tags: Union[str, list[str], None] = None,
|
| 744 |
-
):
|
| 745 |
-
"""
|
| 746 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
| 747 |
-
|
| 748 |
-
Args:
|
| 749 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 750 |
-
Name of the model.
|
| 751 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 752 |
-
Name of the dataset used for training.
|
| 753 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 754 |
-
Tags to be associated with the model card.
|
| 755 |
-
"""
|
| 756 |
-
if not self.is_world_process_zero():
|
| 757 |
-
return
|
| 758 |
-
|
| 759 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 760 |
-
base_model = self.model.config._name_or_path
|
| 761 |
-
else:
|
| 762 |
-
base_model = None
|
| 763 |
-
|
| 764 |
-
tags = tags or []
|
| 765 |
-
if isinstance(tags, str):
|
| 766 |
-
tags = [tags]
|
| 767 |
-
|
| 768 |
-
if hasattr(self.model.config, "unsloth_version"):
|
| 769 |
-
tags.append("unsloth")
|
| 770 |
-
|
| 771 |
-
citation = textwrap.dedent("""\
|
| 772 |
-
@inproceedings{munos2024nash,
|
| 773 |
-
title = {{Nash Learning from Human Feedback}},
|
| 774 |
-
author = {R{\'{e}}mi Munos and Michal Valko and Daniele Calandriello and Mohammad Gheshlaghi Azar and Mark Rowland and Zhaohan Daniel Guo and Yunhao Tang and Matthieu Geist and Thomas Mesnard and C{\\^{o}}me Fiegel and Andrea Michi and Marco Selvi and Sertan Girgin and Nikola Momchev and Olivier Bachem and Daniel J. Mankowitz and Doina Precup and Bilal Piot},
|
| 775 |
-
year = 2024,
|
| 776 |
-
booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
|
| 777 |
-
publisher = {OpenReview.net},
|
| 778 |
-
url = {https://openreview.net/forum?id=Y5AmNYiyCQ}
|
| 779 |
-
}""")
|
| 780 |
-
|
| 781 |
-
model_card = generate_model_card(
|
| 782 |
-
base_model=base_model,
|
| 783 |
-
model_name=model_name,
|
| 784 |
-
hub_model_id=self.hub_model_id,
|
| 785 |
-
dataset_name=dataset_name,
|
| 786 |
-
tags=tags,
|
| 787 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 788 |
-
comet_url=get_comet_experiment_url(),
|
| 789 |
-
trainer_name="Nash-MD",
|
| 790 |
-
trainer_citation=citation,
|
| 791 |
-
paper_title="Nash Learning from Human Feedback",
|
| 792 |
-
paper_id="2312.00886",
|
| 793 |
-
)
|
| 794 |
-
|
| 795 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 796 |
-
class UnslothNashMDTrainer(_UnslothNashMDTrainer):
|
| 797 |
-
"""
|
| 798 |
-
|
| 799 |
-
Initialize NashMDTrainer as a subclass of [`OnlineDPOConfig`].
|
| 800 |
-
|
| 801 |
-
Args:
|
| 802 |
-
model (`transformers.PreTrainedModel`):
|
| 803 |
-
The model to train, preferably an `AutoModelForCausalLM`.
|
| 804 |
-
ref_model (`PreTrainedModelWrapper`):
|
| 805 |
-
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
|
| 806 |
-
reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
|
| 807 |
-
reward_model (`transformers.PreTrainedModel`):
|
| 808 |
-
The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
|
| 809 |
-
judge (`BasePairwiseJudge`):
|
| 810 |
-
The judge to use for pairwise comparison of model completions.
|
| 811 |
-
args (`NashMDConfig`):
|
| 812 |
-
The NashMD config arguments to use for training.
|
| 813 |
-
data_collator (`transformers.DataCollator`):
|
| 814 |
-
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
| 815 |
-
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
| 816 |
-
train_dataset (`datasets.Dataset`):
|
| 817 |
-
The dataset to use for training.
|
| 818 |
-
eval_dataset (`datasets.Dataset`):
|
| 819 |
-
The dataset to use for evaluation.
|
| 820 |
-
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
| 821 |
-
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
| 822 |
-
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
| 823 |
-
reuse the fine-tuned model.
|
| 824 |
-
peft_config (`dict`):
|
| 825 |
-
The peft config to use for training.
|
| 826 |
-
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
| 827 |
-
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
| 828 |
-
a dictionary string to metric values.
|
| 829 |
-
callbacks (`list[transformers.TrainerCallback]`):
|
| 830 |
-
The callbacks to use for training.
|
| 831 |
-
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
| 832 |
-
The optimizer and scheduler to use for training.
|
| 833 |
-
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
| 834 |
-
The function to use to preprocess the logits before computing the metrics.
|
| 835 |
-
|
| 836 |
-
"""
|
| 837 |
-
def __init__(
|
| 838 |
-
self,
|
| 839 |
-
model = None,
|
| 840 |
-
ref_model = None,
|
| 841 |
-
reward_model = None,
|
| 842 |
-
judge = None,
|
| 843 |
-
args = None,
|
| 844 |
-
data_collator = None,
|
| 845 |
-
train_dataset = None,
|
| 846 |
-
eval_dataset = None,
|
| 847 |
-
processing_class = None,
|
| 848 |
-
peft_config = None,
|
| 849 |
-
compute_metrics = None,
|
| 850 |
-
callbacks = None,
|
| 851 |
-
preprocess_logits_for_metrics = None,
|
| 852 |
-
**kwargs
|
| 853 |
-
):
|
| 854 |
-
if args is None: args = UnslothNashMDConfig()
|
| 855 |
-
use_bf16 = getattr(args, 'bf16', False)
|
| 856 |
-
if type(use_bf16) is not bool: use_bf16 = False
|
| 857 |
-
use_fp16 = getattr(args, 'fp16', False)
|
| 858 |
-
if type(use_fp16) is not bool: use_fp16 = False
|
| 859 |
-
force_float32 = False
|
| 860 |
-
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
| 861 |
-
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 862 |
-
force_float32 = True
|
| 863 |
-
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 864 |
-
dtype = getattr(model.config, 'torch_dtype', None)
|
| 865 |
-
if dtype is None: dtype = model.get_input_embeddings().dtype
|
| 866 |
-
from unsloth_zoo.utils import _get_dtype
|
| 867 |
-
dtype = _get_dtype(dtype)
|
| 868 |
-
float16 = dtype == torch.float16
|
| 869 |
-
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 870 |
-
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 871 |
-
if force_float32:
|
| 872 |
-
args.fp16 = False
|
| 873 |
-
args.bf16 = False
|
| 874 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 875 |
-
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 876 |
-
args.fp16 = float16
|
| 877 |
-
args.bf16 = not float16
|
| 878 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 879 |
-
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 880 |
-
args.eval_strategy = 'steps'
|
| 881 |
-
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 882 |
-
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 883 |
-
if ga_steps is not None and ga_steps > 1:
|
| 884 |
-
from transformers import __version__ as transformers_version
|
| 885 |
-
if Version(transformers_version) <= Version('4.45.2'):
|
| 886 |
-
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 887 |
-
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 888 |
-
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 889 |
-
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 890 |
-
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 891 |
-
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 892 |
-
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 893 |
-
if type(fp16_full_eval) is not bool: fp16_full_eval = False
|
| 894 |
-
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 895 |
-
if type(bf16_full_eval) is not bool: bf16_full_eval = False
|
| 896 |
-
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 897 |
-
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 898 |
-
if force_float32:
|
| 899 |
-
args.bf16_full_eval = False
|
| 900 |
-
args.fp16_full_eval = False
|
| 901 |
-
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 902 |
-
args.bf16_full_eval = True
|
| 903 |
-
args.fp16_full_eval = False
|
| 904 |
-
elif not bf16_full_eval and not fp16_full_eval:
|
| 905 |
-
args.bf16_full_eval = args.bf16
|
| 906 |
-
args.fp16_full_eval = args.fp16
|
| 907 |
-
_output_logits = False
|
| 908 |
-
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 909 |
-
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 910 |
-
if _output_logits:
|
| 911 |
-
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 912 |
-
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 913 |
-
pass
|
| 914 |
-
else:
|
| 915 |
-
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 916 |
-
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 917 |
-
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 918 |
-
max_seq_length = model.max_seq_length
|
| 919 |
-
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 920 |
-
if model is not None and hasattr(model, 'for_training'):
|
| 921 |
-
model.for_training()
|
| 922 |
-
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 923 |
-
if 'processing_class' in locals():
|
| 924 |
-
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 925 |
-
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 926 |
-
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 927 |
-
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 928 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 929 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 930 |
-
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
|
| 931 |
-
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 932 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
| 933 |
-
else:
|
| 934 |
-
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 935 |
-
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 936 |
-
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 937 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 938 |
-
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 939 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 940 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
| 941 |
-
else:
|
| 942 |
-
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
|
| 943 |
-
other_metrics = []
|
| 944 |
-
|
| 945 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 946 |
-
PatchRLStatistics('nash_md_trainer', other_metrics)
|
| 947 |
-
|
| 948 |
-
super().__init__(
|
| 949 |
-
model = model,
|
| 950 |
-
ref_model = ref_model,
|
| 951 |
-
reward_model = reward_model,
|
| 952 |
-
judge = judge,
|
| 953 |
-
args = args,
|
| 954 |
-
data_collator = data_collator,
|
| 955 |
-
train_dataset = train_dataset,
|
| 956 |
-
eval_dataset = eval_dataset,
|
| 957 |
-
processing_class = processing_class,
|
| 958 |
-
peft_config = peft_config,
|
| 959 |
-
compute_metrics = compute_metrics,
|
| 960 |
-
callbacks = callbacks,
|
| 961 |
-
preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
|
| 962 |
-
if hasattr(self, 'neftune_hook_handle'):
|
| 963 |
-
self.neftune_hook_handle.remove()
|
| 964 |
-
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 965 |
-
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 966 |
-
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 967 |
-
pass
|
| 968 |
-
|
| 969 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_run_uploads/UnslothORPOTrainer.py
DELETED
|
@@ -1,1552 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
2025.7.11
|
| 3 |
-
2025.7.11
|
| 4 |
-
4.54.1
|
| 5 |
-
0.16.1
|
| 6 |
-
__UNSLOTH_VERSIONING__
|
| 7 |
-
"""
|
| 8 |
-
from torch import Tensor
|
| 9 |
-
import torch
|
| 10 |
-
import torch.nn as nn
|
| 11 |
-
from torch.nn import functional as F
|
| 12 |
-
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
| 13 |
-
from trl.trainer.orpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, ORPOConfig, ORPOTrainer, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, amp, deepcopy, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_torch_xla_available, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, transformers, version, wandb, warnings)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
import os
|
| 17 |
-
from typing import *
|
| 18 |
-
from dataclasses import dataclass, field
|
| 19 |
-
from packaging.version import Version
|
| 20 |
-
import torch
|
| 21 |
-
import numpy as np
|
| 22 |
-
from contextlib import nullcontext
|
| 23 |
-
from torch.nn import functional as F
|
| 24 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
| 25 |
-
|
| 26 |
-
torch_compile_options = {
|
| 27 |
-
"epilogue_fusion" : True,
|
| 28 |
-
"max_autotune" : False,
|
| 29 |
-
"shape_padding" : True,
|
| 30 |
-
"trace.enabled" : False,
|
| 31 |
-
"triton.cudagraphs" : False,
|
| 32 |
-
}
|
| 33 |
-
|
| 34 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 35 |
-
def chunked_selective_log_softmax(logits, index):
|
| 36 |
-
# Split into 4 chunks only
|
| 37 |
-
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
|
| 38 |
-
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
|
| 39 |
-
all_per_token_logps = []
|
| 40 |
-
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
|
| 41 |
-
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
|
| 42 |
-
chunk_logits = chunk_logits.to(torch.float32)
|
| 43 |
-
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 44 |
-
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
|
| 45 |
-
per_token_logps = selected_logits - logsumexp_values
|
| 46 |
-
all_per_token_logps.append(per_token_logps)
|
| 47 |
-
pass
|
| 48 |
-
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 49 |
-
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
|
| 50 |
-
return all_per_token_logps
|
| 51 |
-
@dataclass
|
| 52 |
-
class UnslothORPOConfig(ORPOConfig):
|
| 53 |
-
"""
|
| 54 |
-
|
| 55 |
-
Configuration class for the [`ORPOTrainer`].
|
| 56 |
-
|
| 57 |
-
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 58 |
-
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 59 |
-
command line.
|
| 60 |
-
|
| 61 |
-
Parameters:
|
| 62 |
-
learning_rate (`float`, *optional*, defaults to `1e-6`):
|
| 63 |
-
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
| 64 |
-
[`~transformers.TrainingArguments`].
|
| 65 |
-
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
| 66 |
-
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
|
| 67 |
-
to use the default data collator.
|
| 68 |
-
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
| 69 |
-
Maximum length of the prompt. This argument is required if you want to use the default data collator.
|
| 70 |
-
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
|
| 71 |
-
Maximum length of the completion. This argument is required if you want to use the default data collator
|
| 72 |
-
and your model is an encoder-decoder.
|
| 73 |
-
beta (`float`, *optional*, defaults to `0.1`):
|
| 74 |
-
Parameter controlling the relative ratio loss weight in the ORPO loss. In the [paper](https://huggingface.co/papers/2403.07691),
|
| 75 |
-
it is denoted by λ. In the [code](https://github.com/xfactlab/orpo), it is denoted by `alpha`.
|
| 76 |
-
disable_dropout (`bool`, *optional*, defaults to `True`):
|
| 77 |
-
Whether to disable dropout in the model.
|
| 78 |
-
label_pad_token_id (`int`, *optional*, defaults to `-100`):
|
| 79 |
-
Label pad token id. This argument is required if you want to use the default data collator.
|
| 80 |
-
padding_value (`int` or `None`, *optional*, defaults to `None`):
|
| 81 |
-
Padding value to use. If `None`, the padding value of the tokenizer is used.
|
| 82 |
-
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
|
| 83 |
-
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
|
| 84 |
-
This argument is required if you want to use the default data collator.
|
| 85 |
-
generate_during_eval (`bool`, *optional*, defaults to `False`):
|
| 86 |
-
If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
|
| 87 |
-
is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
|
| 88 |
-
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
|
| 89 |
-
you need to specify if the model returned by the callable is an encoder-decoder model.
|
| 90 |
-
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
| 91 |
-
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
|
| 92 |
-
string.
|
| 93 |
-
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
| 94 |
-
Number of processes to use for processing the dataset.
|
| 95 |
-
|
| 96 |
-
"""
|
| 97 |
-
vllm_sampling_params: Optional[Any] = field(
|
| 98 |
-
default = None,
|
| 99 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
| 100 |
-
)
|
| 101 |
-
unsloth_num_chunks : Optional[int] = field(
|
| 102 |
-
default = -1,
|
| 103 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 104 |
-
)
|
| 105 |
-
def __init__(
|
| 106 |
-
self,
|
| 107 |
-
output_dir = None,
|
| 108 |
-
overwrite_output_dir = None,
|
| 109 |
-
do_train = False,
|
| 110 |
-
do_eval = False,
|
| 111 |
-
do_predict = False,
|
| 112 |
-
eval_strategy = 'no',
|
| 113 |
-
prediction_loss_only = False,
|
| 114 |
-
per_device_train_batch_size = 4,
|
| 115 |
-
per_device_eval_batch_size = 4,
|
| 116 |
-
per_gpu_train_batch_size = None,
|
| 117 |
-
per_gpu_eval_batch_size = None,
|
| 118 |
-
gradient_accumulation_steps = 2,
|
| 119 |
-
eval_accumulation_steps = 2,
|
| 120 |
-
eval_delay = 0,
|
| 121 |
-
torch_empty_cache_steps = 250,
|
| 122 |
-
learning_rate = 5e-05,
|
| 123 |
-
weight_decay = 0.01,
|
| 124 |
-
adam_beta1 = 0.9,
|
| 125 |
-
adam_beta2 = 0.999,
|
| 126 |
-
adam_epsilon = 1e-08,
|
| 127 |
-
max_grad_norm = 1.0,
|
| 128 |
-
num_train_epochs = 3.0,
|
| 129 |
-
max_steps = -1,
|
| 130 |
-
lr_scheduler_type = 'linear',
|
| 131 |
-
warmup_ratio = 0.1,
|
| 132 |
-
warmup_steps = 0,
|
| 133 |
-
log_level = 'passive',
|
| 134 |
-
log_level_replica = 'warning',
|
| 135 |
-
log_on_each_node = True,
|
| 136 |
-
logging_dir = None,
|
| 137 |
-
logging_strategy = 'steps',
|
| 138 |
-
logging_first_step = False,
|
| 139 |
-
logging_steps = 1,
|
| 140 |
-
logging_nan_inf_filter = False,
|
| 141 |
-
save_strategy = 'steps',
|
| 142 |
-
save_steps = 500,
|
| 143 |
-
save_total_limit = None,
|
| 144 |
-
save_safetensors = True,
|
| 145 |
-
save_on_each_node = False,
|
| 146 |
-
save_only_model = False,
|
| 147 |
-
restore_callback_states_from_checkpoint = False,
|
| 148 |
-
no_cuda = False,
|
| 149 |
-
use_cpu = False,
|
| 150 |
-
use_mps_device = False,
|
| 151 |
-
seed = 3407,
|
| 152 |
-
data_seed = 3407,
|
| 153 |
-
jit_mode_eval = False,
|
| 154 |
-
use_ipex = False,
|
| 155 |
-
bf16 = False,
|
| 156 |
-
fp16 = False,
|
| 157 |
-
fp16_opt_level = 'O1',
|
| 158 |
-
half_precision_backend = 'auto',
|
| 159 |
-
bf16_full_eval = False,
|
| 160 |
-
fp16_full_eval = False,
|
| 161 |
-
tf32 = None,
|
| 162 |
-
local_rank = -1,
|
| 163 |
-
ddp_backend = None,
|
| 164 |
-
tpu_num_cores = None,
|
| 165 |
-
tpu_metrics_debug = False,
|
| 166 |
-
debug = '',
|
| 167 |
-
dataloader_drop_last = False,
|
| 168 |
-
eval_steps = None,
|
| 169 |
-
dataloader_num_workers = 0,
|
| 170 |
-
dataloader_prefetch_factor = None,
|
| 171 |
-
past_index = -1,
|
| 172 |
-
run_name = None,
|
| 173 |
-
disable_tqdm = None,
|
| 174 |
-
remove_unused_columns = True,
|
| 175 |
-
label_names = None,
|
| 176 |
-
load_best_model_at_end = False,
|
| 177 |
-
metric_for_best_model = None,
|
| 178 |
-
greater_is_better = None,
|
| 179 |
-
ignore_data_skip = False,
|
| 180 |
-
fsdp = '',
|
| 181 |
-
fsdp_min_num_params = 0,
|
| 182 |
-
fsdp_config = None,
|
| 183 |
-
fsdp_transformer_layer_cls_to_wrap = None,
|
| 184 |
-
accelerator_config = None,
|
| 185 |
-
deepspeed = None,
|
| 186 |
-
label_smoothing_factor = 0.0,
|
| 187 |
-
optim = 'adamw_8bit',
|
| 188 |
-
optim_args = None,
|
| 189 |
-
adafactor = False,
|
| 190 |
-
group_by_length = False,
|
| 191 |
-
length_column_name = 'length',
|
| 192 |
-
report_to = None,
|
| 193 |
-
ddp_find_unused_parameters = None,
|
| 194 |
-
ddp_bucket_cap_mb = None,
|
| 195 |
-
ddp_broadcast_buffers = None,
|
| 196 |
-
dataloader_pin_memory = True,
|
| 197 |
-
dataloader_persistent_workers = False,
|
| 198 |
-
skip_memory_metrics = True,
|
| 199 |
-
use_legacy_prediction_loop = False,
|
| 200 |
-
push_to_hub = False,
|
| 201 |
-
resume_from_checkpoint = None,
|
| 202 |
-
hub_model_id = None,
|
| 203 |
-
hub_strategy = 'every_save',
|
| 204 |
-
hub_token = None,
|
| 205 |
-
hub_private_repo = None,
|
| 206 |
-
hub_always_push = False,
|
| 207 |
-
hub_revision = None,
|
| 208 |
-
gradient_checkpointing = False,
|
| 209 |
-
gradient_checkpointing_kwargs = None,
|
| 210 |
-
include_inputs_for_metrics = False,
|
| 211 |
-
eval_do_concat_batches = True,
|
| 212 |
-
fp16_backend = 'auto',
|
| 213 |
-
push_to_hub_model_id = None,
|
| 214 |
-
push_to_hub_organization = None,
|
| 215 |
-
push_to_hub_token = None,
|
| 216 |
-
mp_parameters = '',
|
| 217 |
-
auto_find_batch_size = True,
|
| 218 |
-
full_determinism = False,
|
| 219 |
-
torchdynamo = None,
|
| 220 |
-
ray_scope = 'last',
|
| 221 |
-
ddp_timeout = 1800,
|
| 222 |
-
torch_compile = False,
|
| 223 |
-
torch_compile_backend = None,
|
| 224 |
-
torch_compile_mode = None,
|
| 225 |
-
include_tokens_per_second = False,
|
| 226 |
-
include_num_input_tokens_seen = False,
|
| 227 |
-
neftune_noise_alpha = None,
|
| 228 |
-
optim_target_modules = None,
|
| 229 |
-
batch_eval_metrics = False,
|
| 230 |
-
eval_on_start = False,
|
| 231 |
-
use_liger_kernel = False,
|
| 232 |
-
liger_kernel_config = None,
|
| 233 |
-
eval_use_gather_object = False,
|
| 234 |
-
average_tokens_across_devices = True,
|
| 235 |
-
max_length = 1024,
|
| 236 |
-
max_prompt_length = 512,
|
| 237 |
-
max_completion_length = None,
|
| 238 |
-
beta = 0.1,
|
| 239 |
-
disable_dropout = True,
|
| 240 |
-
label_pad_token_id = -100,
|
| 241 |
-
padding_value = None,
|
| 242 |
-
truncation_mode = 'keep_end',
|
| 243 |
-
generate_during_eval = False,
|
| 244 |
-
is_encoder_decoder = None,
|
| 245 |
-
model_init_kwargs = None,
|
| 246 |
-
dataset_num_proc = None,
|
| 247 |
-
vllm_sampling_params = None,
|
| 248 |
-
unsloth_num_chunks = -1,
|
| 249 |
-
**kwargs,
|
| 250 |
-
):
|
| 251 |
-
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 252 |
-
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 253 |
-
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 254 |
-
output_dir = 'unsloth_training_checkpoints'
|
| 255 |
-
save_strategy = 'no'
|
| 256 |
-
if dataset_num_proc is None:
|
| 257 |
-
from multiprocessing import cpu_count
|
| 258 |
-
dataset_num_proc = min(cpu_count()*2, 2)
|
| 259 |
-
|
| 260 |
-
super().__init__(
|
| 261 |
-
output_dir = output_dir,
|
| 262 |
-
overwrite_output_dir = overwrite_output_dir,
|
| 263 |
-
do_train = do_train,
|
| 264 |
-
do_eval = do_eval,
|
| 265 |
-
do_predict = do_predict,
|
| 266 |
-
eval_strategy = eval_strategy,
|
| 267 |
-
prediction_loss_only = prediction_loss_only,
|
| 268 |
-
per_device_train_batch_size = per_device_train_batch_size,
|
| 269 |
-
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 270 |
-
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
| 271 |
-
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
| 272 |
-
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 273 |
-
eval_accumulation_steps = eval_accumulation_steps,
|
| 274 |
-
eval_delay = eval_delay,
|
| 275 |
-
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 276 |
-
learning_rate = learning_rate,
|
| 277 |
-
weight_decay = weight_decay,
|
| 278 |
-
adam_beta1 = adam_beta1,
|
| 279 |
-
adam_beta2 = adam_beta2,
|
| 280 |
-
adam_epsilon = adam_epsilon,
|
| 281 |
-
max_grad_norm = max_grad_norm,
|
| 282 |
-
num_train_epochs = num_train_epochs,
|
| 283 |
-
max_steps = max_steps,
|
| 284 |
-
lr_scheduler_type = lr_scheduler_type,
|
| 285 |
-
warmup_ratio = warmup_ratio,
|
| 286 |
-
warmup_steps = warmup_steps,
|
| 287 |
-
log_level = log_level,
|
| 288 |
-
log_level_replica = log_level_replica,
|
| 289 |
-
log_on_each_node = log_on_each_node,
|
| 290 |
-
logging_dir = logging_dir,
|
| 291 |
-
logging_strategy = logging_strategy,
|
| 292 |
-
logging_first_step = logging_first_step,
|
| 293 |
-
logging_steps = logging_steps,
|
| 294 |
-
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 295 |
-
save_strategy = save_strategy,
|
| 296 |
-
save_steps = save_steps,
|
| 297 |
-
save_total_limit = save_total_limit,
|
| 298 |
-
save_safetensors = save_safetensors,
|
| 299 |
-
save_on_each_node = save_on_each_node,
|
| 300 |
-
save_only_model = save_only_model,
|
| 301 |
-
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 302 |
-
no_cuda = no_cuda,
|
| 303 |
-
use_cpu = use_cpu,
|
| 304 |
-
use_mps_device = use_mps_device,
|
| 305 |
-
seed = seed,
|
| 306 |
-
data_seed = data_seed,
|
| 307 |
-
jit_mode_eval = jit_mode_eval,
|
| 308 |
-
use_ipex = use_ipex,
|
| 309 |
-
bf16 = bf16,
|
| 310 |
-
fp16 = fp16,
|
| 311 |
-
fp16_opt_level = fp16_opt_level,
|
| 312 |
-
half_precision_backend = half_precision_backend,
|
| 313 |
-
bf16_full_eval = bf16_full_eval,
|
| 314 |
-
fp16_full_eval = fp16_full_eval,
|
| 315 |
-
tf32 = tf32,
|
| 316 |
-
local_rank = local_rank,
|
| 317 |
-
ddp_backend = ddp_backend,
|
| 318 |
-
tpu_num_cores = tpu_num_cores,
|
| 319 |
-
tpu_metrics_debug = tpu_metrics_debug,
|
| 320 |
-
debug = debug,
|
| 321 |
-
dataloader_drop_last = dataloader_drop_last,
|
| 322 |
-
eval_steps = eval_steps,
|
| 323 |
-
dataloader_num_workers = dataloader_num_workers,
|
| 324 |
-
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 325 |
-
past_index = past_index,
|
| 326 |
-
run_name = run_name,
|
| 327 |
-
disable_tqdm = disable_tqdm,
|
| 328 |
-
remove_unused_columns = remove_unused_columns,
|
| 329 |
-
label_names = label_names,
|
| 330 |
-
load_best_model_at_end = load_best_model_at_end,
|
| 331 |
-
metric_for_best_model = metric_for_best_model,
|
| 332 |
-
greater_is_better = greater_is_better,
|
| 333 |
-
ignore_data_skip = ignore_data_skip,
|
| 334 |
-
fsdp = fsdp,
|
| 335 |
-
fsdp_min_num_params = fsdp_min_num_params,
|
| 336 |
-
fsdp_config = fsdp_config,
|
| 337 |
-
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
| 338 |
-
accelerator_config = accelerator_config,
|
| 339 |
-
deepspeed = deepspeed,
|
| 340 |
-
label_smoothing_factor = label_smoothing_factor,
|
| 341 |
-
optim = optim,
|
| 342 |
-
optim_args = optim_args,
|
| 343 |
-
adafactor = adafactor,
|
| 344 |
-
group_by_length = group_by_length,
|
| 345 |
-
length_column_name = length_column_name,
|
| 346 |
-
report_to = report_to,
|
| 347 |
-
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 348 |
-
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 349 |
-
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 350 |
-
dataloader_pin_memory = dataloader_pin_memory,
|
| 351 |
-
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 352 |
-
skip_memory_metrics = skip_memory_metrics,
|
| 353 |
-
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
| 354 |
-
push_to_hub = push_to_hub,
|
| 355 |
-
resume_from_checkpoint = resume_from_checkpoint,
|
| 356 |
-
hub_model_id = hub_model_id,
|
| 357 |
-
hub_strategy = hub_strategy,
|
| 358 |
-
hub_token = hub_token,
|
| 359 |
-
hub_private_repo = hub_private_repo,
|
| 360 |
-
hub_always_push = hub_always_push,
|
| 361 |
-
hub_revision = hub_revision,
|
| 362 |
-
gradient_checkpointing = gradient_checkpointing,
|
| 363 |
-
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 364 |
-
include_inputs_for_metrics = include_inputs_for_metrics,
|
| 365 |
-
eval_do_concat_batches = eval_do_concat_batches,
|
| 366 |
-
fp16_backend = fp16_backend,
|
| 367 |
-
push_to_hub_model_id = push_to_hub_model_id,
|
| 368 |
-
push_to_hub_organization = push_to_hub_organization,
|
| 369 |
-
push_to_hub_token = push_to_hub_token,
|
| 370 |
-
mp_parameters = mp_parameters,
|
| 371 |
-
auto_find_batch_size = auto_find_batch_size,
|
| 372 |
-
full_determinism = full_determinism,
|
| 373 |
-
torchdynamo = torchdynamo,
|
| 374 |
-
ray_scope = ray_scope,
|
| 375 |
-
ddp_timeout = ddp_timeout,
|
| 376 |
-
torch_compile = torch_compile,
|
| 377 |
-
torch_compile_backend = torch_compile_backend,
|
| 378 |
-
torch_compile_mode = torch_compile_mode,
|
| 379 |
-
include_tokens_per_second = include_tokens_per_second,
|
| 380 |
-
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 381 |
-
neftune_noise_alpha = neftune_noise_alpha,
|
| 382 |
-
optim_target_modules = optim_target_modules,
|
| 383 |
-
batch_eval_metrics = batch_eval_metrics,
|
| 384 |
-
eval_on_start = eval_on_start,
|
| 385 |
-
use_liger_kernel = use_liger_kernel,
|
| 386 |
-
liger_kernel_config = liger_kernel_config,
|
| 387 |
-
eval_use_gather_object = eval_use_gather_object,
|
| 388 |
-
average_tokens_across_devices = average_tokens_across_devices,
|
| 389 |
-
max_length = max_length,
|
| 390 |
-
max_prompt_length = max_prompt_length,
|
| 391 |
-
max_completion_length = max_completion_length,
|
| 392 |
-
beta = beta,
|
| 393 |
-
disable_dropout = disable_dropout,
|
| 394 |
-
label_pad_token_id = label_pad_token_id,
|
| 395 |
-
padding_value = padding_value,
|
| 396 |
-
truncation_mode = truncation_mode,
|
| 397 |
-
generate_during_eval = generate_during_eval,
|
| 398 |
-
is_encoder_decoder = is_encoder_decoder,
|
| 399 |
-
model_init_kwargs = model_init_kwargs,
|
| 400 |
-
dataset_num_proc = dataset_num_proc,**kwargs)
|
| 401 |
-
self.vllm_sampling_params = vllm_sampling_params
|
| 402 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
| 403 |
-
pass
|
| 404 |
-
|
| 405 |
-
class _UnslothORPOTrainer(Trainer):
|
| 406 |
-
r""""""
|
| 407 |
-
|
| 408 |
-
_tag_names = ["trl", "orpo"]
|
| 409 |
-
|
| 410 |
-
def __init__(
|
| 411 |
-
self,
|
| 412 |
-
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
| 413 |
-
args: Optional[ORPOConfig] = None,
|
| 414 |
-
data_collator: Optional[DataCollator] = None,
|
| 415 |
-
train_dataset: Optional[Dataset] = None,
|
| 416 |
-
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 417 |
-
processing_class: Optional[
|
| 418 |
-
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 419 |
-
] = None,
|
| 420 |
-
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
| 421 |
-
callbacks: Optional[list[TrainerCallback]] = None,
|
| 422 |
-
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 423 |
-
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 424 |
-
peft_config: Optional[dict] = None,
|
| 425 |
-
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
| 426 |
-
):
|
| 427 |
-
if args.model_init_kwargs is None:
|
| 428 |
-
model_init_kwargs = {}
|
| 429 |
-
elif not isinstance(model, str):
|
| 430 |
-
raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.")
|
| 431 |
-
else:
|
| 432 |
-
model_init_kwargs = args.model_init_kwargs
|
| 433 |
-
torch_dtype = model_init_kwargs.get("torch_dtype")
|
| 434 |
-
if torch_dtype is not None:
|
| 435 |
-
# Convert to `torch.dtype` if an str is passed
|
| 436 |
-
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
| 437 |
-
torch_dtype = getattr(torch, torch_dtype)
|
| 438 |
-
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
| 439 |
-
raise ValueError(
|
| 440 |
-
f"Invalid `torch_dtype` passed to the ORPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
| 441 |
-
)
|
| 442 |
-
model_init_kwargs["torch_dtype"] = torch_dtype
|
| 443 |
-
|
| 444 |
-
if isinstance(model, str):
|
| 445 |
-
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
| 446 |
-
|
| 447 |
-
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
| 448 |
-
# has been called in order to properly call autocast if needed.
|
| 449 |
-
self._peft_has_been_casted_to_bf16 = False
|
| 450 |
-
|
| 451 |
-
if not is_peft_available() and peft_config is not None:
|
| 452 |
-
raise ValueError(
|
| 453 |
-
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
| 454 |
-
)
|
| 455 |
-
elif is_peft_available() and peft_config is not None:
|
| 456 |
-
# if model is a peft model and we have a peft_config, we merge and unload it first
|
| 457 |
-
if isinstance(model, PeftModel):
|
| 458 |
-
model = model.merge_and_unload()
|
| 459 |
-
|
| 460 |
-
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
| 461 |
-
_support_gc_kwargs = hasattr(
|
| 462 |
-
args, "gradient_checkpointing_kwargs"
|
| 463 |
-
) and "gradient_checkpointing_kwargs" in list(
|
| 464 |
-
inspect.signature(prepare_model_for_kbit_training).parameters
|
| 465 |
-
)
|
| 466 |
-
|
| 467 |
-
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
| 468 |
-
|
| 469 |
-
if _support_gc_kwargs:
|
| 470 |
-
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
| 471 |
-
|
| 472 |
-
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
| 473 |
-
elif getattr(args, "gradient_checkpointing", False):
|
| 474 |
-
# For backward compatibility with older versions of transformers
|
| 475 |
-
if hasattr(model, "enable_input_require_grads"):
|
| 476 |
-
model.enable_input_require_grads()
|
| 477 |
-
else:
|
| 478 |
-
|
| 479 |
-
def make_inputs_require_grad(module, input, output):
|
| 480 |
-
output.requires_grad_(True)
|
| 481 |
-
|
| 482 |
-
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 483 |
-
|
| 484 |
-
# get peft model with the given config
|
| 485 |
-
model = model
|
| 486 |
-
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
| 487 |
-
peft_module_casting_to_bf16(model)
|
| 488 |
-
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
| 489 |
-
self._peft_has_been_casted_to_bf16 = True
|
| 490 |
-
|
| 491 |
-
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
| 492 |
-
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
| 493 |
-
# fail or completely fail.
|
| 494 |
-
elif getattr(args, "gradient_checkpointing", False):
|
| 495 |
-
# For backward compatibility with older versions of transformers
|
| 496 |
-
if hasattr(model, "enable_input_require_grads"):
|
| 497 |
-
model.enable_input_require_grads()
|
| 498 |
-
else:
|
| 499 |
-
|
| 500 |
-
def make_inputs_require_grad(module, input, output):
|
| 501 |
-
output.requires_grad_(True)
|
| 502 |
-
|
| 503 |
-
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 504 |
-
|
| 505 |
-
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
|
| 506 |
-
raise ValueError(
|
| 507 |
-
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
|
| 508 |
-
" Please install `wandb` or `comet-ml` to resolve."
|
| 509 |
-
)
|
| 510 |
-
|
| 511 |
-
if model is not None:
|
| 512 |
-
self.is_encoder_decoder = model.config.is_encoder_decoder
|
| 513 |
-
elif args.is_encoder_decoder is None:
|
| 514 |
-
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
| 515 |
-
else:
|
| 516 |
-
self.is_encoder_decoder = args.is_encoder_decoder
|
| 517 |
-
|
| 518 |
-
if self.is_encoder_decoder:
|
| 519 |
-
self.decoder_start_token_id = model.config.decoder_start_token_id
|
| 520 |
-
self.pad_token_id = model.config.pad_token_id
|
| 521 |
-
|
| 522 |
-
if processing_class is None:
|
| 523 |
-
raise ValueError("processing_class must be specified to tokenize a ORPO dataset.")
|
| 524 |
-
if args.max_length is None:
|
| 525 |
-
warnings.warn(
|
| 526 |
-
"`max_length` is not set in the ORPOConfig's init"
|
| 527 |
-
" it will default to `512` by default, but you should do it yourself in the future.",
|
| 528 |
-
UserWarning,
|
| 529 |
-
)
|
| 530 |
-
max_length = 512
|
| 531 |
-
else:
|
| 532 |
-
max_length = args.max_length
|
| 533 |
-
if args.max_prompt_length is None:
|
| 534 |
-
warnings.warn(
|
| 535 |
-
"`max_prompt_length` is not set in the ORPOConfig's init"
|
| 536 |
-
" it will default to `128` by default, but you should do it yourself in the future.",
|
| 537 |
-
UserWarning,
|
| 538 |
-
)
|
| 539 |
-
max_prompt_length = 128
|
| 540 |
-
else:
|
| 541 |
-
max_prompt_length = args.max_prompt_length
|
| 542 |
-
|
| 543 |
-
if args.max_completion_length is None and self.is_encoder_decoder:
|
| 544 |
-
warnings.warn(
|
| 545 |
-
"When using an encoder decoder architecture, you should set `max_completion_length` in the ORPOConfig's init"
|
| 546 |
-
" it will default to `128` by default, but you should do it yourself in the future.",
|
| 547 |
-
UserWarning,
|
| 548 |
-
)
|
| 549 |
-
self.max_completion_length = 128
|
| 550 |
-
else:
|
| 551 |
-
self.max_completion_length = args.max_completion_length
|
| 552 |
-
|
| 553 |
-
if data_collator is None:
|
| 554 |
-
data_collator = DPODataCollatorWithPadding(
|
| 555 |
-
pad_token_id=processing_class.pad_token_id,
|
| 556 |
-
label_pad_token_id=args.label_pad_token_id,
|
| 557 |
-
is_encoder_decoder=self.is_encoder_decoder,
|
| 558 |
-
)
|
| 559 |
-
|
| 560 |
-
if args.remove_unused_columns:
|
| 561 |
-
args.remove_unused_columns = False
|
| 562 |
-
# warn users
|
| 563 |
-
warnings.warn(
|
| 564 |
-
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
|
| 565 |
-
" we have set it for you, but you should do it yourself in the future.",
|
| 566 |
-
UserWarning,
|
| 567 |
-
)
|
| 568 |
-
|
| 569 |
-
self.use_dpo_data_collator = True
|
| 570 |
-
else:
|
| 571 |
-
self.use_dpo_data_collator = False
|
| 572 |
-
|
| 573 |
-
# Disable dropout in the model and reference model
|
| 574 |
-
if args.disable_dropout:
|
| 575 |
-
disable_dropout_in_model(model)
|
| 576 |
-
|
| 577 |
-
self.max_length = max_length
|
| 578 |
-
self.generate_during_eval = args.generate_during_eval
|
| 579 |
-
self.label_pad_token_id = args.label_pad_token_id
|
| 580 |
-
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
|
| 581 |
-
self.max_prompt_length = max_prompt_length
|
| 582 |
-
self.truncation_mode = args.truncation_mode
|
| 583 |
-
self.processing_class = processing_class
|
| 584 |
-
|
| 585 |
-
self.beta = args.beta
|
| 586 |
-
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
| 587 |
-
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
|
| 588 |
-
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
|
| 589 |
-
warnings.warn(
|
| 590 |
-
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
|
| 591 |
-
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
|
| 592 |
-
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
|
| 593 |
-
"loss.",
|
| 594 |
-
UserWarning,
|
| 595 |
-
)
|
| 596 |
-
|
| 597 |
-
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
| 598 |
-
|
| 599 |
-
# The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
|
| 600 |
-
# input tensor associated with the key "input_ids". However, in ORPO, the sampled data does not include the
|
| 601 |
-
# "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
|
| 602 |
-
# "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
|
| 603 |
-
# of the input, floating-point operations will not be computed." To suppress this warning, we set the
|
| 604 |
-
# "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
|
| 605 |
-
# that the warning has already been issued.
|
| 606 |
-
model.warnings_issued["estimate_tokens"] = True
|
| 607 |
-
|
| 608 |
-
# Compute that only on the main process for faster data processing.
|
| 609 |
-
# see: https://github.com/huggingface/trl/pull/1255
|
| 610 |
-
with PartialState().main_process_first():
|
| 611 |
-
# Extract the prompt if needed, and apply the chat template if needed
|
| 612 |
-
train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
|
| 613 |
-
train_dataset = train_dataset.map(
|
| 614 |
-
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
|
| 615 |
-
)
|
| 616 |
-
train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
| 617 |
-
if eval_dataset is not None:
|
| 618 |
-
eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
|
| 619 |
-
eval_dataset = eval_dataset.map(
|
| 620 |
-
maybe_apply_chat_template,
|
| 621 |
-
fn_kwargs={"tokenizer": processing_class},
|
| 622 |
-
num_proc=args.dataset_num_proc,
|
| 623 |
-
)
|
| 624 |
-
eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
| 625 |
-
|
| 626 |
-
super().__init__(
|
| 627 |
-
model=model,
|
| 628 |
-
args=args,
|
| 629 |
-
data_collator=data_collator,
|
| 630 |
-
train_dataset=train_dataset,
|
| 631 |
-
eval_dataset=eval_dataset,
|
| 632 |
-
processing_class=processing_class,
|
| 633 |
-
model_init=model_init,
|
| 634 |
-
compute_metrics=compute_metrics,
|
| 635 |
-
callbacks=callbacks,
|
| 636 |
-
optimizers=optimizers,
|
| 637 |
-
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 638 |
-
)
|
| 639 |
-
|
| 640 |
-
# Add tags for models that have been loaded with the correct transformers version
|
| 641 |
-
if hasattr(self.model, "add_model_tags"):
|
| 642 |
-
self.model.add_model_tags(self._tag_names)
|
| 643 |
-
|
| 644 |
-
if not hasattr(self, "accelerator"):
|
| 645 |
-
raise AttributeError(
|
| 646 |
-
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
| 647 |
-
)
|
| 648 |
-
|
| 649 |
-
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
|
| 650 |
-
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
| 651 |
-
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
| 652 |
-
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
| 653 |
-
|
| 654 |
-
if model is not None:
|
| 655 |
-
if hasattr(model, "config"):
|
| 656 |
-
hidden_size = (
|
| 657 |
-
max(model.config.hidden_sizes)
|
| 658 |
-
if getattr(model.config, "hidden_sizes", None)
|
| 659 |
-
else getattr(model.config, "hidden_size", None)
|
| 660 |
-
)
|
| 661 |
-
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
| 662 |
-
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
| 663 |
-
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
| 664 |
-
config_kwargs.update(
|
| 665 |
-
{
|
| 666 |
-
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
| 667 |
-
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
| 668 |
-
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
| 669 |
-
}
|
| 670 |
-
)
|
| 671 |
-
|
| 672 |
-
# If ZeRO-3 is used, we shard both the active and reference model.
|
| 673 |
-
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
| 674 |
-
if config_kwargs["zero_optimization"]["stage"] != 3:
|
| 675 |
-
config_kwargs["zero_optimization"]["stage"] = 0
|
| 676 |
-
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
| 677 |
-
model.eval()
|
| 678 |
-
return model
|
| 679 |
-
|
| 680 |
-
def build_tokenized_answer(self, prompt, answer):
|
| 681 |
-
"""
|
| 682 |
-
Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
|
| 683 |
-
It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
|
| 684 |
-
Reference:
|
| 685 |
-
https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
| 686 |
-
"""
|
| 687 |
-
|
| 688 |
-
full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
|
| 689 |
-
prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
|
| 690 |
-
|
| 691 |
-
answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
|
| 692 |
-
answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
|
| 693 |
-
|
| 694 |
-
# Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
|
| 695 |
-
full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
|
| 696 |
-
|
| 697 |
-
# Prepare input tokens for token by token comparison
|
| 698 |
-
full_input_ids = np.array(full_tokenized["input_ids"])
|
| 699 |
-
|
| 700 |
-
if len(full_input_ids) != len(full_concat_input_ids):
|
| 701 |
-
raise ValueError("Prompt input ids and answer input ids should have the same length.")
|
| 702 |
-
|
| 703 |
-
# On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
|
| 704 |
-
# can be merged together when tokenizing prompt+answer. This could result
|
| 705 |
-
# on the last token from the prompt being different when tokenized on its own
|
| 706 |
-
# vs when done as prompt+answer.
|
| 707 |
-
response_token_ids_start_idx = len(prompt_input_ids)
|
| 708 |
-
|
| 709 |
-
# If tokenized prompt is different than both prompt+answer, then it means the
|
| 710 |
-
# last token has changed due to merging.
|
| 711 |
-
if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
|
| 712 |
-
response_token_ids_start_idx -= 1
|
| 713 |
-
|
| 714 |
-
prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
|
| 715 |
-
prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
|
| 716 |
-
|
| 717 |
-
if len(prompt_input_ids) != len(prompt_attention_mask):
|
| 718 |
-
raise ValueError("Prompt input ids and attention mask should have the same length.")
|
| 719 |
-
|
| 720 |
-
answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
|
| 721 |
-
answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
|
| 722 |
-
|
| 723 |
-
return dict(
|
| 724 |
-
prompt_input_ids=prompt_input_ids,
|
| 725 |
-
prompt_attention_mask=prompt_attention_mask,
|
| 726 |
-
input_ids=answer_input_ids,
|
| 727 |
-
attention_mask=answer_attention_mask,
|
| 728 |
-
)
|
| 729 |
-
|
| 730 |
-
def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
|
| 731 |
-
"""Tokenize a single row from a ORPO specific dataset.
|
| 732 |
-
|
| 733 |
-
At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
|
| 734 |
-
in case the prompt + chosen or prompt + rejected responses is/are too long. First
|
| 735 |
-
we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
|
| 736 |
-
|
| 737 |
-
We also create the labels for the chosen/rejected responses, which are of length equal to
|
| 738 |
-
the sum of the length of the prompt and the chosen/rejected response, with
|
| 739 |
-
label_pad_token_id for the prompt tokens.
|
| 740 |
-
"""
|
| 741 |
-
batch = {}
|
| 742 |
-
prompt = feature["prompt"]
|
| 743 |
-
chosen = feature["chosen"]
|
| 744 |
-
rejected = feature["rejected"]
|
| 745 |
-
|
| 746 |
-
if not self.is_encoder_decoder:
|
| 747 |
-
# Check issues below for more details
|
| 748 |
-
# 1. https://github.com/huggingface/trl/issues/907
|
| 749 |
-
# 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
| 750 |
-
# 3. https://github.com/LianjiaTech/BELLE/issues/337
|
| 751 |
-
|
| 752 |
-
if not isinstance(prompt, str):
|
| 753 |
-
raise ValueError(f"prompt should be an str but got {type(prompt)}")
|
| 754 |
-
prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
|
| 755 |
-
prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
|
| 756 |
-
|
| 757 |
-
if not isinstance(chosen, str):
|
| 758 |
-
raise ValueError(f"chosen should be an str but got {type(chosen)}")
|
| 759 |
-
chosen_tokens = self.build_tokenized_answer(prompt, chosen)
|
| 760 |
-
|
| 761 |
-
if not isinstance(rejected, str):
|
| 762 |
-
raise ValueError(f"rejected should be an str but got {type(rejected)}")
|
| 763 |
-
rejected_tokens = self.build_tokenized_answer(prompt, rejected)
|
| 764 |
-
|
| 765 |
-
# Last prompt token might get merged by tokenizer and
|
| 766 |
-
# it should not be included for generation if that happens
|
| 767 |
-
prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
|
| 768 |
-
|
| 769 |
-
chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
|
| 770 |
-
rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
|
| 771 |
-
prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
|
| 772 |
-
|
| 773 |
-
for k, v in prompt_tokens.items():
|
| 774 |
-
prompt_tokens[k] = v[:prompt_len_input_ids]
|
| 775 |
-
|
| 776 |
-
# Make sure prompts only have one different token at most an
|
| 777 |
-
# and length only differs by 1 at most
|
| 778 |
-
num_diff_tokens = sum(
|
| 779 |
-
[a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
|
| 780 |
-
)
|
| 781 |
-
num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
|
| 782 |
-
if num_diff_tokens > 1 or num_diff_len > 1:
|
| 783 |
-
raise ValueError(
|
| 784 |
-
"Chosen and rejected prompt_input_ids might only differ on the "
|
| 785 |
-
"last token due to tokenizer merge ops."
|
| 786 |
-
)
|
| 787 |
-
|
| 788 |
-
# add BOS token to head of prompt. Avoid adding if it's already there
|
| 789 |
-
prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
|
| 790 |
-
self.processing_class.bos_token_id,
|
| 791 |
-
prompt_len_input_ids,
|
| 792 |
-
prompt_tokens,
|
| 793 |
-
chosen_prompt_len_input_ids,
|
| 794 |
-
chosen_tokens,
|
| 795 |
-
rejected_prompt_len_input_ids,
|
| 796 |
-
rejected_tokens,
|
| 797 |
-
)
|
| 798 |
-
|
| 799 |
-
# add EOS token to end of answer. Avoid adding if it's already there
|
| 800 |
-
chosen_tokens, rejected_tokens = add_eos_token_if_needed(
|
| 801 |
-
self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
|
| 802 |
-
)
|
| 803 |
-
|
| 804 |
-
longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
|
| 805 |
-
|
| 806 |
-
# if combined sequence is too long, truncate the prompt
|
| 807 |
-
for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
|
| 808 |
-
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
| 809 |
-
if self.truncation_mode == "keep_start":
|
| 810 |
-
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
| 811 |
-
answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
|
| 812 |
-
elif self.truncation_mode == "keep_end":
|
| 813 |
-
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
| 814 |
-
answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
|
| 815 |
-
else:
|
| 816 |
-
raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
|
| 817 |
-
|
| 818 |
-
# if that's still too long, truncate the response
|
| 819 |
-
for answer_tokens in [chosen_tokens, rejected_tokens]:
|
| 820 |
-
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
| 821 |
-
for k in ["input_ids", "attention_mask"]:
|
| 822 |
-
answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
|
| 823 |
-
|
| 824 |
-
# Create labels
|
| 825 |
-
chosen_sequence_tokens = {
|
| 826 |
-
k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
|
| 827 |
-
}
|
| 828 |
-
rejected_sequence_tokens = {
|
| 829 |
-
k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
|
| 830 |
-
}
|
| 831 |
-
chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
|
| 832 |
-
chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
|
| 833 |
-
self.label_pad_token_id
|
| 834 |
-
] * len(chosen_tokens["prompt_input_ids"])
|
| 835 |
-
rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
|
| 836 |
-
rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
|
| 837 |
-
self.label_pad_token_id
|
| 838 |
-
] * len(rejected_tokens["prompt_input_ids"])
|
| 839 |
-
|
| 840 |
-
for k, toks in {
|
| 841 |
-
"chosen_": chosen_sequence_tokens,
|
| 842 |
-
"rejected_": rejected_sequence_tokens,
|
| 843 |
-
"": prompt_tokens,
|
| 844 |
-
}.items():
|
| 845 |
-
for type_key, tokens in toks.items():
|
| 846 |
-
if type_key == "token_type_ids":
|
| 847 |
-
continue
|
| 848 |
-
batch[f"{k}{type_key}"] = tokens
|
| 849 |
-
|
| 850 |
-
else:
|
| 851 |
-
chosen_tokens = self.processing_class(
|
| 852 |
-
chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
|
| 853 |
-
)
|
| 854 |
-
rejected_tokens = self.processing_class(
|
| 855 |
-
rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
|
| 856 |
-
)
|
| 857 |
-
prompt_tokens = self.processing_class(
|
| 858 |
-
prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
|
| 859 |
-
)
|
| 860 |
-
|
| 861 |
-
batch["chosen_labels"] = chosen_tokens["input_ids"]
|
| 862 |
-
batch["rejected_labels"] = rejected_tokens["input_ids"]
|
| 863 |
-
batch["prompt_input_ids"] = prompt_tokens["input_ids"]
|
| 864 |
-
batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
|
| 865 |
-
|
| 866 |
-
if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
| 867 |
-
batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
| 868 |
-
labels=torch.tensor(batch["rejected_labels"])
|
| 869 |
-
)
|
| 870 |
-
batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
| 871 |
-
labels=torch.tensor(batch["chosen_labels"])
|
| 872 |
-
)
|
| 873 |
-
|
| 874 |
-
if is_torch_xla_available():
|
| 875 |
-
# Pad the sequences to global max_length to avoid TorchXLA recompilation
|
| 876 |
-
for k in batch:
|
| 877 |
-
if "labels" in k or self.is_encoder_decoder:
|
| 878 |
-
pad_value = self.label_pad_token_id
|
| 879 |
-
elif k.endswith("_input_ids"):
|
| 880 |
-
pad_value = self.padding_value
|
| 881 |
-
elif k.endswith("_attention_mask"):
|
| 882 |
-
pad_value = 0
|
| 883 |
-
batch[k] = batch[k] + [pad_value] * (self.max_length - len(batch[k]))
|
| 884 |
-
return batch
|
| 885 |
-
|
| 886 |
-
@staticmethod
|
| 887 |
-
def concatenated_inputs(
|
| 888 |
-
batch: dict[str, Union[list, torch.LongTensor]],
|
| 889 |
-
is_encoder_decoder: bool = False,
|
| 890 |
-
label_pad_token_id: int = -100,
|
| 891 |
-
padding_value: int = 0,
|
| 892 |
-
device: Optional[torch.device] = None,
|
| 893 |
-
) -> dict[str, torch.LongTensor]:
|
| 894 |
-
"""Concatenate the chosen and rejected inputs into a single tensor.
|
| 895 |
-
|
| 896 |
-
Args:
|
| 897 |
-
batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
|
| 898 |
-
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
| 899 |
-
label_pad_token_id: The label pad token id.
|
| 900 |
-
padding_value: The padding value to use for the concatenated inputs_ids.
|
| 901 |
-
device: The device for the concatenated inputs.
|
| 902 |
-
|
| 903 |
-
Returns:
|
| 904 |
-
A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
|
| 905 |
-
"""
|
| 906 |
-
concatenated_batch = {}
|
| 907 |
-
|
| 908 |
-
if is_encoder_decoder:
|
| 909 |
-
max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
|
| 910 |
-
else:
|
| 911 |
-
max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
|
| 912 |
-
|
| 913 |
-
for k in batch:
|
| 914 |
-
if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
|
| 915 |
-
if "labels" in k or is_encoder_decoder:
|
| 916 |
-
pad_value = label_pad_token_id
|
| 917 |
-
elif k.endswith("_input_ids"):
|
| 918 |
-
pad_value = padding_value
|
| 919 |
-
elif k.endswith("_attention_mask"):
|
| 920 |
-
pad_value = 0
|
| 921 |
-
concatenated_key = k.replace("chosen", "concatenated")
|
| 922 |
-
concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
|
| 923 |
-
for k in batch:
|
| 924 |
-
if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
|
| 925 |
-
if "labels" in k or is_encoder_decoder:
|
| 926 |
-
pad_value = label_pad_token_id
|
| 927 |
-
elif k.endswith("_input_ids"):
|
| 928 |
-
pad_value = padding_value
|
| 929 |
-
elif k.endswith("_attention_mask"):
|
| 930 |
-
pad_value = 0
|
| 931 |
-
concatenated_key = k.replace("rejected", "concatenated")
|
| 932 |
-
concatenated_batch[concatenated_key] = torch.cat(
|
| 933 |
-
(
|
| 934 |
-
concatenated_batch[concatenated_key],
|
| 935 |
-
pad_to_length(batch[k], max_length, pad_value=pad_value),
|
| 936 |
-
),
|
| 937 |
-
dim=0,
|
| 938 |
-
).to(device=device)
|
| 939 |
-
|
| 940 |
-
if is_encoder_decoder:
|
| 941 |
-
concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
|
| 942 |
-
concatenated_batch["concatenated_attention_mask"] = (
|
| 943 |
-
batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
|
| 944 |
-
)
|
| 945 |
-
|
| 946 |
-
return concatenated_batch
|
| 947 |
-
|
| 948 |
-
def odds_ratio_loss(
|
| 949 |
-
self,
|
| 950 |
-
policy_chosen_logps: torch.FloatTensor,
|
| 951 |
-
policy_rejected_logps: torch.FloatTensor,
|
| 952 |
-
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
| 953 |
-
"""Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities.
|
| 954 |
-
|
| 955 |
-
Args:
|
| 956 |
-
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
|
| 957 |
-
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
| 958 |
-
|
| 959 |
-
Returns:
|
| 960 |
-
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
|
| 961 |
-
The losses tensor contains the ORPO loss for each example in the batch.
|
| 962 |
-
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
| 963 |
-
The log odds ratio of the chosen responses over the rejected responses ratio for logging purposes.
|
| 964 |
-
The `log(sigmoid(log_odds_chosen))` for logging purposes.
|
| 965 |
-
"""
|
| 966 |
-
|
| 967 |
-
# Derived from Eqs. (4) and (7) from https://huggingface.co/papers/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x)
|
| 968 |
-
log_odds = (policy_chosen_logps - policy_rejected_logps) - (
|
| 969 |
-
torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps))
|
| 970 |
-
)
|
| 971 |
-
ratio = F.logsigmoid(log_odds)
|
| 972 |
-
losses = self.beta * ratio
|
| 973 |
-
|
| 974 |
-
chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
|
| 975 |
-
rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
|
| 976 |
-
|
| 977 |
-
return losses, chosen_rewards, rejected_rewards, torch.mean(ratio), torch.mean(log_odds)
|
| 978 |
-
|
| 979 |
-
@staticmethod
|
| 980 |
-
def get_batch_logps(
|
| 981 |
-
logits: torch.FloatTensor,
|
| 982 |
-
labels: torch.LongTensor,
|
| 983 |
-
average_log_prob: bool = False,
|
| 984 |
-
label_pad_token_id: int = -100,
|
| 985 |
-
is_encoder_decoder: bool = False,
|
| 986 |
-
) -> torch.FloatTensor:
|
| 987 |
-
"""Compute the log probabilities of the given labels under the given logits.
|
| 988 |
-
|
| 989 |
-
Args:
|
| 990 |
-
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
|
| 991 |
-
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
|
| 992 |
-
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
| 993 |
-
label_pad_token_id: The label pad token id.
|
| 994 |
-
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
| 995 |
-
|
| 996 |
-
Returns:
|
| 997 |
-
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
| 998 |
-
"""
|
| 999 |
-
if logits.shape[:-1] != labels.shape:
|
| 1000 |
-
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
| 1001 |
-
|
| 1002 |
-
if not is_encoder_decoder:
|
| 1003 |
-
labels = labels[:, 1:].clone()
|
| 1004 |
-
logits = logits[:, :-1, :]
|
| 1005 |
-
loss_mask = labels != label_pad_token_id
|
| 1006 |
-
|
| 1007 |
-
# dummy token; we'll ignore the losses on these tokens later
|
| 1008 |
-
labels = torch.where(labels == label_pad_token_id, 0, labels)
|
| 1009 |
-
|
| 1010 |
-
per_token_logps = selective_log_softmax(logits, labels)
|
| 1011 |
-
|
| 1012 |
-
if average_log_prob:
|
| 1013 |
-
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
| 1014 |
-
else:
|
| 1015 |
-
return (per_token_logps * loss_mask).sum(-1)
|
| 1016 |
-
|
| 1017 |
-
def concatenated_forward(
|
| 1018 |
-
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
| 1019 |
-
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
| 1020 |
-
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
| 1021 |
-
|
| 1022 |
-
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
| 1023 |
-
"""
|
| 1024 |
-
concatenated_batch = self.concatenated_inputs(
|
| 1025 |
-
batch,
|
| 1026 |
-
is_encoder_decoder=self.is_encoder_decoder,
|
| 1027 |
-
label_pad_token_id=self.label_pad_token_id,
|
| 1028 |
-
padding_value=self.padding_value,
|
| 1029 |
-
device=self.accelerator.device,
|
| 1030 |
-
)
|
| 1031 |
-
len_chosen = batch["chosen_labels"].shape[0]
|
| 1032 |
-
|
| 1033 |
-
model_kwargs = (
|
| 1034 |
-
{
|
| 1035 |
-
"decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
|
| 1036 |
-
}
|
| 1037 |
-
if self.is_encoder_decoder
|
| 1038 |
-
else {}
|
| 1039 |
-
)
|
| 1040 |
-
|
| 1041 |
-
if self.aux_loss_enabled:
|
| 1042 |
-
model_kwargs["output_router_logits"] = True
|
| 1043 |
-
|
| 1044 |
-
outputs = model(
|
| 1045 |
-
concatenated_batch["concatenated_input_ids"],
|
| 1046 |
-
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
| 1047 |
-
use_cache=False,
|
| 1048 |
-
**model_kwargs,
|
| 1049 |
-
)
|
| 1050 |
-
all_logits = outputs.logits
|
| 1051 |
-
|
| 1052 |
-
def cross_entropy_loss(logits, labels):
|
| 1053 |
-
if not self.is_encoder_decoder:
|
| 1054 |
-
# Shift so that tokens < n predict n
|
| 1055 |
-
logits = logits[..., :-1, :].contiguous()
|
| 1056 |
-
labels = labels[..., 1:].contiguous()
|
| 1057 |
-
# Flatten the tokens
|
| 1058 |
-
loss_fct = nn.CrossEntropyLoss()
|
| 1059 |
-
logits = logits.view(-1, logits.shape[-1])
|
| 1060 |
-
labels = labels.view(-1)
|
| 1061 |
-
# Enable model parallelism
|
| 1062 |
-
labels = labels.to(logits.device)
|
| 1063 |
-
loss = loss_fct(logits, labels)
|
| 1064 |
-
return loss
|
| 1065 |
-
|
| 1066 |
-
if self.is_encoder_decoder:
|
| 1067 |
-
labels = concatenated_batch["concatenated_labels"].clone()
|
| 1068 |
-
else:
|
| 1069 |
-
labels = concatenated_batch["concatenated_input_ids"].clone()
|
| 1070 |
-
attention_mask = concatenated_batch["concatenated_attention_mask"]
|
| 1071 |
-
labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
|
| 1072 |
-
# orpo chosen nll loss is computed over the full prompt and response
|
| 1073 |
-
chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
|
| 1074 |
-
|
| 1075 |
-
all_logps = self.get_batch_logps(
|
| 1076 |
-
all_logits,
|
| 1077 |
-
concatenated_batch["concatenated_labels"],
|
| 1078 |
-
average_log_prob=True,
|
| 1079 |
-
is_encoder_decoder=self.is_encoder_decoder,
|
| 1080 |
-
label_pad_token_id=self.label_pad_token_id,
|
| 1081 |
-
)
|
| 1082 |
-
|
| 1083 |
-
chosen_logps = all_logps[:len_chosen]
|
| 1084 |
-
rejected_logps = all_logps[len_chosen:]
|
| 1085 |
-
|
| 1086 |
-
if not self.is_encoder_decoder:
|
| 1087 |
-
chosen_logits = all_logits[:len_chosen, :-1, :]
|
| 1088 |
-
rejected_logits = all_logits[len_chosen:, :-1, :]
|
| 1089 |
-
else:
|
| 1090 |
-
chosen_logits = all_logits[:len_chosen]
|
| 1091 |
-
rejected_logits = all_logits[len_chosen:]
|
| 1092 |
-
|
| 1093 |
-
if self.aux_loss_enabled:
|
| 1094 |
-
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss)
|
| 1095 |
-
|
| 1096 |
-
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)
|
| 1097 |
-
|
| 1098 |
-
def get_batch_loss_metrics(
|
| 1099 |
-
self,
|
| 1100 |
-
model,
|
| 1101 |
-
batch: dict[str, Union[list, torch.LongTensor]],
|
| 1102 |
-
train_eval: Literal["train", "eval"] = "train",
|
| 1103 |
-
):
|
| 1104 |
-
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
|
| 1105 |
-
metrics = {}
|
| 1106 |
-
|
| 1107 |
-
forward_output = self.concatenated_forward(model, batch)
|
| 1108 |
-
(
|
| 1109 |
-
policy_chosen_logps,
|
| 1110 |
-
policy_rejected_logps,
|
| 1111 |
-
policy_chosen_logits,
|
| 1112 |
-
policy_rejected_logits,
|
| 1113 |
-
policy_nll_loss,
|
| 1114 |
-
) = forward_output[:5]
|
| 1115 |
-
if self.aux_loss_enabled:
|
| 1116 |
-
aux_loss = forward_output[5]
|
| 1117 |
-
|
| 1118 |
-
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
|
| 1119 |
-
policy_chosen_logps, policy_rejected_logps
|
| 1120 |
-
)
|
| 1121 |
-
# full ORPO loss
|
| 1122 |
-
loss = policy_nll_loss - losses.mean()
|
| 1123 |
-
|
| 1124 |
-
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
| 1125 |
-
|
| 1126 |
-
prefix = "eval_" if train_eval == "eval" else ""
|
| 1127 |
-
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean()
|
| 1128 |
-
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean()
|
| 1129 |
-
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean()
|
| 1130 |
-
metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
|
| 1131 |
-
chosen_rewards - rejected_rewards
|
| 1132 |
-
).mean()
|
| 1133 |
-
metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
|
| 1134 |
-
metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
|
| 1135 |
-
metrics[f"{prefix}logits/rejected"] = (
|
| 1136 |
-
self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean()
|
| 1137 |
-
)
|
| 1138 |
-
metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean()
|
| 1139 |
-
metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
|
| 1140 |
-
metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).mean()
|
| 1141 |
-
metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).mean()
|
| 1142 |
-
if is_torch_xla_available():
|
| 1143 |
-
xm.mark_step() # needed because .item() calls
|
| 1144 |
-
for k, v in metrics.items():
|
| 1145 |
-
metrics[k] = v.item()
|
| 1146 |
-
if self.aux_loss_enabled:
|
| 1147 |
-
loss += self.aux_loss_coef * aux_loss
|
| 1148 |
-
|
| 1149 |
-
return loss, metrics
|
| 1150 |
-
|
| 1151 |
-
def compute_loss(
|
| 1152 |
-
self,
|
| 1153 |
-
model: Union[PreTrainedModel, nn.Module],
|
| 1154 |
-
inputs: dict[str, Union[torch.Tensor, Any]],
|
| 1155 |
-
return_outputs=False,
|
| 1156 |
-
num_items_in_batch=None,
|
| 1157 |
-
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
| 1158 |
-
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1159 |
-
|
| 1160 |
-
with compute_loss_context_manager:
|
| 1161 |
-
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
|
| 1162 |
-
|
| 1163 |
-
# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
|
| 1164 |
-
loss = loss.to(self.args.device)
|
| 1165 |
-
|
| 1166 |
-
# force log the metrics
|
| 1167 |
-
self.store_metrics(metrics, train_eval="train")
|
| 1168 |
-
|
| 1169 |
-
if return_outputs:
|
| 1170 |
-
return (loss, metrics)
|
| 1171 |
-
return loss
|
| 1172 |
-
|
| 1173 |
-
def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
|
| 1174 |
-
"""Generate samples from the model and reference model for the given batch of inputs."""
|
| 1175 |
-
|
| 1176 |
-
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
| 1177 |
-
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
|
| 1178 |
-
generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1179 |
-
|
| 1180 |
-
with generate_context_manager:
|
| 1181 |
-
policy_output = model.generate(
|
| 1182 |
-
input_ids=batch["prompt_input_ids"],
|
| 1183 |
-
attention_mask=batch["prompt_attention_mask"],
|
| 1184 |
-
max_length=self.max_length,
|
| 1185 |
-
do_sample=True,
|
| 1186 |
-
pad_token_id=self.processing_class.pad_token_id,
|
| 1187 |
-
)
|
| 1188 |
-
|
| 1189 |
-
policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
|
| 1190 |
-
policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
|
| 1191 |
-
|
| 1192 |
-
return policy_output_decoded
|
| 1193 |
-
|
| 1194 |
-
def prediction_step(
|
| 1195 |
-
self,
|
| 1196 |
-
model: Union[PreTrainedModel, nn.Module],
|
| 1197 |
-
inputs: dict[str, Union[torch.Tensor, Any]],
|
| 1198 |
-
prediction_loss_only: bool,
|
| 1199 |
-
ignore_keys: Optional[list[str]] = None,
|
| 1200 |
-
):
|
| 1201 |
-
if not self.use_dpo_data_collator:
|
| 1202 |
-
warnings.warn(
|
| 1203 |
-
"prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
|
| 1204 |
-
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
|
| 1205 |
-
)
|
| 1206 |
-
if ignore_keys is None:
|
| 1207 |
-
if hasattr(model, "config"):
|
| 1208 |
-
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
| 1209 |
-
else:
|
| 1210 |
-
ignore_keys = []
|
| 1211 |
-
|
| 1212 |
-
prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1213 |
-
|
| 1214 |
-
with torch.no_grad(), prediction_context_manager:
|
| 1215 |
-
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
|
| 1216 |
-
|
| 1217 |
-
# force log the metrics
|
| 1218 |
-
self.store_metrics(metrics, train_eval="eval")
|
| 1219 |
-
|
| 1220 |
-
if prediction_loss_only:
|
| 1221 |
-
return (loss.detach(), None, None)
|
| 1222 |
-
|
| 1223 |
-
# logits for the chosen and rejected samples from model
|
| 1224 |
-
logits_dict = {
|
| 1225 |
-
"eval_logits/chosen": metrics["eval_logits/chosen"],
|
| 1226 |
-
"eval_logits/rejected": metrics["eval_logits/rejected"],
|
| 1227 |
-
}
|
| 1228 |
-
logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
|
| 1229 |
-
logits = torch.tensor(logits, device=self.accelerator.device)
|
| 1230 |
-
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
| 1231 |
-
|
| 1232 |
-
return (loss.detach(), logits, labels)
|
| 1233 |
-
|
| 1234 |
-
def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
|
| 1235 |
-
for key, value in metrics.items():
|
| 1236 |
-
self._stored_metrics[train_eval][key].append(value)
|
| 1237 |
-
|
| 1238 |
-
def evaluation_loop(
|
| 1239 |
-
self,
|
| 1240 |
-
dataloader: DataLoader,
|
| 1241 |
-
description: str,
|
| 1242 |
-
prediction_loss_only: Optional[bool] = None,
|
| 1243 |
-
ignore_keys: Optional[list[str]] = None,
|
| 1244 |
-
metric_key_prefix: str = "eval",
|
| 1245 |
-
) -> EvalLoopOutput:
|
| 1246 |
-
"""
|
| 1247 |
-
Overriding built-in evaluation loop to store metrics for each batch.
|
| 1248 |
-
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
| 1249 |
-
|
| 1250 |
-
Works both with or without labels.
|
| 1251 |
-
"""
|
| 1252 |
-
|
| 1253 |
-
# Sample and save to game log if requested (for one batch to save time)
|
| 1254 |
-
if self.generate_during_eval:
|
| 1255 |
-
# Generate random indices within the range of the total number of samples
|
| 1256 |
-
num_samples = len(dataloader.dataset)
|
| 1257 |
-
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
|
| 1258 |
-
|
| 1259 |
-
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
| 1260 |
-
random_batch_dataset = dataloader.dataset.select(random_indices)
|
| 1261 |
-
random_batch = self.data_collator(random_batch_dataset)
|
| 1262 |
-
random_batch = self._prepare_inputs(random_batch)
|
| 1263 |
-
|
| 1264 |
-
policy_output_decoded = self.generate_from_model(self.model, random_batch)
|
| 1265 |
-
|
| 1266 |
-
table = pd.DataFrame(
|
| 1267 |
-
columns=["Prompt", "Policy"],
|
| 1268 |
-
data=[
|
| 1269 |
-
[prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
|
| 1270 |
-
],
|
| 1271 |
-
)
|
| 1272 |
-
if "wandb" in self.args.report_to:
|
| 1273 |
-
wandb.log({"game_log": wandb.Table(data=table)})
|
| 1274 |
-
|
| 1275 |
-
if "comet_ml" in self.args.report_to:
|
| 1276 |
-
log_table_to_comet_experiment(
|
| 1277 |
-
name="game_log.csv",
|
| 1278 |
-
table=table,
|
| 1279 |
-
)
|
| 1280 |
-
|
| 1281 |
-
# Base evaluation
|
| 1282 |
-
initial_output = super().evaluation_loop(
|
| 1283 |
-
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
|
| 1284 |
-
)
|
| 1285 |
-
|
| 1286 |
-
return initial_output
|
| 1287 |
-
|
| 1288 |
-
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
| 1289 |
-
"""
|
| 1290 |
-
Log `logs` on the various objects watching training, including stored metrics.
|
| 1291 |
-
|
| 1292 |
-
Args:
|
| 1293 |
-
logs (`dict[str, float]`):
|
| 1294 |
-
The values to log.
|
| 1295 |
-
start_time (`float` or `None`, *optional*, defaults to `None`):
|
| 1296 |
-
Start time of the training.
|
| 1297 |
-
"""
|
| 1298 |
-
# logs either has 'loss' or 'eval_loss'
|
| 1299 |
-
train_eval = "train" if "loss" in logs else "eval"
|
| 1300 |
-
# Add averaged stored metrics to logs
|
| 1301 |
-
for key, metrics in self._stored_metrics[train_eval].items():
|
| 1302 |
-
logs[key] = torch.tensor(metrics).mean().item()
|
| 1303 |
-
del self._stored_metrics[train_eval]
|
| 1304 |
-
|
| 1305 |
-
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
| 1306 |
-
return super().log(logs, start_time)
|
| 1307 |
-
else: # transformers<=4.46
|
| 1308 |
-
return super().log(logs)
|
| 1309 |
-
|
| 1310 |
-
def _shift_right(self, input_ids):
|
| 1311 |
-
if self.decoder_start_token_id is None:
|
| 1312 |
-
raise ValueError(
|
| 1313 |
-
"model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
|
| 1314 |
-
)
|
| 1315 |
-
|
| 1316 |
-
# shift inputs to the right
|
| 1317 |
-
if is_torch_fx_proxy(input_ids):
|
| 1318 |
-
# Item assignment is not supported natively for proxies.
|
| 1319 |
-
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
|
| 1320 |
-
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
|
| 1321 |
-
else:
|
| 1322 |
-
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
| 1323 |
-
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
| 1324 |
-
shifted_input_ids[..., 0] = self.decoder_start_token_id
|
| 1325 |
-
|
| 1326 |
-
if self.pad_token_id is None:
|
| 1327 |
-
raise ValueError("model.config.pad_token_id has to be defined.")
|
| 1328 |
-
# replace possible -100 values in labels by `pad_token_id`
|
| 1329 |
-
shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
|
| 1330 |
-
|
| 1331 |
-
return shifted_input_ids
|
| 1332 |
-
|
| 1333 |
-
def create_model_card(
|
| 1334 |
-
self,
|
| 1335 |
-
model_name: Optional[str] = None,
|
| 1336 |
-
dataset_name: Optional[str] = None,
|
| 1337 |
-
tags: Union[str, list[str], None] = None,
|
| 1338 |
-
):
|
| 1339 |
-
"""
|
| 1340 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
| 1341 |
-
|
| 1342 |
-
Args:
|
| 1343 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1344 |
-
Name of the model.
|
| 1345 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1346 |
-
Name of the dataset used for training.
|
| 1347 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 1348 |
-
Tags to be associated with the model card.
|
| 1349 |
-
"""
|
| 1350 |
-
if not self.is_world_process_zero():
|
| 1351 |
-
return
|
| 1352 |
-
|
| 1353 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 1354 |
-
base_model = self.model.config._name_or_path
|
| 1355 |
-
else:
|
| 1356 |
-
base_model = None
|
| 1357 |
-
|
| 1358 |
-
tags = tags or []
|
| 1359 |
-
if isinstance(tags, str):
|
| 1360 |
-
tags = [tags]
|
| 1361 |
-
|
| 1362 |
-
if hasattr(self.model.config, "unsloth_version"):
|
| 1363 |
-
tags.append("unsloth")
|
| 1364 |
-
|
| 1365 |
-
citation = textwrap.dedent("""\
|
| 1366 |
-
@article{hong2024orpo,
|
| 1367 |
-
title = {{ORPO: Monolithic Preference Optimization without Reference Model}},
|
| 1368 |
-
author = {Jiwoo Hong and Noah Lee and James Thorne},
|
| 1369 |
-
year = 2024,
|
| 1370 |
-
eprint = {arXiv:2403.07691}
|
| 1371 |
-
}""")
|
| 1372 |
-
|
| 1373 |
-
model_card = generate_model_card(
|
| 1374 |
-
base_model=base_model,
|
| 1375 |
-
model_name=model_name,
|
| 1376 |
-
hub_model_id=self.hub_model_id,
|
| 1377 |
-
dataset_name=dataset_name,
|
| 1378 |
-
tags=tags,
|
| 1379 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 1380 |
-
comet_url=get_comet_experiment_url(),
|
| 1381 |
-
trainer_name="ORPO",
|
| 1382 |
-
trainer_citation=citation,
|
| 1383 |
-
paper_title="ORPO: Monolithic Preference Optimization without Reference Model",
|
| 1384 |
-
paper_id="2403.07691",
|
| 1385 |
-
)
|
| 1386 |
-
|
| 1387 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 1388 |
-
class UnslothORPOTrainer(_UnslothORPOTrainer):
|
| 1389 |
-
"""
|
| 1390 |
-
|
| 1391 |
-
Initialize ORPOTrainer.
|
| 1392 |
-
|
| 1393 |
-
Args:
|
| 1394 |
-
model (`transformers.PreTrainedModel`):
|
| 1395 |
-
The model to train, preferably an `AutoModelForSequenceClassification`.
|
| 1396 |
-
args (`ORPOConfig`):
|
| 1397 |
-
The ORPO config arguments to use for training.
|
| 1398 |
-
data_collator (`transformers.DataCollator`):
|
| 1399 |
-
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
| 1400 |
-
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
| 1401 |
-
train_dataset (`datasets.Dataset`):
|
| 1402 |
-
The dataset to use for training.
|
| 1403 |
-
eval_dataset (`datasets.Dataset`):
|
| 1404 |
-
The dataset to use for evaluation.
|
| 1405 |
-
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
| 1406 |
-
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
| 1407 |
-
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
| 1408 |
-
reuse the fine-tuned model.
|
| 1409 |
-
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
| 1410 |
-
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
| 1411 |
-
callbacks (`list[transformers.TrainerCallback]`):
|
| 1412 |
-
The callbacks to use for training.
|
| 1413 |
-
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
| 1414 |
-
The optimizer and scheduler to use for training.
|
| 1415 |
-
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
| 1416 |
-
The function to use to preprocess the logits before computing the metrics.
|
| 1417 |
-
peft_config (`dict`, defaults to `None`):
|
| 1418 |
-
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
| 1419 |
-
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
| 1420 |
-
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
| 1421 |
-
a dictionary string to metric values.
|
| 1422 |
-
|
| 1423 |
-
"""
|
| 1424 |
-
def __init__(
|
| 1425 |
-
self,
|
| 1426 |
-
model = None,
|
| 1427 |
-
args = None,
|
| 1428 |
-
data_collator = None,
|
| 1429 |
-
train_dataset = None,
|
| 1430 |
-
eval_dataset = None,
|
| 1431 |
-
processing_class = None,
|
| 1432 |
-
model_init = None,
|
| 1433 |
-
callbacks = None,
|
| 1434 |
-
preprocess_logits_for_metrics = None,
|
| 1435 |
-
peft_config = None,
|
| 1436 |
-
compute_metrics = None,
|
| 1437 |
-
**kwargs
|
| 1438 |
-
):
|
| 1439 |
-
if args is None: args = UnslothORPOConfig()
|
| 1440 |
-
use_bf16 = getattr(args, 'bf16', False)
|
| 1441 |
-
if type(use_bf16) is not bool: use_bf16 = False
|
| 1442 |
-
use_fp16 = getattr(args, 'fp16', False)
|
| 1443 |
-
if type(use_fp16) is not bool: use_fp16 = False
|
| 1444 |
-
force_float32 = False
|
| 1445 |
-
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
| 1446 |
-
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 1447 |
-
force_float32 = True
|
| 1448 |
-
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 1449 |
-
dtype = getattr(model.config, 'torch_dtype', None)
|
| 1450 |
-
if dtype is None: dtype = model.get_input_embeddings().dtype
|
| 1451 |
-
from unsloth_zoo.utils import _get_dtype
|
| 1452 |
-
dtype = _get_dtype(dtype)
|
| 1453 |
-
float16 = dtype == torch.float16
|
| 1454 |
-
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 1455 |
-
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 1456 |
-
if force_float32:
|
| 1457 |
-
args.fp16 = False
|
| 1458 |
-
args.bf16 = False
|
| 1459 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 1460 |
-
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 1461 |
-
args.fp16 = float16
|
| 1462 |
-
args.bf16 = not float16
|
| 1463 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 1464 |
-
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 1465 |
-
args.eval_strategy = 'steps'
|
| 1466 |
-
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 1467 |
-
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 1468 |
-
if ga_steps is not None and ga_steps > 1:
|
| 1469 |
-
from transformers import __version__ as transformers_version
|
| 1470 |
-
if Version(transformers_version) <= Version('4.45.2'):
|
| 1471 |
-
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 1472 |
-
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 1473 |
-
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 1474 |
-
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 1475 |
-
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 1476 |
-
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 1477 |
-
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 1478 |
-
if type(fp16_full_eval) is not bool: fp16_full_eval = False
|
| 1479 |
-
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 1480 |
-
if type(bf16_full_eval) is not bool: bf16_full_eval = False
|
| 1481 |
-
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 1482 |
-
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 1483 |
-
if force_float32:
|
| 1484 |
-
args.bf16_full_eval = False
|
| 1485 |
-
args.fp16_full_eval = False
|
| 1486 |
-
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 1487 |
-
args.bf16_full_eval = True
|
| 1488 |
-
args.fp16_full_eval = False
|
| 1489 |
-
elif not bf16_full_eval and not fp16_full_eval:
|
| 1490 |
-
args.bf16_full_eval = args.bf16
|
| 1491 |
-
args.fp16_full_eval = args.fp16
|
| 1492 |
-
_output_logits = False
|
| 1493 |
-
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 1494 |
-
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 1495 |
-
if _output_logits:
|
| 1496 |
-
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 1497 |
-
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 1498 |
-
pass
|
| 1499 |
-
else:
|
| 1500 |
-
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 1501 |
-
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 1502 |
-
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 1503 |
-
max_seq_length = model.max_seq_length
|
| 1504 |
-
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 1505 |
-
if model is not None and hasattr(model, 'for_training'):
|
| 1506 |
-
model.for_training()
|
| 1507 |
-
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 1508 |
-
if 'processing_class' in locals():
|
| 1509 |
-
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 1510 |
-
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 1511 |
-
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 1512 |
-
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 1513 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1514 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 1515 |
-
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
|
| 1516 |
-
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 1517 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
| 1518 |
-
else:
|
| 1519 |
-
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 1520 |
-
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 1521 |
-
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 1522 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1523 |
-
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 1524 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 1525 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
| 1526 |
-
else:
|
| 1527 |
-
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
|
| 1528 |
-
other_metrics = []
|
| 1529 |
-
|
| 1530 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 1531 |
-
PatchRLStatistics('orpo_trainer', other_metrics)
|
| 1532 |
-
|
| 1533 |
-
super().__init__(
|
| 1534 |
-
model = model,
|
| 1535 |
-
args = args,
|
| 1536 |
-
data_collator = data_collator,
|
| 1537 |
-
train_dataset = train_dataset,
|
| 1538 |
-
eval_dataset = eval_dataset,
|
| 1539 |
-
processing_class = processing_class,
|
| 1540 |
-
model_init = model_init,
|
| 1541 |
-
callbacks = callbacks,
|
| 1542 |
-
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
| 1543 |
-
peft_config = peft_config,
|
| 1544 |
-
compute_metrics = compute_metrics,**kwargs)
|
| 1545 |
-
if hasattr(self, 'neftune_hook_handle'):
|
| 1546 |
-
self.neftune_hook_handle.remove()
|
| 1547 |
-
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 1548 |
-
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 1549 |
-
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 1550 |
-
pass
|
| 1551 |
-
|
| 1552 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_run_uploads/UnslothOnlineDPOTrainer.py
DELETED
|
@@ -1,1293 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
2025.7.11
|
| 3 |
-
2025.7.11
|
| 4 |
-
4.54.1
|
| 5 |
-
0.16.1
|
| 6 |
-
__UNSLOTH_VERSIONING__
|
| 7 |
-
"""
|
| 8 |
-
from torch import Tensor
|
| 9 |
-
import torch
|
| 10 |
-
import torch.nn as nn
|
| 11 |
-
from torch.nn import functional as F
|
| 12 |
-
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
| 13 |
-
from trl.trainer.online_dpo_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalPrediction, F, FeatureExtractionMixin, GenerationConfig, IterableDataset, OnlineDPOConfig, OnlineDPOTrainer, OptimizerNames, Optional, PREFIX_CHECKPOINT_DIR, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, Trainer, TrainerCallback, Union, apply_chat_template, create_reference_model, datasets, disable_dropout_in_model, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_peft_available, is_wandb_available, jinja2, logging, maybe_apply_chat_template, nn, np, os, prepare_deepspeed, seed_worker, textwrap, torch, transformers, truncate_right, unwrap_model_for_generation, version, wandb, warnings, wraps, F, is_conversational, os, torch)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
import os
|
| 17 |
-
from typing import *
|
| 18 |
-
from dataclasses import dataclass, field
|
| 19 |
-
from packaging.version import Version
|
| 20 |
-
import torch
|
| 21 |
-
import numpy as np
|
| 22 |
-
from contextlib import nullcontext
|
| 23 |
-
from torch.nn import functional as F
|
| 24 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
| 25 |
-
|
| 26 |
-
torch_compile_options = {
|
| 27 |
-
"epilogue_fusion" : True,
|
| 28 |
-
"max_autotune" : False,
|
| 29 |
-
"shape_padding" : True,
|
| 30 |
-
"trace.enabled" : False,
|
| 31 |
-
"triton.cudagraphs" : False,
|
| 32 |
-
}
|
| 33 |
-
|
| 34 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 35 |
-
def chunked_selective_log_softmax(logits, index):
|
| 36 |
-
# Split into 4 chunks only
|
| 37 |
-
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
|
| 38 |
-
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
|
| 39 |
-
all_per_token_logps = []
|
| 40 |
-
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
|
| 41 |
-
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
|
| 42 |
-
chunk_logits = chunk_logits.to(torch.float32)
|
| 43 |
-
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 44 |
-
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
|
| 45 |
-
per_token_logps = selected_logits - logsumexp_values
|
| 46 |
-
all_per_token_logps.append(per_token_logps)
|
| 47 |
-
pass
|
| 48 |
-
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 49 |
-
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
|
| 50 |
-
return all_per_token_logps
|
| 51 |
-
def vLLMSamplingParams(**kwargs):
|
| 52 |
-
from vllm import SamplingParams
|
| 53 |
-
sampling_params = SamplingParams(**kwargs)
|
| 54 |
-
sampling_params._set_kwargs = kwargs
|
| 55 |
-
return sampling_params
|
| 56 |
-
@dataclass
|
| 57 |
-
class UnslothOnlineDPOConfig(OnlineDPOConfig):
|
| 58 |
-
"""
|
| 59 |
-
|
| 60 |
-
Configuration class for the [`OnlineDPOTrainer`].
|
| 61 |
-
|
| 62 |
-
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 63 |
-
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 64 |
-
command line.
|
| 65 |
-
|
| 66 |
-
Parameters:
|
| 67 |
-
learning_rate (`float`, *optional*, defaults to `5e-7`):
|
| 68 |
-
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
| 69 |
-
[`~transformers.TrainingArguments`].
|
| 70 |
-
reward_model_path (`str` or `None`, *optional*, defaults to `None`):
|
| 71 |
-
Path to the reward model. Either `judge` or `reward_model_path` must be set, but not both.
|
| 72 |
-
judge (`str` or `None`, *optional*, defaults to `None`):
|
| 73 |
-
Name of the judge to use. Either `judge` or `reward_model_path` must be set, but not both.
|
| 74 |
-
max_new_tokens (`int`, *optional*, defaults to `64`):
|
| 75 |
-
Maximum number of tokens to generate per completion.
|
| 76 |
-
max_length (`int`, *optional*, defaults to `256`):
|
| 77 |
-
Maximum total length of the sequence (prompt + completion) used to compute log probabilities. If the
|
| 78 |
-
sequence exceeds this limit, the leftmost tokens will be truncated to preserve as much of the completion as
|
| 79 |
-
possible.
|
| 80 |
-
temperature (`float`, *optional*, defaults to `0.9`):
|
| 81 |
-
Temperature for sampling. The higher the temperature, the more random the completions.
|
| 82 |
-
missing_eos_penalty (`float` or `None`, *optional*, defaults to `None`):
|
| 83 |
-
Penalty applied to the score when the model fails to generate an EOS token. This is useful to encourage
|
| 84 |
-
to generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be a positive
|
| 85 |
-
value.
|
| 86 |
-
beta (`float` or `list[float]`, *optional*, defaults to `0.1`):
|
| 87 |
-
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
|
| 88 |
-
reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
|
| 89 |
-
the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the β is
|
| 90 |
-
selected for each new epoch and the last β is used for the rest of the epochs.
|
| 91 |
-
loss_type (`str`, *optional*, defaults to `"sigmoid"`):
|
| 92 |
-
Type of loss to use. Possible values are:
|
| 93 |
-
|
| 94 |
-
- `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
|
| 95 |
-
- `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
|
| 96 |
-
|
| 97 |
-
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
| 98 |
-
Number of processes to use for processing the dataset.
|
| 99 |
-
disable_dropout (`bool`, *optional*, defaults to `True`):
|
| 100 |
-
Whether to disable dropout in the model and reference model.
|
| 101 |
-
use_vllm (`bool`, *optional*, defaults to `False`):
|
| 102 |
-
Whether to use vLLM for generating completions. Requires vLLM to be installed (`pip install vllm`).
|
| 103 |
-
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
|
| 104 |
-
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
|
| 105 |
-
improving generation speed. However, disabling this option allows training models that exceed the VRAM
|
| 106 |
-
capacity of a single GPU, albeit at the cost of slower generation.
|
| 107 |
-
|
| 108 |
-
"""
|
| 109 |
-
vllm_sampling_params: Optional[Any] = field(
|
| 110 |
-
default = None,
|
| 111 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
| 112 |
-
)
|
| 113 |
-
unsloth_num_chunks : Optional[int] = field(
|
| 114 |
-
default = -1,
|
| 115 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 116 |
-
)
|
| 117 |
-
def __init__(
|
| 118 |
-
self,
|
| 119 |
-
output_dir = None,
|
| 120 |
-
overwrite_output_dir = None,
|
| 121 |
-
do_train = False,
|
| 122 |
-
do_eval = False,
|
| 123 |
-
do_predict = False,
|
| 124 |
-
eval_strategy = 'no',
|
| 125 |
-
prediction_loss_only = False,
|
| 126 |
-
per_device_train_batch_size = 4,
|
| 127 |
-
per_device_eval_batch_size = 4,
|
| 128 |
-
per_gpu_train_batch_size = None,
|
| 129 |
-
per_gpu_eval_batch_size = None,
|
| 130 |
-
gradient_accumulation_steps = 2,
|
| 131 |
-
eval_accumulation_steps = 2,
|
| 132 |
-
eval_delay = 0,
|
| 133 |
-
torch_empty_cache_steps = 250,
|
| 134 |
-
learning_rate = 5e-05,
|
| 135 |
-
weight_decay = 0.01,
|
| 136 |
-
adam_beta1 = 0.9,
|
| 137 |
-
adam_beta2 = 0.999,
|
| 138 |
-
adam_epsilon = 1e-08,
|
| 139 |
-
max_grad_norm = 1.0,
|
| 140 |
-
num_train_epochs = 3.0,
|
| 141 |
-
max_steps = -1,
|
| 142 |
-
lr_scheduler_type = 'linear',
|
| 143 |
-
warmup_ratio = 0.1,
|
| 144 |
-
warmup_steps = 0,
|
| 145 |
-
log_level = 'passive',
|
| 146 |
-
log_level_replica = 'warning',
|
| 147 |
-
log_on_each_node = True,
|
| 148 |
-
logging_dir = None,
|
| 149 |
-
logging_strategy = 'steps',
|
| 150 |
-
logging_first_step = False,
|
| 151 |
-
logging_steps = 1,
|
| 152 |
-
logging_nan_inf_filter = False,
|
| 153 |
-
save_strategy = 'steps',
|
| 154 |
-
save_steps = 500,
|
| 155 |
-
save_total_limit = None,
|
| 156 |
-
save_safetensors = True,
|
| 157 |
-
save_on_each_node = False,
|
| 158 |
-
save_only_model = False,
|
| 159 |
-
restore_callback_states_from_checkpoint = False,
|
| 160 |
-
no_cuda = False,
|
| 161 |
-
use_cpu = False,
|
| 162 |
-
use_mps_device = False,
|
| 163 |
-
seed = 3407,
|
| 164 |
-
data_seed = 3407,
|
| 165 |
-
jit_mode_eval = False,
|
| 166 |
-
use_ipex = False,
|
| 167 |
-
bf16 = False,
|
| 168 |
-
fp16 = False,
|
| 169 |
-
fp16_opt_level = 'O1',
|
| 170 |
-
half_precision_backend = 'auto',
|
| 171 |
-
bf16_full_eval = False,
|
| 172 |
-
fp16_full_eval = False,
|
| 173 |
-
tf32 = None,
|
| 174 |
-
local_rank = -1,
|
| 175 |
-
ddp_backend = None,
|
| 176 |
-
tpu_num_cores = None,
|
| 177 |
-
tpu_metrics_debug = False,
|
| 178 |
-
debug = '',
|
| 179 |
-
dataloader_drop_last = False,
|
| 180 |
-
eval_steps = None,
|
| 181 |
-
dataloader_num_workers = 0,
|
| 182 |
-
dataloader_prefetch_factor = None,
|
| 183 |
-
past_index = -1,
|
| 184 |
-
run_name = None,
|
| 185 |
-
disable_tqdm = None,
|
| 186 |
-
remove_unused_columns = True,
|
| 187 |
-
label_names = None,
|
| 188 |
-
load_best_model_at_end = False,
|
| 189 |
-
metric_for_best_model = None,
|
| 190 |
-
greater_is_better = None,
|
| 191 |
-
ignore_data_skip = False,
|
| 192 |
-
fsdp = '',
|
| 193 |
-
fsdp_min_num_params = 0,
|
| 194 |
-
fsdp_config = None,
|
| 195 |
-
fsdp_transformer_layer_cls_to_wrap = None,
|
| 196 |
-
accelerator_config = None,
|
| 197 |
-
deepspeed = None,
|
| 198 |
-
label_smoothing_factor = 0.0,
|
| 199 |
-
optim = 'adamw_8bit',
|
| 200 |
-
optim_args = None,
|
| 201 |
-
adafactor = False,
|
| 202 |
-
group_by_length = False,
|
| 203 |
-
length_column_name = 'length',
|
| 204 |
-
report_to = None,
|
| 205 |
-
ddp_find_unused_parameters = None,
|
| 206 |
-
ddp_bucket_cap_mb = None,
|
| 207 |
-
ddp_broadcast_buffers = None,
|
| 208 |
-
dataloader_pin_memory = True,
|
| 209 |
-
dataloader_persistent_workers = False,
|
| 210 |
-
skip_memory_metrics = True,
|
| 211 |
-
use_legacy_prediction_loop = False,
|
| 212 |
-
push_to_hub = False,
|
| 213 |
-
resume_from_checkpoint = None,
|
| 214 |
-
hub_model_id = None,
|
| 215 |
-
hub_strategy = 'every_save',
|
| 216 |
-
hub_token = None,
|
| 217 |
-
hub_private_repo = None,
|
| 218 |
-
hub_always_push = False,
|
| 219 |
-
hub_revision = None,
|
| 220 |
-
gradient_checkpointing = False,
|
| 221 |
-
gradient_checkpointing_kwargs = None,
|
| 222 |
-
include_inputs_for_metrics = False,
|
| 223 |
-
eval_do_concat_batches = True,
|
| 224 |
-
fp16_backend = 'auto',
|
| 225 |
-
push_to_hub_model_id = None,
|
| 226 |
-
push_to_hub_organization = None,
|
| 227 |
-
push_to_hub_token = None,
|
| 228 |
-
mp_parameters = '',
|
| 229 |
-
auto_find_batch_size = True,
|
| 230 |
-
full_determinism = False,
|
| 231 |
-
torchdynamo = None,
|
| 232 |
-
ray_scope = 'last',
|
| 233 |
-
ddp_timeout = 1800,
|
| 234 |
-
torch_compile = False,
|
| 235 |
-
torch_compile_backend = None,
|
| 236 |
-
torch_compile_mode = None,
|
| 237 |
-
include_tokens_per_second = False,
|
| 238 |
-
include_num_input_tokens_seen = False,
|
| 239 |
-
neftune_noise_alpha = None,
|
| 240 |
-
optim_target_modules = None,
|
| 241 |
-
batch_eval_metrics = False,
|
| 242 |
-
eval_on_start = False,
|
| 243 |
-
use_liger_kernel = False,
|
| 244 |
-
liger_kernel_config = None,
|
| 245 |
-
eval_use_gather_object = False,
|
| 246 |
-
average_tokens_across_devices = True,
|
| 247 |
-
reward_model_path = None,
|
| 248 |
-
judge = None,
|
| 249 |
-
max_new_tokens = 64,
|
| 250 |
-
max_length = 512,
|
| 251 |
-
temperature = 0.9,
|
| 252 |
-
missing_eos_penalty = None,
|
| 253 |
-
loss_type = 'sigmoid',
|
| 254 |
-
dataset_num_proc = None,
|
| 255 |
-
disable_dropout = True,
|
| 256 |
-
use_vllm = False,
|
| 257 |
-
ds3_gather_for_generation = True,
|
| 258 |
-
vllm_sampling_params = None,
|
| 259 |
-
unsloth_num_chunks = -1,
|
| 260 |
-
**kwargs,
|
| 261 |
-
):
|
| 262 |
-
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 263 |
-
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 264 |
-
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 265 |
-
output_dir = 'unsloth_training_checkpoints'
|
| 266 |
-
save_strategy = 'no'
|
| 267 |
-
if dataset_num_proc is None:
|
| 268 |
-
from multiprocessing import cpu_count
|
| 269 |
-
dataset_num_proc = min(cpu_count()*2, 2)
|
| 270 |
-
if temperature <= 0:
|
| 271 |
-
raise MathError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
|
| 272 |
-
elif temperature >= 10:
|
| 273 |
-
raise MathError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
super().__init__(
|
| 277 |
-
output_dir = output_dir,
|
| 278 |
-
overwrite_output_dir = overwrite_output_dir,
|
| 279 |
-
do_train = do_train,
|
| 280 |
-
do_eval = do_eval,
|
| 281 |
-
do_predict = do_predict,
|
| 282 |
-
eval_strategy = eval_strategy,
|
| 283 |
-
prediction_loss_only = prediction_loss_only,
|
| 284 |
-
per_device_train_batch_size = per_device_train_batch_size,
|
| 285 |
-
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 286 |
-
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
| 287 |
-
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
| 288 |
-
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 289 |
-
eval_accumulation_steps = eval_accumulation_steps,
|
| 290 |
-
eval_delay = eval_delay,
|
| 291 |
-
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 292 |
-
learning_rate = learning_rate,
|
| 293 |
-
weight_decay = weight_decay,
|
| 294 |
-
adam_beta1 = adam_beta1,
|
| 295 |
-
adam_beta2 = adam_beta2,
|
| 296 |
-
adam_epsilon = adam_epsilon,
|
| 297 |
-
max_grad_norm = max_grad_norm,
|
| 298 |
-
num_train_epochs = num_train_epochs,
|
| 299 |
-
max_steps = max_steps,
|
| 300 |
-
lr_scheduler_type = lr_scheduler_type,
|
| 301 |
-
warmup_ratio = warmup_ratio,
|
| 302 |
-
warmup_steps = warmup_steps,
|
| 303 |
-
log_level = log_level,
|
| 304 |
-
log_level_replica = log_level_replica,
|
| 305 |
-
log_on_each_node = log_on_each_node,
|
| 306 |
-
logging_dir = logging_dir,
|
| 307 |
-
logging_strategy = logging_strategy,
|
| 308 |
-
logging_first_step = logging_first_step,
|
| 309 |
-
logging_steps = logging_steps,
|
| 310 |
-
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 311 |
-
save_strategy = save_strategy,
|
| 312 |
-
save_steps = save_steps,
|
| 313 |
-
save_total_limit = save_total_limit,
|
| 314 |
-
save_safetensors = save_safetensors,
|
| 315 |
-
save_on_each_node = save_on_each_node,
|
| 316 |
-
save_only_model = save_only_model,
|
| 317 |
-
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 318 |
-
no_cuda = no_cuda,
|
| 319 |
-
use_cpu = use_cpu,
|
| 320 |
-
use_mps_device = use_mps_device,
|
| 321 |
-
seed = seed,
|
| 322 |
-
data_seed = data_seed,
|
| 323 |
-
jit_mode_eval = jit_mode_eval,
|
| 324 |
-
use_ipex = use_ipex,
|
| 325 |
-
bf16 = bf16,
|
| 326 |
-
fp16 = fp16,
|
| 327 |
-
fp16_opt_level = fp16_opt_level,
|
| 328 |
-
half_precision_backend = half_precision_backend,
|
| 329 |
-
bf16_full_eval = bf16_full_eval,
|
| 330 |
-
fp16_full_eval = fp16_full_eval,
|
| 331 |
-
tf32 = tf32,
|
| 332 |
-
local_rank = local_rank,
|
| 333 |
-
ddp_backend = ddp_backend,
|
| 334 |
-
tpu_num_cores = tpu_num_cores,
|
| 335 |
-
tpu_metrics_debug = tpu_metrics_debug,
|
| 336 |
-
debug = debug,
|
| 337 |
-
dataloader_drop_last = dataloader_drop_last,
|
| 338 |
-
eval_steps = eval_steps,
|
| 339 |
-
dataloader_num_workers = dataloader_num_workers,
|
| 340 |
-
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 341 |
-
past_index = past_index,
|
| 342 |
-
run_name = run_name,
|
| 343 |
-
disable_tqdm = disable_tqdm,
|
| 344 |
-
remove_unused_columns = remove_unused_columns,
|
| 345 |
-
label_names = label_names,
|
| 346 |
-
load_best_model_at_end = load_best_model_at_end,
|
| 347 |
-
metric_for_best_model = metric_for_best_model,
|
| 348 |
-
greater_is_better = greater_is_better,
|
| 349 |
-
ignore_data_skip = ignore_data_skip,
|
| 350 |
-
fsdp = fsdp,
|
| 351 |
-
fsdp_min_num_params = fsdp_min_num_params,
|
| 352 |
-
fsdp_config = fsdp_config,
|
| 353 |
-
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
| 354 |
-
accelerator_config = accelerator_config,
|
| 355 |
-
deepspeed = deepspeed,
|
| 356 |
-
label_smoothing_factor = label_smoothing_factor,
|
| 357 |
-
optim = optim,
|
| 358 |
-
optim_args = optim_args,
|
| 359 |
-
adafactor = adafactor,
|
| 360 |
-
group_by_length = group_by_length,
|
| 361 |
-
length_column_name = length_column_name,
|
| 362 |
-
report_to = report_to,
|
| 363 |
-
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 364 |
-
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 365 |
-
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 366 |
-
dataloader_pin_memory = dataloader_pin_memory,
|
| 367 |
-
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 368 |
-
skip_memory_metrics = skip_memory_metrics,
|
| 369 |
-
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
| 370 |
-
push_to_hub = push_to_hub,
|
| 371 |
-
resume_from_checkpoint = resume_from_checkpoint,
|
| 372 |
-
hub_model_id = hub_model_id,
|
| 373 |
-
hub_strategy = hub_strategy,
|
| 374 |
-
hub_token = hub_token,
|
| 375 |
-
hub_private_repo = hub_private_repo,
|
| 376 |
-
hub_always_push = hub_always_push,
|
| 377 |
-
hub_revision = hub_revision,
|
| 378 |
-
gradient_checkpointing = gradient_checkpointing,
|
| 379 |
-
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 380 |
-
include_inputs_for_metrics = include_inputs_for_metrics,
|
| 381 |
-
eval_do_concat_batches = eval_do_concat_batches,
|
| 382 |
-
fp16_backend = fp16_backend,
|
| 383 |
-
push_to_hub_model_id = push_to_hub_model_id,
|
| 384 |
-
push_to_hub_organization = push_to_hub_organization,
|
| 385 |
-
push_to_hub_token = push_to_hub_token,
|
| 386 |
-
mp_parameters = mp_parameters,
|
| 387 |
-
auto_find_batch_size = auto_find_batch_size,
|
| 388 |
-
full_determinism = full_determinism,
|
| 389 |
-
torchdynamo = torchdynamo,
|
| 390 |
-
ray_scope = ray_scope,
|
| 391 |
-
ddp_timeout = ddp_timeout,
|
| 392 |
-
torch_compile = torch_compile,
|
| 393 |
-
torch_compile_backend = torch_compile_backend,
|
| 394 |
-
torch_compile_mode = torch_compile_mode,
|
| 395 |
-
include_tokens_per_second = include_tokens_per_second,
|
| 396 |
-
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 397 |
-
neftune_noise_alpha = neftune_noise_alpha,
|
| 398 |
-
optim_target_modules = optim_target_modules,
|
| 399 |
-
batch_eval_metrics = batch_eval_metrics,
|
| 400 |
-
eval_on_start = eval_on_start,
|
| 401 |
-
use_liger_kernel = use_liger_kernel,
|
| 402 |
-
liger_kernel_config = liger_kernel_config,
|
| 403 |
-
eval_use_gather_object = eval_use_gather_object,
|
| 404 |
-
average_tokens_across_devices = average_tokens_across_devices,
|
| 405 |
-
reward_model_path = reward_model_path,
|
| 406 |
-
judge = judge,
|
| 407 |
-
max_new_tokens = max_new_tokens,
|
| 408 |
-
max_length = max_length,
|
| 409 |
-
temperature = temperature,
|
| 410 |
-
missing_eos_penalty = missing_eos_penalty,
|
| 411 |
-
loss_type = loss_type,
|
| 412 |
-
dataset_num_proc = dataset_num_proc,
|
| 413 |
-
disable_dropout = disable_dropout,
|
| 414 |
-
use_vllm = use_vllm,
|
| 415 |
-
ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
|
| 416 |
-
self.vllm_sampling_params = vllm_sampling_params
|
| 417 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
| 418 |
-
pass
|
| 419 |
-
|
| 420 |
-
class _UnslothOnlineDPOTrainer(Trainer):
|
| 421 |
-
r""""""
|
| 422 |
-
|
| 423 |
-
_tag_names = ["trl", "online-dpo"]
|
| 424 |
-
|
| 425 |
-
def __init__(
|
| 426 |
-
self,
|
| 427 |
-
model: Union[PreTrainedModel, nn.Module],
|
| 428 |
-
ref_model: Union[PreTrainedModel, nn.Module, None] = None,
|
| 429 |
-
reward_model: Union[PreTrainedModel, nn.Module, None] = None,
|
| 430 |
-
judge: Optional[BasePairwiseJudge] = None,
|
| 431 |
-
args: Optional[OnlineDPOConfig] = None,
|
| 432 |
-
data_collator: Optional[DataCollator] = None,
|
| 433 |
-
train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
|
| 434 |
-
eval_dataset: Optional[Union[Dataset, dict[str, Dataset], "datasets.Dataset"]] = None,
|
| 435 |
-
processing_class: Optional[
|
| 436 |
-
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 437 |
-
] = None,
|
| 438 |
-
reward_processing_class: Optional[PreTrainedTokenizerBase] = None,
|
| 439 |
-
peft_config: Optional[dict] = None,
|
| 440 |
-
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
| 441 |
-
callbacks: Optional[list[TrainerCallback]] = None,
|
| 442 |
-
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 443 |
-
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 444 |
-
) -> None:
|
| 445 |
-
|
| 446 |
-
if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm'):
|
| 447 |
-
if (getattr(args, 'use_vllm', False) == False):
|
| 448 |
-
args.use_vllm = True
|
| 449 |
-
if ref_model is model:
|
| 450 |
-
raise ValueError(
|
| 451 |
-
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
|
| 452 |
-
"same as `model`, either omit the `ref_model` argument or pass `None`."
|
| 453 |
-
)
|
| 454 |
-
|
| 455 |
-
self.ref_model = ref_model
|
| 456 |
-
|
| 457 |
-
if reward_model is not None and judge is not None:
|
| 458 |
-
warnings.warn(
|
| 459 |
-
"Both `reward_model` and `judge` are provided. Please choose provide only one of them. "
|
| 460 |
-
"Ignoring `judge` and using `reward_model`.",
|
| 461 |
-
UserWarning,
|
| 462 |
-
)
|
| 463 |
-
judge = None
|
| 464 |
-
elif reward_model is None and judge is None:
|
| 465 |
-
raise ValueError("Either `reward_model` or `judge` must be provided.")
|
| 466 |
-
|
| 467 |
-
self.reward_model = reward_model
|
| 468 |
-
self.reward_processing_class = reward_processing_class
|
| 469 |
-
self.judge = judge
|
| 470 |
-
|
| 471 |
-
if args.missing_eos_penalty is not None and judge is not None:
|
| 472 |
-
raise ValueError("`missing_eos_penalty` is not supported when `judge` is provided.")
|
| 473 |
-
|
| 474 |
-
if args is None:
|
| 475 |
-
raise ValueError("`args` must be provided.")
|
| 476 |
-
|
| 477 |
-
# Check that the processing_class is provided
|
| 478 |
-
if processing_class is None:
|
| 479 |
-
raise ValueError("`processing_class` must be provided.")
|
| 480 |
-
|
| 481 |
-
# Convert to PEFT model if peft_config is provided
|
| 482 |
-
if False:
|
| 483 |
-
# Check if PEFT is available
|
| 484 |
-
if not is_peft_available():
|
| 485 |
-
raise ImportError(
|
| 486 |
-
"PEFT is not available and passed `peft_config`. Please install PEFT with "
|
| 487 |
-
"`pip install peft` to use it."
|
| 488 |
-
)
|
| 489 |
-
|
| 490 |
-
# If the model is already a PeftModel, we need to merge and unload it.
|
| 491 |
-
# Further information here: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft
|
| 492 |
-
if isinstance(model, PeftModel):
|
| 493 |
-
model = model.merge_and_unload()
|
| 494 |
-
|
| 495 |
-
# Get peft model with the given config
|
| 496 |
-
model = model
|
| 497 |
-
|
| 498 |
-
# Disable dropout in the model and reference model
|
| 499 |
-
if args.disable_dropout:
|
| 500 |
-
disable_dropout_in_model(model)
|
| 501 |
-
if self.ref_model is not None:
|
| 502 |
-
disable_dropout_in_model(self.ref_model)
|
| 503 |
-
|
| 504 |
-
# Handle the ref_model
|
| 505 |
-
# Usually, the user wants the ref model to be the initial version of the model. When using PEFT, it's easy to
|
| 506 |
-
# get the ref model, as it's just the model with a disabled adapter. When not using PEFT, we need to create
|
| 507 |
-
# the ref model from the model by copying it and disable the gradients and set it in evaluation mode.
|
| 508 |
-
if ref_model is None: # No ref model provided, the most common case
|
| 509 |
-
if False:
|
| 510 |
-
self.ref_model = create_reference_model(model) # copy, disable gradients, set eval mode
|
| 511 |
-
else:
|
| 512 |
-
self.ref_model = None # we don't need a ref model here, we can just disable the adapter.
|
| 513 |
-
else: # rare case, the user provided a ref model
|
| 514 |
-
self.ref_model = ref_model
|
| 515 |
-
self.ref_model.eval()
|
| 516 |
-
|
| 517 |
-
# Disable the gradient and set the reward model in eval mode
|
| 518 |
-
if self.reward_model is not None:
|
| 519 |
-
self.reward_model.eval()
|
| 520 |
-
|
| 521 |
-
# Define the collator is not provided
|
| 522 |
-
if data_collator is None:
|
| 523 |
-
data_collator = DPODataCollatorWithPadding(pad_token_id=processing_class.pad_token_id)
|
| 524 |
-
|
| 525 |
-
self.max_length = args.max_length
|
| 526 |
-
|
| 527 |
-
self.stats = {
|
| 528 |
-
"objective/kl": [],
|
| 529 |
-
"objective/entropy": [],
|
| 530 |
-
"objective/non_score_reward": [],
|
| 531 |
-
"rewards/chosen": [],
|
| 532 |
-
"rewards/rejected": [],
|
| 533 |
-
"rewards/accuracies": [],
|
| 534 |
-
"rewards/margins": [],
|
| 535 |
-
"logps/chosen": [],
|
| 536 |
-
"logps/rejected": [],
|
| 537 |
-
"val/contain_eos_token": [],
|
| 538 |
-
"beta": [],
|
| 539 |
-
}
|
| 540 |
-
if self.reward_model is not None:
|
| 541 |
-
self.stats["objective/rlhf_reward"] = []
|
| 542 |
-
self.stats["objective/scores_margin"] = []
|
| 543 |
-
self.stats["objective/scores"] = []
|
| 544 |
-
|
| 545 |
-
if args.use_vllm:
|
| 546 |
-
self.llm = model.vllm_engine; self._last_loaded_step = 0; self.generation_config = SamplingParams(
|
| 547 |
-
n=2,
|
| 548 |
-
max_tokens=args.max_new_tokens,
|
| 549 |
-
temperature=args.temperature,
|
| 550 |
-
top_k=50,
|
| 551 |
-
top_p=1.0,
|
| 552 |
-
detokenize=False,
|
| 553 |
-
**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {}),
|
| 554 |
-
)
|
| 555 |
-
else:
|
| 556 |
-
self.generation_config = GenerationConfig(
|
| 557 |
-
max_new_tokens=args.max_new_tokens,
|
| 558 |
-
temperature=args.temperature,
|
| 559 |
-
top_k=50,
|
| 560 |
-
top_p=1.0,
|
| 561 |
-
do_sample=True,
|
| 562 |
-
use_cache=False if args.gradient_checkpointing else True,
|
| 563 |
-
)
|
| 564 |
-
|
| 565 |
-
# The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
|
| 566 |
-
# input tensor associated with the key "input_ids". However, in Online DPO, the sampled data does not include
|
| 567 |
-
# the "input_ids" key. As a result, the trainer issues the warning: "Could not estimate the number of tokens
|
| 568 |
-
# of the input, floating-point operations will not be computed." To suppress this warning, we set the
|
| 569 |
-
# "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
|
| 570 |
-
# that the warning has already been issued.
|
| 571 |
-
model.warnings_issued["estimate_tokens"] = True
|
| 572 |
-
|
| 573 |
-
super().__init__(
|
| 574 |
-
model=model,
|
| 575 |
-
args=args,
|
| 576 |
-
data_collator=data_collator,
|
| 577 |
-
train_dataset=train_dataset,
|
| 578 |
-
eval_dataset=eval_dataset,
|
| 579 |
-
processing_class=processing_class,
|
| 580 |
-
compute_metrics=compute_metrics,
|
| 581 |
-
callbacks=callbacks,
|
| 582 |
-
optimizers=optimizers,
|
| 583 |
-
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 584 |
-
)
|
| 585 |
-
|
| 586 |
-
# Add tags for models that have been loaded with the correct transformers version
|
| 587 |
-
if hasattr(self.model, "add_model_tags"):
|
| 588 |
-
self.model.add_model_tags(self._tag_names)
|
| 589 |
-
|
| 590 |
-
self._beta = args.beta
|
| 591 |
-
|
| 592 |
-
# Placed after the super[].__init__ because we need self.is_deepspeed_enabled and self.accelerator
|
| 593 |
-
if self.is_deepspeed_enabled:
|
| 594 |
-
if self.reward_model is not None:
|
| 595 |
-
self.reward_model = prepare_deepspeed(
|
| 596 |
-
self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
|
| 597 |
-
)
|
| 598 |
-
if self.ref_model is not None:
|
| 599 |
-
self.ref_model = prepare_deepspeed(
|
| 600 |
-
self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
|
| 601 |
-
)
|
| 602 |
-
else:
|
| 603 |
-
if self.ref_model is not None:
|
| 604 |
-
self.ref_model = self.ref_model.to(self.accelerator.device)
|
| 605 |
-
if self.reward_model is not None:
|
| 606 |
-
self.reward_model = self.reward_model.to(self.accelerator.device)
|
| 607 |
-
|
| 608 |
-
@property
|
| 609 |
-
def beta(self):
|
| 610 |
-
if isinstance(self._beta, list):
|
| 611 |
-
epoch = self.state.epoch
|
| 612 |
-
return self._beta[epoch] if epoch < len(self._beta) else self._beta[-1]
|
| 613 |
-
else:
|
| 614 |
-
return self._beta
|
| 615 |
-
|
| 616 |
-
@staticmethod
|
| 617 |
-
def tokenize_row(feature, is_encoder_decoder: bool, tokenizer: PreTrainedTokenizerBase) -> dict[str, Any]:
|
| 618 |
-
"""Tokenize a single row from a DPO specific dataset."""
|
| 619 |
-
if not is_encoder_decoder:
|
| 620 |
-
batch = tokenizer(feature["prompt"], add_special_tokens=False)
|
| 621 |
-
# Add BOS token to head of prompt. Avoid adding if it's already there
|
| 622 |
-
if tokenizer.bos_token_id is not None:
|
| 623 |
-
prompt_len_input_ids = len(batch["input_ids"])
|
| 624 |
-
if prompt_len_input_ids == 0 or tokenizer.bos_token_id != batch["input_ids"][0]:
|
| 625 |
-
batch["input_ids"] = [tokenizer.bos_token_id] + batch["input_ids"]
|
| 626 |
-
batch["attention_mask"] = [1] + batch["attention_mask"]
|
| 627 |
-
else:
|
| 628 |
-
batch = tokenizer(feature["prompt"], add_special_tokens=True)
|
| 629 |
-
batch = {f"prompt_{key}": value for key, value in batch.items()}
|
| 630 |
-
return batch
|
| 631 |
-
|
| 632 |
-
# Same as Trainer.get_train_dataloader but skip the "remove_unused_columns".
|
| 633 |
-
@wraps(Trainer.get_train_dataloader)
|
| 634 |
-
def get_train_dataloader(self) -> DataLoader:
|
| 635 |
-
if self.train_dataset is None:
|
| 636 |
-
raise ValueError("Trainer: training requires a train_dataset.")
|
| 637 |
-
|
| 638 |
-
train_dataset = self.train_dataset
|
| 639 |
-
data_collator = self.data_collator
|
| 640 |
-
dataloader_params = {
|
| 641 |
-
"batch_size": self._train_batch_size,
|
| 642 |
-
"collate_fn": data_collator,
|
| 643 |
-
"num_workers": self.args.dataloader_num_workers,
|
| 644 |
-
"pin_memory": self.args.dataloader_pin_memory,
|
| 645 |
-
"persistent_workers": self.args.dataloader_persistent_workers,
|
| 646 |
-
}
|
| 647 |
-
|
| 648 |
-
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
| 649 |
-
dataloader_params["sampler"] = self._get_train_sampler()
|
| 650 |
-
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
| 651 |
-
dataloader_params["worker_init_fn"] = seed_worker
|
| 652 |
-
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
| 653 |
-
|
| 654 |
-
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
|
| 655 |
-
|
| 656 |
-
# Same as Trainer.get_eval_dataloader but skip the "remove_unused_columns".
|
| 657 |
-
@wraps(Trainer.get_eval_dataloader)
|
| 658 |
-
def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader:
|
| 659 |
-
if eval_dataset is None and self.eval_dataset is None:
|
| 660 |
-
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
| 661 |
-
|
| 662 |
-
# If we have persistent workers, don't do a fork bomb especially as eval datasets
|
| 663 |
-
# don't change during training
|
| 664 |
-
dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
|
| 665 |
-
if (
|
| 666 |
-
hasattr(self, "_eval_dataloaders")
|
| 667 |
-
and dataloader_key in self._eval_dataloaders
|
| 668 |
-
and self.args.dataloader_persistent_workers
|
| 669 |
-
):
|
| 670 |
-
return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])
|
| 671 |
-
|
| 672 |
-
eval_dataset = (
|
| 673 |
-
self.eval_dataset[eval_dataset]
|
| 674 |
-
if isinstance(eval_dataset, str)
|
| 675 |
-
else eval_dataset
|
| 676 |
-
if eval_dataset is not None
|
| 677 |
-
else self.eval_dataset
|
| 678 |
-
)
|
| 679 |
-
data_collator = self.data_collator
|
| 680 |
-
|
| 681 |
-
dataloader_params = {
|
| 682 |
-
"batch_size": self.args.eval_batch_size,
|
| 683 |
-
"collate_fn": data_collator,
|
| 684 |
-
"num_workers": self.args.dataloader_num_workers,
|
| 685 |
-
"pin_memory": self.args.dataloader_pin_memory,
|
| 686 |
-
"persistent_workers": self.args.dataloader_persistent_workers,
|
| 687 |
-
}
|
| 688 |
-
|
| 689 |
-
if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
|
| 690 |
-
dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
|
| 691 |
-
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
| 692 |
-
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
| 693 |
-
|
| 694 |
-
# accelerator.free_memory() will destroy the references, so
|
| 695 |
-
# we need to store the non-prepared version
|
| 696 |
-
eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
|
| 697 |
-
if self.args.dataloader_persistent_workers:
|
| 698 |
-
if hasattr(self, "_eval_dataloaders"):
|
| 699 |
-
self._eval_dataloaders[dataloader_key] = eval_dataloader
|
| 700 |
-
else:
|
| 701 |
-
self._eval_dataloaders = {dataloader_key: eval_dataloader}
|
| 702 |
-
|
| 703 |
-
return self.accelerator.prepare(eval_dataloader)
|
| 704 |
-
|
| 705 |
-
def _generate_vllm(self, model, prompts):
|
| 706 |
-
eos_token_id = self.processing_class.eos_token_id
|
| 707 |
-
pad_token_id = self.processing_class.pad_token_id
|
| 708 |
-
|
| 709 |
-
# Load the latest weights
|
| 710 |
-
|
| 711 |
-
pass
|
| 712 |
-
|
| 713 |
-
pass
|
| 714 |
-
|
| 715 |
-
if is_conversational({"prompt": prompts[0]}):
|
| 716 |
-
outputs = self.llm.chat(prompts, self.generation_config, use_tqdm=False, lora_request = self.model.load_lora('online_dpo_trainer_lora_model', load_tensors = True))
|
| 717 |
-
else:
|
| 718 |
-
outputs = self.llm.generate(prompts, self.generation_config, use_tqdm=False, lora_request = self.model.load_lora('online_dpo_trainer_lora_model', load_tensors = True))
|
| 719 |
-
|
| 720 |
-
completion_ids = [list(output.outputs[i].token_ids) for i in range(2) for output in outputs]
|
| 721 |
-
prompt_ids = [list(output.prompt_token_ids) for _ in range(2) for output in outputs]
|
| 722 |
-
|
| 723 |
-
# Create mask and pad the prompt and completion
|
| 724 |
-
max_prompt_length = max(len(ids) for ids in prompt_ids)
|
| 725 |
-
prompt_mask = [[0] * (max_prompt_length - len(ids)) + [1] * len(ids) for ids in prompt_ids]
|
| 726 |
-
prompt_ids = [[pad_token_id] * (max_prompt_length - len(ids)) + ids for ids in prompt_ids]
|
| 727 |
-
max_tokens = self.generation_config.max_tokens
|
| 728 |
-
completion_mask = [[1] * len(ids) + [0] * (max_tokens - len(ids)) for ids in completion_ids]
|
| 729 |
-
completion_ids = [
|
| 730 |
-
ids + [eos_token_id] if ids[-1] != eos_token_id and len(ids) < max_tokens else ids
|
| 731 |
-
for ids in completion_ids
|
| 732 |
-
]
|
| 733 |
-
completion_ids = [ids + [pad_token_id] * (max_tokens - len(ids)) for ids in completion_ids]
|
| 734 |
-
|
| 735 |
-
# Convert to tensors
|
| 736 |
-
prompt_ids = torch.tensor(prompt_ids, device=self.accelerator.device)
|
| 737 |
-
prompt_mask = torch.tensor(prompt_mask, device=self.accelerator.device)
|
| 738 |
-
completion_ids = torch.tensor(completion_ids, device=self.accelerator.device)
|
| 739 |
-
completion_mask = torch.tensor(completion_mask, device=self.accelerator.device)
|
| 740 |
-
|
| 741 |
-
return prompt_ids, prompt_mask, completion_ids, completion_mask
|
| 742 |
-
|
| 743 |
-
def _generate(self, model, prompts):
|
| 744 |
-
eos_token_id = self.processing_class.eos_token_id
|
| 745 |
-
pad_token_id = self.processing_class.pad_token_id
|
| 746 |
-
|
| 747 |
-
# Apply chat template and tokenize the input. We do this on-the-fly to enable the use of reward models and
|
| 748 |
-
# policies with different tokenizers / chat templates.
|
| 749 |
-
inputs = [{"prompt": prompt} for prompt in prompts]
|
| 750 |
-
inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
|
| 751 |
-
inputs = [self.tokenize_row(x, model.config.is_encoder_decoder, self.processing_class) for x in inputs]
|
| 752 |
-
inputs = self.data_collator(inputs)
|
| 753 |
-
|
| 754 |
-
# Sample 2 completions per prompt of size `max_new_tokens` from the model
|
| 755 |
-
inputs = self._prepare_inputs(inputs)
|
| 756 |
-
prompt_ids = inputs["prompt_input_ids"].repeat(2, 1)
|
| 757 |
-
prompt_mask = inputs["prompt_attention_mask"].repeat(2, 1)
|
| 758 |
-
with unwrap_model_for_generation(
|
| 759 |
-
model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
| 760 |
-
) as unwrapped_model:
|
| 761 |
-
output = unwrapped_model.generate(
|
| 762 |
-
input_ids=prompt_ids,
|
| 763 |
-
attention_mask=prompt_mask,
|
| 764 |
-
generation_config=self.generation_config,
|
| 765 |
-
)
|
| 766 |
-
|
| 767 |
-
completion_ids = output[:, prompt_ids.size(1) :]
|
| 768 |
-
completion_ids, completion_mask = truncate_right(completion_ids, eos_token_id, pad_token_id)
|
| 769 |
-
|
| 770 |
-
return prompt_ids, prompt_mask, completion_ids, completion_mask
|
| 771 |
-
|
| 772 |
-
def _forward(self, model, prompt_ids, prompt_mask, completion_ids, completion_mask):
|
| 773 |
-
# Get the number of tokens to truncate from prompt
|
| 774 |
-
num_tokens_to_truncate = max(prompt_ids.size(1) + completion_ids.size(1) - self.max_length, 0)
|
| 775 |
-
|
| 776 |
-
# Truncate left to avoid oom
|
| 777 |
-
prompt_ids = prompt_ids[:, num_tokens_to_truncate:]
|
| 778 |
-
prompt_mask = prompt_mask[:, num_tokens_to_truncate:]
|
| 779 |
-
|
| 780 |
-
# Concat the prompt and completion
|
| 781 |
-
prompt_completion_ids = torch.cat((prompt_ids, completion_ids), dim=1)
|
| 782 |
-
prompt_completion_mask = torch.cat((prompt_mask, completion_mask), dim=1)
|
| 783 |
-
|
| 784 |
-
# Get the logprobs of the completions from the model
|
| 785 |
-
output = model(prompt_completion_ids, attention_mask=prompt_completion_mask)
|
| 786 |
-
|
| 787 |
-
# There is 1 offset, because the model predict the next token
|
| 788 |
-
logits = output.logits[:, prompt_ids.size(1) - 1 : -1]
|
| 789 |
-
|
| 790 |
-
# Take the completion tokens logprob
|
| 791 |
-
logprobs = torch.take_along_dim(logits.log_softmax(dim=-1), completion_ids.unsqueeze(-1), dim=2).squeeze(-1)
|
| 792 |
-
return logprobs
|
| 793 |
-
|
| 794 |
-
def training_step(
|
| 795 |
-
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
|
| 796 |
-
) -> torch.Tensor:
|
| 797 |
-
model.train()
|
| 798 |
-
|
| 799 |
-
prompts = inputs["prompt"]
|
| 800 |
-
batch_size = len(prompts)
|
| 801 |
-
|
| 802 |
-
if self.args.use_vllm:
|
| 803 |
-
prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_vllm(model, prompts)
|
| 804 |
-
else:
|
| 805 |
-
prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate(model, prompts)
|
| 806 |
-
|
| 807 |
-
contain_eos_token = torch.any(completion_ids == self.processing_class.eos_token_id, dim=-1)
|
| 808 |
-
|
| 809 |
-
logprobs = self._forward(model, prompt_ids, prompt_mask, completion_ids, completion_mask)
|
| 810 |
-
with torch.no_grad():
|
| 811 |
-
if self.ref_model is not None:
|
| 812 |
-
ref_logprobs = self._forward(self.ref_model, prompt_ids, prompt_mask, completion_ids, completion_mask)
|
| 813 |
-
else: # peft case: we just need to disable the adapter
|
| 814 |
-
with self.model.disable_adapter():
|
| 815 |
-
ref_logprobs = self._forward(self.model, prompt_ids, prompt_mask, completion_ids, completion_mask)
|
| 816 |
-
|
| 817 |
-
# Decode the completions, and format them if the input is conversational
|
| 818 |
-
device = logprobs.device
|
| 819 |
-
completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
| 820 |
-
if is_conversational({"prompt": prompts[0]}):
|
| 821 |
-
completions = [[{"role": "assistant", "content": completion}] for completion in completions]
|
| 822 |
-
|
| 823 |
-
# Get the reward from the reward model or judge
|
| 824 |
-
if self.judge is not None:
|
| 825 |
-
# Once formatted, conversational data may contain special tokens (such as <|im_start|>) that are not
|
| 826 |
-
# directly understandable by the judge and could alter its judgment. To avoid this and make the judge
|
| 827 |
-
# independent of the model's chat template, we use the raw conversation data, and apply our own chat
|
| 828 |
-
# template to it.
|
| 829 |
-
if is_conversational({"prompt": prompts[0]}):
|
| 830 |
-
environment = jinja2.Environment()
|
| 831 |
-
template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
|
| 832 |
-
prompts = [template.render(messages=prompt) for prompt in prompts]
|
| 833 |
-
completions = [template.render(messages=completion) for completion in completions]
|
| 834 |
-
|
| 835 |
-
ranks_of_first_completion = self.judge.judge(
|
| 836 |
-
prompts, list(zip(completions[:batch_size], completions[batch_size:]))
|
| 837 |
-
)
|
| 838 |
-
|
| 839 |
-
# convert ranks to a True/False mask:
|
| 840 |
-
# when rank == 0, it means the first completion is the best
|
| 841 |
-
# when rank == 1, it means the second completion is the best
|
| 842 |
-
mask = torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=device)
|
| 843 |
-
else:
|
| 844 |
-
# The reward model may not have the same chat template or tokenizer as the model, so we need to use the
|
| 845 |
-
# raw data (string), apply the chat template (if needed), and tokenize it with the reward processing class.
|
| 846 |
-
prompts = 2 * prompts # repeat the prompt: [prompt0, prompt1] -> [prompt0, prompt1, prompt0, prompt1]
|
| 847 |
-
if is_conversational({"prompt": prompts[0]}):
|
| 848 |
-
examples = [{"prompt": p, "completion": c} for p, c in zip(prompts, completions)]
|
| 849 |
-
examples = [apply_chat_template(example, self.reward_processing_class) for example in examples]
|
| 850 |
-
prompts = [example["prompt"] for example in examples]
|
| 851 |
-
completions = [example["completion"] for example in examples]
|
| 852 |
-
|
| 853 |
-
# Tokenize the prompts
|
| 854 |
-
prompts_ids = self.reward_processing_class(
|
| 855 |
-
prompts, padding=True, return_tensors="pt", padding_side="left"
|
| 856 |
-
)["input_ids"].to(device)
|
| 857 |
-
context_length = prompts_ids.shape[1]
|
| 858 |
-
|
| 859 |
-
# Tokenize the completions
|
| 860 |
-
completions_ids = self.reward_processing_class(
|
| 861 |
-
completions, padding=True, return_tensors="pt", padding_side="right"
|
| 862 |
-
)["input_ids"].to(device)
|
| 863 |
-
|
| 864 |
-
# Concatenate the prompts and completions and get the reward
|
| 865 |
-
prompt_completion_ids = torch.cat((prompts_ids, completions_ids), dim=1)
|
| 866 |
-
with torch.inference_mode():
|
| 867 |
-
_, scores, _ = get_reward(
|
| 868 |
-
self.reward_model, prompt_completion_ids, self.reward_processing_class.pad_token_id, context_length
|
| 869 |
-
)
|
| 870 |
-
|
| 871 |
-
# Filter completion. Ensure that the sample contains stop_token_id
|
| 872 |
-
# Completions not passing that filter will receive a lower score.
|
| 873 |
-
if self.args.missing_eos_penalty is not None:
|
| 874 |
-
scores[~contain_eos_token] -= self.args.missing_eos_penalty
|
| 875 |
-
|
| 876 |
-
# Split the scores in 2 (the prompts of the first half are the same as the second half)
|
| 877 |
-
first_half, second_half = scores.split(batch_size)
|
| 878 |
-
|
| 879 |
-
# Get the indices of the chosen and rejected examples
|
| 880 |
-
mask = first_half >= second_half
|
| 881 |
-
|
| 882 |
-
batch_range = torch.arange(batch_size, device=device)
|
| 883 |
-
chosen_indices = batch_range + (~mask * batch_size)
|
| 884 |
-
rejected_indices = batch_range + (mask * batch_size)
|
| 885 |
-
|
| 886 |
-
# Build tensor so that the first half is the chosen examples and the second half the rejected examples
|
| 887 |
-
cr_indices = torch.cat((chosen_indices, rejected_indices), dim=0) # cr = chosen and rejected
|
| 888 |
-
cr_logprobs = logprobs[cr_indices]
|
| 889 |
-
cr_ref_logprobs = ref_logprobs[cr_indices]
|
| 890 |
-
|
| 891 |
-
# mask out the padding tokens
|
| 892 |
-
padding_mask = ~completion_mask.bool()
|
| 893 |
-
cr_padding_mask = padding_mask[cr_indices]
|
| 894 |
-
|
| 895 |
-
cr_logprobs_sum = (cr_logprobs * ~cr_padding_mask).sum(1)
|
| 896 |
-
cr_ref_logprobs_sum = (cr_ref_logprobs * ~cr_padding_mask).sum(1)
|
| 897 |
-
|
| 898 |
-
# Split the chosen and rejected examples
|
| 899 |
-
chosen_logprobs_sum, rejected_logprobs_sum = torch.split(cr_logprobs_sum, batch_size)
|
| 900 |
-
chosen_ref_logprobs_sum, rejected_ref_logprobs_sum = torch.split(cr_ref_logprobs_sum, batch_size)
|
| 901 |
-
pi_logratios = chosen_logprobs_sum - rejected_logprobs_sum
|
| 902 |
-
ref_logratios = chosen_ref_logprobs_sum - rejected_ref_logprobs_sum
|
| 903 |
-
|
| 904 |
-
logits = pi_logratios - ref_logratios
|
| 905 |
-
|
| 906 |
-
if self.args.loss_type == "sigmoid":
|
| 907 |
-
losses = -F.logsigmoid(self.beta * logits)
|
| 908 |
-
elif self.args.loss_type == "ipo":
|
| 909 |
-
losses = (logits - 1 / (2 * self.beta)) ** 2
|
| 910 |
-
else:
|
| 911 |
-
raise NotImplementedError(f"invalid loss type {self.loss_type}")
|
| 912 |
-
|
| 913 |
-
loss = losses.mean()
|
| 914 |
-
|
| 915 |
-
# Log everything
|
| 916 |
-
if self.reward_model is not None:
|
| 917 |
-
scores_margin = scores[chosen_indices] - scores[rejected_indices]
|
| 918 |
-
self.stats["objective/scores_margin"].append(
|
| 919 |
-
self.accelerator.gather_for_metrics(scores_margin.mean()).mean().item()
|
| 920 |
-
)
|
| 921 |
-
self.stats["objective/scores"].append(self.accelerator.gather_for_metrics(scores.mean()).mean().item())
|
| 922 |
-
self.stats["val/contain_eos_token"].append(contain_eos_token.float().mean().item())
|
| 923 |
-
self.stats["logps/chosen"].append(self.accelerator.gather_for_metrics(chosen_logprobs_sum).mean().item())
|
| 924 |
-
self.stats["logps/rejected"].append(self.accelerator.gather_for_metrics(rejected_logprobs_sum).mean().item())
|
| 925 |
-
|
| 926 |
-
kl = logprobs - ref_logprobs
|
| 927 |
-
mean_kl = kl.sum(1).mean()
|
| 928 |
-
self.stats["objective/kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
|
| 929 |
-
non_score_reward = (-self.beta * kl).sum(1)
|
| 930 |
-
mean_non_score_reward = non_score_reward.mean()
|
| 931 |
-
self.stats["objective/non_score_reward"].append(
|
| 932 |
-
self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
|
| 933 |
-
)
|
| 934 |
-
if self.reward_model is not None:
|
| 935 |
-
rlhf_reward = scores + non_score_reward
|
| 936 |
-
self.stats["objective/rlhf_reward"].append(self.accelerator.gather_for_metrics(rlhf_reward).mean().item())
|
| 937 |
-
mean_entropy = -logprobs.sum(1).mean()
|
| 938 |
-
self.stats["objective/entropy"].append(self.accelerator.gather_for_metrics(mean_entropy).mean().item())
|
| 939 |
-
chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum)
|
| 940 |
-
gathered_chosen_rewards = self.accelerator.gather_for_metrics(chosen_rewards)
|
| 941 |
-
self.stats["rewards/chosen"].append(gathered_chosen_rewards.mean().item())
|
| 942 |
-
rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum)
|
| 943 |
-
gathered_rejected_rewards = self.accelerator.gather_for_metrics(rejected_rewards)
|
| 944 |
-
self.stats["rewards/rejected"].append(gathered_rejected_rewards.mean().item())
|
| 945 |
-
margin = gathered_chosen_rewards - gathered_rejected_rewards
|
| 946 |
-
self.stats["rewards/margins"].append(margin.mean().item())
|
| 947 |
-
accuracy = margin > 0
|
| 948 |
-
self.stats["rewards/accuracies"].append(accuracy.float().mean().item())
|
| 949 |
-
self.stats["beta"].append(self.beta)
|
| 950 |
-
|
| 951 |
-
if (
|
| 952 |
-
self.args.torch_empty_cache_steps is not None
|
| 953 |
-
and self.state.global_step % self.args.torch_empty_cache_steps == 0
|
| 954 |
-
):
|
| 955 |
-
empty_cache()
|
| 956 |
-
|
| 957 |
-
kwargs = {}
|
| 958 |
-
|
| 959 |
-
# For LOMO optimizers you need to explicitly use the learnign rate
|
| 960 |
-
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
| 961 |
-
kwargs["learning_rate"] = self._get_learning_rate()
|
| 962 |
-
|
| 963 |
-
if self.args.n_gpu > 1:
|
| 964 |
-
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
| 965 |
-
|
| 966 |
-
if self.use_apex:
|
| 967 |
-
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
| 968 |
-
scaled_loss.backward()
|
| 969 |
-
else:
|
| 970 |
-
self.accelerator.backward(loss, **kwargs)
|
| 971 |
-
|
| 972 |
-
return loss.detach() / self.args.gradient_accumulation_steps
|
| 973 |
-
|
| 974 |
-
# Same as Trainer._maybe_log_save_evaluate but log our metrics
|
| 975 |
-
# start_time defaults to None to allow compatibility with transformers<=4.46
|
| 976 |
-
def _maybe_log_save_evaluate(
|
| 977 |
-
self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None, learning_rate=None
|
| 978 |
-
):
|
| 979 |
-
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
|
| 980 |
-
logs: dict[str, float] = {}
|
| 981 |
-
|
| 982 |
-
# all_gather + mean() to get average loss over all processes
|
| 983 |
-
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
|
| 984 |
-
|
| 985 |
-
# reset tr_loss to zero
|
| 986 |
-
tr_loss -= tr_loss
|
| 987 |
-
|
| 988 |
-
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
|
| 989 |
-
if grad_norm is not None:
|
| 990 |
-
logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
|
| 991 |
-
if learning_rate is not None:
|
| 992 |
-
logs["learning_rate"] = learning_rate
|
| 993 |
-
else:
|
| 994 |
-
logs["learning_rate"] = self._get_learning_rate()
|
| 995 |
-
|
| 996 |
-
# Add our metrics
|
| 997 |
-
for key, val in self.stats.items():
|
| 998 |
-
logs[key] = sum(val) / len(val)
|
| 999 |
-
self.stats = {key: [] for key in self.stats} # reset stats
|
| 1000 |
-
|
| 1001 |
-
self._total_loss_scalar += tr_loss_scalar
|
| 1002 |
-
self._globalstep_last_logged = self.state.global_step
|
| 1003 |
-
self.store_flos()
|
| 1004 |
-
|
| 1005 |
-
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
| 1006 |
-
self.log(logs, start_time)
|
| 1007 |
-
else: # transformers<=4.46
|
| 1008 |
-
self.log(logs)
|
| 1009 |
-
|
| 1010 |
-
metrics = None
|
| 1011 |
-
if self.control.should_evaluate:
|
| 1012 |
-
metrics = self._evaluate(trial, ignore_keys_for_eval)
|
| 1013 |
-
is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
|
| 1014 |
-
|
| 1015 |
-
if self.args.save_strategy == "best":
|
| 1016 |
-
self.control.should_save = is_new_best_metric
|
| 1017 |
-
|
| 1018 |
-
if self.control.should_save:
|
| 1019 |
-
self._save_checkpoint(model, trial)
|
| 1020 |
-
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
| 1021 |
-
|
| 1022 |
-
# Copy-pasted from transformers.Trainer to maintain compatibility with earlier versions.
|
| 1023 |
-
# This can be removed once the minimum transformers version is updated to 4.47.
|
| 1024 |
-
# Refer to https://github.com/huggingface/trl/pull/2288 for more details.
|
| 1025 |
-
def _determine_best_metric(self, metrics, trial):
|
| 1026 |
-
"""
|
| 1027 |
-
Determine if the model should be saved based on the evaluation metrics.
|
| 1028 |
-
If args.metric_for_best_model is not set, the loss is used.
|
| 1029 |
-
Returns:
|
| 1030 |
-
bool: True if a new best metric was found, else False
|
| 1031 |
-
"""
|
| 1032 |
-
is_new_best_metric = False
|
| 1033 |
-
|
| 1034 |
-
if self.args.metric_for_best_model is not None:
|
| 1035 |
-
metric_to_check = self.args.metric_for_best_model
|
| 1036 |
-
|
| 1037 |
-
if not metric_to_check.startswith("eval_"):
|
| 1038 |
-
metric_to_check = f"eval_{metric_to_check}"
|
| 1039 |
-
|
| 1040 |
-
try:
|
| 1041 |
-
metric_value = metrics[metric_to_check]
|
| 1042 |
-
except KeyError as exc:
|
| 1043 |
-
raise KeyError(
|
| 1044 |
-
f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. "
|
| 1045 |
-
f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments."
|
| 1046 |
-
) from exc
|
| 1047 |
-
|
| 1048 |
-
operator = np.greater if self.args.greater_is_better else np.less
|
| 1049 |
-
|
| 1050 |
-
if self.state.best_metric is None:
|
| 1051 |
-
self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf")
|
| 1052 |
-
|
| 1053 |
-
if operator(metric_value, self.state.best_metric):
|
| 1054 |
-
run_dir = self._get_output_dir(trial=trial)
|
| 1055 |
-
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
| 1056 |
-
output_dir = os.path.join(run_dir, checkpoint_folder)
|
| 1057 |
-
self.state.best_metric = metric_value
|
| 1058 |
-
self.state.best_model_checkpoint = output_dir
|
| 1059 |
-
|
| 1060 |
-
is_new_best_metric = True
|
| 1061 |
-
|
| 1062 |
-
return is_new_best_metric
|
| 1063 |
-
|
| 1064 |
-
def create_model_card(
|
| 1065 |
-
self,
|
| 1066 |
-
model_name: Optional[str] = None,
|
| 1067 |
-
dataset_name: Optional[str] = None,
|
| 1068 |
-
tags: Union[str, list[str], None] = None,
|
| 1069 |
-
):
|
| 1070 |
-
"""
|
| 1071 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
| 1072 |
-
|
| 1073 |
-
Args:
|
| 1074 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1075 |
-
Name of the model.
|
| 1076 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1077 |
-
Name of the dataset used for training.
|
| 1078 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 1079 |
-
Tags to be associated with the model card.
|
| 1080 |
-
"""
|
| 1081 |
-
if not self.is_world_process_zero():
|
| 1082 |
-
return
|
| 1083 |
-
|
| 1084 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 1085 |
-
base_model = self.model.config._name_or_path
|
| 1086 |
-
else:
|
| 1087 |
-
base_model = None
|
| 1088 |
-
|
| 1089 |
-
tags = tags or []
|
| 1090 |
-
if isinstance(tags, str):
|
| 1091 |
-
tags = [tags]
|
| 1092 |
-
|
| 1093 |
-
if hasattr(self.model.config, "unsloth_version"):
|
| 1094 |
-
tags.append("unsloth")
|
| 1095 |
-
|
| 1096 |
-
citation = textwrap.dedent("""\
|
| 1097 |
-
@article{guo2024direct,
|
| 1098 |
-
title = {{Direct Language Model Alignment from Online AI Feedback}},
|
| 1099 |
-
author = {Shangmin Guo and Biao Zhang and Tianlin Liu and Tianqi Liu and Misha Khalman and Felipe Llinares and Alexandre Ram{\'{e}} and Thomas Mesnard and Yao Zhao and Bilal Piot and Johan Ferret and Mathieu Blondel},
|
| 1100 |
-
year = 2024,
|
| 1101 |
-
eprint = {arXiv:2402.04792}
|
| 1102 |
-
}""")
|
| 1103 |
-
|
| 1104 |
-
model_card = generate_model_card(
|
| 1105 |
-
base_model=base_model,
|
| 1106 |
-
model_name=model_name,
|
| 1107 |
-
hub_model_id=self.hub_model_id,
|
| 1108 |
-
dataset_name=dataset_name,
|
| 1109 |
-
tags=tags,
|
| 1110 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 1111 |
-
comet_url=get_comet_experiment_url(),
|
| 1112 |
-
trainer_name="Online DPO",
|
| 1113 |
-
trainer_citation=citation,
|
| 1114 |
-
paper_title="Direct Language Model Alignment from Online AI Feedback",
|
| 1115 |
-
paper_id="2402.04792",
|
| 1116 |
-
)
|
| 1117 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 1118 |
-
class UnslothOnlineDPOTrainer(_UnslothOnlineDPOTrainer):
|
| 1119 |
-
"""
|
| 1120 |
-
|
| 1121 |
-
Initialize OnlineDPOTrainer.
|
| 1122 |
-
|
| 1123 |
-
Args:
|
| 1124 |
-
model (`transformers.PreTrainedModel` or `torch.nn.Module`):
|
| 1125 |
-
The model to train, preferably an `AutoModelForCausalLM`.
|
| 1126 |
-
ref_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`):
|
| 1127 |
-
The reference model to use for training. If None is specified, the reference model will be created from
|
| 1128 |
-
the model.
|
| 1129 |
-
reward_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`):
|
| 1130 |
-
The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
|
| 1131 |
-
judge (`BasePairwiseJudge`):
|
| 1132 |
-
The judge to use for pairwise comparison of model completions.
|
| 1133 |
-
args (`OnlineDPOConfig`):
|
| 1134 |
-
The online DPO config arguments to use for training.
|
| 1135 |
-
data_collator (`transformers.DataCollator`):
|
| 1136 |
-
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
| 1137 |
-
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
| 1138 |
-
train_dataset (`datasets.Dataset`):
|
| 1139 |
-
The dataset to use for training.
|
| 1140 |
-
eval_dataset (`datasets.Dataset`):
|
| 1141 |
-
The dataset to use for evaluation.
|
| 1142 |
-
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
| 1143 |
-
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
| 1144 |
-
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
| 1145 |
-
reuse the fine-tuned model.
|
| 1146 |
-
peft_config (`dict`):
|
| 1147 |
-
The peft config to use for training.
|
| 1148 |
-
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
| 1149 |
-
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
| 1150 |
-
a dictionary string to metric values.
|
| 1151 |
-
callbacks (`list[transformers.TrainerCallback]`):
|
| 1152 |
-
The callbacks to use for training.
|
| 1153 |
-
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
| 1154 |
-
The optimizer and scheduler to use for training.
|
| 1155 |
-
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
| 1156 |
-
The function to use to preprocess the logits before computing the metrics.
|
| 1157 |
-
|
| 1158 |
-
"""
|
| 1159 |
-
def __init__(
|
| 1160 |
-
self,
|
| 1161 |
-
model,
|
| 1162 |
-
ref_model = None,
|
| 1163 |
-
reward_model = None,
|
| 1164 |
-
judge = None,
|
| 1165 |
-
args = None,
|
| 1166 |
-
data_collator = None,
|
| 1167 |
-
train_dataset = None,
|
| 1168 |
-
eval_dataset = None,
|
| 1169 |
-
processing_class = None,
|
| 1170 |
-
reward_processing_class = None,
|
| 1171 |
-
peft_config = None,
|
| 1172 |
-
compute_metrics = None,
|
| 1173 |
-
callbacks = None,
|
| 1174 |
-
preprocess_logits_for_metrics = None,
|
| 1175 |
-
**kwargs
|
| 1176 |
-
):
|
| 1177 |
-
if args is None: args = UnslothOnlineDPOConfig()
|
| 1178 |
-
use_bf16 = getattr(args, 'bf16', False)
|
| 1179 |
-
if type(use_bf16) is not bool: use_bf16 = False
|
| 1180 |
-
use_fp16 = getattr(args, 'fp16', False)
|
| 1181 |
-
if type(use_fp16) is not bool: use_fp16 = False
|
| 1182 |
-
force_float32 = False
|
| 1183 |
-
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
| 1184 |
-
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 1185 |
-
force_float32 = True
|
| 1186 |
-
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 1187 |
-
dtype = getattr(model.config, 'torch_dtype', None)
|
| 1188 |
-
if dtype is None: dtype = model.get_input_embeddings().dtype
|
| 1189 |
-
from unsloth_zoo.utils import _get_dtype
|
| 1190 |
-
dtype = _get_dtype(dtype)
|
| 1191 |
-
float16 = dtype == torch.float16
|
| 1192 |
-
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 1193 |
-
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 1194 |
-
if force_float32:
|
| 1195 |
-
args.fp16 = False
|
| 1196 |
-
args.bf16 = False
|
| 1197 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 1198 |
-
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 1199 |
-
args.fp16 = float16
|
| 1200 |
-
args.bf16 = not float16
|
| 1201 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 1202 |
-
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 1203 |
-
args.eval_strategy = 'steps'
|
| 1204 |
-
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 1205 |
-
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 1206 |
-
if ga_steps is not None and ga_steps > 1:
|
| 1207 |
-
from transformers import __version__ as transformers_version
|
| 1208 |
-
if Version(transformers_version) <= Version('4.45.2'):
|
| 1209 |
-
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 1210 |
-
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 1211 |
-
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 1212 |
-
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 1213 |
-
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 1214 |
-
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 1215 |
-
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 1216 |
-
if type(fp16_full_eval) is not bool: fp16_full_eval = False
|
| 1217 |
-
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 1218 |
-
if type(bf16_full_eval) is not bool: bf16_full_eval = False
|
| 1219 |
-
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 1220 |
-
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 1221 |
-
if force_float32:
|
| 1222 |
-
args.bf16_full_eval = False
|
| 1223 |
-
args.fp16_full_eval = False
|
| 1224 |
-
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 1225 |
-
args.bf16_full_eval = True
|
| 1226 |
-
args.fp16_full_eval = False
|
| 1227 |
-
elif not bf16_full_eval and not fp16_full_eval:
|
| 1228 |
-
args.bf16_full_eval = args.bf16
|
| 1229 |
-
args.fp16_full_eval = args.fp16
|
| 1230 |
-
_output_logits = False
|
| 1231 |
-
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 1232 |
-
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 1233 |
-
if _output_logits:
|
| 1234 |
-
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 1235 |
-
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 1236 |
-
pass
|
| 1237 |
-
else:
|
| 1238 |
-
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 1239 |
-
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 1240 |
-
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 1241 |
-
max_seq_length = model.max_seq_length
|
| 1242 |
-
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 1243 |
-
if model is not None and hasattr(model, 'for_training'):
|
| 1244 |
-
model.for_training()
|
| 1245 |
-
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 1246 |
-
if 'processing_class' in locals():
|
| 1247 |
-
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 1248 |
-
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 1249 |
-
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 1250 |
-
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 1251 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1252 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 1253 |
-
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
|
| 1254 |
-
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 1255 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
| 1256 |
-
else:
|
| 1257 |
-
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 1258 |
-
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 1259 |
-
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 1260 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1261 |
-
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 1262 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 1263 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
| 1264 |
-
else:
|
| 1265 |
-
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
|
| 1266 |
-
other_metrics = []
|
| 1267 |
-
|
| 1268 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 1269 |
-
PatchRLStatistics('online_dpo_trainer', other_metrics)
|
| 1270 |
-
|
| 1271 |
-
super().__init__(
|
| 1272 |
-
model = model,
|
| 1273 |
-
ref_model = ref_model,
|
| 1274 |
-
reward_model = reward_model,
|
| 1275 |
-
judge = judge,
|
| 1276 |
-
args = args,
|
| 1277 |
-
data_collator = data_collator,
|
| 1278 |
-
train_dataset = train_dataset,
|
| 1279 |
-
eval_dataset = eval_dataset,
|
| 1280 |
-
processing_class = processing_class,
|
| 1281 |
-
reward_processing_class = reward_processing_class,
|
| 1282 |
-
peft_config = peft_config,
|
| 1283 |
-
compute_metrics = compute_metrics,
|
| 1284 |
-
callbacks = callbacks,
|
| 1285 |
-
preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
|
| 1286 |
-
if hasattr(self, 'neftune_hook_handle'):
|
| 1287 |
-
self.neftune_hook_handle.remove()
|
| 1288 |
-
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 1289 |
-
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 1290 |
-
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 1291 |
-
pass
|
| 1292 |
-
|
| 1293 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_run_uploads/UnslothPPOTrainer.py
DELETED
|
@@ -1,1273 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
2025.7.11
|
| 3 |
-
2025.7.11
|
| 4 |
-
4.54.1
|
| 5 |
-
0.16.1
|
| 6 |
-
__UNSLOTH_VERSIONING__
|
| 7 |
-
"""
|
| 8 |
-
from torch import Tensor
|
| 9 |
-
import torch
|
| 10 |
-
import torch.nn as nn
|
| 11 |
-
from torch.nn import functional as F
|
| 12 |
-
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
| 13 |
-
from trl.trainer.ppo_trainer import (Accelerator, BaseImageProcessor, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, GenerationConfig, INVALID_LOGPROB, OnlineTrainerState, Optional, PPOConfig, PPOTrainer, PeftConfig, PeftModel, PolicyAndValueWrapper, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, Trainer, TrainerCallback, TrainerControl, Union, batch_generation, broadcast, contextmanager, create_reference_model, defaultdict, disable_dropout_in_model, exact_div, first_true_indices, forward, gather_object, gc, generate_model_card, get_comet_experiment_url, get_peft_model, get_reporting_integration_callbacks, get_reward, is_peft_available, is_wandb_available, log_table_to_comet_experiment, masked_mean, masked_whiten, math, nn, np, nullcontext, os, pd, peft_module_casting_to_bf16, prepare_deepspeed, print_rich_table, selective_log_softmax, textwrap, time, torch, truncate_response, unwrap_model_for_generation, wandb)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
import os
|
| 17 |
-
from typing import *
|
| 18 |
-
from dataclasses import dataclass, field
|
| 19 |
-
from packaging.version import Version
|
| 20 |
-
import torch
|
| 21 |
-
import numpy as np
|
| 22 |
-
from contextlib import nullcontext
|
| 23 |
-
from torch.nn import functional as F
|
| 24 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
| 25 |
-
|
| 26 |
-
torch_compile_options = {
|
| 27 |
-
"epilogue_fusion" : True,
|
| 28 |
-
"max_autotune" : False,
|
| 29 |
-
"shape_padding" : True,
|
| 30 |
-
"trace.enabled" : False,
|
| 31 |
-
"triton.cudagraphs" : False,
|
| 32 |
-
}
|
| 33 |
-
|
| 34 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 35 |
-
def chunked_selective_log_softmax(logits, index):
|
| 36 |
-
# Split into 4 chunks only
|
| 37 |
-
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
|
| 38 |
-
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
|
| 39 |
-
all_per_token_logps = []
|
| 40 |
-
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
|
| 41 |
-
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
|
| 42 |
-
chunk_logits = chunk_logits.to(torch.float32)
|
| 43 |
-
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 44 |
-
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
|
| 45 |
-
per_token_logps = selected_logits - logsumexp_values
|
| 46 |
-
all_per_token_logps.append(per_token_logps)
|
| 47 |
-
pass
|
| 48 |
-
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 49 |
-
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
|
| 50 |
-
return all_per_token_logps
|
| 51 |
-
@dataclass
|
| 52 |
-
class UnslothPPOConfig(PPOConfig):
|
| 53 |
-
"""
|
| 54 |
-
|
| 55 |
-
Configuration class for the [`PPOTrainer`].
|
| 56 |
-
|
| 57 |
-
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 58 |
-
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 59 |
-
command line.
|
| 60 |
-
|
| 61 |
-
Parameters:
|
| 62 |
-
exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`):
|
| 63 |
-
Name of this experiment.
|
| 64 |
-
reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
|
| 65 |
-
Path to the reward model.
|
| 66 |
-
model_adapter_name (`str` or `None`, *optional*, defaults to `None`):
|
| 67 |
-
Name of the train target PEFT adapter, when using LoRA with multiple adapters.
|
| 68 |
-
ref_adapter_name (`str` or `None`, *optional*, defaults to `None`):
|
| 69 |
-
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
|
| 70 |
-
num_ppo_epochs (`int`, *optional*, defaults to `4`):
|
| 71 |
-
Number of epochs to train.
|
| 72 |
-
whiten_rewards (`bool`, *optional*, defaults to `False`):
|
| 73 |
-
Whether to whiten the rewards.
|
| 74 |
-
kl_coef (`float`, *optional*, defaults to `0.05`):
|
| 75 |
-
KL coefficient.
|
| 76 |
-
cliprange (`float`, *optional*, defaults to `0.2`):
|
| 77 |
-
Clip range.
|
| 78 |
-
vf_coef (`float`, *optional*, defaults to `0.1`):
|
| 79 |
-
Value function coefficient.
|
| 80 |
-
cliprange_value (`float`, *optional*, defaults to `0.2`):
|
| 81 |
-
Clip range for the value function.
|
| 82 |
-
gamma (`float`, *optional*, defaults to `1.0`):
|
| 83 |
-
Discount factor.
|
| 84 |
-
lam (`float`, *optional*, defaults to `0.95`):
|
| 85 |
-
Lambda value for GAE.
|
| 86 |
-
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
|
| 87 |
-
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
|
| 88 |
-
improving generation speed. However, disabling this option allows training models that exceed the VRAM
|
| 89 |
-
capacity of a single GPU, albeit at the cost of slower generation.
|
| 90 |
-
|
| 91 |
-
"""
|
| 92 |
-
vllm_sampling_params: Optional[Any] = field(
|
| 93 |
-
default = None,
|
| 94 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
| 95 |
-
)
|
| 96 |
-
unsloth_num_chunks : Optional[int] = field(
|
| 97 |
-
default = -1,
|
| 98 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 99 |
-
)
|
| 100 |
-
def __init__(
|
| 101 |
-
self,
|
| 102 |
-
output_dir = None,
|
| 103 |
-
overwrite_output_dir = None,
|
| 104 |
-
do_train = False,
|
| 105 |
-
do_eval = False,
|
| 106 |
-
do_predict = False,
|
| 107 |
-
eval_strategy = 'no',
|
| 108 |
-
prediction_loss_only = False,
|
| 109 |
-
per_device_train_batch_size = 4,
|
| 110 |
-
per_device_eval_batch_size = 4,
|
| 111 |
-
per_gpu_train_batch_size = None,
|
| 112 |
-
per_gpu_eval_batch_size = None,
|
| 113 |
-
gradient_accumulation_steps = 2,
|
| 114 |
-
eval_accumulation_steps = 2,
|
| 115 |
-
eval_delay = 0,
|
| 116 |
-
torch_empty_cache_steps = 250,
|
| 117 |
-
learning_rate = 5e-05,
|
| 118 |
-
weight_decay = 0.01,
|
| 119 |
-
adam_beta1 = 0.9,
|
| 120 |
-
adam_beta2 = 0.999,
|
| 121 |
-
adam_epsilon = 1e-08,
|
| 122 |
-
max_grad_norm = 1.0,
|
| 123 |
-
num_train_epochs = 3.0,
|
| 124 |
-
max_steps = -1,
|
| 125 |
-
lr_scheduler_type = 'linear',
|
| 126 |
-
warmup_ratio = 0.1,
|
| 127 |
-
warmup_steps = 0,
|
| 128 |
-
log_level = 'passive',
|
| 129 |
-
log_level_replica = 'warning',
|
| 130 |
-
log_on_each_node = True,
|
| 131 |
-
logging_dir = None,
|
| 132 |
-
logging_strategy = 'steps',
|
| 133 |
-
logging_first_step = False,
|
| 134 |
-
logging_steps = 1,
|
| 135 |
-
logging_nan_inf_filter = False,
|
| 136 |
-
save_strategy = 'steps',
|
| 137 |
-
save_steps = 500,
|
| 138 |
-
save_total_limit = None,
|
| 139 |
-
save_safetensors = True,
|
| 140 |
-
save_on_each_node = False,
|
| 141 |
-
save_only_model = False,
|
| 142 |
-
restore_callback_states_from_checkpoint = False,
|
| 143 |
-
no_cuda = False,
|
| 144 |
-
use_cpu = False,
|
| 145 |
-
use_mps_device = False,
|
| 146 |
-
seed = 3407,
|
| 147 |
-
data_seed = 3407,
|
| 148 |
-
jit_mode_eval = False,
|
| 149 |
-
use_ipex = False,
|
| 150 |
-
bf16 = False,
|
| 151 |
-
fp16 = False,
|
| 152 |
-
fp16_opt_level = 'O1',
|
| 153 |
-
half_precision_backend = 'auto',
|
| 154 |
-
bf16_full_eval = False,
|
| 155 |
-
fp16_full_eval = False,
|
| 156 |
-
tf32 = None,
|
| 157 |
-
local_rank = -1,
|
| 158 |
-
ddp_backend = None,
|
| 159 |
-
tpu_num_cores = None,
|
| 160 |
-
tpu_metrics_debug = False,
|
| 161 |
-
debug = '',
|
| 162 |
-
dataloader_drop_last = False,
|
| 163 |
-
eval_steps = None,
|
| 164 |
-
dataloader_num_workers = 0,
|
| 165 |
-
dataloader_prefetch_factor = None,
|
| 166 |
-
past_index = -1,
|
| 167 |
-
run_name = None,
|
| 168 |
-
disable_tqdm = None,
|
| 169 |
-
remove_unused_columns = True,
|
| 170 |
-
label_names = None,
|
| 171 |
-
load_best_model_at_end = False,
|
| 172 |
-
metric_for_best_model = None,
|
| 173 |
-
greater_is_better = None,
|
| 174 |
-
ignore_data_skip = False,
|
| 175 |
-
fsdp = '',
|
| 176 |
-
fsdp_min_num_params = 0,
|
| 177 |
-
fsdp_config = None,
|
| 178 |
-
fsdp_transformer_layer_cls_to_wrap = None,
|
| 179 |
-
accelerator_config = None,
|
| 180 |
-
deepspeed = None,
|
| 181 |
-
label_smoothing_factor = 0.0,
|
| 182 |
-
optim = 'adamw_8bit',
|
| 183 |
-
optim_args = None,
|
| 184 |
-
adafactor = False,
|
| 185 |
-
group_by_length = False,
|
| 186 |
-
length_column_name = 'length',
|
| 187 |
-
report_to = None,
|
| 188 |
-
ddp_find_unused_parameters = None,
|
| 189 |
-
ddp_bucket_cap_mb = None,
|
| 190 |
-
ddp_broadcast_buffers = None,
|
| 191 |
-
dataloader_pin_memory = True,
|
| 192 |
-
dataloader_persistent_workers = False,
|
| 193 |
-
skip_memory_metrics = True,
|
| 194 |
-
use_legacy_prediction_loop = False,
|
| 195 |
-
push_to_hub = False,
|
| 196 |
-
resume_from_checkpoint = None,
|
| 197 |
-
hub_model_id = None,
|
| 198 |
-
hub_strategy = 'every_save',
|
| 199 |
-
hub_token = None,
|
| 200 |
-
hub_private_repo = None,
|
| 201 |
-
hub_always_push = False,
|
| 202 |
-
hub_revision = None,
|
| 203 |
-
gradient_checkpointing = False,
|
| 204 |
-
gradient_checkpointing_kwargs = None,
|
| 205 |
-
include_inputs_for_metrics = False,
|
| 206 |
-
eval_do_concat_batches = True,
|
| 207 |
-
fp16_backend = 'auto',
|
| 208 |
-
push_to_hub_model_id = None,
|
| 209 |
-
push_to_hub_organization = None,
|
| 210 |
-
push_to_hub_token = None,
|
| 211 |
-
mp_parameters = '',
|
| 212 |
-
auto_find_batch_size = True,
|
| 213 |
-
full_determinism = False,
|
| 214 |
-
torchdynamo = None,
|
| 215 |
-
ray_scope = 'last',
|
| 216 |
-
ddp_timeout = 1800,
|
| 217 |
-
torch_compile = False,
|
| 218 |
-
torch_compile_backend = None,
|
| 219 |
-
torch_compile_mode = None,
|
| 220 |
-
include_tokens_per_second = False,
|
| 221 |
-
include_num_input_tokens_seen = False,
|
| 222 |
-
neftune_noise_alpha = None,
|
| 223 |
-
optim_target_modules = None,
|
| 224 |
-
batch_eval_metrics = False,
|
| 225 |
-
eval_on_start = False,
|
| 226 |
-
use_liger_kernel = False,
|
| 227 |
-
liger_kernel_config = None,
|
| 228 |
-
eval_use_gather_object = False,
|
| 229 |
-
average_tokens_across_devices = True,
|
| 230 |
-
dataset_num_proc = None,
|
| 231 |
-
num_mini_batches = 1,
|
| 232 |
-
total_episodes = None,
|
| 233 |
-
local_rollout_forward_batch_size = 64,
|
| 234 |
-
num_sample_generations = 10,
|
| 235 |
-
response_length = 53,
|
| 236 |
-
stop_token = None,
|
| 237 |
-
stop_token_id = None,
|
| 238 |
-
temperature = 0.7,
|
| 239 |
-
missing_eos_penalty = None,
|
| 240 |
-
sft_model_path = 'EleutherAI/pythia-160m',
|
| 241 |
-
world_size = None,
|
| 242 |
-
num_total_batches = None,
|
| 243 |
-
micro_batch_size = None,
|
| 244 |
-
local_batch_size = None,
|
| 245 |
-
batch_size = None,
|
| 246 |
-
local_mini_batch_size = None,
|
| 247 |
-
mini_batch_size = None,
|
| 248 |
-
exp_name = 'ppo_config',
|
| 249 |
-
reward_model_path = 'EleutherAI/pythia-160m',
|
| 250 |
-
model_adapter_name = None,
|
| 251 |
-
ref_adapter_name = None,
|
| 252 |
-
num_ppo_epochs = 4,
|
| 253 |
-
whiten_rewards = False,
|
| 254 |
-
kl_coef = 0.05,
|
| 255 |
-
cliprange = 0.2,
|
| 256 |
-
vf_coef = 0.1,
|
| 257 |
-
cliprange_value = 0.2,
|
| 258 |
-
gamma = 1.0,
|
| 259 |
-
lam = 0.95,
|
| 260 |
-
ds3_gather_for_generation = True,
|
| 261 |
-
vllm_sampling_params = None,
|
| 262 |
-
unsloth_num_chunks = -1,
|
| 263 |
-
**kwargs,
|
| 264 |
-
):
|
| 265 |
-
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 266 |
-
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 267 |
-
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 268 |
-
output_dir = 'unsloth_training_checkpoints'
|
| 269 |
-
save_strategy = 'no'
|
| 270 |
-
if dataset_num_proc is None:
|
| 271 |
-
from multiprocessing import cpu_count
|
| 272 |
-
dataset_num_proc = min(cpu_count()*2, 2)
|
| 273 |
-
if temperature <= 0:
|
| 274 |
-
raise MathError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
|
| 275 |
-
elif temperature >= 10:
|
| 276 |
-
raise MathError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
super().__init__(
|
| 280 |
-
output_dir = output_dir,
|
| 281 |
-
overwrite_output_dir = overwrite_output_dir,
|
| 282 |
-
do_train = do_train,
|
| 283 |
-
do_eval = do_eval,
|
| 284 |
-
do_predict = do_predict,
|
| 285 |
-
eval_strategy = eval_strategy,
|
| 286 |
-
prediction_loss_only = prediction_loss_only,
|
| 287 |
-
per_device_train_batch_size = per_device_train_batch_size,
|
| 288 |
-
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 289 |
-
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
| 290 |
-
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
| 291 |
-
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 292 |
-
eval_accumulation_steps = eval_accumulation_steps,
|
| 293 |
-
eval_delay = eval_delay,
|
| 294 |
-
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 295 |
-
learning_rate = learning_rate,
|
| 296 |
-
weight_decay = weight_decay,
|
| 297 |
-
adam_beta1 = adam_beta1,
|
| 298 |
-
adam_beta2 = adam_beta2,
|
| 299 |
-
adam_epsilon = adam_epsilon,
|
| 300 |
-
max_grad_norm = max_grad_norm,
|
| 301 |
-
num_train_epochs = num_train_epochs,
|
| 302 |
-
max_steps = max_steps,
|
| 303 |
-
lr_scheduler_type = lr_scheduler_type,
|
| 304 |
-
warmup_ratio = warmup_ratio,
|
| 305 |
-
warmup_steps = warmup_steps,
|
| 306 |
-
log_level = log_level,
|
| 307 |
-
log_level_replica = log_level_replica,
|
| 308 |
-
log_on_each_node = log_on_each_node,
|
| 309 |
-
logging_dir = logging_dir,
|
| 310 |
-
logging_strategy = logging_strategy,
|
| 311 |
-
logging_first_step = logging_first_step,
|
| 312 |
-
logging_steps = logging_steps,
|
| 313 |
-
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 314 |
-
save_strategy = save_strategy,
|
| 315 |
-
save_steps = save_steps,
|
| 316 |
-
save_total_limit = save_total_limit,
|
| 317 |
-
save_safetensors = save_safetensors,
|
| 318 |
-
save_on_each_node = save_on_each_node,
|
| 319 |
-
save_only_model = save_only_model,
|
| 320 |
-
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 321 |
-
no_cuda = no_cuda,
|
| 322 |
-
use_cpu = use_cpu,
|
| 323 |
-
use_mps_device = use_mps_device,
|
| 324 |
-
seed = seed,
|
| 325 |
-
data_seed = data_seed,
|
| 326 |
-
jit_mode_eval = jit_mode_eval,
|
| 327 |
-
use_ipex = use_ipex,
|
| 328 |
-
bf16 = bf16,
|
| 329 |
-
fp16 = fp16,
|
| 330 |
-
fp16_opt_level = fp16_opt_level,
|
| 331 |
-
half_precision_backend = half_precision_backend,
|
| 332 |
-
bf16_full_eval = bf16_full_eval,
|
| 333 |
-
fp16_full_eval = fp16_full_eval,
|
| 334 |
-
tf32 = tf32,
|
| 335 |
-
local_rank = local_rank,
|
| 336 |
-
ddp_backend = ddp_backend,
|
| 337 |
-
tpu_num_cores = tpu_num_cores,
|
| 338 |
-
tpu_metrics_debug = tpu_metrics_debug,
|
| 339 |
-
debug = debug,
|
| 340 |
-
dataloader_drop_last = dataloader_drop_last,
|
| 341 |
-
eval_steps = eval_steps,
|
| 342 |
-
dataloader_num_workers = dataloader_num_workers,
|
| 343 |
-
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 344 |
-
past_index = past_index,
|
| 345 |
-
run_name = run_name,
|
| 346 |
-
disable_tqdm = disable_tqdm,
|
| 347 |
-
remove_unused_columns = remove_unused_columns,
|
| 348 |
-
label_names = label_names,
|
| 349 |
-
load_best_model_at_end = load_best_model_at_end,
|
| 350 |
-
metric_for_best_model = metric_for_best_model,
|
| 351 |
-
greater_is_better = greater_is_better,
|
| 352 |
-
ignore_data_skip = ignore_data_skip,
|
| 353 |
-
fsdp = fsdp,
|
| 354 |
-
fsdp_min_num_params = fsdp_min_num_params,
|
| 355 |
-
fsdp_config = fsdp_config,
|
| 356 |
-
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
| 357 |
-
accelerator_config = accelerator_config,
|
| 358 |
-
deepspeed = deepspeed,
|
| 359 |
-
label_smoothing_factor = label_smoothing_factor,
|
| 360 |
-
optim = optim,
|
| 361 |
-
optim_args = optim_args,
|
| 362 |
-
adafactor = adafactor,
|
| 363 |
-
group_by_length = group_by_length,
|
| 364 |
-
length_column_name = length_column_name,
|
| 365 |
-
report_to = report_to,
|
| 366 |
-
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 367 |
-
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 368 |
-
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 369 |
-
dataloader_pin_memory = dataloader_pin_memory,
|
| 370 |
-
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 371 |
-
skip_memory_metrics = skip_memory_metrics,
|
| 372 |
-
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
| 373 |
-
push_to_hub = push_to_hub,
|
| 374 |
-
resume_from_checkpoint = resume_from_checkpoint,
|
| 375 |
-
hub_model_id = hub_model_id,
|
| 376 |
-
hub_strategy = hub_strategy,
|
| 377 |
-
hub_token = hub_token,
|
| 378 |
-
hub_private_repo = hub_private_repo,
|
| 379 |
-
hub_always_push = hub_always_push,
|
| 380 |
-
hub_revision = hub_revision,
|
| 381 |
-
gradient_checkpointing = gradient_checkpointing,
|
| 382 |
-
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 383 |
-
include_inputs_for_metrics = include_inputs_for_metrics,
|
| 384 |
-
eval_do_concat_batches = eval_do_concat_batches,
|
| 385 |
-
fp16_backend = fp16_backend,
|
| 386 |
-
push_to_hub_model_id = push_to_hub_model_id,
|
| 387 |
-
push_to_hub_organization = push_to_hub_organization,
|
| 388 |
-
push_to_hub_token = push_to_hub_token,
|
| 389 |
-
mp_parameters = mp_parameters,
|
| 390 |
-
auto_find_batch_size = auto_find_batch_size,
|
| 391 |
-
full_determinism = full_determinism,
|
| 392 |
-
torchdynamo = torchdynamo,
|
| 393 |
-
ray_scope = ray_scope,
|
| 394 |
-
ddp_timeout = ddp_timeout,
|
| 395 |
-
torch_compile = torch_compile,
|
| 396 |
-
torch_compile_backend = torch_compile_backend,
|
| 397 |
-
torch_compile_mode = torch_compile_mode,
|
| 398 |
-
include_tokens_per_second = include_tokens_per_second,
|
| 399 |
-
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 400 |
-
neftune_noise_alpha = neftune_noise_alpha,
|
| 401 |
-
optim_target_modules = optim_target_modules,
|
| 402 |
-
batch_eval_metrics = batch_eval_metrics,
|
| 403 |
-
eval_on_start = eval_on_start,
|
| 404 |
-
use_liger_kernel = use_liger_kernel,
|
| 405 |
-
liger_kernel_config = liger_kernel_config,
|
| 406 |
-
eval_use_gather_object = eval_use_gather_object,
|
| 407 |
-
average_tokens_across_devices = average_tokens_across_devices,
|
| 408 |
-
dataset_num_proc = dataset_num_proc,
|
| 409 |
-
num_mini_batches = num_mini_batches,
|
| 410 |
-
total_episodes = total_episodes,
|
| 411 |
-
local_rollout_forward_batch_size = local_rollout_forward_batch_size,
|
| 412 |
-
num_sample_generations = num_sample_generations,
|
| 413 |
-
response_length = response_length,
|
| 414 |
-
stop_token = stop_token,
|
| 415 |
-
stop_token_id = stop_token_id,
|
| 416 |
-
temperature = temperature,
|
| 417 |
-
missing_eos_penalty = missing_eos_penalty,
|
| 418 |
-
sft_model_path = sft_model_path,
|
| 419 |
-
world_size = world_size,
|
| 420 |
-
num_total_batches = num_total_batches,
|
| 421 |
-
micro_batch_size = micro_batch_size,
|
| 422 |
-
local_batch_size = local_batch_size,
|
| 423 |
-
batch_size = batch_size,
|
| 424 |
-
local_mini_batch_size = local_mini_batch_size,
|
| 425 |
-
mini_batch_size = mini_batch_size,
|
| 426 |
-
exp_name = exp_name,
|
| 427 |
-
reward_model_path = reward_model_path,
|
| 428 |
-
model_adapter_name = model_adapter_name,
|
| 429 |
-
ref_adapter_name = ref_adapter_name,
|
| 430 |
-
num_ppo_epochs = num_ppo_epochs,
|
| 431 |
-
whiten_rewards = whiten_rewards,
|
| 432 |
-
kl_coef = kl_coef,
|
| 433 |
-
cliprange = cliprange,
|
| 434 |
-
vf_coef = vf_coef,
|
| 435 |
-
cliprange_value = cliprange_value,
|
| 436 |
-
gamma = gamma,
|
| 437 |
-
lam = lam,
|
| 438 |
-
ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
|
| 439 |
-
self.vllm_sampling_params = vllm_sampling_params
|
| 440 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
| 441 |
-
pass
|
| 442 |
-
|
| 443 |
-
class _UnslothPPOTrainer(Trainer):
|
| 444 |
-
_tag_names = ["trl", "ppo"]
|
| 445 |
-
|
| 446 |
-
def __init__(
|
| 447 |
-
self,
|
| 448 |
-
args: PPOConfig,
|
| 449 |
-
processing_class: Optional[
|
| 450 |
-
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 451 |
-
],
|
| 452 |
-
model: nn.Module,
|
| 453 |
-
ref_model: Optional[nn.Module],
|
| 454 |
-
reward_model: nn.Module,
|
| 455 |
-
train_dataset: Dataset,
|
| 456 |
-
value_model: Optional[nn.Module] = None,
|
| 457 |
-
data_collator: Optional[DataCollatorWithPadding] = None,
|
| 458 |
-
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 459 |
-
# less commonly used
|
| 460 |
-
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 461 |
-
callbacks: Optional[list[TrainerCallback]] = None,
|
| 462 |
-
peft_config: Optional["PeftConfig"] = None,
|
| 463 |
-
) -> None:
|
| 464 |
-
if ref_model is model:
|
| 465 |
-
raise ValueError(
|
| 466 |
-
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
|
| 467 |
-
"same as `model`, you must make a copy of it, or `None` if you use peft."
|
| 468 |
-
)
|
| 469 |
-
|
| 470 |
-
self.args = args
|
| 471 |
-
self.processing_class = processing_class
|
| 472 |
-
self.policy_model = model
|
| 473 |
-
|
| 474 |
-
# Define the collator if not provided
|
| 475 |
-
if data_collator is None:
|
| 476 |
-
data_collator = DataCollatorWithPadding(self.processing_class)
|
| 477 |
-
|
| 478 |
-
# Handle stop token settings: update policy model's generation_config to use provided stop token
|
| 479 |
-
if args.stop_token and args.stop_token_id:
|
| 480 |
-
raise ValueError("You cannot set both `stop_token` and `stop_token_id`.")
|
| 481 |
-
elif args.stop_token:
|
| 482 |
-
if args.stop_token == "eos":
|
| 483 |
-
self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id
|
| 484 |
-
else:
|
| 485 |
-
raise ValueError(
|
| 486 |
-
f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)."
|
| 487 |
-
)
|
| 488 |
-
else:
|
| 489 |
-
self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int
|
| 490 |
-
|
| 491 |
-
# peft support
|
| 492 |
-
if not is_peft_available() and peft_config is not None:
|
| 493 |
-
raise ImportError(
|
| 494 |
-
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
| 495 |
-
)
|
| 496 |
-
elif is_peft_available() and peft_config is not None:
|
| 497 |
-
# if model is a peft model and we have a peft_confg, we merge and unload it first
|
| 498 |
-
if isinstance(self.policy_model, PeftModel):
|
| 499 |
-
self.policy_model = self.policy_model.merge_and_unload()
|
| 500 |
-
|
| 501 |
-
# get peft model with the given config
|
| 502 |
-
self.policy_model = get_peft_model(self.policy_model, peft_config)
|
| 503 |
-
if args.bf16 and getattr(self.policy_model, "is_loaded_in_4bit", False):
|
| 504 |
-
peft_module_casting_to_bf16(self.policy_model)
|
| 505 |
-
|
| 506 |
-
self.is_peft_model = is_peft_available() and isinstance(self.policy_model, PeftModel)
|
| 507 |
-
self.model_adapter_name = args.model_adapter_name
|
| 508 |
-
self.ref_adapter_name = args.ref_adapter_name
|
| 509 |
-
|
| 510 |
-
if ref_model:
|
| 511 |
-
self.ref_model = ref_model
|
| 512 |
-
elif self.is_peft_model:
|
| 513 |
-
self.ref_model = None
|
| 514 |
-
else:
|
| 515 |
-
self.ref_model = create_reference_model(self.policy_model)
|
| 516 |
-
|
| 517 |
-
self.reward_model = reward_model
|
| 518 |
-
self.train_dataset = train_dataset
|
| 519 |
-
self.train_dataset_len = len(train_dataset)
|
| 520 |
-
self.value_model = value_model
|
| 521 |
-
self.data_collator = data_collator
|
| 522 |
-
self.eval_dataset = eval_dataset
|
| 523 |
-
self.optimizer, self.lr_scheduler = optimizers
|
| 524 |
-
self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
|
| 525 |
-
|
| 526 |
-
#########
|
| 527 |
-
# calculate various batch sizes
|
| 528 |
-
#########
|
| 529 |
-
if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
|
| 530 |
-
args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
|
| 531 |
-
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
|
| 532 |
-
self.accelerator = accelerator
|
| 533 |
-
args.world_size = accelerator.num_processes
|
| 534 |
-
args.local_batch_size = (
|
| 535 |
-
args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches
|
| 536 |
-
)
|
| 537 |
-
args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
|
| 538 |
-
args.batch_size = int(args.local_batch_size * args.world_size)
|
| 539 |
-
args.mini_batch_size = exact_div(
|
| 540 |
-
args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
|
| 541 |
-
)
|
| 542 |
-
args.local_mini_batch_size = exact_div(
|
| 543 |
-
args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
|
| 544 |
-
)
|
| 545 |
-
if args.whiten_rewards:
|
| 546 |
-
assert args.local_mini_batch_size >= 8, (
|
| 547 |
-
f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening"
|
| 548 |
-
)
|
| 549 |
-
# `per_rank_rollout_batch_size` is our `args.local_batch_size`
|
| 550 |
-
# `per_rank_minibatch_size` is our `args.local_mini_batch_size`
|
| 551 |
-
args.num_total_batches = math.ceil(
|
| 552 |
-
args.total_episodes / args.batch_size
|
| 553 |
-
) # we may train for more than `total_episodes`
|
| 554 |
-
time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
|
| 555 |
-
time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
|
| 556 |
-
args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
|
| 557 |
-
self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
|
| 558 |
-
if args.num_sample_generations > 0:
|
| 559 |
-
self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
|
| 560 |
-
self.local_dataloader_batch_size = args.local_batch_size
|
| 561 |
-
|
| 562 |
-
#########
|
| 563 |
-
# setup model, optimizer, and others
|
| 564 |
-
#########
|
| 565 |
-
for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]:
|
| 566 |
-
if module is not None:
|
| 567 |
-
disable_dropout_in_model(module)
|
| 568 |
-
self.model = PolicyAndValueWrapper(self.policy_model, self.value_model)
|
| 569 |
-
self.model.config = self.policy_model.config # needed for pushing to hub
|
| 570 |
-
self.create_optimizer_and_scheduler(
|
| 571 |
-
num_training_steps=args.num_total_batches
|
| 572 |
-
) # note that we are calling `self.lr_scheduler.step[]` manually only at the batch level
|
| 573 |
-
|
| 574 |
-
#########
|
| 575 |
-
### trainer specifics
|
| 576 |
-
#########
|
| 577 |
-
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
|
| 578 |
-
self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
|
| 579 |
-
self.callback_handler = CallbackHandler(
|
| 580 |
-
self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
|
| 581 |
-
)
|
| 582 |
-
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
|
| 583 |
-
self.control = TrainerControl()
|
| 584 |
-
self.state = OnlineTrainerState(
|
| 585 |
-
is_local_process_zero=self.is_local_process_zero(),
|
| 586 |
-
is_world_process_zero=self.is_world_process_zero(),
|
| 587 |
-
stateful_callbacks=[
|
| 588 |
-
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
|
| 589 |
-
],
|
| 590 |
-
)
|
| 591 |
-
self.current_flos = 0
|
| 592 |
-
self.hp_search_backend = None
|
| 593 |
-
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
| 594 |
-
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
| 595 |
-
# Create distant repo and output directory if needed
|
| 596 |
-
self.hub_model_id = None
|
| 597 |
-
if self.args.push_to_hub:
|
| 598 |
-
self.init_hf_repo()
|
| 599 |
-
if self.args.should_save:
|
| 600 |
-
os.makedirs(self.args.output_dir, exist_ok=True)
|
| 601 |
-
|
| 602 |
-
# Add tags for models that have been loaded with the correct transformers version
|
| 603 |
-
if hasattr(self.model, "add_model_tags"):
|
| 604 |
-
self.model.add_model_tags(self._tag_names)
|
| 605 |
-
|
| 606 |
-
#########
|
| 607 |
-
### setup dataloader
|
| 608 |
-
#########
|
| 609 |
-
self.dataloader = DataLoader(
|
| 610 |
-
self.train_dataset,
|
| 611 |
-
batch_size=self.local_dataloader_batch_size,
|
| 612 |
-
shuffle=True,
|
| 613 |
-
collate_fn=self.data_collator,
|
| 614 |
-
drop_last=True, # needed; otherwise the last batch will be of ragged shape
|
| 615 |
-
)
|
| 616 |
-
# sync random states for DataLoader[shuffle=True] before `accelerator.prepare`
|
| 617 |
-
# see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
|
| 618 |
-
torch.manual_seed(args.seed)
|
| 619 |
-
self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
|
| 620 |
-
torch.manual_seed(self.local_seed) # reset the local seed again
|
| 621 |
-
|
| 622 |
-
self.eval_dataloader = DataLoader(
|
| 623 |
-
self.eval_dataset,
|
| 624 |
-
batch_size=args.per_device_eval_batch_size,
|
| 625 |
-
collate_fn=self.data_collator,
|
| 626 |
-
drop_last=True,
|
| 627 |
-
) # no need to shuffle eval dataset
|
| 628 |
-
self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
|
| 629 |
-
|
| 630 |
-
if self.is_deepspeed_enabled:
|
| 631 |
-
self.reward_model = prepare_deepspeed(
|
| 632 |
-
self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
|
| 633 |
-
)
|
| 634 |
-
|
| 635 |
-
if self.ref_model is None:
|
| 636 |
-
if not self.is_peft_model:
|
| 637 |
-
raise ValueError("No reference model and model is not a Peft model.")
|
| 638 |
-
else:
|
| 639 |
-
self.ref_model = prepare_deepspeed(
|
| 640 |
-
self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
|
| 641 |
-
)
|
| 642 |
-
else:
|
| 643 |
-
if self.ref_model is None:
|
| 644 |
-
if not self.is_peft_model:
|
| 645 |
-
raise ValueError("No reference model and model is not a Peft model.")
|
| 646 |
-
else:
|
| 647 |
-
self.ref_model = self.ref_model.to(self.accelerator.device)
|
| 648 |
-
self.reward_model = self.reward_model.to(self.accelerator.device)
|
| 649 |
-
|
| 650 |
-
def get_train_dataloader(self) -> DataLoader:
|
| 651 |
-
return self.dataloader
|
| 652 |
-
|
| 653 |
-
def get_eval_dataloader(self) -> DataLoader:
|
| 654 |
-
return self.eval_dataloader
|
| 655 |
-
|
| 656 |
-
@contextmanager
|
| 657 |
-
def null_ref_context(self):
|
| 658 |
-
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
|
| 659 |
-
with (
|
| 660 |
-
self.accelerator.unwrap_model(self.model.policy).disable_adapter()
|
| 661 |
-
if self.is_peft_model and not self.ref_adapter_name
|
| 662 |
-
else nullcontext()
|
| 663 |
-
):
|
| 664 |
-
if self.ref_adapter_name:
|
| 665 |
-
self.model.policy.set_adapter(self.ref_adapter_name)
|
| 666 |
-
yield
|
| 667 |
-
if self.ref_adapter_name:
|
| 668 |
-
self.model.policy.set_adapter(self.model_adapter_name or "default")
|
| 669 |
-
|
| 670 |
-
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
|
| 671 |
-
backup_model = self.model
|
| 672 |
-
self.model = self.model.policy # save only the policy
|
| 673 |
-
|
| 674 |
-
if self.is_deepspeed_enabled:
|
| 675 |
-
backup_deepspeed = self.deepspeed
|
| 676 |
-
self.deepspeed = self.model
|
| 677 |
-
|
| 678 |
-
super().save_model(output_dir, _internal_call)
|
| 679 |
-
|
| 680 |
-
self.model = backup_model
|
| 681 |
-
|
| 682 |
-
if self.is_deepspeed_enabled:
|
| 683 |
-
self.deepspeed = backup_deepspeed
|
| 684 |
-
|
| 685 |
-
def train(self):
|
| 686 |
-
args = self.args
|
| 687 |
-
accelerator = self.accelerator
|
| 688 |
-
optimizer = self.optimizer
|
| 689 |
-
model = self.model
|
| 690 |
-
ref_policy = self.ref_model
|
| 691 |
-
reward_model = self.reward_model
|
| 692 |
-
processing_class = self.processing_class
|
| 693 |
-
dataloader = self.dataloader
|
| 694 |
-
device = accelerator.device
|
| 695 |
-
|
| 696 |
-
def repeat_generator():
|
| 697 |
-
while True:
|
| 698 |
-
yield from dataloader
|
| 699 |
-
|
| 700 |
-
iter_dataloader = iter(repeat_generator())
|
| 701 |
-
generation_config = GenerationConfig(
|
| 702 |
-
max_new_tokens=args.response_length,
|
| 703 |
-
temperature=(args.temperature + 1e-7),
|
| 704 |
-
top_k=0.0,
|
| 705 |
-
top_p=1.0,
|
| 706 |
-
do_sample=True,
|
| 707 |
-
)
|
| 708 |
-
|
| 709 |
-
accelerator.print("===training policy===")
|
| 710 |
-
start_time = time.time()
|
| 711 |
-
stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
|
| 712 |
-
approxkl_stats = torch.zeros(stats_shape, device=device)
|
| 713 |
-
pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
|
| 714 |
-
pg_loss_stats = torch.zeros(stats_shape, device=device)
|
| 715 |
-
vf_loss_stats = torch.zeros(stats_shape, device=device)
|
| 716 |
-
vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
|
| 717 |
-
entropy_stats = torch.zeros(stats_shape, device=device)
|
| 718 |
-
ratio_stats = torch.zeros(stats_shape, device=device)
|
| 719 |
-
model.train()
|
| 720 |
-
|
| 721 |
-
# trainer state initialization
|
| 722 |
-
self.state.global_step = 0
|
| 723 |
-
self.state.episode = 0
|
| 724 |
-
self.state.max_steps = args.num_total_batches * args.num_mini_batches
|
| 725 |
-
self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
|
| 726 |
-
# Compute absolute values for logging, eval, and save if given as ratio
|
| 727 |
-
if args.logging_steps is not None:
|
| 728 |
-
if args.logging_steps < 1:
|
| 729 |
-
self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
|
| 730 |
-
else:
|
| 731 |
-
self.state.logging_steps = args.logging_steps
|
| 732 |
-
if args.eval_steps is not None:
|
| 733 |
-
if args.eval_steps < 1:
|
| 734 |
-
self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
|
| 735 |
-
else:
|
| 736 |
-
self.state.eval_steps = args.eval_steps
|
| 737 |
-
if args.save_steps is not None:
|
| 738 |
-
if args.save_steps < 1:
|
| 739 |
-
self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
|
| 740 |
-
else:
|
| 741 |
-
self.state.save_steps = args.save_steps
|
| 742 |
-
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
|
| 743 |
-
|
| 744 |
-
# backward compatibility
|
| 745 |
-
if self.is_deepspeed_enabled:
|
| 746 |
-
self.deepspeed = self.model
|
| 747 |
-
self.model_wrapped = self.model
|
| 748 |
-
|
| 749 |
-
for update in range(1, args.num_total_batches + 1):
|
| 750 |
-
self.state.episode += 1 * args.batch_size
|
| 751 |
-
data = next(iter_dataloader)
|
| 752 |
-
with torch.no_grad():
|
| 753 |
-
queries = data["input_ids"].to(device)
|
| 754 |
-
context_length = queries.shape[1]
|
| 755 |
-
responses = []
|
| 756 |
-
postprocessed_responses = []
|
| 757 |
-
logprobs = []
|
| 758 |
-
ref_logprobs = []
|
| 759 |
-
scores = []
|
| 760 |
-
sequence_lengths = []
|
| 761 |
-
values = []
|
| 762 |
-
with unwrap_model_for_generation(
|
| 763 |
-
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
| 764 |
-
) as unwrapped_model:
|
| 765 |
-
query_responses, logitss = batch_generation(
|
| 766 |
-
unwrapped_model.policy,
|
| 767 |
-
queries,
|
| 768 |
-
args.local_rollout_forward_batch_size,
|
| 769 |
-
processing_class.pad_token_id,
|
| 770 |
-
generation_config,
|
| 771 |
-
)
|
| 772 |
-
|
| 773 |
-
for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
|
| 774 |
-
query = queries[i : i + args.local_rollout_forward_batch_size]
|
| 775 |
-
query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
|
| 776 |
-
response = query_response[:, context_length:]
|
| 777 |
-
logits = logitss[i : i + args.local_rollout_forward_batch_size]
|
| 778 |
-
logprob = selective_log_softmax(logits, response)
|
| 779 |
-
del logits
|
| 780 |
-
torch.cuda.empty_cache()
|
| 781 |
-
|
| 782 |
-
if ref_policy is None:
|
| 783 |
-
with self.null_ref_context():
|
| 784 |
-
ref_output = forward(model.policy, query_response, processing_class.pad_token_id)
|
| 785 |
-
else:
|
| 786 |
-
ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
|
| 787 |
-
ref_logits = ref_output.logits[:, context_length - 1 : -1]
|
| 788 |
-
ref_logits /= args.temperature + 1e-7
|
| 789 |
-
ref_logprob = selective_log_softmax(ref_logits, response)
|
| 790 |
-
del ref_output, ref_logits
|
| 791 |
-
torch.cuda.empty_cache()
|
| 792 |
-
|
| 793 |
-
# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
|
| 794 |
-
postprocessed_response = response
|
| 795 |
-
if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
|
| 796 |
-
postprocessed_response = truncate_response(
|
| 797 |
-
self.stop_token_id, processing_class.pad_token_id, response
|
| 798 |
-
)
|
| 799 |
-
|
| 800 |
-
# Response Processing 2. run reward model on the truncated responses
|
| 801 |
-
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
|
| 802 |
-
sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
|
| 803 |
-
unwrapped_value_model = accelerator.unwrap_model(model).value_model
|
| 804 |
-
full_value, _, _ = get_reward(
|
| 805 |
-
unwrapped_value_model, query_response, processing_class.pad_token_id, context_length
|
| 806 |
-
)
|
| 807 |
-
value = full_value[:, context_length - 1 : -1].squeeze(-1)
|
| 808 |
-
_, score, _ = get_reward(
|
| 809 |
-
reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
|
| 810 |
-
)
|
| 811 |
-
|
| 812 |
-
responses.append(response)
|
| 813 |
-
postprocessed_responses.append(postprocessed_response)
|
| 814 |
-
logprobs.append(logprob)
|
| 815 |
-
ref_logprobs.append(ref_logprob)
|
| 816 |
-
sequence_lengths.append(sequence_length)
|
| 817 |
-
scores.append(score)
|
| 818 |
-
values.append(value)
|
| 819 |
-
responses = torch.cat(responses, 0)
|
| 820 |
-
postprocessed_responses = torch.cat(postprocessed_responses, 0)
|
| 821 |
-
logprobs = torch.cat(logprobs, 0)
|
| 822 |
-
ref_logprobs = torch.cat(ref_logprobs, 0)
|
| 823 |
-
sequence_lengths = torch.cat(sequence_lengths, 0)
|
| 824 |
-
scores = torch.cat(scores, 0)
|
| 825 |
-
values = torch.cat(values, 0)
|
| 826 |
-
del (logprob, ref_logprob, full_value, value, score, unwrapped_model)
|
| 827 |
-
torch.cuda.empty_cache()
|
| 828 |
-
gc.collect()
|
| 829 |
-
|
| 830 |
-
# Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id
|
| 831 |
-
# Completions not passing that filter will receive a lower score.
|
| 832 |
-
contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1)
|
| 833 |
-
if self.args.missing_eos_penalty is not None:
|
| 834 |
-
scores[~contain_eos_token] -= self.args.missing_eos_penalty
|
| 835 |
-
# accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
|
| 836 |
-
|
| 837 |
-
# be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
|
| 838 |
-
response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
|
| 839 |
-
padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
|
| 840 |
-
logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
|
| 841 |
-
ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
|
| 842 |
-
sequence_lengths_p1 = sequence_lengths + 1
|
| 843 |
-
padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1))
|
| 844 |
-
values = torch.masked_fill(values, padding_mask_p1, 0)
|
| 845 |
-
|
| 846 |
-
# 4. compute rewards
|
| 847 |
-
kl = logprobs - ref_logprobs
|
| 848 |
-
non_score_reward = -args.kl_coef * kl
|
| 849 |
-
rewards = non_score_reward.clone()
|
| 850 |
-
actual_start = torch.arange(rewards.size(0), device=rewards.device)
|
| 851 |
-
actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
|
| 852 |
-
rewards[[actual_start, actual_end]] += scores
|
| 853 |
-
|
| 854 |
-
# 5. whiten rewards
|
| 855 |
-
if args.whiten_rewards:
|
| 856 |
-
rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False)
|
| 857 |
-
rewards = torch.masked_fill(rewards, padding_mask_p1, 0)
|
| 858 |
-
|
| 859 |
-
# 6. compute advantages and returns
|
| 860 |
-
lastgaelam = 0
|
| 861 |
-
advantages_reversed = []
|
| 862 |
-
gen_length = responses.shape[1]
|
| 863 |
-
for t in reversed(range(gen_length)):
|
| 864 |
-
nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
|
| 865 |
-
delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
|
| 866 |
-
lastgaelam = delta + args.gamma * args.lam * lastgaelam
|
| 867 |
-
advantages_reversed.append(lastgaelam)
|
| 868 |
-
advantages = torch.stack(advantages_reversed[::-1], axis=1)
|
| 869 |
-
returns = advantages + values
|
| 870 |
-
advantages = masked_whiten(advantages, ~padding_mask)
|
| 871 |
-
advantages = torch.masked_fill(advantages, padding_mask, 0)
|
| 872 |
-
torch.cuda.empty_cache()
|
| 873 |
-
|
| 874 |
-
# Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
|
| 875 |
-
for ppo_epoch_idx in range(args.num_ppo_epochs):
|
| 876 |
-
b_inds = np.random.permutation(args.local_batch_size)
|
| 877 |
-
minibatch_idx = 0
|
| 878 |
-
for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
|
| 879 |
-
mini_batch_end = mini_batch_start + args.local_mini_batch_size
|
| 880 |
-
mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
|
| 881 |
-
gradient_accumulation_idx = 0
|
| 882 |
-
for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
|
| 883 |
-
with accelerator.accumulate(model):
|
| 884 |
-
micro_batch_end = micro_batch_start + args.per_device_train_batch_size
|
| 885 |
-
micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
|
| 886 |
-
mb_advantage = advantages[micro_batch_inds]
|
| 887 |
-
mb_responses = responses[micro_batch_inds]
|
| 888 |
-
mb_query_responses = query_responses[micro_batch_inds]
|
| 889 |
-
mb_logprobs = logprobs[micro_batch_inds]
|
| 890 |
-
mb_return = returns[micro_batch_inds]
|
| 891 |
-
mb_values = values[micro_batch_inds]
|
| 892 |
-
|
| 893 |
-
output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id)
|
| 894 |
-
logits = output.logits[:, context_length - 1 : -1]
|
| 895 |
-
logits /= args.temperature + 1e-7
|
| 896 |
-
new_logprobs = selective_log_softmax(logits, mb_responses)
|
| 897 |
-
new_logprobs = torch.masked_fill(
|
| 898 |
-
new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
|
| 899 |
-
)
|
| 900 |
-
vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1)
|
| 901 |
-
vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0)
|
| 902 |
-
vpredclipped = torch.clamp(
|
| 903 |
-
vpred,
|
| 904 |
-
mb_values - args.cliprange_value,
|
| 905 |
-
mb_values + args.cliprange_value,
|
| 906 |
-
)
|
| 907 |
-
vf_losses1 = torch.square(vpred - mb_return)
|
| 908 |
-
vf_losses2 = torch.square(vpredclipped - mb_return)
|
| 909 |
-
vf_loss_max = torch.max(vf_losses1, vf_losses2)
|
| 910 |
-
vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds])
|
| 911 |
-
vf_clipfrac = masked_mean(
|
| 912 |
-
(vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds]
|
| 913 |
-
)
|
| 914 |
-
logprobs_diff = new_logprobs - mb_logprobs
|
| 915 |
-
ratio = torch.exp(logprobs_diff)
|
| 916 |
-
pg_losses = -mb_advantage * ratio
|
| 917 |
-
pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
|
| 918 |
-
pg_loss_max = torch.max(pg_losses, pg_losses2)
|
| 919 |
-
pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])
|
| 920 |
-
loss = pg_loss + args.vf_coef * vf_loss
|
| 921 |
-
accelerator.backward(loss)
|
| 922 |
-
optimizer.step()
|
| 923 |
-
optimizer.zero_grad()
|
| 924 |
-
with torch.no_grad():
|
| 925 |
-
pg_clipfrac = masked_mean(
|
| 926 |
-
(pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds]
|
| 927 |
-
)
|
| 928 |
-
prob_dist = torch.nn.functional.softmax(logits, dim=-1)
|
| 929 |
-
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
|
| 930 |
-
approxkl = 0.5 * (logprobs_diff**2).mean()
|
| 931 |
-
approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
|
| 932 |
-
pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
|
| 933 |
-
pg_clipfrac
|
| 934 |
-
)
|
| 935 |
-
pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
|
| 936 |
-
vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
|
| 937 |
-
vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
|
| 938 |
-
vf_clipfrac
|
| 939 |
-
)
|
| 940 |
-
entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
|
| 941 |
-
ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
|
| 942 |
-
gradient_accumulation_idx += 1
|
| 943 |
-
minibatch_idx += 1
|
| 944 |
-
# del everything and empty cache
|
| 945 |
-
# fmt: off
|
| 946 |
-
del (
|
| 947 |
-
output, vpred_temp, logits, new_logprobs, vpred, vpredclipped,
|
| 948 |
-
vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max,
|
| 949 |
-
pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return,
|
| 950 |
-
mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
|
| 951 |
-
)
|
| 952 |
-
# fmt: on
|
| 953 |
-
torch.cuda.empty_cache()
|
| 954 |
-
with torch.no_grad():
|
| 955 |
-
mean_kl = kl.sum(1).mean()
|
| 956 |
-
mean_entropy = (-logprobs).sum(1).mean()
|
| 957 |
-
mean_non_score_reward = non_score_reward.sum(1).mean()
|
| 958 |
-
rlhf_reward = mean_non_score_reward + scores.mean()
|
| 959 |
-
eps = int(self.state.episode / (time.time() - start_time))
|
| 960 |
-
metrics = {}
|
| 961 |
-
metrics["eps"] = eps
|
| 962 |
-
metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
|
| 963 |
-
metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
|
| 964 |
-
metrics["objective/non_score_reward"] = (
|
| 965 |
-
self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
|
| 966 |
-
)
|
| 967 |
-
metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
|
| 968 |
-
metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
|
| 969 |
-
metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
|
| 970 |
-
metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
|
| 971 |
-
metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
|
| 972 |
-
metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item()
|
| 973 |
-
metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
|
| 974 |
-
metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
|
| 975 |
-
metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
|
| 976 |
-
metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
|
| 977 |
-
metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
|
| 978 |
-
metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
|
| 979 |
-
metrics["episode"] = self.state.episode
|
| 980 |
-
self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log
|
| 981 |
-
self.state.global_step += 1
|
| 982 |
-
self.log(metrics)
|
| 983 |
-
|
| 984 |
-
self.lr_scheduler.step()
|
| 985 |
-
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
|
| 986 |
-
if self.control.should_save:
|
| 987 |
-
self._save_checkpoint(model, trial=None)
|
| 988 |
-
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
| 989 |
-
del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
|
| 990 |
-
torch.cuda.empty_cache()
|
| 991 |
-
gc.collect()
|
| 992 |
-
|
| 993 |
-
if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
|
| 994 |
-
self.generate_completions(sampling=True)
|
| 995 |
-
torch.cuda.empty_cache()
|
| 996 |
-
del (
|
| 997 |
-
query_responses,
|
| 998 |
-
responses,
|
| 999 |
-
postprocessed_responses,
|
| 1000 |
-
logprobs,
|
| 1001 |
-
ref_logprobs,
|
| 1002 |
-
values,
|
| 1003 |
-
sequence_lengths,
|
| 1004 |
-
contain_eos_token,
|
| 1005 |
-
sequence_lengths_p1,
|
| 1006 |
-
response_idxs,
|
| 1007 |
-
padding_mask,
|
| 1008 |
-
padding_mask_p1,
|
| 1009 |
-
rewards,
|
| 1010 |
-
actual_start,
|
| 1011 |
-
actual_end,
|
| 1012 |
-
advantages,
|
| 1013 |
-
returns,
|
| 1014 |
-
)
|
| 1015 |
-
torch.cuda.empty_cache()
|
| 1016 |
-
|
| 1017 |
-
# HF trainer specifics
|
| 1018 |
-
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
|
| 1019 |
-
if self.control.should_save:
|
| 1020 |
-
self._save_checkpoint(model, trial=None, metrics=None)
|
| 1021 |
-
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
| 1022 |
-
|
| 1023 |
-
def generate_completions(self, sampling: bool = False):
|
| 1024 |
-
args = self.args
|
| 1025 |
-
processing_class = self.processing_class
|
| 1026 |
-
generation_config = GenerationConfig(
|
| 1027 |
-
max_new_tokens=self.args.response_length,
|
| 1028 |
-
temperature=(0.01 + 1e-7),
|
| 1029 |
-
top_k=0.0,
|
| 1030 |
-
top_p=1.0,
|
| 1031 |
-
do_sample=True,
|
| 1032 |
-
)
|
| 1033 |
-
|
| 1034 |
-
table = defaultdict(list)
|
| 1035 |
-
with unwrap_model_for_generation(
|
| 1036 |
-
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
| 1037 |
-
) as unwrapped_model:
|
| 1038 |
-
for batch in self.eval_dataloader:
|
| 1039 |
-
query = batch["input_ids"]
|
| 1040 |
-
with torch.no_grad():
|
| 1041 |
-
context_length = query.shape[1]
|
| 1042 |
-
query_response, _ = batch_generation(
|
| 1043 |
-
unwrapped_model.policy,
|
| 1044 |
-
query,
|
| 1045 |
-
query.shape[0],
|
| 1046 |
-
processing_class.pad_token_id,
|
| 1047 |
-
generation_config,
|
| 1048 |
-
)
|
| 1049 |
-
response = query_response[:, context_length:]
|
| 1050 |
-
postprocessed_response = response
|
| 1051 |
-
if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
|
| 1052 |
-
postprocessed_response = truncate_response(
|
| 1053 |
-
self.stop_token_id, processing_class.pad_token_id, response
|
| 1054 |
-
)
|
| 1055 |
-
table["query"].extend(
|
| 1056 |
-
gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
|
| 1057 |
-
)
|
| 1058 |
-
table["model response"].extend(
|
| 1059 |
-
gather_object(processing_class.batch_decode(postprocessed_response))
|
| 1060 |
-
)
|
| 1061 |
-
|
| 1062 |
-
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
|
| 1063 |
-
_, score, _ = get_reward(
|
| 1064 |
-
self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
|
| 1065 |
-
)
|
| 1066 |
-
table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
|
| 1067 |
-
|
| 1068 |
-
if sampling:
|
| 1069 |
-
break
|
| 1070 |
-
df = pd.DataFrame(table)
|
| 1071 |
-
|
| 1072 |
-
if self.accelerator.is_main_process:
|
| 1073 |
-
print_rich_table(df.iloc[0 : 0 + 5])
|
| 1074 |
-
if "wandb" in args.report_to:
|
| 1075 |
-
import wandb
|
| 1076 |
-
|
| 1077 |
-
if wandb.run is not None:
|
| 1078 |
-
wandb.log({"completions": wandb.Table(dataframe=df)})
|
| 1079 |
-
|
| 1080 |
-
if "comet_ml" in args.report_to:
|
| 1081 |
-
log_table_to_comet_experiment(
|
| 1082 |
-
name="completions.csv",
|
| 1083 |
-
table=df,
|
| 1084 |
-
)
|
| 1085 |
-
|
| 1086 |
-
def create_model_card(
|
| 1087 |
-
self,
|
| 1088 |
-
model_name: Optional[str] = None,
|
| 1089 |
-
dataset_name: Optional[str] = None,
|
| 1090 |
-
tags: Union[str, list[str], None] = None,
|
| 1091 |
-
):
|
| 1092 |
-
"""
|
| 1093 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
| 1094 |
-
|
| 1095 |
-
Args:
|
| 1096 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1097 |
-
Name of the model.
|
| 1098 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1099 |
-
Name of the dataset used for training.
|
| 1100 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 1101 |
-
Tags to be associated with the model card.
|
| 1102 |
-
"""
|
| 1103 |
-
if not self.is_world_process_zero():
|
| 1104 |
-
return
|
| 1105 |
-
|
| 1106 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 1107 |
-
base_model = self.model.config._name_or_path
|
| 1108 |
-
else:
|
| 1109 |
-
base_model = None
|
| 1110 |
-
|
| 1111 |
-
tags = tags or []
|
| 1112 |
-
if isinstance(tags, str):
|
| 1113 |
-
tags = [tags]
|
| 1114 |
-
|
| 1115 |
-
if hasattr(self.model.config, "unsloth_version"):
|
| 1116 |
-
tags.append("unsloth")
|
| 1117 |
-
|
| 1118 |
-
citation = textwrap.dedent("""\
|
| 1119 |
-
@article{mziegler2019fine-tuning,
|
| 1120 |
-
title = {{Fine-Tuning Language Models from Human Preferences}},
|
| 1121 |
-
author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving},
|
| 1122 |
-
year = 2019,
|
| 1123 |
-
eprint = {arXiv:1909.08593}
|
| 1124 |
-
}""")
|
| 1125 |
-
|
| 1126 |
-
model_card = generate_model_card(
|
| 1127 |
-
base_model=base_model,
|
| 1128 |
-
model_name=model_name,
|
| 1129 |
-
hub_model_id=self.hub_model_id,
|
| 1130 |
-
dataset_name=dataset_name,
|
| 1131 |
-
tags=tags,
|
| 1132 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 1133 |
-
comet_url=get_comet_experiment_url(),
|
| 1134 |
-
trainer_name="PPO",
|
| 1135 |
-
trainer_citation=citation,
|
| 1136 |
-
paper_title="Fine-Tuning Language Models from Human Preferences",
|
| 1137 |
-
paper_id="1909.08593",
|
| 1138 |
-
)
|
| 1139 |
-
|
| 1140 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 1141 |
-
class UnslothPPOTrainer(_UnslothPPOTrainer):
|
| 1142 |
-
"""
|
| 1143 |
-
|
| 1144 |
-
"""
|
| 1145 |
-
def __init__(
|
| 1146 |
-
self,
|
| 1147 |
-
args,
|
| 1148 |
-
processing_class,
|
| 1149 |
-
model,
|
| 1150 |
-
ref_model,
|
| 1151 |
-
reward_model,
|
| 1152 |
-
train_dataset,
|
| 1153 |
-
value_model = None,
|
| 1154 |
-
data_collator = None,
|
| 1155 |
-
eval_dataset = None,
|
| 1156 |
-
callbacks = None,
|
| 1157 |
-
peft_config = None,
|
| 1158 |
-
**kwargs
|
| 1159 |
-
):
|
| 1160 |
-
if args is None: args = UnslothPPOConfig()
|
| 1161 |
-
use_bf16 = getattr(args, 'bf16', False)
|
| 1162 |
-
if type(use_bf16) is not bool: use_bf16 = False
|
| 1163 |
-
use_fp16 = getattr(args, 'fp16', False)
|
| 1164 |
-
if type(use_fp16) is not bool: use_fp16 = False
|
| 1165 |
-
force_float32 = False
|
| 1166 |
-
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
| 1167 |
-
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 1168 |
-
force_float32 = True
|
| 1169 |
-
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 1170 |
-
dtype = getattr(model.config, 'torch_dtype', None)
|
| 1171 |
-
if dtype is None: dtype = model.get_input_embeddings().dtype
|
| 1172 |
-
from unsloth_zoo.utils import _get_dtype
|
| 1173 |
-
dtype = _get_dtype(dtype)
|
| 1174 |
-
float16 = dtype == torch.float16
|
| 1175 |
-
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 1176 |
-
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 1177 |
-
if force_float32:
|
| 1178 |
-
args.fp16 = False
|
| 1179 |
-
args.bf16 = False
|
| 1180 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 1181 |
-
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 1182 |
-
args.fp16 = float16
|
| 1183 |
-
args.bf16 = not float16
|
| 1184 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 1185 |
-
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 1186 |
-
args.eval_strategy = 'steps'
|
| 1187 |
-
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 1188 |
-
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 1189 |
-
if ga_steps is not None and ga_steps > 1:
|
| 1190 |
-
from transformers import __version__ as transformers_version
|
| 1191 |
-
if Version(transformers_version) <= Version('4.45.2'):
|
| 1192 |
-
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 1193 |
-
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 1194 |
-
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 1195 |
-
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 1196 |
-
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 1197 |
-
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 1198 |
-
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 1199 |
-
if type(fp16_full_eval) is not bool: fp16_full_eval = False
|
| 1200 |
-
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 1201 |
-
if type(bf16_full_eval) is not bool: bf16_full_eval = False
|
| 1202 |
-
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 1203 |
-
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 1204 |
-
if force_float32:
|
| 1205 |
-
args.bf16_full_eval = False
|
| 1206 |
-
args.fp16_full_eval = False
|
| 1207 |
-
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 1208 |
-
args.bf16_full_eval = True
|
| 1209 |
-
args.fp16_full_eval = False
|
| 1210 |
-
elif not bf16_full_eval and not fp16_full_eval:
|
| 1211 |
-
args.bf16_full_eval = args.bf16
|
| 1212 |
-
args.fp16_full_eval = args.fp16
|
| 1213 |
-
_output_logits = False
|
| 1214 |
-
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 1215 |
-
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 1216 |
-
if _output_logits:
|
| 1217 |
-
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 1218 |
-
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 1219 |
-
pass
|
| 1220 |
-
else:
|
| 1221 |
-
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 1222 |
-
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 1223 |
-
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 1224 |
-
max_seq_length = model.max_seq_length
|
| 1225 |
-
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 1226 |
-
if model is not None and hasattr(model, 'for_training'):
|
| 1227 |
-
model.for_training()
|
| 1228 |
-
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 1229 |
-
if 'processing_class' in locals():
|
| 1230 |
-
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 1231 |
-
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 1232 |
-
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 1233 |
-
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 1234 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1235 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 1236 |
-
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
|
| 1237 |
-
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 1238 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
| 1239 |
-
else:
|
| 1240 |
-
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 1241 |
-
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 1242 |
-
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 1243 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1244 |
-
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 1245 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 1246 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
| 1247 |
-
else:
|
| 1248 |
-
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
|
| 1249 |
-
other_metrics = []
|
| 1250 |
-
|
| 1251 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 1252 |
-
PatchRLStatistics('ppo_trainer', other_metrics)
|
| 1253 |
-
|
| 1254 |
-
super().__init__(
|
| 1255 |
-
args = args,
|
| 1256 |
-
processing_class = processing_class,
|
| 1257 |
-
model = model,
|
| 1258 |
-
ref_model = ref_model,
|
| 1259 |
-
reward_model = reward_model,
|
| 1260 |
-
train_dataset = train_dataset,
|
| 1261 |
-
value_model = value_model,
|
| 1262 |
-
data_collator = data_collator,
|
| 1263 |
-
eval_dataset = eval_dataset,
|
| 1264 |
-
callbacks = callbacks,
|
| 1265 |
-
peft_config = peft_config,**kwargs)
|
| 1266 |
-
if hasattr(self, 'neftune_hook_handle'):
|
| 1267 |
-
self.neftune_hook_handle.remove()
|
| 1268 |
-
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 1269 |
-
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 1270 |
-
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 1271 |
-
pass
|
| 1272 |
-
|
| 1273 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_run_uploads/UnslothPRMTrainer.py
DELETED
|
@@ -1,809 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
2025.7.11
|
| 3 |
-
2025.7.11
|
| 4 |
-
4.54.1
|
| 5 |
-
0.16.1
|
| 6 |
-
__UNSLOTH_VERSIONING__
|
| 7 |
-
"""
|
| 8 |
-
from torch import Tensor
|
| 9 |
-
import torch
|
| 10 |
-
import torch.nn as nn
|
| 11 |
-
from torch.nn import functional as F
|
| 12 |
-
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
| 13 |
-
from trl.trainer.prm_trainer import (BaseImageProcessor, Callable, DataCollator, DataCollatorForTokenClassification, Dataset, EvalPrediction, FeatureExtractionMixin, Optional, PRMConfig, PRMTrainer, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, chain, compute_accuracy, disable_dropout_in_model, features, generate_model_card, inspect, is_peft_available, is_wandb_available, nn, os, prepare_model_for_kbit_training, textwrap, torch, wandb, warnings)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
import os
|
| 17 |
-
from typing import *
|
| 18 |
-
from dataclasses import dataclass, field
|
| 19 |
-
from packaging.version import Version
|
| 20 |
-
import torch
|
| 21 |
-
import numpy as np
|
| 22 |
-
from contextlib import nullcontext
|
| 23 |
-
from torch.nn import functional as F
|
| 24 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
| 25 |
-
|
| 26 |
-
torch_compile_options = {
|
| 27 |
-
"epilogue_fusion" : True,
|
| 28 |
-
"max_autotune" : False,
|
| 29 |
-
"shape_padding" : True,
|
| 30 |
-
"trace.enabled" : False,
|
| 31 |
-
"triton.cudagraphs" : False,
|
| 32 |
-
}
|
| 33 |
-
|
| 34 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 35 |
-
def chunked_selective_log_softmax(logits, index):
|
| 36 |
-
# Split into 4 chunks only
|
| 37 |
-
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
|
| 38 |
-
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
|
| 39 |
-
all_per_token_logps = []
|
| 40 |
-
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
|
| 41 |
-
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
|
| 42 |
-
chunk_logits = chunk_logits.to(torch.float32)
|
| 43 |
-
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 44 |
-
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
|
| 45 |
-
per_token_logps = selected_logits - logsumexp_values
|
| 46 |
-
all_per_token_logps.append(per_token_logps)
|
| 47 |
-
pass
|
| 48 |
-
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 49 |
-
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
|
| 50 |
-
return all_per_token_logps
|
| 51 |
-
@dataclass
|
| 52 |
-
class UnslothPRMConfig(PRMConfig):
|
| 53 |
-
"""
|
| 54 |
-
|
| 55 |
-
Configuration class for the [`PRMTrainer`].
|
| 56 |
-
|
| 57 |
-
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 58 |
-
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 59 |
-
command line.
|
| 60 |
-
|
| 61 |
-
Parameters:
|
| 62 |
-
learning_rate (`float`, *optional*, defaults to `1e-5`):
|
| 63 |
-
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
| 64 |
-
[`~transformers.TrainingArguments`].
|
| 65 |
-
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
| 66 |
-
Maximum length of the sequences (prompt + completion) used for truncation.
|
| 67 |
-
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
| 68 |
-
Maximum length of the prompt used for truncation.
|
| 69 |
-
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
|
| 70 |
-
Maximum length of the completion used for truncation. The completion is the concatenation of the steps.
|
| 71 |
-
disable_dropout (`bool`, *optional*, defaults to `True`):
|
| 72 |
-
Whether to disable dropout in the model.
|
| 73 |
-
step_separator (`str`, *optional*, defaults to `"\n"`):
|
| 74 |
-
Separator used to separate each step of the reasoning process.
|
| 75 |
-
train_on_last_step_only (`bool`, *optional*, defaults to `False`):
|
| 76 |
-
Whether to train only on the last step.
|
| 77 |
-
dataset_num_proc (`int`, *optional*, defaults to `None`):
|
| 78 |
-
Number of processes to use for processing the dataset.
|
| 79 |
-
|
| 80 |
-
"""
|
| 81 |
-
vllm_sampling_params: Optional[Any] = field(
|
| 82 |
-
default = None,
|
| 83 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
| 84 |
-
)
|
| 85 |
-
unsloth_num_chunks : Optional[int] = field(
|
| 86 |
-
default = -1,
|
| 87 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 88 |
-
)
|
| 89 |
-
def __init__(
|
| 90 |
-
self,
|
| 91 |
-
output_dir = None,
|
| 92 |
-
overwrite_output_dir = None,
|
| 93 |
-
do_train = False,
|
| 94 |
-
do_eval = False,
|
| 95 |
-
do_predict = False,
|
| 96 |
-
eval_strategy = 'no',
|
| 97 |
-
prediction_loss_only = False,
|
| 98 |
-
per_device_train_batch_size = 4,
|
| 99 |
-
per_device_eval_batch_size = 4,
|
| 100 |
-
per_gpu_train_batch_size = None,
|
| 101 |
-
per_gpu_eval_batch_size = None,
|
| 102 |
-
gradient_accumulation_steps = 2,
|
| 103 |
-
eval_accumulation_steps = 2,
|
| 104 |
-
eval_delay = 0,
|
| 105 |
-
torch_empty_cache_steps = 250,
|
| 106 |
-
learning_rate = 5e-05,
|
| 107 |
-
weight_decay = 0.01,
|
| 108 |
-
adam_beta1 = 0.9,
|
| 109 |
-
adam_beta2 = 0.999,
|
| 110 |
-
adam_epsilon = 1e-08,
|
| 111 |
-
max_grad_norm = 1.0,
|
| 112 |
-
num_train_epochs = 3.0,
|
| 113 |
-
max_steps = -1,
|
| 114 |
-
lr_scheduler_type = 'linear',
|
| 115 |
-
warmup_ratio = 0.1,
|
| 116 |
-
warmup_steps = 0,
|
| 117 |
-
log_level = 'passive',
|
| 118 |
-
log_level_replica = 'warning',
|
| 119 |
-
log_on_each_node = True,
|
| 120 |
-
logging_dir = None,
|
| 121 |
-
logging_strategy = 'steps',
|
| 122 |
-
logging_first_step = False,
|
| 123 |
-
logging_steps = 1,
|
| 124 |
-
logging_nan_inf_filter = False,
|
| 125 |
-
save_strategy = 'steps',
|
| 126 |
-
save_steps = 500,
|
| 127 |
-
save_total_limit = None,
|
| 128 |
-
save_safetensors = True,
|
| 129 |
-
save_on_each_node = False,
|
| 130 |
-
save_only_model = False,
|
| 131 |
-
restore_callback_states_from_checkpoint = False,
|
| 132 |
-
no_cuda = False,
|
| 133 |
-
use_cpu = False,
|
| 134 |
-
use_mps_device = False,
|
| 135 |
-
seed = 3407,
|
| 136 |
-
data_seed = 3407,
|
| 137 |
-
jit_mode_eval = False,
|
| 138 |
-
use_ipex = False,
|
| 139 |
-
bf16 = False,
|
| 140 |
-
fp16 = False,
|
| 141 |
-
fp16_opt_level = 'O1',
|
| 142 |
-
half_precision_backend = 'auto',
|
| 143 |
-
bf16_full_eval = False,
|
| 144 |
-
fp16_full_eval = False,
|
| 145 |
-
tf32 = None,
|
| 146 |
-
local_rank = -1,
|
| 147 |
-
ddp_backend = None,
|
| 148 |
-
tpu_num_cores = None,
|
| 149 |
-
tpu_metrics_debug = False,
|
| 150 |
-
debug = '',
|
| 151 |
-
dataloader_drop_last = False,
|
| 152 |
-
eval_steps = None,
|
| 153 |
-
dataloader_num_workers = 0,
|
| 154 |
-
dataloader_prefetch_factor = None,
|
| 155 |
-
past_index = -1,
|
| 156 |
-
run_name = None,
|
| 157 |
-
disable_tqdm = None,
|
| 158 |
-
remove_unused_columns = True,
|
| 159 |
-
label_names = None,
|
| 160 |
-
load_best_model_at_end = False,
|
| 161 |
-
metric_for_best_model = None,
|
| 162 |
-
greater_is_better = None,
|
| 163 |
-
ignore_data_skip = False,
|
| 164 |
-
fsdp = '',
|
| 165 |
-
fsdp_min_num_params = 0,
|
| 166 |
-
fsdp_config = None,
|
| 167 |
-
fsdp_transformer_layer_cls_to_wrap = None,
|
| 168 |
-
accelerator_config = None,
|
| 169 |
-
deepspeed = None,
|
| 170 |
-
label_smoothing_factor = 0.0,
|
| 171 |
-
optim = 'adamw_8bit',
|
| 172 |
-
optim_args = None,
|
| 173 |
-
adafactor = False,
|
| 174 |
-
group_by_length = False,
|
| 175 |
-
length_column_name = 'length',
|
| 176 |
-
report_to = None,
|
| 177 |
-
ddp_find_unused_parameters = None,
|
| 178 |
-
ddp_bucket_cap_mb = None,
|
| 179 |
-
ddp_broadcast_buffers = None,
|
| 180 |
-
dataloader_pin_memory = True,
|
| 181 |
-
dataloader_persistent_workers = False,
|
| 182 |
-
skip_memory_metrics = True,
|
| 183 |
-
use_legacy_prediction_loop = False,
|
| 184 |
-
push_to_hub = False,
|
| 185 |
-
resume_from_checkpoint = None,
|
| 186 |
-
hub_model_id = None,
|
| 187 |
-
hub_strategy = 'every_save',
|
| 188 |
-
hub_token = None,
|
| 189 |
-
hub_private_repo = None,
|
| 190 |
-
hub_always_push = False,
|
| 191 |
-
hub_revision = None,
|
| 192 |
-
gradient_checkpointing = False,
|
| 193 |
-
gradient_checkpointing_kwargs = None,
|
| 194 |
-
include_inputs_for_metrics = False,
|
| 195 |
-
eval_do_concat_batches = True,
|
| 196 |
-
fp16_backend = 'auto',
|
| 197 |
-
push_to_hub_model_id = None,
|
| 198 |
-
push_to_hub_organization = None,
|
| 199 |
-
push_to_hub_token = None,
|
| 200 |
-
mp_parameters = '',
|
| 201 |
-
auto_find_batch_size = True,
|
| 202 |
-
full_determinism = False,
|
| 203 |
-
torchdynamo = None,
|
| 204 |
-
ray_scope = 'last',
|
| 205 |
-
ddp_timeout = 1800,
|
| 206 |
-
torch_compile = False,
|
| 207 |
-
torch_compile_backend = None,
|
| 208 |
-
torch_compile_mode = None,
|
| 209 |
-
include_tokens_per_second = False,
|
| 210 |
-
include_num_input_tokens_seen = False,
|
| 211 |
-
neftune_noise_alpha = None,
|
| 212 |
-
optim_target_modules = None,
|
| 213 |
-
batch_eval_metrics = False,
|
| 214 |
-
eval_on_start = False,
|
| 215 |
-
use_liger_kernel = False,
|
| 216 |
-
liger_kernel_config = None,
|
| 217 |
-
eval_use_gather_object = False,
|
| 218 |
-
average_tokens_across_devices = True,
|
| 219 |
-
max_length = 1024,
|
| 220 |
-
max_prompt_length = 512,
|
| 221 |
-
max_completion_length = None,
|
| 222 |
-
disable_dropout = True,
|
| 223 |
-
step_separator = '\
|
| 224 |
-
',
|
| 225 |
-
train_on_last_step_only = False,
|
| 226 |
-
dataset_num_proc = None,
|
| 227 |
-
vllm_sampling_params = None,
|
| 228 |
-
unsloth_num_chunks = -1,
|
| 229 |
-
**kwargs,
|
| 230 |
-
):
|
| 231 |
-
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 232 |
-
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 233 |
-
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 234 |
-
output_dir = 'unsloth_training_checkpoints'
|
| 235 |
-
save_strategy = 'no'
|
| 236 |
-
if dataset_num_proc is None:
|
| 237 |
-
from multiprocessing import cpu_count
|
| 238 |
-
dataset_num_proc = min(cpu_count()*2, 2)
|
| 239 |
-
|
| 240 |
-
super().__init__(
|
| 241 |
-
output_dir = output_dir,
|
| 242 |
-
overwrite_output_dir = overwrite_output_dir,
|
| 243 |
-
do_train = do_train,
|
| 244 |
-
do_eval = do_eval,
|
| 245 |
-
do_predict = do_predict,
|
| 246 |
-
eval_strategy = eval_strategy,
|
| 247 |
-
prediction_loss_only = prediction_loss_only,
|
| 248 |
-
per_device_train_batch_size = per_device_train_batch_size,
|
| 249 |
-
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 250 |
-
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
| 251 |
-
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
| 252 |
-
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 253 |
-
eval_accumulation_steps = eval_accumulation_steps,
|
| 254 |
-
eval_delay = eval_delay,
|
| 255 |
-
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 256 |
-
learning_rate = learning_rate,
|
| 257 |
-
weight_decay = weight_decay,
|
| 258 |
-
adam_beta1 = adam_beta1,
|
| 259 |
-
adam_beta2 = adam_beta2,
|
| 260 |
-
adam_epsilon = adam_epsilon,
|
| 261 |
-
max_grad_norm = max_grad_norm,
|
| 262 |
-
num_train_epochs = num_train_epochs,
|
| 263 |
-
max_steps = max_steps,
|
| 264 |
-
lr_scheduler_type = lr_scheduler_type,
|
| 265 |
-
warmup_ratio = warmup_ratio,
|
| 266 |
-
warmup_steps = warmup_steps,
|
| 267 |
-
log_level = log_level,
|
| 268 |
-
log_level_replica = log_level_replica,
|
| 269 |
-
log_on_each_node = log_on_each_node,
|
| 270 |
-
logging_dir = logging_dir,
|
| 271 |
-
logging_strategy = logging_strategy,
|
| 272 |
-
logging_first_step = logging_first_step,
|
| 273 |
-
logging_steps = logging_steps,
|
| 274 |
-
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 275 |
-
save_strategy = save_strategy,
|
| 276 |
-
save_steps = save_steps,
|
| 277 |
-
save_total_limit = save_total_limit,
|
| 278 |
-
save_safetensors = save_safetensors,
|
| 279 |
-
save_on_each_node = save_on_each_node,
|
| 280 |
-
save_only_model = save_only_model,
|
| 281 |
-
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 282 |
-
no_cuda = no_cuda,
|
| 283 |
-
use_cpu = use_cpu,
|
| 284 |
-
use_mps_device = use_mps_device,
|
| 285 |
-
seed = seed,
|
| 286 |
-
data_seed = data_seed,
|
| 287 |
-
jit_mode_eval = jit_mode_eval,
|
| 288 |
-
use_ipex = use_ipex,
|
| 289 |
-
bf16 = bf16,
|
| 290 |
-
fp16 = fp16,
|
| 291 |
-
fp16_opt_level = fp16_opt_level,
|
| 292 |
-
half_precision_backend = half_precision_backend,
|
| 293 |
-
bf16_full_eval = bf16_full_eval,
|
| 294 |
-
fp16_full_eval = fp16_full_eval,
|
| 295 |
-
tf32 = tf32,
|
| 296 |
-
local_rank = local_rank,
|
| 297 |
-
ddp_backend = ddp_backend,
|
| 298 |
-
tpu_num_cores = tpu_num_cores,
|
| 299 |
-
tpu_metrics_debug = tpu_metrics_debug,
|
| 300 |
-
debug = debug,
|
| 301 |
-
dataloader_drop_last = dataloader_drop_last,
|
| 302 |
-
eval_steps = eval_steps,
|
| 303 |
-
dataloader_num_workers = dataloader_num_workers,
|
| 304 |
-
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 305 |
-
past_index = past_index,
|
| 306 |
-
run_name = run_name,
|
| 307 |
-
disable_tqdm = disable_tqdm,
|
| 308 |
-
remove_unused_columns = remove_unused_columns,
|
| 309 |
-
label_names = label_names,
|
| 310 |
-
load_best_model_at_end = load_best_model_at_end,
|
| 311 |
-
metric_for_best_model = metric_for_best_model,
|
| 312 |
-
greater_is_better = greater_is_better,
|
| 313 |
-
ignore_data_skip = ignore_data_skip,
|
| 314 |
-
fsdp = fsdp,
|
| 315 |
-
fsdp_min_num_params = fsdp_min_num_params,
|
| 316 |
-
fsdp_config = fsdp_config,
|
| 317 |
-
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
| 318 |
-
accelerator_config = accelerator_config,
|
| 319 |
-
deepspeed = deepspeed,
|
| 320 |
-
label_smoothing_factor = label_smoothing_factor,
|
| 321 |
-
optim = optim,
|
| 322 |
-
optim_args = optim_args,
|
| 323 |
-
adafactor = adafactor,
|
| 324 |
-
group_by_length = group_by_length,
|
| 325 |
-
length_column_name = length_column_name,
|
| 326 |
-
report_to = report_to,
|
| 327 |
-
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 328 |
-
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 329 |
-
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 330 |
-
dataloader_pin_memory = dataloader_pin_memory,
|
| 331 |
-
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 332 |
-
skip_memory_metrics = skip_memory_metrics,
|
| 333 |
-
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
| 334 |
-
push_to_hub = push_to_hub,
|
| 335 |
-
resume_from_checkpoint = resume_from_checkpoint,
|
| 336 |
-
hub_model_id = hub_model_id,
|
| 337 |
-
hub_strategy = hub_strategy,
|
| 338 |
-
hub_token = hub_token,
|
| 339 |
-
hub_private_repo = hub_private_repo,
|
| 340 |
-
hub_always_push = hub_always_push,
|
| 341 |
-
hub_revision = hub_revision,
|
| 342 |
-
gradient_checkpointing = gradient_checkpointing,
|
| 343 |
-
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 344 |
-
include_inputs_for_metrics = include_inputs_for_metrics,
|
| 345 |
-
eval_do_concat_batches = eval_do_concat_batches,
|
| 346 |
-
fp16_backend = fp16_backend,
|
| 347 |
-
push_to_hub_model_id = push_to_hub_model_id,
|
| 348 |
-
push_to_hub_organization = push_to_hub_organization,
|
| 349 |
-
push_to_hub_token = push_to_hub_token,
|
| 350 |
-
mp_parameters = mp_parameters,
|
| 351 |
-
auto_find_batch_size = auto_find_batch_size,
|
| 352 |
-
full_determinism = full_determinism,
|
| 353 |
-
torchdynamo = torchdynamo,
|
| 354 |
-
ray_scope = ray_scope,
|
| 355 |
-
ddp_timeout = ddp_timeout,
|
| 356 |
-
torch_compile = torch_compile,
|
| 357 |
-
torch_compile_backend = torch_compile_backend,
|
| 358 |
-
torch_compile_mode = torch_compile_mode,
|
| 359 |
-
include_tokens_per_second = include_tokens_per_second,
|
| 360 |
-
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 361 |
-
neftune_noise_alpha = neftune_noise_alpha,
|
| 362 |
-
optim_target_modules = optim_target_modules,
|
| 363 |
-
batch_eval_metrics = batch_eval_metrics,
|
| 364 |
-
eval_on_start = eval_on_start,
|
| 365 |
-
use_liger_kernel = use_liger_kernel,
|
| 366 |
-
liger_kernel_config = liger_kernel_config,
|
| 367 |
-
eval_use_gather_object = eval_use_gather_object,
|
| 368 |
-
average_tokens_across_devices = average_tokens_across_devices,
|
| 369 |
-
max_length = max_length,
|
| 370 |
-
max_prompt_length = max_prompt_length,
|
| 371 |
-
max_completion_length = max_completion_length,
|
| 372 |
-
disable_dropout = disable_dropout,
|
| 373 |
-
step_separator = step_separator,
|
| 374 |
-
train_on_last_step_only = train_on_last_step_only,
|
| 375 |
-
dataset_num_proc = dataset_num_proc,**kwargs)
|
| 376 |
-
self.vllm_sampling_params = vllm_sampling_params
|
| 377 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
| 378 |
-
pass
|
| 379 |
-
|
| 380 |
-
class _UnslothPRMTrainer(Trainer):
|
| 381 |
-
""""""
|
| 382 |
-
|
| 383 |
-
_tag_names = ["trl", "prm"]
|
| 384 |
-
|
| 385 |
-
def __init__(
|
| 386 |
-
self,
|
| 387 |
-
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
|
| 388 |
-
args: Optional[PRMConfig] = None,
|
| 389 |
-
data_collator: Optional[DataCollator] = None,
|
| 390 |
-
train_dataset: Optional[Dataset] = None,
|
| 391 |
-
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 392 |
-
processing_class: Optional[
|
| 393 |
-
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 394 |
-
] = None,
|
| 395 |
-
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
| 396 |
-
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
| 397 |
-
callbacks: Optional[list[TrainerCallback]] = None,
|
| 398 |
-
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
|
| 399 |
-
None,
|
| 400 |
-
None,
|
| 401 |
-
),
|
| 402 |
-
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 403 |
-
peft_config: Optional[dict] = None,
|
| 404 |
-
):
|
| 405 |
-
if not is_peft_available() and peft_config is not None:
|
| 406 |
-
raise ValueError(
|
| 407 |
-
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
| 408 |
-
)
|
| 409 |
-
elif is_peft_available() and peft_config is not None:
|
| 410 |
-
if not isinstance(model, PeftModel):
|
| 411 |
-
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
|
| 412 |
-
_supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
|
| 413 |
-
inspect.signature(prepare_model_for_kbit_training).parameters
|
| 414 |
-
)
|
| 415 |
-
|
| 416 |
-
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
| 417 |
-
|
| 418 |
-
if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
| 419 |
-
warnings.warn(
|
| 420 |
-
"You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
|
| 421 |
-
"please update to the latest version of peft to use `gradient_checkpointing_kwargs`."
|
| 422 |
-
)
|
| 423 |
-
elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
| 424 |
-
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
| 425 |
-
|
| 426 |
-
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
| 427 |
-
|
| 428 |
-
model = model
|
| 429 |
-
|
| 430 |
-
# Disable dropout in the model
|
| 431 |
-
if args.disable_dropout:
|
| 432 |
-
disable_dropout_in_model(model)
|
| 433 |
-
|
| 434 |
-
if compute_metrics is None:
|
| 435 |
-
compute_metrics = compute_accuracy
|
| 436 |
-
|
| 437 |
-
if data_collator is None:
|
| 438 |
-
if processing_class is None:
|
| 439 |
-
raise ValueError(
|
| 440 |
-
"A processing_class must be specified when using the default DataCollatorForTokenClassification"
|
| 441 |
-
)
|
| 442 |
-
data_collator = DataCollatorForTokenClassification(processing_class, max_length=args.max_length)
|
| 443 |
-
|
| 444 |
-
if "input_ids" not in train_dataset.column_names:
|
| 445 |
-
with PartialState().main_process_first():
|
| 446 |
-
fn_kwargs = {
|
| 447 |
-
"tokenizer": processing_class,
|
| 448 |
-
"step_separator": args.step_separator,
|
| 449 |
-
"max_length": args.max_length,
|
| 450 |
-
"max_prompt_length": args.max_prompt_length,
|
| 451 |
-
"max_completion_length": args.max_completion_length,
|
| 452 |
-
"train_on_last_step_only": args.train_on_last_step_only,
|
| 453 |
-
}
|
| 454 |
-
train_fn_kwargs = {**fn_kwargs, "is_eval": False}
|
| 455 |
-
train_dataset = train_dataset.map(
|
| 456 |
-
self.tokenize_row,
|
| 457 |
-
fn_kwargs=train_fn_kwargs,
|
| 458 |
-
num_proc=args.dataset_num_proc,
|
| 459 |
-
remove_columns=train_dataset.features,
|
| 460 |
-
desc="Tokenizing train dataset",
|
| 461 |
-
features=features.Features( # needed to avoid map to cast labels to bool
|
| 462 |
-
{
|
| 463 |
-
"labels": features.Sequence(features.Value("int64")),
|
| 464 |
-
"input_ids": features.Sequence(features.Value("int64")),
|
| 465 |
-
}
|
| 466 |
-
),
|
| 467 |
-
)
|
| 468 |
-
|
| 469 |
-
eval_fn_kwargs = {**fn_kwargs, "is_eval": True}
|
| 470 |
-
if eval_dataset is not None:
|
| 471 |
-
eval_dataset = eval_dataset.map(
|
| 472 |
-
self.tokenize_row,
|
| 473 |
-
fn_kwargs=eval_fn_kwargs,
|
| 474 |
-
num_proc=args.dataset_num_proc,
|
| 475 |
-
remove_columns=eval_dataset.features,
|
| 476 |
-
desc="Tokenizing eval dataset",
|
| 477 |
-
features=features.Features( # needed to avoid map to cast labels to bool
|
| 478 |
-
{
|
| 479 |
-
"labels": features.Sequence(features.Value("int64")),
|
| 480 |
-
"input_ids": features.Sequence(features.Value("int64")),
|
| 481 |
-
}
|
| 482 |
-
),
|
| 483 |
-
)
|
| 484 |
-
|
| 485 |
-
super().__init__(
|
| 486 |
-
model=model,
|
| 487 |
-
args=args,
|
| 488 |
-
data_collator=data_collator,
|
| 489 |
-
train_dataset=train_dataset,
|
| 490 |
-
eval_dataset=eval_dataset,
|
| 491 |
-
processing_class=processing_class,
|
| 492 |
-
model_init=model_init,
|
| 493 |
-
compute_metrics=compute_metrics,
|
| 494 |
-
callbacks=callbacks,
|
| 495 |
-
optimizers=optimizers,
|
| 496 |
-
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 497 |
-
)
|
| 498 |
-
|
| 499 |
-
# Add tags for models that have been loaded with the correct transformers version
|
| 500 |
-
if hasattr(self.model, "add_model_tags"):
|
| 501 |
-
self.model.add_model_tags(self._tag_names)
|
| 502 |
-
|
| 503 |
-
@staticmethod
|
| 504 |
-
def tokenize_row(
|
| 505 |
-
features,
|
| 506 |
-
tokenizer,
|
| 507 |
-
step_separator,
|
| 508 |
-
max_length,
|
| 509 |
-
max_prompt_length,
|
| 510 |
-
max_completion_length,
|
| 511 |
-
train_on_last_step_only,
|
| 512 |
-
is_eval,
|
| 513 |
-
):
|
| 514 |
-
r"""
|
| 515 |
-
Tokenize a row of the dataset.
|
| 516 |
-
|
| 517 |
-
Args:
|
| 518 |
-
features (`dict[str, str]`):
|
| 519 |
-
Row of the dataset, should contain the keys `"prompt"`, `"completions"`, and `"labels"`.
|
| 520 |
-
tokenizer (`PreTrainedTokenizerBase`):
|
| 521 |
-
Tokenizer used to process the data.
|
| 522 |
-
step_separator (`str`):
|
| 523 |
-
Separator between steps in the completion.
|
| 524 |
-
max_length (`int` or `None`):
|
| 525 |
-
Maximum length of the sequences (prompt + completion). If `None`, the sequences are not truncated.
|
| 526 |
-
max_prompt_length (`int` or `None`):
|
| 527 |
-
Maximum length of the prompt. If `None`, the prompt is not truncated.
|
| 528 |
-
max_completion_length (`int` or `None`):
|
| 529 |
-
Maximum length of the completion sequences. If `None`, the completion sequences are not truncated.
|
| 530 |
-
train_on_last_step_only (`bool`):
|
| 531 |
-
Whether to train only on the last step. If `True`, the labels are `-100` for all tokens except the last
|
| 532 |
-
token of the completion.
|
| 533 |
-
is_eval (`bool`):
|
| 534 |
-
Whether the function is used to tokenize samples from a training or an evaluation dataset. Used only if `train_on_last_step_only` is set to `True`.
|
| 535 |
-
|
| 536 |
-
Returns:
|
| 537 |
-
`dict[str, list[int]]`:
|
| 538 |
-
Tokenized sequences with the keys `"input_ids"`, and `"labels".
|
| 539 |
-
|
| 540 |
-
Example:
|
| 541 |
-
```python
|
| 542 |
-
>>> from transformers import AutoTokenizer
|
| 543 |
-
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
|
| 544 |
-
>>> features = {"prompt": "Which number is larger, 9.8 or 9.11?",
|
| 545 |
-
... "completions": ["11 is greater than 8.",
|
| 546 |
-
... "Hence, 9.11 > 9.8."],
|
| 547 |
-
... "labels": [True, False]}
|
| 548 |
-
>>> PRMTrainer.tokenize_row(features, tokenizer, "\n", max_completion_length=None, train_on_last_step_only=False, is_eval=False)
|
| 549 |
-
{'input_ids': [23085, 1372, 374, 8131, 11, 220, 24, 13, 23, 476, 220, 24, 13, 16, 16, 30, 16, 16, 374, 7046, 1091, 220, 23, 13, 198, 39, 763, 11, 220, 24, 13, 16, 16, 861, 220, 24, 13, 23, 13, 198],
|
| 550 |
-
'labels': [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0]}
|
| 551 |
-
```
|
| 552 |
-
"""
|
| 553 |
-
# Tokenize the prompt and completions
|
| 554 |
-
prompt_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"]
|
| 555 |
-
completions_ids = [
|
| 556 |
-
tokenizer(completion, add_special_tokens=False)["input_ids"] for completion in features["completions"]
|
| 557 |
-
]
|
| 558 |
-
if train_on_last_step_only and not is_eval:
|
| 559 |
-
labels = [-100] * (len(features["labels"]) - 1) + [int(features["labels"][-1])]
|
| 560 |
-
else:
|
| 561 |
-
labels = [int(label) for label in features["labels"]]
|
| 562 |
-
|
| 563 |
-
# Get the ID of the separator token and add it to the completions
|
| 564 |
-
separator_ids = tokenizer.encode(step_separator, add_special_tokens=False)
|
| 565 |
-
completions_ids = [completion + separator_ids for completion in completions_ids]
|
| 566 |
-
|
| 567 |
-
# Create the label
|
| 568 |
-
labels = [[-100] * (len(completion) - 1) + [label] for completion, label in zip(completions_ids, labels)]
|
| 569 |
-
|
| 570 |
-
# Join the completions and labels steps
|
| 571 |
-
completion_ids = list(chain(*completions_ids))
|
| 572 |
-
labels = list(chain(*labels))
|
| 573 |
-
|
| 574 |
-
if tokenizer.bos_token_id is not None:
|
| 575 |
-
prompt_ids = [tokenizer.bos_token_id] + prompt_ids
|
| 576 |
-
|
| 577 |
-
# Truncate prompt and completion sequences
|
| 578 |
-
if max_prompt_length is not None:
|
| 579 |
-
prompt_ids = prompt_ids[-max_prompt_length:]
|
| 580 |
-
if max_completion_length is not None:
|
| 581 |
-
completion_ids = completion_ids[:max_completion_length]
|
| 582 |
-
labels = labels[:max_completion_length]
|
| 583 |
-
|
| 584 |
-
input_ids = prompt_ids + completion_ids
|
| 585 |
-
labels = [-100] * len(prompt_ids) + labels
|
| 586 |
-
|
| 587 |
-
if max_length is not None:
|
| 588 |
-
input_ids = input_ids[:max_length]
|
| 589 |
-
labels = labels[:max_length]
|
| 590 |
-
|
| 591 |
-
return {"input_ids": input_ids, "labels": labels}
|
| 592 |
-
|
| 593 |
-
def create_model_card(
|
| 594 |
-
self,
|
| 595 |
-
model_name: Optional[str] = None,
|
| 596 |
-
dataset_name: Optional[str] = None,
|
| 597 |
-
tags: Union[str, list[str], None] = None,
|
| 598 |
-
):
|
| 599 |
-
"""
|
| 600 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
| 601 |
-
|
| 602 |
-
Args:
|
| 603 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 604 |
-
Name of the model.
|
| 605 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 606 |
-
Name of the dataset used for training.
|
| 607 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 608 |
-
Tags to be associated with the model card.
|
| 609 |
-
"""
|
| 610 |
-
if not self.is_world_process_zero():
|
| 611 |
-
return
|
| 612 |
-
|
| 613 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 614 |
-
base_model = self.model.config._name_or_path
|
| 615 |
-
else:
|
| 616 |
-
base_model = None
|
| 617 |
-
|
| 618 |
-
tags = tags or []
|
| 619 |
-
if isinstance(tags, str):
|
| 620 |
-
tags = [tags]
|
| 621 |
-
|
| 622 |
-
if hasattr(self.model.config, "unsloth_version"):
|
| 623 |
-
tags.append("unsloth")
|
| 624 |
-
|
| 625 |
-
citation = textwrap.dedent("""\
|
| 626 |
-
@article{uesato2022solving,
|
| 627 |
-
title = {{Solving Math Word Problems With Process- and Outcome-Based Feedback}},
|
| 628 |
-
author = {Uesato, Jonathan and Kushman, Nate and Kumar, Ramana and Song, Francis and Siegel, Noah and Wang, Lisa and Creswell, Antonia and Irving, Geoffrey and Higgins, Irina},
|
| 629 |
-
year = 2022,
|
| 630 |
-
journal = {arXiv preprint arXiv:2211.14275}
|
| 631 |
-
}""")
|
| 632 |
-
|
| 633 |
-
model_card = generate_model_card(
|
| 634 |
-
base_model=base_model,
|
| 635 |
-
model_name=model_name,
|
| 636 |
-
hub_model_id=self.hub_model_id,
|
| 637 |
-
dataset_name=dataset_name,
|
| 638 |
-
tags=tags,
|
| 639 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 640 |
-
trainer_name="PRM",
|
| 641 |
-
trainer_citation=citation,
|
| 642 |
-
paper_title="Solving math word problems with process-and outcome-based feedback",
|
| 643 |
-
)
|
| 644 |
-
|
| 645 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 646 |
-
class UnslothPRMTrainer(_UnslothPRMTrainer):
|
| 647 |
-
"""
|
| 648 |
-
|
| 649 |
-
Initialize PRMTrainer.
|
| 650 |
-
|
| 651 |
-
Args:
|
| 652 |
-
model (`transformers.PreTrainedModel`):
|
| 653 |
-
The model to train, preferably an `AutoModelForTokenClassification`.
|
| 654 |
-
args (`PRMConfig`):
|
| 655 |
-
The arguments to use for training.
|
| 656 |
-
data_collator (`transformers.DataCollator`):
|
| 657 |
-
The data collator to use for training. If None is specified, the default data collator (`DataCollatorForTokenClassification`) will be used
|
| 658 |
-
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
| 659 |
-
train_dataset (`datasets.Dataset`):
|
| 660 |
-
The dataset to use for training.
|
| 661 |
-
eval_dataset (`datasets.Dataset`):
|
| 662 |
-
The dataset to use for evaluation.
|
| 663 |
-
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
| 664 |
-
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
| 665 |
-
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
| 666 |
-
reuse the fine-tuned model.
|
| 667 |
-
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
| 668 |
-
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
| 669 |
-
compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
|
| 670 |
-
The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.
|
| 671 |
-
callbacks (`list[transformers.TrainerCallback]`):
|
| 672 |
-
The callbacks to use for training.
|
| 673 |
-
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
| 674 |
-
The optimizer and scheduler to use for training.
|
| 675 |
-
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
| 676 |
-
The function to use to preprocess the logits before computing the metrics.
|
| 677 |
-
peft_config (`dict`, defaults to `None`):
|
| 678 |
-
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
| 679 |
-
|
| 680 |
-
"""
|
| 681 |
-
def __init__(
|
| 682 |
-
self,
|
| 683 |
-
model = None,
|
| 684 |
-
args = None,
|
| 685 |
-
data_collator = None,
|
| 686 |
-
train_dataset = None,
|
| 687 |
-
eval_dataset = None,
|
| 688 |
-
processing_class = None,
|
| 689 |
-
model_init = None,
|
| 690 |
-
compute_metrics = None,
|
| 691 |
-
callbacks = None,
|
| 692 |
-
preprocess_logits_for_metrics = None,
|
| 693 |
-
peft_config = None,
|
| 694 |
-
**kwargs
|
| 695 |
-
):
|
| 696 |
-
if args is None: args = UnslothPRMConfig()
|
| 697 |
-
use_bf16 = getattr(args, 'bf16', False)
|
| 698 |
-
if type(use_bf16) is not bool: use_bf16 = False
|
| 699 |
-
use_fp16 = getattr(args, 'fp16', False)
|
| 700 |
-
if type(use_fp16) is not bool: use_fp16 = False
|
| 701 |
-
force_float32 = False
|
| 702 |
-
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
| 703 |
-
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 704 |
-
force_float32 = True
|
| 705 |
-
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 706 |
-
dtype = getattr(model.config, 'torch_dtype', None)
|
| 707 |
-
if dtype is None: dtype = model.get_input_embeddings().dtype
|
| 708 |
-
from unsloth_zoo.utils import _get_dtype
|
| 709 |
-
dtype = _get_dtype(dtype)
|
| 710 |
-
float16 = dtype == torch.float16
|
| 711 |
-
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 712 |
-
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 713 |
-
if force_float32:
|
| 714 |
-
args.fp16 = False
|
| 715 |
-
args.bf16 = False
|
| 716 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 717 |
-
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 718 |
-
args.fp16 = float16
|
| 719 |
-
args.bf16 = not float16
|
| 720 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 721 |
-
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 722 |
-
args.eval_strategy = 'steps'
|
| 723 |
-
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 724 |
-
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 725 |
-
if ga_steps is not None and ga_steps > 1:
|
| 726 |
-
from transformers import __version__ as transformers_version
|
| 727 |
-
if Version(transformers_version) <= Version('4.45.2'):
|
| 728 |
-
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 729 |
-
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 730 |
-
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 731 |
-
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 732 |
-
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 733 |
-
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 734 |
-
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 735 |
-
if type(fp16_full_eval) is not bool: fp16_full_eval = False
|
| 736 |
-
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 737 |
-
if type(bf16_full_eval) is not bool: bf16_full_eval = False
|
| 738 |
-
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 739 |
-
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 740 |
-
if force_float32:
|
| 741 |
-
args.bf16_full_eval = False
|
| 742 |
-
args.fp16_full_eval = False
|
| 743 |
-
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 744 |
-
args.bf16_full_eval = True
|
| 745 |
-
args.fp16_full_eval = False
|
| 746 |
-
elif not bf16_full_eval and not fp16_full_eval:
|
| 747 |
-
args.bf16_full_eval = args.bf16
|
| 748 |
-
args.fp16_full_eval = args.fp16
|
| 749 |
-
_output_logits = False
|
| 750 |
-
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 751 |
-
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 752 |
-
if _output_logits:
|
| 753 |
-
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 754 |
-
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 755 |
-
pass
|
| 756 |
-
else:
|
| 757 |
-
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 758 |
-
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 759 |
-
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 760 |
-
max_seq_length = model.max_seq_length
|
| 761 |
-
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 762 |
-
if model is not None and hasattr(model, 'for_training'):
|
| 763 |
-
model.for_training()
|
| 764 |
-
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 765 |
-
if 'processing_class' in locals():
|
| 766 |
-
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 767 |
-
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 768 |
-
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 769 |
-
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 770 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 771 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 772 |
-
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
|
| 773 |
-
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 774 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
| 775 |
-
else:
|
| 776 |
-
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 777 |
-
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 778 |
-
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 779 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 780 |
-
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 781 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 782 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
| 783 |
-
else:
|
| 784 |
-
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
|
| 785 |
-
other_metrics = []
|
| 786 |
-
|
| 787 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 788 |
-
PatchRLStatistics('prm_trainer', other_metrics)
|
| 789 |
-
|
| 790 |
-
super().__init__(
|
| 791 |
-
model = model,
|
| 792 |
-
args = args,
|
| 793 |
-
data_collator = data_collator,
|
| 794 |
-
train_dataset = train_dataset,
|
| 795 |
-
eval_dataset = eval_dataset,
|
| 796 |
-
processing_class = processing_class,
|
| 797 |
-
model_init = model_init,
|
| 798 |
-
compute_metrics = compute_metrics,
|
| 799 |
-
callbacks = callbacks,
|
| 800 |
-
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
| 801 |
-
peft_config = peft_config,**kwargs)
|
| 802 |
-
if hasattr(self, 'neftune_hook_handle'):
|
| 803 |
-
self.neftune_hook_handle.remove()
|
| 804 |
-
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 805 |
-
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 806 |
-
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 807 |
-
pass
|
| 808 |
-
|
| 809 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_run_uploads/UnslothRLOOTrainer.py
DELETED
|
@@ -1,1143 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
2025.7.11
|
| 3 |
-
2025.7.11
|
| 4 |
-
4.54.1
|
| 5 |
-
0.16.1
|
| 6 |
-
__UNSLOTH_VERSIONING__
|
| 7 |
-
"""
|
| 8 |
-
from torch import Tensor
|
| 9 |
-
import torch
|
| 10 |
-
import torch.nn as nn
|
| 11 |
-
from torch.nn import functional as F
|
| 12 |
-
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
| 13 |
-
from trl.trainer.rloo_trainer import (Accelerator, BaseImageProcessor, Callable, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, GenerationConfig, INVALID_LOGPROB, OnlineTrainerState, Optional, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, RLOOConfig, RLOOTrainer, Trainer, TrainerCallback, TrainerControl, Union, batch_generation, broadcast, defaultdict, disable_dropout_in_model, exact_div, first_true_indices, forward, gather_object, gc, generate_model_card, get_comet_experiment_url, get_reporting_integration_callbacks, get_reward, is_wandb_available, log_table_to_comet_experiment, math, nn, np, os, pd, prepare_deepspeed, print_rich_table, selective_log_softmax, textwrap, time, torch, truncate_response, unwrap_model_for_generation, wandb)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
import os
|
| 17 |
-
from typing import *
|
| 18 |
-
from dataclasses import dataclass, field
|
| 19 |
-
from packaging.version import Version
|
| 20 |
-
import torch
|
| 21 |
-
import numpy as np
|
| 22 |
-
from contextlib import nullcontext
|
| 23 |
-
from torch.nn import functional as F
|
| 24 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
| 25 |
-
|
| 26 |
-
torch_compile_options = {
|
| 27 |
-
"epilogue_fusion" : True,
|
| 28 |
-
"max_autotune" : False,
|
| 29 |
-
"shape_padding" : True,
|
| 30 |
-
"trace.enabled" : False,
|
| 31 |
-
"triton.cudagraphs" : False,
|
| 32 |
-
}
|
| 33 |
-
|
| 34 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 35 |
-
def chunked_selective_log_softmax(logits, index):
|
| 36 |
-
# Split into 4 chunks only
|
| 37 |
-
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
|
| 38 |
-
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
|
| 39 |
-
all_per_token_logps = []
|
| 40 |
-
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
|
| 41 |
-
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
|
| 42 |
-
chunk_logits = chunk_logits.to(torch.float32)
|
| 43 |
-
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 44 |
-
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
|
| 45 |
-
per_token_logps = selected_logits - logsumexp_values
|
| 46 |
-
all_per_token_logps.append(per_token_logps)
|
| 47 |
-
pass
|
| 48 |
-
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 49 |
-
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
|
| 50 |
-
return all_per_token_logps
|
| 51 |
-
@dataclass
|
| 52 |
-
class UnslothRLOOConfig(RLOOConfig):
|
| 53 |
-
"""
|
| 54 |
-
|
| 55 |
-
Configuration class for the [`RLOOTrainer`].
|
| 56 |
-
|
| 57 |
-
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 58 |
-
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 59 |
-
command line.
|
| 60 |
-
|
| 61 |
-
Parameters:
|
| 62 |
-
exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[: -len(".py")]`):
|
| 63 |
-
Name of this experiment.
|
| 64 |
-
reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
|
| 65 |
-
Path to the reward model.
|
| 66 |
-
num_ppo_epochs (`int`, *optional*, defaults to `4`):
|
| 67 |
-
Number of epochs to train.
|
| 68 |
-
whiten_rewards (`bool`, *optional*, defaults to `False`):
|
| 69 |
-
Whether to whiten the rewards.
|
| 70 |
-
kl_coef (`float`, *optional*, defaults to `0.05`):
|
| 71 |
-
KL coefficient.
|
| 72 |
-
cliprange (`float`, *optional*, defaults to `0.2`):
|
| 73 |
-
Clip range.
|
| 74 |
-
rloo_k (`int`, *optional*, defaults to `2`):
|
| 75 |
-
REINFORCE Leave-One-Out (RLOO) number of online samples per prompt.
|
| 76 |
-
normalize_reward (`bool`, *optional*, defaults to `False`):
|
| 77 |
-
Whether to normalize rewards.
|
| 78 |
-
reward_clip_range (`float`, *optional*, defaults to `10.0`):
|
| 79 |
-
Clip range for rewards.
|
| 80 |
-
normalize_advantage (`bool`, *optional*, defaults to `False`):
|
| 81 |
-
Whether to normalize advantages.
|
| 82 |
-
token_level_kl (`bool`, *optional*, defaults to `True`):
|
| 83 |
-
Whether to use token-level KL penalty or sequence-level KL penalty.
|
| 84 |
-
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
|
| 85 |
-
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
|
| 86 |
-
improving generation speed. However, disabling this option allows training models that exceed the VRAM
|
| 87 |
-
capacity of a single GPU, albeit at the cost of slower generation.
|
| 88 |
-
|
| 89 |
-
"""
|
| 90 |
-
vllm_sampling_params: Optional[Any] = field(
|
| 91 |
-
default = None,
|
| 92 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
| 93 |
-
)
|
| 94 |
-
unsloth_num_chunks : Optional[int] = field(
|
| 95 |
-
default = -1,
|
| 96 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 97 |
-
)
|
| 98 |
-
def __init__(
|
| 99 |
-
self,
|
| 100 |
-
output_dir = None,
|
| 101 |
-
overwrite_output_dir = None,
|
| 102 |
-
do_train = False,
|
| 103 |
-
do_eval = False,
|
| 104 |
-
do_predict = False,
|
| 105 |
-
eval_strategy = 'no',
|
| 106 |
-
prediction_loss_only = False,
|
| 107 |
-
per_device_train_batch_size = 4,
|
| 108 |
-
per_device_eval_batch_size = 4,
|
| 109 |
-
per_gpu_train_batch_size = None,
|
| 110 |
-
per_gpu_eval_batch_size = None,
|
| 111 |
-
gradient_accumulation_steps = 2,
|
| 112 |
-
eval_accumulation_steps = 2,
|
| 113 |
-
eval_delay = 0,
|
| 114 |
-
torch_empty_cache_steps = 250,
|
| 115 |
-
learning_rate = 5e-05,
|
| 116 |
-
weight_decay = 0.01,
|
| 117 |
-
adam_beta1 = 0.9,
|
| 118 |
-
adam_beta2 = 0.999,
|
| 119 |
-
adam_epsilon = 1e-08,
|
| 120 |
-
max_grad_norm = 1.0,
|
| 121 |
-
num_train_epochs = 3.0,
|
| 122 |
-
max_steps = -1,
|
| 123 |
-
lr_scheduler_type = 'linear',
|
| 124 |
-
warmup_ratio = 0.1,
|
| 125 |
-
warmup_steps = 0,
|
| 126 |
-
log_level = 'passive',
|
| 127 |
-
log_level_replica = 'warning',
|
| 128 |
-
log_on_each_node = True,
|
| 129 |
-
logging_dir = None,
|
| 130 |
-
logging_strategy = 'steps',
|
| 131 |
-
logging_first_step = False,
|
| 132 |
-
logging_steps = 1,
|
| 133 |
-
logging_nan_inf_filter = False,
|
| 134 |
-
save_strategy = 'steps',
|
| 135 |
-
save_steps = 500,
|
| 136 |
-
save_total_limit = None,
|
| 137 |
-
save_safetensors = True,
|
| 138 |
-
save_on_each_node = False,
|
| 139 |
-
save_only_model = False,
|
| 140 |
-
restore_callback_states_from_checkpoint = False,
|
| 141 |
-
no_cuda = False,
|
| 142 |
-
use_cpu = False,
|
| 143 |
-
use_mps_device = False,
|
| 144 |
-
seed = 3407,
|
| 145 |
-
data_seed = 3407,
|
| 146 |
-
jit_mode_eval = False,
|
| 147 |
-
use_ipex = False,
|
| 148 |
-
bf16 = False,
|
| 149 |
-
fp16 = False,
|
| 150 |
-
fp16_opt_level = 'O1',
|
| 151 |
-
half_precision_backend = 'auto',
|
| 152 |
-
bf16_full_eval = False,
|
| 153 |
-
fp16_full_eval = False,
|
| 154 |
-
tf32 = None,
|
| 155 |
-
local_rank = -1,
|
| 156 |
-
ddp_backend = None,
|
| 157 |
-
tpu_num_cores = None,
|
| 158 |
-
tpu_metrics_debug = False,
|
| 159 |
-
debug = '',
|
| 160 |
-
dataloader_drop_last = False,
|
| 161 |
-
eval_steps = None,
|
| 162 |
-
dataloader_num_workers = 0,
|
| 163 |
-
dataloader_prefetch_factor = None,
|
| 164 |
-
past_index = -1,
|
| 165 |
-
run_name = None,
|
| 166 |
-
disable_tqdm = None,
|
| 167 |
-
remove_unused_columns = True,
|
| 168 |
-
label_names = None,
|
| 169 |
-
load_best_model_at_end = False,
|
| 170 |
-
metric_for_best_model = None,
|
| 171 |
-
greater_is_better = None,
|
| 172 |
-
ignore_data_skip = False,
|
| 173 |
-
fsdp = '',
|
| 174 |
-
fsdp_min_num_params = 0,
|
| 175 |
-
fsdp_config = None,
|
| 176 |
-
fsdp_transformer_layer_cls_to_wrap = None,
|
| 177 |
-
accelerator_config = None,
|
| 178 |
-
deepspeed = None,
|
| 179 |
-
label_smoothing_factor = 0.0,
|
| 180 |
-
optim = 'adamw_8bit',
|
| 181 |
-
optim_args = None,
|
| 182 |
-
adafactor = False,
|
| 183 |
-
group_by_length = False,
|
| 184 |
-
length_column_name = 'length',
|
| 185 |
-
report_to = None,
|
| 186 |
-
ddp_find_unused_parameters = None,
|
| 187 |
-
ddp_bucket_cap_mb = None,
|
| 188 |
-
ddp_broadcast_buffers = None,
|
| 189 |
-
dataloader_pin_memory = True,
|
| 190 |
-
dataloader_persistent_workers = False,
|
| 191 |
-
skip_memory_metrics = True,
|
| 192 |
-
use_legacy_prediction_loop = False,
|
| 193 |
-
push_to_hub = False,
|
| 194 |
-
resume_from_checkpoint = None,
|
| 195 |
-
hub_model_id = None,
|
| 196 |
-
hub_strategy = 'every_save',
|
| 197 |
-
hub_token = None,
|
| 198 |
-
hub_private_repo = None,
|
| 199 |
-
hub_always_push = False,
|
| 200 |
-
hub_revision = None,
|
| 201 |
-
gradient_checkpointing = False,
|
| 202 |
-
gradient_checkpointing_kwargs = None,
|
| 203 |
-
include_inputs_for_metrics = False,
|
| 204 |
-
eval_do_concat_batches = True,
|
| 205 |
-
fp16_backend = 'auto',
|
| 206 |
-
push_to_hub_model_id = None,
|
| 207 |
-
push_to_hub_organization = None,
|
| 208 |
-
push_to_hub_token = None,
|
| 209 |
-
mp_parameters = '',
|
| 210 |
-
auto_find_batch_size = True,
|
| 211 |
-
full_determinism = False,
|
| 212 |
-
torchdynamo = None,
|
| 213 |
-
ray_scope = 'last',
|
| 214 |
-
ddp_timeout = 1800,
|
| 215 |
-
torch_compile = False,
|
| 216 |
-
torch_compile_backend = None,
|
| 217 |
-
torch_compile_mode = None,
|
| 218 |
-
include_tokens_per_second = False,
|
| 219 |
-
include_num_input_tokens_seen = False,
|
| 220 |
-
neftune_noise_alpha = None,
|
| 221 |
-
optim_target_modules = None,
|
| 222 |
-
batch_eval_metrics = False,
|
| 223 |
-
eval_on_start = False,
|
| 224 |
-
use_liger_kernel = False,
|
| 225 |
-
liger_kernel_config = None,
|
| 226 |
-
eval_use_gather_object = False,
|
| 227 |
-
average_tokens_across_devices = True,
|
| 228 |
-
dataset_num_proc = None,
|
| 229 |
-
num_mini_batches = 1,
|
| 230 |
-
total_episodes = None,
|
| 231 |
-
local_rollout_forward_batch_size = 64,
|
| 232 |
-
num_sample_generations = 10,
|
| 233 |
-
response_length = 53,
|
| 234 |
-
stop_token = None,
|
| 235 |
-
stop_token_id = None,
|
| 236 |
-
temperature = 0.7,
|
| 237 |
-
missing_eos_penalty = None,
|
| 238 |
-
sft_model_path = 'EleutherAI/pythia-160m',
|
| 239 |
-
world_size = None,
|
| 240 |
-
num_total_batches = None,
|
| 241 |
-
micro_batch_size = None,
|
| 242 |
-
local_batch_size = None,
|
| 243 |
-
batch_size = None,
|
| 244 |
-
local_mini_batch_size = None,
|
| 245 |
-
mini_batch_size = None,
|
| 246 |
-
exp_name = 'rloo_config',
|
| 247 |
-
reward_model_path = 'EleutherAI/pythia-160m',
|
| 248 |
-
num_ppo_epochs = 4,
|
| 249 |
-
whiten_rewards = False,
|
| 250 |
-
kl_coef = 0.05,
|
| 251 |
-
cliprange = 0.2,
|
| 252 |
-
rloo_k = 2,
|
| 253 |
-
normalize_reward = False,
|
| 254 |
-
reward_clip_range = 10.0,
|
| 255 |
-
normalize_advantage = False,
|
| 256 |
-
token_level_kl = False,
|
| 257 |
-
ds3_gather_for_generation = True,
|
| 258 |
-
vllm_sampling_params = None,
|
| 259 |
-
unsloth_num_chunks = -1,
|
| 260 |
-
**kwargs,
|
| 261 |
-
):
|
| 262 |
-
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 263 |
-
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 264 |
-
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 265 |
-
output_dir = 'unsloth_training_checkpoints'
|
| 266 |
-
save_strategy = 'no'
|
| 267 |
-
if dataset_num_proc is None:
|
| 268 |
-
from multiprocessing import cpu_count
|
| 269 |
-
dataset_num_proc = min(cpu_count()*2, 2)
|
| 270 |
-
if temperature <= 0:
|
| 271 |
-
raise MathError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
|
| 272 |
-
elif temperature >= 10:
|
| 273 |
-
raise MathError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
super().__init__(
|
| 277 |
-
output_dir = output_dir,
|
| 278 |
-
overwrite_output_dir = overwrite_output_dir,
|
| 279 |
-
do_train = do_train,
|
| 280 |
-
do_eval = do_eval,
|
| 281 |
-
do_predict = do_predict,
|
| 282 |
-
eval_strategy = eval_strategy,
|
| 283 |
-
prediction_loss_only = prediction_loss_only,
|
| 284 |
-
per_device_train_batch_size = per_device_train_batch_size,
|
| 285 |
-
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 286 |
-
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
| 287 |
-
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
| 288 |
-
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 289 |
-
eval_accumulation_steps = eval_accumulation_steps,
|
| 290 |
-
eval_delay = eval_delay,
|
| 291 |
-
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 292 |
-
learning_rate = learning_rate,
|
| 293 |
-
weight_decay = weight_decay,
|
| 294 |
-
adam_beta1 = adam_beta1,
|
| 295 |
-
adam_beta2 = adam_beta2,
|
| 296 |
-
adam_epsilon = adam_epsilon,
|
| 297 |
-
max_grad_norm = max_grad_norm,
|
| 298 |
-
num_train_epochs = num_train_epochs,
|
| 299 |
-
max_steps = max_steps,
|
| 300 |
-
lr_scheduler_type = lr_scheduler_type,
|
| 301 |
-
warmup_ratio = warmup_ratio,
|
| 302 |
-
warmup_steps = warmup_steps,
|
| 303 |
-
log_level = log_level,
|
| 304 |
-
log_level_replica = log_level_replica,
|
| 305 |
-
log_on_each_node = log_on_each_node,
|
| 306 |
-
logging_dir = logging_dir,
|
| 307 |
-
logging_strategy = logging_strategy,
|
| 308 |
-
logging_first_step = logging_first_step,
|
| 309 |
-
logging_steps = logging_steps,
|
| 310 |
-
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 311 |
-
save_strategy = save_strategy,
|
| 312 |
-
save_steps = save_steps,
|
| 313 |
-
save_total_limit = save_total_limit,
|
| 314 |
-
save_safetensors = save_safetensors,
|
| 315 |
-
save_on_each_node = save_on_each_node,
|
| 316 |
-
save_only_model = save_only_model,
|
| 317 |
-
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 318 |
-
no_cuda = no_cuda,
|
| 319 |
-
use_cpu = use_cpu,
|
| 320 |
-
use_mps_device = use_mps_device,
|
| 321 |
-
seed = seed,
|
| 322 |
-
data_seed = data_seed,
|
| 323 |
-
jit_mode_eval = jit_mode_eval,
|
| 324 |
-
use_ipex = use_ipex,
|
| 325 |
-
bf16 = bf16,
|
| 326 |
-
fp16 = fp16,
|
| 327 |
-
fp16_opt_level = fp16_opt_level,
|
| 328 |
-
half_precision_backend = half_precision_backend,
|
| 329 |
-
bf16_full_eval = bf16_full_eval,
|
| 330 |
-
fp16_full_eval = fp16_full_eval,
|
| 331 |
-
tf32 = tf32,
|
| 332 |
-
local_rank = local_rank,
|
| 333 |
-
ddp_backend = ddp_backend,
|
| 334 |
-
tpu_num_cores = tpu_num_cores,
|
| 335 |
-
tpu_metrics_debug = tpu_metrics_debug,
|
| 336 |
-
debug = debug,
|
| 337 |
-
dataloader_drop_last = dataloader_drop_last,
|
| 338 |
-
eval_steps = eval_steps,
|
| 339 |
-
dataloader_num_workers = dataloader_num_workers,
|
| 340 |
-
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 341 |
-
past_index = past_index,
|
| 342 |
-
run_name = run_name,
|
| 343 |
-
disable_tqdm = disable_tqdm,
|
| 344 |
-
remove_unused_columns = remove_unused_columns,
|
| 345 |
-
label_names = label_names,
|
| 346 |
-
load_best_model_at_end = load_best_model_at_end,
|
| 347 |
-
metric_for_best_model = metric_for_best_model,
|
| 348 |
-
greater_is_better = greater_is_better,
|
| 349 |
-
ignore_data_skip = ignore_data_skip,
|
| 350 |
-
fsdp = fsdp,
|
| 351 |
-
fsdp_min_num_params = fsdp_min_num_params,
|
| 352 |
-
fsdp_config = fsdp_config,
|
| 353 |
-
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
| 354 |
-
accelerator_config = accelerator_config,
|
| 355 |
-
deepspeed = deepspeed,
|
| 356 |
-
label_smoothing_factor = label_smoothing_factor,
|
| 357 |
-
optim = optim,
|
| 358 |
-
optim_args = optim_args,
|
| 359 |
-
adafactor = adafactor,
|
| 360 |
-
group_by_length = group_by_length,
|
| 361 |
-
length_column_name = length_column_name,
|
| 362 |
-
report_to = report_to,
|
| 363 |
-
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 364 |
-
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 365 |
-
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 366 |
-
dataloader_pin_memory = dataloader_pin_memory,
|
| 367 |
-
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 368 |
-
skip_memory_metrics = skip_memory_metrics,
|
| 369 |
-
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
| 370 |
-
push_to_hub = push_to_hub,
|
| 371 |
-
resume_from_checkpoint = resume_from_checkpoint,
|
| 372 |
-
hub_model_id = hub_model_id,
|
| 373 |
-
hub_strategy = hub_strategy,
|
| 374 |
-
hub_token = hub_token,
|
| 375 |
-
hub_private_repo = hub_private_repo,
|
| 376 |
-
hub_always_push = hub_always_push,
|
| 377 |
-
hub_revision = hub_revision,
|
| 378 |
-
gradient_checkpointing = gradient_checkpointing,
|
| 379 |
-
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 380 |
-
include_inputs_for_metrics = include_inputs_for_metrics,
|
| 381 |
-
eval_do_concat_batches = eval_do_concat_batches,
|
| 382 |
-
fp16_backend = fp16_backend,
|
| 383 |
-
push_to_hub_model_id = push_to_hub_model_id,
|
| 384 |
-
push_to_hub_organization = push_to_hub_organization,
|
| 385 |
-
push_to_hub_token = push_to_hub_token,
|
| 386 |
-
mp_parameters = mp_parameters,
|
| 387 |
-
auto_find_batch_size = auto_find_batch_size,
|
| 388 |
-
full_determinism = full_determinism,
|
| 389 |
-
torchdynamo = torchdynamo,
|
| 390 |
-
ray_scope = ray_scope,
|
| 391 |
-
ddp_timeout = ddp_timeout,
|
| 392 |
-
torch_compile = torch_compile,
|
| 393 |
-
torch_compile_backend = torch_compile_backend,
|
| 394 |
-
torch_compile_mode = torch_compile_mode,
|
| 395 |
-
include_tokens_per_second = include_tokens_per_second,
|
| 396 |
-
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 397 |
-
neftune_noise_alpha = neftune_noise_alpha,
|
| 398 |
-
optim_target_modules = optim_target_modules,
|
| 399 |
-
batch_eval_metrics = batch_eval_metrics,
|
| 400 |
-
eval_on_start = eval_on_start,
|
| 401 |
-
use_liger_kernel = use_liger_kernel,
|
| 402 |
-
liger_kernel_config = liger_kernel_config,
|
| 403 |
-
eval_use_gather_object = eval_use_gather_object,
|
| 404 |
-
average_tokens_across_devices = average_tokens_across_devices,
|
| 405 |
-
dataset_num_proc = dataset_num_proc,
|
| 406 |
-
num_mini_batches = num_mini_batches,
|
| 407 |
-
total_episodes = total_episodes,
|
| 408 |
-
local_rollout_forward_batch_size = local_rollout_forward_batch_size,
|
| 409 |
-
num_sample_generations = num_sample_generations,
|
| 410 |
-
response_length = response_length,
|
| 411 |
-
stop_token = stop_token,
|
| 412 |
-
stop_token_id = stop_token_id,
|
| 413 |
-
temperature = temperature,
|
| 414 |
-
missing_eos_penalty = missing_eos_penalty,
|
| 415 |
-
sft_model_path = sft_model_path,
|
| 416 |
-
world_size = world_size,
|
| 417 |
-
num_total_batches = num_total_batches,
|
| 418 |
-
micro_batch_size = micro_batch_size,
|
| 419 |
-
local_batch_size = local_batch_size,
|
| 420 |
-
batch_size = batch_size,
|
| 421 |
-
local_mini_batch_size = local_mini_batch_size,
|
| 422 |
-
mini_batch_size = mini_batch_size,
|
| 423 |
-
exp_name = exp_name,
|
| 424 |
-
reward_model_path = reward_model_path,
|
| 425 |
-
num_ppo_epochs = num_ppo_epochs,
|
| 426 |
-
whiten_rewards = whiten_rewards,
|
| 427 |
-
kl_coef = kl_coef,
|
| 428 |
-
cliprange = cliprange,
|
| 429 |
-
rloo_k = rloo_k,
|
| 430 |
-
normalize_reward = normalize_reward,
|
| 431 |
-
reward_clip_range = reward_clip_range,
|
| 432 |
-
normalize_advantage = normalize_advantage,
|
| 433 |
-
token_level_kl = token_level_kl,
|
| 434 |
-
ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
|
| 435 |
-
self.vllm_sampling_params = vllm_sampling_params
|
| 436 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
| 437 |
-
pass
|
| 438 |
-
|
| 439 |
-
class _UnslothRLOOTrainer(Trainer):
|
| 440 |
-
_tag_names = ["trl", "rloo"]
|
| 441 |
-
|
| 442 |
-
def __init__(
|
| 443 |
-
self,
|
| 444 |
-
config: RLOOConfig,
|
| 445 |
-
processing_class: Optional[
|
| 446 |
-
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 447 |
-
],
|
| 448 |
-
policy: nn.Module,
|
| 449 |
-
ref_policy: nn.Module,
|
| 450 |
-
reward_model: Union[nn.Module, Callable[[list[str]], list[float]]],
|
| 451 |
-
train_dataset: Dataset,
|
| 452 |
-
data_collator: Optional[DataCollatorWithPadding] = None,
|
| 453 |
-
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 454 |
-
# less commonly used
|
| 455 |
-
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 456 |
-
callbacks: Optional[list[TrainerCallback]] = None,
|
| 457 |
-
) -> None:
|
| 458 |
-
if ref_policy is policy:
|
| 459 |
-
raise ValueError(
|
| 460 |
-
"`policy` and `ref_policy` cannot be the same object. If you want `ref_policy` to be the "
|
| 461 |
-
"same as `policy`, you must mass a copy of it, or `None` if you use peft."
|
| 462 |
-
)
|
| 463 |
-
|
| 464 |
-
self.args = config
|
| 465 |
-
args = config
|
| 466 |
-
self.processing_class = processing_class
|
| 467 |
-
self.policy = policy
|
| 468 |
-
|
| 469 |
-
# Define the collator if not provided
|
| 470 |
-
if data_collator is None:
|
| 471 |
-
data_collator = DataCollatorWithPadding(self.processing_class)
|
| 472 |
-
|
| 473 |
-
self.policy.generation_config.eos_token_id = (
|
| 474 |
-
None # disable `pad_token_id` and `eos_token_id` because we just want to
|
| 475 |
-
)
|
| 476 |
-
self.policy.generation_config.pad_token_id = None # generate tokens without truncation / padding
|
| 477 |
-
|
| 478 |
-
self.ref_policy = ref_policy
|
| 479 |
-
self.reward_model = reward_model
|
| 480 |
-
self.train_dataset = train_dataset
|
| 481 |
-
self.train_dataset_len = len(train_dataset)
|
| 482 |
-
self.data_collator = data_collator
|
| 483 |
-
self.eval_dataset = eval_dataset
|
| 484 |
-
self.optimizer, self.lr_scheduler = optimizers
|
| 485 |
-
self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
|
| 486 |
-
|
| 487 |
-
#########
|
| 488 |
-
# calculate various batch sizes
|
| 489 |
-
#########
|
| 490 |
-
if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
|
| 491 |
-
args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
|
| 492 |
-
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
|
| 493 |
-
self.accelerator = accelerator
|
| 494 |
-
args.world_size = accelerator.num_processes
|
| 495 |
-
args.local_batch_size = (
|
| 496 |
-
args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches
|
| 497 |
-
)
|
| 498 |
-
args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
|
| 499 |
-
args.batch_size = int(args.local_batch_size * args.world_size)
|
| 500 |
-
args.mini_batch_size = exact_div(
|
| 501 |
-
args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
|
| 502 |
-
)
|
| 503 |
-
args.local_mini_batch_size = exact_div(
|
| 504 |
-
args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
|
| 505 |
-
)
|
| 506 |
-
args.num_total_batches = math.ceil(
|
| 507 |
-
args.total_episodes / args.batch_size
|
| 508 |
-
) # we may train for more than `total_episodes`
|
| 509 |
-
time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
|
| 510 |
-
time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
|
| 511 |
-
args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
|
| 512 |
-
self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
|
| 513 |
-
if args.num_sample_generations > 0:
|
| 514 |
-
self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
|
| 515 |
-
self.local_dataloader_batch_size = exact_div(
|
| 516 |
-
args.local_batch_size, args.rloo_k, "`local_batch_size` must be a multiple of rloo_k"
|
| 517 |
-
) # RLOO logic: needed because RLOO repeats the same prompt args.rloo_k times
|
| 518 |
-
|
| 519 |
-
#########
|
| 520 |
-
# setup model, optimizer, and others
|
| 521 |
-
#########
|
| 522 |
-
for module in [policy, ref_policy, reward_model]:
|
| 523 |
-
if isinstance(module, nn.Module):
|
| 524 |
-
disable_dropout_in_model(module)
|
| 525 |
-
if args.stop_token and args.stop_token == "eos":
|
| 526 |
-
args.stop_token_id = self.processing_class.eos_token_id
|
| 527 |
-
self.model = policy
|
| 528 |
-
self.create_optimizer_and_scheduler(
|
| 529 |
-
num_training_steps=args.num_total_batches
|
| 530 |
-
) # note that we are calling `self.lr_scheduler.step[]` manually only at the batch level
|
| 531 |
-
|
| 532 |
-
#########
|
| 533 |
-
### trainer specifics
|
| 534 |
-
#########
|
| 535 |
-
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
|
| 536 |
-
self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
|
| 537 |
-
self.callback_handler = CallbackHandler(
|
| 538 |
-
self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
|
| 539 |
-
)
|
| 540 |
-
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
|
| 541 |
-
self.control = TrainerControl()
|
| 542 |
-
self.state = OnlineTrainerState(
|
| 543 |
-
is_local_process_zero=self.is_local_process_zero(),
|
| 544 |
-
is_world_process_zero=self.is_world_process_zero(),
|
| 545 |
-
stateful_callbacks=[
|
| 546 |
-
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
|
| 547 |
-
],
|
| 548 |
-
)
|
| 549 |
-
|
| 550 |
-
self.current_flos = 0
|
| 551 |
-
self.hp_search_backend = None
|
| 552 |
-
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
| 553 |
-
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
| 554 |
-
# Create distant repo and output directory if needed
|
| 555 |
-
self.hub_model_id = None
|
| 556 |
-
if self.args.push_to_hub:
|
| 557 |
-
self.init_hf_repo()
|
| 558 |
-
if self.args.should_save:
|
| 559 |
-
os.makedirs(self.args.output_dir, exist_ok=True)
|
| 560 |
-
self.backup_model = None
|
| 561 |
-
|
| 562 |
-
# Add tags for models that have been loaded with the correct transformers version
|
| 563 |
-
if hasattr(self.model, "add_model_tags"):
|
| 564 |
-
self.model.add_model_tags(self._tag_names)
|
| 565 |
-
|
| 566 |
-
#########
|
| 567 |
-
### setup dataloader
|
| 568 |
-
#########
|
| 569 |
-
self.dataloader = DataLoader(
|
| 570 |
-
self.train_dataset,
|
| 571 |
-
batch_size=self.local_dataloader_batch_size,
|
| 572 |
-
shuffle=True,
|
| 573 |
-
collate_fn=self.data_collator,
|
| 574 |
-
drop_last=True, # needed; otherwise the last batch will be of ragged shape
|
| 575 |
-
)
|
| 576 |
-
# sync random states for DataLoader[shuffle=True] before `accelerator.prepare`
|
| 577 |
-
# see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
|
| 578 |
-
torch.manual_seed(args.seed)
|
| 579 |
-
self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
|
| 580 |
-
torch.manual_seed(self.local_seed) # reset the local seed again
|
| 581 |
-
|
| 582 |
-
self.eval_dataloader = DataLoader(
|
| 583 |
-
self.eval_dataset,
|
| 584 |
-
batch_size=args.per_device_eval_batch_size,
|
| 585 |
-
collate_fn=self.data_collator,
|
| 586 |
-
drop_last=True,
|
| 587 |
-
) # no need to shuffle eval dataset
|
| 588 |
-
self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
|
| 589 |
-
|
| 590 |
-
if self.is_deepspeed_enabled:
|
| 591 |
-
if isinstance(self.reward_model, nn.Module):
|
| 592 |
-
self.reward_model = prepare_deepspeed(
|
| 593 |
-
self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
|
| 594 |
-
)
|
| 595 |
-
self.ref_policy = prepare_deepspeed(
|
| 596 |
-
self.ref_policy, args.per_device_train_batch_size, args.fp16, args.bf16
|
| 597 |
-
)
|
| 598 |
-
self.deepspeed = self.model
|
| 599 |
-
else:
|
| 600 |
-
self.ref_policy = self.ref_policy.to(self.accelerator.device)
|
| 601 |
-
if isinstance(self.reward_model, nn.Module):
|
| 602 |
-
self.reward_model = self.reward_model.to(self.accelerator.device)
|
| 603 |
-
|
| 604 |
-
def get_train_dataloader(self) -> DataLoader:
|
| 605 |
-
return self.dataloader
|
| 606 |
-
|
| 607 |
-
def get_eval_dataloader(self) -> DataLoader:
|
| 608 |
-
return self.eval_dataloader
|
| 609 |
-
|
| 610 |
-
def train(self):
|
| 611 |
-
args = self.args
|
| 612 |
-
accelerator = self.accelerator
|
| 613 |
-
optimizer = self.optimizer
|
| 614 |
-
model = self.model
|
| 615 |
-
self.model_wrapped = self.model
|
| 616 |
-
ref_policy = self.ref_policy
|
| 617 |
-
reward_model = self.reward_model
|
| 618 |
-
processing_class = self.processing_class
|
| 619 |
-
dataloader = self.dataloader
|
| 620 |
-
device = accelerator.device
|
| 621 |
-
|
| 622 |
-
def repeat_generator():
|
| 623 |
-
while True:
|
| 624 |
-
yield from dataloader
|
| 625 |
-
|
| 626 |
-
iter_dataloader = iter(repeat_generator())
|
| 627 |
-
generation_config = GenerationConfig(
|
| 628 |
-
max_new_tokens=args.response_length,
|
| 629 |
-
temperature=(args.temperature + 1e-7),
|
| 630 |
-
top_k=0.0,
|
| 631 |
-
top_p=1.0,
|
| 632 |
-
do_sample=True,
|
| 633 |
-
)
|
| 634 |
-
|
| 635 |
-
accelerator.print("===training policy===")
|
| 636 |
-
start_time = time.time()
|
| 637 |
-
stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
|
| 638 |
-
approxkl_stats = torch.zeros(stats_shape, device=device)
|
| 639 |
-
pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
|
| 640 |
-
pg_loss_stats = torch.zeros(stats_shape, device=device)
|
| 641 |
-
vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
|
| 642 |
-
entropy_stats = torch.zeros(stats_shape, device=device)
|
| 643 |
-
ratio_stats = torch.zeros(stats_shape, device=device)
|
| 644 |
-
model.train()
|
| 645 |
-
|
| 646 |
-
# trainer state initialization
|
| 647 |
-
self.state.global_step = 0
|
| 648 |
-
self.state.episode = 0
|
| 649 |
-
self.state.max_steps = (args.num_total_batches * args.num_mini_batches) // 2
|
| 650 |
-
self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
|
| 651 |
-
# Compute absolute values for logging, eval, and save if given as ratio
|
| 652 |
-
if args.logging_steps is not None:
|
| 653 |
-
if args.logging_steps < 1:
|
| 654 |
-
self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
|
| 655 |
-
else:
|
| 656 |
-
self.state.logging_steps = args.logging_steps
|
| 657 |
-
if args.eval_steps is not None:
|
| 658 |
-
if args.eval_steps < 1:
|
| 659 |
-
self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
|
| 660 |
-
else:
|
| 661 |
-
self.state.eval_steps = args.eval_steps
|
| 662 |
-
if args.save_steps is not None:
|
| 663 |
-
if args.save_steps < 1:
|
| 664 |
-
self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
|
| 665 |
-
else:
|
| 666 |
-
self.state.save_steps = args.save_steps
|
| 667 |
-
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
|
| 668 |
-
|
| 669 |
-
for update in range(1, args.num_total_batches + 1):
|
| 670 |
-
self.state.episode += 1 * args.batch_size
|
| 671 |
-
data = next(iter_dataloader)
|
| 672 |
-
with torch.no_grad():
|
| 673 |
-
queries = data["input_ids"].to(device)
|
| 674 |
-
queries = queries.repeat(args.rloo_k, 1)
|
| 675 |
-
context_length = queries.shape[1]
|
| 676 |
-
responses = []
|
| 677 |
-
postprocessed_responses = []
|
| 678 |
-
logprobs = []
|
| 679 |
-
ref_logprobs = []
|
| 680 |
-
scores = []
|
| 681 |
-
sequence_lengths = []
|
| 682 |
-
|
| 683 |
-
# Generate responses and compute logprobs
|
| 684 |
-
with unwrap_model_for_generation(
|
| 685 |
-
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
| 686 |
-
) as unwrapped_model:
|
| 687 |
-
query_responses, logitss = batch_generation(
|
| 688 |
-
unwrapped_model,
|
| 689 |
-
queries,
|
| 690 |
-
args.local_rollout_forward_batch_size,
|
| 691 |
-
processing_class.pad_token_id,
|
| 692 |
-
generation_config,
|
| 693 |
-
)
|
| 694 |
-
|
| 695 |
-
# Process responses in batches
|
| 696 |
-
for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
|
| 697 |
-
query = queries[i : i + args.local_rollout_forward_batch_size]
|
| 698 |
-
query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
|
| 699 |
-
response = query_response[:, context_length:]
|
| 700 |
-
logits = logitss[i : i + args.local_rollout_forward_batch_size]
|
| 701 |
-
logprob = selective_log_softmax(logits, response)
|
| 702 |
-
del logits
|
| 703 |
-
torch.cuda.empty_cache()
|
| 704 |
-
|
| 705 |
-
ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
|
| 706 |
-
ref_logits = ref_output.logits[:, context_length - 1 : -1]
|
| 707 |
-
ref_logits /= args.temperature + 1e-7
|
| 708 |
-
ref_logprob = selective_log_softmax(ref_logits, response)
|
| 709 |
-
del ref_output, ref_logits
|
| 710 |
-
torch.cuda.empty_cache()
|
| 711 |
-
|
| 712 |
-
# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
|
| 713 |
-
postprocessed_response = response
|
| 714 |
-
if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
|
| 715 |
-
postprocessed_response = truncate_response(
|
| 716 |
-
args.stop_token_id, processing_class.pad_token_id, response
|
| 717 |
-
)
|
| 718 |
-
|
| 719 |
-
# Response Processing 2. run reward model on the truncated responses
|
| 720 |
-
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
|
| 721 |
-
sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
|
| 722 |
-
|
| 723 |
-
if isinstance(reward_model, nn.Module):
|
| 724 |
-
_, score, _ = get_reward(
|
| 725 |
-
reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
|
| 726 |
-
)
|
| 727 |
-
else:
|
| 728 |
-
score = torch.tensor(
|
| 729 |
-
reward_model(
|
| 730 |
-
processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True)
|
| 731 |
-
),
|
| 732 |
-
dtype=torch.float,
|
| 733 |
-
).to(device)
|
| 734 |
-
|
| 735 |
-
# Store batch results
|
| 736 |
-
responses.append(response)
|
| 737 |
-
postprocessed_responses.append(postprocessed_response)
|
| 738 |
-
logprobs.append(logprob)
|
| 739 |
-
ref_logprobs.append(ref_logprob)
|
| 740 |
-
sequence_lengths.append(sequence_length)
|
| 741 |
-
scores.append(score)
|
| 742 |
-
|
| 743 |
-
# Concatenate all batched results
|
| 744 |
-
responses = torch.cat(responses, 0)
|
| 745 |
-
postprocessed_responses = torch.cat(postprocessed_responses, 0)
|
| 746 |
-
logprobs = torch.cat(logprobs, 0)
|
| 747 |
-
ref_logprobs = torch.cat(ref_logprobs, 0)
|
| 748 |
-
sequence_lengths = torch.cat(sequence_lengths, 0)
|
| 749 |
-
scores = torch.cat(scores, 0)
|
| 750 |
-
del (logprob, ref_logprob, score)
|
| 751 |
-
torch.cuda.empty_cache()
|
| 752 |
-
gc.collect()
|
| 753 |
-
|
| 754 |
-
# Response Processing 3. filter response. Ensure that the sample contains stop_token_id
|
| 755 |
-
# responses not passing that filter will receive a low (fixed) score
|
| 756 |
-
# only query humans on responses that pass that filter
|
| 757 |
-
contain_eos_token = torch.any(postprocessed_responses == processing_class.eos_token_id, dim=-1)
|
| 758 |
-
if args.missing_eos_penalty is not None:
|
| 759 |
-
scores[~contain_eos_token] -= self.args.missing_eos_penalty
|
| 760 |
-
# accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
|
| 761 |
-
|
| 762 |
-
# be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
|
| 763 |
-
response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
|
| 764 |
-
padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
|
| 765 |
-
logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
|
| 766 |
-
ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
|
| 767 |
-
|
| 768 |
-
# 4. compute rewards
|
| 769 |
-
# Compute KL divergence
|
| 770 |
-
kl = logprobs - ref_logprobs
|
| 771 |
-
|
| 772 |
-
# Normalize rewards
|
| 773 |
-
if args.normalize_reward:
|
| 774 |
-
scores = (scores - scores.mean()) / (scores.std() + 1e-8)
|
| 775 |
-
scores = torch.clamp(scores, -args.reward_clip_range, args.reward_clip_range)
|
| 776 |
-
|
| 777 |
-
# Compute total reward with KL penalty
|
| 778 |
-
if args.token_level_kl:
|
| 779 |
-
# Token-level KL penalty: apply KL penalty per token
|
| 780 |
-
kl_reward = -args.kl_coef * kl
|
| 781 |
-
|
| 782 |
-
# Get the index of the last non-padded token for each sequence
|
| 783 |
-
eos_indices = padding_mask.size(1) - 1 - padding_mask.long().fliplr().argmax(dim=1, keepdim=True)
|
| 784 |
-
last_reward = torch.zeros_like(kl)
|
| 785 |
-
# Ensure scores has correct shape and type
|
| 786 |
-
scores_shaped = scores.reshape(-1, 1).to(kl.dtype)
|
| 787 |
-
last_reward.scatter_(dim=1, index=eos_indices, src=scores_shaped)
|
| 788 |
-
|
| 789 |
-
# Combine KL reward and last reward
|
| 790 |
-
non_score_reward = kl_reward.sum(1) # Keep this for logging
|
| 791 |
-
reward = last_reward + kl_reward
|
| 792 |
-
rlhf_reward = reward.sum(1) # Sum across sequence length
|
| 793 |
-
else:
|
| 794 |
-
# Sequence-level KL penalty: sum KL across tokens first
|
| 795 |
-
sequence_kl = kl.sum(1)
|
| 796 |
-
non_score_reward = -args.kl_coef * sequence_kl
|
| 797 |
-
rlhf_reward = non_score_reward + scores
|
| 798 |
-
|
| 799 |
-
# vectorized RLOO advantages implementation
|
| 800 |
-
rlhf_reward = rlhf_reward.reshape(args.rloo_k, -1)
|
| 801 |
-
baseline = (rlhf_reward.sum(0) - rlhf_reward) / (args.rloo_k - 1)
|
| 802 |
-
advantages = rlhf_reward - baseline
|
| 803 |
-
advantages = advantages.flatten()
|
| 804 |
-
|
| 805 |
-
# Normalize advantages
|
| 806 |
-
if args.normalize_advantage:
|
| 807 |
-
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
| 808 |
-
|
| 809 |
-
torch.cuda.empty_cache()
|
| 810 |
-
|
| 811 |
-
# Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
|
| 812 |
-
for ppo_epoch_idx in range(args.num_ppo_epochs):
|
| 813 |
-
b_inds = np.random.permutation(args.local_batch_size)
|
| 814 |
-
minibatch_idx = 0
|
| 815 |
-
for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
|
| 816 |
-
mini_batch_end = mini_batch_start + args.local_mini_batch_size
|
| 817 |
-
mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
|
| 818 |
-
gradient_accumulation_idx = 0
|
| 819 |
-
for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
|
| 820 |
-
with accelerator.accumulate(model):
|
| 821 |
-
micro_batch_end = micro_batch_start + args.per_device_train_batch_size
|
| 822 |
-
micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
|
| 823 |
-
|
| 824 |
-
# Get batch data
|
| 825 |
-
mb_advantage = advantages[micro_batch_inds]
|
| 826 |
-
mb_responses = responses[micro_batch_inds]
|
| 827 |
-
mb_query_responses = query_responses[micro_batch_inds]
|
| 828 |
-
mb_logprobs = logprobs[micro_batch_inds]
|
| 829 |
-
|
| 830 |
-
# Forward pass
|
| 831 |
-
output = forward(model, mb_query_responses, processing_class.pad_token_id)
|
| 832 |
-
logits = output.logits[:, context_length - 1 : -1]
|
| 833 |
-
logits /= args.temperature + 1e-7
|
| 834 |
-
|
| 835 |
-
# Compute new logprobs
|
| 836 |
-
new_logprobs = selective_log_softmax(logits, mb_responses)
|
| 837 |
-
new_logprobs = torch.masked_fill(
|
| 838 |
-
new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
|
| 839 |
-
)
|
| 840 |
-
|
| 841 |
-
# Compute probability ratios
|
| 842 |
-
new_ratio = (new_logprobs - mb_logprobs).exp()
|
| 843 |
-
new_logprobs = new_logprobs.sum(1)
|
| 844 |
-
mb_logprobs = mb_logprobs.sum(1)
|
| 845 |
-
logprobs_diff = new_logprobs - mb_logprobs
|
| 846 |
-
ratio = torch.exp(logprobs_diff)
|
| 847 |
-
|
| 848 |
-
# PPO clipped loss
|
| 849 |
-
pg_losses = -mb_advantage * ratio
|
| 850 |
-
pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
|
| 851 |
-
pg_loss_max = torch.max(pg_losses, pg_losses2)
|
| 852 |
-
pg_loss = pg_loss_max.mean()
|
| 853 |
-
|
| 854 |
-
# Final loss
|
| 855 |
-
loss = pg_loss
|
| 856 |
-
|
| 857 |
-
# Optimization step
|
| 858 |
-
accelerator.backward(loss)
|
| 859 |
-
optimizer.step()
|
| 860 |
-
optimizer.zero_grad()
|
| 861 |
-
|
| 862 |
-
with torch.no_grad():
|
| 863 |
-
pg_clipfrac = (pg_losses2 > pg_losses).float().mean()
|
| 864 |
-
prob_dist = torch.nn.functional.softmax(logits, dim=-1)
|
| 865 |
-
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
|
| 866 |
-
approxkl = 0.5 * (logprobs_diff**2).mean()
|
| 867 |
-
approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
|
| 868 |
-
pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
|
| 869 |
-
pg_clipfrac
|
| 870 |
-
)
|
| 871 |
-
pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
|
| 872 |
-
entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
|
| 873 |
-
ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = new_ratio.mean()
|
| 874 |
-
gradient_accumulation_idx += 1
|
| 875 |
-
minibatch_idx += 1
|
| 876 |
-
|
| 877 |
-
# del everything and empty cache
|
| 878 |
-
# fmt: off
|
| 879 |
-
del (
|
| 880 |
-
output, logits, new_logprobs, logprobs_diff, ratio, pg_losses,
|
| 881 |
-
pg_losses2, pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl,
|
| 882 |
-
mb_advantage, mb_responses, mb_query_responses, mb_logprobs,
|
| 883 |
-
)
|
| 884 |
-
# fmt: on
|
| 885 |
-
torch.cuda.empty_cache()
|
| 886 |
-
|
| 887 |
-
# Compute metrics
|
| 888 |
-
with torch.no_grad():
|
| 889 |
-
mean_kl = kl.sum(1).mean()
|
| 890 |
-
mean_entropy = (-logprobs).sum(1).mean()
|
| 891 |
-
mean_non_score_reward = non_score_reward.mean()
|
| 892 |
-
eps = int(self.state.episode / (time.time() - start_time))
|
| 893 |
-
metrics = {}
|
| 894 |
-
metrics["eps"] = eps
|
| 895 |
-
metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
|
| 896 |
-
metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
|
| 897 |
-
metrics["objective/non_score_reward"] = (
|
| 898 |
-
self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
|
| 899 |
-
)
|
| 900 |
-
metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
|
| 901 |
-
metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
|
| 902 |
-
metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
|
| 903 |
-
metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
|
| 904 |
-
metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
|
| 905 |
-
metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
|
| 906 |
-
metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
|
| 907 |
-
metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
|
| 908 |
-
metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
|
| 909 |
-
metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
|
| 910 |
-
metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
|
| 911 |
-
metrics["episode"] = self.state.episode
|
| 912 |
-
self.state.epoch = self.state.episode / (args.rloo_k * self.train_dataset_len) # used by self.log
|
| 913 |
-
self.log(metrics)
|
| 914 |
-
del kl, mean_kl, mean_entropy, scores
|
| 915 |
-
|
| 916 |
-
self.lr_scheduler.step()
|
| 917 |
-
self.state.global_step += 1
|
| 918 |
-
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
|
| 919 |
-
if self.control.should_save:
|
| 920 |
-
self._save_checkpoint(model, trial=None)
|
| 921 |
-
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
| 922 |
-
torch.cuda.empty_cache()
|
| 923 |
-
gc.collect()
|
| 924 |
-
|
| 925 |
-
if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
|
| 926 |
-
self.generate_completions(sampling=True)
|
| 927 |
-
|
| 928 |
-
# HF trainer specifics
|
| 929 |
-
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
|
| 930 |
-
if self.control.should_save:
|
| 931 |
-
self._save_checkpoint(model, trial=None, metrics=None)
|
| 932 |
-
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
| 933 |
-
|
| 934 |
-
def generate_completions(self, sampling: bool = False):
|
| 935 |
-
args = self.args
|
| 936 |
-
processing_class = self.processing_class
|
| 937 |
-
generation_config = GenerationConfig(
|
| 938 |
-
max_new_tokens=self.args.response_length,
|
| 939 |
-
temperature=(0.01 + 1e-7),
|
| 940 |
-
top_k=0.0,
|
| 941 |
-
top_p=1.0,
|
| 942 |
-
do_sample=True,
|
| 943 |
-
)
|
| 944 |
-
|
| 945 |
-
table = defaultdict(list)
|
| 946 |
-
with unwrap_model_for_generation(
|
| 947 |
-
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
| 948 |
-
) as unwrapped_model:
|
| 949 |
-
for batch in self.eval_dataloader:
|
| 950 |
-
query = batch["input_ids"]
|
| 951 |
-
with torch.no_grad():
|
| 952 |
-
context_length = query.shape[1]
|
| 953 |
-
query_response, _ = batch_generation(
|
| 954 |
-
unwrapped_model,
|
| 955 |
-
query,
|
| 956 |
-
query.shape[0],
|
| 957 |
-
processing_class.pad_token_id,
|
| 958 |
-
generation_config,
|
| 959 |
-
)
|
| 960 |
-
response = query_response[:, context_length:]
|
| 961 |
-
postprocessed_response = response
|
| 962 |
-
if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
|
| 963 |
-
postprocessed_response = truncate_response(
|
| 964 |
-
args.stop_token_id, processing_class.pad_token_id, response
|
| 965 |
-
)
|
| 966 |
-
table["query"].extend(
|
| 967 |
-
gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
|
| 968 |
-
)
|
| 969 |
-
table["model response"].extend(
|
| 970 |
-
gather_object(processing_class.batch_decode(postprocessed_response))
|
| 971 |
-
)
|
| 972 |
-
|
| 973 |
-
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
|
| 974 |
-
|
| 975 |
-
if isinstance(self.reward_model, nn.Module):
|
| 976 |
-
_, score, _ = get_reward(
|
| 977 |
-
self.reward_model,
|
| 978 |
-
postprocessed_query_response,
|
| 979 |
-
processing_class.pad_token_id,
|
| 980 |
-
context_length,
|
| 981 |
-
)
|
| 982 |
-
else:
|
| 983 |
-
score = torch.tensor(
|
| 984 |
-
self.reward_model(
|
| 985 |
-
processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True)
|
| 986 |
-
),
|
| 987 |
-
dtype=torch.float,
|
| 988 |
-
).to(postprocessed_query_response.device)
|
| 989 |
-
table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
|
| 990 |
-
|
| 991 |
-
if sampling:
|
| 992 |
-
break
|
| 993 |
-
df = pd.DataFrame(table)
|
| 994 |
-
|
| 995 |
-
if self.accelerator.is_main_process:
|
| 996 |
-
print_rich_table(df.iloc[0 : 0 + 5])
|
| 997 |
-
if "wandb" in args.report_to:
|
| 998 |
-
import wandb
|
| 999 |
-
|
| 1000 |
-
if wandb.run is not None:
|
| 1001 |
-
wandb.log({"completions": wandb.Table(dataframe=df)})
|
| 1002 |
-
|
| 1003 |
-
if "comet_ml" in args.report_to:
|
| 1004 |
-
log_table_to_comet_experiment(
|
| 1005 |
-
name="completions.csv",
|
| 1006 |
-
table=df,
|
| 1007 |
-
)
|
| 1008 |
-
|
| 1009 |
-
def create_model_card(
|
| 1010 |
-
self,
|
| 1011 |
-
model_name: Optional[str] = None,
|
| 1012 |
-
dataset_name: Optional[str] = None,
|
| 1013 |
-
tags: Union[str, list[str], None] = None,
|
| 1014 |
-
):
|
| 1015 |
-
"""
|
| 1016 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
| 1017 |
-
|
| 1018 |
-
Args:
|
| 1019 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1020 |
-
Name of the model.
|
| 1021 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 1022 |
-
Name of the dataset used for training.
|
| 1023 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 1024 |
-
Tags to be associated with the model card.
|
| 1025 |
-
"""
|
| 1026 |
-
if not self.is_world_process_zero():
|
| 1027 |
-
return
|
| 1028 |
-
|
| 1029 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 1030 |
-
base_model = self.model.config._name_or_path
|
| 1031 |
-
else:
|
| 1032 |
-
base_model = None
|
| 1033 |
-
|
| 1034 |
-
tags = tags or []
|
| 1035 |
-
if isinstance(tags, str):
|
| 1036 |
-
tags = [tags]
|
| 1037 |
-
|
| 1038 |
-
if hasattr(self.model.config, "unsloth_version"):
|
| 1039 |
-
tags.append("unsloth")
|
| 1040 |
-
|
| 1041 |
-
citation = textwrap.dedent("""\
|
| 1042 |
-
@inproceedings{ahmadian2024back,
|
| 1043 |
-
title = {{Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs}},
|
| 1044 |
-
author = {Arash Ahmadian and Chris Cremer and Matthias Gall{\'{e}} and Marzieh Fadaee and Julia Kreutzer and Olivier Pietquin and Ahmet {\"{U}}st{\"{u}}n and Sara Hooker},
|
| 1045 |
-
year = 2024,
|
| 1046 |
-
booktitle = {Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), {ACL} 2024, Bangkok, Thailand, August 11-16, 2024},
|
| 1047 |
-
publisher = {Association for Computational Linguistics},
|
| 1048 |
-
pages = {12248--12267},
|
| 1049 |
-
editor = {Lun{-}Wei Ku and Andre Martins and Vivek Srikumar},
|
| 1050 |
-
}""")
|
| 1051 |
-
|
| 1052 |
-
model_card = generate_model_card(
|
| 1053 |
-
base_model=base_model,
|
| 1054 |
-
model_name=model_name,
|
| 1055 |
-
hub_model_id=self.hub_model_id,
|
| 1056 |
-
dataset_name=dataset_name,
|
| 1057 |
-
tags=tags,
|
| 1058 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 1059 |
-
comet_url=get_comet_experiment_url(),
|
| 1060 |
-
trainer_name="RLOO",
|
| 1061 |
-
trainer_citation=citation,
|
| 1062 |
-
paper_title="Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs",
|
| 1063 |
-
paper_id="2402.14740",
|
| 1064 |
-
)
|
| 1065 |
-
|
| 1066 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 1067 |
-
class UnslothRLOOTrainer(_UnslothRLOOTrainer):
|
| 1068 |
-
"""
|
| 1069 |
-
|
| 1070 |
-
"""
|
| 1071 |
-
def __init__(
|
| 1072 |
-
self,
|
| 1073 |
-
config,
|
| 1074 |
-
processing_class,
|
| 1075 |
-
policy,
|
| 1076 |
-
ref_policy,
|
| 1077 |
-
reward_model,
|
| 1078 |
-
train_dataset,
|
| 1079 |
-
data_collator = None,
|
| 1080 |
-
eval_dataset = None,
|
| 1081 |
-
callbacks = None,
|
| 1082 |
-
**kwargs
|
| 1083 |
-
):
|
| 1084 |
-
if args is None: args = UnslothRLOOConfig()
|
| 1085 |
-
_output_logits = False
|
| 1086 |
-
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 1087 |
-
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 1088 |
-
if _output_logits:
|
| 1089 |
-
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 1090 |
-
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 1091 |
-
pass
|
| 1092 |
-
else:
|
| 1093 |
-
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 1094 |
-
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 1095 |
-
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 1096 |
-
max_seq_length = model.max_seq_length
|
| 1097 |
-
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 1098 |
-
if model is not None and hasattr(model, 'for_training'):
|
| 1099 |
-
model.for_training()
|
| 1100 |
-
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 1101 |
-
if 'processing_class' in locals():
|
| 1102 |
-
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 1103 |
-
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 1104 |
-
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 1105 |
-
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 1106 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1107 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 1108 |
-
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
|
| 1109 |
-
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 1110 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
| 1111 |
-
else:
|
| 1112 |
-
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 1113 |
-
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 1114 |
-
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 1115 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1116 |
-
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 1117 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 1118 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
| 1119 |
-
else:
|
| 1120 |
-
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
|
| 1121 |
-
other_metrics = []
|
| 1122 |
-
|
| 1123 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 1124 |
-
PatchRLStatistics('rloo_trainer', other_metrics)
|
| 1125 |
-
|
| 1126 |
-
super().__init__(
|
| 1127 |
-
config = config,
|
| 1128 |
-
processing_class = processing_class,
|
| 1129 |
-
policy = policy,
|
| 1130 |
-
ref_policy = ref_policy,
|
| 1131 |
-
reward_model = reward_model,
|
| 1132 |
-
train_dataset = train_dataset,
|
| 1133 |
-
data_collator = data_collator,
|
| 1134 |
-
eval_dataset = eval_dataset,
|
| 1135 |
-
callbacks = callbacks,**kwargs)
|
| 1136 |
-
if hasattr(self, 'neftune_hook_handle'):
|
| 1137 |
-
self.neftune_hook_handle.remove()
|
| 1138 |
-
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 1139 |
-
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 1140 |
-
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 1141 |
-
pass
|
| 1142 |
-
|
| 1143 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_run_uploads/UnslothRewardTrainer.py
DELETED
|
@@ -1,828 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
2025.7.11
|
| 3 |
-
2025.7.11
|
| 4 |
-
4.54.1
|
| 5 |
-
0.16.1
|
| 6 |
-
__UNSLOTH_VERSIONING__
|
| 7 |
-
"""
|
| 8 |
-
from torch import Tensor
|
| 9 |
-
import torch
|
| 10 |
-
import torch.nn as nn
|
| 11 |
-
from torch.nn import functional as F
|
| 12 |
-
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
| 13 |
-
from trl.trainer.reward_trainer import (Any, BaseImageProcessor, Callable, DataCollator, Dataset, EvalPrediction, FeatureExtractionMixin, FrozenInstanceError, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RewardConfig, RewardDataCollatorWithPadding, RewardTrainer, Trainer, TrainerCallback, Union, _tokenize, compute_accuracy, decode_and_strip_padding, defaultdict, disable_dropout_in_model, gather_object, generate_model_card, get_comet_experiment_url, inspect, is_peft_available, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, nested_detach, nn, os, pd, prepare_model_for_kbit_training, print_rich_table, replace, torch, wandb, warnings)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
import os
|
| 17 |
-
from typing import *
|
| 18 |
-
from dataclasses import dataclass, field
|
| 19 |
-
from packaging.version import Version
|
| 20 |
-
import torch
|
| 21 |
-
import numpy as np
|
| 22 |
-
from contextlib import nullcontext
|
| 23 |
-
from torch.nn import functional as F
|
| 24 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
| 25 |
-
|
| 26 |
-
torch_compile_options = {
|
| 27 |
-
"epilogue_fusion" : True,
|
| 28 |
-
"max_autotune" : False,
|
| 29 |
-
"shape_padding" : True,
|
| 30 |
-
"trace.enabled" : False,
|
| 31 |
-
"triton.cudagraphs" : False,
|
| 32 |
-
}
|
| 33 |
-
|
| 34 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 35 |
-
def chunked_selective_log_softmax(logits, index):
|
| 36 |
-
# Split into 4 chunks only
|
| 37 |
-
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
|
| 38 |
-
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
|
| 39 |
-
all_per_token_logps = []
|
| 40 |
-
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
|
| 41 |
-
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
|
| 42 |
-
chunk_logits = chunk_logits.to(torch.float32)
|
| 43 |
-
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 44 |
-
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
|
| 45 |
-
per_token_logps = selected_logits - logsumexp_values
|
| 46 |
-
all_per_token_logps.append(per_token_logps)
|
| 47 |
-
pass
|
| 48 |
-
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 49 |
-
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
|
| 50 |
-
return all_per_token_logps
|
| 51 |
-
@dataclass
|
| 52 |
-
class UnslothRewardConfig(RewardConfig):
|
| 53 |
-
"""
|
| 54 |
-
|
| 55 |
-
Configuration class for the [`RewardTrainer`].
|
| 56 |
-
|
| 57 |
-
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 58 |
-
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 59 |
-
command line.
|
| 60 |
-
|
| 61 |
-
Parameters:
|
| 62 |
-
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
| 63 |
-
Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the
|
| 64 |
-
limit. This argument is required if you want to use the default data collator.
|
| 65 |
-
disable_dropout (`bool`, *optional*, defaults to `True`):
|
| 66 |
-
Whether to disable dropout in the model.
|
| 67 |
-
dataset_num_proc (`int`, *optional*, defaults to `None`):
|
| 68 |
-
Number of processes to use for processing the dataset.
|
| 69 |
-
center_rewards_coefficient (`float`, *optional*, defaults to `None`):
|
| 70 |
-
Coefficient to incentivize the reward model to output mean-zero rewards (proposed by
|
| 71 |
-
https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`.
|
| 72 |
-
remove_unused_columns (`bool`, *optional*, defaults to `False`):
|
| 73 |
-
Whether to remove the columns that are not used by the model's forward pass. Can be `True` only if
|
| 74 |
-
the dataset is pretokenized.
|
| 75 |
-
|
| 76 |
-
"""
|
| 77 |
-
vllm_sampling_params: Optional[Any] = field(
|
| 78 |
-
default = None,
|
| 79 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
| 80 |
-
)
|
| 81 |
-
unsloth_num_chunks : Optional[int] = field(
|
| 82 |
-
default = -1,
|
| 83 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 84 |
-
)
|
| 85 |
-
def __init__(
|
| 86 |
-
self,
|
| 87 |
-
output_dir = None,
|
| 88 |
-
overwrite_output_dir = None,
|
| 89 |
-
do_train = False,
|
| 90 |
-
do_eval = False,
|
| 91 |
-
do_predict = False,
|
| 92 |
-
eval_strategy = 'no',
|
| 93 |
-
prediction_loss_only = False,
|
| 94 |
-
per_device_train_batch_size = 4,
|
| 95 |
-
per_device_eval_batch_size = 4,
|
| 96 |
-
per_gpu_train_batch_size = None,
|
| 97 |
-
per_gpu_eval_batch_size = None,
|
| 98 |
-
gradient_accumulation_steps = 2,
|
| 99 |
-
eval_accumulation_steps = 2,
|
| 100 |
-
eval_delay = 0,
|
| 101 |
-
torch_empty_cache_steps = 250,
|
| 102 |
-
learning_rate = 5e-05,
|
| 103 |
-
weight_decay = 0.01,
|
| 104 |
-
adam_beta1 = 0.9,
|
| 105 |
-
adam_beta2 = 0.999,
|
| 106 |
-
adam_epsilon = 1e-08,
|
| 107 |
-
max_grad_norm = 1.0,
|
| 108 |
-
num_train_epochs = 3.0,
|
| 109 |
-
max_steps = -1,
|
| 110 |
-
lr_scheduler_type = 'linear',
|
| 111 |
-
warmup_ratio = 0.1,
|
| 112 |
-
warmup_steps = 0,
|
| 113 |
-
log_level = 'passive',
|
| 114 |
-
log_level_replica = 'warning',
|
| 115 |
-
log_on_each_node = True,
|
| 116 |
-
logging_dir = None,
|
| 117 |
-
logging_strategy = 'steps',
|
| 118 |
-
logging_first_step = False,
|
| 119 |
-
logging_steps = 1,
|
| 120 |
-
logging_nan_inf_filter = False,
|
| 121 |
-
save_strategy = 'steps',
|
| 122 |
-
save_steps = 500,
|
| 123 |
-
save_total_limit = None,
|
| 124 |
-
save_safetensors = True,
|
| 125 |
-
save_on_each_node = False,
|
| 126 |
-
save_only_model = False,
|
| 127 |
-
restore_callback_states_from_checkpoint = False,
|
| 128 |
-
no_cuda = False,
|
| 129 |
-
use_cpu = False,
|
| 130 |
-
use_mps_device = False,
|
| 131 |
-
seed = 3407,
|
| 132 |
-
data_seed = 3407,
|
| 133 |
-
jit_mode_eval = False,
|
| 134 |
-
use_ipex = False,
|
| 135 |
-
bf16 = False,
|
| 136 |
-
fp16 = False,
|
| 137 |
-
fp16_opt_level = 'O1',
|
| 138 |
-
half_precision_backend = 'auto',
|
| 139 |
-
bf16_full_eval = False,
|
| 140 |
-
fp16_full_eval = False,
|
| 141 |
-
tf32 = None,
|
| 142 |
-
local_rank = -1,
|
| 143 |
-
ddp_backend = None,
|
| 144 |
-
tpu_num_cores = None,
|
| 145 |
-
tpu_metrics_debug = False,
|
| 146 |
-
debug = '',
|
| 147 |
-
dataloader_drop_last = False,
|
| 148 |
-
eval_steps = None,
|
| 149 |
-
dataloader_num_workers = 0,
|
| 150 |
-
dataloader_prefetch_factor = None,
|
| 151 |
-
past_index = -1,
|
| 152 |
-
run_name = None,
|
| 153 |
-
disable_tqdm = None,
|
| 154 |
-
remove_unused_columns = False,
|
| 155 |
-
label_names = None,
|
| 156 |
-
load_best_model_at_end = False,
|
| 157 |
-
metric_for_best_model = None,
|
| 158 |
-
greater_is_better = None,
|
| 159 |
-
ignore_data_skip = False,
|
| 160 |
-
fsdp = '',
|
| 161 |
-
fsdp_min_num_params = 0,
|
| 162 |
-
fsdp_config = None,
|
| 163 |
-
fsdp_transformer_layer_cls_to_wrap = None,
|
| 164 |
-
accelerator_config = None,
|
| 165 |
-
deepspeed = None,
|
| 166 |
-
label_smoothing_factor = 0.0,
|
| 167 |
-
optim = 'adamw_8bit',
|
| 168 |
-
optim_args = None,
|
| 169 |
-
adafactor = False,
|
| 170 |
-
group_by_length = False,
|
| 171 |
-
length_column_name = 'length',
|
| 172 |
-
report_to = None,
|
| 173 |
-
ddp_find_unused_parameters = None,
|
| 174 |
-
ddp_bucket_cap_mb = None,
|
| 175 |
-
ddp_broadcast_buffers = None,
|
| 176 |
-
dataloader_pin_memory = True,
|
| 177 |
-
dataloader_persistent_workers = False,
|
| 178 |
-
skip_memory_metrics = True,
|
| 179 |
-
use_legacy_prediction_loop = False,
|
| 180 |
-
push_to_hub = False,
|
| 181 |
-
resume_from_checkpoint = None,
|
| 182 |
-
hub_model_id = None,
|
| 183 |
-
hub_strategy = 'every_save',
|
| 184 |
-
hub_token = None,
|
| 185 |
-
hub_private_repo = None,
|
| 186 |
-
hub_always_push = False,
|
| 187 |
-
hub_revision = None,
|
| 188 |
-
gradient_checkpointing = False,
|
| 189 |
-
gradient_checkpointing_kwargs = None,
|
| 190 |
-
include_inputs_for_metrics = False,
|
| 191 |
-
eval_do_concat_batches = True,
|
| 192 |
-
fp16_backend = 'auto',
|
| 193 |
-
push_to_hub_model_id = None,
|
| 194 |
-
push_to_hub_organization = None,
|
| 195 |
-
push_to_hub_token = None,
|
| 196 |
-
mp_parameters = '',
|
| 197 |
-
auto_find_batch_size = True,
|
| 198 |
-
full_determinism = False,
|
| 199 |
-
torchdynamo = None,
|
| 200 |
-
ray_scope = 'last',
|
| 201 |
-
ddp_timeout = 1800,
|
| 202 |
-
torch_compile = False,
|
| 203 |
-
torch_compile_backend = None,
|
| 204 |
-
torch_compile_mode = None,
|
| 205 |
-
include_tokens_per_second = False,
|
| 206 |
-
include_num_input_tokens_seen = False,
|
| 207 |
-
neftune_noise_alpha = None,
|
| 208 |
-
optim_target_modules = None,
|
| 209 |
-
batch_eval_metrics = False,
|
| 210 |
-
eval_on_start = False,
|
| 211 |
-
use_liger_kernel = False,
|
| 212 |
-
liger_kernel_config = None,
|
| 213 |
-
eval_use_gather_object = False,
|
| 214 |
-
average_tokens_across_devices = True,
|
| 215 |
-
max_length = 1024,
|
| 216 |
-
disable_dropout = True,
|
| 217 |
-
dataset_num_proc = None,
|
| 218 |
-
center_rewards_coefficient = None,
|
| 219 |
-
vllm_sampling_params = None,
|
| 220 |
-
unsloth_num_chunks = -1,
|
| 221 |
-
**kwargs,
|
| 222 |
-
):
|
| 223 |
-
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 224 |
-
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 225 |
-
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 226 |
-
output_dir = 'unsloth_training_checkpoints'
|
| 227 |
-
save_strategy = 'no'
|
| 228 |
-
if dataset_num_proc is None:
|
| 229 |
-
from multiprocessing import cpu_count
|
| 230 |
-
dataset_num_proc = min(cpu_count()*2, 2)
|
| 231 |
-
|
| 232 |
-
super().__init__(
|
| 233 |
-
output_dir = output_dir,
|
| 234 |
-
overwrite_output_dir = overwrite_output_dir,
|
| 235 |
-
do_train = do_train,
|
| 236 |
-
do_eval = do_eval,
|
| 237 |
-
do_predict = do_predict,
|
| 238 |
-
eval_strategy = eval_strategy,
|
| 239 |
-
prediction_loss_only = prediction_loss_only,
|
| 240 |
-
per_device_train_batch_size = per_device_train_batch_size,
|
| 241 |
-
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 242 |
-
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
| 243 |
-
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
| 244 |
-
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 245 |
-
eval_accumulation_steps = eval_accumulation_steps,
|
| 246 |
-
eval_delay = eval_delay,
|
| 247 |
-
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 248 |
-
learning_rate = learning_rate,
|
| 249 |
-
weight_decay = weight_decay,
|
| 250 |
-
adam_beta1 = adam_beta1,
|
| 251 |
-
adam_beta2 = adam_beta2,
|
| 252 |
-
adam_epsilon = adam_epsilon,
|
| 253 |
-
max_grad_norm = max_grad_norm,
|
| 254 |
-
num_train_epochs = num_train_epochs,
|
| 255 |
-
max_steps = max_steps,
|
| 256 |
-
lr_scheduler_type = lr_scheduler_type,
|
| 257 |
-
warmup_ratio = warmup_ratio,
|
| 258 |
-
warmup_steps = warmup_steps,
|
| 259 |
-
log_level = log_level,
|
| 260 |
-
log_level_replica = log_level_replica,
|
| 261 |
-
log_on_each_node = log_on_each_node,
|
| 262 |
-
logging_dir = logging_dir,
|
| 263 |
-
logging_strategy = logging_strategy,
|
| 264 |
-
logging_first_step = logging_first_step,
|
| 265 |
-
logging_steps = logging_steps,
|
| 266 |
-
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 267 |
-
save_strategy = save_strategy,
|
| 268 |
-
save_steps = save_steps,
|
| 269 |
-
save_total_limit = save_total_limit,
|
| 270 |
-
save_safetensors = save_safetensors,
|
| 271 |
-
save_on_each_node = save_on_each_node,
|
| 272 |
-
save_only_model = save_only_model,
|
| 273 |
-
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 274 |
-
no_cuda = no_cuda,
|
| 275 |
-
use_cpu = use_cpu,
|
| 276 |
-
use_mps_device = use_mps_device,
|
| 277 |
-
seed = seed,
|
| 278 |
-
data_seed = data_seed,
|
| 279 |
-
jit_mode_eval = jit_mode_eval,
|
| 280 |
-
use_ipex = use_ipex,
|
| 281 |
-
bf16 = bf16,
|
| 282 |
-
fp16 = fp16,
|
| 283 |
-
fp16_opt_level = fp16_opt_level,
|
| 284 |
-
half_precision_backend = half_precision_backend,
|
| 285 |
-
bf16_full_eval = bf16_full_eval,
|
| 286 |
-
fp16_full_eval = fp16_full_eval,
|
| 287 |
-
tf32 = tf32,
|
| 288 |
-
local_rank = local_rank,
|
| 289 |
-
ddp_backend = ddp_backend,
|
| 290 |
-
tpu_num_cores = tpu_num_cores,
|
| 291 |
-
tpu_metrics_debug = tpu_metrics_debug,
|
| 292 |
-
debug = debug,
|
| 293 |
-
dataloader_drop_last = dataloader_drop_last,
|
| 294 |
-
eval_steps = eval_steps,
|
| 295 |
-
dataloader_num_workers = dataloader_num_workers,
|
| 296 |
-
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 297 |
-
past_index = past_index,
|
| 298 |
-
run_name = run_name,
|
| 299 |
-
disable_tqdm = disable_tqdm,
|
| 300 |
-
remove_unused_columns = remove_unused_columns,
|
| 301 |
-
label_names = label_names,
|
| 302 |
-
load_best_model_at_end = load_best_model_at_end,
|
| 303 |
-
metric_for_best_model = metric_for_best_model,
|
| 304 |
-
greater_is_better = greater_is_better,
|
| 305 |
-
ignore_data_skip = ignore_data_skip,
|
| 306 |
-
fsdp = fsdp,
|
| 307 |
-
fsdp_min_num_params = fsdp_min_num_params,
|
| 308 |
-
fsdp_config = fsdp_config,
|
| 309 |
-
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
| 310 |
-
accelerator_config = accelerator_config,
|
| 311 |
-
deepspeed = deepspeed,
|
| 312 |
-
label_smoothing_factor = label_smoothing_factor,
|
| 313 |
-
optim = optim,
|
| 314 |
-
optim_args = optim_args,
|
| 315 |
-
adafactor = adafactor,
|
| 316 |
-
group_by_length = group_by_length,
|
| 317 |
-
length_column_name = length_column_name,
|
| 318 |
-
report_to = report_to,
|
| 319 |
-
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 320 |
-
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 321 |
-
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 322 |
-
dataloader_pin_memory = dataloader_pin_memory,
|
| 323 |
-
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 324 |
-
skip_memory_metrics = skip_memory_metrics,
|
| 325 |
-
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
| 326 |
-
push_to_hub = push_to_hub,
|
| 327 |
-
resume_from_checkpoint = resume_from_checkpoint,
|
| 328 |
-
hub_model_id = hub_model_id,
|
| 329 |
-
hub_strategy = hub_strategy,
|
| 330 |
-
hub_token = hub_token,
|
| 331 |
-
hub_private_repo = hub_private_repo,
|
| 332 |
-
hub_always_push = hub_always_push,
|
| 333 |
-
hub_revision = hub_revision,
|
| 334 |
-
gradient_checkpointing = gradient_checkpointing,
|
| 335 |
-
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 336 |
-
include_inputs_for_metrics = include_inputs_for_metrics,
|
| 337 |
-
eval_do_concat_batches = eval_do_concat_batches,
|
| 338 |
-
fp16_backend = fp16_backend,
|
| 339 |
-
push_to_hub_model_id = push_to_hub_model_id,
|
| 340 |
-
push_to_hub_organization = push_to_hub_organization,
|
| 341 |
-
push_to_hub_token = push_to_hub_token,
|
| 342 |
-
mp_parameters = mp_parameters,
|
| 343 |
-
auto_find_batch_size = auto_find_batch_size,
|
| 344 |
-
full_determinism = full_determinism,
|
| 345 |
-
torchdynamo = torchdynamo,
|
| 346 |
-
ray_scope = ray_scope,
|
| 347 |
-
ddp_timeout = ddp_timeout,
|
| 348 |
-
torch_compile = torch_compile,
|
| 349 |
-
torch_compile_backend = torch_compile_backend,
|
| 350 |
-
torch_compile_mode = torch_compile_mode,
|
| 351 |
-
include_tokens_per_second = include_tokens_per_second,
|
| 352 |
-
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 353 |
-
neftune_noise_alpha = neftune_noise_alpha,
|
| 354 |
-
optim_target_modules = optim_target_modules,
|
| 355 |
-
batch_eval_metrics = batch_eval_metrics,
|
| 356 |
-
eval_on_start = eval_on_start,
|
| 357 |
-
use_liger_kernel = use_liger_kernel,
|
| 358 |
-
liger_kernel_config = liger_kernel_config,
|
| 359 |
-
eval_use_gather_object = eval_use_gather_object,
|
| 360 |
-
average_tokens_across_devices = average_tokens_across_devices,
|
| 361 |
-
max_length = max_length,
|
| 362 |
-
disable_dropout = disable_dropout,
|
| 363 |
-
dataset_num_proc = dataset_num_proc,
|
| 364 |
-
center_rewards_coefficient = center_rewards_coefficient,**kwargs)
|
| 365 |
-
self.vllm_sampling_params = vllm_sampling_params
|
| 366 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
| 367 |
-
pass
|
| 368 |
-
|
| 369 |
-
class _UnslothRewardTrainer(Trainer):
|
| 370 |
-
_tag_names = ["trl", "reward-trainer"]
|
| 371 |
-
|
| 372 |
-
def __init__(
|
| 373 |
-
self,
|
| 374 |
-
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
|
| 375 |
-
args: Optional[RewardConfig] = None,
|
| 376 |
-
data_collator: Optional[DataCollator] = None,
|
| 377 |
-
train_dataset: Optional[Dataset] = None,
|
| 378 |
-
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 379 |
-
processing_class: Optional[
|
| 380 |
-
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 381 |
-
] = None,
|
| 382 |
-
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
| 383 |
-
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
| 384 |
-
callbacks: Optional[list[TrainerCallback]] = None,
|
| 385 |
-
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
|
| 386 |
-
None,
|
| 387 |
-
None,
|
| 388 |
-
),
|
| 389 |
-
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 390 |
-
peft_config: Optional[dict] = None,
|
| 391 |
-
):
|
| 392 |
-
"""
|
| 393 |
-
Initialize RewardTrainer.
|
| 394 |
-
|
| 395 |
-
Args:
|
| 396 |
-
model (`transformers.PreTrainedModel`):
|
| 397 |
-
The model to train, preferably an `AutoModelForSequenceClassification`.
|
| 398 |
-
args (`RewardConfig`):
|
| 399 |
-
The arguments to use for training.
|
| 400 |
-
data_collator (`transformers.DataCollator`):
|
| 401 |
-
The data collator to use for training. If None is specified, the default data collator (`RewardDataCollatorWithPadding`) will be used
|
| 402 |
-
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
| 403 |
-
train_dataset (`datasets.Dataset`):
|
| 404 |
-
The dataset to use for training.
|
| 405 |
-
eval_dataset (`datasets.Dataset`):
|
| 406 |
-
The dataset to use for evaluation.
|
| 407 |
-
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
| 408 |
-
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
| 409 |
-
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
| 410 |
-
reuse the fine-tuned model.
|
| 411 |
-
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
| 412 |
-
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
| 413 |
-
compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
|
| 414 |
-
The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.
|
| 415 |
-
callbacks (`list[transformers.TrainerCallback]`):
|
| 416 |
-
The callbacks to use for training.
|
| 417 |
-
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
| 418 |
-
The optimizer and scheduler to use for training.
|
| 419 |
-
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
| 420 |
-
The function to use to preprocess the logits before computing the metrics.
|
| 421 |
-
peft_config (`dict`, defaults to `None`):
|
| 422 |
-
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
| 423 |
-
"""
|
| 424 |
-
if not is_peft_available() and peft_config is not None:
|
| 425 |
-
raise ValueError(
|
| 426 |
-
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
| 427 |
-
)
|
| 428 |
-
elif is_peft_available() and peft_config is not None:
|
| 429 |
-
if not isinstance(model, PeftModel):
|
| 430 |
-
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
|
| 431 |
-
_supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
|
| 432 |
-
inspect.signature(prepare_model_for_kbit_training).parameters
|
| 433 |
-
)
|
| 434 |
-
|
| 435 |
-
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
| 436 |
-
|
| 437 |
-
if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
| 438 |
-
warnings.warn(
|
| 439 |
-
"You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
|
| 440 |
-
"please update to the latest version of peft to use `gradient_checkpointing_kwargs`.",
|
| 441 |
-
UserWarning,
|
| 442 |
-
)
|
| 443 |
-
elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
| 444 |
-
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
| 445 |
-
|
| 446 |
-
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
| 447 |
-
|
| 448 |
-
model = model
|
| 449 |
-
|
| 450 |
-
# Disable dropout in the model
|
| 451 |
-
if args.disable_dropout:
|
| 452 |
-
disable_dropout_in_model(model)
|
| 453 |
-
|
| 454 |
-
if compute_metrics is None:
|
| 455 |
-
compute_metrics = compute_accuracy
|
| 456 |
-
|
| 457 |
-
if data_collator is None:
|
| 458 |
-
if processing_class is None:
|
| 459 |
-
raise ValueError(
|
| 460 |
-
"A processing_class must be specified when using the default RewardDataCollatorWithPadding"
|
| 461 |
-
)
|
| 462 |
-
|
| 463 |
-
max_length = args.max_length
|
| 464 |
-
|
| 465 |
-
data_collator = RewardDataCollatorWithPadding(processing_class)
|
| 466 |
-
|
| 467 |
-
if args.remove_unused_columns:
|
| 468 |
-
try: # for bc before https://github.com/huggingface/transformers/pull/25435
|
| 469 |
-
args.remove_unused_columns = False
|
| 470 |
-
except FrozenInstanceError:
|
| 471 |
-
args = replace(args, remove_unused_columns=False)
|
| 472 |
-
# warn users
|
| 473 |
-
warnings.warn(
|
| 474 |
-
"When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig"
|
| 475 |
-
" we have set it for you, but you should do it yourself in the future.",
|
| 476 |
-
UserWarning,
|
| 477 |
-
)
|
| 478 |
-
|
| 479 |
-
self.use_reward_data_collator = True
|
| 480 |
-
else:
|
| 481 |
-
self.use_reward_data_collator = False
|
| 482 |
-
|
| 483 |
-
# The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
|
| 484 |
-
# input tensor associated with the key "input_ids". However, in Reward, the sampled data does not include the
|
| 485 |
-
# "input_ids" key. Instead, the available keys are "input_ids_chosen" and "input_ids_rejected". As a result,
|
| 486 |
-
# the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
|
| 487 |
-
# operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
|
| 488 |
-
# "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
|
| 489 |
-
# issued.
|
| 490 |
-
model.warnings_issued["estimate_tokens"] = True
|
| 491 |
-
|
| 492 |
-
if "input_ids_chosen" not in train_dataset.column_names:
|
| 493 |
-
with PartialState().main_process_first():
|
| 494 |
-
fn_kwargs = {"tokenizer": processing_class}
|
| 495 |
-
train_dataset = train_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class})
|
| 496 |
-
train_dataset = train_dataset.map(
|
| 497 |
-
_tokenize,
|
| 498 |
-
batched=True,
|
| 499 |
-
fn_kwargs=fn_kwargs,
|
| 500 |
-
num_proc=args.dataset_num_proc,
|
| 501 |
-
)
|
| 502 |
-
# This filter is important because otherwise you get samples that exceed the model's context length and
|
| 503 |
-
# get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
|
| 504 |
-
# user might get surprised if N samples are missing from training.
|
| 505 |
-
train_dataset = train_dataset.filter(
|
| 506 |
-
lambda x: len(x["input_ids_chosen"]) <= max_length and len(x["input_ids_rejected"]) <= max_length,
|
| 507 |
-
num_proc=args.dataset_num_proc,
|
| 508 |
-
)
|
| 509 |
-
if eval_dataset is not None:
|
| 510 |
-
eval_dataset = eval_dataset.map(
|
| 511 |
-
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}
|
| 512 |
-
)
|
| 513 |
-
eval_dataset = eval_dataset.map(
|
| 514 |
-
_tokenize,
|
| 515 |
-
fn_kwargs=fn_kwargs,
|
| 516 |
-
batched=True,
|
| 517 |
-
num_proc=args.dataset_num_proc,
|
| 518 |
-
)
|
| 519 |
-
# This filter is important because otherwise you get samples that exceed the model's context length and
|
| 520 |
-
# get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
|
| 521 |
-
# user might get surprised if N samples are missing from training.
|
| 522 |
-
eval_dataset = eval_dataset.filter(
|
| 523 |
-
lambda x: len(x["input_ids_chosen"]) <= max_length
|
| 524 |
-
and len(x["input_ids_rejected"]) <= max_length,
|
| 525 |
-
num_proc=args.dataset_num_proc,
|
| 526 |
-
)
|
| 527 |
-
|
| 528 |
-
super().__init__(
|
| 529 |
-
model=model,
|
| 530 |
-
args=args,
|
| 531 |
-
data_collator=data_collator,
|
| 532 |
-
train_dataset=train_dataset,
|
| 533 |
-
eval_dataset=eval_dataset,
|
| 534 |
-
processing_class=processing_class,
|
| 535 |
-
model_init=model_init,
|
| 536 |
-
compute_metrics=compute_metrics,
|
| 537 |
-
callbacks=callbacks,
|
| 538 |
-
optimizers=optimizers,
|
| 539 |
-
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 540 |
-
)
|
| 541 |
-
|
| 542 |
-
# Add tags for models that have been loaded with the correct transformers version
|
| 543 |
-
if hasattr(self.model, "add_model_tags"):
|
| 544 |
-
self.model.add_model_tags(self._tag_names)
|
| 545 |
-
|
| 546 |
-
def compute_loss(
|
| 547 |
-
self,
|
| 548 |
-
model: Union[PreTrainedModel, nn.Module],
|
| 549 |
-
inputs: dict[str, Union[torch.Tensor, Any]],
|
| 550 |
-
return_outputs=False,
|
| 551 |
-
num_items_in_batch=None,
|
| 552 |
-
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
| 553 |
-
rewards_chosen = model(
|
| 554 |
-
input_ids=inputs["input_ids_chosen"],
|
| 555 |
-
attention_mask=inputs["attention_mask_chosen"],
|
| 556 |
-
return_dict=True,
|
| 557 |
-
)["logits"]
|
| 558 |
-
rewards_rejected = model(
|
| 559 |
-
input_ids=inputs["input_ids_rejected"],
|
| 560 |
-
attention_mask=inputs["attention_mask_rejected"],
|
| 561 |
-
return_dict=True,
|
| 562 |
-
)["logits"]
|
| 563 |
-
# calculate loss, optionally modulate with margin
|
| 564 |
-
if "margin" in inputs:
|
| 565 |
-
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
|
| 566 |
-
else:
|
| 567 |
-
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
|
| 568 |
-
|
| 569 |
-
if self.args.center_rewards_coefficient is not None:
|
| 570 |
-
loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected) ** 2)
|
| 571 |
-
|
| 572 |
-
if return_outputs:
|
| 573 |
-
return loss, {
|
| 574 |
-
"rewards_chosen": rewards_chosen,
|
| 575 |
-
"rewards_rejected": rewards_rejected,
|
| 576 |
-
}
|
| 577 |
-
return loss
|
| 578 |
-
|
| 579 |
-
def prediction_step(
|
| 580 |
-
self,
|
| 581 |
-
model: Union[PreTrainedModel, nn.Module],
|
| 582 |
-
inputs: dict[str, Union[torch.Tensor, Any]],
|
| 583 |
-
prediction_loss_only: bool,
|
| 584 |
-
ignore_keys: Optional[list[str]] = None,
|
| 585 |
-
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 586 |
-
inputs = self._prepare_inputs(inputs)
|
| 587 |
-
if ignore_keys is None:
|
| 588 |
-
if hasattr(self.model, "config"):
|
| 589 |
-
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
|
| 590 |
-
else:
|
| 591 |
-
ignore_keys = []
|
| 592 |
-
|
| 593 |
-
with torch.no_grad():
|
| 594 |
-
loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True)
|
| 595 |
-
|
| 596 |
-
if prediction_loss_only:
|
| 597 |
-
return (loss, None, None)
|
| 598 |
-
|
| 599 |
-
loss = loss.detach()
|
| 600 |
-
logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
|
| 601 |
-
logits = nested_detach(logits)
|
| 602 |
-
# Stack accepted against rejected, mean over logits
|
| 603 |
-
# and softmax to get preferences between accepted and rejected to sum to 1
|
| 604 |
-
logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T
|
| 605 |
-
|
| 606 |
-
labels = torch.zeros(logits.shape[0])
|
| 607 |
-
labels = self._prepare_inputs(labels)
|
| 608 |
-
|
| 609 |
-
return loss, logits, labels
|
| 610 |
-
|
| 611 |
-
def evaluate(self, *args, **kwargs):
|
| 612 |
-
num_print_samples = kwargs.pop("num_print_samples", 4)
|
| 613 |
-
self.visualize_samples(num_print_samples)
|
| 614 |
-
return super().evaluate(*args, **kwargs)
|
| 615 |
-
|
| 616 |
-
def visualize_samples(self, num_print_samples: int):
|
| 617 |
-
"""
|
| 618 |
-
Visualize the reward model logits prediction
|
| 619 |
-
|
| 620 |
-
Args:
|
| 621 |
-
num_print_samples (`int`, defaults to `4`):
|
| 622 |
-
The number of samples to print. Set to `-1` to print all samples.
|
| 623 |
-
"""
|
| 624 |
-
eval_dataloader = self.get_eval_dataloader()
|
| 625 |
-
table = defaultdict(list)
|
| 626 |
-
for _, inputs in enumerate(eval_dataloader):
|
| 627 |
-
_, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False)
|
| 628 |
-
chosen_text = decode_and_strip_padding(inputs["input_ids_chosen"], self.processing_class)
|
| 629 |
-
rejected_text = decode_and_strip_padding(inputs["input_ids_rejected"], self.processing_class)
|
| 630 |
-
table["chosen_text"].extend(gather_object(chosen_text))
|
| 631 |
-
table["rejected_text"].extend(gather_object(rejected_text))
|
| 632 |
-
table["logits"].extend(
|
| 633 |
-
gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()])
|
| 634 |
-
)
|
| 635 |
-
if num_print_samples >= 0 and len(table["chosen_text"]) >= num_print_samples:
|
| 636 |
-
break
|
| 637 |
-
df = pd.DataFrame(table)
|
| 638 |
-
if self.accelerator.process_index == 0:
|
| 639 |
-
print_rich_table(df[:num_print_samples])
|
| 640 |
-
if "wandb" in self.args.report_to:
|
| 641 |
-
import wandb
|
| 642 |
-
|
| 643 |
-
if wandb.run is not None:
|
| 644 |
-
wandb.log({"completions": wandb.Table(dataframe=df)})
|
| 645 |
-
|
| 646 |
-
if "comet_ml" in self.args.report_to:
|
| 647 |
-
log_table_to_comet_experiment(
|
| 648 |
-
name="completions.csv",
|
| 649 |
-
table=df,
|
| 650 |
-
)
|
| 651 |
-
|
| 652 |
-
def create_model_card(
|
| 653 |
-
self,
|
| 654 |
-
model_name: Optional[str] = None,
|
| 655 |
-
dataset_name: Optional[str] = None,
|
| 656 |
-
tags: Union[str, list[str], None] = None,
|
| 657 |
-
):
|
| 658 |
-
"""
|
| 659 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
| 660 |
-
|
| 661 |
-
Args:
|
| 662 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 663 |
-
Name of the model.
|
| 664 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 665 |
-
Name of the dataset used for training.
|
| 666 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 667 |
-
Tags to be associated with the model card.
|
| 668 |
-
"""
|
| 669 |
-
if not self.is_world_process_zero():
|
| 670 |
-
return
|
| 671 |
-
|
| 672 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 673 |
-
base_model = self.model.config._name_or_path
|
| 674 |
-
else:
|
| 675 |
-
base_model = None
|
| 676 |
-
|
| 677 |
-
tags = tags or []
|
| 678 |
-
if isinstance(tags, str):
|
| 679 |
-
tags = [tags]
|
| 680 |
-
|
| 681 |
-
if hasattr(self.model.config, "unsloth_version"):
|
| 682 |
-
tags.append("unsloth")
|
| 683 |
-
|
| 684 |
-
model_card = generate_model_card(
|
| 685 |
-
base_model=base_model,
|
| 686 |
-
model_name=model_name,
|
| 687 |
-
hub_model_id=self.hub_model_id,
|
| 688 |
-
dataset_name=dataset_name,
|
| 689 |
-
tags=tags,
|
| 690 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 691 |
-
comet_url=get_comet_experiment_url(),
|
| 692 |
-
trainer_name="Reward",
|
| 693 |
-
)
|
| 694 |
-
|
| 695 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 696 |
-
class UnslothRewardTrainer(_UnslothRewardTrainer):
|
| 697 |
-
"""
|
| 698 |
-
|
| 699 |
-
"""
|
| 700 |
-
def __init__(
|
| 701 |
-
self,
|
| 702 |
-
model = None,
|
| 703 |
-
args = None,
|
| 704 |
-
data_collator = None,
|
| 705 |
-
train_dataset = None,
|
| 706 |
-
eval_dataset = None,
|
| 707 |
-
processing_class = None,
|
| 708 |
-
model_init = None,
|
| 709 |
-
compute_metrics = None,
|
| 710 |
-
callbacks = None,
|
| 711 |
-
preprocess_logits_for_metrics = None,
|
| 712 |
-
peft_config = None,
|
| 713 |
-
**kwargs
|
| 714 |
-
):
|
| 715 |
-
if args is None: args = UnslothRewardConfig()
|
| 716 |
-
use_bf16 = getattr(args, 'bf16', False)
|
| 717 |
-
if type(use_bf16) is not bool: use_bf16 = False
|
| 718 |
-
use_fp16 = getattr(args, 'fp16', False)
|
| 719 |
-
if type(use_fp16) is not bool: use_fp16 = False
|
| 720 |
-
force_float32 = False
|
| 721 |
-
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
| 722 |
-
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 723 |
-
force_float32 = True
|
| 724 |
-
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 725 |
-
dtype = getattr(model.config, 'torch_dtype', None)
|
| 726 |
-
if dtype is None: dtype = model.get_input_embeddings().dtype
|
| 727 |
-
from unsloth_zoo.utils import _get_dtype
|
| 728 |
-
dtype = _get_dtype(dtype)
|
| 729 |
-
float16 = dtype == torch.float16
|
| 730 |
-
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 731 |
-
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 732 |
-
if force_float32:
|
| 733 |
-
args.fp16 = False
|
| 734 |
-
args.bf16 = False
|
| 735 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 736 |
-
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 737 |
-
args.fp16 = float16
|
| 738 |
-
args.bf16 = not float16
|
| 739 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 740 |
-
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 741 |
-
args.eval_strategy = 'steps'
|
| 742 |
-
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 743 |
-
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 744 |
-
if ga_steps is not None and ga_steps > 1:
|
| 745 |
-
from transformers import __version__ as transformers_version
|
| 746 |
-
if Version(transformers_version) <= Version('4.45.2'):
|
| 747 |
-
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 748 |
-
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 749 |
-
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 750 |
-
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 751 |
-
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 752 |
-
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 753 |
-
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 754 |
-
if type(fp16_full_eval) is not bool: fp16_full_eval = False
|
| 755 |
-
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 756 |
-
if type(bf16_full_eval) is not bool: bf16_full_eval = False
|
| 757 |
-
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 758 |
-
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 759 |
-
if force_float32:
|
| 760 |
-
args.bf16_full_eval = False
|
| 761 |
-
args.fp16_full_eval = False
|
| 762 |
-
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 763 |
-
args.bf16_full_eval = True
|
| 764 |
-
args.fp16_full_eval = False
|
| 765 |
-
elif not bf16_full_eval and not fp16_full_eval:
|
| 766 |
-
args.bf16_full_eval = args.bf16
|
| 767 |
-
args.fp16_full_eval = args.fp16
|
| 768 |
-
_output_logits = False
|
| 769 |
-
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 770 |
-
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 771 |
-
if _output_logits:
|
| 772 |
-
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 773 |
-
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 774 |
-
pass
|
| 775 |
-
else:
|
| 776 |
-
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 777 |
-
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 778 |
-
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 779 |
-
max_seq_length = model.max_seq_length
|
| 780 |
-
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 781 |
-
if model is not None and hasattr(model, 'for_training'):
|
| 782 |
-
model.for_training()
|
| 783 |
-
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 784 |
-
if 'processing_class' in locals():
|
| 785 |
-
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 786 |
-
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 787 |
-
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 788 |
-
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 789 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 790 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 791 |
-
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
|
| 792 |
-
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 793 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
| 794 |
-
else:
|
| 795 |
-
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 796 |
-
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 797 |
-
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 798 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 799 |
-
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 800 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 801 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
| 802 |
-
else:
|
| 803 |
-
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
|
| 804 |
-
other_metrics = []
|
| 805 |
-
|
| 806 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 807 |
-
PatchRLStatistics('reward_trainer', other_metrics)
|
| 808 |
-
|
| 809 |
-
super().__init__(
|
| 810 |
-
model = model,
|
| 811 |
-
args = args,
|
| 812 |
-
data_collator = data_collator,
|
| 813 |
-
train_dataset = train_dataset,
|
| 814 |
-
eval_dataset = eval_dataset,
|
| 815 |
-
processing_class = processing_class,
|
| 816 |
-
model_init = model_init,
|
| 817 |
-
compute_metrics = compute_metrics,
|
| 818 |
-
callbacks = callbacks,
|
| 819 |
-
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
| 820 |
-
peft_config = peft_config,**kwargs)
|
| 821 |
-
if hasattr(self, 'neftune_hook_handle'):
|
| 822 |
-
self.neftune_hook_handle.remove()
|
| 823 |
-
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 824 |
-
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 825 |
-
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 826 |
-
pass
|
| 827 |
-
|
| 828 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_run_uploads/UnslothSFTTrainer.py
DELETED
|
@@ -1,1102 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
2025.7.11
|
| 3 |
-
2025.7.11
|
| 4 |
-
4.54.1
|
| 5 |
-
0.16.1
|
| 6 |
-
__UNSLOTH_VERSIONING__
|
| 7 |
-
"""
|
| 8 |
-
from torch import Tensor
|
| 9 |
-
import torch
|
| 10 |
-
import torch.nn as nn
|
| 11 |
-
from torch.nn import functional as F
|
| 12 |
-
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
| 13 |
-
from trl.trainer.sft_trainer import (Any, AutoModelForCausalLM, AutoTokenizer, BaseImageProcessor, Callable, ConstantLengthDataset, DataCollator, DataCollatorForLanguageModeling, DataCollatorWithFlattening, Dataset, EvalPrediction, FeatureExtractionMixin, IterableDataset, Optional, PeftConfig, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTConfig, SFTTrainer, Trainer, TrainerCallback, TrainingArguments, Type, Union, dataclass, dataclasses, defaultdict, generate_model_card, get_comet_experiment_url, get_peft_model, is_peft_available, is_wandb_available, nn, os, pad, peft, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, transformers, version, wandb, warnings, Callable, ConstantLengthDataset, DataCollator, DataCollatorForLanguageModeling, Dataset, IterableDataset, Optional, Union, os, pad, transformers, os)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
import os
|
| 17 |
-
from typing import *
|
| 18 |
-
from dataclasses import dataclass, field
|
| 19 |
-
from packaging.version import Version
|
| 20 |
-
import torch
|
| 21 |
-
import numpy as np
|
| 22 |
-
from contextlib import nullcontext
|
| 23 |
-
from torch.nn import functional as F
|
| 24 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
| 25 |
-
|
| 26 |
-
torch_compile_options = {
|
| 27 |
-
"epilogue_fusion" : True,
|
| 28 |
-
"max_autotune" : False,
|
| 29 |
-
"shape_padding" : True,
|
| 30 |
-
"trace.enabled" : False,
|
| 31 |
-
"triton.cudagraphs" : False,
|
| 32 |
-
}
|
| 33 |
-
|
| 34 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 35 |
-
def chunked_selective_log_softmax(logits, index):
|
| 36 |
-
# Split into 4 chunks only
|
| 37 |
-
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
|
| 38 |
-
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
|
| 39 |
-
all_per_token_logps = []
|
| 40 |
-
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
|
| 41 |
-
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
|
| 42 |
-
chunk_logits = chunk_logits.to(torch.float32)
|
| 43 |
-
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 44 |
-
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
|
| 45 |
-
per_token_logps = selected_logits - logsumexp_values
|
| 46 |
-
all_per_token_logps.append(per_token_logps)
|
| 47 |
-
pass
|
| 48 |
-
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 49 |
-
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
|
| 50 |
-
return all_per_token_logps
|
| 51 |
-
@dataclass
|
| 52 |
-
class UnslothSFTConfig(SFTConfig):
|
| 53 |
-
"""
|
| 54 |
-
|
| 55 |
-
Configuration class for the [`SFTTrainer`].
|
| 56 |
-
|
| 57 |
-
Only the parameters specific to SFT training are listed here. For details on other parameters, refer to the
|
| 58 |
-
[`~transformers.TrainingArguments`] documentation.
|
| 59 |
-
|
| 60 |
-
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 61 |
-
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 62 |
-
command line.
|
| 63 |
-
|
| 64 |
-
Parameters:
|
| 65 |
-
> Parameters that control the model
|
| 66 |
-
|
| 67 |
-
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
| 68 |
-
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
|
| 69 |
-
argument of the [`SFTTrainer`] is provided as a string.
|
| 70 |
-
|
| 71 |
-
> Parameters that control the data preprocessing
|
| 72 |
-
|
| 73 |
-
dataset_text_field (`str`, *optional*, defaults to `"text"`):
|
| 74 |
-
Name of the column that contains text data in the dataset.
|
| 75 |
-
dataset_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
| 76 |
-
Dictionary of optional keyword arguments for the dataset preparation. The only supported key is
|
| 77 |
-
`skip_prepare_dataset`.
|
| 78 |
-
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
| 79 |
-
Number of processes to use for processing the dataset.
|
| 80 |
-
pad_token (`str` or `None`, *optional*, defaults to `None`):
|
| 81 |
-
Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`,
|
| 82 |
-
it falls back to `processing_class.eos_token`.
|
| 83 |
-
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
| 84 |
-
Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right.
|
| 85 |
-
If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length.
|
| 86 |
-
packing (`bool`, *optional*, defaults to `False`):
|
| 87 |
-
Whether to pack multiple sequences into a fixed-length format. Uses `max_length` to define sequence length.
|
| 88 |
-
padding_free (`bool`, *optional*, defaults to `False`):
|
| 89 |
-
Whether to perform forward passes without padding by flattening all sequences in the batch into a single
|
| 90 |
-
continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only
|
| 91 |
-
supported with the `flash_attention_2` attention implementation, which can efficiently handle the flattened
|
| 92 |
-
batch structure.
|
| 93 |
-
eval_packing (`bool` or `None`, *optional*, defaults to `None`):
|
| 94 |
-
Whether to pack the eval dataset. If `None`, uses the same value as `packing`.
|
| 95 |
-
|
| 96 |
-
> Parameters that control the training
|
| 97 |
-
|
| 98 |
-
learning_rate (`float`, *optional*, defaults to `2e-5`):
|
| 99 |
-
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
| 100 |
-
[`~transformers.TrainingArguments`].
|
| 101 |
-
|
| 102 |
-
"""
|
| 103 |
-
vllm_sampling_params: Optional[Any] = field(
|
| 104 |
-
default = None,
|
| 105 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
| 106 |
-
)
|
| 107 |
-
unsloth_num_chunks : Optional[int] = field(
|
| 108 |
-
default = -1,
|
| 109 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 110 |
-
)
|
| 111 |
-
def __init__(
|
| 112 |
-
self,
|
| 113 |
-
output_dir = None,
|
| 114 |
-
overwrite_output_dir = None,
|
| 115 |
-
do_train = False,
|
| 116 |
-
do_eval = False,
|
| 117 |
-
do_predict = False,
|
| 118 |
-
eval_strategy = 'no',
|
| 119 |
-
prediction_loss_only = False,
|
| 120 |
-
per_device_train_batch_size = 4,
|
| 121 |
-
per_device_eval_batch_size = 4,
|
| 122 |
-
per_gpu_train_batch_size = None,
|
| 123 |
-
per_gpu_eval_batch_size = None,
|
| 124 |
-
gradient_accumulation_steps = 2,
|
| 125 |
-
eval_accumulation_steps = 2,
|
| 126 |
-
eval_delay = 0,
|
| 127 |
-
torch_empty_cache_steps = 250,
|
| 128 |
-
learning_rate = 5e-05,
|
| 129 |
-
weight_decay = 0.01,
|
| 130 |
-
adam_beta1 = 0.9,
|
| 131 |
-
adam_beta2 = 0.999,
|
| 132 |
-
adam_epsilon = 1e-08,
|
| 133 |
-
max_grad_norm = 1.0,
|
| 134 |
-
num_train_epochs = 3.0,
|
| 135 |
-
max_steps = -1,
|
| 136 |
-
lr_scheduler_type = 'linear',
|
| 137 |
-
warmup_ratio = 0.1,
|
| 138 |
-
warmup_steps = 0,
|
| 139 |
-
log_level = 'passive',
|
| 140 |
-
log_level_replica = 'warning',
|
| 141 |
-
log_on_each_node = True,
|
| 142 |
-
logging_dir = None,
|
| 143 |
-
logging_strategy = 'steps',
|
| 144 |
-
logging_first_step = False,
|
| 145 |
-
logging_steps = 1,
|
| 146 |
-
logging_nan_inf_filter = False,
|
| 147 |
-
save_strategy = 'steps',
|
| 148 |
-
save_steps = 500,
|
| 149 |
-
save_total_limit = None,
|
| 150 |
-
save_safetensors = True,
|
| 151 |
-
save_on_each_node = False,
|
| 152 |
-
save_only_model = False,
|
| 153 |
-
restore_callback_states_from_checkpoint = False,
|
| 154 |
-
no_cuda = False,
|
| 155 |
-
use_cpu = False,
|
| 156 |
-
use_mps_device = False,
|
| 157 |
-
seed = 3407,
|
| 158 |
-
data_seed = 3407,
|
| 159 |
-
jit_mode_eval = False,
|
| 160 |
-
use_ipex = False,
|
| 161 |
-
bf16 = False,
|
| 162 |
-
fp16 = False,
|
| 163 |
-
fp16_opt_level = 'O1',
|
| 164 |
-
half_precision_backend = 'auto',
|
| 165 |
-
bf16_full_eval = False,
|
| 166 |
-
fp16_full_eval = False,
|
| 167 |
-
tf32 = None,
|
| 168 |
-
local_rank = -1,
|
| 169 |
-
ddp_backend = None,
|
| 170 |
-
tpu_num_cores = None,
|
| 171 |
-
tpu_metrics_debug = False,
|
| 172 |
-
debug = '',
|
| 173 |
-
dataloader_drop_last = False,
|
| 174 |
-
eval_steps = None,
|
| 175 |
-
dataloader_num_workers = 0,
|
| 176 |
-
dataloader_prefetch_factor = None,
|
| 177 |
-
past_index = -1,
|
| 178 |
-
run_name = None,
|
| 179 |
-
disable_tqdm = None,
|
| 180 |
-
remove_unused_columns = True,
|
| 181 |
-
label_names = None,
|
| 182 |
-
load_best_model_at_end = False,
|
| 183 |
-
metric_for_best_model = None,
|
| 184 |
-
greater_is_better = None,
|
| 185 |
-
ignore_data_skip = False,
|
| 186 |
-
fsdp = '',
|
| 187 |
-
fsdp_min_num_params = 0,
|
| 188 |
-
fsdp_config = None,
|
| 189 |
-
fsdp_transformer_layer_cls_to_wrap = None,
|
| 190 |
-
accelerator_config = None,
|
| 191 |
-
deepspeed = None,
|
| 192 |
-
label_smoothing_factor = 0.0,
|
| 193 |
-
optim = 'adamw_8bit',
|
| 194 |
-
optim_args = None,
|
| 195 |
-
adafactor = False,
|
| 196 |
-
group_by_length = False,
|
| 197 |
-
length_column_name = 'length',
|
| 198 |
-
report_to = None,
|
| 199 |
-
ddp_find_unused_parameters = None,
|
| 200 |
-
ddp_bucket_cap_mb = None,
|
| 201 |
-
ddp_broadcast_buffers = None,
|
| 202 |
-
dataloader_pin_memory = True,
|
| 203 |
-
dataloader_persistent_workers = False,
|
| 204 |
-
skip_memory_metrics = True,
|
| 205 |
-
use_legacy_prediction_loop = False,
|
| 206 |
-
push_to_hub = False,
|
| 207 |
-
resume_from_checkpoint = None,
|
| 208 |
-
hub_model_id = None,
|
| 209 |
-
hub_strategy = 'every_save',
|
| 210 |
-
hub_token = None,
|
| 211 |
-
hub_private_repo = None,
|
| 212 |
-
hub_always_push = False,
|
| 213 |
-
hub_revision = None,
|
| 214 |
-
gradient_checkpointing = False,
|
| 215 |
-
gradient_checkpointing_kwargs = None,
|
| 216 |
-
include_inputs_for_metrics = False,
|
| 217 |
-
eval_do_concat_batches = True,
|
| 218 |
-
fp16_backend = 'auto',
|
| 219 |
-
push_to_hub_model_id = None,
|
| 220 |
-
push_to_hub_organization = None,
|
| 221 |
-
push_to_hub_token = None,
|
| 222 |
-
mp_parameters = '',
|
| 223 |
-
auto_find_batch_size = True,
|
| 224 |
-
full_determinism = False,
|
| 225 |
-
torchdynamo = None,
|
| 226 |
-
ray_scope = 'last',
|
| 227 |
-
ddp_timeout = 1800,
|
| 228 |
-
torch_compile = False,
|
| 229 |
-
torch_compile_backend = None,
|
| 230 |
-
torch_compile_mode = None,
|
| 231 |
-
include_tokens_per_second = False,
|
| 232 |
-
include_num_input_tokens_seen = False,
|
| 233 |
-
neftune_noise_alpha = None,
|
| 234 |
-
optim_target_modules = None,
|
| 235 |
-
batch_eval_metrics = False,
|
| 236 |
-
eval_on_start = False,
|
| 237 |
-
use_liger_kernel = False,
|
| 238 |
-
liger_kernel_config = None,
|
| 239 |
-
eval_use_gather_object = False,
|
| 240 |
-
average_tokens_across_devices = True,
|
| 241 |
-
model_init_kwargs = None,
|
| 242 |
-
dataset_text_field = 'text',
|
| 243 |
-
dataset_kwargs = None,
|
| 244 |
-
dataset_num_proc = None,
|
| 245 |
-
pad_token = None,
|
| 246 |
-
max_length = 1024,
|
| 247 |
-
packing = False,
|
| 248 |
-
padding_free = False,
|
| 249 |
-
eval_packing = None,
|
| 250 |
-
dataset_batch_size = None,
|
| 251 |
-
num_of_sequences = None,
|
| 252 |
-
chars_per_token = None,
|
| 253 |
-
max_seq_length = None,
|
| 254 |
-
use_liger = None,
|
| 255 |
-
vllm_sampling_params = None,
|
| 256 |
-
unsloth_num_chunks = -1,
|
| 257 |
-
**kwargs,
|
| 258 |
-
):
|
| 259 |
-
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 260 |
-
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 261 |
-
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 262 |
-
output_dir = 'unsloth_training_checkpoints'
|
| 263 |
-
save_strategy = 'no'
|
| 264 |
-
if dataset_num_proc is None:
|
| 265 |
-
from multiprocessing import cpu_count
|
| 266 |
-
dataset_num_proc = min(cpu_count()*2, 2)
|
| 267 |
-
|
| 268 |
-
super().__init__(
|
| 269 |
-
output_dir = output_dir,
|
| 270 |
-
overwrite_output_dir = overwrite_output_dir,
|
| 271 |
-
do_train = do_train,
|
| 272 |
-
do_eval = do_eval,
|
| 273 |
-
do_predict = do_predict,
|
| 274 |
-
eval_strategy = eval_strategy,
|
| 275 |
-
prediction_loss_only = prediction_loss_only,
|
| 276 |
-
per_device_train_batch_size = per_device_train_batch_size,
|
| 277 |
-
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 278 |
-
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
| 279 |
-
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
| 280 |
-
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 281 |
-
eval_accumulation_steps = eval_accumulation_steps,
|
| 282 |
-
eval_delay = eval_delay,
|
| 283 |
-
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 284 |
-
learning_rate = learning_rate,
|
| 285 |
-
weight_decay = weight_decay,
|
| 286 |
-
adam_beta1 = adam_beta1,
|
| 287 |
-
adam_beta2 = adam_beta2,
|
| 288 |
-
adam_epsilon = adam_epsilon,
|
| 289 |
-
max_grad_norm = max_grad_norm,
|
| 290 |
-
num_train_epochs = num_train_epochs,
|
| 291 |
-
max_steps = max_steps,
|
| 292 |
-
lr_scheduler_type = lr_scheduler_type,
|
| 293 |
-
warmup_ratio = warmup_ratio,
|
| 294 |
-
warmup_steps = warmup_steps,
|
| 295 |
-
log_level = log_level,
|
| 296 |
-
log_level_replica = log_level_replica,
|
| 297 |
-
log_on_each_node = log_on_each_node,
|
| 298 |
-
logging_dir = logging_dir,
|
| 299 |
-
logging_strategy = logging_strategy,
|
| 300 |
-
logging_first_step = logging_first_step,
|
| 301 |
-
logging_steps = logging_steps,
|
| 302 |
-
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 303 |
-
save_strategy = save_strategy,
|
| 304 |
-
save_steps = save_steps,
|
| 305 |
-
save_total_limit = save_total_limit,
|
| 306 |
-
save_safetensors = save_safetensors,
|
| 307 |
-
save_on_each_node = save_on_each_node,
|
| 308 |
-
save_only_model = save_only_model,
|
| 309 |
-
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 310 |
-
no_cuda = no_cuda,
|
| 311 |
-
use_cpu = use_cpu,
|
| 312 |
-
use_mps_device = use_mps_device,
|
| 313 |
-
seed = seed,
|
| 314 |
-
data_seed = data_seed,
|
| 315 |
-
jit_mode_eval = jit_mode_eval,
|
| 316 |
-
use_ipex = use_ipex,
|
| 317 |
-
bf16 = bf16,
|
| 318 |
-
fp16 = fp16,
|
| 319 |
-
fp16_opt_level = fp16_opt_level,
|
| 320 |
-
half_precision_backend = half_precision_backend,
|
| 321 |
-
bf16_full_eval = bf16_full_eval,
|
| 322 |
-
fp16_full_eval = fp16_full_eval,
|
| 323 |
-
tf32 = tf32,
|
| 324 |
-
local_rank = local_rank,
|
| 325 |
-
ddp_backend = ddp_backend,
|
| 326 |
-
tpu_num_cores = tpu_num_cores,
|
| 327 |
-
tpu_metrics_debug = tpu_metrics_debug,
|
| 328 |
-
debug = debug,
|
| 329 |
-
dataloader_drop_last = dataloader_drop_last,
|
| 330 |
-
eval_steps = eval_steps,
|
| 331 |
-
dataloader_num_workers = dataloader_num_workers,
|
| 332 |
-
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 333 |
-
past_index = past_index,
|
| 334 |
-
run_name = run_name,
|
| 335 |
-
disable_tqdm = disable_tqdm,
|
| 336 |
-
remove_unused_columns = remove_unused_columns,
|
| 337 |
-
label_names = label_names,
|
| 338 |
-
load_best_model_at_end = load_best_model_at_end,
|
| 339 |
-
metric_for_best_model = metric_for_best_model,
|
| 340 |
-
greater_is_better = greater_is_better,
|
| 341 |
-
ignore_data_skip = ignore_data_skip,
|
| 342 |
-
fsdp = fsdp,
|
| 343 |
-
fsdp_min_num_params = fsdp_min_num_params,
|
| 344 |
-
fsdp_config = fsdp_config,
|
| 345 |
-
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
| 346 |
-
accelerator_config = accelerator_config,
|
| 347 |
-
deepspeed = deepspeed,
|
| 348 |
-
label_smoothing_factor = label_smoothing_factor,
|
| 349 |
-
optim = optim,
|
| 350 |
-
optim_args = optim_args,
|
| 351 |
-
adafactor = adafactor,
|
| 352 |
-
group_by_length = group_by_length,
|
| 353 |
-
length_column_name = length_column_name,
|
| 354 |
-
report_to = report_to,
|
| 355 |
-
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 356 |
-
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 357 |
-
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 358 |
-
dataloader_pin_memory = dataloader_pin_memory,
|
| 359 |
-
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 360 |
-
skip_memory_metrics = skip_memory_metrics,
|
| 361 |
-
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
| 362 |
-
push_to_hub = push_to_hub,
|
| 363 |
-
resume_from_checkpoint = resume_from_checkpoint,
|
| 364 |
-
hub_model_id = hub_model_id,
|
| 365 |
-
hub_strategy = hub_strategy,
|
| 366 |
-
hub_token = hub_token,
|
| 367 |
-
hub_private_repo = hub_private_repo,
|
| 368 |
-
hub_always_push = hub_always_push,
|
| 369 |
-
hub_revision = hub_revision,
|
| 370 |
-
gradient_checkpointing = gradient_checkpointing,
|
| 371 |
-
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 372 |
-
include_inputs_for_metrics = include_inputs_for_metrics,
|
| 373 |
-
eval_do_concat_batches = eval_do_concat_batches,
|
| 374 |
-
fp16_backend = fp16_backend,
|
| 375 |
-
push_to_hub_model_id = push_to_hub_model_id,
|
| 376 |
-
push_to_hub_organization = push_to_hub_organization,
|
| 377 |
-
push_to_hub_token = push_to_hub_token,
|
| 378 |
-
mp_parameters = mp_parameters,
|
| 379 |
-
auto_find_batch_size = auto_find_batch_size,
|
| 380 |
-
full_determinism = full_determinism,
|
| 381 |
-
torchdynamo = torchdynamo,
|
| 382 |
-
ray_scope = ray_scope,
|
| 383 |
-
ddp_timeout = ddp_timeout,
|
| 384 |
-
torch_compile = torch_compile,
|
| 385 |
-
torch_compile_backend = torch_compile_backend,
|
| 386 |
-
torch_compile_mode = torch_compile_mode,
|
| 387 |
-
include_tokens_per_second = include_tokens_per_second,
|
| 388 |
-
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 389 |
-
neftune_noise_alpha = neftune_noise_alpha,
|
| 390 |
-
optim_target_modules = optim_target_modules,
|
| 391 |
-
batch_eval_metrics = batch_eval_metrics,
|
| 392 |
-
eval_on_start = eval_on_start,
|
| 393 |
-
use_liger_kernel = use_liger_kernel,
|
| 394 |
-
liger_kernel_config = liger_kernel_config,
|
| 395 |
-
eval_use_gather_object = eval_use_gather_object,
|
| 396 |
-
average_tokens_across_devices = average_tokens_across_devices,
|
| 397 |
-
model_init_kwargs = model_init_kwargs,
|
| 398 |
-
dataset_text_field = dataset_text_field,
|
| 399 |
-
dataset_kwargs = dataset_kwargs,
|
| 400 |
-
dataset_num_proc = dataset_num_proc,
|
| 401 |
-
pad_token = pad_token,
|
| 402 |
-
max_length = max_length,
|
| 403 |
-
packing = packing,
|
| 404 |
-
padding_free = padding_free,
|
| 405 |
-
eval_packing = eval_packing,
|
| 406 |
-
dataset_batch_size = dataset_batch_size,
|
| 407 |
-
num_of_sequences = num_of_sequences,
|
| 408 |
-
chars_per_token = chars_per_token,
|
| 409 |
-
max_seq_length = max_seq_length,
|
| 410 |
-
use_liger = use_liger,**kwargs)
|
| 411 |
-
self.vllm_sampling_params = vllm_sampling_params
|
| 412 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
| 413 |
-
pass
|
| 414 |
-
|
| 415 |
-
class _UnslothSFTTrainer(Trainer):
|
| 416 |
-
""""""
|
| 417 |
-
|
| 418 |
-
_tag_names = ["trl", "sft"]
|
| 419 |
-
|
| 420 |
-
def __init__(
|
| 421 |
-
self,
|
| 422 |
-
model: Union[str, nn.Module, PreTrainedModel],
|
| 423 |
-
args: Optional[Union[SFTConfig, TrainingArguments]] = None,
|
| 424 |
-
data_collator: Optional[DataCollator] = None, # type: ignore
|
| 425 |
-
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
| 426 |
-
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 427 |
-
processing_class: Optional[
|
| 428 |
-
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 429 |
-
] = None,
|
| 430 |
-
compute_loss_func: Optional[Callable] = None,
|
| 431 |
-
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
| 432 |
-
callbacks: Optional[list[TrainerCallback]] = None,
|
| 433 |
-
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
|
| 434 |
-
optimizer_cls_and_kwargs: Optional[tuple[Type[torch.optim.Optimizer], dict[str, Any]]] = None,
|
| 435 |
-
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 436 |
-
peft_config: Optional["PeftConfig"] = None,
|
| 437 |
-
formatting_func: Optional[Union[Callable[[dict], str], Callable[[dict], list[str]]]] = None,
|
| 438 |
-
):
|
| 439 |
-
# Args
|
| 440 |
-
model_id = model if isinstance(model, str) else model.config._name_or_path
|
| 441 |
-
if args is None:
|
| 442 |
-
model_name = model_id.split("/")[-1]
|
| 443 |
-
args = SFTConfig(f"{model_name}-SFT")
|
| 444 |
-
elif isinstance(args, TrainingArguments) and not isinstance(args, SFTConfig):
|
| 445 |
-
dict_args = args.to_dict()
|
| 446 |
-
dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token
|
| 447 |
-
dict_args.pop("push_to_hub_token")
|
| 448 |
-
args = SFTConfig(**dict_args)
|
| 449 |
-
|
| 450 |
-
# Handle the tokenizer
|
| 451 |
-
if processing_class is None:
|
| 452 |
-
processing_class = AutoTokenizer.from_pretrained(model_id)
|
| 453 |
-
|
| 454 |
-
# Data collator
|
| 455 |
-
if args.padding_free:
|
| 456 |
-
if data_collator is not None:
|
| 457 |
-
raise ValueError("Passing a custom data collator is not supported when using padding-free.")
|
| 458 |
-
if args.packing:
|
| 459 |
-
warnings.warn(
|
| 460 |
-
"You are passing `packing=True` and `padding_free=True` which is not recommended. Please refer "
|
| 461 |
-
"to the documentation to understand why this is not recommended."
|
| 462 |
-
)
|
| 463 |
-
if model.config._attn_implementation != "flash_attention_2":
|
| 464 |
-
warnings.warn(
|
| 465 |
-
"Padding-free training is enabled, but the attention implementation is not set to "
|
| 466 |
-
"'flash_attention_2'. Padding-free training flattens batches into a single sequence, and "
|
| 467 |
-
"'flash_attention_2' is the only known attention mechanism that reliably supports this. Using "
|
| 468 |
-
"other implementations may lead to unexpected behavior. To ensure compatibility, set "
|
| 469 |
-
"`attn_implementation='flash_attention_2'` in the model configuration, or verify that your "
|
| 470 |
-
"attention mechanism can handle flattened sequences."
|
| 471 |
-
)
|
| 472 |
-
if args.per_device_train_batch_size == 1:
|
| 473 |
-
warnings.warn(
|
| 474 |
-
"You are using a per_device_train_batch_size of 1 with padding-free training. Using a batch size "
|
| 475 |
-
"of 1 anihilate the benefits of padding-free training. Please consider increasing the batch size "
|
| 476 |
-
"to at least 2."
|
| 477 |
-
)
|
| 478 |
-
data_collator = DataCollatorWithFlattening()
|
| 479 |
-
|
| 480 |
-
if data_collator is None:
|
| 481 |
-
# Get the pad token: if not provided, use the one from the processing class or the eos token
|
| 482 |
-
# if the processing class does not have a pad token.
|
| 483 |
-
pad_token = args.pad_token or processing_class.pad_token or processing_class.eos_token
|
| 484 |
-
pad_token_id = processing_class.convert_tokens_to_ids(pad_token)
|
| 485 |
-
if pad_token_id is None:
|
| 486 |
-
raise ValueError(
|
| 487 |
-
f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given "
|
| 488 |
-
f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists "
|
| 489 |
-
"in the vocabulary before using it as a padding token."
|
| 490 |
-
)
|
| 491 |
-
data_collator = DataCollatorForLanguageModeling(pad_token_id)
|
| 492 |
-
|
| 493 |
-
# Model
|
| 494 |
-
if args.model_init_kwargs is not None and not isinstance(model, str):
|
| 495 |
-
warnings.warn(
|
| 496 |
-
"You passed model_init_kwargs to the `SFTConfig`, but your model is already instantiated. "
|
| 497 |
-
"The `model_init_kwargs` will be ignored."
|
| 498 |
-
)
|
| 499 |
-
if isinstance(model, str):
|
| 500 |
-
model = self._create_model_from_path(model, args)
|
| 501 |
-
|
| 502 |
-
# PEFT configuration and model wrapping
|
| 503 |
-
if False:
|
| 504 |
-
pass
|
| 505 |
-
|
| 506 |
-
# Dataset
|
| 507 |
-
preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False)
|
| 508 |
-
if preprocess_dataset:
|
| 509 |
-
train_dataset = self._prepare_dataset(
|
| 510 |
-
train_dataset, processing_class, args, args.packing, formatting_func, "train"
|
| 511 |
-
)
|
| 512 |
-
if eval_dataset is not None:
|
| 513 |
-
packing = args.packing if args.eval_packing is None else args.eval_packing
|
| 514 |
-
if isinstance(eval_dataset, dict):
|
| 515 |
-
eval_dataset = {
|
| 516 |
-
key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key)
|
| 517 |
-
for key, dataset in eval_dataset.items()
|
| 518 |
-
}
|
| 519 |
-
else:
|
| 520 |
-
eval_dataset = self._prepare_dataset(
|
| 521 |
-
eval_dataset, processing_class, args, packing, formatting_func, "eval"
|
| 522 |
-
)
|
| 523 |
-
|
| 524 |
-
# Initialize the metrics
|
| 525 |
-
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
|
| 526 |
-
self._total_train_tokens = 0
|
| 527 |
-
|
| 528 |
-
# Initialize the Trainer. Parent class will handle:
|
| 529 |
-
# - DeepSpeed configuration [through create_accelerator_and_postprocess]
|
| 530 |
-
# - FSDP setup
|
| 531 |
-
# - Distributed training setup
|
| 532 |
-
# - Optimizer and scheduler creation
|
| 533 |
-
# Some arguments are only available for transformers>=4.47.0. Can be removed when the min version is bumped.
|
| 534 |
-
super_init_kwargs = {}
|
| 535 |
-
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
| 536 |
-
super_init_kwargs["optimizer_cls_and_kwargs"] = optimizer_cls_and_kwargs
|
| 537 |
-
else:
|
| 538 |
-
if optimizer_cls_and_kwargs is not None:
|
| 539 |
-
warnings.warn(
|
| 540 |
-
"The `optimizer_cls_and_kwargs` argument is only available for `transformers>=4.47.0`. "
|
| 541 |
-
"The default optimizer will be used. "
|
| 542 |
-
"Remove the `optimizer_cls_and_kwargs` or upgrade to `transformers>=4.47.0`."
|
| 543 |
-
)
|
| 544 |
-
super().__init__(
|
| 545 |
-
model=model,
|
| 546 |
-
args=args,
|
| 547 |
-
data_collator=data_collator,
|
| 548 |
-
train_dataset=train_dataset,
|
| 549 |
-
eval_dataset=eval_dataset,
|
| 550 |
-
processing_class=processing_class,
|
| 551 |
-
compute_loss_func=compute_loss_func,
|
| 552 |
-
compute_metrics=compute_metrics,
|
| 553 |
-
callbacks=callbacks,
|
| 554 |
-
optimizers=optimizers,
|
| 555 |
-
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 556 |
-
**super_init_kwargs,
|
| 557 |
-
)
|
| 558 |
-
|
| 559 |
-
# Add tags for models that have been loaded with the correct transformers version
|
| 560 |
-
if hasattr(self.model, "add_model_tags"):
|
| 561 |
-
self.model.add_model_tags(self._tag_names)
|
| 562 |
-
|
| 563 |
-
def _create_model_from_path(self, model_path: str, args: SFTConfig) -> PreTrainedModel:
|
| 564 |
-
"""Creates a model from a path or model identifier."""
|
| 565 |
-
model_init_kwargs = args.model_init_kwargs or {}
|
| 566 |
-
# Handle torch dtype
|
| 567 |
-
torch_dtype = model_init_kwargs.get("torch_dtype")
|
| 568 |
-
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
|
| 569 |
-
pass # torch_dtype is already a torch.dtype or "auto" or None
|
| 570 |
-
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
|
| 571 |
-
torch_dtype = getattr(torch, torch_dtype)
|
| 572 |
-
model_init_kwargs["torch_dtype"] = torch_dtype
|
| 573 |
-
else:
|
| 574 |
-
raise ValueError(
|
| 575 |
-
"Invalid `torch_dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing "
|
| 576 |
-
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
|
| 577 |
-
)
|
| 578 |
-
# Disable caching if gradient checkpointing is enabled (not supported)
|
| 579 |
-
if args.gradient_checkpointing:
|
| 580 |
-
model_init_kwargs["use_cache"] = False
|
| 581 |
-
|
| 582 |
-
# Create model
|
| 583 |
-
model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
|
| 584 |
-
return model
|
| 585 |
-
|
| 586 |
-
def _prepare_peft_model(self, model: PreTrainedModel, peft_config: Any, args: SFTConfig) -> PreTrainedModel:
|
| 587 |
-
"""Prepares a model for PEFT training."""
|
| 588 |
-
if not is_peft_available():
|
| 589 |
-
raise ImportError("To use PeftModel, you need to install the `peft` library.")
|
| 590 |
-
|
| 591 |
-
if not isinstance(peft_config, PeftConfig):
|
| 592 |
-
raise ValueError(
|
| 593 |
-
f"Expected PeftConfig object but got {type(peft_config)}. If you want to use the PeftModel, you need "
|
| 594 |
-
"to pass a PeftConfig object to the SFTTrainer."
|
| 595 |
-
)
|
| 596 |
-
|
| 597 |
-
if isinstance(model, PeftModel):
|
| 598 |
-
return model
|
| 599 |
-
|
| 600 |
-
# Handle quantized models (QLoRA)
|
| 601 |
-
is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False)
|
| 602 |
-
|
| 603 |
-
is_sharded_qlora = False
|
| 604 |
-
if getattr(model, "is_loaded_in_4bit", False):
|
| 605 |
-
# Check if model is sharded (FSDP/DS-Zero3)
|
| 606 |
-
for _, param in model.named_parameters():
|
| 607 |
-
if param.__class__.__name__ == "Params4bit":
|
| 608 |
-
is_sharded_qlora = param.data.device.type in {"cpu", "meta"}
|
| 609 |
-
break
|
| 610 |
-
|
| 611 |
-
# Prepare model for kbit training if needed
|
| 612 |
-
if is_qlora and not is_sharded_qlora:
|
| 613 |
-
model = self._prepare_model_for_kbit_training(model, args)
|
| 614 |
-
# Disable gradient checkpointing as it's handled by prepare_model_for_kbit_training
|
| 615 |
-
args = dataclasses.replace(args, gradient_checkpointing=False)
|
| 616 |
-
elif args.gradient_checkpointing:
|
| 617 |
-
model = self._enable_gradient_checkpointing(model, args)
|
| 618 |
-
|
| 619 |
-
# Create PEFT model
|
| 620 |
-
if (
|
| 621 |
-
version.parse(peft.__version__) >= version.parse("0.12") # autocast_adapter_dtype introduced in 0.12
|
| 622 |
-
and getattr(model, "is_loaded_in_4bit", False)
|
| 623 |
-
and is_sharded_qlora
|
| 624 |
-
):
|
| 625 |
-
model = get_peft_model(model, peft_config, autocast_adapter_dtype=False)
|
| 626 |
-
else:
|
| 627 |
-
model = get_peft_model(model, peft_config)
|
| 628 |
-
|
| 629 |
-
# Handle bf16 casting for 4-bit models
|
| 630 |
-
if args.bf16 and getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora:
|
| 631 |
-
peft_module_casting_to_bf16(model)
|
| 632 |
-
|
| 633 |
-
return model
|
| 634 |
-
|
| 635 |
-
def _prepare_model_for_kbit_training(self, model: PreTrainedModel, args: SFTConfig) -> PreTrainedModel:
|
| 636 |
-
"""Prepares a quantized model for kbit training."""
|
| 637 |
-
prepare_model_kwargs = {
|
| 638 |
-
"use_gradient_checkpointing": args.gradient_checkpointing,
|
| 639 |
-
"gradient_checkpointing_kwargs": args.gradient_checkpointing_kwargs or {},
|
| 640 |
-
}
|
| 641 |
-
|
| 642 |
-
return prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
| 643 |
-
|
| 644 |
-
def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: SFTConfig) -> PreTrainedModel:
|
| 645 |
-
"""Enables gradient checkpointing for the model."""
|
| 646 |
-
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
|
| 647 |
-
use_reentrant = (
|
| 648 |
-
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
|
| 649 |
-
)
|
| 650 |
-
|
| 651 |
-
if use_reentrant:
|
| 652 |
-
if hasattr(model, "enable_input_require_grads"):
|
| 653 |
-
model.enable_input_require_grads()
|
| 654 |
-
else:
|
| 655 |
-
|
| 656 |
-
def make_inputs_require_grad(module, input, output):
|
| 657 |
-
output.requires_grad_(True)
|
| 658 |
-
|
| 659 |
-
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 660 |
-
|
| 661 |
-
return model
|
| 662 |
-
|
| 663 |
-
def _prepare_dataset(
|
| 664 |
-
self,
|
| 665 |
-
dataset: Union[Dataset, IterableDataset],
|
| 666 |
-
processing_class,
|
| 667 |
-
args,
|
| 668 |
-
packing: bool,
|
| 669 |
-
formatting_func: Optional[Callable[[dict], str]],
|
| 670 |
-
dataset_name: str,
|
| 671 |
-
) -> Union[Dataset, IterableDataset]:
|
| 672 |
-
# All Unsloth Zoo code licensed under LGPLv3
|
| 673 |
-
try:
|
| 674 |
-
if isinstance(dataset, ConstantLengthDataset): return dataset
|
| 675 |
-
except:
|
| 676 |
-
pass
|
| 677 |
-
|
| 678 |
-
map_kwargs = {}
|
| 679 |
-
use_desc = isinstance(dataset, Dataset)
|
| 680 |
-
is_vlm = hasattr(processing_class, "tokenizer")
|
| 681 |
-
tokenizer = processing_class
|
| 682 |
-
if is_vlm: tokenizer = processing_class.tokenizer
|
| 683 |
-
|
| 684 |
-
# Get max length
|
| 685 |
-
max_seq_length = getattr(args, "max_length", 0)
|
| 686 |
-
if max_seq_length == 0: max_seq_length = getattr(args, "max_seq_length", 0)
|
| 687 |
-
if max_seq_length == 0: max_seq_length = getattr(self, "max_seq_length", 0)
|
| 688 |
-
if max_seq_length == 0: max_seq_length = getattr(self, "max_seq", 0)
|
| 689 |
-
if max_seq_length == 0: raise RuntimeError("Unsloth: max_seq_length is 0! Please specify one!")
|
| 690 |
-
dataset_text_field = getattr(args, "dataset_text_field", "text")
|
| 691 |
-
do_truncation = max_seq_length != 0
|
| 692 |
-
do_formatting_func = False
|
| 693 |
-
do_tokenize = True
|
| 694 |
-
|
| 695 |
-
# Get correct column names
|
| 696 |
-
column_names = set(next(iter(dataset)).keys())
|
| 697 |
-
used_column_names = ["input_ids"]
|
| 698 |
-
if "attention_mask" in column_names:
|
| 699 |
-
used_column_names.append("attention_mask")
|
| 700 |
-
|
| 701 |
-
# Check if already tokenized so skip
|
| 702 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
|
| 703 |
-
if "labels" in column_names:
|
| 704 |
-
# Most likely forgot data collator!
|
| 705 |
-
if is_vlm and not hasattr(tokenizer, "pad"):
|
| 706 |
-
# Check if processing_class has a .pad, if not, use tokenizer.tokenizer
|
| 707 |
-
raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!")
|
| 708 |
-
self.data_collator = DataCollatorForSeq2Seq(tokenizer)
|
| 709 |
-
used_column_names.append("labels")
|
| 710 |
-
do_tokenize = False
|
| 711 |
-
elif "input_ids" in column_names:
|
| 712 |
-
# Skip dataset prep, and set data collator
|
| 713 |
-
if is_vlm and not hasattr(tokenizer, "pad"):
|
| 714 |
-
# Check if processing_class has a .pad, if not, use tokenizer.tokenizer
|
| 715 |
-
raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!")
|
| 716 |
-
self.data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
|
| 717 |
-
do_tokenize = False
|
| 718 |
-
elif dataset_text_field not in column_names:
|
| 719 |
-
do_formatting_func = True
|
| 720 |
-
if formatting_func is None:
|
| 721 |
-
raise RuntimeError("Unsloth: You must specify a `formatting_func`")
|
| 722 |
-
pass
|
| 723 |
-
|
| 724 |
-
if do_tokenize:
|
| 725 |
-
# Check double BOS tokens
|
| 726 |
-
if do_formatting_func:
|
| 727 |
-
test_text = formatting_func(next(iter(dataset)))
|
| 728 |
-
if not isinstance(test_text, list):
|
| 729 |
-
raise ValueError(
|
| 730 |
-
"Unsloth: The `formatting_func` should return a list of processed strings."
|
| 731 |
-
)
|
| 732 |
-
test_text = test_text[0]
|
| 733 |
-
else:
|
| 734 |
-
test_text = next(iter(dataset))[dataset_text_field][0]
|
| 735 |
-
|
| 736 |
-
# Get chat template
|
| 737 |
-
chat_template = getattr(processing_class, 'chat_template', '')
|
| 738 |
-
if chat_template == '' and is_vlm:
|
| 739 |
-
chat_template = getattr(tokenizer, 'chat_template', '')
|
| 740 |
-
if chat_template is None:
|
| 741 |
-
chat_template = ''
|
| 742 |
-
|
| 743 |
-
# Get bos_token
|
| 744 |
-
add_special_tokens = True
|
| 745 |
-
bos_token_1 = getattr(processing_class, 'bos_token', None)
|
| 746 |
-
bos_token_2 = getattr(tokenizer, 'bos_token', None)
|
| 747 |
-
bos_token = bos_token_1 or bos_token_2
|
| 748 |
-
|
| 749 |
-
if bos_token is not None:
|
| 750 |
-
if test_text.startswith(bos_token) or bos_token in chat_template:
|
| 751 |
-
add_special_tokens = False
|
| 752 |
-
print("Unsloth: We found double BOS tokens - we shall remove one automatically.")
|
| 753 |
-
pass
|
| 754 |
-
|
| 755 |
-
# Create tokenize function
|
| 756 |
-
def _tokenize(example):
|
| 757 |
-
return tokenizer(
|
| 758 |
-
example[dataset_text_field] if not do_formatting_func else formatting_func(example),
|
| 759 |
-
truncation = do_truncation,
|
| 760 |
-
max_length = max_seq_length,
|
| 761 |
-
return_token_type_ids = False,
|
| 762 |
-
add_special_tokens = add_special_tokens,
|
| 763 |
-
)
|
| 764 |
-
pass
|
| 765 |
-
|
| 766 |
-
if not isinstance(dataset, IterableDataset):
|
| 767 |
-
map_kwargs["num_proc"] = getattr(args, "dataset_num_proc", 2)
|
| 768 |
-
else:
|
| 769 |
-
map_kwargs["batch_size"] = dataset._ex_iterable.batch_size
|
| 770 |
-
|
| 771 |
-
if use_desc: map_kwargs["desc"] = f'Unsloth: Tokenizing ["{dataset_text_field}"]'
|
| 772 |
-
dataset = dataset.map(_tokenize, batched = True, **map_kwargs)
|
| 773 |
-
|
| 774 |
-
# If VLM, switch data collator since .pad is needed!
|
| 775 |
-
if is_vlm and not hasattr(processing_class, "pad"):
|
| 776 |
-
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
|
| 777 |
-
self.data_collator = data_collator
|
| 778 |
-
pass
|
| 779 |
-
pass
|
| 780 |
-
if packing:
|
| 781 |
-
print("Unsloth: Hugging Face's packing is currently buggy - we're disabling it for now!")
|
| 782 |
-
return dataset
|
| 783 |
-
|
| 784 |
-
if max_seq_length == 0:
|
| 785 |
-
raise ValueError("When packing is enabled, `max_seq_length` can't be `None`.")
|
| 786 |
-
|
| 787 |
-
if use_desc: map_kwargs["desc"] = f"Unsloth: Packing {dataset_name} dataset"
|
| 788 |
-
dataset = dataset.select_columns(used_column_names).map(
|
| 789 |
-
pack_examples,
|
| 790 |
-
batched = True,
|
| 791 |
-
fn_kwargs = {"seq_length": max_seq_length,},
|
| 792 |
-
**map_kwargs,
|
| 793 |
-
)
|
| 794 |
-
pass
|
| 795 |
-
return dataset
|
| 796 |
-
|
| 797 |
-
def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None):
|
| 798 |
-
outputs = super().compute_loss(
|
| 799 |
-
model,
|
| 800 |
-
inputs,
|
| 801 |
-
return_outputs = return_outputs,
|
| 802 |
-
num_items_in_batch = num_items_in_batch,
|
| 803 |
-
)
|
| 804 |
-
return outputs
|
| 805 |
-
|
| 806 |
-
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
| 807 |
-
mode = "eval" if self.control.should_evaluate else "train"
|
| 808 |
-
metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics
|
| 809 |
-
|
| 810 |
-
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
|
| 811 |
-
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
|
| 812 |
-
if mode == "eval":
|
| 813 |
-
metrics = {f"eval_{key}": val for key, val in metrics.items()}
|
| 814 |
-
|
| 815 |
-
logs = {**logs, **metrics}
|
| 816 |
-
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
| 817 |
-
super().log(logs, start_time)
|
| 818 |
-
else: # transformers<=4.46
|
| 819 |
-
super().log(logs)
|
| 820 |
-
self._metrics[mode].clear()
|
| 821 |
-
|
| 822 |
-
def create_model_card(
|
| 823 |
-
self,
|
| 824 |
-
model_name: Optional[str] = None,
|
| 825 |
-
dataset_name: Optional[str] = None,
|
| 826 |
-
tags: Union[str, list[str], None] = None,
|
| 827 |
-
):
|
| 828 |
-
"""
|
| 829 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
| 830 |
-
|
| 831 |
-
Args:
|
| 832 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 833 |
-
Name of the model.
|
| 834 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 835 |
-
Name of the dataset used for training.
|
| 836 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 837 |
-
Tags to be associated with the model card.
|
| 838 |
-
"""
|
| 839 |
-
if not self.is_world_process_zero():
|
| 840 |
-
return
|
| 841 |
-
|
| 842 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 843 |
-
base_model = self.model.config._name_or_path
|
| 844 |
-
else:
|
| 845 |
-
base_model = None
|
| 846 |
-
|
| 847 |
-
tags = tags or []
|
| 848 |
-
if isinstance(tags, str):
|
| 849 |
-
tags = [tags]
|
| 850 |
-
|
| 851 |
-
if hasattr(self.model.config, "unsloth_version"):
|
| 852 |
-
tags.append("unsloth")
|
| 853 |
-
|
| 854 |
-
model_card = generate_model_card(
|
| 855 |
-
base_model=base_model,
|
| 856 |
-
model_name=model_name,
|
| 857 |
-
hub_model_id=self.hub_model_id,
|
| 858 |
-
dataset_name=dataset_name,
|
| 859 |
-
tags=tags,
|
| 860 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 861 |
-
comet_url=get_comet_experiment_url(),
|
| 862 |
-
trainer_name="SFT",
|
| 863 |
-
)
|
| 864 |
-
|
| 865 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 866 |
-
class UnslothSFTTrainer(_UnslothSFTTrainer):
|
| 867 |
-
"""
|
| 868 |
-
|
| 869 |
-
Trainer for Supervised Fine-Tuning (SFT) method.
|
| 870 |
-
|
| 871 |
-
This class is a wrapper around the [`transformers.Trainer`] class and inherits all of its attributes and methods.
|
| 872 |
-
|
| 873 |
-
Example:
|
| 874 |
-
|
| 875 |
-
```python
|
| 876 |
-
from datasets import load_dataset
|
| 877 |
-
from trl import SFTTrainer
|
| 878 |
-
|
| 879 |
-
dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]")
|
| 880 |
-
|
| 881 |
-
trainer = SFTTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset)
|
| 882 |
-
trainer.train()
|
| 883 |
-
```
|
| 884 |
-
|
| 885 |
-
Args:
|
| 886 |
-
model (`Union[str, PreTrainedModel]`):
|
| 887 |
-
Model to be trained. Can be either:
|
| 888 |
-
|
| 889 |
-
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
|
| 890 |
-
a path to a *directory* containing model weights saved using
|
| 891 |
-
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
|
| 892 |
-
loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
|
| 893 |
-
in `args.model_init_kwargs`.
|
| 894 |
-
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
|
| 895 |
-
args ([`SFTConfig`], *optional*, defaults to `None`):
|
| 896 |
-
Configuration for this trainer. If `None`, a default configuration is used.
|
| 897 |
-
data_collator (`DataCollator`, *optional*):
|
| 898 |
-
Function to use to form a batch from a list of elements of the prcessed `train_dataset` or `eval_dataset`.
|
| 899 |
-
Will default to [`~transformers.default_data_collator`] if no `processing_class` is provided, an instance
|
| 900 |
-
of [`~transformers.DataCollatorWithPadding`] otherwise if the processing_class is a feature extractor or
|
| 901 |
-
tokenizer.
|
| 902 |
-
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
|
| 903 |
-
Dataset to use for training. SFT supports both [language modeling](#language-modeling) type and
|
| 904 |
-
[prompt-completion](#prompt-completion) type. The format of the samples can be either:
|
| 905 |
-
|
| 906 |
-
- [Standard](dataset_formats#standard): Each sample contains plain text.
|
| 907 |
-
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
|
| 908 |
-
and content).
|
| 909 |
-
|
| 910 |
-
The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field.
|
| 911 |
-
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
|
| 912 |
-
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
|
| 913 |
-
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
|
| 914 |
-
Processing class used to process the data. If `None`, the processing class is loaded from the model's name
|
| 915 |
-
with [`~transformers.AutoTokenizer.from_pretrained`].
|
| 916 |
-
callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
|
| 917 |
-
List of callbacks to customize the training loop. Will add those to the list of default callbacks
|
| 918 |
-
detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
|
| 919 |
-
|
| 920 |
-
If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
|
| 921 |
-
method.
|
| 922 |
-
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
|
| 923 |
-
A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
|
| 924 |
-
model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
|
| 925 |
-
optimizer_cls_and_kwargs (`Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*, defaults to `None`):
|
| 926 |
-
A tuple containing the optimizer class and keyword arguments to use.
|
| 927 |
-
Overrides `optim` and `optim_args` in `args`. Incompatible with the `optimizers` argument.
|
| 928 |
-
|
| 929 |
-
Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before initializing the Trainer.
|
| 930 |
-
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*, defaults to `None`):
|
| 931 |
-
A function that preprocess the logits right before caching them at each evaluation step. Must take two
|
| 932 |
-
tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
|
| 933 |
-
by this function will be reflected in the predictions received by `compute_metrics`.
|
| 934 |
-
|
| 935 |
-
Note that the labels (second parameter) will be `None` if the dataset does not have them.
|
| 936 |
-
peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
|
| 937 |
-
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
|
| 938 |
-
formatting_func (`Optional[Callable]`):
|
| 939 |
-
Formatting function applied to the dataset before tokenization.
|
| 940 |
-
|
| 941 |
-
"""
|
| 942 |
-
def __init__(
|
| 943 |
-
self,
|
| 944 |
-
model,
|
| 945 |
-
args = None,
|
| 946 |
-
data_collator = None,
|
| 947 |
-
train_dataset = None,
|
| 948 |
-
eval_dataset = None,
|
| 949 |
-
processing_class = None,
|
| 950 |
-
compute_loss_func = None,
|
| 951 |
-
compute_metrics = None,
|
| 952 |
-
callbacks = None,
|
| 953 |
-
optimizer_cls_and_kwargs = None,
|
| 954 |
-
preprocess_logits_for_metrics = None,
|
| 955 |
-
peft_config = None,
|
| 956 |
-
formatting_func = None,
|
| 957 |
-
**kwargs
|
| 958 |
-
):
|
| 959 |
-
if args is None: args = UnslothSFTConfig()
|
| 960 |
-
use_bf16 = getattr(args, 'bf16', False)
|
| 961 |
-
if type(use_bf16) is not bool: use_bf16 = False
|
| 962 |
-
use_fp16 = getattr(args, 'fp16', False)
|
| 963 |
-
if type(use_fp16) is not bool: use_fp16 = False
|
| 964 |
-
force_float32 = False
|
| 965 |
-
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
| 966 |
-
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 967 |
-
force_float32 = True
|
| 968 |
-
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 969 |
-
dtype = getattr(model.config, 'torch_dtype', None)
|
| 970 |
-
if dtype is None: dtype = model.get_input_embeddings().dtype
|
| 971 |
-
from unsloth_zoo.utils import _get_dtype
|
| 972 |
-
dtype = _get_dtype(dtype)
|
| 973 |
-
float16 = dtype == torch.float16
|
| 974 |
-
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 975 |
-
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 976 |
-
if force_float32:
|
| 977 |
-
args.fp16 = False
|
| 978 |
-
args.bf16 = False
|
| 979 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 980 |
-
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 981 |
-
args.fp16 = float16
|
| 982 |
-
args.bf16 = not float16
|
| 983 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 984 |
-
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 985 |
-
args.eval_strategy = 'steps'
|
| 986 |
-
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 987 |
-
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 988 |
-
if ga_steps is not None and ga_steps > 1:
|
| 989 |
-
from transformers import __version__ as transformers_version
|
| 990 |
-
if Version(transformers_version) <= Version('4.45.2'):
|
| 991 |
-
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 992 |
-
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 993 |
-
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 994 |
-
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 995 |
-
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 996 |
-
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 997 |
-
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 998 |
-
if type(fp16_full_eval) is not bool: fp16_full_eval = False
|
| 999 |
-
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 1000 |
-
if type(bf16_full_eval) is not bool: bf16_full_eval = False
|
| 1001 |
-
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 1002 |
-
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 1003 |
-
if force_float32:
|
| 1004 |
-
args.bf16_full_eval = False
|
| 1005 |
-
args.fp16_full_eval = False
|
| 1006 |
-
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 1007 |
-
args.bf16_full_eval = True
|
| 1008 |
-
args.fp16_full_eval = False
|
| 1009 |
-
elif not bf16_full_eval and not fp16_full_eval:
|
| 1010 |
-
args.bf16_full_eval = args.bf16
|
| 1011 |
-
args.fp16_full_eval = args.fp16
|
| 1012 |
-
_output_logits = False
|
| 1013 |
-
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 1014 |
-
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 1015 |
-
if _output_logits:
|
| 1016 |
-
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 1017 |
-
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 1018 |
-
pass
|
| 1019 |
-
else:
|
| 1020 |
-
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 1021 |
-
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 1022 |
-
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 1023 |
-
max_seq_length = model.max_seq_length
|
| 1024 |
-
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 1025 |
-
if 'max_length' not in locals() and not hasattr(args, 'max_length'):
|
| 1026 |
-
pass
|
| 1027 |
-
else:
|
| 1028 |
-
if hasattr(args, 'max_seq_length') and args.max_seq_length is not None and args.max_seq_length > 0:
|
| 1029 |
-
if hasattr(args, 'max_length'):
|
| 1030 |
-
args.max_length = args.max_seq_length
|
| 1031 |
-
max_length = args.max_length
|
| 1032 |
-
else:
|
| 1033 |
-
model_max_length = getattr(model, 'max_seq_length', None)
|
| 1034 |
-
# print(model_max_length, 'mml1')
|
| 1035 |
-
if model_max_length is None: model_max_length = getattr(model, 'max_length', None)
|
| 1036 |
-
# print(model_max_length, 'mml2')
|
| 1037 |
-
if model_max_length is not None:
|
| 1038 |
-
args.max_length = model_max_length
|
| 1039 |
-
max_length = args.max_length
|
| 1040 |
-
elif hasattr(args, 'max_length') and args.max_length is not None:
|
| 1041 |
-
max_length = args.max_length
|
| 1042 |
-
# if we are here, then we are in a weird case where max_length is set but max_seq_length is not set
|
| 1043 |
-
setattr(model, 'max_seq_length', max_length)
|
| 1044 |
-
else:
|
| 1045 |
-
print('Unsloth: We did not find `max_seq_length` or `max_length` in the model or args. We will set it to 1024.')
|
| 1046 |
-
args.max_length = 1024
|
| 1047 |
-
if model is not None and hasattr(model, 'for_training'):
|
| 1048 |
-
model.for_training()
|
| 1049 |
-
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 1050 |
-
if 'processing_class' in locals():
|
| 1051 |
-
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 1052 |
-
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 1053 |
-
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 1054 |
-
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 1055 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1056 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 1057 |
-
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
|
| 1058 |
-
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 1059 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
| 1060 |
-
else:
|
| 1061 |
-
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 1062 |
-
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 1063 |
-
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 1064 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 1065 |
-
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 1066 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 1067 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
| 1068 |
-
else:
|
| 1069 |
-
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
|
| 1070 |
-
other_metrics = []
|
| 1071 |
-
|
| 1072 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 1073 |
-
PatchRLStatistics('sft_trainer', other_metrics)
|
| 1074 |
-
IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', '').split('\n')
|
| 1075 |
-
from unsloth_zoo.tokenizer_utils import fix_untrained_tokens
|
| 1076 |
-
from unsloth_zoo.training_utils import fix_zero_training_loss
|
| 1077 |
-
if 'tokenizer' not in locals(): tokenizer = processing_class
|
| 1078 |
-
fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)
|
| 1079 |
-
fix_zero_training_loss(model, tokenizer, train_dataset)
|
| 1080 |
-
|
| 1081 |
-
super().__init__(
|
| 1082 |
-
model = model,
|
| 1083 |
-
args = args,
|
| 1084 |
-
data_collator = data_collator,
|
| 1085 |
-
train_dataset = train_dataset,
|
| 1086 |
-
eval_dataset = eval_dataset,
|
| 1087 |
-
processing_class = processing_class,
|
| 1088 |
-
compute_loss_func = compute_loss_func,
|
| 1089 |
-
compute_metrics = compute_metrics,
|
| 1090 |
-
callbacks = callbacks,
|
| 1091 |
-
optimizer_cls_and_kwargs = optimizer_cls_and_kwargs,
|
| 1092 |
-
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
| 1093 |
-
peft_config = peft_config,
|
| 1094 |
-
formatting_func = formatting_func,**kwargs)
|
| 1095 |
-
if hasattr(self, 'neftune_hook_handle'):
|
| 1096 |
-
self.neftune_hook_handle.remove()
|
| 1097 |
-
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 1098 |
-
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 1099 |
-
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 1100 |
-
pass
|
| 1101 |
-
|
| 1102 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_run_uploads/UnslothXPOTrainer.py
DELETED
|
@@ -1,1024 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
2025.7.11
|
| 3 |
-
2025.7.11
|
| 4 |
-
4.54.1
|
| 5 |
-
0.16.1
|
| 6 |
-
__UNSLOTH_VERSIONING__
|
| 7 |
-
"""
|
| 8 |
-
from torch import Tensor
|
| 9 |
-
import torch
|
| 10 |
-
import torch.nn as nn
|
| 11 |
-
from torch.nn import functional as F
|
| 12 |
-
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
| 13 |
-
from trl.trainer.xpo_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, IterableDataset, OnlineDPOTrainer, OptimizerNames, Optional, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, XPOConfig, XPOTrainer, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_wandb_available, jinja2, maybe_apply_chat_template, nn, os, selective_log_softmax, textwrap, torch, truncate_right, unwrap_model_for_generation, wandb)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
import os
|
| 17 |
-
from typing import *
|
| 18 |
-
from dataclasses import dataclass, field
|
| 19 |
-
from packaging.version import Version
|
| 20 |
-
import torch
|
| 21 |
-
import numpy as np
|
| 22 |
-
from contextlib import nullcontext
|
| 23 |
-
from torch.nn import functional as F
|
| 24 |
-
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
| 25 |
-
|
| 26 |
-
torch_compile_options = {
|
| 27 |
-
"epilogue_fusion" : True,
|
| 28 |
-
"max_autotune" : False,
|
| 29 |
-
"shape_padding" : True,
|
| 30 |
-
"trace.enabled" : False,
|
| 31 |
-
"triton.cudagraphs" : False,
|
| 32 |
-
}
|
| 33 |
-
|
| 34 |
-
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
| 35 |
-
def chunked_selective_log_softmax(logits, index):
|
| 36 |
-
# Split into 4 chunks only
|
| 37 |
-
chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
|
| 38 |
-
chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
|
| 39 |
-
all_per_token_logps = []
|
| 40 |
-
# Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
|
| 41 |
-
for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
|
| 42 |
-
chunk_logits = chunk_logits.to(torch.float32)
|
| 43 |
-
selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
|
| 44 |
-
logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
|
| 45 |
-
per_token_logps = selected_logits - logsumexp_values
|
| 46 |
-
all_per_token_logps.append(per_token_logps)
|
| 47 |
-
pass
|
| 48 |
-
all_per_token_logps = torch.concat(all_per_token_logps)
|
| 49 |
-
all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
|
| 50 |
-
return all_per_token_logps
|
| 51 |
-
@dataclass
|
| 52 |
-
class UnslothXPOConfig(XPOConfig):
|
| 53 |
-
"""
|
| 54 |
-
|
| 55 |
-
Configuration class for the [`XPOTrainer`].
|
| 56 |
-
|
| 57 |
-
Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
|
| 58 |
-
|
| 59 |
-
Parameters:
|
| 60 |
-
alpha (`float` or `list[float]`, *optional*, defaults to `1e-5`):
|
| 61 |
-
Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each new epoch
|
| 62 |
-
and the last alpha is used for the rest of the epochs.
|
| 63 |
-
|
| 64 |
-
"""
|
| 65 |
-
vllm_sampling_params: Optional[Any] = field(
|
| 66 |
-
default = None,
|
| 67 |
-
metadata = {'help': 'vLLM SamplingParams'},
|
| 68 |
-
)
|
| 69 |
-
unsloth_num_chunks : Optional[int] = field(
|
| 70 |
-
default = -1,
|
| 71 |
-
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
| 72 |
-
)
|
| 73 |
-
def __init__(
|
| 74 |
-
self,
|
| 75 |
-
output_dir = None,
|
| 76 |
-
overwrite_output_dir = None,
|
| 77 |
-
do_train = False,
|
| 78 |
-
do_eval = False,
|
| 79 |
-
do_predict = False,
|
| 80 |
-
eval_strategy = 'no',
|
| 81 |
-
prediction_loss_only = False,
|
| 82 |
-
per_device_train_batch_size = 4,
|
| 83 |
-
per_device_eval_batch_size = 4,
|
| 84 |
-
per_gpu_train_batch_size = None,
|
| 85 |
-
per_gpu_eval_batch_size = None,
|
| 86 |
-
gradient_accumulation_steps = 2,
|
| 87 |
-
eval_accumulation_steps = 2,
|
| 88 |
-
eval_delay = 0,
|
| 89 |
-
torch_empty_cache_steps = 250,
|
| 90 |
-
learning_rate = 5e-05,
|
| 91 |
-
weight_decay = 0.01,
|
| 92 |
-
adam_beta1 = 0.9,
|
| 93 |
-
adam_beta2 = 0.999,
|
| 94 |
-
adam_epsilon = 1e-08,
|
| 95 |
-
max_grad_norm = 1.0,
|
| 96 |
-
num_train_epochs = 3.0,
|
| 97 |
-
max_steps = -1,
|
| 98 |
-
lr_scheduler_type = 'linear',
|
| 99 |
-
warmup_ratio = 0.1,
|
| 100 |
-
warmup_steps = 0,
|
| 101 |
-
log_level = 'passive',
|
| 102 |
-
log_level_replica = 'warning',
|
| 103 |
-
log_on_each_node = True,
|
| 104 |
-
logging_dir = None,
|
| 105 |
-
logging_strategy = 'steps',
|
| 106 |
-
logging_first_step = False,
|
| 107 |
-
logging_steps = 1,
|
| 108 |
-
logging_nan_inf_filter = False,
|
| 109 |
-
save_strategy = 'steps',
|
| 110 |
-
save_steps = 500,
|
| 111 |
-
save_total_limit = None,
|
| 112 |
-
save_safetensors = True,
|
| 113 |
-
save_on_each_node = False,
|
| 114 |
-
save_only_model = False,
|
| 115 |
-
restore_callback_states_from_checkpoint = False,
|
| 116 |
-
no_cuda = False,
|
| 117 |
-
use_cpu = False,
|
| 118 |
-
use_mps_device = False,
|
| 119 |
-
seed = 3407,
|
| 120 |
-
data_seed = 3407,
|
| 121 |
-
jit_mode_eval = False,
|
| 122 |
-
use_ipex = False,
|
| 123 |
-
bf16 = False,
|
| 124 |
-
fp16 = False,
|
| 125 |
-
fp16_opt_level = 'O1',
|
| 126 |
-
half_precision_backend = 'auto',
|
| 127 |
-
bf16_full_eval = False,
|
| 128 |
-
fp16_full_eval = False,
|
| 129 |
-
tf32 = None,
|
| 130 |
-
local_rank = -1,
|
| 131 |
-
ddp_backend = None,
|
| 132 |
-
tpu_num_cores = None,
|
| 133 |
-
tpu_metrics_debug = False,
|
| 134 |
-
debug = '',
|
| 135 |
-
dataloader_drop_last = False,
|
| 136 |
-
eval_steps = None,
|
| 137 |
-
dataloader_num_workers = 0,
|
| 138 |
-
dataloader_prefetch_factor = None,
|
| 139 |
-
past_index = -1,
|
| 140 |
-
run_name = None,
|
| 141 |
-
disable_tqdm = None,
|
| 142 |
-
remove_unused_columns = True,
|
| 143 |
-
label_names = None,
|
| 144 |
-
load_best_model_at_end = False,
|
| 145 |
-
metric_for_best_model = None,
|
| 146 |
-
greater_is_better = None,
|
| 147 |
-
ignore_data_skip = False,
|
| 148 |
-
fsdp = '',
|
| 149 |
-
fsdp_min_num_params = 0,
|
| 150 |
-
fsdp_config = None,
|
| 151 |
-
fsdp_transformer_layer_cls_to_wrap = None,
|
| 152 |
-
accelerator_config = None,
|
| 153 |
-
deepspeed = None,
|
| 154 |
-
label_smoothing_factor = 0.0,
|
| 155 |
-
optim = 'adamw_8bit',
|
| 156 |
-
optim_args = None,
|
| 157 |
-
adafactor = False,
|
| 158 |
-
group_by_length = False,
|
| 159 |
-
length_column_name = 'length',
|
| 160 |
-
report_to = None,
|
| 161 |
-
ddp_find_unused_parameters = None,
|
| 162 |
-
ddp_bucket_cap_mb = None,
|
| 163 |
-
ddp_broadcast_buffers = None,
|
| 164 |
-
dataloader_pin_memory = True,
|
| 165 |
-
dataloader_persistent_workers = False,
|
| 166 |
-
skip_memory_metrics = True,
|
| 167 |
-
use_legacy_prediction_loop = False,
|
| 168 |
-
push_to_hub = False,
|
| 169 |
-
resume_from_checkpoint = None,
|
| 170 |
-
hub_model_id = None,
|
| 171 |
-
hub_strategy = 'every_save',
|
| 172 |
-
hub_token = None,
|
| 173 |
-
hub_private_repo = None,
|
| 174 |
-
hub_always_push = False,
|
| 175 |
-
hub_revision = None,
|
| 176 |
-
gradient_checkpointing = False,
|
| 177 |
-
gradient_checkpointing_kwargs = None,
|
| 178 |
-
include_inputs_for_metrics = False,
|
| 179 |
-
eval_do_concat_batches = True,
|
| 180 |
-
fp16_backend = 'auto',
|
| 181 |
-
push_to_hub_model_id = None,
|
| 182 |
-
push_to_hub_organization = None,
|
| 183 |
-
push_to_hub_token = None,
|
| 184 |
-
mp_parameters = '',
|
| 185 |
-
auto_find_batch_size = True,
|
| 186 |
-
full_determinism = False,
|
| 187 |
-
torchdynamo = None,
|
| 188 |
-
ray_scope = 'last',
|
| 189 |
-
ddp_timeout = 1800,
|
| 190 |
-
torch_compile = False,
|
| 191 |
-
torch_compile_backend = None,
|
| 192 |
-
torch_compile_mode = None,
|
| 193 |
-
include_tokens_per_second = False,
|
| 194 |
-
include_num_input_tokens_seen = False,
|
| 195 |
-
neftune_noise_alpha = None,
|
| 196 |
-
optim_target_modules = None,
|
| 197 |
-
batch_eval_metrics = False,
|
| 198 |
-
eval_on_start = False,
|
| 199 |
-
use_liger_kernel = False,
|
| 200 |
-
liger_kernel_config = None,
|
| 201 |
-
eval_use_gather_object = False,
|
| 202 |
-
average_tokens_across_devices = True,
|
| 203 |
-
reward_model_path = None,
|
| 204 |
-
judge = None,
|
| 205 |
-
max_new_tokens = 64,
|
| 206 |
-
max_length = 512,
|
| 207 |
-
temperature = 0.9,
|
| 208 |
-
missing_eos_penalty = None,
|
| 209 |
-
loss_type = 'sigmoid',
|
| 210 |
-
dataset_num_proc = None,
|
| 211 |
-
disable_dropout = True,
|
| 212 |
-
use_vllm = False,
|
| 213 |
-
ds3_gather_for_generation = True,
|
| 214 |
-
vllm_sampling_params = None,
|
| 215 |
-
unsloth_num_chunks = -1,
|
| 216 |
-
**kwargs,
|
| 217 |
-
):
|
| 218 |
-
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
| 219 |
-
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
| 220 |
-
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
| 221 |
-
output_dir = 'unsloth_training_checkpoints'
|
| 222 |
-
save_strategy = 'no'
|
| 223 |
-
if dataset_num_proc is None:
|
| 224 |
-
from multiprocessing import cpu_count
|
| 225 |
-
dataset_num_proc = min(cpu_count()*2, 2)
|
| 226 |
-
if temperature <= 0:
|
| 227 |
-
raise MathError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
|
| 228 |
-
elif temperature >= 10:
|
| 229 |
-
raise MathError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
super().__init__(
|
| 233 |
-
output_dir = output_dir,
|
| 234 |
-
overwrite_output_dir = overwrite_output_dir,
|
| 235 |
-
do_train = do_train,
|
| 236 |
-
do_eval = do_eval,
|
| 237 |
-
do_predict = do_predict,
|
| 238 |
-
eval_strategy = eval_strategy,
|
| 239 |
-
prediction_loss_only = prediction_loss_only,
|
| 240 |
-
per_device_train_batch_size = per_device_train_batch_size,
|
| 241 |
-
per_device_eval_batch_size = per_device_eval_batch_size,
|
| 242 |
-
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
| 243 |
-
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
| 244 |
-
gradient_accumulation_steps = gradient_accumulation_steps,
|
| 245 |
-
eval_accumulation_steps = eval_accumulation_steps,
|
| 246 |
-
eval_delay = eval_delay,
|
| 247 |
-
torch_empty_cache_steps = torch_empty_cache_steps,
|
| 248 |
-
learning_rate = learning_rate,
|
| 249 |
-
weight_decay = weight_decay,
|
| 250 |
-
adam_beta1 = adam_beta1,
|
| 251 |
-
adam_beta2 = adam_beta2,
|
| 252 |
-
adam_epsilon = adam_epsilon,
|
| 253 |
-
max_grad_norm = max_grad_norm,
|
| 254 |
-
num_train_epochs = num_train_epochs,
|
| 255 |
-
max_steps = max_steps,
|
| 256 |
-
lr_scheduler_type = lr_scheduler_type,
|
| 257 |
-
warmup_ratio = warmup_ratio,
|
| 258 |
-
warmup_steps = warmup_steps,
|
| 259 |
-
log_level = log_level,
|
| 260 |
-
log_level_replica = log_level_replica,
|
| 261 |
-
log_on_each_node = log_on_each_node,
|
| 262 |
-
logging_dir = logging_dir,
|
| 263 |
-
logging_strategy = logging_strategy,
|
| 264 |
-
logging_first_step = logging_first_step,
|
| 265 |
-
logging_steps = logging_steps,
|
| 266 |
-
logging_nan_inf_filter = logging_nan_inf_filter,
|
| 267 |
-
save_strategy = save_strategy,
|
| 268 |
-
save_steps = save_steps,
|
| 269 |
-
save_total_limit = save_total_limit,
|
| 270 |
-
save_safetensors = save_safetensors,
|
| 271 |
-
save_on_each_node = save_on_each_node,
|
| 272 |
-
save_only_model = save_only_model,
|
| 273 |
-
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
| 274 |
-
no_cuda = no_cuda,
|
| 275 |
-
use_cpu = use_cpu,
|
| 276 |
-
use_mps_device = use_mps_device,
|
| 277 |
-
seed = seed,
|
| 278 |
-
data_seed = data_seed,
|
| 279 |
-
jit_mode_eval = jit_mode_eval,
|
| 280 |
-
use_ipex = use_ipex,
|
| 281 |
-
bf16 = bf16,
|
| 282 |
-
fp16 = fp16,
|
| 283 |
-
fp16_opt_level = fp16_opt_level,
|
| 284 |
-
half_precision_backend = half_precision_backend,
|
| 285 |
-
bf16_full_eval = bf16_full_eval,
|
| 286 |
-
fp16_full_eval = fp16_full_eval,
|
| 287 |
-
tf32 = tf32,
|
| 288 |
-
local_rank = local_rank,
|
| 289 |
-
ddp_backend = ddp_backend,
|
| 290 |
-
tpu_num_cores = tpu_num_cores,
|
| 291 |
-
tpu_metrics_debug = tpu_metrics_debug,
|
| 292 |
-
debug = debug,
|
| 293 |
-
dataloader_drop_last = dataloader_drop_last,
|
| 294 |
-
eval_steps = eval_steps,
|
| 295 |
-
dataloader_num_workers = dataloader_num_workers,
|
| 296 |
-
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
| 297 |
-
past_index = past_index,
|
| 298 |
-
run_name = run_name,
|
| 299 |
-
disable_tqdm = disable_tqdm,
|
| 300 |
-
remove_unused_columns = remove_unused_columns,
|
| 301 |
-
label_names = label_names,
|
| 302 |
-
load_best_model_at_end = load_best_model_at_end,
|
| 303 |
-
metric_for_best_model = metric_for_best_model,
|
| 304 |
-
greater_is_better = greater_is_better,
|
| 305 |
-
ignore_data_skip = ignore_data_skip,
|
| 306 |
-
fsdp = fsdp,
|
| 307 |
-
fsdp_min_num_params = fsdp_min_num_params,
|
| 308 |
-
fsdp_config = fsdp_config,
|
| 309 |
-
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
| 310 |
-
accelerator_config = accelerator_config,
|
| 311 |
-
deepspeed = deepspeed,
|
| 312 |
-
label_smoothing_factor = label_smoothing_factor,
|
| 313 |
-
optim = optim,
|
| 314 |
-
optim_args = optim_args,
|
| 315 |
-
adafactor = adafactor,
|
| 316 |
-
group_by_length = group_by_length,
|
| 317 |
-
length_column_name = length_column_name,
|
| 318 |
-
report_to = report_to,
|
| 319 |
-
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
| 320 |
-
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
| 321 |
-
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
| 322 |
-
dataloader_pin_memory = dataloader_pin_memory,
|
| 323 |
-
dataloader_persistent_workers = dataloader_persistent_workers,
|
| 324 |
-
skip_memory_metrics = skip_memory_metrics,
|
| 325 |
-
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
| 326 |
-
push_to_hub = push_to_hub,
|
| 327 |
-
resume_from_checkpoint = resume_from_checkpoint,
|
| 328 |
-
hub_model_id = hub_model_id,
|
| 329 |
-
hub_strategy = hub_strategy,
|
| 330 |
-
hub_token = hub_token,
|
| 331 |
-
hub_private_repo = hub_private_repo,
|
| 332 |
-
hub_always_push = hub_always_push,
|
| 333 |
-
hub_revision = hub_revision,
|
| 334 |
-
gradient_checkpointing = gradient_checkpointing,
|
| 335 |
-
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
| 336 |
-
include_inputs_for_metrics = include_inputs_for_metrics,
|
| 337 |
-
eval_do_concat_batches = eval_do_concat_batches,
|
| 338 |
-
fp16_backend = fp16_backend,
|
| 339 |
-
push_to_hub_model_id = push_to_hub_model_id,
|
| 340 |
-
push_to_hub_organization = push_to_hub_organization,
|
| 341 |
-
push_to_hub_token = push_to_hub_token,
|
| 342 |
-
mp_parameters = mp_parameters,
|
| 343 |
-
auto_find_batch_size = auto_find_batch_size,
|
| 344 |
-
full_determinism = full_determinism,
|
| 345 |
-
torchdynamo = torchdynamo,
|
| 346 |
-
ray_scope = ray_scope,
|
| 347 |
-
ddp_timeout = ddp_timeout,
|
| 348 |
-
torch_compile = torch_compile,
|
| 349 |
-
torch_compile_backend = torch_compile_backend,
|
| 350 |
-
torch_compile_mode = torch_compile_mode,
|
| 351 |
-
include_tokens_per_second = include_tokens_per_second,
|
| 352 |
-
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
| 353 |
-
neftune_noise_alpha = neftune_noise_alpha,
|
| 354 |
-
optim_target_modules = optim_target_modules,
|
| 355 |
-
batch_eval_metrics = batch_eval_metrics,
|
| 356 |
-
eval_on_start = eval_on_start,
|
| 357 |
-
use_liger_kernel = use_liger_kernel,
|
| 358 |
-
liger_kernel_config = liger_kernel_config,
|
| 359 |
-
eval_use_gather_object = eval_use_gather_object,
|
| 360 |
-
average_tokens_across_devices = average_tokens_across_devices,
|
| 361 |
-
reward_model_path = reward_model_path,
|
| 362 |
-
judge = judge,
|
| 363 |
-
max_new_tokens = max_new_tokens,
|
| 364 |
-
max_length = max_length,
|
| 365 |
-
temperature = temperature,
|
| 366 |
-
missing_eos_penalty = missing_eos_penalty,
|
| 367 |
-
loss_type = loss_type,
|
| 368 |
-
dataset_num_proc = dataset_num_proc,
|
| 369 |
-
disable_dropout = disable_dropout,
|
| 370 |
-
use_vllm = use_vllm,
|
| 371 |
-
ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
|
| 372 |
-
self.vllm_sampling_params = vllm_sampling_params
|
| 373 |
-
self.unsloth_num_chunks = unsloth_num_chunks
|
| 374 |
-
pass
|
| 375 |
-
|
| 376 |
-
class _UnslothXPOTrainer(OnlineDPOTrainer):
|
| 377 |
-
r""""""
|
| 378 |
-
|
| 379 |
-
_tag_names = ["trl", "xpo"]
|
| 380 |
-
|
| 381 |
-
def __init__(
|
| 382 |
-
self,
|
| 383 |
-
model: Union[PreTrainedModel, nn.Module] = None,
|
| 384 |
-
ref_model: Union[PreTrainedModel, nn.Module] = None,
|
| 385 |
-
reward_model: Optional[nn.Module] = None,
|
| 386 |
-
judge: Optional[BasePairwiseJudge] = None,
|
| 387 |
-
args: Optional[XPOConfig] = None,
|
| 388 |
-
data_collator: Optional[Callable] = None,
|
| 389 |
-
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
| 390 |
-
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
| 391 |
-
processing_class: Optional[
|
| 392 |
-
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
| 393 |
-
] = None,
|
| 394 |
-
peft_config: Optional[dict] = None,
|
| 395 |
-
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
| 396 |
-
callbacks: Optional[list[TrainerCallback]] = None,
|
| 397 |
-
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 398 |
-
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 399 |
-
) -> None:
|
| 400 |
-
super().__init__(
|
| 401 |
-
model=model,
|
| 402 |
-
ref_model=ref_model,
|
| 403 |
-
judge=judge,
|
| 404 |
-
reward_model=reward_model,
|
| 405 |
-
args=args,
|
| 406 |
-
data_collator=data_collator,
|
| 407 |
-
train_dataset=train_dataset,
|
| 408 |
-
eval_dataset=eval_dataset,
|
| 409 |
-
processing_class=processing_class,
|
| 410 |
-
reward_processing_class=processing_class, # for now, XPOTrainer can't use any reward model
|
| 411 |
-
peft_config=peft_config,
|
| 412 |
-
compute_metrics=compute_metrics,
|
| 413 |
-
callbacks=callbacks,
|
| 414 |
-
optimizers=optimizers,
|
| 415 |
-
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 416 |
-
)
|
| 417 |
-
|
| 418 |
-
self._alpha = self.args.alpha
|
| 419 |
-
|
| 420 |
-
# Overwrite the stats dictionary to include XPO specific statistics
|
| 421 |
-
self.stats = {
|
| 422 |
-
# Remove "non_score_reward", "rlhf_reward", "scores"
|
| 423 |
-
# Add "loss/dpo", "loss/xpo"
|
| 424 |
-
"loss/dpo": [],
|
| 425 |
-
"loss/xpo": [],
|
| 426 |
-
"objective/kl": [],
|
| 427 |
-
"objective/entropy": [],
|
| 428 |
-
"rewards/chosen": [],
|
| 429 |
-
"rewards/rejected": [],
|
| 430 |
-
"rewards/accuracies": [],
|
| 431 |
-
"rewards/margins": [],
|
| 432 |
-
"logps/chosen": [],
|
| 433 |
-
"logps/rejected": [],
|
| 434 |
-
# Replace "contain_eos_token" by "model_contain_eos_token" and "ref_contain_eos_token"
|
| 435 |
-
"val/model_contain_eos_token": [],
|
| 436 |
-
"val/ref_contain_eos_token": [],
|
| 437 |
-
"alpha": [],
|
| 438 |
-
"beta": [],
|
| 439 |
-
}
|
| 440 |
-
if self.reward_model is not None:
|
| 441 |
-
# Replace "scores" by "model_scores" and "ref_scores"
|
| 442 |
-
self.stats["objective/model_scores"] = []
|
| 443 |
-
self.stats["objective/ref_scores"] = []
|
| 444 |
-
self.stats["objective/scores_margin"] = []
|
| 445 |
-
|
| 446 |
-
@property
|
| 447 |
-
def alpha(self):
|
| 448 |
-
if isinstance(self._alpha, list):
|
| 449 |
-
epoch = self.state.epoch
|
| 450 |
-
return self._alpha[epoch] if epoch < len(self._alpha) else self._alpha[-1]
|
| 451 |
-
else:
|
| 452 |
-
return self._alpha
|
| 453 |
-
|
| 454 |
-
def _generate_completions(self, prompts, model):
|
| 455 |
-
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
| 456 |
-
model_output = unwrapped_model.generate(
|
| 457 |
-
input_ids=prompts["input_ids"],
|
| 458 |
-
attention_mask=prompts["attention_mask"],
|
| 459 |
-
generation_config=self.generation_config,
|
| 460 |
-
)
|
| 461 |
-
|
| 462 |
-
ref_model = model if self.ref_model is None else self.ref_model
|
| 463 |
-
with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model:
|
| 464 |
-
ref_output = unwrapped_ref_model.generate(
|
| 465 |
-
input_ids=prompts["input_ids"],
|
| 466 |
-
attention_mask=prompts["attention_mask"],
|
| 467 |
-
generation_config=self.generation_config,
|
| 468 |
-
)
|
| 469 |
-
|
| 470 |
-
return model_output, ref_output
|
| 471 |
-
|
| 472 |
-
def _process_completions(self, model_output, ref_output, prompts):
|
| 473 |
-
context_length = prompts["input_ids"].shape[1]
|
| 474 |
-
|
| 475 |
-
# Process model completions
|
| 476 |
-
model_completion_ids = model_output[:, context_length:]
|
| 477 |
-
model_completion_ids, model_completion_mask = truncate_right(
|
| 478 |
-
model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
|
| 479 |
-
)
|
| 480 |
-
model_data = {
|
| 481 |
-
"input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1),
|
| 482 |
-
"attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1),
|
| 483 |
-
"raw": prompts["raw"],
|
| 484 |
-
}
|
| 485 |
-
|
| 486 |
-
# Process reference model completions
|
| 487 |
-
ref_completion_ids = ref_output[:, context_length:]
|
| 488 |
-
ref_completion_ids, ref_completion_mask = truncate_right(
|
| 489 |
-
ref_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
|
| 490 |
-
)
|
| 491 |
-
ref_data = {
|
| 492 |
-
"input_ids": torch.cat((prompts["input_ids"], ref_completion_ids), dim=1),
|
| 493 |
-
"attention_mask": torch.cat((prompts["attention_mask"], ref_completion_mask), dim=1),
|
| 494 |
-
"raw": prompts["raw"],
|
| 495 |
-
}
|
| 496 |
-
|
| 497 |
-
return model_data, ref_data
|
| 498 |
-
|
| 499 |
-
def _compute_rewards(self, model_data, ref_data, context_length):
|
| 500 |
-
with torch.no_grad():
|
| 501 |
-
_, model_scores, _ = get_reward(
|
| 502 |
-
self.reward_model, model_data["input_ids"], self.processing_class.pad_token_id, context_length
|
| 503 |
-
)
|
| 504 |
-
_, ref_scores, _ = get_reward(
|
| 505 |
-
self.reward_model, ref_data["input_ids"], self.processing_class.pad_token_id, context_length
|
| 506 |
-
)
|
| 507 |
-
|
| 508 |
-
# Apply EOS penalty if needed
|
| 509 |
-
if self.args.missing_eos_penalty is not None:
|
| 510 |
-
model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
|
| 511 |
-
ref_contain_eos = torch.any(ref_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
|
| 512 |
-
model_scores[~model_contain_eos] -= self.args.missing_eos_penalty
|
| 513 |
-
ref_scores[~ref_contain_eos] -= self.args.missing_eos_penalty
|
| 514 |
-
|
| 515 |
-
return model_scores, ref_scores
|
| 516 |
-
|
| 517 |
-
def _compute_judge(self, model_data, ref_data, context_length):
|
| 518 |
-
prompts = model_data["raw"]
|
| 519 |
-
model_data_completions = self.processing_class.batch_decode(
|
| 520 |
-
model_data["input_ids"][:, context_length:], skip_special_tokens=True
|
| 521 |
-
)
|
| 522 |
-
model_data_completions = [completion.strip() for completion in model_data_completions]
|
| 523 |
-
|
| 524 |
-
ref_data_completions = self.processing_class.batch_decode(
|
| 525 |
-
ref_data["input_ids"][:, context_length:], skip_special_tokens=True
|
| 526 |
-
)
|
| 527 |
-
ref_data_completions = [completion.strip() for completion in ref_data_completions]
|
| 528 |
-
|
| 529 |
-
if is_conversational({"prompt": prompts[0]}):
|
| 530 |
-
model_data_completions = [
|
| 531 |
-
[{"role": "assistant", "content": completion}] for completion in model_data_completions
|
| 532 |
-
]
|
| 533 |
-
environment = jinja2.Environment()
|
| 534 |
-
template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
|
| 535 |
-
prompts = [template.render(messages=message) for message in prompts]
|
| 536 |
-
model_data_completions = [template.render(messages=completion) for completion in model_data_completions]
|
| 537 |
-
|
| 538 |
-
ref_data_completions = [
|
| 539 |
-
[{"role": "assistant", "content": completion}] for completion in ref_data_completions
|
| 540 |
-
]
|
| 541 |
-
ref_data_completions = [template.render(messages=completion) for completion in ref_data_completions]
|
| 542 |
-
|
| 543 |
-
ranks_of_first_completion = self.judge.judge(
|
| 544 |
-
prompts,
|
| 545 |
-
list(zip(model_data_completions, ref_data_completions)),
|
| 546 |
-
)
|
| 547 |
-
# convert ranks to a True/False mask:
|
| 548 |
-
# when rank == 0, it means the first completion is the best
|
| 549 |
-
# when rank == 1, it means the second completion is the best
|
| 550 |
-
return torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=model_data["input_ids"].device)
|
| 551 |
-
|
| 552 |
-
def _compute_logprobs(self, model, model_data, ref_data, context_length):
|
| 553 |
-
def compute_logprobs_for_data(m, data):
|
| 554 |
-
output = m(data["input_ids"], attention_mask=data["attention_mask"])
|
| 555 |
-
logits = output.logits[:, context_length - 1 : -1]
|
| 556 |
-
token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
|
| 557 |
-
return token_logprobs
|
| 558 |
-
|
| 559 |
-
# Compute logprobs for model completions
|
| 560 |
-
model_logprobs_model_data = compute_logprobs_for_data(model, model_data)
|
| 561 |
-
# Compute logprobs for model on reference completions (for XPO loss)
|
| 562 |
-
model_logprobs_ref_data = compute_logprobs_for_data(model, ref_data)
|
| 563 |
-
|
| 564 |
-
# Compute logprobs for reference model completions
|
| 565 |
-
with torch.no_grad():
|
| 566 |
-
if self.ref_model is None:
|
| 567 |
-
with model.disable_adapter():
|
| 568 |
-
ref_logprobs_model_data = compute_logprobs_for_data(model, model_data)
|
| 569 |
-
ref_logprobs_ref_data = compute_logprobs_for_data(model, ref_data)
|
| 570 |
-
else:
|
| 571 |
-
ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data)
|
| 572 |
-
ref_logprobs_ref_data = compute_logprobs_for_data(self.ref_model, ref_data)
|
| 573 |
-
|
| 574 |
-
# Mask padding tokens
|
| 575 |
-
model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
|
| 576 |
-
ref_padding_mask = ref_data["attention_mask"][:, context_length:] == 0
|
| 577 |
-
model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
|
| 578 |
-
model_logprobs_ref_data = model_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0)
|
| 579 |
-
ref_logprobs_ref_data = ref_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0)
|
| 580 |
-
ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
|
| 581 |
-
|
| 582 |
-
return model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data
|
| 583 |
-
|
| 584 |
-
def _compute_losses(
|
| 585 |
-
self,
|
| 586 |
-
model_logprobs_model_data,
|
| 587 |
-
model_logprobs_ref_data,
|
| 588 |
-
ref_logprobs_ref_data,
|
| 589 |
-
ref_logprobs_model_data,
|
| 590 |
-
chosen_mask,
|
| 591 |
-
):
|
| 592 |
-
# Compute log probs
|
| 593 |
-
model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
|
| 594 |
-
model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1)
|
| 595 |
-
ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1)
|
| 596 |
-
ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
|
| 597 |
-
|
| 598 |
-
chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
|
| 599 |
-
chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
|
| 600 |
-
chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs
|
| 601 |
-
|
| 602 |
-
rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
|
| 603 |
-
rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
|
| 604 |
-
rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs
|
| 605 |
-
|
| 606 |
-
# Compute logits as the difference between chosen and rejected log ratios
|
| 607 |
-
logits = chosen_log_ratios - rejected_log_ratios
|
| 608 |
-
|
| 609 |
-
if self.args.loss_type == "sigmoid":
|
| 610 |
-
dpo_losses = -F.logsigmoid(self.beta * logits)
|
| 611 |
-
elif self.args.loss_type == "ipo":
|
| 612 |
-
dpo_losses = (logits - 1 / (2 * self.beta)) ** 2
|
| 613 |
-
else:
|
| 614 |
-
raise NotImplementedError(f"invalid loss type {self.args.loss_type}")
|
| 615 |
-
|
| 616 |
-
# Compute XPO specific loss
|
| 617 |
-
xpo_losses = self.alpha * model_logprobs_ref_data_sum
|
| 618 |
-
|
| 619 |
-
# Total loss
|
| 620 |
-
loss = (dpo_losses + xpo_losses).mean()
|
| 621 |
-
|
| 622 |
-
return loss, dpo_losses, xpo_losses
|
| 623 |
-
|
| 624 |
-
def _log_statistics(
|
| 625 |
-
self,
|
| 626 |
-
model_data,
|
| 627 |
-
ref_data,
|
| 628 |
-
model_logprobs_model_data,
|
| 629 |
-
model_logprobs_ref_data,
|
| 630 |
-
ref_logprobs_ref_data,
|
| 631 |
-
ref_logprobs_model_data,
|
| 632 |
-
chosen_mask,
|
| 633 |
-
dpo_losses,
|
| 634 |
-
xpo_losses,
|
| 635 |
-
context_length,
|
| 636 |
-
model_scores=None,
|
| 637 |
-
ref_scores=None,
|
| 638 |
-
):
|
| 639 |
-
# Helper function to gather and compute mean
|
| 640 |
-
def gather_mean(tensor):
|
| 641 |
-
return self.accelerator.gather_for_metrics(tensor).mean().item()
|
| 642 |
-
|
| 643 |
-
# Log losses
|
| 644 |
-
self.stats["loss/dpo"].append(gather_mean(dpo_losses))
|
| 645 |
-
self.stats["loss/xpo"].append(gather_mean(xpo_losses))
|
| 646 |
-
|
| 647 |
-
# Log scores
|
| 648 |
-
if self.reward_model is not None:
|
| 649 |
-
self.stats["objective/model_scores"].append(gather_mean(model_scores))
|
| 650 |
-
self.stats["objective/ref_scores"].append(gather_mean(ref_scores))
|
| 651 |
-
self.stats["objective/scores_margin"].append(gather_mean(model_scores - ref_scores))
|
| 652 |
-
|
| 653 |
-
# Log logprobs
|
| 654 |
-
model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
|
| 655 |
-
model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1)
|
| 656 |
-
ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1)
|
| 657 |
-
ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
|
| 658 |
-
|
| 659 |
-
chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
|
| 660 |
-
chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
|
| 661 |
-
chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs
|
| 662 |
-
|
| 663 |
-
rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
|
| 664 |
-
rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
|
| 665 |
-
rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs
|
| 666 |
-
|
| 667 |
-
self.stats["logps/chosen"].append(gather_mean(chosen_model_logprobs.mean() + chosen_ref_logprobs.mean()))
|
| 668 |
-
self.stats["logps/rejected"].append(gather_mean(rejected_model_logprobs.mean() + rejected_ref_logprobs.mean()))
|
| 669 |
-
|
| 670 |
-
# Log rewards
|
| 671 |
-
# Compute various statistics
|
| 672 |
-
chosen_rewards = chosen_log_ratios * self.beta
|
| 673 |
-
rejected_rewards = rejected_log_ratios * self.beta
|
| 674 |
-
self.stats["rewards/chosen"].append(gather_mean(chosen_rewards.mean()))
|
| 675 |
-
self.stats["rewards/rejected"].append(gather_mean(rejected_rewards.mean()))
|
| 676 |
-
|
| 677 |
-
# Calculate KL divergence for model and ref data
|
| 678 |
-
kl_model_data = model_logprobs_model_data - ref_logprobs_model_data
|
| 679 |
-
kl_ref_data = model_logprobs_ref_data - ref_logprobs_ref_data
|
| 680 |
-
mean_kl = (kl_model_data.sum(1) + kl_ref_data.sum(1)).mean() / 2
|
| 681 |
-
self.stats["objective/kl"].append(gather_mean(mean_kl))
|
| 682 |
-
|
| 683 |
-
# Calculate entropy for model and ref data
|
| 684 |
-
entropy_model_data = -model_logprobs_model_data.sum(1)
|
| 685 |
-
entropy_ref_data = -model_logprobs_ref_data.sum(1)
|
| 686 |
-
mean_entropy = (entropy_model_data.mean() + entropy_ref_data.mean()) / 2
|
| 687 |
-
self.stats["objective/entropy"].append(gather_mean(mean_entropy))
|
| 688 |
-
|
| 689 |
-
# Calculate margins
|
| 690 |
-
margin = chosen_rewards - rejected_rewards
|
| 691 |
-
self.stats["rewards/margins"].append(gather_mean(margin.mean()))
|
| 692 |
-
|
| 693 |
-
# Calculate accuracy
|
| 694 |
-
accuracy = (margin > 0).float()
|
| 695 |
-
self.stats["rewards/accuracies"].append(gather_mean(accuracy.mean()))
|
| 696 |
-
|
| 697 |
-
# Log EOS token statistics
|
| 698 |
-
model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
|
| 699 |
-
ref_eos = (ref_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
|
| 700 |
-
self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float()))
|
| 701 |
-
self.stats["val/ref_contain_eos_token"].append(gather_mean(ref_eos.float()))
|
| 702 |
-
|
| 703 |
-
# Log alpha and beta
|
| 704 |
-
self.stats["alpha"].append(self.alpha)
|
| 705 |
-
self.stats["beta"].append(self.beta)
|
| 706 |
-
|
| 707 |
-
def training_step(
|
| 708 |
-
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
|
| 709 |
-
) -> torch.Tensor:
|
| 710 |
-
model.train()
|
| 711 |
-
|
| 712 |
-
# Apply chat template and tokenize the input
|
| 713 |
-
batch_size = len(next(iter(inputs.values())))
|
| 714 |
-
prompts = inputs["prompt"]
|
| 715 |
-
inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
|
| 716 |
-
inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
|
| 717 |
-
inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs]
|
| 718 |
-
inputs = self.data_collator(inputs)
|
| 719 |
-
|
| 720 |
-
# need the prompt_ only
|
| 721 |
-
inputs = self._prepare_inputs(inputs)
|
| 722 |
-
context_length = inputs["prompt_input_ids"].shape[1]
|
| 723 |
-
prompts = {
|
| 724 |
-
"input_ids": inputs["prompt_input_ids"],
|
| 725 |
-
"attention_mask": inputs["prompt_attention_mask"],
|
| 726 |
-
"raw": prompts,
|
| 727 |
-
}
|
| 728 |
-
del inputs
|
| 729 |
-
|
| 730 |
-
# Sample completions from both the model and the reference model
|
| 731 |
-
model_output, ref_output = self._generate_completions(prompts, model)
|
| 732 |
-
|
| 733 |
-
# Process model completions
|
| 734 |
-
model_data, ref_data = self._process_completions(model_output, ref_output, prompts)
|
| 735 |
-
|
| 736 |
-
# Compute rewards
|
| 737 |
-
if self.reward_model is not None:
|
| 738 |
-
model_scores, ref_scores = self._compute_rewards(model_data, ref_data, context_length)
|
| 739 |
-
chosen_mask = model_scores >= ref_scores
|
| 740 |
-
else:
|
| 741 |
-
model_scores, ref_scores = None, None
|
| 742 |
-
chosen_mask = self._compute_judge(model_data, ref_data, context_length)
|
| 743 |
-
|
| 744 |
-
# Compute logprobs
|
| 745 |
-
model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data = (
|
| 746 |
-
self._compute_logprobs(model, model_data, ref_data, context_length)
|
| 747 |
-
)
|
| 748 |
-
|
| 749 |
-
# Compute loss
|
| 750 |
-
loss, dpo_losses, xpo_losses = self._compute_losses(
|
| 751 |
-
model_logprobs_model_data,
|
| 752 |
-
model_logprobs_ref_data,
|
| 753 |
-
ref_logprobs_ref_data,
|
| 754 |
-
ref_logprobs_model_data,
|
| 755 |
-
chosen_mask,
|
| 756 |
-
)
|
| 757 |
-
|
| 758 |
-
# Log everything
|
| 759 |
-
self._log_statistics(
|
| 760 |
-
model_data,
|
| 761 |
-
ref_data,
|
| 762 |
-
model_logprobs_model_data.detach(),
|
| 763 |
-
model_logprobs_ref_data.detach(),
|
| 764 |
-
ref_logprobs_ref_data,
|
| 765 |
-
ref_logprobs_model_data,
|
| 766 |
-
chosen_mask,
|
| 767 |
-
dpo_losses.detach(),
|
| 768 |
-
xpo_losses.detach(),
|
| 769 |
-
context_length,
|
| 770 |
-
model_scores,
|
| 771 |
-
ref_scores,
|
| 772 |
-
)
|
| 773 |
-
|
| 774 |
-
if (
|
| 775 |
-
self.args.torch_empty_cache_steps is not None
|
| 776 |
-
and self.state.global_step % self.args.torch_empty_cache_steps == 0
|
| 777 |
-
):
|
| 778 |
-
empty_cache()
|
| 779 |
-
|
| 780 |
-
kwargs = {}
|
| 781 |
-
# For LOMO optimizers you need to explicitly use the learning rate
|
| 782 |
-
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
| 783 |
-
kwargs["learning_rate"] = self._get_learning_rate()
|
| 784 |
-
|
| 785 |
-
if self.args.n_gpu > 1:
|
| 786 |
-
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
| 787 |
-
|
| 788 |
-
if self.use_apex:
|
| 789 |
-
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
| 790 |
-
scaled_loss.backward()
|
| 791 |
-
else:
|
| 792 |
-
self.accelerator.backward(loss, **kwargs)
|
| 793 |
-
|
| 794 |
-
return loss.detach() / self.args.gradient_accumulation_steps
|
| 795 |
-
|
| 796 |
-
def create_model_card(
|
| 797 |
-
self,
|
| 798 |
-
model_name: Optional[str] = None,
|
| 799 |
-
dataset_name: Optional[str] = None,
|
| 800 |
-
tags: Union[str, list[str], None] = None,
|
| 801 |
-
):
|
| 802 |
-
"""
|
| 803 |
-
Creates a draft of a model card using the information available to the `Trainer`.
|
| 804 |
-
|
| 805 |
-
Args:
|
| 806 |
-
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 807 |
-
Name of the model.
|
| 808 |
-
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 809 |
-
Name of the dataset used for training.
|
| 810 |
-
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 811 |
-
Tags to be associated with the model card.
|
| 812 |
-
"""
|
| 813 |
-
if not self.is_world_process_zero():
|
| 814 |
-
return
|
| 815 |
-
|
| 816 |
-
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 817 |
-
base_model = self.model.config._name_or_path
|
| 818 |
-
else:
|
| 819 |
-
base_model = None
|
| 820 |
-
|
| 821 |
-
tags = tags or []
|
| 822 |
-
if isinstance(tags, str):
|
| 823 |
-
tags = [tags]
|
| 824 |
-
|
| 825 |
-
if hasattr(self.model.config, "unsloth_version"):
|
| 826 |
-
tags.append("unsloth")
|
| 827 |
-
|
| 828 |
-
citation = textwrap.dedent("""\
|
| 829 |
-
@article{jung2024binary,
|
| 830 |
-
title = {{Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF}},
|
| 831 |
-
author = {Tengyang Xie and Dylan J. Foster and Akshay Krishnamurthy and Corby Rosset and Ahmed Awadallah and Alexander Rakhlin},
|
| 832 |
-
year = 2024,
|
| 833 |
-
eprint = {arXiv:2405.21046}
|
| 834 |
-
}""")
|
| 835 |
-
|
| 836 |
-
model_card = generate_model_card(
|
| 837 |
-
base_model=base_model,
|
| 838 |
-
model_name=model_name,
|
| 839 |
-
hub_model_id=self.hub_model_id,
|
| 840 |
-
dataset_name=dataset_name,
|
| 841 |
-
tags=tags,
|
| 842 |
-
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 843 |
-
comet_url=get_comet_experiment_url(),
|
| 844 |
-
trainer_name="XPO",
|
| 845 |
-
trainer_citation=citation,
|
| 846 |
-
paper_title="Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF",
|
| 847 |
-
paper_id="2405.21046",
|
| 848 |
-
)
|
| 849 |
-
|
| 850 |
-
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 851 |
-
class UnslothXPOTrainer(_UnslothXPOTrainer):
|
| 852 |
-
"""
|
| 853 |
-
|
| 854 |
-
Initialize XPOTrainer as a subclass of [`OnlineDPOConfig`].
|
| 855 |
-
|
| 856 |
-
Args:
|
| 857 |
-
model (`transformers.PreTrainedModel`):
|
| 858 |
-
The model to train, preferably an `AutoModelForCausalLM`.
|
| 859 |
-
ref_model (`PreTrainedModelWrapper`):
|
| 860 |
-
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
|
| 861 |
-
reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
|
| 862 |
-
reward_model (`transformers.PreTrainedModel`):
|
| 863 |
-
The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
|
| 864 |
-
judge (`BasePairwiseJudge`):
|
| 865 |
-
The judge to use for pairwise comparison of model completions.
|
| 866 |
-
args (`XPOConfig`):
|
| 867 |
-
The XPO config arguments to use for training.
|
| 868 |
-
data_collator (`transformers.DataCollator`):
|
| 869 |
-
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
| 870 |
-
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
| 871 |
-
train_dataset (`datasets.Dataset`):
|
| 872 |
-
The dataset to use for training.
|
| 873 |
-
eval_dataset (`datasets.Dataset`):
|
| 874 |
-
The dataset to use for evaluation.
|
| 875 |
-
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
| 876 |
-
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
| 877 |
-
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
| 878 |
-
reuse the fine-tuned model.
|
| 879 |
-
peft_config (`dict`):
|
| 880 |
-
The peft config to use for training.
|
| 881 |
-
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
| 882 |
-
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
| 883 |
-
a dictionary string to metric values.
|
| 884 |
-
callbacks (`list[transformers.TrainerCallback]`):
|
| 885 |
-
The callbacks to use for training.
|
| 886 |
-
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
| 887 |
-
The optimizer and scheduler to use for training.
|
| 888 |
-
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
| 889 |
-
The function to use to preprocess the logits before computing the metrics.
|
| 890 |
-
|
| 891 |
-
"""
|
| 892 |
-
def __init__(
|
| 893 |
-
self,
|
| 894 |
-
model = None,
|
| 895 |
-
ref_model = None,
|
| 896 |
-
reward_model = None,
|
| 897 |
-
judge = None,
|
| 898 |
-
args = None,
|
| 899 |
-
data_collator = None,
|
| 900 |
-
train_dataset = None,
|
| 901 |
-
eval_dataset = None,
|
| 902 |
-
processing_class = None,
|
| 903 |
-
peft_config = None,
|
| 904 |
-
compute_metrics = None,
|
| 905 |
-
callbacks = None,
|
| 906 |
-
preprocess_logits_for_metrics = None,
|
| 907 |
-
**kwargs
|
| 908 |
-
):
|
| 909 |
-
if args is None: args = UnslothXPOConfig()
|
| 910 |
-
use_bf16 = getattr(args, 'bf16', False)
|
| 911 |
-
if type(use_bf16) is not bool: use_bf16 = False
|
| 912 |
-
use_fp16 = getattr(args, 'fp16', False)
|
| 913 |
-
if type(use_fp16) is not bool: use_fp16 = False
|
| 914 |
-
force_float32 = False
|
| 915 |
-
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
| 916 |
-
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
| 917 |
-
force_float32 = True
|
| 918 |
-
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
| 919 |
-
dtype = getattr(model.config, 'torch_dtype', None)
|
| 920 |
-
if dtype is None: dtype = model.get_input_embeddings().dtype
|
| 921 |
-
from unsloth_zoo.utils import _get_dtype
|
| 922 |
-
dtype = _get_dtype(dtype)
|
| 923 |
-
float16 = dtype == torch.float16
|
| 924 |
-
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
| 925 |
-
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
| 926 |
-
if force_float32:
|
| 927 |
-
args.fp16 = False
|
| 928 |
-
args.bf16 = False
|
| 929 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
| 930 |
-
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
| 931 |
-
args.fp16 = float16
|
| 932 |
-
args.bf16 = not float16
|
| 933 |
-
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
| 934 |
-
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
| 935 |
-
args.eval_strategy = 'steps'
|
| 936 |
-
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
| 937 |
-
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
| 938 |
-
if ga_steps is not None and ga_steps > 1:
|
| 939 |
-
from transformers import __version__ as transformers_version
|
| 940 |
-
if Version(transformers_version) <= Version('4.45.2'):
|
| 941 |
-
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
| 942 |
-
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
| 943 |
-
if getattr(args, 'eval_strategy', 'no') != 'no':
|
| 944 |
-
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
| 945 |
-
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
| 946 |
-
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
| 947 |
-
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
| 948 |
-
if type(fp16_full_eval) is not bool: fp16_full_eval = False
|
| 949 |
-
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
| 950 |
-
if type(bf16_full_eval) is not bool: bf16_full_eval = False
|
| 951 |
-
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
| 952 |
-
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
| 953 |
-
if force_float32:
|
| 954 |
-
args.bf16_full_eval = False
|
| 955 |
-
args.fp16_full_eval = False
|
| 956 |
-
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
| 957 |
-
args.bf16_full_eval = True
|
| 958 |
-
args.fp16_full_eval = False
|
| 959 |
-
elif not bf16_full_eval and not fp16_full_eval:
|
| 960 |
-
args.bf16_full_eval = args.bf16
|
| 961 |
-
args.fp16_full_eval = args.fp16
|
| 962 |
-
_output_logits = False
|
| 963 |
-
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
| 964 |
-
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
| 965 |
-
if _output_logits:
|
| 966 |
-
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
| 967 |
-
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
| 968 |
-
pass
|
| 969 |
-
else:
|
| 970 |
-
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
| 971 |
-
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
| 972 |
-
if args_max_seq_length is None and model_max_seq_length is not None:
|
| 973 |
-
max_seq_length = model.max_seq_length
|
| 974 |
-
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
| 975 |
-
if model is not None and hasattr(model, 'for_training'):
|
| 976 |
-
model.for_training()
|
| 977 |
-
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
| 978 |
-
if 'processing_class' in locals():
|
| 979 |
-
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
| 980 |
-
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
| 981 |
-
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
| 982 |
-
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
| 983 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 984 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
| 985 |
-
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
|
| 986 |
-
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
| 987 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
| 988 |
-
else:
|
| 989 |
-
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
| 990 |
-
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
| 991 |
-
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
| 992 |
-
if not isinstance(data_collator, UnslothVisionDataCollator):
|
| 993 |
-
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
| 994 |
-
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
| 995 |
-
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
| 996 |
-
else:
|
| 997 |
-
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
|
| 998 |
-
other_metrics = []
|
| 999 |
-
|
| 1000 |
-
from unsloth_zoo.logging_utils import PatchRLStatistics
|
| 1001 |
-
PatchRLStatistics('xpo_trainer', other_metrics)
|
| 1002 |
-
|
| 1003 |
-
super().__init__(
|
| 1004 |
-
model = model,
|
| 1005 |
-
ref_model = ref_model,
|
| 1006 |
-
reward_model = reward_model,
|
| 1007 |
-
judge = judge,
|
| 1008 |
-
args = args,
|
| 1009 |
-
data_collator = data_collator,
|
| 1010 |
-
train_dataset = train_dataset,
|
| 1011 |
-
eval_dataset = eval_dataset,
|
| 1012 |
-
processing_class = processing_class,
|
| 1013 |
-
peft_config = peft_config,
|
| 1014 |
-
compute_metrics = compute_metrics,
|
| 1015 |
-
callbacks = callbacks,
|
| 1016 |
-
preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
|
| 1017 |
-
if hasattr(self, 'neftune_hook_handle'):
|
| 1018 |
-
self.neftune_hook_handle.remove()
|
| 1019 |
-
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
| 1020 |
-
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
| 1021 |
-
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
| 1022 |
-
pass
|
| 1023 |
-
|
| 1024 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_run_uploads/__pycache__/UnslothAlignPropTrainer.cpython-311.pyc
DELETED
|
Binary file (33.9 kB)
|
|
|
test_run_uploads/__pycache__/UnslothBCOTrainer.cpython-311.pyc
DELETED
|
Binary file (92.8 kB)
|
|
|
test_run_uploads/__pycache__/UnslothCPOTrainer.cpython-311.pyc
DELETED
|
Binary file (76.7 kB)
|
|
|
test_run_uploads/__pycache__/UnslothDDPOTrainer.cpython-311.pyc
DELETED
|
Binary file (46.5 kB)
|
|
|
test_run_uploads/__pycache__/UnslothDPOTrainer.cpython-311.pyc
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:8c20178043e78b3057a4eec21c41cb84e543aa7e03cab7996894ab8e7904e768
|
| 3 |
-
size 104591
|
|
|
|
|
|
|
|
|
|
|
|
test_run_uploads/__pycache__/UnslothGKDTrainer.cpython-311.pyc
DELETED
|
Binary file (39.6 kB)
|
|
|
test_run_uploads/__pycache__/UnslothGRPOTrainer.cpython-311.pyc
DELETED
|
Binary file (97.6 kB)
|
|
|
test_run_uploads/__pycache__/UnslothKTOTrainer.cpython-311.pyc
DELETED
|
Binary file (88.7 kB)
|
|
|
test_run_uploads/__pycache__/UnslothNashMDTrainer.cpython-311.pyc
DELETED
|
Binary file (49 kB)
|
|
|
test_run_uploads/__pycache__/UnslothORPOTrainer.cpython-311.pyc
DELETED
|
Binary file (76.7 kB)
|
|
|
test_run_uploads/__pycache__/UnslothOnlineDPOTrainer.cpython-311.pyc
DELETED
|
Binary file (68.8 kB)
|
|
|
test_run_uploads/__pycache__/UnslothPPOTrainer.cpython-311.pyc
DELETED
|
Binary file (64.4 kB)
|
|
|
test_run_uploads/__pycache__/UnslothPRMTrainer.cpython-311.pyc
DELETED
|
Binary file (37.7 kB)
|
|
|
test_run_uploads/__pycache__/UnslothRLOOTrainer.cpython-311.pyc
DELETED
|
Binary file (55.6 kB)
|
|
|
test_run_uploads/__pycache__/UnslothRewardTrainer.cpython-311.pyc
DELETED
|
Binary file (40.2 kB)
|
|
|
test_run_uploads/__pycache__/UnslothSFTTrainer.cpython-311.pyc
DELETED
|
Binary file (52.4 kB)
|
|
|
test_run_uploads/__pycache__/UnslothXPOTrainer.cpython-311.pyc
DELETED
|
Binary file (51.6 kB)
|
|
|
test_run_uploads/checkpoint-50/README.md
DELETED
|
@@ -1,210 +0,0 @@
|
|
| 1 |
-
---
|
| 2 |
-
base_model: mistralai/Ministral-8B-Instruct-2410
|
| 3 |
-
library_name: peft
|
| 4 |
-
pipeline_tag: text-generation
|
| 5 |
-
tags:
|
| 6 |
-
- base_model:adapter:mistralai/Ministral-8B-Instruct-2410
|
| 7 |
-
- lora
|
| 8 |
-
- sft
|
| 9 |
-
- transformers
|
| 10 |
-
- trl
|
| 11 |
-
- unsloth
|
| 12 |
-
---
|
| 13 |
-
|
| 14 |
-
# Model Card for Model ID
|
| 15 |
-
|
| 16 |
-
<!-- Provide a quick summary of what the model is/does. -->
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
## Model Details
|
| 21 |
-
|
| 22 |
-
### Model Description
|
| 23 |
-
|
| 24 |
-
<!-- Provide a longer summary of what this model is. -->
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
- **Developed by:** [More Information Needed]
|
| 29 |
-
- **Funded by [optional]:** [More Information Needed]
|
| 30 |
-
- **Shared by [optional]:** [More Information Needed]
|
| 31 |
-
- **Model type:** [More Information Needed]
|
| 32 |
-
- **Language(s) (NLP):** [More Information Needed]
|
| 33 |
-
- **License:** [More Information Needed]
|
| 34 |
-
- **Finetuned from model [optional]:** [More Information Needed]
|
| 35 |
-
|
| 36 |
-
### Model Sources [optional]
|
| 37 |
-
|
| 38 |
-
<!-- Provide the basic links for the model. -->
|
| 39 |
-
|
| 40 |
-
- **Repository:** [More Information Needed]
|
| 41 |
-
- **Paper [optional]:** [More Information Needed]
|
| 42 |
-
- **Demo [optional]:** [More Information Needed]
|
| 43 |
-
|
| 44 |
-
## Uses
|
| 45 |
-
|
| 46 |
-
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 47 |
-
|
| 48 |
-
### Direct Use
|
| 49 |
-
|
| 50 |
-
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 51 |
-
|
| 52 |
-
[More Information Needed]
|
| 53 |
-
|
| 54 |
-
### Downstream Use [optional]
|
| 55 |
-
|
| 56 |
-
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 57 |
-
|
| 58 |
-
[More Information Needed]
|
| 59 |
-
|
| 60 |
-
### Out-of-Scope Use
|
| 61 |
-
|
| 62 |
-
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 63 |
-
|
| 64 |
-
[More Information Needed]
|
| 65 |
-
|
| 66 |
-
## Bias, Risks, and Limitations
|
| 67 |
-
|
| 68 |
-
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 69 |
-
|
| 70 |
-
[More Information Needed]
|
| 71 |
-
|
| 72 |
-
### Recommendations
|
| 73 |
-
|
| 74 |
-
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 75 |
-
|
| 76 |
-
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 77 |
-
|
| 78 |
-
## How to Get Started with the Model
|
| 79 |
-
|
| 80 |
-
Use the code below to get started with the model.
|
| 81 |
-
|
| 82 |
-
[More Information Needed]
|
| 83 |
-
|
| 84 |
-
## Training Details
|
| 85 |
-
|
| 86 |
-
### Training Data
|
| 87 |
-
|
| 88 |
-
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 89 |
-
|
| 90 |
-
[More Information Needed]
|
| 91 |
-
|
| 92 |
-
### Training Procedure
|
| 93 |
-
|
| 94 |
-
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 95 |
-
|
| 96 |
-
#### Preprocessing [optional]
|
| 97 |
-
|
| 98 |
-
[More Information Needed]
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
#### Training Hyperparameters
|
| 102 |
-
|
| 103 |
-
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 104 |
-
|
| 105 |
-
#### Speeds, Sizes, Times [optional]
|
| 106 |
-
|
| 107 |
-
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 108 |
-
|
| 109 |
-
[More Information Needed]
|
| 110 |
-
|
| 111 |
-
## Evaluation
|
| 112 |
-
|
| 113 |
-
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 114 |
-
|
| 115 |
-
### Testing Data, Factors & Metrics
|
| 116 |
-
|
| 117 |
-
#### Testing Data
|
| 118 |
-
|
| 119 |
-
<!-- This should link to a Dataset Card if possible. -->
|
| 120 |
-
|
| 121 |
-
[More Information Needed]
|
| 122 |
-
|
| 123 |
-
#### Factors
|
| 124 |
-
|
| 125 |
-
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 126 |
-
|
| 127 |
-
[More Information Needed]
|
| 128 |
-
|
| 129 |
-
#### Metrics
|
| 130 |
-
|
| 131 |
-
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 132 |
-
|
| 133 |
-
[More Information Needed]
|
| 134 |
-
|
| 135 |
-
### Results
|
| 136 |
-
|
| 137 |
-
[More Information Needed]
|
| 138 |
-
|
| 139 |
-
#### Summary
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
## Model Examination [optional]
|
| 144 |
-
|
| 145 |
-
<!-- Relevant interpretability work for the model goes here -->
|
| 146 |
-
|
| 147 |
-
[More Information Needed]
|
| 148 |
-
|
| 149 |
-
## Environmental Impact
|
| 150 |
-
|
| 151 |
-
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 152 |
-
|
| 153 |
-
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 154 |
-
|
| 155 |
-
- **Hardware Type:** [More Information Needed]
|
| 156 |
-
- **Hours used:** [More Information Needed]
|
| 157 |
-
- **Cloud Provider:** [More Information Needed]
|
| 158 |
-
- **Compute Region:** [More Information Needed]
|
| 159 |
-
- **Carbon Emitted:** [More Information Needed]
|
| 160 |
-
|
| 161 |
-
## Technical Specifications [optional]
|
| 162 |
-
|
| 163 |
-
### Model Architecture and Objective
|
| 164 |
-
|
| 165 |
-
[More Information Needed]
|
| 166 |
-
|
| 167 |
-
### Compute Infrastructure
|
| 168 |
-
|
| 169 |
-
[More Information Needed]
|
| 170 |
-
|
| 171 |
-
#### Hardware
|
| 172 |
-
|
| 173 |
-
[More Information Needed]
|
| 174 |
-
|
| 175 |
-
#### Software
|
| 176 |
-
|
| 177 |
-
[More Information Needed]
|
| 178 |
-
|
| 179 |
-
## Citation [optional]
|
| 180 |
-
|
| 181 |
-
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 182 |
-
|
| 183 |
-
**BibTeX:**
|
| 184 |
-
|
| 185 |
-
[More Information Needed]
|
| 186 |
-
|
| 187 |
-
**APA:**
|
| 188 |
-
|
| 189 |
-
[More Information Needed]
|
| 190 |
-
|
| 191 |
-
## Glossary [optional]
|
| 192 |
-
|
| 193 |
-
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 194 |
-
|
| 195 |
-
[More Information Needed]
|
| 196 |
-
|
| 197 |
-
## More Information [optional]
|
| 198 |
-
|
| 199 |
-
[More Information Needed]
|
| 200 |
-
|
| 201 |
-
## Model Card Authors [optional]
|
| 202 |
-
|
| 203 |
-
[More Information Needed]
|
| 204 |
-
|
| 205 |
-
## Model Card Contact
|
| 206 |
-
|
| 207 |
-
[More Information Needed]
|
| 208 |
-
### Framework versions
|
| 209 |
-
|
| 210 |
-
- PEFT 0.16.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_run_uploads/checkpoint-50/adapter_config.json
DELETED
|
@@ -1,41 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"alpha_pattern": {},
|
| 3 |
-
"auto_mapping": null,
|
| 4 |
-
"base_model_name_or_path": "mistralai/Ministral-8B-Instruct-2410",
|
| 5 |
-
"bias": "none",
|
| 6 |
-
"corda_config": null,
|
| 7 |
-
"eva_config": null,
|
| 8 |
-
"exclude_modules": null,
|
| 9 |
-
"fan_in_fan_out": false,
|
| 10 |
-
"inference_mode": true,
|
| 11 |
-
"init_lora_weights": true,
|
| 12 |
-
"layer_replication": null,
|
| 13 |
-
"layers_pattern": null,
|
| 14 |
-
"layers_to_transform": null,
|
| 15 |
-
"loftq_config": {},
|
| 16 |
-
"lora_alpha": 64,
|
| 17 |
-
"lora_bias": false,
|
| 18 |
-
"lora_dropout": 0,
|
| 19 |
-
"megatron_config": null,
|
| 20 |
-
"megatron_core": "megatron.core",
|
| 21 |
-
"modules_to_save": null,
|
| 22 |
-
"peft_type": "LORA",
|
| 23 |
-
"qalora_group_size": 16,
|
| 24 |
-
"r": 32,
|
| 25 |
-
"rank_pattern": {},
|
| 26 |
-
"revision": null,
|
| 27 |
-
"target_modules": [
|
| 28 |
-
"up_proj",
|
| 29 |
-
"gate_proj",
|
| 30 |
-
"q_proj",
|
| 31 |
-
"o_proj",
|
| 32 |
-
"v_proj",
|
| 33 |
-
"down_proj",
|
| 34 |
-
"k_proj"
|
| 35 |
-
],
|
| 36 |
-
"task_type": "CAUSAL_LM",
|
| 37 |
-
"trainable_token_indices": null,
|
| 38 |
-
"use_dora": false,
|
| 39 |
-
"use_qalora": false,
|
| 40 |
-
"use_rslora": false
|
| 41 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_run_uploads/checkpoint-50/adapter_model.safetensors
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:7c22732e7777cc816d7e7503316cbac9b3806566322e1f6bab5d429ea8766f00
|
| 3 |
-
size 349243752
|
|
|
|
|
|
|
|
|
|
|
|
test_run_uploads/checkpoint-50/chat_template.jinja
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% if messages[1]['role'] == 'user' %}{{ '[INST] ' + messages[0]['content'] + ' ' + messages[1]['content'] + ' [/INST]' }}{% set loop_messages = messages[2:] %}{% else %}{{ '[INST] ' + messages[0]['content'] + ' [/INST]' }}{% set loop_messages = messages[1:] %}{% endif %}{% else %}{% set loop_messages = messages %}{% endif %}{% for message in loop_messages %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}
|
|
|
|
|
|
test_run_uploads/checkpoint-50/optimizer.pt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:9eb84082a1889da4e199d66b7cabd1c0dbbee3a7097bc5f7aebb331e4786a6d6
|
| 3 |
-
size 177918917
|
|
|
|
|
|
|
|
|
|
|
|
test_run_uploads/checkpoint-50/rng_state.pth
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:181c5f0270cf39930062ddfa3767a2481d0c360f120b11f8e25dbf533a1cdaba
|
| 3 |
-
size 14645
|
|
|
|
|
|
|
|
|
|
|
|
test_run_uploads/checkpoint-50/scaler.pt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:5cd0e9d505fbc3f97feb166d29026132bdf14eb3e5c7ff77beebc303ee666f96
|
| 3 |
-
size 1383
|
|
|
|
|
|
|
|
|
|
|
|
test_run_uploads/checkpoint-50/scheduler.pt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:3f43a6155628947732c83ac3165bbc211721c396e9e3b246bdecdaaf19583e1c
|
| 3 |
-
size 1465
|
|
|
|
|
|
|
|
|
|
|
|
test_run_uploads/checkpoint-50/special_tokens_map.json
DELETED
|
@@ -1,24 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"bos_token": {
|
| 3 |
-
"content": "<s>",
|
| 4 |
-
"lstrip": false,
|
| 5 |
-
"normalized": false,
|
| 6 |
-
"rstrip": false,
|
| 7 |
-
"single_word": false
|
| 8 |
-
},
|
| 9 |
-
"eos_token": {
|
| 10 |
-
"content": "</s>",
|
| 11 |
-
"lstrip": false,
|
| 12 |
-
"normalized": false,
|
| 13 |
-
"rstrip": false,
|
| 14 |
-
"single_word": false
|
| 15 |
-
},
|
| 16 |
-
"pad_token": "<pad>",
|
| 17 |
-
"unk_token": {
|
| 18 |
-
"content": "<unk>",
|
| 19 |
-
"lstrip": false,
|
| 20 |
-
"normalized": false,
|
| 21 |
-
"rstrip": false,
|
| 22 |
-
"single_word": false
|
| 23 |
-
}
|
| 24 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_run_uploads/checkpoint-50/tokenizer.json
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:a7fc0f8e08693e6deb5bbb0cd3ab7431131567cc69bd3a67fd6da0e3c7ee58e4
|
| 3 |
-
size 17078391
|
|
|
|
|
|
|
|
|
|
|
|
test_run_uploads/checkpoint-50/tokenizer_config.json
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
test_run_uploads/checkpoint-50/trainer_state.json
DELETED
|
@@ -1,77 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"best_global_step": null,
|
| 3 |
-
"best_metric": Infinity,
|
| 4 |
-
"best_model_checkpoint": null,
|
| 5 |
-
"epoch": 0.007038783698176955,
|
| 6 |
-
"eval_steps": 50,
|
| 7 |
-
"global_step": 50,
|
| 8 |
-
"is_hyper_param_search": false,
|
| 9 |
-
"is_local_process_zero": true,
|
| 10 |
-
"is_world_process_zero": true,
|
| 11 |
-
"log_history": [
|
| 12 |
-
{
|
| 13 |
-
"epoch": 0.001407756739635391,
|
| 14 |
-
"grad_norm": 4.217477321624756,
|
| 15 |
-
"learning_rate": 1.8e-06,
|
| 16 |
-
"loss": 1.9593,
|
| 17 |
-
"step": 10
|
| 18 |
-
},
|
| 19 |
-
{
|
| 20 |
-
"epoch": 0.002815513479270782,
|
| 21 |
-
"grad_norm": 6.792465686798096,
|
| 22 |
-
"learning_rate": 3.8e-06,
|
| 23 |
-
"loss": 1.8226,
|
| 24 |
-
"step": 20
|
| 25 |
-
},
|
| 26 |
-
{
|
| 27 |
-
"epoch": 0.004223270218906173,
|
| 28 |
-
"grad_norm": 3.987929344177246,
|
| 29 |
-
"learning_rate": 5.8e-06,
|
| 30 |
-
"loss": 1.5628,
|
| 31 |
-
"step": 30
|
| 32 |
-
},
|
| 33 |
-
{
|
| 34 |
-
"epoch": 0.005631026958541564,
|
| 35 |
-
"grad_norm": 3.203339099884033,
|
| 36 |
-
"learning_rate": 7.8e-06,
|
| 37 |
-
"loss": 1.2142,
|
| 38 |
-
"step": 40
|
| 39 |
-
},
|
| 40 |
-
{
|
| 41 |
-
"epoch": 0.007038783698176955,
|
| 42 |
-
"grad_norm": 4.646796226501465,
|
| 43 |
-
"learning_rate": 9.800000000000001e-06,
|
| 44 |
-
"loss": 0.8943,
|
| 45 |
-
"step": 50
|
| 46 |
-
},
|
| 47 |
-
{
|
| 48 |
-
"epoch": 0.007038783698176955,
|
| 49 |
-
"eval_loss": NaN,
|
| 50 |
-
"eval_runtime": 3184.6841,
|
| 51 |
-
"eval_samples_per_second": 1.093,
|
| 52 |
-
"eval_steps_per_second": 0.182,
|
| 53 |
-
"step": 50
|
| 54 |
-
}
|
| 55 |
-
],
|
| 56 |
-
"logging_steps": 10,
|
| 57 |
-
"max_steps": 90,
|
| 58 |
-
"num_input_tokens_seen": 0,
|
| 59 |
-
"num_train_epochs": 1,
|
| 60 |
-
"save_steps": 50,
|
| 61 |
-
"stateful_callbacks": {
|
| 62 |
-
"TrainerControl": {
|
| 63 |
-
"args": {
|
| 64 |
-
"should_epoch_stop": false,
|
| 65 |
-
"should_evaluate": false,
|
| 66 |
-
"should_log": false,
|
| 67 |
-
"should_save": true,
|
| 68 |
-
"should_training_stop": false
|
| 69 |
-
},
|
| 70 |
-
"attributes": {}
|
| 71 |
-
}
|
| 72 |
-
},
|
| 73 |
-
"total_flos": 9110440274558976.0,
|
| 74 |
-
"train_batch_size": 2,
|
| 75 |
-
"trial_name": null,
|
| 76 |
-
"trial_params": null
|
| 77 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_run_uploads/checkpoint-50/training_args.bin
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:97c4848f189bc8ef55d633cdc5629ac09caf902f18ecbe802fee52f91633580d
|
| 3 |
-
size 6097
|
|
|
|
|
|
|
|
|
|
|
|
test_run_uploads/checkpoint-90/README.md
DELETED
|
@@ -1,210 +0,0 @@
|
|
| 1 |
-
---
|
| 2 |
-
base_model: mistralai/Ministral-8B-Instruct-2410
|
| 3 |
-
library_name: peft
|
| 4 |
-
pipeline_tag: text-generation
|
| 5 |
-
tags:
|
| 6 |
-
- base_model:adapter:mistralai/Ministral-8B-Instruct-2410
|
| 7 |
-
- lora
|
| 8 |
-
- sft
|
| 9 |
-
- transformers
|
| 10 |
-
- trl
|
| 11 |
-
- unsloth
|
| 12 |
-
---
|
| 13 |
-
|
| 14 |
-
# Model Card for Model ID
|
| 15 |
-
|
| 16 |
-
<!-- Provide a quick summary of what the model is/does. -->
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
## Model Details
|
| 21 |
-
|
| 22 |
-
### Model Description
|
| 23 |
-
|
| 24 |
-
<!-- Provide a longer summary of what this model is. -->
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
- **Developed by:** [More Information Needed]
|
| 29 |
-
- **Funded by [optional]:** [More Information Needed]
|
| 30 |
-
- **Shared by [optional]:** [More Information Needed]
|
| 31 |
-
- **Model type:** [More Information Needed]
|
| 32 |
-
- **Language(s) (NLP):** [More Information Needed]
|
| 33 |
-
- **License:** [More Information Needed]
|
| 34 |
-
- **Finetuned from model [optional]:** [More Information Needed]
|
| 35 |
-
|
| 36 |
-
### Model Sources [optional]
|
| 37 |
-
|
| 38 |
-
<!-- Provide the basic links for the model. -->
|
| 39 |
-
|
| 40 |
-
- **Repository:** [More Information Needed]
|
| 41 |
-
- **Paper [optional]:** [More Information Needed]
|
| 42 |
-
- **Demo [optional]:** [More Information Needed]
|
| 43 |
-
|
| 44 |
-
## Uses
|
| 45 |
-
|
| 46 |
-
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 47 |
-
|
| 48 |
-
### Direct Use
|
| 49 |
-
|
| 50 |
-
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 51 |
-
|
| 52 |
-
[More Information Needed]
|
| 53 |
-
|
| 54 |
-
### Downstream Use [optional]
|
| 55 |
-
|
| 56 |
-
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 57 |
-
|
| 58 |
-
[More Information Needed]
|
| 59 |
-
|
| 60 |
-
### Out-of-Scope Use
|
| 61 |
-
|
| 62 |
-
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 63 |
-
|
| 64 |
-
[More Information Needed]
|
| 65 |
-
|
| 66 |
-
## Bias, Risks, and Limitations
|
| 67 |
-
|
| 68 |
-
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 69 |
-
|
| 70 |
-
[More Information Needed]
|
| 71 |
-
|
| 72 |
-
### Recommendations
|
| 73 |
-
|
| 74 |
-
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 75 |
-
|
| 76 |
-
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 77 |
-
|
| 78 |
-
## How to Get Started with the Model
|
| 79 |
-
|
| 80 |
-
Use the code below to get started with the model.
|
| 81 |
-
|
| 82 |
-
[More Information Needed]
|
| 83 |
-
|
| 84 |
-
## Training Details
|
| 85 |
-
|
| 86 |
-
### Training Data
|
| 87 |
-
|
| 88 |
-
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 89 |
-
|
| 90 |
-
[More Information Needed]
|
| 91 |
-
|
| 92 |
-
### Training Procedure
|
| 93 |
-
|
| 94 |
-
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 95 |
-
|
| 96 |
-
#### Preprocessing [optional]
|
| 97 |
-
|
| 98 |
-
[More Information Needed]
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
#### Training Hyperparameters
|
| 102 |
-
|
| 103 |
-
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 104 |
-
|
| 105 |
-
#### Speeds, Sizes, Times [optional]
|
| 106 |
-
|
| 107 |
-
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 108 |
-
|
| 109 |
-
[More Information Needed]
|
| 110 |
-
|
| 111 |
-
## Evaluation
|
| 112 |
-
|
| 113 |
-
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 114 |
-
|
| 115 |
-
### Testing Data, Factors & Metrics
|
| 116 |
-
|
| 117 |
-
#### Testing Data
|
| 118 |
-
|
| 119 |
-
<!-- This should link to a Dataset Card if possible. -->
|
| 120 |
-
|
| 121 |
-
[More Information Needed]
|
| 122 |
-
|
| 123 |
-
#### Factors
|
| 124 |
-
|
| 125 |
-
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 126 |
-
|
| 127 |
-
[More Information Needed]
|
| 128 |
-
|
| 129 |
-
#### Metrics
|
| 130 |
-
|
| 131 |
-
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 132 |
-
|
| 133 |
-
[More Information Needed]
|
| 134 |
-
|
| 135 |
-
### Results
|
| 136 |
-
|
| 137 |
-
[More Information Needed]
|
| 138 |
-
|
| 139 |
-
#### Summary
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
## Model Examination [optional]
|
| 144 |
-
|
| 145 |
-
<!-- Relevant interpretability work for the model goes here -->
|
| 146 |
-
|
| 147 |
-
[More Information Needed]
|
| 148 |
-
|
| 149 |
-
## Environmental Impact
|
| 150 |
-
|
| 151 |
-
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 152 |
-
|
| 153 |
-
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 154 |
-
|
| 155 |
-
- **Hardware Type:** [More Information Needed]
|
| 156 |
-
- **Hours used:** [More Information Needed]
|
| 157 |
-
- **Cloud Provider:** [More Information Needed]
|
| 158 |
-
- **Compute Region:** [More Information Needed]
|
| 159 |
-
- **Carbon Emitted:** [More Information Needed]
|
| 160 |
-
|
| 161 |
-
## Technical Specifications [optional]
|
| 162 |
-
|
| 163 |
-
### Model Architecture and Objective
|
| 164 |
-
|
| 165 |
-
[More Information Needed]
|
| 166 |
-
|
| 167 |
-
### Compute Infrastructure
|
| 168 |
-
|
| 169 |
-
[More Information Needed]
|
| 170 |
-
|
| 171 |
-
#### Hardware
|
| 172 |
-
|
| 173 |
-
[More Information Needed]
|
| 174 |
-
|
| 175 |
-
#### Software
|
| 176 |
-
|
| 177 |
-
[More Information Needed]
|
| 178 |
-
|
| 179 |
-
## Citation [optional]
|
| 180 |
-
|
| 181 |
-
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 182 |
-
|
| 183 |
-
**BibTeX:**
|
| 184 |
-
|
| 185 |
-
[More Information Needed]
|
| 186 |
-
|
| 187 |
-
**APA:**
|
| 188 |
-
|
| 189 |
-
[More Information Needed]
|
| 190 |
-
|
| 191 |
-
## Glossary [optional]
|
| 192 |
-
|
| 193 |
-
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 194 |
-
|
| 195 |
-
[More Information Needed]
|
| 196 |
-
|
| 197 |
-
## More Information [optional]
|
| 198 |
-
|
| 199 |
-
[More Information Needed]
|
| 200 |
-
|
| 201 |
-
## Model Card Authors [optional]
|
| 202 |
-
|
| 203 |
-
[More Information Needed]
|
| 204 |
-
|
| 205 |
-
## Model Card Contact
|
| 206 |
-
|
| 207 |
-
[More Information Needed]
|
| 208 |
-
### Framework versions
|
| 209 |
-
|
| 210 |
-
- PEFT 0.16.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_run_uploads/checkpoint-90/adapter_config.json
DELETED
|
@@ -1,41 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"alpha_pattern": {},
|
| 3 |
-
"auto_mapping": null,
|
| 4 |
-
"base_model_name_or_path": "mistralai/Ministral-8B-Instruct-2410",
|
| 5 |
-
"bias": "none",
|
| 6 |
-
"corda_config": null,
|
| 7 |
-
"eva_config": null,
|
| 8 |
-
"exclude_modules": null,
|
| 9 |
-
"fan_in_fan_out": false,
|
| 10 |
-
"inference_mode": true,
|
| 11 |
-
"init_lora_weights": true,
|
| 12 |
-
"layer_replication": null,
|
| 13 |
-
"layers_pattern": null,
|
| 14 |
-
"layers_to_transform": null,
|
| 15 |
-
"loftq_config": {},
|
| 16 |
-
"lora_alpha": 64,
|
| 17 |
-
"lora_bias": false,
|
| 18 |
-
"lora_dropout": 0,
|
| 19 |
-
"megatron_config": null,
|
| 20 |
-
"megatron_core": "megatron.core",
|
| 21 |
-
"modules_to_save": null,
|
| 22 |
-
"peft_type": "LORA",
|
| 23 |
-
"qalora_group_size": 16,
|
| 24 |
-
"r": 32,
|
| 25 |
-
"rank_pattern": {},
|
| 26 |
-
"revision": null,
|
| 27 |
-
"target_modules": [
|
| 28 |
-
"up_proj",
|
| 29 |
-
"gate_proj",
|
| 30 |
-
"q_proj",
|
| 31 |
-
"o_proj",
|
| 32 |
-
"v_proj",
|
| 33 |
-
"down_proj",
|
| 34 |
-
"k_proj"
|
| 35 |
-
],
|
| 36 |
-
"task_type": "CAUSAL_LM",
|
| 37 |
-
"trainable_token_indices": null,
|
| 38 |
-
"use_dora": false,
|
| 39 |
-
"use_qalora": false,
|
| 40 |
-
"use_rslora": false
|
| 41 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_run_uploads/checkpoint-90/adapter_model.safetensors
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:ef1e45c7329d233afbb23de1796591557b08e47a395529287f0ddf873bd719d9
|
| 3 |
-
size 349243752
|
|
|
|
|
|
|
|
|
|
|
|