starsofchance commited on
Commit
5c93775
·
verified ·
1 Parent(s): 0007115

Delete test_run_uploads/ with huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. test_run_uploads/UnslothAlignPropTrainer.py +0 -646
  2. test_run_uploads/UnslothBCOTrainer.py +0 -1834
  3. test_run_uploads/UnslothCPOTrainer.py +0 -1566
  4. test_run_uploads/UnslothDDPOTrainer.py +0 -881
  5. test_run_uploads/UnslothDPOTrainer.py +0 -0
  6. test_run_uploads/UnslothGKDTrainer.py +0 -885
  7. test_run_uploads/UnslothGRPOTrainer.py +0 -0
  8. test_run_uploads/UnslothKTOTrainer.py +0 -1849
  9. test_run_uploads/UnslothNashMDTrainer.py +0 -969
  10. test_run_uploads/UnslothORPOTrainer.py +0 -1552
  11. test_run_uploads/UnslothOnlineDPOTrainer.py +0 -1293
  12. test_run_uploads/UnslothPPOTrainer.py +0 -1273
  13. test_run_uploads/UnslothPRMTrainer.py +0 -809
  14. test_run_uploads/UnslothRLOOTrainer.py +0 -1143
  15. test_run_uploads/UnslothRewardTrainer.py +0 -828
  16. test_run_uploads/UnslothSFTTrainer.py +0 -1102
  17. test_run_uploads/UnslothXPOTrainer.py +0 -1024
  18. test_run_uploads/__pycache__/UnslothAlignPropTrainer.cpython-311.pyc +0 -0
  19. test_run_uploads/__pycache__/UnslothBCOTrainer.cpython-311.pyc +0 -0
  20. test_run_uploads/__pycache__/UnslothCPOTrainer.cpython-311.pyc +0 -0
  21. test_run_uploads/__pycache__/UnslothDDPOTrainer.cpython-311.pyc +0 -0
  22. test_run_uploads/__pycache__/UnslothDPOTrainer.cpython-311.pyc +0 -3
  23. test_run_uploads/__pycache__/UnslothGKDTrainer.cpython-311.pyc +0 -0
  24. test_run_uploads/__pycache__/UnslothGRPOTrainer.cpython-311.pyc +0 -0
  25. test_run_uploads/__pycache__/UnslothKTOTrainer.cpython-311.pyc +0 -0
  26. test_run_uploads/__pycache__/UnslothNashMDTrainer.cpython-311.pyc +0 -0
  27. test_run_uploads/__pycache__/UnslothORPOTrainer.cpython-311.pyc +0 -0
  28. test_run_uploads/__pycache__/UnslothOnlineDPOTrainer.cpython-311.pyc +0 -0
  29. test_run_uploads/__pycache__/UnslothPPOTrainer.cpython-311.pyc +0 -0
  30. test_run_uploads/__pycache__/UnslothPRMTrainer.cpython-311.pyc +0 -0
  31. test_run_uploads/__pycache__/UnslothRLOOTrainer.cpython-311.pyc +0 -0
  32. test_run_uploads/__pycache__/UnslothRewardTrainer.cpython-311.pyc +0 -0
  33. test_run_uploads/__pycache__/UnslothSFTTrainer.cpython-311.pyc +0 -0
  34. test_run_uploads/__pycache__/UnslothXPOTrainer.cpython-311.pyc +0 -0
  35. test_run_uploads/checkpoint-50/README.md +0 -210
  36. test_run_uploads/checkpoint-50/adapter_config.json +0 -41
  37. test_run_uploads/checkpoint-50/adapter_model.safetensors +0 -3
  38. test_run_uploads/checkpoint-50/chat_template.jinja +0 -1
  39. test_run_uploads/checkpoint-50/optimizer.pt +0 -3
  40. test_run_uploads/checkpoint-50/rng_state.pth +0 -3
  41. test_run_uploads/checkpoint-50/scaler.pt +0 -3
  42. test_run_uploads/checkpoint-50/scheduler.pt +0 -3
  43. test_run_uploads/checkpoint-50/special_tokens_map.json +0 -24
  44. test_run_uploads/checkpoint-50/tokenizer.json +0 -3
  45. test_run_uploads/checkpoint-50/tokenizer_config.json +0 -0
  46. test_run_uploads/checkpoint-50/trainer_state.json +0 -77
  47. test_run_uploads/checkpoint-50/training_args.bin +0 -3
  48. test_run_uploads/checkpoint-90/README.md +0 -210
  49. test_run_uploads/checkpoint-90/adapter_config.json +0 -41
  50. test_run_uploads/checkpoint-90/adapter_model.safetensors +0 -3
test_run_uploads/UnslothAlignPropTrainer.py DELETED
@@ -1,646 +0,0 @@
1
- """
2
- 2025.7.11
3
- 2025.7.11
4
- 4.54.1
5
- 0.16.1
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
13
- from trl.trainer.alignprop_trainer import (Accelerator, AlignPropConfig, AlignPropTrainer, Any, Callable, DDPOStableDiffusionPipeline, Optional, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, wandb, warn)
14
-
15
-
16
- import os
17
- from typing import *
18
- from dataclasses import dataclass, field
19
- from packaging.version import Version
20
- import torch
21
- import numpy as np
22
- from contextlib import nullcontext
23
- from torch.nn import functional as F
24
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
25
-
26
- torch_compile_options = {
27
- "epilogue_fusion" : True,
28
- "max_autotune" : False,
29
- "shape_padding" : True,
30
- "trace.enabled" : False,
31
- "triton.cudagraphs" : False,
32
- }
33
-
34
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
35
- def chunked_selective_log_softmax(logits, index):
36
- # Split into 4 chunks only
37
- chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
38
- chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
39
- all_per_token_logps = []
40
- # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
41
- for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
42
- chunk_logits = chunk_logits.to(torch.float32)
43
- selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
44
- logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
45
- per_token_logps = selected_logits - logsumexp_values
46
- all_per_token_logps.append(per_token_logps)
47
- pass
48
- all_per_token_logps = torch.concat(all_per_token_logps)
49
- all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
50
- return all_per_token_logps
51
- @dataclass
52
- class UnslothAlignPropConfig(AlignPropConfig):
53
- """
54
-
55
- Configuration class for the [`AlignPropTrainer`].
56
-
57
- Using [`~transformers.HfArgumentParser`] we can turn this class into
58
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
59
- command line.
60
-
61
- Parameters:
62
- exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
63
- Name of this experiment (defaults to the file name without the extension).
64
- run_name (`str`, *optional*, defaults to `""`):
65
- Name of this run.
66
- seed (`int`, *optional*, defaults to `0`):
67
- Random seed for reproducibility.
68
- log_with (`str` or `None`, *optional*, defaults to `None`):
69
- Log with either `"wandb"` or `"tensorboard"`. Check
70
- [tracking](https://huggingface.co/docs/accelerate/usage_guides/tracking) for more details.
71
- log_image_freq (`int`, *optional*, defaults to `1`):
72
- Frequency for logging images.
73
- tracker_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
74
- Keyword arguments for the tracker (e.g., `wandb_project`).
75
- accelerator_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
76
- Keyword arguments for the accelerator.
77
- project_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
78
- Keyword arguments for the accelerator project config (e.g., `logging_dir`).
79
- tracker_project_name (`str`, *optional*, defaults to `"trl"`):
80
- Name of project to use for tracking.
81
- logdir (`str`, *optional*, defaults to `"logs"`):
82
- Top-level logging directory for checkpoint saving.
83
- num_epochs (`int`, *optional*, defaults to `100`):
84
- Number of epochs to train.
85
- save_freq (`int`, *optional*, defaults to `1`):
86
- Number of epochs between saving model checkpoints.
87
- num_checkpoint_limit (`int`, *optional*, defaults to `5`):
88
- Number of checkpoints to keep before overwriting old ones.
89
- mixed_precision (`str`, *optional*, defaults to `"fp16"`):
90
- Mixed precision training.
91
- allow_tf32 (`bool`, *optional*, defaults to `True`):
92
- Allow `tf32` on Ampere GPUs.
93
- resume_from (`str`, *optional*, defaults to `""`):
94
- Path to resume training from a checkpoint.
95
- sample_num_steps (`int`, *optional*, defaults to `50`):
96
- Number of sampler inference steps.
97
- sample_eta (`float`, *optional*, defaults to `1.0`):
98
- Eta parameter for the DDIM sampler.
99
- sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
100
- Classifier-free guidance weight.
101
- train_batch_size (`int`, *optional*, defaults to `1`):
102
- Batch size for training.
103
- train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
104
- Whether to use the 8bit Adam optimizer from `bitsandbytes`.
105
- train_learning_rate (`float`, *optional*, defaults to `1e-3`):
106
- Learning rate.
107
- train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
108
- Beta1 for Adam optimizer.
109
- train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
110
- Beta2 for Adam optimizer.
111
- train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
112
- Weight decay for Adam optimizer.
113
- train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
114
- Epsilon value for Adam optimizer.
115
- train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
116
- Number of gradient accumulation steps.
117
- train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
118
- Maximum gradient norm for gradient clipping.
119
- negative_prompts (`str` or `None`, *optional*, defaults to `None`):
120
- Comma-separated list of prompts to use as negative examples.
121
- truncated_backprop_rand (`bool`, *optional*, defaults to `True`):
122
- If `True`, randomized truncation to different diffusion timesteps is used.
123
- truncated_backprop_timestep (`int`, *optional*, defaults to `49`):
124
- Absolute timestep to which the gradients are backpropagated. Used only if `truncated_backprop_rand=False`.
125
- truncated_rand_backprop_minmax (`tuple[int, int]`, *optional*, defaults to `(0, 50)`):
126
- Range of diffusion timesteps for randomized truncated backpropagation.
127
- push_to_hub (`bool`, *optional*, defaults to `False`):
128
- Whether to push the final model to the Hub.
129
-
130
- """
131
- vllm_sampling_params: Optional[Any] = field(
132
- default = None,
133
- metadata = {'help': 'vLLM SamplingParams'},
134
- )
135
- unsloth_num_chunks : Optional[int] = field(
136
- default = -1,
137
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
138
- )
139
- def __init__(
140
- self,
141
- exp_name = 'colab_kernel_launcher',
142
- run_name = '',
143
- seed = 3407,
144
- log_with = None,
145
- log_image_freq = 1,
146
- tracker_project_name = 'trl',
147
- logdir = 'logs',
148
- num_epochs = 100,
149
- save_freq = 1,
150
- num_checkpoint_limit = 5,
151
- mixed_precision = 'fp16',
152
- allow_tf32 = True,
153
- resume_from = '',
154
- sample_num_steps = 50,
155
- sample_eta = 1.0,
156
- sample_guidance_scale = 5.0,
157
- train_batch_size = 1,
158
- train_use_8bit_adam = False,
159
- train_learning_rate = 5e-05,
160
- train_adam_beta1 = 0.9,
161
- train_adam_beta2 = 0.999,
162
- train_adam_weight_decay = 0.01,
163
- train_adam_epsilon = 1e-08,
164
- train_gradient_accumulation_steps = 2,
165
- train_max_grad_norm = 1.0,
166
- negative_prompts = None,
167
- truncated_backprop_rand = True,
168
- truncated_backprop_timestep = 49,
169
- push_to_hub = False,
170
- vllm_sampling_params = None,
171
- unsloth_num_chunks = -1,
172
- **kwargs,
173
- ):
174
-
175
- super().__init__(
176
- exp_name = exp_name,
177
- run_name = run_name,
178
- seed = seed,
179
- log_with = log_with,
180
- log_image_freq = log_image_freq,
181
- tracker_project_name = tracker_project_name,
182
- logdir = logdir,
183
- num_epochs = num_epochs,
184
- save_freq = save_freq,
185
- num_checkpoint_limit = num_checkpoint_limit,
186
- mixed_precision = mixed_precision,
187
- allow_tf32 = allow_tf32,
188
- resume_from = resume_from,
189
- sample_num_steps = sample_num_steps,
190
- sample_eta = sample_eta,
191
- sample_guidance_scale = sample_guidance_scale,
192
- train_batch_size = train_batch_size,
193
- train_use_8bit_adam = train_use_8bit_adam,
194
- train_learning_rate = train_learning_rate,
195
- train_adam_beta1 = train_adam_beta1,
196
- train_adam_beta2 = train_adam_beta2,
197
- train_adam_weight_decay = train_adam_weight_decay,
198
- train_adam_epsilon = train_adam_epsilon,
199
- train_gradient_accumulation_steps = train_gradient_accumulation_steps,
200
- train_max_grad_norm = train_max_grad_norm,
201
- negative_prompts = negative_prompts,
202
- truncated_backprop_rand = truncated_backprop_rand,
203
- truncated_backprop_timestep = truncated_backprop_timestep,
204
- push_to_hub = push_to_hub,**kwargs)
205
- self.vllm_sampling_params = vllm_sampling_params
206
- self.unsloth_num_chunks = unsloth_num_chunks
207
- pass
208
-
209
- class _UnslothAlignPropTrainer(PyTorchModelHubMixin):
210
- """"""
211
-
212
- _tag_names = ["trl", "alignprop"]
213
-
214
- def __init__(
215
- self,
216
- config: AlignPropConfig,
217
- reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
218
- prompt_function: Callable[[], tuple[str, Any]],
219
- sd_pipeline: DDPOStableDiffusionPipeline,
220
- image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
221
- ):
222
- if image_samples_hook is None:
223
- warn("No image_samples_hook provided; no images will be logged")
224
-
225
- self.prompt_fn = prompt_function
226
- self.reward_fn = reward_function
227
- self.config = config
228
- self.image_samples_callback = image_samples_hook
229
-
230
- accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
231
-
232
- if self.config.resume_from:
233
- self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
234
- if "checkpoint_" not in os.path.basename(self.config.resume_from):
235
- # get the most recent checkpoint in this directory
236
- checkpoints = list(
237
- filter(
238
- lambda x: "checkpoint_" in x,
239
- os.listdir(self.config.resume_from),
240
- )
241
- )
242
- if len(checkpoints) == 0:
243
- raise ValueError(f"No checkpoints found in {self.config.resume_from}")
244
- checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
245
- self.config.resume_from = os.path.join(
246
- self.config.resume_from,
247
- f"checkpoint_{checkpoint_numbers[-1]}",
248
- )
249
-
250
- accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
251
-
252
- self.accelerator = Accelerator(
253
- log_with=self.config.log_with,
254
- mixed_precision=self.config.mixed_precision,
255
- project_config=accelerator_project_config,
256
- # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
257
- # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
258
- # the total number of optimizer steps to accumulate across.
259
- gradient_accumulation_steps=self.config.train_gradient_accumulation_steps,
260
- **self.config.accelerator_kwargs,
261
- )
262
-
263
- is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
264
-
265
- if self.accelerator.is_main_process:
266
- self.accelerator.init_trackers(
267
- self.config.tracker_project_name,
268
- config=dict(alignprop_trainer_config=config.to_dict())
269
- if not is_using_tensorboard
270
- else config.to_dict(),
271
- init_kwargs=self.config.tracker_kwargs,
272
- )
273
-
274
- logger.info(f"\n{config}")
275
-
276
- set_seed(self.config.seed, device_specific=True)
277
-
278
- self.sd_pipeline = sd_pipeline
279
-
280
- self.sd_pipeline.set_progress_bar_config(
281
- position=1,
282
- disable=not self.accelerator.is_local_main_process,
283
- leave=False,
284
- desc="Timestep",
285
- dynamic_ncols=True,
286
- )
287
-
288
- # For mixed precision training we cast all non-trainable weights [vae, non-lora text_encoder and non-lora unet] to half-precision
289
- # as these weights are only used for inference, keeping weights in full precision is not required.
290
- if self.accelerator.mixed_precision == "fp16":
291
- inference_dtype = torch.float16
292
- elif self.accelerator.mixed_precision == "bf16":
293
- inference_dtype = torch.bfloat16
294
- else:
295
- inference_dtype = torch.float32
296
-
297
- self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
298
- self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
299
- self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
300
-
301
- trainable_layers = self.sd_pipeline.get_trainable_layers()
302
-
303
- self.accelerator.register_save_state_pre_hook(self._save_model_hook)
304
- self.accelerator.register_load_state_pre_hook(self._load_model_hook)
305
-
306
- # Enable TF32 for faster training on Ampere GPUs,
307
- # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
308
- if self.config.allow_tf32:
309
- torch.backends.cuda.matmul.allow_tf32 = True
310
-
311
- self.optimizer = self._setup_optimizer(
312
- trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
313
- )
314
-
315
- self.neg_prompt_embed = self.sd_pipeline.text_encoder(
316
- self.sd_pipeline.tokenizer(
317
- [""] if self.config.negative_prompts is None else self.config.negative_prompts,
318
- return_tensors="pt",
319
- padding="max_length",
320
- truncation=True,
321
- max_length=self.sd_pipeline.tokenizer.model_max_length,
322
- ).input_ids.to(self.accelerator.device)
323
- )[0]
324
-
325
- # NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
326
- # more memory
327
- self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
328
-
329
- if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
330
- unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
331
- self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
332
- else:
333
- self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
334
-
335
- if config.resume_from:
336
- logger.info(f"Resuming from {config.resume_from}")
337
- self.accelerator.load_state(config.resume_from)
338
- self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
339
- else:
340
- self.first_epoch = 0
341
-
342
- def compute_rewards(self, prompt_image_pairs):
343
- reward, reward_metadata = self.reward_fn(
344
- prompt_image_pairs["images"], prompt_image_pairs["prompts"], prompt_image_pairs["prompt_metadata"]
345
- )
346
- return reward
347
-
348
- def step(self, epoch: int, global_step: int):
349
- """
350
- Perform a single step of training.
351
-
352
- Args:
353
- epoch (int): The current epoch.
354
- global_step (int): The current global step.
355
-
356
- Side Effects:
357
- - Model weights are updated
358
- - Logs the statistics to the accelerator trackers.
359
- - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
360
-
361
- Returns:
362
- global_step (int): The updated global step.
363
- """
364
- info = defaultdict(list)
365
-
366
- self.sd_pipeline.unet.train()
367
-
368
- for _ in range(self.config.train_gradient_accumulation_steps):
369
- with self.accelerator.accumulate(self.sd_pipeline.unet), self.autocast(), torch.enable_grad():
370
- prompt_image_pairs = self._generate_samples(
371
- batch_size=self.config.train_batch_size,
372
- )
373
-
374
- rewards = self.compute_rewards(prompt_image_pairs)
375
-
376
- prompt_image_pairs["rewards"] = rewards
377
-
378
- rewards_vis = self.accelerator.gather(rewards).detach().cpu().numpy()
379
-
380
- loss = self.calculate_loss(rewards)
381
-
382
- self.accelerator.backward(loss)
383
-
384
- if self.accelerator.sync_gradients:
385
- self.accelerator.clip_grad_norm_(
386
- self.trainable_layers.parameters()
387
- if not isinstance(self.trainable_layers, list)
388
- else self.trainable_layers,
389
- self.config.train_max_grad_norm,
390
- )
391
-
392
- self.optimizer.step()
393
- self.optimizer.zero_grad()
394
-
395
- info["reward_mean"].append(rewards_vis.mean())
396
- info["reward_std"].append(rewards_vis.std())
397
- info["loss"].append(loss.item())
398
-
399
- # Checks if the accelerator has performed an optimization step behind the scenes
400
- if self.accelerator.sync_gradients:
401
- # log training-related stuff
402
- info = {k: torch.mean(torch.tensor(v)) for k, v in info.items()}
403
- info = self.accelerator.reduce(info, reduction="mean")
404
- info.update({"epoch": epoch})
405
- self.accelerator.log(info, step=global_step)
406
- global_step += 1
407
- info = defaultdict(list)
408
- else:
409
- raise ValueError(
410
- "Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
411
- )
412
- # Logs generated images
413
- if self.image_samples_callback is not None and global_step % self.config.log_image_freq == 0:
414
- self.image_samples_callback(prompt_image_pairs, global_step, self.accelerator.trackers[0])
415
-
416
- if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
417
- self.accelerator.save_state()
418
-
419
- return global_step
420
-
421
- def calculate_loss(self, rewards):
422
- """
423
- Calculate the loss for a batch of an unpacked sample
424
-
425
- Args:
426
- rewards (torch.Tensor):
427
- Differentiable reward scalars for each generated image, shape: [batch_size]
428
-
429
- Returns:
430
- loss (torch.Tensor)
431
- (all of these are of shape (1,))
432
- """
433
- # Loss is specific to Aesthetic Reward function used in AlignProp (https://huggingface.co/papers/2310.03739)
434
- loss = 10.0 - (rewards).mean()
435
- return loss
436
-
437
- def loss(
438
- self,
439
- advantages: torch.Tensor,
440
- clip_range: float,
441
- ratio: torch.Tensor,
442
- ):
443
- unclipped_loss = -advantages * ratio
444
- clipped_loss = -advantages * torch.clamp(
445
- ratio,
446
- 1.0 - clip_range,
447
- 1.0 + clip_range,
448
- )
449
- return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
450
-
451
- def _setup_optimizer(self, trainable_layers_parameters):
452
- if self.config.train_use_8bit_adam:
453
- import bitsandbytes
454
-
455
- optimizer_cls = bitsandbytes.optim.AdamW8bit
456
- else:
457
- optimizer_cls = torch.optim.AdamW
458
-
459
- return optimizer_cls(
460
- trainable_layers_parameters,
461
- lr=self.config.train_learning_rate,
462
- betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
463
- weight_decay=self.config.train_adam_weight_decay,
464
- eps=self.config.train_adam_epsilon,
465
- )
466
-
467
- def _save_model_hook(self, models, weights, output_dir):
468
- self.sd_pipeline.save_checkpoint(models, weights, output_dir)
469
- weights.pop() # ensures that accelerate doesn't try to handle saving of the model
470
-
471
- def _load_model_hook(self, models, input_dir):
472
- self.sd_pipeline.load_checkpoint(models, input_dir)
473
- models.pop() # ensures that accelerate doesn't try to handle loading of the model
474
-
475
- def _generate_samples(self, batch_size, with_grad=True, prompts=None):
476
- """
477
- Generate samples from the model
478
-
479
- Args:
480
- batch_size (int): Batch size to use for sampling
481
- with_grad (bool): Whether the generated RGBs should have gradients attached to it.
482
-
483
- Returns:
484
- prompt_image_pairs (dict[Any])
485
- """
486
- prompt_image_pairs = {}
487
-
488
- sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
489
-
490
- if prompts is None:
491
- prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
492
- else:
493
- prompt_metadata = [{} for _ in range(batch_size)]
494
-
495
- prompt_ids = self.sd_pipeline.tokenizer(
496
- prompts,
497
- return_tensors="pt",
498
- padding="max_length",
499
- truncation=True,
500
- max_length=self.sd_pipeline.tokenizer.model_max_length,
501
- ).input_ids.to(self.accelerator.device)
502
-
503
- prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
504
-
505
- if with_grad:
506
- sd_output = self.sd_pipeline.rgb_with_grad(
507
- prompt_embeds=prompt_embeds,
508
- negative_prompt_embeds=sample_neg_prompt_embeds,
509
- num_inference_steps=self.config.sample_num_steps,
510
- guidance_scale=self.config.sample_guidance_scale,
511
- eta=self.config.sample_eta,
512
- truncated_backprop_rand=self.config.truncated_backprop_rand,
513
- truncated_backprop_timestep=self.config.truncated_backprop_timestep,
514
- truncated_rand_backprop_minmax=self.config.truncated_rand_backprop_minmax,
515
- output_type="pt",
516
- )
517
- else:
518
- sd_output = self.sd_pipeline(
519
- prompt_embeds=prompt_embeds,
520
- negative_prompt_embeds=sample_neg_prompt_embeds,
521
- num_inference_steps=self.config.sample_num_steps,
522
- guidance_scale=self.config.sample_guidance_scale,
523
- eta=self.config.sample_eta,
524
- output_type="pt",
525
- )
526
-
527
- images = sd_output.images
528
-
529
- prompt_image_pairs["images"] = images
530
- prompt_image_pairs["prompts"] = prompts
531
- prompt_image_pairs["prompt_metadata"] = prompt_metadata
532
-
533
- return prompt_image_pairs
534
-
535
- def train(self, epochs: Optional[int] = None):
536
- """
537
- Train the model for a given number of epochs
538
- """
539
- global_step = 0
540
- if epochs is None:
541
- epochs = self.config.num_epochs
542
- for epoch in range(self.first_epoch, epochs):
543
- global_step = self.step(epoch, global_step)
544
-
545
- def _save_pretrained(self, save_directory):
546
- self.sd_pipeline.save_pretrained(save_directory)
547
- self.create_model_card()
548
-
549
- def create_model_card(
550
- self,
551
- model_name: Optional[str] = None,
552
- dataset_name: Optional[str] = None,
553
- tags: Union[str, list[str], None] = None,
554
- ):
555
- """
556
- Creates a draft of a model card using the information available to the `Trainer`.
557
-
558
- Args:
559
- model_name (`str` or `None`, *optional*, defaults to `None`):
560
- Name of the model.
561
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
562
- Name of the dataset used for training.
563
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
564
- Tags to be associated with the model card.
565
- """
566
- if not self.is_world_process_zero():
567
- return
568
-
569
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
570
- base_model = self.model.config._name_or_path
571
- else:
572
- base_model = None
573
-
574
- tags = tags or []
575
- if isinstance(tags, str):
576
- tags = [tags]
577
-
578
- if hasattr(self.model.config, "unsloth_version"):
579
- tags.append("unsloth")
580
-
581
- citation = textwrap.dedent("""\
582
- @article{prabhudesai2024aligning,
583
- title = {{Aligning Text-to-Image Diffusion Models with Reward Backpropagation}},
584
- author = {Mihir Prabhudesai and Anirudh Goyal and Deepak Pathak and Katerina Fragkiadaki},
585
- year = 2024,
586
- eprint = {arXiv:2310.03739}
587
- }""")
588
-
589
- model_card = generate_model_card(
590
- base_model=base_model,
591
- model_name=model_name,
592
- hub_model_id=self.hub_model_id,
593
- dataset_name=dataset_name,
594
- tags=tags,
595
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
596
- comet_url=get_comet_experiment_url(),
597
- trainer_name="AlignProp",
598
- trainer_citation=citation,
599
- paper_title="Aligning Text-to-Image Diffusion Models with Reward Backpropagation",
600
- paper_id="2310.03739",
601
- )
602
-
603
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
604
- class UnslothAlignPropTrainer(_UnslothAlignPropTrainer):
605
- """
606
-
607
- The AlignPropTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
608
- Note, this trainer is heavily inspired by the work here: https://github.com/mihirp1998/AlignProp/
609
- As of now only Stable Diffusion based pipelines are supported
610
-
611
- Attributes:
612
- config (`AlignPropConfig`):
613
- Configuration object for AlignPropTrainer. Check the documentation of `PPOConfig` for more details.
614
- reward_function (`Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]`):
615
- Reward function to be used
616
- prompt_function (`Callable[[], tuple[str, Any]]`):
617
- Function to generate prompts to guide model
618
- sd_pipeline (`DDPOStableDiffusionPipeline`):
619
- Stable Diffusion pipeline to be used for training.
620
- image_samples_hook (`Optional[Callable[[Any, Any, Any], Any]]`):
621
- Hook to be called to log images
622
-
623
- """
624
- def __init__(
625
- self,
626
- config,
627
- reward_function,
628
- prompt_function,
629
- sd_pipeline,
630
- image_samples_hook = None,
631
- **kwargs
632
- ):
633
- if args is None: args = UnslothAlignPropConfig()
634
- other_metrics = []
635
-
636
- from unsloth_zoo.logging_utils import PatchRLStatistics
637
- PatchRLStatistics('alignprop_trainer', other_metrics)
638
-
639
- super().__init__(
640
- config = config,
641
- reward_function = reward_function,
642
- prompt_function = prompt_function,
643
- sd_pipeline = sd_pipeline,
644
- image_samples_hook = image_samples_hook,**kwargs)
645
-
646
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_run_uploads/UnslothBCOTrainer.py DELETED
@@ -1,1834 +0,0 @@
1
- """
2
- 2025.7.11
3
- 2025.7.11
4
- 4.54.1
5
- 0.16.1
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
13
- from trl.trainer.bco_trainer import (Any, AutoModelForCausalLM, BCOConfig, BCOTrainer, BaseImageProcessor, CLF_NAME, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, LogisticRegression, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, RUNNING_NAME, RunningMoments, SequentialSampler, Trainer, TrainerCallback, TrainingArguments, Union, _process_tokens, _tokenize, amp, contextmanager, create_reference_model, deepcopy, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, has_length, inspect, is_comet_available, is_peft_available, is_sklearn_available, is_wandb_available, itemgetter, log_table_to_comet_experiment, maybe_apply_chat_template, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, tqdm, transformers, version, wandb, warnings, F, Optional, PeftModel, PreTrainedModel, Trainer, is_peft_available, os, torch)
14
-
15
-
16
- import os
17
- from typing import *
18
- from dataclasses import dataclass, field
19
- from packaging.version import Version
20
- import torch
21
- import numpy as np
22
- from contextlib import nullcontext
23
- from torch.nn import functional as F
24
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
25
-
26
- torch_compile_options = {
27
- "epilogue_fusion" : True,
28
- "max_autotune" : False,
29
- "shape_padding" : True,
30
- "trace.enabled" : False,
31
- "triton.cudagraphs" : False,
32
- }
33
-
34
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
35
- def chunked_selective_log_softmax(logits, index):
36
- # Split into 4 chunks only
37
- chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
38
- chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
39
- all_per_token_logps = []
40
- # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
41
- for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
42
- chunk_logits = chunk_logits.to(torch.float32)
43
- selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
44
- logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
45
- per_token_logps = selected_logits - logsumexp_values
46
- all_per_token_logps.append(per_token_logps)
47
- pass
48
- all_per_token_logps = torch.concat(all_per_token_logps)
49
- all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
50
- return all_per_token_logps
51
- @dataclass
52
- class UnslothBCOConfig(BCOConfig):
53
- """
54
-
55
- Configuration class for the [`BCOTrainer`].
56
-
57
- Using [`~transformers.HfArgumentParser`] we can turn this class into
58
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
59
- command line.
60
-
61
- Parameters:
62
- max_length (`int` or `None`, *optional*, defaults to `1024`):
63
- Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
64
- to use the default data collator.
65
- max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
66
- Maximum length of the prompt. This argument is required if you want to use the default data collator.
67
- max_completion_length (`int` or `None`, *optional*, defaults to `None`):
68
- Maximum length of the completion. This argument is required if you want to use the default data collator
69
- and your model is an encoder-decoder.
70
- beta (`float`, *optional*, defaults to `0.1`):
71
- Parameter controlling the deviation from the reference model. Higher β means less deviation from the
72
- reference model.
73
- label_pad_token_id (`int`, *optional*, defaults to `-100`):
74
- Label pad token id. This argument is required if you want to use the default data collator.
75
- padding_value (`int` or `None`, *optional*, defaults to `None`):
76
- Padding value to use. If `None`, the padding value of the tokenizer is used.
77
- truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
78
- Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
79
- This argument is required if you want to use the default data collator.
80
- disable_dropout (`bool`, *optional*, defaults to `True`):
81
- Whether to disable dropout in the model and reference model.
82
- generate_during_eval (`bool`, *optional*, defaults to `False`):
83
- If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during
84
- evaluation.
85
- is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
86
- When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
87
- you need to specify if the model returned by the callable is an encoder-decoder model.
88
- precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
89
- Whether to precompute reference model log probabilities for training and evaluation datasets. This is
90
- useful when training without the reference model to reduce the total GPU memory needed.
91
- model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
92
- Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
93
- string.
94
- ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
95
- Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model
96
- from a string.
97
- dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
98
- Number of processes to use for processing the dataset.
99
- prompt_sample_size (`int`, *optional*, defaults to `1024`):
100
- Number of prompts that are fed to density ratio classifier.
101
- min_density_ratio (`float`, *optional*, defaults to `0.5`):
102
- Minimum value of the density ratio. The estimated density ratio is clamped to this value.
103
- max_density_ratio (`float`, *optional*, defaults to `10.0`):
104
- Maximum value of the density ratio. The estimated density ratio is clamped to this value.
105
-
106
- """
107
- vllm_sampling_params: Optional[Any] = field(
108
- default = None,
109
- metadata = {'help': 'vLLM SamplingParams'},
110
- )
111
- unsloth_num_chunks : Optional[int] = field(
112
- default = -1,
113
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
114
- )
115
- def __init__(
116
- self,
117
- output_dir = None,
118
- overwrite_output_dir = None,
119
- do_train = False,
120
- do_eval = False,
121
- do_predict = False,
122
- eval_strategy = 'no',
123
- prediction_loss_only = False,
124
- per_device_train_batch_size = 4,
125
- per_device_eval_batch_size = 4,
126
- per_gpu_train_batch_size = None,
127
- per_gpu_eval_batch_size = None,
128
- gradient_accumulation_steps = 2,
129
- eval_accumulation_steps = 2,
130
- eval_delay = 0,
131
- torch_empty_cache_steps = 250,
132
- learning_rate = 5e-05,
133
- weight_decay = 0.01,
134
- adam_beta1 = 0.9,
135
- adam_beta2 = 0.999,
136
- adam_epsilon = 1e-08,
137
- max_grad_norm = 1.0,
138
- num_train_epochs = 3.0,
139
- max_steps = -1,
140
- lr_scheduler_type = 'linear',
141
- warmup_ratio = 0.1,
142
- warmup_steps = 0,
143
- log_level = 'passive',
144
- log_level_replica = 'warning',
145
- log_on_each_node = True,
146
- logging_dir = None,
147
- logging_strategy = 'steps',
148
- logging_first_step = False,
149
- logging_steps = 1,
150
- logging_nan_inf_filter = False,
151
- save_strategy = 'steps',
152
- save_steps = 500,
153
- save_total_limit = None,
154
- save_safetensors = True,
155
- save_on_each_node = False,
156
- save_only_model = False,
157
- restore_callback_states_from_checkpoint = False,
158
- no_cuda = False,
159
- use_cpu = False,
160
- use_mps_device = False,
161
- seed = 3407,
162
- data_seed = 3407,
163
- jit_mode_eval = False,
164
- use_ipex = False,
165
- bf16 = False,
166
- fp16 = False,
167
- fp16_opt_level = 'O1',
168
- half_precision_backend = 'auto',
169
- bf16_full_eval = False,
170
- fp16_full_eval = False,
171
- tf32 = None,
172
- local_rank = -1,
173
- ddp_backend = None,
174
- tpu_num_cores = None,
175
- tpu_metrics_debug = False,
176
- debug = '',
177
- dataloader_drop_last = False,
178
- eval_steps = None,
179
- dataloader_num_workers = 0,
180
- dataloader_prefetch_factor = None,
181
- past_index = -1,
182
- run_name = None,
183
- disable_tqdm = None,
184
- remove_unused_columns = True,
185
- label_names = None,
186
- load_best_model_at_end = False,
187
- metric_for_best_model = None,
188
- greater_is_better = None,
189
- ignore_data_skip = False,
190
- fsdp = '',
191
- fsdp_min_num_params = 0,
192
- fsdp_config = None,
193
- fsdp_transformer_layer_cls_to_wrap = None,
194
- accelerator_config = None,
195
- deepspeed = None,
196
- label_smoothing_factor = 0.0,
197
- optim = 'adamw_8bit',
198
- optim_args = None,
199
- adafactor = False,
200
- group_by_length = False,
201
- length_column_name = 'length',
202
- report_to = None,
203
- ddp_find_unused_parameters = None,
204
- ddp_bucket_cap_mb = None,
205
- ddp_broadcast_buffers = None,
206
- dataloader_pin_memory = True,
207
- dataloader_persistent_workers = False,
208
- skip_memory_metrics = True,
209
- use_legacy_prediction_loop = False,
210
- push_to_hub = False,
211
- resume_from_checkpoint = None,
212
- hub_model_id = None,
213
- hub_strategy = 'every_save',
214
- hub_token = None,
215
- hub_private_repo = None,
216
- hub_always_push = False,
217
- hub_revision = None,
218
- gradient_checkpointing = False,
219
- gradient_checkpointing_kwargs = None,
220
- include_inputs_for_metrics = False,
221
- eval_do_concat_batches = True,
222
- fp16_backend = 'auto',
223
- push_to_hub_model_id = None,
224
- push_to_hub_organization = None,
225
- push_to_hub_token = None,
226
- mp_parameters = '',
227
- auto_find_batch_size = True,
228
- full_determinism = False,
229
- torchdynamo = None,
230
- ray_scope = 'last',
231
- ddp_timeout = 1800,
232
- torch_compile = False,
233
- torch_compile_backend = None,
234
- torch_compile_mode = None,
235
- include_tokens_per_second = False,
236
- include_num_input_tokens_seen = False,
237
- neftune_noise_alpha = None,
238
- optim_target_modules = None,
239
- batch_eval_metrics = False,
240
- eval_on_start = False,
241
- use_liger_kernel = False,
242
- liger_kernel_config = None,
243
- eval_use_gather_object = False,
244
- average_tokens_across_devices = True,
245
- max_length = 1024,
246
- max_prompt_length = 512,
247
- max_completion_length = None,
248
- beta = 0.1,
249
- label_pad_token_id = -100,
250
- padding_value = None,
251
- truncation_mode = 'keep_end',
252
- disable_dropout = True,
253
- generate_during_eval = False,
254
- is_encoder_decoder = None,
255
- precompute_ref_log_probs = False,
256
- model_init_kwargs = None,
257
- ref_model_init_kwargs = None,
258
- dataset_num_proc = None,
259
- prompt_sample_size = 1024,
260
- min_density_ratio = 0.5,
261
- max_density_ratio = 10.0,
262
- vllm_sampling_params = None,
263
- unsloth_num_chunks = -1,
264
- **kwargs,
265
- ):
266
- if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
267
- if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
268
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
269
- output_dir = 'unsloth_training_checkpoints'
270
- save_strategy = 'no'
271
- if dataset_num_proc is None:
272
- from multiprocessing import cpu_count
273
- dataset_num_proc = min(cpu_count()*2, 2)
274
-
275
- super().__init__(
276
- output_dir = output_dir,
277
- overwrite_output_dir = overwrite_output_dir,
278
- do_train = do_train,
279
- do_eval = do_eval,
280
- do_predict = do_predict,
281
- eval_strategy = eval_strategy,
282
- prediction_loss_only = prediction_loss_only,
283
- per_device_train_batch_size = per_device_train_batch_size,
284
- per_device_eval_batch_size = per_device_eval_batch_size,
285
- per_gpu_train_batch_size = per_gpu_train_batch_size,
286
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
287
- gradient_accumulation_steps = gradient_accumulation_steps,
288
- eval_accumulation_steps = eval_accumulation_steps,
289
- eval_delay = eval_delay,
290
- torch_empty_cache_steps = torch_empty_cache_steps,
291
- learning_rate = learning_rate,
292
- weight_decay = weight_decay,
293
- adam_beta1 = adam_beta1,
294
- adam_beta2 = adam_beta2,
295
- adam_epsilon = adam_epsilon,
296
- max_grad_norm = max_grad_norm,
297
- num_train_epochs = num_train_epochs,
298
- max_steps = max_steps,
299
- lr_scheduler_type = lr_scheduler_type,
300
- warmup_ratio = warmup_ratio,
301
- warmup_steps = warmup_steps,
302
- log_level = log_level,
303
- log_level_replica = log_level_replica,
304
- log_on_each_node = log_on_each_node,
305
- logging_dir = logging_dir,
306
- logging_strategy = logging_strategy,
307
- logging_first_step = logging_first_step,
308
- logging_steps = logging_steps,
309
- logging_nan_inf_filter = logging_nan_inf_filter,
310
- save_strategy = save_strategy,
311
- save_steps = save_steps,
312
- save_total_limit = save_total_limit,
313
- save_safetensors = save_safetensors,
314
- save_on_each_node = save_on_each_node,
315
- save_only_model = save_only_model,
316
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
317
- no_cuda = no_cuda,
318
- use_cpu = use_cpu,
319
- use_mps_device = use_mps_device,
320
- seed = seed,
321
- data_seed = data_seed,
322
- jit_mode_eval = jit_mode_eval,
323
- use_ipex = use_ipex,
324
- bf16 = bf16,
325
- fp16 = fp16,
326
- fp16_opt_level = fp16_opt_level,
327
- half_precision_backend = half_precision_backend,
328
- bf16_full_eval = bf16_full_eval,
329
- fp16_full_eval = fp16_full_eval,
330
- tf32 = tf32,
331
- local_rank = local_rank,
332
- ddp_backend = ddp_backend,
333
- tpu_num_cores = tpu_num_cores,
334
- tpu_metrics_debug = tpu_metrics_debug,
335
- debug = debug,
336
- dataloader_drop_last = dataloader_drop_last,
337
- eval_steps = eval_steps,
338
- dataloader_num_workers = dataloader_num_workers,
339
- dataloader_prefetch_factor = dataloader_prefetch_factor,
340
- past_index = past_index,
341
- run_name = run_name,
342
- disable_tqdm = disable_tqdm,
343
- remove_unused_columns = remove_unused_columns,
344
- label_names = label_names,
345
- load_best_model_at_end = load_best_model_at_end,
346
- metric_for_best_model = metric_for_best_model,
347
- greater_is_better = greater_is_better,
348
- ignore_data_skip = ignore_data_skip,
349
- fsdp = fsdp,
350
- fsdp_min_num_params = fsdp_min_num_params,
351
- fsdp_config = fsdp_config,
352
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
353
- accelerator_config = accelerator_config,
354
- deepspeed = deepspeed,
355
- label_smoothing_factor = label_smoothing_factor,
356
- optim = optim,
357
- optim_args = optim_args,
358
- adafactor = adafactor,
359
- group_by_length = group_by_length,
360
- length_column_name = length_column_name,
361
- report_to = report_to,
362
- ddp_find_unused_parameters = ddp_find_unused_parameters,
363
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
364
- ddp_broadcast_buffers = ddp_broadcast_buffers,
365
- dataloader_pin_memory = dataloader_pin_memory,
366
- dataloader_persistent_workers = dataloader_persistent_workers,
367
- skip_memory_metrics = skip_memory_metrics,
368
- use_legacy_prediction_loop = use_legacy_prediction_loop,
369
- push_to_hub = push_to_hub,
370
- resume_from_checkpoint = resume_from_checkpoint,
371
- hub_model_id = hub_model_id,
372
- hub_strategy = hub_strategy,
373
- hub_token = hub_token,
374
- hub_private_repo = hub_private_repo,
375
- hub_always_push = hub_always_push,
376
- hub_revision = hub_revision,
377
- gradient_checkpointing = gradient_checkpointing,
378
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
379
- include_inputs_for_metrics = include_inputs_for_metrics,
380
- eval_do_concat_batches = eval_do_concat_batches,
381
- fp16_backend = fp16_backend,
382
- push_to_hub_model_id = push_to_hub_model_id,
383
- push_to_hub_organization = push_to_hub_organization,
384
- push_to_hub_token = push_to_hub_token,
385
- mp_parameters = mp_parameters,
386
- auto_find_batch_size = auto_find_batch_size,
387
- full_determinism = full_determinism,
388
- torchdynamo = torchdynamo,
389
- ray_scope = ray_scope,
390
- ddp_timeout = ddp_timeout,
391
- torch_compile = torch_compile,
392
- torch_compile_backend = torch_compile_backend,
393
- torch_compile_mode = torch_compile_mode,
394
- include_tokens_per_second = include_tokens_per_second,
395
- include_num_input_tokens_seen = include_num_input_tokens_seen,
396
- neftune_noise_alpha = neftune_noise_alpha,
397
- optim_target_modules = optim_target_modules,
398
- batch_eval_metrics = batch_eval_metrics,
399
- eval_on_start = eval_on_start,
400
- use_liger_kernel = use_liger_kernel,
401
- liger_kernel_config = liger_kernel_config,
402
- eval_use_gather_object = eval_use_gather_object,
403
- average_tokens_across_devices = average_tokens_across_devices,
404
- max_length = max_length,
405
- max_prompt_length = max_prompt_length,
406
- max_completion_length = max_completion_length,
407
- beta = beta,
408
- label_pad_token_id = label_pad_token_id,
409
- padding_value = padding_value,
410
- truncation_mode = truncation_mode,
411
- disable_dropout = disable_dropout,
412
- generate_during_eval = generate_during_eval,
413
- is_encoder_decoder = is_encoder_decoder,
414
- precompute_ref_log_probs = precompute_ref_log_probs,
415
- model_init_kwargs = model_init_kwargs,
416
- ref_model_init_kwargs = ref_model_init_kwargs,
417
- dataset_num_proc = dataset_num_proc,
418
- prompt_sample_size = prompt_sample_size,
419
- min_density_ratio = min_density_ratio,
420
- max_density_ratio = max_density_ratio,**kwargs)
421
- self.vllm_sampling_params = vllm_sampling_params
422
- self.unsloth_num_chunks = unsloth_num_chunks
423
- pass
424
-
425
- class _UnslothBCOTrainer(Trainer):
426
- r""""""
427
-
428
- _tag_names = ["trl", "bco"]
429
-
430
- def __init__(
431
- self,
432
- model: Union[PreTrainedModel, nn.Module, str] = None,
433
- ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
434
- args: BCOConfig = None,
435
- train_dataset: Optional[Dataset] = None,
436
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
437
- processing_class: Optional[
438
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
439
- ] = None,
440
- data_collator: Optional[DataCollator] = None,
441
- model_init: Optional[Callable[[], PreTrainedModel]] = None,
442
- callbacks: Optional[list[TrainerCallback]] = None,
443
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
444
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
445
- peft_config: Optional[dict] = None,
446
- compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
447
- model_adapter_name: Optional[str] = None,
448
- ref_adapter_name: Optional[str] = None,
449
- embedding_func: Optional[Callable] = None,
450
- embedding_tokenizer: Optional[PreTrainedTokenizerBase] = None,
451
- ):
452
- if not is_sklearn_available():
453
- raise ImportError(
454
- "BCOTrainer requires the scikit-learn library. Please install it with `pip install scikit-learn`."
455
- )
456
-
457
- if type(args) is TrainingArguments:
458
- raise ValueError("Please use `BCOConfig` instead `TrainingArguments`.")
459
-
460
- if not isinstance(model, str) and ref_model is model:
461
- raise ValueError(
462
- "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
463
- "same as `model`, you must mass a copy of it, or `None` if you use peft."
464
- )
465
-
466
- if args.model_init_kwargs is None:
467
- model_init_kwargs = {}
468
- elif not isinstance(model, str):
469
- raise ValueError("You passed model_kwargs to the BCOTrainer. But your model is already instantiated.")
470
- else:
471
- model_init_kwargs = args.model_init_kwargs
472
- torch_dtype = model_init_kwargs.get("torch_dtype")
473
- if torch_dtype is not None:
474
- # Convert to `torch.dtype` if an str is passed
475
- if isinstance(torch_dtype, str) and torch_dtype != "auto":
476
- torch_dtype = getattr(torch, torch_dtype)
477
- if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
478
- raise ValueError(
479
- f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
480
- )
481
- model_init_kwargs["torch_dtype"] = torch_dtype
482
-
483
- if args.ref_model_init_kwargs is None:
484
- ref_model_init_kwargs = {}
485
- elif not isinstance(ref_model, str):
486
- raise ValueError(
487
- "You passed ref_model_kwargs to the BCOTrainer. But your ref_model is already instantiated."
488
- )
489
- else:
490
- ref_model_init_kwargs = args.ref_model_init_kwargs
491
- torch_dtype = ref_model_init_kwargs.get("torch_dtype")
492
- if torch_dtype is not None:
493
- # Convert to `torch.dtype` if an str is passed
494
- if isinstance(torch_dtype, str) and torch_dtype != "auto":
495
- torch_dtype = getattr(torch, torch_dtype)
496
- if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
497
- raise ValueError(
498
- f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
499
- )
500
- ref_model_init_kwargs["torch_dtype"] = torch_dtype
501
-
502
- if isinstance(model, str):
503
- model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
504
-
505
- if isinstance(ref_model, str):
506
- ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
507
-
508
- # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
509
- # has been called in order to properly call autocast if needed.
510
- self._peft_has_been_casted_to_bf16 = False
511
-
512
- if not is_peft_available() and peft_config is not None:
513
- raise ValueError(
514
- "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
515
- )
516
- elif is_peft_available() and peft_config is not None:
517
- # if model is a peft model and we have a peft_config, we merge and unload it first
518
- if isinstance(model, PeftModel):
519
- model = model.merge_and_unload()
520
-
521
- if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
522
- _support_gc_kwargs = hasattr(
523
- args, "gradient_checkpointing_kwargs"
524
- ) and "gradient_checkpointing_kwargs" in list(
525
- inspect.signature(prepare_model_for_kbit_training).parameters
526
- )
527
-
528
- prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
529
-
530
- if _support_gc_kwargs:
531
- prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
532
-
533
- model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
534
- elif getattr(args, "gradient_checkpointing", False):
535
- # For backward compatibility with older versions of transformers
536
- if hasattr(model, "enable_input_require_grads"):
537
- model.enable_input_require_grads()
538
- else:
539
-
540
- def make_inputs_require_grad(module, input, output):
541
- output.requires_grad_(True)
542
-
543
- model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
544
-
545
- # get peft model with the given config
546
- model = model
547
- if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
548
- peft_module_casting_to_bf16(model)
549
- # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
550
- self._peft_has_been_casted_to_bf16 = True
551
-
552
- # For models that use gradient_checkpointing, we need to attach a hook that enables input
553
- # to explicitly have `requires_grad=True`, otherwise training will either silently
554
- # fail or completely fail.
555
- elif getattr(args, "gradient_checkpointing", False):
556
- # For backward compatibility with older versions of transformers
557
- if hasattr(model, "enable_input_require_grads"):
558
- model.enable_input_require_grads()
559
- else:
560
-
561
- def make_inputs_require_grad(module, input, output):
562
- output.requires_grad_(True)
563
-
564
- model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
565
-
566
- if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
567
- raise ValueError(
568
- "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
569
- " Please install `wandb` or `comet-ml` to resolve."
570
- )
571
-
572
- if model is not None:
573
- self.is_encoder_decoder = model.config.is_encoder_decoder
574
- elif args.is_encoder_decoder is None:
575
- raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
576
- else:
577
- self.is_encoder_decoder = args.is_encoder_decoder
578
-
579
- self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
580
- self.model_adapter_name = model_adapter_name
581
- self.ref_adapter_name = ref_adapter_name
582
-
583
- if ref_model:
584
- self.ref_model = ref_model
585
- elif self.is_peft_model or args.precompute_ref_log_probs:
586
- # The `model` with adapters turned off will be used as the reference model
587
- self.ref_model = None
588
- else:
589
- self.ref_model = create_reference_model(model)
590
-
591
- if processing_class is None:
592
- raise ValueError(
593
- "max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
594
- )
595
- if args.max_length is None:
596
- warnings.warn(
597
- "When using DPODataCollatorWithPadding, you should set `max_length` in the `BCOConfig`. "
598
- "It will be set to `512` by default, but you should do it yourself in the future.",
599
- UserWarning,
600
- )
601
- max_length = 512
602
- if args.max_length is not None:
603
- max_length = args.max_length
604
-
605
- if args.max_prompt_length is None:
606
- warnings.warn(
607
- "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the `BCOConfig`. "
608
- "It will be set to `128` by default, but you should do it yourself in the future.",
609
- UserWarning,
610
- )
611
- max_prompt_length = 128
612
- if args.max_prompt_length is not None:
613
- max_prompt_length = args.max_prompt_length
614
-
615
- max_completion_length = None
616
- if args.max_completion_length is None and self.is_encoder_decoder:
617
- warnings.warn(
618
- "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the BCOTrainer's init"
619
- " it will be set to `128` by default, but you should do it yourself in the future.",
620
- UserWarning,
621
- )
622
- max_completion_length = 128
623
- if args.max_completion_length is not None and self.is_encoder_decoder:
624
- max_completion_length = args.max_completion_length
625
-
626
- if data_collator is None:
627
- data_collator = DPODataCollatorWithPadding(
628
- pad_token_id=processing_class.pad_token_id,
629
- label_pad_token_id=args.label_pad_token_id,
630
- is_encoder_decoder=self.is_encoder_decoder,
631
- )
632
-
633
- if args.remove_unused_columns:
634
- args.remove_unused_columns = False
635
- # warn users
636
- warnings.warn(
637
- "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your BCOConfig"
638
- " we have set it for you, but you should do it yourself in the future.",
639
- UserWarning,
640
- )
641
-
642
- self.use_dpo_data_collator = True
643
- else:
644
- self.use_dpo_data_collator = False
645
-
646
- # Disable dropout in the model and reference model
647
- if args.disable_dropout:
648
- disable_dropout_in_model(model)
649
- if self.ref_model is not None:
650
- disable_dropout_in_model(self.ref_model)
651
-
652
- self.max_length = max_length
653
- self.generate_during_eval = args.generate_during_eval
654
- self.label_pad_token_id = args.label_pad_token_id
655
- self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
656
- self.max_prompt_length = max_prompt_length
657
- self.truncation_mode = args.truncation_mode
658
- self.max_completion_length = max_completion_length
659
- self.precompute_ref_log_probs = args.precompute_ref_log_probs
660
-
661
- # Since ref_logs are precomputed on the first call to get_train/eval_dataloader
662
- # keep track of first called to avoid computation of future calls
663
- self._precomputed_train_ref_log_probs = False
664
- self._precomputed_eval_ref_log_probs = False
665
-
666
- # metric
667
- self._stored_metrics = defaultdict(lambda: defaultdict(list))
668
-
669
- # BCO parameter
670
- self.beta = args.beta
671
- self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
672
- self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
673
- if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
674
- warnings.warn(
675
- "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
676
- "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
677
- "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
678
- "loss.",
679
- UserWarning,
680
- )
681
-
682
- # Underlying Distribution Matching argument
683
- self.embedding_func = embedding_func
684
- self.embedding_tokenizer = embedding_tokenizer
685
-
686
- # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
687
- # input tensor associated with the key "input_ids". However, in BCO, the sampled data does not include the
688
- # "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
689
- # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
690
- # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
691
- # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
692
- # issued.
693
- model.warnings_issued["estimate_tokens"] = True
694
-
695
- with PartialState().main_process_first():
696
- # Apply the chat template if needed
697
- train_dataset = train_dataset.map(
698
- maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
699
- )
700
- if eval_dataset is not None:
701
- eval_dataset = eval_dataset.map(
702
- maybe_apply_chat_template,
703
- fn_kwargs={"tokenizer": processing_class},
704
- num_proc=args.dataset_num_proc,
705
- )
706
- # Shuffle the datasets
707
- train_dataset = train_dataset.shuffle(seed=args.data_seed)
708
- if eval_dataset is not None:
709
- eval_dataset = eval_dataset.shuffle(seed=args.data_seed)
710
- # Tokenize and prepare the training datasets
711
- train_dataset = train_dataset.map(
712
- _tokenize,
713
- batched=True,
714
- fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer},
715
- num_proc=args.dataset_num_proc,
716
- desc="Tokenizing train dataset",
717
- )
718
-
719
- # Prepare the datasets
720
- fn_kwargs = {
721
- "prefix": "",
722
- "is_encoder_decoder": self.is_encoder_decoder,
723
- "tokenizer": processing_class,
724
- "max_length": self.max_length,
725
- "truncation_mode": self.truncation_mode,
726
- "label_pad_token_id": self.label_pad_token_id,
727
- "max_prompt_length": self.max_prompt_length,
728
- "max_completion_length": self.max_completion_length,
729
- }
730
- train_dataset = train_dataset.map(
731
- _process_tokens,
732
- fn_kwargs=fn_kwargs,
733
- num_proc=args.dataset_num_proc,
734
- desc="Processing tokenized train dataset",
735
- )
736
-
737
- if eval_dataset is not None:
738
- # Tokenize
739
- eval_dataset = eval_dataset.map(
740
- _tokenize,
741
- fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer},
742
- batched=True,
743
- num_proc=args.dataset_num_proc,
744
- desc="Tokenizing eval dataset",
745
- )
746
-
747
- # Process
748
- fn_kwargs = {
749
- "prefix": "",
750
- "is_encoder_decoder": self.is_encoder_decoder,
751
- "tokenizer": processing_class,
752
- "max_length": self.max_length,
753
- "truncation_mode": self.truncation_mode,
754
- "label_pad_token_id": self.label_pad_token_id,
755
- "max_prompt_length": self.max_prompt_length,
756
- "max_completion_length": self.max_completion_length,
757
- }
758
- eval_dataset = eval_dataset.map(
759
- _process_tokens,
760
- fn_kwargs=fn_kwargs,
761
- num_proc=args.dataset_num_proc,
762
- desc="Processing tokenized eval dataset",
763
- )
764
-
765
- desirable = train_dataset.filter(
766
- lambda x: x["label"], num_proc=args.dataset_num_proc, desc="Filtering desirable examples"
767
- )
768
- undesirable = train_dataset.filter(
769
- lambda x: not x["label"], num_proc=args.dataset_num_proc, desc="Filtering undesirable examples"
770
- )
771
-
772
- desirable = desirable.shuffle(seed=args.data_seed)
773
- undesirable = undesirable.shuffle(seed=args.data_seed)
774
-
775
- super().__init__(
776
- model=model,
777
- args=args,
778
- data_collator=data_collator,
779
- train_dataset=train_dataset,
780
- eval_dataset=eval_dataset,
781
- processing_class=processing_class,
782
- model_init=model_init,
783
- compute_metrics=compute_metrics,
784
- callbacks=callbacks,
785
- optimizers=optimizers,
786
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
787
- )
788
-
789
- # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
790
- # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
791
- # self.model_accepts_loss_kwargs to False to enable scaling.
792
- self.model_accepts_loss_kwargs = False
793
-
794
- # Add tags for models that have been loaded with the correct transformers version
795
- if hasattr(self.model, "add_model_tags"):
796
- self.model.add_model_tags(self._tag_names)
797
-
798
- if not hasattr(self, "accelerator"):
799
- raise AttributeError(
800
- "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
801
- )
802
-
803
- # Deepspeed Zero-3 does not support precompute_ref_log_probs
804
- if self.is_deepspeed_enabled:
805
- if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
806
- raise ValueError(
807
- "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
808
- )
809
-
810
- if self.ref_model is None:
811
- if not (self.is_peft_model or self.precompute_ref_log_probs):
812
- raise ValueError(
813
- "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
814
- )
815
- else:
816
- if self.is_deepspeed_enabled:
817
- self.ref_model = self._prepare_deepspeed(self.ref_model)
818
- else:
819
- self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
820
-
821
- self.running = RunningMoments(accelerator=self.accelerator)
822
-
823
- if self.embedding_func is None:
824
- return
825
-
826
- chosen_embeddings = self._get_sample_prompt_embeddings(desirable, sample_size=self.args.prompt_sample_size)
827
- rejected_embeddings = self._get_sample_prompt_embeddings(undesirable, sample_size=self.args.prompt_sample_size)
828
-
829
- embeddings = torch.cat((chosen_embeddings, rejected_embeddings), dim=0)
830
- labels = torch.cat(
831
- (torch.ones_like(chosen_embeddings[:, 0]), torch.zeros_like(rejected_embeddings[:, 0])), dim=0
832
- )
833
-
834
- self.clf = LogisticRegression(class_weight="balanced").fit(
835
- embeddings.cpu().float().numpy(), labels.cpu().numpy()
836
- )
837
-
838
- @property
839
- def match_underlying_distribution(self):
840
- return self.embedding_func is not None and self.embedding_tokenizer is not None
841
-
842
- def _get_chosen_prob(self, prompt_embeddings: torch.FloatTensor) -> torch.FloatTensor:
843
- """
844
- Calculates the probability if the given prompt embedding is from desirable dataset.
845
- This function calculates the probability in the process and ensemble across processes.
846
- """
847
- dtype = prompt_embeddings.dtype
848
- device = prompt_embeddings.device
849
- rank = self.accelerator.process_index
850
-
851
- padded_prompt_embeddings = self.accelerator.pad_across_processes(
852
- prompt_embeddings, pad_index=self.embedding_tokenizer.pad_token_id
853
- )
854
- sample_size = padded_prompt_embeddings.shape[0]
855
- nonzero = padded_prompt_embeddings.mean(dim=1) != self.embedding_tokenizer.pad_token_id
856
- prompt_embeddings = self.accelerator.gather(padded_prompt_embeddings)
857
-
858
- # cannot predict for all empty values
859
- if prompt_embeddings.shape[0] == 0:
860
- return torch.tensor([], device=device, dtype=dtype)
861
-
862
- prob = self.clf.predict_proba(prompt_embeddings.cpu().float().numpy())[:, 1]
863
- prob = torch.as_tensor(prob, dtype=dtype, device=device)
864
- prob = self.accelerator.reduce(prob, reduction="mean")
865
-
866
- prob = prob[sample_size * rank : sample_size * (rank + 1)]
867
- prob = prob[nonzero]
868
-
869
- return prob
870
-
871
- def _vectorize_prompt(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor) -> torch.FloatTensor:
872
- """
873
- Replaces processing_class.pad_token_id to embedding_tokenizer.pad_token_id
874
- and applies self.embedding_func
875
- """
876
- input_ids = torch.where(
877
- input_ids == self.processing_class.pad_token_id,
878
- self.embedding_tokenizer.pad_token_id,
879
- input_ids,
880
- )
881
-
882
- with torch.no_grad():
883
- embeddings = self.embedding_func(
884
- input_ids=input_ids,
885
- attention_mask=attention_mask,
886
- )
887
-
888
- return embeddings
889
-
890
- def _get_prompt_embeddings(
891
- self, batch: dict[str, Union[list, torch.LongTensor]]
892
- ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
893
- """Extract embeddings from frozen embedding model"""
894
-
895
- if not self.match_underlying_distribution:
896
- return None, None
897
-
898
- embeddings = self._vectorize_prompt(
899
- input_ids=batch["embedding_input_ids"],
900
- attention_mask=batch["embedding_attention_mask"],
901
- )
902
-
903
- chosen_idx = [i for i in range(len(batch["label"])) if batch["label"][i] is True]
904
- rejected_idx = [i for i in range(len(batch["label"])) if batch["label"][i] is False]
905
-
906
- chosen_embeddings = embeddings[chosen_idx, ...]
907
- rejected_embeddings = embeddings[rejected_idx, ...]
908
-
909
- return (chosen_embeddings, rejected_embeddings)
910
-
911
- def _get_sample_prompt_embeddings(self, dataset: Dataset, sample_size: int = 512) -> torch.FloatTensor:
912
- """
913
- Sample instances from dataset and get prompt embeddings.
914
- Used for density ratio classifier training.
915
- """
916
- n_samples = min(len(dataset), sample_size)
917
- rand_indices = np.random.choice(len(dataset), size=(n_samples,))
918
-
919
- embedding_dataset = dataset.select(rand_indices)
920
-
921
- dataloader_params = {
922
- "batch_size": self.args.per_device_train_batch_size,
923
- "collate_fn": self.data_collator,
924
- "num_workers": self.args.dataloader_num_workers,
925
- "pin_memory": self.args.dataloader_pin_memory,
926
- "shuffle": False,
927
- }
928
-
929
- # prepare dataloader
930
- data_loader = self.accelerator.prepare(DataLoader(embedding_dataset, **dataloader_params))
931
-
932
- with torch.no_grad():
933
- all_embeddings = torch.empty(0)
934
- for padded_batch in tqdm(iterable=data_loader, desc="Building sample prompt embeddings"):
935
- embeddings = self._vectorize_prompt(
936
- input_ids=padded_batch["embedding_input_ids"],
937
- attention_mask=padded_batch["embedding_attention_mask"],
938
- )
939
- embeddings = self.accelerator.gather_for_metrics(embeddings)
940
- all_embeddings = torch.cat((all_embeddings, embeddings.cpu()))
941
-
942
- return all_embeddings
943
-
944
- def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
945
- # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
946
- deepspeed_plugin = self.accelerator.state.deepspeed_plugin
947
- config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
948
-
949
- if model is not None:
950
- if hasattr(model, "config"):
951
- hidden_size = (
952
- max(model.config.hidden_sizes)
953
- if getattr(model.config, "hidden_sizes", None)
954
- else getattr(model.config, "hidden_size", None)
955
- )
956
- if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
957
- # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
958
- # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
959
- config_kwargs.update(
960
- {
961
- "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
962
- "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
963
- "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
964
- }
965
- )
966
-
967
- # If ZeRO-3 is used, we shard both the active and reference model.
968
- # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
969
- if config_kwargs["zero_optimization"]["stage"] != 3:
970
- config_kwargs["zero_optimization"]["stage"] = 0
971
- model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
972
- model.eval()
973
- return model
974
-
975
- def _save_optimizer_and_scheduler(self, output_dir):
976
- super()._save_optimizer_and_scheduler(output_dir)
977
-
978
- # When saving optimizer and scheduler to checkpoint, save also the running delta object.
979
- output_dir = output_dir if output_dir is not None else self.args.output_dir
980
-
981
- self.running.save_to_json(os.path.join(output_dir, RUNNING_NAME))
982
-
983
- if self.match_underlying_distribution:
984
- torch.save(self.clf.get_params(), os.path.join(output_dir, CLF_NAME))
985
-
986
- def _load_optimizer_and_scheduler(self, checkpoint):
987
- super()._load_optimizer_and_scheduler(checkpoint)
988
-
989
- if checkpoint is None:
990
- return
991
- # when loading optimizer and scheduler from checkpoint, also load the running delta object.
992
- running_file = os.path.join(checkpoint, RUNNING_NAME)
993
- if os.path.isfile(running_file):
994
- self.running = RunningMoments.load_from_json(self.accelerator, running_file)
995
-
996
- if self.match_underlying_distribution:
997
- clf_file = os.path.join(checkpoint, CLF_NAME)
998
- if os.path.isfile(running_file):
999
- self.clf.set_params(**torch.load(clf_file, weights_only=True, map_location="cpu"))
1000
-
1001
- @contextmanager
1002
- def null_ref_context(self):
1003
- """Context manager for handling null reference model (that is, peft adapter manipulation)."""
1004
- with (
1005
- self.accelerator.unwrap_model(self.model).disable_adapter()
1006
- if self.is_peft_model and not self.ref_adapter_name
1007
- else nullcontext()
1008
- ):
1009
- if self.ref_adapter_name:
1010
- self.model.set_adapter(self.ref_adapter_name)
1011
- yield
1012
- if self.ref_adapter_name:
1013
- self.model.set_adapter(self.model_adapter_name or "default")
1014
-
1015
- def get_train_dataloader(self) -> DataLoader:
1016
- """
1017
- Returns the training [`~torch.utils.data.DataLoader`].
1018
-
1019
- Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
1020
- """
1021
-
1022
- if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
1023
- dataloader_params = {
1024
- "batch_size": self.args.per_device_train_batch_size,
1025
- "collate_fn": self.data_collator,
1026
- "num_workers": self.args.dataloader_num_workers,
1027
- "pin_memory": self.args.dataloader_pin_memory,
1028
- "shuffle": False,
1029
- }
1030
-
1031
- # prepare dataloader
1032
- data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
1033
- reference_completion_logps = []
1034
-
1035
- for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
1036
- reference_completion_logp = self.compute_reference_log_probs(padded_batch)
1037
-
1038
- reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
1039
- reference_completion_logps.append(reference_completion_logp.cpu())
1040
-
1041
- self.train_dataset = self.train_dataset.add_column(
1042
- name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
1043
- )
1044
-
1045
- self._precomputed_train_ref_log_probs = True
1046
-
1047
- return super().get_train_dataloader()
1048
-
1049
- def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
1050
- """
1051
- Returns the evaluation [`~torch.utils.data.DataLoader`].
1052
-
1053
- Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
1054
-
1055
- Args:
1056
- eval_dataset (`torch.utils.data.Dataset`, *optional*):
1057
- If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
1058
- by the `model.forward()` method are automatically removed. It must implement `__len__`.
1059
- """
1060
- if eval_dataset is None and self.eval_dataset is None:
1061
- raise ValueError("Trainer: evaluation requires an eval_dataset.")
1062
- eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
1063
-
1064
- if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
1065
- dataloader_params = {
1066
- "batch_size": self.args.per_device_eval_batch_size,
1067
- "collate_fn": self.data_collator,
1068
- "num_workers": self.args.dataloader_num_workers,
1069
- "pin_memory": self.args.dataloader_pin_memory,
1070
- "shuffle": False,
1071
- }
1072
-
1073
- # prepare dataloader
1074
- data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
1075
-
1076
- reference_completion_logps = []
1077
-
1078
- for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
1079
- reference_completion_logp = self.compute_reference_log_probs(padded_batch)
1080
-
1081
- reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
1082
- reference_completion_logps.append(reference_completion_logp.cpu())
1083
-
1084
- eval_dataset = eval_dataset.add_column(
1085
- name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
1086
- )
1087
-
1088
- # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
1089
- if self.eval_dataset is not None:
1090
- self.eval_dataset = eval_dataset
1091
- self._precomputed_eval_ref_log_probs = True
1092
-
1093
- return super().get_eval_dataloader(eval_dataset=eval_dataset)
1094
-
1095
- def compute_reference_log_probs(self, padded_batch: dict) -> dict:
1096
- """Computes log probabilities of the reference model for a single padded batch of a BCO specific dataset."""
1097
- with torch.no_grad():
1098
- if self.ref_model is None:
1099
- with self.null_ref_context():
1100
- if self.is_encoder_decoder:
1101
- completion_logits = self.model(
1102
- padded_batch["prompt_input_ids"],
1103
- attention_mask=padded_batch["prompt_attention_mask"],
1104
- decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
1105
- labels=padded_batch["completion_labels"],
1106
- ).logits
1107
-
1108
- else:
1109
- completion_logits = self.model(
1110
- padded_batch["completion_input_ids"],
1111
- attention_mask=padded_batch["completion_attention_mask"],
1112
- ).logits
1113
-
1114
- else:
1115
- if self.is_encoder_decoder:
1116
- completion_logits = self.ref_model(
1117
- padded_batch["prompt_input_ids"],
1118
- attention_mask=padded_batch["prompt_attention_mask"],
1119
- decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
1120
- labels=padded_batch["completion_labels"],
1121
- ).logits
1122
-
1123
- else:
1124
- completion_logits = self.ref_model(
1125
- padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
1126
- ).logits
1127
-
1128
- completion_logps = self.get_batch_logps(
1129
- completion_logits,
1130
- padded_batch["completion_labels"],
1131
- average_log_prob=False,
1132
- is_encoder_decoder=self.is_encoder_decoder,
1133
- label_pad_token_id=self.label_pad_token_id,
1134
- )
1135
-
1136
- return completion_logps
1137
-
1138
- @staticmethod
1139
- def get_batch_logps(
1140
- logits: torch.FloatTensor,
1141
- labels: torch.LongTensor,
1142
- average_log_prob: bool = False,
1143
- label_pad_token_id: int = -100,
1144
- is_encoder_decoder: bool = False,
1145
- ) -> torch.FloatTensor:
1146
- """Compute the log probabilities of the given labels under the given logits.
1147
-
1148
- Args:
1149
- logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
1150
- labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
1151
- average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
1152
-
1153
- Returns:
1154
- A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
1155
- """
1156
- if logits.shape[:-1] != labels.shape:
1157
- raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
1158
-
1159
- if not is_encoder_decoder:
1160
- labels = labels[:, 1:].clone()
1161
- logits = logits[:, :-1, :]
1162
- else:
1163
- # Fixes end-dec RuntimeError
1164
- labels = labels.clone()
1165
-
1166
- loss_mask = labels != label_pad_token_id
1167
-
1168
- # dummy token; we'll ignore the losses on these tokens later
1169
- labels[labels == label_pad_token_id] = 0
1170
-
1171
- per_token_logps = selective_log_softmax(logits, labels)
1172
-
1173
- if average_log_prob:
1174
- return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1175
- else:
1176
- return (per_token_logps * loss_mask).sum(-1)
1177
-
1178
- def forward(
1179
- self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
1180
- ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1181
- model_kwargs = (
1182
- {
1183
- "labels": batch["completion_labels"],
1184
- "decoder_input_ids": batch.get("completion_decoder_input_ids"),
1185
- }
1186
- if self.is_encoder_decoder
1187
- else {}
1188
- )
1189
- if self.aux_loss_enabled:
1190
- model_kwargs["output_router_logits"] = True
1191
-
1192
- outputs = model(
1193
- batch["completion_input_ids"],
1194
- attention_mask=batch["completion_attention_mask"],
1195
- **model_kwargs,
1196
- )
1197
- completion_logits = outputs.logits
1198
-
1199
- completion_logps = self.get_batch_logps(
1200
- completion_logits,
1201
- batch["completion_labels"],
1202
- average_log_prob=False,
1203
- is_encoder_decoder=self.is_encoder_decoder,
1204
- label_pad_token_id=self.label_pad_token_id,
1205
- )
1206
-
1207
- if completion_logps.shape[0] != len(batch["label"]):
1208
- raise ValueError(
1209
- "There is a mismatch between the number of examples in this batch and the number of "
1210
- "examples for which an output sequence was predicted."
1211
- )
1212
-
1213
- chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True]
1214
- rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False]
1215
-
1216
- chosen_logps = completion_logps[chosen_idx, ...]
1217
- rejected_logps = completion_logps[rejected_idx, ...]
1218
-
1219
- chosen_logits = completion_logits[chosen_idx, ...]
1220
- rejected_logits = completion_logits[rejected_idx, ...]
1221
-
1222
- if self.aux_loss_enabled:
1223
- return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, outputs.aux_loss)
1224
- else:
1225
- return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
1226
-
1227
- def _get_udm_weight(self, rejected_embeddings: torch.FloatTensor) -> torch.FloatTensor:
1228
- prob_desirable = self._get_chosen_prob(rejected_embeddings)
1229
- min_ratio = self.args.min_density_ratio
1230
- max_ratio = self.args.max_density_ratio
1231
-
1232
- weight = (prob_desirable / (1 - prob_desirable + 1e-8)).clamp(min=min_ratio, max=max_ratio)
1233
-
1234
- return weight
1235
-
1236
- def bco_loss(
1237
- self,
1238
- policy_chosen_logps: torch.FloatTensor,
1239
- policy_rejected_logps: torch.FloatTensor,
1240
- reference_chosen_logps: torch.FloatTensor,
1241
- reference_rejected_logps: torch.FloatTensor,
1242
- chosen_embeddings: Optional[torch.FloatTensor],
1243
- rejected_embeddings: Optional[torch.FloatTensor],
1244
- ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1245
- """Compute the BCO loss for a batch of policy and reference model log probabilities.
1246
-
1247
- Args:
1248
- policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
1249
- policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
1250
- reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
1251
- reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,)
1252
- chosen_embeddings: embeddings of desirable prompts
1253
- rejected_embeddings: embeddings of undesirable prompts
1254
-
1255
- Returns:
1256
- A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, delta).
1257
- The losses tensor contains the BCO loss for each example in the batch.
1258
- The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
1259
- The delta value contains the moving average of all implicit rewards.
1260
- """
1261
-
1262
- if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
1263
- chosen_logratios = policy_chosen_logps - reference_chosen_logps
1264
- chosen_rewards = self.beta * chosen_logratios
1265
- else:
1266
- # lists can't be empty -- if they are, then accelerate.gather will hang
1267
- chosen_losses = torch.Tensor([]).to(self.accelerator.device)
1268
- chosen_rewards = torch.Tensor([]).to(self.accelerator.device)
1269
-
1270
- if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
1271
- rejected_logratios = policy_rejected_logps - reference_rejected_logps
1272
- rejected_rewards = self.beta * rejected_logratios
1273
- else:
1274
- # lists can't be empty -- if they are, then accelerate.gather will hang
1275
- rejected_losses = torch.Tensor([]).to(self.accelerator.device)
1276
- rejected_rewards = torch.Tensor([]).to(self.accelerator.device)
1277
-
1278
- rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
1279
- self.running.update(rewards)
1280
- delta = self.running.mean
1281
-
1282
- if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
1283
- chosen_losses = -F.logsigmoid(chosen_rewards - delta)
1284
-
1285
- if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
1286
- rejected_losses = -F.logsigmoid(-(rejected_rewards - delta))
1287
-
1288
- if self.match_underlying_distribution:
1289
- chosen_weight = torch.ones_like(chosen_losses)
1290
- rejected_weight = self._get_udm_weight(rejected_embeddings)
1291
-
1292
- losses = torch.cat((chosen_weight * chosen_losses, rejected_weight * rejected_losses), dim=0)
1293
- else:
1294
- losses = torch.cat((chosen_losses, rejected_losses), dim=0)
1295
-
1296
- return losses, chosen_rewards, rejected_rewards, torch.as_tensor(delta)
1297
-
1298
- def get_batch_loss_metrics(
1299
- self,
1300
- model,
1301
- batch: dict[str, Union[list, torch.LongTensor]],
1302
- ):
1303
- """Compute the BCO loss and other metrics for the given batch of inputs for train or test."""
1304
- metrics = {}
1305
- batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
1306
-
1307
- forward_output = self.forward(model, batch)
1308
- (
1309
- policy_chosen_logps,
1310
- policy_rejected_logps,
1311
- policy_chosen_logits,
1312
- policy_rejected_logits,
1313
- ) = forward_output[:4]
1314
- if self.aux_loss_enabled:
1315
- aux_loss = forward_output[4]
1316
-
1317
- # if reference_logps in batch use them, otherwise use the reference model
1318
- if "reference_logps" in batch:
1319
- chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
1320
- rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]
1321
-
1322
- reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
1323
- reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
1324
- else:
1325
- with torch.no_grad():
1326
- if self.ref_model is None:
1327
- with self.null_ref_context():
1328
- (
1329
- reference_chosen_logps,
1330
- reference_rejected_logps,
1331
- _,
1332
- _,
1333
- ) = self.forward(self.model, batch)[:4]
1334
- else:
1335
- (
1336
- reference_chosen_logps,
1337
- reference_rejected_logps,
1338
- _,
1339
- _,
1340
- ) = self.forward(self.ref_model, batch)[:4]
1341
-
1342
- chosen_embeddings, rejected_embeddings = self._get_prompt_embeddings(batch)
1343
-
1344
- losses, chosen_rewards, rejected_rewards, delta = self.bco_loss(
1345
- policy_chosen_logps,
1346
- policy_rejected_logps,
1347
- reference_chosen_logps,
1348
- reference_rejected_logps,
1349
- chosen_embeddings,
1350
- rejected_embeddings,
1351
- )
1352
- metrics["delta"] = self.accelerator.gather_for_metrics(delta).mean().item()
1353
-
1354
- num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
1355
- num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
1356
-
1357
- all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
1358
- all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()
1359
-
1360
- if all_num_chosen > 0:
1361
- metrics["rewards/chosen_sum"] = (
1362
- self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
1363
- )
1364
- metrics["logps/chosen_sum"] = (
1365
- self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
1366
- )
1367
- metrics["logits/chosen_sum"] = (
1368
- self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
1369
- )
1370
- metrics["count/chosen"] = all_num_chosen
1371
-
1372
- if all_num_rejected > 0:
1373
- metrics["rewards/rejected_sum"] = (
1374
- self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
1375
- )
1376
- metrics["logps/rejected_sum"] = (
1377
- self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
1378
- )
1379
- metrics["logits/rejected_sum"] = (
1380
- self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
1381
- )
1382
- metrics["count/rejected"] = all_num_rejected
1383
-
1384
- loss = losses.nanmean()
1385
- if self.aux_loss_enabled:
1386
- loss += self.aux_loss_coef * aux_loss
1387
-
1388
- return loss, metrics
1389
-
1390
- def compute_loss(
1391
- self,
1392
- model: Union[PreTrainedModel, nn.Module],
1393
- inputs: dict[str, Union[torch.Tensor, Any]],
1394
- return_outputs=False,
1395
- num_items_in_batch=None,
1396
- ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1397
- compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1398
-
1399
- with compute_loss_context_manager:
1400
- loss, metrics = self.get_batch_loss_metrics(model, inputs)
1401
-
1402
- # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
1403
- loss = loss.to(self.args.device)
1404
- # force log the metrics
1405
- if self.accelerator.is_main_process:
1406
- self.store_metrics(metrics, train_eval="train")
1407
-
1408
- if return_outputs:
1409
- return (loss, metrics)
1410
- return loss
1411
-
1412
- def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1413
- for key, value in metrics.items():
1414
- self._stored_metrics[train_eval][key].append(value)
1415
-
1416
- def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
1417
- if self.train_dataset is None or not has_length(self.train_dataset):
1418
- return None
1419
- return SequentialSampler(self.train_dataset)
1420
-
1421
- def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
1422
- """Generate samples from the model and reference model for the given batch of inputs."""
1423
-
1424
- # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1425
- # the torch cuda amp context manager as some hidden states are silently casted to full precision.
1426
- generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1427
- with generate_context_manager:
1428
- policy_output = model.generate(
1429
- input_ids=batch["prompt_input_ids"],
1430
- attention_mask=batch["prompt_attention_mask"],
1431
- max_length=self.max_length,
1432
- do_sample=True,
1433
- pad_token_id=self.processing_class.pad_token_id,
1434
- )
1435
-
1436
- # if reference_output in batch use that otherwise use the reference model
1437
- if "reference_output" in batch:
1438
- reference_output = batch["reference_output"]
1439
- else:
1440
- if self.ref_model is None:
1441
- with self.null_ref_context():
1442
- reference_output = self.model.generate(
1443
- input_ids=batch["prompt_input_ids"],
1444
- attention_mask=batch["prompt_attention_mask"],
1445
- max_length=self.max_length,
1446
- do_sample=True,
1447
- pad_token_id=self.processing_class.pad_token_id,
1448
- )
1449
- else:
1450
- reference_output = self.ref_model.generate(
1451
- input_ids=batch["prompt_input_ids"],
1452
- attention_mask=batch["prompt_attention_mask"],
1453
- max_length=self.max_length,
1454
- do_sample=True,
1455
- pad_token_id=self.processing_class.pad_token_id,
1456
- )
1457
-
1458
- policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1459
- policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1460
-
1461
- reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
1462
- reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)
1463
-
1464
- return policy_output_decoded, reference_output_decoded
1465
-
1466
- def prediction_step(
1467
- self,
1468
- model: Union[PreTrainedModel, nn.Module],
1469
- inputs: dict[str, Union[torch.Tensor, Any]],
1470
- prediction_loss_only: bool,
1471
- ignore_keys: Optional[list[str]] = None,
1472
- ):
1473
- if ignore_keys is None:
1474
- if hasattr(model, "config"):
1475
- ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1476
- else:
1477
- ignore_keys = []
1478
-
1479
- prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1480
- with torch.no_grad(), prediction_context_manager:
1481
- loss, metrics = self.get_batch_loss_metrics(model, inputs)
1482
-
1483
- # force log the metrics
1484
- if self.accelerator.is_main_process:
1485
- self.store_metrics(metrics, train_eval="eval")
1486
-
1487
- if prediction_loss_only:
1488
- return (loss.detach(), None, None)
1489
-
1490
- # logits for the chosen and rejected samples from model
1491
- logits_dict = {}
1492
- if "logits/chosen_sum" in metrics:
1493
- logits_dict["eval_logits/chosen"] = metrics["logits/chosen_sum"]
1494
- if "logits/rejected_sum" in metrics:
1495
- logits_dict["eval_logits/rejected"] = metrics["logits/rejected_sum"]
1496
- logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
1497
- logits = torch.tensor(logits, device=self.accelerator.device)
1498
- labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1499
-
1500
- return (loss.detach(), logits, labels)
1501
-
1502
- def evaluation_loop(
1503
- self,
1504
- dataloader: DataLoader,
1505
- description: str,
1506
- prediction_loss_only: Optional[bool] = None,
1507
- ignore_keys: Optional[list[str]] = None,
1508
- metric_key_prefix: str = "eval",
1509
- ) -> EvalLoopOutput:
1510
- """
1511
- Overriding built-in evaluation loop to store metrics for each batch.
1512
- Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
1513
-
1514
- Works both with or without labels.
1515
- """
1516
-
1517
- # Sample and save to game log if requested (for one batch to save time)
1518
- if self.generate_during_eval:
1519
- # Generate random indices within the range of the total number of samples
1520
- num_samples = len(dataloader.dataset)
1521
- random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1522
-
1523
- # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1524
- random_batch_dataset = dataloader.dataset.select(random_indices)
1525
- random_batch = self.data_collator(random_batch_dataset)
1526
- random_batch = self._prepare_inputs(random_batch)
1527
-
1528
- target_indicies = [i for i in range(len(random_batch["label"])) if random_batch["label"][i] is False]
1529
- target_batch = {
1530
- "prompt_input_ids": random_batch["prompt_input_ids"][target_indicies],
1531
- "prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies],
1532
- "prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
1533
- }
1534
- policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)
1535
-
1536
- table = pd.DataFrame(
1537
- columns=["Prompt", "Policy", "Ref Model"],
1538
- data=[
1539
- [prompt, pol[len(prompt) :], ref[len(prompt) :]]
1540
- for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
1541
- ],
1542
- )
1543
- if "wandb" in self.args.report_to:
1544
- wandb.log({"game_log": wandb.Table(data=table)})
1545
-
1546
- if "comet_ml" in self.args.report_to:
1547
- log_table_to_comet_experiment(
1548
- name="game_log.csv",
1549
- table=table,
1550
- )
1551
-
1552
- # Base evaluation
1553
- initial_output = super().evaluation_loop(
1554
- dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1555
- )
1556
-
1557
- return initial_output
1558
-
1559
- def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1560
- """
1561
- Log `logs` on the various objects watching training, including stored metrics.
1562
-
1563
- Args:
1564
- logs (`dict[str, float]`):
1565
- The values to log.
1566
- start_time (`float` or `None`, *optional*, defaults to `None`):
1567
- Start time of the training.
1568
- """
1569
- # logs either has 'loss' or 'eval_loss'
1570
- train_eval = "train" if "loss" in logs else "eval"
1571
- # train metrics should have no prefix, eval should have 'eval_'
1572
- prefix = "eval_" if train_eval == "eval" else ""
1573
- # accumulate average metrics from sums and lengths
1574
- for split in ["chosen", "rejected"]:
1575
- if f"count/{split}" in self._stored_metrics[train_eval]:
1576
- count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
1577
- for metric in ["rewards", "logps", "logits"]:
1578
- logs[f"{prefix}{metric}/{split}"] = (
1579
- torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item()
1580
- / count_sum
1581
- )
1582
- # delete obsolete metric
1583
- del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
1584
- del self._stored_metrics[train_eval][f"count/{split}"]
1585
- # calculate reward margin
1586
- if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
1587
- logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
1588
- # Add averaged stored metrics to logs
1589
- for key, metrics in self._stored_metrics[train_eval].items():
1590
- logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
1591
- del self._stored_metrics[train_eval]
1592
-
1593
- if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1594
- return super().log(logs, start_time)
1595
- else: # transformers<=4.46
1596
- return super().log(logs)
1597
-
1598
- def create_model_card(
1599
- self,
1600
- model_name: Optional[str] = None,
1601
- dataset_name: Optional[str] = None,
1602
- tags: Union[str, list[str], None] = None,
1603
- ):
1604
- """
1605
- Creates a draft of a model card using the information available to the `Trainer`.
1606
-
1607
- Args:
1608
- model_name (`str` or `None`, *optional*, defaults to `None`):
1609
- Name of the model.
1610
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
1611
- Name of the dataset used for training.
1612
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1613
- Tags to be associated with the model card.
1614
- """
1615
- if not self.is_world_process_zero():
1616
- return
1617
-
1618
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1619
- base_model = self.model.config._name_or_path
1620
- else:
1621
- base_model = None
1622
-
1623
- tags = tags or []
1624
- if isinstance(tags, str):
1625
- tags = [tags]
1626
-
1627
- if hasattr(self.model.config, "unsloth_version"):
1628
- tags.append("unsloth")
1629
-
1630
- citation = textwrap.dedent("""\
1631
- @article{jung2024binary,
1632
- title = {{Binary Classifier Optimization for Large Language Model Alignment}},
1633
- author = {Seungjae Jung and Gunsoo Han and Daniel Wontae Nam and Kyoung{-}Woon On},
1634
- year = 2024,
1635
- eprint = {arXiv:2404.04656}
1636
- }""")
1637
-
1638
- model_card = generate_model_card(
1639
- base_model=base_model,
1640
- model_name=model_name,
1641
- hub_model_id=self.hub_model_id,
1642
- dataset_name=dataset_name,
1643
- tags=tags,
1644
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1645
- comet_url=get_comet_experiment_url(),
1646
- trainer_name="BCO",
1647
- trainer_citation=citation,
1648
- paper_title="Binary Classifier Optimization for Large Language Model Alignment",
1649
- paper_id="2404.04656",
1650
- )
1651
-
1652
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
1653
- class UnslothBCOTrainer(_UnslothBCOTrainer):
1654
- """
1655
-
1656
- Initialize BCOTrainer from [BCO](https://huggingface.co/papers/2404.04656) paper.
1657
-
1658
- Args:
1659
- model (`transformers.PreTrainedModel`):
1660
- The model to train, preferably an `AutoModelForSequenceClassification`.
1661
- ref_model (`PreTrainedModelWrapper`):
1662
- Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
1663
- reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
1664
- args (`BCOConfig`):
1665
- The arguments to use for training.
1666
- train_dataset (`datasets.Dataset`):
1667
- The dataset to use for training.
1668
- eval_dataset (`datasets.Dataset`):
1669
- The dataset to use for evaluation.
1670
- processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1671
- Processing class used to process the data. If provided, will be used to automatically process the inputs
1672
- for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1673
- reuse the fine-tuned model.
1674
- data_collator (`transformers.DataCollator`, *optional*, defaults to `None`):
1675
- The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1676
- which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1677
- model_init (`Callable[[], transformers.PreTrainedModel]`):
1678
- The model initializer to use for training. If None is specified, the default model initializer will be used.
1679
- callbacks (`list[transformers.TrainerCallback]`):
1680
- The callbacks to use for training.
1681
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1682
- The optimizer and scheduler to use for training.
1683
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1684
- The function to use to preprocess the logits before computing the metrics.
1685
- peft_config (`dict`, defaults to `None`):
1686
- The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
1687
- compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1688
- The function to use to compute the metrics. Must take a `EvalPrediction` and return
1689
- a dictionary string to metric values.
1690
- model_adapter_name (`str`, defaults to `None`):
1691
- Name of the train target PEFT adapter, when using LoRA with multiple adapters.
1692
- ref_adapter_name (`str`, defaults to `None`):
1693
- Name of the reference PEFT adapter, when using LoRA with multiple adapters.
1694
-
1695
- """
1696
- def __init__(
1697
- self,
1698
- model = None,
1699
- ref_model = None,
1700
- args = None,
1701
- train_dataset = None,
1702
- eval_dataset = None,
1703
- processing_class = None,
1704
- data_collator = None,
1705
- model_init = None,
1706
- callbacks = None,
1707
- preprocess_logits_for_metrics = None,
1708
- peft_config = None,
1709
- compute_metrics = None,
1710
- model_adapter_name = None,
1711
- ref_adapter_name = None,
1712
- embedding_func = None,
1713
- embedding_tokenizer = None,
1714
- **kwargs
1715
- ):
1716
- if args is None: args = UnslothBCOConfig()
1717
- use_bf16 = getattr(args, 'bf16', False)
1718
- if type(use_bf16) is not bool: use_bf16 = False
1719
- use_fp16 = getattr(args, 'fp16', False)
1720
- if type(use_fp16) is not bool: use_fp16 = False
1721
- force_float32 = False
1722
- if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1723
- print('Unsloth: Switching to float32 training since model cannot work with float16')
1724
- force_float32 = True
1725
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1726
- dtype = getattr(model.config, 'torch_dtype', None)
1727
- if dtype is None: dtype = model.get_input_embeddings().dtype
1728
- from unsloth_zoo.utils import _get_dtype
1729
- dtype = _get_dtype(dtype)
1730
- float16 = dtype == torch.float16
1731
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1732
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1733
- if force_float32:
1734
- args.fp16 = False
1735
- args.bf16 = False
1736
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1737
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1738
- args.fp16 = float16
1739
- args.bf16 = not float16
1740
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1741
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1742
- args.eval_strategy = 'steps'
1743
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1744
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1745
- if ga_steps is not None and ga_steps > 1:
1746
- from transformers import __version__ as transformers_version
1747
- if Version(transformers_version) <= Version('4.45.2'):
1748
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1749
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1750
- if getattr(args, 'eval_strategy', 'no') != 'no':
1751
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1752
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1753
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1754
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1755
- if type(fp16_full_eval) is not bool: fp16_full_eval = False
1756
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1757
- if type(bf16_full_eval) is not bool: bf16_full_eval = False
1758
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1759
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1760
- if force_float32:
1761
- args.bf16_full_eval = False
1762
- args.fp16_full_eval = False
1763
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1764
- args.bf16_full_eval = True
1765
- args.fp16_full_eval = False
1766
- elif not bf16_full_eval and not fp16_full_eval:
1767
- args.bf16_full_eval = args.bf16
1768
- args.fp16_full_eval = args.fp16
1769
- _output_logits = False
1770
- if locals().get('compute_metrics', None) is not None: _output_logits = True
1771
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1772
- if _output_logits:
1773
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1774
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1775
- pass
1776
- else:
1777
- model_max_seq_length = getattr(model, 'max_seq_length', None)
1778
- args_max_seq_length = getattr(args, 'max_seq_length', None)
1779
- if args_max_seq_length is None and model_max_seq_length is not None:
1780
- max_seq_length = model.max_seq_length
1781
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1782
- if model is not None and hasattr(model, 'for_training'):
1783
- model.for_training()
1784
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1785
- if 'processing_class' in locals():
1786
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1787
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1788
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1789
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1790
- if not isinstance(data_collator, UnslothVisionDataCollator):
1791
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1792
- data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
1793
- elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1794
- data_collator = DataCollatorForSeq2Seq(__tokenizer)
1795
- else:
1796
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1797
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1798
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1799
- if not isinstance(data_collator, UnslothVisionDataCollator):
1800
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1801
- if isinstance(data_collator, DataCollatorForSeq2Seq):
1802
- data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1803
- else:
1804
- data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
1805
- other_metrics = []
1806
-
1807
- from unsloth_zoo.logging_utils import PatchRLStatistics
1808
- PatchRLStatistics('bco_trainer', other_metrics)
1809
-
1810
- super().__init__(
1811
- model = model,
1812
- ref_model = ref_model,
1813
- args = args,
1814
- train_dataset = train_dataset,
1815
- eval_dataset = eval_dataset,
1816
- processing_class = processing_class,
1817
- data_collator = data_collator,
1818
- model_init = model_init,
1819
- callbacks = callbacks,
1820
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1821
- peft_config = peft_config,
1822
- compute_metrics = compute_metrics,
1823
- model_adapter_name = model_adapter_name,
1824
- ref_adapter_name = ref_adapter_name,
1825
- embedding_func = embedding_func,
1826
- embedding_tokenizer = embedding_tokenizer,**kwargs)
1827
- if hasattr(self, 'neftune_hook_handle'):
1828
- self.neftune_hook_handle.remove()
1829
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1830
- if getattr(args, 'neftune_noise_alpha', None) is not None:
1831
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1832
- pass
1833
-
1834
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_run_uploads/UnslothCPOTrainer.py DELETED
@@ -1,1566 +0,0 @@
1
- """
2
- 2025.7.11
3
- 2025.7.11
4
- 4.54.1
5
- 0.16.1
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
13
- from trl.trainer.cpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, CPOConfig, CPOTrainer, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, amp, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, transformers, version, wandb, warnings)
14
-
15
-
16
- import os
17
- from typing import *
18
- from dataclasses import dataclass, field
19
- from packaging.version import Version
20
- import torch
21
- import numpy as np
22
- from contextlib import nullcontext
23
- from torch.nn import functional as F
24
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
25
-
26
- torch_compile_options = {
27
- "epilogue_fusion" : True,
28
- "max_autotune" : False,
29
- "shape_padding" : True,
30
- "trace.enabled" : False,
31
- "triton.cudagraphs" : False,
32
- }
33
-
34
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
35
- def chunked_selective_log_softmax(logits, index):
36
- # Split into 4 chunks only
37
- chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
38
- chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
39
- all_per_token_logps = []
40
- # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
41
- for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
42
- chunk_logits = chunk_logits.to(torch.float32)
43
- selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
44
- logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
45
- per_token_logps = selected_logits - logsumexp_values
46
- all_per_token_logps.append(per_token_logps)
47
- pass
48
- all_per_token_logps = torch.concat(all_per_token_logps)
49
- all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
50
- return all_per_token_logps
51
- @dataclass
52
- class UnslothCPOConfig(CPOConfig):
53
- """
54
-
55
- Configuration class for the [`CPOTrainer`].
56
-
57
- Using [`~transformers.HfArgumentParser`] we can turn this class into
58
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
59
- command line.
60
-
61
- Parameters:
62
- learning_rate (`float`, *optional*, defaults to `1e-6`):
63
- Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
64
- [`~transformers.TrainingArguments`].
65
- max_length (`int` or `None`, *optional*, defaults to `1024`):
66
- Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
67
- to use the default data collator.
68
- max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
69
- Maximum length of the prompt. This argument is required if you want to use the default data collator.
70
- max_completion_length (`int` or `None`, *optional*, defaults to `None`):
71
- Maximum length of the completion. This argument is required if you want to use the default data collator
72
- and your model is an encoder-decoder.
73
- beta (`float`, *optional*, defaults to `0.1`):
74
- Parameter controlling the deviation from the reference model. Higher β means less deviation from the
75
- reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
76
- the [paper](https://huggingface.co/papers/2310.12036).
77
- label_smoothing (`float`, *optional*, defaults to `0.0`):
78
- Label smoothing factor. This argument is required if you want to use the default data collator.
79
- loss_type (`str`, *optional*, defaults to `"sigmoid"`):
80
- Type of loss to use. Possible values are:
81
-
82
- - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
83
- - `"hinge"`: hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper.
84
- - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
85
- - `"simpo"`: SimPO loss from the [SimPO](https://huggingface.co/papers/2405.14734) paper.
86
-
87
- disable_dropout (`bool`, *optional*, defaults to `True`):
88
- Whether to disable dropout in the model.
89
- cpo_alpha (`float`, *optional*, defaults to `1.0`):
90
- Weight of the BC regularizer in CPO training.
91
- simpo_gamma (`float`, *optional*, defaults to `0.5`):
92
- Target reward margin for the SimPO loss, used only when the `loss_type="simpo"`.
93
- label_pad_token_id (`int`, *optional*, defaults to `-100`):
94
- Label pad token id. This argument is required if you want to use the default data collator.
95
- padding_value (`int` or `None`, *optional*, defaults to `None`):
96
- Padding value to use. If `None`, the padding value of the tokenizer is used.
97
- truncation_mode (`str`,*optional*, defaults to `"keep_end"`):
98
- Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
99
- This argument is required if you want to use the default data collator.
100
- generate_during_eval (`bool`, *optional*, defaults to `False`):
101
- If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
102
- is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
103
- When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
104
- you need to specify if the model returned by the callable is an encoder-decoder model.
105
- model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
106
- Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
107
- string.
108
- dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
109
- Number of processes to use for processing the dataset.
110
-
111
- """
112
- vllm_sampling_params: Optional[Any] = field(
113
- default = None,
114
- metadata = {'help': 'vLLM SamplingParams'},
115
- )
116
- unsloth_num_chunks : Optional[int] = field(
117
- default = -1,
118
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
119
- )
120
- def __init__(
121
- self,
122
- output_dir = None,
123
- overwrite_output_dir = None,
124
- do_train = False,
125
- do_eval = False,
126
- do_predict = False,
127
- eval_strategy = 'no',
128
- prediction_loss_only = False,
129
- per_device_train_batch_size = 4,
130
- per_device_eval_batch_size = 4,
131
- per_gpu_train_batch_size = None,
132
- per_gpu_eval_batch_size = None,
133
- gradient_accumulation_steps = 2,
134
- eval_accumulation_steps = 2,
135
- eval_delay = 0,
136
- torch_empty_cache_steps = 250,
137
- learning_rate = 5e-05,
138
- weight_decay = 0.01,
139
- adam_beta1 = 0.9,
140
- adam_beta2 = 0.999,
141
- adam_epsilon = 1e-08,
142
- max_grad_norm = 1.0,
143
- num_train_epochs = 3.0,
144
- max_steps = -1,
145
- lr_scheduler_type = 'linear',
146
- warmup_ratio = 0.1,
147
- warmup_steps = 0,
148
- log_level = 'passive',
149
- log_level_replica = 'warning',
150
- log_on_each_node = True,
151
- logging_dir = None,
152
- logging_strategy = 'steps',
153
- logging_first_step = False,
154
- logging_steps = 1,
155
- logging_nan_inf_filter = False,
156
- save_strategy = 'steps',
157
- save_steps = 500,
158
- save_total_limit = None,
159
- save_safetensors = True,
160
- save_on_each_node = False,
161
- save_only_model = False,
162
- restore_callback_states_from_checkpoint = False,
163
- no_cuda = False,
164
- use_cpu = False,
165
- use_mps_device = False,
166
- seed = 3407,
167
- data_seed = 3407,
168
- jit_mode_eval = False,
169
- use_ipex = False,
170
- bf16 = False,
171
- fp16 = False,
172
- fp16_opt_level = 'O1',
173
- half_precision_backend = 'auto',
174
- bf16_full_eval = False,
175
- fp16_full_eval = False,
176
- tf32 = None,
177
- local_rank = -1,
178
- ddp_backend = None,
179
- tpu_num_cores = None,
180
- tpu_metrics_debug = False,
181
- debug = '',
182
- dataloader_drop_last = False,
183
- eval_steps = None,
184
- dataloader_num_workers = 0,
185
- dataloader_prefetch_factor = None,
186
- past_index = -1,
187
- run_name = None,
188
- disable_tqdm = None,
189
- remove_unused_columns = True,
190
- label_names = None,
191
- load_best_model_at_end = False,
192
- metric_for_best_model = None,
193
- greater_is_better = None,
194
- ignore_data_skip = False,
195
- fsdp = '',
196
- fsdp_min_num_params = 0,
197
- fsdp_config = None,
198
- fsdp_transformer_layer_cls_to_wrap = None,
199
- accelerator_config = None,
200
- deepspeed = None,
201
- label_smoothing_factor = 0.0,
202
- optim = 'adamw_8bit',
203
- optim_args = None,
204
- adafactor = False,
205
- group_by_length = False,
206
- length_column_name = 'length',
207
- report_to = None,
208
- ddp_find_unused_parameters = None,
209
- ddp_bucket_cap_mb = None,
210
- ddp_broadcast_buffers = None,
211
- dataloader_pin_memory = True,
212
- dataloader_persistent_workers = False,
213
- skip_memory_metrics = True,
214
- use_legacy_prediction_loop = False,
215
- push_to_hub = False,
216
- resume_from_checkpoint = None,
217
- hub_model_id = None,
218
- hub_strategy = 'every_save',
219
- hub_token = None,
220
- hub_private_repo = None,
221
- hub_always_push = False,
222
- hub_revision = None,
223
- gradient_checkpointing = False,
224
- gradient_checkpointing_kwargs = None,
225
- include_inputs_for_metrics = False,
226
- eval_do_concat_batches = True,
227
- fp16_backend = 'auto',
228
- push_to_hub_model_id = None,
229
- push_to_hub_organization = None,
230
- push_to_hub_token = None,
231
- mp_parameters = '',
232
- auto_find_batch_size = True,
233
- full_determinism = False,
234
- torchdynamo = None,
235
- ray_scope = 'last',
236
- ddp_timeout = 1800,
237
- torch_compile = False,
238
- torch_compile_backend = None,
239
- torch_compile_mode = None,
240
- include_tokens_per_second = False,
241
- include_num_input_tokens_seen = False,
242
- neftune_noise_alpha = None,
243
- optim_target_modules = None,
244
- batch_eval_metrics = False,
245
- eval_on_start = False,
246
- use_liger_kernel = False,
247
- liger_kernel_config = None,
248
- eval_use_gather_object = False,
249
- average_tokens_across_devices = True,
250
- max_length = 1024,
251
- max_prompt_length = 512,
252
- max_completion_length = None,
253
- beta = 0.1,
254
- label_smoothing = 0.0,
255
- loss_type = 'sigmoid',
256
- disable_dropout = True,
257
- cpo_alpha = 1.0,
258
- simpo_gamma = 0.5,
259
- label_pad_token_id = -100,
260
- padding_value = None,
261
- truncation_mode = 'keep_end',
262
- generate_during_eval = False,
263
- is_encoder_decoder = None,
264
- model_init_kwargs = None,
265
- dataset_num_proc = None,
266
- vllm_sampling_params = None,
267
- unsloth_num_chunks = -1,
268
- **kwargs,
269
- ):
270
- if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
271
- if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
272
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
273
- output_dir = 'unsloth_training_checkpoints'
274
- save_strategy = 'no'
275
- if dataset_num_proc is None:
276
- from multiprocessing import cpu_count
277
- dataset_num_proc = min(cpu_count()*2, 2)
278
-
279
- super().__init__(
280
- output_dir = output_dir,
281
- overwrite_output_dir = overwrite_output_dir,
282
- do_train = do_train,
283
- do_eval = do_eval,
284
- do_predict = do_predict,
285
- eval_strategy = eval_strategy,
286
- prediction_loss_only = prediction_loss_only,
287
- per_device_train_batch_size = per_device_train_batch_size,
288
- per_device_eval_batch_size = per_device_eval_batch_size,
289
- per_gpu_train_batch_size = per_gpu_train_batch_size,
290
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
291
- gradient_accumulation_steps = gradient_accumulation_steps,
292
- eval_accumulation_steps = eval_accumulation_steps,
293
- eval_delay = eval_delay,
294
- torch_empty_cache_steps = torch_empty_cache_steps,
295
- learning_rate = learning_rate,
296
- weight_decay = weight_decay,
297
- adam_beta1 = adam_beta1,
298
- adam_beta2 = adam_beta2,
299
- adam_epsilon = adam_epsilon,
300
- max_grad_norm = max_grad_norm,
301
- num_train_epochs = num_train_epochs,
302
- max_steps = max_steps,
303
- lr_scheduler_type = lr_scheduler_type,
304
- warmup_ratio = warmup_ratio,
305
- warmup_steps = warmup_steps,
306
- log_level = log_level,
307
- log_level_replica = log_level_replica,
308
- log_on_each_node = log_on_each_node,
309
- logging_dir = logging_dir,
310
- logging_strategy = logging_strategy,
311
- logging_first_step = logging_first_step,
312
- logging_steps = logging_steps,
313
- logging_nan_inf_filter = logging_nan_inf_filter,
314
- save_strategy = save_strategy,
315
- save_steps = save_steps,
316
- save_total_limit = save_total_limit,
317
- save_safetensors = save_safetensors,
318
- save_on_each_node = save_on_each_node,
319
- save_only_model = save_only_model,
320
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
321
- no_cuda = no_cuda,
322
- use_cpu = use_cpu,
323
- use_mps_device = use_mps_device,
324
- seed = seed,
325
- data_seed = data_seed,
326
- jit_mode_eval = jit_mode_eval,
327
- use_ipex = use_ipex,
328
- bf16 = bf16,
329
- fp16 = fp16,
330
- fp16_opt_level = fp16_opt_level,
331
- half_precision_backend = half_precision_backend,
332
- bf16_full_eval = bf16_full_eval,
333
- fp16_full_eval = fp16_full_eval,
334
- tf32 = tf32,
335
- local_rank = local_rank,
336
- ddp_backend = ddp_backend,
337
- tpu_num_cores = tpu_num_cores,
338
- tpu_metrics_debug = tpu_metrics_debug,
339
- debug = debug,
340
- dataloader_drop_last = dataloader_drop_last,
341
- eval_steps = eval_steps,
342
- dataloader_num_workers = dataloader_num_workers,
343
- dataloader_prefetch_factor = dataloader_prefetch_factor,
344
- past_index = past_index,
345
- run_name = run_name,
346
- disable_tqdm = disable_tqdm,
347
- remove_unused_columns = remove_unused_columns,
348
- label_names = label_names,
349
- load_best_model_at_end = load_best_model_at_end,
350
- metric_for_best_model = metric_for_best_model,
351
- greater_is_better = greater_is_better,
352
- ignore_data_skip = ignore_data_skip,
353
- fsdp = fsdp,
354
- fsdp_min_num_params = fsdp_min_num_params,
355
- fsdp_config = fsdp_config,
356
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
357
- accelerator_config = accelerator_config,
358
- deepspeed = deepspeed,
359
- label_smoothing_factor = label_smoothing_factor,
360
- optim = optim,
361
- optim_args = optim_args,
362
- adafactor = adafactor,
363
- group_by_length = group_by_length,
364
- length_column_name = length_column_name,
365
- report_to = report_to,
366
- ddp_find_unused_parameters = ddp_find_unused_parameters,
367
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
368
- ddp_broadcast_buffers = ddp_broadcast_buffers,
369
- dataloader_pin_memory = dataloader_pin_memory,
370
- dataloader_persistent_workers = dataloader_persistent_workers,
371
- skip_memory_metrics = skip_memory_metrics,
372
- use_legacy_prediction_loop = use_legacy_prediction_loop,
373
- push_to_hub = push_to_hub,
374
- resume_from_checkpoint = resume_from_checkpoint,
375
- hub_model_id = hub_model_id,
376
- hub_strategy = hub_strategy,
377
- hub_token = hub_token,
378
- hub_private_repo = hub_private_repo,
379
- hub_always_push = hub_always_push,
380
- hub_revision = hub_revision,
381
- gradient_checkpointing = gradient_checkpointing,
382
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
383
- include_inputs_for_metrics = include_inputs_for_metrics,
384
- eval_do_concat_batches = eval_do_concat_batches,
385
- fp16_backend = fp16_backend,
386
- push_to_hub_model_id = push_to_hub_model_id,
387
- push_to_hub_organization = push_to_hub_organization,
388
- push_to_hub_token = push_to_hub_token,
389
- mp_parameters = mp_parameters,
390
- auto_find_batch_size = auto_find_batch_size,
391
- full_determinism = full_determinism,
392
- torchdynamo = torchdynamo,
393
- ray_scope = ray_scope,
394
- ddp_timeout = ddp_timeout,
395
- torch_compile = torch_compile,
396
- torch_compile_backend = torch_compile_backend,
397
- torch_compile_mode = torch_compile_mode,
398
- include_tokens_per_second = include_tokens_per_second,
399
- include_num_input_tokens_seen = include_num_input_tokens_seen,
400
- neftune_noise_alpha = neftune_noise_alpha,
401
- optim_target_modules = optim_target_modules,
402
- batch_eval_metrics = batch_eval_metrics,
403
- eval_on_start = eval_on_start,
404
- use_liger_kernel = use_liger_kernel,
405
- liger_kernel_config = liger_kernel_config,
406
- eval_use_gather_object = eval_use_gather_object,
407
- average_tokens_across_devices = average_tokens_across_devices,
408
- max_length = max_length,
409
- max_prompt_length = max_prompt_length,
410
- max_completion_length = max_completion_length,
411
- beta = beta,
412
- label_smoothing = label_smoothing,
413
- loss_type = loss_type,
414
- disable_dropout = disable_dropout,
415
- cpo_alpha = cpo_alpha,
416
- simpo_gamma = simpo_gamma,
417
- label_pad_token_id = label_pad_token_id,
418
- padding_value = padding_value,
419
- truncation_mode = truncation_mode,
420
- generate_during_eval = generate_during_eval,
421
- is_encoder_decoder = is_encoder_decoder,
422
- model_init_kwargs = model_init_kwargs,
423
- dataset_num_proc = dataset_num_proc,**kwargs)
424
- self.vllm_sampling_params = vllm_sampling_params
425
- self.unsloth_num_chunks = unsloth_num_chunks
426
- pass
427
-
428
- class _UnslothCPOTrainer(Trainer):
429
- r""""""
430
-
431
- _tag_names = ["trl", "cpo"]
432
-
433
- def __init__(
434
- self,
435
- model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
436
- args: Optional[CPOConfig] = None,
437
- data_collator: Optional[DataCollator] = None,
438
- train_dataset: Optional[Dataset] = None,
439
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
440
- processing_class: Optional[
441
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
442
- ] = None,
443
- model_init: Optional[Callable[[], PreTrainedModel]] = None,
444
- callbacks: Optional[list[TrainerCallback]] = None,
445
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
446
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
447
- peft_config: Optional[dict] = None,
448
- compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
449
- ):
450
- if args.model_init_kwargs is None:
451
- model_init_kwargs = {}
452
- elif not isinstance(model, str):
453
- raise ValueError("You passed model_kwargs to the CPOTrainer. But your model is already instantiated.")
454
- else:
455
- model_init_kwargs = args.model_init_kwargs
456
- torch_dtype = model_init_kwargs.get("torch_dtype")
457
- if torch_dtype is not None:
458
- # Convert to `torch.dtype` if an str is passed
459
- if isinstance(torch_dtype, str) and torch_dtype != "auto":
460
- torch_dtype = getattr(torch, torch_dtype)
461
- if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
462
- raise ValueError(
463
- f"Invalid `torch_dtype` passed to the CPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
464
- )
465
- model_init_kwargs["torch_dtype"] = torch_dtype
466
-
467
- if isinstance(model, str):
468
- model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
469
-
470
- # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
471
- # has been called in order to properly call autocast if needed.
472
- self._peft_has_been_casted_to_bf16 = False
473
-
474
- if not is_peft_available() and peft_config is not None:
475
- raise ValueError(
476
- "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
477
- )
478
- elif is_peft_available() and peft_config is not None:
479
- # if model is a peft model and we have a peft_config, we merge and unload it first
480
- if isinstance(model, PeftModel):
481
- model = model.merge_and_unload()
482
-
483
- if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
484
- _support_gc_kwargs = hasattr(
485
- args, "gradient_checkpointing_kwargs"
486
- ) and "gradient_checkpointing_kwargs" in list(
487
- inspect.signature(prepare_model_for_kbit_training).parameters
488
- )
489
-
490
- prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
491
-
492
- if _support_gc_kwargs:
493
- prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
494
-
495
- model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
496
- elif getattr(args, "gradient_checkpointing", False):
497
- # For backward compatibility with older versions of transformers
498
- if hasattr(model, "enable_input_require_grads"):
499
- model.enable_input_require_grads()
500
- else:
501
-
502
- def make_inputs_require_grad(module, input, output):
503
- output.requires_grad_(True)
504
-
505
- model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
506
-
507
- # get peft model with the given config
508
- model = model
509
- if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
510
- peft_module_casting_to_bf16(model)
511
- # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
512
- self._peft_has_been_casted_to_bf16 = True
513
-
514
- # For models that use gradient_checkpointing, we need to attach a hook that enables input
515
- # to explicitly have `requires_grad=True`, otherwise training will either silently
516
- # fail or completely fail.
517
- elif getattr(args, "gradient_checkpointing", False):
518
- # For backward compatibility with older versions of transformers
519
- if hasattr(model, "enable_input_require_grads"):
520
- model.enable_input_require_grads()
521
- else:
522
-
523
- def make_inputs_require_grad(module, input, output):
524
- output.requires_grad_(True)
525
-
526
- model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
527
-
528
- if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
529
- raise ValueError(
530
- "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
531
- " Please install `wandb` or `comet-ml` to resolve."
532
- )
533
-
534
- if model is not None:
535
- self.is_encoder_decoder = model.config.is_encoder_decoder
536
- elif args.is_encoder_decoder is None:
537
- raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
538
- else:
539
- self.is_encoder_decoder = args.is_encoder_decoder
540
-
541
- if self.is_encoder_decoder:
542
- self.decoder_start_token_id = model.config.decoder_start_token_id
543
- self.pad_token_id = model.config.pad_token_id
544
-
545
- if processing_class is None:
546
- raise ValueError("processing_class must be specified to tokenize a CPO dataset.")
547
- if args.max_length is None:
548
- warnings.warn(
549
- "`max_length` is not set in the CPOConfig's init"
550
- " it will default to `512` by default, but you should do it yourself in the future.",
551
- UserWarning,
552
- )
553
- max_length = 512
554
- else:
555
- max_length = args.max_length
556
- if args.max_prompt_length is None:
557
- warnings.warn(
558
- "`max_prompt_length` is not set in the CPOConfig's init"
559
- " it will default to `128` by default, but you should do it yourself in the future.",
560
- UserWarning,
561
- )
562
- max_prompt_length = 128
563
- else:
564
- max_prompt_length = args.max_prompt_length
565
-
566
- if args.max_completion_length is None and self.is_encoder_decoder:
567
- warnings.warn(
568
- "When using an encoder decoder architecture, you should set `max_completion_length` in the CPOConfig's init"
569
- " it will default to `128` by default, but you should do it yourself in the future.",
570
- UserWarning,
571
- )
572
- max_completion_length = 128
573
- else:
574
- max_completion_length = args.max_completion_length
575
-
576
- if data_collator is None:
577
- data_collator = DPODataCollatorWithPadding(
578
- pad_token_id=processing_class.pad_token_id,
579
- label_pad_token_id=args.label_pad_token_id,
580
- is_encoder_decoder=self.is_encoder_decoder,
581
- )
582
-
583
- if args.remove_unused_columns:
584
- args.remove_unused_columns = False
585
- # warn users
586
- warnings.warn(
587
- "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
588
- " we have set it for you, but you should do it yourself in the future.",
589
- UserWarning,
590
- )
591
-
592
- self.use_dpo_data_collator = True
593
- else:
594
- self.use_dpo_data_collator = False
595
-
596
- # Disable dropout in the model
597
- if args.disable_dropout:
598
- disable_dropout_in_model(model)
599
-
600
- self.max_length = max_length
601
- self.generate_during_eval = args.generate_during_eval
602
- self.label_pad_token_id = args.label_pad_token_id
603
- self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
604
- self.max_prompt_length = max_prompt_length
605
- self.truncation_mode = args.truncation_mode
606
- self.max_completion_length = max_completion_length
607
- self.processing_class = processing_class
608
-
609
- if args.loss_type in ["hinge", "ipo"] and args.label_smoothing > 0:
610
- warnings.warn(
611
- f"You are using the {args.loss_type} loss type that does not support label smoothing. The "
612
- "`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this warning.",
613
- UserWarning,
614
- )
615
- if args.loss_type == "kto_pair":
616
- raise ValueError("Support for kto_pair has been removed in CPOTrainer. Please use KTOTrainer.")
617
-
618
- self.beta = args.beta
619
- self.label_smoothing = args.label_smoothing
620
- self.loss_type = args.loss_type
621
- self.cpo_alpha = args.cpo_alpha
622
- self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
623
- self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
624
- if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
625
- warnings.warn(
626
- "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
627
- "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
628
- "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
629
- "loss.",
630
- UserWarning,
631
- )
632
-
633
- if args.loss_type == "simpo":
634
- self.simpo_gamma = args.simpo_gamma
635
-
636
- self._stored_metrics = defaultdict(lambda: defaultdict(list))
637
-
638
- # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
639
- # input tensor associated with the key "input_ids". However, in CPO, the sampled data does not include the
640
- # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
641
- # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
642
- # of the input, floating-point operations will not be computed." To suppress this warning, we set the
643
- # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
644
- # that the warning has already been issued.
645
- model.warnings_issued["estimate_tokens"] = True
646
-
647
- # Compute that only on the main process for faster data processing.
648
- # see: https://github.com/huggingface/trl/pull/1255
649
- with PartialState().main_process_first():
650
- # Extract the prompt if needed, and apply the chat template if needed
651
- train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
652
- train_dataset = train_dataset.map(
653
- maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
654
- )
655
- if eval_dataset is not None:
656
- eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
657
- eval_dataset = eval_dataset.map(
658
- maybe_apply_chat_template,
659
- fn_kwargs={"tokenizer": processing_class},
660
- num_proc=args.dataset_num_proc,
661
- )
662
-
663
- # tokenize the dataset
664
- train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
665
- if eval_dataset is not None:
666
- eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
667
-
668
- super().__init__(
669
- model=model,
670
- args=args,
671
- data_collator=data_collator,
672
- train_dataset=train_dataset,
673
- eval_dataset=eval_dataset,
674
- processing_class=processing_class,
675
- model_init=model_init,
676
- compute_metrics=compute_metrics,
677
- callbacks=callbacks,
678
- optimizers=optimizers,
679
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
680
- )
681
-
682
- # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
683
- # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
684
- # self.model_accepts_loss_kwargs to False to enable scaling.
685
- self.model_accepts_loss_kwargs = False
686
-
687
- # Add tags for models that have been loaded with the correct transformers version
688
- if hasattr(self.model, "add_model_tags"):
689
- self.model.add_model_tags(self._tag_names)
690
-
691
- if not hasattr(self, "accelerator"):
692
- raise AttributeError(
693
- "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
694
- )
695
-
696
- def build_tokenized_answer(self, prompt, answer):
697
- """
698
- Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
699
- It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
700
- Reference:
701
- https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
702
- """
703
-
704
- full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
705
- prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
706
-
707
- answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
708
- answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
709
-
710
- # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
711
- full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
712
-
713
- # Prepare input tokens for token by token comparison
714
- full_input_ids = np.array(full_tokenized["input_ids"])
715
-
716
- if len(full_input_ids) != len(full_concat_input_ids):
717
- raise ValueError("Prompt input ids and answer input ids should have the same length.")
718
-
719
- # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
720
- # can be merged together when tokenizing prompt+answer. This could result
721
- # on the last token from the prompt being different when tokenized on its own
722
- # vs when done as prompt+answer.
723
- response_token_ids_start_idx = len(prompt_input_ids)
724
-
725
- # If tokenized prompt is different than both prompt+answer, then it means the
726
- # last token has changed due to merging.
727
- if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
728
- response_token_ids_start_idx -= 1
729
-
730
- prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
731
- prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
732
-
733
- if len(prompt_input_ids) != len(prompt_attention_mask):
734
- raise ValueError("Prompt input ids and attention mask should have the same length.")
735
-
736
- answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
737
- answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
738
-
739
- return dict(
740
- prompt_input_ids=prompt_input_ids,
741
- prompt_attention_mask=prompt_attention_mask,
742
- input_ids=answer_input_ids,
743
- attention_mask=answer_attention_mask,
744
- )
745
-
746
- def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
747
- """Tokenize a single row from a CPO specific dataset.
748
-
749
- At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
750
- in case the prompt + chosen or prompt + rejected responses is/are too long. First
751
- we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
752
-
753
- We also create the labels for the chosen/rejected responses, which are of length equal to
754
- the sum of the length of the prompt and the chosen/rejected response, with
755
- label_pad_token_id for the prompt tokens.
756
- """
757
- batch = {}
758
- prompt = feature["prompt"]
759
- chosen = feature["chosen"]
760
- rejected = feature["rejected"]
761
-
762
- if not self.is_encoder_decoder:
763
- # Check issues below for more details
764
- # 1. https://github.com/huggingface/trl/issues/907
765
- # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
766
- # 3. https://github.com/LianjiaTech/BELLE/issues/337
767
-
768
- if not isinstance(prompt, str):
769
- raise ValueError(f"prompt should be an str but got {type(prompt)}")
770
- prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
771
- prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
772
-
773
- if not isinstance(chosen, str):
774
- raise ValueError(f"chosen should be an str but got {type(chosen)}")
775
- chosen_tokens = self.build_tokenized_answer(prompt, chosen)
776
-
777
- if not isinstance(rejected, str):
778
- raise ValueError(f"rejected should be an str but got {type(rejected)}")
779
- rejected_tokens = self.build_tokenized_answer(prompt, rejected)
780
-
781
- # Last prompt token might get merged by tokenizer and
782
- # it should not be included for generation if that happens
783
- prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
784
-
785
- chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
786
- rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
787
- prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
788
-
789
- for k, v in prompt_tokens.items():
790
- prompt_tokens[k] = v[:prompt_len_input_ids]
791
-
792
- # Make sure prompts only have one different token at most an
793
- # and length only differs by 1 at most
794
- num_diff_tokens = sum(
795
- [a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
796
- )
797
- num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
798
- if num_diff_tokens > 1 or num_diff_len > 1:
799
- raise ValueError(
800
- "Chosen and rejected prompt_input_ids might only differ on the "
801
- "last token due to tokenizer merge ops."
802
- )
803
-
804
- # add BOS token to head of prompt. Avoid adding if it's already there
805
- prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
806
- self.processing_class.bos_token_id,
807
- prompt_len_input_ids,
808
- prompt_tokens,
809
- chosen_prompt_len_input_ids,
810
- chosen_tokens,
811
- rejected_prompt_len_input_ids,
812
- rejected_tokens,
813
- )
814
-
815
- # add EOS token to end of answer. Avoid adding if it's already there
816
- chosen_tokens, rejected_tokens = add_eos_token_if_needed(
817
- self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
818
- )
819
-
820
- longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
821
-
822
- # if combined sequence is too long, truncate the prompt
823
- for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
824
- if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
825
- if self.truncation_mode == "keep_start":
826
- for k in ["prompt_input_ids", "prompt_attention_mask"]:
827
- answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
828
- elif self.truncation_mode == "keep_end":
829
- for k in ["prompt_input_ids", "prompt_attention_mask"]:
830
- answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
831
- else:
832
- raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
833
-
834
- # if that's still too long, truncate the response
835
- for answer_tokens in [chosen_tokens, rejected_tokens]:
836
- if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
837
- for k in ["input_ids", "attention_mask"]:
838
- answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
839
-
840
- # Create labels
841
- chosen_sequence_tokens = {
842
- k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
843
- }
844
- rejected_sequence_tokens = {
845
- k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
846
- }
847
- chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
848
- chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
849
- self.label_pad_token_id
850
- ] * len(chosen_tokens["prompt_input_ids"])
851
- rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
852
- rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
853
- self.label_pad_token_id
854
- ] * len(rejected_tokens["prompt_input_ids"])
855
-
856
- for k, toks in {
857
- "chosen_": chosen_sequence_tokens,
858
- "rejected_": rejected_sequence_tokens,
859
- "": prompt_tokens,
860
- }.items():
861
- for type_key, tokens in toks.items():
862
- if type_key == "token_type_ids":
863
- continue
864
- batch[f"{k}{type_key}"] = tokens
865
-
866
- else:
867
- chosen_tokens = self.processing_class(
868
- chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
869
- )
870
- rejected_tokens = self.processing_class(
871
- rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
872
- )
873
- prompt_tokens = self.processing_class(
874
- prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
875
- )
876
-
877
- batch["chosen_labels"] = chosen_tokens["input_ids"]
878
- batch["rejected_labels"] = rejected_tokens["input_ids"]
879
- batch["prompt_input_ids"] = prompt_tokens["input_ids"]
880
- batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
881
-
882
- if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
883
- batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
884
- labels=torch.tensor(batch["rejected_labels"])
885
- )
886
- batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
887
- labels=torch.tensor(batch["chosen_labels"])
888
- )
889
-
890
- return batch
891
-
892
- @staticmethod
893
- def concatenated_inputs(
894
- batch: dict[str, Union[list, torch.LongTensor]],
895
- is_encoder_decoder: bool = False,
896
- label_pad_token_id: int = -100,
897
- padding_value: int = 0,
898
- device: Optional[torch.device] = None,
899
- ) -> dict[str, torch.LongTensor]:
900
- """Concatenate the chosen and rejected inputs into a single tensor.
901
-
902
- Args:
903
- batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
904
- is_encoder_decoder: Whether the model is an encoder-decoder model.
905
- label_pad_token_id: The label pad token id.
906
- padding_value: The padding value to use for the concatenated inputs_ids.
907
- device: The device for the concatenated inputs.
908
-
909
- Returns:
910
- A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
911
- """
912
- concatenated_batch = {}
913
-
914
- if is_encoder_decoder:
915
- max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
916
- else:
917
- max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
918
-
919
- for k in batch:
920
- if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
921
- if "labels" in k or is_encoder_decoder:
922
- pad_value = label_pad_token_id
923
- elif k.endswith("_input_ids"):
924
- pad_value = padding_value
925
- elif k.endswith("_attention_mask"):
926
- pad_value = 0
927
- concatenated_key = k.replace("chosen", "concatenated")
928
- concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
929
- for k in batch:
930
- if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
931
- if "labels" in k or is_encoder_decoder:
932
- pad_value = label_pad_token_id
933
- elif k.endswith("_input_ids"):
934
- pad_value = padding_value
935
- elif k.endswith("_attention_mask"):
936
- pad_value = 0
937
- concatenated_key = k.replace("rejected", "concatenated")
938
- concatenated_batch[concatenated_key] = torch.cat(
939
- (
940
- concatenated_batch[concatenated_key],
941
- pad_to_length(batch[k], max_length, pad_value=pad_value),
942
- ),
943
- dim=0,
944
- ).to(device=device)
945
-
946
- if is_encoder_decoder:
947
- concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
948
- concatenated_batch["concatenated_attention_mask"] = (
949
- batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
950
- )
951
-
952
- return concatenated_batch
953
-
954
- def cpo_loss(
955
- self,
956
- policy_chosen_logps: torch.FloatTensor,
957
- policy_rejected_logps: torch.FloatTensor,
958
- ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
959
- """Compute the CPO loss for a batch of policy and reference model log probabilities.
960
-
961
- Args:
962
- policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
963
- policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
964
-
965
- Returns:
966
- A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
967
- The losses tensor contains the CPO loss for each example in the batch.
968
- The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
969
- """
970
- logits = (policy_chosen_logps - policy_rejected_logps).to(self.accelerator.device)
971
-
972
- # The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5.
973
- # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
974
- # calculates a conservative CPO loss.
975
-
976
- if self.loss_type == "simpo":
977
- gamma_logratios = self.simpo_gamma / self.beta
978
- logits = logits - gamma_logratios
979
- # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
980
- losses = (
981
- -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
982
- - F.logsigmoid(-self.beta * logits) * self.label_smoothing
983
- )
984
- elif self.loss_type == "sigmoid":
985
- # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
986
- losses = (
987
- -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
988
- - F.logsigmoid(-self.beta * logits) * self.label_smoothing
989
- )
990
- elif self.loss_type == "hinge":
991
- losses = torch.relu(1 - self.beta * logits)
992
- elif self.loss_type == "ipo":
993
- # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
994
- losses = (logits - 1 / (2 * self.beta)) ** 2
995
- else:
996
- raise ValueError(
997
- f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'simpo']"
998
- )
999
-
1000
- chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
1001
- rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
1002
-
1003
- return losses, chosen_rewards, rejected_rewards
1004
-
1005
- @staticmethod
1006
- def get_batch_logps(
1007
- logits: torch.FloatTensor,
1008
- labels: torch.LongTensor,
1009
- average_log_prob: bool = False,
1010
- label_pad_token_id: int = -100,
1011
- is_encoder_decoder: bool = False,
1012
- ) -> torch.FloatTensor:
1013
- """Compute the log probabilities of the given labels under the given logits.
1014
-
1015
- Args:
1016
- logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
1017
- labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
1018
- average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
1019
- label_pad_token_id: The label pad token id.
1020
- is_encoder_decoder: Whether the model is an encoder-decoder model.
1021
-
1022
- Returns:
1023
- A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
1024
- """
1025
- if logits.shape[:-1] != labels.shape:
1026
- raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
1027
-
1028
- if not is_encoder_decoder:
1029
- labels = labels[:, 1:].clone()
1030
- logits = logits[:, :-1, :]
1031
- loss_mask = labels != label_pad_token_id
1032
-
1033
- # dummy token; we'll ignore the losses on these tokens later
1034
- labels[labels == label_pad_token_id] = 0
1035
-
1036
- per_token_logps = selective_log_softmax(logits, labels)
1037
-
1038
- if average_log_prob:
1039
- return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1040
- else:
1041
- return (per_token_logps * loss_mask).sum(-1)
1042
-
1043
- def concatenated_forward(
1044
- self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
1045
- ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1046
- """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
1047
-
1048
- We do this to avoid doing two forward passes, because it's faster for FSDP.
1049
- """
1050
- concatenated_batch = self.concatenated_inputs(
1051
- batch,
1052
- is_encoder_decoder=self.is_encoder_decoder,
1053
- label_pad_token_id=self.label_pad_token_id,
1054
- padding_value=self.padding_value,
1055
- device=self.accelerator.device,
1056
- )
1057
- len_chosen = batch["chosen_labels"].shape[0]
1058
-
1059
- model_kwargs = (
1060
- {
1061
- "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
1062
- }
1063
- if self.is_encoder_decoder
1064
- else {}
1065
- )
1066
-
1067
- if self.aux_loss_enabled:
1068
- model_kwargs["output_router_logits"] = True
1069
-
1070
- outputs = model(
1071
- concatenated_batch["concatenated_input_ids"],
1072
- attention_mask=concatenated_batch["concatenated_attention_mask"],
1073
- use_cache=False,
1074
- **model_kwargs,
1075
- )
1076
- all_logits = outputs.logits
1077
-
1078
- def cross_entropy_loss(logits, labels):
1079
- if not self.is_encoder_decoder:
1080
- # Shift so that tokens < n predict n
1081
- logits = logits[..., :-1, :].contiguous()
1082
- labels = labels[..., 1:].contiguous()
1083
- # Flatten the tokens
1084
- loss_fct = nn.CrossEntropyLoss()
1085
- logits = logits.view(-1, logits.shape[-1])
1086
- labels = labels.view(-1)
1087
- # Enable model parallelism
1088
- labels = labels.to(logits.device)
1089
- loss = loss_fct(logits, labels)
1090
- return loss
1091
-
1092
- labels = concatenated_batch["concatenated_labels"].clone()
1093
-
1094
- if self.cpo_alpha == 0:
1095
- nll_loss = torch.tensor(0.0).to(self.accelerator.device)
1096
- else:
1097
- nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
1098
-
1099
- all_logps = self.get_batch_logps(
1100
- all_logits,
1101
- concatenated_batch["concatenated_labels"],
1102
- average_log_prob=self.loss_type in ["ipo", "simpo"],
1103
- is_encoder_decoder=self.is_encoder_decoder,
1104
- label_pad_token_id=self.label_pad_token_id,
1105
- )
1106
-
1107
- chosen_logps = all_logps[:len_chosen]
1108
- rejected_logps = all_logps[len_chosen:]
1109
-
1110
- chosen_logits = all_logits[:len_chosen]
1111
- rejected_logits = all_logits[len_chosen:]
1112
-
1113
- if self.aux_loss_enabled:
1114
- return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss)
1115
-
1116
- return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)
1117
-
1118
- def get_batch_loss_metrics(
1119
- self,
1120
- model,
1121
- batch: dict[str, Union[list, torch.LongTensor]],
1122
- train_eval: Literal["train", "eval"] = "train",
1123
- ):
1124
- """Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
1125
- metrics = {}
1126
-
1127
- forward_output = self.concatenated_forward(model, batch)
1128
- (
1129
- policy_chosen_logps,
1130
- policy_rejected_logps,
1131
- policy_chosen_logits,
1132
- policy_rejected_logits,
1133
- policy_nll_loss,
1134
- ) = forward_output[:5]
1135
- if self.aux_loss_enabled:
1136
- aux_loss = forward_output[5]
1137
-
1138
- losses, chosen_rewards, rejected_rewards = self.cpo_loss(
1139
- policy_chosen_logps,
1140
- policy_rejected_logps,
1141
- )
1142
-
1143
- loss = losses.mean() + self.cpo_alpha * policy_nll_loss
1144
- reward_accuracies = (chosen_rewards > rejected_rewards).float()
1145
-
1146
- prefix = "eval_" if train_eval == "eval" else ""
1147
- metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
1148
- metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
1149
- metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
1150
- metrics[f"{prefix}rewards/margins"] = (
1151
- self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
1152
- )
1153
- metrics[f"{prefix}logps/rejected"] = (
1154
- self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean().item()
1155
- )
1156
- metrics[f"{prefix}logps/chosen"] = (
1157
- self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean().item()
1158
- )
1159
- metrics[f"{prefix}logits/rejected"] = (
1160
- self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean().item()
1161
- )
1162
- metrics[f"{prefix}logits/chosen"] = (
1163
- self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean().item()
1164
- )
1165
- metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
1166
-
1167
- if self.aux_loss_enabled:
1168
- loss += self.aux_loss_coef * aux_loss
1169
-
1170
- return loss, metrics
1171
-
1172
- def compute_loss(
1173
- self,
1174
- model: Union[PreTrainedModel, nn.Module],
1175
- inputs: dict[str, Union[torch.Tensor, Any]],
1176
- return_outputs=False,
1177
- num_items_in_batch=None,
1178
- ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1179
- compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1180
-
1181
- with compute_loss_context_manager:
1182
- loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
1183
-
1184
- # force log the metrics
1185
- self.store_metrics(metrics, train_eval="train")
1186
-
1187
- if return_outputs:
1188
- return (loss, metrics)
1189
- return loss
1190
-
1191
- def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
1192
- """Generate samples from the model and reference model for the given batch of inputs."""
1193
-
1194
- # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1195
- # the torch cuda amp context manager as some hidden states are silently casted to full precision.
1196
- generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1197
-
1198
- with generate_context_manager:
1199
- policy_output = model.generate(
1200
- input_ids=batch["prompt_input_ids"],
1201
- attention_mask=batch["prompt_attention_mask"],
1202
- max_length=self.max_length,
1203
- do_sample=True,
1204
- pad_token_id=self.processing_class.pad_token_id,
1205
- )
1206
-
1207
- policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1208
- policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1209
-
1210
- return policy_output_decoded
1211
-
1212
- def prediction_step(
1213
- self,
1214
- model: Union[PreTrainedModel, nn.Module],
1215
- inputs: dict[str, Union[torch.Tensor, Any]],
1216
- prediction_loss_only: bool,
1217
- ignore_keys: Optional[list[str]] = None,
1218
- ):
1219
- if ignore_keys is None:
1220
- if hasattr(model, "config"):
1221
- ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1222
- else:
1223
- ignore_keys = []
1224
-
1225
- prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1226
-
1227
- with torch.no_grad(), prediction_context_manager:
1228
- loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
1229
-
1230
- # force log the metrics
1231
- self.store_metrics(metrics, train_eval="eval")
1232
-
1233
- if prediction_loss_only:
1234
- return (loss.detach(), None, None)
1235
-
1236
- # logits for the chosen and rejected samples from model
1237
- logits_dict = {
1238
- "eval_logits/chosen": metrics["eval_logits/chosen"],
1239
- "eval_logits/rejected": metrics["eval_logits/rejected"],
1240
- }
1241
- logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
1242
- logits = torch.tensor(logits, device=self.accelerator.device)
1243
- labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1244
-
1245
- return (loss.detach(), logits, labels)
1246
-
1247
- def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1248
- for key, value in metrics.items():
1249
- self._stored_metrics[train_eval][key].append(value)
1250
-
1251
- def evaluation_loop(
1252
- self,
1253
- dataloader: DataLoader,
1254
- description: str,
1255
- prediction_loss_only: Optional[bool] = None,
1256
- ignore_keys: Optional[list[str]] = None,
1257
- metric_key_prefix: str = "eval",
1258
- ) -> EvalLoopOutput:
1259
- """
1260
- Overriding built-in evaluation loop to store metrics for each batch.
1261
- Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
1262
-
1263
- Works both with or without labels.
1264
- """
1265
-
1266
- # Sample and save to game log if requested (for one batch to save time)
1267
- if self.generate_during_eval:
1268
- # Generate random indices within the range of the total number of samples
1269
- num_samples = len(dataloader.dataset)
1270
- random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1271
-
1272
- # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1273
- random_batch_dataset = dataloader.dataset.select(random_indices)
1274
- random_batch = self.data_collator(random_batch_dataset)
1275
- random_batch = self._prepare_inputs(random_batch)
1276
-
1277
- policy_output_decoded = self.generate_from_model(self.model, random_batch)
1278
-
1279
- table = pd.DataFrame(
1280
- columns=["Prompt", "Policy"],
1281
- data=[
1282
- [prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
1283
- ],
1284
- )
1285
- if "wandb" in self.args.report_to:
1286
- wandb.log({"game_log": wandb.Table(data=table)})
1287
-
1288
- if "comet_ml" in self.args.report_to:
1289
- log_table_to_comet_experiment(
1290
- name="game_log.csv",
1291
- table=table,
1292
- )
1293
-
1294
- # Base evaluation
1295
- initial_output = super().evaluation_loop(
1296
- dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1297
- )
1298
-
1299
- return initial_output
1300
-
1301
- def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1302
- """
1303
- Log `logs` on the various objects watching training, including stored metrics.
1304
-
1305
- Args:
1306
- logs (`dict[str, float]`):
1307
- The values to log.
1308
- start_time (`float` or `None`, *optional*, defaults to `None`):
1309
- Start time of the training.
1310
- """
1311
- # logs either has 'loss' or 'eval_loss'
1312
- train_eval = "train" if "loss" in logs else "eval"
1313
- # Add averaged stored metrics to logs
1314
- for key, metrics in self._stored_metrics[train_eval].items():
1315
- logs[key] = torch.tensor(metrics).mean().item()
1316
- del self._stored_metrics[train_eval]
1317
-
1318
- if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1319
- return super().log(logs, start_time)
1320
- else: # transformers<=4.46
1321
- return super().log(logs)
1322
-
1323
- def _shift_right(self, input_ids):
1324
- if self.decoder_start_token_id is None:
1325
- raise ValueError(
1326
- "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
1327
- )
1328
-
1329
- # shift inputs to the right
1330
- if is_torch_fx_proxy(input_ids):
1331
- # Item assignment is not supported natively for proxies.
1332
- shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
1333
- shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
1334
- else:
1335
- shifted_input_ids = input_ids.new_zeros(input_ids.shape)
1336
- shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
1337
- shifted_input_ids[..., 0] = self.decoder_start_token_id
1338
-
1339
- if self.pad_token_id is None:
1340
- raise ValueError("model.config.pad_token_id has to be defined.")
1341
- # replace possible -100 values in labels by `pad_token_id`
1342
- shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
1343
-
1344
- return shifted_input_ids
1345
-
1346
- def create_model_card(
1347
- self,
1348
- model_name: Optional[str] = None,
1349
- dataset_name: Optional[str] = None,
1350
- tags: Union[str, list[str], None] = None,
1351
- ):
1352
- """
1353
- Creates a draft of a model card using the information available to the `Trainer`.
1354
-
1355
- Args:
1356
- model_name (`str` or `None`, *optional*, defaults to `None`):
1357
- Name of the model.
1358
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
1359
- Name of the dataset used for training.
1360
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1361
- Tags to be associated with the model card.
1362
- """
1363
- if not self.is_world_process_zero():
1364
- return
1365
-
1366
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1367
- base_model = self.model.config._name_or_path
1368
- else:
1369
- base_model = None
1370
-
1371
- tags = tags or []
1372
- if isinstance(tags, str):
1373
- tags = [tags]
1374
-
1375
- if hasattr(self.model.config, "unsloth_version"):
1376
- tags.append("unsloth")
1377
-
1378
- citation = textwrap.dedent("""\
1379
- @inproceedings{xu2024contrastive,
1380
- title = {{Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation}},
1381
- author = {Haoran Xu and Amr Sharaf and Yunmo Chen and Weiting Tan and Lingfeng Shen and Benjamin Van Durme and Kenton Murray and Young Jin Kim},
1382
- year = 2024,
1383
- booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
1384
- publisher = {OpenReview.net},
1385
- url = {https://openreview.net/forum?id=51iwkioZpn}
1386
- }""")
1387
-
1388
- model_card = generate_model_card(
1389
- base_model=base_model,
1390
- model_name=model_name,
1391
- hub_model_id=self.hub_model_id,
1392
- dataset_name=dataset_name,
1393
- tags=tags,
1394
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1395
- comet_url=get_comet_experiment_url(),
1396
- trainer_name="CPO",
1397
- trainer_citation=citation,
1398
- paper_title="Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation",
1399
- paper_id="2401.08417",
1400
- )
1401
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
1402
- class UnslothCPOTrainer(_UnslothCPOTrainer):
1403
- """
1404
-
1405
- Initialize CPOTrainer.
1406
-
1407
- Args:
1408
- model (`transformers.PreTrainedModel`):
1409
- The model to train, preferably an `AutoModelForSequenceClassification`.
1410
- args (`CPOConfig`):
1411
- The CPO config arguments to use for training.
1412
- data_collator (`transformers.DataCollator`):
1413
- The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1414
- which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1415
- train_dataset (`datasets.Dataset`):
1416
- The dataset to use for training.
1417
- eval_dataset (`datasets.Dataset`):
1418
- The dataset to use for evaluation.
1419
- processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1420
- Processing class used to process the data. If provided, will be used to automatically process the inputs
1421
- for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1422
- reuse the fine-tuned model.
1423
- model_init (`Callable[[], transformers.PreTrainedModel]`):
1424
- The model initializer to use for training. If None is specified, the default model initializer will be used.
1425
- callbacks (`list[transformers.TrainerCallback]`):
1426
- The callbacks to use for training.
1427
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1428
- The optimizer and scheduler to use for training.
1429
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1430
- The function to use to preprocess the logits before computing the metrics.
1431
- peft_config (`dict`, defaults to `None`):
1432
- The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
1433
- compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1434
- The function to use to compute the metrics. Must take a `EvalPrediction` and return
1435
- a dictionary string to metric values.
1436
-
1437
- """
1438
- def __init__(
1439
- self,
1440
- model = None,
1441
- args = None,
1442
- data_collator = None,
1443
- train_dataset = None,
1444
- eval_dataset = None,
1445
- processing_class = None,
1446
- model_init = None,
1447
- callbacks = None,
1448
- preprocess_logits_for_metrics = None,
1449
- peft_config = None,
1450
- compute_metrics = None,
1451
- **kwargs
1452
- ):
1453
- if args is None: args = UnslothCPOConfig()
1454
- use_bf16 = getattr(args, 'bf16', False)
1455
- if type(use_bf16) is not bool: use_bf16 = False
1456
- use_fp16 = getattr(args, 'fp16', False)
1457
- if type(use_fp16) is not bool: use_fp16 = False
1458
- force_float32 = False
1459
- if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1460
- print('Unsloth: Switching to float32 training since model cannot work with float16')
1461
- force_float32 = True
1462
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1463
- dtype = getattr(model.config, 'torch_dtype', None)
1464
- if dtype is None: dtype = model.get_input_embeddings().dtype
1465
- from unsloth_zoo.utils import _get_dtype
1466
- dtype = _get_dtype(dtype)
1467
- float16 = dtype == torch.float16
1468
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1469
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1470
- if force_float32:
1471
- args.fp16 = False
1472
- args.bf16 = False
1473
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1474
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1475
- args.fp16 = float16
1476
- args.bf16 = not float16
1477
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1478
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1479
- args.eval_strategy = 'steps'
1480
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1481
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1482
- if ga_steps is not None and ga_steps > 1:
1483
- from transformers import __version__ as transformers_version
1484
- if Version(transformers_version) <= Version('4.45.2'):
1485
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1486
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1487
- if getattr(args, 'eval_strategy', 'no') != 'no':
1488
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1489
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1490
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1491
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1492
- if type(fp16_full_eval) is not bool: fp16_full_eval = False
1493
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1494
- if type(bf16_full_eval) is not bool: bf16_full_eval = False
1495
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1496
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1497
- if force_float32:
1498
- args.bf16_full_eval = False
1499
- args.fp16_full_eval = False
1500
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1501
- args.bf16_full_eval = True
1502
- args.fp16_full_eval = False
1503
- elif not bf16_full_eval and not fp16_full_eval:
1504
- args.bf16_full_eval = args.bf16
1505
- args.fp16_full_eval = args.fp16
1506
- _output_logits = False
1507
- if locals().get('compute_metrics', None) is not None: _output_logits = True
1508
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1509
- if _output_logits:
1510
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1511
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1512
- pass
1513
- else:
1514
- model_max_seq_length = getattr(model, 'max_seq_length', None)
1515
- args_max_seq_length = getattr(args, 'max_seq_length', None)
1516
- if args_max_seq_length is None and model_max_seq_length is not None:
1517
- max_seq_length = model.max_seq_length
1518
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1519
- if model is not None and hasattr(model, 'for_training'):
1520
- model.for_training()
1521
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1522
- if 'processing_class' in locals():
1523
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1524
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1525
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1526
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1527
- if not isinstance(data_collator, UnslothVisionDataCollator):
1528
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1529
- data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
1530
- elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1531
- data_collator = DataCollatorForSeq2Seq(__tokenizer)
1532
- else:
1533
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1534
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1535
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1536
- if not isinstance(data_collator, UnslothVisionDataCollator):
1537
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1538
- if isinstance(data_collator, DataCollatorForSeq2Seq):
1539
- data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1540
- else:
1541
- data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
1542
- other_metrics = []
1543
-
1544
- from unsloth_zoo.logging_utils import PatchRLStatistics
1545
- PatchRLStatistics('cpo_trainer', other_metrics)
1546
-
1547
- super().__init__(
1548
- model = model,
1549
- args = args,
1550
- data_collator = data_collator,
1551
- train_dataset = train_dataset,
1552
- eval_dataset = eval_dataset,
1553
- processing_class = processing_class,
1554
- model_init = model_init,
1555
- callbacks = callbacks,
1556
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1557
- peft_config = peft_config,
1558
- compute_metrics = compute_metrics,**kwargs)
1559
- if hasattr(self, 'neftune_hook_handle'):
1560
- self.neftune_hook_handle.remove()
1561
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1562
- if getattr(args, 'neftune_noise_alpha', None) is not None:
1563
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1564
- pass
1565
-
1566
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_run_uploads/UnslothDDPOTrainer.py DELETED
@@ -1,881 +0,0 @@
1
- """
2
- 2025.7.11
3
- 2025.7.11
4
- 4.54.1
5
- 0.16.1
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
13
- from trl.trainer.ddpo_trainer import (Accelerator, Any, Callable, DDPOConfig, DDPOStableDiffusionPipeline, DDPOTrainer, Optional, PerPromptStatTracker, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, futures, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, wandb, warn)
14
-
15
-
16
- import os
17
- from typing import *
18
- from dataclasses import dataclass, field
19
- from packaging.version import Version
20
- import torch
21
- import numpy as np
22
- from contextlib import nullcontext
23
- from torch.nn import functional as F
24
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
25
-
26
- torch_compile_options = {
27
- "epilogue_fusion" : True,
28
- "max_autotune" : False,
29
- "shape_padding" : True,
30
- "trace.enabled" : False,
31
- "triton.cudagraphs" : False,
32
- }
33
-
34
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
35
- def chunked_selective_log_softmax(logits, index):
36
- # Split into 4 chunks only
37
- chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
38
- chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
39
- all_per_token_logps = []
40
- # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
41
- for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
42
- chunk_logits = chunk_logits.to(torch.float32)
43
- selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
44
- logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
45
- per_token_logps = selected_logits - logsumexp_values
46
- all_per_token_logps.append(per_token_logps)
47
- pass
48
- all_per_token_logps = torch.concat(all_per_token_logps)
49
- all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
50
- return all_per_token_logps
51
- @dataclass
52
- class UnslothDDPOConfig(DDPOConfig):
53
- """
54
-
55
- Configuration class for the [`DDPOTrainer`].
56
-
57
- Using [`~transformers.HfArgumentParser`] we can turn this class into
58
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
59
- command line.
60
-
61
- Parameters:
62
- exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
63
- Name of this experiment (by default is the file name without the extension name).
64
- run_name (`str`, *optional*, defaults to `""`):
65
- Name of this run.
66
- seed (`int`, *optional*, defaults to `0`):
67
- Random seed.
68
- log_with (`Literal["wandb", "tensorboard"]]` or `None`, *optional*, defaults to `None`):
69
- Log with either 'wandb' or 'tensorboard', check
70
- https://huggingface.co/docs/accelerate/usage_guides/tracking for more details.
71
- tracker_kwargs (`Dict`, *optional*, defaults to `{}`):
72
- Keyword arguments for the tracker (e.g. wandb_project).
73
- accelerator_kwargs (`Dict`, *optional*, defaults to `{}`):
74
- Keyword arguments for the accelerator.
75
- project_kwargs (`Dict`, *optional*, defaults to `{}`):
76
- Keyword arguments for the accelerator project config (e.g. `logging_dir`).
77
- tracker_project_name (`str`, *optional*, defaults to `"trl"`):
78
- Name of project to use for tracking.
79
- logdir (`str`, *optional*, defaults to `"logs"`):
80
- Top-level logging directory for checkpoint saving.
81
- num_epochs (`int`, *optional*, defaults to `100`):
82
- Number of epochs to train.
83
- save_freq (`int`, *optional*, defaults to `1`):
84
- Number of epochs between saving model checkpoints.
85
- num_checkpoint_limit (`int`, *optional*, defaults to `5`):
86
- Number of checkpoints to keep before overwriting old ones.
87
- mixed_precision (`str`, *optional*, defaults to `"fp16"`):
88
- Mixed precision training.
89
- allow_tf32 (`bool`, *optional*, defaults to `True`):
90
- Allow `tf32` on Ampere GPUs.
91
- resume_from (`str`, *optional*, defaults to `""`):
92
- Resume training from a checkpoint.
93
- sample_num_steps (`int`, *optional*, defaults to `50`):
94
- Number of sampler inference steps.
95
- sample_eta (`float`, *optional*, defaults to `1.0`):
96
- Eta parameter for the DDIM sampler.
97
- sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
98
- Classifier-free guidance weight.
99
- sample_batch_size (`int`, *optional*, defaults to `1`):
100
- Batch size (per GPU) to use for sampling.
101
- sample_num_batches_per_epoch (`int`, *optional*, defaults to `2`):
102
- Number of batches to sample per epoch.
103
- train_batch_size (`int`, *optional*, defaults to `1`):
104
- Batch size (per GPU) to use for training.
105
- train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
106
- Use 8bit Adam optimizer from bitsandbytes.
107
- train_learning_rate (`float`, *optional*, defaults to `3e-4`):
108
- Learning rate.
109
- train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
110
- Adam beta1.
111
- train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
112
- Adam beta2.
113
- train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
114
- Adam weight decay.
115
- train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
116
- Adam epsilon.
117
- train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
118
- Number of gradient accumulation steps.
119
- train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
120
- Maximum gradient norm for gradient clipping.
121
- train_num_inner_epochs (`int`, *optional*, defaults to `1`):
122
- Number of inner epochs per outer epoch.
123
- train_cfg (`bool`, *optional*, defaults to `True`):
124
- Whether to use classifier-free guidance during training.
125
- train_adv_clip_max (`float`, *optional*, defaults to `5.0`):
126
- Clip advantages to the range.
127
- train_clip_range (`float`, *optional*, defaults to `1e-4`):
128
- PPO clip range.
129
- train_timestep_fraction (`float`, *optional*, defaults to `1.0`):
130
- Fraction of timesteps to train on.
131
- per_prompt_stat_tracking (`bool`, *optional*, defaults to `False`):
132
- Whether to track statistics for each prompt separately.
133
- per_prompt_stat_tracking_buffer_size (`int`, *optional*, defaults to `16`):
134
- Number of reward values to store in the buffer for each prompt.
135
- per_prompt_stat_tracking_min_count (`int`, *optional*, defaults to `16`):
136
- Minimum number of reward values to store in the buffer.
137
- async_reward_computation (`bool`, *optional*, defaults to `False`):
138
- Whether to compute rewards asynchronously.
139
- max_workers (`int`, *optional*, defaults to `2`):
140
- Maximum number of workers to use for async reward computation.
141
- negative_prompts (`str`, *optional*, defaults to `""`):
142
- Comma-separated list of prompts to use as negative examples.
143
- push_to_hub (`bool`, *optional*, defaults to `False`):
144
- Whether to push the final model checkpoint to the Hub.
145
-
146
- """
147
- vllm_sampling_params: Optional[Any] = field(
148
- default = None,
149
- metadata = {'help': 'vLLM SamplingParams'},
150
- )
151
- unsloth_num_chunks : Optional[int] = field(
152
- default = -1,
153
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
154
- )
155
- def __init__(
156
- self,
157
- exp_name = 'colab_kernel_launcher',
158
- run_name = '',
159
- seed = 3407,
160
- log_with = None,
161
- tracker_project_name = 'trl',
162
- logdir = 'logs',
163
- num_epochs = 100,
164
- save_freq = 1,
165
- num_checkpoint_limit = 5,
166
- mixed_precision = 'fp16',
167
- allow_tf32 = True,
168
- resume_from = '',
169
- sample_num_steps = 50,
170
- sample_eta = 1.0,
171
- sample_guidance_scale = 5.0,
172
- sample_batch_size = 1,
173
- sample_num_batches_per_epoch = 2,
174
- train_batch_size = 1,
175
- train_use_8bit_adam = False,
176
- train_learning_rate = 5e-05,
177
- train_adam_beta1 = 0.9,
178
- train_adam_beta2 = 0.999,
179
- train_adam_weight_decay = 0.01,
180
- train_adam_epsilon = 1e-08,
181
- train_gradient_accumulation_steps = 2,
182
- train_max_grad_norm = 1.0,
183
- train_num_inner_epochs = 1,
184
- train_cfg = True,
185
- train_adv_clip_max = 5.0,
186
- train_clip_range = 0.0001,
187
- train_timestep_fraction = 1.0,
188
- per_prompt_stat_tracking = False,
189
- per_prompt_stat_tracking_buffer_size = 16,
190
- per_prompt_stat_tracking_min_count = 16,
191
- async_reward_computation = False,
192
- max_workers = 2,
193
- negative_prompts = '',
194
- push_to_hub = False,
195
- vllm_sampling_params = None,
196
- unsloth_num_chunks = -1,
197
- **kwargs,
198
- ):
199
-
200
- super().__init__(
201
- exp_name = exp_name,
202
- run_name = run_name,
203
- seed = seed,
204
- log_with = log_with,
205
- tracker_project_name = tracker_project_name,
206
- logdir = logdir,
207
- num_epochs = num_epochs,
208
- save_freq = save_freq,
209
- num_checkpoint_limit = num_checkpoint_limit,
210
- mixed_precision = mixed_precision,
211
- allow_tf32 = allow_tf32,
212
- resume_from = resume_from,
213
- sample_num_steps = sample_num_steps,
214
- sample_eta = sample_eta,
215
- sample_guidance_scale = sample_guidance_scale,
216
- sample_batch_size = sample_batch_size,
217
- sample_num_batches_per_epoch = sample_num_batches_per_epoch,
218
- train_batch_size = train_batch_size,
219
- train_use_8bit_adam = train_use_8bit_adam,
220
- train_learning_rate = train_learning_rate,
221
- train_adam_beta1 = train_adam_beta1,
222
- train_adam_beta2 = train_adam_beta2,
223
- train_adam_weight_decay = train_adam_weight_decay,
224
- train_adam_epsilon = train_adam_epsilon,
225
- train_gradient_accumulation_steps = train_gradient_accumulation_steps,
226
- train_max_grad_norm = train_max_grad_norm,
227
- train_num_inner_epochs = train_num_inner_epochs,
228
- train_cfg = train_cfg,
229
- train_adv_clip_max = train_adv_clip_max,
230
- train_clip_range = train_clip_range,
231
- train_timestep_fraction = train_timestep_fraction,
232
- per_prompt_stat_tracking = per_prompt_stat_tracking,
233
- per_prompt_stat_tracking_buffer_size = per_prompt_stat_tracking_buffer_size,
234
- per_prompt_stat_tracking_min_count = per_prompt_stat_tracking_min_count,
235
- async_reward_computation = async_reward_computation,
236
- max_workers = max_workers,
237
- negative_prompts = negative_prompts,
238
- push_to_hub = push_to_hub,**kwargs)
239
- self.vllm_sampling_params = vllm_sampling_params
240
- self.unsloth_num_chunks = unsloth_num_chunks
241
- pass
242
-
243
- class _UnslothDDPOTrainer(PyTorchModelHubMixin):
244
- """"""
245
-
246
- _tag_names = ["trl", "ddpo"]
247
-
248
- def __init__(
249
- self,
250
- config: DDPOConfig,
251
- reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
252
- prompt_function: Callable[[], tuple[str, Any]],
253
- sd_pipeline: DDPOStableDiffusionPipeline,
254
- image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
255
- ):
256
- if image_samples_hook is None:
257
- warn("No image_samples_hook provided; no images will be logged")
258
-
259
- self.prompt_fn = prompt_function
260
- self.reward_fn = reward_function
261
- self.config = config
262
- self.image_samples_callback = image_samples_hook
263
-
264
- accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
265
-
266
- if self.config.resume_from:
267
- self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
268
- if "checkpoint_" not in os.path.basename(self.config.resume_from):
269
- # get the most recent checkpoint in this directory
270
- checkpoints = list(
271
- filter(
272
- lambda x: "checkpoint_" in x,
273
- os.listdir(self.config.resume_from),
274
- )
275
- )
276
- if len(checkpoints) == 0:
277
- raise ValueError(f"No checkpoints found in {self.config.resume_from}")
278
- checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
279
- self.config.resume_from = os.path.join(
280
- self.config.resume_from,
281
- f"checkpoint_{checkpoint_numbers[-1]}",
282
- )
283
-
284
- accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
285
-
286
- # number of timesteps within each trajectory to train on
287
- self.num_train_timesteps = int(self.config.sample_num_steps * self.config.train_timestep_fraction)
288
-
289
- self.accelerator = Accelerator(
290
- log_with=self.config.log_with,
291
- mixed_precision=self.config.mixed_precision,
292
- project_config=accelerator_project_config,
293
- # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
294
- # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
295
- # the total number of optimizer steps to accumulate across.
296
- gradient_accumulation_steps=self.config.train_gradient_accumulation_steps * self.num_train_timesteps,
297
- **self.config.accelerator_kwargs,
298
- )
299
-
300
- is_okay, message = self._config_check()
301
- if not is_okay:
302
- raise ValueError(message)
303
-
304
- is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
305
-
306
- if self.accelerator.is_main_process:
307
- self.accelerator.init_trackers(
308
- self.config.tracker_project_name,
309
- config=dict(ddpo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(),
310
- init_kwargs=self.config.tracker_kwargs,
311
- )
312
-
313
- logger.info(f"\n{config}")
314
-
315
- set_seed(self.config.seed, device_specific=True)
316
-
317
- self.sd_pipeline = sd_pipeline
318
-
319
- self.sd_pipeline.set_progress_bar_config(
320
- position=1,
321
- disable=not self.accelerator.is_local_main_process,
322
- leave=False,
323
- desc="Timestep",
324
- dynamic_ncols=True,
325
- )
326
-
327
- # For mixed precision training we cast all non-trainable weights [vae, non-lora text_encoder and non-lora unet] to half-precision
328
- # as these weights are only used for inference, keeping weights in full precision is not required.
329
- if self.accelerator.mixed_precision == "fp16":
330
- inference_dtype = torch.float16
331
- elif self.accelerator.mixed_precision == "bf16":
332
- inference_dtype = torch.bfloat16
333
- else:
334
- inference_dtype = torch.float32
335
-
336
- self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
337
- self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
338
- self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
339
-
340
- trainable_layers = self.sd_pipeline.get_trainable_layers()
341
-
342
- self.accelerator.register_save_state_pre_hook(self._save_model_hook)
343
- self.accelerator.register_load_state_pre_hook(self._load_model_hook)
344
-
345
- # Enable TF32 for faster training on Ampere GPUs,
346
- # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
347
- if self.config.allow_tf32:
348
- torch.backends.cuda.matmul.allow_tf32 = True
349
-
350
- self.optimizer = self._setup_optimizer(
351
- trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
352
- )
353
-
354
- self.neg_prompt_embed = self.sd_pipeline.text_encoder(
355
- self.sd_pipeline.tokenizer(
356
- [""] if self.config.negative_prompts is None else self.config.negative_prompts,
357
- return_tensors="pt",
358
- padding="max_length",
359
- truncation=True,
360
- max_length=self.sd_pipeline.tokenizer.model_max_length,
361
- ).input_ids.to(self.accelerator.device)
362
- )[0]
363
-
364
- if config.per_prompt_stat_tracking:
365
- self.stat_tracker = PerPromptStatTracker(
366
- config.per_prompt_stat_tracking_buffer_size,
367
- config.per_prompt_stat_tracking_min_count,
368
- )
369
-
370
- # NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
371
- # more memory
372
- self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
373
-
374
- if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
375
- unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
376
- self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
377
- else:
378
- self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
379
-
380
- if self.config.async_reward_computation:
381
- self.executor = futures.ThreadPoolExecutor(max_workers=config.max_workers)
382
-
383
- if config.resume_from:
384
- logger.info(f"Resuming from {config.resume_from}")
385
- self.accelerator.load_state(config.resume_from)
386
- self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
387
- else:
388
- self.first_epoch = 0
389
-
390
- def compute_rewards(self, prompt_image_pairs, is_async=False):
391
- if not is_async:
392
- rewards = []
393
- for images, prompts, prompt_metadata in prompt_image_pairs:
394
- reward, reward_metadata = self.reward_fn(images, prompts, prompt_metadata)
395
- rewards.append(
396
- (
397
- torch.as_tensor(reward, device=self.accelerator.device),
398
- reward_metadata,
399
- )
400
- )
401
- else:
402
- rewards = self.executor.map(lambda x: self.reward_fn(*x), prompt_image_pairs)
403
- rewards = [
404
- (torch.as_tensor(reward.result(), device=self.accelerator.device), reward_metadata.result())
405
- for reward, reward_metadata in rewards
406
- ]
407
-
408
- return zip(*rewards)
409
-
410
- def step(self, epoch: int, global_step: int):
411
- """
412
- Perform a single step of training.
413
-
414
- Args:
415
- epoch (int): The current epoch.
416
- global_step (int): The current global step.
417
-
418
- Side Effects:
419
- - Model weights are updated
420
- - Logs the statistics to the accelerator trackers.
421
- - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
422
-
423
- Returns:
424
- global_step (int): The updated global step.
425
-
426
- """
427
- samples, prompt_image_data = self._generate_samples(
428
- iterations=self.config.sample_num_batches_per_epoch,
429
- batch_size=self.config.sample_batch_size,
430
- )
431
-
432
- # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
433
- samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
434
- rewards, rewards_metadata = self.compute_rewards(
435
- prompt_image_data, is_async=self.config.async_reward_computation
436
- )
437
-
438
- for i, image_data in enumerate(prompt_image_data):
439
- image_data.extend([rewards[i], rewards_metadata[i]])
440
-
441
- if self.image_samples_callback is not None:
442
- self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0])
443
-
444
- rewards = torch.cat(rewards)
445
- rewards = self.accelerator.gather(rewards).cpu().numpy()
446
-
447
- self.accelerator.log(
448
- {
449
- "reward": rewards,
450
- "epoch": epoch,
451
- "reward_mean": rewards.mean(),
452
- "reward_std": rewards.std(),
453
- },
454
- step=global_step,
455
- )
456
-
457
- if self.config.per_prompt_stat_tracking:
458
- # gather the prompts across processes
459
- prompt_ids = self.accelerator.gather(samples["prompt_ids"]).cpu().numpy()
460
- prompts = self.sd_pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True)
461
- advantages = self.stat_tracker.update(prompts, rewards)
462
- else:
463
- advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
464
-
465
- # ungather advantages; keep the entries corresponding to the samples on this process
466
- samples["advantages"] = (
467
- torch.as_tensor(advantages)
468
- .reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index]
469
- .to(self.accelerator.device)
470
- )
471
-
472
- del samples["prompt_ids"]
473
-
474
- total_batch_size, num_timesteps = samples["timesteps"].shape
475
-
476
- for inner_epoch in range(self.config.train_num_inner_epochs):
477
- # shuffle samples along batch dimension
478
- perm = torch.randperm(total_batch_size, device=self.accelerator.device)
479
- samples = {k: v[perm] for k, v in samples.items()}
480
-
481
- # shuffle along time dimension independently for each sample
482
- # still trying to understand the code below
483
- perms = torch.stack(
484
- [torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)]
485
- )
486
-
487
- for key in ["timesteps", "latents", "next_latents", "log_probs"]:
488
- samples[key] = samples[key][
489
- torch.arange(total_batch_size, device=self.accelerator.device)[:, None],
490
- perms,
491
- ]
492
-
493
- original_keys = samples.keys()
494
- original_values = samples.values()
495
- # rebatch them as user defined train_batch_size is different from sample_batch_size
496
- reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values]
497
-
498
- # Transpose the list of original values
499
- transposed_values = zip(*reshaped_values)
500
- # Create new dictionaries for each row of transposed values
501
- samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values]
502
-
503
- self.sd_pipeline.unet.train()
504
- global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched)
505
- # ensure optimization step at the end of the inner epoch
506
- if not self.accelerator.sync_gradients:
507
- raise ValueError(
508
- "Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
509
- )
510
-
511
- if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
512
- self.accelerator.save_state()
513
-
514
- return global_step
515
-
516
- def calculate_loss(self, latents, timesteps, next_latents, log_probs, advantages, embeds):
517
- """
518
- Calculate the loss for a batch of an unpacked sample
519
-
520
- Args:
521
- latents (torch.Tensor):
522
- The latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
523
- timesteps (torch.Tensor):
524
- The timesteps sampled from the diffusion model, shape: [batch_size]
525
- next_latents (torch.Tensor):
526
- The next latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
527
- log_probs (torch.Tensor):
528
- The log probabilities of the latents, shape: [batch_size]
529
- advantages (torch.Tensor):
530
- The advantages of the latents, shape: [batch_size]
531
- embeds (torch.Tensor):
532
- The embeddings of the prompts, shape: [2*batch_size or batch_size, ...]
533
- Note: the "or" is because if train_cfg is True, the expectation is that negative prompts are concatenated to the embeds
534
-
535
- Returns:
536
- loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor)
537
- (all of these are of shape (1,))
538
- """
539
- with self.autocast():
540
- if self.config.train_cfg:
541
- noise_pred = self.sd_pipeline.unet(
542
- torch.cat([latents] * 2),
543
- torch.cat([timesteps] * 2),
544
- embeds,
545
- ).sample
546
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
547
- noise_pred = noise_pred_uncond + self.config.sample_guidance_scale * (
548
- noise_pred_text - noise_pred_uncond
549
- )
550
- else:
551
- noise_pred = self.sd_pipeline.unet(
552
- latents,
553
- timesteps,
554
- embeds,
555
- ).sample
556
- # compute the log prob of next_latents given latents under the current model
557
-
558
- scheduler_step_output = self.sd_pipeline.scheduler_step(
559
- noise_pred,
560
- timesteps,
561
- latents,
562
- eta=self.config.sample_eta,
563
- prev_sample=next_latents,
564
- )
565
-
566
- log_prob = scheduler_step_output.log_probs
567
-
568
- advantages = torch.clamp(
569
- advantages,
570
- -self.config.train_adv_clip_max,
571
- self.config.train_adv_clip_max,
572
- )
573
-
574
- ratio = torch.exp(log_prob - log_probs)
575
-
576
- loss = self.loss(advantages, self.config.train_clip_range, ratio)
577
-
578
- approx_kl = 0.5 * torch.mean((log_prob - log_probs) ** 2)
579
-
580
- clipfrac = torch.mean((torch.abs(ratio - 1.0) > self.config.train_clip_range).float())
581
-
582
- return loss, approx_kl, clipfrac
583
-
584
- def loss(
585
- self,
586
- advantages: torch.Tensor,
587
- clip_range: float,
588
- ratio: torch.Tensor,
589
- ):
590
- unclipped_loss = -advantages * ratio
591
- clipped_loss = -advantages * torch.clamp(
592
- ratio,
593
- 1.0 - clip_range,
594
- 1.0 + clip_range,
595
- )
596
- return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
597
-
598
- def _setup_optimizer(self, trainable_layers_parameters):
599
- if self.config.train_use_8bit_adam:
600
- import bitsandbytes
601
-
602
- optimizer_cls = bitsandbytes.optim.AdamW8bit
603
- else:
604
- optimizer_cls = torch.optim.AdamW
605
-
606
- return optimizer_cls(
607
- trainable_layers_parameters,
608
- lr=self.config.train_learning_rate,
609
- betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
610
- weight_decay=self.config.train_adam_weight_decay,
611
- eps=self.config.train_adam_epsilon,
612
- )
613
-
614
- def _save_model_hook(self, models, weights, output_dir):
615
- self.sd_pipeline.save_checkpoint(models, weights, output_dir)
616
- weights.pop() # ensures that accelerate doesn't try to handle saving of the model
617
-
618
- def _load_model_hook(self, models, input_dir):
619
- self.sd_pipeline.load_checkpoint(models, input_dir)
620
- models.pop() # ensures that accelerate doesn't try to handle loading of the model
621
-
622
- def _generate_samples(self, iterations, batch_size):
623
- """
624
- Generate samples from the model
625
-
626
- Args:
627
- iterations (int): Number of iterations to generate samples for
628
- batch_size (int): Batch size to use for sampling
629
-
630
- Returns:
631
- samples (list[dict[str, torch.Tensor]]), prompt_image_pairs (list[list[Any]])
632
- """
633
- samples = []
634
- prompt_image_pairs = []
635
- self.sd_pipeline.unet.eval()
636
-
637
- sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
638
-
639
- for _ in range(iterations):
640
- prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
641
-
642
- prompt_ids = self.sd_pipeline.tokenizer(
643
- prompts,
644
- return_tensors="pt",
645
- padding="max_length",
646
- truncation=True,
647
- max_length=self.sd_pipeline.tokenizer.model_max_length,
648
- ).input_ids.to(self.accelerator.device)
649
- prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
650
-
651
- with self.autocast():
652
- sd_output = self.sd_pipeline(
653
- prompt_embeds=prompt_embeds,
654
- negative_prompt_embeds=sample_neg_prompt_embeds,
655
- num_inference_steps=self.config.sample_num_steps,
656
- guidance_scale=self.config.sample_guidance_scale,
657
- eta=self.config.sample_eta,
658
- output_type="pt",
659
- )
660
-
661
- images = sd_output.images
662
- latents = sd_output.latents
663
- log_probs = sd_output.log_probs
664
-
665
- latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, ...)
666
- log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1)
667
- timesteps = self.sd_pipeline.scheduler.timesteps.repeat(batch_size, 1) # (batch_size, num_steps)
668
-
669
- samples.append(
670
- {
671
- "prompt_ids": prompt_ids,
672
- "prompt_embeds": prompt_embeds,
673
- "timesteps": timesteps,
674
- "latents": latents[:, :-1], # each entry is the latent before timestep t
675
- "next_latents": latents[:, 1:], # each entry is the latent after timestep t
676
- "log_probs": log_probs,
677
- "negative_prompt_embeds": sample_neg_prompt_embeds,
678
- }
679
- )
680
- prompt_image_pairs.append([images, prompts, prompt_metadata])
681
-
682
- return samples, prompt_image_pairs
683
-
684
- def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples):
685
- """
686
- Train on a batch of samples. Main training segment
687
-
688
- Args:
689
- inner_epoch (int): The current inner epoch
690
- epoch (int): The current epoch
691
- global_step (int): The current global step
692
- batched_samples (list[dict[str, torch.Tensor]]): The batched samples to train on
693
-
694
- Side Effects:
695
- - Model weights are updated
696
- - Logs the statistics to the accelerator trackers.
697
-
698
- Returns:
699
- global_step (int): The updated global step
700
- """
701
- info = defaultdict(list)
702
- for _i, sample in enumerate(batched_samples):
703
- if self.config.train_cfg:
704
- # concat negative prompts to sample prompts to avoid two forward passes
705
- embeds = torch.cat([sample["negative_prompt_embeds"], sample["prompt_embeds"]])
706
- else:
707
- embeds = sample["prompt_embeds"]
708
-
709
- for j in range(self.num_train_timesteps):
710
- with self.accelerator.accumulate(self.sd_pipeline.unet):
711
- loss, approx_kl, clipfrac = self.calculate_loss(
712
- sample["latents"][:, j],
713
- sample["timesteps"][:, j],
714
- sample["next_latents"][:, j],
715
- sample["log_probs"][:, j],
716
- sample["advantages"],
717
- embeds,
718
- )
719
- info["approx_kl"].append(approx_kl)
720
- info["clipfrac"].append(clipfrac)
721
- info["loss"].append(loss)
722
-
723
- self.accelerator.backward(loss)
724
- if self.accelerator.sync_gradients:
725
- self.accelerator.clip_grad_norm_(
726
- self.trainable_layers.parameters()
727
- if not isinstance(self.trainable_layers, list)
728
- else self.trainable_layers,
729
- self.config.train_max_grad_norm,
730
- )
731
- self.optimizer.step()
732
- self.optimizer.zero_grad()
733
-
734
- # Checks if the accelerator has performed an optimization step behind the scenes
735
- if self.accelerator.sync_gradients:
736
- # log training-related stuff
737
- info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
738
- info = self.accelerator.reduce(info, reduction="mean")
739
- info.update({"epoch": epoch, "inner_epoch": inner_epoch})
740
- self.accelerator.log(info, step=global_step)
741
- global_step += 1
742
- info = defaultdict(list)
743
- return global_step
744
-
745
- def _config_check(self) -> tuple[bool, str]:
746
- samples_per_epoch = (
747
- self.config.sample_batch_size * self.accelerator.num_processes * self.config.sample_num_batches_per_epoch
748
- )
749
- total_train_batch_size = (
750
- self.config.train_batch_size
751
- * self.accelerator.num_processes
752
- * self.config.train_gradient_accumulation_steps
753
- )
754
-
755
- if not self.config.sample_batch_size >= self.config.train_batch_size:
756
- return (
757
- False,
758
- f"Sample batch size ({self.config.sample_batch_size}) must be greater than or equal to the train batch size ({self.config.train_batch_size})",
759
- )
760
- if not self.config.sample_batch_size % self.config.train_batch_size == 0:
761
- return (
762
- False,
763
- f"Sample batch size ({self.config.sample_batch_size}) must be divisible by the train batch size ({self.config.train_batch_size})",
764
- )
765
- if not samples_per_epoch % total_train_batch_size == 0:
766
- return (
767
- False,
768
- f"Number of samples per epoch ({samples_per_epoch}) must be divisible by the total train batch size ({total_train_batch_size})",
769
- )
770
- return True, ""
771
-
772
- def train(self, epochs: Optional[int] = None):
773
- """
774
- Train the model for a given number of epochs
775
- """
776
- global_step = 0
777
- if epochs is None:
778
- epochs = self.config.num_epochs
779
- for epoch in range(self.first_epoch, epochs):
780
- global_step = self.step(epoch, global_step)
781
-
782
- def _save_pretrained(self, save_directory):
783
- self.sd_pipeline.save_pretrained(save_directory)
784
- self.create_model_card()
785
-
786
- def create_model_card(
787
- self,
788
- model_name: Optional[str] = None,
789
- dataset_name: Optional[str] = None,
790
- tags: Union[str, list[str], None] = None,
791
- ):
792
- """
793
- Creates a draft of a model card using the information available to the `Trainer`.
794
-
795
- Args:
796
- model_name (`str` or `None`, *optional*, defaults to `None`):
797
- Name of the model.
798
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
799
- Name of the dataset used for training.
800
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
801
- Tags to be associated with the model card.
802
- """
803
- if not self.is_world_process_zero():
804
- return
805
-
806
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
807
- base_model = self.model.config._name_or_path
808
- else:
809
- base_model = None
810
-
811
- tags = tags or []
812
- if isinstance(tags, str):
813
- tags = [tags]
814
-
815
- if hasattr(self.model.config, "unsloth_version"):
816
- tags.append("unsloth")
817
-
818
- citation = textwrap.dedent("""\
819
- @inproceedings{black2024training,
820
- title = {{Training Diffusion Models with Reinforcement Learning}},
821
- author = {Kevin Black and Michael Janner and Yilun Du and Ilya Kostrikov and Sergey Levine},
822
- year = 2024,
823
- booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
824
- publisher = {OpenReview.net},
825
- url = {https://openreview.net/forum?id=YCWjhGrJFD},
826
- }""")
827
-
828
- model_card = generate_model_card(
829
- base_model=base_model,
830
- model_name=model_name,
831
- hub_model_id=self.hub_model_id,
832
- dataset_name=dataset_name,
833
- tags=tags,
834
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
835
- comet_url=get_comet_experiment_url(),
836
- trainer_name="DDPO",
837
- trainer_citation=citation,
838
- paper_title="Training Diffusion Models with Reinforcement Learning",
839
- paper_id="2305.13301",
840
- )
841
-
842
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
843
- class UnslothDDPOTrainer(_UnslothDDPOTrainer):
844
- """
845
-
846
- The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
847
- Note, this trainer is heavily inspired by the work here: https://github.com/kvablack/ddpo-pytorch
848
- As of now only Stable Diffusion based pipelines are supported
849
-
850
- Attributes:
851
- **config** (`DDPOConfig`) -- Configuration object for DDPOTrainer. Check the documentation of `PPOConfig` for more
852
- details.
853
- **reward_function** (Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]) -- Reward function to be used
854
- **prompt_function** (Callable[[], tuple[str, Any]]) -- Function to generate prompts to guide model
855
- **sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training.
856
- **image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images
857
-
858
- """
859
- def __init__(
860
- self,
861
- config,
862
- reward_function,
863
- prompt_function,
864
- sd_pipeline,
865
- image_samples_hook = None,
866
- **kwargs
867
- ):
868
- if args is None: args = UnslothDDPOConfig()
869
- other_metrics = []
870
-
871
- from unsloth_zoo.logging_utils import PatchRLStatistics
872
- PatchRLStatistics('ddpo_trainer', other_metrics)
873
-
874
- super().__init__(
875
- config = config,
876
- reward_function = reward_function,
877
- prompt_function = prompt_function,
878
- sd_pipeline = sd_pipeline,
879
- image_samples_hook = image_samples_hook,**kwargs)
880
-
881
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_run_uploads/UnslothDPOTrainer.py DELETED
The diff for this file is too large to render. See raw diff
 
test_run_uploads/UnslothGKDTrainer.py DELETED
@@ -1,885 +0,0 @@
1
- """
2
- 2025.7.11
3
- 2025.7.11
4
- 4.54.1
5
- 0.16.1
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
13
- from trl.trainer.gkd_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DataCollator, DataCollatorForChatML, Dataset, EvalPrediction, F, FeatureExtractionMixin, GKDConfig, GKDTrainer, GenerationConfig, Optional, PeftConfig, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, SFTTrainer, TrainerCallback, Union, deepcopy, disable_dropout_in_model, empty_cache, generate_model_card, get_comet_experiment_url, is_wandb_available, nn, os, random, textwrap, torch, unwrap_model_for_generation, wandb)
14
-
15
-
16
- import os
17
- from typing import *
18
- from dataclasses import dataclass, field
19
- from packaging.version import Version
20
- import torch
21
- import numpy as np
22
- from contextlib import nullcontext
23
- from torch.nn import functional as F
24
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
25
-
26
- torch_compile_options = {
27
- "epilogue_fusion" : True,
28
- "max_autotune" : False,
29
- "shape_padding" : True,
30
- "trace.enabled" : False,
31
- "triton.cudagraphs" : False,
32
- }
33
-
34
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
35
- def chunked_selective_log_softmax(logits, index):
36
- # Split into 4 chunks only
37
- chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
38
- chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
39
- all_per_token_logps = []
40
- # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
41
- for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
42
- chunk_logits = chunk_logits.to(torch.float32)
43
- selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
44
- logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
45
- per_token_logps = selected_logits - logsumexp_values
46
- all_per_token_logps.append(per_token_logps)
47
- pass
48
- all_per_token_logps = torch.concat(all_per_token_logps)
49
- all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
50
- return all_per_token_logps
51
- @dataclass
52
- class UnslothGKDConfig(GKDConfig):
53
- """
54
-
55
- Configuration class for [`GKDTrainer`].
56
-
57
- Args:
58
- temperature (`float`, *optional*, defaults to `0.9`):
59
- Temperature for sampling. The higher the temperature, the more random the completions.
60
- lmbda (`float`, *optional*, defaults to `0.5`):
61
- Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy
62
- student-generated outputs).
63
- beta (`float`, *optional*, defaults to `0.5`):
64
- Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When
65
- beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.
66
- max_new_tokens (`int`, *optional*, defaults to `128`):
67
- Maximum number of tokens to generate per completion.
68
- teacher_model_name_or_path (`str` or `None`, *optional*, defaults to `None`):
69
- Model name or path of the teacher model. If `None`, the teacher model will be the same as the model
70
- being trained.
71
- teacher_model_init_kwargs (`dict[str, Any]]` or `None`, *optional*, defaults to `None`):
72
- Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
73
- from a string.
74
- disable_dropout (`bool`, *optional*, defaults to `True`):
75
- Whether to disable dropout in the model.
76
- seq_kd (`bool`, *optional*, defaults to `False`):
77
- Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT
78
- on teacher-generated output).
79
-
80
- """
81
- vllm_sampling_params: Optional[Any] = field(
82
- default = None,
83
- metadata = {'help': 'vLLM SamplingParams'},
84
- )
85
- unsloth_num_chunks : Optional[int] = field(
86
- default = -1,
87
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
88
- )
89
- def __init__(
90
- self,
91
- output_dir = None,
92
- overwrite_output_dir = None,
93
- do_train = False,
94
- do_eval = False,
95
- do_predict = False,
96
- eval_strategy = 'no',
97
- prediction_loss_only = False,
98
- per_device_train_batch_size = 4,
99
- per_device_eval_batch_size = 4,
100
- per_gpu_train_batch_size = None,
101
- per_gpu_eval_batch_size = None,
102
- gradient_accumulation_steps = 2,
103
- eval_accumulation_steps = 2,
104
- eval_delay = 0,
105
- torch_empty_cache_steps = 250,
106
- learning_rate = 5e-05,
107
- weight_decay = 0.01,
108
- adam_beta1 = 0.9,
109
- adam_beta2 = 0.999,
110
- adam_epsilon = 1e-08,
111
- max_grad_norm = 1.0,
112
- num_train_epochs = 3.0,
113
- max_steps = -1,
114
- lr_scheduler_type = 'linear',
115
- warmup_ratio = 0.1,
116
- warmup_steps = 0,
117
- log_level = 'passive',
118
- log_level_replica = 'warning',
119
- log_on_each_node = True,
120
- logging_dir = None,
121
- logging_strategy = 'steps',
122
- logging_first_step = False,
123
- logging_steps = 1,
124
- logging_nan_inf_filter = False,
125
- save_strategy = 'steps',
126
- save_steps = 500,
127
- save_total_limit = None,
128
- save_safetensors = True,
129
- save_on_each_node = False,
130
- save_only_model = False,
131
- restore_callback_states_from_checkpoint = False,
132
- no_cuda = False,
133
- use_cpu = False,
134
- use_mps_device = False,
135
- seed = 3407,
136
- data_seed = 3407,
137
- jit_mode_eval = False,
138
- use_ipex = False,
139
- bf16 = False,
140
- fp16 = False,
141
- fp16_opt_level = 'O1',
142
- half_precision_backend = 'auto',
143
- bf16_full_eval = False,
144
- fp16_full_eval = False,
145
- tf32 = None,
146
- local_rank = -1,
147
- ddp_backend = None,
148
- tpu_num_cores = None,
149
- tpu_metrics_debug = False,
150
- debug = '',
151
- dataloader_drop_last = False,
152
- eval_steps = None,
153
- dataloader_num_workers = 0,
154
- dataloader_prefetch_factor = None,
155
- past_index = -1,
156
- run_name = None,
157
- disable_tqdm = None,
158
- remove_unused_columns = True,
159
- label_names = None,
160
- load_best_model_at_end = False,
161
- metric_for_best_model = None,
162
- greater_is_better = None,
163
- ignore_data_skip = False,
164
- fsdp = '',
165
- fsdp_min_num_params = 0,
166
- fsdp_config = None,
167
- fsdp_transformer_layer_cls_to_wrap = None,
168
- accelerator_config = None,
169
- deepspeed = None,
170
- label_smoothing_factor = 0.0,
171
- optim = 'adamw_8bit',
172
- optim_args = None,
173
- adafactor = False,
174
- group_by_length = False,
175
- length_column_name = 'length',
176
- report_to = None,
177
- ddp_find_unused_parameters = None,
178
- ddp_bucket_cap_mb = None,
179
- ddp_broadcast_buffers = None,
180
- dataloader_pin_memory = True,
181
- dataloader_persistent_workers = False,
182
- skip_memory_metrics = True,
183
- use_legacy_prediction_loop = False,
184
- push_to_hub = False,
185
- resume_from_checkpoint = None,
186
- hub_model_id = None,
187
- hub_strategy = 'every_save',
188
- hub_token = None,
189
- hub_private_repo = None,
190
- hub_always_push = False,
191
- hub_revision = None,
192
- gradient_checkpointing = False,
193
- gradient_checkpointing_kwargs = None,
194
- include_inputs_for_metrics = False,
195
- eval_do_concat_batches = True,
196
- fp16_backend = 'auto',
197
- push_to_hub_model_id = None,
198
- push_to_hub_organization = None,
199
- push_to_hub_token = None,
200
- mp_parameters = '',
201
- auto_find_batch_size = True,
202
- full_determinism = False,
203
- torchdynamo = None,
204
- ray_scope = 'last',
205
- ddp_timeout = 1800,
206
- torch_compile = False,
207
- torch_compile_backend = None,
208
- torch_compile_mode = None,
209
- include_tokens_per_second = False,
210
- include_num_input_tokens_seen = False,
211
- neftune_noise_alpha = None,
212
- optim_target_modules = None,
213
- batch_eval_metrics = False,
214
- eval_on_start = False,
215
- use_liger_kernel = False,
216
- liger_kernel_config = None,
217
- eval_use_gather_object = False,
218
- average_tokens_across_devices = True,
219
- model_init_kwargs = None,
220
- dataset_text_field = 'text',
221
- dataset_kwargs = None,
222
- dataset_num_proc = None,
223
- pad_token = None,
224
- max_length = 1024,
225
- packing = False,
226
- padding_free = False,
227
- eval_packing = None,
228
- dataset_batch_size = None,
229
- num_of_sequences = None,
230
- chars_per_token = None,
231
- max_seq_length = None,
232
- use_liger = None,
233
- temperature = 0.9,
234
- lmbda = 0.5,
235
- beta = 0.5,
236
- max_new_tokens = 128,
237
- teacher_model_name_or_path = None,
238
- teacher_model_init_kwargs = None,
239
- disable_dropout = True,
240
- seq_kd = False,
241
- vllm_sampling_params = None,
242
- unsloth_num_chunks = -1,
243
- **kwargs,
244
- ):
245
- if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
246
- if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
247
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
248
- output_dir = 'unsloth_training_checkpoints'
249
- save_strategy = 'no'
250
- if dataset_num_proc is None:
251
- from multiprocessing import cpu_count
252
- dataset_num_proc = min(cpu_count()*2, 2)
253
- if temperature <= 0:
254
- raise MathError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
255
- elif temperature >= 10:
256
- raise MathError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
257
-
258
-
259
- super().__init__(
260
- output_dir = output_dir,
261
- overwrite_output_dir = overwrite_output_dir,
262
- do_train = do_train,
263
- do_eval = do_eval,
264
- do_predict = do_predict,
265
- eval_strategy = eval_strategy,
266
- prediction_loss_only = prediction_loss_only,
267
- per_device_train_batch_size = per_device_train_batch_size,
268
- per_device_eval_batch_size = per_device_eval_batch_size,
269
- per_gpu_train_batch_size = per_gpu_train_batch_size,
270
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
271
- gradient_accumulation_steps = gradient_accumulation_steps,
272
- eval_accumulation_steps = eval_accumulation_steps,
273
- eval_delay = eval_delay,
274
- torch_empty_cache_steps = torch_empty_cache_steps,
275
- learning_rate = learning_rate,
276
- weight_decay = weight_decay,
277
- adam_beta1 = adam_beta1,
278
- adam_beta2 = adam_beta2,
279
- adam_epsilon = adam_epsilon,
280
- max_grad_norm = max_grad_norm,
281
- num_train_epochs = num_train_epochs,
282
- max_steps = max_steps,
283
- lr_scheduler_type = lr_scheduler_type,
284
- warmup_ratio = warmup_ratio,
285
- warmup_steps = warmup_steps,
286
- log_level = log_level,
287
- log_level_replica = log_level_replica,
288
- log_on_each_node = log_on_each_node,
289
- logging_dir = logging_dir,
290
- logging_strategy = logging_strategy,
291
- logging_first_step = logging_first_step,
292
- logging_steps = logging_steps,
293
- logging_nan_inf_filter = logging_nan_inf_filter,
294
- save_strategy = save_strategy,
295
- save_steps = save_steps,
296
- save_total_limit = save_total_limit,
297
- save_safetensors = save_safetensors,
298
- save_on_each_node = save_on_each_node,
299
- save_only_model = save_only_model,
300
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
301
- no_cuda = no_cuda,
302
- use_cpu = use_cpu,
303
- use_mps_device = use_mps_device,
304
- seed = seed,
305
- data_seed = data_seed,
306
- jit_mode_eval = jit_mode_eval,
307
- use_ipex = use_ipex,
308
- bf16 = bf16,
309
- fp16 = fp16,
310
- fp16_opt_level = fp16_opt_level,
311
- half_precision_backend = half_precision_backend,
312
- bf16_full_eval = bf16_full_eval,
313
- fp16_full_eval = fp16_full_eval,
314
- tf32 = tf32,
315
- local_rank = local_rank,
316
- ddp_backend = ddp_backend,
317
- tpu_num_cores = tpu_num_cores,
318
- tpu_metrics_debug = tpu_metrics_debug,
319
- debug = debug,
320
- dataloader_drop_last = dataloader_drop_last,
321
- eval_steps = eval_steps,
322
- dataloader_num_workers = dataloader_num_workers,
323
- dataloader_prefetch_factor = dataloader_prefetch_factor,
324
- past_index = past_index,
325
- run_name = run_name,
326
- disable_tqdm = disable_tqdm,
327
- remove_unused_columns = remove_unused_columns,
328
- label_names = label_names,
329
- load_best_model_at_end = load_best_model_at_end,
330
- metric_for_best_model = metric_for_best_model,
331
- greater_is_better = greater_is_better,
332
- ignore_data_skip = ignore_data_skip,
333
- fsdp = fsdp,
334
- fsdp_min_num_params = fsdp_min_num_params,
335
- fsdp_config = fsdp_config,
336
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
337
- accelerator_config = accelerator_config,
338
- deepspeed = deepspeed,
339
- label_smoothing_factor = label_smoothing_factor,
340
- optim = optim,
341
- optim_args = optim_args,
342
- adafactor = adafactor,
343
- group_by_length = group_by_length,
344
- length_column_name = length_column_name,
345
- report_to = report_to,
346
- ddp_find_unused_parameters = ddp_find_unused_parameters,
347
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
348
- ddp_broadcast_buffers = ddp_broadcast_buffers,
349
- dataloader_pin_memory = dataloader_pin_memory,
350
- dataloader_persistent_workers = dataloader_persistent_workers,
351
- skip_memory_metrics = skip_memory_metrics,
352
- use_legacy_prediction_loop = use_legacy_prediction_loop,
353
- push_to_hub = push_to_hub,
354
- resume_from_checkpoint = resume_from_checkpoint,
355
- hub_model_id = hub_model_id,
356
- hub_strategy = hub_strategy,
357
- hub_token = hub_token,
358
- hub_private_repo = hub_private_repo,
359
- hub_always_push = hub_always_push,
360
- hub_revision = hub_revision,
361
- gradient_checkpointing = gradient_checkpointing,
362
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
363
- include_inputs_for_metrics = include_inputs_for_metrics,
364
- eval_do_concat_batches = eval_do_concat_batches,
365
- fp16_backend = fp16_backend,
366
- push_to_hub_model_id = push_to_hub_model_id,
367
- push_to_hub_organization = push_to_hub_organization,
368
- push_to_hub_token = push_to_hub_token,
369
- mp_parameters = mp_parameters,
370
- auto_find_batch_size = auto_find_batch_size,
371
- full_determinism = full_determinism,
372
- torchdynamo = torchdynamo,
373
- ray_scope = ray_scope,
374
- ddp_timeout = ddp_timeout,
375
- torch_compile = torch_compile,
376
- torch_compile_backend = torch_compile_backend,
377
- torch_compile_mode = torch_compile_mode,
378
- include_tokens_per_second = include_tokens_per_second,
379
- include_num_input_tokens_seen = include_num_input_tokens_seen,
380
- neftune_noise_alpha = neftune_noise_alpha,
381
- optim_target_modules = optim_target_modules,
382
- batch_eval_metrics = batch_eval_metrics,
383
- eval_on_start = eval_on_start,
384
- use_liger_kernel = use_liger_kernel,
385
- liger_kernel_config = liger_kernel_config,
386
- eval_use_gather_object = eval_use_gather_object,
387
- average_tokens_across_devices = average_tokens_across_devices,
388
- model_init_kwargs = model_init_kwargs,
389
- dataset_text_field = dataset_text_field,
390
- dataset_kwargs = dataset_kwargs,
391
- dataset_num_proc = dataset_num_proc,
392
- pad_token = pad_token,
393
- max_length = max_length,
394
- packing = packing,
395
- padding_free = padding_free,
396
- eval_packing = eval_packing,
397
- dataset_batch_size = dataset_batch_size,
398
- num_of_sequences = num_of_sequences,
399
- chars_per_token = chars_per_token,
400
- max_seq_length = max_seq_length,
401
- use_liger = use_liger,
402
- temperature = temperature,
403
- lmbda = lmbda,
404
- beta = beta,
405
- max_new_tokens = max_new_tokens,
406
- teacher_model_name_or_path = teacher_model_name_or_path,
407
- teacher_model_init_kwargs = teacher_model_init_kwargs,
408
- disable_dropout = disable_dropout,
409
- seq_kd = seq_kd,**kwargs)
410
- self.vllm_sampling_params = vllm_sampling_params
411
- self.unsloth_num_chunks = unsloth_num_chunks
412
- pass
413
-
414
- class _UnslothGKDTrainer(SFTTrainer):
415
- _tag_names = ["trl", "gkd"]
416
-
417
- def __init__(
418
- self,
419
- model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
420
- teacher_model: Union[PreTrainedModel, nn.Module, str] = None,
421
- args: Optional[GKDConfig] = None,
422
- data_collator: Optional[DataCollator] = None, # type: ignore
423
- train_dataset: Optional[Dataset] = None,
424
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
425
- processing_class: Optional[
426
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
427
- ] = None,
428
- compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
429
- callbacks: Optional[list[TrainerCallback]] = None,
430
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
431
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
432
- peft_config: Optional["PeftConfig"] = None,
433
- formatting_func: Optional[Callable] = None,
434
- ):
435
- # add remove_unused_columns=False to the dataclass args
436
- args.remove_unused_columns = False
437
- data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length)
438
-
439
- super().__init__(
440
- model,
441
- args=args,
442
- data_collator=data_collator,
443
- train_dataset=train_dataset,
444
- eval_dataset=eval_dataset,
445
- processing_class=processing_class,
446
- compute_metrics=compute_metrics,
447
- callbacks=callbacks,
448
- optimizers=optimizers,
449
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
450
- peft_config=peft_config,
451
- formatting_func=formatting_func,
452
- )
453
-
454
- if args.teacher_model_init_kwargs is None:
455
- teacher_model_init_kwargs = {}
456
- elif not isinstance(teacher_model, str):
457
- raise ValueError(
458
- "You passed teacher_model_init_kwargs to the GKDConfig, but your teacher_model is already instantiated."
459
- )
460
- else:
461
- teacher_model_init_kwargs = args.teacher_model_init_kwargs
462
- teacher_model_init_kwargs["torch_dtype"] = (
463
- teacher_model_init_kwargs["torch_dtype"]
464
- if teacher_model_init_kwargs["torch_dtype"] in ["auto", None]
465
- else getattr(torch, teacher_model_init_kwargs["torch_dtype"])
466
- )
467
-
468
- if isinstance(teacher_model, str):
469
- teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
470
-
471
- # Disable dropout in the model
472
- if args.disable_dropout:
473
- disable_dropout_in_model(self.model)
474
-
475
- if self.is_deepspeed_enabled:
476
- self.teacher_model = self._prepare_deepspeed(teacher_model)
477
- else:
478
- self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True)
479
-
480
- self.lmbda = args.lmbda
481
- self.beta = args.beta
482
- self.temperature = args.temperature
483
- self.seq_kd = args.seq_kd
484
-
485
- self.generation_config = GenerationConfig(
486
- max_new_tokens=args.max_new_tokens,
487
- temperature=args.temperature,
488
- do_sample=True,
489
- top_k=0,
490
- use_cache=False if args.gradient_checkpointing else True,
491
- pad_token_id=self.processing_class.pad_token_id,
492
- )
493
- # Set custom EOS tokens if they are specified by the model's generation
494
- # config. This is important for models with the Llama 3 chat template,
495
- # which use special tokens <|eot_id|> and <|eom_id|> to mark the end of
496
- # turns or messages.
497
- if (
498
- hasattr(self.model.generation_config, "eos_token_id")
499
- and self.model.generation_config.eos_token_id is not None
500
- ):
501
- self.generation_config.eos_token_id = self.model.generation_config.eos_token_id
502
-
503
- def _prepare_dataset(self, dataset, *args):
504
- # SFTTrainer._prepare_dataset() applies the chat template and rename the messages column to text. However, we
505
- # need to keep the messages column as it is. We use the following workaround to keep the messages column.
506
- dataset = dataset.add_column("_messages", dataset["messages"])
507
- dataset = super()._prepare_dataset(dataset, *args)
508
- dataset = dataset.rename_column("_messages", "messages")
509
- return dataset
510
-
511
- @staticmethod
512
- def generalized_jsd_loss(
513
- student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean"
514
- ):
515
- """
516
- Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1)
517
- of https://huggingface.co/papers/2306.13649 for the definition.
518
-
519
- Args:
520
- student_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
521
- teacher_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
522
- labels: Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing loss
523
- beta: Interpolation coefficient between 0 and 1 (default: 0.5)
524
- temperature: Softmax temperature (default: 1.0)
525
- reduction: Specifies the reduction to apply to the output (default: 'batchmean')
526
-
527
- Returns:
528
- loss: Scalar tensor with the generalized JSD loss
529
- """
530
-
531
- # Apply temperature scaling
532
- student_logits = student_logits / temperature
533
- teacher_logits = teacher_logits / temperature
534
-
535
- # Compute log probabilities for student and probabilities for teacher
536
- student_log_probs = F.log_softmax(student_logits, dim=-1)
537
- teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
538
-
539
- if beta == 0:
540
- jsd = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True)
541
- elif beta == 1:
542
- jsd = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True)
543
- else:
544
- # Compute the log of the mixture distribution
545
- # log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture
546
- beta = torch.tensor(beta, dtype=student_log_probs.dtype)
547
- mixture_log_probs = torch.logsumexp(
548
- torch.stack([student_log_probs + torch.log(1 - beta), teacher_log_probs + torch.log(beta)]),
549
- dim=0,
550
- )
551
-
552
- # Compute KL divergences using F.kl_div
553
- # PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper.
554
- kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
555
- kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)
556
-
557
- # Compute the Generalized Jensen-Shannon Divergence
558
- jsd = beta * kl_teacher + (1 - beta) * kl_student
559
-
560
- # Masking
561
- if labels is not None:
562
- mask = labels != -100
563
- jsd = jsd[mask]
564
-
565
- # Apply reduction
566
- if reduction == "batchmean":
567
- return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / (jsd.size(0) * jsd.size(1))
568
- elif reduction == "sum":
569
- return jsd.sum()
570
- elif reduction == "mean":
571
- return jsd.mean()
572
- else:
573
- return jsd
574
-
575
- def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
576
- # compute student output
577
- outputs_student = model(
578
- input_ids=inputs["input_ids"],
579
- attention_mask=inputs["attention_mask"],
580
- )
581
-
582
- # compute teacher output in eval mode
583
- self.teacher_model.eval()
584
- with torch.no_grad():
585
- outputs_teacher = self.teacher_model(
586
- input_ids=inputs["input_ids"],
587
- attention_mask=inputs["attention_mask"],
588
- )
589
-
590
- # slice the logits for the generated tokens using the inputs["prompts"] lengths
591
- prompt_lengths = inputs["prompts"].shape[1]
592
- shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :]
593
- shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :]
594
- shifted_labels = inputs["labels"][:, prompt_lengths:]
595
-
596
- # compute loss
597
- loss = self.generalized_jsd_loss(
598
- student_logits=shifted_student_logits,
599
- teacher_logits=shifted_teacher_logits,
600
- labels=shifted_labels,
601
- beta=self.beta,
602
- )
603
-
604
- # empty cache
605
- empty_cache()
606
-
607
- # Return loss
608
- return (loss, outputs_student) if return_outputs else loss
609
-
610
- @staticmethod
611
- def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None):
612
- # Generate output with respect to the prompt only
613
- generated_outputs = model.generate(
614
- input_ids=inputs["prompts"],
615
- attention_mask=inputs.get("prompt_attention_mask", None),
616
- generation_config=generation_config,
617
- return_dict_in_generate=True,
618
- )
619
-
620
- # Get the generated token IDs
621
- generated_tokens = generated_outputs.sequences
622
- # Calculate new attention mask
623
- new_attention_mask = torch.ones_like(generated_tokens)
624
- new_labels = generated_tokens.clone()
625
-
626
- # If there's pad_token_id, set attention mask to 0 for padding tokens
627
- if pad_token_id is not None:
628
- new_labels[new_labels == pad_token_id] = -100
629
- new_attention_mask[generated_tokens == pad_token_id] = 0
630
-
631
- return generated_tokens, new_attention_mask, new_labels
632
-
633
- def training_step(
634
- self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
635
- ) -> torch.Tensor:
636
- """
637
- Perform a training step for the Generalized Knowledge Distillation (GKD) model.
638
-
639
- This method implements the on-policy learning approach described in the GKD paper.
640
- With probability `self.lmbda`, it generates new responses using the student model,
641
- which are then used for training instead of the original inputs.
642
- """
643
- if self.seq_kd:
644
- with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model:
645
- new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
646
- unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
647
- )
648
- inputs["input_ids"] = new_input_ids
649
- inputs["attention_mask"] = new_attention_mask
650
- inputs["labels"] = new_labels
651
- if random.random() <= self.lmbda:
652
- with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
653
- new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
654
- unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
655
- )
656
- inputs["input_ids"] = new_input_ids
657
- inputs["attention_mask"] = new_attention_mask
658
- inputs["labels"] = new_labels
659
-
660
- loss = super().training_step(model, inputs, num_items_in_batch)
661
- return loss
662
-
663
- def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
664
- # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
665
- deepspeed_plugin = self.accelerator.state.deepspeed_plugin
666
- config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
667
-
668
- if model is not None:
669
- if hasattr(model, "config"):
670
- hidden_size = (
671
- max(model.config.hidden_sizes)
672
- if getattr(model.config, "hidden_sizes", None)
673
- else getattr(model.config, "hidden_size", None)
674
- )
675
- if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
676
- # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
677
- # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
678
- config_kwargs.update(
679
- {
680
- "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
681
- "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
682
- "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
683
- }
684
- )
685
-
686
- # If ZeRO-3 is used, we shard both the active and reference model.
687
- # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
688
- if config_kwargs["zero_optimization"]["stage"] != 3:
689
- config_kwargs["zero_optimization"]["stage"] = 0
690
- model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
691
- model.eval()
692
- return model
693
-
694
- def create_model_card(
695
- self,
696
- model_name: Optional[str] = None,
697
- dataset_name: Optional[str] = None,
698
- tags: Union[str, list[str], None] = None,
699
- ):
700
- """
701
- Creates a draft of a model card using the information available to the `Trainer`.
702
-
703
- Args:
704
- model_name (`str` or `None`, *optional*, defaults to `None`):
705
- Name of the model.
706
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
707
- Name of the dataset used for training.
708
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
709
- Tags to be associated with the model card.
710
- """
711
- if not self.is_world_process_zero():
712
- return
713
-
714
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
715
- base_model = self.model.config._name_or_path
716
- else:
717
- base_model = None
718
-
719
- tags = tags or []
720
- if isinstance(tags, str):
721
- tags = [tags]
722
-
723
- if hasattr(self.model.config, "unsloth_version"):
724
- tags.append("unsloth")
725
-
726
- citation = textwrap.dedent("""\
727
- @inproceedings{agarwal2024on-policy,
728
- title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}},
729
- author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem},
730
- year = 2024,
731
- booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
732
- publisher = {OpenReview.net},
733
- url = {https://openreview.net/forum?id=3zKtaqxLhW},
734
- }""")
735
-
736
- model_card = generate_model_card(
737
- base_model=base_model,
738
- model_name=model_name,
739
- hub_model_id=self.hub_model_id,
740
- dataset_name=dataset_name,
741
- tags=tags,
742
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
743
- comet_url=get_comet_experiment_url(),
744
- trainer_name="GKD",
745
- trainer_citation=citation,
746
- paper_title="On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes",
747
- paper_id="2306.13649",
748
- )
749
-
750
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
751
- class UnslothGKDTrainer(_UnslothGKDTrainer):
752
- """
753
-
754
- """
755
- def __init__(
756
- self,
757
- model = None,
758
- teacher_model = None,
759
- args = None,
760
- data_collator = None,
761
- train_dataset = None,
762
- eval_dataset = None,
763
- processing_class = None,
764
- compute_metrics = None,
765
- callbacks = None,
766
- preprocess_logits_for_metrics = None,
767
- peft_config = None,
768
- formatting_func = None,
769
- **kwargs
770
- ):
771
- if args is None: args = UnslothGKDConfig()
772
- use_bf16 = getattr(args, 'bf16', False)
773
- if type(use_bf16) is not bool: use_bf16 = False
774
- use_fp16 = getattr(args, 'fp16', False)
775
- if type(use_fp16) is not bool: use_fp16 = False
776
- force_float32 = False
777
- if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
778
- print('Unsloth: Switching to float32 training since model cannot work with float16')
779
- force_float32 = True
780
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
781
- dtype = getattr(model.config, 'torch_dtype', None)
782
- if dtype is None: dtype = model.get_input_embeddings().dtype
783
- from unsloth_zoo.utils import _get_dtype
784
- dtype = _get_dtype(dtype)
785
- float16 = dtype == torch.float16
786
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
787
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
788
- if force_float32:
789
- args.fp16 = False
790
- args.bf16 = False
791
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
792
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
793
- args.fp16 = float16
794
- args.bf16 = not float16
795
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
796
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
797
- args.eval_strategy = 'steps'
798
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
799
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
800
- if ga_steps is not None and ga_steps > 1:
801
- from transformers import __version__ as transformers_version
802
- if Version(transformers_version) <= Version('4.45.2'):
803
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
804
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
805
- if getattr(args, 'eval_strategy', 'no') != 'no':
806
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
807
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
808
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
809
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
810
- if type(fp16_full_eval) is not bool: fp16_full_eval = False
811
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
812
- if type(bf16_full_eval) is not bool: bf16_full_eval = False
813
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
814
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
815
- if force_float32:
816
- args.bf16_full_eval = False
817
- args.fp16_full_eval = False
818
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
819
- args.bf16_full_eval = True
820
- args.fp16_full_eval = False
821
- elif not bf16_full_eval and not fp16_full_eval:
822
- args.bf16_full_eval = args.bf16
823
- args.fp16_full_eval = args.fp16
824
- _output_logits = False
825
- if locals().get('compute_metrics', None) is not None: _output_logits = True
826
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
827
- if _output_logits:
828
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
829
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
830
- pass
831
- else:
832
- model_max_seq_length = getattr(model, 'max_seq_length', None)
833
- args_max_seq_length = getattr(args, 'max_seq_length', None)
834
- if args_max_seq_length is None and model_max_seq_length is not None:
835
- max_seq_length = model.max_seq_length
836
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
837
- if model is not None and hasattr(model, 'for_training'):
838
- model.for_training()
839
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
840
- if 'processing_class' in locals():
841
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
842
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
843
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
844
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
845
- if not isinstance(data_collator, UnslothVisionDataCollator):
846
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
847
- data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
848
- elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
849
- data_collator = DataCollatorForSeq2Seq(__tokenizer)
850
- else:
851
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
852
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
853
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
854
- if not isinstance(data_collator, UnslothVisionDataCollator):
855
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
856
- if isinstance(data_collator, DataCollatorForSeq2Seq):
857
- data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
858
- else:
859
- data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
860
- other_metrics = []
861
-
862
- from unsloth_zoo.logging_utils import PatchRLStatistics
863
- PatchRLStatistics('gkd_trainer', other_metrics)
864
-
865
- super().__init__(
866
- model = model,
867
- teacher_model = teacher_model,
868
- args = args,
869
- data_collator = data_collator,
870
- train_dataset = train_dataset,
871
- eval_dataset = eval_dataset,
872
- processing_class = processing_class,
873
- compute_metrics = compute_metrics,
874
- callbacks = callbacks,
875
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,
876
- peft_config = peft_config,
877
- formatting_func = formatting_func,**kwargs)
878
- if hasattr(self, 'neftune_hook_handle'):
879
- self.neftune_hook_handle.remove()
880
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
881
- if getattr(args, 'neftune_noise_alpha', None) is not None:
882
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
883
- pass
884
-
885
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_run_uploads/UnslothGRPOTrainer.py DELETED
The diff for this file is too large to render. See raw diff
 
test_run_uploads/UnslothKTOTrainer.py DELETED
@@ -1,1849 +0,0 @@
1
- """
2
- 2025.7.11
3
- 2025.7.11
4
- 4.54.1
5
- 0.16.1
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
13
- from trl.trainer.kto_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, KTOConfig, KTOTrainer, Literal, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, SequentialSampler, Trainer, TrainerCallback, TrainingArguments, Union, _get_kl_dataset, _process_tokens, _tokenize, amp, concatenate_datasets, contextmanager, create_reference_model, deepcopy, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, has_length, inspect, is_comet_available, is_peft_available, is_wandb_available, itemgetter, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, tqdm, transformers, version, wandb, warnings)
14
-
15
-
16
- import os
17
- from typing import *
18
- from dataclasses import dataclass, field
19
- from packaging.version import Version
20
- import torch
21
- import numpy as np
22
- from contextlib import nullcontext
23
- from torch.nn import functional as F
24
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
25
-
26
- torch_compile_options = {
27
- "epilogue_fusion" : True,
28
- "max_autotune" : False,
29
- "shape_padding" : True,
30
- "trace.enabled" : False,
31
- "triton.cudagraphs" : False,
32
- }
33
-
34
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
35
- def chunked_selective_log_softmax(logits, index):
36
- # Split into 4 chunks only
37
- chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
38
- chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
39
- all_per_token_logps = []
40
- # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
41
- for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
42
- chunk_logits = chunk_logits.to(torch.float32)
43
- selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
44
- logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
45
- per_token_logps = selected_logits - logsumexp_values
46
- all_per_token_logps.append(per_token_logps)
47
- pass
48
- all_per_token_logps = torch.concat(all_per_token_logps)
49
- all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
50
- return all_per_token_logps
51
- @dataclass
52
- class UnslothKTOConfig(KTOConfig):
53
- """
54
-
55
- Configuration class for the [`KTOTrainer`].
56
-
57
- Using [`~transformers.HfArgumentParser`] we can turn this class into
58
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
59
- command line.
60
-
61
- Parameters:
62
- learning_rate (`float`, *optional*, defaults to `1e-6`):
63
- Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
64
- [`~transformers.TrainingArguments`].
65
- max_length (`int` or `None`, *optional*, defaults to `1024`):
66
- Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
67
- to use the default data collator.
68
- max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
69
- Maximum length of the prompt. This argument is required if you want to use the default data collator.
70
- max_completion_length (`int` or `None`, *optional*, defaults to `None`):
71
- Maximum length of the completion. This argument is required if you want to use the default data collator
72
- and your model is an encoder-decoder.
73
- beta (`float`, *optional*, defaults to `0.1`):
74
- Parameter controlling the deviation from the reference model. Higher β means less deviation from the
75
- reference model.
76
- loss_type (`str`, *optional*, defaults to `"kto"`):
77
- Type of loss to use. Possible values are:
78
-
79
- - `"kto"`: KTO loss from the [KTO](https://huggingface.co/papers/2402.01306) paper.
80
- - `"apo_zero_unpaired"`: Unpaired variant of APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
81
-
82
- desirable_weight (`float`, *optional*, defaults to `1.0`):
83
- Desirable losses are weighed by this factor to counter unequal number of desirable and undesirable paris.
84
- undesirable_weight (`float`, *optional*, defaults to `1.0`):
85
- Undesirable losses are weighed by this factor to counter unequal number of desirable and undesirable pairs.
86
- label_pad_token_id (`int`, *optional*, defaults to `-100`):
87
- Label pad token id. This argument is required if you want to use the default data collator.
88
- padding_value (`int` or `None`, *optional*, defaults to `None`):
89
- Padding value to use. If `None`, the padding value of the tokenizer is used.
90
- truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
91
- Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
92
- This argument is required if you want to use the default data collator.
93
- generate_during_eval (`bool`, *optional*, defaults to `False`):
94
- If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during
95
- evaluation.
96
- is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
97
- When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
98
- you need to specify if the model returned by the callable is an encoder-decoder model.
99
- precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
100
- Whether to precompute reference model log probabilities for training and evaluation datasets. This is
101
- useful when training without the reference model to reduce the total GPU memory needed.
102
- model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
103
- Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
104
- string.
105
- ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
106
- Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model
107
- from a string.
108
- dataset_num_proc: (`int` or `None`, *optional*, defaults to `None`):
109
- Number of processes to use for processing the dataset.
110
- disable_dropout (`bool`, *optional*, defaults to `True`):
111
- Whether to disable dropout in the model and reference model.
112
-
113
- """
114
- vllm_sampling_params: Optional[Any] = field(
115
- default = None,
116
- metadata = {'help': 'vLLM SamplingParams'},
117
- )
118
- unsloth_num_chunks : Optional[int] = field(
119
- default = -1,
120
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
121
- )
122
- def __init__(
123
- self,
124
- output_dir = None,
125
- overwrite_output_dir = None,
126
- do_train = False,
127
- do_eval = False,
128
- do_predict = False,
129
- eval_strategy = 'no',
130
- prediction_loss_only = False,
131
- per_device_train_batch_size = 4,
132
- per_device_eval_batch_size = 4,
133
- per_gpu_train_batch_size = None,
134
- per_gpu_eval_batch_size = None,
135
- gradient_accumulation_steps = 2,
136
- eval_accumulation_steps = 2,
137
- eval_delay = 0,
138
- torch_empty_cache_steps = 250,
139
- learning_rate = 5e-05,
140
- weight_decay = 0.01,
141
- adam_beta1 = 0.9,
142
- adam_beta2 = 0.999,
143
- adam_epsilon = 1e-08,
144
- max_grad_norm = 1.0,
145
- num_train_epochs = 3.0,
146
- max_steps = -1,
147
- lr_scheduler_type = 'linear',
148
- warmup_ratio = 0.1,
149
- warmup_steps = 0,
150
- log_level = 'passive',
151
- log_level_replica = 'warning',
152
- log_on_each_node = True,
153
- logging_dir = None,
154
- logging_strategy = 'steps',
155
- logging_first_step = False,
156
- logging_steps = 1,
157
- logging_nan_inf_filter = False,
158
- save_strategy = 'steps',
159
- save_steps = 500,
160
- save_total_limit = None,
161
- save_safetensors = True,
162
- save_on_each_node = False,
163
- save_only_model = False,
164
- restore_callback_states_from_checkpoint = False,
165
- no_cuda = False,
166
- use_cpu = False,
167
- use_mps_device = False,
168
- seed = 3407,
169
- data_seed = 3407,
170
- jit_mode_eval = False,
171
- use_ipex = False,
172
- bf16 = False,
173
- fp16 = False,
174
- fp16_opt_level = 'O1',
175
- half_precision_backend = 'auto',
176
- bf16_full_eval = False,
177
- fp16_full_eval = False,
178
- tf32 = None,
179
- local_rank = -1,
180
- ddp_backend = None,
181
- tpu_num_cores = None,
182
- tpu_metrics_debug = False,
183
- debug = '',
184
- dataloader_drop_last = False,
185
- eval_steps = None,
186
- dataloader_num_workers = 0,
187
- dataloader_prefetch_factor = None,
188
- past_index = -1,
189
- run_name = None,
190
- disable_tqdm = None,
191
- remove_unused_columns = True,
192
- label_names = None,
193
- load_best_model_at_end = False,
194
- metric_for_best_model = None,
195
- greater_is_better = None,
196
- ignore_data_skip = False,
197
- fsdp = '',
198
- fsdp_min_num_params = 0,
199
- fsdp_config = None,
200
- fsdp_transformer_layer_cls_to_wrap = None,
201
- accelerator_config = None,
202
- deepspeed = None,
203
- label_smoothing_factor = 0.0,
204
- optim = 'adamw_8bit',
205
- optim_args = None,
206
- adafactor = False,
207
- group_by_length = False,
208
- length_column_name = 'length',
209
- report_to = None,
210
- ddp_find_unused_parameters = None,
211
- ddp_bucket_cap_mb = None,
212
- ddp_broadcast_buffers = None,
213
- dataloader_pin_memory = True,
214
- dataloader_persistent_workers = False,
215
- skip_memory_metrics = True,
216
- use_legacy_prediction_loop = False,
217
- push_to_hub = False,
218
- resume_from_checkpoint = None,
219
- hub_model_id = None,
220
- hub_strategy = 'every_save',
221
- hub_token = None,
222
- hub_private_repo = None,
223
- hub_always_push = False,
224
- hub_revision = None,
225
- gradient_checkpointing = False,
226
- gradient_checkpointing_kwargs = None,
227
- include_inputs_for_metrics = False,
228
- eval_do_concat_batches = True,
229
- fp16_backend = 'auto',
230
- push_to_hub_model_id = None,
231
- push_to_hub_organization = None,
232
- push_to_hub_token = None,
233
- mp_parameters = '',
234
- auto_find_batch_size = True,
235
- full_determinism = False,
236
- torchdynamo = None,
237
- ray_scope = 'last',
238
- ddp_timeout = 1800,
239
- torch_compile = False,
240
- torch_compile_backend = None,
241
- torch_compile_mode = None,
242
- include_tokens_per_second = False,
243
- include_num_input_tokens_seen = False,
244
- neftune_noise_alpha = None,
245
- optim_target_modules = None,
246
- batch_eval_metrics = False,
247
- eval_on_start = False,
248
- use_liger_kernel = False,
249
- liger_kernel_config = None,
250
- eval_use_gather_object = False,
251
- average_tokens_across_devices = True,
252
- max_length = 1024,
253
- max_prompt_length = 512,
254
- max_completion_length = None,
255
- beta = 0.1,
256
- loss_type = 'kto',
257
- desirable_weight = 1.0,
258
- undesirable_weight = 1.0,
259
- label_pad_token_id = -100,
260
- padding_value = None,
261
- truncation_mode = 'keep_end',
262
- generate_during_eval = False,
263
- is_encoder_decoder = None,
264
- disable_dropout = True,
265
- precompute_ref_log_probs = False,
266
- model_init_kwargs = None,
267
- ref_model_init_kwargs = None,
268
- dataset_num_proc = None,
269
- vllm_sampling_params = None,
270
- unsloth_num_chunks = -1,
271
- **kwargs,
272
- ):
273
- if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
274
- if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
275
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
276
- output_dir = 'unsloth_training_checkpoints'
277
- save_strategy = 'no'
278
- if dataset_num_proc is None:
279
- from multiprocessing import cpu_count
280
- dataset_num_proc = min(cpu_count()*2, 2)
281
-
282
- super().__init__(
283
- output_dir = output_dir,
284
- overwrite_output_dir = overwrite_output_dir,
285
- do_train = do_train,
286
- do_eval = do_eval,
287
- do_predict = do_predict,
288
- eval_strategy = eval_strategy,
289
- prediction_loss_only = prediction_loss_only,
290
- per_device_train_batch_size = per_device_train_batch_size,
291
- per_device_eval_batch_size = per_device_eval_batch_size,
292
- per_gpu_train_batch_size = per_gpu_train_batch_size,
293
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
294
- gradient_accumulation_steps = gradient_accumulation_steps,
295
- eval_accumulation_steps = eval_accumulation_steps,
296
- eval_delay = eval_delay,
297
- torch_empty_cache_steps = torch_empty_cache_steps,
298
- learning_rate = learning_rate,
299
- weight_decay = weight_decay,
300
- adam_beta1 = adam_beta1,
301
- adam_beta2 = adam_beta2,
302
- adam_epsilon = adam_epsilon,
303
- max_grad_norm = max_grad_norm,
304
- num_train_epochs = num_train_epochs,
305
- max_steps = max_steps,
306
- lr_scheduler_type = lr_scheduler_type,
307
- warmup_ratio = warmup_ratio,
308
- warmup_steps = warmup_steps,
309
- log_level = log_level,
310
- log_level_replica = log_level_replica,
311
- log_on_each_node = log_on_each_node,
312
- logging_dir = logging_dir,
313
- logging_strategy = logging_strategy,
314
- logging_first_step = logging_first_step,
315
- logging_steps = logging_steps,
316
- logging_nan_inf_filter = logging_nan_inf_filter,
317
- save_strategy = save_strategy,
318
- save_steps = save_steps,
319
- save_total_limit = save_total_limit,
320
- save_safetensors = save_safetensors,
321
- save_on_each_node = save_on_each_node,
322
- save_only_model = save_only_model,
323
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
324
- no_cuda = no_cuda,
325
- use_cpu = use_cpu,
326
- use_mps_device = use_mps_device,
327
- seed = seed,
328
- data_seed = data_seed,
329
- jit_mode_eval = jit_mode_eval,
330
- use_ipex = use_ipex,
331
- bf16 = bf16,
332
- fp16 = fp16,
333
- fp16_opt_level = fp16_opt_level,
334
- half_precision_backend = half_precision_backend,
335
- bf16_full_eval = bf16_full_eval,
336
- fp16_full_eval = fp16_full_eval,
337
- tf32 = tf32,
338
- local_rank = local_rank,
339
- ddp_backend = ddp_backend,
340
- tpu_num_cores = tpu_num_cores,
341
- tpu_metrics_debug = tpu_metrics_debug,
342
- debug = debug,
343
- dataloader_drop_last = dataloader_drop_last,
344
- eval_steps = eval_steps,
345
- dataloader_num_workers = dataloader_num_workers,
346
- dataloader_prefetch_factor = dataloader_prefetch_factor,
347
- past_index = past_index,
348
- run_name = run_name,
349
- disable_tqdm = disable_tqdm,
350
- remove_unused_columns = remove_unused_columns,
351
- label_names = label_names,
352
- load_best_model_at_end = load_best_model_at_end,
353
- metric_for_best_model = metric_for_best_model,
354
- greater_is_better = greater_is_better,
355
- ignore_data_skip = ignore_data_skip,
356
- fsdp = fsdp,
357
- fsdp_min_num_params = fsdp_min_num_params,
358
- fsdp_config = fsdp_config,
359
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
360
- accelerator_config = accelerator_config,
361
- deepspeed = deepspeed,
362
- label_smoothing_factor = label_smoothing_factor,
363
- optim = optim,
364
- optim_args = optim_args,
365
- adafactor = adafactor,
366
- group_by_length = group_by_length,
367
- length_column_name = length_column_name,
368
- report_to = report_to,
369
- ddp_find_unused_parameters = ddp_find_unused_parameters,
370
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
371
- ddp_broadcast_buffers = ddp_broadcast_buffers,
372
- dataloader_pin_memory = dataloader_pin_memory,
373
- dataloader_persistent_workers = dataloader_persistent_workers,
374
- skip_memory_metrics = skip_memory_metrics,
375
- use_legacy_prediction_loop = use_legacy_prediction_loop,
376
- push_to_hub = push_to_hub,
377
- resume_from_checkpoint = resume_from_checkpoint,
378
- hub_model_id = hub_model_id,
379
- hub_strategy = hub_strategy,
380
- hub_token = hub_token,
381
- hub_private_repo = hub_private_repo,
382
- hub_always_push = hub_always_push,
383
- hub_revision = hub_revision,
384
- gradient_checkpointing = gradient_checkpointing,
385
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
386
- include_inputs_for_metrics = include_inputs_for_metrics,
387
- eval_do_concat_batches = eval_do_concat_batches,
388
- fp16_backend = fp16_backend,
389
- push_to_hub_model_id = push_to_hub_model_id,
390
- push_to_hub_organization = push_to_hub_organization,
391
- push_to_hub_token = push_to_hub_token,
392
- mp_parameters = mp_parameters,
393
- auto_find_batch_size = auto_find_batch_size,
394
- full_determinism = full_determinism,
395
- torchdynamo = torchdynamo,
396
- ray_scope = ray_scope,
397
- ddp_timeout = ddp_timeout,
398
- torch_compile = torch_compile,
399
- torch_compile_backend = torch_compile_backend,
400
- torch_compile_mode = torch_compile_mode,
401
- include_tokens_per_second = include_tokens_per_second,
402
- include_num_input_tokens_seen = include_num_input_tokens_seen,
403
- neftune_noise_alpha = neftune_noise_alpha,
404
- optim_target_modules = optim_target_modules,
405
- batch_eval_metrics = batch_eval_metrics,
406
- eval_on_start = eval_on_start,
407
- use_liger_kernel = use_liger_kernel,
408
- liger_kernel_config = liger_kernel_config,
409
- eval_use_gather_object = eval_use_gather_object,
410
- average_tokens_across_devices = average_tokens_across_devices,
411
- max_length = max_length,
412
- max_prompt_length = max_prompt_length,
413
- max_completion_length = max_completion_length,
414
- beta = beta,
415
- loss_type = loss_type,
416
- desirable_weight = desirable_weight,
417
- undesirable_weight = undesirable_weight,
418
- label_pad_token_id = label_pad_token_id,
419
- padding_value = padding_value,
420
- truncation_mode = truncation_mode,
421
- generate_during_eval = generate_during_eval,
422
- is_encoder_decoder = is_encoder_decoder,
423
- disable_dropout = disable_dropout,
424
- precompute_ref_log_probs = precompute_ref_log_probs,
425
- model_init_kwargs = model_init_kwargs,
426
- ref_model_init_kwargs = ref_model_init_kwargs,
427
- dataset_num_proc = dataset_num_proc,**kwargs)
428
- self.vllm_sampling_params = vllm_sampling_params
429
- self.unsloth_num_chunks = unsloth_num_chunks
430
- pass
431
-
432
- class _UnslothKTOTrainer(Trainer):
433
- r""""""
434
-
435
- _tag_names = ["trl", "kto"]
436
-
437
- def __init__(
438
- self,
439
- model: Union[PreTrainedModel, nn.Module, str] = None,
440
- ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
441
- args: KTOConfig = None,
442
- train_dataset: Optional[Dataset] = None,
443
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
444
- processing_class: Optional[
445
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
446
- ] = None,
447
- data_collator: Optional[DataCollator] = None,
448
- model_init: Optional[Callable[[], PreTrainedModel]] = None,
449
- callbacks: Optional[list[TrainerCallback]] = None,
450
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
451
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
452
- peft_config: Optional[dict] = None,
453
- compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
454
- model_adapter_name: Optional[str] = None,
455
- ref_adapter_name: Optional[str] = None,
456
- ):
457
- if type(args) is TrainingArguments:
458
- raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
459
-
460
- if not isinstance(model, str) and ref_model is model:
461
- raise ValueError(
462
- "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
463
- "same as `model`, you must mass a copy of it, or `None` if you use peft."
464
- )
465
-
466
- if args.model_init_kwargs is None:
467
- model_init_kwargs = {}
468
- elif not isinstance(model, str):
469
- raise ValueError("You passed model_kwargs to the KTOTrainer. But your model is already instantiated.")
470
- else:
471
- model_init_kwargs = args.model_init_kwargs
472
- torch_dtype = model_init_kwargs.get("torch_dtype")
473
- if torch_dtype is not None:
474
- # Convert to `torch.dtype` if an str is passed
475
- if isinstance(torch_dtype, str) and torch_dtype != "auto":
476
- torch_dtype = getattr(torch, torch_dtype)
477
- if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
478
- raise ValueError(
479
- f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
480
- )
481
- model_init_kwargs["torch_dtype"] = torch_dtype
482
-
483
- if args.ref_model_init_kwargs is None:
484
- ref_model_init_kwargs = {}
485
- elif not isinstance(ref_model, str):
486
- raise ValueError(
487
- "You passed ref_model_kwargs to the KTOTrainer. But your ref_model is already instantiated."
488
- )
489
- else:
490
- ref_model_init_kwargs = args.ref_model_init_kwargs
491
- torch_dtype = ref_model_init_kwargs.get("torch_dtype")
492
- if torch_dtype is not None:
493
- # Convert to `torch.dtype` if an str is passed
494
- if isinstance(torch_dtype, str) and torch_dtype != "auto":
495
- torch_dtype = getattr(torch, torch_dtype)
496
- if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
497
- raise ValueError(
498
- f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
499
- )
500
- ref_model_init_kwargs["torch_dtype"] = torch_dtype
501
-
502
- if isinstance(model, str):
503
- model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
504
-
505
- if isinstance(ref_model, str):
506
- ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
507
-
508
- # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
509
- # has been called in order to properly call autocast if needed.
510
- self._peft_has_been_casted_to_bf16 = False
511
-
512
- if not is_peft_available() and peft_config is not None:
513
- raise ValueError(
514
- "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
515
- )
516
- elif is_peft_available() and peft_config is not None:
517
- # if model is a peft model and we have a peft_config, we merge and unload it first
518
- if isinstance(model, PeftModel):
519
- model = model.merge_and_unload()
520
-
521
- if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
522
- _support_gc_kwargs = hasattr(
523
- args, "gradient_checkpointing_kwargs"
524
- ) and "gradient_checkpointing_kwargs" in list(
525
- inspect.signature(prepare_model_for_kbit_training).parameters
526
- )
527
-
528
- prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
529
-
530
- if _support_gc_kwargs:
531
- prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
532
-
533
- model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
534
- elif getattr(args, "gradient_checkpointing", False):
535
- # For backward compatibility with older versions of transformers
536
- if hasattr(model, "enable_input_require_grads"):
537
- model.enable_input_require_grads()
538
- else:
539
-
540
- def make_inputs_require_grad(module, input, output):
541
- output.requires_grad_(True)
542
-
543
- model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
544
-
545
- # get peft model with the given config
546
- model = model
547
- if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
548
- peft_module_casting_to_bf16(model)
549
- # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
550
- self._peft_has_been_casted_to_bf16 = True
551
-
552
- # For models that use gradient_checkpointing, we need to attach a hook that enables input
553
- # to explicitly have `requires_grad=True`, otherwise training will either silently
554
- # fail or completely fail.
555
- elif getattr(args, "gradient_checkpointing", False):
556
- # For backward compatibility with older versions of transformers
557
- if hasattr(model, "enable_input_require_grads"):
558
- model.enable_input_require_grads()
559
- else:
560
-
561
- def make_inputs_require_grad(module, input, output):
562
- output.requires_grad_(True)
563
-
564
- model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
565
-
566
- if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
567
- raise ValueError(
568
- "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
569
- " Please install `wandb` or `comet-ml` to resolve."
570
- )
571
-
572
- if model is not None:
573
- self.is_encoder_decoder = model.config.is_encoder_decoder
574
- elif args.is_encoder_decoder is None:
575
- raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
576
- else:
577
- self.is_encoder_decoder = args.is_encoder_decoder
578
-
579
- self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
580
- self.model_adapter_name = model_adapter_name
581
- self.ref_adapter_name = ref_adapter_name
582
-
583
- if ref_model:
584
- self.ref_model = ref_model
585
- elif self.is_peft_model or args.precompute_ref_log_probs:
586
- # The `model` with adapters turned off will be used as the reference model
587
- self.ref_model = None
588
- else:
589
- self.ref_model = create_reference_model(model)
590
-
591
- if processing_class is None:
592
- raise ValueError(
593
- "max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
594
- )
595
- if args.max_length is None:
596
- warnings.warn(
597
- "When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init"
598
- " it will be set to `512` by default, but you should do it yourself in the future.",
599
- UserWarning,
600
- )
601
- max_length = 512
602
- if args.max_length is not None:
603
- max_length = args.max_length
604
-
605
- if args.max_prompt_length is None:
606
- warnings.warn(
607
- "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init"
608
- " it will be set to `128` by default, but you should do it yourself in the future.",
609
- UserWarning,
610
- )
611
- max_prompt_length = 128
612
- if args.max_prompt_length is not None:
613
- max_prompt_length = args.max_prompt_length
614
-
615
- max_completion_length = None
616
- if args.max_completion_length is None and self.is_encoder_decoder:
617
- warnings.warn(
618
- "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the KTOTrainer's init"
619
- " it will be set to `128` by default, but you should do it yourself in the future.",
620
- UserWarning,
621
- )
622
- max_completion_length = 128
623
- if args.max_completion_length is not None and self.is_encoder_decoder:
624
- max_completion_length = args.max_completion_length
625
-
626
- if data_collator is None:
627
- data_collator = DPODataCollatorWithPadding(
628
- pad_token_id=processing_class.pad_token_id,
629
- label_pad_token_id=args.label_pad_token_id,
630
- is_encoder_decoder=self.is_encoder_decoder,
631
- )
632
-
633
- if args.remove_unused_columns:
634
- args.remove_unused_columns = False
635
- # warn users
636
- warnings.warn(
637
- "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig"
638
- " we have set it for you, but you should do it yourself in the future.",
639
- UserWarning,
640
- )
641
-
642
- self.use_dpo_data_collator = True
643
- else:
644
- self.use_dpo_data_collator = False
645
-
646
- # Disable dropout in the model and reference model
647
- if args.disable_dropout:
648
- disable_dropout_in_model(model)
649
- if self.ref_model is not None:
650
- disable_dropout_in_model(self.ref_model)
651
-
652
- self.loss_type = args.loss_type
653
- self.max_length = max_length
654
- self.generate_during_eval = args.generate_during_eval
655
- self.label_pad_token_id = args.label_pad_token_id
656
- self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
657
- self.max_prompt_length = max_prompt_length
658
- self.truncation_mode = args.truncation_mode
659
- self.max_completion_length = max_completion_length
660
- self.processing_class = processing_class
661
- self.precompute_ref_log_probs = args.precompute_ref_log_probs
662
-
663
- # Not all losses require a KL calculation
664
- self.calculate_KL = True
665
- if self.loss_type in ["apo_zero_unpaired"]:
666
- self.calculate_KL = False
667
-
668
- # Since ref_logs are precomputed on the first call to get_train/eval_dataloader
669
- # keep track of first called to avoid computation of future calls
670
- self._precomputed_train_ref_log_probs = False
671
- self._precomputed_eval_ref_log_probs = False
672
-
673
- # metric
674
- self._stored_metrics = defaultdict(lambda: defaultdict(list))
675
-
676
- # KTO parameter
677
- self.beta = args.beta
678
- self.desirable_weight = args.desirable_weight
679
- self.undesirable_weight = args.undesirable_weight
680
- self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
681
- self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
682
- if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
683
- warnings.warn(
684
- "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
685
- "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
686
- "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
687
- "loss.",
688
- UserWarning,
689
- )
690
-
691
- # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
692
- # input tensor associated with the key "input_ids". However, in KTO, the sampled data does not include the
693
- # "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
694
- # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
695
- # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
696
- # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
697
- # issued.
698
- model.warnings_issued["estimate_tokens"] = True
699
-
700
- # Compute that only on the main process for faster data processing.
701
- # see: https://github.com/huggingface/trl/pull/1255
702
- with PartialState().main_process_first():
703
- # Extract the prompt if needed
704
- train_dataset = train_dataset.map(
705
- maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset"
706
- )
707
- # Unpair the dataset if needed
708
- train_dataset = maybe_unpair_preference_dataset(
709
- train_dataset, args.dataset_num_proc, desc="Unpairing train dataset"
710
- )
711
- # Apply the chat template if needed
712
- train_dataset = train_dataset.map(
713
- maybe_apply_chat_template,
714
- fn_kwargs={"tokenizer": processing_class},
715
- num_proc=args.dataset_num_proc,
716
- desc="Applying chat template to train dataset",
717
- )
718
- if eval_dataset is not None:
719
- eval_dataset = eval_dataset.map(
720
- maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset"
721
- )
722
- eval_dataset = maybe_unpair_preference_dataset(
723
- eval_dataset, args.dataset_num_proc, desc="Unpairing eval dataset"
724
- )
725
- eval_dataset = eval_dataset.map(
726
- maybe_apply_chat_template,
727
- fn_kwargs={"tokenizer": processing_class},
728
- num_proc=args.dataset_num_proc,
729
- desc="Applying chat template to eval dataset",
730
- )
731
-
732
- # Tokenize and prepare the training datasets
733
- train_dataset = train_dataset.map(
734
- _tokenize,
735
- batched=True,
736
- fn_kwargs={"tokenizer": self.processing_class},
737
- num_proc=args.dataset_num_proc,
738
- desc="Tokenizing train dataset",
739
- )
740
-
741
- fn_kwargs = {
742
- "prefix": "",
743
- "is_encoder_decoder": self.is_encoder_decoder,
744
- "tokenizer": self.processing_class,
745
- "max_length": self.max_length,
746
- "truncation_mode": self.truncation_mode,
747
- "label_pad_token_id": self.label_pad_token_id,
748
- "max_prompt_length": self.max_prompt_length,
749
- "max_completion_length": self.max_completion_length,
750
- }
751
-
752
- train_dataset = train_dataset.map(
753
- _process_tokens,
754
- fn_kwargs=fn_kwargs,
755
- num_proc=args.dataset_num_proc,
756
- desc="Processing tokenized train dataset",
757
- )
758
-
759
- # Tokenize and prepare the eval datasets
760
- if eval_dataset is not None:
761
- eval_dataset = eval_dataset.map(
762
- _tokenize,
763
- fn_kwargs={"tokenizer": self.processing_class},
764
- batched=True,
765
- num_proc=args.dataset_num_proc,
766
- desc="Tokenizing eval dataset",
767
- )
768
-
769
- eval_dataset = eval_dataset.map(
770
- _process_tokens,
771
- fn_kwargs=fn_kwargs,
772
- num_proc=args.dataset_num_proc,
773
- desc="Processing tokenized eval dataset",
774
- )
775
-
776
- # Get KL datasets if needed
777
- if self.calculate_KL:
778
- if args.per_device_train_batch_size <= 1:
779
- raise ValueError(
780
- "Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward."
781
- )
782
-
783
- # create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size
784
- # i.e., [x_1, y_1], ..., [x_n, y_n] --> [x_1, y_n], ..., [x_n, y_1] = [x'_1, y'_1], ..., [x'_n, y'_n]
785
- train_kl_dataset = train_dataset.map(
786
- _get_kl_dataset,
787
- batched=True,
788
- batch_size=args.per_device_train_batch_size,
789
- num_proc=args.dataset_num_proc,
790
- desc="Extracting KL train dataset",
791
- )
792
-
793
- fn_kwargs["prefix"] = "KL_"
794
- train_kl_dataset = train_kl_dataset.map(
795
- _process_tokens,
796
- fn_kwargs=fn_kwargs,
797
- num_proc=args.dataset_num_proc,
798
- remove_columns=[c for c in train_kl_dataset.column_names if c in train_dataset.column_names],
799
- desc="Processing tokenized train KL dataset",
800
- )
801
-
802
- # merge the datasets
803
- train_dataset = concatenate_datasets([train_dataset, train_kl_dataset], axis=1)
804
-
805
- if eval_dataset is not None:
806
- # Get KL dataset
807
- eval_kl_dataset = eval_dataset.map(
808
- _get_kl_dataset,
809
- batched=True,
810
- batch_size=args.per_device_train_batch_size,
811
- num_proc=args.dataset_num_proc,
812
- desc="Extracting eval KL dataset",
813
- )
814
-
815
- eval_kl_dataset = eval_kl_dataset.map(
816
- _process_tokens,
817
- fn_kwargs=fn_kwargs,
818
- num_proc=args.dataset_num_proc,
819
- remove_columns=[c for c in eval_kl_dataset.column_names if c in eval_dataset.column_names],
820
- desc="Processing tokenized eval KL dataset",
821
- )
822
-
823
- # merge the datasets
824
- eval_dataset = concatenate_datasets([eval_dataset, eval_kl_dataset], axis=1)
825
-
826
- # calculate dataset desirability balance
827
- num_desirable = max(sum(train_dataset["label"]), 1)
828
- num_undesirable = max(len(train_dataset["label"]) - num_desirable, 1) # "label" is binary
829
-
830
- if num_desirable != num_undesirable:
831
- # The lower and upper bounds come from Eq. [8] of https://huggingface.co/papers/2402.01306
832
- des_weight_lower_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1, 2)
833
- des_weight_upper_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1.33, 2)
834
- und_weight_lower_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1.33, 2)
835
- und_weight_upper_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1, 2)
836
-
837
- des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound
838
- und_weight_in_range = und_weight_lower_bound <= self.undesirable_weight <= und_weight_upper_bound
839
-
840
- if not (des_weight_in_range or und_weight_in_range):
841
- warnings.warn(
842
- "You have different amounts of desirable/positive and undesirable/negative examples but the "
843
- "weights on the desirable and undesirable losses don't seem to be in an ideal range. Based "
844
- f"on your data, we recommend EITHER "
845
- f"desirable_weight in [{des_weight_lower_bound}, {des_weight_upper_bound}] or "
846
- f"undesirable_weight in [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH). "
847
- "See the documentation on how to optimally set these weights.",
848
- UserWarning,
849
- )
850
-
851
- super().__init__(
852
- model=model,
853
- args=args,
854
- data_collator=data_collator,
855
- train_dataset=train_dataset,
856
- eval_dataset=eval_dataset,
857
- processing_class=processing_class,
858
- model_init=model_init,
859
- compute_metrics=compute_metrics,
860
- callbacks=callbacks,
861
- optimizers=optimizers,
862
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
863
- )
864
-
865
- # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
866
- # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
867
- # self.model_accepts_loss_kwargs to False to enable scaling.
868
- self.model_accepts_loss_kwargs = False
869
-
870
- # Add tags for models that have been loaded with the correct transformers version
871
- if hasattr(self.model, "add_model_tags"):
872
- self.model.add_model_tags(self._tag_names)
873
-
874
- if not hasattr(self, "accelerator"):
875
- raise AttributeError(
876
- "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
877
- )
878
-
879
- # Deepspeed Zero-3 does not support precompute_ref_log_probs
880
- if self.is_deepspeed_enabled:
881
- if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
882
- raise ValueError(
883
- "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
884
- )
885
-
886
- if self.ref_model is None:
887
- if not (self.is_peft_model or self.precompute_ref_log_probs):
888
- raise ValueError(
889
- "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
890
- )
891
- else:
892
- if self.is_deepspeed_enabled:
893
- self.ref_model = self._prepare_deepspeed(self.ref_model)
894
- else:
895
- self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
896
-
897
- def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
898
- # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
899
- deepspeed_plugin = self.accelerator.state.deepspeed_plugin
900
- config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
901
-
902
- if model is not None:
903
- if hasattr(model, "config"):
904
- hidden_size = (
905
- max(model.config.hidden_sizes)
906
- if getattr(model.config, "hidden_sizes", None)
907
- else getattr(model.config, "hidden_size", None)
908
- )
909
- if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
910
- # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
911
- # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
912
- config_kwargs.update(
913
- {
914
- "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
915
- "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
916
- "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
917
- }
918
- )
919
-
920
- # If ZeRO-3 is used, we shard both the active and reference model.
921
- # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
922
- if config_kwargs["zero_optimization"]["stage"] != 3:
923
- config_kwargs["zero_optimization"]["stage"] = 0
924
- model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
925
- model.eval()
926
- return model
927
-
928
- @contextmanager
929
- def null_ref_context(self):
930
- """Context manager for handling null reference model (that is, peft adapter manipulation)."""
931
- with (
932
- self.accelerator.unwrap_model(self.model).disable_adapter()
933
- if self.is_peft_model and not self.ref_adapter_name
934
- else nullcontext()
935
- ):
936
- if self.ref_adapter_name:
937
- self.model.set_adapter(self.ref_adapter_name)
938
- yield
939
- if self.ref_adapter_name:
940
- self.model.set_adapter(self.model_adapter_name or "default")
941
-
942
- def get_train_dataloader(self) -> DataLoader:
943
- """
944
- Returns the training [`~torch.utils.data.DataLoader`].
945
-
946
- Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
947
- """
948
-
949
- if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
950
- dataloader_params = {
951
- "batch_size": self.args.per_device_train_batch_size,
952
- "collate_fn": self.data_collator,
953
- "num_workers": self.args.dataloader_num_workers,
954
- "pin_memory": self.args.dataloader_pin_memory,
955
- "shuffle": False,
956
- }
957
-
958
- # prepare dataloader
959
- data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
960
- reference_completion_logps = []
961
- reference_KL_logps = []
962
-
963
- for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
964
- reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
965
-
966
- reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
967
- reference_completion_logps.append(reference_completion_logp.cpu())
968
-
969
- if self.calculate_KL:
970
- reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
971
- reference_KL_logps.append(reference_KL_logp.cpu())
972
-
973
- self.train_dataset = self.train_dataset.add_column(
974
- name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
975
- )
976
-
977
- if self.calculate_KL:
978
- self.train_dataset = self.train_dataset.add_column(
979
- name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
980
- )
981
-
982
- self._precomputed_train_ref_log_probs = True
983
-
984
- return super().get_train_dataloader()
985
-
986
- def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
987
- """
988
- Returns the evaluation [`~torch.utils.data.DataLoader`].
989
-
990
- Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
991
-
992
- Args:
993
- eval_dataset (`torch.utils.data.Dataset`, *optional*):
994
- If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
995
- by the `model.forward()` method are automatically removed. It must implement `__len__`.
996
- """
997
- if eval_dataset is None and self.eval_dataset is None:
998
- raise ValueError("Trainer: evaluation requires an eval_dataset.")
999
- eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
1000
-
1001
- if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
1002
- dataloader_params = {
1003
- "batch_size": self.args.per_device_eval_batch_size,
1004
- "collate_fn": self.data_collator,
1005
- "num_workers": self.args.dataloader_num_workers,
1006
- "pin_memory": self.args.dataloader_pin_memory,
1007
- "shuffle": False,
1008
- }
1009
-
1010
- # prepare dataloader
1011
- data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
1012
-
1013
- reference_completion_logps = []
1014
- reference_KL_logps = []
1015
-
1016
- for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
1017
- reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
1018
-
1019
- reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
1020
- reference_completion_logps.append(reference_completion_logp.cpu())
1021
-
1022
- if self.calculate_KL:
1023
- reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
1024
- reference_KL_logps.append(reference_KL_logp.cpu())
1025
-
1026
- eval_dataset = eval_dataset.add_column(
1027
- name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
1028
- )
1029
- if self.calculate_KL:
1030
- eval_dataset = eval_dataset.add_column(
1031
- name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
1032
- )
1033
-
1034
- # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
1035
- if self.eval_dataset is not None:
1036
- self.eval_dataset = eval_dataset
1037
- self._precomputed_eval_ref_log_probs = True
1038
-
1039
- return super().get_eval_dataloader(eval_dataset=eval_dataset)
1040
-
1041
- def compute_reference_log_probs(self, padded_batch: dict) -> dict:
1042
- """Computes log probabilities of the reference model for a single padded batch of a KTO specific dataset."""
1043
- with torch.no_grad():
1044
- if self.ref_model is None:
1045
- with self.null_ref_context():
1046
- if self.is_encoder_decoder:
1047
- completion_logits = self.model(
1048
- padded_batch["prompt_input_ids"],
1049
- attention_mask=padded_batch["prompt_attention_mask"],
1050
- decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
1051
- labels=padded_batch["completion_labels"],
1052
- ).logits
1053
-
1054
- if self.calculate_KL:
1055
- KL_logits = self.model(
1056
- padded_batch["KL_prompt_input_ids"],
1057
- attention_mask=padded_batch["KL_prompt_attention_mask"],
1058
- decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
1059
- labels=padded_batch["KL_completion_labels"],
1060
- ).logits
1061
- else:
1062
- completion_logits = self.model(
1063
- padded_batch["completion_input_ids"],
1064
- attention_mask=padded_batch["completion_attention_mask"],
1065
- ).logits
1066
-
1067
- if self.calculate_KL:
1068
- KL_logits = self.model(
1069
- padded_batch["KL_completion_input_ids"],
1070
- attention_mask=padded_batch["KL_completion_attention_mask"],
1071
- ).logits
1072
- else:
1073
- if self.is_encoder_decoder:
1074
- completion_logits = self.ref_model(
1075
- padded_batch["prompt_input_ids"],
1076
- attention_mask=padded_batch["prompt_attention_mask"],
1077
- decoder_input_ids=padded_batch.get("completion_decoder_input_ids"),
1078
- labels=padded_batch["completion_labels"],
1079
- ).logits
1080
-
1081
- if self.calculate_KL:
1082
- KL_logits = self.ref_model(
1083
- padded_batch["KL_prompt_input_ids"],
1084
- attention_mask=padded_batch["KL_prompt_attention_mask"],
1085
- decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"),
1086
- labels=padded_batch["KL_completion_labels"],
1087
- ).logits
1088
- else:
1089
- completion_logits = self.ref_model(
1090
- padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
1091
- ).logits
1092
-
1093
- if self.calculate_KL:
1094
- KL_logits = self.ref_model(
1095
- padded_batch["KL_completion_input_ids"],
1096
- attention_mask=padded_batch["KL_completion_attention_mask"],
1097
- ).logits
1098
-
1099
- completion_logps = self.get_batch_logps(
1100
- completion_logits,
1101
- padded_batch["completion_labels"],
1102
- average_log_prob=False,
1103
- is_encoder_decoder=self.is_encoder_decoder,
1104
- label_pad_token_id=self.label_pad_token_id,
1105
- )
1106
-
1107
- if self.calculate_KL:
1108
- KL_logps = self.get_batch_logps(
1109
- KL_logits,
1110
- padded_batch["KL_completion_labels"],
1111
- average_log_prob=False,
1112
- is_encoder_decoder=self.is_encoder_decoder,
1113
- label_pad_token_id=self.label_pad_token_id,
1114
- )
1115
- else:
1116
- KL_logps = None
1117
-
1118
- return completion_logps, KL_logps
1119
-
1120
- @staticmethod
1121
- def get_batch_logps(
1122
- logits: torch.FloatTensor,
1123
- labels: torch.LongTensor,
1124
- average_log_prob: bool = False,
1125
- label_pad_token_id: int = -100,
1126
- is_encoder_decoder: bool = False,
1127
- ) -> torch.FloatTensor:
1128
- """Compute the log probabilities of the given labels under the given logits.
1129
-
1130
- Args:
1131
- logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
1132
- labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
1133
- average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
1134
-
1135
- Returns:
1136
- A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
1137
- """
1138
- if logits.shape[:-1] != labels.shape:
1139
- raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
1140
-
1141
- if not is_encoder_decoder:
1142
- labels = labels[:, 1:].clone()
1143
- logits = logits[:, :-1, :]
1144
- else:
1145
- # Fixes end-dec RuntimeError
1146
- labels = labels.clone()
1147
-
1148
- loss_mask = labels != label_pad_token_id
1149
-
1150
- # dummy token; we'll ignore the losses on these tokens later
1151
- labels[labels == label_pad_token_id] = 0
1152
-
1153
- per_token_logps = selective_log_softmax(logits, labels)
1154
-
1155
- if average_log_prob:
1156
- return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1157
- else:
1158
- return (per_token_logps * loss_mask).sum(-1)
1159
-
1160
- def forward(
1161
- self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
1162
- ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1163
- if self.calculate_KL:
1164
- KL_logps = None
1165
- KL_model_kwargs = (
1166
- {
1167
- "input_ids": batch["KL_prompt_input_ids"],
1168
- "attention_mask": batch["KL_prompt_attention_mask"],
1169
- "labels": batch["KL_completion_labels"],
1170
- "decoder_input_ids": batch.get("KL_completion_decoder_input_ids"),
1171
- }
1172
- if self.is_encoder_decoder
1173
- else {
1174
- "input_ids": batch["KL_completion_input_ids"],
1175
- "attention_mask": batch["KL_completion_attention_mask"],
1176
- }
1177
- )
1178
- with torch.no_grad():
1179
- KL_logits = model(
1180
- **KL_model_kwargs,
1181
- ).logits
1182
-
1183
- KL_logps = self.get_batch_logps(
1184
- KL_logits,
1185
- batch["KL_completion_labels"],
1186
- average_log_prob=False,
1187
- is_encoder_decoder=self.is_encoder_decoder,
1188
- label_pad_token_id=self.label_pad_token_id,
1189
- )
1190
- else:
1191
- KL_logps = None
1192
-
1193
- model_kwargs = (
1194
- {
1195
- "labels": batch["completion_labels"],
1196
- "decoder_input_ids": batch.get("completion_decoder_input_ids"),
1197
- }
1198
- if self.is_encoder_decoder
1199
- else {}
1200
- )
1201
- if self.aux_loss_enabled:
1202
- model_kwargs["output_router_logits"] = True
1203
-
1204
- outputs = model(
1205
- batch["completion_input_ids"],
1206
- attention_mask=batch["completion_attention_mask"],
1207
- **model_kwargs,
1208
- )
1209
- completion_logits = outputs.logits
1210
-
1211
- completion_logps = self.get_batch_logps(
1212
- completion_logits,
1213
- batch["completion_labels"],
1214
- average_log_prob=False,
1215
- is_encoder_decoder=self.is_encoder_decoder,
1216
- label_pad_token_id=self.label_pad_token_id,
1217
- )
1218
-
1219
- if completion_logps.shape[0] != len(batch["label"]):
1220
- raise ValueError(
1221
- "There is a mismatch between the number of examples in this batch and the number of "
1222
- "examples for which an output sequence was predicted."
1223
- )
1224
-
1225
- chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True]
1226
- rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False]
1227
-
1228
- chosen_logps = completion_logps[chosen_idx, ...]
1229
- rejected_logps = completion_logps[rejected_idx, ...]
1230
-
1231
- chosen_logits = completion_logits[chosen_idx, ...]
1232
- rejected_logits = completion_logits[rejected_idx, ...]
1233
-
1234
- if self.aux_loss_enabled:
1235
- return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps, outputs.aux_loss)
1236
- else:
1237
- return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps)
1238
-
1239
- def kto_loss(
1240
- self,
1241
- policy_chosen_logps: torch.FloatTensor,
1242
- policy_rejected_logps: torch.FloatTensor,
1243
- policy_KL_logps: torch.FloatTensor,
1244
- reference_chosen_logps: torch.FloatTensor,
1245
- reference_rejected_logps: torch.FloatTensor,
1246
- reference_KL_logps: torch.FloatTensor,
1247
- ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1248
- """Compute the KTO loss for a batch of policy and reference model log probabilities.
1249
-
1250
- Args:
1251
- policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
1252
- policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
1253
- policy_KL_logps: Log probabilities of the policy model for the KL responses. Shape: (batch_size,)
1254
- reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
1255
- reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,)
1256
- reference_KL_logps: Log probabilities of the reference model for the KL responses. Shape: (batch_size,)
1257
-
1258
- Returns:
1259
- A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, KL).
1260
- The losses tensor contains the KTO loss for each example in the batch.
1261
- The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
1262
- The KL tensor contains the detached KL divergence estimate between the policy and reference models.
1263
- """
1264
- if self.calculate_KL:
1265
- kl = (policy_KL_logps - reference_KL_logps).mean().detach()
1266
- kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0)
1267
- else:
1268
- kl = torch.zeros(1).to(policy_chosen_logps.device)
1269
-
1270
- # Chosen losses
1271
- if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
1272
- chosen_logratios = policy_chosen_logps - reference_chosen_logps
1273
-
1274
- if self.loss_type == "kto":
1275
- # Eqn (7) of the KTO paper (https://huggingface.co/papers/2402.01306)
1276
- chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl))
1277
- elif self.loss_type == "apo_zero_unpaired":
1278
- # Unpaired variant of Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
1279
- # Use this loss when you believe the chosen outputs are better than your model's default output
1280
- chosen_losses = 1 - F.sigmoid(self.beta * chosen_logratios)
1281
-
1282
- chosen_rewards = self.beta * chosen_logratios.detach()
1283
-
1284
- else:
1285
- # lists can't be empty -- if they are, then accelerate.gather will hang
1286
- chosen_losses = torch.Tensor([]).to(self.accelerator.device)
1287
- chosen_rewards = torch.Tensor([]).to(self.accelerator.device)
1288
-
1289
- # Rejected losses
1290
- if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
1291
- rejected_logratios = policy_rejected_logps - reference_rejected_logps
1292
-
1293
- if self.loss_type == "kto":
1294
- rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios))
1295
- elif self.loss_type == "apo_zero_unpaired":
1296
- rejected_losses = F.sigmoid(self.beta * rejected_logratios)
1297
-
1298
- rejected_rewards = self.beta * rejected_logratios.detach()
1299
- else:
1300
- # lists can't be empty -- if they are, then accelerate.gather will hang
1301
- rejected_losses = torch.Tensor([]).to(self.accelerator.device)
1302
- rejected_rewards = torch.Tensor([]).to(self.accelerator.device)
1303
-
1304
- losses = torch.cat(
1305
- (self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses),
1306
- 0,
1307
- )
1308
-
1309
- return losses, chosen_rewards, rejected_rewards, kl
1310
-
1311
- def get_batch_loss_metrics(
1312
- self,
1313
- model,
1314
- batch: dict[str, Union[list, torch.LongTensor]],
1315
- ):
1316
- """Compute the KTO loss and other metrics for the given batch of inputs for train or test."""
1317
- metrics = {}
1318
- batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
1319
-
1320
- forward_output = self.forward(model, batch)
1321
- (
1322
- policy_chosen_logps,
1323
- policy_rejected_logps,
1324
- policy_chosen_logits,
1325
- policy_rejected_logits,
1326
- policy_KL_logps,
1327
- ) = forward_output[:5]
1328
- if self.aux_loss_enabled:
1329
- aux_loss = forward_output[5]
1330
-
1331
- # if reference_logps in batch use them, otherwise use the reference model
1332
- if "reference_logps" in batch:
1333
- chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
1334
- rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]
1335
-
1336
- reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
1337
- reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
1338
- if self.calculate_KL:
1339
- reference_KL_logps = batch["reference_KL_logps"]
1340
- else:
1341
- reference_KL_logps = None
1342
- else:
1343
- with torch.no_grad():
1344
- if self.ref_model is None:
1345
- with self.null_ref_context():
1346
- (
1347
- reference_chosen_logps,
1348
- reference_rejected_logps,
1349
- _,
1350
- _,
1351
- reference_KL_logps,
1352
- ) = self.forward(self.model, batch)[:5]
1353
- else:
1354
- (
1355
- reference_chosen_logps,
1356
- reference_rejected_logps,
1357
- _,
1358
- _,
1359
- reference_KL_logps,
1360
- ) = self.forward(self.ref_model, batch)[:5]
1361
-
1362
- losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
1363
- policy_chosen_logps,
1364
- policy_rejected_logps,
1365
- policy_KL_logps,
1366
- reference_chosen_logps,
1367
- reference_rejected_logps,
1368
- reference_KL_logps,
1369
- )
1370
- metrics["kl"] = kl.item()
1371
-
1372
- num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
1373
- num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
1374
-
1375
- all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
1376
- all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()
1377
-
1378
- if all_num_chosen > 0:
1379
- metrics["rewards/chosen_sum"] = (
1380
- self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
1381
- )
1382
- metrics["logps/chosen_sum"] = (
1383
- self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
1384
- )
1385
- metrics["logits/chosen_sum"] = (
1386
- self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
1387
- )
1388
- metrics["count/chosen"] = all_num_chosen
1389
-
1390
- if all_num_rejected > 0:
1391
- metrics["rewards/rejected_sum"] = (
1392
- self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
1393
- )
1394
- metrics["logps/rejected_sum"] = (
1395
- self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
1396
- )
1397
- metrics["logits/rejected_sum"] = (
1398
- self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
1399
- )
1400
- metrics["count/rejected"] = all_num_rejected
1401
-
1402
- loss = losses.nanmean()
1403
- if self.aux_loss_enabled:
1404
- loss += self.aux_loss_coef * aux_loss
1405
-
1406
- return loss, metrics
1407
-
1408
- def compute_loss(
1409
- self,
1410
- model: Union[PreTrainedModel, nn.Module],
1411
- inputs: dict[str, Union[torch.Tensor, Any]],
1412
- return_outputs=False,
1413
- num_items_in_batch=None,
1414
- ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1415
- compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1416
-
1417
- with compute_loss_context_manager:
1418
- loss, metrics = self.get_batch_loss_metrics(model, inputs)
1419
-
1420
- # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
1421
- loss = loss.to(self.args.device)
1422
- # force log the metrics
1423
- if self.accelerator.is_main_process:
1424
- self.store_metrics(metrics, train_eval="train")
1425
-
1426
- if return_outputs:
1427
- return (loss, metrics)
1428
- return loss
1429
-
1430
- def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1431
- for key, value in metrics.items():
1432
- self._stored_metrics[train_eval][key].append(value)
1433
-
1434
- def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
1435
- if self.train_dataset is None or not has_length(self.train_dataset):
1436
- return None
1437
- return SequentialSampler(self.train_dataset)
1438
-
1439
- def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
1440
- """Generate samples from the model and reference model for the given batch of inputs."""
1441
-
1442
- # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1443
- # the torch cuda amp context manager as some hidden states are silently casted to full precision.
1444
- generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1445
-
1446
- with generate_context_manager:
1447
- policy_output = model.generate(
1448
- input_ids=batch["prompt_input_ids"],
1449
- attention_mask=batch["prompt_attention_mask"],
1450
- max_length=self.max_length,
1451
- do_sample=True,
1452
- pad_token_id=self.processing_class.pad_token_id,
1453
- )
1454
-
1455
- # if reference_output in batch use that otherwise use the reference model
1456
- if "reference_output" in batch:
1457
- reference_output = batch["reference_output"]
1458
- else:
1459
- if self.ref_model is None:
1460
- with self.null_ref_context():
1461
- reference_output = self.model.generate(
1462
- input_ids=batch["prompt_input_ids"],
1463
- attention_mask=batch["prompt_attention_mask"],
1464
- max_length=self.max_length,
1465
- do_sample=True,
1466
- pad_token_id=self.processing_class.pad_token_id,
1467
- )
1468
- else:
1469
- reference_output = self.ref_model.generate(
1470
- input_ids=batch["prompt_input_ids"],
1471
- attention_mask=batch["prompt_attention_mask"],
1472
- max_length=self.max_length,
1473
- do_sample=True,
1474
- pad_token_id=self.processing_class.pad_token_id,
1475
- )
1476
-
1477
- policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1478
- policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1479
-
1480
- reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
1481
- reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)
1482
-
1483
- return policy_output_decoded, reference_output_decoded
1484
-
1485
- def prediction_step(
1486
- self,
1487
- model: Union[PreTrainedModel, nn.Module],
1488
- inputs: dict[str, Union[torch.Tensor, Any]],
1489
- prediction_loss_only: bool,
1490
- ignore_keys: Optional[list[str]] = None,
1491
- ):
1492
- if ignore_keys is None:
1493
- if hasattr(model, "config"):
1494
- ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1495
- else:
1496
- ignore_keys = []
1497
-
1498
- prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1499
- with torch.no_grad(), prediction_context_manager:
1500
- loss, metrics = self.get_batch_loss_metrics(model, inputs)
1501
-
1502
- # force log the metrics
1503
- if self.accelerator.is_main_process:
1504
- self.store_metrics(metrics, train_eval="eval")
1505
-
1506
- if prediction_loss_only:
1507
- return (loss.detach(), None, None)
1508
-
1509
- # logits for the chosen and rejected samples from model
1510
- logits_dict = {}
1511
- if "logits/chosen_sum" in metrics:
1512
- logits_dict["eval_logits/chosen"] = metrics["logits/chosen_sum"]
1513
- if "logits/rejected_sum" in metrics:
1514
- logits_dict["eval_logits/rejected"] = metrics["logits/rejected_sum"]
1515
- logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
1516
- logits = torch.tensor(logits, device=self.accelerator.device)
1517
- labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1518
-
1519
- return (loss.detach(), logits, labels)
1520
-
1521
- def evaluation_loop(
1522
- self,
1523
- dataloader: DataLoader,
1524
- description: str,
1525
- prediction_loss_only: Optional[bool] = None,
1526
- ignore_keys: Optional[list[str]] = None,
1527
- metric_key_prefix: str = "eval",
1528
- ) -> EvalLoopOutput:
1529
- """
1530
- Overriding built-in evaluation loop to store metrics for each batch.
1531
- Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
1532
-
1533
- Works both with or without labels.
1534
- """
1535
-
1536
- # Sample and save to game log if requested (for one batch to save time)
1537
- if self.generate_during_eval:
1538
- # Generate random indices within the range of the total number of samples
1539
- num_samples = len(dataloader.dataset)
1540
- random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1541
-
1542
- # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1543
- random_batch_dataset = dataloader.dataset.select(random_indices)
1544
- random_batch = self.data_collator(random_batch_dataset)
1545
- random_batch = self._prepare_inputs(random_batch)
1546
-
1547
- target_indicies = [i for i in range(len(random_batch["label"])) if random_batch["label"][i] is False]
1548
- target_batch = {
1549
- "prompt_input_ids": random_batch["prompt_input_ids"][target_indicies],
1550
- "prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies],
1551
- "prompt": itemgetter(*target_indicies)(random_batch["prompt"]),
1552
- }
1553
- policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)
1554
-
1555
- table = pd.DataFrame(
1556
- columns=["Prompt", "Policy", "Ref Model"],
1557
- data=[
1558
- [prompt, pol[len(prompt) :], ref[len(prompt) :]]
1559
- for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
1560
- ],
1561
- )
1562
- if "wandb" in self.args.report_to:
1563
- wandb.log({"game_log": wandb.Table(data=table)})
1564
-
1565
- if "comet_ml" in self.args.report_to:
1566
- log_table_to_comet_experiment(
1567
- name="game_log.csv",
1568
- table=table,
1569
- )
1570
-
1571
- # Base evaluation
1572
- initial_output = super().evaluation_loop(
1573
- dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1574
- )
1575
-
1576
- return initial_output
1577
-
1578
- def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1579
- """
1580
- Log `logs` on the various objects watching training, including stored metrics.
1581
-
1582
- Args:
1583
- logs (`dict[str, float]`):
1584
- The values to log.
1585
- start_time (`float` or `None`, *optional*, defaults to `None`):
1586
- Start time of the training.
1587
- """
1588
- # logs either has 'loss' or 'eval_loss'
1589
- train_eval = "train" if "loss" in logs else "eval"
1590
- # train metrics should have no prefix, eval should have 'eval_'
1591
- prefix = "eval_" if train_eval == "eval" else ""
1592
- # accumulate average metrics from sums and lengths
1593
- for split in ["chosen", "rejected"]:
1594
- if f"count/{split}" in self._stored_metrics[train_eval]:
1595
- count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
1596
- for metric in ["rewards", "logps", "logits"]:
1597
- logs[f"{prefix}{metric}/{split}"] = (
1598
- torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item()
1599
- / count_sum
1600
- )
1601
- # delete obsolete metric
1602
- del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
1603
- del self._stored_metrics[train_eval][f"count/{split}"]
1604
- # calculate reward margin
1605
- if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
1606
- logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
1607
- # Add averaged stored metrics to logs
1608
- for key, metrics in self._stored_metrics[train_eval].items():
1609
- logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
1610
- del self._stored_metrics[train_eval]
1611
-
1612
- if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1613
- return super().log(logs, start_time)
1614
- else: # transformers<=4.46
1615
- return super().log(logs)
1616
-
1617
- def create_model_card(
1618
- self,
1619
- model_name: Optional[str] = None,
1620
- dataset_name: Optional[str] = None,
1621
- tags: Union[str, list[str], None] = None,
1622
- ):
1623
- """
1624
- Creates a draft of a model card using the information available to the `Trainer`.
1625
-
1626
- Args:
1627
- model_name (`str` or `None`, *optional*, defaults to `None`):
1628
- Name of the model.
1629
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
1630
- Name of the dataset used for training.
1631
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1632
- Tags to be associated with the model card.
1633
- """
1634
- if not self.is_world_process_zero():
1635
- return
1636
-
1637
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1638
- base_model = self.model.config._name_or_path
1639
- else:
1640
- base_model = None
1641
-
1642
- tags = tags or []
1643
- if isinstance(tags, str):
1644
- tags = [tags]
1645
-
1646
- if hasattr(self.model.config, "unsloth_version"):
1647
- tags.append("unsloth")
1648
-
1649
- citation = textwrap.dedent("""\
1650
- @article{ethayarajh2024kto,
1651
- title = {{KTO: Model Alignment as Prospect Theoretic Optimization}},
1652
- author = {Kawin Ethayarajh and Winnie Xu and Niklas Muennighoff and Dan Jurafsky and Douwe Kiela},
1653
- year = 2024,
1654
- eprint = {arXiv:2402.01306},
1655
- }""")
1656
-
1657
- model_card = generate_model_card(
1658
- base_model=base_model,
1659
- model_name=model_name,
1660
- hub_model_id=self.hub_model_id,
1661
- dataset_name=dataset_name,
1662
- tags=tags,
1663
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1664
- comet_url=get_comet_experiment_url(),
1665
- trainer_name="KTO",
1666
- trainer_citation=citation,
1667
- paper_title="KTO: Model Alignment as Prospect Theoretic Optimization",
1668
- paper_id="2402.01306",
1669
- )
1670
-
1671
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
1672
- class UnslothKTOTrainer(_UnslothKTOTrainer):
1673
- """
1674
-
1675
- Initialize KTOTrainer.
1676
-
1677
- Args:
1678
- model (`transformers.PreTrainedModel`):
1679
- The model to train, preferably an `AutoModelForSequenceClassification`.
1680
- ref_model (`PreTrainedModelWrapper`):
1681
- Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
1682
- reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
1683
- args (`KTOConfig`):
1684
- The arguments to use for training.
1685
- train_dataset (`datasets.Dataset`):
1686
- The dataset to use for training.
1687
- eval_dataset (`datasets.Dataset`):
1688
- The dataset to use for evaluation.
1689
- processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1690
- Processing class used to process the data. If provided, will be used to automatically process the inputs
1691
- for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1692
- reuse the fine-tuned model.
1693
- data_collator (`transformers.DataCollator`, *optional*, defaults to `None`):
1694
- The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1695
- which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1696
- model_init (`Callable[[], transformers.PreTrainedModel]`):
1697
- The model initializer to use for training. If None is specified, the default model initializer will be used.
1698
- callbacks (`list[transformers.TrainerCallback]`):
1699
- The callbacks to use for training.
1700
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1701
- The optimizer and scheduler to use for training.
1702
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1703
- The function to use to preprocess the logits before computing the metrics.
1704
- peft_config (`dict`, defaults to `None`):
1705
- The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
1706
- compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1707
- The function to use to compute the metrics. Must take a `EvalPrediction` and return
1708
- a dictionary string to metric values.
1709
- model_adapter_name (`str`, defaults to `None`):
1710
- Name of the train target PEFT adapter, when using LoRA with multiple adapters.
1711
- ref_adapter_name (`str`, defaults to `None`):
1712
- Name of the reference PEFT adapter, when using LoRA with multiple adapters.
1713
-
1714
- """
1715
- def __init__(
1716
- self,
1717
- model = None,
1718
- ref_model = None,
1719
- args = None,
1720
- train_dataset = None,
1721
- eval_dataset = None,
1722
- processing_class = None,
1723
- data_collator = None,
1724
- model_init = None,
1725
- callbacks = None,
1726
- preprocess_logits_for_metrics = None,
1727
- peft_config = None,
1728
- compute_metrics = None,
1729
- model_adapter_name = None,
1730
- ref_adapter_name = None,
1731
- **kwargs
1732
- ):
1733
- if args is None: args = UnslothKTOConfig()
1734
- use_bf16 = getattr(args, 'bf16', False)
1735
- if type(use_bf16) is not bool: use_bf16 = False
1736
- use_fp16 = getattr(args, 'fp16', False)
1737
- if type(use_fp16) is not bool: use_fp16 = False
1738
- force_float32 = False
1739
- if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1740
- print('Unsloth: Switching to float32 training since model cannot work with float16')
1741
- force_float32 = True
1742
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1743
- dtype = getattr(model.config, 'torch_dtype', None)
1744
- if dtype is None: dtype = model.get_input_embeddings().dtype
1745
- from unsloth_zoo.utils import _get_dtype
1746
- dtype = _get_dtype(dtype)
1747
- float16 = dtype == torch.float16
1748
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1749
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1750
- if force_float32:
1751
- args.fp16 = False
1752
- args.bf16 = False
1753
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1754
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1755
- args.fp16 = float16
1756
- args.bf16 = not float16
1757
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1758
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1759
- args.eval_strategy = 'steps'
1760
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1761
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1762
- if ga_steps is not None and ga_steps > 1:
1763
- from transformers import __version__ as transformers_version
1764
- if Version(transformers_version) <= Version('4.45.2'):
1765
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1766
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1767
- if getattr(args, 'eval_strategy', 'no') != 'no':
1768
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1769
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1770
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1771
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1772
- if type(fp16_full_eval) is not bool: fp16_full_eval = False
1773
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1774
- if type(bf16_full_eval) is not bool: bf16_full_eval = False
1775
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1776
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1777
- if force_float32:
1778
- args.bf16_full_eval = False
1779
- args.fp16_full_eval = False
1780
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1781
- args.bf16_full_eval = True
1782
- args.fp16_full_eval = False
1783
- elif not bf16_full_eval and not fp16_full_eval:
1784
- args.bf16_full_eval = args.bf16
1785
- args.fp16_full_eval = args.fp16
1786
- _output_logits = False
1787
- if locals().get('compute_metrics', None) is not None: _output_logits = True
1788
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1789
- if _output_logits:
1790
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1791
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1792
- pass
1793
- else:
1794
- model_max_seq_length = getattr(model, 'max_seq_length', None)
1795
- args_max_seq_length = getattr(args, 'max_seq_length', None)
1796
- if args_max_seq_length is None and model_max_seq_length is not None:
1797
- max_seq_length = model.max_seq_length
1798
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1799
- if model is not None and hasattr(model, 'for_training'):
1800
- model.for_training()
1801
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1802
- if 'processing_class' in locals():
1803
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1804
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1805
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1806
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1807
- if not isinstance(data_collator, UnslothVisionDataCollator):
1808
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1809
- data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
1810
- elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1811
- data_collator = DataCollatorForSeq2Seq(__tokenizer)
1812
- else:
1813
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1814
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1815
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1816
- if not isinstance(data_collator, UnslothVisionDataCollator):
1817
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1818
- if isinstance(data_collator, DataCollatorForSeq2Seq):
1819
- data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1820
- else:
1821
- data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
1822
- other_metrics = []
1823
-
1824
- from unsloth_zoo.logging_utils import PatchRLStatistics
1825
- PatchRLStatistics('kto_trainer', other_metrics)
1826
-
1827
- super().__init__(
1828
- model = model,
1829
- ref_model = ref_model,
1830
- args = args,
1831
- train_dataset = train_dataset,
1832
- eval_dataset = eval_dataset,
1833
- processing_class = processing_class,
1834
- data_collator = data_collator,
1835
- model_init = model_init,
1836
- callbacks = callbacks,
1837
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1838
- peft_config = peft_config,
1839
- compute_metrics = compute_metrics,
1840
- model_adapter_name = model_adapter_name,
1841
- ref_adapter_name = ref_adapter_name,**kwargs)
1842
- if hasattr(self, 'neftune_hook_handle'):
1843
- self.neftune_hook_handle.remove()
1844
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1845
- if getattr(args, 'neftune_noise_alpha', None) is not None:
1846
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1847
- pass
1848
-
1849
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_run_uploads/UnslothNashMDTrainer.py DELETED
@@ -1,969 +0,0 @@
1
- """
2
- 2025.7.11
3
- 2025.7.11
4
- 4.54.1
5
- 0.16.1
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
13
- from trl.trainer.nash_md_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, GeometricMixtureWrapper, IterableDataset, NashMDConfig, NashMDTrainer, OnlineDPOTrainer, OptimizerNames, Optional, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_wandb_available, jinja2, maybe_apply_chat_template, nn, os, selective_log_softmax, textwrap, torch, truncate_right, unwrap_model_for_generation, wandb)
14
-
15
-
16
- import os
17
- from typing import *
18
- from dataclasses import dataclass, field
19
- from packaging.version import Version
20
- import torch
21
- import numpy as np
22
- from contextlib import nullcontext
23
- from torch.nn import functional as F
24
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
25
-
26
- torch_compile_options = {
27
- "epilogue_fusion" : True,
28
- "max_autotune" : False,
29
- "shape_padding" : True,
30
- "trace.enabled" : False,
31
- "triton.cudagraphs" : False,
32
- }
33
-
34
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
35
- def chunked_selective_log_softmax(logits, index):
36
- # Split into 4 chunks only
37
- chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
38
- chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
39
- all_per_token_logps = []
40
- # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
41
- for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
42
- chunk_logits = chunk_logits.to(torch.float32)
43
- selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
44
- logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
45
- per_token_logps = selected_logits - logsumexp_values
46
- all_per_token_logps.append(per_token_logps)
47
- pass
48
- all_per_token_logps = torch.concat(all_per_token_logps)
49
- all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
50
- return all_per_token_logps
51
- @dataclass
52
- class UnslothNashMDConfig(NashMDConfig):
53
- """
54
-
55
- Configuration class for the [`NashMDTrainer`].
56
-
57
- Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
58
-
59
- Parameters:
60
- mixture_coef (`float` or `list[float]`, *optional*, defaults to `0.5`):
61
- Logit mixture coefficient for the model and reference model. If a list of floats is provided then the
62
- mixture coefficient is selected for each new epoch and the last coefficient is used for the rest of the
63
- epochs.
64
-
65
- """
66
- vllm_sampling_params: Optional[Any] = field(
67
- default = None,
68
- metadata = {'help': 'vLLM SamplingParams'},
69
- )
70
- unsloth_num_chunks : Optional[int] = field(
71
- default = -1,
72
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
73
- )
74
- def __init__(
75
- self,
76
- output_dir = None,
77
- overwrite_output_dir = None,
78
- do_train = False,
79
- do_eval = False,
80
- do_predict = False,
81
- eval_strategy = 'no',
82
- prediction_loss_only = False,
83
- per_device_train_batch_size = 4,
84
- per_device_eval_batch_size = 4,
85
- per_gpu_train_batch_size = None,
86
- per_gpu_eval_batch_size = None,
87
- gradient_accumulation_steps = 2,
88
- eval_accumulation_steps = 2,
89
- eval_delay = 0,
90
- torch_empty_cache_steps = 250,
91
- learning_rate = 5e-05,
92
- weight_decay = 0.01,
93
- adam_beta1 = 0.9,
94
- adam_beta2 = 0.999,
95
- adam_epsilon = 1e-08,
96
- max_grad_norm = 1.0,
97
- num_train_epochs = 3.0,
98
- max_steps = -1,
99
- lr_scheduler_type = 'linear',
100
- warmup_ratio = 0.1,
101
- warmup_steps = 0,
102
- log_level = 'passive',
103
- log_level_replica = 'warning',
104
- log_on_each_node = True,
105
- logging_dir = None,
106
- logging_strategy = 'steps',
107
- logging_first_step = False,
108
- logging_steps = 1,
109
- logging_nan_inf_filter = False,
110
- save_strategy = 'steps',
111
- save_steps = 500,
112
- save_total_limit = None,
113
- save_safetensors = True,
114
- save_on_each_node = False,
115
- save_only_model = False,
116
- restore_callback_states_from_checkpoint = False,
117
- no_cuda = False,
118
- use_cpu = False,
119
- use_mps_device = False,
120
- seed = 3407,
121
- data_seed = 3407,
122
- jit_mode_eval = False,
123
- use_ipex = False,
124
- bf16 = False,
125
- fp16 = False,
126
- fp16_opt_level = 'O1',
127
- half_precision_backend = 'auto',
128
- bf16_full_eval = False,
129
- fp16_full_eval = False,
130
- tf32 = None,
131
- local_rank = -1,
132
- ddp_backend = None,
133
- tpu_num_cores = None,
134
- tpu_metrics_debug = False,
135
- debug = '',
136
- dataloader_drop_last = False,
137
- eval_steps = None,
138
- dataloader_num_workers = 0,
139
- dataloader_prefetch_factor = None,
140
- past_index = -1,
141
- run_name = None,
142
- disable_tqdm = None,
143
- remove_unused_columns = True,
144
- label_names = None,
145
- load_best_model_at_end = False,
146
- metric_for_best_model = None,
147
- greater_is_better = None,
148
- ignore_data_skip = False,
149
- fsdp = '',
150
- fsdp_min_num_params = 0,
151
- fsdp_config = None,
152
- fsdp_transformer_layer_cls_to_wrap = None,
153
- accelerator_config = None,
154
- deepspeed = None,
155
- label_smoothing_factor = 0.0,
156
- optim = 'adamw_8bit',
157
- optim_args = None,
158
- adafactor = False,
159
- group_by_length = False,
160
- length_column_name = 'length',
161
- report_to = None,
162
- ddp_find_unused_parameters = None,
163
- ddp_bucket_cap_mb = None,
164
- ddp_broadcast_buffers = None,
165
- dataloader_pin_memory = True,
166
- dataloader_persistent_workers = False,
167
- skip_memory_metrics = True,
168
- use_legacy_prediction_loop = False,
169
- push_to_hub = False,
170
- resume_from_checkpoint = None,
171
- hub_model_id = None,
172
- hub_strategy = 'every_save',
173
- hub_token = None,
174
- hub_private_repo = None,
175
- hub_always_push = False,
176
- hub_revision = None,
177
- gradient_checkpointing = False,
178
- gradient_checkpointing_kwargs = None,
179
- include_inputs_for_metrics = False,
180
- eval_do_concat_batches = True,
181
- fp16_backend = 'auto',
182
- push_to_hub_model_id = None,
183
- push_to_hub_organization = None,
184
- push_to_hub_token = None,
185
- mp_parameters = '',
186
- auto_find_batch_size = True,
187
- full_determinism = False,
188
- torchdynamo = None,
189
- ray_scope = 'last',
190
- ddp_timeout = 1800,
191
- torch_compile = False,
192
- torch_compile_backend = None,
193
- torch_compile_mode = None,
194
- include_tokens_per_second = False,
195
- include_num_input_tokens_seen = False,
196
- neftune_noise_alpha = None,
197
- optim_target_modules = None,
198
- batch_eval_metrics = False,
199
- eval_on_start = False,
200
- use_liger_kernel = False,
201
- liger_kernel_config = None,
202
- eval_use_gather_object = False,
203
- average_tokens_across_devices = True,
204
- reward_model_path = None,
205
- judge = None,
206
- max_new_tokens = 64,
207
- max_length = 512,
208
- temperature = 0.9,
209
- missing_eos_penalty = None,
210
- loss_type = 'sigmoid',
211
- dataset_num_proc = None,
212
- disable_dropout = True,
213
- use_vllm = False,
214
- ds3_gather_for_generation = True,
215
- vllm_sampling_params = None,
216
- unsloth_num_chunks = -1,
217
- **kwargs,
218
- ):
219
- if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
220
- if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
221
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
222
- output_dir = 'unsloth_training_checkpoints'
223
- save_strategy = 'no'
224
- if dataset_num_proc is None:
225
- from multiprocessing import cpu_count
226
- dataset_num_proc = min(cpu_count()*2, 2)
227
- if temperature <= 0:
228
- raise MathError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
229
- elif temperature >= 10:
230
- raise MathError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
231
-
232
-
233
- super().__init__(
234
- output_dir = output_dir,
235
- overwrite_output_dir = overwrite_output_dir,
236
- do_train = do_train,
237
- do_eval = do_eval,
238
- do_predict = do_predict,
239
- eval_strategy = eval_strategy,
240
- prediction_loss_only = prediction_loss_only,
241
- per_device_train_batch_size = per_device_train_batch_size,
242
- per_device_eval_batch_size = per_device_eval_batch_size,
243
- per_gpu_train_batch_size = per_gpu_train_batch_size,
244
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
245
- gradient_accumulation_steps = gradient_accumulation_steps,
246
- eval_accumulation_steps = eval_accumulation_steps,
247
- eval_delay = eval_delay,
248
- torch_empty_cache_steps = torch_empty_cache_steps,
249
- learning_rate = learning_rate,
250
- weight_decay = weight_decay,
251
- adam_beta1 = adam_beta1,
252
- adam_beta2 = adam_beta2,
253
- adam_epsilon = adam_epsilon,
254
- max_grad_norm = max_grad_norm,
255
- num_train_epochs = num_train_epochs,
256
- max_steps = max_steps,
257
- lr_scheduler_type = lr_scheduler_type,
258
- warmup_ratio = warmup_ratio,
259
- warmup_steps = warmup_steps,
260
- log_level = log_level,
261
- log_level_replica = log_level_replica,
262
- log_on_each_node = log_on_each_node,
263
- logging_dir = logging_dir,
264
- logging_strategy = logging_strategy,
265
- logging_first_step = logging_first_step,
266
- logging_steps = logging_steps,
267
- logging_nan_inf_filter = logging_nan_inf_filter,
268
- save_strategy = save_strategy,
269
- save_steps = save_steps,
270
- save_total_limit = save_total_limit,
271
- save_safetensors = save_safetensors,
272
- save_on_each_node = save_on_each_node,
273
- save_only_model = save_only_model,
274
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
275
- no_cuda = no_cuda,
276
- use_cpu = use_cpu,
277
- use_mps_device = use_mps_device,
278
- seed = seed,
279
- data_seed = data_seed,
280
- jit_mode_eval = jit_mode_eval,
281
- use_ipex = use_ipex,
282
- bf16 = bf16,
283
- fp16 = fp16,
284
- fp16_opt_level = fp16_opt_level,
285
- half_precision_backend = half_precision_backend,
286
- bf16_full_eval = bf16_full_eval,
287
- fp16_full_eval = fp16_full_eval,
288
- tf32 = tf32,
289
- local_rank = local_rank,
290
- ddp_backend = ddp_backend,
291
- tpu_num_cores = tpu_num_cores,
292
- tpu_metrics_debug = tpu_metrics_debug,
293
- debug = debug,
294
- dataloader_drop_last = dataloader_drop_last,
295
- eval_steps = eval_steps,
296
- dataloader_num_workers = dataloader_num_workers,
297
- dataloader_prefetch_factor = dataloader_prefetch_factor,
298
- past_index = past_index,
299
- run_name = run_name,
300
- disable_tqdm = disable_tqdm,
301
- remove_unused_columns = remove_unused_columns,
302
- label_names = label_names,
303
- load_best_model_at_end = load_best_model_at_end,
304
- metric_for_best_model = metric_for_best_model,
305
- greater_is_better = greater_is_better,
306
- ignore_data_skip = ignore_data_skip,
307
- fsdp = fsdp,
308
- fsdp_min_num_params = fsdp_min_num_params,
309
- fsdp_config = fsdp_config,
310
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
311
- accelerator_config = accelerator_config,
312
- deepspeed = deepspeed,
313
- label_smoothing_factor = label_smoothing_factor,
314
- optim = optim,
315
- optim_args = optim_args,
316
- adafactor = adafactor,
317
- group_by_length = group_by_length,
318
- length_column_name = length_column_name,
319
- report_to = report_to,
320
- ddp_find_unused_parameters = ddp_find_unused_parameters,
321
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
322
- ddp_broadcast_buffers = ddp_broadcast_buffers,
323
- dataloader_pin_memory = dataloader_pin_memory,
324
- dataloader_persistent_workers = dataloader_persistent_workers,
325
- skip_memory_metrics = skip_memory_metrics,
326
- use_legacy_prediction_loop = use_legacy_prediction_loop,
327
- push_to_hub = push_to_hub,
328
- resume_from_checkpoint = resume_from_checkpoint,
329
- hub_model_id = hub_model_id,
330
- hub_strategy = hub_strategy,
331
- hub_token = hub_token,
332
- hub_private_repo = hub_private_repo,
333
- hub_always_push = hub_always_push,
334
- hub_revision = hub_revision,
335
- gradient_checkpointing = gradient_checkpointing,
336
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
337
- include_inputs_for_metrics = include_inputs_for_metrics,
338
- eval_do_concat_batches = eval_do_concat_batches,
339
- fp16_backend = fp16_backend,
340
- push_to_hub_model_id = push_to_hub_model_id,
341
- push_to_hub_organization = push_to_hub_organization,
342
- push_to_hub_token = push_to_hub_token,
343
- mp_parameters = mp_parameters,
344
- auto_find_batch_size = auto_find_batch_size,
345
- full_determinism = full_determinism,
346
- torchdynamo = torchdynamo,
347
- ray_scope = ray_scope,
348
- ddp_timeout = ddp_timeout,
349
- torch_compile = torch_compile,
350
- torch_compile_backend = torch_compile_backend,
351
- torch_compile_mode = torch_compile_mode,
352
- include_tokens_per_second = include_tokens_per_second,
353
- include_num_input_tokens_seen = include_num_input_tokens_seen,
354
- neftune_noise_alpha = neftune_noise_alpha,
355
- optim_target_modules = optim_target_modules,
356
- batch_eval_metrics = batch_eval_metrics,
357
- eval_on_start = eval_on_start,
358
- use_liger_kernel = use_liger_kernel,
359
- liger_kernel_config = liger_kernel_config,
360
- eval_use_gather_object = eval_use_gather_object,
361
- average_tokens_across_devices = average_tokens_across_devices,
362
- reward_model_path = reward_model_path,
363
- judge = judge,
364
- max_new_tokens = max_new_tokens,
365
- max_length = max_length,
366
- temperature = temperature,
367
- missing_eos_penalty = missing_eos_penalty,
368
- loss_type = loss_type,
369
- dataset_num_proc = dataset_num_proc,
370
- disable_dropout = disable_dropout,
371
- use_vllm = use_vllm,
372
- ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
373
- self.vllm_sampling_params = vllm_sampling_params
374
- self.unsloth_num_chunks = unsloth_num_chunks
375
- pass
376
-
377
- class _UnslothNashMDTrainer(OnlineDPOTrainer):
378
- r""""""
379
-
380
- _tag_names = ["trl", "nash-md"]
381
-
382
- def __init__(
383
- self,
384
- model: Union[PreTrainedModel, nn.Module] = None,
385
- ref_model: Union[PreTrainedModel, nn.Module] = None,
386
- reward_model: Union[PreTrainedModel, nn.Module, None] = None,
387
- judge: Optional[BasePairwiseJudge] = None,
388
- args: Optional[NashMDConfig] = None,
389
- data_collator: Optional[Callable] = None,
390
- train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
391
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
392
- processing_class: Optional[
393
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
394
- ] = None,
395
- peft_config: Optional[dict] = None,
396
- compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
397
- callbacks: Optional[list[TrainerCallback]] = None,
398
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
399
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
400
- ) -> None:
401
- super().__init__(
402
- model=model,
403
- ref_model=ref_model,
404
- reward_model=reward_model,
405
- judge=judge,
406
- args=args,
407
- data_collator=data_collator,
408
- train_dataset=train_dataset,
409
- eval_dataset=eval_dataset,
410
- processing_class=processing_class,
411
- reward_processing_class=processing_class, # for now, NashMDTrainer can't use any reward model
412
- peft_config=peft_config,
413
- compute_metrics=compute_metrics,
414
- callbacks=callbacks,
415
- optimizers=optimizers,
416
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
417
- )
418
-
419
- self._mixture_coef = self.args.mixture_coef
420
-
421
- # Overwrite the stats dictionary to include NashMD specific statistics
422
- self.stats = {
423
- # Remove "non_score_reward", "rlhf_reward", "scores_margin"
424
- # Add "mixture_coef"
425
- "loss/kl": [],
426
- "objective/entropy": [],
427
- "loss/score": [],
428
- "rewards/probabilities": [],
429
- "rewards/accuracies": [],
430
- "rewards/margins": [],
431
- "logps/chosen": [],
432
- "logps/rejected": [],
433
- "val/model_contain_eos_token": [],
434
- "val/ref_contain_eos_token": [],
435
- "beta": [],
436
- "mixture_coef": [],
437
- }
438
- if self.reward_model is not None:
439
- self.stats["rewards/chosen"] = []
440
- self.stats["rewards/rejected"] = []
441
-
442
- @property
443
- def mixture_coef(self):
444
- if isinstance(self._mixture_coef, list):
445
- epoch = self.state.epoch
446
- return self._mixture_coef[epoch] if epoch < len(self._mixture_coef) else self._mixture_coef[-1]
447
- else:
448
- return self._mixture_coef
449
-
450
- def _generate_completions(self, model, prompts):
451
- with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
452
- model_output = unwrapped_model.generate(
453
- input_ids=prompts["input_ids"],
454
- attention_mask=prompts["attention_mask"],
455
- generation_config=self.generation_config,
456
- )
457
-
458
- ref_model = model if self.ref_model is None else self.ref_model
459
- with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model:
460
- mixture_model = GeometricMixtureWrapper(
461
- model=unwrapped_model,
462
- ref_model=unwrapped_ref_model,
463
- generation_config=self.generation_config,
464
- mixture_coef=self.mixture_coef,
465
- device=self.accelerator.device,
466
- )
467
-
468
- mixture_output = mixture_model.generate(
469
- input_ids=prompts["input_ids"],
470
- attention_mask=prompts["attention_mask"],
471
- generation_config=self.generation_config,
472
- )
473
-
474
- return model_output, mixture_output
475
-
476
- def _process_completions(self, model_output, mixture_output, prompts):
477
- context_length = prompts["input_ids"].shape[1]
478
-
479
- # Process model completions
480
- model_completion_ids = model_output[:, context_length:]
481
- model_completion_ids, model_completion_mask = truncate_right(
482
- model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
483
- )
484
- model_data = {
485
- "input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1),
486
- "attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1),
487
- "raw": prompts["raw"],
488
- }
489
-
490
- # Process reference model completions
491
- mixture_completion_ids = mixture_output[:, context_length:]
492
- mixture_completion_ids, mixture_completion_mask = truncate_right(
493
- mixture_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
494
- )
495
- mixture_data = {
496
- "input_ids": torch.cat((prompts["input_ids"], mixture_completion_ids), dim=1),
497
- "attention_mask": torch.cat((prompts["attention_mask"], mixture_completion_mask), dim=1),
498
- "raw": prompts["raw"],
499
- }
500
-
501
- return model_data, mixture_data
502
-
503
- def _compute_rewards(self, model_data, mixture_data, context_length):
504
- with torch.no_grad():
505
- _, model_scores, _ = get_reward(
506
- self.reward_model, model_data["input_ids"], self.processing_class.pad_token_id, context_length
507
- )
508
- _, mixture_scores, _ = get_reward(
509
- self.reward_model, mixture_data["input_ids"], self.processing_class.pad_token_id, context_length
510
- )
511
-
512
- # Apply EOS penalty if needed
513
- if self.args.missing_eos_penalty is not None:
514
- model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
515
- mixture_contain_eos = torch.any(mixture_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
516
- model_scores[~model_contain_eos] -= self.args.missing_eos_penalty
517
- mixture_scores[~mixture_contain_eos] -= self.args.missing_eos_penalty
518
-
519
- return model_scores, mixture_scores
520
-
521
- def _compute_judge(self, model_data, mixture_data, context_length):
522
- prompts = model_data["raw"]
523
- model_data_completions = self.processing_class.batch_decode(
524
- model_data["input_ids"][:, context_length:], skip_special_tokens=True
525
- )
526
- model_data_completions = [completion.strip() for completion in model_data_completions]
527
-
528
- mixture_data_completions = self.processing_class.batch_decode(
529
- mixture_data["input_ids"][:, context_length:], skip_special_tokens=True
530
- )
531
- mixture_data_completions = [completion.strip() for completion in mixture_data_completions]
532
- if is_conversational({"prompt": prompts[0]}):
533
- model_data_completions = [
534
- [{"role": "assistant", "content": completion}] for completion in model_data_completions
535
- ]
536
- environment = jinja2.Environment()
537
- template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
538
- prompts = [template.render(messages=message) for message in prompts]
539
- model_data_completions = [template.render(messages=completion) for completion in model_data_completions]
540
-
541
- mixture_data_completions = [
542
- [{"role": "assistant", "content": completion}] for completion in mixture_data_completions
543
- ]
544
- mixture_data_completions = [
545
- template.render(messages=completion) for completion in mixture_data_completions
546
- ]
547
-
548
- probability = self.judge.judge(
549
- prompts,
550
- list(zip(model_data_completions, mixture_data_completions)),
551
- return_scores=True,
552
- )
553
- return torch.tensor(probability, device=model_data["input_ids"].device)
554
-
555
- def _compute_logprobs(self, model, model_data, context_length):
556
- def compute_logprobs_for_data(m, data):
557
- output = m(data["input_ids"], attention_mask=data["attention_mask"])
558
- logits = output.logits[:, context_length - 1 : -1]
559
- token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
560
- return token_logprobs
561
-
562
- # Compute logprobs for model completions under the model
563
- model_logprobs_model_data = compute_logprobs_for_data(model, model_data)
564
-
565
- # Compute logprobs of model completions under the reference model
566
- with torch.no_grad():
567
- if self.ref_model is None:
568
- with model.disable_adapter():
569
- ref_logprobs_model_data = compute_logprobs_for_data(model, model_data)
570
- else:
571
- ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data)
572
-
573
- # Mask padding tokens
574
- model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
575
- model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
576
- ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
577
-
578
- return (model_logprobs_model_data, ref_logprobs_model_data)
579
-
580
- def _compute_losses(
581
- self,
582
- model_logprobs_model_data,
583
- ref_logprobs_model_data,
584
- probability,
585
- ):
586
- # reinforce score where 0.5 is a control variate
587
- score = (probability - 0.5) * model_logprobs_model_data.sum(1)
588
-
589
- # kl divergence via reinforce
590
- with torch.no_grad():
591
- log_ratio = model_logprobs_model_data - ref_logprobs_model_data
592
- kl_div_log = log_ratio.sum(1)
593
- kl_div_loss = (log_ratio * model_logprobs_model_data).sum(1)
594
-
595
- # final loss
596
- loss = self.beta * kl_div_loss - score
597
-
598
- return loss.mean(), score, kl_div_log
599
-
600
- def _log_statistics(
601
- self,
602
- model_data,
603
- mixture_data,
604
- model_logprobs_model_data,
605
- ref_logprobs_model_data,
606
- probability,
607
- score,
608
- kl_div,
609
- context_length,
610
- model_scores=None,
611
- mixture_scores=None,
612
- ):
613
- # Helper function to gather and compute mean
614
- def gather_mean(tensor):
615
- return self.accelerator.gather_for_metrics(tensor).mean().item()
616
-
617
- # Log score
618
- self.stats["loss/score"].append(gather_mean(score))
619
- # Log KL divergence
620
- self.stats["loss/kl"].append(gather_mean(kl_div))
621
-
622
- # Log logprobs
623
- model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
624
- ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
625
-
626
- self.stats["logps/chosen"].append(gather_mean(model_logprobs_model_data_sum))
627
- self.stats["logps/rejected"].append(gather_mean(ref_logprobs_model_data_sum))
628
-
629
- # Log rewards
630
- if self.reward_model is not None:
631
- self.stats["rewards/chosen"].append(gather_mean(model_scores))
632
- self.stats["rewards/rejected"].append(gather_mean(mixture_scores))
633
-
634
- # Log probabilities
635
- self.stats["rewards/probabilities"].append(gather_mean(probability))
636
-
637
- # Calculate entropy for model data
638
- entropy_model_data = -model_logprobs_model_data.sum(1)
639
- self.stats["objective/entropy"].append(gather_mean(entropy_model_data))
640
-
641
- # Calculate margins
642
- margin = model_logprobs_model_data_sum - ref_logprobs_model_data_sum
643
- self.stats["rewards/margins"].append(gather_mean(margin))
644
-
645
- # Calculate accuracy
646
- accuracy = (margin > 0).float()
647
- self.stats["rewards/accuracies"].append(gather_mean(accuracy))
648
-
649
- # Log EOS token statistics
650
- model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
651
- mixture_eos = (mixture_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
652
- self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float()))
653
- self.stats["val/ref_contain_eos_token"].append(gather_mean(mixture_eos.float()))
654
-
655
- # Log beta and mixture coef
656
- self.stats["beta"].append(self.beta)
657
- self.stats["mixture_coef"].append(self.mixture_coef)
658
-
659
- def training_step(
660
- self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
661
- ) -> torch.Tensor:
662
- model.train()
663
-
664
- # Apply chat template and tokenize the input
665
- batch_size = len(next(iter(inputs.values())))
666
- prompts = inputs["prompt"]
667
- inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
668
- inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
669
- inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs]
670
- inputs = self.data_collator(inputs)
671
-
672
- # need the prompt_ only
673
- inputs = self._prepare_inputs(inputs)
674
- context_length = inputs["prompt_input_ids"].shape[1]
675
- prompts = {
676
- "input_ids": inputs["prompt_input_ids"],
677
- "attention_mask": inputs["prompt_attention_mask"],
678
- "raw": prompts,
679
- }
680
- del inputs
681
-
682
- # Sample completions from both the model and the reference model
683
- model_output, mixture_output = self._generate_completions(model, prompts)
684
-
685
- # Process model completions
686
- model_data, mixture_data = self._process_completions(model_output, mixture_output, prompts)
687
-
688
- # Compute rewards
689
- if self.reward_model is not None:
690
- model_scores, mixture_scores = self._compute_rewards(model_data, mixture_data, context_length)
691
- # probability of the model data vs the mixture data
692
- probability = F.sigmoid(model_scores - mixture_scores)
693
- else:
694
- model_scores, mixture_scores = None, None
695
- probability = self._compute_judge(model_data, mixture_data, context_length)
696
-
697
- # Compute logprobs
698
- model_logprobs_model_data, ref_logprobs_model_data = self._compute_logprobs(model, model_data, context_length)
699
-
700
- # Compute loss
701
- loss, score, kl_div = self._compute_losses(model_logprobs_model_data, ref_logprobs_model_data, probability)
702
-
703
- # Log everything
704
- self._log_statistics(
705
- model_data,
706
- mixture_data,
707
- model_logprobs_model_data.detach(),
708
- ref_logprobs_model_data,
709
- probability,
710
- score.detach(),
711
- kl_div.detach(),
712
- context_length,
713
- model_scores,
714
- mixture_scores,
715
- )
716
-
717
- if (
718
- self.args.torch_empty_cache_steps is not None
719
- and self.state.global_step % self.args.torch_empty_cache_steps == 0
720
- ):
721
- empty_cache()
722
-
723
- kwargs = {}
724
- # For LOMO optimizers you need to explicitly use the learning rate
725
- if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
726
- kwargs["learning_rate"] = self._get_learning_rate()
727
-
728
- if self.args.n_gpu > 1:
729
- loss = loss.mean() # mean() to average on multi-gpu parallel training
730
-
731
- if self.use_apex:
732
- with amp.scale_loss(loss, self.optimizer) as scaled_loss:
733
- scaled_loss.backward()
734
- else:
735
- self.accelerator.backward(loss, **kwargs)
736
-
737
- return loss.detach() / self.args.gradient_accumulation_steps
738
-
739
- def create_model_card(
740
- self,
741
- model_name: Optional[str] = None,
742
- dataset_name: Optional[str] = None,
743
- tags: Union[str, list[str], None] = None,
744
- ):
745
- """
746
- Creates a draft of a model card using the information available to the `Trainer`.
747
-
748
- Args:
749
- model_name (`str` or `None`, *optional*, defaults to `None`):
750
- Name of the model.
751
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
752
- Name of the dataset used for training.
753
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
754
- Tags to be associated with the model card.
755
- """
756
- if not self.is_world_process_zero():
757
- return
758
-
759
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
760
- base_model = self.model.config._name_or_path
761
- else:
762
- base_model = None
763
-
764
- tags = tags or []
765
- if isinstance(tags, str):
766
- tags = [tags]
767
-
768
- if hasattr(self.model.config, "unsloth_version"):
769
- tags.append("unsloth")
770
-
771
- citation = textwrap.dedent("""\
772
- @inproceedings{munos2024nash,
773
- title = {{Nash Learning from Human Feedback}},
774
- author = {R{\'{e}}mi Munos and Michal Valko and Daniele Calandriello and Mohammad Gheshlaghi Azar and Mark Rowland and Zhaohan Daniel Guo and Yunhao Tang and Matthieu Geist and Thomas Mesnard and C{\\^{o}}me Fiegel and Andrea Michi and Marco Selvi and Sertan Girgin and Nikola Momchev and Olivier Bachem and Daniel J. Mankowitz and Doina Precup and Bilal Piot},
775
- year = 2024,
776
- booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
777
- publisher = {OpenReview.net},
778
- url = {https://openreview.net/forum?id=Y5AmNYiyCQ}
779
- }""")
780
-
781
- model_card = generate_model_card(
782
- base_model=base_model,
783
- model_name=model_name,
784
- hub_model_id=self.hub_model_id,
785
- dataset_name=dataset_name,
786
- tags=tags,
787
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
788
- comet_url=get_comet_experiment_url(),
789
- trainer_name="Nash-MD",
790
- trainer_citation=citation,
791
- paper_title="Nash Learning from Human Feedback",
792
- paper_id="2312.00886",
793
- )
794
-
795
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
796
- class UnslothNashMDTrainer(_UnslothNashMDTrainer):
797
- """
798
-
799
- Initialize NashMDTrainer as a subclass of [`OnlineDPOConfig`].
800
-
801
- Args:
802
- model (`transformers.PreTrainedModel`):
803
- The model to train, preferably an `AutoModelForCausalLM`.
804
- ref_model (`PreTrainedModelWrapper`):
805
- Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
806
- reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
807
- reward_model (`transformers.PreTrainedModel`):
808
- The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
809
- judge (`BasePairwiseJudge`):
810
- The judge to use for pairwise comparison of model completions.
811
- args (`NashMDConfig`):
812
- The NashMD config arguments to use for training.
813
- data_collator (`transformers.DataCollator`):
814
- The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
815
- which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
816
- train_dataset (`datasets.Dataset`):
817
- The dataset to use for training.
818
- eval_dataset (`datasets.Dataset`):
819
- The dataset to use for evaluation.
820
- processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
821
- Processing class used to process the data. If provided, will be used to automatically process the inputs
822
- for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
823
- reuse the fine-tuned model.
824
- peft_config (`dict`):
825
- The peft config to use for training.
826
- compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
827
- The function to use to compute the metrics. Must take a `EvalPrediction` and return
828
- a dictionary string to metric values.
829
- callbacks (`list[transformers.TrainerCallback]`):
830
- The callbacks to use for training.
831
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
832
- The optimizer and scheduler to use for training.
833
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
834
- The function to use to preprocess the logits before computing the metrics.
835
-
836
- """
837
- def __init__(
838
- self,
839
- model = None,
840
- ref_model = None,
841
- reward_model = None,
842
- judge = None,
843
- args = None,
844
- data_collator = None,
845
- train_dataset = None,
846
- eval_dataset = None,
847
- processing_class = None,
848
- peft_config = None,
849
- compute_metrics = None,
850
- callbacks = None,
851
- preprocess_logits_for_metrics = None,
852
- **kwargs
853
- ):
854
- if args is None: args = UnslothNashMDConfig()
855
- use_bf16 = getattr(args, 'bf16', False)
856
- if type(use_bf16) is not bool: use_bf16 = False
857
- use_fp16 = getattr(args, 'fp16', False)
858
- if type(use_fp16) is not bool: use_fp16 = False
859
- force_float32 = False
860
- if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
861
- print('Unsloth: Switching to float32 training since model cannot work with float16')
862
- force_float32 = True
863
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
864
- dtype = getattr(model.config, 'torch_dtype', None)
865
- if dtype is None: dtype = model.get_input_embeddings().dtype
866
- from unsloth_zoo.utils import _get_dtype
867
- dtype = _get_dtype(dtype)
868
- float16 = dtype == torch.float16
869
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
870
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
871
- if force_float32:
872
- args.fp16 = False
873
- args.bf16 = False
874
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
875
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
876
- args.fp16 = float16
877
- args.bf16 = not float16
878
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
879
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
880
- args.eval_strategy = 'steps'
881
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
882
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
883
- if ga_steps is not None and ga_steps > 1:
884
- from transformers import __version__ as transformers_version
885
- if Version(transformers_version) <= Version('4.45.2'):
886
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
887
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
888
- if getattr(args, 'eval_strategy', 'no') != 'no':
889
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
890
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
891
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
892
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
893
- if type(fp16_full_eval) is not bool: fp16_full_eval = False
894
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
895
- if type(bf16_full_eval) is not bool: bf16_full_eval = False
896
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
897
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
898
- if force_float32:
899
- args.bf16_full_eval = False
900
- args.fp16_full_eval = False
901
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
902
- args.bf16_full_eval = True
903
- args.fp16_full_eval = False
904
- elif not bf16_full_eval and not fp16_full_eval:
905
- args.bf16_full_eval = args.bf16
906
- args.fp16_full_eval = args.fp16
907
- _output_logits = False
908
- if locals().get('compute_metrics', None) is not None: _output_logits = True
909
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
910
- if _output_logits:
911
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
912
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
913
- pass
914
- else:
915
- model_max_seq_length = getattr(model, 'max_seq_length', None)
916
- args_max_seq_length = getattr(args, 'max_seq_length', None)
917
- if args_max_seq_length is None and model_max_seq_length is not None:
918
- max_seq_length = model.max_seq_length
919
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
920
- if model is not None and hasattr(model, 'for_training'):
921
- model.for_training()
922
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
923
- if 'processing_class' in locals():
924
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
925
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
926
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
927
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
928
- if not isinstance(data_collator, UnslothVisionDataCollator):
929
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
930
- data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
931
- elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
932
- data_collator = DataCollatorForSeq2Seq(__tokenizer)
933
- else:
934
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
935
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
936
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
937
- if not isinstance(data_collator, UnslothVisionDataCollator):
938
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
939
- if isinstance(data_collator, DataCollatorForSeq2Seq):
940
- data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
941
- else:
942
- data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
943
- other_metrics = []
944
-
945
- from unsloth_zoo.logging_utils import PatchRLStatistics
946
- PatchRLStatistics('nash_md_trainer', other_metrics)
947
-
948
- super().__init__(
949
- model = model,
950
- ref_model = ref_model,
951
- reward_model = reward_model,
952
- judge = judge,
953
- args = args,
954
- data_collator = data_collator,
955
- train_dataset = train_dataset,
956
- eval_dataset = eval_dataset,
957
- processing_class = processing_class,
958
- peft_config = peft_config,
959
- compute_metrics = compute_metrics,
960
- callbacks = callbacks,
961
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
962
- if hasattr(self, 'neftune_hook_handle'):
963
- self.neftune_hook_handle.remove()
964
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
965
- if getattr(args, 'neftune_noise_alpha', None) is not None:
966
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
967
- pass
968
-
969
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_run_uploads/UnslothORPOTrainer.py DELETED
@@ -1,1552 +0,0 @@
1
- """
2
- 2025.7.11
3
- 2025.7.11
4
- 4.54.1
5
- 0.16.1
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
13
- from trl.trainer.orpo_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalLoopOutput, F, FeatureExtractionMixin, Literal, ORPOConfig, ORPOTrainer, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedModelWrapper, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, add_bos_token_if_needed, add_eos_token_if_needed, amp, deepcopy, defaultdict, disable_dropout_in_model, generate_model_card, get_comet_experiment_url, inspect, is_comet_available, is_peft_available, is_torch_fx_proxy, is_torch_xla_available, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, maybe_extract_prompt, nn, np, nullcontext, os, pad_to_length, pd, peft_module_casting_to_bf16, prepare_model_for_kbit_training, random, selective_log_softmax, textwrap, torch, transformers, version, wandb, warnings)
14
-
15
-
16
- import os
17
- from typing import *
18
- from dataclasses import dataclass, field
19
- from packaging.version import Version
20
- import torch
21
- import numpy as np
22
- from contextlib import nullcontext
23
- from torch.nn import functional as F
24
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
25
-
26
- torch_compile_options = {
27
- "epilogue_fusion" : True,
28
- "max_autotune" : False,
29
- "shape_padding" : True,
30
- "trace.enabled" : False,
31
- "triton.cudagraphs" : False,
32
- }
33
-
34
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
35
- def chunked_selective_log_softmax(logits, index):
36
- # Split into 4 chunks only
37
- chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
38
- chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
39
- all_per_token_logps = []
40
- # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
41
- for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
42
- chunk_logits = chunk_logits.to(torch.float32)
43
- selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
44
- logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
45
- per_token_logps = selected_logits - logsumexp_values
46
- all_per_token_logps.append(per_token_logps)
47
- pass
48
- all_per_token_logps = torch.concat(all_per_token_logps)
49
- all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
50
- return all_per_token_logps
51
- @dataclass
52
- class UnslothORPOConfig(ORPOConfig):
53
- """
54
-
55
- Configuration class for the [`ORPOTrainer`].
56
-
57
- Using [`~transformers.HfArgumentParser`] we can turn this class into
58
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
59
- command line.
60
-
61
- Parameters:
62
- learning_rate (`float`, *optional*, defaults to `1e-6`):
63
- Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
64
- [`~transformers.TrainingArguments`].
65
- max_length (`int` or `None`, *optional*, defaults to `1024`):
66
- Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
67
- to use the default data collator.
68
- max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
69
- Maximum length of the prompt. This argument is required if you want to use the default data collator.
70
- max_completion_length (`int` or `None`, *optional*, defaults to `None`):
71
- Maximum length of the completion. This argument is required if you want to use the default data collator
72
- and your model is an encoder-decoder.
73
- beta (`float`, *optional*, defaults to `0.1`):
74
- Parameter controlling the relative ratio loss weight in the ORPO loss. In the [paper](https://huggingface.co/papers/2403.07691),
75
- it is denoted by λ. In the [code](https://github.com/xfactlab/orpo), it is denoted by `alpha`.
76
- disable_dropout (`bool`, *optional*, defaults to `True`):
77
- Whether to disable dropout in the model.
78
- label_pad_token_id (`int`, *optional*, defaults to `-100`):
79
- Label pad token id. This argument is required if you want to use the default data collator.
80
- padding_value (`int` or `None`, *optional*, defaults to `None`):
81
- Padding value to use. If `None`, the padding value of the tokenizer is used.
82
- truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
83
- Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
84
- This argument is required if you want to use the default data collator.
85
- generate_during_eval (`bool`, *optional*, defaults to `False`):
86
- If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
87
- is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
88
- When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
89
- you need to specify if the model returned by the callable is an encoder-decoder model.
90
- model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
91
- Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
92
- string.
93
- dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
94
- Number of processes to use for processing the dataset.
95
-
96
- """
97
- vllm_sampling_params: Optional[Any] = field(
98
- default = None,
99
- metadata = {'help': 'vLLM SamplingParams'},
100
- )
101
- unsloth_num_chunks : Optional[int] = field(
102
- default = -1,
103
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
104
- )
105
- def __init__(
106
- self,
107
- output_dir = None,
108
- overwrite_output_dir = None,
109
- do_train = False,
110
- do_eval = False,
111
- do_predict = False,
112
- eval_strategy = 'no',
113
- prediction_loss_only = False,
114
- per_device_train_batch_size = 4,
115
- per_device_eval_batch_size = 4,
116
- per_gpu_train_batch_size = None,
117
- per_gpu_eval_batch_size = None,
118
- gradient_accumulation_steps = 2,
119
- eval_accumulation_steps = 2,
120
- eval_delay = 0,
121
- torch_empty_cache_steps = 250,
122
- learning_rate = 5e-05,
123
- weight_decay = 0.01,
124
- adam_beta1 = 0.9,
125
- adam_beta2 = 0.999,
126
- adam_epsilon = 1e-08,
127
- max_grad_norm = 1.0,
128
- num_train_epochs = 3.0,
129
- max_steps = -1,
130
- lr_scheduler_type = 'linear',
131
- warmup_ratio = 0.1,
132
- warmup_steps = 0,
133
- log_level = 'passive',
134
- log_level_replica = 'warning',
135
- log_on_each_node = True,
136
- logging_dir = None,
137
- logging_strategy = 'steps',
138
- logging_first_step = False,
139
- logging_steps = 1,
140
- logging_nan_inf_filter = False,
141
- save_strategy = 'steps',
142
- save_steps = 500,
143
- save_total_limit = None,
144
- save_safetensors = True,
145
- save_on_each_node = False,
146
- save_only_model = False,
147
- restore_callback_states_from_checkpoint = False,
148
- no_cuda = False,
149
- use_cpu = False,
150
- use_mps_device = False,
151
- seed = 3407,
152
- data_seed = 3407,
153
- jit_mode_eval = False,
154
- use_ipex = False,
155
- bf16 = False,
156
- fp16 = False,
157
- fp16_opt_level = 'O1',
158
- half_precision_backend = 'auto',
159
- bf16_full_eval = False,
160
- fp16_full_eval = False,
161
- tf32 = None,
162
- local_rank = -1,
163
- ddp_backend = None,
164
- tpu_num_cores = None,
165
- tpu_metrics_debug = False,
166
- debug = '',
167
- dataloader_drop_last = False,
168
- eval_steps = None,
169
- dataloader_num_workers = 0,
170
- dataloader_prefetch_factor = None,
171
- past_index = -1,
172
- run_name = None,
173
- disable_tqdm = None,
174
- remove_unused_columns = True,
175
- label_names = None,
176
- load_best_model_at_end = False,
177
- metric_for_best_model = None,
178
- greater_is_better = None,
179
- ignore_data_skip = False,
180
- fsdp = '',
181
- fsdp_min_num_params = 0,
182
- fsdp_config = None,
183
- fsdp_transformer_layer_cls_to_wrap = None,
184
- accelerator_config = None,
185
- deepspeed = None,
186
- label_smoothing_factor = 0.0,
187
- optim = 'adamw_8bit',
188
- optim_args = None,
189
- adafactor = False,
190
- group_by_length = False,
191
- length_column_name = 'length',
192
- report_to = None,
193
- ddp_find_unused_parameters = None,
194
- ddp_bucket_cap_mb = None,
195
- ddp_broadcast_buffers = None,
196
- dataloader_pin_memory = True,
197
- dataloader_persistent_workers = False,
198
- skip_memory_metrics = True,
199
- use_legacy_prediction_loop = False,
200
- push_to_hub = False,
201
- resume_from_checkpoint = None,
202
- hub_model_id = None,
203
- hub_strategy = 'every_save',
204
- hub_token = None,
205
- hub_private_repo = None,
206
- hub_always_push = False,
207
- hub_revision = None,
208
- gradient_checkpointing = False,
209
- gradient_checkpointing_kwargs = None,
210
- include_inputs_for_metrics = False,
211
- eval_do_concat_batches = True,
212
- fp16_backend = 'auto',
213
- push_to_hub_model_id = None,
214
- push_to_hub_organization = None,
215
- push_to_hub_token = None,
216
- mp_parameters = '',
217
- auto_find_batch_size = True,
218
- full_determinism = False,
219
- torchdynamo = None,
220
- ray_scope = 'last',
221
- ddp_timeout = 1800,
222
- torch_compile = False,
223
- torch_compile_backend = None,
224
- torch_compile_mode = None,
225
- include_tokens_per_second = False,
226
- include_num_input_tokens_seen = False,
227
- neftune_noise_alpha = None,
228
- optim_target_modules = None,
229
- batch_eval_metrics = False,
230
- eval_on_start = False,
231
- use_liger_kernel = False,
232
- liger_kernel_config = None,
233
- eval_use_gather_object = False,
234
- average_tokens_across_devices = True,
235
- max_length = 1024,
236
- max_prompt_length = 512,
237
- max_completion_length = None,
238
- beta = 0.1,
239
- disable_dropout = True,
240
- label_pad_token_id = -100,
241
- padding_value = None,
242
- truncation_mode = 'keep_end',
243
- generate_during_eval = False,
244
- is_encoder_decoder = None,
245
- model_init_kwargs = None,
246
- dataset_num_proc = None,
247
- vllm_sampling_params = None,
248
- unsloth_num_chunks = -1,
249
- **kwargs,
250
- ):
251
- if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
252
- if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
253
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
254
- output_dir = 'unsloth_training_checkpoints'
255
- save_strategy = 'no'
256
- if dataset_num_proc is None:
257
- from multiprocessing import cpu_count
258
- dataset_num_proc = min(cpu_count()*2, 2)
259
-
260
- super().__init__(
261
- output_dir = output_dir,
262
- overwrite_output_dir = overwrite_output_dir,
263
- do_train = do_train,
264
- do_eval = do_eval,
265
- do_predict = do_predict,
266
- eval_strategy = eval_strategy,
267
- prediction_loss_only = prediction_loss_only,
268
- per_device_train_batch_size = per_device_train_batch_size,
269
- per_device_eval_batch_size = per_device_eval_batch_size,
270
- per_gpu_train_batch_size = per_gpu_train_batch_size,
271
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
272
- gradient_accumulation_steps = gradient_accumulation_steps,
273
- eval_accumulation_steps = eval_accumulation_steps,
274
- eval_delay = eval_delay,
275
- torch_empty_cache_steps = torch_empty_cache_steps,
276
- learning_rate = learning_rate,
277
- weight_decay = weight_decay,
278
- adam_beta1 = adam_beta1,
279
- adam_beta2 = adam_beta2,
280
- adam_epsilon = adam_epsilon,
281
- max_grad_norm = max_grad_norm,
282
- num_train_epochs = num_train_epochs,
283
- max_steps = max_steps,
284
- lr_scheduler_type = lr_scheduler_type,
285
- warmup_ratio = warmup_ratio,
286
- warmup_steps = warmup_steps,
287
- log_level = log_level,
288
- log_level_replica = log_level_replica,
289
- log_on_each_node = log_on_each_node,
290
- logging_dir = logging_dir,
291
- logging_strategy = logging_strategy,
292
- logging_first_step = logging_first_step,
293
- logging_steps = logging_steps,
294
- logging_nan_inf_filter = logging_nan_inf_filter,
295
- save_strategy = save_strategy,
296
- save_steps = save_steps,
297
- save_total_limit = save_total_limit,
298
- save_safetensors = save_safetensors,
299
- save_on_each_node = save_on_each_node,
300
- save_only_model = save_only_model,
301
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
302
- no_cuda = no_cuda,
303
- use_cpu = use_cpu,
304
- use_mps_device = use_mps_device,
305
- seed = seed,
306
- data_seed = data_seed,
307
- jit_mode_eval = jit_mode_eval,
308
- use_ipex = use_ipex,
309
- bf16 = bf16,
310
- fp16 = fp16,
311
- fp16_opt_level = fp16_opt_level,
312
- half_precision_backend = half_precision_backend,
313
- bf16_full_eval = bf16_full_eval,
314
- fp16_full_eval = fp16_full_eval,
315
- tf32 = tf32,
316
- local_rank = local_rank,
317
- ddp_backend = ddp_backend,
318
- tpu_num_cores = tpu_num_cores,
319
- tpu_metrics_debug = tpu_metrics_debug,
320
- debug = debug,
321
- dataloader_drop_last = dataloader_drop_last,
322
- eval_steps = eval_steps,
323
- dataloader_num_workers = dataloader_num_workers,
324
- dataloader_prefetch_factor = dataloader_prefetch_factor,
325
- past_index = past_index,
326
- run_name = run_name,
327
- disable_tqdm = disable_tqdm,
328
- remove_unused_columns = remove_unused_columns,
329
- label_names = label_names,
330
- load_best_model_at_end = load_best_model_at_end,
331
- metric_for_best_model = metric_for_best_model,
332
- greater_is_better = greater_is_better,
333
- ignore_data_skip = ignore_data_skip,
334
- fsdp = fsdp,
335
- fsdp_min_num_params = fsdp_min_num_params,
336
- fsdp_config = fsdp_config,
337
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
338
- accelerator_config = accelerator_config,
339
- deepspeed = deepspeed,
340
- label_smoothing_factor = label_smoothing_factor,
341
- optim = optim,
342
- optim_args = optim_args,
343
- adafactor = adafactor,
344
- group_by_length = group_by_length,
345
- length_column_name = length_column_name,
346
- report_to = report_to,
347
- ddp_find_unused_parameters = ddp_find_unused_parameters,
348
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
349
- ddp_broadcast_buffers = ddp_broadcast_buffers,
350
- dataloader_pin_memory = dataloader_pin_memory,
351
- dataloader_persistent_workers = dataloader_persistent_workers,
352
- skip_memory_metrics = skip_memory_metrics,
353
- use_legacy_prediction_loop = use_legacy_prediction_loop,
354
- push_to_hub = push_to_hub,
355
- resume_from_checkpoint = resume_from_checkpoint,
356
- hub_model_id = hub_model_id,
357
- hub_strategy = hub_strategy,
358
- hub_token = hub_token,
359
- hub_private_repo = hub_private_repo,
360
- hub_always_push = hub_always_push,
361
- hub_revision = hub_revision,
362
- gradient_checkpointing = gradient_checkpointing,
363
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
364
- include_inputs_for_metrics = include_inputs_for_metrics,
365
- eval_do_concat_batches = eval_do_concat_batches,
366
- fp16_backend = fp16_backend,
367
- push_to_hub_model_id = push_to_hub_model_id,
368
- push_to_hub_organization = push_to_hub_organization,
369
- push_to_hub_token = push_to_hub_token,
370
- mp_parameters = mp_parameters,
371
- auto_find_batch_size = auto_find_batch_size,
372
- full_determinism = full_determinism,
373
- torchdynamo = torchdynamo,
374
- ray_scope = ray_scope,
375
- ddp_timeout = ddp_timeout,
376
- torch_compile = torch_compile,
377
- torch_compile_backend = torch_compile_backend,
378
- torch_compile_mode = torch_compile_mode,
379
- include_tokens_per_second = include_tokens_per_second,
380
- include_num_input_tokens_seen = include_num_input_tokens_seen,
381
- neftune_noise_alpha = neftune_noise_alpha,
382
- optim_target_modules = optim_target_modules,
383
- batch_eval_metrics = batch_eval_metrics,
384
- eval_on_start = eval_on_start,
385
- use_liger_kernel = use_liger_kernel,
386
- liger_kernel_config = liger_kernel_config,
387
- eval_use_gather_object = eval_use_gather_object,
388
- average_tokens_across_devices = average_tokens_across_devices,
389
- max_length = max_length,
390
- max_prompt_length = max_prompt_length,
391
- max_completion_length = max_completion_length,
392
- beta = beta,
393
- disable_dropout = disable_dropout,
394
- label_pad_token_id = label_pad_token_id,
395
- padding_value = padding_value,
396
- truncation_mode = truncation_mode,
397
- generate_during_eval = generate_during_eval,
398
- is_encoder_decoder = is_encoder_decoder,
399
- model_init_kwargs = model_init_kwargs,
400
- dataset_num_proc = dataset_num_proc,**kwargs)
401
- self.vllm_sampling_params = vllm_sampling_params
402
- self.unsloth_num_chunks = unsloth_num_chunks
403
- pass
404
-
405
- class _UnslothORPOTrainer(Trainer):
406
- r""""""
407
-
408
- _tag_names = ["trl", "orpo"]
409
-
410
- def __init__(
411
- self,
412
- model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
413
- args: Optional[ORPOConfig] = None,
414
- data_collator: Optional[DataCollator] = None,
415
- train_dataset: Optional[Dataset] = None,
416
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
417
- processing_class: Optional[
418
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
419
- ] = None,
420
- model_init: Optional[Callable[[], PreTrainedModel]] = None,
421
- callbacks: Optional[list[TrainerCallback]] = None,
422
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
423
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
424
- peft_config: Optional[dict] = None,
425
- compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
426
- ):
427
- if args.model_init_kwargs is None:
428
- model_init_kwargs = {}
429
- elif not isinstance(model, str):
430
- raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.")
431
- else:
432
- model_init_kwargs = args.model_init_kwargs
433
- torch_dtype = model_init_kwargs.get("torch_dtype")
434
- if torch_dtype is not None:
435
- # Convert to `torch.dtype` if an str is passed
436
- if isinstance(torch_dtype, str) and torch_dtype != "auto":
437
- torch_dtype = getattr(torch, torch_dtype)
438
- if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
439
- raise ValueError(
440
- f"Invalid `torch_dtype` passed to the ORPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
441
- )
442
- model_init_kwargs["torch_dtype"] = torch_dtype
443
-
444
- if isinstance(model, str):
445
- model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
446
-
447
- # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
448
- # has been called in order to properly call autocast if needed.
449
- self._peft_has_been_casted_to_bf16 = False
450
-
451
- if not is_peft_available() and peft_config is not None:
452
- raise ValueError(
453
- "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
454
- )
455
- elif is_peft_available() and peft_config is not None:
456
- # if model is a peft model and we have a peft_config, we merge and unload it first
457
- if isinstance(model, PeftModel):
458
- model = model.merge_and_unload()
459
-
460
- if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
461
- _support_gc_kwargs = hasattr(
462
- args, "gradient_checkpointing_kwargs"
463
- ) and "gradient_checkpointing_kwargs" in list(
464
- inspect.signature(prepare_model_for_kbit_training).parameters
465
- )
466
-
467
- prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
468
-
469
- if _support_gc_kwargs:
470
- prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
471
-
472
- model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
473
- elif getattr(args, "gradient_checkpointing", False):
474
- # For backward compatibility with older versions of transformers
475
- if hasattr(model, "enable_input_require_grads"):
476
- model.enable_input_require_grads()
477
- else:
478
-
479
- def make_inputs_require_grad(module, input, output):
480
- output.requires_grad_(True)
481
-
482
- model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
483
-
484
- # get peft model with the given config
485
- model = model
486
- if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
487
- peft_module_casting_to_bf16(model)
488
- # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
489
- self._peft_has_been_casted_to_bf16 = True
490
-
491
- # For models that use gradient_checkpointing, we need to attach a hook that enables input
492
- # to explicitly have `requires_grad=True`, otherwise training will either silently
493
- # fail or completely fail.
494
- elif getattr(args, "gradient_checkpointing", False):
495
- # For backward compatibility with older versions of transformers
496
- if hasattr(model, "enable_input_require_grads"):
497
- model.enable_input_require_grads()
498
- else:
499
-
500
- def make_inputs_require_grad(module, input, output):
501
- output.requires_grad_(True)
502
-
503
- model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
504
-
505
- if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
506
- raise ValueError(
507
- "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
508
- " Please install `wandb` or `comet-ml` to resolve."
509
- )
510
-
511
- if model is not None:
512
- self.is_encoder_decoder = model.config.is_encoder_decoder
513
- elif args.is_encoder_decoder is None:
514
- raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
515
- else:
516
- self.is_encoder_decoder = args.is_encoder_decoder
517
-
518
- if self.is_encoder_decoder:
519
- self.decoder_start_token_id = model.config.decoder_start_token_id
520
- self.pad_token_id = model.config.pad_token_id
521
-
522
- if processing_class is None:
523
- raise ValueError("processing_class must be specified to tokenize a ORPO dataset.")
524
- if args.max_length is None:
525
- warnings.warn(
526
- "`max_length` is not set in the ORPOConfig's init"
527
- " it will default to `512` by default, but you should do it yourself in the future.",
528
- UserWarning,
529
- )
530
- max_length = 512
531
- else:
532
- max_length = args.max_length
533
- if args.max_prompt_length is None:
534
- warnings.warn(
535
- "`max_prompt_length` is not set in the ORPOConfig's init"
536
- " it will default to `128` by default, but you should do it yourself in the future.",
537
- UserWarning,
538
- )
539
- max_prompt_length = 128
540
- else:
541
- max_prompt_length = args.max_prompt_length
542
-
543
- if args.max_completion_length is None and self.is_encoder_decoder:
544
- warnings.warn(
545
- "When using an encoder decoder architecture, you should set `max_completion_length` in the ORPOConfig's init"
546
- " it will default to `128` by default, but you should do it yourself in the future.",
547
- UserWarning,
548
- )
549
- self.max_completion_length = 128
550
- else:
551
- self.max_completion_length = args.max_completion_length
552
-
553
- if data_collator is None:
554
- data_collator = DPODataCollatorWithPadding(
555
- pad_token_id=processing_class.pad_token_id,
556
- label_pad_token_id=args.label_pad_token_id,
557
- is_encoder_decoder=self.is_encoder_decoder,
558
- )
559
-
560
- if args.remove_unused_columns:
561
- args.remove_unused_columns = False
562
- # warn users
563
- warnings.warn(
564
- "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
565
- " we have set it for you, but you should do it yourself in the future.",
566
- UserWarning,
567
- )
568
-
569
- self.use_dpo_data_collator = True
570
- else:
571
- self.use_dpo_data_collator = False
572
-
573
- # Disable dropout in the model and reference model
574
- if args.disable_dropout:
575
- disable_dropout_in_model(model)
576
-
577
- self.max_length = max_length
578
- self.generate_during_eval = args.generate_during_eval
579
- self.label_pad_token_id = args.label_pad_token_id
580
- self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
581
- self.max_prompt_length = max_prompt_length
582
- self.truncation_mode = args.truncation_mode
583
- self.processing_class = processing_class
584
-
585
- self.beta = args.beta
586
- self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
587
- self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
588
- if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
589
- warnings.warn(
590
- "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
591
- "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
592
- "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
593
- "loss.",
594
- UserWarning,
595
- )
596
-
597
- self._stored_metrics = defaultdict(lambda: defaultdict(list))
598
-
599
- # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
600
- # input tensor associated with the key "input_ids". However, in ORPO, the sampled data does not include the
601
- # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and
602
- # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens
603
- # of the input, floating-point operations will not be computed." To suppress this warning, we set the
604
- # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
605
- # that the warning has already been issued.
606
- model.warnings_issued["estimate_tokens"] = True
607
-
608
- # Compute that only on the main process for faster data processing.
609
- # see: https://github.com/huggingface/trl/pull/1255
610
- with PartialState().main_process_first():
611
- # Extract the prompt if needed, and apply the chat template if needed
612
- train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
613
- train_dataset = train_dataset.map(
614
- maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
615
- )
616
- train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
617
- if eval_dataset is not None:
618
- eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
619
- eval_dataset = eval_dataset.map(
620
- maybe_apply_chat_template,
621
- fn_kwargs={"tokenizer": processing_class},
622
- num_proc=args.dataset_num_proc,
623
- )
624
- eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
625
-
626
- super().__init__(
627
- model=model,
628
- args=args,
629
- data_collator=data_collator,
630
- train_dataset=train_dataset,
631
- eval_dataset=eval_dataset,
632
- processing_class=processing_class,
633
- model_init=model_init,
634
- compute_metrics=compute_metrics,
635
- callbacks=callbacks,
636
- optimizers=optimizers,
637
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
638
- )
639
-
640
- # Add tags for models that have been loaded with the correct transformers version
641
- if hasattr(self.model, "add_model_tags"):
642
- self.model.add_model_tags(self._tag_names)
643
-
644
- if not hasattr(self, "accelerator"):
645
- raise AttributeError(
646
- "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
647
- )
648
-
649
- def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
650
- # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
651
- deepspeed_plugin = self.accelerator.state.deepspeed_plugin
652
- config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
653
-
654
- if model is not None:
655
- if hasattr(model, "config"):
656
- hidden_size = (
657
- max(model.config.hidden_sizes)
658
- if getattr(model.config, "hidden_sizes", None)
659
- else getattr(model.config, "hidden_size", None)
660
- )
661
- if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
662
- # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
663
- # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
664
- config_kwargs.update(
665
- {
666
- "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
667
- "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
668
- "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
669
- }
670
- )
671
-
672
- # If ZeRO-3 is used, we shard both the active and reference model.
673
- # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
674
- if config_kwargs["zero_optimization"]["stage"] != 3:
675
- config_kwargs["zero_optimization"]["stage"] = 0
676
- model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
677
- model.eval()
678
- return model
679
-
680
- def build_tokenized_answer(self, prompt, answer):
681
- """
682
- Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
683
- It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
684
- Reference:
685
- https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
686
- """
687
-
688
- full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
689
- prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
690
-
691
- answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
692
- answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
693
-
694
- # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
695
- full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
696
-
697
- # Prepare input tokens for token by token comparison
698
- full_input_ids = np.array(full_tokenized["input_ids"])
699
-
700
- if len(full_input_ids) != len(full_concat_input_ids):
701
- raise ValueError("Prompt input ids and answer input ids should have the same length.")
702
-
703
- # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
704
- # can be merged together when tokenizing prompt+answer. This could result
705
- # on the last token from the prompt being different when tokenized on its own
706
- # vs when done as prompt+answer.
707
- response_token_ids_start_idx = len(prompt_input_ids)
708
-
709
- # If tokenized prompt is different than both prompt+answer, then it means the
710
- # last token has changed due to merging.
711
- if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
712
- response_token_ids_start_idx -= 1
713
-
714
- prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
715
- prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
716
-
717
- if len(prompt_input_ids) != len(prompt_attention_mask):
718
- raise ValueError("Prompt input ids and attention mask should have the same length.")
719
-
720
- answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
721
- answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
722
-
723
- return dict(
724
- prompt_input_ids=prompt_input_ids,
725
- prompt_attention_mask=prompt_attention_mask,
726
- input_ids=answer_input_ids,
727
- attention_mask=answer_attention_mask,
728
- )
729
-
730
- def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
731
- """Tokenize a single row from a ORPO specific dataset.
732
-
733
- At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
734
- in case the prompt + chosen or prompt + rejected responses is/are too long. First
735
- we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
736
-
737
- We also create the labels for the chosen/rejected responses, which are of length equal to
738
- the sum of the length of the prompt and the chosen/rejected response, with
739
- label_pad_token_id for the prompt tokens.
740
- """
741
- batch = {}
742
- prompt = feature["prompt"]
743
- chosen = feature["chosen"]
744
- rejected = feature["rejected"]
745
-
746
- if not self.is_encoder_decoder:
747
- # Check issues below for more details
748
- # 1. https://github.com/huggingface/trl/issues/907
749
- # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
750
- # 3. https://github.com/LianjiaTech/BELLE/issues/337
751
-
752
- if not isinstance(prompt, str):
753
- raise ValueError(f"prompt should be an str but got {type(prompt)}")
754
- prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
755
- prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
756
-
757
- if not isinstance(chosen, str):
758
- raise ValueError(f"chosen should be an str but got {type(chosen)}")
759
- chosen_tokens = self.build_tokenized_answer(prompt, chosen)
760
-
761
- if not isinstance(rejected, str):
762
- raise ValueError(f"rejected should be an str but got {type(rejected)}")
763
- rejected_tokens = self.build_tokenized_answer(prompt, rejected)
764
-
765
- # Last prompt token might get merged by tokenizer and
766
- # it should not be included for generation if that happens
767
- prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
768
-
769
- chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
770
- rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
771
- prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
772
-
773
- for k, v in prompt_tokens.items():
774
- prompt_tokens[k] = v[:prompt_len_input_ids]
775
-
776
- # Make sure prompts only have one different token at most an
777
- # and length only differs by 1 at most
778
- num_diff_tokens = sum(
779
- [a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
780
- )
781
- num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
782
- if num_diff_tokens > 1 or num_diff_len > 1:
783
- raise ValueError(
784
- "Chosen and rejected prompt_input_ids might only differ on the "
785
- "last token due to tokenizer merge ops."
786
- )
787
-
788
- # add BOS token to head of prompt. Avoid adding if it's already there
789
- prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
790
- self.processing_class.bos_token_id,
791
- prompt_len_input_ids,
792
- prompt_tokens,
793
- chosen_prompt_len_input_ids,
794
- chosen_tokens,
795
- rejected_prompt_len_input_ids,
796
- rejected_tokens,
797
- )
798
-
799
- # add EOS token to end of answer. Avoid adding if it's already there
800
- chosen_tokens, rejected_tokens = add_eos_token_if_needed(
801
- self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
802
- )
803
-
804
- longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
805
-
806
- # if combined sequence is too long, truncate the prompt
807
- for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
808
- if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
809
- if self.truncation_mode == "keep_start":
810
- for k in ["prompt_input_ids", "prompt_attention_mask"]:
811
- answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
812
- elif self.truncation_mode == "keep_end":
813
- for k in ["prompt_input_ids", "prompt_attention_mask"]:
814
- answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
815
- else:
816
- raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
817
-
818
- # if that's still too long, truncate the response
819
- for answer_tokens in [chosen_tokens, rejected_tokens]:
820
- if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
821
- for k in ["input_ids", "attention_mask"]:
822
- answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
823
-
824
- # Create labels
825
- chosen_sequence_tokens = {
826
- k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
827
- }
828
- rejected_sequence_tokens = {
829
- k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
830
- }
831
- chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
832
- chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
833
- self.label_pad_token_id
834
- ] * len(chosen_tokens["prompt_input_ids"])
835
- rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
836
- rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
837
- self.label_pad_token_id
838
- ] * len(rejected_tokens["prompt_input_ids"])
839
-
840
- for k, toks in {
841
- "chosen_": chosen_sequence_tokens,
842
- "rejected_": rejected_sequence_tokens,
843
- "": prompt_tokens,
844
- }.items():
845
- for type_key, tokens in toks.items():
846
- if type_key == "token_type_ids":
847
- continue
848
- batch[f"{k}{type_key}"] = tokens
849
-
850
- else:
851
- chosen_tokens = self.processing_class(
852
- chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
853
- )
854
- rejected_tokens = self.processing_class(
855
- rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
856
- )
857
- prompt_tokens = self.processing_class(
858
- prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
859
- )
860
-
861
- batch["chosen_labels"] = chosen_tokens["input_ids"]
862
- batch["rejected_labels"] = rejected_tokens["input_ids"]
863
- batch["prompt_input_ids"] = prompt_tokens["input_ids"]
864
- batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
865
-
866
- if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
867
- batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
868
- labels=torch.tensor(batch["rejected_labels"])
869
- )
870
- batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
871
- labels=torch.tensor(batch["chosen_labels"])
872
- )
873
-
874
- if is_torch_xla_available():
875
- # Pad the sequences to global max_length to avoid TorchXLA recompilation
876
- for k in batch:
877
- if "labels" in k or self.is_encoder_decoder:
878
- pad_value = self.label_pad_token_id
879
- elif k.endswith("_input_ids"):
880
- pad_value = self.padding_value
881
- elif k.endswith("_attention_mask"):
882
- pad_value = 0
883
- batch[k] = batch[k] + [pad_value] * (self.max_length - len(batch[k]))
884
- return batch
885
-
886
- @staticmethod
887
- def concatenated_inputs(
888
- batch: dict[str, Union[list, torch.LongTensor]],
889
- is_encoder_decoder: bool = False,
890
- label_pad_token_id: int = -100,
891
- padding_value: int = 0,
892
- device: Optional[torch.device] = None,
893
- ) -> dict[str, torch.LongTensor]:
894
- """Concatenate the chosen and rejected inputs into a single tensor.
895
-
896
- Args:
897
- batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
898
- is_encoder_decoder: Whether the model is an encoder-decoder model.
899
- label_pad_token_id: The label pad token id.
900
- padding_value: The padding value to use for the concatenated inputs_ids.
901
- device: The device for the concatenated inputs.
902
-
903
- Returns:
904
- A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
905
- """
906
- concatenated_batch = {}
907
-
908
- if is_encoder_decoder:
909
- max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
910
- else:
911
- max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
912
-
913
- for k in batch:
914
- if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
915
- if "labels" in k or is_encoder_decoder:
916
- pad_value = label_pad_token_id
917
- elif k.endswith("_input_ids"):
918
- pad_value = padding_value
919
- elif k.endswith("_attention_mask"):
920
- pad_value = 0
921
- concatenated_key = k.replace("chosen", "concatenated")
922
- concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
923
- for k in batch:
924
- if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
925
- if "labels" in k or is_encoder_decoder:
926
- pad_value = label_pad_token_id
927
- elif k.endswith("_input_ids"):
928
- pad_value = padding_value
929
- elif k.endswith("_attention_mask"):
930
- pad_value = 0
931
- concatenated_key = k.replace("rejected", "concatenated")
932
- concatenated_batch[concatenated_key] = torch.cat(
933
- (
934
- concatenated_batch[concatenated_key],
935
- pad_to_length(batch[k], max_length, pad_value=pad_value),
936
- ),
937
- dim=0,
938
- ).to(device=device)
939
-
940
- if is_encoder_decoder:
941
- concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
942
- concatenated_batch["concatenated_attention_mask"] = (
943
- batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
944
- )
945
-
946
- return concatenated_batch
947
-
948
- def odds_ratio_loss(
949
- self,
950
- policy_chosen_logps: torch.FloatTensor,
951
- policy_rejected_logps: torch.FloatTensor,
952
- ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
953
- """Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities.
954
-
955
- Args:
956
- policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
957
- policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
958
-
959
- Returns:
960
- A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
961
- The losses tensor contains the ORPO loss for each example in the batch.
962
- The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
963
- The log odds ratio of the chosen responses over the rejected responses ratio for logging purposes.
964
- The `log(sigmoid(log_odds_chosen))` for logging purposes.
965
- """
966
-
967
- # Derived from Eqs. (4) and (7) from https://huggingface.co/papers/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x)
968
- log_odds = (policy_chosen_logps - policy_rejected_logps) - (
969
- torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps))
970
- )
971
- ratio = F.logsigmoid(log_odds)
972
- losses = self.beta * ratio
973
-
974
- chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
975
- rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
976
-
977
- return losses, chosen_rewards, rejected_rewards, torch.mean(ratio), torch.mean(log_odds)
978
-
979
- @staticmethod
980
- def get_batch_logps(
981
- logits: torch.FloatTensor,
982
- labels: torch.LongTensor,
983
- average_log_prob: bool = False,
984
- label_pad_token_id: int = -100,
985
- is_encoder_decoder: bool = False,
986
- ) -> torch.FloatTensor:
987
- """Compute the log probabilities of the given labels under the given logits.
988
-
989
- Args:
990
- logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
991
- labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
992
- average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
993
- label_pad_token_id: The label pad token id.
994
- is_encoder_decoder: Whether the model is an encoder-decoder model.
995
-
996
- Returns:
997
- A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
998
- """
999
- if logits.shape[:-1] != labels.shape:
1000
- raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
1001
-
1002
- if not is_encoder_decoder:
1003
- labels = labels[:, 1:].clone()
1004
- logits = logits[:, :-1, :]
1005
- loss_mask = labels != label_pad_token_id
1006
-
1007
- # dummy token; we'll ignore the losses on these tokens later
1008
- labels = torch.where(labels == label_pad_token_id, 0, labels)
1009
-
1010
- per_token_logps = selective_log_softmax(logits, labels)
1011
-
1012
- if average_log_prob:
1013
- return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
1014
- else:
1015
- return (per_token_logps * loss_mask).sum(-1)
1016
-
1017
- def concatenated_forward(
1018
- self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
1019
- ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
1020
- """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
1021
-
1022
- We do this to avoid doing two forward passes, because it's faster for FSDP.
1023
- """
1024
- concatenated_batch = self.concatenated_inputs(
1025
- batch,
1026
- is_encoder_decoder=self.is_encoder_decoder,
1027
- label_pad_token_id=self.label_pad_token_id,
1028
- padding_value=self.padding_value,
1029
- device=self.accelerator.device,
1030
- )
1031
- len_chosen = batch["chosen_labels"].shape[0]
1032
-
1033
- model_kwargs = (
1034
- {
1035
- "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
1036
- }
1037
- if self.is_encoder_decoder
1038
- else {}
1039
- )
1040
-
1041
- if self.aux_loss_enabled:
1042
- model_kwargs["output_router_logits"] = True
1043
-
1044
- outputs = model(
1045
- concatenated_batch["concatenated_input_ids"],
1046
- attention_mask=concatenated_batch["concatenated_attention_mask"],
1047
- use_cache=False,
1048
- **model_kwargs,
1049
- )
1050
- all_logits = outputs.logits
1051
-
1052
- def cross_entropy_loss(logits, labels):
1053
- if not self.is_encoder_decoder:
1054
- # Shift so that tokens < n predict n
1055
- logits = logits[..., :-1, :].contiguous()
1056
- labels = labels[..., 1:].contiguous()
1057
- # Flatten the tokens
1058
- loss_fct = nn.CrossEntropyLoss()
1059
- logits = logits.view(-1, logits.shape[-1])
1060
- labels = labels.view(-1)
1061
- # Enable model parallelism
1062
- labels = labels.to(logits.device)
1063
- loss = loss_fct(logits, labels)
1064
- return loss
1065
-
1066
- if self.is_encoder_decoder:
1067
- labels = concatenated_batch["concatenated_labels"].clone()
1068
- else:
1069
- labels = concatenated_batch["concatenated_input_ids"].clone()
1070
- attention_mask = concatenated_batch["concatenated_attention_mask"]
1071
- labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
1072
- # orpo chosen nll loss is computed over the full prompt and response
1073
- chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
1074
-
1075
- all_logps = self.get_batch_logps(
1076
- all_logits,
1077
- concatenated_batch["concatenated_labels"],
1078
- average_log_prob=True,
1079
- is_encoder_decoder=self.is_encoder_decoder,
1080
- label_pad_token_id=self.label_pad_token_id,
1081
- )
1082
-
1083
- chosen_logps = all_logps[:len_chosen]
1084
- rejected_logps = all_logps[len_chosen:]
1085
-
1086
- if not self.is_encoder_decoder:
1087
- chosen_logits = all_logits[:len_chosen, :-1, :]
1088
- rejected_logits = all_logits[len_chosen:, :-1, :]
1089
- else:
1090
- chosen_logits = all_logits[:len_chosen]
1091
- rejected_logits = all_logits[len_chosen:]
1092
-
1093
- if self.aux_loss_enabled:
1094
- return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss)
1095
-
1096
- return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)
1097
-
1098
- def get_batch_loss_metrics(
1099
- self,
1100
- model,
1101
- batch: dict[str, Union[list, torch.LongTensor]],
1102
- train_eval: Literal["train", "eval"] = "train",
1103
- ):
1104
- """Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
1105
- metrics = {}
1106
-
1107
- forward_output = self.concatenated_forward(model, batch)
1108
- (
1109
- policy_chosen_logps,
1110
- policy_rejected_logps,
1111
- policy_chosen_logits,
1112
- policy_rejected_logits,
1113
- policy_nll_loss,
1114
- ) = forward_output[:5]
1115
- if self.aux_loss_enabled:
1116
- aux_loss = forward_output[5]
1117
-
1118
- losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
1119
- policy_chosen_logps, policy_rejected_logps
1120
- )
1121
- # full ORPO loss
1122
- loss = policy_nll_loss - losses.mean()
1123
-
1124
- reward_accuracies = (chosen_rewards > rejected_rewards).float()
1125
-
1126
- prefix = "eval_" if train_eval == "eval" else ""
1127
- metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean()
1128
- metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean()
1129
- metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean()
1130
- metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
1131
- chosen_rewards - rejected_rewards
1132
- ).mean()
1133
- metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
1134
- metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
1135
- metrics[f"{prefix}logits/rejected"] = (
1136
- self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean()
1137
- )
1138
- metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean()
1139
- metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
1140
- metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).mean()
1141
- metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).mean()
1142
- if is_torch_xla_available():
1143
- xm.mark_step() # needed because .item() calls
1144
- for k, v in metrics.items():
1145
- metrics[k] = v.item()
1146
- if self.aux_loss_enabled:
1147
- loss += self.aux_loss_coef * aux_loss
1148
-
1149
- return loss, metrics
1150
-
1151
- def compute_loss(
1152
- self,
1153
- model: Union[PreTrainedModel, nn.Module],
1154
- inputs: dict[str, Union[torch.Tensor, Any]],
1155
- return_outputs=False,
1156
- num_items_in_batch=None,
1157
- ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
1158
- compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1159
-
1160
- with compute_loss_context_manager:
1161
- loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
1162
-
1163
- # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
1164
- loss = loss.to(self.args.device)
1165
-
1166
- # force log the metrics
1167
- self.store_metrics(metrics, train_eval="train")
1168
-
1169
- if return_outputs:
1170
- return (loss, metrics)
1171
- return loss
1172
-
1173
- def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
1174
- """Generate samples from the model and reference model for the given batch of inputs."""
1175
-
1176
- # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1177
- # the torch cuda amp context manager as some hidden states are silently casted to full precision.
1178
- generate_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1179
-
1180
- with generate_context_manager:
1181
- policy_output = model.generate(
1182
- input_ids=batch["prompt_input_ids"],
1183
- attention_mask=batch["prompt_attention_mask"],
1184
- max_length=self.max_length,
1185
- do_sample=True,
1186
- pad_token_id=self.processing_class.pad_token_id,
1187
- )
1188
-
1189
- policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1190
- policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1191
-
1192
- return policy_output_decoded
1193
-
1194
- def prediction_step(
1195
- self,
1196
- model: Union[PreTrainedModel, nn.Module],
1197
- inputs: dict[str, Union[torch.Tensor, Any]],
1198
- prediction_loss_only: bool,
1199
- ignore_keys: Optional[list[str]] = None,
1200
- ):
1201
- if not self.use_dpo_data_collator:
1202
- warnings.warn(
1203
- "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
1204
- "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
1205
- )
1206
- if ignore_keys is None:
1207
- if hasattr(model, "config"):
1208
- ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1209
- else:
1210
- ignore_keys = []
1211
-
1212
- prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
1213
-
1214
- with torch.no_grad(), prediction_context_manager:
1215
- loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
1216
-
1217
- # force log the metrics
1218
- self.store_metrics(metrics, train_eval="eval")
1219
-
1220
- if prediction_loss_only:
1221
- return (loss.detach(), None, None)
1222
-
1223
- # logits for the chosen and rejected samples from model
1224
- logits_dict = {
1225
- "eval_logits/chosen": metrics["eval_logits/chosen"],
1226
- "eval_logits/rejected": metrics["eval_logits/rejected"],
1227
- }
1228
- logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
1229
- logits = torch.tensor(logits, device=self.accelerator.device)
1230
- labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1231
-
1232
- return (loss.detach(), logits, labels)
1233
-
1234
- def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1235
- for key, value in metrics.items():
1236
- self._stored_metrics[train_eval][key].append(value)
1237
-
1238
- def evaluation_loop(
1239
- self,
1240
- dataloader: DataLoader,
1241
- description: str,
1242
- prediction_loss_only: Optional[bool] = None,
1243
- ignore_keys: Optional[list[str]] = None,
1244
- metric_key_prefix: str = "eval",
1245
- ) -> EvalLoopOutput:
1246
- """
1247
- Overriding built-in evaluation loop to store metrics for each batch.
1248
- Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
1249
-
1250
- Works both with or without labels.
1251
- """
1252
-
1253
- # Sample and save to game log if requested (for one batch to save time)
1254
- if self.generate_during_eval:
1255
- # Generate random indices within the range of the total number of samples
1256
- num_samples = len(dataloader.dataset)
1257
- random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1258
-
1259
- # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1260
- random_batch_dataset = dataloader.dataset.select(random_indices)
1261
- random_batch = self.data_collator(random_batch_dataset)
1262
- random_batch = self._prepare_inputs(random_batch)
1263
-
1264
- policy_output_decoded = self.generate_from_model(self.model, random_batch)
1265
-
1266
- table = pd.DataFrame(
1267
- columns=["Prompt", "Policy"],
1268
- data=[
1269
- [prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
1270
- ],
1271
- )
1272
- if "wandb" in self.args.report_to:
1273
- wandb.log({"game_log": wandb.Table(data=table)})
1274
-
1275
- if "comet_ml" in self.args.report_to:
1276
- log_table_to_comet_experiment(
1277
- name="game_log.csv",
1278
- table=table,
1279
- )
1280
-
1281
- # Base evaluation
1282
- initial_output = super().evaluation_loop(
1283
- dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1284
- )
1285
-
1286
- return initial_output
1287
-
1288
- def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1289
- """
1290
- Log `logs` on the various objects watching training, including stored metrics.
1291
-
1292
- Args:
1293
- logs (`dict[str, float]`):
1294
- The values to log.
1295
- start_time (`float` or `None`, *optional*, defaults to `None`):
1296
- Start time of the training.
1297
- """
1298
- # logs either has 'loss' or 'eval_loss'
1299
- train_eval = "train" if "loss" in logs else "eval"
1300
- # Add averaged stored metrics to logs
1301
- for key, metrics in self._stored_metrics[train_eval].items():
1302
- logs[key] = torch.tensor(metrics).mean().item()
1303
- del self._stored_metrics[train_eval]
1304
-
1305
- if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1306
- return super().log(logs, start_time)
1307
- else: # transformers<=4.46
1308
- return super().log(logs)
1309
-
1310
- def _shift_right(self, input_ids):
1311
- if self.decoder_start_token_id is None:
1312
- raise ValueError(
1313
- "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
1314
- )
1315
-
1316
- # shift inputs to the right
1317
- if is_torch_fx_proxy(input_ids):
1318
- # Item assignment is not supported natively for proxies.
1319
- shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
1320
- shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
1321
- else:
1322
- shifted_input_ids = input_ids.new_zeros(input_ids.shape)
1323
- shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
1324
- shifted_input_ids[..., 0] = self.decoder_start_token_id
1325
-
1326
- if self.pad_token_id is None:
1327
- raise ValueError("model.config.pad_token_id has to be defined.")
1328
- # replace possible -100 values in labels by `pad_token_id`
1329
- shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
1330
-
1331
- return shifted_input_ids
1332
-
1333
- def create_model_card(
1334
- self,
1335
- model_name: Optional[str] = None,
1336
- dataset_name: Optional[str] = None,
1337
- tags: Union[str, list[str], None] = None,
1338
- ):
1339
- """
1340
- Creates a draft of a model card using the information available to the `Trainer`.
1341
-
1342
- Args:
1343
- model_name (`str` or `None`, *optional*, defaults to `None`):
1344
- Name of the model.
1345
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
1346
- Name of the dataset used for training.
1347
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1348
- Tags to be associated with the model card.
1349
- """
1350
- if not self.is_world_process_zero():
1351
- return
1352
-
1353
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1354
- base_model = self.model.config._name_or_path
1355
- else:
1356
- base_model = None
1357
-
1358
- tags = tags or []
1359
- if isinstance(tags, str):
1360
- tags = [tags]
1361
-
1362
- if hasattr(self.model.config, "unsloth_version"):
1363
- tags.append("unsloth")
1364
-
1365
- citation = textwrap.dedent("""\
1366
- @article{hong2024orpo,
1367
- title = {{ORPO: Monolithic Preference Optimization without Reference Model}},
1368
- author = {Jiwoo Hong and Noah Lee and James Thorne},
1369
- year = 2024,
1370
- eprint = {arXiv:2403.07691}
1371
- }""")
1372
-
1373
- model_card = generate_model_card(
1374
- base_model=base_model,
1375
- model_name=model_name,
1376
- hub_model_id=self.hub_model_id,
1377
- dataset_name=dataset_name,
1378
- tags=tags,
1379
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1380
- comet_url=get_comet_experiment_url(),
1381
- trainer_name="ORPO",
1382
- trainer_citation=citation,
1383
- paper_title="ORPO: Monolithic Preference Optimization without Reference Model",
1384
- paper_id="2403.07691",
1385
- )
1386
-
1387
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
1388
- class UnslothORPOTrainer(_UnslothORPOTrainer):
1389
- """
1390
-
1391
- Initialize ORPOTrainer.
1392
-
1393
- Args:
1394
- model (`transformers.PreTrainedModel`):
1395
- The model to train, preferably an `AutoModelForSequenceClassification`.
1396
- args (`ORPOConfig`):
1397
- The ORPO config arguments to use for training.
1398
- data_collator (`transformers.DataCollator`):
1399
- The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1400
- which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1401
- train_dataset (`datasets.Dataset`):
1402
- The dataset to use for training.
1403
- eval_dataset (`datasets.Dataset`):
1404
- The dataset to use for evaluation.
1405
- processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1406
- Processing class used to process the data. If provided, will be used to automatically process the inputs
1407
- for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1408
- reuse the fine-tuned model.
1409
- model_init (`Callable[[], transformers.PreTrainedModel]`):
1410
- The model initializer to use for training. If None is specified, the default model initializer will be used.
1411
- callbacks (`list[transformers.TrainerCallback]`):
1412
- The callbacks to use for training.
1413
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1414
- The optimizer and scheduler to use for training.
1415
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1416
- The function to use to preprocess the logits before computing the metrics.
1417
- peft_config (`dict`, defaults to `None`):
1418
- The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
1419
- compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1420
- The function to use to compute the metrics. Must take a `EvalPrediction` and return
1421
- a dictionary string to metric values.
1422
-
1423
- """
1424
- def __init__(
1425
- self,
1426
- model = None,
1427
- args = None,
1428
- data_collator = None,
1429
- train_dataset = None,
1430
- eval_dataset = None,
1431
- processing_class = None,
1432
- model_init = None,
1433
- callbacks = None,
1434
- preprocess_logits_for_metrics = None,
1435
- peft_config = None,
1436
- compute_metrics = None,
1437
- **kwargs
1438
- ):
1439
- if args is None: args = UnslothORPOConfig()
1440
- use_bf16 = getattr(args, 'bf16', False)
1441
- if type(use_bf16) is not bool: use_bf16 = False
1442
- use_fp16 = getattr(args, 'fp16', False)
1443
- if type(use_fp16) is not bool: use_fp16 = False
1444
- force_float32 = False
1445
- if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1446
- print('Unsloth: Switching to float32 training since model cannot work with float16')
1447
- force_float32 = True
1448
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1449
- dtype = getattr(model.config, 'torch_dtype', None)
1450
- if dtype is None: dtype = model.get_input_embeddings().dtype
1451
- from unsloth_zoo.utils import _get_dtype
1452
- dtype = _get_dtype(dtype)
1453
- float16 = dtype == torch.float16
1454
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1455
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1456
- if force_float32:
1457
- args.fp16 = False
1458
- args.bf16 = False
1459
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1460
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1461
- args.fp16 = float16
1462
- args.bf16 = not float16
1463
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1464
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1465
- args.eval_strategy = 'steps'
1466
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1467
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1468
- if ga_steps is not None and ga_steps > 1:
1469
- from transformers import __version__ as transformers_version
1470
- if Version(transformers_version) <= Version('4.45.2'):
1471
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1472
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1473
- if getattr(args, 'eval_strategy', 'no') != 'no':
1474
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1475
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1476
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1477
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1478
- if type(fp16_full_eval) is not bool: fp16_full_eval = False
1479
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1480
- if type(bf16_full_eval) is not bool: bf16_full_eval = False
1481
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1482
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1483
- if force_float32:
1484
- args.bf16_full_eval = False
1485
- args.fp16_full_eval = False
1486
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1487
- args.bf16_full_eval = True
1488
- args.fp16_full_eval = False
1489
- elif not bf16_full_eval and not fp16_full_eval:
1490
- args.bf16_full_eval = args.bf16
1491
- args.fp16_full_eval = args.fp16
1492
- _output_logits = False
1493
- if locals().get('compute_metrics', None) is not None: _output_logits = True
1494
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1495
- if _output_logits:
1496
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1497
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1498
- pass
1499
- else:
1500
- model_max_seq_length = getattr(model, 'max_seq_length', None)
1501
- args_max_seq_length = getattr(args, 'max_seq_length', None)
1502
- if args_max_seq_length is None and model_max_seq_length is not None:
1503
- max_seq_length = model.max_seq_length
1504
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1505
- if model is not None and hasattr(model, 'for_training'):
1506
- model.for_training()
1507
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1508
- if 'processing_class' in locals():
1509
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1510
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1511
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1512
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1513
- if not isinstance(data_collator, UnslothVisionDataCollator):
1514
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1515
- data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
1516
- elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1517
- data_collator = DataCollatorForSeq2Seq(__tokenizer)
1518
- else:
1519
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1520
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1521
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1522
- if not isinstance(data_collator, UnslothVisionDataCollator):
1523
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1524
- if isinstance(data_collator, DataCollatorForSeq2Seq):
1525
- data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1526
- else:
1527
- data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
1528
- other_metrics = []
1529
-
1530
- from unsloth_zoo.logging_utils import PatchRLStatistics
1531
- PatchRLStatistics('orpo_trainer', other_metrics)
1532
-
1533
- super().__init__(
1534
- model = model,
1535
- args = args,
1536
- data_collator = data_collator,
1537
- train_dataset = train_dataset,
1538
- eval_dataset = eval_dataset,
1539
- processing_class = processing_class,
1540
- model_init = model_init,
1541
- callbacks = callbacks,
1542
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1543
- peft_config = peft_config,
1544
- compute_metrics = compute_metrics,**kwargs)
1545
- if hasattr(self, 'neftune_hook_handle'):
1546
- self.neftune_hook_handle.remove()
1547
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1548
- if getattr(args, 'neftune_noise_alpha', None) is not None:
1549
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1550
- pass
1551
-
1552
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_run_uploads/UnslothOnlineDPOTrainer.py DELETED
@@ -1,1293 +0,0 @@
1
- """
2
- 2025.7.11
3
- 2025.7.11
4
- 4.54.1
5
- 0.16.1
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
13
- from trl.trainer.online_dpo_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, DPODataCollatorWithPadding, DataCollator, DataLoader, Dataset, EvalPrediction, F, FeatureExtractionMixin, GenerationConfig, IterableDataset, OnlineDPOConfig, OnlineDPOTrainer, OptimizerNames, Optional, PREFIX_CHECKPOINT_DIR, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, Trainer, TrainerCallback, Union, apply_chat_template, create_reference_model, datasets, disable_dropout_in_model, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_peft_available, is_wandb_available, jinja2, logging, maybe_apply_chat_template, nn, np, os, prepare_deepspeed, seed_worker, textwrap, torch, transformers, truncate_right, unwrap_model_for_generation, version, wandb, warnings, wraps, F, is_conversational, os, torch)
14
-
15
-
16
- import os
17
- from typing import *
18
- from dataclasses import dataclass, field
19
- from packaging.version import Version
20
- import torch
21
- import numpy as np
22
- from contextlib import nullcontext
23
- from torch.nn import functional as F
24
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
25
-
26
- torch_compile_options = {
27
- "epilogue_fusion" : True,
28
- "max_autotune" : False,
29
- "shape_padding" : True,
30
- "trace.enabled" : False,
31
- "triton.cudagraphs" : False,
32
- }
33
-
34
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
35
- def chunked_selective_log_softmax(logits, index):
36
- # Split into 4 chunks only
37
- chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
38
- chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
39
- all_per_token_logps = []
40
- # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
41
- for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
42
- chunk_logits = chunk_logits.to(torch.float32)
43
- selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
44
- logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
45
- per_token_logps = selected_logits - logsumexp_values
46
- all_per_token_logps.append(per_token_logps)
47
- pass
48
- all_per_token_logps = torch.concat(all_per_token_logps)
49
- all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
50
- return all_per_token_logps
51
- def vLLMSamplingParams(**kwargs):
52
- from vllm import SamplingParams
53
- sampling_params = SamplingParams(**kwargs)
54
- sampling_params._set_kwargs = kwargs
55
- return sampling_params
56
- @dataclass
57
- class UnslothOnlineDPOConfig(OnlineDPOConfig):
58
- """
59
-
60
- Configuration class for the [`OnlineDPOTrainer`].
61
-
62
- Using [`~transformers.HfArgumentParser`] we can turn this class into
63
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
64
- command line.
65
-
66
- Parameters:
67
- learning_rate (`float`, *optional*, defaults to `5e-7`):
68
- Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
69
- [`~transformers.TrainingArguments`].
70
- reward_model_path (`str` or `None`, *optional*, defaults to `None`):
71
- Path to the reward model. Either `judge` or `reward_model_path` must be set, but not both.
72
- judge (`str` or `None`, *optional*, defaults to `None`):
73
- Name of the judge to use. Either `judge` or `reward_model_path` must be set, but not both.
74
- max_new_tokens (`int`, *optional*, defaults to `64`):
75
- Maximum number of tokens to generate per completion.
76
- max_length (`int`, *optional*, defaults to `256`):
77
- Maximum total length of the sequence (prompt + completion) used to compute log probabilities. If the
78
- sequence exceeds this limit, the leftmost tokens will be truncated to preserve as much of the completion as
79
- possible.
80
- temperature (`float`, *optional*, defaults to `0.9`):
81
- Temperature for sampling. The higher the temperature, the more random the completions.
82
- missing_eos_penalty (`float` or `None`, *optional*, defaults to `None`):
83
- Penalty applied to the score when the model fails to generate an EOS token. This is useful to encourage
84
- to generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be a positive
85
- value.
86
- beta (`float` or `list[float]`, *optional*, defaults to `0.1`):
87
- Parameter controlling the deviation from the reference model. Higher β means less deviation from the
88
- reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
89
- the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the β is
90
- selected for each new epoch and the last β is used for the rest of the epochs.
91
- loss_type (`str`, *optional*, defaults to `"sigmoid"`):
92
- Type of loss to use. Possible values are:
93
-
94
- - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
95
- - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
96
-
97
- dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
98
- Number of processes to use for processing the dataset.
99
- disable_dropout (`bool`, *optional*, defaults to `True`):
100
- Whether to disable dropout in the model and reference model.
101
- use_vllm (`bool`, *optional*, defaults to `False`):
102
- Whether to use vLLM for generating completions. Requires vLLM to be installed (`pip install vllm`).
103
- ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
104
- This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
105
- improving generation speed. However, disabling this option allows training models that exceed the VRAM
106
- capacity of a single GPU, albeit at the cost of slower generation.
107
-
108
- """
109
- vllm_sampling_params: Optional[Any] = field(
110
- default = None,
111
- metadata = {'help': 'vLLM SamplingParams'},
112
- )
113
- unsloth_num_chunks : Optional[int] = field(
114
- default = -1,
115
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
116
- )
117
- def __init__(
118
- self,
119
- output_dir = None,
120
- overwrite_output_dir = None,
121
- do_train = False,
122
- do_eval = False,
123
- do_predict = False,
124
- eval_strategy = 'no',
125
- prediction_loss_only = False,
126
- per_device_train_batch_size = 4,
127
- per_device_eval_batch_size = 4,
128
- per_gpu_train_batch_size = None,
129
- per_gpu_eval_batch_size = None,
130
- gradient_accumulation_steps = 2,
131
- eval_accumulation_steps = 2,
132
- eval_delay = 0,
133
- torch_empty_cache_steps = 250,
134
- learning_rate = 5e-05,
135
- weight_decay = 0.01,
136
- adam_beta1 = 0.9,
137
- adam_beta2 = 0.999,
138
- adam_epsilon = 1e-08,
139
- max_grad_norm = 1.0,
140
- num_train_epochs = 3.0,
141
- max_steps = -1,
142
- lr_scheduler_type = 'linear',
143
- warmup_ratio = 0.1,
144
- warmup_steps = 0,
145
- log_level = 'passive',
146
- log_level_replica = 'warning',
147
- log_on_each_node = True,
148
- logging_dir = None,
149
- logging_strategy = 'steps',
150
- logging_first_step = False,
151
- logging_steps = 1,
152
- logging_nan_inf_filter = False,
153
- save_strategy = 'steps',
154
- save_steps = 500,
155
- save_total_limit = None,
156
- save_safetensors = True,
157
- save_on_each_node = False,
158
- save_only_model = False,
159
- restore_callback_states_from_checkpoint = False,
160
- no_cuda = False,
161
- use_cpu = False,
162
- use_mps_device = False,
163
- seed = 3407,
164
- data_seed = 3407,
165
- jit_mode_eval = False,
166
- use_ipex = False,
167
- bf16 = False,
168
- fp16 = False,
169
- fp16_opt_level = 'O1',
170
- half_precision_backend = 'auto',
171
- bf16_full_eval = False,
172
- fp16_full_eval = False,
173
- tf32 = None,
174
- local_rank = -1,
175
- ddp_backend = None,
176
- tpu_num_cores = None,
177
- tpu_metrics_debug = False,
178
- debug = '',
179
- dataloader_drop_last = False,
180
- eval_steps = None,
181
- dataloader_num_workers = 0,
182
- dataloader_prefetch_factor = None,
183
- past_index = -1,
184
- run_name = None,
185
- disable_tqdm = None,
186
- remove_unused_columns = True,
187
- label_names = None,
188
- load_best_model_at_end = False,
189
- metric_for_best_model = None,
190
- greater_is_better = None,
191
- ignore_data_skip = False,
192
- fsdp = '',
193
- fsdp_min_num_params = 0,
194
- fsdp_config = None,
195
- fsdp_transformer_layer_cls_to_wrap = None,
196
- accelerator_config = None,
197
- deepspeed = None,
198
- label_smoothing_factor = 0.0,
199
- optim = 'adamw_8bit',
200
- optim_args = None,
201
- adafactor = False,
202
- group_by_length = False,
203
- length_column_name = 'length',
204
- report_to = None,
205
- ddp_find_unused_parameters = None,
206
- ddp_bucket_cap_mb = None,
207
- ddp_broadcast_buffers = None,
208
- dataloader_pin_memory = True,
209
- dataloader_persistent_workers = False,
210
- skip_memory_metrics = True,
211
- use_legacy_prediction_loop = False,
212
- push_to_hub = False,
213
- resume_from_checkpoint = None,
214
- hub_model_id = None,
215
- hub_strategy = 'every_save',
216
- hub_token = None,
217
- hub_private_repo = None,
218
- hub_always_push = False,
219
- hub_revision = None,
220
- gradient_checkpointing = False,
221
- gradient_checkpointing_kwargs = None,
222
- include_inputs_for_metrics = False,
223
- eval_do_concat_batches = True,
224
- fp16_backend = 'auto',
225
- push_to_hub_model_id = None,
226
- push_to_hub_organization = None,
227
- push_to_hub_token = None,
228
- mp_parameters = '',
229
- auto_find_batch_size = True,
230
- full_determinism = False,
231
- torchdynamo = None,
232
- ray_scope = 'last',
233
- ddp_timeout = 1800,
234
- torch_compile = False,
235
- torch_compile_backend = None,
236
- torch_compile_mode = None,
237
- include_tokens_per_second = False,
238
- include_num_input_tokens_seen = False,
239
- neftune_noise_alpha = None,
240
- optim_target_modules = None,
241
- batch_eval_metrics = False,
242
- eval_on_start = False,
243
- use_liger_kernel = False,
244
- liger_kernel_config = None,
245
- eval_use_gather_object = False,
246
- average_tokens_across_devices = True,
247
- reward_model_path = None,
248
- judge = None,
249
- max_new_tokens = 64,
250
- max_length = 512,
251
- temperature = 0.9,
252
- missing_eos_penalty = None,
253
- loss_type = 'sigmoid',
254
- dataset_num_proc = None,
255
- disable_dropout = True,
256
- use_vllm = False,
257
- ds3_gather_for_generation = True,
258
- vllm_sampling_params = None,
259
- unsloth_num_chunks = -1,
260
- **kwargs,
261
- ):
262
- if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
263
- if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
264
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
265
- output_dir = 'unsloth_training_checkpoints'
266
- save_strategy = 'no'
267
- if dataset_num_proc is None:
268
- from multiprocessing import cpu_count
269
- dataset_num_proc = min(cpu_count()*2, 2)
270
- if temperature <= 0:
271
- raise MathError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
272
- elif temperature >= 10:
273
- raise MathError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
274
-
275
-
276
- super().__init__(
277
- output_dir = output_dir,
278
- overwrite_output_dir = overwrite_output_dir,
279
- do_train = do_train,
280
- do_eval = do_eval,
281
- do_predict = do_predict,
282
- eval_strategy = eval_strategy,
283
- prediction_loss_only = prediction_loss_only,
284
- per_device_train_batch_size = per_device_train_batch_size,
285
- per_device_eval_batch_size = per_device_eval_batch_size,
286
- per_gpu_train_batch_size = per_gpu_train_batch_size,
287
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
288
- gradient_accumulation_steps = gradient_accumulation_steps,
289
- eval_accumulation_steps = eval_accumulation_steps,
290
- eval_delay = eval_delay,
291
- torch_empty_cache_steps = torch_empty_cache_steps,
292
- learning_rate = learning_rate,
293
- weight_decay = weight_decay,
294
- adam_beta1 = adam_beta1,
295
- adam_beta2 = adam_beta2,
296
- adam_epsilon = adam_epsilon,
297
- max_grad_norm = max_grad_norm,
298
- num_train_epochs = num_train_epochs,
299
- max_steps = max_steps,
300
- lr_scheduler_type = lr_scheduler_type,
301
- warmup_ratio = warmup_ratio,
302
- warmup_steps = warmup_steps,
303
- log_level = log_level,
304
- log_level_replica = log_level_replica,
305
- log_on_each_node = log_on_each_node,
306
- logging_dir = logging_dir,
307
- logging_strategy = logging_strategy,
308
- logging_first_step = logging_first_step,
309
- logging_steps = logging_steps,
310
- logging_nan_inf_filter = logging_nan_inf_filter,
311
- save_strategy = save_strategy,
312
- save_steps = save_steps,
313
- save_total_limit = save_total_limit,
314
- save_safetensors = save_safetensors,
315
- save_on_each_node = save_on_each_node,
316
- save_only_model = save_only_model,
317
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
318
- no_cuda = no_cuda,
319
- use_cpu = use_cpu,
320
- use_mps_device = use_mps_device,
321
- seed = seed,
322
- data_seed = data_seed,
323
- jit_mode_eval = jit_mode_eval,
324
- use_ipex = use_ipex,
325
- bf16 = bf16,
326
- fp16 = fp16,
327
- fp16_opt_level = fp16_opt_level,
328
- half_precision_backend = half_precision_backend,
329
- bf16_full_eval = bf16_full_eval,
330
- fp16_full_eval = fp16_full_eval,
331
- tf32 = tf32,
332
- local_rank = local_rank,
333
- ddp_backend = ddp_backend,
334
- tpu_num_cores = tpu_num_cores,
335
- tpu_metrics_debug = tpu_metrics_debug,
336
- debug = debug,
337
- dataloader_drop_last = dataloader_drop_last,
338
- eval_steps = eval_steps,
339
- dataloader_num_workers = dataloader_num_workers,
340
- dataloader_prefetch_factor = dataloader_prefetch_factor,
341
- past_index = past_index,
342
- run_name = run_name,
343
- disable_tqdm = disable_tqdm,
344
- remove_unused_columns = remove_unused_columns,
345
- label_names = label_names,
346
- load_best_model_at_end = load_best_model_at_end,
347
- metric_for_best_model = metric_for_best_model,
348
- greater_is_better = greater_is_better,
349
- ignore_data_skip = ignore_data_skip,
350
- fsdp = fsdp,
351
- fsdp_min_num_params = fsdp_min_num_params,
352
- fsdp_config = fsdp_config,
353
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
354
- accelerator_config = accelerator_config,
355
- deepspeed = deepspeed,
356
- label_smoothing_factor = label_smoothing_factor,
357
- optim = optim,
358
- optim_args = optim_args,
359
- adafactor = adafactor,
360
- group_by_length = group_by_length,
361
- length_column_name = length_column_name,
362
- report_to = report_to,
363
- ddp_find_unused_parameters = ddp_find_unused_parameters,
364
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
365
- ddp_broadcast_buffers = ddp_broadcast_buffers,
366
- dataloader_pin_memory = dataloader_pin_memory,
367
- dataloader_persistent_workers = dataloader_persistent_workers,
368
- skip_memory_metrics = skip_memory_metrics,
369
- use_legacy_prediction_loop = use_legacy_prediction_loop,
370
- push_to_hub = push_to_hub,
371
- resume_from_checkpoint = resume_from_checkpoint,
372
- hub_model_id = hub_model_id,
373
- hub_strategy = hub_strategy,
374
- hub_token = hub_token,
375
- hub_private_repo = hub_private_repo,
376
- hub_always_push = hub_always_push,
377
- hub_revision = hub_revision,
378
- gradient_checkpointing = gradient_checkpointing,
379
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
380
- include_inputs_for_metrics = include_inputs_for_metrics,
381
- eval_do_concat_batches = eval_do_concat_batches,
382
- fp16_backend = fp16_backend,
383
- push_to_hub_model_id = push_to_hub_model_id,
384
- push_to_hub_organization = push_to_hub_organization,
385
- push_to_hub_token = push_to_hub_token,
386
- mp_parameters = mp_parameters,
387
- auto_find_batch_size = auto_find_batch_size,
388
- full_determinism = full_determinism,
389
- torchdynamo = torchdynamo,
390
- ray_scope = ray_scope,
391
- ddp_timeout = ddp_timeout,
392
- torch_compile = torch_compile,
393
- torch_compile_backend = torch_compile_backend,
394
- torch_compile_mode = torch_compile_mode,
395
- include_tokens_per_second = include_tokens_per_second,
396
- include_num_input_tokens_seen = include_num_input_tokens_seen,
397
- neftune_noise_alpha = neftune_noise_alpha,
398
- optim_target_modules = optim_target_modules,
399
- batch_eval_metrics = batch_eval_metrics,
400
- eval_on_start = eval_on_start,
401
- use_liger_kernel = use_liger_kernel,
402
- liger_kernel_config = liger_kernel_config,
403
- eval_use_gather_object = eval_use_gather_object,
404
- average_tokens_across_devices = average_tokens_across_devices,
405
- reward_model_path = reward_model_path,
406
- judge = judge,
407
- max_new_tokens = max_new_tokens,
408
- max_length = max_length,
409
- temperature = temperature,
410
- missing_eos_penalty = missing_eos_penalty,
411
- loss_type = loss_type,
412
- dataset_num_proc = dataset_num_proc,
413
- disable_dropout = disable_dropout,
414
- use_vllm = use_vllm,
415
- ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
416
- self.vllm_sampling_params = vllm_sampling_params
417
- self.unsloth_num_chunks = unsloth_num_chunks
418
- pass
419
-
420
- class _UnslothOnlineDPOTrainer(Trainer):
421
- r""""""
422
-
423
- _tag_names = ["trl", "online-dpo"]
424
-
425
- def __init__(
426
- self,
427
- model: Union[PreTrainedModel, nn.Module],
428
- ref_model: Union[PreTrainedModel, nn.Module, None] = None,
429
- reward_model: Union[PreTrainedModel, nn.Module, None] = None,
430
- judge: Optional[BasePairwiseJudge] = None,
431
- args: Optional[OnlineDPOConfig] = None,
432
- data_collator: Optional[DataCollator] = None,
433
- train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
434
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset], "datasets.Dataset"]] = None,
435
- processing_class: Optional[
436
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
437
- ] = None,
438
- reward_processing_class: Optional[PreTrainedTokenizerBase] = None,
439
- peft_config: Optional[dict] = None,
440
- compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
441
- callbacks: Optional[list[TrainerCallback]] = None,
442
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
443
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
444
- ) -> None:
445
-
446
- if hasattr(model, 'vllm_engine') and hasattr(args, 'use_vllm'):
447
- if (getattr(args, 'use_vllm', False) == False):
448
- args.use_vllm = True
449
- if ref_model is model:
450
- raise ValueError(
451
- "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
452
- "same as `model`, either omit the `ref_model` argument or pass `None`."
453
- )
454
-
455
- self.ref_model = ref_model
456
-
457
- if reward_model is not None and judge is not None:
458
- warnings.warn(
459
- "Both `reward_model` and `judge` are provided. Please choose provide only one of them. "
460
- "Ignoring `judge` and using `reward_model`.",
461
- UserWarning,
462
- )
463
- judge = None
464
- elif reward_model is None and judge is None:
465
- raise ValueError("Either `reward_model` or `judge` must be provided.")
466
-
467
- self.reward_model = reward_model
468
- self.reward_processing_class = reward_processing_class
469
- self.judge = judge
470
-
471
- if args.missing_eos_penalty is not None and judge is not None:
472
- raise ValueError("`missing_eos_penalty` is not supported when `judge` is provided.")
473
-
474
- if args is None:
475
- raise ValueError("`args` must be provided.")
476
-
477
- # Check that the processing_class is provided
478
- if processing_class is None:
479
- raise ValueError("`processing_class` must be provided.")
480
-
481
- # Convert to PEFT model if peft_config is provided
482
- if False:
483
- # Check if PEFT is available
484
- if not is_peft_available():
485
- raise ImportError(
486
- "PEFT is not available and passed `peft_config`. Please install PEFT with "
487
- "`pip install peft` to use it."
488
- )
489
-
490
- # If the model is already a PeftModel, we need to merge and unload it.
491
- # Further information here: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft
492
- if isinstance(model, PeftModel):
493
- model = model.merge_and_unload()
494
-
495
- # Get peft model with the given config
496
- model = model
497
-
498
- # Disable dropout in the model and reference model
499
- if args.disable_dropout:
500
- disable_dropout_in_model(model)
501
- if self.ref_model is not None:
502
- disable_dropout_in_model(self.ref_model)
503
-
504
- # Handle the ref_model
505
- # Usually, the user wants the ref model to be the initial version of the model. When using PEFT, it's easy to
506
- # get the ref model, as it's just the model with a disabled adapter. When not using PEFT, we need to create
507
- # the ref model from the model by copying it and disable the gradients and set it in evaluation mode.
508
- if ref_model is None: # No ref model provided, the most common case
509
- if False:
510
- self.ref_model = create_reference_model(model) # copy, disable gradients, set eval mode
511
- else:
512
- self.ref_model = None # we don't need a ref model here, we can just disable the adapter.
513
- else: # rare case, the user provided a ref model
514
- self.ref_model = ref_model
515
- self.ref_model.eval()
516
-
517
- # Disable the gradient and set the reward model in eval mode
518
- if self.reward_model is not None:
519
- self.reward_model.eval()
520
-
521
- # Define the collator is not provided
522
- if data_collator is None:
523
- data_collator = DPODataCollatorWithPadding(pad_token_id=processing_class.pad_token_id)
524
-
525
- self.max_length = args.max_length
526
-
527
- self.stats = {
528
- "objective/kl": [],
529
- "objective/entropy": [],
530
- "objective/non_score_reward": [],
531
- "rewards/chosen": [],
532
- "rewards/rejected": [],
533
- "rewards/accuracies": [],
534
- "rewards/margins": [],
535
- "logps/chosen": [],
536
- "logps/rejected": [],
537
- "val/contain_eos_token": [],
538
- "beta": [],
539
- }
540
- if self.reward_model is not None:
541
- self.stats["objective/rlhf_reward"] = []
542
- self.stats["objective/scores_margin"] = []
543
- self.stats["objective/scores"] = []
544
-
545
- if args.use_vllm:
546
- self.llm = model.vllm_engine; self._last_loaded_step = 0; self.generation_config = SamplingParams(
547
- n=2,
548
- max_tokens=args.max_new_tokens,
549
- temperature=args.temperature,
550
- top_k=50,
551
- top_p=1.0,
552
- detokenize=False,
553
- **getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {}),
554
- )
555
- else:
556
- self.generation_config = GenerationConfig(
557
- max_new_tokens=args.max_new_tokens,
558
- temperature=args.temperature,
559
- top_k=50,
560
- top_p=1.0,
561
- do_sample=True,
562
- use_cache=False if args.gradient_checkpointing else True,
563
- )
564
-
565
- # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
566
- # input tensor associated with the key "input_ids". However, in Online DPO, the sampled data does not include
567
- # the "input_ids" key. As a result, the trainer issues the warning: "Could not estimate the number of tokens
568
- # of the input, floating-point operations will not be computed." To suppress this warning, we set the
569
- # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate
570
- # that the warning has already been issued.
571
- model.warnings_issued["estimate_tokens"] = True
572
-
573
- super().__init__(
574
- model=model,
575
- args=args,
576
- data_collator=data_collator,
577
- train_dataset=train_dataset,
578
- eval_dataset=eval_dataset,
579
- processing_class=processing_class,
580
- compute_metrics=compute_metrics,
581
- callbacks=callbacks,
582
- optimizers=optimizers,
583
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
584
- )
585
-
586
- # Add tags for models that have been loaded with the correct transformers version
587
- if hasattr(self.model, "add_model_tags"):
588
- self.model.add_model_tags(self._tag_names)
589
-
590
- self._beta = args.beta
591
-
592
- # Placed after the super[].__init__ because we need self.is_deepspeed_enabled and self.accelerator
593
- if self.is_deepspeed_enabled:
594
- if self.reward_model is not None:
595
- self.reward_model = prepare_deepspeed(
596
- self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
597
- )
598
- if self.ref_model is not None:
599
- self.ref_model = prepare_deepspeed(
600
- self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
601
- )
602
- else:
603
- if self.ref_model is not None:
604
- self.ref_model = self.ref_model.to(self.accelerator.device)
605
- if self.reward_model is not None:
606
- self.reward_model = self.reward_model.to(self.accelerator.device)
607
-
608
- @property
609
- def beta(self):
610
- if isinstance(self._beta, list):
611
- epoch = self.state.epoch
612
- return self._beta[epoch] if epoch < len(self._beta) else self._beta[-1]
613
- else:
614
- return self._beta
615
-
616
- @staticmethod
617
- def tokenize_row(feature, is_encoder_decoder: bool, tokenizer: PreTrainedTokenizerBase) -> dict[str, Any]:
618
- """Tokenize a single row from a DPO specific dataset."""
619
- if not is_encoder_decoder:
620
- batch = tokenizer(feature["prompt"], add_special_tokens=False)
621
- # Add BOS token to head of prompt. Avoid adding if it's already there
622
- if tokenizer.bos_token_id is not None:
623
- prompt_len_input_ids = len(batch["input_ids"])
624
- if prompt_len_input_ids == 0 or tokenizer.bos_token_id != batch["input_ids"][0]:
625
- batch["input_ids"] = [tokenizer.bos_token_id] + batch["input_ids"]
626
- batch["attention_mask"] = [1] + batch["attention_mask"]
627
- else:
628
- batch = tokenizer(feature["prompt"], add_special_tokens=True)
629
- batch = {f"prompt_{key}": value for key, value in batch.items()}
630
- return batch
631
-
632
- # Same as Trainer.get_train_dataloader but skip the "remove_unused_columns".
633
- @wraps(Trainer.get_train_dataloader)
634
- def get_train_dataloader(self) -> DataLoader:
635
- if self.train_dataset is None:
636
- raise ValueError("Trainer: training requires a train_dataset.")
637
-
638
- train_dataset = self.train_dataset
639
- data_collator = self.data_collator
640
- dataloader_params = {
641
- "batch_size": self._train_batch_size,
642
- "collate_fn": data_collator,
643
- "num_workers": self.args.dataloader_num_workers,
644
- "pin_memory": self.args.dataloader_pin_memory,
645
- "persistent_workers": self.args.dataloader_persistent_workers,
646
- }
647
-
648
- if not isinstance(train_dataset, torch.utils.data.IterableDataset):
649
- dataloader_params["sampler"] = self._get_train_sampler()
650
- dataloader_params["drop_last"] = self.args.dataloader_drop_last
651
- dataloader_params["worker_init_fn"] = seed_worker
652
- dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
653
-
654
- return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
655
-
656
- # Same as Trainer.get_eval_dataloader but skip the "remove_unused_columns".
657
- @wraps(Trainer.get_eval_dataloader)
658
- def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader:
659
- if eval_dataset is None and self.eval_dataset is None:
660
- raise ValueError("Trainer: evaluation requires an eval_dataset.")
661
-
662
- # If we have persistent workers, don't do a fork bomb especially as eval datasets
663
- # don't change during training
664
- dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
665
- if (
666
- hasattr(self, "_eval_dataloaders")
667
- and dataloader_key in self._eval_dataloaders
668
- and self.args.dataloader_persistent_workers
669
- ):
670
- return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])
671
-
672
- eval_dataset = (
673
- self.eval_dataset[eval_dataset]
674
- if isinstance(eval_dataset, str)
675
- else eval_dataset
676
- if eval_dataset is not None
677
- else self.eval_dataset
678
- )
679
- data_collator = self.data_collator
680
-
681
- dataloader_params = {
682
- "batch_size": self.args.eval_batch_size,
683
- "collate_fn": data_collator,
684
- "num_workers": self.args.dataloader_num_workers,
685
- "pin_memory": self.args.dataloader_pin_memory,
686
- "persistent_workers": self.args.dataloader_persistent_workers,
687
- }
688
-
689
- if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
690
- dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
691
- dataloader_params["drop_last"] = self.args.dataloader_drop_last
692
- dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
693
-
694
- # accelerator.free_memory() will destroy the references, so
695
- # we need to store the non-prepared version
696
- eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
697
- if self.args.dataloader_persistent_workers:
698
- if hasattr(self, "_eval_dataloaders"):
699
- self._eval_dataloaders[dataloader_key] = eval_dataloader
700
- else:
701
- self._eval_dataloaders = {dataloader_key: eval_dataloader}
702
-
703
- return self.accelerator.prepare(eval_dataloader)
704
-
705
- def _generate_vllm(self, model, prompts):
706
- eos_token_id = self.processing_class.eos_token_id
707
- pad_token_id = self.processing_class.pad_token_id
708
-
709
- # Load the latest weights
710
-
711
- pass
712
-
713
- pass
714
-
715
- if is_conversational({"prompt": prompts[0]}):
716
- outputs = self.llm.chat(prompts, self.generation_config, use_tqdm=False, lora_request = self.model.load_lora('online_dpo_trainer_lora_model', load_tensors = True))
717
- else:
718
- outputs = self.llm.generate(prompts, self.generation_config, use_tqdm=False, lora_request = self.model.load_lora('online_dpo_trainer_lora_model', load_tensors = True))
719
-
720
- completion_ids = [list(output.outputs[i].token_ids) for i in range(2) for output in outputs]
721
- prompt_ids = [list(output.prompt_token_ids) for _ in range(2) for output in outputs]
722
-
723
- # Create mask and pad the prompt and completion
724
- max_prompt_length = max(len(ids) for ids in prompt_ids)
725
- prompt_mask = [[0] * (max_prompt_length - len(ids)) + [1] * len(ids) for ids in prompt_ids]
726
- prompt_ids = [[pad_token_id] * (max_prompt_length - len(ids)) + ids for ids in prompt_ids]
727
- max_tokens = self.generation_config.max_tokens
728
- completion_mask = [[1] * len(ids) + [0] * (max_tokens - len(ids)) for ids in completion_ids]
729
- completion_ids = [
730
- ids + [eos_token_id] if ids[-1] != eos_token_id and len(ids) < max_tokens else ids
731
- for ids in completion_ids
732
- ]
733
- completion_ids = [ids + [pad_token_id] * (max_tokens - len(ids)) for ids in completion_ids]
734
-
735
- # Convert to tensors
736
- prompt_ids = torch.tensor(prompt_ids, device=self.accelerator.device)
737
- prompt_mask = torch.tensor(prompt_mask, device=self.accelerator.device)
738
- completion_ids = torch.tensor(completion_ids, device=self.accelerator.device)
739
- completion_mask = torch.tensor(completion_mask, device=self.accelerator.device)
740
-
741
- return prompt_ids, prompt_mask, completion_ids, completion_mask
742
-
743
- def _generate(self, model, prompts):
744
- eos_token_id = self.processing_class.eos_token_id
745
- pad_token_id = self.processing_class.pad_token_id
746
-
747
- # Apply chat template and tokenize the input. We do this on-the-fly to enable the use of reward models and
748
- # policies with different tokenizers / chat templates.
749
- inputs = [{"prompt": prompt} for prompt in prompts]
750
- inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
751
- inputs = [self.tokenize_row(x, model.config.is_encoder_decoder, self.processing_class) for x in inputs]
752
- inputs = self.data_collator(inputs)
753
-
754
- # Sample 2 completions per prompt of size `max_new_tokens` from the model
755
- inputs = self._prepare_inputs(inputs)
756
- prompt_ids = inputs["prompt_input_ids"].repeat(2, 1)
757
- prompt_mask = inputs["prompt_attention_mask"].repeat(2, 1)
758
- with unwrap_model_for_generation(
759
- model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
760
- ) as unwrapped_model:
761
- output = unwrapped_model.generate(
762
- input_ids=prompt_ids,
763
- attention_mask=prompt_mask,
764
- generation_config=self.generation_config,
765
- )
766
-
767
- completion_ids = output[:, prompt_ids.size(1) :]
768
- completion_ids, completion_mask = truncate_right(completion_ids, eos_token_id, pad_token_id)
769
-
770
- return prompt_ids, prompt_mask, completion_ids, completion_mask
771
-
772
- def _forward(self, model, prompt_ids, prompt_mask, completion_ids, completion_mask):
773
- # Get the number of tokens to truncate from prompt
774
- num_tokens_to_truncate = max(prompt_ids.size(1) + completion_ids.size(1) - self.max_length, 0)
775
-
776
- # Truncate left to avoid oom
777
- prompt_ids = prompt_ids[:, num_tokens_to_truncate:]
778
- prompt_mask = prompt_mask[:, num_tokens_to_truncate:]
779
-
780
- # Concat the prompt and completion
781
- prompt_completion_ids = torch.cat((prompt_ids, completion_ids), dim=1)
782
- prompt_completion_mask = torch.cat((prompt_mask, completion_mask), dim=1)
783
-
784
- # Get the logprobs of the completions from the model
785
- output = model(prompt_completion_ids, attention_mask=prompt_completion_mask)
786
-
787
- # There is 1 offset, because the model predict the next token
788
- logits = output.logits[:, prompt_ids.size(1) - 1 : -1]
789
-
790
- # Take the completion tokens logprob
791
- logprobs = torch.take_along_dim(logits.log_softmax(dim=-1), completion_ids.unsqueeze(-1), dim=2).squeeze(-1)
792
- return logprobs
793
-
794
- def training_step(
795
- self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
796
- ) -> torch.Tensor:
797
- model.train()
798
-
799
- prompts = inputs["prompt"]
800
- batch_size = len(prompts)
801
-
802
- if self.args.use_vllm:
803
- prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_vllm(model, prompts)
804
- else:
805
- prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate(model, prompts)
806
-
807
- contain_eos_token = torch.any(completion_ids == self.processing_class.eos_token_id, dim=-1)
808
-
809
- logprobs = self._forward(model, prompt_ids, prompt_mask, completion_ids, completion_mask)
810
- with torch.no_grad():
811
- if self.ref_model is not None:
812
- ref_logprobs = self._forward(self.ref_model, prompt_ids, prompt_mask, completion_ids, completion_mask)
813
- else: # peft case: we just need to disable the adapter
814
- with self.model.disable_adapter():
815
- ref_logprobs = self._forward(self.model, prompt_ids, prompt_mask, completion_ids, completion_mask)
816
-
817
- # Decode the completions, and format them if the input is conversational
818
- device = logprobs.device
819
- completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
820
- if is_conversational({"prompt": prompts[0]}):
821
- completions = [[{"role": "assistant", "content": completion}] for completion in completions]
822
-
823
- # Get the reward from the reward model or judge
824
- if self.judge is not None:
825
- # Once formatted, conversational data may contain special tokens (such as <|im_start|>) that are not
826
- # directly understandable by the judge and could alter its judgment. To avoid this and make the judge
827
- # independent of the model's chat template, we use the raw conversation data, and apply our own chat
828
- # template to it.
829
- if is_conversational({"prompt": prompts[0]}):
830
- environment = jinja2.Environment()
831
- template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
832
- prompts = [template.render(messages=prompt) for prompt in prompts]
833
- completions = [template.render(messages=completion) for completion in completions]
834
-
835
- ranks_of_first_completion = self.judge.judge(
836
- prompts, list(zip(completions[:batch_size], completions[batch_size:]))
837
- )
838
-
839
- # convert ranks to a True/False mask:
840
- # when rank == 0, it means the first completion is the best
841
- # when rank == 1, it means the second completion is the best
842
- mask = torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=device)
843
- else:
844
- # The reward model may not have the same chat template or tokenizer as the model, so we need to use the
845
- # raw data (string), apply the chat template (if needed), and tokenize it with the reward processing class.
846
- prompts = 2 * prompts # repeat the prompt: [prompt0, prompt1] -> [prompt0, prompt1, prompt0, prompt1]
847
- if is_conversational({"prompt": prompts[0]}):
848
- examples = [{"prompt": p, "completion": c} for p, c in zip(prompts, completions)]
849
- examples = [apply_chat_template(example, self.reward_processing_class) for example in examples]
850
- prompts = [example["prompt"] for example in examples]
851
- completions = [example["completion"] for example in examples]
852
-
853
- # Tokenize the prompts
854
- prompts_ids = self.reward_processing_class(
855
- prompts, padding=True, return_tensors="pt", padding_side="left"
856
- )["input_ids"].to(device)
857
- context_length = prompts_ids.shape[1]
858
-
859
- # Tokenize the completions
860
- completions_ids = self.reward_processing_class(
861
- completions, padding=True, return_tensors="pt", padding_side="right"
862
- )["input_ids"].to(device)
863
-
864
- # Concatenate the prompts and completions and get the reward
865
- prompt_completion_ids = torch.cat((prompts_ids, completions_ids), dim=1)
866
- with torch.inference_mode():
867
- _, scores, _ = get_reward(
868
- self.reward_model, prompt_completion_ids, self.reward_processing_class.pad_token_id, context_length
869
- )
870
-
871
- # Filter completion. Ensure that the sample contains stop_token_id
872
- # Completions not passing that filter will receive a lower score.
873
- if self.args.missing_eos_penalty is not None:
874
- scores[~contain_eos_token] -= self.args.missing_eos_penalty
875
-
876
- # Split the scores in 2 (the prompts of the first half are the same as the second half)
877
- first_half, second_half = scores.split(batch_size)
878
-
879
- # Get the indices of the chosen and rejected examples
880
- mask = first_half >= second_half
881
-
882
- batch_range = torch.arange(batch_size, device=device)
883
- chosen_indices = batch_range + (~mask * batch_size)
884
- rejected_indices = batch_range + (mask * batch_size)
885
-
886
- # Build tensor so that the first half is the chosen examples and the second half the rejected examples
887
- cr_indices = torch.cat((chosen_indices, rejected_indices), dim=0) # cr = chosen and rejected
888
- cr_logprobs = logprobs[cr_indices]
889
- cr_ref_logprobs = ref_logprobs[cr_indices]
890
-
891
- # mask out the padding tokens
892
- padding_mask = ~completion_mask.bool()
893
- cr_padding_mask = padding_mask[cr_indices]
894
-
895
- cr_logprobs_sum = (cr_logprobs * ~cr_padding_mask).sum(1)
896
- cr_ref_logprobs_sum = (cr_ref_logprobs * ~cr_padding_mask).sum(1)
897
-
898
- # Split the chosen and rejected examples
899
- chosen_logprobs_sum, rejected_logprobs_sum = torch.split(cr_logprobs_sum, batch_size)
900
- chosen_ref_logprobs_sum, rejected_ref_logprobs_sum = torch.split(cr_ref_logprobs_sum, batch_size)
901
- pi_logratios = chosen_logprobs_sum - rejected_logprobs_sum
902
- ref_logratios = chosen_ref_logprobs_sum - rejected_ref_logprobs_sum
903
-
904
- logits = pi_logratios - ref_logratios
905
-
906
- if self.args.loss_type == "sigmoid":
907
- losses = -F.logsigmoid(self.beta * logits)
908
- elif self.args.loss_type == "ipo":
909
- losses = (logits - 1 / (2 * self.beta)) ** 2
910
- else:
911
- raise NotImplementedError(f"invalid loss type {self.loss_type}")
912
-
913
- loss = losses.mean()
914
-
915
- # Log everything
916
- if self.reward_model is not None:
917
- scores_margin = scores[chosen_indices] - scores[rejected_indices]
918
- self.stats["objective/scores_margin"].append(
919
- self.accelerator.gather_for_metrics(scores_margin.mean()).mean().item()
920
- )
921
- self.stats["objective/scores"].append(self.accelerator.gather_for_metrics(scores.mean()).mean().item())
922
- self.stats["val/contain_eos_token"].append(contain_eos_token.float().mean().item())
923
- self.stats["logps/chosen"].append(self.accelerator.gather_for_metrics(chosen_logprobs_sum).mean().item())
924
- self.stats["logps/rejected"].append(self.accelerator.gather_for_metrics(rejected_logprobs_sum).mean().item())
925
-
926
- kl = logprobs - ref_logprobs
927
- mean_kl = kl.sum(1).mean()
928
- self.stats["objective/kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
929
- non_score_reward = (-self.beta * kl).sum(1)
930
- mean_non_score_reward = non_score_reward.mean()
931
- self.stats["objective/non_score_reward"].append(
932
- self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
933
- )
934
- if self.reward_model is not None:
935
- rlhf_reward = scores + non_score_reward
936
- self.stats["objective/rlhf_reward"].append(self.accelerator.gather_for_metrics(rlhf_reward).mean().item())
937
- mean_entropy = -logprobs.sum(1).mean()
938
- self.stats["objective/entropy"].append(self.accelerator.gather_for_metrics(mean_entropy).mean().item())
939
- chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum)
940
- gathered_chosen_rewards = self.accelerator.gather_for_metrics(chosen_rewards)
941
- self.stats["rewards/chosen"].append(gathered_chosen_rewards.mean().item())
942
- rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum)
943
- gathered_rejected_rewards = self.accelerator.gather_for_metrics(rejected_rewards)
944
- self.stats["rewards/rejected"].append(gathered_rejected_rewards.mean().item())
945
- margin = gathered_chosen_rewards - gathered_rejected_rewards
946
- self.stats["rewards/margins"].append(margin.mean().item())
947
- accuracy = margin > 0
948
- self.stats["rewards/accuracies"].append(accuracy.float().mean().item())
949
- self.stats["beta"].append(self.beta)
950
-
951
- if (
952
- self.args.torch_empty_cache_steps is not None
953
- and self.state.global_step % self.args.torch_empty_cache_steps == 0
954
- ):
955
- empty_cache()
956
-
957
- kwargs = {}
958
-
959
- # For LOMO optimizers you need to explicitly use the learnign rate
960
- if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
961
- kwargs["learning_rate"] = self._get_learning_rate()
962
-
963
- if self.args.n_gpu > 1:
964
- loss = loss.mean() # mean() to average on multi-gpu parallel training
965
-
966
- if self.use_apex:
967
- with amp.scale_loss(loss, self.optimizer) as scaled_loss:
968
- scaled_loss.backward()
969
- else:
970
- self.accelerator.backward(loss, **kwargs)
971
-
972
- return loss.detach() / self.args.gradient_accumulation_steps
973
-
974
- # Same as Trainer._maybe_log_save_evaluate but log our metrics
975
- # start_time defaults to None to allow compatibility with transformers<=4.46
976
- def _maybe_log_save_evaluate(
977
- self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None, learning_rate=None
978
- ):
979
- if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
980
- logs: dict[str, float] = {}
981
-
982
- # all_gather + mean() to get average loss over all processes
983
- tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
984
-
985
- # reset tr_loss to zero
986
- tr_loss -= tr_loss
987
-
988
- logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
989
- if grad_norm is not None:
990
- logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
991
- if learning_rate is not None:
992
- logs["learning_rate"] = learning_rate
993
- else:
994
- logs["learning_rate"] = self._get_learning_rate()
995
-
996
- # Add our metrics
997
- for key, val in self.stats.items():
998
- logs[key] = sum(val) / len(val)
999
- self.stats = {key: [] for key in self.stats} # reset stats
1000
-
1001
- self._total_loss_scalar += tr_loss_scalar
1002
- self._globalstep_last_logged = self.state.global_step
1003
- self.store_flos()
1004
-
1005
- if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1006
- self.log(logs, start_time)
1007
- else: # transformers<=4.46
1008
- self.log(logs)
1009
-
1010
- metrics = None
1011
- if self.control.should_evaluate:
1012
- metrics = self._evaluate(trial, ignore_keys_for_eval)
1013
- is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
1014
-
1015
- if self.args.save_strategy == "best":
1016
- self.control.should_save = is_new_best_metric
1017
-
1018
- if self.control.should_save:
1019
- self._save_checkpoint(model, trial)
1020
- self.control = self.callback_handler.on_save(self.args, self.state, self.control)
1021
-
1022
- # Copy-pasted from transformers.Trainer to maintain compatibility with earlier versions.
1023
- # This can be removed once the minimum transformers version is updated to 4.47.
1024
- # Refer to https://github.com/huggingface/trl/pull/2288 for more details.
1025
- def _determine_best_metric(self, metrics, trial):
1026
- """
1027
- Determine if the model should be saved based on the evaluation metrics.
1028
- If args.metric_for_best_model is not set, the loss is used.
1029
- Returns:
1030
- bool: True if a new best metric was found, else False
1031
- """
1032
- is_new_best_metric = False
1033
-
1034
- if self.args.metric_for_best_model is not None:
1035
- metric_to_check = self.args.metric_for_best_model
1036
-
1037
- if not metric_to_check.startswith("eval_"):
1038
- metric_to_check = f"eval_{metric_to_check}"
1039
-
1040
- try:
1041
- metric_value = metrics[metric_to_check]
1042
- except KeyError as exc:
1043
- raise KeyError(
1044
- f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. "
1045
- f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments."
1046
- ) from exc
1047
-
1048
- operator = np.greater if self.args.greater_is_better else np.less
1049
-
1050
- if self.state.best_metric is None:
1051
- self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf")
1052
-
1053
- if operator(metric_value, self.state.best_metric):
1054
- run_dir = self._get_output_dir(trial=trial)
1055
- checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
1056
- output_dir = os.path.join(run_dir, checkpoint_folder)
1057
- self.state.best_metric = metric_value
1058
- self.state.best_model_checkpoint = output_dir
1059
-
1060
- is_new_best_metric = True
1061
-
1062
- return is_new_best_metric
1063
-
1064
- def create_model_card(
1065
- self,
1066
- model_name: Optional[str] = None,
1067
- dataset_name: Optional[str] = None,
1068
- tags: Union[str, list[str], None] = None,
1069
- ):
1070
- """
1071
- Creates a draft of a model card using the information available to the `Trainer`.
1072
-
1073
- Args:
1074
- model_name (`str` or `None`, *optional*, defaults to `None`):
1075
- Name of the model.
1076
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
1077
- Name of the dataset used for training.
1078
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1079
- Tags to be associated with the model card.
1080
- """
1081
- if not self.is_world_process_zero():
1082
- return
1083
-
1084
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1085
- base_model = self.model.config._name_or_path
1086
- else:
1087
- base_model = None
1088
-
1089
- tags = tags or []
1090
- if isinstance(tags, str):
1091
- tags = [tags]
1092
-
1093
- if hasattr(self.model.config, "unsloth_version"):
1094
- tags.append("unsloth")
1095
-
1096
- citation = textwrap.dedent("""\
1097
- @article{guo2024direct,
1098
- title = {{Direct Language Model Alignment from Online AI Feedback}},
1099
- author = {Shangmin Guo and Biao Zhang and Tianlin Liu and Tianqi Liu and Misha Khalman and Felipe Llinares and Alexandre Ram{\'{e}} and Thomas Mesnard and Yao Zhao and Bilal Piot and Johan Ferret and Mathieu Blondel},
1100
- year = 2024,
1101
- eprint = {arXiv:2402.04792}
1102
- }""")
1103
-
1104
- model_card = generate_model_card(
1105
- base_model=base_model,
1106
- model_name=model_name,
1107
- hub_model_id=self.hub_model_id,
1108
- dataset_name=dataset_name,
1109
- tags=tags,
1110
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1111
- comet_url=get_comet_experiment_url(),
1112
- trainer_name="Online DPO",
1113
- trainer_citation=citation,
1114
- paper_title="Direct Language Model Alignment from Online AI Feedback",
1115
- paper_id="2402.04792",
1116
- )
1117
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
1118
- class UnslothOnlineDPOTrainer(_UnslothOnlineDPOTrainer):
1119
- """
1120
-
1121
- Initialize OnlineDPOTrainer.
1122
-
1123
- Args:
1124
- model (`transformers.PreTrainedModel` or `torch.nn.Module`):
1125
- The model to train, preferably an `AutoModelForCausalLM`.
1126
- ref_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`):
1127
- The reference model to use for training. If None is specified, the reference model will be created from
1128
- the model.
1129
- reward_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`):
1130
- The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
1131
- judge (`BasePairwiseJudge`):
1132
- The judge to use for pairwise comparison of model completions.
1133
- args (`OnlineDPOConfig`):
1134
- The online DPO config arguments to use for training.
1135
- data_collator (`transformers.DataCollator`):
1136
- The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
1137
- which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
1138
- train_dataset (`datasets.Dataset`):
1139
- The dataset to use for training.
1140
- eval_dataset (`datasets.Dataset`):
1141
- The dataset to use for evaluation.
1142
- processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
1143
- Processing class used to process the data. If provided, will be used to automatically process the inputs
1144
- for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
1145
- reuse the fine-tuned model.
1146
- peft_config (`dict`):
1147
- The peft config to use for training.
1148
- compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
1149
- The function to use to compute the metrics. Must take a `EvalPrediction` and return
1150
- a dictionary string to metric values.
1151
- callbacks (`list[transformers.TrainerCallback]`):
1152
- The callbacks to use for training.
1153
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
1154
- The optimizer and scheduler to use for training.
1155
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
1156
- The function to use to preprocess the logits before computing the metrics.
1157
-
1158
- """
1159
- def __init__(
1160
- self,
1161
- model,
1162
- ref_model = None,
1163
- reward_model = None,
1164
- judge = None,
1165
- args = None,
1166
- data_collator = None,
1167
- train_dataset = None,
1168
- eval_dataset = None,
1169
- processing_class = None,
1170
- reward_processing_class = None,
1171
- peft_config = None,
1172
- compute_metrics = None,
1173
- callbacks = None,
1174
- preprocess_logits_for_metrics = None,
1175
- **kwargs
1176
- ):
1177
- if args is None: args = UnslothOnlineDPOConfig()
1178
- use_bf16 = getattr(args, 'bf16', False)
1179
- if type(use_bf16) is not bool: use_bf16 = False
1180
- use_fp16 = getattr(args, 'fp16', False)
1181
- if type(use_fp16) is not bool: use_fp16 = False
1182
- force_float32 = False
1183
- if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1184
- print('Unsloth: Switching to float32 training since model cannot work with float16')
1185
- force_float32 = True
1186
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1187
- dtype = getattr(model.config, 'torch_dtype', None)
1188
- if dtype is None: dtype = model.get_input_embeddings().dtype
1189
- from unsloth_zoo.utils import _get_dtype
1190
- dtype = _get_dtype(dtype)
1191
- float16 = dtype == torch.float16
1192
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1193
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1194
- if force_float32:
1195
- args.fp16 = False
1196
- args.bf16 = False
1197
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1198
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1199
- args.fp16 = float16
1200
- args.bf16 = not float16
1201
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1202
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1203
- args.eval_strategy = 'steps'
1204
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1205
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1206
- if ga_steps is not None and ga_steps > 1:
1207
- from transformers import __version__ as transformers_version
1208
- if Version(transformers_version) <= Version('4.45.2'):
1209
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1210
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1211
- if getattr(args, 'eval_strategy', 'no') != 'no':
1212
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1213
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1214
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1215
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1216
- if type(fp16_full_eval) is not bool: fp16_full_eval = False
1217
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1218
- if type(bf16_full_eval) is not bool: bf16_full_eval = False
1219
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1220
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1221
- if force_float32:
1222
- args.bf16_full_eval = False
1223
- args.fp16_full_eval = False
1224
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1225
- args.bf16_full_eval = True
1226
- args.fp16_full_eval = False
1227
- elif not bf16_full_eval and not fp16_full_eval:
1228
- args.bf16_full_eval = args.bf16
1229
- args.fp16_full_eval = args.fp16
1230
- _output_logits = False
1231
- if locals().get('compute_metrics', None) is not None: _output_logits = True
1232
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1233
- if _output_logits:
1234
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1235
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1236
- pass
1237
- else:
1238
- model_max_seq_length = getattr(model, 'max_seq_length', None)
1239
- args_max_seq_length = getattr(args, 'max_seq_length', None)
1240
- if args_max_seq_length is None and model_max_seq_length is not None:
1241
- max_seq_length = model.max_seq_length
1242
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1243
- if model is not None and hasattr(model, 'for_training'):
1244
- model.for_training()
1245
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1246
- if 'processing_class' in locals():
1247
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1248
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1249
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1250
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1251
- if not isinstance(data_collator, UnslothVisionDataCollator):
1252
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1253
- data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
1254
- elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1255
- data_collator = DataCollatorForSeq2Seq(__tokenizer)
1256
- else:
1257
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1258
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1259
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1260
- if not isinstance(data_collator, UnslothVisionDataCollator):
1261
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1262
- if isinstance(data_collator, DataCollatorForSeq2Seq):
1263
- data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1264
- else:
1265
- data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
1266
- other_metrics = []
1267
-
1268
- from unsloth_zoo.logging_utils import PatchRLStatistics
1269
- PatchRLStatistics('online_dpo_trainer', other_metrics)
1270
-
1271
- super().__init__(
1272
- model = model,
1273
- ref_model = ref_model,
1274
- reward_model = reward_model,
1275
- judge = judge,
1276
- args = args,
1277
- data_collator = data_collator,
1278
- train_dataset = train_dataset,
1279
- eval_dataset = eval_dataset,
1280
- processing_class = processing_class,
1281
- reward_processing_class = reward_processing_class,
1282
- peft_config = peft_config,
1283
- compute_metrics = compute_metrics,
1284
- callbacks = callbacks,
1285
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
1286
- if hasattr(self, 'neftune_hook_handle'):
1287
- self.neftune_hook_handle.remove()
1288
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1289
- if getattr(args, 'neftune_noise_alpha', None) is not None:
1290
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1291
- pass
1292
-
1293
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_run_uploads/UnslothPPOTrainer.py DELETED
@@ -1,1273 +0,0 @@
1
- """
2
- 2025.7.11
3
- 2025.7.11
4
- 4.54.1
5
- 0.16.1
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
13
- from trl.trainer.ppo_trainer import (Accelerator, BaseImageProcessor, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, GenerationConfig, INVALID_LOGPROB, OnlineTrainerState, Optional, PPOConfig, PPOTrainer, PeftConfig, PeftModel, PolicyAndValueWrapper, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, Trainer, TrainerCallback, TrainerControl, Union, batch_generation, broadcast, contextmanager, create_reference_model, defaultdict, disable_dropout_in_model, exact_div, first_true_indices, forward, gather_object, gc, generate_model_card, get_comet_experiment_url, get_peft_model, get_reporting_integration_callbacks, get_reward, is_peft_available, is_wandb_available, log_table_to_comet_experiment, masked_mean, masked_whiten, math, nn, np, nullcontext, os, pd, peft_module_casting_to_bf16, prepare_deepspeed, print_rich_table, selective_log_softmax, textwrap, time, torch, truncate_response, unwrap_model_for_generation, wandb)
14
-
15
-
16
- import os
17
- from typing import *
18
- from dataclasses import dataclass, field
19
- from packaging.version import Version
20
- import torch
21
- import numpy as np
22
- from contextlib import nullcontext
23
- from torch.nn import functional as F
24
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
25
-
26
- torch_compile_options = {
27
- "epilogue_fusion" : True,
28
- "max_autotune" : False,
29
- "shape_padding" : True,
30
- "trace.enabled" : False,
31
- "triton.cudagraphs" : False,
32
- }
33
-
34
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
35
- def chunked_selective_log_softmax(logits, index):
36
- # Split into 4 chunks only
37
- chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
38
- chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
39
- all_per_token_logps = []
40
- # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
41
- for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
42
- chunk_logits = chunk_logits.to(torch.float32)
43
- selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
44
- logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
45
- per_token_logps = selected_logits - logsumexp_values
46
- all_per_token_logps.append(per_token_logps)
47
- pass
48
- all_per_token_logps = torch.concat(all_per_token_logps)
49
- all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
50
- return all_per_token_logps
51
- @dataclass
52
- class UnslothPPOConfig(PPOConfig):
53
- """
54
-
55
- Configuration class for the [`PPOTrainer`].
56
-
57
- Using [`~transformers.HfArgumentParser`] we can turn this class into
58
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
59
- command line.
60
-
61
- Parameters:
62
- exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`):
63
- Name of this experiment.
64
- reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
65
- Path to the reward model.
66
- model_adapter_name (`str` or `None`, *optional*, defaults to `None`):
67
- Name of the train target PEFT adapter, when using LoRA with multiple adapters.
68
- ref_adapter_name (`str` or `None`, *optional*, defaults to `None`):
69
- Name of the reference PEFT adapter, when using LoRA with multiple adapters.
70
- num_ppo_epochs (`int`, *optional*, defaults to `4`):
71
- Number of epochs to train.
72
- whiten_rewards (`bool`, *optional*, defaults to `False`):
73
- Whether to whiten the rewards.
74
- kl_coef (`float`, *optional*, defaults to `0.05`):
75
- KL coefficient.
76
- cliprange (`float`, *optional*, defaults to `0.2`):
77
- Clip range.
78
- vf_coef (`float`, *optional*, defaults to `0.1`):
79
- Value function coefficient.
80
- cliprange_value (`float`, *optional*, defaults to `0.2`):
81
- Clip range for the value function.
82
- gamma (`float`, *optional*, defaults to `1.0`):
83
- Discount factor.
84
- lam (`float`, *optional*, defaults to `0.95`):
85
- Lambda value for GAE.
86
- ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
87
- This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
88
- improving generation speed. However, disabling this option allows training models that exceed the VRAM
89
- capacity of a single GPU, albeit at the cost of slower generation.
90
-
91
- """
92
- vllm_sampling_params: Optional[Any] = field(
93
- default = None,
94
- metadata = {'help': 'vLLM SamplingParams'},
95
- )
96
- unsloth_num_chunks : Optional[int] = field(
97
- default = -1,
98
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
99
- )
100
- def __init__(
101
- self,
102
- output_dir = None,
103
- overwrite_output_dir = None,
104
- do_train = False,
105
- do_eval = False,
106
- do_predict = False,
107
- eval_strategy = 'no',
108
- prediction_loss_only = False,
109
- per_device_train_batch_size = 4,
110
- per_device_eval_batch_size = 4,
111
- per_gpu_train_batch_size = None,
112
- per_gpu_eval_batch_size = None,
113
- gradient_accumulation_steps = 2,
114
- eval_accumulation_steps = 2,
115
- eval_delay = 0,
116
- torch_empty_cache_steps = 250,
117
- learning_rate = 5e-05,
118
- weight_decay = 0.01,
119
- adam_beta1 = 0.9,
120
- adam_beta2 = 0.999,
121
- adam_epsilon = 1e-08,
122
- max_grad_norm = 1.0,
123
- num_train_epochs = 3.0,
124
- max_steps = -1,
125
- lr_scheduler_type = 'linear',
126
- warmup_ratio = 0.1,
127
- warmup_steps = 0,
128
- log_level = 'passive',
129
- log_level_replica = 'warning',
130
- log_on_each_node = True,
131
- logging_dir = None,
132
- logging_strategy = 'steps',
133
- logging_first_step = False,
134
- logging_steps = 1,
135
- logging_nan_inf_filter = False,
136
- save_strategy = 'steps',
137
- save_steps = 500,
138
- save_total_limit = None,
139
- save_safetensors = True,
140
- save_on_each_node = False,
141
- save_only_model = False,
142
- restore_callback_states_from_checkpoint = False,
143
- no_cuda = False,
144
- use_cpu = False,
145
- use_mps_device = False,
146
- seed = 3407,
147
- data_seed = 3407,
148
- jit_mode_eval = False,
149
- use_ipex = False,
150
- bf16 = False,
151
- fp16 = False,
152
- fp16_opt_level = 'O1',
153
- half_precision_backend = 'auto',
154
- bf16_full_eval = False,
155
- fp16_full_eval = False,
156
- tf32 = None,
157
- local_rank = -1,
158
- ddp_backend = None,
159
- tpu_num_cores = None,
160
- tpu_metrics_debug = False,
161
- debug = '',
162
- dataloader_drop_last = False,
163
- eval_steps = None,
164
- dataloader_num_workers = 0,
165
- dataloader_prefetch_factor = None,
166
- past_index = -1,
167
- run_name = None,
168
- disable_tqdm = None,
169
- remove_unused_columns = True,
170
- label_names = None,
171
- load_best_model_at_end = False,
172
- metric_for_best_model = None,
173
- greater_is_better = None,
174
- ignore_data_skip = False,
175
- fsdp = '',
176
- fsdp_min_num_params = 0,
177
- fsdp_config = None,
178
- fsdp_transformer_layer_cls_to_wrap = None,
179
- accelerator_config = None,
180
- deepspeed = None,
181
- label_smoothing_factor = 0.0,
182
- optim = 'adamw_8bit',
183
- optim_args = None,
184
- adafactor = False,
185
- group_by_length = False,
186
- length_column_name = 'length',
187
- report_to = None,
188
- ddp_find_unused_parameters = None,
189
- ddp_bucket_cap_mb = None,
190
- ddp_broadcast_buffers = None,
191
- dataloader_pin_memory = True,
192
- dataloader_persistent_workers = False,
193
- skip_memory_metrics = True,
194
- use_legacy_prediction_loop = False,
195
- push_to_hub = False,
196
- resume_from_checkpoint = None,
197
- hub_model_id = None,
198
- hub_strategy = 'every_save',
199
- hub_token = None,
200
- hub_private_repo = None,
201
- hub_always_push = False,
202
- hub_revision = None,
203
- gradient_checkpointing = False,
204
- gradient_checkpointing_kwargs = None,
205
- include_inputs_for_metrics = False,
206
- eval_do_concat_batches = True,
207
- fp16_backend = 'auto',
208
- push_to_hub_model_id = None,
209
- push_to_hub_organization = None,
210
- push_to_hub_token = None,
211
- mp_parameters = '',
212
- auto_find_batch_size = True,
213
- full_determinism = False,
214
- torchdynamo = None,
215
- ray_scope = 'last',
216
- ddp_timeout = 1800,
217
- torch_compile = False,
218
- torch_compile_backend = None,
219
- torch_compile_mode = None,
220
- include_tokens_per_second = False,
221
- include_num_input_tokens_seen = False,
222
- neftune_noise_alpha = None,
223
- optim_target_modules = None,
224
- batch_eval_metrics = False,
225
- eval_on_start = False,
226
- use_liger_kernel = False,
227
- liger_kernel_config = None,
228
- eval_use_gather_object = False,
229
- average_tokens_across_devices = True,
230
- dataset_num_proc = None,
231
- num_mini_batches = 1,
232
- total_episodes = None,
233
- local_rollout_forward_batch_size = 64,
234
- num_sample_generations = 10,
235
- response_length = 53,
236
- stop_token = None,
237
- stop_token_id = None,
238
- temperature = 0.7,
239
- missing_eos_penalty = None,
240
- sft_model_path = 'EleutherAI/pythia-160m',
241
- world_size = None,
242
- num_total_batches = None,
243
- micro_batch_size = None,
244
- local_batch_size = None,
245
- batch_size = None,
246
- local_mini_batch_size = None,
247
- mini_batch_size = None,
248
- exp_name = 'ppo_config',
249
- reward_model_path = 'EleutherAI/pythia-160m',
250
- model_adapter_name = None,
251
- ref_adapter_name = None,
252
- num_ppo_epochs = 4,
253
- whiten_rewards = False,
254
- kl_coef = 0.05,
255
- cliprange = 0.2,
256
- vf_coef = 0.1,
257
- cliprange_value = 0.2,
258
- gamma = 1.0,
259
- lam = 0.95,
260
- ds3_gather_for_generation = True,
261
- vllm_sampling_params = None,
262
- unsloth_num_chunks = -1,
263
- **kwargs,
264
- ):
265
- if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
266
- if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
267
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
268
- output_dir = 'unsloth_training_checkpoints'
269
- save_strategy = 'no'
270
- if dataset_num_proc is None:
271
- from multiprocessing import cpu_count
272
- dataset_num_proc = min(cpu_count()*2, 2)
273
- if temperature <= 0:
274
- raise MathError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
275
- elif temperature >= 10:
276
- raise MathError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
277
-
278
-
279
- super().__init__(
280
- output_dir = output_dir,
281
- overwrite_output_dir = overwrite_output_dir,
282
- do_train = do_train,
283
- do_eval = do_eval,
284
- do_predict = do_predict,
285
- eval_strategy = eval_strategy,
286
- prediction_loss_only = prediction_loss_only,
287
- per_device_train_batch_size = per_device_train_batch_size,
288
- per_device_eval_batch_size = per_device_eval_batch_size,
289
- per_gpu_train_batch_size = per_gpu_train_batch_size,
290
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
291
- gradient_accumulation_steps = gradient_accumulation_steps,
292
- eval_accumulation_steps = eval_accumulation_steps,
293
- eval_delay = eval_delay,
294
- torch_empty_cache_steps = torch_empty_cache_steps,
295
- learning_rate = learning_rate,
296
- weight_decay = weight_decay,
297
- adam_beta1 = adam_beta1,
298
- adam_beta2 = adam_beta2,
299
- adam_epsilon = adam_epsilon,
300
- max_grad_norm = max_grad_norm,
301
- num_train_epochs = num_train_epochs,
302
- max_steps = max_steps,
303
- lr_scheduler_type = lr_scheduler_type,
304
- warmup_ratio = warmup_ratio,
305
- warmup_steps = warmup_steps,
306
- log_level = log_level,
307
- log_level_replica = log_level_replica,
308
- log_on_each_node = log_on_each_node,
309
- logging_dir = logging_dir,
310
- logging_strategy = logging_strategy,
311
- logging_first_step = logging_first_step,
312
- logging_steps = logging_steps,
313
- logging_nan_inf_filter = logging_nan_inf_filter,
314
- save_strategy = save_strategy,
315
- save_steps = save_steps,
316
- save_total_limit = save_total_limit,
317
- save_safetensors = save_safetensors,
318
- save_on_each_node = save_on_each_node,
319
- save_only_model = save_only_model,
320
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
321
- no_cuda = no_cuda,
322
- use_cpu = use_cpu,
323
- use_mps_device = use_mps_device,
324
- seed = seed,
325
- data_seed = data_seed,
326
- jit_mode_eval = jit_mode_eval,
327
- use_ipex = use_ipex,
328
- bf16 = bf16,
329
- fp16 = fp16,
330
- fp16_opt_level = fp16_opt_level,
331
- half_precision_backend = half_precision_backend,
332
- bf16_full_eval = bf16_full_eval,
333
- fp16_full_eval = fp16_full_eval,
334
- tf32 = tf32,
335
- local_rank = local_rank,
336
- ddp_backend = ddp_backend,
337
- tpu_num_cores = tpu_num_cores,
338
- tpu_metrics_debug = tpu_metrics_debug,
339
- debug = debug,
340
- dataloader_drop_last = dataloader_drop_last,
341
- eval_steps = eval_steps,
342
- dataloader_num_workers = dataloader_num_workers,
343
- dataloader_prefetch_factor = dataloader_prefetch_factor,
344
- past_index = past_index,
345
- run_name = run_name,
346
- disable_tqdm = disable_tqdm,
347
- remove_unused_columns = remove_unused_columns,
348
- label_names = label_names,
349
- load_best_model_at_end = load_best_model_at_end,
350
- metric_for_best_model = metric_for_best_model,
351
- greater_is_better = greater_is_better,
352
- ignore_data_skip = ignore_data_skip,
353
- fsdp = fsdp,
354
- fsdp_min_num_params = fsdp_min_num_params,
355
- fsdp_config = fsdp_config,
356
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
357
- accelerator_config = accelerator_config,
358
- deepspeed = deepspeed,
359
- label_smoothing_factor = label_smoothing_factor,
360
- optim = optim,
361
- optim_args = optim_args,
362
- adafactor = adafactor,
363
- group_by_length = group_by_length,
364
- length_column_name = length_column_name,
365
- report_to = report_to,
366
- ddp_find_unused_parameters = ddp_find_unused_parameters,
367
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
368
- ddp_broadcast_buffers = ddp_broadcast_buffers,
369
- dataloader_pin_memory = dataloader_pin_memory,
370
- dataloader_persistent_workers = dataloader_persistent_workers,
371
- skip_memory_metrics = skip_memory_metrics,
372
- use_legacy_prediction_loop = use_legacy_prediction_loop,
373
- push_to_hub = push_to_hub,
374
- resume_from_checkpoint = resume_from_checkpoint,
375
- hub_model_id = hub_model_id,
376
- hub_strategy = hub_strategy,
377
- hub_token = hub_token,
378
- hub_private_repo = hub_private_repo,
379
- hub_always_push = hub_always_push,
380
- hub_revision = hub_revision,
381
- gradient_checkpointing = gradient_checkpointing,
382
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
383
- include_inputs_for_metrics = include_inputs_for_metrics,
384
- eval_do_concat_batches = eval_do_concat_batches,
385
- fp16_backend = fp16_backend,
386
- push_to_hub_model_id = push_to_hub_model_id,
387
- push_to_hub_organization = push_to_hub_organization,
388
- push_to_hub_token = push_to_hub_token,
389
- mp_parameters = mp_parameters,
390
- auto_find_batch_size = auto_find_batch_size,
391
- full_determinism = full_determinism,
392
- torchdynamo = torchdynamo,
393
- ray_scope = ray_scope,
394
- ddp_timeout = ddp_timeout,
395
- torch_compile = torch_compile,
396
- torch_compile_backend = torch_compile_backend,
397
- torch_compile_mode = torch_compile_mode,
398
- include_tokens_per_second = include_tokens_per_second,
399
- include_num_input_tokens_seen = include_num_input_tokens_seen,
400
- neftune_noise_alpha = neftune_noise_alpha,
401
- optim_target_modules = optim_target_modules,
402
- batch_eval_metrics = batch_eval_metrics,
403
- eval_on_start = eval_on_start,
404
- use_liger_kernel = use_liger_kernel,
405
- liger_kernel_config = liger_kernel_config,
406
- eval_use_gather_object = eval_use_gather_object,
407
- average_tokens_across_devices = average_tokens_across_devices,
408
- dataset_num_proc = dataset_num_proc,
409
- num_mini_batches = num_mini_batches,
410
- total_episodes = total_episodes,
411
- local_rollout_forward_batch_size = local_rollout_forward_batch_size,
412
- num_sample_generations = num_sample_generations,
413
- response_length = response_length,
414
- stop_token = stop_token,
415
- stop_token_id = stop_token_id,
416
- temperature = temperature,
417
- missing_eos_penalty = missing_eos_penalty,
418
- sft_model_path = sft_model_path,
419
- world_size = world_size,
420
- num_total_batches = num_total_batches,
421
- micro_batch_size = micro_batch_size,
422
- local_batch_size = local_batch_size,
423
- batch_size = batch_size,
424
- local_mini_batch_size = local_mini_batch_size,
425
- mini_batch_size = mini_batch_size,
426
- exp_name = exp_name,
427
- reward_model_path = reward_model_path,
428
- model_adapter_name = model_adapter_name,
429
- ref_adapter_name = ref_adapter_name,
430
- num_ppo_epochs = num_ppo_epochs,
431
- whiten_rewards = whiten_rewards,
432
- kl_coef = kl_coef,
433
- cliprange = cliprange,
434
- vf_coef = vf_coef,
435
- cliprange_value = cliprange_value,
436
- gamma = gamma,
437
- lam = lam,
438
- ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
439
- self.vllm_sampling_params = vllm_sampling_params
440
- self.unsloth_num_chunks = unsloth_num_chunks
441
- pass
442
-
443
- class _UnslothPPOTrainer(Trainer):
444
- _tag_names = ["trl", "ppo"]
445
-
446
- def __init__(
447
- self,
448
- args: PPOConfig,
449
- processing_class: Optional[
450
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
451
- ],
452
- model: nn.Module,
453
- ref_model: Optional[nn.Module],
454
- reward_model: nn.Module,
455
- train_dataset: Dataset,
456
- value_model: Optional[nn.Module] = None,
457
- data_collator: Optional[DataCollatorWithPadding] = None,
458
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
459
- # less commonly used
460
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
461
- callbacks: Optional[list[TrainerCallback]] = None,
462
- peft_config: Optional["PeftConfig"] = None,
463
- ) -> None:
464
- if ref_model is model:
465
- raise ValueError(
466
- "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
467
- "same as `model`, you must make a copy of it, or `None` if you use peft."
468
- )
469
-
470
- self.args = args
471
- self.processing_class = processing_class
472
- self.policy_model = model
473
-
474
- # Define the collator if not provided
475
- if data_collator is None:
476
- data_collator = DataCollatorWithPadding(self.processing_class)
477
-
478
- # Handle stop token settings: update policy model's generation_config to use provided stop token
479
- if args.stop_token and args.stop_token_id:
480
- raise ValueError("You cannot set both `stop_token` and `stop_token_id`.")
481
- elif args.stop_token:
482
- if args.stop_token == "eos":
483
- self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id
484
- else:
485
- raise ValueError(
486
- f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)."
487
- )
488
- else:
489
- self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int
490
-
491
- # peft support
492
- if not is_peft_available() and peft_config is not None:
493
- raise ImportError(
494
- "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
495
- )
496
- elif is_peft_available() and peft_config is not None:
497
- # if model is a peft model and we have a peft_confg, we merge and unload it first
498
- if isinstance(self.policy_model, PeftModel):
499
- self.policy_model = self.policy_model.merge_and_unload()
500
-
501
- # get peft model with the given config
502
- self.policy_model = get_peft_model(self.policy_model, peft_config)
503
- if args.bf16 and getattr(self.policy_model, "is_loaded_in_4bit", False):
504
- peft_module_casting_to_bf16(self.policy_model)
505
-
506
- self.is_peft_model = is_peft_available() and isinstance(self.policy_model, PeftModel)
507
- self.model_adapter_name = args.model_adapter_name
508
- self.ref_adapter_name = args.ref_adapter_name
509
-
510
- if ref_model:
511
- self.ref_model = ref_model
512
- elif self.is_peft_model:
513
- self.ref_model = None
514
- else:
515
- self.ref_model = create_reference_model(self.policy_model)
516
-
517
- self.reward_model = reward_model
518
- self.train_dataset = train_dataset
519
- self.train_dataset_len = len(train_dataset)
520
- self.value_model = value_model
521
- self.data_collator = data_collator
522
- self.eval_dataset = eval_dataset
523
- self.optimizer, self.lr_scheduler = optimizers
524
- self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
525
-
526
- #########
527
- # calculate various batch sizes
528
- #########
529
- if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
530
- args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
531
- accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
532
- self.accelerator = accelerator
533
- args.world_size = accelerator.num_processes
534
- args.local_batch_size = (
535
- args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches
536
- )
537
- args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
538
- args.batch_size = int(args.local_batch_size * args.world_size)
539
- args.mini_batch_size = exact_div(
540
- args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
541
- )
542
- args.local_mini_batch_size = exact_div(
543
- args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
544
- )
545
- if args.whiten_rewards:
546
- assert args.local_mini_batch_size >= 8, (
547
- f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening"
548
- )
549
- # `per_rank_rollout_batch_size` is our `args.local_batch_size`
550
- # `per_rank_minibatch_size` is our `args.local_mini_batch_size`
551
- args.num_total_batches = math.ceil(
552
- args.total_episodes / args.batch_size
553
- ) # we may train for more than `total_episodes`
554
- time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
555
- time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
556
- args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
557
- self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
558
- if args.num_sample_generations > 0:
559
- self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
560
- self.local_dataloader_batch_size = args.local_batch_size
561
-
562
- #########
563
- # setup model, optimizer, and others
564
- #########
565
- for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]:
566
- if module is not None:
567
- disable_dropout_in_model(module)
568
- self.model = PolicyAndValueWrapper(self.policy_model, self.value_model)
569
- self.model.config = self.policy_model.config # needed for pushing to hub
570
- self.create_optimizer_and_scheduler(
571
- num_training_steps=args.num_total_batches
572
- ) # note that we are calling `self.lr_scheduler.step[]` manually only at the batch level
573
-
574
- #########
575
- ### trainer specifics
576
- #########
577
- default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
578
- self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
579
- self.callback_handler = CallbackHandler(
580
- self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
581
- )
582
- self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
583
- self.control = TrainerControl()
584
- self.state = OnlineTrainerState(
585
- is_local_process_zero=self.is_local_process_zero(),
586
- is_world_process_zero=self.is_world_process_zero(),
587
- stateful_callbacks=[
588
- cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
589
- ],
590
- )
591
- self.current_flos = 0
592
- self.hp_search_backend = None
593
- self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
594
- self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
595
- # Create distant repo and output directory if needed
596
- self.hub_model_id = None
597
- if self.args.push_to_hub:
598
- self.init_hf_repo()
599
- if self.args.should_save:
600
- os.makedirs(self.args.output_dir, exist_ok=True)
601
-
602
- # Add tags for models that have been loaded with the correct transformers version
603
- if hasattr(self.model, "add_model_tags"):
604
- self.model.add_model_tags(self._tag_names)
605
-
606
- #########
607
- ### setup dataloader
608
- #########
609
- self.dataloader = DataLoader(
610
- self.train_dataset,
611
- batch_size=self.local_dataloader_batch_size,
612
- shuffle=True,
613
- collate_fn=self.data_collator,
614
- drop_last=True, # needed; otherwise the last batch will be of ragged shape
615
- )
616
- # sync random states for DataLoader[shuffle=True] before `accelerator.prepare`
617
- # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
618
- torch.manual_seed(args.seed)
619
- self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
620
- torch.manual_seed(self.local_seed) # reset the local seed again
621
-
622
- self.eval_dataloader = DataLoader(
623
- self.eval_dataset,
624
- batch_size=args.per_device_eval_batch_size,
625
- collate_fn=self.data_collator,
626
- drop_last=True,
627
- ) # no need to shuffle eval dataset
628
- self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
629
-
630
- if self.is_deepspeed_enabled:
631
- self.reward_model = prepare_deepspeed(
632
- self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
633
- )
634
-
635
- if self.ref_model is None:
636
- if not self.is_peft_model:
637
- raise ValueError("No reference model and model is not a Peft model.")
638
- else:
639
- self.ref_model = prepare_deepspeed(
640
- self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
641
- )
642
- else:
643
- if self.ref_model is None:
644
- if not self.is_peft_model:
645
- raise ValueError("No reference model and model is not a Peft model.")
646
- else:
647
- self.ref_model = self.ref_model.to(self.accelerator.device)
648
- self.reward_model = self.reward_model.to(self.accelerator.device)
649
-
650
- def get_train_dataloader(self) -> DataLoader:
651
- return self.dataloader
652
-
653
- def get_eval_dataloader(self) -> DataLoader:
654
- return self.eval_dataloader
655
-
656
- @contextmanager
657
- def null_ref_context(self):
658
- """Context manager for handling null reference model (that is, peft adapter manipulation)."""
659
- with (
660
- self.accelerator.unwrap_model(self.model.policy).disable_adapter()
661
- if self.is_peft_model and not self.ref_adapter_name
662
- else nullcontext()
663
- ):
664
- if self.ref_adapter_name:
665
- self.model.policy.set_adapter(self.ref_adapter_name)
666
- yield
667
- if self.ref_adapter_name:
668
- self.model.policy.set_adapter(self.model_adapter_name or "default")
669
-
670
- def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
671
- backup_model = self.model
672
- self.model = self.model.policy # save only the policy
673
-
674
- if self.is_deepspeed_enabled:
675
- backup_deepspeed = self.deepspeed
676
- self.deepspeed = self.model
677
-
678
- super().save_model(output_dir, _internal_call)
679
-
680
- self.model = backup_model
681
-
682
- if self.is_deepspeed_enabled:
683
- self.deepspeed = backup_deepspeed
684
-
685
- def train(self):
686
- args = self.args
687
- accelerator = self.accelerator
688
- optimizer = self.optimizer
689
- model = self.model
690
- ref_policy = self.ref_model
691
- reward_model = self.reward_model
692
- processing_class = self.processing_class
693
- dataloader = self.dataloader
694
- device = accelerator.device
695
-
696
- def repeat_generator():
697
- while True:
698
- yield from dataloader
699
-
700
- iter_dataloader = iter(repeat_generator())
701
- generation_config = GenerationConfig(
702
- max_new_tokens=args.response_length,
703
- temperature=(args.temperature + 1e-7),
704
- top_k=0.0,
705
- top_p=1.0,
706
- do_sample=True,
707
- )
708
-
709
- accelerator.print("===training policy===")
710
- start_time = time.time()
711
- stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
712
- approxkl_stats = torch.zeros(stats_shape, device=device)
713
- pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
714
- pg_loss_stats = torch.zeros(stats_shape, device=device)
715
- vf_loss_stats = torch.zeros(stats_shape, device=device)
716
- vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
717
- entropy_stats = torch.zeros(stats_shape, device=device)
718
- ratio_stats = torch.zeros(stats_shape, device=device)
719
- model.train()
720
-
721
- # trainer state initialization
722
- self.state.global_step = 0
723
- self.state.episode = 0
724
- self.state.max_steps = args.num_total_batches * args.num_mini_batches
725
- self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
726
- # Compute absolute values for logging, eval, and save if given as ratio
727
- if args.logging_steps is not None:
728
- if args.logging_steps < 1:
729
- self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
730
- else:
731
- self.state.logging_steps = args.logging_steps
732
- if args.eval_steps is not None:
733
- if args.eval_steps < 1:
734
- self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
735
- else:
736
- self.state.eval_steps = args.eval_steps
737
- if args.save_steps is not None:
738
- if args.save_steps < 1:
739
- self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
740
- else:
741
- self.state.save_steps = args.save_steps
742
- self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
743
-
744
- # backward compatibility
745
- if self.is_deepspeed_enabled:
746
- self.deepspeed = self.model
747
- self.model_wrapped = self.model
748
-
749
- for update in range(1, args.num_total_batches + 1):
750
- self.state.episode += 1 * args.batch_size
751
- data = next(iter_dataloader)
752
- with torch.no_grad():
753
- queries = data["input_ids"].to(device)
754
- context_length = queries.shape[1]
755
- responses = []
756
- postprocessed_responses = []
757
- logprobs = []
758
- ref_logprobs = []
759
- scores = []
760
- sequence_lengths = []
761
- values = []
762
- with unwrap_model_for_generation(
763
- self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
764
- ) as unwrapped_model:
765
- query_responses, logitss = batch_generation(
766
- unwrapped_model.policy,
767
- queries,
768
- args.local_rollout_forward_batch_size,
769
- processing_class.pad_token_id,
770
- generation_config,
771
- )
772
-
773
- for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
774
- query = queries[i : i + args.local_rollout_forward_batch_size]
775
- query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
776
- response = query_response[:, context_length:]
777
- logits = logitss[i : i + args.local_rollout_forward_batch_size]
778
- logprob = selective_log_softmax(logits, response)
779
- del logits
780
- torch.cuda.empty_cache()
781
-
782
- if ref_policy is None:
783
- with self.null_ref_context():
784
- ref_output = forward(model.policy, query_response, processing_class.pad_token_id)
785
- else:
786
- ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
787
- ref_logits = ref_output.logits[:, context_length - 1 : -1]
788
- ref_logits /= args.temperature + 1e-7
789
- ref_logprob = selective_log_softmax(ref_logits, response)
790
- del ref_output, ref_logits
791
- torch.cuda.empty_cache()
792
-
793
- # Response Processing 1. truncate response after the first occurrence of `stop_token_id`
794
- postprocessed_response = response
795
- if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
796
- postprocessed_response = truncate_response(
797
- self.stop_token_id, processing_class.pad_token_id, response
798
- )
799
-
800
- # Response Processing 2. run reward model on the truncated responses
801
- postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
802
- sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
803
- unwrapped_value_model = accelerator.unwrap_model(model).value_model
804
- full_value, _, _ = get_reward(
805
- unwrapped_value_model, query_response, processing_class.pad_token_id, context_length
806
- )
807
- value = full_value[:, context_length - 1 : -1].squeeze(-1)
808
- _, score, _ = get_reward(
809
- reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
810
- )
811
-
812
- responses.append(response)
813
- postprocessed_responses.append(postprocessed_response)
814
- logprobs.append(logprob)
815
- ref_logprobs.append(ref_logprob)
816
- sequence_lengths.append(sequence_length)
817
- scores.append(score)
818
- values.append(value)
819
- responses = torch.cat(responses, 0)
820
- postprocessed_responses = torch.cat(postprocessed_responses, 0)
821
- logprobs = torch.cat(logprobs, 0)
822
- ref_logprobs = torch.cat(ref_logprobs, 0)
823
- sequence_lengths = torch.cat(sequence_lengths, 0)
824
- scores = torch.cat(scores, 0)
825
- values = torch.cat(values, 0)
826
- del (logprob, ref_logprob, full_value, value, score, unwrapped_model)
827
- torch.cuda.empty_cache()
828
- gc.collect()
829
-
830
- # Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id
831
- # Completions not passing that filter will receive a lower score.
832
- contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1)
833
- if self.args.missing_eos_penalty is not None:
834
- scores[~contain_eos_token] -= self.args.missing_eos_penalty
835
- # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
836
-
837
- # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
838
- response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
839
- padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
840
- logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
841
- ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
842
- sequence_lengths_p1 = sequence_lengths + 1
843
- padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1))
844
- values = torch.masked_fill(values, padding_mask_p1, 0)
845
-
846
- # 4. compute rewards
847
- kl = logprobs - ref_logprobs
848
- non_score_reward = -args.kl_coef * kl
849
- rewards = non_score_reward.clone()
850
- actual_start = torch.arange(rewards.size(0), device=rewards.device)
851
- actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
852
- rewards[[actual_start, actual_end]] += scores
853
-
854
- # 5. whiten rewards
855
- if args.whiten_rewards:
856
- rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False)
857
- rewards = torch.masked_fill(rewards, padding_mask_p1, 0)
858
-
859
- # 6. compute advantages and returns
860
- lastgaelam = 0
861
- advantages_reversed = []
862
- gen_length = responses.shape[1]
863
- for t in reversed(range(gen_length)):
864
- nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
865
- delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
866
- lastgaelam = delta + args.gamma * args.lam * lastgaelam
867
- advantages_reversed.append(lastgaelam)
868
- advantages = torch.stack(advantages_reversed[::-1], axis=1)
869
- returns = advantages + values
870
- advantages = masked_whiten(advantages, ~padding_mask)
871
- advantages = torch.masked_fill(advantages, padding_mask, 0)
872
- torch.cuda.empty_cache()
873
-
874
- # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
875
- for ppo_epoch_idx in range(args.num_ppo_epochs):
876
- b_inds = np.random.permutation(args.local_batch_size)
877
- minibatch_idx = 0
878
- for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
879
- mini_batch_end = mini_batch_start + args.local_mini_batch_size
880
- mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
881
- gradient_accumulation_idx = 0
882
- for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
883
- with accelerator.accumulate(model):
884
- micro_batch_end = micro_batch_start + args.per_device_train_batch_size
885
- micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
886
- mb_advantage = advantages[micro_batch_inds]
887
- mb_responses = responses[micro_batch_inds]
888
- mb_query_responses = query_responses[micro_batch_inds]
889
- mb_logprobs = logprobs[micro_batch_inds]
890
- mb_return = returns[micro_batch_inds]
891
- mb_values = values[micro_batch_inds]
892
-
893
- output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id)
894
- logits = output.logits[:, context_length - 1 : -1]
895
- logits /= args.temperature + 1e-7
896
- new_logprobs = selective_log_softmax(logits, mb_responses)
897
- new_logprobs = torch.masked_fill(
898
- new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
899
- )
900
- vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1)
901
- vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0)
902
- vpredclipped = torch.clamp(
903
- vpred,
904
- mb_values - args.cliprange_value,
905
- mb_values + args.cliprange_value,
906
- )
907
- vf_losses1 = torch.square(vpred - mb_return)
908
- vf_losses2 = torch.square(vpredclipped - mb_return)
909
- vf_loss_max = torch.max(vf_losses1, vf_losses2)
910
- vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds])
911
- vf_clipfrac = masked_mean(
912
- (vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds]
913
- )
914
- logprobs_diff = new_logprobs - mb_logprobs
915
- ratio = torch.exp(logprobs_diff)
916
- pg_losses = -mb_advantage * ratio
917
- pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
918
- pg_loss_max = torch.max(pg_losses, pg_losses2)
919
- pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])
920
- loss = pg_loss + args.vf_coef * vf_loss
921
- accelerator.backward(loss)
922
- optimizer.step()
923
- optimizer.zero_grad()
924
- with torch.no_grad():
925
- pg_clipfrac = masked_mean(
926
- (pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds]
927
- )
928
- prob_dist = torch.nn.functional.softmax(logits, dim=-1)
929
- entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
930
- approxkl = 0.5 * (logprobs_diff**2).mean()
931
- approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
932
- pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
933
- pg_clipfrac
934
- )
935
- pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
936
- vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
937
- vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
938
- vf_clipfrac
939
- )
940
- entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
941
- ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
942
- gradient_accumulation_idx += 1
943
- minibatch_idx += 1
944
- # del everything and empty cache
945
- # fmt: off
946
- del (
947
- output, vpred_temp, logits, new_logprobs, vpred, vpredclipped,
948
- vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max,
949
- pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return,
950
- mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
951
- )
952
- # fmt: on
953
- torch.cuda.empty_cache()
954
- with torch.no_grad():
955
- mean_kl = kl.sum(1).mean()
956
- mean_entropy = (-logprobs).sum(1).mean()
957
- mean_non_score_reward = non_score_reward.sum(1).mean()
958
- rlhf_reward = mean_non_score_reward + scores.mean()
959
- eps = int(self.state.episode / (time.time() - start_time))
960
- metrics = {}
961
- metrics["eps"] = eps
962
- metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
963
- metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
964
- metrics["objective/non_score_reward"] = (
965
- self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
966
- )
967
- metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
968
- metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
969
- metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
970
- metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
971
- metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
972
- metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item()
973
- metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
974
- metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
975
- metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
976
- metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
977
- metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
978
- metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
979
- metrics["episode"] = self.state.episode
980
- self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log
981
- self.state.global_step += 1
982
- self.log(metrics)
983
-
984
- self.lr_scheduler.step()
985
- self.control = self.callback_handler.on_step_end(args, self.state, self.control)
986
- if self.control.should_save:
987
- self._save_checkpoint(model, trial=None)
988
- self.control = self.callback_handler.on_save(self.args, self.state, self.control)
989
- del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
990
- torch.cuda.empty_cache()
991
- gc.collect()
992
-
993
- if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
994
- self.generate_completions(sampling=True)
995
- torch.cuda.empty_cache()
996
- del (
997
- query_responses,
998
- responses,
999
- postprocessed_responses,
1000
- logprobs,
1001
- ref_logprobs,
1002
- values,
1003
- sequence_lengths,
1004
- contain_eos_token,
1005
- sequence_lengths_p1,
1006
- response_idxs,
1007
- padding_mask,
1008
- padding_mask_p1,
1009
- rewards,
1010
- actual_start,
1011
- actual_end,
1012
- advantages,
1013
- returns,
1014
- )
1015
- torch.cuda.empty_cache()
1016
-
1017
- # HF trainer specifics
1018
- self.control = self.callback_handler.on_train_end(args, self.state, self.control)
1019
- if self.control.should_save:
1020
- self._save_checkpoint(model, trial=None, metrics=None)
1021
- self.control = self.callback_handler.on_save(self.args, self.state, self.control)
1022
-
1023
- def generate_completions(self, sampling: bool = False):
1024
- args = self.args
1025
- processing_class = self.processing_class
1026
- generation_config = GenerationConfig(
1027
- max_new_tokens=self.args.response_length,
1028
- temperature=(0.01 + 1e-7),
1029
- top_k=0.0,
1030
- top_p=1.0,
1031
- do_sample=True,
1032
- )
1033
-
1034
- table = defaultdict(list)
1035
- with unwrap_model_for_generation(
1036
- self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
1037
- ) as unwrapped_model:
1038
- for batch in self.eval_dataloader:
1039
- query = batch["input_ids"]
1040
- with torch.no_grad():
1041
- context_length = query.shape[1]
1042
- query_response, _ = batch_generation(
1043
- unwrapped_model.policy,
1044
- query,
1045
- query.shape[0],
1046
- processing_class.pad_token_id,
1047
- generation_config,
1048
- )
1049
- response = query_response[:, context_length:]
1050
- postprocessed_response = response
1051
- if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
1052
- postprocessed_response = truncate_response(
1053
- self.stop_token_id, processing_class.pad_token_id, response
1054
- )
1055
- table["query"].extend(
1056
- gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
1057
- )
1058
- table["model response"].extend(
1059
- gather_object(processing_class.batch_decode(postprocessed_response))
1060
- )
1061
-
1062
- postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
1063
- _, score, _ = get_reward(
1064
- self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
1065
- )
1066
- table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
1067
-
1068
- if sampling:
1069
- break
1070
- df = pd.DataFrame(table)
1071
-
1072
- if self.accelerator.is_main_process:
1073
- print_rich_table(df.iloc[0 : 0 + 5])
1074
- if "wandb" in args.report_to:
1075
- import wandb
1076
-
1077
- if wandb.run is not None:
1078
- wandb.log({"completions": wandb.Table(dataframe=df)})
1079
-
1080
- if "comet_ml" in args.report_to:
1081
- log_table_to_comet_experiment(
1082
- name="completions.csv",
1083
- table=df,
1084
- )
1085
-
1086
- def create_model_card(
1087
- self,
1088
- model_name: Optional[str] = None,
1089
- dataset_name: Optional[str] = None,
1090
- tags: Union[str, list[str], None] = None,
1091
- ):
1092
- """
1093
- Creates a draft of a model card using the information available to the `Trainer`.
1094
-
1095
- Args:
1096
- model_name (`str` or `None`, *optional*, defaults to `None`):
1097
- Name of the model.
1098
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
1099
- Name of the dataset used for training.
1100
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1101
- Tags to be associated with the model card.
1102
- """
1103
- if not self.is_world_process_zero():
1104
- return
1105
-
1106
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1107
- base_model = self.model.config._name_or_path
1108
- else:
1109
- base_model = None
1110
-
1111
- tags = tags or []
1112
- if isinstance(tags, str):
1113
- tags = [tags]
1114
-
1115
- if hasattr(self.model.config, "unsloth_version"):
1116
- tags.append("unsloth")
1117
-
1118
- citation = textwrap.dedent("""\
1119
- @article{mziegler2019fine-tuning,
1120
- title = {{Fine-Tuning Language Models from Human Preferences}},
1121
- author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving},
1122
- year = 2019,
1123
- eprint = {arXiv:1909.08593}
1124
- }""")
1125
-
1126
- model_card = generate_model_card(
1127
- base_model=base_model,
1128
- model_name=model_name,
1129
- hub_model_id=self.hub_model_id,
1130
- dataset_name=dataset_name,
1131
- tags=tags,
1132
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1133
- comet_url=get_comet_experiment_url(),
1134
- trainer_name="PPO",
1135
- trainer_citation=citation,
1136
- paper_title="Fine-Tuning Language Models from Human Preferences",
1137
- paper_id="1909.08593",
1138
- )
1139
-
1140
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
1141
- class UnslothPPOTrainer(_UnslothPPOTrainer):
1142
- """
1143
-
1144
- """
1145
- def __init__(
1146
- self,
1147
- args,
1148
- processing_class,
1149
- model,
1150
- ref_model,
1151
- reward_model,
1152
- train_dataset,
1153
- value_model = None,
1154
- data_collator = None,
1155
- eval_dataset = None,
1156
- callbacks = None,
1157
- peft_config = None,
1158
- **kwargs
1159
- ):
1160
- if args is None: args = UnslothPPOConfig()
1161
- use_bf16 = getattr(args, 'bf16', False)
1162
- if type(use_bf16) is not bool: use_bf16 = False
1163
- use_fp16 = getattr(args, 'fp16', False)
1164
- if type(use_fp16) is not bool: use_fp16 = False
1165
- force_float32 = False
1166
- if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
1167
- print('Unsloth: Switching to float32 training since model cannot work with float16')
1168
- force_float32 = True
1169
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
1170
- dtype = getattr(model.config, 'torch_dtype', None)
1171
- if dtype is None: dtype = model.get_input_embeddings().dtype
1172
- from unsloth_zoo.utils import _get_dtype
1173
- dtype = _get_dtype(dtype)
1174
- float16 = dtype == torch.float16
1175
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
1176
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
1177
- if force_float32:
1178
- args.fp16 = False
1179
- args.bf16 = False
1180
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
1181
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
1182
- args.fp16 = float16
1183
- args.bf16 = not float16
1184
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
1185
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
1186
- args.eval_strategy = 'steps'
1187
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
1188
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
1189
- if ga_steps is not None and ga_steps > 1:
1190
- from transformers import __version__ as transformers_version
1191
- if Version(transformers_version) <= Version('4.45.2'):
1192
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
1193
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
1194
- if getattr(args, 'eval_strategy', 'no') != 'no':
1195
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
1196
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
1197
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
1198
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
1199
- if type(fp16_full_eval) is not bool: fp16_full_eval = False
1200
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1201
- if type(bf16_full_eval) is not bool: bf16_full_eval = False
1202
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1203
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1204
- if force_float32:
1205
- args.bf16_full_eval = False
1206
- args.fp16_full_eval = False
1207
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1208
- args.bf16_full_eval = True
1209
- args.fp16_full_eval = False
1210
- elif not bf16_full_eval and not fp16_full_eval:
1211
- args.bf16_full_eval = args.bf16
1212
- args.fp16_full_eval = args.fp16
1213
- _output_logits = False
1214
- if locals().get('compute_metrics', None) is not None: _output_logits = True
1215
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1216
- if _output_logits:
1217
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1218
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1219
- pass
1220
- else:
1221
- model_max_seq_length = getattr(model, 'max_seq_length', None)
1222
- args_max_seq_length = getattr(args, 'max_seq_length', None)
1223
- if args_max_seq_length is None and model_max_seq_length is not None:
1224
- max_seq_length = model.max_seq_length
1225
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1226
- if model is not None and hasattr(model, 'for_training'):
1227
- model.for_training()
1228
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1229
- if 'processing_class' in locals():
1230
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1231
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1232
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1233
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1234
- if not isinstance(data_collator, UnslothVisionDataCollator):
1235
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1236
- data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
1237
- elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1238
- data_collator = DataCollatorForSeq2Seq(__tokenizer)
1239
- else:
1240
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1241
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1242
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1243
- if not isinstance(data_collator, UnslothVisionDataCollator):
1244
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1245
- if isinstance(data_collator, DataCollatorForSeq2Seq):
1246
- data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1247
- else:
1248
- data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
1249
- other_metrics = []
1250
-
1251
- from unsloth_zoo.logging_utils import PatchRLStatistics
1252
- PatchRLStatistics('ppo_trainer', other_metrics)
1253
-
1254
- super().__init__(
1255
- args = args,
1256
- processing_class = processing_class,
1257
- model = model,
1258
- ref_model = ref_model,
1259
- reward_model = reward_model,
1260
- train_dataset = train_dataset,
1261
- value_model = value_model,
1262
- data_collator = data_collator,
1263
- eval_dataset = eval_dataset,
1264
- callbacks = callbacks,
1265
- peft_config = peft_config,**kwargs)
1266
- if hasattr(self, 'neftune_hook_handle'):
1267
- self.neftune_hook_handle.remove()
1268
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1269
- if getattr(args, 'neftune_noise_alpha', None) is not None:
1270
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1271
- pass
1272
-
1273
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_run_uploads/UnslothPRMTrainer.py DELETED
@@ -1,809 +0,0 @@
1
- """
2
- 2025.7.11
3
- 2025.7.11
4
- 4.54.1
5
- 0.16.1
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
13
- from trl.trainer.prm_trainer import (BaseImageProcessor, Callable, DataCollator, DataCollatorForTokenClassification, Dataset, EvalPrediction, FeatureExtractionMixin, Optional, PRMConfig, PRMTrainer, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, chain, compute_accuracy, disable_dropout_in_model, features, generate_model_card, inspect, is_peft_available, is_wandb_available, nn, os, prepare_model_for_kbit_training, textwrap, torch, wandb, warnings)
14
-
15
-
16
- import os
17
- from typing import *
18
- from dataclasses import dataclass, field
19
- from packaging.version import Version
20
- import torch
21
- import numpy as np
22
- from contextlib import nullcontext
23
- from torch.nn import functional as F
24
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
25
-
26
- torch_compile_options = {
27
- "epilogue_fusion" : True,
28
- "max_autotune" : False,
29
- "shape_padding" : True,
30
- "trace.enabled" : False,
31
- "triton.cudagraphs" : False,
32
- }
33
-
34
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
35
- def chunked_selective_log_softmax(logits, index):
36
- # Split into 4 chunks only
37
- chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
38
- chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
39
- all_per_token_logps = []
40
- # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
41
- for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
42
- chunk_logits = chunk_logits.to(torch.float32)
43
- selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
44
- logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
45
- per_token_logps = selected_logits - logsumexp_values
46
- all_per_token_logps.append(per_token_logps)
47
- pass
48
- all_per_token_logps = torch.concat(all_per_token_logps)
49
- all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
50
- return all_per_token_logps
51
- @dataclass
52
- class UnslothPRMConfig(PRMConfig):
53
- """
54
-
55
- Configuration class for the [`PRMTrainer`].
56
-
57
- Using [`~transformers.HfArgumentParser`] we can turn this class into
58
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
59
- command line.
60
-
61
- Parameters:
62
- learning_rate (`float`, *optional*, defaults to `1e-5`):
63
- Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
64
- [`~transformers.TrainingArguments`].
65
- max_length (`int` or `None`, *optional*, defaults to `1024`):
66
- Maximum length of the sequences (prompt + completion) used for truncation.
67
- max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
68
- Maximum length of the prompt used for truncation.
69
- max_completion_length (`int` or `None`, *optional*, defaults to `None`):
70
- Maximum length of the completion used for truncation. The completion is the concatenation of the steps.
71
- disable_dropout (`bool`, *optional*, defaults to `True`):
72
- Whether to disable dropout in the model.
73
- step_separator (`str`, *optional*, defaults to `"\n"`):
74
- Separator used to separate each step of the reasoning process.
75
- train_on_last_step_only (`bool`, *optional*, defaults to `False`):
76
- Whether to train only on the last step.
77
- dataset_num_proc (`int`, *optional*, defaults to `None`):
78
- Number of processes to use for processing the dataset.
79
-
80
- """
81
- vllm_sampling_params: Optional[Any] = field(
82
- default = None,
83
- metadata = {'help': 'vLLM SamplingParams'},
84
- )
85
- unsloth_num_chunks : Optional[int] = field(
86
- default = -1,
87
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
88
- )
89
- def __init__(
90
- self,
91
- output_dir = None,
92
- overwrite_output_dir = None,
93
- do_train = False,
94
- do_eval = False,
95
- do_predict = False,
96
- eval_strategy = 'no',
97
- prediction_loss_only = False,
98
- per_device_train_batch_size = 4,
99
- per_device_eval_batch_size = 4,
100
- per_gpu_train_batch_size = None,
101
- per_gpu_eval_batch_size = None,
102
- gradient_accumulation_steps = 2,
103
- eval_accumulation_steps = 2,
104
- eval_delay = 0,
105
- torch_empty_cache_steps = 250,
106
- learning_rate = 5e-05,
107
- weight_decay = 0.01,
108
- adam_beta1 = 0.9,
109
- adam_beta2 = 0.999,
110
- adam_epsilon = 1e-08,
111
- max_grad_norm = 1.0,
112
- num_train_epochs = 3.0,
113
- max_steps = -1,
114
- lr_scheduler_type = 'linear',
115
- warmup_ratio = 0.1,
116
- warmup_steps = 0,
117
- log_level = 'passive',
118
- log_level_replica = 'warning',
119
- log_on_each_node = True,
120
- logging_dir = None,
121
- logging_strategy = 'steps',
122
- logging_first_step = False,
123
- logging_steps = 1,
124
- logging_nan_inf_filter = False,
125
- save_strategy = 'steps',
126
- save_steps = 500,
127
- save_total_limit = None,
128
- save_safetensors = True,
129
- save_on_each_node = False,
130
- save_only_model = False,
131
- restore_callback_states_from_checkpoint = False,
132
- no_cuda = False,
133
- use_cpu = False,
134
- use_mps_device = False,
135
- seed = 3407,
136
- data_seed = 3407,
137
- jit_mode_eval = False,
138
- use_ipex = False,
139
- bf16 = False,
140
- fp16 = False,
141
- fp16_opt_level = 'O1',
142
- half_precision_backend = 'auto',
143
- bf16_full_eval = False,
144
- fp16_full_eval = False,
145
- tf32 = None,
146
- local_rank = -1,
147
- ddp_backend = None,
148
- tpu_num_cores = None,
149
- tpu_metrics_debug = False,
150
- debug = '',
151
- dataloader_drop_last = False,
152
- eval_steps = None,
153
- dataloader_num_workers = 0,
154
- dataloader_prefetch_factor = None,
155
- past_index = -1,
156
- run_name = None,
157
- disable_tqdm = None,
158
- remove_unused_columns = True,
159
- label_names = None,
160
- load_best_model_at_end = False,
161
- metric_for_best_model = None,
162
- greater_is_better = None,
163
- ignore_data_skip = False,
164
- fsdp = '',
165
- fsdp_min_num_params = 0,
166
- fsdp_config = None,
167
- fsdp_transformer_layer_cls_to_wrap = None,
168
- accelerator_config = None,
169
- deepspeed = None,
170
- label_smoothing_factor = 0.0,
171
- optim = 'adamw_8bit',
172
- optim_args = None,
173
- adafactor = False,
174
- group_by_length = False,
175
- length_column_name = 'length',
176
- report_to = None,
177
- ddp_find_unused_parameters = None,
178
- ddp_bucket_cap_mb = None,
179
- ddp_broadcast_buffers = None,
180
- dataloader_pin_memory = True,
181
- dataloader_persistent_workers = False,
182
- skip_memory_metrics = True,
183
- use_legacy_prediction_loop = False,
184
- push_to_hub = False,
185
- resume_from_checkpoint = None,
186
- hub_model_id = None,
187
- hub_strategy = 'every_save',
188
- hub_token = None,
189
- hub_private_repo = None,
190
- hub_always_push = False,
191
- hub_revision = None,
192
- gradient_checkpointing = False,
193
- gradient_checkpointing_kwargs = None,
194
- include_inputs_for_metrics = False,
195
- eval_do_concat_batches = True,
196
- fp16_backend = 'auto',
197
- push_to_hub_model_id = None,
198
- push_to_hub_organization = None,
199
- push_to_hub_token = None,
200
- mp_parameters = '',
201
- auto_find_batch_size = True,
202
- full_determinism = False,
203
- torchdynamo = None,
204
- ray_scope = 'last',
205
- ddp_timeout = 1800,
206
- torch_compile = False,
207
- torch_compile_backend = None,
208
- torch_compile_mode = None,
209
- include_tokens_per_second = False,
210
- include_num_input_tokens_seen = False,
211
- neftune_noise_alpha = None,
212
- optim_target_modules = None,
213
- batch_eval_metrics = False,
214
- eval_on_start = False,
215
- use_liger_kernel = False,
216
- liger_kernel_config = None,
217
- eval_use_gather_object = False,
218
- average_tokens_across_devices = True,
219
- max_length = 1024,
220
- max_prompt_length = 512,
221
- max_completion_length = None,
222
- disable_dropout = True,
223
- step_separator = '\
224
- ',
225
- train_on_last_step_only = False,
226
- dataset_num_proc = None,
227
- vllm_sampling_params = None,
228
- unsloth_num_chunks = -1,
229
- **kwargs,
230
- ):
231
- if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
232
- if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
233
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
234
- output_dir = 'unsloth_training_checkpoints'
235
- save_strategy = 'no'
236
- if dataset_num_proc is None:
237
- from multiprocessing import cpu_count
238
- dataset_num_proc = min(cpu_count()*2, 2)
239
-
240
- super().__init__(
241
- output_dir = output_dir,
242
- overwrite_output_dir = overwrite_output_dir,
243
- do_train = do_train,
244
- do_eval = do_eval,
245
- do_predict = do_predict,
246
- eval_strategy = eval_strategy,
247
- prediction_loss_only = prediction_loss_only,
248
- per_device_train_batch_size = per_device_train_batch_size,
249
- per_device_eval_batch_size = per_device_eval_batch_size,
250
- per_gpu_train_batch_size = per_gpu_train_batch_size,
251
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
252
- gradient_accumulation_steps = gradient_accumulation_steps,
253
- eval_accumulation_steps = eval_accumulation_steps,
254
- eval_delay = eval_delay,
255
- torch_empty_cache_steps = torch_empty_cache_steps,
256
- learning_rate = learning_rate,
257
- weight_decay = weight_decay,
258
- adam_beta1 = adam_beta1,
259
- adam_beta2 = adam_beta2,
260
- adam_epsilon = adam_epsilon,
261
- max_grad_norm = max_grad_norm,
262
- num_train_epochs = num_train_epochs,
263
- max_steps = max_steps,
264
- lr_scheduler_type = lr_scheduler_type,
265
- warmup_ratio = warmup_ratio,
266
- warmup_steps = warmup_steps,
267
- log_level = log_level,
268
- log_level_replica = log_level_replica,
269
- log_on_each_node = log_on_each_node,
270
- logging_dir = logging_dir,
271
- logging_strategy = logging_strategy,
272
- logging_first_step = logging_first_step,
273
- logging_steps = logging_steps,
274
- logging_nan_inf_filter = logging_nan_inf_filter,
275
- save_strategy = save_strategy,
276
- save_steps = save_steps,
277
- save_total_limit = save_total_limit,
278
- save_safetensors = save_safetensors,
279
- save_on_each_node = save_on_each_node,
280
- save_only_model = save_only_model,
281
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
282
- no_cuda = no_cuda,
283
- use_cpu = use_cpu,
284
- use_mps_device = use_mps_device,
285
- seed = seed,
286
- data_seed = data_seed,
287
- jit_mode_eval = jit_mode_eval,
288
- use_ipex = use_ipex,
289
- bf16 = bf16,
290
- fp16 = fp16,
291
- fp16_opt_level = fp16_opt_level,
292
- half_precision_backend = half_precision_backend,
293
- bf16_full_eval = bf16_full_eval,
294
- fp16_full_eval = fp16_full_eval,
295
- tf32 = tf32,
296
- local_rank = local_rank,
297
- ddp_backend = ddp_backend,
298
- tpu_num_cores = tpu_num_cores,
299
- tpu_metrics_debug = tpu_metrics_debug,
300
- debug = debug,
301
- dataloader_drop_last = dataloader_drop_last,
302
- eval_steps = eval_steps,
303
- dataloader_num_workers = dataloader_num_workers,
304
- dataloader_prefetch_factor = dataloader_prefetch_factor,
305
- past_index = past_index,
306
- run_name = run_name,
307
- disable_tqdm = disable_tqdm,
308
- remove_unused_columns = remove_unused_columns,
309
- label_names = label_names,
310
- load_best_model_at_end = load_best_model_at_end,
311
- metric_for_best_model = metric_for_best_model,
312
- greater_is_better = greater_is_better,
313
- ignore_data_skip = ignore_data_skip,
314
- fsdp = fsdp,
315
- fsdp_min_num_params = fsdp_min_num_params,
316
- fsdp_config = fsdp_config,
317
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
318
- accelerator_config = accelerator_config,
319
- deepspeed = deepspeed,
320
- label_smoothing_factor = label_smoothing_factor,
321
- optim = optim,
322
- optim_args = optim_args,
323
- adafactor = adafactor,
324
- group_by_length = group_by_length,
325
- length_column_name = length_column_name,
326
- report_to = report_to,
327
- ddp_find_unused_parameters = ddp_find_unused_parameters,
328
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
329
- ddp_broadcast_buffers = ddp_broadcast_buffers,
330
- dataloader_pin_memory = dataloader_pin_memory,
331
- dataloader_persistent_workers = dataloader_persistent_workers,
332
- skip_memory_metrics = skip_memory_metrics,
333
- use_legacy_prediction_loop = use_legacy_prediction_loop,
334
- push_to_hub = push_to_hub,
335
- resume_from_checkpoint = resume_from_checkpoint,
336
- hub_model_id = hub_model_id,
337
- hub_strategy = hub_strategy,
338
- hub_token = hub_token,
339
- hub_private_repo = hub_private_repo,
340
- hub_always_push = hub_always_push,
341
- hub_revision = hub_revision,
342
- gradient_checkpointing = gradient_checkpointing,
343
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
344
- include_inputs_for_metrics = include_inputs_for_metrics,
345
- eval_do_concat_batches = eval_do_concat_batches,
346
- fp16_backend = fp16_backend,
347
- push_to_hub_model_id = push_to_hub_model_id,
348
- push_to_hub_organization = push_to_hub_organization,
349
- push_to_hub_token = push_to_hub_token,
350
- mp_parameters = mp_parameters,
351
- auto_find_batch_size = auto_find_batch_size,
352
- full_determinism = full_determinism,
353
- torchdynamo = torchdynamo,
354
- ray_scope = ray_scope,
355
- ddp_timeout = ddp_timeout,
356
- torch_compile = torch_compile,
357
- torch_compile_backend = torch_compile_backend,
358
- torch_compile_mode = torch_compile_mode,
359
- include_tokens_per_second = include_tokens_per_second,
360
- include_num_input_tokens_seen = include_num_input_tokens_seen,
361
- neftune_noise_alpha = neftune_noise_alpha,
362
- optim_target_modules = optim_target_modules,
363
- batch_eval_metrics = batch_eval_metrics,
364
- eval_on_start = eval_on_start,
365
- use_liger_kernel = use_liger_kernel,
366
- liger_kernel_config = liger_kernel_config,
367
- eval_use_gather_object = eval_use_gather_object,
368
- average_tokens_across_devices = average_tokens_across_devices,
369
- max_length = max_length,
370
- max_prompt_length = max_prompt_length,
371
- max_completion_length = max_completion_length,
372
- disable_dropout = disable_dropout,
373
- step_separator = step_separator,
374
- train_on_last_step_only = train_on_last_step_only,
375
- dataset_num_proc = dataset_num_proc,**kwargs)
376
- self.vllm_sampling_params = vllm_sampling_params
377
- self.unsloth_num_chunks = unsloth_num_chunks
378
- pass
379
-
380
- class _UnslothPRMTrainer(Trainer):
381
- """"""
382
-
383
- _tag_names = ["trl", "prm"]
384
-
385
- def __init__(
386
- self,
387
- model: Optional[Union[PreTrainedModel, nn.Module]] = None,
388
- args: Optional[PRMConfig] = None,
389
- data_collator: Optional[DataCollator] = None,
390
- train_dataset: Optional[Dataset] = None,
391
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
392
- processing_class: Optional[
393
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
394
- ] = None,
395
- model_init: Optional[Callable[[], PreTrainedModel]] = None,
396
- compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
397
- callbacks: Optional[list[TrainerCallback]] = None,
398
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
399
- None,
400
- None,
401
- ),
402
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
403
- peft_config: Optional[dict] = None,
404
- ):
405
- if not is_peft_available() and peft_config is not None:
406
- raise ValueError(
407
- "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
408
- )
409
- elif is_peft_available() and peft_config is not None:
410
- if not isinstance(model, PeftModel):
411
- if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
412
- _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
413
- inspect.signature(prepare_model_for_kbit_training).parameters
414
- )
415
-
416
- prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
417
-
418
- if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
419
- warnings.warn(
420
- "You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
421
- "please update to the latest version of peft to use `gradient_checkpointing_kwargs`."
422
- )
423
- elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
424
- prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
425
-
426
- model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
427
-
428
- model = model
429
-
430
- # Disable dropout in the model
431
- if args.disable_dropout:
432
- disable_dropout_in_model(model)
433
-
434
- if compute_metrics is None:
435
- compute_metrics = compute_accuracy
436
-
437
- if data_collator is None:
438
- if processing_class is None:
439
- raise ValueError(
440
- "A processing_class must be specified when using the default DataCollatorForTokenClassification"
441
- )
442
- data_collator = DataCollatorForTokenClassification(processing_class, max_length=args.max_length)
443
-
444
- if "input_ids" not in train_dataset.column_names:
445
- with PartialState().main_process_first():
446
- fn_kwargs = {
447
- "tokenizer": processing_class,
448
- "step_separator": args.step_separator,
449
- "max_length": args.max_length,
450
- "max_prompt_length": args.max_prompt_length,
451
- "max_completion_length": args.max_completion_length,
452
- "train_on_last_step_only": args.train_on_last_step_only,
453
- }
454
- train_fn_kwargs = {**fn_kwargs, "is_eval": False}
455
- train_dataset = train_dataset.map(
456
- self.tokenize_row,
457
- fn_kwargs=train_fn_kwargs,
458
- num_proc=args.dataset_num_proc,
459
- remove_columns=train_dataset.features,
460
- desc="Tokenizing train dataset",
461
- features=features.Features( # needed to avoid map to cast labels to bool
462
- {
463
- "labels": features.Sequence(features.Value("int64")),
464
- "input_ids": features.Sequence(features.Value("int64")),
465
- }
466
- ),
467
- )
468
-
469
- eval_fn_kwargs = {**fn_kwargs, "is_eval": True}
470
- if eval_dataset is not None:
471
- eval_dataset = eval_dataset.map(
472
- self.tokenize_row,
473
- fn_kwargs=eval_fn_kwargs,
474
- num_proc=args.dataset_num_proc,
475
- remove_columns=eval_dataset.features,
476
- desc="Tokenizing eval dataset",
477
- features=features.Features( # needed to avoid map to cast labels to bool
478
- {
479
- "labels": features.Sequence(features.Value("int64")),
480
- "input_ids": features.Sequence(features.Value("int64")),
481
- }
482
- ),
483
- )
484
-
485
- super().__init__(
486
- model=model,
487
- args=args,
488
- data_collator=data_collator,
489
- train_dataset=train_dataset,
490
- eval_dataset=eval_dataset,
491
- processing_class=processing_class,
492
- model_init=model_init,
493
- compute_metrics=compute_metrics,
494
- callbacks=callbacks,
495
- optimizers=optimizers,
496
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
497
- )
498
-
499
- # Add tags for models that have been loaded with the correct transformers version
500
- if hasattr(self.model, "add_model_tags"):
501
- self.model.add_model_tags(self._tag_names)
502
-
503
- @staticmethod
504
- def tokenize_row(
505
- features,
506
- tokenizer,
507
- step_separator,
508
- max_length,
509
- max_prompt_length,
510
- max_completion_length,
511
- train_on_last_step_only,
512
- is_eval,
513
- ):
514
- r"""
515
- Tokenize a row of the dataset.
516
-
517
- Args:
518
- features (`dict[str, str]`):
519
- Row of the dataset, should contain the keys `"prompt"`, `"completions"`, and `"labels"`.
520
- tokenizer (`PreTrainedTokenizerBase`):
521
- Tokenizer used to process the data.
522
- step_separator (`str`):
523
- Separator between steps in the completion.
524
- max_length (`int` or `None`):
525
- Maximum length of the sequences (prompt + completion). If `None`, the sequences are not truncated.
526
- max_prompt_length (`int` or `None`):
527
- Maximum length of the prompt. If `None`, the prompt is not truncated.
528
- max_completion_length (`int` or `None`):
529
- Maximum length of the completion sequences. If `None`, the completion sequences are not truncated.
530
- train_on_last_step_only (`bool`):
531
- Whether to train only on the last step. If `True`, the labels are `-100` for all tokens except the last
532
- token of the completion.
533
- is_eval (`bool`):
534
- Whether the function is used to tokenize samples from a training or an evaluation dataset. Used only if `train_on_last_step_only` is set to `True`.
535
-
536
- Returns:
537
- `dict[str, list[int]]`:
538
- Tokenized sequences with the keys `"input_ids"`, and `"labels".
539
-
540
- Example:
541
- ```python
542
- >>> from transformers import AutoTokenizer
543
- >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
544
- >>> features = {"prompt": "Which number is larger, 9.8 or 9.11?",
545
- ... "completions": ["11 is greater than 8.",
546
- ... "Hence, 9.11 > 9.8."],
547
- ... "labels": [True, False]}
548
- >>> PRMTrainer.tokenize_row(features, tokenizer, "\n", max_completion_length=None, train_on_last_step_only=False, is_eval=False)
549
- {'input_ids': [23085, 1372, 374, 8131, 11, 220, 24, 13, 23, 476, 220, 24, 13, 16, 16, 30, 16, 16, 374, 7046, 1091, 220, 23, 13, 198, 39, 763, 11, 220, 24, 13, 16, 16, 861, 220, 24, 13, 23, 13, 198],
550
- 'labels': [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0]}
551
- ```
552
- """
553
- # Tokenize the prompt and completions
554
- prompt_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"]
555
- completions_ids = [
556
- tokenizer(completion, add_special_tokens=False)["input_ids"] for completion in features["completions"]
557
- ]
558
- if train_on_last_step_only and not is_eval:
559
- labels = [-100] * (len(features["labels"]) - 1) + [int(features["labels"][-1])]
560
- else:
561
- labels = [int(label) for label in features["labels"]]
562
-
563
- # Get the ID of the separator token and add it to the completions
564
- separator_ids = tokenizer.encode(step_separator, add_special_tokens=False)
565
- completions_ids = [completion + separator_ids for completion in completions_ids]
566
-
567
- # Create the label
568
- labels = [[-100] * (len(completion) - 1) + [label] for completion, label in zip(completions_ids, labels)]
569
-
570
- # Join the completions and labels steps
571
- completion_ids = list(chain(*completions_ids))
572
- labels = list(chain(*labels))
573
-
574
- if tokenizer.bos_token_id is not None:
575
- prompt_ids = [tokenizer.bos_token_id] + prompt_ids
576
-
577
- # Truncate prompt and completion sequences
578
- if max_prompt_length is not None:
579
- prompt_ids = prompt_ids[-max_prompt_length:]
580
- if max_completion_length is not None:
581
- completion_ids = completion_ids[:max_completion_length]
582
- labels = labels[:max_completion_length]
583
-
584
- input_ids = prompt_ids + completion_ids
585
- labels = [-100] * len(prompt_ids) + labels
586
-
587
- if max_length is not None:
588
- input_ids = input_ids[:max_length]
589
- labels = labels[:max_length]
590
-
591
- return {"input_ids": input_ids, "labels": labels}
592
-
593
- def create_model_card(
594
- self,
595
- model_name: Optional[str] = None,
596
- dataset_name: Optional[str] = None,
597
- tags: Union[str, list[str], None] = None,
598
- ):
599
- """
600
- Creates a draft of a model card using the information available to the `Trainer`.
601
-
602
- Args:
603
- model_name (`str` or `None`, *optional*, defaults to `None`):
604
- Name of the model.
605
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
606
- Name of the dataset used for training.
607
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
608
- Tags to be associated with the model card.
609
- """
610
- if not self.is_world_process_zero():
611
- return
612
-
613
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
614
- base_model = self.model.config._name_or_path
615
- else:
616
- base_model = None
617
-
618
- tags = tags or []
619
- if isinstance(tags, str):
620
- tags = [tags]
621
-
622
- if hasattr(self.model.config, "unsloth_version"):
623
- tags.append("unsloth")
624
-
625
- citation = textwrap.dedent("""\
626
- @article{uesato2022solving,
627
- title = {{Solving Math Word Problems With Process- and Outcome-Based Feedback}},
628
- author = {Uesato, Jonathan and Kushman, Nate and Kumar, Ramana and Song, Francis and Siegel, Noah and Wang, Lisa and Creswell, Antonia and Irving, Geoffrey and Higgins, Irina},
629
- year = 2022,
630
- journal = {arXiv preprint arXiv:2211.14275}
631
- }""")
632
-
633
- model_card = generate_model_card(
634
- base_model=base_model,
635
- model_name=model_name,
636
- hub_model_id=self.hub_model_id,
637
- dataset_name=dataset_name,
638
- tags=tags,
639
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
640
- trainer_name="PRM",
641
- trainer_citation=citation,
642
- paper_title="Solving math word problems with process-and outcome-based feedback",
643
- )
644
-
645
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
646
- class UnslothPRMTrainer(_UnslothPRMTrainer):
647
- """
648
-
649
- Initialize PRMTrainer.
650
-
651
- Args:
652
- model (`transformers.PreTrainedModel`):
653
- The model to train, preferably an `AutoModelForTokenClassification`.
654
- args (`PRMConfig`):
655
- The arguments to use for training.
656
- data_collator (`transformers.DataCollator`):
657
- The data collator to use for training. If None is specified, the default data collator (`DataCollatorForTokenClassification`) will be used
658
- which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
659
- train_dataset (`datasets.Dataset`):
660
- The dataset to use for training.
661
- eval_dataset (`datasets.Dataset`):
662
- The dataset to use for evaluation.
663
- processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
664
- Processing class used to process the data. If provided, will be used to automatically process the inputs
665
- for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
666
- reuse the fine-tuned model.
667
- model_init (`Callable[[], transformers.PreTrainedModel]`):
668
- The model initializer to use for training. If None is specified, the default model initializer will be used.
669
- compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
670
- The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.
671
- callbacks (`list[transformers.TrainerCallback]`):
672
- The callbacks to use for training.
673
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
674
- The optimizer and scheduler to use for training.
675
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
676
- The function to use to preprocess the logits before computing the metrics.
677
- peft_config (`dict`, defaults to `None`):
678
- The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
679
-
680
- """
681
- def __init__(
682
- self,
683
- model = None,
684
- args = None,
685
- data_collator = None,
686
- train_dataset = None,
687
- eval_dataset = None,
688
- processing_class = None,
689
- model_init = None,
690
- compute_metrics = None,
691
- callbacks = None,
692
- preprocess_logits_for_metrics = None,
693
- peft_config = None,
694
- **kwargs
695
- ):
696
- if args is None: args = UnslothPRMConfig()
697
- use_bf16 = getattr(args, 'bf16', False)
698
- if type(use_bf16) is not bool: use_bf16 = False
699
- use_fp16 = getattr(args, 'fp16', False)
700
- if type(use_fp16) is not bool: use_fp16 = False
701
- force_float32 = False
702
- if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
703
- print('Unsloth: Switching to float32 training since model cannot work with float16')
704
- force_float32 = True
705
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
706
- dtype = getattr(model.config, 'torch_dtype', None)
707
- if dtype is None: dtype = model.get_input_embeddings().dtype
708
- from unsloth_zoo.utils import _get_dtype
709
- dtype = _get_dtype(dtype)
710
- float16 = dtype == torch.float16
711
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
712
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
713
- if force_float32:
714
- args.fp16 = False
715
- args.bf16 = False
716
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
717
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
718
- args.fp16 = float16
719
- args.bf16 = not float16
720
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
721
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
722
- args.eval_strategy = 'steps'
723
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
724
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
725
- if ga_steps is not None and ga_steps > 1:
726
- from transformers import __version__ as transformers_version
727
- if Version(transformers_version) <= Version('4.45.2'):
728
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
729
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
730
- if getattr(args, 'eval_strategy', 'no') != 'no':
731
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
732
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
733
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
734
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
735
- if type(fp16_full_eval) is not bool: fp16_full_eval = False
736
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
737
- if type(bf16_full_eval) is not bool: bf16_full_eval = False
738
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
739
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
740
- if force_float32:
741
- args.bf16_full_eval = False
742
- args.fp16_full_eval = False
743
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
744
- args.bf16_full_eval = True
745
- args.fp16_full_eval = False
746
- elif not bf16_full_eval and not fp16_full_eval:
747
- args.bf16_full_eval = args.bf16
748
- args.fp16_full_eval = args.fp16
749
- _output_logits = False
750
- if locals().get('compute_metrics', None) is not None: _output_logits = True
751
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
752
- if _output_logits:
753
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
754
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
755
- pass
756
- else:
757
- model_max_seq_length = getattr(model, 'max_seq_length', None)
758
- args_max_seq_length = getattr(args, 'max_seq_length', None)
759
- if args_max_seq_length is None and model_max_seq_length is not None:
760
- max_seq_length = model.max_seq_length
761
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
762
- if model is not None and hasattr(model, 'for_training'):
763
- model.for_training()
764
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
765
- if 'processing_class' in locals():
766
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
767
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
768
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
769
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
770
- if not isinstance(data_collator, UnslothVisionDataCollator):
771
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
772
- data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
773
- elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
774
- data_collator = DataCollatorForSeq2Seq(__tokenizer)
775
- else:
776
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
777
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
778
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
779
- if not isinstance(data_collator, UnslothVisionDataCollator):
780
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
781
- if isinstance(data_collator, DataCollatorForSeq2Seq):
782
- data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
783
- else:
784
- data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
785
- other_metrics = []
786
-
787
- from unsloth_zoo.logging_utils import PatchRLStatistics
788
- PatchRLStatistics('prm_trainer', other_metrics)
789
-
790
- super().__init__(
791
- model = model,
792
- args = args,
793
- data_collator = data_collator,
794
- train_dataset = train_dataset,
795
- eval_dataset = eval_dataset,
796
- processing_class = processing_class,
797
- model_init = model_init,
798
- compute_metrics = compute_metrics,
799
- callbacks = callbacks,
800
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,
801
- peft_config = peft_config,**kwargs)
802
- if hasattr(self, 'neftune_hook_handle'):
803
- self.neftune_hook_handle.remove()
804
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
805
- if getattr(args, 'neftune_noise_alpha', None) is not None:
806
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
807
- pass
808
-
809
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_run_uploads/UnslothRLOOTrainer.py DELETED
@@ -1,1143 +0,0 @@
1
- """
2
- 2025.7.11
3
- 2025.7.11
4
- 4.54.1
5
- 0.16.1
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
13
- from trl.trainer.rloo_trainer import (Accelerator, BaseImageProcessor, Callable, CallbackHandler, DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK, DataCollatorWithPadding, DataLoader, Dataset, ExportableState, FeatureExtractionMixin, GenerationConfig, INVALID_LOGPROB, OnlineTrainerState, Optional, PreTrainedTokenizerBase, PrinterCallback, ProcessorMixin, RLOOConfig, RLOOTrainer, Trainer, TrainerCallback, TrainerControl, Union, batch_generation, broadcast, defaultdict, disable_dropout_in_model, exact_div, first_true_indices, forward, gather_object, gc, generate_model_card, get_comet_experiment_url, get_reporting_integration_callbacks, get_reward, is_wandb_available, log_table_to_comet_experiment, math, nn, np, os, pd, prepare_deepspeed, print_rich_table, selective_log_softmax, textwrap, time, torch, truncate_response, unwrap_model_for_generation, wandb)
14
-
15
-
16
- import os
17
- from typing import *
18
- from dataclasses import dataclass, field
19
- from packaging.version import Version
20
- import torch
21
- import numpy as np
22
- from contextlib import nullcontext
23
- from torch.nn import functional as F
24
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
25
-
26
- torch_compile_options = {
27
- "epilogue_fusion" : True,
28
- "max_autotune" : False,
29
- "shape_padding" : True,
30
- "trace.enabled" : False,
31
- "triton.cudagraphs" : False,
32
- }
33
-
34
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
35
- def chunked_selective_log_softmax(logits, index):
36
- # Split into 4 chunks only
37
- chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
38
- chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
39
- all_per_token_logps = []
40
- # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
41
- for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
42
- chunk_logits = chunk_logits.to(torch.float32)
43
- selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
44
- logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
45
- per_token_logps = selected_logits - logsumexp_values
46
- all_per_token_logps.append(per_token_logps)
47
- pass
48
- all_per_token_logps = torch.concat(all_per_token_logps)
49
- all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
50
- return all_per_token_logps
51
- @dataclass
52
- class UnslothRLOOConfig(RLOOConfig):
53
- """
54
-
55
- Configuration class for the [`RLOOTrainer`].
56
-
57
- Using [`~transformers.HfArgumentParser`] we can turn this class into
58
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
59
- command line.
60
-
61
- Parameters:
62
- exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[: -len(".py")]`):
63
- Name of this experiment.
64
- reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
65
- Path to the reward model.
66
- num_ppo_epochs (`int`, *optional*, defaults to `4`):
67
- Number of epochs to train.
68
- whiten_rewards (`bool`, *optional*, defaults to `False`):
69
- Whether to whiten the rewards.
70
- kl_coef (`float`, *optional*, defaults to `0.05`):
71
- KL coefficient.
72
- cliprange (`float`, *optional*, defaults to `0.2`):
73
- Clip range.
74
- rloo_k (`int`, *optional*, defaults to `2`):
75
- REINFORCE Leave-One-Out (RLOO) number of online samples per prompt.
76
- normalize_reward (`bool`, *optional*, defaults to `False`):
77
- Whether to normalize rewards.
78
- reward_clip_range (`float`, *optional*, defaults to `10.0`):
79
- Clip range for rewards.
80
- normalize_advantage (`bool`, *optional*, defaults to `False`):
81
- Whether to normalize advantages.
82
- token_level_kl (`bool`, *optional*, defaults to `True`):
83
- Whether to use token-level KL penalty or sequence-level KL penalty.
84
- ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
85
- This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
86
- improving generation speed. However, disabling this option allows training models that exceed the VRAM
87
- capacity of a single GPU, albeit at the cost of slower generation.
88
-
89
- """
90
- vllm_sampling_params: Optional[Any] = field(
91
- default = None,
92
- metadata = {'help': 'vLLM SamplingParams'},
93
- )
94
- unsloth_num_chunks : Optional[int] = field(
95
- default = -1,
96
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
97
- )
98
- def __init__(
99
- self,
100
- output_dir = None,
101
- overwrite_output_dir = None,
102
- do_train = False,
103
- do_eval = False,
104
- do_predict = False,
105
- eval_strategy = 'no',
106
- prediction_loss_only = False,
107
- per_device_train_batch_size = 4,
108
- per_device_eval_batch_size = 4,
109
- per_gpu_train_batch_size = None,
110
- per_gpu_eval_batch_size = None,
111
- gradient_accumulation_steps = 2,
112
- eval_accumulation_steps = 2,
113
- eval_delay = 0,
114
- torch_empty_cache_steps = 250,
115
- learning_rate = 5e-05,
116
- weight_decay = 0.01,
117
- adam_beta1 = 0.9,
118
- adam_beta2 = 0.999,
119
- adam_epsilon = 1e-08,
120
- max_grad_norm = 1.0,
121
- num_train_epochs = 3.0,
122
- max_steps = -1,
123
- lr_scheduler_type = 'linear',
124
- warmup_ratio = 0.1,
125
- warmup_steps = 0,
126
- log_level = 'passive',
127
- log_level_replica = 'warning',
128
- log_on_each_node = True,
129
- logging_dir = None,
130
- logging_strategy = 'steps',
131
- logging_first_step = False,
132
- logging_steps = 1,
133
- logging_nan_inf_filter = False,
134
- save_strategy = 'steps',
135
- save_steps = 500,
136
- save_total_limit = None,
137
- save_safetensors = True,
138
- save_on_each_node = False,
139
- save_only_model = False,
140
- restore_callback_states_from_checkpoint = False,
141
- no_cuda = False,
142
- use_cpu = False,
143
- use_mps_device = False,
144
- seed = 3407,
145
- data_seed = 3407,
146
- jit_mode_eval = False,
147
- use_ipex = False,
148
- bf16 = False,
149
- fp16 = False,
150
- fp16_opt_level = 'O1',
151
- half_precision_backend = 'auto',
152
- bf16_full_eval = False,
153
- fp16_full_eval = False,
154
- tf32 = None,
155
- local_rank = -1,
156
- ddp_backend = None,
157
- tpu_num_cores = None,
158
- tpu_metrics_debug = False,
159
- debug = '',
160
- dataloader_drop_last = False,
161
- eval_steps = None,
162
- dataloader_num_workers = 0,
163
- dataloader_prefetch_factor = None,
164
- past_index = -1,
165
- run_name = None,
166
- disable_tqdm = None,
167
- remove_unused_columns = True,
168
- label_names = None,
169
- load_best_model_at_end = False,
170
- metric_for_best_model = None,
171
- greater_is_better = None,
172
- ignore_data_skip = False,
173
- fsdp = '',
174
- fsdp_min_num_params = 0,
175
- fsdp_config = None,
176
- fsdp_transformer_layer_cls_to_wrap = None,
177
- accelerator_config = None,
178
- deepspeed = None,
179
- label_smoothing_factor = 0.0,
180
- optim = 'adamw_8bit',
181
- optim_args = None,
182
- adafactor = False,
183
- group_by_length = False,
184
- length_column_name = 'length',
185
- report_to = None,
186
- ddp_find_unused_parameters = None,
187
- ddp_bucket_cap_mb = None,
188
- ddp_broadcast_buffers = None,
189
- dataloader_pin_memory = True,
190
- dataloader_persistent_workers = False,
191
- skip_memory_metrics = True,
192
- use_legacy_prediction_loop = False,
193
- push_to_hub = False,
194
- resume_from_checkpoint = None,
195
- hub_model_id = None,
196
- hub_strategy = 'every_save',
197
- hub_token = None,
198
- hub_private_repo = None,
199
- hub_always_push = False,
200
- hub_revision = None,
201
- gradient_checkpointing = False,
202
- gradient_checkpointing_kwargs = None,
203
- include_inputs_for_metrics = False,
204
- eval_do_concat_batches = True,
205
- fp16_backend = 'auto',
206
- push_to_hub_model_id = None,
207
- push_to_hub_organization = None,
208
- push_to_hub_token = None,
209
- mp_parameters = '',
210
- auto_find_batch_size = True,
211
- full_determinism = False,
212
- torchdynamo = None,
213
- ray_scope = 'last',
214
- ddp_timeout = 1800,
215
- torch_compile = False,
216
- torch_compile_backend = None,
217
- torch_compile_mode = None,
218
- include_tokens_per_second = False,
219
- include_num_input_tokens_seen = False,
220
- neftune_noise_alpha = None,
221
- optim_target_modules = None,
222
- batch_eval_metrics = False,
223
- eval_on_start = False,
224
- use_liger_kernel = False,
225
- liger_kernel_config = None,
226
- eval_use_gather_object = False,
227
- average_tokens_across_devices = True,
228
- dataset_num_proc = None,
229
- num_mini_batches = 1,
230
- total_episodes = None,
231
- local_rollout_forward_batch_size = 64,
232
- num_sample_generations = 10,
233
- response_length = 53,
234
- stop_token = None,
235
- stop_token_id = None,
236
- temperature = 0.7,
237
- missing_eos_penalty = None,
238
- sft_model_path = 'EleutherAI/pythia-160m',
239
- world_size = None,
240
- num_total_batches = None,
241
- micro_batch_size = None,
242
- local_batch_size = None,
243
- batch_size = None,
244
- local_mini_batch_size = None,
245
- mini_batch_size = None,
246
- exp_name = 'rloo_config',
247
- reward_model_path = 'EleutherAI/pythia-160m',
248
- num_ppo_epochs = 4,
249
- whiten_rewards = False,
250
- kl_coef = 0.05,
251
- cliprange = 0.2,
252
- rloo_k = 2,
253
- normalize_reward = False,
254
- reward_clip_range = 10.0,
255
- normalize_advantage = False,
256
- token_level_kl = False,
257
- ds3_gather_for_generation = True,
258
- vllm_sampling_params = None,
259
- unsloth_num_chunks = -1,
260
- **kwargs,
261
- ):
262
- if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
263
- if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
264
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
265
- output_dir = 'unsloth_training_checkpoints'
266
- save_strategy = 'no'
267
- if dataset_num_proc is None:
268
- from multiprocessing import cpu_count
269
- dataset_num_proc = min(cpu_count()*2, 2)
270
- if temperature <= 0:
271
- raise MathError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
272
- elif temperature >= 10:
273
- raise MathError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
274
-
275
-
276
- super().__init__(
277
- output_dir = output_dir,
278
- overwrite_output_dir = overwrite_output_dir,
279
- do_train = do_train,
280
- do_eval = do_eval,
281
- do_predict = do_predict,
282
- eval_strategy = eval_strategy,
283
- prediction_loss_only = prediction_loss_only,
284
- per_device_train_batch_size = per_device_train_batch_size,
285
- per_device_eval_batch_size = per_device_eval_batch_size,
286
- per_gpu_train_batch_size = per_gpu_train_batch_size,
287
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
288
- gradient_accumulation_steps = gradient_accumulation_steps,
289
- eval_accumulation_steps = eval_accumulation_steps,
290
- eval_delay = eval_delay,
291
- torch_empty_cache_steps = torch_empty_cache_steps,
292
- learning_rate = learning_rate,
293
- weight_decay = weight_decay,
294
- adam_beta1 = adam_beta1,
295
- adam_beta2 = adam_beta2,
296
- adam_epsilon = adam_epsilon,
297
- max_grad_norm = max_grad_norm,
298
- num_train_epochs = num_train_epochs,
299
- max_steps = max_steps,
300
- lr_scheduler_type = lr_scheduler_type,
301
- warmup_ratio = warmup_ratio,
302
- warmup_steps = warmup_steps,
303
- log_level = log_level,
304
- log_level_replica = log_level_replica,
305
- log_on_each_node = log_on_each_node,
306
- logging_dir = logging_dir,
307
- logging_strategy = logging_strategy,
308
- logging_first_step = logging_first_step,
309
- logging_steps = logging_steps,
310
- logging_nan_inf_filter = logging_nan_inf_filter,
311
- save_strategy = save_strategy,
312
- save_steps = save_steps,
313
- save_total_limit = save_total_limit,
314
- save_safetensors = save_safetensors,
315
- save_on_each_node = save_on_each_node,
316
- save_only_model = save_only_model,
317
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
318
- no_cuda = no_cuda,
319
- use_cpu = use_cpu,
320
- use_mps_device = use_mps_device,
321
- seed = seed,
322
- data_seed = data_seed,
323
- jit_mode_eval = jit_mode_eval,
324
- use_ipex = use_ipex,
325
- bf16 = bf16,
326
- fp16 = fp16,
327
- fp16_opt_level = fp16_opt_level,
328
- half_precision_backend = half_precision_backend,
329
- bf16_full_eval = bf16_full_eval,
330
- fp16_full_eval = fp16_full_eval,
331
- tf32 = tf32,
332
- local_rank = local_rank,
333
- ddp_backend = ddp_backend,
334
- tpu_num_cores = tpu_num_cores,
335
- tpu_metrics_debug = tpu_metrics_debug,
336
- debug = debug,
337
- dataloader_drop_last = dataloader_drop_last,
338
- eval_steps = eval_steps,
339
- dataloader_num_workers = dataloader_num_workers,
340
- dataloader_prefetch_factor = dataloader_prefetch_factor,
341
- past_index = past_index,
342
- run_name = run_name,
343
- disable_tqdm = disable_tqdm,
344
- remove_unused_columns = remove_unused_columns,
345
- label_names = label_names,
346
- load_best_model_at_end = load_best_model_at_end,
347
- metric_for_best_model = metric_for_best_model,
348
- greater_is_better = greater_is_better,
349
- ignore_data_skip = ignore_data_skip,
350
- fsdp = fsdp,
351
- fsdp_min_num_params = fsdp_min_num_params,
352
- fsdp_config = fsdp_config,
353
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
354
- accelerator_config = accelerator_config,
355
- deepspeed = deepspeed,
356
- label_smoothing_factor = label_smoothing_factor,
357
- optim = optim,
358
- optim_args = optim_args,
359
- adafactor = adafactor,
360
- group_by_length = group_by_length,
361
- length_column_name = length_column_name,
362
- report_to = report_to,
363
- ddp_find_unused_parameters = ddp_find_unused_parameters,
364
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
365
- ddp_broadcast_buffers = ddp_broadcast_buffers,
366
- dataloader_pin_memory = dataloader_pin_memory,
367
- dataloader_persistent_workers = dataloader_persistent_workers,
368
- skip_memory_metrics = skip_memory_metrics,
369
- use_legacy_prediction_loop = use_legacy_prediction_loop,
370
- push_to_hub = push_to_hub,
371
- resume_from_checkpoint = resume_from_checkpoint,
372
- hub_model_id = hub_model_id,
373
- hub_strategy = hub_strategy,
374
- hub_token = hub_token,
375
- hub_private_repo = hub_private_repo,
376
- hub_always_push = hub_always_push,
377
- hub_revision = hub_revision,
378
- gradient_checkpointing = gradient_checkpointing,
379
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
380
- include_inputs_for_metrics = include_inputs_for_metrics,
381
- eval_do_concat_batches = eval_do_concat_batches,
382
- fp16_backend = fp16_backend,
383
- push_to_hub_model_id = push_to_hub_model_id,
384
- push_to_hub_organization = push_to_hub_organization,
385
- push_to_hub_token = push_to_hub_token,
386
- mp_parameters = mp_parameters,
387
- auto_find_batch_size = auto_find_batch_size,
388
- full_determinism = full_determinism,
389
- torchdynamo = torchdynamo,
390
- ray_scope = ray_scope,
391
- ddp_timeout = ddp_timeout,
392
- torch_compile = torch_compile,
393
- torch_compile_backend = torch_compile_backend,
394
- torch_compile_mode = torch_compile_mode,
395
- include_tokens_per_second = include_tokens_per_second,
396
- include_num_input_tokens_seen = include_num_input_tokens_seen,
397
- neftune_noise_alpha = neftune_noise_alpha,
398
- optim_target_modules = optim_target_modules,
399
- batch_eval_metrics = batch_eval_metrics,
400
- eval_on_start = eval_on_start,
401
- use_liger_kernel = use_liger_kernel,
402
- liger_kernel_config = liger_kernel_config,
403
- eval_use_gather_object = eval_use_gather_object,
404
- average_tokens_across_devices = average_tokens_across_devices,
405
- dataset_num_proc = dataset_num_proc,
406
- num_mini_batches = num_mini_batches,
407
- total_episodes = total_episodes,
408
- local_rollout_forward_batch_size = local_rollout_forward_batch_size,
409
- num_sample_generations = num_sample_generations,
410
- response_length = response_length,
411
- stop_token = stop_token,
412
- stop_token_id = stop_token_id,
413
- temperature = temperature,
414
- missing_eos_penalty = missing_eos_penalty,
415
- sft_model_path = sft_model_path,
416
- world_size = world_size,
417
- num_total_batches = num_total_batches,
418
- micro_batch_size = micro_batch_size,
419
- local_batch_size = local_batch_size,
420
- batch_size = batch_size,
421
- local_mini_batch_size = local_mini_batch_size,
422
- mini_batch_size = mini_batch_size,
423
- exp_name = exp_name,
424
- reward_model_path = reward_model_path,
425
- num_ppo_epochs = num_ppo_epochs,
426
- whiten_rewards = whiten_rewards,
427
- kl_coef = kl_coef,
428
- cliprange = cliprange,
429
- rloo_k = rloo_k,
430
- normalize_reward = normalize_reward,
431
- reward_clip_range = reward_clip_range,
432
- normalize_advantage = normalize_advantage,
433
- token_level_kl = token_level_kl,
434
- ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
435
- self.vllm_sampling_params = vllm_sampling_params
436
- self.unsloth_num_chunks = unsloth_num_chunks
437
- pass
438
-
439
- class _UnslothRLOOTrainer(Trainer):
440
- _tag_names = ["trl", "rloo"]
441
-
442
- def __init__(
443
- self,
444
- config: RLOOConfig,
445
- processing_class: Optional[
446
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
447
- ],
448
- policy: nn.Module,
449
- ref_policy: nn.Module,
450
- reward_model: Union[nn.Module, Callable[[list[str]], list[float]]],
451
- train_dataset: Dataset,
452
- data_collator: Optional[DataCollatorWithPadding] = None,
453
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
454
- # less commonly used
455
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
456
- callbacks: Optional[list[TrainerCallback]] = None,
457
- ) -> None:
458
- if ref_policy is policy:
459
- raise ValueError(
460
- "`policy` and `ref_policy` cannot be the same object. If you want `ref_policy` to be the "
461
- "same as `policy`, you must mass a copy of it, or `None` if you use peft."
462
- )
463
-
464
- self.args = config
465
- args = config
466
- self.processing_class = processing_class
467
- self.policy = policy
468
-
469
- # Define the collator if not provided
470
- if data_collator is None:
471
- data_collator = DataCollatorWithPadding(self.processing_class)
472
-
473
- self.policy.generation_config.eos_token_id = (
474
- None # disable `pad_token_id` and `eos_token_id` because we just want to
475
- )
476
- self.policy.generation_config.pad_token_id = None # generate tokens without truncation / padding
477
-
478
- self.ref_policy = ref_policy
479
- self.reward_model = reward_model
480
- self.train_dataset = train_dataset
481
- self.train_dataset_len = len(train_dataset)
482
- self.data_collator = data_collator
483
- self.eval_dataset = eval_dataset
484
- self.optimizer, self.lr_scheduler = optimizers
485
- self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
486
-
487
- #########
488
- # calculate various batch sizes
489
- #########
490
- if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
491
- args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
492
- accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
493
- self.accelerator = accelerator
494
- args.world_size = accelerator.num_processes
495
- args.local_batch_size = (
496
- args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches
497
- )
498
- args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
499
- args.batch_size = int(args.local_batch_size * args.world_size)
500
- args.mini_batch_size = exact_div(
501
- args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
502
- )
503
- args.local_mini_batch_size = exact_div(
504
- args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
505
- )
506
- args.num_total_batches = math.ceil(
507
- args.total_episodes / args.batch_size
508
- ) # we may train for more than `total_episodes`
509
- time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
510
- time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
511
- args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
512
- self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
513
- if args.num_sample_generations > 0:
514
- self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
515
- self.local_dataloader_batch_size = exact_div(
516
- args.local_batch_size, args.rloo_k, "`local_batch_size` must be a multiple of rloo_k"
517
- ) # RLOO logic: needed because RLOO repeats the same prompt args.rloo_k times
518
-
519
- #########
520
- # setup model, optimizer, and others
521
- #########
522
- for module in [policy, ref_policy, reward_model]:
523
- if isinstance(module, nn.Module):
524
- disable_dropout_in_model(module)
525
- if args.stop_token and args.stop_token == "eos":
526
- args.stop_token_id = self.processing_class.eos_token_id
527
- self.model = policy
528
- self.create_optimizer_and_scheduler(
529
- num_training_steps=args.num_total_batches
530
- ) # note that we are calling `self.lr_scheduler.step[]` manually only at the batch level
531
-
532
- #########
533
- ### trainer specifics
534
- #########
535
- default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
536
- self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
537
- self.callback_handler = CallbackHandler(
538
- self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
539
- )
540
- self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
541
- self.control = TrainerControl()
542
- self.state = OnlineTrainerState(
543
- is_local_process_zero=self.is_local_process_zero(),
544
- is_world_process_zero=self.is_world_process_zero(),
545
- stateful_callbacks=[
546
- cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
547
- ],
548
- )
549
-
550
- self.current_flos = 0
551
- self.hp_search_backend = None
552
- self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
553
- self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
554
- # Create distant repo and output directory if needed
555
- self.hub_model_id = None
556
- if self.args.push_to_hub:
557
- self.init_hf_repo()
558
- if self.args.should_save:
559
- os.makedirs(self.args.output_dir, exist_ok=True)
560
- self.backup_model = None
561
-
562
- # Add tags for models that have been loaded with the correct transformers version
563
- if hasattr(self.model, "add_model_tags"):
564
- self.model.add_model_tags(self._tag_names)
565
-
566
- #########
567
- ### setup dataloader
568
- #########
569
- self.dataloader = DataLoader(
570
- self.train_dataset,
571
- batch_size=self.local_dataloader_batch_size,
572
- shuffle=True,
573
- collate_fn=self.data_collator,
574
- drop_last=True, # needed; otherwise the last batch will be of ragged shape
575
- )
576
- # sync random states for DataLoader[shuffle=True] before `accelerator.prepare`
577
- # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
578
- torch.manual_seed(args.seed)
579
- self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
580
- torch.manual_seed(self.local_seed) # reset the local seed again
581
-
582
- self.eval_dataloader = DataLoader(
583
- self.eval_dataset,
584
- batch_size=args.per_device_eval_batch_size,
585
- collate_fn=self.data_collator,
586
- drop_last=True,
587
- ) # no need to shuffle eval dataset
588
- self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
589
-
590
- if self.is_deepspeed_enabled:
591
- if isinstance(self.reward_model, nn.Module):
592
- self.reward_model = prepare_deepspeed(
593
- self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
594
- )
595
- self.ref_policy = prepare_deepspeed(
596
- self.ref_policy, args.per_device_train_batch_size, args.fp16, args.bf16
597
- )
598
- self.deepspeed = self.model
599
- else:
600
- self.ref_policy = self.ref_policy.to(self.accelerator.device)
601
- if isinstance(self.reward_model, nn.Module):
602
- self.reward_model = self.reward_model.to(self.accelerator.device)
603
-
604
- def get_train_dataloader(self) -> DataLoader:
605
- return self.dataloader
606
-
607
- def get_eval_dataloader(self) -> DataLoader:
608
- return self.eval_dataloader
609
-
610
- def train(self):
611
- args = self.args
612
- accelerator = self.accelerator
613
- optimizer = self.optimizer
614
- model = self.model
615
- self.model_wrapped = self.model
616
- ref_policy = self.ref_policy
617
- reward_model = self.reward_model
618
- processing_class = self.processing_class
619
- dataloader = self.dataloader
620
- device = accelerator.device
621
-
622
- def repeat_generator():
623
- while True:
624
- yield from dataloader
625
-
626
- iter_dataloader = iter(repeat_generator())
627
- generation_config = GenerationConfig(
628
- max_new_tokens=args.response_length,
629
- temperature=(args.temperature + 1e-7),
630
- top_k=0.0,
631
- top_p=1.0,
632
- do_sample=True,
633
- )
634
-
635
- accelerator.print("===training policy===")
636
- start_time = time.time()
637
- stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
638
- approxkl_stats = torch.zeros(stats_shape, device=device)
639
- pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
640
- pg_loss_stats = torch.zeros(stats_shape, device=device)
641
- vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
642
- entropy_stats = torch.zeros(stats_shape, device=device)
643
- ratio_stats = torch.zeros(stats_shape, device=device)
644
- model.train()
645
-
646
- # trainer state initialization
647
- self.state.global_step = 0
648
- self.state.episode = 0
649
- self.state.max_steps = (args.num_total_batches * args.num_mini_batches) // 2
650
- self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
651
- # Compute absolute values for logging, eval, and save if given as ratio
652
- if args.logging_steps is not None:
653
- if args.logging_steps < 1:
654
- self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
655
- else:
656
- self.state.logging_steps = args.logging_steps
657
- if args.eval_steps is not None:
658
- if args.eval_steps < 1:
659
- self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
660
- else:
661
- self.state.eval_steps = args.eval_steps
662
- if args.save_steps is not None:
663
- if args.save_steps < 1:
664
- self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
665
- else:
666
- self.state.save_steps = args.save_steps
667
- self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
668
-
669
- for update in range(1, args.num_total_batches + 1):
670
- self.state.episode += 1 * args.batch_size
671
- data = next(iter_dataloader)
672
- with torch.no_grad():
673
- queries = data["input_ids"].to(device)
674
- queries = queries.repeat(args.rloo_k, 1)
675
- context_length = queries.shape[1]
676
- responses = []
677
- postprocessed_responses = []
678
- logprobs = []
679
- ref_logprobs = []
680
- scores = []
681
- sequence_lengths = []
682
-
683
- # Generate responses and compute logprobs
684
- with unwrap_model_for_generation(
685
- self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
686
- ) as unwrapped_model:
687
- query_responses, logitss = batch_generation(
688
- unwrapped_model,
689
- queries,
690
- args.local_rollout_forward_batch_size,
691
- processing_class.pad_token_id,
692
- generation_config,
693
- )
694
-
695
- # Process responses in batches
696
- for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
697
- query = queries[i : i + args.local_rollout_forward_batch_size]
698
- query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
699
- response = query_response[:, context_length:]
700
- logits = logitss[i : i + args.local_rollout_forward_batch_size]
701
- logprob = selective_log_softmax(logits, response)
702
- del logits
703
- torch.cuda.empty_cache()
704
-
705
- ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
706
- ref_logits = ref_output.logits[:, context_length - 1 : -1]
707
- ref_logits /= args.temperature + 1e-7
708
- ref_logprob = selective_log_softmax(ref_logits, response)
709
- del ref_output, ref_logits
710
- torch.cuda.empty_cache()
711
-
712
- # Response Processing 1. truncate response after the first occurrence of `stop_token_id`
713
- postprocessed_response = response
714
- if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
715
- postprocessed_response = truncate_response(
716
- args.stop_token_id, processing_class.pad_token_id, response
717
- )
718
-
719
- # Response Processing 2. run reward model on the truncated responses
720
- postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
721
- sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
722
-
723
- if isinstance(reward_model, nn.Module):
724
- _, score, _ = get_reward(
725
- reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
726
- )
727
- else:
728
- score = torch.tensor(
729
- reward_model(
730
- processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True)
731
- ),
732
- dtype=torch.float,
733
- ).to(device)
734
-
735
- # Store batch results
736
- responses.append(response)
737
- postprocessed_responses.append(postprocessed_response)
738
- logprobs.append(logprob)
739
- ref_logprobs.append(ref_logprob)
740
- sequence_lengths.append(sequence_length)
741
- scores.append(score)
742
-
743
- # Concatenate all batched results
744
- responses = torch.cat(responses, 0)
745
- postprocessed_responses = torch.cat(postprocessed_responses, 0)
746
- logprobs = torch.cat(logprobs, 0)
747
- ref_logprobs = torch.cat(ref_logprobs, 0)
748
- sequence_lengths = torch.cat(sequence_lengths, 0)
749
- scores = torch.cat(scores, 0)
750
- del (logprob, ref_logprob, score)
751
- torch.cuda.empty_cache()
752
- gc.collect()
753
-
754
- # Response Processing 3. filter response. Ensure that the sample contains stop_token_id
755
- # responses not passing that filter will receive a low (fixed) score
756
- # only query humans on responses that pass that filter
757
- contain_eos_token = torch.any(postprocessed_responses == processing_class.eos_token_id, dim=-1)
758
- if args.missing_eos_penalty is not None:
759
- scores[~contain_eos_token] -= self.args.missing_eos_penalty
760
- # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
761
-
762
- # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
763
- response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
764
- padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
765
- logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
766
- ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
767
-
768
- # 4. compute rewards
769
- # Compute KL divergence
770
- kl = logprobs - ref_logprobs
771
-
772
- # Normalize rewards
773
- if args.normalize_reward:
774
- scores = (scores - scores.mean()) / (scores.std() + 1e-8)
775
- scores = torch.clamp(scores, -args.reward_clip_range, args.reward_clip_range)
776
-
777
- # Compute total reward with KL penalty
778
- if args.token_level_kl:
779
- # Token-level KL penalty: apply KL penalty per token
780
- kl_reward = -args.kl_coef * kl
781
-
782
- # Get the index of the last non-padded token for each sequence
783
- eos_indices = padding_mask.size(1) - 1 - padding_mask.long().fliplr().argmax(dim=1, keepdim=True)
784
- last_reward = torch.zeros_like(kl)
785
- # Ensure scores has correct shape and type
786
- scores_shaped = scores.reshape(-1, 1).to(kl.dtype)
787
- last_reward.scatter_(dim=1, index=eos_indices, src=scores_shaped)
788
-
789
- # Combine KL reward and last reward
790
- non_score_reward = kl_reward.sum(1) # Keep this for logging
791
- reward = last_reward + kl_reward
792
- rlhf_reward = reward.sum(1) # Sum across sequence length
793
- else:
794
- # Sequence-level KL penalty: sum KL across tokens first
795
- sequence_kl = kl.sum(1)
796
- non_score_reward = -args.kl_coef * sequence_kl
797
- rlhf_reward = non_score_reward + scores
798
-
799
- # vectorized RLOO advantages implementation
800
- rlhf_reward = rlhf_reward.reshape(args.rloo_k, -1)
801
- baseline = (rlhf_reward.sum(0) - rlhf_reward) / (args.rloo_k - 1)
802
- advantages = rlhf_reward - baseline
803
- advantages = advantages.flatten()
804
-
805
- # Normalize advantages
806
- if args.normalize_advantage:
807
- advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
808
-
809
- torch.cuda.empty_cache()
810
-
811
- # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
812
- for ppo_epoch_idx in range(args.num_ppo_epochs):
813
- b_inds = np.random.permutation(args.local_batch_size)
814
- minibatch_idx = 0
815
- for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
816
- mini_batch_end = mini_batch_start + args.local_mini_batch_size
817
- mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
818
- gradient_accumulation_idx = 0
819
- for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
820
- with accelerator.accumulate(model):
821
- micro_batch_end = micro_batch_start + args.per_device_train_batch_size
822
- micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
823
-
824
- # Get batch data
825
- mb_advantage = advantages[micro_batch_inds]
826
- mb_responses = responses[micro_batch_inds]
827
- mb_query_responses = query_responses[micro_batch_inds]
828
- mb_logprobs = logprobs[micro_batch_inds]
829
-
830
- # Forward pass
831
- output = forward(model, mb_query_responses, processing_class.pad_token_id)
832
- logits = output.logits[:, context_length - 1 : -1]
833
- logits /= args.temperature + 1e-7
834
-
835
- # Compute new logprobs
836
- new_logprobs = selective_log_softmax(logits, mb_responses)
837
- new_logprobs = torch.masked_fill(
838
- new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
839
- )
840
-
841
- # Compute probability ratios
842
- new_ratio = (new_logprobs - mb_logprobs).exp()
843
- new_logprobs = new_logprobs.sum(1)
844
- mb_logprobs = mb_logprobs.sum(1)
845
- logprobs_diff = new_logprobs - mb_logprobs
846
- ratio = torch.exp(logprobs_diff)
847
-
848
- # PPO clipped loss
849
- pg_losses = -mb_advantage * ratio
850
- pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
851
- pg_loss_max = torch.max(pg_losses, pg_losses2)
852
- pg_loss = pg_loss_max.mean()
853
-
854
- # Final loss
855
- loss = pg_loss
856
-
857
- # Optimization step
858
- accelerator.backward(loss)
859
- optimizer.step()
860
- optimizer.zero_grad()
861
-
862
- with torch.no_grad():
863
- pg_clipfrac = (pg_losses2 > pg_losses).float().mean()
864
- prob_dist = torch.nn.functional.softmax(logits, dim=-1)
865
- entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
866
- approxkl = 0.5 * (logprobs_diff**2).mean()
867
- approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
868
- pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
869
- pg_clipfrac
870
- )
871
- pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
872
- entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
873
- ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = new_ratio.mean()
874
- gradient_accumulation_idx += 1
875
- minibatch_idx += 1
876
-
877
- # del everything and empty cache
878
- # fmt: off
879
- del (
880
- output, logits, new_logprobs, logprobs_diff, ratio, pg_losses,
881
- pg_losses2, pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl,
882
- mb_advantage, mb_responses, mb_query_responses, mb_logprobs,
883
- )
884
- # fmt: on
885
- torch.cuda.empty_cache()
886
-
887
- # Compute metrics
888
- with torch.no_grad():
889
- mean_kl = kl.sum(1).mean()
890
- mean_entropy = (-logprobs).sum(1).mean()
891
- mean_non_score_reward = non_score_reward.mean()
892
- eps = int(self.state.episode / (time.time() - start_time))
893
- metrics = {}
894
- metrics["eps"] = eps
895
- metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
896
- metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
897
- metrics["objective/non_score_reward"] = (
898
- self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
899
- )
900
- metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
901
- metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
902
- metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
903
- metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
904
- metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
905
- metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
906
- metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
907
- metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
908
- metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
909
- metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
910
- metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
911
- metrics["episode"] = self.state.episode
912
- self.state.epoch = self.state.episode / (args.rloo_k * self.train_dataset_len) # used by self.log
913
- self.log(metrics)
914
- del kl, mean_kl, mean_entropy, scores
915
-
916
- self.lr_scheduler.step()
917
- self.state.global_step += 1
918
- self.control = self.callback_handler.on_step_end(args, self.state, self.control)
919
- if self.control.should_save:
920
- self._save_checkpoint(model, trial=None)
921
- self.control = self.callback_handler.on_save(self.args, self.state, self.control)
922
- torch.cuda.empty_cache()
923
- gc.collect()
924
-
925
- if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
926
- self.generate_completions(sampling=True)
927
-
928
- # HF trainer specifics
929
- self.control = self.callback_handler.on_train_end(args, self.state, self.control)
930
- if self.control.should_save:
931
- self._save_checkpoint(model, trial=None, metrics=None)
932
- self.control = self.callback_handler.on_save(self.args, self.state, self.control)
933
-
934
- def generate_completions(self, sampling: bool = False):
935
- args = self.args
936
- processing_class = self.processing_class
937
- generation_config = GenerationConfig(
938
- max_new_tokens=self.args.response_length,
939
- temperature=(0.01 + 1e-7),
940
- top_k=0.0,
941
- top_p=1.0,
942
- do_sample=True,
943
- )
944
-
945
- table = defaultdict(list)
946
- with unwrap_model_for_generation(
947
- self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
948
- ) as unwrapped_model:
949
- for batch in self.eval_dataloader:
950
- query = batch["input_ids"]
951
- with torch.no_grad():
952
- context_length = query.shape[1]
953
- query_response, _ = batch_generation(
954
- unwrapped_model,
955
- query,
956
- query.shape[0],
957
- processing_class.pad_token_id,
958
- generation_config,
959
- )
960
- response = query_response[:, context_length:]
961
- postprocessed_response = response
962
- if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
963
- postprocessed_response = truncate_response(
964
- args.stop_token_id, processing_class.pad_token_id, response
965
- )
966
- table["query"].extend(
967
- gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
968
- )
969
- table["model response"].extend(
970
- gather_object(processing_class.batch_decode(postprocessed_response))
971
- )
972
-
973
- postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
974
-
975
- if isinstance(self.reward_model, nn.Module):
976
- _, score, _ = get_reward(
977
- self.reward_model,
978
- postprocessed_query_response,
979
- processing_class.pad_token_id,
980
- context_length,
981
- )
982
- else:
983
- score = torch.tensor(
984
- self.reward_model(
985
- processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True)
986
- ),
987
- dtype=torch.float,
988
- ).to(postprocessed_query_response.device)
989
- table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
990
-
991
- if sampling:
992
- break
993
- df = pd.DataFrame(table)
994
-
995
- if self.accelerator.is_main_process:
996
- print_rich_table(df.iloc[0 : 0 + 5])
997
- if "wandb" in args.report_to:
998
- import wandb
999
-
1000
- if wandb.run is not None:
1001
- wandb.log({"completions": wandb.Table(dataframe=df)})
1002
-
1003
- if "comet_ml" in args.report_to:
1004
- log_table_to_comet_experiment(
1005
- name="completions.csv",
1006
- table=df,
1007
- )
1008
-
1009
- def create_model_card(
1010
- self,
1011
- model_name: Optional[str] = None,
1012
- dataset_name: Optional[str] = None,
1013
- tags: Union[str, list[str], None] = None,
1014
- ):
1015
- """
1016
- Creates a draft of a model card using the information available to the `Trainer`.
1017
-
1018
- Args:
1019
- model_name (`str` or `None`, *optional*, defaults to `None`):
1020
- Name of the model.
1021
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
1022
- Name of the dataset used for training.
1023
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1024
- Tags to be associated with the model card.
1025
- """
1026
- if not self.is_world_process_zero():
1027
- return
1028
-
1029
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1030
- base_model = self.model.config._name_or_path
1031
- else:
1032
- base_model = None
1033
-
1034
- tags = tags or []
1035
- if isinstance(tags, str):
1036
- tags = [tags]
1037
-
1038
- if hasattr(self.model.config, "unsloth_version"):
1039
- tags.append("unsloth")
1040
-
1041
- citation = textwrap.dedent("""\
1042
- @inproceedings{ahmadian2024back,
1043
- title = {{Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs}},
1044
- author = {Arash Ahmadian and Chris Cremer and Matthias Gall{\'{e}} and Marzieh Fadaee and Julia Kreutzer and Olivier Pietquin and Ahmet {\"{U}}st{\"{u}}n and Sara Hooker},
1045
- year = 2024,
1046
- booktitle = {Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), {ACL} 2024, Bangkok, Thailand, August 11-16, 2024},
1047
- publisher = {Association for Computational Linguistics},
1048
- pages = {12248--12267},
1049
- editor = {Lun{-}Wei Ku and Andre Martins and Vivek Srikumar},
1050
- }""")
1051
-
1052
- model_card = generate_model_card(
1053
- base_model=base_model,
1054
- model_name=model_name,
1055
- hub_model_id=self.hub_model_id,
1056
- dataset_name=dataset_name,
1057
- tags=tags,
1058
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1059
- comet_url=get_comet_experiment_url(),
1060
- trainer_name="RLOO",
1061
- trainer_citation=citation,
1062
- paper_title="Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs",
1063
- paper_id="2402.14740",
1064
- )
1065
-
1066
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
1067
- class UnslothRLOOTrainer(_UnslothRLOOTrainer):
1068
- """
1069
-
1070
- """
1071
- def __init__(
1072
- self,
1073
- config,
1074
- processing_class,
1075
- policy,
1076
- ref_policy,
1077
- reward_model,
1078
- train_dataset,
1079
- data_collator = None,
1080
- eval_dataset = None,
1081
- callbacks = None,
1082
- **kwargs
1083
- ):
1084
- if args is None: args = UnslothRLOOConfig()
1085
- _output_logits = False
1086
- if locals().get('compute_metrics', None) is not None: _output_logits = True
1087
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1088
- if _output_logits:
1089
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1090
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1091
- pass
1092
- else:
1093
- model_max_seq_length = getattr(model, 'max_seq_length', None)
1094
- args_max_seq_length = getattr(args, 'max_seq_length', None)
1095
- if args_max_seq_length is None and model_max_seq_length is not None:
1096
- max_seq_length = model.max_seq_length
1097
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1098
- if model is not None and hasattr(model, 'for_training'):
1099
- model.for_training()
1100
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1101
- if 'processing_class' in locals():
1102
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1103
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1104
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1105
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1106
- if not isinstance(data_collator, UnslothVisionDataCollator):
1107
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1108
- data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
1109
- elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1110
- data_collator = DataCollatorForSeq2Seq(__tokenizer)
1111
- else:
1112
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1113
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1114
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1115
- if not isinstance(data_collator, UnslothVisionDataCollator):
1116
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1117
- if isinstance(data_collator, DataCollatorForSeq2Seq):
1118
- data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1119
- else:
1120
- data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
1121
- other_metrics = []
1122
-
1123
- from unsloth_zoo.logging_utils import PatchRLStatistics
1124
- PatchRLStatistics('rloo_trainer', other_metrics)
1125
-
1126
- super().__init__(
1127
- config = config,
1128
- processing_class = processing_class,
1129
- policy = policy,
1130
- ref_policy = ref_policy,
1131
- reward_model = reward_model,
1132
- train_dataset = train_dataset,
1133
- data_collator = data_collator,
1134
- eval_dataset = eval_dataset,
1135
- callbacks = callbacks,**kwargs)
1136
- if hasattr(self, 'neftune_hook_handle'):
1137
- self.neftune_hook_handle.remove()
1138
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1139
- if getattr(args, 'neftune_noise_alpha', None) is not None:
1140
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1141
- pass
1142
-
1143
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_run_uploads/UnslothRewardTrainer.py DELETED
@@ -1,828 +0,0 @@
1
- """
2
- 2025.7.11
3
- 2025.7.11
4
- 4.54.1
5
- 0.16.1
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
13
- from trl.trainer.reward_trainer import (Any, BaseImageProcessor, Callable, DataCollator, Dataset, EvalPrediction, FeatureExtractionMixin, FrozenInstanceError, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RewardConfig, RewardDataCollatorWithPadding, RewardTrainer, Trainer, TrainerCallback, Union, _tokenize, compute_accuracy, decode_and_strip_padding, defaultdict, disable_dropout_in_model, gather_object, generate_model_card, get_comet_experiment_url, inspect, is_peft_available, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, nested_detach, nn, os, pd, prepare_model_for_kbit_training, print_rich_table, replace, torch, wandb, warnings)
14
-
15
-
16
- import os
17
- from typing import *
18
- from dataclasses import dataclass, field
19
- from packaging.version import Version
20
- import torch
21
- import numpy as np
22
- from contextlib import nullcontext
23
- from torch.nn import functional as F
24
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
25
-
26
- torch_compile_options = {
27
- "epilogue_fusion" : True,
28
- "max_autotune" : False,
29
- "shape_padding" : True,
30
- "trace.enabled" : False,
31
- "triton.cudagraphs" : False,
32
- }
33
-
34
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
35
- def chunked_selective_log_softmax(logits, index):
36
- # Split into 4 chunks only
37
- chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
38
- chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
39
- all_per_token_logps = []
40
- # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
41
- for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
42
- chunk_logits = chunk_logits.to(torch.float32)
43
- selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
44
- logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
45
- per_token_logps = selected_logits - logsumexp_values
46
- all_per_token_logps.append(per_token_logps)
47
- pass
48
- all_per_token_logps = torch.concat(all_per_token_logps)
49
- all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
50
- return all_per_token_logps
51
- @dataclass
52
- class UnslothRewardConfig(RewardConfig):
53
- """
54
-
55
- Configuration class for the [`RewardTrainer`].
56
-
57
- Using [`~transformers.HfArgumentParser`] we can turn this class into
58
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
59
- command line.
60
-
61
- Parameters:
62
- max_length (`int` or `None`, *optional*, defaults to `1024`):
63
- Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the
64
- limit. This argument is required if you want to use the default data collator.
65
- disable_dropout (`bool`, *optional*, defaults to `True`):
66
- Whether to disable dropout in the model.
67
- dataset_num_proc (`int`, *optional*, defaults to `None`):
68
- Number of processes to use for processing the dataset.
69
- center_rewards_coefficient (`float`, *optional*, defaults to `None`):
70
- Coefficient to incentivize the reward model to output mean-zero rewards (proposed by
71
- https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`.
72
- remove_unused_columns (`bool`, *optional*, defaults to `False`):
73
- Whether to remove the columns that are not used by the model's forward pass. Can be `True` only if
74
- the dataset is pretokenized.
75
-
76
- """
77
- vllm_sampling_params: Optional[Any] = field(
78
- default = None,
79
- metadata = {'help': 'vLLM SamplingParams'},
80
- )
81
- unsloth_num_chunks : Optional[int] = field(
82
- default = -1,
83
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
84
- )
85
- def __init__(
86
- self,
87
- output_dir = None,
88
- overwrite_output_dir = None,
89
- do_train = False,
90
- do_eval = False,
91
- do_predict = False,
92
- eval_strategy = 'no',
93
- prediction_loss_only = False,
94
- per_device_train_batch_size = 4,
95
- per_device_eval_batch_size = 4,
96
- per_gpu_train_batch_size = None,
97
- per_gpu_eval_batch_size = None,
98
- gradient_accumulation_steps = 2,
99
- eval_accumulation_steps = 2,
100
- eval_delay = 0,
101
- torch_empty_cache_steps = 250,
102
- learning_rate = 5e-05,
103
- weight_decay = 0.01,
104
- adam_beta1 = 0.9,
105
- adam_beta2 = 0.999,
106
- adam_epsilon = 1e-08,
107
- max_grad_norm = 1.0,
108
- num_train_epochs = 3.0,
109
- max_steps = -1,
110
- lr_scheduler_type = 'linear',
111
- warmup_ratio = 0.1,
112
- warmup_steps = 0,
113
- log_level = 'passive',
114
- log_level_replica = 'warning',
115
- log_on_each_node = True,
116
- logging_dir = None,
117
- logging_strategy = 'steps',
118
- logging_first_step = False,
119
- logging_steps = 1,
120
- logging_nan_inf_filter = False,
121
- save_strategy = 'steps',
122
- save_steps = 500,
123
- save_total_limit = None,
124
- save_safetensors = True,
125
- save_on_each_node = False,
126
- save_only_model = False,
127
- restore_callback_states_from_checkpoint = False,
128
- no_cuda = False,
129
- use_cpu = False,
130
- use_mps_device = False,
131
- seed = 3407,
132
- data_seed = 3407,
133
- jit_mode_eval = False,
134
- use_ipex = False,
135
- bf16 = False,
136
- fp16 = False,
137
- fp16_opt_level = 'O1',
138
- half_precision_backend = 'auto',
139
- bf16_full_eval = False,
140
- fp16_full_eval = False,
141
- tf32 = None,
142
- local_rank = -1,
143
- ddp_backend = None,
144
- tpu_num_cores = None,
145
- tpu_metrics_debug = False,
146
- debug = '',
147
- dataloader_drop_last = False,
148
- eval_steps = None,
149
- dataloader_num_workers = 0,
150
- dataloader_prefetch_factor = None,
151
- past_index = -1,
152
- run_name = None,
153
- disable_tqdm = None,
154
- remove_unused_columns = False,
155
- label_names = None,
156
- load_best_model_at_end = False,
157
- metric_for_best_model = None,
158
- greater_is_better = None,
159
- ignore_data_skip = False,
160
- fsdp = '',
161
- fsdp_min_num_params = 0,
162
- fsdp_config = None,
163
- fsdp_transformer_layer_cls_to_wrap = None,
164
- accelerator_config = None,
165
- deepspeed = None,
166
- label_smoothing_factor = 0.0,
167
- optim = 'adamw_8bit',
168
- optim_args = None,
169
- adafactor = False,
170
- group_by_length = False,
171
- length_column_name = 'length',
172
- report_to = None,
173
- ddp_find_unused_parameters = None,
174
- ddp_bucket_cap_mb = None,
175
- ddp_broadcast_buffers = None,
176
- dataloader_pin_memory = True,
177
- dataloader_persistent_workers = False,
178
- skip_memory_metrics = True,
179
- use_legacy_prediction_loop = False,
180
- push_to_hub = False,
181
- resume_from_checkpoint = None,
182
- hub_model_id = None,
183
- hub_strategy = 'every_save',
184
- hub_token = None,
185
- hub_private_repo = None,
186
- hub_always_push = False,
187
- hub_revision = None,
188
- gradient_checkpointing = False,
189
- gradient_checkpointing_kwargs = None,
190
- include_inputs_for_metrics = False,
191
- eval_do_concat_batches = True,
192
- fp16_backend = 'auto',
193
- push_to_hub_model_id = None,
194
- push_to_hub_organization = None,
195
- push_to_hub_token = None,
196
- mp_parameters = '',
197
- auto_find_batch_size = True,
198
- full_determinism = False,
199
- torchdynamo = None,
200
- ray_scope = 'last',
201
- ddp_timeout = 1800,
202
- torch_compile = False,
203
- torch_compile_backend = None,
204
- torch_compile_mode = None,
205
- include_tokens_per_second = False,
206
- include_num_input_tokens_seen = False,
207
- neftune_noise_alpha = None,
208
- optim_target_modules = None,
209
- batch_eval_metrics = False,
210
- eval_on_start = False,
211
- use_liger_kernel = False,
212
- liger_kernel_config = None,
213
- eval_use_gather_object = False,
214
- average_tokens_across_devices = True,
215
- max_length = 1024,
216
- disable_dropout = True,
217
- dataset_num_proc = None,
218
- center_rewards_coefficient = None,
219
- vllm_sampling_params = None,
220
- unsloth_num_chunks = -1,
221
- **kwargs,
222
- ):
223
- if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
224
- if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
225
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
226
- output_dir = 'unsloth_training_checkpoints'
227
- save_strategy = 'no'
228
- if dataset_num_proc is None:
229
- from multiprocessing import cpu_count
230
- dataset_num_proc = min(cpu_count()*2, 2)
231
-
232
- super().__init__(
233
- output_dir = output_dir,
234
- overwrite_output_dir = overwrite_output_dir,
235
- do_train = do_train,
236
- do_eval = do_eval,
237
- do_predict = do_predict,
238
- eval_strategy = eval_strategy,
239
- prediction_loss_only = prediction_loss_only,
240
- per_device_train_batch_size = per_device_train_batch_size,
241
- per_device_eval_batch_size = per_device_eval_batch_size,
242
- per_gpu_train_batch_size = per_gpu_train_batch_size,
243
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
244
- gradient_accumulation_steps = gradient_accumulation_steps,
245
- eval_accumulation_steps = eval_accumulation_steps,
246
- eval_delay = eval_delay,
247
- torch_empty_cache_steps = torch_empty_cache_steps,
248
- learning_rate = learning_rate,
249
- weight_decay = weight_decay,
250
- adam_beta1 = adam_beta1,
251
- adam_beta2 = adam_beta2,
252
- adam_epsilon = adam_epsilon,
253
- max_grad_norm = max_grad_norm,
254
- num_train_epochs = num_train_epochs,
255
- max_steps = max_steps,
256
- lr_scheduler_type = lr_scheduler_type,
257
- warmup_ratio = warmup_ratio,
258
- warmup_steps = warmup_steps,
259
- log_level = log_level,
260
- log_level_replica = log_level_replica,
261
- log_on_each_node = log_on_each_node,
262
- logging_dir = logging_dir,
263
- logging_strategy = logging_strategy,
264
- logging_first_step = logging_first_step,
265
- logging_steps = logging_steps,
266
- logging_nan_inf_filter = logging_nan_inf_filter,
267
- save_strategy = save_strategy,
268
- save_steps = save_steps,
269
- save_total_limit = save_total_limit,
270
- save_safetensors = save_safetensors,
271
- save_on_each_node = save_on_each_node,
272
- save_only_model = save_only_model,
273
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
274
- no_cuda = no_cuda,
275
- use_cpu = use_cpu,
276
- use_mps_device = use_mps_device,
277
- seed = seed,
278
- data_seed = data_seed,
279
- jit_mode_eval = jit_mode_eval,
280
- use_ipex = use_ipex,
281
- bf16 = bf16,
282
- fp16 = fp16,
283
- fp16_opt_level = fp16_opt_level,
284
- half_precision_backend = half_precision_backend,
285
- bf16_full_eval = bf16_full_eval,
286
- fp16_full_eval = fp16_full_eval,
287
- tf32 = tf32,
288
- local_rank = local_rank,
289
- ddp_backend = ddp_backend,
290
- tpu_num_cores = tpu_num_cores,
291
- tpu_metrics_debug = tpu_metrics_debug,
292
- debug = debug,
293
- dataloader_drop_last = dataloader_drop_last,
294
- eval_steps = eval_steps,
295
- dataloader_num_workers = dataloader_num_workers,
296
- dataloader_prefetch_factor = dataloader_prefetch_factor,
297
- past_index = past_index,
298
- run_name = run_name,
299
- disable_tqdm = disable_tqdm,
300
- remove_unused_columns = remove_unused_columns,
301
- label_names = label_names,
302
- load_best_model_at_end = load_best_model_at_end,
303
- metric_for_best_model = metric_for_best_model,
304
- greater_is_better = greater_is_better,
305
- ignore_data_skip = ignore_data_skip,
306
- fsdp = fsdp,
307
- fsdp_min_num_params = fsdp_min_num_params,
308
- fsdp_config = fsdp_config,
309
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
310
- accelerator_config = accelerator_config,
311
- deepspeed = deepspeed,
312
- label_smoothing_factor = label_smoothing_factor,
313
- optim = optim,
314
- optim_args = optim_args,
315
- adafactor = adafactor,
316
- group_by_length = group_by_length,
317
- length_column_name = length_column_name,
318
- report_to = report_to,
319
- ddp_find_unused_parameters = ddp_find_unused_parameters,
320
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
321
- ddp_broadcast_buffers = ddp_broadcast_buffers,
322
- dataloader_pin_memory = dataloader_pin_memory,
323
- dataloader_persistent_workers = dataloader_persistent_workers,
324
- skip_memory_metrics = skip_memory_metrics,
325
- use_legacy_prediction_loop = use_legacy_prediction_loop,
326
- push_to_hub = push_to_hub,
327
- resume_from_checkpoint = resume_from_checkpoint,
328
- hub_model_id = hub_model_id,
329
- hub_strategy = hub_strategy,
330
- hub_token = hub_token,
331
- hub_private_repo = hub_private_repo,
332
- hub_always_push = hub_always_push,
333
- hub_revision = hub_revision,
334
- gradient_checkpointing = gradient_checkpointing,
335
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
336
- include_inputs_for_metrics = include_inputs_for_metrics,
337
- eval_do_concat_batches = eval_do_concat_batches,
338
- fp16_backend = fp16_backend,
339
- push_to_hub_model_id = push_to_hub_model_id,
340
- push_to_hub_organization = push_to_hub_organization,
341
- push_to_hub_token = push_to_hub_token,
342
- mp_parameters = mp_parameters,
343
- auto_find_batch_size = auto_find_batch_size,
344
- full_determinism = full_determinism,
345
- torchdynamo = torchdynamo,
346
- ray_scope = ray_scope,
347
- ddp_timeout = ddp_timeout,
348
- torch_compile = torch_compile,
349
- torch_compile_backend = torch_compile_backend,
350
- torch_compile_mode = torch_compile_mode,
351
- include_tokens_per_second = include_tokens_per_second,
352
- include_num_input_tokens_seen = include_num_input_tokens_seen,
353
- neftune_noise_alpha = neftune_noise_alpha,
354
- optim_target_modules = optim_target_modules,
355
- batch_eval_metrics = batch_eval_metrics,
356
- eval_on_start = eval_on_start,
357
- use_liger_kernel = use_liger_kernel,
358
- liger_kernel_config = liger_kernel_config,
359
- eval_use_gather_object = eval_use_gather_object,
360
- average_tokens_across_devices = average_tokens_across_devices,
361
- max_length = max_length,
362
- disable_dropout = disable_dropout,
363
- dataset_num_proc = dataset_num_proc,
364
- center_rewards_coefficient = center_rewards_coefficient,**kwargs)
365
- self.vllm_sampling_params = vllm_sampling_params
366
- self.unsloth_num_chunks = unsloth_num_chunks
367
- pass
368
-
369
- class _UnslothRewardTrainer(Trainer):
370
- _tag_names = ["trl", "reward-trainer"]
371
-
372
- def __init__(
373
- self,
374
- model: Optional[Union[PreTrainedModel, nn.Module]] = None,
375
- args: Optional[RewardConfig] = None,
376
- data_collator: Optional[DataCollator] = None,
377
- train_dataset: Optional[Dataset] = None,
378
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
379
- processing_class: Optional[
380
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
381
- ] = None,
382
- model_init: Optional[Callable[[], PreTrainedModel]] = None,
383
- compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
384
- callbacks: Optional[list[TrainerCallback]] = None,
385
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
386
- None,
387
- None,
388
- ),
389
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
390
- peft_config: Optional[dict] = None,
391
- ):
392
- """
393
- Initialize RewardTrainer.
394
-
395
- Args:
396
- model (`transformers.PreTrainedModel`):
397
- The model to train, preferably an `AutoModelForSequenceClassification`.
398
- args (`RewardConfig`):
399
- The arguments to use for training.
400
- data_collator (`transformers.DataCollator`):
401
- The data collator to use for training. If None is specified, the default data collator (`RewardDataCollatorWithPadding`) will be used
402
- which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
403
- train_dataset (`datasets.Dataset`):
404
- The dataset to use for training.
405
- eval_dataset (`datasets.Dataset`):
406
- The dataset to use for evaluation.
407
- processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
408
- Processing class used to process the data. If provided, will be used to automatically process the inputs
409
- for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
410
- reuse the fine-tuned model.
411
- model_init (`Callable[[], transformers.PreTrainedModel]`):
412
- The model initializer to use for training. If None is specified, the default model initializer will be used.
413
- compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
414
- The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.
415
- callbacks (`list[transformers.TrainerCallback]`):
416
- The callbacks to use for training.
417
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
418
- The optimizer and scheduler to use for training.
419
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
420
- The function to use to preprocess the logits before computing the metrics.
421
- peft_config (`dict`, defaults to `None`):
422
- The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
423
- """
424
- if not is_peft_available() and peft_config is not None:
425
- raise ValueError(
426
- "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
427
- )
428
- elif is_peft_available() and peft_config is not None:
429
- if not isinstance(model, PeftModel):
430
- if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
431
- _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
432
- inspect.signature(prepare_model_for_kbit_training).parameters
433
- )
434
-
435
- prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
436
-
437
- if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
438
- warnings.warn(
439
- "You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
440
- "please update to the latest version of peft to use `gradient_checkpointing_kwargs`.",
441
- UserWarning,
442
- )
443
- elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
444
- prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
445
-
446
- model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
447
-
448
- model = model
449
-
450
- # Disable dropout in the model
451
- if args.disable_dropout:
452
- disable_dropout_in_model(model)
453
-
454
- if compute_metrics is None:
455
- compute_metrics = compute_accuracy
456
-
457
- if data_collator is None:
458
- if processing_class is None:
459
- raise ValueError(
460
- "A processing_class must be specified when using the default RewardDataCollatorWithPadding"
461
- )
462
-
463
- max_length = args.max_length
464
-
465
- data_collator = RewardDataCollatorWithPadding(processing_class)
466
-
467
- if args.remove_unused_columns:
468
- try: # for bc before https://github.com/huggingface/transformers/pull/25435
469
- args.remove_unused_columns = False
470
- except FrozenInstanceError:
471
- args = replace(args, remove_unused_columns=False)
472
- # warn users
473
- warnings.warn(
474
- "When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig"
475
- " we have set it for you, but you should do it yourself in the future.",
476
- UserWarning,
477
- )
478
-
479
- self.use_reward_data_collator = True
480
- else:
481
- self.use_reward_data_collator = False
482
-
483
- # The trainer estimates the number of FLOPs [floating-point operations] using the number of elements in the
484
- # input tensor associated with the key "input_ids". However, in Reward, the sampled data does not include the
485
- # "input_ids" key. Instead, the available keys are "input_ids_chosen" and "input_ids_rejected". As a result,
486
- # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
487
- # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
488
- # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
489
- # issued.
490
- model.warnings_issued["estimate_tokens"] = True
491
-
492
- if "input_ids_chosen" not in train_dataset.column_names:
493
- with PartialState().main_process_first():
494
- fn_kwargs = {"tokenizer": processing_class}
495
- train_dataset = train_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class})
496
- train_dataset = train_dataset.map(
497
- _tokenize,
498
- batched=True,
499
- fn_kwargs=fn_kwargs,
500
- num_proc=args.dataset_num_proc,
501
- )
502
- # This filter is important because otherwise you get samples that exceed the model's context length and
503
- # get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
504
- # user might get surprised if N samples are missing from training.
505
- train_dataset = train_dataset.filter(
506
- lambda x: len(x["input_ids_chosen"]) <= max_length and len(x["input_ids_rejected"]) <= max_length,
507
- num_proc=args.dataset_num_proc,
508
- )
509
- if eval_dataset is not None:
510
- eval_dataset = eval_dataset.map(
511
- maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}
512
- )
513
- eval_dataset = eval_dataset.map(
514
- _tokenize,
515
- fn_kwargs=fn_kwargs,
516
- batched=True,
517
- num_proc=args.dataset_num_proc,
518
- )
519
- # This filter is important because otherwise you get samples that exceed the model's context length and
520
- # get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
521
- # user might get surprised if N samples are missing from training.
522
- eval_dataset = eval_dataset.filter(
523
- lambda x: len(x["input_ids_chosen"]) <= max_length
524
- and len(x["input_ids_rejected"]) <= max_length,
525
- num_proc=args.dataset_num_proc,
526
- )
527
-
528
- super().__init__(
529
- model=model,
530
- args=args,
531
- data_collator=data_collator,
532
- train_dataset=train_dataset,
533
- eval_dataset=eval_dataset,
534
- processing_class=processing_class,
535
- model_init=model_init,
536
- compute_metrics=compute_metrics,
537
- callbacks=callbacks,
538
- optimizers=optimizers,
539
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
540
- )
541
-
542
- # Add tags for models that have been loaded with the correct transformers version
543
- if hasattr(self.model, "add_model_tags"):
544
- self.model.add_model_tags(self._tag_names)
545
-
546
- def compute_loss(
547
- self,
548
- model: Union[PreTrainedModel, nn.Module],
549
- inputs: dict[str, Union[torch.Tensor, Any]],
550
- return_outputs=False,
551
- num_items_in_batch=None,
552
- ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
553
- rewards_chosen = model(
554
- input_ids=inputs["input_ids_chosen"],
555
- attention_mask=inputs["attention_mask_chosen"],
556
- return_dict=True,
557
- )["logits"]
558
- rewards_rejected = model(
559
- input_ids=inputs["input_ids_rejected"],
560
- attention_mask=inputs["attention_mask_rejected"],
561
- return_dict=True,
562
- )["logits"]
563
- # calculate loss, optionally modulate with margin
564
- if "margin" in inputs:
565
- loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
566
- else:
567
- loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
568
-
569
- if self.args.center_rewards_coefficient is not None:
570
- loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected) ** 2)
571
-
572
- if return_outputs:
573
- return loss, {
574
- "rewards_chosen": rewards_chosen,
575
- "rewards_rejected": rewards_rejected,
576
- }
577
- return loss
578
-
579
- def prediction_step(
580
- self,
581
- model: Union[PreTrainedModel, nn.Module],
582
- inputs: dict[str, Union[torch.Tensor, Any]],
583
- prediction_loss_only: bool,
584
- ignore_keys: Optional[list[str]] = None,
585
- ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
586
- inputs = self._prepare_inputs(inputs)
587
- if ignore_keys is None:
588
- if hasattr(self.model, "config"):
589
- ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
590
- else:
591
- ignore_keys = []
592
-
593
- with torch.no_grad():
594
- loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True)
595
-
596
- if prediction_loss_only:
597
- return (loss, None, None)
598
-
599
- loss = loss.detach()
600
- logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
601
- logits = nested_detach(logits)
602
- # Stack accepted against rejected, mean over logits
603
- # and softmax to get preferences between accepted and rejected to sum to 1
604
- logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T
605
-
606
- labels = torch.zeros(logits.shape[0])
607
- labels = self._prepare_inputs(labels)
608
-
609
- return loss, logits, labels
610
-
611
- def evaluate(self, *args, **kwargs):
612
- num_print_samples = kwargs.pop("num_print_samples", 4)
613
- self.visualize_samples(num_print_samples)
614
- return super().evaluate(*args, **kwargs)
615
-
616
- def visualize_samples(self, num_print_samples: int):
617
- """
618
- Visualize the reward model logits prediction
619
-
620
- Args:
621
- num_print_samples (`int`, defaults to `4`):
622
- The number of samples to print. Set to `-1` to print all samples.
623
- """
624
- eval_dataloader = self.get_eval_dataloader()
625
- table = defaultdict(list)
626
- for _, inputs in enumerate(eval_dataloader):
627
- _, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False)
628
- chosen_text = decode_and_strip_padding(inputs["input_ids_chosen"], self.processing_class)
629
- rejected_text = decode_and_strip_padding(inputs["input_ids_rejected"], self.processing_class)
630
- table["chosen_text"].extend(gather_object(chosen_text))
631
- table["rejected_text"].extend(gather_object(rejected_text))
632
- table["logits"].extend(
633
- gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()])
634
- )
635
- if num_print_samples >= 0 and len(table["chosen_text"]) >= num_print_samples:
636
- break
637
- df = pd.DataFrame(table)
638
- if self.accelerator.process_index == 0:
639
- print_rich_table(df[:num_print_samples])
640
- if "wandb" in self.args.report_to:
641
- import wandb
642
-
643
- if wandb.run is not None:
644
- wandb.log({"completions": wandb.Table(dataframe=df)})
645
-
646
- if "comet_ml" in self.args.report_to:
647
- log_table_to_comet_experiment(
648
- name="completions.csv",
649
- table=df,
650
- )
651
-
652
- def create_model_card(
653
- self,
654
- model_name: Optional[str] = None,
655
- dataset_name: Optional[str] = None,
656
- tags: Union[str, list[str], None] = None,
657
- ):
658
- """
659
- Creates a draft of a model card using the information available to the `Trainer`.
660
-
661
- Args:
662
- model_name (`str` or `None`, *optional*, defaults to `None`):
663
- Name of the model.
664
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
665
- Name of the dataset used for training.
666
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
667
- Tags to be associated with the model card.
668
- """
669
- if not self.is_world_process_zero():
670
- return
671
-
672
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
673
- base_model = self.model.config._name_or_path
674
- else:
675
- base_model = None
676
-
677
- tags = tags or []
678
- if isinstance(tags, str):
679
- tags = [tags]
680
-
681
- if hasattr(self.model.config, "unsloth_version"):
682
- tags.append("unsloth")
683
-
684
- model_card = generate_model_card(
685
- base_model=base_model,
686
- model_name=model_name,
687
- hub_model_id=self.hub_model_id,
688
- dataset_name=dataset_name,
689
- tags=tags,
690
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
691
- comet_url=get_comet_experiment_url(),
692
- trainer_name="Reward",
693
- )
694
-
695
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
696
- class UnslothRewardTrainer(_UnslothRewardTrainer):
697
- """
698
-
699
- """
700
- def __init__(
701
- self,
702
- model = None,
703
- args = None,
704
- data_collator = None,
705
- train_dataset = None,
706
- eval_dataset = None,
707
- processing_class = None,
708
- model_init = None,
709
- compute_metrics = None,
710
- callbacks = None,
711
- preprocess_logits_for_metrics = None,
712
- peft_config = None,
713
- **kwargs
714
- ):
715
- if args is None: args = UnslothRewardConfig()
716
- use_bf16 = getattr(args, 'bf16', False)
717
- if type(use_bf16) is not bool: use_bf16 = False
718
- use_fp16 = getattr(args, 'fp16', False)
719
- if type(use_fp16) is not bool: use_fp16 = False
720
- force_float32 = False
721
- if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
722
- print('Unsloth: Switching to float32 training since model cannot work with float16')
723
- force_float32 = True
724
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
725
- dtype = getattr(model.config, 'torch_dtype', None)
726
- if dtype is None: dtype = model.get_input_embeddings().dtype
727
- from unsloth_zoo.utils import _get_dtype
728
- dtype = _get_dtype(dtype)
729
- float16 = dtype == torch.float16
730
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
731
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
732
- if force_float32:
733
- args.fp16 = False
734
- args.bf16 = False
735
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
736
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
737
- args.fp16 = float16
738
- args.bf16 = not float16
739
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
740
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
741
- args.eval_strategy = 'steps'
742
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
743
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
744
- if ga_steps is not None and ga_steps > 1:
745
- from transformers import __version__ as transformers_version
746
- if Version(transformers_version) <= Version('4.45.2'):
747
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
748
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
749
- if getattr(args, 'eval_strategy', 'no') != 'no':
750
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
751
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
752
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
753
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
754
- if type(fp16_full_eval) is not bool: fp16_full_eval = False
755
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
756
- if type(bf16_full_eval) is not bool: bf16_full_eval = False
757
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
758
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
759
- if force_float32:
760
- args.bf16_full_eval = False
761
- args.fp16_full_eval = False
762
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
763
- args.bf16_full_eval = True
764
- args.fp16_full_eval = False
765
- elif not bf16_full_eval and not fp16_full_eval:
766
- args.bf16_full_eval = args.bf16
767
- args.fp16_full_eval = args.fp16
768
- _output_logits = False
769
- if locals().get('compute_metrics', None) is not None: _output_logits = True
770
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
771
- if _output_logits:
772
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
773
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
774
- pass
775
- else:
776
- model_max_seq_length = getattr(model, 'max_seq_length', None)
777
- args_max_seq_length = getattr(args, 'max_seq_length', None)
778
- if args_max_seq_length is None and model_max_seq_length is not None:
779
- max_seq_length = model.max_seq_length
780
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
781
- if model is not None and hasattr(model, 'for_training'):
782
- model.for_training()
783
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
784
- if 'processing_class' in locals():
785
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
786
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
787
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
788
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
789
- if not isinstance(data_collator, UnslothVisionDataCollator):
790
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
791
- data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
792
- elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
793
- data_collator = DataCollatorForSeq2Seq(__tokenizer)
794
- else:
795
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
796
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
797
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
798
- if not isinstance(data_collator, UnslothVisionDataCollator):
799
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
800
- if isinstance(data_collator, DataCollatorForSeq2Seq):
801
- data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
802
- else:
803
- data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
804
- other_metrics = []
805
-
806
- from unsloth_zoo.logging_utils import PatchRLStatistics
807
- PatchRLStatistics('reward_trainer', other_metrics)
808
-
809
- super().__init__(
810
- model = model,
811
- args = args,
812
- data_collator = data_collator,
813
- train_dataset = train_dataset,
814
- eval_dataset = eval_dataset,
815
- processing_class = processing_class,
816
- model_init = model_init,
817
- compute_metrics = compute_metrics,
818
- callbacks = callbacks,
819
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,
820
- peft_config = peft_config,**kwargs)
821
- if hasattr(self, 'neftune_hook_handle'):
822
- self.neftune_hook_handle.remove()
823
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
824
- if getattr(args, 'neftune_noise_alpha', None) is not None:
825
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
826
- pass
827
-
828
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_run_uploads/UnslothSFTTrainer.py DELETED
@@ -1,1102 +0,0 @@
1
- """
2
- 2025.7.11
3
- 2025.7.11
4
- 4.54.1
5
- 0.16.1
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
13
- from trl.trainer.sft_trainer import (Any, AutoModelForCausalLM, AutoTokenizer, BaseImageProcessor, Callable, ConstantLengthDataset, DataCollator, DataCollatorForLanguageModeling, DataCollatorWithFlattening, Dataset, EvalPrediction, FeatureExtractionMixin, IterableDataset, Optional, PeftConfig, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTConfig, SFTTrainer, Trainer, TrainerCallback, TrainingArguments, Type, Union, dataclass, dataclasses, defaultdict, generate_model_card, get_comet_experiment_url, get_peft_model, is_peft_available, is_wandb_available, nn, os, pad, peft, peft_module_casting_to_bf16, prepare_model_for_kbit_training, torch, transformers, version, wandb, warnings, Callable, ConstantLengthDataset, DataCollator, DataCollatorForLanguageModeling, Dataset, IterableDataset, Optional, Union, os, pad, transformers, os)
14
-
15
-
16
- import os
17
- from typing import *
18
- from dataclasses import dataclass, field
19
- from packaging.version import Version
20
- import torch
21
- import numpy as np
22
- from contextlib import nullcontext
23
- from torch.nn import functional as F
24
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
25
-
26
- torch_compile_options = {
27
- "epilogue_fusion" : True,
28
- "max_autotune" : False,
29
- "shape_padding" : True,
30
- "trace.enabled" : False,
31
- "triton.cudagraphs" : False,
32
- }
33
-
34
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
35
- def chunked_selective_log_softmax(logits, index):
36
- # Split into 4 chunks only
37
- chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
38
- chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
39
- all_per_token_logps = []
40
- # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
41
- for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
42
- chunk_logits = chunk_logits.to(torch.float32)
43
- selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
44
- logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
45
- per_token_logps = selected_logits - logsumexp_values
46
- all_per_token_logps.append(per_token_logps)
47
- pass
48
- all_per_token_logps = torch.concat(all_per_token_logps)
49
- all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
50
- return all_per_token_logps
51
- @dataclass
52
- class UnslothSFTConfig(SFTConfig):
53
- """
54
-
55
- Configuration class for the [`SFTTrainer`].
56
-
57
- Only the parameters specific to SFT training are listed here. For details on other parameters, refer to the
58
- [`~transformers.TrainingArguments`] documentation.
59
-
60
- Using [`~transformers.HfArgumentParser`] we can turn this class into
61
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
62
- command line.
63
-
64
- Parameters:
65
- > Parameters that control the model
66
-
67
- model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
68
- Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
69
- argument of the [`SFTTrainer`] is provided as a string.
70
-
71
- > Parameters that control the data preprocessing
72
-
73
- dataset_text_field (`str`, *optional*, defaults to `"text"`):
74
- Name of the column that contains text data in the dataset.
75
- dataset_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
76
- Dictionary of optional keyword arguments for the dataset preparation. The only supported key is
77
- `skip_prepare_dataset`.
78
- dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
79
- Number of processes to use for processing the dataset.
80
- pad_token (`str` or `None`, *optional*, defaults to `None`):
81
- Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`,
82
- it falls back to `processing_class.eos_token`.
83
- max_length (`int` or `None`, *optional*, defaults to `1024`):
84
- Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right.
85
- If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length.
86
- packing (`bool`, *optional*, defaults to `False`):
87
- Whether to pack multiple sequences into a fixed-length format. Uses `max_length` to define sequence length.
88
- padding_free (`bool`, *optional*, defaults to `False`):
89
- Whether to perform forward passes without padding by flattening all sequences in the batch into a single
90
- continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only
91
- supported with the `flash_attention_2` attention implementation, which can efficiently handle the flattened
92
- batch structure.
93
- eval_packing (`bool` or `None`, *optional*, defaults to `None`):
94
- Whether to pack the eval dataset. If `None`, uses the same value as `packing`.
95
-
96
- > Parameters that control the training
97
-
98
- learning_rate (`float`, *optional*, defaults to `2e-5`):
99
- Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
100
- [`~transformers.TrainingArguments`].
101
-
102
- """
103
- vllm_sampling_params: Optional[Any] = field(
104
- default = None,
105
- metadata = {'help': 'vLLM SamplingParams'},
106
- )
107
- unsloth_num_chunks : Optional[int] = field(
108
- default = -1,
109
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
110
- )
111
- def __init__(
112
- self,
113
- output_dir = None,
114
- overwrite_output_dir = None,
115
- do_train = False,
116
- do_eval = False,
117
- do_predict = False,
118
- eval_strategy = 'no',
119
- prediction_loss_only = False,
120
- per_device_train_batch_size = 4,
121
- per_device_eval_batch_size = 4,
122
- per_gpu_train_batch_size = None,
123
- per_gpu_eval_batch_size = None,
124
- gradient_accumulation_steps = 2,
125
- eval_accumulation_steps = 2,
126
- eval_delay = 0,
127
- torch_empty_cache_steps = 250,
128
- learning_rate = 5e-05,
129
- weight_decay = 0.01,
130
- adam_beta1 = 0.9,
131
- adam_beta2 = 0.999,
132
- adam_epsilon = 1e-08,
133
- max_grad_norm = 1.0,
134
- num_train_epochs = 3.0,
135
- max_steps = -1,
136
- lr_scheduler_type = 'linear',
137
- warmup_ratio = 0.1,
138
- warmup_steps = 0,
139
- log_level = 'passive',
140
- log_level_replica = 'warning',
141
- log_on_each_node = True,
142
- logging_dir = None,
143
- logging_strategy = 'steps',
144
- logging_first_step = False,
145
- logging_steps = 1,
146
- logging_nan_inf_filter = False,
147
- save_strategy = 'steps',
148
- save_steps = 500,
149
- save_total_limit = None,
150
- save_safetensors = True,
151
- save_on_each_node = False,
152
- save_only_model = False,
153
- restore_callback_states_from_checkpoint = False,
154
- no_cuda = False,
155
- use_cpu = False,
156
- use_mps_device = False,
157
- seed = 3407,
158
- data_seed = 3407,
159
- jit_mode_eval = False,
160
- use_ipex = False,
161
- bf16 = False,
162
- fp16 = False,
163
- fp16_opt_level = 'O1',
164
- half_precision_backend = 'auto',
165
- bf16_full_eval = False,
166
- fp16_full_eval = False,
167
- tf32 = None,
168
- local_rank = -1,
169
- ddp_backend = None,
170
- tpu_num_cores = None,
171
- tpu_metrics_debug = False,
172
- debug = '',
173
- dataloader_drop_last = False,
174
- eval_steps = None,
175
- dataloader_num_workers = 0,
176
- dataloader_prefetch_factor = None,
177
- past_index = -1,
178
- run_name = None,
179
- disable_tqdm = None,
180
- remove_unused_columns = True,
181
- label_names = None,
182
- load_best_model_at_end = False,
183
- metric_for_best_model = None,
184
- greater_is_better = None,
185
- ignore_data_skip = False,
186
- fsdp = '',
187
- fsdp_min_num_params = 0,
188
- fsdp_config = None,
189
- fsdp_transformer_layer_cls_to_wrap = None,
190
- accelerator_config = None,
191
- deepspeed = None,
192
- label_smoothing_factor = 0.0,
193
- optim = 'adamw_8bit',
194
- optim_args = None,
195
- adafactor = False,
196
- group_by_length = False,
197
- length_column_name = 'length',
198
- report_to = None,
199
- ddp_find_unused_parameters = None,
200
- ddp_bucket_cap_mb = None,
201
- ddp_broadcast_buffers = None,
202
- dataloader_pin_memory = True,
203
- dataloader_persistent_workers = False,
204
- skip_memory_metrics = True,
205
- use_legacy_prediction_loop = False,
206
- push_to_hub = False,
207
- resume_from_checkpoint = None,
208
- hub_model_id = None,
209
- hub_strategy = 'every_save',
210
- hub_token = None,
211
- hub_private_repo = None,
212
- hub_always_push = False,
213
- hub_revision = None,
214
- gradient_checkpointing = False,
215
- gradient_checkpointing_kwargs = None,
216
- include_inputs_for_metrics = False,
217
- eval_do_concat_batches = True,
218
- fp16_backend = 'auto',
219
- push_to_hub_model_id = None,
220
- push_to_hub_organization = None,
221
- push_to_hub_token = None,
222
- mp_parameters = '',
223
- auto_find_batch_size = True,
224
- full_determinism = False,
225
- torchdynamo = None,
226
- ray_scope = 'last',
227
- ddp_timeout = 1800,
228
- torch_compile = False,
229
- torch_compile_backend = None,
230
- torch_compile_mode = None,
231
- include_tokens_per_second = False,
232
- include_num_input_tokens_seen = False,
233
- neftune_noise_alpha = None,
234
- optim_target_modules = None,
235
- batch_eval_metrics = False,
236
- eval_on_start = False,
237
- use_liger_kernel = False,
238
- liger_kernel_config = None,
239
- eval_use_gather_object = False,
240
- average_tokens_across_devices = True,
241
- model_init_kwargs = None,
242
- dataset_text_field = 'text',
243
- dataset_kwargs = None,
244
- dataset_num_proc = None,
245
- pad_token = None,
246
- max_length = 1024,
247
- packing = False,
248
- padding_free = False,
249
- eval_packing = None,
250
- dataset_batch_size = None,
251
- num_of_sequences = None,
252
- chars_per_token = None,
253
- max_seq_length = None,
254
- use_liger = None,
255
- vllm_sampling_params = None,
256
- unsloth_num_chunks = -1,
257
- **kwargs,
258
- ):
259
- if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
260
- if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
261
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
262
- output_dir = 'unsloth_training_checkpoints'
263
- save_strategy = 'no'
264
- if dataset_num_proc is None:
265
- from multiprocessing import cpu_count
266
- dataset_num_proc = min(cpu_count()*2, 2)
267
-
268
- super().__init__(
269
- output_dir = output_dir,
270
- overwrite_output_dir = overwrite_output_dir,
271
- do_train = do_train,
272
- do_eval = do_eval,
273
- do_predict = do_predict,
274
- eval_strategy = eval_strategy,
275
- prediction_loss_only = prediction_loss_only,
276
- per_device_train_batch_size = per_device_train_batch_size,
277
- per_device_eval_batch_size = per_device_eval_batch_size,
278
- per_gpu_train_batch_size = per_gpu_train_batch_size,
279
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
280
- gradient_accumulation_steps = gradient_accumulation_steps,
281
- eval_accumulation_steps = eval_accumulation_steps,
282
- eval_delay = eval_delay,
283
- torch_empty_cache_steps = torch_empty_cache_steps,
284
- learning_rate = learning_rate,
285
- weight_decay = weight_decay,
286
- adam_beta1 = adam_beta1,
287
- adam_beta2 = adam_beta2,
288
- adam_epsilon = adam_epsilon,
289
- max_grad_norm = max_grad_norm,
290
- num_train_epochs = num_train_epochs,
291
- max_steps = max_steps,
292
- lr_scheduler_type = lr_scheduler_type,
293
- warmup_ratio = warmup_ratio,
294
- warmup_steps = warmup_steps,
295
- log_level = log_level,
296
- log_level_replica = log_level_replica,
297
- log_on_each_node = log_on_each_node,
298
- logging_dir = logging_dir,
299
- logging_strategy = logging_strategy,
300
- logging_first_step = logging_first_step,
301
- logging_steps = logging_steps,
302
- logging_nan_inf_filter = logging_nan_inf_filter,
303
- save_strategy = save_strategy,
304
- save_steps = save_steps,
305
- save_total_limit = save_total_limit,
306
- save_safetensors = save_safetensors,
307
- save_on_each_node = save_on_each_node,
308
- save_only_model = save_only_model,
309
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
310
- no_cuda = no_cuda,
311
- use_cpu = use_cpu,
312
- use_mps_device = use_mps_device,
313
- seed = seed,
314
- data_seed = data_seed,
315
- jit_mode_eval = jit_mode_eval,
316
- use_ipex = use_ipex,
317
- bf16 = bf16,
318
- fp16 = fp16,
319
- fp16_opt_level = fp16_opt_level,
320
- half_precision_backend = half_precision_backend,
321
- bf16_full_eval = bf16_full_eval,
322
- fp16_full_eval = fp16_full_eval,
323
- tf32 = tf32,
324
- local_rank = local_rank,
325
- ddp_backend = ddp_backend,
326
- tpu_num_cores = tpu_num_cores,
327
- tpu_metrics_debug = tpu_metrics_debug,
328
- debug = debug,
329
- dataloader_drop_last = dataloader_drop_last,
330
- eval_steps = eval_steps,
331
- dataloader_num_workers = dataloader_num_workers,
332
- dataloader_prefetch_factor = dataloader_prefetch_factor,
333
- past_index = past_index,
334
- run_name = run_name,
335
- disable_tqdm = disable_tqdm,
336
- remove_unused_columns = remove_unused_columns,
337
- label_names = label_names,
338
- load_best_model_at_end = load_best_model_at_end,
339
- metric_for_best_model = metric_for_best_model,
340
- greater_is_better = greater_is_better,
341
- ignore_data_skip = ignore_data_skip,
342
- fsdp = fsdp,
343
- fsdp_min_num_params = fsdp_min_num_params,
344
- fsdp_config = fsdp_config,
345
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
346
- accelerator_config = accelerator_config,
347
- deepspeed = deepspeed,
348
- label_smoothing_factor = label_smoothing_factor,
349
- optim = optim,
350
- optim_args = optim_args,
351
- adafactor = adafactor,
352
- group_by_length = group_by_length,
353
- length_column_name = length_column_name,
354
- report_to = report_to,
355
- ddp_find_unused_parameters = ddp_find_unused_parameters,
356
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
357
- ddp_broadcast_buffers = ddp_broadcast_buffers,
358
- dataloader_pin_memory = dataloader_pin_memory,
359
- dataloader_persistent_workers = dataloader_persistent_workers,
360
- skip_memory_metrics = skip_memory_metrics,
361
- use_legacy_prediction_loop = use_legacy_prediction_loop,
362
- push_to_hub = push_to_hub,
363
- resume_from_checkpoint = resume_from_checkpoint,
364
- hub_model_id = hub_model_id,
365
- hub_strategy = hub_strategy,
366
- hub_token = hub_token,
367
- hub_private_repo = hub_private_repo,
368
- hub_always_push = hub_always_push,
369
- hub_revision = hub_revision,
370
- gradient_checkpointing = gradient_checkpointing,
371
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
372
- include_inputs_for_metrics = include_inputs_for_metrics,
373
- eval_do_concat_batches = eval_do_concat_batches,
374
- fp16_backend = fp16_backend,
375
- push_to_hub_model_id = push_to_hub_model_id,
376
- push_to_hub_organization = push_to_hub_organization,
377
- push_to_hub_token = push_to_hub_token,
378
- mp_parameters = mp_parameters,
379
- auto_find_batch_size = auto_find_batch_size,
380
- full_determinism = full_determinism,
381
- torchdynamo = torchdynamo,
382
- ray_scope = ray_scope,
383
- ddp_timeout = ddp_timeout,
384
- torch_compile = torch_compile,
385
- torch_compile_backend = torch_compile_backend,
386
- torch_compile_mode = torch_compile_mode,
387
- include_tokens_per_second = include_tokens_per_second,
388
- include_num_input_tokens_seen = include_num_input_tokens_seen,
389
- neftune_noise_alpha = neftune_noise_alpha,
390
- optim_target_modules = optim_target_modules,
391
- batch_eval_metrics = batch_eval_metrics,
392
- eval_on_start = eval_on_start,
393
- use_liger_kernel = use_liger_kernel,
394
- liger_kernel_config = liger_kernel_config,
395
- eval_use_gather_object = eval_use_gather_object,
396
- average_tokens_across_devices = average_tokens_across_devices,
397
- model_init_kwargs = model_init_kwargs,
398
- dataset_text_field = dataset_text_field,
399
- dataset_kwargs = dataset_kwargs,
400
- dataset_num_proc = dataset_num_proc,
401
- pad_token = pad_token,
402
- max_length = max_length,
403
- packing = packing,
404
- padding_free = padding_free,
405
- eval_packing = eval_packing,
406
- dataset_batch_size = dataset_batch_size,
407
- num_of_sequences = num_of_sequences,
408
- chars_per_token = chars_per_token,
409
- max_seq_length = max_seq_length,
410
- use_liger = use_liger,**kwargs)
411
- self.vllm_sampling_params = vllm_sampling_params
412
- self.unsloth_num_chunks = unsloth_num_chunks
413
- pass
414
-
415
- class _UnslothSFTTrainer(Trainer):
416
- """"""
417
-
418
- _tag_names = ["trl", "sft"]
419
-
420
- def __init__(
421
- self,
422
- model: Union[str, nn.Module, PreTrainedModel],
423
- args: Optional[Union[SFTConfig, TrainingArguments]] = None,
424
- data_collator: Optional[DataCollator] = None, # type: ignore
425
- train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
426
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
427
- processing_class: Optional[
428
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
429
- ] = None,
430
- compute_loss_func: Optional[Callable] = None,
431
- compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
432
- callbacks: Optional[list[TrainerCallback]] = None,
433
- optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
434
- optimizer_cls_and_kwargs: Optional[tuple[Type[torch.optim.Optimizer], dict[str, Any]]] = None,
435
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
436
- peft_config: Optional["PeftConfig"] = None,
437
- formatting_func: Optional[Union[Callable[[dict], str], Callable[[dict], list[str]]]] = None,
438
- ):
439
- # Args
440
- model_id = model if isinstance(model, str) else model.config._name_or_path
441
- if args is None:
442
- model_name = model_id.split("/")[-1]
443
- args = SFTConfig(f"{model_name}-SFT")
444
- elif isinstance(args, TrainingArguments) and not isinstance(args, SFTConfig):
445
- dict_args = args.to_dict()
446
- dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token
447
- dict_args.pop("push_to_hub_token")
448
- args = SFTConfig(**dict_args)
449
-
450
- # Handle the tokenizer
451
- if processing_class is None:
452
- processing_class = AutoTokenizer.from_pretrained(model_id)
453
-
454
- # Data collator
455
- if args.padding_free:
456
- if data_collator is not None:
457
- raise ValueError("Passing a custom data collator is not supported when using padding-free.")
458
- if args.packing:
459
- warnings.warn(
460
- "You are passing `packing=True` and `padding_free=True` which is not recommended. Please refer "
461
- "to the documentation to understand why this is not recommended."
462
- )
463
- if model.config._attn_implementation != "flash_attention_2":
464
- warnings.warn(
465
- "Padding-free training is enabled, but the attention implementation is not set to "
466
- "'flash_attention_2'. Padding-free training flattens batches into a single sequence, and "
467
- "'flash_attention_2' is the only known attention mechanism that reliably supports this. Using "
468
- "other implementations may lead to unexpected behavior. To ensure compatibility, set "
469
- "`attn_implementation='flash_attention_2'` in the model configuration, or verify that your "
470
- "attention mechanism can handle flattened sequences."
471
- )
472
- if args.per_device_train_batch_size == 1:
473
- warnings.warn(
474
- "You are using a per_device_train_batch_size of 1 with padding-free training. Using a batch size "
475
- "of 1 anihilate the benefits of padding-free training. Please consider increasing the batch size "
476
- "to at least 2."
477
- )
478
- data_collator = DataCollatorWithFlattening()
479
-
480
- if data_collator is None:
481
- # Get the pad token: if not provided, use the one from the processing class or the eos token
482
- # if the processing class does not have a pad token.
483
- pad_token = args.pad_token or processing_class.pad_token or processing_class.eos_token
484
- pad_token_id = processing_class.convert_tokens_to_ids(pad_token)
485
- if pad_token_id is None:
486
- raise ValueError(
487
- f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given "
488
- f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists "
489
- "in the vocabulary before using it as a padding token."
490
- )
491
- data_collator = DataCollatorForLanguageModeling(pad_token_id)
492
-
493
- # Model
494
- if args.model_init_kwargs is not None and not isinstance(model, str):
495
- warnings.warn(
496
- "You passed model_init_kwargs to the `SFTConfig`, but your model is already instantiated. "
497
- "The `model_init_kwargs` will be ignored."
498
- )
499
- if isinstance(model, str):
500
- model = self._create_model_from_path(model, args)
501
-
502
- # PEFT configuration and model wrapping
503
- if False:
504
- pass
505
-
506
- # Dataset
507
- preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False)
508
- if preprocess_dataset:
509
- train_dataset = self._prepare_dataset(
510
- train_dataset, processing_class, args, args.packing, formatting_func, "train"
511
- )
512
- if eval_dataset is not None:
513
- packing = args.packing if args.eval_packing is None else args.eval_packing
514
- if isinstance(eval_dataset, dict):
515
- eval_dataset = {
516
- key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key)
517
- for key, dataset in eval_dataset.items()
518
- }
519
- else:
520
- eval_dataset = self._prepare_dataset(
521
- eval_dataset, processing_class, args, packing, formatting_func, "eval"
522
- )
523
-
524
- # Initialize the metrics
525
- self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
526
- self._total_train_tokens = 0
527
-
528
- # Initialize the Trainer. Parent class will handle:
529
- # - DeepSpeed configuration [through create_accelerator_and_postprocess]
530
- # - FSDP setup
531
- # - Distributed training setup
532
- # - Optimizer and scheduler creation
533
- # Some arguments are only available for transformers>=4.47.0. Can be removed when the min version is bumped.
534
- super_init_kwargs = {}
535
- if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
536
- super_init_kwargs["optimizer_cls_and_kwargs"] = optimizer_cls_and_kwargs
537
- else:
538
- if optimizer_cls_and_kwargs is not None:
539
- warnings.warn(
540
- "The `optimizer_cls_and_kwargs` argument is only available for `transformers>=4.47.0`. "
541
- "The default optimizer will be used. "
542
- "Remove the `optimizer_cls_and_kwargs` or upgrade to `transformers>=4.47.0`."
543
- )
544
- super().__init__(
545
- model=model,
546
- args=args,
547
- data_collator=data_collator,
548
- train_dataset=train_dataset,
549
- eval_dataset=eval_dataset,
550
- processing_class=processing_class,
551
- compute_loss_func=compute_loss_func,
552
- compute_metrics=compute_metrics,
553
- callbacks=callbacks,
554
- optimizers=optimizers,
555
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
556
- **super_init_kwargs,
557
- )
558
-
559
- # Add tags for models that have been loaded with the correct transformers version
560
- if hasattr(self.model, "add_model_tags"):
561
- self.model.add_model_tags(self._tag_names)
562
-
563
- def _create_model_from_path(self, model_path: str, args: SFTConfig) -> PreTrainedModel:
564
- """Creates a model from a path or model identifier."""
565
- model_init_kwargs = args.model_init_kwargs or {}
566
- # Handle torch dtype
567
- torch_dtype = model_init_kwargs.get("torch_dtype")
568
- if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
569
- pass # torch_dtype is already a torch.dtype or "auto" or None
570
- elif isinstance(torch_dtype, str): # it's a str, but not "auto"
571
- torch_dtype = getattr(torch, torch_dtype)
572
- model_init_kwargs["torch_dtype"] = torch_dtype
573
- else:
574
- raise ValueError(
575
- "Invalid `torch_dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing "
576
- f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
577
- )
578
- # Disable caching if gradient checkpointing is enabled (not supported)
579
- if args.gradient_checkpointing:
580
- model_init_kwargs["use_cache"] = False
581
-
582
- # Create model
583
- model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
584
- return model
585
-
586
- def _prepare_peft_model(self, model: PreTrainedModel, peft_config: Any, args: SFTConfig) -> PreTrainedModel:
587
- """Prepares a model for PEFT training."""
588
- if not is_peft_available():
589
- raise ImportError("To use PeftModel, you need to install the `peft` library.")
590
-
591
- if not isinstance(peft_config, PeftConfig):
592
- raise ValueError(
593
- f"Expected PeftConfig object but got {type(peft_config)}. If you want to use the PeftModel, you need "
594
- "to pass a PeftConfig object to the SFTTrainer."
595
- )
596
-
597
- if isinstance(model, PeftModel):
598
- return model
599
-
600
- # Handle quantized models (QLoRA)
601
- is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False)
602
-
603
- is_sharded_qlora = False
604
- if getattr(model, "is_loaded_in_4bit", False):
605
- # Check if model is sharded (FSDP/DS-Zero3)
606
- for _, param in model.named_parameters():
607
- if param.__class__.__name__ == "Params4bit":
608
- is_sharded_qlora = param.data.device.type in {"cpu", "meta"}
609
- break
610
-
611
- # Prepare model for kbit training if needed
612
- if is_qlora and not is_sharded_qlora:
613
- model = self._prepare_model_for_kbit_training(model, args)
614
- # Disable gradient checkpointing as it's handled by prepare_model_for_kbit_training
615
- args = dataclasses.replace(args, gradient_checkpointing=False)
616
- elif args.gradient_checkpointing:
617
- model = self._enable_gradient_checkpointing(model, args)
618
-
619
- # Create PEFT model
620
- if (
621
- version.parse(peft.__version__) >= version.parse("0.12") # autocast_adapter_dtype introduced in 0.12
622
- and getattr(model, "is_loaded_in_4bit", False)
623
- and is_sharded_qlora
624
- ):
625
- model = get_peft_model(model, peft_config, autocast_adapter_dtype=False)
626
- else:
627
- model = get_peft_model(model, peft_config)
628
-
629
- # Handle bf16 casting for 4-bit models
630
- if args.bf16 and getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora:
631
- peft_module_casting_to_bf16(model)
632
-
633
- return model
634
-
635
- def _prepare_model_for_kbit_training(self, model: PreTrainedModel, args: SFTConfig) -> PreTrainedModel:
636
- """Prepares a quantized model for kbit training."""
637
- prepare_model_kwargs = {
638
- "use_gradient_checkpointing": args.gradient_checkpointing,
639
- "gradient_checkpointing_kwargs": args.gradient_checkpointing_kwargs or {},
640
- }
641
-
642
- return prepare_model_for_kbit_training(model, **prepare_model_kwargs)
643
-
644
- def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: SFTConfig) -> PreTrainedModel:
645
- """Enables gradient checkpointing for the model."""
646
- gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
647
- use_reentrant = (
648
- "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
649
- )
650
-
651
- if use_reentrant:
652
- if hasattr(model, "enable_input_require_grads"):
653
- model.enable_input_require_grads()
654
- else:
655
-
656
- def make_inputs_require_grad(module, input, output):
657
- output.requires_grad_(True)
658
-
659
- model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
660
-
661
- return model
662
-
663
- def _prepare_dataset(
664
- self,
665
- dataset: Union[Dataset, IterableDataset],
666
- processing_class,
667
- args,
668
- packing: bool,
669
- formatting_func: Optional[Callable[[dict], str]],
670
- dataset_name: str,
671
- ) -> Union[Dataset, IterableDataset]:
672
- # All Unsloth Zoo code licensed under LGPLv3
673
- try:
674
- if isinstance(dataset, ConstantLengthDataset): return dataset
675
- except:
676
- pass
677
-
678
- map_kwargs = {}
679
- use_desc = isinstance(dataset, Dataset)
680
- is_vlm = hasattr(processing_class, "tokenizer")
681
- tokenizer = processing_class
682
- if is_vlm: tokenizer = processing_class.tokenizer
683
-
684
- # Get max length
685
- max_seq_length = getattr(args, "max_length", 0)
686
- if max_seq_length == 0: max_seq_length = getattr(args, "max_seq_length", 0)
687
- if max_seq_length == 0: max_seq_length = getattr(self, "max_seq_length", 0)
688
- if max_seq_length == 0: max_seq_length = getattr(self, "max_seq", 0)
689
- if max_seq_length == 0: raise RuntimeError("Unsloth: max_seq_length is 0! Please specify one!")
690
- dataset_text_field = getattr(args, "dataset_text_field", "text")
691
- do_truncation = max_seq_length != 0
692
- do_formatting_func = False
693
- do_tokenize = True
694
-
695
- # Get correct column names
696
- column_names = set(next(iter(dataset)).keys())
697
- used_column_names = ["input_ids"]
698
- if "attention_mask" in column_names:
699
- used_column_names.append("attention_mask")
700
-
701
- # Check if already tokenized so skip
702
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
703
- if "labels" in column_names:
704
- # Most likely forgot data collator!
705
- if is_vlm and not hasattr(tokenizer, "pad"):
706
- # Check if processing_class has a .pad, if not, use tokenizer.tokenizer
707
- raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!")
708
- self.data_collator = DataCollatorForSeq2Seq(tokenizer)
709
- used_column_names.append("labels")
710
- do_tokenize = False
711
- elif "input_ids" in column_names:
712
- # Skip dataset prep, and set data collator
713
- if is_vlm and not hasattr(tokenizer, "pad"):
714
- # Check if processing_class has a .pad, if not, use tokenizer.tokenizer
715
- raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!")
716
- self.data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
717
- do_tokenize = False
718
- elif dataset_text_field not in column_names:
719
- do_formatting_func = True
720
- if formatting_func is None:
721
- raise RuntimeError("Unsloth: You must specify a `formatting_func`")
722
- pass
723
-
724
- if do_tokenize:
725
- # Check double BOS tokens
726
- if do_formatting_func:
727
- test_text = formatting_func(next(iter(dataset)))
728
- if not isinstance(test_text, list):
729
- raise ValueError(
730
- "Unsloth: The `formatting_func` should return a list of processed strings."
731
- )
732
- test_text = test_text[0]
733
- else:
734
- test_text = next(iter(dataset))[dataset_text_field][0]
735
-
736
- # Get chat template
737
- chat_template = getattr(processing_class, 'chat_template', '')
738
- if chat_template == '' and is_vlm:
739
- chat_template = getattr(tokenizer, 'chat_template', '')
740
- if chat_template is None:
741
- chat_template = ''
742
-
743
- # Get bos_token
744
- add_special_tokens = True
745
- bos_token_1 = getattr(processing_class, 'bos_token', None)
746
- bos_token_2 = getattr(tokenizer, 'bos_token', None)
747
- bos_token = bos_token_1 or bos_token_2
748
-
749
- if bos_token is not None:
750
- if test_text.startswith(bos_token) or bos_token in chat_template:
751
- add_special_tokens = False
752
- print("Unsloth: We found double BOS tokens - we shall remove one automatically.")
753
- pass
754
-
755
- # Create tokenize function
756
- def _tokenize(example):
757
- return tokenizer(
758
- example[dataset_text_field] if not do_formatting_func else formatting_func(example),
759
- truncation = do_truncation,
760
- max_length = max_seq_length,
761
- return_token_type_ids = False,
762
- add_special_tokens = add_special_tokens,
763
- )
764
- pass
765
-
766
- if not isinstance(dataset, IterableDataset):
767
- map_kwargs["num_proc"] = getattr(args, "dataset_num_proc", 2)
768
- else:
769
- map_kwargs["batch_size"] = dataset._ex_iterable.batch_size
770
-
771
- if use_desc: map_kwargs["desc"] = f'Unsloth: Tokenizing ["{dataset_text_field}"]'
772
- dataset = dataset.map(_tokenize, batched = True, **map_kwargs)
773
-
774
- # If VLM, switch data collator since .pad is needed!
775
- if is_vlm and not hasattr(processing_class, "pad"):
776
- data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)
777
- self.data_collator = data_collator
778
- pass
779
- pass
780
- if packing:
781
- print("Unsloth: Hugging Face's packing is currently buggy - we're disabling it for now!")
782
- return dataset
783
-
784
- if max_seq_length == 0:
785
- raise ValueError("When packing is enabled, `max_seq_length` can't be `None`.")
786
-
787
- if use_desc: map_kwargs["desc"] = f"Unsloth: Packing {dataset_name} dataset"
788
- dataset = dataset.select_columns(used_column_names).map(
789
- pack_examples,
790
- batched = True,
791
- fn_kwargs = {"seq_length": max_seq_length,},
792
- **map_kwargs,
793
- )
794
- pass
795
- return dataset
796
-
797
- def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None):
798
- outputs = super().compute_loss(
799
- model,
800
- inputs,
801
- return_outputs = return_outputs,
802
- num_items_in_batch = num_items_in_batch,
803
- )
804
- return outputs
805
-
806
- def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
807
- mode = "eval" if self.control.should_evaluate else "train"
808
- metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics
809
-
810
- # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
811
- # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
812
- if mode == "eval":
813
- metrics = {f"eval_{key}": val for key, val in metrics.items()}
814
-
815
- logs = {**logs, **metrics}
816
- if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
817
- super().log(logs, start_time)
818
- else: # transformers<=4.46
819
- super().log(logs)
820
- self._metrics[mode].clear()
821
-
822
- def create_model_card(
823
- self,
824
- model_name: Optional[str] = None,
825
- dataset_name: Optional[str] = None,
826
- tags: Union[str, list[str], None] = None,
827
- ):
828
- """
829
- Creates a draft of a model card using the information available to the `Trainer`.
830
-
831
- Args:
832
- model_name (`str` or `None`, *optional*, defaults to `None`):
833
- Name of the model.
834
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
835
- Name of the dataset used for training.
836
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
837
- Tags to be associated with the model card.
838
- """
839
- if not self.is_world_process_zero():
840
- return
841
-
842
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
843
- base_model = self.model.config._name_or_path
844
- else:
845
- base_model = None
846
-
847
- tags = tags or []
848
- if isinstance(tags, str):
849
- tags = [tags]
850
-
851
- if hasattr(self.model.config, "unsloth_version"):
852
- tags.append("unsloth")
853
-
854
- model_card = generate_model_card(
855
- base_model=base_model,
856
- model_name=model_name,
857
- hub_model_id=self.hub_model_id,
858
- dataset_name=dataset_name,
859
- tags=tags,
860
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
861
- comet_url=get_comet_experiment_url(),
862
- trainer_name="SFT",
863
- )
864
-
865
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
866
- class UnslothSFTTrainer(_UnslothSFTTrainer):
867
- """
868
-
869
- Trainer for Supervised Fine-Tuning (SFT) method.
870
-
871
- This class is a wrapper around the [`transformers.Trainer`] class and inherits all of its attributes and methods.
872
-
873
- Example:
874
-
875
- ```python
876
- from datasets import load_dataset
877
- from trl import SFTTrainer
878
-
879
- dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]")
880
-
881
- trainer = SFTTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset)
882
- trainer.train()
883
- ```
884
-
885
- Args:
886
- model (`Union[str, PreTrainedModel]`):
887
- Model to be trained. Can be either:
888
-
889
- - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
890
- a path to a *directory* containing model weights saved using
891
- [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
892
- loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
893
- in `args.model_init_kwargs`.
894
- - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
895
- args ([`SFTConfig`], *optional*, defaults to `None`):
896
- Configuration for this trainer. If `None`, a default configuration is used.
897
- data_collator (`DataCollator`, *optional*):
898
- Function to use to form a batch from a list of elements of the prcessed `train_dataset` or `eval_dataset`.
899
- Will default to [`~transformers.default_data_collator`] if no `processing_class` is provided, an instance
900
- of [`~transformers.DataCollatorWithPadding`] otherwise if the processing_class is a feature extractor or
901
- tokenizer.
902
- train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
903
- Dataset to use for training. SFT supports both [language modeling](#language-modeling) type and
904
- [prompt-completion](#prompt-completion) type. The format of the samples can be either:
905
-
906
- - [Standard](dataset_formats#standard): Each sample contains plain text.
907
- - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
908
- and content).
909
-
910
- The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field.
911
- eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
912
- Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
913
- processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
914
- Processing class used to process the data. If `None`, the processing class is loaded from the model's name
915
- with [`~transformers.AutoTokenizer.from_pretrained`].
916
- callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
917
- List of callbacks to customize the training loop. Will add those to the list of default callbacks
918
- detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
919
-
920
- If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
921
- method.
922
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
923
- A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
924
- model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
925
- optimizer_cls_and_kwargs (`Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*, defaults to `None`):
926
- A tuple containing the optimizer class and keyword arguments to use.
927
- Overrides `optim` and `optim_args` in `args`. Incompatible with the `optimizers` argument.
928
-
929
- Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before initializing the Trainer.
930
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*, defaults to `None`):
931
- A function that preprocess the logits right before caching them at each evaluation step. Must take two
932
- tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
933
- by this function will be reflected in the predictions received by `compute_metrics`.
934
-
935
- Note that the labels (second parameter) will be `None` if the dataset does not have them.
936
- peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
937
- PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
938
- formatting_func (`Optional[Callable]`):
939
- Formatting function applied to the dataset before tokenization.
940
-
941
- """
942
- def __init__(
943
- self,
944
- model,
945
- args = None,
946
- data_collator = None,
947
- train_dataset = None,
948
- eval_dataset = None,
949
- processing_class = None,
950
- compute_loss_func = None,
951
- compute_metrics = None,
952
- callbacks = None,
953
- optimizer_cls_and_kwargs = None,
954
- preprocess_logits_for_metrics = None,
955
- peft_config = None,
956
- formatting_func = None,
957
- **kwargs
958
- ):
959
- if args is None: args = UnslothSFTConfig()
960
- use_bf16 = getattr(args, 'bf16', False)
961
- if type(use_bf16) is not bool: use_bf16 = False
962
- use_fp16 = getattr(args, 'fp16', False)
963
- if type(use_fp16) is not bool: use_fp16 = False
964
- force_float32 = False
965
- if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
966
- print('Unsloth: Switching to float32 training since model cannot work with float16')
967
- force_float32 = True
968
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
969
- dtype = getattr(model.config, 'torch_dtype', None)
970
- if dtype is None: dtype = model.get_input_embeddings().dtype
971
- from unsloth_zoo.utils import _get_dtype
972
- dtype = _get_dtype(dtype)
973
- float16 = dtype == torch.float16
974
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
975
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
976
- if force_float32:
977
- args.fp16 = False
978
- args.bf16 = False
979
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
980
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
981
- args.fp16 = float16
982
- args.bf16 = not float16
983
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
984
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
985
- args.eval_strategy = 'steps'
986
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
987
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
988
- if ga_steps is not None and ga_steps > 1:
989
- from transformers import __version__ as transformers_version
990
- if Version(transformers_version) <= Version('4.45.2'):
991
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
992
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
993
- if getattr(args, 'eval_strategy', 'no') != 'no':
994
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
995
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
996
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
997
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
998
- if type(fp16_full_eval) is not bool: fp16_full_eval = False
999
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
1000
- if type(bf16_full_eval) is not bool: bf16_full_eval = False
1001
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
1002
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
1003
- if force_float32:
1004
- args.bf16_full_eval = False
1005
- args.fp16_full_eval = False
1006
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
1007
- args.bf16_full_eval = True
1008
- args.fp16_full_eval = False
1009
- elif not bf16_full_eval and not fp16_full_eval:
1010
- args.bf16_full_eval = args.bf16
1011
- args.fp16_full_eval = args.fp16
1012
- _output_logits = False
1013
- if locals().get('compute_metrics', None) is not None: _output_logits = True
1014
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
1015
- if _output_logits:
1016
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
1017
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
1018
- pass
1019
- else:
1020
- model_max_seq_length = getattr(model, 'max_seq_length', None)
1021
- args_max_seq_length = getattr(args, 'max_seq_length', None)
1022
- if args_max_seq_length is None and model_max_seq_length is not None:
1023
- max_seq_length = model.max_seq_length
1024
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
1025
- if 'max_length' not in locals() and not hasattr(args, 'max_length'):
1026
- pass
1027
- else:
1028
- if hasattr(args, 'max_seq_length') and args.max_seq_length is not None and args.max_seq_length > 0:
1029
- if hasattr(args, 'max_length'):
1030
- args.max_length = args.max_seq_length
1031
- max_length = args.max_length
1032
- else:
1033
- model_max_length = getattr(model, 'max_seq_length', None)
1034
- # print(model_max_length, 'mml1')
1035
- if model_max_length is None: model_max_length = getattr(model, 'max_length', None)
1036
- # print(model_max_length, 'mml2')
1037
- if model_max_length is not None:
1038
- args.max_length = model_max_length
1039
- max_length = args.max_length
1040
- elif hasattr(args, 'max_length') and args.max_length is not None:
1041
- max_length = args.max_length
1042
- # if we are here, then we are in a weird case where max_length is set but max_seq_length is not set
1043
- setattr(model, 'max_seq_length', max_length)
1044
- else:
1045
- print('Unsloth: We did not find `max_seq_length` or `max_length` in the model or args. We will set it to 1024.')
1046
- args.max_length = 1024
1047
- if model is not None and hasattr(model, 'for_training'):
1048
- model.for_training()
1049
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
1050
- if 'processing_class' in locals():
1051
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
1052
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
1053
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
1054
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
1055
- if not isinstance(data_collator, UnslothVisionDataCollator):
1056
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
1057
- data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
1058
- elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
1059
- data_collator = DataCollatorForSeq2Seq(__tokenizer)
1060
- else:
1061
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
1062
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
1063
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
1064
- if not isinstance(data_collator, UnslothVisionDataCollator):
1065
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
1066
- if isinstance(data_collator, DataCollatorForSeq2Seq):
1067
- data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
1068
- else:
1069
- data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
1070
- other_metrics = []
1071
-
1072
- from unsloth_zoo.logging_utils import PatchRLStatistics
1073
- PatchRLStatistics('sft_trainer', other_metrics)
1074
- IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', '').split('\n')
1075
- from unsloth_zoo.tokenizer_utils import fix_untrained_tokens
1076
- from unsloth_zoo.training_utils import fix_zero_training_loss
1077
- if 'tokenizer' not in locals(): tokenizer = processing_class
1078
- fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)
1079
- fix_zero_training_loss(model, tokenizer, train_dataset)
1080
-
1081
- super().__init__(
1082
- model = model,
1083
- args = args,
1084
- data_collator = data_collator,
1085
- train_dataset = train_dataset,
1086
- eval_dataset = eval_dataset,
1087
- processing_class = processing_class,
1088
- compute_loss_func = compute_loss_func,
1089
- compute_metrics = compute_metrics,
1090
- callbacks = callbacks,
1091
- optimizer_cls_and_kwargs = optimizer_cls_and_kwargs,
1092
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,
1093
- peft_config = peft_config,
1094
- formatting_func = formatting_func,**kwargs)
1095
- if hasattr(self, 'neftune_hook_handle'):
1096
- self.neftune_hook_handle.remove()
1097
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1098
- if getattr(args, 'neftune_noise_alpha', None) is not None:
1099
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1100
- pass
1101
-
1102
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_run_uploads/UnslothXPOTrainer.py DELETED
@@ -1,1024 +0,0 @@
1
- """
2
- 2025.7.11
3
- 2025.7.11
4
- 4.54.1
5
- 0.16.1
6
- __UNSLOTH_VERSIONING__
7
- """
8
- from torch import Tensor
9
- import torch
10
- import torch.nn as nn
11
- from torch.nn import functional as F
12
- from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
13
- from trl.trainer.xpo_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, IterableDataset, OnlineDPOTrainer, OptimizerNames, Optional, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, XPOConfig, XPOTrainer, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_wandb_available, jinja2, maybe_apply_chat_template, nn, os, selective_log_softmax, textwrap, torch, truncate_right, unwrap_model_for_generation, wandb)
14
-
15
-
16
- import os
17
- from typing import *
18
- from dataclasses import dataclass, field
19
- from packaging.version import Version
20
- import torch
21
- import numpy as np
22
- from contextlib import nullcontext
23
- from torch.nn import functional as F
24
- from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
25
-
26
- torch_compile_options = {
27
- "epilogue_fusion" : True,
28
- "max_autotune" : False,
29
- "shape_padding" : True,
30
- "trace.enabled" : False,
31
- "triton.cudagraphs" : False,
32
- }
33
-
34
- @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
35
- def chunked_selective_log_softmax(logits, index):
36
- # Split into 4 chunks only
37
- chunked_logits = torch.chunk(logits.reshape(-1, logits.shape[-1]), chunks = 4, dim = 0)
38
- chunked_index = torch.chunk(index.reshape(-1), chunks = 4, dim = 0)
39
- all_per_token_logps = []
40
- # Below loop does the same as selective_log_softmax(chunk_logits, chunk_index)
41
- for chunk_logits, chunk_index in zip(chunked_logits, chunked_index):
42
- chunk_logits = chunk_logits.to(torch.float32)
43
- selected_logits = torch.gather(chunk_logits, dim = -1, index = chunk_index.unsqueeze(-1)).squeeze(-1)
44
- logsumexp_values = torch.logsumexp(chunk_logits, dim = -1)
45
- per_token_logps = selected_logits - logsumexp_values
46
- all_per_token_logps.append(per_token_logps)
47
- pass
48
- all_per_token_logps = torch.concat(all_per_token_logps)
49
- all_per_token_logps = all_per_token_logps.reshape((logits.shape[0], logits.shape[1]))
50
- return all_per_token_logps
51
- @dataclass
52
- class UnslothXPOConfig(XPOConfig):
53
- """
54
-
55
- Configuration class for the [`XPOTrainer`].
56
-
57
- Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
58
-
59
- Parameters:
60
- alpha (`float` or `list[float]`, *optional*, defaults to `1e-5`):
61
- Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each new epoch
62
- and the last alpha is used for the rest of the epochs.
63
-
64
- """
65
- vllm_sampling_params: Optional[Any] = field(
66
- default = None,
67
- metadata = {'help': 'vLLM SamplingParams'},
68
- )
69
- unsloth_num_chunks : Optional[int] = field(
70
- default = -1,
71
- metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
72
- )
73
- def __init__(
74
- self,
75
- output_dir = None,
76
- overwrite_output_dir = None,
77
- do_train = False,
78
- do_eval = False,
79
- do_predict = False,
80
- eval_strategy = 'no',
81
- prediction_loss_only = False,
82
- per_device_train_batch_size = 4,
83
- per_device_eval_batch_size = 4,
84
- per_gpu_train_batch_size = None,
85
- per_gpu_eval_batch_size = None,
86
- gradient_accumulation_steps = 2,
87
- eval_accumulation_steps = 2,
88
- eval_delay = 0,
89
- torch_empty_cache_steps = 250,
90
- learning_rate = 5e-05,
91
- weight_decay = 0.01,
92
- adam_beta1 = 0.9,
93
- adam_beta2 = 0.999,
94
- adam_epsilon = 1e-08,
95
- max_grad_norm = 1.0,
96
- num_train_epochs = 3.0,
97
- max_steps = -1,
98
- lr_scheduler_type = 'linear',
99
- warmup_ratio = 0.1,
100
- warmup_steps = 0,
101
- log_level = 'passive',
102
- log_level_replica = 'warning',
103
- log_on_each_node = True,
104
- logging_dir = None,
105
- logging_strategy = 'steps',
106
- logging_first_step = False,
107
- logging_steps = 1,
108
- logging_nan_inf_filter = False,
109
- save_strategy = 'steps',
110
- save_steps = 500,
111
- save_total_limit = None,
112
- save_safetensors = True,
113
- save_on_each_node = False,
114
- save_only_model = False,
115
- restore_callback_states_from_checkpoint = False,
116
- no_cuda = False,
117
- use_cpu = False,
118
- use_mps_device = False,
119
- seed = 3407,
120
- data_seed = 3407,
121
- jit_mode_eval = False,
122
- use_ipex = False,
123
- bf16 = False,
124
- fp16 = False,
125
- fp16_opt_level = 'O1',
126
- half_precision_backend = 'auto',
127
- bf16_full_eval = False,
128
- fp16_full_eval = False,
129
- tf32 = None,
130
- local_rank = -1,
131
- ddp_backend = None,
132
- tpu_num_cores = None,
133
- tpu_metrics_debug = False,
134
- debug = '',
135
- dataloader_drop_last = False,
136
- eval_steps = None,
137
- dataloader_num_workers = 0,
138
- dataloader_prefetch_factor = None,
139
- past_index = -1,
140
- run_name = None,
141
- disable_tqdm = None,
142
- remove_unused_columns = True,
143
- label_names = None,
144
- load_best_model_at_end = False,
145
- metric_for_best_model = None,
146
- greater_is_better = None,
147
- ignore_data_skip = False,
148
- fsdp = '',
149
- fsdp_min_num_params = 0,
150
- fsdp_config = None,
151
- fsdp_transformer_layer_cls_to_wrap = None,
152
- accelerator_config = None,
153
- deepspeed = None,
154
- label_smoothing_factor = 0.0,
155
- optim = 'adamw_8bit',
156
- optim_args = None,
157
- adafactor = False,
158
- group_by_length = False,
159
- length_column_name = 'length',
160
- report_to = None,
161
- ddp_find_unused_parameters = None,
162
- ddp_bucket_cap_mb = None,
163
- ddp_broadcast_buffers = None,
164
- dataloader_pin_memory = True,
165
- dataloader_persistent_workers = False,
166
- skip_memory_metrics = True,
167
- use_legacy_prediction_loop = False,
168
- push_to_hub = False,
169
- resume_from_checkpoint = None,
170
- hub_model_id = None,
171
- hub_strategy = 'every_save',
172
- hub_token = None,
173
- hub_private_repo = None,
174
- hub_always_push = False,
175
- hub_revision = None,
176
- gradient_checkpointing = False,
177
- gradient_checkpointing_kwargs = None,
178
- include_inputs_for_metrics = False,
179
- eval_do_concat_batches = True,
180
- fp16_backend = 'auto',
181
- push_to_hub_model_id = None,
182
- push_to_hub_organization = None,
183
- push_to_hub_token = None,
184
- mp_parameters = '',
185
- auto_find_batch_size = True,
186
- full_determinism = False,
187
- torchdynamo = None,
188
- ray_scope = 'last',
189
- ddp_timeout = 1800,
190
- torch_compile = False,
191
- torch_compile_backend = None,
192
- torch_compile_mode = None,
193
- include_tokens_per_second = False,
194
- include_num_input_tokens_seen = False,
195
- neftune_noise_alpha = None,
196
- optim_target_modules = None,
197
- batch_eval_metrics = False,
198
- eval_on_start = False,
199
- use_liger_kernel = False,
200
- liger_kernel_config = None,
201
- eval_use_gather_object = False,
202
- average_tokens_across_devices = True,
203
- reward_model_path = None,
204
- judge = None,
205
- max_new_tokens = 64,
206
- max_length = 512,
207
- temperature = 0.9,
208
- missing_eos_penalty = None,
209
- loss_type = 'sigmoid',
210
- dataset_num_proc = None,
211
- disable_dropout = True,
212
- use_vllm = False,
213
- ds3_gather_for_generation = True,
214
- vllm_sampling_params = None,
215
- unsloth_num_chunks = -1,
216
- **kwargs,
217
- ):
218
- if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
219
- if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
220
- if output_dir is None and save_strategy == 'steps' and save_steps == 500:
221
- output_dir = 'unsloth_training_checkpoints'
222
- save_strategy = 'no'
223
- if dataset_num_proc is None:
224
- from multiprocessing import cpu_count
225
- dataset_num_proc = min(cpu_count()*2, 2)
226
- if temperature <= 0:
227
- raise MathError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')
228
- elif temperature >= 10:
229
- raise MathError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')
230
-
231
-
232
- super().__init__(
233
- output_dir = output_dir,
234
- overwrite_output_dir = overwrite_output_dir,
235
- do_train = do_train,
236
- do_eval = do_eval,
237
- do_predict = do_predict,
238
- eval_strategy = eval_strategy,
239
- prediction_loss_only = prediction_loss_only,
240
- per_device_train_batch_size = per_device_train_batch_size,
241
- per_device_eval_batch_size = per_device_eval_batch_size,
242
- per_gpu_train_batch_size = per_gpu_train_batch_size,
243
- per_gpu_eval_batch_size = per_gpu_eval_batch_size,
244
- gradient_accumulation_steps = gradient_accumulation_steps,
245
- eval_accumulation_steps = eval_accumulation_steps,
246
- eval_delay = eval_delay,
247
- torch_empty_cache_steps = torch_empty_cache_steps,
248
- learning_rate = learning_rate,
249
- weight_decay = weight_decay,
250
- adam_beta1 = adam_beta1,
251
- adam_beta2 = adam_beta2,
252
- adam_epsilon = adam_epsilon,
253
- max_grad_norm = max_grad_norm,
254
- num_train_epochs = num_train_epochs,
255
- max_steps = max_steps,
256
- lr_scheduler_type = lr_scheduler_type,
257
- warmup_ratio = warmup_ratio,
258
- warmup_steps = warmup_steps,
259
- log_level = log_level,
260
- log_level_replica = log_level_replica,
261
- log_on_each_node = log_on_each_node,
262
- logging_dir = logging_dir,
263
- logging_strategy = logging_strategy,
264
- logging_first_step = logging_first_step,
265
- logging_steps = logging_steps,
266
- logging_nan_inf_filter = logging_nan_inf_filter,
267
- save_strategy = save_strategy,
268
- save_steps = save_steps,
269
- save_total_limit = save_total_limit,
270
- save_safetensors = save_safetensors,
271
- save_on_each_node = save_on_each_node,
272
- save_only_model = save_only_model,
273
- restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
274
- no_cuda = no_cuda,
275
- use_cpu = use_cpu,
276
- use_mps_device = use_mps_device,
277
- seed = seed,
278
- data_seed = data_seed,
279
- jit_mode_eval = jit_mode_eval,
280
- use_ipex = use_ipex,
281
- bf16 = bf16,
282
- fp16 = fp16,
283
- fp16_opt_level = fp16_opt_level,
284
- half_precision_backend = half_precision_backend,
285
- bf16_full_eval = bf16_full_eval,
286
- fp16_full_eval = fp16_full_eval,
287
- tf32 = tf32,
288
- local_rank = local_rank,
289
- ddp_backend = ddp_backend,
290
- tpu_num_cores = tpu_num_cores,
291
- tpu_metrics_debug = tpu_metrics_debug,
292
- debug = debug,
293
- dataloader_drop_last = dataloader_drop_last,
294
- eval_steps = eval_steps,
295
- dataloader_num_workers = dataloader_num_workers,
296
- dataloader_prefetch_factor = dataloader_prefetch_factor,
297
- past_index = past_index,
298
- run_name = run_name,
299
- disable_tqdm = disable_tqdm,
300
- remove_unused_columns = remove_unused_columns,
301
- label_names = label_names,
302
- load_best_model_at_end = load_best_model_at_end,
303
- metric_for_best_model = metric_for_best_model,
304
- greater_is_better = greater_is_better,
305
- ignore_data_skip = ignore_data_skip,
306
- fsdp = fsdp,
307
- fsdp_min_num_params = fsdp_min_num_params,
308
- fsdp_config = fsdp_config,
309
- fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
310
- accelerator_config = accelerator_config,
311
- deepspeed = deepspeed,
312
- label_smoothing_factor = label_smoothing_factor,
313
- optim = optim,
314
- optim_args = optim_args,
315
- adafactor = adafactor,
316
- group_by_length = group_by_length,
317
- length_column_name = length_column_name,
318
- report_to = report_to,
319
- ddp_find_unused_parameters = ddp_find_unused_parameters,
320
- ddp_bucket_cap_mb = ddp_bucket_cap_mb,
321
- ddp_broadcast_buffers = ddp_broadcast_buffers,
322
- dataloader_pin_memory = dataloader_pin_memory,
323
- dataloader_persistent_workers = dataloader_persistent_workers,
324
- skip_memory_metrics = skip_memory_metrics,
325
- use_legacy_prediction_loop = use_legacy_prediction_loop,
326
- push_to_hub = push_to_hub,
327
- resume_from_checkpoint = resume_from_checkpoint,
328
- hub_model_id = hub_model_id,
329
- hub_strategy = hub_strategy,
330
- hub_token = hub_token,
331
- hub_private_repo = hub_private_repo,
332
- hub_always_push = hub_always_push,
333
- hub_revision = hub_revision,
334
- gradient_checkpointing = gradient_checkpointing,
335
- gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
336
- include_inputs_for_metrics = include_inputs_for_metrics,
337
- eval_do_concat_batches = eval_do_concat_batches,
338
- fp16_backend = fp16_backend,
339
- push_to_hub_model_id = push_to_hub_model_id,
340
- push_to_hub_organization = push_to_hub_organization,
341
- push_to_hub_token = push_to_hub_token,
342
- mp_parameters = mp_parameters,
343
- auto_find_batch_size = auto_find_batch_size,
344
- full_determinism = full_determinism,
345
- torchdynamo = torchdynamo,
346
- ray_scope = ray_scope,
347
- ddp_timeout = ddp_timeout,
348
- torch_compile = torch_compile,
349
- torch_compile_backend = torch_compile_backend,
350
- torch_compile_mode = torch_compile_mode,
351
- include_tokens_per_second = include_tokens_per_second,
352
- include_num_input_tokens_seen = include_num_input_tokens_seen,
353
- neftune_noise_alpha = neftune_noise_alpha,
354
- optim_target_modules = optim_target_modules,
355
- batch_eval_metrics = batch_eval_metrics,
356
- eval_on_start = eval_on_start,
357
- use_liger_kernel = use_liger_kernel,
358
- liger_kernel_config = liger_kernel_config,
359
- eval_use_gather_object = eval_use_gather_object,
360
- average_tokens_across_devices = average_tokens_across_devices,
361
- reward_model_path = reward_model_path,
362
- judge = judge,
363
- max_new_tokens = max_new_tokens,
364
- max_length = max_length,
365
- temperature = temperature,
366
- missing_eos_penalty = missing_eos_penalty,
367
- loss_type = loss_type,
368
- dataset_num_proc = dataset_num_proc,
369
- disable_dropout = disable_dropout,
370
- use_vllm = use_vllm,
371
- ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
372
- self.vllm_sampling_params = vllm_sampling_params
373
- self.unsloth_num_chunks = unsloth_num_chunks
374
- pass
375
-
376
- class _UnslothXPOTrainer(OnlineDPOTrainer):
377
- r""""""
378
-
379
- _tag_names = ["trl", "xpo"]
380
-
381
- def __init__(
382
- self,
383
- model: Union[PreTrainedModel, nn.Module] = None,
384
- ref_model: Union[PreTrainedModel, nn.Module] = None,
385
- reward_model: Optional[nn.Module] = None,
386
- judge: Optional[BasePairwiseJudge] = None,
387
- args: Optional[XPOConfig] = None,
388
- data_collator: Optional[Callable] = None,
389
- train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
390
- eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
391
- processing_class: Optional[
392
- Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
393
- ] = None,
394
- peft_config: Optional[dict] = None,
395
- compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
396
- callbacks: Optional[list[TrainerCallback]] = None,
397
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
398
- preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
399
- ) -> None:
400
- super().__init__(
401
- model=model,
402
- ref_model=ref_model,
403
- judge=judge,
404
- reward_model=reward_model,
405
- args=args,
406
- data_collator=data_collator,
407
- train_dataset=train_dataset,
408
- eval_dataset=eval_dataset,
409
- processing_class=processing_class,
410
- reward_processing_class=processing_class, # for now, XPOTrainer can't use any reward model
411
- peft_config=peft_config,
412
- compute_metrics=compute_metrics,
413
- callbacks=callbacks,
414
- optimizers=optimizers,
415
- preprocess_logits_for_metrics=preprocess_logits_for_metrics,
416
- )
417
-
418
- self._alpha = self.args.alpha
419
-
420
- # Overwrite the stats dictionary to include XPO specific statistics
421
- self.stats = {
422
- # Remove "non_score_reward", "rlhf_reward", "scores"
423
- # Add "loss/dpo", "loss/xpo"
424
- "loss/dpo": [],
425
- "loss/xpo": [],
426
- "objective/kl": [],
427
- "objective/entropy": [],
428
- "rewards/chosen": [],
429
- "rewards/rejected": [],
430
- "rewards/accuracies": [],
431
- "rewards/margins": [],
432
- "logps/chosen": [],
433
- "logps/rejected": [],
434
- # Replace "contain_eos_token" by "model_contain_eos_token" and "ref_contain_eos_token"
435
- "val/model_contain_eos_token": [],
436
- "val/ref_contain_eos_token": [],
437
- "alpha": [],
438
- "beta": [],
439
- }
440
- if self.reward_model is not None:
441
- # Replace "scores" by "model_scores" and "ref_scores"
442
- self.stats["objective/model_scores"] = []
443
- self.stats["objective/ref_scores"] = []
444
- self.stats["objective/scores_margin"] = []
445
-
446
- @property
447
- def alpha(self):
448
- if isinstance(self._alpha, list):
449
- epoch = self.state.epoch
450
- return self._alpha[epoch] if epoch < len(self._alpha) else self._alpha[-1]
451
- else:
452
- return self._alpha
453
-
454
- def _generate_completions(self, prompts, model):
455
- with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
456
- model_output = unwrapped_model.generate(
457
- input_ids=prompts["input_ids"],
458
- attention_mask=prompts["attention_mask"],
459
- generation_config=self.generation_config,
460
- )
461
-
462
- ref_model = model if self.ref_model is None else self.ref_model
463
- with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model:
464
- ref_output = unwrapped_ref_model.generate(
465
- input_ids=prompts["input_ids"],
466
- attention_mask=prompts["attention_mask"],
467
- generation_config=self.generation_config,
468
- )
469
-
470
- return model_output, ref_output
471
-
472
- def _process_completions(self, model_output, ref_output, prompts):
473
- context_length = prompts["input_ids"].shape[1]
474
-
475
- # Process model completions
476
- model_completion_ids = model_output[:, context_length:]
477
- model_completion_ids, model_completion_mask = truncate_right(
478
- model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
479
- )
480
- model_data = {
481
- "input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1),
482
- "attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1),
483
- "raw": prompts["raw"],
484
- }
485
-
486
- # Process reference model completions
487
- ref_completion_ids = ref_output[:, context_length:]
488
- ref_completion_ids, ref_completion_mask = truncate_right(
489
- ref_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
490
- )
491
- ref_data = {
492
- "input_ids": torch.cat((prompts["input_ids"], ref_completion_ids), dim=1),
493
- "attention_mask": torch.cat((prompts["attention_mask"], ref_completion_mask), dim=1),
494
- "raw": prompts["raw"],
495
- }
496
-
497
- return model_data, ref_data
498
-
499
- def _compute_rewards(self, model_data, ref_data, context_length):
500
- with torch.no_grad():
501
- _, model_scores, _ = get_reward(
502
- self.reward_model, model_data["input_ids"], self.processing_class.pad_token_id, context_length
503
- )
504
- _, ref_scores, _ = get_reward(
505
- self.reward_model, ref_data["input_ids"], self.processing_class.pad_token_id, context_length
506
- )
507
-
508
- # Apply EOS penalty if needed
509
- if self.args.missing_eos_penalty is not None:
510
- model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
511
- ref_contain_eos = torch.any(ref_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
512
- model_scores[~model_contain_eos] -= self.args.missing_eos_penalty
513
- ref_scores[~ref_contain_eos] -= self.args.missing_eos_penalty
514
-
515
- return model_scores, ref_scores
516
-
517
- def _compute_judge(self, model_data, ref_data, context_length):
518
- prompts = model_data["raw"]
519
- model_data_completions = self.processing_class.batch_decode(
520
- model_data["input_ids"][:, context_length:], skip_special_tokens=True
521
- )
522
- model_data_completions = [completion.strip() for completion in model_data_completions]
523
-
524
- ref_data_completions = self.processing_class.batch_decode(
525
- ref_data["input_ids"][:, context_length:], skip_special_tokens=True
526
- )
527
- ref_data_completions = [completion.strip() for completion in ref_data_completions]
528
-
529
- if is_conversational({"prompt": prompts[0]}):
530
- model_data_completions = [
531
- [{"role": "assistant", "content": completion}] for completion in model_data_completions
532
- ]
533
- environment = jinja2.Environment()
534
- template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
535
- prompts = [template.render(messages=message) for message in prompts]
536
- model_data_completions = [template.render(messages=completion) for completion in model_data_completions]
537
-
538
- ref_data_completions = [
539
- [{"role": "assistant", "content": completion}] for completion in ref_data_completions
540
- ]
541
- ref_data_completions = [template.render(messages=completion) for completion in ref_data_completions]
542
-
543
- ranks_of_first_completion = self.judge.judge(
544
- prompts,
545
- list(zip(model_data_completions, ref_data_completions)),
546
- )
547
- # convert ranks to a True/False mask:
548
- # when rank == 0, it means the first completion is the best
549
- # when rank == 1, it means the second completion is the best
550
- return torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=model_data["input_ids"].device)
551
-
552
- def _compute_logprobs(self, model, model_data, ref_data, context_length):
553
- def compute_logprobs_for_data(m, data):
554
- output = m(data["input_ids"], attention_mask=data["attention_mask"])
555
- logits = output.logits[:, context_length - 1 : -1]
556
- token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
557
- return token_logprobs
558
-
559
- # Compute logprobs for model completions
560
- model_logprobs_model_data = compute_logprobs_for_data(model, model_data)
561
- # Compute logprobs for model on reference completions (for XPO loss)
562
- model_logprobs_ref_data = compute_logprobs_for_data(model, ref_data)
563
-
564
- # Compute logprobs for reference model completions
565
- with torch.no_grad():
566
- if self.ref_model is None:
567
- with model.disable_adapter():
568
- ref_logprobs_model_data = compute_logprobs_for_data(model, model_data)
569
- ref_logprobs_ref_data = compute_logprobs_for_data(model, ref_data)
570
- else:
571
- ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data)
572
- ref_logprobs_ref_data = compute_logprobs_for_data(self.ref_model, ref_data)
573
-
574
- # Mask padding tokens
575
- model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
576
- ref_padding_mask = ref_data["attention_mask"][:, context_length:] == 0
577
- model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
578
- model_logprobs_ref_data = model_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0)
579
- ref_logprobs_ref_data = ref_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0)
580
- ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
581
-
582
- return model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data
583
-
584
- def _compute_losses(
585
- self,
586
- model_logprobs_model_data,
587
- model_logprobs_ref_data,
588
- ref_logprobs_ref_data,
589
- ref_logprobs_model_data,
590
- chosen_mask,
591
- ):
592
- # Compute log probs
593
- model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
594
- model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1)
595
- ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1)
596
- ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
597
-
598
- chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
599
- chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
600
- chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs
601
-
602
- rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
603
- rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
604
- rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs
605
-
606
- # Compute logits as the difference between chosen and rejected log ratios
607
- logits = chosen_log_ratios - rejected_log_ratios
608
-
609
- if self.args.loss_type == "sigmoid":
610
- dpo_losses = -F.logsigmoid(self.beta * logits)
611
- elif self.args.loss_type == "ipo":
612
- dpo_losses = (logits - 1 / (2 * self.beta)) ** 2
613
- else:
614
- raise NotImplementedError(f"invalid loss type {self.args.loss_type}")
615
-
616
- # Compute XPO specific loss
617
- xpo_losses = self.alpha * model_logprobs_ref_data_sum
618
-
619
- # Total loss
620
- loss = (dpo_losses + xpo_losses).mean()
621
-
622
- return loss, dpo_losses, xpo_losses
623
-
624
- def _log_statistics(
625
- self,
626
- model_data,
627
- ref_data,
628
- model_logprobs_model_data,
629
- model_logprobs_ref_data,
630
- ref_logprobs_ref_data,
631
- ref_logprobs_model_data,
632
- chosen_mask,
633
- dpo_losses,
634
- xpo_losses,
635
- context_length,
636
- model_scores=None,
637
- ref_scores=None,
638
- ):
639
- # Helper function to gather and compute mean
640
- def gather_mean(tensor):
641
- return self.accelerator.gather_for_metrics(tensor).mean().item()
642
-
643
- # Log losses
644
- self.stats["loss/dpo"].append(gather_mean(dpo_losses))
645
- self.stats["loss/xpo"].append(gather_mean(xpo_losses))
646
-
647
- # Log scores
648
- if self.reward_model is not None:
649
- self.stats["objective/model_scores"].append(gather_mean(model_scores))
650
- self.stats["objective/ref_scores"].append(gather_mean(ref_scores))
651
- self.stats["objective/scores_margin"].append(gather_mean(model_scores - ref_scores))
652
-
653
- # Log logprobs
654
- model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
655
- model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1)
656
- ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1)
657
- ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
658
-
659
- chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
660
- chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
661
- chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs
662
-
663
- rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum)
664
- rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum)
665
- rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs
666
-
667
- self.stats["logps/chosen"].append(gather_mean(chosen_model_logprobs.mean() + chosen_ref_logprobs.mean()))
668
- self.stats["logps/rejected"].append(gather_mean(rejected_model_logprobs.mean() + rejected_ref_logprobs.mean()))
669
-
670
- # Log rewards
671
- # Compute various statistics
672
- chosen_rewards = chosen_log_ratios * self.beta
673
- rejected_rewards = rejected_log_ratios * self.beta
674
- self.stats["rewards/chosen"].append(gather_mean(chosen_rewards.mean()))
675
- self.stats["rewards/rejected"].append(gather_mean(rejected_rewards.mean()))
676
-
677
- # Calculate KL divergence for model and ref data
678
- kl_model_data = model_logprobs_model_data - ref_logprobs_model_data
679
- kl_ref_data = model_logprobs_ref_data - ref_logprobs_ref_data
680
- mean_kl = (kl_model_data.sum(1) + kl_ref_data.sum(1)).mean() / 2
681
- self.stats["objective/kl"].append(gather_mean(mean_kl))
682
-
683
- # Calculate entropy for model and ref data
684
- entropy_model_data = -model_logprobs_model_data.sum(1)
685
- entropy_ref_data = -model_logprobs_ref_data.sum(1)
686
- mean_entropy = (entropy_model_data.mean() + entropy_ref_data.mean()) / 2
687
- self.stats["objective/entropy"].append(gather_mean(mean_entropy))
688
-
689
- # Calculate margins
690
- margin = chosen_rewards - rejected_rewards
691
- self.stats["rewards/margins"].append(gather_mean(margin.mean()))
692
-
693
- # Calculate accuracy
694
- accuracy = (margin > 0).float()
695
- self.stats["rewards/accuracies"].append(gather_mean(accuracy.mean()))
696
-
697
- # Log EOS token statistics
698
- model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
699
- ref_eos = (ref_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
700
- self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float()))
701
- self.stats["val/ref_contain_eos_token"].append(gather_mean(ref_eos.float()))
702
-
703
- # Log alpha and beta
704
- self.stats["alpha"].append(self.alpha)
705
- self.stats["beta"].append(self.beta)
706
-
707
- def training_step(
708
- self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
709
- ) -> torch.Tensor:
710
- model.train()
711
-
712
- # Apply chat template and tokenize the input
713
- batch_size = len(next(iter(inputs.values())))
714
- prompts = inputs["prompt"]
715
- inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
716
- inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
717
- inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs]
718
- inputs = self.data_collator(inputs)
719
-
720
- # need the prompt_ only
721
- inputs = self._prepare_inputs(inputs)
722
- context_length = inputs["prompt_input_ids"].shape[1]
723
- prompts = {
724
- "input_ids": inputs["prompt_input_ids"],
725
- "attention_mask": inputs["prompt_attention_mask"],
726
- "raw": prompts,
727
- }
728
- del inputs
729
-
730
- # Sample completions from both the model and the reference model
731
- model_output, ref_output = self._generate_completions(prompts, model)
732
-
733
- # Process model completions
734
- model_data, ref_data = self._process_completions(model_output, ref_output, prompts)
735
-
736
- # Compute rewards
737
- if self.reward_model is not None:
738
- model_scores, ref_scores = self._compute_rewards(model_data, ref_data, context_length)
739
- chosen_mask = model_scores >= ref_scores
740
- else:
741
- model_scores, ref_scores = None, None
742
- chosen_mask = self._compute_judge(model_data, ref_data, context_length)
743
-
744
- # Compute logprobs
745
- model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data = (
746
- self._compute_logprobs(model, model_data, ref_data, context_length)
747
- )
748
-
749
- # Compute loss
750
- loss, dpo_losses, xpo_losses = self._compute_losses(
751
- model_logprobs_model_data,
752
- model_logprobs_ref_data,
753
- ref_logprobs_ref_data,
754
- ref_logprobs_model_data,
755
- chosen_mask,
756
- )
757
-
758
- # Log everything
759
- self._log_statistics(
760
- model_data,
761
- ref_data,
762
- model_logprobs_model_data.detach(),
763
- model_logprobs_ref_data.detach(),
764
- ref_logprobs_ref_data,
765
- ref_logprobs_model_data,
766
- chosen_mask,
767
- dpo_losses.detach(),
768
- xpo_losses.detach(),
769
- context_length,
770
- model_scores,
771
- ref_scores,
772
- )
773
-
774
- if (
775
- self.args.torch_empty_cache_steps is not None
776
- and self.state.global_step % self.args.torch_empty_cache_steps == 0
777
- ):
778
- empty_cache()
779
-
780
- kwargs = {}
781
- # For LOMO optimizers you need to explicitly use the learning rate
782
- if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
783
- kwargs["learning_rate"] = self._get_learning_rate()
784
-
785
- if self.args.n_gpu > 1:
786
- loss = loss.mean() # mean() to average on multi-gpu parallel training
787
-
788
- if self.use_apex:
789
- with amp.scale_loss(loss, self.optimizer) as scaled_loss:
790
- scaled_loss.backward()
791
- else:
792
- self.accelerator.backward(loss, **kwargs)
793
-
794
- return loss.detach() / self.args.gradient_accumulation_steps
795
-
796
- def create_model_card(
797
- self,
798
- model_name: Optional[str] = None,
799
- dataset_name: Optional[str] = None,
800
- tags: Union[str, list[str], None] = None,
801
- ):
802
- """
803
- Creates a draft of a model card using the information available to the `Trainer`.
804
-
805
- Args:
806
- model_name (`str` or `None`, *optional*, defaults to `None`):
807
- Name of the model.
808
- dataset_name (`str` or `None`, *optional*, defaults to `None`):
809
- Name of the dataset used for training.
810
- tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
811
- Tags to be associated with the model card.
812
- """
813
- if not self.is_world_process_zero():
814
- return
815
-
816
- if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
817
- base_model = self.model.config._name_or_path
818
- else:
819
- base_model = None
820
-
821
- tags = tags or []
822
- if isinstance(tags, str):
823
- tags = [tags]
824
-
825
- if hasattr(self.model.config, "unsloth_version"):
826
- tags.append("unsloth")
827
-
828
- citation = textwrap.dedent("""\
829
- @article{jung2024binary,
830
- title = {{Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF}},
831
- author = {Tengyang Xie and Dylan J. Foster and Akshay Krishnamurthy and Corby Rosset and Ahmed Awadallah and Alexander Rakhlin},
832
- year = 2024,
833
- eprint = {arXiv:2405.21046}
834
- }""")
835
-
836
- model_card = generate_model_card(
837
- base_model=base_model,
838
- model_name=model_name,
839
- hub_model_id=self.hub_model_id,
840
- dataset_name=dataset_name,
841
- tags=tags,
842
- wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
843
- comet_url=get_comet_experiment_url(),
844
- trainer_name="XPO",
845
- trainer_citation=citation,
846
- paper_title="Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF",
847
- paper_id="2405.21046",
848
- )
849
-
850
- model_card.save(os.path.join(self.args.output_dir, "README.md"))
851
- class UnslothXPOTrainer(_UnslothXPOTrainer):
852
- """
853
-
854
- Initialize XPOTrainer as a subclass of [`OnlineDPOConfig`].
855
-
856
- Args:
857
- model (`transformers.PreTrainedModel`):
858
- The model to train, preferably an `AutoModelForCausalLM`.
859
- ref_model (`PreTrainedModelWrapper`):
860
- Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
861
- reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
862
- reward_model (`transformers.PreTrainedModel`):
863
- The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
864
- judge (`BasePairwiseJudge`):
865
- The judge to use for pairwise comparison of model completions.
866
- args (`XPOConfig`):
867
- The XPO config arguments to use for training.
868
- data_collator (`transformers.DataCollator`):
869
- The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
870
- which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
871
- train_dataset (`datasets.Dataset`):
872
- The dataset to use for training.
873
- eval_dataset (`datasets.Dataset`):
874
- The dataset to use for evaluation.
875
- processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
876
- Processing class used to process the data. If provided, will be used to automatically process the inputs
877
- for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
878
- reuse the fine-tuned model.
879
- peft_config (`dict`):
880
- The peft config to use for training.
881
- compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
882
- The function to use to compute the metrics. Must take a `EvalPrediction` and return
883
- a dictionary string to metric values.
884
- callbacks (`list[transformers.TrainerCallback]`):
885
- The callbacks to use for training.
886
- optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
887
- The optimizer and scheduler to use for training.
888
- preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
889
- The function to use to preprocess the logits before computing the metrics.
890
-
891
- """
892
- def __init__(
893
- self,
894
- model = None,
895
- ref_model = None,
896
- reward_model = None,
897
- judge = None,
898
- args = None,
899
- data_collator = None,
900
- train_dataset = None,
901
- eval_dataset = None,
902
- processing_class = None,
903
- peft_config = None,
904
- compute_metrics = None,
905
- callbacks = None,
906
- preprocess_logits_for_metrics = None,
907
- **kwargs
908
- ):
909
- if args is None: args = UnslothXPOConfig()
910
- use_bf16 = getattr(args, 'bf16', False)
911
- if type(use_bf16) is not bool: use_bf16 = False
912
- use_fp16 = getattr(args, 'fp16', False)
913
- if type(use_fp16) is not bool: use_fp16 = False
914
- force_float32 = False
915
- if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
916
- print('Unsloth: Switching to float32 training since model cannot work with float16')
917
- force_float32 = True
918
- mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
919
- dtype = getattr(model.config, 'torch_dtype', None)
920
- if dtype is None: dtype = model.get_input_embeddings().dtype
921
- from unsloth_zoo.utils import _get_dtype
922
- dtype = _get_dtype(dtype)
923
- float16 = dtype == torch.float16
924
- if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
925
- if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
926
- if force_float32:
927
- args.fp16 = False
928
- args.bf16 = False
929
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
930
- elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
931
- args.fp16 = float16
932
- args.bf16 = not float16
933
- os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
934
- if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
935
- args.eval_strategy = 'steps'
936
- if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
937
- ga_steps = getattr(args, 'gradient_accumulation_steps', None)
938
- if ga_steps is not None and ga_steps > 1:
939
- from transformers import __version__ as transformers_version
940
- if Version(transformers_version) <= Version('4.45.2'):
941
- print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
942
- '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
943
- if getattr(args, 'eval_strategy', 'no') != 'no':
944
- eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
945
- if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
946
- if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
947
- fp16_full_eval = getattr(args, 'fp16_full_eval', False)
948
- if type(fp16_full_eval) is not bool: fp16_full_eval = False
949
- bf16_full_eval = getattr(args, 'bf16_full_eval', False)
950
- if type(bf16_full_eval) is not bool: bf16_full_eval = False
951
- if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
952
- if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
953
- if force_float32:
954
- args.bf16_full_eval = False
955
- args.fp16_full_eval = False
956
- elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
957
- args.bf16_full_eval = True
958
- args.fp16_full_eval = False
959
- elif not bf16_full_eval and not fp16_full_eval:
960
- args.bf16_full_eval = args.bf16
961
- args.fp16_full_eval = args.fp16
962
- _output_logits = False
963
- if locals().get('compute_metrics', None) is not None: _output_logits = True
964
- if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
965
- if _output_logits:
966
- os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
967
- if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
968
- pass
969
- else:
970
- model_max_seq_length = getattr(model, 'max_seq_length', None)
971
- args_max_seq_length = getattr(args, 'max_seq_length', None)
972
- if args_max_seq_length is None and model_max_seq_length is not None:
973
- max_seq_length = model.max_seq_length
974
- if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
975
- if model is not None and hasattr(model, 'for_training'):
976
- model.for_training()
977
- if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
978
- if 'processing_class' in locals():
979
- if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
980
- if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
981
- __tokenizer = processing_class if 'processing_class' in locals() else tokenizer
982
- from unsloth_zoo.vision_utils import UnslothVisionDataCollator
983
- if not isinstance(data_collator, UnslothVisionDataCollator):
984
- if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
985
- data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
986
- elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
987
- data_collator = DataCollatorForSeq2Seq(__tokenizer)
988
- else:
989
- if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
990
- if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
991
- if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
992
- if not isinstance(data_collator, UnslothVisionDataCollator):
993
- if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
994
- if isinstance(data_collator, DataCollatorForSeq2Seq):
995
- data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
996
- else:
997
- data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
998
- other_metrics = []
999
-
1000
- from unsloth_zoo.logging_utils import PatchRLStatistics
1001
- PatchRLStatistics('xpo_trainer', other_metrics)
1002
-
1003
- super().__init__(
1004
- model = model,
1005
- ref_model = ref_model,
1006
- reward_model = reward_model,
1007
- judge = judge,
1008
- args = args,
1009
- data_collator = data_collator,
1010
- train_dataset = train_dataset,
1011
- eval_dataset = eval_dataset,
1012
- processing_class = processing_class,
1013
- peft_config = peft_config,
1014
- compute_metrics = compute_metrics,
1015
- callbacks = callbacks,
1016
- preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
1017
- if hasattr(self, 'neftune_hook_handle'):
1018
- self.neftune_hook_handle.remove()
1019
- if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
1020
- if getattr(args, 'neftune_noise_alpha', None) is not None:
1021
- model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
1022
- pass
1023
-
1024
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_run_uploads/__pycache__/UnslothAlignPropTrainer.cpython-311.pyc DELETED
Binary file (33.9 kB)
 
test_run_uploads/__pycache__/UnslothBCOTrainer.cpython-311.pyc DELETED
Binary file (92.8 kB)
 
test_run_uploads/__pycache__/UnslothCPOTrainer.cpython-311.pyc DELETED
Binary file (76.7 kB)
 
test_run_uploads/__pycache__/UnslothDDPOTrainer.cpython-311.pyc DELETED
Binary file (46.5 kB)
 
test_run_uploads/__pycache__/UnslothDPOTrainer.cpython-311.pyc DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8c20178043e78b3057a4eec21c41cb84e543aa7e03cab7996894ab8e7904e768
3
- size 104591
 
 
 
 
test_run_uploads/__pycache__/UnslothGKDTrainer.cpython-311.pyc DELETED
Binary file (39.6 kB)
 
test_run_uploads/__pycache__/UnslothGRPOTrainer.cpython-311.pyc DELETED
Binary file (97.6 kB)
 
test_run_uploads/__pycache__/UnslothKTOTrainer.cpython-311.pyc DELETED
Binary file (88.7 kB)
 
test_run_uploads/__pycache__/UnslothNashMDTrainer.cpython-311.pyc DELETED
Binary file (49 kB)
 
test_run_uploads/__pycache__/UnslothORPOTrainer.cpython-311.pyc DELETED
Binary file (76.7 kB)
 
test_run_uploads/__pycache__/UnslothOnlineDPOTrainer.cpython-311.pyc DELETED
Binary file (68.8 kB)
 
test_run_uploads/__pycache__/UnslothPPOTrainer.cpython-311.pyc DELETED
Binary file (64.4 kB)
 
test_run_uploads/__pycache__/UnslothPRMTrainer.cpython-311.pyc DELETED
Binary file (37.7 kB)
 
test_run_uploads/__pycache__/UnslothRLOOTrainer.cpython-311.pyc DELETED
Binary file (55.6 kB)
 
test_run_uploads/__pycache__/UnslothRewardTrainer.cpython-311.pyc DELETED
Binary file (40.2 kB)
 
test_run_uploads/__pycache__/UnslothSFTTrainer.cpython-311.pyc DELETED
Binary file (52.4 kB)
 
test_run_uploads/__pycache__/UnslothXPOTrainer.cpython-311.pyc DELETED
Binary file (51.6 kB)
 
test_run_uploads/checkpoint-50/README.md DELETED
@@ -1,210 +0,0 @@
1
- ---
2
- base_model: mistralai/Ministral-8B-Instruct-2410
3
- library_name: peft
4
- pipeline_tag: text-generation
5
- tags:
6
- - base_model:adapter:mistralai/Ministral-8B-Instruct-2410
7
- - lora
8
- - sft
9
- - transformers
10
- - trl
11
- - unsloth
12
- ---
13
-
14
- # Model Card for Model ID
15
-
16
- <!-- Provide a quick summary of what the model is/does. -->
17
-
18
-
19
-
20
- ## Model Details
21
-
22
- ### Model Description
23
-
24
- <!-- Provide a longer summary of what this model is. -->
25
-
26
-
27
-
28
- - **Developed by:** [More Information Needed]
29
- - **Funded by [optional]:** [More Information Needed]
30
- - **Shared by [optional]:** [More Information Needed]
31
- - **Model type:** [More Information Needed]
32
- - **Language(s) (NLP):** [More Information Needed]
33
- - **License:** [More Information Needed]
34
- - **Finetuned from model [optional]:** [More Information Needed]
35
-
36
- ### Model Sources [optional]
37
-
38
- <!-- Provide the basic links for the model. -->
39
-
40
- - **Repository:** [More Information Needed]
41
- - **Paper [optional]:** [More Information Needed]
42
- - **Demo [optional]:** [More Information Needed]
43
-
44
- ## Uses
45
-
46
- <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
47
-
48
- ### Direct Use
49
-
50
- <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
51
-
52
- [More Information Needed]
53
-
54
- ### Downstream Use [optional]
55
-
56
- <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
57
-
58
- [More Information Needed]
59
-
60
- ### Out-of-Scope Use
61
-
62
- <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
63
-
64
- [More Information Needed]
65
-
66
- ## Bias, Risks, and Limitations
67
-
68
- <!-- This section is meant to convey both technical and sociotechnical limitations. -->
69
-
70
- [More Information Needed]
71
-
72
- ### Recommendations
73
-
74
- <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
75
-
76
- Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
77
-
78
- ## How to Get Started with the Model
79
-
80
- Use the code below to get started with the model.
81
-
82
- [More Information Needed]
83
-
84
- ## Training Details
85
-
86
- ### Training Data
87
-
88
- <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
89
-
90
- [More Information Needed]
91
-
92
- ### Training Procedure
93
-
94
- <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
95
-
96
- #### Preprocessing [optional]
97
-
98
- [More Information Needed]
99
-
100
-
101
- #### Training Hyperparameters
102
-
103
- - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
104
-
105
- #### Speeds, Sizes, Times [optional]
106
-
107
- <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
108
-
109
- [More Information Needed]
110
-
111
- ## Evaluation
112
-
113
- <!-- This section describes the evaluation protocols and provides the results. -->
114
-
115
- ### Testing Data, Factors & Metrics
116
-
117
- #### Testing Data
118
-
119
- <!-- This should link to a Dataset Card if possible. -->
120
-
121
- [More Information Needed]
122
-
123
- #### Factors
124
-
125
- <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
126
-
127
- [More Information Needed]
128
-
129
- #### Metrics
130
-
131
- <!-- These are the evaluation metrics being used, ideally with a description of why. -->
132
-
133
- [More Information Needed]
134
-
135
- ### Results
136
-
137
- [More Information Needed]
138
-
139
- #### Summary
140
-
141
-
142
-
143
- ## Model Examination [optional]
144
-
145
- <!-- Relevant interpretability work for the model goes here -->
146
-
147
- [More Information Needed]
148
-
149
- ## Environmental Impact
150
-
151
- <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
152
-
153
- Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
154
-
155
- - **Hardware Type:** [More Information Needed]
156
- - **Hours used:** [More Information Needed]
157
- - **Cloud Provider:** [More Information Needed]
158
- - **Compute Region:** [More Information Needed]
159
- - **Carbon Emitted:** [More Information Needed]
160
-
161
- ## Technical Specifications [optional]
162
-
163
- ### Model Architecture and Objective
164
-
165
- [More Information Needed]
166
-
167
- ### Compute Infrastructure
168
-
169
- [More Information Needed]
170
-
171
- #### Hardware
172
-
173
- [More Information Needed]
174
-
175
- #### Software
176
-
177
- [More Information Needed]
178
-
179
- ## Citation [optional]
180
-
181
- <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
182
-
183
- **BibTeX:**
184
-
185
- [More Information Needed]
186
-
187
- **APA:**
188
-
189
- [More Information Needed]
190
-
191
- ## Glossary [optional]
192
-
193
- <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
194
-
195
- [More Information Needed]
196
-
197
- ## More Information [optional]
198
-
199
- [More Information Needed]
200
-
201
- ## Model Card Authors [optional]
202
-
203
- [More Information Needed]
204
-
205
- ## Model Card Contact
206
-
207
- [More Information Needed]
208
- ### Framework versions
209
-
210
- - PEFT 0.16.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_run_uploads/checkpoint-50/adapter_config.json DELETED
@@ -1,41 +0,0 @@
1
- {
2
- "alpha_pattern": {},
3
- "auto_mapping": null,
4
- "base_model_name_or_path": "mistralai/Ministral-8B-Instruct-2410",
5
- "bias": "none",
6
- "corda_config": null,
7
- "eva_config": null,
8
- "exclude_modules": null,
9
- "fan_in_fan_out": false,
10
- "inference_mode": true,
11
- "init_lora_weights": true,
12
- "layer_replication": null,
13
- "layers_pattern": null,
14
- "layers_to_transform": null,
15
- "loftq_config": {},
16
- "lora_alpha": 64,
17
- "lora_bias": false,
18
- "lora_dropout": 0,
19
- "megatron_config": null,
20
- "megatron_core": "megatron.core",
21
- "modules_to_save": null,
22
- "peft_type": "LORA",
23
- "qalora_group_size": 16,
24
- "r": 32,
25
- "rank_pattern": {},
26
- "revision": null,
27
- "target_modules": [
28
- "up_proj",
29
- "gate_proj",
30
- "q_proj",
31
- "o_proj",
32
- "v_proj",
33
- "down_proj",
34
- "k_proj"
35
- ],
36
- "task_type": "CAUSAL_LM",
37
- "trainable_token_indices": null,
38
- "use_dora": false,
39
- "use_qalora": false,
40
- "use_rslora": false
41
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_run_uploads/checkpoint-50/adapter_model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:7c22732e7777cc816d7e7503316cbac9b3806566322e1f6bab5d429ea8766f00
3
- size 349243752
 
 
 
 
test_run_uploads/checkpoint-50/chat_template.jinja DELETED
@@ -1 +0,0 @@
1
- {{ bos_token }}{% if messages[0]['role'] == 'system' %}{% if messages[1]['role'] == 'user' %}{{ '[INST] ' + messages[0]['content'] + ' ' + messages[1]['content'] + ' [/INST]' }}{% set loop_messages = messages[2:] %}{% else %}{{ '[INST] ' + messages[0]['content'] + ' [/INST]' }}{% set loop_messages = messages[1:] %}{% endif %}{% else %}{% set loop_messages = messages %}{% endif %}{% for message in loop_messages %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}
 
 
test_run_uploads/checkpoint-50/optimizer.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9eb84082a1889da4e199d66b7cabd1c0dbbee3a7097bc5f7aebb331e4786a6d6
3
- size 177918917
 
 
 
 
test_run_uploads/checkpoint-50/rng_state.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:181c5f0270cf39930062ddfa3767a2481d0c360f120b11f8e25dbf533a1cdaba
3
- size 14645
 
 
 
 
test_run_uploads/checkpoint-50/scaler.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5cd0e9d505fbc3f97feb166d29026132bdf14eb3e5c7ff77beebc303ee666f96
3
- size 1383
 
 
 
 
test_run_uploads/checkpoint-50/scheduler.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3f43a6155628947732c83ac3165bbc211721c396e9e3b246bdecdaaf19583e1c
3
- size 1465
 
 
 
 
test_run_uploads/checkpoint-50/special_tokens_map.json DELETED
@@ -1,24 +0,0 @@
1
- {
2
- "bos_token": {
3
- "content": "<s>",
4
- "lstrip": false,
5
- "normalized": false,
6
- "rstrip": false,
7
- "single_word": false
8
- },
9
- "eos_token": {
10
- "content": "</s>",
11
- "lstrip": false,
12
- "normalized": false,
13
- "rstrip": false,
14
- "single_word": false
15
- },
16
- "pad_token": "<pad>",
17
- "unk_token": {
18
- "content": "<unk>",
19
- "lstrip": false,
20
- "normalized": false,
21
- "rstrip": false,
22
- "single_word": false
23
- }
24
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_run_uploads/checkpoint-50/tokenizer.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a7fc0f8e08693e6deb5bbb0cd3ab7431131567cc69bd3a67fd6da0e3c7ee58e4
3
- size 17078391
 
 
 
 
test_run_uploads/checkpoint-50/tokenizer_config.json DELETED
The diff for this file is too large to render. See raw diff
 
test_run_uploads/checkpoint-50/trainer_state.json DELETED
@@ -1,77 +0,0 @@
1
- {
2
- "best_global_step": null,
3
- "best_metric": Infinity,
4
- "best_model_checkpoint": null,
5
- "epoch": 0.007038783698176955,
6
- "eval_steps": 50,
7
- "global_step": 50,
8
- "is_hyper_param_search": false,
9
- "is_local_process_zero": true,
10
- "is_world_process_zero": true,
11
- "log_history": [
12
- {
13
- "epoch": 0.001407756739635391,
14
- "grad_norm": 4.217477321624756,
15
- "learning_rate": 1.8e-06,
16
- "loss": 1.9593,
17
- "step": 10
18
- },
19
- {
20
- "epoch": 0.002815513479270782,
21
- "grad_norm": 6.792465686798096,
22
- "learning_rate": 3.8e-06,
23
- "loss": 1.8226,
24
- "step": 20
25
- },
26
- {
27
- "epoch": 0.004223270218906173,
28
- "grad_norm": 3.987929344177246,
29
- "learning_rate": 5.8e-06,
30
- "loss": 1.5628,
31
- "step": 30
32
- },
33
- {
34
- "epoch": 0.005631026958541564,
35
- "grad_norm": 3.203339099884033,
36
- "learning_rate": 7.8e-06,
37
- "loss": 1.2142,
38
- "step": 40
39
- },
40
- {
41
- "epoch": 0.007038783698176955,
42
- "grad_norm": 4.646796226501465,
43
- "learning_rate": 9.800000000000001e-06,
44
- "loss": 0.8943,
45
- "step": 50
46
- },
47
- {
48
- "epoch": 0.007038783698176955,
49
- "eval_loss": NaN,
50
- "eval_runtime": 3184.6841,
51
- "eval_samples_per_second": 1.093,
52
- "eval_steps_per_second": 0.182,
53
- "step": 50
54
- }
55
- ],
56
- "logging_steps": 10,
57
- "max_steps": 90,
58
- "num_input_tokens_seen": 0,
59
- "num_train_epochs": 1,
60
- "save_steps": 50,
61
- "stateful_callbacks": {
62
- "TrainerControl": {
63
- "args": {
64
- "should_epoch_stop": false,
65
- "should_evaluate": false,
66
- "should_log": false,
67
- "should_save": true,
68
- "should_training_stop": false
69
- },
70
- "attributes": {}
71
- }
72
- },
73
- "total_flos": 9110440274558976.0,
74
- "train_batch_size": 2,
75
- "trial_name": null,
76
- "trial_params": null
77
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_run_uploads/checkpoint-50/training_args.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:97c4848f189bc8ef55d633cdc5629ac09caf902f18ecbe802fee52f91633580d
3
- size 6097
 
 
 
 
test_run_uploads/checkpoint-90/README.md DELETED
@@ -1,210 +0,0 @@
1
- ---
2
- base_model: mistralai/Ministral-8B-Instruct-2410
3
- library_name: peft
4
- pipeline_tag: text-generation
5
- tags:
6
- - base_model:adapter:mistralai/Ministral-8B-Instruct-2410
7
- - lora
8
- - sft
9
- - transformers
10
- - trl
11
- - unsloth
12
- ---
13
-
14
- # Model Card for Model ID
15
-
16
- <!-- Provide a quick summary of what the model is/does. -->
17
-
18
-
19
-
20
- ## Model Details
21
-
22
- ### Model Description
23
-
24
- <!-- Provide a longer summary of what this model is. -->
25
-
26
-
27
-
28
- - **Developed by:** [More Information Needed]
29
- - **Funded by [optional]:** [More Information Needed]
30
- - **Shared by [optional]:** [More Information Needed]
31
- - **Model type:** [More Information Needed]
32
- - **Language(s) (NLP):** [More Information Needed]
33
- - **License:** [More Information Needed]
34
- - **Finetuned from model [optional]:** [More Information Needed]
35
-
36
- ### Model Sources [optional]
37
-
38
- <!-- Provide the basic links for the model. -->
39
-
40
- - **Repository:** [More Information Needed]
41
- - **Paper [optional]:** [More Information Needed]
42
- - **Demo [optional]:** [More Information Needed]
43
-
44
- ## Uses
45
-
46
- <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
47
-
48
- ### Direct Use
49
-
50
- <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
51
-
52
- [More Information Needed]
53
-
54
- ### Downstream Use [optional]
55
-
56
- <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
57
-
58
- [More Information Needed]
59
-
60
- ### Out-of-Scope Use
61
-
62
- <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
63
-
64
- [More Information Needed]
65
-
66
- ## Bias, Risks, and Limitations
67
-
68
- <!-- This section is meant to convey both technical and sociotechnical limitations. -->
69
-
70
- [More Information Needed]
71
-
72
- ### Recommendations
73
-
74
- <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
75
-
76
- Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
77
-
78
- ## How to Get Started with the Model
79
-
80
- Use the code below to get started with the model.
81
-
82
- [More Information Needed]
83
-
84
- ## Training Details
85
-
86
- ### Training Data
87
-
88
- <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
89
-
90
- [More Information Needed]
91
-
92
- ### Training Procedure
93
-
94
- <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
95
-
96
- #### Preprocessing [optional]
97
-
98
- [More Information Needed]
99
-
100
-
101
- #### Training Hyperparameters
102
-
103
- - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
104
-
105
- #### Speeds, Sizes, Times [optional]
106
-
107
- <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
108
-
109
- [More Information Needed]
110
-
111
- ## Evaluation
112
-
113
- <!-- This section describes the evaluation protocols and provides the results. -->
114
-
115
- ### Testing Data, Factors & Metrics
116
-
117
- #### Testing Data
118
-
119
- <!-- This should link to a Dataset Card if possible. -->
120
-
121
- [More Information Needed]
122
-
123
- #### Factors
124
-
125
- <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
126
-
127
- [More Information Needed]
128
-
129
- #### Metrics
130
-
131
- <!-- These are the evaluation metrics being used, ideally with a description of why. -->
132
-
133
- [More Information Needed]
134
-
135
- ### Results
136
-
137
- [More Information Needed]
138
-
139
- #### Summary
140
-
141
-
142
-
143
- ## Model Examination [optional]
144
-
145
- <!-- Relevant interpretability work for the model goes here -->
146
-
147
- [More Information Needed]
148
-
149
- ## Environmental Impact
150
-
151
- <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
152
-
153
- Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
154
-
155
- - **Hardware Type:** [More Information Needed]
156
- - **Hours used:** [More Information Needed]
157
- - **Cloud Provider:** [More Information Needed]
158
- - **Compute Region:** [More Information Needed]
159
- - **Carbon Emitted:** [More Information Needed]
160
-
161
- ## Technical Specifications [optional]
162
-
163
- ### Model Architecture and Objective
164
-
165
- [More Information Needed]
166
-
167
- ### Compute Infrastructure
168
-
169
- [More Information Needed]
170
-
171
- #### Hardware
172
-
173
- [More Information Needed]
174
-
175
- #### Software
176
-
177
- [More Information Needed]
178
-
179
- ## Citation [optional]
180
-
181
- <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
182
-
183
- **BibTeX:**
184
-
185
- [More Information Needed]
186
-
187
- **APA:**
188
-
189
- [More Information Needed]
190
-
191
- ## Glossary [optional]
192
-
193
- <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
194
-
195
- [More Information Needed]
196
-
197
- ## More Information [optional]
198
-
199
- [More Information Needed]
200
-
201
- ## Model Card Authors [optional]
202
-
203
- [More Information Needed]
204
-
205
- ## Model Card Contact
206
-
207
- [More Information Needed]
208
- ### Framework versions
209
-
210
- - PEFT 0.16.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_run_uploads/checkpoint-90/adapter_config.json DELETED
@@ -1,41 +0,0 @@
1
- {
2
- "alpha_pattern": {},
3
- "auto_mapping": null,
4
- "base_model_name_or_path": "mistralai/Ministral-8B-Instruct-2410",
5
- "bias": "none",
6
- "corda_config": null,
7
- "eva_config": null,
8
- "exclude_modules": null,
9
- "fan_in_fan_out": false,
10
- "inference_mode": true,
11
- "init_lora_weights": true,
12
- "layer_replication": null,
13
- "layers_pattern": null,
14
- "layers_to_transform": null,
15
- "loftq_config": {},
16
- "lora_alpha": 64,
17
- "lora_bias": false,
18
- "lora_dropout": 0,
19
- "megatron_config": null,
20
- "megatron_core": "megatron.core",
21
- "modules_to_save": null,
22
- "peft_type": "LORA",
23
- "qalora_group_size": 16,
24
- "r": 32,
25
- "rank_pattern": {},
26
- "revision": null,
27
- "target_modules": [
28
- "up_proj",
29
- "gate_proj",
30
- "q_proj",
31
- "o_proj",
32
- "v_proj",
33
- "down_proj",
34
- "k_proj"
35
- ],
36
- "task_type": "CAUSAL_LM",
37
- "trainable_token_indices": null,
38
- "use_dora": false,
39
- "use_qalora": false,
40
- "use_rslora": false
41
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_run_uploads/checkpoint-90/adapter_model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ef1e45c7329d233afbb23de1796591557b08e47a395529287f0ddf873bd719d9
3
- size 349243752