import os import sys import urllib.request import torch import gradio as gr import jiwer import difflib import pyarabic.araby as araby from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer # ---------- Setup: Clone CATT repo & download diacritization models ---------- CATT_REPO_URL = "https://github.com/abjadai/catt.git" CATT_FOLDER = "catt" MODELS_DIR = "models" ED_URL = "https://github.com/abjadai/catt/releases/download/v2/best_ed_mlm_ns_epoch_178.pt" EO_URL = "https://github.com/abjadai/catt/releases/download/v2/best_eo_mlm_ns_epoch_193.pt" os.makedirs(MODELS_DIR, exist_ok=True) # Clone if needed if not os.path.isdir(CATT_FOLDER): os.system(f"git clone {CATT_REPO_URL}") if CATT_FOLDER not in sys.path: sys.path.append(CATT_FOLDER) # Download checkpoints for url in (ED_URL, EO_URL): fname = os.path.basename(url) dest = os.path.join(MODELS_DIR, fname) if not os.path.isfile(dest): urllib.request.urlretrieve(url, dest) # Import CATT modules from tashkeel_tokenizer import TashkeelTokenizer from utils import remove_non_arabic from ed_pl import TashkeelModel as TashkeelModel_ED from eo_pl import TashkeelModel as TashkeelModel_EO # Prepare tokenizer & device tokenizer = TashkeelTokenizer() device = "cuda" if torch.cuda.is_available() else "cpu" # Load diacritization models def load_diacritization_models(): global model_ed, model_eo max_seq_len = 1024 model_ed = TashkeelModel_ED(tokenizer, max_seq_len=max_seq_len, n_layers=3, learnable_pos_emb=False) model_ed.load_state_dict(torch.load(os.path.join(MODELS_DIR, os.path.basename(ED_URL)), map_location=device)) model_ed.eval().to(device) model_eo = TashkeelModel_EO(tokenizer, max_seq_len=max_seq_len, n_layers=6, learnable_pos_emb=False) model_eo.load_state_dict(torch.load(os.path.join(MODELS_DIR, os.path.basename(EO_URL)), map_location=device)) model_eo.eval().to(device) load_diacritization_models() # ---------- Setup: Arabic syllable transcription pipelines ---------- ASR_PIPE = pipeline("automatic-speech-recognition", model="IbrahimSalah/Arabic_speech_Syllables_recognition_Using_Wav2vec2") MT5_MODEL = AutoModelForSeq2SeqLM.from_pretrained("IbrahimSalah/Arabic_Syllables_to_text_Converter_Using_MT5") MT5_TOKENIZER = AutoTokenizer.from_pretrained("IbrahimSalah/Arabic_Syllables_to_text_Converter_Using_MT5") MT5_MODEL.eval() # Arabic diacritics set try: DIACRITICS = { araby.FATHA, araby.FATHATAN, araby.DAMMA, araby.DAMMATAN, araby.KASRA, araby.KASRATAN, araby.SUKUN, araby.SHADDA, } except: DIACRITICS = {'\u064B','\u064C','\u064D','\u064E','\u064F','\u0650','\u0651','\u0652'} # ---------- Core Functions ---------- def diacritize_text(model_type, input_text): """ Returns the diacritized text twice: once for display, once for state storage. """ text_clean = remove_non_arabic(input_text.strip()) if not text_clean: return "Please enter some Arabic text.", "" x = [text_clean] if model_type == "Encoder-Decoder": outputs = model_ed.do_tashkeel_batch(x, batch_size=16, verbose=False) else: outputs = model_eo.do_tashkeel_batch(x, batch_size=16, verbose=False) result = outputs[0] if outputs else "" return result, result def get_and_process_syllables(audio_path): # ASR -> syllable sequence -> MT5 conversion clip = ASR_PIPE(audio_path)["text"] seq = "|" + clip.replace(" ", "|") + "." input_ids = MT5_TOKENIZER.encode(seq, return_tensors="pt") out_ids = MT5_MODEL.generate( input_ids, max_length=100, early_stopping=True, pad_token_id=MT5_TOKENIZER.pad_token_id, bos_token_id=MT5_TOKENIZER.bos_token_id, eos_token_id=MT5_TOKENIZER.eos_token_id, ) text = MT5_TOKENIZER.decode(out_ids[0][1:], skip_special_tokens=True).split('.')[0] return text, seq def get_diacritics_sequence(txt): return ' '.join([c for c in txt if c in DIACRITICS]) def calculate_metrics(ref, hyp): if not ref.strip() and not hyp.strip(): return 0.0, 0.0, 0.0 if not ref.strip(): return 1.0, 1.0, 1.0 wer = jiwer.wer(ref, hyp) ref_d, hyp_d = get_diacritics_sequence(ref), get_diacritics_sequence(hyp) der = 0.0 if (not ref_d and not hyp_d) else (1.0 if not ref_d else jiwer.wer(ref_d, hyp_d)) cer = jiwer.cer(ref, hyp) return round(wer,4), round(der,4), round(cer,4) def highlight_errors(ref, hyp): ref_w, hyp_w = ref.split(), hyp.split() matcher = difflib.SequenceMatcher(None, ref_w, hyp_w, autojunk=False) out_words, errs = [], [] for tag, i1, i2, j1, j2 in matcher.get_opcodes(): if tag == 'equal': out_words.extend(hyp_w[j1:j2]) elif tag == 'replace': for w in hyp_w[j1:j2]: out_words.append(f"{w}") errs.extend(ref_w[i1:i2] + hyp_w[j1:j2]) elif tag == 'delete': errs.extend(ref_w[i1:i2]) elif tag == 'insert': for w in hyp_w[j1:j2]: out_words.append(f"{w}") errs.extend(hyp_w[j1:j2]) return ' '.join(out_words), ', '.join(sorted(set(errs))) def process_audio_and_compare(audio_path, reference_text): if not audio_path: return *("Error: No audio provided.",)*2, None, None, None, "", "" if not reference_text.strip(): return *("Error: No reference text.",)*2, None, None, None, "", "" hyp, syll = get_and_process_syllables(audio_path) wer, der, cer = calculate_metrics(reference_text, hyp) if not hyp.startswith("Error") else (None,None,None) html_out, errs = highlight_errors(reference_text, hyp) if not hyp.startswith("Error") else ("", "") return hyp, syll, wer, der, cer, html_out, errs # ---------- Gradio Interface ---------- with gr.Blocks(theme=gr.themes.Soft()) as app: gr.Markdown(""" # Arabic Diacritization & Reading Assessment 1. Enter undiacritized Arabic text → Diacritize. 2. Optionally edit the diacritized result. 3. Record/upload audio → Transcribe & Compare. """) ref_state = gr.State("") with gr.Row(): with gr.Column(scale=1): text_in = gr.Textbox(label="Undiacritized Arabic Text", lines=3, text_align="right") model_sel = gr.Dropdown(choices=["Encoder-Only","Encoder-Decoder"], value="Encoder-Only", label="Model") diac_btn = gr.Button("Diacritize Text") diac_out = gr.Textbox(label="Diacritized Text (Reference)", lines=3, text_align="right", interactive=True) diac_btn.click(fn=diacritize_text, inputs=[model_sel, text_in], outputs=[diac_out, ref_state]) diac_out.change(fn=lambda text: text, inputs=diac_out, outputs=ref_state) with gr.Column(scale=1): audio_in = gr.Audio(label="Record/Upload Audio", type="filepath") trans_btn = gr.Button("Transcribe & Compare") hyp_out = gr.Textbox(label="Transcript (Hypothesis)", lines=3, text_align="right") syl_out = gr.Textbox(label="Transcript Syllables", lines=3, text_align="right") wer_n = gr.Number(label="WER", precision=4) der_n = gr.Number(label="DER", precision=4) cer_n = gr.Number(label="CER", precision=4) err_html = gr.HTML(label="Highlighted Errors") err_list = gr.Textbox(label="Error Words") trans_btn.click( fn=process_audio_and_compare, inputs=[audio_in, ref_state], outputs=[hyp_out, syl_out, wer_n, der_n, cer_n, err_html, err_list] ) # Launch if __name__ == "__main__": app.launch(debug=True, share=True)