Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import requests | |
| import pandas as pd | |
| import time | |
| import re | |
| from smolagents import LiteLLMModel, CodeAgent, Tool | |
| # --- Constants --- | |
| DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" | |
| # --- Answer Extraction Function --- | |
| def extract_answer(text: str, original_question: str) -> str: | |
| """Extract the answer from the LLM response, being robust to various formats.""" | |
| if not text: | |
| return "- unknown" | |
| # Clean the text | |
| cleaned = text.strip() | |
| # If the response is the same as the question, it's not an answer | |
| if cleaned == original_question.strip(): | |
| return "- unknown" | |
| # Remove common prefixes | |
| prefixes_to_remove = [ | |
| '[ANSWER]:', | |
| '[ANSWER]', | |
| 'Final answer:', | |
| 'Final Answer:', | |
| 'Answer:', | |
| 'answer:', | |
| 'The answer is', | |
| 'The final answer is', | |
| ] | |
| for prefix in prefixes_to_remove: | |
| if cleaned.startswith(prefix): | |
| cleaned = cleaned[len(prefix):].strip() | |
| # If it's a "how many" question, try to extract just the number | |
| if 'how many' in original_question.lower(): | |
| # Look for numbers in the response | |
| number_patterns = [ | |
| r'(?:is|are)\s*(\d+)', | |
| r'^\s*(\d+)', | |
| r'\D(\d+)\D*$', | |
| r'(\d+)' | |
| ] | |
| for pattern in number_patterns: | |
| numbers = re.findall(pattern, cleaned) | |
| if numbers: | |
| return numbers[0] | |
| numbers = re.findall(r'\d+', cleaned) | |
| if numbers: | |
| return numbers[0] # Return the first number found | |
| # If it's asking for a year, try to extract just the year | |
| if re.search(r'\b(19|20)\d{2}\b', original_question): | |
| years = re.findall(r'\b(19|20)\d{2}\b', cleaned) | |
| if years: | |
| return years[0] # Return the first year found | |
| # If we still have the full question in the response, try to extract what comes after it | |
| if original_question.strip() in cleaned: | |
| # Split by the question and take what comes after | |
| parts = cleaned.split(original_question.strip()) | |
| if len(parts) > 1 and parts[1].strip(): | |
| cleaned = parts[1].strip() | |
| else: | |
| # Try to find numbers or short answers in the response | |
| # Look for a line that might contain the answer | |
| lines = cleaned.split('\n') | |
| for line in lines: | |
| line = line.strip() | |
| if line and line != original_question.strip(): | |
| # If it's a short line, it might be the answer | |
| if len(line) < 100 or 'how many' in original_question.lower(): | |
| cleaned = line | |
| break | |
| # If the cleaned answer is still very long and contains the question, | |
| # try to extract just the essential part | |
| if len(cleaned) > 200 and original_question.strip() in cleaned: | |
| # Try to find a short line that might be the answer | |
| lines = cleaned.split('\n') | |
| for line in lines: | |
| line = line.strip() | |
| if line and len(line) < 100 and line != original_question.strip(): | |
| # Check if it looks like an answer (short and possibly numeric) | |
| if re.match(r'^[\w\s\d\-\.,]+$', line): # Simple alphanumeric answer | |
| return line | |
| # If we still have a very long response, try to extract just the last line | |
| # which might be the answer | |
| if len(cleaned) > 200: | |
| lines = cleaned.split('\n') | |
| # Take the last non-empty line that isn't too long | |
| for line in reversed(lines): | |
| line = line.strip() | |
| if line and len(line) < 100: | |
| cleaned = line | |
| break | |
| # Final fallback - if the result is still the same as the question, return unknown | |
| if cleaned == original_question.strip(): | |
| return "- unknown" | |
| return cleaned if cleaned else "- unknown" | |
| # --- Agent Tools --- | |
| class MathSolver(Tool): | |
| name = "math_solver" | |
| description = "Safely evaluate basic math expressions." | |
| inputs = {"input": {"type": "string", "description": "Math expression to evaluate."}} | |
| output_type = "string" | |
| def forward(self, input: str) -> str: | |
| try: | |
| # Safe evaluation of math expressions | |
| allowed_names = { | |
| k: v for k, v in __builtins__.items() if k in [ | |
| 'abs', 'round', 'min', 'max', 'sum', 'pow' | |
| ] | |
| } | |
| allowed_names.update({ | |
| 'int': int, 'float': float, 'str': str, | |
| '__builtins__': {} | |
| }) | |
| return str(eval(input, allowed_names)) | |
| except Exception as e: | |
| return f"Math error: {e}" | |
| class FileAttachmentQueryTool(Tool): | |
| name = "run_query_with_file" | |
| description = "Downloads a file mentioned in a user prompt, adds it to the context, and runs a query on it." | |
| inputs = { | |
| "task_id": { | |
| "type": "string", | |
| "description": "A unique identifier for the task related to this file, used to download it.", | |
| "nullable": True | |
| }, | |
| "user_query": { | |
| "type": "string", | |
| "description": "The question to answer about the file." | |
| } | |
| } | |
| output_type = "string" | |
| def forward(self, task_id: str | None, user_query: str) -> str: | |
| if not task_id: | |
| return "No task_id provided for file download." | |
| file_url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}" | |
| try: | |
| file_response = requests.get(file_url) | |
| if file_response.status_code != 200: | |
| return f"Failed to download file: {file_response.status_code}" | |
| # For text-based files, return content directly | |
| file_content = file_response.text[:2000] # Limit content size | |
| return f"Relevant information from file: {file_content}" | |
| except Exception as e: | |
| return f"File download error: {e}" | |
| class WikipediaSearchTool(Tool): | |
| name = "wikipedia_search" | |
| description = "Search Wikipedia for information relevant to the user's query." | |
| inputs = {"query": {"type": "string", "description": "The search query."}} | |
| output_type = "string" | |
| def forward(self, query: str) -> str: | |
| try: | |
| search_url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{requests.utils.quote(query)}" | |
| response = requests.get(search_url) | |
| if response.status_code == 200: | |
| data = response.json() | |
| return data.get("extract", "No summary available.") | |
| else: | |
| return f"Wikipedia search error: {response.status_code}" | |
| except Exception as e: | |
| return f"Wikipedia search exception: {e}" | |
| # --- Agent Implementation --- | |
| def select_model(provider="groq"): | |
| """Select and return a model based on the provider.""" | |
| GROQ_MODEL_NAME = "groq/llama3-70b-8192" | |
| HF_MODEL_NAME = "huggingfaceh4/zephyr-7b-beta" | |
| if provider == "groq": | |
| api_key = os.getenv("GROQ_API_KEY") | |
| if api_key: | |
| return LiteLLMModel(model_id="groq/llama-3.1-8b-instant", | |
| api_key=os.getenv("GROQ_API_KEY")) | |
| if not api_key: | |
| raise ValueError("GROQ_API_KEY environment variable is not set") | |
| return LiteLLMModel(model_id=GROQ_MODEL_NAME, api_key=api_key) | |
| elif provider == "hf": | |
| api_key = os.getenv("HF_TOKEN") | |
| if not api_key: | |
| raise ValueError("HF_TOKEN environment variable is not set") | |
| return LiteLLMModel(model_id=HF_MODEL_NAME, api_key=api_key) | |
| else: | |
| # Default to Groq if no valid provider specified | |
| api_key = os.getenv("GROQ_API_KEY") | |
| if not api_key: | |
| raise ValueError("GROQ_API_KEY environment variable is not set") | |
| return LiteLLMModel(model_id=GROQ_MODEL_NAME, api_key=api_key) | |
| class BasicAgent: | |
| def __init__(self, provider="groq"): | |
| model = select_model(provider) | |
| tools = [ | |
| MathSolver(), | |
| FileAttachmentQueryTool(), | |
| WikipediaSearchTool() | |
| ] | |
| self.agent = CodeAgent( | |
| model=model, | |
| tools=tools, | |
| add_base_tools=False, | |
| max_steps=15, | |
| ) | |
| # System prompt to enforce exact answer format | |
| self.agent.prompt_templates["system_prompt"] = ( | |
| "You are a GAIA benchmark AI assistant. Your sole purpose is to output the minimal, final answer. " | |
| "You must NEVER output explanations, intermediate steps, reasoning, or comments β only the answer. " | |
| "For numerical answers, use digits only, e.g., `4` not `four`. " | |
| "For string answers, omit articles ('a', 'the') and use full words. " | |
| "For lists, output in comma-separated format with no conjunctions. " | |
| "If the answer is not found, say `- unknown`." | |
| "IMPORTANT: Respond with ONLY the answer, nothing else. No prefixes, no explanations." | |
| ) | |
| def __call__(self, question: str) -> str: | |
| max_retries = 3 | |
| retry_delay = 10 # Start with 10 seconds | |
| for attempt in range(max_retries): | |
| try: | |
| result = self.agent.run(question) | |
| # Use our enhanced extraction function | |
| final_str = extract_answer(str(result), question) | |
| return final_str | |
| except Exception as e: | |
| # Check if it's a rate limit error | |
| if "RateLimitError" in str(e) or "rate_limit_exceeded" in str(e): | |
| if attempt < max_retries - 1: # Not the last attempt | |
| print(f"Rate limit hit. Waiting {retry_delay} seconds before retry {attempt + 1}/{max_retries}") | |
| time.sleep(retry_delay) | |
| retry_delay *= 2 # Exponential backoff | |
| continue | |
| else: | |
| return f"Rate limit error after {max_retries} attempts: {e}" | |
| else: | |
| # Not a rate limit error, re-raise | |
| raise e | |
| return f"Failed to get response after {max_retries} attempts" | |
| # --- Main Application Functions --- | |
| def run_and_submit_all(profile: gr.OAuthProfile | None): | |
| """ | |
| Fetches all questions, runs the BasicAgent on them, submits all answers, | |
| and displays the results. | |
| """ | |
| # --- Determine HF Space Runtime URL and Repo URL --- | |
| space_id = os.getenv("SPACE_ID") | |
| if profile: | |
| username = f"{profile.username}" | |
| print(f"User logged in: {username}") | |
| else: | |
| print("User not logged in.") | |
| return "Please Login to Hugging Face with the button.", None | |
| api_url = DEFAULT_API_URL | |
| questions_url = f"{api_url}/questions" | |
| submit_url = f"{api_url}/submit" | |
| # 1. Instantiate Agent | |
| try: | |
| agent = BasicAgent() | |
| except Exception as e: | |
| print(f"Error instantiating agent: {e}") | |
| return f"Error initializing agent: {e}", None | |
| # In the case of an app running as a hugging Face space, this link points toward your codebase | |
| agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main" | |
| print(agent_code) | |
| # 2. Fetch Questions | |
| print(f"Fetching questions from: {questions_url}") | |
| try: | |
| response = requests.get(questions_url, timeout=30) | |
| response.raise_for_status() | |
| questions_data = response.json() | |
| if not questions_data: | |
| print("Fetched questions list is empty.") | |
| return "Fetched questions list is empty or invalid format.", None | |
| print(f"Fetched {len(questions_data)} questions.") | |
| except requests.exceptions.RequestException as e: | |
| print(f"Error fetching questions: {e}") | |
| return f"Error fetching questions: {e}", None | |
| except requests.exceptions.JSONDecodeError as e: | |
| print(f"Error decoding JSON response from questions endpoint: {e}") | |
| print(f"Response text: {response.text[:500]}") | |
| return f"Error decoding server response for questions: {e}", None | |
| except Exception as e: | |
| print(f"An unexpected error occurred fetching questions: {e}") | |
| return f"An unexpected error occurred fetching questions: {e}", None | |
| # 3. Run your Agent | |
| results_log = [] | |
| answers_payload = [] | |
| print(f"Running agent on {len(questions_data)} questions...") | |
| # Progress tracking | |
| progress_count = 0 | |
| total_questions = len(questions_data) | |
| for item in questions_data: | |
| task_id = item.get("task_id") | |
| question_text = item.get("question") | |
| if not task_id or question_text is None: | |
| print(f"Skipping item with missing task_id or question: {item}") | |
| continue | |
| # Update progress | |
| progress_count += 1 | |
| print(f"Processing question {progress_count}/{total_questions}") | |
| try: | |
| submitted_answer = agent(question_text) | |
| answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer}) | |
| results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer}) | |
| break | |
| # Add a small delay between questions to help with rate limiting | |
| if progress_count < total_questions: # Don't delay after the last question | |
| time.sleep(2) | |
| except Exception as e: | |
| print(f"Error running agent on task {task_id}: {e}") | |
| results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"}) | |
| if not answers_payload: | |
| print("Agent did not produce any answers to submit.") | |
| return "Agent did not produce any answers to submit.", pd.DataFrame(results_log) | |
| # 4. Prepare Submission | |
| submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload} | |
| status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..." | |
| print(status_update) | |
| # 5. Submit | |
| print(f"Submitting {len(answers_payload)} answers to: {submit_url}") | |
| try: | |
| response = requests.post(submit_url, json=submission_data, timeout=120) | |
| response.raise_for_status() | |
| result_data = response.json() | |
| final_status = ( | |
| f"β Submission Successful!\n" | |
| f"User: {result_data.get('username')}\n" | |
| f"Overall Score: {result_data.get('score', 'N/A')}% " | |
| f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n" | |
| f"Message: {result_data.get('message', 'No message received.')}" | |
| ) | |
| print("Submission successful.") | |
| results_df = pd.DataFrame(results_log) | |
| return final_status, results_df | |
| except requests.exceptions.HTTPError as e: | |
| error_detail = f"Server responded with status {e.response.status_code}." | |
| try: | |
| error_json = e.response.json() | |
| error_detail += f" Detail: {error_json.get('detail', e.response.text)}" | |
| except requests.exceptions.JSONDecodeError: | |
| error_detail += f" Response: {e.response.text[:500]}" | |
| status_message = f"β Submission Failed: {error_detail}" | |
| print(status_message) | |
| results_df = pd.DataFrame(results_log) | |
| return status_message, results_df | |
| except requests.exceptions.Timeout: | |
| status_message = "β Submission Failed: The request timed out. Please try again." | |
| print(status_message) | |
| results_df = pd.DataFrame(results_log) | |
| return status_message, results_df | |
| except requests.exceptions.RequestException as e: | |
| status_message = f"β Submission Failed: Network error - {e}" | |
| print(status_message) | |
| results_df = pd.DataFrame(results_log) | |
| return status_message, results_df | |
| except Exception as e: | |
| status_message = f"β An unexpected error occurred during submission: {e}" | |
| print(status_message) | |
| results_df = pd.DataFrame(results_log) | |
| return status_message, results_df | |
| def test_agent(question: str, provider: str): | |
| """Test the agent with a single question.""" | |
| try: | |
| agent = BasicAgent(provider=provider) | |
| answer = agent(question) | |
| return f"Question: {question}\nAnswer: {answer}" | |
| except Exception as e: | |
| return f"Error testing agent: {e}" | |
| # --- Build Gradio Interface using Blocks --- | |
| with gr.Blocks(title="GAIA Agent Evaluator") as demo: | |
| gr.Markdown("# π€ GAIA Agent Evaluator") | |
| gr.Markdown( | |
| """ | |
| This interface allows you to evaluate your agent against the GAIA benchmark questions. | |
| **Instructions:** | |
| 1. Log in to your Hugging Face account using the button below | |
| 2. Click 'Run Evaluation & Submit All Answers' to fetch questions, run your agent, and submit answers | |
| 3. View your results and score in the output panel | |
| **For Testing:** | |
| Use the test section below to verify your agent works correctly with sample questions. | |
| """ | |
| ) | |
| with gr.Tab("Evaluation"): | |
| gr.Markdown("## π Run Full Evaluation") | |
| gr.LoginButton() | |
| with gr.Row(): | |
| run_button = gr.Button("Run Evaluation & Submit All Answers", variant="primary") | |
| status_output = gr.Textbox(label="π Status / Submission Result", lines=8, interactive=False) | |
| results_table = gr.DataFrame(label="π Questions and Agent Answers", wrap=True) | |
| run_button.click( | |
| fn=run_and_submit_all, | |
| outputs=[status_output, results_table] | |
| ) | |
| with gr.Tab("Testing"): | |
| gr.Markdown("## π§ͺ Test Your Agent") | |
| with gr.Row(): | |
| with gr.Column(): | |
| test_question = gr.Textbox( | |
| label="Question", | |
| placeholder="Enter a test question...", | |
| value="What is 2+2?" | |
| ) | |
| provider_choice = gr.Radio( | |
| choices=["groq", "hf"], | |
| value="groq", | |
| label="Provider" | |
| ) | |
| test_button = gr.Button("Test Agent") | |
| with gr.Column(): | |
| test_output = gr.Textbox(label="Agent Response", lines=10, interactive=False) | |
| test_button.click( | |
| fn=test_agent, | |
| inputs=[test_question, provider_choice], | |
| outputs=test_output | |
| ) | |
| if __name__ == "__main__": | |
| print("\n" + "="*50) | |
| print("π GAIA Agent Evaluator Starting") | |
| print("="*50) | |
| # Check for SPACE_HOST and SPACE_ID at startup for information | |
| space_host_startup = os.getenv("SPACE_HOST") | |
| space_id_startup = os.getenv("SPACE_ID") | |
| if space_host_startup: | |
| print(f"β SPACE_HOST found: {space_host_startup}") | |
| print(f" Runtime URL: https://{space_host_startup}.hf.space") | |
| else: | |
| print("βΉοΈ Running locally (SPACE_HOST not found)") | |
| if space_id_startup: | |
| print(f"β SPACE_ID found: {space_id_startup}") | |
| print(f" Repo URL: https://huggingface.co/spaces/{space_id_startup}") | |
| else: | |
| print("βΉοΈ SPACE_ID not found (Repo URL cannot be determined)") | |
| print("="*50) | |
| print("Launching Gradio Interface...") | |
| demo.launch(debug=True, share=False) |