RAG_chatbot / chat.py
Rhodham96's picture
Add requirements.txt and update Dockerfile
cdf3e4e
raw
history blame
6.96 kB
import streamlit as st
import google.generativeai as genai
import fitz
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
import numpy as np
import faiss
def warn(*args, **kwargs):
pass
import warnings
warnings.warn = warn
warnings.filterwarnings('ignore')
from langchain_community.document_loaders import PyPDFLoader
# Initialize session state variables
if "messages" not in st.session_state:
st.session_state.messages = []
if "uploaded_files" not in st.session_state:
st.session_state.uploaded_files = []
if "api_key" not in st.session_state:
st.session_state.api_key = ""
def extract_text_from_pdf(file):
file.seek(0)
pdf_bytes = file.read()
if not pdf_bytes:
raise ValueError(f"Le fichier {file.name} est vide ou n’a pas pu être lu.")
try:
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
except Exception as e:
raise RuntimeError(f"Erreur lors de l'ouverture du fichier {file.name} : {e}")
text = ""
for page in doc:
text += page.get_text()
return text
def process_files(files):
texts = []
for file in files:
if file.type == "text/plain":
content = file.getvalue().decode("utf-8")
texts.append(content)
elif file.type == "application/pdf":
content = extract_text_from_pdf(file)
texts.append(content)
return "\n".join(texts)
def build_index(text):
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
chunks = splitter.split_text(text)
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
vectors = [embeddings.embed_query(chunk) for chunk in chunks]
dimension = len(vectors[0])
index = faiss.IndexFlatL2(dimension)
index.add(np.array(vectors).astype("float32"))
return index, chunks
def retrieve_chunks(query, index, chunks, k=3):
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
query_vector = np.array([embeddings.embed_query(query)]).astype("float32")
distances, indices = index.search(query_vector, k)
return [chunks[i] for i in indices[0]]
def create_sidebar():
with st.sidebar:
st.title("🤖 Gemini Chatbot")
# API key input + validate button
api_key_input = st.text_input("Google API Key:", type="password")
if st.button("Validate API"):
if api_key_input.strip():
st.session_state.api_key = api_key_input.strip()
st.success("API Key saved ✅")
else:
st.error("Please enter a valid API key.")
# Show file uploader only if API key is set
if st.session_state.api_key:
uploaded_files = st.file_uploader(
"📂 Upload your files (txt, pdf, etc.)",
accept_multiple_files=True
)
if uploaded_files:
for file in uploaded_files:
# Avoid duplicates
if file.name not in [f.name for f in st.session_state.uploaded_files]:
st.session_state.uploaded_files.append(file)
st.success(f"{len(st.session_state.uploaded_files)} files loaded.")
if st.session_state.uploaded_files:
st.markdown("**Files currently loaded:**")
for f in st.session_state.uploaded_files:
st.write(f.name)
def main():
if "faiss_index" not in st.session_state:
st.session_state.faiss_index = None
if "chunks" not in st.session_state:
st.session_state.chunks = []
st.set_page_config(page_title="Gemini Chatbot")
create_sidebar()
if not st.session_state.api_key:
st.warning("👆 Please enter and validate your API key in the sidebar.")
return
genai.configure(api_key=st.session_state.api_key)
model = genai.GenerativeModel("gemini-2.0-flash")
# Build index if not done yet
if st.session_state.uploaded_files and st.session_state.faiss_index is None:
full_text = process_files(st.session_state.uploaded_files)
if full_text:
index, chunks = build_index(full_text)
st.write(f"Nombre de chunks : {len(chunks)}")
if "faiss_index" in st.session_state and st.session_state.faiss_index is not None:
st.write(f"Dimension FAISS : {st.session_state.faiss_index.d}")
st.write(f"Taille index FAISS : {st.session_state.faiss_index.ntotal}")
else:
st.warning("L'index FAISS n'est pas encore initialisé.")
st.session_state.faiss_index = index
st.session_state.chunks = chunks
st.title("💬 Gemini Chatbot")
# Show chat history
chat_container = st.container()
with chat_container:
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
# Champ texte toujours en bas
prompt = st.text_area("Type your question here...", key="input")
if st.button("Send"):
# Ajouter message utilisateur
if not prompt.strip():
st.warning("Please write a message before sending.")
return
st.session_state.messages.append({"role": "user", "content": prompt})
#st.experimental_rerun()
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
try:
# Retrieve relevant chunks from your files
if st.session_state.faiss_index:
relevant_chunks = retrieve_chunks(prompt, st.session_state.faiss_index, st.session_state.chunks, k=3)
context = "\n---\n".join(chunk if isinstance(chunk, str) else chunk.page_content for chunk in relevant_chunks)
prompt_with_context = f"Use the following context to answer the question:\n{context}\n\nQuestion: {prompt}"
else:
prompt_with_context = prompt # fallback
response = model.generate_content(prompt_with_context)
if response.text:
st.markdown(response.text)
st.session_state.messages.append({"role": "assistant", "content": response.text})
else:
st.markdown("Sorry, I couldn't get a response.")
st.session_state.messages.append({"role": "assistant", "content": "Sorry, I couldn't get a response."})
except Exception as e:
st.error(f"Error: {e}")
st.session_state.messages.append({"role": "assistant", "content": f"Error: {e}"})
if __name__ == "__main__":
main()