frozen8569 commited on
Commit
38a08f3
·
verified ·
1 Parent(s): 46eee7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -123
app.py CHANGED
@@ -1,7 +1,7 @@
1
- import gradio as gr
2
  import torch
3
  import fitz # PyMuPDF
4
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
  from langchain_community.vectorstores import FAISS
7
  from langchain_community.embeddings import HuggingFaceEmbeddings
@@ -9,161 +9,202 @@ from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
9
  from langchain.chains import ConversationalRetrievalChain
10
  from langchain.memory import ConversationBufferMemory
11
  from langchain.prompts import PromptTemplate
 
 
12
  import pandas as pd
13
  from aif360.datasets import StandardDataset
14
  from aif360.metrics import BinaryLabelDatasetMetric
15
- import time
16
 
17
- # --- Caching (simple global variables for Gradio) ---
18
- _llm = None
19
- _qa_chain = None
 
 
 
 
20
 
21
- # --- Core AI and Data Processing Functions (Unchanged) ---
 
22
  def load_llm():
23
- """Loads the IBM Granite LLM, forcing it to use the GPU."""
24
- global _llm
25
- if _llm is None:
26
- print("Loading LLM for the first time...")
27
- llm_model_name = "ibm-granite/granite-3.3-8b-instruct"
28
-
29
- if not torch.cuda.is_available():
30
- raise RuntimeError("ZeroGPU requires a GPU. Please ensure hardware is set correctly.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- model = AutoModelForCausalLM.from_pretrained(
33
- llm_model_name, torch_dtype=torch.bfloat16, load_in_4bit=True
34
- )
35
- tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
36
-
37
- pipe = pipeline(
38
- "text-generation", model=model, tokenizer=tokenizer,
39
- max_new_tokens=512, temperature=0.1, device=0
40
- )
41
- _llm = HuggingFacePipeline(pipeline=pipe)
42
- return _llm
43
-
44
- def load_and_process_pdf(pdf_path="PMKisanSamanNidhi.PDF"):
45
- """Loads and processes the PDF into a FAISS vector store."""
46
- print("Loading and processing PDF...")
47
- doc = fitz.open(pdf_path)
48
- text = "".join(page.get_text() for page in doc)
49
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
50
  docs = text_splitter.create_documents([text])
51
- embedding_model = HuggingFaceEmbeddings(model_name="ibm-granite/granite-embedding-278m-multilingual")
 
 
 
 
52
  vector_db = FAISS.from_documents(docs, embedding_model)
53
  return vector_db
54
 
55
- def create_conversational_chain(llm, vector_db):
 
56
  """Creates the LangChain conversational retrieval chain."""
57
- prompt_template = """You are a polite and professional AI assistant for the PM-KISAN scheme... (rest of prompt)"""
 
 
 
 
 
 
 
 
 
58
  QA_PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
59
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key='answer')
 
60
  chain = ConversationalRetrievalChain.from_llm(
61
- llm=llm, retriever=vector_db.as_retriever(), memory=memory,
62
- return_source_documents=True, combine_docs_chain_kwargs={"prompt": QA_PROMPT}
 
 
 
63
  )
64
  return chain
65
 
66
- def get_qa_chain():
67
- """Initializes and returns the QA chain."""
68
- global _qa_chain
69
- if _qa_chain is None:
70
- llm = load_llm()
71
- vector_db = load_and_process_pdf()
72
- _qa_chain = create_conversational_chain(llm, vector_db)
73
- return _qa_chain
74
-
75
  def run_fairness_audit():
76
- """Performs the fairness audit and returns a formatted string."""
77
- df_display = pd.DataFrame({
 
 
 
 
 
78
  'query': ["loan for my farm", "help for my crops", "scheme for women", "grant for female farmer"],
79
  'gender_text': ['male', 'male', 'female', 'female'],
80
  'expected_doc': ['doc1', 'doc1', 'doc2', 'doc2']
81
- })
 
 
82
  def simulate_retriever(query):
83
  return "doc2" if "women" in query or "female" in query else "doc1"
84
  df_display['retrieved_doc'] = df_display['query'].apply(simulate_retriever)
85
  df_display['favorable_outcome'] = (df_display['retrieved_doc'] == df_display['expected_doc']).astype(int)
86
-
87
  df_for_aif = pd.DataFrame()
88
  df_for_aif['gender'] = df_display['gender_text'].map({'male': 1, 'female': 0})
89
  df_for_aif['favorable_outcome'] = df_display['favorable_outcome']
