import gradio as gr from PIL import Image import torch import os # Load Hugging Face token securely from Space Secrets HF_TOKEN = os.getenv("HF_TOKEN") from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, pipeline lingshu_model = None lingshu_processor = None medgemma_pipe = None def load_lingshu(): global lingshu_model, lingshu_processor if lingshu_model is None or lingshu_processor is None: lingshu_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( "lingshu-medical-mllm/Lingshu-7B", torch_dtype=torch.bfloat16, device_map="auto" ) lingshu_processor = AutoProcessor.from_pretrained("lingshu-medical-mllm/Lingshu-7B") return lingshu_model, lingshu_processor def load_medgemma(): global medgemma_pipe if medgemma_pipe is None: medgemma_pipe = pipeline( "image-text-to-text", model="google/medgemma-27b-it", torch_dtype=torch.bfloat16, device="cuda", use_auth_token=HF_TOKEN ) return medgemma_pipe def inference(image, question, selected_model): if image is None or question is None or question.strip() == "": return "Please upload a medical image and enter your question or prompt." if selected_model == "Lingshu-7B": model, processor = load_lingshu() messages = [ {"role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": question} ]} ] text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor( text=[text], images=[image], padding=True, return_tensors="pt" ).to(model.device) with torch.no_grad(): generated_ids = model.generate(**inputs, max_new_tokens=128) trim_ids = generated_ids[:, inputs.input_ids.shape[1]:] out_text = processor.batch_decode(trim_ids, skip_special_tokens=True) return out_text[0] if out_text else "No response." elif selected_model == "MedGemma-27B-IT": pipe = load_medgemma() messages = [ {"role": "system", "content": [{"type": "text", "text": "You are a medical expert."}]}, {"role": "user", "content": [ {"type": "text", "text": question}, {"type": "image", "image": image} ]} ] try: res = pipe(text=messages, max_new_tokens=200) return res[0]["generated_text"][-1]["content"] except Exception as e: return f"MedGemma error: {str(e)}" return "Please select a valid model." with gr.Blocks() as demo: gr.Markdown("## 🩺 Multi-Modality Medical AI Doctor Companion\nUpload a medical image, type your question, and select a model to generate automated analysis/report.") model_radio = gr.Radio(label="Model", choices=["Lingshu-7B", "MedGemma-27B-IT"], value="Lingshu-7B") image_input = gr.Image(type="pil", label="Medical Image") text_input = gr.Textbox(lines=2, label="Prompt", value="Describe this image.") outbox = gr.Textbox(lines=10, label="AI Answer / Report", interactive=False) run_btn = gr.Button("Run Analysis") run_btn.click(inference, [image_input, text_input, model_radio], outbox) demo.launch()