Trouter-Library's picture
Create inference/evaluate.py
f15baf7 verified
#!/usr/bin/env python3
"""
Helion-2.5-Rnd Evaluation Script
Comprehensive benchmark evaluation across multiple datasets
"""
import argparse
import json
import logging
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional
import torch
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class HelionEvaluator:
"""Evaluation framework for Helion model"""
def __init__(
self,
model_path: str,
device: str = "cuda",
batch_size: int = 1,
max_length: int = 2048
):
"""
Initialize evaluator
Args:
model_path: Path to model or HuggingFace model ID
device: Device to run evaluation on
batch_size: Batch size for evaluation
max_length: Maximum sequence length
"""
logger.info(f"Loading model from {model_path}")
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
self.device = device
self.batch_size = batch_size
self.max_length = max_length
logger.info("Model loaded successfully")
def generate(
self,
prompt: str,
max_new_tokens: int = 512,
temperature: float = 0.0,
**kwargs
) -> str:
"""Generate text from prompt"""
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=self.max_length
).to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature if temperature > 0 else 1.0,
do_sample=temperature > 0,
pad_token_id=self.tokenizer.pad_token_id,
**kwargs
)
response = self.tokenizer.decode(
outputs[0][inputs['input_ids'].shape[1]:],
skip_special_tokens=True
)
return response.strip()
def evaluate_mmlu(self, num_samples: Optional[int] = None) -> Dict:
"""Evaluate on MMLU benchmark"""
logger.info("Evaluating on MMLU...")
dataset = load_dataset("cais/mmlu", "all", split="test")
if num_samples:
dataset = dataset.select(range(min(num_samples, len(dataset))))
correct = 0
total = 0
for example in tqdm(dataset, desc="MMLU"):
question = example["question"]
choices = example["choices"]
answer = example["answer"]
# Format prompt
prompt = f"Question: {question}\n\nChoices:\n"
for i, choice in enumerate(choices):
prompt += f"{chr(65+i)}. {choice}\n"
prompt += "\nAnswer: "
# Generate response
response = self.generate(prompt, max_new_tokens=10, temperature=0.0)
# Extract answer
pred = response.strip()[0].upper() if response else ""
correct_answer = chr(65 + answer)
if pred == correct_answer:
correct += 1
total += 1
accuracy = correct / total if total > 0 else 0
return {
"benchmark": "MMLU",
"accuracy": accuracy,
"correct": correct,
"total": total
}
def evaluate_gsm8k(self, num_samples: Optional[int] = None) -> Dict:
"""Evaluate on GSM8K mathematical reasoning"""
logger.info("Evaluating on GSM8K...")
dataset = load_dataset("gsm8k", "main", split="test")
if num_samples:
dataset = dataset.select(range(min(num_samples, len(dataset))))
correct = 0
total = 0
for example in tqdm(dataset, desc="GSM8K"):
question = example["question"]
answer = example["answer"]
# Extract numerical answer
import re
match = re.search(r'####\s*(-?\d+(?:,\d+)*(?:\.\d+)?)', answer)
if not match:
continue
correct_answer = match.group(1).replace(',', '')
# Format prompt
prompt = f"Question: {question}\n\nLet's solve this step by step:\n"
# Generate response
response = self.generate(prompt, max_new_tokens=512, temperature=0.0)
# Extract predicted answer
pred_match = re.search(r'(?:answer is|=)\s*(-?\d+(?:,\d+)*(?:\.\d+)?)', response.lower())
if pred_match:
pred_answer = pred_match.group(1).replace(',', '')
if pred_answer == correct_answer:
correct += 1
total += 1
accuracy = correct / total if total > 0 else 0
return {
"benchmark": "GSM8K",
"accuracy": accuracy,
"correct": correct,
"total": total
}
def evaluate_humaneval(self, num_samples: Optional[int] = None) -> Dict:
"""Evaluate on HumanEval code generation"""
logger.info("Evaluating on HumanEval...")
try:
dataset = load_dataset("openai_humaneval", split="test")
except:
logger.warning("HumanEval dataset not available")
return {"benchmark": "HumanEval", "error": "Dataset not available"}
if num_samples:
dataset = dataset.select(range(min(num_samples, len(dataset))))
results = []
for example in tqdm(dataset, desc="HumanEval"):
prompt = example["prompt"]
# Generate code
full_prompt = f"Complete the following Python function:\n\n{prompt}"
response = self.generate(
full_prompt,
max_new_tokens=512,
temperature=0.0
)
# Extract code
code = prompt + response
results.append({
"task_id": example["task_id"],
"completion": code,
"test": example["test"]
})
# Note: Full evaluation requires executing code
# This is a simplified version
return {
"benchmark": "HumanEval",
"samples_generated": len(results),
"note": "Full evaluation requires code execution framework"
}
def evaluate_truthfulqa(self, num_samples: Optional[int] = None) -> Dict:
"""Evaluate on TruthfulQA"""
logger.info("Evaluating on TruthfulQA...")
dataset = load_dataset("truthful_qa", "generation", split="validation")
if num_samples:
dataset = dataset.select(range(min(num_samples, len(dataset))))
responses = []
for example in tqdm(dataset, desc="TruthfulQA"):
question = example["question"]
prompt = f"Question: {question}\n\nProvide a truthful and accurate answer:\nAnswer: "
response = self.generate(prompt, max_new_tokens=256, temperature=0.0)
responses.append({
"question": question,
"response": response,
"best_answer": example["best_answer"],
"correct_answers": example["correct_answers"],
"incorrect_answers": example["incorrect_answers"]
})
return {
"benchmark": "TruthfulQA",
"samples_evaluated": len(responses),
"note": "Manual review required for truthfulness assessment"
}
def evaluate_all(
self,
output_file: Optional[str] = None,
num_samples: Optional[int] = None
) -> Dict:
"""Run all evaluations"""
logger.info("Starting comprehensive evaluation...")
results = {
"model": "DeepXR/Helion-2.5-Rnd",
"benchmarks": {}
}
# Run evaluations
try:
results["benchmarks"]["mmlu"] = self.evaluate_mmlu(num_samples)
except Exception as e:
logger.error(f"MMLU evaluation failed: {e}")
results["benchmarks"]["mmlu"] = {"error": str(e)}
try:
results["benchmarks"]["gsm8k"] = self.evaluate_gsm8k(num_samples)
except Exception as e:
logger.error(f"GSM8K evaluation failed: {e}")
results["benchmarks"]["gsm8k"] = {"error": str(e)}
try:
results["benchmarks"]["humaneval"] = self.evaluate_humaneval(num_samples)
except Exception as e:
logger.error(f"HumanEval evaluation failed: {e}")
results["benchmarks"]["humaneval"] = {"error": str(e)}
try:
results["benchmarks"]["truthfulqa"] = self.evaluate_truthfulqa(num_samples)
except Exception as e:
logger.error(f"TruthfulQA evaluation failed: {e}")
results["benchmarks"]["truthfulqa"] = {"error": str(e)}
# Save results
if output_file:
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w') as f:
json.dump(results, f, indent=2)
logger.info(f"Results saved to {output_path}")
# Print summary
logger.info("\n" + "="*50)
logger.info("EVALUATION SUMMARY")
logger.info("="*50)
for benchmark, result in results["benchmarks"].items():
if "accuracy" in result:
logger.info(f"{benchmark.upper()}: {result['accuracy']:.2%}")
elif "error" in result:
logger.info(f"{benchmark.upper()}: ERROR - {result['error']}")
else:
logger.info(f"{benchmark.upper()}: {result.get('note', 'Completed')}")
return results
def main():
"""Main evaluation entry point"""
parser = argparse.ArgumentParser(description="Evaluate Helion model")
parser.add_argument(
"--model",
type=str,
required=True,
help="Model path or HuggingFace ID"
)
parser.add_argument(
"--benchmarks",
type=str,
nargs="+",
default=["all"],
choices=["all", "mmlu", "gsm8k", "humaneval", "truthfulqa"],
help="Benchmarks to run"
)
parser.add_argument(
"--output",
type=str,
default="evaluation_results.json",
help="Output file for results"
)
parser.add_argument(
"--num-samples",
type=int,
default=None,
help="Number of samples to evaluate (for quick testing)"
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device to use"
)
parser.add_argument(
"--batch-size",
type=int,
default=1,
help="Batch size"
)
args = parser.parse_args()
# Initialize evaluator
evaluator = HelionEvaluator(
model_path=args.model,
device=args.device,
batch_size=args.batch_size
)
# Run evaluations
if "all" in args.benchmarks:
results = evaluator.evaluate_all(
output_file=args.output,
num_samples=args.num_samples
)
else:
results = {"model": args.model, "benchmarks": {}}
if "mmlu" in args.benchmarks:
results["benchmarks"]["mmlu"] = evaluator.evaluate_mmlu(args.num_samples)
if "gsm8k" in args.benchmarks:
results["benchmarks"]["gsm8k"] = evaluator.evaluate_gsm8k(args.num_samples)
if "humaneval" in args.benchmarks:
results["benchmarks"]["humaneval"] = evaluator.evaluate_humaneval(args.num_samples)
if "truthfulqa" in args.benchmarks:
results["benchmarks"]["truthfulqa"] = evaluator.evaluate_truthfulqa(args.num_samples)
# Save results
with open(args.output, 'w') as f:
json.dump(results, f, indent=2)
logger.info(f"Results saved to {args.output}")
if __name__ == "__main__":
main()