Trouter-Library commited on
Commit
f15baf7
·
verified ·
1 Parent(s): a4013c5

Create inference/evaluate.py

Browse files
Files changed (1) hide show
  1. inference/evaluate.py +401 -0
inference/evaluate.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Helion-2.5-Rnd Evaluation Script
4
+ Comprehensive benchmark evaluation across multiple datasets
5
+ """
6
+
7
+ import argparse
8
+ import json
9
+ import logging
10
+ import os
11
+ from collections import defaultdict
12
+ from pathlib import Path
13
+ from typing import Dict, List, Optional
14
+
15
+ import torch
16
+ from datasets import load_dataset
17
+ from tqdm import tqdm
18
+ from transformers import AutoModelForCausalLM, AutoTokenizer
19
+
20
+ logging.basicConfig(
21
+ level=logging.INFO,
22
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
23
+ )
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class HelionEvaluator:
28
+ """Evaluation framework for Helion model"""
29
+
30
+ def __init__(
31
+ self,
32
+ model_path: str,
33
+ device: str = "cuda",
34
+ batch_size: int = 1,
35
+ max_length: int = 2048
36
+ ):
37
+ """
38
+ Initialize evaluator
39
+
40
+ Args:
41
+ model_path: Path to model or HuggingFace model ID
42
+ device: Device to run evaluation on
43
+ batch_size: Batch size for evaluation
44
+ max_length: Maximum sequence length
45
+ """
46
+ logger.info(f"Loading model from {model_path}")
47
+
48
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
49
+ self.model = AutoModelForCausalLM.from_pretrained(
50
+ model_path,
51
+ torch_dtype=torch.bfloat16,
52
+ device_map="auto",
53
+ trust_remote_code=True
54
+ )
55
+
56
+ self.device = device
57
+ self.batch_size = batch_size
58
+ self.max_length = max_length
59
+
60
+ logger.info("Model loaded successfully")
61
+
62
+ def generate(
63
+ self,
64
+ prompt: str,
65
+ max_new_tokens: int = 512,
66
+ temperature: float = 0.0,
67
+ **kwargs
68
+ ) -> str:
69
+ """Generate text from prompt"""
70
+ inputs = self.tokenizer(
71
+ prompt,
72
+ return_tensors="pt",
73
+ truncation=True,
74
+ max_length=self.max_length
75
+ ).to(self.device)
76
+
77
+ with torch.no_grad():
78
+ outputs = self.model.generate(
79
+ **inputs,
80
+ max_new_tokens=max_new_tokens,
81
+ temperature=temperature if temperature > 0 else 1.0,
82
+ do_sample=temperature > 0,
83
+ pad_token_id=self.tokenizer.pad_token_id,
84
+ **kwargs
85
+ )
86
+
87
+ response = self.tokenizer.decode(
88
+ outputs[0][inputs['input_ids'].shape[1]:],
89
+ skip_special_tokens=True
90
+ )
91
+
92
+ return response.strip()
93
+
94
+ def evaluate_mmlu(self, num_samples: Optional[int] = None) -> Dict:
95
+ """Evaluate on MMLU benchmark"""
96
+ logger.info("Evaluating on MMLU...")
97
+
98
+ dataset = load_dataset("cais/mmlu", "all", split="test")
99
+ if num_samples:
100
+ dataset = dataset.select(range(min(num_samples, len(dataset))))
101
+
102
+ correct = 0
103
+ total = 0
104
+
105
+ for example in tqdm(dataset, desc="MMLU"):
106
+ question = example["question"]
107
+ choices = example["choices"]
108
+ answer = example["answer"]
109
+
110
+ # Format prompt
111
+ prompt = f"Question: {question}\n\nChoices:\n"
112
+ for i, choice in enumerate(choices):
113
+ prompt += f"{chr(65+i)}. {choice}\n"
114
+ prompt += "\nAnswer: "
115
+
116
+ # Generate response
117
+ response = self.generate(prompt, max_new_tokens=10, temperature=0.0)
118
+
119
+ # Extract answer
120
+ pred = response.strip()[0].upper() if response else ""
121
+ correct_answer = chr(65 + answer)
122
+
123
+ if pred == correct_answer:
124
+ correct += 1
125
+ total += 1
126
+
127
+ accuracy = correct / total if total > 0 else 0
128
+
129
+ return {
130
+ "benchmark": "MMLU",
131
+ "accuracy": accuracy,
132
+ "correct": correct,
133
+ "total": total
134
+ }
135
+
136
+ def evaluate_gsm8k(self, num_samples: Optional[int] = None) -> Dict:
137
+ """Evaluate on GSM8K mathematical reasoning"""
138
+ logger.info("Evaluating on GSM8K...")
139
+
140
+ dataset = load_dataset("gsm8k", "main", split="test")
141
+ if num_samples:
142
+ dataset = dataset.select(range(min(num_samples, len(dataset))))
143
+
144
+ correct = 0
145
+ total = 0
146
+
147
+ for example in tqdm(dataset, desc="GSM8K"):
148
+ question = example["question"]
149
+ answer = example["answer"]
150
+
151
+ # Extract numerical answer
152
+ import re
153
+ match = re.search(r'####\s*(-?\d+(?:,\d+)*(?:\.\d+)?)', answer)
154
+ if not match:
155
+ continue
156
+
157
+ correct_answer = match.group(1).replace(',', '')
158
+
159
+ # Format prompt
160
+ prompt = f"Question: {question}\n\nLet's solve this step by step:\n"
161
+
162
+ # Generate response
163
+ response = self.generate(prompt, max_new_tokens=512, temperature=0.0)
164
+
165
+ # Extract predicted answer
166
+ pred_match = re.search(r'(?:answer is|=)\s*(-?\d+(?:,\d+)*(?:\.\d+)?)', response.lower())
167
+ if pred_match:
168
+ pred_answer = pred_match.group(1).replace(',', '')
169
+ if pred_answer == correct_answer:
170
+ correct += 1
171
+
172
+ total += 1
173
+
174
+ accuracy = correct / total if total > 0 else 0
175
+
176
+ return {
177
+ "benchmark": "GSM8K",
178
+ "accuracy": accuracy,
179
+ "correct": correct,
180
+ "total": total
181
+ }
182
+
183
+ def evaluate_humaneval(self, num_samples: Optional[int] = None) -> Dict:
184
+ """Evaluate on HumanEval code generation"""
185
+ logger.info("Evaluating on HumanEval...")
186
+
187
+ try:
188
+ dataset = load_dataset("openai_humaneval", split="test")
189
+ except:
190
+ logger.warning("HumanEval dataset not available")
191
+ return {"benchmark": "HumanEval", "error": "Dataset not available"}
192
+
193
+ if num_samples:
194
+ dataset = dataset.select(range(min(num_samples, len(dataset))))
195
+
196
+ results = []
197
+
198
+ for example in tqdm(dataset, desc="HumanEval"):
199
+ prompt = example["prompt"]
200
+
201
+ # Generate code
202
+ full_prompt = f"Complete the following Python function:\n\n{prompt}"
203
+ response = self.generate(
204
+ full_prompt,
205
+ max_new_tokens=512,
206
+ temperature=0.0
207
+ )
208
+
209
+ # Extract code
210
+ code = prompt + response
211
+
212
+ results.append({
213
+ "task_id": example["task_id"],
214
+ "completion": code,
215
+ "test": example["test"]
216
+ })
217
+
218
+ # Note: Full evaluation requires executing code
219
+ # This is a simplified version
220
+ return {
221
+ "benchmark": "HumanEval",
222
+ "samples_generated": len(results),
223
+ "note": "Full evaluation requires code execution framework"
224
+ }
225
+
226
+ def evaluate_truthfulqa(self, num_samples: Optional[int] = None) -> Dict:
227
+ """Evaluate on TruthfulQA"""
228
+ logger.info("Evaluating on TruthfulQA...")
229
+
230
+ dataset = load_dataset("truthful_qa", "generation", split="validation")
231
+ if num_samples:
232
+ dataset = dataset.select(range(min(num_samples, len(dataset))))
233
+
234
+ responses = []
235
+
236
+ for example in tqdm(dataset, desc="TruthfulQA"):
237
+ question = example["question"]
238
+
239
+ prompt = f"Question: {question}\n\nProvide a truthful and accurate answer:\nAnswer: "
240
+
241
+ response = self.generate(prompt, max_new_tokens=256, temperature=0.0)
242
+
243
+ responses.append({
244
+ "question": question,
245
+ "response": response,
246
+ "best_answer": example["best_answer"],
247
+ "correct_answers": example["correct_answers"],
248
+ "incorrect_answers": example["incorrect_answers"]
249
+ })
250
+
251
+ return {
252
+ "benchmark": "TruthfulQA",
253
+ "samples_evaluated": len(responses),
254
+ "note": "Manual review required for truthfulness assessment"
255
+ }
256
+
257
+ def evaluate_all(
258
+ self,
259
+ output_file: Optional[str] = None,
260
+ num_samples: Optional[int] = None
261
+ ) -> Dict:
262
+ """Run all evaluations"""
263
+ logger.info("Starting comprehensive evaluation...")
264
+
265
+ results = {
266
+ "model": "DeepXR/Helion-2.5-Rnd",
267
+ "benchmarks": {}
268
+ }
269
+
270
+ # Run evaluations
271
+ try:
272
+ results["benchmarks"]["mmlu"] = self.evaluate_mmlu(num_samples)
273
+ except Exception as e:
274
+ logger.error(f"MMLU evaluation failed: {e}")
275
+ results["benchmarks"]["mmlu"] = {"error": str(e)}
276
+
277
+ try:
278
+ results["benchmarks"]["gsm8k"] = self.evaluate_gsm8k(num_samples)
279
+ except Exception as e:
280
+ logger.error(f"GSM8K evaluation failed: {e}")
281
+ results["benchmarks"]["gsm8k"] = {"error": str(e)}
282
+
283
+ try:
284
+ results["benchmarks"]["humaneval"] = self.evaluate_humaneval(num_samples)
285
+ except Exception as e:
286
+ logger.error(f"HumanEval evaluation failed: {e}")
287
+ results["benchmarks"]["humaneval"] = {"error": str(e)}
288
+
289
+ try:
290
+ results["benchmarks"]["truthfulqa"] = self.evaluate_truthfulqa(num_samples)
291
+ except Exception as e:
292
+ logger.error(f"TruthfulQA evaluation failed: {e}")
293
+ results["benchmarks"]["truthfulqa"] = {"error": str(e)}
294
+
295
+ # Save results
296
+ if output_file:
297
+ output_path = Path(output_file)
298
+ output_path.parent.mkdir(parents=True, exist_ok=True)
299
+
300
+ with open(output_path, 'w') as f:
301
+ json.dump(results, f, indent=2)
302
+
303
+ logger.info(f"Results saved to {output_path}")
304
+
305
+ # Print summary
306
+ logger.info("\n" + "="*50)
307
+ logger.info("EVALUATION SUMMARY")
308
+ logger.info("="*50)
309
+
310
+ for benchmark, result in results["benchmarks"].items():
311
+ if "accuracy" in result:
312
+ logger.info(f"{benchmark.upper()}: {result['accuracy']:.2%}")
313
+ elif "error" in result:
314
+ logger.info(f"{benchmark.upper()}: ERROR - {result['error']}")
315
+ else:
316
+ logger.info(f"{benchmark.upper()}: {result.get('note', 'Completed')}")
317
+
318
+ return results
319
+
320
+
321
+ def main():
322
+ """Main evaluation entry point"""
323
+ parser = argparse.ArgumentParser(description="Evaluate Helion model")
324
+ parser.add_argument(
325
+ "--model",
326
+ type=str,
327
+ required=True,
328
+ help="Model path or HuggingFace ID"
329
+ )
330
+ parser.add_argument(
331
+ "--benchmarks",
332
+ type=str,
333
+ nargs="+",
334
+ default=["all"],
335
+ choices=["all", "mmlu", "gsm8k", "humaneval", "truthfulqa"],
336
+ help="Benchmarks to run"
337
+ )
338
+ parser.add_argument(
339
+ "--output",
340
+ type=str,
341
+ default="evaluation_results.json",
342
+ help="Output file for results"
343
+ )
344
+ parser.add_argument(
345
+ "--num-samples",
346
+ type=int,
347
+ default=None,
348
+ help="Number of samples to evaluate (for quick testing)"
349
+ )
350
+ parser.add_argument(
351
+ "--device",
352
+ type=str,
353
+ default="cuda",
354
+ help="Device to use"
355
+ )
356
+ parser.add_argument(
357
+ "--batch-size",
358
+ type=int,
359
+ default=1,
360
+ help="Batch size"
361
+ )
362
+
363
+ args = parser.parse_args()
364
+
365
+ # Initialize evaluator
366
+ evaluator = HelionEvaluator(
367
+ model_path=args.model,
368
+ device=args.device,
369
+ batch_size=args.batch_size
370
+ )
371
+
372
+ # Run evaluations
373
+ if "all" in args.benchmarks:
374
+ results = evaluator.evaluate_all(
375
+ output_file=args.output,
376
+ num_samples=args.num_samples
377
+ )
378
+ else:
379
+ results = {"model": args.model, "benchmarks": {}}
380
+
381
+ if "mmlu" in args.benchmarks:
382
+ results["benchmarks"]["mmlu"] = evaluator.evaluate_mmlu(args.num_samples)
383
+
384
+ if "gsm8k" in args.benchmarks:
385
+ results["benchmarks"]["gsm8k"] = evaluator.evaluate_gsm8k(args.num_samples)
386
+
387
+ if "humaneval" in args.benchmarks:
388
+ results["benchmarks"]["humaneval"] = evaluator.evaluate_humaneval(args.num_samples)
389
+
390
+ if "truthfulqa" in args.benchmarks:
391
+ results["benchmarks"]["truthfulqa"] = evaluator.evaluate_truthfulqa(args.num_samples)
392
+
393
+ # Save results
394
+ with open(args.output, 'w') as f:
395
+ json.dump(results, f, indent=2)
396
+
397
+ logger.info(f"Results saved to {args.output}")
398
+
399
+
400
+ if __name__ == "__main__":
401
+ main()