|
|
|
|
|
""" |
|
|
Helion-2.5-Rnd Advanced Data Loader |
|
|
Efficient data loading and preprocessing for inference |
|
|
""" |
|
|
|
|
|
import json |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, Iterator, List, Optional, Union |
|
|
|
|
|
import numpy as np |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class SafeTensorsLoader: |
|
|
"""Efficient SafeTensors model loading with validation""" |
|
|
|
|
|
def __init__(self, model_path: str, device: str = "cuda"): |
|
|
""" |
|
|
Initialize SafeTensors loader |
|
|
|
|
|
Args: |
|
|
model_path: Path to model directory |
|
|
device: Target device for loading |
|
|
""" |
|
|
self.model_path = Path(model_path) |
|
|
self.device = device |
|
|
self.index = self._load_index() |
|
|
self.loaded_shards = {} |
|
|
|
|
|
def _load_index(self) -> Dict: |
|
|
"""Load SafeTensors index file""" |
|
|
index_path = self.model_path / "model.safetensors.index.json" |
|
|
|
|
|
if not index_path.exists(): |
|
|
raise FileNotFoundError(f"Index file not found: {index_path}") |
|
|
|
|
|
with open(index_path, 'r') as f: |
|
|
index = json.load(f) |
|
|
|
|
|
logger.info(f"Loaded index with {len(index.get('weight_map', {}))} weight mappings") |
|
|
return index |
|
|
|
|
|
def get_shard_path(self, shard_name: str) -> Path: |
|
|
"""Get full path to shard file""" |
|
|
return self.model_path / shard_name |
|
|
|
|
|
def load_shard(self, shard_name: str, lazy: bool = False) -> Dict: |
|
|
""" |
|
|
Load a single SafeTensors shard |
|
|
|
|
|
Args: |
|
|
shard_name: Name of shard file |
|
|
lazy: Whether to use lazy loading |
|
|
|
|
|
Returns: |
|
|
Dictionary of tensors |
|
|
""" |
|
|
if shard_name in self.loaded_shards: |
|
|
logger.debug(f"Using cached shard: {shard_name}") |
|
|
return self.loaded_shards[shard_name] |
|
|
|
|
|
shard_path = self.get_shard_path(shard_name) |
|
|
|
|
|
if not shard_path.exists(): |
|
|
raise FileNotFoundError(f"Shard not found: {shard_path}") |
|
|
|
|
|
logger.info(f"Loading shard: {shard_name}") |
|
|
|
|
|
try: |
|
|
tensors = load_file(str(shard_path), device=self.device) |
|
|
|
|
|
if not lazy: |
|
|
self.loaded_shards[shard_name] = tensors |
|
|
|
|
|
return tensors |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load shard {shard_name}: {e}") |
|
|
raise |
|
|
|
|
|
def load_weight(self, weight_name: str) -> Any: |
|
|
""" |
|
|
Load a specific weight by name |
|
|
|
|
|
Args: |
|
|
weight_name: Name of the weight tensor |
|
|
|
|
|
Returns: |
|
|
Weight tensor |
|
|
""" |
|
|
weight_map = self.index.get('weight_map', {}) |
|
|
|
|
|
if weight_name not in weight_map: |
|
|
raise KeyError(f"Weight not found in index: {weight_name}") |
|
|
|
|
|
shard_name = weight_map[weight_name] |
|
|
tensors = self.load_shard(shard_name) |
|
|
|
|
|
return tensors[weight_name] |
|
|
|
|
|
def load_all_weights(self, progress_callback=None) -> Dict: |
|
|
""" |
|
|
Load all model weights |
|
|
|
|
|
Args: |
|
|
progress_callback: Optional callback for progress updates |
|
|
|
|
|
Returns: |
|
|
Dictionary of all weights |
|
|
""" |
|
|
all_weights = {} |
|
|
weight_map = self.index.get('weight_map', {}) |
|
|
unique_shards = set(weight_map.values()) |
|
|
|
|
|
logger.info(f"Loading {len(unique_shards)} shards...") |
|
|
|
|
|
for i, shard_name in enumerate(sorted(unique_shards)): |
|
|
tensors = self.load_shard(shard_name) |
|
|
all_weights.update(tensors) |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(i + 1, len(unique_shards)) |
|
|
|
|
|
logger.info(f"Loaded {len(all_weights)} weight tensors") |
|
|
return all_weights |
|
|
|
|
|
def validate_checksums(self) -> Dict[str, bool]: |
|
|
""" |
|
|
Validate SHA256 checksums of all shards |
|
|
|
|
|
Returns: |
|
|
Dictionary mapping shard names to validation status |
|
|
""" |
|
|
import hashlib |
|
|
|
|
|
results = {} |
|
|
file_metadata = self.index.get('file_metadata', {}) |
|
|
|
|
|
for shard_name, metadata in file_metadata.items(): |
|
|
expected_hash = metadata.get('sha256') |
|
|
|
|
|
if not expected_hash: |
|
|
results[shard_name] = None |
|
|
continue |
|
|
|
|
|
shard_path = self.get_shard_path(shard_name) |
|
|
|
|
|
if not shard_path.exists(): |
|
|
results[shard_name] = False |
|
|
continue |
|
|
|
|
|
sha256 = hashlib.sha256() |
|
|
with open(shard_path, 'rb') as f: |
|
|
for chunk in iter(lambda: f.read(4096), b''): |
|
|
sha256.update(chunk) |
|
|
|
|
|
actual_hash = sha256.hexdigest() |
|
|
results[shard_name] = (actual_hash == expected_hash) |
|
|
|
|
|
status = "✓" if results[shard_name] else "✗" |
|
|
logger.info(f"{status} {shard_name}") |
|
|
|
|
|
return results |
|
|
|
|
|
def get_model_info(self) -> Dict: |
|
|
"""Get model information from index""" |
|
|
metadata = self.index.get('metadata', {}) |
|
|
|
|
|
return { |
|
|
'model_name': metadata.get('model_name', 'Unknown'), |
|
|
'version': metadata.get('version', 'Unknown'), |
|
|
'total_size_bytes': metadata.get('total_size', 0), |
|
|
'total_size_gb': metadata.get('total_size', 0) / (1024**3), |
|
|
'format': metadata.get('format', 'safetensors'), |
|
|
'precision': metadata.get('precision', 'unknown'), |
|
|
'total_shards': metadata.get('total_shards', 0), |
|
|
'parameters': metadata.get('parameters', 'Unknown') |
|
|
} |
|
|
|
|
|
def clear_cache(self): |
|
|
"""Clear loaded shard cache""" |
|
|
self.loaded_shards.clear() |
|
|
logger.info("Cleared shard cache") |
|
|
|
|
|
|
|
|
class DatasetPreprocessor: |
|
|
"""Preprocess datasets for inference""" |
|
|
|
|
|
def __init__(self, tokenizer=None, max_length: int = 131072): |
|
|
""" |
|
|
Initialize preprocessor |
|
|
|
|
|
Args: |
|
|
tokenizer: Tokenizer instance |
|
|
max_length: Maximum sequence length |
|
|
""" |
|
|
self.tokenizer = tokenizer |
|
|
self.max_length = max_length |
|
|
|
|
|
def preprocess_text(self, text: str) -> str: |
|
|
""" |
|
|
Preprocess raw text |
|
|
|
|
|
Args: |
|
|
text: Input text |
|
|
|
|
|
Returns: |
|
|
Preprocessed text |
|
|
""" |
|
|
|
|
|
text = ' '.join(text.split()) |
|
|
|
|
|
|
|
|
text = ''.join(char for char in text if ord(char) >= 32 or char in '\n\t') |
|
|
|
|
|
return text.strip() |
|
|
|
|
|
def preprocess_chat_messages(self, messages: List[Dict[str, str]]) -> str: |
|
|
""" |
|
|
Preprocess chat messages into prompt format |
|
|
|
|
|
Args: |
|
|
messages: List of message dictionaries |
|
|
|
|
|
Returns: |
|
|
Formatted prompt string |
|
|
""" |
|
|
formatted = "" |
|
|
|
|
|
for msg in messages: |
|
|
role = msg.get('role', 'user') |
|
|
content = self.preprocess_text(msg.get('content', '')) |
|
|
formatted += f"<|im_start|>{role}\n{content}<|im_end|>\n" |
|
|
|
|
|
formatted += "<|im_start|>assistant\n" |
|
|
return formatted |
|
|
|
|
|
def batch_preprocess( |
|
|
self, |
|
|
texts: List[str], |
|
|
add_special_tokens: bool = True, |
|
|
padding: bool = True, |
|
|
truncation: bool = True |
|
|
) -> Dict: |
|
|
""" |
|
|
Batch preprocess texts |
|
|
|
|
|
Args: |
|
|
texts: List of input texts |
|
|
add_special_tokens: Whether to add special tokens |
|
|
padding: Whether to pad sequences |
|
|
truncation: Whether to truncate sequences |
|
|
|
|
|
Returns: |
|
|
Batch of preprocessed data |
|
|
""" |
|
|
if self.tokenizer is None: |
|
|
raise ValueError("Tokenizer not initialized") |
|
|
|
|
|
processed_texts = [self.preprocess_text(text) for text in texts] |
|
|
|
|
|
encodings = self.tokenizer( |
|
|
processed_texts, |
|
|
add_special_tokens=add_special_tokens, |
|
|
padding=padding, |
|
|
truncation=truncation, |
|
|
max_length=self.max_length, |
|
|
return_tensors='pt' |
|
|
) |
|
|
|
|
|
return encodings |
|
|
|
|
|
def stream_process_file( |
|
|
self, |
|
|
file_path: str, |
|
|
batch_size: int = 32 |
|
|
) -> Iterator[Dict]: |
|
|
""" |
|
|
Stream process large files in batches |
|
|
|
|
|
Args: |
|
|
file_path: Path to input file |
|
|
batch_size: Number of samples per batch |
|
|
|
|
|
Yields: |
|
|
Batches of preprocessed data |
|
|
""" |
|
|
path = Path(file_path) |
|
|
|
|
|
if path.suffix == '.jsonl': |
|
|
with open(path, 'r') as f: |
|
|
batch = [] |
|
|
|
|
|
for line in f: |
|
|
try: |
|
|
data = json.loads(line) |
|
|
text = data.get('text', '') |
|
|
batch.append(text) |
|
|
|
|
|
if len(batch) >= batch_size: |
|
|
yield self.batch_preprocess(batch) |
|
|
batch = [] |
|
|
|
|
|
except json.JSONDecodeError: |
|
|
logger.warning(f"Skipping invalid JSON line") |
|
|
|
|
|
if batch: |
|
|
yield self.batch_preprocess(batch) |
|
|
|
|
|
elif path.suffix == '.txt': |
|
|
with open(path, 'r') as f: |
|
|
batch = [] |
|
|
|
|
|
for line in f: |
|
|
batch.append(line.strip()) |
|
|
|
|
|
if len(batch) >= batch_size: |
|
|
yield self.batch_preprocess(batch) |
|
|
batch = [] |
|
|
|
|
|
if batch: |
|
|
yield self.batch_preprocess(batch) |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unsupported file format: {path.suffix}") |
|
|
|
|
|
|
|
|
class InferenceDataCollator: |
|
|
"""Collate data for efficient batch inference""" |
|
|
|
|
|
def __init__(self, pad_token_id: int = 128001): |
|
|
""" |
|
|
Initialize data collator |
|
|
|
|
|
Args: |
|
|
pad_token_id: ID for padding token |
|
|
""" |
|
|
self.pad_token_id = pad_token_id |
|
|
|
|
|
def __call__(self, features: List[Dict]) -> Dict: |
|
|
""" |
|
|
Collate features into batch |
|
|
|
|
|
Args: |
|
|
features: List of feature dictionaries |
|
|
|
|
|
Returns: |
|
|
Batched features |
|
|
""" |
|
|
if not features: |
|
|
return {} |
|
|
|
|
|
|
|
|
max_length = max(len(f['input_ids']) for f in features) |
|
|
|
|
|
batch = { |
|
|
'input_ids': [], |
|
|
'attention_mask': [] |
|
|
} |
|
|
|
|
|
for feature in features: |
|
|
input_ids = feature['input_ids'] |
|
|
attention_mask = feature.get('attention_mask', [1] * len(input_ids)) |
|
|
|
|
|
|
|
|
padding_length = max_length - len(input_ids) |
|
|
|
|
|
input_ids = input_ids + [self.pad_token_id] * padding_length |
|
|
attention_mask = attention_mask + [0] * padding_length |
|
|
|
|
|
batch['input_ids'].append(input_ids) |
|
|
batch['attention_mask'].append(attention_mask) |
|
|
|
|
|
|
|
|
batch['input_ids'] = np.array(batch['input_ids'], dtype=np.int64) |
|
|
batch['attention_mask'] = np.array(batch['attention_mask'], dtype=np.int64) |
|
|
|
|
|
return batch |
|
|
|
|
|
def dynamic_padding(self, features: List[Dict], padding_multiple: int = 8) -> Dict: |
|
|
""" |
|
|
Apply dynamic padding optimized for hardware |
|
|
|
|
|
Args: |
|
|
features: List of feature dictionaries |
|
|
padding_multiple: Pad to multiple of this value |
|
|
|
|
|
Returns: |
|
|
Batched features with optimal padding |
|
|
""" |
|
|
if not features: |
|
|
return {} |
|
|
|
|
|
max_length = max(len(f['input_ids']) for f in features) |
|
|
|
|
|
|
|
|
padded_length = ((max_length + padding_multiple - 1) // padding_multiple) * padding_multiple |
|
|
|
|
|
batch = { |
|
|
'input_ids': [], |
|
|
'attention_mask': [] |
|
|
} |
|
|
|
|
|
for feature in features: |
|
|
input_ids = feature['input_ids'] |
|
|
attention_mask = feature.get('attention_mask', [1] * len(input_ids)) |
|
|
|
|
|
padding_length = padded_length - len(input_ids) |
|
|
|
|
|
input_ids = input_ids + [self.pad_token_id] * padding_length |
|
|
attention_mask = attention_mask + [0] * padding_length |
|
|
|
|
|
batch['input_ids'].append(input_ids) |
|
|
batch['attention_mask'].append(attention_mask) |
|
|
|
|
|
batch['input_ids'] = np.array(batch['input_ids'], dtype=np.int64) |
|
|
batch['attention_mask'] = np.array(batch['attention_mask'], dtype=np.int64) |
|
|
|
|
|
return batch |
|
|
|
|
|
|
|
|
class CachedDataLoader: |
|
|
"""Data loader with caching for repeated inference""" |
|
|
|
|
|
def __init__(self, cache_dir: str = "./cache"): |
|
|
""" |
|
|
Initialize cached data loader |
|
|
|
|
|
Args: |
|
|
cache_dir: Directory for cache storage |
|
|
""" |
|
|
self.cache_dir = Path(cache_dir) |
|
|
self.cache_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
def get_cache_key(self, text: str) -> str: |
|
|
"""Generate cache key from text""" |
|
|
import hashlib |
|
|
return hashlib.sha256(text.encode()).hexdigest() |
|
|
|
|
|
def load_from_cache(self, cache_key: str) -> Optional[Any]: |
|
|
""" |
|
|
Load data from cache |
|
|
|
|
|
Args: |
|
|
cache_key: Cache identifier |
|
|
|
|
|
Returns: |
|
|
Cached data or None |
|
|
""" |
|
|
cache_path = self.cache_dir / f"{cache_key}.json" |
|
|
|
|
|
if not cache_path.exists(): |
|
|
return None |
|
|
|
|
|
try: |
|
|
with open(cache_path, 'r') as f: |
|
|
return json.load(f) |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to load from cache: {e}") |
|
|
return None |
|
|
|
|
|
def save_to_cache(self, cache_key: str, data: Any): |
|
|
""" |
|
|
Save data to cache |
|
|
|
|
|
Args: |
|
|
cache_key: Cache identifier |
|
|
data: Data to cache |
|
|
""" |
|
|
cache_path = self.cache_dir / f"{cache_key}.json" |
|
|
|
|
|
try: |
|
|
with open(cache_path, 'w') as f: |
|
|
json.dump(data, f) |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to save to cache: {e}") |
|
|
|
|
|
def clear_cache(self): |
|
|
"""Clear all cached data""" |
|
|
import shutil |
|
|
shutil.rmtree(self.cache_dir) |
|
|
self.cache_dir.mkdir(parents=True, exist_ok=True) |
|
|
logger.info("Cache cleared") |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Example usage""" |
|
|
|
|
|
loader = SafeTensorsLoader("./models/helion") |
|
|
|
|
|
|
|
|
info = loader.get_model_info() |
|
|
print(f"Model: {info['model_name']}") |
|
|
print(f"Size: {info['total_size_gb']:.2f} GB") |
|
|
print(f"Shards: {info['total_shards']}") |
|
|
|
|
|
|
|
|
print("\nValidating checksums...") |
|
|
results = loader.validate_checksums() |
|
|
valid_count = sum(1 for v in results.values() if v) |
|
|
print(f"Valid: {valid_count}/{len(results)}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |