Spaces:
Sleeping
Sleeping
IDAgents Developer
Fix: Orchestrator now finds subagents - Pass per-user agents config to orchestrator
e10b7d3
| """ | |
| chat_orchestrator.py | |
| ------------------- | |
| Agent orchestration, chat streaming, and related logic. | |
| """ | |
| import json | |
| from typing import Dict, cast | |
| from core.utils.llm_connector import AgentLLMConnector | |
| from core.utils.skills_registry import get_tool_by_name | |
| from core.agents.agent_utils import validate_and_reference_recommendation | |
| def convert_gradio_to_dict_format(history): | |
| """ | |
| Convert Gradio tuple format [["user1", "assistant1"], ["user2", "assistant2"]] | |
| to dictionary format [{"role": "user", "content": "user1"}, {"role": "assistant", "content": "assistant1"}, ...] | |
| """ | |
| if not history: | |
| return [] | |
| dict_history = [] | |
| for item in history: | |
| if isinstance(item, dict): | |
| # Already in dictionary format | |
| dict_history.append(item) | |
| elif isinstance(item, (list, tuple)) and len(item) >= 2: | |
| # Tuple format: ["user_msg", "assistant_msg"] | |
| user_msg, assistant_msg = item[0], item[1] | |
| if user_msg: # Only add if not empty | |
| dict_history.append({"role": "user", "content": user_msg}) | |
| if assistant_msg: # Only add if not empty | |
| dict_history.append({"role": "assistant", "content": assistant_msg}) | |
| return dict_history | |
| def convert_dict_to_gradio_format(dict_history): | |
| """ | |
| Convert dictionary format back to Gradio tuple format | |
| """ | |
| if not dict_history: | |
| return [] | |
| gradio_history = [] | |
| current_pair = ["", ""] # [user_msg, assistant_msg] | |
| for msg in dict_history: | |
| if isinstance(msg, dict): | |
| role = msg.get("role", "") | |
| content = msg.get("content", "") | |
| if role == "user": | |
| # Start a new pair or complete current pair | |
| if current_pair[0]: | |
| # Previous pair is incomplete, add it and start new | |
| gradio_history.append([current_pair[0], current_pair[1]]) | |
| current_pair = [content, ""] | |
| elif role == "assistant": | |
| # Complete current pair | |
| if not current_pair[0]: | |
| current_pair[0] = "" # Empty user message | |
| current_pair[1] = content | |
| gradio_history.append([current_pair[0], current_pair[1]]) | |
| current_pair = ["", ""] | |
| # Add any incomplete pair | |
| if current_pair[0]: | |
| gradio_history.append([current_pair[0], current_pair[1]]) | |
| return gradio_history | |
| # --- Orchestrator state --- | |
| orchestrators: Dict[str, object] = {} | |
| # --- Constants --- | |
| MAX_HISTORY = 20 | |
| # --- Build log utility --- | |
| def build_log(conn): | |
| """Return a formatted log of tool invocations if available.""" | |
| if hasattr(conn, 'invocations'): | |
| invocations = getattr(conn, 'invocations', []) | |
| if not invocations: | |
| return '' | |
| log = '--- Tool Invocation Log ---\n' | |
| for i, inv in enumerate(invocations, 1): | |
| log += f"{i}. {inv}\n" | |
| return log | |
| return '' | |
| # --- Streaming to agent (for child agents) --- | |
| def _stream_to_agent(cfg, history, user_input, debug_flag, active_children): | |
| """Yield streaming responses for a child agent (sync generator).""" | |
| # Convert history to dictionary format for internal processing | |
| dict_history = convert_gradio_to_dict_format(history) | |
| skill_objs = [] | |
| if cfg.get("web_access", False): | |
| web_tool = get_tool_by_name("search_internet", {"user_query": user_input}) | |
| if web_tool: | |
| skill_objs.append(web_tool) | |
| for skill_name in cfg.get("skills", []): | |
| tool = get_tool_by_name(skill_name, {"user_query": user_input}) | |
| if tool: | |
| skill_objs.append(tool) | |
| # Pass allow_fallback, trusted_links, grounded_files to AgentLLMConnector | |
| allow_fallback = cfg.get("allow_fallback", True) | |
| trusted_links = cfg.get("trusted_links", []) | |
| grounded_files = cfg.get("grounded_files", []) | |
| # Get global RAG retriever if available | |
| rag_retriever = None | |
| try: | |
| import sys | |
| if 'app' in sys.modules: | |
| app_module = sys.modules['app'] | |
| rag_retriever = getattr(app_module, 'rag_retriever', None) | |
| except: | |
| pass # No RAG retriever available | |
| conn = AgentLLMConnector( | |
| api_key=cast(str, cfg.get("api_key")), | |
| skills=skill_objs, | |
| allow_fallback=allow_fallback, | |
| trusted_links=trusted_links, | |
| grounded_files=grounded_files, | |
| rag_retriever=rag_retriever | |
| ) | |
| model = conn.agent_model_mapping.get(cfg.get("agent_type", ""), "gpt-5-mini") | |
| # Build enhanced system message with tool-specific guidance | |
| system_content = f"You are {cfg.get('agent_name', 'Agent')}. {cfg.get('agent_mission', '')}" | |
| # Add general concise response guidelines for child agents | |
| system_content += ( | |
| "\n\nRESPONSE GUIDELINES:\n" | |
| "- Be concise and clinically focused\n" | |
| "- Provide clear, actionable recommendations\n" | |
| "- Avoid excessive reasoning or explanation unless specifically requested\n" | |
| "- Structure responses with clear sections when appropriate\n" | |
| "- Use bullet points or numbered lists for multiple recommendations" | |
| ) | |
| # Add specific guidance for IPC reporting | |
| has_ipc_reporting = any(skill.name == "IPC_reporting" for skill in skill_objs) | |
| if has_ipc_reporting: | |
| system_content += ( | |
| "\n\nWhen users ask about reportable diseases, reporting requirements, or infection control reporting, " | |
| "always offer to help with the specific reporting process. Use the IPC_reporting tool to provide " | |
| "jurisdiction-specific requirements and generate formatted reports. " | |
| "IMPORTANT: When calling IPC_reporting, include ALL conversation context in the case_summary parameter, " | |
| "especially the specific organism/pathogen mentioned by the user (e.g., 'User asked about typhus fever reporting')." | |
| ) | |
| system_msg = {"role": "system", "content": system_content} | |
| from collections import deque | |
| recent = deque(dict_history, maxlen=MAX_HISTORY) | |
| history_msgs = [ {"role": m["role"], "content": m["content"]} for m in recent if m["role"] in ("user", "assistant")] | |
| messages = [system_msg] + history_msgs | |
| dict_history.append({"role": "assistant", "content": ""}) | |
| buf = "" | |
| # This is a sync generator for child agents; in real use, adapt to async if needed | |
| # Fix: call the async generator and iterate with asyncio | |
| import asyncio | |
| async def run_stream(): | |
| async for token in conn.chat_with_agent_stream(model_name=model, messages=messages): | |
| yield token | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| gen = run_stream() | |
| try: | |
| while True: | |
| token = loop.run_until_complete(gen.__anext__()) | |
| buf += token | |
| dict_history[-1]["content"] = buf | |
| invocation_log = build_log(conn) | |
| # Convert back to Gradio format for display | |
| gradio_history = convert_dict_to_gradio_format(dict_history) | |
| yield gradio_history, "", invocation_log, active_children, None | |
| except StopAsyncIteration: | |
| pass | |
| dict_history[-1]["content"] = buf.strip() | |
| invocation_log = build_log(conn) | |
| # Convert back to Gradio format for final display | |
| gradio_history = convert_dict_to_gradio_format(dict_history) | |
| yield gradio_history, "", invocation_log, active_children, None | |
| # --- Main async chat orchestrator --- | |
| async def simulate_agent_response_stream(agent_json, history, user_input, debug_flag, active_children, user_agents_config=None): | |
| """ | |
| Streams agent replies with sliding window, multi-agent routing, | |
| and invocation logging under orchestrator flows. | |
| Yields: history, cleared input, invocation log, active_children, challenger_info. | |
| Args: | |
| user_agents_config: Dict of user's agent configurations (for orchestrator subagents) | |
| """ | |
| if not agent_json or not agent_json.strip(): | |
| # Convert to Gradio format and return | |
| dict_history = convert_gradio_to_dict_format(history) | |
| dict_history.append({"role": "assistant", "content": "⚠️ No agent configuration found. Please Generate or Load an agent first."}) | |
| gradio_history = convert_dict_to_gradio_format(dict_history) | |
| yield gradio_history, "", "", active_children, None | |
| return | |
| try: | |
| cfg = json.loads(agent_json) | |
| except json.JSONDecodeError: | |
| # Convert to Gradio format and return | |
| dict_history = convert_gradio_to_dict_format(history) | |
| dict_history.append({"role": "assistant", "content": "⚠️ Invalid agent configuration. Please regenerate or reload the agent."}) | |
| gradio_history = convert_dict_to_gradio_format(dict_history) | |
| yield gradio_history, "", "", active_children, None | |
| return | |
| # Convert history from Gradio tuple format to dictionary format for internal processing | |
| dict_history = convert_gradio_to_dict_format(history) | |
| name = cfg.get("agent_name", "Agent") | |
| mission = cfg.get("agent_mission", "") | |
| agent_type = cfg.get("agent_type", "") | |
| if not dict_history: | |
| dict_history.append({"role": "assistant", "content": f"👋 Hello! I'm {name}. How can I assist today?"}) | |
| gradio_history = convert_dict_to_gradio_format(dict_history) | |
| yield gradio_history, "", "", active_children, None | |
| # Add user message to history and display it immediately | |
| dict_history.append({"role": "user", "content": user_input}) | |
| # Convert back to Gradio format for display | |
| gradio_history = convert_dict_to_gradio_format(dict_history) | |
| yield gradio_history, "", "", active_children, None | |
| if agent_type == "🎼 Orchestrator": | |
| try: | |
| from core.agents.orchestrator import OrchestratorAgent | |
| global orchestrators | |
| name = cfg.get("agent_name", "orchestrator") | |
| # Create a fresh orchestrator instance for new conversations, | |
| # but reuse existing instance to maintain state for execution | |
| orch = orchestrators.get(name) | |
| if orch is None: | |
| # Use per-user agents config if provided, otherwise fallback to global (legacy) | |
| if user_agents_config is not None: | |
| runtime_agents_config = user_agents_config | |
| else: | |
| # Fallback: try to import from app.py or config | |
| import sys | |
| if 'app' in sys.modules: | |
| app_module = sys.modules['app'] | |
| runtime_agents_config = getattr(app_module, 'agents_config', {}) | |
| else: | |
| from config import agents_config as runtime_agents_config | |
| # Create new orchestrator instance only if none exists | |
| orch = OrchestratorAgent(runtime_agents_config, cast(str, cfg.get("api_key", ""))) | |
| orchestrators[name] = orch | |
| dict_history.append({"role": "assistant", "content": ""}) | |
| orch_agent = cast(OrchestratorAgent, orch) | |
| answer_gen = orch_agent.answer(dict_history, user_input, debug=debug_flag) | |
| async for msg in answer_gen: | |
| if isinstance(msg, dict): | |
| chunk = msg.get("content", "") | |
| if chunk: | |
| dict_history[-1]["content"] += chunk | |
| else: | |
| dict_history[-1]["content"] += str(msg) | |
| invocation_log = build_log(orch) if hasattr(orch, 'invocations') and isinstance(orch, OrchestratorAgent) else "" | |
| # Convert back to Gradio format for display | |
| gradio_history = convert_dict_to_gradio_format(dict_history) | |
| yield gradio_history, "", invocation_log, active_children, None | |
| except ImportError: | |
| dict_history.append({"role": "assistant", "content": "Orchestrator not available."}) | |
| # Convert back to Gradio format for display | |
| gradio_history = convert_dict_to_gradio_format(dict_history) | |
| yield gradio_history, "", "", active_children, None | |
| return | |
| if active_children: | |
| for child_json in active_children: | |
| child_cfg = json.loads(child_json) | |
| for output in _stream_to_agent(child_cfg, dict_history, user_input, debug_flag, active_children): | |
| yield output | |
| return | |
| skill_objs = [] | |
| if cfg.get("web_access", False): | |
| web_tool = get_tool_by_name("search_internet", {"user_query": user_input}) | |
| if web_tool: | |
| skill_objs.append(web_tool) | |
| for skill_name in cfg.get("skills", []): | |
| tool = get_tool_by_name(skill_name, {"user_query": user_input}) | |
| if tool: | |
| skill_objs.append(tool) | |
| # Pass allow_fallback, trusted_links, grounded_files to AgentLLMConnector | |
| allow_fallback = cfg.get("allow_fallback", True) | |
| trusted_links = cfg.get("trusted_links", []) | |
| grounded_files = cfg.get("grounded_files", []) | |
| # Get global RAG retriever if available | |
| rag_retriever = None | |
| try: | |
| import sys | |
| if 'app' in sys.modules: | |
| app_module = sys.modules['app'] | |
| rag_retriever = getattr(app_module, 'rag_retriever', None) | |
| except: | |
| pass # No RAG retriever available | |
| conn = AgentLLMConnector( | |
| api_key=cast(str, cfg.get("api_key", "")), | |
| skills=skill_objs, | |
| allow_fallback=allow_fallback, | |
| trusted_links=trusted_links, | |
| grounded_files=grounded_files, | |
| rag_retriever=rag_retriever | |
| ) | |
| model = conn.agent_model_mapping.get(agent_type, "gpt-5-mini") | |
| has_history_tool = any(t.name == "history_taking" for t in skill_objs) | |
| has_ipc_reporting = any(t.name == "IPC_reporting" for t in skill_objs) | |
| if agent_type == "🏥 Clinical Assistant" and has_history_tool: | |
| system_content = ( | |
| f"You are {name}. {mission}\n\n" | |
| "Before giving any advice, gather all necessary patient history by calling the " | |
| "`history_taking` function. A JSON-schema has been provided with each question as a " | |
| "parameter description. Ask wauestions, wait for the user's answer, and only " | |
| "once every required field is filled will you then provide your final recommendation.\n\n" | |
| "RESPONSE FORMAT: Keep responses concise and clinical. Avoid lengthy explanations unless specifically asked. " | |
| "Focus on actionable recommendations and key clinical points." | |
| ) | |
| else: | |
| system_content = f"You are {name}." + (f" {mission}" if mission else "") | |
| # Add general instruction to keep responses focused and concise | |
| system_content += ( | |
| "\n\nRESPONSE GUIDELINES:\n" | |
| "- Be concise and clinically focused\n" | |
| "- Provide clear, actionable recommendations\n" | |
| "- Avoid excessive reasoning or explanation unless specifically requested\n" | |
| "- Structure responses with clear sections when appropriate\n" | |
| "- Use bullet points or numbered lists for multiple recommendations" | |
| ) | |
| # Add specific guidance for IPC reporting | |
| if has_ipc_reporting: | |
| system_content += ( | |
| "\n\nWhen users ask about reportable diseases, reporting requirements, or infection control reporting, " | |
| "always offer to help with the specific reporting process. Use the IPC_reporting tool to provide " | |
| "jurisdiction-specific requirements and generate formatted reports. " | |
| "IMPORTANT: When calling IPC_reporting, include ALL conversation context in the case_summary parameter, " | |
| "especially the specific organism/pathogen mentioned by the user (e.g., 'User asked about typhus fever reporting')." | |
| ) | |
| system_msg = {"role": "system", "content": system_content} | |
| from collections import deque | |
| recent = deque(dict_history, maxlen=MAX_HISTORY) | |
| history_msgs = [ {"role": m["role"], "content": m["content"]} for m in recent if m["role"] in ("user", "assistant")] | |
| messages = [system_msg] + history_msgs | |
| dict_history.append({"role": "assistant", "content": ""}) | |
| buf = "" | |
| tool_invoked = False | |
| async for token in conn.chat_with_agent_stream(model_name=model, messages=messages): | |
| buf += token | |
| dict_history[-1]["content"] = buf | |
| invocation_log = build_log(conn) | |
| # Detect if a tool was invoked (by tool name in the reply) | |
| for tool in skill_objs: | |
| if tool.name in buf: | |
| tool_invoked = True | |
| # Convert back to Gradio format for display | |
| gradio_history = convert_dict_to_gradio_format(dict_history) | |
| yield gradio_history, "", invocation_log, active_children, None | |
| # Apply clinical validation and references to the final reply | |
| original_reply = validate_and_reference_recommendation(buf.strip()) | |
| # --- Challenger step: adversarial critique if enabled --- | |
| challenger_enabled = cfg.get("challenger_enabled", False) | |
| critique = None | |
| final_reply = original_reply | |
| # --- Only run challenger if the required fields for the tool actually invoked are present, or if FORCE_CHALLENGE is present --- | |
| def required_fields_present_for_invoked_tool(): | |
| # Try to infer which tool was actually invoked (last tool in skill_objs with a matching name in the reply) | |
| from core.utils.skills_registry import get_tool_by_name | |
| invoked_tool = None | |
| for skill_name in cfg.get("skills", []): | |
| if skill_name in original_reply: | |
| invoked_tool = get_tool_by_name(skill_name, {"user_query": user_input}) | |
| break | |
| if not invoked_tool and skill_objs: | |
| invoked_tool = skill_objs[-1] # fallback: last tool | |
| if not invoked_tool: | |
| return True # fallback: allow | |
| required_fields = invoked_tool.args_schema.get("required", []) | |
| if not required_fields: | |
| return True | |
| for field in required_fields: | |
| found = False | |
| for m in dict_history[::-1]: | |
| if m["role"] == "user" and (field.replace("_", " ") in m["content"].lower() or field in m["content"].lower()): | |
| found = True | |
| break | |
| if not found: | |
| return False | |
| return True | |
| force_challenge = "FORCE_CHALLENGE" in user_input or "FORCE_CHALLENGE" in original_reply | |
| # Always run challenger if enabled (or forced) | |
| if challenger_enabled or force_challenge: | |
| try: | |
| from core.utils.llm_connector import challenge_agent_response, refine_final_answer | |
| user_message = user_input | |
| agent_reply = original_reply | |
| # Pass conversation history to challenger for better context awareness | |
| critique = await challenge_agent_response(user_message, agent_reply, dict_history) | |
| # If critique is None or empty, treat as OK | |
| if not critique or critique.strip().upper() == "OK": | |
| critique = "OK" | |
| final_reply = original_reply | |
| else: | |
| # Use a refiner LLM to produce a clean, user-facing answer | |
| final_reply = await refine_final_answer(user_message, original_reply, critique) | |
| except Exception as e: | |
| critique = f"[Challenger error: {e}]" | |
| final_reply = original_reply | |
| dict_history[-1]["content"] = final_reply | |
| invocation_log = build_log(conn) | |
| # Convert back to Gradio format for final display | |
| gradio_history = convert_dict_to_gradio_format(dict_history) | |
| yield gradio_history, "", invocation_log, active_children, { | |
| "original_reply": original_reply, | |
| "challenger_critique": critique, | |
| "final_reply": final_reply | |
| } | |