Spaces:
Running
Running
| import json | |
| import time | |
| from copy import deepcopy | |
| from multi_turn_eval.multi_turn_utils import ( | |
| STATELESS_CLASSES, | |
| execute_multi_turn_func_call, | |
| is_empty_execute_response, | |
| ) | |
| from constant import ( | |
| DEFAULT_USER_PROMPT_FOR_ADDITIONAL_FUNCTION_FC, | |
| DEFAULT_USER_PROMPT_FOR_ADDITIONAL_FUNCTION_PROMPTING, | |
| MAXIMUM_STEP_LIMIT, | |
| ) | |
| from model_style import ModelStyle | |
| from overrides import final | |
| class BaseHandler: | |
| model_name: str | |
| model_style: ModelStyle | |
| def __init__(self, model_name, temperature) -> None: | |
| self.model_name = model_name | |
| # Replace the slash with underscore to avoid creating subdirectories | |
| # Replace the dash and dot with underscore for valid variable name | |
| self.model_name_underline_replaced = ( | |
| model_name.replace("/", "_").replace("-", "_").replace(".", "_") | |
| ) | |
| self.temperature = temperature | |
| self.is_fc_model = False # Whether the model is a function calling model | |
| def inference(self, test_entry: dict, include_input_log: bool=False, include_state_log: bool=False): | |
| # This method is used to retrive model response for each model. | |
| return self.inference_multi_turn_FC(test_entry, include_input_log, include_state_log) | |
| def inference_multi_turn_FC( | |
| self, test_entry: dict, include_input_log: bool, include_state_log: bool | |
| ): | |
| initial_config: dict = test_entry["initial_config"] | |
| involved_classes: list = test_entry["involved_classes"] | |
| test_entry_id: str = test_entry["id"] | |
| test_category: str = test_entry_id.rsplit("_", 1)[0] | |
| # This is only for the miss function category | |
| # A mapping from turn index to function to holdout | |
| holdout_function: dict[int, list] = test_entry.get("missed_function", {}) | |
| total_input_token_count: list[list[float]] = [] | |
| total_output_token_count: list[list[float]] = [] | |
| total_latency: list[list[float]] = [] | |
| all_model_response: list[list] = ( | |
| [] | |
| ) # The model response that will be used for later evaluation | |
| all_inference_log: list[list[dict]] = ( | |
| [] | |
| ) # The debugging log for human to understand | |
| force_quit = False # Whether the model has been forced to quit. If True, this whole entry will be failed. | |
| # Execute no function call, but just to get a reference to all the instances to get the initial state for logging purpose | |
| if include_state_log: | |
| _, involved_instances = execute_multi_turn_func_call( | |
| [], | |
| initial_config, | |
| involved_classes, | |
| self.model_name_underline_replaced, | |
| test_entry_id, | |
| long_context=( | |
| "long_context" in test_category or "composite" in test_category | |
| ), | |
| is_evaL_run=False, | |
| ) | |
| state_log = [] | |
| for class_name, class_instance in involved_instances.items(): | |
| if class_name in STATELESS_CLASSES: | |
| continue | |
| class_instance = deepcopy(class_instance) # Avoid modification in future turns | |
| state_log.append( | |
| { | |
| "role": "state_info", | |
| "class_name": class_name, | |
| "content": { | |
| key: value | |
| for key, value in vars(class_instance).items() | |
| if not key.startswith("_") | |
| }, | |
| } | |
| ) | |
| all_inference_log.append(state_log) | |
| inference_data: dict = {} | |
| inference_data = self._pre_query_processing_FC(inference_data, test_entry) | |
| inference_data = self._compile_tools(inference_data, test_entry) | |
| all_multi_turn_messages: list[list[dict]] = test_entry["question"] | |
| for turn_idx, current_turn_message in enumerate(all_multi_turn_messages): | |
| current_turn_message: list[dict] | |
| if str(turn_idx) in holdout_function: | |
| test_entry["function"].extend(holdout_function[str(turn_idx)]) | |
| # Since we have added new functions, we need to recompile the tools | |
| inference_data = self._compile_tools(inference_data, test_entry) | |
| assert ( | |
| len(current_turn_message) == 0 | |
| ), "Holdout turn should not have user message." | |
| current_turn_message = [ | |
| { | |
| "role": "user", | |
| "content": DEFAULT_USER_PROMPT_FOR_ADDITIONAL_FUNCTION_FC, | |
| } | |
| ] | |
| if turn_idx == 0: | |
| inference_data = self.add_first_turn_message_FC( | |
| inference_data, [current_turn_message] | |
| ) | |
| else: | |
| assert isinstance(current_turn_message, list), "Current turn message is not a list" | |
| inference_data = self._add_next_turn_user_message_FC( | |
| inference_data, current_turn_message | |
| ) | |
| current_turn_response = [] | |
| current_turn_inference_log: list[dict] = {"begin_of_turn_query": current_turn_message} | |
| current_turn_input_token_count: list[float] = [] | |
| current_turn_output_token_count: list[float] = [] | |
| current_turn_latency: list[float] = [] | |
| involved_instances = None | |
| count = 0 | |
| while True: | |
| print("-" * 100) | |
| print( | |
| f"ID: {test_entry_id.replace('multi_turn_', '')}, Turn: {turn_idx}, Step: {count}" | |
| ) | |
| current_step_inference_log: list[dict] = [] | |
| # Add to the current_turn_inference_log at beginning of each step so that we don't need to bother dealing with the break statements | |
| current_turn_inference_log[f"step_{count}"] = current_step_inference_log | |
| start_time = time.time() | |
| api_response = self._query_FC(inference_data) | |
| query_latency = time.time() - start_time | |
| # This part of logging is disabled by default because it is too verbose and will make the result file extremely large | |
| # It is only useful to see if the inference pipeline is working as expected (eg, does it convert all the inputs correctly) | |
| if include_input_log: | |
| current_step_inference_log.append( | |
| { | |
| "role": "handler_log", | |
| "content": inference_data.get("inference_input_log", ""), | |
| } | |
| ) | |
| # Try parsing the model response | |
| model_response_data = self._parse_query_response_FC(api_response) | |
| model_responses = model_response_data["model_responses"] | |
| # Add the assistant message to the chat history | |
| inference_data = self._add_assistant_message_FC( | |
| inference_data, model_response_data | |
| ) | |
| # Process the metadata | |
| current_turn_input_token_count.append(model_response_data["input_token"]) | |
| current_turn_output_token_count.append(model_response_data["output_token"]) | |
| current_turn_latency.append(query_latency) | |
| current_turn_response.append(model_responses) | |
| current_step_inference_log.append( | |
| {"role": "assistant", "content": model_responses} | |
| ) | |
| # Try decoding the model response | |
| try: | |
| decoded_model_responses = self.decode_execute(model_responses) | |
| current_step_inference_log.append( | |
| { | |
| "role": "handler_log", | |
| "content": "Successfully decoded model response.", | |
| "model_response_decoded": decoded_model_responses, | |
| } | |
| ) | |
| if is_empty_execute_response(decoded_model_responses): | |
| print("Empty response from the model. Proceed to next turn.") | |
| current_step_inference_log.append( | |
| { | |
| "role": "handler_log", | |
| "content": f"Empty response from the model. Proceed to next turn.", | |
| "model_response_decoded": decoded_model_responses, | |
| } | |
| ) | |
| break | |
| except Exception as e: | |
| print("Failed to decode the model response. Proceed to next turn.") | |
| current_step_inference_log.append( | |
| { | |
| "role": "handler_log", | |
| "content": f"Error decoding the model response. Proceed to next turn.", | |
| "error": str(e), | |
| } | |
| ) | |
| yield ("summary", model_responses, None, self.model_name) | |
| break | |
| # Obtain the execution results | |
| execution_results, involved_instances = execute_multi_turn_func_call( | |
| decoded_model_responses, | |
| initial_config, | |
| involved_classes, | |
| self.model_name_underline_replaced, | |
| test_entry_id, | |
| long_context=( | |
| "long_context" in test_category or "composite" in test_category | |
| ), | |
| is_evaL_run=False, | |
| ) | |
| # Add the execution results to the chat history for the next turn | |
| inference_data = self._add_execution_results_FC( | |
| inference_data, execution_results, model_response_data | |
| ) | |
| for execution_result in execution_results: | |
| current_step_inference_log.append( | |
| { | |
| "role": "tool", | |
| "content": execution_result, | |
| } | |
| ) | |
| execution_results = deepcopy(execution_results) | |
| for i in range(len(execution_results)): | |
| if "error" in execution_results[i]: | |
| execution_results[i] = execution_results[i].replace("error", "error❗️") | |
| yield ("regular", decoded_model_responses, execution_results, self.model_name) | |
| count += 1 | |
| # Force quit after too many steps | |
| if count > MAXIMUM_STEP_LIMIT: | |
| force_quit = True | |
| current_step_inference_log.append( | |
| { | |
| "role": "handler_log", | |
| "content": f"Model has been forced to quit after {MAXIMUM_STEP_LIMIT} steps.", | |
| } | |
| ) | |
| break | |
| # Add to the total list | |
| all_model_response.append(current_turn_response) | |
| all_inference_log.append(current_turn_inference_log) | |
| total_input_token_count.append(current_turn_input_token_count) | |
| total_output_token_count.append(current_turn_output_token_count) | |
| total_latency.append(current_turn_latency) | |
| if include_state_log: | |
| state_log = [] | |
| for class_name, class_instance in involved_instances.items(): | |
| if class_name in STATELESS_CLASSES: | |
| continue | |
| class_instance = deepcopy(class_instance) # Avoid modification in future turns | |
| state_log.append( | |
| { | |
| "role": "state_info", | |
| "class_name": class_name, | |
| "content": { | |
| key: value | |
| for key, value in vars(class_instance).items() | |
| if not key.startswith("_") | |
| }, | |
| } | |
| ) | |
| all_inference_log.append(state_log) | |
| if force_quit: | |
| break | |
| metadata = { | |
| "input_token_count": total_input_token_count, | |
| "output_token_count": total_output_token_count, | |
| "latency": total_latency, | |
| "inference_log": all_inference_log, | |
| } | |
| yield ("final", current_turn_response, inference_data, involved_instances) | |
| def decode_ast(self, result, language="Python"): | |
| # This method takes raw model output and convert it to standard AST checker input. | |
| raise NotImplementedError | |
| def decode_execute(self, result): | |
| # This method takes raw model output and convert it to standard execute checker input. | |
| raise NotImplementedError | |
| #### FC methods #### | |
| def _query_FC(self, inference_data: dict): | |
| """ | |
| Call the model API in FC mode to get the response. | |
| Return the response object that can be used to feed into the decode method. | |
| """ | |
| raise NotImplementedError | |
| def _pre_query_processing_FC(self, inference_data: dict, test_entry: dict) -> dict: | |
| """ | |
| Preprocess the testset entry before sending it to the model. | |
| This includes transforming the input user message into the format expected by the model, and any other necessary preprocessing steps. | |
| The inference_data dict is updated in place and returned. | |
| """ | |
| raise NotImplementedError | |
| def _compile_tools(self, inference_data: dict, test_entry: dict) -> dict: | |
| """ | |
| Compile the tools from the test entry and add them to the inference data. | |
| This method is used to prepare the tools for the model query in FC mode. | |
| The inference_data dict is updated in place and returned. | |
| """ | |
| raise NotImplementedError | |
| def _parse_query_response_FC(self, api_response: any) -> dict: | |
| """ | |
| Parses the raw response from the model API to extract the result, input token count, and output token count. | |
| Args: | |
| api_response (any): The raw response from the model API. | |
| Returns: | |
| A dict containing the following elements: | |
| - model_responses (any): The parsed result that can be directly used as input to the decode method. | |
| - input_token (int): The number of tokens used in the input to the model. | |
| - output_token (int): The number of tokens generated by the model as output. | |
| - tool_call_ids (list[str]): The IDs of the tool calls that are generated by the model. Optional. | |
| - Any other metadata that is specific to the model. | |
| """ | |
| raise NotImplementedError | |
| def add_first_turn_message_FC( | |
| self, inference_data: dict, first_turn_message: list[dict] | |
| ) -> dict: | |
| """ | |
| Add the first turn message to the chat history. | |
| """ | |
| raise NotImplementedError | |
| def _add_next_turn_user_message_FC( | |
| self, inference_data: dict, user_message: list[dict] | |
| ) -> dict: | |
| """ | |
| [Only for multi-turn] | |
| Add next turn user message to the chat history for query. | |
| user_message is a list of 1 element, which is the user message. | |
| """ | |
| raise NotImplementedError | |
| def _add_assistant_message_FC( | |
| self, inference_data: dict, model_response_data: dict | |
| ) -> dict: | |
| """ | |
| Add assistant message to the chat history. | |
| """ | |
| raise NotImplementedError | |
| def _add_execution_results_FC( | |
| self, inference_data: dict, execution_results: list[str], model_response_data: dict | |
| ) -> dict: | |
| """ | |
| Add the execution results to the chat history to prepare for the next turn of query. | |
| Some models may need to add additional information to the chat history, such as tool call IDs. | |
| """ | |
| raise NotImplementedError | |
| #### Prompting methods #### | |
| def _query_prompting(self, inference_data: dict): | |
| """ | |
| Call the model API in prompting mode to get the response. | |
| Return the response object that can be used to feed into the decode method. | |
| """ | |
| raise NotImplementedError | |
| def _pre_query_processing_prompting(self, test_entry: dict) -> dict: | |
| """ | |
| Preprocess the testset entry before sending it to the model. | |
| Returns a dict that contains all the necessary information for the query method. | |
| `tools` and `message` must be included in the returned dict. | |
| Things like `system_prompt` and `chat_history` are optional, specific to the model. | |
| """ | |
| raise NotImplementedError | |
| def _parse_query_response_prompting(self, api_response: any) -> dict: | |
| """ | |
| Parses the raw response from the model API to extract the result, input token count, and output token count. | |
| Args: | |
| api_response (any): The raw response from the model API. | |
| Returns: | |
| A dict containing the following elements: | |
| - model_responses (any): The parsed result that can be directly used as input to the decode method. | |
| - input_token (int): The number of tokens used in the input to the model. | |
| - output_token (int): The number of tokens generated by the model as output. | |
| - tool_call_ids (list[str]): The IDs of the tool calls that are generated by the model. Optional. | |
| - Any other metadata that is specific to the model. | |
| """ | |
| raise NotImplementedError | |
| def add_first_turn_message_prompting( | |
| self, inference_data: dict, first_turn_message: list[dict] | |
| ) -> dict: | |
| """ | |
| Add the first turn message to the chat history. | |
| """ | |
| raise NotImplementedError | |
| def _add_next_turn_user_message_prompting( | |
| self, inference_data: dict, user_message: list[dict] | |
| ) -> dict: | |
| """ | |
| [Only for multi-turn] | |
| Add next turn user message to the chat history for query. | |
| user_message is a list of 1 element, which is the user message. | |
| """ | |
| raise NotImplementedError | |
| def _add_assistant_message_prompting( | |
| self, inference_data: dict, model_response_data: dict | |
| ) -> dict: | |
| """ | |
| Add assistant message to the chat history. | |
| """ | |
| raise NotImplementedError | |
| def _add_execution_results_prompting( | |
| self, inference_data: dict, execution_results: list[str], model_response_data: dict | |
| ) -> dict: | |
| """ | |
| Add the execution results to the chat history to prepare for the next turn of query. | |
| Some models may need to add additional information to the chat history, such as tool call IDs. | |
| """ | |
| raise NotImplementedError | |