ibibrahim Xenova HF Staff commited on
Commit
2379782
·
verified ·
1 Parent(s): e35a862

Add WebGPU demo files (#1)

Browse files

- [WIP] Add demo files (b9224a933ed61040af9c1054defa7629fe85ccd4)
- Update README.md (716521d5ce913ca8cae48764b32dc8326b9f5263)
- Update README.md (0aa1108f8968cb5b493f60c6af346b5ebec240c6)
- Update package.json (57dea5890a937808971bd015838545869b307ceb)


Co-authored-by: Joshua <[email protected]>

.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ public/banner.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,19 @@
1
  ---
2
- title: Granite 4.0 WebGPU
3
- emoji: 🐠
4
  colorFrom: indigo
5
  colorTo: green
6
- sdk: gradio
7
- sdk_version: 5.48.0
8
- app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
 
 
 
 
 
 
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Granite-4.0 WebGPU
3
+ emoji: 🪨
4
  colorFrom: indigo
5
  colorTo: green
6
+ sdk: static
 
 
7
  pinned: false
8
  license: apache-2.0
9
+ app_build_command: npm run build
10
+ app_file: dist/index.html
11
+ thumbnail: >-
12
+ https://huggingface.co/spaces/ibm-granite/Granite-4.0-WebGPU/resolve/main/public/banner.png
13
+ short_description: Run Granite-4.0-Micro locally in your browser on WebGPU
14
+ models:
15
+ - onnx-community/granite-4.0-micro-ONNX-web
16
+ - onnx-community/granite-4.0-micro-ONNX
17
  ---
18
 
19
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
eslint.config.js ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import js from "@eslint/js";
2
+ import globals from "globals";
3
+ import react from "eslint-plugin-react";
4
+ import reactHooks from "eslint-plugin-react-hooks";
5
+ import reactRefresh from "eslint-plugin-react-refresh";
6
+
7
+ export default [
8
+ { ignores: ["dist"] },
9
+ {
10
+ files: ["**/*.{js,jsx}"],
11
+ languageOptions: {
12
+ ecmaVersion: 2020,
13
+ globals: globals.browser,
14
+ parserOptions: {
15
+ ecmaVersion: "latest",
16
+ ecmaFeatures: { jsx: true },
17
+ sourceType: "module",
18
+ },
19
+ },
20
+ settings: { react: { version: "18.3" } },
21
+ plugins: {
22
+ react,
23
+ "react-hooks": reactHooks,
24
+ "react-refresh": reactRefresh,
25
+ },
26
+ rules: {
27
+ ...js.configs.recommended.rules,
28
+ ...react.configs.recommended.rules,
29
+ ...react.configs["jsx-runtime"].rules,
30
+ ...reactHooks.configs.recommended.rules,
31
+ "react/jsx-no-target-blank": "off",
32
+ "react-refresh/only-export-components": [
33
+ "warn",
34
+ { allowConstantExport: true },
35
+ ],
36
+ },
37
+ },
38
+ ];
index.html ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <link rel="icon" type="image/png" href="/logo.png" />
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
+ <title>Granite-4.0 WebGPU</title>
8
+ </head>
9
+
10
+ <body>
11
+ <div id="root"></div>
12
+
13
+ <script>
14
+ window.MathJax = {
15
+ tex: {
16
+ inlineMath: [
17
+ ["$", "$"],
18
+ ["\\(", "\\)"],
19
+ ],
20
+ },
21
+ svg: {
22
+ fontCache: "global",
23
+ },
24
+ };
25
+ </script>
26
+ <script
27
+ id="MathJax-script"
28
+ src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"
29
+ ></script>
30
+ <script type="module" src="/src/main.jsx"></script>
31
+ </body>
32
+ </html>
package.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "granite-4.0-webgpu",
3
+ "private": true,
4
+ "version": "0.0.0",
5
+ "type": "module",
6
+ "scripts": {
7
+ "dev": "vite",
8
+ "build": "vite build",
9
+ "lint": "eslint .",
10
+ "preview": "vite preview"
11
+ },
12
+ "dependencies": {
13
+ "@huggingface/transformers": "3.7.5",
14
+ "dompurify": "^3.1.2",
15
+ "marked": "^12.0.2",
16
+ "react": "^18.3.1",
17
+ "react-dom": "^18.3.1"
18
+ },
19
+ "devDependencies": {
20
+ "@eslint/js": "^9.9.0",
21
+ "@types/react": "^18.3.3",
22
+ "@types/react-dom": "^18.3.0",
23
+ "@vitejs/plugin-react": "^4.3.1",
24
+ "autoprefixer": "^10.4.20",
25
+ "eslint": "^9.9.0",
26
+ "eslint-plugin-react": "^7.35.0",
27
+ "eslint-plugin-react-hooks": "^5.1.0-rc.0",
28
+ "eslint-plugin-react-refresh": "^0.4.9",
29
+ "globals": "^15.9.0",
30
+ "postcss": "^8.4.41",
31
+ "tailwindcss": "^3.4.10",
32
+ "vite": "^6.2.0"
33
+ }
34
+ }
postcss.config.js ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ export default {
2
+ plugins: {
3
+ tailwindcss: {},
4
+ autoprefixer: {},
5
+ },
6
+ };
public/banner.png ADDED

