Spaces:
Build error
Build error
| import gc | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| from transformers import ( | |
| Wav2Vec2ForSequenceClassification, | |
| AutoFeatureExtractor, | |
| Wav2Vec2ForCTC, | |
| AutoProcessor, | |
| AutoTokenizer, | |
| AutoModelForSeq2SeqLM | |
| ) | |
| import spaces | |
| import logging | |
| from difflib import SequenceMatcher | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| class AudioProcessor: | |
| def __init__(self, chunk_size=5, overlap=1, sample_rate=16000): | |
| self.chunk_size = chunk_size | |
| self.overlap = overlap | |
| self.sample_rate = sample_rate | |
| self.previous_text = "" | |
| self.previous_lang = None | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def load_models(self): | |
| """Load all required models""" | |
| logger.info("Loading MMS models...") | |
| # Language identification model | |
| lid_processor = AutoFeatureExtractor.from_pretrained("facebook/mms-lid-256") | |
| lid_model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/mms-lid-256") | |
| # Transcription model | |
| mms_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all") | |
| mms_model = Wav2Vec2ForCTC.from_pretrained("facebook/mms-1b-all") | |
| # Translation model | |
| translation_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") | |
| translation_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") | |
| return { | |
| 'lid': (lid_model, lid_processor), | |
| 'mms': (mms_model, mms_processor), | |
| 'translation': (translation_model, translation_tokenizer) | |
| } | |
| def identify_language(self, audio_chunk, models): | |
| """Identify language of audio chunk""" | |
| lid_model, lid_processor = models['lid'] | |
| inputs = lid_processor(audio_chunk, sampling_rate=16000, return_tensors="pt") | |
| lid_model.to(self.device) | |
| with torch.no_grad(): | |
| outputs = lid_model(inputs.input_values.to(self.device)).logits | |
| lang_id = torch.argmax(outputs, dim=-1)[0].item() | |
| detected_lang = lid_model.config.id2label[lang_id] | |
| return detected_lang | |
| def transcribe_chunk(self, audio_chunk, language, models): | |
| """Transcribe audio chunk""" | |
| mms_model, mms_processor = models['mms'] | |
| mms_processor.tokenizer.set_target_lang(language) | |
| mms_model.load_adapter(language) | |
| mms_model.to(self.device) | |
| inputs = mms_processor(audio_chunk, sampling_rate=16000, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = mms_model(inputs.input_values.to(self.device)).logits | |
| ids = torch.argmax(outputs, dim=-1)[0] | |
| transcription = mms_processor.decode(ids) | |
| return transcription | |
| def translate_text(self, text, models): | |
| """Translate text to English""" | |
| translation_model, translation_tokenizer = models['translation'] | |
| inputs = translation_tokenizer(text, return_tensors="pt") | |
| inputs = inputs.to(self.device) | |
| translation_model.to(self.device) | |
| with torch.no_grad(): | |
| outputs = translation_model.generate( | |
| **inputs, | |
| forced_bos_token_id=translation_tokenizer.convert_tokens_to_ids("eng_Latn"), | |
| max_length=100 | |
| ) | |
| translation = translation_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
| return translation | |
| def preprocess_audio(self, audio): | |
| """ | |
| Create overlapping chunks with improved timing logic | |
| """ | |
| chunk_samples = int(self.chunk_size * self.sample_rate) | |
| overlap_samples = int(self.overlap * self.sample_rate) | |
| chunks_with_times = [] | |
| start_idx = 0 | |
| while start_idx < len(audio): | |
| end_idx = min(start_idx + chunk_samples, len(audio)) | |
| # Add padding for first chunk | |
| if start_idx == 0: | |
| chunk = audio[start_idx:end_idx] | |
| padding = torch.zeros(int(1 * self.sample_rate)) | |
| chunk = torch.cat([padding, chunk]) | |
| else: | |
| # Include overlap from previous chunk | |
| actual_start = max(0, start_idx - overlap_samples) | |
| chunk = audio[actual_start:end_idx] | |
| # Pad if necessary | |
| if len(chunk) < chunk_samples: | |
| chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk))) | |
| # Adjust time ranges to account for overlaps | |
| chunk_start_time = max(0, (start_idx / self.sample_rate) - self.overlap) | |
| chunk_end_time = min((end_idx / self.sample_rate) + self.overlap, len(audio) / self.sample_rate) | |
| chunks_with_times.append({ | |
| 'chunk': chunk, | |
| 'start_time': start_idx / self.sample_rate, | |
| 'end_time': end_idx / self.sample_rate, | |
| 'transcribe_start': chunk_start_time, | |
| 'transcribe_end': chunk_end_time | |
| }) | |
| # Move to next chunk with smaller step size for better continuity | |
| start_idx += (chunk_samples - overlap_samples) | |
| return chunks_with_times | |
| def process_audio(self, audio_path, translate=False): | |
| """Main processing function""" | |
| try: | |
| # Load audio | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| if waveform.shape[0] > 1: | |
| waveform = torch.mean(waveform, dim=0) | |
| else: | |
| waveform = waveform.squeeze(0) | |
| # Resample if necessary | |
| if sample_rate != self.sample_rate: | |
| resampler = torchaudio.transforms.Resample( | |
| orig_freq=sample_rate, | |
| new_freq=self.sample_rate | |
| ) | |
| waveform = resampler(waveform) | |
| # if sample_rate != self.sample_rate: | |
| # waveform = torchaudio.transforms.Resample(sample_rate, self.sample_rate)(waveform) | |
| # Load models | |
| models = self.load_models() | |
| # Process in chunks | |
| chunk_samples = int(self.chunk_size * self.sample_rate) | |
| overlap_samples = int(self.overlap * self.sample_rate) | |
| segments = [] | |
| language_segments = [] | |
| for i in range(0, len(waveform), chunk_samples - overlap_samples): | |
| chunk = waveform[i:i + chunk_samples] | |
| if len(chunk) < chunk_samples: | |
| chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk))) | |
| # Process chunk | |
| start_time = i / self.sample_rate | |
| end_time = (i + len(chunk)) / self.sample_rate | |
| # Identify language | |
| language = self.identify_language(chunk, models) | |
| # Record language segment | |
| language_segments.append({ | |
| "language": language, | |
| "start": start_time, | |
| "end": end_time | |
| }) | |
| # Transcribe | |
| transcription = self.transcribe_chunk(chunk, language, models) | |
| segment = { | |
| "start": start_time, | |
| "end": end_time, | |
| "language": language, | |
| "text": transcription, | |
| "speaker": "Speaker" # Simple speaker assignment | |
| } | |
| if translate: | |
| translation = self.translate_text(transcription, models) | |
| segment["translated"] = translation | |
| segments.append(segment) | |
| # Clean up GPU memory | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # Merge nearby segments | |
| merged_segments = self.merge_segments(segments) | |
| return language_segments, merged_segments | |
| except Exception as e: | |
| logger.error(f"Error processing audio: {str(e)}") | |
| raise | |
| def merge_segments(self, segments, time_threshold=0.5, similarity_threshold=0.7): | |
| """Merge similar nearby segments""" | |
| if not segments: | |
| return segments | |
| merged = [] | |
| current = segments[0] | |
| for next_segment in segments[1:]: | |
| if (next_segment['start'] - current['end'] <= time_threshold and | |
| current['language'] == next_segment['language']): | |
| # Check text similarity | |
| matcher = SequenceMatcher(None, current['text'], next_segment['text']) | |
| similarity = matcher.ratio() | |
| if similarity > similarity_threshold: | |
| # Merge segments | |
| current['end'] = next_segment['end'] | |
| current['text'] = current['text'] + ' ' + next_segment['text'] | |
| if 'translated' in current and 'translated' in next_segment: | |
| current['translated'] = current['translated'] + ' ' + next_segment['translated'] | |
| else: | |
| merged.append(current) | |
| current = next_segment | |
| else: | |
| merged.append(current) | |
| current = next_segment | |
| merged.append(current) | |
| return merged |