90
 
91
- aif_dataset = StandardDataset(df_for_aif, label_name='favorable_outcome', favorable_classes=[1],
92
- protected_attribute_names=['gender'], privileged_classes=[[1]])
 
 
 
 
93
  metric = BinaryLabelDatasetMetric(aif_dataset, unprivileged_groups=[{'gender': 0}], privileged_groups=[{'gender': 1}])
94
  spd = metric.statistical_parity_difference()
 
 
 
 
 
 
 
95
 
96
- report = f"""
97
- ### 🤖 IBM AIF360 - Fairness Audit Results
98
- **Metric: Statistical Parity Difference (SPD):** {spd:.4f}
99
- **Interpretation:** An SPD of 0.0 indicates perfect fairness in this simulation.
100
- ---
101
- **Raw Audit Data:**
102
- ```
103
- {df_display.to_string()}
104
- ```
105
- """
106
- return report
107
-
108
- # --- Gradio UI ---
109
- def chat_response(message, history):
110
- """Handles the user's message and returns the bot's response."""
111
- qa_chain = get_qa_chain()
112
- result = qa_chain.invoke({"question": message})
113
- response = result["answer"]
114
-
115
- # Add sources to the response
116
- source_docs = result.get("source_documents", [])
117
- if source_docs:
118
- response += "\n\n--- \n*Sources used to generate this answer:*"
119
- for i, doc in enumerate(source_docs):
120
- cleaned_content = ' '.join(doc.page_content.split())
121
- response += f"\n\n> **Source {i+1}:** \"{cleaned_content[:150]}...\""
122
-
123
- # Yield response for streaming effect
124
- for i in range(len(response)):
125
- time.sleep(0.005)
126
- yield response[:i+1]
127
-
128
- # Initialize the AI model on startup
129
- print("Initializing AI Chain...")
130
- get_qa_chain()
131
- print("AI Chain Ready.")
132
-
133
- with gr.Blocks(theme=gr.themes.Soft(), title="Sahay AI") as demo:
134
- gr.Markdown("# 🇮🇳 Chat with Sahay AI 💬")
135
- gr.Markdown("Your trusted guide to the PM-KISAN scheme, powered by IBM Granite.")
136
 
137
- with gr.Row():
138
- with gr.Column(scale=3):
139
- chatbot = gr.Chatbot(
140
- value=[[None, "Welcome! Ask me a question about the PM-KISAN scheme."]],
141
- label="Conversation",
142
- bubble_full_width=False,
143
- avatar_images=(None, "https://upload.wikimedia.org/wikipedia/commons/5/51/IBM_logo.svg")
144
- )
145
- msg = gr.Textbox(label="Your Question", placeholder="e.g., Who is eligible for this scheme?")
146
- submit_btn = gr.Button("Send", variant="primary")
147
-
148
- with gr.Column(scale=1):
149
- gr.Markdown("### Actions & Connect")
150
- audit_button = gr.Button("Run Fairness Audit")
151
- audit_report = gr.Markdown(visible=False)
152
-
153
- whatsapp_link = "https://wa.me/15551234567?text=Hello%20Sahay%20AI!"
154
- gr.Markdown(f"📱 [Chat on WhatsApp]({whatsapp_link})")
155
- gr.Markdown(" [View Project on GitHub](https://github.com)")
156
-
157
- # Event handlers
158
- msg.submit(chat_response, [msg, chatbot], chatbot)
159
- submit_btn.click(chat_response, [msg, chatbot], chatbot)
160
- msg.submit(lambda: "", None, msg) # Clear textbox on submit
161
- submit_btn.click(lambda: "", None, msg) # Clear textbox on submit
162
 
163
- def show_audit():
164
- report = run_fairness_audit()
165
- return gr.update(value=report, visible=True)
166
- audit_button.click(show_audit, outputs=audit_report)
 
 
167
 
168
- if __name__ == "__main__":
169
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
  import torch
3
  import fitz # PyMuPDF
4
+ from transformers import AutoTokenizer, pipeline, AutoModelForSeq2SeqLM # Import for T5 model
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
  from langchain_community.vectorstores import FAISS
7
  from langchain_community.embeddings import HuggingFaceEmbeddings
 
9
  from langchain.chains import ConversationalRetrievalChain
10
  from langchain.memory import ConversationBufferMemory
11
  from langchain.prompts import PromptTemplate
12
+
13
+ # For Fairness Audit
14
  import pandas as pd
15
  from aif360.datasets import StandardDataset
