import { useState, useRef, useCallback, useLayoutEffect } from "react"; import { Send, Paperclip, Brain, ChevronDown, X, Plus } from "lucide-react"; import { AutoModelForCausalLM, AutoTokenizer, InterruptableStoppingCriteria, TextStreamer, } from "@huggingface/transformers"; import { Streamdown } from "streamdown"; import type { PreTrainedTokenizer, LlamaForCausalLM, } from "@huggingface/transformers"; import type React from "react"; const MODEL_ID = "onnx-community/Baguettotron-ONNX"; const DTYPES = { fp32: "FP32 (~1.28 GB)", fp16: "FP16 (~642 MB)", q4: "Q4 (~329 MB)", q4f16: "Q4F16 (~235 MB)", } as const; type Dtype = keyof typeof DTYPES; const SOURCE_SEPARATOR_REGEX = /\n{2,}/g; type Role = "user" | "assistant"; /** * Format the sources into tagged segments for the model input. */ const buildSourcesPayload = (rawContext: string) => { const trimmed = rawContext.trim(); if (!trimmed) { return { payload: "", count: 0, segments: [] }; } const segments = trimmed .split(SOURCE_SEPARATOR_REGEX) .map((segment) => segment.trim()) .filter(Boolean); if (segments.length === 0) { return { payload: "", count: 0, segments: [] }; } const payload = "\n\n" + segments .map( (segment, index) => `${segment}`, ) .join("\n"); return { payload, count: segments.length, segments }; }; /** * Converts ... tags in the content to superscript references. */ const convertRefsToSuperscript = (content: string) => { const refRegex = /([\s\S]*?)<\/ref>/g; const refLabelMap = new Map(); let refCounter = 1; // First, process all complete ... tags let result = content.replace(refRegex, (_, sourceName = "", refBody) => { const label = refLabelMap.get(sourceName) ?? (() => { const assigned = refCounter++; refLabelMap.set(sourceName, assigned); return assigned; })(); const escapedRefBody = refBody.replace(/"/g, """); return `[${label}]`; }); // Remove any trailing incomplete tag const incompleteRefRegex = /]*>[\s\S]*$/; result = result.replace(incompleteRefRegex, ""); return result; }; /** * Sanitizes user input by replacing angle brackets. */ const sanitizeInput = (text: string) => { return text.replace(//g, ">"); }; /** * Represents a single chat message in the history. */ interface Message { id: number; role: Role; content: string; thinkTrace?: string; rawStream?: string; isLoading?: boolean; timestamp?: number; thinkEndTime?: number; } /** * A simple, self-contained collapsible component. */ const Collapsible: React.FC<{ title: React.ReactNode; children: React.ReactNode; }> = ({ title, children }) => { const [isOpen, setIsOpen] = useState(false); const contentRef = useRef(null); return (
{children}
); }; /** * A single chat message bubble. */ const MessageBubble: React.FC<{ message: Message; minHeight?: number }> = ({ message, minHeight, }) => { const { role, content, thinkTrace, isLoading, timestamp, thinkEndTime } = message; const isUser = role === "user"; let thinkingText = ""; let opacityClass = ""; const hasDuration = typeof thinkEndTime === "number" && typeof timestamp === "number"; const durationSeconds = hasDuration ? Math.max(Math.round((thinkEndTime - timestamp) / 1000), 0) : null; if (isLoading && !thinkEndTime) { thinkingText = "Thinking..."; opacityClass = "opacity-70 hover:opacity-100"; } else if (thinkTrace) { thinkingText = durationSeconds !== null ? `Thought for ${durationSeconds} seconds` : "Thought interrupted"; } else { thinkingText = "Show Thoughts"; } const markdownContent = convertRefsToSuperscript(content); return (
{(thinkTrace || isLoading) && ( {thinkingText}
} > {thinkTrace || (isLoading ? "..." : "")} )}
{markdownContent || (isLoading ? "" : "")}
); }; /** * Manages the model and tokenizer loading state and refs. */ const useLLM = () => { const [modelStatus, setModelStatus] = useState< "idle" | "loading" | "ready" | "error" >("idle"); const [loadProgress, setLoadProgress] = useState(0); const modelRef = useRef(null); const tokenizerRef = useRef(null); const loadModel = useCallback( async (dtype: Dtype) => { if (modelRef.current && tokenizerRef.current) { setModelStatus("ready"); setLoadProgress(100); return; } if (modelStatus === "loading") return; setModelStatus("loading"); setLoadProgress(0); const progress_callback = (progress: any) => { if ( progress.status === "progress" && typeof progress.total === "number" && typeof progress.loaded === "number" && typeof progress.file === "string" && progress.file.endsWith(".onnx_data") ) { const percentage = Math.round( (progress.loaded / progress.total) * 100, ); setLoadProgress(percentage); } }; try { const tokenizer = await AutoTokenizer.from_pretrained(MODEL_ID, { progress_callback, }); const model = await AutoModelForCausalLM.from_pretrained(MODEL_ID, { dtype, device: "webgpu", progress_callback, }); tokenizerRef.current = tokenizer; modelRef.current = model; setLoadProgress(100); setModelStatus("ready"); } catch (error) { console.error("Failed to load model", error); setModelStatus("error"); } }, [modelStatus], ); return { modelStatus, loadProgress, modelRef, tokenizerRef, loadModel, }; }; const App: React.FC = () => { const [messages, setMessages] = useState([]); const [currentInput, setCurrentInput] = useState(""); const [context, setContext] = useState(""); const [showContext, setShowContext] = useState(false); const [isLoading, setIsLoading] = useState(false); const [lastMessageMinHeight, setLastMessageMinHeight] = useState< number | undefined >(undefined); const [selectedDtype, setSelectedDtype] = useState("fp16"); const [dtypeMenuOpen, setDtypeMenuOpen] = useState(false); const stoppingCriteriaRef = useRef( null, ); const mainRef = useRef(null); const { modelStatus, loadProgress, modelRef, tokenizerRef, loadModel } = useLLM(); useLayoutEffect(() => { if (!mainRef.current) return; const el = mainRef.current; // If the last message is from the assistant, calculate a min-height to prevent layout shifts. if (messages.at(-1)?.role === "assistant") { const userMessageElement = el.querySelector( `[data-message-id="${messages.at(-2)?.id}"]`, ); if (userMessageElement) { const userMessageHeight = userMessageElement.getBoundingClientRect().height; const screenHeight = window.innerHeight; const newMinHeight = Math.max( screenHeight - userMessageHeight - 270, 0, ); setLastMessageMinHeight(newMinHeight); } } else { setLastMessageMinHeight(undefined); } }, [messages.length]); useLayoutEffect(() => { if (mainRef.current) { const el = mainRef.current; setTimeout(() => { el.scrollTo({ top: el.scrollHeight, behavior: "smooth", }); }, 0); } }, [messages.length, lastMessageMinHeight]); const handleStreamUpdate = useCallback((newToken: string) => { setMessages((prev) => { if (prev.length === 0 || prev.at(-1)!.role === "user") { return prev; } const lastMessage = { ...prev.at(-1)! }; lastMessage.rawStream = (lastMessage.rawStream || "") + newToken; const raw = lastMessage.rawStream; const thinkEndTag = ""; const thinkEndIndex = raw.indexOf(thinkEndTag); let content; let thinkTrace = ""; if (thinkEndIndex !== -1) { // Think block is complete. thinkTrace = raw.substring(0, thinkEndIndex); const contentAfter = raw.substring(thinkEndIndex + thinkEndTag.length); content = contentAfter.replace("<|im_end|><|end_of_text|>", ""); if (!lastMessage.thinkEndTime) { lastMessage.thinkEndTime = Date.now(); } } else { // Think block has started but not finished. thinkTrace = raw; content = ""; } lastMessage.content = content.trim(); lastMessage.thinkTrace = thinkTrace.trim(); return [...prev.slice(0, -1), lastMessage]; }); }, []); const handleStopGeneration = useCallback(() => { stoppingCriteriaRef.current?.interrupt(); }, []); const streamAssistantResponse = useCallback( async ( historyForModel: { role: Role; content: string }[], assistantMessageId: number, ) => { const tokenizer = tokenizerRef.current; const model = modelRef.current; if (!tokenizer || !model) return; const inputs = tokenizer.apply_chat_template(historyForModel, { add_generation_prompt: true, return_dict: true, }) as any; const streamer = new TextStreamer(tokenizer, { skip_prompt: true, skip_special_tokens: false, callback_function: (token: string) => handleStreamUpdate(token), }); const stoppingCriteria = new InterruptableStoppingCriteria(); stoppingCriteriaRef.current = stoppingCriteria; try { await model.generate({ ...inputs, max_new_tokens: 2048, streamer, stopping_criteria: stoppingCriteria, repetition_penalty: 1.2, }); } catch (error) { console.error(error); } finally { stoppingCriteriaRef.current = null; setIsLoading(false); setMessages((prev) => prev.map((msg) => { if (msg.id === assistantMessageId) { const { rawStream, isLoading: _, ...rest } = msg; return rest; } return msg; }), ); } }, [handleStreamUpdate, modelRef, tokenizerRef], ); const handleSubmit = async ( e?: React.FormEvent, prompt?: string, sources?: string, ) => { if (e) e.preventDefault(); if (isLoading || modelStatus !== "ready") return; const input = prompt || currentInput; if (!input.trim()) return; const trimmedContext = (sources || context).trim(); const { payload: sourcesPayload, count: sourceCount, segments: sourceSegments, } = buildSourcesPayload(trimmedContext); const fullPrompt = `${input}${sourcesPayload}`; const sanitizedInput = sanitizeInput(input); let userMessageContent = sanitizedInput; if (sourceCount > 0) { const sourcesList = sourceSegments .map( (seg, i) => `${i + 1}. ${seg.substring(0, 75)}${seg.length > 75 ? "..." : ""}`, ) .join("\n"); userMessageContent += `\n\n[Source${sourceCount > 1 ? "s" : ""}]:\n${sourcesList}`; } const userMessage: Message = { id: messages.length, role: "user", content: userMessageContent, }; const assistantPlaceholder: Message = { id: messages.length + 1, role: "assistant", content: "", thinkTrace: "", rawStream: "", isLoading: true, timestamp: Date.now(), }; setMessages((prev) => [...prev, userMessage, assistantPlaceholder]); setCurrentInput(""); setContext(""); setShowContext(false); setIsLoading(true); setLastMessageMinHeight(undefined); const historyForModel = [ ...messages.map(({ role, content }) => ({ role, content })), { role: "user" as Role, content: fullPrompt }, ]; await streamAssistantResponse(historyForModel, assistantPlaceholder.id); }; const handleNewChat = () => { handleStopGeneration(); setMessages([]); setCurrentInput(""); setContext(""); setShowContext(false); setIsLoading(false); setLastMessageMinHeight(undefined); }; return (
{modelStatus === "ready" && (

🥖 Baguettotron WebGPU

A small but powerful reasoning model

)}
{modelStatus !== "ready" ? (
🥖

Baguettotron WebGPU

You are about to load Baguettotron, a 300M parameter reasoning model optimized for in-browser inference. Everything runs entirely in your browser with 🤗 Transformers.js and ONNX Runtime Web, meaning no data is sent to a server. Once loaded, it can even be used offline.

{dtypeMenuOpen && (
{Object.entries(DTYPES).map(([dtype, label]) => ( ))}
)}
{modelStatus === "error" && (

Model load failed. Check console for details and retry.

)}
) : ( <> {messages.length === 0 && (

Welcome to Baguettotron

Ask me a question, or try one of the examples below!

{["fp32", "fp16"].includes(selectedDtype) && ( )}
)} {messages.map((msg, index) => { const isLastAssistantMessage = index === messages.length - 1 && msg.role === "assistant"; const minHeight = isLastAssistantMessage ? lastMessageMinHeight : undefined; return ( ); })} )}
{modelStatus === "ready" && (