zyc4975matholic commited on
Commit
303c2e0
·
1 Parent(s): 63d28a7

Include DNA training code

Browse files
tr2d2-dna/README.md ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TR2-D2 For Enhancer DNA Design
2
+
3
+ This part of the code is for finetuning DNA sequence models for optimizing DNA enhancer activity with TR2-D2.
4
+
5
+ The codebase is built upon [MDLM (Sahoo et.al, 2023)](https://github.com/kuleshov-group/mdlm), [Drakes (Wang et.al, 2024)](https://github.com/ChenyuWang-Monica/DRAKES), [SEPO (Zekri et.al, 2025)](https://github.com/ozekri/SEPO/tree/main), and [MDNS (Zhu et.al, 2025)](https://arxiv.org/abs/2508.10684).
6
+
7
+ ## Environment Installation
8
+ ```
9
+ conda create -n tr2d2-dna python=3.9.18
10
+
11
+ conda activate tr2d2-dna
12
+
13
+ bash env.sh
14
+ ```
15
+
16
+ ## Model Pretrained Weights Download
17
+
18
+ All data and model weights can be downloaded from the link below, which is provided by the [DRAKES](https://arxiv.org/abs/2410.13643) author. Save the downloaded file in `$BASE_PATH`.
19
+
20
+ https://www.dropbox.com/scl/fi/zi6egfppp0o78gr0tmbb1/DRAKES_data.zip?rlkey=yf7w0pm64tlypwsewqc01wmfq&st=xe8dzn8k&dl=0
21
+
22
+ For downloading using terminal, use
23
+
24
+ ```
25
+ curl -L -o dna.zip "https://www.dropbox.com/scl/fi/zi6egfppp0o78gr0tmbb1/DRAKES_data.zip?rlkey=yf7w0pm64tlypwsewqc01wmfq&st=xe8dzn8k&dl=0"
26
+
27
+ unzip dna.zip
28
+ ```
29
+
30
+ ## Finetune with TR2-D2
31
+ After downloading the pretrained checkpoints, fill in the `base_path` in `dataloader_gosai.py`, `oracle.py`, and `finetune.sh`. Fill in `HOME_LOC` and `SAVE_PATH` in `finetune.sh` as well.
32
+
33
+ Reproduce the DNA experiments with $\alpha = 0.1$ using
34
+ ```
35
+ sbatch train.sh
36
+ ```
37
+
38
+ ## Evaluate saved checkpoints
39
+ The checkpoints will be saved to `SAVE_PATH`.
40
+ Fill in `RUNS_DIR` in `run_batch_eval.sh` to be the same as `SAVE_PATH`. The checkpoints can be evaluated with
41
+ ```
42
+ sbatch run_batch_eval.sh
43
+ ```
44
+
45
+
46
+
47
+
48
+
49
+
tr2d2-dna/configs_gosai/callbacks/checkpoint_every_n_steps.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ checkpoint_every_n_steps:
2
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
3
+ save_top_k: -1 # Do not save any "best" models; this callback is being used to save every n train steps
4
+ save_last: True # save model as ${save_dir}/checkpoints/last.ckpt
5
+ dirpath: ${checkpointing.save_dir}/checkpoints
6
+ verbose: True
7
+ auto_insert_metric_name: False
8
+ every_n_train_steps: 500
tr2d2-dna/configs_gosai/callbacks/checkpoint_monitor.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ checkpoint_monitor:
2
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
3
+ monitor: val/nll # name of the logged metric which determines when model is improving
4
+ mode: min # can be "max" or "min"
5
+ save_top_k: 1 # save k best models (determined by above metric)
6
+ save_last: False # True = additionally always save model from last epoch
7
+ dirpath: ${checkpointing.save_dir}/checkpoints
8
+ filename: best
9
+ auto_insert_metric_name: False
10
+ verbose: True
tr2d2-dna/configs_gosai/callbacks/learning_rate_monitor.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ learning_rate_monitor:
2
+ _target_: lightning.pytorch.callbacks.LearningRateMonitor
3
+ logging_interval: step
tr2d2-dna/configs_gosai/config_gosai.yaml ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - /callbacks: [checkpoint_every_n_steps, checkpoint_monitor, learning_rate_monitor]
4
+ - /model: dnaconv
5
+ - /strategy: ddp
6
+ - /noise: loglinear
7
+ - /lr_scheduler: constant_warmup
8
+
9
+ mode: train
10
+ diffusion: absorbing_state
11
+ backbone: cnn
12
+ parameterization: subs
13
+ time_conditioning: False
14
+ T: 0 # 0 (continuous time) / 1000
15
+ subs_masking: False
16
+ debug_mode: False
17
+
18
+ seed: 1
19
+
20
+ data:
21
+ streaming: False
22
+
23
+ loader:
24
+ global_batch_size: 512
25
+ eval_global_batch_size: ${.global_batch_size}
26
+ # Note: batch_size and eval_batch_size are **per machine**
27
+ batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
28
+ eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
29
+ num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
30
+ pin_memory: True
31
+
32
+ sampling:
33
+ predictor: ddpm
34
+ steps: 128
35
+ noise_removal: True
36
+ num_sample_batches: 2 # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
37
+ num_sample_log: 2
38
+ semi_ar: False
39
+ stride_length: 1
40
+ num_strides: 1
41
+
42
+ training:
43
+ ema: 0.9999
44
+ antithetic_sampling: True
45
+ importance_sampling: False
46
+ sampling_eps: 1e-3
47
+ change_of_variables: False
48
+
49
+ eval:
50
+ checkpoint_path: '' # Used to evaluate a checkpoint after training.
51
+ disable_ema: False
52
+ compute_generative_perplexity: True # False
53
+ perplexity_batch_size: 8
54
+ compute_perplexity_on_sanity: False
55
+ gen_ppl_eval_model_name_or_path: gpt2-large # gpt2-large, meta-llama/Llama-2-7b-hf
56
+ generate_samples: True
57
+ subset_size: 5000
58
+
59
+ optim:
60
+ weight_decay: 0
61
+ lr: 3e-4
62
+ beta1: 0.9
63
+ beta2: 0.999
64
+ eps: 1e-8
65
+
66
+ trainer:
67
+ _target_: lightning.Trainer
68
+ accelerator: cuda
69
+ num_nodes: 1
70
+ devices: ${device_count:}
71
+ accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
72
+ gradient_clip_val: 1.0
73
+ precision: 'bf16'
74
+ num_sanity_val_steps: 2
75
+ max_steps: 131500 # 100 epochs
76
+ log_every_n_steps: 10
77
+ limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
78
+ limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run
79
+ val_check_interval: 1000
80
+
81
+ wandb:
82
+ project: gosai-dna
83
+ notes: null
84
+ group: null
85
+ job_type: null
86
+ name: null
87
+ id: ${uuid:}
88
+ tags:
89
+ - ${noise.type}
90
+
91
+ hydra:
92
+ run:
93
+ dir: ${now:%Y.%m.%d}/${now:%H%M%S}
94
+ job:
95
+ chdir: true
96
+
97
+ checkpointing:
98
+ # Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
99
+ save_dir: ${cwd:}
100
+ # Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
101
+ resume_from_ckpt: true
102
+ resume_ckpt_path: ${.save_dir}/checkpoints/last.ckpt
103
+
104
+ finetuning:
105
+ gumbel_softmax_temp: 1.0
106
+ truncate_steps: 3
107
+
108
+ mcts:
109
+ sampling: 0 # 0: gumbel noise, >0 top-k sampling
tr2d2-dna/configs_gosai/lr_scheduler/constant_warmup.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ _target_: transformers.get_constant_schedule_with_warmup
2
+ num_warmup_steps: 2500
tr2d2-dna/configs_gosai/lr_scheduler/cosine_decay_warmup.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ _target_: utils.CosineDecayWarmupLRScheduler
2
+ t_in_epochs: False
3
+ t_initial: ${eval:${trainer.max_steps}-${.warmup_t}}
4
+ warmup_prefix: True
5
+ warmup_lr_init: 1e-6
6
+ warmup_t: ${eval:0.1*${trainer.max_steps}}
7
+ lr_min: 1e-6
tr2d2-dna/configs_gosai/model/dnaconv.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: dnaconv
2
+ type: cnn
3
+ length: 200 # for gosai
4
+ hidden_dim: 128
5
+ num_cnn_stacks: 4
6
+ dropout: 0.0
7
+ clean_data: False
8
+
9
+ cls_free_guidance: False
10
+ cls_free_threshold: 2.52
11
+ cls_free_prob: 0.3
12
+ cls_free_weight: 0.3 # weight in sampling
tr2d2-dna/configs_gosai/noise/ar.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ type: ar
2
+ scale: 6.0
tr2d2-dna/configs_gosai/noise/cosine.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ type: cosine
tr2d2-dna/configs_gosai/noise/geometric.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ type: geometric
2
+ sigma_min: 1e-4
3
+ sigma_max: 20
tr2d2-dna/configs_gosai/noise/linear.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ type: linear
2
+ sigma_min: 1e-3
3
+ sigma_max: 7.0
tr2d2-dna/configs_gosai/noise/loglinear.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ type: loglinear
2
+ sigma_min: 1e-4
3
+ sigma_max: 20
tr2d2-dna/configs_gosai/noise/polynomial.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ type: polynomial
2
+ a: -3
3
+ b: 5
4
+ c: -4
5
+ eps: 1e-3
tr2d2-dna/configs_gosai/strategy/ddp.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ _target_: lightning.pytorch.strategies.DDPStrategy
2
+ find_unused_parameters: false # TODO(yair): this seems hacky, I think if things are correct we shouldn't need this
tr2d2-dna/configs_gosai/strategy/fsdp.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # TODO(yair): Currenly not compatible with grad clipping
2
+ _target_: lightning.pytorch.strategies.FSDPStrategy
3
+ sharding_strategy: SHARD_GRAD_OP
tr2d2-dna/dataloader_gosai.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pandas as pd
3
+ import typing
4
+ import math
5
+ import utils
6
+ import numpy as np
7
+ import os
8
+
9
+ base_path = "" # Fill in directory of the pretrained checkpoints, e.g., "...../data_and_model/"
10
+ LOGGER = utils.get_logger(__name__)
11
+ DNA_ALPHABET = {'A': 0, 'C': 1, 'G': 2, 'T': 3} #, 'M': 4}
12
+ INDEX_TO_DNA = {v: k for k, v in DNA_ALPHABET.items()}
13
+ lookup_array = np.array([INDEX_TO_DNA[i] for i in range(len(INDEX_TO_DNA))])
14
+
15
+
16
+ def dna_detokenize(seq):
17
+ return ''.join([list(DNA_ALPHABET.keys())[int(i)] for i in seq])
18
+
19
+ def batch_dna_detokenize(batch_seq):
20
+ """
21
+ batch_seq: numpy array of shape [batch_size, seq_len]
22
+ return: list of strings
23
+ """
24
+ detokenized_batch = lookup_array[batch_seq]
25
+ detokenized_batch = [''.join(seq) for seq in detokenized_batch]
26
+ return detokenized_batch
27
+
28
+ def dna_tokenize(seq):
29
+ return [DNA_ALPHABET[c] for c in seq]
30
+
31
+ def batch_dna_tokenize(batch_seq):
32
+ """
33
+ batch_seq: list of strings
34
+ return: numpy array of shape [batch_size, seq_len]
35
+ """
36
+ tokenized_batch = np.array([[DNA_ALPHABET[c] for c in seq] for seq in batch_seq])
37
+ return tokenized_batch
38
+
39
+ class GosaiDataset(torch.utils.data.Dataset):
40
+ def __init__(self):
41
+ data_df = pd.read_csv(os.path.join(base_path, f'mdlm/gosai_data/processed_data/gosai_all.csv'))
42
+ self.seqs = torch.tensor(data_df['seq'].apply(lambda x: [DNA_ALPHABET[c] for c in x]).tolist())
43
+ self.clss = torch.tensor(data_df[['hepg2', 'k562', 'sknsh']].to_numpy())
44
+ LOGGER.info(f'Loaded data: seqs shape: {self.seqs.shape}, clss shape: {self.clss.shape}')
45
+
46
+ def __len__(self):
47
+ return len(self.seqs)
48
+
49
+ def __getitem__(self, idx):
50
+ return {'seqs': self.seqs[idx], 'clss': self.clss[idx], 'attention_mask': torch.ones(len(self.seqs[idx]))}
51
+
52
+
53
+ def get_datasets_gosai():
54
+ return GosaiDataset()
55
+
56
+
57
+ def get_dataloaders_gosai(config, skip_valid=False, valid_seed=None):
58
+ num_gpus = torch.cuda.device_count()
59
+ if config.loader.global_batch_size % (
60
+ num_gpus * config.trainer.accumulate_grad_batches) != 0:
61
+ raise ValueError(
62
+ f'Train Batch Size {config.training.batch_size}'
63
+ f'not divisible by {num_gpus} gpus with accumulation '
64
+ f'{config.trainer.accumulate_grad_batches}.')
65
+ if config.loader.eval_global_batch_size % num_gpus != 0:
66
+ raise ValueError(
67
+ f'Eval Batch Size for {config.eval.batch_size} '
68
+ f'not divisible by {num_gpus}.')
69
+ train_set = GosaiDataset()
70
+ # randomly sample a subset of the train_set as valid_set
71
+ valid_set = torch.utils.data.Subset(train_set, np.random.choice(len(train_set), 40000, replace=False))
72
+ test_set = torch.utils.data.Subset(train_set, np.random.choice(len(train_set), 40000, replace=False))
73
+
74
+ train_loader = torch.utils.data.DataLoader(
75
+ train_set,
76
+ batch_size=config.loader.batch_size,
77
+ num_workers=config.loader.num_workers,
78
+ pin_memory=config.loader.pin_memory,
79
+ shuffle=not config.data.streaming,
80
+ persistent_workers=True)
81
+ if skip_valid:
82
+ valid_loader = None
83
+ test_loader = None
84
+ else:
85
+ if valid_seed is None:
86
+ shuffle_valid = False
87
+ generator = None
88
+ else:
89
+ shuffle_valid = True
90
+ generator = torch.Generator().manual_seed(valid_seed)
91
+ valid_loader = torch.utils.data.DataLoader(
92
+ valid_set,
93
+ batch_size=config.loader.eval_batch_size,
94
+ num_workers=config.loader.num_workers,
95
+ pin_memory=config.loader.pin_memory,
96
+ shuffle=shuffle_valid,
97
+ generator=generator)
98
+ test_loader = torch.utils.data.DataLoader(
99
+ test_set,
100
+ batch_size=config.loader.eval_batch_size,
101
+ num_workers=config.loader.num_workers,
102
+ pin_memory=config.loader.pin_memory,
103
+ shuffle=shuffle_valid,
104
+ generator=generator)
105
+
106
+ return train_loader, valid_loader, test_loader
107
+
108
+
109
+ # Samplers adapted from: https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/fault_tolerant_sampler.py
110
+ class RandomFaultTolerantSampler(torch.utils.data.RandomSampler):
111
+
112
+ def __init__(self, *args, generator=None, **kwargs):
113
+ # TD [2022-07-17]: We don't force the seed to be zero. We generate random seed,
114
+ # which should be reproducible if pl.seed_everything was called beforehand.
115
+ # This means that changing the seed of the experiment will also change the
116
+ # sampling order.
117
+ if generator is None:
118
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
119
+ generator = torch.Generator().manual_seed(seed)
120
+ kwargs.pop('shuffle', None)
121
+ super().__init__(*args, generator=generator, **kwargs)
122
+ self.counter = 0
123
+ self.restarting = False
124
+
125
+ def state_dict(self):
126
+ return {'random_state': self.generator.get_state(),
127
+ 'counter': self.counter}
128
+
129
+ def load_state_dict(self, state_dict):
130
+ self.generator.set_state(state_dict.get('random_state'))
131
+ self.counter = state_dict['counter']
132
+ # self.start_counter = self.counter
133
+ self.restarting = True
134
+
135
+ # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
136
+ # epoch, and subsequent epoch will have very few batches.
137
+
138
+ def __iter__(self) -> typing.Iterator[int]:
139
+ n = len(self.data_source)
140
+
141
+ self.state = self.generator.get_state()
142
+ indices = torch.randperm(n, generator=self.generator).tolist()
143
+
144
+ if not self.restarting:
145
+ self.counter = 0
146
+ else:
147
+ indices = indices[self.counter:]
148
+ self.restarting = False
149
+
150
+ for index in indices:
151
+ self.counter += 1
152
+ yield index
153
+
154
+ self.counter = 0
155
+
156
+
157
+ class FaultTolerantDistributedSampler(torch.utils.data.DistributedSampler):
158
+
159
+ def __init__(self, *args, **kwargs):
160
+ super().__init__(*args, **kwargs)
161
+ self.counter = 0
162
+ self.restarting = False
163
+
164
+ def state_dict(self):
165
+ return {'epoch': self.epoch, 'counter': self.counter}
166
+
167
+ def load_state_dict(self, state_dict):
168
+ self.epoch = state_dict['epoch']
169
+ self.counter = state_dict['counter']
170
+ self.restarting = True
171
+
172
+ # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
173
+ # epoch, and subsequent epoch will have very few batches.
174
+ def __iter__(self):
175
+ if self.shuffle:
176
+ # deterministically shuffle based on epoch and seed
177
+ g = torch.Generator()
178
+ g.manual_seed(self.seed + self.epoch)
179
+ indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
180
+ else:
181
+ indices = list(range(len(self.dataset))) # type: ignore[arg-type]
182
+
183
+ if not self.drop_last:
184
+ # add extra samples to make it evenly divisible
185
+ padding_size = self.total_size - len(indices)
186
+ if padding_size <= len(indices):
187
+ indices += indices[:padding_size]
188
+ else:
189
+ indices += (indices * math.ceil(
190
+ padding_size / len(indices)))[:padding_size]
191
+ else:
192
+ # remove tail of data to make it evenly divisible.
193
+ indices = indices[:self.total_size]
194
+ assert len(indices) == self.total_size
195
+
196
+ # subsample
197
+ indices = indices[self.rank:self.total_size:self.num_replicas]
198
+ assert len(indices) == self.num_samples
199
+
200
+ if not self.restarting:
201
+ self.counter = 0
202
+ else:
203
+ indices = indices[self.counter:]
204
+ self.restarting = False
205
+
206
+ for index in indices:
207
+ self.counter += 1
208
+ yield index
209
+
210
+ self.counter = 0
211
+
tr2d2-dna/diffusion.py ADDED
@@ -0,0 +1,1604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import math
3
+ from dataclasses import dataclass
4
+
5
+ import hydra.utils
6
+ import lightning as L
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torchmetrics
11
+ from torch import Tensor
12
+
13
+ import dataloader_gosai
14
+ import models
15
+ import noise_schedule
16
+ import utils
17
+ import oracle
18
+ from scipy.stats import wasserstein_distance, pearsonr
19
+ from finetune_utils import to_one_hot
20
+
21
+ LOG2 = math.log(2)
22
+ LOGGER = utils.get_logger(__name__)
23
+
24
+
25
+ def _sample_categorical(categorical_probs):
26
+ gumbel_norm = (
27
+ 1e-10
28
+ - (torch.rand_like(categorical_probs) + 1e-10).log())
29
+ return (categorical_probs / gumbel_norm).argmax(dim=-1).to(dtype=torch.long)
30
+
31
+ def _sample_categorical_gradient(categorical_probs, temp = 1.0):
32
+ gumbel_norm = (
33
+ 1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log())
34
+ output = torch.nn.functional.softmax((torch.log(categorical_probs)-torch.log(gumbel_norm))/temp, 2)
35
+ return output
36
+
37
+ def _unsqueeze(x, reference):
38
+ return x.view(
39
+ * x.shape,
40
+ * ((1,) * (len(reference.shape) - len(x.shape))))
41
+
42
+ def sample_batched_categorical(categorical_probs, batch_size):
43
+ """
44
+ Generates `m` distinct sequences sampled from categorical probabilities
45
+ using the Gumbel distribution to ensure randomness while following probabilities
46
+
47
+ Args:
48
+ categorical_probs (torch.Tensor): tensor of shape (sequence_length, vocab_length)
49
+ representing categorical probabilities
50
+ m (int): number of distinct sequences to sample
51
+
52
+ Returns:
53
+ torch.Tensor: tensor of shape (m, sequence_length), where each row is a
54
+ distinct sequence of sampled category indices.
55
+ """
56
+ _, sequence_length, vocab_size = categorical_probs.shape
57
+
58
+ # add Gumbel noise and sample m sequences
59
+ gumbel_noise = (-torch.log(-torch.log(torch.rand(batch_size, sequence_length, vocab_size) + 1e-10) + 1e-10)).to(categorical_probs.device)
60
+ noisy_scores = torch.log(categorical_probs) + gumbel_noise # add Gumbel noise to log probabilities
61
+
62
+ # select the highest score (most likely category after Gumbel noise)
63
+ sampled_sequences = noisy_scores.argmax(dim=-1).to(dtype=torch.long) # shape: (m, sequence_length)
64
+
65
+ return sampled_sequences
66
+
67
+ def sample_batched_top_k(categorical_probs, batch_size, k):
68
+ """
69
+ Generates `m` sequences sampled from the top-k probabilities of each token
70
+ using Gumbel noise to ensure randomness and reduce bias towards the most likely options.
71
+
72
+ Args:
73
+ categorical_probs (torch.Tensor): A tensor of shape (sequence_length, vocab_length)
74
+ representing categorical probabilities.
75
+ m (int): Number of sequences to sample.
76
+ k (int): Number of top probabilities to consider for sampling.
77
+
78
+ Returns:
79
+ torch.Tensor: A tensor of shape (m, sequence_length), where each row is a
80
+ sampled sequence of category indices.
81
+ """
82
+ _, sequence_length, vocab_length = categorical_probs.shape
83
+
84
+ # Add Gumbel noise to the log probabilities
85
+ gumbel_noise = -torch.log(-torch.log(torch.rand(batch_size, sequence_length, vocab_length) + 1e-10) + 1e-10).to(categorical_probs.device)
86
+ noisy_scores = torch.log(categorical_probs[None, :, :]) + gumbel_noise # Shape: (m, sequence_length, vocab_length)
87
+
88
+ # Get the top-k categories based on noisy scores
89
+ top_k_scores, top_k_indices = torch.topk(noisy_scores, k, dim=-1) # Shape: (m, sequence_length, k)
90
+
91
+ # Convert top-k scores back to probabilities and normalize
92
+ top_k_probs = torch.softmax(top_k_scores, dim=-1).to(categorical_probs.device) # Shape: (m, sequence_length, k)
93
+
94
+ # Sample randomly from the top-k probabilities
95
+ sampled_indices_in_top_k = torch.multinomial(top_k_probs.reshape(-1, k), num_samples=1).squeeze(-1).to(categorical_probs.device)
96
+ sampled_indices_in_top_k = sampled_indices_in_top_k.view(batch_size, sequence_length).to(categorical_probs.device) # Shape: (batch_size, sequence_length)
97
+
98
+ # Map sampled indices back to the original vocabulary indices
99
+ sampled_sequences = torch.gather(top_k_indices, -1, sampled_indices_in_top_k.unsqueeze(-1)).squeeze(-1).to(categorical_probs.device).to(dtype=torch.long)
100
+
101
+ return sampled_sequences
102
+
103
+ @dataclass
104
+ class Loss:
105
+ loss: torch.FloatTensor
106
+ nlls: torch.FloatTensor
107
+ token_mask: torch.FloatTensor
108
+
109
+
110
+ class NLL(torchmetrics.aggregation.MeanMetric):
111
+ pass
112
+
113
+
114
+ class BPD(NLL):
115
+ def compute(self) -> Tensor:
116
+ """Computes the bits per dimension.
117
+
118
+ Returns:
119
+ bpd
120
+ """
121
+ return self.mean_value / self.weight / LOG2
122
+
123
+
124
+ class Perplexity(NLL):
125
+ def compute(self) -> Tensor:
126
+ """Computes the Perplexity.
127
+
128
+ Returns:
129
+ Perplexity
130
+ """
131
+ return torch.exp(self.mean_value / self.weight)
132
+
133
+
134
+ class Diffusion(L.LightningModule):
135
+ def __init__(
136
+ self,
137
+ config,
138
+ eval=False):
139
+
140
+ super().__init__()
141
+ self.save_hyperparameters()
142
+ self.config = config
143
+ self.vocab_size = 4
144
+ self.sampler = self.config.sampling.predictor
145
+ self.antithetic_sampling = self.config.training.antithetic_sampling
146
+ self.importance_sampling = self.config.training.importance_sampling
147
+ self.change_of_variables = self.config.training.change_of_variables
148
+ # add mask token
149
+ self.mask_index = self.vocab_size
150
+ self.vocab_size += 1
151
+ self.parameterization = self.config.parameterization
152
+
153
+ # dna backbone model
154
+ if self.config.backbone == 'cnn':
155
+ self.backbone = models.dnaconv.CNNModel(
156
+ self.config.model, alphabet_size=self.vocab_size, num_cls=3) # num_cls is not used since classifier is always set to False
157
+ else:
158
+ raise ValueError(f'Unknown backbone: {self.config.backbone}')
159
+
160
+ self.T = self.config.T
161
+ self.subs_masking = self.config.subs_masking
162
+
163
+ self.softplus = torch.nn.Softplus()
164
+ # metrics are automatically reset at end of epoch
165
+ metrics = torchmetrics.MetricCollection({
166
+ 'nll': NLL(),
167
+ 'bpd': BPD(),
168
+ 'ppl': Perplexity(),
169
+ })
170
+ metrics.set_dtype(torch.float64)
171
+ self.train_metrics = metrics.clone(prefix='train/')
172
+ self.valid_metrics = metrics.clone(prefix='val/')
173
+ self.test_metrics = metrics.clone(prefix='test/')
174
+
175
+ # generative perplexity
176
+ self.gen_ppl_metric = Perplexity()
177
+ self.noise = noise_schedule.get_noise(self.config,
178
+ dtype=self.dtype)
179
+
180
+ # ema
181
+ if self.config.training.ema > 0:
182
+ self.ema = models.ema.ExponentialMovingAverage(
183
+ itertools.chain(self.backbone.parameters(),
184
+ self.noise.parameters()),
185
+ decay=self.config.training.ema)
186
+ else:
187
+ self.ema = None
188
+
189
+ self.lr = self.config.optim.lr
190
+ self.sampling_eps = self.config.training.sampling_eps
191
+ self.time_conditioning = self.config.time_conditioning
192
+ self.neg_infinity = -1000000.0
193
+ self.fast_forward_epochs = None
194
+ self.fast_forward_batches = None
195
+ self._validate_configuration()
196
+
197
+ # subset of data for evaluation
198
+ if eval:
199
+ self.eval_sets_sp = oracle.subset_for_eval(n=config.eval.subset_size)
200
+ self.eval_sets_sp_clss = oracle.subset_eval_groundtruth(self.eval_sets_sp)
201
+ self.eval_sets_sp_preds = oracle.subset_eval_preds(self.eval_sets_sp)
202
+ self.eval_sets_sp_kmers = oracle.subset_eval_kmers(self.eval_sets_sp)
203
+ self.emb_pca = oracle.cal_emb_pca(oracle.subset_for_eval(n=40000), n_components=50)
204
+ self.eval_sets_sp_embs_pca = oracle.subset_eval_embs_pca(self.eval_sets_sp, self.emb_pca)
205
+
206
+ def _validate_configuration(self):
207
+ assert not (self.change_of_variables and self.importance_sampling)
208
+ assert self.parameterization == 'subs'
209
+
210
+
211
+ def on_load_checkpoint(self, checkpoint):
212
+ if self.ema:
213
+ self.ema.load_state_dict(checkpoint['ema'])
214
+ # Copied from:
215
+ # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py#L41
216
+ self.fast_forward_epochs = checkpoint['loops']['fit_loop']['epoch_progress']['current']['completed']
217
+ self.fast_forward_batches = checkpoint['loops'][
218
+ 'fit_loop']['epoch_loop.batch_progress'][
219
+ 'current']['completed']
220
+
221
+
222
+ def on_save_checkpoint(self, checkpoint):
223
+ if self.ema:
224
+ checkpoint['ema'] = self.ema.state_dict()
225
+ # Copied from:
226
+ # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py
227
+ # ['epoch_loop.batch_progress']['total']['completed'] is 1 iteration
228
+ # behind, so we're using the optimizer's progress.
229
+ checkpoint['loops']['fit_loop'][
230
+ 'epoch_loop.batch_progress']['total'][
231
+ 'completed'] = checkpoint['loops']['fit_loop'][
232
+ 'epoch_loop.automatic_optimization.optim_progress'][
233
+ 'optimizer']['step']['total'][
234
+ 'completed'] * self.trainer.accumulate_grad_batches
235
+ checkpoint['loops']['fit_loop'][
236
+ 'epoch_loop.batch_progress']['current'][
237
+ 'completed'] = checkpoint['loops']['fit_loop'][
238
+ 'epoch_loop.automatic_optimization.optim_progress'][
239
+ 'optimizer']['step']['current'][
240
+ 'completed'] * self.trainer.accumulate_grad_batches
241
+ # _batches_that_stepped tracks the number of global steps, not the number
242
+ # of local steps, so we don't multiply with self.trainer.accumulate_grad_batches here.
243
+ checkpoint['loops']['fit_loop'][
244
+ 'epoch_loop.state_dict'][
245
+ '_batches_that_stepped'] = checkpoint['loops']['fit_loop'][
246
+ 'epoch_loop.automatic_optimization.optim_progress'][
247
+ 'optimizer']['step']['total']['completed']
248
+ if 'sampler' not in checkpoint.keys():
249
+ checkpoint['sampler'] = {}
250
+ if hasattr(self.trainer.train_dataloader.sampler, 'state_dict'):
251
+ sampler_state_dict = self.trainer.train_dataloader.sampler.state_dict()
252
+ checkpoint['sampler']['random_state'] = sampler_state_dict.get('random_state', None)
253
+ else:
254
+ checkpoint['sampler']['random_state'] = None
255
+
256
+ def on_train_start(self):
257
+ if self.ema:
258
+ self.ema.move_shadow_params_to_device(self.device)
259
+
260
+ distributed = (
261
+ self.trainer._accelerator_connector.use_distributed_sampler
262
+ and self.trainer._accelerator_connector.is_distributed)
263
+
264
+ print('distributed:', distributed)
265
+
266
+ if distributed:
267
+ sampler_cls = dataloader_gosai.FaultTolerantDistributedSampler
268
+ else:
269
+ sampler_cls = dataloader_gosai.RandomFaultTolerantSampler
270
+
271
+ updated_dls = []
272
+ for dl in self.trainer.fit_loop._combined_loader.flattened:
273
+ if hasattr(dl.sampler, 'shuffle'):
274
+ dl_sampler = sampler_cls(dl.dataset, shuffle=dl.sampler.shuffle)
275
+ else:
276
+ dl_sampler = sampler_cls(dl.dataset)
277
+ if (distributed and self.fast_forward_epochs is not None
278
+ and self.fast_forward_batches is not None):
279
+
280
+ dl_sampler.load_state_dict({
281
+ 'epoch': self.fast_forward_epochs,
282
+ 'counter': (self.fast_forward_batches
283
+ * self.config.loader.batch_size)})
284
+ updated_dls.append(
285
+ torch.utils.data.DataLoader(
286
+ dl.dataset,
287
+ batch_size=self.config.loader.batch_size,
288
+ num_workers=self.config.loader.num_workers,
289
+ pin_memory=self.config.loader.pin_memory,
290
+ sampler=dl_sampler,
291
+ shuffle=False,
292
+ persistent_workers=True))
293
+
294
+ self.trainer.fit_loop._combined_loader.flattened = updated_dls
295
+
296
+ def optimizer_step(self, *args, **kwargs):
297
+ super().optimizer_step(*args, **kwargs)
298
+ if self.ema:
299
+ self.ema.update(itertools.chain(
300
+ self.backbone.parameters(),
301
+ self.noise.parameters()))
302
+
303
+ # subs parameterization from MDLM
304
+ def _subs_parameterization(self, logits, xt):
305
+ logits[:, :, self.mask_index] += self.neg_infinity
306
+ logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True)
307
+ if xt.ndim > 2 and xt.shape[-1] == self.vocab_size:
308
+ # this is for finetuning setting when the input is one-hot encoded or probs
309
+ xt = xt.argmax(dim=-1)
310
+ unmasked_indices = (xt != self.mask_index)
311
+ logits[unmasked_indices] = self.neg_infinity
312
+ logits[unmasked_indices, xt[unmasked_indices]] = 0
313
+ return logits
314
+
315
+ def _process_sigma(self, sigma):
316
+ if sigma is None:
317
+ assert self.parameterization == 'ar'
318
+ return sigma
319
+ if sigma.ndim > 1:
320
+ sigma = sigma.squeeze(-1)
321
+ if not self.time_conditioning:
322
+ sigma = torch.zeros_like(sigma)
323
+ assert sigma.ndim == 1, sigma.shape
324
+ return sigma
325
+
326
+ def forward(self, x, sigma):
327
+ """Returns log score."""
328
+ sigma = self._process_sigma(sigma)
329
+
330
+ x = x.to(dtype=torch.long)
331
+
332
+ with torch.cuda.amp.autocast(dtype=torch.float32):
333
+ logits = self.backbone(x, sigma)
334
+
335
+ if self.parameterization == 'subs':
336
+ return self._subs_parameterization(logits=logits, xt=x)
337
+
338
+ return logits
339
+
340
+ # might need changing to match wdce loss
341
+ def _compute_loss(self, batch, prefix):
342
+
343
+ if 'attention_mask' in batch:
344
+ attention_mask = batch['attention_mask']
345
+ else:
346
+ attention_mask = None
347
+ losses = self._loss(batch['seqs'], attention_mask)
348
+ loss = losses.loss
349
+
350
+
351
+ if prefix == 'train':
352
+ self.train_metrics.update(losses.nlls, losses.token_mask)
353
+ metrics = self.train_metrics
354
+ elif prefix == 'val':
355
+ self.valid_metrics.update(losses.nlls, losses.token_mask)
356
+ metrics = self.valid_metrics
357
+ elif prefix == 'test':
358
+ self.test_metrics.update(losses.nlls, losses.token_mask)
359
+ metrics = self.test_metrics
360
+ else:
361
+ raise ValueError(f'Invalid prefix: {prefix}')
362
+
363
+ self.log_dict(metrics, on_step=False, on_epoch=True, sync_dist=True)
364
+
365
+ return loss
366
+
367
+ def on_train_epoch_start(self):
368
+ self.backbone.train()
369
+ self.noise.train()
370
+
371
+ def training_step(self, batch, batch_idx):
372
+ loss = self._compute_loss(batch, prefix='train')
373
+ self.log(name='trainer/loss',
374
+ value=loss.item(),
375
+ on_step=True,
376
+ on_epoch=False,
377
+ sync_dist=True)
378
+ return loss
379
+
380
+ def on_validation_epoch_start(self):
381
+ if self.ema:
382
+ self.ema.store(itertools.chain(
383
+ self.backbone.parameters(),
384
+ self.noise.parameters()))
385
+ self.ema.copy_to(itertools.chain(
386
+ self.backbone.parameters(),
387
+ self.noise.parameters()))
388
+ self.backbone.eval()
389
+ self.noise.eval()
390
+ assert self.valid_metrics.nll.mean_value == 0
391
+ assert self.valid_metrics.nll.weight == 0
392
+
393
+ def validation_step(self, batch, batch_idx):
394
+ return self._compute_loss(batch, prefix='val')
395
+
396
+ def on_validation_epoch_end(self):
397
+ if ((self.config.eval.compute_perplexity_on_sanity
398
+ or not self.trainer.sanity_checking)
399
+ and self.config.eval.generate_samples
400
+ and not self.parameterization == 'ar'):
401
+ all_samples, all_detoeknized_samples = [], []
402
+
403
+ for _ in range(self.config.sampling.num_sample_batches):
404
+
405
+ samples = self._sample().detach().cpu().numpy()
406
+ detokenized_samples = dataloader_gosai.batch_dna_detokenize(samples)
407
+ all_samples.append(samples)
408
+ all_detoeknized_samples.extend(detokenized_samples)
409
+
410
+ all_samples = np.concatenate(all_samples, axis=0)
411
+ ws_distance_dict = self.cal_wasserstein_distance(all_detoeknized_samples)
412
+ pearsonr_list = self.cal_kmer_pearsonr(all_detoeknized_samples)
413
+ ws_embpca_list = self.cal_ws_distance_embpca(all_detoeknized_samples)
414
+
415
+ current_step = self.trainer.global_step
416
+ LOGGER.info(f'Current step: {current_step}')
417
+ LOGGER.info(f'Wasserstein distance: {ws_distance_dict}')
418
+ LOGGER.info(f'3mer Pearsonr: {pearsonr_list}')
419
+ LOGGER.info(f'Wasserstein distance embpca: {ws_embpca_list}')
420
+ self.log('val/3mer_pearsonr', pearsonr_list, on_step=False, on_epoch=True, sync_dist=True)
421
+ self.log('val/ws_embpca', ws_embpca_list, on_step=False, on_epoch=True, sync_dist=True)
422
+
423
+ for key in ws_distance_dict:
424
+ for cell_type in ws_distance_dict[key]:
425
+ metric_values = ws_distance_dict[key][cell_type]
426
+ if metric_values: # Check if the list is not empty
427
+ # Assuming metric_values contains [train_metric, valid_metric, test_metric]
428
+ self.log(f'val/{key}_{cell_type}', metric_values[0], on_step=False, on_epoch=True, sync_dist=True)
429
+
430
+ if self.ema:
431
+ self.ema.restore(itertools.chain(self.backbone.parameters(),
432
+ self.noise.parameters()))
433
+
434
+ ### VALIDATION METRICS ###
435
+ def cal_wasserstein_distance(self, seqs):
436
+ generated_preds = oracle.cal_gosai_pred_new(seqs)
437
+ ws_distance_dict = {'truth': {'hepg2': [], 'k562': [], 'sknsh': []},
438
+ 'preds': {'hepg2': [], 'k562': [], 'sknsh': []}}
439
+ ws_distance_dict['truth']['hepg2'].append(wasserstein_distance(generated_preds[:, 0], self.eval_sets_sp_clss[:, 0]))
440
+ ws_distance_dict['truth']['k562'].append(wasserstein_distance(generated_preds[:, 1], self.eval_sets_sp_clss[:, 1]))
441
+ ws_distance_dict['truth']['sknsh'].append(wasserstein_distance(generated_preds[:, 2], self.eval_sets_sp_clss[:, 2]))
442
+ ws_distance_dict['preds']['hepg2'].append(wasserstein_distance(generated_preds[:, 0], self.eval_sets_sp_preds[:, 0]))
443
+ ws_distance_dict['preds']['k562'].append(wasserstein_distance(generated_preds[:, 1], self.eval_sets_sp_preds[:, 1]))
444
+ ws_distance_dict['preds']['sknsh'].append(wasserstein_distance(generated_preds[:, 2], self.eval_sets_sp_preds[:, 2]))
445
+ return ws_distance_dict
446
+
447
+ def cal_ws_distance_embpca(self, seqs):
448
+ generated_embs = oracle.cal_gosai_emb(seqs)
449
+ generated_embs_pca = self.emb_pca.transform(generated_embs.reshape(generated_embs.shape[0], -1))
450
+ return oracle.get_wasserstein_dist(generated_embs_pca, self.eval_sets_sp_embs_pca)
451
+
452
+ def compare_kmer(self, kmer1, kmer2, n_sp1, n_sp2):
453
+ kmer_set = set(kmer1.keys()) | set(kmer2.keys())
454
+ counts = np.zeros((len(kmer_set), 2))
455
+ for i, kmer in enumerate(kmer_set):
456
+ if kmer in kmer1:
457
+ counts[i][1] = kmer1[kmer] * n_sp2 / n_sp1
458
+ if kmer in kmer2:
459
+ counts[i][0] = kmer2[kmer]
460
+ return pearsonr(counts[:, 0], counts[:, 1])[0]
461
+
462
+ def cal_kmer_pearsonr(self, seqs):
463
+ generated_kmer = oracle.count_kmers(seqs)
464
+ return self.compare_kmer(self.eval_sets_sp_kmers, generated_kmer, self.config.eval.subset_size, len(seqs))
465
+
466
+ def configure_optimizers(self):
467
+ optimizer = torch.optim.AdamW(
468
+ itertools.chain(self.backbone.parameters(),
469
+ self.noise.parameters()),
470
+ lr=self.config.optim.lr,
471
+ betas=(self.config.optim.beta1, self.config.optim.beta2),
472
+ eps=self.config.optim.eps,
473
+ weight_decay=self.config.optim.weight_decay)
474
+
475
+ scheduler = hydra.utils.instantiate(self.config.lr_scheduler, optimizer=optimizer)
476
+ scheduler_dict = {
477
+ 'scheduler': scheduler,
478
+ 'interval': 'step',
479
+ 'monitor': 'val/loss',
480
+ 'name': 'trainer/lr',
481
+ }
482
+ return [optimizer], [scheduler_dict]
483
+
484
+ def q_xt(self, x, move_chance):
485
+ """Computes the noisy sample xt.
486
+
487
+ Args:
488
+ x: int torch.Tensor with shape (batch_size,
489
+ diffusion_model_input_length), input.
490
+ move_chance: float torch.Tensor with shape (batch_size, 1).
491
+ """
492
+ move_indices = torch.rand(* x.shape, device=x.device) < move_chance
493
+
494
+ xt = torch.where(move_indices, self.mask_index, x)
495
+ return xt
496
+
497
+ def _sample_prior(self, *batch_dims):
498
+ """
499
+ Returns array of fully masked sequences with same shape as input
500
+ """
501
+ return self.mask_index * torch.ones(* batch_dims, dtype=torch.int64)
502
+
503
+ def _ddpm_caching_update(self, x, t, dt, p_x0=None):
504
+ assert self.config.noise.type == 'loglinear'
505
+ sigma_t, _ = self.noise(t)
506
+ if t.ndim > 1:
507
+ t = t.squeeze(-1)
508
+ assert t.ndim == 1
509
+ move_chance_t = t[:, None, None]
510
+ move_chance_s = (t - dt)[:, None, None]
511
+ assert move_chance_t.ndim == 3, move_chance_t.shape
512
+ if p_x0 is None:
513
+ p_x0 = self.forward(x, sigma_t).exp()
514
+
515
+ assert move_chance_t.ndim == p_x0.ndim
516
+ q_xs = p_x0 * (move_chance_t - move_chance_s)
517
+ q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
518
+ _x = _sample_categorical(q_xs)
519
+
520
+ copy_flag = (x != self.mask_index).to(x.dtype)
521
+ return p_x0, copy_flag * x + (1 - copy_flag) * _x
522
+
523
+ def _ddpm_update(self, x, t, dt, return_process=False):
524
+ sigma_t, _ = self.noise(t)
525
+ sigma_s, _ = self.noise(t - dt)
526
+ if sigma_t.ndim > 1:
527
+ sigma_t = sigma_t.squeeze(-1)
528
+ if sigma_s.ndim > 1:
529
+ sigma_s = sigma_s.squeeze(-1)
530
+ assert sigma_t.ndim == 1, sigma_t.shape
531
+ assert sigma_s.ndim == 1, sigma_s.shape
532
+ move_chance_t = 1 - torch.exp(-sigma_t) # t
533
+ move_chance_s = 1 - torch.exp(-sigma_s)
534
+ move_chance_t = move_chance_t[:, None, None]
535
+ move_chance_s = move_chance_s[:, None, None]
536
+ unet_conditioning = sigma_t
537
+ log_p_x0 = self.forward(x, unet_conditioning)
538
+ assert move_chance_t.ndim == log_p_x0.ndim
539
+ q_xs = log_p_x0.exp() * (move_chance_t - move_chance_s)
540
+ q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
541
+ _x = _sample_categorical(q_xs)
542
+ copy_flag = (x != self.mask_index).to(x.dtype)
543
+ if return_process:
544
+ return copy_flag * x + (1 - copy_flag) * _x, x, unet_conditioning, move_chance_t, copy_flag
545
+ else:
546
+ return copy_flag * x + (1 - copy_flag) * _x
547
+
548
+ def _ar_sampler(self, bsz):
549
+ # precompute token buffer
550
+ num_pred_tokens = self.config.model.length - 1
551
+ x = torch.zeros(
552
+ (bsz, num_pred_tokens + 1),
553
+ dtype=torch.long,
554
+ device=self.device)
555
+ x[:, 0] = self.tokenizer.bos_token_id
556
+ # precompute noise
557
+ noise = (torch.distributions.Gumbel(0, 1)
558
+ .sample((bsz, num_pred_tokens, self.vocab_size))
559
+ .to(self.device))
560
+ for i in range(num_pred_tokens):
561
+ next_logits = self.forward(x[:, :i + 1], None)[:, -1]
562
+ y = (next_logits + noise[:, i]).argmax(-1)
563
+ x[:, i + 1] = y
564
+ return x
565
+
566
+ @torch.no_grad()
567
+ def _sample(self, num_steps=None, eps=1e-5, eval_sp_size=None):
568
+ """Generate samples from the model."""
569
+ if eval_sp_size is None:
570
+ batch_size_per_gpu = self.config.loader.eval_batch_size
571
+ else:
572
+ batch_size_per_gpu = eval_sp_size
573
+ if self.parameterization == 'ar':
574
+ return self._ar_sampler(batch_size_per_gpu)
575
+ # Lightning auto-casting is not working in this method for some reason
576
+ if num_steps is None:
577
+ num_steps = self.config.sampling.steps
578
+ x = self._sample_prior(
579
+ batch_size_per_gpu,
580
+ self.config.model.length).to(self.device)
581
+
582
+ timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
583
+ dt = (1 - eps) / num_steps
584
+ p_x0_cache = None
585
+
586
+ for i in range(num_steps):
587
+ t = timesteps[i] * torch.ones(x.shape[0], 1, device=self.device)
588
+
589
+ if self.sampler == 'ddpm':
590
+ x = self._ddpm_update(x, t, dt)
591
+ elif self.sampler == 'ddpm_cache':
592
+ p_x0_cache, x_next = self._ddpm_caching_update(x, t, dt, p_x0=p_x0_cache)
593
+ if (not torch.allclose(x_next, x) or self.time_conditioning):
594
+ p_x0_cache = None
595
+ x = x_next
596
+ else:
597
+ x = self._analytic_update(x, t, dt)
598
+
599
+ if self.config.sampling.noise_removal:
600
+ t = timesteps[-1] * torch.ones(x.shape[0], 1,
601
+ device=self.device)
602
+ if self.sampler == 'analytic':
603
+ x = self._denoiser_update(x, t)
604
+ else:
605
+ unet_conditioning = self.noise(t)[0]
606
+ logits = self.forward(x, unet_conditioning)
607
+ x = logits[:, :, :-1].argmax(dim=-1)
608
+ return x
609
+
610
+ ### FOR THE EXPANSION AND ROLLOUT STEP ###
611
+ def sample_finetuned_with_rnd(self, args, reward_model,pretrained, eps=1e-5):
612
+ num_steps = args.total_num_steps
613
+ x_rollout = self._sample_prior(
614
+ args.batch_size,
615
+ args.seq_length).to(self.device)
616
+
617
+ log_rnd = torch.zeros(args.batch_size, device=self.device)
618
+
619
+ timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
620
+ dt = (1 - eps) / num_steps
621
+
622
+ for i in range(num_steps):
623
+ t = timesteps[i] * torch.ones(x_rollout.shape[0], 1, device=self.device)
624
+
625
+ log_p, x_next, log_policy_step, log_pretrained_step = self.mcts_reverse_step(x_rollout, t=t, dt=dt, pretrained=pretrained)
626
+ log_rnd += log_pretrained_step - log_policy_step
627
+
628
+ x_rollout = x_next
629
+
630
+ # if mask token remains, fully unmask
631
+ mask_positions = (x_rollout == self.mask_index) # (B, L) bool
632
+
633
+ # does **any** mask remain in any sequence
634
+ any_mask_global = mask_positions.any().item() # true if mask remains
635
+ if any_mask_global:
636
+ log_p, x_next = self.single_noise_removal(x_rollout, t=t, dt=dt)
637
+
638
+ x_rollout = x_next
639
+
640
+ x_final = x_rollout
641
+
642
+ x_one_hot = to_one_hot(x_final)
643
+ x_one_hot_reward = torch.transpose(x_one_hot, 1, 2)
644
+ reward_preds = reward_model(x_one_hot_reward).squeeze(-1) # (num_children, 4)
645
+ rewards = reward_preds[:, 0] # (num_children, 1)
646
+ log_rnd = log_rnd + rewards / args.alpha
647
+ mean_reward = rewards.mean()
648
+
649
+ return x_final, log_rnd, rewards
650
+
651
+ def sample_finetuned(self, args, reward_model, eps=1e-5):
652
+ num_steps = args.total_num_steps
653
+ x_rollout = self._sample_prior(
654
+ args.batch_size,
655
+ args.seq_length).to(self.device)
656
+
657
+ timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
658
+ dt = (1 - eps) / num_steps
659
+
660
+ for i in range(num_steps):
661
+ t = timesteps[i] * torch.ones(x_rollout.shape[0], 1, device=self.device)
662
+
663
+ log_p, x_next = self.single_reverse_step(x_rollout, t=t, dt=dt)
664
+
665
+ x_rollout = x_next
666
+
667
+ # if mask token remains, fully unmask
668
+ mask_positions = (x_rollout == self.mask_index) # (B, L) bool
669
+
670
+ # does **any** mask remain in any sequence
671
+ any_mask_global = mask_positions.any().item() # true if mask remains
672
+ if any_mask_global:
673
+ log_p, x_next = self.single_noise_removal(x_rollout, t=t, dt=dt)
674
+
675
+ x_rollout = x_next
676
+
677
+ x_final = x_rollout
678
+
679
+ x_one_hot = to_one_hot(x_final)
680
+ x_one_hot_reward = torch.transpose(x_one_hot, 1, 2)
681
+ reward_preds = reward_model(x_one_hot_reward).squeeze(-1) # (num_children, 4)
682
+ rewards = reward_preds[:, 0] # (num_children, 1)
683
+
684
+ mean_reward = rewards.mean()
685
+
686
+ return x_final, mean_reward
687
+
688
+ def compute_log_policy(self, token_array, x_next, t, dt):
689
+ sigma_t, _ = self.noise(t)
690
+
691
+ if token_array.ndim == 1:
692
+ token_array = token_array.unsqueeze(0)
693
+
694
+ if x_next.ndim == 1:
695
+ x_next = x_next.unsqueeze(0)
696
+
697
+ if t.ndim > 1:
698
+ t = t.squeeze(-1)
699
+ assert t.ndim == 1
700
+
701
+ change_prob_t = t[:, None, None]
702
+ change_prob_s = (t - dt)[:, None, None]
703
+
704
+ assert change_prob_t.ndim == 3, change_prob_t.shape
705
+
706
+ log_p = self.forward(token_array, sigma=sigma_t)
707
+ p_x0 = log_p.exp()
708
+
709
+ assert change_prob_t.ndim == p_x0.ndim
710
+
711
+ q_xs = p_x0 * (change_prob_t - change_prob_s)
712
+
713
+ # zero-masking probability
714
+ q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
715
+
716
+ copy_flag = (token_array != self.mask_index)
717
+
718
+ assert copy_flag.dtype == torch.bool, "copy_flag must be bool"
719
+ changed_mask = (~copy_flag)
720
+
721
+ # compute the per-sequence log-probability under the pretrained model
722
+ log_policy_token = log_p.gather(-1, x_next.unsqueeze(-1)).squeeze(-1)
723
+
724
+ unmasked_this_step = (changed_mask & (x_next != self.mask_index)).to(log_policy_token.dtype)
725
+ log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1)
726
+
727
+ # returns:
728
+ # log_policy_step (B, ) log probability x_next tokens under policy
729
+ if log_policy_step.ndim == 1:
730
+ log_policy_step = log_policy_step.squeeze(0)
731
+
732
+ return log_policy_step
733
+
734
+
735
+ def single_reverse_step(self, token_array, t, dt, p_x0=None):
736
+ assert self.config.noise.type == 'loglinear'
737
+ sigma_t, _ = self.noise(t)
738
+
739
+ if t.ndim > 1:
740
+ t = t.squeeze(-1)
741
+ assert t.ndim == 1
742
+
743
+ change_prob_t = t[:, None, None]
744
+ change_prob_s = (t - dt)[:, None, None]
745
+
746
+ assert change_prob_t.ndim == 3, change_prob_t.shape
747
+
748
+ if p_x0 is None:
749
+ log_p = self.forward(token_array, sigma=sigma_t)
750
+ p_x0 = log_p.exp()
751
+
752
+ assert change_prob_t.ndim == p_x0.ndim
753
+
754
+ q_xs = p_x0 * (change_prob_t - change_prob_s)
755
+
756
+ # zero-masking probability
757
+ q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
758
+
759
+ x_changed = _sample_categorical(q_xs)
760
+
761
+ copy_flag = (token_array != self.mask_index)
762
+
763
+ int_copy_flag = copy_flag.to(token_array.dtype)
764
+ x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed
765
+
766
+ # returns:
767
+ # log_p (B, L, D) log probabilties of each token under the policy model
768
+ # x_next (B, L) next sequences
769
+ return log_p, x_next
770
+
771
+
772
+ def single_noise_removal(self, token_array, t, dt, p_x0=None):
773
+ assert self.config.noise.type == 'loglinear'
774
+ sigma_t, _ = self.noise(t)
775
+
776
+ if t.ndim > 1:
777
+ t = t.squeeze(-1)
778
+ assert t.ndim == 1
779
+
780
+ change_prob_t = t[:, None, None]
781
+ change_prob_s = (t - dt)[:, None, None]
782
+
783
+ assert change_prob_t.ndim == 3, change_prob_t.shape
784
+
785
+ if p_x0 is None:
786
+ log_p = self.forward(token_array, sigma=sigma_t)
787
+ p_x0 = log_p.exp()
788
+
789
+ assert change_prob_t.ndim == p_x0.ndim
790
+
791
+ # changed for noise removal
792
+ p_x0 = p_x0.clone()
793
+ p_x0[:, :, self.mask_index] = 0.0 # prevent remaining a mask
794
+ p_x0 = p_x0 / p_x0.sum(dim=-1, keepdim=True).clamp_min(1e-12) # renorm over non-MASK
795
+ q_xs = p_x0 * (change_prob_t - change_prob_s)
796
+
797
+ x_changed = _sample_categorical(q_xs)
798
+
799
+ copy_flag = (token_array != self.mask_index)
800
+
801
+ int_copy_flag = copy_flag.to(token_array.dtype)
802
+ x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed
803
+
804
+ # returns:
805
+ # log_p (B, L, D) log probabilties of each token under the policy model
806
+ # x_next (B, L) next sequences
807
+ return log_p, x_next
808
+
809
+ def mcts_reverse_step(self, token_array, t, dt, pretrained, p_x0=None):
810
+ assert self.config.noise.type == 'loglinear'
811
+ sigma_t, _ = self.noise(t)
812
+
813
+ if t.ndim > 1:
814
+ t = t.squeeze(-1)
815
+ assert t.ndim == 1
816
+
817
+ change_prob_t = t[:, None, None]
818
+ change_prob_s = (t - dt)[:, None, None]
819
+
820
+ assert change_prob_t.ndim == 3, change_prob_t.shape
821
+
822
+ if p_x0 is None:
823
+ log_p = self.forward(token_array, sigma=sigma_t)
824
+ p_x0 = log_p.exp()
825
+
826
+ assert change_prob_t.ndim == p_x0.ndim
827
+
828
+ q_xs = p_x0 * (change_prob_t - change_prob_s)
829
+
830
+ # zero-masking probability
831
+ q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
832
+
833
+ x_changed = _sample_categorical(q_xs)
834
+
835
+ copy_flag = (token_array != self.mask_index)
836
+
837
+ int_copy_flag = copy_flag.to(token_array.dtype)
838
+ x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed
839
+
840
+ # compute the log-probability under pretrained model at each step
841
+ with torch.no_grad():
842
+ # pretrained should output log-probs over vocab at each position given the *parent* (masked) input
843
+ log_pre = pretrained.forward(token_array, sigma=sigma_t)
844
+
845
+ # log-prob of the *sampled token* at each position
846
+ log_pre_token = log_pre.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
847
+
848
+ # sum only over the sites actually sampled this step (i.e., where parent was mask)
849
+
850
+ assert copy_flag.dtype == torch.bool, "copy_flag must be bool"
851
+ changed_mask = (~copy_flag)
852
+ # mask of tokens that were unmasked in this step
853
+ unmasked_this_step = (changed_mask & (x_next != self.mask_index)).to(log_pre_token.dtype)
854
+
855
+ log_pretrained_step = (log_pre_token * unmasked_this_step).sum(dim=-1)
856
+
857
+ # compute the per-sequence log-probability under the pretrained model
858
+ log_policy_token = log_p.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
859
+ log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1)
860
+
861
+ # returns:
862
+ # log_p (B, L, D) log probabilties of each token under the policy model
863
+ # x_next (B, L) next sequences
864
+ # log_policy_step (B, ) log probability of all unmasked tokens under policy
865
+ # log_pretrained_step (B, ) log probabiltiy of all unmasked tokens under pretrained model
866
+ return log_p, x_next, log_policy_step, log_pretrained_step
867
+
868
+ def mcts_noise_removal(self, token_array, t, dt, pretrained, p_x0=None):
869
+ assert self.config.noise.type == 'loglinear'
870
+ sigma_t, _ = self.noise(t)
871
+
872
+ if t.ndim > 1:
873
+ t = t.squeeze(-1)
874
+ assert t.ndim == 1
875
+
876
+ change_prob_t = t[:, None, None]
877
+ change_prob_s = (t - dt)[:, None, None]
878
+
879
+ assert change_prob_t.ndim == 3, change_prob_t.shape
880
+
881
+ if p_x0 is None:
882
+ log_p = self.forward(token_array, sigma=sigma_t)
883
+ p_x0 = log_p.exp()
884
+
885
+ assert change_prob_t.ndim == p_x0.ndim
886
+
887
+ # changed for noise removal
888
+ p_x0 = p_x0.clone()
889
+ p_x0[:, :, self.mask_index] = 0.0 # prevent remaining a mask
890
+ p_x0 = p_x0 / p_x0.sum(dim=-1, keepdim=True).clamp_min(1e-12) # renorm over non-MASK
891
+ q_xs = p_x0 * (change_prob_t - change_prob_s)
892
+
893
+ x_changed = _sample_categorical(q_xs)
894
+
895
+ copy_flag = (token_array != self.mask_index)
896
+
897
+ int_copy_flag = copy_flag.to(token_array.dtype)
898
+ x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed
899
+
900
+ # compute the log-probability under pretrained model at each step
901
+ with torch.no_grad():
902
+ # pretrained should output log-probs over vocab at each position given the *parent* (masked) input
903
+ log_pre = pretrained.forward(token_array, sigma=sigma_t)
904
+
905
+ # log-prob of the *sampled token* at each position
906
+ log_pre_token = log_pre.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
907
+
908
+ # sum only over the sites actually sampled this step (i.e., where parent was mask)
909
+
910
+ assert copy_flag.dtype == torch.bool, "copy_flag must be bool"
911
+ changed_mask = (~copy_flag)
912
+ # mask of tokens that were unmasked in this step
913
+ unmasked_this_step = (changed_mask & (x_next != self.mask_index)).to(log_pre_token.dtype)
914
+
915
+ log_pretrained_step = (log_pre_token * unmasked_this_step).sum(dim=-1)
916
+
917
+ # compute the per-sequence log-probability under the pretrained model
918
+ log_policy_token = log_p.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
919
+ log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1)
920
+
921
+ # returns:
922
+ # log_p (B, L, D) log probabilties of each token under the policy model
923
+ # x_next (B, L) next sequences
924
+ # log_policy_step (B, ) log probability of all unmasked tokens under policy
925
+ # log_pretrained_step (B, ) log probabiltiy of all unmasked tokens under pretrained model
926
+ return log_p, x_next, log_policy_step, log_pretrained_step
927
+
928
+ # first step in expansion
929
+ def batch_mcts_reverse_step(self, token_array, t, dt, batch_size, pretrained, p_x0=None):
930
+
931
+ assert self.config.noise.type == 'loglinear'
932
+ sigma_t, _ = self.noise(t)
933
+
934
+ if t.ndim > 1:
935
+ t = t.squeeze(-1)
936
+ assert t.ndim == 1
937
+
938
+ change_prob_t = t[:, None, None]
939
+ change_prob_s = (t - dt)[:, None, None]
940
+
941
+ assert change_prob_t.ndim == 3, change_prob_t.shape
942
+
943
+ if token_array.dim() == 1:
944
+ token_array = token_array.unsqueeze(0)
945
+
946
+ # expand to match (num_children, L)
947
+
948
+ if p_x0 is None:
949
+ log_p = self.forward(token_array, sigma=sigma_t)
950
+ p_x0 = log_p.exp()
951
+
952
+ assert change_prob_t.ndim == p_x0.ndim
953
+
954
+ q_xs = p_x0 * (change_prob_t - change_prob_s)
955
+
956
+ # zero-masking probability
957
+ q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
958
+
959
+ # repeat the parent token along the first dimension which will be unmasked into distinct sequences
960
+ token_array_expanded = token_array.repeat(batch_size, 1)
961
+
962
+ if self.config.mcts.sampling == 0:
963
+ x_changed = sample_batched_categorical(q_xs.to(self.device), batch_size)
964
+ else:
965
+ x_changed = sample_batched_top_k(q_xs.to(self.device), batch_size, self.config.mcts.sampling)
966
+
967
+ copy_flag = (token_array_expanded != self.mask_index)
968
+
969
+ int_copy_flag = copy_flag.to(token_array.dtype)
970
+ x_children = int_copy_flag * token_array_expanded + (1 - int_copy_flag) * x_changed
971
+
972
+
973
+ # compute the log-probability under pretrained model at each step
974
+ with torch.no_grad():
975
+ # pretrained should output log-probs over vocab at each position given the *parent* (masked) input
976
+ log_pre = pretrained.forward(token_array, sigma=sigma_t)
977
+
978
+ # expand to match the shape of x_children
979
+ log_pre = log_pre.repeat(batch_size, 1, 1)
980
+
981
+ # log-prob of the *sampled token* at each position
982
+ log_pre_token = log_pre.gather(-1, x_children.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
983
+
984
+ # sum only over the sites actually sampled this step (i.e., where parent was mask)
985
+
986
+ assert copy_flag.dtype == torch.bool, "copy_flag must be bool"
987
+ changed_mask = (~copy_flag)
988
+ # mask of tokens that were unmasked in this step
989
+ unmasked_this_step = (changed_mask & (x_children != self.mask_index)).to(log_pre_token.dtype)
990
+
991
+ log_pretrained_step = (log_pre_token * unmasked_this_step).sum(dim=-1)
992
+
993
+ # compute the per-child log-probability under the pretrained model
994
+ log_p = log_p.repeat(batch_size, 1, 1)
995
+ log_policy_token = log_p.gather(-1, x_children.unsqueeze(-1)).squeeze(-1) # (B, L) probability of each chosen token
996
+ #print(log_policy_token)
997
+ log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1)
998
+
999
+ # returns:
1000
+ # log_p (B, L, D) log probabilties of each token under the policy model
1001
+ # x_children (B, L) child sequences
1002
+ # log_policy_step (B, ) log probability of all unmasked tokens under policy
1003
+ # log_pretrained_step (B, ) log probabiltiy of all unmasked tokens under pretrained model
1004
+ return log_p, x_children, log_policy_step, log_pretrained_step
1005
+
1006
+ ### SPECIFIC TO DRAKES? ###
1007
+ def _ddpm_update_finetune_gradient(self, x, t, dt, copy_flag_temp, return_process=False):
1008
+
1009
+ if x.ndim == 2 or x.shape[-1] != self.vocab_size:
1010
+ x = F.one_hot(x, num_classes=self.vocab_size).to(torch.float32)
1011
+
1012
+ sigma_t, _ = self.noise(t)
1013
+ sigma_s, _ = self.noise(t - dt)
1014
+ if sigma_t.ndim > 1:
1015
+ sigma_t = sigma_t.squeeze(-1)
1016
+ if sigma_s.ndim > 1:
1017
+ sigma_s = sigma_s.squeeze(-1)
1018
+ assert sigma_t.ndim == 1, sigma_t.shape
1019
+ assert sigma_s.ndim == 1, sigma_s.shape
1020
+ move_chance_t = 1 - torch.exp(-sigma_t) # (1-eps)*t
1021
+ move_chance_s = 1 - torch.exp(-sigma_s)
1022
+ move_chance_t = move_chance_t[:, None, None]
1023
+ move_chance_s = move_chance_s[:, None, None]
1024
+ unet_conditioning = sigma_t
1025
+ log_p_x0 = self.forward(x, unet_conditioning)
1026
+ assert move_chance_t.ndim == log_p_x0.ndim
1027
+ q_xs = log_p_x0.exp() * (move_chance_t - move_chance_s)
1028
+
1029
+ q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
1030
+ _x = _sample_categorical_gradient(q_xs, temp=self.config.finetuning.gumbel_softmax_temp)
1031
+
1032
+ if copy_flag_temp is not None:
1033
+ copy_flag_prob = 1 - x[:, :, self.mask_index].unsqueeze(-1)
1034
+ soft_copy_flag = torch.nn.functional.sigmoid(copy_flag_prob/copy_flag_temp)
1035
+ else:
1036
+ soft_copy_flag = 1 - x[:, :, self.mask_index].unsqueeze(-1)
1037
+
1038
+ if return_process:
1039
+ return soft_copy_flag * x + (1 - soft_copy_flag) * _x, x, unet_conditioning, move_chance_t, soft_copy_flag
1040
+ else:
1041
+ return soft_copy_flag * x + (1 - soft_copy_flag) * _x
1042
+
1043
+
1044
+ def _sample_finetune_gradient(self, num_steps=None, eps=1e-5, eval_sp_size=None, copy_flag_temp=None):
1045
+ """Generate samples from the model."""
1046
+ assert self.parameterization == 'subs' and self.sampler == 'ddpm'
1047
+ if eval_sp_size is None:
1048
+ batch_size_per_gpu = self.config.loader.eval_batch_size
1049
+ else:
1050
+ batch_size_per_gpu = eval_sp_size
1051
+ if num_steps is None:
1052
+ num_steps = self.config.sampling.steps
1053
+ x = self._sample_prior(
1054
+ batch_size_per_gpu,
1055
+ self.config.model.length).to(self.device)
1056
+ timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
1057
+ dt = (1 - eps) / num_steps
1058
+ p_x0_cache = None
1059
+
1060
+ last_x_list = []
1061
+ condt_list = []
1062
+ move_chance_t_list = []
1063
+ copy_flag_list = []
1064
+
1065
+ for i in range(num_steps):
1066
+ t = timesteps[i] * torch.ones(x.shape[0], 1, device=self.device)
1067
+ if self.sampler == 'ddpm':
1068
+ if i < num_steps - self.config.finetuning.truncate_steps:
1069
+ x, last_x, condt, move_chance_t, copy_flag = self._ddpm_update(x, t, dt, return_process=True)
1070
+ x = x.detach()
1071
+ copy_flag = copy_flag.unsqueeze(-1)
1072
+ last_x = F.one_hot(last_x, num_classes=self.vocab_size).to(torch.float32).detach()
1073
+ else:
1074
+ x, last_x, condt, move_chance_t, copy_flag = self._ddpm_update_finetune_gradient(x, t, dt, copy_flag_temp, return_process=True)
1075
+
1076
+ last_x_list.append(last_x)
1077
+ condt_list.append(condt)
1078
+ move_chance_t_list.append(move_chance_t)
1079
+ copy_flag_list.append(copy_flag)
1080
+
1081
+ x_argmax = x[:, :, :-1].argmax(dim=-1)
1082
+ x_argmax = torch.nn.functional.one_hot(x_argmax, num_classes=self.vocab_size-1).to(torch.float32)
1083
+ return x[:, :, :-1] + (x_argmax - x[:, :, :-1]).detach(), last_x_list, condt_list, move_chance_t_list, copy_flag_list
1084
+
1085
+ @torch.no_grad()
1086
+ def _ddpm_update_finetune_controlled_SMC(self, x, t, dt, reward_model, alpha = 1.0):
1087
+
1088
+ sigma_t, _ = self.noise(t)
1089
+ sigma_s, _ = self.noise(t - dt)
1090
+ if sigma_t.ndim > 1:
1091
+ sigma_t = sigma_t.squeeze(-1)
1092
+ if sigma_s.ndim > 1:
1093
+ sigma_s = sigma_s.squeeze(-1)
1094
+ assert sigma_t.ndim == 1, sigma_t.shape
1095
+ assert sigma_s.ndim == 1, sigma_s.shape
1096
+ move_chance_t = 1 - torch.exp(-sigma_t)
1097
+ move_chance_s = 1 - torch.exp(-sigma_s)
1098
+ move_chance_t = move_chance_t[:, None, None]
1099
+ move_chance_s = move_chance_s[:, None, None]
1100
+ unet_conditioning = sigma_t
1101
+ log_p_x0 = self.forward(x, unet_conditioning)
1102
+ assert move_chance_t.ndim == log_p_x0.ndim
1103
+ q_xs = log_p_x0.exp() * (move_chance_t - move_chance_s)
1104
+ q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
1105
+ copy_flag = (x != self.mask_index).to(x.dtype)
1106
+ sample = copy_flag * x + (1 - copy_flag) * _sample_categorical(q_xs)
1107
+ '''
1108
+ Calcualte exp(v_{t-1}(x_{t-1})/alpha)
1109
+ '''
1110
+ expected_x0 = self.forward(sample, sigma_s) # Calcualte E[x_0|x_{t-1}]
1111
+ expected_x0_arg = torch.argmax(expected_x0,dim=2)
1112
+ expected_x0_onehot = torch.nn.functional.one_hot(expected_x0_arg)
1113
+ reward_num = reward_model(expected_x0_onehot.float().transpose(1, 2)).detach()[:, 0][:, 0]
1114
+ '''
1115
+ Calcualte exp(v_{t}(x_{t})/alpha)
1116
+ '''
1117
+ expected_x0 = self.forward(x, sigma_s) # Calcualte E[x_0|x_t]
1118
+ expected_x0_arg = torch.argmax(expected_x0,dim=2)
1119
+ expected_x0_onehot = torch.nn.functional.one_hot(expected_x0_arg)
1120
+ reward_den = reward_model(expected_x0_onehot.float().transpose(1, 2)).detach()[:, 0][:, 0]
1121
+
1122
+ ratio = torch.exp(1.0/alpha * (reward_num - reward_den)) # Now calculate exp( (v_{t-1}(x_{t-1) -v_{t}(x_{t}) /alpha)
1123
+ ratio = ratio.detach().cpu().numpy()
1124
+ final_sample_indices = np.random.choice(reward_num.shape[0], reward_num.shape[0], p = ratio/ratio.sum() )
1125
+
1126
+ return sample[final_sample_indices]
1127
+
1128
+ def _ddpm_update_finetune_controlled_CG(self, x, t, dt, reward_model, guidance_scale):
1129
+
1130
+ sigma_t, _ = self.noise(t)
1131
+ sigma_s, _ = self.noise(t - dt)
1132
+ if sigma_t.ndim > 1:
1133
+ sigma_t = sigma_t.squeeze(-1)
1134
+ if sigma_s.ndim > 1:
1135
+ sigma_s = sigma_s.squeeze(-1)
1136
+ assert sigma_t.ndim == 1, sigma_t.shape
1137
+ assert sigma_s.ndim == 1, sigma_s.shape
1138
+ move_chance_t = 1 - torch.exp(-sigma_t)
1139
+ move_chance_s = 1 - torch.exp(-sigma_s)
1140
+ move_chance_t = move_chance_t[:, None, None]
1141
+ move_chance_s = move_chance_s[:, None, None]
1142
+ unet_conditioning = sigma_t
1143
+ log_p_x0 = self.forward(x, unet_conditioning)
1144
+ assert move_chance_t.ndim == log_p_x0.ndim
1145
+ q_xs = log_p_x0.exp() * (move_chance_t - move_chance_s)
1146
+ x_onehot = F.one_hot(x, num_classes=5).float()
1147
+
1148
+ x_grad = self.compute_gradient_CG(x_onehot, x, reward_model, sigma_s )
1149
+ guidance = guidance_scale * (x_grad - x_grad[:, :, self.mask_index][:, :, None])
1150
+ q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
1151
+ q_xs = q_xs * guidance.exp()
1152
+
1153
+ _x = _sample_categorical(q_xs)
1154
+ copy_flag = (x != self.mask_index).to(x.dtype)
1155
+ return copy_flag * x + (1 - copy_flag) * _x
1156
+
1157
+ def compute_gradient_CG(self, x_onehot, x, reward_model, sigma_s):
1158
+ x_onehot.requires_grad_(True)
1159
+ expected_x0 = self.forward(x_onehot, sigma_s) # Calcualte E[x_0|x_t]
1160
+ scores = reward_model(expected_x0.transpose(1, 2)[:,0:4,:])[:, 0]
1161
+ scores = scores.mean()
1162
+ scores.backward()
1163
+ x_grad = x_onehot.grad.clone()
1164
+ return x_grad
1165
+
1166
+ def _ddpm_update_finetune_controlled_TDS(self, x, t, dt, reward_model, alpha = 1.0, guidance_scale=1000):
1167
+ # SMC with the twisted proposal
1168
+
1169
+ sigma_t, _ = self.noise(t)
1170
+ sigma_s, _ = self.noise(t - dt)
1171
+ if sigma_t.ndim > 1:
1172
+ sigma_t = sigma_t.squeeze(-1)
1173
+ if sigma_s.ndim > 1:
1174
+ sigma_s = sigma_s.squeeze(-1)
1175
+ assert sigma_t.ndim == 1, sigma_t.shape
1176
+ assert sigma_s.ndim == 1, sigma_s.shape
1177
+ move_chance_t = 1 - torch.exp(-sigma_t)
1178
+ move_chance_s = 1 - torch.exp(-sigma_s)
1179
+ move_chance_t = move_chance_t[:, None, None]
1180
+ move_chance_s = move_chance_s[:, None, None]
1181
+ unet_conditioning = sigma_t
1182
+ log_p_x0 = self.forward(x, unet_conditioning)
1183
+ assert move_chance_t.ndim == log_p_x0.ndim
1184
+ q_xs = log_p_x0.exp() * (move_chance_t
1185
+ - move_chance_s)
1186
+ x_onehot = F.one_hot(x, num_classes=5).float()
1187
+
1188
+ x_grad = self.compute_gradient_CG(x_onehot, x, reward_model, sigma_s )
1189
+ guidance = guidance_scale * (x_grad - x_grad[:, :, self.mask_index][:, :, None])
1190
+ q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
1191
+ # print(q_xs.sum(-1))
1192
+ q_xs = q_xs * guidance.exp()
1193
+
1194
+ _x = _sample_categorical(q_xs)
1195
+ copy_flag = (x != self.mask_index).to(x.dtype)
1196
+ sample = copy_flag * x + (1 - copy_flag) * _x
1197
+ prob_multiplier = (1 - copy_flag) * torch.gather(guidance.exp(), 2, _x.unsqueeze(-1)).squeeze(-1) + copy_flag * torch.ones_like(_x)
1198
+ '''
1199
+ Calcualte exp(v_{t-1}(x_{t-1})/alpha)
1200
+ '''
1201
+ expected_x0 = self.forward(sample, sigma_s) # Calcualte E[x_0|x_{t-1}]
1202
+ expected_x0_arg = torch.argmax(expected_x0,dim=2)
1203
+ expected_x0_onehot = torch.nn.functional.one_hot(expected_x0_arg)
1204
+ reward_num = reward_model(expected_x0_onehot.float().transpose(1, 2)).detach()[:, 0][:, 0]
1205
+ '''
1206
+ Calcualte exp(v_{t}(x_{t})/alpha)
1207
+ '''
1208
+ expected_x0 = self.forward(x, sigma_s) # Calcualte E[x_0|x_t]
1209
+ expected_x0_arg = torch.argmax(expected_x0,dim=2)
1210
+ expected_x0_onehot = torch.nn.functional.one_hot(expected_x0_arg)
1211
+ reward_den = reward_model(expected_x0_onehot.float().transpose(1, 2)).detach()[:, 0][:, 0]
1212
+
1213
+ # set the nan values to 1
1214
+ prob_multiplier[torch.isnan(prob_multiplier)] = 1
1215
+ ratio = torch.exp(1.0/alpha * (reward_num - reward_den)) / prob_multiplier.prod(dim=-1)
1216
+ ratio = ratio.detach().cpu().numpy()
1217
+ final_sample_indices = np.random.choice(reward_num.shape[0], reward_num.shape[0], p = ratio/ratio.sum() )
1218
+
1219
+ return sample[final_sample_indices]
1220
+
1221
+ @torch.no_grad()
1222
+ def controlled_sample_SMC(self, reward_model, alpha, num_steps=None, eps=1e-5, eval_sp_size=None):
1223
+ """Generate samples from the model."""
1224
+ if eval_sp_size is None:
1225
+ batch_size_per_gpu = self.config.loader.eval_batch_size
1226
+ else:
1227
+ batch_size_per_gpu = eval_sp_size
1228
+ if self.parameterization == 'ar':
1229
+ return self._ar_sampler(batch_size_per_gpu)
1230
+ # Lightning auto-casting is not working in this method for some reason
1231
+ if num_steps is None:
1232
+ num_steps = self.config.sampling.steps
1233
+ x = self._sample_prior(
1234
+ batch_size_per_gpu,
1235
+ self.config.model.length).to(self.device)
1236
+ timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
1237
+ dt = (1 - eps) / num_steps
1238
+ p_x0_cache = None
1239
+
1240
+ for i in range(num_steps):
1241
+ t = timesteps[i] * torch.ones(
1242
+ x.shape[0], 1, device=self.device)
1243
+ if self.sampler == 'ddpm':
1244
+ x = self._ddpm_update_finetune_controlled_SMC(x, t, dt, reward_model, alpha)
1245
+ else:
1246
+ x = self._analytic_update(x, t, dt)
1247
+
1248
+ if self.config.sampling.noise_removal:
1249
+ t = timesteps[-1] * torch.ones(x.shape[0], 1, device=self.device)
1250
+ if self.sampler == 'analytic':
1251
+ x = self._denoiser_update(x, t)
1252
+ else:
1253
+ unet_conditioning = self.noise(t)[0]
1254
+ logits = self.forward(x, unet_conditioning)
1255
+ x = logits[:, :, :-1].argmax(dim=-1)
1256
+ return x
1257
+
1258
+ def controlled_sample_CG(self, reward_model, guidance_scale, num_steps=None, eps=1e-5, eval_sp_size=None):
1259
+ """Generate samples from the model."""
1260
+ if eval_sp_size is None:
1261
+ batch_size_per_gpu = self.config.loader.eval_batch_size
1262
+ else:
1263
+ batch_size_per_gpu = eval_sp_size
1264
+ if self.parameterization == 'ar':
1265
+ return self._ar_sampler(batch_size_per_gpu)
1266
+ # Lightning auto-casting is not working in this method for some reason
1267
+ if num_steps is None:
1268
+ num_steps = self.config.sampling.steps
1269
+ x = self._sample_prior(
1270
+ batch_size_per_gpu,
1271
+ self.config.model.length).to(self.device)
1272
+ timesteps = torch.linspace(
1273
+ 1, eps, num_steps + 1, device=self.device)
1274
+ dt = (1 - eps) / num_steps
1275
+ p_x0_cache = None
1276
+
1277
+ for i in range(num_steps):
1278
+ t = timesteps[i] * torch.ones(
1279
+ x.shape[0], 1, device=self.device)
1280
+ if self.sampler == 'ddpm':
1281
+ x = self._ddpm_update_finetune_controlled_CG(x, t, dt, reward_model, guidance_scale)
1282
+ else:
1283
+ x = self._analytic_update(x, t, dt)
1284
+
1285
+ if self.config.sampling.noise_removal:
1286
+ t = timesteps[-1] * torch.ones(x.shape[0], 1,
1287
+ device=self.device)
1288
+ if self.sampler == 'analytic':
1289
+ x = self._denoiser_update(x, t)
1290
+ else:
1291
+ unet_conditioning = self.noise(t)[0]
1292
+ logits = self.forward(x, unet_conditioning)
1293
+ x = logits[:, :, :-1].argmax(dim=-1)
1294
+ return x
1295
+
1296
+ def controlled_sample_TDS(self, reward_model, alpha, guidance_scale, num_steps=None, eps=1e-5, eval_sp_size=None):
1297
+ """Generate samples from the model."""
1298
+ if eval_sp_size is None:
1299
+ batch_size_per_gpu = self.config.loader.eval_batch_size
1300
+ else:
1301
+ batch_size_per_gpu = eval_sp_size
1302
+
1303
+ if self.parameterization == 'ar':
1304
+ return self._ar_sampler(batch_size_per_gpu)
1305
+
1306
+ if num_steps is None:
1307
+ num_steps = self.config.sampling.steps
1308
+ x = self._sample_prior(
1309
+ batch_size_per_gpu,
1310
+ self.config.model.length).to(self.device)
1311
+ timesteps = torch.linspace(
1312
+ 1, eps, num_steps + 1, device=self.device)
1313
+ dt = (1 - eps) / num_steps
1314
+ p_x0_cache = None
1315
+
1316
+ for i in range(num_steps):
1317
+ t = timesteps[i] * torch.ones(
1318
+ x.shape[0], 1, device=self.device)
1319
+ if self.sampler == 'ddpm':
1320
+ x = self._ddpm_update_finetune_controlled_TDS(x, t, dt, reward_model,alpha, guidance_scale)
1321
+ else:
1322
+ x = self._analytic_update(x, t, dt)
1323
+
1324
+ if self.config.sampling.noise_removal:
1325
+ t = timesteps[-1] * torch.ones(x.shape[0], 1,
1326
+ device=self.device)
1327
+ if self.sampler == 'analytic':
1328
+ x = self._denoiser_update(x, t)
1329
+ else:
1330
+ unet_conditioning = self.noise(t)[0]
1331
+ logits = self.forward(x, unet_conditioning)
1332
+ x = logits[:, :, :-1].argmax(dim=-1)
1333
+ return x
1334
+
1335
+ @torch.no_grad()
1336
+ def get_likelihood(self, x0, num_steps=None, eps=1e-5, n_samples=1):
1337
+ """Compute the likelihood of a sequence under the model.
1338
+ x0: int torch.Tensor with shape (batch_size,
1339
+ diffusion_model_input_length)
1340
+ """
1341
+ if num_steps is None:
1342
+ num_steps = self.config.sampling.steps
1343
+ timesteps = torch.linspace(
1344
+ 1, eps, num_steps + 1, device=self.device) # t=0 is clean data
1345
+ dt = (1 - eps) / num_steps
1346
+ log_p_sample_list = []
1347
+ for _ in range(n_samples):
1348
+ log_p_at_time_list = []
1349
+ for i in range(num_steps):
1350
+ t = timesteps[i] * torch.ones(
1351
+ x0.shape[0], 1, device=self.device)
1352
+ sigma_t, _ = self.noise(t)
1353
+ sigma_s, _ = self.noise(t - dt)
1354
+ if sigma_t.ndim > 1:
1355
+ sigma_t = sigma_t.squeeze(-1)
1356
+ if sigma_s.ndim > 1:
1357
+ sigma_s = sigma_s.squeeze(-1)
1358
+ assert sigma_t.ndim == 1, sigma_t.shape
1359
+ assert sigma_s.ndim == 1, sigma_s.shape
1360
+ move_chance_t = 1 - torch.exp(-sigma_t) # (1-eps)*t
1361
+ move_chance_s = 1 - torch.exp(-sigma_s)
1362
+ move_chance_t = move_chance_t[:, None] # [bsz, 1]
1363
+ move_chance_s = move_chance_s[:, None]
1364
+ unet_conditioning = sigma_t # [bsz]
1365
+ multiplier = (move_chance_t - move_chance_s)/move_chance_t # [bsz, 1]
1366
+ xt = self.q_xt(x0, move_chance_t) # [bsz, seq_len]
1367
+ # log prob, already apply subs parametrization (unmasked token remains unchanged)
1368
+ model_output = self.forward(xt, unet_conditioning) # [bsz, seq_len, vocab_size]
1369
+ # take the log prob of the token that corresponds to x0
1370
+ log_p_x0 = model_output.gather(-1, x0[..., None]).squeeze(-1) # [bsz, seq_len]
1371
+ log_p_x0 = log_p_x0 * multiplier
1372
+ log_p_at_time_list.append(log_p_x0)
1373
+ log_p_x0 = torch.stack(log_p_at_time_list, dim=0).sum(dim=0) # [bsz, seq_len]
1374
+ log_p_sample_list.append(log_p_x0.sum(dim=-1))
1375
+ log_p_sample = torch.stack(log_p_sample_list, dim=0).mean(dim=0)
1376
+ return log_p_sample
1377
+
1378
+ def get_score(self, x, sigma):
1379
+ model_output = self.forward(x, sigma)
1380
+ if self.parameterization == 'subs':
1381
+ # score(x, t) = p_t(y) / p_t(x)
1382
+ # => log score(x, t) = log p_t(y) - log p_t(x)
1383
+
1384
+ # case 1: x = masked
1385
+ # (i) y = unmasked
1386
+ # log score(x, t) = log p_\theta(x)|_y + log k
1387
+ # where k = exp(- sigma) / (1 - exp(- sigma))
1388
+ # (ii) y = masked
1389
+ # log score(x, t) = 0
1390
+
1391
+ # case 2: x = unmasked
1392
+ # (i) y != masked, y != x
1393
+ # log score(x_i, t) = - inf
1394
+ # (ii) y = x
1395
+ # log score(x_i, t) = 0
1396
+ # (iii) y = masked token
1397
+ # log score(x_i, t) = - log k
1398
+ # where k = exp(- sigma) / (1 - exp(- sigma))
1399
+
1400
+ log_k = - torch.log(torch.expm1(sigma)).squeeze(-1)
1401
+ assert log_k.ndim == 1
1402
+
1403
+ masked_score = model_output + log_k[:, None, None]
1404
+ masked_score[:, :, self.mask_index] = 0
1405
+
1406
+ unmasked_score = self.neg_infinity * torch.ones_like(
1407
+ model_output)
1408
+ unmasked_score = torch.scatter(
1409
+ unmasked_score,
1410
+ -1,
1411
+ x[..., None],
1412
+ torch.zeros_like(unmasked_score[..., :1]))
1413
+ unmasked_score[:, :, self.mask_index] = - (
1414
+ log_k[:, None] * torch.ones_like(x))
1415
+
1416
+ masked_indices = (x == self.mask_index).to(
1417
+ model_output.dtype)[:, :, None]
1418
+ model_output = (
1419
+ masked_score * masked_indices
1420
+ + unmasked_score * (1 - masked_indices))
1421
+ return model_output.exp()
1422
+
1423
+ def _staggered_score(self, score, dsigma):
1424
+ score = score.clone()
1425
+ extra_const = (1 - dsigma.exp()) * score.sum(dim=-1)
1426
+ score *= dsigma.exp()[:, None]
1427
+ score[..., self.mask_index] += extra_const
1428
+ return score
1429
+
1430
+ def _analytic_update(self, x, t, step_size):
1431
+ curr_sigma, _ = self.noise(t)
1432
+ next_sigma, _ = self.noise(t - step_size)
1433
+ dsigma = curr_sigma - next_sigma
1434
+ score = self.get_score(x, curr_sigma)
1435
+ stag_score = self._staggered_score(score, dsigma)
1436
+ probs = stag_score * self._transp_transition(x, dsigma)
1437
+ return _sample_categorical(probs)
1438
+
1439
+ def _denoiser_update(self, x, t):
1440
+ sigma, _ = self.noise(t)
1441
+ score = self.get_score(x, sigma)
1442
+ stag_score = self._staggered_score(score, sigma)
1443
+ probs = stag_score * self._transp_transition(x, sigma)
1444
+ probs[..., self.mask_index] = 0
1445
+ samples = _sample_categorical(probs)
1446
+ return samples
1447
+
1448
+ def _transp_transition(self, i, sigma):
1449
+ sigma = _unsqueeze(sigma, reference=i[..., None])
1450
+ edge = torch.exp(-sigma) * F.one_hot(
1451
+ i, num_classes=self.vocab_size)
1452
+ edge += torch.where(i == self.mask_index,
1453
+ 1 - torch.exp(-sigma).squeeze(-1),
1454
+ 0)[..., None]
1455
+ return edge
1456
+
1457
+ def _sample_t(self, n, device):
1458
+ _eps_t = torch.rand(n, device=device)
1459
+ if self.antithetic_sampling:
1460
+ # for variance reduction
1461
+ offset = torch.arange(n, device=device) / n
1462
+ _eps_t = (_eps_t / n + offset) % 1
1463
+ t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps
1464
+ if self.importance_sampling:
1465
+ return self.noise.importance_sampling_transformation(t)
1466
+ return t
1467
+
1468
+ def _maybe_sub_sample(self, x0, attention_mask):
1469
+ seqlen = x0.shape[1]
1470
+ if seqlen > self.config.model.length:
1471
+ raise NotImplementedError('Sub-sampling not implemented')
1472
+ elif self.parameterization == 'ar':
1473
+ input_tokens = x0[:, :-1]
1474
+ output_tokens = x0[:, 1:]
1475
+ new_attention_mask = attention_mask[:, 1:]
1476
+ else:
1477
+ input_tokens = x0
1478
+ output_tokens = None
1479
+ new_attention_mask = attention_mask
1480
+ return input_tokens, output_tokens, new_attention_mask
1481
+
1482
+ def _reconstruction_loss(self, x0):
1483
+ t0 = torch.zeros(x0.shape[0], dtype=self.dtype,
1484
+ device=self.device)
1485
+ assert self.config.noise.type == 'loglinear'
1486
+ # The above assert is for d3pm parameterization
1487
+ unet_conditioning = self.noise(t0)[0][:, None]
1488
+ model_output_t0 = self.forward(x0, unet_conditioning)
1489
+ return - torch.gather(input=model_output_t0,
1490
+ dim=-1,
1491
+ index=x0[:, :, None]).squeeze(-1)
1492
+
1493
+ def _forward_pass_diffusion(self, x0):
1494
+ t = self._sample_t(x0.shape[0], x0.device)
1495
+ if self.T > 0:
1496
+ # else ts are between 0 and 1
1497
+ t = (t * self.T).to(torch.int)
1498
+ t = t / self.T
1499
+ # t \in {1/T, 2/T, ..., 1}
1500
+ t += (1 / self.T)
1501
+
1502
+ if self.change_of_variables: # False
1503
+ unet_conditioning = t[:, None]
1504
+ f_T = torch.log1p(- torch.exp(- self.noise.sigma_max))
1505
+ f_0 = torch.log1p(- torch.exp(- self.noise.sigma_min))
1506
+ move_chance = torch.exp(f_0 + t * (f_T - f_0))
1507
+ move_chance = move_chance[:, None]
1508
+ else:
1509
+ sigma, dsigma = self.noise(t) # total noise, rate noise
1510
+ unet_conditioning = sigma[:, None]
1511
+ move_chance = 1 - torch.exp(-sigma[:, None])
1512
+
1513
+ xt = self.q_xt(x0, move_chance) # q(xt|x0)
1514
+ model_output = self.forward(xt, unet_conditioning)
1515
+ utils.print_nans(model_output, 'model_output')
1516
+
1517
+ if self.parameterization == 'sedd':
1518
+ return dsigma[:, None] * self._score_entropy(
1519
+ model_output, sigma[:, None], xt, x0)
1520
+
1521
+ if self.T > 0:
1522
+ diffusion_loss = self._d3pm_loss(
1523
+ model_output=model_output, xt=xt, x0=x0, t=t)
1524
+ if self.parameterization == 'd3pm':
1525
+ reconstruction_loss = self._reconstruction_loss(x0)
1526
+ elif self.parameterization == 'subs':
1527
+ reconstruction_loss = 0
1528
+ return reconstruction_loss + diffusion_loss
1529
+
1530
+ # SUBS parameterization, continuous time.
1531
+ log_p_theta = torch.gather(
1532
+ input=model_output,
1533
+ dim=-1,
1534
+ index=x0[:, :, None]).squeeze(-1)
1535
+
1536
+ if self.change_of_variables or self.importance_sampling:
1537
+ return log_p_theta * torch.log1p(
1538
+ - torch.exp(- self.noise.sigma_min))
1539
+
1540
+ return - log_p_theta * (
1541
+ dsigma / torch.expm1(sigma))[:, None]
1542
+
1543
+ def _loss(self, x0, attention_mask):
1544
+ (input_tokens, output_tokens, attention_mask) = self._maybe_sub_sample(
1545
+ x0, attention_mask)
1546
+
1547
+ if self.parameterization == 'ar':
1548
+ logprobs = self.backbone(input_tokens, None)
1549
+ loss = - logprobs.gather(
1550
+ -1, output_tokens[:, :, None])[:, :, 0]
1551
+ else:
1552
+ loss = self._forward_pass_diffusion(input_tokens)
1553
+
1554
+ nlls = loss * attention_mask
1555
+ count = attention_mask.sum()
1556
+
1557
+ batch_nll = nlls.sum()
1558
+ token_nll = batch_nll / count
1559
+
1560
+ return Loss(loss=token_nll,
1561
+ nlls=nlls,
1562
+ token_mask=attention_mask)
1563
+
1564
+ def _score_entropy(self, log_score, sigma, xt, x0):
1565
+ """Computes the SEDD loss.
1566
+
1567
+ Args:
1568
+ log_score: float torch.Tensor with shape (batch_size,
1569
+ diffusion_model_input_length, vocab_size),
1570
+ log score, output of the denoising network.
1571
+ xt: int torch.Tensor with shape (batch_size,
1572
+ diffusion_model_input_length), input.
1573
+ x0: int torch.Tensor with shape (batch_size,
1574
+ diffusion_model_input_length), input.
1575
+ sigma: float torch.Tensor with shape (batch_size, 1).
1576
+
1577
+ Returns:
1578
+ loss with shape (batch_size, diffusion_model_input_length)
1579
+ """
1580
+ # seems that it takes y=x0,xt=M case
1581
+ # what is the const term for, seems to be y=M,xt=x0 case and x0 is known so score estimation is precise
1582
+ masked_indices = xt == self.mask_index
1583
+
1584
+ expsig_minus_1 = torch.expm1(sigma).expand_as(xt)
1585
+ q_ratio = 1 / expsig_minus_1[masked_indices]
1586
+
1587
+ words_that_were_masked = x0[masked_indices]
1588
+
1589
+ neg_term = q_ratio * torch.gather(
1590
+ log_score[masked_indices],
1591
+ -1,
1592
+ words_that_were_masked[..., None]).squeeze(-1)
1593
+ score = log_score[masked_indices].exp()
1594
+ if self.mask_index == self.vocab_size - 1:
1595
+ pos_term = score[:, :-1].sum(dim=-1)
1596
+ else:
1597
+ pos_term = score[:, : self.mask_index].sum(
1598
+ dim=-1) + score[:, self.mask_index + 1:].sum(dim=-1)
1599
+ const = q_ratio * (q_ratio.log() - 1)
1600
+
1601
+ entropy = torch.zeros(* xt.shape, device=xt.device)
1602
+ entropy[masked_indices] += pos_term - neg_term + const
1603
+ return entropy
1604
+
tr2d2-dna/diffusion_gosai_cfg.py ADDED
@@ -0,0 +1,729 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import math
3
+ from dataclasses import dataclass
4
+ import hydra.utils
5
+ import lightning as L
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torchmetrics
10
+ from torch import Tensor
11
+ import dataloader_gosai
12
+ import models
13
+ import noise_schedule
14
+ import utils
15
+ import oracle
16
+
17
+ LOG2 = math.log(2)
18
+ LOGGER = utils.get_logger(__name__)
19
+
20
+
21
+ def _sample_categorical(categorical_probs):
22
+ gumbel_norm = (
23
+ 1e-10
24
+ - (torch.rand_like(categorical_probs) + 1e-10).log())
25
+ return (categorical_probs / gumbel_norm).argmax(dim=-1)
26
+
27
+
28
+ def _unsqueeze(x, reference):
29
+ return x.view(
30
+ * x.shape,
31
+ * ((1,) * (len(reference.shape) - len(x.shape))))
32
+
33
+
34
+ @dataclass
35
+ class Loss:
36
+ loss: torch.FloatTensor
37
+ nlls: torch.FloatTensor
38
+ token_mask: torch.FloatTensor
39
+
40
+
41
+ class NLL(torchmetrics.aggregation.MeanMetric):
42
+ pass
43
+
44
+
45
+ class BPD(NLL):
46
+ def compute(self) -> Tensor:
47
+ """Computes the bits per dimension.
48
+
49
+ Returns:
50
+ bpd
51
+ """
52
+ return self.mean_value / self.weight / LOG2
53
+
54
+
55
+ class Perplexity(NLL):
56
+ def compute(self) -> Tensor:
57
+ """Computes the Perplexity.
58
+
59
+ Returns:
60
+ Perplexity
61
+ """
62
+ return torch.exp(self.mean_value / self.weight)
63
+
64
+
65
+ class Diffusion(L.LightningModule):
66
+ def __init__(
67
+ self,
68
+ config,
69
+ eval=True):
70
+ super().__init__()
71
+ self.save_hyperparameters()
72
+ self.config = config
73
+ self.vocab_size = 4
74
+ self.sampler = self.config.sampling.predictor
75
+ self.antithetic_sampling = self.config.training.antithetic_sampling
76
+ self.importance_sampling = self.config.training.importance_sampling
77
+ self.change_of_variables = self.config.training.change_of_variables
78
+ self.mask_index = self.vocab_size
79
+ self.vocab_size += 1
80
+ self.parameterization = self.config.parameterization
81
+ if self.config.backbone == 'cnn':
82
+ self.backbone = models.dnaconv.CNNModel(
83
+ self.config.model, alphabet_size=self.vocab_size, num_cls=2)
84
+ else:
85
+ raise ValueError(
86
+ f'Unknown backbone: {self.config.backbone}')
87
+
88
+ self.T = self.config.T
89
+ self.subs_masking = self.config.subs_masking
90
+
91
+ self.softplus = torch.nn.Softplus()
92
+ # metrics are automatically reset at end of epoch
93
+ metrics = torchmetrics.MetricCollection({
94
+ 'nll': NLL(),
95
+ 'bpd': BPD(),
96
+ 'ppl': Perplexity(),
97
+ })
98
+ metrics.set_dtype(torch.float64)
99
+ self.train_metrics = metrics.clone(prefix='train/')
100
+ self.valid_metrics = metrics.clone(prefix='val/')
101
+ self.test_metrics = metrics.clone(prefix='test/')
102
+
103
+ # generative perplexity
104
+ self.gen_ppl_metric = Perplexity()
105
+
106
+ self.noise = noise_schedule.get_noise(self.config,
107
+ dtype=self.dtype)
108
+ if self.config.training.ema > 0:
109
+ self.ema = models.ema.ExponentialMovingAverage(
110
+ itertools.chain(self.backbone.parameters(),
111
+ self.noise.parameters()),
112
+ decay=self.config.training.ema)
113
+ else:
114
+ self.ema = None
115
+
116
+ self.lr = self.config.optim.lr
117
+ self.sampling_eps = self.config.training.sampling_eps
118
+ self.time_conditioning = self.config.time_conditioning
119
+ self.neg_infinity = -1000000.0
120
+ self.fast_forward_epochs = None
121
+ self.fast_forward_batches = None
122
+ self._validate_configuration()
123
+
124
+ # subset of data for evaluation
125
+ if eval:
126
+ self.eval_sets_sp = oracle.subset_for_eval(n=config.eval.subset_size)
127
+ self.eval_sets_sp_clss = oracle.subset_eval_groundtruth(self.eval_sets_sp)
128
+ self.eval_sets_sp_preds = oracle.subset_eval_preds(self.eval_sets_sp)
129
+ self.eval_sets_sp_kmers = oracle.subset_eval_kmers(self.eval_sets_sp)
130
+ self.emb_pca = oracle.cal_emb_pca(oracle.subset_for_eval(n=40000), n_components=50)
131
+ self.eval_sets_sp_embs_pca = oracle.subset_eval_embs_pca(self.eval_sets_sp, self.emb_pca)
132
+
133
+ def _validate_configuration(self):
134
+ assert not (self.change_of_variables
135
+ and self.importance_sampling)
136
+ assert self.parameterization == 'subs'
137
+
138
+ def on_load_checkpoint(self, checkpoint):
139
+ if self.ema:
140
+ self.ema.load_state_dict(checkpoint['ema'])
141
+ # Copied from:
142
+ # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py#L41
143
+ self.fast_forward_epochs = checkpoint['loops'][
144
+ 'fit_loop']['epoch_progress']['current']['completed']
145
+ self.fast_forward_batches = checkpoint['loops'][
146
+ 'fit_loop']['epoch_loop.batch_progress'][
147
+ 'current']['completed']
148
+
149
+ def on_save_checkpoint(self, checkpoint):
150
+ if self.ema:
151
+ checkpoint['ema'] = self.ema.state_dict()
152
+ # Copied from:
153
+ # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py
154
+ # ['epoch_loop.batch_progress']['total']['completed'] is 1 iteration
155
+ # behind, so we're using the optimizer's progress.
156
+ checkpoint['loops']['fit_loop'][
157
+ 'epoch_loop.batch_progress']['total'][
158
+ 'completed'] = checkpoint['loops']['fit_loop'][
159
+ 'epoch_loop.automatic_optimization.optim_progress'][
160
+ 'optimizer']['step']['total'][
161
+ 'completed'] * self.trainer.accumulate_grad_batches
162
+ checkpoint['loops']['fit_loop'][
163
+ 'epoch_loop.batch_progress']['current'][
164
+ 'completed'] = checkpoint['loops']['fit_loop'][
165
+ 'epoch_loop.automatic_optimization.optim_progress'][
166
+ 'optimizer']['step']['current'][
167
+ 'completed'] * self.trainer.accumulate_grad_batches
168
+ # _batches_that_stepped tracks the number of global steps, not the number
169
+ # of local steps, so we don't multiply with self.trainer.accumulate_grad_batches here.
170
+ checkpoint['loops']['fit_loop'][
171
+ 'epoch_loop.state_dict'][
172
+ '_batches_that_stepped'] = checkpoint['loops']['fit_loop'][
173
+ 'epoch_loop.automatic_optimization.optim_progress'][
174
+ 'optimizer']['step']['total']['completed']
175
+ if 'sampler' not in checkpoint.keys():
176
+ checkpoint['sampler'] = {}
177
+ if hasattr(self.trainer.train_dataloader.sampler,
178
+ 'state_dict'):
179
+ sampler_state_dict = self.trainer.\
180
+ train_dataloader.sampler.state_dict()
181
+ checkpoint['sampler'][
182
+ 'random_state'] = sampler_state_dict.get(
183
+ 'random_state', None)
184
+ else:
185
+ checkpoint['sampler']['random_state'] = None
186
+
187
+ def on_train_start(self):
188
+ if self.ema:
189
+ self.ema.move_shadow_params_to_device(self.device)
190
+ # Adapted from:
191
+ # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py
192
+ distributed = (
193
+ self.trainer._accelerator_connector.use_distributed_sampler
194
+ and self.trainer._accelerator_connector.is_distributed)
195
+
196
+ print('distributed:', distributed)
197
+ if distributed:
198
+ sampler_cls = dataloader_gosai.FaultTolerantDistributedSampler
199
+ else:
200
+ sampler_cls = dataloader_gosai.RandomFaultTolerantSampler
201
+
202
+ updated_dls = []
203
+ for dl in self.trainer.fit_loop._combined_loader.flattened:
204
+ if hasattr(dl.sampler, 'shuffle'):
205
+ dl_sampler = sampler_cls(
206
+ dl.dataset, shuffle=dl.sampler.shuffle)
207
+ else:
208
+ dl_sampler = sampler_cls(dl.dataset)
209
+ if (distributed
210
+ and self.fast_forward_epochs is not None
211
+ and self.fast_forward_batches is not None):
212
+ dl_sampler.load_state_dict({
213
+ 'epoch': self.fast_forward_epochs,
214
+ 'counter': (self.fast_forward_batches
215
+ * self.config.loader.batch_size)})
216
+ updated_dls.append(
217
+ torch.utils.data.DataLoader(
218
+ dl.dataset,
219
+ batch_size=self.config.loader.batch_size,
220
+ num_workers=self.config.loader.num_workers,
221
+ pin_memory=self.config.loader.pin_memory,
222
+ sampler=dl_sampler,
223
+ shuffle=False,
224
+ persistent_workers=True))
225
+ self.trainer.fit_loop._combined_loader.flattened = updated_dls
226
+
227
+ def optimizer_step(self, *args, **kwargs):
228
+ super().optimizer_step(*args, **kwargs)
229
+ if self.ema:
230
+ self.ema.update(itertools.chain(
231
+ self.backbone.parameters(),
232
+ self.noise.parameters()))
233
+
234
+ def _subs_parameterization(self, logits, xt):
235
+ logits[:, :, self.mask_index] += self.neg_infinity
236
+ logits = logits - torch.logsumexp(logits, dim=-1,
237
+ keepdim=True)
238
+
239
+ unmasked_indices = (xt != self.mask_index)
240
+ logits[unmasked_indices] = self.neg_infinity
241
+ logits[unmasked_indices, xt[unmasked_indices]] = 0
242
+ return logits
243
+
244
+ def _process_sigma(self, sigma):
245
+ if sigma is None:
246
+ assert self.parameterization == 'ar'
247
+ return sigma
248
+ if sigma.ndim > 1:
249
+ sigma = sigma.squeeze(-1)
250
+ if not self.time_conditioning:
251
+ sigma = torch.zeros_like(sigma)
252
+ assert sigma.ndim == 1, sigma.shape
253
+ return sigma
254
+
255
+ def forward(self, x, sigma, binary_clss=None):
256
+ """Returns log score."""
257
+ sigma = self._process_sigma(sigma)
258
+ with torch.cuda.amp.autocast(dtype=torch.float32):
259
+ logits = self.backbone(x, sigma, cls=binary_clss)
260
+
261
+ if self.parameterization == 'subs':
262
+ return self._subs_parameterization(logits=logits, xt=x)
263
+
264
+ return logits
265
+
266
+ def _compute_loss(self, batch, prefix):
267
+ if 'attention_mask' in batch:
268
+ attention_mask = batch['attention_mask']
269
+ else:
270
+ attention_mask = None
271
+ # classifier-free guidance
272
+ assert self.config.model.cls_free_guidance == True
273
+ binary_clss = (batch['clss'][:,0] > self.config.model.cls_free_threshold).long()
274
+ random_list = np.random.binomial(1, self.config.model.cls_free_prob, binary_clss.shape[0])
275
+ binary_clss[random_list==1] = 2
276
+ losses = self._loss(batch['seqs'], attention_mask, binary_clss)
277
+ loss = losses.loss
278
+
279
+ if prefix == 'train':
280
+ self.train_metrics.update(losses.nlls, losses.token_mask)
281
+ metrics = self.train_metrics
282
+ elif prefix == 'val':
283
+ self.valid_metrics.update(losses.nlls, losses.token_mask)
284
+ metrics = self.valid_metrics
285
+ elif prefix == 'test':
286
+ self.test_metrics.update(losses.nlls, losses.token_mask)
287
+ metrics = self.test_metrics
288
+ else:
289
+ raise ValueError(f'Invalid prefix: {prefix}')
290
+
291
+ self.log_dict(metrics,
292
+ on_step=False,
293
+ on_epoch=True,
294
+ sync_dist=True)
295
+ return loss
296
+
297
+ def on_train_epoch_start(self):
298
+ self.backbone.train()
299
+ self.noise.train()
300
+
301
+ def training_step(self, batch, batch_idx):
302
+ loss = self._compute_loss(batch, prefix='train')
303
+ self.log(name='trainer/loss',
304
+ value=loss.item(),
305
+ on_step=True,
306
+ on_epoch=False,
307
+ sync_dist=True)
308
+ return loss
309
+
310
+ def on_validation_epoch_start(self):
311
+ if self.ema:
312
+ self.ema.store(itertools.chain(
313
+ self.backbone.parameters(),
314
+ self.noise.parameters()))
315
+ self.ema.copy_to(itertools.chain(
316
+ self.backbone.parameters(),
317
+ self.noise.parameters()))
318
+ self.backbone.eval()
319
+ self.noise.eval()
320
+ assert self.valid_metrics.nll.mean_value == 0
321
+ assert self.valid_metrics.nll.weight == 0
322
+
323
+ def validation_step(self, batch, batch_idx):
324
+ return self._compute_loss(batch, prefix='val')
325
+
326
+ def on_validation_epoch_end(self):
327
+ if ((self.config.eval.compute_perplexity_on_sanity
328
+ or not self.trainer.sanity_checking)
329
+ and self.config.eval.generate_samples
330
+ and not self.parameterization == 'ar'):
331
+ all_samples, all_detoeknized_samples = [], []
332
+ for _ in range(
333
+ self.config.sampling.num_sample_batches):
334
+ samples = self._sample(cls=1).detach().cpu().numpy()
335
+ detokenized_samples = dataloader_gosai.batch_dna_detokenize(samples)
336
+ all_samples.append(samples)
337
+ all_detoeknized_samples.extend(detokenized_samples)
338
+ all_samples = np.concatenate(all_samples, axis=0)
339
+ generated_preds = oracle.cal_gosai_pred(all_detoeknized_samples, mode='eval')[:,0]
340
+ avg_generated_preds = np.mean(generated_preds, axis=0)
341
+
342
+ current_step = self.trainer.global_step
343
+ LOGGER.info(f'Current step: {current_step}')
344
+ LOGGER.info(f'Generated preds: {avg_generated_preds}')
345
+ self.log('val/gosai_preds_avg', avg_generated_preds, on_step=False, on_epoch=True, sync_dist=True)
346
+
347
+ if self.ema:
348
+ self.ema.restore(
349
+ itertools.chain(self.backbone.parameters(),
350
+ self.noise.parameters()))
351
+
352
+ def configure_optimizers(self):
353
+ # TODO(yair): Lightning currently giving this warning when using `fp16`:
354
+ # "Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
355
+ # Not clear if this is a problem or not.
356
+ # See: https://github.com/Lightning-AI/pytorch-lightning/issues/5558
357
+ optimizer = torch.optim.AdamW(
358
+ itertools.chain(self.backbone.parameters(),
359
+ self.noise.parameters()),
360
+ lr=self.config.optim.lr,
361
+ betas=(self.config.optim.beta1,
362
+ self.config.optim.beta2),
363
+ eps=self.config.optim.eps,
364
+ weight_decay=self.config.optim.weight_decay)
365
+
366
+ scheduler = hydra.utils.instantiate(
367
+ self.config.lr_scheduler, optimizer=optimizer)
368
+ scheduler_dict = {
369
+ 'scheduler': scheduler,
370
+ 'interval': 'step',
371
+ 'monitor': 'val/loss',
372
+ 'name': 'trainer/lr',
373
+ }
374
+ return [optimizer], [scheduler_dict]
375
+
376
+ def q_xt(self, x, move_chance):
377
+ """Computes the noisy sample xt.
378
+
379
+ Args:
380
+ x: int torch.Tensor with shape (batch_size,
381
+ diffusion_model_input_length), input.
382
+ move_chance: float torch.Tensor with shape (batch_size, 1).
383
+ """
384
+ move_indices = torch.rand(
385
+ * x.shape, device=x.device) < move_chance
386
+ xt = torch.where(move_indices, self.mask_index, x)
387
+ return xt
388
+
389
+ def _sample_prior(self, *batch_dims):
390
+ return self.mask_index * torch.ones(
391
+ * batch_dims, dtype=torch.int64)
392
+
393
+ def _ddpm_caching_update(self, x, t, dt, p_x0=None):
394
+ assert self.config.noise.type == 'loglinear'
395
+ sigma_t, _ = self.noise(t)
396
+ if t.ndim > 1:
397
+ t = t.squeeze(-1)
398
+ assert t.ndim == 1
399
+ move_chance_t = t[:, None, None]
400
+ move_chance_s = (t - dt)[:, None, None]
401
+ assert move_chance_t.ndim == 3, move_chance_t.shape
402
+ if p_x0 is None:
403
+ p_x0 = self.forward(x, sigma_t).exp()
404
+
405
+ assert move_chance_t.ndim == p_x0.ndim
406
+ q_xs = p_x0 * (move_chance_t - move_chance_s)
407
+ q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
408
+ _x = _sample_categorical(q_xs)
409
+
410
+ copy_flag = (x != self.mask_index).to(x.dtype)
411
+ return p_x0, copy_flag * x + (1 - copy_flag) * _x
412
+
413
+ def _ddpm_update(self, x, t, dt, cls, w):
414
+ sigma_t, _ = self.noise(t)
415
+ sigma_s, _ = self.noise(t - dt)
416
+ if sigma_t.ndim > 1:
417
+ sigma_t = sigma_t.squeeze(-1)
418
+ if sigma_s.ndim > 1:
419
+ sigma_s = sigma_s.squeeze(-1)
420
+ assert sigma_t.ndim == 1, sigma_t.shape
421
+ assert sigma_s.ndim == 1, sigma_s.shape
422
+ move_chance_t = 1 - torch.exp(-sigma_t)
423
+ move_chance_s = 1 - torch.exp(-sigma_s)
424
+ move_chance_t = move_chance_t[:, None, None]
425
+ move_chance_s = move_chance_s[:, None, None]
426
+ unet_conditioning = sigma_t
427
+ uncond = (2 * torch.ones(x.shape[0], device=x.device)).long()
428
+ cond = (cls * torch.ones(x.shape[0], device=x.device)).long()
429
+ log_p_x0_uncond = self.forward(x, unet_conditioning, uncond)
430
+ log_p_x0_cond = self.forward(x, unet_conditioning, cond)
431
+ log_p_x0 = (1+w) * log_p_x0_cond - w * log_p_x0_uncond
432
+ assert move_chance_t.ndim == log_p_x0.ndim
433
+ q_xs = log_p_x0.exp() * (move_chance_t
434
+ - move_chance_s)
435
+ q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
436
+ _x = _sample_categorical(q_xs)
437
+
438
+ copy_flag = (x != self.mask_index).to(x.dtype)
439
+ return copy_flag * x + (1 - copy_flag) * _x
440
+
441
+ def _ar_sampler(self, bsz):
442
+ # precompute token buffer
443
+ num_pred_tokens = self.config.model.length - 1
444
+ x = torch.zeros(
445
+ (bsz, num_pred_tokens + 1),
446
+ dtype=torch.long,
447
+ device=self.device)
448
+ x[:, 0] = self.tokenizer.bos_token_id
449
+ # precompute noise
450
+ noise = (torch.distributions.Gumbel(0, 1)
451
+ .sample((bsz, num_pred_tokens, self.vocab_size))
452
+ .to(self.device))
453
+ for i in range(num_pred_tokens):
454
+ next_logits = self.forward(x[:, :i + 1], None)[:, -1]
455
+ y = (next_logits + noise[:, i]).argmax(-1)
456
+ x[:, i + 1] = y
457
+ return x
458
+
459
+ @torch.no_grad()
460
+ def _sample(self, num_steps=None, eps=1e-5, eval_sp_size=None, cls=1, w=None):
461
+ """Generate samples from the model."""
462
+ if w is None:
463
+ w = self.config.model.cls_free_weight
464
+ if eval_sp_size is None:
465
+ batch_size_per_gpu = self.config.loader.eval_batch_size
466
+ else:
467
+ batch_size_per_gpu = eval_sp_size
468
+ if self.parameterization == 'ar':
469
+ return self._ar_sampler(batch_size_per_gpu)
470
+ if num_steps is None:
471
+ num_steps = self.config.sampling.steps
472
+ x = self._sample_prior(
473
+ batch_size_per_gpu,
474
+ self.config.model.length).to(self.device)
475
+ timesteps = torch.linspace(
476
+ 1, eps, num_steps + 1, device=self.device)
477
+ dt = (1 - eps) / num_steps
478
+ p_x0_cache = None
479
+
480
+ for i in range(num_steps):
481
+ t = timesteps[i] * torch.ones(
482
+ x.shape[0], 1, device=self.device)
483
+ if self.sampler == 'ddpm':
484
+ x = self._ddpm_update(x, t, dt, cls, w)
485
+ else:
486
+ raise NotImplementedError
487
+
488
+ if self.config.sampling.noise_removal:
489
+ t = timesteps[-1] * torch.ones(x.shape[0], 1,
490
+ device=self.device)
491
+ unet_conditioning = self.noise(t)[0]
492
+
493
+ uncond = (2 * torch.ones(x.shape[0], device=x.device)).long()
494
+ cond = (cls * torch.ones(x.shape[0], device=x.device)).long()
495
+ log_p_x0_uncond = self.forward(x, unet_conditioning, uncond)
496
+ log_p_x0_cond = self.forward(x, unet_conditioning, cond)
497
+
498
+ logits = (1+w) * log_p_x0_cond - w * log_p_x0_uncond
499
+ x = logits[:, :, :-1].argmax(dim=-1)
500
+
501
+ return x
502
+
503
+ def get_score(self, x, sigma):
504
+ model_output = self.forward(x, sigma)
505
+ if self.parameterization == 'subs':
506
+ # score(x, t) = p_t(y) / p_t(x)
507
+ # => log score(x, t) = log p_t(y) - log p_t(x)
508
+
509
+ # case 1: x = masked
510
+ # (i) y = unmasked
511
+ # log score(x, t) = log p_\theta(x)|_y + log k
512
+ # where k = exp(- sigma) / (1 - exp(- sigma))
513
+ # (ii) y = masked
514
+ # log score(x, t) = 0
515
+
516
+ # case 2: x = unmasked
517
+ # (i) y != masked, y != x
518
+ # log score(x_i, t) = - inf
519
+ # (ii) y = x
520
+ # log score(x_i, t) = 0
521
+ # (iii) y = masked token
522
+ # log score(x_i, t) = - log k
523
+ # where k = exp(- sigma) / (1 - exp(- sigma))
524
+
525
+ log_k = - torch.log(torch.expm1(sigma)).squeeze(-1)
526
+ assert log_k.ndim == 1
527
+
528
+ masked_score = model_output + log_k[:, None, None]
529
+ masked_score[:, :, self.mask_index] = 0
530
+
531
+ unmasked_score = self.neg_infinity * torch.ones_like(
532
+ model_output)
533
+ unmasked_score = torch.scatter(
534
+ unmasked_score,
535
+ -1,
536
+ x[..., None],
537
+ torch.zeros_like(unmasked_score[..., :1]))
538
+ unmasked_score[:, :, self.mask_index] = - (
539
+ log_k[:, None] * torch.ones_like(x))
540
+
541
+ masked_indices = (x == self.mask_index).to(
542
+ model_output.dtype)[:, :, None]
543
+ model_output = (
544
+ masked_score * masked_indices
545
+ + unmasked_score * (1 - masked_indices))
546
+ return model_output.exp()
547
+
548
+ def _staggered_score(self, score, dsigma):
549
+ score = score.clone()
550
+ extra_const = (1 - dsigma.exp()) * score.sum(dim=-1)
551
+ score *= dsigma.exp()[:, None]
552
+ score[..., self.mask_index] += extra_const
553
+ return score
554
+
555
+ def _analytic_update(self, x, t, step_size):
556
+ curr_sigma, _ = self.noise(t)
557
+ next_sigma, _ = self.noise(t - step_size)
558
+ dsigma = curr_sigma - next_sigma
559
+ score = self.get_score(x, curr_sigma)
560
+ stag_score = self._staggered_score(score, dsigma)
561
+ probs = stag_score * self._transp_transition(x, dsigma)
562
+ return _sample_categorical(probs)
563
+
564
+ def _denoiser_update(self, x, t):
565
+ sigma, _ = self.noise(t)
566
+ score = self.get_score(x, sigma)
567
+ stag_score = self._staggered_score(score, sigma)
568
+ probs = stag_score * self._transp_transition(x, sigma)
569
+ probs[..., self.mask_index] = 0
570
+ samples = _sample_categorical(probs)
571
+ return samples
572
+
573
+ def _transp_transition(self, i, sigma):
574
+ sigma = _unsqueeze(sigma, reference=i[..., None])
575
+ edge = torch.exp(-sigma) * F.one_hot(
576
+ i, num_classes=self.vocab_size)
577
+ edge += torch.where(i == self.mask_index,
578
+ 1 - torch.exp(-sigma).squeeze(-1),
579
+ 0)[..., None]
580
+ return edge
581
+
582
+ def _sample_t(self, n, device):
583
+ _eps_t = torch.rand(n, device=device)
584
+ if self.antithetic_sampling:
585
+ # for variance reduction
586
+ offset = torch.arange(n, device=device) / n
587
+ _eps_t = (_eps_t / n + offset) % 1
588
+ t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps
589
+ if self.importance_sampling:
590
+ return self.noise.importance_sampling_transformation(t)
591
+ return t
592
+
593
+ def _maybe_sub_sample(self, x0, attention_mask):
594
+ seqlen = x0.shape[1]
595
+ if seqlen > self.config.model.length:
596
+ raise NotImplementedError('Sub-sampling not implemented')
597
+ elif self.parameterization == 'ar':
598
+ input_tokens = x0[:, :-1]
599
+ output_tokens = x0[:, 1:]
600
+ new_attention_mask = attention_mask[:, 1:]
601
+ else:
602
+ input_tokens = x0
603
+ output_tokens = None
604
+ new_attention_mask = attention_mask
605
+ return input_tokens, output_tokens, new_attention_mask
606
+
607
+ def _reconstruction_loss(self, x0):
608
+ t0 = torch.zeros(x0.shape[0], dtype=self.dtype,
609
+ device=self.device)
610
+ assert self.config.noise.type == 'loglinear'
611
+ # The above assert is for d3pm parameterization
612
+ unet_conditioning = self.noise(t0)[0][:, None]
613
+ model_output_t0 = self.forward(x0, unet_conditioning)
614
+ return - torch.gather(input=model_output_t0,
615
+ dim=-1,
616
+ index=x0[:, :, None]).squeeze(-1)
617
+
618
+ def _forward_pass_diffusion(self, x0, binary_clss=None):
619
+ t = self._sample_t(x0.shape[0], x0.device)
620
+ if self.T > 0:
621
+ # else ts are between 0 and 1
622
+ t = (t * self.T).to(torch.int)
623
+ t = t / self.T
624
+ # t \in {1/T, 2/T, ..., 1}
625
+ t += (1 / self.T)
626
+
627
+ if self.change_of_variables: # False
628
+ unet_conditioning = t[:, None]
629
+ f_T = torch.log1p(- torch.exp(- self.noise.sigma_max))
630
+ f_0 = torch.log1p(- torch.exp(- self.noise.sigma_min))
631
+ move_chance = torch.exp(f_0 + t * (f_T - f_0))
632
+ move_chance = move_chance[:, None]
633
+ else:
634
+ sigma, dsigma = self.noise(t) # total noise, rate noise
635
+ unet_conditioning = sigma[:, None]
636
+ move_chance = 1 - torch.exp(-sigma[:, None])
637
+
638
+ xt = self.q_xt(x0, move_chance) # q(xt|x0)
639
+ model_output = self.forward(xt, unet_conditioning, binary_clss=binary_clss)
640
+ utils.print_nans(model_output, 'model_output')
641
+
642
+ if self.parameterization == 'sedd':
643
+ return dsigma[:, None] * self._score_entropy(
644
+ model_output, sigma[:, None], xt, x0)
645
+
646
+ if self.T > 0:
647
+ diffusion_loss = self._d3pm_loss(
648
+ model_output=model_output, xt=xt, x0=x0, t=t)
649
+ if self.parameterization == 'd3pm':
650
+ reconstruction_loss = self._reconstruction_loss(x0)
651
+ elif self.parameterization == 'subs':
652
+ reconstruction_loss = 0
653
+ return reconstruction_loss + diffusion_loss
654
+
655
+ # SUBS parameterization, continuous time.
656
+ log_p_theta = torch.gather(
657
+ input=model_output,
658
+ dim=-1,
659
+ index=x0[:, :, None]).squeeze(-1)
660
+
661
+ if self.change_of_variables or self.importance_sampling:
662
+ return log_p_theta * torch.log1p(
663
+ - torch.exp(- self.noise.sigma_min))
664
+
665
+ return - log_p_theta * (
666
+ dsigma / torch.expm1(sigma))[:, None]
667
+
668
+ def _loss(self, x0, attention_mask, binary_clss):
669
+ (input_tokens, output_tokens,
670
+ attention_mask) = self._maybe_sub_sample(
671
+ x0, attention_mask)
672
+
673
+ if self.parameterization == 'ar':
674
+ logprobs = self.backbone(input_tokens, None, cls=binary_clss)
675
+ loss = - logprobs.gather(
676
+ -1, output_tokens[:, :, None])[:, :, 0]
677
+ else:
678
+ loss = self._forward_pass_diffusion(input_tokens, binary_clss=binary_clss)
679
+
680
+ nlls = loss * attention_mask
681
+ count = attention_mask.sum()
682
+
683
+ batch_nll = nlls.sum()
684
+ token_nll = batch_nll / count
685
+
686
+ return Loss(loss=token_nll,
687
+ nlls=nlls,
688
+ token_mask=attention_mask)
689
+
690
+ def _score_entropy(self, log_score, sigma, xt, x0):
691
+ """Computes the SEDD loss.
692
+
693
+ Args:
694
+ log_score: float torch.Tensor with shape (batch_size,
695
+ diffusion_model_input_length, vocab_size),
696
+ log score, output of the denoising network.
697
+ xt: int torch.Tensor with shape (batch_size,
698
+ diffusion_model_input_length), input.
699
+ x0: int torch.Tensor with shape (batch_size,
700
+ diffusion_model_input_length), input.
701
+ sigma: float torch.Tensor with shape (batch_size, 1).
702
+
703
+ Returns:
704
+ loss with shape (batch_size, diffusion_model_input_length)
705
+ """
706
+ # seems that it takes y=x0,xt=M case
707
+ # what is the const term for, seems to be y=M,xt=x0 case and x0 is known so score estimation is precise
708
+ masked_indices = xt == self.mask_index
709
+
710
+ expsig_minus_1 = torch.expm1(sigma).expand_as(xt)
711
+ q_ratio = 1 / expsig_minus_1[masked_indices]
712
+
713
+ words_that_were_masked = x0[masked_indices]
714
+
715
+ neg_term = q_ratio * torch.gather(
716
+ log_score[masked_indices],
717
+ -1,
718
+ words_that_were_masked[..., None]).squeeze(-1)
719
+ score = log_score[masked_indices].exp()
720
+ if self.mask_index == self.vocab_size - 1:
721
+ pos_term = score[:, :-1].sum(dim=-1)
722
+ else:
723
+ pos_term = score[:, : self.mask_index].sum(
724
+ dim=-1) + score[:, self.mask_index + 1:].sum(dim=-1)
725
+ const = q_ratio * (q_ratio.log() - 1)
726
+
727
+ entropy = torch.zeros(* xt.shape, device=xt.device)
728
+ entropy[masked_indices] += pos_term - neg_term + const
729
+ return entropy
tr2d2-dna/env.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ conda install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda=12.1 -c pytorch -c nvidia
2
+ pip install packaging
3
+ pip install ninja
4
+ pip install transformers
5
+ pip install datasets
6
+ pip install omegaconf
7
+ conda install ipykernel
8
+ python -m ipykernel install --user --name tr2d2 --display-name "Python (tr2d2)"
9
+ pip install hydra-core --upgrade
10
+ pip install hydra-submitit-launcher
11
+
12
+ # for mdlm
13
+ pip install causal-conv1d
14
+ pip install lightning
15
+ pip install timm
16
+ pip install rich
17
+
18
+ pip install scipy
19
+ pip install wandb
20
+ pip install gReLU
tr2d2-dna/eval_runs_batch.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Batch evaluation script for multiple runs with checkpoints.
4
+
5
+ This script:
6
+ 1. Scans a folder containing different runs
7
+ 2. For each run, finds checkpoints and selects the one with largest epoch number
8
+ 3. Evaluates that checkpoint and saves results indexed by run folder name
9
+ """
10
+
11
+ import os
12
+ import re
13
+ import glob
14
+ import argparse
15
+ from pathlib import Path
16
+ from diffusion import Diffusion
17
+ import dataloader_gosai
18
+ import matplotlib.pyplot as plt
19
+ import numpy as np
20
+ import pandas as pd
21
+ import oracle
22
+ from scipy.stats import pearsonr
23
+ import torch
24
+ from tqdm import tqdm
25
+ from eval_utils import get_eval_matrics
26
+ from hydra import initialize, compose
27
+ from hydra.core.global_hydra import GlobalHydra
28
+ from dataclasses import dataclass
29
+ from datetime import datetime
30
+ import json
31
+
32
+
33
+ @dataclass
34
+ class Args:
35
+ total_num_steps: int
36
+ batch_size: int
37
+ num_seeds: int
38
+ total_samples: int
39
+ seq_length: int
40
+
41
+
42
+ def find_latest_checkpoint(run_dir):
43
+ """
44
+ Find the checkpoint with the largest epoch/step number in a run directory.
45
+
46
+ Args:
47
+ run_dir (str): Path to the run directory
48
+
49
+ Returns:
50
+ str or None: Path to the latest checkpoint, or None if no checkpoints found
51
+ """
52
+ ckpt_pattern = os.path.join(run_dir, "model_*.ckpt")
53
+ ckpt_files = glob.glob(ckpt_pattern)
54
+
55
+ if not ckpt_files:
56
+ return None
57
+
58
+ # Extract step numbers from checkpoint filenames
59
+ step_numbers = []
60
+ for ckpt_file in ckpt_files:
61
+ filename = os.path.basename(ckpt_file)
62
+ match = re.search(r'model_(\d+)\.ckpt', filename)
63
+ if match:
64
+ step_numbers.append((int(match.group(1)), ckpt_file))
65
+
66
+ if not step_numbers:
67
+ return None
68
+
69
+ # Return checkpoint with largest step number
70
+ step_numbers.sort(key=lambda x: x[0], reverse=True)
71
+ return step_numbers[0][1]
72
+
73
+
74
+ def evaluate_checkpoint(checkpoint_path, args, cfg, pretrained_model, gosai_oracle,
75
+ cal_atac_pred_new_mdl, highexp_kmers_999, n_highexp_kmers_999, device):
76
+ """
77
+ Evaluate a single checkpoint.
78
+
79
+ Args:
80
+ checkpoint_path (str): Path to the checkpoint file
81
+ args: Evaluation arguments
82
+ cfg: Configuration object
83
+ pretrained_model: Pretrained reference model
84
+ gosai_oracle: GOSAI oracle model
85
+ cal_atac_pred_new_mdl: ATAC prediction model
86
+ highexp_kmers_999: High expression k-mers
87
+ n_highexp_kmers_999: Number of high expression k-mers
88
+ device: Device to run evaluation on
89
+
90
+ Returns:
91
+ tuple: (eval_metrics_agg, total_rewards_agg) containing aggregated results
92
+ """
93
+ # Load the policy model from checkpoint
94
+ policy_model = Diffusion(cfg).to(device)
95
+ policy_model.load_state_dict(torch.load(checkpoint_path, map_location=device))
96
+ policy_model.eval()
97
+
98
+ total_rewards_all = []
99
+ eval_metrics_all = []
100
+
101
+ print(f"Evaluating checkpoint: {os.path.basename(checkpoint_path)}")
102
+
103
+ for i in range(args.num_seeds):
104
+ iter_times = args.total_samples // args.batch_size
105
+ total_samples = []
106
+ total_rewards = []
107
+ range_bar = tqdm(range(iter_times), desc=f"Seed {i+1}", leave=False)
108
+
109
+ for j in range_bar:
110
+ x_eval, mean_reward_eval = policy_model.sample_finetuned(args, gosai_oracle)
111
+ total_samples.append(x_eval)
112
+ total_rewards.append(mean_reward_eval.item() * args.batch_size)
113
+
114
+ total_samples = torch.concat(total_samples)
115
+ eval_metrics = get_eval_matrics(samples=total_samples, ref_model=pretrained_model,
116
+ gosai_oracle=gosai_oracle, cal_atac_pred_new_mdl=cal_atac_pred_new_mdl,
117
+ highexp_kmers_999=highexp_kmers_999, n_highexp_kmers_999=n_highexp_kmers_999)
118
+
119
+ eval_metrics_all.append(eval_metrics)
120
+ total_rewards_all.append(np.sum(total_rewards) / args.total_samples)
121
+
122
+ # Aggregate results
123
+ eval_metrics_agg = {k: (np.mean([eval_metrics[k] for eval_metrics in eval_metrics_all]),
124
+ np.std([eval_metrics[k] for eval_metrics in eval_metrics_all]))
125
+ for k in eval_metrics_all[0].keys()}
126
+ total_rewards_agg = (np.mean(total_rewards_all), np.std(total_rewards_all))
127
+
128
+ return eval_metrics_agg, total_rewards_agg
129
+
130
+
131
+ def save_results(results, output_file):
132
+ """
133
+ Save evaluation results to a text file.
134
+
135
+ Args:
136
+ results (dict): Dictionary containing results for each run
137
+ output_file (str): Path to output file
138
+ """
139
+ with open(output_file, 'w') as f:
140
+ f.write("="*80 + "\n")
141
+ f.write("BATCH EVALUATION RESULTS\n")
142
+ f.write("="*80 + "\n")
143
+ f.write(f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
144
+ f.write(f"Total runs evaluated: {len(results)}\n\n")
145
+
146
+ for run_name, result in results.items():
147
+ if result is None:
148
+ f.write(f"RUN: {run_name}\n")
149
+ f.write("-" * 60 + "\n")
150
+ f.write("Status: No checkpoints found or evaluation failed\n\n")
151
+ continue
152
+
153
+ eval_metrics_agg, total_rewards_agg, checkpoint_path = result
154
+
155
+ f.write(f"RUN: {run_name}\n")
156
+ f.write("-" * 60 + "\n")
157
+ f.write(f"Checkpoint: {os.path.basename(checkpoint_path)}\n")
158
+ f.write(f"Full path: {checkpoint_path}\n\n")
159
+
160
+ f.write("📊 EVALUATION METRICS:\n")
161
+ for metric_name in eval_metrics_agg.keys():
162
+ mean_val = eval_metrics_agg[metric_name][0]
163
+ std_val = eval_metrics_agg[metric_name][1]
164
+ f.write(f" {metric_name:<20}: {mean_val:8.4f} ± {std_val:6.4f}\n")
165
+
166
+ f.write(f"\n🎯 TOTAL REWARDS:\n")
167
+ f.write(f" {'Mean':<20}: {total_rewards_agg[0]:8.4f}\n")
168
+ f.write(f" {'Std':<20}: {total_rewards_agg[1]:8.4f}\n")
169
+ f.write("\n")
170
+
171
+ print(f"Results saved to: {output_file}")
172
+
173
+
174
+ def append_single_result(run_name, result, output_file, is_first_run=False):
175
+ """
176
+ Append a single successful run result to the output file.
177
+
178
+ Args:
179
+ run_name (str): Name of the run
180
+ result: Result tuple (eval_metrics_agg, total_rewards_agg, checkpoint_path)
181
+ output_file (str): Path to output file
182
+ is_first_run (bool): Whether this is the first successful run (write header)
183
+ """
184
+ mode = 'w' if is_first_run else 'a'
185
+
186
+ with open(output_file, mode) as f:
187
+ if is_first_run:
188
+ f.write("="*80 + "\n")
189
+ f.write("BATCH EVALUATION RESULTS\n")
190
+ f.write("="*80 + "\n")
191
+ f.write(f"Started on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
192
+ f.write("Results are saved incrementally as each run completes.\n")
193
+ f.write("Only successful evaluations are included.\n\n")
194
+
195
+ eval_metrics_agg, total_rewards_agg, checkpoint_path = result
196
+
197
+ f.write(f"RUN: {run_name}\n")
198
+ f.write("-" * 60 + "\n")
199
+ f.write(f"Checkpoint: {os.path.basename(checkpoint_path)}\n")
200
+ f.write(f"Full path: {checkpoint_path}\n")
201
+ f.write(f"Completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
202
+
203
+ f.write("📊 EVALUATION METRICS:\n")
204
+ for metric_name in eval_metrics_agg.keys():
205
+ mean_val = eval_metrics_agg[metric_name][0]
206
+ std_val = eval_metrics_agg[metric_name][1]
207
+ f.write(f" {metric_name:<20}: {mean_val:8.4f} ± {std_val:6.4f}\n")
208
+
209
+ f.write(f"\n🎯 TOTAL REWARDS:\n")
210
+ f.write(f" {'Mean':<20}: {total_rewards_agg[0]:8.4f}\n")
211
+ f.write(f" {'Std':<20}: {total_rewards_agg[1]:8.4f}\n")
212
+ f.write("\n" + "="*80 + "\n\n") # Add separator line and extra spacing
213
+
214
+
215
+ def main():
216
+ parser = argparse.ArgumentParser(description="Batch evaluation of multiple runs")
217
+ parser.add_argument("--runs_dir", type=str, required=True,
218
+ help="Directory containing run folders with checkpoints")
219
+ parser.add_argument("--output_file", type=str, default="batch_eval_results.txt",
220
+ help="Output file to save results")
221
+ parser.add_argument("--device", type=str, default="cuda:0",
222
+ help="Device to run evaluation on")
223
+ parser.add_argument("--total_num_steps", type=int, default=128,
224
+ help="Total number of diffusion steps")
225
+ parser.add_argument("--batch_size", type=int, default=128,
226
+ help="Batch size for evaluation")
227
+ parser.add_argument("--num_seeds", type=int, default=3,
228
+ help="Number of random seeds for evaluation")
229
+ parser.add_argument("--total_samples", type=int, default=640,
230
+ help="Total number of samples to generate")
231
+ parser.add_argument("--seq_length", type=int, default=200,
232
+ help="Sequence length")
233
+ parser.add_argument("--pretrained_path", type=str,
234
+ default=None,
235
+ help="Path to pretrained model checkpoint")
236
+
237
+ args = parser.parse_args()
238
+
239
+ # Setup evaluation arguments
240
+ eval_args = Args(
241
+ total_num_steps=args.total_num_steps,
242
+ batch_size=args.batch_size,
243
+ num_seeds=args.num_seeds,
244
+ total_samples=args.total_samples,
245
+ seq_length=args.seq_length
246
+ )
247
+
248
+ device = args.device
249
+
250
+ # Initialize Hydra configuration
251
+ if GlobalHydra().is_initialized():
252
+ GlobalHydra.instance().clear()
253
+
254
+ initialize(config_path="configs_gosai", job_name="batch_eval")
255
+ cfg = compose(config_name="config_gosai.yaml")
256
+
257
+ print("Loading pretrained model and oracles...")
258
+ # Load pretrained model
259
+ pretrained_model = Diffusion.load_from_checkpoint(args.pretrained_path, config=cfg, map_location=device)
260
+ pretrained_model.eval()
261
+
262
+ # Load oracles
263
+ _, _, highexp_kmers_999, n_highexp_kmers_999, _, _, _ = oracle.cal_highexp_kmers(return_clss=True)
264
+ cal_atac_pred_new_mdl = oracle.get_cal_atac_orale(device=device)
265
+ cal_atac_pred_new_mdl.eval()
266
+ gosai_oracle = oracle.get_gosai_oracle(mode='eval', device=device)
267
+ gosai_oracle.eval()
268
+
269
+ print("Scanning for runs...")
270
+ # Find all run directories
271
+ runs_dir = Path(args.runs_dir)
272
+ if not runs_dir.exists():
273
+ print(f"Error: Directory {args.runs_dir} does not exist")
274
+ return
275
+
276
+ run_dirs = [d for d in runs_dir.iterdir() if d.is_dir()]
277
+ run_dirs.sort() # Sort for consistent ordering
278
+
279
+ print(f"Found {len(run_dirs)} run directories")
280
+
281
+ results = {}
282
+ successful_runs = 0
283
+ failed_runs = 0
284
+
285
+ # Process each run
286
+ for i, run_dir in enumerate(tqdm(run_dirs, desc="Processing runs")):
287
+ run_name = run_dir.name
288
+ print(f"\nProcessing run {i+1}/{len(run_dirs)}: {run_name}")
289
+
290
+ # Find latest checkpoint
291
+ latest_ckpt = find_latest_checkpoint(str(run_dir))
292
+
293
+ if latest_ckpt is None:
294
+ print(f" No checkpoints found in {run_name} - skipping")
295
+ failed_runs += 1
296
+ continue # Skip this run entirely, don't save anything to file
297
+
298
+ print(f" Found latest checkpoint: {os.path.basename(latest_ckpt)}")
299
+
300
+ try:
301
+ # Evaluate checkpoint
302
+ eval_metrics_agg, total_rewards_agg = evaluate_checkpoint(
303
+ latest_ckpt, eval_args, cfg, pretrained_model, gosai_oracle,
304
+ cal_atac_pred_new_mdl, highexp_kmers_999, n_highexp_kmers_999, device
305
+ )
306
+
307
+ result = (eval_metrics_agg, total_rewards_agg, latest_ckpt)
308
+ results[run_name] = result
309
+ successful_runs += 1
310
+ print(f" ✓ Evaluation completed successfully")
311
+
312
+ # Save result incrementally (only for successful evaluations)
313
+ is_first_run = (len(results) == 1) # First successful run
314
+ append_single_result(run_name, result, args.output_file, is_first_run=is_first_run)
315
+ print(f" Result saved to {args.output_file}")
316
+
317
+ except Exception as e:
318
+ print(f" ✗ Evaluation failed: {str(e)}")
319
+ failed_runs += 1
320
+ # Don't save failed evaluations to file either
321
+
322
+ # Add final summary to the file (only if there were successful runs)
323
+ if successful_runs > 0:
324
+ with open(args.output_file, 'a') as f:
325
+ f.write("="*80 + "\n")
326
+ f.write("FINAL SUMMARY\n")
327
+ f.write("="*80 + "\n")
328
+ f.write(f"Completed on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
329
+ f.write(f"Total runs processed: {len(run_dirs)}\n")
330
+ f.write(f"Successful evaluations: {successful_runs}\n")
331
+ f.write(f"Failed/skipped runs: {failed_runs}\n")
332
+ else:
333
+ print(f"No successful evaluations - output file {args.output_file} not created")
334
+
335
+ # Print summary
336
+ print(f"\nFinal Summary:")
337
+ print(f" Total runs processed: {len(run_dirs)}")
338
+ print(f" Successful evaluations: {successful_runs}")
339
+ print(f" Failed/skipped runs: {failed_runs}")
340
+ if successful_runs > 0:
341
+ print(f" Results saved to: {args.output_file}")
342
+ else:
343
+ print(f" No output file created (no successful evaluations)")
344
+
345
+
346
+ if __name__ == "__main__":
347
+ main()
tr2d2-dna/eval_utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from scipy.stats import pearsonr
4
+ import dataloader_gosai
5
+ import oracle
6
+
7
+
8
+ def compare_kmer(kmer1, kmer2, n_sp1, n_sp2):
9
+ kmer_set = set(kmer1.keys()) | set(kmer2.keys())
10
+ counts = np.zeros((len(kmer_set), 2))
11
+ for i, kmer in enumerate(kmer_set):
12
+ if kmer in kmer1: counts[i][1] = kmer1[kmer] * n_sp2 / n_sp1
13
+ if kmer in kmer2: counts[i][0] = kmer2[kmer]
14
+ return pearsonr(counts[:, 0], counts[:, 1])[0]
15
+
16
+
17
+ def get_eval_matrics(samples, ref_model, gosai_oracle, cal_atac_pred_new_mdl, highexp_kmers_999, n_highexp_kmers_999):
18
+ """samples: [B, 200]"""
19
+ info = {}
20
+ detokenized_samples = dataloader_gosai.batch_dna_detokenize(samples.detach().cpu().numpy()) # [B], strings with length 200
21
+ ref_log_lik = ref_model.get_likelihood(samples, num_steps=128, n_samples=1) # [B]
22
+ info['[log-lik-med]'] = torch.median(ref_log_lik).item()
23
+ preds = oracle.cal_gosai_pred_new(detokenized_samples, gosai_oracle, mode='eval')[:, 0]
24
+ info['[pred-activity-med]'] = np.median(preds).item()
25
+ atac = oracle.cal_atac_pred_new(detokenized_samples, cal_atac_pred_new_mdl)[:, 1]
26
+ info['[atac-acc%]'] = (atac > 0.5).sum().item() / len(samples) * 100
27
+ kmer = oracle.count_kmers(detokenized_samples)
28
+ info['[3-mer-corr]'] = compare_kmer(highexp_kmers_999, kmer, n_highexp_kmers_999, len(detokenized_samples)).item()
29
+ return info
tr2d2-dna/finetune.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # direct reward backpropagation
2
+ from diffusion import Diffusion
3
+ from hydra import initialize, compose
4
+ from hydra.core.global_hydra import GlobalHydra
5
+ import numpy as np
6
+ import oracle
7
+ from scipy.stats import pearsonr
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import argparse
11
+ import wandb
12
+ import os
13
+ import datetime
14
+ from utils import str2bool, set_seed
15
+ from finetune_dna import finetune
16
+ from mcts import MCTS
17
+
18
+ argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
19
+ argparser.add_argument('--base_path', type=str, default="")
20
+ argparser.add_argument('--learning_rate', type=float, default=1e-4)
21
+ argparser.add_argument('--num_epochs', type=int, default=100)
22
+ argparser.add_argument('--num_accum_steps', type=int, default=4)
23
+ argparser.add_argument('--truncate_steps', type=int, default=50)
24
+ argparser.add_argument("--truncate_kl", type=str2bool, default=False)
25
+ argparser.add_argument('--gumbel_temp', type=float, default=1.0)
26
+ argparser.add_argument('--gradnorm_clip', type=float, default=1.0)
27
+ argparser.add_argument('--batch_size', type=int, default=32)
28
+ argparser.add_argument('--name', type=str, default='debug')
29
+ argparser.add_argument('--total_num_steps', type=int, default=128)
30
+ argparser.add_argument('--copy_flag_temp', type=float, default=None)
31
+ argparser.add_argument('--save_every_n_epochs', type=int, default=10)
32
+ argparser.add_argument('--eval_every_n_epochs', type=int, default=200)
33
+ argparser.add_argument('--alpha', type=float, default=0.001)
34
+ argparser.add_argument('--alpha_schedule_warmup', type=int, default=0)
35
+ argparser.add_argument("--seed", type=int, default=0)
36
+ # new
37
+ argparser.add_argument('--run_name', type=str, default='drakes')
38
+ argparser.add_argument("--device", default="cuda:0", type=str)
39
+ argparser.add_argument("--save_path_dir", default=None, type=str)
40
+ argparser.add_argument("--no_mcts", action='store_true', default=False)
41
+ argparser.add_argument("--centering", action='store_true', default=False)
42
+ argparser.add_argument("--reward_clip", action='store_true', default=False)
43
+ argparser.add_argument("--reward_clip_value", type=float, default=15.0)
44
+ argparser.add_argument("--select_topk", action='store_true', default=False)
45
+ argparser.add_argument('--select_topk_value', type=int, default=10)
46
+ argparser.add_argument("--restart_ckpt_path", type=str, default=None)
47
+
48
+
49
+ # mcts
50
+ argparser.add_argument('--num_sequences', type=int, default=10)
51
+ argparser.add_argument('--num_children', type=int, default=50)
52
+ argparser.add_argument('--num_iter', type=int, default=30) # iterations of mcts
53
+ argparser.add_argument('--seq_length', type=int, default=200)
54
+ argparser.add_argument('--time_conditioning', action='store_true', default=False)
55
+ argparser.add_argument('--mcts_sampling', type=int, default=0) # for batched categorical sampling: '0' means gumbel noise
56
+ argparser.add_argument('--buffer_size', type=int, default=100)
57
+ argparser.add_argument('--wdce_num_replicates', type=int, default=16)
58
+ argparser.add_argument('--noise_removal', action='store_true', default=False)
59
+ argparser.add_argument('--grad_clip', action='store_true', default=False)
60
+ argparser.add_argument('--resample_every_n_step', type=int, default=10)
61
+ argparser.add_argument('--exploration', type=float, default=0.1)
62
+ argparser.add_argument('--reset_tree', action='store_true', default=False)
63
+
64
+ # eval
65
+
66
+ args = argparser.parse_args()
67
+ print(args)
68
+
69
+ # pretrained model path
70
+ CKPT_PATH = os.path.join(args.base_path, 'mdlm/outputs_gosai/pretrained.ckpt')
71
+ log_base_dir = os.path.join(args.save_path_dir, 'mdlm/reward_bp_results_final')
72
+
73
+ # reinitialize Hydra
74
+ GlobalHydra.instance().clear()
75
+
76
+ # Initialize Hydra and compose the configuration
77
+ initialize(config_path="configs_gosai", job_name="load_model")
78
+ cfg = compose(config_name="config_gosai.yaml")
79
+ cfg.eval.checkpoint_path = CKPT_PATH
80
+ curr_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
81
+
82
+ if args.no_mcts:
83
+ run_name = f'MDNS_buffer{args.buffer_size}_alpha{args.alpha}_resample{args.resample_every_n_step}_centering{args.centering}_{curr_time}'
84
+ else:
85
+ run_name = f'MCTS_buffer{args.buffer_size}_alpha{args.alpha}_resample{args.resample_every_n_step}_num_iter{args.num_iter}_centering{args.centering}_select_topk{args.select_topk}_select_topk_value{args.select_topk_value}_{curr_time}'
86
+
87
+ args.save_path = os.path.join(args.save_path_dir, run_name)
88
+ os.makedirs(args.save_path, exist_ok=True)
89
+ # wandb init
90
+ wandb.init(project='search-rl', name=run_name, config=args, dir=args.save_path)
91
+
92
+ log_path = os.path.join(args.save_path, 'log.txt')
93
+
94
+ set_seed(args.seed, use_cuda=True)
95
+
96
+ # Initialize the model
97
+ if args.restart_ckpt_path is not None:
98
+ # Resume from saved ckpt
99
+ restart_ckpt_path = os.path.join(args.base_path, args.restart_ckpt_path)
100
+ restart_epoch = restart_ckpt_path.split('_')[-1].split('.')[0]
101
+ args.restart_epoch = restart_epoch
102
+ policy_model = Diffusion(cfg).to(args.device)
103
+ policy_model.load_state_dict(torch.load(restart_ckpt_path, map_location=args.device))
104
+ else:
105
+ # Start from pretrained model
106
+ policy_model = Diffusion.load_from_checkpoint(cfg.eval.checkpoint_path, config=cfg, map_location=args.device)
107
+ pretrained = Diffusion.load_from_checkpoint(cfg.eval.checkpoint_path, config=cfg, map_location=args.device)
108
+ reward_model = oracle.get_gosai_oracle(mode='train', device=args.device)
109
+
110
+ #reward_model_eval = oracle.get_gosai_oracle(mode='eval').to(args.device)
111
+
112
+ reward_model.eval()
113
+ pretrained.eval()
114
+ #reward_model_eval.eval()
115
+
116
+ # define mcts
117
+ mcts = MCTS(args, cfg, policy_model, pretrained, reward_model)
118
+
119
+
120
+ _, _, highexp_kmers_999, n_highexp_kmers_999, _, _, _ = oracle.cal_highexp_kmers(return_clss=True)
121
+
122
+ cal_atac_pred_new_mdl = oracle.get_cal_atac_orale(device=args.device)
123
+ cal_atac_pred_new_mdl.eval()
124
+
125
+ gosai_oracle = oracle.get_gosai_oracle(mode='eval', device=args.device)
126
+ gosai_oracle.eval()
127
+
128
+
129
+ print("args.device:", args.device)
130
+ print("policy_model device:", policy_model.device)
131
+ print("pretrained device:", pretrained.device)
132
+ print("reward_model device:", reward_model.device)
133
+ print("mcts device:", mcts.device)
134
+ print("gosai_oracle device:", gosai_oracle.device)
135
+ print("cal_atac_pred_new_mdl device:", cal_atac_pred_new_mdl.device)
136
+
137
+ eval_model_dict = {
138
+ "gosai_oracle": gosai_oracle,
139
+ "highexp_kmers_999": highexp_kmers_999,
140
+ "n_highexp_kmers_999": n_highexp_kmers_999,
141
+ "cal_atac_pred_new_mdl": cal_atac_pred_new_mdl,
142
+ "gosai_oracle": gosai_oracle
143
+ }
144
+
145
+
146
+ finetune(args = args, cfg = cfg, policy_model = policy_model,
147
+ reward_model = reward_model, mcts = mcts,
148
+ pretrained_model = pretrained,
149
+ eval_model_dict = eval_model_dict)
tr2d2-dna/finetune_dna.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # direct reward backpropagation
2
+ from hydra import initialize, compose
3
+ from hydra.core.global_hydra import GlobalHydra
4
+ import numpy as np
5
+ import oracle
6
+ from scipy.stats import pearsonr
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import argparse
10
+ import wandb
11
+ import os
12
+ import datetime
13
+ from utils import str2bool, set_seed
14
+ # imports
15
+ from finetune_utils import loss_wdce
16
+ from tqdm import tqdm
17
+
18
+ def finetune(args, cfg, policy_model, reward_model, mcts = None, pretrained_model = None, eps=1e-5):
19
+ """
20
+ Finetuning with WDCE loss
21
+ """
22
+ dt = (1 - eps) / args.total_num_steps
23
+
24
+ if args.no_mcts:
25
+ assert pretrained_model is not None, "pretrained model is required for no mcts"
26
+ else:
27
+ assert mcts is not None, "mcts is required for mcts"
28
+
29
+ # set model to train mode
30
+ policy_model.train()
31
+ torch.set_grad_enabled(True)
32
+ optim = torch.optim.AdamW(policy_model.parameters(), lr=args.learning_rate)
33
+
34
+ # record metrics
35
+ batch_losses = []
36
+ batch_rewards = []
37
+
38
+ # initialize the final seqs and log_rnd of the trajectories that generated those seqs
39
+ x_saved, log_rnd_saved, final_rewards_saved = None, None, None
40
+
41
+ # finetuning loop
42
+ pbar = tqdm(range(args.num_epochs))
43
+ for epoch in pbar:
44
+ # store metrics
45
+ rewards = []
46
+ losses = []
47
+
48
+ policy_model.train()
49
+
50
+ with torch.no_grad():
51
+ if x_saved is None or epoch % args.resample_every_n_step == 0:
52
+ # compute final sequences and trajectory log_rnd
53
+ if args.no_mcts:
54
+ x_final, log_rnd, final_rewards = policy_model.sample_finetuned_with_rnd(args, reward_model, pretrained_model)
55
+ else:
56
+ x_final, log_rnd, final_rewards = mcts.forward(args.reset_tree)
57
+
58
+
59
+ # save for next iteration
60
+ x_saved, log_rnd_saved, final_rewards_saved = x_final, log_rnd, final_rewards
61
+ else:
62
+ x_final, log_rnd, final_rewards = x_saved, log_rnd_saved, final_rewards_saved
63
+
64
+ # compute wdce loss
65
+ loss = loss_wdce(policy_model, log_rnd, x_final, num_replicates=args.wdce_num_replicates)
66
+
67
+ # gradient descent
68
+ loss.backward()
69
+
70
+ # optimizer
71
+ if args.grad_clip:
72
+ torch.nn.utils.clip_grad_norm_(policy_model.parameters(), args.gradnorm_clip)
73
+ optim.step()
74
+ optim.zero_grad()
75
+
76
+ pbar.set_postfix(loss=loss.item())
77
+
78
+ losses.append(loss.item())
79
+
80
+ # sample a eval batch with updated policy to evaluate rewards
81
+ x_eval, mean_reward_eval = policy_model.sample_finetuned(args, reward_model)
82
+
83
+ batch_losses.append(loss.cpu().detach().numpy())
84
+ batch_rewards.append(mean_reward_eval.cpu().detach().item())
85
+ losses.append(loss.cpu().detach().numpy())
86
+
87
+ rewards = np.array(mean_reward_eval.detach().cpu().numpy())
88
+ losses = np.array(losses)
89
+
90
+ mean_reward_search = final_rewards.mean().item()
91
+ min_reward_search = final_rewards.min().item()
92
+ max_reward_search = final_rewards.max().item()
93
+ median_reward_search = final_rewards.median().item()
94
+
95
+
96
+ #reward_losses = np.array(reward_losses)
97
+
98
+ print("epoch %d"%epoch, "mean reward %f"%mean_reward_eval, "mean loss %f"%np.mean(losses))
99
+
100
+ wandb.log({"epoch": epoch, "mean_reward": mean_reward_eval, "mean_loss": np.mean(losses),
101
+ "mean_reward_search": mean_reward_search, "min_reward_search": min_reward_search,
102
+ "max_reward_search": max_reward_search, "median_reward_search": median_reward_search})
103
+
104
+
105
+ if (epoch+1) % args.save_every_n_epochs == 0:
106
+ model_path = os.path.join(args.save_path, f'model_{epoch}.ckpt')
107
+ torch.save(policy_model.state_dict(), model_path)
108
+ print(f"model saved at epoch {epoch}")
109
+
110
+
111
+ wandb.finish()
112
+
113
+ return batch_losses
tr2d2-dna/finetune_utils.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ from torch.utils.data import DataLoader, TensorDataset
4
+ from utils import sample_categorical_logits
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ import torch.distributed as dist
8
+ import torch.nn.functional as F
9
+
10
+ def compute_ess(log_rnd, normalize=True):
11
+ """
12
+ log_rnd: [B]
13
+ Compute effective sample size:
14
+ If normalize: divide ESS by batch size, so range is [0, 1];
15
+ otherwise, range is [0, B]
16
+ """
17
+ weights = log_rnd.detach().softmax(dim=-1)
18
+ ess = 1 / (weights ** 2).sum().item()
19
+ return ess / log_rnd.shape[0] if normalize else ess
20
+
21
+ def to_one_hot(x_idx, num_classes=4):
22
+ oh = F.one_hot(x_idx.long(), num_classes=num_classes)
23
+ return oh.float()
24
+
25
+ def rnd(model, reward_model, batch_size, scale=1, device='cuda:0'):
26
+ r"""
27
+ Run random order sampling and compute the RND $\log\frac{dP^*}{dP^u}$ along the trajectory
28
+ reward_model: r(X)
29
+
30
+ return:
31
+ - x: the final samples, [B, D]
32
+ - log_rnd: the log RND along this trajectory, [B]
33
+ """
34
+ if hasattr(model, 'module'):
35
+ model = model.module
36
+
37
+ x = torch.full((batch_size, model.length), model.vocab_size-1).to(device=device, dtype=torch.int64)
38
+ batch_arange = torch.arange(batch_size, device=device)
39
+ jump_pos = torch.rand(x.shape, device=device).argsort(dim=-1)
40
+ # jump_times, jump_pos = torch.rand(x.shape, device=device).sort(dim=-1)
41
+ # jump_times: Unif[0,1] in increasing order
42
+ # jump_pos: random permutation of range(D)
43
+ log_rnd = torch.zeros(batch_size, device=device) # [B]
44
+ for d in range(model.length-1, -1, -1):
45
+ # jump at time jump_times[:, d] at position jump_pos[:, d]
46
+ logits = model(x)[:, :, :-1] # [B, D, N-1]
47
+ update = sample_categorical_logits(
48
+ logits[batch_arange, jump_pos[:, d]]) # [B]
49
+ if torch.is_grad_enabled(): # avoid issues with in-place operations
50
+ x = x.clone()
51
+ x[batch_arange, jump_pos[:, d]] = update
52
+ log_rnd += -np.log(model.vocab_size-1) - logits[batch_arange, jump_pos[:, d], update]
53
+ log_rnd += scale * reward_model(x) # [B]
54
+ return x, log_rnd
55
+
56
+
57
+ @torch.no_grad()
58
+ def sampling(model, batch_size, rounds=1, device='cuda:0'):
59
+ """Any order autoregressive sampling"""
60
+ if hasattr(model, 'module'):
61
+ model = model.module
62
+ batch_arange = torch.arange(batch_size, device=device)
63
+ all_samples = []
64
+ for _ in tqdm(range(rounds), leave=False):
65
+ x = torch.full((batch_size, model.length), model.vocab_size-1).to(device=device, dtype=torch.int64)
66
+ jump_pos = torch.rand(x.shape, device=device).argsort(dim=-1)
67
+ # jump_times, jump_pos = torch.rand(x.shape, device=device).sort(dim=-1)
68
+ # jump_times: Unif[0,1] in increasing order
69
+ # jump_pos: random permutation of range(D)
70
+ for d in tqdm(range(model.length-1, -1, -1), leave=False):
71
+ # jump at time jump_times[:, d] at position jump_pos[:, d]
72
+ logits = model.logits(x)[:, :, :-1] # [B, D, N-1], not log-softmaxed but fine
73
+ update = sample_categorical_logits(
74
+ logits[batch_arange, jump_pos[:, d]]) # [B]
75
+ x[batch_arange, jump_pos[:, d]] = update
76
+ all_samples.append(x)
77
+ return torch.cat(all_samples) # (rounds * B, L)
78
+
79
+
80
+ def loss_ce(log_rnd):
81
+ """Cross entropy loss KL(P^*||P^u)"""
82
+ weights = log_rnd.detach().softmax(dim=-1)
83
+ return (log_rnd * weights).sum()
84
+
85
+
86
+ def loss_lv(log_rnd):
87
+ r"""Log variance loss Var_{P^\bar{u}}\log\frac{dP^*}{dP^u}"""
88
+ return log_rnd.var()
89
+
90
+
91
+ def loss_re_rf(log_rnd, const=0):
92
+ r"""Relative entropy loss KL(P^u||P^*) with REINFORCE trick"""
93
+ return (-log_rnd * (-log_rnd.detach() + const)).mean()
94
+
95
+
96
+ def loss_wdce(policy_model, log_rnd, x, num_replicates=16, weight_func=lambda l: 1/l, eps=1e-3, centering = False):
97
+ r"""
98
+ Weighted denoising cross entropy loss
99
+ X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X)
100
+
101
+ log_rnd: [B]; x: [B, L] (no mask)
102
+ num_replicates: R, number of replicates of each row in x
103
+ weight_func: w(lambda) for each sample, 1/lambda by default
104
+ """
105
+ mask_index = policy_model.mask_index
106
+ if hasattr(policy_model, 'module'):
107
+ policy_model = policy_model.module
108
+
109
+ batch = x.repeat_interleave(num_replicates, dim=0) # [B*R, L]
110
+ batch_weights = log_rnd.detach_().softmax(dim=-1)
111
+ if centering:
112
+ batch_weights = batch_weights - batch_weights.mean(dim=-1, keepdim=True)
113
+
114
+ batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0) # [B*R]
115
+
116
+ lamda = torch.rand(batch.shape[0], device=batch.device) # [B*R]
117
+ lamda_weights = weight_func(lamda).clamp(max=1e5) # [B*R]
118
+
119
+ masked_index = torch.rand(*batch.shape, device=batch.device) < lamda[..., None] # [B*R, D]
120
+ perturbed_batch = torch.where(masked_index, mask_index, batch)
121
+
122
+ # add time conditioning
123
+ t = lamda
124
+ sigma_t = -torch.log1p(-(1 - eps) * t)
125
+
126
+ # compute logits
127
+ logits = policy_model(perturbed_batch, sigma_t)
128
+ losses = torch.zeros(*batch.shape, device=batch.device, dtype=logits.dtype) # [B*R, D]
129
+ losses[masked_index] = torch.gather(input=logits[masked_index], dim=-1,
130
+ index=batch[masked_index][..., None]).squeeze(-1)
131
+ return - (losses.sum(dim=-1) * lamda_weights * batch_weights).mean()
132
+
133
+
134
+ def loss_dce(model, x, weight_func=lambda l: 1/l):
135
+ r"""
136
+ Denoising cross entropy loss, x [B, D] are ground truth samples
137
+ weight_func: w(lambda) for each sample, 1/lambda by default
138
+ """
139
+ lamda = torch.rand(x.shape[0], device=x.device) # [B]
140
+ lamda_weights = weight_func(lamda).clamp(max=1e5) # [B]
141
+ masked_index = torch.rand(*x.shape, device=x.device) < lamda[..., None] # [B, D]
142
+ perturbed_batch = torch.where(masked_index, model.vocab_size-1, x)
143
+ logits = model(perturbed_batch)
144
+ losses = torch.zeros(*x.shape, device=x.device, dtype=logits.dtype) # [B, D]
145
+ losses[masked_index] = torch.gather(input=logits[masked_index], dim=-1,
146
+ index=x[masked_index][..., None]).squeeze(-1)
147
+ return - (losses.sum(dim=-1) * lamda_weights).mean()
tr2d2-dna/mcts.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import random as rd
6
+ from finetune_utils import to_one_hot
7
+ from utils import StepTimer
8
+
9
+ import noise_schedule
10
+
11
+ ### BEGINNING OF NODE CLASS ###
12
+
13
+ class Node:
14
+ """
15
+ Node class: partially unmasked sequence
16
+ - parentNode: Node object at previous time step
17
+ - childNodes: set of M Node objects generated from sampling M distinct unmasking schemes
18
+ - totalReward: vector of cumulative rewards for all K objectives
19
+ - visits: number of times the node has been visited by an interation
20
+ - path: array of partially unmasked SMILES strings leading to the node from the completely masked root node
21
+ - timestep: the time step where the sequence was sampled
22
+ """
23
+ def __init__(self, args, tokens=None, log_rnd=None, log_policy_step=None, log_pretrained_step=None, parentNode=None, childNodes=None, totalReward=None, timestep=None):
24
+ self.args = args
25
+ self.parentNode = parentNode
26
+ self.childNodes = [] if childNodes is None else childNodes
27
+
28
+ self.log_rnd = log_rnd # stores the log_rnd up to that step
29
+
30
+ #self.log_p0 = 0 # stores the log probabiltiy of the unmasking step from the previous iteration
31
+ self.log_policy_step = log_policy_step # stores the log probability of the unmasking step under the current policy
32
+ self.log_pretrained_step = log_pretrained_step
33
+
34
+ # initialize total rewards to the reward of the roll out unmasked sequence
35
+ self.totalReward = totalReward # potential reward of the node based on generated children
36
+
37
+ # set initial visits to 1
38
+ self.visits = 1
39
+
40
+ #self.path = path
41
+
42
+ # set timestep (value between 0 and num_steps)
43
+ self.timestep = timestep
44
+ # set the sampling probabiltiy equal to the probability from the reverse posterior
45
+ #self.sampleProb = sampleProb # stores the probability of the sampling step under the current policy
46
+
47
+ # dict with 'seqs' as token array and 'attention_mask'
48
+ self.tokens = tokens
49
+
50
+ def selectNode(self, rootNode):
51
+ """
52
+ Selects a node to move to among the children nodes based on select score
53
+ """
54
+ # extract the status of the current node
55
+ nodeStatus = self.getExpandStatus()
56
+
57
+ # if the node is a legal non-leaf node
58
+ if (nodeStatus == 3):
59
+ # initialize array that will store select score vectors of each child node
60
+ selectScores = []
61
+ selectable_children = [] # children nodes that can be selected
62
+
63
+ for childNode in self.childNodes:
64
+ childStatus = childNode.getExpandStatus()
65
+ # only append child if it is legal leaf node (expandable) or legal non-leaf node
66
+ if childStatus == 2 or childStatus == 3:
67
+ selectScore = childNode.calcSelectScore()
68
+ if torch.is_tensor(selectScore) and selectScore.numel() == 1:
69
+ selectScore = selectScore.item()
70
+
71
+ selectable_children.append(childNode)
72
+ selectScores.append(float(selectScore))
73
+
74
+ # no selectable children
75
+ if len(selectable_children) == 0:
76
+ return rootNode, 3
77
+
78
+ selectScores = np.asarray(selectScores, dtype=np.float64)
79
+
80
+ temp = 1.0
81
+ # compute softmax probabiltiies
82
+ m = np.max(selectScores)
83
+ exps = np.exp((selectScores - m) / temp)
84
+ tot = exps.sum()
85
+
86
+ if not np.isfinite(tot) or tot <= 0.0:
87
+ probs = np.full(len(selectable_children), 1.0 / len(selectable_children))
88
+ else:
89
+ probs = exps / tot
90
+
91
+ # choose child index from categorical distribution
92
+ idx = np.random.choice(len(selectable_children), p=probs)
93
+ selected = selectable_children[idx]
94
+
95
+ # return selected child node and status
96
+ return selected, selected.getExpandStatus()
97
+
98
+ elif (nodeStatus == 2):
99
+ return self, nodeStatus
100
+
101
+ # if node is not valid non-leaf node
102
+ return rootNode, 3
103
+
104
+ def selectNodeTopK(self, rootNode, k = 3, temp = 1.0):
105
+ """
106
+ Pick from the top-k by select score.
107
+ Returns: (selected_node, selected_status)
108
+ """
109
+ nodeStatus = self.getExpandStatus()
110
+
111
+ # If expandable leaf, return it directly
112
+ if nodeStatus == 2:
113
+ return self, nodeStatus
114
+
115
+ if nodeStatus == 3:
116
+ selectable_children = []
117
+ selectScores = []
118
+
119
+ # collect candidates
120
+ for ch in self.childNodes:
121
+ s = ch.getExpandStatus()
122
+ if s in (2, 3):
123
+ sc = ch.calcSelectScore()
124
+ if torch.is_tensor(sc):
125
+ sc = sc.item() if sc.numel() == 1 else float(sc.mean().item())
126
+ sc = float(sc) if np.isfinite(sc) else -np.inf # push bad scores to -inf
127
+ selectable_children.append(ch)
128
+ selectScores.append(sc)
129
+
130
+ if not selectable_children:
131
+ return rootNode, 3
132
+
133
+ scores = np.asarray(selectScores, dtype=np.float64)
134
+
135
+ # top-k indices (largest scores)
136
+ k_eff = min(k, len(scores))
137
+ topk_idx = np.argpartition(-scores, kth=k_eff-1)[:k_eff]
138
+ # sort the top-k by score desc for stability
139
+ topk_idx = topk_idx[np.argsort(-scores[topk_idx])]
140
+
141
+ # slice down to top-k pool
142
+ pool_scores = scores[topk_idx]
143
+ pool_children = [selectable_children[i] for i in topk_idx]
144
+
145
+ # softmax over the top-k
146
+ m = np.max(pool_scores)
147
+ z = (pool_scores - m) / max(1e-8, temp)
148
+ exps = np.exp(np.clip(z, -60, 60))
149
+ tot = exps.sum()
150
+ if not np.isfinite(tot) or tot <= 0.0:
151
+ idx_local = 0 # best
152
+ else:
153
+ probs = exps / tot
154
+
155
+ idx_local = int(np.random.choice(len(pool_children), p=probs))
156
+
157
+ selected = pool_children[idx_local]
158
+ return selected, selected.getExpandStatus()
159
+
160
+ return rootNode, 3
161
+
162
+ def addChildNode(self, tokens, log_rnd, log_policy_step, log_pretrained_step, totalReward):
163
+ """"
164
+ Adds a child node:
165
+ log_rnd: log_rnd of the path up to the added child node
166
+ log_policy_step: scalar value of the log-prob of sampling the step under the policy
167
+ log_pretrained_step: scalar value of the log-prob of sampling the step under the pretrained model
168
+ """
169
+ child = Node(args=self.args,
170
+ tokens=tokens,
171
+ log_rnd = log_rnd,
172
+ log_policy_step=log_policy_step,
173
+ log_pretrained_step=log_pretrained_step,
174
+ parentNode=self,
175
+ childNodes=[],
176
+ totalReward=totalReward,
177
+ timestep=self.timestep+1)
178
+
179
+ self.childNodes.append(child)
180
+ return child
181
+
182
+ def update_logrnd(self, log_policy_step, log_rnd):
183
+ self.log_policy_step = log_policy_step
184
+ self.log_rnd = log_rnd
185
+
186
+ def updateNode(self, rewards):
187
+ """
188
+ Updates the cumulative rewards vector with the reward vector at a descendent leaf node.
189
+ Increments the number of visits to the node.
190
+ """
191
+ self.visits += 1
192
+
193
+ self.totalReward += rewards # singleton tensor
194
+
195
+ def calcSelectScore(self):
196
+ """
197
+ Calculates the select score for the node from the cumulative rewards vector and number of visits.
198
+ - c: determines the degree of exploration
199
+ - minSelectScore: determines the
200
+ """
201
+ # K-dimensional vector of normalized rewards for each objective
202
+ normRewards = self.totalReward / self.visits
203
+
204
+ # scales the cumulative reward by the sampling probability
205
+
206
+ return normRewards + (self.args.exploration * self.log_policy_step * np.sqrt(self.parentNode.visits) / self.visits)
207
+
208
+ def getExpandStatus(self):
209
+ """
210
+ Returns an integer indicating whether the node is a:
211
+ 1. terminal node (sequence is fully unmasked)
212
+ 2. legal leaf node (partially unmasked sequence that can be expanded)
213
+ 3. legal non-leaf node (already expanded sequence with M child nodes)
214
+ """
215
+ if self.timestep == self.args.total_num_steps:
216
+ return 1
217
+ elif (self.timestep < self.args.total_num_steps) and (len(self.childNodes) == 0):
218
+ return 2
219
+ return 3
220
+
221
+ ### END OF NODE CLASS ###
222
+
223
+ ### BEGINNING OF MCTS CLASS ###
224
+
225
+ class MCTS:
226
+ def __init__(self, args, config, policy_model, pretrained, rewardFunc, rootNode=None):
227
+ self.timer = StepTimer(policy_model.device)
228
+
229
+ # debugging
230
+ self.buf_stats = {"insert":0, "replace":0, "reject_worse":0,
231
+ "reject_dup":0, "reject_nonfinite":0}
232
+ self._seen_hashes = set()
233
+
234
+ self.device = policy_model.device
235
+ print(f"MCTS device: {self.device}")
236
+
237
+ self.args = args
238
+ self.config = config
239
+ self.noise = noise_schedule.get_noise(config)
240
+ self.time_conditioning = args.time_conditioning
241
+
242
+ self.mask_index = policy_model.mask_index
243
+ masked_seq = torch.ones((self.args.seq_length), device = self.device) * self.mask_index
244
+ masked_tokens = {'seqs': masked_seq.to(dtype=torch.long), 'attention_mask': torch.ones_like(masked_seq).to(self.device)}
245
+ if rootNode is None:
246
+ self.rootNode = Node(self.args, tokens = masked_tokens,
247
+ log_rnd=torch.zeros((), device=self.device),
248
+ log_policy_step=torch.zeros((), device=self.device),
249
+ log_pretrained_step=torch.zeros((), device=self.device),
250
+ totalReward=torch.zeros((), device=self.device), timestep=0)
251
+ else:
252
+ self.rootNode = rootNode # stores the root node of the tree
253
+
254
+ # dictionary:
255
+ # "seq": final unmasked sequence
256
+ # "traj": list of (N_steps, L)
257
+ # "reward": reward of the trajectory
258
+ self.buffer = [] # List[Dict[str, Any]]
259
+
260
+ self.buffer_size = args.buffer_size
261
+
262
+ self.num_steps = args.total_num_steps
263
+ self.num_sequences = args.num_sequences
264
+
265
+ # pretrained model
266
+ self.pretrained = pretrained
267
+
268
+ # the policy model that we want to finetune
269
+ self.policy_model = policy_model
270
+ #self.tokenizer = policy_model.tokenizer
271
+ self.device = policy_model.device
272
+
273
+ self.sequence_length = args.seq_length
274
+
275
+ self.num_iter = args.num_iter
276
+
277
+ self.num_children = args.num_children
278
+
279
+ # score functions
280
+ self.rewardFunc = rewardFunc
281
+
282
+ self.iter_num = 0
283
+
284
+ self.reward_log = []
285
+ self.logrnd_log = []
286
+
287
+ self.policy_model.eval()
288
+ self.pretrained.eval()
289
+ self.rewardFunc.eval()
290
+
291
+ def _hash_tokens(self, t):
292
+ # t: (L,) torch.long
293
+ return tuple(t.detach().cpu().tolist())
294
+
295
+ def reset(self, resetTree):
296
+ self.iter_num = 0
297
+ self.buffer = []
298
+ self._seen_hashes = set() # Clear the hash set too!
299
+ self.reward_log = []
300
+ self.logrnd_log = []
301
+
302
+ # add option to continue with the same tree
303
+ if resetTree:
304
+ masked_seq = torch.ones((self.args.seq_length), device = self.device) * self.mask_index
305
+ masked_tokens = {'seqs': masked_seq.to(dtype=torch.long), 'attention_mask': torch.ones_like(masked_seq).to(self.device)}
306
+ self.rootNode = Node(self.args, tokens = masked_tokens,
307
+ log_rnd=torch.zeros((), device=self.device),
308
+ log_policy_step=torch.zeros((), device=self.device),
309
+ log_pretrained_step=torch.zeros((), device=self.device),
310
+ totalReward=torch.zeros((), device=self.device), timestep=0)
311
+
312
+ def forward(self, resetTree=False):
313
+
314
+ self.reset(resetTree)
315
+
316
+ while (self.iter_num < self.num_iter):
317
+ self.iter_num += 1
318
+
319
+ # traverse the tree form the root node until a leaf node
320
+ with self.timer.section("select"):
321
+ leafNode, _ = self.select(self.rootNode)
322
+
323
+ # expand leaf node into num_children partially unmasked sequences at the next timestep
324
+ with self.timer.section("expand"):
325
+ self.expand(leafNode)
326
+
327
+ final_x, log_rnd, final_rewards = self.consolidateBuffer()
328
+
329
+ rows = self.timer.summary()
330
+ print("\n=== Timing summary (by total time) ===")
331
+ for name, cnt, total, mean, p50, p95 in rows:
332
+ print(f"{name:30s} n={cnt:5d} total={total:8.3f}s mean={mean*1e3:7.2f}ms "
333
+ f"p50={p50*1e3:7.2f}ms p95={p95*1e3:7.2f}ms")
334
+
335
+ # return final_seqs (B, L), log_rnd (B, ), and final rewards (B, )
336
+ return final_x, log_rnd, final_rewards
337
+
338
+
339
+ def updateBuffer(self, x_final, log_rnd, final_reward):
340
+ B = x_final.shape[0]
341
+ for i in range(B):
342
+ # Finite check
343
+ if not torch.isfinite(final_reward[i]) or not torch.isfinite(log_rnd[i]):
344
+ self.buf_stats["reject_nonfinite"] += 1
345
+ continue
346
+
347
+ h = self._hash_tokens(x_final[i])
348
+ if h in self._seen_hashes:
349
+ self.buf_stats["reject_dup"] += 1
350
+ continue
351
+
352
+ item = {"x_final": x_final[i].clone(),
353
+ "log_rnd": log_rnd[i].clone(),
354
+ "final_reward": final_reward[i].clone()}
355
+
356
+ if len(self.buffer) < self.buffer_size:
357
+ self.buffer.append(item)
358
+ self._seen_hashes.add(h)
359
+ self.buf_stats["insert"] += 1
360
+ else:
361
+ # replace if strictly better, or tie-break with log_rnd
362
+ min_idx, min_item = min(
363
+ enumerate(self.buffer),
364
+ key=lambda kv: (kv[1]["final_reward"].item(), kv[1]["log_rnd"].item())
365
+ )
366
+ cand_key = (final_reward[i].item(), log_rnd[i].item())
367
+ min_key = (min_item["final_reward"].item(), min_item["log_rnd"].item())
368
+
369
+ if cand_key > min_key: # allow ties via 2nd key
370
+ # update hash set
371
+ old_h = self._hash_tokens(self.buffer[min_idx]["x_final"])
372
+ if old_h in self._seen_hashes:
373
+ self._seen_hashes.remove(old_h)
374
+ self.buffer[min_idx] = item
375
+ self._seen_hashes.add(h)
376
+ self.buf_stats["replace"] += 1
377
+ else:
378
+ self.buf_stats["reject_worse"] += 1
379
+
380
+ def print_buffer_stats(self):
381
+ print("[BUFFER] ",
382
+ " ".join(f"{k}={v}" for k,v in self.buf_stats.items()),
383
+ f" size={len(self.buffer)}/{self.buffer_size}")
384
+ if len(self.buffer):
385
+ vals = torch.stack([b["final_reward"] for b in self.buffer]).float()
386
+ print(f"[BUFFER] reward min/mean/max: {vals.min():.4f} {vals.mean():.4f} {vals.max():.4f}")
387
+
388
+ def consolidateBuffer(self):
389
+ """
390
+ returns x_final, log_rnd, and final_rewards in tensors
391
+ """
392
+ x_final = []
393
+ log_rnd = []
394
+ final_rewards = []
395
+ for item in self.buffer:
396
+ x_final.append(item["x_final"])
397
+ log_rnd.append(item["log_rnd"])
398
+ final_rewards.append(item["final_reward"])
399
+
400
+ x_final = torch.stack(x_final, dim=0) # (B, L)
401
+ log_rnd = torch.stack(log_rnd, dim=0).to(dtype=torch.float32) # (B)
402
+ final_rewards = torch.stack(final_rewards, dim=0).to(dtype=torch.float32) # (B)
403
+
404
+ return x_final, log_rnd, final_rewards
405
+
406
+
407
+ def isPathEnd(self, path, maxDepth):
408
+ """
409
+ Checks if the node is completely unmasked (ie. end of path)
410
+ or if the path is at the max depth
411
+ """
412
+ if (path[-1] != self.mask_index).all():
413
+ return True
414
+ elif len(path) >= maxDepth:
415
+ return True
416
+ return False
417
+
418
+ def select(self, currNode, eps=1e-5):
419
+ """
420
+ Traverse the tree from the root node until reaching a legal leaf node
421
+ """
422
+ #iter = 1
423
+ updated_log_rnd = torch.zeros((), device=self.device)
424
+ while True:
425
+ if self.args.select_topk:
426
+ currNode, nodeStatus = currNode.selectNodeTopK(self.rootNode, k=self.args.select_topk_value, temp=1.0)
427
+ else:
428
+ currNode, nodeStatus = currNode.selectNode(self.rootNode)
429
+
430
+ if currNode.parentNode is not None:
431
+ # compute new log_policy
432
+ child_tokens = currNode.tokens['seqs'].to(self.device)
433
+ attn_mask = currNode.tokens['attention_mask'].to(self.device)
434
+ parent = currNode.parentNode
435
+ parent_tokens = parent.tokens['seqs'].to(self.device)
436
+ t = torch.ones(1, device = self.device)
437
+ dt = (1 - eps) / self.num_steps
438
+ with torch.no_grad():
439
+ with self.timer.section("select.compute_log_policy"):
440
+ updated_log_policy_step = self.policy_model.compute_log_policy(parent_tokens,
441
+ child_tokens,
442
+ t=t, dt=dt)
443
+ updated_log_rnd += (currNode.log_pretrained_step - updated_log_policy_step)
444
+
445
+ currNode.update_logrnd(updated_log_policy_step, updated_log_rnd) # update log_rnd
446
+
447
+ # node is terminal node or logal leaf node, return for expansion
448
+ if nodeStatus == 2:
449
+ return currNode, nodeStatus
450
+ elif nodeStatus == 1:
451
+ currNode = self.rootNode
452
+
453
+ def expand(self, parentNode, eps=1e-5):
454
+ """
455
+ Sample unmasking steps from the pre-trained MDLM
456
+ adds num_children partially unmasked sequences to the children of the parentNode
457
+ """
458
+
459
+ num_children = self.num_children
460
+ # initialize child rewards that will be added to total rewards
461
+
462
+ allChildReward = torch.zeros((), device=self.device)
463
+
464
+ # compute number of rollout steps
465
+ # if parentNode.timestep = self.num_steps then num_rollout_steps = 1
466
+ num_rollout_steps = self.num_steps - parentNode.timestep
467
+ # array of rollout timesteps from the timestep of parent node to 0
468
+ rollout_t = torch.linspace(1, eps, self.num_steps + 1, device=self.device)
469
+ dt = (1 - eps) / self.num_steps
470
+
471
+ # initialize x and attn_mask
472
+ x = parentNode.tokens['seqs'].to(self.device)
473
+ attn_mask = parentNode.tokens['attention_mask'].to(self.device)
474
+ parent_log_rnd = parentNode.log_rnd # stores the log_rnd up to parent node
475
+
476
+ t = rollout_t[parentNode.timestep] * torch.ones(1, 1, device = self.device)
477
+
478
+ # generate (n_children, seq_length) array of sampled children nodes
479
+
480
+ # sample M child sequences and compute their log probabilities
481
+ with torch.no_grad():
482
+ with self.timer.section("expand.batch_mcts_reverse_step"):
483
+ child_log_p, x_children, child_log_policy_step, child_log_pretrained_step = \
484
+ self.policy_model.batch_mcts_reverse_step(token_array=x,
485
+ t=t, dt=dt,
486
+ batch_size=num_children,
487
+ pretrained=self.pretrained)
488
+
489
+ # compute weight of the step (num_children, 1)
490
+
491
+ child_log_rnd = (parent_log_rnd + (child_log_pretrained_step - child_log_policy_step)).to(self.device)
492
+
493
+ x_rollout = x_children
494
+
495
+ traj_log_rnd = child_log_rnd # initialize log_rnd for entire rolled out trajectory
496
+
497
+ # rollout under the policy and compute the log ratio at each step
498
+ with self.timer.section("expand.rollout_total"):
499
+ for i in range(1, num_rollout_steps):
500
+ t = rollout_t[parentNode.timestep + i] * torch.ones(num_children, 1, device = self.device)
501
+
502
+ with torch.no_grad():
503
+ log_p, x_next, log_policy_step, log_pretrained_step = \
504
+ self.policy_model.mcts_reverse_step(x_rollout,
505
+ t=t, dt=dt,
506
+ pretrained=self.pretrained)
507
+
508
+ # add the rollout step
509
+ traj_log_rnd += log_pretrained_step - log_policy_step
510
+
511
+ x_rollout = x_next
512
+
513
+ # if mask token remains, fully unmask
514
+ mask_positions = (x_rollout == self.mask_index) # (B, L) bool
515
+
516
+ # does **any** mask remain in any sequence
517
+ any_mask_global = mask_positions.any().item() # true if mask remains
518
+ if any_mask_global:
519
+ with torch.no_grad():
520
+ with self.timer.section("expand.noise_removal"):
521
+ log_p, x_next, log_policy_step, log_pretrained_step = \
522
+ self.policy_model.mcts_noise_removal(x_rollout,
523
+ t=t, dt=dt,
524
+ pretrained=self.pretrained)
525
+
526
+ traj_log_rnd += log_pretrained_step - log_policy_step
527
+
528
+ x_rollout = x_next
529
+
530
+ x_final = x_rollout # final sequences (B, L)
531
+
532
+ # edit? how is the reward model defined?
533
+ #childSequences = self.tokenizer.batch_decode(x_rollout)
534
+
535
+ #if self.args.data == "peptide":
536
+ #validSequences = []
537
+
538
+
539
+ # get final rewards
540
+ x_one_hot = to_one_hot(x_final)
541
+ x_one_hot_reward = torch.transpose(x_one_hot, 1, 2)
542
+ reward_preds = self.rewardFunc(x_one_hot_reward).squeeze(-1) # (num_children, 4)
543
+ rewards_value = reward_preds[:, 0] # (num_children, 1)
544
+
545
+ if self.args.reward_clip:
546
+ rewards = torch.clamp(rewards_value, max=self.args.reward_clip_value)
547
+ else:
548
+ rewards = rewards_value
549
+
550
+ traj_log_rnd += rewards / self.args.alpha
551
+
552
+ self.reward_log.append(rewards.detach().cpu().numpy())
553
+ self.logrnd_log.append(traj_log_rnd.detach().cpu().numpy())
554
+
555
+ # update buffer
556
+ with self.timer.section("expand.update_buffer"):
557
+ self.updateBuffer(x_final, traj_log_rnd, rewards)
558
+
559
+ for i in range(num_children):
560
+
561
+ # add to all child reward vector for backprop
562
+ allChildReward += rewards[i]
563
+
564
+ # create node for sequence and add to the children node of parent
565
+ childTokens = {'seqs': x_children[i].to(dtype=torch.long), 'attention_mask': attn_mask}
566
+ parentNode.addChildNode(tokens=childTokens,
567
+ log_rnd=child_log_rnd[i],
568
+ log_policy_step=child_log_policy_step[i],
569
+ log_pretrained_step=child_log_pretrained_step[i],
570
+ totalReward=rewards[i])
571
+
572
+ # backpropogate all child rewards
573
+ with self.timer.section("expand.backprop"):
574
+ self.backprop(parentNode, allChildReward)
575
+
576
+
577
+ def backprop(self, node, allChildReward):
578
+ # backpropogate rewards through the path leading to the leaf node from the root
579
+ while node:
580
+ node.updateNode(allChildReward)
581
+ node = node.parentNode
tr2d2-dna/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import ema
2
+ from . import dnaconv
tr2d2-dna/models/dnaconv.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import copy
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class GaussianFourierProjection(nn.Module):
9
+ """
10
+ Gaussian random features for encoding time steps.
11
+ """
12
+
13
+ def __init__(self, embed_dim, scale=30.):
14
+ super().__init__()
15
+ # Randomly sample weights during initialization. These weights are fixed
16
+ # during optimization and are not trainable.
17
+ self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
18
+
19
+ def forward(self, x):
20
+ x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
21
+ return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
22
+
23
+
24
+ class Dense(nn.Module):
25
+ """
26
+ A fully connected layer that reshapes outputs to feature maps.
27
+ """
28
+
29
+ def __init__(self, input_dim, output_dim):
30
+ super().__init__()
31
+ self.dense = nn.Linear(input_dim, output_dim)
32
+
33
+ def forward(self, x):
34
+ return self.dense(x)[...]
35
+
36
+ # from https://github.com/HannesStark/dirichlet-flow-matching
37
+ class CNNModel(nn.Module):
38
+ def __init__(self, args, alphabet_size, num_cls, classifier=False):
39
+ super().__init__()
40
+ self.alphabet_size = alphabet_size
41
+ self.args = args
42
+ self.classifier = classifier
43
+ self.num_cls = num_cls
44
+
45
+ if self.args.clean_data:
46
+ self.linear = nn.Embedding(self.alphabet_size, embedding_dim=args.hidden_dim)
47
+ else:
48
+ inp_size = self.alphabet_size #+ 1
49
+ self.linear = nn.Conv1d(inp_size, args.hidden_dim, kernel_size=9, padding=4)
50
+ self.time_embedder = nn.Sequential(GaussianFourierProjection(embed_dim= args.hidden_dim),nn.Linear(args.hidden_dim, args.hidden_dim))
51
+
52
+ self.num_layers = 5 * args.num_cnn_stacks
53
+ self.convs = [nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, padding=4),
54
+ nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, padding=4),
55
+ nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=4, padding=16),
56
+ nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=16, padding=64),
57
+ nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=64, padding=256)]
58
+ self.convs = nn.ModuleList([copy.deepcopy(layer) for layer in self.convs for i in range(args.num_cnn_stacks)])
59
+ self.time_layers = nn.ModuleList([Dense(args.hidden_dim, args.hidden_dim) for _ in range(self.num_layers)])
60
+ self.norms = nn.ModuleList([nn.LayerNorm(args.hidden_dim) for _ in range(self.num_layers)])
61
+ self.final_conv = nn.Sequential(nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=1),
62
+ nn.ReLU(),
63
+ nn.Conv1d(args.hidden_dim, args.hidden_dim if classifier else self.alphabet_size, kernel_size=1))
64
+ self.dropout = nn.Dropout(args.dropout)
65
+ if classifier:
66
+ self.cls_head = nn.Sequential(nn.Linear(args.hidden_dim, args.hidden_dim),
67
+ nn.ReLU(),
68
+ nn.Linear(args.hidden_dim, self.num_cls))
69
+
70
+ if self.args.cls_free_guidance and not self.classifier:
71
+ self.cls_embedder = nn.Embedding(num_embeddings=self.num_cls + 1, embedding_dim=args.hidden_dim)
72
+ self.cls_layers = nn.ModuleList([Dense(args.hidden_dim, args.hidden_dim) for _ in range(self.num_layers)])
73
+
74
+ def forward(self, seq, t, cls = None, return_embedding=False):
75
+ # adapt it to support both seq indices input and one-hot input
76
+ if not (seq.ndim > 2 and seq.shape[-1] == self.alphabet_size):
77
+ seq = F.one_hot(seq, num_classes=self.alphabet_size).float()
78
+
79
+ if self.args.clean_data:
80
+ feat = self.linear(seq)
81
+ feat = feat.permute(0, 2, 1)
82
+ else:
83
+ time_emb = F.relu(self.time_embedder(t))
84
+ feat = seq.permute(0, 2, 1)
85
+ feat = F.relu(self.linear(feat))
86
+
87
+
88
+ if self.args.cls_free_guidance and not self.classifier:
89
+ cls_emb = self.cls_embedder(cls)
90
+
91
+ for i in range(self.num_layers):
92
+ h = self.dropout(feat.clone())
93
+
94
+ if not self.args.clean_data:
95
+ h = h + self.time_layers[i](time_emb)[:, :, None]
96
+
97
+ if self.args.cls_free_guidance and not self.classifier:
98
+ h = h + self.cls_layers[i](cls_emb)[:, :, None]
99
+
100
+
101
+ h = self.norms[i]((h).permute(0, 2, 1))
102
+ h = F.relu(self.convs[i](h.permute(0, 2, 1)))
103
+
104
+ if h.shape == feat.shape:
105
+ feat = h + feat
106
+ else:
107
+ feat = h
108
+
109
+ feat = self.final_conv(feat)
110
+
111
+ feat = feat.permute(0, 2, 1)
112
+
113
+ if self.classifier:
114
+ feat = feat.mean(dim=1)
115
+ if return_embedding:
116
+ embedding = self.cls_head[:1](feat)
117
+ return self.cls_head[1:](embedding), embedding
118
+ else:
119
+ return self.cls_head(feat)
120
+
121
+ return feat
tr2d2-dna/models/ema.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class ExponentialMovingAverage:
5
+ """
6
+ Maintains (exponential) moving average of a set of parameters.
7
+ """
8
+
9
+ def __init__(self, parameters, decay, use_num_updates=True):
10
+ """
11
+ Args:
12
+ parameters: Iterable of `torch.nn.Parameter`; usually the result of
13
+ `model.parameters()`.
14
+ decay: The exponential decay.
15
+ use_num_updates: Whether to use number of updates when computing
16
+ averages.
17
+ """
18
+ if decay < 0.0 or decay > 1.0:
19
+ raise ValueError('Decay must be between 0 and 1')
20
+ self.decay = decay
21
+ self.num_updates = 0 if use_num_updates else None
22
+ self.shadow_params = [p.clone().detach()
23
+ for p in parameters if p.requires_grad]
24
+ self.collected_params = []
25
+
26
+ def move_shadow_params_to_device(self, device):
27
+ self.shadow_params = [i.to(device) for i in self.shadow_params]
28
+
29
+ def update(self, parameters):
30
+ """
31
+ Update currently maintained parameters.
32
+
33
+ Call this every time the parameters are updated, such as the result of
34
+ the `optimizer.step()` call.
35
+
36
+ Args:
37
+ parameters: Iterable of `torch.nn.Parameter`; usually the same set of
38
+ parameters used to initialize this object.
39
+ """
40
+ decay = self.decay
41
+ if self.num_updates is not None:
42
+ self.num_updates += 1
43
+ decay = min(decay, (1 + self.num_updates) /
44
+ (10 + self.num_updates))
45
+ one_minus_decay = 1.0 - decay
46
+ with torch.no_grad():
47
+ parameters = [p for p in parameters if p.requires_grad]
48
+ for s_param, param in zip(self.shadow_params, parameters):
49
+ s_param.sub_(one_minus_decay * (s_param - param))
50
+
51
+ def copy_to(self, parameters):
52
+ """
53
+ Copy current parameters into given collection of parameters.
54
+
55
+ Args:
56
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
57
+ updated with the stored moving averages.
58
+ """
59
+ parameters = [p for p in parameters if p.requires_grad]
60
+ for s_param, param in zip(self.shadow_params, parameters):
61
+ if param.requires_grad:
62
+ param.data.copy_(s_param.data)
63
+
64
+ def store(self, parameters):
65
+ """
66
+ Save the current parameters for restoring later.
67
+
68
+ Args:
69
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
70
+ temporarily stored.
71
+ """
72
+ self.collected_params = [param.clone() for param in parameters]
73
+
74
+ def restore(self, parameters):
75
+ """
76
+ Restore the parameters stored with the `store` method.
77
+ Useful to validate the model with EMA parameters without affecting the
78
+ original optimization process. Store the parameters before the
79
+ `copy_to` method. After validation (or model saving), use this to
80
+ restore the former parameters.
81
+
82
+ Args:
83
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
84
+ updated with the stored parameters.
85
+ """
86
+ for c_param, param in zip(self.collected_params, parameters):
87
+ param.data.copy_(c_param.data)
88
+
89
+ def state_dict(self):
90
+ return dict(decay=self.decay,
91
+ num_updates=self.num_updates,
92
+ shadow_params=self.shadow_params)
93
+
94
+ def load_state_dict(self, state_dict):
95
+ self.decay = state_dict['decay']
96
+ self.num_updates = state_dict['num_updates']
97
+ self.shadow_params = state_dict['shadow_params']
tr2d2-dna/noise_schedule.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ # Flags required to enable jit fusion kernels
7
+ torch._C._jit_set_profiling_mode(False)
8
+ torch._C._jit_set_profiling_executor(False)
9
+ torch._C._jit_override_can_fuse_on_cpu(True)
10
+ torch._C._jit_override_can_fuse_on_gpu(True)
11
+
12
+
13
+ def get_noise(config, dtype=torch.float32):
14
+ if config.noise.type == 'geometric':
15
+ return GeometricNoise(config.noise.sigma_min,
16
+ config.noise.sigma_max)
17
+ elif config.noise.type == 'loglinear':
18
+ return LogLinearNoise()
19
+ elif config.noise.type == 'cosine':
20
+ return CosineNoise()
21
+ elif config.noise.type == 'cosinesqr':
22
+ return CosineSqrNoise()
23
+ elif config.noise.type == 'linear':
24
+ return Linear(config.noise.sigma_min,
25
+ config.noise.sigma_max,
26
+ dtype)
27
+ else:
28
+ raise ValueError(f'{config.noise.type} is not a valid noise')
29
+
30
+
31
+ def binary_discretization(z):
32
+ z_hard = torch.sign(z)
33
+ z_soft = z / torch.norm(z, dim=-1, keepdim=True)
34
+ return z_soft + (z_hard - z_soft).detach()
35
+
36
+
37
+ class Noise(abc.ABC, nn.Module):
38
+ """
39
+ Baseline forward method to get the total + rate of noise at a timestep
40
+ """
41
+ def forward(self, t):
42
+ # Assume time goes from 0 to 1
43
+ return self.total_noise(t), self.rate_noise(t)
44
+
45
+ @abc.abstractmethod
46
+ def rate_noise(self, t):
47
+ """
48
+ Rate of change of noise ie g(t)
49
+ """
50
+ pass
51
+
52
+ @abc.abstractmethod
53
+ def total_noise(self, t):
54
+ """
55
+ Total noise ie \int_0^t g(t) dt + g(0)
56
+ """
57
+ pass
58
+
59
+
60
+ class CosineNoise(Noise):
61
+ def __init__(self, eps=1e-3):
62
+ super().__init__()
63
+ self.eps = eps
64
+
65
+ def rate_noise(self, t):
66
+ cos = (1 - self.eps) * torch.cos(t * torch.pi / 2)
67
+ sin = (1 - self.eps) * torch.sin(t * torch.pi / 2)
68
+ scale = torch.pi / 2
69
+ return scale * sin / (cos + self.eps)
70
+
71
+ def total_noise(self, t):
72
+ cos = torch.cos(t * torch.pi / 2)
73
+ return - torch.log(self.eps + (1 - self.eps) * cos)
74
+
75
+
76
+ class CosineSqrNoise(Noise):
77
+ def __init__(self, eps=1e-3):
78
+ super().__init__()
79
+ self.eps = eps
80
+
81
+ def rate_noise(self, t):
82
+ cos = (1 - self.eps) * (
83
+ torch.cos(t * torch.pi / 2) ** 2)
84
+ sin = (1 - self.eps) * torch.sin(t * torch.pi)
85
+ scale = torch.pi / 2
86
+ return scale * sin / (cos + self.eps)
87
+
88
+ def total_noise(self, t):
89
+ cos = torch.cos(t * torch.pi / 2) ** 2
90
+ return - torch.log(self.eps + (1 - self.eps) * cos)
91
+
92
+
93
+ class Linear(Noise):
94
+ def __init__(self, sigma_min=0, sigma_max=10, dtype=torch.float32):
95
+ super().__init__()
96
+ self.sigma_min = torch.tensor(sigma_min, dtype=dtype)
97
+ self.sigma_max = torch.tensor(sigma_max, dtype=dtype)
98
+
99
+ def rate_noise(self, t):
100
+ return self.sigma_max - self.sigma_min
101
+
102
+ def total_noise(self, t):
103
+ return self.sigma_min + t * (self.sigma_max - self.sigma_min)
104
+
105
+ def importance_sampling_transformation(self, t):
106
+ f_T = torch.log1p(- torch.exp(- self.sigma_max))
107
+ f_0 = torch.log1p(- torch.exp(- self.sigma_min))
108
+ sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
109
+ return (sigma_t - self.sigma_min) / (
110
+ self.sigma_max - self.sigma_min)
111
+
112
+
113
+ class GeometricNoise(Noise):
114
+ def __init__(self, sigma_min=1e-3, sigma_max=1):
115
+ super().__init__()
116
+ self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max])
117
+
118
+ def rate_noise(self, t):
119
+ return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (
120
+ self.sigmas[1].log() - self.sigmas[0].log())
121
+
122
+ def total_noise(self, t):
123
+ return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
124
+
125
+
126
+ class LogLinearNoise(Noise):
127
+ """Log Linear noise schedule.
128
+
129
+ Built such that 1 - 1/e^(n(t)) interpolates between 0 and
130
+ ~1 when t varies from 0 to 1. Total noise is
131
+ -log(1 - (1 - eps) * t), so the sigma will be
132
+ (1 - eps) * t.
133
+ """
134
+ def __init__(self, eps=1e-3):
135
+ super().__init__()
136
+ self.eps = eps
137
+ self.sigma_max = self.total_noise(torch.tensor(1.0))
138
+ self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0))
139
+
140
+ def rate_noise(self, t):
141
+ return (1 - self.eps) / (1 - (1 - self.eps) * t)
142
+
143
+ def total_noise(self, t):
144
+ return -torch.log1p(-(1 - self.eps) * t)
145
+
146
+ def importance_sampling_transformation(self, t):
147
+ f_T = torch.log1p(- torch.exp(- self.sigma_max))
148
+ f_0 = torch.log1p(- torch.exp(- self.sigma_min))
149
+ sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
150
+ t = - torch.expm1(- sigma_t) / (1 - self.eps)
151
+ return t
tr2d2-dna/oracle.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tempfile
3
+ import grelu
4
+ import pandas as pd
5
+ import os
6
+ from grelu.lightning import LightningModel
7
+ import grelu.data.preprocess
8
+ import grelu.data.dataset
9
+ import dataloader_gosai
10
+ import numpy as np
11
+ from typing import Callable, Union, List
12
+ from scipy.linalg import sqrtm
13
+ from scipy.stats import pearsonr
14
+ import torch.nn.functional as F
15
+ import io
16
+
17
+ base_path = "" # Fill in directory of the pretrained checkpoints, e.g., "...../data_and_model/"
18
+
19
+
20
+ def get_cal_atac_orale(device=None):
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
22
+ ckpt_path = os.path.join(base_path, 'mdlm/gosai_data/binary_atac_cell_lines.ckpt')
23
+
24
+ ckpt = torch.load(ckpt_path, map_location="cpu")
25
+
26
+ hp = ckpt.get("hyper_parameters", {})
27
+ ckpt.setdefault("data_params", hp.get("data_params", {}))
28
+ ckpt.setdefault("performance", {})
29
+
30
+ if not ckpt["performance"]:
31
+ ckpt["performance"] = {
32
+ "best_step": ckpt.get("global_step", 0),
33
+ "best_metric": None,
34
+ }
35
+
36
+ # Load model from in-memory checkpoint (no file I/O needed)
37
+ buffer = io.BytesIO()
38
+ torch.save(ckpt, buffer)
39
+ buffer.seek(0) # Reset buffer position to the beginning
40
+
41
+ model_load = LightningModel.load_from_checkpoint(buffer, map_location="cpu")
42
+ model_load.to(device)
43
+
44
+ model_load.train_params['logger'] = None
45
+
46
+ return model_load
47
+
48
+
49
+
50
+
51
+ def get_gosai_oracle(mode='train', device=None):
52
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
53
+ if mode == 'train':
54
+ ckpt_path = os.path.join(base_path, "mdlm/outputs_gosai/lightning_logs/reward_oracle_ft.ckpt")
55
+
56
+ ckpt = torch.load(ckpt_path, map_location="cpu")
57
+
58
+ hp = ckpt.get("hyper_parameters", {})
59
+ ckpt.setdefault("data_params", hp.get("data_params", {}))
60
+ ckpt.setdefault("performance", {})
61
+
62
+ if not ckpt["performance"]:
63
+ ckpt["performance"] = {
64
+ "best_step": ckpt.get("global_step", 0),
65
+ "best_metric": None,
66
+ }
67
+
68
+ # Load model from in-memory checkpoint (no file I/O needed)
69
+ buffer = io.BytesIO()
70
+ torch.save(ckpt, buffer)
71
+ buffer.seek(0) # Reset buffer position to the beginning
72
+
73
+ model_load = LightningModel.load_from_checkpoint(buffer, map_location="cpu")
74
+ model_load.to(device)
75
+
76
+ elif mode == 'eval':
77
+
78
+ ckpt_path = os.path.join(base_path, "mdlm/outputs_gosai/lightning_logs/reward_oracle_eval.ckpt")
79
+ ckpt = torch.load(ckpt_path, map_location="cpu")
80
+
81
+ hp = ckpt.get("hyper_parameters", {})
82
+ ckpt.setdefault("data_params", hp.get("data_params", {})) # safe default
83
+ ckpt.setdefault("performance", {}) # safe default
84
+ # Optional: add minimal hints if code later reads fields
85
+ if not ckpt["performance"]:
86
+ ckpt["performance"] = {
87
+ "best_step": ckpt.get("global_step", 0),
88
+ "best_metric": None,
89
+ }
90
+
91
+ # Load model from in-memory checkpoint (no file I/O needed)
92
+ buffer = io.BytesIO()
93
+ torch.save(ckpt, buffer)
94
+ buffer.seek(0) # Reset buffer position to the beginning
95
+
96
+ model_load = LightningModel.load_from_checkpoint(buffer, map_location="cpu")
97
+ model_load.to(device)
98
+ else:
99
+ raise ValueError
100
+
101
+ model_load.train_params['logger'] = None
102
+
103
+ return model_load
104
+
105
+ def cal_gosai_pred(seqs, model=None, mode='eval'):
106
+ """
107
+ seqs: list of sequences (detokenized ACGT...)
108
+ """
109
+ if model is None:
110
+ model = get_gosai_oracle(mode=mode)
111
+ df_seqs = pd.DataFrame(seqs, columns=['seq'])
112
+ pred_dataset = grelu.data.dataset.DFSeqDataset(df_seqs)
113
+ preds = model.predict_on_dataset(pred_dataset, devices=[0])
114
+ return preds.squeeze() # numpy array with shape [n_seqs, 3]
115
+
116
+ def cal_gosai_pred_new(seqs, model=None, mode='eval'):
117
+ """
118
+ seqs: list of sequences (detokenized ACGT...)
119
+ """
120
+ if model is None:
121
+ model = get_gosai_oracle(mode=mode)
122
+ model.eval()
123
+ tokens = dataloader_gosai.batch_dna_tokenize(seqs)
124
+ tokens = torch.tensor(tokens).long().to(model.device)
125
+ onehot_tokens = F.one_hot(tokens, num_classes=4).float()
126
+ preds = model(onehot_tokens.float().transpose(1, 2)).detach().cpu().numpy()
127
+ return preds.squeeze()
128
+
129
+ def cal_atac_pred(seqs, model=None):
130
+ """
131
+ seqs: list of sequences (detokenized ACGT...)
132
+ """
133
+ if model is None:
134
+ model = LightningModel.load_from_checkpoint(os.path.join(base_path, 'mdlm/gosai_data/binary_atac_cell_lines.ckpt'), map_location='cuda')
135
+ df_seqs = pd.DataFrame(seqs, columns=['seq'])
136
+ pred_dataset = grelu.data.dataset.DFSeqDataset(df_seqs)
137
+ preds = model.predict_on_dataset(pred_dataset, devices=[0])
138
+ return preds.squeeze() # numpy array with shape [n_seqs, 7]
139
+
140
+
141
+ def cal_atac_pred_new(seqs, model=None):
142
+ """
143
+ seqs: list of sequences (detokenized ACGT...)
144
+ """
145
+ if model is None:
146
+ model = LightningModel.load_from_checkpoint(os.path.join(base_path, 'mdlm/gosai_data/binary_atac_cell_lines.ckpt'), map_location='cuda')
147
+ model.eval()
148
+ tokens = dataloader_gosai.batch_dna_tokenize(seqs)
149
+ tokens = torch.tensor(tokens).long().to(model.device)
150
+ onehot_tokens = F.one_hot(tokens, num_classes=4).float()
151
+ preds = model(onehot_tokens.float().transpose(1, 2)).detach().cpu().numpy()
152
+ return preds.squeeze() # numpy array with shape [n_seqs, 7]
153
+
154
+
155
+ def count_kmers(seqs, k=3):
156
+ counts = {}
157
+ for seq in seqs:
158
+ for i in range(len(seq) - k + 1):
159
+ subseq = seq[i : i + k]
160
+ try:
161
+ counts[subseq] += 1
162
+ except KeyError:
163
+ counts[subseq] = 1
164
+ return counts
165
+
166
+
167
+ def subset_for_eval(n=5000, seed=0):
168
+ train_set = dataloader_gosai.get_datasets_gosai()
169
+ np.random.seed(seed)
170
+ torch.manual_seed(seed)
171
+ train_set_sp = torch.utils.data.Subset(train_set, np.random.choice(len(train_set), n, replace=False))
172
+ return train_set_sp
173
+
174
+
175
+ def subset_eval_groundtruth(sets_sp):
176
+ train_set_sp = sets_sp
177
+ train_set_sp_clss = train_set_sp.dataset.clss[train_set_sp.indices]
178
+ return train_set_sp_clss
179
+
180
+
181
+ def subset_eval_preds(sets_sp, oracle_model=None):
182
+ train_set_sp = sets_sp
183
+ train_preds = cal_gosai_pred(
184
+ dataloader_gosai.batch_dna_detokenize(train_set_sp.dataset.seqs[train_set_sp.indices].numpy()), oracle_model)
185
+ return train_preds
186
+
187
+
188
+ def subset_eval_kmers(sets_sp, k=3):
189
+ train_set_sp = sets_sp
190
+ train_seqs = dataloader_gosai.batch_dna_detokenize(train_set_sp.dataset.seqs[train_set_sp.indices].numpy())
191
+ train_kmers = count_kmers(train_seqs, k)
192
+ return train_kmers
193
+
194
+
195
+ def subset_eval_embs(sets_sp, oracle_model=None):
196
+ train_set_sp = sets_sp
197
+ train_sp_emb = cal_gosai_emb(
198
+ dataloader_gosai.batch_dna_detokenize(train_set_sp.dataset.seqs[train_set_sp.indices].numpy()), oracle_model)
199
+ return train_sp_emb
200
+
201
+
202
+ def cal_emb_pca(sets_sp, n_components=50, oracle_model=None):
203
+ train_set_sp = sets_sp
204
+ train_sp_emb = cal_gosai_emb(
205
+ dataloader_gosai.batch_dna_detokenize(train_set_sp.dataset.seqs[train_set_sp.indices].numpy()), oracle_model)
206
+ from sklearn.decomposition import PCA
207
+ pca = PCA(n_components=n_components)
208
+ pca.fit(train_sp_emb.reshape(train_sp_emb.shape[0], -1))
209
+ return pca
210
+
211
+
212
+ def subset_eval_embs_pca(sets_sp, pca, oracle_model=None):
213
+ train_sp_emb = subset_eval_embs(sets_sp, oracle_model)
214
+ train_sp_emb_pca = pca.transform(train_sp_emb.reshape(train_sp_emb.shape[0], -1))
215
+ return train_sp_emb_pca
216
+
217
+
218
+ # https://github.com/HannesStark/dirichlet-flow-matching/blob/main/utils/flow_utils.py
219
+ def get_wasserstein_dist(embeds1, embeds2):
220
+ if np.isnan(embeds2).any() or np.isnan(embeds1).any() or len(embeds1) == 0 or len(embeds2) == 0:
221
+ return float('nan')
222
+ mu1, sigma1 = embeds1.mean(axis=0), np.cov(embeds1, rowvar=False)
223
+ mu2, sigma2 = embeds2.mean(axis=0), np.cov(embeds2, rowvar=False)
224
+ ssdiff = np.sum((mu1 - mu2) ** 2.0)
225
+ covmean = sqrtm(sigma1.dot(sigma2))
226
+ if np.iscomplexobj(covmean):
227
+ covmean = covmean.real
228
+ dist = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
229
+ return dist
230
+
231
+
232
+ def embed_on_dataset(
233
+ model,
234
+ dataset: Callable,
235
+ devices: Union[str, int, List[int]] = "cpu",
236
+ num_workers: int = 1,
237
+ batch_size: int = 256,
238
+ ):
239
+ """
240
+ Return embeddings for a dataset of sequences
241
+
242
+ Args:
243
+ dataset: Dataset object that yields one-hot encoded sequences
244
+ devices: Device IDs to use
245
+ num_workers: Number of workers for data loader
246
+ batch_size: Batch size for data loader
247
+
248
+ Returns:
249
+ Numpy array of shape (B, T, L) containing embeddings.
250
+ """
251
+ torch.set_float32_matmul_precision("medium")
252
+
253
+ # Make dataloader
254
+ dataloader = model.make_predict_loader(
255
+ dataset, num_workers=num_workers, batch_size=batch_size
256
+ )
257
+
258
+ # Get device
259
+ orig_device = model.device
260
+ device = model.parse_devices(devices)[1]
261
+ if isinstance(device, list):
262
+ device = device[0]
263
+ model.to(device)
264
+
265
+ # Get embeddings
266
+ preds = []
267
+ model.model = model.model.eval()
268
+ for batch in iter(dataloader):
269
+ batch = batch.to(device)
270
+ preds.append(model.model.embedding(batch).detach().cpu())
271
+
272
+ # Return to original device
273
+ model.to(orig_device)
274
+ return torch.vstack(preds).numpy()
275
+
276
+
277
+ def cal_gosai_emb(seqs, model=None):
278
+ """
279
+ seqs: list of sequences (detokenized ACGT...)
280
+ """
281
+ if model is None:
282
+ model = get_gosai_oracle()
283
+ df_seqs = pd.DataFrame(seqs, columns=['seq'])
284
+ pred_dataset = grelu.data.dataset.DFSeqDataset(df_seqs)
285
+ embs = embed_on_dataset(model, pred_dataset, devices=[0])
286
+ return embs # numpy array with shape [n_seqs, 3072, 2]
287
+
288
+
289
+ def cal_highexp_kmers(k=3, return_clss=False):
290
+ train_set = dataloader_gosai.get_datasets_gosai()
291
+ exp_threshold = np.quantile(train_set.clss[:, 0].numpy(), 0.99) # 4.56
292
+ highexp_indices = [i for i, data in enumerate(train_set) if data['clss'][0] > exp_threshold]
293
+ highexp_set_sp = torch.utils.data.Subset(train_set, highexp_indices)
294
+ highexp_seqs = dataloader_gosai.batch_dna_detokenize(highexp_set_sp.dataset.seqs[highexp_set_sp.indices].numpy())
295
+ highexp_kmers_99 = count_kmers(highexp_seqs, k=k)
296
+ n_highexp_kmers_99 = len(highexp_indices)
297
+
298
+ exp_threshold = np.quantile(train_set.clss[:, 0].numpy(), 0.999) # 6.27
299
+ highexp_indices = [i for i, data in enumerate(train_set) if data['clss'][0] > exp_threshold]
300
+ highexp_set_sp = torch.utils.data.Subset(train_set, highexp_indices)
301
+ highexp_seqs = dataloader_gosai.batch_dna_detokenize(highexp_set_sp.dataset.seqs[highexp_set_sp.indices].numpy())
302
+ highexp_kmers_999 = count_kmers(highexp_seqs, k=k)
303
+ n_highexp_kmers_999 = len(highexp_indices)
304
+
305
+ if return_clss:
306
+ highexp_set_sp_clss_999 = highexp_set_sp.dataset.clss[highexp_set_sp.indices]
307
+ highexp_preds_999 = cal_gosai_pred_new(
308
+ dataloader_gosai.batch_dna_detokenize(highexp_set_sp.dataset.seqs[highexp_set_sp.indices].numpy()))
309
+ return highexp_kmers_99, n_highexp_kmers_99, highexp_kmers_999, n_highexp_kmers_999, highexp_set_sp_clss_999, highexp_preds_999, highexp_seqs
310
+
311
+ return highexp_kmers_99, n_highexp_kmers_99, highexp_kmers_999, n_highexp_kmers_999
312
+
313
+
314
+ def cal_kmer_corr(model, highexp_kmers, n_highexp_kmers, n_sample=128):
315
+ model.eval()
316
+ all_detoeknized_samples = []
317
+ for _ in range(10):
318
+ samples = model._sample(eval_sp_size=n_sample).detach().cpu().numpy()
319
+ detokenized_samples = dataloader_gosai.batch_dna_detokenize(samples)
320
+ all_detoeknized_samples.extend(detokenized_samples)
321
+ generated_kmer = count_kmers(all_detoeknized_samples)
322
+
323
+
324
+ kmer_set = set(highexp_kmers.keys()) | set(generated_kmer.keys())
325
+ counts = np.zeros((len(kmer_set), 2))
326
+ for i, kmer in enumerate(kmer_set):
327
+ if kmer in highexp_kmers:
328
+ counts[i][1] = highexp_kmers[kmer] * len(generated_kmer) / n_highexp_kmers
329
+ if kmer in generated_kmer:
330
+ counts[i][0] = generated_kmer[kmer]
331
+
332
+ corr = pearsonr(counts[:, 0], counts[:, 1])[0]
333
+ return corr
334
+
335
+ def cal_avg_likelihood(model, old_model, n_sample=128):
336
+ model.eval()
337
+ old_model.eval()
338
+ all_raw_samples = []
339
+ for _ in range(10):
340
+ samples = model._sample(eval_sp_size=n_sample)
341
+ all_raw_samples.append(samples)
342
+ all_raw_samples = torch.concat(all_raw_samples)
343
+ avg_likelihood = old_model._forward_pass_diffusion(all_raw_samples).sum(-1).mean().item()
344
+ return avg_likelihood
tr2d2-dna/run_batch_eval.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=dna
3
+ #SBATCH --partition=coe-gpu
4
+ #SBATCH --gres=gpu:H200:1
5
+ #SBATCH --time=16:00:00
6
+ #SBATCH --mem-per-gpu=60G
7
+ #SBATCH --cpus-per-task=2
8
+ #SBATCH --wait-all-nodes=1
9
+ #SBATCH --output=../outputs/%j.%x/.log
10
+
11
+
12
+ # Set the path to your runs directory
13
+ RUNS_DIR="" # Fill in directory of which to eval the checkpoints
14
+
15
+ # Set output file name with timestamp
16
+ TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
17
+ OUTPUT_FILE="batch_eval_results_${TIMESTAMP}.txt"
18
+
19
+ # Run the batch evaluation
20
+ python eval_runs_batch.py \
21
+ --runs_dir "$RUNS_DIR" \
22
+ --output_file "$OUTPUT_FILE" \
23
+ --device "cuda:0" \
24
+ --total_num_steps 128 \
25
+ --batch_size 128 \
26
+ --num_seeds 3 \
27
+ --total_samples 640 \
28
+ --seq_length 200
29
+
30
+ echo "Batch evaluation completed. Results saved to: $OUTPUT_FILE"
tr2d2-dna/train.sh ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ #SBATCH --job-name=dna_mdns
4
+ #SBATCH --partition=coe-gpu
5
+ #SBATCH --gres=gpu:H100:1
6
+ #SBATCH --time=16:00:00
7
+ # max 16 GPU hours, i.e., time <= 16h / num of GPUs
8
+ #SBATCH --mem-per-gpu=60G
9
+ # maximum GPU RAM, 141G for H200, 94G for H100
10
+ # in the current setting, 40G is enough for num_replicates=2 and 80G is enough for num_replicates=4
11
+ #SBATCH --cpus-per-task=2
12
+ #SBATCH --wait-all-nodes=1
13
+ #SBATCH --output=../outputs/%j.%x/.log
14
+
15
+
16
+ HOME_LOC= "" # Fill in directory of the repo
17
+ SAVE_PATH = "" # Fill in directory to save the checkpoints
18
+ BASE_PATH = "" # Fill in directory of the pretrained checkpoints, e.g., "...../data_and_model/"
19
+ SCRIPT_LOC=$HOME_LOC/tr2d2/dna
20
+ LOG_LOC=$HOME_LOC/tr2d2/dna/logs
21
+ DATE=$(date +%m_%d)
22
+
23
+
24
+
25
+ mkdir -p "$LOG_LOC"
26
+
27
+ # set 3 have skip connection
28
+ # ===================================================================
29
+ python $SCRIPT_LOC/finetune.py \
30
+ --base_path $BASE_PATH \
31
+ --device "cuda:0" \
32
+ --noise_removal \
33
+ --save_path_dir $SAVE_PATH \
34
+ --wdce_num_replicates 16 \
35
+ --buffer_size 160 \
36
+ --batch_size 160 \
37
+ --seq_length 200 \
38
+ --num_children 32 \
39
+ --total_num_steps 128 \
40
+ --num_iter 5 \
41
+ --resample_every_n_step 5 \
42
+ --eval_every_n_epochs 10 \
43
+ --num_epochs 60000 \
44
+ --exploration 0.1 \
45
+ --save_every_n_epoch 2000 \
46
+ --alpha 0.1 \
47
+ --centering \
48
+ --grad_clip \
49
+ --reward_clip \
50
+ --reward_clip_value 15.0 \
51
+ --reset_tree
tr2d2-dna/utils.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Console logger utilities.
2
+
3
+ Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py
4
+ Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging
5
+ """
6
+
7
+ import logging
8
+ import fsspec
9
+ import lightning
10
+ import torch
11
+ from timm.scheduler import CosineLRScheduler
12
+ import argparse
13
+ import numpy as np
14
+ import random
15
+ import os
16
+ import time, torch
17
+ from collections import defaultdict
18
+ from contextlib import contextmanager
19
+
20
+ class StepTimer:
21
+ def __init__(self, device=None):
22
+ self.times = defaultdict(list)
23
+ self.device = device
24
+ self._use_cuda_sync = (
25
+ isinstance(device, torch.device) and device.type == "cuda"
26
+ ) or (isinstance(device, str) and "cuda" in device)
27
+
28
+ @contextmanager
29
+ def section(self, name):
30
+ if self._use_cuda_sync:
31
+ torch.cuda.synchronize()
32
+ t0 = time.perf_counter()
33
+ try:
34
+ yield
35
+ finally:
36
+ if self._use_cuda_sync:
37
+ torch.cuda.synchronize()
38
+ dt = time.perf_counter() - t0
39
+ self.times[name].append(dt)
40
+
41
+ def summary(self, top_k=None):
42
+ # returns (name, count, total, mean, p50, p95)
43
+ import numpy as np
44
+ rows = []
45
+ for k, v in self.times.items():
46
+ a = np.array(v, dtype=float)
47
+ rows.append((k, len(a), a.sum(), a.mean(), np.median(a), np.percentile(a, 95)))
48
+ rows.sort(key=lambda r: r[2], reverse=True) # by total time
49
+ return rows[:top_k] if top_k else rows
50
+
51
+
52
+ def sample_categorical_logits(logits, dtype=torch.float64):
53
+ # do not require logits to be log-softmaxed
54
+ gumbel_noise = -(1e-10 - (torch.rand_like(logits, dtype=dtype) + 1e-10).log()).log()
55
+ return (logits + gumbel_noise).argmax(dim=-1)
56
+
57
+ def fsspec_exists(filename):
58
+ """Check if a file exists using fsspec."""
59
+ fs, _ = fsspec.core.url_to_fs(filename)
60
+ return fs.exists(filename)
61
+
62
+
63
+ def fsspec_listdir(dirname):
64
+ """Listdir in manner compatible with fsspec."""
65
+ fs, _ = fsspec.core.url_to_fs(dirname)
66
+ return fs.ls(dirname)
67
+
68
+
69
+ def fsspec_mkdirs(dirname, exist_ok=True):
70
+ """Mkdirs in manner compatible with fsspec."""
71
+ fs, _ = fsspec.core.url_to_fs(dirname)
72
+ fs.makedirs(dirname, exist_ok=exist_ok)
73
+
74
+
75
+ def print_nans(tensor, name):
76
+ if torch.isnan(tensor).any():
77
+ print(name, tensor)
78
+
79
+
80
+ class CosineDecayWarmupLRScheduler(
81
+ CosineLRScheduler,
82
+ torch.optim.lr_scheduler._LRScheduler):
83
+ """Wrap timm.scheduler.CosineLRScheduler
84
+ Enables calling scheduler.step() without passing in epoch.
85
+ Supports resuming as well.
86
+ Adapted from:
87
+ https://github.com/HazyResearch/hyena-dna/blob/main/src/utils/optim/schedulers.py
88
+ """
89
+
90
+ def __init__(self, *args, **kwargs):
91
+ super().__init__(*args, **kwargs)
92
+ self._last_epoch = -1
93
+ self.step(epoch=0)
94
+
95
+ def step(self, epoch=None):
96
+ if epoch is None:
97
+ self._last_epoch += 1
98
+ else:
99
+ self._last_epoch = epoch
100
+ # We call either step or step_update, depending on
101
+ # whether we're using the scheduler every epoch or every
102
+ # step.
103
+ # Otherwise, lightning will always call step (i.e.,
104
+ # meant for each epoch), and if we set scheduler
105
+ # interval to "step", then the learning rate update will
106
+ # be wrong.
107
+ if self.t_in_epochs:
108
+ super().step(epoch=self._last_epoch)
109
+ else:
110
+ super().step_update(num_updates=self._last_epoch)
111
+
112
+
113
+ class LoggingContext:
114
+ """Context manager for selective logging."""
115
+ def __init__(self, logger, level=None, handler=None, close=True):
116
+ self.logger = logger
117
+ self.level = level
118
+ self.handler = handler
119
+ self.close = close
120
+
121
+ def __enter__(self):
122
+ if self.level is not None:
123
+ self.old_level = self.logger.level
124
+ self.logger.setLevel(self.level)
125
+ if self.handler:
126
+ self.logger.addHandler(self.handler)
127
+
128
+ def __exit__(self, et, ev, tb):
129
+ if self.level is not None:
130
+ self.logger.setLevel(self.old_level)
131
+ if self.handler:
132
+ self.logger.removeHandler(self.handler)
133
+ if self.handler and self.close:
134
+ self.handler.close()
135
+
136
+
137
+ def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
138
+ """Initializes multi-GPU-friendly python logger."""
139
+
140
+ logger = logging.getLogger(name)
141
+ logger.setLevel(level)
142
+
143
+ # this ensures all logging levels get marked with the rank zero decorator
144
+ # otherwise logs would get multiplied for each GPU process in multi-GPU setup
145
+ for level in ('debug', 'info', 'warning', 'error',
146
+ 'exception', 'fatal', 'critical'):
147
+ setattr(logger,
148
+ level,
149
+ lightning.pytorch.utilities.rank_zero_only(
150
+ getattr(logger, level)))
151
+
152
+ return logger
153
+
154
+
155
+ def str2bool(v):
156
+ if isinstance(v, bool):
157
+ return v
158
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
159
+ return True
160
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
161
+ return False
162
+ else:
163
+ raise argparse.ArgumentTypeError('Boolean value expected.')
164
+
165
+
166
+ def set_seed(seed, use_cuda):
167
+ os.environ['PYTHONHASHSEED'] = str(seed)
168
+ np.random.seed(seed)
169
+ random.seed(seed)
170
+ torch.manual_seed(seed)
171
+ # torch.backends.cudnn.deterministic = True
172
+ if use_cuda:
173
+ torch.cuda.manual_seed(seed)
174
+ torch.cuda.manual_seed_all(seed)
175
+ print(f'=> Seed of the run set to {seed}')