16
  from aif360.metrics import BinaryLabelDatasetMetric
 
17
 
18
+ # --- Page Configuration ---
19
+ st.set_page_config(
20
+ page_title="Sahay AI 🇮🇳",
21
+ page_icon="🤖",
22
+ layout="wide",
23
+ initial_sidebar_state="expanded"
24
+ )
25
 
26
+ # --- Caching for Performance ---
27
+ @st.cache_resource
28
  def load_llm():
29
+ """
30
+ Loads a smaller, CPU-friendly model (FLAN-T5-Base) for better performance
31
+ on the free Hugging Face Spaces hardware.
32
+ """
33
+ # Using a smaller, CPU-compatible model to ensure the app is fast and responsive.
34
+ llm_model_name = "google/flan-t5-base"
35
+
36
+ tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
37
+ # Use AutoModelForSeq2SeqLM for T5 models
38
+ model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_name)
39
+
40
+ pipe = pipeline(
41
+ "text2text-generation", # T5 models use this pipeline type
42
+ model=model,
43
+ tokenizer=tokenizer,
44
+ max_length=512
45
+ )
46
+ return HuggingFacePipeline(pipeline=pipe)
47
+
48
+ @st.cache_resource
49
+ def load_and_process_pdf(pdf_path):
50
+ """Loads, chunks, and embeds the PDF into a FAISS vector store using IBM's model."""
51
+ try:
52
+ doc = fitz.open(pdf_path)
53
+ text = "".join(page.get_text() for page in doc)
54
+ if not text:
55
+ st.error("Could not extract text from the PDF.")
56
+ return None
57
+ except Exception as e:
58
+ st.error(f"Error reading PDF file: {e}. Make sure 'PMKisanSamanNidhi.PDF' is uploaded to the Space.")
59
+ return None
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
62
  docs = text_splitter.create_documents([text])
63
+
64
+ # Still using the powerful IBM embedding model for multilingual understanding
65
+ model_name = "ibm-granite/granite-embedding-278m-multilingual"
66
+ embedding_model = HuggingFaceEmbeddings(model_name=model_name)
67
+
68
  vector_db = FAISS.from_documents(docs, embedding_model)
69
  return vector_db
70
 
71
+ # --- Conversational Chain ---
72
+ def create_conversational_chain(_llm, _vector_db):
73
  """Creates the LangChain conversational retrieval chain."""
74
+ prompt_template = """You are a polite and professional AI assistant for the PM-KISAN scheme.
75
+ Use the following context to answer the user's question precisely.
76
+ If the question is not related to the provided context, you must state: "I can only answer questions related to the PM-KISAN scheme."
77
+ Do not make up information.
78
+
79
+ Context: {context}
80
+ Question: {question}
81
+
82
+ Helpful Answer:"""
83
+
84
  QA_PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
85
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key='answer')
86
+
87
  chain = ConversationalRetrievalChain.from_llm(
88
+ llm=_llm,
89
+ retriever=_vector_db.as_retriever(search_kwargs={'k': 3}),
90
+ memory=memory,
91
+ return_source_documents=True,
92
+ combine_docs_chain_kwargs={"prompt": QA_PROMPT}
93
  )
94
  return chain
95
 
96
+ # --- IBM AIF360 Fairness Audit ---
 
 
 
 
 
 
 
 
97
  def run_fairness_audit():
98
+ """Performs and displays a simulated fairness audit."""
99
+ st.subheader("🤖 IBM AIF360 - Fairness Audit")
100
+ st.info("""
101
+ This is a simulation to demonstrate how we can check for bias in our information retriever.
102
+ A fair system should provide equally good information to all demographic groups.
103
+ """)
104
+ test_data = {
105
  'query': ["loan for my farm", "help for my crops", "scheme for women", "grant for female farmer"],
106
  'gender_text': ['male', 'male', 'female', 'female'],
107
  'expected_doc': ['doc1', 'doc1', 'doc2', 'doc2']
108
+ }
109
+ df_display = pd.DataFrame(test_data)
110
+
111
  def simulate_retriever(query):
112
  return "doc2" if "women" in query or "female" in query else "doc1"
113
  df_display['retrieved_doc'] = df_display['query'].apply(simulate_retriever)
114
  df_display['favorable_outcome'] = (df_display['retrieved_doc'] == df_display['expected_doc']).astype(int)
115
+
116
  df_for_aif = pd.DataFrame()
117
  df_for_aif['gender'] = df_display['gender_text'].map({'male': 1, 'female': 0})
