Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Commit 
							
							·
						
						2ca0c5e
	
1
								Parent(s):
							
							d96a4ed
								
initial commit
Browse files- .gitignore +3 -0
- app.py +209 -52
- chatstate.py +94 -0
- img/bot.png +0 -0
- img/gemma.png +0 -0
- img/keras_logo_k.png +0 -0
- img/llama.png +0 -0
- img/mistral.png +0 -0
- img/usr.png +0 -0
- img/vicuna.png +0 -0
- models.py +105 -0
- requirements.txt +6 -1
    	
        .gitignore
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            .DS_Store
         | 
| 2 | 
            +
            .vscode
         | 
| 3 | 
            +
            __pycache__
         | 
    	
        app.py
    CHANGED
    
    | @@ -1,63 +1,220 @@ | |
|  | |
|  | |
|  | |
|  | |
| 1 | 
             
            import gradio as gr
         | 
| 2 | 
            -
            from  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 3 |  | 
| 4 | 
            -
             | 
| 5 | 
            -
             | 
| 6 | 
            -
            "" | 
| 7 | 
            -
             | 
|  | |
|  | |
| 8 |  | 
|  | |
|  | |
|  | |
| 9 |  | 
| 10 | 
            -
             | 
|  | |
| 11 | 
             
                message,
         | 
| 12 | 
            -
                 | 
|  | |
|  | |
|  | |
| 13 | 
             
                system_message,
         | 
| 14 | 
            -
                max_tokens,
         | 
| 15 | 
            -
                temperature,
         | 
| 16 | 
            -
                top_p,
         | 
| 17 | 
             
            ):
         | 
| 18 | 
            -
                 | 
| 19 | 
            -
             | 
| 20 | 
            -
                 | 
| 21 | 
            -
             | 
| 22 | 
            -
             | 
| 23 | 
            -
             | 
| 24 | 
            -
             | 
| 25 | 
            -
             | 
| 26 | 
            -
             | 
| 27 | 
            -
             | 
| 28 | 
            -
                 | 
| 29 | 
            -
             | 
| 30 | 
            -
             | 
| 31 | 
            -
             | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
| 35 | 
            -
             | 
| 36 | 
            -
             | 
| 37 | 
            -
             | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
| 40 | 
            -
                     | 
| 41 | 
            -
             | 
| 42 | 
            -
             | 
| 43 | 
            -
            "" | 
| 44 | 
            -
             | 
| 45 | 
            -
            "" | 
| 46 | 
            -
             | 
| 47 | 
            -
                 | 
| 48 | 
            -
                 | 
| 49 | 
            -
             | 
| 50 | 
            -
             | 
| 51 | 
            -
             | 
| 52 | 
            -
             | 
| 53 | 
            -
             | 
| 54 | 
            -
             | 
| 55 | 
            -
             | 
| 56 | 
            -
             | 
| 57 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 58 | 
             
                    ),
         | 
| 59 | 
            -
             | 
| 60 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 61 |  | 
| 62 |  | 
| 63 | 
             
            if __name__ == "__main__":
         | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            os.environ["KERAS_BACKEND"] = "jax"
         | 
| 4 | 
            +
             | 
| 5 | 
             
            import gradio as gr
         | 
| 6 | 
            +
            from gradio import ChatMessage
         | 
| 7 | 
            +
            import keras_hub
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from chatstate import ChatState
         | 
| 10 | 
            +
            from models import (
         | 
| 11 | 
            +
                model_presets,
         | 
| 12 | 
            +
                load_model,
         | 
| 13 | 
            +
                model_labels,
         | 
| 14 | 
            +
                preset_to_website_url,
         | 
| 15 | 
            +
                get_appropriate_chat_template,
         | 
| 16 | 
            +
            )
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            model_labels_list = list(model_labels)
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            # lod a warm up (compile) all the models
         | 
| 21 | 
            +
            models = []
         | 
| 22 | 
            +
            for preset in model_presets:
         | 
| 23 | 
            +
                model = load_model(preset)
         | 
| 24 | 
            +
                chat_template = get_appropriate_chat_template(preset)
         | 
| 25 | 
            +
                chat_state = ChatState(model, "", chat_template)
         | 
| 26 | 
            +
                prompt, response = chat_state.send_message("Hello")
         | 
| 27 | 
            +
                print("model " + preset + "loaded and initialized.")
         | 
