#!/usr/bin/env python3 """ Helion-2.5-Rnd Batch Inference Efficient batch processing for large-scale inference tasks """ import argparse import json import logging import time from pathlib import Path from typing import Dict, List, Optional, Union import pandas as pd from tqdm import tqdm from inference.client import HelionClient logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) class BatchProcessor: """Process large batches of inference requests""" def __init__( self, client: HelionClient, batch_size: int = 10, max_retries: int = 3, retry_delay: float = 1.0 ): """ Initialize batch processor Args: client: HelionClient instance batch_size: Number of requests to process concurrently max_retries: Maximum retry attempts for failed requests retry_delay: Delay between retries in seconds """ self.client = client self.batch_size = batch_size self.max_retries = max_retries self.retry_delay = retry_delay self.stats = { 'total': 0, 'successful': 0, 'failed': 0, 'total_time': 0.0, 'avg_time_per_request': 0.0 } def process_prompts( self, prompts: List[str], temperature: float = 0.7, max_tokens: int = 1024, **kwargs ) -> List[Dict]: """ Process a list of prompts Args: prompts: List of input prompts temperature: Sampling temperature max_tokens: Maximum tokens per response **kwargs: Additional generation parameters Returns: List of results with prompt, response, and metadata """ results = [] start_time = time.time() logger.info(f"Processing {len(prompts)} prompts...") for i in tqdm(range(0, len(prompts), self.batch_size)): batch = prompts[i:i + self.batch_size] for prompt in batch: result = self._process_single_with_retry( prompt, temperature, max_tokens, **kwargs ) results.append(result) # Update statistics self.stats['total'] = len(prompts) self.stats['successful'] = sum(1 for r in results if r['success']) self.stats['failed'] = len(prompts) - self.stats['successful'] self.stats['total_time'] = time.time() - start_time self.stats['avg_time_per_request'] = self.stats['total_time'] / len(prompts) logger.info(f"Batch processing complete. Success rate: {self.stats['successful']}/{self.stats['total']}") return results def _process_single_with_retry( self, prompt: str, temperature: float, max_tokens: int, **kwargs ) -> Dict: """Process single prompt with retry logic""" for attempt in range(self.max_retries): try: start = time.time() response = self.client.complete( prompt=prompt, temperature=temperature, max_tokens=max_tokens, **kwargs ) duration = time.time() - start return { 'prompt': prompt, 'response': response, 'success': True, 'duration': duration, 'attempts': attempt + 1 } except Exception as e: logger.warning(f"Attempt {attempt + 1} failed: {str(e)}") if attempt < self.max_retries - 1: time.sleep(self.retry_delay) else: return { 'prompt': prompt, 'response': None, 'success': False, 'error': str(e), 'attempts': attempt + 1 } def process_chat_conversations( self, conversations: List[List[Dict]], temperature: float = 0.7, max_tokens: int = 1024, **kwargs ) -> List[Dict]: """ Process chat conversations in batch Args: conversations: List of message lists temperature: Sampling temperature max_tokens: Maximum tokens per response **kwargs: Additional generation parameters Returns: List of conversation results """ results = [] start_time = time.time() logger.info(f"Processing {len(conversations)} conversations...") for conv in tqdm(conversations): try: start = time.time() response = self.client.chat( messages=conv, temperature=temperature, max_tokens=max_tokens, **kwargs ) duration = time.time() - start results.append({ 'conversation': conv, 'response': response, 'success': True, 'duration': duration }) except Exception as e: logger.error(f"Conversation processing failed: {str(e)}") results.append({ 'conversation': conv, 'response': None, 'success': False, 'error': str(e) }) total_time = time.time() - start_time successful = sum(1 for r in results if r['success']) logger.info(f"Processed {successful}/{len(conversations)} conversations in {total_time:.2f}s") return results def process_file( self, input_file: str, output_file: str, prompt_column: str = "prompt", temperature: float = 0.7, max_tokens: int = 1024, **kwargs ) -> pd.DataFrame: """ Process prompts from file Args: input_file: Input CSV/JSON file path output_file: Output file path prompt_column: Column name containing prompts temperature: Sampling temperature max_tokens: Maximum tokens per response **kwargs: Additional generation parameters Returns: DataFrame with results """ # Load input file input_path = Path(input_file) if input_path.suffix == '.csv': df = pd.read_csv(input_path) elif input_path.suffix == '.json': df = pd.read_json(input_path) else: raise ValueError(f"Unsupported file format: {input_path.suffix}") if prompt_column not in df.columns: raise ValueError(f"Column '{prompt_column}' not found in input file") # Process prompts prompts = df[prompt_column].tolist() results = self.process_prompts( prompts, temperature=temperature, max_tokens=max_tokens, **kwargs ) # Add results to dataframe df['response'] = [r['response'] for r in results] df['success'] = [r['success'] for r in results] df['duration'] = [r.get('duration', None) for r in results] df['error'] = [r.get('error', None) for r in results] # Save results output_path = Path(output_file) output_path.parent.mkdir(parents=True, exist_ok=True) if output_path.suffix == '.csv': df.to_csv(output_path, index=False) elif output_path.suffix == '.json': df.to_json(output_path, orient='records', indent=2) else: raise ValueError(f"Unsupported output format: {output_path.suffix}") logger.info(f"Results saved to {output_path}") return df def get_statistics(self) -> Dict: """Get processing statistics""" return self.stats.copy() class DatasetProcessor: """Process specific dataset formats""" def __init__(self, client: HelionClient): self.client = client self.processor = BatchProcessor(client) def process_qa_dataset( self, questions: List[str], contexts: Optional[List[str]] = None, temperature: float = 0.3, max_tokens: int = 512 ) -> List[Dict]: """Process question-answering dataset""" prompts = [] for i, question in enumerate(questions): if contexts and i < len(contexts): prompt = f"Context: {contexts[i]}\n\nQuestion: {question}\n\nAnswer:" else: prompt = f"Question: {question}\n\nAnswer:" prompts.append(prompt) return self.processor.process_prompts( prompts, temperature=temperature, max_tokens=max_tokens ) def process_code_dataset( self, tasks: List[str], languages: Optional[List[str]] = None, temperature: float = 0.2, max_tokens: int = 1024 ) -> List[Dict]: """Process code generation tasks""" prompts = [] for i, task in enumerate(tasks): lang = languages[i] if languages and i < len(languages) else "python" prompt = f"Write a {lang} function to: {task}\n\n```{lang}\n" prompts.append(prompt) return self.processor.process_prompts( prompts, temperature=temperature, max_tokens=max_tokens ) def process_translation_dataset( self, texts: List[str], source_lang: str, target_lang: str, temperature: float = 0.3, max_tokens: int = 1024 ) -> List[Dict]: """Process translation tasks""" prompts = [] for text in texts: prompt = f"Translate the following text from {source_lang} to {target_lang}:\n\n{text}\n\nTranslation:" prompts.append(prompt) return self.processor.process_prompts( prompts, temperature=temperature, max_tokens=max_tokens ) def process_summarization_dataset( self, documents: List[str], max_summary_length: int = 150, temperature: float = 0.5, max_tokens: int = 512 ) -> List[Dict]: """Process document summarization""" prompts = [] for doc in documents: prompt = f"Summarize the following document in {max_summary_length} words or less:\n\n{doc}\n\nSummary:" prompts.append(prompt) return self.processor.process_prompts( prompts, temperature=temperature, max_tokens=max_tokens ) def main(): """Main batch processing entry point""" parser = argparse.ArgumentParser(description="Batch inference with Helion") parser.add_argument("--base-url", type=str, default="http://localhost:8000") parser.add_argument("--input", type=str, required=True, help="Input file (CSV/JSON)") parser.add_argument("--output", type=str, required=True, help="Output file (CSV/JSON)") parser.add_argument("--prompt-column", type=str, default="prompt") parser.add_argument("--temperature", type=float, default=0.7) parser.add_argument("--max-tokens", type=int, default=1024) parser.add_argument("--batch-size", type=int, default=10) args = parser.parse_args() # Initialize client and processor client = HelionClient(base_url=args.base_url) processor = BatchProcessor(client, batch_size=args.batch_size) # Process file df = processor.process_file( input_file=args.input, output_file=args.output, prompt_column=args.prompt_column, temperature=args.temperature, max_tokens=args.max_tokens ) # Print statistics stats = processor.get_statistics() logger.info("\nProcessing Statistics:") logger.info(f"Total requests: {stats['total']}") logger.info(f"Successful: {stats['successful']}") logger.info(f"Failed: {stats['failed']}") logger.info(f"Total time: {stats['total_time']:.2f}s") logger.info(f"Avg time per request: {stats['avg_time_per_request']:.2f}s") if __name__ == "__main__": main()