import os import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM # ===================================================== # Environment setup # ===================================================== os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache" os.environ["HF_HOME"] = "/tmp/hf_home" # ===================================================== # Model configuration # ===================================================== GEN_MODEL = "hackergeek/qwen3-harrison-rag" HF_TOKEN = os.getenv("HF_TOKEN") if not HF_TOKEN: print("⚠️ No Hugging Face token found. Set one using:") print(" export HF_TOKEN='your_hf_token_here'") # ===================================================== # Load private model # ===================================================== def load_private_model(model_name, token): dtype_value = torch.float16 if torch.cuda.is_available() else torch.float32 load_kwargs = { "dtype": dtype_value, "cache_dir": "/tmp/hf_cache", "low_cpu_mem_usage": True, } try: import accelerate load_kwargs["device_map"] = "auto" except ImportError: print("⚠️ `accelerate` not installed — default device placement used.") tokenizer = AutoTokenizer.from_pretrained(model_name, token=token) model = AutoModelForCausalLM.from_pretrained(model_name, token=token, **load_kwargs) return tokenizer, model tokenizer, model = load_private_model(GEN_MODEL, token=HF_TOKEN) # ===================================================== # Dynamic token allocation # ===================================================== def calculate_max_tokens(query, min_tokens=1000, max_tokens=8192, factor=8): query_tokens = len(tokenizer(query)["input_ids"]) dynamic_tokens = query_tokens * factor return min(max(dynamic_tokens, min_tokens), max_tokens) # ===================================================== # Generate long, complete, structured answers # ===================================================== def generate_answer(query, history): if not query.strip(): return history, history # Correct common typos corrected_query = query.replace("COPP", "COPD") # Step 1: Rephrase for precise retrieval rephrase_prompt = ( "You are a medical assistant. Rephrase this query for precise retrieval:\n\n" f"Query: {corrected_query}\n\nRephrased query:" ) inputs = tokenizer(rephrase_prompt, return_tensors="pt").to(model.device) rephrased_ids = model.generate(**inputs, max_new_tokens=128, do_sample=False) rephrased_query = tokenizer.decode( rephrased_ids[0], skip_special_tokens=True ).split("Rephrased query:")[-1].strip() # Step 2: Generate detailed structured answer max_tokens = calculate_max_tokens(rephrased_query) prompt = ( "You are a retrieval-augmented medical assistant. Provide a **long, detailed, structured** medical answer " "as if writing a concise clinical guideline. Use markdown headings and bullet points. " "Each section should include multiple complete sentences and clear explanations.\n\n" "Follow this structure:\n" "### Definition / Description\n" "### Epidemiology / Causes\n" "### Symptoms & Signs\n" "### Diagnosis / Investigations\n" "### Complications\n" "### Treatment & Management\n" "### Prognosis / Prevention\n" "### Key Notes / References\n\n" "At the end, include a **🩺 Quick Summary** with 3–5 key takeaways written in plain English " "that a non-medical reader could understand.\n\n" f"User query: {rephrased_query}\n\nAnswer:" ) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) output_ids = model.generate( **inputs, max_new_tokens=max_tokens, do_sample=True, temperature=0.8, top_p=0.9, repetition_penalty=1.2, pad_token_id=tokenizer.eos_token_id, ) output = tokenizer.decode(output_ids[0], skip_special_tokens=True) answer = output.split("Answer:")[-1].strip() # Clean up potential triple breaks while "\n\n\n" in answer: answer = answer.replace("\n\n\n", "\n\n") history = history + [(query, answer)] return history, history # ===================================================== # Gradio interface # ===================================================== with gr.Blocks(title="Qwen3-Harrison-RAG Chatbot") as demo: gr.Markdown(""" # 🧠 Qwen3-Harrison-RAG Medical Chatbot This model provides **guideline-style medical answers** with structured sections and a **Quick Summary**. *For educational and informational purposes only — not a substitute for professional medical advice.* """) chatbot = gr.Chatbot(height=480, show_label=False) with gr.Row(): msg = gr.Textbox(placeholder="Ask a detailed medical question...", scale=4) clear = gr.Button("Clear", scale=1) msg.submit(generate_answer, [msg, chatbot], [chatbot, chatbot]) clear.click(lambda: None, None, chatbot, queue=False) # ===================================================== # Launch # ===================================================== if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)), debug=True)