| 28 | 
            +
                print("The model responded: " + response)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            models = [load_model(preset) for preset in model_presets]
         | 
| 31 | 
            +
            # model = keras_hub.models.Llama3CausalLM.from_preset(
         | 
| 32 | 
            +
            #     "hf://meta-llama/Llama-3.2-1B-Instruct", dtype="bfloat16"
         | 
| 33 | 
            +
            # )
         | 
| 34 | 
            +
            # models = [model, model]
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def chat_turn_assistant_1(
         | 
| 38 | 
            +
                model,
         | 
| 39 | 
            +
                message,
         | 
| 40 | 
            +
                history,
         | 
| 41 | 
            +
                system_message,
         | 
| 42 | 
            +
                preset,
         | 
| 43 | 
            +
                # max_tokens,
         | 
| 44 | 
            +
                # temperature,
         | 
| 45 | 
            +
                # top_p,
         | 
| 46 | 
            +
            ):
         | 
| 47 | 
            +
                chat_template = get_appropriate_chat_template(preset)
         | 
| 48 | 
            +
                chat_state = ChatState(model, system_message, chat_template)
         | 
| 49 |  | 
| 50 | 
            +
                for msg in history:
         | 
| 51 | 
            +
                    msg = ChatMessage(**msg)
         | 
| 52 | 
            +
                    if msg.role == "user":
         | 
| 53 | 
            +
                        chat_state.add_to_history_as_user(msg.content)
         | 
| 54 | 
            +
                    elif msg.role == "assistant":
         | 
| 55 | 
            +
                        chat_state.add_to_history_as_model(msg.content)
         | 
| 56 |  | 
| 57 | 
            +
                prompt, response = chat_state.send_message(message)
         | 
| 58 | 
            +
                history.append(ChatMessage(role="assistant", content=response))
         | 
| 59 | 
            +
                return history
         | 
| 60 |  | 
| 61 | 
            +
             | 
| 62 | 
            +
            def chat_turn_assistant(
         | 
| 63 | 
             
                message,
         | 
| 64 | 
            +
                sel1,
         | 
| 65 | 
            +
                history1,
         | 
| 66 | 
            +
                sel2,
         | 
| 67 | 
            +
                history2,
         | 
| 68 | 
             
                system_message,
         | 
| 69 | 
            +
                # max_tokens,
         | 
| 70 | 
            +
                # temperature,
         | 
| 71 | 
            +
                # top_p,
         | 
| 72 | 
             
            ):
         | 
| 73 | 
            +
                history1 = chat_turn_assistant_1(
         | 
| 74 | 
            +
                    models[sel1], message, history1, system_message, model_presets[sel1]
         | 
| 75 | 
            +
                )
         | 
| 76 | 
            +
                history2 = chat_turn_assistant_1(
         | 
| 77 | 
            +
                    models[sel2], message, history2, system_message, model_presets[sel2]
         | 
| 78 | 
            +
                )
         | 
| 79 | 
            +
                return "", history1, history2
         | 
| 80 | 
            +
             | 
| 81 | 
            +
             | 
| 82 | 
            +
            def chat_turn_user_1(message, history):
         | 
| 83 | 
            +
                history.append(ChatMessage(role="user", content=message))
         | 
| 84 | 
            +
                return history
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            def chat_turn_user(message, history1, history2):
         | 
| 88 | 
            +
                history1 = chat_turn_user_1(message, history1)
         | 
| 89 | 
            +
                history2 = chat_turn_user_1(message, history2)
         | 
| 90 | 
            +
                return "", history1, history2
         | 
| 91 | 
            +
             | 
| 92 | 
            +
             | 
| 93 | 
            +
            def bot_icon_select(model_name):
         | 
| 94 | 
            +
                if "gemma" in model_name:
         | 
| 95 | 
            +
                    return "img/gemma.png"
         | 
| 96 | 
            +
                elif "llama" in model_name:
         | 
| 97 | 
            +
                    return "img/llama.png"
         | 
| 98 | 
            +
                elif "vicuna" in model_name:
         | 
| 99 | 
            +
                    return "img/vicuna.png"
         | 
| 100 | 
            +
                elif "mistral" in model_name:
         | 
| 101 | 
            +
                    return "img/mistral.png"
         | 
| 102 | 
            +
                # default
         | 
| 103 | 
            +
                return "img/bot.png"
         | 
| 104 | 
            +
             | 
| 105 | 
            +
             | 
