Trouter-Library commited on
Commit
d6f46cf
·
verified ·
1 Parent(s): e727309

Create inference/data_loader.py

Browse files
Files changed (1) hide show
  1. inference/data_loader.py +510 -0
inference/data_loader.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Helion-2.5-Rnd Advanced Data Loader
4
+ Efficient data loading and preprocessing for inference
5
+ """
6
+
7
+ import json
8
+ import logging
9
+ from pathlib import Path
10
+ from typing import Any, Dict, Iterator, List, Optional, Union
11
+
12
+ import numpy as np
13
+ from safetensors.torch import load_file
14
+
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class SafeTensorsLoader:
20
+ """Efficient SafeTensors model loading with validation"""
21
+
22
+ def __init__(self, model_path: str, device: str = "cuda"):
23
+ """
24
+ Initialize SafeTensors loader
25
+
26
+ Args:
27
+ model_path: Path to model directory
28
+ device: Target device for loading
29
+ """
30
+ self.model_path = Path(model_path)
31
+ self.device = device
32
+ self.index = self._load_index()
33
+ self.loaded_shards = {}
34
+
35
+ def _load_index(self) -> Dict:
36
+ """Load SafeTensors index file"""
37
+ index_path = self.model_path / "model.safetensors.index.json"
38
+
39
+ if not index_path.exists():
40
+ raise FileNotFoundError(f"Index file not found: {index_path}")
41
+
42
+ with open(index_path, 'r') as f:
43
+ index = json.load(f)
44
+
45
+ logger.info(f"Loaded index with {len(index.get('weight_map', {}))} weight mappings")
46
+ return index
47
+
48
+ def get_shard_path(self, shard_name: str) -> Path:
49
+ """Get full path to shard file"""
50
+ return self.model_path / shard_name
51
+
52
+ def load_shard(self, shard_name: str, lazy: bool = False) -> Dict:
53
+ """
54
+ Load a single SafeTensors shard
55
+
56
+ Args:
57
+ shard_name: Name of shard file
58
+ lazy: Whether to use lazy loading
59
+
60
+ Returns:
61
+ Dictionary of tensors
62
+ """
63
+ if shard_name in self.loaded_shards:
64
+ logger.debug(f"Using cached shard: {shard_name}")
65
+ return self.loaded_shards[shard_name]
66
+
67
+ shard_path = self.get_shard_path(shard_name)
68
+
69
+ if not shard_path.exists():
70
+ raise FileNotFoundError(f"Shard not found: {shard_path}")
71
+
72
+ logger.info(f"Loading shard: {shard_name}")
73
+
74
+ try:
75
+ tensors = load_file(str(shard_path), device=self.device)
76
+
77
+ if not lazy:
78
+ self.loaded_shards[shard_name] = tensors
79
+
80
+ return tensors
81
+
82
+ except Exception as e:
83
+ logger.error(f"Failed to load shard {shard_name}: {e}")
84
+ raise
85
+
86
+ def load_weight(self, weight_name: str) -> Any:
87
+ """
88
+ Load a specific weight by name
89
+
90
+ Args:
91
+ weight_name: Name of the weight tensor
92
+
93
+ Returns:
94
+ Weight tensor
95
+ """
96
+ weight_map = self.index.get('weight_map', {})
97
+
98
+ if weight_name not in weight_map:
99
+ raise KeyError(f"Weight not found in index: {weight_name}")
100
+
101
+ shard_name = weight_map[weight_name]
102
+ tensors = self.load_shard(shard_name)
103
+
104
+ return tensors[weight_name]
105
+
106
+ def load_all_weights(self, progress_callback=None) -> Dict:
107
+ """
108
+ Load all model weights
109
+
110
+ Args:
111
+ progress_callback: Optional callback for progress updates
112
+
113
+ Returns:
114
+ Dictionary of all weights
115
+ """
116
+ all_weights = {}
117
+ weight_map = self.index.get('weight_map', {})
118
+ unique_shards = set(weight_map.values())
119
+
120
+ logger.info(f"Loading {len(unique_shards)} shards...")
121
+
122
+ for i, shard_name in enumerate(sorted(unique_shards)):
123
+ tensors = self.load_shard(shard_name)
124
+ all_weights.update(tensors)
125
+
126
+ if progress_callback:
127
+ progress_callback(i + 1, len(unique_shards))
128
+
129
+ logger.info(f"Loaded {len(all_weights)} weight tensors")
130
+ return all_weights
131
+
132
+ def validate_checksums(self) -> Dict[str, bool]:
133
+ """
134
+ Validate SHA256 checksums of all shards
135
+
136
+ Returns:
137
+ Dictionary mapping shard names to validation status
138
+ """
139
+ import hashlib
140
+
141
+ results = {}
142
+ file_metadata = self.index.get('file_metadata', {})
143
+
144
+ for shard_name, metadata in file_metadata.items():
145
+ expected_hash = metadata.get('sha256')
146
+
147
+ if not expected_hash:
148
+ results[shard_name] = None
149
+ continue
150
+
151
+ shard_path = self.get_shard_path(shard_name)
152
+
153
+ if not shard_path.exists():
154
+ results[shard_name] = False
155
+ continue
156
+
157
+ sha256 = hashlib.sha256()
158
+ with open(shard_path, 'rb') as f:
159
+ for chunk in iter(lambda: f.read(4096), b''):
160
+ sha256.update(chunk)
161
+
162
+ actual_hash = sha256.hexdigest()
163
+ results[shard_name] = (actual_hash == expected_hash)
164
+
165
+ status = "✓" if results[shard_name] else "✗"
166
+ logger.info(f"{status} {shard_name}")
167
+
168
+ return results
169
+
170
+ def get_model_info(self) -> Dict:
171
+ """Get model information from index"""
172
+ metadata = self.index.get('metadata', {})
173
+
174
+ return {
175
+ 'model_name': metadata.get('model_name', 'Unknown'),
176
+ 'version': metadata.get('version', 'Unknown'),
177
+ 'total_size_bytes': metadata.get('total_size', 0),
178
+ 'total_size_gb': metadata.get('total_size', 0) / (1024**3),
179
+ 'format': metadata.get('format', 'safetensors'),
180
+ 'precision': metadata.get('precision', 'unknown'),
181
+ 'total_shards': metadata.get('total_shards', 0),
182
+ 'parameters': metadata.get('parameters', 'Unknown')
183
+ }
184
+
185
+ def clear_cache(self):
186
+ """Clear loaded shard cache"""
187
+ self.loaded_shards.clear()
188
+ logger.info("Cleared shard cache")
189
+
190
+
191
+ class DatasetPreprocessor:
192
+ """Preprocess datasets for inference"""
193
+
194
+ def __init__(self, tokenizer=None, max_length: int = 131072):
195
+ """
196
+ Initialize preprocessor
197
+
198
+ Args:
199
+ tokenizer: Tokenizer instance
200
+ max_length: Maximum sequence length
201
+ """
202
+ self.tokenizer = tokenizer
203
+ self.max_length = max_length
204
+
205
+ def preprocess_text(self, text: str) -> str:
206
+ """
207
+ Preprocess raw text
208
+
209
+ Args:
210
+ text: Input text
211
+
212
+ Returns:
213
+ Preprocessed text
214
+ """
215
+ # Remove excessive whitespace
216
+ text = ' '.join(text.split())
217
+
218
+ # Remove control characters
219
+ text = ''.join(char for char in text if ord(char) >= 32 or char in '\n\t')
220
+
221
+ return text.strip()
222
+
223
+ def preprocess_chat_messages(self, messages: List[Dict[str, str]]) -> str:
224
+ """
225
+ Preprocess chat messages into prompt format
226
+
227
+ Args:
228
+ messages: List of message dictionaries
229
+
230
+ Returns:
231
+ Formatted prompt string
232
+ """
233
+ formatted = ""
234
+
235
+ for msg in messages:
236
+ role = msg.get('role', 'user')
237
+ content = self.preprocess_text(msg.get('content', ''))
238
+ formatted += f"<|im_start|>{role}\n{content}<|im_end|>\n"
239
+
240
+ formatted += "<|im_start|>assistant\n"
241
+ return formatted
242
+
243
+ def batch_preprocess(
244
+ self,
245
+ texts: List[str],
246
+ add_special_tokens: bool = True,
247
+ padding: bool = True,
248
+ truncation: bool = True
249
+ ) -> Dict:
250
+ """
251
+ Batch preprocess texts
252
+
253
+ Args:
254
+ texts: List of input texts
255
+ add_special_tokens: Whether to add special tokens
256
+ padding: Whether to pad sequences
257
+ truncation: Whether to truncate sequences
258
+
259
+ Returns:
260
+ Batch of preprocessed data
261
+ """
262
+ if self.tokenizer is None:
263
+ raise ValueError("Tokenizer not initialized")
264
+
265
+ processed_texts = [self.preprocess_text(text) for text in texts]
266
+
267
+ encodings = self.tokenizer(
268
+ processed_texts,
269
+ add_special_tokens=add_special_tokens,
270
+ padding=padding,
271
+ truncation=truncation,
272
+ max_length=self.max_length,
273
+ return_tensors='pt'
274
+ )
275
+
276
+ return encodings
277
+
278
+ def stream_process_file(
279
+ self,
280
+ file_path: str,
281
+ batch_size: int = 32
282
+ ) -> Iterator[Dict]:
283
+ """
284
+ Stream process large files in batches
285
+
286
+ Args:
287
+ file_path: Path to input file
288
+ batch_size: Number of samples per batch
289
+
290
+ Yields:
291
+ Batches of preprocessed data
292
+ """
293
+ path = Path(file_path)
294
+
295
+ if path.suffix == '.jsonl':
296
+ with open(path, 'r') as f:
297
+ batch = []
298
+
299
+ for line in f:
300
+ try:
301
+ data = json.loads(line)
302
+ text = data.get('text', '')
303
+ batch.append(text)
304
+
305
+ if len(batch) >= batch_size:
306
+ yield self.batch_preprocess(batch)
307
+ batch = []
308
+
309
+ except json.JSONDecodeError:
310
+ logger.warning(f"Skipping invalid JSON line")
311
+
312
+ if batch:
313
+ yield self.batch_preprocess(batch)
314
+
315
+ elif path.suffix == '.txt':
316
+ with open(path, 'r') as f:
317
+ batch = []
318
+
319
+ for line in f:
320
+ batch.append(line.strip())
321
+
322
+ if len(batch) >= batch_size:
323
+ yield self.batch_preprocess(batch)
324
+ batch = []
325
+
326
+ if batch:
327
+ yield self.batch_preprocess(batch)
328
+
329
+ else:
330
+ raise ValueError(f"Unsupported file format: {path.suffix}")
331
+
332
+
333
+ class InferenceDataCollator:
334
+ """Collate data for efficient batch inference"""
335
+
336
+ def __init__(self, pad_token_id: int = 128001):
337
+ """
338
+ Initialize data collator
339
+
340
+ Args:
341
+ pad_token_id: ID for padding token
342
+ """
343
+ self.pad_token_id = pad_token_id
344
+
345
+ def __call__(self, features: List[Dict]) -> Dict:
346
+ """
347
+ Collate features into batch
348
+
349
+ Args:
350
+ features: List of feature dictionaries
351
+
352
+ Returns:
353
+ Batched features
354
+ """
355
+ if not features:
356
+ return {}
357
+
358
+ # Get maximum sequence length in batch
359
+ max_length = max(len(f['input_ids']) for f in features)
360
+
361
+ batch = {
362
+ 'input_ids': [],
363
+ 'attention_mask': []
364
+ }
365
+
366
+ for feature in features:
367
+ input_ids = feature['input_ids']
368
+ attention_mask = feature.get('attention_mask', [1] * len(input_ids))
369
+
370
+ # Pad to max length
371
+ padding_length = max_length - len(input_ids)
372
+
373
+ input_ids = input_ids + [self.pad_token_id] * padding_length
374
+ attention_mask = attention_mask + [0] * padding_length
375
+
376
+ batch['input_ids'].append(input_ids)
377
+ batch['attention_mask'].append(attention_mask)
378
+
379
+ # Convert to numpy arrays
380
+ batch['input_ids'] = np.array(batch['input_ids'], dtype=np.int64)
381
+ batch['attention_mask'] = np.array(batch['attention_mask'], dtype=np.int64)
382
+
383
+ return batch
384
+
385
+ def dynamic_padding(self, features: List[Dict], padding_multiple: int = 8) -> Dict:
386
+ """
387
+ Apply dynamic padding optimized for hardware
388
+
389
+ Args:
390
+ features: List of feature dictionaries
391
+ padding_multiple: Pad to multiple of this value
392
+
393
+ Returns:
394
+ Batched features with optimal padding
395
+ """
396
+ if not features:
397
+ return {}
398
+
399
+ max_length = max(len(f['input_ids']) for f in features)
400
+
401
+ # Round up to nearest multiple
402
+ padded_length = ((max_length + padding_multiple - 1) // padding_multiple) * padding_multiple
403
+
404
+ batch = {
405
+ 'input_ids': [],
406
+ 'attention_mask': []
407
+ }
408
+
409
+ for feature in features:
410
+ input_ids = feature['input_ids']
411
+ attention_mask = feature.get('attention_mask', [1] * len(input_ids))
412
+
413
+ padding_length = padded_length - len(input_ids)
414
+
415
+ input_ids = input_ids + [self.pad_token_id] * padding_length
416
+ attention_mask = attention_mask + [0] * padding_length
417
+
418
+ batch['input_ids'].append(input_ids)
419
+ batch['attention_mask'].append(attention_mask)
420
+
421
+ batch['input_ids'] = np.array(batch['input_ids'], dtype=np.int64)
422
+ batch['attention_mask'] = np.array(batch['attention_mask'], dtype=np.int64)
423
+
424
+ return batch
425
+
426
+
427
+ class CachedDataLoader:
428
+ """Data loader with caching for repeated inference"""
429
+
430
+ def __init__(self, cache_dir: str = "./cache"):
431
+ """
432
+ Initialize cached data loader
433
+
434
+ Args:
435
+ cache_dir: Directory for cache storage
436
+ """
437
+ self.cache_dir = Path(cache_dir)
438
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
439
+
440
+ def get_cache_key(self, text: str) -> str:
441
+ """Generate cache key from text"""
442
+ import hashlib
443
+ return hashlib.sha256(text.encode()).hexdigest()
444
+
445
+ def load_from_cache(self, cache_key: str) -> Optional[Any]:
446
+ """
447
+ Load data from cache
448
+
449
+ Args:
450
+ cache_key: Cache identifier
451
+
452
+ Returns:
453
+ Cached data or None
454
+ """
455
+ cache_path = self.cache_dir / f"{cache_key}.json"
456
+
457
+ if not cache_path.exists():
458
+ return None
459
+
460
+ try:
461
+ with open(cache_path, 'r') as f:
462
+ return json.load(f)
463
+ except Exception as e:
464
+ logger.warning(f"Failed to load from cache: {e}")
465
+ return None
466
+
467
+ def save_to_cache(self, cache_key: str, data: Any):
468
+ """
469
+ Save data to cache
470
+
471
+ Args:
472
+ cache_key: Cache identifier
473
+ data: Data to cache
474
+ """
475
+ cache_path = self.cache_dir / f"{cache_key}.json"
476
+
477
+ try:
478
+ with open(cache_path, 'w') as f:
479
+ json.dump(data, f)
480
+ except Exception as e:
481
+ logger.warning(f"Failed to save to cache: {e}")
482
+
483
+ def clear_cache(self):
484
+ """Clear all cached data"""
485
+ import shutil
486
+ shutil.rmtree(self.cache_dir)
487
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
488
+ logger.info("Cache cleared")
489
+
490
+
491
+ def main():
492
+ """Example usage"""
493
+ # SafeTensors loading
494
+ loader = SafeTensorsLoader("./models/helion")
495
+
496
+ # Get model info
497
+ info = loader.get_model_info()
498
+ print(f"Model: {info['model_name']}")
499
+ print(f"Size: {info['total_size_gb']:.2f} GB")
500
+ print(f"Shards: {info['total_shards']}")
501
+
502
+ # Validate checksums
503
+ print("\nValidating checksums...")
504
+ results = loader.validate_checksums()
505
+ valid_count = sum(1 for v in results.values() if v)
506
+ print(f"Valid: {valid_count}/{len(results)}")
507
+
508
+
509
+ if __name__ == "__main__":
510
+ main()