import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR import json import numpy as np from tqdm import tqdm # ---------------- Hyperparameters ---------------- ESM_DIM = 1280 # ESM-2 hidden dim (esm2_t33_650M_UR50D) COMP_RATIO = 16 # compression factor COMP_DIM = ESM_DIM // COMP_RATIO MAX_SEQ_LEN = 50 # Actual sequence length from final_sequence_encoder.py BATCH_SIZE = 32 EPOCHS = 30 BASE_LR = 1e-3 # initial learning rate LR_MIN = 8e-5 # minimum learning rate for cosine schedule WARMUP_STEPS = 10_000 DEPTH = 4 # total transformer layers (2 pre-pool, 2 post-pool) HEADS = 8 # attention heads DIM_FF = ESM_DIM * 4 POOLING = True # enforce ProtFlow hourglass pooling # ---------------- Dataset for Pre-computed Embeddings ---------------- class PrecomputedEmbeddingDataset(Dataset): def __init__(self, embeddings_path): """ Load pre-computed embeddings from the final_sequence_encoder.py output. Args: embeddings_path: Path to the directory containing individual .pt embedding files """ print(f"Loading pre-computed embeddings from {embeddings_path}...") # Load all individual embedding files import glob import os embedding_files = glob.glob(os.path.join(embeddings_path, "*.pt")) embedding_files = [f for f in embedding_files if not f.endswith('metadata.json') and not f.endswith('sequence_ids.json')] print(f"Found {len(embedding_files)} embedding files") # Load and stack all embeddings embeddings_list = [] for file_path in embedding_files: try: embedding = torch.load(file_path) if embedding.dim() == 2: # (seq_len, hidden_dim) embeddings_list.append(embedding) else: print(f"Warning: Skipping {file_path} - unexpected shape {embedding.shape}") except Exception as e: print(f"Warning: Could not load {file_path}: {e}") if not embeddings_list: raise ValueError("No valid embeddings found!") self.embeddings = torch.stack(embeddings_list) print(f"Loaded {len(self.embeddings)} embeddings with shape {self.embeddings.shape}") # Ensure embeddings are the right shape if len(self.embeddings.shape) != 3: raise ValueError(f"Expected 3D tensor, got shape {self.embeddings.shape}") if self.embeddings.shape[1] != MAX_SEQ_LEN: print(f"Warning: Expected sequence length {MAX_SEQ_LEN}, got {self.embeddings.shape[1]}") if self.embeddings.shape[2] != ESM_DIM: print(f"Warning: Expected embedding dim {ESM_DIM}, got {self.embeddings.shape[2]}") def __len__(self): return len(self.embeddings) def __getitem__(self, idx): return self.embeddings[idx] # ---------------- Compressor ---------------- class Compressor(nn.Module): def __init__(self, in_dim=ESM_DIM, out_dim=COMP_DIM): super().__init__() self.norm = nn.LayerNorm(in_dim) layer = lambda: nn.TransformerEncoderLayer( d_model=in_dim, nhead=HEADS, dim_feedforward=DIM_FF, batch_first=True) # two layers before pool, two after self.pre_tr = nn.TransformerEncoder(layer(), num_layers=DEPTH//2) self.post_tr = nn.TransformerEncoder(layer(), num_layers=DEPTH//2) self.proj = nn.Sequential( nn.LayerNorm(in_dim), nn.Linear(in_dim, out_dim), nn.Tanh() ) self.pooling = POOLING def forward(self, x, stats=None): if stats: m, s, mn, mx = stats['mean'], stats['std'], stats['min'], stats['max'] # Move stats to the same device as x m = m.to(x.device) s = s.to(x.device) mn = mn.to(x.device) mx = mx.to(x.device) x = torch.clamp((x - m) / s, -4, 4) x = torch.clamp((x - mn) / (mx - mn + 1e-8), 0, 1) x = self.norm(x) x = self.pre_tr(x) # [B, L, D] if self.pooling: B, L, D = x.shape if L % 2: x = x[:, :-1, :] x = x.view(B, L//2, 2, D).mean(2) # halve sequence length x = self.post_tr(x) # [B, L' , D] return self.proj(x) # [B, L', COMP_DIM] # ---------------- Decompressor ---------------- class Decompressor(nn.Module): def __init__(self, in_dim=COMP_DIM, out_dim=ESM_DIM): super().__init__() self.proj = nn.Sequential( nn.LayerNorm(in_dim), nn.Linear(in_dim, out_dim) ) layer = lambda: nn.TransformerEncoderLayer( d_model=out_dim, nhead=HEADS, dim_feedforward=DIM_FF, batch_first=True) self.decoder = nn.TransformerEncoder(layer(), num_layers=DEPTH//2) self.pooling = POOLING def forward(self, z): x = self.proj(z) # [B, L', D] if self.pooling: x = x.repeat_interleave(2, dim=1) # unpool to full length return self.decoder(x) # [B, L, out_dim] # ---------------- Training Loop ---------------- def train_with_precomputed_embeddings(embeddings_path, device='cuda'): """ Train compressor using pre-computed embeddings from final_sequence_encoder.py """ # Load dataset ds = PrecomputedEmbeddingDataset(embeddings_path) # Compute normalization statistics print("Computing normalization statistics...") flat = ds.embeddings.view(-1, ESM_DIM) stats = { 'mean': flat.mean(0), 'std': flat.std(0) + 1e-8, 'min': torch.clamp((flat - flat.mean(0)) / (flat.std(0) + 1e-8), -4,4).min(0)[0], 'max': torch.clamp((flat - flat.mean(0)) / (flat.std(0) + 1e-8), -4,4).max(0)[0] } # Save statistics for later use torch.save(stats, 'normalization_stats.pt') print("Saved normalization statistics to normalization_stats.pt") # Create data loader dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True) # Initialize models comp = Compressor().to(device) decomp = Decompressor().to(device) # Initialize optimizer opt = optim.AdamW(list(comp.parameters()) + list(decomp.parameters()), lr=BASE_LR) # LR scheduling: warmup -> cosine warmup_sched = LinearLR(opt, start_factor=1e-8, end_factor=1.0, total_iters=WARMUP_STEPS) cosine_sched = CosineAnnealingLR(opt, T_max=EPOCHS*len(dl), eta_min=LR_MIN) sched = SequentialLR(opt, [warmup_sched, cosine_sched], milestones=[WARMUP_STEPS]) print(f"Starting training for {EPOCHS} epochs...") print(f"Device: {device}") print(f"Batch size: {BATCH_SIZE}") print(f"Total batches per epoch: {len(dl)}") # Training loop for epoch in range(1, EPOCHS+1): total_loss = 0 comp.train() decomp.train() for batch_idx, x in enumerate(tqdm(dl, desc=f"Epoch {epoch}/{EPOCHS}")): x = x.to(device) z = comp(x, stats) xr = decomp(z) loss = (x - xr).pow(2).mean() opt.zero_grad() loss.backward() opt.step() sched.step() total_loss += loss.item() # Print progress every 100 batches if batch_idx % 100 == 0: print(f" Batch {batch_idx}/{len(dl)} - Loss: {loss.item():.6f}") avg_loss = total_loss / len(dl) print(f"Epoch {epoch}/{EPOCHS} — Average MSE: {avg_loss:.6f}") # Save checkpoint every 5 epochs if epoch % 5 == 0: torch.save({ 'epoch': epoch, 'compressor_state_dict': comp.state_dict(), 'decompressor_state_dict': decomp.state_dict(), 'optimizer_state_dict': opt.state_dict(), 'loss': avg_loss, }, f'checkpoint_epoch_{epoch}.pth') # Save final models torch.save(comp.state_dict(), 'compressor_final.pth') torch.save(decomp.state_dict(), 'decompressor_final.pth') print("Training completed! Models saved as compressor_final.pth and decompressor_final.pth") # ---------------- Utility Functions ---------------- def load_and_test_models(compressor_path, decompressor_path, embeddings_path, device='cuda'): """ Load trained models and test reconstruction quality """ print("Loading trained models...") comp = Compressor().to(device) decomp = Decompressor().to(device) comp.load_state_dict(torch.load(compressor_path)) decomp.load_state_dict(torch.load(decompressor_path)) comp.eval() decomp.eval() # Load test data ds = PrecomputedEmbeddingDataset(embeddings_path) test_loader = DataLoader(ds, batch_size=16, shuffle=False) # Load normalization stats stats = torch.load('normalization_stats.pt') print("Testing reconstruction quality...") total_mse = 0 total_samples = 0 with torch.no_grad(): for batch in tqdm(test_loader, desc="Testing"): x = batch.to(device) z = comp(x, stats) xr = decomp(z) mse = (x - xr).pow(2).mean() total_mse += mse.item() * len(x) total_samples += len(x) avg_mse = total_mse / total_samples print(f"Average reconstruction MSE: {avg_mse:.6f}") return avg_mse # ---------------- Entrypoint ---------------- if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Train protein compressor with pre-computed embeddings') parser.add_argument('--embeddings', type=str, default='/data2/edwardsun/flow_project/compressor_dataset/peptide_embeddings.pt', help='Path to pre-computed embeddings from final_sequence_encoder.py') parser.add_argument('--device', type=str, default='cuda', help='Device to use (cuda/cpu)') parser.add_argument('--test', action='store_true', help='Test existing models instead of training') args = parser.parse_args() device = torch.device(args.device if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") if args.test: # Test existing models load_and_test_models('compressor_final.pth', 'decompressor_final.pth', args.embeddings, device) else: # Train new models train_with_precomputed_embeddings(args.embeddings, device)