Update app.py
Browse files
app.py
CHANGED
|
@@ -4,6 +4,7 @@ import logging
|
|
| 4 |
import re
|
| 5 |
import gradio as gr
|
| 6 |
from spaces import zero # 关键:引入 zero 装饰器
|
|
|
|
| 7 |
|
| 8 |
# 不要在这里 import torch 或加载模型
|
| 9 |
# from transformers import TextIteratorStreamer, AutoTokenizer # 不再需要
|
|
@@ -29,7 +30,7 @@ def _str_to_dtype(dtype_str):
|
|
| 29 |
return dtype_str
|
| 30 |
return "float32"
|
| 31 |
|
| 32 |
-
@
|
| 33 |
def gpu_predict(model_path, device_map, torch_dtype,
|
| 34 |
prompt_cot, sys_prompt, temperature, max_new_tokens, device):
|
| 35 |
# 注意:所有 CUDA 相关 import 放在子进程函数内部
|
|
|
|
| 4 |
import re
|
| 5 |
import gradio as gr
|
| 6 |
from spaces import zero # 关键:引入 zero 装饰器
|
| 7 |
+
import spaces
|
| 8 |
|
| 9 |
# 不要在这里 import torch 或加载模型
|
| 10 |
# from transformers import TextIteratorStreamer, AutoTokenizer # 不再需要
|
|
|
|
| 30 |
return dtype_str
|
| 31 |
return "float32"
|
| 32 |
|
| 33 |
+
@spaces.GPU # 在子进程(拥有 GPU)中执行:包含模型加载与推理
|
| 34 |
def gpu_predict(model_path, device_map, torch_dtype,
|
| 35 |
prompt_cot, sys_prompt, temperature, max_new_tokens, device):
|
| 36 |
# 注意:所有 CUDA 相关 import 放在子进程函数内部
|