Shrey Goel
commited on
Commit
·
d04a061
0
Parent(s):
adding code
Browse files- .DS_Store +0 -0
- .gitignore +24 -0
- README.md +1 -0
- __init__.py +0 -0
- setup.py +10 -0
- src/configs/guidance.yaml +74 -0
- src/configs/lm.yaml +66 -0
- src/guidance/dataloader.py +108 -0
- src/guidance/main.py +73 -0
- src/guidance/solubility_module.py +155 -0
- src/guidance/utils.py +21 -0
- src/lm/memdlm/dataloader.py +88 -0
- src/lm/memdlm/diffusion_module.py +189 -0
- src/lm/memdlm/loss.py +42 -0
- src/lm/memdlm/main.py +88 -0
- src/sampling/guided_generator.py +90 -0
- src/sampling/guided_sampler.py +291 -0
- src/sampling/unconditional_generator.py +114 -0
- src/sampling/unconditional_sampler.py +108 -0
- src/utils/generate_utils.py +151 -0
- src/utils/model_utils.py +139 -0
- src/utils/optimizer_utils.py +300 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
.gitignore
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# .gitignore
|
| 2 |
+
|
| 3 |
+
/checkpoints/
|
| 4 |
+
/data/
|
| 5 |
+
/results/
|
| 6 |
+
/build/
|
| 7 |
+
/src/scripts/
|
| 8 |
+
/src/benchmarks
|
| 9 |
+
|
| 10 |
+
/src/lm/dplm
|
| 11 |
+
/src/lm/evodiff
|
| 12 |
+
/src/lm/dplm_playground.ipynb
|
| 13 |
+
/src/lm/evoflow_playground.ipynb
|
| 14 |
+
/src/utils/ubuntu_font
|
| 15 |
+
|
| 16 |
+
/src/sampling/old_guidance.py
|
| 17 |
+
|
| 18 |
+
/MeMDLM_v2.egg-info/
|
| 19 |
+
*.pth
|
| 20 |
+
*.ckpt
|
| 21 |
+
*.err
|
| 22 |
+
*.out
|
| 23 |
+
*.csv
|
| 24 |
+
__pycache__/
|
README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# MeMDLM_v2
|
__init__.py
ADDED
|
File without changes
|
setup.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup, find_packages
|
| 2 |
+
|
| 3 |
+
setup(
|
| 4 |
+
name='MeMDLM_v2',
|
| 5 |
+
version='1.0',
|
| 6 |
+
packages=find_packages(),
|
| 7 |
+
install_requires=[],
|
| 8 |
+
author='Shrey Goel',
|
| 9 |
+
author_email='[email protected]'
|
| 10 |
+
)
|
src/configs/guidance.yaml
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
seed: 42
|
| 4 |
+
base_dir: /scratch/sgoel/MeMDLM_v2
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
lm:
|
| 8 |
+
pretrained_esm: facebook/esm2_t33_650M_UR50D
|
| 9 |
+
pretrained_evoflow: fredzzp/EvoFlow-650M-context-3070
|
| 10 |
+
pretrained_dplm: airkingbd/dplm_650m
|
| 11 |
+
ft_evoflow: ft_eflow-3070-650M_steps=50k_layers=3_lr=0.00004_wd=.01_polynom_pwr=1_betas=.9-.98_bsz=8_gclip=1.0
|
| 12 |
+
ft_dplm: ft_dplm-650M_steps=5k_layers=3_lr=0.00004_wd=.01_polynom_pwr=1_betas=.9-.98_bsz=32_gclip=1.0
|
| 13 |
+
|
| 14 |
+
model:
|
| 15 |
+
d_model: 1280
|
| 16 |
+
num_heads: 2
|
| 17 |
+
dropout: 0.5
|
| 18 |
+
num_layers: 4
|
| 19 |
+
label_pad_value: -100
|
| 20 |
+
|
| 21 |
+
optim:
|
| 22 |
+
type: adamw
|
| 23 |
+
lr: 3e-5
|
| 24 |
+
lr_end: 1e-5
|
| 25 |
+
weight_decay: 0.01
|
| 26 |
+
beta1: 0.9
|
| 27 |
+
beta2: 0.98
|
| 28 |
+
power: 1
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
training:
|
| 32 |
+
mode: test # train / test
|
| 33 |
+
n_layers: 4
|
| 34 |
+
max_steps: 3000
|
| 35 |
+
warmup_steps: 150
|
| 36 |
+
log_every_n_steps: 10
|
| 37 |
+
num_sanity_val_steps: 2
|
| 38 |
+
val_check_interval: 250
|
| 39 |
+
enable_progress_bar: true
|
| 40 |
+
grad_clip_val: 1.0
|
| 41 |
+
devices: [0] # list of GPU IDs from 0-7
|
| 42 |
+
|
| 43 |
+
guidance:
|
| 44 |
+
n_steps: 128
|
| 45 |
+
alpha: 3
|
| 46 |
+
gamma: 0.3
|
| 47 |
+
saliency_eps: 1e-4
|
| 48 |
+
saliency_t: 2.0
|
| 49 |
+
sampling_t: 0.7
|
| 50 |
+
boltzmann_t: 0.3
|
| 51 |
+
top_p: 0.2
|
| 52 |
+
steps: 128
|
| 53 |
+
prior: lm_probs # lm_probs / boltzmann
|
| 54 |
+
|
| 55 |
+
data:
|
| 56 |
+
batch_size: 32
|
| 57 |
+
max_seq_len: 1024
|
| 58 |
+
train: ${base_dir}/data/classifier/train.csv
|
| 59 |
+
test: ${base_dir}/data/classifier/test.csv
|
| 60 |
+
val: ${base_dir}/data/classifier/val.csv
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
wandb:
|
| 64 |
+
project: memdlm_guidance
|
| 65 |
+
group: programmablebio
|
| 66 |
+
name: new_data_cleaned_steps3k_lr3e-5_bsz32_heads2_drpt0.5_layers4
|
| 67 |
+
id: ${.name}_${seed}
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
checkpointing:
|
| 71 |
+
save_every_n_steps: 250
|
| 72 |
+
save_dir: ${base_dir}/checkpoints/${wandb.name}
|
| 73 |
+
resume_ckpt_path: ${checkpointing.save_dir}/last.ckpt
|
| 74 |
+
best_ckpt_path: ${checkpointing.save_dir}/best_model.ckpt
|
src/configs/lm.yaml
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
seed: 42
|
| 4 |
+
base_dir: /scratch/pranamlab/sgoel/MeMDLM_v2
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
lm:
|
| 8 |
+
pretrained_esm: facebook/esm2_t33_650M_UR50D
|
| 9 |
+
pretrained_evoflow: fredzzp/EvoFlow-650M-context-3070
|
| 10 |
+
pretrained_dplm: airkingbd/dplm_650m
|
| 11 |
+
pretrained_progen: hugohrban/progen2-base
|
| 12 |
+
num_diffusion_timesteps: 500
|
| 13 |
+
weight_type: linear # constant / linear
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
optim:
|
| 17 |
+
type: adamw
|
| 18 |
+
scheduler: polynomial
|
| 19 |
+
lr: 0.00004
|
| 20 |
+
lr_end: 1e-5
|
| 21 |
+
warmup_init_lr: 1e-07
|
| 22 |
+
weight_decay: 0.01
|
| 23 |
+
beta1: 0.9
|
| 24 |
+
beta2: 0.98
|
| 25 |
+
power: 1
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
training:
|
| 29 |
+
mode: train # train / test / resume_from_checkpoint
|
| 30 |
+
n_layers: 3
|
| 31 |
+
max_steps: 5000
|
| 32 |
+
warmup_steps: 25
|
| 33 |
+
log_every_n_steps: 10
|
| 34 |
+
num_sanity_val_steps: 2
|
| 35 |
+
val_check_interval: 250
|
| 36 |
+
enable_progress_bar: true
|
| 37 |
+
grad_clip_val: 1.0
|
| 38 |
+
devices: [0,1,2] # list of GPU IDs
|
| 39 |
+
|
| 40 |
+
sampling:
|
| 41 |
+
n_steps: 128
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
data:
|
| 45 |
+
batch_size: 8
|
| 46 |
+
max_seq_len: 1024
|
| 47 |
+
train: ${base_dir}/data/new/train.csv
|
| 48 |
+
test: ${base_dir}/data/new/test.csv
|
| 49 |
+
val: ${base_dir}/data/new/val.csv
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
wandb:
|
| 53 |
+
project: memdlm
|
| 54 |
+
group: programmablebio
|
| 55 |
+
name: ft_eflow-3070-650M_steps=5k_layers=3_lr=0.00004_wd=.01_polynom_pwr=1_betas=.9-.98_bsz=8_gclip=1.0_ml=1024
|
| 56 |
+
# name: ft_progen-base-764M_steps=50k_layers=2_lr=0.00004_wd=.1_cosine-to-frac_betas=.9-.999_bsz=8_gclip=0.8
|
| 57 |
+
# name: ft_dplm-650M_steps=5k_layers=3_lr=0.00004_wd=.01_polynom_pwr=1_betas=.9-.98_bsz=32_gclip=1.0
|
| 58 |
+
# name: ft_esm-650M_steps=3k_layers=3_lr=0.00004_wd=.01_polynom_pwr=1_betas=.9-.98_bsz=32_gclip=1.0
|
| 59 |
+
id: ${.name}_${seed}
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
checkpointing:
|
| 63 |
+
save_every_n_steps: 250
|
| 64 |
+
save_dir: ${base_dir}/checkpoints/${wandb.name}
|
| 65 |
+
resume_ckpt_path: ${checkpointing.save_dir}/last.ckpt
|
| 66 |
+
best_ckpt_path: ${checkpointing.save_dir}/best_model.ckpt
|
src/guidance/dataloader.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import lightning.pytorch as pl
|
| 4 |
+
|
| 5 |
+
from transformers import AutoModel, AutoTokenizer
|
| 6 |
+
from torch.utils.data import Dataset, DataLoader
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class MembraneDataset(Dataset):
|
| 10 |
+
def __init__(self, config, data_path):
|
| 11 |
+
self.config = config
|
| 12 |
+
self.data = pd.read_csv(data_path)
|
| 13 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.config.lm.pretrained_esm)
|
| 14 |
+
|
| 15 |
+
def __len__(self):
|
| 16 |
+
return len(self.data)
|
| 17 |
+
|
| 18 |
+
def __getitem__(self, idx):
|
| 19 |
+
sequence = self.data.iloc[idx]["Sequence"]
|
| 20 |
+
|
| 21 |
+
tokens = self.tokenizer(
|
| 22 |
+
sequence.upper(),
|
| 23 |
+
return_tensors='pt',
|
| 24 |
+
padding='max_length',
|
| 25 |
+
truncation=True,
|
| 26 |
+
max_length=self.config.data.max_seq_len,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
labels = self.get_labels(sequence)
|
| 30 |
+
|
| 31 |
+
return {
|
| 32 |
+
"input_ids": tokens['input_ids'],
|
| 33 |
+
"attention_mask": tokens['attention_mask'],
|
| 34 |
+
"labels": labels
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
def get_labels(self, sequence):
|
| 38 |
+
max_len = self.config.data.max_seq_len
|
| 39 |
+
|
| 40 |
+
# Create per-residue labels
|
| 41 |
+
labels = torch.tensor([1 if residue.islower() else 0 for residue in sequence], dtype=torch.float)
|
| 42 |
+
|
| 43 |
+
if len(labels) < max_len: # Padding if sequence shorter than tokenizer truncation length
|
| 44 |
+
padded_labels = torch.cat(
|
| 45 |
+
[labels, torch.full(size=(max_len - len(labels),), fill_value=self.config.model.label_pad_value)]
|
| 46 |
+
)
|
| 47 |
+
else: # Truncation otherwise
|
| 48 |
+
padded_labels = labels[:max_len]
|
| 49 |
+
return padded_labels
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def collate_fn(batch):
|
| 53 |
+
input_ids = torch.stack([item['input_ids'].squeeze(0) for item in batch])
|
| 54 |
+
masks = torch.stack([item['attention_mask'].squeeze(0) for item in batch])
|
| 55 |
+
labels = torch.stack([item['labels'] for item in batch])
|
| 56 |
+
|
| 57 |
+
return {
|
| 58 |
+
'input_ids': input_ids,
|
| 59 |
+
'attention_mask': masks,
|
| 60 |
+
'labels': labels
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class MembraneDataModule(pl.LightningDataModule):
|
| 65 |
+
def __init__(self, config, train_dataset, val_dataset, test_dataset, collate_fn=collate_fn):
|
| 66 |
+
super().__init__()
|
| 67 |
+
self.train_dataset = train_dataset
|
| 68 |
+
self.val_dataset = val_dataset
|
| 69 |
+
self.test_dataset = test_dataset
|
| 70 |
+
self.collate_fn = collate_fn
|
| 71 |
+
self.batch_size = config.data.batch_size
|
| 72 |
+
|
| 73 |
+
def train_dataloader(self):
|
| 74 |
+
return DataLoader(self.train_dataset,
|
| 75 |
+
batch_size=self.batch_size,
|
| 76 |
+
collate_fn=self.collate_fn,
|
| 77 |
+
num_workers=8,
|
| 78 |
+
pin_memory=True)
|
| 79 |
+
|
| 80 |
+
def val_dataloader(self):
|
| 81 |
+
return DataLoader(self.val_dataset,
|
| 82 |
+
batch_size=self.batch_size,
|
| 83 |
+
collate_fn=self.collate_fn,
|
| 84 |
+
num_workers=8,
|
| 85 |
+
pin_memory=True)
|
| 86 |
+
|
| 87 |
+
def test_dataloader(self):
|
| 88 |
+
return DataLoader(self.test_dataset,
|
| 89 |
+
batch_size=self.batch_size,
|
| 90 |
+
collate_fn=self.collate_fn,
|
| 91 |
+
num_workers=8,
|
| 92 |
+
pin_memory=True)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def get_datasets(config):
|
| 96 |
+
"""Helper method to grab datasets to quickly init data module in main.py"""
|
| 97 |
+
esm_model = AutoModel.from_pretrained(config.lm.pretrained_esm)
|
| 98 |
+
tokenizer = AutoTokenizer.from_pretrained(config.lm.pretrained_esm)
|
| 99 |
+
|
| 100 |
+
train_dataset = MembraneDataset(config, config.data.train)
|
| 101 |
+
val_dataset = MembraneDataset(config, config.data.val)
|
| 102 |
+
test_dataset = MembraneDataset(config, config.data.test)
|
| 103 |
+
|
| 104 |
+
return {
|
| 105 |
+
"train": train_dataset,
|
| 106 |
+
"val": val_dataset,
|
| 107 |
+
"test": test_dataset
|
| 108 |
+
}
|
src/guidance/main.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import wandb
|
| 5 |
+
import lightning.pytorch as pl
|
| 6 |
+
|
| 7 |
+
from omegaconf import OmegaConf
|
| 8 |
+
from lightning.pytorch.strategies import DDPStrategy
|
| 9 |
+
from lightning.pytorch.loggers import WandbLogger
|
| 10 |
+
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
|
| 11 |
+
|
| 12 |
+
from src.utils.model_utils import _print
|
| 13 |
+
from src.guidance.solubility_module import SolubilityClassifier
|
| 14 |
+
from src.guidance.dataloader import MembraneDataModule, get_datasets
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
config = OmegaConf.load("/scratch/sgoel/MeMDLM_v2/src/configs/guidance.yaml")
|
| 18 |
+
wandb.login(key='2b76a2fa2c1cdfddc5f443602c17b011fefb0a8f')
|
| 19 |
+
|
| 20 |
+
# data
|
| 21 |
+
datasets = get_datasets(config)
|
| 22 |
+
data_module = MembraneDataModule(
|
| 23 |
+
config=config,
|
| 24 |
+
train_dataset=datasets['train'],
|
| 25 |
+
val_dataset=datasets['val'],
|
| 26 |
+
test_dataset=datasets['test'],
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# wandb logging
|
| 30 |
+
#wandb.init(project=config.wandb.project, name=config.wandb.name)
|
| 31 |
+
wandb_logger = WandbLogger(**config.wandb)
|
| 32 |
+
|
| 33 |
+
# lightning checkpoints
|
| 34 |
+
lr_monitor = LearningRateMonitor(logging_interval="step")
|
| 35 |
+
checkpoint_callback = ModelCheckpoint(
|
| 36 |
+
monitor="val/loss",
|
| 37 |
+
save_top_k=1,
|
| 38 |
+
mode="min",
|
| 39 |
+
dirpath=config.checkpointing.save_dir,
|
| 40 |
+
filename="best_model",
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# lightning trainer
|
| 44 |
+
trainer = pl.Trainer(
|
| 45 |
+
max_steps=config.training.max_steps,
|
| 46 |
+
accelerator="cuda",
|
| 47 |
+
devices=1, #config.training.devices if config.training.mode=='train' else [0],
|
| 48 |
+
#strategy=DDPStrategy(find_unused_parameters=True),
|
| 49 |
+
callbacks=[checkpoint_callback, lr_monitor],
|
| 50 |
+
logger=wandb_logger,
|
| 51 |
+
log_every_n_steps=config.training.log_every_n_steps
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# Folder to save checkpoints
|
| 55 |
+
ckpt_dir = config.checkpointing.save_dir
|
| 56 |
+
os.makedirs(ckpt_dir, exist_ok=True)
|
| 57 |
+
|
| 58 |
+
# instantiate model
|
| 59 |
+
model = SolubilityClassifier(config)
|
| 60 |
+
|
| 61 |
+
# train or evalute the model
|
| 62 |
+
if config.training.mode == "train":
|
| 63 |
+
trainer.fit(model, datamodule=data_module)
|
| 64 |
+
|
| 65 |
+
elif config.training.mode == "test":
|
| 66 |
+
ckpt_path = os.path.join(ckpt_dir, "best_model.ckpt")
|
| 67 |
+
state_dict = model.get_state_dict(ckpt_path)
|
| 68 |
+
model.load_state_dict(state_dict)
|
| 69 |
+
trainer.test(model, datamodule=data_module, ckpt_path=ckpt_path)
|
| 70 |
+
else:
|
| 71 |
+
raise ValueError(f"{config.training.mode} is invalid. Must be 'train' or 'test'")
|
| 72 |
+
|
| 73 |
+
wandb.finish()
|
src/guidance/solubility_module.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import lightning.pytorch as pl
|
| 5 |
+
|
| 6 |
+
from omegaconf import OmegaConf
|
| 7 |
+
from transformers import AutoModel
|
| 8 |
+
from torchmetrics.classification import BinaryAUROC, BinaryAccuracy
|
| 9 |
+
|
| 10 |
+
from src.utils.model_utils import _print
|
| 11 |
+
from src.guidance.utils import CosineWarmup
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
config = OmegaConf.load("/scratch/sgoel/MeMDLM_v2/src/configs/guidance.yaml")
|
| 15 |
+
|
| 16 |
+
class SolubilityClassifier(pl.LightningModule):
|
| 17 |
+
def __init__(self, config):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.config = config
|
| 20 |
+
self.loss_fn = nn.BCEWithLogitsLoss(reduction='none')
|
| 21 |
+
self.auroc = BinaryAUROC()
|
| 22 |
+
self.accuracy = BinaryAccuracy()
|
| 23 |
+
|
| 24 |
+
self.esm_model = AutoModel.from_pretrained(self.config.lm.pretrained_esm)
|
| 25 |
+
for p in self.esm_model.parameters():
|
| 26 |
+
p.requires_grad = False
|
| 27 |
+
|
| 28 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 29 |
+
d_model=config.model.d_model,
|
| 30 |
+
nhead=config.model.num_heads,
|
| 31 |
+
dropout=config.model.dropout,
|
| 32 |
+
batch_first=True
|
| 33 |
+
)
|
| 34 |
+
self.encoder = nn.TransformerEncoder(encoder_layer, config.model.num_layers)
|
| 35 |
+
self.layer_norm = nn.LayerNorm(config.model.d_model)
|
| 36 |
+
self.dropout = nn.Dropout(config.model.dropout)
|
| 37 |
+
self.mlp = nn.Sequential(
|
| 38 |
+
nn.Linear(config.model.d_model, config.model.d_model // 2),
|
| 39 |
+
nn.ReLU(),
|
| 40 |
+
nn.Dropout(config.model.dropout),
|
| 41 |
+
nn.Linear(config.model.d_model // 2, 1),
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# -------# Classifier step #-------- #
|
| 45 |
+
def forward(self, batch):
|
| 46 |
+
if 'input_ids' in batch:
|
| 47 |
+
esm_embeds = self.get_esm_embeddings(batch['input_ids'], batch['attention_mask'])
|
| 48 |
+
elif 'embeds' in batch:
|
| 49 |
+
esm_embeds = batch['embeds']
|
| 50 |
+
encodings = self.encoder(esm_embeds, src_key_padding_mask=(batch['attention_mask'] == 0))
|
| 51 |
+
encodings = self.dropout(self.layer_norm(encodings))
|
| 52 |
+
logits = self.mlp(encodings).squeeze(-1)
|
| 53 |
+
return logits
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# -------# Training / Evaluation #-------- #
|
| 57 |
+
def training_step(self, batch, batch_idx):
|
| 58 |
+
train_loss, _ = self.compute_loss(batch)
|
| 59 |
+
self.log(name="train/loss", value=train_loss.item(), on_step=True, on_epoch=False, logger=True, sync_dist=True)
|
| 60 |
+
self.save_ckpt()
|
| 61 |
+
return train_loss
|
| 62 |
+
|
| 63 |
+
def validation_step(self, batch, batch_idx):
|
| 64 |
+
val_loss, _ = self.compute_loss(batch)
|
| 65 |
+
self.log(name="val/loss", value=val_loss.item(), on_step=False, on_epoch=True, logger=True, sync_dist=True)
|
| 66 |
+
return val_loss
|
| 67 |
+
|
| 68 |
+
def test_step(self, batch):
|
| 69 |
+
test_loss, preds = self.compute_loss(batch)
|
| 70 |
+
auroc, accuracy = self.get_metrics(batch, preds)
|
| 71 |
+
self.log(name="test/loss", value=test_loss.item(), on_step=False, on_epoch=True, logger=True, sync_dist=True)
|
| 72 |
+
self.log(name="test/AUROC", value=auroc.item(), on_step=False, on_epoch=True, logger=True, sync_dist=True)
|
| 73 |
+
self.log(name="test/accuracy", value=accuracy.item(), on_step=False, on_epoch=True, logger=True, sync_dist=True)
|
| 74 |
+
return test_loss
|
| 75 |
+
|
| 76 |
+
def on_test_epoch_end(self):
|
| 77 |
+
self.auroc.reset()
|
| 78 |
+
self.accuracy.reset()
|
| 79 |
+
|
| 80 |
+
def optimizer_step(self, *args, **kwargs):
|
| 81 |
+
super().optimizer_step(*args, **kwargs)
|
| 82 |
+
gc.collect()
|
| 83 |
+
torch.cuda.empty_cache()
|
| 84 |
+
|
| 85 |
+
def configure_optimizers(self):
|
| 86 |
+
path = self.config.training
|
| 87 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.config.optim.lr)
|
| 88 |
+
lr_scheduler = CosineWarmup(
|
| 89 |
+
optimizer,
|
| 90 |
+
warmup_steps=path.warmup_steps,
|
| 91 |
+
total_steps=path.max_steps,
|
| 92 |
+
)
|
| 93 |
+
scheduler_dict = {
|
| 94 |
+
"scheduler": lr_scheduler,
|
| 95 |
+
"interval": 'step',
|
| 96 |
+
'frequency': 1,
|
| 97 |
+
'monitor': 'val/loss',
|
| 98 |
+
'name': 'learning_rate'
|
| 99 |
+
}
|
| 100 |
+
return [optimizer], [scheduler_dict]
|
| 101 |
+
|
| 102 |
+
def save_ckpt(self):
|
| 103 |
+
curr_step = self.global_step
|
| 104 |
+
save_every = self.config.training.val_check_interval
|
| 105 |
+
if curr_step % save_every == 0 and curr_step > 0: # Save every 250 steps
|
| 106 |
+
ckpt_path = f"{self.config.checkpointing.save_dir}/step={curr_step}.ckpt"
|
| 107 |
+
self.trainer.save_checkpoint(ckpt_path)
|
| 108 |
+
|
| 109 |
+
# -------# Loss and Test Set Metrics #-------- #
|
| 110 |
+
@torch.no_grad
|
| 111 |
+
def get_esm_embeddings(self, input_ids, attention_mask):
|
| 112 |
+
outputs = self.esm_model(input_ids=input_ids, attention_mask=attention_mask)
|
| 113 |
+
embeddings = outputs.last_hidden_state
|
| 114 |
+
return embeddings
|
| 115 |
+
|
| 116 |
+
def compute_loss(self, batch):
|
| 117 |
+
"""Helper method to handle loss calculation"""
|
| 118 |
+
labels = batch['labels']
|
| 119 |
+
preds = self.forward(batch)
|
| 120 |
+
loss = self.loss_fn(preds, labels)
|
| 121 |
+
loss_mask = (labels != self.config.model.label_pad_value) # only calculate loss over non-pad tokens
|
| 122 |
+
loss = (loss * loss_mask).sum() / loss_mask.sum()
|
| 123 |
+
return loss, preds
|
| 124 |
+
|
| 125 |
+
def get_metrics(self, batch, preds):
|
| 126 |
+
"""Helper method to compute metrics"""
|
| 127 |
+
labels = batch['labels']
|
| 128 |
+
|
| 129 |
+
valid_mask = (labels != self.config.model.label_pad_value)
|
| 130 |
+
labels = labels[valid_mask]
|
| 131 |
+
preds = preds[valid_mask]
|
| 132 |
+
|
| 133 |
+
_print(f"labels {labels.shape}")
|
| 134 |
+
_print(f"preds {preds.shape}")
|
| 135 |
+
|
| 136 |
+
auroc = self.auroc.forward(preds, labels)
|
| 137 |
+
accuracy = self.accuracy.forward(preds, labels)
|
| 138 |
+
return auroc, accuracy
|
| 139 |
+
|
| 140 |
+
# -------# Helper Functions #-------- #
|
| 141 |
+
def get_state_dict(self, ckpt_path):
|
| 142 |
+
"""Helper method to load and process a trained model's state dict from saved checkpoint"""
|
| 143 |
+
def remove_model_prefix(state_dict):
|
| 144 |
+
for k in state_dict.keys():
|
| 145 |
+
if "model." in k:
|
| 146 |
+
k.replace('model.', '')
|
| 147 |
+
return state_dict
|
| 148 |
+
|
| 149 |
+
checkpoint = torch.load(ckpt_path, map_location='cuda' if torch.cuda.is_available() else 'cpu')
|
| 150 |
+
state_dict = checkpoint.get("state_dict", checkpoint)
|
| 151 |
+
|
| 152 |
+
if any(k.startswith("model.") for k in state_dict.keys()):
|
| 153 |
+
state_dict = remove_model_prefix(state_dict)
|
| 154 |
+
|
| 155 |
+
return state_dict
|
src/guidance/utils.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
| 5 |
+
|
| 6 |
+
class CosineWarmup(_LRScheduler):
|
| 7 |
+
def __init__(self, optimizer, warmup_steps, total_steps, eta_ratio=0.1, last_epoch=-1):
|
| 8 |
+
self.warmup_steps = warmup_steps
|
| 9 |
+
self.total_steps = total_steps
|
| 10 |
+
self.eta_ratio = eta_ratio # The ratio of minimum to maximum learning rate
|
| 11 |
+
super(CosineWarmup, self).__init__(optimizer, last_epoch)
|
| 12 |
+
|
| 13 |
+
def get_lr(self):
|
| 14 |
+
if self.last_epoch < self.warmup_steps:
|
| 15 |
+
return [base_lr * self.last_epoch / self.warmup_steps for base_lr in self.base_lrs]
|
| 16 |
+
|
| 17 |
+
progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
|
| 18 |
+
cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
|
| 19 |
+
decayed_lr = (1 - self.eta_ratio) * cosine_decay + self.eta_ratio
|
| 20 |
+
|
| 21 |
+
return [decayed_lr * base_lr for base_lr in self.base_lrs]
|
src/lm/memdlm/dataloader.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import lightning.pytorch as pl
|
| 4 |
+
|
| 5 |
+
from transformers import AutoModel, AutoTokenizer
|
| 6 |
+
from torch.utils.data import Dataset, DataLoader
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class MembraneDataset(Dataset):
|
| 10 |
+
def __init__(self, config, data_path):
|
| 11 |
+
self.config = config
|
| 12 |
+
self.data = pd.read_csv(data_path)
|
| 13 |
+
self.tokenizer = AutoTokenizer.from_pretrained(config.lm.pretrained_evoflow)
|
| 14 |
+
|
| 15 |
+
def __len__(self):
|
| 16 |
+
return len(self.data)
|
| 17 |
+
|
| 18 |
+
def __getitem__(self, idx):
|
| 19 |
+
sequence = self.data.iloc[idx]["Sequence"]
|
| 20 |
+
|
| 21 |
+
tokens = self.tokenizer(
|
| 22 |
+
sequence.upper(),
|
| 23 |
+
return_tensors='pt',
|
| 24 |
+
padding='max_length',
|
| 25 |
+
truncation=True,
|
| 26 |
+
max_length=self.config.data.max_seq_len
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
#return {"input_ids": tokens['input_ids'], "attention_mask": tokens['attention_mask']}
|
| 30 |
+
|
| 31 |
+
return {
|
| 32 |
+
"input_ids": tokens['input_ids'].squeeze(0),
|
| 33 |
+
"attention_mask": tokens['attention_mask'].squeeze(0)
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def collate_fn(batch):
|
| 38 |
+
input_ids = torch.stack([item['input_ids'] for item in batch])#.squeeze()
|
| 39 |
+
masks = torch.stack([item['attention_mask'] for item in batch])#.squeeze()
|
| 40 |
+
|
| 41 |
+
return {'input_ids': input_ids, 'attention_mask': masks}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class MembraneDataModule(pl.LightningDataModule):
|
| 45 |
+
def __init__(self, config, train_dataset, val_dataset, test_dataset, collate_fn=collate_fn):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.train_dataset = train_dataset
|
| 48 |
+
self.val_dataset = val_dataset
|
| 49 |
+
self.test_dataset = test_dataset
|
| 50 |
+
self.collate_fn = collate_fn
|
| 51 |
+
self.batch_size = config.data.batch_size
|
| 52 |
+
self.tokenizer = AutoTokenizer.from_pretrained(config.lm.pretrained_evoflow)
|
| 53 |
+
|
| 54 |
+
def train_dataloader(self):
|
| 55 |
+
return DataLoader(self.train_dataset,
|
| 56 |
+
batch_size=self.batch_size,
|
| 57 |
+
collate_fn=self.collate_fn,
|
| 58 |
+
num_workers=8,
|
| 59 |
+
pin_memory=True)
|
| 60 |
+
|
| 61 |
+
def val_dataloader(self):
|
| 62 |
+
return DataLoader(self.val_dataset,
|
| 63 |
+
batch_size=self.batch_size,
|
| 64 |
+
collate_fn=self.collate_fn,
|
| 65 |
+
num_workers=8,
|
| 66 |
+
shuffle=False,
|
| 67 |
+
pin_memory=True)
|
| 68 |
+
|
| 69 |
+
def test_dataloader(self):
|
| 70 |
+
return DataLoader(self.test_dataset,
|
| 71 |
+
batch_size=self.batch_size,
|
| 72 |
+
collate_fn=self.collate_fn,
|
| 73 |
+
num_workers=8,
|
| 74 |
+
shuffle=False,
|
| 75 |
+
pin_memory=True)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def get_datasets(config):
|
| 79 |
+
"""Helper method to grab datasets to quickly init data module in main.py"""
|
| 80 |
+
train_dataset = MembraneDataset(config, config.data.train)
|
| 81 |
+
test_dataset = MembraneDataset(config, config.data.test)
|
| 82 |
+
val_dataset = MembraneDataset(config, config.data.val)
|
| 83 |
+
|
| 84 |
+
return {
|
| 85 |
+
"train": train_dataset,
|
| 86 |
+
"val": val_dataset,
|
| 87 |
+
"test": test_dataset
|
| 88 |
+
}
|
src/lm/memdlm/diffusion_module.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gc
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import lightning as pl
|
| 7 |
+
|
| 8 |
+
from typing import Optional
|
| 9 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
| 10 |
+
|
| 11 |
+
from src.utils.model_utils import _print
|
| 12 |
+
from src.utils.optimizer_utils import get_optimizer, get_scheduler
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class MembraneDiffusion(pl.LightningModule):
|
| 16 |
+
def __init__(self, config):
|
| 17 |
+
"""
|
| 18 |
+
Args:
|
| 19 |
+
config (OmegaConf): config.yaml file with all training parameters
|
| 20 |
+
"""
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.config = config
|
| 23 |
+
self.save_hyperparameters(logger=True)
|
| 24 |
+
|
| 25 |
+
self.model = AutoModelForMaskedLM.from_pretrained(config.lm.pretrained_evoflow, trust_remote_code=True)
|
| 26 |
+
self.tokenizer = AutoTokenizer.from_pretrained(config.lm.pretrained_evoflow)
|
| 27 |
+
|
| 28 |
+
self.mask_id = self.tokenizer.mask_token_id
|
| 29 |
+
self.pad_id = self.tokenizer.pad_token_id
|
| 30 |
+
|
| 31 |
+
def forward(self, input_ids, attention_mask, guidance: Optional[bool] = False):
|
| 32 |
+
"""
|
| 33 |
+
Forward pass through language model.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
- input_ids (torch.Tensor): [B, L], token ids
|
| 37 |
+
- attention_mask (torch.Tensor): [B, L], pad/non-pad binary mask
|
| 38 |
+
Returns:
|
| 39 |
+
- logits (torch.Tensor): [B, L, V], unnormalized model outputs
|
| 40 |
+
"""
|
| 41 |
+
return self.model(input_ids=input_ids, attention_mask=attention_mask).logits
|
| 42 |
+
|
| 43 |
+
# -------# Diffusion #-------- #
|
| 44 |
+
def step(self, batch):
|
| 45 |
+
labels = batch['input_ids']
|
| 46 |
+
|
| 47 |
+
# Forward diffusion
|
| 48 |
+
t1 = self.sample_t(labels) # Sample timestep
|
| 49 |
+
xt, _ = self.noise_x0(labels, t1, maskable_mask=self.is_maskable(labels)) # Noise sequence
|
| 50 |
+
logits = self.forward(input_ids=xt, attention_mask=batch['attention_mask']) # Model logits
|
| 51 |
+
|
| 52 |
+
# Loss computation
|
| 53 |
+
weight = self.get_weight(t1, weight_type=self.config.lm.weight_type) # RDM uses a weighted cross entropy loss
|
| 54 |
+
loss_out = self.compute_loss(logits, labels, weight) # Compute loss and ppl
|
| 55 |
+
|
| 56 |
+
self.cleanup()
|
| 57 |
+
return loss_out['loss'], loss_out['ppl']
|
| 58 |
+
|
| 59 |
+
def sample_t(self, labels, rdm_coupling=False):
|
| 60 |
+
"""
|
| 61 |
+
Sample diffusion timesteps. Non-coupling RDM only uses one timestep (t1).
|
| 62 |
+
"""
|
| 63 |
+
timesteps = torch.randint(
|
| 64 |
+
1,
|
| 65 |
+
self.config.lm.num_diffusion_timesteps + 1,
|
| 66 |
+
(2 if rdm_coupling else 1) * (labels.size(0),),
|
| 67 |
+
device=labels.device
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
if rdm_coupling:
|
| 71 |
+
return timesteps.chunk(2)
|
| 72 |
+
return timesteps
|
| 73 |
+
|
| 74 |
+
def noise_x0(self, x0, t1, maskable_mask):
|
| 75 |
+
"""
|
| 76 |
+
Apply noise to the initial sequence x0.
|
| 77 |
+
"""
|
| 78 |
+
u = torch.rand_like(x0, dtype=torch.float)
|
| 79 |
+
t1_mask = (u < (t1 / self.config.lm.num_diffusion_timesteps)[:, None]) & maskable_mask
|
| 80 |
+
x_t1 = x0.masked_fill(t1_mask, self.mask_id)
|
| 81 |
+
x_t1 = x_t1.masked_fill(t1_mask, self.mask_id)
|
| 82 |
+
return x_t1, t1_mask
|
| 83 |
+
|
| 84 |
+
def get_weight(self, t, weight_type):
|
| 85 |
+
"""
|
| 86 |
+
Compute the weighting factor for the RDM-derived loss (weighted cross-entropy).
|
| 87 |
+
"""
|
| 88 |
+
num_timesteps = self.config.lm.num_diffusion_timesteps
|
| 89 |
+
weight = {
|
| 90 |
+
"linear": (num_timesteps - (t - 1)), # num_timesteps * (1 - (t-1)/num_timesteps)
|
| 91 |
+
"constant": num_timesteps * torch.ones_like(t),
|
| 92 |
+
}[weight_type][:, None].float() / num_timesteps
|
| 93 |
+
return weight.squeeze()
|
| 94 |
+
|
| 95 |
+
def compute_loss(self, logits, labels, weight):
|
| 96 |
+
"""
|
| 97 |
+
Compute the cross entropy loss per sample.
|
| 98 |
+
First, compute the per-token loss (with no reduction), then reduce over the sequence length for each sample.
|
| 99 |
+
Finally, average over the batch.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
logits (torch.Tensor): [B, L, vocab_size], unnormalized model outputs
|
| 103 |
+
labels (torch.Tensor): [B, L], target labels (with padding tokens as -100)
|
| 104 |
+
weight (torch.Tensor): [B, 1], per-sample weight for loss calculation
|
| 105 |
+
Returns:
|
| 106 |
+
loss (torch.Tensor): Averaged loss over the batch
|
| 107 |
+
logging_output (torch.Tensor): Dictionary of values for logging
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
loss_token = F.cross_entropy(
|
| 111 |
+
logits.view(-1, logits.size(-1)),
|
| 112 |
+
labels.view(-1),
|
| 113 |
+
reduction='none',
|
| 114 |
+
ignore_index=self.pad_id,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
loss_token = loss_token.view(labels.size(0), labels.size(1)) # Reshape to [B, L]
|
| 118 |
+
valid_mask = (labels != self.pad_id)
|
| 119 |
+
|
| 120 |
+
sample_loss = (loss_token * valid_mask.float()).sum(dim=1) / valid_mask.float().sum(dim=1).clamp(min=1)
|
| 121 |
+
sample_loss *= weight # RDM weighting
|
| 122 |
+
ppl = torch.exp(sample_loss)
|
| 123 |
+
|
| 124 |
+
return {'ppl': ppl.mean(), 'loss': sample_loss.mean()}
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# -------# Training / Evaluation #-------- #
|
| 128 |
+
def training_step(self, batch):
|
| 129 |
+
loss, ppl = self.step(batch)
|
| 130 |
+
self.log("train/loss", loss.item(), on_step=True, on_epoch=False, prog_bar=True)
|
| 131 |
+
self.log("train/ppl", ppl.item(), on_step=True, on_epoch=False, prog_bar=False)
|
| 132 |
+
return loss
|
| 133 |
+
|
| 134 |
+
def validation_step(self, batch):
|
| 135 |
+
loss, ppl = self.step(batch)
|
| 136 |
+
self.cleanup()
|
| 137 |
+
self.log("val/loss", loss.item(), on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 138 |
+
self.log("val/ppl", ppl.item(), on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
| 139 |
+
return loss
|
| 140 |
+
|
| 141 |
+
def test_step(self, batch):
|
| 142 |
+
loss, ppl = self.step(batch)
|
| 143 |
+
self.cleanup()
|
| 144 |
+
self.log('test/loss', loss.item(), on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 145 |
+
self.log("test/ppl", ppl.item(), on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
| 146 |
+
return loss
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# -------# Helper methods #-------- #
|
| 150 |
+
def is_maskable(self, input_ids: torch.Tensor):
|
| 151 |
+
return (
|
| 152 |
+
(input_ids != self.tokenizer.pad_token_id)
|
| 153 |
+
& (input_ids != self.tokenizer.cls_token_id)
|
| 154 |
+
& (input_ids != self.tokenizer.eos_token_id)
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
def configure_optimizers(self):
|
| 158 |
+
"""
|
| 159 |
+
Choosing which optimizer and lr scheduler to use.
|
| 160 |
+
"""
|
| 161 |
+
optimizer = get_optimizer(self.config, self.model.parameters())
|
| 162 |
+
lr_scheduler, extra_kwargs = get_scheduler(self.config, optimizer) # Polynomial scheduler
|
| 163 |
+
return {
|
| 164 |
+
"optimizer": optimizer,
|
| 165 |
+
"lr_scheduler": {"scheduler": lr_scheduler, **extra_kwargs},
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
def validate_config(self):
|
| 169 |
+
assert os.path.isdir(self.config.checkpointing.save_dir), "invalid checkpointing path"
|
| 170 |
+
assert self.config.training.mode in ["train", "test", "resume_from_checkpoint"], "invalid mode"
|
| 171 |
+
|
| 172 |
+
def get_state_dict(self, ckpt_path):
|
| 173 |
+
def remove_model_prefix(state_dict):
|
| 174 |
+
for k, v in state_dict.items():
|
| 175 |
+
if "model." in k:
|
| 176 |
+
k.replace('model.', '')
|
| 177 |
+
return state_dict
|
| 178 |
+
|
| 179 |
+
checkpoint = torch.load(ckpt_path, map_location='cuda' if torch.cuda.is_available() else 'cpu')
|
| 180 |
+
state_dict = checkpoint.get("state_dict", checkpoint)
|
| 181 |
+
|
| 182 |
+
if any(k.startswith("model.") for k in state_dict.keys()):
|
| 183 |
+
state_dict = remove_model_prefix(state_dict)
|
| 184 |
+
|
| 185 |
+
return state_dict
|
| 186 |
+
|
| 187 |
+
def cleanup(self):
|
| 188 |
+
torch.cuda.empty_cache()
|
| 189 |
+
gc.collect()
|
src/lm/memdlm/loss.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
# Ignore file
|
| 6 |
+
|
| 7 |
+
class RDMCrossEntropyLoss(nn.CrossEntropyLoss):
|
| 8 |
+
def __init__(self, ignore_index):
|
| 9 |
+
self.ignore_index = ignore_index
|
| 10 |
+
|
| 11 |
+
def forward(self,
|
| 12 |
+
scores: torch.Tensor,
|
| 13 |
+
target: torch.Tensor,
|
| 14 |
+
label_mask,
|
| 15 |
+
weights,
|
| 16 |
+
) -> torch.Tensor:
|
| 17 |
+
"""
|
| 18 |
+
Computes the RDM-derived loss (weighted cross-entropy).
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
sample_size = target.ne(self.ignore_index).float().sum()
|
| 22 |
+
|
| 23 |
+
lprobs = F.log_softmax(scores, dim=-1)
|
| 24 |
+
|
| 25 |
+
loss = lprobs * weights
|
| 26 |
+
fullseq_loss = loss.sum() / sample_size
|
| 27 |
+
|
| 28 |
+
# use coord masked loss for model training,
|
| 29 |
+
# ignoring those position with missing coords (as nan)
|
| 30 |
+
label_mask = label_mask.float()
|
| 31 |
+
sample_size = label_mask.sum() # sample size should be set to valid coordinates
|
| 32 |
+
loss = (loss * label_mask).sum() / sample_size
|
| 33 |
+
|
| 34 |
+
ppl = torch.exp(loss)
|
| 35 |
+
|
| 36 |
+
logging_output = {
|
| 37 |
+
'ppl': ppl.data,
|
| 38 |
+
'fullseq_loss': fullseq_loss.data,
|
| 39 |
+
'weight_diff_loss': loss.data
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
return logging_output
|
src/lm/memdlm/main.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gc
|
| 3 |
+
import sys
|
| 4 |
+
import torch
|
| 5 |
+
import wandb
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import lightning.pytorch as pl
|
| 8 |
+
|
| 9 |
+
from omegaconf import OmegaConf
|
| 10 |
+
from lightning.pytorch.strategies import DDPStrategy
|
| 11 |
+
from lightning.pytorch.loggers import WandbLogger
|
| 12 |
+
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
|
| 13 |
+
|
| 14 |
+
from src.lm.memdlm.diffusion_module import MembraneDiffusion
|
| 15 |
+
from src.lm.memdlm.dataloader import MembraneDataModule, get_datasets
|
| 16 |
+
from src.utils.model_utils import apply_rdm_freezing
|
| 17 |
+
|
| 18 |
+
wandb.login(key='2b76a2fa2c1cdfddc5f443602c17b011fefb0a8f')
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# Load yaml config
|
| 22 |
+
config = OmegaConf.load("/scratch/pranamlab/sgoel/MeMDLM_v2/src/configs/lm.yaml")
|
| 23 |
+
|
| 24 |
+
# Get datasets
|
| 25 |
+
datasets = get_datasets(config)
|
| 26 |
+
data_module = MembraneDataModule(
|
| 27 |
+
config=config,
|
| 28 |
+
train_dataset=datasets['train'],
|
| 29 |
+
val_dataset=datasets['val'],
|
| 30 |
+
test_dataset=datasets['test'],
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# Initialize WandB for logging
|
| 34 |
+
wandb.init(project=config.wandb.project, name=config.wandb.name)
|
| 35 |
+
wandb_logger = WandbLogger(**config.wandb)
|
| 36 |
+
|
| 37 |
+
# PL checkpoints
|
| 38 |
+
lr_monitor = LearningRateMonitor(logging_interval="step")
|
| 39 |
+
checkpoint_callback = ModelCheckpoint(
|
| 40 |
+
monitor="val/loss",
|
| 41 |
+
save_top_k=1,
|
| 42 |
+
mode="min",
|
| 43 |
+
dirpath=config.checkpointing.save_dir,
|
| 44 |
+
filename="best_model",
|
| 45 |
+
every_n_train_steps=config.checkpointing.save_every_n_steps
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# PL trainer
|
| 49 |
+
trainer = pl.Trainer(
|
| 50 |
+
max_steps=config.training.max_steps,
|
| 51 |
+
max_epochs=None, # Ensure training is based on num steps
|
| 52 |
+
accelerator="cuda" if torch.cuda.is_available() else "cpu",
|
| 53 |
+
devices=config.training.devices if config.training.mode=='train' else [0],
|
| 54 |
+
strategy=DDPStrategy(find_unused_parameters=True),
|
| 55 |
+
callbacks=[checkpoint_callback, lr_monitor],
|
| 56 |
+
logger=wandb_logger,
|
| 57 |
+
log_every_n_steps=config.training.log_every_n_steps
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# Folder to save checkpoints
|
| 62 |
+
ckpt_path = config.checkpointing.save_dir
|
| 63 |
+
try: os.makedirs(ckpt_path, exist_ok=False)
|
| 64 |
+
except FileExistsError: pass
|
| 65 |
+
|
| 66 |
+
# PL Model for training
|
| 67 |
+
diffusion = MembraneDiffusion(config)
|
| 68 |
+
diffusion.validate_config()
|
| 69 |
+
|
| 70 |
+
# Start/resume training or evaluate the model
|
| 71 |
+
model_type = "evoflow"
|
| 72 |
+
if config.training.mode == "train":
|
| 73 |
+
apply_rdm_freezing(diffusion.model, config.training.n_layers, model_type)
|
| 74 |
+
trainer.fit(diffusion, datamodule=data_module)
|
| 75 |
+
|
| 76 |
+
elif config.training.mode == "test":
|
| 77 |
+
state_dict = diffusion.get_state_dict(config.checkpointing.best_ckpt_path)
|
| 78 |
+
diffusion.load_state_dict(state_dict)
|
| 79 |
+
trainer.test(diffusion, datamodule=data_module, ckpt_path=config.checkpointing.best_ckpt_path)
|
| 80 |
+
|
| 81 |
+
elif config.training.mode == "resume_from_checkpoint":
|
| 82 |
+
state_dict = diffusion.get_state_dict(config.training.resume_ckpt_path)
|
| 83 |
+
diffusion.load_state_dict(state_dict)
|
| 84 |
+
apply_rdm_freezing(diffusion.model, config.training.n_layers, model_type)
|
| 85 |
+
trainer.fit(diffusion, datamodule=data_module, ckpt_path=ckpt_path)
|
| 86 |
+
|
| 87 |
+
wandb.finish()
|
| 88 |
+
|
src/sampling/guided_generator.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from omegaconf import OmegaConf
|
| 10 |
+
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
| 11 |
+
|
| 12 |
+
from src.lm.memdlm.diffusion_module import MembraneFlow
|
| 13 |
+
from src.utils.model_utils import _print
|
| 14 |
+
from src.sampling.guided_sampler import GuidedSampler
|
| 15 |
+
from src.utils.generate_utils import (
|
| 16 |
+
mask_for_scaffold,
|
| 17 |
+
calc_blosum_score,
|
| 18 |
+
calc_ppl
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
config = OmegaConf.load("/home/a03-sgoel/MeMDLM_v2/src/configs/guidance.yaml")
|
| 22 |
+
|
| 23 |
+
os.chdir(f'/home/a03-sgoel/MeMDLM_v2/results/infilling/guided/{config.lm.ft_evoflow}/test_set/')
|
| 24 |
+
todays_date = datetime.today().strftime('%Y-%m-%d')
|
| 25 |
+
csv_save_path = f'./{todays_date}_boltzmann-soft_new_clf_data_cleaned/'
|
| 26 |
+
try: os.makedirs(csv_save_path, exist_ok=False)
|
| 27 |
+
except FileExistsError: pass
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def main():
|
| 31 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 32 |
+
|
| 33 |
+
tokenizer = AutoTokenizer.from_pretrained(config.lm.pretrained_esm)
|
| 34 |
+
esm_model = AutoModelForMaskedLM.from_pretrained(config.lm.pretrained_esm).eval().to(device)
|
| 35 |
+
|
| 36 |
+
diffusion = MembraneFlow(config).to(device)
|
| 37 |
+
state_dict = diffusion.get_state_dict(f"/home/a03-sgoel/MeMDLM_v2/checkpoints/{config.lm.ft_evoflow}/best_model.ckpt")
|
| 38 |
+
diffusion.load_state_dict(state_dict)
|
| 39 |
+
diffusion.eval().to(device)
|
| 40 |
+
|
| 41 |
+
sampler = GuidedSampler(config, esm_model, tokenizer, diffusion, device)
|
| 42 |
+
|
| 43 |
+
df = pd.read_csv('/home/a03-sgoel/MeMDLM_v2/data/classifier/test.csv')
|
| 44 |
+
sequences = df['Sequence'].tolist()
|
| 45 |
+
|
| 46 |
+
gen_seqs, ppls, blosums = [], [], []
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
for seq in tqdm(sequences, desc='Infilling Sequences'):
|
| 50 |
+
masked_seq = mask_for_scaffold(seq, generate_type='uppercase', mask_token='<mask>')
|
| 51 |
+
tokens = tokenizer(masked_seq, return_tensors='pt')
|
| 52 |
+
input_ids, attn_masks = tokens['input_ids'].to(device), tokens['attention_mask'].to(device)
|
| 53 |
+
|
| 54 |
+
soluble_idxs = [i for i in range(len(seq)) if seq[i].isupper()]
|
| 55 |
+
infilled_tokens = sampler.optimize_sequence(
|
| 56 |
+
input_ids=input_ids,
|
| 57 |
+
attn_masks=attn_masks,
|
| 58 |
+
soluble_indices=soluble_idxs,
|
| 59 |
+
)
|
| 60 |
+
infilled_seq = tokenizer.decode(infilled_tokens).replace(" ", "")[5:-5]
|
| 61 |
+
|
| 62 |
+
bl = calc_blosum_score(seq.upper(), infilled_seq, soluble_idxs)
|
| 63 |
+
try:
|
| 64 |
+
ppl = calc_ppl(esm_model, tokenizer, infilled_seq, [i for i in range(len(seq))], model_type='esm')
|
| 65 |
+
except:
|
| 66 |
+
ppl = float('inf')
|
| 67 |
+
|
| 68 |
+
gen_seqs.append(infilled_seq)
|
| 69 |
+
ppls.append(ppl)
|
| 70 |
+
blosums.append(bl)
|
| 71 |
+
|
| 72 |
+
_print(seq)
|
| 73 |
+
_print(infilled_seq)
|
| 74 |
+
_print(ppl)
|
| 75 |
+
_print(bl)
|
| 76 |
+
_print('\n')
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
df['MeMDLM Sequence'] = gen_seqs
|
| 80 |
+
df['MeMDLM PPL'] = ppls
|
| 81 |
+
df['MeMDLM BLOSUM'] = blosums
|
| 82 |
+
|
| 83 |
+
_print(df)
|
| 84 |
+
df.to_csv(f'./{csv_save_path}/t=0.7_new-data-cleaned_infilled_seqs.csv', index=False)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
if __name__ == "__main__":
|
| 89 |
+
main()
|
| 90 |
+
|
src/sampling/guided_sampler.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from src.utils.model_utils import _print
|
| 7 |
+
from src.guidance.solubility_module import SolubilityClassifier
|
| 8 |
+
from src.sampling.unconditional_sampler import UnconditionalSampler
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class GuidedSampler:
|
| 12 |
+
def __init__(self, config, esm_model, tokenizer, diffusion, device):
|
| 13 |
+
self.config = config
|
| 14 |
+
self.device = device
|
| 15 |
+
|
| 16 |
+
self.esm = esm_model
|
| 17 |
+
self.memdlm = diffusion
|
| 18 |
+
self.tokenizer = tokenizer
|
| 19 |
+
self.uncond_generator = UnconditionalSampler(self.tokenizer, self.memdlm)
|
| 20 |
+
|
| 21 |
+
ckpt_path = os.path.join(f"/home/a03-sgoel/MeMDLM_v2/checkpoints/{config.wandb.name}/best_model.ckpt")
|
| 22 |
+
self.classifier_model = SolubilityClassifier(config)
|
| 23 |
+
state_dict = self.classifier_model.get_state_dict(ckpt_path)
|
| 24 |
+
self.classifier_model.load_state_dict(state_dict)
|
| 25 |
+
self.classifier_model.eval().to(self.device)
|
| 26 |
+
|
| 27 |
+
self.top_p = self.config.guidance.top_p
|
| 28 |
+
self.alpha = self.config.guidance.alpha
|
| 29 |
+
self.gamma = self.config.guidance.gamma
|
| 30 |
+
self.saliency_eps = self.config.guidance.saliency_eps
|
| 31 |
+
self.saliency_t = self.config.guidance.saliency_t
|
| 32 |
+
self.sampling_t = self.config.guidance.sampling_t
|
| 33 |
+
self.boltzmann_t = self.config.guidance.boltzmann_t
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def embed_sequence(self, input_ids, attention_masks):
|
| 37 |
+
with torch.no_grad():
|
| 38 |
+
outs = self.esm(
|
| 39 |
+
input_ids=input_ids,
|
| 40 |
+
attention_mask=attention_masks,
|
| 41 |
+
output_hidden_states=True,
|
| 42 |
+
output_attentions=True
|
| 43 |
+
)
|
| 44 |
+
embeds = outs.hidden_states[-1]
|
| 45 |
+
attn_matrix = outs.attentions
|
| 46 |
+
return embeds, attn_matrix
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def sample_from_categorical(self, logits, temperature, noise_scale=1.0):
|
| 50 |
+
gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-8) + 1e-8)
|
| 51 |
+
logits = (logits / temperature) + (noise_scale * gumbel_noise)
|
| 52 |
+
|
| 53 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
| 54 |
+
_, tokens = log_probs.max(dim=-1)
|
| 55 |
+
|
| 56 |
+
return tokens, log_probs
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def denoise_sequence(self, input_ids, attn_masks):
|
| 60 |
+
"""
|
| 61 |
+
Compute the current and prior sequences' log prob distribution.
|
| 62 |
+
"""
|
| 63 |
+
has_masks = (input_ids == self.tokenizer.mask_token_id).any()
|
| 64 |
+
|
| 65 |
+
# Denosie the sequence if needed
|
| 66 |
+
if has_masks:
|
| 67 |
+
xt_prior, logits_prior = self.uncond_generator.sample_unconditional(
|
| 68 |
+
xt=input_ids,
|
| 69 |
+
num_steps=self.config.guidance.n_steps,
|
| 70 |
+
tau=self.sampling_t,
|
| 71 |
+
return_logits=True
|
| 72 |
+
)
|
| 73 |
+
else:
|
| 74 |
+
xt_prior = input_ids
|
| 75 |
+
logits_prior = self.memdlm(input_ids=input_ids, attention_mask=attn_masks)
|
| 76 |
+
|
| 77 |
+
# Take the final sampling step
|
| 78 |
+
_, logits = self.uncond_generator.sample_unconditional(
|
| 79 |
+
xt=xt_prior,
|
| 80 |
+
num_steps=1, # Only need 1 sampling step
|
| 81 |
+
tau=self.sampling_t,
|
| 82 |
+
return_logits=True
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# Get final sequence log probs (always needed)
|
| 86 |
+
x0, logp_lm = self.sample_from_categorical(logits, temperature=self.sampling_t)
|
| 87 |
+
|
| 88 |
+
return x0.squeeze(), logp_lm.squeeze(), logits_prior
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def get_prior(self, logits_prior, solubility_logits):
|
| 92 |
+
if self.config.guidance.prior == "boltzmann":
|
| 93 |
+
hydrophilic = ["D","E","K","R","N","Q","H","S","T","Y"]
|
| 94 |
+
hydrophobic = ["L","I","V","F","W","M","A","C","G","P"]
|
| 95 |
+
amino_acids = hydrophilic + hydrophobic
|
| 96 |
+
|
| 97 |
+
tokens = list(self.tokenizer.get_vocab().keys())
|
| 98 |
+
other = [tok for tok in tokens if tok not in amino_acids]
|
| 99 |
+
|
| 100 |
+
hydrophilic_idxs = [self.tokenizer.convert_tokens_to_ids(aa) for aa in hydrophilic]
|
| 101 |
+
hydrophobic_idxs = [self.tokenizer.convert_tokens_to_ids(aa) for aa in hydrophobic]
|
| 102 |
+
other_idxs = [self.tokenizer.convert_tokens_to_ids(tok) for tok in other]
|
| 103 |
+
|
| 104 |
+
bias = torch.zeros(len(tokens), device=self.device)
|
| 105 |
+
bias[hydrophilic_idxs] = 1.0
|
| 106 |
+
bias[hydrophobic_idxs] = -1.0
|
| 107 |
+
bias[other_idxs] = 0.0
|
| 108 |
+
|
| 109 |
+
sol_scores = torch.sigmoid(solubility_logits)
|
| 110 |
+
token_bias = sol_scores.unsqueeze(-1) * bias
|
| 111 |
+
|
| 112 |
+
lm_probs = F.softmax(logits_prior / self.sampling_t, dim=-1)
|
| 113 |
+
boltz_weight = torch.exp(token_bias / self.boltzmann_t)
|
| 114 |
+
|
| 115 |
+
p_prior = lm_probs * boltz_weight
|
| 116 |
+
p_prior = p_prior / p_prior.sum(dim=-1, keepdim=True)
|
| 117 |
+
logp_prior = torch.log(p_prior)
|
| 118 |
+
|
| 119 |
+
elif self.config.guidance.prior == "lm_probs":
|
| 120 |
+
_, logp_prior = self.sample_from_categorical(logits_prior, temperature=self.sampling_t)
|
| 121 |
+
|
| 122 |
+
return logp_prior.squeeze()
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def compute_saliency_map(self, embeds, solubility_logits):
|
| 126 |
+
"""
|
| 127 |
+
Compute a saliency map as in LaMBO-2 (https://arxiv.org/abs/2305.20009) Eq. 5
|
| 128 |
+
"""
|
| 129 |
+
# Gradient tracking is already enabled for the embeddings
|
| 130 |
+
solubility_logits.sum().backward(retain_graph=True) # Clf gradients wrt hidden states
|
| 131 |
+
grads = embeds.grad.abs().sum(dim=-1) # Aggergate across hidden dim. Abs value for mangitude only.
|
| 132 |
+
saliency = grads.pow(1.0 / self.saliency_t).clamp(min=self.saliency_eps).to(self.device)
|
| 133 |
+
saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min() + 1e-6)
|
| 134 |
+
return saliency.squeeze()
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def determine_edit_positions(self, saliency_map, soluble_indices, solubility_logits):
|
| 138 |
+
"""
|
| 139 |
+
Fix the insoluble residues and additional TM residues to
|
| 140 |
+
maintain membrane-like protein structure.
|
| 141 |
+
"""
|
| 142 |
+
seq_len = saliency_map.shape[0]
|
| 143 |
+
|
| 144 |
+
# Initialize a mask to store the editable token positions
|
| 145 |
+
edit_mask = torch.ones(seq_len, dtype=torch.bool, device=self.device)
|
| 146 |
+
|
| 147 |
+
# Check for any provided soluble residues, otherwise use classifier preds
|
| 148 |
+
if len(soluble_indices) > 0:
|
| 149 |
+
edit_mask[soluble_indices] = False
|
| 150 |
+
elif soluble_indices is None or len(soluble_indices) == 0:
|
| 151 |
+
solubility_preds = F.sigmoid(solubility_logits)
|
| 152 |
+
edit_mask[solubility_preds > 0.5] = False
|
| 153 |
+
|
| 154 |
+
# Find additional TM residues
|
| 155 |
+
num_conserved = max(1, int(0.1 * edit_mask.sum()))
|
| 156 |
+
_, topk_idxs = torch.topk(saliency_map, num_conserved)
|
| 157 |
+
edit_mask[topk_idxs] = False
|
| 158 |
+
|
| 159 |
+
edit_idxs = edit_mask.nonzero(as_tuple=True)[0]
|
| 160 |
+
return edit_idxs
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def create_neighborhood(self, edit_pos, attn_matrix, top_p):
|
| 164 |
+
"""
|
| 165 |
+
Select a dynamic "neighborhood" of tokens for edit position via top-p sampling.
|
| 166 |
+
Attention scores find relevant tokens, avoding blind updates of the individual token
|
| 167 |
+
"""
|
| 168 |
+
# Get the attention scores for the current edit position
|
| 169 |
+
row = attn_matrix[edit_pos].clone().squeeze()
|
| 170 |
+
row = row.index_fill(
|
| 171 |
+
dim=0,
|
| 172 |
+
index=torch.tensor([0, edit_pos, row.size(0)-1], device=row.device),
|
| 173 |
+
value=float('-inf')
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# Top-p (nucleus) sampling of tokens via normed attention scores
|
| 177 |
+
temp = 1.0 / math.log(row.size(0)) # scale temp with seq len to balance
|
| 178 |
+
attn_probs = F.softmax(row / temp, dim=0)
|
| 179 |
+
sorted_probs, sorted_idxs = torch.sort(attn_probs, descending=True)
|
| 180 |
+
cum_probs = sorted_probs.cumsum(dim=0)
|
| 181 |
+
cutoff = (cum_probs <= top_p).nonzero(as_tuple=True)[0]
|
| 182 |
+
|
| 183 |
+
# Ensure neighborhoods will always have 1 token
|
| 184 |
+
final_idx = cutoff[-1].item() + 1 if cutoff.numel() > 0 else 1
|
| 185 |
+
neighborhood = sorted_idxs[:final_idx]
|
| 186 |
+
return neighborhood
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def compute_saliency_weight(self, edit_pos, attn_mat, saliency_map, neighborhood):
|
| 190 |
+
"""
|
| 191 |
+
Blend the saliency of the neighborhood's tokens and the token at the edit position.
|
| 192 |
+
"""
|
| 193 |
+
neighborhood_attns = attn_mat[edit_pos, neighborhood]
|
| 194 |
+
neighborhood_attns /= neighborhood_attns.sum()
|
| 195 |
+
|
| 196 |
+
neighborhood_saliencies = saliency_map[neighborhood]
|
| 197 |
+
|
| 198 |
+
neighborhood_weight = torch.sum(neighborhood_attns * neighborhood_saliencies)
|
| 199 |
+
ctxt_aware_saliency = saliency_map[edit_pos] + (self.gamma * neighborhood_weight)
|
| 200 |
+
|
| 201 |
+
return ctxt_aware_saliency
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def compute_guidance_dist(self, logp_lm, logp_prior, saliency_weight):
|
| 205 |
+
"""
|
| 206 |
+
Define a guidance distribution between a prior and the current LM probs.
|
| 207 |
+
Compute the log probs of the "new" (optimized) token.
|
| 208 |
+
"""
|
| 209 |
+
w = torch.sigmoid(saliency_weight * self.alpha) # Between [0, 1] to ensure valid probs
|
| 210 |
+
p_lm = torch.exp(logp_lm)
|
| 211 |
+
p_prior = torch.exp(logp_prior)
|
| 212 |
+
mixed_probs = (1 - w) * p_lm + w * p_prior
|
| 213 |
+
guidance_dist = torch.log(mixed_probs + 1e-12)
|
| 214 |
+
return guidance_dist
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def check_scaffold(self, seq1, seq2, idxs):
|
| 218 |
+
changed = (seq1[idxs] != seq2[idxs])
|
| 219 |
+
if changed.any():
|
| 220 |
+
_print('soluble residues changed')
|
| 221 |
+
else:
|
| 222 |
+
_print('no soluble residue changes')
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def optimize_sequence(self, input_ids, attn_masks, soluble_indices):
|
| 226 |
+
_print(f'soluble idx: {soluble_indices}')
|
| 227 |
+
|
| 228 |
+
# Initialize token ids, logits, and log probs of sequence
|
| 229 |
+
x0, logp_lm, logits_prior = self.denoise_sequence(input_ids, attn_masks)
|
| 230 |
+
_print(f'og tokens: {x0}')
|
| 231 |
+
_print(f'og tokens: {x0.shape}')
|
| 232 |
+
_print(f'og log probs: {logp_lm.shape}')
|
| 233 |
+
|
| 234 |
+
# Embeddings and attention matrix of current sequence
|
| 235 |
+
embeds, attn_mats = self.embed_sequence(x0.unsqueeze(0), attn_masks)
|
| 236 |
+
embeds = embeds.detach().clone().requires_grad_(True) # enable grad tracking for saliency map
|
| 237 |
+
attn_matrix = attn_mats[-1].mean(dim=1)[0].squeeze(0)
|
| 238 |
+
|
| 239 |
+
# Precompute logits of the classifier to avoid repeated calls
|
| 240 |
+
batch = {"embeds": embeds, "attention_mask": attn_masks}
|
| 241 |
+
solubility_logits = self.classifier_model(batch)
|
| 242 |
+
|
| 243 |
+
# Create a saliency map to determined optimal edit positions
|
| 244 |
+
saliency_map = self.compute_saliency_map(embeds, solubility_logits)
|
| 245 |
+
_print(f'saliency map: {saliency_map}')
|
| 246 |
+
edit_positions = self.determine_edit_positions(saliency_map, soluble_indices, solubility_logits)
|
| 247 |
+
_print(f'edit positions: {edit_positions}')
|
| 248 |
+
|
| 249 |
+
# Compute the log probs of the prior dist
|
| 250 |
+
logp_prior = self.get_prior(logits_prior, solubility_logits)
|
| 251 |
+
_print(f'prior log probs: {logp_prior.shape}')
|
| 252 |
+
|
| 253 |
+
# Optimize the insoluble residues
|
| 254 |
+
for edit_pos in edit_positions.tolist():
|
| 255 |
+
neighborhood = self.create_neighborhood(
|
| 256 |
+
edit_pos,
|
| 257 |
+
attn_matrix,
|
| 258 |
+
self.top_p
|
| 259 |
+
)
|
| 260 |
+
_print(f'neighborhood: {neighborhood}')
|
| 261 |
+
|
| 262 |
+
ctxt_aware_saliency = self.compute_saliency_weight(
|
| 263 |
+
edit_pos,
|
| 264 |
+
attn_matrix,
|
| 265 |
+
saliency_map,
|
| 266 |
+
neighborhood
|
| 267 |
+
)
|
| 268 |
+
_print(f'ctx aware saliency: {ctxt_aware_saliency}')
|
| 269 |
+
|
| 270 |
+
logp_lm_prime = self.compute_guidance_dist(
|
| 271 |
+
logp_lm[edit_pos],
|
| 272 |
+
logp_prior[edit_pos],
|
| 273 |
+
ctxt_aware_saliency
|
| 274 |
+
)
|
| 275 |
+
logp_lm[edit_pos] = logp_lm_prime
|
| 276 |
+
|
| 277 |
+
tot = torch.exp(logp_lm_prime).sum()
|
| 278 |
+
one = torch.tensor(1.0, dtype=tot.dtype, device=tot.device)
|
| 279 |
+
assert torch.isclose(tot, one, atol=1e-4), f"Invalid prob distribution. Sum = {tot:5f}"
|
| 280 |
+
|
| 281 |
+
# Sample new tokens
|
| 282 |
+
x0_prime = torch.distributions.Categorical(logits=logp_lm).sample()
|
| 283 |
+
|
| 284 |
+
# Check if any soluble residues have been changed
|
| 285 |
+
self.check_scaffold(x0, x0_prime, soluble_indices)
|
| 286 |
+
|
| 287 |
+
# Preserve the initial sequence scaffold by copying over the soluble tokens
|
| 288 |
+
x0_prime[soluble_indices] = x0[soluble_indices]
|
| 289 |
+
self.check_scaffold(x0, x0_prime, soluble_indices)
|
| 290 |
+
|
| 291 |
+
return x0_prime
|
src/sampling/unconditional_generator.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
import random
|
| 7 |
+
import torch
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
from collections import Counter
|
| 13 |
+
from omegaconf import OmegaConf
|
| 14 |
+
from datetime import datetime
|
| 15 |
+
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
| 16 |
+
|
| 17 |
+
from MeMDLM_v2.src.lm.diffusion_module import MembraneFlow
|
| 18 |
+
from src.sampling.unconditional_sampler import UnconditionalSampler
|
| 19 |
+
from src.utils.generate_utils import mask_for_de_novo, calc_ppl
|
| 20 |
+
from src.utils.model_utils import _print
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 24 |
+
os.chdir('/home/a03-sgoel/MeMDLM_v2')
|
| 25 |
+
config = OmegaConf.load("./src/configs/lm.yaml")
|
| 26 |
+
|
| 27 |
+
date = datetime.now().strftime("%Y-%m-%d")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def generate_sequence(prior: str, tokenizer, generator, device):
|
| 32 |
+
input_ids = tokenizer(prior, return_tensors="pt").to(device)['input_ids']
|
| 33 |
+
ids = generator.sample_unconditional(
|
| 34 |
+
xt=input_ids,
|
| 35 |
+
num_steps=config.sampling.n_steps,
|
| 36 |
+
return_logits=False,
|
| 37 |
+
banned_token_ids=None
|
| 38 |
+
#banned_token_ids=[tokenizer.convert_tokens_to_ids("P"), tokenizer.convert_tokens_to_ids("C")]
|
| 39 |
+
)
|
| 40 |
+
generated_sequence = tokenizer.decode(ids[0].squeeze())[5:-5].replace(" ", "") # bos/eos tokens & spaces between residues
|
| 41 |
+
return generated_sequence
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def main():
|
| 45 |
+
csv_save_path = f'./results/denovo/unconditional/{config.wandb.name}/{date}_tau=3.0_test-set_distribution'
|
| 46 |
+
|
| 47 |
+
try: os.makedirs(csv_save_path, exist_ok=False)
|
| 48 |
+
except FileExistsError: pass
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
tokenizer = AutoTokenizer.from_pretrained(config.lm.pretrained_evoflow)
|
| 52 |
+
|
| 53 |
+
flow = MembraneFlow(config).to(device)
|
| 54 |
+
state_dict = flow.get_state_dict(f"./checkpoints/{config.wandb.name}/best_model.ckpt")
|
| 55 |
+
flow.load_state_dict(state_dict)
|
| 56 |
+
flow.eval()
|
| 57 |
+
|
| 58 |
+
esm_pth = config.lm.pretrained_esm
|
| 59 |
+
esm_model = AutoModelForMaskedLM.from_pretrained(esm_pth).to(device)
|
| 60 |
+
esm_model.eval()
|
| 61 |
+
|
| 62 |
+
generator = UnconditionalSampler(tokenizer, flow)
|
| 63 |
+
|
| 64 |
+
# # Get 100 random sequence lengths to generate
|
| 65 |
+
# seq_lengths = [random.randint(50, 250) for _ in range(5000)]
|
| 66 |
+
|
| 67 |
+
# # Determine length from positive controls
|
| 68 |
+
# df = pd.read_csv(f'./results/denovo/unconditional/{config.wandb.name}/perin_pos_ctrl/raw_seqs.csv')
|
| 69 |
+
# seq_lengths = [len(seq) for seq in df['Sequence'].tolist() for _ in range(500)] # generate each length 100 times
|
| 70 |
+
# _print(seq_lengths)
|
| 71 |
+
|
| 72 |
+
# Determine lengths from test set distribution
|
| 73 |
+
df = pd.read_csv("./data/test.csv")
|
| 74 |
+
seq_lengths = [len(seq) for seq in df['Sequence'].tolist()]
|
| 75 |
+
length_counts = Counter(seq_lengths) # {L1: freq, L2: freq, ...}
|
| 76 |
+
total = sum(length_counts.values()) # total number of tokens
|
| 77 |
+
lengths = np.array(list(length_counts.keys())) # Frequency of each length
|
| 78 |
+
probs = np.array([length_counts[l] / total for l in lengths])
|
| 79 |
+
seq_lengths = np.random.choice(lengths, size=len(seq_lengths), p=probs)
|
| 80 |
+
|
| 81 |
+
generation_results = []
|
| 82 |
+
for seq_len in tqdm(seq_lengths, desc=f"Generating sequences: "):
|
| 83 |
+
seq_res = []
|
| 84 |
+
|
| 85 |
+
masked_seq = mask_for_de_novo(seq_len) # Sequence of all <mask> tokens
|
| 86 |
+
gen_seq = ""
|
| 87 |
+
attempts = 0
|
| 88 |
+
|
| 89 |
+
while len(gen_seq) != seq_len and attempts < 3:
|
| 90 |
+
gen_seq = generate_sequence(masked_seq, tokenizer, generator, device)
|
| 91 |
+
attempts += 1
|
| 92 |
+
|
| 93 |
+
if len(gen_seq) != seq_len:
|
| 94 |
+
esm_ppl, flow_ppl = None, None
|
| 95 |
+
else:
|
| 96 |
+
esm_ppl = calc_ppl(esm_model, tokenizer, gen_seq, [i for i in range(len(gen_seq))], model_type='esm')
|
| 97 |
+
flow_ppl = calc_ppl(flow, tokenizer, gen_seq, [i for i in range(len(gen_seq))], model_type='flow')
|
| 98 |
+
|
| 99 |
+
_print(f'gen seq: {gen_seq}')
|
| 100 |
+
_print(f'esm ppl: {esm_ppl}')
|
| 101 |
+
_print(f'flow ppl: {flow_ppl}')
|
| 102 |
+
|
| 103 |
+
seq_res.append(gen_seq)
|
| 104 |
+
seq_res.append(esm_ppl)
|
| 105 |
+
seq_res.append(flow_ppl)
|
| 106 |
+
|
| 107 |
+
generation_results.append(seq_res)
|
| 108 |
+
|
| 109 |
+
df = pd.DataFrame(generation_results, columns=['Generated Sequence', 'ESM PPL', 'Flow PPL'])
|
| 110 |
+
df.to_csv(csv_save_path + "/seqs_with_ppl.csv", index=False)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
if __name__ == "__main__":
|
| 114 |
+
main()
|
src/sampling/unconditional_sampler.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import torch
|
| 3 |
+
import random
|
| 4 |
+
import numpy as np
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from src.utils.model_utils import _print
|
| 7 |
+
|
| 8 |
+
class UnconditionalSampler:
|
| 9 |
+
def __init__(self, tokenizer, model):
|
| 10 |
+
self.model = model
|
| 11 |
+
self.tokenizer = tokenizer
|
| 12 |
+
|
| 13 |
+
self.device = self.model.device
|
| 14 |
+
self.mask_id = self.tokenizer.mask_token_id
|
| 15 |
+
self.seed_everything(seed=42)
|
| 16 |
+
|
| 17 |
+
@torch.inference_mode()
|
| 18 |
+
def sample_unconditional(self, xt, num_steps, tau=0.7, kappa_fn=lambda t: t, eta=1, alpha=1., banned_token_ids=None, return_logits=None):
|
| 19 |
+
"""
|
| 20 |
+
Stochastic remasking sampling method for iterative refinement of sequences.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
xt (Tensor): Initial token tensor.
|
| 24 |
+
num_steps (int): Number of refinement steps.
|
| 25 |
+
tau (float): Temperature parameter for softmax sampling.
|
| 26 |
+
kappa_fn (callable): Function controlling the unmasking schedule.
|
| 27 |
+
eta (float): Scaling factor for score adjustments.
|
| 28 |
+
alpha (float): Weighting for confidence-based scoring.
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
Tensor: Final sampled sequence tensor.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
dt = 1 / num_steps
|
| 35 |
+
fix_mask = xt != self.mask_id # tokens to retain
|
| 36 |
+
attention_mask = torch.ones_like(xt).to(self.device)
|
| 37 |
+
|
| 38 |
+
for i in range(1, num_steps + 1):
|
| 39 |
+
kappa_t = kappa_fn(i * dt)
|
| 40 |
+
logits = self.model(input_ids=xt, attention_mask=attention_mask)
|
| 41 |
+
last_mask = xt == self.mask_id # tokens currently masked
|
| 42 |
+
unmask_t = ~last_mask & ~fix_mask # unmasked and not fixed tokens - candidates for masking
|
| 43 |
+
|
| 44 |
+
x0, logp = self.stochastic_sample_from_categorical(logits, tau, banned_token_ids=banned_token_ids) # tokens, logprobs
|
| 45 |
+
|
| 46 |
+
# Confidence-based scoring
|
| 47 |
+
entropy = torch.distributions.Categorical(logits=logits).entropy()
|
| 48 |
+
score = alpha * logp + (1 - alpha) * -entropy # alpha = 1 --> score = logp
|
| 49 |
+
score = score.masked_fill(fix_mask, float('inf'))
|
| 50 |
+
|
| 51 |
+
score[unmask_t] = score[unmask_t] * eta
|
| 52 |
+
|
| 53 |
+
num_to_mask = ((~fix_mask).sum(1, keepdim=True).float() * (1 - kappa_t)).long()
|
| 54 |
+
lowest_k_mask = self.topk_lowest_masking(score, num_to_mask)
|
| 55 |
+
|
| 56 |
+
xt[lowest_k_mask] = self.mask_id
|
| 57 |
+
mask_2_x0 = last_mask & ~lowest_k_mask
|
| 58 |
+
xt[mask_2_x0] = x0[mask_2_x0]
|
| 59 |
+
|
| 60 |
+
# print(f"Step {i}/{num_steps} | eta: {eta}, alpha: {alpha}, Stochastic remask: \n", xt[0])
|
| 61 |
+
|
| 62 |
+
xt[xt == self.mask_id] = x0[xt == self.mask_id]
|
| 63 |
+
return xt, logits if return_logits else xt
|
| 64 |
+
|
| 65 |
+
def stochastic_sample_from_categorical(self, logits, temperature, noise_scale=1.0, banned_token_ids=None):
|
| 66 |
+
"""
|
| 67 |
+
Sample from a categorical distribution with optional temperature scaling and Gumbel noise.
|
| 68 |
+
"""
|
| 69 |
+
logits = logits.double()
|
| 70 |
+
|
| 71 |
+
if banned_token_ids is not None:
|
| 72 |
+
banned_token_mask = torch.zeros_like(logits, device=logits.device).bool()
|
| 73 |
+
for token_id in banned_token_ids:
|
| 74 |
+
banned_token_mask[..., token_id] = True
|
| 75 |
+
logits = logits.masked_fill(banned_token_mask, float('-inf'))
|
| 76 |
+
|
| 77 |
+
if temperature != 0:
|
| 78 |
+
gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-8) + 1e-8)
|
| 79 |
+
logits = logits / temperature + noise_scale * gumbel_noise
|
| 80 |
+
scores, tokens = logits.log_softmax(dim=-1).max(dim=-1)
|
| 81 |
+
|
| 82 |
+
return tokens, scores
|
| 83 |
+
|
| 84 |
+
def topk_lowest_masking(self, scores, cutoff_len):
|
| 85 |
+
"""
|
| 86 |
+
scores: [b, n]
|
| 87 |
+
cutoff_len: [b, 1]
|
| 88 |
+
returns:
|
| 89 |
+
mask: [b, n], with 1 if the token is in top-k lowest scores, 0 otherwise
|
| 90 |
+
"""
|
| 91 |
+
sorted_index = scores.sort(-1)[0]
|
| 92 |
+
cutoff = sorted_index.gather(dim=-1, index=cutoff_len)
|
| 93 |
+
return scores < cutoff
|
| 94 |
+
|
| 95 |
+
def seed_everything(self, seed):
|
| 96 |
+
"""
|
| 97 |
+
Set the seed for reproducibility across various libraries.
|
| 98 |
+
"""
|
| 99 |
+
if seed is None:
|
| 100 |
+
return
|
| 101 |
+
random.seed(seed)
|
| 102 |
+
np.random.seed(seed)
|
| 103 |
+
torch.manual_seed(seed)
|
| 104 |
+
if torch.cuda.is_available():
|
| 105 |
+
torch.cuda.manual_seed(seed)
|
| 106 |
+
torch.cuda.manual_seed_all(seed) # if using multi-GPU
|
| 107 |
+
torch.backends.cudnn.deterministic = True
|
| 108 |
+
torch.backends.cudnn.benchmark = False
|
src/utils/generate_utils.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from omegaconf import OmegaConf
|
| 10 |
+
from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer
|
| 11 |
+
|
| 12 |
+
from src.lm.memdlm.diffusion_module import MembraneFlow
|
| 13 |
+
from src.lm.dplm.diffusion_module import DPLM
|
| 14 |
+
from src.utils.model_utils import get_latents, _print
|
| 15 |
+
from src.sampling.unconditional_sampler import UnconditionalSampler
|
| 16 |
+
from src.lm.dplm.unconditional_sampler import UnconditionalSampler as DPLMUnconditionalSampler
|
| 17 |
+
|
| 18 |
+
config = OmegaConf.load("/home/a03-sgoel/MeMDLM_v2/src/configs/lm.yaml")
|
| 19 |
+
|
| 20 |
+
# -------# Masking #-------- #
|
| 21 |
+
def mask_for_de_novo(sequence_length):
|
| 22 |
+
return "<mask>" * sequence_length
|
| 23 |
+
|
| 24 |
+
def mask_for_scaffold(sequence, generate_type, mask_token):
|
| 25 |
+
if generate_type == "uppercase":
|
| 26 |
+
sequence = ''.join([mask_token if residue.isupper() else residue.upper() for residue in sequence])
|
| 27 |
+
elif generate_type == "lowercase":
|
| 28 |
+
sequence = ''.join([mask_token if residue.islower() else residue for residue in sequence])
|
| 29 |
+
return sequence
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# -------# Generation #-------- #
|
| 33 |
+
def memflow_infill_uncond(masked_seq, tokenizer, model: MembraneFlow):
|
| 34 |
+
generator = UnconditionalSampler(tokenizer, model) # initialize the generator object
|
| 35 |
+
xt = tokenizer(masked_seq, return_tensors='pt')['input_ids'].to(model.device)
|
| 36 |
+
denoised_tokens = generator.sample_unconditional(xt, config.sampling.n_steps)[0].squeeze()
|
| 37 |
+
generated_sequence = tokenizer.decode(denoised_tokens).replace(" ", "")[5:-5]
|
| 38 |
+
return generated_sequence
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def evodiff_infill(motif_seq, tokenizer, model, device, batch_size=1):
|
| 42 |
+
"""
|
| 43 |
+
Following the given evodiff example
|
| 44 |
+
https://github.com/microsoft/evodiff/blob/main/examples/evodiff.ipynb
|
| 45 |
+
"""
|
| 46 |
+
# Manual masking of infilling sequence
|
| 47 |
+
motif_seq = ''.join(["#" if aa.islower() else aa for aa in motif_seq]) # Mask token is "#" in evodiff tokenizer
|
| 48 |
+
tkns = tokenizer.tokenize([motif_seq])
|
| 49 |
+
sample = torch.as_tensor(tkns).to(device)
|
| 50 |
+
|
| 51 |
+
# Create input motif + scaffold
|
| 52 |
+
loc = torch.arange(0, len(motif_seq)).to(device)[sample==tokenizer.mask_id].cpu().numpy()
|
| 53 |
+
np.random.shuffle(loc)
|
| 54 |
+
|
| 55 |
+
sample = sample.to(device).unsqueeze(0)
|
| 56 |
+
# og_sample = sample.clone()
|
| 57 |
+
|
| 58 |
+
with torch.no_grad():
|
| 59 |
+
for i in loc:
|
| 60 |
+
timestep = torch.tensor([0] * batch_size).to(device) # placeholder but not called in model
|
| 61 |
+
timestep = timestep.to(device)
|
| 62 |
+
prediction = model(sample, timestep)
|
| 63 |
+
p = prediction[:, i, :len(tokenizer.all_aas) - 6] # only canonical
|
| 64 |
+
p = F.softmax(p, dim=1) # softmax over logits
|
| 65 |
+
p_sample = torch.multinomial(p, num_samples=1) # sample from categorical distribution
|
| 66 |
+
sample[:, i] = p_sample.squeeze()
|
| 67 |
+
output = [tokenizer.untokenize(s) for s in sample]
|
| 68 |
+
return output[0] #if batch_size==1 else output, og_sample, loc
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def dplm_infill(masked_seq, tokenizer, model: DPLM, device):
|
| 72 |
+
generator = DPLMUnconditionalSampler(tokenizer, model)
|
| 73 |
+
xt = tokenizer(masked_seq, return_tensors='pt')['input_ids'].to(model.device)
|
| 74 |
+
denoised_tokens = generator.sample_unconditional(xt, config.sampling.n_steps)[0].squeeze()
|
| 75 |
+
generated_sequence = tokenizer.decode(denoised_tokens).replace(" ", "")[5:-5]
|
| 76 |
+
return generated_sequence
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# -------# Metrics #-------- #
|
| 80 |
+
def calc_progen_ppl(model, tokenizer, target, device, fp16=True):
|
| 81 |
+
"""Compute causal LM cross-entropy loss for a given sequence."""
|
| 82 |
+
with torch.no_grad():
|
| 83 |
+
with torch.cuda.amp.autocast(enabled=fp16):
|
| 84 |
+
logits = model(
|
| 85 |
+
input_ids = target,
|
| 86 |
+
attention_mask = torch.ones_like(target)
|
| 87 |
+
).logits
|
| 88 |
+
# Shift
|
| 89 |
+
logits = logits[:-1, ...]
|
| 90 |
+
target = target[1:]
|
| 91 |
+
loss = torch.nn.functional.cross_entropy(
|
| 92 |
+
input=logits,
|
| 93 |
+
target=target,
|
| 94 |
+
reduction='mean'
|
| 95 |
+
)
|
| 96 |
+
return torch.exp(loss).item()
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def calc_ppl(model, tokenizer, generated_sequence, mask_token_indices, model_type):
|
| 100 |
+
total_loss = 0.0
|
| 101 |
+
tensor_input = tokenizer.encode(generated_sequence, return_tensors='pt').to(model.device)
|
| 102 |
+
attn_mask = torch.ones_like(tensor_input).to(model.device)
|
| 103 |
+
|
| 104 |
+
for i in mask_token_indices:
|
| 105 |
+
masked_input = tensor_input.clone()
|
| 106 |
+
masked_input[0, i] = tokenizer.mask_token_id
|
| 107 |
+
|
| 108 |
+
labels = torch.full(tensor_input.shape, -100).to(model.device)
|
| 109 |
+
labels[0, i] = tensor_input[0, i]
|
| 110 |
+
|
| 111 |
+
with torch.no_grad():
|
| 112 |
+
if model_type == 'esm':
|
| 113 |
+
loss = model(masked_input, labels=labels).loss.item()
|
| 114 |
+
elif model_type == 'flow':
|
| 115 |
+
logits = model.forward(masked_input, attention_mask=attn_mask)
|
| 116 |
+
loss = F.cross_entropy(
|
| 117 |
+
logits.view(-1, logits.size(-1)),
|
| 118 |
+
labels.view(-1),
|
| 119 |
+
reduction='none',
|
| 120 |
+
ignore_index=-100,
|
| 121 |
+
)[i].item()
|
| 122 |
+
|
| 123 |
+
total_loss += loss
|
| 124 |
+
|
| 125 |
+
avg_loss = total_loss / len(generated_sequence)
|
| 126 |
+
perplexity = math.exp(avg_loss)
|
| 127 |
+
|
| 128 |
+
return perplexity
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def calc_blosum_score(og_seq, gen_seq, indices):
|
| 132 |
+
import blosum as bl
|
| 133 |
+
mat = bl.BLOSUM(62)
|
| 134 |
+
tot_score = 0
|
| 135 |
+
for i in indices:
|
| 136 |
+
og_res, gen_res = og_seq[i], gen_seq[i]
|
| 137 |
+
try:
|
| 138 |
+
val = mat[og_res][gen_res]
|
| 139 |
+
tot_score += val
|
| 140 |
+
except KeyError:
|
| 141 |
+
# -4 is lowest BLOSUM score indicating biological implausability
|
| 142 |
+
tot_score += -4
|
| 143 |
+
return tot_score / len(indices) if indices else 0
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def calc_cos_sim(original_sequence, generated_sequence, tokenizer, esm_model, device):
|
| 147 |
+
og_embeddings = get_latents(esm_model, tokenizer, original_sequence.upper(), device)
|
| 148 |
+
new_embeddings = get_latents(esm_model, tokenizer, generated_sequence, device)
|
| 149 |
+
cosine_sim = torch.nn.functional.cosine_similarity(og_embeddings, new_embeddings, dim=-1)
|
| 150 |
+
cosine_sim = torch.mean(cosine_sim).item()
|
| 151 |
+
return cosine_sim
|
src/utils/model_utils.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _print(s):
|
| 7 |
+
print(s)
|
| 8 |
+
sys.stdout.flush()
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_latents(model, tokenizer, sequence, device):
|
| 12 |
+
tokens = tokenizer(sequence, return_tensors="pt").to(device)
|
| 13 |
+
with torch.no_grad():
|
| 14 |
+
outputs = model(**tokens)
|
| 15 |
+
embeds = outputs.hidden_states[-1].squeeze(0) # Get last hidden states
|
| 16 |
+
return embeds
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# General model freezing
|
| 21 |
+
def freeze_model(model: nn.Module):
|
| 22 |
+
# Disable parameter updates for all layers
|
| 23 |
+
for param in model.parameters():
|
| 24 |
+
param.requires_grad = False
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# For ProGen2 architecture
|
| 29 |
+
def apply_gptj_freezing(model, N_layers):
|
| 30 |
+
def unfreeze_n_layers(model, N_layers):
|
| 31 |
+
# Count number of encoder layers
|
| 32 |
+
model_layers = len(model.transformer.h)
|
| 33 |
+
for i, h in enumerate(model.transformer.h):
|
| 34 |
+
if i >= model_layers - N_layers:
|
| 35 |
+
for module in h.attn.modules():
|
| 36 |
+
for param in module.parameters():
|
| 37 |
+
param.requires_grad = True
|
| 38 |
+
|
| 39 |
+
def check_frozen_model(model, N_layers: int):
|
| 40 |
+
"""
|
| 41 |
+
Verify that only the last N_layers of model.transformer.h are unfrozen.
|
| 42 |
+
Source: https://github.com/enijkamp/progen2/blob/main/progen/modeling_progen.py
|
| 43 |
+
"""
|
| 44 |
+
model_layers = len(model.transformer.h)
|
| 45 |
+
frozen_layers = 0
|
| 46 |
+
unfrozen_layers = 0
|
| 47 |
+
for i, h in enumerate(model.transformer.h):
|
| 48 |
+
if i >= model_layers - N_layers: # should be unfrozen
|
| 49 |
+
if any(param.requires_grad for param in h.parameters()):
|
| 50 |
+
unfrozen_layers += 1
|
| 51 |
+
else:
|
| 52 |
+
print(f"Layer {i} has all parameters frozen, but it should be unfrozen.")
|
| 53 |
+
else: # should be frozen
|
| 54 |
+
if any(param.requires_grad for param in h.parameters()):
|
| 55 |
+
print(f"Layer {i} is not frozen, but it should be frozen.")
|
| 56 |
+
else:
|
| 57 |
+
frozen_layers += 1
|
| 58 |
+
|
| 59 |
+
assert frozen_layers == model_layers - N_layers and unfrozen_layers == N_layers, \
|
| 60 |
+
f"frozen layers: {frozen_layers}, unfrozen layers: {unfrozen_layers}"
|
| 61 |
+
|
| 62 |
+
print(f"frozen layers: {frozen_layers}, unfrozen layers: {unfrozen_layers}")
|
| 63 |
+
|
| 64 |
+
freeze_model(model)
|
| 65 |
+
unfreeze_n_layers(model, N_layers)
|
| 66 |
+
check_frozen_model(model, N_layers)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# For RDM-based architectures
|
| 73 |
+
def apply_rdm_freezing(model: nn.Module, N_layers: int, model_type: str):
|
| 74 |
+
"""
|
| 75 |
+
Freeze all layers except last N for esm-like architectures
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
model (nn.Module): model to freeze
|
| 79 |
+
N_layers (int): num encoder layers to unfreeze
|
| 80 |
+
model_type (str): one of {"esm", "evoflow", "dplm"}
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
# choose encoder layers based on the model type
|
| 84 |
+
if model_type == "dplm":
|
| 85 |
+
encoder_layers = model.net.esm.encoder.layer
|
| 86 |
+
elif model_type in ("esm", "evoflow"):
|
| 87 |
+
encoder_layers = model.esm.encoder.layer
|
| 88 |
+
else:
|
| 89 |
+
raise ValueError(f"Unknown model_type: {model_type}")
|
| 90 |
+
|
| 91 |
+
def unfreeze_n_layers(layers, N_layers: int):
|
| 92 |
+
model_layers = len(layers)
|
| 93 |
+
for i, layer in enumerate(layers):
|
| 94 |
+
if i >= model_layers - N_layers:
|
| 95 |
+
for module in layer.attention.self.key.modules():
|
| 96 |
+
for param in module.parameters():
|
| 97 |
+
param.requires_grad = True
|
| 98 |
+
for module in layer.attention.self.query.modules():
|
| 99 |
+
for param in module.parameters():
|
| 100 |
+
param.requires_grad = True
|
| 101 |
+
for module in layer.attention.self.value.modules():
|
| 102 |
+
for param in module.parameters():
|
| 103 |
+
param.requires_grad = True
|
| 104 |
+
|
| 105 |
+
def check_model(layers, N_layers: int):
|
| 106 |
+
model_layers = len(layers)
|
| 107 |
+
frozen_layers = 0
|
| 108 |
+
unfrozen_layers = 0
|
| 109 |
+
|
| 110 |
+
for i, layer in enumerate(layers):
|
| 111 |
+
if i >= model_layers - N_layers:
|
| 112 |
+
layer_frozen = True
|
| 113 |
+
for module in layer.attention.self.key.modules():
|
| 114 |
+
if any(param.requires_grad for param in module.parameters()):
|
| 115 |
+
layer_frozen = False
|
| 116 |
+
for module in layer.attention.self.query.modules():
|
| 117 |
+
if any(param.requires_grad for param in module.parameters()):
|
| 118 |
+
layer_frozen = False
|
| 119 |
+
for module in layer.attention.self.value.modules():
|
| 120 |
+
if any(param.requires_grad for param in module.parameters()):
|
| 121 |
+
layer_frozen = False
|
| 122 |
+
|
| 123 |
+
if layer_frozen:
|
| 124 |
+
print(f"layer {i} has all parameters frozen, but it should be unfrozen.")
|
| 125 |
+
else:
|
| 126 |
+
unfrozen_layers += 1
|
| 127 |
+
else:
|
| 128 |
+
if any(param.requires_grad for param in layer.parameters()):
|
| 129 |
+
print(f"layer {i} is not frozen, but it should")
|
| 130 |
+
else:
|
| 131 |
+
frozen_layers += 1
|
| 132 |
+
|
| 133 |
+
assert (frozen_layers == model_layers - N_layers) and (unfrozen_layers == N_layers), \
|
| 134 |
+
f"frozen layers: {frozen_layers}, unfrozen layers: {unfrozen_layers}"
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
freeze_model(model)
|
| 138 |
+
unfreeze_n_layers(encoder_layers, N_layers)
|
| 139 |
+
check_model(encoder_layers, N_layers)
|
src/utils/optimizer_utils.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
from torch.optim import Optimizer
|
| 5 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 6 |
+
from torch.optim.adamw import adamw
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
import deepspeed
|
| 10 |
+
from deepspeed.ops.adam import FusedAdam
|
| 11 |
+
from deepspeed.ops.adam import DeepSpeedCPUAdam
|
| 12 |
+
except:
|
| 13 |
+
pass
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_optimizer(cfg, params):
|
| 17 |
+
if cfg.optim.type == 'adam':
|
| 18 |
+
return torch.optim.Adam(
|
| 19 |
+
params=params,
|
| 20 |
+
lr=cfg.optim.lr,
|
| 21 |
+
weight_decay=cfg.optim.weight_decay,
|
| 22 |
+
betas=(cfg.optim.beta1, cfg.optim.beta2)
|
| 23 |
+
)
|
| 24 |
+
elif cfg.optim.type == 'adamw':
|
| 25 |
+
return AdamW(
|
| 26 |
+
params=params,
|
| 27 |
+
lr=cfg.optim.lr,
|
| 28 |
+
weight_decay=cfg.optim.weight_decay,
|
| 29 |
+
betas=(cfg.optim.beta1, cfg.optim.beta2)
|
| 30 |
+
)
|
| 31 |
+
elif cfg.type == 'fusedadam':
|
| 32 |
+
return FusedAdam(
|
| 33 |
+
params=params,
|
| 34 |
+
lr=cfg.lr,
|
| 35 |
+
weight_decay=cfg.weight_decay,
|
| 36 |
+
betas=cfg.betas,
|
| 37 |
+
)
|
| 38 |
+
else:
|
| 39 |
+
raise NotImplementedError('Optimizer not supported: %s' % cfg.type)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class AdamW(torch.optim.AdamW):
|
| 43 |
+
@torch.no_grad()
|
| 44 |
+
def step(self, closure=None):
|
| 45 |
+
"""Performs a single optimization step.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
closure (callable, optional): A closure that reevaluates the model
|
| 49 |
+
and returns the loss.
|
| 50 |
+
"""
|
| 51 |
+
self._cuda_graph_capture_health_check()
|
| 52 |
+
|
| 53 |
+
loss = None
|
| 54 |
+
if closure is not None:
|
| 55 |
+
with torch.enable_grad():
|
| 56 |
+
loss = closure()
|
| 57 |
+
|
| 58 |
+
for group in self.param_groups:
|
| 59 |
+
params_with_grad = []
|
| 60 |
+
grads = []
|
| 61 |
+
exp_avgs = []
|
| 62 |
+
exp_avg_sqs = []
|
| 63 |
+
max_exp_avg_sqs = []
|
| 64 |
+
state_steps = []
|
| 65 |
+
amsgrad = group['amsgrad']
|
| 66 |
+
beta1, beta2 = group['betas']
|
| 67 |
+
|
| 68 |
+
for p in group['params']:
|
| 69 |
+
if p.grad is None:
|
| 70 |
+
continue
|
| 71 |
+
params_with_grad.append(p)
|
| 72 |
+
if p.grad.is_sparse:
|
| 73 |
+
raise RuntimeError('AdamW does not support sparse gradients')
|
| 74 |
+
grads.append(p.grad)
|
| 75 |
+
|
| 76 |
+
state = self.state[p]
|
| 77 |
+
|
| 78 |
+
# State initialization
|
| 79 |
+
if len(state) == 0:
|
| 80 |
+
state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \
|
| 81 |
+
if self.defaults['capturable'] else torch.tensor(0.)
|
| 82 |
+
# Exponential moving average of gradient values
|
| 83 |
+
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
| 84 |
+
# Exponential moving average of squared gradient values
|
| 85 |
+
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
| 86 |
+
if amsgrad:
|
| 87 |
+
# Maintains max of all exp. moving avg. of sq. grad. values
|
| 88 |
+
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
| 89 |
+
|
| 90 |
+
exp_avgs.append(state['exp_avg'])
|
| 91 |
+
exp_avg_sqs.append(state['exp_avg_sq'])
|
| 92 |
+
|
| 93 |
+
if amsgrad:
|
| 94 |
+
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
|
| 95 |
+
|
| 96 |
+
state_steps.append(state['step'].cpu())
|
| 97 |
+
|
| 98 |
+
adamw(params_with_grad,
|
| 99 |
+
grads,
|
| 100 |
+
exp_avgs,
|
| 101 |
+
exp_avg_sqs,
|
| 102 |
+
max_exp_avg_sqs,
|
| 103 |
+
state_steps,
|
| 104 |
+
amsgrad=amsgrad,
|
| 105 |
+
beta1=beta1,
|
| 106 |
+
beta2=beta2,
|
| 107 |
+
lr=group['lr'],
|
| 108 |
+
weight_decay=group['weight_decay'],
|
| 109 |
+
eps=group['eps'],
|
| 110 |
+
maximize=group['maximize'],
|
| 111 |
+
foreach=group['foreach'],
|
| 112 |
+
capturable=group['capturable'])
|
| 113 |
+
|
| 114 |
+
return loss
|
| 115 |
+
|
| 116 |
+
def get_scheduler(cfg, optimizer):
|
| 117 |
+
if cfg.optim.scheduler is None:
|
| 118 |
+
return BlackHole()
|
| 119 |
+
elif cfg.optim.scheduler == 'plateau':
|
| 120 |
+
return (
|
| 121 |
+
torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 122 |
+
optimizer,
|
| 123 |
+
mode=cfg.mode,
|
| 124 |
+
factor=cfg.factor,
|
| 125 |
+
patience=cfg.patience,
|
| 126 |
+
min_lr=cfg.min_lr,
|
| 127 |
+
),
|
| 128 |
+
{'monitor': "val/loss", 'interval': 'epoch'}
|
| 129 |
+
)
|
| 130 |
+
elif cfg.optim.scheduler == 'noam':
|
| 131 |
+
return (
|
| 132 |
+
NoamScheduler(
|
| 133 |
+
optimizer,
|
| 134 |
+
lr=cfg.lr,
|
| 135 |
+
warmup_steps=cfg.warmup_steps,
|
| 136 |
+
model_size=cfg.model_size,
|
| 137 |
+
warmup_init_lr=cfg.get('warmup_init_lr')
|
| 138 |
+
),
|
| 139 |
+
{'frequency': 1, 'interval': 'step'}
|
| 140 |
+
)
|
| 141 |
+
elif cfg.optim.scheduler == 'polynomial':
|
| 142 |
+
return (
|
| 143 |
+
PolyNomialLRScheduler(
|
| 144 |
+
optimizer,
|
| 145 |
+
total_steps=cfg.training.max_steps,
|
| 146 |
+
warmup_steps=cfg.training.warmup_steps,
|
| 147 |
+
lr=cfg.optim.lr,
|
| 148 |
+
lr_end=cfg.optim.lr_end,
|
| 149 |
+
warmup_init_lr=cfg.optim.warmup_init_lr,
|
| 150 |
+
power=cfg.optim.power
|
| 151 |
+
),
|
| 152 |
+
{'frequency': 1, 'interval': 'step'}
|
| 153 |
+
)
|
| 154 |
+
elif cfg.optim.scheduler == 'multistep':
|
| 155 |
+
return torch.optim.lr_scheduler.MultiStepLR(
|
| 156 |
+
optimizer,
|
| 157 |
+
milestones=cfg.milestones,
|
| 158 |
+
gamma=cfg.gamma,
|
| 159 |
+
)
|
| 160 |
+
elif cfg.optim.scheduler == 'exp':
|
| 161 |
+
return torch.optim.lr_scheduler.ExponentialLR(
|
| 162 |
+
optimizer,
|
| 163 |
+
gamma=cfg.gamma,
|
| 164 |
+
)
|
| 165 |
+
elif cfg.optim.scheduler == 'progen_ft':
|
| 166 |
+
sched = CosineToFrac(
|
| 167 |
+
optimizer=optimizer,
|
| 168 |
+
total_steps=cfg.training.max_steps,
|
| 169 |
+
final_frac=0.2, # decay to lr/5
|
| 170 |
+
)
|
| 171 |
+
return (sched, {'frequency': 1, 'interval': 'step'})
|
| 172 |
+
elif cfg.optim.scheduler is None:
|
| 173 |
+
return BlackHole()
|
| 174 |
+
else:
|
| 175 |
+
raise NotImplementedError('Scheduler not supported: %s' % cfg.optim.scheduler)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class BlackHole(object):
|
| 179 |
+
def __setattr__(self, name, value):
|
| 180 |
+
pass
|
| 181 |
+
|
| 182 |
+
def __call__(self, *args, **kwargs):
|
| 183 |
+
return self
|
| 184 |
+
|
| 185 |
+
def __getattr__(self, name):
|
| 186 |
+
return self
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
# -------# DPLM Scheduler #-------- #
|
| 190 |
+
def polynomial_lr_schedule(step, total_steps, warmup_steps, warmup_init_lr, lr, lr_end, power):
|
| 191 |
+
if step < warmup_steps:
|
| 192 |
+
return warmup_init_lr + (lr - warmup_init_lr) * step / warmup_steps
|
| 193 |
+
elif step > total_steps:
|
| 194 |
+
return lr_end
|
| 195 |
+
else:
|
| 196 |
+
return lr_end + (lr - lr_end) * (1 - (step - warmup_steps) / (total_steps - warmup_steps)) ** power
|
| 197 |
+
|
| 198 |
+
class PolyNomialLRScheduler(LambdaLR):
|
| 199 |
+
def __init__(
|
| 200 |
+
self,
|
| 201 |
+
optimizer: Optimizer,
|
| 202 |
+
total_steps: int = 1000,
|
| 203 |
+
warmup_steps: int = 0,
|
| 204 |
+
lr: float = 0.00004, # 5e-04,
|
| 205 |
+
lr_end: float = 1e-5, #1e-07,
|
| 206 |
+
warmup_init_lr: float = 1e-07, # 1e-07,
|
| 207 |
+
power: float = 1.0,
|
| 208 |
+
) -> None:
|
| 209 |
+
|
| 210 |
+
self.warmup_init_lr = warmup_init_lr
|
| 211 |
+
self.warmup_steps = warmup_steps
|
| 212 |
+
|
| 213 |
+
def lr_lambda(step):
|
| 214 |
+
return polynomial_lr_schedule(
|
| 215 |
+
step, total_steps, warmup_steps, warmup_init_lr, lr, lr_end, power
|
| 216 |
+
) / lr
|
| 217 |
+
|
| 218 |
+
super().__init__(optimizer, lr_lambda=lr_lambda)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
# -------# ProGen2 Fine-Tuning Scheduler #-------- #
|
| 222 |
+
def cosine_frac_scheduler(step, total_steps, final_frac):
|
| 223 |
+
s = min(max(step, 0), total_steps)
|
| 224 |
+
cos = 0.5 * (1.0 + math.cos(math.pi * s / total_steps)) # 1 --> 0
|
| 225 |
+
return final_frac + (1.0 - final_frac) * cos # multiplier goes from 1.0 down to final_frac
|
| 226 |
+
|
| 227 |
+
class CosineToFrac(LambdaLR):
|
| 228 |
+
"""
|
| 229 |
+
Cosine decay of the LR multiplier from 1.0 -> final_frac over total_steps (no warmup).
|
| 230 |
+
For ProGen fine-tuning, final_frac=0.2 implements decay to lr/5.
|
| 231 |
+
"""
|
| 232 |
+
def __init__(self, optimizer, total_steps, final_frac=0.2):
|
| 233 |
+
self.total_steps = max(int(total_steps), 1)
|
| 234 |
+
self.final_frac = float(final_frac)
|
| 235 |
+
|
| 236 |
+
def lr_lambda(step):
|
| 237 |
+
return cosine_frac_scheduler(
|
| 238 |
+
step=step,
|
| 239 |
+
total_steps=self.total_steps,
|
| 240 |
+
final_frac=self.final_frac
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
super().__init__(optimizer, lr_lambda=lr_lambda)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def inverse_sqrt_lr_schedule(step, warmup_steps, warmup_init_lr, lr_step, decay_step):
|
| 248 |
+
if step == 0:
|
| 249 |
+
step = 1
|
| 250 |
+
if step < warmup_steps:
|
| 251 |
+
return warmup_init_lr + lr_step * step
|
| 252 |
+
else:
|
| 253 |
+
return decay_step * step ** -0.5
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class InverseSqrtLRScheduler(LambdaLR):
|
| 257 |
+
def __init__(
|
| 258 |
+
self,
|
| 259 |
+
optimizer: Optimizer,
|
| 260 |
+
warmup_steps: int = 0,
|
| 261 |
+
lr: float = 5e-04,
|
| 262 |
+
warmup_init_lr: float = 1e-07,
|
| 263 |
+
) -> None:
|
| 264 |
+
|
| 265 |
+
self.warmup_init_lr = warmup_init_lr
|
| 266 |
+
self.warmup_steps = warmup_steps
|
| 267 |
+
self.lr_step = (lr - warmup_init_lr) / warmup_steps
|
| 268 |
+
self.decay_step = lr * warmup_steps ** 0.5
|
| 269 |
+
|
| 270 |
+
def lr_lambda(step):
|
| 271 |
+
return inverse_sqrt_lr_schedule(
|
| 272 |
+
step, warmup_steps, warmup_init_lr, self.lr_step, self.decay_step
|
| 273 |
+
) / lr
|
| 274 |
+
|
| 275 |
+
super().__init__(optimizer, lr_lambda=lr_lambda)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def noam_lr_schedule(step, warmup_steps, factor, model_size):
|
| 279 |
+
if step == 0:
|
| 280 |
+
step = 1
|
| 281 |
+
return factor * (model_size ** (-0.5) * min(step ** (-0.5), step * warmup_steps ** (-1.5)))
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class NoamScheduler(LambdaLR):
|
| 285 |
+
def __init__(
|
| 286 |
+
self,
|
| 287 |
+
optimizer: Optimizer,
|
| 288 |
+
lr,
|
| 289 |
+
warmup_init_lr,
|
| 290 |
+
model_size: int = 128,
|
| 291 |
+
warmup_steps: int = 0,
|
| 292 |
+
factor: int = 2,
|
| 293 |
+
) -> None:
|
| 294 |
+
|
| 295 |
+
# dummy_lr = self.base_lrs[0]
|
| 296 |
+
def lr_lambda(step):
|
| 297 |
+
return noam_lr_schedule(step, warmup_steps, factor, model_size) / lr
|
| 298 |
+
|
| 299 |
+
super().__init__(optimizer, lr_lambda=lr_lambda)
|
| 300 |
+
|