FlowFinal / src /generate_amps.py
esunAI's picture
Add generate_amps.py
37158e8 verified
raw
history blame
17.4 kB
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import os
from datetime import datetime
#d Import torchdiffeq for proper ODE solving
try:
from torchdiffeq import odeint
TORCHDIFFEQ_AVAILABLE = True
print("โœ“ torchdiffeq available for proper ODE solving")
except ImportError:
TORCHDIFFEQ_AVAILABLE = False
print("โš ๏ธ torchdiffeq not available, using manual Euler integration")
# Import your components
from compressor_with_embeddings import Compressor, Decompressor
from final_flow_model import AMPFlowMatcherCFGConcat, AMPProtFlowPipelineCFG
class AMPGenerator:
"""
Generate AMP samples using trained ProtFlow model.
"""
def __init__(self, model_path, device='cuda'):
self.device = device
# Load models
self._load_models(model_path)
# Load preprocessing statistics
self.stats = torch.load('normalization_stats.pt', map_location=device)
def _load_models(self, model_path):
"""Load trained models."""
print("Loading trained models...")
# Load compressor and decompressor
self.compressor = Compressor().to(self.device)
self.decompressor = Decompressor().to(self.device)
self.compressor.load_state_dict(torch.load('/data2/edwardsun/flow_amp/models/final_compressor_model.pth', map_location=self.device))
self.decompressor.load_state_dict(torch.load('/data2/edwardsun/flow_amp/models/final_decompressor_model.pth', map_location=self.device))
# Load flow matching model with CFG
self.flow_model = AMPFlowMatcherCFGConcat(
hidden_dim=480,
compressed_dim=80, # 1280 // 16
n_layers=12,
n_heads=16,
dim_ff=3072,
max_seq_len=25,
use_cfg=True
).to(self.device)
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
# Handle PyTorch compilation wrapper
state_dict = checkpoint['flow_model_state_dict']
new_state_dict = {}
for key, value in state_dict.items():
# Remove _orig_mod prefix if present
if key.startswith('_orig_mod.'):
new_key = key[10:] # Remove '_orig_mod.' prefix
else:
new_key = key
new_state_dict[new_key] = value
self.flow_model.load_state_dict(new_state_dict)
print(f"โœ“ All models loaded successfully from step {checkpoint['step']}!")
print(f" Loss at checkpoint: {checkpoint['loss']:.6f}")
# Initialize ODE solving capabilities
if TORCHDIFFEQ_AVAILABLE:
print("โœ“ Enhanced with proper ODE solving (torchdiffeq)")
else:
print("โš ๏ธ Using fallback Euler integration")
def _create_ode_func(self, cfg_scale=7.5):
"""Create ODE function for torchdiffeq integration."""
def ode_func(t, x):
"""
ODE function: dx/dt = v_theta(x, t)
Args:
t: scalar time (single float)
x: state tensor [B*L*D] (flattened)
Returns:
dx/dt: derivative [B*L*D] (flattened)
"""
# Reshape x back to [B, L, D]
batch_size, seq_len, dim = self.current_shape
x = x.view(batch_size, seq_len, dim)
# Create time tensor for batch
t_tensor = torch.full((batch_size,), t, device=self.device, dtype=x.dtype)
# Compute vector field with CFG
if cfg_scale > 0:
# With AMP condition
amp_labels = torch.full((batch_size,), 0, device=self.device) # 0 = AMP
vt_cond = self.flow_model(x, t_tensor, labels=amp_labels)
# Without condition (mask)
mask_labels = torch.full((batch_size,), 2, device=self.device) # 2 = Mask
vt_uncond = self.flow_model(x, t_tensor, labels=mask_labels)
# CFG interpolation
vt = vt_uncond + cfg_scale * (vt_cond - vt_uncond)
else:
# No CFG, use mask label
mask_labels = torch.full((batch_size,), 2, device=self.device)
vt = self.flow_model(x, t_tensor, labels=mask_labels)
# Return flattened derivative
return vt.view(-1)
return ode_func
def generate_amps(self, num_samples=100, num_steps=25, batch_size=32, cfg_scale=7.5,
ode_method='dopri5', rtol=1e-5, atol=1e-6):
"""
Generate AMP samples using flow matching with CFG and improved ODE solving.
Args:
num_samples: Number of AMP samples to generate
num_steps: Number of ODE solving steps (25 for good quality, 1 for reflow)
batch_size: Batch size for generation
cfg_scale: CFG guidance scale (higher = stronger conditioning)
ode_method: ODE solver method ('dopri5', 'rk4', 'euler', 'adaptive_heun')
rtol: Relative tolerance for adaptive solvers
atol: Absolute tolerance for adaptive solvers
"""
method_str = f"{ode_method} ODE solver" if TORCHDIFFEQ_AVAILABLE and ode_method != 'euler' else "manual Euler integration"
print(f"Generating {num_samples} AMP samples with {method_str} (CFG scale: {cfg_scale})...")
if TORCHDIFFEQ_AVAILABLE and ode_method != 'euler':
print(f" Method: {ode_method}, rtol={rtol}, atol={atol}")
self.flow_model.eval()
self.compressor.eval()
self.decompressor.eval()
all_generated = []
with torch.no_grad():
for i in tqdm(range(0, num_samples, batch_size), desc="Generating with improved ODE"):
current_batch = min(batch_size, num_samples - i)
# Sample random noise (starting point at t=1)
eps = torch.randn(current_batch, 25, 80, device=self.device) # [B, L', COMP_DIM]
# Choose ODE solving method
if TORCHDIFFEQ_AVAILABLE and ode_method != 'euler':
# Use proper ODE solver
try:
# Store shape for ODE function
self.current_shape = eps.shape
# Create ODE function
ode_func = self._create_ode_func(cfg_scale=cfg_scale)
# Time span: from t=1 (noise) to t=0 (data)
t_span = torch.tensor([1.0, 0.0], device=self.device, dtype=eps.dtype)
# Flatten initial condition for torchdiffeq
y0 = eps.view(-1)
# Solve ODE with proper adaptive solver
if ode_method in ['dopri5', 'adaptive_heun']:
# Adaptive solvers
solution = odeint(
ode_func, y0, t_span,
method=ode_method,
rtol=rtol,
atol=atol,
options={'max_num_steps': 1000}
)
else:
# Fixed-step solvers
solution = odeint(
ode_func, y0, t_span,
method=ode_method,
options={'step_size': 0.04} # 1/25 for 25 steps
)
# Get final solution (at t=0)
xt = solution[-1].view(self.current_shape)
except Exception as e:
print(f"โš ๏ธ ODE solving failed for batch {i//batch_size + 1}: {e}")
print("Falling back to Euler method...")
# Fall through to Euler method
xt = self._generate_with_euler(eps, current_batch, cfg_scale, num_steps)
else:
# Use manual Euler integration (original method)
xt = self._generate_with_euler(eps, current_batch, cfg_scale, num_steps)
# Decompress to get embeddings
decompressed = self.decompressor(xt) # [B, L, ESM_DIM]
# Apply reverse preprocessing
m, s, mn, mx = self.stats['mean'], self.stats['std'], self.stats['min'], self.stats['max']
decompressed = decompressed * (mx - mn + 1e-8) + mn
decompressed = decompressed * s + m
all_generated.append(decompressed.cpu())
# Concatenate all batches
generated_embeddings = torch.cat(all_generated, dim=0)
print(f"โœ“ Generated {generated_embeddings.shape[0]} AMP embeddings")
print(f" Shape: {generated_embeddings.shape}")
print(f" Stats - Mean: {generated_embeddings.mean():.4f}, Std: {generated_embeddings.std():.4f}")
return generated_embeddings
def _generate_with_euler(self, eps, current_batch, cfg_scale, num_steps):
"""Fallback Euler integration method (original implementation)."""
xt = eps.clone()
amp_labels = torch.full((current_batch,), 0, device=self.device) # 0 = AMP
mask_labels = torch.full((current_batch,), 2, device=self.device) # 2 = Mask
for step in range(num_steps):
t = torch.ones(current_batch, device=self.device) * (1.0 - step/num_steps)
# CFG: Generate with condition and without condition
if cfg_scale > 0:
# With AMP condition
vt_cond = self.flow_model(xt, t, labels=amp_labels)
# Without condition (mask)
vt_uncond = self.flow_model(xt, t, labels=mask_labels)
# CFG interpolation
vt = vt_uncond + cfg_scale * (vt_cond - vt_uncond)
else:
# No CFG, use mask label
vt = self.flow_model(xt, t, labels=mask_labels)
# Euler step for backward integration (t: 1 -> 0)
dt = -1.0 / num_steps
xt = xt + vt * dt
return xt
def compare_ode_methods(self, num_samples=20, cfg_scale=7.5):
"""
Compare different ODE solving methods for quality assessment.
"""
if not TORCHDIFFEQ_AVAILABLE:
print("โš ๏ธ torchdiffeq not available, cannot compare ODE methods")
return self.generate_amps(num_samples=num_samples, cfg_scale=cfg_scale)
methods = ['euler', 'rk4', 'dopri5', 'adaptive_heun']
results = {}
print("๐Ÿ”ฌ Comparing ODE solving methods...")
for method in methods:
print(f"\n--- Testing {method} ---")
try:
start_time = torch.cuda.Event(enable_timing=True)
end_time = torch.cuda.Event(enable_timing=True)
start_time.record()
embeddings = self.generate_amps(
num_samples=num_samples,
batch_size=10,
cfg_scale=cfg_scale,
ode_method=method
)
end_time.record()
torch.cuda.synchronize()
elapsed_time = start_time.elapsed_time(end_time) / 1000.0 # Convert to seconds
results[method] = {
'embeddings': embeddings,
'time': elapsed_time,
'mean': embeddings.mean().item(),
'std': embeddings.std().item(),
'success': True
}
print(f"โœ“ {method}: {elapsed_time:.2f}s, mean={embeddings.mean():.4f}, std={embeddings.std():.4f}")
except Exception as e:
print(f"โŒ {method} failed: {e}")
results[method] = {'success': False, 'error': str(e)}
return results
def generate_with_reflow(self, num_samples=100):
"""
Generate AMP samples using 1-step reflow (if you have reflow model).
"""
print(f"Generating {num_samples} AMP samples with 1-step reflow...")
# This would use the reflow implementation
# For now, just use 1-step generation
return self.generate_amps(num_samples=num_samples, num_steps=1, batch_size=32)
def main():
"""Main generation function."""
print("=== AMP Generation Pipeline with CFG ===")
# Use the best model from training (lowest validation loss: 0.017183)
model_path = '/data2/edwardsun/flow_checkpoints/amp_flow_model_best_optimized.pth'
# Check if checkpoint exists
try:
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
print(f"โœ“ Found best model at step {checkpoint['step']} with loss {checkpoint['loss']:.6f}")
print(f" Global step: {checkpoint['global_step']}")
print(f" Total samples: {checkpoint['total_samples']:,}")
except:
print(f"โŒ Best model not found: {model_path}")
print("Please train the flow matching model first using amp_flow_training.py")
return
# Initialize generator
generator = AMPGenerator(model_path, device='cuda')
# Test ODE methods comparison if available
if TORCHDIFFEQ_AVAILABLE:
print("\n๐Ÿ”ฌ Comparing ODE solving methods...")
comparison_results = generator.compare_ode_methods(num_samples=10, cfg_scale=7.5)
# Use best method for generation
best_method = 'dopri5' # Recommended method
print(f"\n๐Ÿš€ Using {best_method} for main generation...")
else:
best_method = 'euler'
print("\nโš ๏ธ Using fallback Euler integration...")
# Generate samples with different CFG scales using improved ODE solving
print("\n1. Generating with CFG scale 0.0 (no conditioning)...")
samples_no_cfg = generator.generate_amps(num_samples=20, num_steps=25, cfg_scale=0.0, ode_method=best_method)
print("\n2. Generating with CFG scale 3.0 (weak conditioning)...")
samples_weak_cfg = generator.generate_amps(num_samples=20, num_steps=25, cfg_scale=3.0, ode_method=best_method)
print("\n3. Generating with CFG scale 7.5 (strong conditioning)...")
samples_strong_cfg = generator.generate_amps(num_samples=20, num_steps=25, cfg_scale=7.5, ode_method=best_method)
print("\n4. Generating with CFG scale 15.0 (very strong conditioning)...")
samples_very_strong_cfg = generator.generate_amps(num_samples=20, num_steps=25, cfg_scale=15.0, ode_method=best_method)
# Create output directory if it doesn't exist
output_dir = '/data2/edwardsun/generated_samples'
os.makedirs(output_dir, exist_ok=True)
# Get today's date for filename
today = datetime.now().strftime('%Y%m%d')
# Save generated samples with date
torch.save(samples_no_cfg, os.path.join(output_dir, f'generated_amps_best_model_no_cfg_{today}.pt'))
torch.save(samples_weak_cfg, os.path.join(output_dir, f'generated_amps_best_model_weak_cfg_{today}.pt'))
torch.save(samples_strong_cfg, os.path.join(output_dir, f'generated_amps_best_model_strong_cfg_{today}.pt'))
torch.save(samples_very_strong_cfg, os.path.join(output_dir, f'generated_amps_best_model_very_strong_cfg_{today}.pt'))
print("\nโœ“ Generation complete!")
print(f"Generated samples saved (Date: {today}):")
print(f" - generated_amps_best_model_no_cfg_{today}.pt (no conditioning)")
print(f" - generated_amps_best_model_weak_cfg_{today}.pt (weak CFG)")
print(f" - generated_amps_best_model_strong_cfg_{today}.pt (strong CFG)")
print(f" - generated_amps_best_model_very_strong_cfg_{today}.pt (very strong CFG)")
print("\nCFG Analysis:")
print(" - CFG scale 0.0: No conditioning, generates diverse sequences")
print(" - CFG scale 3.0: Weak AMP conditioning")
print(" - CFG scale 7.5: Strong AMP conditioning (recommended)")
print(" - CFG scale 15.0: Very strong AMP conditioning (may be too restrictive)")
print("\nNext steps:")
print("1. Decode embeddings back to sequences using ESM-2 decoder")
print("2. Evaluate with ProtFlow metrics (FPD, MMD, ESM-2 perplexity)")
print("3. Compare sequences generated with different CFG scales")
print("4. Evaluate AMP properties (antimicrobial activity, toxicity)")
if TORCHDIFFEQ_AVAILABLE:
print(f"5. โœ“ Enhanced generation with {best_method} ODE solver")
else:
print("5. Install torchdiffeq for improved ODE solving: pip install torchdiffeq")
if __name__ == "__main__":
main()