|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
import os |
|
|
from datetime import datetime |
|
|
|
|
|
try: |
|
|
from torchdiffeq import odeint |
|
|
TORCHDIFFEQ_AVAILABLE = True |
|
|
print("โ torchdiffeq available for proper ODE solving") |
|
|
except ImportError: |
|
|
TORCHDIFFEQ_AVAILABLE = False |
|
|
print("โ ๏ธ torchdiffeq not available, using manual Euler integration") |
|
|
|
|
|
|
|
|
from compressor_with_embeddings import Compressor, Decompressor |
|
|
from final_flow_model import AMPFlowMatcherCFGConcat, AMPProtFlowPipelineCFG |
|
|
|
|
|
class AMPGenerator: |
|
|
""" |
|
|
Generate AMP samples using trained ProtFlow model. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_path, device='cuda'): |
|
|
self.device = device |
|
|
|
|
|
|
|
|
self._load_models(model_path) |
|
|
|
|
|
|
|
|
self.stats = torch.load('normalization_stats.pt', map_location=device) |
|
|
|
|
|
def _load_models(self, model_path): |
|
|
"""Load trained models.""" |
|
|
print("Loading trained models...") |
|
|
|
|
|
|
|
|
self.compressor = Compressor().to(self.device) |
|
|
self.decompressor = Decompressor().to(self.device) |
|
|
|
|
|
self.compressor.load_state_dict(torch.load('/data2/edwardsun/flow_amp/models/final_compressor_model.pth', map_location=self.device)) |
|
|
self.decompressor.load_state_dict(torch.load('/data2/edwardsun/flow_amp/models/final_decompressor_model.pth', map_location=self.device)) |
|
|
|
|
|
|
|
|
self.flow_model = AMPFlowMatcherCFGConcat( |
|
|
hidden_dim=480, |
|
|
compressed_dim=80, |
|
|
n_layers=12, |
|
|
n_heads=16, |
|
|
dim_ff=3072, |
|
|
max_seq_len=25, |
|
|
use_cfg=True |
|
|
).to(self.device) |
|
|
|
|
|
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False) |
|
|
|
|
|
|
|
|
state_dict = checkpoint['flow_model_state_dict'] |
|
|
new_state_dict = {} |
|
|
|
|
|
for key, value in state_dict.items(): |
|
|
|
|
|
if key.startswith('_orig_mod.'): |
|
|
new_key = key[10:] |
|
|
else: |
|
|
new_key = key |
|
|
new_state_dict[new_key] = value |
|
|
|
|
|
self.flow_model.load_state_dict(new_state_dict) |
|
|
|
|
|
print(f"โ All models loaded successfully from step {checkpoint['step']}!") |
|
|
print(f" Loss at checkpoint: {checkpoint['loss']:.6f}") |
|
|
|
|
|
|
|
|
if TORCHDIFFEQ_AVAILABLE: |
|
|
print("โ Enhanced with proper ODE solving (torchdiffeq)") |
|
|
else: |
|
|
print("โ ๏ธ Using fallback Euler integration") |
|
|
|
|
|
def _create_ode_func(self, cfg_scale=7.5): |
|
|
"""Create ODE function for torchdiffeq integration.""" |
|
|
|
|
|
def ode_func(t, x): |
|
|
""" |
|
|
ODE function: dx/dt = v_theta(x, t) |
|
|
|
|
|
Args: |
|
|
t: scalar time (single float) |
|
|
x: state tensor [B*L*D] (flattened) |
|
|
Returns: |
|
|
dx/dt: derivative [B*L*D] (flattened) |
|
|
""" |
|
|
|
|
|
batch_size, seq_len, dim = self.current_shape |
|
|
x = x.view(batch_size, seq_len, dim) |
|
|
|
|
|
|
|
|
t_tensor = torch.full((batch_size,), t, device=self.device, dtype=x.dtype) |
|
|
|
|
|
|
|
|
if cfg_scale > 0: |
|
|
|
|
|
amp_labels = torch.full((batch_size,), 0, device=self.device) |
|
|
vt_cond = self.flow_model(x, t_tensor, labels=amp_labels) |
|
|
|
|
|
|
|
|
mask_labels = torch.full((batch_size,), 2, device=self.device) |
|
|
vt_uncond = self.flow_model(x, t_tensor, labels=mask_labels) |
|
|
|
|
|
|
|
|
vt = vt_uncond + cfg_scale * (vt_cond - vt_uncond) |
|
|
else: |
|
|
|
|
|
mask_labels = torch.full((batch_size,), 2, device=self.device) |
|
|
vt = self.flow_model(x, t_tensor, labels=mask_labels) |
|
|
|
|
|
|
|
|
return vt.view(-1) |
|
|
|
|
|
return ode_func |
|
|
|
|
|
def generate_amps(self, num_samples=100, num_steps=25, batch_size=32, cfg_scale=7.5, |
|
|
ode_method='dopri5', rtol=1e-5, atol=1e-6): |
|
|
""" |
|
|
Generate AMP samples using flow matching with CFG and improved ODE solving. |
|
|
|
|
|
Args: |
|
|
num_samples: Number of AMP samples to generate |
|
|
num_steps: Number of ODE solving steps (25 for good quality, 1 for reflow) |
|
|
batch_size: Batch size for generation |
|
|
cfg_scale: CFG guidance scale (higher = stronger conditioning) |
|
|
ode_method: ODE solver method ('dopri5', 'rk4', 'euler', 'adaptive_heun') |
|
|
rtol: Relative tolerance for adaptive solvers |
|
|
atol: Absolute tolerance for adaptive solvers |
|
|
""" |
|
|
method_str = f"{ode_method} ODE solver" if TORCHDIFFEQ_AVAILABLE and ode_method != 'euler' else "manual Euler integration" |
|
|
print(f"Generating {num_samples} AMP samples with {method_str} (CFG scale: {cfg_scale})...") |
|
|
if TORCHDIFFEQ_AVAILABLE and ode_method != 'euler': |
|
|
print(f" Method: {ode_method}, rtol={rtol}, atol={atol}") |
|
|
|
|
|
self.flow_model.eval() |
|
|
self.compressor.eval() |
|
|
self.decompressor.eval() |
|
|
|
|
|
all_generated = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for i in tqdm(range(0, num_samples, batch_size), desc="Generating with improved ODE"): |
|
|
current_batch = min(batch_size, num_samples - i) |
|
|
|
|
|
|
|
|
eps = torch.randn(current_batch, 25, 80, device=self.device) |
|
|
|
|
|
|
|
|
if TORCHDIFFEQ_AVAILABLE and ode_method != 'euler': |
|
|
|
|
|
try: |
|
|
|
|
|
self.current_shape = eps.shape |
|
|
|
|
|
|
|
|
ode_func = self._create_ode_func(cfg_scale=cfg_scale) |
|
|
|
|
|
|
|
|
t_span = torch.tensor([1.0, 0.0], device=self.device, dtype=eps.dtype) |
|
|
|
|
|
|
|
|
y0 = eps.view(-1) |
|
|
|
|
|
|
|
|
if ode_method in ['dopri5', 'adaptive_heun']: |
|
|
|
|
|
solution = odeint( |
|
|
ode_func, y0, t_span, |
|
|
method=ode_method, |
|
|
rtol=rtol, |
|
|
atol=atol, |
|
|
options={'max_num_steps': 1000} |
|
|
) |
|
|
else: |
|
|
|
|
|
solution = odeint( |
|
|
ode_func, y0, t_span, |
|
|
method=ode_method, |
|
|
options={'step_size': 0.04} |
|
|
) |
|
|
|
|
|
|
|
|
xt = solution[-1].view(self.current_shape) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"โ ๏ธ ODE solving failed for batch {i//batch_size + 1}: {e}") |
|
|
print("Falling back to Euler method...") |
|
|
|
|
|
xt = self._generate_with_euler(eps, current_batch, cfg_scale, num_steps) |
|
|
else: |
|
|
|
|
|
xt = self._generate_with_euler(eps, current_batch, cfg_scale, num_steps) |
|
|
|
|
|
|
|
|
decompressed = self.decompressor(xt) |
|
|
|
|
|
|
|
|
m, s, mn, mx = self.stats['mean'], self.stats['std'], self.stats['min'], self.stats['max'] |
|
|
decompressed = decompressed * (mx - mn + 1e-8) + mn |
|
|
decompressed = decompressed * s + m |
|
|
|
|
|
all_generated.append(decompressed.cpu()) |
|
|
|
|
|
|
|
|
generated_embeddings = torch.cat(all_generated, dim=0) |
|
|
|
|
|
print(f"โ Generated {generated_embeddings.shape[0]} AMP embeddings") |
|
|
print(f" Shape: {generated_embeddings.shape}") |
|
|
print(f" Stats - Mean: {generated_embeddings.mean():.4f}, Std: {generated_embeddings.std():.4f}") |
|
|
|
|
|
return generated_embeddings |
|
|
|
|
|
def _generate_with_euler(self, eps, current_batch, cfg_scale, num_steps): |
|
|
"""Fallback Euler integration method (original implementation).""" |
|
|
xt = eps.clone() |
|
|
amp_labels = torch.full((current_batch,), 0, device=self.device) |
|
|
mask_labels = torch.full((current_batch,), 2, device=self.device) |
|
|
|
|
|
for step in range(num_steps): |
|
|
t = torch.ones(current_batch, device=self.device) * (1.0 - step/num_steps) |
|
|
|
|
|
|
|
|
if cfg_scale > 0: |
|
|
|
|
|
vt_cond = self.flow_model(xt, t, labels=amp_labels) |
|
|
|
|
|
|
|
|
vt_uncond = self.flow_model(xt, t, labels=mask_labels) |
|
|
|
|
|
|
|
|
vt = vt_uncond + cfg_scale * (vt_cond - vt_uncond) |
|
|
else: |
|
|
|
|
|
vt = self.flow_model(xt, t, labels=mask_labels) |
|
|
|
|
|
|
|
|
dt = -1.0 / num_steps |
|
|
xt = xt + vt * dt |
|
|
|
|
|
return xt |
|
|
|
|
|
def compare_ode_methods(self, num_samples=20, cfg_scale=7.5): |
|
|
""" |
|
|
Compare different ODE solving methods for quality assessment. |
|
|
""" |
|
|
if not TORCHDIFFEQ_AVAILABLE: |
|
|
print("โ ๏ธ torchdiffeq not available, cannot compare ODE methods") |
|
|
return self.generate_amps(num_samples=num_samples, cfg_scale=cfg_scale) |
|
|
|
|
|
methods = ['euler', 'rk4', 'dopri5', 'adaptive_heun'] |
|
|
results = {} |
|
|
|
|
|
print("๐ฌ Comparing ODE solving methods...") |
|
|
|
|
|
for method in methods: |
|
|
print(f"\n--- Testing {method} ---") |
|
|
try: |
|
|
start_time = torch.cuda.Event(enable_timing=True) |
|
|
end_time = torch.cuda.Event(enable_timing=True) |
|
|
|
|
|
start_time.record() |
|
|
embeddings = self.generate_amps( |
|
|
num_samples=num_samples, |
|
|
batch_size=10, |
|
|
cfg_scale=cfg_scale, |
|
|
ode_method=method |
|
|
) |
|
|
end_time.record() |
|
|
|
|
|
torch.cuda.synchronize() |
|
|
elapsed_time = start_time.elapsed_time(end_time) / 1000.0 |
|
|
|
|
|
results[method] = { |
|
|
'embeddings': embeddings, |
|
|
'time': elapsed_time, |
|
|
'mean': embeddings.mean().item(), |
|
|
'std': embeddings.std().item(), |
|
|
'success': True |
|
|
} |
|
|
|
|
|
print(f"โ {method}: {elapsed_time:.2f}s, mean={embeddings.mean():.4f}, std={embeddings.std():.4f}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"โ {method} failed: {e}") |
|
|
results[method] = {'success': False, 'error': str(e)} |
|
|
|
|
|
return results |
|
|
|
|
|
def generate_with_reflow(self, num_samples=100): |
|
|
""" |
|
|
Generate AMP samples using 1-step reflow (if you have reflow model). |
|
|
""" |
|
|
print(f"Generating {num_samples} AMP samples with 1-step reflow...") |
|
|
|
|
|
|
|
|
|
|
|
return self.generate_amps(num_samples=num_samples, num_steps=1, batch_size=32) |
|
|
|
|
|
def main(): |
|
|
"""Main generation function.""" |
|
|
print("=== AMP Generation Pipeline with CFG ===") |
|
|
|
|
|
|
|
|
model_path = '/data2/edwardsun/flow_checkpoints/amp_flow_model_best_optimized.pth' |
|
|
|
|
|
|
|
|
try: |
|
|
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) |
|
|
print(f"โ Found best model at step {checkpoint['step']} with loss {checkpoint['loss']:.6f}") |
|
|
print(f" Global step: {checkpoint['global_step']}") |
|
|
print(f" Total samples: {checkpoint['total_samples']:,}") |
|
|
except: |
|
|
print(f"โ Best model not found: {model_path}") |
|
|
print("Please train the flow matching model first using amp_flow_training.py") |
|
|
return |
|
|
|
|
|
|
|
|
generator = AMPGenerator(model_path, device='cuda') |
|
|
|
|
|
|
|
|
if TORCHDIFFEQ_AVAILABLE: |
|
|
print("\n๐ฌ Comparing ODE solving methods...") |
|
|
comparison_results = generator.compare_ode_methods(num_samples=10, cfg_scale=7.5) |
|
|
|
|
|
|
|
|
best_method = 'dopri5' |
|
|
print(f"\n๐ Using {best_method} for main generation...") |
|
|
else: |
|
|
best_method = 'euler' |
|
|
print("\nโ ๏ธ Using fallback Euler integration...") |
|
|
|
|
|
|
|
|
print("\n1. Generating with CFG scale 0.0 (no conditioning)...") |
|
|
samples_no_cfg = generator.generate_amps(num_samples=20, num_steps=25, cfg_scale=0.0, ode_method=best_method) |
|
|
|
|
|
print("\n2. Generating with CFG scale 3.0 (weak conditioning)...") |
|
|
samples_weak_cfg = generator.generate_amps(num_samples=20, num_steps=25, cfg_scale=3.0, ode_method=best_method) |
|
|
|
|
|
print("\n3. Generating with CFG scale 7.5 (strong conditioning)...") |
|
|
samples_strong_cfg = generator.generate_amps(num_samples=20, num_steps=25, cfg_scale=7.5, ode_method=best_method) |
|
|
|
|
|
print("\n4. Generating with CFG scale 15.0 (very strong conditioning)...") |
|
|
samples_very_strong_cfg = generator.generate_amps(num_samples=20, num_steps=25, cfg_scale=15.0, ode_method=best_method) |
|
|
|
|
|
|
|
|
output_dir = '/data2/edwardsun/generated_samples' |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
today = datetime.now().strftime('%Y%m%d') |
|
|
|
|
|
|
|
|
torch.save(samples_no_cfg, os.path.join(output_dir, f'generated_amps_best_model_no_cfg_{today}.pt')) |
|
|
torch.save(samples_weak_cfg, os.path.join(output_dir, f'generated_amps_best_model_weak_cfg_{today}.pt')) |
|
|
torch.save(samples_strong_cfg, os.path.join(output_dir, f'generated_amps_best_model_strong_cfg_{today}.pt')) |
|
|
torch.save(samples_very_strong_cfg, os.path.join(output_dir, f'generated_amps_best_model_very_strong_cfg_{today}.pt')) |
|
|
|
|
|
print("\nโ Generation complete!") |
|
|
print(f"Generated samples saved (Date: {today}):") |
|
|
print(f" - generated_amps_best_model_no_cfg_{today}.pt (no conditioning)") |
|
|
print(f" - generated_amps_best_model_weak_cfg_{today}.pt (weak CFG)") |
|
|
print(f" - generated_amps_best_model_strong_cfg_{today}.pt (strong CFG)") |
|
|
print(f" - generated_amps_best_model_very_strong_cfg_{today}.pt (very strong CFG)") |
|
|
|
|
|
print("\nCFG Analysis:") |
|
|
print(" - CFG scale 0.0: No conditioning, generates diverse sequences") |
|
|
print(" - CFG scale 3.0: Weak AMP conditioning") |
|
|
print(" - CFG scale 7.5: Strong AMP conditioning (recommended)") |
|
|
print(" - CFG scale 15.0: Very strong AMP conditioning (may be too restrictive)") |
|
|
|
|
|
print("\nNext steps:") |
|
|
print("1. Decode embeddings back to sequences using ESM-2 decoder") |
|
|
print("2. Evaluate with ProtFlow metrics (FPD, MMD, ESM-2 perplexity)") |
|
|
print("3. Compare sequences generated with different CFG scales") |
|
|
print("4. Evaluate AMP properties (antimicrobial activity, toxicity)") |
|
|
if TORCHDIFFEQ_AVAILABLE: |
|
|
print(f"5. โ Enhanced generation with {best_method} ODE solver") |
|
|
else: |
|
|
print("5. Install torchdiffeq for improved ODE solving: pip install torchdiffeq") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |