import os os.environ['ANONYMIZED_TELEMETRY'] = 'False' import zipfile import chromadb from sentence_transformers import SentenceTransformer import gradio as gr from fastapi import FastAPI from pydantic import BaseModel from typing import List, Optional import re import time # Extract and load database DB_PATH = "./medqa_db" if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"): print("šŸ“¦ Extracting database...") with zipfile.ZipFile("./medqa_db.zip", 'r') as z: z.extractall(".") print("āœ… Database extracted") print("šŸ“Œ Loading ChromaDB...") client = chromadb.PersistentClient(path=DB_PATH) collection = client.get_collection("medqa") print(f"āœ… Loaded {collection.count()} questions") print("🧠 Loading MedCPT model...") model = SentenceTransformer('ncbi/MedCPT-Query-Encoder') print("āœ… Model ready") # ============================================================================ # Deduplication function # ============================================================================ def deduplicate_results(results, target_count): """ Remove duplicate questions based on: 1. High text similarity (>0.92) - catches near-exact duplicates 2. Same answer + moderate similarity (>0.85) - catches conceptual duplicates """ if not results['documents'][0]: return results documents = results['documents'][0] metadatas = results['metadatas'][0] distances = results['distances'][0] selected_indices = [] for i in range(len(documents)): is_duplicate = False current_answer = metadatas[i].get('answer', '') for j in selected_indices: selected_answer = metadatas[j].get('answer', '') dist_diff = abs(distances[i] - distances[j]) if dist_diff < 0.08: is_duplicate = True break if current_answer == selected_answer and dist_diff < 0.15: is_duplicate = True break if not is_duplicate: selected_indices.append(i) if len(selected_indices) >= target_count: break return { 'documents': [[documents[i] for i in selected_indices]], 'metadatas': [[metadatas[i] for i in selected_indices]], 'distances': [[distances[i] for i in selected_indices]], 'ids': [[results['ids'][0][i] for i in selected_indices]] if 'ids' in results else None } # ============================================================================ # Search function with deduplication # ============================================================================ def search(query, num_results=3, source_filter=None): emb = model.encode(query).tolist() where_clause = None if source_filter and source_filter != "all": where_clause = {"source": source_filter} fetch_count = min(num_results * 4, 50) results = collection.query( query_embeddings=[emb], n_results=fetch_count, where=where_clause ) return deduplicate_results(results, num_results) # ============================================================================ # Parser to extract question structure # ============================================================================ def parse_question_document(doc_text, metadata): """Extract question and choices from document text - NO TRUNCATION.""" lines = doc_text.split('\n') question_lines = [] options_started = False options = {} for line in lines: line = line.strip() if not line: continue # Check if this is an option line (A., B., C., etc.) option_match = re.match(r'^([A-E])[\.\)]\s*(.+)$', line) if option_match: options_started = True letter = option_match.group(1) text = option_match.group(2).strip() options[letter] = text elif not options_started: question_lines.append(line) # Reconstruct FULL question text - no truncation question_text = ' '.join(question_lines).strip() answer_idx = metadata.get('answer_idx', 'N/A') answer_text = metadata.get('answer', 'N/A') # If answer_text is just the letter, map it to the actual option text if answer_text in options: answer_text = options[answer_text] return { 'question': question_text, 'choices': options, 'correct_answer_letter': answer_idx, 'correct_answer_text': answer_text } # ============================================================================ # Enhanced Gradio UI # ============================================================================ def ui_search(query, num_results=3, source_filter="all"): if not query.strip(): return "šŸ’” Enter a medical query to search" try: r = search(query, num_results, source_filter if source_filter != "all" else None) if not r['documents'][0]: return "āŒ No results found" out = f"šŸ” Found {len(r['documents'][0])} unique results\n\n" for i in range(len(r['documents'][0])): source = r['metadatas'][0][i].get('source', 'unknown') distance = r['distances'][0][i] similarity = 1 - distance # Source emoji if source == 'medgemini': source_icon = "šŸ”¬" source_name = "Med-Gemini" elif source.startswith('medqa_'): source_icon = "šŸ“š" split = source.replace('medqa_', '').upper() source_name = f"MedQA {split}" else: source_icon = "šŸ“„" source_name = source.upper() out += f"\n{'='*70}\n" out += f"{source_icon} Result {i+1} | {source_name} | Similarity: {similarity:.3f}\n" out += f"{'='*70}\n\n" out += r['documents'][0][i] answer = r['metadatas'][0][i].get('answer', 'N/A') out += f"\n\nāœ… CORRECT ANSWER: {answer}\n" explanation = r['metadatas'][0][i].get('explanation', '') if explanation and explanation.strip(): out += f"\nšŸ’” EXPLANATION:\n{explanation}\n" out += "\n" return out except Exception as e: return f"āŒ Error: {e}" # Create Gradio interface with gr.Blocks(theme=gr.themes.Soft(), title="MedQA Search") as demo: gr.Markdown(""" # 🄼 MedQA Semantic Search Search across **Med-Gemini** (expert explanations) and **MedQA** (USMLE questions) databases. Uses medical-specific embeddings (MedCPT) for accurate retrieval. ✨ **Features**: Automatic deduplication, structured output for AI integration """) with gr.Row(): with gr.Column(scale=3): query_input = gr.Textbox( label="Medical Query", placeholder="e.g., hyponatremia, myocardial infarction, diabetes management...", lines=2 ) with gr.Column(scale=1): num_results = gr.Slider( minimum=1, maximum=10, value=3, step=1, label="Number of Results" ) with gr.Row(): source_filter = gr.Radio( choices=["all", "medgemini", "medqa_train", "medqa_dev", "medqa_test"], value="all", label="Filter by Source" ) search_btn = gr.Button("šŸ” Search", variant="primary", size="lg") output = gr.Textbox( label="Search Results", lines=25, max_lines=50 ) search_btn.click( fn=ui_search, inputs=[query_input, num_results, source_filter], outputs=output ) query_input.submit( fn=ui_search, inputs=[query_input, num_results, source_filter], outputs=output ) gr.Markdown(""" ### šŸ“Š Database Info **Med-Gemini**: Expert-relabeled questions with detailed explanations **MedQA**: USMLE-style questions (Train/Dev/Test splits) **Total Questions**: ~10,000+ USMLE-style questions """) gr.Examples( examples=[ ["hyponatremia", 3, "all"], ["myocardial infarction treatment", 2, "medgemini"], ["diabetes complications", 3, "all"], ["antibiotics for pneumonia", 2, "medqa_train"] ], inputs=[query_input, num_results, source_filter] ) # ============================================================================ # FastAPI with structured JSON output (for OpenAI integration) # ============================================================================ app = FastAPI() class SearchRequest(BaseModel): query: str num_results: int = 3 source_filter: str = None class BatchSearchRequest(BaseModel): queries: List[str] num_results_per_query: int = 10 source_filter: Optional[str] = None @app.post("/search_medqa") def api_search(req: SearchRequest): """ Search MedQA and return structured exemplars. Returns COMPLETE question text with no truncation. """ r = search(req.query, req.num_results, req.source_filter) if not r['documents'][0]: return {"results": []} results = [] for i in range(len(r['documents'][0])): doc_text = r['documents'][0][i] metadata = r['metadatas'][0][i] # Parse the document into structured format parsed = parse_question_document(doc_text, metadata) # Build complete result object result = { "result_number": i + 1, "question": parsed['question'], # FULL question text "choices": parsed['choices'], "correct_answer": parsed['correct_answer_letter'], "correct_answer_text": parsed['correct_answer_text'], "explanation": metadata.get('explanation', ''), "has_explanation": bool(metadata.get('explanation', '').strip()), "source": metadata.get('source', 'unknown'), "exam_type": metadata.get('exam_type', 'unknown'), "split": metadata.get('split', 'unknown'), "similarity": round(1 - r['distances'][0][i], 3), "metamap_phrases": metadata.get('metamap_phrases', '') } results.append(result) return {"results": results} @app.post("/batch_search_medqa") def batch_api_search(req: BatchSearchRequest): """ NEW: Batch search for multiple learning objectives. Processes all queries, tracks duplicates, and returns organized results. Returns: - results_by_objective: List of results organized by each objective - unique_questions: Deduplicated list of all questions - statistics: Coverage and quality metrics """ start_time = time.time() # Track all questions and their objective mappings all_questions = {} # key: question_text, value: question data + objectives list results_by_objective = [] for obj_idx, query in enumerate(req.queries): objective_id = obj_idx + 1 # Search for this objective r = search(query, req.num_results_per_query, req.source_filter) objective_results = [] similarities = [] if r['documents'][0]: for i in range(len(r['documents'][0])): doc_text = r['documents'][0][i] metadata = r['metadatas'][0][i] similarity = round(1 - r['distances'][0][i], 3) similarities.append(similarity) # Parse the document parsed = parse_question_document(doc_text, metadata) # Create unique key for deduplication question_key = parsed['question'][:200] # Use first 200 chars as key # Build result object result = { "question": parsed['question'], "choices": parsed['choices'], "correct_answer": parsed['correct_answer_letter'], "correct_answer_text": parsed['correct_answer_text'], "explanation": metadata.get('explanation', ''), "has_explanation": bool(metadata.get('explanation', '').strip()), "source": metadata.get('source', 'unknown'), "similarity": similarity } # Track for global deduplication if question_key in all_questions: # This question already exists - add this objective to its list all_questions[question_key]['matches_objectives'].append(objective_id) # Update similarity if higher if similarity > all_questions[question_key]['max_similarity']: all_questions[question_key]['max_similarity'] = similarity else: # First time seeing this question all_questions[question_key] = { **result, 'matches_objectives': [objective_id], 'max_similarity': similarity, 'first_seen_at': objective_id } objective_results.append(result) # Store results for this objective results_by_objective.append({ "objective_id": objective_id, "objective_text": query, "num_results": len(objective_results), "avg_similarity": round(sum(similarities) / len(similarities), 3) if similarities else 0, "results": objective_results }) # Prepare unique questions list unique_questions = list(all_questions.values()) # Calculate statistics execution_time = round(time.time() - start_time, 2) total_retrieved = sum(obj['num_results'] for obj in results_by_objective) # Coverage analysis coverage = { "excellent": [obj for obj in results_by_objective if obj['num_results'] >= 5], "moderate": [obj for obj in results_by_objective if 2 <= obj['num_results'] < 5], "limited": [obj for obj in results_by_objective if obj['num_results'] == 1], "none": [obj for obj in results_by_objective if obj['num_results'] == 0] } # Multi-objective questions multi_objective_questions = [q for q in unique_questions if len(q['matches_objectives']) > 1] # Source distribution sources = {} for q in unique_questions: source = q['source'] sources[source] = sources.get(source, 0) + 1 # Similarity distribution all_similarities = [q['max_similarity'] for q in unique_questions] high_sim = len([s for s in all_similarities if s > 0.8]) med_sim = len([s for s in all_similarities if 0.7 <= s <= 0.8]) low_sim = len([s for s in all_similarities if s < 0.7]) statistics = { "total_objectives": len(req.queries), "total_retrieved": total_retrieved, "unique_questions": len(unique_questions), "deduplication_rate": round((total_retrieved - len(unique_questions)) / total_retrieved * 100, 1) if total_retrieved > 0 else 0, "execution_time_seconds": execution_time, "coverage": { "excellent_coverage_count": len(coverage["excellent"]), "moderate_coverage_count": len(coverage["moderate"]), "limited_coverage_count": len(coverage["limited"]), "no_coverage_count": len(coverage["none"]), "no_coverage_objectives": [obj['objective_id'] for obj in coverage["none"]] }, "cross_objective": { "multi_objective_questions": len(multi_objective_questions), "multi_objective_percentage": round(len(multi_objective_questions) / len(unique_questions) * 100, 1) if unique_questions else 0 }, "sources": sources, "similarity_distribution": { "high_similarity_count": high_sim, "medium_similarity_count": med_sim, "low_similarity_count": low_sim, "average_similarity": round(sum(all_similarities) / len(all_similarities), 3) if all_similarities else 0 } } return { "results_by_objective": results_by_objective, "unique_questions": unique_questions, "statistics": statistics } app = gr.mount_gradio_app(app, demo, path="/") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)