Spaces:
Runtime error
Runtime error
| import os | |
| from fastapi import FastAPI, Request, HTTPException | |
| from fastapi.responses import JSONResponse, HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| from huggingface_hub import InferenceClient | |
| from datasets import load_dataset | |
| import markdown2 | |
| # Set up Hugging Face cache | |
| os.environ["HF_HOME"] = "/app/.cache" | |
| # Initialize FastAPI application | |
| app = FastAPI() | |
| # Set up templates and static file serving | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| templates = Jinja2Templates(directory="templates") | |
| # Hugging Face API token | |
| hf_token = os.getenv("HF_TOKEN") | |
| # Load datasets | |
| chat_doctor_dataset = load_dataset("avaliev/chat_doctor") | |
| mental_health_dataset = load_dataset("Amod/mental_health_counseling_conversations") | |
| # Set up Hugging Face Inference Client | |
| client = InferenceClient( | |
| "meta-llama/Meta-Llama-3-8B-Instruct", | |
| token=hf_token, | |
| ) | |
| def select_relevant_context(user_input: str) -> str: | |
| """Select relevant context from the datasets based on user input keywords.""" | |
| mental_health_keywords = [ | |
| "anxious", "depressed", "stress", "mental health", "counseling", | |
| "therapy", "feelings", "worthless", "suicidal", "panic", "anxiety" | |
| ] | |
| medical_keywords = [ | |
| "symptoms", "diagnosis", "treatment", "doctor", "prescription", "medication", | |
| "pain", "illness", "disease", "infection", "surgery" | |
| ] | |
| if any(keyword in user_input.lower() for keyword in mental_health_keywords): | |
| example = mental_health_dataset['train'][0] | |
| context = f"Counselor: {example['Response']}\nUser: {example['Context']}" | |
| elif any(keyword in user_input.lower() for keyword in medical_keywords): | |
| example = chat_doctor_dataset['train'][0] | |
| context = f"Doctor: {example['input']}\nPatient: {example['output']}" | |
| else: | |
| context = "You are a general assistant. Respond to the user's query in a helpful manner." | |
| return context | |
| def create_prompt(context: str, user_input: str) -> str: | |
| """Create the final prompt based on the context and user input.""" | |
| return f"{context}\n\nUser: {user_input}\nAssistant:" | |
| def render_markdown(text: str) -> str: | |
| """Render Markdown into HTML.""" | |
| return markdown2.markdown(text) | |
| async def index(request: Request): | |
| """Render the homepage.""" | |
| return templates.TemplateResponse("index.html", {"request": request}) | |
| async def chat(request: Request): | |
| """Handle the chat route and process user input.""" | |
| try: | |
| data = await request.json() | |
| user_input = data["message"] | |
| context = select_relevant_context(user_input) | |
| prompt = create_prompt(context, user_input) | |
| response = "" | |
| for message in client.chat_completion( | |
| messages=[{"role": "user", "content": prompt}], | |
| max_tokens=500, | |
| stream=True, | |
| ): | |
| response += message.choices[0].delta.content | |
| formatted_response = render_markdown(response) | |
| return JSONResponse({"response": formatted_response}) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing chat: {str(e)}") | |