esunAI commited on
Commit
a1fa8fa
·
verified ·
1 Parent(s): e37da79

Add final_flow_model.py

Browse files
Files changed (1) hide show
  1. src/final_flow_model.py +310 -0
src/final_flow_model.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ class SinusoidalTimeEmbedding(nn.Module):
7
+ """Sinusoidal time embedding as used in ProtFlow paper."""
8
+
9
+ def __init__(self, dim):
10
+ super().__init__()
11
+ self.dim = dim
12
+
13
+ def forward(self, time):
14
+ device = time.device
15
+ half_dim = self.dim // 2
16
+ embeddings = math.log(10000) / (half_dim - 1)
17
+ embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
18
+ # Ensure time is 2D: [B, 1] and embeddings is 1D: [half_dim]
19
+ if time.dim() > 2:
20
+ time = time.squeeze() # Remove extra dimensions
21
+ embeddings = time.unsqueeze(-1) * embeddings.unsqueeze(0) # [B, half_dim]
22
+ embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) # [B, dim]
23
+ # Ensure output is exactly 2D
24
+ if embeddings.dim() > 2:
25
+ embeddings = embeddings.squeeze()
26
+ return embeddings
27
+
28
+ class LabelMLP(nn.Module):
29
+ """
30
+ MLP for processing class labels into embeddings.
31
+ This approach processes labels separately from time embeddings.
32
+ """
33
+ def __init__(self, num_classes=3, hidden_dim=480, mlp_dim=256):
34
+ super().__init__()
35
+ self.num_classes = num_classes
36
+
37
+ # MLP to process labels
38
+ self.label_mlp = nn.Sequential(
39
+ nn.Embedding(num_classes, mlp_dim),
40
+ nn.Linear(mlp_dim, mlp_dim),
41
+ nn.GELU(),
42
+ nn.Linear(mlp_dim, hidden_dim),
43
+ nn.GELU(),
44
+ nn.Linear(hidden_dim, hidden_dim)
45
+ )
46
+
47
+ # Initialize embeddings
48
+ nn.init.normal_(self.label_mlp[0].weight, std=0.02)
49
+
50
+ def forward(self, labels):
51
+ """
52
+ Args:
53
+ labels: (B,) tensor of class labels
54
+ - 0: AMP (MIC < 100)
55
+ - 1: Non-AMP (MIC >= 100)
56
+ - 2: Mask (Unknown MIC)
57
+ Returns:
58
+ embeddings: (B, hidden_dim) tensor of processed label embeddings
59
+ """
60
+ return self.label_mlp(labels)
61
+
62
+ class AMPFlowMatcherCFGConcat(nn.Module):
63
+ """
64
+ Flow Matching model with Classifier-Free Guidance using concatenation approach.
65
+ - 12-layer transformer with long skip connections
66
+ - Time embedding + MLP-processed label embedding (concatenated then projected)
67
+ - Optimized for peptide sequences (max length 50)
68
+ """
69
+
70
+ def __init__(self, hidden_dim=480, compressed_dim=30, n_layers=12, n_heads=16,
71
+ dim_ff=3072, dropout=0.1, max_seq_len=25, use_cfg=True):
72
+ super().__init__()
73
+ self.hidden_dim = hidden_dim
74
+ self.compressed_dim = compressed_dim
75
+ self.n_layers = n_layers
76
+ self.max_seq_len = max_seq_len
77
+ self.use_cfg = use_cfg
78
+
79
+ # Time embedding
80
+ self.time_embed = nn.Sequential(
81
+ SinusoidalTimeEmbedding(hidden_dim),
82
+ nn.Linear(hidden_dim, hidden_dim),
83
+ nn.GELU(),
84
+ nn.Linear(hidden_dim, hidden_dim)
85
+ )
86
+
87
+ # CFG components using concatenation approach
88
+ if use_cfg:
89
+ self.label_mlp = LabelMLP(num_classes=3, hidden_dim=hidden_dim)
90
+
91
+ # Projection layer for concatenated time + label embeddings
92
+ self.condition_proj = nn.Sequential(
93
+ nn.Linear(hidden_dim * 2, hidden_dim), # 2 for time + label
94
+ nn.GELU(),
95
+ nn.Linear(hidden_dim, hidden_dim)
96
+ )
97
+
98
+ # Projection layers for compressed space
99
+ self.compress_proj = nn.Linear(compressed_dim, hidden_dim)
100
+ self.decompress_proj = nn.Linear(hidden_dim, compressed_dim)
101
+
102
+ # Positional encoding for peptide sequences
103
+ self.pos_embed = nn.Parameter(torch.randn(1, max_seq_len, hidden_dim))
104
+
105
+ # Transformer layers with long skip connections
106
+ self.layers = nn.ModuleList([
107
+ nn.TransformerEncoderLayer(
108
+ d_model=hidden_dim,
109
+ nhead=n_heads,
110
+ dim_feedforward=dim_ff,
111
+ dropout=dropout,
112
+ activation='gelu',
113
+ batch_first=True
114
+ ) for _ in range(n_layers)
115
+ ])
116
+
117
+ # Long skip connections (U-ViT style)
118
+ self.skip_projections = nn.ModuleList([
119
+ nn.Linear(hidden_dim, hidden_dim) for _ in range(n_layers - 1)
120
+ ])
121
+
122
+ # Output projection
123
+ self.output_proj = nn.Linear(hidden_dim, compressed_dim)
124
+
125
+ def forward(self, x, t, labels=None, mask=None):
126
+ """
127
+ Args:
128
+ x: compressed latent (B, L, compressed_dim) - AMP embeddings
129
+ t: time scalar (B,) or (B, 1)
130
+ labels: class labels (B,) for CFG - 0=AMP, 1=Non-AMP, 2=Mask
131
+ mask: attention mask (B, L) if needed
132
+ """
133
+ B, L, D = x.shape
134
+
135
+ # Project to hidden dimension
136
+ x = self.compress_proj(x) # (B, L, hidden_dim)
137
+
138
+ # Add positional encoding
139
+ if L <= self.max_seq_len:
140
+ x = x + self.pos_embed[:, :L, :]
141
+
142
+ # Time embedding - ensure t is 2D (B, 1)
143
+ if t.dim() == 1:
144
+ t = t.unsqueeze(-1) # (B, 1)
145
+ elif t.dim() > 2:
146
+ t = t.squeeze() # Remove extra dimensions
147
+ if t.dim() == 1:
148
+ t = t.unsqueeze(-1) # (B, 1)
149
+
150
+ t_emb = self.time_embed(t) # (B, hidden_dim)
151
+ # Ensure t_emb is 2D before expanding
152
+ if t_emb.dim() > 2:
153
+ t_emb = t_emb.squeeze() # Remove extra dimensions
154
+ t_emb = t_emb.unsqueeze(1).expand(-1, L, -1) # (B, L, hidden_dim)
155
+
156
+ # CFG: Process label embedding if enabled
157
+ if self.use_cfg and labels is not None:
158
+ # Process labels through MLP
159
+ label_emb = self.label_mlp(labels) # (B, hidden_dim)
160
+ label_emb = label_emb.unsqueeze(1).expand(-1, L, -1) # (B, L, hidden_dim)
161
+
162
+ # Professor's approach: Concatenate time and label embeddings
163
+ combined_emb = torch.cat([t_emb, label_emb], dim=-1) # (B, L, hidden_dim*2)
164
+ projected_emb = self.condition_proj(combined_emb) # (B, L, hidden_dim)
165
+ else:
166
+ projected_emb = t_emb # Just use time embedding if no CFG
167
+
168
+ # Store intermediate representations for skip connections
169
+ skip_features = []
170
+
171
+ # Pass through transformer layers with skip connections
172
+ for i, layer in enumerate(self.layers):
173
+ # Add skip connection from earlier layers
174
+ if i > 0 and i < len(self.layers) - 1:
175
+ skip_feat = skip_features[i-1]
176
+ skip_feat = self.skip_projections[i-1](skip_feat)
177
+ x = x + skip_feat
178
+
179
+ # Store current features for future skip connections
180
+ if i < len(self.layers) - 1:
181
+ skip_features.append(x.clone())
182
+
183
+ # Add projected condition embedding to EACH layer
184
+ x = x + projected_emb
185
+
186
+ # Apply transformer layer
187
+ x = layer(x, src_key_padding_mask=mask)
188
+
189
+ # Project back to compressed dimension
190
+ x = self.output_proj(x) # (B, L, compressed_dim)
191
+
192
+ return x
193
+
194
+ class AMPProtFlowPipelineCFG:
195
+ """
196
+ Complete ProtFlow pipeline for AMP generation with CFG.
197
+ """
198
+
199
+ def __init__(self, compressor, decompressor, flow_model, device='cuda'):
200
+ self.compressor = compressor
201
+ self.decompressor = decompressor
202
+ self.flow_model = flow_model
203
+ self.device = device
204
+
205
+ # Load normalization stats
206
+ self.stats = torch.load('normalization_stats.pt', map_location=device)
207
+
208
+ def generate_amps_cfg(self, num_samples=100, num_steps=25, cfg_scale=7.5,
209
+ condition_label=0):
210
+ """
211
+ Generate AMP samples using CFG.
212
+
213
+ Args:
214
+ num_samples: Number of samples to generate
215
+ num_steps: Number of ODE solving steps
216
+ cfg_scale: CFG guidance scale (higher = stronger conditioning)
217
+ condition_label: 0=AMP, 1=Non-AMP, 2=Mask
218
+ """
219
+ print(f"Generating {num_samples} samples with CFG (label={condition_label}, scale={cfg_scale})...")
220
+
221
+ # Sample random noise
222
+ batch_size = min(num_samples, 32) # Process in batches
223
+ all_samples = []
224
+
225
+ for i in range(0, num_samples, batch_size):
226
+ current_batch = min(batch_size, num_samples - i)
227
+
228
+ # Initialize with noise
229
+ eps = torch.randn(current_batch, self.flow_model.max_seq_len,
230
+ self.flow_model.compressed_dim, device=self.device)
231
+
232
+ # ODE solving steps with CFG
233
+ xt = eps.clone()
234
+ for step in range(num_steps):
235
+ t = torch.ones(current_batch, device=self.device) * (1.0 - step/num_steps)
236
+
237
+ # CFG: Generate with condition and without condition
238
+ if cfg_scale > 0:
239
+ # With condition
240
+ vt_cond = self.flow_model(xt, t,
241
+ labels=torch.full((current_batch,), condition_label,
242
+ device=self.device))
243
+
244
+ # Without condition (mask)
245
+ vt_uncond = self.flow_model(xt, t,
246
+ labels=torch.full((current_batch,), 2,
247
+ device=self.device))
248
+
249
+ # CFG interpolation
250
+ vt = vt_uncond + cfg_scale * (vt_cond - vt_uncond)
251
+ else:
252
+ # No CFG, use mask label
253
+ vt = self.flow_model(xt, t,
254
+ labels=torch.full((current_batch,), 2,
255
+ device=self.device))
256
+
257
+ # Euler step for backward integration (t: 1 -> 0)
258
+ # Use negative dt to integrate backward from noise to data
259
+ dt = -1.0 / num_steps
260
+ xt = xt + vt * dt
261
+
262
+ all_samples.append(xt)
263
+
264
+ # Concatenate all batches
265
+ generated = torch.cat(all_samples, dim=0)
266
+
267
+ # Decompress and decode
268
+ with torch.no_grad():
269
+ # Decompress
270
+ decompressed = self.decompressor(generated)
271
+
272
+ # Apply reverse normalization
273
+ m, s, mn, mx = self.stats['mean'], self.stats['std'], self.stats['min'], self.stats['max']
274
+ decompressed = decompressed * (mx - mn + 1e-8) + mn
275
+ decompressed = decompressed * s + m
276
+
277
+ return generated, decompressed
278
+
279
+ # Example usage
280
+ if __name__ == "__main__":
281
+ # Initialize FINAL AMP flow model with CFG using concatenation approach
282
+ flow_model = AMPFlowMatcherCFGConcat(
283
+ hidden_dim=480,
284
+ compressed_dim=30, # 16x compression of 480
285
+ n_layers=12,
286
+ n_heads=16,
287
+ dim_ff=3072,
288
+ max_seq_len=25, # For AMP sequences (max 50, halved by pooling)
289
+ use_cfg=True
290
+ )
291
+
292
+ print(f"FINAL AMP Flow Model with CFG (Concat+Proj) parameters: {sum(p.numel() for p in flow_model.parameters()):,}")
293
+
294
+ # Test forward pass
295
+ batch_size = 4
296
+ seq_len = 20
297
+ compressed_dim = 30
298
+
299
+ x = torch.randn(batch_size, seq_len, compressed_dim)
300
+ t = torch.rand(batch_size)
301
+ labels = torch.randint(0, 3, (batch_size,)) # Random labels
302
+
303
+ with torch.no_grad():
304
+ output = flow_model(x, t, labels=labels)
305
+ print(f"Input shape: {x.shape}")
306
+ print(f"Output shape: {output.shape}")
307
+ print(f"Time embedding shape: {t.shape}")
308
+ print(f"Labels: {labels}")
309
+
310
+ print("🎯 FINAL AMP Flow Model with CFG (Concat+Proj) ready for training!")