broadfield-dev's picture
Update app.py
3356d92 verified
raw
history blame
4.07 kB
import gradio as gr
import torch
from PIL import Image
import requests
from io import BytesIO
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
# --- Configuration ---
MODEL_PATH = "Qwen/Qwen3-VL-2B-Instruct"
CPU_DEVICE = "cpu"
# --- Model and Processor Loading ---
# This will be done once when the Space starts.
# 'device_map="auto"' will correctly assign the model to the CPU in this environment.
print("Loading model and processor... This will take a few minutes on a CPU.")
processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
model = Qwen3VLForConditionalGeneration.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
dtype="auto", # Use 'auto' for dtype for better compatibility
device_map="auto" # This is the key for CPU (and GPU) compatibility
)
print("Model and processor loaded successfully.")
# --- Inference Function ---
def process_and_generate(image_input, text_prompt):
"""
Processes the image and text prompt, and generates a response from the model.
"""
if image_input is None or not text_prompt.strip():
return "Please provide both an image and a text prompt."
# Convert Gradio's numpy array to a PIL Image
pil_image = Image.fromarray(image_input)
# Prepare the messages payload for the model
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": pil_image},
{"type": "text", "text": text_prompt},
],
}
]
print("Processing inputs and generating response... This will be slow.")
try:
# Preparation for inference
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
)
inputs = inputs.to(model.device)
# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=1024)
# To get only the new tokens, we trim the input IDs from the generated IDs
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
# Decode the trimmed IDs to text
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
# batch_decode returns a list, we return the first element
return output_text[0]
except Exception as e:
return f"An error occurred during generation: {str(e)}"
# --- Gradio Interface ---
with gr.Blocks() as demo:
gr.Markdown(
"""
# Qwen3-VL-2B-Instruct CPU Demo
This Space runs the `Qwen/Qwen3-VL-2B-Instruct` model using the standard `transformers` library.
**Warning:** Running this on a free CPU Space is **very slow**. Please be patient after clicking the generate button.
"""
)
with gr.Row():
with gr.Column():
image_input = gr.Image(type="numpy", label="Upload Image")
text_prompt = gr.Textbox(label="Prompt", placeholder="e.g., Describe this image in detail.")
submit_button = gr.Button("Generate Response")
with gr.Column():
output_text = gr.Textbox(label="Model Output", lines=10, interactive=False)
submit_button.click(
fn=process_and_generate,
inputs=[image_input, text_prompt],
outputs=output_text
)
gr.Examples(
examples=[
["https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", "Describe this image."],
["https://qianwen-res.oss-accelerate.aliyuncs.com/Qwen3-VL/receipt.png", "Read the text from this receipt."],
["https://qianwen-res.oss-accelerate.aliyuncs.com/Qwen3-VL/what_is_in_the_box.jpg", "What is inside the red box?"],
],
inputs=[image_input, text_prompt]
)
if __name__ == "__main__":
demo.launch()