olcapone commited on
Commit
b3af0f9
·
verified ·
1 Parent(s): f24677e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -2
app.py CHANGED
@@ -2,18 +2,133 @@ import os
2
  import gradio as gr
3
  import requests
4
  import pandas as pd
5
- from agent import BasicAgent
6
 
7
  # --- Constants ---
8
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def run_and_submit_all(profile: gr.OAuthProfile | None):
11
  """
12
  Fetches all questions, runs the BasicAgent on them, submits all answers,
13
  and displays the results.
14
  """
15
  # --- Determine HF Space Runtime URL and Repo URL ---
16
- space_id = os.getenv("SPACE_ID") # Get the SPACE_ID for sending link to the code
17
 
18
  if profile:
19
  username = f"{profile.username}"
 
2
  import gradio as gr
3
  import requests
4
  import pandas as pd
5
+ from smolagents import LiteLLMModel, CodeAgent, Tool
6
 
7
  # --- Constants ---
8
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
9
 
10
+ # --- Agent Tools ---
11
+ class MathSolver(Tool):
12
+ name = "math_solver"
13
+ description = "Safely evaluate basic math expressions."
14
+ inputs = {"input": {"type": "string", "description": "Math expression to evaluate."}}
15
+ output_type = "string"
16
+
17
+ def forward(self, input: str) -> str:
18
+ try:
19
+ # Safe evaluation of math expressions
20
+ allowed_names = {
21
+ k: v for k, v in __builtins__.items() if k in [
22
+ 'abs', 'round', 'min', 'max', 'sum', 'pow'
23
+ ]
24
+ }
25
+ allowed_names.update({
26
+ 'int': int, 'float': float, 'str': str,
27
+ '__builtins__': {}
28
+ })
29
+ return str(eval(input, allowed_names))
30
+ except Exception as e:
31
+ return f"Math error: {e}"
32
+
33
+ class FileAttachmentQueryTool(Tool):
34
+ name = "run_query_with_file"
35
+ description = "Downloads a file mentioned in a user prompt, adds it to the context, and runs a query on it."
36
+ inputs = {
37
+ "task_id": {
38
+ "type": "string",
39
+ "description": "A unique identifier for the task related to this file, used to download it.",
40
+ "nullable": True
41
+ },
42
+ "user_query": {
43
+ "type": "string",
44
+ "description": "The question to answer about the file."
45
+ }
46
+ }
47
+ output_type = "string"
48
+
49
+ def forward(self, task_id: str | None, user_query: str) -> str:
50
+ if not task_id:
51
+ return "No task_id provided for file download."
52
+
53
+ file_url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}"
54
+ try:
55
+ file_response = requests.get(file_url)
56
+ if file_response.status_code != 200:
57
+ return f"Failed to download file: {file_response.status_code}"
58
+
59
+ # For text-based files, return content directly
60
+ file_content = file_response.text[:2000] # Limit content size
61
+ return f"Relevant information from file: {file_content}"
62
+ except Exception as e:
63
+ return f"File download error: {e}"
64
+
65
+ # --- Agent Implementation ---
66
+ def select_model(provider="groq"):
67
+ """Select and return a model based on the provider."""
68
+ GROQ_MODEL_NAME = "groq/llama3-70b-8192"
69
+ HF_MODEL_NAME = "huggingfaceh4/zephyr-7b-beta"
70
+
71
+ if provider == "groq":
72
+ api_key = os.getenv("GROQ_API_KEY")
73
+ if not api_key:
74
+ raise ValueError("GROQ_API_KEY environment variable is not set")
75
+ return LiteLLMModel(model_id=GROQ_MODEL_NAME, api_key=api_key)
76
+ elif provider == "hf":
77
+ api_key = os.getenv("HF_TOKEN")
78
+ if not api_key:
79
+ raise ValueError("HF_TOKEN environment variable is not set")
80
+ return LiteLLMModel(model_id=HF_MODEL_NAME, api_key=api_key)
81
+ else:
82
+ # Default to Groq if no valid provider specified
83
+ api_key = os.getenv("GROQ_API_KEY")
84
+ if not api_key:
85
+ raise ValueError("GROQ_API_KEY environment variable is not set")
86
+ return LiteLLMModel(model_id=GROQ_MODEL_NAME, api_key=api_key)
87
+
88
+ class BasicAgent:
89
+ def __init__(self, provider="groq"):
90
+ model = select_model(provider)
91
+ tools = [
92
+ MathSolver(),
93
+ FileAttachmentQueryTool(),
94
+ ]
95
+ self.agent = CodeAgent(
96
+ model=model,
97
+ tools=tools,
98
+ add_base_tools=False,
99
+ max_steps=15,
100
+ )
101
+ # System prompt to enforce exact answer format
102
+ self.agent.prompt_templates["system_prompt"] = (
103
+ "You are a GAIA benchmark AI assistant. Your sole purpose is to output the minimal, final answer. "
104
+ "You must NEVER output explanations, intermediate steps, reasoning, or comments — only the answer. "
105
+ "For numerical answers, use digits only, e.g., `4` not `four`. "
106
+ "For string answers, omit articles ('a', 'the') and use full words. "
107
+ "For lists, output in comma-separated format with no conjunctions. "
108
+ "If the answer is not found, say `- unknown`."
109
+ )
110
+
111
+ def __call__(self, question: str) -> str:
112
+ result = self.agent.run(question)
113
+ # Extract only the final answer without any wrappers
114
+ final_str = str(result).strip()
115
+ # Remove any potential prefixes
116
+ if final_str.startswith('[ANSWER]'):
117
+ final_str = final_str[8:].strip()
118
+ if final_str.startswith('Final answer:'):
119
+ final_str = final_str[13:].strip()
120
+ if final_str.startswith('Answer:'):
121
+ final_str = final_str[7:].strip()
122
+ return final_str
123
+
124
+ # --- Main Application Functions ---
125
  def run_and_submit_all(profile: gr.OAuthProfile | None):
126
  """
127
  Fetches all questions, runs the BasicAgent on them, submits all answers,
128
  and displays the results.
129
  """
130
  # --- Determine HF Space Runtime URL and Repo URL ---
131
+ space_id = os.getenv("SPACE_ID")
132
 
133
  if profile:
134
  username = f"{profile.username}"