Spaces:
Running
Running
| # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """BotClient class for interacting with bot models.""" | |
| import os | |
| import argparse | |
| import json | |
| import logging | |
| import traceback | |
| import jieba | |
| import requests | |
| from openai import OpenAI | |
| class BotClient: | |
| """Client for interacting with various AI models.""" | |
| def __init__(self, args: argparse.Namespace): | |
| """ | |
| Initializes the BotClient instance by configuring essential parameters from command line arguments | |
| including retry limits, character constraints, model endpoints and API credentials while setting up | |
| default values for missing arguments to ensure robust operation. | |
| Args: | |
| args (argparse.Namespace): Command line arguments containing configuration parameters. | |
| Uses getattr() to safely retrieve values with fallback defaults. | |
| """ | |
| self.logger = logging.getLogger(__name__) | |
| self.max_retry_num = getattr(args, "max_retry_num", 3) | |
| self.max_char = getattr(args, "max_char", 8000) | |
| self.model_map = getattr(args, "model_map", {}) | |
| self.api_key = os.environ.get("API_KEY") | |
| self.embedding_service_url = getattr( | |
| args, "embedding_service_url", "embedding_service_url" | |
| ) | |
| self.embedding_model = getattr(args, "embedding_model", "embedding_model") | |
| self.web_search_service_url = getattr( | |
| args, "web_search_service_url", "web_search_service_url" | |
| ) | |
| self.max_search_results_num = getattr(args, "max_search_results_num", 15) | |
| self.qianfan_api_key = os.environ.get("API_KEY") | |
| def call_back(self, host_url: str, req_data: dict) -> dict: | |
| """ | |
| Executes an HTTP request to the specified endpoint using the OpenAI client, handles the response | |
| conversion to a compatible dictionary format, and manages any exceptions that may occur during | |
| the request process while logging errors appropriately. | |
| Args: | |
| host_url (str): The URL to send the request to. | |
| req_data (dict): The data to send in the request body. | |
| Returns: | |
| dict: Parsed JSON response from the server. Returns empty dict | |
| if request fails or response is invalid. | |
| """ | |
| try: | |
| client = OpenAI(base_url=host_url, api_key=self.api_key) | |
| response = client.chat.completions.create(**req_data) | |
| # Convert OpenAI response to compatible format | |
| return response.model_dump() | |
| except Exception as e: | |
| self.logger.error(f"Stream request failed: {e}") | |
| raise | |
| def call_back_stream(self, host_url: str, req_data: dict) -> dict: | |
| """ | |
| Makes a streaming HTTP request to the specified host URL using the OpenAI client and yields response chunks | |
| in real-time while handling any exceptions that may occur during the streaming process. | |
| Args: | |
| host_url (str): The URL to send the request to. | |
| req_data (dict): The data to send in the request body. | |
| Returns: | |
| generator: Generator that yields parsed JSON responses from the server. | |
| """ | |
| try: | |
| client = OpenAI(base_url=host_url, api_key=self.api_key) | |
| response = client.chat.completions.create( | |
| **req_data, | |
| stream=True, | |
| ) | |
| for chunk in response: | |
| if not chunk.choices: | |
| continue | |
| # Convert OpenAI response to compatible format | |
| yield chunk.model_dump() | |
| except Exception as e: | |
| self.logger.error(f"Stream request failed: {e}") | |
| raise | |
| def process( | |
| self, | |
| model_name: str, | |
| req_data: dict, | |
| max_tokens: int = 2048, | |
| temperature: float = 1.0, | |
| top_p: float = 0.7, | |
| ) -> dict: | |
| """ | |
| Handles chat completion requests by mapping the model name to its endpoint, preparing request parameters | |
| including token limits and sampling settings, truncating messages to fit character limits, making API calls | |
| with built-in retry mechanism, and logging the full request/response cycle for debugging purposes. | |
| Args: | |
| model_name (str): Name of the model, used to look up the model URL from model_map. | |
| req_data (dict): Dictionary containing request data, including information to be processed. | |
| max_tokens (int): Maximum number of tokens to generate. | |
| temperature (float): Sampling temperature to control the diversity of generated text. | |
| top_p (float): Cumulative probability threshold to control the diversity of generated text. | |
| Returns: | |
| dict: Dictionary containing the model's processing results. | |
| """ | |
| model_url = self.model_map[model_name] | |
| req_data["model"] = model_name | |
| req_data["max_tokens"] = max_tokens | |
| req_data["temperature"] = temperature | |
| req_data["top_p"] = top_p | |
| req_data["messages"] = self.truncate_messages(req_data["messages"]) | |
| for _ in range(self.max_retry_num): | |
| try: | |
| self.logger.info(f"[MODEL] {model_url}") | |
| self.logger.info("[req_data]====>") | |
| self.logger.info(json.dumps(req_data, ensure_ascii=False)) | |
| res = self.call_back(model_url, req_data) | |
| self.logger.info("model response") | |
| self.logger.info(res) | |
| self.logger.info("-" * 30) | |
| except Exception as e: | |
| self.logger.info(e) | |
| self.logger.info(traceback.format_exc()) | |
| res = {} | |
| if len(res) != 0 and "error" not in res: | |
| break | |
| return res | |
| def process_stream( | |
| self, | |
| model_name: str, | |
| req_data: dict, | |
| max_tokens: int = 2048, | |
| temperature: float = 1.0, | |
| top_p: float = 0.7, | |
| ) -> dict: | |
| """ | |
| Processes streaming requests by mapping the model name to its endpoint, configuring request parameters, | |
| implementing a retry mechanism with logging, and streaming back response chunks in real-time while | |
| handling any errors that may occur during the streaming session. | |
| Args: | |
| model_name (str): Name of the model, used to look up the model URL from model_map. | |
| req_data (dict): Dictionary containing request data, including information to be processed. | |
| max_tokens (int): Maximum number of tokens to generate. | |
| temperature (float): Sampling temperature to control the diversity of generated text. | |
| top_p (float): Cumulative probability threshold to control the diversity of generated text. | |
| Yields: | |
| dict: Dictionary containing the model's processing results. | |
| """ | |
| model_url = self.model_map[model_name] | |
| req_data["model"] = model_name | |
| req_data["max_tokens"] = max_tokens | |
| req_data["temperature"] = temperature | |
| req_data["top_p"] = top_p | |
| req_data["messages"] = self.truncate_messages(req_data["messages"]) | |
| last_error = None | |
| for _ in range(self.max_retry_num): | |
| try: | |
| self.logger.info(f"[MODEL] {model_url}") | |
| self.logger.info("[req_data]====>") | |
| self.logger.info(json.dumps(req_data, ensure_ascii=False)) | |
| yield from self.call_back_stream(model_url, req_data) | |
| return | |
| except Exception as e: | |
| last_error = e | |
| self.logger.error( | |
| f"Stream request failed (attempt {_ + 1}/{self.max_retry_num}): {e}" | |
| ) | |
| self.logger.error("All retry attempts failed for stream request") | |
| yield {"error": str(last_error)} | |
| def cut_chinese_english(self, text: str) -> list: | |
| """ | |
| Segments mixed Chinese and English text into individual components using Jieba for Chinese words | |
| while preserving English words as whole units, with special handling for Unicode character ranges | |
| to distinguish between the two languages. | |
| Args: | |
| text (str): Input string to be segmented. | |
| Returns: | |
| list: A list of segments, where each segment is either a letter or a word. | |
| """ | |
| words = jieba.lcut(text) | |
| en_ch_words = [] | |
| for word in words: | |
| if word.isalpha() and not any( | |
| "\u4e00" <= char <= "\u9fff" for char in word | |
| ): | |
| en_ch_words.append(word) | |
| else: | |
| en_ch_words.extend(list(word)) | |
| return en_ch_words | |
| def truncate_messages(self, messages: list[dict]) -> list: | |
| """ | |
| Truncates conversation messages to fit within the maximum character limit (self.max_char) | |
| by intelligently removing content while preserving message structure. The truncation follows | |
| a prioritized order: historical messages first, then system message, and finally the last message. | |
| Args: | |
| messages (list[dict]): List of messages to be truncated. | |
| Returns: | |
| list[dict]: Modified list of messages after truncation. | |
| """ | |
| if not messages: | |
| return messages | |
| processed = [] | |
| total_units = 0 | |
| for msg in messages: | |
| # Handle two different content formats | |
| if isinstance(msg["content"], str): | |
| text_content = msg["content"] | |
| elif isinstance(msg["content"], list): | |
| text_content = msg["content"][1]["text"] | |
| else: | |
| text_content = "" | |
| # Calculate unit count after tokenization | |
| units = self.cut_chinese_english(text_content) | |
| unit_count = len(units) | |
| processed.append( | |
| { | |
| "role": msg["role"], | |
| "original_content": msg["content"], # Preserve original content | |
| "text_content": text_content, # Extracted plain text | |
| "units": units, | |
| "unit_count": unit_count, | |
| } | |
| ) | |
| total_units += unit_count | |
| if total_units <= self.max_char: | |
| return messages | |
| # Number of units to remove | |
| to_remove = total_units - self.max_char | |
| # 1. Truncate historical messages | |
| for i in range(len(processed) - 1, 1): | |
| if to_remove <= 0: | |
| break | |
| # current = processed[i] | |
| if processed[i]["unit_count"] <= to_remove: | |
| processed[i]["text_content"] = "" | |
| to_remove -= processed[i]["unit_count"] | |
| if isinstance(processed[i]["original_content"], str): | |
| processed[i]["original_content"] = "" | |
| elif isinstance(processed[i]["original_content"], list): | |
| processed[i]["original_content"][1]["text"] = "" | |
| else: | |
| kept_units = processed[i]["units"][:-to_remove] | |
| new_text = "".join(kept_units) | |
| processed[i]["text_content"] = new_text | |
| if isinstance(processed[i]["original_content"], str): | |
| processed[i]["original_content"] = new_text | |
| elif isinstance(processed[i]["original_content"], list): | |
| processed[i]["original_content"][1]["text"] = new_text | |
| to_remove = 0 | |
| # 2. Truncate system message | |
| if to_remove > 0: | |
| system_msg = processed[0] | |
| if system_msg["unit_count"] <= to_remove: | |
| processed[0]["text_content"] = "" | |
| to_remove -= system_msg["unit_count"] | |
| if isinstance(processed[0]["original_content"], str): | |
| processed[0]["original_content"] = "" | |
| elif isinstance(processed[0]["original_content"], list): | |
| processed[0]["original_content"][1]["text"] = "" | |
| else: | |
| kept_units = system_msg["units"][:-to_remove] | |
| new_text = "".join(kept_units) | |
| processed[0]["text_content"] = new_text | |
| if isinstance(processed[0]["original_content"], str): | |
| processed[0]["original_content"] = new_text | |
| elif isinstance(processed[0]["original_content"], list): | |
| processed[0]["original_content"][1]["text"] = new_text | |
| to_remove = 0 | |
| # 3. Truncate last message | |
| if to_remove > 0 and len(processed) > 1: | |
| last_msg = processed[-1] | |
| if last_msg["unit_count"] > to_remove: | |
| kept_units = last_msg["units"][:-to_remove] | |
| new_text = "".join(kept_units) | |
| last_msg["text_content"] = new_text | |
| if isinstance(last_msg["original_content"], str): | |
| last_msg["original_content"] = new_text | |
| elif isinstance(last_msg["original_content"], list): | |
| last_msg["original_content"][1]["text"] = new_text | |
| else: | |
| last_msg["text_content"] = "" | |
| if isinstance(last_msg["original_content"], str): | |
| last_msg["original_content"] = "" | |
| elif isinstance(last_msg["original_content"], list): | |
| last_msg["original_content"][1]["text"] = "" | |
| result = [] | |
| for msg in processed: | |
| if msg["text_content"]: | |
| result.append({"role": msg["role"], "content": msg["original_content"]}) | |
| return result | |
| def embed_fn(self, text: str) -> list: | |
| """ | |
| Generate an embedding for the given text using the QianFan API. | |
| Args: | |
| text (str): The input text to be embedded. | |
| Returns: | |
| list: A list of floats representing the embedding. | |
| """ | |
| client = OpenAI( | |
| base_url=self.embedding_service_url, api_key=self.qianfan_api_key | |
| ) | |
| response = client.embeddings.create(input=[text], model=self.embedding_model) | |
| return response.data[0].embedding | |
| def get_web_search_res(self, query_list: list) -> list: | |
| """ | |
| Send a request to the AI Search service using the provided API key and service URL. | |
| Args: | |
| query_list (list): List of queries to send to the AI Search service. | |
| Returns: | |
| list: List of responses from the AI Search service. | |
| """ | |
| headers = { | |
| "Authorization": "Bearer " + self.qianfan_api_key, | |
| "Content-Type": "application/json", | |
| } | |
| results = [] | |
| top_k = self.max_search_results_num // len(query_list) | |
| for query in query_list: | |
| payload = { | |
| "messages": [{"role": "user", "content": query}], | |
| "resource_type_filter": [{"type": "web", "top_k": top_k}], | |
| } | |
| response = requests.post( | |
| self.web_search_service_url, headers=headers, json=payload | |
| ) | |
| if response.status_code == 200: | |
| response = response.json() | |
| self.logger.info(response) | |
| results.append(response["references"]) | |
| else: | |
| self.logger.info(f"请求失败,状态码: {response.status_code}") | |
| self.logger.info(response.text) | |
| return results | |