Spaces:
Running
Running
| import json | |
| import logging | |
| import os | |
| from collections import defaultdict | |
| from typing import Dict, List, Tuple | |
| import mols2grid | |
| import pandas as pd | |
| from gt4sd.algorithms import ( | |
| RegressionTransformerMolecules, | |
| RegressionTransformerProteins, | |
| ) | |
| from gt4sd.algorithms.core import AlgorithmConfiguration | |
| from rdkit import Chem | |
| from terminator.selfies import decoder | |
| logger = logging.getLogger(__name__) | |
| logger.addHandler(logging.NullHandler()) | |
| def get_application(application: str) -> AlgorithmConfiguration: | |
| """ | |
| Convert application name to AlgorithmConfiguration. | |
| Args: | |
| application: Molecules or Proteins | |
| Returns: | |
| The corresponding AlgorithmConfiguration | |
| """ | |
| if application == "Molecules": | |
| application = RegressionTransformerMolecules | |
| elif application == "Proteins": | |
| application = RegressionTransformerProteins | |
| else: | |
| raise ValueError( | |
| "Currently only models for molecules and proteins are supported" | |
| ) | |
| return application | |
| def get_inference_dict( | |
| application: AlgorithmConfiguration, algorithm_version: str | |
| ) -> Dict: | |
| """ | |
| Get inference dictionary for a given application and algorithm version. | |
| Args: | |
| application: algorithm application (Molecules or Proteins) | |
| algorithm_version: algorithm version (e.g. qed) | |
| Returns: | |
| A dictionary with the inference parameters. | |
| """ | |
| config = application(algorithm_version=algorithm_version) | |
| with open(os.path.join(config.ensure_artifacts(), "inference.json"), "r") as f: | |
| data = json.load(f) | |
| return data | |
| def get_rt_name(x: Dict) -> str: | |
| """ | |
| Get the UI display name of the regression transformer. | |
| Args: | |
| x: dictionary with the inference parameters | |
| Returns: | |
| The display name | |
| """ | |
| return ( | |
| x["algorithm_application"].split("Transformer")[-1] | |
| + ": " | |
| + x["algorithm_version"].capitalize() | |
| ) | |
| def draw_grid_predict(prediction: str, target: str, domain: str) -> str: | |
| """ | |
| Uses mols2grid to draw a HTML grid for the prediction | |
| Args: | |
| prediction: Predicted sequence. | |
| target: Target molecule | |
| domain: Domain of the prediction (molecules or proteins) | |
| Returns: | |
| HTML to display | |
| """ | |
| if domain not in ["Molecules", "Proteins"]: | |
| raise ValueError(f"Unsupported domain {domain}") | |
| seq = target.split("|")[-1] | |
| converter = ( | |
| decoder | |
| if domain == "Molecules" | |
| else lambda x: Chem.MolToSmiles(Chem.MolFromFASTA(x)) | |
| ) | |
| try: | |
| seq = converter(seq) | |
| except Exception: | |
| logger.warning(f"Could not draw sequence {seq}") | |
| result = {"SMILES": [seq], "Name": ["Target"]} | |
| # Add properties | |
| for prop in prediction.split("<")[1:]: | |
| result[ | |
| prop.split(">")[0] | |
| ] = f"{prop.split('>')[0].capitalize()} = {prop.split('>')[1]}" | |
| result_df = pd.DataFrame(result) | |
| obj = mols2grid.display( | |
| result_df, | |
| tooltip=list(result.keys()), | |
| height=900, | |
| n_cols=1, | |
| name="Results", | |
| size=(600, 700), | |
| ) | |
| return obj.data | |
| def draw_grid_generate( | |
| samples: List[Tuple[str]], domain: str, n_cols: int = 5, size=(140, 200) | |
| ) -> str: | |
| """ | |
| Uses mols2grid to draw a HTML grid for the generated molecules | |
| Args: | |
| samples: The generated samples (with properties) | |
| domain: Domain of the prediction (molecules or proteins) | |
| n_cols: Number of columns in grid. Defaults to 5. | |
| size: Size of molecule in grid. Defaults to (140, 200). | |
| Returns: | |
| HTML to display | |
| """ | |
| if domain not in ["Molecules", "Proteins"]: | |
| raise ValueError(f"Unsupported domain {domain}") | |
| if domain == "Proteins": | |
| try: | |
| smis = list( | |
| map(lambda x: Chem.MolToSmiles(Chem.MolFromFASTA(x[0])), samples) | |
| ) | |
| except Exception: | |
| logger.warning(f"Could not convert some sequences {samples}") | |
| else: | |
| smis = [s[0] for s in samples] | |
| result = defaultdict(list) | |
| result.update({"SMILES": smis, "Name": [f"sample_{i}" for i in range(len(smis))]}) | |
| # Create properties | |
| properties = [s.split("<")[1] for s in samples[0][1].split(">")[:-1]] | |
| # Fill properties | |
| for sample in samples: | |
| for prop in properties: | |
| value = float(sample[1].split(prop)[-1][1:].split("<")[0]) | |
| result[prop].append(f"{prop} = {value}") | |
| result_df = pd.DataFrame(result) | |
| obj = mols2grid.display( | |
| result_df, | |
| tooltip=list(result.keys()), | |
| height=1100, | |
| n_cols=n_cols, | |
| name="Results", | |
| size=size, | |
| ) | |
| return obj.data | |