olcapone's picture
Update app.py
c687b40 verified
raw
history blame
19.7 kB
import os
import gradio as gr
import requests
import pandas as pd
import time
import re
from smolagents import LiteLLMModel, CodeAgent, Tool
# --- Constants ---
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
# --- Answer Extraction Function ---
def extract_answer(text: str, original_question: str) -> str:
"""Extract the answer from the LLM response, being robust to various formats."""
if not text:
return "- unknown"
# Clean the text
cleaned = text.strip()
# If the response is the same as the question, it's not an answer
if cleaned == original_question.strip():
return "- unknown"
# Remove common prefixes
prefixes_to_remove = [
'[ANSWER]:',
'[ANSWER]',
'Final answer:',
'Final Answer:',
'Answer:',
'answer:',
'The answer is',
'The final answer is',
]
for prefix in prefixes_to_remove:
if cleaned.startswith(prefix):
cleaned = cleaned[len(prefix):].strip()
# If it's a "how many" question, try to extract just the number
if 'how many' in original_question.lower():
# Look for numbers in the response
number_patterns = [
r'(?:is|are)\s*(\d+)',
r'^\s*(\d+)',
r'\D(\d+)\D*$',
r'(\d+)'
]
for pattern in number_patterns:
numbers = re.findall(pattern, cleaned)
if numbers:
return numbers[0]
numbers = re.findall(r'\d+', cleaned)
if numbers:
return numbers[0] # Return the first number found
# If it's asking for a year, try to extract just the year
if re.search(r'\b(19|20)\d{2}\b', original_question):
years = re.findall(r'\b(19|20)\d{2}\b', cleaned)
if years:
return years[0] # Return the first year found
# If we still have the full question in the response, try to extract what comes after it
if original_question.strip() in cleaned:
# Split by the question and take what comes after
parts = cleaned.split(original_question.strip())
if len(parts) > 1 and parts[1].strip():
cleaned = parts[1].strip()
else:
# Try to find numbers or short answers in the response
# Look for a line that might contain the answer
lines = cleaned.split('\n')
for line in lines:
line = line.strip()
if line and line != original_question.strip():
# If it's a short line, it might be the answer
if len(line) < 100 or 'how many' in original_question.lower():
cleaned = line
break
# If the cleaned answer is still very long and contains the question,
# try to extract just the essential part
if len(cleaned) > 200 and original_question.strip() in cleaned:
# Try to find a short line that might be the answer
lines = cleaned.split('\n')
for line in lines:
line = line.strip()
if line and len(line) < 100 and line != original_question.strip():
# Check if it looks like an answer (short and possibly numeric)
if re.match(r'^[\w\s\d\-\.,]+$', line): # Simple alphanumeric answer
return line
# If we still have a very long response, try to extract just the last line
# which might be the answer
if len(cleaned) > 200:
lines = cleaned.split('\n')
# Take the last non-empty line that isn't too long
for line in reversed(lines):
line = line.strip()
if line and len(line) < 100:
cleaned = line
break
# Final fallback - if the result is still the same as the question, return unknown
if cleaned == original_question.strip():
return "- unknown"
return cleaned if cleaned else "- unknown"
# --- Agent Tools ---
class MathSolver(Tool):
name = "math_solver"
description = "Safely evaluate basic math expressions."
inputs = {"input": {"type": "string", "description": "Math expression to evaluate."}}
output_type = "string"
def forward(self, input: str) -> str:
try:
# Safe evaluation of math expressions
allowed_names = {
k: v for k, v in __builtins__.items() if k in [
'abs', 'round', 'min', 'max', 'sum', 'pow'
]
}
allowed_names.update({
'int': int, 'float': float, 'str': str,
'__builtins__': {}
})
return str(eval(input, allowed_names))
except Exception as e:
return f"Math error: {e}"
class FileAttachmentQueryTool(Tool):
name = "run_query_with_file"
description = "Downloads a file mentioned in a user prompt, adds it to the context, and runs a query on it."
inputs = {
"task_id": {
"type": "string",
"description": "A unique identifier for the task related to this file, used to download it.",
"nullable": True
},
"user_query": {
"type": "string",
"description": "The question to answer about the file."
}
}
output_type = "string"
def forward(self, task_id: str | None, user_query: str) -> str:
if not task_id:
return "No task_id provided for file download."
file_url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}"
try:
file_response = requests.get(file_url)
if file_response.status_code != 200:
return f"Failed to download file: {file_response.status_code}"
# For text-based files, return content directly
file_content = file_response.text[:2000] # Limit content size
return f"Relevant information from file: {file_content}"
except Exception as e:
return f"File download error: {e}"
class WikipediaSearchTool(Tool):
name = "wikipedia_search"
description = "Search Wikipedia for information relevant to the user's query."
inputs = {"query": {"type": "string", "description": "The search query."}}
output_type = "string"
def forward(self, query: str) -> str:
try:
search_url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{requests.utils.quote(query)}"
response = requests.get(search_url)
if response.status_code == 200:
data = response.json()
return data.get("extract", "No summary available.")
else:
return f"Wikipedia search error: {response.status_code}"
except Exception as e:
return f"Wikipedia search exception: {e}"
# --- Agent Implementation ---
def select_model(provider="groq"):
"""Select and return a model based on the provider."""
GROQ_MODEL_NAME = "groq/llama3-70b-8192"
HF_MODEL_NAME = "huggingfaceh4/zephyr-7b-beta"
if provider == "groq":
api_key = os.getenv("GROQ_API_KEY")
if api_key:
return LiteLLMModel(model_id="groq/llama-3.1-8b-instant",
api_key=os.getenv("GROQ_API_KEY"))
if not api_key:
raise ValueError("GROQ_API_KEY environment variable is not set")
return LiteLLMModel(model_id=GROQ_MODEL_NAME, api_key=api_key)
elif provider == "hf":
api_key = os.getenv("HF_TOKEN")
if not api_key:
raise ValueError("HF_TOKEN environment variable is not set")
return LiteLLMModel(model_id=HF_MODEL_NAME, api_key=api_key)
else:
# Default to Groq if no valid provider specified
api_key = os.getenv("GROQ_API_KEY")
if not api_key:
raise ValueError("GROQ_API_KEY environment variable is not set")
return LiteLLMModel(model_id=GROQ_MODEL_NAME, api_key=api_key)
class BasicAgent:
def __init__(self, provider="groq"):
model = select_model(provider)
tools = [
MathSolver(),
FileAttachmentQueryTool(),
WikipediaSearchTool()
]
self.agent = CodeAgent(
model=model,
tools=tools,
add_base_tools=False,
max_steps=15,
)
# System prompt to enforce exact answer format
self.agent.prompt_templates["system_prompt"] = (
"You are a GAIA benchmark AI assistant. Your sole purpose is to output the minimal, final answer. "
"You must NEVER output explanations, intermediate steps, reasoning, or comments β€” only the answer. "
"For numerical answers, use digits only, e.g., `4` not `four`. "
"For string answers, omit articles ('a', 'the') and use full words. "
"For lists, output in comma-separated format with no conjunctions. "
"If the answer is not found, say `- unknown`."
"IMPORTANT: Respond with ONLY the answer, nothing else. No prefixes, no explanations."
)
def __call__(self, question: str) -> str:
max_retries = 3
retry_delay = 10 # Start with 10 seconds
for attempt in range(max_retries):
try:
result = self.agent.run(question)
# Use our enhanced extraction function
final_str = extract_answer(str(result), question)
return final_str
except Exception as e:
# Check if it's a rate limit error
if "RateLimitError" in str(e) or "rate_limit_exceeded" in str(e):
if attempt < max_retries - 1: # Not the last attempt
print(f"Rate limit hit. Waiting {retry_delay} seconds before retry {attempt + 1}/{max_retries}")
time.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
continue
else:
return f"Rate limit error after {max_retries} attempts: {e}"
else:
# Not a rate limit error, re-raise
raise e
return f"Failed to get response after {max_retries} attempts"
# --- Main Application Functions ---
def run_and_submit_all(profile: gr.OAuthProfile | None):
"""
Fetches all questions, runs the BasicAgent on them, submits all answers,
and displays the results.
"""
# --- Determine HF Space Runtime URL and Repo URL ---
space_id = os.getenv("SPACE_ID")
if profile:
username = f"{profile.username}"
print(f"User logged in: {username}")
else:
print("User not logged in.")
return "Please Login to Hugging Face with the button.", None
api_url = DEFAULT_API_URL
questions_url = f"{api_url}/questions"
submit_url = f"{api_url}/submit"
# 1. Instantiate Agent
try:
agent = BasicAgent()
except Exception as e:
print(f"Error instantiating agent: {e}")
return f"Error initializing agent: {e}", None
# In the case of an app running as a hugging Face space, this link points toward your codebase
agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
print(agent_code)
# 2. Fetch Questions
print(f"Fetching questions from: {questions_url}")
try:
response = requests.get(questions_url, timeout=30)
response.raise_for_status()
questions_data = response.json()
if not questions_data:
print("Fetched questions list is empty.")
return "Fetched questions list is empty or invalid format.", None
print(f"Fetched {len(questions_data)} questions.")
except requests.exceptions.RequestException as e:
print(f"Error fetching questions: {e}")
return f"Error fetching questions: {e}", None
except requests.exceptions.JSONDecodeError as e:
print(f"Error decoding JSON response from questions endpoint: {e}")
print(f"Response text: {response.text[:500]}")
return f"Error decoding server response for questions: {e}", None
except Exception as e:
print(f"An unexpected error occurred fetching questions: {e}")
return f"An unexpected error occurred fetching questions: {e}", None
# 3. Run your Agent
results_log = []
answers_payload = []
print(f"Running agent on {len(questions_data)} questions...")
# Progress tracking
progress_count = 0
total_questions = len(questions_data)
for item in questions_data:
task_id = item.get("task_id")
question_text = item.get("question")
if not task_id or question_text is None:
print(f"Skipping item with missing task_id or question: {item}")
continue
# Update progress
progress_count += 1
print(f"Processing question {progress_count}/{total_questions}")
try:
submitted_answer = agent(question_text)
answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
break
# Add a small delay between questions to help with rate limiting
if progress_count < total_questions: # Don't delay after the last question
time.sleep(2)
except Exception as e:
print(f"Error running agent on task {task_id}: {e}")
results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
if not answers_payload:
print("Agent did not produce any answers to submit.")
return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
# 4. Prepare Submission
submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
print(status_update)
# 5. Submit
print(f"Submitting {len(answers_payload)} answers to: {submit_url}")
try:
response = requests.post(submit_url, json=submission_data, timeout=120)
response.raise_for_status()
result_data = response.json()
final_status = (
f"βœ… Submission Successful!\n"
f"User: {result_data.get('username')}\n"
f"Overall Score: {result_data.get('score', 'N/A')}% "
f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
f"Message: {result_data.get('message', 'No message received.')}"
)
print("Submission successful.")
results_df = pd.DataFrame(results_log)
return final_status, results_df
except requests.exceptions.HTTPError as e:
error_detail = f"Server responded with status {e.response.status_code}."
try:
error_json = e.response.json()
error_detail += f" Detail: {error_json.get('detail', e.response.text)}"
except requests.exceptions.JSONDecodeError:
error_detail += f" Response: {e.response.text[:500]}"
status_message = f"❌ Submission Failed: {error_detail}"
print(status_message)
results_df = pd.DataFrame(results_log)
return status_message, results_df
except requests.exceptions.Timeout:
status_message = "❌ Submission Failed: The request timed out. Please try again."
print(status_message)
results_df = pd.DataFrame(results_log)
return status_message, results_df
except requests.exceptions.RequestException as e:
status_message = f"❌ Submission Failed: Network error - {e}"
print(status_message)
results_df = pd.DataFrame(results_log)
return status_message, results_df
except Exception as e:
status_message = f"❌ An unexpected error occurred during submission: {e}"
print(status_message)
results_df = pd.DataFrame(results_log)
return status_message, results_df
def test_agent(question: str, provider: str):
"""Test the agent with a single question."""
try:
agent = BasicAgent(provider=provider)
answer = agent(question)
return f"Question: {question}\nAnswer: {answer}"
except Exception as e:
return f"Error testing agent: {e}"
# --- Build Gradio Interface using Blocks ---
with gr.Blocks(title="GAIA Agent Evaluator") as demo:
gr.Markdown("# πŸ€– GAIA Agent Evaluator")
gr.Markdown(
"""
This interface allows you to evaluate your agent against the GAIA benchmark questions.
**Instructions:**
1. Log in to your Hugging Face account using the button below
2. Click 'Run Evaluation & Submit All Answers' to fetch questions, run your agent, and submit answers
3. View your results and score in the output panel
**For Testing:**
Use the test section below to verify your agent works correctly with sample questions.
"""
)
with gr.Tab("Evaluation"):
gr.Markdown("## πŸš€ Run Full Evaluation")
gr.LoginButton()
with gr.Row():
run_button = gr.Button("Run Evaluation & Submit All Answers", variant="primary")
status_output = gr.Textbox(label="πŸ“Š Status / Submission Result", lines=8, interactive=False)
results_table = gr.DataFrame(label="πŸ“‹ Questions and Agent Answers", wrap=True)
run_button.click(
fn=run_and_submit_all,
outputs=[status_output, results_table]
)
with gr.Tab("Testing"):
gr.Markdown("## πŸ§ͺ Test Your Agent")
with gr.Row():
with gr.Column():
test_question = gr.Textbox(
label="Question",
placeholder="Enter a test question...",
value="What is 2+2?"
)
provider_choice = gr.Radio(
choices=["groq", "hf"],
value="groq",
label="Provider"
)
test_button = gr.Button("Test Agent")
with gr.Column():
test_output = gr.Textbox(label="Agent Response", lines=10, interactive=False)
test_button.click(
fn=test_agent,
inputs=[test_question, provider_choice],
outputs=test_output
)
if __name__ == "__main__":
print("\n" + "="*50)
print("πŸš€ GAIA Agent Evaluator Starting")
print("="*50)
# Check for SPACE_HOST and SPACE_ID at startup for information
space_host_startup = os.getenv("SPACE_HOST")
space_id_startup = os.getenv("SPACE_ID")
if space_host_startup:
print(f"βœ… SPACE_HOST found: {space_host_startup}")
print(f" Runtime URL: https://{space_host_startup}.hf.space")
else:
print("ℹ️ Running locally (SPACE_HOST not found)")
if space_id_startup:
print(f"βœ… SPACE_ID found: {space_id_startup}")
print(f" Repo URL: https://huggingface.co/spaces/{space_id_startup}")
else:
print("ℹ️ SPACE_ID not found (Repo URL cannot be determined)")
print("="*50)
print("Launching Gradio Interface...")
demo.launch(debug=True, share=False)