|
|
|
|
|
from transformers import Trainer, TrainingArguments
|
|
|
from torch.utils.data import DataLoader
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class MTPTrainer(Trainer):
|
|
|
"""
|
|
|
Custom trainer for Multi-Token Prediction training.
|
|
|
"""
|
|
|
def __init__(self, model, args=None, train_dataset=None, eval_dataset=None, **kwargs):
|
|
|
super().__init__(
|
|
|
model=model,
|
|
|
args=args,
|
|
|
train_dataset=train_dataset,
|
|
|
eval_dataset=eval_dataset,
|
|
|
**kwargs
|
|
|
)
|
|
|
self.use_mtp_training = True
|
|
|
|
|
|
def compute_loss(self, model, inputs, return_outputs=False):
|
|
|
"""
|
|
|
Compute loss during training - handles both MTP and standard LM training.
|
|
|
"""
|
|
|
labels = inputs.get("labels")
|
|
|
outputs = model(**inputs, labels=labels, use_mtp_training=self.use_mtp_training)
|
|
|
loss = outputs.loss
|
|
|
|
|
|
return (loss, outputs) if return_outputs else loss
|
|
|
|
|
|
def train_step_with_mtp(self, model, inputs):
|
|
|
"""
|
|
|
Specialized training step for MTP training.
|
|
|
"""
|
|
|
model.set_mtp_training(True)
|
|
|
return self.training_step(model, inputs)
|
|
|
|
|
|
def train_step_with_lm(self, model, inputs):
|
|
|
"""
|
|
|
Standard language modeling training step.
|
|
|
"""
|
|
|
model.set_mtp_training(False)
|
|
|
return self.training_step(model, inputs)
|
|
|
|
|
|
|
|
|
class RLTrainer:
|
|
|
"""
|
|
|
Separate trainer class for Reinforcement Learning training.
|
|
|
This can use the generate_with_logprobs method from your model.
|
|
|
"""
|
|
|
def __init__(self, model, tokenizer, rl_config=None):
|
|
|
self.model = model
|
|
|
self.tokenizer = tokenizer
|
|
|
self.rl_config = rl_config or {}
|
|
|
|
|
|
def rl_training_step(self, input_ids, old_log_probs, old_action_probs, **kwargs):
|
|
|
"""
|
|
|
Perform an RL training step using the model's generate_with_logprobs method
|
|
|
and the reward functions from rl_utils.
|
|
|
"""
|
|
|
|
|
|
from .rl_utils import (
|
|
|
batch_compute_rewards,
|
|
|
compute_ppo_loss,
|
|
|
compute_kl_divergence,
|
|
|
compute_entropy_bonus,
|
|
|
AdaptiveKLController
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
pass |