#!/usr/bin/env python3 """ Helion-2.5-Rnd Advanced Data Loader Efficient data loading and preprocessing for inference """ import json import logging from pathlib import Path from typing import Any, Dict, Iterator, List, Optional, Union import numpy as np from safetensors.torch import load_file logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class SafeTensorsLoader: """Efficient SafeTensors model loading with validation""" def __init__(self, model_path: str, device: str = "cuda"): """ Initialize SafeTensors loader Args: model_path: Path to model directory device: Target device for loading """ self.model_path = Path(model_path) self.device = device self.index = self._load_index() self.loaded_shards = {} def _load_index(self) -> Dict: """Load SafeTensors index file""" index_path = self.model_path / "model.safetensors.index.json" if not index_path.exists(): raise FileNotFoundError(f"Index file not found: {index_path}") with open(index_path, 'r') as f: index = json.load(f) logger.info(f"Loaded index with {len(index.get('weight_map', {}))} weight mappings") return index def get_shard_path(self, shard_name: str) -> Path: """Get full path to shard file""" return self.model_path / shard_name def load_shard(self, shard_name: str, lazy: bool = False) -> Dict: """ Load a single SafeTensors shard Args: shard_name: Name of shard file lazy: Whether to use lazy loading Returns: Dictionary of tensors """ if shard_name in self.loaded_shards: logger.debug(f"Using cached shard: {shard_name}") return self.loaded_shards[shard_name] shard_path = self.get_shard_path(shard_name) if not shard_path.exists(): raise FileNotFoundError(f"Shard not found: {shard_path}") logger.info(f"Loading shard: {shard_name}") try: tensors = load_file(str(shard_path), device=self.device) if not lazy: self.loaded_shards[shard_name] = tensors return tensors except Exception as e: logger.error(f"Failed to load shard {shard_name}: {e}") raise def load_weight(self, weight_name: str) -> Any: """ Load a specific weight by name Args: weight_name: Name of the weight tensor Returns: Weight tensor """ weight_map = self.index.get('weight_map', {}) if weight_name not in weight_map: raise KeyError(f"Weight not found in index: {weight_name}") shard_name = weight_map[weight_name] tensors = self.load_shard(shard_name) return tensors[weight_name] def load_all_weights(self, progress_callback=None) -> Dict: """ Load all model weights Args: progress_callback: Optional callback for progress updates Returns: Dictionary of all weights """ all_weights = {} weight_map = self.index.get('weight_map', {}) unique_shards = set(weight_map.values()) logger.info(f"Loading {len(unique_shards)} shards...") for i, shard_name in enumerate(sorted(unique_shards)): tensors = self.load_shard(shard_name) all_weights.update(tensors) if progress_callback: progress_callback(i + 1, len(unique_shards)) logger.info(f"Loaded {len(all_weights)} weight tensors") return all_weights def validate_checksums(self) -> Dict[str, bool]: """ Validate SHA256 checksums of all shards Returns: Dictionary mapping shard names to validation status """ import hashlib results = {} file_metadata = self.index.get('file_metadata', {}) for shard_name, metadata in file_metadata.items(): expected_hash = metadata.get('sha256') if not expected_hash: results[shard_name] = None continue shard_path = self.get_shard_path(shard_name) if not shard_path.exists(): results[shard_name] = False continue sha256 = hashlib.sha256() with open(shard_path, 'rb') as f: for chunk in iter(lambda: f.read(4096), b''): sha256.update(chunk) actual_hash = sha256.hexdigest() results[shard_name] = (actual_hash == expected_hash) status = "✓" if results[shard_name] else "✗" logger.info(f"{status} {shard_name}") return results def get_model_info(self) -> Dict: """Get model information from index""" metadata = self.index.get('metadata', {}) return { 'model_name': metadata.get('model_name', 'Unknown'), 'version': metadata.get('version', 'Unknown'), 'total_size_bytes': metadata.get('total_size', 0), 'total_size_gb': metadata.get('total_size', 0) / (1024**3), 'format': metadata.get('format', 'safetensors'), 'precision': metadata.get('precision', 'unknown'), 'total_shards': metadata.get('total_shards', 0), 'parameters': metadata.get('parameters', 'Unknown') } def clear_cache(self): """Clear loaded shard cache""" self.loaded_shards.clear() logger.info("Cleared shard cache") class DatasetPreprocessor: """Preprocess datasets for inference""" def __init__(self, tokenizer=None, max_length: int = 131072): """ Initialize preprocessor Args: tokenizer: Tokenizer instance max_length: Maximum sequence length """ self.tokenizer = tokenizer self.max_length = max_length def preprocess_text(self, text: str) -> str: """ Preprocess raw text Args: text: Input text Returns: Preprocessed text """ # Remove excessive whitespace text = ' '.join(text.split()) # Remove control characters text = ''.join(char for char in text if ord(char) >= 32 or char in '\n\t') return text.strip() def preprocess_chat_messages(self, messages: List[Dict[str, str]]) -> str: """ Preprocess chat messages into prompt format Args: messages: List of message dictionaries Returns: Formatted prompt string """ formatted = "" for msg in messages: role = msg.get('role', 'user') content = self.preprocess_text(msg.get('content', '')) formatted += f"<|im_start|>{role}\n{content}<|im_end|>\n" formatted += "<|im_start|>assistant\n" return formatted def batch_preprocess( self, texts: List[str], add_special_tokens: bool = True, padding: bool = True, truncation: bool = True ) -> Dict: """ Batch preprocess texts Args: texts: List of input texts add_special_tokens: Whether to add special tokens padding: Whether to pad sequences truncation: Whether to truncate sequences Returns: Batch of preprocessed data """ if self.tokenizer is None: raise ValueError("Tokenizer not initialized") processed_texts = [self.preprocess_text(text) for text in texts] encodings = self.tokenizer( processed_texts, add_special_tokens=add_special_tokens, padding=padding, truncation=truncation, max_length=self.max_length, return_tensors='pt' ) return encodings def stream_process_file( self, file_path: str, batch_size: int = 32 ) -> Iterator[Dict]: """ Stream process large files in batches Args: file_path: Path to input file batch_size: Number of samples per batch Yields: Batches of preprocessed data """ path = Path(file_path) if path.suffix == '.jsonl': with open(path, 'r') as f: batch = [] for line in f: try: data = json.loads(line) text = data.get('text', '') batch.append(text) if len(batch) >= batch_size: yield self.batch_preprocess(batch) batch = [] except json.JSONDecodeError: logger.warning(f"Skipping invalid JSON line") if batch: yield self.batch_preprocess(batch) elif path.suffix == '.txt': with open(path, 'r') as f: batch = [] for line in f: batch.append(line.strip()) if len(batch) >= batch_size: yield self.batch_preprocess(batch) batch = [] if batch: yield self.batch_preprocess(batch) else: raise ValueError(f"Unsupported file format: {path.suffix}") class InferenceDataCollator: """Collate data for efficient batch inference""" def __init__(self, pad_token_id: int = 128001): """ Initialize data collator Args: pad_token_id: ID for padding token """ self.pad_token_id = pad_token_id def __call__(self, features: List[Dict]) -> Dict: """ Collate features into batch Args: features: List of feature dictionaries Returns: Batched features """ if not features: return {} # Get maximum sequence length in batch max_length = max(len(f['input_ids']) for f in features) batch = { 'input_ids': [], 'attention_mask': [] } for feature in features: input_ids = feature['input_ids'] attention_mask = feature.get('attention_mask', [1] * len(input_ids)) # Pad to max length padding_length = max_length - len(input_ids) input_ids = input_ids + [self.pad_token_id] * padding_length attention_mask = attention_mask + [0] * padding_length batch['input_ids'].append(input_ids) batch['attention_mask'].append(attention_mask) # Convert to numpy arrays batch['input_ids'] = np.array(batch['input_ids'], dtype=np.int64) batch['attention_mask'] = np.array(batch['attention_mask'], dtype=np.int64) return batch def dynamic_padding(self, features: List[Dict], padding_multiple: int = 8) -> Dict: """ Apply dynamic padding optimized for hardware Args: features: List of feature dictionaries padding_multiple: Pad to multiple of this value Returns: Batched features with optimal padding """ if not features: return {} max_length = max(len(f['input_ids']) for f in features) # Round up to nearest multiple padded_length = ((max_length + padding_multiple - 1) // padding_multiple) * padding_multiple batch = { 'input_ids': [], 'attention_mask': [] } for feature in features: input_ids = feature['input_ids'] attention_mask = feature.get('attention_mask', [1] * len(input_ids)) padding_length = padded_length - len(input_ids) input_ids = input_ids + [self.pad_token_id] * padding_length attention_mask = attention_mask + [0] * padding_length batch['input_ids'].append(input_ids) batch['attention_mask'].append(attention_mask) batch['input_ids'] = np.array(batch['input_ids'], dtype=np.int64) batch['attention_mask'] = np.array(batch['attention_mask'], dtype=np.int64) return batch class CachedDataLoader: """Data loader with caching for repeated inference""" def __init__(self, cache_dir: str = "./cache"): """ Initialize cached data loader Args: cache_dir: Directory for cache storage """ self.cache_dir = Path(cache_dir) self.cache_dir.mkdir(parents=True, exist_ok=True) def get_cache_key(self, text: str) -> str: """Generate cache key from text""" import hashlib return hashlib.sha256(text.encode()).hexdigest() def load_from_cache(self, cache_key: str) -> Optional[Any]: """ Load data from cache Args: cache_key: Cache identifier Returns: Cached data or None """ cache_path = self.cache_dir / f"{cache_key}.json" if not cache_path.exists(): return None try: with open(cache_path, 'r') as f: return json.load(f) except Exception as e: logger.warning(f"Failed to load from cache: {e}") return None def save_to_cache(self, cache_key: str, data: Any): """ Save data to cache Args: cache_key: Cache identifier data: Data to cache """ cache_path = self.cache_dir / f"{cache_key}.json" try: with open(cache_path, 'w') as f: json.dump(data, f) except Exception as e: logger.warning(f"Failed to save to cache: {e}") def clear_cache(self): """Clear all cached data""" import shutil shutil.rmtree(self.cache_dir) self.cache_dir.mkdir(parents=True, exist_ok=True) logger.info("Cache cleared") def main(): """Example usage""" # SafeTensors loading loader = SafeTensorsLoader("./models/helion") # Get model info info = loader.get_model_info() print(f"Model: {info['model_name']}") print(f"Size: {info['total_size_gb']:.2f} GB") print(f"Shards: {info['total_shards']}") # Validate checksums print("\nValidating checksums...") results = loader.validate_checksums() valid_count = sum(1 for v in results.values() if v) print(f"Valid: {valid_count}/{len(results)}") if __name__ == "__main__": main()