File size: 17,745 Bytes
164b12c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import json
import os
from typing import Dict, List, Tuple, Optional
import random
import re

def parse_fasta_with_amp_labels(fasta_path: str, max_seq_len: int = 50) -> Dict[str, any]:
    """
    Parse FASTA file and assign AMP/Non-AMP labels based on header prefixes.
    
    Label assignment strategy:
    - AMP (0): Headers starting with '>AP'
    - Non-AMP (1): Headers starting with '>sp'
    - Mask (2): Used for CFG training (randomly assigned)
    
    File format:
    - Odd lines: Headers (>sp or >AP)
    - Even lines: Amino acid sequences
    
    Args:
        fasta_path: Path to FASTA file
        max_seq_len: Maximum sequence length to include
        
    Returns:
        Dictionary with sequences, labels, and metadata
    """
    sequences = []
    labels = []
    headers = []
    
    print(f"Parsing FASTA file: {fasta_path}")
    print("Label assignment: >AP = AMP (0), >sp = Non-AMP (1)")
    
    current_header = ""
    current_sequence = ""
    
    with open(fasta_path, 'r') as f:
        for line in f:
            line = line.strip()
            if line.startswith('>'):
                # Process previous sequence if exists
                if current_sequence and current_header:
                    if 2 <= len(current_sequence) <= max_seq_len:
                        # Check if sequence contains only canonical amino acids
                        canonical_aa = set('ACDEFGHIKLMNPQRSTVWY')
                        if all(aa in canonical_aa for aa in current_sequence.upper()):
                            sequences.append(current_sequence.upper())
                            headers.append(current_header)
                            
                            # Assign label based on header prefix
                            if current_header.startswith('AP'):
                                labels.append(0)  # AMP
                            elif current_header.startswith('sp'):
                                labels.append(1)  # Non-AMP
                            else:
                                # Unknown prefix, default to Non-AMP
                                labels.append(1)
                                print(f"Warning: Unknown header prefix in '{current_header}', defaulting to Non-AMP")
                
                # Start new sequence (remove '>' from header)
                current_header = line[1:]
                current_sequence = ""
            else:
                current_sequence += line
        
        # Process last sequence
        if current_sequence and current_header:
            if 2 <= len(current_sequence) <= max_seq_len:
                canonical_aa = set('ACDEFGHIKLMNPQRSTVWY')
                if all(aa in canonical_aa for aa in current_sequence.upper()):
                    sequences.append(current_sequence.upper())
                    headers.append(current_header)
                    
                    # Assign label based on header prefix
                    if current_header.startswith('AP'):
                        labels.append(0)  # AMP
                    elif current_header.startswith('sp'):
                        labels.append(1)  # Non-AMP
                    else:
                        # Unknown prefix, default to Non-AMP
                        labels.append(1)
                        print(f"Warning: Unknown header prefix in '{current_header}', defaulting to Non-AMP")
    
    # Create masked labels for CFG training (10% masked)
    original_labels = np.array(labels)
    masked_labels = original_labels.copy()
    mask_probability = 0.1
    mask_indices = np.random.choice(
        len(original_labels), 
        size=int(len(original_labels) * mask_probability), 
        replace=False
    )
    masked_labels[mask_indices] = 2  # 2 = mask/unknown
    
    print(f"✓ Parsed {len(sequences)} valid sequences from FASTA")
    print(f"  AMP sequences: {np.sum(original_labels == 0)}")
    print(f"  Non-AMP sequences: {np.sum(original_labels == 1)}")
    print(f"  Masked for CFG: {len(mask_indices)}")
    
    return {
        'sequences': sequences,
        'headers': headers,
        'labels': original_labels.tolist(),
        'masked_labels': masked_labels.tolist(),
        'mask_indices': mask_indices.tolist()
    }

