Spaces:
Sleeping
Sleeping
File size: 4,497 Bytes
d9df210 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import argparse
import datetime
import os
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from rdkit import RDLogger
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from mvp.data.data_module import ContrastiveDataModule
from mvp.definitions import TEST_RESULTS_DIR
import yaml
from mvp.data.datasets import ContrastiveDataset
from functools import partial
from mvp.utils.data import get_ms_dataset, get_spec_featurizer, get_mol_featurizer
from mvp.utils.models import get_model
# Suppress RDKit warnings and errors
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)
parser = argparse.ArgumentParser()
parser.add_argument("--param_pth", type=str, default="params_formSpec.yaml")
def main(params):
# Seed everything
pl.seed_everything(params['seed'])
# Init paths to data files
if params['debug']:
params['dataset_pth'] = "../data/sample/data.tsv"
params['candidates_pth'] =None
params['split_pth']=None
# Load dataset
spec_featurizer = get_spec_featurizer(params['spectra_view'], params)
mol_featurizer = get_mol_featurizer(params['molecule_view'], params)
dataset = get_ms_dataset(params['spectra_view'], params['molecule_view'], spec_featurizer, mol_featurizer, params)
# Init data module
collate_fn = partial(ContrastiveDataset.collate_fn, spec_enc=params['spec_enc'], spectra_view=params['spectra_view'], mask_peak_ratio=params['mask_peak_ratio'], aug_cands=params['aug_cands'])
data_module = ContrastiveDataModule(
dataset=dataset,
collate_fn=collate_fn,
split_pth=params['split_pth'],
batch_size=params['batch_size'],
num_workers=params['num_workers'],
)
model = get_model(params['model'], params)
# Init logger
if params['no_wandb']:
logger = None
else:
logger = pl.loggers.WandbLogger(
save_dir=params['experiment_dir'],
dir=params['experiment_dir'],
log_dir=params['experiment_dir'],
name=params['run_name'],
project=params['project_name'],
log_model=False,
config=model.hparams
)
# Init callbacks for checkpointing and early stopping
callbacks = [pl.callbacks.ModelCheckpoint(save_last=False) ]
for i, monitor in enumerate(model.get_checkpoint_monitors()):
monitor_name = monitor['monitor']
checkpoint = pl.callbacks.ModelCheckpoint(
monitor=monitor_name,
save_top_k=1,
mode=monitor['mode'],
dirpath=params['experiment_dir'],
filename=f'{{epoch}}-{{{monitor_name}:.2f}}',
# filename='{epoch}-{val_loss:.2f}-{train_loss:.2f}',
auto_insert_metric_name=True,
save_last=(i == 0)
)
callbacks.append(checkpoint)
if monitor.get('early_stopping', False):
early_stopping = EarlyStopping(
monitor=monitor_name,
mode=monitor['mode'],
verbose=True,
patience=params['early_stopping_patience'],
)
callbacks.append(early_stopping)
# Init trainer
trainer = Trainer(
accelerator=params['accelerator'],
devices=params['devices'],
max_epochs=params['max_epochs'],
logger=logger,
log_every_n_steps=params['log_every_n_steps'],
val_check_interval=params['val_check_interval'],
callbacks=callbacks,
default_root_dir=params['experiment_dir'],
)
# Prepare data module to validate or test before training
data_module.prepare_data()
data_module.setup()
# Validate before training
trainer.validate(model, datamodule=data_module)
# Train
trainer.fit(model, datamodule=data_module)
if __name__ == "__main__":
args = parser.parse_args([] if "__file__" not in globals() else None)
# Get current time
now = datetime.datetime.now()
now_formatted = now.strftime("%Y%m%d")
# Load
with open(args.param_pth) as f:
params = yaml.load(f, Loader=yaml.FullLoader)
experiment_dir = str(TEST_RESULTS_DIR / f"{now_formatted}_{params['run_name']}")
params['experiment_dir'] = experiment_dir
if not params['df_test_path']:
params['df_test_path'] = os.path.join(experiment_dir, "result.pkl")
main(params)
|