Spaces:
Sleeping
Sleeping
File size: 6,961 Bytes
02136c1 cdf3e4e 02136c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import streamlit as st
import google.generativeai as genai
import fitz
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
import numpy as np
import faiss
def warn(*args, **kwargs):
pass
import warnings
warnings.warn = warn
warnings.filterwarnings('ignore')
from langchain_community.document_loaders import PyPDFLoader
# Initialize session state variables
if "messages" not in st.session_state:
st.session_state.messages = []
if "uploaded_files" not in st.session_state:
st.session_state.uploaded_files = []
if "api_key" not in st.session_state:
st.session_state.api_key = ""
def extract_text_from_pdf(file):
file.seek(0)
pdf_bytes = file.read()
if not pdf_bytes:
raise ValueError(f"Le fichier {file.name} est vide ou n’a pas pu être lu.")
try:
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
except Exception as e:
raise RuntimeError(f"Erreur lors de l'ouverture du fichier {file.name} : {e}")
text = ""
for page in doc:
text += page.get_text()
return text
def process_files(files):
texts = []
for file in files:
if file.type == "text/plain":
content = file.getvalue().decode("utf-8")
texts.append(content)
elif file.type == "application/pdf":
content = extract_text_from_pdf(file)
texts.append(content)
return "\n".join(texts)
def build_index(text):
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
chunks = splitter.split_text(text)
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
vectors = [embeddings.embed_query(chunk) for chunk in chunks]
dimension = len(vectors[0])
index = faiss.IndexFlatL2(dimension)
index.add(np.array(vectors).astype("float32"))
return index, chunks
def retrieve_chunks(query, index, chunks, k=3):
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
query_vector = np.array([embeddings.embed_query(query)]).astype("float32")
distances, indices = index.search(query_vector, k)
return [chunks[i] for i in indices[0]]
def create_sidebar():
with st.sidebar:
st.title("🤖 Gemini Chatbot")
# API key input + validate button
api_key_input = st.text_input("Google API Key:", type="password")
if st.button("Validate API"):
if api_key_input.strip():
st.session_state.api_key = api_key_input.strip()
st.success("API Key saved ✅")
else:
st.error("Please enter a valid API key.")
# Show file uploader only if API key is set
if st.session_state.api_key:
uploaded_files = st.file_uploader(
"📂 Upload your files (txt, pdf, etc.)",
accept_multiple_files=True
)
if uploaded_files:
for file in uploaded_files:
# Avoid duplicates
if file.name not in [f.name for f in st.session_state.uploaded_files]:
st.session_state.uploaded_files.append(file)
st.success(f"{len(st.session_state.uploaded_files)} files loaded.")
if st.session_state.uploaded_files:
st.markdown("**Files currently loaded:**")
for f in st.session_state.uploaded_files:
st.write(f.name)
def main():
if "faiss_index" not in st.session_state:
st.session_state.faiss_index = None
if "chunks" not in st.session_state:
st.session_state.chunks = []
st.set_page_config(page_title="Gemini Chatbot")
create_sidebar()
if not st.session_state.api_key:
st.warning("👆 Please enter and validate your API key in the sidebar.")
return
genai.configure(api_key=st.session_state.api_key)
model = genai.GenerativeModel("gemini-2.0-flash")
# Build index if not done yet
if st.session_state.uploaded_files and st.session_state.faiss_index is None:
full_text = process_files(st.session_state.uploaded_files)
if full_text:
index, chunks = build_index(full_text)
st.write(f"Nombre de chunks : {len(chunks)}")
if "faiss_index" in st.session_state and st.session_state.faiss_index is not None:
st.write(f"Dimension FAISS : {st.session_state.faiss_index.d}")
st.write(f"Taille index FAISS : {st.session_state.faiss_index.ntotal}")
else:
st.warning("L'index FAISS n'est pas encore initialisé.")
st.session_state.faiss_index = index
st.session_state.chunks = chunks
st.title("💬 Gemini Chatbot")
# Show chat history
chat_container = st.container()
with chat_container:
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
# Champ texte toujours en bas
prompt = st.text_area("Type your question here...", key="input")
if st.button("Send"):
# Ajouter message utilisateur
if not prompt.strip():
st.warning("Please write a message before sending.")
return
st.session_state.messages.append({"role": "user", "content": prompt})
#st.experimental_rerun()
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
try:
# Retrieve relevant chunks from your files
if st.session_state.faiss_index:
relevant_chunks = retrieve_chunks(prompt, st.session_state.faiss_index, st.session_state.chunks, k=3)
context = "\n---\n".join(chunk if isinstance(chunk, str) else chunk.page_content for chunk in relevant_chunks)
prompt_with_context = f"Use the following context to answer the question:\n{context}\n\nQuestion: {prompt}"
else:
prompt_with_context = prompt # fallback
response = model.generate_content(prompt_with_context)
if response.text:
st.markdown(response.text)
st.session_state.messages.append({"role": "assistant", "content": response.text})
else:
st.markdown("Sorry, I couldn't get a response.")
st.session_state.messages.append({"role": "assistant", "content": "Sorry, I couldn't get a response."})
except Exception as e:
st.error(f"Error: {e}")
st.session_state.messages.append({"role": "assistant", "content": f"Error: {e}"})
if __name__ == "__main__":
main() |