File size: 4,309 Bytes
d9df210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import argparse
import datetime
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from rdkit import RDLogger
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from massspecgym.models.base import Stage
import os

from mvp.data.data_module import TestDataModule
from mvp.data.datasets import ContrastiveDataset
from mvp.utils.data import get_spec_featurizer, get_mol_featurizer, get_test_ms_dataset
from mvp.utils.models import get_model

from mvp.definitions import TEST_RESULTS_DIR
import yaml
from functools import partial
# Suppress RDKit warnings and errors
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

parser = argparse.ArgumentParser()
parser.add_argument("--param_pth", type=str, default="params_formSpec.yaml")
parser.add_argument('--checkpoint_pth', type=str, default='')
parser.add_argument('--checkpoint_choice', type=str, default='train', choices=['train', 'val'])
parser.add_argument('--df_test_pth', type=str, help='result file name')
parser.add_argument('--exp_dir', type=str)
parser.add_argument('--candidates_pth', type=str)
def main(params):
    # Seed everything
    pl.seed_everything(params['seed'])
        
    # Init paths to data files
    if params['debug']:
        params['dataset_pth'] = "../data/sample/data.tsv"
        params['split_pth']=None
        params['df_test_path'] = os.path.join(params['experiment_dir'], 'debug_result.pkl')

    # Load dataset
    spec_featurizer = get_spec_featurizer(params['spectra_view'], params)
    mol_featurizer = get_mol_featurizer(params['molecule_view'], params)
    dataset = get_test_ms_dataset(params['spectra_view'], params['molecule_view'], spec_featurizer, mol_featurizer, params)

    # Init data module
    collate_fn = partial(ContrastiveDataset.collate_fn, spec_enc=params['spec_enc'], spectra_view=params['spectra_view'], stage=Stage.TEST)
    data_module = TestDataModule(
        dataset=dataset,
        collate_fn=collate_fn,
        split_pth=params['split_pth'],
        batch_size=params['batch_size'],
        num_workers=params['num_workers']
    )

    model = get_model(params['model'], params)
    model.df_test_path = params['df_test_path']
    
    # Init trainer
    trainer = Trainer(
        accelerator=params['accelerator'],
        devices=params['devices'],
        default_root_dir=params['experiment_dir']
    )

    # Prepare data module to test
    data_module.prepare_data()
    data_module.setup(stage="test")
        
    # Test
    trainer.test(model, datamodule=data_module)


if __name__ == "__main__":
    args = parser.parse_args([] if "__file__" not in globals() else None)

    # Load
    with open(args.param_pth) as f:
        params = yaml.load(f, Loader=yaml.FullLoader)
    
    # Experiment directory
    if args.exp_dir:
        exp_dir = args.exp_dir
    else:
        run_name = params['run_name']
        for exp in os.listdir(TEST_RESULTS_DIR): # find exp dir with matching run_name
            if exp.endswith("_"+run_name):
                exp_dir = str(TEST_RESULTS_DIR / exp)
                break
    if not exp_dir:
        now = datetime.datetime.now().strftime("%Y%m%d")
        exp_dir = str(TEST_RESULTS_DIR / f"{now}_{params['run_name']}")
        os.makedirs(exp_dir, exist_ok=True)
    print("EXPERIMENT directory: ",exp_dir)
    params['experiment_dir'] = exp_dir
    
    # Checkpoint path
    if args.checkpoint_pth:
        params['checkpoint_pth'] = args.checkpoint_pth
    
    if not params['checkpoint_pth']:
        print("No checkpoint provided. Using the checkpoint in the experiment directory")
        for f in os.listdir(exp_dir):
            if f.endswith("ckpt") and f.startswith("epoch") and args.checkpoint_choice in f:
                checkpoint_path = os.path.join(exp_dir, f)
                params['checkpoint_pth'] = checkpoint_path
                break
    assert(params['checkpoint_pth'] != '')
    
    if args.candidates_pth:
        params['candidates_pth'] = args.candidates_pth
    if args.df_test_pth:
        params['df_test_path'] = os.path.join(exp_dir, args.df_test_pth)
    if not params['df_test_path']:
        params['df_test_path'] = os.path.join(exp_dir, f"result_{params['candidates_pth'].split('/')[-1].split('.')[0]}.pkl")
        
    main(params)