| 106 | 
            +
            def instantiate_chatbots(sel1, sel2):
         | 
| 107 | 
            +
                model_name1 = model_presets[sel1]
         | 
| 108 | 
            +
                chatbot1 = gr.Chatbot(
         | 
| 109 | 
            +
                    type="messages",
         | 
| 110 | 
            +
                    show_label=False,
         | 
| 111 | 
            +
                    avatar_images=("img/usr.png", bot_icon_select(model_name1)),
         | 
| 112 | 
            +
                )
         | 
| 113 | 
            +
                model_name2 = model_presets[sel2]
         | 
| 114 | 
            +
                chatbot2 = gr.Chatbot(
         | 
| 115 | 
            +
                    type="messages",
         | 
| 116 | 
            +
                    show_label=False,
         | 
| 117 | 
            +
                    avatar_images=("img/usr.png", bot_icon_select(model_name2)),
         | 
| 118 | 
            +
                )
         | 
| 119 | 
            +
                return chatbot1, chatbot2
         | 
| 120 | 
            +
             | 
| 121 | 
            +
             | 
| 122 | 
            +
            def instantiate_select_boxes(sel1, sel2, model_labels):
         | 
| 123 | 
            +
                sel1 = gr.Dropdown(
         | 
| 124 | 
            +
                    choices=[(name, i) for i, name in enumerate(model_labels)],
         | 
| 125 | 
            +
                    show_label=False,
         | 
| 126 | 
            +
                    info="<span style='color:black'>Selected model 1:</span> "
         | 
| 127 | 
            +
                    + "<a href='"
         | 
| 128 | 
            +
                    + preset_to_website_url(model_presets[sel1])
         | 
| 129 | 
            +
                    + "'>"
         | 
| 130 | 
            +
                    + preset_to_website_url(model_presets[sel1])
         | 
| 131 | 
            +
                    + "</a>",
         | 
| 132 | 
            +
                    value=sel1,
         | 
| 133 | 
            +
                )
         | 
| 134 | 
            +
                sel2 = gr.Dropdown(
         | 
| 135 | 
            +
                    choices=[(name, i) for i, name in enumerate(model_labels)],
         | 
| 136 | 
            +
                    show_label=False,
         | 
| 137 | 
            +
                    info="<span style='color:black'>Selected model 2:</span> "
         | 
| 138 | 
            +
                    + "<a href='"
         | 
| 139 | 
            +
                    + preset_to_website_url(model_presets[sel2])
         | 
| 140 | 
            +
                    + "'>"
         | 
| 141 | 
            +
                    + preset_to_website_url(model_presets[sel2])
         | 
| 142 | 
            +
                    + "</a>",
         | 
| 143 | 
            +
                    value=sel2,
         | 
| 144 | 
            +
                )
         | 
| 145 | 
            +
                return sel1, sel2
         | 
| 146 | 
            +
             | 
| 147 | 
            +
             | 
| 148 | 
            +
            def instantiate_chatbots_and_select_boxes(sel1, sel2, model_labels):
         | 
| 149 | 
            +
                chatbot1, chatbot2 = instantiate_chatbots(sel1, sel2)
         | 
| 150 | 
            +
                sel1, sel2 = instantiate_select_boxes(sel1, sel2, model_labels)
         | 
| 151 | 
            +
                return sel1, chatbot1, sel2, chatbot2
         | 
| 152 | 
            +
             | 
| 153 | 
            +
             | 
| 154 | 
            +
            with gr.Blocks(fill_width=True, title="Keras demo") as demo:
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                with gr.Row():
         | 
| 157 | 
            +
                    gr.Image(
         | 
| 158 | 
            +
                        "img/keras_logo_k.png",
         | 
| 159 | 
            +
                        width=80,
         | 
| 160 | 
            +
                        height=80,
         | 
| 161 | 
            +
                        min_width=80,
         | 
| 162 | 
            +
                        show_label=False,
         | 
| 163 | 
            +
                        show_download_button=False,
         | 
| 164 | 
            +
                        show_fullscreen_button=False,
         | 
| 165 | 
            +
                        interactive=False,
         | 
| 166 | 
            +
                        scale=0.01,
         | 
| 167 | 
            +
                        container=False,
         | 
| 168 | 
            +
                    )
         | 
