Spaces:
Build error
Build error
| import os | |
| import gc | |
| import sys | |
| import time | |
| import torch | |
| import spaces | |
| import torchaudio | |
| import numpy as np | |
| from scipy.signal import resample | |
| from pyannote.audio import Pipeline | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| from difflib import SequenceMatcher | |
| from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor, Wav2Vec2ForCTC, AutoProcessor, AutoTokenizer, AutoModelForSeq2SeqLM | |
| from difflib import SequenceMatcher | |
| import logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[logging.StreamHandler(sys.stdout)] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class ChunkedTranscriber: | |
| def __init__(self, chunk_size=5, overlap=1, sample_rate=16000): | |
| self.chunk_size = chunk_size | |
| self.overlap = overlap | |
| self.sample_rate = sample_rate | |
| self.previous_text = "" | |
| self.previous_lang = None | |
| self.speaker_diarization_pipeline = self.load_speaker_diarization_pipeline() | |
| def load_speaker_diarization_pipeline(self): | |
| """ | |
| Load the pre-trained speaker diarization pipeline from pyannote-audio. | |
| """ | |
| pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=os.getenv("HF_TOKEN")) | |
| return pipeline | |
| def diarize_audio(self, audio_path): | |
| """ | |
| Perform speaker diarization on the input audio. | |
| """ | |
| diarization_result = self.speaker_diarization_pipeline({"uri": "audio", "audio": audio_path}) | |
| return diarization_result | |
| def load_lid_mms(self): | |
| model_id = "facebook/mms-lid-256" | |
| processor = AutoFeatureExtractor.from_pretrained(model_id) | |
| model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id) | |
| return processor, model | |
| def language_identification(self, model, processor, chunk, device="cuda"): | |
| inputs = processor(chunk, sampling_rate=16_000, return_tensors="pt") | |
| model.to(device) | |
| inputs.to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs).logits | |
| lang_id = torch.argmax(outputs, dim=-1)[0].item() | |
| detected_lang = model.config.id2label[lang_id] | |
| del model | |
| del inputs | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return detected_lang | |
| def load_mms(self) : | |
| model_id = "facebook/mms-1b-all" | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| model = Wav2Vec2ForCTC.from_pretrained(model_id) | |
| return model, processor | |
| def mms_transcription(self, model, processor, chunk, device="cuda"): | |
| inputs = processor(chunk, sampling_rate=16_000, return_tensors="pt") | |
| model.to(device) | |
| inputs.to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs).logits | |
| ids = torch.argmax(outputs, dim=-1)[0] | |
| transcription = processor.decode(ids) | |
| del model | |
| del inputs | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return transcription | |
| def load_T2T_translation_model(self) : | |
| model_id = "facebook/nllb-200-distilled-600M" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_id) | |
| return model, tokenizer | |
| def text2text_translation(self, translation_model, translation_tokenizer, transcript, device="cuda"): | |
| # model, tokenizer = load_translation_model() | |
| tokenized_inputs = translation_tokenizer(transcript, return_tensors='pt') | |
| translation_model.to(device) | |
| tokenized_inputs.to(device) | |
| translated_tokens = translation_model.generate(**tokenized_inputs, | |
| forced_bos_token_id=translation_tokenizer.convert_tokens_to_ids("eng_Latn"), | |
| max_length=100) | |
| del translation_model | |
| del tokenized_inputs | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return translation_tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] | |
| def preprocess_audio(self, audio): | |
| """ | |
| Create overlapping chunks with improved timing logic | |
| """ | |
| chunk_samples = int(self.chunk_size * self.sample_rate) | |
| overlap_samples = int(self.overlap * self.sample_rate) | |
| chunks_with_times = [] | |
| start_idx = 0 | |
| while start_idx < len(audio): | |
| end_idx = min(start_idx + chunk_samples, len(audio)) | |
| # Add padding for first chunk | |
| if start_idx == 0: | |
| chunk = audio[start_idx:end_idx] | |
| padding = torch.zeros(int(1 * self.sample_rate)) | |
| chunk = torch.cat([padding, chunk]) | |
| else: | |
| # Include overlap from previous chunk | |
| actual_start = max(0, start_idx - overlap_samples) | |
| chunk = audio[actual_start:end_idx] | |
| # Pad if necessary | |
| if len(chunk) < chunk_samples: | |
| chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk))) | |
| # Adjust time ranges to account for overlaps | |
| chunk_start_time = max(0, (start_idx / self.sample_rate) - self.overlap) | |
| chunk_end_time = min((end_idx / self.sample_rate) + self.overlap, len(audio) / self.sample_rate) | |
| chunks_with_times.append({ | |
| 'chunk': chunk, | |
| 'start_time': start_idx / self.sample_rate, | |
| 'end_time': end_idx / self.sample_rate, | |
| 'transcribe_start': chunk_start_time, | |
| 'transcribe_end': chunk_end_time | |
| }) | |
| # Move to next chunk with smaller step size for better continuity | |
| start_idx += (chunk_samples - overlap_samples) | |
| return chunks_with_times | |
| def merge_close_segments(self, results): | |
| """ | |
| Merge segments that are close in time and have the same language | |
| """ | |
| if not results: | |
| return results | |
| merged = [] | |
| current = results[0] | |
| for next_segment in results[1:]: | |
| # Skip empty segments | |
| if not next_segment['text'].strip(): | |
| continue | |
| # If segments are in the same language and close in time | |
| if (current['detected_language'] == next_segment['detected_language'] and | |
| abs(next_segment['start_time'] - current['end_time']) <= self.overlap): | |
| # Merge the segments | |
| current['text'] = current['text'] + ' ' + next_segment['text'] | |
| current['end_time'] = next_segment['end_time'] | |
| if 'translated' in current and 'translated' in next_segment: | |
| current['translated'] = current['translated'] + ' ' + next_segment['translated'] | |
| else: | |
| if current['text'].strip(): # Only add non-empty segments | |
| merged.append(current) | |
| current = next_segment | |
| if current['text'].strip(): # Add the last segment if non-empty | |
| merged.append(current) | |
| return merged | |
| def clean_overlapping_text(self, current_text, prev_text, current_lang, prev_lang, min_overlap=3): | |
| """ | |
| Improved text cleaning with language awareness and better sentence boundary handling | |
| """ | |
| if not prev_text or not current_text: | |
| return current_text | |
| # If languages are different, don't try to merge | |
| if prev_lang and current_lang and prev_lang != current_lang: | |
| return current_text | |
| # Split into words | |
| prev_words = prev_text.split() | |
| curr_words = current_text.split() | |
| if len(prev_words) < 2 or len(curr_words) < 2: | |
| return current_text | |
| # Find matching sequences at the end of prev_text and start of current_text | |
| matcher = SequenceMatcher(None, prev_words, curr_words) | |
| matches = list(matcher.get_matching_blocks()) | |
| # Look for significant overlaps | |
| best_overlap = 0 | |
| overlap_size = 0 | |
| for match in matches: | |
| # Check if the match is at the start of current text | |
| if match.b == 0 and match.size >= min_overlap: | |
| if match.size > overlap_size: | |
| best_overlap = match.size | |
| overlap_size = match.size | |
| if best_overlap > 0: | |
| # Remove overlapping content while preserving sentence integrity | |
| cleaned_words = curr_words[best_overlap:] | |
| if not cleaned_words: # If everything was overlapping | |
| return "" | |
| return ' '.join(cleaned_words).strip() | |
| return current_text | |
| def process_chunk(self, chunk_data, mms_model, mms_processor, translation_model=None, translation_tokenizer=None): | |
| """ | |
| Process chunk with improved language handling | |
| """ | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| try: | |
| # Language detection | |
| lid_processor, lid_model = self.load_lid_mms() | |
| lid_lang = self.language_identification(lid_model, lid_processor, chunk_data['chunk']) | |
| # Configure processor | |
| mms_processor.tokenizer.set_target_lang(lid_lang) | |
| mms_model.load_adapter(lid_lang) | |
| # Transcribe | |
| inputs = mms_processor(chunk_data['chunk'], sampling_rate=self.sample_rate, return_tensors="pt") | |
| inputs = inputs.to(device) | |
| mms_model = mms_model.to(device) | |
| with torch.no_grad(): | |
| outputs = mms_model(**inputs).logits | |
| ids = torch.argmax(outputs, dim=-1)[0] | |
| transcription = mms_processor.decode(ids) | |
| # Clean overlapping text with language awareness | |
| cleaned_transcription = self.clean_overlapping_text( | |
| transcription, | |
| self.previous_text, | |
| lid_lang, | |
| self.previous_lang, | |
| min_overlap=3 | |
| ) | |
| # Update previous state | |
| self.previous_text = transcription | |
| self.previous_lang = lid_lang | |
| if not cleaned_transcription.strip(): | |
| return None | |
| result = { | |
| 'start_time': chunk_data['start_time'], | |
| 'end_time': chunk_data['end_time'], | |
| 'text': cleaned_transcription, | |
| 'detected_language': lid_lang | |
| } | |
| # Handle translation | |
| if translation_model and translation_tokenizer and cleaned_transcription.strip(): | |
| translation = self.text2text_translation( | |
| translation_model, | |
| translation_tokenizer, | |
| cleaned_transcription | |
| ) | |
| result['translated'] = translation | |
| return result | |
| except Exception as e: | |
| print(f"Error processing chunk: {str(e)}") | |
| return None | |
| finally: | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| def translate_text(self, text, translation_model, translation_tokenizer, device): | |
| """ | |
| Translate cleaned text using the provided translation model. | |
| """ | |
| tokenized_inputs = translation_tokenizer(text, return_tensors='pt') | |
| tokenized_inputs = tokenized_inputs.to(device) | |
| translation_model = translation_model.to(device) | |
| translated_tokens = translation_model.generate( | |
| **tokenized_inputs, | |
| forced_bos_token_id=translation_tokenizer.convert_tokens_to_ids("eng_Latn"), | |
| max_length=100 | |
| ) | |
| translation = translation_tokenizer.batch_decode( | |
| translated_tokens, | |
| skip_special_tokens=True | |
| )[0] | |
| del translation_model | |
| del tokenized_inputs | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return translation | |
| def transcribe_audio(self, audio_path, translate=False): | |
| """ | |
| Main transcription function with improved segment merging | |
| """ | |
| # Perform speaker diarization | |
| diarization_result = self.diarize_audio(audio_path) | |
| # Extract speaker segments | |
| speaker_segments = [] | |
| for turn, _, speaker in diarization_result.itertracks(yield_label=True): | |
| speaker_segments.append({ | |
| 'start_time': turn.start, | |
| 'end_time': turn.end, | |
| 'speaker': speaker | |
| }) | |
| audio = self.load_audio(audio_path) | |
| chunks = self.preprocess_audio(audio) | |
| mms_model, mms_processor = self.load_mms() | |
| translation_model, translation_tokenizer = None, None | |
| if translate: | |
| translation_model, translation_tokenizer = self.load_T2T_translation_model() | |
| # Process chunks | |
| results = [] | |
| for chunk_data in chunks: | |
| result = self.process_chunk( | |
| chunk_data, | |
| mms_model, | |
| mms_processor, | |
| translation_model, | |
| translation_tokenizer | |
| ) | |
| if result: | |
| for segment in speaker_segments: | |
| if int(segment['start_time']) <= int(chunk_data['start_time']) < int(segment['end_time']): | |
| result['speaker'] = segment['speaker'] | |
| break | |
| results.append(result) | |
| # results.append(result) | |
| # Merge close segments and clean up | |
| merged_results = self.merge_close_segments(results) | |
| _translation = "" | |
| _output = "" | |
| for res in merged_results: | |
| _translation+=res['translated'] | |
| _output+=f"{res['start_time']}-{res['end_time']} - Speaker: {res['speaker'].split('_')[1]} - Language: {res['detected_language']}\n Text: {res['text']}\n Translation: {res['translated']}\n\n" | |
| logger.info(f"\n\n TRANSLATION: {_translation}") | |
| return _translation, _output | |
| def load_audio(self, audio_path): | |
| """ | |
| Load and preprocess audio file. | |
| """ | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| # Convert to mono if stereo | |
| if waveform.shape[0] > 1: | |
| waveform = torch.mean(waveform, dim=0) | |
| else: | |
| waveform = waveform.squeeze(0) | |
| # Resample if necessary | |
| if sample_rate != self.sample_rate: | |
| resampler = torchaudio.transforms.Resample( | |
| orig_freq=sample_rate, | |
| new_freq=self.sample_rate | |
| ) | |
| waveform = resampler(waveform) | |
| return waveform.float() |