|
|
import os |
|
|
import time |
|
|
import logging |
|
|
import re |
|
|
import gradio as gr |
|
|
from spaces import zero |
|
|
import spaces |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from qwen_vl_utils import process_vision_info |
|
|
except Exception: |
|
|
def process_vision_info(messages): |
|
|
return None, None |
|
|
|
|
|
def replace_single_quotes(text): |
|
|
pattern = r"\B'([^']*)'\B" |
|
|
replaced_text = re.sub(pattern, r'"\1"', text) |
|
|
replaced_text = replaced_text.replace("’", "”").replace("‘", "“") |
|
|
return replaced_text |
|
|
|
|
|
DEFAULT_MODEL_PATH = os.environ.get("MODEL_OUTPUT_PATH", "PromptEnhancer/PromptEnhancer-32B") |
|
|
|
|
|
def _str_to_dtype(dtype_str): |
|
|
|
|
|
if dtype_str in ("bfloat16", "float16", "float32"): |
|
|
return dtype_str |
|
|
return "float32" |
|
|
|
|
|
@spaces.GPU |
|
|
def gpu_predict(model_path, device_map, torch_dtype, |
|
|
prompt_cot, sys_prompt, temperature, max_new_tokens, device): |
|
|
|
|
|
import torch |
|
|
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor |
|
|
|
|
|
|
|
|
if not logging.getLogger(__name__).handlers: |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
if torch_dtype == "bfloat16": |
|
|
dtype = torch.bfloat16 |
|
|
elif torch_dtype == "float16": |
|
|
dtype = torch.float16 |
|
|
else: |
|
|
dtype = torch.float32 |
|
|
|
|
|
|
|
|
|
|
|
target_device = "cuda" if device == "cuda" else "cpu" |
|
|
load_device_map = "cuda" if device_map == "cuda" else "cpu" |
|
|
|
|
|
|
|
|
model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=dtype, |
|
|
device_map=load_device_map, |
|
|
attn_implementation="sdpa", |
|
|
) |
|
|
processor = AutoProcessor.from_pretrained(model_path) |
|
|
|
|
|
|
|
|
org_prompt_cot = prompt_cot |
|
|
try: |
|
|
user_prompt_format = sys_prompt + "\n" + org_prompt_cot |
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "text", "text": user_prompt_format}, |
|
|
], |
|
|
} |
|
|
] |
|
|
|
|
|
text = processor.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
image_inputs, video_inputs = process_vision_info(messages) |
|
|
|
|
|
inputs = processor( |
|
|
text=[text], |
|
|
images=image_inputs, |
|
|
videos=video_inputs, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
inputs = inputs.to(target_device) |
|
|
|
|
|
|
|
|
generated_ids = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=int(max_new_tokens), |
|
|
temperature=float(temperature), |
|
|
do_sample=False, |
|
|
top_k=5, |
|
|
top_p=0.9, |
|
|
) |
|
|
|
|
|
generated_ids_trimmed = [ |
|
|
out_ids[len(in_ids):] |
|
|
for in_ids, out_ids in zip(inputs.input_ids, generated_ids) |
|
|
] |
|
|
output_text = processor.batch_decode( |
|
|
generated_ids_trimmed, |
|
|
skip_special_tokens=True, |
|
|
clean_up_tokenization_spaces=False, |
|
|
) |
|
|
output_res = output_text[0] |
|
|
|
|
|
try: |
|
|
assert output_res.count("think>") == 2 |
|
|
new_prompt = output_res.split("think>")[-1] |
|
|
if new_prompt.startswith("\n"): |
|
|
new_prompt = new_prompt[1:] |
|
|
new_prompt = replace_single_quotes(new_prompt) |
|
|
except Exception: |
|
|
|
|
|
new_prompt = org_prompt_cot |
|
|
return new_prompt, "" |
|
|
except Exception as e: |
|
|
|
|
|
return org_prompt_cot, f"推理失败:{e}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_single(prompt, sys_prompt, temperature, max_new_tokens, device, |
|
|
model_path, device_map, torch_dtype, state): |
|
|
if not prompt or not str(prompt).strip(): |
|
|
return "", "请先输入提示词。", state |
|
|
|
|
|
t0 = time.time() |
|
|
try: |
|
|
new_prompt, err = gpu_predict( |
|
|
model_path=model_path, |
|
|
device_map=device_map, |
|
|
torch_dtype=_str_to_dtype(torch_dtype), |
|
|
prompt_cot=prompt, |
|
|
sys_prompt=sys_prompt, |
|
|
temperature=temperature, |
|
|
max_new_tokens=max_new_tokens, |
|
|
device=device, |
|
|
) |
|
|
dt = time.time() - t0 |
|
|
if err: |
|
|
return new_prompt, f"{err}(耗时 {dt:.2f}s)", state |
|
|
return new_prompt, f"耗时:{dt:.2f}s", state |
|
|
except Exception as e: |
|
|
return "", f"调用失败:{e}", state |
|
|
|
|
|
|
|
|
test_list_zh = [ |
|
|
"第三人称视角,赛车在城市赛道上飞驰,左上角是小地图,地图下面是当前名次,右下角仪表盘显示当前速度。", |
|
|
"韩系插画风女生头像,粉紫色短发+透明感腮红,侧光渲染。", |
|
|
"点彩派,盛夏海滨,两位渔夫正在搬运木箱,三艘帆船停在岸边,对角线构图。", |
|
|
"一幅由梵高绘制的梦境麦田,旋转的蓝色星云与燃烧的向日葵相纠缠。", |
|
|
] |
|
|
test_list_en = [ |
|
|
"Create a painting depicting a 30-year-old white female white-collar worker on a business trip by plane.", |
|
|
"Depicted in the anime style of Studio Ghibli, a girl stands quietly at the deck with a gentle smile.", |
|
|
"Blue background, a lone girl gazes into the distant sea; her expression is sorrowful.", |
|
|
"A blend of expressionist and vintage styles, drawing a building with colorful walls.", |
|
|
"Paint a winter scene with crystalline ice hangings from an Antarctic research station.", |
|
|
] |
|
|
|
|
|
with gr.Blocks(title="Prompt Enhancer_V2") as demo: |
|
|
gr.Markdown("## 提示词重写器") |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
model_path = gr.Textbox( |
|
|
label="模型路径(本地或HF地址)", |
|
|
value=DEFAULT_MODEL_PATH, |
|
|
placeholder="例如:Qwen/Qwen2.5-VL-7B-Instruct", |
|
|
) |
|
|
device_map = gr.Dropdown( |
|
|
choices=["cuda", "cpu"], |
|
|
value="cuda", |
|
|
label="device_map(模型加载映射)" |
|
|
) |
|
|
torch_dtype = gr.Dropdown( |
|
|
choices=["bfloat16", "float16", "float32"], |
|
|
value="bfloat16", |
|
|
label="torch_dtype" |
|
|
) |
|
|
|
|
|
with gr.Column(scale=3): |
|
|
sys_prompt = gr.Textbox( |
|
|
label="系统提示词(默认无需修改)", |
|
|
value="请根据用户的输入,生成思考过程的思维链并改写提示词:", |
|
|
lines=3 |
|
|
) |
|
|
with gr.Row(): |
|
|
temperature = gr.Slider(0, 1, value=0.1, step=0.05, label="Temperature") |
|
|
max_new_tokens = gr.Slider(16, 4096, value=2048, step=16, label="Max New Tokens") |
|
|
device = gr.Dropdown(choices=["cuda", "cpu"], value="cuda", label="推理device") |
|
|
|
|
|
state = gr.State(value=None) |
|
|
|
|
|
with gr.Tab("推理"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
prompt = gr.Textbox(label="输入提示词", lines=6, placeholder="在此粘贴要改写的提示词...") |
|
|
run_btn = gr.Button("生成重写", variant="primary") |
|
|
gr.Examples( |
|
|
examples=test_list_zh + test_list_en, |
|
|
inputs=prompt, |
|
|
label="示例" |
|
|
) |
|
|
with gr.Column(scale=3): |
|
|
out_text = gr.Textbox(label="重写结果", lines=10) |
|
|
out_info = gr.Markdown("准备就绪。") |
|
|
|
|
|
run_btn.click( |
|
|
run_single, |
|
|
inputs=[prompt, sys_prompt, temperature, max_new_tokens, device, |
|
|
model_path, device_map, torch_dtype, state], |
|
|
outputs=[out_text, out_info, state] |
|
|
) |
|
|
|
|
|
gr.Markdown("提示:如有任何问题可 email 联系:[email protected]") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(ssr_mode=False, show_error=True, share=True) |