Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import joblib | |
| from concurrent.futures import ThreadPoolExecutor | |
| from transformers import AutoTokenizer, AutoModel, EsmModel | |
| import torch | |
| import numpy as np | |
| import random | |
| import tensorflow as tf | |
| import os | |
| from keras.layers import TFSMLayer | |
| print(f"TensorFlow Version: {tf.__version__}") | |
| base_dir = "." | |
| # Set random seed | |
| SEED = 42 | |
| np.random.seed(SEED) | |
| random.seed(SEED) | |
| torch.manual_seed(SEED) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(SEED) | |
| torch.cuda.manual_seed_all(SEED) | |
| # Ensure deterministic behavior | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| def load_model(model_path): | |
| print(f"Loading model from {model_path}...") | |
| #print(f"Loading model from {model_path} using TFSMLayer...") | |
| #return TFSMLayer(model_path, call_endpoint="serving_default") | |
| #return tf.keras.models.load_model(model_path) | |
| return tf.saved_model.load(model_path) | |
| # Load Random Forest models and configurations | |
| print("Loading models...") | |
| plant_models = { | |
| "Specificity": {"model": joblib.load("Specificity.pkl"), "esm_model": "facebook/esm1b_t33_650M_UR50S", "layer": 6}, | |
| "kcatC": {"model": joblib.load("kcatC.pkl"), "esm_model": "facebook/esm2_t36_3B_UR50D", "layer": 11}, | |
| "KC": {"model": joblib.load("KC.pkl"), "esm_model": "facebook/esm1b_t33_650M_UR50S", "layer": 4}, | |
| } | |
| general_models = { | |
| "Specificity": {"model": load_model(f"Specificity"), "esm_model": "facebook/esm2_t33_650M_UR50D", "layer": 33}, | |
| "kcatC": {"model": load_model(f"kcatC"), "esm_model": "facebook/esm2_t12_35M_UR50D", "layer": 7}, | |
| "KC": {"model": load_model(f"KC"), "esm_model": "facebook/esm2_t30_150M_UR50D", "layer": 26}, | |
| } | |
| # Function to generate embeddings | |
| def get_embedding(sequence, esm_model_name, layer): | |
| print(f"Generating embeddings using {esm_model_name}, Layer {layer}...") | |
| tokenizer = AutoTokenizer.from_pretrained(esm_model_name) | |
| model = EsmModel.from_pretrained(esm_model_name, output_hidden_states=True) | |
| # Tokenize the sequence | |
| inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1024) | |
| # Generate embeddings | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| hidden_states = outputs.hidden_states # Retrieve all hidden states | |
| embedding = hidden_states[layer].mean(dim=1).numpy() # Average pooling | |
| return embedding | |
| def predict_with_gpflow(model, X): | |
| # Convert input to TensorFlow tensor | |
| X_tensor = tf.convert_to_tensor(X, dtype=tf.float64) | |
| # Get predictions | |
| predict_fn = model.predict_f_compiled | |
| mean, variance = predict_fn(X_tensor) | |
| # Return mean and variance as numpy arrays | |
| return mean.numpy().flatten(), variance.numpy().flatten() | |
| # Function to predict based on user choice | |
| def predict(sequence, prediction_type): | |
| # Select the appropriate model set | |
| selected_models = plant_models if prediction_type == "Plant-Specific" else general_models | |
| def process_target(target): | |
| esm_model_name = selected_models[target]["esm_model"] | |
| layer = selected_models[target]["layer"] | |
| model = selected_models[target]["model"] | |
| # Generate embedding | |
| embedding = get_embedding(sequence, esm_model_name, layer) | |
| if prediction_type == "Plant-Specific": | |
| # Random Forest prediction | |
| prediction = model.predict(embedding)[0] | |
| return target, round(prediction, 2) | |
| else: | |
| # GPflow prediction | |
| mean, variance = predict_with_gpflow(model, embedding) | |
| return target, round(mean[0], 2), round(variance[0], 2) | |
| # Predict for all targets in parallel | |
| with ThreadPoolExecutor() as executor: | |
| results = list(executor.map(process_target, selected_models.keys())) | |
| # Format results | |
| if prediction_type == "Plant-Specific": | |
| formatted_results = [ | |
| ["Specificity", results[0][1]], | |
| ["kcat\u1d9c", results[1][1]], | |
| ["K\u1d9c", results[2][1]], | |
| ] | |
| else: | |
| formatted_results = [ | |
| ["Specificity", results[0][1], results[0][2]], | |
| ["kcat\u1d9c", results[1][1], results[1][2]], | |
| ["K\u1d9c", results[2][1], results[2][2]], | |
| ] | |
| return formatted_results | |
| # Define Gradio interface | |
| print("Creating Gradio interface...") | |
| interface = gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.Textbox(label="Input Protein Sequence"), # Input: Text box for sequence | |
| gr.Radio(choices=["Plant-Specific", "General"], label="Prediction Type", value="Plant-Specific"), # Dropdown for selection | |
| ], | |
| outputs=gr.Dataframe( | |
| headers=["Target", "Prediction", "Uncertainty (for General)"], | |
| type="array" | |
| ), # Output: Table | |
| title="Rubisco Kinetics Prediction", | |
| description=( | |
| "Enter a protein sequence to predict Rubisco kinetics properties (Specificity, kcat\u1d9c, and K\u1d9c). " | |
| "Choose between 'Plant-Specific' (Random Forest) or 'General' (GPflow) predictions." | |
| ), | |
| ) | |
| if __name__ == "__main__": | |
| interface.launch() | |