File size: 12,043 Bytes
a1fa8fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SinusoidalTimeEmbedding(nn.Module):
    """Sinusoidal time embedding as used in ProtFlow paper."""
    
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        
    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        # Ensure time is 2D: [B, 1] and embeddings is 1D: [half_dim]
        if time.dim() > 2:
            time = time.squeeze()  # Remove extra dimensions
        embeddings = time.unsqueeze(-1) * embeddings.unsqueeze(0)  # [B, half_dim]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)  # [B, dim]
        # Ensure output is exactly 2D
        if embeddings.dim() > 2:
            embeddings = embeddings.squeeze()
        return embeddings

class LabelMLP(nn.Module):
    """
    MLP for processing class labels into embeddings.
    This approach processes labels separately from time embeddings.
    """
    def __init__(self, num_classes=3, hidden_dim=480, mlp_dim=256):
        super().__init__()
        self.num_classes = num_classes
        
        # MLP to process labels
        self.label_mlp = nn.Sequential(
            nn.Embedding(num_classes, mlp_dim),
            nn.Linear(mlp_dim, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Initialize embeddings
        nn.init.normal_(self.label_mlp[0].weight, std=0.02)
    
    def forward(self, labels):
        """
        Args:
            labels: (B,) tensor of class labels
                   - 0: AMP (MIC < 100)
                   - 1: Non-AMP (MIC >= 100)
                   - 2: Mask (Unknown MIC)
        Returns:
            embeddings: (B, hidden_dim) tensor of processed label embeddings
        """
        return self.label_mlp(labels)

class AMPFlowMatcherCFGConcat(nn.Module):
    """
    Flow Matching model with Classifier-Free Guidance using concatenation approach.
    - 12-layer transformer with long skip connections
    - Time embedding + MLP-processed label embedding (concatenated then projected)
    - Optimized for peptide sequences (max length 50)
    """
    
    def __init__(self, hidden_dim=480, compressed_dim=30, n_layers=12, n_heads=16, 
                 dim_ff=3072, dropout=0.1, max_seq_len=25, use_cfg=True):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.compressed_dim = compressed_dim
        self.n_layers = n_layers
        self.max_seq_len = max_seq_len
        self.use_cfg = use_cfg
        
        # Time embedding
        self.time_embed = nn.Sequential(
            SinusoidalTimeEmbedding(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # CFG components using concatenation approach
        if use_cfg:
            self.label_mlp = LabelMLP(num_classes=3, hidden_dim=hidden_dim)
            
            # Projection layer for concatenated time + label embeddings
            self.condition_proj = nn.Sequential(
                nn.Linear(hidden_dim * 2, hidden_dim),  # 2 for time + label
                nn.GELU(),
                nn.Linear(hidden_dim, hidden_dim)
            )
        
        # Projection layers for compressed space
        self.compress_proj = nn.Linear(compressed_dim, hidden_dim)
        self.decompress_proj = nn.Linear(hidden_dim, compressed_dim)
        
        # Positional encoding for peptide sequences
        self.pos_embed = nn.Parameter(torch.randn(1, max_seq_len, hidden_dim))
        
        # Transformer layers with long skip connections
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=hidden_dim,
                nhead=n_heads,
                dim_feedforward=dim_ff,
                dropout=dropout,
                activation='gelu',
                batch_first=True
            ) for _ in range(n_layers)
        ])
        
        # Long skip connections (U-ViT style)
        self.skip_projections = nn.ModuleList([
            nn.Linear(hidden_dim, hidden_dim) for _ in range(n_layers - 1)
        ])
        
        # Output projection
        self.output_proj = nn.Linear(hidden_dim, compressed_dim)
        
    def forward(self, x, t, labels=None, mask=None):
        """
        Args:
            x: compressed latent (B, L, compressed_dim) - AMP embeddings
            t: time scalar (B,) or (B, 1)
            labels: class labels (B,) for CFG - 0=AMP, 1=Non-AMP, 2=Mask
            mask: attention mask (B, L) if needed
        """
        B, L, D = x.shape
        
        # Project to hidden dimension
        x = self.compress_proj(x)  # (B, L, hidden_dim)
        
        # Add positional encoding
        if L <= self.max_seq_len:
            x = x + self.pos_embed[:, :L, :]
        
        # Time embedding - ensure t is 2D (B, 1)
        if t.dim() == 1:
            t = t.unsqueeze(-1)  # (B, 1)
        elif t.dim() > 2:
            t = t.squeeze()  # Remove extra dimensions
            if t.dim() == 1:
                t = t.unsqueeze(-1)  # (B, 1)
        
        t_emb = self.time_embed(t)  # (B, hidden_dim)
        # Ensure t_emb is 2D before expanding
        if t_emb.dim() > 2:
            t_emb = t_emb.squeeze()  # Remove extra dimensions
        t_emb = t_emb.unsqueeze(1).expand(-1, L, -1)  # (B, L, hidden_dim)
        
        # CFG: Process label embedding if enabled
        if self.use_cfg and labels is not None:
            # Process labels through MLP
            label_emb = self.label_mlp(labels)  # (B, hidden_dim)
            label_emb = label_emb.unsqueeze(1).expand(-1, L, -1)  # (B, L, hidden_dim)
            
            # Professor's approach: Concatenate time and label embeddings
            combined_emb = torch.cat([t_emb, label_emb], dim=-1)  # (B, L, hidden_dim*2)
            projected_emb = self.condition_proj(combined_emb)  # (B, L, hidden_dim)
        else:
            projected_emb = t_emb  # Just use time embedding if no CFG
        
        # Store intermediate representations for skip connections
        skip_features = []
        
        # Pass through transformer layers with skip connections
        for i, layer in enumerate(self.layers):
            # Add skip connection from earlier layers
            if i > 0 and i < len(self.layers) - 1:
                skip_feat = skip_features[i-1]
                skip_feat = self.skip_projections[i-1](skip_feat)
                x = x + skip_feat
            
            # Store current features for future skip connections
            if i < len(self.layers) - 1:
                skip_features.append(x.clone())
            
            # Add projected condition embedding to EACH layer
            x = x + projected_emb
            
            # Apply transformer layer
            x = layer(x, src_key_padding_mask=mask)
        
        # Project back to compressed dimension
        x = self.output_proj(x)  # (B, L, compressed_dim)
        
        return x

class AMPProtFlowPipelineCFG:
    """
    Complete ProtFlow pipeline for AMP generation with CFG.
    """
    
    def __init__(self, compressor, decompressor, flow_model, device='cuda'):
        self.compressor = compressor
        self.decompressor = decompressor
        self.flow_model = flow_model
        self.device = device
        
        # Load normalization stats
        self.stats = torch.load('normalization_stats.pt', map_location=device)
        
    def generate_amps_cfg(self, num_samples=100, num_steps=25, cfg_scale=7.5, 
                         condition_label=0):
        """
        Generate AMP samples using CFG.
        
        Args:
            num_samples: Number of samples to generate
            num_steps: Number of ODE solving steps
            cfg_scale: CFG guidance scale (higher = stronger conditioning)
            condition_label: 0=AMP, 1=Non-AMP, 2=Mask
        """
        print(f"Generating {num_samples} samples with CFG (label={condition_label}, scale={cfg_scale})...")
        
        # Sample random noise
        batch_size = min(num_samples, 32)  # Process in batches
        all_samples = []
        
        for i in range(0, num_samples, batch_size):
            current_batch = min(batch_size, num_samples - i)
            
            # Initialize with noise
            eps = torch.randn(current_batch, self.flow_model.max_seq_len, 
                            self.flow_model.compressed_dim, device=self.device)
            
            # ODE solving steps with CFG
            xt = eps.clone()
            for step in range(num_steps):
                t = torch.ones(current_batch, device=self.device) * (1.0 - step/num_steps)
                
                # CFG: Generate with condition and without condition
                if cfg_scale > 0:
                    # With condition
                    vt_cond = self.flow_model(xt, t, 
                                            labels=torch.full((current_batch,), condition_label, 
                                                            device=self.device))
                    
                    # Without condition (mask)
                    vt_uncond = self.flow_model(xt, t, 
                                              labels=torch.full((current_batch,), 2, 
                                                              device=self.device))
                    
                    # CFG interpolation
                    vt = vt_uncond + cfg_scale * (vt_cond - vt_uncond)
                else:
                    # No CFG, use mask label
                    vt = self.flow_model(xt, t, 
                                       labels=torch.full((current_batch,), 2, 
                                                       device=self.device))
                
                # Euler step for backward integration (t: 1 -> 0)
                # Use negative dt to integrate backward from noise to data
                dt = -1.0 / num_steps
                xt = xt + vt * dt
            
            all_samples.append(xt)
        
        # Concatenate all batches
        generated = torch.cat(all_samples, dim=0)
        
        # Decompress and decode
        with torch.no_grad():
            # Decompress
            decompressed = self.decompressor(generated)
            
            # Apply reverse normalization
            m, s, mn, mx = self.stats['mean'], self.stats['std'], self.stats['min'], self.stats['max']
            decompressed = decompressed * (mx - mn + 1e-8) + mn
            decompressed = decompressed * s + m
            
        return generated, decompressed

# Example usage
if __name__ == "__main__":
    # Initialize FINAL AMP flow model with CFG using concatenation approach
    flow_model = AMPFlowMatcherCFGConcat(
        hidden_dim=480,
        compressed_dim=30,  # 16x compression of 480
        n_layers=12,
        n_heads=16,
        dim_ff=3072,
        max_seq_len=25,  # For AMP sequences (max 50, halved by pooling)
        use_cfg=True
    )
    
    print(f"FINAL AMP Flow Model with CFG (Concat+Proj) parameters: {sum(p.numel() for p in flow_model.parameters()):,}")
    
    # Test forward pass
    batch_size = 4
    seq_len = 20
    compressed_dim = 30
    
    x = torch.randn(batch_size, seq_len, compressed_dim)
    t = torch.rand(batch_size)
    labels = torch.randint(0, 3, (batch_size,))  # Random labels
    
    with torch.no_grad():
        output = flow_model(x, t, labels=labels)
        print(f"Input shape: {x.shape}")
        print(f"Output shape: {output.shape}")
        print(f"Time embedding shape: {t.shape}")
        print(f"Labels: {labels}")
    
    print("🎯 FINAL AMP Flow Model with CFG (Concat+Proj) ready for training!")