class CFGUniProtDataset(Dataset):
    """
    Dataset class for UniProt sequences with classifier-free guidance.
    
    This dataset:
    1. Loads processed UniProt data with AMP classifications
    2. Handles label masking for CFG training
    3. Integrates with your existing flow training pipeline
    4. Provides sequences, labels, and masking information
    """
    
    def __init__(self, 
                 data_path: str,
                 use_masked_labels: bool = True,
                 mask_probability: float = 0.1,
                 max_seq_len: int = 50,
                 device: str = 'cuda'):
        
        self.data_path = data_path
        self.use_masked_labels = use_masked_labels
        self.mask_probability = mask_probability
        self.max_seq_len = max_seq_len
        self.device = device
        
        # Load processed data
        self._load_data()
        
        # Label mapping
        self.label_map = {
            0: 'amp',      # MIC < 100
            1: 'non_amp',  # MIC > 100
            2: 'mask'      # Unknown MIC
        }
        
        print(f"CFG Dataset initialized:")
        print(f"  Total sequences: {len(self.sequences)}")
        print(f"  Using masked labels: {use_masked_labels}")
        print(f"  Mask probability: {mask_probability}")
        print(f"  Label distribution: {self._get_label_distribution()}")
    
    def _load_data(self):
        """Load processed UniProt data."""
        if os.path.exists(self.data_path):
            with open(self.data_path, 'r') as f:
                data = json.load(f)
            
            self.sequences = data['sequences']
            self.original_labels = np.array(data['original_labels'])
            self.masked_labels = np.array(data['masked_labels'])
            self.mask_indices = set(data['mask_indices'])
            
        else:
            raise FileNotFoundError(f"Data file not found: {self.data_path}")
    
    def _get_label_distribution(self) -> Dict[str, int]:
        """Get distribution of labels in the dataset."""
        labels = self.masked_labels if self.use_masked_labels else self.original_labels
        unique, counts = np.unique(labels, return_counts=True)
        return {self.label_map[label]: count for label, count in zip(unique, counts)}
    
    def __len__(self) -> int:
        return len(self.sequences)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """Get a single sample with sequence and label."""
        sequence = self.sequences[idx]
        
        # Get appropriate label
        if self.use_masked_labels:
            label = self.masked_labels[idx]
        else:
            label = self.original_labels[idx]
        
        # Check if this sample was masked
        is_masked = idx in self.mask_indices
        
        return {
            'sequence': sequence,
            'label': torch.tensor(label, dtype=torch.long),
            'original_label': torch.tensor(self.original_labels[idx], dtype=torch.long),
            'is_masked': torch.tensor(is_masked, dtype=torch.bool),
            'index': torch.tensor(idx, dtype=torch.long)
        }
    
    def get_label_statistics(self) -> Dict[str, Dict]:
        """Get detailed statistics about labels."""
        stats = {
            'original': self._get_label_distribution(),
            'masked': self._get_label_distribution() if self.use_masked_labels else None,
            'masking_info': {
                'total_masked': len(self.mask_indices),
                'mask_probability': self.mask_probability,
                'masked_indices': list(self.mask_indices)
            }
        }
        return stats

