Spaces:
Running
Running
| import os | |
| import json | |
| import re | |
| import gradio as gr | |
| import asyncio | |
| import logging | |
| import torch | |
| import random | |
| import tempfile | |
| import zipfile | |
| from serpapi import GoogleSearch | |
| from pydantic import BaseModel | |
| from autogen_agentchat.agents import AssistantAgent | |
| from autogen_agentchat.conditions import HandoffTermination, TextMentionTermination | |
| from autogen_agentchat.teams import Swarm | |
| from autogen_agentchat.ui import Console | |
| from autogen_agentchat.messages import TextMessage, HandoffMessage, StructuredMessage | |
| from autogen_ext.models.anthropic import AnthropicChatCompletionClient | |
| from autogen_ext.models.openai import OpenAIChatCompletionClient | |
| from autogen_ext.models.ollama import OllamaChatCompletionClient | |
| from markdown_pdf import MarkdownPdf, Section | |
| import traceback | |
| import soundfile as sf | |
| import shutil | |
| from pydub import AudioSegment | |
| from TTS.api import TTS | |
| from gradio_pdf import PDF | |
| # Set up logging | |
| logging.basicConfig( | |
| level=logging.DEBUG, | |
| format="%(asctime)s - %(levelname)s - %(message)s", | |
| handlers=[ | |
| logging.FileHandler("lecture_generation.log"), | |
| logging.StreamHandler() | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Set up environment | |
| os.environ["COQUI_TOS_AGREED"] = "1" | |
| # Define Pydantic model for slide data | |
| class Slide(BaseModel): | |
| title: str | |
| content: str | |
| class SlidesOutput(BaseModel): | |
| slides: list[Slide] | |
| # Define search_web tool using SerpApi | |
| def search_web(query: str, serpapi_key: str) -> str: | |
| try: | |
| params = { | |
| "q": query, | |
| "engine": "google", | |
| "api_key": serpapi_key, | |
| "num": 5 | |
| } | |
| search = GoogleSearch(params) | |
| results = search.get_dict() | |
| if "error" in results: | |
| logger.error("SerpApi error: %s", results["error"]) | |
| return f"Error during search: {results['error']}" | |
| if "organic_results" not in results or not results["organic_results"]: | |
| logger.info("No search results found for query: %s", query) | |
| return f"No results found for query: {query}" | |
| formatted_results = [] | |
| for item in results["organic_results"][:5]: | |
| title = item.get("title", "No title") | |
| snippet = item.get("snippet", "No snippet") | |
| link = item.get("link", "No link") | |
| formatted_results.append(f"Title: {title}\nSnippet: {snippet}\nLink: {link}\n") | |
| formatted_output = "\n".join(formatted_results) | |
| logger.info("Successfully retrieved search results for query: %s", query) | |
| return f"Search results for {query}:\n{formatted_output}" | |
| except Exception as e: | |
| logger.error("Unexpected error during search: %s", str(e)) | |
| return f"Unexpected error during search: {str(e)}" | |
| # Function to get model client based on selected service | |
| def get_model_client(service, api_key): | |
| if service == "OpenAI-gpt-4o-2024-08-06": | |
| return OpenAIChatCompletionClient(model="gpt-4o-2024-08-06", api_key=api_key) | |
| elif service == "Anthropic-claude-3-sonnet-20240229": | |
| return AnthropicChatCompletionClient(model="claude-3-sonnet-20240229", api_key=api_key) | |
| elif service == "Google-gemini-1.5-flash": | |
| return OpenAIChatCompletionClient(model="gemini-1.5-flash", api_key=api_key) | |
| elif service == "Ollama-llama3.2": | |
| return OllamaChatCompletionClient(model="llama3.2") | |
| else: | |
| raise ValueError("Invalid service") | |
| # Helper function to clean script text and make it natural | |
| def clean_script_text(script): | |
| if not script or not isinstance(script, str): | |
| logger.error("Invalid script input: %s", script) | |
| return None | |
| # Minimal cleaning to preserve natural language | |
| script = re.sub(r"\*\*Slide \d+:.*?\*\*", "", script) # Remove slide headers | |
| script = re.sub(r"\[.*?\]", "", script) # Remove bracketed content | |
| script = re.sub(r"Title:.*?\n|Content:.*?\n", "", script) # Remove metadata | |
| script = script.replace("humanlike", "human-like").replace("problemsolving", "problem-solving") | |
| script = re.sub(r"\s+", " ", script).strip() # Normalize whitespace | |
| # Convert bullet points to spoken cues | |
| script = re.sub(r"^\s*-\s*", "So, ", script, flags=re.MULTILINE) | |
| # Add non-verbal words randomly (e.g., "um," "you know," "like") | |
| non_verbal = ["um, ", "you know, ", "like, "] | |
| words = script.split() | |
| for i in range(len(words) - 1, -1, -1): | |
| if random.random() < 0.1: # 10% chance per word | |
| words.insert(i, random.choice(non_verbal)) | |
| script = " ".join(words) | |
| # Basic validation | |
| if len(script) < 10: | |
| logger.error("Cleaned script too short (%d characters): %s", len(script), script) | |
| return None | |
| logger.info("Cleaned and naturalized script: %s", script) | |
| return script | |
| # Helper function to validate and convert speaker audio (MP3 or WAV) | |
| async def validate_and_convert_speaker_audio(speaker_audio, temp_dir): | |
| if not os.path.exists(speaker_audio): | |
| logger.error("Speaker audio file does not exist: %s", speaker_audio) | |
| return None | |
| try: | |
| # Check file extension | |
| ext = os.path.splitext(speaker_audio)[1].lower() | |
| if ext == ".mp3": | |
| logger.info("Converting MP3 to WAV: %s", speaker_audio) | |
| audio = AudioSegment.from_mp3(speaker_audio) | |
| # Convert to mono, 22050 Hz | |
| audio = audio.set_channels(1).set_frame_rate(22050) | |
| speaker_wav = os.path.join(temp_dir, "speaker_converted.wav") | |
| audio.export(speaker_wav, format="wav") | |
| elif ext == ".wav": | |
| speaker_wav = speaker_audio | |
| else: | |
| logger.error("Unsupported audio format: %s", ext) | |
| return None | |
| # Validate WAV file | |
| data, samplerate = sf.read(speaker_wav) | |
| if samplerate < 16000 or samplerate > 48000: | |
| logger.error("Invalid sample rate for %s: %d Hz", speaker_wav, samplerate) | |
| return None | |
| if len(data) < 16000: | |
| logger.error("Speaker audio too short: %d frames", len(data)) | |
| return None | |
| if data.ndim == 2: | |
| logger.info("Converting stereo WAV to mono: %s", speaker_wav) | |
| data = data.mean(axis=1) | |
| mono_wav = os.path.join(temp_dir, "speaker_mono.wav") | |
| sf.write(mono_wav, data, samplerate) | |
| speaker_wav = mono_wav | |
| logger.info("Validated speaker audio: %s", speaker_wav) | |
| return speaker_wav | |
| except Exception as e: | |
| logger.error("Failed to validate or convert speaker audio %s: %s", speaker_audio, str(e)) | |
| return None | |
| # Helper function to generate audio using Coqui TTS API | |
| def generate_xtts_audio(tts, text, speaker_wav, output_path): | |
| if not tts: | |
| logger.error("TTS model not initialized") | |
| return False | |
| try: | |
| tts.tts_to_file(text=text, speaker_wav=speaker_wav, language="en", file_path=output_path) | |
| logger.info("Generated audio for %s", output_path) | |
| return True | |
| except Exception as e: | |
| logger.error("Failed to generate audio for %s: %s", output_path, str(e)) | |
| return False | |
| # Helper function to extract JSON from messages | |
| def extract_json_from_message(message): | |
| if isinstance(message, TextMessage): | |
| content = message.content | |
| logger.debug("Extracting JSON from TextMessage: %s", content) | |
| if not isinstance(content, str): | |
| logger.warning("TextMessage content is not a string: %s", content) | |
| return None | |
| # Try standard JSON block | |
| pattern = r"```json\s*(.*?)\s*```" | |
| match = re.search(pattern, content, re.DOTALL) | |
| if match: | |
| try: | |
| return json.loads(match.group(1)) | |
| except json.JSONDecodeError as e: | |
| logger.error("Failed to parse JSON from TextMessage: %s, Content: %s", e, content) | |
| # Fallback: Try raw JSON array | |
| json_pattern = r"\[\s*\{.*?\}\s*\]" | |
| match = re.search(json_pattern, content, re.DOTALL) | |
| if match: | |
| try: | |
| return json.loads(match.group(0)) | |
| except json.JSONDecodeError as e: | |
| logger.error("Failed to parse fallback JSON from TextMessage: %s, Content: %s", e, content) | |
| # Fallback: Try any JSON-like structure | |
| try: | |
| parsed = json.loads(content) | |
| if isinstance(parsed, (list, dict)): | |
| logger.info("Parsed JSON from raw content: %s", parsed) | |
| return parsed | |
| except json.JSONDecodeError: | |
| pass | |
| logger.warning("No JSON found in TextMessage content: %s", content) | |
| return None | |
| elif isinstance(message, StructuredMessage): | |
| content = message.content | |
| logger.debug("Extracting JSON from StructuredMessage: %s", content) | |
| try: | |
| if isinstance(content, BaseModel): | |
| content_dict = content.dict() | |
| return content_dict.get("slides", content_dict) | |
| return content | |
| except Exception as e: | |
| logger.error("Failed to extract JSON from StructuredMessage: %s, Content: %s", e, content) | |
| return None | |
| elif isinstance(message, HandoffMessage): | |
| logger.debug("Extracting JSON from HandoffMessage context") | |
| for ctx_msg in message.context: | |
| if hasattr(ctx_msg, "content"): | |
| content = ctx_msg.content | |
| logger.debug("Handoff context message content: %s", content) | |
| if isinstance(content, str): | |
| pattern = r"```json\s*(.*?)\s*```" | |
| match = re.search(pattern, content, re.DOTALL) | |
| if match: | |
| try: | |
| return json.loads(match.group(1)) | |
| except json.JSONDecodeError as e: | |
| logger.error("Failed to parse JSON from HandoffMessage context: %s, Content: %s", e, content) | |
| json_pattern = r"\[\s*\{.*?\}\s*\]" | |
| match = re.search(json_pattern, content, re.DOTALL) | |
| if match: | |
| try: | |
| return json.loads(match.group(0)) | |
| except json.JSONDecodeError as e: | |
| logger.error("Failed to parse fallback JSON from HandoffMessage context: %s, Content: %s", e, content) | |
| try: | |
| parsed = json.loads(content) | |
| if isinstance(parsed, (list, dict)): | |
| logger.info("Parsed JSON from raw HandoffMessage context: %s", parsed) | |
| return parsed | |
| except json.JSONDecodeError: | |
| pass | |
| elif isinstance(content, dict): | |
| return content.get("slides", content) | |
| logger.warning("No JSON found in HandoffMessage context") | |
| return None | |
| logger.warning("Unsupported message type for JSON extraction: %s", type(message)) | |
| return None | |
| # Function to generate Markdown and convert to PDF (portrait, centered) | |
| def generate_slides_pdf(slides, temp_dir): | |
| pdf = MarkdownPdf() | |
| for slide in slides: | |
| content_lines = slide['content'].replace('\n', '\n\n') | |
| markdown_content = f""" | |
| <div style="display: flex; flex-direction: column; justify-content: center; align-items: center; height: 100%; text-align: center; padding: 20px;"> | |
| # {slide['title']} | |
| *Prof. AI Feynman* | |
| *Princeton University, April 26th, 2025* | |
| {content_lines} | |
| </div> | |
| --- | |
| """ | |
| pdf.add_section(Section(markdown_content, toc=False)) | |
| pdf_file = os.path.join(temp_dir, "slides.pdf") | |
| pdf.save(pdf_file) | |
| logger.info("Generated PDF slides (portrait): %s", pdf_file) | |
| return pdf_file | |
| # Helper function to create ZIP file of outputs | |
| def create_outputs_zip(temp_dir, slides, audio_files, scripts): | |
| zip_path = os.path.join(temp_dir, "lecture_outputs.zip") | |
| with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: | |
| # Add slides PDF | |
| pdf_file = os.path.join(temp_dir, "slides.pdf") | |
| if os.path.exists(pdf_file): | |
| zipf.write(pdf_file, "slides.pdf") | |
| # Add audio files | |
| for i, audio_file in enumerate(audio_files): | |
| if audio_file and os.path.exists(audio_file): | |
| zipf.write(audio_file, f"slide_{i+1}.wav") | |
| # Add raw and cleaned scripts | |
| for i in range(len(slides)): | |
| raw_script_file = os.path.join(temp_dir, f"slide_{i+1}_raw_script.txt") | |
| cleaned_script_file = os.path.join(temp_dir, f"slide_{i+1}_script.txt") | |
| if os.path.exists(raw_script_file): | |
| zipf.write(raw_script_file, f"slide_{i+1}_raw_script.txt") | |
| if os.path.exists(cleaned_script_file): | |
| zipf.write(cleaned_script_file, f"slide_{i+1}_script.txt") | |
| logger.info("Created ZIP file: %s", zip_path) | |
| return zip_path | |
| # Helper function for progress HTML | |
| def html_with_progress(label, progress): | |
| return f""" | |
| <div style="display: flex; flex-direction: column; justify-content: center; align-items: center; height: 100%; min-height: 700px; padding: 20px; text-align: center; border: 1px solid #ddd; border-radius: 8px;"> | |
| <div style="width: 100%; background-color: #FFFFFF; border-radius: 10px; overflow: hidden; margin-bottom: 20px;"> | |
| <div style="width: {progress}%; height: 30px; background-color: #4CAF50; border-radius: 10px;"></div> | |
| </div> | |
| <h2 style="font-style: italic; color: #555;">{label}</h2> | |
| </div> | |
| """ | |
| # Async function to update audio preview | |
| async def update_audio_preview(audio_file): | |
| if audio_file: | |
| logger.info("Updating audio preview for file: %s", audio_file) | |
| return audio_file | |
| return None | |
| # Async function to generate lecture materials and audio | |
| async def on_generate(api_service, api_key, serpapi_key, title, topic, instructions, lecture_type, speaker_audio, num_slides): | |
| if not serpapi_key: | |
| yield html_with_progress("SerpApi key required. Please provide a valid key.", 0) | |
| return | |
| # Create temporary directory | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| # Initialize TTS model | |
| tts = None | |
| try: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device) | |
| logger.info("TTS model initialized on %s", device) | |
| except Exception as e: | |
| logger.error("Failed to initialize TTS model: %s", str(e)) | |
| yield html_with_progress(f"TTS model initialization failed: {str(e)}", 0) | |
| return | |
| model_client = get_model_client(api_service, api_key) | |
| research_agent = AssistantAgent( | |
| name="research_agent", | |
| model_client=model_client, | |
| handoffs=["slide_agent"], | |
| system_message="You are a Research Agent. Use the search_web tool to gather information on the topic and keywords from the initial message. Summarize the findings concisely in a single message, then use the handoff_to_slide_agent tool to pass the task to the Slide Agent. Do not produce any other output.", | |
| tools=[search_web] | |
| ) | |
| slide_agent = AssistantAgent( | |
| name="slide_agent", | |
| model_client=model_client, | |
| handoffs=["script_agent"], | |
| system_message=f""" | |
| You are a Slide Agent. Using the research from the conversation history, generate EXACTLY {num_slides} content slides on the topic, plus 1 quiz slide, 1 assignment slide, and 1 thank-you slide, for a TOTAL of {num_slides + 3} slides. Output ONLY a JSON array wrapped in ```json ... ``` in a TextMessage, with each slide as an object with 'title' and 'content' keys. Ensure the JSON is valid and contains precisely {num_slides + 3} slides. If the slide count is incorrect, adjust the output to meet this requirement before proceeding. Do not include explanatory text or comments. After outputting the JSON, use the handoff_to_script_agent tool. | |
| Example for 2 content slides: | |
| ```json | |
| [ | |
| {{"title": "Slide 1", "content": "Content for slide 1"}}, | |
| {{"title": "Slide 2", "content": "Content for slide 2"}}, | |
| {{"title": "Quiz", "content": "Quiz questions"}}, | |
| {{"title": "Assignment", "content": "Assignment details"}}, | |
| {{"title": "Thank You", "content": "Thank you message"}} | |
| ] | |
| ```""", | |
| output_content_type=None, | |
| reflect_on_tool_use=False | |
| ) | |
| script_agent = AssistantAgent( | |
| name="script_agent", | |
| model_client=model_client, | |
| handoffs=["feynman_agent"], | |
| system_message=f""" | |
| You are a Script Agent. Access the JSON array of {num_slides + 3} slides from the conversation history. Generate a narration script (1-2 sentences) for each of the {num_slides + 3} slides, summarizing its content in a natural, conversational tone as a speaker would, including occasional non-verbal words (e.g., "um," "you know," "like"). Output ONLY a JSON array wrapped in ```json ... ``` with exactly {num_slides + 3} strings, one script per slide, in the same order. Ensure the JSON is valid and complete. After outputting, use the handoff_to_feynman_agent tool. If scripts cannot be generated, retry once. | |
| Example for 1 content slide: | |
| ```json | |
| [ | |
| "So, this slide, um, covers the main topic in a fun way.", | |
| "Alright, you know, answer these quiz questions.", | |
| "Here's your, like, assignment to complete.", | |
| "Thanks for, um, attending today!" | |
| ] | |
| ```""", | |
| output_content_type=None, | |
| reflect_on_tool_use=False | |
| ) | |
| feynman_agent = AssistantAgent( | |
| name="feynman_agent", | |
| model_client=model_client, | |
| handoffs=[], | |
| system_message=f""" | |
| You are Agent Feynman. Review the slides and scripts from the conversation history to ensure coherence, completeness, and that exactly {num_slides + 3} slides and {num_slides + 3} scripts are received. Output a confirmation message summarizing the number of slides and scripts received. If slides or scripts are missing, invalid, or do not match the expected count ({num_slides + 3}), report the issue clearly. Use 'TERMINATE' to signal completion. | |
| Example: 'Received {num_slides + 3} slides and {num_slides + 3} scripts. Lecture is coherent. TERMINATE' | |
| """) | |
| swarm = Swarm( | |
| participants=[research_agent, slide_agent, script_agent, feynman_agent], | |
| termination_condition=HandoffTermination(target="user") | TextMentionTermination("TERMINATE") | |
| ) | |
| progress = 0 | |
| label = "Research: in progress..." | |
| yield html_with_progress(label, progress) | |
| await asyncio.sleep(0.1) | |
| initial_message = f""" | |
| Lecture Title: {title} | |
| Topic: {topic} | |
| Additional Instructions: {instructions} | |
| Audience: {lecture_type} | |
| Number of Content Slides: {num_slides} | |
| Please start by researching the topic. | |
| """ | |
| logger.info("Starting lecture generation for topic: %s", topic) | |
| slides = None | |
| scripts = None | |
| max_slide_retries = 2 | |
| slide_retry_count = 0 | |
| while slide_retry_count <= max_slide_retries: | |
| try: | |
| logger.info("Research Agent starting (Slide attempt %d/%d)", slide_retry_count + 1, max_slide_retries) | |
| task_result = await Console(swarm.run_stream(task=initial_message)) | |
| logger.info("Swarm execution completed") | |
| script_retry_count = 0 | |
| max_script_retries = 2 | |
| for message in task_result.messages: | |
| source = getattr(message, 'source', getattr(message, 'sender', None)) | |
| logger.debug("Processing message from %s, type: %s, content: %s", source, type(message), message.to_text() if hasattr(message, 'to_text') else str(message)) | |
| if isinstance(message, HandoffMessage): | |
| logger.info("Handoff from %s to %s, Context: %s", source, message.target, message.context) | |
| if source == "research_agent" and message.target == "slide_agent": | |
| progress = 25 | |
| label = "Slides: generating..." | |
| yield html_with_progress(label, progress) | |
| await asyncio.sleep(0.1) | |
| elif source == "slide_agent" and message.target == "script_agent": | |
| if slides is None: | |
| logger.warning("Slide Agent handoff without slides JSON") | |
| extracted_json = extract_json_from_message(message) | |
| if extracted_json: | |
| slides = extracted_json | |
| logger.info("Extracted slides JSON from HandoffMessage context: %s", slides) | |
| if slides is None: | |
| label = "Slides: failed to generate..." | |
| yield html_with_progress(label, progress) | |
| await asyncio.sleep(0.1) | |
| progress = 50 | |
| label = "Scripts: generating..." | |
| yield html_with_progress(label, progress) | |
| await asyncio.sleep(0.1) | |
| elif source == "script_agent" and message.target == "feynman_agent": | |
| if scripts is None: | |
| logger.warning("Script Agent handoff without scripts JSON") | |
| extracted_json = extract_json_from_message(message) | |
| if extracted_json: | |
| scripts = extracted_json | |
| logger.info("Extracted scripts JSON from HandoffMessage context: %s", scripts) | |
| progress = 75 | |
| label = "Review: in progress..." | |
| yield html_with_progress(label, progress) | |
| await asyncio.sleep(0.1) | |
| elif source == "research_agent" and isinstance(message, TextMessage) and "handoff_to_slide_agent" in message.content: | |
| logger.info("Research Agent completed research") | |
| progress = 25 | |
| label = "Slides: generating..." | |
| yield html_with_progress(label, progress) | |
| await asyncio.sleep(0.1) | |
| elif source == "slide_agent" and isinstance(message, (TextMessage, StructuredMessage)): | |
| logger.debug("Slide Agent message received: %s", message.to_text()) | |
| extracted_json = extract_json_from_message(message) | |
| if extracted_json: | |
| slides = extracted_json | |
| logger.info("Slide Agent generated %d slides: %s", len(slides), slides) | |
| expected_slide_count = num_slides + 3 | |
| if len(slides) != expected_slide_count: | |
| logger.warning("Generated %d slides, expected %d. Retrying...", len(slides), expected_slide_count) | |
| slide_retry_count += 1 | |
| if slide_retry_count <= max_slide_retries: | |
| # Re-prompt slide agent | |
| retry_message = TextMessage( | |
| content=f"Please generate EXACTLY {num_slides} content slides plus 1 quiz, 1 assignment, and 1 thank-you slide (total {num_slides + 3}).", | |
| source="user", | |
| recipient="slide_agent" | |
| ) | |
| task_result.messages.append(retry_message) | |
| slides = None | |
| continue | |
| else: | |
| yield html_with_progress(f"Failed to generate correct number of slides after {max_slide_retries} retries. Expected {expected_slide_count}, got {len(slides)}.", progress) | |
| return | |
| # Save slide content to individual files | |
| for i, slide in enumerate(slides): | |
| content_file = os.path.join(temp_dir, f"slide_{i+1}_content.txt") | |
| try: | |
| with open(content_file, "w", encoding="utf-8") as f: | |
| f.write(slide["content"]) | |
| logger.info("Saved slide content to %s: %s", content_file, slide["content"]) | |
| except Exception as e: | |
| logger.error("Error saving slide content to %s: %s", content_file, str(e)) | |
| progress = 50 | |
| label = "Scripts: generating..." | |
| yield html_with_progress(label, progress) | |
| await asyncio.sleep(0.1) | |
| else: | |
| logger.warning("No JSON extracted from slide_agent message: %s", message.to_text()) | |
| elif source == "script_agent" and isinstance(message, (TextMessage, StructuredMessage)): | |
| logger.debug("Script Agent message received: %s", message.to_text()) | |
| extracted_json = extract_json_from_message(message) | |
| if extracted_json: | |
| scripts = extracted_json | |
| logger.info("Script Agent generated scripts for %d slides: %s", len(scripts), scripts) | |
| # Save raw scripts to individual files | |
| for i, script in enumerate(scripts): | |
| script_file = os.path.join(temp_dir, f"slide_{i+1}_raw_script.txt") | |
| try: | |
| with open(script_file, "w", encoding="utf-8") as f: | |
| f.write(script) | |
| logger.info("Saved raw script to %s: %s", script_file, script) | |
| except Exception as e: | |
| logger.error("Error saving raw script to %s: %s", script_file, str(e)) | |
| progress = 75 | |
| label = "Scripts generated and saved. Reviewing..." | |
| yield html_with_progress(label, progress) | |
| await asyncio.sleep(0.1) | |
| else: | |
| logger.warning("No JSON extracted from script_agent message: %s", message.to_text()) | |
| if script_retry_count < max_script_retries: | |
| script_retry_count += 1 | |
| logger.info("Retrying script generation (attempt %d/%d)", script_retry_count, max_script_retries) | |
| # Re-prompt script agent | |
| retry_message = TextMessage( | |
| content="Please generate scripts for the slides as per your instructions.", | |
| source="user", | |
| recipient="script_agent" | |
| ) | |
| task_result.messages.append(retry_message) | |
| continue | |
| elif source == "feynman_agent" and isinstance(message, TextMessage) and "TERMINATE" in message.content: | |
| logger.info("Feynman Agent completed lecture review: %s", message.content) | |
| progress = 90 | |
| label = "Lecture materials ready. Generating audio..." | |
| yield html_with_progress(label, progress) | |
| await asyncio.sleep(0.1) | |
| logger.info("Slides state: %s", "Generated" if slides else "None") | |
| logger.info("Scripts state: %s", "Generated" if scripts else "None") | |
| if not slides or not scripts: | |
| error_message = f"Failed to generate {'slides and scripts' if not slides and not scripts else 'slides' if not slides else 'scripts'}" | |
| error_message += f". Received {len(slides) if slides else 0} slides and {len(scripts) if scripts else 0} scripts." | |
| logger.error("%s", error_message) | |
| yield html_with_progress(error_message, progress) | |
| return | |
| expected_slide_count = num_slides + 3 | |
| if len(slides) != expected_slide_count: | |
| logger.error("Final validation failed: Expected %d slides, received %d", expected_slide_count, len(slides)) | |
| yield html_with_progress(f"Incorrect number of slides. Expected {expected_slide_count}, got {len(slides)}.", progress) | |
| return | |
| if not isinstance(scripts, list) or not all(isinstance(s, str) for s in scripts): | |
| logger.error("Scripts are not a list of strings: %s", scripts) | |
| yield html_with_progress("Invalid script format. Scripts must be a list of strings.", progress) | |
| return | |
| if len(scripts) != expected_slide_count: | |
| logger.error("Mismatch between number of slides (%d) and scripts (%d)", len(slides), len(scripts)) | |
| yield html_with_progress(f"Mismatch in slides and scripts. Generated {len(slides)} slides but {len(scripts)} scripts.", progress) | |
| return | |
| # Generate PDF from slides | |
| pdf_file = generate_slides_pdf(slides, temp_dir) | |
| audio_files = [] | |
| speaker_audio = speaker_audio if speaker_audio else "feynman.mp3" | |
| validated_speaker_wav = await validate_and_convert_speaker_audio(speaker_audio, temp_dir) | |
| if not validated_speaker_wav: | |
| logger.error("Invalid speaker audio after conversion, skipping TTS") | |
| yield html_with_progress("Invalid speaker audio. Please upload a valid MP3 or WAV file.", progress) | |
| return | |
| # Process audio generation sequentially with retries | |
| for i, script in enumerate(scripts): | |
| cleaned_script = clean_script_text(script) | |
| audio_file = os.path.join(temp_dir, f"slide_{i+1}.wav") | |
| script_file = os.path.join(temp_dir, f"slide_{i+1}_script.txt") | |
| # Save cleaned script | |
| try: | |
| with open(script_file, "w", encoding="utf-8") as f: | |
| f.write(cleaned_script or "") | |
| logger.info("Saved cleaned script to %s: %s", script_file, cleaned_script) | |
| except Exception as e: | |
| logger.error("Error saving cleaned script to %s: %s", script_file, str(e)) | |
| if not cleaned_script: | |
| logger.error("Skipping audio for slide %d due to empty or invalid script", i + 1) | |
| audio_files.append(None) | |
| progress = 90 + ((i + 1) / len(scripts)) * 10 | |
| label = f"Generated audio for slide {i + 1}/{len(scripts)}..." | |
| yield html_with_progress(label, progress) | |
| await asyncio.sleep(0.1) | |
| continue | |
| max_retries = 2 | |
| for attempt in range(max_retries + 1): | |
| try: | |
| current_text = cleaned_script | |
| if attempt > 0: | |
| sentences = re.split(r"[.!?]+", cleaned_script) | |
| sentences = [s.strip() for s in sentences if s.strip()][:2] | |
| current_text = ". ".join(sentences) + "." | |
| logger.info("Retry %d for slide %d with simplified text: %s", attempt, i + 1, current_text) | |
| success = generate_xtts_audio(tts, current_text, validated_speaker_wav, audio_file) | |
| if not success: | |
| raise RuntimeError("TTS generation failed") | |
| logger.info("Generated audio for slide %d: %s", i + 1, audio_file) | |
| audio_files.append(audio_file) | |
| progress = 90 + ((i + 1) / len(scripts)) * 10 | |
| label = f"Generated audio for slide {i + 1}/{len(scripts)}..." | |
| yield html_with_progress(label, progress) | |
| await asyncio.sleep(0.1) | |
| break | |
| except Exception as e: | |
| logger.error("Error generating audio for slide %d (attempt %d): %s\n%s", i + 1, attempt, str(e), traceback.format_exc()) | |
| if attempt == max_retries: | |
| logger.error("Max retries reached for slide %d, skipping", i + 1) | |
| audio_files.append(None) | |
| progress = 90 + ((i + 1) / len(scripts)) * 10 | |
| label = f"Generated audio for slide {i + 1}/{len(scripts)}..." | |
| yield html_with_progress(label, progress) | |
| await asyncio.sleep(0.1) | |
| break | |
| # Create ZIP file of all outputs | |
| zip_path = create_outputs_zip(temp_dir, slides, audio_files, scripts) | |
| # Prepare UI output | |
| slides_info = json.dumps({"slides": [ | |
| {"title": slide["title"], "content": slide["content"]} | |
| for slide in slides | |
| ], "audioFiles": audio_files}) | |
| html_output = f""" | |
| <div id="lecture-container" style="height: 700px; border: 1px solid #ddd; border-radius: 8px; display: flex; flex-direction: column; justify-content: space-between;"> | |
| <div id="slide-content" style="flex: 1; overflow: auto;"> | |
| <div id="pdf-viewer"></div> | |
| </div> | |
| <div style="padding: 20px;"> | |
| <div id="progress-bar" style="width: 100%; height: 5px; background-color: #ddd; border-radius: 2px; margin-bottom: 10px;"> | |
| <div id="progress-fill" style="width: {(1/len(slides)*100)}%; height: 100%; background-color: #4CAF50; border-radius: 2px;"></div> | |
| </div> | |
| <div style="display: flex; justify-content: center; margin-bottom: 10px;"> | |
| <button onclick="prevSlide()" style="border-radius: 50%; width: 40px; height: 40px; margin: 0 5px; font-size: 1.2em; cursor: pointer;">⏮</button> | |
| <button onclick="togglePlay()" style="border-radius: 50%; width: 40px; height: 40px; margin: 0 5px; font-size: 1.2em; cursor: pointer;">⏯</button> | |
| <button onclick="nextSlide()" style="border-radius: 50%; width: 40px; height: 40px; margin: 0 5px; font-size: 1.2em; cursor: pointer;">⏭</button> | |
| </div> | |
| <p id="slide-counter" style="text-align: center;">Slide 1 of {len(slides)}</p> | |
| </div> | |
| </div> | |
| <script> | |
| const lectureData = {slides_info}; | |
| let currentSlide = 0; | |
| const totalSlides = lectureData.slides.length; | |
| const slideCounter = document.getElementById('slide-counter'); | |
| const progressFill = document.getElementById('progress-fill'); | |
| let audioElements = []; | |
| let currentAudio = null; | |
| for (let i = 0; i < totalSlides; i++) {{ | |
| if (lectureData.audioFiles && lectureData.audioFiles[i]) {{ | |
| const audio = new Audio('file://' + lectureData.audioFiles[i]); | |
| audioElements.push(audio); | |
| }} else {{ | |
| audioElements.push(null); | |
| }} | |
| }} | |
| function updateSlide() {{ | |
| slideCounter.textContent = `Slide ${{currentSlide + 1}} of ${{totalSlides}}`; | |
| progressFill.style.width = `${{(currentSlide + 1) / totalSlides * 100}}%`; | |
| if (currentAudio) {{ | |
| currentAudio.pause(); | |
| currentAudio.currentTime = 0; | |
| }} | |
| if (audioElements[currentSlide]) {{ | |
| currentAudio = audioElements[currentSlide]; | |
| currentAudio.play().catch(e => console.error('Audio play failed:', e)); | |
| }} else {{ | |
| currentAudio = null; | |
| }} | |
| }} | |
| function prevSlide() {{ | |
| if (currentSlide > 0) {{ | |
| currentSlide--; | |
| updateSlide(); | |
| }} | |
| }} | |
| function nextSlide() {{ | |
| if (currentSlide < totalSlides - 1) {{ | |
| currentSlide++; | |
| updateSlide(); | |
| }} | |
| }} | |
| function togglePlay() {{ | |
| if (!audioElements[currentSlide]) return; | |
| if (currentAudio.paused) {{ | |
| currentAudio.play().catch(e => console.error('Audio play failed:', e)); | |
| }} else {{ | |
| currentAudio.pause(); | |
| }} | |
| }} | |
| audioElements.forEach((audio, index) => {{ | |
| if (audio) {{ | |
| audio.addEventListener('ended', () => {{ | |
| if (index < totalSlides - 1) {{ | |
| nextSlide(); | |
| }} | |
| }}); | |
| }} | |
| }}); | |
| </script> | |
| """ | |
| yield { | |
| "pdf": pdf_file, | |
| "html": html_output, | |
| "zip": zip_path | |
| } | |
| return | |
| except Exception as e: | |
| logger.error("Error during lecture generation: %s\n%s", str(e), traceback.format_exc()) | |
| yield html_with_progress(f"Error during lecture generation: {str(e)}", progress) | |
| return | |
| # Gradio interface | |
| with gr.Blocks(title="Agent Feynman") as demo: | |
| gr.Markdown("# <center>Learn Anything With Professor AI Feynman</center>") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| title = gr.Textbox(label="Lecture Title", placeholder="e.g. Introduction to AI") | |
| topic = gr.Textbox(label="Topic", placeholder="e.g. Artificial Intelligence") | |
| instructions = gr.Textbox(label="Additional Instructions", placeholder="e.g. Focus on recent advancements") | |
| lecture_type = gr.Dropdown(["Conference", "University", "High school"], label="Audience", value="University") | |
| api_service = gr.Dropdown( | |
| choices=[ | |
| "OpenAI-gpt-4o-2024-08-06", | |
| "Anthropic-claude-3-sonnet-20240229", | |
| "Google-gemini-1.5-flash", | |
| "Ollama-llama3.2" | |
| ], | |
| label="Model", | |
| value="Google-gemini-1.5-flash" | |
| ) | |
| api_key = gr.Textbox(label="Model Provider API Key", type="password", placeholder="Not required for Ollama") | |
| serpapi_key = gr.Textbox(label="SerpApi Key", type="password", placeholder="Enter your SerpApi key") | |
| num_slides = gr.Slider(1, 20, step=1, label="Number of Content Slides", value=3) | |
| speaker_audio = gr.Audio(label="Speaker sample audio (MP3 or WAV)", type="filepath", elem_id="speaker-audio") | |
| generate_btn = gr.Button("Generate Lecture") | |
| with gr.Column(scale=2): | |
| default_slide_html = """ | |
| <div style="display: flex; flex-direction: column; justify-content: center; align-items: center; height: 100%; min-height: 700px; padding: 20px; text-align: center; border: 1px solid #ddd; border-radius: 8px;"> | |
| <h2 style="font-style: italic; color: #555;">Waiting for lecture content...</h2> | |
| <p style="margin-top: 10px; font-size: 16px;">Please Generate lecture content via the form on the left first before lecture begins</p> | |
| </div> | |
| """ | |
| slide_display = gr.HTML(label="Lecture Slides", value=default_slide_html) | |
| pdf_display = gr.PDF(label="Lecture Slides PDF") | |
| outputs_zip = gr.File(label="Download Outputs (PDF, Audio, Scripts)") | |
| speaker_audio.change( | |
| fn=update_audio_preview, | |
| inputs=speaker_audio, | |
| outputs=speaker_audio | |
| ) | |
| generate_btn.click( | |
| fn=on_generate, | |
| inputs=[api_service, api_key, serpapi_key, title, topic, instructions, lecture_type, speaker_audio, num_slides], | |
| outputs=[slide_display, pdf_display, outputs_zip] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |