File size: 4,357 Bytes
0747d86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import gradio as gr
import numpy as np
import onnxruntime as ort
import re
import os

# ----------------------------
# CONFIGURATION (UPDATE THESE)
# ----------------------------
MODEL_PATH = "tinystories_lstm.onnx"      # Path to your ONNX model
VOCAB_PATH = "vocab.txt"                  # Path to vocab file

# Download files if not present (for Hugging Face Spaces or Colab)
if not os.path.exists(MODEL_PATH):
    from huggingface_hub import hf_hub_download
    MODEL_PATH = hf_hub_download(
        repo_id="phmd/TinyStories-LSTM-5.5M",  # πŸ‘ˆ REPLACE WITH YOUR HF MODEL ID
        filename="tinystories_lstm.onnx"
    )
if not os.path.exists(VOCAB_PATH):
    VOCAB_PATH = hf_hub_download(
        repo_id="phmd/TinyStories-LSTM-5.5M",
        filename="vocab.txt"
    )

# ----------------------------
# LOAD MODEL & VOCAB
# ----------------------------
print("Loading vocabulary...")
with open(VOCAB_PATH, "r") as f:
    vocab = [line.strip() for line in f]
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for word, idx in word2idx.items()}

SOS_IDX = word2idx["<SOS>"]
EOS_IDX = word2idx["<EOS>"]
PAD_IDX = word2idx["<PAD>"]
UNK_IDX = word2idx["<UNK>"]

print("Loading ONNX model...")
ort_session = ort.InferenceSession(MODEL_PATH)

# ----------------------------
# HELPER FUNCTIONS
# ----------------------------
def tokenize(text):
    text = re.sub(r'([.,!?])', r' \1 ', text.lower())
    return text.split()

def generate_story(prompt, max_new_tokens=64, temperature=0.8):
    if not prompt.strip():
        prompt = "once upon a time"
    
    tokens = tokenize(prompt)
    current_seq = [SOS_IDX] + [word2idx.get(t, UNK_IDX) for t in tokens]
    
    for _ in range(max_new_tokens):
        # Pad/truncate to 50
        padded = current_seq[-50:] if len(current_seq) > 50 else current_seq
        padded = padded + [PAD_IDX] * (50 - len(padded))
        
        input_tensor = np.array([padded], dtype=np.int64)
        logits = ort_session.run(None, {"input": input_tensor})[0]
        
        # Get logits at last real token position
        last_pos = min(len(current_seq) - 1, 49)
        next_token_logits = logits[0, last_pos, :] / temperature
        
        # Apply softmax sampling
        probs = np.exp(next_token_logits - np.max(next_token_logits))
        probs = probs / (np.sum(probs) + 1e-8)
        next_token = np.random.choice(len(probs), p=probs)
        
        if next_token == EOS_IDX:
            break
        current_seq.append(next_token)
    
    # Decode and clean
    words = [idx2word[idx] for idx in current_seq[1:] if idx != PAD_IDX]
    story = " ".join(words)
    story = re.sub(r'\s+([.,!?])', r'\1', story)  # Fix spacing
    story = story.replace(" Ò€ ", '"').replace("Ò€", '"')  # Fix encoding artifacts
    return story

# ----------------------------
# GRADIO INTERFACE
# ----------------------------
with gr.Blocks(title="TinyStories LSTM") as demo:
    gr.Markdown("# πŸ“– TinyStories Word-Level LSTM")
    gr.Markdown("A **10.9 MB** LSTM that generates children's stories in seconds β€” **runs on CPU!**")
    
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(
                label="Story Starter",
                value="once upon a time",
                placeholder="e.g., 'there was a brave little mouse...'"
            )
            with gr.Row():
                max_len = gr.Slider(10, 200, value=80, label="Max New Tokens")
                temp = gr.Slider(0.5, 1.5, value=0.8, label="Temperature")
            btn = gr.Button("Generate Story πŸͺ„", variant="primary")
        
        with gr.Column():
            output = gr.Textbox(label="Generated Story", lines=15)
    
    btn.click(
        fn=generate_story,
        inputs=[prompt, max_len, temp],
        outputs=output
    )
    
    gr.Examples(
        examples=[
            ["once upon a time"],
            ["there was a robot who loved flowers"],
            ["in a faraway forest, a squirrel found a key"]
        ],
        inputs=prompt
    )
    
    gr.Markdown("""
    ---
    Model: [TinyStories LSTM (ONNX)](https://huggingface.co/your-username/tinystories-lstm-onnx) β€’ 
    Size: 10.9 MB β€’ 
    Runs on CPU in seconds β€’ 
    Trained on 500k TinyStories
    """)

# Launch (for Colab or local)
if __name__ == "__main__":
    demo.launch()