|
|
|
|
|
""" |
|
|
Helion-2.5-Rnd Batch Inference |
|
|
Efficient batch processing for large-scale inference tasks |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import logging |
|
|
import time |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Optional, Union |
|
|
|
|
|
import pandas as pd |
|
|
from tqdm import tqdm |
|
|
|
|
|
from inference.client import HelionClient |
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class BatchProcessor: |
|
|
"""Process large batches of inference requests""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
client: HelionClient, |
|
|
batch_size: int = 10, |
|
|
max_retries: int = 3, |
|
|
retry_delay: float = 1.0 |
|
|
): |
|
|
""" |
|
|
Initialize batch processor |
|
|
|
|
|
Args: |
|
|
client: HelionClient instance |
|
|
batch_size: Number of requests to process concurrently |
|
|
max_retries: Maximum retry attempts for failed requests |
|
|
retry_delay: Delay between retries in seconds |
|
|
""" |
|
|
self.client = client |
|
|
self.batch_size = batch_size |
|
|
self.max_retries = max_retries |
|
|
self.retry_delay = retry_delay |
|
|
|
|
|
self.stats = { |
|
|
'total': 0, |
|
|
'successful': 0, |
|
|
'failed': 0, |
|
|
'total_time': 0.0, |
|
|
'avg_time_per_request': 0.0 |
|
|
} |
|
|
|
|
|
def process_prompts( |
|
|
self, |
|
|
prompts: List[str], |
|
|
temperature: float = 0.7, |
|
|
max_tokens: int = 1024, |
|
|
**kwargs |
|
|
) -> List[Dict]: |
|
|
""" |
|
|
Process a list of prompts |
|
|
|
|
|
Args: |
|
|
prompts: List of input prompts |
|
|
temperature: Sampling temperature |
|
|
max_tokens: Maximum tokens per response |
|
|
**kwargs: Additional generation parameters |
|
|
|
|
|
Returns: |
|
|
List of results with prompt, response, and metadata |
|
|
""" |
|
|
results = [] |
|
|
start_time = time.time() |
|
|
|
|
|
logger.info(f"Processing {len(prompts)} prompts...") |
|
|
|
|
|
for i in tqdm(range(0, len(prompts), self.batch_size)): |
|
|
batch = prompts[i:i + self.batch_size] |
|
|
|
|
|
for prompt in batch: |
|
|
result = self._process_single_with_retry( |
|
|
prompt, |
|
|
temperature, |
|
|
max_tokens, |
|
|
**kwargs |
|
|
) |
|
|
results.append(result) |
|
|
|
|
|
|
|
|
self.stats['total'] = len(prompts) |
|
|
self.stats['successful'] = sum(1 for r in results if r['success']) |
|
|
self.stats['failed'] = len(prompts) - self.stats['successful'] |
|
|
self.stats['total_time'] = time.time() - start_time |
|
|
self.stats['avg_time_per_request'] = self.stats['total_time'] / len(prompts) |
|
|
|
|
|
logger.info(f"Batch processing complete. Success rate: {self.stats['successful']}/{self.stats['total']}") |
|
|
|
|
|
return results |
|
|
|
|
|
def _process_single_with_retry( |
|
|
self, |
|
|
prompt: str, |
|
|
temperature: float, |
|
|
max_tokens: int, |
|
|
**kwargs |
|
|
) -> Dict: |
|
|
"""Process single prompt with retry logic""" |
|
|
for attempt in range(self.max_retries): |
|
|
try: |
|
|
start = time.time() |
|
|
response = self.client.complete( |
|
|
prompt=prompt, |
|
|
temperature=temperature, |
|
|
max_tokens=max_tokens, |
|
|
**kwargs |
|
|
) |
|
|
duration = time.time() - start |
|
|
|
|
|
return { |
|
|
'prompt': prompt, |
|
|
'response': response, |
|
|
'success': True, |
|
|
'duration': duration, |
|
|
'attempts': attempt + 1 |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Attempt {attempt + 1} failed: {str(e)}") |
|
|
|
|
|
if attempt < self.max_retries - 1: |
|
|
time.sleep(self.retry_delay) |
|
|
else: |
|
|
return { |
|
|
'prompt': prompt, |
|
|
'response': None, |
|
|
'success': False, |
|
|
'error': str(e), |
|
|
'attempts': attempt + 1 |
|
|
} |
|
|
|
|
|
def process_chat_conversations( |
|
|
self, |
|
|
conversations: List[List[Dict]], |
|
|
temperature: float = 0.7, |
|
|
max_tokens: int = 1024, |
|
|
**kwargs |
|
|
) -> List[Dict]: |
|
|
""" |
|
|
Process chat conversations in batch |
|
|
|
|
|
Args: |
|
|
conversations: List of message lists |
|
|
temperature: Sampling temperature |
|
|
max_tokens: Maximum tokens per response |
|
|
**kwargs: Additional generation parameters |
|
|
|
|
|
Returns: |
|
|
List of conversation results |
|
|
""" |
|
|
results = [] |
|
|
start_time = time.time() |
|
|
|
|
|
logger.info(f"Processing {len(conversations)} conversations...") |
|
|
|
|
|
for conv in tqdm(conversations): |
|
|
try: |
|
|
start = time.time() |
|
|
response = self.client.chat( |
|
|
messages=conv, |
|
|
temperature=temperature, |
|
|
max_tokens=max_tokens, |
|
|
**kwargs |
|
|
) |
|
|
duration = time.time() - start |
|
|
|
|
|
results.append({ |
|
|
'conversation': conv, |
|
|
'response': response, |
|
|
'success': True, |
|
|
'duration': duration |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Conversation processing failed: {str(e)}") |
|
|
results.append({ |
|
|
'conversation': conv, |
|
|
'response': None, |
|
|
'success': False, |
|
|
'error': str(e) |
|
|
}) |
|
|
|
|
|
total_time = time.time() - start_time |
|
|
successful = sum(1 for r in results if r['success']) |
|
|
|
|
|
logger.info(f"Processed {successful}/{len(conversations)} conversations in {total_time:.2f}s") |
|
|
|
|
|
return results |
|
|
|
|
|
def process_file( |
|
|
self, |
|
|
input_file: str, |
|
|
output_file: str, |
|
|
prompt_column: str = "prompt", |
|
|
temperature: float = 0.7, |
|
|
max_tokens: int = 1024, |
|
|
**kwargs |
|
|
) -> pd.DataFrame: |
|
|
""" |
|
|
Process prompts from file |
|
|
|
|
|
Args: |
|
|
input_file: Input CSV/JSON file path |
|
|
output_file: Output file path |
|
|
prompt_column: Column name containing prompts |
|
|
temperature: Sampling temperature |
|
|
max_tokens: Maximum tokens per response |
|
|
**kwargs: Additional generation parameters |
|
|
|
|
|
Returns: |
|
|
DataFrame with results |
|
|
""" |
|
|
|
|
|
input_path = Path(input_file) |
|
|
|
|
|
if input_path.suffix == '.csv': |
|
|
df = pd.read_csv(input_path) |
|
|
elif input_path.suffix == '.json': |
|
|
df = pd.read_json(input_path) |
|
|
else: |
|
|
raise ValueError(f"Unsupported file format: {input_path.suffix}") |
|
|
|
|
|
if prompt_column not in df.columns: |
|
|
raise ValueError(f"Column '{prompt_column}' not found in input file") |
|
|
|
|
|
|
|
|
prompts = df[prompt_column].tolist() |
|
|
results = self.process_prompts( |
|
|
prompts, |
|
|
temperature=temperature, |
|
|
max_tokens=max_tokens, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
|
|
|
df['response'] = [r['response'] for r in results] |
|
|
df['success'] = [r['success'] for r in results] |
|
|
df['duration'] = [r.get('duration', None) for r in results] |
|
|
df['error'] = [r.get('error', None) for r in results] |
|
|
|
|
|
|
|
|
output_path = Path(output_file) |
|
|
output_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
if output_path.suffix == '.csv': |
|
|
df.to_csv(output_path, index=False) |
|
|
elif output_path.suffix == '.json': |
|
|
df.to_json(output_path, orient='records', indent=2) |
|
|
else: |
|
|
raise ValueError(f"Unsupported output format: {output_path.suffix}") |
|
|
|
|
|
logger.info(f"Results saved to {output_path}") |
|
|
|
|
|
return df |
|
|
|
|
|
def get_statistics(self) -> Dict: |
|
|
"""Get processing statistics""" |
|
|
return self.stats.copy() |
|
|
|
|
|
|
|
|
class DatasetProcessor: |
|
|
"""Process specific dataset formats""" |
|
|
|
|
|
def __init__(self, client: HelionClient): |
|
|
self.client = client |
|
|
self.processor = BatchProcessor(client) |
|
|
|
|
|
def process_qa_dataset( |
|
|
self, |
|
|
questions: List[str], |
|
|
contexts: Optional[List[str]] = None, |
|
|
temperature: float = 0.3, |
|
|
max_tokens: int = 512 |
|
|
) -> List[Dict]: |
|
|
"""Process question-answering dataset""" |
|
|
prompts = [] |
|
|
|
|
|
for i, question in enumerate(questions): |
|
|
if contexts and i < len(contexts): |
|
|
prompt = f"Context: {contexts[i]}\n\nQuestion: {question}\n\nAnswer:" |
|
|
else: |
|
|
prompt = f"Question: {question}\n\nAnswer:" |
|
|
|
|
|
prompts.append(prompt) |
|
|
|
|
|
return self.processor.process_prompts( |
|
|
prompts, |
|
|
temperature=temperature, |
|
|
max_tokens=max_tokens |
|
|
) |
|
|
|
|
|
def process_code_dataset( |
|
|
self, |
|
|
tasks: List[str], |
|
|
languages: Optional[List[str]] = None, |
|
|
temperature: float = 0.2, |
|
|
max_tokens: int = 1024 |
|
|
) -> List[Dict]: |
|
|
"""Process code generation tasks""" |
|
|
prompts = [] |
|
|
|
|
|
for i, task in enumerate(tasks): |
|
|
lang = languages[i] if languages and i < len(languages) else "python" |
|
|
prompt = f"Write a {lang} function to: {task}\n\n```{lang}\n" |
|
|
prompts.append(prompt) |
|
|
|
|
|
return self.processor.process_prompts( |
|
|
prompts, |
|
|
temperature=temperature, |
|
|
max_tokens=max_tokens |
|
|
) |
|
|
|
|
|
def process_translation_dataset( |
|
|
self, |
|
|
texts: List[str], |
|
|
source_lang: str, |
|
|
target_lang: str, |
|
|
temperature: float = 0.3, |
|
|
max_tokens: int = 1024 |
|
|
) -> List[Dict]: |
|
|
"""Process translation tasks""" |
|
|
prompts = [] |
|
|
|
|
|
for text in texts: |
|
|
prompt = f"Translate the following text from {source_lang} to {target_lang}:\n\n{text}\n\nTranslation:" |
|
|
prompts.append(prompt) |
|
|
|
|
|
return self.processor.process_prompts( |
|
|
prompts, |
|
|
temperature=temperature, |
|
|
max_tokens=max_tokens |
|
|
) |
|
|
|
|
|
def process_summarization_dataset( |
|
|
self, |
|
|
documents: List[str], |
|
|
max_summary_length: int = 150, |
|
|
temperature: float = 0.5, |
|
|
max_tokens: int = 512 |
|
|
) -> List[Dict]: |
|
|
"""Process document summarization""" |
|
|
prompts = [] |
|
|
|
|
|
for doc in documents: |
|
|
prompt = f"Summarize the following document in {max_summary_length} words or less:\n\n{doc}\n\nSummary:" |
|
|
prompts.append(prompt) |
|
|
|
|
|
return self.processor.process_prompts( |
|
|
prompts, |
|
|
temperature=temperature, |
|
|
max_tokens=max_tokens |
|
|
) |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main batch processing entry point""" |
|
|
parser = argparse.ArgumentParser(description="Batch inference with Helion") |
|
|
parser.add_argument("--base-url", type=str, default="http://localhost:8000") |
|
|
parser.add_argument("--input", type=str, required=True, help="Input file (CSV/JSON)") |
|
|
parser.add_argument("--output", type=str, required=True, help="Output file (CSV/JSON)") |
|
|
parser.add_argument("--prompt-column", type=str, default="prompt") |
|
|
parser.add_argument("--temperature", type=float, default=0.7) |
|
|
parser.add_argument("--max-tokens", type=int, default=1024) |
|
|
parser.add_argument("--batch-size", type=int, default=10) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
client = HelionClient(base_url=args.base_url) |
|
|
processor = BatchProcessor(client, batch_size=args.batch_size) |
|
|
|
|
|
|
|
|
df = processor.process_file( |
|
|
input_file=args.input, |
|
|
output_file=args.output, |
|
|
prompt_column=args.prompt_column, |
|
|
temperature=args.temperature, |
|
|
max_tokens=args.max_tokens |
|
|
) |
|
|
|
|
|
|
|
|
stats = processor.get_statistics() |
|
|
logger.info("\nProcessing Statistics:") |
|
|
logger.info(f"Total requests: {stats['total']}") |
|
|
logger.info(f"Successful: {stats['successful']}") |
|
|
logger.info(f"Failed: {stats['failed']}") |
|
|
logger.info(f"Total time: {stats['total_time']:.2f}s") |
|
|
logger.info(f"Avg time per request: {stats['avg_time_per_request']:.2f}s") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |