MVP / mvp /test.py
yzhouchen001's picture
model code
d9df210
raw
history blame
4.31 kB
import argparse
import datetime
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from rdkit import RDLogger
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from massspecgym.models.base import Stage
import os
from mvp.data.data_module import TestDataModule
from mvp.data.datasets import ContrastiveDataset
from mvp.utils.data import get_spec_featurizer, get_mol_featurizer, get_test_ms_dataset
from mvp.utils.models import get_model
from mvp.definitions import TEST_RESULTS_DIR
import yaml
from functools import partial
# Suppress RDKit warnings and errors
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)
parser = argparse.ArgumentParser()
parser.add_argument("--param_pth", type=str, default="params_formSpec.yaml")
parser.add_argument('--checkpoint_pth', type=str, default='')
parser.add_argument('--checkpoint_choice', type=str, default='train', choices=['train', 'val'])
parser.add_argument('--df_test_pth', type=str, help='result file name')
parser.add_argument('--exp_dir', type=str)
parser.add_argument('--candidates_pth', type=str)
def main(params):
# Seed everything
pl.seed_everything(params['seed'])
# Init paths to data files
if params['debug']:
params['dataset_pth'] = "../data/sample/data.tsv"
params['split_pth']=None
params['df_test_path'] = os.path.join(params['experiment_dir'], 'debug_result.pkl')
# Load dataset
spec_featurizer = get_spec_featurizer(params['spectra_view'], params)
mol_featurizer = get_mol_featurizer(params['molecule_view'], params)
dataset = get_test_ms_dataset(params['spectra_view'], params['molecule_view'], spec_featurizer, mol_featurizer, params)
# Init data module
collate_fn = partial(ContrastiveDataset.collate_fn, spec_enc=params['spec_enc'], spectra_view=params['spectra_view'], stage=Stage.TEST)
data_module = TestDataModule(
dataset=dataset,
collate_fn=collate_fn,
split_pth=params['split_pth'],
batch_size=params['batch_size'],
num_workers=params['num_workers']
)
model = get_model(params['model'], params)
model.df_test_path = params['df_test_path']
# Init trainer
trainer = Trainer(
accelerator=params['accelerator'],
devices=params['devices'],
default_root_dir=params['experiment_dir']
)
# Prepare data module to test
data_module.prepare_data()
data_module.setup(stage="test")
# Test
trainer.test(model, datamodule=data_module)
if __name__ == "__main__":
args = parser.parse_args([] if "__file__" not in globals() else None)
# Load
with open(args.param_pth) as f:
params = yaml.load(f, Loader=yaml.FullLoader)
# Experiment directory
if args.exp_dir:
exp_dir = args.exp_dir
else:
run_name = params['run_name']
for exp in os.listdir(TEST_RESULTS_DIR): # find exp dir with matching run_name
if exp.endswith("_"+run_name):
exp_dir = str(TEST_RESULTS_DIR / exp)
break
if not exp_dir:
now = datetime.datetime.now().strftime("%Y%m%d")
exp_dir = str(TEST_RESULTS_DIR / f"{now}_{params['run_name']}")
os.makedirs(exp_dir, exist_ok=True)
print("EXPERIMENT directory: ",exp_dir)
params['experiment_dir'] = exp_dir
# Checkpoint path
if args.checkpoint_pth:
params['checkpoint_pth'] = args.checkpoint_pth
if not params['checkpoint_pth']:
print("No checkpoint provided. Using the checkpoint in the experiment directory")
for f in os.listdir(exp_dir):
if f.endswith("ckpt") and f.startswith("epoch") and args.checkpoint_choice in f:
checkpoint_path = os.path.join(exp_dir, f)
params['checkpoint_pth'] = checkpoint_path
break
assert(params['checkpoint_pth'] != '')
if args.candidates_pth:
params['candidates_pth'] = args.candidates_pth
if args.df_test_pth:
params['df_test_path'] = os.path.join(exp_dir, args.df_test_pth)
if not params['df_test_path']:
params['df_test_path'] = os.path.join(exp_dir, f"result_{params['candidates_pth'].split('/')[-1].split('.')[0]}.pkl")
main(params)