import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR import numpy as np from tqdm import tqdm import json import os import argparse import time from torch.cuda.amp import autocast, GradScaler import wandb # For logging (optional) # Import your existing components from compressor_with_embeddings import Compressor, Decompressor, PrecomputedEmbeddingDataset from final_flow_model import AMPFlowMatcherCFGConcat, SinusoidalTimeEmbedding from cfg_dataset import CFGFlowDataset, create_cfg_dataloader # ---------------- Optimized Configuration for H100 ---------------- ESM_DIM = 1280 # ESM-2 hidden dim (esm2_t33_650M_UR50D) COMP_RATIO = 16 # compression factor COMP_DIM = ESM_DIM // COMP_RATIO MAX_SEQ_LEN = 50 # Actual sequence length from final_sequence_encoder.py # OPTIMIZED H100 hyperparameters - HIGH THROUGHPUT + STABLE TRAINING BATCH_SIZE = 512 # PUSH H100 TO LIMITS - using ~70GB memory EPOCHS = 2000 # Slightly more epochs with safer LR for same 5-6 hour target BASE_LR = 8e-4 # SAFE but effective LR - 2x original, not 4x LR_MIN = 4e-4 # Conservative minimum learning rate WARMUP_STEPS = 4000 # Gentler warmup to avoid explosion GPU_ID = 0 # Use GPU 0 # Training optimizations USE_MIXED_PRECISION = True # BF16 for H100 GRADIENT_CLIP_NORM = 0.5 # TIGHTER gradient clipping for flow matching stability WEIGHT_DECAY = 0.01 # Weight decay for regularization VALIDATION_INTERVAL = 5000 # Validate every 5K steps (more frequent) CHECKPOINT_INTERVAL = 300 # Save checkpoint every 300 epochs (more frequent) NUM_WORKERS = 32 # MAXIMIZED data loading workers for H100 # CFG training parameters CFG_DROPOUT_RATE = 0.15 # 15% of batches as unconditional for CFG class AMPFlowTrainerSingleGPUFullData: """ Optimized Single GPU training pipeline for AMP generation using ProtFlow methodology. Uses ALL available data with H100-optimized settings for overnight training. """ def __init__(self, embeddings_path, cfg_data_path, use_wandb=False): self.device = torch.device(f'cuda:{GPU_ID}') self.embeddings_path = embeddings_path self.cfg_data_path = cfg_data_path self.use_wandb = use_wandb # Enable H100 optimizations torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True print(f"Using GPU {GPU_ID} for optimized H100 training") print(f"Mixed precision: {USE_MIXED_PRECISION}") print(f"Batch size: {BATCH_SIZE}") print(f"Target epochs: {EPOCHS}") print(f"Learning rate: {BASE_LR} -> {LR_MIN}") # Initialize mixed precision training if USE_MIXED_PRECISION: self.scaler = GradScaler() print("✓ Mixed precision training enabled (BF16)") # Initialize wandb if requested if self.use_wandb: wandb.init( project="amp-flow-training", config={ "batch_size": BATCH_SIZE, "epochs": EPOCHS, "base_lr": BASE_LR, "lr_min": LR_MIN, "warmup_steps": WARMUP_STEPS, "mixed_precision": USE_MIXED_PRECISION, "gradient_clip": GRADIENT_CLIP_NORM, "weight_decay": WEIGHT_DECAY } ) print(f"Loading ALL AMP embeddings from {embeddings_path}...") # Load ALL embeddings (use the combined file instead of individual files) self._load_all_embeddings() # Compute normalization statistics print("Computing preprocessing statistics...") self._compute_preprocessing_stats() # Initialize models self._initialize_models() # Initialize datasets and dataloaders self._initialize_data() # Initialize optimizer and scheduler self._initialize_optimizer() print("✓ Optimized Single GPU training setup complete with FULL DATA!") def _load_all_embeddings(self): """Load ALL peptide embeddings from the combined file.""" # Try to load the combined embeddings file first combined_path = os.path.join(self.embeddings_path, "all_peptide_embeddings.pt") if os.path.exists(combined_path): print(f"Loading combined embeddings from {combined_path}...") self.embeddings = torch.load(combined_path, map_location=self.device) print(f"✓ Loaded ALL embeddings: {self.embeddings.shape}") else: print("Combined embeddings file not found, loading individual files...") # Fallback to individual files import glob embedding_files = glob.glob(os.path.join(self.embeddings_path, "*.pt")) embedding_files = [f for f in embedding_files if not f.endswith('metadata.json') and not f.endswith('sequence_ids.json') and not f.endswith('all_peptide_embeddings.pt')] print(f"Found {len(embedding_files)} individual embedding files") # Load and stack all embeddings embeddings_list = [] for file_path in embedding_files: try: embedding = torch.load(file_path) if embedding.dim() == 2: # (seq_len, hidden_dim) embeddings_list.append(embedding) else: print(f"Warning: Skipping {file_path} - unexpected shape {embedding.shape}") except Exception as e: print(f"Warning: Could not load {file_path}: {e}") if not embeddings_list: raise ValueError("No valid embeddings found!") self.embeddings = torch.stack(embeddings_list) print(f"Loaded {len(self.embeddings)} embeddings from individual files") def _compute_preprocessing_stats(self): """Compute normalization statistics for embeddings.""" # Flatten all embeddings flat_embeddings = self.embeddings.reshape(-1, ESM_DIM) # Compute statistics mean = flat_embeddings.mean(dim=0) std = flat_embeddings.std(dim=0) min_val = flat_embeddings.min() max_val = flat_embeddings.max() self.stats = { 'mean': mean, 'std': std, 'min': min_val, 'max': max_val } # Save statistics torch.save(self.stats, 'normalization_stats.pt') print(f"✓ Statistics computed and saved:") print(f" Total embeddings: {len(self.embeddings):,}") print(f" Mean: {mean.mean():.4f} ± {mean.std():.4f}") print(f" Std: {std.mean():.4f} ± {std.std():.4f}") print(f" Range: [{min_val:.4f}, {max_val:.4f}]") def _initialize_models(self): """Initialize compressor, decompressor, and flow model.""" print("Initializing models...") # Load pre-trained compressor and decompressor self.compressor = Compressor().to(self.device) self.decompressor = Decompressor().to(self.device) self.compressor.load_state_dict(torch.load('final_compressor_model.pth', map_location=self.device)) self.decompressor.load_state_dict(torch.load('final_decompressor_model.pth', map_location=self.device)) # Initialize flow model with CFG self.flow_model = AMPFlowMatcherCFGConcat( hidden_dim=480, compressed_dim=COMP_DIM, n_layers=12, n_heads=16, dim_ff=3072, max_seq_len=25, # MAX_SEQ_LEN // 2 due to pooling use_cfg=True ).to(self.device) # Compile model for PyTorch 2.x speedup (if available) try: self.flow_model = torch.compile(self.flow_model, mode="reduce-overhead") print("✓ Model compiled with torch.compile for speedup") except Exception as e: print(f"⚠️ Model compilation failed: {e}") # Set models to training mode self.compressor.train() self.decompressor.train() self.flow_model.train() print(f"✓ Models initialized:") print(f" Compressor parameters: {sum(p.numel() for p in self.compressor.parameters()):,}") print(f" Decompressor parameters: {sum(p.numel() for p in self.decompressor.parameters()):,}") print(f" Flow model parameters: {sum(p.numel() for p in self.flow_model.parameters()):,}") def _initialize_data(self): """Initialize datasets and dataloaders with FULL data.""" print("Initializing datasets with FULL data...") # Create CFG dataset with FULL UniProt data self.cfg_dataset = CFGFlowDataset( embeddings_path=self.embeddings_path, cfg_data_path=self.cfg_data_path, use_masked_labels=True, max_seq_len=MAX_SEQ_LEN, device=self.device ) # Create dataloader with optimized settings self.dataloader = create_cfg_dataloader( self.cfg_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS ) # Calculate total steps and validation intervals self.total_steps = len(self.dataloader) * EPOCHS self.validation_steps = VALIDATION_INTERVAL print(f"✓ Dataset initialized with FULL data:") print(f" Total samples: {len(self.cfg_dataset):,}") print(f" Batch size: {BATCH_SIZE}") print(f" Batches per epoch: {len(self.dataloader):,}") print(f" Total training steps: {self.total_steps:,}") print(f" Validation every: {self.validation_steps:,} steps") def _initialize_optimizer(self): """Initialize optimizer and learning rate scheduler.""" print("Initializing optimizer and scheduler...") # Optimizer for flow model only (compressor/decompressor are frozen) self.optimizer = optim.AdamW( self.flow_model.parameters(), lr=BASE_LR, weight_decay=WEIGHT_DECAY, betas=(0.9, 0.98), # Optimized betas for flow matching eps=1e-6 # Lower epsilon for numerical stability ) # Learning rate scheduler with proper warmup and cosine annealing warmup_scheduler = LinearLR( self.optimizer, start_factor=0.1, end_factor=1.0, total_iters=WARMUP_STEPS ) main_scheduler = CosineAnnealingLR( self.optimizer, T_max=self.total_steps - WARMUP_STEPS, eta_min=LR_MIN ) self.scheduler = SequentialLR( self.optimizer, schedulers=[warmup_scheduler, main_scheduler], milestones=[WARMUP_STEPS] ) print(f"✓ Optimizer initialized:") print(f" Base LR: {BASE_LR}") print(f" Min LR: {LR_MIN}") print(f" Warmup steps: {WARMUP_STEPS}") print(f" Weight decay: {WEIGHT_DECAY}") print(f" Gradient clip norm: {GRADIENT_CLIP_NORM}") def _preprocess_batch(self, batch): """Preprocess a batch of data for training.""" # Extract data embeddings = batch['embeddings'].to(self.device) # (B, L, ESM_DIM) labels = batch['labels'].to(self.device) # (B,) # Normalize embeddings m, s = self.stats['mean'].to(self.device), self.stats['std'].to(self.device) mn, mx = self.stats['min'].to(self.device), self.stats['max'].to(self.device) embeddings = (embeddings - m) / (s + 1e-8) embeddings = (embeddings - mn) / (mx - mn + 1e-8) # Compress embeddings with torch.no_grad(): compressed = self.compressor(embeddings) # (B, L, COMP_DIM) return compressed, labels def _compute_validation_metrics(self): """Compute validation metrics on a subset of data.""" self.flow_model.eval() val_losses = [] # Use a subset of data for validation val_samples = min(1000, len(self.cfg_dataset)) val_indices = torch.randperm(len(self.cfg_dataset))[:val_samples] with torch.no_grad(): for i in range(0, val_samples, BATCH_SIZE): batch_indices = val_indices[i:i+BATCH_SIZE] batch_data = [self.cfg_dataset[idx] for idx in batch_indices] # Collate batch embeddings = torch.stack([item['embedding'] for item in batch_data]) labels = torch.stack([item['label'] for item in batch_data]) # Preprocess compressed, labels = self._preprocess_batch({ 'embeddings': embeddings, 'labels': labels }) B, L, D = compressed.shape # Sample random time t = torch.rand(B, device=self.device) # Sample random noise eps = torch.randn_like(compressed) # Compute target xt = (1 - t.unsqueeze(-1).unsqueeze(-1)) * compressed + t.unsqueeze(-1).unsqueeze(-1) * eps # Predict vector field vt_pred = self.flow_model(xt, t, labels=labels) # Target vector field vt_target = eps - compressed # Compute loss loss = F.mse_loss(vt_pred, vt_target) val_losses.append(loss.item()) self.flow_model.train() return np.mean(val_losses) def train_flow_matching(self): """Train the flow matching model with FULL data and optimizations.""" print(f"🚀 Starting Optimized Single GPU Flow Matching Training with FULL DATA") print(f"GPU: {GPU_ID}") print(f"Total iterations: {EPOCHS}") print(f"Batch size: {BATCH_SIZE}") print(f"Total samples: {len(self.cfg_dataset):,}") print(f"Mixed precision: {USE_MIXED_PRECISION}") print(f"Estimated time: ~8-10 hours (overnight training with ALL data)") print("=" * 60) # Training loop best_loss = float('inf') losses = [] val_losses = [] global_step = 0 start_time = time.time() for epoch in tqdm(range(EPOCHS), desc="Training Flow Model"): epoch_losses = [] epoch_start_time = time.time() for batch_idx, batch in enumerate(self.dataloader): # Preprocess batch compressed, labels = self._preprocess_batch(batch) B, L, D = compressed.shape # CFG training: randomly mask some labels for unconditional training if torch.rand(1).item() < CFG_DROPOUT_RATE: labels = torch.full_like(labels, fill_value=-1) # Unconditional # Sample random time t = torch.rand(B, device=self.device) # (B,) # Sample random noise eps = torch.randn_like(compressed) # (B, L, D) # Compute target: x_t = (1-t) * x_0 + t * eps xt = (1 - t.unsqueeze(-1).unsqueeze(-1)) * compressed + t.unsqueeze(-1).unsqueeze(-1) * eps # Forward pass with mixed precision if USE_MIXED_PRECISION: with autocast(dtype=torch.bfloat16): vt_pred = self.flow_model(xt, t, labels=labels) # (B, L, D) vt_target = eps - compressed # (B, L, D) loss = F.mse_loss(vt_pred, vt_target) # Backward pass with gradient scaling self.optimizer.zero_grad() self.scaler.scale(loss).backward() # Gradient clipping self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.flow_model.parameters(), max_norm=GRADIENT_CLIP_NORM) self.scaler.step(self.optimizer) self.scaler.update() else: # Standard training vt_pred = self.flow_model(xt, t, labels=labels) # (B, L, D) vt_target = eps - compressed # (B, L, D) loss = F.mse_loss(vt_pred, vt_target) # Backward pass self.optimizer.zero_grad() loss.backward() # Gradient clipping torch.nn.utils.clip_grad_norm_(self.flow_model.parameters(), max_norm=GRADIENT_CLIP_NORM) self.optimizer.step() # Update learning rate self.scheduler.step() epoch_losses.append(loss.item()) global_step += 1 # Logging if batch_idx % 100 == 0: current_lr = self.scheduler.get_last_lr()[0] elapsed_time = time.time() - start_time steps_per_sec = global_step / elapsed_time eta_hours = (self.total_steps - global_step) / steps_per_sec / 3600 print(f"Epoch {epoch:4d} | Step {global_step:6d}/{self.total_steps:6d} | " f"Loss: {loss.item():.6f} | LR: {current_lr:.2e} | " f"Speed: {steps_per_sec:.1f} steps/s | ETA: {eta_hours:.1f}h") # Log to wandb if self.use_wandb: wandb.log({ 'train/loss': loss.item(), 'train/learning_rate': current_lr, 'train/steps_per_sec': steps_per_sec, 'train/global_step': global_step }) # Validation if global_step % self.validation_steps == 0: val_loss = self._compute_validation_metrics() val_losses.append(val_loss) print(f"Validation at step {global_step}: Loss = {val_loss:.6f}") if self.use_wandb: wandb.log({ 'val/loss': val_loss, 'val/global_step': global_step }) # Early stopping check if val_loss < best_loss: best_loss = val_loss self._save_checkpoint(epoch, val_loss, global_step, is_final=False, is_best=True) # Compute epoch statistics avg_loss = np.mean(epoch_losses) losses.append(avg_loss) epoch_time = time.time() - epoch_start_time print(f"Epoch {epoch:4d} | Avg Loss: {avg_loss:.6f} | " f"LR: {self.scheduler.get_last_lr()[0]:.2e} | " f"Time: {epoch_time:.1f}s | Samples: {len(self.cfg_dataset):,}") # Save checkpoint if (epoch + 1) % CHECKPOINT_INTERVAL == 0: self._save_checkpoint(epoch, avg_loss, global_step, is_final=True) # Save final model self._save_checkpoint(EPOCHS - 1, losses[-1], global_step, is_final=True) total_time = time.time() - start_time print("=" * 60) print("🎉 Optimized Training Complete with FULL DATA!") print(f"Best validation loss: {best_loss:.6f}") print(f"Total training time: {total_time/3600:.1f} hours") print(f"Total samples used: {len(self.cfg_dataset):,}") print(f"Final model saved as: amp_flow_model_final_optimized.pth") return losses, val_losses def _save_checkpoint(self, step, loss, global_step, is_final=False, is_best=False): """Save model checkpoint.""" # Create output directory if it doesn't exist output_dir = '/data2/edwardsun/flow_checkpoints' os.makedirs(output_dir, exist_ok=True) if is_best: filename = os.path.join(output_dir, 'amp_flow_model_best_optimized.pth') elif is_final: filename = os.path.join(output_dir, 'amp_flow_model_final_optimized.pth') else: filename = os.path.join(output_dir, f'amp_flow_checkpoint_optimized_step_{step:04d}.pth') checkpoint = { 'step': step, 'global_step': global_step, 'loss': loss, 'flow_model_state_dict': self.flow_model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), 'stats': self.stats, 'total_samples': len(self.cfg_dataset), 'config': { 'batch_size': BATCH_SIZE, 'epochs': EPOCHS, 'base_lr': BASE_LR, 'lr_min': LR_MIN, 'warmup_steps': WARMUP_STEPS, 'mixed_precision': USE_MIXED_PRECISION, 'gradient_clip': GRADIENT_CLIP_NORM, 'weight_decay': WEIGHT_DECAY } } torch.save(checkpoint, filename) print(f"✓ Checkpoint saved: {filename} (loss: {loss:.6f}, step: {global_step})") def main(): """Main training function.""" global BATCH_SIZE, EPOCHS parser = argparse.ArgumentParser(description='Optimized Single GPU AMP Flow Training with FULL DATA') parser.add_argument('--embeddings', default='/data2/edwardsun/flow_project/peptide_embeddings/', help='Path to peptide embeddings directory') parser.add_argument('--cfg_data', default='/data2/edwardsun/flow_project/test_uniprot_processed/uniprot_processed_data.json', help='Path to FULL CFG data file') parser.add_argument('--use_wandb', action='store_true', help='Use wandb for logging') parser.add_argument('--batch_size', type=int, default=BATCH_SIZE, help='Batch size for training') parser.add_argument('--epochs', type=int, default=EPOCHS, help='Number of training epochs') args = parser.parse_args() # Update global variables if provided if args.batch_size != BATCH_SIZE: BATCH_SIZE = args.batch_size if args.epochs != EPOCHS: EPOCHS = args.epochs print(f"Starting optimized training with batch_size={BATCH_SIZE}, epochs={EPOCHS}") # Initialize trainer trainer = AMPFlowTrainerSingleGPUFullData(args.embeddings, args.cfg_data, args.use_wandb) # Start training losses, val_losses = trainer.train_flow_matching() print("Optimized training completed successfully with FULL DATA!") if __name__ == "__main__": main()