shreyask commited on
Commit
ebde7f4
·
verified ·
1 Parent(s): 9a9d1f0

feat: implement worker for LLM model loading and response generation with message handling

Browse files
src/App.tsx CHANGED
@@ -3,6 +3,7 @@ import React, {
3
  useEffect,
4
  useCallback,
5
  useRef,
 
6
  } from "react";
7
  import { openDB, type IDBPDatabase } from "idb";
8
  import {
@@ -128,6 +129,59 @@ const App: React.FC = () => {
128
  connectAll: connectAllMCPServers,
129
  } = useMCP();
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  const loadTools = useCallback(async (): Promise<void> => {
132
  const db = await getDB();
133
  const allTools: Tool[] = await db.getAll(STORE_NAME);
@@ -405,10 +459,6 @@ const App: React.FC = () => {
405
  setIsGenerating(true);
406
 
407
  try {
408
- const toolSchemas = tools
409
- .filter((tool) => tool.enabled)
410
- .map((tool) => generateSchemaFromCode(tool.code));
411
-
412
  while (true) {
413
  const messagesForGeneration = [
414
  { role: "system" as const, content: systemPrompt },
@@ -574,10 +624,6 @@ const App: React.FC = () => {
574
  setIsGenerating(true);
575
 
576
  try {
577
- const toolSchemas = tools
578
- .filter((tool) => tool.enabled)
579
- .map((tool) => generateSchemaFromCode(tool.code));
580
-
581
  while (true) {
582
  const messagesForGeneration = [
583
  { role: "system" as const, content: systemPrompt },
@@ -716,50 +762,7 @@ const App: React.FC = () => {
716
  >
717
  {messages.length === 0 && isReady ? (
718
  <ExamplePrompts
719
- examples={(() => {
720
- const enabledTools = tools.filter((tool) => tool.enabled);
721
-
722
- // Group tools by server (MCP tools have mcpServerId in their code)
723
- const toolsByServer = enabledTools.reduce((acc, tool) => {
724
- const mcpServerMatch = tool.code?.match(/mcpServerId: "([^"]+)"/);
725
- const serverId = mcpServerMatch ? mcpServerMatch[1] : 'local';
726
- if (!acc[serverId]) acc[serverId] = [];
727
- acc[serverId].push(tool);
728
- return acc;
729
- }, {} as Record<string, typeof enabledTools>);
730
-
731
- // Pick one tool from each server (up to 3 servers)
732
- const serverIds = Object.keys(toolsByServer).slice(0, 3);
733
- const selectedTools = serverIds.map(serverId => {
734
- const serverTools = toolsByServer[serverId];
735
- return serverTools[Math.floor(Math.random() * serverTools.length)];
736
- });
737
-
738
- return selectedTools.map((tool) => {
739
- const schema = generateSchemaFromCode(tool.code);
740
- const description = schema.description || tool.name;
741
-
742
- // Create a cleaner natural language prompt
743
- let displayText = description;
744
- if (description !== tool.name) {
745
- // If there's a description, make it conversational
746
- displayText = description.charAt(0).toUpperCase() + description.slice(1);
747
- if (!displayText.endsWith('?') && !displayText.endsWith('.')) {
748
- displayText += '?';
749
- }
750
- } else {
751
- // Fallback to tool name in a readable format
752
- displayText = tool.name.replace(/_/g, ' ');
753
- displayText = displayText.charAt(0).toUpperCase() + displayText.slice(1);
754
- }
755
-
756
- return {
757
- icon: "🛠️",
758
- displayText,
759
- messageText: displayText,
760
- };
761
- });
762
- })()}
763
  onExampleClick={handleExampleClick}
764
  />
765
  ) : (
 
3
  useEffect,
4
  useCallback,
5
  useRef,
6
+ useMemo,
7
  } from "react";
8
  import { openDB, type IDBPDatabase } from "idb";
9
  import {
 
129
  connectAll: connectAllMCPServers,
130
  } = useMCP();
131
 
132
+ // Memoize tool schemas to avoid recalculating on every render
133
+ const toolSchemas = useMemo(() => {
134
+ return tools
135
+ .filter((tool) => tool.enabled)
136
+ .map((tool) => generateSchemaFromCode(tool.code));
137
+ }, [tools]);
138
+
139
+ // Memoize example prompts to prevent flickering
140
+ const examplePrompts = useMemo(() => {
141
+ const enabledTools = tools.filter((tool) => tool.enabled);
142
+
143
+ // Group tools by server (MCP tools have mcpServerId in their code)
144
+ const toolsByServer = enabledTools.reduce((acc, tool) => {
145
+ const mcpServerMatch = tool.code?.match(/mcpServerId: "([^"]+)"/);
146
+ const serverId = mcpServerMatch ? mcpServerMatch[1] : 'local';
147
+ if (!acc[serverId]) acc[serverId] = [];
148
+ acc[serverId].push(tool);
149
+ return acc;
150
+ }, {} as Record<string, typeof enabledTools>);
151
+
152
+ // Pick one tool from each server (up to 3 servers)
153
+ const serverIds = Object.keys(toolsByServer).slice(0, 3);
154
+ const selectedTools = serverIds.map(serverId => {
155
+ const serverTools = toolsByServer[serverId];
156
+ return serverTools[Math.floor(Math.random() * serverTools.length)];
157
+ });
158
+
159
+ return selectedTools.map((tool) => {
160
+ const schema = generateSchemaFromCode(tool.code);
161
+ const description = schema.description || tool.name;
162
+
163
+ // Create a cleaner natural language prompt
164
+ let displayText = description;
165
+ if (description !== tool.name) {
166
+ // If there's a description, make it conversational
167
+ displayText = description.charAt(0).toUpperCase() + description.slice(1);
168
+ if (!displayText.endsWith('?') && !displayText.endsWith('.')) {
169
+ displayText += '?';
170
+ }
171
+ } else {
172
+ // Fallback to tool name in a readable format
173
+ displayText = tool.name.replace(/_/g, ' ');
174
+ displayText = displayText.charAt(0).toUpperCase() + displayText.slice(1);
175
+ }
176
+
177
+ return {
178
+ icon: "🛠️",
179
+ displayText,
180
+ messageText: displayText,
181
+ };
182
+ });
183
+ }, [tools]);
184
+
185
  const loadTools = useCallback(async (): Promise<void> => {
186
  const db = await getDB();
187
  const allTools: Tool[] = await db.getAll(STORE_NAME);
 
459
  setIsGenerating(true);
460
 
461
  try {
 
 
 
 
462
  while (true) {
463
  const messagesForGeneration = [
464
  { role: "system" as const, content: systemPrompt },
 
624
  setIsGenerating(true);
625
 
626
  try {
 
 
 
 
627
  while (true) {
628
  const messagesForGeneration = [
629
  { role: "system" as const, content: systemPrompt },
 
762
  >
763
  {messages.length === 0 && isReady ? (
764
  <ExamplePrompts
765
+ examples={examplePrompts}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
766
  onExampleClick={handleExampleClick}
767
  />
768
  ) : (
src/components/ExamplePrompts.tsx CHANGED
@@ -30,12 +30,18 @@ const ExamplePrompts: React.FC<ExamplePromptsProps> = ({
30
  <p className="text-sm text-gray-500">Click one to get started</p>
31
  </div>
32
 
33
- <div className="grid grid-cols-1 sm:grid-cols-2 lg:grid-cols-3 gap-3 max-w-4xl w-full px-4">
 
 
 
 
34
  {dynamicExamples.map((example, index) => (
35
  <button
36
  key={index}
37
  onClick={() => onExampleClick(example.messageText)}
38
- className="flex items-start gap-3 p-4 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors text-left group cursor-pointer"
 
 
39
  >
40
  <span className="text-xl flex-shrink-0 group-hover:scale-110 transition-transform">
41
  {example.icon}
 
30
  <p className="text-sm text-gray-500">Click one to get started</p>
31
  </div>
32
 
33
+ <div className={`grid gap-3 max-w-4xl w-full px-4 ${
34
+ dynamicExamples.length === 1
35
+ ? 'grid-cols-1 justify-items-center'
36
+ : 'grid-cols-1 sm:grid-cols-2 lg:grid-cols-3'
37
+ }`}>
38
  {dynamicExamples.map((example, index) => (
39
  <button
40
  key={index}
41
  onClick={() => onExampleClick(example.messageText)}
42
+ className={`flex items-start gap-3 p-4 bg-gray-700 hover:bg-gray-600 rounded-lg transition-colors text-left group cursor-pointer ${
43
+ dynamicExamples.length === 1 ? 'max-w-md' : ''
44
+ }`}
45
  >
46
  <span className="text-xl flex-shrink-0 group-hover:scale-110 transition-transform">
47
  {example.icon}
src/hooks/useLLM.ts CHANGED
@@ -1,9 +1,4 @@
1
  import { useState, useEffect, useRef, useCallback } from "react";
2
- import {
3
- AutoModelForCausalLM,
4
- AutoTokenizer,
5
- TextStreamer,
6
- } from "@huggingface/transformers";
7
 
8
  interface LLMState {
9
  isLoading: boolean;
@@ -14,18 +9,6 @@ interface LLMState {
14
  numTokens: number;
15
  }
16
 
17
- interface LLMInstance {
18
- model: any;
19
- tokenizer: any;
20
- }
21
-
22
- let moduleCache: {
23
- [modelId: string]: {
24
- instance: LLMInstance | null;
25
- loadingPromise: Promise<LLMInstance> | null;
26
- };
27
- } = {};
28
-
29
  export const useLLM = (modelId?: string) => {
30
  const [state, setState] = useState<LLMState>({
31
  isLoading: false,
@@ -36,54 +19,92 @@ export const useLLM = (modelId?: string) => {
36
  numTokens: 0,
37
  });
38
 
39
- const instanceRef = useRef<LLMInstance | null>(null);
40
- const loadingPromiseRef = useRef<Promise<LLMInstance> | null>(null);
 
 
41
 
42
- const abortControllerRef = useRef<AbortController | null>(null);
43
- const pastKeyValuesRef = useRef<any>(null);
44
- const generationAbortControllerRef = useRef<AbortController | null>(null);
45
-
46
- const loadModel = useCallback(async () => {
47
- if (!modelId) {
48
- throw new Error("Model ID is required");
49
- }
50
-
51
- const MODEL_ID = modelId;
52
-
53
- if (!moduleCache[modelId]) {
54
- moduleCache[modelId] = {
55
- instance: null,
56
- loadingPromise: null,
57
- };
58
- }
59
-
60
- const cache = moduleCache[modelId];
61
-
62
- const existingInstance = instanceRef.current || cache.instance;
63
- if (existingInstance) {
64
- instanceRef.current = existingInstance;
65
- cache.instance = existingInstance;
66
- setState((prev) => ({ ...prev, isReady: true, isLoading: false }));
67
- return existingInstance;
68
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- const existingPromise = loadingPromiseRef.current || cache.loadingPromise;
71
- if (existingPromise) {
72
- try {
73
- const instance = await existingPromise;
74
- instanceRef.current = instance;
75
- cache.instance = instance;
76
- setState((prev) => ({ ...prev, isReady: true, isLoading: false }));
77
- return instance;
78
- } catch (error) {
79
- setState((prev) => ({
80
- ...prev,
81
- isLoading: false,
82
- error:
83
- error instanceof Error ? error.message : "Failed to load model",
84
- }));
85
- throw error;
 
 
 
 
86
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  }
88
 
89
  setState((prev) => ({
@@ -93,76 +114,10 @@ export const useLLM = (modelId?: string) => {
93
  progress: 0,
94
  }));
95
 
96
- abortControllerRef.current = new AbortController();
97
-
98
- const loadingPromise = (async () => {
99
- try {
100
- const progressCallback = (progress: any) => {
101
- // Only update progress for weights
102
- if (
103
- progress.status === "progress" &&
104
- progress.file.endsWith(".onnx_data")
105
- ) {
106
- const percentage = Math.round(
107
- (progress.loaded / progress.total) * 100
108
- );
109
- setState((prev) => ({ ...prev, progress: percentage }));
110
- }
111
- };
112
-
113
- const tokenizer = await AutoTokenizer.from_pretrained(MODEL_ID, {
114
- progress_callback: progressCallback,
115
- });
116
-
117
- const model = await AutoModelForCausalLM.from_pretrained(MODEL_ID, {
118
- dtype: "q4f16",
119
- device: "webgpu",
120
- progress_callback: progressCallback,
121
- });
122
-
123
- // Pre-warm the model with a dummy input for shader compilation
124
- console.log("Pre-warming model...");
125
- const dummyInput = tokenizer("Hello", {
126
- return_tensors: "pt",
127
- padding: false,
128
- truncation: false,
129
- });
130
- await model.generate({
131
- ...dummyInput,
132
- max_new_tokens: 1,
133
- do_sample: false,
134
- });
135
- console.log("Model pre-warmed");
136
-
137
- const instance = { model, tokenizer };
138
- instanceRef.current = instance;
139
- cache.instance = instance;
140
- loadingPromiseRef.current = null;
141
- cache.loadingPromise = null;
142
-
143
- setState((prev) => ({
144
- ...prev,
145
- isLoading: false,
146
- isReady: true,
147
- progress: 100,
148
- }));
149
- return instance;
150
- } catch (error) {
151
- loadingPromiseRef.current = null;
152
- cache.loadingPromise = null;
153
- setState((prev) => ({
154
- ...prev,
155
- isLoading: false,
156
- error:
157
- error instanceof Error ? error.message : "Failed to load model",
158
- }));
159
- throw error;
160
- }
161
- })();
162
-
163
- loadingPromiseRef.current = loadingPromise;
164
- cache.loadingPromise = loadingPromise;
165
- return loadingPromise;
166
  }, [modelId]);
167
 
168
  const generateResponse = useCallback(
@@ -171,104 +126,44 @@ export const useLLM = (modelId?: string) => {
171
  tools: Array<any>,
172
  onToken?: (token: string) => void
173
  ): Promise<string> => {
174
- const instance = instanceRef.current;
175
- if (!instance) {
176
- throw new Error("Model not loaded. Call loadModel() first.");
177
  }
178
 
179
- const { model, tokenizer } = instance;
180
-
181
- // Create abort controller for this generation
182
- generationAbortControllerRef.current = new AbortController();
183
-
184
- // Apply chat template with tools
185
- const input = tokenizer.apply_chat_template(messages, {
186
- tools,
187
- add_generation_prompt: true,
188
- return_dict: true,
189
- });
190
-
191
- // Track tokens and timing
192
- const startTime = performance.now();
193
- let tokenCount = 0;
194
-
195
- const streamer = onToken
196
- ? new TextStreamer(tokenizer, {
197
- skip_prompt: true,
198
- skip_special_tokens: false,
199
- callback_function: (token: string) => {
200
- tokenCount++;
201
- const elapsed = (performance.now() - startTime) / 1000;
202
- const tps = tokenCount / elapsed;
203
- setState((prev) => ({
204
- ...prev,
205
- tokensPerSecond: tps,
206
- numTokens: tokenCount,
207
- }));
208
- onToken(token);
209
- },
210
- })
211
- : undefined;
212
 
213
- try {
214
- // Generate the response
215
- const { sequences, past_key_values } = await model.generate({
216
- ...input,
217
- past_key_values: pastKeyValuesRef.current,
218
- max_new_tokens: 1024,
219
- do_sample: false,
220
- streamer,
221
- return_dict_in_generate: true,
222
  });
223
- pastKeyValuesRef.current = past_key_values;
224
-
225
- // Decode the generated text with special tokens preserved (except end tokens) for tool call detection
226
- const response = tokenizer
227
- .batch_decode(sequences.slice(null, [input.input_ids.dims[1], null]), {
228
- skip_special_tokens: false,
229
- })[0]
230
- .replace(/<\|im_end\|>$/, "")
231
- .replace(/<\|end_of_text\|>$/, "");
232
-
233
- return response;
234
- } finally {
235
- generationAbortControllerRef.current = null;
236
- }
237
  },
238
  []
239
  );
240
 
241
  const interruptGeneration = useCallback(() => {
242
- if (generationAbortControllerRef.current) {
243
- generationAbortControllerRef.current.abort();
244
  }
245
  }, []);
246
 
247
  const clearPastKeyValues = useCallback(() => {
248
- pastKeyValuesRef.current = null;
 
 
249
  }, []);
250
 
251
  const cleanup = useCallback(() => {
252
- if (abortControllerRef.current) {
253
- abortControllerRef.current.abort();
 
254
  }
255
  }, []);
256
 
257
- useEffect(() => {
258
- return cleanup;
259
- }, [cleanup]);
260
-
261
- useEffect(() => {
262
- if (modelId && moduleCache[modelId]) {
263
- const existingInstance =
264
- instanceRef.current || moduleCache[modelId].instance;
265
- if (existingInstance) {
266
- instanceRef.current = existingInstance;
267
- setState((prev) => ({ ...prev, isReady: true }));
268
- }
269
- }
270
- }, [modelId]);
271
-
272
  return {
273
  ...state,
274
  loadModel,
 
1
  import { useState, useEffect, useRef, useCallback } from "react";
 
 
 
 
 
2
 
3
  interface LLMState {
4
  isLoading: boolean;
 
9
  numTokens: number;
10
  }
11
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  export const useLLM = (modelId?: string) => {
13
  const [state, setState] = useState<LLMState>({
14
  isLoading: false,
 
19
  numTokens: 0,
20
  });
21
 
22
+ const workerRef = useRef<Worker | null>(null);
23
+ const onTokenCallbackRef = useRef<((token: string) => void) | null>(null);
24
+ const resolveGenerationRef = useRef<((text: string) => void) | null>(null);
25
+ const rejectGenerationRef = useRef<((error: Error) => void) | null>(null);
26
 
27
+ // Initialize worker
28
+ useEffect(() => {
29
+ const worker = new Worker(
30
+ new URL("../workers/llm.worker.ts", import.meta.url),
31
+ { type: "module" }
32
+ );
33
+
34
+ workerRef.current = worker;
35
+
36
+ // Handle messages from worker
37
+ worker.onmessage = (e) => {
38
+ const message = e.data;
39
+
40
+ switch (message.type) {
41
+ case "progress":
42
+ setState((prev) => ({
43
+ ...prev,
44
+ progress: message.progress,
45
+ isLoading: true,
46
+ }));
47
+ break;
48
+
49
+ case "ready":
50
+ setState((prev) => ({
51
+ ...prev,
52
+ isLoading: false,
53
+ isReady: true,
54
+ progress: 100,
55
+ }));
56
+ break;
57
+
58
+ case "update":
59
+ setState((prev) => ({
60
+ ...prev,
61
+ tokensPerSecond: message.tokensPerSecond,
62
+ numTokens: message.numTokens,
63
+ }));
64
+ if (onTokenCallbackRef.current) {
65
+ onTokenCallbackRef.current(message.token);
66
+ }
67
+ break;
68
 
69
+ case "complete":
70
+ if (resolveGenerationRef.current) {
71
+ resolveGenerationRef.current(message.text);
72
+ resolveGenerationRef.current = null;
73
+ rejectGenerationRef.current = null;
74
+ }
75
+ break;
76
+
77
+ case "error":
78
+ setState((prev) => ({
79
+ ...prev,
80
+ isLoading: false,
81
+ error: message.error,
82
+ }));
83
+ if (rejectGenerationRef.current) {
84
+ rejectGenerationRef.current(new Error(message.error));
85
+ resolveGenerationRef.current = null;
86
+ rejectGenerationRef.current = null;
87
+ }
88
+ break;
89
  }
90
+ };
91
+
92
+ worker.onerror = (error) => {
93
+ setState((prev) => ({
94
+ ...prev,
95
+ isLoading: false,
96
+ error: error.message,
97
+ }));
98
+ };
99
+
100
+ return () => {
101
+ worker.terminate();
102
+ };
103
+ }, []);
104
+
105
+ const loadModel = useCallback(async () => {
106
+ if (!modelId || !workerRef.current) {
107
+ throw new Error("Model ID or worker not available");
108
  }
109
 
110
  setState((prev) => ({
 
114
  progress: 0,
115
  }));
116
 
117
+ workerRef.current.postMessage({
118
+ type: "load",
119
+ modelId,
120
+ });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  }, [modelId]);
122
 
123
  const generateResponse = useCallback(
 
126
  tools: Array<any>,
127
  onToken?: (token: string) => void
128
  ): Promise<string> => {
129
+ if (!workerRef.current) {
130
+ throw new Error("Worker not initialized");
 
131
  }
132
 
133
+ return new Promise((resolve, reject) => {
134
+ onTokenCallbackRef.current = onToken || null;
135
+ resolveGenerationRef.current = resolve;
136
+ rejectGenerationRef.current = reject;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+ workerRef.current!.postMessage({
139
+ type: "generate",
140
+ messages,
141
+ tools,
 
 
 
 
 
142
  });
143
+ });
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  },
145
  []
146
  );
147
 
148
  const interruptGeneration = useCallback(() => {
149
+ if (workerRef.current) {
150
+ workerRef.current.postMessage({ type: "interrupt" });
151
  }
152
  }, []);
153
 
154
  const clearPastKeyValues = useCallback(() => {
155
+ if (workerRef.current) {
156
+ workerRef.current.postMessage({ type: "reset" });
157
+ }
158
  }, []);
159
 
160
  const cleanup = useCallback(() => {
161
+ if (workerRef.current) {
162
+ workerRef.current.terminate();
163
+ workerRef.current = null;
164
  }
165
  }, []);
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  return {
168
  ...state,
169
  loadModel,
src/workers/llm.worker.ts ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import {
2
+ AutoModelForCausalLM,
3
+ AutoTokenizer,
4
+ TextStreamer,
5
+ } from "@huggingface/transformers";
6
+
7
+ // Worker state
8
+ let model: any = null;
9
+ let tokenizer: any = null;
10
+ let pastKeyValues: any = null;
11
+ let isGenerating = false;
12
+
13
+ // Cache for loaded models
14
+ const modelCache: {
15
+ [modelId: string]: {
16
+ model: any;
17
+ tokenizer: any;
18
+ };
19
+ } = {};
20
+
21
+ // Message types from main thread
22
+ interface LoadMessage {
23
+ type: "load";
24
+ modelId: string;
25
+ }
26
+
27
+ interface GenerateMessage {
28
+ type: "generate";
29
+ messages: Array<{ role: string; content: string }>;
30
+ tools: Array<any>;
31
+ }
32
+
33
+ interface InterruptMessage {
34
+ type: "interrupt";
35
+ }
36
+
37
+ interface ResetMessage {
38
+ type: "reset";
39
+ }
40
+
41
+ type WorkerMessage = LoadMessage | GenerateMessage | InterruptMessage | ResetMessage;
42
+
43
+ // Message types to main thread
44
+ interface ProgressMessage {
45
+ type: "progress";
46
+ progress: number;
47
+ file?: string;
48
+ }
49
+
50
+ interface ReadyMessage {
51
+ type: "ready";
52
+ }
53
+
54
+ interface UpdateMessage {
55
+ type: "update";
56
+ token: string;
57
+ tokensPerSecond: number;
58
+ numTokens: number;
59
+ }
60
+
61
+ interface CompleteMessage {
62
+ type: "complete";
63
+ text: string;
64
+ }
65
+
66
+ interface ErrorMessage {
67
+ type: "error";
68
+ error: string;
69
+ }
70
+
71
+ type WorkerResponse = ProgressMessage | ReadyMessage | UpdateMessage | CompleteMessage | ErrorMessage;
72
+
73
+ function postMessage(message: WorkerResponse) {
74
+ self.postMessage(message);
75
+ }
76
+
77
+ // Load model
78
+ async function loadModel(modelId: string) {
79
+ try {
80
+ // Check cache first
81
+ if (modelCache[modelId]) {
82
+ model = modelCache[modelId].model;
83
+ tokenizer = modelCache[modelId].tokenizer;
84
+ postMessage({ type: "ready" });
85
+ return;
86
+ }
87
+
88
+ const progressCallback = (progress: any) => {
89
+ if (
90
+ progress.status === "progress" &&
91
+ progress.file.endsWith(".onnx_data")
92
+ ) {
93
+ const percentage = Math.round(
94
+ (progress.loaded / progress.total) * 100
95
+ );
96
+ postMessage({
97
+ type: "progress",
98
+ progress: percentage,
99
+ file: progress.file,
100
+ });
101
+ }
102
+ };
103
+
104
+ // Load tokenizer
105
+ tokenizer = await AutoTokenizer.from_pretrained(modelId, {
106
+ progress_callback: progressCallback,
107
+ });
108
+
109
+ // Load model
110
+ model = await AutoModelForCausalLM.from_pretrained(modelId, {
111
+ dtype: "q4f16",
112
+ device: "webgpu",
113
+ progress_callback: progressCallback,
114
+ });
115
+
116
+ // Pre-warm the model with a dummy input for shader compilation
117
+ const dummyInput = tokenizer("Hello", {
118
+ return_tensors: "pt",
119
+ padding: false,
120
+ truncation: false,
121
+ });
122
+ await model.generate({
123
+ ...dummyInput,
124
+ max_new_tokens: 1,
125
+ do_sample: false,
126
+ });
127
+
128
+ // Cache the loaded model
129
+ modelCache[modelId] = { model, tokenizer };
130
+
131
+ postMessage({ type: "ready" });
132
+ } catch (error) {
133
+ postMessage({
134
+ type: "error",
135
+ error: error instanceof Error ? error.message : "Failed to load model",
136
+ });
137
+ }
138
+ }
139
+
140
+ // Generate response
141
+ async function generate(
142
+ messages: Array<{ role: string; content: string }>,
143
+ tools: Array<any>
144
+ ) {
145
+ if (!model || !tokenizer) {
146
+ postMessage({ type: "error", error: "Model not loaded" });
147
+ return;
148
+ }
149
+
150
+ try {
151
+ isGenerating = true;
152
+
153
+ // Apply chat template with tools
154
+ const input = tokenizer.apply_chat_template(messages, {
155
+ tools,
156
+ add_generation_prompt: true,
157
+ return_dict: true,
158
+ });
159
+
160
+ // Track tokens and timing
161
+ const startTime = performance.now();
162
+ let tokenCount = 0;
163
+
164
+ const streamer = new TextStreamer(tokenizer, {
165
+ skip_prompt: true,
166
+ skip_special_tokens: false,
167
+ callback_function: (token: string) => {
168
+ if (!isGenerating) return; // Check if interrupted
169
+
170
+ tokenCount++;
171
+ const elapsed = (performance.now() - startTime) / 1000;
172
+ const tps = tokenCount / elapsed;
173
+
174
+ postMessage({
175
+ type: "update",
176
+ token,
177
+ tokensPerSecond: tps,
178
+ numTokens: tokenCount,
179
+ });
180
+ },
181
+ });
182
+
183
+ // Generate the response
184
+ const { sequences, past_key_values } = await model.generate({
185
+ ...input,
186
+ past_key_values: pastKeyValues,
187
+ max_new_tokens: 1024,
188
+ do_sample: false,
189
+ streamer,
190
+ return_dict_in_generate: true,
191
+ });
192
+
193
+ pastKeyValues = past_key_values;
194
+
195
+ // Decode the generated text
196
+ const response = tokenizer
197
+ .batch_decode(sequences.slice(null, [input.input_ids.dims[1], null]), {
198
+ skip_special_tokens: false,
199
+ })[0]
200
+ .replace(/<\|im_end\|>$/, "")
201
+ .replace(/<\|end_of_text\|>$/, "");
202
+
203
+ if (isGenerating) {
204
+ postMessage({ type: "complete", text: response });
205
+ }
206
+
207
+ isGenerating = false;
208
+ } catch (error) {
209
+ isGenerating = false;
210
+ postMessage({
211
+ type: "error",
212
+ error: error instanceof Error ? error.message : "Generation failed",
213
+ });
214
+ }
215
+ }
216
+
217
+ // Interrupt generation
218
+ function interrupt() {
219
+ isGenerating = false;
220
+ // Send a completion message with empty text to resolve the promise
221
+ postMessage({ type: "complete", text: "" });
222
+ }
223
+
224
+ // Reset past key values
225
+ function reset() {
226
+ pastKeyValues = null;
227
+ }
228
+
229
+ // Handle messages from main thread
230
+ self.onmessage = async (e: MessageEvent<WorkerMessage>) => {
231
+ const message = e.data;
232
+
233
+ switch (message.type) {
234
+ case "load":
235
+ await loadModel(message.modelId);
236
+ break;
237
+
238
+ case "generate":
239
+ await generate(message.messages, message.tools);
240
+ break;
241
+
242
+ case "interrupt":
243
+ interrupt();
244
+ break;
245
+
246
+ case "reset":
247
+ reset();
248
+ break;
249
+ }
250
+ };
251
+
252
+ // Export for TypeScript
253
+ export {};