File size: 7,726 Bytes
2b1b4a0
 
 
 
647d7b9
 
2b1b4a0
17ddd97
2b1b4a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31c4539
2b1b4a0
18d0c35
 
17ddd97
2b1b4a0
 
 
 
 
a325594
 
 
2b1b4a0
 
a325594
2b1b4a0
 
a325594
2b1b4a0
a325594
 
 
2b1b4a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31c4539
 
2b1b4a0
 
18d0c35
31c4539
2b1b4a0
 
 
 
 
 
 
 
31c4539
 
2b1b4a0
 
 
 
31c4539
18d0c35
2b1b4a0
31c4539
2b1b4a0
 
31c4539
2b1b4a0
31c4539
2b1b4a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
647d7b9
2b1b4a0
 
 
e69f273
 
2b1b4a0
 
31c4539
647d7b9
 
2b1b4a0
 
 
e69f273
2b1b4a0
e69f273
647d7b9
 
2b1b4a0
 
 
 
 
 
 
 
 
 
 
 
 
 
31c4539
647d7b9
2b1b4a0
0570125
370cc4e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import os
import sys
import urllib.request
import torch
import gradio as gr
import jiwer
import difflib
import pyarabic.araby as araby
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer

# ---------- Setup: Clone CATT repo & download diacritization models ----------
CATT_REPO_URL = "https://github.com/abjadai/catt.git"
CATT_FOLDER = "catt"
MODELS_DIR = "models"
ED_URL = "https://github.com/abjadai/catt/releases/download/v2/best_ed_mlm_ns_epoch_178.pt"
EO_URL = "https://github.com/abjadai/catt/releases/download/v2/best_eo_mlm_ns_epoch_193.pt"

os.makedirs(MODELS_DIR, exist_ok=True)

# Clone if needed
if not os.path.isdir(CATT_FOLDER):
    os.system(f"git clone {CATT_REPO_URL}")
if CATT_FOLDER not in sys.path:
    sys.path.append(CATT_FOLDER)

# Download checkpoints
for url in (ED_URL, EO_URL):
    fname = os.path.basename(url)
    dest = os.path.join(MODELS_DIR, fname)
    if not os.path.isfile(dest):
        urllib.request.urlretrieve(url, dest)

# Import CATT modules
from tashkeel_tokenizer import TashkeelTokenizer
from utils import remove_non_arabic
from ed_pl import TashkeelModel as TashkeelModel_ED
from eo_pl import TashkeelModel as TashkeelModel_EO

# Prepare tokenizer & device
tokenizer = TashkeelTokenizer()
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load diacritization models
def load_diacritization_models():
    global model_ed, model_eo
    max_seq_len = 1024
    model_ed = TashkeelModel_ED(tokenizer, max_seq_len=max_seq_len, n_layers=3, learnable_pos_emb=False)
    model_ed.load_state_dict(torch.load(os.path.join(MODELS_DIR, os.path.basename(ED_URL)), map_location=device))
    model_ed.eval().to(device)

    model_eo = TashkeelModel_EO(tokenizer, max_seq_len=max_seq_len, n_layers=6, learnable_pos_emb=False)
    model_eo.load_state_dict(torch.load(os.path.join(MODELS_DIR, os.path.basename(EO_URL)), map_location=device))
    model_eo.eval().to(device)

load_diacritization_models()

# ---------- Setup: Arabic syllable transcription pipelines ----------
ASR_PIPE = pipeline("automatic-speech-recognition", model="IbrahimSalah/Arabic_speech_Syllables_recognition_Using_Wav2vec2")
MT5_MODEL = AutoModelForSeq2SeqLM.from_pretrained("IbrahimSalah/Arabic_Syllables_to_text_Converter_Using_MT5")
MT5_TOKENIZER = AutoTokenizer.from_pretrained("IbrahimSalah/Arabic_Syllables_to_text_Converter_Using_MT5")
MT5_MODEL.eval()

# Arabic diacritics set
try:
    DIACRITICS = {
        araby.FATHA, araby.FATHATAN, araby.DAMMA, araby.DAMMATAN,
        araby.KASRA, araby.KASRATAN, araby.SUKUN, araby.SHADDA,
    }
except:
    DIACRITICS = {'\u064B','\u064C','\u064D','\u064E','\u064F','\u0650','\u0651','\u0652'}

# ---------- Core Functions ----------
def diacritize_text(model_type, input_text):
    """
    Returns the diacritized text twice: once for display, once for state storage.
    """
    text_clean = remove_non_arabic(input_text.strip())
    if not text_clean:
        return "Please enter some Arabic text.", ""
    x = [text_clean]
    if model_type == "Encoder-Decoder":
        outputs = model_ed.do_tashkeel_batch(x, batch_size=16, verbose=False)
    else:
        outputs = model_eo.do_tashkeel_batch(x, batch_size=16, verbose=False)
    result = outputs[0] if outputs else ""
    return result, result