class CFGFlowDataset(Dataset):
    """
    Dataset that integrates CFG labels with your existing flow training pipeline.
    
    This dataset:
    1. Loads your existing AMP embeddings
    2. Adds CFG labels from UniProt processing
    3. Handles the integration between embeddings and labels
    4. Provides data in the format expected by your flow training
    """
    
    def __init__(self, 
                 embeddings_path: str,
                 cfg_data_path: str,
                 use_masked_labels: bool = True,
                 max_seq_len: int = 50,
                 device: str = 'cuda'):
        
        self.embeddings_path = embeddings_path
        self.cfg_data_path = cfg_data_path
        self.use_masked_labels = use_masked_labels
        self.max_seq_len = max_seq_len
        self.device = device
        
        # Load data
        self._load_embeddings()
        self._load_cfg_data()
        self._align_data()
        
        print(f"CFG Flow Dataset initialized:")
        print(f"  AMP embeddings: {self.embeddings.shape}")
        print(f"  CFG labels: {len(self.cfg_labels)}")
        print(f"  Aligned samples: {len(self.aligned_indices)}")
    
    def _load_embeddings(self):
        """Load your existing AMP embeddings."""
        print(f"Loading AMP embeddings from {self.embeddings_path}...")
        
        # Try to load the combined embeddings file first (FULL DATA)
        combined_path = os.path.join(self.embeddings_path, "all_peptide_embeddings.pt")
        
        if os.path.exists(combined_path):
            print(f"Loading combined embeddings from {combined_path} (FULL DATA)...")
            # Load on CPU first to avoid CUDA issues with DataLoader workers
            self.embeddings = torch.load(combined_path, map_location='cpu')
            print(f"✓ Loaded ALL embeddings: {self.embeddings.shape}")
        else:
            print("Combined embeddings file not found, loading individual files...")
            # Fallback to individual files
            import glob
            
            embedding_files = glob.glob(os.path.join(self.embeddings_path, "*.pt"))
            embedding_files = [f for f in embedding_files if not f.endswith('metadata.json') and not f.endswith('sequence_ids.json') and not f.endswith('all_peptide_embeddings.pt')]
            
            print(f"Found {len(embedding_files)} individual embedding files")
            
            # Load and stack all embeddings
            embeddings_list = []
            for file_path in embedding_files:
                try:
                    embedding = torch.load(file_path, map_location='cpu')
                    if embedding.dim() == 2:  # (seq_len, hidden_dim)
                        embeddings_list.append(embedding)
                    else:
                        print(f"Warning: Skipping {file_path} - unexpected shape {embedding.shape}")
                except Exception as e:
                    print(f"Warning: Could not load {file_path}: {e}")
            
            if not embeddings_list:
                raise ValueError("No valid embeddings found!")
            
            self.embeddings = torch.stack(embeddings_list)
            print(f"Loaded {len(self.embeddings)} embeddings from individual files")
    
    def _load_cfg_data(self):
        """Load CFG data from FASTA file with automatic AMP labeling."""
        print(f"Loading CFG data from FASTA: {self.cfg_data_path}...")
        
        # Check if it's a FASTA file or JSON file
        if self.cfg_data_path.endswith('.fasta') or self.cfg_data_path.endswith('.fa'):
            # Parse FASTA file with automatic labeling
            cfg_data = parse_fasta_with_amp_labels(self.cfg_data_path, self.max_seq_len)
            
            self.cfg_sequences = cfg_data['sequences']
            self.cfg_headers = cfg_data['headers']
            self.cfg_original_labels = np.array(cfg_data['labels'])
            self.cfg_masked_labels = np.array(cfg_data['masked_labels'])
            self.cfg_mask_indices = set(cfg_data['mask_indices'])
            
        else:
            # Legacy JSON format support
            with open(self.cfg_data_path, 'r') as f:
                cfg_data = json.load(f)
            
            self.cfg_sequences = cfg_data['sequences']
            self.cfg_headers = cfg_data.get('headers', [''] * len(cfg_data['sequences']))
            self.cfg_original_labels = np.array(cfg_data['labels'])
            
            # For CFG training, we need to create masked labels
            # Randomly mask 10% of labels for CFG training
            self.cfg_masked_labels = self.cfg_original_labels.copy()
            mask_probability = 0.1
            mask_indices = np.random.choice(
                len(self.cfg_original_labels), 
                size=int(len(self.cfg_original_labels) * mask_probability), 
                replace=False
            )
            self.cfg_masked_labels[mask_indices] = 2  # 2 = mask/unknown
            self.cfg_mask_indices = set(mask_indices)
        
        print(f"Loaded {len(self.cfg_sequences)} CFG sequences")
        print(f"Label distribution: {np.bincount(self.cfg_original_labels)}")
        print(f"Masked {len(self.cfg_mask_indices)} labels for CFG training")
    
    def _align_data(self):
        """Align AMP embeddings with CFG data based on sequence matching."""
        print("Aligning AMP embeddings with CFG data...")
        
        # For now, we'll use a simple approach: take the first N sequences
        # where N is the minimum of embeddings and CFG data
        min_samples = min(len(self.embeddings), len(self.cfg_sequences))
        
        self.aligned_indices = list(range(min_samples))
        
        # Align labels
        if self.use_masked_labels:
            self.cfg_labels = self.cfg_masked_labels[:min_samples]
        else:
            self.cfg_labels = self.cfg_original_labels[:min_samples]
        
        # Align embeddings
        self.aligned_embeddings = self.embeddings[:min_samples]
        
        print(f"Aligned {min_samples} samples")
    
    def __len__(self) -> int:
        return len(self.aligned_indices)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """Get a single sample with embedding and CFG label."""
        # Embeddings are already on CPU
        embedding = self.aligned_embeddings[idx]
        label = self.cfg_labels[idx]
        original_label = self.cfg_original_labels[idx]
        is_masked = idx in self.cfg_mask_indices
        
        return {
            'embedding': embedding,
            'label': torch.tensor(label, dtype=torch.long),
            'original_label': torch.tensor(original_label, dtype=torch.long),
            'is_masked': torch.tensor(is_masked, dtype=torch.bool),
            'index': torch.tensor(idx, dtype=torch.long)
        }
    
    def get_embedding_stats(self) -> Dict:
        """Get statistics about the embeddings."""
        return {
            'shape': self.aligned_embeddings.shape,
            'mean': self.aligned_embeddings.mean().item(),
            'std': self.aligned_embeddings.std().item(),
            'min': self.aligned_embeddings.min().item(),
            'max': self.aligned_embeddings.max().item()
        }

