|
|
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
|
|
|
import gradio as gr
|
|
|
import numpy as np
|
|
|
import scipy.io.wavfile
|
|
|
import tempfile
|
|
|
import os
|
|
|
from transformers import VitsModel, AutoTokenizer
|
|
|
import torch
|
|
|
import re
|
|
|
import traceback
|
|
|
|
|
|
print("Starting application...")
|
|
|
|
|
|
|
|
|
punct_pipe = None
|
|
|
model = None
|
|
|
tokenizer = None
|
|
|
|
|
|
def load_models():
|
|
|
global punct_pipe, model, tokenizer
|
|
|
|
|
|
print("Loading punctuation model...")
|
|
|
try:
|
|
|
punctuation_model_id = "oliverguhr/fullstop-punctuation-multilang-large"
|
|
|
punct_tokenizer = AutoTokenizer.from_pretrained(punctuation_model_id)
|
|
|
punct_model = AutoModelForTokenClassification.from_pretrained(punctuation_model_id)
|
|
|
punct_pipe = pipeline("token-classification", model=punct_model, tokenizer=punct_tokenizer, aggregation_strategy="simple")
|
|
|
print("✓ Punctuation model loaded successfully")
|
|
|
except Exception as e:
|
|
|
print(f"✗ Error loading punctuation model: {e}")
|
|
|
punct_pipe = None
|
|
|
|
|
|
print("Loading TTS model...")
|
|
|
try:
|
|
|
model = VitsModel.from_pretrained("facebook/mms-tts-kmr-script_latin")
|
|
|
tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-kmr-script_latin")
|
|
|
print("✓ TTS model loaded successfully")
|
|
|
except Exception as e:
|
|
|
print(f"✗ Error loading TTS model: {e}")
|
|
|
model = None
|
|
|
tokenizer = None
|
|
|
|
|
|
|
|
|
load_models()
|
|
|
|
|
|
|
|
|
num2word = {
|
|
|
"0": "sifir", "1": "yek", "2": "du", "3": "sê", "4": "çar", "5": "pênc",
|
|
|
"6": "şeş", "7": "heft", "8": "heşt", "9": "neh", "10": "deh"
|
|
|
}
|
|
|
|
|
|
def replace_numbers_with_words(text):
|
|
|
def repl(match):
|
|
|
num = match.group()
|
|
|
return num2word.get(num, num)
|
|
|
return re.sub(r'\b\d+\b', repl, text)
|
|
|
|
|
|
def restore_punctuation(text):
|
|
|
if punct_pipe is None:
|
|
|
print("Punctuation model not available, skipping...")
|
|
|
return text
|
|
|
|
|
|
try:
|
|
|
results = punct_pipe(text)
|
|
|
punctuated = ""
|
|
|
for token in results:
|
|
|
word = token['word']
|
|
|
punct = token.get('entity_group', '')
|
|
|
if punct == "PERIOD":
|
|
|
punctuated += word + ". "
|
|
|
elif punct == "COMMA":
|
|
|
punctuated += word + ", "
|
|
|
else:
|
|
|
punctuated += word + " "
|
|
|
return punctuated.strip()
|
|
|
except Exception as e:
|
|
|
print(f"Punctuation error: {e}")
|
|
|
return text
|
|
|
|
|
|
def text_to_speech(text):
|
|
|
print(f"=== TTS Function Called ===")
|
|
|
print(f"Input text: '{text}'")
|
|
|
|
|
|
try:
|
|
|
|
|
|
if not text or text.strip() == "":
|
|
|
error_msg = "Please enter some text"
|
|
|
print(f"Error: {error_msg}")
|
|
|
return None
|
|
|
|
|
|
|
|
|
if model is None or tokenizer is None:
|
|
|
error_msg = "TTS model not loaded properly"
|
|
|
print(f"Error: {error_msg}")
|
|
|
return None
|
|
|
|
|
|
print("Processing text...")
|
|
|
|
|
|
|
|
|
processed_text = text.strip()
|
|
|
processed_text = replace_numbers_with_words(processed_text)
|
|
|
print(f"Processed text: '{processed_text}'")
|
|
|
|
|
|
|
|
|
print("Tokenizing...")
|
|
|
inputs = tokenizer(processed_text, return_tensors="pt")
|
|
|
print(f"Tokenized successfully, input_ids shape: {inputs['input_ids'].shape}")
|
|
|
|
|
|
|
|
|
print("Generating audio...")
|
|
|
with torch.no_grad():
|
|
|
output = model(**inputs).waveform
|
|
|
print(f"Audio generated, shape: {output.shape}")
|
|
|
|
|
|
|
|
|
waveform = output.squeeze().numpy()
|
|
|
print(f"Waveform shape: {waveform.shape}")
|
|
|
|
|
|
|
|
|
print("Saving audio file...")
|
|
|
tmp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
|
|
tmp_path = tmp_file.name
|
|
|
tmp_file.close()
|
|
|
|
|
|
scipy.io.wavfile.write(
|
|
|
tmp_path,
|
|
|
rate=model.config.sampling_rate,
|
|
|
data=waveform
|
|
|
)
|
|
|
|
|
|
print(f"✓ Audio saved to: {tmp_path}")
|
|
|
print("=== TTS Function Completed Successfully ===")
|
|
|
return tmp_path
|
|
|
|
|
|
except Exception as e:
|
|
|
error_msg = f"Error in TTS: {str(e)}"
|
|
|
print(f"✗ {error_msg}")
|
|
|
print("Full traceback:")
|
|
|
traceback.print_exc()
|
|
|
return None
|
|
|
|
|
|
|
|
|
def test_function(text):
|
|
|
print(f"Test function called with: {text}")
|
|
|
return f"You entered: {text}"
|
|
|
|
|
|
|
|
|
print("Creating Gradio interface...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
interface = gr.Interface(
|
|
|
fn=text_to_speech,
|
|
|
inputs=gr.Textbox(
|
|
|
label="Enter Kurmanji Text",
|
|
|
placeholder="e.g. Silav! Ez bi xêr im.",
|
|
|
lines=2,
|
|
|
value=""
|
|
|
),
|
|
|
outputs=gr.Audio(label="Generated Speech"),
|
|
|
title="Kurmanji Text-to-Speech",
|
|
|
description="Enter Kurmanji Kurdish text to convert to speech.",
|
|
|
examples=[
|
|
|
["Silav"],
|
|
|
["Ez bi xêr im"],
|
|
|
["Spas"]
|
|
|
],
|
|
|
cache_examples=False,
|
|
|
flagging_mode="never"
|
|
|
)
|
|
|
|
|
|
print("Launching interface...")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
interface.launch(
|
|
|
debug=True,
|
|
|
share=False,
|
|
|
show_error=True,
|
|
|
server_name="0.0.0.0" if "SPACE_ID" in os.environ else "127.0.0.1",
|
|
|
server_port=7860
|
|
|
) |