legolasyiu commited on
Commit
4723961
·
verified ·
1 Parent(s): df0b8e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -83
app.py CHANGED
@@ -1,91 +1,21 @@
 
1
  from transformers import AutoModelForCausalLM, AutoTokenizer
2
- import torch
3
  import gradio as gr
4
- from gradio import ChatMessage
5
- from typing import Iterator
6
 
7
  checkpoint = "EpistemeAI/metatune-gpt20b-R0"
8
- device = "cuda" if torch.cuda.is_available() else "cpu"
9
-
10
- # Load model + tokenizer
11
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
12
- model = AutoModelForCausalLM.from_pretrained(
13
- checkpoint,
14
- torch_dtype="auto",
15
- ).to(device)
16
-
17
- def format_history_for_model(messages):
18
- """Convert the message list into a single string prompt"""
19
- chat_prompt = ""
20
- for msg in messages:
21
- role = msg["role"]
22
- content = msg["content"]
23
- if role == "user":
24
- chat_prompt += f"User: {content}\n"
25
- else:
26
- chat_prompt += f"Assistant: {content}\n"
27
- return chat_prompt.strip()
28
-
29
- def stream_response(user_message: str, messages: list) -> Iterator[list]:
30
- try:
31
- print(f"User: {user_message}")
32
- prompt = format_history_for_model(messages) + f"\nUser: {user_message}\nAssistant:"
33
-
34
- # Tokenize
35
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
36
-
37
- # Stream output tokens
38
- generated = model.generate(
39
- **inputs,
40
- max_new_tokens=256,
41
- temperature=0.7,
42
- do_sample=True,
43
- top_p=0.9,
44
- repetition_penalty=1.1,
45
- pad_token_id=tokenizer.eos_token_id,
46
- )
47
- output_text = tokenizer.decode(generated[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
48
-
49
- # Send back message
50
- messages.append(ChatMessage(role="assistant", content=output_text))
51
- yield messages
52
-
53
- except Exception as e:
54
- messages.append(ChatMessage(role="assistant", content=f"Error: {str(e)}"))
55
- yield messages
56
-
57
- def user_message(msg: str, history: list):
58
- history.append(ChatMessage(role="user", content=msg))
59
- return "", history
60
-
61
- # --- UI ---
62
- with gr.Blocks(theme=gr.themes.Citrus(), fill_height=True) as demo:
63
- gr.Markdown("# Chat with Metatune GPT 20B 💭")
64
-
65
- chatbot = gr.Chatbot(type="messages", label="Metatune 20B Chatbot", render_markdown=True)
66
- with gr.Row():
67
- input_box = gr.Textbox(label="Message", placeholder="Type your message here...")
68
- clear_button = gr.Button("Clear")
69
-
70
- msg_store = gr.State("")
71
 
72
- input_box.submit(
73
- lambda msg: (msg, msg, ""),
74
- inputs=[input_box],
75
- outputs=[msg_store, input_box, input_box],
76
- queue=False,
77
- ).then(
78
- user_message,
79
- inputs=[msg_store, chatbot],
80
- outputs=[input_box, chatbot],
81
- queue=False,
82
- ).then(
83
- stream_response,
84
- inputs=[msg_store, chatbot],
85
- outputs=chatbot,
86
- )
87
 
88
- clear_button.click(lambda: ([], "", ""), outputs=[chatbot, input_box, msg_store])
89
 
90
- if __name__ == "__main__":
91
- demo.launch(debug=True)
 
1
+
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
3
  import gradio as gr
 
 
4
 
5
  checkpoint = "EpistemeAI/metatune-gpt20b-R0"
6
+ device = "cuda" # "cuda" or "cpu"
 
 
7
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
8
+ model = AutoModelForCausalLM.from_pretrained(checkpoint,torch_dtype="auto").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ def predict(message, history):
11
+ history.append({"role": "user", "assistant": message})
12
+ input_text = tokenizer.apply_chat_template(history, tokenize=False)
13
+ inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
14
+ outputs = model.generate(inputs, max_new_tokens=3200, temperature=0.2, top_p=0.9, do_sample=True)
15
+ decoded = tokenizer.decode(outputs[0])
16
+ response = decoded.split("<|im_start|>assistant\n")[-1].split("<|im_end|>")[0]
17
+ return response
 
 
 
 
 
 
 
18
 
19
+ demo = gr.ChatInterface(predict, type="messages")
20
 
21
+ demo.launch()