Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import onnxruntime as ort | |
| import re | |
| import os | |
| # ---------------------------- | |
| # CONFIGURATION (UPDATE THESE) | |
| # ---------------------------- | |
| MODEL_PATH = "tinystories_lstm.onnx" # Path to your ONNX model | |
| VOCAB_PATH = "vocab.txt" # Path to vocab file | |
| # Download files if not present (for Hugging Face Spaces or Colab) | |
| if not os.path.exists(MODEL_PATH): | |
| from huggingface_hub import hf_hub_download | |
| MODEL_PATH = hf_hub_download( | |
| repo_id="phmd/TinyStories-LSTM-5.5M", # π REPLACE WITH YOUR HF MODEL ID | |
| filename="tinystories_lstm.onnx" | |
| ) | |
| if not os.path.exists(VOCAB_PATH): | |
| VOCAB_PATH = hf_hub_download( | |
| repo_id="phmd/TinyStories-LSTM-5.5M", | |
| filename="vocab.txt" | |
| ) | |
| # ---------------------------- | |
| # LOAD MODEL & VOCAB | |
| # ---------------------------- | |
| print("Loading vocabulary...") | |
| with open(VOCAB_PATH, "r") as f: | |
| vocab = [line.strip() for line in f] | |
| word2idx = {word: idx for idx, word in enumerate(vocab)} | |
| idx2word = {idx: word for word, idx in word2idx.items()} | |
| SOS_IDX = word2idx["<SOS>"] | |
| EOS_IDX = word2idx["<EOS>"] | |
| PAD_IDX = word2idx["<PAD>"] | |
| UNK_IDX = word2idx["<UNK>"] | |
| print("Loading ONNX model...") | |
| ort_session = ort.InferenceSession(MODEL_PATH) | |
| # ---------------------------- | |
| # HELPER FUNCTIONS | |
| # ---------------------------- | |
| def tokenize(text): | |
| text = re.sub(r'([.,!?])', r' \1 ', text.lower()) | |
| return text.split() | |
| def generate_story(prompt, max_new_tokens=64, temperature=0.8): | |
| if not prompt.strip(): | |
| prompt = "once upon a time" | |
| tokens = tokenize(prompt) | |
| current_seq = [SOS_IDX] + [word2idx.get(t, UNK_IDX) for t in tokens] | |
| for _ in range(max_new_tokens): | |
| # Pad/truncate to 50 | |
| padded = current_seq[-50:] if len(current_seq) > 50 else current_seq | |
| padded = padded + [PAD_IDX] * (50 - len(padded)) | |
| input_tensor = np.array([padded], dtype=np.int64) | |
| logits = ort_session.run(None, {"input": input_tensor})[0] | |
| # Get logits at last real token position | |
| last_pos = min(len(current_seq) - 1, 49) | |
| next_token_logits = logits[0, last_pos, :] / temperature | |
| # Apply softmax sampling | |
| probs = np.exp(next_token_logits - np.max(next_token_logits)) | |
| probs = probs / (np.sum(probs) + 1e-8) | |
| next_token = np.random.choice(len(probs), p=probs) | |
| if next_token == EOS_IDX: | |
| break | |
| current_seq.append(next_token) | |
| # Decode and clean | |
| words = [idx2word[idx] for idx in current_seq[1:] if idx != PAD_IDX] | |
| story = " ".join(words) | |
| story = re.sub(r'\s+([.,!?])', r'\1', story) # Fix spacing | |
| story = story.replace(" Γ’β¬ ", '"').replace("Γ’β¬", '"') # Fix encoding artifacts | |
| return story | |
| # ---------------------------- | |
| # GRADIO INTERFACE | |
| # ---------------------------- | |
| with gr.Blocks(title="TinyStories LSTM") as demo: | |
| gr.Markdown("# π TinyStories Word-Level LSTM") | |
| gr.Markdown("A **10.9 MB** LSTM that generates children's stories in seconds β **runs on CPU!**") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox( | |
| label="Story Starter", | |
| value="once upon a time", | |
| placeholder="e.g., 'there was a brave little mouse...'" | |
| ) | |
| with gr.Row(): | |
| max_len = gr.Slider(10, 200, value=80, label="Max New Tokens") | |
| temp = gr.Slider(0.5, 1.5, value=0.8, label="Temperature") | |
| btn = gr.Button("Generate Story πͺ", variant="primary") | |
| with gr.Column(): | |
| output = gr.Textbox(label="Generated Story", lines=15) | |
| btn.click( | |
| fn=generate_story, | |
| inputs=[prompt, max_len, temp], | |
| outputs=output | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["once upon a time"], | |
| ["there was a robot who loved flowers"], | |
| ["in a faraway forest, a squirrel found a key"] | |
| ], | |
| inputs=prompt | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| Model: [TinyStories LSTM (ONNX)](https://huggingface.co/your-username/tinystories-lstm-onnx) β’ | |
| Size: 10.9 MB β’ | |
| Runs on CPU in seconds β’ | |
| Trained on 500k TinyStories | |
| """) | |
| # Launch (for Colab or local) | |
| if __name__ == "__main__": | |
| demo.launch() |