FlowFinal / src /final_flow_model.py
esunAI's picture
Add final_flow_model.py
a1fa8fa verified
raw
history blame
12 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SinusoidalTimeEmbedding(nn.Module):
"""Sinusoidal time embedding as used in ProtFlow paper."""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
# Ensure time is 2D: [B, 1] and embeddings is 1D: [half_dim]
if time.dim() > 2:
time = time.squeeze() # Remove extra dimensions
embeddings = time.unsqueeze(-1) * embeddings.unsqueeze(0) # [B, half_dim]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) # [B, dim]
# Ensure output is exactly 2D
if embeddings.dim() > 2:
embeddings = embeddings.squeeze()
return embeddings
class LabelMLP(nn.Module):
"""
MLP for processing class labels into embeddings.
This approach processes labels separately from time embeddings.
"""
def __init__(self, num_classes=3, hidden_dim=480, mlp_dim=256):
super().__init__()
self.num_classes = num_classes
# MLP to process labels
self.label_mlp = nn.Sequential(
nn.Embedding(num_classes, mlp_dim),
nn.Linear(mlp_dim, mlp_dim),
nn.GELU(),
nn.Linear(mlp_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim)
)
# Initialize embeddings
nn.init.normal_(self.label_mlp[0].weight, std=0.02)
def forward(self, labels):
"""
Args:
labels: (B,) tensor of class labels
- 0: AMP (MIC < 100)
- 1: Non-AMP (MIC >= 100)
- 2: Mask (Unknown MIC)
Returns:
embeddings: (B, hidden_dim) tensor of processed label embeddings
"""
return self.label_mlp(labels)
class AMPFlowMatcherCFGConcat(nn.Module):
"""
Flow Matching model with Classifier-Free Guidance using concatenation approach.
- 12-layer transformer with long skip connections
- Time embedding + MLP-processed label embedding (concatenated then projected)
- Optimized for peptide sequences (max length 50)
"""
def __init__(self, hidden_dim=480, compressed_dim=30, n_layers=12, n_heads=16,
dim_ff=3072, dropout=0.1, max_seq_len=25, use_cfg=True):
super().__init__()
self.hidden_dim = hidden_dim
self.compressed_dim = compressed_dim
self.n_layers = n_layers
self.max_seq_len = max_seq_len
self.use_cfg = use_cfg
# Time embedding
self.time_embed = nn.Sequential(
SinusoidalTimeEmbedding(hidden_dim),
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim)
)
# CFG components using concatenation approach
if use_cfg:
self.label_mlp = LabelMLP(num_classes=3, hidden_dim=hidden_dim)
# Projection layer for concatenated time + label embeddings
self.condition_proj = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim), # 2 for time + label
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim)
)
# Projection layers for compressed space
self.compress_proj = nn.Linear(compressed_dim, hidden_dim)
self.decompress_proj = nn.Linear(hidden_dim, compressed_dim)
# Positional encoding for peptide sequences
self.pos_embed = nn.Parameter(torch.randn(1, max_seq_len, hidden_dim))
# Transformer layers with long skip connections
self.layers = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=n_heads,
dim_feedforward=dim_ff,
dropout=dropout,
activation='gelu',
batch_first=True
) for _ in range(n_layers)
])
# Long skip connections (U-ViT style)
self.skip_projections = nn.ModuleList([
nn.Linear(hidden_dim, hidden_dim) for _ in range(n_layers - 1)
])
# Output projection
self.output_proj = nn.Linear(hidden_dim, compressed_dim)
def forward(self, x, t, labels=None, mask=None):
"""
Args:
x: compressed latent (B, L, compressed_dim) - AMP embeddings
t: time scalar (B,) or (B, 1)
labels: class labels (B,) for CFG - 0=AMP, 1=Non-AMP, 2=Mask
mask: attention mask (B, L) if needed
"""
B, L, D = x.shape
# Project to hidden dimension
x = self.compress_proj(x) # (B, L, hidden_dim)
# Add positional encoding
if L <= self.max_seq_len:
x = x + self.pos_embed[:, :L, :]
# Time embedding - ensure t is 2D (B, 1)
if t.dim() == 1:
t = t.unsqueeze(-1) # (B, 1)
elif t.dim() > 2:
t = t.squeeze() # Remove extra dimensions
if t.dim() == 1:
t = t.unsqueeze(-1) # (B, 1)
t_emb = self.time_embed(t) # (B, hidden_dim)
# Ensure t_emb is 2D before expanding
if t_emb.dim() > 2:
t_emb = t_emb.squeeze() # Remove extra dimensions
t_emb = t_emb.unsqueeze(1).expand(-1, L, -1) # (B, L, hidden_dim)
# CFG: Process label embedding if enabled
if self.use_cfg and labels is not None:
# Process labels through MLP
label_emb = self.label_mlp(labels) # (B, hidden_dim)
label_emb = label_emb.unsqueeze(1).expand(-1, L, -1) # (B, L, hidden_dim)
# Professor's approach: Concatenate time and label embeddings
combined_emb = torch.cat([t_emb, label_emb], dim=-1) # (B, L, hidden_dim*2)
projected_emb = self.condition_proj(combined_emb) # (B, L, hidden_dim)
else:
projected_emb = t_emb # Just use time embedding if no CFG
# Store intermediate representations for skip connections
skip_features = []
# Pass through transformer layers with skip connections
for i, layer in enumerate(self.layers):
# Add skip connection from earlier layers
if i > 0 and i < len(self.layers) - 1:
skip_feat = skip_features[i-1]
skip_feat = self.skip_projections[i-1](skip_feat)
x = x + skip_feat
# Store current features for future skip connections
if i < len(self.layers) - 1:
skip_features.append(x.clone())
# Add projected condition embedding to EACH layer
x = x + projected_emb
# Apply transformer layer
x = layer(x, src_key_padding_mask=mask)
# Project back to compressed dimension
x = self.output_proj(x) # (B, L, compressed_dim)
return x
class AMPProtFlowPipelineCFG:
"""
Complete ProtFlow pipeline for AMP generation with CFG.
"""
def __init__(self, compressor, decompressor, flow_model, device='cuda'):
self.compressor = compressor
self.decompressor = decompressor
self.flow_model = flow_model
self.device = device
# Load normalization stats
self.stats = torch.load('normalization_stats.pt', map_location=device)
def generate_amps_cfg(self, num_samples=100, num_steps=25, cfg_scale=7.5,
condition_label=0):
"""
Generate AMP samples using CFG.
Args:
num_samples: Number of samples to generate
num_steps: Number of ODE solving steps
cfg_scale: CFG guidance scale (higher = stronger conditioning)
condition_label: 0=AMP, 1=Non-AMP, 2=Mask
"""
print(f"Generating {num_samples} samples with CFG (label={condition_label}, scale={cfg_scale})...")
# Sample random noise
batch_size = min(num_samples, 32) # Process in batches
all_samples = []
for i in range(0, num_samples, batch_size):
current_batch = min(batch_size, num_samples - i)
# Initialize with noise
eps = torch.randn(current_batch, self.flow_model.max_seq_len,
self.flow_model.compressed_dim, device=self.device)
# ODE solving steps with CFG
xt = eps.clone()
for step in range(num_steps):
t = torch.ones(current_batch, device=self.device) * (1.0 - step/num_steps)
# CFG: Generate with condition and without condition
if cfg_scale > 0:
# With condition
vt_cond = self.flow_model(xt, t,
labels=torch.full((current_batch,), condition_label,
device=self.device))
# Without condition (mask)
vt_uncond = self.flow_model(xt, t,
labels=torch.full((current_batch,), 2,
device=self.device))
# CFG interpolation
vt = vt_uncond + cfg_scale * (vt_cond - vt_uncond)
else:
# No CFG, use mask label
vt = self.flow_model(xt, t,
labels=torch.full((current_batch,), 2,
device=self.device))
# Euler step for backward integration (t: 1 -> 0)
# Use negative dt to integrate backward from noise to data
dt = -1.0 / num_steps
xt = xt + vt * dt
all_samples.append(xt)
# Concatenate all batches
generated = torch.cat(all_samples, dim=0)
# Decompress and decode
with torch.no_grad():
# Decompress
decompressed = self.decompressor(generated)
# Apply reverse normalization
m, s, mn, mx = self.stats['mean'], self.stats['std'], self.stats['min'], self.stats['max']
decompressed = decompressed * (mx - mn + 1e-8) + mn
decompressed = decompressed * s + m
return generated, decompressed
# Example usage
if __name__ == "__main__":
# Initialize FINAL AMP flow model with CFG using concatenation approach
flow_model = AMPFlowMatcherCFGConcat(
hidden_dim=480,
compressed_dim=30, # 16x compression of 480
n_layers=12,
n_heads=16,
dim_ff=3072,
max_seq_len=25, # For AMP sequences (max 50, halved by pooling)
use_cfg=True
)
print(f"FINAL AMP Flow Model with CFG (Concat+Proj) parameters: {sum(p.numel() for p in flow_model.parameters()):,}")
# Test forward pass
batch_size = 4
seq_len = 20
compressed_dim = 30
x = torch.randn(batch_size, seq_len, compressed_dim)
t = torch.rand(batch_size)
labels = torch.randint(0, 3, (batch_size,)) # Random labels
with torch.no_grad():
output = flow_model(x, t, labels=labels)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Time embedding shape: {t.shape}")
print(f"Labels: {labels}")
print("🎯 FINAL AMP Flow Model with CFG (Concat+Proj) ready for training!")