Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -20,7 +20,7 @@ if not HF_TOKEN:
|
|
| 20 |
print(" export HF_TOKEN='your_hf_token_here'")
|
| 21 |
|
| 22 |
# =====================================================
|
| 23 |
-
# Load private
|
| 24 |
# =====================================================
|
| 25 |
def load_private_model(model_name, token):
|
| 26 |
dtype_value = torch.float16 if torch.cuda.is_available() else torch.float32
|
|
@@ -44,13 +44,13 @@ tokenizer, model = load_private_model(GEN_MODEL, token=HF_TOKEN)
|
|
| 44 |
# =====================================================
|
| 45 |
# Dynamic token allocation
|
| 46 |
# =====================================================
|
| 47 |
-
def calculate_max_tokens(query, min_tokens=
|
| 48 |
query_tokens = len(tokenizer(query)["input_ids"])
|
| 49 |
dynamic_tokens = query_tokens * factor
|
| 50 |
return min(max(dynamic_tokens, min_tokens), max_tokens)
|
| 51 |
|
| 52 |
# =====================================================
|
| 53 |
-
# Generate structured
|
| 54 |
# =====================================================
|
| 55 |
def generate_answer(query, history):
|
| 56 |
if not query.strip():
|
|
@@ -59,22 +59,35 @@ def generate_answer(query, history):
|
|
| 59 |
# Correct common typos
|
| 60 |
corrected_query = query.replace("COPP", "COPD")
|
| 61 |
|
| 62 |
-
# Step 1: Rephrase
|
| 63 |
rephrase_prompt = (
|
| 64 |
"You are a medical assistant. Rephrase this query for precise retrieval:\n\n"
|
| 65 |
f"Query: {corrected_query}\n\nRephrased query:"
|
| 66 |
)
|
| 67 |
inputs = tokenizer(rephrase_prompt, return_tensors="pt").to(model.device)
|
| 68 |
-
rephrased_ids = model.generate(**inputs, max_new_tokens=
|
| 69 |
-
rephrased_query = tokenizer.decode(
|
|
|
|
|
|
|
| 70 |
|
| 71 |
# Step 2: Generate detailed structured answer
|
| 72 |
max_tokens = calculate_max_tokens(rephrased_query)
|
| 73 |
prompt = (
|
| 74 |
-
"You are a retrieval-augmented medical assistant. Provide a detailed, structured answer
|
| 75 |
-
"
|
| 76 |
-
"
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
)
|
| 79 |
|
| 80 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
@@ -82,35 +95,34 @@ def generate_answer(query, history):
|
|
| 82 |
**inputs,
|
| 83 |
max_new_tokens=max_tokens,
|
| 84 |
do_sample=True,
|
| 85 |
-
temperature=0.
|
| 86 |
-
|
|
|
|
| 87 |
pad_token_id=tokenizer.eos_token_id,
|
| 88 |
)
|
| 89 |
|
| 90 |
output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 91 |
answer = output.split("Answer:")[-1].strip()
|
| 92 |
|
| 93 |
-
#
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
if token_str in answer:
|
| 97 |
-
answer = answer.split(token_str)[0] + token_str
|
| 98 |
-
break
|
| 99 |
|
| 100 |
history = history + [(query, answer)]
|
| 101 |
return history, history
|
| 102 |
|
| 103 |
# =====================================================
|
| 104 |
-
# Gradio
|
| 105 |
# =====================================================
|
| 106 |
with gr.Blocks(title="Qwen3-Harrison-RAG Chatbot") as demo:
|
| 107 |
gr.Markdown("""
|
| 108 |
-
#
|
| 109 |
-
|
|
|
|
| 110 |
""")
|
| 111 |
-
chatbot = gr.Chatbot(height=
|
| 112 |
with gr.Row():
|
| 113 |
-
msg = gr.Textbox(placeholder="
|
| 114 |
clear = gr.Button("Clear", scale=1)
|
| 115 |
msg.submit(generate_answer, [msg, chatbot], [chatbot, chatbot])
|
| 116 |
clear.click(lambda: None, None, chatbot, queue=False)
|
|
@@ -119,4 +131,4 @@ with gr.Blocks(title="Qwen3-Harrison-RAG Chatbot") as demo:
|
|
| 119 |
# Launch
|
| 120 |
# =====================================================
|
| 121 |
if __name__ == "__main__":
|
| 122 |
-
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
|
|
|
|
| 20 |
print(" export HF_TOKEN='your_hf_token_here'")
|
| 21 |
|
| 22 |
# =====================================================
|
| 23 |
+
# Load private model
|
| 24 |
# =====================================================
|
| 25 |
def load_private_model(model_name, token):
|
| 26 |
dtype_value = torch.float16 if torch.cuda.is_available() else torch.float32
|
|
|
|
| 44 |
# =====================================================
|
| 45 |
# Dynamic token allocation
|
| 46 |
# =====================================================
|
| 47 |
+
def calculate_max_tokens(query, min_tokens=1000, max_tokens=8192, factor=8):
|
| 48 |
query_tokens = len(tokenizer(query)["input_ids"])
|
| 49 |
dynamic_tokens = query_tokens * factor
|
| 50 |
return min(max(dynamic_tokens, min_tokens), max_tokens)
|
| 51 |
|
| 52 |
# =====================================================
|
| 53 |
+
# Generate long, complete, structured answers
|
| 54 |
# =====================================================
|
| 55 |
def generate_answer(query, history):
|
| 56 |
if not query.strip():
|
|
|
|
| 59 |
# Correct common typos
|
| 60 |
corrected_query = query.replace("COPP", "COPD")
|
| 61 |
|
| 62 |
+
# Step 1: Rephrase for precise retrieval
|
| 63 |
rephrase_prompt = (
|
| 64 |
"You are a medical assistant. Rephrase this query for precise retrieval:\n\n"
|
| 65 |
f"Query: {corrected_query}\n\nRephrased query:"
|
| 66 |
)
|
| 67 |
inputs = tokenizer(rephrase_prompt, return_tensors="pt").to(model.device)
|
| 68 |
+
rephrased_ids = model.generate(**inputs, max_new_tokens=128, do_sample=False)
|
| 69 |
+
rephrased_query = tokenizer.decode(
|
| 70 |
+
rephrased_ids[0], skip_special_tokens=True
|
| 71 |
+
).split("Rephrased query:")[-1].strip()
|
| 72 |
|
| 73 |
# Step 2: Generate detailed structured answer
|
| 74 |
max_tokens = calculate_max_tokens(rephrased_query)
|
| 75 |
prompt = (
|
| 76 |
+
"You are a retrieval-augmented medical assistant. Provide a **long, detailed, structured** medical answer "
|
| 77 |
+
"as if writing a concise clinical guideline. Use markdown headings and bullet points. "
|
| 78 |
+
"Each section should include multiple complete sentences and clear explanations.\n\n"
|
| 79 |
+
"Follow this structure:\n"
|
| 80 |
+
"### Definition / Description\n"
|
| 81 |
+
"### Epidemiology / Causes\n"
|
| 82 |
+
"### Symptoms & Signs\n"
|
| 83 |
+
"### Diagnosis / Investigations\n"
|
| 84 |
+
"### Complications\n"
|
| 85 |
+
"### Treatment & Management\n"
|
| 86 |
+
"### Prognosis / Prevention\n"
|
| 87 |
+
"### Key Notes / References\n\n"
|
| 88 |
+
"At the end, include a **🩺 Quick Summary** with 3–5 key takeaways written in plain English "
|
| 89 |
+
"that a non-medical reader could understand.\n\n"
|
| 90 |
+
f"User query: {rephrased_query}\n\nAnswer:"
|
| 91 |
)
|
| 92 |
|
| 93 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
|
|
| 95 |
**inputs,
|
| 96 |
max_new_tokens=max_tokens,
|
| 97 |
do_sample=True,
|
| 98 |
+
temperature=0.8,
|
| 99 |
+
top_p=0.9,
|
| 100 |
+
repetition_penalty=1.2,
|
| 101 |
pad_token_id=tokenizer.eos_token_id,
|
| 102 |
)
|
| 103 |
|
| 104 |
output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 105 |
answer = output.split("Answer:")[-1].strip()
|
| 106 |
|
| 107 |
+
# Clean up potential triple breaks
|
| 108 |
+
while "\n\n\n" in answer:
|
| 109 |
+
answer = answer.replace("\n\n\n", "\n\n")
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
history = history + [(query, answer)]
|
| 112 |
return history, history
|
| 113 |
|
| 114 |
# =====================================================
|
| 115 |
+
# Gradio interface
|
| 116 |
# =====================================================
|
| 117 |
with gr.Blocks(title="Qwen3-Harrison-RAG Chatbot") as demo:
|
| 118 |
gr.Markdown("""
|
| 119 |
+
# 🧠 Qwen3-Harrison-RAG Medical Chatbot
|
| 120 |
+
This model provides **guideline-style medical answers** with structured sections and a **Quick Summary**.
|
| 121 |
+
*For educational and informational purposes only — not a substitute for professional medical advice.*
|
| 122 |
""")
|
| 123 |
+
chatbot = gr.Chatbot(height=480, show_label=False)
|
| 124 |
with gr.Row():
|
| 125 |
+
msg = gr.Textbox(placeholder="Ask a detailed medical question...", scale=4)
|
| 126 |
clear = gr.Button("Clear", scale=1)
|
| 127 |
msg.submit(generate_answer, [msg, chatbot], [chatbot, chatbot])
|
| 128 |
clear.click(lambda: None, None, chatbot, queue=False)
|
|
|
|
| 131 |
# Launch
|
| 132 |
# =====================================================
|
| 133 |
if __name__ == "__main__":
|
| 134 |
+
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)), debug=True)
|