MCP-WebGPU / src /hooks /useLLM.ts
shreyask's picture
feat: implement worker for LLM model loading and response generation with message handling
ebde7f4 verified
import { useState, useEffect, useRef, useCallback } from "react";
interface LLMState {
isLoading: boolean;
isReady: boolean;
error: string | null;
progress: number;
tokensPerSecond: number | null;
numTokens: number;
}
export const useLLM = (modelId?: string) => {
const [state, setState] = useState<LLMState>({
isLoading: false,
isReady: false,
error: null,
progress: 0,
tokensPerSecond: null,
numTokens: 0,
});
const workerRef = useRef<Worker | null>(null);
const onTokenCallbackRef = useRef<((token: string) => void) | null>(null);
const resolveGenerationRef = useRef<((text: string) => void) | null>(null);
const rejectGenerationRef = useRef<((error: Error) => void) | null>(null);
// Initialize worker
useEffect(() => {
const worker = new Worker(
new URL("../workers/llm.worker.ts", import.meta.url),
{ type: "module" }
);
workerRef.current = worker;
// Handle messages from worker
worker.onmessage = (e) => {
const message = e.data;
switch (message.type) {
case "progress":
setState((prev) => ({
...prev,
progress: message.progress,
isLoading: true,
}));
break;
case "ready":
setState((prev) => ({
...prev,
isLoading: false,
isReady: true,
progress: 100,
}));
break;
case "update":
setState((prev) => ({
...prev,
tokensPerSecond: message.tokensPerSecond,
numTokens: message.numTokens,
}));
if (onTokenCallbackRef.current) {
onTokenCallbackRef.current(message.token);
}
break;
case "complete":
if (resolveGenerationRef.current) {
resolveGenerationRef.current(message.text);
resolveGenerationRef.current = null;
rejectGenerationRef.current = null;
}
break;
case "error":
setState((prev) => ({
...prev,
isLoading: false,
error: message.error,
}));
if (rejectGenerationRef.current) {
rejectGenerationRef.current(new Error(message.error));
resolveGenerationRef.current = null;
rejectGenerationRef.current = null;
}
break;
}
};
worker.onerror = (error) => {
setState((prev) => ({
...prev,
isLoading: false,
error: error.message,
}));
};
return () => {
worker.terminate();
};
}, []);
const loadModel = useCallback(async () => {
if (!modelId || !workerRef.current) {
throw new Error("Model ID or worker not available");
}
setState((prev) => ({
...prev,
isLoading: true,
error: null,
progress: 0,
}));
workerRef.current.postMessage({
type: "load",
modelId,
});
}, [modelId]);
const generateResponse = useCallback(
async (
messages: Array<{ role: string; content: string }>,
tools: Array<any>,
onToken?: (token: string) => void
): Promise<string> => {
if (!workerRef.current) {
throw new Error("Worker not initialized");
}
return new Promise((resolve, reject) => {
onTokenCallbackRef.current = onToken || null;
resolveGenerationRef.current = resolve;
rejectGenerationRef.current = reject;
workerRef.current!.postMessage({
type: "generate",
messages,
tools,
});
});
},
[]
);
const interruptGeneration = useCallback(() => {
if (workerRef.current) {
workerRef.current.postMessage({ type: "interrupt" });
}
}, []);
const clearPastKeyValues = useCallback(() => {
if (workerRef.current) {
workerRef.current.postMessage({ type: "reset" });
}
}, []);
const cleanup = useCallback(() => {
if (workerRef.current) {
workerRef.current.terminate();
workerRef.current = null;
}
}, []);
return {
...state,
loadModel,
generateResponse,
clearPastKeyValues,
cleanup,
interruptGeneration,
};
};