ChemMiniQ3-SAbRLo / trainer.py
gbyuvd's picture
Upload 6 files
c9723bd verified
# trainer.py
from transformers import Trainer, TrainingArguments
from torch.utils.data import DataLoader
import torch
import torch.nn.functional as F
class MTPTrainer(Trainer):
"""
Custom trainer for Multi-Token Prediction training.
"""
def __init__(self, model, args=None, train_dataset=None, eval_dataset=None, **kwargs):
super().__init__(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
**kwargs
)
self.use_mtp_training = True
def compute_loss(self, model, inputs, return_outputs=False):
"""
Compute loss during training - handles both MTP and standard LM training.
"""
labels = inputs.get("labels")
outputs = model(**inputs, labels=labels, use_mtp_training=self.use_mtp_training)
loss = outputs.loss
return (loss, outputs) if return_outputs else loss
def train_step_with_mtp(self, model, inputs):
"""
Specialized training step for MTP training.
"""
model.set_mtp_training(True)
return self.training_step(model, inputs)
def train_step_with_lm(self, model, inputs):
"""
Standard language modeling training step.
"""
model.set_mtp_training(False)
return self.training_step(model, inputs)
class RLTrainer:
"""
Separate trainer class for Reinforcement Learning training.
This can use the generate_with_logprobs method from your model.
"""
def __init__(self, model, tokenizer, rl_config=None):
self.model = model
self.tokenizer = tokenizer
self.rl_config = rl_config or {}
def rl_training_step(self, input_ids, old_log_probs, old_action_probs, **kwargs):
"""
Perform an RL training step using the model's generate_with_logprobs method
and the reward functions from rl_utils.
"""
# Import RL utilities
from .rl_utils import (
batch_compute_rewards,
compute_ppo_loss,
compute_kl_divergence,
compute_entropy_bonus,
AdaptiveKLController
)
# This would call the generate_with_logprobs method from your model
# and then compute RL-specific losses
pass