🩺 MedNarrate AI
Multimodal Medical Assistant
AI-powered analysis of medical records, images, and clinical notes
import os import torch import gradio as gr from transformers import ( AutoTokenizer, AutoModelForCausalLM, BlipProcessor, BlipForConditionalGeneration, pipeline, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan ) from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler from PIL import Image import PyPDF2 from datetime import datetime from datasets import load_dataset import soundfile as sf import numpy as np # Set device device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Load models with error handling and optimization print("Loading models...") try: # Use smaller, more efficient summarization model summarizer = pipeline( "summarization", model="sshleifer/distilbart-cnn-6-6", device=0 if device == "cuda" else -1 ) print("✅ Summarizer loaded") except Exception as e: print(f"❌ Error loading summarizer: {e}") summarizer = None try: # Load BLIP with memory optimization (using base model instead of large) blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") blip_model = BlipForConditionalGeneration.from_pretrained( "Salesforce/blip-image-captioning-base", torch_dtype=torch.float16 if device == "cuda" else torch.float32, low_cpu_mem_usage=True ).to(device) print("✅ BLIP model loaded") except Exception as e: print(f"❌ Error loading BLIP: {e}") blip_model = None blip_processor = None try: # Load Stable Diffusion with optimizations sd_pipe = StableDiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float16 if device == "cuda" else torch.float32, safety_checker=None, low_cpu_mem_usage=True, variant="fp16" if device == "cuda" else None ).to(device) sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config) # Enable memory efficient attention if available if device == "cuda": try: sd_pipe.enable_attention_slicing() sd_pipe.enable_vae_slicing() except: pass print("✅ Stable Diffusion loaded") except Exception as e: print(f"❌ Error loading Stable Diffusion: {e}") sd_pipe = None try: # Load BioGPT for medical text enhancement biogpt_tokenizer = AutoTokenizer.from_pretrained("microsoft/biogpt") biogpt_model = AutoModelForCausalLM.from_pretrained( "microsoft/biogpt", low_cpu_mem_usage=True ).to(device) print("✅ BioGPT loaded") except Exception as e: print(f"❌ Error loading BioGPT: {e}") biogpt_model = None biogpt_tokenizer = None try: # Load SpeechT5 for Text-to-Speech (lightweight alternative to Bark) tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") tts_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device) tts_vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device) # Load speaker embeddings embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(device) print("✅ SpeechT5 TTS loaded") except Exception as e: print(f"❌ Error loading TTS: {e}") tts_processor = None tts_model = None tts_vocoder = None speaker_embeddings = None print("Model loading complete!") # Helper Functions def extract_text_from_pdf(pdf_file): """Extract text from uploaded PDF file""" try: pdf_reader = PyPDF2.PdfReader(pdf_file) text = "" for page in pdf_reader.pages: text += page.extract_text() return text except Exception as e: return f"Error extracting PDF: {str(e)}" def summarize_medical_text(text, max_length=150, min_length=50): """Summarize medical text using BART""" if summarizer is None: return "Summarization model not available." try: # Split text into chunks if too long max_input_length = 1024 if len(text.split()) > max_input_length: text = ' '.join(text.split()[:max_input_length]) if not text.strip(): return "No text to summarize." summary = summarizer( text, max_length=max_length, min_length=min_length, do_sample=False ) return summary[0]['summary_text'] except Exception as e: return f"Error in summarization: {str(e)}" def analyze_medical_image(image): """Analyze medical image and generate description""" if blip_model is None or blip_processor is None: return "Image analysis model not available." try: # Process image inputs = blip_processor(image, return_tensors="pt").to( device, torch.float16 if device == "cuda" else torch.float32 ) # Generate caption out = blip_model.generate(**inputs, max_length=100) caption = blip_processor.decode(out[0], skip_special_tokens=True) # Enhance with medical context using BioGPT if biogpt_model is not None and biogpt_tokenizer is not None: medical_prompt = f"Medical imaging analysis: {caption}. Detailed clinical observations:" inputs = biogpt_tokenizer(medical_prompt, return_tensors="pt").to(device) with torch.no_grad(): outputs = biogpt_model.generate( **inputs, max_length=150, num_return_sequences=1, temperature=0.7, do_sample=True ) enhanced_description = biogpt_tokenizer.decode(outputs[0], skip_special_tokens=True) return enhanced_description else: return f"Medical Image Analysis: {caption}" except Exception as e: return f"Error analyzing image: {str(e)}" def generate_medical_visualization(prompt): """Generate medical visualization using Stable Diffusion""" if sd_pipe is None: return None try: # Enhance prompt for medical context enhanced_prompt = f"medical illustration, {prompt}, anatomical diagram, clinical style, high quality, detailed" # Generate image image = sd_pipe( enhanced_prompt, num_inference_steps=30, guidance_scale=7.5 ).images[0] return image except Exception as e: print(f"Error generating visualization: {str(e)}") return None def text_to_speech_conversion(text): """Convert text to speech using Microsoft SpeechT5""" if tts_processor is None or tts_model is None or tts_vocoder is None or speaker_embeddings is None: return None try: # Limit text length for TTS (max ~600 characters for best results) if len(text) > 600: text = text[:600] + "..." # Process input text inputs = tts_processor(text=text, return_tensors="pt").to(device) # Generate speech with torch.no_grad(): speech = tts_model.generate_speech( inputs["input_ids"], speaker_embeddings, vocoder=tts_vocoder ) # Convert to numpy and return as audio tuple (sample_rate, audio_data) speech_np = speech.cpu().numpy() sample_rate = 16000 # SpeechT5 uses 16kHz return (sample_rate, speech_np) except Exception as e: print(f"Error in TTS: {str(e)}") return None def generate_patient_report(clinical_notes, image_analysis, summary): """Generate comprehensive patient report""" report = f""" ╔══════════════════════════════════════════════════════════╗ ║ MEDNARRATE AI - PATIENT REPORT ║ ║ Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ║ ╚══════════════════════════════════════════════════════════╝ 📋 EXECUTIVE SUMMARY: {summary} 🔬 CLINICAL NOTES ANALYSIS: {clinical_notes[:500]}{'...' if len(clinical_notes) > 500 else ''} 🏥 MEDICAL IMAGING FINDINGS: {image_analysis} 📊 RECOMMENDATION: This report is AI-generated and should be reviewed by a licensed healthcare professional. For accurate diagnosis and treatment, please consult with your physician. ════════════════════════════════════════════════════════════ """ return report def process_medical_data(clinical_text, medical_image, pdf_file, generate_viz, generate_audio): """ Main function that processes all inputs through the multimodal pipeline Args: clinical_text: Direct text input from doctor's notes medical_image: Uploaded medical scan (X-ray, MRI, etc.) pdf_file: Uploaded PDF with patient records generate_viz: Boolean to generate visualization generate_audio: Boolean to generate audio narration Returns: Tuple of (summary, image_analysis, full_report, visualization, audio) """ results = { 'summary': '', 'image_analysis': 'No image provided', 'report': '', 'visualization': None, 'audio': None } try: # STEP 1: Process Text Input print("📝 Processing text input...") full_text = "" if pdf_file is not None: print(" → Extracting text from PDF...") pdf_text = extract_text_from_pdf(pdf_file) full_text += pdf_text + "\n\n" if clinical_text and clinical_text.strip(): full_text += clinical_text if not full_text.strip(): full_text = "No clinical notes provided." # STEP 2: Generate Summary print("📊 Generating clinical summary...") summary = summarize_medical_text(full_text) results['summary'] = summary # STEP 3: Analyze Medical Image if medical_image is not None: print("🔬 Analyzing medical image...") image_analysis = analyze_medical_image(medical_image) results['image_analysis'] = image_analysis else: results['image_analysis'] = "No medical image provided for analysis." # STEP 4: Generate Comprehensive Report print("📋 Generating comprehensive report...") full_report = generate_patient_report( clinical_notes=full_text, image_analysis=results['image_analysis'], summary=results['summary'] ) results['report'] = full_report # STEP 5: Generate Visualization (Optional) if generate_viz and results['image_analysis'] != "No medical image provided for analysis.": print("🎨 Generating medical visualization...") viz_prompt = f"anatomical visualization of {results['image_analysis'][:100]}" visualization = generate_medical_visualization(viz_prompt) results['visualization'] = visualization # STEP 6: Generate Audio Narration (Optional) if generate_audio: print("🔊 Generating audio narration...") narration_text = f"Medical Report Summary. {results['summary']}" audio_output = text_to_speech_conversion(narration_text) results['audio'] = audio_output print("✅ Processing complete!") return ( results['summary'], results['image_analysis'], results['report'], results['visualization'], results['audio'] ) except Exception as e: error_msg = f"❌ Error in processing: {str(e)}" print(error_msg) return ( error_msg, "Error occurred", error_msg, None, None ) def create_interface(): """Create the Gradio web interface""" # Custom CSS for better styling custom_css = """ .gradio-container { font-family: 'Arial', sans-serif; } .main-header { text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 20px; border-radius: 10px; color: white; margin-bottom: 20px; } """ # Create the interface with gr.Blocks(css=custom_css, title="MedNarrate AI") as demo: # Header gr.HTML("""
AI-powered analysis of medical records, images, and clinical notes
⚠️ Disclaimer: This is an AI demonstration tool for educational purposes.
Not intended for actual medical diagnosis. Always consult qualified healthcare professionals.
Powered by 🤗 Hugging Face Transformers | Built with Gradio
TTS: Microsoft SpeechT5 | Image Analysis: BLIP | Text: BioGPT