Upload nemotron_toolcall_parser_no_streaming.py
Browse files
    	
        nemotron_toolcall_parser_no_streaming.py
    ADDED
    
    | @@ -0,0 +1,110 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import ast
         | 
| 4 | 
            +
            import json
         | 
| 5 | 
            +
            import re
         | 
| 6 | 
            +
            from collections.abc import Sequence
         | 
| 7 | 
            +
            from typing import Union
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import partial_json_parser
         | 
| 10 | 
            +
            from partial_json_parser.core.options import Allow
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from vllm.entrypoints.openai.protocol import (
         | 
| 13 | 
            +
                ChatCompletionRequest,
         | 
| 14 | 
            +
                DeltaFunctionCall, DeltaMessage,
         | 
| 15 | 
            +
                DeltaToolCall,
         | 
| 16 | 
            +
                ExtractedToolCallInformation,
         | 
| 17 | 
            +
                FunctionCall,
         | 
| 18 | 
            +
                ToolCall,
         | 
| 19 | 
            +
            )
         | 
| 20 | 
            +
            from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
         | 
| 21 | 
            +
                ToolParser,
         | 
| 22 | 
            +
                ToolParserManager,
         | 
| 23 | 
            +
            )
         | 
| 24 | 
            +
            from vllm.logger import init_logger
         | 
| 25 | 
            +
            from vllm.transformers_utils.tokenizer import AnyTokenizer
         | 
| 26 | 
            +
            from vllm.utils import random_uuid
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            logger = init_logger(__name__)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            @ToolParserManager.register_module("nemotron_json")
         | 
| 32 | 
            +
            class NemotronJSONToolParser(ToolParser):
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                def __init__(self, tokenizer: AnyTokenizer):
         | 
| 35 | 
            +
                    super().__init__(tokenizer)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    self.current_tool_name_sent: bool = False
         | 
| 38 | 
            +
                    self.prev_tool_call_arr: list[dict] = []
         | 
| 39 | 
            +
                    self.current_tool_id: int = -1
         | 
| 40 | 
            +
                    self.streamed_args_for_tool: list[str] = []
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    self.tool_call_start_token: str = "<TOOLCALL>"
         | 
| 43 | 
            +
                    self.tool_call_end_token: str = "</TOOLCALL>"
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                    self.tool_call_regex = re.compile(r"<TOOLCALL>(.*?)</TOOLCALL>", re.DOTALL)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                def extract_tool_calls(
         | 
| 48 | 
            +
                    self,
         | 
| 49 | 
            +
                    model_output: str,
         | 
| 50 | 
            +
                    request: ChatCompletionRequest,
         | 
| 51 | 
            +
                ) -> ExtractedToolCallInformation:
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                    if self.tool_call_start_token not in model_output:
         | 
| 54 | 
            +
                        return ExtractedToolCallInformation(
         | 
| 55 | 
            +
                            tools_called=False,
         | 
| 56 | 
            +
                            tool_calls=[],
         | 
| 57 | 
            +
                            content=model_output,
         | 
| 58 | 
            +
                        )
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    else:
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                        try:
         | 
| 63 | 
            +
                            str_tool_calls = self.tool_call_regex.findall(model_output)[0].strip()
         | 
| 64 | 
            +
                            if not str_tool_calls.startswith("["):
         | 
| 65 | 
            +
                                str_tool_calls = "[" + str_tool_calls
         | 
| 66 | 
            +
                            if not str_tool_calls.endswith("]"):
         | 
| 67 | 
            +
                                str_tool_calls = "]" + str_tool_calls
         | 
| 68 | 
            +
                            json_tool_calls = json.loads(str_tool_calls)
         | 
| 69 | 
            +
                            tool_calls = []
         | 
| 70 | 
            +
                            for tool_call in json_tool_calls:
         | 
| 71 | 
            +
                                try:
         | 
| 72 | 
            +
                                    tool_calls.append(ToolCall(
         | 
| 73 | 
            +
                                        type="function",
         | 
| 74 | 
            +
                                        function=FunctionCall(
         | 
| 75 | 
            +
                                            name=tool_call["name"],
         | 
| 76 | 
            +
                                            arguments=json.dumps(tool_call["arguments"], ensure_ascii=False) \
         | 
| 77 | 
            +
                                                if isinstance(tool_call["arguments"], dict) else tool_call["arguments"],
         | 
| 78 | 
            +
                                        ),
         | 
| 79 | 
            +
                                    ))
         | 
| 80 | 
            +
                                except:
         | 
| 81 | 
            +
                                    continue
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                            content = model_output[:model_output.rfind(self.tool_call_start_token)]
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                            return ExtractedToolCallInformation(
         | 
| 86 | 
            +
                                tools_called=True,
         | 
| 87 | 
            +
                                tool_calls=tool_calls,
         | 
| 88 | 
            +
                                content=content if content else None,
         | 
| 89 | 
            +
                            )
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                        except Exception:
         | 
| 92 | 
            +
                            logger.exception(f"Error in extracting tool call from response. Response: {model_output}")
         | 
| 93 | 
            +
                            return ExtractedToolCallInformation(
         | 
| 94 | 
            +
                                tools_called=False,
         | 
| 95 | 
            +
                                tool_calls=[],
         | 
| 96 | 
            +
                                content=model_output,
         | 
| 97 | 
            +
                            )
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                def extract_tool_calls_streaming(
         | 
| 100 | 
            +
                    self,
         | 
| 101 | 
            +
                    previous_text: str,
         | 
| 102 | 
            +
                    current_text: str,
         | 
| 103 | 
            +
                    delta_text: str,
         | 
| 104 | 
            +
                    previous_token_ids: Sequence[int],
         | 
| 105 | 
            +
                    current_token_ids: Sequence[int],
         | 
| 106 | 
            +
                    delta_token_ids: Sequence[int],
         | 
| 107 | 
            +
                    request: ChatCompletionRequest,
         | 
| 108 | 
            +
                ) -> Union[DeltaMessage, None]:
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    raise NotImplementedError("Tool calling is not supported in streaming mode!")
         | 
