zyc4975matholic
commited on
Commit
·
303c2e0
1
Parent(s):
63d28a7
Include DNA training code
Browse files- tr2d2-dna/README.md +49 -0
- tr2d2-dna/configs_gosai/callbacks/checkpoint_every_n_steps.yaml +8 -0
- tr2d2-dna/configs_gosai/callbacks/checkpoint_monitor.yaml +10 -0
- tr2d2-dna/configs_gosai/callbacks/learning_rate_monitor.yaml +3 -0
- tr2d2-dna/configs_gosai/config_gosai.yaml +109 -0
- tr2d2-dna/configs_gosai/lr_scheduler/constant_warmup.yaml +2 -0
- tr2d2-dna/configs_gosai/lr_scheduler/cosine_decay_warmup.yaml +7 -0
- tr2d2-dna/configs_gosai/model/dnaconv.yaml +12 -0
- tr2d2-dna/configs_gosai/noise/ar.yaml +2 -0
- tr2d2-dna/configs_gosai/noise/cosine.yaml +1 -0
- tr2d2-dna/configs_gosai/noise/geometric.yaml +3 -0
- tr2d2-dna/configs_gosai/noise/linear.yaml +3 -0
- tr2d2-dna/configs_gosai/noise/loglinear.yaml +3 -0
- tr2d2-dna/configs_gosai/noise/polynomial.yaml +5 -0
- tr2d2-dna/configs_gosai/strategy/ddp.yaml +2 -0
- tr2d2-dna/configs_gosai/strategy/fsdp.yaml +3 -0
- tr2d2-dna/dataloader_gosai.py +211 -0
- tr2d2-dna/diffusion.py +1604 -0
- tr2d2-dna/diffusion_gosai_cfg.py +729 -0
- tr2d2-dna/env.sh +20 -0
- tr2d2-dna/eval_runs_batch.py +347 -0
- tr2d2-dna/eval_utils.py +29 -0
- tr2d2-dna/finetune.py +149 -0
- tr2d2-dna/finetune_dna.py +113 -0
- tr2d2-dna/finetune_utils.py +147 -0
- tr2d2-dna/mcts.py +581 -0
- tr2d2-dna/models/__init__.py +2 -0
- tr2d2-dna/models/dnaconv.py +121 -0
- tr2d2-dna/models/ema.py +97 -0
- tr2d2-dna/noise_schedule.py +151 -0
- tr2d2-dna/oracle.py +344 -0
- tr2d2-dna/run_batch_eval.sh +30 -0
- tr2d2-dna/train.sh +51 -0
- tr2d2-dna/utils.py +175 -0
tr2d2-dna/README.md
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TR2-D2 For Enhancer DNA Design
|
| 2 |
+
|
| 3 |
+
This part of the code is for finetuning DNA sequence models for optimizing DNA enhancer activity with TR2-D2.
|
| 4 |
+
|
| 5 |
+
The codebase is built upon [MDLM (Sahoo et.al, 2023)](https://github.com/kuleshov-group/mdlm), [Drakes (Wang et.al, 2024)](https://github.com/ChenyuWang-Monica/DRAKES), [SEPO (Zekri et.al, 2025)](https://github.com/ozekri/SEPO/tree/main), and [MDNS (Zhu et.al, 2025)](https://arxiv.org/abs/2508.10684).
|
| 6 |
+
|
| 7 |
+
## Environment Installation
|
| 8 |
+
```
|
| 9 |
+
conda create -n tr2d2-dna python=3.9.18
|
| 10 |
+
|
| 11 |
+
conda activate tr2d2-dna
|
| 12 |
+
|
| 13 |
+
bash env.sh
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
## Model Pretrained Weights Download
|
| 17 |
+
|
| 18 |
+
All data and model weights can be downloaded from the link below, which is provided by the [DRAKES](https://arxiv.org/abs/2410.13643) author. Save the downloaded file in `$BASE_PATH`.
|
| 19 |
+
|
| 20 |
+
https://www.dropbox.com/scl/fi/zi6egfppp0o78gr0tmbb1/DRAKES_data.zip?rlkey=yf7w0pm64tlypwsewqc01wmfq&st=xe8dzn8k&dl=0
|
| 21 |
+
|
| 22 |
+
For downloading using terminal, use
|
| 23 |
+
|
| 24 |
+
```
|
| 25 |
+
curl -L -o dna.zip "https://www.dropbox.com/scl/fi/zi6egfppp0o78gr0tmbb1/DRAKES_data.zip?rlkey=yf7w0pm64tlypwsewqc01wmfq&st=xe8dzn8k&dl=0"
|
| 26 |
+
|
| 27 |
+
unzip dna.zip
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
## Finetune with TR2-D2
|
| 31 |
+
After downloading the pretrained checkpoints, fill in the `base_path` in `dataloader_gosai.py`, `oracle.py`, and `finetune.sh`. Fill in `HOME_LOC` and `SAVE_PATH` in `finetune.sh` as well.
|
| 32 |
+
|
| 33 |
+
Reproduce the DNA experiments with $\alpha = 0.1$ using
|
| 34 |
+
```
|
| 35 |
+
sbatch train.sh
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
## Evaluate saved checkpoints
|
| 39 |
+
The checkpoints will be saved to `SAVE_PATH`.
|
| 40 |
+
Fill in `RUNS_DIR` in `run_batch_eval.sh` to be the same as `SAVE_PATH`. The checkpoints can be evaluated with
|
| 41 |
+
```
|
| 42 |
+
sbatch run_batch_eval.sh
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
tr2d2-dna/configs_gosai/callbacks/checkpoint_every_n_steps.yaml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
checkpoint_every_n_steps:
|
| 2 |
+
_target_: lightning.pytorch.callbacks.ModelCheckpoint
|
| 3 |
+
save_top_k: -1 # Do not save any "best" models; this callback is being used to save every n train steps
|
| 4 |
+
save_last: True # save model as ${save_dir}/checkpoints/last.ckpt
|
| 5 |
+
dirpath: ${checkpointing.save_dir}/checkpoints
|
| 6 |
+
verbose: True
|
| 7 |
+
auto_insert_metric_name: False
|
| 8 |
+
every_n_train_steps: 500
|
tr2d2-dna/configs_gosai/callbacks/checkpoint_monitor.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
checkpoint_monitor:
|
| 2 |
+
_target_: lightning.pytorch.callbacks.ModelCheckpoint
|
| 3 |
+
monitor: val/nll # name of the logged metric which determines when model is improving
|
| 4 |
+
mode: min # can be "max" or "min"
|
| 5 |
+
save_top_k: 1 # save k best models (determined by above metric)
|
| 6 |
+
save_last: False # True = additionally always save model from last epoch
|
| 7 |
+
dirpath: ${checkpointing.save_dir}/checkpoints
|
| 8 |
+
filename: best
|
| 9 |
+
auto_insert_metric_name: False
|
| 10 |
+
verbose: True
|
tr2d2-dna/configs_gosai/callbacks/learning_rate_monitor.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
learning_rate_monitor:
|
| 2 |
+
_target_: lightning.pytorch.callbacks.LearningRateMonitor
|
| 3 |
+
logging_interval: step
|
tr2d2-dna/configs_gosai/config_gosai.yaml
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- /callbacks: [checkpoint_every_n_steps, checkpoint_monitor, learning_rate_monitor]
|
| 4 |
+
- /model: dnaconv
|
| 5 |
+
- /strategy: ddp
|
| 6 |
+
- /noise: loglinear
|
| 7 |
+
- /lr_scheduler: constant_warmup
|
| 8 |
+
|
| 9 |
+
mode: train
|
| 10 |
+
diffusion: absorbing_state
|
| 11 |
+
backbone: cnn
|
| 12 |
+
parameterization: subs
|
| 13 |
+
time_conditioning: False
|
| 14 |
+
T: 0 # 0 (continuous time) / 1000
|
| 15 |
+
subs_masking: False
|
| 16 |
+
debug_mode: False
|
| 17 |
+
|
| 18 |
+
seed: 1
|
| 19 |
+
|
| 20 |
+
data:
|
| 21 |
+
streaming: False
|
| 22 |
+
|
| 23 |
+
loader:
|
| 24 |
+
global_batch_size: 512
|
| 25 |
+
eval_global_batch_size: ${.global_batch_size}
|
| 26 |
+
# Note: batch_size and eval_batch_size are **per machine**
|
| 27 |
+
batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
|
| 28 |
+
eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
|
| 29 |
+
num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
|
| 30 |
+
pin_memory: True
|
| 31 |
+
|
| 32 |
+
sampling:
|
| 33 |
+
predictor: ddpm
|
| 34 |
+
steps: 128
|
| 35 |
+
noise_removal: True
|
| 36 |
+
num_sample_batches: 2 # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
|
| 37 |
+
num_sample_log: 2
|
| 38 |
+
semi_ar: False
|
| 39 |
+
stride_length: 1
|
| 40 |
+
num_strides: 1
|
| 41 |
+
|
| 42 |
+
training:
|
| 43 |
+
ema: 0.9999
|
| 44 |
+
antithetic_sampling: True
|
| 45 |
+
importance_sampling: False
|
| 46 |
+
sampling_eps: 1e-3
|
| 47 |
+
change_of_variables: False
|
| 48 |
+
|
| 49 |
+
eval:
|
| 50 |
+
checkpoint_path: '' # Used to evaluate a checkpoint after training.
|
| 51 |
+
disable_ema: False
|
| 52 |
+
compute_generative_perplexity: True # False
|
| 53 |
+
perplexity_batch_size: 8
|
| 54 |
+
compute_perplexity_on_sanity: False
|
| 55 |
+
gen_ppl_eval_model_name_or_path: gpt2-large # gpt2-large, meta-llama/Llama-2-7b-hf
|
| 56 |
+
generate_samples: True
|
| 57 |
+
subset_size: 5000
|
| 58 |
+
|
| 59 |
+
optim:
|
| 60 |
+
weight_decay: 0
|
| 61 |
+
lr: 3e-4
|
| 62 |
+
beta1: 0.9
|
| 63 |
+
beta2: 0.999
|
| 64 |
+
eps: 1e-8
|
| 65 |
+
|
| 66 |
+
trainer:
|
| 67 |
+
_target_: lightning.Trainer
|
| 68 |
+
accelerator: cuda
|
| 69 |
+
num_nodes: 1
|
| 70 |
+
devices: ${device_count:}
|
| 71 |
+
accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
|
| 72 |
+
gradient_clip_val: 1.0
|
| 73 |
+
precision: 'bf16'
|
| 74 |
+
num_sanity_val_steps: 2
|
| 75 |
+
max_steps: 131500 # 100 epochs
|
| 76 |
+
log_every_n_steps: 10
|
| 77 |
+
limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
|
| 78 |
+
limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run
|
| 79 |
+
val_check_interval: 1000
|
| 80 |
+
|
| 81 |
+
wandb:
|
| 82 |
+
project: gosai-dna
|
| 83 |
+
notes: null
|
| 84 |
+
group: null
|
| 85 |
+
job_type: null
|
| 86 |
+
name: null
|
| 87 |
+
id: ${uuid:}
|
| 88 |
+
tags:
|
| 89 |
+
- ${noise.type}
|
| 90 |
+
|
| 91 |
+
hydra:
|
| 92 |
+
run:
|
| 93 |
+
dir: ${now:%Y.%m.%d}/${now:%H%M%S}
|
| 94 |
+
job:
|
| 95 |
+
chdir: true
|
| 96 |
+
|
| 97 |
+
checkpointing:
|
| 98 |
+
# Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
|
| 99 |
+
save_dir: ${cwd:}
|
| 100 |
+
# Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
|
| 101 |
+
resume_from_ckpt: true
|
| 102 |
+
resume_ckpt_path: ${.save_dir}/checkpoints/last.ckpt
|
| 103 |
+
|
| 104 |
+
finetuning:
|
| 105 |
+
gumbel_softmax_temp: 1.0
|
| 106 |
+
truncate_steps: 3
|
| 107 |
+
|
| 108 |
+
mcts:
|
| 109 |
+
sampling: 0 # 0: gumbel noise, >0 top-k sampling
|
tr2d2-dna/configs_gosai/lr_scheduler/constant_warmup.yaml
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: transformers.get_constant_schedule_with_warmup
|
| 2 |
+
num_warmup_steps: 2500
|
tr2d2-dna/configs_gosai/lr_scheduler/cosine_decay_warmup.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: utils.CosineDecayWarmupLRScheduler
|
| 2 |
+
t_in_epochs: False
|
| 3 |
+
t_initial: ${eval:${trainer.max_steps}-${.warmup_t}}
|
| 4 |
+
warmup_prefix: True
|
| 5 |
+
warmup_lr_init: 1e-6
|
| 6 |
+
warmup_t: ${eval:0.1*${trainer.max_steps}}
|
| 7 |
+
lr_min: 1e-6
|
tr2d2-dna/configs_gosai/model/dnaconv.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: dnaconv
|
| 2 |
+
type: cnn
|
| 3 |
+
length: 200 # for gosai
|
| 4 |
+
hidden_dim: 128
|
| 5 |
+
num_cnn_stacks: 4
|
| 6 |
+
dropout: 0.0
|
| 7 |
+
clean_data: False
|
| 8 |
+
|
| 9 |
+
cls_free_guidance: False
|
| 10 |
+
cls_free_threshold: 2.52
|
| 11 |
+
cls_free_prob: 0.3
|
| 12 |
+
cls_free_weight: 0.3 # weight in sampling
|
tr2d2-dna/configs_gosai/noise/ar.yaml
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
type: ar
|
| 2 |
+
scale: 6.0
|
tr2d2-dna/configs_gosai/noise/cosine.yaml
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
type: cosine
|
tr2d2-dna/configs_gosai/noise/geometric.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
type: geometric
|
| 2 |
+
sigma_min: 1e-4
|
| 3 |
+
sigma_max: 20
|
tr2d2-dna/configs_gosai/noise/linear.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
type: linear
|
| 2 |
+
sigma_min: 1e-3
|
| 3 |
+
sigma_max: 7.0
|
tr2d2-dna/configs_gosai/noise/loglinear.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
type: loglinear
|
| 2 |
+
sigma_min: 1e-4
|
| 3 |
+
sigma_max: 20
|
tr2d2-dna/configs_gosai/noise/polynomial.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
type: polynomial
|
| 2 |
+
a: -3
|
| 3 |
+
b: 5
|
| 4 |
+
c: -4
|
| 5 |
+
eps: 1e-3
|
tr2d2-dna/configs_gosai/strategy/ddp.yaml
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: lightning.pytorch.strategies.DDPStrategy
|
| 2 |
+
find_unused_parameters: false # TODO(yair): this seems hacky, I think if things are correct we shouldn't need this
|
tr2d2-dna/configs_gosai/strategy/fsdp.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TODO(yair): Currenly not compatible with grad clipping
|
| 2 |
+
_target_: lightning.pytorch.strategies.FSDPStrategy
|
| 3 |
+
sharding_strategy: SHARD_GRAD_OP
|
tr2d2-dna/dataloader_gosai.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import typing
|
| 4 |
+
import math
|
| 5 |
+
import utils
|
| 6 |
+
import numpy as np
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
base_path = "" # Fill in directory of the pretrained checkpoints, e.g., "...../data_and_model/"
|
| 10 |
+
LOGGER = utils.get_logger(__name__)
|
| 11 |
+
DNA_ALPHABET = {'A': 0, 'C': 1, 'G': 2, 'T': 3} #, 'M': 4}
|
| 12 |
+
INDEX_TO_DNA = {v: k for k, v in DNA_ALPHABET.items()}
|
| 13 |
+
lookup_array = np.array([INDEX_TO_DNA[i] for i in range(len(INDEX_TO_DNA))])
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def dna_detokenize(seq):
|
| 17 |
+
return ''.join([list(DNA_ALPHABET.keys())[int(i)] for i in seq])
|
| 18 |
+
|
| 19 |
+
def batch_dna_detokenize(batch_seq):
|
| 20 |
+
"""
|
| 21 |
+
batch_seq: numpy array of shape [batch_size, seq_len]
|
| 22 |
+
return: list of strings
|
| 23 |
+
"""
|
| 24 |
+
detokenized_batch = lookup_array[batch_seq]
|
| 25 |
+
detokenized_batch = [''.join(seq) for seq in detokenized_batch]
|
| 26 |
+
return detokenized_batch
|
| 27 |
+
|
| 28 |
+
def dna_tokenize(seq):
|
| 29 |
+
return [DNA_ALPHABET[c] for c in seq]
|
| 30 |
+
|
| 31 |
+
def batch_dna_tokenize(batch_seq):
|
| 32 |
+
"""
|
| 33 |
+
batch_seq: list of strings
|
| 34 |
+
return: numpy array of shape [batch_size, seq_len]
|
| 35 |
+
"""
|
| 36 |
+
tokenized_batch = np.array([[DNA_ALPHABET[c] for c in seq] for seq in batch_seq])
|
| 37 |
+
return tokenized_batch
|
| 38 |
+
|
| 39 |
+
class GosaiDataset(torch.utils.data.Dataset):
|
| 40 |
+
def __init__(self):
|
| 41 |
+
data_df = pd.read_csv(os.path.join(base_path, f'mdlm/gosai_data/processed_data/gosai_all.csv'))
|
| 42 |
+
self.seqs = torch.tensor(data_df['seq'].apply(lambda x: [DNA_ALPHABET[c] for c in x]).tolist())
|
| 43 |
+
self.clss = torch.tensor(data_df[['hepg2', 'k562', 'sknsh']].to_numpy())
|
| 44 |
+
LOGGER.info(f'Loaded data: seqs shape: {self.seqs.shape}, clss shape: {self.clss.shape}')
|
| 45 |
+
|
| 46 |
+
def __len__(self):
|
| 47 |
+
return len(self.seqs)
|
| 48 |
+
|
| 49 |
+
def __getitem__(self, idx):
|
| 50 |
+
return {'seqs': self.seqs[idx], 'clss': self.clss[idx], 'attention_mask': torch.ones(len(self.seqs[idx]))}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_datasets_gosai():
|
| 54 |
+
return GosaiDataset()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_dataloaders_gosai(config, skip_valid=False, valid_seed=None):
|
| 58 |
+
num_gpus = torch.cuda.device_count()
|
| 59 |
+
if config.loader.global_batch_size % (
|
| 60 |
+
num_gpus * config.trainer.accumulate_grad_batches) != 0:
|
| 61 |
+
raise ValueError(
|
| 62 |
+
f'Train Batch Size {config.training.batch_size}'
|
| 63 |
+
f'not divisible by {num_gpus} gpus with accumulation '
|
| 64 |
+
f'{config.trainer.accumulate_grad_batches}.')
|
| 65 |
+
if config.loader.eval_global_batch_size % num_gpus != 0:
|
| 66 |
+
raise ValueError(
|
| 67 |
+
f'Eval Batch Size for {config.eval.batch_size} '
|
| 68 |
+
f'not divisible by {num_gpus}.')
|
| 69 |
+
train_set = GosaiDataset()
|
| 70 |
+
# randomly sample a subset of the train_set as valid_set
|
| 71 |
+
valid_set = torch.utils.data.Subset(train_set, np.random.choice(len(train_set), 40000, replace=False))
|
| 72 |
+
test_set = torch.utils.data.Subset(train_set, np.random.choice(len(train_set), 40000, replace=False))
|
| 73 |
+
|
| 74 |
+
train_loader = torch.utils.data.DataLoader(
|
| 75 |
+
train_set,
|
| 76 |
+
batch_size=config.loader.batch_size,
|
| 77 |
+
num_workers=config.loader.num_workers,
|
| 78 |
+
pin_memory=config.loader.pin_memory,
|
| 79 |
+
shuffle=not config.data.streaming,
|
| 80 |
+
persistent_workers=True)
|
| 81 |
+
if skip_valid:
|
| 82 |
+
valid_loader = None
|
| 83 |
+
test_loader = None
|
| 84 |
+
else:
|
| 85 |
+
if valid_seed is None:
|
| 86 |
+
shuffle_valid = False
|
| 87 |
+
generator = None
|
| 88 |
+
else:
|
| 89 |
+
shuffle_valid = True
|
| 90 |
+
generator = torch.Generator().manual_seed(valid_seed)
|
| 91 |
+
valid_loader = torch.utils.data.DataLoader(
|
| 92 |
+
valid_set,
|
| 93 |
+
batch_size=config.loader.eval_batch_size,
|
| 94 |
+
num_workers=config.loader.num_workers,
|
| 95 |
+
pin_memory=config.loader.pin_memory,
|
| 96 |
+
shuffle=shuffle_valid,
|
| 97 |
+
generator=generator)
|
| 98 |
+
test_loader = torch.utils.data.DataLoader(
|
| 99 |
+
test_set,
|
| 100 |
+
batch_size=config.loader.eval_batch_size,
|
| 101 |
+
num_workers=config.loader.num_workers,
|
| 102 |
+
pin_memory=config.loader.pin_memory,
|
| 103 |
+
shuffle=shuffle_valid,
|
| 104 |
+
generator=generator)
|
| 105 |
+
|
| 106 |
+
return train_loader, valid_loader, test_loader
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# Samplers adapted from: https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/fault_tolerant_sampler.py
|
| 110 |
+
class RandomFaultTolerantSampler(torch.utils.data.RandomSampler):
|
| 111 |
+
|
| 112 |
+
def __init__(self, *args, generator=None, **kwargs):
|
| 113 |
+
# TD [2022-07-17]: We don't force the seed to be zero. We generate random seed,
|
| 114 |
+
# which should be reproducible if pl.seed_everything was called beforehand.
|
| 115 |
+
# This means that changing the seed of the experiment will also change the
|
| 116 |
+
# sampling order.
|
| 117 |
+
if generator is None:
|
| 118 |
+
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
| 119 |
+
generator = torch.Generator().manual_seed(seed)
|
| 120 |
+
kwargs.pop('shuffle', None)
|
| 121 |
+
super().__init__(*args, generator=generator, **kwargs)
|
| 122 |
+
self.counter = 0
|
| 123 |
+
self.restarting = False
|
| 124 |
+
|
| 125 |
+
def state_dict(self):
|
| 126 |
+
return {'random_state': self.generator.get_state(),
|
| 127 |
+
'counter': self.counter}
|
| 128 |
+
|
| 129 |
+
def load_state_dict(self, state_dict):
|
| 130 |
+
self.generator.set_state(state_dict.get('random_state'))
|
| 131 |
+
self.counter = state_dict['counter']
|
| 132 |
+
# self.start_counter = self.counter
|
| 133 |
+
self.restarting = True
|
| 134 |
+
|
| 135 |
+
# TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
|
| 136 |
+
# epoch, and subsequent epoch will have very few batches.
|
| 137 |
+
|
| 138 |
+
def __iter__(self) -> typing.Iterator[int]:
|
| 139 |
+
n = len(self.data_source)
|
| 140 |
+
|
| 141 |
+
self.state = self.generator.get_state()
|
| 142 |
+
indices = torch.randperm(n, generator=self.generator).tolist()
|
| 143 |
+
|
| 144 |
+
if not self.restarting:
|
| 145 |
+
self.counter = 0
|
| 146 |
+
else:
|
| 147 |
+
indices = indices[self.counter:]
|
| 148 |
+
self.restarting = False
|
| 149 |
+
|
| 150 |
+
for index in indices:
|
| 151 |
+
self.counter += 1
|
| 152 |
+
yield index
|
| 153 |
+
|
| 154 |
+
self.counter = 0
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class FaultTolerantDistributedSampler(torch.utils.data.DistributedSampler):
|
| 158 |
+
|
| 159 |
+
def __init__(self, *args, **kwargs):
|
| 160 |
+
super().__init__(*args, **kwargs)
|
| 161 |
+
self.counter = 0
|
| 162 |
+
self.restarting = False
|
| 163 |
+
|
| 164 |
+
def state_dict(self):
|
| 165 |
+
return {'epoch': self.epoch, 'counter': self.counter}
|
| 166 |
+
|
| 167 |
+
def load_state_dict(self, state_dict):
|
| 168 |
+
self.epoch = state_dict['epoch']
|
| 169 |
+
self.counter = state_dict['counter']
|
| 170 |
+
self.restarting = True
|
| 171 |
+
|
| 172 |
+
# TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
|
| 173 |
+
# epoch, and subsequent epoch will have very few batches.
|
| 174 |
+
def __iter__(self):
|
| 175 |
+
if self.shuffle:
|
| 176 |
+
# deterministically shuffle based on epoch and seed
|
| 177 |
+
g = torch.Generator()
|
| 178 |
+
g.manual_seed(self.seed + self.epoch)
|
| 179 |
+
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
|
| 180 |
+
else:
|
| 181 |
+
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
|
| 182 |
+
|
| 183 |
+
if not self.drop_last:
|
| 184 |
+
# add extra samples to make it evenly divisible
|
| 185 |
+
padding_size = self.total_size - len(indices)
|
| 186 |
+
if padding_size <= len(indices):
|
| 187 |
+
indices += indices[:padding_size]
|
| 188 |
+
else:
|
| 189 |
+
indices += (indices * math.ceil(
|
| 190 |
+
padding_size / len(indices)))[:padding_size]
|
| 191 |
+
else:
|
| 192 |
+
# remove tail of data to make it evenly divisible.
|
| 193 |
+
indices = indices[:self.total_size]
|
| 194 |
+
assert len(indices) == self.total_size
|
| 195 |
+
|
| 196 |
+
# subsample
|
| 197 |
+
indices = indices[self.rank:self.total_size:self.num_replicas]
|
| 198 |
+
assert len(indices) == self.num_samples
|
| 199 |
+
|
| 200 |
+
if not self.restarting:
|
| 201 |
+
self.counter = 0
|
| 202 |
+
else:
|
| 203 |
+
indices = indices[self.counter:]
|
| 204 |
+
self.restarting = False
|
| 205 |
+
|
| 206 |
+
for index in indices:
|
| 207 |
+
self.counter += 1
|
| 208 |
+
yield index
|
| 209 |
+
|
| 210 |
+
self.counter = 0
|
| 211 |
+
|
tr2d2-dna/diffusion.py
ADDED
|
@@ -0,0 +1,1604 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import math
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
import hydra.utils
|
| 6 |
+
import lightning as L
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import torchmetrics
|
| 11 |
+
from torch import Tensor
|
| 12 |
+
|
| 13 |
+
import dataloader_gosai
|
| 14 |
+
import models
|
| 15 |
+
import noise_schedule
|
| 16 |
+
import utils
|
| 17 |
+
import oracle
|
| 18 |
+
from scipy.stats import wasserstein_distance, pearsonr
|
| 19 |
+
from finetune_utils import to_one_hot
|
| 20 |
+
|
| 21 |
+
LOG2 = math.log(2)
|
| 22 |
+
LOGGER = utils.get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _sample_categorical(categorical_probs):
|
| 26 |
+
gumbel_norm = (
|
| 27 |
+
1e-10
|
| 28 |
+
- (torch.rand_like(categorical_probs) + 1e-10).log())
|
| 29 |
+
return (categorical_probs / gumbel_norm).argmax(dim=-1).to(dtype=torch.long)
|
| 30 |
+
|
| 31 |
+
def _sample_categorical_gradient(categorical_probs, temp = 1.0):
|
| 32 |
+
gumbel_norm = (
|
| 33 |
+
1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log())
|
| 34 |
+
output = torch.nn.functional.softmax((torch.log(categorical_probs)-torch.log(gumbel_norm))/temp, 2)
|
| 35 |
+
return output
|
| 36 |
+
|
| 37 |
+
def _unsqueeze(x, reference):
|
| 38 |
+
return x.view(
|
| 39 |
+
* x.shape,
|
| 40 |
+
* ((1,) * (len(reference.shape) - len(x.shape))))
|
| 41 |
+
|
| 42 |
+
def sample_batched_categorical(categorical_probs, batch_size):
|
| 43 |
+
"""
|
| 44 |
+
Generates `m` distinct sequences sampled from categorical probabilities
|
| 45 |
+
using the Gumbel distribution to ensure randomness while following probabilities
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
categorical_probs (torch.Tensor): tensor of shape (sequence_length, vocab_length)
|
| 49 |
+
representing categorical probabilities
|
| 50 |
+
m (int): number of distinct sequences to sample
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
torch.Tensor: tensor of shape (m, sequence_length), where each row is a
|
| 54 |
+
distinct sequence of sampled category indices.
|
| 55 |
+
"""
|
| 56 |
+
_, sequence_length, vocab_size = categorical_probs.shape
|
| 57 |
+
|
| 58 |
+
# add Gumbel noise and sample m sequences
|
| 59 |
+
gumbel_noise = (-torch.log(-torch.log(torch.rand(batch_size, sequence_length, vocab_size) + 1e-10) + 1e-10)).to(categorical_probs.device)
|
| 60 |
+
noisy_scores = torch.log(categorical_probs) + gumbel_noise # add Gumbel noise to log probabilities
|
| 61 |
+
|
| 62 |
+
# select the highest score (most likely category after Gumbel noise)
|
| 63 |
+
sampled_sequences = noisy_scores.argmax(dim=-1).to(dtype=torch.long) # shape: (m, sequence_length)
|
| 64 |
+
|
| 65 |
+
return sampled_sequences
|
| 66 |
+
|
| 67 |
+
def sample_batched_top_k(categorical_probs, batch_size, k):
|
| 68 |
+
"""
|
| 69 |
+
Generates `m` sequences sampled from the top-k probabilities of each token
|
| 70 |
+
using Gumbel noise to ensure randomness and reduce bias towards the most likely options.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
categorical_probs (torch.Tensor): A tensor of shape (sequence_length, vocab_length)
|
| 74 |
+
representing categorical probabilities.
|
| 75 |
+
m (int): Number of sequences to sample.
|
| 76 |
+
k (int): Number of top probabilities to consider for sampling.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
torch.Tensor: A tensor of shape (m, sequence_length), where each row is a
|
| 80 |
+
sampled sequence of category indices.
|
| 81 |
+
"""
|
| 82 |
+
_, sequence_length, vocab_length = categorical_probs.shape
|
| 83 |
+
|
| 84 |
+
# Add Gumbel noise to the log probabilities
|
| 85 |
+
gumbel_noise = -torch.log(-torch.log(torch.rand(batch_size, sequence_length, vocab_length) + 1e-10) + 1e-10).to(categorical_probs.device)
|
| 86 |
+
noisy_scores = torch.log(categorical_probs[None, :, :]) + gumbel_noise # Shape: (m, sequence_length, vocab_length)
|
| 87 |
+
|
| 88 |
+
# Get the top-k categories based on noisy scores
|
| 89 |
+
top_k_scores, top_k_indices = torch.topk(noisy_scores, k, dim=-1) # Shape: (m, sequence_length, k)
|
| 90 |
+
|
| 91 |
+
# Convert top-k scores back to probabilities and normalize
|
| 92 |
+
top_k_probs = torch.softmax(top_k_scores, dim=-1).to(categorical_probs.device) # Shape: (m, sequence_length, k)
|
| 93 |
+
|
| 94 |
+
# Sample randomly from the top-k probabilities
|
| 95 |
+
sampled_indices_in_top_k = torch.multinomial(top_k_probs.reshape(-1, k), num_samples=1).squeeze(-1).to(categorical_probs.device)
|
| 96 |
+
sampled_indices_in_top_k = sampled_indices_in_top_k.view(batch_size, sequence_length).to(categorical_probs.device) # Shape: (batch_size, sequence_length)
|
| 97 |
+
|
| 98 |
+
# Map sampled indices back to the original vocabulary indices
|
| 99 |
+
sampled_sequences = torch.gather(top_k_indices, -1, sampled_indices_in_top_k.unsqueeze(-1)).squeeze(-1).to(categorical_probs.device).to(dtype=torch.long)
|
| 100 |
+
|
| 101 |
+
return sampled_sequences
|
| 102 |
+
|
| 103 |
+
@dataclass
|
| 104 |
+
class Loss:
|
| 105 |
+
loss: torch.FloatTensor
|
| 106 |
+
nlls: torch.FloatTensor
|
| 107 |
+
token_mask: torch.FloatTensor
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class NLL(torchmetrics.aggregation.MeanMetric):
|
| 111 |
+
pass
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class BPD(NLL):
|
| 115 |
+
def compute(self) -> Tensor:
|
| 116 |
+
"""Computes the bits per dimension.
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
bpd
|
| 120 |
+
"""
|
| 121 |
+
return self.mean_value / self.weight / LOG2
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class Perplexity(NLL):
|
| 125 |
+
def compute(self) -> Tensor:
|
| 126 |
+
"""Computes the Perplexity.
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
Perplexity
|
| 130 |
+
"""
|
| 131 |
+
return torch.exp(self.mean_value / self.weight)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class Diffusion(L.LightningModule):
|
| 135 |
+
def __init__(
|
| 136 |
+
self,
|
| 137 |
+
config,
|
| 138 |
+
eval=False):
|
| 139 |
+
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.save_hyperparameters()
|
| 142 |
+
self.config = config
|
| 143 |
+
self.vocab_size = 4
|
| 144 |
+
self.sampler = self.config.sampling.predictor
|
| 145 |
+
self.antithetic_sampling = self.config.training.antithetic_sampling
|
| 146 |
+
self.importance_sampling = self.config.training.importance_sampling
|
| 147 |
+
self.change_of_variables = self.config.training.change_of_variables
|
| 148 |
+
# add mask token
|
| 149 |
+
self.mask_index = self.vocab_size
|
| 150 |
+
self.vocab_size += 1
|
| 151 |
+
self.parameterization = self.config.parameterization
|
| 152 |
+
|
| 153 |
+
# dna backbone model
|
| 154 |
+
if self.config.backbone == 'cnn':
|
| 155 |
+
self.backbone = models.dnaconv.CNNModel(
|
| 156 |
+
self.config.model, alphabet_size=self.vocab_size, num_cls=3) # num_cls is not used since classifier is always set to False
|
| 157 |
+
else:
|
| 158 |
+
raise ValueError(f'Unknown backbone: {self.config.backbone}')
|
| 159 |
+
|
| 160 |
+
self.T = self.config.T
|
| 161 |
+
self.subs_masking = self.config.subs_masking
|
| 162 |
+
|
| 163 |
+
self.softplus = torch.nn.Softplus()
|
| 164 |
+
# metrics are automatically reset at end of epoch
|
| 165 |
+
metrics = torchmetrics.MetricCollection({
|
| 166 |
+
'nll': NLL(),
|
| 167 |
+
'bpd': BPD(),
|
| 168 |
+
'ppl': Perplexity(),
|
| 169 |
+
})
|
| 170 |
+
metrics.set_dtype(torch.float64)
|
| 171 |
+
self.train_metrics = metrics.clone(prefix='train/')
|
| 172 |
+
self.valid_metrics = metrics.clone(prefix='val/')
|
| 173 |
+
self.test_metrics = metrics.clone(prefix='test/')
|
| 174 |
+
|
| 175 |
+
# generative perplexity
|
| 176 |
+
self.gen_ppl_metric = Perplexity()
|
| 177 |
+
self.noise = noise_schedule.get_noise(self.config,
|
| 178 |
+
dtype=self.dtype)
|
| 179 |
+
|
| 180 |
+
# ema
|
| 181 |
+
if self.config.training.ema > 0:
|
| 182 |
+
self.ema = models.ema.ExponentialMovingAverage(
|
| 183 |
+
itertools.chain(self.backbone.parameters(),
|
| 184 |
+
self.noise.parameters()),
|
| 185 |
+
decay=self.config.training.ema)
|
| 186 |
+
else:
|
| 187 |
+
self.ema = None
|
| 188 |
+
|
| 189 |
+
self.lr = self.config.optim.lr
|
| 190 |
+
self.sampling_eps = self.config.training.sampling_eps
|
| 191 |
+
self.time_conditioning = self.config.time_conditioning
|
| 192 |
+
self.neg_infinity = -1000000.0
|
| 193 |
+
self.fast_forward_epochs = None
|
| 194 |
+
self.fast_forward_batches = None
|
| 195 |
+
self._validate_configuration()
|
| 196 |
+
|
| 197 |
+
# subset of data for evaluation
|
| 198 |
+
if eval:
|
| 199 |
+
self.eval_sets_sp = oracle.subset_for_eval(n=config.eval.subset_size)
|
| 200 |
+
self.eval_sets_sp_clss = oracle.subset_eval_groundtruth(self.eval_sets_sp)
|
| 201 |
+
self.eval_sets_sp_preds = oracle.subset_eval_preds(self.eval_sets_sp)
|
| 202 |
+
self.eval_sets_sp_kmers = oracle.subset_eval_kmers(self.eval_sets_sp)
|
| 203 |
+
self.emb_pca = oracle.cal_emb_pca(oracle.subset_for_eval(n=40000), n_components=50)
|
| 204 |
+
self.eval_sets_sp_embs_pca = oracle.subset_eval_embs_pca(self.eval_sets_sp, self.emb_pca)
|
| 205 |
+
|
| 206 |
+
def _validate_configuration(self):
|
| 207 |
+
assert not (self.change_of_variables and self.importance_sampling)
|
| 208 |
+
assert self.parameterization == 'subs'
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def on_load_checkpoint(self, checkpoint):
|
| 212 |
+
if self.ema:
|
| 213 |
+
self.ema.load_state_dict(checkpoint['ema'])
|
| 214 |
+
# Copied from:
|
| 215 |
+
# https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py#L41
|
| 216 |
+
self.fast_forward_epochs = checkpoint['loops']['fit_loop']['epoch_progress']['current']['completed']
|
| 217 |
+
self.fast_forward_batches = checkpoint['loops'][
|
| 218 |
+
'fit_loop']['epoch_loop.batch_progress'][
|
| 219 |
+
'current']['completed']
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def on_save_checkpoint(self, checkpoint):
|
| 223 |
+
if self.ema:
|
| 224 |
+
checkpoint['ema'] = self.ema.state_dict()
|
| 225 |
+
# Copied from:
|
| 226 |
+
# https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py
|
| 227 |
+
# ['epoch_loop.batch_progress']['total']['completed'] is 1 iteration
|
| 228 |
+
# behind, so we're using the optimizer's progress.
|
| 229 |
+
checkpoint['loops']['fit_loop'][
|
| 230 |
+
'epoch_loop.batch_progress']['total'][
|
| 231 |
+
'completed'] = checkpoint['loops']['fit_loop'][
|
| 232 |
+
'epoch_loop.automatic_optimization.optim_progress'][
|
| 233 |
+
'optimizer']['step']['total'][
|
| 234 |
+
'completed'] * self.trainer.accumulate_grad_batches
|
| 235 |
+
checkpoint['loops']['fit_loop'][
|
| 236 |
+
'epoch_loop.batch_progress']['current'][
|
| 237 |
+
'completed'] = checkpoint['loops']['fit_loop'][
|
| 238 |
+
'epoch_loop.automatic_optimization.optim_progress'][
|
| 239 |
+
'optimizer']['step']['current'][
|
| 240 |
+
'completed'] * self.trainer.accumulate_grad_batches
|
| 241 |
+
# _batches_that_stepped tracks the number of global steps, not the number
|
| 242 |
+
# of local steps, so we don't multiply with self.trainer.accumulate_grad_batches here.
|
| 243 |
+
checkpoint['loops']['fit_loop'][
|
| 244 |
+
'epoch_loop.state_dict'][
|
| 245 |
+
'_batches_that_stepped'] = checkpoint['loops']['fit_loop'][
|
| 246 |
+
'epoch_loop.automatic_optimization.optim_progress'][
|
| 247 |
+
'optimizer']['step']['total']['completed']
|
| 248 |
+
if 'sampler' not in checkpoint.keys():
|
| 249 |
+
checkpoint['sampler'] = {}
|
| 250 |
+
if hasattr(self.trainer.train_dataloader.sampler, 'state_dict'):
|
| 251 |
+
sampler_state_dict = self.trainer.train_dataloader.sampler.state_dict()
|
| 252 |
+
checkpoint['sampler']['random_state'] = sampler_state_dict.get('random_state', None)
|
| 253 |
+
else:
|
| 254 |
+
checkpoint['sampler']['random_state'] = None
|
| 255 |
+
|
| 256 |
+
def on_train_start(self):
|
| 257 |
+
if self.ema:
|
| 258 |
+
self.ema.move_shadow_params_to_device(self.device)
|
| 259 |
+
|
| 260 |
+
distributed = (
|
| 261 |
+
self.trainer._accelerator_connector.use_distributed_sampler
|
| 262 |
+
and self.trainer._accelerator_connector.is_distributed)
|
| 263 |
+
|
| 264 |
+
print('distributed:', distributed)
|
| 265 |
+
|
| 266 |
+
if distributed:
|
| 267 |
+
sampler_cls = dataloader_gosai.FaultTolerantDistributedSampler
|
| 268 |
+
else:
|
| 269 |
+
sampler_cls = dataloader_gosai.RandomFaultTolerantSampler
|
| 270 |
+
|
| 271 |
+
updated_dls = []
|
| 272 |
+
for dl in self.trainer.fit_loop._combined_loader.flattened:
|
| 273 |
+
if hasattr(dl.sampler, 'shuffle'):
|
| 274 |
+
dl_sampler = sampler_cls(dl.dataset, shuffle=dl.sampler.shuffle)
|
| 275 |
+
else:
|
| 276 |
+
dl_sampler = sampler_cls(dl.dataset)
|
| 277 |
+
if (distributed and self.fast_forward_epochs is not None
|
| 278 |
+
and self.fast_forward_batches is not None):
|
| 279 |
+
|
| 280 |
+
dl_sampler.load_state_dict({
|
| 281 |
+
'epoch': self.fast_forward_epochs,
|
| 282 |
+
'counter': (self.fast_forward_batches
|
| 283 |
+
* self.config.loader.batch_size)})
|
| 284 |
+
updated_dls.append(
|
| 285 |
+
torch.utils.data.DataLoader(
|
| 286 |
+
dl.dataset,
|
| 287 |
+
batch_size=self.config.loader.batch_size,
|
| 288 |
+
num_workers=self.config.loader.num_workers,
|
| 289 |
+
pin_memory=self.config.loader.pin_memory,
|
| 290 |
+
sampler=dl_sampler,
|
| 291 |
+
shuffle=False,
|
| 292 |
+
persistent_workers=True))
|
| 293 |
+
|
| 294 |
+
self.trainer.fit_loop._combined_loader.flattened = updated_dls
|
| 295 |
+
|
| 296 |
+
def optimizer_step(self, *args, **kwargs):
|
| 297 |
+
super().optimizer_step(*args, **kwargs)
|
| 298 |
+
if self.ema:
|
| 299 |
+
self.ema.update(itertools.chain(
|
| 300 |
+
self.backbone.parameters(),
|
| 301 |
+
self.noise.parameters()))
|
| 302 |
+
|
| 303 |
+
# subs parameterization from MDLM
|
| 304 |
+
def _subs_parameterization(self, logits, xt):
|
| 305 |
+
logits[:, :, self.mask_index] += self.neg_infinity
|
| 306 |
+
logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True)
|
| 307 |
+
if xt.ndim > 2 and xt.shape[-1] == self.vocab_size:
|
| 308 |
+
# this is for finetuning setting when the input is one-hot encoded or probs
|
| 309 |
+
xt = xt.argmax(dim=-1)
|
| 310 |
+
unmasked_indices = (xt != self.mask_index)
|
| 311 |
+
logits[unmasked_indices] = self.neg_infinity
|
| 312 |
+
logits[unmasked_indices, xt[unmasked_indices]] = 0
|
| 313 |
+
return logits
|
| 314 |
+
|
| 315 |
+
def _process_sigma(self, sigma):
|
| 316 |
+
if sigma is None:
|
| 317 |
+
assert self.parameterization == 'ar'
|
| 318 |
+
return sigma
|
| 319 |
+
if sigma.ndim > 1:
|
| 320 |
+
sigma = sigma.squeeze(-1)
|
| 321 |
+
if not self.time_conditioning:
|
| 322 |
+
sigma = torch.zeros_like(sigma)
|
| 323 |
+
assert sigma.ndim == 1, sigma.shape
|
| 324 |
+
return sigma
|
| 325 |
+
|
| 326 |
+
def forward(self, x, sigma):
|
| 327 |
+
"""Returns log score."""
|
| 328 |
+
sigma = self._process_sigma(sigma)
|
| 329 |
+
|
| 330 |
+
x = x.to(dtype=torch.long)
|
| 331 |
+
|
| 332 |
+
with torch.cuda.amp.autocast(dtype=torch.float32):
|
| 333 |
+
logits = self.backbone(x, sigma)
|
| 334 |
+
|
| 335 |
+
if self.parameterization == 'subs':
|
| 336 |
+
return self._subs_parameterization(logits=logits, xt=x)
|
| 337 |
+
|
| 338 |
+
return logits
|
| 339 |
+
|
| 340 |
+
# might need changing to match wdce loss
|
| 341 |
+
def _compute_loss(self, batch, prefix):
|
| 342 |
+
|
| 343 |
+
if 'attention_mask' in batch:
|
| 344 |
+
attention_mask = batch['attention_mask']
|
| 345 |
+
else:
|
| 346 |
+
attention_mask = None
|
| 347 |
+
losses = self._loss(batch['seqs'], attention_mask)
|
| 348 |
+
loss = losses.loss
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
if prefix == 'train':
|
| 352 |
+
self.train_metrics.update(losses.nlls, losses.token_mask)
|
| 353 |
+
metrics = self.train_metrics
|
| 354 |
+
elif prefix == 'val':
|
| 355 |
+
self.valid_metrics.update(losses.nlls, losses.token_mask)
|
| 356 |
+
metrics = self.valid_metrics
|
| 357 |
+
elif prefix == 'test':
|
| 358 |
+
self.test_metrics.update(losses.nlls, losses.token_mask)
|
| 359 |
+
metrics = self.test_metrics
|
| 360 |
+
else:
|
| 361 |
+
raise ValueError(f'Invalid prefix: {prefix}')
|
| 362 |
+
|
| 363 |
+
self.log_dict(metrics, on_step=False, on_epoch=True, sync_dist=True)
|
| 364 |
+
|
| 365 |
+
return loss
|
| 366 |
+
|
| 367 |
+
def on_train_epoch_start(self):
|
| 368 |
+
self.backbone.train()
|
| 369 |
+
self.noise.train()
|
| 370 |
+
|
| 371 |
+
def training_step(self, batch, batch_idx):
|
| 372 |
+
loss = self._compute_loss(batch, prefix='train')
|
| 373 |
+
self.log(name='trainer/loss',
|
| 374 |
+
value=loss.item(),
|
| 375 |
+
on_step=True,
|
| 376 |
+
on_epoch=False,
|
| 377 |
+
sync_dist=True)
|
| 378 |
+
return loss
|
| 379 |
+
|
| 380 |
+
def on_validation_epoch_start(self):
|
| 381 |
+
if self.ema:
|
| 382 |
+
self.ema.store(itertools.chain(
|
| 383 |
+
self.backbone.parameters(),
|
| 384 |
+
self.noise.parameters()))
|
| 385 |
+
self.ema.copy_to(itertools.chain(
|
| 386 |
+
self.backbone.parameters(),
|
| 387 |
+
self.noise.parameters()))
|
| 388 |
+
self.backbone.eval()
|
| 389 |
+
self.noise.eval()
|
| 390 |
+
assert self.valid_metrics.nll.mean_value == 0
|
| 391 |
+
assert self.valid_metrics.nll.weight == 0
|
| 392 |
+
|
| 393 |
+
def validation_step(self, batch, batch_idx):
|
| 394 |
+
return self._compute_loss(batch, prefix='val')
|
| 395 |
+
|
| 396 |
+
def on_validation_epoch_end(self):
|
| 397 |
+
if ((self.config.eval.compute_perplexity_on_sanity
|
| 398 |
+
or not self.trainer.sanity_checking)
|
| 399 |
+
and self.config.eval.generate_samples
|
| 400 |
+
and not self.parameterization == 'ar'):
|
| 401 |
+
all_samples, all_detoeknized_samples = [], []
|
| 402 |
+
|
| 403 |
+
for _ in range(self.config.sampling.num_sample_batches):
|
| 404 |
+
|
| 405 |
+
samples = self._sample().detach().cpu().numpy()
|
| 406 |
+
detokenized_samples = dataloader_gosai.batch_dna_detokenize(samples)
|
| 407 |
+
all_samples.append(samples)
|
| 408 |
+
all_detoeknized_samples.extend(detokenized_samples)
|
| 409 |
+
|
| 410 |
+
all_samples = np.concatenate(all_samples, axis=0)
|
| 411 |
+
ws_distance_dict = self.cal_wasserstein_distance(all_detoeknized_samples)
|
| 412 |
+
pearsonr_list = self.cal_kmer_pearsonr(all_detoeknized_samples)
|
| 413 |
+
ws_embpca_list = self.cal_ws_distance_embpca(all_detoeknized_samples)
|
| 414 |
+
|
| 415 |
+
current_step = self.trainer.global_step
|
| 416 |
+
LOGGER.info(f'Current step: {current_step}')
|
| 417 |
+
LOGGER.info(f'Wasserstein distance: {ws_distance_dict}')
|
| 418 |
+
LOGGER.info(f'3mer Pearsonr: {pearsonr_list}')
|
| 419 |
+
LOGGER.info(f'Wasserstein distance embpca: {ws_embpca_list}')
|
| 420 |
+
self.log('val/3mer_pearsonr', pearsonr_list, on_step=False, on_epoch=True, sync_dist=True)
|
| 421 |
+
self.log('val/ws_embpca', ws_embpca_list, on_step=False, on_epoch=True, sync_dist=True)
|
| 422 |
+
|
| 423 |
+
for key in ws_distance_dict:
|
| 424 |
+
for cell_type in ws_distance_dict[key]:
|
| 425 |
+
metric_values = ws_distance_dict[key][cell_type]
|
| 426 |
+
if metric_values: # Check if the list is not empty
|
| 427 |
+
# Assuming metric_values contains [train_metric, valid_metric, test_metric]
|
| 428 |
+
self.log(f'val/{key}_{cell_type}', metric_values[0], on_step=False, on_epoch=True, sync_dist=True)
|
| 429 |
+
|
| 430 |
+
if self.ema:
|
| 431 |
+
self.ema.restore(itertools.chain(self.backbone.parameters(),
|
| 432 |
+
self.noise.parameters()))
|
| 433 |
+
|
| 434 |
+
### VALIDATION METRICS ###
|
| 435 |
+
def cal_wasserstein_distance(self, seqs):
|
| 436 |
+
generated_preds = oracle.cal_gosai_pred_new(seqs)
|
| 437 |
+
ws_distance_dict = {'truth': {'hepg2': [], 'k562': [], 'sknsh': []},
|
| 438 |
+
'preds': {'hepg2': [], 'k562': [], 'sknsh': []}}
|
| 439 |
+
ws_distance_dict['truth']['hepg2'].append(wasserstein_distance(generated_preds[:, 0], self.eval_sets_sp_clss[:, 0]))
|
| 440 |
+
ws_distance_dict['truth']['k562'].append(wasserstein_distance(generated_preds[:, 1], self.eval_sets_sp_clss[:, 1]))
|
| 441 |
+
ws_distance_dict['truth']['sknsh'].append(wasserstein_distance(generated_preds[:, 2], self.eval_sets_sp_clss[:, 2]))
|
| 442 |
+
ws_distance_dict['preds']['hepg2'].append(wasserstein_distance(generated_preds[:, 0], self.eval_sets_sp_preds[:, 0]))
|
| 443 |
+
ws_distance_dict['preds']['k562'].append(wasserstein_distance(generated_preds[:, 1], self.eval_sets_sp_preds[:, 1]))
|
| 444 |
+
ws_distance_dict['preds']['sknsh'].append(wasserstein_distance(generated_preds[:, 2], self.eval_sets_sp_preds[:, 2]))
|
| 445 |
+
return ws_distance_dict
|
| 446 |
+
|
| 447 |
+
def cal_ws_distance_embpca(self, seqs):
|
| 448 |
+
generated_embs = oracle.cal_gosai_emb(seqs)
|
| 449 |
+
generated_embs_pca = self.emb_pca.transform(generated_embs.reshape(generated_embs.shape[0], -1))
|
| 450 |
+
return oracle.get_wasserstein_dist(generated_embs_pca, self.eval_sets_sp_embs_pca)
|
| 451 |
+
|
| 452 |
+
def compare_kmer(self, kmer1, kmer2, n_sp1, n_sp2):
|
| 453 |
+
kmer_set = set(kmer1.keys()) | set(kmer2.keys())
|
| 454 |
+
counts = np.zeros((len(kmer_set), 2))
|
| 455 |
+
for i, kmer in enumerate(kmer_set):
|
| 456 |
+
if kmer in kmer1:
|
| 457 |
+
counts[i][1] = kmer1[kmer] * n_sp2 / n_sp1
|
| 458 |
+
if kmer in kmer2:
|
| 459 |
+
counts[i][0] = kmer2[kmer]
|
| 460 |
+
return pearsonr(counts[:, 0], counts[:, 1])[0]
|
| 461 |
+
|
| 462 |
+
def cal_kmer_pearsonr(self, seqs):
|
| 463 |
+
generated_kmer = oracle.count_kmers(seqs)
|
| 464 |
+
return self.compare_kmer(self.eval_sets_sp_kmers, generated_kmer, self.config.eval.subset_size, len(seqs))
|
| 465 |
+
|
| 466 |
+
def configure_optimizers(self):
|
| 467 |
+
optimizer = torch.optim.AdamW(
|
| 468 |
+
itertools.chain(self.backbone.parameters(),
|
| 469 |
+
self.noise.parameters()),
|
| 470 |
+
lr=self.config.optim.lr,
|
| 471 |
+
betas=(self.config.optim.beta1, self.config.optim.beta2),
|
| 472 |
+
eps=self.config.optim.eps,
|
| 473 |
+
weight_decay=self.config.optim.weight_decay)
|
| 474 |
+
|
| 475 |
+
scheduler = hydra.utils.instantiate(self.config.lr_scheduler, optimizer=optimizer)
|
| 476 |
+
scheduler_dict = {
|
| 477 |
+
'scheduler': scheduler,
|
| 478 |
+
'interval': 'step',
|
| 479 |
+
'monitor': 'val/loss',
|
| 480 |
+
'name': 'trainer/lr',
|
| 481 |
+
}
|
| 482 |
+
return [optimizer], [scheduler_dict]
|
| 483 |
+
|
| 484 |
+
def q_xt(self, x, move_chance):
|
| 485 |
+
"""Computes the noisy sample xt.
|
| 486 |
+
|
| 487 |
+
Args:
|
| 488 |
+
x: int torch.Tensor with shape (batch_size,
|
| 489 |
+
diffusion_model_input_length), input.
|
| 490 |
+
move_chance: float torch.Tensor with shape (batch_size, 1).
|
| 491 |
+
"""
|
| 492 |
+
move_indices = torch.rand(* x.shape, device=x.device) < move_chance
|
| 493 |
+
|
| 494 |
+
xt = torch.where(move_indices, self.mask_index, x)
|
| 495 |
+
return xt
|
| 496 |
+
|
| 497 |
+
def _sample_prior(self, *batch_dims):
|
| 498 |
+
"""
|
| 499 |
+
Returns array of fully masked sequences with same shape as input
|
| 500 |
+
"""
|
| 501 |
+
return self.mask_index * torch.ones(* batch_dims, dtype=torch.int64)
|
| 502 |
+
|
| 503 |
+
def _ddpm_caching_update(self, x, t, dt, p_x0=None):
|
| 504 |
+
assert self.config.noise.type == 'loglinear'
|
| 505 |
+
sigma_t, _ = self.noise(t)
|
| 506 |
+
if t.ndim > 1:
|
| 507 |
+
t = t.squeeze(-1)
|
| 508 |
+
assert t.ndim == 1
|
| 509 |
+
move_chance_t = t[:, None, None]
|
| 510 |
+
move_chance_s = (t - dt)[:, None, None]
|
| 511 |
+
assert move_chance_t.ndim == 3, move_chance_t.shape
|
| 512 |
+
if p_x0 is None:
|
| 513 |
+
p_x0 = self.forward(x, sigma_t).exp()
|
| 514 |
+
|
| 515 |
+
assert move_chance_t.ndim == p_x0.ndim
|
| 516 |
+
q_xs = p_x0 * (move_chance_t - move_chance_s)
|
| 517 |
+
q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
|
| 518 |
+
_x = _sample_categorical(q_xs)
|
| 519 |
+
|
| 520 |
+
copy_flag = (x != self.mask_index).to(x.dtype)
|
| 521 |
+
return p_x0, copy_flag * x + (1 - copy_flag) * _x
|
| 522 |
+
|
| 523 |
+
def _ddpm_update(self, x, t, dt, return_process=False):
|
| 524 |
+
sigma_t, _ = self.noise(t)
|
| 525 |
+
sigma_s, _ = self.noise(t - dt)
|
| 526 |
+
if sigma_t.ndim > 1:
|
| 527 |
+
sigma_t = sigma_t.squeeze(-1)
|
| 528 |
+
if sigma_s.ndim > 1:
|
| 529 |
+
sigma_s = sigma_s.squeeze(-1)
|
| 530 |
+
assert sigma_t.ndim == 1, sigma_t.shape
|
| 531 |
+
assert sigma_s.ndim == 1, sigma_s.shape
|
| 532 |
+
move_chance_t = 1 - torch.exp(-sigma_t) # t
|
| 533 |
+
move_chance_s = 1 - torch.exp(-sigma_s)
|
| 534 |
+
move_chance_t = move_chance_t[:, None, None]
|
| 535 |
+
move_chance_s = move_chance_s[:, None, None]
|
| 536 |
+
unet_conditioning = sigma_t
|
| 537 |
+
log_p_x0 = self.forward(x, unet_conditioning)
|
| 538 |
+
assert move_chance_t.ndim == log_p_x0.ndim
|
| 539 |
+
q_xs = log_p_x0.exp() * (move_chance_t - move_chance_s)
|
| 540 |
+
q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
|
| 541 |
+
_x = _sample_categorical(q_xs)
|
| 542 |
+
copy_flag = (x != self.mask_index).to(x.dtype)
|
| 543 |
+
if return_process:
|
| 544 |
+
return copy_flag * x + (1 - copy_flag) * _x, x, unet_conditioning, move_chance_t, copy_flag
|
| 545 |
+
else:
|
| 546 |
+
return copy_flag * x + (1 - copy_flag) * _x
|
| 547 |
+
|
| 548 |
+
def _ar_sampler(self, bsz):
|
| 549 |
+
# precompute token buffer
|
| 550 |
+
num_pred_tokens = self.config.model.length - 1
|
| 551 |
+
x = torch.zeros(
|
| 552 |
+
(bsz, num_pred_tokens + 1),
|
| 553 |
+
dtype=torch.long,
|
| 554 |
+
device=self.device)
|
| 555 |
+
x[:, 0] = self.tokenizer.bos_token_id
|
| 556 |
+
# precompute noise
|
| 557 |
+
noise = (torch.distributions.Gumbel(0, 1)
|
| 558 |
+
.sample((bsz, num_pred_tokens, self.vocab_size))
|
| 559 |
+
.to(self.device))
|
| 560 |
+
for i in range(num_pred_tokens):
|
| 561 |
+
next_logits = self.forward(x[:, :i + 1], None)[:, -1]
|
| 562 |
+
y = (next_logits + noise[:, i]).argmax(-1)
|
| 563 |
+
x[:, i + 1] = y
|
| 564 |
+
return x
|
| 565 |
+
|
| 566 |
+
@torch.no_grad()
|
| 567 |
+
def _sample(self, num_steps=None, eps=1e-5, eval_sp_size=None):
|
| 568 |
+
"""Generate samples from the model."""
|
| 569 |
+
if eval_sp_size is None:
|
| 570 |
+
batch_size_per_gpu = self.config.loader.eval_batch_size
|
| 571 |
+
else:
|
| 572 |
+
batch_size_per_gpu = eval_sp_size
|
| 573 |
+
if self.parameterization == 'ar':
|
| 574 |
+
return self._ar_sampler(batch_size_per_gpu)
|
| 575 |
+
# Lightning auto-casting is not working in this method for some reason
|
| 576 |
+
if num_steps is None:
|
| 577 |
+
num_steps = self.config.sampling.steps
|
| 578 |
+
x = self._sample_prior(
|
| 579 |
+
batch_size_per_gpu,
|
| 580 |
+
self.config.model.length).to(self.device)
|
| 581 |
+
|
| 582 |
+
timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
|
| 583 |
+
dt = (1 - eps) / num_steps
|
| 584 |
+
p_x0_cache = None
|
| 585 |
+
|
| 586 |
+
for i in range(num_steps):
|
| 587 |
+
t = timesteps[i] * torch.ones(x.shape[0], 1, device=self.device)
|
| 588 |
+
|
| 589 |
+
if self.sampler == 'ddpm':
|
| 590 |
+
x = self._ddpm_update(x, t, dt)
|
| 591 |
+
elif self.sampler == 'ddpm_cache':
|
| 592 |
+
p_x0_cache, x_next = self._ddpm_caching_update(x, t, dt, p_x0=p_x0_cache)
|
| 593 |
+
if (not torch.allclose(x_next, x) or self.time_conditioning):
|
| 594 |
+
p_x0_cache = None
|
| 595 |
+
x = x_next
|
| 596 |
+
else:
|
| 597 |
+
x = self._analytic_update(x, t, dt)
|
| 598 |
+
|
| 599 |
+
if self.config.sampling.noise_removal:
|
| 600 |
+
t = timesteps[-1] * torch.ones(x.shape[0], 1,
|
| 601 |
+
device=self.device)
|
| 602 |
+
if self.sampler == 'analytic':
|
| 603 |
+
x = self._denoiser_update(x, t)
|
| 604 |
+
else:
|
| 605 |
+
unet_conditioning = self.noise(t)[0]
|
| 606 |
+
logits = self.forward(x, unet_conditioning)
|
| 607 |
+
x = logits[:, :, :-1].argmax(dim=-1)
|
| 608 |
+
return x
|
| 609 |
+
|
| 610 |
+
### FOR THE EXPANSION AND ROLLOUT STEP ###
|
| 611 |
+
def sample_finetuned_with_rnd(self, args, reward_model,pretrained, eps=1e-5):
|
| 612 |
+
num_steps = args.total_num_steps
|
| 613 |
+
x_rollout = self._sample_prior(
|
| 614 |
+
args.batch_size,
|
| 615 |
+
args.seq_length).to(self.device)
|
| 616 |
+
|
| 617 |
+
log_rnd = torch.zeros(args.batch_size, device=self.device)
|
| 618 |
+
|
| 619 |
+
timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
|
| 620 |
+
dt = (1 - eps) / num_steps
|
| 621 |
+
|
| 622 |
+
for i in range(num_steps):
|
| 623 |
+
t = timesteps[i] * torch.ones(x_rollout.shape[0], 1, device=self.device)
|
| 624 |
+
|
| 625 |
+
log_p, x_next, log_policy_step, log_pretrained_step = self.mcts_reverse_step(x_rollout, t=t, dt=dt, pretrained=pretrained)
|
| 626 |
+
log_rnd += log_pretrained_step - log_policy_step
|
| 627 |
+
|
| 628 |
+
x_rollout = x_next
|
| 629 |
+
|
| 630 |
+
# if mask token remains, fully unmask
|
| 631 |
+
mask_positions = (x_rollout == self.mask_index) # (B, L) bool
|
| 632 |
+
|
| 633 |
+
# does **any** mask remain in any sequence
|
| 634 |
+
any_mask_global = mask_positions.any().item() # true if mask remains
|
| 635 |
+
if any_mask_global:
|
| 636 |
+
log_p, x_next = self.single_noise_removal(x_rollout, t=t, dt=dt)
|
| 637 |
+
|
| 638 |
+
x_rollout = x_next
|
| 639 |
+
|
| 640 |
+
x_final = x_rollout
|
| 641 |
+
|
| 642 |
+
x_one_hot = to_one_hot(x_final)
|
| 643 |
+
x_one_hot_reward = torch.transpose(x_one_hot, 1, 2)
|
| 644 |
+
reward_preds = reward_model(x_one_hot_reward).squeeze(-1) # (num_children, 4)
|
| 645 |
+
rewards = reward_preds[:, 0] # (num_children, 1)
|
| 646 |
+
log_rnd = log_rnd + rewards / args.alpha
|
| 647 |
+
mean_reward = rewards.mean()
|
| 648 |
+
|
| 649 |
+
return x_final, log_rnd, rewards
|
| 650 |
+
|
| 651 |
+
def sample_finetuned(self, args, reward_model, eps=1e-5):
|
| 652 |
+
num_steps = args.total_num_steps
|
| 653 |
+
x_rollout = self._sample_prior(
|
| 654 |
+
args.batch_size,
|
| 655 |
+
args.seq_length).to(self.device)
|
| 656 |
+
|
| 657 |
+
timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
|
| 658 |
+
dt = (1 - eps) / num_steps
|
| 659 |
+
|
| 660 |
+
for i in range(num_steps):
|
| 661 |
+
t = timesteps[i] * torch.ones(x_rollout.shape[0], 1, device=self.device)
|
| 662 |
+
|
| 663 |
+
log_p, x_next = self.single_reverse_step(x_rollout, t=t, dt=dt)
|
| 664 |
+
|
| 665 |
+
x_rollout = x_next
|
| 666 |
+
|
| 667 |
+
# if mask token remains, fully unmask
|
| 668 |
+
mask_positions = (x_rollout == self.mask_index) # (B, L) bool
|
| 669 |
+
|
| 670 |
+
# does **any** mask remain in any sequence
|
| 671 |
+
any_mask_global = mask_positions.any().item() # true if mask remains
|
| 672 |
+
if any_mask_global:
|
| 673 |
+
log_p, x_next = self.single_noise_removal(x_rollout, t=t, dt=dt)
|
| 674 |
+
|
| 675 |
+
x_rollout = x_next
|
| 676 |
+
|
| 677 |
+
x_final = x_rollout
|
| 678 |
+
|
| 679 |
+
x_one_hot = to_one_hot(x_final)
|
| 680 |
+
x_one_hot_reward = torch.transpose(x_one_hot, 1, 2)
|
| 681 |
+
reward_preds = reward_model(x_one_hot_reward).squeeze(-1) # (num_children, 4)
|
| 682 |
+
rewards = reward_preds[:, 0] # (num_children, 1)
|
| 683 |
+
|
| 684 |
+
mean_reward = rewards.mean()
|
| 685 |
+
|
| 686 |
+
return x_final, mean_reward
|
| 687 |
+
|
| 688 |
+
def compute_log_policy(self, token_array, x_next, t, dt):
|
| 689 |
+
sigma_t, _ = self.noise(t)
|
| 690 |
+
|
| 691 |
+
if token_array.ndim == 1:
|
| 692 |
+
token_array = token_array.unsqueeze(0)
|
| 693 |
+
|
| 694 |
+
if x_next.ndim == 1:
|
| 695 |
+
x_next = x_next.unsqueeze(0)
|
| 696 |
+
|
| 697 |
+
if t.ndim > 1:
|
| 698 |
+
t = t.squeeze(-1)
|
| 699 |
+
assert t.ndim == 1
|
| 700 |
+
|
| 701 |
+
change_prob_t = t[:, None, None]
|
| 702 |
+
change_prob_s = (t - dt)[:, None, None]
|
| 703 |
+
|
| 704 |
+
assert change_prob_t.ndim == 3, change_prob_t.shape
|
| 705 |
+
|
| 706 |
+
log_p = self.forward(token_array, sigma=sigma_t)
|
| 707 |
+
p_x0 = log_p.exp()
|
| 708 |
+
|
| 709 |
+
assert change_prob_t.ndim == p_x0.ndim
|
| 710 |
+
|
| 711 |
+
q_xs = p_x0 * (change_prob_t - change_prob_s)
|
| 712 |
+
|
| 713 |
+
# zero-masking probability
|
| 714 |
+
q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
|
| 715 |
+
|
| 716 |
+
copy_flag = (token_array != self.mask_index)
|
| 717 |
+
|
| 718 |
+
assert copy_flag.dtype == torch.bool, "copy_flag must be bool"
|
| 719 |
+
changed_mask = (~copy_flag)
|
| 720 |
+
|
| 721 |
+
# compute the per-sequence log-probability under the pretrained model
|
| 722 |
+
log_policy_token = log_p.gather(-1, x_next.unsqueeze(-1)).squeeze(-1)
|
| 723 |
+
|
| 724 |
+
unmasked_this_step = (changed_mask & (x_next != self.mask_index)).to(log_policy_token.dtype)
|
| 725 |
+
log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1)
|
| 726 |
+
|
| 727 |
+
# returns:
|
| 728 |
+
# log_policy_step (B, ) log probability x_next tokens under policy
|
| 729 |
+
if log_policy_step.ndim == 1:
|
| 730 |
+
log_policy_step = log_policy_step.squeeze(0)
|
| 731 |
+
|
| 732 |
+
return log_policy_step
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
def single_reverse_step(self, token_array, t, dt, p_x0=None):
|
| 736 |
+
assert self.config.noise.type == 'loglinear'
|
| 737 |
+
sigma_t, _ = self.noise(t)
|
| 738 |
+
|
| 739 |
+
if t.ndim > 1:
|
| 740 |
+
t = t.squeeze(-1)
|
| 741 |
+
assert t.ndim == 1
|
| 742 |
+
|
| 743 |
+
change_prob_t = t[:, None, None]
|
| 744 |
+
change_prob_s = (t - dt)[:, None, None]
|
| 745 |
+
|
| 746 |
+
assert change_prob_t.ndim == 3, change_prob_t.shape
|
| 747 |
+
|
| 748 |
+
if p_x0 is None:
|
| 749 |
+
log_p = self.forward(token_array, sigma=sigma_t)
|
| 750 |
+
p_x0 = log_p.exp()
|
| 751 |
+
|
| 752 |
+
assert change_prob_t.ndim == p_x0.ndim
|
| 753 |
+
|
| 754 |
+
q_xs = p_x0 * (change_prob_t - change_prob_s)
|
| 755 |
+
|
| 756 |
+
# zero-masking probability
|
| 757 |
+
q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
|
| 758 |
+
|
| 759 |
+
x_changed = _sample_categorical(q_xs)
|
| 760 |
+
|
| 761 |
+
copy_flag = (token_array != self.mask_index)
|
| 762 |
+
|
| 763 |
+
int_copy_flag = copy_flag.to(token_array.dtype)
|
| 764 |
+
x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed
|
| 765 |
+
|
| 766 |
+
# returns:
|
| 767 |
+
# log_p (B, L, D) log probabilties of each token under the policy model
|
| 768 |
+
# x_next (B, L) next sequences
|
| 769 |
+
return log_p, x_next
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
def single_noise_removal(self, token_array, t, dt, p_x0=None):
|
| 773 |
+
assert self.config.noise.type == 'loglinear'
|
| 774 |
+
sigma_t, _ = self.noise(t)
|
| 775 |
+
|
| 776 |
+
if t.ndim > 1:
|
| 777 |
+
t = t.squeeze(-1)
|
| 778 |
+
assert t.ndim == 1
|
| 779 |
+
|
| 780 |
+
change_prob_t = t[:, None, None]
|
| 781 |
+
change_prob_s = (t - dt)[:, None, None]
|
| 782 |
+
|
| 783 |
+
assert change_prob_t.ndim == 3, change_prob_t.shape
|
| 784 |
+
|
| 785 |
+
if p_x0 is None:
|
| 786 |
+
log_p = self.forward(token_array, sigma=sigma_t)
|
| 787 |
+
p_x0 = log_p.exp()
|
| 788 |
+
|
| 789 |
+
assert change_prob_t.ndim == p_x0.ndim
|
| 790 |
+
|
| 791 |
+
# changed for noise removal
|
| 792 |
+
p_x0 = p_x0.clone()
|
| 793 |
+
p_x0[:, :, self.mask_index] = 0.0 # prevent remaining a mask
|
| 794 |
+
p_x0 = p_x0 / p_x0.sum(dim=-1, keepdim=True).clamp_min(1e-12) # renorm over non-MASK
|
| 795 |
+
q_xs = p_x0 * (change_prob_t - change_prob_s)
|
| 796 |
+
|
| 797 |
+
x_changed = _sample_categorical(q_xs)
|
| 798 |
+
|
| 799 |
+
copy_flag = (token_array != self.mask_index)
|
| 800 |
+
|
| 801 |
+
int_copy_flag = copy_flag.to(token_array.dtype)
|
| 802 |
+
x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed
|
| 803 |
+
|
| 804 |
+
# returns:
|
| 805 |
+
# log_p (B, L, D) log probabilties of each token under the policy model
|
| 806 |
+
# x_next (B, L) next sequences
|
| 807 |
+
return log_p, x_next
|
| 808 |
+
|
| 809 |
+
def mcts_reverse_step(self, token_array, t, dt, pretrained, p_x0=None):
|
| 810 |
+
assert self.config.noise.type == 'loglinear'
|
| 811 |
+
sigma_t, _ = self.noise(t)
|
| 812 |
+
|
| 813 |
+
if t.ndim > 1:
|
| 814 |
+
t = t.squeeze(-1)
|
| 815 |
+
assert t.ndim == 1
|
| 816 |
+
|
| 817 |
+
change_prob_t = t[:, None, None]
|
| 818 |
+
change_prob_s = (t - dt)[:, None, None]
|
| 819 |
+
|
| 820 |
+
assert change_prob_t.ndim == 3, change_prob_t.shape
|
| 821 |
+
|
| 822 |
+
if p_x0 is None:
|
| 823 |
+
log_p = self.forward(token_array, sigma=sigma_t)
|
| 824 |
+
p_x0 = log_p.exp()
|
| 825 |
+
|
| 826 |
+
assert change_prob_t.ndim == p_x0.ndim
|
| 827 |
+
|
| 828 |
+
q_xs = p_x0 * (change_prob_t - change_prob_s)
|
| 829 |
+
|
| 830 |
+
# zero-masking probability
|
| 831 |
+
q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
|
| 832 |
+
|
| 833 |
+
x_changed = _sample_categorical(q_xs)
|
| 834 |
+
|
| 835 |
+
copy_flag = (token_array != self.mask_index)
|
| 836 |
+
|
| 837 |
+
int_copy_flag = copy_flag.to(token_array.dtype)
|
| 838 |
+
x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed
|
| 839 |
+
|
| 840 |
+
# compute the log-probability under pretrained model at each step
|
| 841 |
+
with torch.no_grad():
|
| 842 |
+
# pretrained should output log-probs over vocab at each position given the *parent* (masked) input
|
| 843 |
+
log_pre = pretrained.forward(token_array, sigma=sigma_t)
|
| 844 |
+
|
| 845 |
+
# log-prob of the *sampled token* at each position
|
| 846 |
+
log_pre_token = log_pre.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
|
| 847 |
+
|
| 848 |
+
# sum only over the sites actually sampled this step (i.e., where parent was mask)
|
| 849 |
+
|
| 850 |
+
assert copy_flag.dtype == torch.bool, "copy_flag must be bool"
|
| 851 |
+
changed_mask = (~copy_flag)
|
| 852 |
+
# mask of tokens that were unmasked in this step
|
| 853 |
+
unmasked_this_step = (changed_mask & (x_next != self.mask_index)).to(log_pre_token.dtype)
|
| 854 |
+
|
| 855 |
+
log_pretrained_step = (log_pre_token * unmasked_this_step).sum(dim=-1)
|
| 856 |
+
|
| 857 |
+
# compute the per-sequence log-probability under the pretrained model
|
| 858 |
+
log_policy_token = log_p.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
|
| 859 |
+
log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1)
|
| 860 |
+
|
| 861 |
+
# returns:
|
| 862 |
+
# log_p (B, L, D) log probabilties of each token under the policy model
|
| 863 |
+
# x_next (B, L) next sequences
|
| 864 |
+
# log_policy_step (B, ) log probability of all unmasked tokens under policy
|
| 865 |
+
# log_pretrained_step (B, ) log probabiltiy of all unmasked tokens under pretrained model
|
| 866 |
+
return log_p, x_next, log_policy_step, log_pretrained_step
|
| 867 |
+
|
| 868 |
+
def mcts_noise_removal(self, token_array, t, dt, pretrained, p_x0=None):
|
| 869 |
+
assert self.config.noise.type == 'loglinear'
|
| 870 |
+
sigma_t, _ = self.noise(t)
|
| 871 |
+
|
| 872 |
+
if t.ndim > 1:
|
| 873 |
+
t = t.squeeze(-1)
|
| 874 |
+
assert t.ndim == 1
|
| 875 |
+
|
| 876 |
+
change_prob_t = t[:, None, None]
|
| 877 |
+
change_prob_s = (t - dt)[:, None, None]
|
| 878 |
+
|
| 879 |
+
assert change_prob_t.ndim == 3, change_prob_t.shape
|
| 880 |
+
|
| 881 |
+
if p_x0 is None:
|
| 882 |
+
log_p = self.forward(token_array, sigma=sigma_t)
|
| 883 |
+
p_x0 = log_p.exp()
|
| 884 |
+
|
| 885 |
+
assert change_prob_t.ndim == p_x0.ndim
|
| 886 |
+
|
| 887 |
+
# changed for noise removal
|
| 888 |
+
p_x0 = p_x0.clone()
|
| 889 |
+
p_x0[:, :, self.mask_index] = 0.0 # prevent remaining a mask
|
| 890 |
+
p_x0 = p_x0 / p_x0.sum(dim=-1, keepdim=True).clamp_min(1e-12) # renorm over non-MASK
|
| 891 |
+
q_xs = p_x0 * (change_prob_t - change_prob_s)
|
| 892 |
+
|
| 893 |
+
x_changed = _sample_categorical(q_xs)
|
| 894 |
+
|
| 895 |
+
copy_flag = (token_array != self.mask_index)
|
| 896 |
+
|
| 897 |
+
int_copy_flag = copy_flag.to(token_array.dtype)
|
| 898 |
+
x_next = int_copy_flag * token_array + (1 - int_copy_flag) * x_changed
|
| 899 |
+
|
| 900 |
+
# compute the log-probability under pretrained model at each step
|
| 901 |
+
with torch.no_grad():
|
| 902 |
+
# pretrained should output log-probs over vocab at each position given the *parent* (masked) input
|
| 903 |
+
log_pre = pretrained.forward(token_array, sigma=sigma_t)
|
| 904 |
+
|
| 905 |
+
# log-prob of the *sampled token* at each position
|
| 906 |
+
log_pre_token = log_pre.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
|
| 907 |
+
|
| 908 |
+
# sum only over the sites actually sampled this step (i.e., where parent was mask)
|
| 909 |
+
|
| 910 |
+
assert copy_flag.dtype == torch.bool, "copy_flag must be bool"
|
| 911 |
+
changed_mask = (~copy_flag)
|
| 912 |
+
# mask of tokens that were unmasked in this step
|
| 913 |
+
unmasked_this_step = (changed_mask & (x_next != self.mask_index)).to(log_pre_token.dtype)
|
| 914 |
+
|
| 915 |
+
log_pretrained_step = (log_pre_token * unmasked_this_step).sum(dim=-1)
|
| 916 |
+
|
| 917 |
+
# compute the per-sequence log-probability under the pretrained model
|
| 918 |
+
log_policy_token = log_p.gather(-1, x_next.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
|
| 919 |
+
log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1)
|
| 920 |
+
|
| 921 |
+
# returns:
|
| 922 |
+
# log_p (B, L, D) log probabilties of each token under the policy model
|
| 923 |
+
# x_next (B, L) next sequences
|
| 924 |
+
# log_policy_step (B, ) log probability of all unmasked tokens under policy
|
| 925 |
+
# log_pretrained_step (B, ) log probabiltiy of all unmasked tokens under pretrained model
|
| 926 |
+
return log_p, x_next, log_policy_step, log_pretrained_step
|
| 927 |
+
|
| 928 |
+
# first step in expansion
|
| 929 |
+
def batch_mcts_reverse_step(self, token_array, t, dt, batch_size, pretrained, p_x0=None):
|
| 930 |
+
|
| 931 |
+
assert self.config.noise.type == 'loglinear'
|
| 932 |
+
sigma_t, _ = self.noise(t)
|
| 933 |
+
|
| 934 |
+
if t.ndim > 1:
|
| 935 |
+
t = t.squeeze(-1)
|
| 936 |
+
assert t.ndim == 1
|
| 937 |
+
|
| 938 |
+
change_prob_t = t[:, None, None]
|
| 939 |
+
change_prob_s = (t - dt)[:, None, None]
|
| 940 |
+
|
| 941 |
+
assert change_prob_t.ndim == 3, change_prob_t.shape
|
| 942 |
+
|
| 943 |
+
if token_array.dim() == 1:
|
| 944 |
+
token_array = token_array.unsqueeze(0)
|
| 945 |
+
|
| 946 |
+
# expand to match (num_children, L)
|
| 947 |
+
|
| 948 |
+
if p_x0 is None:
|
| 949 |
+
log_p = self.forward(token_array, sigma=sigma_t)
|
| 950 |
+
p_x0 = log_p.exp()
|
| 951 |
+
|
| 952 |
+
assert change_prob_t.ndim == p_x0.ndim
|
| 953 |
+
|
| 954 |
+
q_xs = p_x0 * (change_prob_t - change_prob_s)
|
| 955 |
+
|
| 956 |
+
# zero-masking probability
|
| 957 |
+
q_xs[:, :, self.mask_index] = change_prob_s[:, :, 0]
|
| 958 |
+
|
| 959 |
+
# repeat the parent token along the first dimension which will be unmasked into distinct sequences
|
| 960 |
+
token_array_expanded = token_array.repeat(batch_size, 1)
|
| 961 |
+
|
| 962 |
+
if self.config.mcts.sampling == 0:
|
| 963 |
+
x_changed = sample_batched_categorical(q_xs.to(self.device), batch_size)
|
| 964 |
+
else:
|
| 965 |
+
x_changed = sample_batched_top_k(q_xs.to(self.device), batch_size, self.config.mcts.sampling)
|
| 966 |
+
|
| 967 |
+
copy_flag = (token_array_expanded != self.mask_index)
|
| 968 |
+
|
| 969 |
+
int_copy_flag = copy_flag.to(token_array.dtype)
|
| 970 |
+
x_children = int_copy_flag * token_array_expanded + (1 - int_copy_flag) * x_changed
|
| 971 |
+
|
| 972 |
+
|
| 973 |
+
# compute the log-probability under pretrained model at each step
|
| 974 |
+
with torch.no_grad():
|
| 975 |
+
# pretrained should output log-probs over vocab at each position given the *parent* (masked) input
|
| 976 |
+
log_pre = pretrained.forward(token_array, sigma=sigma_t)
|
| 977 |
+
|
| 978 |
+
# expand to match the shape of x_children
|
| 979 |
+
log_pre = log_pre.repeat(batch_size, 1, 1)
|
| 980 |
+
|
| 981 |
+
# log-prob of the *sampled token* at each position
|
| 982 |
+
log_pre_token = log_pre.gather(-1, x_children.unsqueeze(-1)).squeeze(-1) # [B*batch,L]
|
| 983 |
+
|
| 984 |
+
# sum only over the sites actually sampled this step (i.e., where parent was mask)
|
| 985 |
+
|
| 986 |
+
assert copy_flag.dtype == torch.bool, "copy_flag must be bool"
|
| 987 |
+
changed_mask = (~copy_flag)
|
| 988 |
+
# mask of tokens that were unmasked in this step
|
| 989 |
+
unmasked_this_step = (changed_mask & (x_children != self.mask_index)).to(log_pre_token.dtype)
|
| 990 |
+
|
| 991 |
+
log_pretrained_step = (log_pre_token * unmasked_this_step).sum(dim=-1)
|
| 992 |
+
|
| 993 |
+
# compute the per-child log-probability under the pretrained model
|
| 994 |
+
log_p = log_p.repeat(batch_size, 1, 1)
|
| 995 |
+
log_policy_token = log_p.gather(-1, x_children.unsqueeze(-1)).squeeze(-1) # (B, L) probability of each chosen token
|
| 996 |
+
#print(log_policy_token)
|
| 997 |
+
log_policy_step = (log_policy_token * unmasked_this_step).sum(dim=-1)
|
| 998 |
+
|
| 999 |
+
# returns:
|
| 1000 |
+
# log_p (B, L, D) log probabilties of each token under the policy model
|
| 1001 |
+
# x_children (B, L) child sequences
|
| 1002 |
+
# log_policy_step (B, ) log probability of all unmasked tokens under policy
|
| 1003 |
+
# log_pretrained_step (B, ) log probabiltiy of all unmasked tokens under pretrained model
|
| 1004 |
+
return log_p, x_children, log_policy_step, log_pretrained_step
|
| 1005 |
+
|
| 1006 |
+
### SPECIFIC TO DRAKES? ###
|
| 1007 |
+
def _ddpm_update_finetune_gradient(self, x, t, dt, copy_flag_temp, return_process=False):
|
| 1008 |
+
|
| 1009 |
+
if x.ndim == 2 or x.shape[-1] != self.vocab_size:
|
| 1010 |
+
x = F.one_hot(x, num_classes=self.vocab_size).to(torch.float32)
|
| 1011 |
+
|
| 1012 |
+
sigma_t, _ = self.noise(t)
|
| 1013 |
+
sigma_s, _ = self.noise(t - dt)
|
| 1014 |
+
if sigma_t.ndim > 1:
|
| 1015 |
+
sigma_t = sigma_t.squeeze(-1)
|
| 1016 |
+
if sigma_s.ndim > 1:
|
| 1017 |
+
sigma_s = sigma_s.squeeze(-1)
|
| 1018 |
+
assert sigma_t.ndim == 1, sigma_t.shape
|
| 1019 |
+
assert sigma_s.ndim == 1, sigma_s.shape
|
| 1020 |
+
move_chance_t = 1 - torch.exp(-sigma_t) # (1-eps)*t
|
| 1021 |
+
move_chance_s = 1 - torch.exp(-sigma_s)
|
| 1022 |
+
move_chance_t = move_chance_t[:, None, None]
|
| 1023 |
+
move_chance_s = move_chance_s[:, None, None]
|
| 1024 |
+
unet_conditioning = sigma_t
|
| 1025 |
+
log_p_x0 = self.forward(x, unet_conditioning)
|
| 1026 |
+
assert move_chance_t.ndim == log_p_x0.ndim
|
| 1027 |
+
q_xs = log_p_x0.exp() * (move_chance_t - move_chance_s)
|
| 1028 |
+
|
| 1029 |
+
q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
|
| 1030 |
+
_x = _sample_categorical_gradient(q_xs, temp=self.config.finetuning.gumbel_softmax_temp)
|
| 1031 |
+
|
| 1032 |
+
if copy_flag_temp is not None:
|
| 1033 |
+
copy_flag_prob = 1 - x[:, :, self.mask_index].unsqueeze(-1)
|
| 1034 |
+
soft_copy_flag = torch.nn.functional.sigmoid(copy_flag_prob/copy_flag_temp)
|
| 1035 |
+
else:
|
| 1036 |
+
soft_copy_flag = 1 - x[:, :, self.mask_index].unsqueeze(-1)
|
| 1037 |
+
|
| 1038 |
+
if return_process:
|
| 1039 |
+
return soft_copy_flag * x + (1 - soft_copy_flag) * _x, x, unet_conditioning, move_chance_t, soft_copy_flag
|
| 1040 |
+
else:
|
| 1041 |
+
return soft_copy_flag * x + (1 - soft_copy_flag) * _x
|
| 1042 |
+
|
| 1043 |
+
|
| 1044 |
+
def _sample_finetune_gradient(self, num_steps=None, eps=1e-5, eval_sp_size=None, copy_flag_temp=None):
|
| 1045 |
+
"""Generate samples from the model."""
|
| 1046 |
+
assert self.parameterization == 'subs' and self.sampler == 'ddpm'
|
| 1047 |
+
if eval_sp_size is None:
|
| 1048 |
+
batch_size_per_gpu = self.config.loader.eval_batch_size
|
| 1049 |
+
else:
|
| 1050 |
+
batch_size_per_gpu = eval_sp_size
|
| 1051 |
+
if num_steps is None:
|
| 1052 |
+
num_steps = self.config.sampling.steps
|
| 1053 |
+
x = self._sample_prior(
|
| 1054 |
+
batch_size_per_gpu,
|
| 1055 |
+
self.config.model.length).to(self.device)
|
| 1056 |
+
timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
|
| 1057 |
+
dt = (1 - eps) / num_steps
|
| 1058 |
+
p_x0_cache = None
|
| 1059 |
+
|
| 1060 |
+
last_x_list = []
|
| 1061 |
+
condt_list = []
|
| 1062 |
+
move_chance_t_list = []
|
| 1063 |
+
copy_flag_list = []
|
| 1064 |
+
|
| 1065 |
+
for i in range(num_steps):
|
| 1066 |
+
t = timesteps[i] * torch.ones(x.shape[0], 1, device=self.device)
|
| 1067 |
+
if self.sampler == 'ddpm':
|
| 1068 |
+
if i < num_steps - self.config.finetuning.truncate_steps:
|
| 1069 |
+
x, last_x, condt, move_chance_t, copy_flag = self._ddpm_update(x, t, dt, return_process=True)
|
| 1070 |
+
x = x.detach()
|
| 1071 |
+
copy_flag = copy_flag.unsqueeze(-1)
|
| 1072 |
+
last_x = F.one_hot(last_x, num_classes=self.vocab_size).to(torch.float32).detach()
|
| 1073 |
+
else:
|
| 1074 |
+
x, last_x, condt, move_chance_t, copy_flag = self._ddpm_update_finetune_gradient(x, t, dt, copy_flag_temp, return_process=True)
|
| 1075 |
+
|
| 1076 |
+
last_x_list.append(last_x)
|
| 1077 |
+
condt_list.append(condt)
|
| 1078 |
+
move_chance_t_list.append(move_chance_t)
|
| 1079 |
+
copy_flag_list.append(copy_flag)
|
| 1080 |
+
|
| 1081 |
+
x_argmax = x[:, :, :-1].argmax(dim=-1)
|
| 1082 |
+
x_argmax = torch.nn.functional.one_hot(x_argmax, num_classes=self.vocab_size-1).to(torch.float32)
|
| 1083 |
+
return x[:, :, :-1] + (x_argmax - x[:, :, :-1]).detach(), last_x_list, condt_list, move_chance_t_list, copy_flag_list
|
| 1084 |
+
|
| 1085 |
+
@torch.no_grad()
|
| 1086 |
+
def _ddpm_update_finetune_controlled_SMC(self, x, t, dt, reward_model, alpha = 1.0):
|
| 1087 |
+
|
| 1088 |
+
sigma_t, _ = self.noise(t)
|
| 1089 |
+
sigma_s, _ = self.noise(t - dt)
|
| 1090 |
+
if sigma_t.ndim > 1:
|
| 1091 |
+
sigma_t = sigma_t.squeeze(-1)
|
| 1092 |
+
if sigma_s.ndim > 1:
|
| 1093 |
+
sigma_s = sigma_s.squeeze(-1)
|
| 1094 |
+
assert sigma_t.ndim == 1, sigma_t.shape
|
| 1095 |
+
assert sigma_s.ndim == 1, sigma_s.shape
|
| 1096 |
+
move_chance_t = 1 - torch.exp(-sigma_t)
|
| 1097 |
+
move_chance_s = 1 - torch.exp(-sigma_s)
|
| 1098 |
+
move_chance_t = move_chance_t[:, None, None]
|
| 1099 |
+
move_chance_s = move_chance_s[:, None, None]
|
| 1100 |
+
unet_conditioning = sigma_t
|
| 1101 |
+
log_p_x0 = self.forward(x, unet_conditioning)
|
| 1102 |
+
assert move_chance_t.ndim == log_p_x0.ndim
|
| 1103 |
+
q_xs = log_p_x0.exp() * (move_chance_t - move_chance_s)
|
| 1104 |
+
q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
|
| 1105 |
+
copy_flag = (x != self.mask_index).to(x.dtype)
|
| 1106 |
+
sample = copy_flag * x + (1 - copy_flag) * _sample_categorical(q_xs)
|
| 1107 |
+
'''
|
| 1108 |
+
Calcualte exp(v_{t-1}(x_{t-1})/alpha)
|
| 1109 |
+
'''
|
| 1110 |
+
expected_x0 = self.forward(sample, sigma_s) # Calcualte E[x_0|x_{t-1}]
|
| 1111 |
+
expected_x0_arg = torch.argmax(expected_x0,dim=2)
|
| 1112 |
+
expected_x0_onehot = torch.nn.functional.one_hot(expected_x0_arg)
|
| 1113 |
+
reward_num = reward_model(expected_x0_onehot.float().transpose(1, 2)).detach()[:, 0][:, 0]
|
| 1114 |
+
'''
|
| 1115 |
+
Calcualte exp(v_{t}(x_{t})/alpha)
|
| 1116 |
+
'''
|
| 1117 |
+
expected_x0 = self.forward(x, sigma_s) # Calcualte E[x_0|x_t]
|
| 1118 |
+
expected_x0_arg = torch.argmax(expected_x0,dim=2)
|
| 1119 |
+
expected_x0_onehot = torch.nn.functional.one_hot(expected_x0_arg)
|
| 1120 |
+
reward_den = reward_model(expected_x0_onehot.float().transpose(1, 2)).detach()[:, 0][:, 0]
|
| 1121 |
+
|
| 1122 |
+
ratio = torch.exp(1.0/alpha * (reward_num - reward_den)) # Now calculate exp( (v_{t-1}(x_{t-1) -v_{t}(x_{t}) /alpha)
|
| 1123 |
+
ratio = ratio.detach().cpu().numpy()
|
| 1124 |
+
final_sample_indices = np.random.choice(reward_num.shape[0], reward_num.shape[0], p = ratio/ratio.sum() )
|
| 1125 |
+
|
| 1126 |
+
return sample[final_sample_indices]
|
| 1127 |
+
|
| 1128 |
+
def _ddpm_update_finetune_controlled_CG(self, x, t, dt, reward_model, guidance_scale):
|
| 1129 |
+
|
| 1130 |
+
sigma_t, _ = self.noise(t)
|
| 1131 |
+
sigma_s, _ = self.noise(t - dt)
|
| 1132 |
+
if sigma_t.ndim > 1:
|
| 1133 |
+
sigma_t = sigma_t.squeeze(-1)
|
| 1134 |
+
if sigma_s.ndim > 1:
|
| 1135 |
+
sigma_s = sigma_s.squeeze(-1)
|
| 1136 |
+
assert sigma_t.ndim == 1, sigma_t.shape
|
| 1137 |
+
assert sigma_s.ndim == 1, sigma_s.shape
|
| 1138 |
+
move_chance_t = 1 - torch.exp(-sigma_t)
|
| 1139 |
+
move_chance_s = 1 - torch.exp(-sigma_s)
|
| 1140 |
+
move_chance_t = move_chance_t[:, None, None]
|
| 1141 |
+
move_chance_s = move_chance_s[:, None, None]
|
| 1142 |
+
unet_conditioning = sigma_t
|
| 1143 |
+
log_p_x0 = self.forward(x, unet_conditioning)
|
| 1144 |
+
assert move_chance_t.ndim == log_p_x0.ndim
|
| 1145 |
+
q_xs = log_p_x0.exp() * (move_chance_t - move_chance_s)
|
| 1146 |
+
x_onehot = F.one_hot(x, num_classes=5).float()
|
| 1147 |
+
|
| 1148 |
+
x_grad = self.compute_gradient_CG(x_onehot, x, reward_model, sigma_s )
|
| 1149 |
+
guidance = guidance_scale * (x_grad - x_grad[:, :, self.mask_index][:, :, None])
|
| 1150 |
+
q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
|
| 1151 |
+
q_xs = q_xs * guidance.exp()
|
| 1152 |
+
|
| 1153 |
+
_x = _sample_categorical(q_xs)
|
| 1154 |
+
copy_flag = (x != self.mask_index).to(x.dtype)
|
| 1155 |
+
return copy_flag * x + (1 - copy_flag) * _x
|
| 1156 |
+
|
| 1157 |
+
def compute_gradient_CG(self, x_onehot, x, reward_model, sigma_s):
|
| 1158 |
+
x_onehot.requires_grad_(True)
|
| 1159 |
+
expected_x0 = self.forward(x_onehot, sigma_s) # Calcualte E[x_0|x_t]
|
| 1160 |
+
scores = reward_model(expected_x0.transpose(1, 2)[:,0:4,:])[:, 0]
|
| 1161 |
+
scores = scores.mean()
|
| 1162 |
+
scores.backward()
|
| 1163 |
+
x_grad = x_onehot.grad.clone()
|
| 1164 |
+
return x_grad
|
| 1165 |
+
|
| 1166 |
+
def _ddpm_update_finetune_controlled_TDS(self, x, t, dt, reward_model, alpha = 1.0, guidance_scale=1000):
|
| 1167 |
+
# SMC with the twisted proposal
|
| 1168 |
+
|
| 1169 |
+
sigma_t, _ = self.noise(t)
|
| 1170 |
+
sigma_s, _ = self.noise(t - dt)
|
| 1171 |
+
if sigma_t.ndim > 1:
|
| 1172 |
+
sigma_t = sigma_t.squeeze(-1)
|
| 1173 |
+
if sigma_s.ndim > 1:
|
| 1174 |
+
sigma_s = sigma_s.squeeze(-1)
|
| 1175 |
+
assert sigma_t.ndim == 1, sigma_t.shape
|
| 1176 |
+
assert sigma_s.ndim == 1, sigma_s.shape
|
| 1177 |
+
move_chance_t = 1 - torch.exp(-sigma_t)
|
| 1178 |
+
move_chance_s = 1 - torch.exp(-sigma_s)
|
| 1179 |
+
move_chance_t = move_chance_t[:, None, None]
|
| 1180 |
+
move_chance_s = move_chance_s[:, None, None]
|
| 1181 |
+
unet_conditioning = sigma_t
|
| 1182 |
+
log_p_x0 = self.forward(x, unet_conditioning)
|
| 1183 |
+
assert move_chance_t.ndim == log_p_x0.ndim
|
| 1184 |
+
q_xs = log_p_x0.exp() * (move_chance_t
|
| 1185 |
+
- move_chance_s)
|
| 1186 |
+
x_onehot = F.one_hot(x, num_classes=5).float()
|
| 1187 |
+
|
| 1188 |
+
x_grad = self.compute_gradient_CG(x_onehot, x, reward_model, sigma_s )
|
| 1189 |
+
guidance = guidance_scale * (x_grad - x_grad[:, :, self.mask_index][:, :, None])
|
| 1190 |
+
q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
|
| 1191 |
+
# print(q_xs.sum(-1))
|
| 1192 |
+
q_xs = q_xs * guidance.exp()
|
| 1193 |
+
|
| 1194 |
+
_x = _sample_categorical(q_xs)
|
| 1195 |
+
copy_flag = (x != self.mask_index).to(x.dtype)
|
| 1196 |
+
sample = copy_flag * x + (1 - copy_flag) * _x
|
| 1197 |
+
prob_multiplier = (1 - copy_flag) * torch.gather(guidance.exp(), 2, _x.unsqueeze(-1)).squeeze(-1) + copy_flag * torch.ones_like(_x)
|
| 1198 |
+
'''
|
| 1199 |
+
Calcualte exp(v_{t-1}(x_{t-1})/alpha)
|
| 1200 |
+
'''
|
| 1201 |
+
expected_x0 = self.forward(sample, sigma_s) # Calcualte E[x_0|x_{t-1}]
|
| 1202 |
+
expected_x0_arg = torch.argmax(expected_x0,dim=2)
|
| 1203 |
+
expected_x0_onehot = torch.nn.functional.one_hot(expected_x0_arg)
|
| 1204 |
+
reward_num = reward_model(expected_x0_onehot.float().transpose(1, 2)).detach()[:, 0][:, 0]
|
| 1205 |
+
'''
|
| 1206 |
+
Calcualte exp(v_{t}(x_{t})/alpha)
|
| 1207 |
+
'''
|
| 1208 |
+
expected_x0 = self.forward(x, sigma_s) # Calcualte E[x_0|x_t]
|
| 1209 |
+
expected_x0_arg = torch.argmax(expected_x0,dim=2)
|
| 1210 |
+
expected_x0_onehot = torch.nn.functional.one_hot(expected_x0_arg)
|
| 1211 |
+
reward_den = reward_model(expected_x0_onehot.float().transpose(1, 2)).detach()[:, 0][:, 0]
|
| 1212 |
+
|
| 1213 |
+
# set the nan values to 1
|
| 1214 |
+
prob_multiplier[torch.isnan(prob_multiplier)] = 1
|
| 1215 |
+
ratio = torch.exp(1.0/alpha * (reward_num - reward_den)) / prob_multiplier.prod(dim=-1)
|
| 1216 |
+
ratio = ratio.detach().cpu().numpy()
|
| 1217 |
+
final_sample_indices = np.random.choice(reward_num.shape[0], reward_num.shape[0], p = ratio/ratio.sum() )
|
| 1218 |
+
|
| 1219 |
+
return sample[final_sample_indices]
|
| 1220 |
+
|
| 1221 |
+
@torch.no_grad()
|
| 1222 |
+
def controlled_sample_SMC(self, reward_model, alpha, num_steps=None, eps=1e-5, eval_sp_size=None):
|
| 1223 |
+
"""Generate samples from the model."""
|
| 1224 |
+
if eval_sp_size is None:
|
| 1225 |
+
batch_size_per_gpu = self.config.loader.eval_batch_size
|
| 1226 |
+
else:
|
| 1227 |
+
batch_size_per_gpu = eval_sp_size
|
| 1228 |
+
if self.parameterization == 'ar':
|
| 1229 |
+
return self._ar_sampler(batch_size_per_gpu)
|
| 1230 |
+
# Lightning auto-casting is not working in this method for some reason
|
| 1231 |
+
if num_steps is None:
|
| 1232 |
+
num_steps = self.config.sampling.steps
|
| 1233 |
+
x = self._sample_prior(
|
| 1234 |
+
batch_size_per_gpu,
|
| 1235 |
+
self.config.model.length).to(self.device)
|
| 1236 |
+
timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
|
| 1237 |
+
dt = (1 - eps) / num_steps
|
| 1238 |
+
p_x0_cache = None
|
| 1239 |
+
|
| 1240 |
+
for i in range(num_steps):
|
| 1241 |
+
t = timesteps[i] * torch.ones(
|
| 1242 |
+
x.shape[0], 1, device=self.device)
|
| 1243 |
+
if self.sampler == 'ddpm':
|
| 1244 |
+
x = self._ddpm_update_finetune_controlled_SMC(x, t, dt, reward_model, alpha)
|
| 1245 |
+
else:
|
| 1246 |
+
x = self._analytic_update(x, t, dt)
|
| 1247 |
+
|
| 1248 |
+
if self.config.sampling.noise_removal:
|
| 1249 |
+
t = timesteps[-1] * torch.ones(x.shape[0], 1, device=self.device)
|
| 1250 |
+
if self.sampler == 'analytic':
|
| 1251 |
+
x = self._denoiser_update(x, t)
|
| 1252 |
+
else:
|
| 1253 |
+
unet_conditioning = self.noise(t)[0]
|
| 1254 |
+
logits = self.forward(x, unet_conditioning)
|
| 1255 |
+
x = logits[:, :, :-1].argmax(dim=-1)
|
| 1256 |
+
return x
|
| 1257 |
+
|
| 1258 |
+
def controlled_sample_CG(self, reward_model, guidance_scale, num_steps=None, eps=1e-5, eval_sp_size=None):
|
| 1259 |
+
"""Generate samples from the model."""
|
| 1260 |
+
if eval_sp_size is None:
|
| 1261 |
+
batch_size_per_gpu = self.config.loader.eval_batch_size
|
| 1262 |
+
else:
|
| 1263 |
+
batch_size_per_gpu = eval_sp_size
|
| 1264 |
+
if self.parameterization == 'ar':
|
| 1265 |
+
return self._ar_sampler(batch_size_per_gpu)
|
| 1266 |
+
# Lightning auto-casting is not working in this method for some reason
|
| 1267 |
+
if num_steps is None:
|
| 1268 |
+
num_steps = self.config.sampling.steps
|
| 1269 |
+
x = self._sample_prior(
|
| 1270 |
+
batch_size_per_gpu,
|
| 1271 |
+
self.config.model.length).to(self.device)
|
| 1272 |
+
timesteps = torch.linspace(
|
| 1273 |
+
1, eps, num_steps + 1, device=self.device)
|
| 1274 |
+
dt = (1 - eps) / num_steps
|
| 1275 |
+
p_x0_cache = None
|
| 1276 |
+
|
| 1277 |
+
for i in range(num_steps):
|
| 1278 |
+
t = timesteps[i] * torch.ones(
|
| 1279 |
+
x.shape[0], 1, device=self.device)
|
| 1280 |
+
if self.sampler == 'ddpm':
|
| 1281 |
+
x = self._ddpm_update_finetune_controlled_CG(x, t, dt, reward_model, guidance_scale)
|
| 1282 |
+
else:
|
| 1283 |
+
x = self._analytic_update(x, t, dt)
|
| 1284 |
+
|
| 1285 |
+
if self.config.sampling.noise_removal:
|
| 1286 |
+
t = timesteps[-1] * torch.ones(x.shape[0], 1,
|
| 1287 |
+
device=self.device)
|
| 1288 |
+
if self.sampler == 'analytic':
|
| 1289 |
+
x = self._denoiser_update(x, t)
|
| 1290 |
+
else:
|
| 1291 |
+
unet_conditioning = self.noise(t)[0]
|
| 1292 |
+
logits = self.forward(x, unet_conditioning)
|
| 1293 |
+
x = logits[:, :, :-1].argmax(dim=-1)
|
| 1294 |
+
return x
|
| 1295 |
+
|
| 1296 |
+
def controlled_sample_TDS(self, reward_model, alpha, guidance_scale, num_steps=None, eps=1e-5, eval_sp_size=None):
|
| 1297 |
+
"""Generate samples from the model."""
|
| 1298 |
+
if eval_sp_size is None:
|
| 1299 |
+
batch_size_per_gpu = self.config.loader.eval_batch_size
|
| 1300 |
+
else:
|
| 1301 |
+
batch_size_per_gpu = eval_sp_size
|
| 1302 |
+
|
| 1303 |
+
if self.parameterization == 'ar':
|
| 1304 |
+
return self._ar_sampler(batch_size_per_gpu)
|
| 1305 |
+
|
| 1306 |
+
if num_steps is None:
|
| 1307 |
+
num_steps = self.config.sampling.steps
|
| 1308 |
+
x = self._sample_prior(
|
| 1309 |
+
batch_size_per_gpu,
|
| 1310 |
+
self.config.model.length).to(self.device)
|
| 1311 |
+
timesteps = torch.linspace(
|
| 1312 |
+
1, eps, num_steps + 1, device=self.device)
|
| 1313 |
+
dt = (1 - eps) / num_steps
|
| 1314 |
+
p_x0_cache = None
|
| 1315 |
+
|
| 1316 |
+
for i in range(num_steps):
|
| 1317 |
+
t = timesteps[i] * torch.ones(
|
| 1318 |
+
x.shape[0], 1, device=self.device)
|
| 1319 |
+
if self.sampler == 'ddpm':
|
| 1320 |
+
x = self._ddpm_update_finetune_controlled_TDS(x, t, dt, reward_model,alpha, guidance_scale)
|
| 1321 |
+
else:
|
| 1322 |
+
x = self._analytic_update(x, t, dt)
|
| 1323 |
+
|
| 1324 |
+
if self.config.sampling.noise_removal:
|
| 1325 |
+
t = timesteps[-1] * torch.ones(x.shape[0], 1,
|
| 1326 |
+
device=self.device)
|
| 1327 |
+
if self.sampler == 'analytic':
|
| 1328 |
+
x = self._denoiser_update(x, t)
|
| 1329 |
+
else:
|
| 1330 |
+
unet_conditioning = self.noise(t)[0]
|
| 1331 |
+
logits = self.forward(x, unet_conditioning)
|
| 1332 |
+
x = logits[:, :, :-1].argmax(dim=-1)
|
| 1333 |
+
return x
|
| 1334 |
+
|
| 1335 |
+
@torch.no_grad()
|
| 1336 |
+
def get_likelihood(self, x0, num_steps=None, eps=1e-5, n_samples=1):
|
| 1337 |
+
"""Compute the likelihood of a sequence under the model.
|
| 1338 |
+
x0: int torch.Tensor with shape (batch_size,
|
| 1339 |
+
diffusion_model_input_length)
|
| 1340 |
+
"""
|
| 1341 |
+
if num_steps is None:
|
| 1342 |
+
num_steps = self.config.sampling.steps
|
| 1343 |
+
timesteps = torch.linspace(
|
| 1344 |
+
1, eps, num_steps + 1, device=self.device) # t=0 is clean data
|
| 1345 |
+
dt = (1 - eps) / num_steps
|
| 1346 |
+
log_p_sample_list = []
|
| 1347 |
+
for _ in range(n_samples):
|
| 1348 |
+
log_p_at_time_list = []
|
| 1349 |
+
for i in range(num_steps):
|
| 1350 |
+
t = timesteps[i] * torch.ones(
|
| 1351 |
+
x0.shape[0], 1, device=self.device)
|
| 1352 |
+
sigma_t, _ = self.noise(t)
|
| 1353 |
+
sigma_s, _ = self.noise(t - dt)
|
| 1354 |
+
if sigma_t.ndim > 1:
|
| 1355 |
+
sigma_t = sigma_t.squeeze(-1)
|
| 1356 |
+
if sigma_s.ndim > 1:
|
| 1357 |
+
sigma_s = sigma_s.squeeze(-1)
|
| 1358 |
+
assert sigma_t.ndim == 1, sigma_t.shape
|
| 1359 |
+
assert sigma_s.ndim == 1, sigma_s.shape
|
| 1360 |
+
move_chance_t = 1 - torch.exp(-sigma_t) # (1-eps)*t
|
| 1361 |
+
move_chance_s = 1 - torch.exp(-sigma_s)
|
| 1362 |
+
move_chance_t = move_chance_t[:, None] # [bsz, 1]
|
| 1363 |
+
move_chance_s = move_chance_s[:, None]
|
| 1364 |
+
unet_conditioning = sigma_t # [bsz]
|
| 1365 |
+
multiplier = (move_chance_t - move_chance_s)/move_chance_t # [bsz, 1]
|
| 1366 |
+
xt = self.q_xt(x0, move_chance_t) # [bsz, seq_len]
|
| 1367 |
+
# log prob, already apply subs parametrization (unmasked token remains unchanged)
|
| 1368 |
+
model_output = self.forward(xt, unet_conditioning) # [bsz, seq_len, vocab_size]
|
| 1369 |
+
# take the log prob of the token that corresponds to x0
|
| 1370 |
+
log_p_x0 = model_output.gather(-1, x0[..., None]).squeeze(-1) # [bsz, seq_len]
|
| 1371 |
+
log_p_x0 = log_p_x0 * multiplier
|
| 1372 |
+
log_p_at_time_list.append(log_p_x0)
|
| 1373 |
+
log_p_x0 = torch.stack(log_p_at_time_list, dim=0).sum(dim=0) # [bsz, seq_len]
|
| 1374 |
+
log_p_sample_list.append(log_p_x0.sum(dim=-1))
|
| 1375 |
+
log_p_sample = torch.stack(log_p_sample_list, dim=0).mean(dim=0)
|
| 1376 |
+
return log_p_sample
|
| 1377 |
+
|
| 1378 |
+
def get_score(self, x, sigma):
|
| 1379 |
+
model_output = self.forward(x, sigma)
|
| 1380 |
+
if self.parameterization == 'subs':
|
| 1381 |
+
# score(x, t) = p_t(y) / p_t(x)
|
| 1382 |
+
# => log score(x, t) = log p_t(y) - log p_t(x)
|
| 1383 |
+
|
| 1384 |
+
# case 1: x = masked
|
| 1385 |
+
# (i) y = unmasked
|
| 1386 |
+
# log score(x, t) = log p_\theta(x)|_y + log k
|
| 1387 |
+
# where k = exp(- sigma) / (1 - exp(- sigma))
|
| 1388 |
+
# (ii) y = masked
|
| 1389 |
+
# log score(x, t) = 0
|
| 1390 |
+
|
| 1391 |
+
# case 2: x = unmasked
|
| 1392 |
+
# (i) y != masked, y != x
|
| 1393 |
+
# log score(x_i, t) = - inf
|
| 1394 |
+
# (ii) y = x
|
| 1395 |
+
# log score(x_i, t) = 0
|
| 1396 |
+
# (iii) y = masked token
|
| 1397 |
+
# log score(x_i, t) = - log k
|
| 1398 |
+
# where k = exp(- sigma) / (1 - exp(- sigma))
|
| 1399 |
+
|
| 1400 |
+
log_k = - torch.log(torch.expm1(sigma)).squeeze(-1)
|
| 1401 |
+
assert log_k.ndim == 1
|
| 1402 |
+
|
| 1403 |
+
masked_score = model_output + log_k[:, None, None]
|
| 1404 |
+
masked_score[:, :, self.mask_index] = 0
|
| 1405 |
+
|
| 1406 |
+
unmasked_score = self.neg_infinity * torch.ones_like(
|
| 1407 |
+
model_output)
|
| 1408 |
+
unmasked_score = torch.scatter(
|
| 1409 |
+
unmasked_score,
|
| 1410 |
+
-1,
|
| 1411 |
+
x[..., None],
|
| 1412 |
+
torch.zeros_like(unmasked_score[..., :1]))
|
| 1413 |
+
unmasked_score[:, :, self.mask_index] = - (
|
| 1414 |
+
log_k[:, None] * torch.ones_like(x))
|
| 1415 |
+
|
| 1416 |
+
masked_indices = (x == self.mask_index).to(
|
| 1417 |
+
model_output.dtype)[:, :, None]
|
| 1418 |
+
model_output = (
|
| 1419 |
+
masked_score * masked_indices
|
| 1420 |
+
+ unmasked_score * (1 - masked_indices))
|
| 1421 |
+
return model_output.exp()
|
| 1422 |
+
|
| 1423 |
+
def _staggered_score(self, score, dsigma):
|
| 1424 |
+
score = score.clone()
|
| 1425 |
+
extra_const = (1 - dsigma.exp()) * score.sum(dim=-1)
|
| 1426 |
+
score *= dsigma.exp()[:, None]
|
| 1427 |
+
score[..., self.mask_index] += extra_const
|
| 1428 |
+
return score
|
| 1429 |
+
|
| 1430 |
+
def _analytic_update(self, x, t, step_size):
|
| 1431 |
+
curr_sigma, _ = self.noise(t)
|
| 1432 |
+
next_sigma, _ = self.noise(t - step_size)
|
| 1433 |
+
dsigma = curr_sigma - next_sigma
|
| 1434 |
+
score = self.get_score(x, curr_sigma)
|
| 1435 |
+
stag_score = self._staggered_score(score, dsigma)
|
| 1436 |
+
probs = stag_score * self._transp_transition(x, dsigma)
|
| 1437 |
+
return _sample_categorical(probs)
|
| 1438 |
+
|
| 1439 |
+
def _denoiser_update(self, x, t):
|
| 1440 |
+
sigma, _ = self.noise(t)
|
| 1441 |
+
score = self.get_score(x, sigma)
|
| 1442 |
+
stag_score = self._staggered_score(score, sigma)
|
| 1443 |
+
probs = stag_score * self._transp_transition(x, sigma)
|
| 1444 |
+
probs[..., self.mask_index] = 0
|
| 1445 |
+
samples = _sample_categorical(probs)
|
| 1446 |
+
return samples
|
| 1447 |
+
|
| 1448 |
+
def _transp_transition(self, i, sigma):
|
| 1449 |
+
sigma = _unsqueeze(sigma, reference=i[..., None])
|
| 1450 |
+
edge = torch.exp(-sigma) * F.one_hot(
|
| 1451 |
+
i, num_classes=self.vocab_size)
|
| 1452 |
+
edge += torch.where(i == self.mask_index,
|
| 1453 |
+
1 - torch.exp(-sigma).squeeze(-1),
|
| 1454 |
+
0)[..., None]
|
| 1455 |
+
return edge
|
| 1456 |
+
|
| 1457 |
+
def _sample_t(self, n, device):
|
| 1458 |
+
_eps_t = torch.rand(n, device=device)
|
| 1459 |
+
if self.antithetic_sampling:
|
| 1460 |
+
# for variance reduction
|
| 1461 |
+
offset = torch.arange(n, device=device) / n
|
| 1462 |
+
_eps_t = (_eps_t / n + offset) % 1
|
| 1463 |
+
t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps
|
| 1464 |
+
if self.importance_sampling:
|
| 1465 |
+
return self.noise.importance_sampling_transformation(t)
|
| 1466 |
+
return t
|
| 1467 |
+
|
| 1468 |
+
def _maybe_sub_sample(self, x0, attention_mask):
|
| 1469 |
+
seqlen = x0.shape[1]
|
| 1470 |
+
if seqlen > self.config.model.length:
|
| 1471 |
+
raise NotImplementedError('Sub-sampling not implemented')
|
| 1472 |
+
elif self.parameterization == 'ar':
|
| 1473 |
+
input_tokens = x0[:, :-1]
|
| 1474 |
+
output_tokens = x0[:, 1:]
|
| 1475 |
+
new_attention_mask = attention_mask[:, 1:]
|
| 1476 |
+
else:
|
| 1477 |
+
input_tokens = x0
|
| 1478 |
+
output_tokens = None
|
| 1479 |
+
new_attention_mask = attention_mask
|
| 1480 |
+
return input_tokens, output_tokens, new_attention_mask
|
| 1481 |
+
|
| 1482 |
+
def _reconstruction_loss(self, x0):
|
| 1483 |
+
t0 = torch.zeros(x0.shape[0], dtype=self.dtype,
|
| 1484 |
+
device=self.device)
|
| 1485 |
+
assert self.config.noise.type == 'loglinear'
|
| 1486 |
+
# The above assert is for d3pm parameterization
|
| 1487 |
+
unet_conditioning = self.noise(t0)[0][:, None]
|
| 1488 |
+
model_output_t0 = self.forward(x0, unet_conditioning)
|
| 1489 |
+
return - torch.gather(input=model_output_t0,
|
| 1490 |
+
dim=-1,
|
| 1491 |
+
index=x0[:, :, None]).squeeze(-1)
|
| 1492 |
+
|
| 1493 |
+
def _forward_pass_diffusion(self, x0):
|
| 1494 |
+
t = self._sample_t(x0.shape[0], x0.device)
|
| 1495 |
+
if self.T > 0:
|
| 1496 |
+
# else ts are between 0 and 1
|
| 1497 |
+
t = (t * self.T).to(torch.int)
|
| 1498 |
+
t = t / self.T
|
| 1499 |
+
# t \in {1/T, 2/T, ..., 1}
|
| 1500 |
+
t += (1 / self.T)
|
| 1501 |
+
|
| 1502 |
+
if self.change_of_variables: # False
|
| 1503 |
+
unet_conditioning = t[:, None]
|
| 1504 |
+
f_T = torch.log1p(- torch.exp(- self.noise.sigma_max))
|
| 1505 |
+
f_0 = torch.log1p(- torch.exp(- self.noise.sigma_min))
|
| 1506 |
+
move_chance = torch.exp(f_0 + t * (f_T - f_0))
|
| 1507 |
+
move_chance = move_chance[:, None]
|
| 1508 |
+
else:
|
| 1509 |
+
sigma, dsigma = self.noise(t) # total noise, rate noise
|
| 1510 |
+
unet_conditioning = sigma[:, None]
|
| 1511 |
+
move_chance = 1 - torch.exp(-sigma[:, None])
|
| 1512 |
+
|
| 1513 |
+
xt = self.q_xt(x0, move_chance) # q(xt|x0)
|
| 1514 |
+
model_output = self.forward(xt, unet_conditioning)
|
| 1515 |
+
utils.print_nans(model_output, 'model_output')
|
| 1516 |
+
|
| 1517 |
+
if self.parameterization == 'sedd':
|
| 1518 |
+
return dsigma[:, None] * self._score_entropy(
|
| 1519 |
+
model_output, sigma[:, None], xt, x0)
|
| 1520 |
+
|
| 1521 |
+
if self.T > 0:
|
| 1522 |
+
diffusion_loss = self._d3pm_loss(
|
| 1523 |
+
model_output=model_output, xt=xt, x0=x0, t=t)
|
| 1524 |
+
if self.parameterization == 'd3pm':
|
| 1525 |
+
reconstruction_loss = self._reconstruction_loss(x0)
|
| 1526 |
+
elif self.parameterization == 'subs':
|
| 1527 |
+
reconstruction_loss = 0
|
| 1528 |
+
return reconstruction_loss + diffusion_loss
|
| 1529 |
+
|
| 1530 |
+
# SUBS parameterization, continuous time.
|
| 1531 |
+
log_p_theta = torch.gather(
|
| 1532 |
+
input=model_output,
|
| 1533 |
+
dim=-1,
|
| 1534 |
+
index=x0[:, :, None]).squeeze(-1)
|
| 1535 |
+
|
| 1536 |
+
if self.change_of_variables or self.importance_sampling:
|
| 1537 |
+
return log_p_theta * torch.log1p(
|
| 1538 |
+
- torch.exp(- self.noise.sigma_min))
|
| 1539 |
+
|
| 1540 |
+
return - log_p_theta * (
|
| 1541 |
+
dsigma / torch.expm1(sigma))[:, None]
|
| 1542 |
+
|
| 1543 |
+
def _loss(self, x0, attention_mask):
|
| 1544 |
+
(input_tokens, output_tokens, attention_mask) = self._maybe_sub_sample(
|
| 1545 |
+
x0, attention_mask)
|
| 1546 |
+
|
| 1547 |
+
if self.parameterization == 'ar':
|
| 1548 |
+
logprobs = self.backbone(input_tokens, None)
|
| 1549 |
+
loss = - logprobs.gather(
|
| 1550 |
+
-1, output_tokens[:, :, None])[:, :, 0]
|
| 1551 |
+
else:
|
| 1552 |
+
loss = self._forward_pass_diffusion(input_tokens)
|
| 1553 |
+
|
| 1554 |
+
nlls = loss * attention_mask
|
| 1555 |
+
count = attention_mask.sum()
|
| 1556 |
+
|
| 1557 |
+
batch_nll = nlls.sum()
|
| 1558 |
+
token_nll = batch_nll / count
|
| 1559 |
+
|
| 1560 |
+
return Loss(loss=token_nll,
|
| 1561 |
+
nlls=nlls,
|
| 1562 |
+
token_mask=attention_mask)
|
| 1563 |
+
|
| 1564 |
+
def _score_entropy(self, log_score, sigma, xt, x0):
|
| 1565 |
+
"""Computes the SEDD loss.
|
| 1566 |
+
|
| 1567 |
+
Args:
|
| 1568 |
+
log_score: float torch.Tensor with shape (batch_size,
|
| 1569 |
+
diffusion_model_input_length, vocab_size),
|
| 1570 |
+
log score, output of the denoising network.
|
| 1571 |
+
xt: int torch.Tensor with shape (batch_size,
|
| 1572 |
+
diffusion_model_input_length), input.
|
| 1573 |
+
x0: int torch.Tensor with shape (batch_size,
|
| 1574 |
+
diffusion_model_input_length), input.
|
| 1575 |
+
sigma: float torch.Tensor with shape (batch_size, 1).
|
| 1576 |
+
|
| 1577 |
+
Returns:
|
| 1578 |
+
loss with shape (batch_size, diffusion_model_input_length)
|
| 1579 |
+
"""
|
| 1580 |
+
# seems that it takes y=x0,xt=M case
|
| 1581 |
+
# what is the const term for, seems to be y=M,xt=x0 case and x0 is known so score estimation is precise
|
| 1582 |
+
masked_indices = xt == self.mask_index
|
| 1583 |
+
|
| 1584 |
+
expsig_minus_1 = torch.expm1(sigma).expand_as(xt)
|
| 1585 |
+
q_ratio = 1 / expsig_minus_1[masked_indices]
|
| 1586 |
+
|
| 1587 |
+
words_that_were_masked = x0[masked_indices]
|
| 1588 |
+
|
| 1589 |
+
neg_term = q_ratio * torch.gather(
|
| 1590 |
+
log_score[masked_indices],
|
| 1591 |
+
-1,
|
| 1592 |
+
words_that_were_masked[..., None]).squeeze(-1)
|
| 1593 |
+
score = log_score[masked_indices].exp()
|
| 1594 |
+
if self.mask_index == self.vocab_size - 1:
|
| 1595 |
+
pos_term = score[:, :-1].sum(dim=-1)
|
| 1596 |
+
else:
|
| 1597 |
+
pos_term = score[:, : self.mask_index].sum(
|
| 1598 |
+
dim=-1) + score[:, self.mask_index + 1:].sum(dim=-1)
|
| 1599 |
+
const = q_ratio * (q_ratio.log() - 1)
|
| 1600 |
+
|
| 1601 |
+
entropy = torch.zeros(* xt.shape, device=xt.device)
|
| 1602 |
+
entropy[masked_indices] += pos_term - neg_term + const
|
| 1603 |
+
return entropy
|
| 1604 |
+
|
tr2d2-dna/diffusion_gosai_cfg.py
ADDED
|
@@ -0,0 +1,729 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import math
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
import hydra.utils
|
| 5 |
+
import lightning as L
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torchmetrics
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
import dataloader_gosai
|
| 12 |
+
import models
|
| 13 |
+
import noise_schedule
|
| 14 |
+
import utils
|
| 15 |
+
import oracle
|
| 16 |
+
|
| 17 |
+
LOG2 = math.log(2)
|
| 18 |
+
LOGGER = utils.get_logger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _sample_categorical(categorical_probs):
|
| 22 |
+
gumbel_norm = (
|
| 23 |
+
1e-10
|
| 24 |
+
- (torch.rand_like(categorical_probs) + 1e-10).log())
|
| 25 |
+
return (categorical_probs / gumbel_norm).argmax(dim=-1)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _unsqueeze(x, reference):
|
| 29 |
+
return x.view(
|
| 30 |
+
* x.shape,
|
| 31 |
+
* ((1,) * (len(reference.shape) - len(x.shape))))
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class Loss:
|
| 36 |
+
loss: torch.FloatTensor
|
| 37 |
+
nlls: torch.FloatTensor
|
| 38 |
+
token_mask: torch.FloatTensor
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class NLL(torchmetrics.aggregation.MeanMetric):
|
| 42 |
+
pass
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class BPD(NLL):
|
| 46 |
+
def compute(self) -> Tensor:
|
| 47 |
+
"""Computes the bits per dimension.
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
bpd
|
| 51 |
+
"""
|
| 52 |
+
return self.mean_value / self.weight / LOG2
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class Perplexity(NLL):
|
| 56 |
+
def compute(self) -> Tensor:
|
| 57 |
+
"""Computes the Perplexity.
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Perplexity
|
| 61 |
+
"""
|
| 62 |
+
return torch.exp(self.mean_value / self.weight)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class Diffusion(L.LightningModule):
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
config,
|
| 69 |
+
eval=True):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.save_hyperparameters()
|
| 72 |
+
self.config = config
|
| 73 |
+
self.vocab_size = 4
|
| 74 |
+
self.sampler = self.config.sampling.predictor
|
| 75 |
+
self.antithetic_sampling = self.config.training.antithetic_sampling
|
| 76 |
+
self.importance_sampling = self.config.training.importance_sampling
|
| 77 |
+
self.change_of_variables = self.config.training.change_of_variables
|
| 78 |
+
self.mask_index = self.vocab_size
|
| 79 |
+
self.vocab_size += 1
|
| 80 |
+
self.parameterization = self.config.parameterization
|
| 81 |
+
if self.config.backbone == 'cnn':
|
| 82 |
+
self.backbone = models.dnaconv.CNNModel(
|
| 83 |
+
self.config.model, alphabet_size=self.vocab_size, num_cls=2)
|
| 84 |
+
else:
|
| 85 |
+
raise ValueError(
|
| 86 |
+
f'Unknown backbone: {self.config.backbone}')
|
| 87 |
+
|
| 88 |
+
self.T = self.config.T
|
| 89 |
+
self.subs_masking = self.config.subs_masking
|
| 90 |
+
|
| 91 |
+
self.softplus = torch.nn.Softplus()
|
| 92 |
+
# metrics are automatically reset at end of epoch
|
| 93 |
+
metrics = torchmetrics.MetricCollection({
|
| 94 |
+
'nll': NLL(),
|
| 95 |
+
'bpd': BPD(),
|
| 96 |
+
'ppl': Perplexity(),
|
| 97 |
+
})
|
| 98 |
+
metrics.set_dtype(torch.float64)
|
| 99 |
+
self.train_metrics = metrics.clone(prefix='train/')
|
| 100 |
+
self.valid_metrics = metrics.clone(prefix='val/')
|
| 101 |
+
self.test_metrics = metrics.clone(prefix='test/')
|
| 102 |
+
|
| 103 |
+
# generative perplexity
|
| 104 |
+
self.gen_ppl_metric = Perplexity()
|
| 105 |
+
|
| 106 |
+
self.noise = noise_schedule.get_noise(self.config,
|
| 107 |
+
dtype=self.dtype)
|
| 108 |
+
if self.config.training.ema > 0:
|
| 109 |
+
self.ema = models.ema.ExponentialMovingAverage(
|
| 110 |
+
itertools.chain(self.backbone.parameters(),
|
| 111 |
+
self.noise.parameters()),
|
| 112 |
+
decay=self.config.training.ema)
|
| 113 |
+
else:
|
| 114 |
+
self.ema = None
|
| 115 |
+
|
| 116 |
+
self.lr = self.config.optim.lr
|
| 117 |
+
self.sampling_eps = self.config.training.sampling_eps
|
| 118 |
+
self.time_conditioning = self.config.time_conditioning
|
| 119 |
+
self.neg_infinity = -1000000.0
|
| 120 |
+
self.fast_forward_epochs = None
|
| 121 |
+
self.fast_forward_batches = None
|
| 122 |
+
self._validate_configuration()
|
| 123 |
+
|
| 124 |
+
# subset of data for evaluation
|
| 125 |
+
if eval:
|
| 126 |
+
self.eval_sets_sp = oracle.subset_for_eval(n=config.eval.subset_size)
|
| 127 |
+
self.eval_sets_sp_clss = oracle.subset_eval_groundtruth(self.eval_sets_sp)
|
| 128 |
+
self.eval_sets_sp_preds = oracle.subset_eval_preds(self.eval_sets_sp)
|
| 129 |
+
self.eval_sets_sp_kmers = oracle.subset_eval_kmers(self.eval_sets_sp)
|
| 130 |
+
self.emb_pca = oracle.cal_emb_pca(oracle.subset_for_eval(n=40000), n_components=50)
|
| 131 |
+
self.eval_sets_sp_embs_pca = oracle.subset_eval_embs_pca(self.eval_sets_sp, self.emb_pca)
|
| 132 |
+
|
| 133 |
+
def _validate_configuration(self):
|
| 134 |
+
assert not (self.change_of_variables
|
| 135 |
+
and self.importance_sampling)
|
| 136 |
+
assert self.parameterization == 'subs'
|
| 137 |
+
|
| 138 |
+
def on_load_checkpoint(self, checkpoint):
|
| 139 |
+
if self.ema:
|
| 140 |
+
self.ema.load_state_dict(checkpoint['ema'])
|
| 141 |
+
# Copied from:
|
| 142 |
+
# https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py#L41
|
| 143 |
+
self.fast_forward_epochs = checkpoint['loops'][
|
| 144 |
+
'fit_loop']['epoch_progress']['current']['completed']
|
| 145 |
+
self.fast_forward_batches = checkpoint['loops'][
|
| 146 |
+
'fit_loop']['epoch_loop.batch_progress'][
|
| 147 |
+
'current']['completed']
|
| 148 |
+
|
| 149 |
+
def on_save_checkpoint(self, checkpoint):
|
| 150 |
+
if self.ema:
|
| 151 |
+
checkpoint['ema'] = self.ema.state_dict()
|
| 152 |
+
# Copied from:
|
| 153 |
+
# https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py
|
| 154 |
+
# ['epoch_loop.batch_progress']['total']['completed'] is 1 iteration
|
| 155 |
+
# behind, so we're using the optimizer's progress.
|
| 156 |
+
checkpoint['loops']['fit_loop'][
|
| 157 |
+
'epoch_loop.batch_progress']['total'][
|
| 158 |
+
'completed'] = checkpoint['loops']['fit_loop'][
|
| 159 |
+
'epoch_loop.automatic_optimization.optim_progress'][
|
| 160 |
+
'optimizer']['step']['total'][
|
| 161 |
+
'completed'] * self.trainer.accumulate_grad_batches
|
| 162 |
+
checkpoint['loops']['fit_loop'][
|
| 163 |
+
'epoch_loop.batch_progress']['current'][
|
| 164 |
+
'completed'] = checkpoint['loops']['fit_loop'][
|
| 165 |
+
'epoch_loop.automatic_optimization.optim_progress'][
|
| 166 |
+
'optimizer']['step']['current'][
|
| 167 |
+
'completed'] * self.trainer.accumulate_grad_batches
|
| 168 |
+
# _batches_that_stepped tracks the number of global steps, not the number
|
| 169 |
+
# of local steps, so we don't multiply with self.trainer.accumulate_grad_batches here.
|
| 170 |
+
checkpoint['loops']['fit_loop'][
|
| 171 |
+
'epoch_loop.state_dict'][
|
| 172 |
+
'_batches_that_stepped'] = checkpoint['loops']['fit_loop'][
|
| 173 |
+
'epoch_loop.automatic_optimization.optim_progress'][
|
| 174 |
+
'optimizer']['step']['total']['completed']
|
| 175 |
+
if 'sampler' not in checkpoint.keys():
|
| 176 |
+
checkpoint['sampler'] = {}
|
| 177 |
+
if hasattr(self.trainer.train_dataloader.sampler,
|
| 178 |
+
'state_dict'):
|
| 179 |
+
sampler_state_dict = self.trainer.\
|
| 180 |
+
train_dataloader.sampler.state_dict()
|
| 181 |
+
checkpoint['sampler'][
|
| 182 |
+
'random_state'] = sampler_state_dict.get(
|
| 183 |
+
'random_state', None)
|
| 184 |
+
else:
|
| 185 |
+
checkpoint['sampler']['random_state'] = None
|
| 186 |
+
|
| 187 |
+
def on_train_start(self):
|
| 188 |
+
if self.ema:
|
| 189 |
+
self.ema.move_shadow_params_to_device(self.device)
|
| 190 |
+
# Adapted from:
|
| 191 |
+
# https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py
|
| 192 |
+
distributed = (
|
| 193 |
+
self.trainer._accelerator_connector.use_distributed_sampler
|
| 194 |
+
and self.trainer._accelerator_connector.is_distributed)
|
| 195 |
+
|
| 196 |
+
print('distributed:', distributed)
|
| 197 |
+
if distributed:
|
| 198 |
+
sampler_cls = dataloader_gosai.FaultTolerantDistributedSampler
|
| 199 |
+
else:
|
| 200 |
+
sampler_cls = dataloader_gosai.RandomFaultTolerantSampler
|
| 201 |
+
|
| 202 |
+
updated_dls = []
|
| 203 |
+
for dl in self.trainer.fit_loop._combined_loader.flattened:
|
| 204 |
+
if hasattr(dl.sampler, 'shuffle'):
|
| 205 |
+
dl_sampler = sampler_cls(
|
| 206 |
+
dl.dataset, shuffle=dl.sampler.shuffle)
|
| 207 |
+
else:
|
| 208 |
+
dl_sampler = sampler_cls(dl.dataset)
|
| 209 |
+
if (distributed
|
| 210 |
+
and self.fast_forward_epochs is not None
|
| 211 |
+
and self.fast_forward_batches is not None):
|
| 212 |
+
dl_sampler.load_state_dict({
|
| 213 |
+
'epoch': self.fast_forward_epochs,
|
| 214 |
+
'counter': (self.fast_forward_batches
|
| 215 |
+
* self.config.loader.batch_size)})
|
| 216 |
+
updated_dls.append(
|
| 217 |
+
torch.utils.data.DataLoader(
|
| 218 |
+
dl.dataset,
|
| 219 |
+
batch_size=self.config.loader.batch_size,
|
| 220 |
+
num_workers=self.config.loader.num_workers,
|
| 221 |
+
pin_memory=self.config.loader.pin_memory,
|
| 222 |
+
sampler=dl_sampler,
|
| 223 |
+
shuffle=False,
|
| 224 |
+
persistent_workers=True))
|
| 225 |
+
self.trainer.fit_loop._combined_loader.flattened = updated_dls
|
| 226 |
+
|
| 227 |
+
def optimizer_step(self, *args, **kwargs):
|
| 228 |
+
super().optimizer_step(*args, **kwargs)
|
| 229 |
+
if self.ema:
|
| 230 |
+
self.ema.update(itertools.chain(
|
| 231 |
+
self.backbone.parameters(),
|
| 232 |
+
self.noise.parameters()))
|
| 233 |
+
|
| 234 |
+
def _subs_parameterization(self, logits, xt):
|
| 235 |
+
logits[:, :, self.mask_index] += self.neg_infinity
|
| 236 |
+
logits = logits - torch.logsumexp(logits, dim=-1,
|
| 237 |
+
keepdim=True)
|
| 238 |
+
|
| 239 |
+
unmasked_indices = (xt != self.mask_index)
|
| 240 |
+
logits[unmasked_indices] = self.neg_infinity
|
| 241 |
+
logits[unmasked_indices, xt[unmasked_indices]] = 0
|
| 242 |
+
return logits
|
| 243 |
+
|
| 244 |
+
def _process_sigma(self, sigma):
|
| 245 |
+
if sigma is None:
|
| 246 |
+
assert self.parameterization == 'ar'
|
| 247 |
+
return sigma
|
| 248 |
+
if sigma.ndim > 1:
|
| 249 |
+
sigma = sigma.squeeze(-1)
|
| 250 |
+
if not self.time_conditioning:
|
| 251 |
+
sigma = torch.zeros_like(sigma)
|
| 252 |
+
assert sigma.ndim == 1, sigma.shape
|
| 253 |
+
return sigma
|
| 254 |
+
|
| 255 |
+
def forward(self, x, sigma, binary_clss=None):
|
| 256 |
+
"""Returns log score."""
|
| 257 |
+
sigma = self._process_sigma(sigma)
|
| 258 |
+
with torch.cuda.amp.autocast(dtype=torch.float32):
|
| 259 |
+
logits = self.backbone(x, sigma, cls=binary_clss)
|
| 260 |
+
|
| 261 |
+
if self.parameterization == 'subs':
|
| 262 |
+
return self._subs_parameterization(logits=logits, xt=x)
|
| 263 |
+
|
| 264 |
+
return logits
|
| 265 |
+
|
| 266 |
+
def _compute_loss(self, batch, prefix):
|
| 267 |
+
if 'attention_mask' in batch:
|
| 268 |
+
attention_mask = batch['attention_mask']
|
| 269 |
+
else:
|
| 270 |
+
attention_mask = None
|
| 271 |
+
# classifier-free guidance
|
| 272 |
+
assert self.config.model.cls_free_guidance == True
|
| 273 |
+
binary_clss = (batch['clss'][:,0] > self.config.model.cls_free_threshold).long()
|
| 274 |
+
random_list = np.random.binomial(1, self.config.model.cls_free_prob, binary_clss.shape[0])
|
| 275 |
+
binary_clss[random_list==1] = 2
|
| 276 |
+
losses = self._loss(batch['seqs'], attention_mask, binary_clss)
|
| 277 |
+
loss = losses.loss
|
| 278 |
+
|
| 279 |
+
if prefix == 'train':
|
| 280 |
+
self.train_metrics.update(losses.nlls, losses.token_mask)
|
| 281 |
+
metrics = self.train_metrics
|
| 282 |
+
elif prefix == 'val':
|
| 283 |
+
self.valid_metrics.update(losses.nlls, losses.token_mask)
|
| 284 |
+
metrics = self.valid_metrics
|
| 285 |
+
elif prefix == 'test':
|
| 286 |
+
self.test_metrics.update(losses.nlls, losses.token_mask)
|
| 287 |
+
metrics = self.test_metrics
|
| 288 |
+
else:
|
| 289 |
+
raise ValueError(f'Invalid prefix: {prefix}')
|
| 290 |
+
|
| 291 |
+
self.log_dict(metrics,
|
| 292 |
+
on_step=False,
|
| 293 |
+
on_epoch=True,
|
| 294 |
+
sync_dist=True)
|
| 295 |
+
return loss
|
| 296 |
+
|
| 297 |
+
def on_train_epoch_start(self):
|
| 298 |
+
self.backbone.train()
|
| 299 |
+
self.noise.train()
|
| 300 |
+
|
| 301 |
+
def training_step(self, batch, batch_idx):
|
| 302 |
+
loss = self._compute_loss(batch, prefix='train')
|
| 303 |
+
self.log(name='trainer/loss',
|
| 304 |
+
value=loss.item(),
|
| 305 |
+
on_step=True,
|
| 306 |
+
on_epoch=False,
|
| 307 |
+
sync_dist=True)
|
| 308 |
+
return loss
|
| 309 |
+
|
| 310 |
+
def on_validation_epoch_start(self):
|
| 311 |
+
if self.ema:
|
| 312 |
+
self.ema.store(itertools.chain(
|
| 313 |
+
self.backbone.parameters(),
|
| 314 |
+
self.noise.parameters()))
|
| 315 |
+
self.ema.copy_to(itertools.chain(
|
| 316 |
+
self.backbone.parameters(),
|
| 317 |
+
self.noise.parameters()))
|
| 318 |
+
self.backbone.eval()
|
| 319 |
+
self.noise.eval()
|
| 320 |
+
assert self.valid_metrics.nll.mean_value == 0
|
| 321 |
+
assert self.valid_metrics.nll.weight == 0
|
| 322 |
+
|
| 323 |
+
def validation_step(self, batch, batch_idx):
|
| 324 |
+
return self._compute_loss(batch, prefix='val')
|
| 325 |
+
|
| 326 |
+
def on_validation_epoch_end(self):
|
| 327 |
+
if ((self.config.eval.compute_perplexity_on_sanity
|
| 328 |
+
or not self.trainer.sanity_checking)
|
| 329 |
+
and self.config.eval.generate_samples
|
| 330 |
+
and not self.parameterization == 'ar'):
|
| 331 |
+
all_samples, all_detoeknized_samples = [], []
|
| 332 |
+
for _ in range(
|
| 333 |
+
self.config.sampling.num_sample_batches):
|
| 334 |
+
samples = self._sample(cls=1).detach().cpu().numpy()
|
| 335 |
+
detokenized_samples = dataloader_gosai.batch_dna_detokenize(samples)
|
| 336 |
+
all_samples.append(samples)
|
| 337 |
+
all_detoeknized_samples.extend(detokenized_samples)
|
| 338 |
+
all_samples = np.concatenate(all_samples, axis=0)
|
| 339 |
+
generated_preds = oracle.cal_gosai_pred(all_detoeknized_samples, mode='eval')[:,0]
|
| 340 |
+
avg_generated_preds = np.mean(generated_preds, axis=0)
|
| 341 |
+
|
| 342 |
+
current_step = self.trainer.global_step
|
| 343 |
+
LOGGER.info(f'Current step: {current_step}')
|
| 344 |
+
LOGGER.info(f'Generated preds: {avg_generated_preds}')
|
| 345 |
+
self.log('val/gosai_preds_avg', avg_generated_preds, on_step=False, on_epoch=True, sync_dist=True)
|
| 346 |
+
|
| 347 |
+
if self.ema:
|
| 348 |
+
self.ema.restore(
|
| 349 |
+
itertools.chain(self.backbone.parameters(),
|
| 350 |
+
self.noise.parameters()))
|
| 351 |
+
|
| 352 |
+
def configure_optimizers(self):
|
| 353 |
+
# TODO(yair): Lightning currently giving this warning when using `fp16`:
|
| 354 |
+
# "Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
|
| 355 |
+
# Not clear if this is a problem or not.
|
| 356 |
+
# See: https://github.com/Lightning-AI/pytorch-lightning/issues/5558
|
| 357 |
+
optimizer = torch.optim.AdamW(
|
| 358 |
+
itertools.chain(self.backbone.parameters(),
|
| 359 |
+
self.noise.parameters()),
|
| 360 |
+
lr=self.config.optim.lr,
|
| 361 |
+
betas=(self.config.optim.beta1,
|
| 362 |
+
self.config.optim.beta2),
|
| 363 |
+
eps=self.config.optim.eps,
|
| 364 |
+
weight_decay=self.config.optim.weight_decay)
|
| 365 |
+
|
| 366 |
+
scheduler = hydra.utils.instantiate(
|
| 367 |
+
self.config.lr_scheduler, optimizer=optimizer)
|
| 368 |
+
scheduler_dict = {
|
| 369 |
+
'scheduler': scheduler,
|
| 370 |
+
'interval': 'step',
|
| 371 |
+
'monitor': 'val/loss',
|
| 372 |
+
'name': 'trainer/lr',
|
| 373 |
+
}
|
| 374 |
+
return [optimizer], [scheduler_dict]
|
| 375 |
+
|
| 376 |
+
def q_xt(self, x, move_chance):
|
| 377 |
+
"""Computes the noisy sample xt.
|
| 378 |
+
|
| 379 |
+
Args:
|
| 380 |
+
x: int torch.Tensor with shape (batch_size,
|
| 381 |
+
diffusion_model_input_length), input.
|
| 382 |
+
move_chance: float torch.Tensor with shape (batch_size, 1).
|
| 383 |
+
"""
|
| 384 |
+
move_indices = torch.rand(
|
| 385 |
+
* x.shape, device=x.device) < move_chance
|
| 386 |
+
xt = torch.where(move_indices, self.mask_index, x)
|
| 387 |
+
return xt
|
| 388 |
+
|
| 389 |
+
def _sample_prior(self, *batch_dims):
|
| 390 |
+
return self.mask_index * torch.ones(
|
| 391 |
+
* batch_dims, dtype=torch.int64)
|
| 392 |
+
|
| 393 |
+
def _ddpm_caching_update(self, x, t, dt, p_x0=None):
|
| 394 |
+
assert self.config.noise.type == 'loglinear'
|
| 395 |
+
sigma_t, _ = self.noise(t)
|
| 396 |
+
if t.ndim > 1:
|
| 397 |
+
t = t.squeeze(-1)
|
| 398 |
+
assert t.ndim == 1
|
| 399 |
+
move_chance_t = t[:, None, None]
|
| 400 |
+
move_chance_s = (t - dt)[:, None, None]
|
| 401 |
+
assert move_chance_t.ndim == 3, move_chance_t.shape
|
| 402 |
+
if p_x0 is None:
|
| 403 |
+
p_x0 = self.forward(x, sigma_t).exp()
|
| 404 |
+
|
| 405 |
+
assert move_chance_t.ndim == p_x0.ndim
|
| 406 |
+
q_xs = p_x0 * (move_chance_t - move_chance_s)
|
| 407 |
+
q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
|
| 408 |
+
_x = _sample_categorical(q_xs)
|
| 409 |
+
|
| 410 |
+
copy_flag = (x != self.mask_index).to(x.dtype)
|
| 411 |
+
return p_x0, copy_flag * x + (1 - copy_flag) * _x
|
| 412 |
+
|
| 413 |
+
def _ddpm_update(self, x, t, dt, cls, w):
|
| 414 |
+
sigma_t, _ = self.noise(t)
|
| 415 |
+
sigma_s, _ = self.noise(t - dt)
|
| 416 |
+
if sigma_t.ndim > 1:
|
| 417 |
+
sigma_t = sigma_t.squeeze(-1)
|
| 418 |
+
if sigma_s.ndim > 1:
|
| 419 |
+
sigma_s = sigma_s.squeeze(-1)
|
| 420 |
+
assert sigma_t.ndim == 1, sigma_t.shape
|
| 421 |
+
assert sigma_s.ndim == 1, sigma_s.shape
|
| 422 |
+
move_chance_t = 1 - torch.exp(-sigma_t)
|
| 423 |
+
move_chance_s = 1 - torch.exp(-sigma_s)
|
| 424 |
+
move_chance_t = move_chance_t[:, None, None]
|
| 425 |
+
move_chance_s = move_chance_s[:, None, None]
|
| 426 |
+
unet_conditioning = sigma_t
|
| 427 |
+
uncond = (2 * torch.ones(x.shape[0], device=x.device)).long()
|
| 428 |
+
cond = (cls * torch.ones(x.shape[0], device=x.device)).long()
|
| 429 |
+
log_p_x0_uncond = self.forward(x, unet_conditioning, uncond)
|
| 430 |
+
log_p_x0_cond = self.forward(x, unet_conditioning, cond)
|
| 431 |
+
log_p_x0 = (1+w) * log_p_x0_cond - w * log_p_x0_uncond
|
| 432 |
+
assert move_chance_t.ndim == log_p_x0.ndim
|
| 433 |
+
q_xs = log_p_x0.exp() * (move_chance_t
|
| 434 |
+
- move_chance_s)
|
| 435 |
+
q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
|
| 436 |
+
_x = _sample_categorical(q_xs)
|
| 437 |
+
|
| 438 |
+
copy_flag = (x != self.mask_index).to(x.dtype)
|
| 439 |
+
return copy_flag * x + (1 - copy_flag) * _x
|
| 440 |
+
|
| 441 |
+
def _ar_sampler(self, bsz):
|
| 442 |
+
# precompute token buffer
|
| 443 |
+
num_pred_tokens = self.config.model.length - 1
|
| 444 |
+
x = torch.zeros(
|
| 445 |
+
(bsz, num_pred_tokens + 1),
|
| 446 |
+
dtype=torch.long,
|
| 447 |
+
device=self.device)
|
| 448 |
+
x[:, 0] = self.tokenizer.bos_token_id
|
| 449 |
+
# precompute noise
|
| 450 |
+
noise = (torch.distributions.Gumbel(0, 1)
|
| 451 |
+
.sample((bsz, num_pred_tokens, self.vocab_size))
|
| 452 |
+
.to(self.device))
|
| 453 |
+
for i in range(num_pred_tokens):
|
| 454 |
+
next_logits = self.forward(x[:, :i + 1], None)[:, -1]
|
| 455 |
+
y = (next_logits + noise[:, i]).argmax(-1)
|
| 456 |
+
x[:, i + 1] = y
|
| 457 |
+
return x
|
| 458 |
+
|
| 459 |
+
@torch.no_grad()
|
| 460 |
+
def _sample(self, num_steps=None, eps=1e-5, eval_sp_size=None, cls=1, w=None):
|
| 461 |
+
"""Generate samples from the model."""
|
| 462 |
+
if w is None:
|
| 463 |
+
w = self.config.model.cls_free_weight
|
| 464 |
+
if eval_sp_size is None:
|
| 465 |
+
batch_size_per_gpu = self.config.loader.eval_batch_size
|
| 466 |
+
else:
|
| 467 |
+
batch_size_per_gpu = eval_sp_size
|
| 468 |
+
if self.parameterization == 'ar':
|
| 469 |
+
return self._ar_sampler(batch_size_per_gpu)
|
| 470 |
+
if num_steps is None:
|
| 471 |
+
num_steps = self.config.sampling.steps
|
| 472 |
+
x = self._sample_prior(
|
| 473 |
+
batch_size_per_gpu,
|
| 474 |
+
self.config.model.length).to(self.device)
|
| 475 |
+
timesteps = torch.linspace(
|
| 476 |
+
1, eps, num_steps + 1, device=self.device)
|
| 477 |
+
dt = (1 - eps) / num_steps
|
| 478 |
+
p_x0_cache = None
|
| 479 |
+
|
| 480 |
+
for i in range(num_steps):
|
| 481 |
+
t = timesteps[i] * torch.ones(
|
| 482 |
+
x.shape[0], 1, device=self.device)
|
| 483 |
+
if self.sampler == 'ddpm':
|
| 484 |
+
x = self._ddpm_update(x, t, dt, cls, w)
|
| 485 |
+
else:
|
| 486 |
+
raise NotImplementedError
|
| 487 |
+
|
| 488 |
+
if self.config.sampling.noise_removal:
|
| 489 |
+
t = timesteps[-1] * torch.ones(x.shape[0], 1,
|
| 490 |
+
device=self.device)
|
| 491 |
+
unet_conditioning = self.noise(t)[0]
|
| 492 |
+
|
| 493 |
+
uncond = (2 * torch.ones(x.shape[0], device=x.device)).long()
|
| 494 |
+
cond = (cls * torch.ones(x.shape[0], device=x.device)).long()
|
| 495 |
+
log_p_x0_uncond = self.forward(x, unet_conditioning, uncond)
|
| 496 |
+
log_p_x0_cond = self.forward(x, unet_conditioning, cond)
|
| 497 |
+
|
| 498 |
+
logits = (1+w) * log_p_x0_cond - w * log_p_x0_uncond
|
| 499 |
+
x = logits[:, :, :-1].argmax(dim=-1)
|
| 500 |
+
|
| 501 |
+
return x
|
| 502 |
+
|
| 503 |
+
def get_score(self, x, sigma):
|
| 504 |
+
model_output = self.forward(x, sigma)
|
| 505 |
+
if self.parameterization == 'subs':
|
| 506 |
+
# score(x, t) = p_t(y) / p_t(x)
|
| 507 |
+
# => log score(x, t) = log p_t(y) - log p_t(x)
|
| 508 |
+
|
| 509 |
+
# case 1: x = masked
|
| 510 |
+
# (i) y = unmasked
|
| 511 |
+
# log score(x, t) = log p_\theta(x)|_y + log k
|
| 512 |
+
# where k = exp(- sigma) / (1 - exp(- sigma))
|
| 513 |
+
# (ii) y = masked
|
| 514 |
+
# log score(x, t) = 0
|
| 515 |
+
|
| 516 |
+
# case 2: x = unmasked
|
| 517 |
+
# (i) y != masked, y != x
|
| 518 |
+
# log score(x_i, t) = - inf
|
| 519 |
+
# (ii) y = x
|
| 520 |
+
# log score(x_i, t) = 0
|
| 521 |
+
# (iii) y = masked token
|
| 522 |
+
# log score(x_i, t) = - log k
|
| 523 |
+
# where k = exp(- sigma) / (1 - exp(- sigma))
|
| 524 |
+
|
| 525 |
+
log_k = - torch.log(torch.expm1(sigma)).squeeze(-1)
|
| 526 |
+
assert log_k.ndim == 1
|
| 527 |
+
|
| 528 |
+
masked_score = model_output + log_k[:, None, None]
|
| 529 |
+
masked_score[:, :, self.mask_index] = 0
|
| 530 |
+
|
| 531 |
+
unmasked_score = self.neg_infinity * torch.ones_like(
|
| 532 |
+
model_output)
|
| 533 |
+
unmasked_score = torch.scatter(
|
| 534 |
+
unmasked_score,
|
| 535 |
+
-1,
|
| 536 |
+
x[..., None],
|
| 537 |
+
torch.zeros_like(unmasked_score[..., :1]))
|
| 538 |
+
unmasked_score[:, :, self.mask_index] = - (
|
| 539 |
+
log_k[:, None] * torch.ones_like(x))
|
| 540 |
+
|
| 541 |
+
masked_indices = (x == self.mask_index).to(
|
| 542 |
+
model_output.dtype)[:, :, None]
|
| 543 |
+
model_output = (
|
| 544 |
+
masked_score * masked_indices
|
| 545 |
+
+ unmasked_score * (1 - masked_indices))
|
| 546 |
+
return model_output.exp()
|
| 547 |
+
|
| 548 |
+
def _staggered_score(self, score, dsigma):
|
| 549 |
+
score = score.clone()
|
| 550 |
+
extra_const = (1 - dsigma.exp()) * score.sum(dim=-1)
|
| 551 |
+
score *= dsigma.exp()[:, None]
|
| 552 |
+
score[..., self.mask_index] += extra_const
|
| 553 |
+
return score
|
| 554 |
+
|
| 555 |
+
def _analytic_update(self, x, t, step_size):
|
| 556 |
+
curr_sigma, _ = self.noise(t)
|
| 557 |
+
next_sigma, _ = self.noise(t - step_size)
|
| 558 |
+
dsigma = curr_sigma - next_sigma
|
| 559 |
+
score = self.get_score(x, curr_sigma)
|
| 560 |
+
stag_score = self._staggered_score(score, dsigma)
|
| 561 |
+
probs = stag_score * self._transp_transition(x, dsigma)
|
| 562 |
+
return _sample_categorical(probs)
|
| 563 |
+
|
| 564 |
+
def _denoiser_update(self, x, t):
|
| 565 |
+
sigma, _ = self.noise(t)
|
| 566 |
+
score = self.get_score(x, sigma)
|
| 567 |
+
stag_score = self._staggered_score(score, sigma)
|
| 568 |
+
probs = stag_score * self._transp_transition(x, sigma)
|
| 569 |
+
probs[..., self.mask_index] = 0
|
| 570 |
+
samples = _sample_categorical(probs)
|
| 571 |
+
return samples
|
| 572 |
+
|
| 573 |
+
def _transp_transition(self, i, sigma):
|
| 574 |
+
sigma = _unsqueeze(sigma, reference=i[..., None])
|
| 575 |
+
edge = torch.exp(-sigma) * F.one_hot(
|
| 576 |
+
i, num_classes=self.vocab_size)
|
| 577 |
+
edge += torch.where(i == self.mask_index,
|
| 578 |
+
1 - torch.exp(-sigma).squeeze(-1),
|
| 579 |
+
0)[..., None]
|
| 580 |
+
return edge
|
| 581 |
+
|
| 582 |
+
def _sample_t(self, n, device):
|
| 583 |
+
_eps_t = torch.rand(n, device=device)
|
| 584 |
+
if self.antithetic_sampling:
|
| 585 |
+
# for variance reduction
|
| 586 |
+
offset = torch.arange(n, device=device) / n
|
| 587 |
+
_eps_t = (_eps_t / n + offset) % 1
|
| 588 |
+
t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps
|
| 589 |
+
if self.importance_sampling:
|
| 590 |
+
return self.noise.importance_sampling_transformation(t)
|
| 591 |
+
return t
|
| 592 |
+
|
| 593 |
+
def _maybe_sub_sample(self, x0, attention_mask):
|
| 594 |
+
seqlen = x0.shape[1]
|
| 595 |
+
if seqlen > self.config.model.length:
|
| 596 |
+
raise NotImplementedError('Sub-sampling not implemented')
|
| 597 |
+
elif self.parameterization == 'ar':
|
| 598 |
+
input_tokens = x0[:, :-1]
|
| 599 |
+
output_tokens = x0[:, 1:]
|
| 600 |
+
new_attention_mask = attention_mask[:, 1:]
|
| 601 |
+
else:
|
| 602 |
+
input_tokens = x0
|
| 603 |
+
output_tokens = None
|
| 604 |
+
new_attention_mask = attention_mask
|
| 605 |
+
return input_tokens, output_tokens, new_attention_mask
|
| 606 |
+
|
| 607 |
+
def _reconstruction_loss(self, x0):
|
| 608 |
+
t0 = torch.zeros(x0.shape[0], dtype=self.dtype,
|
| 609 |
+
device=self.device)
|
| 610 |
+
assert self.config.noise.type == 'loglinear'
|
| 611 |
+
# The above assert is for d3pm parameterization
|
| 612 |
+
unet_conditioning = self.noise(t0)[0][:, None]
|
| 613 |
+
model_output_t0 = self.forward(x0, unet_conditioning)
|
| 614 |
+
return - torch.gather(input=model_output_t0,
|
| 615 |
+
dim=-1,
|
| 616 |
+
index=x0[:, :, None]).squeeze(-1)
|
| 617 |
+
|
| 618 |
+
def _forward_pass_diffusion(self, x0, binary_clss=None):
|
| 619 |
+
t = self._sample_t(x0.shape[0], x0.device)
|
| 620 |
+
if self.T > 0:
|
| 621 |
+
# else ts are between 0 and 1
|
| 622 |
+
t = (t * self.T).to(torch.int)
|
| 623 |
+
t = t / self.T
|
| 624 |
+
# t \in {1/T, 2/T, ..., 1}
|
| 625 |
+
t += (1 / self.T)
|
| 626 |
+
|
| 627 |
+
if self.change_of_variables: # False
|
| 628 |
+
unet_conditioning = t[:, None]
|
| 629 |
+
f_T = torch.log1p(- torch.exp(- self.noise.sigma_max))
|
| 630 |
+
f_0 = torch.log1p(- torch.exp(- self.noise.sigma_min))
|
| 631 |
+
move_chance = torch.exp(f_0 + t * (f_T - f_0))
|
| 632 |
+
move_chance = move_chance[:, None]
|
| 633 |
+
else:
|
| 634 |
+
sigma, dsigma = self.noise(t) # total noise, rate noise
|
| 635 |
+
unet_conditioning = sigma[:, None]
|
| 636 |
+
move_chance = 1 - torch.exp(-sigma[:, None])
|
| 637 |
+
|
| 638 |
+
xt = self.q_xt(x0, move_chance) # q(xt|x0)
|
| 639 |
+
model_output = self.forward(xt, unet_conditioning, binary_clss=binary_clss)
|
| 640 |
+
utils.print_nans(model_output, 'model_output')
|
| 641 |
+
|
| 642 |
+
if self.parameterization == 'sedd':
|
| 643 |
+
return dsigma[:, None] * self._score_entropy(
|
| 644 |
+
model_output, sigma[:, None], xt, x0)
|
| 645 |
+
|
| 646 |
+
if self.T > 0:
|
| 647 |
+
diffusion_loss = self._d3pm_loss(
|
| 648 |
+
model_output=model_output, xt=xt, x0=x0, t=t)
|
| 649 |
+
if self.parameterization == 'd3pm':
|
| 650 |
+
reconstruction_loss = self._reconstruction_loss(x0)
|
| 651 |
+
elif self.parameterization == 'subs':
|
| 652 |
+
reconstruction_loss = 0
|
| 653 |
+
return reconstruction_loss + diffusion_loss
|
| 654 |
+
|
| 655 |
+
# SUBS parameterization, continuous time.
|
| 656 |
+
log_p_theta = torch.gather(
|
| 657 |
+
input=model_output,
|
| 658 |
+
dim=-1,
|
| 659 |
+
index=x0[:, :, None]).squeeze(-1)
|
| 660 |
+
|
| 661 |
+
if self.change_of_variables or self.importance_sampling:
|
| 662 |
+
return log_p_theta * torch.log1p(
|
| 663 |
+
- torch.exp(- self.noise.sigma_min))
|
| 664 |
+
|
| 665 |
+
return - log_p_theta * (
|
| 666 |
+
dsigma / torch.expm1(sigma))[:, None]
|
| 667 |
+
|
| 668 |
+
def _loss(self, x0, attention_mask, binary_clss):
|
| 669 |
+
(input_tokens, output_tokens,
|
| 670 |
+
attention_mask) = self._maybe_sub_sample(
|
| 671 |
+
x0, attention_mask)
|
| 672 |
+
|
| 673 |
+
if self.parameterization == 'ar':
|
| 674 |
+
logprobs = self.backbone(input_tokens, None, cls=binary_clss)
|
| 675 |
+
loss = - logprobs.gather(
|
| 676 |
+
-1, output_tokens[:, :, None])[:, :, 0]
|
| 677 |
+
else:
|
| 678 |
+
loss = self._forward_pass_diffusion(input_tokens, binary_clss=binary_clss)
|
| 679 |
+
|
| 680 |
+
nlls = loss * attention_mask
|
| 681 |
+
count = attention_mask.sum()
|
| 682 |
+
|
| 683 |
+
batch_nll = nlls.sum()
|
| 684 |
+
token_nll = batch_nll / count
|
| 685 |
+
|
| 686 |
+
return Loss(loss=token_nll,
|
| 687 |
+
nlls=nlls,
|
| 688 |
+
token_mask=attention_mask)
|
| 689 |
+
|
| 690 |
+
def _score_entropy(self, log_score, sigma, xt, x0):
|
| 691 |
+
"""Computes the SEDD loss.
|
| 692 |
+
|
| 693 |
+
Args:
|
| 694 |
+
log_score: float torch.Tensor with shape (batch_size,
|
| 695 |
+
diffusion_model_input_length, vocab_size),
|
| 696 |
+
log score, output of the denoising network.
|
| 697 |
+
xt: int torch.Tensor with shape (batch_size,
|
| 698 |
+
diffusion_model_input_length), input.
|
| 699 |
+
x0: int torch.Tensor with shape (batch_size,
|
| 700 |
+
diffusion_model_input_length), input.
|
| 701 |
+
sigma: float torch.Tensor with shape (batch_size, 1).
|
| 702 |
+
|
| 703 |
+
Returns:
|
| 704 |
+
loss with shape (batch_size, diffusion_model_input_length)
|
| 705 |
+
"""
|
| 706 |
+
# seems that it takes y=x0,xt=M case
|
| 707 |
+
# what is the const term for, seems to be y=M,xt=x0 case and x0 is known so score estimation is precise
|
| 708 |
+
masked_indices = xt == self.mask_index
|
| 709 |
+
|
| 710 |
+
expsig_minus_1 = torch.expm1(sigma).expand_as(xt)
|
| 711 |
+
q_ratio = 1 / expsig_minus_1[masked_indices]
|
| 712 |
+
|
| 713 |
+
words_that_were_masked = x0[masked_indices]
|
| 714 |
+
|
| 715 |
+
neg_term = q_ratio * torch.gather(
|
| 716 |
+
log_score[masked_indices],
|
| 717 |
+
-1,
|
| 718 |
+
words_that_were_masked[..., None]).squeeze(-1)
|
| 719 |
+
score = log_score[masked_indices].exp()
|
| 720 |
+
if self.mask_index == self.vocab_size - 1:
|
| 721 |
+
pos_term = score[:, :-1].sum(dim=-1)
|
| 722 |
+
else:
|
| 723 |
+
pos_term = score[:, : self.mask_index].sum(
|
| 724 |
+
dim=-1) + score[:, self.mask_index + 1:].sum(dim=-1)
|
| 725 |
+
const = q_ratio * (q_ratio.log() - 1)
|
| 726 |
+
|
| 727 |
+
entropy = torch.zeros(* xt.shape, device=xt.device)
|
| 728 |
+
entropy[masked_indices] += pos_term - neg_term + const
|
| 729 |
+
return entropy
|
tr2d2-dna/env.sh
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
conda install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda=12.1 -c pytorch -c nvidia
|
| 2 |
+
pip install packaging
|
| 3 |
+
pip install ninja
|
| 4 |
+
pip install transformers
|
| 5 |
+
pip install datasets
|
| 6 |
+
pip install omegaconf
|
| 7 |
+
conda install ipykernel
|
| 8 |
+
python -m ipykernel install --user --name tr2d2 --display-name "Python (tr2d2)"
|
| 9 |
+
pip install hydra-core --upgrade
|
| 10 |
+
pip install hydra-submitit-launcher
|
| 11 |
+
|
| 12 |
+
# for mdlm
|
| 13 |
+
pip install causal-conv1d
|
| 14 |
+
pip install lightning
|
| 15 |
+
pip install timm
|
| 16 |
+
pip install rich
|
| 17 |
+
|
| 18 |
+
pip install scipy
|
| 19 |
+
pip install wandb
|
| 20 |
+
pip install gReLU
|
tr2d2-dna/eval_runs_batch.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Batch evaluation script for multiple runs with checkpoints.
|
| 4 |
+
|
| 5 |
+
This script:
|
| 6 |
+
1. Scans a folder containing different runs
|
| 7 |
+
2. For each run, finds checkpoints and selects the one with largest epoch number
|
| 8 |
+
3. Evaluates that checkpoint and saves results indexed by run folder name
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import re
|
| 13 |
+
import glob
|
| 14 |
+
import argparse
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from diffusion import Diffusion
|
| 17 |
+
import dataloader_gosai
|
| 18 |
+
import matplotlib.pyplot as plt
|
| 19 |
+
import numpy as np
|
| 20 |
+
import pandas as pd
|
| 21 |
+
import oracle
|
| 22 |
+
from scipy.stats import pearsonr
|
| 23 |
+
import torch
|
| 24 |
+
from tqdm import tqdm
|
| 25 |
+
from eval_utils import get_eval_matrics
|
| 26 |
+
from hydra import initialize, compose
|
| 27 |
+
from hydra.core.global_hydra import GlobalHydra
|
| 28 |
+
from dataclasses import dataclass
|
| 29 |
+
from datetime import datetime
|
| 30 |
+
import json
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class Args:
|
| 35 |
+
total_num_steps: int
|
| 36 |
+
batch_size: int
|
| 37 |
+
num_seeds: int
|
| 38 |
+
total_samples: int
|
| 39 |
+
seq_length: int
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def find_latest_checkpoint(run_dir):
|
| 43 |
+
"""
|
| 44 |
+
Find the checkpoint with the largest epoch/step number in a run directory.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
run_dir (str): Path to the run directory
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
str or None: Path to the latest checkpoint, or None if no checkpoints found
|
| 51 |
+
"""
|
| 52 |
+
ckpt_pattern = os.path.join(run_dir, "model_*.ckpt")
|
| 53 |
+
ckpt_files = glob.glob(ckpt_pattern)
|
| 54 |
+
|
| 55 |
+
if not ckpt_files:
|
| 56 |
+
return None
|
| 57 |
+
|
| 58 |
+
# Extract step numbers from checkpoint filenames
|
| 59 |
+
step_numbers = []
|
| 60 |
+
for ckpt_file in ckpt_files:
|
| 61 |
+
filename = os.path.basename(ckpt_file)
|
| 62 |
+
match = re.search(r'model_(\d+)\.ckpt', filename)
|
| 63 |
+
if match:
|
| 64 |
+
step_numbers.append((int(match.group(1)), ckpt_file))
|
| 65 |
+
|
| 66 |
+
if not step_numbers:
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
# Return checkpoint with largest step number
|
| 70 |
+
step_numbers.sort(key=lambda x: x[0], reverse=True)
|
| 71 |
+
return step_numbers[0][1]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def evaluate_checkpoint(checkpoint_path, args, cfg, pretrained_model, gosai_oracle,
|
| 75 |
+
cal_atac_pred_new_mdl, highexp_kmers_999, n_highexp_kmers_999, device):
|
| 76 |
+
"""
|
| 77 |
+
Evaluate a single checkpoint.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
checkpoint_path (str): Path to the checkpoint file
|
| 81 |
+
args: Evaluation arguments
|
| 82 |
+
cfg: Configuration object
|
| 83 |
+
pretrained_model: Pretrained reference model
|
| 84 |
+
gosai_oracle: GOSAI oracle model
|
| 85 |
+
cal_atac_pred_new_mdl: ATAC prediction model
|
| 86 |
+
highexp_kmers_999: High expression k-mers
|
| 87 |
+
n_highexp_kmers_999: Number of high expression k-mers
|
| 88 |
+
device: Device to run evaluation on
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
tuple: (eval_metrics_agg, total_rewards_agg) containing aggregated results
|
| 92 |
+
"""
|
| 93 |
+
# Load the policy model from checkpoint
|
| 94 |
+
policy_model = Diffusion(cfg).to(device)
|
| 95 |
+
policy_model.load_state_dict(torch.load(checkpoint_path, map_location=device))
|
| 96 |
+
policy_model.eval()
|
| 97 |
+
|
| 98 |
+
total_rewards_all = []
|
| 99 |
+
eval_metrics_all = []
|
| 100 |
+
|
| 101 |
+
print(f"Evaluating checkpoint: {os.path.basename(checkpoint_path)}")
|
| 102 |
+
|
| 103 |
+
for i in range(args.num_seeds):
|
| 104 |
+
iter_times = args.total_samples // args.batch_size
|
| 105 |
+
total_samples = []
|
| 106 |
+
total_rewards = []
|
| 107 |
+
range_bar = tqdm(range(iter_times), desc=f"Seed {i+1}", leave=False)
|
| 108 |
+
|
| 109 |
+
for j in range_bar:
|
| 110 |
+
x_eval, mean_reward_eval = policy_model.sample_finetuned(args, gosai_oracle)
|
| 111 |
+
total_samples.append(x_eval)
|
| 112 |
+
total_rewards.append(mean_reward_eval.item() * args.batch_size)
|
| 113 |
+
|
| 114 |
+
total_samples = torch.concat(total_samples)
|
| 115 |
+
eval_metrics = get_eval_matrics(samples=total_samples, ref_model=pretrained_model,
|
| 116 |
+
gosai_oracle=gosai_oracle, cal_atac_pred_new_mdl=cal_atac_pred_new_mdl,
|
| 117 |
+
highexp_kmers_999=highexp_kmers_999, n_highexp_kmers_999=n_highexp_kmers_999)
|
| 118 |
+
|
| 119 |
+
eval_metrics_all.append(eval_metrics)
|
| 120 |
+
total_rewards_all.append(np.sum(total_rewards) / args.total_samples)
|
| 121 |
+
|
| 122 |
+
# Aggregate results
|
| 123 |
+
eval_metrics_agg = {k: (np.mean([eval_metrics[k] for eval_metrics in eval_metrics_all]),
|
| 124 |
+
np.std([eval_metrics[k] for eval_metrics in eval_metrics_all]))
|
| 125 |
+
for k in eval_metrics_all[0].keys()}
|
| 126 |
+
total_rewards_agg = (np.mean(total_rewards_all), np.std(total_rewards_all))
|
| 127 |
+
|
| 128 |
+
return eval_metrics_agg, total_rewards_agg
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def save_results(results, output_file):
|
| 132 |
+
"""
|
| 133 |
+
Save evaluation results to a text file.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
results (dict): Dictionary containing results for each run
|
| 137 |
+
output_file (str): Path to output file
|
| 138 |
+
"""
|
| 139 |
+
with open(output_file, 'w') as f:
|
| 140 |
+
f.write("="*80 + "\n")
|
| 141 |
+
f.write("BATCH EVALUATION RESULTS\n")
|
| 142 |
+
f.write("="*80 + "\n")
|
| 143 |
+
f.write(f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
| 144 |
+
f.write(f"Total runs evaluated: {len(results)}\n\n")
|
| 145 |
+
|
| 146 |
+
for run_name, result in results.items():
|
| 147 |
+
if result is None:
|
| 148 |
+
f.write(f"RUN: {run_name}\n")
|
| 149 |
+
f.write("-" * 60 + "\n")
|
| 150 |
+
f.write("Status: No checkpoints found or evaluation failed\n\n")
|
| 151 |
+
continue
|
| 152 |
+
|
| 153 |
+
eval_metrics_agg, total_rewards_agg, checkpoint_path = result
|
| 154 |
+
|
| 155 |
+
f.write(f"RUN: {run_name}\n")
|
| 156 |
+
f.write("-" * 60 + "\n")
|
| 157 |
+
f.write(f"Checkpoint: {os.path.basename(checkpoint_path)}\n")
|
| 158 |
+
f.write(f"Full path: {checkpoint_path}\n\n")
|
| 159 |
+
|
| 160 |
+
f.write("📊 EVALUATION METRICS:\n")
|
| 161 |
+
for metric_name in eval_metrics_agg.keys():
|
| 162 |
+
mean_val = eval_metrics_agg[metric_name][0]
|
| 163 |
+
std_val = eval_metrics_agg[metric_name][1]
|
| 164 |
+
f.write(f" {metric_name:<20}: {mean_val:8.4f} ± {std_val:6.4f}\n")
|
| 165 |
+
|
| 166 |
+
f.write(f"\n🎯 TOTAL REWARDS:\n")
|
| 167 |
+
f.write(f" {'Mean':<20}: {total_rewards_agg[0]:8.4f}\n")
|
| 168 |
+
f.write(f" {'Std':<20}: {total_rewards_agg[1]:8.4f}\n")
|
| 169 |
+
f.write("\n")
|
| 170 |
+
|
| 171 |
+
print(f"Results saved to: {output_file}")
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def append_single_result(run_name, result, output_file, is_first_run=False):
|
| 175 |
+
"""
|
| 176 |
+
Append a single successful run result to the output file.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
run_name (str): Name of the run
|
| 180 |
+
result: Result tuple (eval_metrics_agg, total_rewards_agg, checkpoint_path)
|
| 181 |
+
output_file (str): Path to output file
|
| 182 |
+
is_first_run (bool): Whether this is the first successful run (write header)
|
| 183 |
+
"""
|
| 184 |
+
mode = 'w' if is_first_run else 'a'
|
| 185 |
+
|
| 186 |
+
with open(output_file, mode) as f:
|
| 187 |
+
if is_first_run:
|
| 188 |
+
f.write("="*80 + "\n")
|
| 189 |
+
f.write("BATCH EVALUATION RESULTS\n")
|
| 190 |
+
f.write("="*80 + "\n")
|
| 191 |
+
f.write(f"Started on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
| 192 |
+
f.write("Results are saved incrementally as each run completes.\n")
|
| 193 |
+
f.write("Only successful evaluations are included.\n\n")
|
| 194 |
+
|
| 195 |
+
eval_metrics_agg, total_rewards_agg, checkpoint_path = result
|
| 196 |
+
|
| 197 |
+
f.write(f"RUN: {run_name}\n")
|
| 198 |
+
f.write("-" * 60 + "\n")
|
| 199 |
+
f.write(f"Checkpoint: {os.path.basename(checkpoint_path)}\n")
|
| 200 |
+
f.write(f"Full path: {checkpoint_path}\n")
|
| 201 |
+
f.write(f"Completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
|
| 202 |
+
|
| 203 |
+
f.write("📊 EVALUATION METRICS:\n")
|
| 204 |
+
for metric_name in eval_metrics_agg.keys():
|
| 205 |
+
mean_val = eval_metrics_agg[metric_name][0]
|
| 206 |
+
std_val = eval_metrics_agg[metric_name][1]
|
| 207 |
+
f.write(f" {metric_name:<20}: {mean_val:8.4f} ± {std_val:6.4f}\n")
|
| 208 |
+
|
| 209 |
+
f.write(f"\n🎯 TOTAL REWARDS:\n")
|
| 210 |
+
f.write(f" {'Mean':<20}: {total_rewards_agg[0]:8.4f}\n")
|
| 211 |
+
f.write(f" {'Std':<20}: {total_rewards_agg[1]:8.4f}\n")
|
| 212 |
+
f.write("\n" + "="*80 + "\n\n") # Add separator line and extra spacing
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def main():
|
| 216 |
+
parser = argparse.ArgumentParser(description="Batch evaluation of multiple runs")
|
| 217 |
+
parser.add_argument("--runs_dir", type=str, required=True,
|
| 218 |
+
help="Directory containing run folders with checkpoints")
|
| 219 |
+
parser.add_argument("--output_file", type=str, default="batch_eval_results.txt",
|
| 220 |
+
help="Output file to save results")
|
| 221 |
+
parser.add_argument("--device", type=str, default="cuda:0",
|
| 222 |
+
help="Device to run evaluation on")
|
| 223 |
+
parser.add_argument("--total_num_steps", type=int, default=128,
|
| 224 |
+
help="Total number of diffusion steps")
|
| 225 |
+
parser.add_argument("--batch_size", type=int, default=128,
|
| 226 |
+
help="Batch size for evaluation")
|
| 227 |
+
parser.add_argument("--num_seeds", type=int, default=3,
|
| 228 |
+
help="Number of random seeds for evaluation")
|
| 229 |
+
parser.add_argument("--total_samples", type=int, default=640,
|
| 230 |
+
help="Total number of samples to generate")
|
| 231 |
+
parser.add_argument("--seq_length", type=int, default=200,
|
| 232 |
+
help="Sequence length")
|
| 233 |
+
parser.add_argument("--pretrained_path", type=str,
|
| 234 |
+
default=None,
|
| 235 |
+
help="Path to pretrained model checkpoint")
|
| 236 |
+
|
| 237 |
+
args = parser.parse_args()
|
| 238 |
+
|
| 239 |
+
# Setup evaluation arguments
|
| 240 |
+
eval_args = Args(
|
| 241 |
+
total_num_steps=args.total_num_steps,
|
| 242 |
+
batch_size=args.batch_size,
|
| 243 |
+
num_seeds=args.num_seeds,
|
| 244 |
+
total_samples=args.total_samples,
|
| 245 |
+
seq_length=args.seq_length
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
device = args.device
|
| 249 |
+
|
| 250 |
+
# Initialize Hydra configuration
|
| 251 |
+
if GlobalHydra().is_initialized():
|
| 252 |
+
GlobalHydra.instance().clear()
|
| 253 |
+
|
| 254 |
+
initialize(config_path="configs_gosai", job_name="batch_eval")
|
| 255 |
+
cfg = compose(config_name="config_gosai.yaml")
|
| 256 |
+
|
| 257 |
+
print("Loading pretrained model and oracles...")
|
| 258 |
+
# Load pretrained model
|
| 259 |
+
pretrained_model = Diffusion.load_from_checkpoint(args.pretrained_path, config=cfg, map_location=device)
|
| 260 |
+
pretrained_model.eval()
|
| 261 |
+
|
| 262 |
+
# Load oracles
|
| 263 |
+
_, _, highexp_kmers_999, n_highexp_kmers_999, _, _, _ = oracle.cal_highexp_kmers(return_clss=True)
|
| 264 |
+
cal_atac_pred_new_mdl = oracle.get_cal_atac_orale(device=device)
|
| 265 |
+
cal_atac_pred_new_mdl.eval()
|
| 266 |
+
gosai_oracle = oracle.get_gosai_oracle(mode='eval', device=device)
|
| 267 |
+
gosai_oracle.eval()
|
| 268 |
+
|
| 269 |
+
print("Scanning for runs...")
|
| 270 |
+
# Find all run directories
|
| 271 |
+
runs_dir = Path(args.runs_dir)
|
| 272 |
+
if not runs_dir.exists():
|
| 273 |
+
print(f"Error: Directory {args.runs_dir} does not exist")
|
| 274 |
+
return
|
| 275 |
+
|
| 276 |
+
run_dirs = [d for d in runs_dir.iterdir() if d.is_dir()]
|
| 277 |
+
run_dirs.sort() # Sort for consistent ordering
|
| 278 |
+
|
| 279 |
+
print(f"Found {len(run_dirs)} run directories")
|
| 280 |
+
|
| 281 |
+
results = {}
|
| 282 |
+
successful_runs = 0
|
| 283 |
+
failed_runs = 0
|
| 284 |
+
|
| 285 |
+
# Process each run
|
| 286 |
+
for i, run_dir in enumerate(tqdm(run_dirs, desc="Processing runs")):
|
| 287 |
+
run_name = run_dir.name
|
| 288 |
+
print(f"\nProcessing run {i+1}/{len(run_dirs)}: {run_name}")
|
| 289 |
+
|
| 290 |
+
# Find latest checkpoint
|
| 291 |
+
latest_ckpt = find_latest_checkpoint(str(run_dir))
|
| 292 |
+
|
| 293 |
+
if latest_ckpt is None:
|
| 294 |
+
print(f" No checkpoints found in {run_name} - skipping")
|
| 295 |
+
failed_runs += 1
|
| 296 |
+
continue # Skip this run entirely, don't save anything to file
|
| 297 |
+
|
| 298 |
+
print(f" Found latest checkpoint: {os.path.basename(latest_ckpt)}")
|
| 299 |
+
|
| 300 |
+
try:
|
| 301 |
+
# Evaluate checkpoint
|
| 302 |
+
eval_metrics_agg, total_rewards_agg = evaluate_checkpoint(
|
| 303 |
+
latest_ckpt, eval_args, cfg, pretrained_model, gosai_oracle,
|
| 304 |
+
cal_atac_pred_new_mdl, highexp_kmers_999, n_highexp_kmers_999, device
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
result = (eval_metrics_agg, total_rewards_agg, latest_ckpt)
|
| 308 |
+
results[run_name] = result
|
| 309 |
+
successful_runs += 1
|
| 310 |
+
print(f" ✓ Evaluation completed successfully")
|
| 311 |
+
|
| 312 |
+
# Save result incrementally (only for successful evaluations)
|
| 313 |
+
is_first_run = (len(results) == 1) # First successful run
|
| 314 |
+
append_single_result(run_name, result, args.output_file, is_first_run=is_first_run)
|
| 315 |
+
print(f" Result saved to {args.output_file}")
|
| 316 |
+
|
| 317 |
+
except Exception as e:
|
| 318 |
+
print(f" ✗ Evaluation failed: {str(e)}")
|
| 319 |
+
failed_runs += 1
|
| 320 |
+
# Don't save failed evaluations to file either
|
| 321 |
+
|
| 322 |
+
# Add final summary to the file (only if there were successful runs)
|
| 323 |
+
if successful_runs > 0:
|
| 324 |
+
with open(args.output_file, 'a') as f:
|
| 325 |
+
f.write("="*80 + "\n")
|
| 326 |
+
f.write("FINAL SUMMARY\n")
|
| 327 |
+
f.write("="*80 + "\n")
|
| 328 |
+
f.write(f"Completed on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
| 329 |
+
f.write(f"Total runs processed: {len(run_dirs)}\n")
|
| 330 |
+
f.write(f"Successful evaluations: {successful_runs}\n")
|
| 331 |
+
f.write(f"Failed/skipped runs: {failed_runs}\n")
|
| 332 |
+
else:
|
| 333 |
+
print(f"No successful evaluations - output file {args.output_file} not created")
|
| 334 |
+
|
| 335 |
+
# Print summary
|
| 336 |
+
print(f"\nFinal Summary:")
|
| 337 |
+
print(f" Total runs processed: {len(run_dirs)}")
|
| 338 |
+
print(f" Successful evaluations: {successful_runs}")
|
| 339 |
+
print(f" Failed/skipped runs: {failed_runs}")
|
| 340 |
+
if successful_runs > 0:
|
| 341 |
+
print(f" Results saved to: {args.output_file}")
|
| 342 |
+
else:
|
| 343 |
+
print(f" No output file created (no successful evaluations)")
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
if __name__ == "__main__":
|
| 347 |
+
main()
|
tr2d2-dna/eval_utils.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from scipy.stats import pearsonr
|
| 4 |
+
import dataloader_gosai
|
| 5 |
+
import oracle
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def compare_kmer(kmer1, kmer2, n_sp1, n_sp2):
|
| 9 |
+
kmer_set = set(kmer1.keys()) | set(kmer2.keys())
|
| 10 |
+
counts = np.zeros((len(kmer_set), 2))
|
| 11 |
+
for i, kmer in enumerate(kmer_set):
|
| 12 |
+
if kmer in kmer1: counts[i][1] = kmer1[kmer] * n_sp2 / n_sp1
|
| 13 |
+
if kmer in kmer2: counts[i][0] = kmer2[kmer]
|
| 14 |
+
return pearsonr(counts[:, 0], counts[:, 1])[0]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_eval_matrics(samples, ref_model, gosai_oracle, cal_atac_pred_new_mdl, highexp_kmers_999, n_highexp_kmers_999):
|
| 18 |
+
"""samples: [B, 200]"""
|
| 19 |
+
info = {}
|
| 20 |
+
detokenized_samples = dataloader_gosai.batch_dna_detokenize(samples.detach().cpu().numpy()) # [B], strings with length 200
|
| 21 |
+
ref_log_lik = ref_model.get_likelihood(samples, num_steps=128, n_samples=1) # [B]
|
| 22 |
+
info['[log-lik-med]'] = torch.median(ref_log_lik).item()
|
| 23 |
+
preds = oracle.cal_gosai_pred_new(detokenized_samples, gosai_oracle, mode='eval')[:, 0]
|
| 24 |
+
info['[pred-activity-med]'] = np.median(preds).item()
|
| 25 |
+
atac = oracle.cal_atac_pred_new(detokenized_samples, cal_atac_pred_new_mdl)[:, 1]
|
| 26 |
+
info['[atac-acc%]'] = (atac > 0.5).sum().item() / len(samples) * 100
|
| 27 |
+
kmer = oracle.count_kmers(detokenized_samples)
|
| 28 |
+
info['[3-mer-corr]'] = compare_kmer(highexp_kmers_999, kmer, n_highexp_kmers_999, len(detokenized_samples)).item()
|
| 29 |
+
return info
|
tr2d2-dna/finetune.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# direct reward backpropagation
|
| 2 |
+
from diffusion import Diffusion
|
| 3 |
+
from hydra import initialize, compose
|
| 4 |
+
from hydra.core.global_hydra import GlobalHydra
|
| 5 |
+
import numpy as np
|
| 6 |
+
import oracle
|
| 7 |
+
from scipy.stats import pearsonr
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import argparse
|
| 11 |
+
import wandb
|
| 12 |
+
import os
|
| 13 |
+
import datetime
|
| 14 |
+
from utils import str2bool, set_seed
|
| 15 |
+
from finetune_dna import finetune
|
| 16 |
+
from mcts import MCTS
|
| 17 |
+
|
| 18 |
+
argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 19 |
+
argparser.add_argument('--base_path', type=str, default="")
|
| 20 |
+
argparser.add_argument('--learning_rate', type=float, default=1e-4)
|
| 21 |
+
argparser.add_argument('--num_epochs', type=int, default=100)
|
| 22 |
+
argparser.add_argument('--num_accum_steps', type=int, default=4)
|
| 23 |
+
argparser.add_argument('--truncate_steps', type=int, default=50)
|
| 24 |
+
argparser.add_argument("--truncate_kl", type=str2bool, default=False)
|
| 25 |
+
argparser.add_argument('--gumbel_temp', type=float, default=1.0)
|
| 26 |
+
argparser.add_argument('--gradnorm_clip', type=float, default=1.0)
|
| 27 |
+
argparser.add_argument('--batch_size', type=int, default=32)
|
| 28 |
+
argparser.add_argument('--name', type=str, default='debug')
|
| 29 |
+
argparser.add_argument('--total_num_steps', type=int, default=128)
|
| 30 |
+
argparser.add_argument('--copy_flag_temp', type=float, default=None)
|
| 31 |
+
argparser.add_argument('--save_every_n_epochs', type=int, default=10)
|
| 32 |
+
argparser.add_argument('--eval_every_n_epochs', type=int, default=200)
|
| 33 |
+
argparser.add_argument('--alpha', type=float, default=0.001)
|
| 34 |
+
argparser.add_argument('--alpha_schedule_warmup', type=int, default=0)
|
| 35 |
+
argparser.add_argument("--seed", type=int, default=0)
|
| 36 |
+
# new
|
| 37 |
+
argparser.add_argument('--run_name', type=str, default='drakes')
|
| 38 |
+
argparser.add_argument("--device", default="cuda:0", type=str)
|
| 39 |
+
argparser.add_argument("--save_path_dir", default=None, type=str)
|
| 40 |
+
argparser.add_argument("--no_mcts", action='store_true', default=False)
|
| 41 |
+
argparser.add_argument("--centering", action='store_true', default=False)
|
| 42 |
+
argparser.add_argument("--reward_clip", action='store_true', default=False)
|
| 43 |
+
argparser.add_argument("--reward_clip_value", type=float, default=15.0)
|
| 44 |
+
argparser.add_argument("--select_topk", action='store_true', default=False)
|
| 45 |
+
argparser.add_argument('--select_topk_value', type=int, default=10)
|
| 46 |
+
argparser.add_argument("--restart_ckpt_path", type=str, default=None)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# mcts
|
| 50 |
+
argparser.add_argument('--num_sequences', type=int, default=10)
|
| 51 |
+
argparser.add_argument('--num_children', type=int, default=50)
|
| 52 |
+
argparser.add_argument('--num_iter', type=int, default=30) # iterations of mcts
|
| 53 |
+
argparser.add_argument('--seq_length', type=int, default=200)
|
| 54 |
+
argparser.add_argument('--time_conditioning', action='store_true', default=False)
|
| 55 |
+
argparser.add_argument('--mcts_sampling', type=int, default=0) # for batched categorical sampling: '0' means gumbel noise
|
| 56 |
+
argparser.add_argument('--buffer_size', type=int, default=100)
|
| 57 |
+
argparser.add_argument('--wdce_num_replicates', type=int, default=16)
|
| 58 |
+
argparser.add_argument('--noise_removal', action='store_true', default=False)
|
| 59 |
+
argparser.add_argument('--grad_clip', action='store_true', default=False)
|
| 60 |
+
argparser.add_argument('--resample_every_n_step', type=int, default=10)
|
| 61 |
+
argparser.add_argument('--exploration', type=float, default=0.1)
|
| 62 |
+
argparser.add_argument('--reset_tree', action='store_true', default=False)
|
| 63 |
+
|
| 64 |
+
# eval
|
| 65 |
+
|
| 66 |
+
args = argparser.parse_args()
|
| 67 |
+
print(args)
|
| 68 |
+
|
| 69 |
+
# pretrained model path
|
| 70 |
+
CKPT_PATH = os.path.join(args.base_path, 'mdlm/outputs_gosai/pretrained.ckpt')
|
| 71 |
+
log_base_dir = os.path.join(args.save_path_dir, 'mdlm/reward_bp_results_final')
|
| 72 |
+
|
| 73 |
+
# reinitialize Hydra
|
| 74 |
+
GlobalHydra.instance().clear()
|
| 75 |
+
|
| 76 |
+
# Initialize Hydra and compose the configuration
|
| 77 |
+
initialize(config_path="configs_gosai", job_name="load_model")
|
| 78 |
+
cfg = compose(config_name="config_gosai.yaml")
|
| 79 |
+
cfg.eval.checkpoint_path = CKPT_PATH
|
| 80 |
+
curr_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 81 |
+
|
| 82 |
+
if args.no_mcts:
|
| 83 |
+
run_name = f'MDNS_buffer{args.buffer_size}_alpha{args.alpha}_resample{args.resample_every_n_step}_centering{args.centering}_{curr_time}'
|
| 84 |
+
else:
|
| 85 |
+
run_name = f'MCTS_buffer{args.buffer_size}_alpha{args.alpha}_resample{args.resample_every_n_step}_num_iter{args.num_iter}_centering{args.centering}_select_topk{args.select_topk}_select_topk_value{args.select_topk_value}_{curr_time}'
|
| 86 |
+
|
| 87 |
+
args.save_path = os.path.join(args.save_path_dir, run_name)
|
| 88 |
+
os.makedirs(args.save_path, exist_ok=True)
|
| 89 |
+
# wandb init
|
| 90 |
+
wandb.init(project='search-rl', name=run_name, config=args, dir=args.save_path)
|
| 91 |
+
|
| 92 |
+
log_path = os.path.join(args.save_path, 'log.txt')
|
| 93 |
+
|
| 94 |
+
set_seed(args.seed, use_cuda=True)
|
| 95 |
+
|
| 96 |
+
# Initialize the model
|
| 97 |
+
if args.restart_ckpt_path is not None:
|
| 98 |
+
# Resume from saved ckpt
|
| 99 |
+
restart_ckpt_path = os.path.join(args.base_path, args.restart_ckpt_path)
|
| 100 |
+
restart_epoch = restart_ckpt_path.split('_')[-1].split('.')[0]
|
| 101 |
+
args.restart_epoch = restart_epoch
|
| 102 |
+
policy_model = Diffusion(cfg).to(args.device)
|
| 103 |
+
policy_model.load_state_dict(torch.load(restart_ckpt_path, map_location=args.device))
|
| 104 |
+
else:
|
| 105 |
+
# Start from pretrained model
|
| 106 |
+
policy_model = Diffusion.load_from_checkpoint(cfg.eval.checkpoint_path, config=cfg, map_location=args.device)
|
| 107 |
+
pretrained = Diffusion.load_from_checkpoint(cfg.eval.checkpoint_path, config=cfg, map_location=args.device)
|
| 108 |
+
reward_model = oracle.get_gosai_oracle(mode='train', device=args.device)
|
| 109 |
+
|
| 110 |
+
#reward_model_eval = oracle.get_gosai_oracle(mode='eval').to(args.device)
|
| 111 |
+
|
| 112 |
+
reward_model.eval()
|
| 113 |
+
pretrained.eval()
|
| 114 |
+
#reward_model_eval.eval()
|
| 115 |
+
|
| 116 |
+
# define mcts
|
| 117 |
+
mcts = MCTS(args, cfg, policy_model, pretrained, reward_model)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
_, _, highexp_kmers_999, n_highexp_kmers_999, _, _, _ = oracle.cal_highexp_kmers(return_clss=True)
|
| 121 |
+
|
| 122 |
+
cal_atac_pred_new_mdl = oracle.get_cal_atac_orale(device=args.device)
|
| 123 |
+
cal_atac_pred_new_mdl.eval()
|
| 124 |
+
|
| 125 |
+
gosai_oracle = oracle.get_gosai_oracle(mode='eval', device=args.device)
|
| 126 |
+
gosai_oracle.eval()
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
print("args.device:", args.device)
|
| 130 |
+
print("policy_model device:", policy_model.device)
|
| 131 |
+
print("pretrained device:", pretrained.device)
|
| 132 |
+
print("reward_model device:", reward_model.device)
|
| 133 |
+
print("mcts device:", mcts.device)
|
| 134 |
+
print("gosai_oracle device:", gosai_oracle.device)
|
| 135 |
+
print("cal_atac_pred_new_mdl device:", cal_atac_pred_new_mdl.device)
|
| 136 |
+
|
| 137 |
+
eval_model_dict = {
|
| 138 |
+
"gosai_oracle": gosai_oracle,
|
| 139 |
+
"highexp_kmers_999": highexp_kmers_999,
|
| 140 |
+
"n_highexp_kmers_999": n_highexp_kmers_999,
|
| 141 |
+
"cal_atac_pred_new_mdl": cal_atac_pred_new_mdl,
|
| 142 |
+
"gosai_oracle": gosai_oracle
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
finetune(args = args, cfg = cfg, policy_model = policy_model,
|
| 147 |
+
reward_model = reward_model, mcts = mcts,
|
| 148 |
+
pretrained_model = pretrained,
|
| 149 |
+
eval_model_dict = eval_model_dict)
|
tr2d2-dna/finetune_dna.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# direct reward backpropagation
|
| 2 |
+
from hydra import initialize, compose
|
| 3 |
+
from hydra.core.global_hydra import GlobalHydra
|
| 4 |
+
import numpy as np
|
| 5 |
+
import oracle
|
| 6 |
+
from scipy.stats import pearsonr
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import argparse
|
| 10 |
+
import wandb
|
| 11 |
+
import os
|
| 12 |
+
import datetime
|
| 13 |
+
from utils import str2bool, set_seed
|
| 14 |
+
# imports
|
| 15 |
+
from finetune_utils import loss_wdce
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
|
| 18 |
+
def finetune(args, cfg, policy_model, reward_model, mcts = None, pretrained_model = None, eps=1e-5):
|
| 19 |
+
"""
|
| 20 |
+
Finetuning with WDCE loss
|
| 21 |
+
"""
|
| 22 |
+
dt = (1 - eps) / args.total_num_steps
|
| 23 |
+
|
| 24 |
+
if args.no_mcts:
|
| 25 |
+
assert pretrained_model is not None, "pretrained model is required for no mcts"
|
| 26 |
+
else:
|
| 27 |
+
assert mcts is not None, "mcts is required for mcts"
|
| 28 |
+
|
| 29 |
+
# set model to train mode
|
| 30 |
+
policy_model.train()
|
| 31 |
+
torch.set_grad_enabled(True)
|
| 32 |
+
optim = torch.optim.AdamW(policy_model.parameters(), lr=args.learning_rate)
|
| 33 |
+
|
| 34 |
+
# record metrics
|
| 35 |
+
batch_losses = []
|
| 36 |
+
batch_rewards = []
|
| 37 |
+
|
| 38 |
+
# initialize the final seqs and log_rnd of the trajectories that generated those seqs
|
| 39 |
+
x_saved, log_rnd_saved, final_rewards_saved = None, None, None
|
| 40 |
+
|
| 41 |
+
# finetuning loop
|
| 42 |
+
pbar = tqdm(range(args.num_epochs))
|
| 43 |
+
for epoch in pbar:
|
| 44 |
+
# store metrics
|
| 45 |
+
rewards = []
|
| 46 |
+
losses = []
|
| 47 |
+
|
| 48 |
+
policy_model.train()
|
| 49 |
+
|
| 50 |
+
with torch.no_grad():
|
| 51 |
+
if x_saved is None or epoch % args.resample_every_n_step == 0:
|
| 52 |
+
# compute final sequences and trajectory log_rnd
|
| 53 |
+
if args.no_mcts:
|
| 54 |
+
x_final, log_rnd, final_rewards = policy_model.sample_finetuned_with_rnd(args, reward_model, pretrained_model)
|
| 55 |
+
else:
|
| 56 |
+
x_final, log_rnd, final_rewards = mcts.forward(args.reset_tree)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# save for next iteration
|
| 60 |
+
x_saved, log_rnd_saved, final_rewards_saved = x_final, log_rnd, final_rewards
|
| 61 |
+
else:
|
| 62 |
+
x_final, log_rnd, final_rewards = x_saved, log_rnd_saved, final_rewards_saved
|
| 63 |
+
|
| 64 |
+
# compute wdce loss
|
| 65 |
+
loss = loss_wdce(policy_model, log_rnd, x_final, num_replicates=args.wdce_num_replicates)
|
| 66 |
+
|
| 67 |
+
# gradient descent
|
| 68 |
+
loss.backward()
|
| 69 |
+
|
| 70 |
+
# optimizer
|
| 71 |
+
if args.grad_clip:
|
| 72 |
+
torch.nn.utils.clip_grad_norm_(policy_model.parameters(), args.gradnorm_clip)
|
| 73 |
+
optim.step()
|
| 74 |
+
optim.zero_grad()
|
| 75 |
+
|
| 76 |
+
pbar.set_postfix(loss=loss.item())
|
| 77 |
+
|
| 78 |
+
losses.append(loss.item())
|
| 79 |
+
|
| 80 |
+
# sample a eval batch with updated policy to evaluate rewards
|
| 81 |
+
x_eval, mean_reward_eval = policy_model.sample_finetuned(args, reward_model)
|
| 82 |
+
|
| 83 |
+
batch_losses.append(loss.cpu().detach().numpy())
|
| 84 |
+
batch_rewards.append(mean_reward_eval.cpu().detach().item())
|
| 85 |
+
losses.append(loss.cpu().detach().numpy())
|
| 86 |
+
|
| 87 |
+
rewards = np.array(mean_reward_eval.detach().cpu().numpy())
|
| 88 |
+
losses = np.array(losses)
|
| 89 |
+
|
| 90 |
+
mean_reward_search = final_rewards.mean().item()
|
| 91 |
+
min_reward_search = final_rewards.min().item()
|
| 92 |
+
max_reward_search = final_rewards.max().item()
|
| 93 |
+
median_reward_search = final_rewards.median().item()
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
#reward_losses = np.array(reward_losses)
|
| 97 |
+
|
| 98 |
+
print("epoch %d"%epoch, "mean reward %f"%mean_reward_eval, "mean loss %f"%np.mean(losses))
|
| 99 |
+
|
| 100 |
+
wandb.log({"epoch": epoch, "mean_reward": mean_reward_eval, "mean_loss": np.mean(losses),
|
| 101 |
+
"mean_reward_search": mean_reward_search, "min_reward_search": min_reward_search,
|
| 102 |
+
"max_reward_search": max_reward_search, "median_reward_search": median_reward_search})
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
if (epoch+1) % args.save_every_n_epochs == 0:
|
| 106 |
+
model_path = os.path.join(args.save_path, f'model_{epoch}.ckpt')
|
| 107 |
+
torch.save(policy_model.state_dict(), model_path)
|
| 108 |
+
print(f"model saved at epoch {epoch}")
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
wandb.finish()
|
| 112 |
+
|
| 113 |
+
return batch_losses
|
tr2d2-dna/finetune_utils.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 4 |
+
from utils import sample_categorical_logits
|
| 5 |
+
import numpy as np
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import torch.distributed as dist
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
def compute_ess(log_rnd, normalize=True):
|
| 11 |
+
"""
|
| 12 |
+
log_rnd: [B]
|
| 13 |
+
Compute effective sample size:
|
| 14 |
+
If normalize: divide ESS by batch size, so range is [0, 1];
|
| 15 |
+
otherwise, range is [0, B]
|
| 16 |
+
"""
|
| 17 |
+
weights = log_rnd.detach().softmax(dim=-1)
|
| 18 |
+
ess = 1 / (weights ** 2).sum().item()
|
| 19 |
+
return ess / log_rnd.shape[0] if normalize else ess
|
| 20 |
+
|
| 21 |
+
def to_one_hot(x_idx, num_classes=4):
|
| 22 |
+
oh = F.one_hot(x_idx.long(), num_classes=num_classes)
|
| 23 |
+
return oh.float()
|
| 24 |
+
|
| 25 |
+
def rnd(model, reward_model, batch_size, scale=1, device='cuda:0'):
|
| 26 |
+
r"""
|
| 27 |
+
Run random order sampling and compute the RND $\log\frac{dP^*}{dP^u}$ along the trajectory
|
| 28 |
+
reward_model: r(X)
|
| 29 |
+
|
| 30 |
+
return:
|
| 31 |
+
- x: the final samples, [B, D]
|
| 32 |
+
- log_rnd: the log RND along this trajectory, [B]
|
| 33 |
+
"""
|
| 34 |
+
if hasattr(model, 'module'):
|
| 35 |
+
model = model.module
|
| 36 |
+
|
| 37 |
+
x = torch.full((batch_size, model.length), model.vocab_size-1).to(device=device, dtype=torch.int64)
|
| 38 |
+
batch_arange = torch.arange(batch_size, device=device)
|
| 39 |
+
jump_pos = torch.rand(x.shape, device=device).argsort(dim=-1)
|
| 40 |
+
# jump_times, jump_pos = torch.rand(x.shape, device=device).sort(dim=-1)
|
| 41 |
+
# jump_times: Unif[0,1] in increasing order
|
| 42 |
+
# jump_pos: random permutation of range(D)
|
| 43 |
+
log_rnd = torch.zeros(batch_size, device=device) # [B]
|
| 44 |
+
for d in range(model.length-1, -1, -1):
|
| 45 |
+
# jump at time jump_times[:, d] at position jump_pos[:, d]
|
| 46 |
+
logits = model(x)[:, :, :-1] # [B, D, N-1]
|
| 47 |
+
update = sample_categorical_logits(
|
| 48 |
+
logits[batch_arange, jump_pos[:, d]]) # [B]
|
| 49 |
+
if torch.is_grad_enabled(): # avoid issues with in-place operations
|
| 50 |
+
x = x.clone()
|
| 51 |
+
x[batch_arange, jump_pos[:, d]] = update
|
| 52 |
+
log_rnd += -np.log(model.vocab_size-1) - logits[batch_arange, jump_pos[:, d], update]
|
| 53 |
+
log_rnd += scale * reward_model(x) # [B]
|
| 54 |
+
return x, log_rnd
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@torch.no_grad()
|
| 58 |
+
def sampling(model, batch_size, rounds=1, device='cuda:0'):
|
| 59 |
+
"""Any order autoregressive sampling"""
|
| 60 |
+
if hasattr(model, 'module'):
|
| 61 |
+
model = model.module
|
| 62 |
+
batch_arange = torch.arange(batch_size, device=device)
|
| 63 |
+
all_samples = []
|
| 64 |
+
for _ in tqdm(range(rounds), leave=False):
|
| 65 |
+
x = torch.full((batch_size, model.length), model.vocab_size-1).to(device=device, dtype=torch.int64)
|
| 66 |
+
jump_pos = torch.rand(x.shape, device=device).argsort(dim=-1)
|
| 67 |
+
# jump_times, jump_pos = torch.rand(x.shape, device=device).sort(dim=-1)
|
| 68 |
+
# jump_times: Unif[0,1] in increasing order
|
| 69 |
+
# jump_pos: random permutation of range(D)
|
| 70 |
+
for d in tqdm(range(model.length-1, -1, -1), leave=False):
|
| 71 |
+
# jump at time jump_times[:, d] at position jump_pos[:, d]
|
| 72 |
+
logits = model.logits(x)[:, :, :-1] # [B, D, N-1], not log-softmaxed but fine
|
| 73 |
+
update = sample_categorical_logits(
|
| 74 |
+
logits[batch_arange, jump_pos[:, d]]) # [B]
|
| 75 |
+
x[batch_arange, jump_pos[:, d]] = update
|
| 76 |
+
all_samples.append(x)
|
| 77 |
+
return torch.cat(all_samples) # (rounds * B, L)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def loss_ce(log_rnd):
|
| 81 |
+
"""Cross entropy loss KL(P^*||P^u)"""
|
| 82 |
+
weights = log_rnd.detach().softmax(dim=-1)
|
| 83 |
+
return (log_rnd * weights).sum()
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def loss_lv(log_rnd):
|
| 87 |
+
r"""Log variance loss Var_{P^\bar{u}}\log\frac{dP^*}{dP^u}"""
|
| 88 |
+
return log_rnd.var()
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def loss_re_rf(log_rnd, const=0):
|
| 92 |
+
r"""Relative entropy loss KL(P^u||P^*) with REINFORCE trick"""
|
| 93 |
+
return (-log_rnd * (-log_rnd.detach() + const)).mean()
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def loss_wdce(policy_model, log_rnd, x, num_replicates=16, weight_func=lambda l: 1/l, eps=1e-3, centering = False):
|
| 97 |
+
r"""
|
| 98 |
+
Weighted denoising cross entropy loss
|
| 99 |
+
X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X)
|
| 100 |
+
|
| 101 |
+
log_rnd: [B]; x: [B, L] (no mask)
|
| 102 |
+
num_replicates: R, number of replicates of each row in x
|
| 103 |
+
weight_func: w(lambda) for each sample, 1/lambda by default
|
| 104 |
+
"""
|
| 105 |
+
mask_index = policy_model.mask_index
|
| 106 |
+
if hasattr(policy_model, 'module'):
|
| 107 |
+
policy_model = policy_model.module
|
| 108 |
+
|
| 109 |
+
batch = x.repeat_interleave(num_replicates, dim=0) # [B*R, L]
|
| 110 |
+
batch_weights = log_rnd.detach_().softmax(dim=-1)
|
| 111 |
+
if centering:
|
| 112 |
+
batch_weights = batch_weights - batch_weights.mean(dim=-1, keepdim=True)
|
| 113 |
+
|
| 114 |
+
batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0) # [B*R]
|
| 115 |
+
|
| 116 |
+
lamda = torch.rand(batch.shape[0], device=batch.device) # [B*R]
|
| 117 |
+
lamda_weights = weight_func(lamda).clamp(max=1e5) # [B*R]
|
| 118 |
+
|
| 119 |
+
masked_index = torch.rand(*batch.shape, device=batch.device) < lamda[..., None] # [B*R, D]
|
| 120 |
+
perturbed_batch = torch.where(masked_index, mask_index, batch)
|
| 121 |
+
|
| 122 |
+
# add time conditioning
|
| 123 |
+
t = lamda
|
| 124 |
+
sigma_t = -torch.log1p(-(1 - eps) * t)
|
| 125 |
+
|
| 126 |
+
# compute logits
|
| 127 |
+
logits = policy_model(perturbed_batch, sigma_t)
|
| 128 |
+
losses = torch.zeros(*batch.shape, device=batch.device, dtype=logits.dtype) # [B*R, D]
|
| 129 |
+
losses[masked_index] = torch.gather(input=logits[masked_index], dim=-1,
|
| 130 |
+
index=batch[masked_index][..., None]).squeeze(-1)
|
| 131 |
+
return - (losses.sum(dim=-1) * lamda_weights * batch_weights).mean()
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def loss_dce(model, x, weight_func=lambda l: 1/l):
|
| 135 |
+
r"""
|
| 136 |
+
Denoising cross entropy loss, x [B, D] are ground truth samples
|
| 137 |
+
weight_func: w(lambda) for each sample, 1/lambda by default
|
| 138 |
+
"""
|
| 139 |
+
lamda = torch.rand(x.shape[0], device=x.device) # [B]
|
| 140 |
+
lamda_weights = weight_func(lamda).clamp(max=1e5) # [B]
|
| 141 |
+
masked_index = torch.rand(*x.shape, device=x.device) < lamda[..., None] # [B, D]
|
| 142 |
+
perturbed_batch = torch.where(masked_index, model.vocab_size-1, x)
|
| 143 |
+
logits = model(perturbed_batch)
|
| 144 |
+
losses = torch.zeros(*x.shape, device=x.device, dtype=logits.dtype) # [B, D]
|
| 145 |
+
losses[masked_index] = torch.gather(input=logits[masked_index], dim=-1,
|
| 146 |
+
index=x[masked_index][..., None]).squeeze(-1)
|
| 147 |
+
return - (losses.sum(dim=-1) * lamda_weights).mean()
|
tr2d2-dna/mcts.py
ADDED
|
@@ -0,0 +1,581 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
import random as rd
|
| 6 |
+
from finetune_utils import to_one_hot
|
| 7 |
+
from utils import StepTimer
|
| 8 |
+
|
| 9 |
+
import noise_schedule
|
| 10 |
+
|
| 11 |
+
### BEGINNING OF NODE CLASS ###
|
| 12 |
+
|
| 13 |
+
class Node:
|
| 14 |
+
"""
|
| 15 |
+
Node class: partially unmasked sequence
|
| 16 |
+
- parentNode: Node object at previous time step
|
| 17 |
+
- childNodes: set of M Node objects generated from sampling M distinct unmasking schemes
|
| 18 |
+
- totalReward: vector of cumulative rewards for all K objectives
|
| 19 |
+
- visits: number of times the node has been visited by an interation
|
| 20 |
+
- path: array of partially unmasked SMILES strings leading to the node from the completely masked root node
|
| 21 |
+
- timestep: the time step where the sequence was sampled
|
| 22 |
+
"""
|
| 23 |
+
def __init__(self, args, tokens=None, log_rnd=None, log_policy_step=None, log_pretrained_step=None, parentNode=None, childNodes=None, totalReward=None, timestep=None):
|
| 24 |
+
self.args = args
|
| 25 |
+
self.parentNode = parentNode
|
| 26 |
+
self.childNodes = [] if childNodes is None else childNodes
|
| 27 |
+
|
| 28 |
+
self.log_rnd = log_rnd # stores the log_rnd up to that step
|
| 29 |
+
|
| 30 |
+
#self.log_p0 = 0 # stores the log probabiltiy of the unmasking step from the previous iteration
|
| 31 |
+
self.log_policy_step = log_policy_step # stores the log probability of the unmasking step under the current policy
|
| 32 |
+
self.log_pretrained_step = log_pretrained_step
|
| 33 |
+
|
| 34 |
+
# initialize total rewards to the reward of the roll out unmasked sequence
|
| 35 |
+
self.totalReward = totalReward # potential reward of the node based on generated children
|
| 36 |
+
|
| 37 |
+
# set initial visits to 1
|
| 38 |
+
self.visits = 1
|
| 39 |
+
|
| 40 |
+
#self.path = path
|
| 41 |
+
|
| 42 |
+
# set timestep (value between 0 and num_steps)
|
| 43 |
+
self.timestep = timestep
|
| 44 |
+
# set the sampling probabiltiy equal to the probability from the reverse posterior
|
| 45 |
+
#self.sampleProb = sampleProb # stores the probability of the sampling step under the current policy
|
| 46 |
+
|
| 47 |
+
# dict with 'seqs' as token array and 'attention_mask'
|
| 48 |
+
self.tokens = tokens
|
| 49 |
+
|
| 50 |
+
def selectNode(self, rootNode):
|
| 51 |
+
"""
|
| 52 |
+
Selects a node to move to among the children nodes based on select score
|
| 53 |
+
"""
|
| 54 |
+
# extract the status of the current node
|
| 55 |
+
nodeStatus = self.getExpandStatus()
|
| 56 |
+
|
| 57 |
+
# if the node is a legal non-leaf node
|
| 58 |
+
if (nodeStatus == 3):
|
| 59 |
+
# initialize array that will store select score vectors of each child node
|
| 60 |
+
selectScores = []
|
| 61 |
+
selectable_children = [] # children nodes that can be selected
|
| 62 |
+
|
| 63 |
+
for childNode in self.childNodes:
|
| 64 |
+
childStatus = childNode.getExpandStatus()
|
| 65 |
+
# only append child if it is legal leaf node (expandable) or legal non-leaf node
|
| 66 |
+
if childStatus == 2 or childStatus == 3:
|
| 67 |
+
selectScore = childNode.calcSelectScore()
|
| 68 |
+
if torch.is_tensor(selectScore) and selectScore.numel() == 1:
|
| 69 |
+
selectScore = selectScore.item()
|
| 70 |
+
|
| 71 |
+
selectable_children.append(childNode)
|
| 72 |
+
selectScores.append(float(selectScore))
|
| 73 |
+
|
| 74 |
+
# no selectable children
|
| 75 |
+
if len(selectable_children) == 0:
|
| 76 |
+
return rootNode, 3
|
| 77 |
+
|
| 78 |
+
selectScores = np.asarray(selectScores, dtype=np.float64)
|
| 79 |
+
|
| 80 |
+
temp = 1.0
|
| 81 |
+
# compute softmax probabiltiies
|
| 82 |
+
m = np.max(selectScores)
|
| 83 |
+
exps = np.exp((selectScores - m) / temp)
|
| 84 |
+
tot = exps.sum()
|
| 85 |
+
|
| 86 |
+
if not np.isfinite(tot) or tot <= 0.0:
|
| 87 |
+
probs = np.full(len(selectable_children), 1.0 / len(selectable_children))
|
| 88 |
+
else:
|
| 89 |
+
probs = exps / tot
|
| 90 |
+
|
| 91 |
+
# choose child index from categorical distribution
|
| 92 |
+
idx = np.random.choice(len(selectable_children), p=probs)
|
| 93 |
+
selected = selectable_children[idx]
|
| 94 |
+
|
| 95 |
+
# return selected child node and status
|
| 96 |
+
return selected, selected.getExpandStatus()
|
| 97 |
+
|
| 98 |
+
elif (nodeStatus == 2):
|
| 99 |
+
return self, nodeStatus
|
| 100 |
+
|
| 101 |
+
# if node is not valid non-leaf node
|
| 102 |
+
return rootNode, 3
|
| 103 |
+
|
| 104 |
+
def selectNodeTopK(self, rootNode, k = 3, temp = 1.0):
|
| 105 |
+
"""
|
| 106 |
+
Pick from the top-k by select score.
|
| 107 |
+
Returns: (selected_node, selected_status)
|
| 108 |
+
"""
|
| 109 |
+
nodeStatus = self.getExpandStatus()
|
| 110 |
+
|
| 111 |
+
# If expandable leaf, return it directly
|
| 112 |
+
if nodeStatus == 2:
|
| 113 |
+
return self, nodeStatus
|
| 114 |
+
|
| 115 |
+
if nodeStatus == 3:
|
| 116 |
+
selectable_children = []
|
| 117 |
+
selectScores = []
|
| 118 |
+
|
| 119 |
+
# collect candidates
|
| 120 |
+
for ch in self.childNodes:
|
| 121 |
+
s = ch.getExpandStatus()
|
| 122 |
+
if s in (2, 3):
|
| 123 |
+
sc = ch.calcSelectScore()
|
| 124 |
+
if torch.is_tensor(sc):
|
| 125 |
+
sc = sc.item() if sc.numel() == 1 else float(sc.mean().item())
|
| 126 |
+
sc = float(sc) if np.isfinite(sc) else -np.inf # push bad scores to -inf
|
| 127 |
+
selectable_children.append(ch)
|
| 128 |
+
selectScores.append(sc)
|
| 129 |
+
|
| 130 |
+
if not selectable_children:
|
| 131 |
+
return rootNode, 3
|
| 132 |
+
|
| 133 |
+
scores = np.asarray(selectScores, dtype=np.float64)
|
| 134 |
+
|
| 135 |
+
# top-k indices (largest scores)
|
| 136 |
+
k_eff = min(k, len(scores))
|
| 137 |
+
topk_idx = np.argpartition(-scores, kth=k_eff-1)[:k_eff]
|
| 138 |
+
# sort the top-k by score desc for stability
|
| 139 |
+
topk_idx = topk_idx[np.argsort(-scores[topk_idx])]
|
| 140 |
+
|
| 141 |
+
# slice down to top-k pool
|
| 142 |
+
pool_scores = scores[topk_idx]
|
| 143 |
+
pool_children = [selectable_children[i] for i in topk_idx]
|
| 144 |
+
|
| 145 |
+
# softmax over the top-k
|
| 146 |
+
m = np.max(pool_scores)
|
| 147 |
+
z = (pool_scores - m) / max(1e-8, temp)
|
| 148 |
+
exps = np.exp(np.clip(z, -60, 60))
|
| 149 |
+
tot = exps.sum()
|
| 150 |
+
if not np.isfinite(tot) or tot <= 0.0:
|
| 151 |
+
idx_local = 0 # best
|
| 152 |
+
else:
|
| 153 |
+
probs = exps / tot
|
| 154 |
+
|
| 155 |
+
idx_local = int(np.random.choice(len(pool_children), p=probs))
|
| 156 |
+
|
| 157 |
+
selected = pool_children[idx_local]
|
| 158 |
+
return selected, selected.getExpandStatus()
|
| 159 |
+
|
| 160 |
+
return rootNode, 3
|
| 161 |
+
|
| 162 |
+
def addChildNode(self, tokens, log_rnd, log_policy_step, log_pretrained_step, totalReward):
|
| 163 |
+
""""
|
| 164 |
+
Adds a child node:
|
| 165 |
+
log_rnd: log_rnd of the path up to the added child node
|
| 166 |
+
log_policy_step: scalar value of the log-prob of sampling the step under the policy
|
| 167 |
+
log_pretrained_step: scalar value of the log-prob of sampling the step under the pretrained model
|
| 168 |
+
"""
|
| 169 |
+
child = Node(args=self.args,
|
| 170 |
+
tokens=tokens,
|
| 171 |
+
log_rnd = log_rnd,
|
| 172 |
+
log_policy_step=log_policy_step,
|
| 173 |
+
log_pretrained_step=log_pretrained_step,
|
| 174 |
+
parentNode=self,
|
| 175 |
+
childNodes=[],
|
| 176 |
+
totalReward=totalReward,
|
| 177 |
+
timestep=self.timestep+1)
|
| 178 |
+
|
| 179 |
+
self.childNodes.append(child)
|
| 180 |
+
return child
|
| 181 |
+
|
| 182 |
+
def update_logrnd(self, log_policy_step, log_rnd):
|
| 183 |
+
self.log_policy_step = log_policy_step
|
| 184 |
+
self.log_rnd = log_rnd
|
| 185 |
+
|
| 186 |
+
def updateNode(self, rewards):
|
| 187 |
+
"""
|
| 188 |
+
Updates the cumulative rewards vector with the reward vector at a descendent leaf node.
|
| 189 |
+
Increments the number of visits to the node.
|
| 190 |
+
"""
|
| 191 |
+
self.visits += 1
|
| 192 |
+
|
| 193 |
+
self.totalReward += rewards # singleton tensor
|
| 194 |
+
|
| 195 |
+
def calcSelectScore(self):
|
| 196 |
+
"""
|
| 197 |
+
Calculates the select score for the node from the cumulative rewards vector and number of visits.
|
| 198 |
+
- c: determines the degree of exploration
|
| 199 |
+
- minSelectScore: determines the
|
| 200 |
+
"""
|
| 201 |
+
# K-dimensional vector of normalized rewards for each objective
|
| 202 |
+
normRewards = self.totalReward / self.visits
|
| 203 |
+
|
| 204 |
+
# scales the cumulative reward by the sampling probability
|
| 205 |
+
|
| 206 |
+
return normRewards + (self.args.exploration * self.log_policy_step * np.sqrt(self.parentNode.visits) / self.visits)
|
| 207 |
+
|
| 208 |
+
def getExpandStatus(self):
|
| 209 |
+
"""
|
| 210 |
+
Returns an integer indicating whether the node is a:
|
| 211 |
+
1. terminal node (sequence is fully unmasked)
|
| 212 |
+
2. legal leaf node (partially unmasked sequence that can be expanded)
|
| 213 |
+
3. legal non-leaf node (already expanded sequence with M child nodes)
|
| 214 |
+
"""
|
| 215 |
+
if self.timestep == self.args.total_num_steps:
|
| 216 |
+
return 1
|
| 217 |
+
elif (self.timestep < self.args.total_num_steps) and (len(self.childNodes) == 0):
|
| 218 |
+
return 2
|
| 219 |
+
return 3
|
| 220 |
+
|
| 221 |
+
### END OF NODE CLASS ###
|
| 222 |
+
|
| 223 |
+
### BEGINNING OF MCTS CLASS ###
|
| 224 |
+
|
| 225 |
+
class MCTS:
|
| 226 |
+
def __init__(self, args, config, policy_model, pretrained, rewardFunc, rootNode=None):
|
| 227 |
+
self.timer = StepTimer(policy_model.device)
|
| 228 |
+
|
| 229 |
+
# debugging
|
| 230 |
+
self.buf_stats = {"insert":0, "replace":0, "reject_worse":0,
|
| 231 |
+
"reject_dup":0, "reject_nonfinite":0}
|
| 232 |
+
self._seen_hashes = set()
|
| 233 |
+
|
| 234 |
+
self.device = policy_model.device
|
| 235 |
+
print(f"MCTS device: {self.device}")
|
| 236 |
+
|
| 237 |
+
self.args = args
|
| 238 |
+
self.config = config
|
| 239 |
+
self.noise = noise_schedule.get_noise(config)
|
| 240 |
+
self.time_conditioning = args.time_conditioning
|
| 241 |
+
|
| 242 |
+
self.mask_index = policy_model.mask_index
|
| 243 |
+
masked_seq = torch.ones((self.args.seq_length), device = self.device) * self.mask_index
|
| 244 |
+
masked_tokens = {'seqs': masked_seq.to(dtype=torch.long), 'attention_mask': torch.ones_like(masked_seq).to(self.device)}
|
| 245 |
+
if rootNode is None:
|
| 246 |
+
self.rootNode = Node(self.args, tokens = masked_tokens,
|
| 247 |
+
log_rnd=torch.zeros((), device=self.device),
|
| 248 |
+
log_policy_step=torch.zeros((), device=self.device),
|
| 249 |
+
log_pretrained_step=torch.zeros((), device=self.device),
|
| 250 |
+
totalReward=torch.zeros((), device=self.device), timestep=0)
|
| 251 |
+
else:
|
| 252 |
+
self.rootNode = rootNode # stores the root node of the tree
|
| 253 |
+
|
| 254 |
+
# dictionary:
|
| 255 |
+
# "seq": final unmasked sequence
|
| 256 |
+
# "traj": list of (N_steps, L)
|
| 257 |
+
# "reward": reward of the trajectory
|
| 258 |
+
self.buffer = [] # List[Dict[str, Any]]
|
| 259 |
+
|
| 260 |
+
self.buffer_size = args.buffer_size
|
| 261 |
+
|
| 262 |
+
self.num_steps = args.total_num_steps
|
| 263 |
+
self.num_sequences = args.num_sequences
|
| 264 |
+
|
| 265 |
+
# pretrained model
|
| 266 |
+
self.pretrained = pretrained
|
| 267 |
+
|
| 268 |
+
# the policy model that we want to finetune
|
| 269 |
+
self.policy_model = policy_model
|
| 270 |
+
#self.tokenizer = policy_model.tokenizer
|
| 271 |
+
self.device = policy_model.device
|
| 272 |
+
|
| 273 |
+
self.sequence_length = args.seq_length
|
| 274 |
+
|
| 275 |
+
self.num_iter = args.num_iter
|
| 276 |
+
|
| 277 |
+
self.num_children = args.num_children
|
| 278 |
+
|
| 279 |
+
# score functions
|
| 280 |
+
self.rewardFunc = rewardFunc
|
| 281 |
+
|
| 282 |
+
self.iter_num = 0
|
| 283 |
+
|
| 284 |
+
self.reward_log = []
|
| 285 |
+
self.logrnd_log = []
|
| 286 |
+
|
| 287 |
+
self.policy_model.eval()
|
| 288 |
+
self.pretrained.eval()
|
| 289 |
+
self.rewardFunc.eval()
|
| 290 |
+
|
| 291 |
+
def _hash_tokens(self, t):
|
| 292 |
+
# t: (L,) torch.long
|
| 293 |
+
return tuple(t.detach().cpu().tolist())
|
| 294 |
+
|
| 295 |
+
def reset(self, resetTree):
|
| 296 |
+
self.iter_num = 0
|
| 297 |
+
self.buffer = []
|
| 298 |
+
self._seen_hashes = set() # Clear the hash set too!
|
| 299 |
+
self.reward_log = []
|
| 300 |
+
self.logrnd_log = []
|
| 301 |
+
|
| 302 |
+
# add option to continue with the same tree
|
| 303 |
+
if resetTree:
|
| 304 |
+
masked_seq = torch.ones((self.args.seq_length), device = self.device) * self.mask_index
|
| 305 |
+
masked_tokens = {'seqs': masked_seq.to(dtype=torch.long), 'attention_mask': torch.ones_like(masked_seq).to(self.device)}
|
| 306 |
+
self.rootNode = Node(self.args, tokens = masked_tokens,
|
| 307 |
+
log_rnd=torch.zeros((), device=self.device),
|
| 308 |
+
log_policy_step=torch.zeros((), device=self.device),
|
| 309 |
+
log_pretrained_step=torch.zeros((), device=self.device),
|
| 310 |
+
totalReward=torch.zeros((), device=self.device), timestep=0)
|
| 311 |
+
|
| 312 |
+
def forward(self, resetTree=False):
|
| 313 |
+
|
| 314 |
+
self.reset(resetTree)
|
| 315 |
+
|
| 316 |
+
while (self.iter_num < self.num_iter):
|
| 317 |
+
self.iter_num += 1
|
| 318 |
+
|
| 319 |
+
# traverse the tree form the root node until a leaf node
|
| 320 |
+
with self.timer.section("select"):
|
| 321 |
+
leafNode, _ = self.select(self.rootNode)
|
| 322 |
+
|
| 323 |
+
# expand leaf node into num_children partially unmasked sequences at the next timestep
|
| 324 |
+
with self.timer.section("expand"):
|
| 325 |
+
self.expand(leafNode)
|
| 326 |
+
|
| 327 |
+
final_x, log_rnd, final_rewards = self.consolidateBuffer()
|
| 328 |
+
|
| 329 |
+
rows = self.timer.summary()
|
| 330 |
+
print("\n=== Timing summary (by total time) ===")
|
| 331 |
+
for name, cnt, total, mean, p50, p95 in rows:
|
| 332 |
+
print(f"{name:30s} n={cnt:5d} total={total:8.3f}s mean={mean*1e3:7.2f}ms "
|
| 333 |
+
f"p50={p50*1e3:7.2f}ms p95={p95*1e3:7.2f}ms")
|
| 334 |
+
|
| 335 |
+
# return final_seqs (B, L), log_rnd (B, ), and final rewards (B, )
|
| 336 |
+
return final_x, log_rnd, final_rewards
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def updateBuffer(self, x_final, log_rnd, final_reward):
|
| 340 |
+
B = x_final.shape[0]
|
| 341 |
+
for i in range(B):
|
| 342 |
+
# Finite check
|
| 343 |
+
if not torch.isfinite(final_reward[i]) or not torch.isfinite(log_rnd[i]):
|
| 344 |
+
self.buf_stats["reject_nonfinite"] += 1
|
| 345 |
+
continue
|
| 346 |
+
|
| 347 |
+
h = self._hash_tokens(x_final[i])
|
| 348 |
+
if h in self._seen_hashes:
|
| 349 |
+
self.buf_stats["reject_dup"] += 1
|
| 350 |
+
continue
|
| 351 |
+
|
| 352 |
+
item = {"x_final": x_final[i].clone(),
|
| 353 |
+
"log_rnd": log_rnd[i].clone(),
|
| 354 |
+
"final_reward": final_reward[i].clone()}
|
| 355 |
+
|
| 356 |
+
if len(self.buffer) < self.buffer_size:
|
| 357 |
+
self.buffer.append(item)
|
| 358 |
+
self._seen_hashes.add(h)
|
| 359 |
+
self.buf_stats["insert"] += 1
|
| 360 |
+
else:
|
| 361 |
+
# replace if strictly better, or tie-break with log_rnd
|
| 362 |
+
min_idx, min_item = min(
|
| 363 |
+
enumerate(self.buffer),
|
| 364 |
+
key=lambda kv: (kv[1]["final_reward"].item(), kv[1]["log_rnd"].item())
|
| 365 |
+
)
|
| 366 |
+
cand_key = (final_reward[i].item(), log_rnd[i].item())
|
| 367 |
+
min_key = (min_item["final_reward"].item(), min_item["log_rnd"].item())
|
| 368 |
+
|
| 369 |
+
if cand_key > min_key: # allow ties via 2nd key
|
| 370 |
+
# update hash set
|
| 371 |
+
old_h = self._hash_tokens(self.buffer[min_idx]["x_final"])
|
| 372 |
+
if old_h in self._seen_hashes:
|
| 373 |
+
self._seen_hashes.remove(old_h)
|
| 374 |
+
self.buffer[min_idx] = item
|
| 375 |
+
self._seen_hashes.add(h)
|
| 376 |
+
self.buf_stats["replace"] += 1
|
| 377 |
+
else:
|
| 378 |
+
self.buf_stats["reject_worse"] += 1
|
| 379 |
+
|
| 380 |
+
def print_buffer_stats(self):
|
| 381 |
+
print("[BUFFER] ",
|
| 382 |
+
" ".join(f"{k}={v}" for k,v in self.buf_stats.items()),
|
| 383 |
+
f" size={len(self.buffer)}/{self.buffer_size}")
|
| 384 |
+
if len(self.buffer):
|
| 385 |
+
vals = torch.stack([b["final_reward"] for b in self.buffer]).float()
|
| 386 |
+
print(f"[BUFFER] reward min/mean/max: {vals.min():.4f} {vals.mean():.4f} {vals.max():.4f}")
|
| 387 |
+
|
| 388 |
+
def consolidateBuffer(self):
|
| 389 |
+
"""
|
| 390 |
+
returns x_final, log_rnd, and final_rewards in tensors
|
| 391 |
+
"""
|
| 392 |
+
x_final = []
|
| 393 |
+
log_rnd = []
|
| 394 |
+
final_rewards = []
|
| 395 |
+
for item in self.buffer:
|
| 396 |
+
x_final.append(item["x_final"])
|
| 397 |
+
log_rnd.append(item["log_rnd"])
|
| 398 |
+
final_rewards.append(item["final_reward"])
|
| 399 |
+
|
| 400 |
+
x_final = torch.stack(x_final, dim=0) # (B, L)
|
| 401 |
+
log_rnd = torch.stack(log_rnd, dim=0).to(dtype=torch.float32) # (B)
|
| 402 |
+
final_rewards = torch.stack(final_rewards, dim=0).to(dtype=torch.float32) # (B)
|
| 403 |
+
|
| 404 |
+
return x_final, log_rnd, final_rewards
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def isPathEnd(self, path, maxDepth):
|
| 408 |
+
"""
|
| 409 |
+
Checks if the node is completely unmasked (ie. end of path)
|
| 410 |
+
or if the path is at the max depth
|
| 411 |
+
"""
|
| 412 |
+
if (path[-1] != self.mask_index).all():
|
| 413 |
+
return True
|
| 414 |
+
elif len(path) >= maxDepth:
|
| 415 |
+
return True
|
| 416 |
+
return False
|
| 417 |
+
|
| 418 |
+
def select(self, currNode, eps=1e-5):
|
| 419 |
+
"""
|
| 420 |
+
Traverse the tree from the root node until reaching a legal leaf node
|
| 421 |
+
"""
|
| 422 |
+
#iter = 1
|
| 423 |
+
updated_log_rnd = torch.zeros((), device=self.device)
|
| 424 |
+
while True:
|
| 425 |
+
if self.args.select_topk:
|
| 426 |
+
currNode, nodeStatus = currNode.selectNodeTopK(self.rootNode, k=self.args.select_topk_value, temp=1.0)
|
| 427 |
+
else:
|
| 428 |
+
currNode, nodeStatus = currNode.selectNode(self.rootNode)
|
| 429 |
+
|
| 430 |
+
if currNode.parentNode is not None:
|
| 431 |
+
# compute new log_policy
|
| 432 |
+
child_tokens = currNode.tokens['seqs'].to(self.device)
|
| 433 |
+
attn_mask = currNode.tokens['attention_mask'].to(self.device)
|
| 434 |
+
parent = currNode.parentNode
|
| 435 |
+
parent_tokens = parent.tokens['seqs'].to(self.device)
|
| 436 |
+
t = torch.ones(1, device = self.device)
|
| 437 |
+
dt = (1 - eps) / self.num_steps
|
| 438 |
+
with torch.no_grad():
|
| 439 |
+
with self.timer.section("select.compute_log_policy"):
|
| 440 |
+
updated_log_policy_step = self.policy_model.compute_log_policy(parent_tokens,
|
| 441 |
+
child_tokens,
|
| 442 |
+
t=t, dt=dt)
|
| 443 |
+
updated_log_rnd += (currNode.log_pretrained_step - updated_log_policy_step)
|
| 444 |
+
|
| 445 |
+
currNode.update_logrnd(updated_log_policy_step, updated_log_rnd) # update log_rnd
|
| 446 |
+
|
| 447 |
+
# node is terminal node or logal leaf node, return for expansion
|
| 448 |
+
if nodeStatus == 2:
|
| 449 |
+
return currNode, nodeStatus
|
| 450 |
+
elif nodeStatus == 1:
|
| 451 |
+
currNode = self.rootNode
|
| 452 |
+
|
| 453 |
+
def expand(self, parentNode, eps=1e-5):
|
| 454 |
+
"""
|
| 455 |
+
Sample unmasking steps from the pre-trained MDLM
|
| 456 |
+
adds num_children partially unmasked sequences to the children of the parentNode
|
| 457 |
+
"""
|
| 458 |
+
|
| 459 |
+
num_children = self.num_children
|
| 460 |
+
# initialize child rewards that will be added to total rewards
|
| 461 |
+
|
| 462 |
+
allChildReward = torch.zeros((), device=self.device)
|
| 463 |
+
|
| 464 |
+
# compute number of rollout steps
|
| 465 |
+
# if parentNode.timestep = self.num_steps then num_rollout_steps = 1
|
| 466 |
+
num_rollout_steps = self.num_steps - parentNode.timestep
|
| 467 |
+
# array of rollout timesteps from the timestep of parent node to 0
|
| 468 |
+
rollout_t = torch.linspace(1, eps, self.num_steps + 1, device=self.device)
|
| 469 |
+
dt = (1 - eps) / self.num_steps
|
| 470 |
+
|
| 471 |
+
# initialize x and attn_mask
|
| 472 |
+
x = parentNode.tokens['seqs'].to(self.device)
|
| 473 |
+
attn_mask = parentNode.tokens['attention_mask'].to(self.device)
|
| 474 |
+
parent_log_rnd = parentNode.log_rnd # stores the log_rnd up to parent node
|
| 475 |
+
|
| 476 |
+
t = rollout_t[parentNode.timestep] * torch.ones(1, 1, device = self.device)
|
| 477 |
+
|
| 478 |
+
# generate (n_children, seq_length) array of sampled children nodes
|
| 479 |
+
|
| 480 |
+
# sample M child sequences and compute their log probabilities
|
| 481 |
+
with torch.no_grad():
|
| 482 |
+
with self.timer.section("expand.batch_mcts_reverse_step"):
|
| 483 |
+
child_log_p, x_children, child_log_policy_step, child_log_pretrained_step = \
|
| 484 |
+
self.policy_model.batch_mcts_reverse_step(token_array=x,
|
| 485 |
+
t=t, dt=dt,
|
| 486 |
+
batch_size=num_children,
|
| 487 |
+
pretrained=self.pretrained)
|
| 488 |
+
|
| 489 |
+
# compute weight of the step (num_children, 1)
|
| 490 |
+
|
| 491 |
+
child_log_rnd = (parent_log_rnd + (child_log_pretrained_step - child_log_policy_step)).to(self.device)
|
| 492 |
+
|
| 493 |
+
x_rollout = x_children
|
| 494 |
+
|
| 495 |
+
traj_log_rnd = child_log_rnd # initialize log_rnd for entire rolled out trajectory
|
| 496 |
+
|
| 497 |
+
# rollout under the policy and compute the log ratio at each step
|
| 498 |
+
with self.timer.section("expand.rollout_total"):
|
| 499 |
+
for i in range(1, num_rollout_steps):
|
| 500 |
+
t = rollout_t[parentNode.timestep + i] * torch.ones(num_children, 1, device = self.device)
|
| 501 |
+
|
| 502 |
+
with torch.no_grad():
|
| 503 |
+
log_p, x_next, log_policy_step, log_pretrained_step = \
|
| 504 |
+
self.policy_model.mcts_reverse_step(x_rollout,
|
| 505 |
+
t=t, dt=dt,
|
| 506 |
+
pretrained=self.pretrained)
|
| 507 |
+
|
| 508 |
+
# add the rollout step
|
| 509 |
+
traj_log_rnd += log_pretrained_step - log_policy_step
|
| 510 |
+
|
| 511 |
+
x_rollout = x_next
|
| 512 |
+
|
| 513 |
+
# if mask token remains, fully unmask
|
| 514 |
+
mask_positions = (x_rollout == self.mask_index) # (B, L) bool
|
| 515 |
+
|
| 516 |
+
# does **any** mask remain in any sequence
|
| 517 |
+
any_mask_global = mask_positions.any().item() # true if mask remains
|
| 518 |
+
if any_mask_global:
|
| 519 |
+
with torch.no_grad():
|
| 520 |
+
with self.timer.section("expand.noise_removal"):
|
| 521 |
+
log_p, x_next, log_policy_step, log_pretrained_step = \
|
| 522 |
+
self.policy_model.mcts_noise_removal(x_rollout,
|
| 523 |
+
t=t, dt=dt,
|
| 524 |
+
pretrained=self.pretrained)
|
| 525 |
+
|
| 526 |
+
traj_log_rnd += log_pretrained_step - log_policy_step
|
| 527 |
+
|
| 528 |
+
x_rollout = x_next
|
| 529 |
+
|
| 530 |
+
x_final = x_rollout # final sequences (B, L)
|
| 531 |
+
|
| 532 |
+
# edit? how is the reward model defined?
|
| 533 |
+
#childSequences = self.tokenizer.batch_decode(x_rollout)
|
| 534 |
+
|
| 535 |
+
#if self.args.data == "peptide":
|
| 536 |
+
#validSequences = []
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
# get final rewards
|
| 540 |
+
x_one_hot = to_one_hot(x_final)
|
| 541 |
+
x_one_hot_reward = torch.transpose(x_one_hot, 1, 2)
|
| 542 |
+
reward_preds = self.rewardFunc(x_one_hot_reward).squeeze(-1) # (num_children, 4)
|
| 543 |
+
rewards_value = reward_preds[:, 0] # (num_children, 1)
|
| 544 |
+
|
| 545 |
+
if self.args.reward_clip:
|
| 546 |
+
rewards = torch.clamp(rewards_value, max=self.args.reward_clip_value)
|
| 547 |
+
else:
|
| 548 |
+
rewards = rewards_value
|
| 549 |
+
|
| 550 |
+
traj_log_rnd += rewards / self.args.alpha
|
| 551 |
+
|
| 552 |
+
self.reward_log.append(rewards.detach().cpu().numpy())
|
| 553 |
+
self.logrnd_log.append(traj_log_rnd.detach().cpu().numpy())
|
| 554 |
+
|
| 555 |
+
# update buffer
|
| 556 |
+
with self.timer.section("expand.update_buffer"):
|
| 557 |
+
self.updateBuffer(x_final, traj_log_rnd, rewards)
|
| 558 |
+
|
| 559 |
+
for i in range(num_children):
|
| 560 |
+
|
| 561 |
+
# add to all child reward vector for backprop
|
| 562 |
+
allChildReward += rewards[i]
|
| 563 |
+
|
| 564 |
+
# create node for sequence and add to the children node of parent
|
| 565 |
+
childTokens = {'seqs': x_children[i].to(dtype=torch.long), 'attention_mask': attn_mask}
|
| 566 |
+
parentNode.addChildNode(tokens=childTokens,
|
| 567 |
+
log_rnd=child_log_rnd[i],
|
| 568 |
+
log_policy_step=child_log_policy_step[i],
|
| 569 |
+
log_pretrained_step=child_log_pretrained_step[i],
|
| 570 |
+
totalReward=rewards[i])
|
| 571 |
+
|
| 572 |
+
# backpropogate all child rewards
|
| 573 |
+
with self.timer.section("expand.backprop"):
|
| 574 |
+
self.backprop(parentNode, allChildReward)
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
def backprop(self, node, allChildReward):
|
| 578 |
+
# backpropogate rewards through the path leading to the leaf node from the root
|
| 579 |
+
while node:
|
| 580 |
+
node.updateNode(allChildReward)
|
| 581 |
+
node = node.parentNode
|
tr2d2-dna/models/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import ema
|
| 2 |
+
from . import dnaconv
|
tr2d2-dna/models/dnaconv.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
import copy
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class GaussianFourierProjection(nn.Module):
|
| 9 |
+
"""
|
| 10 |
+
Gaussian random features for encoding time steps.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, embed_dim, scale=30.):
|
| 14 |
+
super().__init__()
|
| 15 |
+
# Randomly sample weights during initialization. These weights are fixed
|
| 16 |
+
# during optimization and are not trainable.
|
| 17 |
+
self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
|
| 21 |
+
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Dense(nn.Module):
|
| 25 |
+
"""
|
| 26 |
+
A fully connected layer that reshapes outputs to feature maps.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, input_dim, output_dim):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.dense = nn.Linear(input_dim, output_dim)
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
return self.dense(x)[...]
|
| 35 |
+
|
| 36 |
+
# from https://github.com/HannesStark/dirichlet-flow-matching
|
| 37 |
+
class CNNModel(nn.Module):
|
| 38 |
+
def __init__(self, args, alphabet_size, num_cls, classifier=False):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.alphabet_size = alphabet_size
|
| 41 |
+
self.args = args
|
| 42 |
+
self.classifier = classifier
|
| 43 |
+
self.num_cls = num_cls
|
| 44 |
+
|
| 45 |
+
if self.args.clean_data:
|
| 46 |
+
self.linear = nn.Embedding(self.alphabet_size, embedding_dim=args.hidden_dim)
|
| 47 |
+
else:
|
| 48 |
+
inp_size = self.alphabet_size #+ 1
|
| 49 |
+
self.linear = nn.Conv1d(inp_size, args.hidden_dim, kernel_size=9, padding=4)
|
| 50 |
+
self.time_embedder = nn.Sequential(GaussianFourierProjection(embed_dim= args.hidden_dim),nn.Linear(args.hidden_dim, args.hidden_dim))
|
| 51 |
+
|
| 52 |
+
self.num_layers = 5 * args.num_cnn_stacks
|
| 53 |
+
self.convs = [nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, padding=4),
|
| 54 |
+
nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, padding=4),
|
| 55 |
+
nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=4, padding=16),
|
| 56 |
+
nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=16, padding=64),
|
| 57 |
+
nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=64, padding=256)]
|
| 58 |
+
self.convs = nn.ModuleList([copy.deepcopy(layer) for layer in self.convs for i in range(args.num_cnn_stacks)])
|
| 59 |
+
self.time_layers = nn.ModuleList([Dense(args.hidden_dim, args.hidden_dim) for _ in range(self.num_layers)])
|
| 60 |
+
self.norms = nn.ModuleList([nn.LayerNorm(args.hidden_dim) for _ in range(self.num_layers)])
|
| 61 |
+
self.final_conv = nn.Sequential(nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=1),
|
| 62 |
+
nn.ReLU(),
|
| 63 |
+
nn.Conv1d(args.hidden_dim, args.hidden_dim if classifier else self.alphabet_size, kernel_size=1))
|
| 64 |
+
self.dropout = nn.Dropout(args.dropout)
|
| 65 |
+
if classifier:
|
| 66 |
+
self.cls_head = nn.Sequential(nn.Linear(args.hidden_dim, args.hidden_dim),
|
| 67 |
+
nn.ReLU(),
|
| 68 |
+
nn.Linear(args.hidden_dim, self.num_cls))
|
| 69 |
+
|
| 70 |
+
if self.args.cls_free_guidance and not self.classifier:
|
| 71 |
+
self.cls_embedder = nn.Embedding(num_embeddings=self.num_cls + 1, embedding_dim=args.hidden_dim)
|
| 72 |
+
self.cls_layers = nn.ModuleList([Dense(args.hidden_dim, args.hidden_dim) for _ in range(self.num_layers)])
|
| 73 |
+
|
| 74 |
+
def forward(self, seq, t, cls = None, return_embedding=False):
|
| 75 |
+
# adapt it to support both seq indices input and one-hot input
|
| 76 |
+
if not (seq.ndim > 2 and seq.shape[-1] == self.alphabet_size):
|
| 77 |
+
seq = F.one_hot(seq, num_classes=self.alphabet_size).float()
|
| 78 |
+
|
| 79 |
+
if self.args.clean_data:
|
| 80 |
+
feat = self.linear(seq)
|
| 81 |
+
feat = feat.permute(0, 2, 1)
|
| 82 |
+
else:
|
| 83 |
+
time_emb = F.relu(self.time_embedder(t))
|
| 84 |
+
feat = seq.permute(0, 2, 1)
|
| 85 |
+
feat = F.relu(self.linear(feat))
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
if self.args.cls_free_guidance and not self.classifier:
|
| 89 |
+
cls_emb = self.cls_embedder(cls)
|
| 90 |
+
|
| 91 |
+
for i in range(self.num_layers):
|
| 92 |
+
h = self.dropout(feat.clone())
|
| 93 |
+
|
| 94 |
+
if not self.args.clean_data:
|
| 95 |
+
h = h + self.time_layers[i](time_emb)[:, :, None]
|
| 96 |
+
|
| 97 |
+
if self.args.cls_free_guidance and not self.classifier:
|
| 98 |
+
h = h + self.cls_layers[i](cls_emb)[:, :, None]
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
h = self.norms[i]((h).permute(0, 2, 1))
|
| 102 |
+
h = F.relu(self.convs[i](h.permute(0, 2, 1)))
|
| 103 |
+
|
| 104 |
+
if h.shape == feat.shape:
|
| 105 |
+
feat = h + feat
|
| 106 |
+
else:
|
| 107 |
+
feat = h
|
| 108 |
+
|
| 109 |
+
feat = self.final_conv(feat)
|
| 110 |
+
|
| 111 |
+
feat = feat.permute(0, 2, 1)
|
| 112 |
+
|
| 113 |
+
if self.classifier:
|
| 114 |
+
feat = feat.mean(dim=1)
|
| 115 |
+
if return_embedding:
|
| 116 |
+
embedding = self.cls_head[:1](feat)
|
| 117 |
+
return self.cls_head[1:](embedding), embedding
|
| 118 |
+
else:
|
| 119 |
+
return self.cls_head(feat)
|
| 120 |
+
|
| 121 |
+
return feat
|
tr2d2-dna/models/ema.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class ExponentialMovingAverage:
|
| 5 |
+
"""
|
| 6 |
+
Maintains (exponential) moving average of a set of parameters.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
def __init__(self, parameters, decay, use_num_updates=True):
|
| 10 |
+
"""
|
| 11 |
+
Args:
|
| 12 |
+
parameters: Iterable of `torch.nn.Parameter`; usually the result of
|
| 13 |
+
`model.parameters()`.
|
| 14 |
+
decay: The exponential decay.
|
| 15 |
+
use_num_updates: Whether to use number of updates when computing
|
| 16 |
+
averages.
|
| 17 |
+
"""
|
| 18 |
+
if decay < 0.0 or decay > 1.0:
|
| 19 |
+
raise ValueError('Decay must be between 0 and 1')
|
| 20 |
+
self.decay = decay
|
| 21 |
+
self.num_updates = 0 if use_num_updates else None
|
| 22 |
+
self.shadow_params = [p.clone().detach()
|
| 23 |
+
for p in parameters if p.requires_grad]
|
| 24 |
+
self.collected_params = []
|
| 25 |
+
|
| 26 |
+
def move_shadow_params_to_device(self, device):
|
| 27 |
+
self.shadow_params = [i.to(device) for i in self.shadow_params]
|
| 28 |
+
|
| 29 |
+
def update(self, parameters):
|
| 30 |
+
"""
|
| 31 |
+
Update currently maintained parameters.
|
| 32 |
+
|
| 33 |
+
Call this every time the parameters are updated, such as the result of
|
| 34 |
+
the `optimizer.step()` call.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
parameters: Iterable of `torch.nn.Parameter`; usually the same set of
|
| 38 |
+
parameters used to initialize this object.
|
| 39 |
+
"""
|
| 40 |
+
decay = self.decay
|
| 41 |
+
if self.num_updates is not None:
|
| 42 |
+
self.num_updates += 1
|
| 43 |
+
decay = min(decay, (1 + self.num_updates) /
|
| 44 |
+
(10 + self.num_updates))
|
| 45 |
+
one_minus_decay = 1.0 - decay
|
| 46 |
+
with torch.no_grad():
|
| 47 |
+
parameters = [p for p in parameters if p.requires_grad]
|
| 48 |
+
for s_param, param in zip(self.shadow_params, parameters):
|
| 49 |
+
s_param.sub_(one_minus_decay * (s_param - param))
|
| 50 |
+
|
| 51 |
+
def copy_to(self, parameters):
|
| 52 |
+
"""
|
| 53 |
+
Copy current parameters into given collection of parameters.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
| 57 |
+
updated with the stored moving averages.
|
| 58 |
+
"""
|
| 59 |
+
parameters = [p for p in parameters if p.requires_grad]
|
| 60 |
+
for s_param, param in zip(self.shadow_params, parameters):
|
| 61 |
+
if param.requires_grad:
|
| 62 |
+
param.data.copy_(s_param.data)
|
| 63 |
+
|
| 64 |
+
def store(self, parameters):
|
| 65 |
+
"""
|
| 66 |
+
Save the current parameters for restoring later.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
| 70 |
+
temporarily stored.
|
| 71 |
+
"""
|
| 72 |
+
self.collected_params = [param.clone() for param in parameters]
|
| 73 |
+
|
| 74 |
+
def restore(self, parameters):
|
| 75 |
+
"""
|
| 76 |
+
Restore the parameters stored with the `store` method.
|
| 77 |
+
Useful to validate the model with EMA parameters without affecting the
|
| 78 |
+
original optimization process. Store the parameters before the
|
| 79 |
+
`copy_to` method. After validation (or model saving), use this to
|
| 80 |
+
restore the former parameters.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
| 84 |
+
updated with the stored parameters.
|
| 85 |
+
"""
|
| 86 |
+
for c_param, param in zip(self.collected_params, parameters):
|
| 87 |
+
param.data.copy_(c_param.data)
|
| 88 |
+
|
| 89 |
+
def state_dict(self):
|
| 90 |
+
return dict(decay=self.decay,
|
| 91 |
+
num_updates=self.num_updates,
|
| 92 |
+
shadow_params=self.shadow_params)
|
| 93 |
+
|
| 94 |
+
def load_state_dict(self, state_dict):
|
| 95 |
+
self.decay = state_dict['decay']
|
| 96 |
+
self.num_updates = state_dict['num_updates']
|
| 97 |
+
self.shadow_params = state_dict['shadow_params']
|
tr2d2-dna/noise_schedule.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
# Flags required to enable jit fusion kernels
|
| 7 |
+
torch._C._jit_set_profiling_mode(False)
|
| 8 |
+
torch._C._jit_set_profiling_executor(False)
|
| 9 |
+
torch._C._jit_override_can_fuse_on_cpu(True)
|
| 10 |
+
torch._C._jit_override_can_fuse_on_gpu(True)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_noise(config, dtype=torch.float32):
|
| 14 |
+
if config.noise.type == 'geometric':
|
| 15 |
+
return GeometricNoise(config.noise.sigma_min,
|
| 16 |
+
config.noise.sigma_max)
|
| 17 |
+
elif config.noise.type == 'loglinear':
|
| 18 |
+
return LogLinearNoise()
|
| 19 |
+
elif config.noise.type == 'cosine':
|
| 20 |
+
return CosineNoise()
|
| 21 |
+
elif config.noise.type == 'cosinesqr':
|
| 22 |
+
return CosineSqrNoise()
|
| 23 |
+
elif config.noise.type == 'linear':
|
| 24 |
+
return Linear(config.noise.sigma_min,
|
| 25 |
+
config.noise.sigma_max,
|
| 26 |
+
dtype)
|
| 27 |
+
else:
|
| 28 |
+
raise ValueError(f'{config.noise.type} is not a valid noise')
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def binary_discretization(z):
|
| 32 |
+
z_hard = torch.sign(z)
|
| 33 |
+
z_soft = z / torch.norm(z, dim=-1, keepdim=True)
|
| 34 |
+
return z_soft + (z_hard - z_soft).detach()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Noise(abc.ABC, nn.Module):
|
| 38 |
+
"""
|
| 39 |
+
Baseline forward method to get the total + rate of noise at a timestep
|
| 40 |
+
"""
|
| 41 |
+
def forward(self, t):
|
| 42 |
+
# Assume time goes from 0 to 1
|
| 43 |
+
return self.total_noise(t), self.rate_noise(t)
|
| 44 |
+
|
| 45 |
+
@abc.abstractmethod
|
| 46 |
+
def rate_noise(self, t):
|
| 47 |
+
"""
|
| 48 |
+
Rate of change of noise ie g(t)
|
| 49 |
+
"""
|
| 50 |
+
pass
|
| 51 |
+
|
| 52 |
+
@abc.abstractmethod
|
| 53 |
+
def total_noise(self, t):
|
| 54 |
+
"""
|
| 55 |
+
Total noise ie \int_0^t g(t) dt + g(0)
|
| 56 |
+
"""
|
| 57 |
+
pass
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class CosineNoise(Noise):
|
| 61 |
+
def __init__(self, eps=1e-3):
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.eps = eps
|
| 64 |
+
|
| 65 |
+
def rate_noise(self, t):
|
| 66 |
+
cos = (1 - self.eps) * torch.cos(t * torch.pi / 2)
|
| 67 |
+
sin = (1 - self.eps) * torch.sin(t * torch.pi / 2)
|
| 68 |
+
scale = torch.pi / 2
|
| 69 |
+
return scale * sin / (cos + self.eps)
|
| 70 |
+
|
| 71 |
+
def total_noise(self, t):
|
| 72 |
+
cos = torch.cos(t * torch.pi / 2)
|
| 73 |
+
return - torch.log(self.eps + (1 - self.eps) * cos)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class CosineSqrNoise(Noise):
|
| 77 |
+
def __init__(self, eps=1e-3):
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.eps = eps
|
| 80 |
+
|
| 81 |
+
def rate_noise(self, t):
|
| 82 |
+
cos = (1 - self.eps) * (
|
| 83 |
+
torch.cos(t * torch.pi / 2) ** 2)
|
| 84 |
+
sin = (1 - self.eps) * torch.sin(t * torch.pi)
|
| 85 |
+
scale = torch.pi / 2
|
| 86 |
+
return scale * sin / (cos + self.eps)
|
| 87 |
+
|
| 88 |
+
def total_noise(self, t):
|
| 89 |
+
cos = torch.cos(t * torch.pi / 2) ** 2
|
| 90 |
+
return - torch.log(self.eps + (1 - self.eps) * cos)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class Linear(Noise):
|
| 94 |
+
def __init__(self, sigma_min=0, sigma_max=10, dtype=torch.float32):
|
| 95 |
+
super().__init__()
|
| 96 |
+
self.sigma_min = torch.tensor(sigma_min, dtype=dtype)
|
| 97 |
+
self.sigma_max = torch.tensor(sigma_max, dtype=dtype)
|
| 98 |
+
|
| 99 |
+
def rate_noise(self, t):
|
| 100 |
+
return self.sigma_max - self.sigma_min
|
| 101 |
+
|
| 102 |
+
def total_noise(self, t):
|
| 103 |
+
return self.sigma_min + t * (self.sigma_max - self.sigma_min)
|
| 104 |
+
|
| 105 |
+
def importance_sampling_transformation(self, t):
|
| 106 |
+
f_T = torch.log1p(- torch.exp(- self.sigma_max))
|
| 107 |
+
f_0 = torch.log1p(- torch.exp(- self.sigma_min))
|
| 108 |
+
sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
|
| 109 |
+
return (sigma_t - self.sigma_min) / (
|
| 110 |
+
self.sigma_max - self.sigma_min)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class GeometricNoise(Noise):
|
| 114 |
+
def __init__(self, sigma_min=1e-3, sigma_max=1):
|
| 115 |
+
super().__init__()
|
| 116 |
+
self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max])
|
| 117 |
+
|
| 118 |
+
def rate_noise(self, t):
|
| 119 |
+
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (
|
| 120 |
+
self.sigmas[1].log() - self.sigmas[0].log())
|
| 121 |
+
|
| 122 |
+
def total_noise(self, t):
|
| 123 |
+
return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class LogLinearNoise(Noise):
|
| 127 |
+
"""Log Linear noise schedule.
|
| 128 |
+
|
| 129 |
+
Built such that 1 - 1/e^(n(t)) interpolates between 0 and
|
| 130 |
+
~1 when t varies from 0 to 1. Total noise is
|
| 131 |
+
-log(1 - (1 - eps) * t), so the sigma will be
|
| 132 |
+
(1 - eps) * t.
|
| 133 |
+
"""
|
| 134 |
+
def __init__(self, eps=1e-3):
|
| 135 |
+
super().__init__()
|
| 136 |
+
self.eps = eps
|
| 137 |
+
self.sigma_max = self.total_noise(torch.tensor(1.0))
|
| 138 |
+
self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0))
|
| 139 |
+
|
| 140 |
+
def rate_noise(self, t):
|
| 141 |
+
return (1 - self.eps) / (1 - (1 - self.eps) * t)
|
| 142 |
+
|
| 143 |
+
def total_noise(self, t):
|
| 144 |
+
return -torch.log1p(-(1 - self.eps) * t)
|
| 145 |
+
|
| 146 |
+
def importance_sampling_transformation(self, t):
|
| 147 |
+
f_T = torch.log1p(- torch.exp(- self.sigma_max))
|
| 148 |
+
f_0 = torch.log1p(- torch.exp(- self.sigma_min))
|
| 149 |
+
sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0))
|
| 150 |
+
t = - torch.expm1(- sigma_t) / (1 - self.eps)
|
| 151 |
+
return t
|
tr2d2-dna/oracle.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import tempfile
|
| 3 |
+
import grelu
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import os
|
| 6 |
+
from grelu.lightning import LightningModel
|
| 7 |
+
import grelu.data.preprocess
|
| 8 |
+
import grelu.data.dataset
|
| 9 |
+
import dataloader_gosai
|
| 10 |
+
import numpy as np
|
| 11 |
+
from typing import Callable, Union, List
|
| 12 |
+
from scipy.linalg import sqrtm
|
| 13 |
+
from scipy.stats import pearsonr
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
import io
|
| 16 |
+
|
| 17 |
+
base_path = "" # Fill in directory of the pretrained checkpoints, e.g., "...../data_and_model/"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_cal_atac_orale(device=None):
|
| 21 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
| 22 |
+
ckpt_path = os.path.join(base_path, 'mdlm/gosai_data/binary_atac_cell_lines.ckpt')
|
| 23 |
+
|
| 24 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
| 25 |
+
|
| 26 |
+
hp = ckpt.get("hyper_parameters", {})
|
| 27 |
+
ckpt.setdefault("data_params", hp.get("data_params", {}))
|
| 28 |
+
ckpt.setdefault("performance", {})
|
| 29 |
+
|
| 30 |
+
if not ckpt["performance"]:
|
| 31 |
+
ckpt["performance"] = {
|
| 32 |
+
"best_step": ckpt.get("global_step", 0),
|
| 33 |
+
"best_metric": None,
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
# Load model from in-memory checkpoint (no file I/O needed)
|
| 37 |
+
buffer = io.BytesIO()
|
| 38 |
+
torch.save(ckpt, buffer)
|
| 39 |
+
buffer.seek(0) # Reset buffer position to the beginning
|
| 40 |
+
|
| 41 |
+
model_load = LightningModel.load_from_checkpoint(buffer, map_location="cpu")
|
| 42 |
+
model_load.to(device)
|
| 43 |
+
|
| 44 |
+
model_load.train_params['logger'] = None
|
| 45 |
+
|
| 46 |
+
return model_load
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_gosai_oracle(mode='train', device=None):
|
| 52 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
| 53 |
+
if mode == 'train':
|
| 54 |
+
ckpt_path = os.path.join(base_path, "mdlm/outputs_gosai/lightning_logs/reward_oracle_ft.ckpt")
|
| 55 |
+
|
| 56 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
| 57 |
+
|
| 58 |
+
hp = ckpt.get("hyper_parameters", {})
|
| 59 |
+
ckpt.setdefault("data_params", hp.get("data_params", {}))
|
| 60 |
+
ckpt.setdefault("performance", {})
|
| 61 |
+
|
| 62 |
+
if not ckpt["performance"]:
|
| 63 |
+
ckpt["performance"] = {
|
| 64 |
+
"best_step": ckpt.get("global_step", 0),
|
| 65 |
+
"best_metric": None,
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
# Load model from in-memory checkpoint (no file I/O needed)
|
| 69 |
+
buffer = io.BytesIO()
|
| 70 |
+
torch.save(ckpt, buffer)
|
| 71 |
+
buffer.seek(0) # Reset buffer position to the beginning
|
| 72 |
+
|
| 73 |
+
model_load = LightningModel.load_from_checkpoint(buffer, map_location="cpu")
|
| 74 |
+
model_load.to(device)
|
| 75 |
+
|
| 76 |
+
elif mode == 'eval':
|
| 77 |
+
|
| 78 |
+
ckpt_path = os.path.join(base_path, "mdlm/outputs_gosai/lightning_logs/reward_oracle_eval.ckpt")
|
| 79 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
| 80 |
+
|
| 81 |
+
hp = ckpt.get("hyper_parameters", {})
|
| 82 |
+
ckpt.setdefault("data_params", hp.get("data_params", {})) # safe default
|
| 83 |
+
ckpt.setdefault("performance", {}) # safe default
|
| 84 |
+
# Optional: add minimal hints if code later reads fields
|
| 85 |
+
if not ckpt["performance"]:
|
| 86 |
+
ckpt["performance"] = {
|
| 87 |
+
"best_step": ckpt.get("global_step", 0),
|
| 88 |
+
"best_metric": None,
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
# Load model from in-memory checkpoint (no file I/O needed)
|
| 92 |
+
buffer = io.BytesIO()
|
| 93 |
+
torch.save(ckpt, buffer)
|
| 94 |
+
buffer.seek(0) # Reset buffer position to the beginning
|
| 95 |
+
|
| 96 |
+
model_load = LightningModel.load_from_checkpoint(buffer, map_location="cpu")
|
| 97 |
+
model_load.to(device)
|
| 98 |
+
else:
|
| 99 |
+
raise ValueError
|
| 100 |
+
|
| 101 |
+
model_load.train_params['logger'] = None
|
| 102 |
+
|
| 103 |
+
return model_load
|
| 104 |
+
|
| 105 |
+
def cal_gosai_pred(seqs, model=None, mode='eval'):
|
| 106 |
+
"""
|
| 107 |
+
seqs: list of sequences (detokenized ACGT...)
|
| 108 |
+
"""
|
| 109 |
+
if model is None:
|
| 110 |
+
model = get_gosai_oracle(mode=mode)
|
| 111 |
+
df_seqs = pd.DataFrame(seqs, columns=['seq'])
|
| 112 |
+
pred_dataset = grelu.data.dataset.DFSeqDataset(df_seqs)
|
| 113 |
+
preds = model.predict_on_dataset(pred_dataset, devices=[0])
|
| 114 |
+
return preds.squeeze() # numpy array with shape [n_seqs, 3]
|
| 115 |
+
|
| 116 |
+
def cal_gosai_pred_new(seqs, model=None, mode='eval'):
|
| 117 |
+
"""
|
| 118 |
+
seqs: list of sequences (detokenized ACGT...)
|
| 119 |
+
"""
|
| 120 |
+
if model is None:
|
| 121 |
+
model = get_gosai_oracle(mode=mode)
|
| 122 |
+
model.eval()
|
| 123 |
+
tokens = dataloader_gosai.batch_dna_tokenize(seqs)
|
| 124 |
+
tokens = torch.tensor(tokens).long().to(model.device)
|
| 125 |
+
onehot_tokens = F.one_hot(tokens, num_classes=4).float()
|
| 126 |
+
preds = model(onehot_tokens.float().transpose(1, 2)).detach().cpu().numpy()
|
| 127 |
+
return preds.squeeze()
|
| 128 |
+
|
| 129 |
+
def cal_atac_pred(seqs, model=None):
|
| 130 |
+
"""
|
| 131 |
+
seqs: list of sequences (detokenized ACGT...)
|
| 132 |
+
"""
|
| 133 |
+
if model is None:
|
| 134 |
+
model = LightningModel.load_from_checkpoint(os.path.join(base_path, 'mdlm/gosai_data/binary_atac_cell_lines.ckpt'), map_location='cuda')
|
| 135 |
+
df_seqs = pd.DataFrame(seqs, columns=['seq'])
|
| 136 |
+
pred_dataset = grelu.data.dataset.DFSeqDataset(df_seqs)
|
| 137 |
+
preds = model.predict_on_dataset(pred_dataset, devices=[0])
|
| 138 |
+
return preds.squeeze() # numpy array with shape [n_seqs, 7]
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def cal_atac_pred_new(seqs, model=None):
|
| 142 |
+
"""
|
| 143 |
+
seqs: list of sequences (detokenized ACGT...)
|
| 144 |
+
"""
|
| 145 |
+
if model is None:
|
| 146 |
+
model = LightningModel.load_from_checkpoint(os.path.join(base_path, 'mdlm/gosai_data/binary_atac_cell_lines.ckpt'), map_location='cuda')
|
| 147 |
+
model.eval()
|
| 148 |
+
tokens = dataloader_gosai.batch_dna_tokenize(seqs)
|
| 149 |
+
tokens = torch.tensor(tokens).long().to(model.device)
|
| 150 |
+
onehot_tokens = F.one_hot(tokens, num_classes=4).float()
|
| 151 |
+
preds = model(onehot_tokens.float().transpose(1, 2)).detach().cpu().numpy()
|
| 152 |
+
return preds.squeeze() # numpy array with shape [n_seqs, 7]
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def count_kmers(seqs, k=3):
|
| 156 |
+
counts = {}
|
| 157 |
+
for seq in seqs:
|
| 158 |
+
for i in range(len(seq) - k + 1):
|
| 159 |
+
subseq = seq[i : i + k]
|
| 160 |
+
try:
|
| 161 |
+
counts[subseq] += 1
|
| 162 |
+
except KeyError:
|
| 163 |
+
counts[subseq] = 1
|
| 164 |
+
return counts
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def subset_for_eval(n=5000, seed=0):
|
| 168 |
+
train_set = dataloader_gosai.get_datasets_gosai()
|
| 169 |
+
np.random.seed(seed)
|
| 170 |
+
torch.manual_seed(seed)
|
| 171 |
+
train_set_sp = torch.utils.data.Subset(train_set, np.random.choice(len(train_set), n, replace=False))
|
| 172 |
+
return train_set_sp
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def subset_eval_groundtruth(sets_sp):
|
| 176 |
+
train_set_sp = sets_sp
|
| 177 |
+
train_set_sp_clss = train_set_sp.dataset.clss[train_set_sp.indices]
|
| 178 |
+
return train_set_sp_clss
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def subset_eval_preds(sets_sp, oracle_model=None):
|
| 182 |
+
train_set_sp = sets_sp
|
| 183 |
+
train_preds = cal_gosai_pred(
|
| 184 |
+
dataloader_gosai.batch_dna_detokenize(train_set_sp.dataset.seqs[train_set_sp.indices].numpy()), oracle_model)
|
| 185 |
+
return train_preds
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def subset_eval_kmers(sets_sp, k=3):
|
| 189 |
+
train_set_sp = sets_sp
|
| 190 |
+
train_seqs = dataloader_gosai.batch_dna_detokenize(train_set_sp.dataset.seqs[train_set_sp.indices].numpy())
|
| 191 |
+
train_kmers = count_kmers(train_seqs, k)
|
| 192 |
+
return train_kmers
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def subset_eval_embs(sets_sp, oracle_model=None):
|
| 196 |
+
train_set_sp = sets_sp
|
| 197 |
+
train_sp_emb = cal_gosai_emb(
|
| 198 |
+
dataloader_gosai.batch_dna_detokenize(train_set_sp.dataset.seqs[train_set_sp.indices].numpy()), oracle_model)
|
| 199 |
+
return train_sp_emb
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def cal_emb_pca(sets_sp, n_components=50, oracle_model=None):
|
| 203 |
+
train_set_sp = sets_sp
|
| 204 |
+
train_sp_emb = cal_gosai_emb(
|
| 205 |
+
dataloader_gosai.batch_dna_detokenize(train_set_sp.dataset.seqs[train_set_sp.indices].numpy()), oracle_model)
|
| 206 |
+
from sklearn.decomposition import PCA
|
| 207 |
+
pca = PCA(n_components=n_components)
|
| 208 |
+
pca.fit(train_sp_emb.reshape(train_sp_emb.shape[0], -1))
|
| 209 |
+
return pca
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def subset_eval_embs_pca(sets_sp, pca, oracle_model=None):
|
| 213 |
+
train_sp_emb = subset_eval_embs(sets_sp, oracle_model)
|
| 214 |
+
train_sp_emb_pca = pca.transform(train_sp_emb.reshape(train_sp_emb.shape[0], -1))
|
| 215 |
+
return train_sp_emb_pca
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
# https://github.com/HannesStark/dirichlet-flow-matching/blob/main/utils/flow_utils.py
|
| 219 |
+
def get_wasserstein_dist(embeds1, embeds2):
|
| 220 |
+
if np.isnan(embeds2).any() or np.isnan(embeds1).any() or len(embeds1) == 0 or len(embeds2) == 0:
|
| 221 |
+
return float('nan')
|
| 222 |
+
mu1, sigma1 = embeds1.mean(axis=0), np.cov(embeds1, rowvar=False)
|
| 223 |
+
mu2, sigma2 = embeds2.mean(axis=0), np.cov(embeds2, rowvar=False)
|
| 224 |
+
ssdiff = np.sum((mu1 - mu2) ** 2.0)
|
| 225 |
+
covmean = sqrtm(sigma1.dot(sigma2))
|
| 226 |
+
if np.iscomplexobj(covmean):
|
| 227 |
+
covmean = covmean.real
|
| 228 |
+
dist = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
|
| 229 |
+
return dist
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def embed_on_dataset(
|
| 233 |
+
model,
|
| 234 |
+
dataset: Callable,
|
| 235 |
+
devices: Union[str, int, List[int]] = "cpu",
|
| 236 |
+
num_workers: int = 1,
|
| 237 |
+
batch_size: int = 256,
|
| 238 |
+
):
|
| 239 |
+
"""
|
| 240 |
+
Return embeddings for a dataset of sequences
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
dataset: Dataset object that yields one-hot encoded sequences
|
| 244 |
+
devices: Device IDs to use
|
| 245 |
+
num_workers: Number of workers for data loader
|
| 246 |
+
batch_size: Batch size for data loader
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
Numpy array of shape (B, T, L) containing embeddings.
|
| 250 |
+
"""
|
| 251 |
+
torch.set_float32_matmul_precision("medium")
|
| 252 |
+
|
| 253 |
+
# Make dataloader
|
| 254 |
+
dataloader = model.make_predict_loader(
|
| 255 |
+
dataset, num_workers=num_workers, batch_size=batch_size
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
# Get device
|
| 259 |
+
orig_device = model.device
|
| 260 |
+
device = model.parse_devices(devices)[1]
|
| 261 |
+
if isinstance(device, list):
|
| 262 |
+
device = device[0]
|
| 263 |
+
model.to(device)
|
| 264 |
+
|
| 265 |
+
# Get embeddings
|
| 266 |
+
preds = []
|
| 267 |
+
model.model = model.model.eval()
|
| 268 |
+
for batch in iter(dataloader):
|
| 269 |
+
batch = batch.to(device)
|
| 270 |
+
preds.append(model.model.embedding(batch).detach().cpu())
|
| 271 |
+
|
| 272 |
+
# Return to original device
|
| 273 |
+
model.to(orig_device)
|
| 274 |
+
return torch.vstack(preds).numpy()
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def cal_gosai_emb(seqs, model=None):
|
| 278 |
+
"""
|
| 279 |
+
seqs: list of sequences (detokenized ACGT...)
|
| 280 |
+
"""
|
| 281 |
+
if model is None:
|
| 282 |
+
model = get_gosai_oracle()
|
| 283 |
+
df_seqs = pd.DataFrame(seqs, columns=['seq'])
|
| 284 |
+
pred_dataset = grelu.data.dataset.DFSeqDataset(df_seqs)
|
| 285 |
+
embs = embed_on_dataset(model, pred_dataset, devices=[0])
|
| 286 |
+
return embs # numpy array with shape [n_seqs, 3072, 2]
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def cal_highexp_kmers(k=3, return_clss=False):
|
| 290 |
+
train_set = dataloader_gosai.get_datasets_gosai()
|
| 291 |
+
exp_threshold = np.quantile(train_set.clss[:, 0].numpy(), 0.99) # 4.56
|
| 292 |
+
highexp_indices = [i for i, data in enumerate(train_set) if data['clss'][0] > exp_threshold]
|
| 293 |
+
highexp_set_sp = torch.utils.data.Subset(train_set, highexp_indices)
|
| 294 |
+
highexp_seqs = dataloader_gosai.batch_dna_detokenize(highexp_set_sp.dataset.seqs[highexp_set_sp.indices].numpy())
|
| 295 |
+
highexp_kmers_99 = count_kmers(highexp_seqs, k=k)
|
| 296 |
+
n_highexp_kmers_99 = len(highexp_indices)
|
| 297 |
+
|
| 298 |
+
exp_threshold = np.quantile(train_set.clss[:, 0].numpy(), 0.999) # 6.27
|
| 299 |
+
highexp_indices = [i for i, data in enumerate(train_set) if data['clss'][0] > exp_threshold]
|
| 300 |
+
highexp_set_sp = torch.utils.data.Subset(train_set, highexp_indices)
|
| 301 |
+
highexp_seqs = dataloader_gosai.batch_dna_detokenize(highexp_set_sp.dataset.seqs[highexp_set_sp.indices].numpy())
|
| 302 |
+
highexp_kmers_999 = count_kmers(highexp_seqs, k=k)
|
| 303 |
+
n_highexp_kmers_999 = len(highexp_indices)
|
| 304 |
+
|
| 305 |
+
if return_clss:
|
| 306 |
+
highexp_set_sp_clss_999 = highexp_set_sp.dataset.clss[highexp_set_sp.indices]
|
| 307 |
+
highexp_preds_999 = cal_gosai_pred_new(
|
| 308 |
+
dataloader_gosai.batch_dna_detokenize(highexp_set_sp.dataset.seqs[highexp_set_sp.indices].numpy()))
|
| 309 |
+
return highexp_kmers_99, n_highexp_kmers_99, highexp_kmers_999, n_highexp_kmers_999, highexp_set_sp_clss_999, highexp_preds_999, highexp_seqs
|
| 310 |
+
|
| 311 |
+
return highexp_kmers_99, n_highexp_kmers_99, highexp_kmers_999, n_highexp_kmers_999
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def cal_kmer_corr(model, highexp_kmers, n_highexp_kmers, n_sample=128):
|
| 315 |
+
model.eval()
|
| 316 |
+
all_detoeknized_samples = []
|
| 317 |
+
for _ in range(10):
|
| 318 |
+
samples = model._sample(eval_sp_size=n_sample).detach().cpu().numpy()
|
| 319 |
+
detokenized_samples = dataloader_gosai.batch_dna_detokenize(samples)
|
| 320 |
+
all_detoeknized_samples.extend(detokenized_samples)
|
| 321 |
+
generated_kmer = count_kmers(all_detoeknized_samples)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
kmer_set = set(highexp_kmers.keys()) | set(generated_kmer.keys())
|
| 325 |
+
counts = np.zeros((len(kmer_set), 2))
|
| 326 |
+
for i, kmer in enumerate(kmer_set):
|
| 327 |
+
if kmer in highexp_kmers:
|
| 328 |
+
counts[i][1] = highexp_kmers[kmer] * len(generated_kmer) / n_highexp_kmers
|
| 329 |
+
if kmer in generated_kmer:
|
| 330 |
+
counts[i][0] = generated_kmer[kmer]
|
| 331 |
+
|
| 332 |
+
corr = pearsonr(counts[:, 0], counts[:, 1])[0]
|
| 333 |
+
return corr
|
| 334 |
+
|
| 335 |
+
def cal_avg_likelihood(model, old_model, n_sample=128):
|
| 336 |
+
model.eval()
|
| 337 |
+
old_model.eval()
|
| 338 |
+
all_raw_samples = []
|
| 339 |
+
for _ in range(10):
|
| 340 |
+
samples = model._sample(eval_sp_size=n_sample)
|
| 341 |
+
all_raw_samples.append(samples)
|
| 342 |
+
all_raw_samples = torch.concat(all_raw_samples)
|
| 343 |
+
avg_likelihood = old_model._forward_pass_diffusion(all_raw_samples).sum(-1).mean().item()
|
| 344 |
+
return avg_likelihood
|
tr2d2-dna/run_batch_eval.sh
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=dna
|
| 3 |
+
#SBATCH --partition=coe-gpu
|
| 4 |
+
#SBATCH --gres=gpu:H200:1
|
| 5 |
+
#SBATCH --time=16:00:00
|
| 6 |
+
#SBATCH --mem-per-gpu=60G
|
| 7 |
+
#SBATCH --cpus-per-task=2
|
| 8 |
+
#SBATCH --wait-all-nodes=1
|
| 9 |
+
#SBATCH --output=../outputs/%j.%x/.log
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# Set the path to your runs directory
|
| 13 |
+
RUNS_DIR="" # Fill in directory of which to eval the checkpoints
|
| 14 |
+
|
| 15 |
+
# Set output file name with timestamp
|
| 16 |
+
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
|
| 17 |
+
OUTPUT_FILE="batch_eval_results_${TIMESTAMP}.txt"
|
| 18 |
+
|
| 19 |
+
# Run the batch evaluation
|
| 20 |
+
python eval_runs_batch.py \
|
| 21 |
+
--runs_dir "$RUNS_DIR" \
|
| 22 |
+
--output_file "$OUTPUT_FILE" \
|
| 23 |
+
--device "cuda:0" \
|
| 24 |
+
--total_num_steps 128 \
|
| 25 |
+
--batch_size 128 \
|
| 26 |
+
--num_seeds 3 \
|
| 27 |
+
--total_samples 640 \
|
| 28 |
+
--seq_length 200
|
| 29 |
+
|
| 30 |
+
echo "Batch evaluation completed. Results saved to: $OUTPUT_FILE"
|
tr2d2-dna/train.sh
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
#SBATCH --job-name=dna_mdns
|
| 4 |
+
#SBATCH --partition=coe-gpu
|
| 5 |
+
#SBATCH --gres=gpu:H100:1
|
| 6 |
+
#SBATCH --time=16:00:00
|
| 7 |
+
# max 16 GPU hours, i.e., time <= 16h / num of GPUs
|
| 8 |
+
#SBATCH --mem-per-gpu=60G
|
| 9 |
+
# maximum GPU RAM, 141G for H200, 94G for H100
|
| 10 |
+
# in the current setting, 40G is enough for num_replicates=2 and 80G is enough for num_replicates=4
|
| 11 |
+
#SBATCH --cpus-per-task=2
|
| 12 |
+
#SBATCH --wait-all-nodes=1
|
| 13 |
+
#SBATCH --output=../outputs/%j.%x/.log
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
HOME_LOC= "" # Fill in directory of the repo
|
| 17 |
+
SAVE_PATH = "" # Fill in directory to save the checkpoints
|
| 18 |
+
BASE_PATH = "" # Fill in directory of the pretrained checkpoints, e.g., "...../data_and_model/"
|
| 19 |
+
SCRIPT_LOC=$HOME_LOC/tr2d2/dna
|
| 20 |
+
LOG_LOC=$HOME_LOC/tr2d2/dna/logs
|
| 21 |
+
DATE=$(date +%m_%d)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
mkdir -p "$LOG_LOC"
|
| 26 |
+
|
| 27 |
+
# set 3 have skip connection
|
| 28 |
+
# ===================================================================
|
| 29 |
+
python $SCRIPT_LOC/finetune.py \
|
| 30 |
+
--base_path $BASE_PATH \
|
| 31 |
+
--device "cuda:0" \
|
| 32 |
+
--noise_removal \
|
| 33 |
+
--save_path_dir $SAVE_PATH \
|
| 34 |
+
--wdce_num_replicates 16 \
|
| 35 |
+
--buffer_size 160 \
|
| 36 |
+
--batch_size 160 \
|
| 37 |
+
--seq_length 200 \
|
| 38 |
+
--num_children 32 \
|
| 39 |
+
--total_num_steps 128 \
|
| 40 |
+
--num_iter 5 \
|
| 41 |
+
--resample_every_n_step 5 \
|
| 42 |
+
--eval_every_n_epochs 10 \
|
| 43 |
+
--num_epochs 60000 \
|
| 44 |
+
--exploration 0.1 \
|
| 45 |
+
--save_every_n_epoch 2000 \
|
| 46 |
+
--alpha 0.1 \
|
| 47 |
+
--centering \
|
| 48 |
+
--grad_clip \
|
| 49 |
+
--reward_clip \
|
| 50 |
+
--reward_clip_value 15.0 \
|
| 51 |
+
--reset_tree
|
tr2d2-dna/utils.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Console logger utilities.
|
| 2 |
+
|
| 3 |
+
Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py
|
| 4 |
+
Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import fsspec
|
| 9 |
+
import lightning
|
| 10 |
+
import torch
|
| 11 |
+
from timm.scheduler import CosineLRScheduler
|
| 12 |
+
import argparse
|
| 13 |
+
import numpy as np
|
| 14 |
+
import random
|
| 15 |
+
import os
|
| 16 |
+
import time, torch
|
| 17 |
+
from collections import defaultdict
|
| 18 |
+
from contextlib import contextmanager
|
| 19 |
+
|
| 20 |
+
class StepTimer:
|
| 21 |
+
def __init__(self, device=None):
|
| 22 |
+
self.times = defaultdict(list)
|
| 23 |
+
self.device = device
|
| 24 |
+
self._use_cuda_sync = (
|
| 25 |
+
isinstance(device, torch.device) and device.type == "cuda"
|
| 26 |
+
) or (isinstance(device, str) and "cuda" in device)
|
| 27 |
+
|
| 28 |
+
@contextmanager
|
| 29 |
+
def section(self, name):
|
| 30 |
+
if self._use_cuda_sync:
|
| 31 |
+
torch.cuda.synchronize()
|
| 32 |
+
t0 = time.perf_counter()
|
| 33 |
+
try:
|
| 34 |
+
yield
|
| 35 |
+
finally:
|
| 36 |
+
if self._use_cuda_sync:
|
| 37 |
+
torch.cuda.synchronize()
|
| 38 |
+
dt = time.perf_counter() - t0
|
| 39 |
+
self.times[name].append(dt)
|
| 40 |
+
|
| 41 |
+
def summary(self, top_k=None):
|
| 42 |
+
# returns (name, count, total, mean, p50, p95)
|
| 43 |
+
import numpy as np
|
| 44 |
+
rows = []
|
| 45 |
+
for k, v in self.times.items():
|
| 46 |
+
a = np.array(v, dtype=float)
|
| 47 |
+
rows.append((k, len(a), a.sum(), a.mean(), np.median(a), np.percentile(a, 95)))
|
| 48 |
+
rows.sort(key=lambda r: r[2], reverse=True) # by total time
|
| 49 |
+
return rows[:top_k] if top_k else rows
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def sample_categorical_logits(logits, dtype=torch.float64):
|
| 53 |
+
# do not require logits to be log-softmaxed
|
| 54 |
+
gumbel_noise = -(1e-10 - (torch.rand_like(logits, dtype=dtype) + 1e-10).log()).log()
|
| 55 |
+
return (logits + gumbel_noise).argmax(dim=-1)
|
| 56 |
+
|
| 57 |
+
def fsspec_exists(filename):
|
| 58 |
+
"""Check if a file exists using fsspec."""
|
| 59 |
+
fs, _ = fsspec.core.url_to_fs(filename)
|
| 60 |
+
return fs.exists(filename)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def fsspec_listdir(dirname):
|
| 64 |
+
"""Listdir in manner compatible with fsspec."""
|
| 65 |
+
fs, _ = fsspec.core.url_to_fs(dirname)
|
| 66 |
+
return fs.ls(dirname)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def fsspec_mkdirs(dirname, exist_ok=True):
|
| 70 |
+
"""Mkdirs in manner compatible with fsspec."""
|
| 71 |
+
fs, _ = fsspec.core.url_to_fs(dirname)
|
| 72 |
+
fs.makedirs(dirname, exist_ok=exist_ok)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def print_nans(tensor, name):
|
| 76 |
+
if torch.isnan(tensor).any():
|
| 77 |
+
print(name, tensor)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class CosineDecayWarmupLRScheduler(
|
| 81 |
+
CosineLRScheduler,
|
| 82 |
+
torch.optim.lr_scheduler._LRScheduler):
|
| 83 |
+
"""Wrap timm.scheduler.CosineLRScheduler
|
| 84 |
+
Enables calling scheduler.step() without passing in epoch.
|
| 85 |
+
Supports resuming as well.
|
| 86 |
+
Adapted from:
|
| 87 |
+
https://github.com/HazyResearch/hyena-dna/blob/main/src/utils/optim/schedulers.py
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(self, *args, **kwargs):
|
| 91 |
+
super().__init__(*args, **kwargs)
|
| 92 |
+
self._last_epoch = -1
|
| 93 |
+
self.step(epoch=0)
|
| 94 |
+
|
| 95 |
+
def step(self, epoch=None):
|
| 96 |
+
if epoch is None:
|
| 97 |
+
self._last_epoch += 1
|
| 98 |
+
else:
|
| 99 |
+
self._last_epoch = epoch
|
| 100 |
+
# We call either step or step_update, depending on
|
| 101 |
+
# whether we're using the scheduler every epoch or every
|
| 102 |
+
# step.
|
| 103 |
+
# Otherwise, lightning will always call step (i.e.,
|
| 104 |
+
# meant for each epoch), and if we set scheduler
|
| 105 |
+
# interval to "step", then the learning rate update will
|
| 106 |
+
# be wrong.
|
| 107 |
+
if self.t_in_epochs:
|
| 108 |
+
super().step(epoch=self._last_epoch)
|
| 109 |
+
else:
|
| 110 |
+
super().step_update(num_updates=self._last_epoch)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class LoggingContext:
|
| 114 |
+
"""Context manager for selective logging."""
|
| 115 |
+
def __init__(self, logger, level=None, handler=None, close=True):
|
| 116 |
+
self.logger = logger
|
| 117 |
+
self.level = level
|
| 118 |
+
self.handler = handler
|
| 119 |
+
self.close = close
|
| 120 |
+
|
| 121 |
+
def __enter__(self):
|
| 122 |
+
if self.level is not None:
|
| 123 |
+
self.old_level = self.logger.level
|
| 124 |
+
self.logger.setLevel(self.level)
|
| 125 |
+
if self.handler:
|
| 126 |
+
self.logger.addHandler(self.handler)
|
| 127 |
+
|
| 128 |
+
def __exit__(self, et, ev, tb):
|
| 129 |
+
if self.level is not None:
|
| 130 |
+
self.logger.setLevel(self.old_level)
|
| 131 |
+
if self.handler:
|
| 132 |
+
self.logger.removeHandler(self.handler)
|
| 133 |
+
if self.handler and self.close:
|
| 134 |
+
self.handler.close()
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
|
| 138 |
+
"""Initializes multi-GPU-friendly python logger."""
|
| 139 |
+
|
| 140 |
+
logger = logging.getLogger(name)
|
| 141 |
+
logger.setLevel(level)
|
| 142 |
+
|
| 143 |
+
# this ensures all logging levels get marked with the rank zero decorator
|
| 144 |
+
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
|
| 145 |
+
for level in ('debug', 'info', 'warning', 'error',
|
| 146 |
+
'exception', 'fatal', 'critical'):
|
| 147 |
+
setattr(logger,
|
| 148 |
+
level,
|
| 149 |
+
lightning.pytorch.utilities.rank_zero_only(
|
| 150 |
+
getattr(logger, level)))
|
| 151 |
+
|
| 152 |
+
return logger
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def str2bool(v):
|
| 156 |
+
if isinstance(v, bool):
|
| 157 |
+
return v
|
| 158 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
| 159 |
+
return True
|
| 160 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
| 161 |
+
return False
|
| 162 |
+
else:
|
| 163 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def set_seed(seed, use_cuda):
|
| 167 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
| 168 |
+
np.random.seed(seed)
|
| 169 |
+
random.seed(seed)
|
| 170 |
+
torch.manual_seed(seed)
|
| 171 |
+
# torch.backends.cudnn.deterministic = True
|
| 172 |
+
if use_cuda:
|
| 173 |
+
torch.cuda.manual_seed(seed)
|
| 174 |
+
torch.cuda.manual_seed_all(seed)
|
| 175 |
+
print(f'=> Seed of the run set to {seed}')
|