aladdin1995 commited on
Commit
6f21ce1
·
verified ·
1 Parent(s): c24971a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -4,6 +4,7 @@ import logging
4
  import re
5
  import gradio as gr
6
  from spaces import zero # 关键:引入 zero 装饰器
 
7
 
8
  # 不要在这里 import torch 或加载模型
9
  # from transformers import TextIteratorStreamer, AutoTokenizer # 不再需要
@@ -29,7 +30,7 @@ def _str_to_dtype(dtype_str):
29
  return dtype_str
30
  return "float32"
31
 
32
- @zero.gpu # 在子进程(拥有 GPU)中执行:包含模型加载与推理
33
  def gpu_predict(model_path, device_map, torch_dtype,
34
  prompt_cot, sys_prompt, temperature, max_new_tokens, device):
35
  # 注意:所有 CUDA 相关 import 放在子进程函数内部
 
4
  import re
5
  import gradio as gr
6
  from spaces import zero # 关键:引入 zero 装饰器
7
+ import spaces
8
 
9
  # 不要在这里 import torch 或加载模型
10
  # from transformers import TextIteratorStreamer, AutoTokenizer # 不再需要
 
30
  return dtype_str
31
  return "float32"
32
 
33
+ @spaces.GPU # 在子进程(拥有 GPU)中执行:包含模型加载与推理
34
  def gpu_predict(model_path, device_map, torch_dtype,
35
  prompt_cot, sys_prompt, temperature, max_new_tokens, device):
36
  # 注意:所有 CUDA 相关 import 放在子进程函数内部