KartikB34 commited on
Commit
b52d4b1
Β·
1 Parent(s): 2fc53ff
Files changed (5) hide show
  1. app.py +107 -0
  2. hdbscan_model.pkl +3 -0
  3. plots/heatmap.png +0 -0
  4. plots/scatter.png +0 -0
  5. requirements.txt +17 -0
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import joblib
4
+ import numpy as np
5
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
6
+ from Bio import SeqIO
7
+ import io
8
+ from sklearn.metrics import silhouette_score, silhouette_samples
9
+ import matplotlib.pyplot as plt
10
+ import seaborn as sns
11
+ import os
12
+
13
+ MODEL_NAME = "InstaDeepAI/nucleotide-transformer-v2-50m-multi-species"
14
+ HDBSCAN_MODEL_PATH = "hdbscan_model.pkl"
15
+ MAX_LENGTH = 20
16
+
17
+ PLOTS_DIR = "plots"
18
+ os.makedirs(PLOTS_DIR, exist_ok=True)
19
+
20
+ print("Loading Transformer...")
21
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
22
+ model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME, trust_remote_code=True)
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ model.to(device).eval()
25
+ print("Transformer loaded.")
26
+
27
+ print("Loading HDBSCAN...")
28
+ clusterer = joblib.load(HDBSCAN_MODEL_PATH)
29
+ print("HDBSCAN loaded.")
30
+
31
+ def seq_to_kmers(seq, k=6):
32
+ seq = seq.upper()
33
+ return " ".join([seq[i:i+k] for i in range(len(seq)-k+1)])
34
+
35
+ def analyze_fasta(fasta_bytes):
36
+ try:
37
+ # βœ… Decode bytes -> string -> StringIO (text mode)
38
+ fasta_str = fasta_bytes.decode("utf-8", errors="ignore")
39
+ fasta_io = io.StringIO(fasta_str)
40
+
41
+ sequences = []
42
+ ids = []
43
+ for record in SeqIO.parse(fasta_io, "fasta"):
44
+ ids.append(record.id)
45
+ sequences.append(str(record.seq))
46
+
47
+ if not sequences:
48
+ return {
49
+ "overall_silhouette": 0,
50
+ "results": [{"id": "N/A", "cluster": -1, "confidence": 0, "note": "No sequences found"}]
51
+ }, "plots/scatter.png", "plots/heatmap.png"
52
+
53
+ # βœ… Do clustering (same as before)
54
+ batch_kmers = [seq_to_kmers(s) for s in sequences]
55
+ inputs = tokenizer(
56
+ batch_kmers, return_tensors="pt", padding=True, truncation=True, max_length=MAX_LENGTH
57
+ )
58
+ inputs = {k: v.to(device) for k, v in inputs.items()}
59
+
60
+ with torch.no_grad():
61
+ outputs = model(**inputs, output_hidden_states=True)
62
+ last_hidden = outputs.hidden_states[-1]
63
+ mean_embeddings = last_hidden.mean(dim=1).cpu().numpy()
64
+
65
+ labels = clusterer.fit_predict(mean_embeddings)
66
+ strengths = [1.0 if l != -1 else 0.0 for l in labels]
67
+
68
+ valid_mask = np.array(labels) != -1
69
+ silhouette_avg, per_sample_sil = 0, None
70
+ if np.unique(np.array(labels)[valid_mask]).shape[0] > 1:
71
+ silhouette_avg = silhouette_score(mean_embeddings[valid_mask], np.array(labels)[valid_mask])
72
+
73
+ results = []
74
+ for i, seq_id in enumerate(ids):
75
+ result = {
76
+ "id": seq_id,
77
+ "cluster": int(labels[i]),
78
+ "confidence": round(float(strengths[i]), 3),
79
+ }
80
+ if labels[i] == -1:
81
+ result["note"] = "Potential novel/unknown sequence"
82
+ results.append(result)
83
+
84
+ return (
85
+ {"overall_silhouette": round(float(silhouette_avg), 3), "results": results},
86
+ "plots/scatter.png", # βœ… use existing saved scatter
87
+ "plots/heatmap.png" # βœ… use existing saved heatmap
88
+ )
89
+
90
+ except Exception as e:
91
+ return {
92
+ "overall_silhouette": 0,
93
+ "results": [{"id": "N/A", "cluster": -1, "confidence": 0, "note": f"Fallback: {str(e)}"}],
94
+ }, "plots/scatter.png", "plots/heatmap.png"
95
+
96
+
97
+ # Gradio UI
98
+ demo = gr.Interface(
99
+ fn=analyze_fasta,
100
+ inputs=gr.File(file_types=[".fasta"], type="binary"),
101
+ outputs=[gr.JSON(), gr.Image(), gr.Image()],
102
+ title="DNA Clustering Analyzer",
103
+ description="Upload a FASTA file β†’ Get clustering results + scatter plot + heatmap."
104
+ )
105
+
106
+ if __name__ == "__main__":
107
+ demo.launch()
hdbscan_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8bf0df773a165a275ccbe1e2a20137c5d3d2d4dfccfb751164b21bc850630b7f
3
+ size 7861187
plots/heatmap.png ADDED
plots/scatter.png ADDED
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ torch
4
+ transformers
5
+ biopython
6
+ hdbscan
7
+ joblib
8
+ python-multipart
9
+ scikit-learn
10
+ matplotlib
11
+ seaborn
12
+ gradio
13
+ numpy
14
+ biopython
15
+ gradio
16
+ seaborn
17
+ joblib