Shrey Goel commited on
Commit
d04a061
·
0 Parent(s):

adding code

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitignore ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # .gitignore
2
+
3
+ /checkpoints/
4
+ /data/
5
+ /results/
6
+ /build/
7
+ /src/scripts/
8
+ /src/benchmarks
9
+
10
+ /src/lm/dplm
11
+ /src/lm/evodiff
12
+ /src/lm/dplm_playground.ipynb
13
+ /src/lm/evoflow_playground.ipynb
14
+ /src/utils/ubuntu_font
15
+
16
+ /src/sampling/old_guidance.py
17
+
18
+ /MeMDLM_v2.egg-info/
19
+ *.pth
20
+ *.ckpt
21
+ *.err
22
+ *.out
23
+ *.csv
24
+ __pycache__/
README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # MeMDLM_v2
__init__.py ADDED
File without changes
setup.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name='MeMDLM_v2',
5
+ version='1.0',
6
+ packages=find_packages(),
7
+ install_requires=[],
8
+ author='Shrey Goel',
9
+ author_email='[email protected]'
10
+ )
src/configs/guidance.yaml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ seed: 42
4
+ base_dir: /scratch/sgoel/MeMDLM_v2
5
+
6
+
7
+ lm:
8
+ pretrained_esm: facebook/esm2_t33_650M_UR50D
9
+ pretrained_evoflow: fredzzp/EvoFlow-650M-context-3070
10
+ pretrained_dplm: airkingbd/dplm_650m
11
+ ft_evoflow: ft_eflow-3070-650M_steps=50k_layers=3_lr=0.00004_wd=.01_polynom_pwr=1_betas=.9-.98_bsz=8_gclip=1.0
12
+ ft_dplm: ft_dplm-650M_steps=5k_layers=3_lr=0.00004_wd=.01_polynom_pwr=1_betas=.9-.98_bsz=32_gclip=1.0
13
+
14
+ model:
15
+ d_model: 1280
16
+ num_heads: 2
17
+ dropout: 0.5
18
+ num_layers: 4
19
+ label_pad_value: -100
20
+
21
+ optim:
22
+ type: adamw
23
+ lr: 3e-5
24
+ lr_end: 1e-5
25
+ weight_decay: 0.01
26
+ beta1: 0.9
27
+ beta2: 0.98
28
+ power: 1
29
+
30
+
31
+ training:
32
+ mode: test # train / test
33
+ n_layers: 4
34
+ max_steps: 3000
35
+ warmup_steps: 150
36
+ log_every_n_steps: 10
37
+ num_sanity_val_steps: 2
38
+ val_check_interval: 250
39
+ enable_progress_bar: true
40
+ grad_clip_val: 1.0
41
+ devices: [0] # list of GPU IDs from 0-7
42
+
43
+ guidance:
44
+ n_steps: 128
45
+ alpha: 3
46
+ gamma: 0.3
47
+ saliency_eps: 1e-4
48
+ saliency_t: 2.0
49
+ sampling_t: 0.7
50
+ boltzmann_t: 0.3
51
+ top_p: 0.2
52
+ steps: 128
53
+ prior: lm_probs # lm_probs / boltzmann
54
+
55
+ data:
56
+ batch_size: 32
57
+ max_seq_len: 1024
58
+ train: ${base_dir}/data/classifier/train.csv
59
+ test: ${base_dir}/data/classifier/test.csv
60
+ val: ${base_dir}/data/classifier/val.csv
61
+
62
+
63
+ wandb:
64
+ project: memdlm_guidance
65
+ group: programmablebio
66
+ name: new_data_cleaned_steps3k_lr3e-5_bsz32_heads2_drpt0.5_layers4
67
+ id: ${.name}_${seed}
68
+
69
+
70
+ checkpointing:
71
+ save_every_n_steps: 250
72
+ save_dir: ${base_dir}/checkpoints/${wandb.name}
73
+ resume_ckpt_path: ${checkpointing.save_dir}/last.ckpt
74
+ best_ckpt_path: ${checkpointing.save_dir}/best_model.ckpt
src/configs/lm.yaml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ seed: 42
4
+ base_dir: /scratch/pranamlab/sgoel/MeMDLM_v2
5
+
6
+
7
+ lm:
8
+ pretrained_esm: facebook/esm2_t33_650M_UR50D
9
+ pretrained_evoflow: fredzzp/EvoFlow-650M-context-3070
10
+ pretrained_dplm: airkingbd/dplm_650m
11
+ pretrained_progen: hugohrban/progen2-base
12
+ num_diffusion_timesteps: 500
13
+ weight_type: linear # constant / linear
14
+
15
+
16
+ optim:
17
+ type: adamw
18
+ scheduler: polynomial
19
+ lr: 0.00004
20
+ lr_end: 1e-5
21
+ warmup_init_lr: 1e-07
22
+ weight_decay: 0.01
23
+ beta1: 0.9
24
+ beta2: 0.98
25
+ power: 1
26
+
27
+
28
+ training:
29
+ mode: train # train / test / resume_from_checkpoint
30
+ n_layers: 3
31
+ max_steps: 5000
32
+ warmup_steps: 25
33
+ log_every_n_steps: 10
34
+ num_sanity_val_steps: 2
35
+ val_check_interval: 250
36
+ enable_progress_bar: true
37
+ grad_clip_val: 1.0
38
+ devices: [0,1,2] # list of GPU IDs
39
+
40
+ sampling:
41
+ n_steps: 128
42
+
43
+
44
+ data:
45
+ batch_size: 8
46
+ max_seq_len: 1024
47
+ train: ${base_dir}/data/new/train.csv
48
+ test: ${base_dir}/data/new/test.csv
49
+ val: ${base_dir}/data/new/val.csv
50
+
51
+
52
+ wandb:
53
+ project: memdlm
54
+ group: programmablebio
55
+ name: ft_eflow-3070-650M_steps=5k_layers=3_lr=0.00004_wd=.01_polynom_pwr=1_betas=.9-.98_bsz=8_gclip=1.0_ml=1024
56
+ # name: ft_progen-base-764M_steps=50k_layers=2_lr=0.00004_wd=.1_cosine-to-frac_betas=.9-.999_bsz=8_gclip=0.8
57
+ # name: ft_dplm-650M_steps=5k_layers=3_lr=0.00004_wd=.01_polynom_pwr=1_betas=.9-.98_bsz=32_gclip=1.0
58
+ # name: ft_esm-650M_steps=3k_layers=3_lr=0.00004_wd=.01_polynom_pwr=1_betas=.9-.98_bsz=32_gclip=1.0
59
+ id: ${.name}_${seed}
60
+
61
+
62
+ checkpointing:
63
+ save_every_n_steps: 250
64
+ save_dir: ${base_dir}/checkpoints/${wandb.name}
65
+ resume_ckpt_path: ${checkpointing.save_dir}/last.ckpt
66
+ best_ckpt_path: ${checkpointing.save_dir}/best_model.ckpt
src/guidance/dataloader.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pandas as pd
3
+ import lightning.pytorch as pl
4
+
5
+ from transformers import AutoModel, AutoTokenizer
6
+ from torch.utils.data import Dataset, DataLoader
7
+
8
+
9
+ class MembraneDataset(Dataset):
10
+ def __init__(self, config, data_path):
11
+ self.config = config
12
+ self.data = pd.read_csv(data_path)
13
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config.lm.pretrained_esm)
14
+
15
+ def __len__(self):
16
+ return len(self.data)
17
+
18
+ def __getitem__(self, idx):
19
+ sequence = self.data.iloc[idx]["Sequence"]
20
+
21
+ tokens = self.tokenizer(
22
+ sequence.upper(),
23
+ return_tensors='pt',
24
+ padding='max_length',
25
+ truncation=True,
26
+ max_length=self.config.data.max_seq_len,
27
+ )
28
+
29
+ labels = self.get_labels(sequence)
30
+
31
+ return {
32
+ "input_ids": tokens['input_ids'],
33
+ "attention_mask": tokens['attention_mask'],
34
+ "labels": labels
35
+ }
36
+
37
+ def get_labels(self, sequence):
38
+ max_len = self.config.data.max_seq_len
39
+
40
+ # Create per-residue labels
41
+ labels = torch.tensor([1 if residue.islower() else 0 for residue in sequence], dtype=torch.float)
42
+
43
+ if len(labels) < max_len: # Padding if sequence shorter than tokenizer truncation length
44
+ padded_labels = torch.cat(
45
+ [labels, torch.full(size=(max_len - len(labels),), fill_value=self.config.model.label_pad_value)]
46
+ )
47
+ else: # Truncation otherwise
48
+ padded_labels = labels[:max_len]
49
+ return padded_labels
50
+
51
+
52
+ def collate_fn(batch):
53
+ input_ids = torch.stack([item['input_ids'].squeeze(0) for item in batch])
54
+ masks = torch.stack([item['attention_mask'].squeeze(0) for item in batch])
55
+ labels = torch.stack([item['labels'] for item in batch])
56
+
57
+ return {
58
+ 'input_ids': input_ids,
59
+ 'attention_mask': masks,
60
+ 'labels': labels
61
+ }
62
+
63
+
64
+ class MembraneDataModule(pl.LightningDataModule):
65
+ def __init__(self, config, train_dataset, val_dataset, test_dataset, collate_fn=collate_fn):
66
+ super().__init__()
67
+ self.train_dataset = train_dataset
68
+ self.val_dataset = val_dataset
69
+ self.test_dataset = test_dataset
70
+ self.collate_fn = collate_fn
71
+ self.batch_size = config.data.batch_size
72
+
73
+ def train_dataloader(self):
74
+ return DataLoader(self.train_dataset,
75
+ batch_size=self.batch_size,
76
+ collate_fn=self.collate_fn,
77
+ num_workers=8,
78
+ pin_memory=True)
79
+
80
+ def val_dataloader(self):
81
+ return DataLoader(self.val_dataset,
82
+ batch_size=self.batch_size,
83
+ collate_fn=self.collate_fn,
84
+ num_workers=8,
85
+ pin_memory=True)
86
+
87
+ def test_dataloader(self):
88
+ return DataLoader(self.test_dataset,
89
+ batch_size=self.batch_size,
90
+ collate_fn=self.collate_fn,
91
+ num_workers=8,
92
+ pin_memory=True)
93
+
94
+
95
+ def get_datasets(config):
96
+ """Helper method to grab datasets to quickly init data module in main.py"""
97
+ esm_model = AutoModel.from_pretrained(config.lm.pretrained_esm)
98
+ tokenizer = AutoTokenizer.from_pretrained(config.lm.pretrained_esm)
99
+
100
+ train_dataset = MembraneDataset(config, config.data.train)
101
+ val_dataset = MembraneDataset(config, config.data.val)
102
+ test_dataset = MembraneDataset(config, config.data.test)
103
+
104
+ return {
105
+ "train": train_dataset,
106
+ "val": val_dataset,
107
+ "test": test_dataset
108
+ }
src/guidance/main.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import os
4
+ import wandb
5
+ import lightning.pytorch as pl
6
+
7
+ from omegaconf import OmegaConf
8
+ from lightning.pytorch.strategies import DDPStrategy
9
+ from lightning.pytorch.loggers import WandbLogger
10
+ from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
11
+
12
+ from src.utils.model_utils import _print
13
+ from src.guidance.solubility_module import SolubilityClassifier
14
+ from src.guidance.dataloader import MembraneDataModule, get_datasets
15
+
16
+
17
+ config = OmegaConf.load("/scratch/sgoel/MeMDLM_v2/src/configs/guidance.yaml")
18
+ wandb.login(key='2b76a2fa2c1cdfddc5f443602c17b011fefb0a8f')
19
+
20
+ # data
21
+ datasets = get_datasets(config)
22
+ data_module = MembraneDataModule(
23
+ config=config,
24
+ train_dataset=datasets['train'],
25
+ val_dataset=datasets['val'],
26
+ test_dataset=datasets['test'],
27
+ )
28
+
29
+ # wandb logging
30
+ #wandb.init(project=config.wandb.project, name=config.wandb.name)
31
+ wandb_logger = WandbLogger(**config.wandb)
32
+
33
+ # lightning checkpoints
34
+ lr_monitor = LearningRateMonitor(logging_interval="step")
35
+ checkpoint_callback = ModelCheckpoint(
36
+ monitor="val/loss",
37
+ save_top_k=1,
38
+ mode="min",
39
+ dirpath=config.checkpointing.save_dir,
40
+ filename="best_model",
41
+ )
42
+
43
+ # lightning trainer
44
+ trainer = pl.Trainer(
45
+ max_steps=config.training.max_steps,
46
+ accelerator="cuda",
47
+ devices=1, #config.training.devices if config.training.mode=='train' else [0],
48
+ #strategy=DDPStrategy(find_unused_parameters=True),
49
+ callbacks=[checkpoint_callback, lr_monitor],
50
+ logger=wandb_logger,
51
+ log_every_n_steps=config.training.log_every_n_steps
52
+ )
53
+
54
+ # Folder to save checkpoints
55
+ ckpt_dir = config.checkpointing.save_dir
56
+ os.makedirs(ckpt_dir, exist_ok=True)
57
+
58
+ # instantiate model
59
+ model = SolubilityClassifier(config)
60
+
61
+ # train or evalute the model
62
+ if config.training.mode == "train":
63
+ trainer.fit(model, datamodule=data_module)
64
+
65
+ elif config.training.mode == "test":
66
+ ckpt_path = os.path.join(ckpt_dir, "best_model.ckpt")
67
+ state_dict = model.get_state_dict(ckpt_path)
68
+ model.load_state_dict(state_dict)
69
+ trainer.test(model, datamodule=data_module, ckpt_path=ckpt_path)
70
+ else:
71
+ raise ValueError(f"{config.training.mode} is invalid. Must be 'train' or 'test'")
72
+
73
+ wandb.finish()
src/guidance/solubility_module.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import torch
3
+ import torch.nn as nn
4
+ import lightning.pytorch as pl
5
+
6
+ from omegaconf import OmegaConf
7
+ from transformers import AutoModel
8
+ from torchmetrics.classification import BinaryAUROC, BinaryAccuracy
9
+
10
+ from src.utils.model_utils import _print
11
+ from src.guidance.utils import CosineWarmup
12
+
13
+
14
+ config = OmegaConf.load("/scratch/sgoel/MeMDLM_v2/src/configs/guidance.yaml")
15
+
16
+ class SolubilityClassifier(pl.LightningModule):
17
+ def __init__(self, config):
18
+ super().__init__()
19
+ self.config = config
20
+ self.loss_fn = nn.BCEWithLogitsLoss(reduction='none')
21
+ self.auroc = BinaryAUROC()
22
+ self.accuracy = BinaryAccuracy()
23
+
24
+ self.esm_model = AutoModel.from_pretrained(self.config.lm.pretrained_esm)
25
+ for p in self.esm_model.parameters():
26
+ p.requires_grad = False
27
+
28
+ encoder_layer = nn.TransformerEncoderLayer(
29
+ d_model=config.model.d_model,
30
+ nhead=config.model.num_heads,
31
+ dropout=config.model.dropout,
32
+ batch_first=True
33
+ )
34
+ self.encoder = nn.TransformerEncoder(encoder_layer, config.model.num_layers)
35
+ self.layer_norm = nn.LayerNorm(config.model.d_model)
36
+ self.dropout = nn.Dropout(config.model.dropout)
37
+ self.mlp = nn.Sequential(
38
+ nn.Linear(config.model.d_model, config.model.d_model // 2),
39
+ nn.ReLU(),
40
+ nn.Dropout(config.model.dropout),
41
+ nn.Linear(config.model.d_model // 2, 1),
42
+ )
43
+
44
+ # -------# Classifier step #-------- #
45
+ def forward(self, batch):
46
+ if 'input_ids' in batch:
47
+ esm_embeds = self.get_esm_embeddings(batch['input_ids'], batch['attention_mask'])
48
+ elif 'embeds' in batch:
49
+ esm_embeds = batch['embeds']
50
+ encodings = self.encoder(esm_embeds, src_key_padding_mask=(batch['attention_mask'] == 0))
51
+ encodings = self.dropout(self.layer_norm(encodings))
52
+ logits = self.mlp(encodings).squeeze(-1)
53
+ return logits
54
+
55
+
56
+ # -------# Training / Evaluation #-------- #
57
+ def training_step(self, batch, batch_idx):
58
+ train_loss, _ = self.compute_loss(batch)
59
+ self.log(name="train/loss", value=train_loss.item(), on_step=True, on_epoch=False, logger=True, sync_dist=True)
60
+ self.save_ckpt()
61
+ return train_loss
62
+
63
+ def validation_step(self, batch, batch_idx):
64
+ val_loss, _ = self.compute_loss(batch)
65
+ self.log(name="val/loss", value=val_loss.item(), on_step=False, on_epoch=True, logger=True, sync_dist=True)
66
+ return val_loss
67
+
68
+ def test_step(self, batch):
69
+ test_loss, preds = self.compute_loss(batch)
70
+ auroc, accuracy = self.get_metrics(batch, preds)
71
+ self.log(name="test/loss", value=test_loss.item(), on_step=False, on_epoch=True, logger=True, sync_dist=True)
72
+ self.log(name="test/AUROC", value=auroc.item(), on_step=False, on_epoch=True, logger=True, sync_dist=True)
73
+ self.log(name="test/accuracy", value=accuracy.item(), on_step=False, on_epoch=True, logger=True, sync_dist=True)
74
+ return test_loss
75
+
76
+ def on_test_epoch_end(self):
77
+ self.auroc.reset()
78
+ self.accuracy.reset()
79
+
80
+ def optimizer_step(self, *args, **kwargs):
81
+ super().optimizer_step(*args, **kwargs)
82
+ gc.collect()
83
+ torch.cuda.empty_cache()
84
+
85
+ def configure_optimizers(self):
86
+ path = self.config.training
87
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.config.optim.lr)
88
+ lr_scheduler = CosineWarmup(
89
+ optimizer,
90
+ warmup_steps=path.warmup_steps,
91
+ total_steps=path.max_steps,
92
+ )
93
+ scheduler_dict = {
94
+ "scheduler": lr_scheduler,
95
+ "interval": 'step',
96
+ 'frequency': 1,
97
+ 'monitor': 'val/loss',
98
+ 'name': 'learning_rate'
99
+ }
100
+ return [optimizer], [scheduler_dict]
101
+
102
+ def save_ckpt(self):
103
+ curr_step = self.global_step
104
+ save_every = self.config.training.val_check_interval
105
+ if curr_step % save_every == 0 and curr_step > 0: # Save every 250 steps
106
+ ckpt_path = f"{self.config.checkpointing.save_dir}/step={curr_step}.ckpt"
107
+ self.trainer.save_checkpoint(ckpt_path)
108
+
109
+ # -------# Loss and Test Set Metrics #-------- #
110
+ @torch.no_grad
111
+ def get_esm_embeddings(self, input_ids, attention_mask):
112
+ outputs = self.esm_model(input_ids=input_ids, attention_mask=attention_mask)
113
+ embeddings = outputs.last_hidden_state
114
+ return embeddings
115
+
116
+ def compute_loss(self, batch):
117
+ """Helper method to handle loss calculation"""
118
+ labels = batch['labels']
119
+ preds = self.forward(batch)
120
+ loss = self.loss_fn(preds, labels)
121
+ loss_mask = (labels != self.config.model.label_pad_value) # only calculate loss over non-pad tokens
122
+ loss = (loss * loss_mask).sum() / loss_mask.sum()
123
+ return loss, preds
124
+
125
+ def get_metrics(self, batch, preds):
126
+ """Helper method to compute metrics"""
127
+ labels = batch['labels']
128
+
129
+ valid_mask = (labels != self.config.model.label_pad_value)
130
+ labels = labels[valid_mask]
131
+ preds = preds[valid_mask]
132
+
133
+ _print(f"labels {labels.shape}")
134
+ _print(f"preds {preds.shape}")
135
+
136
+ auroc = self.auroc.forward(preds, labels)
137
+ accuracy = self.accuracy.forward(preds, labels)
138
+ return auroc, accuracy
139
+
140
+ # -------# Helper Functions #-------- #
141
+ def get_state_dict(self, ckpt_path):
142
+ """Helper method to load and process a trained model's state dict from saved checkpoint"""
143
+ def remove_model_prefix(state_dict):
144
+ for k in state_dict.keys():
145
+ if "model." in k:
146
+ k.replace('model.', '')
147
+ return state_dict
148
+
149
+ checkpoint = torch.load(ckpt_path, map_location='cuda' if torch.cuda.is_available() else 'cpu')
150
+ state_dict = checkpoint.get("state_dict", checkpoint)
151
+
152
+ if any(k.startswith("model.") for k in state_dict.keys()):
153
+ state_dict = remove_model_prefix(state_dict)
154
+
155
+ return state_dict
src/guidance/utils.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import numpy as np
4
+ from torch.optim.lr_scheduler import _LRScheduler
5
+
6
+ class CosineWarmup(_LRScheduler):
7
+ def __init__(self, optimizer, warmup_steps, total_steps, eta_ratio=0.1, last_epoch=-1):
8
+ self.warmup_steps = warmup_steps
9
+ self.total_steps = total_steps
10
+ self.eta_ratio = eta_ratio # The ratio of minimum to maximum learning rate
11
+ super(CosineWarmup, self).__init__(optimizer, last_epoch)
12
+
13
+ def get_lr(self):
14
+ if self.last_epoch < self.warmup_steps:
15
+ return [base_lr * self.last_epoch / self.warmup_steps for base_lr in self.base_lrs]
16
+
17
+ progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
18
+ cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
19
+ decayed_lr = (1 - self.eta_ratio) * cosine_decay + self.eta_ratio
20
+
21
+ return [decayed_lr * base_lr for base_lr in self.base_lrs]
src/lm/memdlm/dataloader.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pandas as pd
3
+ import lightning.pytorch as pl
4
+
5
+ from transformers import AutoModel, AutoTokenizer
6
+ from torch.utils.data import Dataset, DataLoader
7
+
8
+
9
+ class MembraneDataset(Dataset):
10
+ def __init__(self, config, data_path):
11
+ self.config = config
12
+ self.data = pd.read_csv(data_path)
13
+ self.tokenizer = AutoTokenizer.from_pretrained(config.lm.pretrained_evoflow)
14
+
15
+ def __len__(self):
16
+ return len(self.data)
17
+
18
+ def __getitem__(self, idx):
19
+ sequence = self.data.iloc[idx]["Sequence"]
20
+
21
+ tokens = self.tokenizer(
22
+ sequence.upper(),
23
+ return_tensors='pt',
24
+ padding='max_length',
25
+ truncation=True,
26
+ max_length=self.config.data.max_seq_len
27
+ )
28
+
29
+ #return {"input_ids": tokens['input_ids'], "attention_mask": tokens['attention_mask']}
30
+
31
+ return {
32
+ "input_ids": tokens['input_ids'].squeeze(0),
33
+ "attention_mask": tokens['attention_mask'].squeeze(0)
34
+ }
35
+
36
+
37
+ def collate_fn(batch):
38
+ input_ids = torch.stack([item['input_ids'] for item in batch])#.squeeze()
39
+ masks = torch.stack([item['attention_mask'] for item in batch])#.squeeze()
40
+
41
+ return {'input_ids': input_ids, 'attention_mask': masks}
42
+
43
+
44
+ class MembraneDataModule(pl.LightningDataModule):
45
+ def __init__(self, config, train_dataset, val_dataset, test_dataset, collate_fn=collate_fn):
46
+ super().__init__()
47
+ self.train_dataset = train_dataset
48
+ self.val_dataset = val_dataset
49
+ self.test_dataset = test_dataset
50
+ self.collate_fn = collate_fn
51
+ self.batch_size = config.data.batch_size
52
+ self.tokenizer = AutoTokenizer.from_pretrained(config.lm.pretrained_evoflow)
53
+
54
+ def train_dataloader(self):
55
+ return DataLoader(self.train_dataset,
56
+ batch_size=self.batch_size,
57
+ collate_fn=self.collate_fn,
58
+ num_workers=8,
59
+ pin_memory=True)
60
+
61
+ def val_dataloader(self):
62
+ return DataLoader(self.val_dataset,
63
+ batch_size=self.batch_size,
64
+ collate_fn=self.collate_fn,
65
+ num_workers=8,
66
+ shuffle=False,
67
+ pin_memory=True)
68
+
69
+ def test_dataloader(self):
70
+ return DataLoader(self.test_dataset,
71
+ batch_size=self.batch_size,
72
+ collate_fn=self.collate_fn,
73
+ num_workers=8,
74
+ shuffle=False,
75
+ pin_memory=True)
76
+
77
+
78
+ def get_datasets(config):
79
+ """Helper method to grab datasets to quickly init data module in main.py"""
80
+ train_dataset = MembraneDataset(config, config.data.train)
81
+ test_dataset = MembraneDataset(config, config.data.test)
82
+ val_dataset = MembraneDataset(config, config.data.val)
83
+
84
+ return {
85
+ "train": train_dataset,
86
+ "val": val_dataset,
87
+ "test": test_dataset
88
+ }
src/lm/memdlm/diffusion_module.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import torch
4
+
5
+ import torch.nn.functional as F
6
+ import lightning as pl
7
+
8
+ from typing import Optional
9
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
10
+
11
+ from src.utils.model_utils import _print
12
+ from src.utils.optimizer_utils import get_optimizer, get_scheduler
13
+
14
+
15
+ class MembraneDiffusion(pl.LightningModule):
16
+ def __init__(self, config):
17
+ """
18
+ Args:
19
+ config (OmegaConf): config.yaml file with all training parameters
20
+ """
21
+ super().__init__()
22
+ self.config = config
23
+ self.save_hyperparameters(logger=True)
24
+
25
+ self.model = AutoModelForMaskedLM.from_pretrained(config.lm.pretrained_evoflow, trust_remote_code=True)
26
+ self.tokenizer = AutoTokenizer.from_pretrained(config.lm.pretrained_evoflow)
27
+
28
+ self.mask_id = self.tokenizer.mask_token_id
29
+ self.pad_id = self.tokenizer.pad_token_id
30
+
31
+ def forward(self, input_ids, attention_mask, guidance: Optional[bool] = False):
32
+ """
33
+ Forward pass through language model.
34
+
35
+ Args:
36
+ - input_ids (torch.Tensor): [B, L], token ids
37
+ - attention_mask (torch.Tensor): [B, L], pad/non-pad binary mask
38
+ Returns:
39
+ - logits (torch.Tensor): [B, L, V], unnormalized model outputs
40
+ """
41
+ return self.model(input_ids=input_ids, attention_mask=attention_mask).logits
42
+
43
+ # -------# Diffusion #-------- #
44
+ def step(self, batch):
45
+ labels = batch['input_ids']
46
+
47
+ # Forward diffusion
48
+ t1 = self.sample_t(labels) # Sample timestep
49
+ xt, _ = self.noise_x0(labels, t1, maskable_mask=self.is_maskable(labels)) # Noise sequence
50
+ logits = self.forward(input_ids=xt, attention_mask=batch['attention_mask']) # Model logits
51
+
52
+ # Loss computation
53
+ weight = self.get_weight(t1, weight_type=self.config.lm.weight_type) # RDM uses a weighted cross entropy loss
54
+ loss_out = self.compute_loss(logits, labels, weight) # Compute loss and ppl
55
+
56
+ self.cleanup()
57
+ return loss_out['loss'], loss_out['ppl']
58
+
59
+ def sample_t(self, labels, rdm_coupling=False):
60
+ """
61
+ Sample diffusion timesteps. Non-coupling RDM only uses one timestep (t1).
62
+ """
63
+ timesteps = torch.randint(
64
+ 1,
65
+ self.config.lm.num_diffusion_timesteps + 1,
66
+ (2 if rdm_coupling else 1) * (labels.size(0),),
67
+ device=labels.device
68
+ )
69
+
70
+ if rdm_coupling:
71
+ return timesteps.chunk(2)
72
+ return timesteps
73
+
74
+ def noise_x0(self, x0, t1, maskable_mask):
75
+ """
76
+ Apply noise to the initial sequence x0.
77
+ """
78
+ u = torch.rand_like(x0, dtype=torch.float)
79
+ t1_mask = (u < (t1 / self.config.lm.num_diffusion_timesteps)[:, None]) & maskable_mask
80
+ x_t1 = x0.masked_fill(t1_mask, self.mask_id)
81
+ x_t1 = x_t1.masked_fill(t1_mask, self.mask_id)
82
+ return x_t1, t1_mask
83
+
84
+ def get_weight(self, t, weight_type):
85
+ """
86
+ Compute the weighting factor for the RDM-derived loss (weighted cross-entropy).
87
+ """
88
+ num_timesteps = self.config.lm.num_diffusion_timesteps
89
+ weight = {
90
+ "linear": (num_timesteps - (t - 1)), # num_timesteps * (1 - (t-1)/num_timesteps)
91
+ "constant": num_timesteps * torch.ones_like(t),
92
+ }[weight_type][:, None].float() / num_timesteps
93
+ return weight.squeeze()
94
+
95
+ def compute_loss(self, logits, labels, weight):
96
+ """
97
+ Compute the cross entropy loss per sample.
98
+ First, compute the per-token loss (with no reduction), then reduce over the sequence length for each sample.
99
+ Finally, average over the batch.
100
+
101
+ Args:
102
+ logits (torch.Tensor): [B, L, vocab_size], unnormalized model outputs
103
+ labels (torch.Tensor): [B, L], target labels (with padding tokens as -100)
104
+ weight (torch.Tensor): [B, 1], per-sample weight for loss calculation
105
+ Returns:
106
+ loss (torch.Tensor): Averaged loss over the batch
107
+ logging_output (torch.Tensor): Dictionary of values for logging
108
+ """
109
+
110
+ loss_token = F.cross_entropy(
111
+ logits.view(-1, logits.size(-1)),
112
+ labels.view(-1),
113
+ reduction='none',
114
+ ignore_index=self.pad_id,
115
+ )
116
+
117
+ loss_token = loss_token.view(labels.size(0), labels.size(1)) # Reshape to [B, L]
118
+ valid_mask = (labels != self.pad_id)
119
+
120
+ sample_loss = (loss_token * valid_mask.float()).sum(dim=1) / valid_mask.float().sum(dim=1).clamp(min=1)
121
+ sample_loss *= weight # RDM weighting
122
+ ppl = torch.exp(sample_loss)
123
+
124
+ return {'ppl': ppl.mean(), 'loss': sample_loss.mean()}
125
+
126
+
127
+ # -------# Training / Evaluation #-------- #
128
+ def training_step(self, batch):
129
+ loss, ppl = self.step(batch)
130
+ self.log("train/loss", loss.item(), on_step=True, on_epoch=False, prog_bar=True)
131
+ self.log("train/ppl", ppl.item(), on_step=True, on_epoch=False, prog_bar=False)
132
+ return loss
133
+
134
+ def validation_step(self, batch):
135
+ loss, ppl = self.step(batch)
136
+ self.cleanup()
137
+ self.log("val/loss", loss.item(), on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
138
+ self.log("val/ppl", ppl.item(), on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
139
+ return loss
140
+
141
+ def test_step(self, batch):
142
+ loss, ppl = self.step(batch)
143
+ self.cleanup()
144
+ self.log('test/loss', loss.item(), on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
145
+ self.log("test/ppl", ppl.item(), on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
146
+ return loss
147
+
148
+
149
+ # -------# Helper methods #-------- #
150
+ def is_maskable(self, input_ids: torch.Tensor):
151
+ return (
152
+ (input_ids != self.tokenizer.pad_token_id)
153
+ & (input_ids != self.tokenizer.cls_token_id)
154
+ & (input_ids != self.tokenizer.eos_token_id)
155
+ )
156
+
157
+ def configure_optimizers(self):
158
+ """
159
+ Choosing which optimizer and lr scheduler to use.
160
+ """
161
+ optimizer = get_optimizer(self.config, self.model.parameters())
162
+ lr_scheduler, extra_kwargs = get_scheduler(self.config, optimizer) # Polynomial scheduler
163
+ return {
164
+ "optimizer": optimizer,
165
+ "lr_scheduler": {"scheduler": lr_scheduler, **extra_kwargs},
166
+ }
167
+
168
+ def validate_config(self):
169
+ assert os.path.isdir(self.config.checkpointing.save_dir), "invalid checkpointing path"
170
+ assert self.config.training.mode in ["train", "test", "resume_from_checkpoint"], "invalid mode"
171
+
172
+ def get_state_dict(self, ckpt_path):
173
+ def remove_model_prefix(state_dict):
174
+ for k, v in state_dict.items():
175
+ if "model." in k:
176
+ k.replace('model.', '')
177
+ return state_dict
178
+
179
+ checkpoint = torch.load(ckpt_path, map_location='cuda' if torch.cuda.is_available() else 'cpu')
180
+ state_dict = checkpoint.get("state_dict", checkpoint)
181
+
182
+ if any(k.startswith("model.") for k in state_dict.keys()):
183
+ state_dict = remove_model_prefix(state_dict)
184
+
185
+ return state_dict
186
+
187
+ def cleanup(self):
188
+ torch.cuda.empty_cache()
189
+ gc.collect()
src/lm/memdlm/loss.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ # Ignore file
6
+
7
+ class RDMCrossEntropyLoss(nn.CrossEntropyLoss):
8
+ def __init__(self, ignore_index):
9
+ self.ignore_index = ignore_index
10
+
11
+ def forward(self,
12
+ scores: torch.Tensor,
13
+ target: torch.Tensor,
14
+ label_mask,
15
+ weights,
16
+ ) -> torch.Tensor:
17
+ """
18
+ Computes the RDM-derived loss (weighted cross-entropy).
19
+ """
20
+
21
+ sample_size = target.ne(self.ignore_index).float().sum()
22
+
23
+ lprobs = F.log_softmax(scores, dim=-1)
24
+
25
+ loss = lprobs * weights
26
+ fullseq_loss = loss.sum() / sample_size
27
+
28
+ # use coord masked loss for model training,
29
+ # ignoring those position with missing coords (as nan)
30
+ label_mask = label_mask.float()
31
+ sample_size = label_mask.sum() # sample size should be set to valid coordinates
32
+ loss = (loss * label_mask).sum() / sample_size
33
+
34
+ ppl = torch.exp(loss)
35
+
36
+ logging_output = {
37
+ 'ppl': ppl.data,
38
+ 'fullseq_loss': fullseq_loss.data,
39
+ 'weight_diff_loss': loss.data
40
+ }
41
+
42
+ return logging_output
src/lm/memdlm/main.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import sys
4
+ import torch
5
+ import wandb
6
+ import torch.nn as nn
7
+ import lightning.pytorch as pl
8
+
9
+ from omegaconf import OmegaConf
10
+ from lightning.pytorch.strategies import DDPStrategy
11
+ from lightning.pytorch.loggers import WandbLogger
12
+ from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
13
+
14
+ from src.lm.memdlm.diffusion_module import MembraneDiffusion
15
+ from src.lm.memdlm.dataloader import MembraneDataModule, get_datasets
16
+ from src.utils.model_utils import apply_rdm_freezing
17
+
18
+ wandb.login(key='2b76a2fa2c1cdfddc5f443602c17b011fefb0a8f')
19
+
20
+
21
+ # Load yaml config
22
+ config = OmegaConf.load("/scratch/pranamlab/sgoel/MeMDLM_v2/src/configs/lm.yaml")
23
+
24
+ # Get datasets
25
+ datasets = get_datasets(config)
26
+ data_module = MembraneDataModule(
27
+ config=config,
28
+ train_dataset=datasets['train'],
29
+ val_dataset=datasets['val'],
30
+ test_dataset=datasets['test'],
31
+ )
32
+
33
+ # Initialize WandB for logging
34
+ wandb.init(project=config.wandb.project, name=config.wandb.name)
35
+ wandb_logger = WandbLogger(**config.wandb)
36
+
37
+ # PL checkpoints
38
+ lr_monitor = LearningRateMonitor(logging_interval="step")
39
+ checkpoint_callback = ModelCheckpoint(
40
+ monitor="val/loss",
41
+ save_top_k=1,
42
+ mode="min",
43
+ dirpath=config.checkpointing.save_dir,
44
+ filename="best_model",
45
+ every_n_train_steps=config.checkpointing.save_every_n_steps
46
+ )
47
+
48
+ # PL trainer
49
+ trainer = pl.Trainer(
50
+ max_steps=config.training.max_steps,
51
+ max_epochs=None, # Ensure training is based on num steps
52
+ accelerator="cuda" if torch.cuda.is_available() else "cpu",
53
+ devices=config.training.devices if config.training.mode=='train' else [0],
54
+ strategy=DDPStrategy(find_unused_parameters=True),
55
+ callbacks=[checkpoint_callback, lr_monitor],
56
+ logger=wandb_logger,
57
+ log_every_n_steps=config.training.log_every_n_steps
58
+ )
59
+
60
+
61
+ # Folder to save checkpoints
62
+ ckpt_path = config.checkpointing.save_dir
63
+ try: os.makedirs(ckpt_path, exist_ok=False)
64
+ except FileExistsError: pass
65
+
66
+ # PL Model for training
67
+ diffusion = MembraneDiffusion(config)
68
+ diffusion.validate_config()
69
+
70
+ # Start/resume training or evaluate the model
71
+ model_type = "evoflow"
72
+ if config.training.mode == "train":
73
+ apply_rdm_freezing(diffusion.model, config.training.n_layers, model_type)
74
+ trainer.fit(diffusion, datamodule=data_module)
75
+
76
+ elif config.training.mode == "test":
77
+ state_dict = diffusion.get_state_dict(config.checkpointing.best_ckpt_path)
78
+ diffusion.load_state_dict(state_dict)
79
+ trainer.test(diffusion, datamodule=data_module, ckpt_path=config.checkpointing.best_ckpt_path)
80
+
81
+ elif config.training.mode == "resume_from_checkpoint":
82
+ state_dict = diffusion.get_state_dict(config.training.resume_ckpt_path)
83
+ diffusion.load_state_dict(state_dict)
84
+ apply_rdm_freezing(diffusion.model, config.training.n_layers, model_type)
85
+ trainer.fit(diffusion, datamodule=data_module, ckpt_path=ckpt_path)
86
+
87
+ wandb.finish()
88
+
src/sampling/guided_generator.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import sys
4
+ import os
5
+ import torch
6
+ import pandas as pd
7
+ from tqdm import tqdm
8
+ from datetime import datetime
9
+ from omegaconf import OmegaConf
10
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
11
+
12
+ from src.lm.memdlm.diffusion_module import MembraneFlow
13
+ from src.utils.model_utils import _print
14
+ from src.sampling.guided_sampler import GuidedSampler
15
+ from src.utils.generate_utils import (
16
+ mask_for_scaffold,
17
+ calc_blosum_score,
18
+ calc_ppl
19
+ )
20
+
21
+ config = OmegaConf.load("/home/a03-sgoel/MeMDLM_v2/src/configs/guidance.yaml")
22
+
23
+ os.chdir(f'/home/a03-sgoel/MeMDLM_v2/results/infilling/guided/{config.lm.ft_evoflow}/test_set/')
24
+ todays_date = datetime.today().strftime('%Y-%m-%d')
25
+ csv_save_path = f'./{todays_date}_boltzmann-soft_new_clf_data_cleaned/'
26
+ try: os.makedirs(csv_save_path, exist_ok=False)
27
+ except FileExistsError: pass
28
+
29
+
30
+ def main():
31
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32
+
33
+ tokenizer = AutoTokenizer.from_pretrained(config.lm.pretrained_esm)
34
+ esm_model = AutoModelForMaskedLM.from_pretrained(config.lm.pretrained_esm).eval().to(device)
35
+
36
+ diffusion = MembraneFlow(config).to(device)
37
+ state_dict = diffusion.get_state_dict(f"/home/a03-sgoel/MeMDLM_v2/checkpoints/{config.lm.ft_evoflow}/best_model.ckpt")
38
+ diffusion.load_state_dict(state_dict)
39
+ diffusion.eval().to(device)
40
+
41
+ sampler = GuidedSampler(config, esm_model, tokenizer, diffusion, device)
42
+
43
+ df = pd.read_csv('/home/a03-sgoel/MeMDLM_v2/data/classifier/test.csv')
44
+ sequences = df['Sequence'].tolist()
45
+
46
+ gen_seqs, ppls, blosums = [], [], []
47
+
48
+
49
+ for seq in tqdm(sequences, desc='Infilling Sequences'):
50
+ masked_seq = mask_for_scaffold(seq, generate_type='uppercase', mask_token='<mask>')
51
+ tokens = tokenizer(masked_seq, return_tensors='pt')
52
+ input_ids, attn_masks = tokens['input_ids'].to(device), tokens['attention_mask'].to(device)
53
+
54
+ soluble_idxs = [i for i in range(len(seq)) if seq[i].isupper()]
55
+ infilled_tokens = sampler.optimize_sequence(
56
+ input_ids=input_ids,
57
+ attn_masks=attn_masks,
58
+ soluble_indices=soluble_idxs,
59
+ )
60
+ infilled_seq = tokenizer.decode(infilled_tokens).replace(" ", "")[5:-5]
61
+
62
+ bl = calc_blosum_score(seq.upper(), infilled_seq, soluble_idxs)
63
+ try:
64
+ ppl = calc_ppl(esm_model, tokenizer, infilled_seq, [i for i in range(len(seq))], model_type='esm')
65
+ except:
66
+ ppl = float('inf')
67
+
68
+ gen_seqs.append(infilled_seq)
69
+ ppls.append(ppl)
70
+ blosums.append(bl)
71
+
72
+ _print(seq)
73
+ _print(infilled_seq)
74
+ _print(ppl)
75
+ _print(bl)
76
+ _print('\n')
77
+
78
+
79
+ df['MeMDLM Sequence'] = gen_seqs
80
+ df['MeMDLM PPL'] = ppls
81
+ df['MeMDLM BLOSUM'] = blosums
82
+
83
+ _print(df)
84
+ df.to_csv(f'./{csv_save_path}/t=0.7_new-data-cleaned_infilled_seqs.csv', index=False)
85
+
86
+
87
+
88
+ if __name__ == "__main__":
89
+ main()
90
+
src/sampling/guided_sampler.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from src.utils.model_utils import _print
7
+ from src.guidance.solubility_module import SolubilityClassifier
8
+ from src.sampling.unconditional_sampler import UnconditionalSampler
9
+
10
+
11
+ class GuidedSampler:
12
+ def __init__(self, config, esm_model, tokenizer, diffusion, device):
13
+ self.config = config
14
+ self.device = device
15
+
16
+ self.esm = esm_model
17
+ self.memdlm = diffusion
18
+ self.tokenizer = tokenizer
19
+ self.uncond_generator = UnconditionalSampler(self.tokenizer, self.memdlm)
20
+
21
+ ckpt_path = os.path.join(f"/home/a03-sgoel/MeMDLM_v2/checkpoints/{config.wandb.name}/best_model.ckpt")
22
+ self.classifier_model = SolubilityClassifier(config)
23
+ state_dict = self.classifier_model.get_state_dict(ckpt_path)
24
+ self.classifier_model.load_state_dict(state_dict)
25
+ self.classifier_model.eval().to(self.device)
26
+
27
+ self.top_p = self.config.guidance.top_p
28
+ self.alpha = self.config.guidance.alpha
29
+ self.gamma = self.config.guidance.gamma
30
+ self.saliency_eps = self.config.guidance.saliency_eps
31
+ self.saliency_t = self.config.guidance.saliency_t
32
+ self.sampling_t = self.config.guidance.sampling_t
33
+ self.boltzmann_t = self.config.guidance.boltzmann_t
34
+
35
+
36
+ def embed_sequence(self, input_ids, attention_masks):
37
+ with torch.no_grad():
38
+ outs = self.esm(
39
+ input_ids=input_ids,
40
+ attention_mask=attention_masks,
41
+ output_hidden_states=True,
42
+ output_attentions=True
43
+ )
44
+ embeds = outs.hidden_states[-1]
45
+ attn_matrix = outs.attentions
46
+ return embeds, attn_matrix
47
+
48
+
49
+ def sample_from_categorical(self, logits, temperature, noise_scale=1.0):
50
+ gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-8) + 1e-8)
51
+ logits = (logits / temperature) + (noise_scale * gumbel_noise)
52
+
53
+ log_probs = F.log_softmax(logits, dim=-1)
54
+ _, tokens = log_probs.max(dim=-1)
55
+
56
+ return tokens, log_probs
57
+
58
+
59
+ def denoise_sequence(self, input_ids, attn_masks):
60
+ """
61
+ Compute the current and prior sequences' log prob distribution.
62
+ """
63
+ has_masks = (input_ids == self.tokenizer.mask_token_id).any()
64
+
65
+ # Denosie the sequence if needed
66
+ if has_masks:
67
+ xt_prior, logits_prior = self.uncond_generator.sample_unconditional(
68
+ xt=input_ids,
69
+ num_steps=self.config.guidance.n_steps,
70
+ tau=self.sampling_t,
71
+ return_logits=True
72
+ )
73
+ else:
74
+ xt_prior = input_ids
75
+ logits_prior = self.memdlm(input_ids=input_ids, attention_mask=attn_masks)
76
+
77
+ # Take the final sampling step
78
+ _, logits = self.uncond_generator.sample_unconditional(
79
+ xt=xt_prior,
80
+ num_steps=1, # Only need 1 sampling step
81
+ tau=self.sampling_t,
82
+ return_logits=True
83
+ )
84
+
85
+ # Get final sequence log probs (always needed)
86
+ x0, logp_lm = self.sample_from_categorical(logits, temperature=self.sampling_t)
87
+
88
+ return x0.squeeze(), logp_lm.squeeze(), logits_prior
89
+
90
+
91
+ def get_prior(self, logits_prior, solubility_logits):
92
+ if self.config.guidance.prior == "boltzmann":
93
+ hydrophilic = ["D","E","K","R","N","Q","H","S","T","Y"]
94
+ hydrophobic = ["L","I","V","F","W","M","A","C","G","P"]
95
+ amino_acids = hydrophilic + hydrophobic
96
+
97
+ tokens = list(self.tokenizer.get_vocab().keys())
98
+ other = [tok for tok in tokens if tok not in amino_acids]
99
+
100
+ hydrophilic_idxs = [self.tokenizer.convert_tokens_to_ids(aa) for aa in hydrophilic]
101
+ hydrophobic_idxs = [self.tokenizer.convert_tokens_to_ids(aa) for aa in hydrophobic]
102
+ other_idxs = [self.tokenizer.convert_tokens_to_ids(tok) for tok in other]
103
+
104
+ bias = torch.zeros(len(tokens), device=self.device)
105
+ bias[hydrophilic_idxs] = 1.0
106
+ bias[hydrophobic_idxs] = -1.0
107
+ bias[other_idxs] = 0.0
108
+
109
+ sol_scores = torch.sigmoid(solubility_logits)
110
+ token_bias = sol_scores.unsqueeze(-1) * bias
111
+
112
+ lm_probs = F.softmax(logits_prior / self.sampling_t, dim=-1)
113
+ boltz_weight = torch.exp(token_bias / self.boltzmann_t)
114
+
115
+ p_prior = lm_probs * boltz_weight
116
+ p_prior = p_prior / p_prior.sum(dim=-1, keepdim=True)
117
+ logp_prior = torch.log(p_prior)
118
+
119
+ elif self.config.guidance.prior == "lm_probs":
120
+ _, logp_prior = self.sample_from_categorical(logits_prior, temperature=self.sampling_t)
121
+
122
+ return logp_prior.squeeze()
123
+
124
+
125
+ def compute_saliency_map(self, embeds, solubility_logits):
126
+ """
127
+ Compute a saliency map as in LaMBO-2 (https://arxiv.org/abs/2305.20009) Eq. 5
128
+ """
129
+ # Gradient tracking is already enabled for the embeddings
130
+ solubility_logits.sum().backward(retain_graph=True) # Clf gradients wrt hidden states
131
+ grads = embeds.grad.abs().sum(dim=-1) # Aggergate across hidden dim. Abs value for mangitude only.
132
+ saliency = grads.pow(1.0 / self.saliency_t).clamp(min=self.saliency_eps).to(self.device)
133
+ saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min() + 1e-6)
134
+ return saliency.squeeze()
135
+
136
+
137
+ def determine_edit_positions(self, saliency_map, soluble_indices, solubility_logits):
138
+ """
139
+ Fix the insoluble residues and additional TM residues to
140
+ maintain membrane-like protein structure.
141
+ """
142
+ seq_len = saliency_map.shape[0]
143
+
144
+ # Initialize a mask to store the editable token positions
145
+ edit_mask = torch.ones(seq_len, dtype=torch.bool, device=self.device)
146
+
147
+ # Check for any provided soluble residues, otherwise use classifier preds
148
+ if len(soluble_indices) > 0:
149
+ edit_mask[soluble_indices] = False
150
+ elif soluble_indices is None or len(soluble_indices) == 0:
151
+ solubility_preds = F.sigmoid(solubility_logits)
152
+ edit_mask[solubility_preds > 0.5] = False
153
+
154
+ # Find additional TM residues
155
+ num_conserved = max(1, int(0.1 * edit_mask.sum()))
156
+ _, topk_idxs = torch.topk(saliency_map, num_conserved)
157
+ edit_mask[topk_idxs] = False
158
+
159
+ edit_idxs = edit_mask.nonzero(as_tuple=True)[0]
160
+ return edit_idxs
161
+
162
+
163
+ def create_neighborhood(self, edit_pos, attn_matrix, top_p):
164
+ """
165
+ Select a dynamic "neighborhood" of tokens for edit position via top-p sampling.
166
+ Attention scores find relevant tokens, avoding blind updates of the individual token
167
+ """
168
+ # Get the attention scores for the current edit position
169
+ row = attn_matrix[edit_pos].clone().squeeze()
170
+ row = row.index_fill(
171
+ dim=0,
172
+ index=torch.tensor([0, edit_pos, row.size(0)-1], device=row.device),
173
+ value=float('-inf')
174
+ )
175
+
176
+ # Top-p (nucleus) sampling of tokens via normed attention scores
177
+ temp = 1.0 / math.log(row.size(0)) # scale temp with seq len to balance
178
+ attn_probs = F.softmax(row / temp, dim=0)
179
+ sorted_probs, sorted_idxs = torch.sort(attn_probs, descending=True)
180
+ cum_probs = sorted_probs.cumsum(dim=0)
181
+ cutoff = (cum_probs <= top_p).nonzero(as_tuple=True)[0]
182
+
183
+ # Ensure neighborhoods will always have 1 token
184
+ final_idx = cutoff[-1].item() + 1 if cutoff.numel() > 0 else 1
185
+ neighborhood = sorted_idxs[:final_idx]
186
+ return neighborhood
187
+
188
+
189
+ def compute_saliency_weight(self, edit_pos, attn_mat, saliency_map, neighborhood):
190
+ """
191
+ Blend the saliency of the neighborhood's tokens and the token at the edit position.
192
+ """
193
+ neighborhood_attns = attn_mat[edit_pos, neighborhood]
194
+ neighborhood_attns /= neighborhood_attns.sum()
195
+
196
+ neighborhood_saliencies = saliency_map[neighborhood]
197
+
198
+ neighborhood_weight = torch.sum(neighborhood_attns * neighborhood_saliencies)
199
+ ctxt_aware_saliency = saliency_map[edit_pos] + (self.gamma * neighborhood_weight)
200
+
201
+ return ctxt_aware_saliency
202
+
203
+
204
+ def compute_guidance_dist(self, logp_lm, logp_prior, saliency_weight):
205
+ """
206
+ Define a guidance distribution between a prior and the current LM probs.
207
+ Compute the log probs of the "new" (optimized) token.
208
+ """
209
+ w = torch.sigmoid(saliency_weight * self.alpha) # Between [0, 1] to ensure valid probs
210
+ p_lm = torch.exp(logp_lm)
211
+ p_prior = torch.exp(logp_prior)
212
+ mixed_probs = (1 - w) * p_lm + w * p_prior
213
+ guidance_dist = torch.log(mixed_probs + 1e-12)
214
+ return guidance_dist
215
+
216
+
217
+ def check_scaffold(self, seq1, seq2, idxs):
218
+ changed = (seq1[idxs] != seq2[idxs])
219
+ if changed.any():
220
+ _print('soluble residues changed')
221
+ else:
222
+ _print('no soluble residue changes')
223
+
224
+
225
+ def optimize_sequence(self, input_ids, attn_masks, soluble_indices):
226
+ _print(f'soluble idx: {soluble_indices}')
227
+
228
+ # Initialize token ids, logits, and log probs of sequence
229
+ x0, logp_lm, logits_prior = self.denoise_sequence(input_ids, attn_masks)
230
+ _print(f'og tokens: {x0}')
231
+ _print(f'og tokens: {x0.shape}')
232
+ _print(f'og log probs: {logp_lm.shape}')
233
+
234
+ # Embeddings and attention matrix of current sequence
235
+ embeds, attn_mats = self.embed_sequence(x0.unsqueeze(0), attn_masks)
236
+ embeds = embeds.detach().clone().requires_grad_(True) # enable grad tracking for saliency map
237
+ attn_matrix = attn_mats[-1].mean(dim=1)[0].squeeze(0)
238
+
239
+ # Precompute logits of the classifier to avoid repeated calls
240
+ batch = {"embeds": embeds, "attention_mask": attn_masks}
241
+ solubility_logits = self.classifier_model(batch)
242
+
243
+ # Create a saliency map to determined optimal edit positions
244
+ saliency_map = self.compute_saliency_map(embeds, solubility_logits)
245
+ _print(f'saliency map: {saliency_map}')
246
+ edit_positions = self.determine_edit_positions(saliency_map, soluble_indices, solubility_logits)
247
+ _print(f'edit positions: {edit_positions}')
248
+
249
+ # Compute the log probs of the prior dist
250
+ logp_prior = self.get_prior(logits_prior, solubility_logits)
251
+ _print(f'prior log probs: {logp_prior.shape}')
252
+
253
+ # Optimize the insoluble residues
254
+ for edit_pos in edit_positions.tolist():
255
+ neighborhood = self.create_neighborhood(
256
+ edit_pos,
257
+ attn_matrix,
258
+ self.top_p
259
+ )
260
+ _print(f'neighborhood: {neighborhood}')
261
+
262
+ ctxt_aware_saliency = self.compute_saliency_weight(
263
+ edit_pos,
264
+ attn_matrix,
265
+ saliency_map,
266
+ neighborhood
267
+ )
268
+ _print(f'ctx aware saliency: {ctxt_aware_saliency}')
269
+
270
+ logp_lm_prime = self.compute_guidance_dist(
271
+ logp_lm[edit_pos],
272
+ logp_prior[edit_pos],
273
+ ctxt_aware_saliency
274
+ )
275
+ logp_lm[edit_pos] = logp_lm_prime
276
+
277
+ tot = torch.exp(logp_lm_prime).sum()
278
+ one = torch.tensor(1.0, dtype=tot.dtype, device=tot.device)
279
+ assert torch.isclose(tot, one, atol=1e-4), f"Invalid prob distribution. Sum = {tot:5f}"
280
+
281
+ # Sample new tokens
282
+ x0_prime = torch.distributions.Categorical(logits=logp_lm).sample()
283
+
284
+ # Check if any soluble residues have been changed
285
+ self.check_scaffold(x0, x0_prime, soluble_indices)
286
+
287
+ # Preserve the initial sequence scaffold by copying over the soluble tokens
288
+ x0_prime[soluble_indices] = x0[soluble_indices]
289
+ self.check_scaffold(x0, x0_prime, soluble_indices)
290
+
291
+ return x0_prime
src/sampling/unconditional_generator.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import sys
4
+ import os
5
+
6
+ import random
7
+ import torch
8
+ import pandas as pd
9
+ import numpy as np
10
+
11
+ from tqdm import tqdm
12
+ from collections import Counter
13
+ from omegaconf import OmegaConf
14
+ from datetime import datetime
15
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
16
+
17
+ from MeMDLM_v2.src.lm.diffusion_module import MembraneFlow
18
+ from src.sampling.unconditional_sampler import UnconditionalSampler
19
+ from src.utils.generate_utils import mask_for_de_novo, calc_ppl
20
+ from src.utils.model_utils import _print
21
+
22
+
23
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24
+ os.chdir('/home/a03-sgoel/MeMDLM_v2')
25
+ config = OmegaConf.load("./src/configs/lm.yaml")
26
+
27
+ date = datetime.now().strftime("%Y-%m-%d")
28
+
29
+
30
+
31
+ def generate_sequence(prior: str, tokenizer, generator, device):
32
+ input_ids = tokenizer(prior, return_tensors="pt").to(device)['input_ids']
33
+ ids = generator.sample_unconditional(
34
+ xt=input_ids,
35
+ num_steps=config.sampling.n_steps,
36
+ return_logits=False,
37
+ banned_token_ids=None
38
+ #banned_token_ids=[tokenizer.convert_tokens_to_ids("P"), tokenizer.convert_tokens_to_ids("C")]
39
+ )
40
+ generated_sequence = tokenizer.decode(ids[0].squeeze())[5:-5].replace(" ", "") # bos/eos tokens & spaces between residues
41
+ return generated_sequence
42
+
43
+
44
+ def main():
45
+ csv_save_path = f'./results/denovo/unconditional/{config.wandb.name}/{date}_tau=3.0_test-set_distribution'
46
+
47
+ try: os.makedirs(csv_save_path, exist_ok=False)
48
+ except FileExistsError: pass
49
+
50
+
51
+ tokenizer = AutoTokenizer.from_pretrained(config.lm.pretrained_evoflow)
52
+
53
+ flow = MembraneFlow(config).to(device)
54
+ state_dict = flow.get_state_dict(f"./checkpoints/{config.wandb.name}/best_model.ckpt")
55
+ flow.load_state_dict(state_dict)
56
+ flow.eval()
57
+
58
+ esm_pth = config.lm.pretrained_esm
59
+ esm_model = AutoModelForMaskedLM.from_pretrained(esm_pth).to(device)
60
+ esm_model.eval()
61
+
62
+ generator = UnconditionalSampler(tokenizer, flow)
63
+
64
+ # # Get 100 random sequence lengths to generate
65
+ # seq_lengths = [random.randint(50, 250) for _ in range(5000)]
66
+
67
+ # # Determine length from positive controls
68
+ # df = pd.read_csv(f'./results/denovo/unconditional/{config.wandb.name}/perin_pos_ctrl/raw_seqs.csv')
69
+ # seq_lengths = [len(seq) for seq in df['Sequence'].tolist() for _ in range(500)] # generate each length 100 times
70
+ # _print(seq_lengths)
71
+
72
+ # Determine lengths from test set distribution
73
+ df = pd.read_csv("./data/test.csv")
74
+ seq_lengths = [len(seq) for seq in df['Sequence'].tolist()]
75
+ length_counts = Counter(seq_lengths) # {L1: freq, L2: freq, ...}
76
+ total = sum(length_counts.values()) # total number of tokens
77
+ lengths = np.array(list(length_counts.keys())) # Frequency of each length
78
+ probs = np.array([length_counts[l] / total for l in lengths])
79
+ seq_lengths = np.random.choice(lengths, size=len(seq_lengths), p=probs)
80
+
81
+ generation_results = []
82
+ for seq_len in tqdm(seq_lengths, desc=f"Generating sequences: "):
83
+ seq_res = []
84
+
85
+ masked_seq = mask_for_de_novo(seq_len) # Sequence of all <mask> tokens
86
+ gen_seq = ""
87
+ attempts = 0
88
+
89
+ while len(gen_seq) != seq_len and attempts < 3:
90
+ gen_seq = generate_sequence(masked_seq, tokenizer, generator, device)
91
+ attempts += 1
92
+
93
+ if len(gen_seq) != seq_len:
94
+ esm_ppl, flow_ppl = None, None
95
+ else:
96
+ esm_ppl = calc_ppl(esm_model, tokenizer, gen_seq, [i for i in range(len(gen_seq))], model_type='esm')
97
+ flow_ppl = calc_ppl(flow, tokenizer, gen_seq, [i for i in range(len(gen_seq))], model_type='flow')
98
+
99
+ _print(f'gen seq: {gen_seq}')
100
+ _print(f'esm ppl: {esm_ppl}')
101
+ _print(f'flow ppl: {flow_ppl}')
102
+
103
+ seq_res.append(gen_seq)
104
+ seq_res.append(esm_ppl)
105
+ seq_res.append(flow_ppl)
106
+
107
+ generation_results.append(seq_res)
108
+
109
+ df = pd.DataFrame(generation_results, columns=['Generated Sequence', 'ESM PPL', 'Flow PPL'])
110
+ df.to_csv(csv_save_path + "/seqs_with_ppl.csv", index=False)
111
+
112
+
113
+ if __name__ == "__main__":
114
+ main()
src/sampling/unconditional_sampler.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from src.utils.model_utils import _print
7
+
8
+ class UnconditionalSampler:
9
+ def __init__(self, tokenizer, model):
10
+ self.model = model
11
+ self.tokenizer = tokenizer
12
+
13
+ self.device = self.model.device
14
+ self.mask_id = self.tokenizer.mask_token_id
15
+ self.seed_everything(seed=42)
16
+
17
+ @torch.inference_mode()
18
+ def sample_unconditional(self, xt, num_steps, tau=0.7, kappa_fn=lambda t: t, eta=1, alpha=1., banned_token_ids=None, return_logits=None):
19
+ """
20
+ Stochastic remasking sampling method for iterative refinement of sequences.
21
+
22
+ Args:
23
+ xt (Tensor): Initial token tensor.
24
+ num_steps (int): Number of refinement steps.
25
+ tau (float): Temperature parameter for softmax sampling.
26
+ kappa_fn (callable): Function controlling the unmasking schedule.
27
+ eta (float): Scaling factor for score adjustments.
28
+ alpha (float): Weighting for confidence-based scoring.
29
+
30
+ Returns:
31
+ Tensor: Final sampled sequence tensor.
32
+ """
33
+
34
+ dt = 1 / num_steps
35
+ fix_mask = xt != self.mask_id # tokens to retain
36
+ attention_mask = torch.ones_like(xt).to(self.device)
37
+
38
+ for i in range(1, num_steps + 1):
39
+ kappa_t = kappa_fn(i * dt)
40
+ logits = self.model(input_ids=xt, attention_mask=attention_mask)
41
+ last_mask = xt == self.mask_id # tokens currently masked
42
+ unmask_t = ~last_mask & ~fix_mask # unmasked and not fixed tokens - candidates for masking
43
+
44
+ x0, logp = self.stochastic_sample_from_categorical(logits, tau, banned_token_ids=banned_token_ids) # tokens, logprobs
45
+
46
+ # Confidence-based scoring
47
+ entropy = torch.distributions.Categorical(logits=logits).entropy()
48
+ score = alpha * logp + (1 - alpha) * -entropy # alpha = 1 --> score = logp
49
+ score = score.masked_fill(fix_mask, float('inf'))
50
+
51
+ score[unmask_t] = score[unmask_t] * eta
52
+
53
+ num_to_mask = ((~fix_mask).sum(1, keepdim=True).float() * (1 - kappa_t)).long()
54
+ lowest_k_mask = self.topk_lowest_masking(score, num_to_mask)
55
+
56
+ xt[lowest_k_mask] = self.mask_id
57
+ mask_2_x0 = last_mask & ~lowest_k_mask
58
+ xt[mask_2_x0] = x0[mask_2_x0]
59
+
60
+ # print(f"Step {i}/{num_steps} | eta: {eta}, alpha: {alpha}, Stochastic remask: \n", xt[0])
61
+
62
+ xt[xt == self.mask_id] = x0[xt == self.mask_id]
63
+ return xt, logits if return_logits else xt
64
+
65
+ def stochastic_sample_from_categorical(self, logits, temperature, noise_scale=1.0, banned_token_ids=None):
66
+ """
67
+ Sample from a categorical distribution with optional temperature scaling and Gumbel noise.
68
+ """
69
+ logits = logits.double()
70
+
71
+ if banned_token_ids is not None:
72
+ banned_token_mask = torch.zeros_like(logits, device=logits.device).bool()
73
+ for token_id in banned_token_ids:
74
+ banned_token_mask[..., token_id] = True
75
+ logits = logits.masked_fill(banned_token_mask, float('-inf'))
76
+
77
+ if temperature != 0:
78
+ gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-8) + 1e-8)
79
+ logits = logits / temperature + noise_scale * gumbel_noise
80
+ scores, tokens = logits.log_softmax(dim=-1).max(dim=-1)
81
+
82
+ return tokens, scores
83
+
84
+ def topk_lowest_masking(self, scores, cutoff_len):
85
+ """
86
+ scores: [b, n]
87
+ cutoff_len: [b, 1]
88
+ returns:
89
+ mask: [b, n], with 1 if the token is in top-k lowest scores, 0 otherwise
90
+ """
91
+ sorted_index = scores.sort(-1)[0]
92
+ cutoff = sorted_index.gather(dim=-1, index=cutoff_len)
93
+ return scores < cutoff
94
+
95
+ def seed_everything(self, seed):
96
+ """
97
+ Set the seed for reproducibility across various libraries.
98
+ """
99
+ if seed is None:
100
+ return
101
+ random.seed(seed)
102
+ np.random.seed(seed)
103
+ torch.manual_seed(seed)
104
+ if torch.cuda.is_available():
105
+ torch.cuda.manual_seed(seed)
106
+ torch.cuda.manual_seed_all(seed) # if using multi-GPU
107
+ torch.backends.cudnn.deterministic = True
108
+ torch.backends.cudnn.benchmark = False
src/utils/generate_utils.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import sys
4
+
5
+ import torch.nn.functional as F
6
+ import pandas as pd
7
+ import numpy as np
8
+
9
+ from omegaconf import OmegaConf
10
+ from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer
11
+
12
+ from src.lm.memdlm.diffusion_module import MembraneFlow
13
+ from src.lm.dplm.diffusion_module import DPLM
14
+ from src.utils.model_utils import get_latents, _print
15
+ from src.sampling.unconditional_sampler import UnconditionalSampler
16
+ from src.lm.dplm.unconditional_sampler import UnconditionalSampler as DPLMUnconditionalSampler
17
+
18
+ config = OmegaConf.load("/home/a03-sgoel/MeMDLM_v2/src/configs/lm.yaml")
19
+
20
+ # -------# Masking #-------- #
21
+ def mask_for_de_novo(sequence_length):
22
+ return "<mask>" * sequence_length
23
+
24
+ def mask_for_scaffold(sequence, generate_type, mask_token):
25
+ if generate_type == "uppercase":
26
+ sequence = ''.join([mask_token if residue.isupper() else residue.upper() for residue in sequence])
27
+ elif generate_type == "lowercase":
28
+ sequence = ''.join([mask_token if residue.islower() else residue for residue in sequence])
29
+ return sequence
30
+
31
+
32
+ # -------# Generation #-------- #
33
+ def memflow_infill_uncond(masked_seq, tokenizer, model: MembraneFlow):
34
+ generator = UnconditionalSampler(tokenizer, model) # initialize the generator object
35
+ xt = tokenizer(masked_seq, return_tensors='pt')['input_ids'].to(model.device)
36
+ denoised_tokens = generator.sample_unconditional(xt, config.sampling.n_steps)[0].squeeze()
37
+ generated_sequence = tokenizer.decode(denoised_tokens).replace(" ", "")[5:-5]
38
+ return generated_sequence
39
+
40
+
41
+ def evodiff_infill(motif_seq, tokenizer, model, device, batch_size=1):
42
+ """
43
+ Following the given evodiff example
44
+ https://github.com/microsoft/evodiff/blob/main/examples/evodiff.ipynb
45
+ """
46
+ # Manual masking of infilling sequence
47
+ motif_seq = ''.join(["#" if aa.islower() else aa for aa in motif_seq]) # Mask token is "#" in evodiff tokenizer
48
+ tkns = tokenizer.tokenize([motif_seq])
49
+ sample = torch.as_tensor(tkns).to(device)
50
+
51
+ # Create input motif + scaffold
52
+ loc = torch.arange(0, len(motif_seq)).to(device)[sample==tokenizer.mask_id].cpu().numpy()
53
+ np.random.shuffle(loc)
54
+
55
+ sample = sample.to(device).unsqueeze(0)
56
+ # og_sample = sample.clone()
57
+
58
+ with torch.no_grad():
59
+ for i in loc:
60
+ timestep = torch.tensor([0] * batch_size).to(device) # placeholder but not called in model
61
+ timestep = timestep.to(device)
62
+ prediction = model(sample, timestep)
63
+ p = prediction[:, i, :len(tokenizer.all_aas) - 6] # only canonical
64
+ p = F.softmax(p, dim=1) # softmax over logits
65
+ p_sample = torch.multinomial(p, num_samples=1) # sample from categorical distribution
66
+ sample[:, i] = p_sample.squeeze()
67
+ output = [tokenizer.untokenize(s) for s in sample]
68
+ return output[0] #if batch_size==1 else output, og_sample, loc
69
+
70
+
71
+ def dplm_infill(masked_seq, tokenizer, model: DPLM, device):
72
+ generator = DPLMUnconditionalSampler(tokenizer, model)
73
+ xt = tokenizer(masked_seq, return_tensors='pt')['input_ids'].to(model.device)
74
+ denoised_tokens = generator.sample_unconditional(xt, config.sampling.n_steps)[0].squeeze()
75
+ generated_sequence = tokenizer.decode(denoised_tokens).replace(" ", "")[5:-5]
76
+ return generated_sequence
77
+
78
+
79
+ # -------# Metrics #-------- #
80
+ def calc_progen_ppl(model, tokenizer, target, device, fp16=True):
81
+ """Compute causal LM cross-entropy loss for a given sequence."""
82
+ with torch.no_grad():
83
+ with torch.cuda.amp.autocast(enabled=fp16):
84
+ logits = model(
85
+ input_ids = target,
86
+ attention_mask = torch.ones_like(target)
87
+ ).logits
88
+ # Shift
89
+ logits = logits[:-1, ...]
90
+ target = target[1:]
91
+ loss = torch.nn.functional.cross_entropy(
92
+ input=logits,
93
+ target=target,
94
+ reduction='mean'
95
+ )
96
+ return torch.exp(loss).item()
97
+
98
+
99
+ def calc_ppl(model, tokenizer, generated_sequence, mask_token_indices, model_type):
100
+ total_loss = 0.0
101
+ tensor_input = tokenizer.encode(generated_sequence, return_tensors='pt').to(model.device)
102
+ attn_mask = torch.ones_like(tensor_input).to(model.device)
103
+
104
+ for i in mask_token_indices:
105
+ masked_input = tensor_input.clone()
106
+ masked_input[0, i] = tokenizer.mask_token_id
107
+
108
+ labels = torch.full(tensor_input.shape, -100).to(model.device)
109
+ labels[0, i] = tensor_input[0, i]
110
+
111
+ with torch.no_grad():
112
+ if model_type == 'esm':
113
+ loss = model(masked_input, labels=labels).loss.item()
114
+ elif model_type == 'flow':
115
+ logits = model.forward(masked_input, attention_mask=attn_mask)
116
+ loss = F.cross_entropy(
117
+ logits.view(-1, logits.size(-1)),
118
+ labels.view(-1),
119
+ reduction='none',
120
+ ignore_index=-100,
121
+ )[i].item()
122
+
123
+ total_loss += loss
124
+
125
+ avg_loss = total_loss / len(generated_sequence)
126
+ perplexity = math.exp(avg_loss)
127
+
128
+ return perplexity
129
+
130
+
131
+ def calc_blosum_score(og_seq, gen_seq, indices):
132
+ import blosum as bl
133
+ mat = bl.BLOSUM(62)
134
+ tot_score = 0
135
+ for i in indices:
136
+ og_res, gen_res = og_seq[i], gen_seq[i]
137
+ try:
138
+ val = mat[og_res][gen_res]
139
+ tot_score += val
140
+ except KeyError:
141
+ # -4 is lowest BLOSUM score indicating biological implausability
142
+ tot_score += -4
143
+ return tot_score / len(indices) if indices else 0
144
+
145
+
146
+ def calc_cos_sim(original_sequence, generated_sequence, tokenizer, esm_model, device):
147
+ og_embeddings = get_latents(esm_model, tokenizer, original_sequence.upper(), device)
148
+ new_embeddings = get_latents(esm_model, tokenizer, generated_sequence, device)
149
+ cosine_sim = torch.nn.functional.cosine_similarity(og_embeddings, new_embeddings, dim=-1)
150
+ cosine_sim = torch.mean(cosine_sim).item()
151
+ return cosine_sim
src/utils/model_utils.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ def _print(s):
7
+ print(s)
8
+ sys.stdout.flush()
9
+
10
+
11
+ def get_latents(model, tokenizer, sequence, device):
12
+ tokens = tokenizer(sequence, return_tensors="pt").to(device)
13
+ with torch.no_grad():
14
+ outputs = model(**tokens)
15
+ embeds = outputs.hidden_states[-1].squeeze(0) # Get last hidden states
16
+ return embeds
17
+
18
+
19
+
20
+ # General model freezing
21
+ def freeze_model(model: nn.Module):
22
+ # Disable parameter updates for all layers
23
+ for param in model.parameters():
24
+ param.requires_grad = False
25
+
26
+
27
+
28
+ # For ProGen2 architecture
29
+ def apply_gptj_freezing(model, N_layers):
30
+ def unfreeze_n_layers(model, N_layers):
31
+ # Count number of encoder layers
32
+ model_layers = len(model.transformer.h)
33
+ for i, h in enumerate(model.transformer.h):
34
+ if i >= model_layers - N_layers:
35
+ for module in h.attn.modules():
36
+ for param in module.parameters():
37
+ param.requires_grad = True
38
+
39
+ def check_frozen_model(model, N_layers: int):
40
+ """
41
+ Verify that only the last N_layers of model.transformer.h are unfrozen.
42
+ Source: https://github.com/enijkamp/progen2/blob/main/progen/modeling_progen.py
43
+ """
44
+ model_layers = len(model.transformer.h)
45
+ frozen_layers = 0
46
+ unfrozen_layers = 0
47
+ for i, h in enumerate(model.transformer.h):
48
+ if i >= model_layers - N_layers: # should be unfrozen
49
+ if any(param.requires_grad for param in h.parameters()):
50
+ unfrozen_layers += 1
51
+ else:
52
+ print(f"Layer {i} has all parameters frozen, but it should be unfrozen.")
53
+ else: # should be frozen
54
+ if any(param.requires_grad for param in h.parameters()):
55
+ print(f"Layer {i} is not frozen, but it should be frozen.")
56
+ else:
57
+ frozen_layers += 1
58
+
59
+ assert frozen_layers == model_layers - N_layers and unfrozen_layers == N_layers, \
60
+ f"frozen layers: {frozen_layers}, unfrozen layers: {unfrozen_layers}"
61
+
62
+ print(f"frozen layers: {frozen_layers}, unfrozen layers: {unfrozen_layers}")
63
+
64
+ freeze_model(model)
65
+ unfreeze_n_layers(model, N_layers)
66
+ check_frozen_model(model, N_layers)
67
+
68
+
69
+
70
+
71
+
72
+ # For RDM-based architectures
73
+ def apply_rdm_freezing(model: nn.Module, N_layers: int, model_type: str):
74
+ """
75
+ Freeze all layers except last N for esm-like architectures
76
+
77
+ Args:
78
+ model (nn.Module): model to freeze
79
+ N_layers (int): num encoder layers to unfreeze
80
+ model_type (str): one of {"esm", "evoflow", "dplm"}
81
+ """
82
+
83
+ # choose encoder layers based on the model type
84
+ if model_type == "dplm":
85
+ encoder_layers = model.net.esm.encoder.layer
86
+ elif model_type in ("esm", "evoflow"):
87
+ encoder_layers = model.esm.encoder.layer
88
+ else:
89
+ raise ValueError(f"Unknown model_type: {model_type}")
90
+
91
+ def unfreeze_n_layers(layers, N_layers: int):
92
+ model_layers = len(layers)
93
+ for i, layer in enumerate(layers):
94
+ if i >= model_layers - N_layers:
95
+ for module in layer.attention.self.key.modules():
96
+ for param in module.parameters():
97
+ param.requires_grad = True
98
+ for module in layer.attention.self.query.modules():
99
+ for param in module.parameters():
100
+ param.requires_grad = True
101
+ for module in layer.attention.self.value.modules():
102
+ for param in module.parameters():
103
+ param.requires_grad = True
104
+
105
+ def check_model(layers, N_layers: int):
106
+ model_layers = len(layers)
107
+ frozen_layers = 0
108
+ unfrozen_layers = 0
109
+
110
+ for i, layer in enumerate(layers):
111
+ if i >= model_layers - N_layers:
112
+ layer_frozen = True
113
+ for module in layer.attention.self.key.modules():
114
+ if any(param.requires_grad for param in module.parameters()):
115
+ layer_frozen = False
116
+ for module in layer.attention.self.query.modules():
117
+ if any(param.requires_grad for param in module.parameters()):
118
+ layer_frozen = False
119
+ for module in layer.attention.self.value.modules():
120
+ if any(param.requires_grad for param in module.parameters()):
121
+ layer_frozen = False
122
+
123
+ if layer_frozen:
124
+ print(f"layer {i} has all parameters frozen, but it should be unfrozen.")
125
+ else:
126
+ unfrozen_layers += 1
127
+ else:
128
+ if any(param.requires_grad for param in layer.parameters()):
129
+ print(f"layer {i} is not frozen, but it should")
130
+ else:
131
+ frozen_layers += 1
132
+
133
+ assert (frozen_layers == model_layers - N_layers) and (unfrozen_layers == N_layers), \
134
+ f"frozen layers: {frozen_layers}, unfrozen layers: {unfrozen_layers}"
135
+
136
+
137
+ freeze_model(model)
138
+ unfreeze_n_layers(encoder_layers, N_layers)
139
+ check_model(encoder_layers, N_layers)
src/utils/optimizer_utils.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+
4
+ from torch.optim import Optimizer
5
+ from torch.optim.lr_scheduler import LambdaLR
6
+ from torch.optim.adamw import adamw
7
+
8
+ try:
9
+ import deepspeed
10
+ from deepspeed.ops.adam import FusedAdam
11
+ from deepspeed.ops.adam import DeepSpeedCPUAdam
12
+ except:
13
+ pass
14
+
15
+
16
+ def get_optimizer(cfg, params):
17
+ if cfg.optim.type == 'adam':
18
+ return torch.optim.Adam(
19
+ params=params,
20
+ lr=cfg.optim.lr,
21
+ weight_decay=cfg.optim.weight_decay,
22
+ betas=(cfg.optim.beta1, cfg.optim.beta2)
23
+ )
24
+ elif cfg.optim.type == 'adamw':
25
+ return AdamW(
26
+ params=params,
27
+ lr=cfg.optim.lr,
28
+ weight_decay=cfg.optim.weight_decay,
29
+ betas=(cfg.optim.beta1, cfg.optim.beta2)
30
+ )
31
+ elif cfg.type == 'fusedadam':
32
+ return FusedAdam(
33
+ params=params,
34
+ lr=cfg.lr,
35
+ weight_decay=cfg.weight_decay,
36
+ betas=cfg.betas,
37
+ )
38
+ else:
39
+ raise NotImplementedError('Optimizer not supported: %s' % cfg.type)
40
+
41
+
42
+ class AdamW(torch.optim.AdamW):
43
+ @torch.no_grad()
44
+ def step(self, closure=None):
45
+ """Performs a single optimization step.
46
+
47
+ Args:
48
+ closure (callable, optional): A closure that reevaluates the model
49
+ and returns the loss.
50
+ """
51
+ self._cuda_graph_capture_health_check()
52
+
53
+ loss = None
54
+ if closure is not None:
55
+ with torch.enable_grad():
56
+ loss = closure()
57
+
58
+ for group in self.param_groups:
59
+ params_with_grad = []
60
+ grads = []
61
+ exp_avgs = []
62
+ exp_avg_sqs = []
63
+ max_exp_avg_sqs = []
64
+ state_steps = []
65
+ amsgrad = group['amsgrad']
66
+ beta1, beta2 = group['betas']
67
+
68
+ for p in group['params']:
69
+ if p.grad is None:
70
+ continue
71
+ params_with_grad.append(p)
72
+ if p.grad.is_sparse:
73
+ raise RuntimeError('AdamW does not support sparse gradients')
74
+ grads.append(p.grad)
75
+
76
+ state = self.state[p]
77
+
78
+ # State initialization
79
+ if len(state) == 0:
80
+ state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \
81
+ if self.defaults['capturable'] else torch.tensor(0.)
82
+ # Exponential moving average of gradient values
83
+ state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
84
+ # Exponential moving average of squared gradient values
85
+ state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
86
+ if amsgrad:
87
+ # Maintains max of all exp. moving avg. of sq. grad. values
88
+ state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
89
+
90
+ exp_avgs.append(state['exp_avg'])
91
+ exp_avg_sqs.append(state['exp_avg_sq'])
92
+
93
+ if amsgrad:
94
+ max_exp_avg_sqs.append(state['max_exp_avg_sq'])
95
+
96
+ state_steps.append(state['step'].cpu())
97
+
98
+ adamw(params_with_grad,
99
+ grads,
100
+ exp_avgs,
101
+ exp_avg_sqs,
102
+ max_exp_avg_sqs,
103
+ state_steps,
104
+ amsgrad=amsgrad,
105
+ beta1=beta1,
106
+ beta2=beta2,
107
+ lr=group['lr'],
108
+ weight_decay=group['weight_decay'],
109
+ eps=group['eps'],
110
+ maximize=group['maximize'],
111
+ foreach=group['foreach'],
112
+ capturable=group['capturable'])
113
+
114
+ return loss
115
+
116
+ def get_scheduler(cfg, optimizer):
117
+ if cfg.optim.scheduler is None:
118
+ return BlackHole()
119
+ elif cfg.optim.scheduler == 'plateau':
120
+ return (
121
+ torch.optim.lr_scheduler.ReduceLROnPlateau(
122
+ optimizer,
123
+ mode=cfg.mode,
124
+ factor=cfg.factor,
125
+ patience=cfg.patience,
126
+ min_lr=cfg.min_lr,
127
+ ),
128
+ {'monitor': "val/loss", 'interval': 'epoch'}
129
+ )
130
+ elif cfg.optim.scheduler == 'noam':
131
+ return (
132
+ NoamScheduler(
133
+ optimizer,
134
+ lr=cfg.lr,
135
+ warmup_steps=cfg.warmup_steps,
136
+ model_size=cfg.model_size,
137
+ warmup_init_lr=cfg.get('warmup_init_lr')
138
+ ),
139
+ {'frequency': 1, 'interval': 'step'}
140
+ )
141
+ elif cfg.optim.scheduler == 'polynomial':
142
+ return (
143
+ PolyNomialLRScheduler(
144
+ optimizer,
145
+ total_steps=cfg.training.max_steps,
146
+ warmup_steps=cfg.training.warmup_steps,
147
+ lr=cfg.optim.lr,
148
+ lr_end=cfg.optim.lr_end,
149
+ warmup_init_lr=cfg.optim.warmup_init_lr,
150
+ power=cfg.optim.power
151
+ ),
152
+ {'frequency': 1, 'interval': 'step'}
153
+ )
154
+ elif cfg.optim.scheduler == 'multistep':
155
+ return torch.optim.lr_scheduler.MultiStepLR(
156
+ optimizer,
157
+ milestones=cfg.milestones,
158
+ gamma=cfg.gamma,
159
+ )
160
+ elif cfg.optim.scheduler == 'exp':
161
+ return torch.optim.lr_scheduler.ExponentialLR(
162
+ optimizer,
163
+ gamma=cfg.gamma,
164
+ )
165
+ elif cfg.optim.scheduler == 'progen_ft':
166
+ sched = CosineToFrac(
167
+ optimizer=optimizer,
168
+ total_steps=cfg.training.max_steps,
169
+ final_frac=0.2, # decay to lr/5
170
+ )
171
+ return (sched, {'frequency': 1, 'interval': 'step'})
172
+ elif cfg.optim.scheduler is None:
173
+ return BlackHole()
174
+ else:
175
+ raise NotImplementedError('Scheduler not supported: %s' % cfg.optim.scheduler)
176
+
177
+
178
+ class BlackHole(object):
179
+ def __setattr__(self, name, value):
180
+ pass
181
+
182
+ def __call__(self, *args, **kwargs):
183
+ return self
184
+
185
+ def __getattr__(self, name):
186
+ return self
187
+
188
+
189
+ # -------# DPLM Scheduler #-------- #
190
+ def polynomial_lr_schedule(step, total_steps, warmup_steps, warmup_init_lr, lr, lr_end, power):
191
+ if step < warmup_steps:
192
+ return warmup_init_lr + (lr - warmup_init_lr) * step / warmup_steps
193
+ elif step > total_steps:
194
+ return lr_end
195
+ else:
196
+ return lr_end + (lr - lr_end) * (1 - (step - warmup_steps) / (total_steps - warmup_steps)) ** power
197
+
198
+ class PolyNomialLRScheduler(LambdaLR):
199
+ def __init__(
200
+ self,
201
+ optimizer: Optimizer,
202
+ total_steps: int = 1000,
203
+ warmup_steps: int = 0,
204
+ lr: float = 0.00004, # 5e-04,
205
+ lr_end: float = 1e-5, #1e-07,
206
+ warmup_init_lr: float = 1e-07, # 1e-07,
207
+ power: float = 1.0,
208
+ ) -> None:
209
+
210
+ self.warmup_init_lr = warmup_init_lr
211
+ self.warmup_steps = warmup_steps
212
+
213
+ def lr_lambda(step):
214
+ return polynomial_lr_schedule(
215
+ step, total_steps, warmup_steps, warmup_init_lr, lr, lr_end, power
216
+ ) / lr
217
+
218
+ super().__init__(optimizer, lr_lambda=lr_lambda)
219
+
220
+
221
+ # -------# ProGen2 Fine-Tuning Scheduler #-------- #
222
+ def cosine_frac_scheduler(step, total_steps, final_frac):
223
+ s = min(max(step, 0), total_steps)
224
+ cos = 0.5 * (1.0 + math.cos(math.pi * s / total_steps)) # 1 --> 0
225
+ return final_frac + (1.0 - final_frac) * cos # multiplier goes from 1.0 down to final_frac
226
+
227
+ class CosineToFrac(LambdaLR):
228
+ """
229
+ Cosine decay of the LR multiplier from 1.0 -> final_frac over total_steps (no warmup).
230
+ For ProGen fine-tuning, final_frac=0.2 implements decay to lr/5.
231
+ """
232
+ def __init__(self, optimizer, total_steps, final_frac=0.2):
233
+ self.total_steps = max(int(total_steps), 1)
234
+ self.final_frac = float(final_frac)
235
+
236
+ def lr_lambda(step):
237
+ return cosine_frac_scheduler(
238
+ step=step,
239
+ total_steps=self.total_steps,
240
+ final_frac=self.final_frac
241
+ )
242
+
243
+ super().__init__(optimizer, lr_lambda=lr_lambda)
244
+
245
+
246
+
247
+ def inverse_sqrt_lr_schedule(step, warmup_steps, warmup_init_lr, lr_step, decay_step):
248
+ if step == 0:
249
+ step = 1
250
+ if step < warmup_steps:
251
+ return warmup_init_lr + lr_step * step
252
+ else:
253
+ return decay_step * step ** -0.5
254
+
255
+
256
+ class InverseSqrtLRScheduler(LambdaLR):
257
+ def __init__(
258
+ self,
259
+ optimizer: Optimizer,
260
+ warmup_steps: int = 0,
261
+ lr: float = 5e-04,
262
+ warmup_init_lr: float = 1e-07,
263
+ ) -> None:
264
+
265
+ self.warmup_init_lr = warmup_init_lr
266
+ self.warmup_steps = warmup_steps
267
+ self.lr_step = (lr - warmup_init_lr) / warmup_steps
268
+ self.decay_step = lr * warmup_steps ** 0.5
269
+
270
+ def lr_lambda(step):
271
+ return inverse_sqrt_lr_schedule(
272
+ step, warmup_steps, warmup_init_lr, self.lr_step, self.decay_step
273
+ ) / lr
274
+
275
+ super().__init__(optimizer, lr_lambda=lr_lambda)
276
+
277
+
278
+ def noam_lr_schedule(step, warmup_steps, factor, model_size):
279
+ if step == 0:
280
+ step = 1
281
+ return factor * (model_size ** (-0.5) * min(step ** (-0.5), step * warmup_steps ** (-1.5)))
282
+
283
+
284
+ class NoamScheduler(LambdaLR):
285
+ def __init__(
286
+ self,
287
+ optimizer: Optimizer,
288
+ lr,
289
+ warmup_init_lr,
290
+ model_size: int = 128,
291
+ warmup_steps: int = 0,
292
+ factor: int = 2,
293
+ ) -> None:
294
+
295
+ # dummy_lr = self.base_lrs[0]
296
+ def lr_lambda(step):
297
+ return noam_lr_schedule(step, warmup_steps, factor, model_size) / lr
298
+
299
+ super().__init__(optimizer, lr_lambda=lr_lambda)
300
+