|
|
import { useState, useRef, useCallback, useLayoutEffect } from "react"; |
|
|
import { Send, Paperclip, Brain, ChevronDown, X, Plus } from "lucide-react"; |
|
|
import { |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer, |
|
|
InterruptableStoppingCriteria, |
|
|
TextStreamer, |
|
|
} from "@huggingface/transformers"; |
|
|
import { Streamdown } from "streamdown"; |
|
|
|
|
|
import type { |
|
|
PreTrainedTokenizer, |
|
|
LlamaForCausalLM, |
|
|
} from "@huggingface/transformers"; |
|
|
import type React from "react"; |
|
|
|
|
|
const MODEL_ID = "onnx-community/Baguettotron-ONNX"; |
|
|
|
|
|
const DTYPES = { |
|
|
fp32: "FP32 (~1.28 GB)", |
|
|
fp16: "FP16 (~642 MB)", |
|
|
q4: "Q4 (~329 MB)", |
|
|
q4f16: "Q4F16 (~235 MB)", |
|
|
} as const; |
|
|
type Dtype = keyof typeof DTYPES; |
|
|
|
|
|
const SOURCE_SEPARATOR_REGEX = /\n{2,}/g; |
|
|
|
|
|
type Role = "user" | "assistant"; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const buildSourcesPayload = (rawContext: string) => { |
|
|
const trimmed = rawContext.trim(); |
|
|
if (!trimmed) { |
|
|
return { payload: "", count: 0, segments: [] }; |
|
|
} |
|
|
|
|
|
const segments = trimmed |
|
|
.split(SOURCE_SEPARATOR_REGEX) |
|
|
.map((segment) => segment.trim()) |
|
|
.filter(Boolean); |
|
|
|
|
|
if (segments.length === 0) { |
|
|
return { payload: "", count: 0, segments: [] }; |
|
|
} |
|
|
|
|
|
const payload = |
|
|
"\n\n" + |
|
|
segments |
|
|
.map( |
|
|
(segment, index) => |
|
|
`<source_${index + 1}>${segment}</source_${index + 1}>`, |
|
|
) |
|
|
.join("\n"); |
|
|
return { payload, count: segments.length, segments }; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const convertRefsToSuperscript = (content: string) => { |
|
|
const refRegex = /<ref name="([^"]+)">([\s\S]*?)<\/ref>/g; |
|
|
const refLabelMap = new Map<string, number>(); |
|
|
let refCounter = 1; |
|
|
|
|
|
|
|
|
let result = content.replace(refRegex, (_, sourceName = "", refBody) => { |
|
|
const label = |
|
|
refLabelMap.get(sourceName) ?? |
|
|
(() => { |
|
|
const assigned = refCounter++; |
|
|
refLabelMap.set(sourceName, assigned); |
|
|
return assigned; |
|
|
})(); |
|
|
const escapedRefBody = refBody.replace(/"/g, """); |
|
|
return `<sup className="cursor-pointer" title="${escapedRefBody}">[${label}]</sup>`; |
|
|
}); |
|
|
|
|
|
|
|
|
const incompleteRefRegex = /<ref[^>]*>[\s\S]*$/; |
|
|
result = result.replace(incompleteRefRegex, ""); |
|
|
|
|
|
return result; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const sanitizeInput = (text: string) => { |
|
|
return text.replace(/</g, "<").replace(/>/g, ">"); |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
interface Message { |
|
|
id: number; |
|
|
role: Role; |
|
|
content: string; |
|
|
|
|
|
thinkTrace?: string; |
|
|
rawStream?: string; |
|
|
isLoading?: boolean; |
|
|
timestamp?: number; |
|
|
thinkEndTime?: number; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const Collapsible: React.FC<{ |
|
|
title: React.ReactNode; |
|
|
children: React.ReactNode; |
|
|
}> = ({ title, children }) => { |
|
|
const [isOpen, setIsOpen] = useState(false); |
|
|
const contentRef = useRef<HTMLDivElement>(null); |
|
|
return ( |
|
|
<div className="collapsible mt-2"> |
|
|
<button |
|
|
onClick={() => setIsOpen(!isOpen)} |
|
|
className="flex items-center space-x-1 text-xs font-medium text-amber-700 hover:text-amber-900 transition-colors" |
|
|
> |
|
|
{title} |
|
|
<ChevronDown |
|
|
size={14} |
|
|
className={`transform transition-transform ${isOpen ? "rotate-180" : "rotate-0"}`} |
|
|
/> |
|
|
</button> |
|
|
<div |
|
|
ref={contentRef} |
|
|
style={{ |
|
|
maxHeight: isOpen ? `${contentRef.current?.scrollHeight}px` : "0px", |
|
|
}} |
|
|
className="overflow-hidden transition-all duration-300 ease-in-out" |
|
|
> |
|
|
<div className="mt-2 p-2 bg-amber-50 border border-dashed border-amber-200 rounded-md text-xs text-stone-600 prose-sm"> |
|
|
{children} |
|
|
</div> |
|
|
</div> |
|
|
</div> |
|
|
); |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const MessageBubble: React.FC<{ message: Message; minHeight?: number }> = ({ |
|
|
message, |
|
|
minHeight, |
|
|
}) => { |
|
|
const { role, content, thinkTrace, isLoading, timestamp, thinkEndTime } = |
|
|
message; |
|
|
const isUser = role === "user"; |
|
|
|
|
|
let thinkingText = ""; |
|
|
let opacityClass = ""; |
|
|
const hasDuration = |
|
|
typeof thinkEndTime === "number" && typeof timestamp === "number"; |
|
|
const durationSeconds = hasDuration |
|
|
? Math.max(Math.round((thinkEndTime - timestamp) / 1000), 0) |
|
|
: null; |
|
|
|
|
|
if (isLoading && !thinkEndTime) { |
|
|
thinkingText = "Thinking..."; |
|
|
opacityClass = "opacity-70 hover:opacity-100"; |
|
|
} else if (thinkTrace) { |
|
|
thinkingText = |
|
|
durationSeconds !== null |
|
|
? `Thought for ${durationSeconds} seconds` |
|
|
: "Thought interrupted"; |
|
|
} else { |
|
|
thinkingText = "Show Thoughts"; |
|
|
} |
|
|
|
|
|
const markdownContent = convertRefsToSuperscript(content); |
|
|
|
|
|
return ( |
|
|
<div |
|
|
data-message-id={message.id} |
|
|
data-role={role} |
|
|
className={`message flex items-start animate-in fade-in slide-in-from-bottom-2 duration-300 py-2 ${isUser ? "justify-end" : "justify-start"}`} |
|
|
style={{ |
|
|
minHeight, |
|
|
}} |
|
|
> |
|
|
<div |
|
|
className={`max-w-xl lg:max-w-2xl px-4 py-3 rounded-2xl ${ |
|
|
isUser |
|
|
? "bg-amber-500 text-white rounded-br-none" |
|
|
: "bg-white text-stone-800 rounded-bl-none shadow-sm border border-stone-200" |
|
|
}`} |
|
|
> |
|
|
{(thinkTrace || isLoading) && ( |
|
|
<Collapsible |
|
|
title={ |
|
|
<div className="flex items-center space-x-1.5 text-sm"> |
|
|
<Brain size={16} /> |
|
|
<span |
|
|
className={`${isLoading ? "animate-glisten" : ""} ${opacityClass}`} |
|
|
> |
|
|
{thinkingText} |
|
|
</span> |
|
|
</div> |
|
|
} |
|
|
> |
|
|
<Streamdown |
|
|
parseIncompleteMarkdown={false} |
|
|
className="text-xs text-stone-500" |
|
|
isAnimating={Boolean(isLoading && thinkEndTime)} |
|
|
> |
|
|
{thinkTrace || (isLoading ? "..." : "")} |
|
|
</Streamdown> |
|
|
</Collapsible> |
|
|
)} |
|
|
<div className={`${thinkTrace || isLoading ? "mt-2" : ""}`}> |
|
|
<Streamdown |
|
|
parseIncompleteMarkdown={false} |
|
|
className="text-sm leading-relaxed text-stone-800" |
|
|
isAnimating={Boolean(isLoading && !thinkEndTime)} |
|
|
> |
|
|
{markdownContent || (isLoading ? "" : "")} |
|
|
</Streamdown> |
|
|
</div> |
|
|
</div> |
|
|
</div> |
|
|
); |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const useLLM = () => { |
|
|
const [modelStatus, setModelStatus] = useState< |
|
|
"idle" | "loading" | "ready" | "error" |
|
|
>("idle"); |
|
|
const [loadProgress, setLoadProgress] = useState(0); |
|
|
const modelRef = useRef<LlamaForCausalLM | null>(null); |
|
|
const tokenizerRef = useRef<PreTrainedTokenizer | null>(null); |
|
|
|
|
|
const loadModel = useCallback( |
|
|
async (dtype: Dtype) => { |
|
|
if (modelRef.current && tokenizerRef.current) { |
|
|
setModelStatus("ready"); |
|
|
setLoadProgress(100); |
|
|
return; |
|
|
} |
|
|
if (modelStatus === "loading") return; |
|
|
|
|
|
setModelStatus("loading"); |
|
|
setLoadProgress(0); |
|
|
|
|
|
const progress_callback = (progress: any) => { |
|
|
if ( |
|
|
progress.status === "progress" && |
|
|
typeof progress.total === "number" && |
|
|
typeof progress.loaded === "number" && |
|
|
typeof progress.file === "string" && |
|
|
progress.file.endsWith(".onnx_data") |
|
|
) { |
|
|
const percentage = Math.round( |
|
|
(progress.loaded / progress.total) * 100, |
|
|
); |
|
|
setLoadProgress(percentage); |
|
|
} |
|
|
}; |
|
|
|
|
|
try { |
|
|
const tokenizer = await AutoTokenizer.from_pretrained(MODEL_ID, { |
|
|
progress_callback, |
|
|
}); |
|
|
const model = await AutoModelForCausalLM.from_pretrained(MODEL_ID, { |
|
|
dtype, |
|
|
device: "webgpu", |
|
|
progress_callback, |
|
|
}); |
|
|
tokenizerRef.current = tokenizer; |
|
|
modelRef.current = model; |
|
|
setLoadProgress(100); |
|
|
setModelStatus("ready"); |
|
|
} catch (error) { |
|
|
console.error("Failed to load model", error); |
|
|
setModelStatus("error"); |
|
|
} |
|
|
}, |
|
|
[modelStatus], |
|
|
); |
|
|
|
|
|
return { |
|
|
modelStatus, |
|
|
loadProgress, |
|
|
modelRef, |
|
|
tokenizerRef, |
|
|
loadModel, |
|
|
}; |
|
|
}; |
|
|
|
|
|
const App: React.FC = () => { |
|
|
const [messages, setMessages] = useState<Message[]>([]); |
|
|
const [currentInput, setCurrentInput] = useState(""); |
|
|
const [context, setContext] = useState(""); |
|
|
const [showContext, setShowContext] = useState(false); |
|
|
const [isLoading, setIsLoading] = useState(false); |
|
|
const [lastMessageMinHeight, setLastMessageMinHeight] = useState< |
|
|
number | undefined |
|
|
>(undefined); |
|
|
const [selectedDtype, setSelectedDtype] = useState<Dtype>("fp16"); |
|
|
const [dtypeMenuOpen, setDtypeMenuOpen] = useState(false); |
|
|
|
|
|
const stoppingCriteriaRef = useRef<InterruptableStoppingCriteria | null>( |
|
|
null, |
|
|
); |
|
|
const mainRef = useRef<HTMLDivElement>(null); |
|
|
|
|
|
const { modelStatus, loadProgress, modelRef, tokenizerRef, loadModel } = |
|
|
useLLM(); |
|
|
|
|
|
useLayoutEffect(() => { |
|
|
if (!mainRef.current) return; |
|
|
const el = mainRef.current; |
|
|
|
|
|
|
|
|
if (messages.at(-1)?.role === "assistant") { |
|
|
const userMessageElement = el.querySelector<HTMLDivElement>( |
|
|
`[data-message-id="${messages.at(-2)?.id}"]`, |
|
|
); |
|
|
if (userMessageElement) { |
|
|
const userMessageHeight = |
|
|
userMessageElement.getBoundingClientRect().height; |
|
|
const screenHeight = window.innerHeight; |
|
|
const newMinHeight = Math.max( |
|
|
screenHeight - userMessageHeight - 270, |
|
|
0, |
|
|
); |
|
|
setLastMessageMinHeight(newMinHeight); |
|
|
} |
|
|
} else { |
|
|
setLastMessageMinHeight(undefined); |
|
|
} |
|
|
}, [messages.length]); |
|
|
|
|
|
useLayoutEffect(() => { |
|
|
if (mainRef.current) { |
|
|
const el = mainRef.current; |
|
|
setTimeout(() => { |
|
|
el.scrollTo({ |
|
|
top: el.scrollHeight, |
|
|
behavior: "smooth", |
|
|
}); |
|
|
}, 0); |
|
|
} |
|
|
}, [messages.length, lastMessageMinHeight]); |
|
|
|
|
|
const handleStreamUpdate = useCallback((newToken: string) => { |
|
|
setMessages((prev) => { |
|
|
if (prev.length === 0 || prev.at(-1)!.role === "user") { |
|
|
return prev; |
|
|
} |
|
|
|
|
|
const lastMessage = { ...prev.at(-1)! }; |
|
|
lastMessage.rawStream = (lastMessage.rawStream || "") + newToken; |
|
|
|
|
|
const raw = lastMessage.rawStream; |
|
|
const thinkEndTag = "</think>"; |
|
|
const thinkEndIndex = raw.indexOf(thinkEndTag); |
|
|
|
|
|
let content; |
|
|
let thinkTrace = ""; |
|
|
|
|
|
if (thinkEndIndex !== -1) { |
|
|
|
|
|
thinkTrace = raw.substring(0, thinkEndIndex); |
|
|
const contentAfter = raw.substring(thinkEndIndex + thinkEndTag.length); |
|
|
content = contentAfter.replace("<|im_end|><|end_of_text|>", ""); |
|
|
if (!lastMessage.thinkEndTime) { |
|
|
lastMessage.thinkEndTime = Date.now(); |
|
|
} |
|
|
} else { |
|
|
|
|
|
thinkTrace = raw; |
|
|
content = ""; |
|
|
} |
|
|
|
|
|
lastMessage.content = content.trim(); |
|
|
lastMessage.thinkTrace = thinkTrace.trim(); |
|
|
|
|
|
return [...prev.slice(0, -1), lastMessage]; |
|
|
}); |
|
|
}, []); |
|
|
|
|
|
const handleStopGeneration = useCallback(() => { |
|
|
stoppingCriteriaRef.current?.interrupt(); |
|
|
}, []); |
|
|
|
|
|
const streamAssistantResponse = useCallback( |
|
|
async ( |
|
|
historyForModel: { role: Role; content: string }[], |
|
|
assistantMessageId: number, |
|
|
) => { |
|
|
const tokenizer = tokenizerRef.current; |
|
|
const model = modelRef.current; |
|
|
if (!tokenizer || !model) return; |
|
|
|
|
|
const inputs = tokenizer.apply_chat_template(historyForModel, { |
|
|
add_generation_prompt: true, |
|
|
return_dict: true, |
|
|
}) as any; |
|
|
|
|
|
const streamer = new TextStreamer(tokenizer, { |
|
|
skip_prompt: true, |
|
|
skip_special_tokens: false, |
|
|
callback_function: (token: string) => handleStreamUpdate(token), |
|
|
}); |
|
|
|
|
|
const stoppingCriteria = new InterruptableStoppingCriteria(); |
|
|
stoppingCriteriaRef.current = stoppingCriteria; |
|
|
|
|
|
try { |
|
|
await model.generate({ |
|
|
...inputs, |
|
|
max_new_tokens: 2048, |
|
|
streamer, |
|
|
stopping_criteria: stoppingCriteria, |
|
|
|
|
|
repetition_penalty: 1.2, |
|
|
}); |
|
|
} catch (error) { |
|
|
console.error(error); |
|
|
} finally { |
|
|
stoppingCriteriaRef.current = null; |
|
|
setIsLoading(false); |
|
|
setMessages((prev) => |
|
|
prev.map((msg) => { |
|
|
if (msg.id === assistantMessageId) { |
|
|
const { rawStream, isLoading: _, ...rest } = msg; |
|
|
return rest; |
|
|
} |
|
|
return msg; |
|
|
}), |
|
|
); |
|
|
} |
|
|
}, |
|
|
[handleStreamUpdate, modelRef, tokenizerRef], |
|
|
); |
|
|
|
|
|
const handleSubmit = async ( |
|
|
e?: React.FormEvent, |
|
|
prompt?: string, |
|
|
sources?: string, |
|
|
) => { |
|
|
if (e) e.preventDefault(); |
|
|
if (isLoading || modelStatus !== "ready") return; |
|
|
|
|
|
const input = prompt || currentInput; |
|
|
if (!input.trim()) return; |
|
|
|
|
|
const trimmedContext = (sources || context).trim(); |
|
|
const { |
|
|
payload: sourcesPayload, |
|
|
count: sourceCount, |
|
|
segments: sourceSegments, |
|
|
} = buildSourcesPayload(trimmedContext); |
|
|
|
|
|
const fullPrompt = `${input}${sourcesPayload}`; |
|
|
|
|
|
const sanitizedInput = sanitizeInput(input); |
|
|
let userMessageContent = sanitizedInput; |
|
|
|
|
|
if (sourceCount > 0) { |
|
|
const sourcesList = sourceSegments |
|
|
.map( |
|
|
(seg, i) => |
|
|
`${i + 1}. ${seg.substring(0, 75)}${seg.length > 75 ? "..." : ""}`, |
|
|
) |
|
|
.join("\n"); |
|
|
userMessageContent += `\n\n[Source${sourceCount > 1 ? "s" : ""}]:\n${sourcesList}`; |
|
|
} |
|
|
|
|
|
const userMessage: Message = { |
|
|
id: messages.length, |
|
|
role: "user", |
|
|
content: userMessageContent, |
|
|
}; |
|
|
|
|
|
const assistantPlaceholder: Message = { |
|
|
id: messages.length + 1, |
|
|
role: "assistant", |
|
|
content: "", |
|
|
thinkTrace: "", |
|
|
rawStream: "", |
|
|
isLoading: true, |
|
|
timestamp: Date.now(), |
|
|
}; |
|
|
|
|
|
setMessages((prev) => [...prev, userMessage, assistantPlaceholder]); |
|
|
setCurrentInput(""); |
|
|
setContext(""); |
|
|
setShowContext(false); |
|
|
setIsLoading(true); |
|
|
setLastMessageMinHeight(undefined); |
|
|
|
|
|
const historyForModel = [ |
|
|
...messages.map(({ role, content }) => ({ role, content })), |
|
|
{ role: "user" as Role, content: fullPrompt }, |
|
|
]; |
|
|
|
|
|
await streamAssistantResponse(historyForModel, assistantPlaceholder.id); |
|
|
}; |
|
|
|
|
|
const handleNewChat = () => { |
|
|
handleStopGeneration(); |
|
|
setMessages([]); |
|
|
setCurrentInput(""); |
|
|
setContext(""); |
|
|
setShowContext(false); |
|
|
setIsLoading(false); |
|
|
setLastMessageMinHeight(undefined); |
|
|
}; |
|
|
|
|
|
return ( |
|
|
<div className="flex flex-col h-screen bg-amber-50 font-sans text-stone-800"> |
|
|
{modelStatus === "ready" && ( |
|
|
<header className="flex-shrink-0 sticky top-0 z-10 flex items-center justify-between p-4 bg-white/90 backdrop-blur-md shadow-sm border-b border-amber-200 h-[100px]"> |
|
|
<button |
|
|
onClick={handleNewChat} |
|
|
className="p-2 rounded-full text-stone-500 hover:text-amber-600 hover:bg-amber-50 transition-colors" |
|
|
title="New Chat" |
|
|
> |
|
|
<Plus size={20} /> |
|
|
</button> |
|
|
<div className="flex-1 text-center"> |
|
|
<h1 className="text-2xl md:text-3xl font-serif font-bold text-amber-800"> |
|
|
🥖 Baguettotron WebGPU |
|
|
</h1> |
|
|
<p className="text-sm text-stone-600"> |
|
|
A small but powerful reasoning model |
|
|
</p> |
|
|
</div> |
|
|
</header> |
|
|
)} |
|
|
|
|
|
<main ref={mainRef} className="flex-grow overflow-y-auto"> |
|
|
<div className="mx-auto w-full max-w-6xl p-4 md:p-6 space-y-2 h-full"> |
|
|
{modelStatus !== "ready" ? ( |
|
|
<div className="flex h-full flex-col items-center justify-center gap-6 text-center text-stone-600"> |
|
|
<span className="text-8xl animate-wobble">🥖</span> |
|
|
<div> |
|
|
<h1 className="text-5xl font-bold text-amber-800"> |
|
|
Baguettotron WebGPU |
|
|
</h1> |
|
|
<p className="mt-4 max-w-xl text-md"> |
|
|
You are about to load Baguettotron, a 300M parameter reasoning |
|
|
model optimized for in-browser inference. Everything runs |
|
|
entirely in your browser with 🤗 Transformers.js and ONNX |
|
|
Runtime Web, meaning no data is sent to a server. Once loaded, |
|
|
it can even be used offline. |
|
|
</p> |
|
|
</div> |
|
|
<div className="relative inline-flex rounded-full shadow-sm"> |
|
|
<button |
|
|
onClick={() => loadModel(selectedDtype)} |
|
|
disabled={modelStatus === "loading"} |
|
|
className="rounded-l-full bg-amber-600 pl-6 pr-5 py-3 text-white font-medium transition hover:bg-amber-700 disabled:opacity-50 disabled:cursor-not-allowed" |
|
|
> |
|
|
{modelStatus === "loading" |
|
|
? `Loading ${loadProgress}%` |
|
|
: `Load model (${selectedDtype})`} |
|
|
</button> |
|
|
<button |
|
|
onClick={() => setDtypeMenuOpen(!dtypeMenuOpen)} |
|
|
disabled={modelStatus === "loading"} |
|
|
className="rounded-r-full bg-amber-600 px-3 py-3 text-white transition hover:bg-amber-700 disabled:opacity-50 border-l border-amber-500" |
|
|
> |
|
|
<ChevronDown |
|
|
size={20} |
|
|
className={`transform transition-transform ${dtypeMenuOpen ? "rotate-180" : ""}`} |
|
|
/> |
|
|
</button> |
|
|
{dtypeMenuOpen && ( |
|
|
<div className="absolute top-full mt-2 w-full bg-white rounded-md shadow-lg z-10 border border-stone-200"> |
|
|
{Object.entries(DTYPES).map(([dtype, label]) => ( |
|
|
<button |
|
|
key={dtype} |
|
|
onClick={() => { |
|
|
setSelectedDtype(dtype as Dtype); |
|
|
setDtypeMenuOpen(false); |
|
|
}} |
|
|
className="w-full text-left px-4 py-2 text-sm text-stone-700 hover:bg-amber-50" |
|
|
> |
|
|
{label} |
|
|
</button> |
|
|
))} |
|
|
</div> |
|
|
)} |
|
|
</div> |
|
|
{modelStatus === "error" && ( |
|
|
<p className="text-sm text-red-600"> |
|
|
Model load failed. Check console for details and retry. |
|
|
</p> |
|
|
)} |
|
|
</div> |
|
|
) : ( |
|
|
<> |
|
|
{messages.length === 0 && ( |
|
|
<div className="flex flex-col items-center justify-center h-full text-center text-stone-500"> |
|
|
<div className="p-8 rounded-2xl flex flex-col items-center"> |
|
|
<h2 className="text-3xl font-semibold mt-4 text-stone-700"> |
|
|
Welcome to Baguettotron |
|
|
</h2> |
|
|
<h3 className="max-w-xs mt-1 text-lg"> |
|
|
Ask me a question, or try one of the examples below! |
|
|
</h3> |
|
|
</div> |
|
|
<div className="mt-2 flex flex-wrap justify-center gap-4"> |
|
|
<button |
|
|
onClick={() => |
|
|
handleSubmit( |
|
|
undefined, |
|
|
"What is the capital of France? Just provide the answer.", |
|
|
) |
|
|
} |
|
|
className="bg-amber-100 hover:bg-amber-200 text-amber-800 px-4 py-2 rounded-lg shadow-sm border border-amber-200 transition-colors" |
|
|
> |
|
|
Encyclopedic knowledge |
|
|
</button> |
|
|
{["fp32", "fp16"].includes(selectedDtype) && ( |
|
|
<button |
|
|
onClick={() => |
|
|
handleSubmit( |
|
|
undefined, |
|
|
"Write me a short poem about machine learning.", |
|
|
) |
|
|
} |
|
|
className="bg-amber-100 hover:bg-amber-200 text-amber-800 px-4 py-2 rounded-lg shadow-sm border border-amber-200 transition-colors" |
|
|
> |
|
|
Creative writing |
|
|
</button> |
|
|
)} |
|
|
<button |
|
|
onClick={() => |
|
|
handleSubmit( |
|
|
undefined, |
|
|
"Which is wider: Australia or the Moon?", |
|
|
"Australia is approximately 4,000 km in width from east to west, according to Geoscience Australia.\n\nThe diameter of the Moon is about 3,476 km, according to Britannica.", |
|
|
) |
|
|
} |
|
|
className="bg-amber-100 hover:bg-amber-200 text-amber-800 px-4 py-2 rounded-lg shadow-sm border border-amber-200 transition-colors" |
|
|
> |
|
|
RAG with grounding |
|
|
</button> |
|
|
</div> |
|
|
</div> |
|
|
)} |
|
|
{messages.map((msg, index) => { |
|
|
const isLastAssistantMessage = |
|
|
index === messages.length - 1 && msg.role === "assistant"; |
|
|
const minHeight = isLastAssistantMessage |
|
|
? lastMessageMinHeight |
|
|
: undefined; |
|
|
|
|
|
return ( |
|
|
<MessageBubble |
|
|
key={msg.id} |
|
|
message={msg} |
|
|
minHeight={minHeight} |
|
|
/> |
|
|
); |
|
|
})} |
|
|
</> |
|
|
)} |
|
|
</div> |
|
|
</main> |
|
|
|
|
|
{modelStatus === "ready" && ( |
|
|
<footer className="flex-shrink-0 sticky bottom-0 z-10 p-4 bg-white/90 backdrop-blur-md border-t border-amber-100"> |
|
|
<form onSubmit={handleSubmit} className="max-w-3xl mx-auto"> |
|
|
<div |
|
|
style={{ |
|
|
maxHeight: showContext ? "120px" : "0px", |
|
|
transition: "max-height 0.3s ease-in-out", |
|
|
opacity: showContext ? 1 : 0, |
|
|
}} |
|
|
className="overflow-hidden relative" |
|
|
> |
|
|
<textarea |
|
|
value={context} |
|
|
onChange={(e) => setContext(e.target.value)} |
|
|
disabled={isLoading} |
|
|
placeholder="Add RAG context here. Separate multiple sources with two new lines." |
|
|
className="w-full h-28 p-2 mb-2 rounded-lg border border-stone-300 focus:ring-amber-500 focus:border-amber-500 text-sm resize-none" |
|
|
/> |
|
|
<button |
|
|
type="button" |
|
|
onClick={() => setShowContext(false)} |
|
|
className="absolute top-2 right-2 p-1 text-stone-400 hover:text-stone-600 bg-white/50 rounded-full" |
|
|
> |
|
|
<X size={16} /> |
|
|
</button> |
|
|
</div> |
|
|
|
|
|
<div className="flex items-center space-x-2"> |
|
|
<button |
|
|
type="button" |
|
|
onClick={() => setShowContext(!showContext)} |
|
|
title="Add Context for RAG" |
|
|
className={`flex-shrink-0 p-2 rounded-full transition-colors ${ |
|
|
showContext |
|
|
? "bg-amber-100 text-amber-700" |
|
|
: "text-stone-500 hover:text-amber-600 hover:bg-amber-50" |
|
|
}`} |
|
|
> |
|
|
<Paperclip size={20} /> |
|
|
</button> |
|
|
<input |
|
|
type="text" |
|
|
value={currentInput} |
|
|
onChange={(e) => setCurrentInput(e.target.value)} |
|
|
placeholder="Send a message..." |
|
|
className="flex-grow px-4 py-2 rounded-full border border-stone-300 focus:ring-2 focus:ring-amber-500 focus:border-transparent outline-none transition-shadow" |
|
|
disabled={isLoading || modelStatus !== "ready"} |
|
|
/> |
|
|
{isLoading ? ( |
|
|
<button |
|
|
type="button" |
|
|
onClick={handleStopGeneration} |
|
|
className="group flex h-10 w-10 flex-shrink-0 items-center justify-center rounded-full border border-stone-300 bg-white text-stone-600 hover:border-red-500" |
|
|
> |
|
|
<span className="h-3.5 w-3.5 rounded-sm bg-stone-600 transition-colors group-hover:bg-red-500" /> |
|
|
</button> |
|
|
) : ( |
|
|
<button |
|
|
type="submit" |
|
|
disabled={ |
|
|
isLoading || !currentInput.trim() || modelStatus !== "ready" |
|
|
} |
|
|
className="flex h-10 w-10 flex-shrink-0 items-center justify-center rounded-full bg-amber-600 text-white transition-all transform |
|
|
hover:bg-amber-700 hover:scale-105 active:scale-95 |
|
|
disabled:bg-stone-300 disabled:scale-100 disabled:cursor-not-allowed" |
|
|
> |
|
|
<Send size={20} /> |
|
|
</button> |
|
|
)} |
|
|
</div> |
|
|
<p className="text-center text-xs text-stone-400 mt-2"> |
|
|
⚡ Powered by{" "} |
|
|
<a |
|
|
href="https://github.com/huggingface/transformers.js" |
|
|
target="_blank" |
|
|
rel="noopener noreferrer" |
|
|
> |
|
|
Transformers.js |
|
|
</a>{" "} |
|
|
— Runs locally in your browser on WebGPU. |
|
|
</p> |
|
|
</form> |
|
|
</footer> |
|
|
)} |
|
|
</div> |
|
|
); |
|
|
}; |
|
|
|
|
|
export default App; |
|
|
|