Spaces:
Running
Running
feat: implement worker for LLM model loading and response generation with message handling
Browse files- src/App.tsx +55 -52
- src/components/ExamplePrompts.tsx +8 -2
- src/hooks/useLLM.ts +106 -211
- src/workers/llm.worker.ts +253 -0
src/App.tsx
CHANGED
|
@@ -3,6 +3,7 @@ import React, {
|
|
| 3 |
useEffect,
|
| 4 |
useCallback,
|
| 5 |
useRef,
|
|
|
|
| 6 |
} from "react";
|
| 7 |
import { openDB, type IDBPDatabase } from "idb";
|
| 8 |
import {
|
|
@@ -128,6 +129,59 @@ const App: React.FC = () => {
|
|
| 128 |
connectAll: connectAllMCPServers,
|
| 129 |
} = useMCP();
|
| 130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
const loadTools = useCallback(async (): Promise<void> => {
|
| 132 |
const db = await getDB();
|
| 133 |
const allTools: Tool[] = await db.getAll(STORE_NAME);
|
|
@@ -405,10 +459,6 @@ const App: React.FC = () => {
|
|
| 405 |
setIsGenerating(true);
|
| 406 |
|
| 407 |
try {
|
| 408 |
-
const toolSchemas = tools
|
| 409 |
-
.filter((tool) => tool.enabled)
|
| 410 |
-
.map((tool) => generateSchemaFromCode(tool.code));
|
| 411 |
-
|
| 412 |
while (true) {
|
| 413 |
const messagesForGeneration = [
|
| 414 |
{ role: "system" as const, content: systemPrompt },
|
|
@@ -574,10 +624,6 @@ const App: React.FC = () => {
|
|
| 574 |
setIsGenerating(true);
|
| 575 |
|
| 576 |
try {
|
| 577 |
-
const toolSchemas = tools
|
| 578 |
-
.filter((tool) => tool.enabled)
|
| 579 |
-
.map((tool) => generateSchemaFromCode(tool.code));
|
| 580 |
-
|
| 581 |
while (true) {
|
| 582 |
const messagesForGeneration = [
|
| 583 |
{ role: "system" as const, content: systemPrompt },
|
|
@@ -716,50 +762,7 @@ const App: React.FC = () => {
|
|
| 716 |
>
|
| 717 |
{messages.length === 0 && isReady ? (
|
| 718 |
<ExamplePrompts
|
| 719 |
-
examples={
|
| 720 |
-
const enabledTools = tools.filter((tool) => tool.enabled);
|
| 721 |
-
|
| 722 |
-
// Group tools by server (MCP tools have mcpServerId in their code)
|
| 723 |
-
const toolsByServer = enabledTools.reduce((acc, tool) => {
|
| 724 |
-
const mcpServerMatch = tool.code?.match(/mcpServerId: "([^"]+)"/);
|
| 725 |
-
const serverId = mcpServerMatch ? mcpServerMatch[1] : 'local';
|
| 726 |
-
if (!acc[serverId]) acc[serverId] = [];
|
| 727 |
-
acc[serverId].push(tool);
|
| 728 |
-
return acc;
|
| 729 |
-
}, {} as Record<string, typeof enabledTools>);
|
| 730 |
-
|
| 731 |
-
// Pick one tool from each server (up to 3 servers)
|
| 732 |
-
const serverIds = Object.keys(toolsByServer).slice(0, 3);
|
| 733 |
-
const selectedTools = serverIds.map(serverId => {
|
| 734 |
-
const serverTools = toolsByServer[serverId];
|
| 735 |
-
return serverTools[Math.floor(Math.random() * serverTools.length)];
|
| 736 |
-
});
|
| 737 |
-
|
| 738 |
-
return selectedTools.map((tool) => {
|
| 739 |
-
const schema = generateSchemaFromCode(tool.code);
|
| 740 |
-
const description = schema.description || tool.name;
|
| 741 |
-
|
| 742 |
-
// Create a cleaner natural language prompt
|
| 743 |
-
let displayText = description;
|
| 744 |
-
if (description !== tool.name) {
|
| 745 |
-
// If there's a description, make it conversational
|
| 746 |
-
displayText = description.charAt(0).toUpperCase() + description.slice(1);
|
| 747 |
-
if (!displayText.endsWith('?') && !displayText.endsWith('.')) {
|
| 748 |
-
displayText += '?';
|
| 749 |
-
}
|
| 750 |
-
} else {
|
| 751 |
-
// Fallback to tool name in a readable format
|
| 752 |
-
displayText = tool.name.replace(/_/g, ' ');
|
| 753 |
-
displayText = displayText.charAt(0).toUpperCase() + displayText.slice(1);
|
| 754 |
-
}
|
| 755 |
-
|
| 756 |
-
return {
|
| 757 |
-
icon: "🛠️",
|
| 758 |
-
displayText,
|
| 759 |
-
messageText: displayText,
|
| 760 |
-
};
|
| 761 |
-
});
|
| 762 |
-
})()}
|
| 763 |
onExampleClick={handleExampleClick}
|
| 764 |
/>
|
| 765 |
) : (
|
|
|
|
| 3 |
useEffect,
|
| 4 |
useCallback,
|
| 5 |
useRef,
|
| 6 |
+
useMemo,
|
| 7 |
} from "react";
|
| 8 |
import { openDB, type IDBPDatabase } from "idb";
|
| 9 |
import {
|
|
|
|
| 129 |
connectAll: connectAllMCPServers,
|
| 130 |
} = useMCP();
|
| 131 |
|
| 132 |
+
// Memoize tool schemas to avoid recalculating on every render
|
| 133 |
+
const toolSchemas = useMemo(() => {
|
| 134 |
+
return tools
|
| 135 |
+
.filter((tool) => tool.enabled)
|
| 136 |
+
.map((tool) => generateSchemaFromCode(tool.code));
|
| 137 |
+
}, [tools]);
|
| 138 |
+
|
| 139 |
+
// Memoize example prompts to prevent flickering
|
| 140 |
+
const examplePrompts = useMemo(() => {
|
| 141 |
+
const enabledTools = tools.filter((tool) => tool.enabled);
|
| 142 |
+
|
| 143 |
+
// Group tools by server (MCP tools have mcpServerId in their code)
|
| 144 |
+
const toolsByServer = enabledTools.reduce((acc, tool) => {
|
| 145 |
+
const mcpServerMatch = tool.code?.match(/mcpServerId: "([^"]+)"/);
|
| 146 |
+
const serverId = mcpServerMatch ? mcpServerMatch[1] : 'local';
|
| 147 |
+
if (!acc[serverId]) acc[serverId] = [];
|
| 148 |
+
acc[serverId].push(tool);
|
| 149 |
+
return acc;
|
| 150 |
+
}, {} as Record<string, typeof enabledTools>);
|
| 151 |
+
|
| 152 |
+
// Pick one tool from each server (up to 3 servers)
|
| 153 |
+
const serverIds = Object.keys(toolsByServer).slice(0, 3);
|
| 154 |
+
const selectedTools = serverIds.map(serverId => {
|
| 155 |
+
const serverTools = toolsByServer[serverId];
|
| 156 |
+
return serverTools[Math.floor(Math.random() * serverTools.length)];
|
| 157 |
+
});
|
| 158 |
+
|
| 159 |
+
return selectedTools.map((tool) => {
|
| 160 |
+
const schema = generateSchemaFromCode(tool.code);
|
| 161 |
+
const description = schema.description || tool.name;
|
| 162 |
+
|
| 163 |
+
// Create a cleaner natural language prompt
|
| 164 |
+
let displayText = description;
|
| 165 |
+
if (description !== tool.name) {
|
| 166 |
+
// If there's a description, make it conversational
|
| 167 |
+
displayText = description.charAt(0).toUpperCase() + description.slice(1);
|
| 168 |
+
if (!displayText.endsWith('?') && !displayText.endsWith('.')) {
|
| 169 |
+
displayText += '?';
|
| 170 |
+
}
|
| 171 |
+
} else {
|
| 172 |
+
// Fallback to tool name in a readable format
|
| 173 |
+
displayText = tool.name.replace(/_/g, ' ');
|
| 174 |
+
displayText = displayText.charAt(0).toUpperCase() + displayText.slice(1);
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
return {
|
| 178 |
+
icon: "🛠️",
|
| 179 |
+
displayText,
|
| 180 |
+
messageText: displayText,
|
| 181 |
+
};
|
| 182 |
+
});
|
| 183 |
+
}, [tools]);
|
| 184 |
+
|
| 185 |
const loadTools = useCallback(async (): Promise<void> => {
|
| 186 |
const db = await getDB();
|
| 187 |
const allTools: Tool[] = await db.getAll(STORE_NAME);
|
|
|
|
| 459 |
setIsGenerating(true);
|
| 460 |
|
| 461 |
try {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 462 |
while (true) {
|
| 463 |
const messagesForGeneration = [
|
| 464 |
{ role: "system" as const, content: systemPrompt },
|
|
|
|
| 624 |
setIsGenerating(true);
|
| 625 |
|
| 626 |
try {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 627 |
while (true) {
|
| 628 |
const messagesForGeneration = [
|
| 629 |
{ role: "system" as const, content: systemPrompt },
|
|
|
|
| 762 |
>
|
| 763 |
{messages.length === 0 && isReady ? (
|
| 764 |
<ExamplePrompts
|
| 765 |
+
examples={examplePrompts}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 766 |
onExampleClick={handleExampleClick}
|
| 767 |
/>
|
| 768 |
) : (
|
src/components/ExamplePrompts.tsx
CHANGED
|
@@ -30,12 +30,18 @@ const ExamplePrompts: React.FC<ExamplePromptsProps> = ({
|
|
| 30 |
<p className="text-sm text-gray-500">Click one to get started</p>
|
| 31 |
</div>
|
| 32 |
|
| 33 |
-
<div className=
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
{dynamicExamples.map((example, index) => (
|
| 35 |
<button
|
| 36 |
key={index}
|
| 37 |
onClick={() => onExampleClick(example.messageText)}
|
| 38 |
-
className=
|
|
|
|
|
|
|
| 39 |
>
|
| 40 |
<span className="text-xl flex-shrink-0 group-hover:scale-110 transition-transform">
|
| 41 |
{example.icon}
|
|
|
|
| 30 |
<p className="text-sm text-gray-500">Click one to get started</p>
|
| 31 |
</div>
|
| 32 |
|
| 33 |
+
<div className={`grid gap-3 max-w-4xl w-full px-4 ${
|
| 34 |
+
dynamicExamples.length === 1
|
| 35 |
+
? 'grid-cols-1 justify-items-center'
|
| 36 |
+
: 'grid-cols-1 sm:grid-cols-2 lg:grid-cols-3'
|
| 37 |
+
}`}>
|
| 38 |
{dynamicExamples.map((example, index) => (
|
| 39 |
<button
|
| 40 |
key={index}
|
| 41 |
onClick={() => onExampleClick(example.messageText)}
|
| 42 |
+
className={`flex items-start gap-3 p-4 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors text-left group cursor-pointer ${
|
| 43 |
+
dynamicExamples.length === 1 ? 'max-w-md' : ''
|
| 44 |
+
}`}
|
| 45 |
>
|
| 46 |
<span className="text-xl flex-shrink-0 group-hover:scale-110 transition-transform">
|
| 47 |
{example.icon}
|
src/hooks/useLLM.ts
CHANGED
|
@@ -1,9 +1,4 @@
|
|
| 1 |
import { useState, useEffect, useRef, useCallback } from "react";
|
| 2 |
-
import {
|
| 3 |
-
AutoModelForCausalLM,
|
| 4 |
-
AutoTokenizer,
|
| 5 |
-
TextStreamer,
|
| 6 |
-
} from "@huggingface/transformers";
|
| 7 |
|
| 8 |
interface LLMState {
|
| 9 |
isLoading: boolean;
|
|
@@ -14,18 +9,6 @@ interface LLMState {
|
|
| 14 |
numTokens: number;
|
| 15 |
}
|
| 16 |
|
| 17 |
-
interface LLMInstance {
|
| 18 |
-
model: any;
|
| 19 |
-
tokenizer: any;
|
| 20 |
-
}
|
| 21 |
-
|
| 22 |
-
let moduleCache: {
|
| 23 |
-
[modelId: string]: {
|
| 24 |
-
instance: LLMInstance | null;
|
| 25 |
-
loadingPromise: Promise<LLMInstance> | null;
|
| 26 |
-
};
|
| 27 |
-
} = {};
|
| 28 |
-
|
| 29 |
export const useLLM = (modelId?: string) => {
|
| 30 |
const [state, setState] = useState<LLMState>({
|
| 31 |
isLoading: false,
|
|
@@ -36,54 +19,92 @@ export const useLLM = (modelId?: string) => {
|
|
| 36 |
numTokens: 0,
|
| 37 |
});
|
| 38 |
|
| 39 |
-
const
|
| 40 |
-
const
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
}
|
| 88 |
|
| 89 |
setState((prev) => ({
|
|
@@ -93,76 +114,10 @@ export const useLLM = (modelId?: string) => {
|
|
| 93 |
progress: 0,
|
| 94 |
}));
|
| 95 |
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
const progressCallback = (progress: any) => {
|
| 101 |
-
// Only update progress for weights
|
| 102 |
-
if (
|
| 103 |
-
progress.status === "progress" &&
|
| 104 |
-
progress.file.endsWith(".onnx_data")
|
| 105 |
-
) {
|
| 106 |
-
const percentage = Math.round(
|
| 107 |
-
(progress.loaded / progress.total) * 100
|
| 108 |
-
);
|
| 109 |
-
setState((prev) => ({ ...prev, progress: percentage }));
|
| 110 |
-
}
|
| 111 |
-
};
|
| 112 |
-
|
| 113 |
-
const tokenizer = await AutoTokenizer.from_pretrained(MODEL_ID, {
|
| 114 |
-
progress_callback: progressCallback,
|
| 115 |
-
});
|
| 116 |
-
|
| 117 |
-
const model = await AutoModelForCausalLM.from_pretrained(MODEL_ID, {
|
| 118 |
-
dtype: "q4f16",
|
| 119 |
-
device: "webgpu",
|
| 120 |
-
progress_callback: progressCallback,
|
| 121 |
-
});
|
| 122 |
-
|
| 123 |
-
// Pre-warm the model with a dummy input for shader compilation
|
| 124 |
-
console.log("Pre-warming model...");
|
| 125 |
-
const dummyInput = tokenizer("Hello", {
|
| 126 |
-
return_tensors: "pt",
|
| 127 |
-
padding: false,
|
| 128 |
-
truncation: false,
|
| 129 |
-
});
|
| 130 |
-
await model.generate({
|
| 131 |
-
...dummyInput,
|
| 132 |
-
max_new_tokens: 1,
|
| 133 |
-
do_sample: false,
|
| 134 |
-
});
|
| 135 |
-
console.log("Model pre-warmed");
|
| 136 |
-
|
| 137 |
-
const instance = { model, tokenizer };
|
| 138 |
-
instanceRef.current = instance;
|
| 139 |
-
cache.instance = instance;
|
| 140 |
-
loadingPromiseRef.current = null;
|
| 141 |
-
cache.loadingPromise = null;
|
| 142 |
-
|
| 143 |
-
setState((prev) => ({
|
| 144 |
-
...prev,
|
| 145 |
-
isLoading: false,
|
| 146 |
-
isReady: true,
|
| 147 |
-
progress: 100,
|
| 148 |
-
}));
|
| 149 |
-
return instance;
|
| 150 |
-
} catch (error) {
|
| 151 |
-
loadingPromiseRef.current = null;
|
| 152 |
-
cache.loadingPromise = null;
|
| 153 |
-
setState((prev) => ({
|
| 154 |
-
...prev,
|
| 155 |
-
isLoading: false,
|
| 156 |
-
error:
|
| 157 |
-
error instanceof Error ? error.message : "Failed to load model",
|
| 158 |
-
}));
|
| 159 |
-
throw error;
|
| 160 |
-
}
|
| 161 |
-
})();
|
| 162 |
-
|
| 163 |
-
loadingPromiseRef.current = loadingPromise;
|
| 164 |
-
cache.loadingPromise = loadingPromise;
|
| 165 |
-
return loadingPromise;
|
| 166 |
}, [modelId]);
|
| 167 |
|
| 168 |
const generateResponse = useCallback(
|
|
@@ -171,104 +126,44 @@ export const useLLM = (modelId?: string) => {
|
|
| 171 |
tools: Array<any>,
|
| 172 |
onToken?: (token: string) => void
|
| 173 |
): Promise<string> => {
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
throw new Error("Model not loaded. Call loadModel() first.");
|
| 177 |
}
|
| 178 |
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
// Apply chat template with tools
|
| 185 |
-
const input = tokenizer.apply_chat_template(messages, {
|
| 186 |
-
tools,
|
| 187 |
-
add_generation_prompt: true,
|
| 188 |
-
return_dict: true,
|
| 189 |
-
});
|
| 190 |
-
|
| 191 |
-
// Track tokens and timing
|
| 192 |
-
const startTime = performance.now();
|
| 193 |
-
let tokenCount = 0;
|
| 194 |
-
|
| 195 |
-
const streamer = onToken
|
| 196 |
-
? new TextStreamer(tokenizer, {
|
| 197 |
-
skip_prompt: true,
|
| 198 |
-
skip_special_tokens: false,
|
| 199 |
-
callback_function: (token: string) => {
|
| 200 |
-
tokenCount++;
|
| 201 |
-
const elapsed = (performance.now() - startTime) / 1000;
|
| 202 |
-
const tps = tokenCount / elapsed;
|
| 203 |
-
setState((prev) => ({
|
| 204 |
-
...prev,
|
| 205 |
-
tokensPerSecond: tps,
|
| 206 |
-
numTokens: tokenCount,
|
| 207 |
-
}));
|
| 208 |
-
onToken(token);
|
| 209 |
-
},
|
| 210 |
-
})
|
| 211 |
-
: undefined;
|
| 212 |
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
past_key_values: pastKeyValuesRef.current,
|
| 218 |
-
max_new_tokens: 1024,
|
| 219 |
-
do_sample: false,
|
| 220 |
-
streamer,
|
| 221 |
-
return_dict_in_generate: true,
|
| 222 |
});
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
// Decode the generated text with special tokens preserved (except end tokens) for tool call detection
|
| 226 |
-
const response = tokenizer
|
| 227 |
-
.batch_decode(sequences.slice(null, [input.input_ids.dims[1], null]), {
|
| 228 |
-
skip_special_tokens: false,
|
| 229 |
-
})[0]
|
| 230 |
-
.replace(/<\|im_end\|>$/, "")
|
| 231 |
-
.replace(/<\|end_of_text\|>$/, "");
|
| 232 |
-
|
| 233 |
-
return response;
|
| 234 |
-
} finally {
|
| 235 |
-
generationAbortControllerRef.current = null;
|
| 236 |
-
}
|
| 237 |
},
|
| 238 |
[]
|
| 239 |
);
|
| 240 |
|
| 241 |
const interruptGeneration = useCallback(() => {
|
| 242 |
-
if (
|
| 243 |
-
|
| 244 |
}
|
| 245 |
}, []);
|
| 246 |
|
| 247 |
const clearPastKeyValues = useCallback(() => {
|
| 248 |
-
|
|
|
|
|
|
|
| 249 |
}, []);
|
| 250 |
|
| 251 |
const cleanup = useCallback(() => {
|
| 252 |
-
if (
|
| 253 |
-
|
|
|
|
| 254 |
}
|
| 255 |
}, []);
|
| 256 |
|
| 257 |
-
useEffect(() => {
|
| 258 |
-
return cleanup;
|
| 259 |
-
}, [cleanup]);
|
| 260 |
-
|
| 261 |
-
useEffect(() => {
|
| 262 |
-
if (modelId && moduleCache[modelId]) {
|
| 263 |
-
const existingInstance =
|
| 264 |
-
instanceRef.current || moduleCache[modelId].instance;
|
| 265 |
-
if (existingInstance) {
|
| 266 |
-
instanceRef.current = existingInstance;
|
| 267 |
-
setState((prev) => ({ ...prev, isReady: true }));
|
| 268 |
-
}
|
| 269 |
-
}
|
| 270 |
-
}, [modelId]);
|
| 271 |
-
|
| 272 |
return {
|
| 273 |
...state,
|
| 274 |
loadModel,
|
|
|
|
| 1 |
import { useState, useEffect, useRef, useCallback } from "react";
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
interface LLMState {
|
| 4 |
isLoading: boolean;
|
|
|
|
| 9 |
numTokens: number;
|
| 10 |
}
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
export const useLLM = (modelId?: string) => {
|
| 13 |
const [state, setState] = useState<LLMState>({
|
| 14 |
isLoading: false,
|
|
|
|
| 19 |
numTokens: 0,
|
| 20 |
});
|
| 21 |
|
| 22 |
+
const workerRef = useRef<Worker | null>(null);
|
| 23 |
+
const onTokenCallbackRef = useRef<((token: string) => void) | null>(null);
|
| 24 |
+
const resolveGenerationRef = useRef<((text: string) => void) | null>(null);
|
| 25 |
+
const rejectGenerationRef = useRef<((error: Error) => void) | null>(null);
|
| 26 |
|
| 27 |
+
// Initialize worker
|
| 28 |
+
useEffect(() => {
|
| 29 |
+
const worker = new Worker(
|
| 30 |
+
new URL("../workers/llm.worker.ts", import.meta.url),
|
| 31 |
+
{ type: "module" }
|
| 32 |
+
);
|
| 33 |
+
|
| 34 |
+
workerRef.current = worker;
|
| 35 |
+
|
| 36 |
+
// Handle messages from worker
|
| 37 |
+
worker.onmessage = (e) => {
|
| 38 |
+
const message = e.data;
|
| 39 |
+
|
| 40 |
+
switch (message.type) {
|
| 41 |
+
case "progress":
|
| 42 |
+
setState((prev) => ({
|
| 43 |
+
...prev,
|
| 44 |
+
progress: message.progress,
|
| 45 |
+
isLoading: true,
|
| 46 |
+
}));
|
| 47 |
+
break;
|
| 48 |
+
|
| 49 |
+
case "ready":
|
| 50 |
+
setState((prev) => ({
|
| 51 |
+
...prev,
|
| 52 |
+
isLoading: false,
|
| 53 |
+
isReady: true,
|
| 54 |
+
progress: 100,
|
| 55 |
+
}));
|
| 56 |
+
break;
|
| 57 |
+
|
| 58 |
+
case "update":
|
| 59 |
+
setState((prev) => ({
|
| 60 |
+
...prev,
|
| 61 |
+
tokensPerSecond: message.tokensPerSecond,
|
| 62 |
+
numTokens: message.numTokens,
|
| 63 |
+
}));
|
| 64 |
+
if (onTokenCallbackRef.current) {
|
| 65 |
+
onTokenCallbackRef.current(message.token);
|
| 66 |
+
}
|
| 67 |
+
break;
|
| 68 |
|
| 69 |
+
case "complete":
|
| 70 |
+
if (resolveGenerationRef.current) {
|
| 71 |
+
resolveGenerationRef.current(message.text);
|
| 72 |
+
resolveGenerationRef.current = null;
|
| 73 |
+
rejectGenerationRef.current = null;
|
| 74 |
+
}
|
| 75 |
+
break;
|
| 76 |
+
|
| 77 |
+
case "error":
|
| 78 |
+
setState((prev) => ({
|
| 79 |
+
...prev,
|
| 80 |
+
isLoading: false,
|
| 81 |
+
error: message.error,
|
| 82 |
+
}));
|
| 83 |
+
if (rejectGenerationRef.current) {
|
| 84 |
+
rejectGenerationRef.current(new Error(message.error));
|
| 85 |
+
resolveGenerationRef.current = null;
|
| 86 |
+
rejectGenerationRef.current = null;
|
| 87 |
+
}
|
| 88 |
+
break;
|
| 89 |
}
|
| 90 |
+
};
|
| 91 |
+
|
| 92 |
+
worker.onerror = (error) => {
|
| 93 |
+
setState((prev) => ({
|
| 94 |
+
...prev,
|
| 95 |
+
isLoading: false,
|
| 96 |
+
error: error.message,
|
| 97 |
+
}));
|
| 98 |
+
};
|
| 99 |
+
|
| 100 |
+
return () => {
|
| 101 |
+
worker.terminate();
|
| 102 |
+
};
|
| 103 |
+
}, []);
|
| 104 |
+
|
| 105 |
+
const loadModel = useCallback(async () => {
|
| 106 |
+
if (!modelId || !workerRef.current) {
|
| 107 |
+
throw new Error("Model ID or worker not available");
|
| 108 |
}
|
| 109 |
|
| 110 |
setState((prev) => ({
|
|
|
|
| 114 |
progress: 0,
|
| 115 |
}));
|
| 116 |
|
| 117 |
+
workerRef.current.postMessage({
|
| 118 |
+
type: "load",
|
| 119 |
+
modelId,
|
| 120 |
+
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
}, [modelId]);
|
| 122 |
|
| 123 |
const generateResponse = useCallback(
|
|
|
|
| 126 |
tools: Array<any>,
|
| 127 |
onToken?: (token: string) => void
|
| 128 |
): Promise<string> => {
|
| 129 |
+
if (!workerRef.current) {
|
| 130 |
+
throw new Error("Worker not initialized");
|
|
|
|
| 131 |
}
|
| 132 |
|
| 133 |
+
return new Promise((resolve, reject) => {
|
| 134 |
+
onTokenCallbackRef.current = onToken || null;
|
| 135 |
+
resolveGenerationRef.current = resolve;
|
| 136 |
+
rejectGenerationRef.current = reject;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
+
workerRef.current!.postMessage({
|
| 139 |
+
type: "generate",
|
| 140 |
+
messages,
|
| 141 |
+
tools,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
});
|
| 143 |
+
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
},
|
| 145 |
[]
|
| 146 |
);
|
| 147 |
|
| 148 |
const interruptGeneration = useCallback(() => {
|
| 149 |
+
if (workerRef.current) {
|
| 150 |
+
workerRef.current.postMessage({ type: "interrupt" });
|
| 151 |
}
|
| 152 |
}, []);
|
| 153 |
|
| 154 |
const clearPastKeyValues = useCallback(() => {
|
| 155 |
+
if (workerRef.current) {
|
| 156 |
+
workerRef.current.postMessage({ type: "reset" });
|
| 157 |
+
}
|
| 158 |
}, []);
|
| 159 |
|
| 160 |
const cleanup = useCallback(() => {
|
| 161 |
+
if (workerRef.current) {
|
| 162 |
+
workerRef.current.terminate();
|
| 163 |
+
workerRef.current = null;
|
| 164 |
}
|
| 165 |
}, []);
|
| 166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
return {
|
| 168 |
...state,
|
| 169 |
loadModel,
|
src/workers/llm.worker.ts
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import {
|
| 2 |
+
AutoModelForCausalLM,
|
| 3 |
+
AutoTokenizer,
|
| 4 |
+
TextStreamer,
|
| 5 |
+
} from "@huggingface/transformers";
|
| 6 |
+
|
| 7 |
+
// Worker state
|
| 8 |
+
let model: any = null;
|
| 9 |
+
let tokenizer: any = null;
|
| 10 |
+
let pastKeyValues: any = null;
|
| 11 |
+
let isGenerating = false;
|
| 12 |
+
|
| 13 |
+
// Cache for loaded models
|
| 14 |
+
const modelCache: {
|
| 15 |
+
[modelId: string]: {
|
| 16 |
+
model: any;
|
| 17 |
+
tokenizer: any;
|
| 18 |
+
};
|
| 19 |
+
} = {};
|
| 20 |
+
|
| 21 |
+
// Message types from main thread
|
| 22 |
+
interface LoadMessage {
|
| 23 |
+
type: "load";
|
| 24 |
+
modelId: string;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
interface GenerateMessage {
|
| 28 |
+
type: "generate";
|
| 29 |
+
messages: Array<{ role: string; content: string }>;
|
| 30 |
+
tools: Array<any>;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
interface InterruptMessage {
|
| 34 |
+
type: "interrupt";
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
interface ResetMessage {
|
| 38 |
+
type: "reset";
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
type WorkerMessage = LoadMessage | GenerateMessage | InterruptMessage | ResetMessage;
|
| 42 |
+
|
| 43 |
+
// Message types to main thread
|
| 44 |
+
interface ProgressMessage {
|
| 45 |
+
type: "progress";
|
| 46 |
+
progress: number;
|
| 47 |
+
file?: string;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
interface ReadyMessage {
|
| 51 |
+
type: "ready";
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
interface UpdateMessage {
|
| 55 |
+
type: "update";
|
| 56 |
+
token: string;
|
| 57 |
+
tokensPerSecond: number;
|
| 58 |
+
numTokens: number;
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
interface CompleteMessage {
|
| 62 |
+
type: "complete";
|
| 63 |
+
text: string;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
interface ErrorMessage {
|
| 67 |
+
type: "error";
|
| 68 |
+
error: string;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
type WorkerResponse = ProgressMessage | ReadyMessage | UpdateMessage | CompleteMessage | ErrorMessage;
|
| 72 |
+
|
| 73 |
+
function postMessage(message: WorkerResponse) {
|
| 74 |
+
self.postMessage(message);
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
// Load model
|
| 78 |
+
async function loadModel(modelId: string) {
|
| 79 |
+
try {
|
| 80 |
+
// Check cache first
|
| 81 |
+
if (modelCache[modelId]) {
|
| 82 |
+
model = modelCache[modelId].model;
|
| 83 |
+
tokenizer = modelCache[modelId].tokenizer;
|
| 84 |
+
postMessage({ type: "ready" });
|
| 85 |
+
return;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
const progressCallback = (progress: any) => {
|
| 89 |
+
if (
|
| 90 |
+
progress.status === "progress" &&
|
| 91 |
+
progress.file.endsWith(".onnx_data")
|
| 92 |
+
) {
|
| 93 |
+
const percentage = Math.round(
|
| 94 |
+
(progress.loaded / progress.total) * 100
|
| 95 |
+
);
|
| 96 |
+
postMessage({
|
| 97 |
+
type: "progress",
|
| 98 |
+
progress: percentage,
|
| 99 |
+
file: progress.file,
|
| 100 |
+
});
|
| 101 |
+
}
|
| 102 |
+
};
|
| 103 |
+
|
| 104 |
+
// Load tokenizer
|
| 105 |
+
tokenizer = await AutoTokenizer.from_pretrained(modelId, {
|
| 106 |
+
progress_callback: progressCallback,
|
| 107 |
+
});
|
| 108 |
+
|
| 109 |
+
// Load model
|
| 110 |
+
model = await AutoModelForCausalLM.from_pretrained(modelId, {
|
| 111 |
+
dtype: "q4f16",
|
| 112 |
+
device: "webgpu",
|
| 113 |
+
progress_callback: progressCallback,
|
| 114 |
+
});
|
| 115 |
+
|
| 116 |
+
// Pre-warm the model with a dummy input for shader compilation
|
| 117 |
+
const dummyInput = tokenizer("Hello", {
|
| 118 |
+
return_tensors: "pt",
|
| 119 |
+
padding: false,
|
| 120 |
+
truncation: false,
|
| 121 |
+
});
|
| 122 |
+
await model.generate({
|
| 123 |
+
...dummyInput,
|
| 124 |
+
max_new_tokens: 1,
|
| 125 |
+
do_sample: false,
|
| 126 |
+
});
|
| 127 |
+
|
| 128 |
+
// Cache the loaded model
|
| 129 |
+
modelCache[modelId] = { model, tokenizer };
|
| 130 |
+
|
| 131 |
+
postMessage({ type: "ready" });
|
| 132 |
+
} catch (error) {
|
| 133 |
+
postMessage({
|
| 134 |
+
type: "error",
|
| 135 |
+
error: error instanceof Error ? error.message : "Failed to load model",
|
| 136 |
+
});
|
| 137 |
+
}
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
// Generate response
|
| 141 |
+
async function generate(
|
| 142 |
+
messages: Array<{ role: string; content: string }>,
|
| 143 |
+
tools: Array<any>
|
| 144 |
+
) {
|
| 145 |
+
if (!model || !tokenizer) {
|
| 146 |
+
postMessage({ type: "error", error: "Model not loaded" });
|
| 147 |
+
return;
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
try {
|
| 151 |
+
isGenerating = true;
|
| 152 |
+
|
| 153 |
+
// Apply chat template with tools
|
| 154 |
+
const input = tokenizer.apply_chat_template(messages, {
|
| 155 |
+
tools,
|
| 156 |
+
add_generation_prompt: true,
|
| 157 |
+
return_dict: true,
|
| 158 |
+
});
|
| 159 |
+
|
| 160 |
+
// Track tokens and timing
|
| 161 |
+
const startTime = performance.now();
|
| 162 |
+
let tokenCount = 0;
|
| 163 |
+
|
| 164 |
+
const streamer = new TextStreamer(tokenizer, {
|
| 165 |
+
skip_prompt: true,
|
| 166 |
+
skip_special_tokens: false,
|
| 167 |
+
callback_function: (token: string) => {
|
| 168 |
+
if (!isGenerating) return; // Check if interrupted
|
| 169 |
+
|
| 170 |
+
tokenCount++;
|
| 171 |
+
const elapsed = (performance.now() - startTime) / 1000;
|
| 172 |
+
const tps = tokenCount / elapsed;
|
| 173 |
+
|
| 174 |
+
postMessage({
|
| 175 |
+
type: "update",
|
| 176 |
+
token,
|
| 177 |
+
tokensPerSecond: tps,
|
| 178 |
+
numTokens: tokenCount,
|
| 179 |
+
});
|
| 180 |
+
},
|
| 181 |
+
});
|
| 182 |
+
|
| 183 |
+
// Generate the response
|
| 184 |
+
const { sequences, past_key_values } = await model.generate({
|
| 185 |
+
...input,
|
| 186 |
+
past_key_values: pastKeyValues,
|
| 187 |
+
max_new_tokens: 1024,
|
| 188 |
+
do_sample: false,
|
| 189 |
+
streamer,
|
| 190 |
+
return_dict_in_generate: true,
|
| 191 |
+
});
|
| 192 |
+
|
| 193 |
+
pastKeyValues = past_key_values;
|
| 194 |
+
|
| 195 |
+
// Decode the generated text
|
| 196 |
+
const response = tokenizer
|
| 197 |
+
.batch_decode(sequences.slice(null, [input.input_ids.dims[1], null]), {
|
| 198 |
+
skip_special_tokens: false,
|
| 199 |
+
})[0]
|
| 200 |
+
.replace(/<\|im_end\|>$/, "")
|
| 201 |
+
.replace(/<\|end_of_text\|>$/, "");
|
| 202 |
+
|
| 203 |
+
if (isGenerating) {
|
| 204 |
+
postMessage({ type: "complete", text: response });
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
isGenerating = false;
|
| 208 |
+
} catch (error) {
|
| 209 |
+
isGenerating = false;
|
| 210 |
+
postMessage({
|
| 211 |
+
type: "error",
|
| 212 |
+
error: error instanceof Error ? error.message : "Generation failed",
|
| 213 |
+
});
|
| 214 |
+
}
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
// Interrupt generation
|
| 218 |
+
function interrupt() {
|
| 219 |
+
isGenerating = false;
|
| 220 |
+
// Send a completion message with empty text to resolve the promise
|
| 221 |
+
postMessage({ type: "complete", text: "" });
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
// Reset past key values
|
| 225 |
+
function reset() {
|
| 226 |
+
pastKeyValues = null;
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
// Handle messages from main thread
|
| 230 |
+
self.onmessage = async (e: MessageEvent<WorkerMessage>) => {
|
| 231 |
+
const message = e.data;
|
| 232 |
+
|
| 233 |
+
switch (message.type) {
|
| 234 |
+
case "load":
|
| 235 |
+
await loadModel(message.modelId);
|
| 236 |
+
break;
|
| 237 |
+
|
| 238 |
+
case "generate":
|
| 239 |
+
await generate(message.messages, message.tools);
|
| 240 |
+
break;
|
| 241 |
+
|
| 242 |
+
case "interrupt":
|
| 243 |
+
interrupt();
|
| 244 |
+
break;
|
| 245 |
+
|
| 246 |
+
case "reset":
|
| 247 |
+
reset();
|
| 248 |
+
break;
|
| 249 |
+
}
|
| 250 |
+
};
|
| 251 |
+
|
| 252 |
+
// Export for TypeScript
|
| 253 |
+
export {};
|