|
|
""" |
|
|
Based on LLaVA v1.5: https://github.com/haotian-liu/LLaVA |
|
|
Modified for SafeLLaVA |
|
|
|
|
|
Original LLaVA License: Apache License 2.0 |
|
|
""" |
|
|
|
|
|
""" |
|
|
Conversation prompts for SafeLLaVA. |
|
|
|
|
|
This is a simplified version containing only the llava_v1 conversation template. |
|
|
""" |
|
|
|
|
|
import dataclasses |
|
|
from enum import auto, Enum |
|
|
from typing import List |
|
|
|
|
|
|
|
|
class SeparatorStyle(Enum): |
|
|
"""Different separator style.""" |
|
|
TWO = auto() |
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
class Conversation: |
|
|
"""A class that keeps all conversation history.""" |
|
|
system: str |
|
|
roles: List[str] |
|
|
messages: List[List[str]] |
|
|
offset: int |
|
|
sep_style: SeparatorStyle = SeparatorStyle.TWO |
|
|
sep: str = " " |
|
|
sep2: str = "</s>" |
|
|
version: str = "v1" |
|
|
|
|
|
def get_prompt(self): |
|
|
"""Generate the full prompt from conversation history.""" |
|
|
messages = self.messages |
|
|
|
|
|
|
|
|
if len(messages) > 0 and type(messages[0][1]) is tuple: |
|
|
messages = self.messages.copy() |
|
|
init_role, init_msg = messages[0].copy() |
|
|
init_msg = init_msg[0].replace("<image>", "").strip() |
|
|
messages[0] = (init_role, "<image>\n" + init_msg) |
|
|
|
|
|
if self.sep_style == SeparatorStyle.TWO: |
|
|
seps = [self.sep, self.sep2] |
|
|
ret = self.system + seps[0] |
|
|
for i, (role, message) in enumerate(messages): |
|
|
if message: |
|
|
if type(message) is tuple: |
|
|
message, _, _ = message |
|
|
ret += role + ": " + message + seps[i % 2] |
|
|
else: |
|
|
ret += role + ":" |
|
|
return ret |
|
|
else: |
|
|
raise ValueError(f"Invalid style: {self.sep_style}") |
|
|
|
|
|
def append_message(self, role, message): |
|
|
"""Append a message to conversation history.""" |
|
|
self.messages.append([role, message]) |
|
|
|
|
|
def copy(self): |
|
|
"""Create a copy of this conversation.""" |
|
|
return Conversation( |
|
|
system=self.system, |
|
|
roles=self.roles, |
|
|
messages=[[x, y] for x, y in self.messages], |
|
|
offset=self.offset, |
|
|
sep_style=self.sep_style, |
|
|
sep=self.sep, |
|
|
sep2=self.sep2, |
|
|
version=self.version, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
conv_llava_v1 = Conversation( |
|
|
system="A chat between a curious human and an artificial intelligence assistant. " |
|
|
"The assistant gives helpful, detailed, and polite answers to the human's questions.", |
|
|
roles=("USER", "ASSISTANT"), |
|
|
version="v1", |
|
|
messages=(), |
|
|
offset=0, |
|
|
sep_style=SeparatorStyle.TWO, |
|
|
sep=" ", |
|
|
sep2="</s>", |
|
|
) |
|
|
|
|
|
|
|
|
conv_templates = { |
|
|
"llava_v1": conv_llava_v1, |
|
|
"default": conv_llava_v1, |
|
|
} |
|
|
|