Xenova HF Staff commited on
Commit
c369b38
·
verified ·
1 Parent(s): c1521a5

Upload 10 files

Browse files
Files changed (10) hide show
  1. eslint.config.js +23 -0
  2. index.html +16 -0
  3. package.json +35 -0
  4. src/App.tsx +763 -0
  5. src/index.css +54 -0
  6. src/main.tsx +10 -0
  7. tsconfig.app.json +28 -0
  8. tsconfig.json +7 -0
  9. tsconfig.node.json +26 -0
  10. vite.config.ts +8 -0
eslint.config.js ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import js from "@eslint/js";
2
+ import globals from "globals";
3
+ import reactHooks from "eslint-plugin-react-hooks";
4
+ import reactRefresh from "eslint-plugin-react-refresh";
5
+ import tseslint from "typescript-eslint";
6
+ import { defineConfig, globalIgnores } from "eslint/config";
7
+
8
+ export default defineConfig([
9
+ globalIgnores(["dist"]),
10
+ {
11
+ files: ["**/*.{ts,tsx}"],
12
+ extends: [
13
+ js.configs.recommended,
14
+ tseslint.configs.recommended,
15
+ reactHooks.configs.flat.recommended,
16
+ reactRefresh.configs.vite,
17
+ ],
18
+ languageOptions: {
19
+ ecmaVersion: 2020,
20
+ globals: globals.browser,
21
+ },
22
+ },
23
+ ]);
index.html ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <link
6
+ rel="icon"
7
+ href="data:image/svg+xml,<svg xmlns=%22http://www.w3.org/2000/svg%22 viewBox=%220 0 100 100%22><text y=%22.9em%22 font-size=%2290%22>🥖</text></svg>"
8
+ />
9
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
10
+ <title>Baguettotron WebGPU</title>
11
+ </head>
12
+ <body>
13
+ <div id="root"></div>
14
+ <script type="module" src="/src/main.tsx"></script>
15
+ </body>
16
+ </html>
package.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "baguettotron-webgpu",
3
+ "private": true,
4
+ "version": "0.0.0",
5
+ "type": "module",
6
+ "scripts": {
7
+ "dev": "vite",
8
+ "build": "tsc -b && vite build",
9
+ "lint": "eslint .",
10
+ "preview": "vite preview"
11
+ },
12
+ "dependencies": {
13
+ "@huggingface/transformers": "^3.7.6",
14
+ "@tailwindcss/vite": "^4.1.17",
15
+ "lucide-react": "^0.553.0",
16
+ "react": "^19.2.0",
17
+ "react-dom": "^19.2.0",
18
+ "streamdown": "^1.4.0",
19
+ "tailwindcss": "^4.1.17"
20
+ },
21
+ "devDependencies": {
22
+ "@eslint/js": "^9.39.1",
23
+ "@types/node": "^24.10.0",
24
+ "@types/react": "^19.2.2",
25
+ "@types/react-dom": "^19.2.2",
26
+ "@vitejs/plugin-react": "^5.1.0",
27
+ "eslint": "^9.39.1",
28
+ "eslint-plugin-react-hooks": "^7.0.1",
29
+ "eslint-plugin-react-refresh": "^0.4.24",
30
+ "globals": "^16.5.0",
31
+ "typescript": "~5.9.3",
32
+ "typescript-eslint": "^8.46.3",
33
+ "vite": "^7.2.2"
34
+ }
35
+ }
src/App.tsx ADDED
@@ -0,0 +1,763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useState, useRef, useCallback, useLayoutEffect } from "react";
2
+ import { Send, Paperclip, Brain, ChevronDown, X, Plus } from "lucide-react";
3
+ import {
4
+ AutoModelForCausalLM,
5
+ AutoTokenizer,
6
+ InterruptableStoppingCriteria,
7
+ TextStreamer,
8
+ } from "@huggingface/transformers";
9
+ import { Streamdown } from "streamdown";
10
+
11
+ import type {
12
+ PreTrainedTokenizer,
13
+ LlamaForCausalLM,
14
+ } from "@huggingface/transformers";
15
+ import type React from "react";
16
+
17
+ const MODEL_ID = "onnx-community/Baguettotron-ONNX";
18
+
19
+ const DTYPES = {
20
+ fp32: "FP32 (~1.28 GB)",
21
+ fp16: "FP16 (~642 MB)",
22
+ q4: "Q4 (~329 MB)",
23
+ q4f16: "Q4F16 (~235 MB)",
24
+ } as const;
25
+ type Dtype = keyof typeof DTYPES;
26
+
27
+ const SOURCE_SEPARATOR_REGEX = /\n{2,}/g;
28
+
29
+ type Role = "user" | "assistant";
30
+
31
+ /**
32
+ * Format the sources into tagged segments for the model input.
33
+ */
34
+ const buildSourcesPayload = (rawContext: string) => {
35
+ const trimmed = rawContext.trim();
36
+ if (!trimmed) {
37
+ return { payload: "", count: 0, segments: [] };
38
+ }
39
+
40
+ const segments = trimmed
41
+ .split(SOURCE_SEPARATOR_REGEX)
42
+ .map((segment) => segment.trim())
43
+ .filter(Boolean);
44
+
45
+ if (segments.length === 0) {
46
+ return { payload: "", count: 0, segments: [] };
47
+ }
48
+
49
+ const payload =
50
+ "\n\n" +
51
+ segments
52
+ .map(
53
+ (segment, index) =>
54
+ `<source_${index + 1}>${segment}</source_${index + 1}>`,
55
+ )
56
+ .join("\n");
57
+ return { payload, count: segments.length, segments };
58
+ };
59
+
60
+ /**
61
+ * Converts <ref>...</ref> tags in the content to superscript references.
62
+ */
63
+ const convertRefsToSuperscript = (content: string) => {
64
+ const refRegex = /<ref name="([^"]+)">([\s\S]*?)<\/ref>/g;
65
+ const refLabelMap = new Map<string, number>();
66
+ let refCounter = 1;
67
+
68
+ // First, process all complete <ref>...</ref> tags
69
+ let result = content.replace(refRegex, (_, sourceName = "", refBody) => {
70
+ const label =
71
+ refLabelMap.get(sourceName) ??
72
+ (() => {
73
+ const assigned = refCounter++;
74
+ refLabelMap.set(sourceName, assigned);
75
+ return assigned;
76
+ })();
77
+ const escapedRefBody = refBody.replace(/"/g, "&quot;");
78
+ return `<sup className="cursor-pointer" title="${escapedRefBody}">[${label}]</sup>`;
79
+ });
80
+
81
+ // Remove any trailing incomplete <ref> tag
82
+ const incompleteRefRegex = /<ref[^>]*>[\s\S]*$/;
83
+ result = result.replace(incompleteRefRegex, "");
84
+
85
+ return result;
86
+ };
87
+
88
+ /**
89
+ * Sanitizes user input by replacing angle brackets.
90
+ */
91
+ const sanitizeInput = (text: string) => {
92
+ return text.replace(/</g, "&lt;").replace(/>/g, "&gt;");
93
+ };
94
+
95
+ /**
96
+ * Represents a single chat message in the history.
97
+ */
98
+ interface Message {
99
+ id: number;
100
+ role: Role;
101
+ content: string;
102
+
103
+ thinkTrace?: string;
104
+ rawStream?: string;
105
+ isLoading?: boolean;
106
+ timestamp?: number;
107
+ thinkEndTime?: number;
108
+ }
109
+
110
+ /**
111
+ * A simple, self-contained collapsible component.
112
+ */
113
+ const Collapsible: React.FC<{
114
+ title: React.ReactNode;
115
+ children: React.ReactNode;
116
+ }> = ({ title, children }) => {
117
+ const [isOpen, setIsOpen] = useState(false);
118
+ const contentRef = useRef<HTMLDivElement>(null);
119
+ return (
120
+ <div className="collapsible mt-2">
121
+ <button
122
+ onClick={() => setIsOpen(!isOpen)}
123
+ className="flex items-center space-x-1 text-xs font-medium text-amber-700 hover:text-amber-900 transition-colors"
124
+ >
125
+ {title}
126
+ <ChevronDown
127
+ size={14}
128
+ className={`transform transition-transform ${isOpen ? "rotate-180" : "rotate-0"}`}
129
+ />
130
+ </button>
131
+ <div
132
+ ref={contentRef}
133
+ style={{
134
+ maxHeight: isOpen ? `${contentRef.current?.scrollHeight}px` : "0px",
135
+ }}
136
+ className="overflow-hidden transition-all duration-300 ease-in-out"
137
+ >
138
+ <div className="mt-2 p-2 bg-amber-50 border border-dashed border-amber-200 rounded-md text-xs text-stone-600 prose-sm">
139
+ {children}
140
+ </div>
141
+ </div>
142
+ </div>
143
+ );
144
+ };
145
+
146
+ /**
147
+ * A single chat message bubble.
148
+ */
149
+ const MessageBubble: React.FC<{ message: Message; minHeight?: number }> = ({
150
+ message,
151
+ minHeight,
152
+ }) => {
153
+ const { role, content, thinkTrace, isLoading, timestamp, thinkEndTime } =
154
+ message;
155
+ const isUser = role === "user";
156
+
157
+ let thinkingText = "";
158
+ let opacityClass = "";
159
+ const hasDuration =
160
+ typeof thinkEndTime === "number" && typeof timestamp === "number";
161
+ const durationSeconds = hasDuration
162
+ ? Math.max(Math.round((thinkEndTime - timestamp) / 1000), 0)
163
+ : null;
164
+
165
+ if (isLoading && !thinkEndTime) {
166
+ thinkingText = "Thinking...";
167
+ opacityClass = "opacity-70 hover:opacity-100";
168
+ } else if (thinkTrace) {
169
+ thinkingText =
170
+ durationSeconds !== null
171
+ ? `Thought for ${durationSeconds} seconds`
172
+ : "Thought interrupted";
173
+ } else {
174
+ thinkingText = "Show Thoughts";
175
+ }
176
+
177
+ const markdownContent = convertRefsToSuperscript(content);
178
+
179
+ return (
180
+ <div
181
+ data-message-id={message.id}
182
+ data-role={role}
183
+ className={`message flex items-start animate-in fade-in slide-in-from-bottom-2 duration-300 py-2 ${isUser ? "justify-end" : "justify-start"}`}
184
+ style={{
185
+ minHeight,
186
+ }}
187
+ >
188
+ <div
189
+ className={`max-w-xl lg:max-w-2xl px-4 py-3 rounded-2xl ${
190
+ isUser
191
+ ? "bg-amber-500 text-white rounded-br-none"
192
+ : "bg-white text-stone-800 rounded-bl-none shadow-sm border border-stone-200"
193
+ }`}
194
+ >
195
+ {(thinkTrace || isLoading) && (
196
+ <Collapsible
197
+ title={
198
+ <div className="flex items-center space-x-1.5 text-sm">
199
+ <Brain size={16} />
200
+ <span
201
+ className={`${isLoading ? "animate-glisten" : ""} ${opacityClass}`}
202
+ >
203
+ {thinkingText}
204
+ </span>
205
+ </div>
206
+ }
207
+ >
208
+ <Streamdown
209
+ parseIncompleteMarkdown={false}
210
+ className="text-xs text-stone-500"
211
+ isAnimating={Boolean(isLoading && thinkEndTime)}
212
+ >
213
+ {thinkTrace || (isLoading ? "..." : "")}
214
+ </Streamdown>
215
+ </Collapsible>
216
+ )}
217
+ <div className={`${thinkTrace || isLoading ? "mt-2" : ""}`}>
218
+ <Streamdown
219
+ parseIncompleteMarkdown={false}
220
+ className="text-sm leading-relaxed text-stone-800"
221
+ isAnimating={Boolean(isLoading && !thinkEndTime)}
222
+ >
223
+ {markdownContent || (isLoading ? "" : "")}
224
+ </Streamdown>
225
+ </div>
226
+ </div>
227
+ </div>
228
+ );
229
+ };
230
+
231
+ /**
232
+ * Manages the model and tokenizer loading state and refs.
233
+ */
234
+ const useLLM = () => {
235
+ const [modelStatus, setModelStatus] = useState<
236
+ "idle" | "loading" | "ready" | "error"
237
+ >("idle");
238
+ const [loadProgress, setLoadProgress] = useState(0);
239
+ const modelRef = useRef<LlamaForCausalLM | null>(null);
240
+ const tokenizerRef = useRef<PreTrainedTokenizer | null>(null);
241
+
242
+ const loadModel = useCallback(
243
+ async (dtype: Dtype) => {
244
+ if (modelRef.current && tokenizerRef.current) {
245
+ setModelStatus("ready");
246
+ setLoadProgress(100);
247
+ return;
248
+ }
249
+ if (modelStatus === "loading") return;
250
+
251
+ setModelStatus("loading");
252
+ setLoadProgress(0);
253
+
254
+ const progress_callback = (progress: any) => {
255
+ if (
256
+ progress.status === "progress" &&
257
+ typeof progress.total === "number" &&
258
+ typeof progress.loaded === "number" &&
259
+ typeof progress.file === "string" &&
260
+ progress.file.endsWith(".onnx_data")
261
+ ) {
262
+ const percentage = Math.round(
263
+ (progress.loaded / progress.total) * 100,
264
+ );
265
+ setLoadProgress(percentage);
266
+ }
267
+ };
268
+
269
+ try {
270
+ const tokenizer = await AutoTokenizer.from_pretrained(MODEL_ID, {
271
+ progress_callback,
272
+ });
273
+ const model = await AutoModelForCausalLM.from_pretrained(MODEL_ID, {
274
+ dtype,
275
+ device: "webgpu",
276
+ progress_callback,
277
+ });
278
+ tokenizerRef.current = tokenizer;
279
+ modelRef.current = model;
280
+ setLoadProgress(100);
281
+ setModelStatus("ready");
282
+ } catch (error) {
283
+ console.error("Failed to load model", error);
284
+ setModelStatus("error");
285
+ }
286
+ },
287
+ [modelStatus],
288
+ );
289
+
290
+ return {
291
+ modelStatus,
292
+ loadProgress,
293
+ modelRef,
294
+ tokenizerRef,
295
+ loadModel,
296
+ };
297
+ };
298
+
299
+ const App: React.FC = () => {
300
+ const [messages, setMessages] = useState<Message[]>([]);
301
+ const [currentInput, setCurrentInput] = useState("");
302
+ const [context, setContext] = useState("");
303
+ const [showContext, setShowContext] = useState(false);
304
+ const [isLoading, setIsLoading] = useState(false);
305
+ const [lastMessageMinHeight, setLastMessageMinHeight] = useState<
306
+ number | undefined
307
+ >(undefined);
308
+ const [selectedDtype, setSelectedDtype] = useState<Dtype>("fp16");
309
+ const [dtypeMenuOpen, setDtypeMenuOpen] = useState(false);
310
+
311
+ const stoppingCriteriaRef = useRef<InterruptableStoppingCriteria | null>(
312
+ null,
313
+ );
314
+ const mainRef = useRef<HTMLDivElement>(null);
315
+
316
+ const { modelStatus, loadProgress, modelRef, tokenizerRef, loadModel } =
317
+ useLLM();
318
+
319
+ useLayoutEffect(() => {
320
+ if (!mainRef.current) return;
321
+ const el = mainRef.current;
322
+
323
+ // If the last message is from the assistant, calculate a min-height to prevent layout shifts.
324
+ if (messages.at(-1)?.role === "assistant") {
325
+ const userMessageElement = el.querySelector<HTMLDivElement>(
326
+ `[data-message-id="${messages.at(-2)?.id}"]`,
327
+ );
328
+ if (userMessageElement) {
329
+ const userMessageHeight =
330
+ userMessageElement.getBoundingClientRect().height;
331
+ const screenHeight = window.innerHeight;
332
+ const newMinHeight = Math.max(
333
+ screenHeight - userMessageHeight - 270,
334
+ 0,
335
+ );
336
+ setLastMessageMinHeight(newMinHeight);
337
+ }
338
+ } else {
339
+ setLastMessageMinHeight(undefined);
340
+ }
341
+ }, [messages.length]);
342
+
343
+ useLayoutEffect(() => {
344
+ if (mainRef.current) {
345
+ const el = mainRef.current;
346
+ setTimeout(() => {
347
+ el.scrollTo({
348
+ top: el.scrollHeight,
349
+ behavior: "smooth",
350
+ });
351
+ }, 0);
352
+ }
353
+ }, [messages.length, lastMessageMinHeight]);
354
+
355
+ const handleStreamUpdate = useCallback((newToken: string) => {
356
+ setMessages((prev) => {
357
+ if (prev.length === 0 || prev.at(-1)!.role === "user") {
358
+ return prev;
359
+ }
360
+
361
+ const lastMessage = { ...prev.at(-1)! };
362
+ lastMessage.rawStream = (lastMessage.rawStream || "") + newToken;
363
+
364
+ const raw = lastMessage.rawStream;
365
+ const thinkEndTag = "</think>";
366
+ const thinkEndIndex = raw.indexOf(thinkEndTag);
367
+
368
+ let content;
369
+ let thinkTrace = "";
370
+
371
+ if (thinkEndIndex !== -1) {
372
+ // Think block is complete.
373
+ thinkTrace = raw.substring(0, thinkEndIndex);
374
+ const contentAfter = raw.substring(thinkEndIndex + thinkEndTag.length);
375
+ content = contentAfter.replace("<|im_end|><|end_of_text|>", "");
376
+ if (!lastMessage.thinkEndTime) {
377
+ lastMessage.thinkEndTime = Date.now();
378
+ }
379
+ } else {
380
+ // Think block has started but not finished.
381
+ thinkTrace = raw;
382
+ content = "";
383
+ }
384
+
385
+ lastMessage.content = content.trim();
386
+ lastMessage.thinkTrace = thinkTrace.trim();
387
+
388
+ return [...prev.slice(0, -1), lastMessage];
389
+ });
390
+ }, []);
391
+
392
+ const handleStopGeneration = useCallback(() => {
393
+ stoppingCriteriaRef.current?.interrupt();
394
+ }, []);
395
+
396
+ const streamAssistantResponse = useCallback(
397
+ async (
398
+ historyForModel: { role: Role; content: string }[],
399
+ assistantMessageId: number,
400
+ ) => {
401
+ const tokenizer = tokenizerRef.current;
402
+ const model = modelRef.current;
403
+ if (!tokenizer || !model) return;
404
+
405
+ const inputs = tokenizer.apply_chat_template(historyForModel, {
406
+ add_generation_prompt: true,
407
+ return_dict: true,
408
+ }) as any;
409
+
410
+ const streamer = new TextStreamer(tokenizer, {
411
+ skip_prompt: true,
412
+ skip_special_tokens: false,
413
+ callback_function: (token: string) => handleStreamUpdate(token),
414
+ });
415
+
416
+ const stoppingCriteria = new InterruptableStoppingCriteria();
417
+ stoppingCriteriaRef.current = stoppingCriteria;
418
+
419
+ try {
420
+ await model.generate({
421
+ ...inputs,
422
+ max_new_tokens: 2048,
423
+ streamer,
424
+ stopping_criteria: stoppingCriteria,
425
+
426
+ repetition_penalty: 1.2,
427
+ });
428
+ } catch (error) {
429
+ console.error(error);
430
+ } finally {
431
+ stoppingCriteriaRef.current = null;
432
+ setIsLoading(false);
433
+ setMessages((prev) =>
434
+ prev.map((msg) => {
435
+ if (msg.id === assistantMessageId) {
436
+ const { rawStream, isLoading: _, ...rest } = msg;
437
+ return rest;
438
+ }
439
+ return msg;
440
+ }),
441
+ );
442
+ }
443
+ },
444
+ [handleStreamUpdate, modelRef, tokenizerRef],
445
+ );
446
+
447
+ const handleSubmit = async (
448
+ e?: React.FormEvent,
449
+ prompt?: string,
450
+ sources?: string,
451
+ ) => {
452
+ if (e) e.preventDefault();
453
+ if (isLoading || modelStatus !== "ready") return;
454
+
455
+ const input = prompt || currentInput;
456
+ if (!input.trim()) return;
457
+
458
+ const trimmedContext = (sources || context).trim();
459
+ const {
460
+ payload: sourcesPayload,
461
+ count: sourceCount,
462
+ segments: sourceSegments,
463
+ } = buildSourcesPayload(trimmedContext);
464
+
465
+ const fullPrompt = `${input}${sourcesPayload}`;
466
+
467
+ const sanitizedInput = sanitizeInput(input);
468
+ let userMessageContent = sanitizedInput;
469
+
470
+ if (sourceCount > 0) {
471
+ const sourcesList = sourceSegments
472
+ .map(
473
+ (seg, i) =>
474
+ `${i + 1}. ${seg.substring(0, 75)}${seg.length > 75 ? "..." : ""}`,
475
+ )
476
+ .join("\n");
477
+ userMessageContent += `\n\n[Source${sourceCount > 1 ? "s" : ""}]:\n${sourcesList}`;
478
+ }
479
+
480
+ const userMessage: Message = {
481
+ id: messages.length,
482
+ role: "user",
483
+ content: userMessageContent,
484
+ };
485
+
486
+ const assistantPlaceholder: Message = {
487
+ id: messages.length + 1,
488
+ role: "assistant",
489
+ content: "",
490
+ thinkTrace: "",
491
+ rawStream: "",
492
+ isLoading: true,
493
+ timestamp: Date.now(),
494
+ };
495
+
496
+ setMessages((prev) => [...prev, userMessage, assistantPlaceholder]);
497
+ setCurrentInput("");
498
+ setContext("");
499
+ setShowContext(false);
500
+ setIsLoading(true);
501
+ setLastMessageMinHeight(undefined);
502
+
503
+ const historyForModel = [
504
+ ...messages.map(({ role, content }) => ({ role, content })),
505
+ { role: "user" as Role, content: fullPrompt },
506
+ ];
507
+
508
+ await streamAssistantResponse(historyForModel, assistantPlaceholder.id);
509
+ };
510
+
511
+ const handleNewChat = () => {
512
+ handleStopGeneration();
513
+ setMessages([]);
514
+ setCurrentInput("");
515
+ setContext("");
516
+ setShowContext(false);
517
+ setIsLoading(false);
518
+ setLastMessageMinHeight(undefined);
519
+ };
520
+
521
+ return (
522
+ <div className="flex flex-col h-screen bg-amber-50 font-sans text-stone-800">
523
+ {modelStatus === "ready" && (
524
+ <header className="flex-shrink-0 sticky top-0 z-10 flex items-center justify-between p-4 bg-white/90 backdrop-blur-md shadow-sm border-b border-amber-200 h-[100px]">
525
+ <button
526
+ onClick={handleNewChat}
527
+ className="p-2 rounded-full text-stone-500 hover:text-amber-600 hover:bg-amber-50 transition-colors"
528
+ title="New Chat"
529
+ >
530
+ <Plus size={20} />
531
+ </button>
532
+ <div className="flex-1 text-center">
533
+ <h1 className="text-2xl md:text-3xl font-serif font-bold text-amber-800">
534
+ 🥖 Baguettotron WebGPU
535
+ </h1>
536
+ <p className="text-sm text-stone-600">
537
+ A small but powerful reasoning model
538
+ </p>
539
+ </div>
540
+ </header>
541
+ )}
542
+
543
+ <main ref={mainRef} className="flex-grow overflow-y-auto">
544
+ <div className="mx-auto w-full max-w-6xl p-4 md:p-6 space-y-2 h-full">
545
+ {modelStatus !== "ready" ? (
546
+ <div className="flex h-full flex-col items-center justify-center gap-6 text-center text-stone-600">
547
+ <span className="text-8xl animate-wobble">🥖</span>
548
+ <div>
549
+ <h1 className="text-5xl font-bold text-amber-800">
550
+ Baguettotron WebGPU
551
+ </h1>
552
+ <p className="mt-4 max-w-xl text-md">
553
+ You are about to load Baguettotron, a 300M parameter reasoning
554
+ model optimized for in-browser inference. Everything runs
555
+ entirely in your browser with 🤗 Transformers.js and ONNX
556
+ Runtime Web, meaning no data is sent to a server. Once loaded,
557
+ it can even be used offline.
558
+ </p>
559
+ </div>
560
+ <div className="relative inline-flex rounded-full shadow-sm">
561
+ <button
562
+ onClick={() => loadModel(selectedDtype)}
563
+ disabled={modelStatus === "loading"}
564
+ className="rounded-l-full bg-amber-600 pl-6 pr-5 py-3 text-white font-medium transition hover:bg-amber-700 disabled:opacity-50 disabled:cursor-not-allowed"
565
+ >
566
+ {modelStatus === "loading"
567
+ ? `Loading ${loadProgress}%`
568
+ : `Load model (${selectedDtype})`}
569
+ </button>
570
+ <button
571
+ onClick={() => setDtypeMenuOpen(!dtypeMenuOpen)}
572
+ disabled={modelStatus === "loading"}
573
+ className="rounded-r-full bg-amber-600 px-3 py-3 text-white transition hover:bg-amber-700 disabled:opacity-50 border-l border-amber-500"
574
+ >
575
+ <ChevronDown
576
+ size={20}
577
+ className={`transform transition-transform ${dtypeMenuOpen ? "rotate-180" : ""}`}
578
+ />
579
+ </button>
580
+ {dtypeMenuOpen && (
581
+ <div className="absolute top-full mt-2 w-full bg-white rounded-md shadow-lg z-10 border border-stone-200">
582
+ {Object.entries(DTYPES).map(([dtype, label]) => (
583
+ <button
584
+ key={dtype}
585
+ onClick={() => {
586
+ setSelectedDtype(dtype as Dtype);
587
+ setDtypeMenuOpen(false);
588
+ }}
589
+ className="w-full text-left px-4 py-2 text-sm text-stone-700 hover:bg-amber-50"
590
+ >
591
+ {label}
592
+ </button>
593
+ ))}
594
+ </div>
595
+ )}
596
+ </div>
597
+ {modelStatus === "error" && (
598
+ <p className="text-sm text-red-600">
599
+ Model load failed. Check console for details and retry.
600
+ </p>
601
+ )}
602
+ </div>
603
+ ) : (
604
+ <>
605
+ {messages.length === 0 && (
606
+ <div className="flex flex-col items-center justify-center h-full text-center text-stone-500">
607
+ <div className="p-8 rounded-2xl flex flex-col items-center">
608
+ <h2 className="text-3xl font-semibold mt-4 text-stone-700">
609
+ Welcome to Baguettotron
610
+ </h2>
611
+ <h3 className="max-w-xs mt-1 text-lg">
612
+ Ask me a question, or try one of the examples below!
613
+ </h3>
614
+ </div>
615
+ <div className="mt-2 flex flex-wrap justify-center gap-4">
616
+ <button
617
+ onClick={() =>
618
+ handleSubmit(
619
+ undefined,
620
+ "What is the capital of France? Just provide the answer.",
621
+ )
622
+ }
623
+ className="bg-amber-100 hover:bg-amber-200 text-amber-800 px-4 py-2 rounded-lg shadow-sm border border-amber-200 transition-colors"
624
+ >
625
+ Encyclopedic knowledge
626
+ </button>
627
+ {["fp32", "fp16"].includes(selectedDtype) && (
628
+ <button
629
+ onClick={() =>
630
+ handleSubmit(
631
+ undefined,
632
+ "Write me a short poem about machine learning.",
633
+ )
634
+ }
635
+ className="bg-amber-100 hover:bg-amber-200 text-amber-800 px-4 py-2 rounded-lg shadow-sm border border-amber-200 transition-colors"
636
+ >
637
+ Creative writing
638
+ </button>
639
+ )}
640
+ <button
641
+ onClick={() =>
642
+ handleSubmit(
643
+ undefined,
644
+ "Which is wider: Australia or the Moon?",
645
+ "Australia is approximately 4,000 km in width from east to west, according to Geoscience Australia.\n\nThe diameter of the Moon is about 3,476 km, according to Britannica.",
646
+ )
647
+ }
648
+ className="bg-amber-100 hover:bg-amber-200 text-amber-800 px-4 py-2 rounded-lg shadow-sm border border-amber-200 transition-colors"
649
+ >
650
+ RAG with grounding
651
+ </button>
652
+ </div>
653
+ </div>
654
+ )}
655
+ {messages.map((msg, index) => {
656
+ const isLastAssistantMessage =
657
+ index === messages.length - 1 && msg.role === "assistant";
658
+ const minHeight = isLastAssistantMessage
659
+ ? lastMessageMinHeight
660
+ : undefined;
661
+
662
+ return (
663
+ <MessageBubble
664
+ key={msg.id}
665
+ message={msg}
666
+ minHeight={minHeight}
667
+ />
668
+ );
669
+ })}
670
+ </>
671
+ )}
672
+ </div>
673
+ </main>
674
+
675
+ {modelStatus === "ready" && (
676
+ <footer className="flex-shrink-0 sticky bottom-0 z-10 p-4 bg-white/90 backdrop-blur-md border-t border-amber-100">
677
+ <form onSubmit={handleSubmit} className="max-w-3xl mx-auto">
678
+ <div
679
+ style={{
680
+ maxHeight: showContext ? "120px" : "0px",
681
+ transition: "max-height 0.3s ease-in-out",
682
+ opacity: showContext ? 1 : 0,
683
+ }}
684
+ className="overflow-hidden relative"
685
+ >
686
+ <textarea
687
+ value={context}
688
+ onChange={(e) => setContext(e.target.value)}
689
+ disabled={isLoading}
690
+ placeholder="Add RAG context here. Separate multiple sources with two new lines."
691
+ className="w-full h-28 p-2 mb-2 rounded-lg border border-stone-300 focus:ring-amber-500 focus:border-amber-500 text-sm resize-none"
692
+ />
693
+ <button
694
+ type="button"
695
+ onClick={() => setShowContext(false)}
696
+ className="absolute top-2 right-2 p-1 text-stone-400 hover:text-stone-600 bg-white/50 rounded-full"
697
+ >
698
+ <X size={16} />
699
+ </button>
700
+ </div>
701
+
702
+ <div className="flex items-center space-x-2">
703
+ <button
704
+ type="button"
705
+ onClick={() => setShowContext(!showContext)}
706
+ title="Add Context for RAG"
707
+ className={`flex-shrink-0 p-2 rounded-full transition-colors ${
708
+ showContext
709
+ ? "bg-amber-100 text-amber-700"
710
+ : "text-stone-500 hover:text-amber-600 hover:bg-amber-50"
711
+ }`}
712
+ >
713
+ <Paperclip size={20} />
714
+ </button>
715
+ <input
716
+ type="text"
717
+ value={currentInput}
718
+ onChange={(e) => setCurrentInput(e.target.value)}
719
+ placeholder="Send a message..."
720
+ className="flex-grow px-4 py-2 rounded-full border border-stone-300 focus:ring-2 focus:ring-amber-500 focus:border-transparent outline-none transition-shadow"
721
+ disabled={isLoading || modelStatus !== "ready"}
722
+ />
723
+ {isLoading ? (
724
+ <button
725
+ type="button"
726
+ onClick={handleStopGeneration}
727
+ className="group flex h-10 w-10 flex-shrink-0 items-center justify-center rounded-full border border-stone-300 bg-white text-stone-600 hover:border-red-500"
728
+ >
729
+ <span className="h-3.5 w-3.5 rounded-sm bg-stone-600 transition-colors group-hover:bg-red-500" />
730
+ </button>
731
+ ) : (
732
+ <button
733
+ type="submit"
734
+ disabled={
735
+ isLoading || !currentInput.trim() || modelStatus !== "ready"
736
+ }
737
+ className="flex h-10 w-10 flex-shrink-0 items-center justify-center rounded-full bg-amber-600 text-white transition-all transform
738
+ hover:bg-amber-700 hover:scale-105 active:scale-95
739
+ disabled:bg-stone-300 disabled:scale-100 disabled:cursor-not-allowed"
740
+ >
741
+ <Send size={20} />
742
+ </button>
743
+ )}
744
+ </div>
745
+ <p className="text-center text-xs text-stone-400 mt-2">
746
+ ⚡ Powered by{" "}
747
+ <a
748
+ href="https://github.com/huggingface/transformers.js"
749
+ target="_blank"
750
+ rel="noopener noreferrer"
751
+ >
752
+ Transformers.js
753
+ </a>{" "}
754
+ — Runs locally in your browser on WebGPU.
755
+ </p>
756
+ </form>
757
+ </footer>
758
+ )}
759
+ </div>
760
+ );
761
+ };
762
+
763
+ export default App;
src/index.css ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @import "tailwindcss";
2
+
3
+ @source "../node_modules/streamdown/dist/index.js";
4
+
5
+ @keyframes glisten {
6
+ 0% {
7
+ background-position: 200% 0;
8
+ }
9
+ 100% {
10
+ background-position: -200% 0;
11
+ }
12
+ }
13
+
14
+ .animate-glisten {
15
+ /* Using amber-700 and amber-300 for the gradient */
16
+ background-image: linear-gradient(
17
+ 90deg,
18
+ rgba(168, 85, 6, 1) 0%,
19
+ rgba(253, 186, 116, 1) 40%,
20
+ rgba(253, 186, 116, 1) 60%,
21
+ rgba(168, 85, 6, 1) 100%
22
+ );
23
+ color: transparent;
24
+ background-size: 200% auto;
25
+ background-clip: text;
26
+ -webkit-background-clip: text;
27
+ animation: glisten 3s linear infinite;
28
+ }
29
+
30
+ /* Overrides */
31
+ .message h3[data-streamdown="heading-3"] {
32
+ @apply mt-3;
33
+ }
34
+
35
+ .message .collapsible p {
36
+ @apply my-1;
37
+ }
38
+
39
+ .message p {
40
+ @apply whitespace-pre-wrap;
41
+ }
42
+
43
+ @keyframes wobble {
44
+ 0%,
45
+ 100% {
46
+ transform: rotate(-3deg);
47
+ }
48
+ 50% {
49
+ transform: rotate(3deg);
50
+ }
51
+ }
52
+ .animate-wobble {
53
+ animation: wobble 2s ease-in-out infinite;
54
+ }
src/main.tsx ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import { StrictMode } from "react";
2
+ import { createRoot } from "react-dom/client";
3
+ import "./index.css";
4
+ import App from "./App.tsx";
5
+
6
+ createRoot(document.getElementById("root")!).render(
7
+ <StrictMode>
8
+ <App />
9
+ </StrictMode>,
10
+ );
tsconfig.app.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "compilerOptions": {
3
+ "tsBuildInfoFile": "./node_modules/.tmp/tsconfig.app.tsbuildinfo",
4
+ "target": "ES2022",
5
+ "useDefineForClassFields": true,
6
+ "lib": ["ES2022", "DOM", "DOM.Iterable"],
7
+ "module": "ESNext",
8
+ "types": ["vite/client"],
9
+ "skipLibCheck": true,
10
+
11
+ /* Bundler mode */
12
+ "moduleResolution": "bundler",
13
+ "allowImportingTsExtensions": true,
14
+ "verbatimModuleSyntax": true,
15
+ "moduleDetection": "force",
16
+ "noEmit": true,
17
+ "jsx": "react-jsx",
18
+
19
+ /* Linting */
20
+ "strict": true,
21
+ "noUnusedLocals": true,
22
+ "noUnusedParameters": true,
23
+ "erasableSyntaxOnly": true,
24
+ "noFallthroughCasesInSwitch": true,
25
+ "noUncheckedSideEffectImports": true
26
+ },
27
+ "include": ["src"]
28
+ }
tsconfig.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "files": [],
3
+ "references": [
4
+ { "path": "./tsconfig.app.json" },
5
+ { "path": "./tsconfig.node.json" }
6
+ ]
7
+ }
tsconfig.node.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "compilerOptions": {
3
+ "tsBuildInfoFile": "./node_modules/.tmp/tsconfig.node.tsbuildinfo",
4
+ "target": "ES2023",
5
+ "lib": ["ES2023"],
6
+ "module": "ESNext",
7
+ "types": ["node"],
8
+ "skipLibCheck": true,
9
+
10
+ /* Bundler mode */
11
+ "moduleResolution": "bundler",
12
+ "allowImportingTsExtensions": true,
13
+ "verbatimModuleSyntax": true,
14
+ "moduleDetection": "force",
15
+ "noEmit": true,
16
+
17
+ /* Linting */
18
+ "strict": true,
19
+ "noUnusedLocals": true,
20
+ "noUnusedParameters": true,
21
+ "erasableSyntaxOnly": true,
22
+ "noFallthroughCasesInSwitch": true,
23
+ "noUncheckedSideEffectImports": true
24
+ },
25
+ "include": ["vite.config.ts"]
26
+ }
vite.config.ts ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import { defineConfig } from "vite";
2
+ import react from "@vitejs/plugin-react";
3
+ import tailwindcss from "@tailwindcss/vite";
4
+
5
+ // https://vite.dev/config/
6
+ export default defineConfig({
7
+ plugins: [react(), tailwindcss()],
8
+ });