esunAI commited on
Commit
97eb7cb
·
verified ·
1 Parent(s): a1fa8fa

Add compressor_with_embeddings.py

Browse files
Files changed (1) hide show
  1. src/compressor_with_embeddings.py +278 -0
src/compressor_with_embeddings.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR
6
+ import json
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+
10
+ # ---------------- Hyperparameters ----------------
11
+ ESM_DIM = 1280 # ESM-2 hidden dim (esm2_t33_650M_UR50D)
12
+ COMP_RATIO = 16 # compression factor
13
+ COMP_DIM = ESM_DIM // COMP_RATIO
14
+ MAX_SEQ_LEN = 50 # Actual sequence length from final_sequence_encoder.py
15
+ BATCH_SIZE = 32
16
+ EPOCHS = 30
17
+ BASE_LR = 1e-3 # initial learning rate
18
+ LR_MIN = 8e-5 # minimum learning rate for cosine schedule
19
+ WARMUP_STEPS = 10_000
20
+ DEPTH = 4 # total transformer layers (2 pre-pool, 2 post-pool)
21
+ HEADS = 8 # attention heads
22
+ DIM_FF = ESM_DIM * 4
23
+ POOLING = True # enforce ProtFlow hourglass pooling
24
+
25
+ # ---------------- Dataset for Pre-computed Embeddings ----------------
26
+ class PrecomputedEmbeddingDataset(Dataset):
27
+ def __init__(self, embeddings_path):
28
+ """
29
+ Load pre-computed embeddings from the final_sequence_encoder.py output.
30
+ Args:
31
+ embeddings_path: Path to the directory containing individual .pt embedding files
32
+ """
33
+ print(f"Loading pre-computed embeddings from {embeddings_path}...")
34
+
35
+ # Load all individual embedding files
36
+ import glob
37
+ import os
38
+
39
+ embedding_files = glob.glob(os.path.join(embeddings_path, "*.pt"))
40
+ embedding_files = [f for f in embedding_files if not f.endswith('metadata.json') and not f.endswith('sequence_ids.json')]
41
+
42
+ print(f"Found {len(embedding_files)} embedding files")
43
+
44
+ # Load and stack all embeddings
45
+ embeddings_list = []
46
+ for file_path in embedding_files:
47
+ try:
48
+ embedding = torch.load(file_path)
49
+ if embedding.dim() == 2: # (seq_len, hidden_dim)
50
+ embeddings_list.append(embedding)
51
+ else:
52
+ print(f"Warning: Skipping {file_path} - unexpected shape {embedding.shape}")
53
+ except Exception as e:
54
+ print(f"Warning: Could not load {file_path}: {e}")
55
+
56
+ if not embeddings_list:
57
+ raise ValueError("No valid embeddings found!")
58
+
59
+ self.embeddings = torch.stack(embeddings_list)
60
+ print(f"Loaded {len(self.embeddings)} embeddings with shape {self.embeddings.shape}")
61
+
62
+ # Ensure embeddings are the right shape
63
+ if len(self.embeddings.shape) != 3:
64
+ raise ValueError(f"Expected 3D tensor, got shape {self.embeddings.shape}")
65
+
66
+ if self.embeddings.shape[1] != MAX_SEQ_LEN:
67
+ print(f"Warning: Expected sequence length {MAX_SEQ_LEN}, got {self.embeddings.shape[1]}")
68
+
69
+ if self.embeddings.shape[2] != ESM_DIM:
70
+ print(f"Warning: Expected embedding dim {ESM_DIM}, got {self.embeddings.shape[2]}")
71
+
72
+ def __len__(self):
73
+ return len(self.embeddings)
74
+
75
+ def __getitem__(self, idx):
76
+ return self.embeddings[idx]
77
+
78
+ # ---------------- Compressor ----------------
79
+ class Compressor(nn.Module):
80
+ def __init__(self, in_dim=ESM_DIM, out_dim=COMP_DIM):
81
+ super().__init__()
82
+ self.norm = nn.LayerNorm(in_dim)
83
+ layer = lambda: nn.TransformerEncoderLayer(
84
+ d_model=in_dim, nhead=HEADS, dim_feedforward=DIM_FF,
85
+ batch_first=True)
86
+ # two layers before pool, two after
87
+ self.pre_tr = nn.TransformerEncoder(layer(), num_layers=DEPTH//2)
88
+ self.post_tr = nn.TransformerEncoder(layer(), num_layers=DEPTH//2)
89
+ self.proj = nn.Sequential(
90
+ nn.LayerNorm(in_dim),
91
+ nn.Linear(in_dim, out_dim),
92
+ nn.Tanh()
93
+ )
94
+ self.pooling = POOLING
95
+
96
+ def forward(self, x, stats=None):
97
+ if stats:
98
+ m, s, mn, mx = stats['mean'], stats['std'], stats['min'], stats['max']
99
+ # Move stats to the same device as x
100
+ m = m.to(x.device)
101
+ s = s.to(x.device)
102
+ mn = mn.to(x.device)
103
+ mx = mx.to(x.device)
104
+ x = torch.clamp((x - m) / s, -4, 4)
105
+ x = torch.clamp((x - mn) / (mx - mn + 1e-8), 0, 1)
106
+ x = self.norm(x)
107
+ x = self.pre_tr(x) # [B, L, D]
108
+ if self.pooling:
109
+ B, L, D = x.shape
110
+ if L % 2: x = x[:, :-1, :]
111
+ x = x.view(B, L//2, 2, D).mean(2) # halve sequence length
112
+ x = self.post_tr(x) # [B, L' , D]
113
+ return self.proj(x) # [B, L', COMP_DIM]
114
+
115
+ # ---------------- Decompressor ----------------
116
+ class Decompressor(nn.Module):
117
+ def __init__(self, in_dim=COMP_DIM, out_dim=ESM_DIM):
118
+ super().__init__()
119
+ self.proj = nn.Sequential(
120
+ nn.LayerNorm(in_dim),
121
+ nn.Linear(in_dim, out_dim)
122
+ )
123
+ layer = lambda: nn.TransformerEncoderLayer(
124
+ d_model=out_dim, nhead=HEADS, dim_feedforward=DIM_FF,
125
+ batch_first=True)
126
+ self.decoder = nn.TransformerEncoder(layer(), num_layers=DEPTH//2)
127
+ self.pooling = POOLING
128
+
129
+ def forward(self, z):
130
+ x = self.proj(z) # [B, L', D]
131
+ if self.pooling:
132
+ x = x.repeat_interleave(2, dim=1) # unpool to full length
133
+ return self.decoder(x) # [B, L, out_dim]
134
+
135
+ # ---------------- Training Loop ----------------
136
+ def train_with_precomputed_embeddings(embeddings_path, device='cuda'):
137
+ """
138
+ Train compressor using pre-computed embeddings from final_sequence_encoder.py
139
+ """
140
+ # Load dataset
141
+ ds = PrecomputedEmbeddingDataset(embeddings_path)
142
+
143
+ # Compute normalization statistics
144
+ print("Computing normalization statistics...")
145
+ flat = ds.embeddings.view(-1, ESM_DIM)
146
+ stats = {
147
+ 'mean': flat.mean(0),
148
+ 'std': flat.std(0) + 1e-8,
149
+ 'min': torch.clamp((flat - flat.mean(0)) / (flat.std(0) + 1e-8), -4,4).min(0)[0],
150
+ 'max': torch.clamp((flat - flat.mean(0)) / (flat.std(0) + 1e-8), -4,4).max(0)[0]
151
+ }
152
+
153
+ # Save statistics for later use
154
+ torch.save(stats, 'normalization_stats.pt')
155
+ print("Saved normalization statistics to normalization_stats.pt")
156
+
157
+ # Create data loader
158
+ dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True)
159
+
160
+ # Initialize models
161
+ comp = Compressor().to(device)
162
+ decomp = Decompressor().to(device)
163
+
164
+ # Initialize optimizer
165
+ opt = optim.AdamW(list(comp.parameters()) + list(decomp.parameters()), lr=BASE_LR)
166
+
167
+ # LR scheduling: warmup -> cosine
168
+ warmup_sched = LinearLR(opt, start_factor=1e-8, end_factor=1.0, total_iters=WARMUP_STEPS)
169
+ cosine_sched = CosineAnnealingLR(opt, T_max=EPOCHS*len(dl), eta_min=LR_MIN)
170
+ sched = SequentialLR(opt, [warmup_sched, cosine_sched], milestones=[WARMUP_STEPS])
171
+
172
+ print(f"Starting training for {EPOCHS} epochs...")
173
+ print(f"Device: {device}")
174
+ print(f"Batch size: {BATCH_SIZE}")
175
+ print(f"Total batches per epoch: {len(dl)}")
176
+
177
+ # Training loop
178
+ for epoch in range(1, EPOCHS+1):
179
+ total_loss = 0
180
+ comp.train()
181
+ decomp.train()
182
+
183
+ for batch_idx, x in enumerate(tqdm(dl, desc=f"Epoch {epoch}/{EPOCHS}")):
184
+ x = x.to(device)
185
+ z = comp(x, stats)
186
+ xr = decomp(z)
187
+ loss = (x - xr).pow(2).mean()
188
+
189
+ opt.zero_grad()
190
+ loss.backward()
191
+ opt.step()
192
+ sched.step()
193
+
194
+ total_loss += loss.item()
195
+
196
+ # Print progress every 100 batches
197
+ if batch_idx % 100 == 0:
198
+ print(f" Batch {batch_idx}/{len(dl)} - Loss: {loss.item():.6f}")
199
+
200
+ avg_loss = total_loss / len(dl)
201
+ print(f"Epoch {epoch}/{EPOCHS} — Average MSE: {avg_loss:.6f}")
202
+
203
+ # Save checkpoint every 5 epochs
204
+ if epoch % 5 == 0:
205
+ torch.save({
206
+ 'epoch': epoch,
207
+ 'compressor_state_dict': comp.state_dict(),
208
+ 'decompressor_state_dict': decomp.state_dict(),
209
+ 'optimizer_state_dict': opt.state_dict(),
210
+ 'loss': avg_loss,
211
+ }, f'checkpoint_epoch_{epoch}.pth')
212
+
213
+ # Save final models
214
+ torch.save(comp.state_dict(), 'compressor_final.pth')
215
+ torch.save(decomp.state_dict(), 'decompressor_final.pth')
216
+ print("Training completed! Models saved as compressor_final.pth and decompressor_final.pth")
217
+
218
+ # ---------------- Utility Functions ----------------
219
+ def load_and_test_models(compressor_path, decompressor_path, embeddings_path, device='cuda'):
220
+ """
221
+ Load trained models and test reconstruction quality
222
+ """
223
+ print("Loading trained models...")
224
+ comp = Compressor().to(device)
225
+ decomp = Decompressor().to(device)
226
+
227
+ comp.load_state_dict(torch.load(compressor_path))
228
+ decomp.load_state_dict(torch.load(decompressor_path))
229
+
230
+ comp.eval()
231
+ decomp.eval()
232
+
233
+ # Load test data
234
+ ds = PrecomputedEmbeddingDataset(embeddings_path)
235
+ test_loader = DataLoader(ds, batch_size=16, shuffle=False)
236
+
237
+ # Load normalization stats
238
+ stats = torch.load('normalization_stats.pt')
239
+
240
+ print("Testing reconstruction quality...")
241
+ total_mse = 0
242
+ total_samples = 0
243
+
244
+ with torch.no_grad():
245
+ for batch in tqdm(test_loader, desc="Testing"):
246
+ x = batch.to(device)
247
+ z = comp(x, stats)
248
+ xr = decomp(z)
249
+ mse = (x - xr).pow(2).mean()
250
+ total_mse += mse.item() * len(x)
251
+ total_samples += len(x)
252
+
253
+ avg_mse = total_mse / total_samples
254
+ print(f"Average reconstruction MSE: {avg_mse:.6f}")
255
+
256
+ return avg_mse
257
+
258
+ # ---------------- Entrypoint ----------------
259
+ if __name__ == '__main__':
260
+ import argparse
261
+
262
+ parser = argparse.ArgumentParser(description='Train protein compressor with pre-computed embeddings')
263
+ parser.add_argument('--embeddings', type=str, default='/data2/edwardsun/flow_project/compressor_dataset/peptide_embeddings.pt',
264
+ help='Path to pre-computed embeddings from final_sequence_encoder.py')
265
+ parser.add_argument('--device', type=str, default='cuda', help='Device to use (cuda/cpu)')
266
+ parser.add_argument('--test', action='store_true', help='Test existing models instead of training')
267
+
268
+ args = parser.parse_args()
269
+
270
+ device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
271
+ print(f"Using device: {device}")
272
+
273
+ if args.test:
274
+ # Test existing models
275
+ load_and_test_models('compressor_final.pth', 'decompressor_final.pth', args.embeddings, device)
276
+ else:
277
+ # Train new models
278
+ train_with_precomputed_embeddings(args.embeddings, device)