|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR |
|
|
import json |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
ESM_DIM = 1280 |
|
|
COMP_RATIO = 16 |
|
|
COMP_DIM = ESM_DIM // COMP_RATIO |
|
|
MAX_SEQ_LEN = 50 |
|
|
BATCH_SIZE = 32 |
|
|
EPOCHS = 30 |
|
|
BASE_LR = 1e-3 |
|
|
LR_MIN = 8e-5 |
|
|
WARMUP_STEPS = 10_000 |
|
|
DEPTH = 4 |
|
|
HEADS = 8 |
|
|
DIM_FF = ESM_DIM * 4 |
|
|
POOLING = True |
|
|
|
|
|
|
|
|
class PrecomputedEmbeddingDataset(Dataset): |
|
|
def __init__(self, embeddings_path): |
|
|
""" |
|
|
Load pre-computed embeddings from the final_sequence_encoder.py output. |
|
|
Args: |
|
|
embeddings_path: Path to the directory containing individual .pt embedding files |
|
|
""" |
|
|
print(f"Loading pre-computed embeddings from {embeddings_path}...") |
|
|
|
|
|
|
|
|
import glob |
|
|
import os |
|
|
|
|
|
embedding_files = glob.glob(os.path.join(embeddings_path, "*.pt")) |
|
|
embedding_files = [f for f in embedding_files if not f.endswith('metadata.json') and not f.endswith('sequence_ids.json')] |
|
|
|
|
|
print(f"Found {len(embedding_files)} embedding files") |
|
|
|
|
|
|
|
|
embeddings_list = [] |
|
|
for file_path in embedding_files: |
|
|
try: |
|
|
embedding = torch.load(file_path) |
|
|
if embedding.dim() == 2: |
|
|
embeddings_list.append(embedding) |
|
|
else: |
|
|
print(f"Warning: Skipping {file_path} - unexpected shape {embedding.shape}") |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not load {file_path}: {e}") |
|
|
|
|
|
if not embeddings_list: |
|
|
raise ValueError("No valid embeddings found!") |
|
|
|
|
|
self.embeddings = torch.stack(embeddings_list) |
|
|
print(f"Loaded {len(self.embeddings)} embeddings with shape {self.embeddings.shape}") |
|
|
|
|
|
|
|
|
if len(self.embeddings.shape) != 3: |
|
|
raise ValueError(f"Expected 3D tensor, got shape {self.embeddings.shape}") |
|
|
|
|
|
if self.embeddings.shape[1] != MAX_SEQ_LEN: |
|
|
print(f"Warning: Expected sequence length {MAX_SEQ_LEN}, got {self.embeddings.shape[1]}") |
|
|
|
|
|
if self.embeddings.shape[2] != ESM_DIM: |
|
|
print(f"Warning: Expected embedding dim {ESM_DIM}, got {self.embeddings.shape[2]}") |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.embeddings) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
return self.embeddings[idx] |
|
|
|
|
|
|
|
|
class Compressor(nn.Module): |
|
|
def __init__(self, in_dim=ESM_DIM, out_dim=COMP_DIM): |
|
|
super().__init__() |
|
|
self.norm = nn.LayerNorm(in_dim) |
|
|
layer = lambda: nn.TransformerEncoderLayer( |
|
|
d_model=in_dim, nhead=HEADS, dim_feedforward=DIM_FF, |
|
|
batch_first=True) |
|
|
|
|
|
self.pre_tr = nn.TransformerEncoder(layer(), num_layers=DEPTH//2) |
|
|
self.post_tr = nn.TransformerEncoder(layer(), num_layers=DEPTH//2) |
|
|
self.proj = nn.Sequential( |
|
|
nn.LayerNorm(in_dim), |
|
|
nn.Linear(in_dim, out_dim), |
|
|
nn.Tanh() |
|
|
) |
|
|
self.pooling = POOLING |
|
|
|
|
|
def forward(self, x, stats=None): |
|
|
if stats: |
|
|
m, s, mn, mx = stats['mean'], stats['std'], stats['min'], stats['max'] |
|
|
|
|
|
m = m.to(x.device) |
|
|
s = s.to(x.device) |
|
|
mn = mn.to(x.device) |
|
|
mx = mx.to(x.device) |
|
|
x = torch.clamp((x - m) / s, -4, 4) |
|
|
x = torch.clamp((x - mn) / (mx - mn + 1e-8), 0, 1) |
|
|
x = self.norm(x) |
|
|
x = self.pre_tr(x) |
|
|
if self.pooling: |
|
|
B, L, D = x.shape |
|
|
if L % 2: x = x[:, :-1, :] |
|
|
x = x.view(B, L//2, 2, D).mean(2) |
|
|
x = self.post_tr(x) |
|
|
return self.proj(x) |
|
|
|
|
|
|
|
|
class Decompressor(nn.Module): |
|
|
def __init__(self, in_dim=COMP_DIM, out_dim=ESM_DIM): |
|
|
super().__init__() |
|
|
self.proj = nn.Sequential( |
|
|
nn.LayerNorm(in_dim), |
|
|
nn.Linear(in_dim, out_dim) |
|
|
) |
|
|
layer = lambda: nn.TransformerEncoderLayer( |
|
|
d_model=out_dim, nhead=HEADS, dim_feedforward=DIM_FF, |
|
|
batch_first=True) |
|
|
self.decoder = nn.TransformerEncoder(layer(), num_layers=DEPTH//2) |
|
|
self.pooling = POOLING |
|
|
|
|
|
def forward(self, z): |
|
|
x = self.proj(z) |
|
|
if self.pooling: |
|
|
x = x.repeat_interleave(2, dim=1) |
|
|
return self.decoder(x) |
|
|
|
|
|
|
|
|
def train_with_precomputed_embeddings(embeddings_path, device='cuda'): |
|
|
""" |
|
|
Train compressor using pre-computed embeddings from final_sequence_encoder.py |
|
|
""" |
|
|
|
|
|
ds = PrecomputedEmbeddingDataset(embeddings_path) |
|
|
|
|
|
|
|
|
print("Computing normalization statistics...") |
|
|
flat = ds.embeddings.view(-1, ESM_DIM) |
|
|
stats = { |
|
|
'mean': flat.mean(0), |
|
|
'std': flat.std(0) + 1e-8, |
|
|
'min': torch.clamp((flat - flat.mean(0)) / (flat.std(0) + 1e-8), -4,4).min(0)[0], |
|
|
'max': torch.clamp((flat - flat.mean(0)) / (flat.std(0) + 1e-8), -4,4).max(0)[0] |
|
|
} |
|
|
|
|
|
|
|
|
torch.save(stats, 'normalization_stats.pt') |
|
|
print("Saved normalization statistics to normalization_stats.pt") |
|
|
|
|
|
|
|
|
dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True) |
|
|
|
|
|
|
|
|
comp = Compressor().to(device) |
|
|
decomp = Decompressor().to(device) |
|
|
|
|
|
|
|
|
opt = optim.AdamW(list(comp.parameters()) + list(decomp.parameters()), lr=BASE_LR) |
|
|
|
|
|
|
|
|
warmup_sched = LinearLR(opt, start_factor=1e-8, end_factor=1.0, total_iters=WARMUP_STEPS) |
|
|
cosine_sched = CosineAnnealingLR(opt, T_max=EPOCHS*len(dl), eta_min=LR_MIN) |
|
|
sched = SequentialLR(opt, [warmup_sched, cosine_sched], milestones=[WARMUP_STEPS]) |
|
|
|
|
|
print(f"Starting training for {EPOCHS} epochs...") |
|
|
print(f"Device: {device}") |
|
|
print(f"Batch size: {BATCH_SIZE}") |
|
|
print(f"Total batches per epoch: {len(dl)}") |
|
|
|
|
|
|
|
|
for epoch in range(1, EPOCHS+1): |
|
|
total_loss = 0 |
|
|
comp.train() |
|
|
decomp.train() |
|
|
|
|
|
for batch_idx, x in enumerate(tqdm(dl, desc=f"Epoch {epoch}/{EPOCHS}")): |
|
|
x = x.to(device) |
|
|
z = comp(x, stats) |
|
|
xr = decomp(z) |
|
|
loss = (x - xr).pow(2).mean() |
|
|
|
|
|
opt.zero_grad() |
|
|
loss.backward() |
|
|
opt.step() |
|
|
sched.step() |
|
|
|
|
|
total_loss += loss.item() |
|
|
|
|
|
|
|
|
if batch_idx % 100 == 0: |
|
|
print(f" Batch {batch_idx}/{len(dl)} - Loss: {loss.item():.6f}") |
|
|
|
|
|
avg_loss = total_loss / len(dl) |
|
|
print(f"Epoch {epoch}/{EPOCHS} — Average MSE: {avg_loss:.6f}") |
|
|
|
|
|
|
|
|
if epoch % 5 == 0: |
|
|
torch.save({ |
|
|
'epoch': epoch, |
|
|
'compressor_state_dict': comp.state_dict(), |
|
|
'decompressor_state_dict': decomp.state_dict(), |
|
|
'optimizer_state_dict': opt.state_dict(), |
|
|
'loss': avg_loss, |
|
|
}, f'checkpoint_epoch_{epoch}.pth') |
|
|
|
|
|
|
|
|
torch.save(comp.state_dict(), 'compressor_final.pth') |
|
|
torch.save(decomp.state_dict(), 'decompressor_final.pth') |
|
|
print("Training completed! Models saved as compressor_final.pth and decompressor_final.pth") |
|
|
|
|
|
|
|
|
def load_and_test_models(compressor_path, decompressor_path, embeddings_path, device='cuda'): |
|
|
""" |
|
|
Load trained models and test reconstruction quality |
|
|
""" |
|
|
print("Loading trained models...") |
|
|
comp = Compressor().to(device) |
|
|
decomp = Decompressor().to(device) |
|
|
|
|
|
comp.load_state_dict(torch.load(compressor_path)) |
|
|
decomp.load_state_dict(torch.load(decompressor_path)) |
|
|
|
|
|
comp.eval() |
|
|
decomp.eval() |
|
|
|
|
|
|
|
|
ds = PrecomputedEmbeddingDataset(embeddings_path) |
|
|
test_loader = DataLoader(ds, batch_size=16, shuffle=False) |
|
|
|
|
|
|
|
|
stats = torch.load('normalization_stats.pt') |
|
|
|
|
|
print("Testing reconstruction quality...") |
|
|
total_mse = 0 |
|
|
total_samples = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in tqdm(test_loader, desc="Testing"): |
|
|
x = batch.to(device) |
|
|
z = comp(x, stats) |
|
|
xr = decomp(z) |
|
|
mse = (x - xr).pow(2).mean() |
|
|
total_mse += mse.item() * len(x) |
|
|
total_samples += len(x) |
|
|
|
|
|
avg_mse = total_mse / total_samples |
|
|
print(f"Average reconstruction MSE: {avg_mse:.6f}") |
|
|
|
|
|
return avg_mse |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description='Train protein compressor with pre-computed embeddings') |
|
|
parser.add_argument('--embeddings', type=str, default='/data2/edwardsun/flow_project/compressor_dataset/peptide_embeddings.pt', |
|
|
help='Path to pre-computed embeddings from final_sequence_encoder.py') |
|
|
parser.add_argument('--device', type=str, default='cuda', help='Device to use (cuda/cpu)') |
|
|
parser.add_argument('--test', action='store_true', help='Test existing models instead of training') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
device = torch.device(args.device if torch.cuda.is_available() else 'cpu') |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
if args.test: |
|
|
|
|
|
load_and_test_models('compressor_final.pth', 'decompressor_final.pth', args.embeddings, device) |
|
|
else: |
|
|
|
|
|
train_with_precomputed_embeddings(args.embeddings, device) |