Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import google.generativeai as genai | |
| import fitz | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| import numpy as np | |
| import faiss | |
| def warn(*args, **kwargs): | |
| pass | |
| import warnings | |
| warnings.warn = warn | |
| warnings.filterwarnings('ignore') | |
| from langchain_community.document_loaders import PyPDFLoader | |
| # Initialize session state variables | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| if "uploaded_files" not in st.session_state: | |
| st.session_state.uploaded_files = [] | |
| if "api_key" not in st.session_state: | |
| st.session_state.api_key = "" | |
| def extract_text_from_pdf(file): | |
| file.seek(0) | |
| pdf_bytes = file.read() | |
| if not pdf_bytes: | |
| raise ValueError(f"Le fichier {file.name} est vide ou n’a pas pu être lu.") | |
| try: | |
| doc = fitz.open(stream=pdf_bytes, filetype="pdf") | |
| except Exception as e: | |
| raise RuntimeError(f"Erreur lors de l'ouverture du fichier {file.name} : {e}") | |
| text = "" | |
| for page in doc: | |
| text += page.get_text() | |
| return text | |
| def process_files(files): | |
| texts = [] | |
| for file in files: | |
| if file.type == "text/plain": | |
| content = file.getvalue().decode("utf-8") | |
| texts.append(content) | |
| elif file.type == "application/pdf": | |
| content = extract_text_from_pdf(file) | |
| texts.append(content) | |
| return "\n".join(texts) | |
| def build_index(text): | |
| splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) | |
| chunks = splitter.split_text(text) | |
| embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") | |
| vectors = [embeddings.embed_query(chunk) for chunk in chunks] | |
| dimension = len(vectors[0]) | |
| index = faiss.IndexFlatL2(dimension) | |
| index.add(np.array(vectors).astype("float32")) | |
| return index, chunks | |
| def retrieve_chunks(query, index, chunks, k=3): | |
| embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") | |
| query_vector = np.array([embeddings.embed_query(query)]).astype("float32") | |
| distances, indices = index.search(query_vector, k) | |
| return [chunks[i] for i in indices[0]] | |
| def create_sidebar(): | |
| with st.sidebar: | |
| st.title("🤖 Gemini Chatbot") | |
| # API key input + validate button | |
| api_key_input = st.text_input("Google API Key:", type="password") | |
| if st.button("Validate API"): | |
| if api_key_input.strip(): | |
| st.session_state.api_key = api_key_input.strip() | |
| st.success("API Key saved ✅") | |
| else: | |
| st.error("Please enter a valid API key.") | |
| # Show file uploader only if API key is set | |
| if st.session_state.api_key: | |
| uploaded_files = st.file_uploader( | |
| "📂 Upload your files (txt, pdf, etc.)", | |
| accept_multiple_files=True | |
| ) | |
| if uploaded_files: | |
| for file in uploaded_files: | |
| # Avoid duplicates | |
| if file.name not in [f.name for f in st.session_state.uploaded_files]: | |
| st.session_state.uploaded_files.append(file) | |
| st.success(f"{len(st.session_state.uploaded_files)} files loaded.") | |
| if st.session_state.uploaded_files: | |
| st.markdown("**Files currently loaded:**") | |
| for f in st.session_state.uploaded_files: | |
| st.write(f.name) | |
| def main(): | |
| if "faiss_index" not in st.session_state: | |
| st.session_state.faiss_index = None | |
| if "chunks" not in st.session_state: | |
| st.session_state.chunks = [] | |
| st.set_page_config(page_title="Gemini Chatbot") | |
| create_sidebar() | |
| if not st.session_state.api_key: | |
| st.warning("👆 Please enter and validate your API key in the sidebar.") | |
| return | |
| genai.configure(api_key=st.session_state.api_key) | |
| model = genai.GenerativeModel("gemini-2.0-flash") | |
| # Build index if not done yet | |
| if st.session_state.uploaded_files and st.session_state.faiss_index is None: | |
| full_text = process_files(st.session_state.uploaded_files) | |
| if full_text: | |
| index, chunks = build_index(full_text) | |
| st.write(f"Nombre de chunks : {len(chunks)}") | |
| if "faiss_index" in st.session_state and st.session_state.faiss_index is not None: | |
| st.write(f"Dimension FAISS : {st.session_state.faiss_index.d}") | |
| st.write(f"Taille index FAISS : {st.session_state.faiss_index.ntotal}") | |
| else: | |
| st.warning("L'index FAISS n'est pas encore initialisé.") | |
| st.session_state.faiss_index = index | |
| st.session_state.chunks = chunks | |
| st.title("💬 Gemini Chatbot") | |
| # Show chat history | |
| chat_container = st.container() | |
| with chat_container: | |
| for msg in st.session_state.messages: | |
| with st.chat_message(msg["role"]): | |
| st.markdown(msg["content"]) | |
| # Champ texte toujours en bas | |
| prompt = st.text_area("Type your question here...", key="input") | |
| if st.button("Send"): | |
| # Ajouter message utilisateur | |
| if not prompt.strip(): | |
| st.warning("Please write a message before sending.") | |
| return | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| #st.experimental_rerun() | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| with st.chat_message("assistant"): | |
| with st.spinner("Thinking..."): | |
| try: | |
| # Retrieve relevant chunks from your files | |
| if st.session_state.faiss_index: | |
| relevant_chunks = retrieve_chunks(prompt, st.session_state.faiss_index, st.session_state.chunks, k=3) | |
| context = "\n---\n".join(chunk if isinstance(chunk, str) else chunk.page_content for chunk in relevant_chunks) | |
| prompt_with_context = f"Use the following context to answer the question:\n{context}\n\nQuestion: {prompt}" | |
| else: | |
| prompt_with_context = prompt # fallback | |
| response = model.generate_content(prompt_with_context) | |
| if response.text: | |
| st.markdown(response.text) | |
| st.session_state.messages.append({"role": "assistant", "content": response.text}) | |
| else: | |
| st.markdown("Sorry, I couldn't get a response.") | |
| st.session_state.messages.append({"role": "assistant", "content": "Sorry, I couldn't get a response."}) | |
| except Exception as e: | |
| st.error(f"Error: {e}") | |
| st.session_state.messages.append({"role": "assistant", "content": f"Error: {e}"}) | |
| if __name__ == "__main__": | |
| main() |