Spaces:
Sleeping
Sleeping
| import argparse | |
| import datetime | |
| import os | |
| import sys | |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
| from rdkit import RDLogger | |
| import pytorch_lightning as pl | |
| from pytorch_lightning import Trainer | |
| from pytorch_lightning.callbacks.early_stopping import EarlyStopping | |
| from mvp.data.data_module import ContrastiveDataModule | |
| from mvp.definitions import TEST_RESULTS_DIR | |
| import yaml | |
| from mvp.data.datasets import ContrastiveDataset | |
| from functools import partial | |
| from mvp.utils.data import get_ms_dataset, get_spec_featurizer, get_mol_featurizer | |
| from mvp.utils.models import get_model | |
| # Suppress RDKit warnings and errors | |
| lg = RDLogger.logger() | |
| lg.setLevel(RDLogger.CRITICAL) | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--param_pth", type=str, default="params_formSpec.yaml") | |
| def main(params): | |
| # Seed everything | |
| pl.seed_everything(params['seed']) | |
| # Init paths to data files | |
| if params['debug']: | |
| params['dataset_pth'] = "../data/sample/data.tsv" | |
| params['candidates_pth'] =None | |
| params['split_pth']=None | |
| # Load dataset | |
| spec_featurizer = get_spec_featurizer(params['spectra_view'], params) | |
| mol_featurizer = get_mol_featurizer(params['molecule_view'], params) | |
| dataset = get_ms_dataset(params['spectra_view'], params['molecule_view'], spec_featurizer, mol_featurizer, params) | |
| # Init data module | |
| collate_fn = partial(ContrastiveDataset.collate_fn, spec_enc=params['spec_enc'], spectra_view=params['spectra_view'], mask_peak_ratio=params['mask_peak_ratio'], aug_cands=params['aug_cands']) | |
| data_module = ContrastiveDataModule( | |
| dataset=dataset, | |
| collate_fn=collate_fn, | |
| split_pth=params['split_pth'], | |
| batch_size=params['batch_size'], | |
| num_workers=params['num_workers'], | |
| ) | |
| model = get_model(params['model'], params) | |
| # Init logger | |
| if params['no_wandb']: | |
| logger = None | |
| else: | |
| logger = pl.loggers.WandbLogger( | |
| save_dir=params['experiment_dir'], | |
| dir=params['experiment_dir'], | |
| log_dir=params['experiment_dir'], | |
| name=params['run_name'], | |
| project=params['project_name'], | |
| log_model=False, | |
| config=model.hparams | |
| ) | |
| # Init callbacks for checkpointing and early stopping | |
| callbacks = [pl.callbacks.ModelCheckpoint(save_last=False) ] | |
| for i, monitor in enumerate(model.get_checkpoint_monitors()): | |
| monitor_name = monitor['monitor'] | |
| checkpoint = pl.callbacks.ModelCheckpoint( | |
| monitor=monitor_name, | |
| save_top_k=1, | |
| mode=monitor['mode'], | |
| dirpath=params['experiment_dir'], | |
| filename=f'{{epoch}}-{{{monitor_name}:.2f}}', | |
| # filename='{epoch}-{val_loss:.2f}-{train_loss:.2f}', | |
| auto_insert_metric_name=True, | |
| save_last=(i == 0) | |
| ) | |
| callbacks.append(checkpoint) | |
| if monitor.get('early_stopping', False): | |
| early_stopping = EarlyStopping( | |
| monitor=monitor_name, | |
| mode=monitor['mode'], | |
| verbose=True, | |
| patience=params['early_stopping_patience'], | |
| ) | |
| callbacks.append(early_stopping) | |
| # Init trainer | |
| trainer = Trainer( | |
| accelerator=params['accelerator'], | |
| devices=params['devices'], | |
| max_epochs=params['max_epochs'], | |
| logger=logger, | |
| log_every_n_steps=params['log_every_n_steps'], | |
| val_check_interval=params['val_check_interval'], | |
| callbacks=callbacks, | |
| default_root_dir=params['experiment_dir'], | |
| ) | |
| # Prepare data module to validate or test before training | |
| data_module.prepare_data() | |
| data_module.setup() | |
| # Validate before training | |
| trainer.validate(model, datamodule=data_module) | |
| # Train | |
| trainer.fit(model, datamodule=data_module) | |
| if __name__ == "__main__": | |
| args = parser.parse_args([] if "__file__" not in globals() else None) | |
| # Get current time | |
| now = datetime.datetime.now() | |
| now_formatted = now.strftime("%Y%m%d") | |
| # Load | |
| with open(args.param_pth) as f: | |
| params = yaml.load(f, Loader=yaml.FullLoader) | |
| experiment_dir = str(TEST_RESULTS_DIR / f"{now_formatted}_{params['run_name']}") | |
| params['experiment_dir'] = experiment_dir | |
| if not params['df_test_path']: | |
| params['df_test_path'] = os.path.join(experiment_dir, "result.pkl") | |
| main(params) | |