Manjulabalathandayutham commited on
Commit
03facfc
·
verified ·
1 Parent(s): 09bd568

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -29,10 +29,13 @@ if st.button("Send"):
29
  st.session_state.conversation_history.append(f"User: {user_input}")
30
 
31
  # BioBERT for Medical Q&A
32
- context = "Medical knowledge base or a long text about medical conditions..." # Insert medical reference or context
 
 
 
33
  inputs = medical_tokenizer.encode_plus(user_input, context, add_special_tokens=True, return_tensors="pt")
34
  input_ids = inputs["input_ids"].tolist()[0]
35
-
36
  # Perform Question Answering using BioBERT
37
  with torch.no_grad():
38
  outputs = medical_model(**inputs)
@@ -43,13 +46,17 @@ if st.button("Send"):
43
  answer_end = torch.argmax(answer_end_scores) + 1
44
  medical_answer = medical_tokenizer.convert_tokens_to_string(medical_tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))
45
 
 
 
 
 
46
  # Append medical response to the conversation history
47
  st.session_state.conversation_history.append(f"Bot (Medical): {medical_answer}")
48
 
49
  # DialoGPT for conversational response
50
  conversation_input_ids = conversation_tokenizer.encode(user_input + conversation_tokenizer.eos_token, return_tensors='pt')
51
  conversation_bot_input_ids = torch.cat([conversation_tokenizer.encode(convo + conversation_tokenizer.eos_token, return_tensors='pt') for convo in st.session_state.conversation_history], dim=-1)
52
-
53
  # Generate conversational response
54
  chat_history_ids = conversation_model.generate(conversation_bot_input_ids, max_length=1000, pad_token_id=conversation_tokenizer.eos_token_id)
55
  conversation_response = conversation_tokenizer.decode(chat_history_ids[:, conversation_bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
 
29
  st.session_state.conversation_history.append(f"User: {user_input}")
30
 
31
  # BioBERT for Medical Q&A
32
+ context = """
33
+ A headache can have many causes, ranging from stress, dehydration, or fatigue to more severe conditions like migraines, infections, or neurological problems. Common remedies include over-the-counter pain relievers, hydration, and rest.
34
+ If headaches are persistent or severe, it may indicate an underlying condition such as tension headaches, cluster headaches, or even infections like sinusitis. If the headache is accompanied by other symptoms such as nausea, vision changes, or confusion, it is recommended to seek medical attention.
35
+ """
36
  inputs = medical_tokenizer.encode_plus(user_input, context, add_special_tokens=True, return_tensors="pt")
37
  input_ids = inputs["input_ids"].tolist()[0]
38
+
39
  # Perform Question Answering using BioBERT
40
  with torch.no_grad():
41
  outputs = medical_model(**inputs)
 
46
  answer_end = torch.argmax(answer_end_scores) + 1
47
  medical_answer = medical_tokenizer.convert_tokens_to_string(medical_tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))
48
 
49
+ # If no medical answer is found, provide a fallback response
50
+ if medical_answer.strip() == "":
51
+ medical_answer = "I'm not sure about that. You may want to consult a medical professional."
52
+
53
  # Append medical response to the conversation history
54
  st.session_state.conversation_history.append(f"Bot (Medical): {medical_answer}")
55
 
56
  # DialoGPT for conversational response
57
  conversation_input_ids = conversation_tokenizer.encode(user_input + conversation_tokenizer.eos_token, return_tensors='pt')
58
  conversation_bot_input_ids = torch.cat([conversation_tokenizer.encode(convo + conversation_tokenizer.eos_token, return_tensors='pt') for convo in st.session_state.conversation_history], dim=-1)
59
+
60
  # Generate conversational response
61
  chat_history_ids = conversation_model.generate(conversation_bot_input_ids, max_length=1000, pad_token_id=conversation_tokenizer.eos_token_id)
62
  conversation_response = conversation_tokenizer.decode(chat_history_ids[:, conversation_bot_input_ids.shape[-1]:][0], skip_special_tokens=True)