| 169 | 
            +
                    gr.HTML(
         | 
| 170 | 
            +
                        "<H2> Battle of the Keras chatbots on TPU</H2>"
         | 
| 171 | 
            +
                        + "All the models are loaded into the TPU memory. "
         | 
| 172 | 
            +
                        + "You can call them at will and compare their answers. <br/>"
         | 
| 173 | 
            +
                        + "The entire chat history is fed to the models at every submission."
         | 
| 174 | 
            +
                        + "This demno is runnig on a Google TPU v5e 2x4 (8 cores).",
         | 
| 175 | 
            +
                    )
         | 
| 176 | 
            +
                with gr.Row():
         | 
| 177 | 
            +
                    sel1, sel2 = instantiate_select_boxes(0, 1, model_labels_list)
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                with gr.Row():
         | 
| 180 | 
            +
                    chatbot1, chatbot2 = instantiate_chatbots(sel1.value, sel2.value)
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                msg = gr.Textbox(
         | 
| 183 | 
            +
                    label="Your message:",
         | 
| 184 | 
            +
                )
         | 
| 185 | 
            +
                with gr.Row():
         | 
| 186 | 
            +
                    gr.ClearButton([msg, chatbot1, chatbot2])
         | 
| 187 | 
            +
                    with gr.Accordion("Additional settings", open=False):
         | 
| 188 | 
            +
                        system_message = gr.Textbox(
         | 
| 189 | 
            +
                            label="Sytem prompt",
         | 
| 190 | 
            +
                            value="You are a helpful assistant and your name is Eliza.",
         | 
| 191 | 
            +
                        )
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                sel1.select(
         | 
| 194 | 
            +
                    lambda sel1, sel2: instantiate_chatbots_and_select_boxes(
         | 
| 195 | 
            +
                        sel1, sel2, model_labels_list
         | 
| 196 | 
             
                    ),
         | 
| 197 | 
            +
                    inputs=[sel1, sel2],
         | 
| 198 | 
            +
                    outputs=[sel1, chatbot1, sel2, chatbot2],
         | 
| 199 | 
            +
                )
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                sel2.select(
         | 
| 202 | 
            +
                    lambda sel1, sel2: instantiate_chatbots_and_select_boxes(
         | 
| 203 | 
            +
                        sel1, sel2, model_labels_list
         | 
| 204 | 
            +
                    ),
         | 
| 205 | 
            +
                    inputs=[sel1, sel2],
         | 
| 206 | 
            +
                    outputs=[sel1, chatbot1, sel2, chatbot2],
         | 
| 207 | 
            +
                )
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                msg.submit(
         | 
| 210 | 
            +
                    chat_turn_user,
         | 
| 211 | 
            +
                    inputs=[msg, chatbot1, chatbot2],
         | 
| 212 | 
            +
                    outputs=[msg, chatbot1, chatbot2],
         | 
| 213 | 
            +
                ).then(
         | 
| 214 | 
            +
                    chat_turn_assistant,
         | 
| 215 | 
            +
                    [msg, sel1, chatbot1, sel2, chatbot2, system_message],
         | 
| 216 | 
            +
                    outputs=[msg, chatbot1, chatbot2],
         | 
| 217 | 
            +
                )
         | 
| 218 |  | 
| 219 |  | 
| 220 | 
             
            if __name__ == "__main__":
         | 
    	
        chatstate.py
    ADDED
    
    | @@ -0,0 +1,94 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # chat helper
         | 
| 2 | 
            +
            class ChatState:
         | 
| 3 | 
            +
             | 
| 4 | 
            +
                def __init__(self, model, system="", chat_template="auto"):
         | 
| 5 | 
            +
                    chat_template = (
         | 
| 6 | 
            +
                        type(model).__name__ if chat_template == "auto" else chat_template
         | 
| 7 | 
            +
                    )
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                    if chat_template == "Llama3CausalLM":
         | 
| 10 | 
            +
                        self.__START_TURN_SYSTEM__ = (
         | 
| 11 | 
            +
                            "<|start_header_id|>system<|end_header_id|>\n\n"
         | 
| 12 | 
            +
                        )
         | 
| 13 | 
            +
                        self.__START_TURN_USER__ = (
         | 
| 14 | 
            +
                            "<|start_header_id|>user<|end_header_id|>\n\n"
         | 
| 15 | 
            +
                        )
         | 
| 16 | 
            +
                        self.__START_TURN_MODEL__ = (
         | 
| 17 | 
            +
                            "<|start_header_id|>assistant<|end_header_id|>\n\n"
         | 
| 18 | 
            +
                        )
         | 
