Manjulabalathandayutham commited on
Commit
09bd568
·
verified ·
1 Parent(s): 1ae518c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -26
app.py CHANGED
@@ -1,42 +1,64 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForQuestionAnswering
3
  import torch
4
 
5
- # Load Bio_ClinicalBERT model and tokenizer
6
- model_name = "emilyalsentzer/Bio_ClinicalBERT"
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForQuestionAnswering.from_pretrained(model_name)
 
 
 
 
 
9
 
10
  # Streamlit app layout
11
  st.title("Medical Chatbot")
12
- st.write("Enter your medical query below:")
 
 
 
 
13
 
14
  # Text input
15
- user_input = st.text_area("Input Text")
16
 
17
- if st.button("Get Response"):
18
  if user_input:
19
- # Example context for medical queries
20
- context = """
21
- A fever is a temporary increase in body temperature, often due to an illness.
22
- It's a sign that something unusual is going on in your body.
23
- For adults, a fever may be uncomfortable, but usually isn’t a cause for concern unless it reaches 103 F (39.4 C) or higher.
24
- """ # This context should be based on your use case
25
-
26
- # Tokenize input
27
- inputs = tokenizer(user_input, context, return_tensors="pt")
28
 
29
- # Get model outputs
 
 
 
 
 
30
  with torch.no_grad():
31
- outputs = model(**inputs)
 
 
32
 
33
- # Get the most likely beginning and end of the answer
34
- answer_start = torch.argmax(outputs.start_logits)
35
- answer_end = torch.argmax(outputs.end_logits) + 1
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- # Convert tokens to string
38
- answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end]))
39
 
40
- st.write(f"Bot: {answer}")
 
 
41
  else:
42
- st.write("Please enter a medical query.")
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering, AutoModelForCausalLM
3
  import torch
4
 
5
+ # Load BioBERT for medical question answering
6
+ medical_model_name = "dmis-lab/biobert-base-cased-v1.1"
7
+ medical_tokenizer = AutoTokenizer.from_pretrained(medical_model_name)
8
+ medical_model = AutoModelForQuestionAnswering.from_pretrained(medical_model_name)
9
+
10
+ # Load DialoGPT for conversation
11
+ conversation_model_name = "microsoft/DialoGPT-small"
12
+ conversation_tokenizer = AutoTokenizer.from_pretrained(conversation_model_name)
13
+ conversation_model = AutoModelForCausalLM.from_pretrained(conversation_model_name)
14
 
15
  # Streamlit app layout
16
  st.title("Medical Chatbot")
17
+ st.write("Ask your medical-related question below:")
18
+
19
+ # Conversation history tracker
20
+ if 'conversation_history' not in st.session_state:
21
+ st.session_state.conversation_history = []
22
 
23
  # Text input
24
+ user_input = st.text_input("You:")
25
 
26
+ if st.button("Send"):
27
  if user_input:
28
+ # Append user input to the conversation history
29
+ st.session_state.conversation_history.append(f"User: {user_input}")
 
 
 
 
 
 
 
30
 
31
+ # BioBERT for Medical Q&A
32
+ context = "Medical knowledge base or a long text about medical conditions..." # Insert medical reference or context
33
+ inputs = medical_tokenizer.encode_plus(user_input, context, add_special_tokens=True, return_tensors="pt")
34
+ input_ids = inputs["input_ids"].tolist()[0]
35
+
36
+ # Perform Question Answering using BioBERT
37
  with torch.no_grad():
38
+ outputs = medical_model(**inputs)
39
+ answer_start_scores = outputs.start_logits
40
+ answer_end_scores = outputs.end_logits
41
 
42
+ answer_start = torch.argmax(answer_start_scores)
43
+ answer_end = torch.argmax(answer_end_scores) + 1
44
+ medical_answer = medical_tokenizer.convert_tokens_to_string(medical_tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))
45
+
46
+ # Append medical response to the conversation history
47
+ st.session_state.conversation_history.append(f"Bot (Medical): {medical_answer}")
48
+
49
+ # DialoGPT for conversational response
50
+ conversation_input_ids = conversation_tokenizer.encode(user_input + conversation_tokenizer.eos_token, return_tensors='pt')
51
+ conversation_bot_input_ids = torch.cat([conversation_tokenizer.encode(convo + conversation_tokenizer.eos_token, return_tensors='pt') for convo in st.session_state.conversation_history], dim=-1)
52
+
53
+ # Generate conversational response
54
+ chat_history_ids = conversation_model.generate(conversation_bot_input_ids, max_length=1000, pad_token_id=conversation_tokenizer.eos_token_id)
55
+ conversation_response = conversation_tokenizer.decode(chat_history_ids[:, conversation_bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
56
 
57
+ # Append conversational response to conversation history
58
+ st.session_state.conversation_history.append(f"Bot (Conversational): {conversation_response}")
59
 
60
+ # Display conversation history
61
+ for message in st.session_state.conversation_history:
62
+ st.write(message)
63
  else:
64
+ st.write("Please enter a medical question.")