Add cfg_dataset.py
Browse files- src/cfg_dataset.py +441 -0
src/cfg_dataset.py
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.utils.data import Dataset, DataLoader
|
| 4 |
+
import numpy as np
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
from typing import Dict, List, Tuple, Optional
|
| 8 |
+
import random
|
| 9 |
+
import re
|
| 10 |
+
|
| 11 |
+
def parse_fasta_with_amp_labels(fasta_path: str, max_seq_len: int = 50) -> Dict[str, any]:
|
| 12 |
+
"""
|
| 13 |
+
Parse FASTA file and assign AMP/Non-AMP labels based on header prefixes.
|
| 14 |
+
|
| 15 |
+
Label assignment strategy:
|
| 16 |
+
- AMP (0): Headers starting with '>AP'
|
| 17 |
+
- Non-AMP (1): Headers starting with '>sp'
|
| 18 |
+
- Mask (2): Used for CFG training (randomly assigned)
|
| 19 |
+
|
| 20 |
+
File format:
|
| 21 |
+
- Odd lines: Headers (>sp or >AP)
|
| 22 |
+
- Even lines: Amino acid sequences
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
fasta_path: Path to FASTA file
|
| 26 |
+
max_seq_len: Maximum sequence length to include
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
Dictionary with sequences, labels, and metadata
|
| 30 |
+
"""
|
| 31 |
+
sequences = []
|
| 32 |
+
labels = []
|
| 33 |
+
headers = []
|
| 34 |
+
|
| 35 |
+
print(f"Parsing FASTA file: {fasta_path}")
|
| 36 |
+
print("Label assignment: >AP = AMP (0), >sp = Non-AMP (1)")
|
| 37 |
+
|
| 38 |
+
current_header = ""
|
| 39 |
+
current_sequence = ""
|
| 40 |
+
|
| 41 |
+
with open(fasta_path, 'r') as f:
|
| 42 |
+
for line in f:
|
| 43 |
+
line = line.strip()
|
| 44 |
+
if line.startswith('>'):
|
| 45 |
+
# Process previous sequence if exists
|
| 46 |
+
if current_sequence and current_header:
|
| 47 |
+
if 2 <= len(current_sequence) <= max_seq_len:
|
| 48 |
+
# Check if sequence contains only canonical amino acids
|
| 49 |
+
canonical_aa = set('ACDEFGHIKLMNPQRSTVWY')
|
| 50 |
+
if all(aa in canonical_aa for aa in current_sequence.upper()):
|
| 51 |
+
sequences.append(current_sequence.upper())
|
| 52 |
+
headers.append(current_header)
|
| 53 |
+
|
| 54 |
+
# Assign label based on header prefix
|
| 55 |
+
if current_header.startswith('AP'):
|
| 56 |
+
labels.append(0) # AMP
|
| 57 |
+
elif current_header.startswith('sp'):
|
| 58 |
+
labels.append(1) # Non-AMP
|
| 59 |
+
else:
|
| 60 |
+
# Unknown prefix, default to Non-AMP
|
| 61 |
+
labels.append(1)
|
| 62 |
+
print(f"Warning: Unknown header prefix in '{current_header}', defaulting to Non-AMP")
|
| 63 |
+
|
| 64 |
+
# Start new sequence (remove '>' from header)
|
| 65 |
+
current_header = line[1:]
|
| 66 |
+
current_sequence = ""
|
| 67 |
+
else:
|
| 68 |
+
current_sequence += line
|
| 69 |
+
|
| 70 |
+
# Process last sequence
|
| 71 |
+
if current_sequence and current_header:
|
| 72 |
+
if 2 <= len(current_sequence) <= max_seq_len:
|
| 73 |
+
canonical_aa = set('ACDEFGHIKLMNPQRSTVWY')
|
| 74 |
+
if all(aa in canonical_aa for aa in current_sequence.upper()):
|
| 75 |
+
sequences.append(current_sequence.upper())
|
| 76 |
+
headers.append(current_header)
|
| 77 |
+
|
| 78 |
+
# Assign label based on header prefix
|
| 79 |
+
if current_header.startswith('AP'):
|
| 80 |
+
labels.append(0) # AMP
|
| 81 |
+
elif current_header.startswith('sp'):
|
| 82 |
+
labels.append(1) # Non-AMP
|
| 83 |
+
else:
|
| 84 |
+
# Unknown prefix, default to Non-AMP
|
| 85 |
+
labels.append(1)
|
| 86 |
+
print(f"Warning: Unknown header prefix in '{current_header}', defaulting to Non-AMP")
|
| 87 |
+
|
| 88 |
+
# Create masked labels for CFG training (10% masked)
|
| 89 |
+
original_labels = np.array(labels)
|
| 90 |
+
masked_labels = original_labels.copy()
|
| 91 |
+
mask_probability = 0.1
|
| 92 |
+
mask_indices = np.random.choice(
|
| 93 |
+
len(original_labels),
|
| 94 |
+
size=int(len(original_labels) * mask_probability),
|
| 95 |
+
replace=False
|
| 96 |
+
)
|
| 97 |
+
masked_labels[mask_indices] = 2 # 2 = mask/unknown
|
| 98 |
+
|
| 99 |
+
print(f"✓ Parsed {len(sequences)} valid sequences from FASTA")
|
| 100 |
+
print(f" AMP sequences: {np.sum(original_labels == 0)}")
|
| 101 |
+
print(f" Non-AMP sequences: {np.sum(original_labels == 1)}")
|
| 102 |
+
print(f" Masked for CFG: {len(mask_indices)}")
|
| 103 |
+
|
| 104 |
+
return {
|
| 105 |
+
'sequences': sequences,
|
| 106 |
+
'headers': headers,
|
| 107 |
+
'labels': original_labels.tolist(),
|
| 108 |
+
'masked_labels': masked_labels.tolist(),
|
| 109 |
+
'mask_indices': mask_indices.tolist()
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
class CFGUniProtDataset(Dataset):
|
| 113 |
+
"""
|
| 114 |
+
Dataset class for UniProt sequences with classifier-free guidance.
|
| 115 |
+
|
| 116 |
+
This dataset:
|
| 117 |
+
1. Loads processed UniProt data with AMP classifications
|
| 118 |
+
2. Handles label masking for CFG training
|
| 119 |
+
3. Integrates with your existing flow training pipeline
|
| 120 |
+
4. Provides sequences, labels, and masking information
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
def __init__(self,
|
| 124 |
+
data_path: str,
|
| 125 |
+
use_masked_labels: bool = True,
|
| 126 |
+
mask_probability: float = 0.1,
|
| 127 |
+
max_seq_len: int = 50,
|
| 128 |
+
device: str = 'cuda'):
|
| 129 |
+
|
| 130 |
+
self.data_path = data_path
|
| 131 |
+
self.use_masked_labels = use_masked_labels
|
| 132 |
+
self.mask_probability = mask_probability
|
| 133 |
+
self.max_seq_len = max_seq_len
|
| 134 |
+
self.device = device
|
| 135 |
+
|
| 136 |
+
# Load processed data
|
| 137 |
+
self._load_data()
|
| 138 |
+
|
| 139 |
+
# Label mapping
|
| 140 |
+
self.label_map = {
|
| 141 |
+
0: 'amp', # MIC < 100
|
| 142 |
+
1: 'non_amp', # MIC > 100
|
| 143 |
+
2: 'mask' # Unknown MIC
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
print(f"CFG Dataset initialized:")
|
| 147 |
+
print(f" Total sequences: {len(self.sequences)}")
|
| 148 |
+
print(f" Using masked labels: {use_masked_labels}")
|
| 149 |
+
print(f" Mask probability: {mask_probability}")
|
| 150 |
+
print(f" Label distribution: {self._get_label_distribution()}")
|
| 151 |
+
|
| 152 |
+
def _load_data(self):
|
| 153 |
+
"""Load processed UniProt data."""
|
| 154 |
+
if os.path.exists(self.data_path):
|
| 155 |
+
with open(self.data_path, 'r') as f:
|
| 156 |
+
data = json.load(f)
|
| 157 |
+
|
| 158 |
+
self.sequences = data['sequences']
|
| 159 |
+
self.original_labels = np.array(data['original_labels'])
|
| 160 |
+
self.masked_labels = np.array(data['masked_labels'])
|
| 161 |
+
self.mask_indices = set(data['mask_indices'])
|
| 162 |
+
|
| 163 |
+
else:
|
| 164 |
+
raise FileNotFoundError(f"Data file not found: {self.data_path}")
|
| 165 |
+
|
| 166 |
+
def _get_label_distribution(self) -> Dict[str, int]:
|
| 167 |
+
"""Get distribution of labels in the dataset."""
|
| 168 |
+
labels = self.masked_labels if self.use_masked_labels else self.original_labels
|
| 169 |
+
unique, counts = np.unique(labels, return_counts=True)
|
| 170 |
+
return {self.label_map[label]: count for label, count in zip(unique, counts)}
|
| 171 |
+
|
| 172 |
+
def __len__(self) -> int:
|
| 173 |
+
return len(self.sequences)
|
| 174 |
+
|
| 175 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 176 |
+
"""Get a single sample with sequence and label."""
|
| 177 |
+
sequence = self.sequences[idx]
|
| 178 |
+
|
| 179 |
+
# Get appropriate label
|
| 180 |
+
if self.use_masked_labels:
|
| 181 |
+
label = self.masked_labels[idx]
|
| 182 |
+
else:
|
| 183 |
+
label = self.original_labels[idx]
|
| 184 |
+
|
| 185 |
+
# Check if this sample was masked
|
| 186 |
+
is_masked = idx in self.mask_indices
|
| 187 |
+
|
| 188 |
+
return {
|
| 189 |
+
'sequence': sequence,
|
| 190 |
+
'label': torch.tensor(label, dtype=torch.long),
|
| 191 |
+
'original_label': torch.tensor(self.original_labels[idx], dtype=torch.long),
|
| 192 |
+
'is_masked': torch.tensor(is_masked, dtype=torch.bool),
|
| 193 |
+
'index': torch.tensor(idx, dtype=torch.long)
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
def get_label_statistics(self) -> Dict[str, Dict]:
|
| 197 |
+
"""Get detailed statistics about labels."""
|
| 198 |
+
stats = {
|
| 199 |
+
'original': self._get_label_distribution(),
|
| 200 |
+
'masked': self._get_label_distribution() if self.use_masked_labels else None,
|
| 201 |
+
'masking_info': {
|
| 202 |
+
'total_masked': len(self.mask_indices),
|
| 203 |
+
'mask_probability': self.mask_probability,
|
| 204 |
+
'masked_indices': list(self.mask_indices)
|
| 205 |
+
}
|
| 206 |
+
}
|
| 207 |
+
return stats
|
| 208 |
+
|
| 209 |
+
class CFGFlowDataset(Dataset):
|
| 210 |
+
"""
|
| 211 |
+
Dataset that integrates CFG labels with your existing flow training pipeline.
|
| 212 |
+
|
| 213 |
+
This dataset:
|
| 214 |
+
1. Loads your existing AMP embeddings
|
| 215 |
+
2. Adds CFG labels from UniProt processing
|
| 216 |
+
3. Handles the integration between embeddings and labels
|
| 217 |
+
4. Provides data in the format expected by your flow training
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
def __init__(self,
|
| 221 |
+
embeddings_path: str,
|
| 222 |
+
cfg_data_path: str,
|
| 223 |
+
use_masked_labels: bool = True,
|
| 224 |
+
max_seq_len: int = 50,
|
| 225 |
+
device: str = 'cuda'):
|
| 226 |
+
|
| 227 |
+
self.embeddings_path = embeddings_path
|
| 228 |
+
self.cfg_data_path = cfg_data_path
|
| 229 |
+
self.use_masked_labels = use_masked_labels
|
| 230 |
+
self.max_seq_len = max_seq_len
|
| 231 |
+
self.device = device
|
| 232 |
+
|
| 233 |
+
# Load data
|
| 234 |
+
self._load_embeddings()
|
| 235 |
+
self._load_cfg_data()
|
| 236 |
+
self._align_data()
|
| 237 |
+
|
| 238 |
+
print(f"CFG Flow Dataset initialized:")
|
| 239 |
+
print(f" AMP embeddings: {self.embeddings.shape}")
|
| 240 |
+
print(f" CFG labels: {len(self.cfg_labels)}")
|
| 241 |
+
print(f" Aligned samples: {len(self.aligned_indices)}")
|
| 242 |
+
|
| 243 |
+
def _load_embeddings(self):
|
| 244 |
+
"""Load your existing AMP embeddings."""
|
| 245 |
+
print(f"Loading AMP embeddings from {self.embeddings_path}...")
|
| 246 |
+
|
| 247 |
+
# Try to load the combined embeddings file first (FULL DATA)
|
| 248 |
+
combined_path = os.path.join(self.embeddings_path, "all_peptide_embeddings.pt")
|
| 249 |
+
|
| 250 |
+
if os.path.exists(combined_path):
|
| 251 |
+
print(f"Loading combined embeddings from {combined_path} (FULL DATA)...")
|
| 252 |
+
# Load on CPU first to avoid CUDA issues with DataLoader workers
|
| 253 |
+
self.embeddings = torch.load(combined_path, map_location='cpu')
|
| 254 |
+
print(f"✓ Loaded ALL embeddings: {self.embeddings.shape}")
|
| 255 |
+
else:
|
| 256 |
+
print("Combined embeddings file not found, loading individual files...")
|
| 257 |
+
# Fallback to individual files
|
| 258 |
+
import glob
|
| 259 |
+
|
| 260 |
+
embedding_files = glob.glob(os.path.join(self.embeddings_path, "*.pt"))
|
| 261 |
+
embedding_files = [f for f in embedding_files if not f.endswith('metadata.json') and not f.endswith('sequence_ids.json') and not f.endswith('all_peptide_embeddings.pt')]
|
| 262 |
+
|
| 263 |
+
print(f"Found {len(embedding_files)} individual embedding files")
|
| 264 |
+
|
| 265 |
+
# Load and stack all embeddings
|
| 266 |
+
embeddings_list = []
|
| 267 |
+
for file_path in embedding_files:
|
| 268 |
+
try:
|
| 269 |
+
embedding = torch.load(file_path, map_location='cpu')
|
| 270 |
+
if embedding.dim() == 2: # (seq_len, hidden_dim)
|
| 271 |
+
embeddings_list.append(embedding)
|
| 272 |
+
else:
|
| 273 |
+
print(f"Warning: Skipping {file_path} - unexpected shape {embedding.shape}")
|
| 274 |
+
except Exception as e:
|
| 275 |
+
print(f"Warning: Could not load {file_path}: {e}")
|
| 276 |
+
|
| 277 |
+
if not embeddings_list:
|
| 278 |
+
raise ValueError("No valid embeddings found!")
|
| 279 |
+
|
| 280 |
+
self.embeddings = torch.stack(embeddings_list)
|
| 281 |
+
print(f"Loaded {len(self.embeddings)} embeddings from individual files")
|
| 282 |
+
|
| 283 |
+
def _load_cfg_data(self):
|
| 284 |
+
"""Load CFG data from FASTA file with automatic AMP labeling."""
|
| 285 |
+
print(f"Loading CFG data from FASTA: {self.cfg_data_path}...")
|
| 286 |
+
|
| 287 |
+
# Check if it's a FASTA file or JSON file
|
| 288 |
+
if self.cfg_data_path.endswith('.fasta') or self.cfg_data_path.endswith('.fa'):
|
| 289 |
+
# Parse FASTA file with automatic labeling
|
| 290 |
+
cfg_data = parse_fasta_with_amp_labels(self.cfg_data_path, self.max_seq_len)
|
| 291 |
+
|
| 292 |
+
self.cfg_sequences = cfg_data['sequences']
|
| 293 |
+
self.cfg_headers = cfg_data['headers']
|
| 294 |
+
self.cfg_original_labels = np.array(cfg_data['labels'])
|
| 295 |
+
self.cfg_masked_labels = np.array(cfg_data['masked_labels'])
|
| 296 |
+
self.cfg_mask_indices = set(cfg_data['mask_indices'])
|
| 297 |
+
|
| 298 |
+
else:
|
| 299 |
+
# Legacy JSON format support
|
| 300 |
+
with open(self.cfg_data_path, 'r') as f:
|
| 301 |
+
cfg_data = json.load(f)
|
| 302 |
+
|
| 303 |
+
self.cfg_sequences = cfg_data['sequences']
|
| 304 |
+
self.cfg_headers = cfg_data.get('headers', [''] * len(cfg_data['sequences']))
|
| 305 |
+
self.cfg_original_labels = np.array(cfg_data['labels'])
|
| 306 |
+
|
| 307 |
+
# For CFG training, we need to create masked labels
|
| 308 |
+
# Randomly mask 10% of labels for CFG training
|
| 309 |
+
self.cfg_masked_labels = self.cfg_original_labels.copy()
|
| 310 |
+
mask_probability = 0.1
|
| 311 |
+
mask_indices = np.random.choice(
|
| 312 |
+
len(self.cfg_original_labels),
|
| 313 |
+
size=int(len(self.cfg_original_labels) * mask_probability),
|
| 314 |
+
replace=False
|
| 315 |
+
)
|
| 316 |
+
self.cfg_masked_labels[mask_indices] = 2 # 2 = mask/unknown
|
| 317 |
+
self.cfg_mask_indices = set(mask_indices)
|
| 318 |
+
|
| 319 |
+
print(f"Loaded {len(self.cfg_sequences)} CFG sequences")
|
| 320 |
+
print(f"Label distribution: {np.bincount(self.cfg_original_labels)}")
|
| 321 |
+
print(f"Masked {len(self.cfg_mask_indices)} labels for CFG training")
|
| 322 |
+
|
| 323 |
+
def _align_data(self):
|
| 324 |
+
"""Align AMP embeddings with CFG data based on sequence matching."""
|
| 325 |
+
print("Aligning AMP embeddings with CFG data...")
|
| 326 |
+
|
| 327 |
+
# For now, we'll use a simple approach: take the first N sequences
|
| 328 |
+
# where N is the minimum of embeddings and CFG data
|
| 329 |
+
min_samples = min(len(self.embeddings), len(self.cfg_sequences))
|
| 330 |
+
|
| 331 |
+
self.aligned_indices = list(range(min_samples))
|
| 332 |
+
|
| 333 |
+
# Align labels
|
| 334 |
+
if self.use_masked_labels:
|
| 335 |
+
self.cfg_labels = self.cfg_masked_labels[:min_samples]
|
| 336 |
+
else:
|
| 337 |
+
self.cfg_labels = self.cfg_original_labels[:min_samples]
|
| 338 |
+
|
| 339 |
+
# Align embeddings
|
| 340 |
+
self.aligned_embeddings = self.embeddings[:min_samples]
|
| 341 |
+
|
| 342 |
+
print(f"Aligned {min_samples} samples")
|
| 343 |
+
|
| 344 |
+
def __len__(self) -> int:
|
| 345 |
+
return len(self.aligned_indices)
|
| 346 |
+
|
| 347 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 348 |
+
"""Get a single sample with embedding and CFG label."""
|
| 349 |
+
# Embeddings are already on CPU
|
| 350 |
+
embedding = self.aligned_embeddings[idx]
|
| 351 |
+
label = self.cfg_labels[idx]
|
| 352 |
+
original_label = self.cfg_original_labels[idx]
|
| 353 |
+
is_masked = idx in self.cfg_mask_indices
|
| 354 |
+
|
| 355 |
+
return {
|
| 356 |
+
'embedding': embedding,
|
| 357 |
+
'label': torch.tensor(label, dtype=torch.long),
|
| 358 |
+
'original_label': torch.tensor(original_label, dtype=torch.long),
|
| 359 |
+
'is_masked': torch.tensor(is_masked, dtype=torch.bool),
|
| 360 |
+
'index': torch.tensor(idx, dtype=torch.long)
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
def get_embedding_stats(self) -> Dict:
|
| 364 |
+
"""Get statistics about the embeddings."""
|
| 365 |
+
return {
|
| 366 |
+
'shape': self.aligned_embeddings.shape,
|
| 367 |
+
'mean': self.aligned_embeddings.mean().item(),
|
| 368 |
+
'std': self.aligned_embeddings.std().item(),
|
| 369 |
+
'min': self.aligned_embeddings.min().item(),
|
| 370 |
+
'max': self.aligned_embeddings.max().item()
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
def create_cfg_dataloader(dataset: Dataset,
|
| 374 |
+
batch_size: int = 32,
|
| 375 |
+
shuffle: bool = True,
|
| 376 |
+
num_workers: int = 4) -> DataLoader:
|
| 377 |
+
"""Create a DataLoader for CFG training."""
|
| 378 |
+
|
| 379 |
+
def collate_fn(batch):
|
| 380 |
+
"""Custom collate function for CFG data."""
|
| 381 |
+
# Separate different types of data
|
| 382 |
+
embeddings = torch.stack([item['embedding'] for item in batch])
|
| 383 |
+
labels = torch.stack([item['label'] for item in batch])
|
| 384 |
+
original_labels = torch.stack([item['original_label'] for item in batch])
|
| 385 |
+
is_masked = torch.stack([item['is_masked'] for item in batch])
|
| 386 |
+
indices = torch.stack([item['index'] for item in batch])
|
| 387 |
+
|
| 388 |
+
return {
|
| 389 |
+
'embeddings': embeddings,
|
| 390 |
+
'labels': labels,
|
| 391 |
+
'original_labels': original_labels,
|
| 392 |
+
'is_masked': is_masked,
|
| 393 |
+
'indices': indices
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
return DataLoader(
|
| 397 |
+
dataset,
|
| 398 |
+
batch_size=batch_size,
|
| 399 |
+
shuffle=shuffle,
|
| 400 |
+
num_workers=num_workers,
|
| 401 |
+
collate_fn=collate_fn,
|
| 402 |
+
pin_memory=True
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
def test_cfg_dataset():
|
| 406 |
+
"""Test function to verify the CFG dataset works correctly."""
|
| 407 |
+
print("Testing CFG Dataset...")
|
| 408 |
+
|
| 409 |
+
# Test with a small subset
|
| 410 |
+
test_data = {
|
| 411 |
+
'sequences': ['MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG',
|
| 412 |
+
'MKLLIVTFCLTFAAL',
|
| 413 |
+
'MKLLIVTFCLTFAALMKLLIVTFCLTFAAL'],
|
| 414 |
+
'original_labels': [0, 1, 0], # amp, non_amp, amp
|
| 415 |
+
'masked_labels': [0, 2, 0], # amp, mask, amp
|
| 416 |
+
'mask_indices': [1] # Only second sequence is masked
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
# Save test data
|
| 420 |
+
test_path = 'test_cfg_data.json'
|
| 421 |
+
with open(test_path, 'w') as f:
|
| 422 |
+
json.dump(test_data, f)
|
| 423 |
+
|
| 424 |
+
# Test dataset
|
| 425 |
+
dataset = CFGUniProtDataset(test_path, use_masked_labels=True)
|
| 426 |
+
|
| 427 |
+
print(f"Dataset length: {len(dataset)}")
|
| 428 |
+
for i in range(len(dataset)):
|
| 429 |
+
sample = dataset[i]
|
| 430 |
+
print(f"Sample {i}:")
|
| 431 |
+
print(f" Sequence: {sample['sequence'][:20]}...")
|
| 432 |
+
print(f" Label: {sample['label'].item()}")
|
| 433 |
+
print(f" Original Label: {sample['original_label'].item()}")
|
| 434 |
+
print(f" Is Masked: {sample['is_masked'].item()}")
|
| 435 |
+
|
| 436 |
+
# Clean up
|
| 437 |
+
os.remove(test_path)
|
| 438 |
+
print("Test completed successfully!")
|
| 439 |
+
|
| 440 |
+
if __name__ == "__main__":
|
| 441 |
+
test_cfg_dataset()
|