Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -61,14 +61,14 @@ if TYPE_CHECKING:
|
|
| 61 |
LANGGRAPH_FLAVOR_AVAILABLE = False
|
| 62 |
LG_StateGraph: Optional[Type[Any]] = None
|
| 63 |
LG_ToolExecutor_Class: Optional[Type[Any]] = None
|
| 64 |
-
LG_END: Optional[Any] = None
|
| 65 |
LG_ToolInvocation: Optional[Type[Any]] = None
|
| 66 |
add_messages: Optional[Any] = None
|
| 67 |
MemorySaver_Class: Optional[Type[Any]] = None
|
| 68 |
|
| 69 |
AGENT_INSTANCE: Optional[Union[AgentExecutor, Any]] = None
|
| 70 |
TOOLS: List[BaseTool] = []
|
| 71 |
-
LLM_INSTANCE: Optional[ChatGoogleGenerativeAI] = None
|
| 72 |
LANGGRAPH_MEMORY_SAVER: Optional[Any] = None
|
| 73 |
|
| 74 |
# google-genai Client SDK
|
|
@@ -126,8 +126,8 @@ except ImportError as e:
|
|
| 126 |
|
| 127 |
# --- Constants ---
|
| 128 |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
| 129 |
-
GEMINI_MODEL_NAME = "gemini-
|
| 130 |
-
GEMINI_FLASH_MULTIMODAL_MODEL_NAME = "gemini-
|
| 131 |
SCORING_API_BASE_URL = os.getenv("SCORING_API_URL", DEFAULT_API_URL)
|
| 132 |
MAX_FILE_SIZE_BYTES = 50 * 1024 * 1024
|
| 133 |
LOCAL_FILE_STORE_PATH = "./Data"
|
|
@@ -255,7 +255,8 @@ def _download_file(file_identifier: str, task_id_for_file: Optional[str] = None)
|
|
| 255 |
name_without_ext, current_ext = os.path.splitext(effective_save_path)
|
| 256 |
if not current_ext:
|
| 257 |
content_type_header = r.headers.get('content-type', '')
|
| 258 |
-
|
|
|
|
| 259 |
if content_type_val:
|
| 260 |
guessed_ext = mimetypes.guess_extension(content_type_val)
|
| 261 |
if guessed_ext: effective_save_path += guessed_ext; logger.info(f"Added guessed ext: {guessed_ext}")
|
|
@@ -324,7 +325,7 @@ def transcribe_audio_tool(action_input_json_str: str) -> str:
|
|
| 324 |
|
| 325 |
@lc_tool_decorator
|
| 326 |
def direct_multimodal_gemini_tool(action_input_json_str: str) -> str:
|
| 327 |
-
"""Processes an image file (URL or local path) along with a text prompt using a Gemini multimodal model (gemini-
|
| 328 |
global google_genai_client
|
| 329 |
if not google_genai_client: return "Error: google-genai SDK client not initialized."
|
| 330 |
if not PIL_TESSERACT_AVAILABLE : return "Error: Pillow (PIL) library not available for image processing." # Relies on PIL_TESSERACT_AVAILABLE for PIL
|
|
@@ -360,9 +361,9 @@ You have access to the following tools. Use them if necessary.
|
|
| 360 |
{tools}
|
| 361 |
TOOL USAGE:
|
| 362 |
- To use a tool, your response must include a `tool_calls` attribute in the AIMessage. Each tool call should be a dictionary with "name", "args" (a dictionary of arguments), and "id".
|
| 363 |
-
- For file tools ('read_pdf_tool', 'ocr_image_tool', 'transcribe_audio_tool', 'direct_multimodal_gemini_tool'): `args` must
|
| 364 |
- 'web_search': `args` is like '{{"query": "search query"}}'.
|
| 365 |
-
- 'python_repl': `args` is like '{{"
|
| 366 |
RESPONSE FORMAT:
|
| 367 |
Final AIMessage should contain ONLY the answer in 'content' and NO 'tool_calls'. If using tools, 'content' can be thought process, with 'tool_calls'.
|
| 368 |
Begin!
|
|
@@ -376,8 +377,7 @@ Process: Question -> Thought -> Action (ONE of [{tool_names}]) -> Action Input -
|
|
| 376 |
Tool Inputs:
|
| 377 |
- web_search: Your search query string.
|
| 378 |
- python_repl: Python code string. Use print(). For Excel/CSV, use pandas: import pandas as pd; df = pd.read_excel('./Data/TASKID_filename.xlsx'); print(df.head())
|
| 379 |
-
- read_pdf_tool, ocr_image_tool, transcribe_audio_tool: JSON string like '{{"file_identifier": "FILENAME_OR_URL", "task_id": "CURRENT_TASK_ID_IF_FILENAME"}}'.
|
| 380 |
-
- direct_multimodal_gemini_tool: JSON string like '{{"file_identifier": "IMAGE_FILENAME_OR_URL", "text_prompt": "Your prompt for the image.", "task_id": "TASK_ID_IF_GAIA_FILENAME"}}'.
|
| 381 |
If tool fails or info missing, Final Answer: N/A. Do NOT use unlisted tools.
|
| 382 |
Begin!
|
| 383 |
{input}
|
|
@@ -422,28 +422,27 @@ def initialize_agent_and_tools(force_reinit=False):
|
|
| 422 |
try:
|
| 423 |
logger.info(f"Attempting LangGraph init (Tool Executor type: {LG_ToolExecutor_Class.__name__ if LG_ToolExecutor_Class else 'None'})")
|
| 424 |
_TypedDict = getattr(__import__('typing_extensions'), 'TypedDict', dict)
|
| 425 |
-
|
|
|
|
|
|
|
| 426 |
|
| 427 |
# System prompt template - this describes the agent's role and tools.
|
| 428 |
# The {input} placeholder for the actual task will be filled by the HumanMessage.
|
| 429 |
-
base_system_prompt_content_lg = LANGGRAPH_PROMPT_TEMPLATE_STR.split("{input}")[0].strip()
|
| 430 |
-
|
| 431 |
|
| 432 |
def agent_node(state: AgentState):
|
| 433 |
-
current_task_query = state.get('input', '') # The specific question/task for this turn
|
| 434 |
-
|
| 435 |
system_message_content = base_system_prompt_content_lg.format(
|
| 436 |
tools="\n".join([f"- {t.name}: {t.description}" for t in TOOLS])
|
| 437 |
)
|
| 438 |
|
|
|
|
| 439 |
messages_for_llm = [SystemMessage(content=system_message_content)]
|
| 440 |
-
messages_for_llm.extend(state
|
| 441 |
-
messages_for_llm.append(HumanMessage(content=current_task_query)) # Add current task as HumanMessage
|
| 442 |
|
| 443 |
logger.debug(f"LangGraph agent_node - messages_for_llm: {messages_for_llm}")
|
| 444 |
-
if not messages_for_llm
|
| 445 |
-
logger.error("LLM call would fail in agent_node:
|
| 446 |
-
return {"messages": [AIMessage(content="[ERROR] Agent node:
|
| 447 |
|
| 448 |
bound_llm = LLM_INSTANCE.bind_tools(TOOLS)
|
| 449 |
response = bound_llm.invoke(messages_for_llm)
|
|
@@ -478,9 +477,6 @@ def initialize_agent_and_tools(force_reinit=False):
|
|
| 478 |
|
| 479 |
workflow_lg = LG_StateGraph(AgentState) # type: ignore
|
| 480 |
workflow_lg.add_node("agent", agent_node)
|
| 481 |
-
# If LG_ToolExecutor_Class is ToolNode, it can often be added directly as the node.
|
| 482 |
-
# workflow_lg.add_node("tools", tool_executor_instance_lg)
|
| 483 |
-
# For now, using the custom tool_node which wraps the executor instance.
|
| 484 |
workflow_lg.add_node("tools", tool_node)
|
| 485 |
workflow_lg.set_entry_point("agent")
|
| 486 |
def should_continue_lg(state: AgentState): return "tools" if state['messages'][-1].tool_calls else LG_END
|
|
@@ -528,7 +524,8 @@ def get_agent_response(prompt: str, task_id: Optional[str]=None, thread_id: Opti
|
|
| 528 |
try:
|
| 529 |
if is_langgraph_agent_get:
|
| 530 |
logger.debug(f"Using LangGraph agent for thread: {thread_id_to_use}")
|
| 531 |
-
|
|
|
|
| 532 |
logger.debug(f"Invoking LangGraph with input: {input_for_lg_get}")
|
| 533 |
final_state_lg_get = AGENT_INSTANCE.invoke(input_for_lg_get, {"configurable": {"thread_id": thread_id_to_use}})
|
| 534 |
|
|
@@ -574,14 +571,12 @@ def get_agent_response(prompt: str, task_id: Optional[str]=None, thread_id: Opti
|
|
| 574 |
return f"[ERROR] Agent execution failed: {str(e_agent_run_get)[:150]}"
|
| 575 |
|
| 576 |
def construct_prompt_for_agent(q: Dict[str,Any]) -> str:
|
| 577 |
-
# ... (Your original construct_prompt_for_agent logic - unchanged) ...
|
| 578 |
tid,q_str=q.get("task_id","N/A"),q.get("question",""); files=q.get("files",[])
|
| 579 |
files_info = ("\nFiles:\n"+"\n".join([f"- {f} (task_id:{tid})"for f in files])) if files else ""
|
| 580 |
level = f"\nLevel:{q.get('level')}" if q.get('level') else ""
|
| 581 |
return f"Task ID:{tid}{level}{files_info}\n\nQuestion:{q_str}"
|
| 582 |
|
| 583 |
def run_and_submit_all(profile: Optional[gr.OAuthProfile] = None):
|
| 584 |
-
# ... (Your original run_and_submit_all logic - unchanged) ...
|
| 585 |
global AGENT_INSTANCE
|
| 586 |
space_id = os.getenv("SPACE_ID")
|
| 587 |
username_for_submission = None
|
|
|
|
| 61 |
LANGGRAPH_FLAVOR_AVAILABLE = False
|
| 62 |
LG_StateGraph: Optional[Type[Any]] = None
|
| 63 |
LG_ToolExecutor_Class: Optional[Type[Any]] = None
|
| 64 |
+
LG_END: Optional[Any]] = None
|
| 65 |
LG_ToolInvocation: Optional[Type[Any]] = None
|
| 66 |
add_messages: Optional[Any] = None
|
| 67 |
MemorySaver_Class: Optional[Type[Any]] = None
|
| 68 |
|
| 69 |
AGENT_INSTANCE: Optional[Union[AgentExecutor, Any]] = None
|
| 70 |
TOOLS: List[BaseTool] = []
|
| 71 |
+
LLM_INSTANCE: Optional[ChatGoogleGenerativeAI]] = None
|
| 72 |
LANGGRAPH_MEMORY_SAVER: Optional[Any] = None
|
| 73 |
|
| 74 |
# google-genai Client SDK
|
|
|
|
| 126 |
|
| 127 |
# --- Constants ---
|
| 128 |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
| 129 |
+
GEMINI_MODEL_NAME = "gemini-1.5-pro-preview-0514"
|
| 130 |
+
GEMINI_FLASH_MULTIMODAL_MODEL_NAME = "gemini-1.5-flash-preview-0514"
|
| 131 |
SCORING_API_BASE_URL = os.getenv("SCORING_API_URL", DEFAULT_API_URL)
|
| 132 |
MAX_FILE_SIZE_BYTES = 50 * 1024 * 1024
|
| 133 |
LOCAL_FILE_STORE_PATH = "./Data"
|
|
|
|
| 255 |
name_without_ext, current_ext = os.path.splitext(effective_save_path)
|
| 256 |
if not current_ext:
|
| 257 |
content_type_header = r.headers.get('content-type', '')
|
| 258 |
+
# FIX: Handle split correctly and take first part before stripping
|
| 259 |
+
content_type_val = content_type_header.split(';')[0].strip() if content_type_header else ''
|
| 260 |
if content_type_val:
|
| 261 |
guessed_ext = mimetypes.guess_extension(content_type_val)
|
| 262 |
if guessed_ext: effective_save_path += guessed_ext; logger.info(f"Added guessed ext: {guessed_ext}")
|
|
|
|
| 325 |
|
| 326 |
@lc_tool_decorator
|
| 327 |
def direct_multimodal_gemini_tool(action_input_json_str: str) -> str:
|
| 328 |
+
"""Processes an image file (URL or local path) along with a text prompt using a Gemini multimodal model (gemini-1.5-flash-preview-0514) for tasks like image description, Q&A about the image, or text generation based on the image. Input: JSON '{\"file_identifier\": \"IMAGE_FILENAME_OR_URL\", \"text_prompt\": \"Your question or instruction related to the image.\", \"task_id\": \"TASK_ID_IF_GAIA_FILENAME_ONLY\" (optional)}'. Returns the model's text response."""
|
| 329 |
global google_genai_client
|
| 330 |
if not google_genai_client: return "Error: google-genai SDK client not initialized."
|
| 331 |
if not PIL_TESSERACT_AVAILABLE : return "Error: Pillow (PIL) library not available for image processing." # Relies on PIL_TESSERACT_AVAILABLE for PIL
|
|
|
|
| 361 |
{tools}
|
| 362 |
TOOL USAGE:
|
| 363 |
- To use a tool, your response must include a `tool_calls` attribute in the AIMessage. Each tool call should be a dictionary with "name", "args" (a dictionary of arguments), and "id".
|
| 364 |
+
- For file tools ('read_pdf_tool', 'ocr_image_tool', 'transcribe_audio_tool', 'direct_multimodal_gemini_tool'): The `args` field must be a dictionary with a single key 'action_input_json_str' whose value is a JSON STRING. Example: {{"action_input_json_str": "{{\\"file_identifier\\": \\"file.pdf\\", \\"task_id\\": \\"123\\"}}"}}.
|
| 365 |
- 'web_search': `args` is like '{{"query": "search query"}}'.
|
| 366 |
+
- 'python_repl': `args` is like '{{"query": "python code string"}}'. Use print() for output.
|
| 367 |
RESPONSE FORMAT:
|
| 368 |
Final AIMessage should contain ONLY the answer in 'content' and NO 'tool_calls'. If using tools, 'content' can be thought process, with 'tool_calls'.
|
| 369 |
Begin!
|
|
|
|
| 377 |
Tool Inputs:
|
| 378 |
- web_search: Your search query string.
|
| 379 |
- python_repl: Python code string. Use print(). For Excel/CSV, use pandas: import pandas as pd; df = pd.read_excel('./Data/TASKID_filename.xlsx'); print(df.head())
|
| 380 |
+
- read_pdf_tool, ocr_image_tool, transcribe_audio_tool, direct_multimodal_gemini_tool: JSON string like '{{"file_identifier": "FILENAME_OR_URL", "task_id": "CURRENT_TASK_ID_IF_FILENAME"}}'.
|
|
|
|
| 381 |
If tool fails or info missing, Final Answer: N/A. Do NOT use unlisted tools.
|
| 382 |
Begin!
|
| 383 |
{input}
|
|
|
|
| 422 |
try:
|
| 423 |
logger.info(f"Attempting LangGraph init (Tool Executor type: {LG_ToolExecutor_Class.__name__ if LG_ToolExecutor_Class else 'None'})")
|
| 424 |
_TypedDict = getattr(__import__('typing_extensions'), 'TypedDict', dict)
|
| 425 |
+
# FIX: Remove 'input' key from state, only use 'messages' for conversational flow
|
| 426 |
+
class AgentState(_TypedDict):
|
| 427 |
+
messages: Annotated[List[Any], add_messages]
|
| 428 |
|
| 429 |
# System prompt template - this describes the agent's role and tools.
|
| 430 |
# The {input} placeholder for the actual task will be filled by the HumanMessage.
|
| 431 |
+
base_system_prompt_content_lg = LANGGRAPH_PROMPT_TEMPLATE_STR.split("{input}")[0].strip()
|
|
|
|
| 432 |
|
| 433 |
def agent_node(state: AgentState):
|
|
|
|
|
|
|
| 434 |
system_message_content = base_system_prompt_content_lg.format(
|
| 435 |
tools="\n".join([f"- {t.name}: {t.description}" for t in TOOLS])
|
| 436 |
)
|
| 437 |
|
| 438 |
+
# FIX: Construct message list from state, don't re-add original prompt
|
| 439 |
messages_for_llm = [SystemMessage(content=system_message_content)]
|
| 440 |
+
messages_for_llm.extend(state['messages'])
|
|
|
|
| 441 |
|
| 442 |
logger.debug(f"LangGraph agent_node - messages_for_llm: {messages_for_llm}")
|
| 443 |
+
if not messages_for_llm or not any(isinstance(m, (HumanMessage, ToolMessage)) for m in messages_for_llm):
|
| 444 |
+
logger.error("LLM call would fail in agent_node: No HumanMessage or ToolMessage found in history.")
|
| 445 |
+
return {"messages": [AIMessage(content="[ERROR] Agent node: No user input found in messages.")]}
|
| 446 |
|
| 447 |
bound_llm = LLM_INSTANCE.bind_tools(TOOLS)
|
| 448 |
response = bound_llm.invoke(messages_for_llm)
|
|
|
|
| 477 |
|
| 478 |
workflow_lg = LG_StateGraph(AgentState) # type: ignore
|
| 479 |
workflow_lg.add_node("agent", agent_node)
|
|
|
|
|
|
|
|
|
|
| 480 |
workflow_lg.add_node("tools", tool_node)
|
| 481 |
workflow_lg.set_entry_point("agent")
|
| 482 |
def should_continue_lg(state: AgentState): return "tools" if state['messages'][-1].tool_calls else LG_END
|
|
|
|
| 524 |
try:
|
| 525 |
if is_langgraph_agent_get:
|
| 526 |
logger.debug(f"Using LangGraph agent for thread: {thread_id_to_use}")
|
| 527 |
+
# FIX: The input should be a list of messages for the 'add_messages' reducer.
|
| 528 |
+
input_for_lg_get = {"messages": [HumanMessage(content=prompt)]}
|
| 529 |
logger.debug(f"Invoking LangGraph with input: {input_for_lg_get}")
|
| 530 |
final_state_lg_get = AGENT_INSTANCE.invoke(input_for_lg_get, {"configurable": {"thread_id": thread_id_to_use}})
|
| 531 |
|
|
|
|
| 571 |
return f"[ERROR] Agent execution failed: {str(e_agent_run_get)[:150]}"
|
| 572 |
|
| 573 |
def construct_prompt_for_agent(q: Dict[str,Any]) -> str:
|
|
|
|
| 574 |
tid,q_str=q.get("task_id","N/A"),q.get("question",""); files=q.get("files",[])
|
| 575 |
files_info = ("\nFiles:\n"+"\n".join([f"- {f} (task_id:{tid})"for f in files])) if files else ""
|
| 576 |
level = f"\nLevel:{q.get('level')}" if q.get('level') else ""
|
| 577 |
return f"Task ID:{tid}{level}{files_info}\n\nQuestion:{q_str}"
|
| 578 |
|
| 579 |
def run_and_submit_all(profile: Optional[gr.OAuthProfile] = None):
|
|
|
|
| 580 |
global AGENT_INSTANCE
|
| 581 |
space_id = os.getenv("SPACE_ID")
|
| 582 |
username_for_submission = None
|