|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
|
|
|
class SinusoidalTimeEmbedding(nn.Module): |
|
|
"""Sinusoidal time embedding as used in ProtFlow paper.""" |
|
|
|
|
|
def __init__(self, dim): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
|
|
|
def forward(self, time): |
|
|
device = time.device |
|
|
half_dim = self.dim // 2 |
|
|
embeddings = math.log(10000) / (half_dim - 1) |
|
|
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) |
|
|
|
|
|
if time.dim() > 2: |
|
|
time = time.squeeze() |
|
|
embeddings = time.unsqueeze(-1) * embeddings.unsqueeze(0) |
|
|
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) |
|
|
|
|
|
if embeddings.dim() > 2: |
|
|
embeddings = embeddings.squeeze() |
|
|
return embeddings |
|
|
|
|
|
class LabelMLP(nn.Module): |
|
|
""" |
|
|
MLP for processing class labels into embeddings. |
|
|
This approach processes labels separately from time embeddings. |
|
|
""" |
|
|
def __init__(self, num_classes=3, hidden_dim=480, mlp_dim=256): |
|
|
super().__init__() |
|
|
self.num_classes = num_classes |
|
|
|
|
|
|
|
|
self.label_mlp = nn.Sequential( |
|
|
nn.Embedding(num_classes, mlp_dim), |
|
|
nn.Linear(mlp_dim, mlp_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(mlp_dim, hidden_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(hidden_dim, hidden_dim) |
|
|
) |
|
|
|
|
|
|
|
|
nn.init.normal_(self.label_mlp[0].weight, std=0.02) |
|
|
|
|
|
def forward(self, labels): |
|
|
""" |
|
|
Args: |
|
|
labels: (B,) tensor of class labels |
|
|
- 0: AMP (MIC < 100) |
|
|
- 1: Non-AMP (MIC >= 100) |
|
|
- 2: Mask (Unknown MIC) |
|
|
Returns: |
|
|
embeddings: (B, hidden_dim) tensor of processed label embeddings |
|
|
""" |
|
|
return self.label_mlp(labels) |
|
|
|
|
|
class AMPFlowMatcherCFGConcat(nn.Module): |
|
|
""" |
|
|
Flow Matching model with Classifier-Free Guidance using concatenation approach. |
|
|
- 12-layer transformer with long skip connections |
|
|
- Time embedding + MLP-processed label embedding (concatenated then projected) |
|
|
- Optimized for peptide sequences (max length 50) |
|
|
""" |
|
|
|
|
|
def __init__(self, hidden_dim=480, compressed_dim=30, n_layers=12, n_heads=16, |
|
|
dim_ff=3072, dropout=0.1, max_seq_len=25, use_cfg=True): |
|
|
super().__init__() |
|
|
self.hidden_dim = hidden_dim |
|
|
self.compressed_dim = compressed_dim |
|
|
self.n_layers = n_layers |
|
|
self.max_seq_len = max_seq_len |
|
|
self.use_cfg = use_cfg |
|
|
|
|
|
|
|
|
self.time_embed = nn.Sequential( |
|
|
SinusoidalTimeEmbedding(hidden_dim), |
|
|
nn.Linear(hidden_dim, hidden_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(hidden_dim, hidden_dim) |
|
|
) |
|
|
|
|
|
|
|
|
if use_cfg: |
|
|
self.label_mlp = LabelMLP(num_classes=3, hidden_dim=hidden_dim) |
|
|
|
|
|
|
|
|
self.condition_proj = nn.Sequential( |
|
|
nn.Linear(hidden_dim * 2, hidden_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(hidden_dim, hidden_dim) |
|
|
) |
|
|
|
|
|
|
|
|
self.compress_proj = nn.Linear(compressed_dim, hidden_dim) |
|
|
self.decompress_proj = nn.Linear(hidden_dim, compressed_dim) |
|
|
|
|
|
|
|
|
self.pos_embed = nn.Parameter(torch.randn(1, max_seq_len, hidden_dim)) |
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList([ |
|
|
nn.TransformerEncoderLayer( |
|
|
d_model=hidden_dim, |
|
|
nhead=n_heads, |
|
|
dim_feedforward=dim_ff, |
|
|
dropout=dropout, |
|
|
activation='gelu', |
|
|
batch_first=True |
|
|
) for _ in range(n_layers) |
|
|
]) |
|
|
|
|
|
|
|
|
self.skip_projections = nn.ModuleList([ |
|
|
nn.Linear(hidden_dim, hidden_dim) for _ in range(n_layers - 1) |
|
|
]) |
|
|
|
|
|
|
|
|
self.output_proj = nn.Linear(hidden_dim, compressed_dim) |
|
|
|
|
|
def forward(self, x, t, labels=None, mask=None): |
|
|
""" |
|
|
Args: |
|
|
x: compressed latent (B, L, compressed_dim) - AMP embeddings |
|
|
t: time scalar (B,) or (B, 1) |
|
|
labels: class labels (B,) for CFG - 0=AMP, 1=Non-AMP, 2=Mask |
|
|
mask: attention mask (B, L) if needed |
|
|
""" |
|
|
B, L, D = x.shape |
|
|
|
|
|
|
|
|
x = self.compress_proj(x) |
|
|
|
|
|
|
|
|
if L <= self.max_seq_len: |
|
|
x = x + self.pos_embed[:, :L, :] |
|
|
|
|
|
|
|
|
if t.dim() == 1: |
|
|
t = t.unsqueeze(-1) |
|
|
elif t.dim() > 2: |
|
|
t = t.squeeze() |
|
|
if t.dim() == 1: |
|
|
t = t.unsqueeze(-1) |
|
|
|
|
|
t_emb = self.time_embed(t) |
|
|
|
|
|
if t_emb.dim() > 2: |
|
|
t_emb = t_emb.squeeze() |
|
|
t_emb = t_emb.unsqueeze(1).expand(-1, L, -1) |
|
|
|
|
|
|
|
|
if self.use_cfg and labels is not None: |
|
|
|
|
|
label_emb = self.label_mlp(labels) |
|
|
label_emb = label_emb.unsqueeze(1).expand(-1, L, -1) |
|
|
|
|
|
|
|
|
combined_emb = torch.cat([t_emb, label_emb], dim=-1) |
|
|
projected_emb = self.condition_proj(combined_emb) |
|
|
else: |
|
|
projected_emb = t_emb |
|
|
|
|
|
|
|
|
skip_features = [] |
|
|
|
|
|
|
|
|
for i, layer in enumerate(self.layers): |
|
|
|
|
|
if i > 0 and i < len(self.layers) - 1: |
|
|
skip_feat = skip_features[i-1] |
|
|
skip_feat = self.skip_projections[i-1](skip_feat) |
|
|
x = x + skip_feat |
|
|
|
|
|
|
|
|
if i < len(self.layers) - 1: |
|
|
skip_features.append(x.clone()) |
|
|
|
|
|
|
|
|
x = x + projected_emb |
|
|
|
|
|
|
|
|
x = layer(x, src_key_padding_mask=mask) |
|
|
|
|
|
|
|
|
x = self.output_proj(x) |
|
|
|
|
|
return x |
|
|
|
|
|
class AMPProtFlowPipelineCFG: |
|
|
""" |
|
|
Complete ProtFlow pipeline for AMP generation with CFG. |
|
|
""" |
|
|
|
|
|
def __init__(self, compressor, decompressor, flow_model, device='cuda'): |
|
|
self.compressor = compressor |
|
|
self.decompressor = decompressor |
|
|
self.flow_model = flow_model |
|
|
self.device = device |
|
|
|
|
|
|
|
|
self.stats = torch.load('normalization_stats.pt', map_location=device) |
|
|
|
|
|
def generate_amps_cfg(self, num_samples=100, num_steps=25, cfg_scale=7.5, |
|
|
condition_label=0): |
|
|
""" |
|
|
Generate AMP samples using CFG. |
|
|
|
|
|
Args: |
|
|
num_samples: Number of samples to generate |
|
|
num_steps: Number of ODE solving steps |
|
|
cfg_scale: CFG guidance scale (higher = stronger conditioning) |
|
|
condition_label: 0=AMP, 1=Non-AMP, 2=Mask |
|
|
""" |
|
|
print(f"Generating {num_samples} samples with CFG (label={condition_label}, scale={cfg_scale})...") |
|
|
|
|
|
|
|
|
batch_size = min(num_samples, 32) |
|
|
all_samples = [] |
|
|
|
|
|
for i in range(0, num_samples, batch_size): |
|
|
current_batch = min(batch_size, num_samples - i) |
|
|
|
|
|
|
|
|
eps = torch.randn(current_batch, self.flow_model.max_seq_len, |
|
|
self.flow_model.compressed_dim, device=self.device) |
|
|
|
|
|
|
|
|
xt = eps.clone() |
|
|
for step in range(num_steps): |
|
|
t = torch.ones(current_batch, device=self.device) * (1.0 - step/num_steps) |
|
|
|
|
|
|
|
|
if cfg_scale > 0: |
|
|
|
|
|
vt_cond = self.flow_model(xt, t, |
|
|
labels=torch.full((current_batch,), condition_label, |
|
|
device=self.device)) |
|
|
|
|
|
|
|
|
vt_uncond = self.flow_model(xt, t, |
|
|
labels=torch.full((current_batch,), 2, |
|
|
device=self.device)) |
|
|
|
|
|
|
|
|
vt = vt_uncond + cfg_scale * (vt_cond - vt_uncond) |
|
|
else: |
|
|
|
|
|
vt = self.flow_model(xt, t, |
|
|
labels=torch.full((current_batch,), 2, |
|
|
device=self.device)) |
|
|
|
|
|
|
|
|
|
|
|
dt = -1.0 / num_steps |
|
|
xt = xt + vt * dt |
|
|
|
|
|
all_samples.append(xt) |
|
|
|
|
|
|
|
|
generated = torch.cat(all_samples, dim=0) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
decompressed = self.decompressor(generated) |
|
|
|
|
|
|
|
|
m, s, mn, mx = self.stats['mean'], self.stats['std'], self.stats['min'], self.stats['max'] |
|
|
decompressed = decompressed * (mx - mn + 1e-8) + mn |
|
|
decompressed = decompressed * s + m |
|
|
|
|
|
return generated, decompressed |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
flow_model = AMPFlowMatcherCFGConcat( |
|
|
hidden_dim=480, |
|
|
compressed_dim=30, |
|
|
n_layers=12, |
|
|
n_heads=16, |
|
|
dim_ff=3072, |
|
|
max_seq_len=25, |
|
|
use_cfg=True |
|
|
) |
|
|
|
|
|
print(f"FINAL AMP Flow Model with CFG (Concat+Proj) parameters: {sum(p.numel() for p in flow_model.parameters()):,}") |
|
|
|
|
|
|
|
|
batch_size = 4 |
|
|
seq_len = 20 |
|
|
compressed_dim = 30 |
|
|
|
|
|
x = torch.randn(batch_size, seq_len, compressed_dim) |
|
|
t = torch.rand(batch_size) |
|
|
labels = torch.randint(0, 3, (batch_size,)) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = flow_model(x, t, labels=labels) |
|
|
print(f"Input shape: {x.shape}") |
|
|
print(f"Output shape: {output.shape}") |
|
|
print(f"Time embedding shape: {t.shape}") |
|
|
print(f"Labels: {labels}") |
|
|
|
|
|
print("🎯 FINAL AMP Flow Model with CFG (Concat+Proj) ready for training!") |