Update app.py
Browse files
app.py
CHANGED
|
@@ -29,10 +29,13 @@ if st.button("Send"):
|
|
| 29 |
st.session_state.conversation_history.append(f"User: {user_input}")
|
| 30 |
|
| 31 |
# BioBERT for Medical Q&A
|
| 32 |
-
context = "
|
|
|
|
|
|
|
|
|
|
| 33 |
inputs = medical_tokenizer.encode_plus(user_input, context, add_special_tokens=True, return_tensors="pt")
|
| 34 |
input_ids = inputs["input_ids"].tolist()[0]
|
| 35 |
-
|
| 36 |
# Perform Question Answering using BioBERT
|
| 37 |
with torch.no_grad():
|
| 38 |
outputs = medical_model(**inputs)
|
|
@@ -43,13 +46,17 @@ if st.button("Send"):
|
|
| 43 |
answer_end = torch.argmax(answer_end_scores) + 1
|
| 44 |
medical_answer = medical_tokenizer.convert_tokens_to_string(medical_tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
# Append medical response to the conversation history
|
| 47 |
st.session_state.conversation_history.append(f"Bot (Medical): {medical_answer}")
|
| 48 |
|
| 49 |
# DialoGPT for conversational response
|
| 50 |
conversation_input_ids = conversation_tokenizer.encode(user_input + conversation_tokenizer.eos_token, return_tensors='pt')
|
| 51 |
conversation_bot_input_ids = torch.cat([conversation_tokenizer.encode(convo + conversation_tokenizer.eos_token, return_tensors='pt') for convo in st.session_state.conversation_history], dim=-1)
|
| 52 |
-
|
| 53 |
# Generate conversational response
|
| 54 |
chat_history_ids = conversation_model.generate(conversation_bot_input_ids, max_length=1000, pad_token_id=conversation_tokenizer.eos_token_id)
|
| 55 |
conversation_response = conversation_tokenizer.decode(chat_history_ids[:, conversation_bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
|
|
|
|
| 29 |
st.session_state.conversation_history.append(f"User: {user_input}")
|
| 30 |
|
| 31 |
# BioBERT for Medical Q&A
|
| 32 |
+
context = """
|
| 33 |
+
A headache can have many causes, ranging from stress, dehydration, or fatigue to more severe conditions like migraines, infections, or neurological problems. Common remedies include over-the-counter pain relievers, hydration, and rest.
|
| 34 |
+
If headaches are persistent or severe, it may indicate an underlying condition such as tension headaches, cluster headaches, or even infections like sinusitis. If the headache is accompanied by other symptoms such as nausea, vision changes, or confusion, it is recommended to seek medical attention.
|
| 35 |
+
"""
|
| 36 |
inputs = medical_tokenizer.encode_plus(user_input, context, add_special_tokens=True, return_tensors="pt")
|
| 37 |
input_ids = inputs["input_ids"].tolist()[0]
|
| 38 |
+
|
| 39 |
# Perform Question Answering using BioBERT
|
| 40 |
with torch.no_grad():
|
| 41 |
outputs = medical_model(**inputs)
|
|
|
|
| 46 |
answer_end = torch.argmax(answer_end_scores) + 1
|
| 47 |
medical_answer = medical_tokenizer.convert_tokens_to_string(medical_tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))
|
| 48 |
|
| 49 |
+
# If no medical answer is found, provide a fallback response
|
| 50 |
+
if medical_answer.strip() == "":
|
| 51 |
+
medical_answer = "I'm not sure about that. You may want to consult a medical professional."
|
| 52 |
+
|
| 53 |
# Append medical response to the conversation history
|
| 54 |
st.session_state.conversation_history.append(f"Bot (Medical): {medical_answer}")
|
| 55 |
|
| 56 |
# DialoGPT for conversational response
|
| 57 |
conversation_input_ids = conversation_tokenizer.encode(user_input + conversation_tokenizer.eos_token, return_tensors='pt')
|
| 58 |
conversation_bot_input_ids = torch.cat([conversation_tokenizer.encode(convo + conversation_tokenizer.eos_token, return_tensors='pt') for convo in st.session_state.conversation_history], dim=-1)
|
| 59 |
+
|
| 60 |
# Generate conversational response
|
| 61 |
chat_history_ids = conversation_model.generate(conversation_bot_input_ids, max_length=1000, pad_token_id=conversation_tokenizer.eos_token_id)
|
| 62 |
conversation_response = conversation_tokenizer.decode(chat_history_ids[:, conversation_bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
|