Xenova's picture
Xenova HF Staff
Upload 10 files
c369b38 verified
import { useState, useRef, useCallback, useLayoutEffect } from "react";
import { Send, Paperclip, Brain, ChevronDown, X, Plus } from "lucide-react";
import {
AutoModelForCausalLM,
AutoTokenizer,
InterruptableStoppingCriteria,
TextStreamer,
} from "@huggingface/transformers";
import { Streamdown } from "streamdown";
import type {
PreTrainedTokenizer,
LlamaForCausalLM,
} from "@huggingface/transformers";
import type React from "react";
const MODEL_ID = "onnx-community/Baguettotron-ONNX";
const DTYPES = {
fp32: "FP32 (~1.28 GB)",
fp16: "FP16 (~642 MB)",
q4: "Q4 (~329 MB)",
q4f16: "Q4F16 (~235 MB)",
} as const;
type Dtype = keyof typeof DTYPES;
const SOURCE_SEPARATOR_REGEX = /\n{2,}/g;
type Role = "user" | "assistant";
/**
* Format the sources into tagged segments for the model input.
*/
const buildSourcesPayload = (rawContext: string) => {
const trimmed = rawContext.trim();
if (!trimmed) {
return { payload: "", count: 0, segments: [] };
}
const segments = trimmed
.split(SOURCE_SEPARATOR_REGEX)
.map((segment) => segment.trim())
.filter(Boolean);
if (segments.length === 0) {
return { payload: "", count: 0, segments: [] };
}
const payload =
"\n\n" +
segments
.map(
(segment, index) =>
`<source_${index + 1}>${segment}</source_${index + 1}>`,
)
.join("\n");
return { payload, count: segments.length, segments };
};
/**
* Converts <ref>...</ref> tags in the content to superscript references.
*/
const convertRefsToSuperscript = (content: string) => {
const refRegex = /<ref name="([^"]+)">([\s\S]*?)<\/ref>/g;
const refLabelMap = new Map<string, number>();
let refCounter = 1;
// First, process all complete <ref>...</ref> tags
let result = content.replace(refRegex, (_, sourceName = "", refBody) => {
const label =
refLabelMap.get(sourceName) ??
(() => {
const assigned = refCounter++;
refLabelMap.set(sourceName, assigned);
return assigned;
})();
const escapedRefBody = refBody.replace(/"/g, "&quot;");
return `<sup className="cursor-pointer" title="${escapedRefBody}">[${label}]</sup>`;
});
// Remove any trailing incomplete <ref> tag
const incompleteRefRegex = /<ref[^>]*>[\s\S]*$/;
result = result.replace(incompleteRefRegex, "");
return result;
};
/**
* Sanitizes user input by replacing angle brackets.
*/
const sanitizeInput = (text: string) => {
return text.replace(/</g, "&lt;").replace(/>/g, "&gt;");
};
/**
* Represents a single chat message in the history.
*/
interface Message {
id: number;
role: Role;
content: string;
thinkTrace?: string;
rawStream?: string;
isLoading?: boolean;
timestamp?: number;
thinkEndTime?: number;
}
/**
* A simple, self-contained collapsible component.
*/
const Collapsible: React.FC<{
title: React.ReactNode;
children: React.ReactNode;
}> = ({ title, children }) => {
const [isOpen, setIsOpen] = useState(false);
const contentRef = useRef<HTMLDivElement>(null);
return (
<div className="collapsible mt-2">
<button
onClick={() => setIsOpen(!isOpen)}
className="flex items-center space-x-1 text-xs font-medium text-amber-700 hover:text-amber-900 transition-colors"
>
{title}
<ChevronDown
size={14}
className={`transform transition-transform ${isOpen ? "rotate-180" : "rotate-0"}`}
/>
</button>
<div
ref={contentRef}
style={{
maxHeight: isOpen ? `${contentRef.current?.scrollHeight}px` : "0px",
}}
className="overflow-hidden transition-all duration-300 ease-in-out"
>
<div className="mt-2 p-2 bg-amber-50 border border-dashed border-amber-200 rounded-md text-xs text-stone-600 prose-sm">
{children}
</div>
</div>
</div>
);
};
/**
* A single chat message bubble.
*/
const MessageBubble: React.FC<{ message: Message; minHeight?: number }> = ({
message,
minHeight,
}) => {
const { role, content, thinkTrace, isLoading, timestamp, thinkEndTime } =
message;
const isUser = role === "user";
let thinkingText = "";
let opacityClass = "";
const hasDuration =
typeof thinkEndTime === "number" && typeof timestamp === "number";
const durationSeconds = hasDuration
? Math.max(Math.round((thinkEndTime - timestamp) / 1000), 0)
: null;
if (isLoading && !thinkEndTime) {
thinkingText = "Thinking...";
opacityClass = "opacity-70 hover:opacity-100";
} else if (thinkTrace) {
thinkingText =
durationSeconds !== null
? `Thought for ${durationSeconds} seconds`
: "Thought interrupted";
} else {
thinkingText = "Show Thoughts";
}
const markdownContent = convertRefsToSuperscript(content);
return (
<div
data-message-id={message.id}
data-role={role}
className={`message flex items-start animate-in fade-in slide-in-from-bottom-2 duration-300 py-2 ${isUser ? "justify-end" : "justify-start"}`}
style={{
minHeight,
}}
>
<div
className={`max-w-xl lg:max-w-2xl px-4 py-3 rounded-2xl ${
isUser
? "bg-amber-500 text-white rounded-br-none"
: "bg-white text-stone-800 rounded-bl-none shadow-sm border border-stone-200"
}`}
>
{(thinkTrace || isLoading) && (
<Collapsible
title={
<div className="flex items-center space-x-1.5 text-sm">
<Brain size={16} />
<span
className={`${isLoading ? "animate-glisten" : ""} ${opacityClass}`}
>
{thinkingText}
</span>
</div>
}
>
<Streamdown
parseIncompleteMarkdown={false}
className="text-xs text-stone-500"
isAnimating={Boolean(isLoading && thinkEndTime)}
>
{thinkTrace || (isLoading ? "..." : "")}
</Streamdown>
</Collapsible>
)}
<div className={`${thinkTrace || isLoading ? "mt-2" : ""}`}>
<Streamdown
parseIncompleteMarkdown={false}
className="text-sm leading-relaxed text-stone-800"
isAnimating={Boolean(isLoading && !thinkEndTime)}
>
{markdownContent || (isLoading ? "" : "")}
</Streamdown>
</div>
</div>
</div>
);
};
/**
* Manages the model and tokenizer loading state and refs.
*/
const useLLM = () => {
const [modelStatus, setModelStatus] = useState<
"idle" | "loading" | "ready" | "error"
>("idle");
const [loadProgress, setLoadProgress] = useState(0);
const modelRef = useRef<LlamaForCausalLM | null>(null);
const tokenizerRef = useRef<PreTrainedTokenizer | null>(null);
const loadModel = useCallback(
async (dtype: Dtype) => {
if (modelRef.current && tokenizerRef.current) {
setModelStatus("ready");
setLoadProgress(100);
return;
}
if (modelStatus === "loading") return;
setModelStatus("loading");
setLoadProgress(0);
const progress_callback = (progress: any) => {
if (
progress.status === "progress" &&
typeof progress.total === "number" &&
typeof progress.loaded === "number" &&
typeof progress.file === "string" &&
progress.file.endsWith(".onnx_data")
) {
const percentage = Math.round(
(progress.loaded / progress.total) * 100,
);
setLoadProgress(percentage);
}
};
try {
const tokenizer = await AutoTokenizer.from_pretrained(MODEL_ID, {
progress_callback,
});
const model = await AutoModelForCausalLM.from_pretrained(MODEL_ID, {
dtype,
device: "webgpu",
progress_callback,
});
tokenizerRef.current = tokenizer;
modelRef.current = model;
setLoadProgress(100);
setModelStatus("ready");
} catch (error) {
console.error("Failed to load model", error);
setModelStatus("error");
}
},
[modelStatus],
);
return {
modelStatus,
loadProgress,
modelRef,
tokenizerRef,
loadModel,
};
};
const App: React.FC = () => {
const [messages, setMessages] = useState<Message[]>([]);
const [currentInput, setCurrentInput] = useState("");
const [context, setContext] = useState("");
const [showContext, setShowContext] = useState(false);
const [isLoading, setIsLoading] = useState(false);
const [lastMessageMinHeight, setLastMessageMinHeight] = useState<
number | undefined
>(undefined);
const [selectedDtype, setSelectedDtype] = useState<Dtype>("fp16");
const [dtypeMenuOpen, setDtypeMenuOpen] = useState(false);
const stoppingCriteriaRef = useRef<InterruptableStoppingCriteria | null>(
null,
);
const mainRef = useRef<HTMLDivElement>(null);
const { modelStatus, loadProgress, modelRef, tokenizerRef, loadModel } =
useLLM();
useLayoutEffect(() => {
if (!mainRef.current) return;
const el = mainRef.current;
// If the last message is from the assistant, calculate a min-height to prevent layout shifts.
if (messages.at(-1)?.role === "assistant") {
const userMessageElement = el.querySelector<HTMLDivElement>(
`[data-message-id="${messages.at(-2)?.id}"]`,
);
if (userMessageElement) {
const userMessageHeight =
userMessageElement.getBoundingClientRect().height;
const screenHeight = window.innerHeight;
const newMinHeight = Math.max(
screenHeight - userMessageHeight - 270,
0,
);
setLastMessageMinHeight(newMinHeight);
}
} else {
setLastMessageMinHeight(undefined);
}
}, [messages.length]);
useLayoutEffect(() => {
if (mainRef.current) {
const el = mainRef.current;
setTimeout(() => {
el.scrollTo({
top: el.scrollHeight,
behavior: "smooth",
});
}, 0);
}
}, [messages.length, lastMessageMinHeight]);
const handleStreamUpdate = useCallback((newToken: string) => {
setMessages((prev) => {
if (prev.length === 0 || prev.at(-1)!.role === "user") {
return prev;
}
const lastMessage = { ...prev.at(-1)! };
lastMessage.rawStream = (lastMessage.rawStream || "") + newToken;
const raw = lastMessage.rawStream;
const thinkEndTag = "</think>";
const thinkEndIndex = raw.indexOf(thinkEndTag);
let content;
let thinkTrace = "";
if (thinkEndIndex !== -1) {
// Think block is complete.
thinkTrace = raw.substring(0, thinkEndIndex);
const contentAfter = raw.substring(thinkEndIndex + thinkEndTag.length);
content = contentAfter.replace("<|im_end|><|end_of_text|>", "");
if (!lastMessage.thinkEndTime) {
lastMessage.thinkEndTime = Date.now();
}
} else {
// Think block has started but not finished.
thinkTrace = raw;
content = "";
}
lastMessage.content = content.trim();
lastMessage.thinkTrace = thinkTrace.trim();
return [...prev.slice(0, -1), lastMessage];
});
}, []);
const handleStopGeneration = useCallback(() => {
stoppingCriteriaRef.current?.interrupt();
}, []);
const streamAssistantResponse = useCallback(
async (
historyForModel: { role: Role; content: string }[],
assistantMessageId: number,
) => {
const tokenizer = tokenizerRef.current;
const model = modelRef.current;
if (!tokenizer || !model) return;
const inputs = tokenizer.apply_chat_template(historyForModel, {
add_generation_prompt: true,
return_dict: true,
}) as any;
const streamer = new TextStreamer(tokenizer, {
skip_prompt: true,
skip_special_tokens: false,
callback_function: (token: string) => handleStreamUpdate(token),
});
const stoppingCriteria = new InterruptableStoppingCriteria();
stoppingCriteriaRef.current = stoppingCriteria;
try {
await model.generate({
...inputs,
max_new_tokens: 2048,
streamer,
stopping_criteria: stoppingCriteria,
repetition_penalty: 1.2,
});
} catch (error) {
console.error(error);
} finally {
stoppingCriteriaRef.current = null;
setIsLoading(false);
setMessages((prev) =>
prev.map((msg) => {
if (msg.id === assistantMessageId) {
const { rawStream, isLoading: _, ...rest } = msg;
return rest;
}
return msg;
}),
);
}
},
[handleStreamUpdate, modelRef, tokenizerRef],
);
const handleSubmit = async (
e?: React.FormEvent,
prompt?: string,
sources?: string,
) => {
if (e) e.preventDefault();
if (isLoading || modelStatus !== "ready") return;
const input = prompt || currentInput;
if (!input.trim()) return;
const trimmedContext = (sources || context).trim();
const {
payload: sourcesPayload,
count: sourceCount,
segments: sourceSegments,
} = buildSourcesPayload(trimmedContext);
const fullPrompt = `${input}${sourcesPayload}`;
const sanitizedInput = sanitizeInput(input);
let userMessageContent = sanitizedInput;
if (sourceCount > 0) {
const sourcesList = sourceSegments
.map(
(seg, i) =>
`${i + 1}. ${seg.substring(0, 75)}${seg.length > 75 ? "..." : ""}`,
)
.join("\n");
userMessageContent += `\n\n[Source${sourceCount > 1 ? "s" : ""}]:\n${sourcesList}`;
}
const userMessage: Message = {
id: messages.length,
role: "user",
content: userMessageContent,
};
const assistantPlaceholder: Message = {
id: messages.length + 1,
role: "assistant",
content: "",
thinkTrace: "",
rawStream: "",
isLoading: true,
timestamp: Date.now(),
};
setMessages((prev) => [...prev, userMessage, assistantPlaceholder]);
setCurrentInput("");
setContext("");
setShowContext(false);
setIsLoading(true);
setLastMessageMinHeight(undefined);
const historyForModel = [
...messages.map(({ role, content }) => ({ role, content })),
{ role: "user" as Role, content: fullPrompt },
];
await streamAssistantResponse(historyForModel, assistantPlaceholder.id);
};
const handleNewChat = () => {
handleStopGeneration();
setMessages([]);
setCurrentInput("");
setContext("");
setShowContext(false);
setIsLoading(false);
setLastMessageMinHeight(undefined);
};
return (
<div className="flex flex-col h-screen bg-amber-50 font-sans text-stone-800">
{modelStatus === "ready" && (
<header className="flex-shrink-0 sticky top-0 z-10 flex items-center justify-between p-4 bg-white/90 backdrop-blur-md shadow-sm border-b border-amber-200 h-[100px]">
<button
onClick={handleNewChat}
className="p-2 rounded-full text-stone-500 hover:text-amber-600 hover:bg-amber-50 transition-colors"
title="New Chat"
>
<Plus size={20} />
</button>
<div className="flex-1 text-center">
<h1 className="text-2xl md:text-3xl font-serif font-bold text-amber-800">
🥖 Baguettotron WebGPU
</h1>
<p className="text-sm text-stone-600">
A small but powerful reasoning model
</p>
</div>
</header>
)}
<main ref={mainRef} className="flex-grow overflow-y-auto">
<div className="mx-auto w-full max-w-6xl p-4 md:p-6 space-y-2 h-full">
{modelStatus !== "ready" ? (
<div className="flex h-full flex-col items-center justify-center gap-6 text-center text-stone-600">
<span className="text-8xl animate-wobble">🥖</span>
<div>
<h1 className="text-5xl font-bold text-amber-800">
Baguettotron WebGPU
</h1>
<p className="mt-4 max-w-xl text-md">
You are about to load Baguettotron, a 300M parameter reasoning
model optimized for in-browser inference. Everything runs
entirely in your browser with 🤗 Transformers.js and ONNX
Runtime Web, meaning no data is sent to a server. Once loaded,
it can even be used offline.
</p>
</div>
<div className="relative inline-flex rounded-full shadow-sm">
<button
onClick={() => loadModel(selectedDtype)}
disabled={modelStatus === "loading"}
className="rounded-l-full bg-amber-600 pl-6 pr-5 py-3 text-white font-medium transition hover:bg-amber-700 disabled:opacity-50 disabled:cursor-not-allowed"
>
{modelStatus === "loading"
? `Loading ${loadProgress}%`
: `Load model (${selectedDtype})`}
</button>
<button
onClick={() => setDtypeMenuOpen(!dtypeMenuOpen)}
disabled={modelStatus === "loading"}
className="rounded-r-full bg-amber-600 px-3 py-3 text-white transition hover:bg-amber-700 disabled:opacity-50 border-l border-amber-500"
>
<ChevronDown
size={20}
className={`transform transition-transform ${dtypeMenuOpen ? "rotate-180" : ""}`}
/>
</button>
{dtypeMenuOpen && (
<div className="absolute top-full mt-2 w-full bg-white rounded-md shadow-lg z-10 border border-stone-200">
{Object.entries(DTYPES).map(([dtype, label]) => (
<button
key={dtype}
onClick={() => {
setSelectedDtype(dtype as Dtype);
setDtypeMenuOpen(false);
}}
className="w-full text-left px-4 py-2 text-sm text-stone-700 hover:bg-amber-50"
>
{label}
</button>
))}
</div>
)}
</div>
{modelStatus === "error" && (
<p className="text-sm text-red-600">
Model load failed. Check console for details and retry.
</p>
)}
</div>
) : (
<>
{messages.length === 0 && (
<div className="flex flex-col items-center justify-center h-full text-center text-stone-500">
<div className="p-8 rounded-2xl flex flex-col items-center">
<h2 className="text-3xl font-semibold mt-4 text-stone-700">
Welcome to Baguettotron
</h2>
<h3 className="max-w-xs mt-1 text-lg">
Ask me a question, or try one of the examples below!
</h3>
</div>
<div className="mt-2 flex flex-wrap justify-center gap-4">
<button
onClick={() =>
handleSubmit(
undefined,
"What is the capital of France? Just provide the answer.",
)
}
className="bg-amber-100 hover:bg-amber-200 text-amber-800 px-4 py-2 rounded-lg shadow-sm border border-amber-200 transition-colors"
>
Encyclopedic knowledge
</button>
{["fp32", "fp16"].includes(selectedDtype) && (
<button
onClick={() =>
handleSubmit(
undefined,
"Write me a short poem about machine learning.",
)
}
className="bg-amber-100 hover:bg-amber-200 text-amber-800 px-4 py-2 rounded-lg shadow-sm border border-amber-200 transition-colors"
>
Creative writing
</button>
)}
<button
onClick={() =>
handleSubmit(
undefined,
"Which is wider: Australia or the Moon?",
"Australia is approximately 4,000 km in width from east to west, according to Geoscience Australia.\n\nThe diameter of the Moon is about 3,476 km, according to Britannica.",
)
}
className="bg-amber-100 hover:bg-amber-200 text-amber-800 px-4 py-2 rounded-lg shadow-sm border border-amber-200 transition-colors"
>
RAG with grounding
</button>
</div>
</div>
)}
{messages.map((msg, index) => {
const isLastAssistantMessage =
index === messages.length - 1 && msg.role === "assistant";
const minHeight = isLastAssistantMessage
? lastMessageMinHeight
: undefined;
return (
<MessageBubble
key={msg.id}
message={msg}
minHeight={minHeight}
/>
);
})}
</>
)}
</div>
</main>
{modelStatus === "ready" && (
<footer className="flex-shrink-0 sticky bottom-0 z-10 p-4 bg-white/90 backdrop-blur-md border-t border-amber-100">
<form onSubmit={handleSubmit} className="max-w-3xl mx-auto">
<div
style={{
maxHeight: showContext ? "120px" : "0px",
transition: "max-height 0.3s ease-in-out",
opacity: showContext ? 1 : 0,
}}
className="overflow-hidden relative"
>
<textarea
value={context}
onChange={(e) => setContext(e.target.value)}
disabled={isLoading}
placeholder="Add RAG context here. Separate multiple sources with two new lines."
className="w-full h-28 p-2 mb-2 rounded-lg border border-stone-300 focus:ring-amber-500 focus:border-amber-500 text-sm resize-none"
/>
<button
type="button"
onClick={() => setShowContext(false)}
className="absolute top-2 right-2 p-1 text-stone-400 hover:text-stone-600 bg-white/50 rounded-full"
>
<X size={16} />
</button>
</div>
<div className="flex items-center space-x-2">
<button
type="button"
onClick={() => setShowContext(!showContext)}
title="Add Context for RAG"
className={`flex-shrink-0 p-2 rounded-full transition-colors ${
showContext
? "bg-amber-100 text-amber-700"
: "text-stone-500 hover:text-amber-600 hover:bg-amber-50"
}`}
>
<Paperclip size={20} />
</button>
<input
type="text"
value={currentInput}
onChange={(e) => setCurrentInput(e.target.value)}
placeholder="Send a message..."
className="flex-grow px-4 py-2 rounded-full border border-stone-300 focus:ring-2 focus:ring-amber-500 focus:border-transparent outline-none transition-shadow"
disabled={isLoading || modelStatus !== "ready"}
/>
{isLoading ? (
<button
type="button"
onClick={handleStopGeneration}
className="group flex h-10 w-10 flex-shrink-0 items-center justify-center rounded-full border border-stone-300 bg-white text-stone-600 hover:border-red-500"
>
<span className="h-3.5 w-3.5 rounded-sm bg-stone-600 transition-colors group-hover:bg-red-500" />
</button>
) : (
<button
type="submit"
disabled={
isLoading || !currentInput.trim() || modelStatus !== "ready"
}
className="flex h-10 w-10 flex-shrink-0 items-center justify-center rounded-full bg-amber-600 text-white transition-all transform
hover:bg-amber-700 hover:scale-105 active:scale-95
disabled:bg-stone-300 disabled:scale-100 disabled:cursor-not-allowed"
>
<Send size={20} />
</button>
)}
</div>
<p className="text-center text-xs text-stone-400 mt-2">
⚡ Powered by{" "}
<a
href="https://github.com/huggingface/transformers.js"
target="_blank"
rel="noopener noreferrer"
>
Transformers.js
</a>{" "}
— Runs locally in your browser on WebGPU.
</p>
</form>
</footer>
)}
</div>
);
};
export default App;