import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader import numpy as np import json import os from typing import Dict, List, Tuple, Optional import random import re def parse_fasta_with_amp_labels(fasta_path: str, max_seq_len: int = 50) -> Dict[str, any]: """ Parse FASTA file and assign AMP/Non-AMP labels based on header prefixes. Label assignment strategy: - AMP (0): Headers starting with '>AP' - Non-AMP (1): Headers starting with '>sp' - Mask (2): Used for CFG training (randomly assigned) File format: - Odd lines: Headers (>sp or >AP) - Even lines: Amino acid sequences Args: fasta_path: Path to FASTA file max_seq_len: Maximum sequence length to include Returns: Dictionary with sequences, labels, and metadata """ sequences = [] labels = [] headers = [] print(f"Parsing FASTA file: {fasta_path}") print("Label assignment: >AP = AMP (0), >sp = Non-AMP (1)") current_header = "" current_sequence = "" with open(fasta_path, 'r') as f: for line in f: line = line.strip() if line.startswith('>'): # Process previous sequence if exists if current_sequence and current_header: if 2 <= len(current_sequence) <= max_seq_len: # Check if sequence contains only canonical amino acids canonical_aa = set('ACDEFGHIKLMNPQRSTVWY') if all(aa in canonical_aa for aa in current_sequence.upper()): sequences.append(current_sequence.upper()) headers.append(current_header) # Assign label based on header prefix if current_header.startswith('AP'): labels.append(0) # AMP elif current_header.startswith('sp'): labels.append(1) # Non-AMP else: # Unknown prefix, default to Non-AMP labels.append(1) print(f"Warning: Unknown header prefix in '{current_header}', defaulting to Non-AMP") # Start new sequence (remove '>' from header) current_header = line[1:] current_sequence = "" else: current_sequence += line # Process last sequence if current_sequence and current_header: if 2 <= len(current_sequence) <= max_seq_len: canonical_aa = set('ACDEFGHIKLMNPQRSTVWY') if all(aa in canonical_aa for aa in current_sequence.upper()): sequences.append(current_sequence.upper()) headers.append(current_header) # Assign label based on header prefix if current_header.startswith('AP'): labels.append(0) # AMP elif current_header.startswith('sp'): labels.append(1) # Non-AMP else: # Unknown prefix, default to Non-AMP labels.append(1) print(f"Warning: Unknown header prefix in '{current_header}', defaulting to Non-AMP") # Create masked labels for CFG training (10% masked) original_labels = np.array(labels) masked_labels = original_labels.copy() mask_probability = 0.1 mask_indices = np.random.choice( len(original_labels), size=int(len(original_labels) * mask_probability), replace=False ) masked_labels[mask_indices] = 2 # 2 = mask/unknown print(f"✓ Parsed {len(sequences)} valid sequences from FASTA") print(f" AMP sequences: {np.sum(original_labels == 0)}") print(f" Non-AMP sequences: {np.sum(original_labels == 1)}") print(f" Masked for CFG: {len(mask_indices)}") return { 'sequences': sequences, 'headers': headers, 'labels': original_labels.tolist(), 'masked_labels': masked_labels.tolist(), 'mask_indices': mask_indices.tolist() } class CFGUniProtDataset(Dataset): """ Dataset class for UniProt sequences with classifier-free guidance. This dataset: 1. Loads processed UniProt data with AMP classifications 2. Handles label masking for CFG training 3. Integrates with your existing flow training pipeline 4. Provides sequences, labels, and masking information """ def __init__(self, data_path: str, use_masked_labels: bool = True, mask_probability: float = 0.1, max_seq_len: int = 50, device: str = 'cuda'): self.data_path = data_path self.use_masked_labels = use_masked_labels self.mask_probability = mask_probability self.max_seq_len = max_seq_len self.device = device # Load processed data self._load_data() # Label mapping self.label_map = { 0: 'amp', # MIC < 100 1: 'non_amp', # MIC > 100 2: 'mask' # Unknown MIC } print(f"CFG Dataset initialized:") print(f" Total sequences: {len(self.sequences)}") print(f" Using masked labels: {use_masked_labels}") print(f" Mask probability: {mask_probability}") print(f" Label distribution: {self._get_label_distribution()}") def _load_data(self): """Load processed UniProt data.""" if os.path.exists(self.data_path): with open(self.data_path, 'r') as f: data = json.load(f) self.sequences = data['sequences'] self.original_labels = np.array(data['original_labels']) self.masked_labels = np.array(data['masked_labels']) self.mask_indices = set(data['mask_indices']) else: raise FileNotFoundError(f"Data file not found: {self.data_path}") def _get_label_distribution(self) -> Dict[str, int]: """Get distribution of labels in the dataset.""" labels = self.masked_labels if self.use_masked_labels else self.original_labels unique, counts = np.unique(labels, return_counts=True) return {self.label_map[label]: count for label, count in zip(unique, counts)} def __len__(self) -> int: return len(self.sequences) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: """Get a single sample with sequence and label.""" sequence = self.sequences[idx] # Get appropriate label if self.use_masked_labels: label = self.masked_labels[idx] else: label = self.original_labels[idx] # Check if this sample was masked is_masked = idx in self.mask_indices return { 'sequence': sequence, 'label': torch.tensor(label, dtype=torch.long), 'original_label': torch.tensor(self.original_labels[idx], dtype=torch.long), 'is_masked': torch.tensor(is_masked, dtype=torch.bool), 'index': torch.tensor(idx, dtype=torch.long) } def get_label_statistics(self) -> Dict[str, Dict]: """Get detailed statistics about labels.""" stats = { 'original': self._get_label_distribution(), 'masked': self._get_label_distribution() if self.use_masked_labels else None, 'masking_info': { 'total_masked': len(self.mask_indices), 'mask_probability': self.mask_probability, 'masked_indices': list(self.mask_indices) } } return stats class CFGFlowDataset(Dataset): """ Dataset that integrates CFG labels with your existing flow training pipeline. This dataset: 1. Loads your existing AMP embeddings 2. Adds CFG labels from UniProt processing 3. Handles the integration between embeddings and labels 4. Provides data in the format expected by your flow training """ def __init__(self, embeddings_path: str, cfg_data_path: str, use_masked_labels: bool = True, max_seq_len: int = 50, device: str = 'cuda'): self.embeddings_path = embeddings_path self.cfg_data_path = cfg_data_path self.use_masked_labels = use_masked_labels self.max_seq_len = max_seq_len self.device = device # Load data self._load_embeddings() self._load_cfg_data() self._align_data() print(f"CFG Flow Dataset initialized:") print(f" AMP embeddings: {self.embeddings.shape}") print(f" CFG labels: {len(self.cfg_labels)}") print(f" Aligned samples: {len(self.aligned_indices)}") def _load_embeddings(self): """Load your existing AMP embeddings.""" print(f"Loading AMP embeddings from {self.embeddings_path}...") # Try to load the combined embeddings file first (FULL DATA) combined_path = os.path.join(self.embeddings_path, "all_peptide_embeddings.pt") if os.path.exists(combined_path): print(f"Loading combined embeddings from {combined_path} (FULL DATA)...") # Load on CPU first to avoid CUDA issues with DataLoader workers self.embeddings = torch.load(combined_path, map_location='cpu') print(f"✓ Loaded ALL embeddings: {self.embeddings.shape}") else: print("Combined embeddings file not found, loading individual files...") # Fallback to individual files import glob embedding_files = glob.glob(os.path.join(self.embeddings_path, "*.pt")) embedding_files = [f for f in embedding_files if not f.endswith('metadata.json') and not f.endswith('sequence_ids.json') and not f.endswith('all_peptide_embeddings.pt')] print(f"Found {len(embedding_files)} individual embedding files") # Load and stack all embeddings embeddings_list = [] for file_path in embedding_files: try: embedding = torch.load(file_path, map_location='cpu') if embedding.dim() == 2: # (seq_len, hidden_dim) embeddings_list.append(embedding) else: print(f"Warning: Skipping {file_path} - unexpected shape {embedding.shape}") except Exception as e: print(f"Warning: Could not load {file_path}: {e}") if not embeddings_list: raise ValueError("No valid embeddings found!") self.embeddings = torch.stack(embeddings_list) print(f"Loaded {len(self.embeddings)} embeddings from individual files") def _load_cfg_data(self): """Load CFG data from FASTA file with automatic AMP labeling.""" print(f"Loading CFG data from FASTA: {self.cfg_data_path}...") # Check if it's a FASTA file or JSON file if self.cfg_data_path.endswith('.fasta') or self.cfg_data_path.endswith('.fa'): # Parse FASTA file with automatic labeling cfg_data = parse_fasta_with_amp_labels(self.cfg_data_path, self.max_seq_len) self.cfg_sequences = cfg_data['sequences'] self.cfg_headers = cfg_data['headers'] self.cfg_original_labels = np.array(cfg_data['labels']) self.cfg_masked_labels = np.array(cfg_data['masked_labels']) self.cfg_mask_indices = set(cfg_data['mask_indices']) else: # Legacy JSON format support with open(self.cfg_data_path, 'r') as f: cfg_data = json.load(f) self.cfg_sequences = cfg_data['sequences'] self.cfg_headers = cfg_data.get('headers', [''] * len(cfg_data['sequences'])) self.cfg_original_labels = np.array(cfg_data['labels']) # For CFG training, we need to create masked labels # Randomly mask 10% of labels for CFG training self.cfg_masked_labels = self.cfg_original_labels.copy() mask_probability = 0.1 mask_indices = np.random.choice( len(self.cfg_original_labels), size=int(len(self.cfg_original_labels) * mask_probability), replace=False ) self.cfg_masked_labels[mask_indices] = 2 # 2 = mask/unknown self.cfg_mask_indices = set(mask_indices) print(f"Loaded {len(self.cfg_sequences)} CFG sequences") print(f"Label distribution: {np.bincount(self.cfg_original_labels)}") print(f"Masked {len(self.cfg_mask_indices)} labels for CFG training") def _align_data(self): """Align AMP embeddings with CFG data based on sequence matching.""" print("Aligning AMP embeddings with CFG data...") # For now, we'll use a simple approach: take the first N sequences # where N is the minimum of embeddings and CFG data min_samples = min(len(self.embeddings), len(self.cfg_sequences)) self.aligned_indices = list(range(min_samples)) # Align labels if self.use_masked_labels: self.cfg_labels = self.cfg_masked_labels[:min_samples] else: self.cfg_labels = self.cfg_original_labels[:min_samples] # Align embeddings self.aligned_embeddings = self.embeddings[:min_samples] print(f"Aligned {min_samples} samples") def __len__(self) -> int: return len(self.aligned_indices) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: """Get a single sample with embedding and CFG label.""" # Embeddings are already on CPU embedding = self.aligned_embeddings[idx] label = self.cfg_labels[idx] original_label = self.cfg_original_labels[idx] is_masked = idx in self.cfg_mask_indices return { 'embedding': embedding, 'label': torch.tensor(label, dtype=torch.long), 'original_label': torch.tensor(original_label, dtype=torch.long), 'is_masked': torch.tensor(is_masked, dtype=torch.bool), 'index': torch.tensor(idx, dtype=torch.long) } def get_embedding_stats(self) -> Dict: """Get statistics about the embeddings.""" return { 'shape': self.aligned_embeddings.shape, 'mean': self.aligned_embeddings.mean().item(), 'std': self.aligned_embeddings.std().item(), 'min': self.aligned_embeddings.min().item(), 'max': self.aligned_embeddings.max().item() } def create_cfg_dataloader(dataset: Dataset, batch_size: int = 32, shuffle: bool = True, num_workers: int = 4) -> DataLoader: """Create a DataLoader for CFG training.""" def collate_fn(batch): """Custom collate function for CFG data.""" # Separate different types of data embeddings = torch.stack([item['embedding'] for item in batch]) labels = torch.stack([item['label'] for item in batch]) original_labels = torch.stack([item['original_label'] for item in batch]) is_masked = torch.stack([item['is_masked'] for item in batch]) indices = torch.stack([item['index'] for item in batch]) return { 'embeddings': embeddings, 'labels': labels, 'original_labels': original_labels, 'is_masked': is_masked, 'indices': indices } return DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=collate_fn, pin_memory=True ) def test_cfg_dataset(): """Test function to verify the CFG dataset works correctly.""" print("Testing CFG Dataset...") # Test with a small subset test_data = { 'sequences': ['MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG', 'MKLLIVTFCLTFAAL', 'MKLLIVTFCLTFAALMKLLIVTFCLTFAAL'], 'original_labels': [0, 1, 0], # amp, non_amp, amp 'masked_labels': [0, 2, 0], # amp, mask, amp 'mask_indices': [1] # Only second sequence is masked } # Save test data test_path = 'test_cfg_data.json' with open(test_path, 'w') as f: json.dump(test_data, f) # Test dataset dataset = CFGUniProtDataset(test_path, use_masked_labels=True) print(f"Dataset length: {len(dataset)}") for i in range(len(dataset)): sample = dataset[i] print(f"Sample {i}:") print(f" Sequence: {sample['sequence'][:20]}...") print(f" Label: {sample['label'].item()}") print(f" Original Label: {sample['original_label'].item()}") print(f" Is Masked: {sample['is_masked'].item()}") # Clean up os.remove(test_path) print("Test completed successfully!") if __name__ == "__main__": test_cfg_dataset()