esunAI commited on
Commit
164b12c
·
verified ·
1 Parent(s): 321da93

Add cfg_dataset.py

Browse files
Files changed (1) hide show
  1. src/cfg_dataset.py +441 -0
src/cfg_dataset.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import Dataset, DataLoader
4
+ import numpy as np
5
+ import json
6
+ import os
7
+ from typing import Dict, List, Tuple, Optional
8
+ import random
9
+ import re
10
+
11
+ def parse_fasta_with_amp_labels(fasta_path: str, max_seq_len: int = 50) -> Dict[str, any]:
12
+ """
13
+ Parse FASTA file and assign AMP/Non-AMP labels based on header prefixes.
14
+
15
+ Label assignment strategy:
16
+ - AMP (0): Headers starting with '>AP'
17
+ - Non-AMP (1): Headers starting with '>sp'
18
+ - Mask (2): Used for CFG training (randomly assigned)
19
+
20
+ File format:
21
+ - Odd lines: Headers (>sp or >AP)
22
+ - Even lines: Amino acid sequences
23
+
24
+ Args:
25
+ fasta_path: Path to FASTA file
26
+ max_seq_len: Maximum sequence length to include
27
+
28
+ Returns:
29
+ Dictionary with sequences, labels, and metadata
30
+ """
31
+ sequences = []
32
+ labels = []
33
+ headers = []
34
+
35
+ print(f"Parsing FASTA file: {fasta_path}")
36
+ print("Label assignment: >AP = AMP (0), >sp = Non-AMP (1)")
37
+
38
+ current_header = ""
39
+ current_sequence = ""
40
+
41
+ with open(fasta_path, 'r') as f:
42
+ for line in f:
43
+ line = line.strip()
44
+ if line.startswith('>'):
45
+ # Process previous sequence if exists
46
+ if current_sequence and current_header:
47
+ if 2 <= len(current_sequence) <= max_seq_len:
48
+ # Check if sequence contains only canonical amino acids
49
+ canonical_aa = set('ACDEFGHIKLMNPQRSTVWY')
50
+ if all(aa in canonical_aa for aa in current_sequence.upper()):
51
+ sequences.append(current_sequence.upper())
52
+ headers.append(current_header)
53
+
54
+ # Assign label based on header prefix
55
+ if current_header.startswith('AP'):
56
+ labels.append(0) # AMP
57
+ elif current_header.startswith('sp'):
58
+ labels.append(1) # Non-AMP
59
+ else:
60
+ # Unknown prefix, default to Non-AMP
61
+ labels.append(1)
62
+ print(f"Warning: Unknown header prefix in '{current_header}', defaulting to Non-AMP")
63
+
64
+ # Start new sequence (remove '>' from header)
65
+ current_header = line[1:]
66
+ current_sequence = ""
67
+ else:
68
+ current_sequence += line
69
+
70
+ # Process last sequence
71
+ if current_sequence and current_header:
72
+ if 2 <= len(current_sequence) <= max_seq_len:
73
+ canonical_aa = set('ACDEFGHIKLMNPQRSTVWY')
74
+ if all(aa in canonical_aa for aa in current_sequence.upper()):
75
+ sequences.append(current_sequence.upper())
76
+ headers.append(current_header)
77
+
78
+ # Assign label based on header prefix
79
+ if current_header.startswith('AP'):
80
+ labels.append(0) # AMP
81
+ elif current_header.startswith('sp'):
82
+ labels.append(1) # Non-AMP
83
+ else:
84
+ # Unknown prefix, default to Non-AMP
85
+ labels.append(1)
86
+ print(f"Warning: Unknown header prefix in '{current_header}', defaulting to Non-AMP")
87
+
88
+ # Create masked labels for CFG training (10% masked)
89
+ original_labels = np.array(labels)
90
+ masked_labels = original_labels.copy()
91
+ mask_probability = 0.1
92
+ mask_indices = np.random.choice(
93
+ len(original_labels),
94
+ size=int(len(original_labels) * mask_probability),
95
+ replace=False
96
+ )
97
+ masked_labels[mask_indices] = 2 # 2 = mask/unknown
98
+
99
+ print(f"✓ Parsed {len(sequences)} valid sequences from FASTA")
100
+ print(f" AMP sequences: {np.sum(original_labels == 0)}")
101
+ print(f" Non-AMP sequences: {np.sum(original_labels == 1)}")
102
+ print(f" Masked for CFG: {len(mask_indices)}")
103
+
104
+ return {
105
+ 'sequences': sequences,
106
+ 'headers': headers,
107
+ 'labels': original_labels.tolist(),
108
+ 'masked_labels': masked_labels.tolist(),
109
+ 'mask_indices': mask_indices.tolist()
110
+ }
111
+
112
+ class CFGUniProtDataset(Dataset):
113
+ """
114
+ Dataset class for UniProt sequences with classifier-free guidance.
115
+
116
+ This dataset:
117
+ 1. Loads processed UniProt data with AMP classifications
118
+ 2. Handles label masking for CFG training
119
+ 3. Integrates with your existing flow training pipeline
120
+ 4. Provides sequences, labels, and masking information
121
+ """
122
+
123
+ def __init__(self,
124
+ data_path: str,
125
+ use_masked_labels: bool = True,
126
+ mask_probability: float = 0.1,
127
+ max_seq_len: int = 50,
128
+ device: str = 'cuda'):
129
+
130
+ self.data_path = data_path
131
+ self.use_masked_labels = use_masked_labels
132
+ self.mask_probability = mask_probability
133
+ self.max_seq_len = max_seq_len
134
+ self.device = device
135
+
136
+ # Load processed data
137
+ self._load_data()
138
+
139
+ # Label mapping
140
+ self.label_map = {
141
+ 0: 'amp', # MIC < 100
142
+ 1: 'non_amp', # MIC > 100
143
+ 2: 'mask' # Unknown MIC
144
+ }
145
+
146
+ print(f"CFG Dataset initialized:")
147
+ print(f" Total sequences: {len(self.sequences)}")
148
+ print(f" Using masked labels: {use_masked_labels}")
149
+ print(f" Mask probability: {mask_probability}")
150
+ print(f" Label distribution: {self._get_label_distribution()}")
151
+
152
+ def _load_data(self):
153
+ """Load processed UniProt data."""
154
+ if os.path.exists(self.data_path):
155
+ with open(self.data_path, 'r') as f:
156
+ data = json.load(f)
157
+
158
+ self.sequences = data['sequences']
159
+ self.original_labels = np.array(data['original_labels'])
160
+ self.masked_labels = np.array(data['masked_labels'])
161
+ self.mask_indices = set(data['mask_indices'])
162
+
163
+ else:
164
+ raise FileNotFoundError(f"Data file not found: {self.data_path}")
165
+
166
+ def _get_label_distribution(self) -> Dict[str, int]:
167
+ """Get distribution of labels in the dataset."""
168
+ labels = self.masked_labels if self.use_masked_labels else self.original_labels
169
+ unique, counts = np.unique(labels, return_counts=True)
170
+ return {self.label_map[label]: count for label, count in zip(unique, counts)}
171
+
172
+ def __len__(self) -> int:
173
+ return len(self.sequences)
174
+
175
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
176
+ """Get a single sample with sequence and label."""
177
+ sequence = self.sequences[idx]
178
+
179
+ # Get appropriate label
180
+ if self.use_masked_labels:
181
+ label = self.masked_labels[idx]
182
+ else:
183
+ label = self.original_labels[idx]
184
+
185
+ # Check if this sample was masked
186
+ is_masked = idx in self.mask_indices
187
+
188
+ return {
189
+ 'sequence': sequence,
190
+ 'label': torch.tensor(label, dtype=torch.long),
191
+ 'original_label': torch.tensor(self.original_labels[idx], dtype=torch.long),
192
+ 'is_masked': torch.tensor(is_masked, dtype=torch.bool),
193
+ 'index': torch.tensor(idx, dtype=torch.long)
194
+ }
195
+
196
+ def get_label_statistics(self) -> Dict[str, Dict]:
197
+ """Get detailed statistics about labels."""
198
+ stats = {
199
+ 'original': self._get_label_distribution(),
200
+ 'masked': self._get_label_distribution() if self.use_masked_labels else None,
201
+ 'masking_info': {
202
+ 'total_masked': len(self.mask_indices),
203
+ 'mask_probability': self.mask_probability,
204
+ 'masked_indices': list(self.mask_indices)
205
+ }
206
+ }
207
+ return stats
208
+
209
+ class CFGFlowDataset(Dataset):
210
+ """
211
+ Dataset that integrates CFG labels with your existing flow training pipeline.
212
+
213
+ This dataset:
214
+ 1. Loads your existing AMP embeddings
215
+ 2. Adds CFG labels from UniProt processing
216
+ 3. Handles the integration between embeddings and labels
217
+ 4. Provides data in the format expected by your flow training
218
+ """
219
+
220
+ def __init__(self,
221
+ embeddings_path: str,
222
+ cfg_data_path: str,
223
+ use_masked_labels: bool = True,
224
+ max_seq_len: int = 50,
225
+ device: str = 'cuda'):
226
+
227
+ self.embeddings_path = embeddings_path
228
+ self.cfg_data_path = cfg_data_path
229
+ self.use_masked_labels = use_masked_labels
230
+ self.max_seq_len = max_seq_len
231
+ self.device = device
232
+
233
+ # Load data
234
+ self._load_embeddings()
235
+ self._load_cfg_data()
236
+ self._align_data()
237
+
238
+ print(f"CFG Flow Dataset initialized:")
239
+ print(f" AMP embeddings: {self.embeddings.shape}")
240
+ print(f" CFG labels: {len(self.cfg_labels)}")
241
+ print(f" Aligned samples: {len(self.aligned_indices)}")
242
+
243
+ def _load_embeddings(self):
244
+ """Load your existing AMP embeddings."""
245
+ print(f"Loading AMP embeddings from {self.embeddings_path}...")
246
+
247
+ # Try to load the combined embeddings file first (FULL DATA)
248
+ combined_path = os.path.join(self.embeddings_path, "all_peptide_embeddings.pt")
249
+
250
+ if os.path.exists(combined_path):
251
+ print(f"Loading combined embeddings from {combined_path} (FULL DATA)...")
252
+ # Load on CPU first to avoid CUDA issues with DataLoader workers
253
+ self.embeddings = torch.load(combined_path, map_location='cpu')
254
+ print(f"✓ Loaded ALL embeddings: {self.embeddings.shape}")
255
+ else:
256
+ print("Combined embeddings file not found, loading individual files...")
257
+ # Fallback to individual files
258
+ import glob
259
+
260
+ embedding_files = glob.glob(os.path.join(self.embeddings_path, "*.pt"))
261
+ embedding_files = [f for f in embedding_files if not f.endswith('metadata.json') and not f.endswith('sequence_ids.json') and not f.endswith('all_peptide_embeddings.pt')]
262
+
263
+ print(f"Found {len(embedding_files)} individual embedding files")
264
+
265
+ # Load and stack all embeddings
266
+ embeddings_list = []
267
+ for file_path in embedding_files:
268
+ try:
269
+ embedding = torch.load(file_path, map_location='cpu')
270
+ if embedding.dim() == 2: # (seq_len, hidden_dim)
271
+ embeddings_list.append(embedding)
272
+ else:
273
+ print(f"Warning: Skipping {file_path} - unexpected shape {embedding.shape}")
274
+ except Exception as e:
275
+ print(f"Warning: Could not load {file_path}: {e}")
276
+
277
+ if not embeddings_list:
278
+ raise ValueError("No valid embeddings found!")
279
+
280
+ self.embeddings = torch.stack(embeddings_list)
281
+ print(f"Loaded {len(self.embeddings)} embeddings from individual files")
282
+
283
+ def _load_cfg_data(self):
284
+ """Load CFG data from FASTA file with automatic AMP labeling."""
285
+ print(f"Loading CFG data from FASTA: {self.cfg_data_path}...")
286
+
287
+ # Check if it's a FASTA file or JSON file
288
+ if self.cfg_data_path.endswith('.fasta') or self.cfg_data_path.endswith('.fa'):
289
+ # Parse FASTA file with automatic labeling
290
+ cfg_data = parse_fasta_with_amp_labels(self.cfg_data_path, self.max_seq_len)
291
+
292
+ self.cfg_sequences = cfg_data['sequences']
293
+ self.cfg_headers = cfg_data['headers']
294
+ self.cfg_original_labels = np.array(cfg_data['labels'])
295
+ self.cfg_masked_labels = np.array(cfg_data['masked_labels'])
296
+ self.cfg_mask_indices = set(cfg_data['mask_indices'])
297
+
298
+ else:
299
+ # Legacy JSON format support
300
+ with open(self.cfg_data_path, 'r') as f:
301
+ cfg_data = json.load(f)
302
+
303
+ self.cfg_sequences = cfg_data['sequences']
304
+ self.cfg_headers = cfg_data.get('headers', [''] * len(cfg_data['sequences']))
305
+ self.cfg_original_labels = np.array(cfg_data['labels'])
306
+
307
+ # For CFG training, we need to create masked labels
308
+ # Randomly mask 10% of labels for CFG training
309
+ self.cfg_masked_labels = self.cfg_original_labels.copy()
310
+ mask_probability = 0.1
311
+ mask_indices = np.random.choice(
312
+ len(self.cfg_original_labels),
313
+ size=int(len(self.cfg_original_labels) * mask_probability),
314
+ replace=False
315
+ )
316
+ self.cfg_masked_labels[mask_indices] = 2 # 2 = mask/unknown
317
+ self.cfg_mask_indices = set(mask_indices)
318
+
319
+ print(f"Loaded {len(self.cfg_sequences)} CFG sequences")
320
+ print(f"Label distribution: {np.bincount(self.cfg_original_labels)}")
321
+ print(f"Masked {len(self.cfg_mask_indices)} labels for CFG training")
322
+
323
+ def _align_data(self):
324
+ """Align AMP embeddings with CFG data based on sequence matching."""
325
+ print("Aligning AMP embeddings with CFG data...")
326
+
327
+ # For now, we'll use a simple approach: take the first N sequences
328
+ # where N is the minimum of embeddings and CFG data
329
+ min_samples = min(len(self.embeddings), len(self.cfg_sequences))
330
+
331
+ self.aligned_indices = list(range(min_samples))
332
+
333
+ # Align labels
334
+ if self.use_masked_labels:
335
+ self.cfg_labels = self.cfg_masked_labels[:min_samples]
336
+ else:
337
+ self.cfg_labels = self.cfg_original_labels[:min_samples]
338
+
339
+ # Align embeddings
340
+ self.aligned_embeddings = self.embeddings[:min_samples]
341
+
342
+ print(f"Aligned {min_samples} samples")
343
+
344
+ def __len__(self) -> int:
345
+ return len(self.aligned_indices)
346
+
347
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
348
+ """Get a single sample with embedding and CFG label."""
349
+ # Embeddings are already on CPU
350
+ embedding = self.aligned_embeddings[idx]
351
+ label = self.cfg_labels[idx]
352
+ original_label = self.cfg_original_labels[idx]
353
+ is_masked = idx in self.cfg_mask_indices
354
+
355
+ return {
356
+ 'embedding': embedding,
357
+ 'label': torch.tensor(label, dtype=torch.long),
358
+ 'original_label': torch.tensor(original_label, dtype=torch.long),
359
+ 'is_masked': torch.tensor(is_masked, dtype=torch.bool),
360
+ 'index': torch.tensor(idx, dtype=torch.long)
361
+ }
362
+
363
+ def get_embedding_stats(self) -> Dict:
364
+ """Get statistics about the embeddings."""
365
+ return {
366
+ 'shape': self.aligned_embeddings.shape,
367
+ 'mean': self.aligned_embeddings.mean().item(),
368
+ 'std': self.aligned_embeddings.std().item(),
369
+ 'min': self.aligned_embeddings.min().item(),
370
+ 'max': self.aligned_embeddings.max().item()
371
+ }
372
+
373
+ def create_cfg_dataloader(dataset: Dataset,
374
+ batch_size: int = 32,
375
+ shuffle: bool = True,
376
+ num_workers: int = 4) -> DataLoader:
377
+ """Create a DataLoader for CFG training."""
378
+
379
+ def collate_fn(batch):
380
+ """Custom collate function for CFG data."""
381
+ # Separate different types of data
382
+ embeddings = torch.stack([item['embedding'] for item in batch])
383
+ labels = torch.stack([item['label'] for item in batch])
384
+ original_labels = torch.stack([item['original_label'] for item in batch])
385
+ is_masked = torch.stack([item['is_masked'] for item in batch])
386
+ indices = torch.stack([item['index'] for item in batch])
387
+
388
+ return {
389
+ 'embeddings': embeddings,
390
+ 'labels': labels,
391
+ 'original_labels': original_labels,
392
+ 'is_masked': is_masked,
393
+ 'indices': indices
394
+ }
395
+
396
+ return DataLoader(
397
+ dataset,
398
+ batch_size=batch_size,
399
+ shuffle=shuffle,
400
+ num_workers=num_workers,
401
+ collate_fn=collate_fn,
402
+ pin_memory=True
403
+ )
404
+
405
+ def test_cfg_dataset():
406
+ """Test function to verify the CFG dataset works correctly."""
407
+ print("Testing CFG Dataset...")
408
+
409
+ # Test with a small subset
410
+ test_data = {
411
+ 'sequences': ['MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG',
412
+ 'MKLLIVTFCLTFAAL',
413
+ 'MKLLIVTFCLTFAALMKLLIVTFCLTFAAL'],
414
+ 'original_labels': [0, 1, 0], # amp, non_amp, amp
415
+ 'masked_labels': [0, 2, 0], # amp, mask, amp
416
+ 'mask_indices': [1] # Only second sequence is masked
417
+ }
418
+
419
+ # Save test data
420
+ test_path = 'test_cfg_data.json'
421
+ with open(test_path, 'w') as f:
422
+ json.dump(test_data, f)
423
+
424
+ # Test dataset
425
+ dataset = CFGUniProtDataset(test_path, use_masked_labels=True)
426
+
427
+ print(f"Dataset length: {len(dataset)}")
428
+ for i in range(len(dataset)):
429
+ sample = dataset[i]
430
+ print(f"Sample {i}:")
431
+ print(f" Sequence: {sample['sequence'][:20]}...")
432
+ print(f" Label: {sample['label'].item()}")
433
+ print(f" Original Label: {sample['original_label'].item()}")
434
+ print(f" Is Masked: {sample['is_masked'].item()}")
435
+
436
+ # Clean up
437
+ os.remove(test_path)
438
+ print("Test completed successfully!")
439
+
440
+ if __name__ == "__main__":
441
+ test_cfg_dataset()