import torch import torch.nn.functional as F import numpy as np from tqdm import tqdm import os from datetime import datetime #d Import torchdiffeq for proper ODE solving try: from torchdiffeq import odeint TORCHDIFFEQ_AVAILABLE = True print("✓ torchdiffeq available for proper ODE solving") except ImportError: TORCHDIFFEQ_AVAILABLE = False print("⚠️ torchdiffeq not available, using manual Euler integration") # Import your components from compressor_with_embeddings import Compressor, Decompressor from final_flow_model import AMPFlowMatcherCFGConcat, AMPProtFlowPipelineCFG class AMPGenerator: """ Generate AMP samples using trained ProtFlow model. """ def __init__(self, model_path, device='cuda'): self.device = device # Load models self._load_models(model_path) # Load preprocessing statistics self.stats = torch.load('normalization_stats.pt', map_location=device) def _load_models(self, model_path): """Load trained models.""" print("Loading trained models...") # Load compressor and decompressor self.compressor = Compressor().to(self.device) self.decompressor = Decompressor().to(self.device) self.compressor.load_state_dict(torch.load('/data2/edwardsun/flow_amp/models/final_compressor_model.pth', map_location=self.device)) self.decompressor.load_state_dict(torch.load('/data2/edwardsun/flow_amp/models/final_decompressor_model.pth', map_location=self.device)) # Load flow matching model with CFG self.flow_model = AMPFlowMatcherCFGConcat( hidden_dim=480, compressed_dim=80, # 1280 // 16 n_layers=12, n_heads=16, dim_ff=3072, max_seq_len=25, use_cfg=True ).to(self.device) checkpoint = torch.load(model_path, map_location=self.device, weights_only=False) # Handle PyTorch compilation wrapper state_dict = checkpoint['flow_model_state_dict'] new_state_dict = {} for key, value in state_dict.items(): # Remove _orig_mod prefix if present if key.startswith('_orig_mod.'): new_key = key[10:] # Remove '_orig_mod.' prefix else: new_key = key new_state_dict[new_key] = value self.flow_model.load_state_dict(new_state_dict) print(f"✓ All models loaded successfully from step {checkpoint['step']}!") print(f" Loss at checkpoint: {checkpoint['loss']:.6f}") # Initialize ODE solving capabilities if TORCHDIFFEQ_AVAILABLE: print("✓ Enhanced with proper ODE solving (torchdiffeq)") else: print("⚠️ Using fallback Euler integration") def _create_ode_func(self, cfg_scale=7.5): """Create ODE function for torchdiffeq integration.""" def ode_func(t, x): """ ODE function: dx/dt = v_theta(x, t) Args: t: scalar time (single float) x: state tensor [B*L*D] (flattened) Returns: dx/dt: derivative [B*L*D] (flattened) """ # Reshape x back to [B, L, D] batch_size, seq_len, dim = self.current_shape x = x.view(batch_size, seq_len, dim) # Create time tensor for batch t_tensor = torch.full((batch_size,), t, device=self.device, dtype=x.dtype) # Compute vector field with CFG if cfg_scale > 0: # With AMP condition amp_labels = torch.full((batch_size,), 0, device=self.device) # 0 = AMP vt_cond = self.flow_model(x, t_tensor, labels=amp_labels) # Without condition (mask) mask_labels = torch.full((batch_size,), 2, device=self.device) # 2 = Mask vt_uncond = self.flow_model(x, t_tensor, labels=mask_labels) # CFG interpolation vt = vt_uncond + cfg_scale * (vt_cond - vt_uncond) else: # No CFG, use mask label mask_labels = torch.full((batch_size,), 2, device=self.device) vt = self.flow_model(x, t_tensor, labels=mask_labels) # Return flattened derivative return vt.view(-1) return ode_func def generate_amps(self, num_samples=100, num_steps=25, batch_size=32, cfg_scale=7.5, ode_method='dopri5', rtol=1e-5, atol=1e-6): """ Generate AMP samples using flow matching with CFG and improved ODE solving. Args: num_samples: Number of AMP samples to generate num_steps: Number of ODE solving steps (25 for good quality, 1 for reflow) batch_size: Batch size for generation cfg_scale: CFG guidance scale (higher = stronger conditioning) ode_method: ODE solver method ('dopri5', 'rk4', 'euler', 'adaptive_heun') rtol: Relative tolerance for adaptive solvers atol: Absolute tolerance for adaptive solvers """ method_str = f"{ode_method} ODE solver" if TORCHDIFFEQ_AVAILABLE and ode_method != 'euler' else "manual Euler integration" print(f"Generating {num_samples} AMP samples with {method_str} (CFG scale: {cfg_scale})...") if TORCHDIFFEQ_AVAILABLE and ode_method != 'euler': print(f" Method: {ode_method}, rtol={rtol}, atol={atol}") self.flow_model.eval() self.compressor.eval() self.decompressor.eval() all_generated = [] with torch.no_grad(): for i in tqdm(range(0, num_samples, batch_size), desc="Generating with improved ODE"): current_batch = min(batch_size, num_samples - i) # Sample random noise (starting point at t=1) eps = torch.randn(current_batch, 25, 80, device=self.device) # [B, L', COMP_DIM] # Choose ODE solving method if TORCHDIFFEQ_AVAILABLE and ode_method != 'euler': # Use proper ODE solver try: # Store shape for ODE function self.current_shape = eps.shape # Create ODE function ode_func = self._create_ode_func(cfg_scale=cfg_scale) # Time span: from t=1 (noise) to t=0 (data) t_span = torch.tensor([1.0, 0.0], device=self.device, dtype=eps.dtype) # Flatten initial condition for torchdiffeq y0 = eps.view(-1) # Solve ODE with proper adaptive solver if ode_method in ['dopri5', 'adaptive_heun']: # Adaptive solvers solution = odeint( ode_func, y0, t_span, method=ode_method, rtol=rtol, atol=atol, options={'max_num_steps': 1000} ) else: # Fixed-step solvers solution = odeint( ode_func, y0, t_span, method=ode_method, options={'step_size': 0.04} # 1/25 for 25 steps ) # Get final solution (at t=0) xt = solution[-1].view(self.current_shape) except Exception as e: print(f"⚠️ ODE solving failed for batch {i//batch_size + 1}: {e}") print("Falling back to Euler method...") # Fall through to Euler method xt = self._generate_with_euler(eps, current_batch, cfg_scale, num_steps) else: # Use manual Euler integration (original method) xt = self._generate_with_euler(eps, current_batch, cfg_scale, num_steps) # Decompress to get embeddings decompressed = self.decompressor(xt) # [B, L, ESM_DIM] # Apply reverse preprocessing m, s, mn, mx = self.stats['mean'], self.stats['std'], self.stats['min'], self.stats['max'] decompressed = decompressed * (mx - mn + 1e-8) + mn decompressed = decompressed * s + m all_generated.append(decompressed.cpu()) # Concatenate all batches generated_embeddings = torch.cat(all_generated, dim=0) print(f"✓ Generated {generated_embeddings.shape[0]} AMP embeddings") print(f" Shape: {generated_embeddings.shape}") print(f" Stats - Mean: {generated_embeddings.mean():.4f}, Std: {generated_embeddings.std():.4f}") return generated_embeddings def _generate_with_euler(self, eps, current_batch, cfg_scale, num_steps): """Fallback Euler integration method (original implementation).""" xt = eps.clone() amp_labels = torch.full((current_batch,), 0, device=self.device) # 0 = AMP mask_labels = torch.full((current_batch,), 2, device=self.device) # 2 = Mask for step in range(num_steps): t = torch.ones(current_batch, device=self.device) * (1.0 - step/num_steps) # CFG: Generate with condition and without condition if cfg_scale > 0: # With AMP condition vt_cond = self.flow_model(xt, t, labels=amp_labels) # Without condition (mask) vt_uncond = self.flow_model(xt, t, labels=mask_labels) # CFG interpolation vt = vt_uncond + cfg_scale * (vt_cond - vt_uncond) else: # No CFG, use mask label vt = self.flow_model(xt, t, labels=mask_labels) # Euler step for backward integration (t: 1 -> 0) dt = -1.0 / num_steps xt = xt + vt * dt return xt def compare_ode_methods(self, num_samples=20, cfg_scale=7.5): """ Compare different ODE solving methods for quality assessment. """ if not TORCHDIFFEQ_AVAILABLE: print("⚠️ torchdiffeq not available, cannot compare ODE methods") return self.generate_amps(num_samples=num_samples, cfg_scale=cfg_scale) methods = ['euler', 'rk4', 'dopri5', 'adaptive_heun'] results = {} print("🔬 Comparing ODE solving methods...") for method in methods: print(f"\n--- Testing {method} ---") try: start_time = torch.cuda.Event(enable_timing=True) end_time = torch.cuda.Event(enable_timing=True) start_time.record() embeddings = self.generate_amps( num_samples=num_samples, batch_size=10, cfg_scale=cfg_scale, ode_method=method ) end_time.record() torch.cuda.synchronize() elapsed_time = start_time.elapsed_time(end_time) / 1000.0 # Convert to seconds results[method] = { 'embeddings': embeddings, 'time': elapsed_time, 'mean': embeddings.mean().item(), 'std': embeddings.std().item(), 'success': True } print(f"✓ {method}: {elapsed_time:.2f}s, mean={embeddings.mean():.4f}, std={embeddings.std():.4f}") except Exception as e: print(f"❌ {method} failed: {e}") results[method] = {'success': False, 'error': str(e)} return results def generate_with_reflow(self, num_samples=100): """ Generate AMP samples using 1-step reflow (if you have reflow model). """ print(f"Generating {num_samples} AMP samples with 1-step reflow...") # This would use the reflow implementation # For now, just use 1-step generation return self.generate_amps(num_samples=num_samples, num_steps=1, batch_size=32) def main(): """Main generation function.""" print("=== AMP Generation Pipeline with CFG ===") # Use the best model from training (lowest validation loss: 0.017183) model_path = '/data2/edwardsun/flow_checkpoints/amp_flow_model_best_optimized.pth' # Check if checkpoint exists try: checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) print(f"✓ Found best model at step {checkpoint['step']} with loss {checkpoint['loss']:.6f}") print(f" Global step: {checkpoint['global_step']}") print(f" Total samples: {checkpoint['total_samples']:,}") except: print(f"❌ Best model not found: {model_path}") print("Please train the flow matching model first using amp_flow_training.py") return # Initialize generator generator = AMPGenerator(model_path, device='cuda') # Test ODE methods comparison if available if TORCHDIFFEQ_AVAILABLE: print("\n🔬 Comparing ODE solving methods...") comparison_results = generator.compare_ode_methods(num_samples=10, cfg_scale=7.5) # Use best method for generation best_method = 'dopri5' # Recommended method print(f"\n🚀 Using {best_method} for main generation...") else: best_method = 'euler' print("\n⚠️ Using fallback Euler integration...") # Generate samples with different CFG scales using improved ODE solving print("\n1. Generating with CFG scale 0.0 (no conditioning)...") samples_no_cfg = generator.generate_amps(num_samples=20, num_steps=25, cfg_scale=0.0, ode_method=best_method) print("\n2. Generating with CFG scale 3.0 (weak conditioning)...") samples_weak_cfg = generator.generate_amps(num_samples=20, num_steps=25, cfg_scale=3.0, ode_method=best_method) print("\n3. Generating with CFG scale 7.5 (strong conditioning)...") samples_strong_cfg = generator.generate_amps(num_samples=20, num_steps=25, cfg_scale=7.5, ode_method=best_method) print("\n4. Generating with CFG scale 15.0 (very strong conditioning)...") samples_very_strong_cfg = generator.generate_amps(num_samples=20, num_steps=25, cfg_scale=15.0, ode_method=best_method) # Create output directory if it doesn't exist output_dir = '/data2/edwardsun/generated_samples' os.makedirs(output_dir, exist_ok=True) # Get today's date for filename today = datetime.now().strftime('%Y%m%d') # Save generated samples with date torch.save(samples_no_cfg, os.path.join(output_dir, f'generated_amps_best_model_no_cfg_{today}.pt')) torch.save(samples_weak_cfg, os.path.join(output_dir, f'generated_amps_best_model_weak_cfg_{today}.pt')) torch.save(samples_strong_cfg, os.path.join(output_dir, f'generated_amps_best_model_strong_cfg_{today}.pt')) torch.save(samples_very_strong_cfg, os.path.join(output_dir, f'generated_amps_best_model_very_strong_cfg_{today}.pt')) print("\n✓ Generation complete!") print(f"Generated samples saved (Date: {today}):") print(f" - generated_amps_best_model_no_cfg_{today}.pt (no conditioning)") print(f" - generated_amps_best_model_weak_cfg_{today}.pt (weak CFG)") print(f" - generated_amps_best_model_strong_cfg_{today}.pt (strong CFG)") print(f" - generated_amps_best_model_very_strong_cfg_{today}.pt (very strong CFG)") print("\nCFG Analysis:") print(" - CFG scale 0.0: No conditioning, generates diverse sequences") print(" - CFG scale 3.0: Weak AMP conditioning") print(" - CFG scale 7.5: Strong AMP conditioning (recommended)") print(" - CFG scale 15.0: Very strong AMP conditioning (may be too restrictive)") print("\nNext steps:") print("1. Decode embeddings back to sequences using ESM-2 decoder") print("2. Evaluate with ProtFlow metrics (FPD, MMD, ESM-2 perplexity)") print("3. Compare sequences generated with different CFG scales") print("4. Evaluate AMP properties (antimicrobial activity, toxicity)") if TORCHDIFFEQ_AVAILABLE: print(f"5. ✓ Enhanced generation with {best_method} ODE solver") else: print("5. Install torchdiffeq for improved ODE solving: pip install torchdiffeq") if __name__ == "__main__": main()