Spaces:
Running
Running
Added Selene / Selene Mini API
Browse files- gen_api_answer.py +50 -43
gen_api_answer.py
CHANGED
|
@@ -15,6 +15,7 @@ from prompts import (
|
|
| 15 |
FLOW_JUDGE_PROMPT
|
| 16 |
)
|
| 17 |
from transformers import AutoTokenizer
|
|
|
|
| 18 |
|
| 19 |
# Initialize clients
|
| 20 |
anthropic_client = anthropic.Anthropic()
|
|
@@ -24,6 +25,10 @@ hf_api_key = os.getenv("HF_API_KEY")
|
|
| 24 |
flow_judge_api_key = os.getenv("FLOW_JUDGE_API_KEY")
|
| 25 |
cohere_client = cohere.ClientV2(os.getenv("CO_API_KEY"))
|
| 26 |
salesforce_api_key = os.getenv("SALESFORCE_API_KEY")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
def get_openai_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0):
|
| 28 |
"""Get response from OpenAI API"""
|
| 29 |
try:
|
|
@@ -110,42 +115,33 @@ def get_prometheus_response(model_name, prompt, system_prompt=None, max_tokens=5
|
|
| 110 |
return f"Error with Hugging Face model {model_name}: {str(e)}"
|
| 111 |
|
| 112 |
def get_atla_response(model_name, prompt, system_prompt=None, max_tokens=500, temperature=0.01):
|
| 113 |
-
"""Get response from
|
| 114 |
try:
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
#
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
"
|
| 135 |
-
|
| 136 |
-
"return_full_text": False,
|
| 137 |
-
"temperature": temperature,
|
| 138 |
-
"seed": 42,
|
| 139 |
-
"add_generation_prompt": True
|
| 140 |
-
}
|
| 141 |
}
|
| 142 |
-
|
| 143 |
-
response = requests.post(
|
| 144 |
-
"https://bkp9p28gri93egqh.us-east-1.aws.endpoints.huggingface.cloud",
|
| 145 |
-
headers=headers,
|
| 146 |
-
json=payload
|
| 147 |
-
)
|
| 148 |
-
return response.json()[0]["generated_text"]
|
| 149 |
except Exception as e:
|
| 150 |
return f"Error with Atla model {model_name}: {str(e)}"
|
| 151 |
|
|
@@ -321,9 +317,16 @@ def get_model_response(
|
|
| 321 |
api_model, final_prompt, system_prompt, max_tokens, temperature = 0.01
|
| 322 |
)
|
| 323 |
elif organization == "Atla":
|
| 324 |
-
|
| 325 |
-
api_model,
|
| 326 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
elif organization == "Cohere":
|
| 328 |
return get_cohere_response(
|
| 329 |
api_model, final_prompt, system_prompt, max_tokens, temperature
|
|
@@ -350,6 +353,10 @@ def parse_model_response(response):
|
|
| 350 |
# Debug print
|
| 351 |
print(f"Raw model response: {response}")
|
| 352 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
# If response is already a dictionary, use it directly
|
| 354 |
if isinstance(response, dict):
|
| 355 |
return str(response.get("result", "N/A")), response.get("feedback", "N/A")
|
|
@@ -359,10 +366,10 @@ def parse_model_response(response):
|
|
| 359 |
data = json.loads(response)
|
| 360 |
return str(data.get("result", "N/A")), data.get("feedback", "N/A")
|
| 361 |
except json.JSONDecodeError:
|
| 362 |
-
# If that fails, check if this is a Salesforce response
|
| 363 |
if "**Reasoning:**" in response or "**Result:**" in response:
|
| 364 |
-
# Use ATLA parser for Salesforce responses
|
| 365 |
-
return
|
| 366 |
|
| 367 |
# Otherwise try to find JSON within the response
|
| 368 |
json_match = re.search(r"{.*}", response, re.DOTALL)
|
|
@@ -443,10 +450,10 @@ def prometheus_parse_model_response(output):
|
|
| 443 |
print(f"Failed to parse response: {str(e)}")
|
| 444 |
return "Error", f"Exception during parsing: {str(e)}"
|
| 445 |
|
| 446 |
-
def
|
| 447 |
-
"""Parse response from
|
| 448 |
try:
|
| 449 |
-
print(f"Raw
|
| 450 |
output = output.strip()
|
| 451 |
|
| 452 |
# Look for the Reasoning and Result sections
|
|
@@ -458,10 +465,10 @@ def atla_parse_model_response(output):
|
|
| 458 |
score = result_match.group(1)
|
| 459 |
return str(score), feedback
|
| 460 |
|
| 461 |
-
return "Error", f"Failed to parse
|
| 462 |
|
| 463 |
except Exception as e:
|
| 464 |
-
print(f"Failed to parse
|
| 465 |
return "Error", f"Exception during parsing: {str(e)}"
|
| 466 |
|
| 467 |
def flow_judge_parse_model_response(output):
|
|
|
|
| 15 |
FLOW_JUDGE_PROMPT
|
| 16 |
)
|
| 17 |
from transformers import AutoTokenizer
|
| 18 |
+
from atla import Atla
|
| 19 |
|
| 20 |
# Initialize clients
|
| 21 |
anthropic_client = anthropic.Anthropic()
|
|
|
|
| 25 |
flow_judge_api_key = os.getenv("FLOW_JUDGE_API_KEY")
|
| 26 |
cohere_client = cohere.ClientV2(os.getenv("CO_API_KEY"))
|
| 27 |
salesforce_api_key = os.getenv("SALESFORCE_API_KEY")
|
| 28 |
+
|
| 29 |
+
# Initialize Atla client
|
| 30 |
+
atla_client = Atla()
|
| 31 |
+
|
| 32 |
def get_openai_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0):
|
| 33 |
"""Get response from OpenAI API"""
|
| 34 |
try:
|
|
|
|
| 115 |
return f"Error with Hugging Face model {model_name}: {str(e)}"
|
| 116 |
|
| 117 |
def get_atla_response(model_name, prompt, system_prompt=None, max_tokens=500, temperature=0.01):
|
| 118 |
+
"""Get response from Atla API"""
|
| 119 |
try:
|
| 120 |
+
# Extract components from the prompt data
|
| 121 |
+
model_input = prompt.get('human_input', '')
|
| 122 |
+
model_output = prompt.get('ai_response', '')
|
| 123 |
+
expected_output = prompt.get('ground_truth_input', '')
|
| 124 |
+
evaluation_criteria = prompt.get('eval_criteria', '')
|
| 125 |
+
|
| 126 |
+
# Set model_id based on the model name
|
| 127 |
+
if "Mini" in model_name:
|
| 128 |
+
model_id = "atla-selene-mini"
|
| 129 |
+
else:
|
| 130 |
+
model_id = "atla-selene"
|
| 131 |
+
|
| 132 |
+
response = atla_client.evaluation.create(
|
| 133 |
+
model_id=model_id,
|
| 134 |
+
model_input=model_input,
|
| 135 |
+
model_output=model_output,
|
| 136 |
+
expected_model_output=expected_output if expected_output else None,
|
| 137 |
+
evaluation_criteria=evaluation_criteria,
|
| 138 |
+
)
|
| 139 |
|
| 140 |
+
# Return the score and critique directly
|
| 141 |
+
return {
|
| 142 |
+
"score": response.result.evaluation.score,
|
| 143 |
+
"critique": response.result.evaluation.critique
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
except Exception as e:
|
| 146 |
return f"Error with Atla model {model_name}: {str(e)}"
|
| 147 |
|
|
|
|
| 317 |
api_model, final_prompt, system_prompt, max_tokens, temperature = 0.01
|
| 318 |
)
|
| 319 |
elif organization == "Atla":
|
| 320 |
+
response = get_atla_response(
|
| 321 |
+
api_model, prompt_data, system_prompt, max_tokens, temperature
|
| 322 |
)
|
| 323 |
+
# Response now contains score and critique directly
|
| 324 |
+
if isinstance(response, dict) and 'score' in response and 'critique' in response:
|
| 325 |
+
score = str(response['score'])
|
| 326 |
+
critique = response['critique']
|
| 327 |
+
return score, critique
|
| 328 |
+
else:
|
| 329 |
+
return "Error", str(response)
|
| 330 |
elif organization == "Cohere":
|
| 331 |
return get_cohere_response(
|
| 332 |
api_model, final_prompt, system_prompt, max_tokens, temperature
|
|
|
|
| 353 |
# Debug print
|
| 354 |
print(f"Raw model response: {response}")
|
| 355 |
|
| 356 |
+
# If response is already a tuple (from Atla/Salesforce), use it directly
|
| 357 |
+
if isinstance(response, tuple):
|
| 358 |
+
return response
|
| 359 |
+
|
| 360 |
# If response is already a dictionary, use it directly
|
| 361 |
if isinstance(response, dict):
|
| 362 |
return str(response.get("result", "N/A")), response.get("feedback", "N/A")
|
|
|
|
| 366 |
data = json.loads(response)
|
| 367 |
return str(data.get("result", "N/A")), data.get("feedback", "N/A")
|
| 368 |
except json.JSONDecodeError:
|
| 369 |
+
# If that fails, check if this is a Salesforce response
|
| 370 |
if "**Reasoning:**" in response or "**Result:**" in response:
|
| 371 |
+
# Use ATLA parser for Salesforce responses only
|
| 372 |
+
return salesforce_parse_model_response(response)
|
| 373 |
|
| 374 |
# Otherwise try to find JSON within the response
|
| 375 |
json_match = re.search(r"{.*}", response, re.DOTALL)
|
|
|
|
| 450 |
print(f"Failed to parse response: {str(e)}")
|
| 451 |
return "Error", f"Exception during parsing: {str(e)}"
|
| 452 |
|
| 453 |
+
def salesforce_parse_model_response(output):
|
| 454 |
+
"""Parse response from Salesforce model"""
|
| 455 |
try:
|
| 456 |
+
print(f"Raw Salesforce model response: {output}")
|
| 457 |
output = output.strip()
|
| 458 |
|
| 459 |
# Look for the Reasoning and Result sections
|
|
|
|
| 465 |
score = result_match.group(1)
|
| 466 |
return str(score), feedback
|
| 467 |
|
| 468 |
+
return "Error", f"Failed to parse Salesforce response format: {output}"
|
| 469 |
|
| 470 |
except Exception as e:
|
| 471 |
+
print(f"Failed to parse Salesforce response: {str(e)}")
|
| 472 |
return "Error", f"Exception during parsing: {str(e)}"
|
| 473 |
|
| 474 |
def flow_judge_parse_model_response(output):
|