Spaces:
Sleeping
Sleeping
File size: 4,357 Bytes
0747d86 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import gradio as gr
import numpy as np
import onnxruntime as ort
import re
import os
# ----------------------------
# CONFIGURATION (UPDATE THESE)
# ----------------------------
MODEL_PATH = "tinystories_lstm.onnx" # Path to your ONNX model
VOCAB_PATH = "vocab.txt" # Path to vocab file
# Download files if not present (for Hugging Face Spaces or Colab)
if not os.path.exists(MODEL_PATH):
from huggingface_hub import hf_hub_download
MODEL_PATH = hf_hub_download(
repo_id="phmd/TinyStories-LSTM-5.5M", # π REPLACE WITH YOUR HF MODEL ID
filename="tinystories_lstm.onnx"
)
if not os.path.exists(VOCAB_PATH):
VOCAB_PATH = hf_hub_download(
repo_id="phmd/TinyStories-LSTM-5.5M",
filename="vocab.txt"
)
# ----------------------------
# LOAD MODEL & VOCAB
# ----------------------------
print("Loading vocabulary...")
with open(VOCAB_PATH, "r") as f:
vocab = [line.strip() for line in f]
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for word, idx in word2idx.items()}
SOS_IDX = word2idx["<SOS>"]
EOS_IDX = word2idx["<EOS>"]
PAD_IDX = word2idx["<PAD>"]
UNK_IDX = word2idx["<UNK>"]
print("Loading ONNX model...")
ort_session = ort.InferenceSession(MODEL_PATH)
# ----------------------------
# HELPER FUNCTIONS
# ----------------------------
def tokenize(text):
text = re.sub(r'([.,!?])', r' \1 ', text.lower())
return text.split()
def generate_story(prompt, max_new_tokens=64, temperature=0.8):
if not prompt.strip():
prompt = "once upon a time"
tokens = tokenize(prompt)
current_seq = [SOS_IDX] + [word2idx.get(t, UNK_IDX) for t in tokens]
for _ in range(max_new_tokens):
# Pad/truncate to 50
padded = current_seq[-50:] if len(current_seq) > 50 else current_seq
padded = padded + [PAD_IDX] * (50 - len(padded))
input_tensor = np.array([padded], dtype=np.int64)
logits = ort_session.run(None, {"input": input_tensor})[0]
# Get logits at last real token position
last_pos = min(len(current_seq) - 1, 49)
next_token_logits = logits[0, last_pos, :] / temperature
# Apply softmax sampling
probs = np.exp(next_token_logits - np.max(next_token_logits))
probs = probs / (np.sum(probs) + 1e-8)
next_token = np.random.choice(len(probs), p=probs)
if next_token == EOS_IDX:
break
current_seq.append(next_token)
# Decode and clean
words = [idx2word[idx] for idx in current_seq[1:] if idx != PAD_IDX]
story = " ".join(words)
story = re.sub(r'\s+([.,!?])', r'\1', story) # Fix spacing
story = story.replace(" Γ’β¬ ", '"').replace("Γ’β¬", '"') # Fix encoding artifacts
return story
# ----------------------------
# GRADIO INTERFACE
# ----------------------------
with gr.Blocks(title="TinyStories LSTM") as demo:
gr.Markdown("# π TinyStories Word-Level LSTM")
gr.Markdown("A **10.9 MB** LSTM that generates children's stories in seconds β **runs on CPU!**")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Story Starter",
value="once upon a time",
placeholder="e.g., 'there was a brave little mouse...'"
)
with gr.Row():
max_len = gr.Slider(10, 200, value=80, label="Max New Tokens")
temp = gr.Slider(0.5, 1.5, value=0.8, label="Temperature")
btn = gr.Button("Generate Story πͺ", variant="primary")
with gr.Column():
output = gr.Textbox(label="Generated Story", lines=15)
btn.click(
fn=generate_story,
inputs=[prompt, max_len, temp],
outputs=output
)
gr.Examples(
examples=[
["once upon a time"],
["there was a robot who loved flowers"],
["in a faraway forest, a squirrel found a key"]
],
inputs=prompt
)
gr.Markdown("""
---
Model: [TinyStories LSTM (ONNX)](https://huggingface.co/your-username/tinystories-lstm-onnx) β’
Size: 10.9 MB β’
Runs on CPU in seconds β’
Trained on 500k TinyStories
""")
# Launch (for Colab or local)
if __name__ == "__main__":
demo.launch() |