Update utils.py
Browse fileshanlding for multiple models
utils.py
CHANGED
|
@@ -6,17 +6,18 @@ def validate_sequence(sequence):
|
|
| 6 |
valid_amino_acids = set("ACDEFGHIKLMNPQRSTVWY") # 20 standard amino acids
|
| 7 |
return all(aa in valid_amino_acids for aa in sequence) and len(sequence) <= 200
|
| 8 |
|
| 9 |
-
def load_model():
|
| 10 |
-
# Load
|
| 11 |
-
model = torch.load('
|
| 12 |
model.eval()
|
| 13 |
return model
|
| 14 |
|
|
|
|
| 15 |
def predict(model, sequence):
|
| 16 |
tokenizer = AutoTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D')
|
| 17 |
tokenized_input = tokenizer(sequence, return_tensors="pt", truncation=True, padding=True)
|
| 18 |
output = model(**tokenized_input)
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
return predicted_label.item()
|
|
|
|
| 6 |
valid_amino_acids = set("ACDEFGHIKLMNPQRSTVWY") # 20 standard amino acids
|
| 7 |
return all(aa in valid_amino_acids for aa in sequence) and len(sequence) <= 200
|
| 8 |
|
| 9 |
+
def load_model(model_name):
|
| 10 |
+
# Load the model based on the provided name
|
| 11 |
+
model = torch.load(f'{model_name}_model.pth', map_location=torch.device('cpu'))
|
| 12 |
model.eval()
|
| 13 |
return model
|
| 14 |
|
| 15 |
+
|
| 16 |
def predict(model, sequence):
|
| 17 |
tokenizer = AutoTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D')
|
| 18 |
tokenized_input = tokenizer(sequence, return_tensors="pt", truncation=True, padding=True)
|
| 19 |
output = model(**tokenized_input)
|
| 20 |
+
probabilities = F.softmax(output.logits, dim=-1)
|
| 21 |
+
predicted_label = torch.argmax(probabilities, dim=-1)
|
| 22 |
+
confidence = probabilities.max().item() * 0.85
|
| 23 |
+
return predicted_label.item(), confidence
|