Commit
·
6a1fd9e
1
Parent(s):
7134600
Upload h2oai_pipeline.py
Browse files- h2oai_pipeline.py +648 -2
h2oai_pipeline.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
from transformers import TextGenerationPipeline
|
| 2 |
from transformers.pipelines.text_generation import ReturnType
|
| 3 |
|
| 4 |
-
|
| 5 |
-
|
| 6 |
|
| 7 |
|
| 8 |
class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
@@ -126,3 +126,649 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
| 126 |
else:
|
| 127 |
raise ValueError("TF not avaialble.")
|
| 128 |
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from transformers import TextGenerationPipeline
|
| 2 |
from transformers.pipelines.text_generation import ReturnType
|
| 3 |
|
| 4 |
+
|
| 5 |
+
|
| 6 |
|
| 7 |
|
| 8 |
class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
|
|
| 126 |
else:
|
| 127 |
raise ValueError("TF not avaialble.")
|
| 128 |
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
|
| 129 |
+
import torch
|
| 130 |
+
from transformers import StoppingCriteria, StoppingCriteriaList
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class StoppingCriteriaSub(StoppingCriteria):
|
| 135 |
+
|
| 136 |
+
def __init__(self, stops=[], encounters=[], device="cuda"):
|
| 137 |
+
super().__init__()
|
| 138 |
+
assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
|
| 139 |
+
self.encounters = encounters
|
| 140 |
+
self.stops = [stop.to(device) for stop in stops]
|
| 141 |
+
self.num_stops = [0] * len(stops)
|
| 142 |
+
|
| 143 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
| 144 |
+
for stopi, stop in enumerate(self.stops):
|
| 145 |
+
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
| 146 |
+
self.num_stops[stopi] += 1
|
| 147 |
+
if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
|
| 148 |
+
# print("Stopped", flush=True)
|
| 149 |
+
return True
|
| 150 |
+
# print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
|
| 151 |
+
# print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
|
| 152 |
+
return False
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def get_stopping(prompt_type, tokenizer, device, human='<human>:', bot="<bot>:"):
|
| 156 |
+
if prompt_type in [PromptType.human_bot.name, PromptType.instruct_vicuna.name, PromptType.instruct_with_end.name]:
|
| 157 |
+
if prompt_type == PromptType.human_bot.name:
|
| 158 |
+
# encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
|
| 159 |
+
# stopping only starts once output is beyond prompt
|
| 160 |
+
# 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
|
| 161 |
+
stop_words = [human, bot, '\n' + human, '\n' + bot]
|
| 162 |
+
encounters = [1, 2]
|
| 163 |
+
elif prompt_type == PromptType.instruct_vicuna.name:
|
| 164 |
+
# even below is not enough, generic strings and many ways to encode
|
| 165 |
+
stop_words = [
|
| 166 |
+
'### Human:',
|
| 167 |
+
"""
|
| 168 |
+
### Human:""",
|
| 169 |
+
"""
|
| 170 |
+
### Human:
|
| 171 |
+
""",
|
| 172 |
+
'### Assistant:',
|
| 173 |
+
"""
|
| 174 |
+
### Assistant:""",
|
| 175 |
+
"""
|
| 176 |
+
### Assistant:
|
| 177 |
+
""",
|
| 178 |
+
]
|
| 179 |
+
encounters = [1, 2]
|
| 180 |
+
else:
|
| 181 |
+
# some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
|
| 182 |
+
stop_words = ['### End']
|
| 183 |
+
encounters = [1]
|
| 184 |
+
stop_words_ids = [
|
| 185 |
+
tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
|
| 186 |
+
# handle single token case
|
| 187 |
+
stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
|
| 188 |
+
stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
|
| 189 |
+
# avoid padding in front of tokens
|
| 190 |
+
if tokenizer._pad_token: # use hidden variable to avoid annoying properly logger bug
|
| 191 |
+
stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
|
| 192 |
+
# handle fake \n added
|
| 193 |
+
stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
|
| 194 |
+
# build stopper
|
| 195 |
+
stopping_criteria = StoppingCriteriaList(
|
| 196 |
+
[StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device)])
|
| 197 |
+
else:
|
| 198 |
+
stopping_criteria = StoppingCriteriaList()
|
| 199 |
+
return stopping_criteria
|
| 200 |
+
import time
|
| 201 |
+
from enum import Enum
|
| 202 |
+
|
| 203 |
+
non_hf_types = ['gpt4all_llama', 'llama', 'gptj']
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class PromptType(Enum):
|
| 207 |
+
plain = 0
|
| 208 |
+
instruct = 1
|
| 209 |
+
quality = 2
|
| 210 |
+
human_bot = 3
|
| 211 |
+
dai_faq = 4
|
| 212 |
+
summarize = 5
|
| 213 |
+
simple_instruct = 6
|
| 214 |
+
instruct_vicuna = 7
|
| 215 |
+
instruct_with_end = 8
|
| 216 |
+
human_bot_orig = 9
|
| 217 |
+
prompt_answer = 10
|
| 218 |
+
open_assistant = 11
|
| 219 |
+
wizard_lm = 12
|
| 220 |
+
wizard_mega = 13
|
| 221 |
+
instruct_vicuna2 = 14
|
| 222 |
+
instruct_vicuna3 = 15
|
| 223 |
+
wizard2 = 16
|
| 224 |
+
wizard3 = 17
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
prompt_type_to_model_name = {
|
| 228 |
+
'plain': [
|
| 229 |
+
'EleutherAI/gpt-j-6B',
|
| 230 |
+
'EleutherAI/pythia-6.9b',
|
| 231 |
+
'EleutherAI/pythia-12b',
|
| 232 |
+
'EleutherAI/pythia-12b-deduped',
|
| 233 |
+
'EleutherAI/gpt-neox-20b',
|
| 234 |
+
'openlm-research/open_llama_7b_700bt_preview',
|
| 235 |
+
'decapoda-research/llama-7b-hf',
|
| 236 |
+
'decapoda-research/llama-13b-hf',
|
| 237 |
+
'decapoda-research/llama-30b-hf',
|
| 238 |
+
'decapoda-research/llama-65b-hf',
|
| 239 |
+
'facebook/mbart-large-50-many-to-many-mmt',
|
| 240 |
+
'philschmid/bart-large-cnn-samsum',
|
| 241 |
+
'philschmid/flan-t5-base-samsum',
|
| 242 |
+
'gpt2',
|
| 243 |
+
'distilgpt2',
|
| 244 |
+
'mosaicml/mpt-7b-storywriter',
|
| 245 |
+
'mosaicml/mpt-7b-instruct', # internal code handles instruct
|
| 246 |
+
'mosaicml/mpt-7b-chat', # NC, internal code handles instruct
|
| 247 |
+
'gptj', # internally handles prompting
|
| 248 |
+
'llama', # plain, or need to choose prompt_type for given TheBloke model
|
| 249 |
+
'gpt4all_llama', # internally handles prompting
|
| 250 |
+
],
|
| 251 |
+
'prompt_answer': [
|
| 252 |
+
'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
|
| 253 |
+
'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
|
| 254 |
+
'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
|
| 255 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
|
| 256 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2',
|
| 257 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-700bt',
|
| 258 |
+
],
|
| 259 |
+
'instruct': [],
|
| 260 |
+
'instruct_with_end': ['databricks/dolly-v2-12b'],
|
| 261 |
+
'quality': [],
|
| 262 |
+
'human_bot': [
|
| 263 |
+
'h2oai/h2ogpt-oasst1-512-12b',
|
| 264 |
+
'h2oai/h2ogpt-oasst1-512-20b',
|
| 265 |
+
'h2oai/h2ogpt-oig-oasst1-256-6_9b',
|
| 266 |
+
'h2oai/h2ogpt-oig-oasst1-512-6_9b',
|
| 267 |
+
'h2oai/h2ogpt-oig-oasst1-256-6.9b', # legacy
|
| 268 |
+
'h2oai/h2ogpt-oig-oasst1-512-6.9b', # legacy
|
| 269 |
+
'h2oai/h2ogpt-research-oasst1-512-30b',
|
| 270 |
+
'h2oai/h2ogpt-oasst1-falcon-40b',
|
| 271 |
+
],
|
| 272 |
+
'dai_faq': [],
|
| 273 |
+
'summarize': [],
|
| 274 |
+
'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'],
|
| 275 |
+
'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b', 'TheBloke/stable-vicuna-13B-HF', 'junelee/wizard-vicuna-13b'],
|
| 276 |
+
'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'],
|
| 277 |
+
"open_assistant": ['OpenAssistant/oasst-sft-7-llama-30b-xor', 'oasst-sft-7-llama-30b'],
|
| 278 |
+
"wizard_lm": ['ehartford/WizardLM-7B-Uncensored', 'ehartford/WizardLM-13B-Uncensored'],
|
| 279 |
+
"wizard_mega": ['openaccess-ai-collective/wizard-mega-13b'],
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
| 283 |
+
inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
| 284 |
+
|
| 285 |
+
prompt_types_strings = []
|
| 286 |
+
for p in PromptType:
|
| 287 |
+
prompt_types_strings.extend([p.name])
|
| 288 |
+
|
| 289 |
+
prompt_types = []
|
| 290 |
+
for p in PromptType:
|
| 291 |
+
prompt_types.extend([p.name, p.value, str(p.value)])
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def get_prompt(prompt_type, chat, context, reduced):
|
| 295 |
+
if prompt_type in [PromptType.plain.value, str(PromptType.plain.value),
|
| 296 |
+
PromptType.plain.name]:
|
| 297 |
+
promptA = promptB = PreInstruct = PreInput = PreResponse = ''
|
| 298 |
+
terminate_response = []
|
| 299 |
+
chat_sep = ''
|
| 300 |
+
humanstr = ''
|
| 301 |
+
botstr = ''
|
| 302 |
+
elif prompt_type == 'simple_instruct':
|
| 303 |
+
promptA = promptB = PreInstruct = PreInput = PreResponse = None
|
| 304 |
+
terminate_response = []
|
| 305 |
+
chat_sep = '\n'
|
| 306 |
+
humanstr = ''
|
| 307 |
+
botstr = ''
|
| 308 |
+
elif prompt_type in [PromptType.instruct.value, str(PromptType.instruct.value),
|
| 309 |
+
PromptType.instruct.name] + [PromptType.instruct_with_end.value,
|
| 310 |
+
str(PromptType.instruct_with_end.value),
|
| 311 |
+
PromptType.instruct_with_end.name]:
|
| 312 |
+
promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (
|
| 313 |
+
chat and reduced) else ''
|
| 314 |
+
promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (
|
| 315 |
+
chat and reduced) else ''
|
| 316 |
+
|
| 317 |
+
PreInstruct = """
|
| 318 |
+
### Instruction:
|
| 319 |
+
"""
|
| 320 |
+
|
| 321 |
+
PreInput = """
|
| 322 |
+
### Input:
|
| 323 |
+
"""
|
| 324 |
+
|
| 325 |
+
PreResponse = """
|
| 326 |
+
### Response:
|
| 327 |
+
"""
|
| 328 |
+
if prompt_type in [PromptType.instruct_with_end.value, str(PromptType.instruct_with_end.value),
|
| 329 |
+
PromptType.instruct_with_end.name]:
|
| 330 |
+
terminate_response = ['### End']
|
| 331 |
+
else:
|
| 332 |
+
terminate_response = None
|
| 333 |
+
chat_sep = '\n'
|
| 334 |
+
humanstr = PreInstruct
|
| 335 |
+
botstr = PreResponse
|
| 336 |
+
elif prompt_type in [PromptType.quality.value, str(PromptType.quality.value),
|
| 337 |
+
PromptType.quality.name]:
|
| 338 |
+
promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (
|
| 339 |
+
chat and reduced) else ''
|
| 340 |
+
promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (
|
| 341 |
+
chat and reduced) else ''
|
| 342 |
+
|
| 343 |
+
PreInstruct = """
|
| 344 |
+
### Instruction:
|
| 345 |
+
"""
|
| 346 |
+
|
| 347 |
+
PreInput = """
|
| 348 |
+
### Input:
|
| 349 |
+
"""
|
| 350 |
+
|
| 351 |
+
PreResponse = """
|
| 352 |
+
### Response:
|
| 353 |
+
"""
|
| 354 |
+
terminate_response = None
|
| 355 |
+
chat_sep = '\n'
|
| 356 |
+
humanstr = PreInstruct # first thing human says
|
| 357 |
+
botstr = PreResponse # first thing bot says
|
| 358 |
+
elif prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
|
| 359 |
+
PromptType.human_bot.name] + [PromptType.human_bot_orig.value,
|
| 360 |
+
str(PromptType.human_bot_orig.value),
|
| 361 |
+
PromptType.human_bot_orig.name]:
|
| 362 |
+
human = '<human>:'
|
| 363 |
+
bot = "<bot>:"
|
| 364 |
+
if reduced or context or prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
|
| 365 |
+
PromptType.human_bot.name]:
|
| 366 |
+
preprompt = ''
|
| 367 |
+
else:
|
| 368 |
+
cur_date = time.strftime('%Y-%m-%d')
|
| 369 |
+
cur_time = time.strftime('%H:%M:%S %p %Z')
|
| 370 |
+
|
| 371 |
+
PRE_PROMPT = """\
|
| 372 |
+
Current Date: {}
|
| 373 |
+
Current Time: {}
|
| 374 |
+
|
| 375 |
+
"""
|
| 376 |
+
preprompt = PRE_PROMPT.format(cur_date, cur_time)
|
| 377 |
+
start = human
|
| 378 |
+
promptB = promptA = '%s%s ' % (preprompt, start)
|
| 379 |
+
|
| 380 |
+
PreInstruct = ""
|
| 381 |
+
|
| 382 |
+
PreInput = None
|
| 383 |
+
|
| 384 |
+
if reduced:
|
| 385 |
+
# when making context, want it to appear as-if LLM generated, which starts with space after :
|
| 386 |
+
PreResponse = bot + ' '
|
| 387 |
+
else:
|
| 388 |
+
# normally LLM adds space after this, because was how trained.
|
| 389 |
+
# if add space here, non-unique tokenization will often make LLM produce wrong output
|
| 390 |
+
PreResponse = bot
|
| 391 |
+
|
| 392 |
+
terminate_response = [start, PreResponse]
|
| 393 |
+
chat_sep = '\n'
|
| 394 |
+
humanstr = human # tag before human talks
|
| 395 |
+
botstr = bot # tag before bot talks
|
| 396 |
+
elif prompt_type in [PromptType.dai_faq.value, str(PromptType.dai_faq.value),
|
| 397 |
+
PromptType.dai_faq.name]:
|
| 398 |
+
promptA = ''
|
| 399 |
+
promptB = 'Answer the following Driverless AI question.\n'
|
| 400 |
+
|
| 401 |
+
PreInstruct = """
|
| 402 |
+
### Driverless AI frequently asked question:
|
| 403 |
+
"""
|
| 404 |
+
|
| 405 |
+
PreInput = None
|
| 406 |
+
|
| 407 |
+
PreResponse = """
|
| 408 |
+
### Driverless AI documentation answer:
|
| 409 |
+
"""
|
| 410 |
+
terminate_response = ['\n\n']
|
| 411 |
+
chat_sep = terminate_response
|
| 412 |
+
humanstr = PreInstruct
|
| 413 |
+
botstr = PreResponse
|
| 414 |
+
elif prompt_type in [PromptType.summarize.value, str(PromptType.summarize.value),
|
| 415 |
+
PromptType.summarize.name]:
|
| 416 |
+
promptA = promptB = PreInput = ''
|
| 417 |
+
PreInstruct = '## Main Text\n\n'
|
| 418 |
+
PreResponse = '\n\n## Summary\n\n'
|
| 419 |
+
terminate_response = None
|
| 420 |
+
chat_sep = '\n'
|
| 421 |
+
humanstr = PreInstruct
|
| 422 |
+
botstr = PreResponse
|
| 423 |
+
elif prompt_type in [PromptType.instruct_vicuna.value, str(PromptType.instruct_vicuna.value),
|
| 424 |
+
PromptType.instruct_vicuna.name]:
|
| 425 |
+
promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
|
| 426 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions." if not (
|
| 427 |
+
chat and reduced) else ''
|
| 428 |
+
|
| 429 |
+
PreInstruct = """
|
| 430 |
+
### Human:
|
| 431 |
+
"""
|
| 432 |
+
|
| 433 |
+
PreInput = None
|
| 434 |
+
|
| 435 |
+
PreResponse = """
|
| 436 |
+
### Assistant:
|
| 437 |
+
"""
|
| 438 |
+
terminate_response = [
|
| 439 |
+
'### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
| 440 |
+
chat_sep = '\n'
|
| 441 |
+
humanstr = PreInstruct
|
| 442 |
+
botstr = PreResponse
|
| 443 |
+
elif prompt_type in [PromptType.prompt_answer.value, str(PromptType.prompt_answer.value),
|
| 444 |
+
PromptType.prompt_answer.name]:
|
| 445 |
+
preprompt = ''
|
| 446 |
+
prompt_tokens = "<|prompt|>"
|
| 447 |
+
answer_tokens = "<|answer|>"
|
| 448 |
+
start = prompt_tokens
|
| 449 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
| 450 |
+
PreInstruct = ""
|
| 451 |
+
PreInput = None
|
| 452 |
+
PreResponse = answer_tokens
|
| 453 |
+
eos = '<|endoftext|>' # neox eos
|
| 454 |
+
terminate_response = [start, PreResponse, eos]
|
| 455 |
+
chat_sep = eos
|
| 456 |
+
humanstr = prompt_tokens
|
| 457 |
+
botstr = answer_tokens
|
| 458 |
+
elif prompt_type in [PromptType.open_assistant.value, str(PromptType.open_assistant.value),
|
| 459 |
+
PromptType.open_assistant.name]:
|
| 460 |
+
# From added_tokens.json
|
| 461 |
+
preprompt = ''
|
| 462 |
+
prompt_tokens = "<|prompter|>"
|
| 463 |
+
answer_tokens = "<|assistant|>"
|
| 464 |
+
start = prompt_tokens
|
| 465 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
| 466 |
+
PreInstruct = ""
|
| 467 |
+
PreInput = None
|
| 468 |
+
PreResponse = answer_tokens
|
| 469 |
+
pend = "<|prefix_end|>"
|
| 470 |
+
eos = "</s>"
|
| 471 |
+
terminate_response = [start, PreResponse, pend, eos]
|
| 472 |
+
chat_sep = eos
|
| 473 |
+
humanstr = prompt_tokens
|
| 474 |
+
botstr = answer_tokens
|
| 475 |
+
elif prompt_type in [PromptType.wizard_lm.value, str(PromptType.wizard_lm.value),
|
| 476 |
+
PromptType.wizard_lm.name]:
|
| 477 |
+
# https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py
|
| 478 |
+
preprompt = ''
|
| 479 |
+
start = ''
|
| 480 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
| 481 |
+
PreInstruct = ""
|
| 482 |
+
PreInput = None
|
| 483 |
+
PreResponse = "\n\n### Response\n"
|
| 484 |
+
eos = "</s>"
|
| 485 |
+
terminate_response = [PreResponse, eos]
|
| 486 |
+
chat_sep = eos
|
| 487 |
+
humanstr = promptA
|
| 488 |
+
botstr = PreResponse
|
| 489 |
+
elif prompt_type in [PromptType.wizard_mega.value, str(PromptType.wizard_mega.value),
|
| 490 |
+
PromptType.wizard_mega.name]:
|
| 491 |
+
preprompt = ''
|
| 492 |
+
start = ''
|
| 493 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
| 494 |
+
PreInstruct = """
|
| 495 |
+
### Instruction:
|
| 496 |
+
"""
|
| 497 |
+
PreInput = None
|
| 498 |
+
PreResponse = """
|
| 499 |
+
### Assistant:
|
| 500 |
+
"""
|
| 501 |
+
terminate_response = [PreResponse]
|
| 502 |
+
chat_sep = '\n'
|
| 503 |
+
humanstr = PreInstruct
|
| 504 |
+
botstr = PreResponse
|
| 505 |
+
elif prompt_type in [PromptType.instruct_vicuna2.value, str(PromptType.instruct_vicuna2.value),
|
| 506 |
+
PromptType.instruct_vicuna2.name]:
|
| 507 |
+
promptA = promptB = "" if not (
|
| 508 |
+
chat and reduced) else ''
|
| 509 |
+
|
| 510 |
+
PreInstruct = """
|
| 511 |
+
HUMAN:
|
| 512 |
+
"""
|
| 513 |
+
|
| 514 |
+
PreInput = None
|
| 515 |
+
|
| 516 |
+
PreResponse = """
|
| 517 |
+
ASSISTANT:
|
| 518 |
+
"""
|
| 519 |
+
terminate_response = [
|
| 520 |
+
'HUMAN:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
| 521 |
+
chat_sep = '\n'
|
| 522 |
+
humanstr = PreInstruct
|
| 523 |
+
botstr = PreResponse
|
| 524 |
+
elif prompt_type in [PromptType.instruct_vicuna3.value, str(PromptType.instruct_vicuna3.value),
|
| 525 |
+
PromptType.instruct_vicuna3.name]:
|
| 526 |
+
promptA = promptB = "" if not (
|
| 527 |
+
chat and reduced) else ''
|
| 528 |
+
|
| 529 |
+
PreInstruct = """
|
| 530 |
+
### User:
|
| 531 |
+
"""
|
| 532 |
+
|
| 533 |
+
PreInput = None
|
| 534 |
+
|
| 535 |
+
PreResponse = """
|
| 536 |
+
### Assistant:
|
| 537 |
+
"""
|
| 538 |
+
terminate_response = [
|
| 539 |
+
'### User:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
| 540 |
+
chat_sep = '\n'
|
| 541 |
+
humanstr = PreInstruct
|
| 542 |
+
botstr = PreResponse
|
| 543 |
+
elif prompt_type in [PromptType.wizard2.value, str(PromptType.wizard2.value),
|
| 544 |
+
PromptType.wizard2.name]:
|
| 545 |
+
# https://huggingface.co/TheBloke/WizardLM-7B-uncensored-GGML
|
| 546 |
+
preprompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request."""
|
| 547 |
+
start = ''
|
| 548 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
| 549 |
+
PreInstruct = """
|
| 550 |
+
### Instruction:
|
| 551 |
+
"""
|
| 552 |
+
PreInput = None
|
| 553 |
+
PreResponse = """
|
| 554 |
+
### Response:
|
| 555 |
+
"""
|
| 556 |
+
terminate_response = [PreResponse]
|
| 557 |
+
chat_sep = '\n'
|
| 558 |
+
humanstr = PreInstruct
|
| 559 |
+
botstr = PreResponse
|
| 560 |
+
elif prompt_type in [PromptType.wizard3.value, str(PromptType.wizard3.value),
|
| 561 |
+
PromptType.wizard3.name]:
|
| 562 |
+
# https://huggingface.co/TheBloke/wizardLM-13B-1.0-GGML
|
| 563 |
+
preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."""
|
| 564 |
+
start = ''
|
| 565 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
| 566 |
+
PreInstruct = """USER: """
|
| 567 |
+
PreInput = None
|
| 568 |
+
PreResponse = """ASSISTANT: """
|
| 569 |
+
terminate_response = [PreResponse]
|
| 570 |
+
chat_sep = '\n'
|
| 571 |
+
humanstr = PreInstruct
|
| 572 |
+
botstr = PreResponse
|
| 573 |
+
|
| 574 |
+
else:
|
| 575 |
+
raise RuntimeError("No such prompt_type=%s" % prompt_type)
|
| 576 |
+
|
| 577 |
+
return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response, chat_sep, humanstr, botstr
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
def generate_prompt(data_point, prompt_type, chat, reduced):
|
| 581 |
+
context = data_point.get('context')
|
| 582 |
+
if context is None:
|
| 583 |
+
context = ''
|
| 584 |
+
instruction = data_point.get('instruction')
|
| 585 |
+
input = data_point.get('input')
|
| 586 |
+
output = data_point.get('output')
|
| 587 |
+
prompt_type = data_point.get('prompt_type', prompt_type)
|
| 588 |
+
assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
|
| 589 |
+
promptA, promptB, PreInstruct, PreInput, PreResponse, \
|
| 590 |
+
terminate_response, chat_sep, humanstr, botstr = get_prompt(prompt_type, chat, context, reduced)
|
| 591 |
+
|
| 592 |
+
prompt = context if not reduced else ''
|
| 593 |
+
|
| 594 |
+
if input and promptA:
|
| 595 |
+
prompt += f"""{promptA}"""
|
| 596 |
+
elif promptB:
|
| 597 |
+
prompt += f"""{promptB}"""
|
| 598 |
+
|
| 599 |
+
if instruction and PreInstruct is not None and input and PreInput is not None:
|
| 600 |
+
prompt += f"""{PreInstruct}{instruction}{PreInput}{input}"""
|
| 601 |
+
prompt = inject_newline(prompt_type, prompt)
|
| 602 |
+
elif instruction and input and PreInstruct is None and PreInput is not None:
|
| 603 |
+
prompt += f"""{PreInput}{instruction}
|
| 604 |
+
{input}"""
|
| 605 |
+
prompt = inject_newline(prompt_type, prompt)
|
| 606 |
+
elif input and instruction and PreInput is None and PreInstruct is not None:
|
| 607 |
+
prompt += f"""{PreInstruct}{instruction}
|
| 608 |
+
{input}"""
|
| 609 |
+
prompt = inject_newline(prompt_type, prompt)
|
| 610 |
+
elif instruction and PreInstruct is not None:
|
| 611 |
+
prompt += f"""{PreInstruct}{instruction}"""
|
| 612 |
+
prompt = inject_newline(prompt_type, prompt)
|
| 613 |
+
elif input and PreInput is not None:
|
| 614 |
+
prompt += f"""{PreInput}{input}"""
|
| 615 |
+
prompt = inject_newline(prompt_type, prompt)
|
| 616 |
+
elif input and instruction and PreInput is not None:
|
| 617 |
+
prompt += f"""{PreInput}{instruction}{input}"""
|
| 618 |
+
prompt = inject_newline(prompt_type, prompt)
|
| 619 |
+
elif input and instruction and PreInstruct is not None:
|
| 620 |
+
prompt += f"""{PreInstruct}{instruction}{input}"""
|
| 621 |
+
prompt = inject_newline(prompt_type, prompt)
|
| 622 |
+
elif input and instruction:
|
| 623 |
+
# i.e. for simple_instruct
|
| 624 |
+
prompt += f"""{instruction}: {input}"""
|
| 625 |
+
prompt = inject_newline(prompt_type, prompt)
|
| 626 |
+
elif input:
|
| 627 |
+
prompt += f"""{input}"""
|
| 628 |
+
prompt = inject_newline(prompt_type, prompt)
|
| 629 |
+
elif instruction:
|
| 630 |
+
prompt += f"""{instruction}"""
|
| 631 |
+
prompt = inject_newline(prompt_type, prompt)
|
| 632 |
+
|
| 633 |
+
if PreResponse is not None:
|
| 634 |
+
prompt += f"""{PreResponse}"""
|
| 635 |
+
pre_response = PreResponse # Don't use strip
|
| 636 |
+
else:
|
| 637 |
+
pre_response = ''
|
| 638 |
+
|
| 639 |
+
if output:
|
| 640 |
+
prompt += f"""{output}"""
|
| 641 |
+
|
| 642 |
+
return prompt, pre_response, terminate_response, chat_sep
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
def inject_newline(prompt_type, prompt):
|
| 646 |
+
if prompt_type not in [-1, '-1', 'plain', 'simple_instruct']:
|
| 647 |
+
# only add new line if structured prompt, while 'plain' is just generation of next tokens from input
|
| 648 |
+
prompt += '\n'
|
| 649 |
+
return prompt
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
class Prompter(object):
|
| 653 |
+
def __init__(self, prompt_type, debug=False, chat=False, stream_output=False, repeat_penalty=True,
|
| 654 |
+
allowed_repeat_line_length=10):
|
| 655 |
+
self.prompt_type = prompt_type
|
| 656 |
+
data_point = dict(instruction='', input='', output='')
|
| 657 |
+
_, self.pre_response, self.terminate_response, self.chat_sep = \
|
| 658 |
+
generate_prompt(data_point, prompt_type, chat, False)
|
| 659 |
+
self.debug = debug
|
| 660 |
+
self.chat = chat
|
| 661 |
+
self.stream_output = stream_output
|
| 662 |
+
self.repeat_penalty = repeat_penalty
|
| 663 |
+
self.allowed_repeat_line_length = allowed_repeat_line_length
|
| 664 |
+
self.prompt = None
|
| 665 |
+
context = "" # not for chat context
|
| 666 |
+
reduced = False # not for chat context
|
| 667 |
+
self.promptA, self.promptB, self.PreInstruct, self.PreInput, self.PreResponse, \
|
| 668 |
+
self.terminate_response, self.chat_sep, self.humanstr, self.botstr = \
|
| 669 |
+
get_prompt(prompt_type, chat, context, reduced)
|
| 670 |
+
|
| 671 |
+
def generate_prompt(self, data_point):
|
| 672 |
+
reduced = False
|
| 673 |
+
prompt, _, _, _ = generate_prompt(data_point, self.prompt_type, self.chat, reduced)
|
| 674 |
+
if self.debug:
|
| 675 |
+
print("prompt: ", prompt, flush=True)
|
| 676 |
+
self.prompt = prompt
|
| 677 |
+
return prompt
|
| 678 |
+
|
| 679 |
+
def get_response(self, outputs, prompt=None, sanitize_bot_response=True):
|
| 680 |
+
if isinstance(outputs, str):
|
| 681 |
+
outputs = [outputs]
|
| 682 |
+
if self.debug:
|
| 683 |
+
print("output:\n", '\n\n'.join(outputs), flush=True)
|
| 684 |
+
if prompt is not None:
|
| 685 |
+
self.prompt = prompt
|
| 686 |
+
|
| 687 |
+
def clean_response(response):
|
| 688 |
+
meaningless_words = ['<pad>', '</s>', '<|endoftext|>']
|
| 689 |
+
for word in meaningless_words:
|
| 690 |
+
response = response.replace(word, "")
|
| 691 |
+
if sanitize_bot_response:
|
| 692 |
+
from better_profanity import profanity
|
| 693 |
+
response = profanity.censor(response)
|
| 694 |
+
response = response.strip("\n")
|
| 695 |
+
return response
|
| 696 |
+
|
| 697 |
+
def clean_repeats(response):
|
| 698 |
+
lines = response.split('\n')
|
| 699 |
+
new_lines = []
|
| 700 |
+
[new_lines.append(line) for line in lines if
|
| 701 |
+
line not in new_lines or len(line) < self.allowed_repeat_line_length]
|
| 702 |
+
if self.debug and len(lines) != len(new_lines):
|
| 703 |
+
print("cleaned repeats: %s %s" % (len(lines), len(new_lines)), flush=True)
|
| 704 |
+
response = '\n'.join(new_lines)
|
| 705 |
+
return response
|
| 706 |
+
|
| 707 |
+
multi_output = len(outputs) > 1
|
| 708 |
+
|
| 709 |
+
for oi, output in enumerate(outputs):
|
| 710 |
+
if self.prompt_type in [PromptType.plain.value, str(PromptType.plain.value), PromptType.plain.name]:
|
| 711 |
+
output = clean_response(output)
|
| 712 |
+
elif prompt is None:
|
| 713 |
+
# then use most basic parsing like pipeline
|
| 714 |
+
if self.botstr in output:
|
| 715 |
+
if self.humanstr:
|
| 716 |
+
output = clean_response(output.split(self.botstr)[1].strip().split(self.humanstr)[0].strip())
|
| 717 |
+
else:
|
| 718 |
+
# i.e. use after bot but only up to next bot
|
| 719 |
+
output = clean_response(output.split(self.botstr)[1].strip().split(self.botstr)[0].strip())
|
| 720 |
+
else:
|
| 721 |
+
# output = clean_response(output.strip())
|
| 722 |
+
# assume just not printed yet
|
| 723 |
+
output = ""
|
| 724 |
+
else:
|
| 725 |
+
# find first instance of prereponse
|
| 726 |
+
# prompt sometimes has odd characters, that mutate length,
|
| 727 |
+
# so can't go by length alone
|
| 728 |
+
if self.pre_response:
|
| 729 |
+
outputi = output.find(prompt)
|
| 730 |
+
if outputi >= 0:
|
| 731 |
+
output = output[outputi + len(prompt):]
|
| 732 |
+
allow_terminate = True
|
| 733 |
+
else:
|
| 734 |
+
# subtraction is risky due to space offsets sometimes, so only do if necessary
|
| 735 |
+
output = output[len(prompt) - len(self.pre_response):]
|
| 736 |
+
# [1] to avoid repeated pre_response, just take first (after prompt - pre_response for chat)
|
| 737 |
+
if self.pre_response in output:
|
| 738 |
+
output = output.split(self.pre_response)[1]
|
| 739 |
+
allow_terminate = True
|
| 740 |
+
else:
|
| 741 |
+
if output:
|
| 742 |
+
print("Failure of parsing or not enough output yet: %s" % output, flush=True)
|
| 743 |
+
allow_terminate = False
|
| 744 |
+
else:
|
| 745 |
+
allow_terminate = True
|
| 746 |
+
output = output[len(prompt):]
|
| 747 |
+
# clean after subtract prompt out, so correct removal of pre_response
|
| 748 |
+
output = clean_response(output).strip()
|
| 749 |
+
if self.repeat_penalty:
|
| 750 |
+
output = clean_repeats(output).strip()
|
| 751 |
+
if self.terminate_response and allow_terminate:
|
| 752 |
+
finds = []
|
| 753 |
+
for term in self.terminate_response:
|
| 754 |
+
finds.append(output.find(term))
|
| 755 |
+
finds = [x for x in finds if x >= 0]
|
| 756 |
+
if len(finds) > 0:
|
| 757 |
+
termi = finds[0]
|
| 758 |
+
output = output[:termi].strip()
|
| 759 |
+
else:
|
| 760 |
+
output = output.strip()
|
| 761 |
+
else:
|
| 762 |
+
output = output.strip()
|
| 763 |
+
if multi_output:
|
| 764 |
+
# prefix with output counter
|
| 765 |
+
output = "\n=========== Output %d\n\n" % (1 + oi) + output
|
| 766 |
+
if oi > 0:
|
| 767 |
+
# post fix outputs with seperator
|
| 768 |
+
output += '\n'
|
| 769 |
+
outputs[oi] = output
|
| 770 |
+
# join all outputs, only one extra new line between outputs
|
| 771 |
+
output = '\n'.join(outputs)
|
| 772 |
+
if self.debug:
|
| 773 |
+
print("outputclean:\n", '\n\n'.join(outputs), flush=True)
|
| 774 |
+
return output
|