R-4B / app.py
akhaliq's picture
akhaliq HF Staff
Update app.py
f9c99cb verified
import gradio as gr
import requests
from PIL import Image
import torch
from transformers import AutoModel, AutoProcessor
import spaces
model_path = "YannQi/R-4B"
model = AutoModel.from_pretrained(
model_path,
torch_dtype=torch.float32,
trust_remote_code=True,
).to("cuda")
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
@spaces.GPU(duration=120)
def generate_response(message, history, thinking_mode):
if not message:
return "", history
messages = []
all_images = []
for user_msg, asst_msg in history:
# Process user message
if isinstance(user_msg, str):
user_content = [{"type": "text", "text": user_msg}]
else:
text = user_msg.get('text', '')
files = user_msg.get('files', [])
file_paths = [f.get('path', str(f)) for f in files]
user_content = []
img_paths = file_paths if isinstance(file_paths, list) else []
for path in img_paths:
try:
img = Image.open(path)
all_images.append(img)
user_content.append({"type": "image", "image": path})
except:
pass
if text:
user_content.append({"type": "text", "text": text})
messages.append({"role": "user", "content": user_content})
# Process assistant message
asst_text = asst_msg if isinstance(asst_msg, str) else asst_msg.get('text', '')
messages.append({"role": "assistant", "content": [{"type": "text", "text": asst_text}]})
# Current user message
if isinstance(message, str):
curr_text = message
curr_files = []
else:
curr_text = message.get('text', '')
curr_files = message.get('files', [])
curr_user_content = []
curr_images = []
curr_file_paths = [f.get('path', str(f)) for f in curr_files]
for path in curr_file_paths:
if path.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
try:
img = Image.open(path)
curr_images.append(img)
curr_user_content.append({"type": "image", "image": path})
except:
pass
if curr_text:
curr_user_content.append({"type": "text", "text": curr_text})
if not curr_user_content:
return "", history
messages.append({"role": "user", "content": curr_user_content})
# Apply chat template
text = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
thinking_mode=thinking_mode
)
# All images
all_images += curr_images
# Process inputs
inputs = processor(
images=all_images if all_images else None,
text=text,
return_tensors="pt"
).to("cuda")
# Generate
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7
)
output_ids = generated_ids[0][len(inputs.input_ids[0]):]
output_text = processor.decode(
output_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
# Prepare display for current user message
user_display = message
new_history = history + [(user_display, output_text)]
return "", new_history
with gr.Blocks(title="Transformers Chat") as demo:
gr.Markdown("# Using 🤗 Transformers to Chat")
gr.Markdown("Select thinking mode: auto (auto-thinking), long (thinking), short (non-thinking). Default: auto.")
chatbot = gr.Chatbot(type="tuples", height=500, label="Chat")
with gr.Row():
msg = gr.MultimodalTextbox(
placeholder="Type your message or upload images...",
file_types=[".jpg", ".jpeg", ".png", ".gif", ".bmp"],
file_count="multiple",
label="Message"
)
mode = gr.Dropdown(
choices=["auto", "long", "short"],
value="auto",
label="Thinking Mode",
interactive=True
)
with gr.Row():
submit_btn = gr.Button("Send", variant="primary", scale=3)
clear_btn = gr.Button("Clear", scale=1)
submit_btn.click(generate_response, [msg, chatbot, mode], [msg, chatbot])
msg.submit(generate_response, [msg, chatbot, mode], [msg, chatbot])
clear_btn.click(lambda: ([], ""), None, [chatbot, msg], queue=False)
if __name__ == "__main__":
demo.launch()