def create_cfg_dataloader(dataset: Dataset, 
                         batch_size: int = 32,
                         shuffle: bool = True,
                         num_workers: int = 4) -> DataLoader:
    """Create a DataLoader for CFG training."""
    
    def collate_fn(batch):
        """Custom collate function for CFG data."""
        # Separate different types of data
        embeddings = torch.stack([item['embedding'] for item in batch])
        labels = torch.stack([item['label'] for item in batch])
        original_labels = torch.stack([item['original_label'] for item in batch])
        is_masked = torch.stack([item['is_masked'] for item in batch])
        indices = torch.stack([item['index'] for item in batch])
        
        return {
            'embeddings': embeddings,
            'labels': labels,
            'original_labels': original_labels,
            'is_masked': is_masked,
            'indices': indices
        }
    
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True
    )

def test_cfg_dataset():
    """Test function to verify the CFG dataset works correctly."""
    print("Testing CFG Dataset...")
    
    # Test with a small subset
    test_data = {
        'sequences': ['MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG', 
                     'MKLLIVTFCLTFAAL',
                     'MKLLIVTFCLTFAALMKLLIVTFCLTFAAL'],
        'original_labels': [0, 1, 0],  # amp, non_amp, amp
        'masked_labels': [0, 2, 0],    # amp, mask, amp
        'mask_indices': [1]  # Only second sequence is masked
    }
    
    # Save test data
    test_path = 'test_cfg_data.json'
    with open(test_path, 'w') as f:
        json.dump(test_data, f)
    
    # Test dataset
    dataset = CFGUniProtDataset(test_path, use_masked_labels=True)
    
    print(f"Dataset length: {len(dataset)}")
    for i in range(len(dataset)):
        sample = dataset[i]
        print(f"Sample {i}:")
        print(f"  Sequence: {sample['sequence'][:20]}...")
        print(f"  Label: {sample['label'].item()}")
        print(f"  Original Label: {sample['original_label'].item()}")
        print(f"  Is Masked: {sample['is_masked'].item()}")
    
    # Clean up
    os.remove(test_path)
    print("Test completed successfully!")

if __name__ == "__main__":
    test_cfg_dataset()