aladdin1995 commited on
Commit
c24971a
·
verified ·
1 Parent(s): cfb21b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -137
app.py CHANGED
@@ -1,16 +1,12 @@
1
- # app.py
2
- # Gradio UI for PromptEnhancerV2
3
-
4
  import os
5
- from threading import Thread
6
- from transformers import TextIteratorStreamer, AutoTokenizer
7
  import time
8
  import logging
9
  import re
10
- import torch
11
  import gradio as gr
12
- import spaces
13
 
 
 
14
 
15
  # 尝试导入 qwen_vl_utils,若失败则提供降级实现(返回空的图像/视频输入)
16
  try:
@@ -25,120 +21,114 @@ def replace_single_quotes(text):
25
  replaced_text = replaced_text.replace("’", "”").replace("‘", "“")
26
  return replaced_text
27
 
28
- class PromptEnhancerV2:
29
- @spaces.GPU
30
- def __init__(self, models_root_path, device_map="auto", torch_dtype="bfloat16"):#auto
31
- from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
32
- if not logging.getLogger(__name__).handlers:
33
- logging.basicConfig(level=logging.INFO)
34
- self.logger = logging.getLogger(__name__)
35
-
36
- # dtype 兼容处理
37
- if torch_dtype == "bfloat16":
38
- dtype = torch.bfloat16
39
- elif torch_dtype == "float16":
40
- dtype = torch.float16
41
- else:
42
- dtype = torch.float32
43
-
44
- self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
45
- models_root_path,
46
- torch_dtype=dtype,
47
- device_map=device_map,
48
- )
49
- self.processor = AutoProcessor.from_pretrained(models_root_path)
50
-
51
- # @torch.inference_mode()
52
- @spaces.GPU
53
- def predict(
54
- self,
55
- prompt_cot,
56
- sys_prompt="请根据用户的输入,生成思考过程的思维链并改写提示词:",
57
- temperature=0.1,
58
- top_p=1.0,
59
- max_new_tokens=2048,
60
- device="cuda",
61
- ):
62
- org_prompt_cot = prompt_cot
63
- try:
64
- user_prompt_format = sys_prompt + "\n" + org_prompt_cot
65
- messages = [
66
- {
67
- "role": "user",
68
- "content": [
69
- {"type": "text", "text": user_prompt_format},
70
- ],
71
- }
72
- ]
73
-
74
- text = self.processor.apply_chat_template(
75
- messages, tokenize=False, add_generation_prompt=True
76
- )
77
- image_inputs, video_inputs = process_vision_info(messages)
78
- inputs = self.processor(
79
- text=[text],
80
- images=image_inputs,
81
- videos=video_inputs,
82
- padding=True,
83
- return_tensors="pt",
84
- )
85
- inputs = inputs.to(device)
86
-
87
- # 注意:原始代码固定 do_sample=False,top_k=5, top_p=0.9,这里保持一致
88
- generated_ids = self.model.generate(
89
- **inputs,
90
- max_new_tokens=2048, # 与原始代码保持一致(未使用 max_new_tokens 参数)
91
- temperature=float(temperature),
92
- do_sample=False,
93
- top_k=5,
94
- top_p=0.9
95
- )
96
- generated_ids_trimmed = [
97
- out_ids[len(in_ids):]
98
- for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
99
- ]
100
- output_text = self.processor.batch_decode(
101
- generated_ids_trimmed,
102
- skip_special_tokens=True,
103
- clean_up_tokenization_spaces=False,
104
- )
105
- output_res = output_text[0]
106
- assert output_res.count("think>") == 2
107
- prompt_cot = output_res.split("think>")[-1]
108
- if prompt_cot.startswith("\n"):
109
- prompt_cot = prompt_cot[1:]
110
- prompt_cot = replace_single_quotes(prompt_cot)
111
- except Exception as e:
112
- prompt_cot = org_prompt_cot
113
- print(f"✗ Re-prompting failed, so we are using the original prompt. Error: {e}")
114
-
115
- return prompt_cot
116
- # -------------------------
117
- # Gradio app helpers
118
- # -------------------------
119
-
120
  DEFAULT_MODEL_PATH = os.environ.get("MODEL_OUTPUT_PATH", "PromptEnhancer/PromptEnhancer-32B")
