Trouter-Library's picture
Create inference/pipeline.py (#3)
f30f448 verified
#!/usr/bin/env python3
"""
Helion-2.5-Rnd Inference Pipeline
High-level pipeline for easy model usage
"""
import logging
import time
from typing import Any, Dict, List, Optional, Union
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class StopOnTokens(StoppingCriteria):
"""Stop generation when specific tokens are generated"""
def __init__(self, stop_token_ids: List[int]):
self.stop_token_ids = stop_token_ids
def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
**kwargs
) -> bool:
for stop_id in self.stop_token_ids:
if input_ids[0][-1] == stop_id:
return True
return False
class HelionPipeline:
"""High-level inference pipeline for Helion model"""
def __init__(
self,
model_path: str,
device: str = "cuda",
torch_dtype=torch.bfloat16,
load_in_8bit: bool = False,
trust_remote_code: bool = True
):
"""
Initialize Helion pipeline
Args:
model_path: Path to model or HuggingFace ID
device: Device to load model on
torch_dtype: Torch data type
load_in_8bit: Whether to load in 8-bit
trust_remote_code: Trust remote code
"""
logger.info(f"Loading Helion model from {model_path}")
self.device = device
self.model_path = model_path
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=trust_remote_code
)
# Load model
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch_dtype,
device_map="auto" if device == "cuda" else None,
load_in_8bit=load_in_8bit,
trust_remote_code=trust_remote_code
)
if device != "cuda" and not load_in_8bit:
self.model = self.model.to(device)
self.model.eval()
# Setup stop tokens
self.stop_token_ids = [
self.tokenizer.eos_token_id,
self.tokenizer.convert_tokens_to_ids("<|im_end|>"),
]
logger.info("Model loaded successfully")
def generate(
self,
prompt: str,
max_new_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.1,
do_sample: bool = True,
num_return_sequences: int = 1,
**kwargs
) -> Union[str, List[str]]:
"""
Generate text from prompt
Args:
prompt: Input prompt
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter
repetition_penalty: Repetition penalty
do_sample: Whether to sample
num_return_sequences: Number of sequences to return
**kwargs: Additional generation parameters
Returns:
Generated text or list of texts
"""
# Tokenize input
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=self.model.config.max_position_embeddings
).to(self.device)
# Setup stopping criteria
stopping_criteria = StoppingCriteriaList([
StopOnTokens(self.stop_token_ids)
])
# Generate
with torch.no_grad():
start_time = time.time()
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
do_sample=do_sample,
num_return_sequences=num_return_sequences,
stopping_criteria=stopping_criteria,
pad_token_id=self.tokenizer.pad_token_id,
**kwargs
)
generation_time = time.time() - start_time
# Decode outputs
generated_texts = []
for output in outputs:
text = self.tokenizer.decode(
output[inputs['input_ids'].shape[1]:],
skip_special_tokens=True
)
generated_texts.append(text.strip())
logger.info(f"Generated {len(generated_texts)} sequences in {generation_time:.2f}s")
if num_return_sequences == 1:
return generated_texts[0]
return generated_texts
def chat(
self,
messages: List[Dict[str, str]],
max_new_tokens: int = 512,
temperature: float = 0.7,
**kwargs
) -> str:
"""
Chat completion
Args:
messages: List of message dictionaries
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature
**kwargs: Additional generation parameters
Returns:
Assistant response
"""
# Format chat prompt
prompt = self._format_chat_prompt(messages)
# Generate response
response = self.generate(
prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
**kwargs
)
return response
def _format_chat_prompt(self, messages: List[Dict[str, str]]) -> str:
"""Format messages into chat prompt"""
formatted = ""
for msg in messages:
role = msg.get('role', 'user')
content = msg.get('content', '')
formatted += f"<|im_start|>{role}\n{content}<|im_end|>\n"
formatted += "<|im_start|>assistant\n"
return formatted
def batch_generate(
self,
prompts: List[str],
max_new_tokens: int = 512,
temperature: float = 0.7,
batch_size: int = 4,
**kwargs
) -> List[str]:
"""
Generate for multiple prompts in batches
Args:
prompts: List of input prompts
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature
batch_size: Batch size for processing
**kwargs: Additional generation parameters
Returns:
List of generated texts
"""
all_outputs = []
for i in range(0, len(prompts), batch_size):
batch = prompts[i:i + batch_size]
# Tokenize batch
inputs = self.tokenizer(
batch,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.model.config.max_position_embeddings
).to(self.device)
# Generate
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
pad_token_id=self.tokenizer.pad_token_id,
**kwargs
)
# Decode
for j, output in enumerate(outputs):
text = self.tokenizer.decode(
output[inputs['input_ids'][j].shape[0]:],
skip_special_tokens=True
)
all_outputs.append(text.strip())
logger.info(f"Generated {len(all_outputs)} outputs")
return all_outputs
def stream_generate(
self,
prompt: str,
max_new_tokens: int = 512,
temperature: float = 0.7,
**kwargs
):
"""
Stream generation token by token
Args:
prompt: Input prompt
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature
**kwargs: Additional generation parameters
Yields:
Generated tokens
"""
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
input_length = inputs['input_ids'].shape[1]
stopping_criteria = StoppingCriteriaList([
StopOnTokens(self.stop_token_ids)
])
with torch.no_grad():
for _ in range(max_new_tokens):
outputs = self.model.generate(
**inputs,
max_new_tokens=1,
temperature=temperature,
stopping_criteria=stopping_criteria,
pad_token_id=self.tokenizer.pad_token_id,
**kwargs
)
new_token_id = outputs[0, -1].item()
# Check if stop token
if new_token_id in self.stop_token_ids:
break
# Decode and yield new token
new_token = self.tokenizer.decode([new_token_id])
yield new_token
# Update inputs for next iteration
inputs = {
'input_ids': outputs,
'attention_mask': torch.ones_like(outputs)
}
def get_embeddings(self, text: str) -> torch.Tensor:
"""
Get embeddings for text
Args:
text: Input text
Returns:
Embedding tensor
"""
inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model(**inputs, output_hidden_states=True)
embeddings = outputs.hidden_states[-1].mean(dim=1)
return embeddings
def score_text(self, text: str) -> float:
"""
Calculate perplexity score for text
Args:
text: Input text
Returns:
Perplexity score
"""
inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model(**inputs, labels=inputs['input_ids'])
loss = outputs.loss
perplexity = torch.exp(loss).item()
return perplexity
def cleanup(self):
"""Clean up resources"""
del self.model
del self.tokenizer
torch.cuda.empty_cache()
logger.info("Pipeline cleaned up")
class ConversationPipeline(HelionPipeline):
"""Pipeline with conversation history management"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conversation_history: List[Dict[str, str]] = []
self.system_prompt: Optional[str] = None
def set_system_prompt(self, prompt: str):
"""Set system prompt for conversation"""
self.system_prompt = prompt
def add_message(self, role: str, content: str):
"""Add message to conversation history"""
self.conversation_history.append({
'role': role,
'content': content
})
def generate_response(
self,
user_message: str,
max_new_tokens: int = 512,
temperature: float = 0.7,
**kwargs
) -> str:
"""
Generate response in conversation context
Args:
user_message: User's message
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature
**kwargs: Additional generation parameters
Returns:
Assistant response
"""
# Build messages
messages = []
if self.system_prompt:
messages.append({
'role': 'system',
'content': self.system_prompt
})
messages.extend(self.conversation_history)
messages.append({
'role': 'user',
'content': user_message
})
# Generate response
response = self.chat(
messages,
max_new_tokens=max_new_tokens,
temperature=temperature,
**kwargs
)
# Update history
self.add_message('user', user_message)
self.add_message('assistant', response)
return response
def reset_conversation(self):
"""Reset conversation history"""
self.conversation_history.clear()
logger.info("Conversation history reset")
def main():
"""Example usage"""
# Initialize pipeline
pipeline = HelionPipeline(
model_path="DeepXR/Helion-2.5-Rnd",
device="cuda"
)
# Simple generation
prompt = "Explain quantum computing in simple terms:"
response = pipeline.generate(prompt, max_new_tokens=256)
print(f"Response: {response}\n")
# Chat completion
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the capital of France?"}
]
response = pipeline.chat(messages)
print(f"Chat response: {response}\n")
# Batch generation
prompts = [
"Write a haiku about AI:",
"Explain machine learning:",
"What is Python?"
]
responses = pipeline.batch_generate(prompts, batch_size=2)
for i, resp in enumerate(responses):
print(f"Batch {i+1}: {resp}\n")
# Conversation
conv_pipeline = ConversationPipeline(
model_path="DeepXR/Helion-2.5-Rnd",
device="cuda"
)
conv_pipeline.set_system_prompt("You are a helpful coding assistant.")
response1 = conv_pipeline.generate_response("How do I sort a list in Python?")
print(f"Conv 1: {response1}\n")
response2 = conv_pipeline.generate_response("Can you show me an example?")
print(f"Conv 2: {response2}\n")
if __name__ == "__main__":
main()