| 19 | 
            +
                        self.__END_TURN_SYSTEM__ = "<|eot_id|>"
         | 
| 20 | 
            +
                        self.__END_TURN_USER__ = "<|eot_id|>"
         | 
| 21 | 
            +
                        self.__END_TURN_MODEL__ = "<|eot_id|>"
         | 
| 22 | 
            +
                        print("Using chat template for: Llama")
         | 
| 23 | 
            +
                    elif chat_template == "GemmaCausalLM":
         | 
| 24 | 
            +
                        self.__START_TURN_SYSTEM__ = ""
         | 
| 25 | 
            +
                        self.__START_TURN_USER__ = "<start_of_turn>user\n"
         | 
| 26 | 
            +
                        self.__START_TURN_MODEL__ = "<start_of_turn>model\n"
         | 
| 27 | 
            +
                        self.__END_TURN_SYSTEM__ = "\n"
         | 
| 28 | 
            +
                        self.__END_TURN_USER__ = "<end_of_turn>\n"
         | 
| 29 | 
            +
                        self.__END_TURN_MODEL__ = "<end_of_turn>\n"
         | 
| 30 | 
            +
                        print("Using chat template for: Gemma")
         | 
| 31 | 
            +
                    elif chat_template == "MistralCausalLM":
         | 
| 32 | 
            +
                        self.__START_TURN_SYSTEM__ = ""
         | 
| 33 | 
            +
                        self.__START_TURN_USER__ = "[INST]"
         | 
| 34 | 
            +
                        self.__START_TURN_MODEL__ = ""
         | 
| 35 | 
            +
                        self.__END_TURN_SYSTEM__ = "<s>"
         | 
| 36 | 
            +
                        self.__END_TURN_USER__ = "[/INST]"
         | 
| 37 | 
            +
                        self.__END_TURN_MODEL__ = "</s>"
         | 
| 38 | 
            +
                        print("Using chat template for: Mistral")
         | 
| 39 | 
            +
                    elif chat_template == "Vicuna":
         | 
| 40 | 
            +
                        self.__START_TURN_SYSTEM__ = ""
         | 
| 41 | 
            +
                        self.__START_TURN_USER__ = "USER: "
         | 
| 42 | 
            +
                        self.__START_TURN_MODEL__ = "ASSISTANT: "
         | 
| 43 | 
            +
                        self.__END_TURN_SYSTEM__ = "\n\n"
         | 
| 44 | 
            +
                        self.__END_TURN_USER__ = "\n"
         | 
| 45 | 
            +
                        self.__END_TURN_MODEL__ = "</s>\n"
         | 
| 46 | 
            +
                        print("Using chat template for : Vicuna")
         | 
| 47 | 
            +
                    else:
         | 
| 48 | 
            +
                        assert (0, "Unknown turn tags for this model class")
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    self.model = model
         | 
| 51 | 
            +
                    self.system = system
         | 
| 52 | 
            +
                    self.history = []
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                def add_to_history_as_user(self, message):
         | 
| 55 | 
            +
                    self.history.append(
         | 
| 56 | 
            +
                        self.__START_TURN_USER__ + message + self.__END_TURN_USER__
         | 
| 57 | 
            +
                    )
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def add_to_history_as_model(self, message):
         | 
| 60 | 
            +
                    self.history.append(
         | 
| 61 | 
            +
                        self.__START_TURN_MODEL__ + message + self.__END_TURN_MODEL__
         | 
| 62 | 
            +
                    )
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                def get_history(self):
         | 
| 65 | 
            +
                    return "".join([*self.history])
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                def get_full_prompt(self):
         | 
| 68 | 
            +
                    prompt = self.get_history() + self.__START_TURN_MODEL__
         | 
| 69 | 
            +
                    if len(self.system) > 0:
         | 
| 70 | 
            +
                        prompt = (
         | 
| 71 | 
            +
                            self.__START_TURN_SYSTEM__
         | 
| 72 | 
            +
                            + self.system
         | 
| 73 | 
            +
                            + self.__END_TURN_SYSTEM__
         | 
| 74 | 
            +
                            + prompt
         | 
| 75 | 
            +
                        )
         | 
| 76 | 
            +
                    return prompt
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                def send_message(self, message):
         | 
