Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| TrainingArguments, | |
| Trainer, | |
| DataCollatorForLanguageModeling, | |
| pipeline, | |
| ) | |
| #from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, pipeline | |
| #from llama_cpp import Llama | |
| from datasets import load_dataset | |
| import os | |
| import requests | |
| # Replace with the direct image URL | |
| flower_image_url = "https://i.postimg.cc/hG2FG85D/2.png" | |
| # Inject custom CSS for the background with a centered and blurred image | |
| st.markdown( | |
| f""" | |
| <style> | |
| /* Container for background */ | |
| html, body {{ | |
| margin: 0; | |
| padding: 0; | |
| overflow: hidden; | |
| }} | |
| [data-testid="stAppViewContainer"] {{ | |
| position: relative; | |
| z-index: 1; /* Ensure UI elements are above the background */ | |
| }} | |
| /* Blurred background image */ | |
| .blurred-background {{ | |
| position: fixed; | |
| top: 0; | |
| left: 0; | |
| width: 100%; | |
| height: 100%; | |
| z-index: -1; /* Send background image behind all UI elements */ | |
| background-image: url("{flower_image_url}"); | |
| background-size: cover; | |
| background-position: center; | |
| filter: blur(10px); /* Adjust blur ratio here */ | |
| opacity: 0.8; /* Optional: Add slight transparency for a subtle effect */ | |
| }} | |
| </style> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| # Add the blurred background div | |
| st.markdown('<div class="blurred-background"></div>', unsafe_allow_html=True) | |
| #""""""""""""""""""""""""" Application Code Starts here """"""""""""""""""""""""""""""""""""""""""""" | |
| # Cache resource for dataset loading | |
| def load_counseling_dataset(): | |
| # Load a smaller subset of the dataset for memory efficiency | |
| dataset = load_dataset("Amod/mental_health_counseling_conversations", split="train") | |
| return dataset | |
| # Process the dataset in batches to avoid memory overuse | |
| def process_dataset_in_batches(dataset, batch_size=500): | |
| for example in dataset.shuffle().select(range(batch_size)): | |
| yield example | |
| # Fine-tune the model and save it | |
| def fine_tune_model(): | |
| # Load base model and tokenizer | |
| model_name = "prabureddy/Mental-Health-FineTuned-Mistral-7B-Instruct-v0.2" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| # Enable gradient checkpointing for memory optimization | |
| model.gradient_checkpointing_enable() | |
| # Prepare dataset for training | |
| dataset = load_counseling_dataset() | |
| def preprocess_function(examples): | |
| return tokenizer(examples["context"] + "\n" + examples["response"], truncation=True) | |
| tokenized_datasets = dataset.map(preprocess_function, batched=True) | |
| data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) | |
| # Training arguments | |
| training_args = TrainingArguments( | |
| output_dir="./fine_tuned_model", | |
| evaluation_strategy="steps", | |
| learning_rate=2e-5, | |
| per_device_train_batch_size=5, | |
| per_device_eval_batch_size=5, | |
| num_train_epochs=3, | |
| weight_decay=0.01, | |
| fp16=True, # Enable FP16 for lower memory usage | |
| save_total_limit=2, | |
| save_steps=250, | |
| logging_steps=50, | |
| ) | |
| # Trainer | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=tokenized_datasets["train"], | |
| eval_dataset=tokenized_datasets["validation"], | |
| tokenizer=tokenizer, | |
| data_collator=data_collator, | |
| ) | |
| trainer.train() | |
| # Save the fine-tuned model | |
| trainer.save_model("./fine_tuned_model") | |
| tokenizer.save_pretrained("./fine_tuned_model") | |
| return "./fine_tuned_model" | |
| # Load or fine-tune the model | |
| model_dir = fine_tune_model() | |
| # Load the fine-tuned model for inference | |
| def load_pipeline(model_dir): | |
| return pipeline("text-generation", model=model_dir) | |
| pipe = load_pipeline(model_dir) | |
| # Streamlit App | |
| st.title("Mental Health Support Assistant") | |
| st.markdown(""" | |
| Welcome to the **Mental Health Support Assistant**. | |
| This tool helps detect potential mental health concerns based on user input and provides **uplifting and positive suggestions** to boost morale. | |
| """) | |
| # User input for mental health concerns | |
| user_input = st.text_area("Please share your concern:", placeholder="Type your question or concern here...") | |
| if st.button("Get Supportive Response"): | |
| if user_input.strip(): | |
| with st.spinner("Analyzing your input and generating a response..."): | |
| try: | |
| # Generate a response | |
| response = pipe(user_input, max_length=150, num_return_sequences=1)[0]["generated_text"] | |
| st.subheader("Supportive Suggestion:") | |
| st.markdown(f"**{response}**") | |
| except Exception as e: | |
| st.error(f"An error occurred while generating the response: {e}") | |
| else: | |
| st.error("Please enter a concern to receive suggestions.") | |
| # Sidebar for additional resources | |
| st.sidebar.header("Additional Resources") | |
| st.sidebar.markdown(""" | |
| - [Mental Health Foundation](https://www.mentalhealth.org) | |
| - [Mind](https://www.mind.org.uk) | |
| - [National Suicide Prevention Lifeline](https://suicidepreventionlifeline.org) | |
| """) | |
| st.sidebar.info("This application is not a replacement for professional counseling. If you're in crisis, seek professional help immediately.") | |