Spaces:
Sleeping
Sleeping
GitHub Actions
commited on
Commit
Β·
9741310
1
Parent(s):
8c19d9e
Sync from GitHub 5ab1c441d82e182643effd2ce651c57f249719ee
Browse files- streamlit_app.py +148 -66
streamlit_app.py
CHANGED
|
@@ -248,6 +248,61 @@ st.markdown("""
|
|
| 248 |
height: 8px;
|
| 249 |
}
|
| 250 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
::-webkit-scrollbar-track {
|
| 252 |
background: rgba(255, 255, 255, 0.02);
|
| 253 |
}
|
|
@@ -741,9 +796,20 @@ What interests you?"""
|
|
| 741 |
|
| 742 |
return "I'm here to help with Machine Learning research! π Ask me about any ML topics or papers."
|
| 743 |
|
| 744 |
-
# Chat input
|
| 745 |
query = st.chat_input("π¬ Ask me anything about ML research...")
|
| 746 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 747 |
# Display chat history
|
| 748 |
for i, msg in enumerate(st.session_state["messages"]):
|
| 749 |
# Show user message
|
|
@@ -767,18 +833,28 @@ for i, msg in enumerate(st.session_state["messages"]):
|
|
| 767 |
)
|
| 768 |
if idx < len(msg["context"]):
|
| 769 |
st.markdown("---")
|
| 770 |
-
|
| 771 |
-
#
|
| 772 |
with st.chat_message("assistant", avatar="π€"):
|
| 773 |
-
|
| 774 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 775 |
|
| 776 |
-
#
|
| 777 |
if is_casual_conversation(msg["query"]):
|
|
|
|
|
|
|
| 778 |
casual_response = get_casual_response(msg["query"])
|
| 779 |
|
| 780 |
-
#
|
|
|
|
| 781 |
response_placeholder = st.empty()
|
|
|
|
|
|
|
| 782 |
full_response = ""
|
| 783 |
words = casual_response.split()
|
| 784 |
|
|
@@ -787,19 +863,26 @@ for i, msg in enumerate(st.session_state["messages"]):
|
|
| 787 |
response_placeholder.markdown(full_response)
|
| 788 |
time.sleep(0.02)
|
| 789 |
|
|
|
|
| 790 |
st.session_state["messages"][i]["answer"] = casual_response
|
|
|
|
| 791 |
st.rerun()
|
| 792 |
|
| 793 |
else:
|
| 794 |
# Research question - full RAG pipeline
|
| 795 |
-
rag_chain, adv_retriever = build_chain()
|
| 796 |
-
|
| 797 |
-
docs = []
|
| 798 |
-
answer_text = ""
|
| 799 |
-
error_occurred = False
|
| 800 |
-
|
| 801 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 802 |
docs = adv_retriever.get_relevant_documents(msg["query"])
|
|
|
|
| 803 |
|
| 804 |
if not docs:
|
| 805 |
answer_text = """I couldn't find any relevant research papers in the database that match your query.
|
|
@@ -812,8 +895,6 @@ for i, msg in enumerate(st.session_state["messages"]):
|
|
| 812 |
|
| 813 |
The current database focuses on ArXiv ML papers, but may not cover all research areas comprehensively."""
|
| 814 |
else:
|
| 815 |
-
thinking_placeholder.markdown('<p class="thinking">π§ Analyzing documents...</p>', unsafe_allow_html=True)
|
| 816 |
-
|
| 817 |
# Check relevance
|
| 818 |
formatted_context = format_docs(docs)
|
| 819 |
relevance_check_chain = {"context": RunnablePassthrough(), "question": RunnablePassthrough()} | relevance_prompt | llm
|
|
@@ -834,66 +915,67 @@ The current database focuses on ArXiv ML papers, but may not cover all research
|
|
| 834 |
|
| 835 |
I can only provide answers based on the ArXiv papers in the database."""
|
| 836 |
else:
|
| 837 |
-
#
|
| 838 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 839 |
answer = rag_chain.invoke(msg["query"])
|
| 840 |
answer_text = answer.content if hasattr(answer, "content") else str(answer)
|
| 841 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 842 |
except Exception as e:
|
| 843 |
-
|
|
|
|
| 844 |
msg_err = str(e)
|
| 845 |
if "models/" in msg_err and "not found" in msg_err.lower():
|
| 846 |
answer_text = "β οΈ Selected model not found. Try a different model in the sidebar."
|
| 847 |
else:
|
| 848 |
answer_text = f"β οΈ An error occurred: {e}\n\nPlease try again or rebuild the index."
|
| 849 |
-
|
| 850 |
-
|
| 851 |
-
|
| 852 |
-
|
| 853 |
-
|
| 854 |
-
import re
|
| 855 |
-
response_placeholder = st.empty()
|
| 856 |
-
parts = re.split(r'(\n\n|(?<=[.!?])\s+)', answer_text)
|
| 857 |
-
|
| 858 |
-
full_response = ""
|
| 859 |
-
for part in parts:
|
| 860 |
-
full_response += part
|
| 861 |
-
response_placeholder.markdown(full_response)
|
| 862 |
-
time.sleep(0.03)
|
| 863 |
-
|
| 864 |
-
# Update session state
|
| 865 |
-
st.session_state["messages"][i]["answer"] = answer_text
|
| 866 |
-
st.session_state["messages"][i]["context"] = docs
|
| 867 |
-
|
| 868 |
-
# Show retrieved documents
|
| 869 |
-
if docs:
|
| 870 |
-
with st.expander(f"π View {len(docs)} Retrieved Documents", expanded=False):
|
| 871 |
-
for idx, doc in enumerate(docs, 1):
|
| 872 |
-
st.markdown(f"**π Document {idx}**")
|
| 873 |
-
st.caption(_format_metadata(doc.metadata))
|
| 874 |
-
st.text_area(
|
| 875 |
-
f"Content {idx}",
|
| 876 |
-
doc.page_content[:800] + ("..." if len(doc.page_content) > 800 else ""),
|
| 877 |
-
height=150,
|
| 878 |
-
key=f"new_doc_{i}_{idx}",
|
| 879 |
-
disabled=True
|
| 880 |
-
)
|
| 881 |
-
if idx < len(docs):
|
| 882 |
-
st.markdown("---")
|
| 883 |
-
|
| 884 |
-
st.rerun()
|
| 885 |
-
|
| 886 |
-
# Process new query
|
| 887 |
-
if query:
|
| 888 |
-
# Add message to session state immediately
|
| 889 |
-
st.session_state["messages"].append({
|
| 890 |
-
"query": query,
|
| 891 |
-
"answer": None,
|
| 892 |
-
"context": []
|
| 893 |
-
})
|
| 894 |
-
|
| 895 |
-
# Force rerun to show the user message immediately
|
| 896 |
-
st.rerun()
|
| 897 |
|
| 898 |
# Footer with tips - only show if there are messages
|
| 899 |
if len(st.session_state["messages"]) > 0:
|
|
|
|
| 248 |
height: 8px;
|
| 249 |
}
|
| 250 |
|
| 251 |
+
/* Thinking indicator animation */
|
| 252 |
+
.thinking {
|
| 253 |
+
display: flex;
|
| 254 |
+
align-items: center;
|
| 255 |
+
color: #a8a8a8 !important;
|
| 256 |
+
font-style: italic;
|
| 257 |
+
animation: thinking-pulse 2s ease-in-out infinite;
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
.thinking::after {
|
| 261 |
+
content: "...";
|
| 262 |
+
animation: thinking-dots 1.5s steps(4, end) infinite;
|
| 263 |
+
margin-left: 4px;
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
@keyframes thinking-pulse {
|
| 267 |
+
0%, 100% { opacity: 0.7; }
|
| 268 |
+
50% { opacity: 1; }
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
@keyframes thinking-dots {
|
| 272 |
+
0%, 20% { content: ""; }
|
| 273 |
+
40% { content: "."; }
|
| 274 |
+
60% { content: ".."; }
|
| 275 |
+
80%, 100% { content: "..."; }
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
/* Processing status indicators */
|
| 279 |
+
.status-indicator {
|
| 280 |
+
display: flex;
|
| 281 |
+
align-items: center;
|
| 282 |
+
gap: 8px;
|
| 283 |
+
padding: 8px 12px;
|
| 284 |
+
background: rgba(16, 163, 127, 0.1);
|
| 285 |
+
border: 1px solid rgba(16, 163, 127, 0.3);
|
| 286 |
+
border-radius: 8px;
|
| 287 |
+
margin-bottom: 10px;
|
| 288 |
+
color: #a8e6cf;
|
| 289 |
+
font-size: 14px;
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
.spinner {
|
| 293 |
+
width: 16px;
|
| 294 |
+
height: 16px;
|
| 295 |
+
border: 2px solid rgba(16, 163, 127, 0.3);
|
| 296 |
+
border-top: 2px solid #10a37f;
|
| 297 |
+
border-radius: 50%;
|
| 298 |
+
animation: spin 1s linear infinite;
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
@keyframes spin {
|
| 302 |
+
0% { transform: rotate(0deg); }
|
| 303 |
+
100% { transform: rotate(360deg); }
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
::-webkit-scrollbar-track {
|
| 307 |
background: rgba(255, 255, 255, 0.02);
|
| 308 |
}
|
|
|
|
| 796 |
|
| 797 |
return "I'm here to help with Machine Learning research! π Ask me about any ML topics or papers."
|
| 798 |
|
| 799 |
+
# Chat input - Process new query first (before displaying messages)
|
| 800 |
query = st.chat_input("π¬ Ask me anything about ML research...")
|
| 801 |
|
| 802 |
+
if query:
|
| 803 |
+
# Add message to session state immediately
|
| 804 |
+
st.session_state["messages"].append({
|
| 805 |
+
"query": query,
|
| 806 |
+
"answer": None,
|
| 807 |
+
"context": [],
|
| 808 |
+
"processing": True
|
| 809 |
+
})
|
| 810 |
+
# Force rerun to show the user message immediately
|
| 811 |
+
st.rerun()
|
| 812 |
+
|
| 813 |
# Display chat history
|
| 814 |
for i, msg in enumerate(st.session_state["messages"]):
|
| 815 |
# Show user message
|
|
|
|
| 833 |
)
|
| 834 |
if idx < len(msg["context"]):
|
| 835 |
st.markdown("---")
|
| 836 |
+
elif msg.get("processing", False):
|
| 837 |
+
# Process this message now - show thinking indicator and generate response
|
| 838 |
with st.chat_message("assistant", avatar="π€"):
|
| 839 |
+
# Show immediate thinking indicator
|
| 840 |
+
status_container = st.container()
|
| 841 |
+
with status_container:
|
| 842 |
+
st.markdown(
|
| 843 |
+
'<div class="status-indicator"><div class="spinner"></div>π Searching research papers</div>',
|
| 844 |
+
unsafe_allow_html=True
|
| 845 |
+
)
|
| 846 |
|
| 847 |
+
# Process the response
|
| 848 |
if is_casual_conversation(msg["query"]):
|
| 849 |
+
# Handle casual conversation
|
| 850 |
+
time.sleep(0.5) # Brief pause for better UX
|
| 851 |
casual_response = get_casual_response(msg["query"])
|
| 852 |
|
| 853 |
+
# Clear status and show response
|
| 854 |
+
status_container.empty()
|
| 855 |
response_placeholder = st.empty()
|
| 856 |
+
|
| 857 |
+
# Smooth streaming effect
|
| 858 |
full_response = ""
|
| 859 |
words = casual_response.split()
|
| 860 |
|
|
|
|
| 863 |
response_placeholder.markdown(full_response)
|
| 864 |
time.sleep(0.02)
|
| 865 |
|
| 866 |
+
# Update session state
|
| 867 |
st.session_state["messages"][i]["answer"] = casual_response
|
| 868 |
+
st.session_state["messages"][i]["processing"] = False
|
| 869 |
st.rerun()
|
| 870 |
|
| 871 |
else:
|
| 872 |
# Research question - full RAG pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 873 |
try:
|
| 874 |
+
rag_chain, adv_retriever = build_chain()
|
| 875 |
+
|
| 876 |
+
# Update status
|
| 877 |
+
status_container.empty()
|
| 878 |
+
with status_container:
|
| 879 |
+
st.markdown(
|
| 880 |
+
'<div class="status-indicator"><div class="spinner"></div>π§ Analyzing documents</div>',
|
| 881 |
+
unsafe_allow_html=True
|
| 882 |
+
)
|
| 883 |
+
|
| 884 |
docs = adv_retriever.get_relevant_documents(msg["query"])
|
| 885 |
+
answer_text = ""
|
| 886 |
|
| 887 |
if not docs:
|
| 888 |
answer_text = """I couldn't find any relevant research papers in the database that match your query.
|
|
|
|
| 895 |
|
| 896 |
The current database focuses on ArXiv ML papers, but may not cover all research areas comprehensively."""
|
| 897 |
else:
|
|
|
|
|
|
|
| 898 |
# Check relevance
|
| 899 |
formatted_context = format_docs(docs)
|
| 900 |
relevance_check_chain = {"context": RunnablePassthrough(), "question": RunnablePassthrough()} | relevance_prompt | llm
|
|
|
|
| 915 |
|
| 916 |
I can only provide answers based on the ArXiv papers in the database."""
|
| 917 |
else:
|
| 918 |
+
# Update status for generation
|
| 919 |
+
status_container.empty()
|
| 920 |
+
with status_container:
|
| 921 |
+
st.markdown(
|
| 922 |
+
'<div class="status-indicator"><div class="spinner"></div>βοΈ Generating response</div>',
|
| 923 |
+
unsafe_allow_html=True
|
| 924 |
+
)
|
| 925 |
+
|
| 926 |
answer = rag_chain.invoke(msg["query"])
|
| 927 |
answer_text = answer.content if hasattr(answer, "content") else str(answer)
|
| 928 |
|
| 929 |
+
# Clear status and display response with streaming
|
| 930 |
+
status_container.empty()
|
| 931 |
+
|
| 932 |
+
# Stream response
|
| 933 |
+
import re
|
| 934 |
+
response_placeholder = st.empty()
|
| 935 |
+
parts = re.split(r'(\n\n|(?<=[.!?])\s+)', answer_text)
|
| 936 |
+
|
| 937 |
+
full_response = ""
|
| 938 |
+
for part in parts:
|
| 939 |
+
full_response += part
|
| 940 |
+
response_placeholder.markdown(full_response)
|
| 941 |
+
time.sleep(0.03)
|
| 942 |
+
|
| 943 |
+
# Update session state
|
| 944 |
+
st.session_state["messages"][i]["answer"] = answer_text
|
| 945 |
+
st.session_state["messages"][i]["context"] = docs
|
| 946 |
+
st.session_state["messages"][i]["processing"] = False
|
| 947 |
+
|
| 948 |
+
# Show retrieved documents
|
| 949 |
+
if docs:
|
| 950 |
+
with st.expander(f"π View {len(docs)} Retrieved Documents", expanded=False):
|
| 951 |
+
for idx, doc in enumerate(docs, 1):
|
| 952 |
+
st.markdown(f"**π Document {idx}**")
|
| 953 |
+
st.caption(_format_metadata(doc.metadata))
|
| 954 |
+
st.text_area(
|
| 955 |
+
f"Content {idx}",
|
| 956 |
+
doc.page_content[:800] + ("..." if len(doc.page_content) > 800 else ""),
|
| 957 |
+
height=150,
|
| 958 |
+
key=f"new_doc_{i}_{idx}",
|
| 959 |
+
disabled=True
|
| 960 |
+
)
|
| 961 |
+
if idx < len(docs):
|
| 962 |
+
st.markdown("---")
|
| 963 |
+
|
| 964 |
+
st.rerun()
|
| 965 |
+
|
| 966 |
except Exception as e:
|
| 967 |
+
# Handle errors
|
| 968 |
+
status_container.empty()
|
| 969 |
msg_err = str(e)
|
| 970 |
if "models/" in msg_err and "not found" in msg_err.lower():
|
| 971 |
answer_text = "β οΈ Selected model not found. Try a different model in the sidebar."
|
| 972 |
else:
|
| 973 |
answer_text = f"β οΈ An error occurred: {e}\n\nPlease try again or rebuild the index."
|
| 974 |
+
|
| 975 |
+
st.error(answer_text)
|
| 976 |
+
st.session_state["messages"][i]["answer"] = answer_text
|
| 977 |
+
st.session_state["messages"][i]["processing"] = False
|
| 978 |
+
st.rerun()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 979 |
|
| 980 |
# Footer with tips - only show if there are messages
|
| 981 |
if len(st.session_state["messages"]) > 0:
|