| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from typing import Any, Dict, List, Optional, Sequence | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						import transformers | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from .constants import IGNORE_INDEX, SENTINEL_TOKEN | 
					
					
						
						| 
							 | 
						from .conversation import SeparatorStyle, default_conversation | 
					
					
						
						| 
							 | 
						from .mm_utils import tokenizer_image_token | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						DUMMY_CONVERSATION = [ | 
					
					
						
						| 
							 | 
						    {"from": "human", "value": "question"}, | 
					
					
						
						| 
							 | 
						    {"from": "gpt", "value": "answer"}, | 
					
					
						
						| 
							 | 
						] * 10 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def tokenize_conversation_legacy( | 
					
					
						
						| 
							 | 
						    messages: Sequence[Dict[str, str]], | 
					
					
						
						| 
							 | 
						    tokenizer: transformers.PreTrainedTokenizer, | 
					
					
						
						| 
							 | 
						    add_generation_prompt: bool = False, | 
					
					
						
						| 
							 | 
						    overrides: Optional[Dict[str, str]] = None, | 
					
					
						
						| 
							 | 
						    no_system_prompt: bool = False, | 
					
					
						
						| 
							 | 
						) -> torch.Tensor: | 
					
					
						
						| 
							 | 
						    conv = default_conversation.copy() | 
					
					
						
						| 
							 | 
						    roles = {"human": conv.roles[0], "gpt": conv.roles[1]} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if no_system_prompt: | 
					
					
						
						| 
							 | 
						        conv.system = "" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if messages[0]["from"] != "human": | 
					
					
						
						| 
							 | 
						        messages = messages[1:] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if add_generation_prompt: | 
					
					
						
						| 
							 | 
						        messages.append({"from": "gpt", "value": None}) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    conv.messages = [] | 
					
					
						
						| 
							 | 
						    for turn, message in enumerate(messages): | 
					
					
						
						| 
							 | 
						        role = roles[message["from"]] | 
					
					
						
						| 
							 | 
						        assert role == conv.roles[turn % 2] | 
					
					
						
						| 
							 | 
						        if overrides is not None and message["from"] in overrides: | 
					
					
						
						| 
							 | 
						            conv.append_message(role, overrides[message["from"]]) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            conv.append_message(role, message["value"]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def tokenize_conversation( | 
					
					
						
						| 
							 | 
						    messages: Sequence[Dict[str, str]], | 
					
					
						
						| 
							 | 
						    tokenizer: transformers.PreTrainedTokenizer, | 
					
					
						
						| 
							 | 
						    add_generation_prompt: bool = False, | 
					
					
						
						| 
							 | 
						    overrides: Optional[Dict[str, str]] = None, | 
					
					
						
						| 
							 | 
						    no_system_prompt: bool = False, | 
					
					
						
						| 
							 | 
						) -> torch.Tensor: | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    for message in messages: | 
					
					
						
						| 
							 | 
						        message["value"] = message["value"].strip() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if default_conversation.sep_style != SeparatorStyle.AUTO: | 
					
					
						
						| 
							 | 
						        return tokenize_conversation_legacy( | 
					
					
						
						| 
							 | 
						            messages, | 
					
					
						
						| 
							 | 
						            tokenizer, | 
					
					
						
						| 
							 | 
						            add_generation_prompt=add_generation_prompt, | 
					
					
						
						| 
							 | 
						            overrides=overrides, | 
					
					
						
						| 
							 | 
						            no_system_prompt=no_system_prompt, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    conversation = [] | 
					
					
						
						| 
							 | 
						    for m in messages: | 
					
					
						
						| 
							 | 
						        message = {} | 
					
					
						
						| 
							 | 
						        if m["from"] == "human": | 
					
					
						
						| 
							 | 
						            message["role"] = "user" | 
					
					
						
						| 
							 | 
						        elif m["from"] == "gpt": | 
					
					
						
						| 
							 | 
						            message["role"] = "assistant" | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            raise ValueError(f"Unexpected sender '{m['from']}' in conversation entry.") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        message["content"] = m["value"] | 
					
					
						
						| 
							 | 
						        if overrides is not None and m["from"] in overrides: | 
					
					
						
						| 
							 | 
						            message["content"] = overrides[m["from"]] | 
					
					
						
						| 
							 | 
						        conversation.append(message) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if no_system_prompt: | 
					
					
						
						| 
							 | 
						        conversation = [{"role": "system", "content": ""}] + conversation | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    text = tokenizer.apply_chat_template( | 
					
					
						
						| 
							 | 
						        conversation, | 
					
					
						
						| 
							 | 
						        add_generation_prompt=add_generation_prompt, | 
					
					
						
						| 
							 | 
						        tokenize=False, | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						    return tokenizer_image_token(text, tokenizer, return_tensors="pt") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None: | 
					
					
						
						| 
							 | 
						    if not hasattr(tokenizer, "sentinel_token"): | 
					
					
						
						| 
							 | 
						        tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True) | 
					
					
						
						| 
							 | 
						        tokenizer.sentinel_token = SENTINEL_TOKEN | 
					
					
						
						| 
							 | 
						        tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def preprocess_conversation( | 
					
					
						
						| 
							 | 
						    conversation: Sequence[Dict[str, str]], | 
					
					
						
						| 
							 | 
						    tokenizer: transformers.PreTrainedTokenizer, | 
					
					
						
						| 
							 | 
						    no_system_prompt: bool = False, | 
					
					
						
						| 
							 | 
						    retried: bool = False, | 
					
					
						
						| 
							 | 
						) -> Dict[str, Any]: | 
					
					
						
						| 
							 | 
						    inputs = tokenize_conversation(conversation, tokenizer, no_system_prompt=no_system_prompt) | 
					
					
						
						| 
							 | 
						    labels = torch.ones_like(inputs) * IGNORE_INDEX | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    _maybe_add_sentinel_token(tokenizer) | 
					
					
						
						| 
							 | 
						    template = tokenize_conversation( | 
					
					
						
						| 
							 | 
						        conversation, tokenizer, overrides={"gpt": SENTINEL_TOKEN}, no_system_prompt=no_system_prompt | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    mask = torch.ones_like(template, dtype=torch.bool) | 
					
					
						
						| 
							 | 
						    for k in range(template.size(0) - 1): | 
					
					
						
						| 
							 | 
						        if template[k] == tokenizer.sentinel_token_id: | 
					
					
						
						| 
							 | 
						            mask[k : k + 2] = False | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            if k > 0 and retried: | 
					
					
						
						| 
							 | 
						                mask[k - 1] = False | 
					
					
						
						| 
							 | 
						    template = template[mask] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    p = 0 | 
					
					
						
						| 
							 | 
						    for k in range(inputs.size(0)): | 
					
					
						
						| 
							 | 
						        if p < template.size(0) and inputs[k] == template[p]: | 
					
					
						
						| 
							 | 
						            p += 1 | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            labels[k] = inputs[k] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if p < template.size(0): | 
					
					
						
						| 
							 | 
						        if not retried: | 
					
					
						
						| 
							 | 
						            return preprocess_conversation( | 
					
					
						
						| 
							 | 
						                conversation, | 
					
					
						
						| 
							 | 
						                tokenizer, | 
					
					
						
						| 
							 | 
						                no_system_prompt=no_system_prompt, | 
					
					
						
						| 
							 | 
						                retried=True, | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						        print(f"Failed to process the conversation: '{conversation}'. All tokens will be masked in the label.") | 
					
					
						
						| 
							 | 
						        labels[:] = IGNORE_INDEX | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return {"input_ids": inputs, "labels": labels} | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]: | 
					
					
						
						| 
							 | 
						    _maybe_add_sentinel_token(tokenizer) | 
					
					
						
						| 
							 | 
						    template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN}) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    stop_tokens = {tokenizer.eos_token} | 
					
					
						
						| 
							 | 
						    for k in range(template.size(0) - 1): | 
					
					
						
						| 
							 | 
						        if template[k] == tokenizer.sentinel_token_id: | 
					
					
						
						| 
							 | 
						            stop_token = tokenizer.decode(template[k + 1]) | 
					
					
						
						| 
							 | 
						            stop_tokens.add(stop_token) | 
					
					
						
						| 
							 | 
						    return list(stop_tokens) | 
					
					
						
						| 
							 | 
						
 |