Git LFS Details

  • SHA256: cca1c6874368c593763d536822910d6e398287c42a89634a74646edd4ce98e19
  • Pointer size: 131 Bytes
  • Size of remote file: 162 kB
public/logo.png ADDED
src/App.jsx ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useEffect, useState, useRef } from "react";
2
+
3
+ import Chat from "./components/Chat";
4
+ import ArrowRightIcon from "./components/icons/ArrowRightIcon";
5
+ import StopIcon from "./components/icons/StopIcon";
6
+ import Progress from "./components/Progress";
7
+
8
+ const IS_WEBGPU_AVAILABLE = !!navigator.gpu;
9
+ const STICKY_SCROLL_THRESHOLD = 120;
10
+ const EXAMPLES = [
11
+ "Give me some tips to improve my time management skills.",
12
+ "What is the difference between AI and ML?",
13
+ "Write python code to compute the nth fibonacci number.",
14
+ ];
15
+
16
+ function App() {
17
+ // Create a reference to the worker object.
18
+ const worker = useRef(null);
19
+
20
+ const textareaRef = useRef(null);
21
+ const chatContainerRef = useRef(null);
22
+
23
+ // Model loading and progress
24
+ const [status, setStatus] = useState(null);
25
+ const [error, setError] = useState(null);
26
+ const [loadingMessage, setLoadingMessage] = useState("");
27
+ const [progressItems, setProgressItems] = useState([]);
28
+ const [isRunning, setIsRunning] = useState(false);
29
+
30
+ // Inputs and outputs
31
+ const [input, setInput] = useState("");
32
+ const [messages, setMessages] = useState([]);
33
+ const [tps, setTps] = useState(null);
34
+ const [numTokens, setNumTokens] = useState(null);
35
+
36
+ function onEnter(message) {
37
+ setMessages((prev) => [...prev, { role: "user", content: message }]);
38
+ setTps(null);
39
+ setIsRunning(true);
40
+ setInput("");
41
+ }
42
+
43
+ function onInterrupt() {
44
+ // NOTE: We do not set isRunning to false here because the worker
45
+ // will send a 'complete' message when it is done.
46
+ worker.current.postMessage({ type: "interrupt" });
47
+ }
48
+
49
+ useEffect(() => {
50
+ resizeInput();
51
+ }, [input]);
52
+
53
+ function resizeInput() {
54
+ if (!textareaRef.current) return;
55
+
56
+ const target = textareaRef.current;
57
+ target.style.height = "auto";
58
+ const newHeight = Math.min(Math.max(target.scrollHeight, 24), 200);
59
+ target.style.height = `${newHeight}px`;
60
+ }
61
+
62
+ // We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted.
63
+ useEffect(() => {
64
+ // Create the worker if it does not yet exist.
65
+ if (!worker.current) {
66
+ worker.current = new Worker(new URL("./worker.js", import.meta.url), {
67
+ type: "module",
68
+ });
69
+ worker.current.postMessage({ type: "check" }); // Do a feature check
70
+ }
71
+
72
+ // Create a callback function for messages from the worker thread.
73
+ const onMessageReceived = (e) => {
74
+ switch (e.data.status) {
75
+ case "loading":
76
+ // Model file start load: add a new progress item to the list.
77
+ setStatus("loading");
78
+ setLoadingMessage(e.data.data);
79
+ break;
80
+
81
+ case "initiate":
82
+ setProgressItems((prev) => [...prev, e.data]);
83
+ break;
84
+
85
+ case "progress":
86
+ // Model file progress: update one of the progress items.
87
+ setProgressItems((prev) =>
88
+ prev.map((item) => {
89
+ if (item.file === e.data.file) {
90
+ return { ...item, ...e.data };
91
+ }
92
+ return item;
93
+ }),
94
+ );
95
+ break;
96
+
97
+ case "done":
98
+ // Model file loaded: remove the progress item from the list.
99
+ setProgressItems((prev) =>
100
+ prev.filter((item) => item.file !== e.data.file),
101
+ );
102
+ break;
103
+
104
+ case "ready":
105
+ // Pipeline ready: the worker is ready to accept messages.
106
+ setStatus("ready");
107
+ break;
108
+
109
+ case "start":
110
+ {
111
+ // Start generation
112
+ setMessages((prev) => [
113
+ ...prev,
114
+ { role: "assistant", content: "" },
115
+ ]);
116
+ }
117
+ break;
118
+
119
+ case "update":
120
+ {
121
+ // Generation update: update the output text.
122
+ // Parse messages
123
+ const { output, tps, numTokens } = e.data;
124
+ setTps(tps);
125
+ setNumTokens(numTokens);
126
+ setMessages((prev) => {
127
+ const cloned = [...prev];
128
+ const last = cloned.at(-1);
129
+ cloned[cloned.length - 1] = {
130
+ ...last,
131
+ content: last.content + output,
132
+ };
133
+ return cloned;
134
+ });
135
+ }
136
+ break;
137
+
138
+ case "complete":
139
+ // Generation complete: re-enable the "Generate" button
140
+ setIsRunning(false);
141
+ break;
142
+
143
+ case "error":
144
+ setError(e.data.data);
145
+ break;
146
+ }
147
+ };
148
+
149
+ const onErrorReceived = (e) => {
150
+ console.error("Worker error:", e);
151
+ };
152
+
153
+ // Attach the callback function as an event listener.
154
+ worker.current.addEventListener("message", onMessageReceived);
155
+ worker.current.addEventListener("error", onErrorReceived);
156
+
157
+ // Define a cleanup function for when the component is unmounted.
158
+ return () => {
159
+ worker.current.removeEventListener("message", onMessageReceived);
160
+ worker.current.removeEventListener("error", onErrorReceived);
161
+ };
162
+ }, []);
163
+
164
+ // Send the messages to the worker thread whenever the `messages` state changes.
165
+ useEffect(() => {
166
+ if (messages.filter((x) => x.role === "user").length === 0) {
167
+ // No user messages yet: do nothing.
168
+ return;
169
+ }
170
+ if (messages.at(-1).role === "assistant") {
171
+ // Do not update if the last message is from the assistant
172
+ return;
173
+ }
174
+ setTps(null);
175
+ worker.current.postMessage({ type: "generate", data: messages });
176
+ }, [messages, isRunning]);
177
+
178
+ useEffect(() => {
179
+ if (!chatContainerRef.current || !isRunning) return;
180
+ const element = chatContainerRef.current;
181
+ if (
182
+ element.scrollHeight - element.scrollTop - element.clientHeight <
183
+ STICKY_SCROLL_THRESHOLD
184
+ ) {
185
+ element.scrollTop = element.scrollHeight;
186
+ }
187
+ }, [messages, isRunning]);
188
+
189
+ return IS_WEBGPU_AVAILABLE ? (
190
+ <div className="flex flex-col h-screen mx-auto items justify-end text-gray-800 dark:text-gray-200 bg-white dark:bg-gray-900">
191
+ {status === null && messages.length === 0 && (
192
+ <div className="h-full overflow-auto scrollbar-thin flex justify-center items-center flex-col relative">
193
+ <div className="flex flex-col items-center mb-1 max-w-[360px] text-center">
194
+ <img
195
+ src="logo.png"
196
+ width="80%"
197
+ height="auto"
198
+ className="block rounded-2xl mb-6"
199
+ ></img>
200
+ <h1 className="text-4xl font-bold mb-2">Granite-4.0 WebGPU</h1>
201
+ <h2 className="text-xl">
202
+ A reliable and efficient AI chatbot <br />that runs locally in your
203
+ browser.
204
+ </h2>
205
+ </div>
206
+
207
+ <div className="flex flex-col items-center px-4 text-lg">
208
+ <p className="max-w-[514px] mb-4">
209
+ <br />
210
+ You are about to load{" "}
211
+ <a
212
+ href="https://huggingface.co/onnx-community/granite-4.0-micro-ONNX-web"
213
+ target="_blank"
214
+ rel="noreferrer"
215
+ className="font-medium underline"
216
+ >
217
+ Granite-4.0 Micro
218
+ </a>
219
+ , a 3.4B parameter long-context instruct model optimized for in-browser inference.
220
+ Everything runs entirely in your browser with{" "}
221
+ <a
222
+ href="https://huggingface.co/docs/transformers.js"
223
+ target="_blank"
224
+ rel="noreferrer"
225
+ className="underline"
226
+ >
227
+ 🤗&nbsp;Transformers.js
228
+ </a>{" "}
229
+ and ONNX Runtime Web, meaning no data is sent to a server. Once
230
+ loaded (≈ 2.3 GB), it can even be used offline.
231
+ </p>
232
+
233
+ {error && (
234
+ <div className="text-red-500 text-center mb-2">
235
+ <p className="mb-1">
236
+ Unable to load model due to the following error:
237
+ </p>
238
+ <p className="text-sm">{error}</p>
239
+ </div>
240
+ )}
241
+
242
+ <button
243
+ className="border px-4 py-2 rounded-lg bg-blue-400 text-white hover:bg-blue-500 disabled:bg-blue-100 disabled:cursor-not-allowed select-none"
244
+ onClick={() => {
245
+ worker.current.postMessage({ type: "load" });
246
+ setStatus("loading");
247
+ }}
248
+ disabled={status !== null || error !== null}
249
+ >
250
+ Load model
251
+ </button>
252
+ </div>
253
+ </div>
254
+ )}
255
+ {status === "loading" && (
256
+ <>
257
+ <div className="w-full max-w-[500px] text-left mx-auto p-4 bottom-0 mt-auto">
258
+ <p className="text-center mb-1">{loadingMessage}</p>
259
+ {progressItems.map(({ file, progress, total }, i) => (
260
+ <Progress
261
+ key={i}
262
+ text={file}
263
+ percentage={progress}
264
+ total={total}
265
+ />
266
+ ))}
267
+ </div>
268
+ </>
269
+ )}
270
+
271
+ {status === "ready" && (
272
+ <div
273
+ ref={chatContainerRef}
274
+ className="overflow-y-auto scrollbar-thin w-full flex flex-col items-center h-full"
275
+ >
276
+ <Chat messages={messages} />
277
+ {messages.length === 0 && (
278
+ <div>
279
+ {EXAMPLES.map((msg, i) => (
280
+ <div
281
+ key={i}
282
+ className="m-1 border dark:border-gray-600 rounded-md p-2 bg-gray-100 dark:bg-gray-700 cursor-pointer"
283
+ onClick={() => onEnter(msg)}
284
+ >
285
+ {msg}
286
+ </div>
287
+ ))}
288
+ </div>
289
+ )}
290
+ <p className="text-center text-sm min-h-6 text-gray-500 dark:text-gray-300">
291
+ {tps && messages.length > 0 && (
292
+ <>
293
+ {!isRunning && (
294
+ <span>
295
+ Generated {numTokens} tokens in{" "}
296
+ {(numTokens / tps).toFixed(2)} seconds&nbsp;&#40;
297
+ </span>
298
+ )}
299
+ {
300
+ <>
301
+ <span className="font-medium text-center mr-1 text-black dark:text-white">
302
+ {tps.toFixed(2)}
303
+ </span>
304
+ <span className="text-gray-500 dark:text-gray-300">
305
+ tokens/second
306
+ </span>
307
+ </>
308
+ }
309
+ {!isRunning && (
310
+ <>
311
+ <span className="mr-1">&#41;.</span>
312
+ <span
313
+ className="underline cursor-pointer"
314
+ onClick={() => {
315
+ worker.current.postMessage({ type: "reset" });
316
+ setMessages([]);
317
+ }}
318
+ >
319
+ Reset
320
+ </span>
321
+ </>
322
+ )}
323
+ </>
324
+ )}
325
+ </p>
326
+ </div>
327
+ )}
328
+
329
+ <div className="mt-2 border dark:bg-gray-700 rounded-lg w-[600px] max-w-[80%] max-h-[200px] mx-auto relative mb-3 flex">
330
+ <textarea
331
+ ref={textareaRef}
332
+ className="scrollbar-thin w-[550px] dark:bg-gray-700 px-3 py-4 rounded-lg bg-transparent border-none outline-none text-gray-800 disabled:text-gray-400 dark:text-gray-200 placeholder-gray-500 dark:placeholder-gray-400 disabled:placeholder-gray-200 resize-none disabled:cursor-not-allowed"
333
+ placeholder="Type your message..."
334
+ type="text"
335
+ rows={1}
336
+ value={input}
337
+ disabled={status !== "ready"}
338
+ title={status === "ready" ? "Model is ready" : "Model not loaded yet"}
339
+ onKeyDown={(e) => {
340
+ if (
341
+ input.length > 0 &&
342
+ !isRunning &&
343
+ e.key === "Enter" &&
344
+ !e.shiftKey
345
+ ) {
346
+ e.preventDefault(); // Prevent default behavior of Enter key
347
+ onEnter(input);
348
+ }
349
+ }}
350
+ onInput={(e) => setInput(e.target.value)}
351
+ />
352
+ {isRunning ? (
353
+ <div className="cursor-pointer" onClick={onInterrupt}>
354
+ <StopIcon className="h-8 w-8 p-1 rounded-md text-gray-800 dark:text-gray-100 absolute right-3 bottom-3" />
355
+ </div>
356
+ ) : input.length > 0 ? (
357
+ <div className="cursor-pointer" onClick={() => onEnter(input)}>
358
+ <ArrowRightIcon
359
+ className={`h-8 w-8 p-1 bg-gray-800 dark:bg-gray-100 text-white dark:text-black rounded-md absolute right-3 bottom-3`}
360
+ />
361
+ </div>
362
+ ) : (
363
+ <div>
364
+ <ArrowRightIcon
365
+ className={`h-8 w-8 p-1 bg-gray-200 dark:bg-gray-600 text-gray-50 dark:text-gray-800 rounded-md absolute right-3 bottom-3`}
366
+ />
367
+ </div>
368
+ )}
369
+ </div>
370
+
371
+ <p className="text-xs text-gray-400 text-center mb-3">
372
+ Disclaimer: Generated content may be inaccurate or false.
373
+ </p>
374
+ </div>
375
+ ) : (
376
+ <div className="fixed w-screen h-screen bg-black z-10 bg-opacity-[92%] text-white text-2xl font-semibold flex justify-center items-center text-center">
377
+ WebGPU is not supported
378
+ <br />
379
+ by this browser :&#40;
380
+ </div>
381
+ );
382
+ }
383
+
384
+ export default App;
src/components/Chat.css ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @scope (.markdown) {
2
+ /* Code blocks */
3
+ pre {
4
+ margin: 0.5rem 0;
5
+ white-space: break-spaces;
6
+ }
7
+
8
+ code {
9
+ padding: 0.2em 0.4em;
10
+ border-radius: 4px;
11
+ font-family: Consolas, Monaco, "Andale Mono", "Ubuntu Mono", monospace;
12
+ font-size: 0.9em;
13
+ }
14
+
15
+ pre,
16
+ code {
17
+ background-color: #f2f2f2;
18
+ }
19
+
20
+ @media (prefers-color-scheme: dark) {
21
+ pre,
22
+ code {
23
+ background-color: #333;
24
+ }
25
+ }
26
+
27
+ pre:has(code) {
28
+ padding: 1rem 0.5rem;
29
+ }
30
+
31
+ pre > code {
32
+ padding: 0;
33
+ }
34
+
35
+ /* Headings */
36
+ h1,
37
+ h2,
38
+ h3,
39
+ h4,
40
+ h5,
41
+ h6 {
42
+ font-weight: 600;
43
+ line-height: 1.2;
44
+ }
45
+
46
+ h1 {
47
+ font-size: 2em;
48
+ margin: 1rem 0;
49
+ }
50
+
51
+ h2 {
52
+ font-size: 1.5em;
53
+ margin: 0.83rem 0;
54
+ }
55
+
56
+ h3 {
57
+ font-size: 1.25em;
58
+ margin: 0.67rem 0;
59
+ }
60
+
61
+ h4 {
62
+ font-size: 1em;
63
+ margin: 0.5rem 0;
64
+ }
65
+
66
+ h5 {
67
+ font-size: 0.875em;
68
+ margin: 0.33rem 0;
69
+ }
70
+
71
+ h6 {
72
+ font-size: 0.75em;
73
+ margin: 0.25rem 0;
74
+ }
75
+
76
+ h1,
77
+ h2,
78
+ h3,
79
+ h4,
80
+ h5,
81
+ h6:first-child {
82
+ margin-top: 0;
83
+ }
84
+
85
+ /* Unordered List */
86
+ ul {
87
+ list-style-type: disc;
88
+ margin-left: 1.5rem;
89
+ }
90
+
91
+ /* Ordered List */
92
+ ol {
93
+ list-style-type: decimal;
94
+ margin-left: 1.5rem;
95
+ }
96
+
97
+ /* List Items */
98
+ li {
99
+ margin: 0.25rem 0;
100
+ }
101
+
102
+ p:not(:first-child) {
103
+ margin-top: 0.75rem;
104
+ }
105
+
106
+ p:not(:last-child) {
107
+ margin-bottom: 0.75rem;
108
+ }
109
+
110
+ ul > li {
111
+ margin-left: 1rem;
112
+ }
113
+
114
+ /* Table */
115
+ table,
116
+ th,
117
+ td {
118
+ border: 1px solid lightgray;
119
+ padding: 0.25rem;
120
+ }
121
+
122
+ @media (prefers-color-scheme: dark) {
123
+ table,
124
+ th,
125
+ td {
126
+ border: 1px solid #f2f2f2;
127
+ }
128
+ }
129
+ }
src/components/Chat.jsx ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { marked } from "marked";
2
+ import DOMPurify from "dompurify";
3
+
4
+ import BotIcon from "./icons/BotIcon";
5
+ import UserIcon from "./icons/UserIcon";
6
+
7
+ import "./Chat.css";
8
+ import { useEffect } from "react";
9
+
10
+ function render(text) {
11
+ return DOMPurify.sanitize(marked.parse(text));
12
+ }
13
+
14
+ export default function Chat({ messages }) {
15
+ const empty = messages.length === 0;
16
+
17
+ useEffect(() => {
18
+ window.MathJax.typeset();
19
+ }, [messages]);
20
+
21
+ return (
22
+ <div
23
+ className={`flex-1 p-6 max-w-[960px] w-full ${empty ? "flex flex-col items-center justify-end" : "space-y-4"}`}
24
+ >
25
+ {empty ? (
26
+ <div className="text-xl">Ready!</div>
27
+ ) : (
28
+ messages.map((msg, i) => (
29
+ <div key={`message-${i}`} className="flex items-start space-x-4">
30
+ {msg.role === "assistant" ? (
31
+ <>
32
+ <BotIcon className="h-6 w-6 min-h-6 min-w-6 my-3 text-gray-500 dark:text-gray-300" />
33
+ <div className="bg-gray-200 dark:bg-gray-700 rounded-lg p-4">
34
+ <p className="min-h-6 text-gray-800 dark:text-gray-200 overflow-wrap-anywhere">
35
+ {msg.content.length > 0 ? (
36
+ <span
37
+ className="markdown"
38
+ dangerouslySetInnerHTML={{
39
+ __html: render(msg.content),
40
+ }}
41
+ />
42
+ ) : (
43
+ <span className="h-6 flex items-center gap-1">
44
+ <span className="w-2.5 h-2.5 bg-gray-600 dark:bg-gray-300 rounded-full animate-pulse"></span>
45
+ <span className="w-2.5 h-2.5 bg-gray-600 dark:bg-gray-300 rounded-full animate-pulse animation-delay-200"></span>
46
+ <span className="w-2.5 h-2.5 bg-gray-600 dark:bg-gray-300 rounded-full animate-pulse animation-delay-400"></span>
47
+ </span>
48
+ )}
49
+ </p>
50
+ </div>
51
+ </>
52
+ ) : (
53
+ <>
54
+ <UserIcon className="h-6 w-6 min-h-6 min-w-6 my-3 text-gray-500 dark:text-gray-300" />
55
+ <div className="bg-blue-500 text-white rounded-lg p-4">
56
+ <p className="min-h-6 overflow-wrap-anywhere">
57
+ {msg.content}
58
+ </p>
59
+ </div>
60
+ </>
61
+ )}
62
+ </div>
63
+ ))
64
+ )}
65
+ </div>
66
+ );
67
+ }
src/components/Progress.jsx ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ function formatBytes(size) {
2
+ const i = size == 0 ? 0 : Math.floor(Math.log(size) / Math.log(1024));
3
+ return (
4
+ +(size / Math.pow(1024, i)).toFixed(2) * 1 +
5
+ ["B", "kB", "MB", "GB", "TB"][i]
6
+ );
7
+ }
8
+
9
+ export default function Progress({ text, percentage, total }) {
10
+ percentage ??= 0;
11
+ return (
12
+ <div className="w-full bg-gray-100 dark:bg-gray-700 text-left rounded-lg overflow-hidden mb-0.5">
13
+ <div
14
+ className="bg-blue-400 whitespace-nowrap px-1 text-sm"
15
+ style={{ width: `${percentage}%` }}
16
+ >
17
+ {text} ({percentage.toFixed(2)}%
18
+ {isNaN(total) ? "" : ` of ${formatBytes(total)}`})
19
+ </div>
20
+ </div>
21
+ );
22
+ }
src/components/icons/ArrowRightIcon.jsx ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export default function ArrowRightIcon(props) {
2
+ return (
3
+ <svg
4
+ {...props}
5
+ xmlns="http://www.w3.org/2000/svg"
6
+ width="24"
7
+ height="24"
8
+ viewBox="0 0 24 24"
9
+ fill="none"
10
+ stroke="currentColor"
11
+ strokeWidth="2"
12
+ strokeLinecap="round"
13
+ strokeLinejoin="round"
14
+ >
15
+ <path d="M5 12h14" />
16
+ <path d="m12 5 7 7-7 7" />
17
+ </svg>
18
+ );
19
+ }
src/components/icons/BotIcon.jsx ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export default function BotIcon(props) {
2
+ return (
3
+ <svg
4
+ {...props}
5
+ xmlns="http://www.w3.org/2000/svg"
6
+ width="24"
7
+ height="24"
8
+ viewBox="0 0 24 24"
9
+ fill="none"
10
+ stroke="currentColor"
11
+ strokeWidth="2"
12
+ strokeLinecap="round"
13
+ strokeLinejoin="round"
14
+ >
15
+ <path d="M12 8V4H8" />
16
+ <rect width="16" height="12" x="4" y="8" rx="2" />
17
+ <path d="M2 14h2" />
18
+ <path d="M20 14h2" />
19
+ <path d="M15 13v2" />
20
+ <path d="M9 13v2" />
21
+ </svg>
22
+ );
23
+ }
src/components/icons/StopIcon.jsx ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export default function StopIcon(props) {
2
+ return (
3
+ <svg
4
+ {...props}
5
+ xmlns="http://www.w3.org/2000/svg"
6
+ width="24"
7
+ height="24"
8
+ viewBox="0 0 24 24"
9
+ fill="none"
10
+ stroke="currentColor"
11
+ strokeWidth="2"
12
+ strokeLinecap="round"
13
+ strokeLinejoin="round"
14
+ >
15
+ <path d="M21 12a9 9 0 1 1-18 0 9 9 0 0 1 18 0Z" />
16
+ <path
17
+ fill="currentColor"
18
+ d="M9 9.563C9 9.252 9.252 9 9.563 9h4.874c.311 0 .563.252.563.563v4.874c0 .311-.252.563-.563.563H9.564A.562.562 0 0 1 9 14.437V9.564Z"
19
+ />
20
+ </svg>
21
+ );
22
+ }
src/components/icons/UserIcon.jsx ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export default function UserIcon(props) {
2
+ return (
3
+ <svg
4
+ {...props}
5
+ xmlns="http://www.w3.org/2000/svg"
6
+ width="24"
7
+ height="24"
8
+ viewBox="0 0 24 24"
9
+ fill="none"
10
+ stroke="currentColor"
11
+ strokeWidth="2"
12
+ strokeLinecap="round"
13
+ strokeLinejoin="round"
14
+ >
15
+ <path d="M19 21v-2a4 4 0 0 0-4-4H9a4 4 0 0 0-4 4v2" />
16
+ <circle cx="12" cy="7" r="4" />
17
+ </svg>
18
+ );
19
+ }
src/index.css ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @tailwind base;
2
+ @tailwind components;
3
+ @tailwind utilities;
4
+
5
+ @layer utilities {
6
+ .scrollbar-thin::-webkit-scrollbar {
7
+ @apply w-2;
8
+ }
9
+
10
+ .scrollbar-thin::-webkit-scrollbar-track {
11
+ @apply rounded-full bg-gray-100 dark:bg-gray-700;
12
+ }
13
+
14
+ .scrollbar-thin::-webkit-scrollbar-thumb {
15
+ @apply rounded-full bg-gray-300 dark:bg-gray-600;
16
+ }
17
+
18
+ .scrollbar-thin::-webkit-scrollbar-thumb:hover {
19
+ @apply bg-gray-500;
20
+ }
21
+
22
+ .animation-delay-200 {
23
+ animation-delay: 200ms;
24
+ }
25
+ .animation-delay-400 {
26
+ animation-delay: 400ms;
27
+ }
28
+
29
+ .overflow-wrap-anywhere {
30
+ overflow-wrap: anywhere;
31
+ }
32
+ }
src/main.jsx ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import React from "react";
2
+ import ReactDOM from "react-dom/client";
3
+ import App from "./App.jsx";
4
+ import "./index.css";
5
+
6
+ ReactDOM.createRoot(document.getElementById("root")).render(
7
+ <React.StrictMode>
8
+ <App />
9
+ </React.StrictMode>,
10
+ );
src/worker.js ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import {
2
+ AutoTokenizer,
3
+ AutoModelForCausalLM,
4
+ TextStreamer,
5
+ InterruptableStoppingCriteria,
6
+ } from "@huggingface/transformers";
7
+
8
+ /**
9
+ * Helper function to perform feature detection for WebGPU
10
+ */
11
+ async function check() {
12
+ try {
13
+ const adapter = await navigator.gpu.requestAdapter();
14
+ if (!adapter) {
15
+ throw new Error("WebGPU is not supported (no adapter found)");
16
+ }
17
+ if (!adapter.features.has("shader-f16")) {
18
+ throw new Error("WebGPU adapter does not support shader-f16");
19
+ }
20
+ } catch (e) {
21
+ self.postMessage({
22
+ status: "error",
23
+ data: e.toString(),
24
+ });
25
+ }
26
+ }
27
+
28
+ /**
29
+ * This class uses the Singleton pattern to enable lazy-loading of the pipeline
30
+ */
31
+ class TextGenerationPipeline {
32
+ static model_id = "onnx-community/granite-4.0-micro-ONNX-web";
33
+
34
+ static async getInstance(progress_callback = null) {
35
+ this.tokenizer ??= AutoTokenizer.from_pretrained(this.model_id, {
36
+ progress_callback,
37
+ });
38
+
39
+ this.model ??= AutoModelForCausalLM.from_pretrained(this.model_id, {
40
+ dtype: "q4f16",
41
+ device: "webgpu",
42
+ progress_callback,
43
+ });
44
+
45
+ return Promise.all([this.tokenizer, this.model]);
46
+ }
47
+ }
48
+
49
+ const stopping_criteria = new InterruptableStoppingCriteria();
50
+
51
+ let past_key_values_cache = null;
52
+ async function generate(messages) {
53
+ // Retrieve the text-generation pipeline.
54
+ const [tokenizer, model] = await TextGenerationPipeline.getInstance();
55
+
56
+ const inputs = tokenizer.apply_chat_template(messages, {
57
+ add_generation_prompt: true,
58
+ return_dict: true,
59
+ });
60
+
61
+ let startTime;
62
+ let numTokens = 0;
63
+ let tps;
64
+ const token_callback_function = () => {
65
+ startTime ??= performance.now();
66
+
67
+ if (numTokens++ > 0) {
68
+ tps = (numTokens / (performance.now() - startTime)) * 1000;
69
+ }
70
+ };
71
+ const callback_function = (output) => {
72
+ self.postMessage({
73
+ status: "update",
74
+ output,
75
+ tps,
76
+ numTokens,
77
+ });
78
+ };
79
+
80
+ const streamer = new TextStreamer(tokenizer, {
81
+ skip_prompt: true,
82
+ skip_special_tokens: true,
83
+ callback_function,
84
+ token_callback_function,
85
+ });
86
+
87
+ // Tell the main thread we are starting
88
+ self.postMessage({ status: "start" });
89
+
90
+ const { past_key_values, sequences } = await model.generate({
91
+ ...inputs,
92
+ past_key_values: past_key_values_cache,
93
+
94
+ // Sampling
95
+ // do_sample: true,
96
+ // top_k: 3,
97
+ // temperature: 0.2,
98
+
99
+ max_new_tokens: 1024,
100
+ streamer,
101
+ stopping_criteria,
102
+ return_dict_in_generate: true,
103
+ });
104
+ past_key_values_cache = past_key_values;
105
+
106
+ const decoded = tokenizer.batch_decode(sequences, {
107
+ skip_special_tokens: true,
108
+ });
109
+
110
+ // Send the output back to the main thread
111
+ self.postMessage({
112
+ status: "complete",
113
+ output: decoded,
114
+ });
115
+ }
116
+
117
+ async function load() {
118
+ self.postMessage({
119
+ status: "loading",
120
+ data: "Loading model...",
121
+ });
122
+
123
+ // Load the pipeline and save it for future use.
124
+ const [tokenizer, model] = await TextGenerationPipeline.getInstance((x) => {
125
+ // We also add a progress callback to the pipeline so that we can
126
+ // track model loading.
127
+ self.postMessage(x);
128
+ });
129
+
130
+ self.postMessage({
131
+ status: "loading",
132
+ data: "Compiling shaders and warming up model...",
133
+ });
134
+
135
+ // Run model with dummy input to compile shaders
136
+ const inputs = tokenizer("a");
137
+ await model.generate({ ...inputs, max_new_tokens: 1 });
138
+ self.postMessage({ status: "ready" });
139
+ }
140
+ // Listen for messages from the main thread
141
+ self.addEventListener("message", async (e) => {
142
+ const { type, data } = e.data;
143
+
144
+ switch (type) {
145
+ case "check":
146
+ check();
147
+ break;
148
+
149
+ case "load":
150
+ load();
151
+ break;
152
+
153
+ case "generate":
154
+ stopping_criteria.reset();
155
+ generate(data);
156
+ break;
157
+
158
+ case "interrupt":
159
+ stopping_criteria.interrupt();
160
+ break;
161
+
162
+ case "reset":
163
+ past_key_values_cache = null;
164
+ stopping_criteria.reset();
165
+ break;
166
+ }
167
+ });
tailwind.config.js ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ /** @type {import('tailwindcss').Config} */
2
+ export default {
3
+ content: ["./index.html", "./src/**/*.{js,ts,jsx,tsx}"],
4
+ theme: {
5
+ extend: {},
6
+ },
7
+ plugins: [],
8
+ };
vite.config.js ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import { defineConfig } from "vite";
2
+ import react from "@vitejs/plugin-react";
3
+
4
+ // https://vitejs.dev/config/
5
+ export default defineConfig({
6
+ plugins: [react()],
7
+ });