Spaces:
Running
Running
| import requests | |
| import gradio as gr | |
| import os | |
| import torch | |
| import json | |
| import time | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # Check if CUDA is available and set the device accordingly | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # API URLs and headers | |
| AUDIO_API_URL = "https://api-inference.huggingface.co/models/MIT/ast-finetuned-audioset-10-10-0.4593" | |
| LYRICS_API_URL = "https://api-inference.huggingface.co/models/EleutherAI/gpt-j-6B" | |
| headers = {"Authorization": f"Bearer {os.environ.get('HF_TOKEN')}"} | |
| def format_error(message): | |
| """Helper function to format error messages as JSON""" | |
| return {"error": message} | |
| def create_lyrics_prompt(classification_results): | |
| """Create a prompt for lyrics generation based on classification results""" | |
| # Get the top genre and its characteristics | |
| top_result = classification_results[0] | |
| genre = top_result['label'] | |
| confidence = float(top_result['score'].strip('%')) / 100 | |
| # Get additional musical elements | |
| additional_elements = [r['label'] for r in classification_results[1:3]] | |
| # Create a detailed creative prompt | |
| prompt = f"""Write creative and original song lyrics that capture the following musical elements: | |
| Primary Style: {genre} ({confidence*100:.1f}% confidence) | |
| Secondary Elements: {', '.join(additional_elements)} | |
| Requirements: | |
| 1. Create lyrics that strongly reflect the {genre} style | |
| 2. Incorporate elements of {' and '.join(additional_elements)} | |
| 3. Include both verses and a chorus | |
| 4. Match the mood and atmosphere typical of this genre | |
| 5. Use appropriate musical terminology and style | |
| Lyrics: | |
| [Verse 1] | |
| """ | |
| return prompt | |
| def generate_lyrics_with_retry(prompt, max_retries=5, initial_wait=2): | |
| """Generate lyrics using GPT-J-6B with retry logic""" | |
| wait_time = initial_wait | |
| for attempt in range(max_retries): | |
| try: | |
| response = requests.post( | |
| LYRICS_API_URL, | |
| headers=headers, | |
| json={ | |
| "inputs": prompt, | |
| "parameters": { | |
| "max_new_tokens": 250, | |
| "temperature": 0.8, | |
| "top_p": 0.92, | |
| "do_sample": True, | |
| "return_full_text": False, | |
| "stop": ["[End]", "\n\n\n"] | |
| } | |
| } | |
| ) | |
| print(f"Response status: {response.status_code}") | |
| print(f"Response content: {response.content.decode('utf-8', errors='ignore')}") | |
| if response.status_code == 200: | |
| result = response.json() | |
| if isinstance(result, list) and len(result) > 0: | |
| generated_text = result[0].get("generated_text", "") | |
| # Clean up and format the generated text | |
| lines = generated_text.split('\n') | |
| cleaned_lines = [] | |
| for line in lines: | |
| line = line.strip() | |
| if line and not line.startswith('###') and not line.startswith('```'): | |
| cleaned_lines.append(line) | |
| return "\n".join(cleaned_lines) | |
| return "Error: No text generated" | |
| elif response.status_code == 503: | |
| print(f"Model loading, attempt {attempt + 1}/{max_retries}. Waiting {wait_time} seconds...") | |
| time.sleep(wait_time) | |
| wait_time *= 1.5 # Increase wait time for next attempt | |
| continue | |
| else: | |
| return f"Error generating lyrics: {response.text}" | |
| except Exception as e: | |
| if attempt == max_retries - 1: # Last attempt | |
| return f"Error after {max_retries} attempts: {str(e)}" | |
| time.sleep(wait_time) | |
| wait_time *= 1.5 | |
| return "Failed to generate lyrics after multiple attempts. Please try again." | |
| def format_results(classification_results, lyrics, prompt): | |
| """Format the results for display""" | |
| # Format classification results | |
| classification_text = "Classification Results:\n" | |
| for i, result in enumerate(classification_results): | |
| classification_text += f"{i+1}. {result['label']}: {result['score']}\n" | |
| # Format final output | |
| output = f""" | |
| {classification_text} | |
| \n---Generated Lyrics---\n | |
| {lyrics} | |
| """ | |
| return output | |
| def classify_and_generate(audio_file): | |
| """ | |
| Classify the audio and generate matching lyrics | |
| """ | |
| if audio_file is None: | |
| return "Please upload an audio file." | |
| try: | |
| token = os.environ.get('HF_TOKEN') | |
| if not token: | |
| return "Error: HF_TOKEN environment variable is not set. Please set your Hugging Face API token." | |
| # First, classify the audio | |
| with open(audio_file, "rb") as f: | |
| data = f.read() | |
| print("Sending request to Audio Classification API...") | |
| response = requests.post(AUDIO_API_URL, headers=headers, data=data) | |
| if response.status_code == 200: | |
| classification_results = response.json() | |
| # Format classification results | |
| formatted_results = [] | |
| for result in classification_results: | |
| formatted_results.append({ | |
| 'label': result['label'], | |
| 'score': f"{result['score']*100:.2f}%" | |
| }) | |
| # Generate lyrics based on classification with retry logic | |
| print("Generating lyrics based on classification...") | |
| prompt = create_lyrics_prompt(formatted_results) | |
| lyrics = generate_lyrics_with_retry(prompt) | |
| # Format and return results | |
| return format_results(formatted_results, lyrics, prompt) | |
| elif response.status_code == 401: | |
| return "Error: Invalid or missing API token. Please check your Hugging Face API token." | |
| elif response.status_code == 503: | |
| return "Error: Model is loading. Please try again in a few seconds." | |
| else: | |
| return f"Error: API returned status code {response.status_code}\nResponse: {response.text}" | |
| except Exception as e: | |
| import traceback | |
| error_details = traceback.format_exc() | |
| return f"Error processing request: {str(e)}\nDetails:\n{error_details}" | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=classify_and_generate, | |
| inputs=gr.Audio(type="filepath", label="Upload Audio File"), | |
| outputs=gr.Textbox( | |
| label="Results", | |
| lines=15, | |
| placeholder="Upload an audio file to see classification results and generated lyrics..." | |
| ), | |
| title="Music Genre Classifier + Lyric Generator", | |
| description="Upload an audio file to classify its genre and generate matching lyrics using AI.", | |
| examples=[], | |
| ) | |
| # Launch the interface | |
| if __name__ == "__main__": | |
| iface.launch(server_name="0.0.0.0", server_port=7860) |