Spaces:
Build error
Build error
| import os | |
| import json | |
| import gradio as gr | |
| import torch | |
| from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM | |
| import logging | |
| import traceback | |
| import sys | |
| from audio_processing import AudioProcessor | |
| import spaces | |
| from chunkedTranscriber import ChunkedTranscriber | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[logging.StreamHandler(sys.stdout)] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def load_qa_model(): | |
| """Load question-answering model""" | |
| try: | |
| model_id = "meta-llama/Meta-Llama-3-8B-Instruct" | |
| qa_pipeline = pipeline( | |
| "text-generation", | |
| model="hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4", | |
| model_kwargs={"torch_dtype": torch.bfloat16}, | |
| device_map="auto", | |
| use_auth_token=os.getenv("HF_TOKEN") | |
| ) | |
| return qa_pipeline | |
| except Exception as e: | |
| logger.error(f"Failed to load Q&A model: {str(e)}") | |
| return None | |
| def load_summarization_model(): | |
| """Load summarization model""" | |
| try: | |
| summarizer = pipeline( | |
| "summarization", | |
| model="sshleifer/distilbart-cnn-12-6", | |
| device=0 if torch.cuda.is_available() else -1 | |
| ) | |
| return summarizer | |
| except Exception as e: | |
| logger.error(f"Failed to load summarization model: {str(e)}") | |
| return None | |
| def process_audio(audio_file, translate=False): | |
| """Process audio file""" | |
| transcriber = ChunkedTranscriber(chunk_size=5, overlap=1) | |
| results = transcriber.transcribe_audio(audio_file, translate=True) | |
| return results | |
| # try: | |
| # processor = AudioProcessor() | |
| # language_segments, final_segments = processor.process_audio(audio_file, translate) | |
| # # Format output | |
| # transcription = "" | |
| # full_text = "" | |
| # # Add language detection information | |
| # for segment in language_segments: | |
| # transcription += f"Language: {segment['language']}\n" | |
| # transcription += f"Time: {segment['start']:.2f}s - {segment['end']:.2f}s\n\n" | |
| # # Add transcription/translation information | |
| # transcription += "Transcription with language detection:\n\n" | |
| # for segment in final_segments: | |
| # transcription += f"[{segment['start']:.2f}s - {segment['end']:.2f}s] ({segment['language']}):\n" | |
| # transcription += f"Original: {segment['text']}\n" | |
| # if translate and 'translated' in segment: | |
| # transcription += f"Translated: {segment['translated']}\n" | |
| # full_text += segment['translated'] + " " | |
| # else: | |
| # full_text += segment['text'] + " " | |
| # transcription += "\n" | |
| # return transcription, full_text | |
| # except Exception as e: | |
| # logger.error(f"Audio processing failed: {str(e)}") | |
| # raise gr.Error(f"Processing failed: {str(e)}") | |
| def summarize_text(text): | |
| """Summarize text""" | |
| try: | |
| summarizer = load_summarization_model() | |
| if summarizer is None: | |
| return "Summarization model could not be loaded." | |
| logger.info("Successfully loaded summarization Model") | |
| data = json.loads(text) | |
| translated_text = ''.join(item['translated'] for item in data if 'translated' in item) | |
| # full_text = ''.join(item['translated'] for item in results if 'translated' in item) | |
| logger.info(f"\n\nWorking on text:\n{full_text}") | |
| summary = summarizer( full_text, max_length=150, min_length=50, do_sample=False)[0]['summary_text'] | |
| return summary | |
| except Exception as e: | |
| logger.error(f"Summarization failed: {str(e)}") | |
| return "Error occurred during summarization." | |
| def answer_question(context, question): | |
| """Answer questions about the text""" | |
| try: | |
| qa_pipeline = load_qa_model() | |
| if qa_pipeline is None: | |
| return "Q&A model could not be loaded." | |
| if not question : | |
| return "Please enter your Question" | |
| messages = [ | |
| {"role": "system", "content": "You are a helpful assistant who can answer questions based on the given context."}, | |
| {"role": "user", "content": f"Context: {''.join(item['translated'] for item in context if 'translated' in item)}\n\nQuestion: {question}"} | |
| ] | |
| response = qa_pipeline(messages, max_new_tokens=256)[0]['generated_text'] | |
| return response | |
| except Exception as e: | |
| logger.error(f"Q&A failed: {str(e)}") | |
| return f"Error occurred during Q&A process: {str(e)}" | |
| # Create Gradio interface | |
| with gr.Blocks() as iface: | |
| gr.Markdown("# Automatic Speech Recognition for Indic Languages") | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = gr.Audio(type="filepath") | |
| translate_checkbox = gr.Checkbox(label="Enable Translation") | |
| process_button = gr.Button("Process Audio") | |
| with gr.Column(): | |
| # ASR_RESULT = gr.Textbox(label="Output") | |
| full_text_output = gr.Textbox(label="Full Text", lines=5) | |
| translation_output = gr.Textbox(label="Transcription/Translation", lines=10) | |
| with gr.Row(): | |
| with gr.Column(): | |
| summarize_button = gr.Button("Summarize") | |
| summary_output = gr.Textbox(label="Summary", lines=3) | |
| with gr.Column(): | |
| question_input = gr.Textbox(label="Ask a question about the transcription") | |
| answer_button = gr.Button("Get Answer") | |
| answer_output = gr.Textbox(label="Answer", lines=3) | |
| # Set up event handlers | |
| process_button.click( | |
| process_audio, | |
| inputs=[audio_input, translate_checkbox], | |
| outputs=[translation_output, full_text_output] | |
| # outputs=[ASR_RESULT] | |
| ) | |
| # logger.info(f"{ASR_RESULT}") | |
| # translated_text = ''.join(item['translated'] for item in ASR_RESULT if 'translated' in item) | |
| summarize_button.click( | |
| summarize_text, | |
| # inputs=[ASR_RESULT], | |
| inputs=[translation_output], | |
| outputs=[summary_output] | |
| ) | |
| answer_button.click( | |
| answer_question, | |
| inputs=[translation_output, question_input], | |
| outputs=[answer_output] | |
| ) | |
| # Add system information | |
| gr.Markdown(f""" | |
| ## System Information | |
| - Device: {"CUDA" if torch.cuda.is_available() else "CPU"} | |
| - CUDA Available: {"Yes" if torch.cuda.is_available() else "No"} | |
| ## Features | |
| - Automatic language detection | |
| - High-quality transcription using MMS | |
| - Optional translation to English | |
| - Text summarization | |
| - Question answering | |
| """) | |
| if __name__ == "__main__": | |
| iface.launch(server_port=None) |