|
|
|
|
|
""" |
|
|
Helion-2.5-Rnd Utility Functions |
|
|
Common utilities for model inference and processing |
|
|
""" |
|
|
|
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
import time |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import yaml |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class ModelConfig: |
|
|
"""Model configuration manager""" |
|
|
|
|
|
def __init__(self, config_path: str = "model_config.yaml"): |
|
|
"""Load configuration from YAML file""" |
|
|
self.config_path = Path(config_path) |
|
|
self.config = self._load_config() |
|
|
|
|
|
def _load_config(self) -> Dict[str, Any]: |
|
|
"""Load YAML configuration""" |
|
|
if not self.config_path.exists(): |
|
|
logger.warning(f"Config file not found: {self.config_path}") |
|
|
return self._default_config() |
|
|
|
|
|
with open(self.config_path, 'r') as f: |
|
|
config = yaml.safe_load(f) |
|
|
|
|
|
logger.info(f"Loaded configuration from {self.config_path}") |
|
|
return config |
|
|
|
|
|
def _default_config(self) -> Dict[str, Any]: |
|
|
"""Return default configuration""" |
|
|
return { |
|
|
"model": { |
|
|
"name": "DeepXR/Helion-2.5-Rnd", |
|
|
"max_position_embeddings": 131072, |
|
|
}, |
|
|
"inference": { |
|
|
"default_parameters": { |
|
|
"temperature": 0.7, |
|
|
"top_p": 0.9, |
|
|
"max_new_tokens": 4096, |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
def get(self, key: str, default: Any = None) -> Any: |
|
|
"""Get configuration value by dot-separated key""" |
|
|
keys = key.split('.') |
|
|
value = self.config |
|
|
|
|
|
for k in keys: |
|
|
if isinstance(value, dict): |
|
|
value = value.get(k) |
|
|
if value is None: |
|
|
return default |
|
|
else: |
|
|
return default |
|
|
|
|
|
return value |
|
|
|
|
|
|
|
|
class TokenCounter: |
|
|
"""Token counting utilities""" |
|
|
|
|
|
def __init__(self, model_name: str = "meta-llama/Meta-Llama-3.1-70B"): |
|
|
"""Initialize tokenizer for counting""" |
|
|
try: |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to load tokenizer: {e}") |
|
|
self.tokenizer = None |
|
|
|
|
|
def count_tokens(self, text: str) -> int: |
|
|
"""Count tokens in text""" |
|
|
if self.tokenizer is None: |
|
|
|
|
|
return len(text) // 4 |
|
|
|
|
|
return len(self.tokenizer.encode(text)) |
|
|
|
|
|
def count_messages_tokens(self, messages: List[Dict[str, str]]) -> int: |
|
|
"""Count tokens in message list""" |
|
|
total = 0 |
|
|
for msg in messages: |
|
|
|
|
|
total += self.count_tokens(msg.get('role', '')) |
|
|
total += self.count_tokens(msg.get('content', '')) |
|
|
|
|
|
total += 4 |
|
|
|
|
|
return total |
|
|
|
|
|
def truncate_to_tokens( |
|
|
self, |
|
|
text: str, |
|
|
max_tokens: int, |
|
|
from_end: bool = False |
|
|
) -> str: |
|
|
"""Truncate text to maximum token count""" |
|
|
if self.tokenizer is None: |
|
|
|
|
|
max_chars = max_tokens * 4 |
|
|
if from_end: |
|
|
return text[-max_chars:] |
|
|
return text[:max_chars] |
|
|
|
|
|
tokens = self.tokenizer.encode(text) |
|
|
|
|
|
if len(tokens) <= max_tokens: |
|
|
return text |
|
|
|
|
|
if from_end: |
|
|
truncated_tokens = tokens[-max_tokens:] |
|
|
else: |
|
|
truncated_tokens = tokens[:max_tokens] |
|
|
|
|
|
return self.tokenizer.decode(truncated_tokens) |
|
|
|
|
|
|
|
|
class PromptTemplate: |
|
|
"""Prompt templating utilities""" |
|
|
|
|
|
TEMPLATES = { |
|
|
"chat": ( |
|
|
"{% for message in messages %}" |
|
|
"<|im_start|>{{ message.role }}\n{{ message.content }}<|im_end|>\n" |
|
|
"{% endfor %}" |
|
|
"<|im_start|>assistant\n" |
|
|
), |
|
|
"instruction": ( |
|
|
"### Instruction:\n{instruction}\n\n" |
|
|
"### Response:\n" |
|
|
), |
|
|
"qa": ( |
|
|
"Question: {question}\n\n" |
|
|
"Answer: " |
|
|
), |
|
|
"code": ( |
|
|
"# Task: {task}\n\n" |
|
|
"```{language}\n" |
|
|
), |
|
|
"analysis": ( |
|
|
"Analyze the following:\n\n{content}\n\n" |
|
|
"Analysis:" |
|
|
) |
|
|
} |
|
|
|
|
|
@classmethod |
|
|
def format(cls, template_name: str, **kwargs) -> str: |
|
|
"""Format a template with given arguments""" |
|
|
template = cls.TEMPLATES.get(template_name) |
|
|
if template is None: |
|
|
raise ValueError(f"Unknown template: {template_name}") |
|
|
|
|
|
|
|
|
try: |
|
|
return template.format(**kwargs) |
|
|
except KeyError as e: |
|
|
raise ValueError(f"Missing required argument: {e}") |
|
|
|
|
|
@classmethod |
|
|
def format_chat(cls, messages: List[Dict[str, str]]) -> str: |
|
|
"""Format chat messages into prompt""" |
|
|
formatted = "" |
|
|
for msg in messages: |
|
|
role = msg.get('role', 'user') |
|
|
content = msg.get('content', '') |
|
|
formatted += f"<|im_start|>{role}\n{content}<|im_end|>\n" |
|
|
formatted += "<|im_start|>assistant\n" |
|
|
return formatted |
|
|
|
|
|
|
|
|
class ResponseParser: |
|
|
"""Parse and validate model responses""" |
|
|
|
|
|
@staticmethod |
|
|
def extract_code(response: str, language: Optional[str] = None) -> str: |
|
|
"""Extract code from markdown code blocks""" |
|
|
import re |
|
|
|
|
|
if language: |
|
|
pattern = f"```{language}\n(.*?)```" |
|
|
else: |
|
|
pattern = r"```(?:\w+)?\n(.*?)```" |
|
|
|
|
|
matches = re.findall(pattern, response, re.DOTALL) |
|
|
|
|
|
if matches: |
|
|
return matches[0].strip() |
|
|
|
|
|
|
|
|
return response.strip() |
|
|
|
|
|
@staticmethod |
|
|
def extract_json(response: str) -> Optional[Dict]: |
|
|
"""Extract and parse JSON from response""" |
|
|
import re |
|
|
|
|
|
|
|
|
json_pattern = r"```json\n(.*?)```" |
|
|
matches = re.findall(json_pattern, response, re.DOTALL) |
|
|
|
|
|
if matches: |
|
|
try: |
|
|
return json.loads(matches[0]) |
|
|
except json.JSONDecodeError: |
|
|
pass |
|
|
|
|
|
|
|
|
try: |
|
|
return json.loads(response) |
|
|
except json.JSONDecodeError: |
|
|
return None |
|
|
|
|
|
@staticmethod |
|
|
def split_sections(response: str) -> Dict[str, str]: |
|
|
"""Split response into sections based on headers""" |
|
|
import re |
|
|
|
|
|
sections = {} |
|
|
current_section = "main" |
|
|
current_content = [] |
|
|
|
|
|
for line in response.split('\n'): |
|
|
|
|
|
header_match = re.match(r'^#{1,3}\s+(.+)$', line) |
|
|
if header_match: |
|
|
|
|
|
if current_content: |
|
|
sections[current_section] = '\n'.join(current_content).strip() |
|
|
|
|
|
|
|
|
current_section = header_match.group(1).lower().replace(' ', '_') |
|
|
current_content = [] |
|
|
else: |
|
|
current_content.append(line) |
|
|
|
|
|
|
|
|
if current_content: |
|
|
sections[current_section] = '\n'.join(current_content).strip() |
|
|
|
|
|
return sections |
|
|
|
|
|
|
|
|
class PerformanceMonitor: |
|
|
"""Monitor inference performance""" |
|
|
|
|
|
def __init__(self): |
|
|
self.requests = [] |
|
|
self.start_time = time.time() |
|
|
|
|
|
def record_request( |
|
|
self, |
|
|
duration: float, |
|
|
input_tokens: int, |
|
|
output_tokens: int, |
|
|
success: bool = True |
|
|
): |
|
|
"""Record a request""" |
|
|
self.requests.append({ |
|
|
'timestamp': time.time(), |
|
|
'duration': duration, |
|
|
'input_tokens': input_tokens, |
|
|
'output_tokens': output_tokens, |
|
|
'success': success, |
|
|
'tokens_per_second': output_tokens / duration if duration > 0 else 0 |
|
|
}) |
|
|
|
|
|
def get_stats(self) -> Dict[str, Any]: |
|
|
"""Get performance statistics""" |
|
|
if not self.requests: |
|
|
return { |
|
|
'total_requests': 0, |
|
|
'uptime_seconds': time.time() - self.start_time |
|
|
} |
|
|
|
|
|
successful = [r for r in self.requests if r['success']] |
|
|
|
|
|
return { |
|
|
'total_requests': len(self.requests), |
|
|
'successful_requests': len(successful), |
|
|
'failed_requests': len(self.requests) - len(successful), |
|
|
'uptime_seconds': time.time() - self.start_time, |
|
|
'avg_duration': sum(r['duration'] for r in successful) / len(successful), |
|
|
'avg_tokens_per_second': sum(r['tokens_per_second'] for r in successful) / len(successful), |
|
|
'total_input_tokens': sum(r['input_tokens'] for r in self.requests), |
|
|
'total_output_tokens': sum(r['output_tokens'] for r in self.requests), |
|
|
} |
|
|
|
|
|
def reset(self): |
|
|
"""Reset statistics""" |
|
|
self.requests = [] |
|
|
self.start_time = time.time() |
|
|
|
|
|
|
|
|
class SafetyFilter: |
|
|
"""Basic safety filtering for outputs""" |
|
|
|
|
|
UNSAFE_PATTERNS = [ |
|
|
r'\b(kill|murder|suicide)\s+(?:yourself|myself)', |
|
|
r'\b(bomb|weapon)\s+(?:making|instructions)', |
|
|
r'\bhate\s+speech\b', |
|
|
] |
|
|
|
|
|
@classmethod |
|
|
def is_safe(cls, text: str) -> Tuple[bool, Optional[str]]: |
|
|
""" |
|
|
Check if text is safe |
|
|
|
|
|
Returns: |
|
|
(is_safe, reason) |
|
|
""" |
|
|
import re |
|
|
|
|
|
text_lower = text.lower() |
|
|
|
|
|
for pattern in cls.UNSAFE_PATTERNS: |
|
|
if re.search(pattern, text_lower): |
|
|
return False, f"Matched unsafe pattern: {pattern}" |
|
|
|
|
|
return True, None |
|
|
|
|
|
@classmethod |
|
|
def filter_response(cls, text: str, replacement: str = "[FILTERED]") -> str: |
|
|
"""Filter unsafe content from response""" |
|
|
is_safe, reason = cls.is_safe(text) |
|
|
|
|
|
if not is_safe: |
|
|
logger.warning(f"Filtered unsafe content: {reason}") |
|
|
return replacement |
|
|
|
|
|
return text |
|
|
|
|
|
|
|
|
def get_gpu_info() -> Dict[str, Any]: |
|
|
"""Get GPU information""" |
|
|
if not torch.cuda.is_available(): |
|
|
return {"available": False} |
|
|
|
|
|
info = { |
|
|
"available": True, |
|
|
"count": torch.cuda.device_count(), |
|
|
"devices": [] |
|
|
} |
|
|
|
|
|
for i in range(torch.cuda.device_count()): |
|
|
device_info = { |
|
|
"id": i, |
|
|
"name": torch.cuda.get_device_name(i), |
|
|
"memory_total": torch.cuda.get_device_properties(i).total_memory, |
|
|
"memory_allocated": torch.cuda.memory_allocated(i), |
|
|
"memory_reserved": torch.cuda.memory_reserved(i), |
|
|
} |
|
|
info["devices"].append(device_info) |
|
|
|
|
|
return info |
|
|
|
|
|
|
|
|
def format_bytes(bytes_value: int) -> str: |
|
|
"""Format bytes to human-readable string""" |
|
|
for unit in ['B', 'KB', 'MB', 'GB', 'TB']: |
|
|
if bytes_value < 1024.0: |
|
|
return f"{bytes_value:.2f} {unit}" |
|
|
bytes_value /= 1024.0 |
|
|
return f"{bytes_value:.2f} PB" |