Spaces:
Sleeping
Sleeping
| # ================================ | |
| # ✅ Cache-Safe Multimodal App | |
| # ================================ | |
| import shutil, os | |
| # ====== Force all cache dirs to /tmp (writable in most environments) ====== | |
| CACHE_BASE = "/tmp/cache" | |
| os.environ["HF_HOME"] = f"{CACHE_BASE}/hf_home" | |
| os.environ["TRANSFORMERS_CACHE"] = f"{CACHE_BASE}/transformers" | |
| os.environ["SENTENCE_TRANSFORMERS_HOME"] = f"{CACHE_BASE}/sentence_transformers" | |
| os.environ["HF_DATASETS_CACHE"] = f"{CACHE_BASE}/hf_datasets" | |
| os.environ["TORCH_HOME"] = f"{CACHE_BASE}/torch" | |
| os.environ["STREAMLIT_CACHE_DIR"] = f"{CACHE_BASE}/streamlit_cache" | |
| os.environ["STREAMLIT_STATIC_DIR"] = f"{CACHE_BASE}/streamlit_static" | |
| os.environ["STREAMLIT_CONFIG_DIR"] = "/tmp/.streamlit" | |
| # Create the directories before imports | |
| os.makedirs(os.environ["STREAMLIT_CONFIG_DIR"], exist_ok=True) | |
| # Create the directories before imports | |
| for path in os.environ.values(): | |
| if path.startswith(CACHE_BASE): | |
| os.makedirs(path, exist_ok=True) | |
| # ====== Imports ====== | |
| import streamlit as st | |
| import torch | |
| from sentence_transformers import SentenceTransformer, util | |
| from transformers import CLIPProcessor, CLIPModel | |
| from datasets import load_dataset, get_dataset_split_names | |
| from PIL import Image | |
| from openai import OpenAI | |
| import comet_llm | |
| from opik import track | |
| # ========== 🔑 API Key ========== | |
| OpenAI.api_key = os.getenv("OPENAI_API_KEY") | |
| os.environ["OPIK_API_KEY"] = os.getenv("OPIK_API_KEY") | |
| os.environ["OPIK_WORKSPACE"] = os.getenv("OPIK_WORKSPACE") | |
| # ========== 📥 Load Models ========== | |
| def load_models(): | |
| _clip_model = CLIPModel.from_pretrained( | |
| "openai/clip-vit-base-patch32", | |
| cache_dir=os.environ["TRANSFORMERS_CACHE"] | |
| ) | |
| _clip_processor = CLIPProcessor.from_pretrained( | |
| "openai/clip-vit-base-patch32", | |
| cache_dir=os.environ["TRANSFORMERS_CACHE"] | |
| ) | |
| _text_model = SentenceTransformer( | |
| "all-MiniLM-L6-v2", | |
| cache_folder=os.environ["SENTENCE_TRANSFORMERS_HOME"] | |
| ) | |
| return _clip_model, _clip_processor, _text_model | |
| clip_model, clip_processor, text_model = load_models() | |
| # ========== 📥 Load Dataset ========== | |
| def load_medical_data(): | |
| available_splits = get_dataset_split_names("univanxx/3mdbench") | |
| split_to_use = "train" if "train" in available_splits else available_splits[0] | |
| dataset = load_dataset( | |
| "univanxx/3mdbench", | |
| split=split_to_use, | |
| cache_dir=os.environ["HF_DATASETS_CACHE"] | |
| ) | |
| return dataset | |
| # Cache dataset image embeddings (takes time, so cached) | |
| def embed_dataset_images(_dataset): | |
| features = [] | |
| for item in _dataset: | |
| # Load image from URL/path or raw bytes - adapt this if needed | |
| img = item["image"] | |
| inputs_img = clip_processor(images=img, return_tensors="pt") | |
| with torch.no_grad(): | |
| feat = clip_model.get_image_features(**inputs_img) | |
| feat /= feat.norm(p=2, dim=-1, keepdim=True) | |
| features.append(feat.cpu()) | |
| return torch.cat(features, dim=0) | |
| data = load_medical_data() | |
| dataset_image_features = embed_dataset_images(data) | |
| client = OpenAI(api_key=OpenAI.api_key) | |
| # Temporary debug display | |
| #st.write("Dataset columns:", data.features.keys()) | |
| # After seeing the real column name, let's say it's "text" instead of "description": | |
| text_field = "text" if "text" in data.features else list(data.features.keys())[0] | |
| def prepare_combined_texts(_dataset): | |
| combined = [] | |
| for gc, c in zip(_dataset["general_complaint"], _dataset["complaints"]): | |
| gc_str = gc if gc else "" | |
| c_str = c if c else "" | |
| combined.append(f"General complaint: {gc_str}. Additional details: {c_str}") | |
| return combined | |
| combined_texts = prepare_combined_texts(data) | |
| # Then use dynamic access: | |
| #text_embeddings = embed_texts(data[text_field]) | |
| # ========== 🧠 Embedding Function ========== | |
| def embed_dataset_texts(_texts): | |
| return text_model.encode(_texts, convert_to_tensor=True) | |
| def embed_query_text(_query): | |
| return text_model.encode([_query], convert_to_tensor=True)[0] | |
| def get_chat_completion_openai(_client, _prompt: str): | |
| return _client.chat.completions.create( | |
| model="gpt-4o", # or "gpt-4" if you need the older GPT-4 | |
| messages=[{"role": "user", "content": _prompt}], | |
| temperature=0.5, | |
| max_tokens=425 | |
| ) | |
| def get_similar_prompt(_query): | |
| text_embeddings = embed_dataset_texts(combined_texts) # cached | |
| query_embedding = embed_query_text(_query) # recalculated each time | |
| cos_scores = util.pytorch_cos_sim(query_embedding, text_embeddings)[0] | |
| top_result = torch.topk(cos_scores, k=1) | |
| _idx = top_result.indices[0].item() | |
| return data[_idx] | |
| # Pick which text column to use | |
| TEXT_COLUMN = "complaints" # or "general_complaint", depending on your needs | |
| # ========== 🧑⚕️ App UI ========== | |
| st.title("🩺 Multimodal Medical Chatbot") | |
| query = st.text_input("Enter your medical question or symptom description:") | |
| uploaded_files = st.file_uploader("Upload an image to find similar medical cases:", type=["png", "jpg", "jpeg"], accept_multiple_files=True) | |
| # Add author info in the sidebar | |
| with st.sidebar: | |
| st.markdown("## 👤👤Authors") | |
| st.markdown("**Vasan Iyer**") | |
| st.markdown("**Eric J Giacomucci**") | |
| st.markdown("[GitHub](https://github.com/Vaiy108)") | |
| st.markdown("[LinkedIn](https://linkedin.com/in/vasan-iyer)") | |
| if st.button("Submit") and query: | |
| with st.spinner("Searching medical cases..."): | |
| # Compute similarity | |
| selected = get_similar_prompt(query) | |
| # Show Image | |
| st.image(selected['image'], caption="Most relevant medical image", use_container_width=True) | |
| # Show Text | |
| st.markdown(f"**Case Description:** {selected[TEXT_COLUMN]}") | |
| # GPT Explanation | |
| if OpenAI.api_key: | |
| prompt = f"Explain this case in plain English: {selected[TEXT_COLUMN]}" | |
| explanation = get_chat_completion_openai(client, prompt) | |
| explanation = explanation.choices[0].message.content | |
| st.markdown(f"### 🤖 Explanation by GPT:\n{explanation}") | |
| else: | |
| st.warning("OpenAI API key not found. Please set OPENAI_API_KEY as a secret environment variable.") | |
| if uploaded_files is not None: | |
| with st.spinner("Searching medical cases..."): | |
| st.write(f"Number of files: {len(uploaded_files)}") | |
| if len(uploaded_files) > 0: | |
| print(uploaded_files) | |
| uploaded_file = uploaded_files[0] | |
| st.write(f'uploading file {uploaded_file.name}') | |
| query_image = Image.open(uploaded_file).convert("RGB") | |
| st.image(query_image, caption="Your uploaded image", use_container_width=True) | |
| # Embed uploaded image | |
| inputs = clip_processor(images=query_image, return_tensors="pt") | |
| with torch.no_grad(): | |
| query_feat = clip_model.get_image_features(**inputs) | |
| query_feat /= query_feat.norm(p=2, dim=-1, keepdim=True) | |
| # Compute cosine similarity | |
| similarities = (dataset_image_features @ query_feat.T).squeeze(1) # [num_dataset_images] | |
| top_k = 3 | |
| top_results = torch.topk(similarities, k=top_k) | |
| st.write(f"Top {top_k} similar medical cases:") | |
| for rank, idx in enumerate(top_results.indices): | |
| score = top_results.values[rank].item() | |
| similar_img = data[int(idx)]['image'] | |
| st.image(similar_img, caption=f"Similarity: {score:.3f}", use_container_width=True) | |
| st.markdown(f"**Case description:** {data[int(idx)]['complaints']}") | |
| st.caption("This chatbot is for educational purposes only and does not provide medical advice.") | |