import streamlit as st import torch from transformers import AutoModelForCausalLM, TextIteratorStreamer, AutoConfig import gc from threading import Thread from Qwenov3Config import Qwenov3Config, Qwenov3 from PIL import Image MODEL_MAPPING = { 'QwenoV3-Pretrain': '', 'QwenoV3-SFT': '', } def unload_model(): if 'model' in st.session_state: del st.session_state.model if 'tokenizer' in st.session_state: del st.session_state.tokenizer if 'processor' in st.session_state: del st.session_state.processor if 'streamer' in st.session_state: del st.session_state.streamer torch.cuda.empty_cache() gc.collect() def call_model(info_placeholder, messages, generated_text, message_placeholder, image=None): info_placeholder.markdown(f'已选择{st.session_state.model_display}执行任务') if image is not None: image = Image.open(image).convert('RGB') if '' not in messages[1]['content']: messages[1]['content'] = '\n' + messages[1]['content'] query_text = st.session_state.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=False ) if '' in query_text: query_text = query_text.replace('', '<|vision_start|>' + '<|image_pad|>' * st.session_state.model.config.image_pad_num + '<|vision_end|>') text_inputs = st.session_state.tokenizer(query_text, return_tensors="pt") input_ids = text_inputs['input_ids'].to(st.session_state.model.device) attention_mask = text_inputs['attention_mask'].to(st.session_state.model.device) text_embeds = st.session_state.model.llm_model.get_input_embeddings()(input_ids) if image is not None: pixel_values = st.session_state.processor(images=image, return_tensors="pt")['pixel_values'].to( st.session_state.model.device) image_embeds = st.session_state.model.vision_model(pixel_values).last_hidden_state patch_embeds = image_embeds[:, 5:, :] b, num_patches, hidden_dim = patch_embeds.shape patch_embeds = patch_embeds.view(b, num_patches // 4, hidden_dim * 4) image_features = st.session_state.model.adapter(patch_embeds) text_embeds = text_embeds.to(image_features.dtype) inputs_embeds = st.session_state.model.merge_input_ids_with_image_features(image_features, text_embeds, input_ids) else: inputs_embeds = text_embeds generate_params = dict( inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=st.session_state.max_new_tokens, min_new_tokens=st.session_state.min_new_tokens, do_sample=True, temperature=st.session_state.temperature, top_k=st.session_state.top_k, top_p=st.session_state.top_p, min_p=0.0, repetition_penalty=st.session_state.repetition_penalty, streamer=st.session_state.streamer, eos_token_id=st.session_state.tokenizer.eos_token_id ) thread = Thread(target=st.session_state.model.llm_model.generate, kwargs=generate_params) thread.start() for new_text in st.session_state.streamer: generated_text += new_text message_placeholder.markdown(generated_text) return generated_text def ini_message(): if 'messages' not in st.session_state: st.session_state.messages = [ {"role": "system", "content": "You are QwenoV3, a helpful assistant created by 天烨."}, ] if 'uploaded_image' not in st.session_state: st.session_state.uploaded_image = None def parameter_settings(): with st.sidebar: previous_model = st.session_state.get('model_display', None) st.session_state.model_display = st.selectbox("选择模型", list(MODEL_MAPPING.keys()), index=len(MODEL_MAPPING.keys()) - 1, help="选择模型") st.session_state.model_path = MODEL_MAPPING[st.session_state.model_display] with st.expander("对话参数", expanded=False): col1, col2 = st.columns(2) with col1: st.session_state.temperature = st.slider("Temperature", 0.0, 2.0, 0.7, 0.1, help="控制模型回答的多样性,值越高表示回复多样性越高") st.session_state.min_new_tokens = st.number_input("Min Tokens", min_value=0, max_value=512, value=10, help="生成文本的最小长度") st.session_state.max_new_tokens = st.number_input("Max Tokens", min_value=1, max_value=4096, value=512, help="生成文本的最大长度") with col2: st.session_state.top_p = st.slider("Top P", 0.0, 1.0, 0.8, 0.1, help="控制词汇选择的多样性,值越高表示潜在生成词汇越多样") st.session_state.top_k = st.slider("Top K", 0, 80, 20, 1, help="控制词汇选择的多样性,值越高表示潜在生成词汇越多样") st.session_state.repetition_penalty = st.slider("Repetition Penalty", 0.0, 2.0, 1.05, 0.1, help="控制回复主题的多样性性,值越高重复性越低") with st.expander("图片上传", expanded=False): st.session_state.uploaded_image = st.file_uploader( "上传图片", type=["jpg", "jpeg", "png"] ) if st.session_state.uploaded_image: image = Image.open(st.session_state.uploaded_image) width, height = image.size if width > 256 or height > 256: scale = 256 / max(height, width) new_h, new_w = int(height * scale), int(width * scale) image = image.resize((new_w, new_h), Image.BILINEAR) st.image(image, caption="图片预览") if st.button("开启新对话", help="开启新对话将清空当前对话记录"): st.session_state.uploaded_image = None st.session_state.messages = [ {"role": "system", "content": "You are QwenoV3, a helpful assistant created by 天烨."}, ] st.success("已成功开启新的对话") st.rerun() if previous_model != st.session_state.model_display or 'tokenizer' not in st.session_state or 'model' not in st.session_state or 'processor' not in st.session_state: unload_model() try: with st.spinner('加载模型中...'): AutoConfig.register("Qwenov3", Qwenov3Config) AutoModelForCausalLM.register(Qwenov3Config, Qwenov3) st.session_state.model = AutoModelForCausalLM.from_pretrained( st.session_state.model_path, torch_dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True, trust_remote_code=True ) st.session_state.tokenizer = st.session_state.model.tokenizer st.session_state.processor = st.session_state.model.processor st.session_state.streamer = TextIteratorStreamer(st.session_state.tokenizer, skip_prompt=True, skip_special_tokens=True) except Exception as e: st.error('模型加载出错:', e) return def main(): st.markdown("""

QwenoV3 - Marrying DinoV3 With Qwen3 🫡

""", unsafe_allow_html=True) ini_message() parameter_settings() for message in st.session_state.messages: if message["role"] == "system": continue with st.chat_message(message["role"]): st.markdown(message["content"]) if user_input := st.chat_input("在这里输入您的问题:", key="chat_input"): with st.chat_message("user"): st.markdown(user_input) st.session_state.messages.append({"role": "user", "content": user_input}) with st.chat_message("assistant"): info_placeholder = st.empty() message_placeholder = st.empty() generated_text = "" try: with torch.inference_mode(): generated_text = call_model(info_placeholder, st.session_state.messages, generated_text, message_placeholder, st.session_state.uploaded_image) st.session_state.messages.append({"role": "assistant", "content": generated_text}) except Exception as e: st.error(f"生成回答时出错: {str(e)}") if __name__ == '__main__': main()