import os
import sys
import urllib.request
import torch
import gradio as gr
import jiwer
import difflib
import pyarabic.araby as araby
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
# ---------- Setup: Clone CATT repo & download diacritization models ----------
CATT_REPO_URL = "https://github.com/abjadai/catt.git"
CATT_FOLDER = "catt"
MODELS_DIR = "models"
ED_URL = "https://github.com/abjadai/catt/releases/download/v2/best_ed_mlm_ns_epoch_178.pt"
EO_URL = "https://github.com/abjadai/catt/releases/download/v2/best_eo_mlm_ns_epoch_193.pt"
os.makedirs(MODELS_DIR, exist_ok=True)
# Clone if needed
if not os.path.isdir(CATT_FOLDER):
os.system(f"git clone {CATT_REPO_URL}")
if CATT_FOLDER not in sys.path:
sys.path.append(CATT_FOLDER)
# Download checkpoints
for url in (ED_URL, EO_URL):
fname = os.path.basename(url)
dest = os.path.join(MODELS_DIR, fname)
if not os.path.isfile(dest):
urllib.request.urlretrieve(url, dest)
# Import CATT modules
from tashkeel_tokenizer import TashkeelTokenizer
from utils import remove_non_arabic
from ed_pl import TashkeelModel as TashkeelModel_ED
from eo_pl import TashkeelModel as TashkeelModel_EO
# Prepare tokenizer & device
tokenizer = TashkeelTokenizer()
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load diacritization models
def load_diacritization_models():
global model_ed, model_eo
max_seq_len = 1024
model_ed = TashkeelModel_ED(tokenizer, max_seq_len=max_seq_len, n_layers=3, learnable_pos_emb=False)
model_ed.load_state_dict(torch.load(os.path.join(MODELS_DIR, os.path.basename(ED_URL)), map_location=device))
model_ed.eval().to(device)
model_eo = TashkeelModel_EO(tokenizer, max_seq_len=max_seq_len, n_layers=6, learnable_pos_emb=False)
model_eo.load_state_dict(torch.load(os.path.join(MODELS_DIR, os.path.basename(EO_URL)), map_location=device))
model_eo.eval().to(device)
load_diacritization_models()
# ---------- Setup: Arabic syllable transcription pipelines ----------
ASR_PIPE = pipeline("automatic-speech-recognition", model="IbrahimSalah/Arabic_speech_Syllables_recognition_Using_Wav2vec2")
MT5_MODEL = AutoModelForSeq2SeqLM.from_pretrained("IbrahimSalah/Arabic_Syllables_to_text_Converter_Using_MT5")
MT5_TOKENIZER = AutoTokenizer.from_pretrained("IbrahimSalah/Arabic_Syllables_to_text_Converter_Using_MT5")
MT5_MODEL.eval()
# Arabic diacritics set
try:
DIACRITICS = {
araby.FATHA, araby.FATHATAN, araby.DAMMA, araby.DAMMATAN,
araby.KASRA, araby.KASRATAN, araby.SUKUN, araby.SHADDA,
}
except:
DIACRITICS = {'\u064B','\u064C','\u064D','\u064E','\u064F','\u0650','\u0651','\u0652'}
# ---------- Core Functions ----------
def diacritize_text(model_type, input_text):
"""
Returns the diacritized text twice: once for display, once for state storage.
"""
text_clean = remove_non_arabic(input_text.strip())
if not text_clean:
return "Please enter some Arabic text.", ""
x = [text_clean]
if model_type == "Encoder-Decoder":
outputs = model_ed.do_tashkeel_batch(x, batch_size=16, verbose=False)
else:
outputs = model_eo.do_tashkeel_batch(x, batch_size=16, verbose=False)
result = outputs[0] if outputs else ""
return result, result
def get_and_process_syllables(audio_path):
# ASR -> syllable sequence -> MT5 conversion
clip = ASR_PIPE(audio_path)["text"]
seq = "|" + clip.replace(" ", "|") + "."
input_ids = MT5_TOKENIZER.encode(seq, return_tensors="pt")
out_ids = MT5_MODEL.generate(
input_ids,
max_length=100,
early_stopping=True,
pad_token_id=MT5_TOKENIZER.pad_token_id,
bos_token_id=MT5_TOKENIZER.bos_token_id,
eos_token_id=MT5_TOKENIZER.eos_token_id,
)
text = MT5_TOKENIZER.decode(out_ids[0][1:], skip_special_tokens=True).split('.')[0]
return text, seq
def get_diacritics_sequence(txt):
return ' '.join([c for c in txt if c in DIACRITICS])
def calculate_metrics(ref, hyp):
if not ref.strip() and not hyp.strip(): return 0.0, 0.0, 0.0
if not ref.strip(): return 1.0, 1.0, 1.0
wer = jiwer.wer(ref, hyp)
ref_d, hyp_d = get_diacritics_sequence(ref), get_diacritics_sequence(hyp)
der = 0.0 if (not ref_d and not hyp_d) else (1.0 if not ref_d else jiwer.wer(ref_d, hyp_d))
cer = jiwer.cer(ref, hyp)
return round(wer,4), round(der,4), round(cer,4)
def highlight_errors(ref, hyp):
ref_w, hyp_w = ref.split(), hyp.split()
matcher = difflib.SequenceMatcher(None, ref_w, hyp_w, autojunk=False)
out_words, errs = [], []
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
if tag == 'equal':
out_words.extend(hyp_w[j1:j2])
elif tag == 'replace':
for w in hyp_w[j1:j2]: out_words.append(f"{w}")
errs.extend(ref_w[i1:i2] + hyp_w[j1:j2])
elif tag == 'delete':
errs.extend(ref_w[i1:i2])
elif tag == 'insert':
for w in hyp_w[j1:j2]: out_words.append(f"{w}")
errs.extend(hyp_w[j1:j2])
return ' '.join(out_words), ', '.join(sorted(set(errs)))
def process_audio_and_compare(audio_path, reference_text):
if not audio_path:
return *("Error: No audio provided.",)*2, None, None, None, "", ""
if not reference_text.strip():
return *("Error: No reference text.",)*2, None, None, None, "", ""
hyp, syll = get_and_process_syllables(audio_path)
wer, der, cer = calculate_metrics(reference_text, hyp) if not hyp.startswith("Error") else (None,None,None)
html_out, errs = highlight_errors(reference_text, hyp) if not hyp.startswith("Error") else ("", "")
return hyp, syll, wer, der, cer, html_out, errs
# ---------- Gradio Interface ----------
with gr.Blocks(theme=gr.themes.Soft()) as app:
gr.Markdown("""
# Arabic Diacritization & Reading Assessment
1. Enter undiacritized Arabic text → Diacritize.
2. Optionally edit the diacritized result.
3. Record/upload audio → Transcribe & Compare.
""")
ref_state = gr.State("")
with gr.Row():
with gr.Column(scale=1):
text_in = gr.Textbox(label="Undiacritized Arabic Text", lines=3, text_align="right")
model_sel = gr.Dropdown(choices=["Encoder-Only","Encoder-Decoder"], value="Encoder-Only", label="Model")
diac_btn = gr.Button("Diacritize Text")
diac_out = gr.Textbox(label="Diacritized Text (Reference)", lines=3, text_align="right", interactive=True)
diac_btn.click(fn=diacritize_text, inputs=[model_sel, text_in], outputs=[diac_out, ref_state])
diac_out.change(fn=lambda text: text, inputs=diac_out, outputs=ref_state)
with gr.Column(scale=1):
audio_in = gr.Audio(label="Record/Upload Audio", type="filepath")
trans_btn = gr.Button("Transcribe & Compare")
hyp_out = gr.Textbox(label="Transcript (Hypothesis)", lines=3, text_align="right")
syl_out = gr.Textbox(label="Transcript Syllables", lines=3, text_align="right")
wer_n = gr.Number(label="WER", precision=4)
der_n = gr.Number(label="DER", precision=4)
cer_n = gr.Number(label="CER", precision=4)
err_html = gr.HTML(label="Highlighted Errors")
err_list = gr.Textbox(label="Error Words")
trans_btn.click(
fn=process_audio_and_compare,
inputs=[audio_in, ref_state],
outputs=[hyp_out, syl_out, wer_n, der_n, cer_n, err_html, err_list]
)
# Launch
if __name__ == "__main__":
app.launch(debug=True, share=True)