try: import spaces except ImportError: # Local run: define dummy decorator class spaces: @staticmethod def GPU(duration=10): def dummy(func): return func return dummy import argparse import json import time import gradio as gr from filelock import FileLock from PIL import Image import threading from utils import ( build_logger, server_error_msg, violates_moderation, moderation_msg, get_log_filename, ) from conversation import Conversation from model import ( FullSequenceStreamer, get_model, ) logger = build_logger("dimple", "dimple.log") no_change_btn = gr.Button() enable_btn = gr.Button(interactive=True) disable_btn = gr.Button(interactive=False) @spaces.GPU(duration=10) def make_zerogpu_happy(): pass def write2file(path, content): lock = FileLock(f"{path}.lock") with lock: with open(path, "a") as fout: fout.write(content) model, processor = get_model("cuda:0") get_window_url_params = """ function() { const params = new URLSearchParams(window.location.search); url_params = Object.fromEntries(params); console.log(url_params); return url_params; } """ def init_state(state=None): if state is not None: del state return Conversation() def vote_last_response(state, liked, request: gr.Request): conv_data = { "tstamp": round(time.time(), 4), "like": liked, "model": '"rp-yu/Dimple-7B"', "state": state.dict(), "ip": request.client.host, } write2file(get_log_filename(), json.dumps(conv_data) + "\n") def upvote_last_response(state, request: gr.Request): logger.info(f"upvote. ip: {request.client.host}") vote_last_response(state, True, request) textbox = gr.MultimodalTextbox(value=None, interactive=True) return (textbox,) + (disable_btn,) * 3 def downvote_last_response(state, request: gr.Request): logger.info(f"downvote. ip: {request.client.host}") vote_last_response(state, False, request) textbox = gr.MultimodalTextbox(value=None, interactive=True) return (textbox,) + (disable_btn,) * 3 def vote_selected_response( state, request: gr.Request, data: gr.LikeData ): logger.info( f"Vote: {data.liked}, index: {data.index}, value: {data.value} , ip: {request.client.host}" ) conv_data = { "tstamp": round(time.time(), 4), "like": data.liked, "index": data.index, "model": 'rp-yu/Dimple-7B', "state": state.dict(), "ip": request.client.host, } write2file(get_log_filename(), json.dumps(conv_data) + "\n") return def flag_last_response(state, request: gr.Request): logger.info(f"flag. ip: {request.client.host}") vote_last_response(state, "flag", request) textbox = gr.MultimodalTextbox(value=None, interactive=True) return (textbox,) + (disable_btn,) * 3 def regenerate(state, image_process_mode, request: gr.Request): logger.info(f"regenerate. ip: {request.client.host}") # state.messages[-1][-1] = None state.update_message(Conversation.ASSISTANT, content='', image=None, idx=-1) prev_human_msg = state.messages[-2] if type(prev_human_msg[1]) in (tuple, list): prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode) state.skip_next = False textbox = gr.MultimodalTextbox(value=None, interactive=True) return (state, state.to_gradio_chatbot(), textbox) + (disable_btn,) * 5 def clear_history(request: gr.Request): logger.info(f"clear_history. ip: {request.client.host}") state = init_state() textbox = gr.MultimodalTextbox(value=None, interactive=True) return (state, state.to_gradio_chatbot(), textbox) + (disable_btn,) * 5 def add_text(state, message, system_prompt, request: gr.Request): print(f"state: {state}") if not state: state = init_state() images = message.get("files", []) text = message.get("text", "").strip() logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") # import pdb; pdb.set_trace() textbox = gr.MultimodalTextbox(value=None, interactive=False) if len(text) <= 0 and len(images) == 0: state.skip_next = True return (state, state.to_gradio_chatbot(), textbox) + (no_change_btn,) * 5 if args.moderate: flagged = violates_moderation(text) if flagged: state.skip_next = True textbox = gr.MultimodalTextbox( value={"text": moderation_msg}, interactive=True ) return (state, state.to_gradio_chatbot(), textbox) + (no_change_btn,) * 5 images = [Image.open(path).convert("RGB") for path in images] if len(images) > 0 and len(state.get_images(source=state.USER)) > 0: state = init_state(state) state.set_system_message(system_prompt) state.append_message(Conversation.USER, text, images) state.skip_next = False return (state, state.to_gradio_chatbot(), textbox) + ( disable_btn, ) * 5 def http_bot( state, temperature, top_p, p_threshold, alg_temp, max_new_tokens, steps, alg, ): start_tstamp = time.time() if hasattr(state, "skip_next") and state.skip_next: # This generate call is skipped due to invalid inputs yield ( state, state.to_gradio_chatbot(), gr.MultimodalTextbox(interactive=False), ) + (no_change_btn,) * 5 return all_images = state.get_images(source=state.USER) all_image_paths = [state.save_image(image) for image in all_images] if len(all_images) == 0: all_images = None messages = state.get_prompt() text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, add_vision_id=False ) inputs = processor( text=text, images=all_images, videos=None, padding="longest", return_tensors="pt", ).to(model.device) input_ids = inputs.pop("input_ids") streamer = FullSequenceStreamer( processor.tokenizer, timeout=10, skip_special_tokens=True, ) def run_generate(): output = model.diffusion_generate( input_ids, max_new_tokens=int(max_new_tokens), output_history=True, return_dict_in_generate=True, steps=int(steps), temperature=float(temperature), top_p=float(top_p), alg=alg, alg_temp = float(alg_temp), use_cache=True, alg_p_threshold=float(p_threshold), use_original_confidence=True, decoding_pipeline="dim", streamer = streamer, **inputs ) thread = threading.Thread(target=run_generate) thread.start() logger.info(f"==== wait for first token ====\n") state.append_message(Conversation.ASSISTANT, state.streaming_placeholder) yield ( state, state.to_gradio_chatbot(), gr.MultimodalTextbox(interactive=False), ) + (disable_btn,) * 5 try: # Stream output for ans in streamer: if len(ans) > 1: ans = "\n".join(ans) else: ans = ans[0] state.update_message(Conversation.ASSISTANT, ans, None) yield ( state, state.to_gradio_chatbot(), gr.MultimodalTextbox(interactive=False), ) + (disable_btn,) * 5 except Exception as e: state.update_message(Conversation.ASSISTANT, server_error_msg, None) yield ( state, state.to_gradio_chatbot(), gr.MultimodalTextbox(interactive=True), ) + ( disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, ) return state.end_of_current_turn() yield ( state, state.to_gradio_chatbot(), gr.MultimodalTextbox(interactive=True), ) + (enable_btn,) * 5 finish_tstamp = time.time() logger.info(f"{ans}") data = { "tstamp": round(finish_tstamp, 4), "like": None, "model": "rp-yu/Dimple-7B", "start": round(start_tstamp, 4), "finish": round(start_tstamp, 4), "state": state.dict(), "images": all_image_paths, } write2file(get_log_filename(), json.dumps(data) + "\n") title_html = """
Dimple: Discrete Diffusion Multimodal Large Language Model with Parallel Decoding
[đ Dimple Paper]