def get_and_process_syllables(audio_path):
    # ASR -> syllable sequence -> MT5 conversion
    clip = ASR_PIPE(audio_path)["text"]
    seq = "|" + clip.replace(" ", "|") + "."
    input_ids = MT5_TOKENIZER.encode(seq, return_tensors="pt")
    out_ids = MT5_MODEL.generate(
        input_ids,
        max_length=100,
        early_stopping=True,
        pad_token_id=MT5_TOKENIZER.pad_token_id,
        bos_token_id=MT5_TOKENIZER.bos_token_id,
        eos_token_id=MT5_TOKENIZER.eos_token_id,
    )
    text = MT5_TOKENIZER.decode(out_ids[0][1:], skip_special_tokens=True).split('.')[0]
    return text, seq


def get_diacritics_sequence(txt):
    return ' '.join([c for c in txt if c in DIACRITICS])


def calculate_metrics(ref, hyp):
    if not ref.strip() and not hyp.strip(): return 0.0, 0.0, 0.0
    if not ref.strip(): return 1.0, 1.0, 1.0
    wer = jiwer.wer(ref, hyp)
    ref_d, hyp_d = get_diacritics_sequence(ref), get_diacritics_sequence(hyp)
    der = 0.0 if (not ref_d and not hyp_d) else (1.0 if not ref_d else jiwer.wer(ref_d, hyp_d))
    cer = jiwer.cer(ref, hyp)
    return round(wer,4), round(der,4), round(cer,4)


def highlight_errors(ref, hyp):
    ref_w, hyp_w = ref.split(), hyp.split()
    matcher = difflib.SequenceMatcher(None, ref_w, hyp_w, autojunk=False)
    out_words, errs = [], []
    for tag, i1, i2, j1, j2 in matcher.get_opcodes():
        if tag == 'equal':
            out_words.extend(hyp_w[j1:j2])
        elif tag == 'replace':
            for w in hyp_w[j1:j2]: out_words.append(f"<mark style='background-color:#ffcccb;'>{w}</mark>")
            errs.extend(ref_w[i1:i2] + hyp_w[j1:j2])
        elif tag == 'delete':
            errs.extend(ref_w[i1:i2])
        elif tag == 'insert':
            for w in hyp_w[j1:j2]: out_words.append(f"<mark style='background-color:#ccffcc;'>{w}</mark>")
            errs.extend(hyp_w[j1:j2])
    return ' '.join(out_words), ', '.join(sorted(set(errs)))


def process_audio_and_compare(audio_path, reference_text):
    if not audio_path:
        return *("Error: No audio provided.",)*2, None, None, None, "", ""
    if not reference_text.strip():
        return *("Error: No reference text.",)*2, None, None, None, "", ""
    hyp, syll = get_and_process_syllables(audio_path)
    wer, der, cer = calculate_metrics(reference_text, hyp) if not hyp.startswith("Error") else (None,None,None)
    html_out, errs = highlight_errors(reference_text, hyp) if not hyp.startswith("Error") else ("", "")
    return hyp, syll, wer, der, cer, html_out, errs

# ---------- Gradio Interface ----------
with gr.Blocks(theme=gr.themes.Soft()) as app:
    gr.Markdown("""
    # Arabic Diacritization & Reading Assessment
    1. Enter undiacritized Arabic text → Diacritize.
    2. Optionally edit the diacritized result.
    3. Record/upload audio → Transcribe & Compare.
    """)
    ref_state = gr.State("")

    with gr.Row():
        with gr.Column(scale=1):
            text_in = gr.Textbox(label="Undiacritized Arabic Text", lines=3, text_align="right")
            model_sel = gr.Dropdown(choices=["Encoder-Only","Encoder-Decoder"], value="Encoder-Only", label="Model")
            diac_btn = gr.Button("Diacritize Text")
            diac_out = gr.Textbox(label="Diacritized Text (Reference)", lines=3, text_align="right", interactive=True)
            diac_btn.click(fn=diacritize_text, inputs=[model_sel, text_in], outputs=[diac_out, ref_state])
            diac_out.change(fn=lambda text: text, inputs=diac_out, outputs=ref_state)

        with gr.Column(scale=1):
            audio_in = gr.Audio(label="Record/Upload Audio", type="filepath")
            trans_btn = gr.Button("Transcribe & Compare")
            hyp_out = gr.Textbox(label="Transcript (Hypothesis)", lines=3, text_align="right")
            syl_out = gr.Textbox(label="Transcript Syllables", lines=3, text_align="right")
            wer_n = gr.Number(label="WER", precision=4)
            der_n = gr.Number(label="DER", precision=4)
            cer_n = gr.Number(label="CER", precision=4)
            err_html = gr.HTML(label="Highlighted Errors")
            err_list = gr.Textbox(label="Error Words")

            trans_btn.click(
                fn=process_audio_and_compare,
                inputs=[audio_in, ref_state],
                outputs=[hyp_out, syl_out, wer_n, der_n, cer_n, err_html, err_list]
            )

# Launch
if __name__ == "__main__":
    app.launch(debug=True, share=True)