yashgupta1512 commited on
Commit
e84d512
·
verified ·
1 Parent(s): f0ed8d6

Upload 3 files

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. biobert_embeddings.pt +3 -0
  3. filtered_combined.xlsx +3 -0
  4. fin.py +127 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ filtered_combined.xlsx filter=lfs diff=lfs merge=lfs -text
biobert_embeddings.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e927c747db1f3ab40d738ceefd859e2aefcf354f8887cfb21d68bab4faed7488
3
+ size 362435795
filtered_combined.xlsx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e6c0d8986434b607859f786db205dc0d75129725f1fea973958c63b30a1ec8e
3
+ size 262863592
fin.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModel
5
+ from sklearn.metrics.pairwise import cosine_similarity
6
+ import numpy as np
7
+ import os
8
+
9
+ # Load the BioBERT model and tokenizer
10
+ @st.cache_resource
11
+ def load_model_and_tokenizer():
12
+ model_name = "dmis-lab/biobert-base-cased-v1.1"
13
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
14
+ model = AutoModel.from_pretrained(model_name).to(device)
15
+ return tokenizer, model
16
+
17
+ # Function to generate embeddings for a single input text
18
+ def generate_single_embedding(text, tokenizer, model):
19
+ model.eval()
20
+ with torch.no_grad():
21
+ encoding = tokenizer(
22
+ text,
23
+ max_length=512,
24
+ padding="max_length",
25
+ truncation=True,
26
+ return_tensors="pt",
27
+ )
28
+ encoding = {key: val.squeeze(0).to(device) for key, val in encoding.items()}
29
+ output = model(**encoding)
30
+ return output.last_hidden_state[:, 0, :].cpu().numpy()
31
+
32
+ # Load the dataset and embeddings
33
+ @st.cache_data
34
+ def load_data_and_embeddings():
35
+ file_name = "./filtered_combined.xlsx"
36
+ model_file = "./biobert_embeddings.pt"
37
+
38
+ df = pd.read_excel(file_name)
39
+ df["Combined_Text"] = df["Combined Column"].fillna("")
40
+ embeddings = torch.load(model_file)
41
+ return df, embeddings
42
+
43
+ # Function to get top N similar trials
44
+ def get_similar_trials(query_embedding, embeddings, top_n=10):
45
+ query_embedding_cpu = query_embedding.cpu().detach().numpy()
46
+ embeddings_cpu = embeddings.cpu().detach().numpy()
47
+ similarities = cosine_similarity(query_embedding_cpu, embeddings_cpu)
48
+ similar_indices = similarities.argsort(axis=1)[:, -top_n-1:-1][:, ::-1]
49
+ return similar_indices, similarities
50
+
51
+ # Load resources
52
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
53
+ tokenizer, model = load_model_and_tokenizer()
54
+ df, embeddings = load_data_and_embeddings()
55
+ def main():
56
+ tokenizer, model = load_model_and_tokenizer()
57
+ st.write("Model and Tokenizer Loaded Successfully!")
58
+ # Add your Streamlit app code here
59
+ # Streamlit GUI
60
+ st.title("Clinical Trials Similarity Finder")
61
+ st.write("Find the most similar clinical trials using BioBERT embeddings.")
62
+
63
+ # Input method
64
+ # option = st.radio(
65
+ # "Search by:",
66
+ # ("NCT ID", "Outcome or Criteria"),
67
+ # index=0,
68
+ # help="Choose how you want to search for similar trials."
69
+ # )
70
+
71
+ # if option == "NCT ID":
72
+ # nct_id = st.text_input("Enter NCT ID:", placeholder="e.g., NCT00385736")
73
+ # else:
74
+ # criteria_text = st.text_area(
75
+ # "Enter Outcome or Criteria:",
76
+ # placeholder="e.g., A study evaluating the effects of drug X on Y patients..."
77
+ # )
78
+ nct_id = st.text_input("Enter NCT ID:", placeholder="e.g., NCT00385736")
79
+
80
+ top_n = st.slider("Number of similar trials to retrieve:", min_value=1, max_value=20, value=10)
81
+
82
+ if st.button("Find Similar Trials"):
83
+ # if option == "NCT ID" and nct_id:
84
+ # # Search by NCT ID
85
+ # nct_id_to_index = {nct_id: idx for idx, nct_id in enumerate(df["nct_id"])}
86
+ # if nct_id in nct_id_to_index:
87
+ # query_idx = nct_id_to_index[nct_id]
88
+ # query_embedding = embeddings[query_idx].unsqueeze(0).to(device)
89
+ # else:
90
+ # st.error(f"NCT ID {nct_id} not found in the dataset.")
91
+ # st.stop()
92
+ # elif option == "Outcome or Criteria" and criteria_text:
93
+ # # Search by text
94
+ # query_embedding = torch.tensor(generate_single_embedding(criteria_text, tokenizer, model)).to(device)
95
+ # else:
96
+ # st.error("Please provide a valid input.")
97
+ # st.stop()
98
+ if nct_id:
99
+ # Search by NCT ID
100
+ nct_id_to_index = {nct_id: idx for idx, nct_id in enumerate(df["nct_id"])}
101
+ if nct_id in nct_id_to_index:
102
+ query_idx = nct_id_to_index[nct_id]
103
+ query_embedding = embeddings[query_idx].unsqueeze(0).to(device)
104
+ else:
105
+ st.error(f"NCT ID {nct_id} not found in the dataset.")
106
+ st.stop()
107
+
108
+ # Get similar trials
109
+ similar_indices, similarities = get_similar_trials(query_embedding, embeddings, top_n=top_n)
110
+ similar_trials = df.iloc[similar_indices[0]].copy()
111
+ similar_trials["Similarity Score"] = [
112
+ similarities[0, idx] for idx in similar_indices[0]
113
+ ]
114
+
115
+ # Display results
116
+ st.write("### Top Similar Clinical Trials:")
117
+ st.dataframe(similar_trials[["nct_id", "Study Title", "Similarity Score"]])
118
+
119
+ # Download as Excel
120
+ output_file = "similar_trials_results.xlsx"
121
+ similar_trials.to_excel(output_file, index=False)
122
+ with open(output_file, "rb") as f:
123
+ st.download_button("Download Results as Excel", f, file_name="similar_trials_results.xlsx")
124
+
125
+
126
+ if __name__ == "__main__":
127
+ main()