File size: 4,497 Bytes
d9df210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import argparse
import datetime

import os
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from rdkit import RDLogger
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping


from mvp.data.data_module import ContrastiveDataModule

from mvp.definitions import TEST_RESULTS_DIR
import yaml
from mvp.data.datasets import ContrastiveDataset
from functools import partial

from mvp.utils.data import get_ms_dataset, get_spec_featurizer, get_mol_featurizer
from mvp.utils.models import get_model
# Suppress RDKit warnings and errors
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

parser = argparse.ArgumentParser()
parser.add_argument("--param_pth", type=str, default="params_formSpec.yaml")

def main(params):
    # Seed everything
    pl.seed_everything(params['seed'])
        
    # Init paths to data files
    if params['debug']:
        params['dataset_pth'] = "../data/sample/data.tsv"
        params['candidates_pth'] =None
        params['split_pth']=None

    # Load dataset
    spec_featurizer = get_spec_featurizer(params['spectra_view'], params)
    mol_featurizer = get_mol_featurizer(params['molecule_view'], params)
    dataset = get_ms_dataset(params['spectra_view'], params['molecule_view'], spec_featurizer, mol_featurizer, params)
    
    # Init data module
    collate_fn = partial(ContrastiveDataset.collate_fn, spec_enc=params['spec_enc'], spectra_view=params['spectra_view'], mask_peak_ratio=params['mask_peak_ratio'], aug_cands=params['aug_cands'])
    data_module = ContrastiveDataModule(
        dataset=dataset,
        collate_fn=collate_fn,
        split_pth=params['split_pth'],
        batch_size=params['batch_size'],
        num_workers=params['num_workers'],
    )

    model = get_model(params['model'], params)

    # Init logger
    if params['no_wandb']:
        logger = None
    else:
        logger = pl.loggers.WandbLogger(
            save_dir=params['experiment_dir'],
            dir=params['experiment_dir'],
            log_dir=params['experiment_dir'],
            name=params['run_name'],
            project=params['project_name'],
            log_model=False,
            config=model.hparams
        )

    # Init callbacks for checkpointing and early stopping
    callbacks = [pl.callbacks.ModelCheckpoint(save_last=False) ]
    for i, monitor in enumerate(model.get_checkpoint_monitors()):
        monitor_name = monitor['monitor']
        checkpoint = pl.callbacks.ModelCheckpoint(
            monitor=monitor_name,
            save_top_k=1,
            mode=monitor['mode'],
            dirpath=params['experiment_dir'],
            filename=f'{{epoch}}-{{{monitor_name}:.2f}}',
            # filename='{epoch}-{val_loss:.2f}-{train_loss:.2f}',
            auto_insert_metric_name=True,
            save_last=(i == 0)
        )
        callbacks.append(checkpoint)
        if monitor.get('early_stopping', False):
            early_stopping = EarlyStopping(
                monitor=monitor_name,
                mode=monitor['mode'],
                verbose=True,
                patience=params['early_stopping_patience'],
            )
            callbacks.append(early_stopping)

    # Init trainer
    trainer = Trainer(
        accelerator=params['accelerator'],
        devices=params['devices'],
        max_epochs=params['max_epochs'],
        logger=logger,
        log_every_n_steps=params['log_every_n_steps'],
        val_check_interval=params['val_check_interval'],
        callbacks=callbacks,
        default_root_dir=params['experiment_dir'],
    )

    # Prepare data module to validate or test before training
    data_module.prepare_data()
    data_module.setup()


    # Validate before training
    trainer.validate(model, datamodule=data_module)

    # Train
    trainer.fit(model, datamodule=data_module)
        


if __name__ == "__main__":
    args = parser.parse_args([] if "__file__" not in globals() else None)

    # Get current time
    now = datetime.datetime.now()
    now_formatted = now.strftime("%Y%m%d")

    # Load
    with open(args.param_pth) as f:
        params = yaml.load(f, Loader=yaml.FullLoader)

    experiment_dir = str(TEST_RESULTS_DIR / f"{now_formatted}_{params['run_name']}")
    params['experiment_dir'] = experiment_dir

    if not params['df_test_path']:
        params['df_test_path'] = os.path.join(experiment_dir, "result.pkl")

    main(params)