# trainer.py from transformers import Trainer, TrainingArguments from torch.utils.data import DataLoader import torch import torch.nn.functional as F class MTPTrainer(Trainer): """ Custom trainer for Multi-Token Prediction training. """ def __init__(self, model, args=None, train_dataset=None, eval_dataset=None, **kwargs): super().__init__( model=model, args=args, train_dataset=train_dataset, eval_dataset=eval_dataset, **kwargs ) self.use_mtp_training = True def compute_loss(self, model, inputs, return_outputs=False): """ Compute loss during training - handles both MTP and standard LM training. """ labels = inputs.get("labels") outputs = model(**inputs, labels=labels, use_mtp_training=self.use_mtp_training) loss = outputs.loss return (loss, outputs) if return_outputs else loss def train_step_with_mtp(self, model, inputs): """ Specialized training step for MTP training. """ model.set_mtp_training(True) return self.training_step(model, inputs) def train_step_with_lm(self, model, inputs): """ Standard language modeling training step. """ model.set_mtp_training(False) return self.training_step(model, inputs) class RLTrainer: """ Separate trainer class for Reinforcement Learning training. This can use the generate_with_logprobs method from your model. """ def __init__(self, model, tokenizer, rl_config=None): self.model = model self.tokenizer = tokenizer self.rl_config = rl_config or {} def rl_training_step(self, input_ids, old_log_probs, old_action_probs, **kwargs): """ Perform an RL training step using the model's generate_with_logprobs method and the reward functions from rl_utils. """ # Import RL utilities from .rl_utils import ( batch_compute_rewards, compute_ppo_loss, compute_kl_divergence, compute_entropy_bonus, AdaptiveKLController ) # This would call the generate_with_logprobs method from your model # and then compute RL-specific losses pass