|
|
|
|
|
""" |
|
|
Helion-2.5-Rnd Inference Pipeline |
|
|
High-level pipeline for easy model usage |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import time |
|
|
from typing import Any, Dict, List, Optional, Union |
|
|
|
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class StopOnTokens(StoppingCriteria): |
|
|
"""Stop generation when specific tokens are generated""" |
|
|
|
|
|
def __init__(self, stop_token_ids: List[int]): |
|
|
self.stop_token_ids = stop_token_ids |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
scores: torch.FloatTensor, |
|
|
**kwargs |
|
|
) -> bool: |
|
|
for stop_id in self.stop_token_ids: |
|
|
if input_ids[0][-1] == stop_id: |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
class HelionPipeline: |
|
|
"""High-level inference pipeline for Helion model""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_path: str, |
|
|
device: str = "cuda", |
|
|
torch_dtype=torch.bfloat16, |
|
|
load_in_8bit: bool = False, |
|
|
trust_remote_code: bool = True |
|
|
): |
|
|
""" |
|
|
Initialize Helion pipeline |
|
|
|
|
|
Args: |
|
|
model_path: Path to model or HuggingFace ID |
|
|
device: Device to load model on |
|
|
torch_dtype: Torch data type |
|
|
load_in_8bit: Whether to load in 8-bit |
|
|
trust_remote_code: Trust remote code |
|
|
""" |
|
|
logger.info(f"Loading Helion model from {model_path}") |
|
|
|
|
|
self.device = device |
|
|
self.model_path = model_path |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
model_path, |
|
|
trust_remote_code=trust_remote_code |
|
|
) |
|
|
|
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=torch_dtype, |
|
|
device_map="auto" if device == "cuda" else None, |
|
|
load_in_8bit=load_in_8bit, |
|
|
trust_remote_code=trust_remote_code |
|
|
) |
|
|
|
|
|
if device != "cuda" and not load_in_8bit: |
|
|
self.model = self.model.to(device) |
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
self.stop_token_ids = [ |
|
|
self.tokenizer.eos_token_id, |
|
|
self.tokenizer.convert_tokens_to_ids("<|im_end|>"), |
|
|
] |
|
|
|
|
|
logger.info("Model loaded successfully") |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
prompt: str, |
|
|
max_new_tokens: int = 512, |
|
|
temperature: float = 0.7, |
|
|
top_p: float = 0.9, |
|
|
top_k: int = 50, |
|
|
repetition_penalty: float = 1.1, |
|
|
do_sample: bool = True, |
|
|
num_return_sequences: int = 1, |
|
|
**kwargs |
|
|
) -> Union[str, List[str]]: |
|
|
""" |
|
|
Generate text from prompt |
|
|
|
|
|
Args: |
|
|
prompt: Input prompt |
|
|
max_new_tokens: Maximum tokens to generate |
|
|
temperature: Sampling temperature |
|
|
top_p: Nucleus sampling parameter |
|
|
top_k: Top-k sampling parameter |
|
|
repetition_penalty: Repetition penalty |
|
|
do_sample: Whether to sample |
|
|
num_return_sequences: Number of sequences to return |
|
|
**kwargs: Additional generation parameters |
|
|
|
|
|
Returns: |
|
|
Generated text or list of texts |
|
|
""" |
|
|
|
|
|
inputs = self.tokenizer( |
|
|
prompt, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=self.model.config.max_position_embeddings |
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
stopping_criteria = StoppingCriteriaList([ |
|
|
StopOnTokens(self.stop_token_ids) |
|
|
]) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
start_time = time.time() |
|
|
|
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
top_k=top_k, |
|
|
repetition_penalty=repetition_penalty, |
|
|
do_sample=do_sample, |
|
|
num_return_sequences=num_return_sequences, |
|
|
stopping_criteria=stopping_criteria, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
generation_time = time.time() - start_time |
|
|
|
|
|
|
|
|
generated_texts = [] |
|
|
for output in outputs: |
|
|
text = self.tokenizer.decode( |
|
|
output[inputs['input_ids'].shape[1]:], |
|
|
skip_special_tokens=True |
|
|
) |
|
|
generated_texts.append(text.strip()) |
|
|
|
|
|
logger.info(f"Generated {len(generated_texts)} sequences in {generation_time:.2f}s") |
|
|
|
|
|
if num_return_sequences == 1: |
|
|
return generated_texts[0] |
|
|
return generated_texts |
|
|
|
|
|
def chat( |
|
|
self, |
|
|
messages: List[Dict[str, str]], |
|
|
max_new_tokens: int = 512, |
|
|
temperature: float = 0.7, |
|
|
**kwargs |
|
|
) -> str: |
|
|
""" |
|
|
Chat completion |
|
|
|
|
|
Args: |
|
|
messages: List of message dictionaries |
|
|
max_new_tokens: Maximum tokens to generate |
|
|
temperature: Sampling temperature |
|
|
**kwargs: Additional generation parameters |
|
|
|
|
|
Returns: |
|
|
Assistant response |
|
|
""" |
|
|
|
|
|
prompt = self._format_chat_prompt(messages) |
|
|
|
|
|
|
|
|
response = self.generate( |
|
|
prompt, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
return response |
|
|
|
|
|
def _format_chat_prompt(self, messages: List[Dict[str, str]]) -> str: |
|
|
"""Format messages into chat prompt""" |
|
|
formatted = "" |
|
|
|
|
|
for msg in messages: |
|
|
role = msg.get('role', 'user') |
|
|
content = msg.get('content', '') |
|
|
formatted += f"<|im_start|>{role}\n{content}<|im_end|>\n" |
|
|
|
|
|
formatted += "<|im_start|>assistant\n" |
|
|
return formatted |
|
|
|
|
|
def batch_generate( |
|
|
self, |
|
|
prompts: List[str], |
|
|
max_new_tokens: int = 512, |
|
|
temperature: float = 0.7, |
|
|
batch_size: int = 4, |
|
|
**kwargs |
|
|
) -> List[str]: |
|
|
""" |
|
|
Generate for multiple prompts in batches |
|
|
|
|
|
Args: |
|
|
prompts: List of input prompts |
|
|
max_new_tokens: Maximum tokens to generate |
|
|
temperature: Sampling temperature |
|
|
batch_size: Batch size for processing |
|
|
**kwargs: Additional generation parameters |
|
|
|
|
|
Returns: |
|
|
List of generated texts |
|
|
""" |
|
|
all_outputs = [] |
|
|
|
|
|
for i in range(0, len(prompts), batch_size): |
|
|
batch = prompts[i:i + batch_size] |
|
|
|
|
|
|
|
|
inputs = self.tokenizer( |
|
|
batch, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=self.model.config.max_position_embeddings |
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
|
|
|
for j, output in enumerate(outputs): |
|
|
text = self.tokenizer.decode( |
|
|
output[inputs['input_ids'][j].shape[0]:], |
|
|
skip_special_tokens=True |
|
|
) |
|
|
all_outputs.append(text.strip()) |
|
|
|
|
|
logger.info(f"Generated {len(all_outputs)} outputs") |
|
|
return all_outputs |
|
|
|
|
|
def stream_generate( |
|
|
self, |
|
|
prompt: str, |
|
|
max_new_tokens: int = 512, |
|
|
temperature: float = 0.7, |
|
|
**kwargs |
|
|
): |
|
|
""" |
|
|
Stream generation token by token |
|
|
|
|
|
Args: |
|
|
prompt: Input prompt |
|
|
max_new_tokens: Maximum tokens to generate |
|
|
temperature: Sampling temperature |
|
|
**kwargs: Additional generation parameters |
|
|
|
|
|
Yields: |
|
|
Generated tokens |
|
|
""" |
|
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) |
|
|
input_length = inputs['input_ids'].shape[1] |
|
|
|
|
|
stopping_criteria = StoppingCriteriaList([ |
|
|
StopOnTokens(self.stop_token_ids) |
|
|
]) |
|
|
|
|
|
with torch.no_grad(): |
|
|
for _ in range(max_new_tokens): |
|
|
outputs = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=1, |
|
|
temperature=temperature, |
|
|
stopping_criteria=stopping_criteria, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
new_token_id = outputs[0, -1].item() |
|
|
|
|
|
|
|
|
if new_token_id in self.stop_token_ids: |
|
|
break |
|
|
|
|
|
|
|
|
new_token = self.tokenizer.decode([new_token_id]) |
|
|
yield new_token |
|
|
|
|
|
|
|
|
inputs = { |
|
|
'input_ids': outputs, |
|
|
'attention_mask': torch.ones_like(outputs) |
|
|
} |
|
|
|
|
|
def get_embeddings(self, text: str) -> torch.Tensor: |
|
|
""" |
|
|
Get embeddings for text |
|
|
|
|
|
Args: |
|
|
text: Input text |
|
|
|
|
|
Returns: |
|
|
Embedding tensor |
|
|
""" |
|
|
inputs = self.tokenizer(text, return_tensors="pt").to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**inputs, output_hidden_states=True) |
|
|
embeddings = outputs.hidden_states[-1].mean(dim=1) |
|
|
|
|
|
return embeddings |
|
|
|
|
|
def score_text(self, text: str) -> float: |
|
|
""" |
|
|
Calculate perplexity score for text |
|
|
|
|
|
Args: |
|
|
text: Input text |
|
|
|
|
|
Returns: |
|
|
Perplexity score |
|
|
""" |
|
|
inputs = self.tokenizer(text, return_tensors="pt").to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**inputs, labels=inputs['input_ids']) |
|
|
loss = outputs.loss |
|
|
perplexity = torch.exp(loss).item() |
|
|
|
|
|
return perplexity |
|
|
|
|
|
def cleanup(self): |
|
|
"""Clean up resources""" |
|
|
del self.model |
|
|
del self.tokenizer |
|
|
torch.cuda.empty_cache() |
|
|
logger.info("Pipeline cleaned up") |
|
|
|
|
|
|
|
|
class ConversationPipeline(HelionPipeline): |
|
|
"""Pipeline with conversation history management""" |
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
self.conversation_history: List[Dict[str, str]] = [] |
|
|
self.system_prompt: Optional[str] = None |
|
|
|
|
|
def set_system_prompt(self, prompt: str): |
|
|
"""Set system prompt for conversation""" |
|
|
self.system_prompt = prompt |
|
|
|
|
|
def add_message(self, role: str, content: str): |
|
|
"""Add message to conversation history""" |
|
|
self.conversation_history.append({ |
|
|
'role': role, |
|
|
'content': content |
|
|
}) |
|
|
|
|
|
def generate_response( |
|
|
self, |
|
|
user_message: str, |
|
|
max_new_tokens: int = 512, |
|
|
temperature: float = 0.7, |
|
|
**kwargs |
|
|
) -> str: |
|
|
""" |
|
|
Generate response in conversation context |
|
|
|
|
|
Args: |
|
|
user_message: User's message |
|
|
max_new_tokens: Maximum tokens to generate |
|
|
temperature: Sampling temperature |
|
|
**kwargs: Additional generation parameters |
|
|
|
|
|
Returns: |
|
|
Assistant response |
|
|
""" |
|
|
|
|
|
messages = [] |
|
|
|
|
|
if self.system_prompt: |
|
|
messages.append({ |
|
|
'role': 'system', |
|
|
'content': self.system_prompt |
|
|
}) |
|
|
|
|
|
messages.extend(self.conversation_history) |
|
|
messages.append({ |
|
|
'role': 'user', |
|
|
'content': user_message |
|
|
}) |
|
|
|
|
|
|
|
|
response = self.chat( |
|
|
messages, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
|
|
|
self.add_message('user', user_message) |
|
|
self.add_message('assistant', response) |
|
|
|
|
|
return response |
|
|
|
|
|
def reset_conversation(self): |
|
|
"""Reset conversation history""" |
|
|
self.conversation_history.clear() |
|
|
logger.info("Conversation history reset") |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Example usage""" |
|
|
|
|
|
pipeline = HelionPipeline( |
|
|
model_path="DeepXR/Helion-2.5-Rnd", |
|
|
device="cuda" |
|
|
) |
|
|
|
|
|
|
|
|
prompt = "Explain quantum computing in simple terms:" |
|
|
response = pipeline.generate(prompt, max_new_tokens=256) |
|
|
print(f"Response: {response}\n") |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
|
{"role": "user", "content": "What is the capital of France?"} |
|
|
] |
|
|
response = pipeline.chat(messages) |
|
|
print(f"Chat response: {response}\n") |
|
|
|
|
|
|
|
|
prompts = [ |
|
|
"Write a haiku about AI:", |
|
|
"Explain machine learning:", |
|
|
"What is Python?" |
|
|
] |
|
|
responses = pipeline.batch_generate(prompts, batch_size=2) |
|
|
for i, resp in enumerate(responses): |
|
|
print(f"Batch {i+1}: {resp}\n") |
|
|
|
|
|
|
|
|
conv_pipeline = ConversationPipeline( |
|
|
model_path="DeepXR/Helion-2.5-Rnd", |
|
|
device="cuda" |
|
|
) |
|
|
conv_pipeline.set_system_prompt("You are a helpful coding assistant.") |
|
|
|
|
|
response1 = conv_pipeline.generate_response("How do I sort a list in Python?") |
|
|
print(f"Conv 1: {response1}\n") |
|
|
|
|
|
response2 = conv_pipeline.generate_response("Can you show me an example?") |
|
|
print(f"Conv 2: {response2}\n") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |