Spaces:
Running
Running
| import { useState, useEffect, useRef, useCallback } from "react"; | |
| interface LLMState { | |
| isLoading: boolean; | |
| isReady: boolean; | |
| error: string | null; | |
| progress: number; | |
| tokensPerSecond: number | null; | |
| numTokens: number; | |
| } | |
| export const useLLM = (modelId?: string) => { | |
| const [state, setState] = useState<LLMState>({ | |
| isLoading: false, | |
| isReady: false, | |
| error: null, | |
| progress: 0, | |
| tokensPerSecond: null, | |
| numTokens: 0, | |
| }); | |
| const workerRef = useRef<Worker | null>(null); | |
| const onTokenCallbackRef = useRef<((token: string) => void) | null>(null); | |
| const resolveGenerationRef = useRef<((text: string) => void) | null>(null); | |
| const rejectGenerationRef = useRef<((error: Error) => void) | null>(null); | |
| // Initialize worker | |
| useEffect(() => { | |
| const worker = new Worker( | |
| new URL("../workers/llm.worker.ts", import.meta.url), | |
| { type: "module" } | |
| ); | |
| workerRef.current = worker; | |
| // Handle messages from worker | |
| worker.onmessage = (e) => { | |
| const message = e.data; | |
| switch (message.type) { | |
| case "progress": | |
| setState((prev) => ({ | |
| ...prev, | |
| progress: message.progress, | |
| isLoading: true, | |
| })); | |
| break; | |
| case "ready": | |
| setState((prev) => ({ | |
| ...prev, | |
| isLoading: false, | |
| isReady: true, | |
| progress: 100, | |
| })); | |
| break; | |
| case "update": | |
| setState((prev) => ({ | |
| ...prev, | |
| tokensPerSecond: message.tokensPerSecond, | |
| numTokens: message.numTokens, | |
| })); | |
| if (onTokenCallbackRef.current) { | |
| onTokenCallbackRef.current(message.token); | |
| } | |
| break; | |
| case "complete": | |
| if (resolveGenerationRef.current) { | |
| resolveGenerationRef.current(message.text); | |
| resolveGenerationRef.current = null; | |
| rejectGenerationRef.current = null; | |
| } | |
| break; | |
| case "error": | |
| setState((prev) => ({ | |
| ...prev, | |
| isLoading: false, | |
| error: message.error, | |
| })); | |
| if (rejectGenerationRef.current) { | |
| rejectGenerationRef.current(new Error(message.error)); | |
| resolveGenerationRef.current = null; | |
| rejectGenerationRef.current = null; | |
| } | |
| break; | |
| } | |
| }; | |
| worker.onerror = (error) => { | |
| setState((prev) => ({ | |
| ...prev, | |
| isLoading: false, | |
| error: error.message, | |
| })); | |
| }; | |
| return () => { | |
| worker.terminate(); | |
| }; | |
| }, []); | |
| const loadModel = useCallback(async () => { | |
| if (!modelId || !workerRef.current) { | |
| throw new Error("Model ID or worker not available"); | |
| } | |
| setState((prev) => ({ | |
| ...prev, | |
| isLoading: true, | |
| error: null, | |
| progress: 0, | |
| })); | |
| workerRef.current.postMessage({ | |
| type: "load", | |
| modelId, | |
| }); | |
| }, [modelId]); | |
| const generateResponse = useCallback( | |
| async ( | |
| messages: Array<{ role: string; content: string }>, | |
| tools: Array<any>, | |
| onToken?: (token: string) => void | |
| ): Promise<string> => { | |
| if (!workerRef.current) { | |
| throw new Error("Worker not initialized"); | |
| } | |
| return new Promise((resolve, reject) => { | |
| onTokenCallbackRef.current = onToken || null; | |
| resolveGenerationRef.current = resolve; | |
| rejectGenerationRef.current = reject; | |
| workerRef.current!.postMessage({ | |
| type: "generate", | |
| messages, | |
| tools, | |
| }); | |
| }); | |
| }, | |
| [] | |
| ); | |
| const interruptGeneration = useCallback(() => { | |
| if (workerRef.current) { | |
| workerRef.current.postMessage({ type: "interrupt" }); | |
| } | |
| }, []); | |
| const clearPastKeyValues = useCallback(() => { | |
| if (workerRef.current) { | |
| workerRef.current.postMessage({ type: "reset" }); | |
| } | |
| }, []); | |
| const cleanup = useCallback(() => { | |
| if (workerRef.current) { | |
| workerRef.current.terminate(); | |
| workerRef.current = null; | |
| } | |
| }, []); | |
| return { | |
| ...state, | |
| loadModel, | |
| generateResponse, | |
| clearPastKeyValues, | |
| cleanup, | |
| interruptGeneration, | |
| }; | |
| }; | |