phmd commited on
Commit
0747d86
·
verified ·
1 Parent(s): dd19777

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -0
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import onnxruntime as ort
4
+ import re
5
+ import os
6
+
7
+ # ----------------------------
8
+ # CONFIGURATION (UPDATE THESE)
9
+ # ----------------------------
10
+ MODEL_PATH = "tinystories_lstm.onnx" # Path to your ONNX model
11
+ VOCAB_PATH = "vocab.txt" # Path to vocab file
12
+
13
+ # Download files if not present (for Hugging Face Spaces or Colab)
14
+ if not os.path.exists(MODEL_PATH):
15
+ from huggingface_hub import hf_hub_download
16
+ MODEL_PATH = hf_hub_download(
17
+ repo_id="phmd/TinyStories-LSTM-5.5M", # 👈 REPLACE WITH YOUR HF MODEL ID
18
+ filename="tinystories_lstm.onnx"
19
+ )
20
+ if not os.path.exists(VOCAB_PATH):
21
+ VOCAB_PATH = hf_hub_download(
22
+ repo_id="phmd/TinyStories-LSTM-5.5M",
23
+ filename="vocab.txt"
24
+ )
25
+
26
+ # ----------------------------
27
+ # LOAD MODEL & VOCAB
28
+ # ----------------------------
29
+ print("Loading vocabulary...")
30
+ with open(VOCAB_PATH, "r") as f:
31
+ vocab = [line.strip() for line in f]
32
+ word2idx = {word: idx for idx, word in enumerate(vocab)}
33
+ idx2word = {idx: word for word, idx in word2idx.items()}
34
+
35
+ SOS_IDX = word2idx["<SOS>"]
36
+ EOS_IDX = word2idx["<EOS>"]
37
+ PAD_IDX = word2idx["<PAD>"]
38
+ UNK_IDX = word2idx["<UNK>"]
39
+
40
+ print("Loading ONNX model...")
41
+ ort_session = ort.InferenceSession(MODEL_PATH)
42
+
43
+ # ----------------------------
44
+ # HELPER FUNCTIONS
45
+ # ----------------------------
46
+ def tokenize(text):
47
+ text = re.sub(r'([.,!?])', r' \1 ', text.lower())
48
+ return text.split()
49
+
50
+ def generate_story(prompt, max_new_tokens=64, temperature=0.8):
51
+ if not prompt.strip():
52
+ prompt = "once upon a time"
53
+
54
+ tokens = tokenize(prompt)
55
+ current_seq = [SOS_IDX] + [word2idx.get(t, UNK_IDX) for t in tokens]
56
+
57
+ for _ in range(max_new_tokens):
58
+ # Pad/truncate to 50
59
+ padded = current_seq[-50:] if len(current_seq) > 50 else current_seq
60
+ padded = padded + [PAD_IDX] * (50 - len(padded))
61
+
62
+ input_tensor = np.array([padded], dtype=np.int64)
63
+ logits = ort_session.run(None, {"input": input_tensor})[0]
64
+
65
+ # Get logits at last real token position
66
+ last_pos = min(len(current_seq) - 1, 49)
67
+ next_token_logits = logits[0, last_pos, :] / temperature
68
+
69
+ # Apply softmax sampling
70
+ probs = np.exp(next_token_logits - np.max(next_token_logits))
71
+ probs = probs / (np.sum(probs) + 1e-8)
72
+ next_token = np.random.choice(len(probs), p=probs)
73
+
74
+ if next_token == EOS_IDX:
75
+ break
76
+ current_seq.append(next_token)
77
+
78
+ # Decode and clean
79
+ words = [idx2word[idx] for idx in current_seq[1:] if idx != PAD_IDX]
80
+ story = " ".join(words)
81
+ story = re.sub(r'\s+([.,!?])', r'\1', story) # Fix spacing
82
+ story = story.replace(" †", '"').replace("â€", '"') # Fix encoding artifacts
83
+ return story
84
+
85
+ # ----------------------------
86
+ # GRADIO INTERFACE
87
+ # ----------------------------
88
+ with gr.Blocks(title="TinyStories LSTM") as demo:
89
+ gr.Markdown("# 📖 TinyStories Word-Level LSTM")
90
+ gr.Markdown("A **10.9 MB** LSTM that generates children's stories in seconds — **runs on CPU!**")
91
+
92
+ with gr.Row():
93
+ with gr.Column():
94
+ prompt = gr.Textbox(
95
+ label="Story Starter",
96
+ value="once upon a time",
97
+ placeholder="e.g., 'there was a brave little mouse...'"
98
+ )
99
+ with gr.Row():
100
+ max_len = gr.Slider(10, 200, value=80, label="Max New Tokens")
101
+ temp = gr.Slider(0.5, 1.5, value=0.8, label="Temperature")
102
+ btn = gr.Button("Generate Story 🪄", variant="primary")
103
+
104
+ with gr.Column():
105
+ output = gr.Textbox(label="Generated Story", lines=15)
106
+
107
+ btn.click(
108
+ fn=generate_story,
109
+ inputs=[prompt, max_len, temp],
110
+ outputs=output
111
+ )
112
+
113
+ gr.Examples(
114
+ examples=[
115
+ ["once upon a time"],
116
+ ["there was a robot who loved flowers"],
117
+ ["in a faraway forest, a squirrel found a key"]
118
+ ],
119
+ inputs=prompt
120
+ )
121
+
122
+ gr.Markdown("""
123
+ ---
124
+ Model: [TinyStories LSTM (ONNX)](https://huggingface.co/your-username/tinystories-lstm-onnx) •
125
+ Size: 10.9 MB •
126
+ Runs on CPU in seconds •
127
+ Trained on 500k TinyStories
128
+ """)
129
+
130
+ # Launch (for Colab or local)
131
+ if __name__ == "__main__":
132
+ demo.launch()