chat-d32-demo / app.py
burtenshaw's picture
burtenshaw HF Staff
Update app.py
8f42a5a verified
raw
history blame
1.46 kB
import spaces
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model():
model_id="karpathy/nanochat-d32"
revision="refs/pr/1"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=False, revision=revision)
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=False, dtype=torch.bfloat16, revision=revision).to(device)
model.eval()
return tokenizer, model
tokenizer, model = load_model()
@spaces.GPU
def generate(prompt, history):
if len(history) > 0:
messages = history + [
{"role": "user", "content": prompt},
]
else:
messages = [
{"role": "user", "content": prompt},
]
print(history)
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True,
).to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=512,
)
generated_tokens = outputs[0, inputs.input_ids.shape[1]:]
output = tokenizer.decode(generated_tokens, skip_special_tokens=True)
return output
demo = gr.ChatInterface(fn=generate, type="messages", examples=["hello", "hola", "merhaba"], title="NanoChat")
demo.launch()