Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -38,18 +38,14 @@ except ImportError: PIL_TESSERACT_AVAILABLE = False; print("WARNING: Pillow or P
|
|
| 38 |
try: import whisper; WHISPER_AVAILABLE = True
|
| 39 |
except ImportError: WHISPER_AVAILABLE = False; print("WARNING: OpenAI Whisper not found, Audio Transcription tool will be disabled.")
|
| 40 |
|
| 41 |
-
#
|
| 42 |
-
from google import
|
| 43 |
-
from google.
|
| 44 |
-
# For FileState enum later
|
| 45 |
-
from google.ai import generativelanguage as glm
|
| 46 |
-
# --- End google-genai SDK ---
|
| 47 |
-
|
| 48 |
|
| 49 |
# LangChain
|
| 50 |
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage
|
| 51 |
from langchain.prompts import PromptTemplate
|
| 52 |
-
from langchain.tools import BaseTool, tool as lc_tool_decorator
|
| 53 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 54 |
from langchain.agents import AgentExecutor, create_react_agent
|
| 55 |
from langchain_community.tools import DuckDuckGoSearchRun
|
|
@@ -75,7 +71,9 @@ TOOLS: List[BaseTool] = []
|
|
| 75 |
LLM_INSTANCE: Optional[ChatGoogleGenerativeAI] = None
|
| 76 |
LANGGRAPH_MEMORY_SAVER: Optional[Any] = None
|
| 77 |
|
| 78 |
-
|
|
|
|
|
|
|
| 79 |
|
| 80 |
try:
|
| 81 |
from langgraph.graph import StateGraph, END
|
|
@@ -105,7 +103,7 @@ try:
|
|
| 105 |
LG_StateGraph, LG_END, LG_ToolInvocation, add_messages, MemorySaver_Class = (None,) * 5
|
| 106 |
print(f"WARNING: No suitable LangGraph tool executor (ToolNode/ToolExecutor) found. LangGraph agent will be disabled.")
|
| 107 |
|
| 108 |
-
except ImportError as e:
|
| 109 |
LANGGRAPH_FLAVOR_AVAILABLE = False
|
| 110 |
LG_StateGraph, LG_ToolExecutor_Class, LG_END, LG_ToolInvocation, add_messages, MemorySaver_Class = (None,) * 6
|
| 111 |
print(f"WARNING: Core LangGraph components (StateGraph, END) not found or import error: {e}. LangGraph agent will be disabled.")
|
|
@@ -134,7 +132,7 @@ logger = logging.getLogger(__name__)
|
|
| 134 |
# --- Initialize google-genai Client SDK ---
|
| 135 |
if GOOGLE_API_KEY:
|
| 136 |
try:
|
| 137 |
-
google_genai_client = google_genai_sdk.Client(api_key=GOOGLE_API_KEY)
|
| 138 |
logger.info("google-genai SDK Client initialized successfully.")
|
| 139 |
except Exception as e:
|
| 140 |
logger.error(f"Failed to initialize google-genai SDK Client: {e}")
|
|
@@ -263,10 +261,10 @@ def _download_file(file_identifier: str, task_id_for_file: Optional[str] = None)
|
|
| 263 |
logger.error(f"Download error for {file_url_to_try}: {e}", exc_info=True); return f"Error: {str(e)[:100]}"
|
| 264 |
|
| 265 |
# --- Tool Function Definitions ---
|
| 266 |
-
|
| 267 |
-
@lc_tool_decorator
|
| 268 |
def read_pdf_tool(action_input_json_str: str) -> str:
|
| 269 |
-
|
| 270 |
if not PYPDF2_AVAILABLE: return "Error: PyPDF2 not installed."
|
| 271 |
try: data = json.loads(action_input_json_str); file_id, task_id = data.get("file_identifier"), data.get("task_id")
|
| 272 |
except Exception as e: return f"Error parsing JSON for read_pdf_tool: {e}. Input: {action_input_json_str}"
|
|
@@ -286,10 +284,9 @@ def read_pdf_tool(action_input_json_str: str) -> str:
|
|
| 286 |
return text_content[:40000]
|
| 287 |
except Exception as e: return f"Error reading PDF '{path}': {e}"
|
| 288 |
|
| 289 |
-
|
| 290 |
-
@lc_tool_decorator(description=OCR_IMAGE_TOOL_DESC)
|
| 291 |
def ocr_image_tool(action_input_json_str: str) -> str:
|
| 292 |
-
|
| 293 |
if not PIL_TESSERACT_AVAILABLE: return "Error: Pillow/Pytesseract not installed."
|
| 294 |
try: data = json.loads(action_input_json_str); file_id, task_id = data.get("file_identifier"), data.get("task_id")
|
| 295 |
except Exception as e: return f"Error parsing JSON for ocr_image_tool: {e}. Input: {action_input_json_str}"
|
|
@@ -299,10 +296,9 @@ def ocr_image_tool(action_input_json_str: str) -> str:
|
|
| 299 |
try: return pytesseract.image_to_string(Image.open(path))[:40000]
|
| 300 |
except Exception as e: return f"Error OCR'ing '{path}': {e}"
|
| 301 |
|
| 302 |
-
|
| 303 |
-
@lc_tool_decorator(description=TRANSCRIBE_AUDIO_TOOL_DESC)
|
| 304 |
def transcribe_audio_tool(action_input_json_str: str) -> str:
|
| 305 |
-
|
| 306 |
global WHISPER_MODEL
|
| 307 |
if not WHISPER_AVAILABLE: return "Error: Whisper not installed."
|
| 308 |
try: data = json.loads(action_input_json_str); file_id, task_id = data.get("file_identifier"), data.get("task_id")
|
|
@@ -316,48 +312,37 @@ def transcribe_audio_tool(action_input_json_str: str) -> str:
|
|
| 316 |
try: result = WHISPER_MODEL.transcribe(path, fp16=False); return result["text"][:40000] # type: ignore
|
| 317 |
except Exception as e: logger.error(f"Whisper error on '{path}': {e}", exc_info=True); return f"Error transcribing '{path}': {e}"
|
| 318 |
|
| 319 |
-
|
| 320 |
-
"Processes an image file (URL or local path) along with a text prompt using a Gemini multimodal model (gemini-2.0-flash-exp) "
|
| 321 |
-
"for tasks like image description, Q&A about the image, or text generation based on the image. "
|
| 322 |
-
"Input: JSON '{\"file_identifier\": \"IMAGE_FILENAME_OR_URL\", \"text_prompt\": \"Your question or instruction.\", \"task_id\": \"TASK_ID\" (optional)}'. "
|
| 323 |
-
"Returns the model's text response."
|
| 324 |
-
)
|
| 325 |
-
@lc_tool_decorator(description=DIRECT_MULTIMODAL_GEMINI_TOOL_DESC)
|
| 326 |
def direct_multimodal_gemini_tool(action_input_json_str: str) -> str:
|
| 327 |
-
|
| 328 |
global google_genai_client
|
| 329 |
if not google_genai_client: return "Error: google-genai SDK client not initialized."
|
| 330 |
-
if not PIL_TESSERACT_AVAILABLE : return "Error: Pillow (PIL) library not available."
|
| 331 |
try:
|
| 332 |
data = json.loads(action_input_json_str)
|
| 333 |
file_identifier = data.get("file_identifier")
|
| 334 |
text_prompt = data.get("text_prompt", "Describe this image.")
|
| 335 |
task_id = data.get("task_id")
|
| 336 |
if not file_identifier: return "Error: 'file_identifier' for image missing."
|
| 337 |
-
logger.info(f"Direct Multimodal Tool:
|
| 338 |
local_image_path = _download_file(file_identifier, task_id)
|
| 339 |
-
if local_image_path.startswith("Error:"): return f"Error downloading for Direct
|
| 340 |
try:
|
| 341 |
pil_image = Image.open(local_image_path)
|
| 342 |
-
except Exception as e_img_open: return f"Error opening image {local_image_path}: {str(e_img_open)}"
|
| 343 |
-
|
| 344 |
-
# For the client SDK, model names often don't need "models/" prefix if it's a tuned model or specific ID.
|
| 345 |
-
# If it's a base model, "models/" is usually required. Let's assume GEMINI_FLASH_MULTIMODAL_MODEL_NAME is a direct ID.
|
| 346 |
-
# However, to be safe with client.models.generate_content, using "models/" is more standard.
|
| 347 |
model_id_for_client = f"models/{GEMINI_FLASH_MULTIMODAL_MODEL_NAME}" if not GEMINI_FLASH_MULTIMODAL_MODEL_NAME.startswith("models/") else GEMINI_FLASH_MULTIMODAL_MODEL_NAME
|
| 348 |
-
|
| 349 |
response = google_genai_client.models.generate_content(
|
| 350 |
-
model=model_id_for_client,
|
| 351 |
-
contents=[pil_image, text_prompt]
|
| 352 |
)
|
| 353 |
logger.info(f"Direct Multimodal Tool: Response received from {model_id_for_client} received.")
|
| 354 |
return response.text[:40000]
|
| 355 |
-
except json.JSONDecodeError as e_json_mm: return f"Error parsing JSON for Direct
|
| 356 |
except Exception as e_tool_mm:
|
| 357 |
logger.error(f"Error in direct_multimodal_gemini_tool: {e_tool_mm}", exc_info=True)
|
| 358 |
return f"Error executing Direct Multimodal Tool: {str(e_tool_mm)}"
|
| 359 |
|
| 360 |
-
# --- Agent Prompts
|
| 361 |
LANGGRAPH_PROMPT_TEMPLATE_STR = """You are a highly intelligent agent for the GAIA benchmark.
|
| 362 |
Your goal is to provide an EXACT MATCH final answer. No conversational text, explanations, or markdown unless explicitly part of the answer.
|
| 363 |
TOOLS:
|
|
@@ -396,7 +381,8 @@ def initialize_agent_and_tools(force_reinit=False):
|
|
| 396 |
logger.info("Initializing agent and tools...")
|
| 397 |
if not GOOGLE_API_KEY: raise ValueError("GOOGLE_API_KEY not set for LangChain LLM.")
|
| 398 |
|
| 399 |
-
#
|
|
|
|
| 400 |
llm_safety_settings_corrected_final = {
|
| 401 |
HarmCategory.HARM_CATEGORY_HARASSMENT.value: HarmBlockThreshold.BLOCK_NONE.value,
|
| 402 |
HarmCategory.HARM_CATEGORY_HATE_SPEECH.value: HarmBlockThreshold.BLOCK_NONE.value,
|
|
@@ -409,7 +395,7 @@ def initialize_agent_and_tools(force_reinit=False):
|
|
| 409 |
model=GEMINI_MODEL_NAME,
|
| 410 |
google_api_key=GOOGLE_API_KEY,
|
| 411 |
temperature=0.0,
|
| 412 |
-
safety_settings=llm_safety_settings_corrected_final,
|
| 413 |
timeout=120,
|
| 414 |
convert_system_message_to_human=True
|
| 415 |
)
|
|
@@ -640,7 +626,7 @@ with gr.Blocks(css=".gradio-container {max-width:1280px !important;margin:auto !
|
|
| 640 |
demo.load(update_ui_on_load_fn_within_context, [], [agent_status_display, missing_secrets_display])
|
| 641 |
|
| 642 |
if __name__ == "__main__":
|
| 643 |
-
logger.info(f"Application starting up (v7 -
|
| 644 |
if not PYPDF2_AVAILABLE: logger.warning("PyPDF2 (PDF tool) NOT AVAILABLE.")
|
| 645 |
if not PIL_TESSERACT_AVAILABLE: logger.warning("Pillow/Pytesseract (OCR tool) NOT AVAILABLE.")
|
| 646 |
if not WHISPER_AVAILABLE: logger.warning("Whisper (Audio tool) NOT AVAILABLE.")
|
|
@@ -668,4 +654,4 @@ if __name__ == "__main__":
|
|
| 668 |
|
| 669 |
logger.info(f"Space ID: {os.getenv('SPACE_ID', 'Not Set')}")
|
| 670 |
logger.info("Gradio Interface launching...")
|
| 671 |
-
demo.queue().launch(debug=os.getenv("GRADIO_DEBUG","false").lower()=="true", share=False, max_threads=20)
|
|
|
|
| 38 |
try: import whisper; WHISPER_AVAILABLE = True
|
| 39 |
except ImportError: WHISPER_AVAILABLE = False; print("WARNING: OpenAI Whisper not found, Audio Transcription tool will be disabled.")
|
| 40 |
|
| 41 |
+
# Google GenAI (Used by LangChain integration AND direct client)
|
| 42 |
+
from google.genai.types import HarmCategory, HarmBlockThreshold # CORRECTED IMPORT
|
| 43 |
+
from google.ai import generativelanguage as glm # For FileState enum
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
# LangChain
|
| 46 |
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage
|
| 47 |
from langchain.prompts import PromptTemplate
|
| 48 |
+
from langchain.tools import BaseTool, tool as lc_tool_decorator # Use langchain.tools.tool
|
| 49 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 50 |
from langchain.agents import AgentExecutor, create_react_agent
|
| 51 |
from langchain_community.tools import DuckDuckGoSearchRun
|
|
|
|
| 71 |
LLM_INSTANCE: Optional[ChatGoogleGenerativeAI] = None
|
| 72 |
LANGGRAPH_MEMORY_SAVER: Optional[Any] = None
|
| 73 |
|
| 74 |
+
# google-genai Client SDK
|
| 75 |
+
from google import genai as google_genai_sdk
|
| 76 |
+
google_genai_client: Optional[google_genai_sdk.Client] = None
|
| 77 |
|
| 78 |
try:
|
| 79 |
from langgraph.graph import StateGraph, END
|
|
|
|
| 103 |
LG_StateGraph, LG_END, LG_ToolInvocation, add_messages, MemorySaver_Class = (None,) * 5
|
| 104 |
print(f"WARNING: No suitable LangGraph tool executor (ToolNode/ToolExecutor) found. LangGraph agent will be disabled.")
|
| 105 |
|
| 106 |
+
except ImportError as e:
|
| 107 |
LANGGRAPH_FLAVOR_AVAILABLE = False
|
| 108 |
LG_StateGraph, LG_ToolExecutor_Class, LG_END, LG_ToolInvocation, add_messages, MemorySaver_Class = (None,) * 6
|
| 109 |
print(f"WARNING: Core LangGraph components (StateGraph, END) not found or import error: {e}. LangGraph agent will be disabled.")
|
|
|
|
| 132 |
# --- Initialize google-genai Client SDK ---
|
| 133 |
if GOOGLE_API_KEY:
|
| 134 |
try:
|
| 135 |
+
google_genai_client = google_genai_sdk.Client(api_key=GOOGLE_API_KEY)
|
| 136 |
logger.info("google-genai SDK Client initialized successfully.")
|
| 137 |
except Exception as e:
|
| 138 |
logger.error(f"Failed to initialize google-genai SDK Client: {e}")
|
|
|
|
| 261 |
logger.error(f"Download error for {file_url_to_try}: {e}", exc_info=True); return f"Error: {str(e)[:100]}"
|
| 262 |
|
| 263 |
# --- Tool Function Definitions ---
|
| 264 |
+
# Corrected: Removed 'description' from @lc_tool_decorator, use docstring
|
| 265 |
+
@lc_tool_decorator
|
| 266 |
def read_pdf_tool(action_input_json_str: str) -> str:
|
| 267 |
+
"""Reads text content from a PDF file. Input: JSON '{\"file_identifier\": \"FILENAME_OR_URL\", \"task_id\": \"TASK_ID_IF_GAIA_FILENAME_ONLY\"}'. Returns extracted text."""
|
| 268 |
if not PYPDF2_AVAILABLE: return "Error: PyPDF2 not installed."
|
| 269 |
try: data = json.loads(action_input_json_str); file_id, task_id = data.get("file_identifier"), data.get("task_id")
|
| 270 |
except Exception as e: return f"Error parsing JSON for read_pdf_tool: {e}. Input: {action_input_json_str}"
|
|
|
|
| 284 |
return text_content[:40000]
|
| 285 |
except Exception as e: return f"Error reading PDF '{path}': {e}"
|
| 286 |
|
| 287 |
+
@lc_tool_decorator
|
|
|
|
| 288 |
def ocr_image_tool(action_input_json_str: str) -> str:
|
| 289 |
+
"""Extracts text from an image using OCR. Input: JSON '{\"file_identifier\": \"FILENAME_OR_URL\", \"task_id\": \"TASK_ID_IF_GAIA_FILENAME_ONLY\"}'. Returns extracted text."""
|
| 290 |
if not PIL_TESSERACT_AVAILABLE: return "Error: Pillow/Pytesseract not installed."
|
| 291 |
try: data = json.loads(action_input_json_str); file_id, task_id = data.get("file_identifier"), data.get("task_id")
|
| 292 |
except Exception as e: return f"Error parsing JSON for ocr_image_tool: {e}. Input: {action_input_json_str}"
|
|
|
|
| 296 |
try: return pytesseract.image_to_string(Image.open(path))[:40000]
|
| 297 |
except Exception as e: return f"Error OCR'ing '{path}': {e}"
|
| 298 |
|
| 299 |
+
@lc_tool_decorator
|
|
|
|
| 300 |
def transcribe_audio_tool(action_input_json_str: str) -> str:
|
| 301 |
+
"""Transcribes speech from an audio file (or YouTube URL) to text. Input: JSON '{\"file_identifier\": \"FILENAME_OR_URL_OR_YOUTUBE_URL\", \"task_id\": \"TASK_ID_IF_GAIA_FILENAME_ONLY\"}'. Returns transcript."""
|
| 302 |
global WHISPER_MODEL
|
| 303 |
if not WHISPER_AVAILABLE: return "Error: Whisper not installed."
|
| 304 |
try: data = json.loads(action_input_json_str); file_id, task_id = data.get("file_identifier"), data.get("task_id")
|
|
|
|
| 312 |
try: result = WHISPER_MODEL.transcribe(path, fp16=False); return result["text"][:40000] # type: ignore
|
| 313 |
except Exception as e: logger.error(f"Whisper error on '{path}': {e}", exc_info=True); return f"Error transcribing '{path}': {e}"
|
| 314 |
|
| 315 |
+
@lc_tool_decorator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
def direct_multimodal_gemini_tool(action_input_json_str: str) -> str:
|
| 317 |
+
"""Processes an image file (URL or local path) along with a text prompt using a Gemini multimodal model (gemini-2.0-flash-exp) for tasks like image description, Q&A about the image, or text generation based on the image. Input: JSON '{\"file_identifier\": \"IMAGE_FILENAME_OR_URL\", \"text_prompt\": \"Your question or instruction related to the image.\", \"task_id\": \"TASK_ID_IF_GAIA_FILENAME_ONLY\" (optional)}'. Returns the model's text response."""
|
| 318 |
global google_genai_client
|
| 319 |
if not google_genai_client: return "Error: google-genai SDK client not initialized."
|
| 320 |
+
if not PIL_TESSERACT_AVAILABLE : return "Error: Pillow (PIL) library not available for image processing."
|
| 321 |
try:
|
| 322 |
data = json.loads(action_input_json_str)
|
| 323 |
file_identifier = data.get("file_identifier")
|
| 324 |
text_prompt = data.get("text_prompt", "Describe this image.")
|
| 325 |
task_id = data.get("task_id")
|
| 326 |
if not file_identifier: return "Error: 'file_identifier' for image missing."
|
| 327 |
+
logger.info(f"Direct Multimodal Tool: Processing image '{file_identifier}' with prompt '{text_prompt}'")
|
| 328 |
local_image_path = _download_file(file_identifier, task_id)
|
| 329 |
+
if local_image_path.startswith("Error:"): return f"Error downloading image for Direct Multimodal Tool: {local_image_path}"
|
| 330 |
try:
|
| 331 |
pil_image = Image.open(local_image_path)
|
| 332 |
+
except Exception as e_img_open: return f"Error opening image file {local_image_path}: {str(e_img_open)}"
|
| 333 |
+
|
|
|
|
|
|
|
|
|
|
| 334 |
model_id_for_client = f"models/{GEMINI_FLASH_MULTIMODAL_MODEL_NAME}" if not GEMINI_FLASH_MULTIMODAL_MODEL_NAME.startswith("models/") else GEMINI_FLASH_MULTIMODAL_MODEL_NAME
|
|
|
|
| 335 |
response = google_genai_client.models.generate_content(
|
| 336 |
+
model=model_id_for_client, contents=[pil_image, text_prompt]
|
|
|
|
| 337 |
)
|
| 338 |
logger.info(f"Direct Multimodal Tool: Response received from {model_id_for_client} received.")
|
| 339 |
return response.text[:40000]
|
| 340 |
+
except json.JSONDecodeError as e_json_mm: return f"Error parsing JSON input for Direct Multimodal Tool: {str(e_json_mm)}. Input: {action_input_json_str}"
|
| 341 |
except Exception as e_tool_mm:
|
| 342 |
logger.error(f"Error in direct_multimodal_gemini_tool: {e_tool_mm}", exc_info=True)
|
| 343 |
return f"Error executing Direct Multimodal Tool: {str(e_tool_mm)}"
|
| 344 |
|
| 345 |
+
# --- Agent Prompts ---
|
| 346 |
LANGGRAPH_PROMPT_TEMPLATE_STR = """You are a highly intelligent agent for the GAIA benchmark.
|
| 347 |
Your goal is to provide an EXACT MATCH final answer. No conversational text, explanations, or markdown unless explicitly part of the answer.
|
| 348 |
TOOLS:
|
|
|
|
| 381 |
logger.info("Initializing agent and tools...")
|
| 382 |
if not GOOGLE_API_KEY: raise ValueError("GOOGLE_API_KEY not set for LangChain LLM.")
|
| 383 |
|
| 384 |
+
# Corrected safety_settings format for ChatGoogleGenerativeAI
|
| 385 |
+
# Using INTEGER VALUES for HarmCategory keys and HarmBlockThreshold enum members for values.
|
| 386 |
llm_safety_settings_corrected_final = {
|
| 387 |
HarmCategory.HARM_CATEGORY_HARASSMENT.value: HarmBlockThreshold.BLOCK_NONE.value,
|
| 388 |
HarmCategory.HARM_CATEGORY_HATE_SPEECH.value: HarmBlockThreshold.BLOCK_NONE.value,
|
|
|
|
| 395 |
model=GEMINI_MODEL_NAME,
|
| 396 |
google_api_key=GOOGLE_API_KEY,
|
| 397 |
temperature=0.0,
|
| 398 |
+
safety_settings=llm_safety_settings_corrected_final,
|
| 399 |
timeout=120,
|
| 400 |
convert_system_message_to_human=True
|
| 401 |
)
|
|
|
|
| 626 |
demo.load(update_ui_on_load_fn_within_context, [], [agent_status_display, missing_secrets_display])
|
| 627 |
|
| 628 |
if __name__ == "__main__":
|
| 629 |
+
logger.info(f"Application starting up (v7 - Final SafetySettings Fix)...")
|
| 630 |
if not PYPDF2_AVAILABLE: logger.warning("PyPDF2 (PDF tool) NOT AVAILABLE.")
|
| 631 |
if not PIL_TESSERACT_AVAILABLE: logger.warning("Pillow/Pytesseract (OCR tool) NOT AVAILABLE.")
|
| 632 |
if not WHISPER_AVAILABLE: logger.warning("Whisper (Audio tool) NOT AVAILABLE.")
|
|
|
|
| 654 |
|
| 655 |
logger.info(f"Space ID: {os.getenv('SPACE_ID', 'Not Set')}")
|
| 656 |
logger.info("Gradio Interface launching...")
|
| 657 |
+
demo.queue().launch(debug=os.getenv("GRADIO_DEBUG","false").lower()=="true", share=False, max_threads=20)
|