olcapone commited on
Commit
bc13e30
·
verified ·
1 Parent(s): e258014
Files changed (1) hide show
  1. app.py +98 -13
app.py CHANGED
@@ -3,11 +3,101 @@ import gradio as gr
3
  import requests
4
  import pandas as pd
5
  import time
 
6
  from smolagents import LiteLLMModel, CodeAgent, Tool
7
 
8
  # --- Constants ---
9
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # --- Agent Tools ---
12
  class MathSolver(Tool):
13
  name = "math_solver"
@@ -70,13 +160,14 @@ def select_model(provider="groq"):
70
  HF_MODEL_NAME = "huggingfaceh4/zephyr-7b-beta"
71
 
72
  if provider == "groq":
73
- return LiteLLMModel(model_id="groq/llama-3.1-8b-instant",
 
 
 
74
  api_key=os.getenv("GROQ_API_KEY"))
75
  if not api_key:
76
  raise ValueError("GROQ_API_KEY environment variable is not set")
77
-
78
-
79
-
80
  return LiteLLMModel(model_id=GROQ_MODEL_NAME, api_key=api_key)
81
  elif provider == "hf":
82
  api_key = os.getenv("HF_TOKEN")
@@ -111,6 +202,7 @@ class BasicAgent:
111
  "For string answers, omit articles ('a', 'the') and use full words. "
112
  "For lists, output in comma-separated format with no conjunctions. "
113
  "If the answer is not found, say `- unknown`."
 
114
  )
115
 
116
  def __call__(self, question: str) -> str:
@@ -120,15 +212,8 @@ class BasicAgent:
120
  for attempt in range(max_retries):
121
  try:
122
  result = self.agent.run(question)
123
- # Extract only the final answer without any wrappers
124
- final_str = str(result).strip()
125
- # Remove any potential prefixes
126
- if final_str.startswith('[ANSWER]'):
127
- final_str = final_str[8:].strip()
128
- if final_str.startswith('Final answer:'):
129
- final_str = final_str[13:].strip()
130
- if final_str.startswith('Answer:'):
131
- final_str = final_str[7:].strip()
132
  return final_str
133
  except Exception as e:
134
  # Check if it's a rate limit error
 
3
  import requests
4
  import pandas as pd
5
  import time
6
+ import re
7
  from smolagents import LiteLLMModel, CodeAgent, Tool
8
 
9
  # --- Constants ---
10
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
11
 
12
+ # --- Answer Extraction Function ---
13
+ def extract_answer(text: str, original_question: str) -> str:
14
+ """Extract the answer from the LLM response, being robust to various formats."""
15
+ if not text:
16
+ return "- unknown"
17
+
18
+ # Clean the text
19
+ cleaned = text.strip()
20
+
21
+ # If the response is the same as the question, it's not an answer
22
+ if cleaned == original_question.strip():
23
+ return "- unknown"
24
+
25
+ # Remove common prefixes
26
+ prefixes_to_remove = [
27
+ '[ANSWER]:',
28
+ '[ANSWER]',
29
+ 'Final answer:',
30
+ 'Final Answer:',
31
+ 'Answer:',
32
+ 'answer:',
33
+ 'The answer is',
34
+ 'The final answer is',
35
+ ]
36
+
37
+ for prefix in prefixes_to_remove:
38
+ if cleaned.startswith(prefix):
39
+ cleaned = cleaned[len(prefix):].strip()
40
+
41
+ # If it's a "how many" question, try to extract just the number
42
+ if 'how many' in original_question.lower():
43
+ # Look for numbers in the response
44
+ numbers = re.findall(r'\d+', cleaned)
45
+ if numbers:
46
+ return numbers[0] # Return the first number found
47
+
48
+ # If it's asking for a year, try to extract just the year
49
+ if re.search(r'\b(19|20)\d{2}\b', original_question):
50
+ years = re.findall(r'\b(19|20)\d{2}\b', cleaned)
51
+ if years:
52
+ return years[0] # Return the first year found
53
+
54
+ # If we still have the full question in the response, try to extract what comes after it
55
+ if original_question.strip() in cleaned:
56
+ # Split by the question and take what comes after
57
+ parts = cleaned.split(original_question.strip())
58
+ if len(parts) > 1 and parts[1].strip():
59
+ cleaned = parts[1].strip()
60
+ else:
61
+ # Try to find numbers or short answers in the response
62
+ # Look for a line that might contain the answer
63
+ lines = cleaned.split('\n')
64
+ for line in lines:
65
+ line = line.strip()
66
+ if line and line != original_question.strip():
67
+ # If it's a short line, it might be the answer
68
+ if len(line) < 100 or 'how many' in original_question.lower():
69
+ cleaned = line
70
+ break
71
+
72
+ # If the cleaned answer is still very long and contains the question,
73
+ # try to extract just the essential part
74
+ if len(cleaned) > 200 and original_question.strip() in cleaned:
75
+ # Try to find a short line that might be the answer
76
+ lines = cleaned.split('\n')
77
+ for line in lines:
78
+ line = line.strip()
79
+ if line and len(line) < 100 and line != original_question.strip():
80
+ # Check if it looks like an answer (short and possibly numeric)
81
+ if re.match(r'^[\w\s\d\-\.,]+$', line): # Simple alphanumeric answer
82
+ return line
83
+
84
+ # If we still have a very long response, try to extract just the last line
85
+ # which might be the answer
86
+ if len(cleaned) > 200:
87
+ lines = cleaned.split('\n')
88
+ # Take the last non-empty line that isn't too long
89
+ for line in reversed(lines):
90
+ line = line.strip()
91
+ if line and len(line) < 100:
92
+ cleaned = line
93
+ break
94
+
95
+ # Final fallback - if the result is still the same as the question, return unknown
96
+ if cleaned == original_question.strip():
97
+ return "- unknown"
98
+
99
+ return cleaned if cleaned else "- unknown"
100
+
101
  # --- Agent Tools ---
102
  class MathSolver(Tool):
103
  name = "math_solver"
 
160
  HF_MODEL_NAME = "huggingfaceh4/zephyr-7b-beta"
161
 
162
  if provider == "groq":
163
+ api_key = os.getenv("GROQ_API_KEY")
164
+
165
+ if api_key:
166
+ return LiteLLMModel(model_id="groq/llama-3.1-8b-instant",
167
  api_key=os.getenv("GROQ_API_KEY"))
168
  if not api_key:
169
  raise ValueError("GROQ_API_KEY environment variable is not set")
170
+
 
 
171
  return LiteLLMModel(model_id=GROQ_MODEL_NAME, api_key=api_key)
172
  elif provider == "hf":
173
  api_key = os.getenv("HF_TOKEN")
 
202
  "For string answers, omit articles ('a', 'the') and use full words. "
203
  "For lists, output in comma-separated format with no conjunctions. "
204
  "If the answer is not found, say `- unknown`."
205
+ "IMPORTANT: Respond with ONLY the answer, nothing else. No prefixes, no explanations."
206
  )
207
 
208
  def __call__(self, question: str) -> str:
 
212
  for attempt in range(max_retries):
213
  try:
214
  result = self.agent.run(question)
215
+ # Use our enhanced extraction function
216
+ final_str = extract_answer(str(result), question)
 
 
 
 
 
 
 
217
  return final_str
218
  except Exception as e:
219
  # Check if it's a rate limit error