Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import spaces | |
| import torch | |
| from pydub import AudioSegment | |
| import numpy as np | |
| import io | |
| from scipy.io import wavfile | |
| from colpali_engine.models import ColQwen2_5Omni, ColQwen2_5OmniProcessor | |
| from transformers.utils.import_utils import is_flash_attn_2_available | |
| import base64 | |
| from scipy.io.wavfile import write | |
| import os | |
| # Global model variables | |
| model = None | |
| processor = None | |
| def load_model(): | |
| """Load model and processor once""" | |
| global model, processor | |
| if model is None: | |
| model = ColQwen2_5Omni.from_pretrained( | |
| "vidore/colqwen-omni-v0.1", | |
| torch_dtype=torch.bfloat16, | |
| device_map="cpu", # Start on CPU for ZeroGPU | |
| attn_implementation="eager" # ZeroGPU compatible | |
| ).eval() | |
| processor = ColQwen2_5OmniProcessor.from_pretrained("manu/colqwen-omni-v0.1") | |
| return model, processor | |
| def chunk_audio(audio_file, chunk_length=30): | |
| """Split audio into chunks""" | |
| audio = AudioSegment.from_file(audio_file.name) | |
| audios = [] | |
| target_rate = 16000 | |
| chunk_length_ms = chunk_length * 1000 | |
| for i in range(0, len(audio), chunk_length_ms): | |
| chunk = audio[i:i + chunk_length_ms] | |
| chunk = chunk.set_channels(1).set_frame_rate(target_rate) | |
| buf = io.BytesIO() | |
| chunk.export(buf, format="wav") | |
| buf.seek(0) | |
| rate, data = wavfile.read(buf) | |
| audios.append(data) | |
| return audios | |
| def embed_audio_chunks(audios): | |
| """Embed audio chunks using GPU""" | |
| model, processor = load_model() | |
| model = model.to('cuda') | |
| # Process in batches | |
| from torch.utils.data import DataLoader | |
| dataloader = DataLoader( | |
| dataset=audios, | |
| batch_size=4, | |
| shuffle=False, | |
| collate_fn=lambda x: processor.process_audios(x) | |
| ) | |
| embeddings = [] | |
| for batch_doc in dataloader: | |
| with torch.no_grad(): | |
| batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()} | |
| embeddings_doc = model(**batch_doc) | |
| embeddings.extend(list(torch.unbind(embeddings_doc.to("cpu")))) | |
| # Move model back to CPU to free GPU memory | |
| model = model.to('cpu') | |
| torch.cuda.empty_cache() | |
| return embeddings | |
| def search_audio(query, embeddings, audios, top_k=5): | |
| """Search for relevant audio chunks""" | |
| model, processor = load_model() | |
| model = model.to('cuda') | |
| # Process query | |
| batch_queries = processor.process_queries([query]).to(model.device) | |
| with torch.no_grad(): | |
| query_embeddings = model(**batch_queries) | |
| # Score against all embeddings | |
| scores = processor.score_multi_vector(query_embeddings, embeddings) | |
| top_indices = scores[0].topk(top_k).indices.tolist() | |
| # Move model back to CPU | |
| model = model.to('cpu') | |
| torch.cuda.empty_cache() | |
| return top_indices | |
| def audio_to_base64(data, rate=16000): | |
| """Convert audio data to base64""" | |
| buf = io.BytesIO() | |
| write(buf, rate, data) | |
| buf.seek(0) | |
| encoded_string = base64.b64encode(buf.read()).decode("utf-8") | |
| return encoded_string | |
| def process_audio_rag(audio_file, query, chunk_length=30, use_openai=False, openai_key=None): | |
| """Main processing function""" | |
| if not audio_file: | |
| return "Please upload an audio file", None, None | |
| # Chunk audio | |
| audios = chunk_audio(audio_file, chunk_length) | |
| # Embed chunks | |
| embeddings = embed_audio_chunks(audios) | |
| # Search for relevant chunks | |
| top_indices = search_audio(query, embeddings, audios) | |
| # Prepare results | |
| result_text = f"Found {len(top_indices)} relevant audio chunks:\n" | |
| result_text += f"Chunk indices: {top_indices}\n\n" | |
| # Save first result as audio file | |
| first_chunk_path = "result_chunk.wav" | |
| wavfile.write(first_chunk_path, 16000, audios[top_indices[0]]) | |
| # Optional: Use OpenAI for answer generation | |
| if use_openai and openai_key: | |
| from openai import OpenAI | |
| client = OpenAI(api_key=openai_key) | |
| content = [{"type": "text", "text": f"Answer the query using the audio files. Query: {query}"}] | |
| for idx in top_indices[:3]: # Use top 3 chunks | |
| content.extend([ | |
| {"type": "text", "text": f"Audio chunk #{idx}:"}, | |
| { | |
| "type": "input_audio", | |
| "input_audio": { | |
| "data": audio_to_base64(audios[idx]), | |
| "format": "wav" | |
| } | |
| } | |
| ]) | |
| try: | |
| completion = client.chat.completions.create( | |
| model="gpt-4o-audio-preview", | |
| messages=[{"role": "user", "content": content}] | |
| ) | |
| result_text += f"\nOpenAI Answer: {completion.choices[0].message.content}" | |
| except Exception as e: | |
| result_text += f"\nOpenAI Error: {str(e)}" | |
| # Create audio visualization | |
| import matplotlib.pyplot as plt | |
| fig, ax = plt.subplots(figsize=(10, 4)) | |
| ax.plot(audios[top_indices[0]]) | |
| ax.set_title(f"Waveform of top matching chunk (#{top_indices[0]})") | |
| ax.set_xlabel("Samples") | |
| ax.set_ylabel("Amplitude") | |
| plt.tight_layout() | |
| return result_text, first_chunk_path, fig | |
| # Create Gradio interface | |
| with gr.Blocks(title="AudioRAG Demo") as demo: | |
| gr.Markdown("# AudioRAG Demo - Semantic Audio Search") | |
| gr.Markdown("Upload an audio file and search through it using natural language queries!") | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = gr.Audio(label="Upload Audio File", type="filepath") | |
| query_input = gr.Textbox(label="Search Query", placeholder="What are you looking for in the audio?") | |
| chunk_length = gr.Slider(minimum=10, maximum=60, value=30, step=5, label="Chunk Length (seconds)") | |
| with gr.Accordion("OpenAI Integration (Optional)", open=False): | |
| use_openai = gr.Checkbox(label="Use OpenAI for answer generation") | |
| openai_key = gr.Textbox(label="OpenAI API Key", type="password") | |
| search_btn = gr.Button("Search Audio", variant="primary") | |
| with gr.Column(): | |
| output_text = gr.Textbox(label="Results", lines=10) | |
| output_audio = gr.Audio(label="Top Matching Audio Chunk", type="filepath") | |
| output_plot = gr.Plot(label="Audio Waveform") | |
| search_btn.click( | |
| fn=process_audio_rag, | |
| inputs=[audio_input, query_input, chunk_length, use_openai, openai_key], | |
| outputs=[output_text, output_audio, output_plot] | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["example_audio.wav", "Was Hannibal well liked by his men?", 30], | |
| ["podcast.mp3", "What did they say about climate change?", 20], | |
| ], | |
| inputs=[audio_input, query_input, chunk_length] | |
| ) | |
| if __name__ == "__main__": | |
| # Load model on startup | |
| load_model() | |
| demo.launch() |