Spaces:
Sleeping
Sleeping
| import os | |
| # Fix OpenMP environment variable issue | |
| os.environ['OMP_NUM_THREADS'] = '1' | |
| import gradio as gr | |
| from nemo.collections.speechlm2.models import SALM | |
| import torch | |
| import tempfile | |
| # Load model using official NVIDIA NeMo approach | |
| model_id = "nvidia/canary-qwen-2.5b" | |
| print("Loading NVIDIA Canary-Qwen-2.5B model using NeMo...") | |
| model = SALM.from_pretrained(model_id) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| def generate_text(prompt, max_tokens=200, temperature=0.7, top_p=0.9): | |
| """Generate text using the NVIDIA NeMo model (LLM mode)""" | |
| try: | |
| # Use LLM mode (text-only) as per official documentation | |
| with model.llm.disable_adapter(): | |
| answer_ids = model.generate( | |
| prompts=[[{"role": "user", "content": prompt}]], | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True | |
| ) | |
| # Convert IDs to text using model's tokenizer | |
| # response = model.tokenizer.ids_to_text(answer_ids[0].cpu()) | |
| response = model.tokenizer.ids_to_text(answer_ids[0].to(device)) | |
| return response | |
| except Exception as e: | |
| return f"Error generating text: {str(e)}" | |
| def transcribe_audio(audio_file, user_prompt="Transcribe the following:"): | |
| """Transcribe audio using ASR mode""" | |
| try: | |
| if audio_file is None: | |
| return "No audio file provided" | |
| # Use ASR mode (speech-to-text) as per official documentation | |
| answer_ids = model.generate( | |
| prompts=[ | |
| [{"role": "user", "content": f"{user_prompt} {model.audio_locator_tag}", "audio": [audio_file]}] | |
| ], | |
| max_new_tokens=128, | |
| ) | |
| # Convert IDs to text | |
| # transcript = model.tokenizer.ids_to_text(answer_ids[0].cpu()) | |
| transcript = model.tokenizer.ids_to_text(answer_ids[0].to(device)) | |
| return transcript | |
| except Exception as e: | |
| return f"Error transcribing audio: {str(e)}" | |
| def chat_interface(message, history, max_tokens, temperature, top_p): | |
| """Chat interface for Gradio""" | |
| # Build conversation context | |
| conversation = "" | |
| for user_msg, bot_msg in history: | |
| conversation += f"User: {user_msg}\nAssistant: {bot_msg}\n" | |
| conversation += f"User: {message}\nAssistant: " | |
| # Generate response | |
| response = generate_text(conversation, max_tokens, temperature, top_p) | |
| # Update history | |
| history.append((message, response)) | |
| return "", history | |
| # Create Gradio interface | |
| with gr.Blocks(title="NVIDIA Canary-Qwen-2.5B Chat") as demo: | |
| gr.HTML(""" | |
| <div style="text-align: center;"> | |
| <h1>π€ NVIDIA Canary-Qwen-2.5B</h1> | |
| <p>Official NeMo implementation - Speech-to-Text & Text Generation</p> | |
| <p><strong>Capabilities:</strong> Audio Transcription + Text Chat</p> | |
| </div> | |
| """) | |
| with gr.Tab("π€ Audio Transcription (ASR)"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = gr.Audio( | |
| label="Upload Audio File (.wav or .flac)", | |
| type="filepath", | |
| format="wav" | |
| ) | |
| asr_prompt = gr.Textbox( | |
| label="Custom Prompt (optional)", | |
| value="Transcribe the following:", | |
| placeholder="Enter custom transcription prompt..." | |
| ) | |
| transcribe_btn = gr.Button("π€ Transcribe Audio", variant="primary") | |
| transcript_output = gr.Textbox( | |
| label="Transcription Result", | |
| lines=8, | |
| max_lines=15 | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["Transcribe the following:"], | |
| ["Please transcribe this audio in detail:"], | |
| ["Convert this speech to text:"] | |
| ], | |
| inputs=[asr_prompt] | |
| ) | |
| with gr.Tab("π¬ Text Chat (LLM)"): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot(height=400) | |
| msg = gr.Textbox(label="Your message", placeholder="Type here...") | |
| with gr.Row(): | |
| submit_btn = gr.Button("Send", variant="primary") | |
| clear_btn = gr.Button("Clear Chat") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### βοΈ Settings") | |
| max_tokens = gr.Slider( | |
| minimum=10, maximum=500, value=200, step=10, | |
| label="Max Tokens" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, maximum=2.0, value=0.7, step=0.1, | |
| label="Temperature" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, maximum=1.0, value=0.9, step=0.05, | |
| label="Top-p" | |
| ) | |
| with gr.Tab("π Single Generation"): | |
| with gr.Column(): | |
| prompt_input = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Enter your prompt...", | |
| lines=5 | |
| ) | |
| generate_btn = gr.Button("Generate", variant="primary") | |
| output_text = gr.Textbox( | |
| label="Generated Text", | |
| lines=10, | |
| max_lines=20 | |
| ) | |
| with gr.Row(): | |
| single_max_tokens = gr.Slider(10, 500, 200, label="Max Tokens") | |
| single_temperature = gr.Slider(0.1, 2.0, 0.7, label="Temperature") | |
| single_top_p = gr.Slider(0.1, 1.0, 0.9, label="Top-p") | |
| with gr.Tab("βΉοΈ Model Info"): | |
| gr.Markdown(""" | |
| ## NVIDIA Canary-Qwen-2.5B Model Information | |
| ### Capabilities: | |
| - π€ **Audio Transcription (ASR)**: Convert speech to text | |
| - π¬ **Text Generation (LLM)**: Chat and text completion | |
| - π― **Multimodal**: Combines audio and text processing | |
| ### Model Details: | |
| - **Size**: 2.5 billion parameters | |
| - **Framework**: NVIDIA NeMo | |
| - **Audio Input**: 16kHz mono-channel .wav or .flac files | |
| - **Languages**: Multiple languages supported | |
| ### Usage Tips: | |
| 1. **For Audio**: Upload .wav or .flac files (16kHz recommended) | |
| 2. **For Text**: Use natural language prompts | |
| 3. **Custom Prompts**: You can modify transcription prompts | |
| 4. **Parameters**: Adjust temperature and tokens for different outputs | |
| ### Official Documentation: | |
| - [Model Card](https://huggingface.co/nvidia/canary-qwen-2.5b) | |
| - [NVIDIA NeMo](https://github.com/NVIDIA/NeMo) | |
| """) | |
| # Event handlers | |
| transcribe_btn.click( | |
| transcribe_audio, | |
| inputs=[audio_input, asr_prompt], | |
| outputs=[transcript_output] | |
| ) | |
| # Event handlers | |
| submit_btn.click( | |
| chat_interface, | |
| inputs=[msg, chatbot, max_tokens, temperature, top_p], | |
| outputs=[msg, chatbot] | |
| ) | |
| msg.submit( | |
| chat_interface, | |
| inputs=[msg, chatbot, max_tokens, temperature, top_p], | |
| outputs=[msg, chatbot] | |
| ) | |
| clear_btn.click(lambda: ([], ""), outputs=[chatbot, msg]) | |
| generate_btn.click( | |
| generate_text, | |
| inputs=[prompt_input, single_max_tokens, single_temperature, single_top_p], | |
| outputs=[output_text] | |
| ) | |
| # Example prompts | |
| gr.Examples( | |
| examples=[ | |
| ["Explain quantum computing in simple terms"], | |
| ["Write a short story about AI"], | |
| ["What are the benefits of renewable energy?"], | |
| ["How do neural networks work?"], | |
| ["Summarize the key points about machine learning"] | |
| ], | |
| inputs=[prompt_input] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |