Spaces:
Paused
Paused
| import json | |
| import torch | |
| import streamlit as st | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from transformers.generation.utils import GenerationConfig | |
| st.set_page_config(page_title="Baichuan-13B-Chat") | |
| st.title("Baichuan-13B-Chat") | |
| def init_model(): | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "baichuan-inc/Baichuan-13B-Chat", | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| model.generation_config = GenerationConfig.from_pretrained( | |
| "baichuan-inc/Baichuan-13B-Chat" | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "baichuan-inc/Baichuan-13B-Chat", | |
| use_fast=False, | |
| trust_remote_code=True | |
| ) | |
| return model, tokenizer | |
| def clear_chat_history(): | |
| del st.session_state.messages | |
| def init_chat_history(): | |
| with st.chat_message("assistant", avatar='π€'): | |
| st.markdown("Greetings! I am the BaiChuan large language model, delighted to assist you.π₯°") | |
| if "messages" in st.session_state: | |
| for message in st.session_state.messages: | |
| avatar = 'π§βπ»' if message["role"] == "user" else 'π€' | |
| with st.chat_message(message["role"], avatar=avatar): | |
| st.markdown(message["content"]) | |
| else: | |
| st.session_state.messages = [] | |
| return st.session_state.messages | |
| def main(): | |
| model, tokenizer = init_model() | |
| messages = init_chat_history() | |
| if prompt := st.chat_input("Shift + Enter for a new line, Enter to send"): | |
| with st.chat_message("user", avatar='π§βπ»'): | |
| st.markdown(prompt) | |
| messages.append({"role": "user", "content": prompt}) | |
| print(f"[user] {prompt}", flush=True) | |
| with st.chat_message("assistant", avatar='π€'): | |
| placeholder = st.empty() | |
| for response in model.chat(tokenizer, messages, stream=True): | |
| placeholder.markdown(response) | |
| if torch.backends.mps.is_available(): | |
| torch.mps.empty_cache() | |
| messages.append({"role": "assistant", "content": response}) | |
| print(json.dumps(messages, ensure_ascii=False), flush=True) | |
| st.button("Reset Chat", on_click=clear_chat_history) | |
| if __name__ == "__main__": | |
| main() | |