import gradio as gr import numpy as np import onnxruntime as ort import re import os # ---------------------------- # CONFIGURATION (UPDATE THESE) # ---------------------------- MODEL_PATH = "tinystories_lstm.onnx" # Path to your ONNX model VOCAB_PATH = "vocab.txt" # Path to vocab file # Download files if not present (for Hugging Face Spaces or Colab) if not os.path.exists(MODEL_PATH): from huggingface_hub import hf_hub_download MODEL_PATH = hf_hub_download( repo_id="phmd/TinyStories-LSTM-5.5M", # 👈 REPLACE WITH YOUR HF MODEL ID filename="tinystories_lstm.onnx" ) if not os.path.exists(VOCAB_PATH): VOCAB_PATH = hf_hub_download( repo_id="phmd/TinyStories-LSTM-5.5M", filename="vocab.txt" ) # ---------------------------- # LOAD MODEL & VOCAB # ---------------------------- print("Loading vocabulary...") with open(VOCAB_PATH, "r") as f: vocab = [line.strip() for line in f] word2idx = {word: idx for idx, word in enumerate(vocab)} idx2word = {idx: word for word, idx in word2idx.items()} SOS_IDX = word2idx[""] EOS_IDX = word2idx[""] PAD_IDX = word2idx[""] UNK_IDX = word2idx[""] print("Loading ONNX model...") ort_session = ort.InferenceSession(MODEL_PATH) # ---------------------------- # HELPER FUNCTIONS # ---------------------------- def tokenize(text): text = re.sub(r'([.,!?])', r' \1 ', text.lower()) return text.split() def generate_story(prompt, max_new_tokens=64, temperature=0.8): if not prompt.strip(): prompt = "once upon a time" tokens = tokenize(prompt) current_seq = [SOS_IDX] + [word2idx.get(t, UNK_IDX) for t in tokens] for _ in range(max_new_tokens): # Pad/truncate to 50 padded = current_seq[-50:] if len(current_seq) > 50 else current_seq padded = padded + [PAD_IDX] * (50 - len(padded)) input_tensor = np.array([padded], dtype=np.int64) logits = ort_session.run(None, {"input": input_tensor})[0] # Get logits at last real token position last_pos = min(len(current_seq) - 1, 49) next_token_logits = logits[0, last_pos, :] / temperature # Apply softmax sampling probs = np.exp(next_token_logits - np.max(next_token_logits)) probs = probs / (np.sum(probs) + 1e-8) next_token = np.random.choice(len(probs), p=probs) if next_token == EOS_IDX: break current_seq.append(next_token) # Decode and clean words = [idx2word[idx] for idx in current_seq[1:] if idx != PAD_IDX] story = " ".join(words) story = re.sub(r'\s+([.,!?])', r'\1', story) # Fix spacing story = story.replace(" †", '"').replace("â€", '"') # Fix encoding artifacts return story # ---------------------------- # GRADIO INTERFACE # ---------------------------- with gr.Blocks(title="TinyStories LSTM") as demo: gr.Markdown("# 📖 TinyStories Word-Level LSTM") gr.Markdown("A **10.9 MB** LSTM that generates children's stories in seconds — **runs on CPU!**") with gr.Row(): with gr.Column(): prompt = gr.Textbox( label="Story Starter", value="once upon a time", placeholder="e.g., 'there was a brave little mouse...'" ) with gr.Row(): max_len = gr.Slider(10, 200, value=80, label="Max New Tokens") temp = gr.Slider(0.5, 1.5, value=0.8, label="Temperature") btn = gr.Button("Generate Story 🪄", variant="primary") with gr.Column(): output = gr.Textbox(label="Generated Story", lines=15) btn.click( fn=generate_story, inputs=[prompt, max_len, temp], outputs=output ) gr.Examples( examples=[ ["once upon a time"], ["there was a robot who loved flowers"], ["in a faraway forest, a squirrel found a key"] ], inputs=prompt ) gr.Markdown(""" --- Model: [TinyStories LSTM (ONNX)](https://huggingface.co/your-username/tinystories-lstm-onnx) • Size: 10.9 MB • Runs on CPU in seconds • Trained on 500k TinyStories """) # Launch (for Colab or local) if __name__ == "__main__": demo.launch()