| 79 | 
            +
                    """
         | 
| 80 | 
            +
                    Handles sending a user message and getting a model response.
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    Args:
         | 
| 83 | 
            +
                        message: The user's message.
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    Returns:
         | 
| 86 | 
            +
                        The model's response.
         | 
| 87 | 
            +
                    """
         | 
| 88 | 
            +
                    self.add_to_history_as_user(message)
         | 
| 89 | 
            +
                    prompt = self.get_full_prompt()
         | 
| 90 | 
            +
                    response = self.model.generate(
         | 
| 91 | 
            +
                        prompt, max_length=1024, strip_prompt=True
         | 
| 92 | 
            +
                    )
         | 
| 93 | 
            +
                    self.add_to_history_as_model(response)
         | 
| 94 | 
            +
                    return (message, response)
         | 
    	
        img/bot.png
    ADDED
    
    |   | 
    	
        img/gemma.png
    ADDED
    
    |   | 
    	
        img/keras_logo_k.png
    ADDED
    
    |   | 
    	
        img/llama.png
    ADDED
    
    |   | 
    	
        img/mistral.png
    ADDED
    
    |   | 
    	
        img/usr.png
    ADDED
    
    |   | 
    	
        img/vicuna.png
    ADDED
    
    |   | 
    	
        models.py
    ADDED
    
    | @@ -0,0 +1,105 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import keras
         | 
| 2 | 
            +
            import keras_hub
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            model_presets = [
         | 
| 5 | 
            +
                "hf://google/gemma-2-instruct-9b-keras",
         | 
| 6 | 
            +
                "hf://meta-llama/Llama-3.1-8B-Instruct",
         | 
| 7 | 
            +
                "hf://google/codegemma-7b-it-keras",
         | 
| 8 | 
            +
                "hf://keras/mistral_instruct_7b_en",
         | 
| 9 | 
            +
                "hf://keras/vicuna_1.5_7b_en",
         | 
| 10 | 
            +
            ]
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            model_labels = map(lambda s: s.removeprefix("hf://"), model_presets)
         | 
| 13 | 
            +
            model_labels = map(lambda s: s.removeprefix("google/"), model_labels)
         | 
| 14 | 
            +
            model_labels = map(lambda s: s.removeprefix("keras/"), model_labels)
         | 
| 15 | 
            +
            model_labels = map(lambda s: s.removeprefix("meta-llama/"), model_labels)
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def preset_to_website_url(preset):
         | 
| 19 | 
            +
                preset = preset.removeprefix("hf://")
         | 
| 20 | 
            +
                url = "http://huggingface.co/" + preset
         | 
| 21 | 
            +
                return url
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            def get_appropriate_chat_template(preset):
         | 
| 25 | 
            +
                return "Vicuna" if "vicuna" in preset else "auto"
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            def get_default_layout_map(preset_name, device_mesh):
         | 
| 29 | 
            +
                # Llama's default layout map works for mistral and vicuna
         | 
| 30 | 
            +
                # because their transformer layers have the same names.
         | 
| 31 | 
            +
                if (
         | 
| 32 | 
            +
                    "Llama" in preset_name
         | 
| 33 | 
            +
                    or "mistral" in preset_name
         | 
| 34 | 
            +
                    or "vicuna" in preset_name
         | 
| 35 | 
            +
                ):
         | 
| 36 | 
            +
                    return keras_hub.models.Llama3Backbone.get_layout_map(device_mesh)
         | 
| 37 | 
            +
                elif "gemma" in preset_name:
         | 
| 38 | 
            +
                    return keras_hub.models.GemmaBackbone.get_layout_map(device_mesh)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            def log_applied_layout_map(model):
         | 
| 42 | 
            +
                if "Gemma" in type(model):
         | 
| 43 | 
            +
                    transformer_decoder_block_name = "decoder_block_1"
         | 
| 44 | 
            +
                elif "Llama3" in type(model) or "Mistral" in type(model):
         | 
| 45 | 
            +
                    transformer_decoder_block_name = "transformer_layer_1"
         | 
| 46 | 
            +
                else:
         | 
| 47 | 
            +
                    assert (0, "Model type not recognized. Cannot display model layout.")
         | 
| 48 | 
            +
                    # See how layer sharding was applied
         | 
| 49 | 
            +
                    embedding_layer = model.backbone.get_layer("token_embedding")
         | 
| 50 | 
            +
                    print(embedding_layer)
         | 
