Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from pydub import AudioSegment | |
| import json | |
| import uuid | |
| import edge_tts | |
| import asyncio | |
| import aiofiles | |
| import os | |
| import time | |
| import mimetypes | |
| import torch | |
| from typing import List, Dict | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # Constants | |
| MAX_FILE_SIZE_MB = 20 | |
| MAX_FILE_SIZE_BYTES = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes | |
| MODEL_ID = "unsloth/gemma-3-1b-pt" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" | |
| ).eval() | |
| class PodcastGenerator: | |
| def __init__(self): | |
| pass | |
| async def generate_script(self, prompt: str, language: str, api_key: str, file_obj=None, progress=None) -> Dict: | |
| example = """ | |
| { | |
| "topic": "AGI", | |
| "podcast": [ | |
| { | |
| "speaker": 2, | |
| "line": "So, AGI, huh? Seems like everyone's talking about it these days." | |
| }, | |
| { | |
| "speaker": 1, | |
| "line": "Yeah, it's definitely having a moment, isn't it?" | |
| }, | |
| { | |
| "speaker": 2, | |
| "line": "It is and for good reason, right? I mean, you've been digging into this stuff, listening to the podcasts and everything. What really stood out to you? What got you hooked?" | |
| }, | |
| { | |
| "speaker": 1, | |
| "line": "I like that. It really is." | |
| }, | |
| { | |
| "speaker": 2, | |
| "line": "And honestly, that's a responsibility that extends beyond just the researchers and the policymakers." | |
| }, | |
| { | |
| "speaker": 1, | |
| "line": "100%" | |
| }, | |
| { | |
| "speaker": 2, | |
| "line": "So to everyone listening out there I'll leave you with this. As AGI continues to develop, what role do you want to play in shaping its future?" | |
| }, | |
| { | |
| "speaker": 1, | |
| "line": "That's a question worth pondering." | |
| }, | |
| { | |
| "speaker": 2, | |
| "line": "It certainly is and on that note, we'll wrap up this deep dive. Thanks for listening, everyone." | |
| }, | |
| { | |
| "speaker": 1, | |
| "line": "Peace." | |
| } | |
| ] | |
| } | |
| """ | |
| if language == "Auto Detect": | |
| language_instruction = "- The podcast MUST be in the same language as the user input." | |
| else: | |
| language_instruction = f"- The podcast MUST be in {language} language" | |
| system_prompt = f""" | |
| You are a professional podcast generator. Your task is to generate a professional podcast script based on the user input. | |
| {language_instruction} | |
| - The podcast should have 2 speakers. | |
| - The podcast should be long. | |
| - Do not use names for the speakers. | |
| - The podcast should be interesting, lively, and engaging, and hook the listener from the start. | |
| - The input text might be disorganized or unformatted, originating from sources like PDFs or text files. Ignore any formatting inconsistencies or irrelevant details; your task is to distill the essential points, identify key definitions, and highlight intriguing facts that would be suitable for discussion in a podcast. | |
| - The script must be in JSON format. | |
| Follow this example structure: | |
| {example} | |
| """ | |
| # Construct system and user prompt | |
| if prompt and file_obj: | |
| user_prompt = f"Please generate a podcast script based on the uploaded file following user input:\n{prompt}" | |
| elif prompt: | |
| user_prompt = f"Please generate a podcast script based on the following user input:\n{prompt}" | |
| else: | |
| user_prompt = "Please generate a podcast script based on the uploaded file." | |
| # NOTE: file_obj cannot be passed to a text-only LLM | |
| if file_obj: | |
| print("Warning: Uploaded file is ignored in this version because external LLM does not support file input.") | |
| # Build prompt | |
| full_prompt = f"""{system_prompt} | |
| {user_prompt} | |
| Return the result strictly as a JSON object in the format: | |
| {{ | |
| "topic": "{prompt}", | |
| "podcast": [ | |
| {{ "speaker": 1, "line": "..." }}, | |
| {{ "speaker": 2, "line": "..." }} | |
| ] | |
| }} | |
| """ | |
| try: | |
| if progress: | |
| progress(0.3, "Generating podcast script...") | |
| inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device) | |
| output = model.generate(**inputs, max_new_tokens=1024) | |
| text = tokenizer.decode(output[0], skip_special_tokens=True) | |
| except Exception as e: | |
| raise Exception(f"Failed to generate podcast script: {e}") | |
| print(f"Generated podcast script:\n{text}") | |
| if progress: | |
| progress(0.4, "Script generated successfully!") | |
| try: | |
| return json.loads(text) | |
| except json.JSONDecodeError: | |
| raise Exception("The model did not return valid JSON. Please refine the prompt.") | |
| async def _read_file_bytes(self, file_obj) -> bytes: | |
| """Read file bytes from a file object""" | |
| # Check file size before reading | |
| if hasattr(file_obj, 'size'): | |
| file_size = file_obj.size | |
| else: | |
| file_size = os.path.getsize(file_obj.name) | |
| if file_size > MAX_FILE_SIZE_BYTES: | |
| raise Exception(f"File size exceeds the {MAX_FILE_SIZE_MB}MB limit. Please upload a smaller file.") | |
| if hasattr(file_obj, 'read'): | |
| return file_obj.read() | |
| else: | |
| async with aiofiles.open(file_obj.name, 'rb') as f: | |
| return await f.read() | |
| def _get_mime_type(self, filename: str) -> str: | |
| """Determine MIME type based on file extension""" | |
| ext = os.path.splitext(filename)[1].lower() | |
| if ext == '.pdf': | |
| return "application/pdf" | |
| elif ext == '.txt': | |
| return "text/plain" | |
| else: | |
| # Fallback to the default mime type detector | |
| mime_type, _ = mimetypes.guess_type(filename) | |
| return mime_type or "application/octet-stream" | |
| async def tts_generate(self, text: str, speaker: int, speaker1: str, speaker2: str) -> str: | |
| voice = speaker1 if speaker == 1 else speaker2 | |
| speech = edge_tts.Communicate(text, voice) | |
| temp_filename = f"temp_{uuid.uuid4()}.wav" | |
| try: | |
| # Add timeout to TTS generation | |
| await asyncio.wait_for(speech.save(temp_filename), timeout=30) # 30 seconds timeout | |
| return temp_filename | |
| except asyncio.TimeoutError: | |
| if os.path.exists(temp_filename): | |
| os.remove(temp_filename) | |
| raise Exception("Text-to-speech generation timed out. Please try with a shorter text.") | |
| except Exception as e: | |
| if os.path.exists(temp_filename): | |
| os.remove(temp_filename) | |
| raise e | |
| async def combine_audio_files(self, audio_files: List[str], progress=None) -> str: | |
| if progress: | |
| progress(0.9, "Combining audio files...") | |
| combined_audio = AudioSegment.empty() | |
| for audio_file in audio_files: | |
| combined_audio += AudioSegment.from_file(audio_file) | |
| os.remove(audio_file) # Clean up temporary files | |
| output_filename = f"output_{uuid.uuid4()}.wav" | |
| combined_audio.export(output_filename, format="wav") | |
| if progress: | |
| progress(1.0, "Podcast generated successfully!") | |
| return output_filename | |
| async def generate_podcast(self, input_text: str, language: str, speaker1: str, speaker2: str, api_key: str, file_obj=None, progress=None) -> str: | |
| try: | |
| if progress: | |
| progress(0.1, "Starting podcast generation...") | |
| # Set overall timeout for the entire process | |
| return await asyncio.wait_for( | |
| self._generate_podcast_internal(input_text, language, speaker1, speaker2, api_key, file_obj, progress), | |
| timeout=600 # 10 minutes total timeout | |
| ) | |
| except asyncio.TimeoutError: | |
| raise Exception("The podcast generation process timed out. Please try with shorter text or try again later.") | |
| except Exception as e: | |
| raise Exception(f"Error generating podcast: {str(e)}") | |
| async def _generate_podcast_internal(self, input_text: str, language: str, speaker1: str, speaker2: str, api_key: str, file_obj=None, progress=None) -> str: | |
| if progress: | |
| progress(0.2, "Generating podcast script...") | |
| podcast_json = await self.generate_script(input_text, language, api_key, file_obj, progress) | |
| if progress: | |
| progress(0.5, "Converting text to speech...") | |
| # Process TTS in batches for concurrent processing | |
| audio_files = [] | |
| total_lines = len(podcast_json['podcast']) | |
| # Define batch size to control concurrency | |
| batch_size = 10 # Adjust based on system resources | |
| # Process in batches | |
| for batch_start in range(0, total_lines, batch_size): | |
| batch_end = min(batch_start + batch_size, total_lines) | |
| batch = podcast_json['podcast'][batch_start:batch_end] | |
| # Create tasks for concurrent processing | |
| tts_tasks = [] | |
| for item in batch: | |
| tts_task = self.tts_generate(item['line'], item['speaker'], speaker1, speaker2) | |
| tts_tasks.append(tts_task) | |
| try: | |
| # Process batch concurrently | |
| batch_results = await asyncio.gather(*tts_tasks, return_exceptions=True) | |
| # Check for exceptions and handle results | |
| for i, result in enumerate(batch_results): | |
| if isinstance(result, Exception): | |
| # Clean up any files already created | |
| for file in audio_files: | |
| if os.path.exists(file): | |
| os.remove(file) | |
| raise Exception(f"Error generating speech: {str(result)}") | |
| else: | |
| audio_files.append(result) | |
| # Update progress | |
| if progress: | |
| current_progress = 0.5 + (0.4 * (batch_end / total_lines)) | |
| progress(current_progress, f"Processed {batch_end}/{total_lines} speech segments...") | |
| except Exception as e: | |
| # Clean up any files already created | |
| for file in audio_files: | |
| if os.path.exists(file): | |
| os.remove(file) | |
| raise Exception(f"Error in batch TTS generation: {str(e)}") | |
| combined_audio = await self.combine_audio_files(audio_files, progress) | |
| return combined_audio | |
| async def process_input(input_text: str, input_file, language: str, speaker1: str, speaker2: str, api_key: str = "", progress=None) -> str: | |
| start_time = time.time() | |
| voice_names = { | |
| "Andrew - English (United States)": "en-US-AndrewMultilingualNeural", | |
| "Ava - English (United States)": "en-US-AvaMultilingualNeural", | |
| "Brian - English (United States)": "en-US-BrianMultilingualNeural", | |
| "Emma - English (United States)": "en-US-EmmaMultilingualNeural", | |
| "Florian - German (Germany)": "de-DE-FlorianMultilingualNeural", | |
| "Seraphina - German (Germany)": "de-DE-SeraphinaMultilingualNeural", | |
| "Remy - French (France)": "fr-FR-RemyMultilingualNeural", | |
| "Vivienne - French (France)": "fr-FR-VivienneMultilingualNeural" | |
| } | |
| speaker1 = voice_names[speaker1] | |
| speaker2 = voice_names[speaker2] | |
| try: | |
| if progress: | |
| progress(0.05, "Processing input...") | |
| api_key = "" # No API key needed for local model | |
| podcast_generator = PodcastGenerator() | |
| podcast = await podcast_generator.generate_podcast(input_text, language, speaker1, speaker2, api_key, input_file, progress) | |
| end_time = time.time() | |
| print(f"Total podcast generation time: {end_time - start_time:.2f} seconds") | |
| return podcast | |
| except Exception as e: | |
| # Ensure we show a user-friendly error | |
| error_msg = str(e) | |
| if "rate limit" in error_msg.lower(): | |
| raise Exception("Rate limit exceeded. Please try again later or use your own API key.") | |
| elif "timeout" in error_msg.lower(): | |
| raise Exception("The request timed out. This could be due to server load or the length of your input. Please try again with shorter text.") | |
| else: | |
| raise Exception(f"Error: {error_msg}") | |
| # Gradio UI | |
| def generate_podcast_gradio(input_text, input_file, language, speaker1, speaker2, api_key, progress=gr.Progress()): | |
| # Handle the file if uploaded | |
| file_obj = None | |
| if input_file is not None: | |
| file_obj = input_file | |
| # Use the progress function from Gradio | |
| def progress_callback(value, text): | |
| progress(value, text) | |
| # Run the async function in the event loop | |
| result = asyncio.run(process_input( | |
| input_text, | |
| file_obj, | |
| language, | |
| speaker1, | |
| speaker2, | |
| api_key, | |
| progress_callback | |
| )) | |
| return result | |
| def main(): | |
| # Define language options | |
| language_options = [ | |
| "Auto Detect", | |
| "Afrikaans", "Albanian", "Amharic", "Arabic", "Armenian", "Azerbaijani", | |
| "Bahasa Indonesian", "Bangla", "Basque", "Bengali", "Bosnian", "Bulgarian", | |
| "Burmese", "Catalan", "Chinese Cantonese", "Chinese Mandarin", | |
| "Chinese Taiwanese", "Croatian", "Czech", "Danish", "Dutch", "English", | |
| "Estonian", "Filipino", "Finnish", "French", "Galician", "Georgian", | |
| "German", "Greek", "Hebrew", "Hindi", "Hungarian", "Icelandic", "Irish", | |
| "Italian", "Japanese", "Javanese", "Kannada", "Kazakh", "Khmer", "Korean", | |
| "Lao", "Latvian", "Lithuanian", "Macedonian", "Malay", "Malayalam", | |
| "Maltese", "Mongolian", "Nepali", "Norwegian Bokmål", "Pashto", "Persian", | |
| "Polish", "Portuguese", "Romanian", "Russian", "Serbian", "Sinhala", | |
| "Slovak", "Slovene", "Somali", "Spanish", "Sundanese", "Swahili", | |
| "Swedish", "Tamil", "Telugu", "Thai", "Turkish", "Ukrainian", "Urdu", | |
| "Uzbek", "Vietnamese", "Welsh", "Zulu" | |
| ] | |
| # Define voice options | |
| voice_options = [ | |
| "Andrew - English (United States)", | |
| "Ava - English (United States)", | |
| "Brian - English (United States)", | |
| "Emma - English (United States)", | |
| "Florian - German (Germany)", | |
| "Seraphina - German (Germany)", | |
| "Remy - French (France)", | |
| "Vivienne - French (France)" | |
| ] | |
| # Create Gradio interface | |
| with gr.Blocks(title="PodcastGen 🎙️") as demo: | |
| gr.Markdown("# PodcastGen 🎙️") | |
| gr.Markdown("Generate a 2-speaker podcast from text input or documents!") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| input_text = gr.Textbox(label="Input Text", lines=10, placeholder="Enter text for podcast generation...") | |
| with gr.Column(scale=1): | |
| input_file = gr.File(label="Or Upload a PDF or TXT file", file_types=[".pdf", ".txt"]) | |
| with gr.Row(): | |
| with gr.Column(): | |
| api_key = gr.Textbox(label="Your Gemini API Key (Optional)", placeholder="Enter API key here if you're getting rate limited", type="password") | |
| language = gr.Dropdown(label="Language", choices=language_options, value="Auto Detect") | |
| with gr.Column(): | |
| speaker1 = gr.Dropdown(label="Speaker 1 Voice", choices=voice_options, value="Andrew - English (United States)") | |
| speaker2 = gr.Dropdown(label="Speaker 2 Voice", choices=voice_options, value="Ava - English (United States)") | |
| generate_btn = gr.Button("Generate Podcast", variant="primary") | |
| with gr.Row(): | |
| output_audio = gr.Audio(label="Generated Podcast", type="filepath", format="wav") | |
| generate_btn.click( | |
| fn=generate_podcast_gradio, | |
| inputs=[input_text, input_file, language, speaker1, speaker2, api_key], | |
| outputs=[output_audio] | |
| ) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() | |