Leo71288's picture
Update app.py
dbd3899 verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import spaces
# === Chargement du modèle et du tokenizer ===
model_path = "nvidia/Nemotron-Mini-4B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_path)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map={"": "cpu"},
torch_dtype=torch.float32
)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
model_kwargs={"torch_dtype": torch.float32},
device_map={"": "cpu"}
)
# === Fonction de génération de prompt ===
def create_prompt(system_message, user_message, tool_definition="", context=""):
return f"{system_message}\n\nUser: {user_message}\nAssistant:"
# === Génération de réponse ===
@spaces.GPU(duration=94)
def generate_response(message, history, system_message, max_tokens, temperature, top_p, do_sample, use_pipeline=False, tool_definition="", context=""):
full_prompt = create_prompt(system_message, message, tool_definition, context)
if use_pipeline:
response = pipe(
full_prompt,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=do_sample
)[0]['generated_text']
else:
max_model_length = getattr(model.config, 'max_position_embeddings', 8192)
max_length = max_model_length - max_tokens
inputs = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
input_ids = inputs['input_ids'].to(model.device)
attention_mask = inputs['attention_mask'].to(model.device)
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=do_sample,
attention_mask=attention_mask
)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
assistant_response = response.split("Assistant:")[-1].strip()
if tool_definition and "<toolcall>" in assistant_response:
tool_call = assistant_response.split("<toolcall>")[1].split("</toolcall>")[0]
assistant_response += f"\n\nTool Call: {tool_call}\n\nNote: This is a simulated tool call."
return assistant_response
# === Gradio interaction ===
def user(user_message, history):
return "", history + [[user_message, None]]
def bot(history, system_prompt, max_length, temperature, top_p, advanced_checkbox, use_pipeline, tool_definition):
user_message = history[-1][0]
do_sample = advanced_checkbox
bot_message = generate_response(user_message, history, system_prompt, max_length, temperature, top_p, do_sample, use_pipeline, tool_definition)
history[-1][1] = bot_message
return history
# === Interface Gradio ===
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=2):
system_prompt = gr.TextArea(label="System Prompt", placeholder="add context here...", lines=5)
user_input = gr.TextArea(label="Input", placeholder="Talk with Nemotron !", lines=2)
advanced_checkbox = gr.Checkbox(label="Advanced Settings", value=False)
with gr.Column(visible=False) as advanced_settings:
max_length = gr.Slider(label="📏Max Length", minimum=12, maximum=64000, value=2048, step=1)
temperature = gr.Slider(label="🌡️Temperature", minimum=0.01, maximum=1.0, value=0.7, step=0.01)
top_p = gr.Slider(label="⚛️Top-p (Nucleus Sampling)", minimum=0.1, maximum=1.0, value=0.9, step=0.01)
use_pipeline = gr.Checkbox(label="Use Pipeline", value=False)
use_tool = gr.Checkbox(label="Use Function Calling", value=False)
with gr.Column(visible=False) as tool_options:
tool_definition = gr.Code(
label="Tool Definition (JSON)",
value='{\n "name": "example_tool",\n "description": "A dummy tool.",\n "parameters": {\n "param1": {"type": "string", "description": "Parameter 1"}\n },\n "required": ["param1"]\n}',
lines=15,
language="json"
)
generate_button = gr.Button(value="Send")
with gr.Column(scale=2):
chatbot = gr.Chatbot(label="Nemotron_Mini")
generate_button.click(
user,
[user_input, chatbot],
[user_input, chatbot],
queue=False
).then(
bot,
[chatbot, system_prompt, max_length, temperature, top_p, advanced_checkbox, use_pipeline, tool_definition],
chatbot
)
advanced_checkbox.change(
fn=lambda x: gr.update(visible=x),
inputs=[advanced_checkbox],
outputs=[advanced_settings]
)
use_tool.change(
fn=lambda x: gr.update(visible=x),
inputs=[use_tool],
outputs=[tool_options]
)
if __name__ == "__main__":
demo.queue()
demo.launch(ssr_mode=False, mcp_server=True)