Spaces:
Sleeping
Sleeping
| from mvp.models.spec_encoder import SpecEncMLP_BIN, SpecFormulaTransformer | |
| from mvp.models.mol_encoder import MolEnc | |
| from mvp.models.encoders import MLP | |
| from mvp.models.contrastive import ContrastiveModel, MultiViewContrastive | |
| def get_spec_encoder(spec_enc:str, args): | |
| return {"MLP_BIN": SpecEncMLP_BIN, | |
| "Transformer_Formula": SpecFormulaTransformer}[spec_enc](args) | |
| def get_mol_encoder(mol_enc: str, args): | |
| return {'GNN': MolEnc}[mol_enc](args, in_dim=78) | |
| def get_fp_pred_model(args): | |
| return MLP(in_dim=args.final_embedding_dim, hidden_dims=[args.fp_size], final_activation='sigmoid', dropout=args.fp_dropout) | |
| def get_fp_enc_model(args): | |
| return MLP(in_dim=args.fp_size, hidden_dims=[args.final_embedding_dim,args.final_embedding_dim*2,args.final_embedding_dim,], final_activation=None, dropout=0.0) | |
| def get_model(model:str, | |
| params): | |
| if model == 'contrastive': | |
| model= ContrastiveModel(**params) | |
| elif model == "MultiviewContrastive": | |
| model = MultiViewContrastive(**params) | |
| else: | |
| raise Exception(f"Model {model} not implemented.") | |
| # If checkpoint path is provided, load the model from the checkpoint instead | |
| if params['checkpoint_pth'] is not None and params['checkpoint_pth'] != "": | |
| model = type(model).load_from_checkpoint( | |
| params['checkpoint_pth'], | |
| log_only_loss_at_stages=params['log_only_loss_at_stages'], | |
| df_test_path=params['df_test_path'] | |
| ) | |
| print("Loaded Model from checkpoint") | |
| return model |