phmd's picture
Create app.py
0747d86 verified
import gradio as gr
import numpy as np
import onnxruntime as ort
import re
import os
# ----------------------------
# CONFIGURATION (UPDATE THESE)
# ----------------------------
MODEL_PATH = "tinystories_lstm.onnx" # Path to your ONNX model
VOCAB_PATH = "vocab.txt" # Path to vocab file
# Download files if not present (for Hugging Face Spaces or Colab)
if not os.path.exists(MODEL_PATH):
from huggingface_hub import hf_hub_download
MODEL_PATH = hf_hub_download(
repo_id="phmd/TinyStories-LSTM-5.5M", # πŸ‘ˆ REPLACE WITH YOUR HF MODEL ID
filename="tinystories_lstm.onnx"
)
if not os.path.exists(VOCAB_PATH):
VOCAB_PATH = hf_hub_download(
repo_id="phmd/TinyStories-LSTM-5.5M",
filename="vocab.txt"
)
# ----------------------------
# LOAD MODEL & VOCAB
# ----------------------------
print("Loading vocabulary...")
with open(VOCAB_PATH, "r") as f:
vocab = [line.strip() for line in f]
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for word, idx in word2idx.items()}
SOS_IDX = word2idx["<SOS>"]
EOS_IDX = word2idx["<EOS>"]
PAD_IDX = word2idx["<PAD>"]
UNK_IDX = word2idx["<UNK>"]
print("Loading ONNX model...")
ort_session = ort.InferenceSession(MODEL_PATH)
# ----------------------------
# HELPER FUNCTIONS
# ----------------------------
def tokenize(text):
text = re.sub(r'([.,!?])', r' \1 ', text.lower())
return text.split()
def generate_story(prompt, max_new_tokens=64, temperature=0.8):
if not prompt.strip():
prompt = "once upon a time"
tokens = tokenize(prompt)
current_seq = [SOS_IDX] + [word2idx.get(t, UNK_IDX) for t in tokens]
for _ in range(max_new_tokens):
# Pad/truncate to 50
padded = current_seq[-50:] if len(current_seq) > 50 else current_seq
padded = padded + [PAD_IDX] * (50 - len(padded))
input_tensor = np.array([padded], dtype=np.int64)
logits = ort_session.run(None, {"input": input_tensor})[0]
# Get logits at last real token position
last_pos = min(len(current_seq) - 1, 49)
next_token_logits = logits[0, last_pos, :] / temperature
# Apply softmax sampling
probs = np.exp(next_token_logits - np.max(next_token_logits))
probs = probs / (np.sum(probs) + 1e-8)
next_token = np.random.choice(len(probs), p=probs)
if next_token == EOS_IDX:
break
current_seq.append(next_token)
# Decode and clean
words = [idx2word[idx] for idx in current_seq[1:] if idx != PAD_IDX]
story = " ".join(words)
story = re.sub(r'\s+([.,!?])', r'\1', story) # Fix spacing
story = story.replace(" Ò€ ", '"').replace("Ò€", '"') # Fix encoding artifacts
return story
# ----------------------------
# GRADIO INTERFACE
# ----------------------------
with gr.Blocks(title="TinyStories LSTM") as demo:
gr.Markdown("# πŸ“– TinyStories Word-Level LSTM")
gr.Markdown("A **10.9 MB** LSTM that generates children's stories in seconds β€” **runs on CPU!**")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Story Starter",
value="once upon a time",
placeholder="e.g., 'there was a brave little mouse...'"
)
with gr.Row():
max_len = gr.Slider(10, 200, value=80, label="Max New Tokens")
temp = gr.Slider(0.5, 1.5, value=0.8, label="Temperature")
btn = gr.Button("Generate Story πŸͺ„", variant="primary")
with gr.Column():
output = gr.Textbox(label="Generated Story", lines=15)
btn.click(
fn=generate_story,
inputs=[prompt, max_len, temp],
outputs=output
)
gr.Examples(
examples=[
["once upon a time"],
["there was a robot who loved flowers"],
["in a faraway forest, a squirrel found a key"]
],
inputs=prompt
)
gr.Markdown("""
---
Model: [TinyStories LSTM (ONNX)](https://huggingface.co/your-username/tinystories-lstm-onnx) β€’
Size: 10.9 MB β€’
Runs on CPU in seconds β€’
Trained on 500k TinyStories
""")
# Launch (for Colab or local)
if __name__ == "__main__":
demo.launch()