| 51 | 
            +
                    decoder_block = model.backbone.get_layer(transformer_decoder_block_name)
         | 
| 52 | 
            +
                    print(type(decoder_block))
         | 
| 53 | 
            +
                    for variable in embedding_layer.weights + decoder_block.weights:
         | 
| 54 | 
            +
                        print(
         | 
| 55 | 
            +
                            f"{variable.path:<58}  \
         | 
| 56 | 
            +
                              {str(variable.shape):<16}  \
         | 
| 57 | 
            +
                              {str(variable.value.sharding.spec):<35} \
         | 
| 58 | 
            +
                              {str(variable.dtype)}"
         | 
| 59 | 
            +
                        )
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
            def load_model(preset):
         | 
| 63 | 
            +
                devices = keras.distribution.list_devices()
         | 
| 64 | 
            +
                device_mesh = keras.distribution.DeviceMesh(
         | 
| 65 | 
            +
                    shape=(1, len(devices)), axis_names=["batch", "model"], devices=devices
         | 
| 66 | 
            +
                )
         | 
| 67 | 
            +
                model_parallel = keras.distribution.ModelParallel(
         | 
| 68 | 
            +
                    layout_map=get_default_layout_map(preset, device_mesh),
         | 
| 69 | 
            +
                    batch_dim_name="batch",
         | 
| 70 | 
            +
                )
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                with model_parallel.scope():
         | 
| 73 | 
            +
                    # These two buggy models need this workaround to be loaded in bfloat16
         | 
| 74 | 
            +
                    if "google/gemma-2-instruct-9b-keras" in preset:
         | 
| 75 | 
            +
                        model = keras_hub.models.GemmaCausalLM(
         | 
| 76 | 
            +
                            backbone=keras_hub.models.GemmaBackbone.from_preset(
         | 
| 77 | 
            +
                                preset, dtype="bfloat16"
         | 
| 78 | 
            +
                            ),
         | 
| 79 | 
            +
                            preprocessor=keras_hub.models.GemmaCausalLMPreprocessor.from_preset(
         | 
| 80 | 
            +
                                preset
         | 
| 81 | 
            +
                            ),
         | 
| 82 | 
            +
                        )
         | 
| 83 | 
            +
                    elif "meta-llama/Llama-3.1-8B-Instruct" in preset:
         | 
| 84 | 
            +
                        model = keras_hub.models.Llama3CausalLM(
         | 
| 85 | 
            +
                            backbone=keras_hub.models.Llama3Backbone.from_preset(
         | 
| 86 | 
            +
                                preset, dtype="bfloat16"
         | 
| 87 | 
            +
                            ),
         | 
| 88 | 
            +
                            preprocessor=keras_hub.models.Llama3CausalLMPreprocessor.from_preset(
         | 
| 89 | 
            +
                                preset
         | 
| 90 | 
            +
                            ),
         | 
| 91 | 
            +
                        )
         | 
| 92 | 
            +
                    else:
         | 
| 93 | 
            +
                        model = keras_hub.models.CausalLM.from_preset(
         | 
| 94 | 
            +
                            preset, dtype="bfloat16"
         | 
| 95 | 
            +
                        )
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                log_applied_layout_map(model)
         | 
| 98 | 
            +
                return model
         | 
| 99 | 
            +
             | 
| 100 | 
            +
             | 
| 101 | 
            +
            # Some small models too
         | 
| 102 | 
            +
            # model1 = keras_hub.models.CausalLM.from_preset("hf://meta-llama/Llama-3.2-1B-Instruct", dtype="bfloat16")
         | 
| 103 | 
            +
            # model2 = keras_hub.models.CausalLM.from_preset("hf://google/gemma-2b-it-keras", dtype="bfloat16")
         | 
| 104 | 
            +
            # model3 = keras_hub.models.CausalLM.from_preset("hf://meta-llama/Llama-3.2-3B-Instruct", dtype="bfloat16")
         | 
| 105 | 
            +
            # keras/gemma_1.1_instruct_7b_en
         | 
    	
        requirements.txt
    CHANGED
    
    | @@ -1 +1,6 @@ | |
| 1 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
         | 
| 2 | 
            +
            jax[tpu]
         | 
| 3 | 
            +
            keras>=3
         | 
| 4 | 
            +
            keras-hub
         | 
| 5 | 
            +
            safetensors
         | 
| 6 | 
            +
            huggingface_hub
         | 
 
			

