FlowFinal / src /compressor_with_embeddings.py
esunAI's picture
Add compressor_with_embeddings.py
97eb7cb verified
raw
history blame
10.8 kB
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR
import json
import numpy as np
from tqdm import tqdm
# ---------------- Hyperparameters ----------------
ESM_DIM = 1280 # ESM-2 hidden dim (esm2_t33_650M_UR50D)
COMP_RATIO = 16 # compression factor
COMP_DIM = ESM_DIM // COMP_RATIO
MAX_SEQ_LEN = 50 # Actual sequence length from final_sequence_encoder.py
BATCH_SIZE = 32
EPOCHS = 30
BASE_LR = 1e-3 # initial learning rate
LR_MIN = 8e-5 # minimum learning rate for cosine schedule
WARMUP_STEPS = 10_000
DEPTH = 4 # total transformer layers (2 pre-pool, 2 post-pool)
HEADS = 8 # attention heads
DIM_FF = ESM_DIM * 4
POOLING = True # enforce ProtFlow hourglass pooling
# ---------------- Dataset for Pre-computed Embeddings ----------------
class PrecomputedEmbeddingDataset(Dataset):
def __init__(self, embeddings_path):
"""
Load pre-computed embeddings from the final_sequence_encoder.py output.
Args:
embeddings_path: Path to the directory containing individual .pt embedding files
"""
print(f"Loading pre-computed embeddings from {embeddings_path}...")
# Load all individual embedding files
import glob
import os
embedding_files = glob.glob(os.path.join(embeddings_path, "*.pt"))
embedding_files = [f for f in embedding_files if not f.endswith('metadata.json') and not f.endswith('sequence_ids.json')]
print(f"Found {len(embedding_files)} embedding files")
# Load and stack all embeddings
embeddings_list = []
for file_path in embedding_files:
try:
embedding = torch.load(file_path)
if embedding.dim() == 2: # (seq_len, hidden_dim)
embeddings_list.append(embedding)
else:
print(f"Warning: Skipping {file_path} - unexpected shape {embedding.shape}")
except Exception as e:
print(f"Warning: Could not load {file_path}: {e}")
if not embeddings_list:
raise ValueError("No valid embeddings found!")
self.embeddings = torch.stack(embeddings_list)
print(f"Loaded {len(self.embeddings)} embeddings with shape {self.embeddings.shape}")
# Ensure embeddings are the right shape
if len(self.embeddings.shape) != 3:
raise ValueError(f"Expected 3D tensor, got shape {self.embeddings.shape}")
if self.embeddings.shape[1] != MAX_SEQ_LEN:
print(f"Warning: Expected sequence length {MAX_SEQ_LEN}, got {self.embeddings.shape[1]}")
if self.embeddings.shape[2] != ESM_DIM:
print(f"Warning: Expected embedding dim {ESM_DIM}, got {self.embeddings.shape[2]}")
def __len__(self):
return len(self.embeddings)
def __getitem__(self, idx):
return self.embeddings[idx]
# ---------------- Compressor ----------------
class Compressor(nn.Module):
def __init__(self, in_dim=ESM_DIM, out_dim=COMP_DIM):
super().__init__()
self.norm = nn.LayerNorm(in_dim)
layer = lambda: nn.TransformerEncoderLayer(
d_model=in_dim, nhead=HEADS, dim_feedforward=DIM_FF,
batch_first=True)
# two layers before pool, two after
self.pre_tr = nn.TransformerEncoder(layer(), num_layers=DEPTH//2)
self.post_tr = nn.TransformerEncoder(layer(), num_layers=DEPTH//2)
self.proj = nn.Sequential(
nn.LayerNorm(in_dim),
nn.Linear(in_dim, out_dim),
nn.Tanh()
)
self.pooling = POOLING
def forward(self, x, stats=None):
if stats:
m, s, mn, mx = stats['mean'], stats['std'], stats['min'], stats['max']
# Move stats to the same device as x
m = m.to(x.device)
s = s.to(x.device)
mn = mn.to(x.device)
mx = mx.to(x.device)
x = torch.clamp((x - m) / s, -4, 4)
x = torch.clamp((x - mn) / (mx - mn + 1e-8), 0, 1)
x = self.norm(x)
x = self.pre_tr(x) # [B, L, D]
if self.pooling:
B, L, D = x.shape
if L % 2: x = x[:, :-1, :]
x = x.view(B, L//2, 2, D).mean(2) # halve sequence length
x = self.post_tr(x) # [B, L' , D]
return self.proj(x) # [B, L', COMP_DIM]
# ---------------- Decompressor ----------------
class Decompressor(nn.Module):
def __init__(self, in_dim=COMP_DIM, out_dim=ESM_DIM):
super().__init__()
self.proj = nn.Sequential(
nn.LayerNorm(in_dim),
nn.Linear(in_dim, out_dim)
)
layer = lambda: nn.TransformerEncoderLayer(
d_model=out_dim, nhead=HEADS, dim_feedforward=DIM_FF,
batch_first=True)
self.decoder = nn.TransformerEncoder(layer(), num_layers=DEPTH//2)
self.pooling = POOLING
def forward(self, z):
x = self.proj(z) # [B, L', D]
if self.pooling:
x = x.repeat_interleave(2, dim=1) # unpool to full length
return self.decoder(x) # [B, L, out_dim]
# ---------------- Training Loop ----------------
def train_with_precomputed_embeddings(embeddings_path, device='cuda'):
"""
Train compressor using pre-computed embeddings from final_sequence_encoder.py
"""
# Load dataset
ds = PrecomputedEmbeddingDataset(embeddings_path)
# Compute normalization statistics
print("Computing normalization statistics...")
flat = ds.embeddings.view(-1, ESM_DIM)
stats = {
'mean': flat.mean(0),
'std': flat.std(0) + 1e-8,
'min': torch.clamp((flat - flat.mean(0)) / (flat.std(0) + 1e-8), -4,4).min(0)[0],
'max': torch.clamp((flat - flat.mean(0)) / (flat.std(0) + 1e-8), -4,4).max(0)[0]
}
# Save statistics for later use
torch.save(stats, 'normalization_stats.pt')
print("Saved normalization statistics to normalization_stats.pt")
# Create data loader
dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True)
# Initialize models
comp = Compressor().to(device)
decomp = Decompressor().to(device)
# Initialize optimizer
opt = optim.AdamW(list(comp.parameters()) + list(decomp.parameters()), lr=BASE_LR)
# LR scheduling: warmup -> cosine
warmup_sched = LinearLR(opt, start_factor=1e-8, end_factor=1.0, total_iters=WARMUP_STEPS)
cosine_sched = CosineAnnealingLR(opt, T_max=EPOCHS*len(dl), eta_min=LR_MIN)
sched = SequentialLR(opt, [warmup_sched, cosine_sched], milestones=[WARMUP_STEPS])
print(f"Starting training for {EPOCHS} epochs...")
print(f"Device: {device}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Total batches per epoch: {len(dl)}")
# Training loop
for epoch in range(1, EPOCHS+1):
total_loss = 0
comp.train()
decomp.train()
for batch_idx, x in enumerate(tqdm(dl, desc=f"Epoch {epoch}/{EPOCHS}")):
x = x.to(device)
z = comp(x, stats)
xr = decomp(z)
loss = (x - xr).pow(2).mean()
opt.zero_grad()
loss.backward()
opt.step()
sched.step()
total_loss += loss.item()
# Print progress every 100 batches
if batch_idx % 100 == 0:
print(f" Batch {batch_idx}/{len(dl)} - Loss: {loss.item():.6f}")
avg_loss = total_loss / len(dl)
print(f"Epoch {epoch}/{EPOCHS} — Average MSE: {avg_loss:.6f}")
# Save checkpoint every 5 epochs
if epoch % 5 == 0:
torch.save({
'epoch': epoch,
'compressor_state_dict': comp.state_dict(),
'decompressor_state_dict': decomp.state_dict(),
'optimizer_state_dict': opt.state_dict(),
'loss': avg_loss,
}, f'checkpoint_epoch_{epoch}.pth')
# Save final models
torch.save(comp.state_dict(), 'compressor_final.pth')
torch.save(decomp.state_dict(), 'decompressor_final.pth')
print("Training completed! Models saved as compressor_final.pth and decompressor_final.pth")
# ---------------- Utility Functions ----------------
def load_and_test_models(compressor_path, decompressor_path, embeddings_path, device='cuda'):
"""
Load trained models and test reconstruction quality
"""
print("Loading trained models...")
comp = Compressor().to(device)
decomp = Decompressor().to(device)
comp.load_state_dict(torch.load(compressor_path))
decomp.load_state_dict(torch.load(decompressor_path))
comp.eval()
decomp.eval()
# Load test data
ds = PrecomputedEmbeddingDataset(embeddings_path)
test_loader = DataLoader(ds, batch_size=16, shuffle=False)
# Load normalization stats
stats = torch.load('normalization_stats.pt')
print("Testing reconstruction quality...")
total_mse = 0
total_samples = 0
with torch.no_grad():
for batch in tqdm(test_loader, desc="Testing"):
x = batch.to(device)
z = comp(x, stats)
xr = decomp(z)
mse = (x - xr).pow(2).mean()
total_mse += mse.item() * len(x)
total_samples += len(x)
avg_mse = total_mse / total_samples
print(f"Average reconstruction MSE: {avg_mse:.6f}")
return avg_mse
# ---------------- Entrypoint ----------------
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Train protein compressor with pre-computed embeddings')
parser.add_argument('--embeddings', type=str, default='/data2/edwardsun/flow_project/compressor_dataset/peptide_embeddings.pt',
help='Path to pre-computed embeddings from final_sequence_encoder.py')
parser.add_argument('--device', type=str, default='cuda', help='Device to use (cuda/cpu)')
parser.add_argument('--test', action='store_true', help='Test existing models instead of training')
args = parser.parse_args()
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if args.test:
# Test existing models
load_and_test_models('compressor_final.pth', 'decompressor_final.pth', args.embeddings, device)
else:
# Train new models
train_with_precomputed_embeddings(args.embeddings, device)