118
  df_for_aif['favorable_outcome'] = df_display['favorable_outcome']
119
 
120
+ aif_dataset = StandardDataset(df_for_aif,
121
+ label_name='favorable_outcome',
122
+ favorable_classes=[1],
123
+ protected_attribute_names=['gender'],
124
+ privileged_classes=[[1]])
125
+
126
  metric = BinaryLabelDatasetMetric(aif_dataset, unprivileged_groups=[{'gender': 0}], privileged_groups=[{'gender': 1}])
127
  spd = metric.statistical_parity_difference()
128
+
129
+ st.markdown("---")
130
+ col1, col2 = st.columns(2)
131
+ with col1:
132
+ st.metric(label="**Metric: Statistical Parity Difference (SPD)**", value=f"{spd:.4f}")
133
+ with col2:
134
+ st.success("An SPD of **0.0** indicates perfect fairness in this simulation.")
135
 
136
+ with st.expander("Show Raw Audit Data"):
137
+ st.dataframe(df_display)
138
+
139
+ # --- Main Application UI ---
140
+ if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
+ with st.sidebar:
143
+ st.image("https://upload.wikimedia.org/wikipedia/commons/5/51/IBM_logo.svg", width=100)
144
+ st.title("🇮🇳 Sahay AI")
145
+ st.markdown("### About")
146
+ st.markdown("An AI assistant for the **PM-KISAN** scheme, built with IBM's multilingual embedding model.")
147
+ st.markdown("---")
148
+
149
+ st.markdown("### Actions")
150
+ if st.button("Run Fairness Audit", use_container_width=True):
151
+ st.session_state.run_audit = True
152
+ st.markdown("---")
153
+
154
+ st.markdown("### Connect")
155
+ st.markdown("📱 [Try the WhatsApp Bot](https://wa.me/15551234567?text=Hello%20Sahay%20AI!)") # Replace with your number
156
+ st.markdown("⭐ [View Project on GitHub](https://github.com)")
157
+ st.markdown("---")
158
+
159
+ st.header("Chat with Sahay AI 💬")
160
+ st.markdown("Your trusted guide to the PM-KISAN scheme.")
161
+
162
+ if st.session_state.get('run_audit', False):
163
+ run_fairness_audit()
164
+ st.session_state.run_audit = False
 
 
165
 
166
+ if "messages" not in st.session_state:
167
+ st.session_state.messages = []
168
+ st.session_state.messages.append({
169
+ "role": "assistant",
170
+ "content": "Welcome! How can I help you understand the PM-KISAN scheme today? You can ask me questions like:\n- What is this scheme about?\n- Who is eligible?\n- *इस योजना के लिए कौन पात्र है?*"
171
+ })
172
 
173
+ if "qa_chain" not in st.session_state:
174
+ with st.spinner("🚀 Initializing Sahay AI... This may take a moment."):
175
+ llm = load_llm()
176
+ vector_db = load_and_process_pdf("PMKisanSamanNidhi.PDF")
177
+ if vector_db:
178
+ st.session_state.qa_chain = create_conversational_chain(llm, vector_db)
179
+ else:
180
+ st.error("Application could not start. Please check the PDF file is uploaded correctly.")
181
+ st.stop()
182
+
183
+ for message in st.session_state.messages:
184
+ with st.chat_message(message["role"]):
185
+ st.markdown(message["content"])
186
+
187
+ if prompt := st.chat_input("Ask a question about the PM-KISAN scheme..."):
188
+ st.session_state.messages.append({"role": "user", "content": prompt})
189
+ with st.chat_message("user"):
190
+ st.markdown(prompt)
191
+
192
+ with st.chat_message("assistant"):
193
+ with st.spinner("🧠 Thinking..."):
194
+ if "qa_chain" in st.session_state:
195
+ result = st.session_state.qa_chain.invoke({"question": prompt})
196
+ response = result["answer"]
197
+ source_docs = result.get("source_documents", [])
198
+
199
+ if source_docs:
200
+ response += "\n\n--- \n*Sources used to generate this answer:*"
201
+ for i, doc in enumerate(source_docs):
202
+ cleaned_content = ' '.join(doc.page_content.split())
203
+ response += f"\n\n> **Source {i+1}:** \"{cleaned_content[:150]}...\""
204
+
205
+ st.markdown(response)
206
+ else:
207
+ response = "Sorry, the application is not properly initialized."
208
+ st.error(response)
209
+
210
+ st.session_state.messages.append({"role": "assistant", "content": response})