ESM-2 for RNA Binding Site Prediction
A small RNA binding site predictor trained on dataset "S1" from Data of protein-RNA binding sites using facebook/esm2_t6_8M_UR50D. The dataset can also be found on Hugging Face here.
This model only has a validation loss of 0.12738210861297214.  
To use, try running:
import torch
from transformers import AutoTokenizer, EsmForTokenClassification
# Define the class mapping
class_mapping = {
    0: 'Not Binding Site',
    1: 'Binding Site',
}
# Load the trained model and tokenizer
model = EsmForTokenClassification.from_pretrained("AmelieSchreiber/esm2_t6_8M_UR50D_rna_binding_site_predictor")
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
# Define the new sequences
new_sequences = [
    'VLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTK',
    'SQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWF',
    # ... add more sequences here ...
]
# Iterate over the new sequences
for seq in new_sequences:
    # Convert sequence to input IDs
    inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=1290, return_tensors="pt")["input_ids"]
    # Apply the model to get the logits
    with torch.no_grad():
        outputs = model(inputs)
    # Get the predictions by picking the label (class) with the highest logit
    predictions = torch.argmax(outputs.logits, dim=-1)
    # Convert the tensor to a list of integers
    prediction_list = predictions.tolist()[0]
    # Convert the predicted class indices to class names
    predicted_labels = [class_mapping[pred] for pred in prediction_list]
    # Create a list that matches each amino acid in the sequence to its predicted class label
    residue_to_label = list(zip(list(seq), predicted_labels))
    # Print out the list
    for i, (residue, predicted_label) in enumerate(residue_to_label):
        print(f"Position {i+1} - {residue}: {predicted_label}")
- Downloads last month
- 1