121
 
122
- def ensure_enhancer(state, model_path, device_map, torch_dtype):
123
- """
124
- state: dict or None
125
- Returns: (state_dict)
126
- """
127
- need_reload = False
128
- if state is None or not isinstance(state, dict):
129
- need_reload = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  else:
131
- prev_path = state.get("model_path")
132
- prev_map = state.get("device_map")
133
- prev_dtype = state.get("torch_dtype")
134
- if prev_path != model_path or prev_map != device_map or prev_dtype != torch_dtype:
135
- need_reload = True
 
 
 
 
 
 
 
 
 
 
136
 
137
- if need_reload:
138
- enhancer = PromptEnhancerV2(model_path, device_map=device_map, torch_dtype=torch_dtype)
139
- return {"enhancer": enhancer, "model_path": model_path, "device_map": device_map, "torch_dtype": torch_dtype}
140
- return state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
 
 
 
142
 
143
  def run_single(prompt, sys_prompt, temperature, max_new_tokens, device,
144
  model_path, device_map, torch_dtype, state):
@@ -146,21 +136,24 @@ def run_single(prompt, sys_prompt, temperature, max_new_tokens, device,
146
  return "", "请先输入提示词。", state
147
 
148
  t0 = time.time()
149
- state = ensure_enhancer(state, model_path, device_map, torch_dtype)
150
- enhancer = state["enhancer"]
151
  try:
152
- out = enhancer.predict(
 
 
 
153
  prompt_cot=prompt,
154
  sys_prompt=sys_prompt,
155
  temperature=temperature,
156
  max_new_tokens=max_new_tokens,
157
- device=device
158
  )
159
  dt = time.time() - t0
160
- return out, f"耗时:{dt:.2f}s", state
 
 
161
  except Exception as e:
162
- return "", f"推理失败:{e}", state
163
-
164
  # 示例数据
165
  test_list_zh = [
166
  "第三人称视角,赛车在城市赛道上飞驰,左上角是小地图,地图下面是当前名次,右下角仪表盘显示当前速度。",
@@ -183,13 +176,13 @@ with gr.Blocks(title="Prompt Enhancer_V2") as demo:
183
  model_path = gr.Textbox(
184
  label="模型路径(本地或HF地址)",
185
  value=DEFAULT_MODEL_PATH,
186
- placeholder="/apdcephfs_jn3/share_302243908/aladdinwang/model_weight/cot_taurus_v6_50/global_step0",
187
  )
188
  device_map = gr.Dropdown(
189
  choices=["cuda", "cpu"],
190
  value="cuda",
191
  label="device_map(模型加载映射)"
192
- )
193
  torch_dtype = gr.Dropdown(
194
  choices=["bfloat16", "float16", "float32"],
195
  value="bfloat16",
@@ -204,7 +197,7 @@ with gr.Blocks(title="Prompt Enhancer_V2") as demo:
204
  )
205
  with gr.Row():
206
  temperature = gr.Slider(0, 1, value=0.1, step=0.05, label="Temperature")
207
- max_new_tokens = gr.Slider(16, 4096, value=2048, step=16, label="Max New Tokens(原代码未使用该参数)")
208
  device = gr.Dropdown(choices=["cuda", "cpu"], value="cuda", label="推理device")
209
 
210
  state = gr.State(value=None)
@@ -223,12 +216,6 @@ with gr.Blocks(title="Prompt Enhancer_V2") as demo:
223
  out_text = gr.Textbox(label="重写结果", lines=10)
224
  out_info = gr.Markdown("准备就绪。")
225
 
226
- # run_btn.click(
227
- # stream_single,
228
- # inputs=[prompt, sys_prompt, temperature, max_new_tokens, device,
229
- # model_path, device_map, torch_dtype, state],
230
- # outputs=[out_text, out_info, state]
231
- # )
232
  run_btn.click(
233
  run_single,
234
  inputs=[prompt, sys_prompt, temperature, max_new_tokens, device,
@@ -236,12 +223,10 @@ with gr.Blocks(title="Prompt Enhancer_V2") as demo:
236
  outputs=[out_text, out_info, state]
237
  )
238
 
239
- gr.Markdown(
240
- "提示:如有任何问题可email联系:[email protected]"
241
- )
242
 
243
- # 为避免多并发导致显存爆,限制并发
244
  # demo.queue(concurrency_count=1, max_size=10)
 
245
  if __name__ == "__main__":
246
- # demo.launch(server_name="0.0.0.0", server_port=8080, show_error=True)
247
  demo.launch(ssr_mode=False, show_error=True, share=True)
 
 
 
 
1
  import os
 
 
2
  import time
3
  import logging
4
  import re
 
5
  import gradio as gr
6
+ from spaces import zero # 关键:引入 zero 装饰器
7
 
8
+ # 不要在这里 import torch 或加载模型
9
+ # from transformers import TextIteratorStreamer, AutoTokenizer # 不再需要
10
 
11
  # 尝试导入 qwen_vl_utils,若失败则提供降级实现(返回空的图像/视频输入)
12
  try:
 
21
  replaced_text = replaced_text.replace("’", "”").replace("‘", "“")
22
  return replaced_text
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  DEFAULT_MODEL_PATH = os.environ.get("MODEL_OUTPUT_PATH", "PromptEnhancer/PromptEnhancer-32B")
25
 
26
+ def _str_to_dtype(dtype_str):
27
+ # 在子进程中再真正用 torch;这里仅返回字符串用于传参
28
+ if dtype_str in ("bfloat16", "float16", "float32"):
29
+ return dtype_str
30
+ return "float32"
31
+
32
+ @zero.gpu # 在子进程(拥有 GPU)中执行:包含模型加载与推理
33
+ def gpu_predict(model_path, device_map, torch_dtype,
34
+ prompt_cot, sys_prompt, temperature, max_new_tokens, device):
35
+ # 注意:所有 CUDA 相关 import 放在子进程函数内部
36
+ import torch
37
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
38
+
39
+ # logger(可选)
40
+ if not logging.getLogger(__name__).handlers:
41
+ logging.basicConfig(level=logging.INFO)
42
+ logger = logging.getLogger(__name__)
43
+
44
+ # dtype
45
+ if torch_dtype == "bfloat16":
46
+ dtype = torch.bfloat16
47
+ elif torch_dtype == "float16":
48
+ dtype = torch.float16
49
  else:
50
+ dtype = torch.float32
51
+
52
+ # 设备映射:根据 UI 的 device / device_map 决定
53
+ # ZeroGPU 建议 GPU 推理时用 "cuda"
54
+ target_device = "cuda" if device == "cuda" else "cpu"
55
+ load_device_map = "cuda" if device_map == "cuda" else "cpu"
56
+
57
+ # 加载模型与处理器(在子进程)
58
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
59
+ model_path,
60
+ torch_dtype=dtype,
61
+ device_map=load_device_map,
62
+ attn_implementation="sdpa", # 禁用 flash-attn,兼容性更好
63
+ )
64
+ processor = AutoProcessor.from_pretrained(model_path)
65
 
66
+ # 组装消息
67
+ org_prompt_cot = prompt_cot
68
+ try:
69
+ user_prompt_format = sys_prompt + "\n" + org_prompt_cot
70
+ messages = [
71
+ {
72
+ "role": "user",
73
+ "content": [
74
+ {"type": "text", "text": user_prompt_format},
75
+ ],
76
+ }
77
+ ]
78
+
79
+ text = processor.apply_chat_template(
80
+ messages, tokenize=False, add_generation_prompt=True
81
+ )
82
+ image_inputs, video_inputs = process_vision_info(messages)
83
+
84
+ inputs = processor(
85
+ text=[text],
86
+ images=image_inputs,
87
+ videos=video_inputs,
88
+ padding=True,
89
+ return_tensors="pt",
90
+ )
91
+ # 把输入移动到目标设备
92
+ inputs = inputs.to(target_device)
93
+
94
+ # 生成
95
+ generated_ids = model.generate(
96
+ **inputs,
97
+ max_new_tokens=int(max_new_tokens),
98
+ temperature=float(temperature),
99
+ do_sample=False,
100
+ top_k=5,
101
+ top_p=0.9,
102
+ )
103
+ # 仅解码新增 token
104
+ generated_ids_trimmed = [
105
+ out_ids[len(in_ids):]
106
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
107
+ ]
108
+ output_text = processor.batch_decode(
109
+ generated_ids_trimmed,
110
+ skip_special_tokens=True,
111
+ clean_up_tokenization_spaces=False,
112
+ )
113
+ output_res = output_text[0]
114
+ # 兼容原逻辑:提取 think> 之后的内容
115
+ try:
116
+ assert output_res.count("think>") == 2
117
+ new_prompt = output_res.split("think>")[-1]
118
+ if new_prompt.startswith("\n"):
119
+ new_prompt = new_prompt[1:]
120
+ new_prompt = replace_single_quotes(new_prompt)
121
+ except Exception:
122
+ # 如果格式不符合预期,则直接回退为原始输入
123
+ new_prompt = org_prompt_cot
124
+ return new_prompt, ""
125
+ except Exception as e:
126
+ # 失败则返回原始提示词和错误信息
127
+ return org_prompt_cot, f"推理失败:{e}"
128
 
129
+ # -------------------------
130
+ # Gradio app
131
+ # -------------------------
132
 
133
  def run_single(prompt, sys_prompt, temperature, max_new_tokens, device,
134
  model_path, device_map, torch_dtype, state):
 
136
  return "", "请先输入提示词。", state
137
 
138
  t0 = time.time()
 
 
139
  try:
140
+ new_prompt, err = gpu_predict(
141
+ model_path=model_path,
142
+ device_map=device_map,
143
+ torch_dtype=_str_to_dtype(torch_dtype),
144
  prompt_cot=prompt,
145
  sys_prompt=sys_prompt,
146
  temperature=temperature,
147
  max_new_tokens=max_new_tokens,
148
+ device=device,
149
  )
150
  dt = time.time() - t0
151
+ if err:
152
+ return new_prompt, f"{err}(耗时 {dt:.2f}s)", state
153
+ return new_prompt, f"耗时:{dt:.2f}s", state
154
  except Exception as e:
155
+ return "", f"调用失败:{e}", state
156
+
157
  # 示例数据
158
  test_list_zh = [
159
  "第三人称视角,赛车在城市赛道上飞驰,左上角是小地图,地图下面是当前名次,右下角仪表盘显示当前速度。",
 
176
  model_path = gr.Textbox(
177
  label="模型路径(本地或HF地址)",
178
  value=DEFAULT_MODEL_PATH,
179
+ placeholder="例如:Qwen/Qwen2.5-VL-7B-Instruct",
180
  )
181
  device_map = gr.Dropdown(
182
  choices=["cuda", "cpu"],
183
  value="cuda",
184
  label="device_map(模型加载映射)"
185
+ )
186
  torch_dtype = gr.Dropdown(
187
  choices=["bfloat16", "float16", "float32"],
188
  value="bfloat16",
 
197
  )
198
  with gr.Row():
199
  temperature = gr.Slider(0, 1, value=0.1, step=0.05, label="Temperature")
200
+ max_new_tokens = gr.Slider(16, 4096, value=2048, step=16, label="Max New Tokens")
201
  device = gr.Dropdown(choices=["cuda", "cpu"], value="cuda", label="推理device")
202
 
203
  state = gr.State(value=None)
 
216
  out_text = gr.Textbox(label="重写结果", lines=10)
217
  out_info = gr.Markdown("准备就绪。")
218
 
 
 
 
 
 
 
219
  run_btn.click(
220
  run_single,
221
  inputs=[prompt, sys_prompt, temperature, max_new_tokens, device,
 
223
  outputs=[out_text, out_info, state]
224
  )
225
 
226
+ gr.Markdown("提示:如有任何问题可 email 联系:[email protected]")
 
 
227
 
228
+ # 为避免多并发导致显存爆,可限制并发(ZeroGPU 本身是无状态,仍建议限制)
229
  # demo.queue(concurrency_count=1, max_size=10)
230
+
231
  if __name__ == "__main__":
 
232
  demo.launch(ssr_mode=False, show_error=True, share=True)