File size: 6,961 Bytes
02136c1
 
 
 
 
 
 
 
 
 
 
 
 
 
cdf3e4e
02136c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import streamlit as st
import google.generativeai as genai
import fitz 
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
import numpy as np
import faiss
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn
warnings.filterwarnings('ignore')

from langchain_community.document_loaders import PyPDFLoader


# Initialize session state variables
if "messages" not in st.session_state:
    st.session_state.messages = []
if "uploaded_files" not in st.session_state:
    st.session_state.uploaded_files = []
if "api_key" not in st.session_state:
    st.session_state.api_key = ""

def extract_text_from_pdf(file):
    file.seek(0)
    pdf_bytes = file.read()

    if not pdf_bytes:
        raise ValueError(f"Le fichier {file.name} est vide ou n’a pas pu être lu.")

    try:
        doc = fitz.open(stream=pdf_bytes, filetype="pdf")
    except Exception as e:
        raise RuntimeError(f"Erreur lors de l'ouverture du fichier {file.name} : {e}")
    
    text = ""
    for page in doc:
        text += page.get_text()
    
    return text


def process_files(files):
    texts = []
    for file in files:
        if file.type == "text/plain":
            content = file.getvalue().decode("utf-8")
            texts.append(content)
        elif file.type == "application/pdf":
            content = extract_text_from_pdf(file)
            texts.append(content)
       
    return "\n".join(texts)

def build_index(text):
    splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
    chunks = splitter.split_text(text)
    
    embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
    vectors = [embeddings.embed_query(chunk) for chunk in chunks]
    dimension = len(vectors[0])
    
    index = faiss.IndexFlatL2(dimension)
    index.add(np.array(vectors).astype("float32"))
    
    return index, chunks

def retrieve_chunks(query, index, chunks, k=3):
    embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
    query_vector = np.array([embeddings.embed_query(query)]).astype("float32")
    distances, indices = index.search(query_vector, k)
    return [chunks[i] for i in indices[0]]


def create_sidebar():
    with st.sidebar:
        st.title("🤖 Gemini Chatbot")

        # API key input + validate button
        api_key_input = st.text_input("Google API Key:", type="password")
        if st.button("Validate API"):
            if api_key_input.strip():
                st.session_state.api_key = api_key_input.strip()
                st.success("API Key saved ✅")
            else:
                st.error("Please enter a valid API key.")

        # Show file uploader only if API key is set
        if st.session_state.api_key:
            uploaded_files = st.file_uploader(
                "📂 Upload your files (txt, pdf, etc.)", 
                accept_multiple_files=True
            )
            if uploaded_files:
                for file in uploaded_files:
                    # Avoid duplicates
                    if file.name not in [f.name for f in st.session_state.uploaded_files]:
                        st.session_state.uploaded_files.append(file)
                st.success(f"{len(st.session_state.uploaded_files)} files loaded.")
            
            if st.session_state.uploaded_files:
                st.markdown("**Files currently loaded:**")
                for f in st.session_state.uploaded_files:
                    st.write(f.name)

def main():
    if "faiss_index" not in st.session_state:
        st.session_state.faiss_index = None
    if "chunks" not in st.session_state:
        st.session_state.chunks = []
    st.set_page_config(page_title="Gemini Chatbot")
    create_sidebar()

    if not st.session_state.api_key:
        st.warning("👆 Please enter and validate your API key in the sidebar.")
        return

    genai.configure(api_key=st.session_state.api_key)
    model = genai.GenerativeModel("gemini-2.0-flash")

    # Build index if not done yet
    if st.session_state.uploaded_files and st.session_state.faiss_index is None:
        full_text = process_files(st.session_state.uploaded_files)
        if full_text:
            index, chunks = build_index(full_text)
            st.write(f"Nombre de chunks : {len(chunks)}")
            if "faiss_index" in st.session_state and st.session_state.faiss_index is not None:
                st.write(f"Dimension FAISS : {st.session_state.faiss_index.d}")
                st.write(f"Taille index FAISS : {st.session_state.faiss_index.ntotal}")
            else:
                st.warning("L'index FAISS n'est pas encore initialisé.")
            
            st.session_state.faiss_index = index
            st.session_state.chunks = chunks

    st.title("💬 Gemini Chatbot")

    # Show chat history
    chat_container = st.container()

    with chat_container:
        for msg in st.session_state.messages:
            with st.chat_message(msg["role"]):
                st.markdown(msg["content"])

    # Champ texte toujours en bas
    prompt = st.text_area("Type your question here...", key="input")

    if st.button("Send"):
        # Ajouter message utilisateur
        if not prompt.strip():
            st.warning("Please write a message before sending.")
            return

        st.session_state.messages.append({"role": "user", "content": prompt})
        #st.experimental_rerun()
        with st.chat_message("user"):
            st.markdown(prompt)

        with st.chat_message("assistant"):
            with st.spinner("Thinking..."):
                try:
                    # Retrieve relevant chunks from your files
                    if st.session_state.faiss_index:
                        relevant_chunks = retrieve_chunks(prompt, st.session_state.faiss_index, st.session_state.chunks, k=3)
                        context = "\n---\n".join(chunk if isinstance(chunk, str) else chunk.page_content for chunk in relevant_chunks)
                        prompt_with_context = f"Use the following context to answer the question:\n{context}\n\nQuestion: {prompt}"
                    else:
                        prompt_with_context = prompt  # fallback

                    response = model.generate_content(prompt_with_context)
                    if response.text:
                        st.markdown(response.text)
                        st.session_state.messages.append({"role": "assistant", "content": response.text})
                    else:
                        st.markdown("Sorry, I couldn't get a response.")
                        st.session_state.messages.append({"role": "assistant", "content": "Sorry, I couldn't get a response."})
                except Exception as e:
                    st.error(f"Error: {e}")
                    st.session_state.messages.append({"role": "assistant", "content": f"Error: {e}"})

if __name__ == "__main__":
    main()