Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -10,8 +10,9 @@ import hashlib
|
|
| 10 |
from urllib.parse import urlparse
|
| 11 |
import mimetypes
|
| 12 |
import subprocess # For yt-dlp
|
|
|
|
| 13 |
|
| 14 |
-
from huggingface_hub import get_space_runtime
|
| 15 |
|
| 16 |
# --- Global Variables for Startup Status ---
|
| 17 |
missing_vars_startup_list_global = []
|
|
@@ -19,16 +20,17 @@ agent_pre_init_status_msg_global = "Agent status will be determined at startup."
|
|
| 19 |
|
| 20 |
# File Processing Libs
|
| 21 |
try: from PyPDF2 import PdfReader; PYPDF2_AVAILABLE = True
|
| 22 |
-
except ImportError: PYPDF2_AVAILABLE = False
|
| 23 |
-
try: from PIL import Image; import pytesseract; PIL_TESSERACT_AVAILABLE = True
|
| 24 |
-
except ImportError: PIL_TESSERACT_AVAILABLE = False
|
| 25 |
try: import whisper; WHISPER_AVAILABLE = True
|
| 26 |
-
except ImportError: WHISPER_AVAILABLE = False
|
| 27 |
|
| 28 |
-
# Google GenAI
|
| 29 |
from google.generativeai.types import HarmCategory, HarmBlockThreshold
|
|
|
|
| 30 |
# LangChain
|
| 31 |
-
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage
|
| 32 |
from langchain.prompts import PromptTemplate
|
| 33 |
from langchain.tools import BaseTool, tool as lc_tool_decorator
|
| 34 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
@@ -46,47 +48,72 @@ if TYPE_CHECKING:
|
|
| 46 |
LANGGRAPH_FLAVOR_AVAILABLE = False
|
| 47 |
LG_StateGraph: Optional[Type[Any]] = None
|
| 48 |
LG_ToolExecutor: Optional[Type[Any]] = None
|
| 49 |
-
LG_END: Optional[Any] = None
|
| 50 |
LG_ToolInvocation: Optional[Type[Any]] = None
|
| 51 |
-
add_messages: Optional[Any] = None
|
| 52 |
MemorySaver_Class: Optional[Type[Any]] = None
|
| 53 |
|
| 54 |
AGENT_INSTANCE: Optional[Union[AgentExecutor, Any]] = None
|
| 55 |
TOOLS: List[BaseTool] = []
|
| 56 |
-
LLM_INSTANCE: Optional[ChatGoogleGenerativeAI] = None
|
| 57 |
LANGGRAPH_MEMORY_SAVER: Optional[Any] = None
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
try:
|
| 60 |
-
from langgraph.graph import StateGraph, END
|
| 61 |
from langgraph.prebuilt import ToolExecutor, ToolInvocation as LGToolInvocationActual
|
| 62 |
from langgraph.graph.message import add_messages as lg_add_messages
|
| 63 |
-
from langgraph.checkpoint.memory import MemorySaver as LGMemorySaver
|
| 64 |
LANGGRAPH_FLAVOR_AVAILABLE = True
|
| 65 |
LG_StateGraph, LG_ToolExecutor, LG_END, LG_ToolInvocation, add_messages, MemorySaver_Class = \
|
| 66 |
StateGraph, ToolExecutor, END, LGToolInvocationActual, lg_add_messages, LGMemorySaver
|
| 67 |
-
|
|
|
|
| 68 |
LANGGRAPH_FLAVOR_AVAILABLE = False
|
|
|
|
| 69 |
LG_StateGraph, LG_ToolExecutor, LG_END, LG_ToolInvocation, add_messages, MemorySaver_Class = (None,) * 6
|
|
|
|
| 70 |
|
| 71 |
# --- Constants ---
|
| 72 |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
| 73 |
-
GEMINI_MODEL_NAME
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
SCORING_API_BASE_URL = os.getenv("SCORING_API_URL", DEFAULT_API_URL)
|
| 75 |
-
MAX_FILE_SIZE_BYTES = 50 * 1024 * 1024
|
| 76 |
-
LOCAL_FILE_STORE_PATH = "./Data"
|
|
|
|
| 77 |
|
| 78 |
# --- Global State ---
|
| 79 |
WHISPER_MODEL: Optional[Any] = None
|
| 80 |
|
| 81 |
# --- Environment Variables & API Keys ---
|
| 82 |
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
|
| 83 |
-
HUGGINGFACE_TOKEN = os.environ.get("HF_TOKEN")
|
| 84 |
|
| 85 |
# --- Setup Logging ---
|
| 86 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(name)s - %(module)s:%(lineno)d - %(message)s')
|
| 87 |
logger = logging.getLogger(__name__)
|
| 88 |
|
| 89 |
-
# ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
def _strip_exact_match_answer(text: Any) -> str:
|
| 91 |
if not isinstance(text, str): text = str(text)
|
| 92 |
text_lower_check = text.lower()
|
|
@@ -94,7 +121,7 @@ def _strip_exact_match_answer(text: Any) -> str:
|
|
| 94 |
text = text[len("final answer:"):].strip()
|
| 95 |
text = text.strip()
|
| 96 |
if text.startswith("```") and text.endswith("```"):
|
| 97 |
-
if "\n" in text:
|
| 98 |
text_content = text.split("\n", 1)[1] if len(text.split("\n", 1)) > 1 else ""
|
| 99 |
text = text_content.strip()[:-3].strip() if text_content.strip().endswith("```") else text[3:-3].strip()
|
| 100 |
else: text = text[3:-3].strip()
|
|
@@ -113,6 +140,7 @@ def _is_youtube_url(url: str) -> bool:
|
|
| 113 |
return parsed_url.netloc.lower().endswith(("youtube.com", "youtu.be"))
|
| 114 |
|
| 115 |
def _download_file(file_identifier: str, task_id_for_file: Optional[str] = None) -> str:
|
|
|
|
| 116 |
os.makedirs(LOCAL_FILE_STORE_PATH, exist_ok=True)
|
| 117 |
logger.debug(f"Download request: '{file_identifier}', task_id: {task_id_for_file}")
|
| 118 |
original_filename = os.path.basename(urlparse(file_identifier).path) if _is_full_url(file_identifier) else os.path.basename(file_identifier)
|
|
@@ -131,12 +159,12 @@ def _download_file(file_identifier: str, task_id_for_file: Optional[str] = None)
|
|
| 131 |
logger.info(f"Cached YouTube MP3: {target_mp3_path}"); return target_mp3_path
|
| 132 |
temp_output_template = os.path.join(LOCAL_FILE_STORE_PATH, yt_filename_base + "_temp.%(ext)s")
|
| 133 |
try:
|
| 134 |
-
command = ['yt-dlp', '--quiet', '--no-warnings', '-x', '--audio-format', 'mp3',
|
| 135 |
'--audio-quality', '0', '--max-filesize', str(MAX_FILE_SIZE_BYTES),
|
| 136 |
'-o', temp_output_template, file_identifier]
|
| 137 |
logger.info(f"yt-dlp command: {' '.join(command)}")
|
| 138 |
process = subprocess.run(command, capture_output=True, text=True, timeout=180, check=False)
|
| 139 |
-
downloaded_temp_file = next((os.path.join(LOCAL_FILE_STORE_PATH, f) for f in os.listdir(LOCAL_FILE_STORE_PATH)
|
| 140 |
if f.startswith(yt_filename_base + "_temp") and f.endswith(".mp3")), None)
|
| 141 |
if process.returncode == 0 and downloaded_temp_file and os.path.exists(downloaded_temp_file):
|
| 142 |
os.rename(downloaded_temp_file, target_mp3_path)
|
|
@@ -169,28 +197,28 @@ def _download_file(file_identifier: str, task_id_for_file: Optional[str] = None)
|
|
| 169 |
if cd_header:
|
| 170 |
try:
|
| 171 |
decoded_cd_header = cd_header.encode('latin-1', 'replace').decode('utf-8', 'replace')
|
| 172 |
-
_, params = requests.utils.parse_header_links(decoded_cd_header)
|
| 173 |
for key, val in params.items():
|
| 174 |
if key.lower() == 'filename*' and val.lower().startswith("utf-8''"):
|
| 175 |
filename_from_cd = requests.utils.unquote(val[len("utf-8''"):]); break
|
| 176 |
elif key.lower() == 'filename':
|
| 177 |
filename_from_cd = requests.utils.unquote(val)
|
| 178 |
if filename_from_cd.startswith('"') and filename_from_cd.endswith('"'): filename_from_cd = filename_from_cd[1:-1]
|
| 179 |
-
break
|
| 180 |
except Exception as e_cd: logger.warning(f"CD parse error '{cd_header}': {e_cd}")
|
| 181 |
if filename_from_cd:
|
| 182 |
sanitized_cd_filename = "".join(c if c.isalnum() or c in ['.', '_', '-'] else '_' for c in filename_from_cd)
|
| 183 |
effective_save_path = os.path.join(LOCAL_FILE_STORE_PATH, f"{prefix}{sanitized_cd_filename}")
|
| 184 |
logger.info(f"Using CD filename: '{sanitized_cd_filename}'. Path: {effective_save_path}")
|
| 185 |
-
|
| 186 |
name_without_ext, current_ext = os.path.splitext(effective_save_path)
|
| 187 |
-
if not current_ext:
|
| 188 |
content_type_header = r.headers.get('content-type', '')
|
| 189 |
-
content_type_val = content_type_header.split(';').strip() if content_type_header else ''
|
| 190 |
if content_type_val:
|
| 191 |
guessed_ext = mimetypes.guess_extension(content_type_val)
|
| 192 |
if guessed_ext: effective_save_path += guessed_ext; logger.info(f"Added guessed ext: {guessed_ext}")
|
| 193 |
-
|
| 194 |
if effective_save_path != tentative_local_path and os.path.exists(effective_save_path) and os.path.getsize(effective_save_path) > 0:
|
| 195 |
logger.info(f"Cached file (CD name): {effective_save_path}"); return effective_save_path
|
| 196 |
with open(effective_save_path, "wb") as f:
|
|
@@ -202,10 +230,11 @@ def _download_file(file_identifier: str, task_id_for_file: Optional[str] = None)
|
|
| 202 |
except Exception as e:
|
| 203 |
logger.error(f"Download error for {file_url_to_try}: {e}", exc_info=True); return f"Error: {str(e)[:100]}"
|
| 204 |
|
| 205 |
-
# --- Tool Function Definitions ---
|
| 206 |
-
READ_PDF_TOOL_DESC = "Reads PDF. Input: JSON '{\"file_identifier\": \"
|
| 207 |
@lc_tool_decorator(description=READ_PDF_TOOL_DESC)
|
| 208 |
def read_pdf_tool(action_input_json_str: str) -> str:
|
|
|
|
| 209 |
if not PYPDF2_AVAILABLE: return "Error: PyPDF2 not installed."
|
| 210 |
try: data = json.loads(action_input_json_str); file_id, task_id = data.get("file_identifier"), data.get("task_id")
|
| 211 |
except Exception as e: return f"Error parsing JSON for read_pdf_tool: {e}. Input: {action_input_json_str}"
|
|
@@ -213,19 +242,23 @@ def read_pdf_tool(action_input_json_str: str) -> str:
|
|
| 213 |
path = _download_file(file_id, task_id)
|
| 214 |
if path.startswith("Error:"): return path
|
| 215 |
try:
|
| 216 |
-
text = "";
|
| 217 |
-
with open(path, "rb") as f:
|
| 218 |
reader = PdfReader(f)
|
| 219 |
if reader.is_encrypted:
|
| 220 |
try: reader.decrypt('')
|
| 221 |
except: return f"Error: PDF '{path}' encrypted."
|
| 222 |
-
for
|
| 223 |
-
|
|
|
|
|
|
|
| 224 |
except Exception as e: return f"Error reading PDF '{path}': {e}"
|
| 225 |
|
| 226 |
-
|
|
|
|
| 227 |
@lc_tool_decorator(description=OCR_IMAGE_TOOL_DESC)
|
| 228 |
def ocr_image_tool(action_input_json_str: str) -> str:
|
|
|
|
| 229 |
if not PIL_TESSERACT_AVAILABLE: return "Error: Pillow/Pytesseract not installed."
|
| 230 |
try: data = json.loads(action_input_json_str); file_id, task_id = data.get("file_identifier"), data.get("task_id")
|
| 231 |
except Exception as e: return f"Error parsing JSON for ocr_image_tool: {e}. Input: {action_input_json_str}"
|
|
@@ -235,9 +268,10 @@ def ocr_image_tool(action_input_json_str: str) -> str:
|
|
| 235 |
try: return pytesseract.image_to_string(Image.open(path))[:40000]
|
| 236 |
except Exception as e: return f"Error OCR'ing '{path}': {e}"
|
| 237 |
|
| 238 |
-
TRANSCRIBE_AUDIO_TOOL_DESC = "Transcribes audio
|
| 239 |
@lc_tool_decorator(description=TRANSCRIBE_AUDIO_TOOL_DESC)
|
| 240 |
def transcribe_audio_tool(action_input_json_str: str) -> str:
|
|
|
|
| 241 |
global WHISPER_MODEL
|
| 242 |
if not WHISPER_AVAILABLE: return "Error: Whisper not installed."
|
| 243 |
try: data = json.loads(action_input_json_str); file_id, task_id = data.get("file_identifier"), data.get("task_id")
|
|
@@ -248,137 +282,221 @@ def transcribe_audio_tool(action_input_json_str: str) -> str:
|
|
| 248 |
except Exception as e: logger.error(f"Whisper load failed: {e}"); return f"Error: Whisper load: {e}"
|
| 249 |
path = _download_file(file_id, task_id)
|
| 250 |
if path.startswith("Error:"): return path
|
| 251 |
-
try: result = WHISPER_MODEL.transcribe(path, fp16=False); return result["text"][:40000]
|
| 252 |
except Exception as e: logger.error(f"Whisper error on '{path}': {e}", exc_info=True); return f"Error transcribing '{path}': {e}"
|
| 253 |
|
| 254 |
-
#
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
|
| 284 |
# --- Agent Initialization and Response Logic ---
|
| 285 |
def initialize_agent_and_tools(force_reinit=False):
|
| 286 |
-
global AGENT_INSTANCE, TOOLS, LLM_INSTANCE, LANGGRAPH_FLAVOR_AVAILABLE, LG_StateGraph, LG_ToolExecutor, LG_END, LG_ToolInvocation, add_messages, MemorySaver_Class, LANGGRAPH_MEMORY_SAVER
|
| 287 |
if AGENT_INSTANCE and not force_reinit: logger.info("Agent already initialized."); return
|
| 288 |
logger.info("Initializing agent and tools...")
|
| 289 |
-
if not GOOGLE_API_KEY: raise ValueError("GOOGLE_API_KEY not set.")
|
| 290 |
-
|
| 291 |
-
|
| 292 |
try:
|
| 293 |
-
LLM_INSTANCE = ChatGoogleGenerativeAI(model=GEMINI_MODEL_NAME, google_api_key=GOOGLE_API_KEY, temperature=0.0,
|
| 294 |
safety_settings={HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
| 295 |
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,},
|
| 296 |
request_timeout=120, convert_system_message_to_human=True )
|
| 297 |
-
logger.info(f"LLM initialized: {GEMINI_MODEL_NAME}")
|
| 298 |
-
except Exception as e: logger.error(f"LLM init failed: {e}", exc_info=True); raise
|
|
|
|
| 299 |
TOOLS = []
|
| 300 |
if PYPDF2_AVAILABLE: TOOLS.append(read_pdf_tool)
|
| 301 |
if PIL_TESSERACT_AVAILABLE: TOOLS.append(ocr_image_tool)
|
| 302 |
if WHISPER_AVAILABLE: TOOLS.append(transcribe_audio_tool)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
try: search_tool = DuckDuckGoSearchRun(name="web_search"); search_tool.description = "Web search. Input: query."; TOOLS.append(search_tool)
|
| 304 |
except Exception as e: logger.warning(f"DuckDuckGoSearchRun init failed: {e}")
|
| 305 |
try: python_repl = PythonREPLTool(name="python_repl"); python_repl.description = "Python REPL. print() for output."; TOOLS.append(python_repl)
|
| 306 |
except Exception as e: logger.warning(f"PythonREPLTool init failed: {e}")
|
| 307 |
-
logger.info(f"
|
|
|
|
| 308 |
|
|
|
|
| 309 |
if LANGGRAPH_FLAVOR_AVAILABLE and all([LG_StateGraph, LG_ToolExecutor, LG_END, LLM_INSTANCE, LG_ToolInvocation, add_messages]):
|
| 310 |
-
if not LANGGRAPH_MEMORY_SAVER and MemorySaver_Class: LANGGRAPH_MEMORY_SAVER = MemorySaver_Class()
|
| 311 |
try:
|
| 312 |
logger.info(f"Attempting LangGraph init (Memory: {LANGGRAPH_MEMORY_SAVER is not None})")
|
| 313 |
-
_TypedDict =
|
| 314 |
-
class AgentState(_TypedDict): input: str; messages: Annotated[List[
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
continue
|
| 327 |
try:
|
| 328 |
-
logger.info(f"LG Tool Invoking: '{
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
logger.info(f"LangGraph compiled (Memory: {LANGGRAPH_MEMORY_SAVER is not None}).")
|
| 340 |
-
except Exception as
|
| 341 |
-
|
|
|
|
|
|
|
| 342 |
|
| 343 |
if not AGENT_INSTANCE:
|
| 344 |
-
logger.info("Initializing ReAct agent.")
|
| 345 |
try:
|
| 346 |
-
if not LLM_INSTANCE: raise ValueError("
|
| 347 |
-
|
| 348 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
logger.info("ReAct agent initialized.")
|
| 350 |
-
except Exception as
|
| 351 |
-
|
| 352 |
-
|
|
|
|
|
|
|
| 353 |
|
|
|
|
|
|
|
|
|
|
| 354 |
def get_agent_response(prompt: str, task_id: Optional[str]=None, thread_id: Optional[str]=None) -> str:
|
|
|
|
| 355 |
global AGENT_INSTANCE, LLM_INSTANCE
|
| 356 |
-
|
| 357 |
if not AGENT_INSTANCE or not LLM_INSTANCE:
|
| 358 |
-
logger.warning("
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
|
|
|
|
|
|
|
|
|
| 362 |
try:
|
| 363 |
-
if
|
| 364 |
-
logger.debug(f"Using LangGraph (Memory: {LANGGRAPH_MEMORY_SAVER is not None})")
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
|
| 374 |
def construct_prompt_for_agent(q: Dict[str,Any]) -> str:
|
|
|
|
| 375 |
tid,q_str=q.get("task_id","N/A"),q.get("question",""); files=q.get("files",[])
|
| 376 |
files_info = ("\nFiles:\n"+"\n".join([f"- {f} (task_id:{tid})"for f in files])) if files else ""
|
| 377 |
level = f"\nLevel:{q.get('level')}" if q.get('level') else ""
|
| 378 |
return f"Task ID:{tid}{level}{files_info}\n\nQuestion:{q_str}"
|
| 379 |
|
| 380 |
-
|
| 381 |
-
|
| 382 |
global AGENT_INSTANCE
|
| 383 |
space_id = os.getenv("SPACE_ID")
|
| 384 |
username_for_submission = None
|
|
@@ -387,80 +505,79 @@ def run_and_submit_all(profile: Optional[gr.OAuthProfile] = None): # Re-added pr
|
|
| 387 |
username_for_submission = profile.username
|
| 388 |
logger.info(f"Username from OAuth profile: {username_for_submission}")
|
| 389 |
else:
|
| 390 |
-
|
| 391 |
-
logger.warning("OAuth profile not available or username missing. Submission might fail or be attributed to a default/fallback if allowed by API.")
|
| 392 |
-
# As per strict template, we should stop if no profile.
|
| 393 |
return "Hugging Face login required. Please use the login button and try again.", None
|
| 394 |
|
| 395 |
if AGENT_INSTANCE is None:
|
| 396 |
-
try: logger.info("Agent not pre-initialized. Initializing
|
| 397 |
except Exception as e: return f"Agent on-demand initialization failed: {e}", None
|
| 398 |
-
if AGENT_INSTANCE is None: return "Agent is still None after on-demand
|
| 399 |
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
try:
|
| 404 |
-
logger.info(f"Fetching questions from {
|
| 405 |
-
|
| 406 |
-
if not
|
| 407 |
-
logger.info(f"Fetched {len(
|
| 408 |
except Exception as e:logger.error(f"Fetch questions error: {e}",exc_info=True);return f"Fetch questions error:{e}",None
|
| 409 |
|
| 410 |
-
|
| 411 |
-
logger.info(f"Running agent on {len(
|
| 412 |
-
for i,
|
| 413 |
-
|
| 414 |
-
if not
|
| 415 |
-
|
| 416 |
-
logger.info(f"Processing Q {i+1}/{len(
|
| 417 |
try:
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
except Exception as e:
|
| 422 |
-
logger.error(f"Agent error task {
|
| 423 |
-
|
| 424 |
-
|
| 425 |
|
| 426 |
-
if not
|
| 427 |
-
|
| 428 |
-
logger.info(f"Submitting {len(
|
| 429 |
-
|
| 430 |
try:
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
logger.info(f"Submission OK! {
|
| 434 |
except requests.exceptions.HTTPError as e:
|
| 435 |
-
|
| 436 |
-
except Exception as e:logger.error(f"Submit Fail unexpected:{e}",exc_info=True);return f"Submit Fail:{str(e)[:100]}",pd.DataFrame(
|
|
|
|
| 437 |
|
| 438 |
-
# --- Build Gradio Interface ---
|
| 439 |
with gr.Blocks(css=".gradio-container {max-width:1280px !important;margin:auto !important;}",theme=gr.themes.Soft()) as demo:
|
| 440 |
-
|
|
|
|
| 441 |
gr.Markdown(f"""**Instructions:**
|
| 442 |
1. **Login with Hugging Face** using the button below. Your HF username will be used for submission.
|
| 443 |
2. Click 'Run Evaluation & Submit' to process GAIA questions (typically 20).
|
| 444 |
-
3. **Goal: 30%+ (6/20).** Agent uses Gemini
|
| 445 |
4. Ensure `GOOGLE_API_KEY` and `HUGGINGFACE_TOKEN` are Space secrets.
|
| 446 |
5. Check Space logs for details. LangGraph is attempted (ReAct fallback).""")
|
| 447 |
-
|
| 448 |
-
agent_status_display = gr.Markdown("**Agent Status:** Initializing...")
|
| 449 |
-
missing_secrets_display = gr.Markdown("")
|
| 450 |
|
| 451 |
-
gr.
|
|
|
|
|
|
|
|
|
|
| 452 |
run_button = gr.Button("Run Evaluation & Submit All Answers")
|
| 453 |
status_output = gr.Textbox(label="Run Status / Submission Result", lines=7, interactive=False)
|
| 454 |
-
results_table = gr.DataFrame(label="Q&A Log", headers=["Task ID","Question","Prompt","Raw","Submitted"], wrap=True)
|
| 455 |
-
|
| 456 |
-
# The `profile` argument in `run_and_submit_all` will be populated by Gradio
|
| 457 |
-
# if the user is logged in via the `gr.LoginButton()` flow.
|
| 458 |
run_button.click(fn=run_and_submit_all, outputs=[status_output,results_table], api_name="run_evaluation")
|
| 459 |
|
| 460 |
def update_ui_on_load_fn_within_context():
|
| 461 |
-
|
|
|
|
| 462 |
secrets_msg_md = ""
|
| 463 |
-
if missing_vars_startup_list_global:
|
| 464 |
secrets_msg_md = f"<font color='red'>**⚠️ Secrets Missing:** {', '.join(missing_vars_startup_list_global)}.</font>"
|
| 465 |
env_issues = []
|
| 466 |
try: subprocess.run(['yt-dlp','--version'],check=True,stdout=subprocess.DEVNULL,stderr=subprocess.DEVNULL)
|
|
@@ -468,35 +585,34 @@ with gr.Blocks(css=".gradio-container {max-width:1280px !important;margin:auto !
|
|
| 468 |
try: subprocess.run(['ffmpeg','-version'],check=True,stdout=subprocess.DEVNULL,stderr=subprocess.DEVNULL)
|
| 469 |
except: env_issues.append("ffmpeg"); logger.warning("ffmpeg check failed (UI load).")
|
| 470 |
if env_issues: secrets_msg_md += f"<br/><font color='orange'>**Tool Deps Missing:** {', '.join(env_issues)}.</font>"
|
| 471 |
-
|
| 472 |
-
current_status_md = agent_pre_init_status_msg_global
|
| 473 |
if not LANGGRAPH_FLAVOR_AVAILABLE and "LangGraph" not in current_status_md:
|
| 474 |
current_status_md += " (LangGraph core import failed, ReAct fallback.)"
|
| 475 |
-
|
| 476 |
-
return { agent_status_display: gr.Markdown(value=current_status_md),
|
| 477 |
missing_secrets_display: gr.Markdown(value=secrets_msg_md) }
|
| 478 |
|
| 479 |
demo.load(update_ui_on_load_fn_within_context, [], [agent_status_display, missing_secrets_display])
|
| 480 |
|
| 481 |
if __name__ == "__main__":
|
| 482 |
-
|
|
|
|
| 483 |
if not PYPDF2_AVAILABLE: logger.warning("PyPDF2 (PDF tool) NOT AVAILABLE.")
|
| 484 |
-
if not PIL_TESSERACT_AVAILABLE: logger.warning("Pillow/Pytesseract (OCR tool) NOT AVAILABLE.")
|
| 485 |
if not WHISPER_AVAILABLE: logger.warning("Whisper (Audio tool) NOT AVAILABLE.")
|
| 486 |
if LANGGRAPH_FLAVOR_AVAILABLE: logger.info("Core LangGraph (StateGraph, END) loaded.")
|
| 487 |
-
else: logger.warning("Core LangGraph FAILED import. ReAct fallback. Check requirements
|
| 488 |
|
| 489 |
-
missing_vars_startup_list_global.clear()
|
| 490 |
if not GOOGLE_API_KEY: missing_vars_startup_list_global.append("GOOGLE_API_KEY")
|
| 491 |
if not HUGGINGFACE_TOKEN: missing_vars_startup_list_global.append("HUGGINGFACE_TOKEN (for GAIA API)")
|
| 492 |
-
|
| 493 |
-
|
| 494 |
try:
|
| 495 |
logger.info("Pre-initializing agent...")
|
| 496 |
-
initialize_agent_and_tools()
|
| 497 |
if AGENT_INSTANCE:
|
| 498 |
-
|
| 499 |
-
|
|
|
|
| 500 |
agent_pre_init_status_msg_global = f"Agent Pre-initialized: **LangGraph** (Memory: {LANGGRAPH_MEMORY_SAVER is not None})."
|
| 501 |
else: agent_pre_init_status_msg_global = "Agent pre-init FAILED (AGENT_INSTANCE is None)."
|
| 502 |
logger.info(agent_pre_init_status_msg_global.replace("**",""))
|
|
|
|
| 10 |
from urllib.parse import urlparse
|
| 11 |
import mimetypes
|
| 12 |
import subprocess # For yt-dlp
|
| 13 |
+
import io # For BytesIO with PIL
|
| 14 |
|
| 15 |
+
# Removed: from huggingface_hub import get_space_runtime - not used for username with OAuth
|
| 16 |
|
| 17 |
# --- Global Variables for Startup Status ---
|
| 18 |
missing_vars_startup_list_global = []
|
|
|
|
| 20 |
|
| 21 |
# File Processing Libs
|
| 22 |
try: from PyPDF2 import PdfReader; PYPDF2_AVAILABLE = True
|
| 23 |
+
except ImportError: PYPDF2_AVAILABLE = False; print("WARNING: PyPDF2 not found, PDF tool will be disabled.")
|
| 24 |
+
try: from PIL import Image; import pytesseract; PIL_TESSERACT_AVAILABLE = True # PIL is needed for new tool
|
| 25 |
+
except ImportError: PIL_TESSERACT_AVAILABLE = False; print("WARNING: Pillow or Pytesseract not found, OCR tool will be disabled.")
|
| 26 |
try: import whisper; WHISPER_AVAILABLE = True
|
| 27 |
+
except ImportError: WHISPER_AVAILABLE = False; print("WARNING: OpenAI Whisper not found, Audio Transcription tool will be disabled.")
|
| 28 |
|
| 29 |
+
# Google GenAI (Used by LangChain integration)
|
| 30 |
from google.generativeai.types import HarmCategory, HarmBlockThreshold
|
| 31 |
+
|
| 32 |
# LangChain
|
| 33 |
+
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage # Removed AnyMessage
|
| 34 |
from langchain.prompts import PromptTemplate
|
| 35 |
from langchain.tools import BaseTool, tool as lc_tool_decorator
|
| 36 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
|
|
| 48 |
LANGGRAPH_FLAVOR_AVAILABLE = False
|
| 49 |
LG_StateGraph: Optional[Type[Any]] = None
|
| 50 |
LG_ToolExecutor: Optional[Type[Any]] = None
|
| 51 |
+
LG_END: Optional[Any] = None
|
| 52 |
LG_ToolInvocation: Optional[Type[Any]] = None
|
| 53 |
+
add_messages: Optional[Any] = None
|
| 54 |
MemorySaver_Class: Optional[Type[Any]] = None
|
| 55 |
|
| 56 |
AGENT_INSTANCE: Optional[Union[AgentExecutor, Any]] = None
|
| 57 |
TOOLS: List[BaseTool] = []
|
| 58 |
+
LLM_INSTANCE: Optional[ChatGoogleGenerativeAI] = None # This is the agent's "planner"
|
| 59 |
LANGGRAPH_MEMORY_SAVER: Optional[Any] = None
|
| 60 |
|
| 61 |
+
# --- google-genai Client SDK (for the new direct multimodal tool) ---
|
| 62 |
+
from google import genai as google_genai_sdk
|
| 63 |
+
google_genai_client: Optional[google_genai_sdk.Client] = None # Initialized later
|
| 64 |
+
# --- End google-genai Client SDK section ---
|
| 65 |
+
|
| 66 |
try:
|
| 67 |
+
from langgraph.graph import StateGraph, END
|
| 68 |
from langgraph.prebuilt import ToolExecutor, ToolInvocation as LGToolInvocationActual
|
| 69 |
from langgraph.graph.message import add_messages as lg_add_messages
|
| 70 |
+
from langgraph.checkpoint.memory import MemorySaver as LGMemorySaver
|
| 71 |
LANGGRAPH_FLAVOR_AVAILABLE = True
|
| 72 |
LG_StateGraph, LG_ToolExecutor, LG_END, LG_ToolInvocation, add_messages, MemorySaver_Class = \
|
| 73 |
StateGraph, ToolExecutor, END, LGToolInvocationActual, lg_add_messages, LGMemorySaver
|
| 74 |
+
print("Successfully imported LangGraph components.")
|
| 75 |
+
except ImportError as e:
|
| 76 |
LANGGRAPH_FLAVOR_AVAILABLE = False
|
| 77 |
+
# Assign None to all to prevent NameError if used before assignment
|
| 78 |
LG_StateGraph, LG_ToolExecutor, LG_END, LG_ToolInvocation, add_messages, MemorySaver_Class = (None,) * 6
|
| 79 |
+
print(f"WARNING: LangGraph components not found or import error: {e}. LangGraph agent will be disabled.")
|
| 80 |
|
| 81 |
# --- Constants ---
|
| 82 |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
| 83 |
+
# GEMINI_MODEL_NAME is for the agent's planner LLM (LangChain)
|
| 84 |
+
GEMINI_MODEL_NAME = "gemini-2.5-pro-preview-05-06" # Retained from original for planner
|
| 85 |
+
# GEMINI_FLASH_MULTIMODAL_MODEL_NAME is for the new direct multimodal tool (google-genai client SDK)
|
| 86 |
+
GEMINI_FLASH_MULTIMODAL_MODEL_NAME = "gemini-2.0-flash-exp"
|
| 87 |
+
|
| 88 |
SCORING_API_BASE_URL = os.getenv("SCORING_API_URL", DEFAULT_API_URL)
|
| 89 |
+
MAX_FILE_SIZE_BYTES = 50 * 1024 * 1024
|
| 90 |
+
LOCAL_FILE_STORE_PATH = "./Data"
|
| 91 |
+
os.makedirs(LOCAL_FILE_STORE_PATH, exist_ok=True) # Create data directory at startup
|
| 92 |
|
| 93 |
# --- Global State ---
|
| 94 |
WHISPER_MODEL: Optional[Any] = None
|
| 95 |
|
| 96 |
# --- Environment Variables & API Keys ---
|
| 97 |
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
|
| 98 |
+
HUGGINGFACE_TOKEN = os.environ.get("HF_TOKEN")
|
| 99 |
|
| 100 |
# --- Setup Logging ---
|
| 101 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(name)s - %(module)s:%(lineno)d - %(message)s')
|
| 102 |
logger = logging.getLogger(__name__)
|
| 103 |
|
| 104 |
+
# --- Initialize google-genai Client SDK ---
|
| 105 |
+
if GOOGLE_API_KEY:
|
| 106 |
+
try:
|
| 107 |
+
google_genai_client = google_genai_sdk.Client(api_key=GOOGLE_API_KEY)
|
| 108 |
+
logger.info("google-genai SDK Client initialized successfully.")
|
| 109 |
+
except Exception as e:
|
| 110 |
+
logger.error(f"Failed to initialize google-genai SDK Client: {e}")
|
| 111 |
+
google_genai_client = None # Ensure it's None if init fails
|
| 112 |
+
else:
|
| 113 |
+
logger.warning("GOOGLE_API_KEY not found. google-genai SDK Client (for direct multimodal tool) not initialized.")
|
| 114 |
+
# --- End Initialize google-genai Client SDK ---
|
| 115 |
+
|
| 116 |
+
# --- Helper Functions (Unchanged from your original) ---
|
| 117 |
def _strip_exact_match_answer(text: Any) -> str:
|
| 118 |
if not isinstance(text, str): text = str(text)
|
| 119 |
text_lower_check = text.lower()
|
|
|
|
| 121 |
text = text[len("final answer:"):].strip()
|
| 122 |
text = text.strip()
|
| 123 |
if text.startswith("```") and text.endswith("```"):
|
| 124 |
+
if "\n" in text:
|
| 125 |
text_content = text.split("\n", 1)[1] if len(text.split("\n", 1)) > 1 else ""
|
| 126 |
text = text_content.strip()[:-3].strip() if text_content.strip().endswith("```") else text[3:-3].strip()
|
| 127 |
else: text = text[3:-3].strip()
|
|
|
|
| 140 |
return parsed_url.netloc.lower().endswith(("youtube.com", "youtu.be"))
|
| 141 |
|
| 142 |
def _download_file(file_identifier: str, task_id_for_file: Optional[str] = None) -> str:
|
| 143 |
+
# ... (Your original _download_file function - unchanged)
|
| 144 |
os.makedirs(LOCAL_FILE_STORE_PATH, exist_ok=True)
|
| 145 |
logger.debug(f"Download request: '{file_identifier}', task_id: {task_id_for_file}")
|
| 146 |
original_filename = os.path.basename(urlparse(file_identifier).path) if _is_full_url(file_identifier) else os.path.basename(file_identifier)
|
|
|
|
| 159 |
logger.info(f"Cached YouTube MP3: {target_mp3_path}"); return target_mp3_path
|
| 160 |
temp_output_template = os.path.join(LOCAL_FILE_STORE_PATH, yt_filename_base + "_temp.%(ext)s")
|
| 161 |
try:
|
| 162 |
+
command = ['yt-dlp', '--quiet', '--no-warnings', '-x', '--audio-format', 'mp3',
|
| 163 |
'--audio-quality', '0', '--max-filesize', str(MAX_FILE_SIZE_BYTES),
|
| 164 |
'-o', temp_output_template, file_identifier]
|
| 165 |
logger.info(f"yt-dlp command: {' '.join(command)}")
|
| 166 |
process = subprocess.run(command, capture_output=True, text=True, timeout=180, check=False)
|
| 167 |
+
downloaded_temp_file = next((os.path.join(LOCAL_FILE_STORE_PATH, f) for f in os.listdir(LOCAL_FILE_STORE_PATH)
|
| 168 |
if f.startswith(yt_filename_base + "_temp") and f.endswith(".mp3")), None)
|
| 169 |
if process.returncode == 0 and downloaded_temp_file and os.path.exists(downloaded_temp_file):
|
| 170 |
os.rename(downloaded_temp_file, target_mp3_path)
|
|
|
|
| 197 |
if cd_header:
|
| 198 |
try:
|
| 199 |
decoded_cd_header = cd_header.encode('latin-1', 'replace').decode('utf-8', 'replace')
|
| 200 |
+
_, params = requests.utils.parse_header_links(decoded_cd_header) # type: ignore
|
| 201 |
for key, val in params.items():
|
| 202 |
if key.lower() == 'filename*' and val.lower().startswith("utf-8''"):
|
| 203 |
filename_from_cd = requests.utils.unquote(val[len("utf-8''"):]); break
|
| 204 |
elif key.lower() == 'filename':
|
| 205 |
filename_from_cd = requests.utils.unquote(val)
|
| 206 |
if filename_from_cd.startswith('"') and filename_from_cd.endswith('"'): filename_from_cd = filename_from_cd[1:-1]
|
| 207 |
+
break
|
| 208 |
except Exception as e_cd: logger.warning(f"CD parse error '{cd_header}': {e_cd}")
|
| 209 |
if filename_from_cd:
|
| 210 |
sanitized_cd_filename = "".join(c if c.isalnum() or c in ['.', '_', '-'] else '_' for c in filename_from_cd)
|
| 211 |
effective_save_path = os.path.join(LOCAL_FILE_STORE_PATH, f"{prefix}{sanitized_cd_filename}")
|
| 212 |
logger.info(f"Using CD filename: '{sanitized_cd_filename}'. Path: {effective_save_path}")
|
| 213 |
+
|
| 214 |
name_without_ext, current_ext = os.path.splitext(effective_save_path)
|
| 215 |
+
if not current_ext:
|
| 216 |
content_type_header = r.headers.get('content-type', '')
|
| 217 |
+
content_type_val = content_type_header.split(';')[0].strip() if content_type_header else ''
|
| 218 |
if content_type_val:
|
| 219 |
guessed_ext = mimetypes.guess_extension(content_type_val)
|
| 220 |
if guessed_ext: effective_save_path += guessed_ext; logger.info(f"Added guessed ext: {guessed_ext}")
|
| 221 |
+
|
| 222 |
if effective_save_path != tentative_local_path and os.path.exists(effective_save_path) and os.path.getsize(effective_save_path) > 0:
|
| 223 |
logger.info(f"Cached file (CD name): {effective_save_path}"); return effective_save_path
|
| 224 |
with open(effective_save_path, "wb") as f:
|
|
|
|
| 230 |
except Exception as e:
|
| 231 |
logger.error(f"Download error for {file_url_to_try}: {e}", exc_info=True); return f"Error: {str(e)[:100]}"
|
| 232 |
|
| 233 |
+
# --- Tool Function Definitions (Original tools unchanged) ---
|
| 234 |
+
READ_PDF_TOOL_DESC = "Reads text content from a PDF file. Input: JSON '{\"file_identifier\": \"FILENAME_OR_URL\", \"task_id\": \"TASK_ID_IF_GAIA_FILENAME_ONLY\"}'. Returns extracted text."
|
| 235 |
@lc_tool_decorator(description=READ_PDF_TOOL_DESC)
|
| 236 |
def read_pdf_tool(action_input_json_str: str) -> str:
|
| 237 |
+
# ... (Your original read_pdf_tool logic)
|
| 238 |
if not PYPDF2_AVAILABLE: return "Error: PyPDF2 not installed."
|
| 239 |
try: data = json.loads(action_input_json_str); file_id, task_id = data.get("file_identifier"), data.get("task_id")
|
| 240 |
except Exception as e: return f"Error parsing JSON for read_pdf_tool: {e}. Input: {action_input_json_str}"
|
|
|
|
| 242 |
path = _download_file(file_id, task_id)
|
| 243 |
if path.startswith("Error:"): return path
|
| 244 |
try:
|
| 245 |
+
text = "";
|
| 246 |
+
with open(path, "rb") as f:
|
| 247 |
reader = PdfReader(f)
|
| 248 |
if reader.is_encrypted:
|
| 249 |
try: reader.decrypt('')
|
| 250 |
except: return f"Error: PDF '{path}' encrypted."
|
| 251 |
+
for page_num in range(len(reader.pages)):
|
| 252 |
+
page = reader.pages[page_num]
|
| 253 |
+
text += page.extract_text() + "\n\n"
|
| 254 |
+
return text[:40000]
|
| 255 |
except Exception as e: return f"Error reading PDF '{path}': {e}"
|
| 256 |
|
| 257 |
+
|
| 258 |
+
OCR_IMAGE_TOOL_DESC = "Extracts text from an image using OCR. Input: JSON '{\"file_identifier\": \"FILENAME_OR_URL\", \"task_id\": \"TASK_ID_IF_GAIA_FILENAME_ONLY\"}'. Returns extracted text."
|
| 259 |
@lc_tool_decorator(description=OCR_IMAGE_TOOL_DESC)
|
| 260 |
def ocr_image_tool(action_input_json_str: str) -> str:
|
| 261 |
+
# ... (Your original ocr_image_tool logic)
|
| 262 |
if not PIL_TESSERACT_AVAILABLE: return "Error: Pillow/Pytesseract not installed."
|
| 263 |
try: data = json.loads(action_input_json_str); file_id, task_id = data.get("file_identifier"), data.get("task_id")
|
| 264 |
except Exception as e: return f"Error parsing JSON for ocr_image_tool: {e}. Input: {action_input_json_str}"
|
|
|
|
| 268 |
try: return pytesseract.image_to_string(Image.open(path))[:40000]
|
| 269 |
except Exception as e: return f"Error OCR'ing '{path}': {e}"
|
| 270 |
|
| 271 |
+
TRANSCRIBE_AUDIO_TOOL_DESC = "Transcribes speech from an audio file (or YouTube URL) to text. Input: JSON '{\"file_identifier\": \"FILENAME_OR_URL_OR_YOUTUBE_URL\", \"task_id\": \"TASK_ID_IF_GAIA_FILENAME_ONLY\"}'. Returns transcript."
|
| 272 |
@lc_tool_decorator(description=TRANSCRIBE_AUDIO_TOOL_DESC)
|
| 273 |
def transcribe_audio_tool(action_input_json_str: str) -> str:
|
| 274 |
+
# ... (Your original transcribe_audio_tool logic)
|
| 275 |
global WHISPER_MODEL
|
| 276 |
if not WHISPER_AVAILABLE: return "Error: Whisper not installed."
|
| 277 |
try: data = json.loads(action_input_json_str); file_id, task_id = data.get("file_identifier"), data.get("task_id")
|
|
|
|
| 282 |
except Exception as e: logger.error(f"Whisper load failed: {e}"); return f"Error: Whisper load: {e}"
|
| 283 |
path = _download_file(file_id, task_id)
|
| 284 |
if path.startswith("Error:"): return path
|
| 285 |
+
try: result = WHISPER_MODEL.transcribe(path, fp16=False); return result["text"][:40000] # type: ignore
|
| 286 |
except Exception as e: logger.error(f"Whisper error on '{path}': {e}", exc_info=True); return f"Error transcribing '{path}': {e}"
|
| 287 |
|
| 288 |
+
# +++ NEW TOOL using google-genai Client SDK for Multimodal Prompts +++
|
| 289 |
+
DIRECT_MULTIMODAL_GEMINI_TOOL_DESC = (
|
| 290 |
+
"Processes an image file (URL or local path) along with a text prompt using a Gemini multimodal model (gemini-2.0-flash-exp) "
|
| 291 |
+
"for tasks like image description, answering questions about the image, or generating text based on the image. "
|
| 292 |
+
"Input: JSON '{\"file_identifier\": \"IMAGE_FILENAME_OR_URL\", \"text_prompt\": \"Your question or instruction related to the image.\", \"task_id\": \"TASK_ID_IF_GAIA_FILENAME_ONLY\" (optional)}'. "
|
| 293 |
+
"Returns the model's text response."
|
| 294 |
+
)
|
| 295 |
+
@lc_tool_decorator(description=DIRECT_MULTIMODAL_GEMINI_TOOL_DESC)
|
| 296 |
+
def direct_multimodal_gemini_tool(action_input_json_str: str) -> str:
|
| 297 |
+
global google_genai_client # Use the initialized client
|
| 298 |
+
if not google_genai_client:
|
| 299 |
+
return "Error: google-genai SDK client not initialized. GOOGLE_API_KEY might be missing."
|
| 300 |
+
if not PIL_TESSERACT_AVAILABLE : # Check if PIL is available, as it's used to open the image
|
| 301 |
+
return "Error: Pillow (PIL) library is not available for image processing."
|
| 302 |
+
|
| 303 |
+
try:
|
| 304 |
+
data = json.loads(action_input_json_str)
|
| 305 |
+
file_identifier = data.get("file_identifier")
|
| 306 |
+
text_prompt = data.get("text_prompt", "Describe this image.") # Default prompt
|
| 307 |
+
task_id = data.get("task_id") # Optional, for _download_file if needed
|
| 308 |
+
|
| 309 |
+
if not file_identifier:
|
| 310 |
+
return "Error: 'file_identifier' for the image is missing in the input."
|
| 311 |
+
|
| 312 |
+
logger.info(f"Direct Multimodal Tool: Processing image '{file_identifier}' with prompt '{text_prompt}'")
|
| 313 |
+
|
| 314 |
+
# Download the file to a local path (handles URLs and GAIA files)
|
| 315 |
+
local_image_path = _download_file(file_identifier, task_id)
|
| 316 |
+
if local_image_path.startswith("Error:"):
|
| 317 |
+
return f"Error downloading image for Direct Multimodal Tool: {local_image_path}"
|
| 318 |
+
|
| 319 |
+
# Open the image using Pillow
|
| 320 |
+
try:
|
| 321 |
+
pil_image = Image.open(local_image_path)
|
| 322 |
+
pil_image.thumbnail((1024, 1024)) # Optional: resize large images
|
| 323 |
+
except Exception as e_img:
|
| 324 |
+
logger.error(f"Error opening image at {local_image_path}: {e_img}")
|
| 325 |
+
return f"Error opening image file {local_image_path}: {str(e_img)}"
|
| 326 |
+
|
| 327 |
+
# Send to Gemini Flash model using the client SDK
|
| 328 |
+
response = google_genai_client.models.generate_content(
|
| 329 |
+
model=GEMINI_FLASH_MULTIMODAL_MODEL_NAME, # Use the specified Flash model
|
| 330 |
+
contents=[pil_image, text_prompt] # Pass PIL image and text prompt
|
| 331 |
+
)
|
| 332 |
+
logger.info(f"Direct Multimodal Tool: Response received from {GEMINI_FLASH_MULTIMODAL_MODEL_NAME}.")
|
| 333 |
+
return response.text[:40000] # Return model's text response, truncated if very long
|
| 334 |
+
|
| 335 |
+
except json.JSONDecodeError as e_json:
|
| 336 |
+
return f"Error parsing JSON input for Direct Multimodal Tool: {str(e_json)}. Input was: {action_input_json_str}"
|
| 337 |
+
except Exception as e_tool:
|
| 338 |
+
logger.error(f"Error in direct_multimodal_gemini_tool: {e_tool}", exc_info=True)
|
| 339 |
+
return f"Error executing Direct Multimodal Tool: {str(e_tool)}"
|
| 340 |
+
# +++ END NEW TOOL +++
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
# --- Agent Prompts (Slightly updated to include the new tool name if available) ---
|
| 344 |
+
# (Agent prompts remain largely the same, the agent will learn to use tools from their descriptions)
|
| 345 |
|
| 346 |
# --- Agent Initialization and Response Logic ---
|
| 347 |
def initialize_agent_and_tools(force_reinit=False):
|
| 348 |
+
global AGENT_INSTANCE, TOOLS, LLM_INSTANCE, LANGGRAPH_FLAVOR_AVAILABLE, LG_StateGraph, LG_ToolExecutor, LG_END, LG_ToolInvocation, add_messages, MemorySaver_Class, LANGGRAPH_MEMORY_SAVER, google_genai_client
|
| 349 |
if AGENT_INSTANCE and not force_reinit: logger.info("Agent already initialized."); return
|
| 350 |
logger.info("Initializing agent and tools...")
|
| 351 |
+
if not GOOGLE_API_KEY: raise ValueError("GOOGLE_API_KEY not set for LangChain LLM.")
|
| 352 |
+
|
| 353 |
+
# Initialize LangChain LLM (Planner)
|
| 354 |
try:
|
| 355 |
+
LLM_INSTANCE = ChatGoogleGenerativeAI(model=GEMINI_MODEL_NAME, google_api_key=GOOGLE_API_KEY, temperature=0.0,
|
| 356 |
safety_settings={HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
| 357 |
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,},
|
| 358 |
request_timeout=120, convert_system_message_to_human=True )
|
| 359 |
+
logger.info(f"LangChain LLM (Planner) initialized: {GEMINI_MODEL_NAME}")
|
| 360 |
+
except Exception as e: logger.error(f"LangChain LLM init failed: {e}", exc_info=True); raise
|
| 361 |
+
|
| 362 |
TOOLS = []
|
| 363 |
if PYPDF2_AVAILABLE: TOOLS.append(read_pdf_tool)
|
| 364 |
if PIL_TESSERACT_AVAILABLE: TOOLS.append(ocr_image_tool)
|
| 365 |
if WHISPER_AVAILABLE: TOOLS.append(transcribe_audio_tool)
|
| 366 |
+
|
| 367 |
+
# Add the new direct multimodal tool if its dependencies (client, PIL) are met
|
| 368 |
+
if google_genai_client and PIL_TESSERACT_AVAILABLE: # PIL_TESSERACT_AVAILABLE implies PIL is available
|
| 369 |
+
TOOLS.append(direct_multimodal_gemini_tool)
|
| 370 |
+
logger.info("Added 'direct_multimodal_gemini_tool'.")
|
| 371 |
+
else:
|
| 372 |
+
logger.warning("'direct_multimodal_gemini_tool' NOT added due to missing google_genai_client or PIL.")
|
| 373 |
+
|
| 374 |
try: search_tool = DuckDuckGoSearchRun(name="web_search"); search_tool.description = "Web search. Input: query."; TOOLS.append(search_tool)
|
| 375 |
except Exception as e: logger.warning(f"DuckDuckGoSearchRun init failed: {e}")
|
| 376 |
try: python_repl = PythonREPLTool(name="python_repl"); python_repl.description = "Python REPL. print() for output."; TOOLS.append(python_repl)
|
| 377 |
except Exception as e: logger.warning(f"PythonREPLTool init failed: {e}")
|
| 378 |
+
logger.info(f"Final tools list for agent: {[t.name for t in TOOLS]}")
|
| 379 |
+
|
| 380 |
|
| 381 |
+
# ... (Rest of your initialize_agent_and_tools function for LangGraph/ReAct setup - unchanged)
|
| 382 |
if LANGGRAPH_FLAVOR_AVAILABLE and all([LG_StateGraph, LG_ToolExecutor, LG_END, LLM_INSTANCE, LG_ToolInvocation, add_messages]):
|
| 383 |
+
if not LANGGRAPH_MEMORY_SAVER and MemorySaver_Class: LANGGRAPH_MEMORY_SAVER = MemorySaver_Class(); logger.info("LangGraph MemorySaver initialized.")
|
| 384 |
try:
|
| 385 |
logger.info(f"Attempting LangGraph init (Memory: {LANGGRAPH_MEMORY_SAVER is not None})")
|
| 386 |
+
_TypedDict = getattr(__import__('typing_extensions'), 'TypedDict', dict)
|
| 387 |
+
class AgentState(_TypedDict): input: str; messages: Annotated[List[Any], add_messages] # Use Any for AnyMessage for broader compatibility
|
| 388 |
+
|
| 389 |
+
prompt_content_lg = LANGGRAPH_PROMPT_TEMPLATE_STR.format(
|
| 390 |
+
tools="\n".join([f"- {t.name}: {t.description}" for t in TOOLS]),
|
| 391 |
+
input="{input}"
|
| 392 |
+
)
|
| 393 |
+
def agent_node(state: AgentState):
|
| 394 |
+
current_input_lg = state.get('input', '')
|
| 395 |
+
formatted_system_prompt_lg = prompt_content_lg.replace("{input}", current_input_lg)
|
| 396 |
+
messages_for_llm_lg = [SystemMessage(content=formatted_system_prompt_lg)] + state.get('messages', [])
|
| 397 |
+
bound_llm_for_tools_lg = LLM_INSTANCE.bind_tools(TOOLS)
|
| 398 |
+
response_from_llm_lg = bound_llm_for_tools_lg.invoke(messages_for_llm_lg)
|
| 399 |
+
return {"messages": [response_from_llm_lg]}
|
| 400 |
+
|
| 401 |
+
tool_executor_lg_instance = LG_ToolExecutor(TOOLS) # type: ignore
|
| 402 |
+
def tool_node(state: AgentState):
|
| 403 |
+
last_msg_lg = state['messages'][-1] if state.get('messages') and isinstance(state['messages'][-1], AIMessage) else None
|
| 404 |
+
if not last_msg_lg or not last_msg_lg.tool_calls: return {"messages": []}
|
| 405 |
+
tool_results_lg = []
|
| 406 |
+
for tc_lg in last_msg_lg.tool_calls:
|
| 407 |
+
name_lg, args_lg, tc_id_lg = tc_lg.get('name'), tc_lg.get('args'), tc_lg.get('id')
|
| 408 |
+
if not all([name_lg, isinstance(args_lg, dict), tc_id_lg]):
|
| 409 |
+
err_msg_lg=f"Invalid tool_call: {tc_lg}"; logger.error(err_msg_lg)
|
| 410 |
+
tool_results_lg.append(ToolMessage(f"Error: {err_msg_lg}", tool_call_id=tc_id_lg or "error_id", name=name_lg or "error_tool"))
|
| 411 |
continue
|
| 412 |
try:
|
| 413 |
+
logger.info(f"LG Tool Invoking: '{name_lg}' with {args_lg} (ID: {tc_id_lg})")
|
| 414 |
+
tool_invocation_obj_lg = LG_ToolInvocation(tool=name_lg, tool_input=args_lg) # type: ignore
|
| 415 |
+
output_lg = tool_executor_lg_instance.invoke(tool_invocation_obj_lg)
|
| 416 |
+
tool_results_lg.append(ToolMessage(content=str(output_lg), tool_call_id=tc_id_lg, name=name_lg))
|
| 417 |
+
except Exception as e_tool_node_lg:
|
| 418 |
+
logger.error(f"LG Tool Error ('{name_lg}'): {e_tool_node_lg}", exc_info=True)
|
| 419 |
+
tool_results_lg.append(ToolMessage(content=f"Error for tool {name_lg}: {str(e_tool_node_lg)}", tool_call_id=tc_id_lg, name=name_lg))
|
| 420 |
+
return {"messages": tool_results_lg}
|
| 421 |
+
|
| 422 |
+
workflow_lg = LG_StateGraph(AgentState) # type: ignore
|
| 423 |
+
workflow_lg.add_node("agent", agent_node)
|
| 424 |
+
workflow_lg.add_node("tools", tool_node)
|
| 425 |
+
workflow_lg.set_entry_point("agent")
|
| 426 |
+
def should_continue_lg(state: AgentState): return "tools" if state['messages'][-1].tool_calls else LG_END
|
| 427 |
+
workflow_lg.add_conditional_edges("agent", should_continue_lg, {"tools": "tools", LG_END: LG_END}) # type: ignore
|
| 428 |
+
workflow_lg.add_edge("tools", "agent")
|
| 429 |
+
AGENT_INSTANCE = workflow_lg.compile(checkpointer=LANGGRAPH_MEMORY_SAVER) if LANGGRAPH_MEMORY_SAVER else workflow_lg.compile()
|
| 430 |
logger.info(f"LangGraph compiled (Memory: {LANGGRAPH_MEMORY_SAVER is not None}).")
|
| 431 |
+
except Exception as e_lg_init_main:
|
| 432 |
+
logger.error(f"LangGraph init error: {e_lg_init_main}. Fallback ReAct.", exc_info=True); AGENT_INSTANCE = None
|
| 433 |
+
else:
|
| 434 |
+
logger.info("Skipping LangGraph: core components missing or LLM not ready."); AGENT_INSTANCE = None
|
| 435 |
|
| 436 |
if not AGENT_INSTANCE:
|
| 437 |
+
logger.info("Initializing ReAct agent as fallback.")
|
| 438 |
try:
|
| 439 |
+
if not LLM_INSTANCE: raise ValueError("LLM_INSTANCE is None for ReAct.")
|
| 440 |
+
prompt_react_instance = PromptTemplate.from_template(REACT_PROMPT_TEMPLATE_STR).partial(
|
| 441 |
+
tools="\n".join([f"- {t.name}:{t.description}" for t in TOOLS]),
|
| 442 |
+
tool_names=",".join([t.name for t in TOOLS])
|
| 443 |
+
)
|
| 444 |
+
react_agent_runnable_instance = create_react_agent(LLM_INSTANCE, TOOLS, prompt_react_instance)
|
| 445 |
+
AGENT_INSTANCE = AgentExecutor(agent=react_agent_runnable_instance, tools=TOOLS, verbose=True, handle_parsing_errors=True, max_iterations=15, early_stopping_method="force")
|
| 446 |
logger.info("ReAct agent initialized.")
|
| 447 |
+
except Exception as e_react_init_main:
|
| 448 |
+
logger.error(f"ReAct agent init failed: {e_react_init_main}", exc_info=True); AGENT_INSTANCE = None
|
| 449 |
+
|
| 450 |
+
if not AGENT_INSTANCE: raise RuntimeError("CRITICAL: Agent initialization completely failed.")
|
| 451 |
+
logger.info(f"Agent init finished. Active agent type: {type(AGENT_INSTANCE).__name__}")
|
| 452 |
|
| 453 |
+
|
| 454 |
+
# --- get_agent_response, construct_prompt_for_agent, run_and_submit_all ---
|
| 455 |
+
# --- These functions remain UNCHANGED from your original code ---
|
| 456 |
def get_agent_response(prompt: str, task_id: Optional[str]=None, thread_id: Optional[str]=None) -> str:
|
| 457 |
+
# ... (Your original get_agent_response logic) ...
|
| 458 |
global AGENT_INSTANCE, LLM_INSTANCE
|
| 459 |
+
thread_id_to_use = thread_id or (f"gaia_task_{task_id}" if task_id else hashlib.md5(prompt.encode()).hexdigest()[:8])
|
| 460 |
if not AGENT_INSTANCE or not LLM_INSTANCE:
|
| 461 |
+
logger.warning("Agent/LLM not initialized in get_agent_response. Attempting re-initialization.")
|
| 462 |
+
try: initialize_agent_and_tools(force_reinit=True)
|
| 463 |
+
except Exception as e_reinit_get: logger.error(f"Re-initialization failed: {e_reinit_get}"); return f"[ERROR] Agent/LLM re-init failed: {str(e_reinit_get)}"
|
| 464 |
+
if not AGENT_INSTANCE or not LLM_INSTANCE: return "[ERROR] Agent/LLM still None after re-init."
|
| 465 |
+
agent_name_get = type(AGENT_INSTANCE).__name__
|
| 466 |
+
logger.info(f"Agent ({agent_name_get}) processing. Task: {task_id or 'N/A'}. Thread: {thread_id_to_use}.")
|
| 467 |
+
is_langgraph_agent_get = LANGGRAPH_FLAVOR_AVAILABLE and AGENT_INSTANCE and hasattr(AGENT_INSTANCE, 'graph') and hasattr(AGENT_INSTANCE, 'config_schema')
|
| 468 |
try:
|
| 469 |
+
if is_langgraph_agent_get:
|
| 470 |
+
logger.debug(f"Using LangGraph agent (Memory: {LANGGRAPH_MEMORY_SAVER is not None}) for thread: {thread_id_to_use}")
|
| 471 |
+
initial_messages_lg_get = []
|
| 472 |
+
input_for_lg_get = {"input": prompt, "messages": initial_messages_lg_get}
|
| 473 |
+
final_state_lg_get = AGENT_INSTANCE.invoke(input_for_lg_get, {"configurable": {"thread_id": thread_id_to_use}})
|
| 474 |
+
if not final_state_lg_get or 'messages' not in final_state_lg_get or not final_state_lg_get['messages']:
|
| 475 |
+
logger.error("LangGraph: No final state/messages."); return "[ERROR] LangGraph: No final state/messages."
|
| 476 |
+
for message_item_lg_get in reversed(final_state_lg_get['messages']):
|
| 477 |
+
if isinstance(message_item_lg_get, AIMessage) and not message_item_lg_get.tool_calls:
|
| 478 |
+
return str(message_item_lg_get.content)
|
| 479 |
+
logger.warning("LangGraph: No suitable final AIMessage without tool_calls.")
|
| 480 |
+
return str(final_state_lg_get['messages'][-1].content) if final_state_lg_get['messages'] else "[ERROR] LangGraph: Empty messages."
|
| 481 |
+
elif isinstance(AGENT_INSTANCE, AgentExecutor):
|
| 482 |
+
logger.debug("Using ReAct agent.")
|
| 483 |
+
response_react_get = AGENT_INSTANCE.invoke({"input": prompt})
|
| 484 |
+
return str(response_react_get.get("output", "[ERROR] ReAct: No 'output' key."))
|
| 485 |
+
else:
|
| 486 |
+
logger.error(f"Unknown agent type: {agent_name_get}"); return f"[ERROR] Unknown agent type: {agent_name_get}"
|
| 487 |
+
except Exception as e_agent_run_get:
|
| 488 |
+
logger.error(f"Error during agent execution ({agent_name_get}): {e_agent_run_get}", exc_info=True)
|
| 489 |
+
return f"[ERROR] Agent execution failed: {str(e_agent_run_get)[:150]}"
|
| 490 |
|
| 491 |
def construct_prompt_for_agent(q: Dict[str,Any]) -> str:
|
| 492 |
+
# ... (Your original construct_prompt_for_agent logic) ...
|
| 493 |
tid,q_str=q.get("task_id","N/A"),q.get("question",""); files=q.get("files",[])
|
| 494 |
files_info = ("\nFiles:\n"+"\n".join([f"- {f} (task_id:{tid})"for f in files])) if files else ""
|
| 495 |
level = f"\nLevel:{q.get('level')}" if q.get('level') else ""
|
| 496 |
return f"Task ID:{tid}{level}{files_info}\n\nQuestion:{q_str}"
|
| 497 |
|
| 498 |
+
def run_and_submit_all(profile: Optional[gr.OAuthProfile] = None):
|
| 499 |
+
# ... (Your original run_and_submit_all logic - unchanged) ...
|
| 500 |
global AGENT_INSTANCE
|
| 501 |
space_id = os.getenv("SPACE_ID")
|
| 502 |
username_for_submission = None
|
|
|
|
| 505 |
username_for_submission = profile.username
|
| 506 |
logger.info(f"Username from OAuth profile: {username_for_submission}")
|
| 507 |
else:
|
| 508 |
+
logger.warning("OAuth profile not available or username missing.")
|
|
|
|
|
|
|
| 509 |
return "Hugging Face login required. Please use the login button and try again.", None
|
| 510 |
|
| 511 |
if AGENT_INSTANCE is None:
|
| 512 |
+
try: logger.info("Agent not pre-initialized. Initializing for run..."); initialize_agent_and_tools()
|
| 513 |
except Exception as e: return f"Agent on-demand initialization failed: {e}", None
|
| 514 |
+
if AGENT_INSTANCE is None: return "Agent is still None after on-demand init.", None
|
| 515 |
|
| 516 |
+
agent_code_url_run=f"https://huggingface.co/spaces/{space_id}/tree/main" if space_id else "local_dev_run"
|
| 517 |
+
questions_url_run,submit_url_run=f"{DEFAULT_API_URL}/questions",f"{DEFAULT_API_URL}/submit"
|
| 518 |
+
auth_headers_run={"Authorization":f"Bearer {HUGGINGFACE_TOKEN}"} if HUGGINGFACE_TOKEN else {}
|
| 519 |
try:
|
| 520 |
+
logger.info(f"Fetching questions from {questions_url_run}")
|
| 521 |
+
response_q_run=requests.get(questions_url_run,headers=auth_headers_run,timeout=30);response_q_run.raise_for_status();questions_data_run=response_q_run.json()
|
| 522 |
+
if not questions_data_run or not isinstance(questions_data_run,list):logger.error(f"Invalid questions data: {questions_data_run}");return "Fetched questions_data invalid.",None
|
| 523 |
+
logger.info(f"Fetched {len(questions_data_run)} questions.")
|
| 524 |
except Exception as e:logger.error(f"Fetch questions error: {e}",exc_info=True);return f"Fetch questions error:{e}",None
|
| 525 |
|
| 526 |
+
results_log_run,answers_payload_run=[],[]
|
| 527 |
+
logger.info(f"Running agent on {len(questions_data_run)} questions for user '{username_for_submission}'...")
|
| 528 |
+
for i,item_run in enumerate(questions_data_run):
|
| 529 |
+
task_id_run,question_text_run=item_run.get("task_id"),item_run.get("question")
|
| 530 |
+
if not task_id_run or question_text_run is None:logger.warning(f"Skipping item: {item_run}");continue
|
| 531 |
+
prompt_run=construct_prompt_for_agent(item_run);thread_id_run=f"gaia_batch_task_{task_id_run}"
|
| 532 |
+
logger.info(f"Processing Q {i+1}/{len(questions_data_run)} - Task: {task_id_run}")
|
| 533 |
try:
|
| 534 |
+
raw_answer_run=get_agent_response(prompt_run,task_id=task_id_run,thread_id=thread_id_run);submitted_answer_run=_strip_exact_match_answer(raw_answer_run)
|
| 535 |
+
answers_payload_run.append({"task_id":task_id_run,"submitted_answer":submitted_answer_run})
|
| 536 |
+
results_log_run.append({"Task ID":task_id_run,"Question":question_text_run,"Full Agent Prompt":prompt_run,"Raw Agent Output":raw_answer_run,"Submitted Answer":submitted_answer_run})
|
| 537 |
except Exception as e:
|
| 538 |
+
logger.error(f"Agent error task {task_id_run}:{e}",exc_info=True);error_answer_run=f"AGENT ERROR:{str(e)[:100]}"
|
| 539 |
+
answers_payload_run.append({"task_id":task_id_run,"submitted_answer":"N/A [AGENT_ERROR]"})
|
| 540 |
+
results_log_run.append({"Task ID":task_id_run,"Question":question_text_run,"Full Agent Prompt":prompt_run,"Raw Agent Output":error_answer_run,"Submitted Answer":"N/A [AGENT_ERROR]"})
|
| 541 |
|
| 542 |
+
if not answers_payload_run:return "Agent produced no answers.",pd.DataFrame(results_log_run)
|
| 543 |
+
submission_payload_run={"username":username_for_submission.strip(),"agent_code":agent_code_url_run,"answers":answers_payload_run}
|
| 544 |
+
logger.info(f"Submitting {len(answers_payload_run)} answers to {submit_url_run} for user '{username_for_submission}'...")
|
| 545 |
+
submission_headers_run={"Content-Type":"application/json",**auth_headers_run}
|
| 546 |
try:
|
| 547 |
+
response_s_run=requests.post(submit_url_run,json=submission_payload_run,headers=submission_headers_run,timeout=120);response_s_run.raise_for_status();submission_result_run=response_s_run.json()
|
| 548 |
+
result_message_run=(f"User:{submission_result_run.get('username',username_for_submission)}\nScore:{submission_result_run.get('score','N/A')}% ({submission_result_run.get('correct_count','?')}/{submission_result_run.get('total_attempted','?')})\nMsg:{submission_result_run.get('message','N/A')}")
|
| 549 |
+
logger.info(f"Submission OK! {result_message_run}");return f"Submission OK!\n{result_message_run}",pd.DataFrame(results_log_run,columns=["Task ID","Question","Full Agent Prompt","Raw Agent Output","Submitted Answer"])
|
| 550 |
except requests.exceptions.HTTPError as e:
|
| 551 |
+
error_http_run=f"HTTP {e.response.status_code}. Detail:{e.response.text[:200]}"; logger.error(f"Submit Fail:{error_http_run}",exc_info=True); return f"Submit Fail:{error_http_run}",pd.DataFrame(results_log_run)
|
| 552 |
+
except Exception as e:logger.error(f"Submit Fail unexpected:{e}",exc_info=True);return f"Submit Fail:{str(e)[:100]}",pd.DataFrame(results_log_run)
|
| 553 |
+
|
| 554 |
|
| 555 |
+
# --- Build Gradio Interface (Unchanged from your original) ---
|
| 556 |
with gr.Blocks(css=".gradio-container {max-width:1280px !important;margin:auto !important;}",theme=gr.themes.Soft()) as demo:
|
| 557 |
+
# ... (Your original Gradio UI layout - unchanged) ...
|
| 558 |
+
gr.Markdown("# GAIA Agent Challenge Runner v7 (OAuth for Username)")
|
| 559 |
gr.Markdown(f"""**Instructions:**
|
| 560 |
1. **Login with Hugging Face** using the button below. Your HF username will be used for submission.
|
| 561 |
2. Click 'Run Evaluation & Submit' to process GAIA questions (typically 20).
|
| 562 |
+
3. **Goal: 30%+ (6/20).** Agent uses Gemini Pro ({GEMINI_MODEL_NAME}) as planner. Tools include Web Search, Python, PDF, OCR, Audio/YouTube, and a new Direct Multimodal tool using Gemini Flash ({GEMINI_FLASH_MULTIMODAL_MODEL_NAME}).
|
| 563 |
4. Ensure `GOOGLE_API_KEY` and `HUGGINGFACE_TOKEN` are Space secrets.
|
| 564 |
5. Check Space logs for details. LangGraph is attempted (ReAct fallback).""")
|
|
|
|
|
|
|
|
|
|
| 565 |
|
| 566 |
+
agent_status_display = gr.Markdown("**Agent Status:** Initializing...")
|
| 567 |
+
missing_secrets_display = gr.Markdown("")
|
| 568 |
+
|
| 569 |
+
gr.LoginButton()
|
| 570 |
run_button = gr.Button("Run Evaluation & Submit All Answers")
|
| 571 |
status_output = gr.Textbox(label="Run Status / Submission Result", lines=7, interactive=False)
|
| 572 |
+
results_table = gr.DataFrame(label="Q&A Log", headers=["Task ID","Question","Prompt","Raw","Submitted"], wrap=True, height=400)
|
| 573 |
+
|
|
|
|
|
|
|
| 574 |
run_button.click(fn=run_and_submit_all, outputs=[status_output,results_table], api_name="run_evaluation")
|
| 575 |
|
| 576 |
def update_ui_on_load_fn_within_context():
|
| 577 |
+
# ... (Your original update_ui_on_load_fn_within_context logic - unchanged) ...
|
| 578 |
+
global missing_vars_startup_list_global, agent_pre_init_status_msg_global
|
| 579 |
secrets_msg_md = ""
|
| 580 |
+
if missing_vars_startup_list_global:
|
| 581 |
secrets_msg_md = f"<font color='red'>**⚠️ Secrets Missing:** {', '.join(missing_vars_startup_list_global)}.</font>"
|
| 582 |
env_issues = []
|
| 583 |
try: subprocess.run(['yt-dlp','--version'],check=True,stdout=subprocess.DEVNULL,stderr=subprocess.DEVNULL)
|
|
|
|
| 585 |
try: subprocess.run(['ffmpeg','-version'],check=True,stdout=subprocess.DEVNULL,stderr=subprocess.DEVNULL)
|
| 586 |
except: env_issues.append("ffmpeg"); logger.warning("ffmpeg check failed (UI load).")
|
| 587 |
if env_issues: secrets_msg_md += f"<br/><font color='orange'>**Tool Deps Missing:** {', '.join(env_issues)}.</font>"
|
| 588 |
+
current_status_md = agent_pre_init_status_msg_global
|
|
|
|
| 589 |
if not LANGGRAPH_FLAVOR_AVAILABLE and "LangGraph" not in current_status_md:
|
| 590 |
current_status_md += " (LangGraph core import failed, ReAct fallback.)"
|
| 591 |
+
return { agent_status_display: gr.Markdown(value=current_status_md),
|
|
|
|
| 592 |
missing_secrets_display: gr.Markdown(value=secrets_msg_md) }
|
| 593 |
|
| 594 |
demo.load(update_ui_on_load_fn_within_context, [], [agent_status_display, missing_secrets_display])
|
| 595 |
|
| 596 |
if __name__ == "__main__":
|
| 597 |
+
# ... (Your original __main__ block for startup logging and pre-initialization - unchanged) ...
|
| 598 |
+
logger.info("Application starting up (v7 with Direct Multimodal Tool)...")
|
| 599 |
if not PYPDF2_AVAILABLE: logger.warning("PyPDF2 (PDF tool) NOT AVAILABLE.")
|
| 600 |
+
if not PIL_TESSERACT_AVAILABLE: logger.warning("Pillow/Pytesseract (OCR tool) NOT AVAILABLE.") # PIL also needed for new tool
|
| 601 |
if not WHISPER_AVAILABLE: logger.warning("Whisper (Audio tool) NOT AVAILABLE.")
|
| 602 |
if LANGGRAPH_FLAVOR_AVAILABLE: logger.info("Core LangGraph (StateGraph, END) loaded.")
|
| 603 |
+
else: logger.warning("Core LangGraph FAILED import. ReAct fallback. Check requirements & Space build logs.")
|
| 604 |
|
| 605 |
+
missing_vars_startup_list_global.clear()
|
| 606 |
if not GOOGLE_API_KEY: missing_vars_startup_list_global.append("GOOGLE_API_KEY")
|
| 607 |
if not HUGGINGFACE_TOKEN: missing_vars_startup_list_global.append("HUGGINGFACE_TOKEN (for GAIA API)")
|
| 608 |
+
|
|
|
|
| 609 |
try:
|
| 610 |
logger.info("Pre-initializing agent...")
|
| 611 |
+
initialize_agent_and_tools() # This will now include the new direct_multimodal_gemini_tool
|
| 612 |
if AGENT_INSTANCE:
|
| 613 |
+
agent_type_name = type(AGENT_INSTANCE).__name__
|
| 614 |
+
agent_pre_init_status_msg_global = f"Agent Pre-initialized: **{agent_type_name}**."
|
| 615 |
+
if LANGGRAPH_FLAVOR_AVAILABLE and "StateGraph" in agent_type_name: # More robust check
|
| 616 |
agent_pre_init_status_msg_global = f"Agent Pre-initialized: **LangGraph** (Memory: {LANGGRAPH_MEMORY_SAVER is not None})."
|
| 617 |
else: agent_pre_init_status_msg_global = "Agent pre-init FAILED (AGENT_INSTANCE is None)."
|
| 618 |
logger.info(agent_pre_init_status_msg_global.replace("**",""))
|