FlowFinal / src /cfg_dataset.py
esunAI's picture
Add cfg_dataset.py
164b12c verified
raw
history blame
17.7 kB
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import json
import os
from typing import Dict, List, Tuple, Optional
import random
import re
def parse_fasta_with_amp_labels(fasta_path: str, max_seq_len: int = 50) -> Dict[str, any]:
"""
Parse FASTA file and assign AMP/Non-AMP labels based on header prefixes.
Label assignment strategy:
- AMP (0): Headers starting with '>AP'
- Non-AMP (1): Headers starting with '>sp'
- Mask (2): Used for CFG training (randomly assigned)
File format:
- Odd lines: Headers (>sp or >AP)
- Even lines: Amino acid sequences
Args:
fasta_path: Path to FASTA file
max_seq_len: Maximum sequence length to include
Returns:
Dictionary with sequences, labels, and metadata
"""
sequences = []
labels = []
headers = []
print(f"Parsing FASTA file: {fasta_path}")
print("Label assignment: >AP = AMP (0), >sp = Non-AMP (1)")
current_header = ""
current_sequence = ""
with open(fasta_path, 'r') as f:
for line in f:
line = line.strip()
if line.startswith('>'):
# Process previous sequence if exists
if current_sequence and current_header:
if 2 <= len(current_sequence) <= max_seq_len:
# Check if sequence contains only canonical amino acids
canonical_aa = set('ACDEFGHIKLMNPQRSTVWY')
if all(aa in canonical_aa for aa in current_sequence.upper()):
sequences.append(current_sequence.upper())
headers.append(current_header)
# Assign label based on header prefix
if current_header.startswith('AP'):
labels.append(0) # AMP
elif current_header.startswith('sp'):
labels.append(1) # Non-AMP
else:
# Unknown prefix, default to Non-AMP
labels.append(1)
print(f"Warning: Unknown header prefix in '{current_header}', defaulting to Non-AMP")
# Start new sequence (remove '>' from header)
current_header = line[1:]
current_sequence = ""
else:
current_sequence += line
# Process last sequence
if current_sequence and current_header:
if 2 <= len(current_sequence) <= max_seq_len:
canonical_aa = set('ACDEFGHIKLMNPQRSTVWY')
if all(aa in canonical_aa for aa in current_sequence.upper()):
sequences.append(current_sequence.upper())
headers.append(current_header)
# Assign label based on header prefix
if current_header.startswith('AP'):
labels.append(0) # AMP
elif current_header.startswith('sp'):
labels.append(1) # Non-AMP
else:
# Unknown prefix, default to Non-AMP
labels.append(1)
print(f"Warning: Unknown header prefix in '{current_header}', defaulting to Non-AMP")
# Create masked labels for CFG training (10% masked)
original_labels = np.array(labels)
masked_labels = original_labels.copy()
mask_probability = 0.1
mask_indices = np.random.choice(
len(original_labels),
size=int(len(original_labels) * mask_probability),
replace=False
)
masked_labels[mask_indices] = 2 # 2 = mask/unknown
print(f"✓ Parsed {len(sequences)} valid sequences from FASTA")
print(f" AMP sequences: {np.sum(original_labels == 0)}")
print(f" Non-AMP sequences: {np.sum(original_labels == 1)}")
print(f" Masked for CFG: {len(mask_indices)}")
return {
'sequences': sequences,
'headers': headers,
'labels': original_labels.tolist(),
'masked_labels': masked_labels.tolist(),
'mask_indices': mask_indices.tolist()
}
class CFGUniProtDataset(Dataset):
"""
Dataset class for UniProt sequences with classifier-free guidance.
This dataset:
1. Loads processed UniProt data with AMP classifications
2. Handles label masking for CFG training
3. Integrates with your existing flow training pipeline
4. Provides sequences, labels, and masking information
"""
def __init__(self,
data_path: str,
use_masked_labels: bool = True,
mask_probability: float = 0.1,
max_seq_len: int = 50,
device: str = 'cuda'):
self.data_path = data_path
self.use_masked_labels = use_masked_labels
self.mask_probability = mask_probability
self.max_seq_len = max_seq_len
self.device = device
# Load processed data
self._load_data()
# Label mapping
self.label_map = {
0: 'amp', # MIC < 100
1: 'non_amp', # MIC > 100
2: 'mask' # Unknown MIC
}
print(f"CFG Dataset initialized:")
print(f" Total sequences: {len(self.sequences)}")
print(f" Using masked labels: {use_masked_labels}")
print(f" Mask probability: {mask_probability}")
print(f" Label distribution: {self._get_label_distribution()}")
def _load_data(self):
"""Load processed UniProt data."""
if os.path.exists(self.data_path):
with open(self.data_path, 'r') as f:
data = json.load(f)
self.sequences = data['sequences']
self.original_labels = np.array(data['original_labels'])
self.masked_labels = np.array(data['masked_labels'])
self.mask_indices = set(data['mask_indices'])
else:
raise FileNotFoundError(f"Data file not found: {self.data_path}")
def _get_label_distribution(self) -> Dict[str, int]:
"""Get distribution of labels in the dataset."""
labels = self.masked_labels if self.use_masked_labels else self.original_labels
unique, counts = np.unique(labels, return_counts=True)
return {self.label_map[label]: count for label, count in zip(unique, counts)}
def __len__(self) -> int:
return len(self.sequences)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""Get a single sample with sequence and label."""
sequence = self.sequences[idx]
# Get appropriate label
if self.use_masked_labels:
label = self.masked_labels[idx]
else:
label = self.original_labels[idx]
# Check if this sample was masked
is_masked = idx in self.mask_indices
return {
'sequence': sequence,
'label': torch.tensor(label, dtype=torch.long),
'original_label': torch.tensor(self.original_labels[idx], dtype=torch.long),
'is_masked': torch.tensor(is_masked, dtype=torch.bool),
'index': torch.tensor(idx, dtype=torch.long)
}
def get_label_statistics(self) -> Dict[str, Dict]:
"""Get detailed statistics about labels."""
stats = {
'original': self._get_label_distribution(),
'masked': self._get_label_distribution() if self.use_masked_labels else None,
'masking_info': {
'total_masked': len(self.mask_indices),
'mask_probability': self.mask_probability,
'masked_indices': list(self.mask_indices)
}
}
return stats
class CFGFlowDataset(Dataset):
"""
Dataset that integrates CFG labels with your existing flow training pipeline.
This dataset:
1. Loads your existing AMP embeddings
2. Adds CFG labels from UniProt processing
3. Handles the integration between embeddings and labels
4. Provides data in the format expected by your flow training
"""
def __init__(self,
embeddings_path: str,
cfg_data_path: str,
use_masked_labels: bool = True,
max_seq_len: int = 50,
device: str = 'cuda'):
self.embeddings_path = embeddings_path
self.cfg_data_path = cfg_data_path
self.use_masked_labels = use_masked_labels
self.max_seq_len = max_seq_len
self.device = device
# Load data
self._load_embeddings()
self._load_cfg_data()
self._align_data()
print(f"CFG Flow Dataset initialized:")
print(f" AMP embeddings: {self.embeddings.shape}")
print(f" CFG labels: {len(self.cfg_labels)}")
print(f" Aligned samples: {len(self.aligned_indices)}")
def _load_embeddings(self):
"""Load your existing AMP embeddings."""
print(f"Loading AMP embeddings from {self.embeddings_path}...")
# Try to load the combined embeddings file first (FULL DATA)
combined_path = os.path.join(self.embeddings_path, "all_peptide_embeddings.pt")
if os.path.exists(combined_path):
print(f"Loading combined embeddings from {combined_path} (FULL DATA)...")
# Load on CPU first to avoid CUDA issues with DataLoader workers
self.embeddings = torch.load(combined_path, map_location='cpu')
print(f"✓ Loaded ALL embeddings: {self.embeddings.shape}")
else:
print("Combined embeddings file not found, loading individual files...")
# Fallback to individual files
import glob
embedding_files = glob.glob(os.path.join(self.embeddings_path, "*.pt"))
embedding_files = [f for f in embedding_files if not f.endswith('metadata.json') and not f.endswith('sequence_ids.json') and not f.endswith('all_peptide_embeddings.pt')]
print(f"Found {len(embedding_files)} individual embedding files")
# Load and stack all embeddings
embeddings_list = []
for file_path in embedding_files:
try:
embedding = torch.load(file_path, map_location='cpu')
if embedding.dim() == 2: # (seq_len, hidden_dim)
embeddings_list.append(embedding)
else:
print(f"Warning: Skipping {file_path} - unexpected shape {embedding.shape}")
except Exception as e:
print(f"Warning: Could not load {file_path}: {e}")
if not embeddings_list:
raise ValueError("No valid embeddings found!")
self.embeddings = torch.stack(embeddings_list)
print(f"Loaded {len(self.embeddings)} embeddings from individual files")
def _load_cfg_data(self):
"""Load CFG data from FASTA file with automatic AMP labeling."""
print(f"Loading CFG data from FASTA: {self.cfg_data_path}...")
# Check if it's a FASTA file or JSON file
if self.cfg_data_path.endswith('.fasta') or self.cfg_data_path.endswith('.fa'):
# Parse FASTA file with automatic labeling
cfg_data = parse_fasta_with_amp_labels(self.cfg_data_path, self.max_seq_len)
self.cfg_sequences = cfg_data['sequences']
self.cfg_headers = cfg_data['headers']
self.cfg_original_labels = np.array(cfg_data['labels'])
self.cfg_masked_labels = np.array(cfg_data['masked_labels'])
self.cfg_mask_indices = set(cfg_data['mask_indices'])
else:
# Legacy JSON format support
with open(self.cfg_data_path, 'r') as f:
cfg_data = json.load(f)
self.cfg_sequences = cfg_data['sequences']
self.cfg_headers = cfg_data.get('headers', [''] * len(cfg_data['sequences']))
self.cfg_original_labels = np.array(cfg_data['labels'])
# For CFG training, we need to create masked labels
# Randomly mask 10% of labels for CFG training
self.cfg_masked_labels = self.cfg_original_labels.copy()
mask_probability = 0.1
mask_indices = np.random.choice(
len(self.cfg_original_labels),
size=int(len(self.cfg_original_labels) * mask_probability),
replace=False
)
self.cfg_masked_labels[mask_indices] = 2 # 2 = mask/unknown
self.cfg_mask_indices = set(mask_indices)
print(f"Loaded {len(self.cfg_sequences)} CFG sequences")
print(f"Label distribution: {np.bincount(self.cfg_original_labels)}")
print(f"Masked {len(self.cfg_mask_indices)} labels for CFG training")
def _align_data(self):
"""Align AMP embeddings with CFG data based on sequence matching."""
print("Aligning AMP embeddings with CFG data...")
# For now, we'll use a simple approach: take the first N sequences
# where N is the minimum of embeddings and CFG data
min_samples = min(len(self.embeddings), len(self.cfg_sequences))
self.aligned_indices = list(range(min_samples))
# Align labels
if self.use_masked_labels:
self.cfg_labels = self.cfg_masked_labels[:min_samples]
else:
self.cfg_labels = self.cfg_original_labels[:min_samples]
# Align embeddings
self.aligned_embeddings = self.embeddings[:min_samples]
print(f"Aligned {min_samples} samples")
def __len__(self) -> int:
return len(self.aligned_indices)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""Get a single sample with embedding and CFG label."""
# Embeddings are already on CPU
embedding = self.aligned_embeddings[idx]
label = self.cfg_labels[idx]
original_label = self.cfg_original_labels[idx]
is_masked = idx in self.cfg_mask_indices
return {
'embedding': embedding,
'label': torch.tensor(label, dtype=torch.long),
'original_label': torch.tensor(original_label, dtype=torch.long),
'is_masked': torch.tensor(is_masked, dtype=torch.bool),
'index': torch.tensor(idx, dtype=torch.long)
}
def get_embedding_stats(self) -> Dict:
"""Get statistics about the embeddings."""
return {
'shape': self.aligned_embeddings.shape,
'mean': self.aligned_embeddings.mean().item(),
'std': self.aligned_embeddings.std().item(),
'min': self.aligned_embeddings.min().item(),
'max': self.aligned_embeddings.max().item()
}
def create_cfg_dataloader(dataset: Dataset,
batch_size: int = 32,
shuffle: bool = True,
num_workers: int = 4) -> DataLoader:
"""Create a DataLoader for CFG training."""
def collate_fn(batch):
"""Custom collate function for CFG data."""
# Separate different types of data
embeddings = torch.stack([item['embedding'] for item in batch])
labels = torch.stack([item['label'] for item in batch])
original_labels = torch.stack([item['original_label'] for item in batch])
is_masked = torch.stack([item['is_masked'] for item in batch])
indices = torch.stack([item['index'] for item in batch])
return {
'embeddings': embeddings,
'labels': labels,
'original_labels': original_labels,
'is_masked': is_masked,
'indices': indices
}
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=True
)
def test_cfg_dataset():
"""Test function to verify the CFG dataset works correctly."""
print("Testing CFG Dataset...")
# Test with a small subset
test_data = {
'sequences': ['MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG',
'MKLLIVTFCLTFAAL',
'MKLLIVTFCLTFAALMKLLIVTFCLTFAAL'],
'original_labels': [0, 1, 0], # amp, non_amp, amp
'masked_labels': [0, 2, 0], # amp, mask, amp
'mask_indices': [1] # Only second sequence is masked
}
# Save test data
test_path = 'test_cfg_data.json'
with open(test_path, 'w') as f:
json.dump(test_data, f)
# Test dataset
dataset = CFGUniProtDataset(test_path, use_masked_labels=True)
print(f"Dataset length: {len(dataset)}")
for i in range(len(dataset)):
sample = dataset[i]
print(f"Sample {i}:")
print(f" Sequence: {sample['sequence'][:20]}...")
print(f" Label: {sample['label'].item()}")
print(f" Original Label: {sample['original_label'].item()}")
print(f" Is Masked: {sample['is_masked'].item()}")
# Clean up
os.remove(test_path)
print("Test completed successfully!")
if __name__ == "__main__":
test_cfg_dataset()