Upload predict.py with huggingface_hub
Browse files- predict.py +33 -5
predict.py
CHANGED
|
@@ -185,6 +185,8 @@ def get_conversation_template(model_path: str) -> Conversation:
|
|
| 185 |
"""Get the default conversation template."""
|
| 186 |
if "aquila-v1" in model_path:
|
| 187 |
return get_conv_template("aquila-v1")
|
|
|
|
|
|
|
| 188 |
elif "aquila-chat" in model_path:
|
| 189 |
return get_conv_template("aquila-chat")
|
| 190 |
elif "aquila-legacy" in model_path:
|
|
@@ -252,6 +254,21 @@ register_conv_template(
|
|
| 252 |
)
|
| 253 |
)
|
| 254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
|
| 256 |
if __name__ == "__main__":
|
| 257 |
print("aquila template:")
|
|
@@ -294,6 +311,17 @@ if __name__ == "__main__":
|
|
| 294 |
|
| 295 |
print("\n")
|
| 296 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
def set_random_seed(seed):
|
| 298 |
"""Set random seed for reproducability."""
|
| 299 |
if seed is not None and seed > 0:
|
|
@@ -330,9 +358,9 @@ def covert_prompt_to_input_ids_with_history(text, history, tokenizer, max_token,
|
|
| 330 |
return example
|
| 331 |
|
| 332 |
def predict(model, text, tokenizer=None,
|
| 333 |
-
max_gen_len=200, top_p=0.
|
| 334 |
-
seed=
|
| 335 |
-
temperature=0
|
| 336 |
sft=True, convo_template = "",
|
| 337 |
device = "cuda",
|
| 338 |
model_name="AquilaChat2-7B",
|
|
@@ -346,8 +374,8 @@ def predict(model, text, tokenizer=None,
|
|
| 346 |
|
| 347 |
template_map = {"AquilaChat2-7B": "aquila-v1",
|
| 348 |
"AquilaChat2-34B": "aquila-legacy",
|
| 349 |
-
"AquilaChat2-7B-16K": "aquila",
|
| 350 |
"AquilaChat2-70B-Expr": "aquila-v2",
|
|
|
|
| 351 |
"AquilaChat2-34B-16K": "aquila"}
|
| 352 |
if not convo_template:
|
| 353 |
convo_template=template_map.get(model_name, "aquila-chat")
|
|
@@ -357,7 +385,7 @@ def predict(model, text, tokenizer=None,
|
|
| 357 |
topk = 1
|
| 358 |
temperature = 1.0
|
| 359 |
if sft:
|
| 360 |
-
tokens = covert_prompt_to_input_ids_with_history(text, history=history, tokenizer=tokenizer, max_token=
|
| 361 |
tokens = torch.tensor(tokens)[None,].to(device)
|
| 362 |
else :
|
| 363 |
tokens = tokenizer.encode_plus(text)["input_ids"]
|
|
|
|
| 185 |
"""Get the default conversation template."""
|
| 186 |
if "aquila-v1" in model_path:
|
| 187 |
return get_conv_template("aquila-v1")
|
| 188 |
+
elif "aquila-v2" in model_path:
|
| 189 |
+
return get_conv_template("aquila-v2")
|
| 190 |
elif "aquila-chat" in model_path:
|
| 191 |
return get_conv_template("aquila-chat")
|
| 192 |
elif "aquila-legacy" in model_path:
|
|
|
|
| 254 |
)
|
| 255 |
)
|
| 256 |
|
| 257 |
+
register_conv_template(
|
| 258 |
+
Conversation(
|
| 259 |
+
name="aquila-v2",
|
| 260 |
+
system_message="A chat between a curious human and an artificial intelligence assistant. "
|
| 261 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
| 262 |
+
roles=("<|startofpiece|>", "<|endofpiece|>", ""),
|
| 263 |
+
messages=(),
|
| 264 |
+
offset=0,
|
| 265 |
+
sep_style=SeparatorStyle.NO_COLON_TWO,
|
| 266 |
+
sep="",
|
| 267 |
+
sep2="</s>",
|
| 268 |
+
stop_str=["</s>", "<|endoftext|>", "<|startofpiece|>", "<|endofpiece|>"],
|
| 269 |
+
)
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
|
| 273 |
if __name__ == "__main__":
|
| 274 |
print("aquila template:")
|
|
|
|
| 311 |
|
| 312 |
print("\n")
|
| 313 |
|
| 314 |
+
print("aquila-v2 template:")
|
| 315 |
+
conv = get_conv_template("aquila-v2")
|
| 316 |
+
conv.append_message(conv.roles[0], "Hello!")
|
| 317 |
+
conv.append_message(conv.roles[1], "Hi!")
|
| 318 |
+
conv.append_message(conv.roles[0], "How are you?")
|
| 319 |
+
conv.append_message(conv.roles[1], None)
|
| 320 |
+
print(conv.get_prompt())
|
| 321 |
+
|
| 322 |
+
print("\n")
|
| 323 |
+
|
| 324 |
+
|
| 325 |
def set_random_seed(seed):
|
| 326 |
"""Set random seed for reproducability."""
|
| 327 |
if seed is not None and seed > 0:
|
|
|
|
| 358 |
return example
|
| 359 |
|
| 360 |
def predict(model, text, tokenizer=None,
|
| 361 |
+
max_gen_len=200, top_p=0.9,
|
| 362 |
+
seed=123, topk=15,
|
| 363 |
+
temperature=1.0,
|
| 364 |
sft=True, convo_template = "",
|
| 365 |
device = "cuda",
|
| 366 |
model_name="AquilaChat2-7B",
|
|
|
|
| 374 |
|
| 375 |
template_map = {"AquilaChat2-7B": "aquila-v1",
|
| 376 |
"AquilaChat2-34B": "aquila-legacy",
|
|
|
|
| 377 |
"AquilaChat2-70B-Expr": "aquila-v2",
|
| 378 |
+
"AquilaChat2-7B-16K": "aquila",
|
| 379 |
"AquilaChat2-34B-16K": "aquila"}
|
| 380 |
if not convo_template:
|
| 381 |
convo_template=template_map.get(model_name, "aquila-chat")
|
|
|
|
| 385 |
topk = 1
|
| 386 |
temperature = 1.0
|
| 387 |
if sft:
|
| 388 |
+
tokens = covert_prompt_to_input_ids_with_history(text, history=history, tokenizer=tokenizer, max_token=20480, convo_template=convo_template)
|
| 389 |
tokens = torch.tensor(tokens)[None,].to(device)
|
| 390 |
else :
|
| 391 |
tokens = tokenizer.encode_plus(text)["input_ids"]
|