|
|
import gradio as gr |
|
|
import requests |
|
|
from PIL import Image |
|
|
import torch |
|
|
from transformers import AutoModel, AutoProcessor |
|
|
import spaces |
|
|
|
|
|
model_path = "YannQi/R-4B" |
|
|
|
|
|
model = AutoModel.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=torch.float32, |
|
|
trust_remote_code=True, |
|
|
).to("cuda") |
|
|
|
|
|
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def generate_response(message, history, thinking_mode): |
|
|
if not message: |
|
|
return "", history |
|
|
|
|
|
messages = [] |
|
|
all_images = [] |
|
|
|
|
|
for user_msg, asst_msg in history: |
|
|
|
|
|
if isinstance(user_msg, str): |
|
|
user_content = [{"type": "text", "text": user_msg}] |
|
|
else: |
|
|
text = user_msg.get('text', '') |
|
|
files = user_msg.get('files', []) |
|
|
file_paths = [f.get('path', str(f)) for f in files] |
|
|
user_content = [] |
|
|
img_paths = file_paths if isinstance(file_paths, list) else [] |
|
|
for path in img_paths: |
|
|
try: |
|
|
img = Image.open(path) |
|
|
all_images.append(img) |
|
|
user_content.append({"type": "image", "image": path}) |
|
|
except: |
|
|
pass |
|
|
if text: |
|
|
user_content.append({"type": "text", "text": text}) |
|
|
messages.append({"role": "user", "content": user_content}) |
|
|
|
|
|
|
|
|
asst_text = asst_msg if isinstance(asst_msg, str) else asst_msg.get('text', '') |
|
|
messages.append({"role": "assistant", "content": [{"type": "text", "text": asst_text}]}) |
|
|
|
|
|
|
|
|
if isinstance(message, str): |
|
|
curr_text = message |
|
|
curr_files = [] |
|
|
else: |
|
|
curr_text = message.get('text', '') |
|
|
curr_files = message.get('files', []) |
|
|
curr_user_content = [] |
|
|
curr_images = [] |
|
|
curr_file_paths = [f.get('path', str(f)) for f in curr_files] |
|
|
for path in curr_file_paths: |
|
|
if path.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')): |
|
|
try: |
|
|
img = Image.open(path) |
|
|
curr_images.append(img) |
|
|
curr_user_content.append({"type": "image", "image": path}) |
|
|
except: |
|
|
pass |
|
|
if curr_text: |
|
|
curr_user_content.append({"type": "text", "text": curr_text}) |
|
|
if not curr_user_content: |
|
|
return "", history |
|
|
messages.append({"role": "user", "content": curr_user_content}) |
|
|
|
|
|
|
|
|
text = processor.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True, |
|
|
thinking_mode=thinking_mode |
|
|
) |
|
|
|
|
|
|
|
|
all_images += curr_images |
|
|
|
|
|
|
|
|
inputs = processor( |
|
|
images=all_images if all_images else None, |
|
|
text=text, |
|
|
return_tensors="pt" |
|
|
).to("cuda") |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
generated_ids = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=512, |
|
|
do_sample=True, |
|
|
temperature=0.7 |
|
|
) |
|
|
output_ids = generated_ids[0][len(inputs.input_ids[0]):] |
|
|
output_text = processor.decode( |
|
|
output_ids, |
|
|
skip_special_tokens=True, |
|
|
clean_up_tokenization_spaces=False |
|
|
) |
|
|
|
|
|
|
|
|
user_display = message |
|
|
new_history = history + [(user_display, output_text)] |
|
|
|
|
|
return "", new_history |
|
|
|
|
|
with gr.Blocks(title="Transformers Chat") as demo: |
|
|
gr.Markdown("# Using 🤗 Transformers to Chat") |
|
|
gr.Markdown("Select thinking mode: auto (auto-thinking), long (thinking), short (non-thinking). Default: auto.") |
|
|
chatbot = gr.Chatbot(type="tuples", height=500, label="Chat") |
|
|
with gr.Row(): |
|
|
msg = gr.MultimodalTextbox( |
|
|
placeholder="Type your message or upload images...", |
|
|
file_types=[".jpg", ".jpeg", ".png", ".gif", ".bmp"], |
|
|
file_count="multiple", |
|
|
label="Message" |
|
|
) |
|
|
mode = gr.Dropdown( |
|
|
choices=["auto", "long", "short"], |
|
|
value="auto", |
|
|
label="Thinking Mode", |
|
|
interactive=True |
|
|
) |
|
|
with gr.Row(): |
|
|
submit_btn = gr.Button("Send", variant="primary", scale=3) |
|
|
clear_btn = gr.Button("Clear", scale=1) |
|
|
submit_btn.click(generate_response, [msg, chatbot, mode], [msg, chatbot]) |
|
|
msg.submit(generate_response, [msg, chatbot, mode], [msg, chatbot]) |
|
|
clear_btn.click(lambda: ([], ""), None, [chatbot, msg], queue=False) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |