Trouter-Library commited on
Commit
b010003
·
verified ·
1 Parent(s): f15baf7

Create inference/batch_inference.py

Browse files
Files changed (1) hide show
  1. inference/batch_inference.py +406 -0
inference/batch_inference.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Helion-2.5-Rnd Batch Inference
4
+ Efficient batch processing for large-scale inference tasks
5
+ """
6
+
7
+ import argparse
8
+ import json
9
+ import logging
10
+ import time
11
+ from pathlib import Path
12
+ from typing import Dict, List, Optional, Union
13
+
14
+ import pandas as pd
15
+ from tqdm import tqdm
16
+
17
+ from inference.client import HelionClient
18
+
19
+ logging.basicConfig(
20
+ level=logging.INFO,
21
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
22
+ )
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class BatchProcessor:
27
+ """Process large batches of inference requests"""
28
+
29
+ def __init__(
30
+ self,
31
+ client: HelionClient,
32
+ batch_size: int = 10,
33
+ max_retries: int = 3,
34
+ retry_delay: float = 1.0
35
+ ):
36
+ """
37
+ Initialize batch processor
38
+
39
+ Args:
40
+ client: HelionClient instance
41
+ batch_size: Number of requests to process concurrently
42
+ max_retries: Maximum retry attempts for failed requests
43
+ retry_delay: Delay between retries in seconds
44
+ """
45
+ self.client = client
46
+ self.batch_size = batch_size
47
+ self.max_retries = max_retries
48
+ self.retry_delay = retry_delay
49
+
50
+ self.stats = {
51
+ 'total': 0,
52
+ 'successful': 0,
53
+ 'failed': 0,
54
+ 'total_time': 0.0,
55
+ 'avg_time_per_request': 0.0
56
+ }
57
+
58
+ def process_prompts(
59
+ self,
60
+ prompts: List[str],
61
+ temperature: float = 0.7,
62
+ max_tokens: int = 1024,
63
+ **kwargs
64
+ ) -> List[Dict]:
65
+ """
66
+ Process a list of prompts
67
+
68
+ Args:
69
+ prompts: List of input prompts
70
+ temperature: Sampling temperature
71
+ max_tokens: Maximum tokens per response
72
+ **kwargs: Additional generation parameters
73
+
74
+ Returns:
75
+ List of results with prompt, response, and metadata
76
+ """
77
+ results = []
78
+ start_time = time.time()
79
+
80
+ logger.info(f"Processing {len(prompts)} prompts...")
81
+
82
+ for i in tqdm(range(0, len(prompts), self.batch_size)):
83
+ batch = prompts[i:i + self.batch_size]
84
+
85
+ for prompt in batch:
86
+ result = self._process_single_with_retry(
87
+ prompt,
88
+ temperature,
89
+ max_tokens,
90
+ **kwargs
91
+ )
92
+ results.append(result)
93
+
94
+ # Update statistics
95
+ self.stats['total'] = len(prompts)
96
+ self.stats['successful'] = sum(1 for r in results if r['success'])
97
+ self.stats['failed'] = len(prompts) - self.stats['successful']
98
+ self.stats['total_time'] = time.time() - start_time
99
+ self.stats['avg_time_per_request'] = self.stats['total_time'] / len(prompts)
100
+
101
+ logger.info(f"Batch processing complete. Success rate: {self.stats['successful']}/{self.stats['total']}")
102
+
103
+ return results
104
+
105
+ def _process_single_with_retry(
106
+ self,
107
+ prompt: str,
108
+ temperature: float,
109
+ max_tokens: int,
110
+ **kwargs
111
+ ) -> Dict:
112
+ """Process single prompt with retry logic"""
113
+ for attempt in range(self.max_retries):
114
+ try:
115
+ start = time.time()
116
+ response = self.client.complete(
117
+ prompt=prompt,
118
+ temperature=temperature,
119
+ max_tokens=max_tokens,
120
+ **kwargs
121
+ )
122
+ duration = time.time() - start
123
+
124
+ return {
125
+ 'prompt': prompt,
126
+ 'response': response,
127
+ 'success': True,
128
+ 'duration': duration,
129
+ 'attempts': attempt + 1
130
+ }
131
+
132
+ except Exception as e:
133
+ logger.warning(f"Attempt {attempt + 1} failed: {str(e)}")
134
+
135
+ if attempt < self.max_retries - 1:
136
+ time.sleep(self.retry_delay)
137
+ else:
138
+ return {
139
+ 'prompt': prompt,
140
+ 'response': None,
141
+ 'success': False,
142
+ 'error': str(e),
143
+ 'attempts': attempt + 1
144
+ }
145
+
146
+ def process_chat_conversations(
147
+ self,
148
+ conversations: List[List[Dict]],
149
+ temperature: float = 0.7,
150
+ max_tokens: int = 1024,
151
+ **kwargs
152
+ ) -> List[Dict]:
153
+ """
154
+ Process chat conversations in batch
155
+
156
+ Args:
157
+ conversations: List of message lists
158
+ temperature: Sampling temperature
159
+ max_tokens: Maximum tokens per response
160
+ **kwargs: Additional generation parameters
161
+
162
+ Returns:
163
+ List of conversation results
164
+ """
165
+ results = []
166
+ start_time = time.time()
167
+
168
+ logger.info(f"Processing {len(conversations)} conversations...")
169
+
170
+ for conv in tqdm(conversations):
171
+ try:
172
+ start = time.time()
173
+ response = self.client.chat(
174
+ messages=conv,
175
+ temperature=temperature,
176
+ max_tokens=max_tokens,
177
+ **kwargs
178
+ )
179
+ duration = time.time() - start
180
+
181
+ results.append({
182
+ 'conversation': conv,
183
+ 'response': response,
184
+ 'success': True,
185
+ 'duration': duration
186
+ })
187
+
188
+ except Exception as e:
189
+ logger.error(f"Conversation processing failed: {str(e)}")
190
+ results.append({
191
+ 'conversation': conv,
192
+ 'response': None,
193
+ 'success': False,
194
+ 'error': str(e)
195
+ })
196
+
197
+ total_time = time.time() - start_time
198
+ successful = sum(1 for r in results if r['success'])
199
+
200
+ logger.info(f"Processed {successful}/{len(conversations)} conversations in {total_time:.2f}s")
201
+
202
+ return results
203
+
204
+ def process_file(
205
+ self,
206
+ input_file: str,
207
+ output_file: str,
208
+ prompt_column: str = "prompt",
209
+ temperature: float = 0.7,
210
+ max_tokens: int = 1024,
211
+ **kwargs
212
+ ) -> pd.DataFrame:
213
+ """
214
+ Process prompts from file
215
+
216
+ Args:
217
+ input_file: Input CSV/JSON file path
218
+ output_file: Output file path
219
+ prompt_column: Column name containing prompts
220
+ temperature: Sampling temperature
221
+ max_tokens: Maximum tokens per response
222
+ **kwargs: Additional generation parameters
223
+
224
+ Returns:
225
+ DataFrame with results
226
+ """
227
+ # Load input file
228
+ input_path = Path(input_file)
229
+
230
+ if input_path.suffix == '.csv':
231
+ df = pd.read_csv(input_path)
232
+ elif input_path.suffix == '.json':
233
+ df = pd.read_json(input_path)
234
+ else:
235
+ raise ValueError(f"Unsupported file format: {input_path.suffix}")
236
+
237
+ if prompt_column not in df.columns:
238
+ raise ValueError(f"Column '{prompt_column}' not found in input file")
239
+
240
+ # Process prompts
241
+ prompts = df[prompt_column].tolist()
242
+ results = self.process_prompts(
243
+ prompts,
244
+ temperature=temperature,
245
+ max_tokens=max_tokens,
246
+ **kwargs
247
+ )
248
+
249
+ # Add results to dataframe
250
+ df['response'] = [r['response'] for r in results]
251
+ df['success'] = [r['success'] for r in results]
252
+ df['duration'] = [r.get('duration', None) for r in results]
253
+ df['error'] = [r.get('error', None) for r in results]
254
+
255
+ # Save results
256
+ output_path = Path(output_file)
257
+ output_path.parent.mkdir(parents=True, exist_ok=True)
258
+
259
+ if output_path.suffix == '.csv':
260
+ df.to_csv(output_path, index=False)
261
+ elif output_path.suffix == '.json':
262
+ df.to_json(output_path, orient='records', indent=2)
263
+ else:
264
+ raise ValueError(f"Unsupported output format: {output_path.suffix}")
265
+
266
+ logger.info(f"Results saved to {output_path}")
267
+
268
+ return df
269
+
270
+ def get_statistics(self) -> Dict:
271
+ """Get processing statistics"""
272
+ return self.stats.copy()
273
+
274
+
275
+ class DatasetProcessor:
276
+ """Process specific dataset formats"""
277
+
278
+ def __init__(self, client: HelionClient):
279
+ self.client = client
280
+ self.processor = BatchProcessor(client)
281
+
282
+ def process_qa_dataset(
283
+ self,
284
+ questions: List[str],
285
+ contexts: Optional[List[str]] = None,
286
+ temperature: float = 0.3,
287
+ max_tokens: int = 512
288
+ ) -> List[Dict]:
289
+ """Process question-answering dataset"""
290
+ prompts = []
291
+
292
+ for i, question in enumerate(questions):
293
+ if contexts and i < len(contexts):
294
+ prompt = f"Context: {contexts[i]}\n\nQuestion: {question}\n\nAnswer:"
295
+ else:
296
+ prompt = f"Question: {question}\n\nAnswer:"
297
+
298
+ prompts.append(prompt)
299
+
300
+ return self.processor.process_prompts(
301
+ prompts,
302
+ temperature=temperature,
303
+ max_tokens=max_tokens
304
+ )
305
+
306
+ def process_code_dataset(
307
+ self,
308
+ tasks: List[str],
309
+ languages: Optional[List[str]] = None,
310
+ temperature: float = 0.2,
311
+ max_tokens: int = 1024
312
+ ) -> List[Dict]:
313
+ """Process code generation tasks"""
314
+ prompts = []
315
+
316
+ for i, task in enumerate(tasks):
317
+ lang = languages[i] if languages and i < len(languages) else "python"
318
+ prompt = f"Write a {lang} function to: {task}\n\n```{lang}\n"
319
+ prompts.append(prompt)
320
+
321
+ return self.processor.process_prompts(
322
+ prompts,
323
+ temperature=temperature,
324
+ max_tokens=max_tokens
325
+ )
326
+
327
+ def process_translation_dataset(
328
+ self,
329
+ texts: List[str],
330
+ source_lang: str,
331
+ target_lang: str,
332
+ temperature: float = 0.3,
333
+ max_tokens: int = 1024
334
+ ) -> List[Dict]:
335
+ """Process translation tasks"""
336
+ prompts = []
337
+
338
+ for text in texts:
339
+ prompt = f"Translate the following text from {source_lang} to {target_lang}:\n\n{text}\n\nTranslation:"
340
+ prompts.append(prompt)
341
+
342
+ return self.processor.process_prompts(
343
+ prompts,
344
+ temperature=temperature,
345
+ max_tokens=max_tokens
346
+ )
347
+
348
+ def process_summarization_dataset(
349
+ self,
350
+ documents: List[str],
351
+ max_summary_length: int = 150,
352
+ temperature: float = 0.5,
353
+ max_tokens: int = 512
354
+ ) -> List[Dict]:
355
+ """Process document summarization"""
356
+ prompts = []
357
+
358
+ for doc in documents:
359
+ prompt = f"Summarize the following document in {max_summary_length} words or less:\n\n{doc}\n\nSummary:"
360
+ prompts.append(prompt)
361
+
362
+ return self.processor.process_prompts(
363
+ prompts,
364
+ temperature=temperature,
365
+ max_tokens=max_tokens
366
+ )
367
+
368
+
369
+ def main():
370
+ """Main batch processing entry point"""
371
+ parser = argparse.ArgumentParser(description="Batch inference with Helion")
372
+ parser.add_argument("--base-url", type=str, default="http://localhost:8000")
373
+ parser.add_argument("--input", type=str, required=True, help="Input file (CSV/JSON)")
374
+ parser.add_argument("--output", type=str, required=True, help="Output file (CSV/JSON)")
375
+ parser.add_argument("--prompt-column", type=str, default="prompt")
376
+ parser.add_argument("--temperature", type=float, default=0.7)
377
+ parser.add_argument("--max-tokens", type=int, default=1024)
378
+ parser.add_argument("--batch-size", type=int, default=10)
379
+
380
+ args = parser.parse_args()
381
+
382
+ # Initialize client and processor
383
+ client = HelionClient(base_url=args.base_url)
384
+ processor = BatchProcessor(client, batch_size=args.batch_size)
385
+
386
+ # Process file
387
+ df = processor.process_file(
388
+ input_file=args.input,
389
+ output_file=args.output,
390
+ prompt_column=args.prompt_column,
391
+ temperature=args.temperature,
392
+ max_tokens=args.max_tokens
393
+ )
394
+
395
+ # Print statistics
396
+ stats = processor.get_statistics()
397
+ logger.info("\nProcessing Statistics:")
398
+ logger.info(f"Total requests: {stats['total']}")
399
+ logger.info(f"Successful: {stats['successful']}")
400
+ logger.info(f"Failed: {stats['failed']}")
401
+ logger.info(f"Total time: {stats['total_time']:.2f}s")
402
+ logger.info(f"Avg time per request: {stats['avg_time_per_request']:.2f}s")
403
+
404
+
405
+ if __name__ == "__main__":
406
+ main()