Spaces:
Running
Running
| import asyncio | |
| import copy | |
| import logging | |
| from dataclasses import asdict | |
| from typing import List, Optional, Union | |
| import aiohttp | |
| from lagent.llms.base_llm import AsyncLLMMixin, BaseLLM | |
| from lagent.schema import ModelStatusCode | |
| from lagent.utils.util import filter_suffix | |
| class TritonClient(BaseLLM): | |
| """TritonClient is a wrapper of TritonClient for LLM. | |
| Args: | |
| tritonserver_addr (str): the address in format "ip:port" of | |
| triton inference server | |
| model_name (str): the name of the model | |
| session_len (int): the context size | |
| max_tokens (int): the expected generated token numbers | |
| """ | |
| def __init__(self, | |
| tritonserver_addr: str, | |
| model_name: str, | |
| session_len: int = 32768, | |
| log_level: str = 'WARNING', | |
| **kwargs): | |
| super().__init__(path=None, **kwargs) | |
| try: | |
| from lmdeploy.serve.turbomind.chatbot import Chatbot, StatusCode | |
| except Exception as e: | |
| logging.error(f'{e}') | |
| raise RuntimeError('DO NOT use turbomind.chatbot since it has ' | |
| 'been removed by lmdeploy since v0.5.2') | |
| self.state_map = { | |
| StatusCode.TRITON_STREAM_END: ModelStatusCode.END, | |
| StatusCode.TRITON_SERVER_ERR: ModelStatusCode.SERVER_ERR, | |
| StatusCode.TRITON_SESSION_CLOSED: ModelStatusCode.SESSION_CLOSED, | |
| StatusCode.TRITON_STREAM_ING: ModelStatusCode.STREAM_ING, | |
| StatusCode.TRITON_SESSION_OUT_OF_LIMIT: | |
| ModelStatusCode.SESSION_OUT_OF_LIMIT, | |
| StatusCode.TRITON_SESSION_INVALID_ARG: | |
| ModelStatusCode.SESSION_INVALID_ARG, | |
| StatusCode.TRITON_SESSION_READY: ModelStatusCode.SESSION_READY | |
| } | |
| self.chatbot = Chatbot( | |
| tritonserver_addr=tritonserver_addr, | |
| model_name=model_name, | |
| session_len=session_len, | |
| log_level=log_level, | |
| **kwargs) | |
| def generate(self, | |
| inputs: Union[str, List[str]], | |
| session_id: int = 2967, | |
| request_id: str = '', | |
| sequence_start: bool = True, | |
| sequence_end: bool = True, | |
| skip_special_tokens: bool = False, | |
| **kwargs): | |
| """Start a new round conversation of a session. Return the chat | |
| completions in non-stream mode. | |
| Args: | |
| inputs (str, List[str]): user's prompt(s) in this round | |
| session_id (int): the identical id of a session | |
| request_id (str): the identical id of this round conversation | |
| sequence_start (bool): start flag of a session | |
| sequence_end (bool): end flag of a session | |
| skip_special_tokens (bool): Whether or not to remove special tokens | |
| in the decoding. Default to be False. | |
| Returns: | |
| (a list of/batched) text/chat completion | |
| """ | |
| from lmdeploy.serve.turbomind.chatbot import Session, get_logger | |
| if isinstance(inputs, str): | |
| inputs = [inputs] | |
| prompt = inputs | |
| assert isinstance(session_id, int), \ | |
| f'INT session id is required, but got {type(session_id)}' | |
| self.chatbot.cfg = self._update_gen_params(**kwargs) | |
| max_new_tokens = self.chatbot.cfg.max_new_tokens | |
| logger = get_logger('service.ft', log_level=self.chatbot.log_level) | |
| logger.info(f'session {session_id}, request_id {request_id}, ' | |
| f'max_out_len {max_new_tokens}') | |
| if self.chatbot._session is None: | |
| sequence_start = True | |
| self.chatbot._session = Session(session_id=session_id) | |
| elif self.chatbot._session.status == 0: | |
| logger.error(f'session {session_id} has been ended. Please set ' | |
| f'`sequence_start` be True if you want to restart it') | |
| return '' | |
| self.chatbot._session.status = 1 | |
| self.chatbot._session.request_id = request_id | |
| self.chatbot._session.response = '' | |
| status, res, _ = None, '', 0 | |
| for status, res, _ in self.chatbot._stream_infer( | |
| self.chatbot._session, | |
| prompt, | |
| max_new_tokens, | |
| sequence_start, | |
| sequence_end, | |
| skip_special_tokens=skip_special_tokens): | |
| status = self.state_map.get(status) | |
| if status < ModelStatusCode.END: | |
| return '' | |
| elif status == ModelStatusCode.END: | |
| self.chatbot._session.histories = ( | |
| self.chatbot._session.histories + | |
| self.chatbot._session.prompt + | |
| self.chatbot._session.response) | |
| # remove stop_words | |
| res = filter_suffix(res, self.gen_params.get('stop_words')) | |
| return res | |
| def stream_chat(self, | |
| inputs: List[dict], | |
| session_id: int = 2967, | |
| request_id: str = '', | |
| sequence_start: bool = True, | |
| sequence_end: bool = True, | |
| skip_special_tokens: bool = False, | |
| **kwargs): | |
| """Start a new round conversation of a session. Return the chat | |
| completions in stream mode. | |
| Args: | |
| session_id (int): the identical id of a session | |
| inputs (List[dict]): user's inputs in this round conversation | |
| request_id (str): the identical id of this round conversation | |
| sequence_start (bool): start flag of a session | |
| sequence_end (bool): end flag of a session | |
| skip_special_tokens (bool): Whether or not to remove special tokens | |
| in the decoding. Default to be False. | |
| Returns: | |
| tuple(Status, str, int): status, text/chat completion, | |
| generated token number | |
| """ | |
| from lmdeploy.serve.turbomind.chatbot import Session, get_logger | |
| assert isinstance(session_id, int), \ | |
| f'INT session id is required, but got {type(session_id)}' | |
| self.chatbot.cfg = self._update_gen_params(**kwargs) | |
| max_new_tokens = self.chatbot.cfg.max_new_tokens | |
| logger = get_logger('service.ft', log_level=self.chatbot.log_level) | |
| logger.info(f'session {session_id}, request_id {request_id}, ' | |
| f'max_out_len {max_new_tokens}') | |
| if self.chatbot._session is None: | |
| sequence_start = True | |
| self.chatbot._session = Session(session_id=session_id) | |
| elif self.chatbot._session.status == 0: | |
| logger.error(f'session {session_id} has been ended. Please set ' | |
| f'`sequence_start` be True if you want to restart it') | |
| return ModelStatusCode.SESSION_CLOSED, '', 0 | |
| self.chatbot._session.status = 1 | |
| self.chatbot._session.request_id = request_id | |
| self.chatbot._session.response = '' | |
| prompt = self.template_parser(inputs) | |
| status, res, _ = None, '', 0 | |
| for status, res, _ in self.chatbot._stream_infer( | |
| self.chatbot._session, | |
| prompt, | |
| max_new_tokens, | |
| sequence_start, | |
| sequence_end, | |
| skip_special_tokens=skip_special_tokens): | |
| status = self.state_map.get(status) | |
| # The stop symbol also appears in the output of the last STREAM_ING state. | |
| res = filter_suffix(res, self.gen_params.get('stop_words')) | |
| if status < ModelStatusCode.END: | |
| return status, res, _ | |
| elif status == ModelStatusCode.END: # remove stop_words | |
| self.chatbot._session.histories = ( | |
| self.chatbot._session.histories + | |
| self.chatbot._session.prompt + | |
| self.chatbot._session.response) | |
| yield status, res, _ | |
| break | |
| else: | |
| yield status, res, _ | |
| def _update_gen_params(self, **kwargs): | |
| import mmengine | |
| new_gen_params = self.update_gen_params(**kwargs) | |
| self.gen_params['stop_words'] = new_gen_params.pop('stop_words') | |
| stop_words = self.chatbot._stop_words( | |
| self.gen_params.get('stop_words')) | |
| cfg = mmengine.Config( | |
| dict( | |
| session_len=self.chatbot.model.session_len, | |
| stop_words=stop_words, | |
| bad_words=self.chatbot.cfg.bad_words, | |
| **new_gen_params)) | |
| return cfg | |
| class LMDeployPipeline(BaseLLM): | |
| """ | |
| Args: | |
| path (str): The path to the model. | |
| It could be one of the following options: | |
| - i) A local directory path of a turbomind model which is | |
| converted by `lmdeploy convert` command or download | |
| from ii) and iii). | |
| - ii) The model_id of a lmdeploy-quantized model hosted | |
| inside a model repo on huggingface.co, such as | |
| "InternLM/internlm-chat-20b-4bit", | |
| "lmdeploy/llama2-chat-70b-4bit", etc. | |
| - iii) The model_id of a model hosted inside a model repo | |
| on huggingface.co, such as "internlm/internlm-chat-7b", | |
| "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" | |
| and so on. | |
| model_name (str): needed when model_path is a pytorch model on | |
| huggingface.co, such as "internlm-chat-7b", | |
| "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on. | |
| tp (int): tensor parallel | |
| pipeline_cfg (dict): config of pipeline | |
| """ | |
| def __init__(self, | |
| path: str, | |
| model_name: Optional[str] = None, | |
| tp: int = 1, | |
| pipeline_cfg=dict(), | |
| **kwargs): | |
| import lmdeploy | |
| from lmdeploy import ChatTemplateConfig, TurbomindEngineConfig, pipeline, version_info | |
| self.str_version = lmdeploy.__version__ | |
| self.version = version_info | |
| self.do_sample = kwargs.pop('do_sample', None) | |
| if self.do_sample is not None and self.version < (0, 6, 0): | |
| raise RuntimeError( | |
| '`do_sample` parameter is not supported by lmdeploy until ' | |
| f'v0.6.0, but currently using lmdeloy {self.str_version}') | |
| super().__init__(path=path, **kwargs) | |
| backend_config = copy.deepcopy(pipeline_cfg) | |
| backend_config.update(tp=tp) | |
| backend_config = { | |
| k: v | |
| for k, v in backend_config.items() | |
| if hasattr(TurbomindEngineConfig, k) | |
| } | |
| backend_config = TurbomindEngineConfig(**backend_config) | |
| chat_template_config = ChatTemplateConfig( | |
| model_name=model_name) if model_name else None | |
| self.model = pipeline( | |
| model_path=self.path, | |
| backend_config=backend_config, | |
| chat_template_config=chat_template_config, | |
| log_level='WARNING') | |
| def generate(self, | |
| inputs: Union[str, List[str]], | |
| do_preprocess: bool = None, | |
| skip_special_tokens: bool = False, | |
| return_dict: bool = False, | |
| **kwargs): | |
| """Return the chat completions in non-stream mode. | |
| Args: | |
| inputs (Union[str, List[str]]): input texts to be completed. | |
| do_preprocess (bool): whether pre-process the messages. Default to | |
| True, which means chat_template will be applied. | |
| skip_special_tokens (bool): Whether or not to remove special tokens | |
| in the decoding. Default to be False. | |
| Returns: | |
| (a list of/batched) text/chat completion | |
| """ | |
| from lmdeploy.messages import GenerationConfig | |
| batched = True | |
| if isinstance(inputs, str): | |
| inputs = [inputs] | |
| batched = False | |
| prompt = inputs | |
| do_sample = kwargs.pop('do_sample', None) | |
| gen_params = self.update_gen_params(**kwargs) | |
| if do_sample is None: | |
| do_sample = self.do_sample | |
| if do_sample is not None and self.version < (0, 6, 0): | |
| raise RuntimeError( | |
| '`do_sample` parameter is not supported by lmdeploy until ' | |
| f'v0.6.0, but currently using lmdeloy {self.str_version}') | |
| if self.version >= (0, 6, 0): | |
| if do_sample is None: | |
| do_sample = gen_params['top_k'] > 1 or gen_params[ | |
| 'temperature'] > 0 | |
| gen_params.update(do_sample=do_sample) | |
| gen_config = GenerationConfig( | |
| skip_special_tokens=skip_special_tokens, **gen_params) | |
| response = self.model.batch_infer( | |
| prompt, gen_config=gen_config, do_preprocess=do_preprocess) | |
| texts = [resp.text for resp in response] | |
| # remove stop_words | |
| texts = filter_suffix(texts, self.gen_params.get('stop_words')) | |
| for resp, text in zip(response, texts): | |
| resp.text = text | |
| if batched: | |
| return [asdict(resp) | |
| for resp in response] if return_dict else texts | |
| return asdict(response[0]) if return_dict else texts[0] | |
| class LMDeployServer(BaseLLM): | |
| """ | |
| Args: | |
| path (str): The path to the model. | |
| It could be one of the following options: | |
| - i) A local directory path of a turbomind model which is | |
| converted by `lmdeploy convert` command or download from | |
| ii) and iii). | |
| - ii) The model_id of a lmdeploy-quantized model hosted | |
| inside a model repo on huggingface.co, such as | |
| "InternLM/internlm-chat-20b-4bit", | |
| "lmdeploy/llama2-chat-70b-4bit", etc. | |
| - iii) The model_id of a model hosted inside a model repo | |
| on huggingface.co, such as "internlm/internlm-chat-7b", | |
| "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" | |
| and so on. | |
| model_name (str): needed when model_path is a pytorch model on | |
| huggingface.co, such as "internlm-chat-7b", | |
| "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on. | |
| server_name (str): host ip for serving | |
| server_port (int): server port | |
| tp (int): tensor parallel | |
| log_level (str): set log level whose value among | |
| [CRITICAL, ERROR, WARNING, INFO, DEBUG] | |
| """ | |
| def __init__(self, | |
| path: str, | |
| model_name: Optional[str] = None, | |
| server_name: str = '0.0.0.0', | |
| server_port: int = 23333, | |
| tp: int = 1, | |
| log_level: str = 'WARNING', | |
| serve_cfg=dict(), | |
| **kwargs): | |
| super().__init__(path=path, **kwargs) | |
| self.model_name = model_name | |
| # TODO get_logger issue in multi processing | |
| import lmdeploy | |
| self.client = lmdeploy.serve( | |
| model_path=self.path, | |
| model_name=model_name, | |
| server_name=server_name, | |
| server_port=server_port, | |
| tp=tp, | |
| log_level=log_level, | |
| **serve_cfg) | |
| def generate(self, | |
| inputs: Union[str, List[str]], | |
| session_id: int = 2967, | |
| sequence_start: bool = True, | |
| sequence_end: bool = True, | |
| ignore_eos: bool = False, | |
| skip_special_tokens: Optional[bool] = False, | |
| timeout: int = 30, | |
| **kwargs) -> List[str]: | |
| """Start a new round conversation of a session. Return the chat | |
| completions in non-stream mode. | |
| Args: | |
| inputs (str, List[str]): user's prompt(s) in this round | |
| session_id (int): the identical id of a session | |
| sequence_start (bool): start flag of a session | |
| sequence_end (bool): end flag of a session | |
| ignore_eos (bool): indicator for ignoring eos | |
| skip_special_tokens (bool): Whether or not to remove special tokens | |
| in the decoding. Default to be False. | |
| timeout (int): max time to wait for response | |
| Returns: | |
| (a list of/batched) text/chat completion | |
| """ | |
| batched = True | |
| if isinstance(inputs, str): | |
| inputs = [inputs] | |
| batched = False | |
| gen_params = self.update_gen_params(**kwargs) | |
| max_new_tokens = gen_params.pop('max_new_tokens') | |
| gen_params.update(max_tokens=max_new_tokens) | |
| resp = [''] * len(inputs) | |
| for text in self.client.completions_v1( | |
| self.model_name, | |
| inputs, | |
| session_id=session_id, | |
| sequence_start=sequence_start, | |
| sequence_end=sequence_end, | |
| stream=False, | |
| ignore_eos=ignore_eos, | |
| skip_special_tokens=skip_special_tokens, | |
| timeout=timeout, | |
| **gen_params): | |
| resp = [ | |
| resp[i] + item['text'] | |
| for i, item in enumerate(text['choices']) | |
| ] | |
| # remove stop_words | |
| resp = filter_suffix(resp, self.gen_params.get('stop_words')) | |
| if not batched: | |
| return resp[0] | |
| return resp | |
| def stream_chat(self, | |
| inputs: List[dict], | |
| session_id=0, | |
| sequence_start: bool = True, | |
| sequence_end: bool = True, | |
| stream: bool = True, | |
| ignore_eos: bool = False, | |
| skip_special_tokens: Optional[bool] = False, | |
| timeout: int = 30, | |
| **kwargs): | |
| """Start a new round conversation of a session. Return the chat | |
| completions in stream mode. | |
| Args: | |
| session_id (int): the identical id of a session | |
| inputs (List[dict]): user's inputs in this round conversation | |
| sequence_start (bool): start flag of a session | |
| sequence_end (bool): end flag of a session | |
| stream (bool): return in a streaming format if enabled | |
| ignore_eos (bool): indicator for ignoring eos | |
| skip_special_tokens (bool): Whether or not to remove special tokens | |
| in the decoding. Default to be False. | |
| timeout (int): max time to wait for response | |
| Returns: | |
| tuple(Status, str, int): status, text/chat completion, | |
| generated token number | |
| """ | |
| gen_params = self.update_gen_params(**kwargs) | |
| max_new_tokens = gen_params.pop('max_new_tokens') | |
| gen_params.update(max_tokens=max_new_tokens) | |
| prompt = self.template_parser(inputs) | |
| resp = '' | |
| finished = False | |
| stop_words = self.gen_params.get('stop_words') | |
| for text in self.client.completions_v1( | |
| self.model_name, | |
| prompt, | |
| session_id=session_id, | |
| sequence_start=sequence_start, | |
| sequence_end=sequence_end, | |
| stream=stream, | |
| ignore_eos=ignore_eos, | |
| skip_special_tokens=skip_special_tokens, | |
| timeout=timeout, | |
| **gen_params): | |
| resp += text['choices'][0]['text'] | |
| if not resp: | |
| continue | |
| # remove stop_words | |
| for sw in stop_words: | |
| if sw in resp: | |
| resp = filter_suffix(resp, stop_words) | |
| finished = True | |
| break | |
| yield ModelStatusCode.STREAM_ING, resp, None | |
| if finished: | |
| break | |
| yield ModelStatusCode.END, resp, None | |
| class LMDeployClient(LMDeployServer): | |
| """ | |
| Args: | |
| url (str): communicating address 'http://<ip>:<port>' of | |
| api_server | |
| model_name (str): needed when model_path is a pytorch model on | |
| huggingface.co, such as "internlm-chat-7b", | |
| "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on. | |
| """ | |
| def __init__(self, url: str, model_name: str, **kwargs): | |
| BaseLLM.__init__(self, path=url, **kwargs) | |
| from lmdeploy.serve.openai.api_client import APIClient | |
| self.client = APIClient(url) | |
| self.model_name = model_name | |
| class AsyncLMDeployPipeline(AsyncLLMMixin, LMDeployPipeline): | |
| """ | |
| Args: | |
| path (str): The path to the model. | |
| It could be one of the following options: | |
| - i) A local directory path of a turbomind model which is | |
| converted by `lmdeploy convert` command or download | |
| from ii) and iii). | |
| - ii) The model_id of a lmdeploy-quantized model hosted | |
| inside a model repo on huggingface.co, such as | |
| "InternLM/internlm-chat-20b-4bit", | |
| "lmdeploy/llama2-chat-70b-4bit", etc. | |
| - iii) The model_id of a model hosted inside a model repo | |
| on huggingface.co, such as "internlm/internlm-chat-7b", | |
| "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" | |
| and so on. | |
| model_name (str): needed when model_path is a pytorch model on | |
| huggingface.co, such as "internlm-chat-7b", | |
| "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on. | |
| tp (int): tensor parallel | |
| pipeline_cfg (dict): config of pipeline | |
| """ | |
| async def generate(self, | |
| inputs: Union[str, List[str]], | |
| session_ids: Union[int, List[int]] = None, | |
| do_preprocess: bool = None, | |
| skip_special_tokens: bool = False, | |
| return_dict: bool = False, | |
| **kwargs): | |
| """Return the chat completions in non-stream mode. | |
| Args: | |
| inputs (Union[str, List[str]]): input texts to be completed. | |
| do_preprocess (bool): whether pre-process the messages. Default to | |
| True, which means chat_template will be applied. | |
| skip_special_tokens (bool): Whether or not to remove special tokens | |
| in the decoding. Default to be False. | |
| Returns: | |
| (a list of/batched) text/chat completion | |
| """ | |
| from lmdeploy.messages import GenerationConfig, Response | |
| batched = True | |
| if isinstance(inputs, str): | |
| inputs = [inputs] | |
| batched = False | |
| if session_ids is None: | |
| session_ids = list(range(len(inputs))) | |
| elif isinstance(session_ids, (int, str)): | |
| session_ids = [session_ids] | |
| assert len(inputs) == len(session_ids) | |
| prompt = inputs | |
| gen_params = self.update_gen_params(**kwargs) | |
| gen_config = GenerationConfig( | |
| skip_special_tokens=skip_special_tokens, **gen_params) | |
| async def _inner_generate(uid, text): | |
| resp = Response('', 0, 0, uid) | |
| async for out in self.model.generate( | |
| text, | |
| uid, | |
| gen_config, | |
| stream_response=True, | |
| sequence_start=True, | |
| sequence_end=True, | |
| do_preprocess=do_preprocess, | |
| **kwargs): | |
| resp.text += out.response | |
| resp.generate_token_len = out.generate_token_len | |
| resp.input_token_len = out.input_token_len | |
| resp.finish_reason = out.finish_reason | |
| if out.token_ids: | |
| resp.token_ids.extend(out.token_ids) | |
| if out.logprobs: | |
| if resp.logprobs is None: | |
| resp.logprobs = [] | |
| resp.logprobs.extend(out.logprobs) | |
| return resp | |
| response = await asyncio.gather(*[ | |
| _inner_generate(sid, inp) for sid, inp in zip(session_ids, prompt) | |
| ]) | |
| texts = [resp.text for resp in response] | |
| # remove stop_words | |
| texts = filter_suffix(texts, self.gen_params.get('stop_words')) | |
| for resp, text in zip(response, texts): | |
| resp.text = text | |
| if batched: | |
| return [asdict(resp) | |
| for resp in response] if return_dict else texts | |
| return asdict(response[0]) if return_dict else texts[0] | |
| class AsyncLMDeployServer(AsyncLLMMixin, LMDeployServer): | |
| """ | |
| Args: | |
| path (str): The path to the model. | |
| It could be one of the following options: | |
| - i) A local directory path of a turbomind model which is | |
| converted by `lmdeploy convert` command or download from | |
| ii) and iii). | |
| - ii) The model_id of a lmdeploy-quantized model hosted | |
| inside a model repo on huggingface.co, such as | |
| "InternLM/internlm-chat-20b-4bit", | |
| "lmdeploy/llama2-chat-70b-4bit", etc. | |
| - iii) The model_id of a model hosted inside a model repo | |
| on huggingface.co, such as "internlm/internlm-chat-7b", | |
| "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" | |
| and so on. | |
| model_name (str): needed when model_path is a pytorch model on | |
| huggingface.co, such as "internlm-chat-7b", | |
| "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on. | |
| server_name (str): host ip for serving | |
| server_port (int): server port | |
| tp (int): tensor parallel | |
| log_level (str): set log level whose value among | |
| [CRITICAL, ERROR, WARNING, INFO, DEBUG] | |
| """ | |
| async def generate( | |
| self, | |
| inputs: Union[str, List[str]], | |
| session_ids: Union[int, List[int]] = None, | |
| sequence_start: bool = True, | |
| sequence_end: bool = True, | |
| ignore_eos: bool = False, | |
| skip_special_tokens: Optional[bool] = False, | |
| timeout: int = 30, | |
| **kwargs, | |
| ): | |
| """Start a new round conversation of a session. Return the chat | |
| completions in non-stream mode. | |
| Args: | |
| inputs (str, List[str]): user's prompt(s) in this round | |
| session_ids (int, List[int]): session id(s) | |
| sequence_start (bool): start flag of a session | |
| sequence_end (bool): end flag of a session | |
| ignore_eos (bool): indicator for ignoring eos | |
| skip_special_tokens (bool): Whether or not to remove special tokens | |
| in the decoding. Default to be False. | |
| timeout (int): max time to wait for response | |
| Returns: | |
| (a list of/batched) text/chat completion | |
| """ | |
| from lmdeploy.serve.openai.api_client import json_loads | |
| batched = True | |
| if isinstance(inputs, str): | |
| inputs = [inputs] | |
| batched = False | |
| gen_params = self.update_gen_params(**kwargs) | |
| max_new_tokens = gen_params.pop('max_new_tokens') | |
| gen_params.update(max_tokens=max_new_tokens) | |
| responses = [''] * len(inputs) | |
| pload = dict( | |
| model=self.model_name, | |
| prompt=inputs, | |
| sequence_start=sequence_start, | |
| sequence_end=sequence_end, | |
| stream=False, | |
| ignore_eos=ignore_eos, | |
| skip_special_tokens=skip_special_tokens, | |
| timeout=timeout, | |
| **gen_params) | |
| async with aiohttp.ClientSession( | |
| timeout=aiohttp.ClientTimeout(3 * 3600)) as session: | |
| async with session.post( | |
| self.client.completions_v1_url, | |
| headers=self.client.headers, | |
| json=pload) as resp: | |
| async for chunk in resp.content: | |
| if chunk: | |
| decoded = chunk.decode('utf-8') | |
| output = json_loads(decoded) | |
| responses = [ | |
| response + item['text'] for response, item in zip( | |
| responses, output['choices']) | |
| ] | |
| # remove stop_words | |
| responses = filter_suffix(responses, self.gen_params.get('stop_words')) | |
| if not batched: | |
| return responses[0] | |
| return responses | |
| async def stream_chat( | |
| self, | |
| inputs: List[dict], | |
| session_id: int = None, | |
| sequence_start: bool = True, | |
| sequence_end: bool = True, | |
| stream: bool = True, | |
| ignore_eos: bool = False, | |
| skip_special_tokens: Optional[bool] = False, | |
| timeout: int = 30, | |
| **kwargs, | |
| ): | |
| """Start a new round conversation of a session. Return the chat | |
| completions in stream mode. | |
| Args: | |
| inputs (List[dict]): user's inputs in this round conversation | |
| session_id (int): session id | |
| sequence_start (bool): start flag of a session | |
| sequence_end (bool): end flag of a session | |
| stream (bool): return in a streaming format if enabled | |
| ignore_eos (bool): indicator for ignoring eos | |
| skip_special_tokens (bool): Whether or not to remove special tokens | |
| in the decoding. Default to be False. | |
| timeout (int): max time to wait for response | |
| Returns: | |
| tuple(Status, str, int): status, text/chat completion, | |
| generated token number | |
| """ | |
| from lmdeploy.serve.openai.api_client import json_loads | |
| gen_params = self.update_gen_params(**kwargs) | |
| max_new_tokens = gen_params.pop('max_new_tokens') | |
| gen_params.update(max_tokens=max_new_tokens) | |
| prompt = self.template_parser(inputs) | |
| response = '' | |
| finished = False | |
| stop_words = self.gen_params.get('stop_words') | |
| pload = dict( | |
| model=self.model_name, | |
| prompt=prompt, | |
| sequence_start=sequence_start, | |
| sequence_end=sequence_end, | |
| stream=stream, | |
| ignore_eos=ignore_eos, | |
| skip_special_tokens=skip_special_tokens, | |
| timeout=timeout, | |
| **gen_params) | |
| async with aiohttp.ClientSession( | |
| timeout=aiohttp.ClientTimeout(3 * 3600)) as session: | |
| async with session.post( | |
| self.client.completions_v1_url, | |
| headers=self.client.headers, | |
| json=pload) as resp: | |
| async for chunk in resp.content: | |
| if chunk: | |
| decoded = chunk.decode('utf-8') | |
| if not decoded.strip() or decoded.rstrip( | |
| ) == 'data: [DONE]': | |
| continue | |
| if decoded[:6] == 'data: ': | |
| decoded = decoded[6:] | |
| output = json_loads(decoded) | |
| response += output['choices'][0]['text'] | |
| if not response: | |
| continue | |
| # remove stop_words | |
| for sw in stop_words: | |
| if sw in response: | |
| response = filter_suffix(response, stop_words) | |
| finished = True | |
| break | |
| yield ModelStatusCode.STREAM_ING, response, None | |
| if finished: | |
| break | |
| yield ModelStatusCode.END, response, None | |
| class AsyncLMDeployClient(AsyncLMDeployServer): | |
| """ | |
| Args: | |
| url (str): communicating address 'http://<ip>:<port>' of | |
| api_server | |
| model_name (str): needed when model_path is a pytorch model on | |
| huggingface.co, such as "internlm-chat-7b", | |
| "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on. | |
| """ | |
| def __init__(self, url: str, model_name: str, **kwargs): | |
| BaseLLM.__init__(self, path=url, **kwargs) | |
| from lmdeploy.serve.openai.api_client import APIClient | |
| self.client = APIClient(url) | |
| self.model_name = model_name | |