import argparse import datetime import sys sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from rdkit import RDLogger import pytorch_lightning as pl from pytorch_lightning import Trainer from massspecgym.models.base import Stage import os from mvp.data.data_module import TestDataModule from mvp.data.datasets import ContrastiveDataset from mvp.utils.data import get_spec_featurizer, get_mol_featurizer, get_test_ms_dataset from mvp.utils.models import get_model from mvp.definitions import TEST_RESULTS_DIR import yaml from functools import partial # Suppress RDKit warnings and errors lg = RDLogger.logger() lg.setLevel(RDLogger.CRITICAL) parser = argparse.ArgumentParser() parser.add_argument("--param_pth", type=str, default="params_formSpec.yaml") parser.add_argument('--checkpoint_pth', type=str, default='') parser.add_argument('--checkpoint_choice', type=str, default='train', choices=['train', 'val']) parser.add_argument('--df_test_pth', type=str, help='result file name') parser.add_argument('--exp_dir', type=str) parser.add_argument('--candidates_pth', type=str) def main(params): # Seed everything pl.seed_everything(params['seed']) # Init paths to data files if params['debug']: params['dataset_pth'] = "../data/sample/data.tsv" params['split_pth']=None params['df_test_path'] = os.path.join(params['experiment_dir'], 'debug_result.pkl') # Load dataset spec_featurizer = get_spec_featurizer(params['spectra_view'], params) mol_featurizer = get_mol_featurizer(params['molecule_view'], params) dataset = get_test_ms_dataset(params['spectra_view'], params['molecule_view'], spec_featurizer, mol_featurizer, params) # Init data module collate_fn = partial(ContrastiveDataset.collate_fn, spec_enc=params['spec_enc'], spectra_view=params['spectra_view'], stage=Stage.TEST) data_module = TestDataModule( dataset=dataset, collate_fn=collate_fn, split_pth=params['split_pth'], batch_size=params['batch_size'], num_workers=params['num_workers'] ) model = get_model(params['model'], params) model.df_test_path = params['df_test_path'] # Init trainer trainer = Trainer( accelerator=params['accelerator'], devices=params['devices'], default_root_dir=params['experiment_dir'] ) # Prepare data module to test data_module.prepare_data() data_module.setup(stage="test") # Test trainer.test(model, datamodule=data_module) if __name__ == "__main__": args = parser.parse_args([] if "__file__" not in globals() else None) # Load with open(args.param_pth) as f: params = yaml.load(f, Loader=yaml.FullLoader) # Experiment directory if args.exp_dir: exp_dir = args.exp_dir else: run_name = params['run_name'] for exp in os.listdir(TEST_RESULTS_DIR): # find exp dir with matching run_name if exp.endswith("_"+run_name): exp_dir = str(TEST_RESULTS_DIR / exp) break if not exp_dir: now = datetime.datetime.now().strftime("%Y%m%d") exp_dir = str(TEST_RESULTS_DIR / f"{now}_{params['run_name']}") os.makedirs(exp_dir, exist_ok=True) print("EXPERIMENT directory: ",exp_dir) params['experiment_dir'] = exp_dir # Checkpoint path if args.checkpoint_pth: params['checkpoint_pth'] = args.checkpoint_pth if not params['checkpoint_pth']: print("No checkpoint provided. Using the checkpoint in the experiment directory") for f in os.listdir(exp_dir): if f.endswith("ckpt") and f.startswith("epoch") and args.checkpoint_choice in f: checkpoint_path = os.path.join(exp_dir, f) params['checkpoint_pth'] = checkpoint_path break assert(params['checkpoint_pth'] != '') if args.candidates_pth: params['candidates_pth'] = args.candidates_pth if args.df_test_pth: params['df_test_path'] = os.path.join(exp_dir, args.df_test_pth) if not params['df_test_path']: params['df_test_path'] = os.path.join(exp_dir, f"result_{params['candidates_pth'].split('/')[-1].split('.')[0]}.pkl") main(params)