|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
import numpy as np |
|
|
import json |
|
|
import os |
|
|
from typing import Dict, List, Tuple, Optional |
|
|
import random |
|
|
import re |
|
|
|
|
|
def parse_fasta_with_amp_labels(fasta_path: str, max_seq_len: int = 50) -> Dict[str, any]: |
|
|
""" |
|
|
Parse FASTA file and assign AMP/Non-AMP labels based on header prefixes. |
|
|
|
|
|
Label assignment strategy: |
|
|
- AMP (0): Headers starting with '>AP' |
|
|
- Non-AMP (1): Headers starting with '>sp' |
|
|
- Mask (2): Used for CFG training (randomly assigned) |
|
|
|
|
|
File format: |
|
|
- Odd lines: Headers (>sp or >AP) |
|
|
- Even lines: Amino acid sequences |
|
|
|
|
|
Args: |
|
|
fasta_path: Path to FASTA file |
|
|
max_seq_len: Maximum sequence length to include |
|
|
|
|
|
Returns: |
|
|
Dictionary with sequences, labels, and metadata |
|
|
""" |
|
|
sequences = [] |
|
|
labels = [] |
|
|
headers = [] |
|
|
|
|
|
print(f"Parsing FASTA file: {fasta_path}") |
|
|
print("Label assignment: >AP = AMP (0), >sp = Non-AMP (1)") |
|
|
|
|
|
current_header = "" |
|
|
current_sequence = "" |
|
|
|
|
|
with open(fasta_path, 'r') as f: |
|
|
for line in f: |
|
|
line = line.strip() |
|
|
if line.startswith('>'): |
|
|
|
|
|
if current_sequence and current_header: |
|
|
if 2 <= len(current_sequence) <= max_seq_len: |
|
|
|
|
|
canonical_aa = set('ACDEFGHIKLMNPQRSTVWY') |
|
|
if all(aa in canonical_aa for aa in current_sequence.upper()): |
|
|
sequences.append(current_sequence.upper()) |
|
|
headers.append(current_header) |
|
|
|
|
|
|
|
|
if current_header.startswith('AP'): |
|
|
labels.append(0) |
|
|
elif current_header.startswith('sp'): |
|
|
labels.append(1) |
|
|
else: |
|
|
|
|
|
labels.append(1) |
|
|
print(f"Warning: Unknown header prefix in '{current_header}', defaulting to Non-AMP") |
|
|
|
|
|
|
|
|
current_header = line[1:] |
|
|
current_sequence = "" |
|
|
else: |
|
|
current_sequence += line |
|
|
|
|
|
|
|
|
if current_sequence and current_header: |
|
|
if 2 <= len(current_sequence) <= max_seq_len: |
|
|
canonical_aa = set('ACDEFGHIKLMNPQRSTVWY') |
|
|
if all(aa in canonical_aa for aa in current_sequence.upper()): |
|
|
sequences.append(current_sequence.upper()) |
|
|
headers.append(current_header) |
|
|
|
|
|
|
|
|
if current_header.startswith('AP'): |
|
|
labels.append(0) |
|
|
elif current_header.startswith('sp'): |
|
|
labels.append(1) |
|
|
else: |
|
|
|
|
|
labels.append(1) |
|
|
print(f"Warning: Unknown header prefix in '{current_header}', defaulting to Non-AMP") |
|
|
|
|
|
|
|
|
original_labels = np.array(labels) |
|
|
masked_labels = original_labels.copy() |
|
|
mask_probability = 0.1 |
|
|
mask_indices = np.random.choice( |
|
|
len(original_labels), |
|
|
size=int(len(original_labels) * mask_probability), |
|
|
replace=False |
|
|
) |
|
|
masked_labels[mask_indices] = 2 |
|
|
|
|
|
print(f"✓ Parsed {len(sequences)} valid sequences from FASTA") |
|
|
print(f" AMP sequences: {np.sum(original_labels == 0)}") |
|
|
print(f" Non-AMP sequences: {np.sum(original_labels == 1)}") |
|
|
print(f" Masked for CFG: {len(mask_indices)}") |
|
|
|
|
|
return { |
|
|
'sequences': sequences, |
|
|
'headers': headers, |
|
|
'labels': original_labels.tolist(), |
|
|
'masked_labels': masked_labels.tolist(), |
|
|
'mask_indices': mask_indices.tolist() |
|
|
} |
|
|
|
|
|
class CFGUniProtDataset(Dataset): |
|
|
""" |
|
|
Dataset class for UniProt sequences with classifier-free guidance. |
|
|
|
|
|
This dataset: |
|
|
1. Loads processed UniProt data with AMP classifications |
|
|
2. Handles label masking for CFG training |
|
|
3. Integrates with your existing flow training pipeline |
|
|
4. Provides sequences, labels, and masking information |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
data_path: str, |
|
|
use_masked_labels: bool = True, |
|
|
mask_probability: float = 0.1, |
|
|
max_seq_len: int = 50, |
|
|
device: str = 'cuda'): |
|
|
|
|
|
self.data_path = data_path |
|
|
self.use_masked_labels = use_masked_labels |
|
|
self.mask_probability = mask_probability |
|
|
self.max_seq_len = max_seq_len |
|
|
self.device = device |
|
|
|
|
|
|
|
|
self._load_data() |
|
|
|
|
|
|
|
|
self.label_map = { |
|
|
0: 'amp', |
|
|
1: 'non_amp', |
|
|
2: 'mask' |
|
|
} |
|
|
|
|
|
print(f"CFG Dataset initialized:") |
|
|
print(f" Total sequences: {len(self.sequences)}") |
|
|
print(f" Using masked labels: {use_masked_labels}") |
|
|
print(f" Mask probability: {mask_probability}") |
|
|
print(f" Label distribution: {self._get_label_distribution()}") |
|
|
|
|
|
def _load_data(self): |
|
|
"""Load processed UniProt data.""" |
|
|
if os.path.exists(self.data_path): |
|
|
with open(self.data_path, 'r') as f: |
|
|
data = json.load(f) |
|
|
|
|
|
self.sequences = data['sequences'] |
|
|
self.original_labels = np.array(data['original_labels']) |
|
|
self.masked_labels = np.array(data['masked_labels']) |
|
|
self.mask_indices = set(data['mask_indices']) |
|
|
|
|
|
else: |
|
|
raise FileNotFoundError(f"Data file not found: {self.data_path}") |
|
|
|
|
|
def _get_label_distribution(self) -> Dict[str, int]: |
|
|
"""Get distribution of labels in the dataset.""" |
|
|
labels = self.masked_labels if self.use_masked_labels else self.original_labels |
|
|
unique, counts = np.unique(labels, return_counts=True) |
|
|
return {self.label_map[label]: count for label, count in zip(unique, counts)} |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.sequences) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
|
|
"""Get a single sample with sequence and label.""" |
|
|
sequence = self.sequences[idx] |
|
|
|
|
|
|
|
|
if self.use_masked_labels: |
|
|
label = self.masked_labels[idx] |
|
|
else: |
|
|
label = self.original_labels[idx] |
|
|
|
|
|
|
|
|
is_masked = idx in self.mask_indices |
|
|
|
|
|
return { |
|
|
'sequence': sequence, |
|
|
'label': torch.tensor(label, dtype=torch.long), |
|
|
'original_label': torch.tensor(self.original_labels[idx], dtype=torch.long), |
|
|
'is_masked': torch.tensor(is_masked, dtype=torch.bool), |
|
|
'index': torch.tensor(idx, dtype=torch.long) |
|
|
} |
|
|
|
|
|
def get_label_statistics(self) -> Dict[str, Dict]: |
|
|
"""Get detailed statistics about labels.""" |
|
|
stats = { |
|
|
'original': self._get_label_distribution(), |
|
|
'masked': self._get_label_distribution() if self.use_masked_labels else None, |
|
|
'masking_info': { |
|
|
'total_masked': len(self.mask_indices), |
|
|
'mask_probability': self.mask_probability, |
|
|
'masked_indices': list(self.mask_indices) |
|
|
} |
|
|
} |
|
|
return stats |
|
|
|
|
|
class CFGFlowDataset(Dataset): |
|
|
""" |
|
|
Dataset that integrates CFG labels with your existing flow training pipeline. |
|
|
|
|
|
This dataset: |
|
|
1. Loads your existing AMP embeddings |
|
|
2. Adds CFG labels from UniProt processing |
|
|
3. Handles the integration between embeddings and labels |
|
|
4. Provides data in the format expected by your flow training |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
embeddings_path: str, |
|
|
cfg_data_path: str, |
|
|
use_masked_labels: bool = True, |
|
|
max_seq_len: int = 50, |
|
|
device: str = 'cuda'): |
|
|
|
|
|
self.embeddings_path = embeddings_path |
|
|
self.cfg_data_path = cfg_data_path |
|
|
self.use_masked_labels = use_masked_labels |
|
|
self.max_seq_len = max_seq_len |
|
|
self.device = device |
|
|
|
|
|
|
|
|
self._load_embeddings() |
|
|
self._load_cfg_data() |
|
|
self._align_data() |
|
|
|
|
|
print(f"CFG Flow Dataset initialized:") |
|
|
print(f" AMP embeddings: {self.embeddings.shape}") |
|
|
print(f" CFG labels: {len(self.cfg_labels)}") |
|
|
print(f" Aligned samples: {len(self.aligned_indices)}") |
|
|
|
|
|
def _load_embeddings(self): |
|
|
"""Load your existing AMP embeddings.""" |
|
|
print(f"Loading AMP embeddings from {self.embeddings_path}...") |
|
|
|
|
|
|
|
|
combined_path = os.path.join(self.embeddings_path, "all_peptide_embeddings.pt") |
|
|
|
|
|
if os.path.exists(combined_path): |
|
|
print(f"Loading combined embeddings from {combined_path} (FULL DATA)...") |
|
|
|
|
|
self.embeddings = torch.load(combined_path, map_location='cpu') |
|
|
print(f"✓ Loaded ALL embeddings: {self.embeddings.shape}") |
|
|
else: |
|
|
print("Combined embeddings file not found, loading individual files...") |
|
|
|
|
|
import glob |
|
|
|
|
|
embedding_files = glob.glob(os.path.join(self.embeddings_path, "*.pt")) |
|
|
embedding_files = [f for f in embedding_files if not f.endswith('metadata.json') and not f.endswith('sequence_ids.json') and not f.endswith('all_peptide_embeddings.pt')] |
|
|
|
|
|
print(f"Found {len(embedding_files)} individual embedding files") |
|
|
|
|
|
|
|
|
embeddings_list = [] |
|
|
for file_path in embedding_files: |
|
|
try: |
|
|
embedding = torch.load(file_path, map_location='cpu') |
|
|
if embedding.dim() == 2: |
|
|
embeddings_list.append(embedding) |
|
|
else: |
|
|
print(f"Warning: Skipping {file_path} - unexpected shape {embedding.shape}") |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not load {file_path}: {e}") |
|
|
|
|
|
if not embeddings_list: |
|
|
raise ValueError("No valid embeddings found!") |
|
|
|
|
|
self.embeddings = torch.stack(embeddings_list) |
|
|
print(f"Loaded {len(self.embeddings)} embeddings from individual files") |
|
|
|
|
|
def _load_cfg_data(self): |
|
|
"""Load CFG data from FASTA file with automatic AMP labeling.""" |
|
|
print(f"Loading CFG data from FASTA: {self.cfg_data_path}...") |
|
|
|
|
|
|
|
|
if self.cfg_data_path.endswith('.fasta') or self.cfg_data_path.endswith('.fa'): |
|
|
|
|
|
cfg_data = parse_fasta_with_amp_labels(self.cfg_data_path, self.max_seq_len) |
|
|
|
|
|
self.cfg_sequences = cfg_data['sequences'] |
|
|
self.cfg_headers = cfg_data['headers'] |
|
|
self.cfg_original_labels = np.array(cfg_data['labels']) |
|
|
self.cfg_masked_labels = np.array(cfg_data['masked_labels']) |
|
|
self.cfg_mask_indices = set(cfg_data['mask_indices']) |
|
|
|
|
|
else: |
|
|
|
|
|
with open(self.cfg_data_path, 'r') as f: |
|
|
cfg_data = json.load(f) |
|
|
|
|
|
self.cfg_sequences = cfg_data['sequences'] |
|
|
self.cfg_headers = cfg_data.get('headers', [''] * len(cfg_data['sequences'])) |
|
|
self.cfg_original_labels = np.array(cfg_data['labels']) |
|
|
|
|
|
|
|
|
|
|
|
self.cfg_masked_labels = self.cfg_original_labels.copy() |
|
|
mask_probability = 0.1 |
|
|
mask_indices = np.random.choice( |
|
|
len(self.cfg_original_labels), |
|
|
size=int(len(self.cfg_original_labels) * mask_probability), |
|
|
replace=False |
|
|
) |
|
|
self.cfg_masked_labels[mask_indices] = 2 |
|
|
self.cfg_mask_indices = set(mask_indices) |
|
|
|
|
|
print(f"Loaded {len(self.cfg_sequences)} CFG sequences") |
|
|
print(f"Label distribution: {np.bincount(self.cfg_original_labels)}") |
|
|
print(f"Masked {len(self.cfg_mask_indices)} labels for CFG training") |
|
|
|
|
|
def _align_data(self): |
|
|
"""Align AMP embeddings with CFG data based on sequence matching.""" |
|
|
print("Aligning AMP embeddings with CFG data...") |
|
|
|
|
|
|
|
|
|
|
|
min_samples = min(len(self.embeddings), len(self.cfg_sequences)) |
|
|
|
|
|
self.aligned_indices = list(range(min_samples)) |
|
|
|
|
|
|
|
|
if self.use_masked_labels: |
|
|
self.cfg_labels = self.cfg_masked_labels[:min_samples] |
|
|
else: |
|
|
self.cfg_labels = self.cfg_original_labels[:min_samples] |
|
|
|
|
|
|
|
|
self.aligned_embeddings = self.embeddings[:min_samples] |
|
|
|
|
|
print(f"Aligned {min_samples} samples") |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.aligned_indices) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
|
|
"""Get a single sample with embedding and CFG label.""" |
|
|
|
|
|
embedding = self.aligned_embeddings[idx] |
|
|
label = self.cfg_labels[idx] |
|
|
original_label = self.cfg_original_labels[idx] |
|
|
is_masked = idx in self.cfg_mask_indices |
|
|
|
|
|
return { |
|
|
'embedding': embedding, |
|
|
'label': torch.tensor(label, dtype=torch.long), |
|
|
'original_label': torch.tensor(original_label, dtype=torch.long), |
|
|
'is_masked': torch.tensor(is_masked, dtype=torch.bool), |
|
|
'index': torch.tensor(idx, dtype=torch.long) |
|
|
} |
|
|
|
|
|
def get_embedding_stats(self) -> Dict: |
|
|
"""Get statistics about the embeddings.""" |
|
|
return { |
|
|
'shape': self.aligned_embeddings.shape, |
|
|
'mean': self.aligned_embeddings.mean().item(), |
|
|
'std': self.aligned_embeddings.std().item(), |
|
|
'min': self.aligned_embeddings.min().item(), |
|
|
'max': self.aligned_embeddings.max().item() |
|
|
} |
|
|
|
|
|
def create_cfg_dataloader(dataset: Dataset, |
|
|
batch_size: int = 32, |
|
|
shuffle: bool = True, |
|
|
num_workers: int = 4) -> DataLoader: |
|
|
"""Create a DataLoader for CFG training.""" |
|
|
|
|
|
def collate_fn(batch): |
|
|
"""Custom collate function for CFG data.""" |
|
|
|
|
|
embeddings = torch.stack([item['embedding'] for item in batch]) |
|
|
labels = torch.stack([item['label'] for item in batch]) |
|
|
original_labels = torch.stack([item['original_label'] for item in batch]) |
|
|
is_masked = torch.stack([item['is_masked'] for item in batch]) |
|
|
indices = torch.stack([item['index'] for item in batch]) |
|
|
|
|
|
return { |
|
|
'embeddings': embeddings, |
|
|
'labels': labels, |
|
|
'original_labels': original_labels, |
|
|
'is_masked': is_masked, |
|
|
'indices': indices |
|
|
} |
|
|
|
|
|
return DataLoader( |
|
|
dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=shuffle, |
|
|
num_workers=num_workers, |
|
|
collate_fn=collate_fn, |
|
|
pin_memory=True |
|
|
) |
|
|
|
|
|
def test_cfg_dataset(): |
|
|
"""Test function to verify the CFG dataset works correctly.""" |
|
|
print("Testing CFG Dataset...") |
|
|
|
|
|
|
|
|
test_data = { |
|
|
'sequences': ['MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG', |
|
|
'MKLLIVTFCLTFAAL', |
|
|
'MKLLIVTFCLTFAALMKLLIVTFCLTFAAL'], |
|
|
'original_labels': [0, 1, 0], |
|
|
'masked_labels': [0, 2, 0], |
|
|
'mask_indices': [1] |
|
|
} |
|
|
|
|
|
|
|
|
test_path = 'test_cfg_data.json' |
|
|
with open(test_path, 'w') as f: |
|
|
json.dump(test_data, f) |
|
|
|
|
|
|
|
|
dataset = CFGUniProtDataset(test_path, use_masked_labels=True) |
|
|
|
|
|
print(f"Dataset length: {len(dataset)}") |
|
|
for i in range(len(dataset)): |
|
|
sample = dataset[i] |
|
|
print(f"Sample {i}:") |
|
|
print(f" Sequence: {sample['sequence'][:20]}...") |
|
|
print(f" Label: {sample['label'].item()}") |
|
|
print(f" Original Label: {sample['original_label'].item()}") |
|
|
print(f" Is Masked: {sample['is_masked'].item()}") |
|
|
|
|
|
|
|
|
os.remove(test_path) |
|
|
print("Test completed successfully!") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
test_cfg_dataset() |