Spaces:
Running
Running
| import { | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| TextStreamer, | |
| InterruptableStoppingCriteria, | |
| } from "@huggingface/transformers"; | |
| /** | |
| * Helper function to perform feature detection for WebGPU | |
| */ | |
| async function check() { | |
| try { | |
| const adapter = await navigator.gpu.requestAdapter(); | |
| if (!adapter) { | |
| throw new Error("WebGPU is not supported (no adapter found)"); | |
| } | |
| if (!adapter.features.has("shader-f16")) { | |
| throw new Error("shader-f16 is not supported in this browser"); | |
| } | |
| } catch (e) { | |
| self.postMessage({ | |
| status: "error", | |
| data: e.toString(), | |
| }); | |
| } | |
| } | |
| /** | |
| * This class uses the Singleton pattern to enable lazy-loading of the pipeline | |
| */ | |
| class TextGenerationPipeline { | |
| static model_id = "HuggingFaceTB/SmolLM3-3B-ONNX"; | |
| static async getInstance(progress_callback = null) { | |
| this.tokenizer ??= AutoTokenizer.from_pretrained(this.model_id, { | |
| progress_callback, | |
| }); | |
| this.model ??= AutoModelForCausalLM.from_pretrained(this.model_id, { | |
| dtype: "q4f16", | |
| device: "webgpu", | |
| progress_callback, | |
| }); | |
| return Promise.all([this.tokenizer, this.model]); | |
| } | |
| } | |
| const stopping_criteria = new InterruptableStoppingCriteria(); | |
| let past_key_values_cache = null; | |
| async function generate({ messages, reasonEnabled }) { | |
| const [tokenizer, model] = await TextGenerationPipeline.getInstance(); | |
| const inputs = tokenizer.apply_chat_template(messages, { | |
| enable_thinking: reasonEnabled, | |
| add_generation_prompt: true, | |
| return_dict: true, | |
| }); | |
| const [START_THINKING_TOKEN_ID, END_THINKING_TOKEN_ID] = tokenizer.encode( | |
| "<think></think>", | |
| { add_special_tokens: false }, | |
| ); | |
| let state = "answering"; // 'thinking' or 'answering' | |
| let startTime; | |
| let numTokens = 0; | |
| let tps; | |
| const token_callback_function = (tokens) => { | |
| startTime ??= performance.now(); | |
| if (numTokens++ > 0) { | |
| tps = (numTokens / (performance.now() - startTime)) * 1000; | |
| } | |
| switch (Number(tokens[0])) { | |
| case START_THINKING_TOKEN_ID: | |
| state = "thinking"; | |
| break; | |
| case END_THINKING_TOKEN_ID: | |
| state = "answering"; | |
| break; | |
| } | |
| }; | |
| const callback_function = (output) => { | |
| self.postMessage({ | |
| status: "update", | |
| output, | |
| tps, | |
| numTokens, | |
| state, | |
| }); | |
| }; | |
| const streamer = new TextStreamer(tokenizer, { | |
| skip_prompt: true, | |
| skip_special_tokens: true, | |
| callback_function, | |
| token_callback_function, | |
| }); | |
| // Tell the main thread we are starting | |
| self.postMessage({ status: "start" }); | |
| const { past_key_values, sequences } = await model.generate({ | |
| ...inputs, | |
| past_key_values: past_key_values_cache, | |
| // Sampling | |
| do_sample: !reasonEnabled, | |
| repetition_penalty: reasonEnabled ? 1.1 : undefined, | |
| top_k: 3, | |
| max_new_tokens: reasonEnabled ? 4096 : 1024, | |
| streamer, | |
| stopping_criteria, | |
| return_dict_in_generate: true, | |
| }); | |
| past_key_values_cache = past_key_values; | |
| const decoded = tokenizer.batch_decode(sequences, { | |
| skip_special_tokens: true, | |
| }); | |
| // Send the output back to the main thread | |
| self.postMessage({ | |
| status: "complete", | |
| output: decoded, | |
| }); | |
| } | |
| async function load() { | |
| self.postMessage({ | |
| status: "loading", | |
| data: "Loading model...", | |
| }); | |
| // Load the pipeline and save it for future use. | |
| const [tokenizer, model] = await TextGenerationPipeline.getInstance((x) => { | |
| // We also add a progress callback to the pipeline so that we can | |
| // track model loading. | |
| self.postMessage(x); | |
| }); | |
| self.postMessage({ | |
| status: "loading", | |
| data: "Compiling shaders and warming up model...", | |
| }); | |
| // Run model with dummy input to compile shaders | |
| const inputs = tokenizer("a"); | |
| await model.generate({ ...inputs, max_new_tokens: 1 }); | |
| self.postMessage({ status: "ready" }); | |
| } | |
| // Listen for messages from the main thread | |
| self.addEventListener("message", async (e) => { | |
| const { type, data } = e.data; | |
| switch (type) { | |
| case "check": | |
| check(); | |
| break; | |
| case "load": | |
| load(); | |
| break; | |
| case "generate": | |
| stopping_criteria.reset(); | |
| generate(data); | |
| break; | |
| case "interrupt": | |
| stopping_criteria.interrupt(); | |
| break; | |
| case "reset": | |
| past_key_values_cache = null; | |
| stopping_criteria.reset(); | |
| break; | |
| } | |
| }); | |