FLARE / app_utils /model_utils.py
yzhouchen001's picture
update
19a4dfc
raw
history blame
1.06 kB
import sys
# sys.path.insert(0, "/data/yzhouc01/MassSpecGym")
# sys.path.insert(0, "/data/yzhouc01/FILIP-MS")
from rdkit import RDLogger
from flare.utils.data import get_spec_featurizer, get_mol_featurizer, get_ms_dataset
from flare.utils.models import get_model
import yaml
# Suppress RDKit warnings and errors
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)
# Load model and data
def load_model_components():
param_pth = '/data/yzhouc01/FILIP-MS/experiments/20250913_optimized_filip-model/lightning_logs/version_0/hparams.yaml'
with open(param_pth) as f:
params = yaml.load(f, Loader=yaml.FullLoader)
spec_featurizer = get_spec_featurizer(params['spectra_view'], params)
mol_featurizer = get_mol_featurizer(params['molecule_view'], params)
# load model
checkpoint_pth = "/data/yzhouc01/FILIP-MS/experiments/20250913_optimized_filip-model/epoch=1993-train_loss=0.10.ckpt"
params['checkpoint_pth'] = checkpoint_pth
model = get_model(params['model'], params)
return spec_featurizer, mol_featurizer, model