Spaces:
Runtime error
Runtime error
Siddhant
commited on
Commit
·
b9a6dd9
1
Parent(s):
58f82d5
Update demo
Browse files- app.py +856 -492
- pyscripts/utils/dialog_eval/ASR_WER.py +165 -0
- pyscripts/utils/dialog_eval/LLM_Metrics.py +245 -0
- pyscripts/utils/dialog_eval/TTS_intelligibility.py +169 -0
- pyscripts/utils/dialog_eval/TTS_speech_quality.py +98 -0
- pyscripts/utils/dialog_eval/__pycache__/ASR_WER.cpython-39.pyc +0 -0
- pyscripts/utils/dialog_eval/__pycache__/LLM_Metrics.cpython-39.pyc +0 -0
- pyscripts/utils/dialog_eval/__pycache__/TTS_intelligibility.cpython-39.pyc +0 -0
- pyscripts/utils/dialog_eval/__pycache__/TTS_speech_quality.cpython-39.pyc +0 -0
- pyscripts/utils/dialog_eval/__pycache__/human_feedback.cpython-39.pyc +0 -0
- pyscripts/utils/dialog_eval/__pycache__/vert.cpython-39.pyc +0 -0
- pyscripts/utils/dialog_eval/human_feedback.py +242 -0
- pyscripts/utils/dialog_eval/vert.py +299 -0
app.py
CHANGED
|
@@ -5,347 +5,382 @@ except ImportError:
|
|
| 5 |
with open('versa.sh', 'rb') as file:
|
| 6 |
script = file.read()
|
| 7 |
rc = call(script, shell=True)
|
|
|
|
| 8 |
import os
|
| 9 |
import shutil
|
| 10 |
-
|
| 11 |
-
from
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
from espnet2.sds.llm.hugging_face_llm import HuggingFaceLLM
|
| 17 |
-
from espnet2.sds.vad.webrtc_vad import WebrtcVADModel
|
| 18 |
-
from espnet2.sds.eval.TTS_intelligibility import handle_espnet_TTS_intelligibility
|
| 19 |
-
from espnet2.sds.eval.ASR_WER import handle_espnet_ASR_WER
|
| 20 |
-
from espnet2.sds.eval.TTS_speech_quality import TTS_psuedomos
|
| 21 |
-
from espnet2.sds.eval.LLM_Metrics import perplexity, vert, bert_score, DialoGPT_perplexity
|
| 22 |
-
from espnet2.sds.utils.chat import Chat
|
| 23 |
-
from espnet2.sds.end_to_end.mini_omni_e2e import MiniOmniE2EModel
|
| 24 |
-
import argparse
|
| 25 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
access_token = os.environ.get("HF_TOKEN")
|
| 28 |
ASR_name="pyf98/owsm_ctc_v3.1_1B"
|
| 29 |
LLM_name="meta-llama/Llama-3.2-1B-Instruct"
|
| 30 |
TTS_name="kan-bayashi/ljspeech_vits"
|
| 31 |
-
ASR_options="pyf98/owsm_ctc_v3.1_1B,espnet/owsm_ctc_v3.2_ft_1B,espnet/owsm_v3.1_ebf,librispeech_asr,whisper".split(",")
|
| 32 |
LLM_options="meta-llama/Llama-3.2-1B-Instruct,HuggingFaceTB/SmolLM2-1.7B-Instruct".split(",")
|
| 33 |
TTS_options="kan-bayashi/ljspeech_vits,kan-bayashi/libritts_xvector_vits,kan-bayashi/vctk_multi_spk_vits,ChatTTS".split(",")
|
| 34 |
Eval_options="Latency,TTS Intelligibility,TTS Speech Quality,ASR WER,Text Dialog Metrics"
|
| 35 |
upload_to_hub=None
|
|
|
|
|
|
|
|
|
|
| 36 |
ASR_curr_name=None
|
| 37 |
LLM_curr_name=None
|
| 38 |
TTS_curr_name=None
|
| 39 |
-
# def read_args():
|
| 40 |
-
# global access_token
|
| 41 |
-
# global ASR_name
|
| 42 |
-
# global LLM_name
|
| 43 |
-
# global TTS_name
|
| 44 |
-
# global ASR_options
|
| 45 |
-
# global LLM_options
|
| 46 |
-
# global TTS_options
|
| 47 |
-
# global Eval_options
|
| 48 |
-
# global upload_to_hub
|
| 49 |
-
# parser = argparse.ArgumentParser(description="Run the app with HF_TOKEN as a command-line argument.")
|
| 50 |
-
# parser.add_argument("--HF_TOKEN", required=True, help="Provide the Hugging Face token.")
|
| 51 |
-
# parser.add_argument("--asr_options", required=True, help="Provide the possible ASR options available to user.")
|
| 52 |
-
# parser.add_argument("--llm_options", required=True, help="Provide the possible LLM options available to user.")
|
| 53 |
-
# parser.add_argument("--tts_options", required=True, help="Provide the possible TTS options available to user.")
|
| 54 |
-
# parser.add_argument("--eval_options", required=True, help="Provide the possible automatic evaluation metrics available to user.")
|
| 55 |
-
# parser.add_argument("--default_asr_model", required=False, default="pyf98/owsm_ctc_v3.1_1B", help="Provide the default ASR model.")
|
| 56 |
-
# parser.add_argument("--default_llm_model", required=False, default="meta-llama/Llama-3.2-1B-Instruct", help="Provide the default ASR model.")
|
| 57 |
-
# parser.add_argument("--default_tts_model", required=False, default="kan-bayashi/ljspeech_vits", help="Provide the default ASR model.")
|
| 58 |
-
# parser.add_argument("--upload_to_hub", required=False, default=None, help="Hugging Face dataset to upload user data")
|
| 59 |
-
# args = parser.parse_args()
|
| 60 |
-
# access_token=args.HF_TOKEN
|
| 61 |
-
# ASR_name=args.default_asr_model
|
| 62 |
-
# LLM_name=args.default_llm_model
|
| 63 |
-
# TTS_name=args.default_tts_model
|
| 64 |
-
# ASR_options=args.asr_options.split(",")
|
| 65 |
-
# LLM_options=args.llm_options.split(",")
|
| 66 |
-
# TTS_options=args.tts_options.split(",")
|
| 67 |
-
# Eval_options=args.eval_options.split(",")
|
| 68 |
-
# upload_to_hub=args.upload_to_hub
|
| 69 |
-
|
| 70 |
-
# read_args()
|
| 71 |
-
from huggingface_hub import HfApi
|
| 72 |
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
import gradio as gr
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
import numpy as np
|
| 80 |
-
|
| 81 |
-
chat = Chat(2)
|
| 82 |
-
chat.init_chat({"role": "system", "content": "You are a helpful and friendly AI assistant. The user is talking to you with their voice and you should respond in a conversational style. You are polite, respectful, and aim to provide concise and complete responses of less than 15 words."})
|
| 83 |
-
user_role = "user"
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
client=None
|
| 89 |
-
|
| 90 |
-
latency_ASR=0.0
|
| 91 |
-
latency_LM=0.0
|
| 92 |
-
latency_TTS=0.0
|
| 93 |
-
|
| 94 |
-
text_str=""
|
| 95 |
-
asr_output_str=""
|
| 96 |
-
vad_output=None
|
| 97 |
audio_output = None
|
| 98 |
audio_output1 = None
|
| 99 |
-
LLM_response_arr=[]
|
| 100 |
-
total_response_arr=[]
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
if TTS_curr_name is not None:
|
| 105 |
-
if option==TTS_curr_name:
|
| 106 |
-
return
|
| 107 |
-
yield gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False)
|
| 108 |
-
global text2speech
|
| 109 |
-
TTS_curr_name=option
|
| 110 |
-
tag = option
|
| 111 |
-
if tag=="ChatTTS":
|
| 112 |
-
text2speech = ChatTTSModel()
|
| 113 |
-
else:
|
| 114 |
-
text2speech = ESPnetTTSModel(tag)
|
| 115 |
-
text2speech.warmup()
|
| 116 |
-
yield gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True)
|
| 117 |
-
|
| 118 |
-
def handle_LLM_selection(option):
|
| 119 |
-
global LLM_curr_name
|
| 120 |
-
if LLM_curr_name is not None:
|
| 121 |
-
if option==LLM_curr_name:
|
| 122 |
-
return
|
| 123 |
-
yield gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False)
|
| 124 |
-
global LM_pipe
|
| 125 |
-
LLM_curr_name=option
|
| 126 |
-
LM_pipe = HuggingFaceLLM(access_token=access_token,tag = option)
|
| 127 |
-
LM_pipe.warmup()
|
| 128 |
-
yield gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True)
|
| 129 |
-
|
| 130 |
-
def handle_ASR_selection(option):
|
| 131 |
-
global ASR_curr_name
|
| 132 |
-
if option=="librispeech_asr":
|
| 133 |
-
option="espnet/simpleoier_librispeech_asr_train_asr_conformer7_wavlm_large_raw_en_bpe5000_sp"
|
| 134 |
-
if ASR_curr_name is not None:
|
| 135 |
-
if option==ASR_curr_name:
|
| 136 |
-
return
|
| 137 |
-
yield gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False)
|
| 138 |
-
global s2t
|
| 139 |
-
ASR_curr_name=option
|
| 140 |
-
if option=="espnet/owsm_v3.1_ebf":
|
| 141 |
-
s2t = OWSMModel()
|
| 142 |
-
elif option=="espnet/simpleoier_librispeech_asr_train_asr_conformer7_wavlm_large_raw_en_bpe5000_sp":
|
| 143 |
-
s2t = ESPnetASRModel(tag=option)
|
| 144 |
-
elif option=="whisper":
|
| 145 |
-
s2t = WhisperASRModel()
|
| 146 |
-
else:
|
| 147 |
-
s2t = OWSMCTCModel(tag=option)
|
| 148 |
|
| 149 |
-
|
| 150 |
-
|
|
|
|
| 151 |
|
| 152 |
-
def handle_eval_selection(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
global LLM_response_arr
|
| 154 |
global total_response_arr
|
| 155 |
-
yield (option,gr.Textbox(visible=True))
|
| 156 |
-
if option=="Latency":
|
| 157 |
-
text=
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
yield (None,
|
| 163 |
-
elif option=="
|
| 164 |
-
yield (None,
|
| 165 |
-
elif option=="
|
| 166 |
-
yield (None,
|
| 167 |
-
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
global LLM_response_arr
|
| 170 |
global total_response_arr
|
| 171 |
-
yield (option,gr.Textbox(visible=True))
|
| 172 |
-
if option=="Latency":
|
| 173 |
-
text=f"Total Latency: {latency_TTS:.2f}"
|
| 174 |
-
yield (None,text)
|
| 175 |
-
elif option=="TTS Intelligibility":
|
| 176 |
-
yield (None,handle_espnet_TTS_intelligibility(TTS_audio_output,LLM_Output))
|
| 177 |
-
elif option=="TTS Speech Quality":
|
| 178 |
-
yield (None,TTS_psuedomos(TTS_audio_output))
|
| 179 |
-
elif option=="Text Dialog Metrics":
|
| 180 |
-
yield (None,perplexity(LLM_Output.replace("\n"," "))+vert(LLM_response_arr))
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
global client
|
| 184 |
-
global LM_pipe
|
| 185 |
-
global s2t
|
| 186 |
-
global text2speech
|
| 187 |
-
yield (gr.Radio(visible=False),gr.Radio(visible=False),gr.Radio(visible=False),gr.Radio(visible=False), gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False),gr.Radio(visible=False),gr.Radio(visible=False))
|
| 188 |
-
if option=="Cascaded":
|
| 189 |
-
client=None
|
| 190 |
-
for _ in handle_selection(TTS_radio):
|
| 191 |
-
continue
|
| 192 |
-
for _ in handle_ASR_selection(ASR_radio):
|
| 193 |
-
continue
|
| 194 |
-
for _ in handle_LLM_selection(LLM_radio):
|
| 195 |
-
continue
|
| 196 |
-
yield (gr.Radio(visible=True),gr.Radio(visible=True),gr.Radio(visible=True),gr.Radio(visible=False),gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True),gr.Radio(visible=True, interactive=True),gr.Radio(visible=False))
|
| 197 |
else:
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
LM_pipe=None
|
| 201 |
-
global ASR_curr_name
|
| 202 |
-
global LLM_curr_name
|
| 203 |
-
global TTS_curr_name
|
| 204 |
-
ASR_curr_name=None
|
| 205 |
-
LLM_curr_name=None
|
| 206 |
-
TTS_curr_name=None
|
| 207 |
-
handle_E2E_selection()
|
| 208 |
-
yield (gr.Radio(visible=False),gr.Radio(visible=False),gr.Radio(visible=False),gr.Radio(visible=True),gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True),gr.Radio(visible=False),gr.Radio(visible=True, interactive=True))
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
def handle_E2E_selection():
|
| 212 |
-
global client
|
| 213 |
-
if client is None:
|
| 214 |
-
client = MiniOmniE2EModel()
|
| 215 |
-
client.warmup()
|
| 216 |
|
| 217 |
def start_warmup():
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
continue
|
| 241 |
-
for _ in handle_ASR_selection(ASR_name):
|
| 242 |
continue
|
| 243 |
-
for _ in handle_LLM_selection(LLM_name):
|
| 244 |
continue
|
| 245 |
-
dummy_input =
|
|
|
|
| 246 |
(3000),
|
| 247 |
dtype=getattr(torch, "float16"),
|
| 248 |
device="cpu",
|
| 249 |
-
|
| 250 |
-
|
|
|
|
|
|
|
|
|
|
| 251 |
for opt in Eval_options:
|
| 252 |
handle_eval_selection(opt, dummy_input, dummy_text, dummy_input, dummy_text)
|
| 253 |
|
| 254 |
-
start_warmup()
|
| 255 |
-
vad_model=WebrtcVADModel()
|
| 256 |
|
| 257 |
-
callback = gr.CSVLogger()
|
| 258 |
-
start_record_time=None
|
| 259 |
-
enable_btn = gr.Button(interactive=True, visible=True)
|
| 260 |
-
disable_btn = gr.Button(interactive=False, visible=False)
|
| 261 |
def flash_buttons():
|
|
|
|
|
|
|
|
|
|
| 262 |
btn_updates = (enable_btn,) * 8
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
return ip
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
def vote_last_response(vote_type, request: gr.Request):
|
| 280 |
-
with open("save_dict.json", "a") as fout:
|
| 281 |
-
data = {
|
| 282 |
-
"tstamp": round(time.time(), 4),
|
| 283 |
-
"type": vote_type,
|
| 284 |
-
"ip": get_ip(request),
|
| 285 |
-
}
|
| 286 |
-
fout.write(json.dumps(data) + "\n")
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
def natural_vote1_last_response(
|
| 290 |
-
request: gr.Request
|
| 291 |
):
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
ip_address1=get_ip(request)
|
| 328 |
-
print(f"Partially Relevant (voted). ip: {ip_address1}")
|
| 329 |
-
return ("Partially Relevant",ip_address1,)+(disable_btn,) * 4
|
| 330 |
-
|
| 331 |
-
def relevant_vote3_last_response(
|
| 332 |
-
request: gr.Request
|
| 333 |
-
):
|
| 334 |
-
ip_address1=get_ip(request)
|
| 335 |
-
print(f"Slightly Irrelevant (voted). ip: {ip_address1}")
|
| 336 |
-
return ("Slightly Irrelevant",ip_address1,)+(disable_btn,) * 4
|
| 337 |
-
|
| 338 |
-
def relevant_vote4_last_response(
|
| 339 |
-
request: gr.Request
|
| 340 |
-
):
|
| 341 |
-
ip_address1=get_ip(request)
|
| 342 |
-
print(f"Completely Irrelevant (voted). ip: {ip_address1}")
|
| 343 |
-
return ("Completely Irrelevant",ip_address1,)+(disable_btn,) * 4
|
| 344 |
-
|
| 345 |
-
import json
|
| 346 |
-
import time
|
| 347 |
-
|
| 348 |
-
def transcribe(stream, new_chunk, TTS_option, ASR_option, LLM_option, type_option):
|
| 349 |
sr, y = new_chunk
|
| 350 |
global text_str
|
| 351 |
global chat
|
|
@@ -364,219 +399,548 @@ def transcribe(stream, new_chunk, TTS_option, ASR_option, LLM_option, type_optio
|
|
| 364 |
global total_response_arr
|
| 365 |
if stream is None:
|
| 366 |
# Handle user refresh
|
| 367 |
-
|
| 368 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
gr.Info("The models are being reloaded due to a browser refresh.")
|
| 370 |
-
yield (stream,asr_output_box,text_box,audio_box,gr.Audio(visible=False))
|
| 371 |
-
stream=y
|
| 372 |
-
|
| 373 |
-
text_str=""
|
| 374 |
audio_output = None
|
| 375 |
audio_output1 = None
|
| 376 |
else:
|
| 377 |
-
stream=np.concatenate((stream,y))
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
chat.append({"role": user_role, "content": prompt})
|
| 410 |
-
chat_messages = chat.to_list()
|
| 411 |
-
generated_text = LM_pipe(chat_messages)
|
| 412 |
-
start_TTS_time=time.time()
|
| 413 |
-
latency_LM=(start_TTS_time - start_LM_time)
|
| 414 |
-
|
| 415 |
-
chat.append({"role": "assistant", "content": generated_text})
|
| 416 |
-
text_str=generated_text
|
| 417 |
-
audio_output=text2speech(text_str)
|
| 418 |
-
latency_TTS=(time.time() - start_TTS_time)
|
| 419 |
-
audio_output1=(orig_sr,stream)
|
| 420 |
-
stream=y
|
| 421 |
-
LLM_response_arr.append(text_str.replace("\n"," "))
|
| 422 |
-
total_response_arr.append(text_str.replace("\n"," "))
|
| 423 |
-
text_str1=text_str
|
| 424 |
-
if ((text_str!="") and (start_record_time is None)):
|
| 425 |
-
start_record_time=time.time()
|
| 426 |
elif start_record_time is not None:
|
| 427 |
-
current_record_time=time.time()
|
| 428 |
-
if current_record_time-start_record_time>300:
|
| 429 |
-
gr.Info(
|
| 430 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
if upload_to_hub is not None:
|
| 432 |
api.upload_folder(
|
| 433 |
folder_path="flagged_data_points",
|
| 434 |
-
path_in_repo="checkpoint_"+str(start_record_time),
|
| 435 |
repo_id=upload_to_hub,
|
| 436 |
repo_type="dataset",
|
| 437 |
token=access_token,
|
| 438 |
)
|
| 439 |
-
chat.buffer=[
|
| 440 |
-
text_str=""
|
| 441 |
audio_output = None
|
| 442 |
audio_output1 = None
|
| 443 |
asr_output_str = ""
|
| 444 |
start_record_time = None
|
| 445 |
-
LLM_response_arr=[]
|
| 446 |
-
total_response_arr=[]
|
| 447 |
-
shutil.rmtree(
|
| 448 |
os.mkdir("flagged_data_points")
|
| 449 |
-
yield (stream,asr_output_str,text_str1, audio_output, audio_output1)
|
| 450 |
-
yield stream,gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(
|
| 451 |
-
|
| 452 |
-
|
| 453 |
|
|
|
|
| 454 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 455 |
with gr.Blocks(
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
)
|
| 496 |
-
with gr.Row():
|
| 497 |
-
natural_btn1 = gr.Button(
|
| 498 |
-
value="Very Natural", visible=False, interactive=False, scale=1
|
| 499 |
-
)
|
| 500 |
-
natural_btn2 = gr.Button(
|
| 501 |
-
value="Somewhat Awkward", visible=False, interactive=False, scale=1
|
| 502 |
-
)
|
| 503 |
-
natural_btn3 = gr.Button(value="Very Awkward", visible=False, interactive=False, scale=1)
|
| 504 |
-
natural_btn4 = gr.Button(
|
| 505 |
-
value="Unnatural", visible=False, interactive=False, scale=1
|
| 506 |
-
)
|
| 507 |
-
with gr.Row():
|
| 508 |
-
relevant_btn1 = gr.Button(
|
| 509 |
-
value="Highly Relevant", visible=False, interactive=False, scale=1
|
| 510 |
-
)
|
| 511 |
-
relevant_btn2 = gr.Button(
|
| 512 |
-
value="Partially Relevant", visible=False, interactive=False, scale=1
|
| 513 |
-
)
|
| 514 |
-
relevant_btn3 = gr.Button(value="Slightly Irrelevant", visible=False, interactive=False, scale=1)
|
| 515 |
-
relevant_btn4 = gr.Button(
|
| 516 |
-
value= "Completely Irrelevant", visible=False, interactive=False, scale=1
|
| 517 |
-
)
|
| 518 |
-
with gr.Column(scale=1):
|
| 519 |
-
output_audio = gr.Audio(label="Output", interactive=False, autoplay=True, visible=True)
|
| 520 |
-
output_audio1 = gr.Audio(label="Output1", autoplay=False, visible=False)
|
| 521 |
-
output_asr_text = gr.Textbox(label="ASR output", interactive=False)
|
| 522 |
-
output_text = gr.Textbox(label="LLM output", interactive=False)
|
| 523 |
-
eval_radio = gr.Radio(
|
| 524 |
-
choices=["Latency", "TTS Intelligibility", "TTS Speech Quality", "ASR WER","Text Dialog Metrics"],
|
| 525 |
-
label="Choose Evaluation metrics:",
|
| 526 |
)
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 530 |
visible=False,
|
| 531 |
)
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 582 |
demo.launch(share=True)
|
|
|
|
|
|
| 5 |
with open('versa.sh', 'rb') as file:
|
| 6 |
script = file.read()
|
| 7 |
rc = call(script, shell=True)
|
| 8 |
+
|
| 9 |
import os
|
| 10 |
import shutil
|
| 11 |
+
import time
|
| 12 |
+
from typing import Generator, Optional, Tuple
|
| 13 |
+
|
| 14 |
+
import gradio as gr
|
| 15 |
+
import nltk
|
| 16 |
+
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
import torch
|
| 18 |
+
from huggingface_hub import HfApi
|
| 19 |
+
from pyscripts.utils.dialog_eval.ASR_WER import handle_espnet_ASR_WER
|
| 20 |
+
from pyscripts.utils.dialog_eval.human_feedback import (
|
| 21 |
+
natural_vote1_last_response,
|
| 22 |
+
natural_vote2_last_response,
|
| 23 |
+
natural_vote3_last_response,
|
| 24 |
+
natural_vote4_last_response,
|
| 25 |
+
relevant_vote1_last_response,
|
| 26 |
+
relevant_vote2_last_response,
|
| 27 |
+
relevant_vote3_last_response,
|
| 28 |
+
relevant_vote4_last_response,
|
| 29 |
+
)
|
| 30 |
+
from pyscripts.utils.dialog_eval.LLM_Metrics import (
|
| 31 |
+
DialoGPT_perplexity,
|
| 32 |
+
bert_score,
|
| 33 |
+
perplexity,
|
| 34 |
+
vert,
|
| 35 |
+
)
|
| 36 |
+
from pyscripts.utils.dialog_eval.TTS_intelligibility import (
|
| 37 |
+
handle_espnet_TTS_intelligibility,
|
| 38 |
+
)
|
| 39 |
+
from pyscripts.utils.dialog_eval.TTS_speech_quality import TTS_psuedomos
|
| 40 |
+
|
| 41 |
+
from espnet2.sds.espnet_model import ESPnetSDSModelInterface
|
| 42 |
+
|
| 43 |
+
# ------------------------
|
| 44 |
+
# Hyperparameters
|
| 45 |
+
# ------------------------
|
| 46 |
|
| 47 |
access_token = os.environ.get("HF_TOKEN")
|
| 48 |
ASR_name="pyf98/owsm_ctc_v3.1_1B"
|
| 49 |
LLM_name="meta-llama/Llama-3.2-1B-Instruct"
|
| 50 |
TTS_name="kan-bayashi/ljspeech_vits"
|
| 51 |
+
ASR_options="pyf98/owsm_ctc_v3.1_1B,espnet/owsm_ctc_v3.2_ft_1B,espnet/owsm_v3.1_ebf,librispeech_asr,whisper-large".split(",")
|
| 52 |
LLM_options="meta-llama/Llama-3.2-1B-Instruct,HuggingFaceTB/SmolLM2-1.7B-Instruct".split(",")
|
| 53 |
TTS_options="kan-bayashi/ljspeech_vits,kan-bayashi/libritts_xvector_vits,kan-bayashi/vctk_multi_spk_vits,ChatTTS".split(",")
|
| 54 |
Eval_options="Latency,TTS Intelligibility,TTS Speech Quality,ASR WER,Text Dialog Metrics"
|
| 55 |
upload_to_hub=None
|
| 56 |
+
dialogue_model = ESPnetSDSModelInterface(
|
| 57 |
+
ASR_name, LLM_name, TTS_name, "Cascaded", access_token
|
| 58 |
+
)
|
| 59 |
ASR_curr_name=None
|
| 60 |
LLM_curr_name=None
|
| 61 |
TTS_curr_name=None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
+
latency_ASR = 0.0
|
| 64 |
+
latency_LM = 0.0
|
| 65 |
+
latency_TTS = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
+
text_str = ""
|
| 68 |
+
asr_output_str = ""
|
| 69 |
+
vad_output = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
audio_output = None
|
| 71 |
audio_output1 = None
|
| 72 |
+
LLM_response_arr = []
|
| 73 |
+
total_response_arr = []
|
| 74 |
+
callback = gr.CSVLogger()
|
| 75 |
+
start_record_time = None
|
| 76 |
+
enable_btn = gr.Button(interactive=True, visible=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
+
# ------------------------
|
| 79 |
+
# Function Definitions
|
| 80 |
+
# ------------------------
|
| 81 |
|
| 82 |
+
def handle_eval_selection(
|
| 83 |
+
option: str,
|
| 84 |
+
TTS_audio_output: str,
|
| 85 |
+
LLM_Output: str,
|
| 86 |
+
ASR_audio_output: str,
|
| 87 |
+
ASR_transcript: str,
|
| 88 |
+
):
|
| 89 |
+
"""
|
| 90 |
+
Handles the evaluation of a selected metric based on
|
| 91 |
+
user input and provided outputs.
|
| 92 |
+
|
| 93 |
+
This function evaluates different aspects of a
|
| 94 |
+
casacaded conversational AI pipeline, such as:
|
| 95 |
+
Latency, TTS intelligibility, TTS speech quality,
|
| 96 |
+
ASR WER, and text dialog metrics.
|
| 97 |
+
It is designed to integrate with Gradio via
|
| 98 |
+
multiple yield statements,
|
| 99 |
+
allowing updates to be displayed in real time.
|
| 100 |
+
|
| 101 |
+
Parameters:
|
| 102 |
+
----------
|
| 103 |
+
option : str
|
| 104 |
+
The evaluation metric selected by the user.
|
| 105 |
+
Supported options include:
|
| 106 |
+
- "Latency"
|
| 107 |
+
- "TTS Intelligibility"
|
| 108 |
+
- "TTS Speech Quality"
|
| 109 |
+
- "ASR WER"
|
| 110 |
+
- "Text Dialog Metrics"
|
| 111 |
+
TTS_audio_output : np.ndarray
|
| 112 |
+
The audio output generated by the TTS module for evaluation.
|
| 113 |
+
LLM_Output : str
|
| 114 |
+
The text output generated by the LLM module for evaluation.
|
| 115 |
+
ASR_audio_output : np.ndarray
|
| 116 |
+
The audio input/output used for ASR evaluation.
|
| 117 |
+
ASR_transcript : str
|
| 118 |
+
The transcript generated by the ASR module for evaluation.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
-------
|
| 122 |
+
str
|
| 123 |
+
A string representation of the evaluation results.
|
| 124 |
+
The specific result depends on the selected evaluation metric:
|
| 125 |
+
- "Latency": Latencies of ASR, LLM, and TTS modules.
|
| 126 |
+
- "TTS Intelligibility": A range of scores indicating how intelligible
|
| 127 |
+
the TTS audio output is based on different reference ASR models.
|
| 128 |
+
- "TTS Speech Quality": A range of scores representing the
|
| 129 |
+
speech quality of the TTS audio output.
|
| 130 |
+
- "ASR WER": The Word Error Rate (WER) of the ASR output
|
| 131 |
+
based on different judge ASR models.
|
| 132 |
+
- "Text Dialog Metrics": A combination of perplexity,
|
| 133 |
+
diversity metrics, and relevance scores for the dialog.
|
| 134 |
+
|
| 135 |
+
Raises:
|
| 136 |
+
------
|
| 137 |
+
ValueError
|
| 138 |
+
If the `option` parameter does not match any supported evaluation metric.
|
| 139 |
+
|
| 140 |
+
Example:
|
| 141 |
+
-------
|
| 142 |
+
>>> result = handle_eval_selection(
|
| 143 |
+
option="Latency",
|
| 144 |
+
TTS_audio_output=audio_array,
|
| 145 |
+
LLM_Output="Generated response",
|
| 146 |
+
ASR_audio_output=audio_input,
|
| 147 |
+
ASR_transcript="Expected transcript"
|
| 148 |
+
)
|
| 149 |
+
>>> print(result)
|
| 150 |
+
"ASR Latency: 0.14
|
| 151 |
+
LLM Latency: 0.42
|
| 152 |
+
TTS Latency: 0.21"
|
| 153 |
+
"""
|
| 154 |
global LLM_response_arr
|
| 155 |
global total_response_arr
|
| 156 |
+
yield (option, gr.Textbox(visible=True))
|
| 157 |
+
if option == "Latency":
|
| 158 |
+
text = (
|
| 159 |
+
f"ASR Latency: {latency_ASR:.2f}\n"
|
| 160 |
+
f"LLM Latency: {latency_LM:.2f}\n"
|
| 161 |
+
f"TTS Latency: {latency_TTS:.2f}"
|
| 162 |
+
)
|
| 163 |
+
yield (None, text)
|
| 164 |
+
elif option == "TTS Intelligibility":
|
| 165 |
+
yield (None, handle_espnet_TTS_intelligibility(TTS_audio_output, LLM_Output))
|
| 166 |
+
elif option == "TTS Speech Quality":
|
| 167 |
+
yield (None, TTS_psuedomos(TTS_audio_output))
|
| 168 |
+
elif option == "ASR WER":
|
| 169 |
+
yield (None, handle_espnet_ASR_WER(ASR_audio_output, ASR_transcript))
|
| 170 |
+
elif option == "Text Dialog Metrics":
|
| 171 |
+
yield (
|
| 172 |
+
None,
|
| 173 |
+
perplexity(LLM_Output.replace("\n", " "))
|
| 174 |
+
+ vert(LLM_response_arr)
|
| 175 |
+
+ bert_score(total_response_arr)
|
| 176 |
+
+ DialoGPT_perplexity(
|
| 177 |
+
ASR_transcript.replace("\n", " "), LLM_Output.replace("\n", " ")
|
| 178 |
+
),
|
| 179 |
+
)
|
| 180 |
+
elif option is None:
|
| 181 |
+
return
|
| 182 |
+
else:
|
| 183 |
+
raise ValueError(f"Unknown option: {option}")
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def handle_eval_selection_E2E(
|
| 187 |
+
option: str,
|
| 188 |
+
TTS_audio_output: str,
|
| 189 |
+
LLM_Output: str,
|
| 190 |
+
):
|
| 191 |
+
"""
|
| 192 |
+
Handles the evaluation of a selected metric based on user input
|
| 193 |
+
and provided outputs.
|
| 194 |
+
|
| 195 |
+
This function evaluates different aspects of an E2E
|
| 196 |
+
conversational AI model, such as:
|
| 197 |
+
Latency, TTS intelligibility, TTS speech quality, and
|
| 198 |
+
text dialog metrics.
|
| 199 |
+
It is designed to integrate with Gradio via
|
| 200 |
+
multiple yield statements,
|
| 201 |
+
allowing updates to be displayed in real time.
|
| 202 |
+
|
| 203 |
+
Parameters:
|
| 204 |
+
----------
|
| 205 |
+
option : str
|
| 206 |
+
The evaluation metric selected by the user.
|
| 207 |
+
Supported options include:
|
| 208 |
+
- "Latency"
|
| 209 |
+
- "TTS Intelligibility"
|
| 210 |
+
- "TTS Speech Quality"
|
| 211 |
+
- "Text Dialog Metrics"
|
| 212 |
+
TTS_audio_output : np.ndarray
|
| 213 |
+
The audio output generated by the TTS module for evaluation.
|
| 214 |
+
LLM_Output : str
|
| 215 |
+
The text output generated by the LLM module for evaluation.
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
-------
|
| 219 |
+
str
|
| 220 |
+
A string representation of the evaluation results.
|
| 221 |
+
The specific result depends on the selected evaluation metric:
|
| 222 |
+
- "Latency": Latency of the entire system.
|
| 223 |
+
- "TTS Intelligibility": A range of scores indicating how intelligible the
|
| 224 |
+
TTS audio output is based on different reference ASR models.
|
| 225 |
+
- "TTS Speech Quality": A range of scores representing the
|
| 226 |
+
speech quality of the TTS audio output.
|
| 227 |
+
- "Text Dialog Metrics": A combination of perplexity and
|
| 228 |
+
diversity metrics for the dialog.
|
| 229 |
+
|
| 230 |
+
Raises:
|
| 231 |
+
------
|
| 232 |
+
ValueError
|
| 233 |
+
If the `option` parameter does not match any supported evaluation metric.
|
| 234 |
+
|
| 235 |
+
Example:
|
| 236 |
+
-------
|
| 237 |
+
>>> result = handle_eval_selection(
|
| 238 |
+
option="Latency",
|
| 239 |
+
TTS_audio_output=audio_array,
|
| 240 |
+
LLM_Output="Generated response",
|
| 241 |
+
)
|
| 242 |
+
>>> print(result)
|
| 243 |
+
"Total Latency: 2.34"
|
| 244 |
+
"""
|
| 245 |
global LLM_response_arr
|
| 246 |
global total_response_arr
|
| 247 |
+
yield (option, gr.Textbox(visible=True))
|
| 248 |
+
if option == "Latency":
|
| 249 |
+
text = f"Total Latency: {latency_TTS:.2f}"
|
| 250 |
+
yield (None, text)
|
| 251 |
+
elif option == "TTS Intelligibility":
|
| 252 |
+
yield (None, handle_espnet_TTS_intelligibility(TTS_audio_output, LLM_Output))
|
| 253 |
+
elif option == "TTS Speech Quality":
|
| 254 |
+
yield (None, TTS_psuedomos(TTS_audio_output))
|
| 255 |
+
elif option == "Text Dialog Metrics":
|
| 256 |
+
yield (None, perplexity(LLM_Output.replace("\n", " ")) + vert(LLM_response_arr))
|
| 257 |
+
elif option is None:
|
| 258 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
else:
|
| 260 |
+
raise ValueError(f"Unknown option: {option}")
|
| 261 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
|
| 263 |
def start_warmup():
|
| 264 |
+
"""
|
| 265 |
+
Initializes and warms up the dialogue and evaluation model.
|
| 266 |
+
|
| 267 |
+
This function is designed to ensure that all
|
| 268 |
+
components of the dialogue model are pre-loaded
|
| 269 |
+
and ready for execution, avoiding delays during runtime.
|
| 270 |
+
"""
|
| 271 |
+
global dialogue_model
|
| 272 |
+
global ASR_options
|
| 273 |
+
global LLM_options
|
| 274 |
+
global TTS_options
|
| 275 |
+
global ASR_name
|
| 276 |
+
global LLM_name
|
| 277 |
+
global TTS_name
|
| 278 |
+
for opt_count in range(len(ASR_options)):
|
| 279 |
+
opt = ASR_options[opt_count]
|
| 280 |
+
try:
|
| 281 |
+
for _ in dialogue_model.handle_ASR_selection(opt):
|
| 282 |
+
continue
|
| 283 |
+
except Exception:
|
| 284 |
+
print("Removing " + opt + " from ASR options since it cannot be loaded.")
|
| 285 |
+
ASR_options = ASR_options[:opt_count] + ASR_options[(opt_count + 1) :]
|
| 286 |
+
if opt == ASR_name:
|
| 287 |
+
ASR_name = ASR_options[0]
|
| 288 |
+
for opt_count in range(len(LLM_options)):
|
| 289 |
+
opt = LLM_options[opt_count]
|
| 290 |
+
try:
|
| 291 |
+
for _ in dialogue_model.handle_LLM_selection(opt):
|
| 292 |
+
continue
|
| 293 |
+
except Exception:
|
| 294 |
+
print("Removing " + opt + " from LLM options since it cannot be loaded.")
|
| 295 |
+
LLM_options = LLM_options[:opt_count] + LLM_options[(opt_count + 1) :]
|
| 296 |
+
if opt == LLM_name:
|
| 297 |
+
LLM_name = LLM_options[0]
|
| 298 |
+
for opt_count in range(len(TTS_options)):
|
| 299 |
+
opt = TTS_options[opt_count]
|
| 300 |
+
try:
|
| 301 |
+
for _ in dialogue_model.handle_TTS_selection(opt):
|
| 302 |
+
continue
|
| 303 |
+
except Exception:
|
| 304 |
+
print("Removing " + opt + " from TTS options since it cannot be loaded.")
|
| 305 |
+
TTS_options = TTS_options[:opt_count] + TTS_options[(opt_count + 1) :]
|
| 306 |
+
if opt == TTS_name:
|
| 307 |
+
TTS_name = TTS_options[0]
|
| 308 |
+
dialogue_model.handle_E2E_selection()
|
| 309 |
+
dialogue_model.client = None
|
| 310 |
+
for _ in dialogue_model.handle_TTS_selection(TTS_name):
|
| 311 |
continue
|
| 312 |
+
for _ in dialogue_model.handle_ASR_selection(ASR_name):
|
| 313 |
continue
|
| 314 |
+
for _ in dialogue_model.handle_LLM_selection(LLM_name):
|
| 315 |
continue
|
| 316 |
+
dummy_input = (
|
| 317 |
+
torch.randn(
|
| 318 |
(3000),
|
| 319 |
dtype=getattr(torch, "float16"),
|
| 320 |
device="cpu",
|
| 321 |
+
)
|
| 322 |
+
.cpu()
|
| 323 |
+
.numpy()
|
| 324 |
+
)
|
| 325 |
+
dummy_text = "This is dummy text"
|
| 326 |
for opt in Eval_options:
|
| 327 |
handle_eval_selection(opt, dummy_input, dummy_text, dummy_input, dummy_text)
|
| 328 |
|
|
|
|
|
|
|
| 329 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
def flash_buttons():
|
| 331 |
+
"""
|
| 332 |
+
Enables human feedback buttons after displaying system output.
|
| 333 |
+
"""
|
| 334 |
btn_updates = (enable_btn,) * 8
|
| 335 |
+
yield (
|
| 336 |
+
"",
|
| 337 |
+
"",
|
| 338 |
+
) + btn_updates
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def transcribe(
|
| 342 |
+
stream: np.ndarray,
|
| 343 |
+
new_chunk: Tuple[int, np.ndarray],
|
| 344 |
+
TTS_option: str,
|
| 345 |
+
ASR_option: str,
|
| 346 |
+
LLM_option: str,
|
| 347 |
+
type_option: str,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
):
|
| 349 |
+
"""
|
| 350 |
+
Processes and transcribes an audio stream in real-time.
|
| 351 |
+
|
| 352 |
+
This function handles the transcription of audio input
|
| 353 |
+
and its transformation through a cascaded
|
| 354 |
+
or E2E conversational AI system.
|
| 355 |
+
It dynamically updates the transcription, text generation,
|
| 356 |
+
and synthesized speech output, while managing global states and latencies.
|
| 357 |
+
|
| 358 |
+
Args:
|
| 359 |
+
stream: The current audio stream buffer.
|
| 360 |
+
`None` if the stream is being reset (e.g., after user refresh).
|
| 361 |
+
new_chunk: A tuple containing:
|
| 362 |
+
- `sr`: Sample rate of the new audio chunk.
|
| 363 |
+
- `y`: New audio data chunk.
|
| 364 |
+
TTS_option: Selected TTS model option.
|
| 365 |
+
ASR_option: Selected ASR model option.
|
| 366 |
+
LLM_option: Selected LLM model option.
|
| 367 |
+
type_option: Type of system ("Cascaded" or "E2E").
|
| 368 |
+
|
| 369 |
+
Yields:
|
| 370 |
+
Tuple[Optional[np.ndarray], Optional[str], Optional[str],
|
| 371 |
+
Optional[Tuple[int, np.ndarray]], Optional[Tuple[int, np.ndarray]]]:
|
| 372 |
+
A tuple containing:
|
| 373 |
+
- Updated stream buffer.
|
| 374 |
+
- ASR output text.
|
| 375 |
+
- Generated LLM output text.
|
| 376 |
+
- Audio output as a tuple of sample rate and audio waveform.
|
| 377 |
+
- User input audio as a tuple of sample rate and audio waveform.
|
| 378 |
+
|
| 379 |
+
Notes:
|
| 380 |
+
- Resets the session if the transcription exceeds 5 minutes.
|
| 381 |
+
- Updates the Gradio interface elements dynamically.
|
| 382 |
+
- Manages latencies.
|
| 383 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
sr, y = new_chunk
|
| 385 |
global text_str
|
| 386 |
global chat
|
|
|
|
| 399 |
global total_response_arr
|
| 400 |
if stream is None:
|
| 401 |
# Handle user refresh
|
| 402 |
+
for (
|
| 403 |
+
_,
|
| 404 |
+
_,
|
| 405 |
+
_,
|
| 406 |
+
_,
|
| 407 |
+
asr_output_box,
|
| 408 |
+
text_box,
|
| 409 |
+
audio_box,
|
| 410 |
+
_,
|
| 411 |
+
_,
|
| 412 |
+
) in dialogue_model.handle_type_selection(
|
| 413 |
+
type_option, TTS_option, ASR_option, LLM_option
|
| 414 |
+
):
|
| 415 |
gr.Info("The models are being reloaded due to a browser refresh.")
|
| 416 |
+
yield (stream, asr_output_box, text_box, audio_box, gr.Audio(visible=False))
|
| 417 |
+
stream = y
|
| 418 |
+
text_str = ""
|
|
|
|
| 419 |
audio_output = None
|
| 420 |
audio_output1 = None
|
| 421 |
else:
|
| 422 |
+
stream = np.concatenate((stream, y))
|
| 423 |
+
(
|
| 424 |
+
asr_output_str,
|
| 425 |
+
text_str,
|
| 426 |
+
audio_output,
|
| 427 |
+
audio_output1,
|
| 428 |
+
latency_ASR,
|
| 429 |
+
latency_LM,
|
| 430 |
+
latency_TTS,
|
| 431 |
+
stream,
|
| 432 |
+
change,
|
| 433 |
+
) = dialogue_model(
|
| 434 |
+
y,
|
| 435 |
+
sr,
|
| 436 |
+
stream,
|
| 437 |
+
asr_output_str,
|
| 438 |
+
text_str,
|
| 439 |
+
audio_output,
|
| 440 |
+
audio_output1,
|
| 441 |
+
latency_ASR,
|
| 442 |
+
latency_LM,
|
| 443 |
+
latency_TTS,
|
| 444 |
+
)
|
| 445 |
+
text_str1 = text_str
|
| 446 |
+
if change:
|
| 447 |
+
print("Output changed")
|
| 448 |
+
if asr_output_str != "":
|
| 449 |
+
total_response_arr.append(asr_output_str.replace("\n", " "))
|
| 450 |
+
LLM_response_arr.append(text_str.replace("\n", " "))
|
| 451 |
+
total_response_arr.append(text_str.replace("\n", " "))
|
| 452 |
+
if (text_str != "") and (start_record_time is None):
|
| 453 |
+
start_record_time = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
elif start_record_time is not None:
|
| 455 |
+
current_record_time = time.time()
|
| 456 |
+
if current_record_time - start_record_time > 300:
|
| 457 |
+
gr.Info(
|
| 458 |
+
"Conversations are limited to 5 minutes. "
|
| 459 |
+
"The session will restart in approximately 60 seconds. "
|
| 460 |
+
"Please wait for the demo to reset. "
|
| 461 |
+
"Close this message once you have read it.",
|
| 462 |
+
duration=None,
|
| 463 |
+
)
|
| 464 |
+
yield stream, gr.Textbox(visible=False), gr.Textbox(
|
| 465 |
+
visible=False
|
| 466 |
+
), gr.Audio(visible=False), gr.Audio(visible=False)
|
| 467 |
if upload_to_hub is not None:
|
| 468 |
api.upload_folder(
|
| 469 |
folder_path="flagged_data_points",
|
| 470 |
+
path_in_repo="checkpoint_" + str(start_record_time),
|
| 471 |
repo_id=upload_to_hub,
|
| 472 |
repo_type="dataset",
|
| 473 |
token=access_token,
|
| 474 |
)
|
| 475 |
+
dialogue_model.chat.buffer = []
|
| 476 |
+
text_str = ""
|
| 477 |
audio_output = None
|
| 478 |
audio_output1 = None
|
| 479 |
asr_output_str = ""
|
| 480 |
start_record_time = None
|
| 481 |
+
LLM_response_arr = []
|
| 482 |
+
total_response_arr = []
|
| 483 |
+
shutil.rmtree("flagged_data_points")
|
| 484 |
os.mkdir("flagged_data_points")
|
| 485 |
+
yield (stream, asr_output_str, text_str1, audio_output, audio_output1)
|
| 486 |
+
yield stream, gr.Textbox(visible=True), gr.Textbox(visible=True), gr.Audio(
|
| 487 |
+
visible=True
|
| 488 |
+
), gr.Audio(visible=False)
|
| 489 |
|
| 490 |
+
yield (stream, asr_output_str, text_str1, audio_output, audio_output1)
|
| 491 |
|
| 492 |
+
|
| 493 |
+
# ------------------------
|
| 494 |
+
# Executable Script
|
| 495 |
+
# ------------------------
|
| 496 |
+
api = HfApi()
|
| 497 |
+
nltk.download("averaged_perceptron_tagger_eng")
|
| 498 |
+
start_warmup()
|
| 499 |
with gr.Blocks(
|
| 500 |
+
title="E2E Spoken Dialog System",
|
| 501 |
+
) as demo:
|
| 502 |
+
with gr.Row():
|
| 503 |
+
gr.Markdown(
|
| 504 |
+
"""
|
| 505 |
+
## ESPnet-SDS
|
| 506 |
+
Welcome to our unified web interface for various cascaded and
|
| 507 |
+
E2E spoken dialogue systems built using ESPnet-SDS toolkit,
|
| 508 |
+
supporting real-time automated evaluation metrics, and
|
| 509 |
+
human-in-the-loop feedback collection.
|
| 510 |
+
|
| 511 |
+
For more details on how to use the app, refer to the [README]
|
| 512 |
+
(https://github.com/siddhu001/espnet/tree/sds_demo_recipe/egs2/TEMPLATE/sds1#how-to-use).
|
| 513 |
+
"""
|
| 514 |
+
)
|
| 515 |
+
with gr.Row():
|
| 516 |
+
with gr.Column(scale=1):
|
| 517 |
+
user_audio = gr.Audio(
|
| 518 |
+
sources=["microphone"],
|
| 519 |
+
streaming=True,
|
| 520 |
+
waveform_options=gr.WaveformOptions(sample_rate=16000),
|
| 521 |
+
)
|
| 522 |
+
with gr.Row():
|
| 523 |
+
type_radio = gr.Radio(
|
| 524 |
+
choices=["Cascaded", "E2E"],
|
| 525 |
+
label="Choose type of Spoken Dialog:",
|
| 526 |
+
value="Cascaded",
|
| 527 |
+
)
|
| 528 |
+
with gr.Row():
|
| 529 |
+
ASR_radio = gr.Radio(
|
| 530 |
+
choices=ASR_options,
|
| 531 |
+
label="Choose ASR:",
|
| 532 |
+
value=ASR_name,
|
| 533 |
+
)
|
| 534 |
+
with gr.Row():
|
| 535 |
+
LLM_radio = gr.Radio(
|
| 536 |
+
choices=LLM_options,
|
| 537 |
+
label="Choose LLM:",
|
| 538 |
+
value=LLM_name,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 539 |
)
|
| 540 |
+
with gr.Row():
|
| 541 |
+
radio = gr.Radio(
|
| 542 |
+
choices=TTS_options,
|
| 543 |
+
label="Choose TTS:",
|
| 544 |
+
value=TTS_name,
|
| 545 |
+
)
|
| 546 |
+
with gr.Row():
|
| 547 |
+
E2Eradio = gr.Radio(
|
| 548 |
+
choices=["mini-omni"],
|
| 549 |
+
label="Choose E2E model:",
|
| 550 |
+
value="mini-omni",
|
| 551 |
visible=False,
|
| 552 |
)
|
| 553 |
+
with gr.Row():
|
| 554 |
+
feedback_btn = gr.Button(
|
| 555 |
+
value=(
|
| 556 |
+
"Please provide your feedback "
|
| 557 |
+
"after each system response below."
|
| 558 |
+
),
|
| 559 |
+
visible=True,
|
| 560 |
+
interactive=False,
|
| 561 |
+
elem_id="button",
|
| 562 |
+
)
|
| 563 |
+
with gr.Row():
|
| 564 |
+
natural_btn1 = gr.Button(
|
| 565 |
+
value="Very Natural", visible=False, interactive=False, scale=1
|
| 566 |
+
)
|
| 567 |
+
natural_btn2 = gr.Button(
|
| 568 |
+
value="Somewhat Awkward", visible=False, interactive=False, scale=1
|
| 569 |
+
)
|
| 570 |
+
natural_btn3 = gr.Button(
|
| 571 |
+
value="Very Awkward", visible=False, interactive=False, scale=1
|
| 572 |
+
)
|
| 573 |
+
natural_btn4 = gr.Button(
|
| 574 |
+
value="Unnatural", visible=False, interactive=False, scale=1
|
| 575 |
+
)
|
| 576 |
+
with gr.Row():
|
| 577 |
+
relevant_btn1 = gr.Button(
|
| 578 |
+
value="Highly Relevant", visible=False, interactive=False, scale=1
|
| 579 |
+
)
|
| 580 |
+
relevant_btn2 = gr.Button(
|
| 581 |
+
value="Partially Relevant",
|
| 582 |
+
visible=False,
|
| 583 |
+
interactive=False,
|
| 584 |
+
scale=1,
|
| 585 |
+
)
|
| 586 |
+
relevant_btn3 = gr.Button(
|
| 587 |
+
value="Slightly Irrelevant",
|
| 588 |
+
visible=False,
|
| 589 |
+
interactive=False,
|
| 590 |
+
scale=1,
|
| 591 |
+
)
|
| 592 |
+
relevant_btn4 = gr.Button(
|
| 593 |
+
value="Completely Irrelevant",
|
| 594 |
+
visible=False,
|
| 595 |
+
interactive=False,
|
| 596 |
+
scale=1,
|
| 597 |
+
)
|
| 598 |
+
with gr.Column(scale=1):
|
| 599 |
+
output_audio = gr.Audio(label="Output", autoplay=True, visible=True)
|
| 600 |
+
output_audio1 = gr.Audio(label="Output1", autoplay=False, visible=False)
|
| 601 |
+
output_asr_text = gr.Textbox(label="ASR output")
|
| 602 |
+
output_text = gr.Textbox(label="LLM output")
|
| 603 |
+
eval_radio = gr.Radio(
|
| 604 |
+
choices=[
|
| 605 |
+
"Latency",
|
| 606 |
+
"TTS Intelligibility",
|
| 607 |
+
"TTS Speech Quality",
|
| 608 |
+
"ASR WER",
|
| 609 |
+
"Text Dialog Metrics",
|
| 610 |
+
],
|
| 611 |
+
label="Choose Evaluation metrics:",
|
| 612 |
+
)
|
| 613 |
+
eval_radio_E2E = gr.Radio(
|
| 614 |
+
choices=[
|
| 615 |
+
"Latency",
|
| 616 |
+
"TTS Intelligibility",
|
| 617 |
+
"TTS Speech Quality",
|
| 618 |
+
"Text Dialog Metrics",
|
| 619 |
+
],
|
| 620 |
+
label="Choose Evaluation metrics:",
|
| 621 |
+
visible=False,
|
| 622 |
+
)
|
| 623 |
+
output_eval_text = gr.Textbox(label="Evaluation Results")
|
| 624 |
+
state = gr.State()
|
| 625 |
+
with gr.Row():
|
| 626 |
+
privacy_text = gr.Textbox(
|
| 627 |
+
label="Privacy Notice",
|
| 628 |
+
interactive=False,
|
| 629 |
+
value=(
|
| 630 |
+
"By using this demo, you acknowledge that"
|
| 631 |
+
"interactions with this dialog system are collected "
|
| 632 |
+
"for research and improvement purposes. The data "
|
| 633 |
+
"will only be used to enhance the performance and "
|
| 634 |
+
"understanding of the system. If you have any "
|
| 635 |
+
"concerns about data collection, please discontinue "
|
| 636 |
+
"use."
|
| 637 |
+
),
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
btn_list = [
|
| 641 |
+
natural_btn1,
|
| 642 |
+
natural_btn2,
|
| 643 |
+
natural_btn3,
|
| 644 |
+
natural_btn4,
|
| 645 |
+
relevant_btn1,
|
| 646 |
+
relevant_btn2,
|
| 647 |
+
relevant_btn3,
|
| 648 |
+
relevant_btn4,
|
| 649 |
+
]
|
| 650 |
+
natural_btn_list = [
|
| 651 |
+
natural_btn1,
|
| 652 |
+
natural_btn2,
|
| 653 |
+
natural_btn3,
|
| 654 |
+
natural_btn4,
|
| 655 |
+
]
|
| 656 |
+
relevant_btn_list = [
|
| 657 |
+
relevant_btn1,
|
| 658 |
+
relevant_btn2,
|
| 659 |
+
relevant_btn3,
|
| 660 |
+
relevant_btn4,
|
| 661 |
+
]
|
| 662 |
+
natural_response = gr.Textbox(
|
| 663 |
+
label="natural_response", visible=False, interactive=False
|
| 664 |
+
)
|
| 665 |
+
diversity_response = gr.Textbox(
|
| 666 |
+
label="diversity_response", visible=False, interactive=False
|
| 667 |
+
)
|
| 668 |
+
ip_address = gr.Textbox(label="ip_address", visible=False, interactive=False)
|
| 669 |
+
callback.setup(
|
| 670 |
+
[
|
| 671 |
+
user_audio,
|
| 672 |
+
output_asr_text,
|
| 673 |
+
output_text,
|
| 674 |
+
output_audio,
|
| 675 |
+
output_audio1,
|
| 676 |
+
type_radio,
|
| 677 |
+
ASR_radio,
|
| 678 |
+
LLM_radio,
|
| 679 |
+
radio,
|
| 680 |
+
E2Eradio,
|
| 681 |
+
natural_response,
|
| 682 |
+
diversity_response,
|
| 683 |
+
ip_address,
|
| 684 |
+
],
|
| 685 |
+
"flagged_data_points",
|
| 686 |
+
)
|
| 687 |
+
user_audio.stream(
|
| 688 |
+
transcribe,
|
| 689 |
+
inputs=[state, user_audio, radio, ASR_radio, LLM_radio, type_radio],
|
| 690 |
+
outputs=[state, output_asr_text, output_text, output_audio, output_audio1],
|
| 691 |
+
).then(
|
| 692 |
+
lambda *args: callback.flag(list(args)), [user_audio], None, preprocess=False
|
| 693 |
+
)
|
| 694 |
+
radio.change(
|
| 695 |
+
fn=dialogue_model.handle_TTS_selection,
|
| 696 |
+
inputs=[radio],
|
| 697 |
+
outputs=[output_asr_text, output_text, output_audio],
|
| 698 |
+
)
|
| 699 |
+
LLM_radio.change(
|
| 700 |
+
fn=dialogue_model.handle_LLM_selection,
|
| 701 |
+
inputs=[LLM_radio],
|
| 702 |
+
outputs=[output_asr_text, output_text, output_audio],
|
| 703 |
+
)
|
| 704 |
+
ASR_radio.change(
|
| 705 |
+
fn=dialogue_model.handle_ASR_selection,
|
| 706 |
+
inputs=[ASR_radio],
|
| 707 |
+
outputs=[output_asr_text, output_text, output_audio],
|
| 708 |
+
)
|
| 709 |
+
eval_radio.change(
|
| 710 |
+
fn=handle_eval_selection,
|
| 711 |
+
inputs=[eval_radio, output_audio, output_text, output_audio1, output_asr_text],
|
| 712 |
+
outputs=[eval_radio, output_eval_text],
|
| 713 |
+
)
|
| 714 |
+
eval_radio_E2E.change(
|
| 715 |
+
fn=handle_eval_selection_E2E,
|
| 716 |
+
inputs=[eval_radio_E2E, output_audio, output_text],
|
| 717 |
+
outputs=[eval_radio_E2E, output_eval_text],
|
| 718 |
+
)
|
| 719 |
+
type_radio.change(
|
| 720 |
+
fn=dialogue_model.handle_type_selection,
|
| 721 |
+
inputs=[type_radio, radio, ASR_radio, LLM_radio],
|
| 722 |
+
outputs=[
|
| 723 |
+
radio,
|
| 724 |
+
ASR_radio,
|
| 725 |
+
LLM_radio,
|
| 726 |
+
E2Eradio,
|
| 727 |
+
output_asr_text,
|
| 728 |
+
output_text,
|
| 729 |
+
output_audio,
|
| 730 |
+
eval_radio,
|
| 731 |
+
eval_radio_E2E,
|
| 732 |
+
],
|
| 733 |
+
)
|
| 734 |
+
output_audio.play(
|
| 735 |
+
flash_buttons, [], [natural_response, diversity_response] + btn_list
|
| 736 |
+
).then(
|
| 737 |
+
lambda *args: callback.flag(list(args)),
|
| 738 |
+
[
|
| 739 |
+
user_audio,
|
| 740 |
+
output_asr_text,
|
| 741 |
+
output_text,
|
| 742 |
+
output_audio,
|
| 743 |
+
output_audio1,
|
| 744 |
+
type_radio,
|
| 745 |
+
ASR_radio,
|
| 746 |
+
LLM_radio,
|
| 747 |
+
radio,
|
| 748 |
+
E2Eradio,
|
| 749 |
+
],
|
| 750 |
+
None,
|
| 751 |
+
preprocess=False,
|
| 752 |
+
)
|
| 753 |
+
natural_btn1.click(
|
| 754 |
+
natural_vote1_last_response,
|
| 755 |
+
[],
|
| 756 |
+
[natural_response, ip_address] + natural_btn_list,
|
| 757 |
+
).then(
|
| 758 |
+
lambda *args: callback.flag(list(args)),
|
| 759 |
+
[
|
| 760 |
+
user_audio,
|
| 761 |
+
output_asr_text,
|
| 762 |
+
output_text,
|
| 763 |
+
output_audio,
|
| 764 |
+
output_audio1,
|
| 765 |
+
type_radio,
|
| 766 |
+
ASR_radio,
|
| 767 |
+
LLM_radio,
|
| 768 |
+
radio,
|
| 769 |
+
E2Eradio,
|
| 770 |
+
natural_response,
|
| 771 |
+
diversity_response,
|
| 772 |
+
ip_address,
|
| 773 |
+
],
|
| 774 |
+
None,
|
| 775 |
+
preprocess=False,
|
| 776 |
+
)
|
| 777 |
+
natural_btn2.click(
|
| 778 |
+
natural_vote2_last_response,
|
| 779 |
+
[],
|
| 780 |
+
[natural_response, ip_address] + natural_btn_list,
|
| 781 |
+
).then(
|
| 782 |
+
lambda *args: callback.flag(list(args)),
|
| 783 |
+
[
|
| 784 |
+
user_audio,
|
| 785 |
+
output_asr_text,
|
| 786 |
+
output_text,
|
| 787 |
+
output_audio,
|
| 788 |
+
output_audio1,
|
| 789 |
+
type_radio,
|
| 790 |
+
ASR_radio,
|
| 791 |
+
LLM_radio,
|
| 792 |
+
radio,
|
| 793 |
+
E2Eradio,
|
| 794 |
+
natural_response,
|
| 795 |
+
diversity_response,
|
| 796 |
+
ip_address,
|
| 797 |
+
],
|
| 798 |
+
None,
|
| 799 |
+
preprocess=False,
|
| 800 |
+
)
|
| 801 |
+
natural_btn3.click(
|
| 802 |
+
natural_vote3_last_response,
|
| 803 |
+
[],
|
| 804 |
+
[natural_response, ip_address] + natural_btn_list,
|
| 805 |
+
).then(
|
| 806 |
+
lambda *args: callback.flag(list(args)),
|
| 807 |
+
[
|
| 808 |
+
user_audio,
|
| 809 |
+
output_asr_text,
|
| 810 |
+
output_text,
|
| 811 |
+
output_audio,
|
| 812 |
+
output_audio1,
|
| 813 |
+
type_radio,
|
| 814 |
+
ASR_radio,
|
| 815 |
+
LLM_radio,
|
| 816 |
+
radio,
|
| 817 |
+
E2Eradio,
|
| 818 |
+
natural_response,
|
| 819 |
+
diversity_response,
|
| 820 |
+
ip_address,
|
| 821 |
+
],
|
| 822 |
+
None,
|
| 823 |
+
preprocess=False,
|
| 824 |
+
)
|
| 825 |
+
natural_btn4.click(
|
| 826 |
+
natural_vote4_last_response,
|
| 827 |
+
[],
|
| 828 |
+
[natural_response, ip_address] + natural_btn_list,
|
| 829 |
+
).then(
|
| 830 |
+
lambda *args: callback.flag(list(args)),
|
| 831 |
+
[
|
| 832 |
+
user_audio,
|
| 833 |
+
output_asr_text,
|
| 834 |
+
output_text,
|
| 835 |
+
output_audio,
|
| 836 |
+
output_audio1,
|
| 837 |
+
type_radio,
|
| 838 |
+
ASR_radio,
|
| 839 |
+
LLM_radio,
|
| 840 |
+
radio,
|
| 841 |
+
E2Eradio,
|
| 842 |
+
natural_response,
|
| 843 |
+
diversity_response,
|
| 844 |
+
ip_address,
|
| 845 |
+
],
|
| 846 |
+
None,
|
| 847 |
+
preprocess=False,
|
| 848 |
+
)
|
| 849 |
+
relevant_btn1.click(
|
| 850 |
+
relevant_vote1_last_response,
|
| 851 |
+
[],
|
| 852 |
+
[diversity_response, ip_address] + relevant_btn_list,
|
| 853 |
+
).then(
|
| 854 |
+
lambda *args: callback.flag(list(args)),
|
| 855 |
+
[
|
| 856 |
+
user_audio,
|
| 857 |
+
output_asr_text,
|
| 858 |
+
output_text,
|
| 859 |
+
output_audio,
|
| 860 |
+
output_audio1,
|
| 861 |
+
type_radio,
|
| 862 |
+
ASR_radio,
|
| 863 |
+
LLM_radio,
|
| 864 |
+
radio,
|
| 865 |
+
E2Eradio,
|
| 866 |
+
natural_response,
|
| 867 |
+
diversity_response,
|
| 868 |
+
ip_address,
|
| 869 |
+
],
|
| 870 |
+
None,
|
| 871 |
+
preprocess=False,
|
| 872 |
+
)
|
| 873 |
+
relevant_btn2.click(
|
| 874 |
+
relevant_vote2_last_response,
|
| 875 |
+
[],
|
| 876 |
+
[diversity_response, ip_address] + relevant_btn_list,
|
| 877 |
+
).then(
|
| 878 |
+
lambda *args: callback.flag(list(args)),
|
| 879 |
+
[
|
| 880 |
+
user_audio,
|
| 881 |
+
output_asr_text,
|
| 882 |
+
output_text,
|
| 883 |
+
output_audio,
|
| 884 |
+
output_audio1,
|
| 885 |
+
type_radio,
|
| 886 |
+
ASR_radio,
|
| 887 |
+
LLM_radio,
|
| 888 |
+
radio,
|
| 889 |
+
E2Eradio,
|
| 890 |
+
natural_response,
|
| 891 |
+
diversity_response,
|
| 892 |
+
ip_address,
|
| 893 |
+
],
|
| 894 |
+
None,
|
| 895 |
+
preprocess=False,
|
| 896 |
+
)
|
| 897 |
+
relevant_btn3.click(
|
| 898 |
+
relevant_vote3_last_response,
|
| 899 |
+
[],
|
| 900 |
+
[diversity_response, ip_address] + relevant_btn_list,
|
| 901 |
+
).then(
|
| 902 |
+
lambda *args: callback.flag(list(args)),
|
| 903 |
+
[
|
| 904 |
+
user_audio,
|
| 905 |
+
output_asr_text,
|
| 906 |
+
output_text,
|
| 907 |
+
output_audio,
|
| 908 |
+
output_audio1,
|
| 909 |
+
type_radio,
|
| 910 |
+
ASR_radio,
|
| 911 |
+
LLM_radio,
|
| 912 |
+
radio,
|
| 913 |
+
E2Eradio,
|
| 914 |
+
natural_response,
|
| 915 |
+
diversity_response,
|
| 916 |
+
ip_address,
|
| 917 |
+
],
|
| 918 |
+
None,
|
| 919 |
+
preprocess=False,
|
| 920 |
+
)
|
| 921 |
+
relevant_btn4.click(
|
| 922 |
+
relevant_vote4_last_response,
|
| 923 |
+
[],
|
| 924 |
+
[diversity_response, ip_address] + relevant_btn_list,
|
| 925 |
+
).then(
|
| 926 |
+
lambda *args: callback.flag(list(args)),
|
| 927 |
+
[
|
| 928 |
+
user_audio,
|
| 929 |
+
output_asr_text,
|
| 930 |
+
output_text,
|
| 931 |
+
output_audio,
|
| 932 |
+
output_audio1,
|
| 933 |
+
type_radio,
|
| 934 |
+
ASR_radio,
|
| 935 |
+
LLM_radio,
|
| 936 |
+
radio,
|
| 937 |
+
E2Eradio,
|
| 938 |
+
natural_response,
|
| 939 |
+
diversity_response,
|
| 940 |
+
ip_address,
|
| 941 |
+
],
|
| 942 |
+
None,
|
| 943 |
+
preprocess=False,
|
| 944 |
+
)
|
| 945 |
demo.launch(share=True)
|
| 946 |
+
|
pyscripts/utils/dialog_eval/ASR_WER.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from espnet2.sds.utils.utils import int2float
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def handle_espnet_ASR_WER(
|
| 9 |
+
ASR_audio_output: Tuple[int, np.ndarray], ASR_transcript: str
|
| 10 |
+
) -> str:
|
| 11 |
+
"""
|
| 12 |
+
Compute and return Word Error Rate (WER) and Character Error Rate (CER) metrics
|
| 13 |
+
for multiple judge ASR systems (ESPnet, OWSM, Whisper) using the Versa library.
|
| 14 |
+
|
| 15 |
+
This function performs the following:
|
| 16 |
+
1. Imports necessary metrics and setup functions from Versa.
|
| 17 |
+
2. Prepares configuration arguments for each ASR system (ESPnet, OWSM, Whisper).
|
| 18 |
+
3. Runs the Levenshtein-based WER/CER calculations.
|
| 19 |
+
4. Returns a formatted string summarizing WER and CER
|
| 20 |
+
results for reference produced by each ASR system.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
ASR_audio_output (tuple):
|
| 24 |
+
A tuple where:
|
| 25 |
+
- The first element is the frame rate.
|
| 26 |
+
- The second element is the audio signal (NumPy array).
|
| 27 |
+
ASR_transcript (str):
|
| 28 |
+
The transcript produced by the ASR model in the cascaded
|
| 29 |
+
conversational AI pipeline.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
str:
|
| 33 |
+
A formatted string showing the WER and CER percentages
|
| 34 |
+
for ESPnet, OWSM, and Whisper. Example output:
|
| 35 |
+
|
| 36 |
+
"ESPnet WER: 10.50
|
| 37 |
+
ESPnet CER: 7.20
|
| 38 |
+
OWSM WER: 11.30
|
| 39 |
+
OWSM CER: 8.00
|
| 40 |
+
Whisper WER: 9.25
|
| 41 |
+
Whisper CER: 6.50"
|
| 42 |
+
|
| 43 |
+
Raises:
|
| 44 |
+
ImportError:
|
| 45 |
+
If Versa is not installed or cannot be imported.
|
| 46 |
+
|
| 47 |
+
Example:
|
| 48 |
+
>>> asr_audio_output = (16000, audio_array)
|
| 49 |
+
>>> asr_transcript = "This is the ASR transcript."
|
| 50 |
+
>>> result = handle_espnet_ASR_WER(asr_audio_output, asr_transcript)
|
| 51 |
+
>>> print(result)
|
| 52 |
+
"ESPnet WER: 10.50
|
| 53 |
+
ESPnet CER: 7.20
|
| 54 |
+
OWSM WER: 11.30
|
| 55 |
+
OWSM CER: 8.00
|
| 56 |
+
Whisper WER: 9.25
|
| 57 |
+
Whisper CER: 6.50"
|
| 58 |
+
"""
|
| 59 |
+
try:
|
| 60 |
+
from versa import (
|
| 61 |
+
espnet_levenshtein_metric,
|
| 62 |
+
espnet_wer_setup,
|
| 63 |
+
owsm_levenshtein_metric,
|
| 64 |
+
owsm_wer_setup,
|
| 65 |
+
whisper_levenshtein_metric,
|
| 66 |
+
whisper_wer_setup,
|
| 67 |
+
)
|
| 68 |
+
except Exception as e:
|
| 69 |
+
print("Error: Versa is not properly installed.")
|
| 70 |
+
raise e
|
| 71 |
+
score_modules_espnet = {
|
| 72 |
+
"module": espnet_levenshtein_metric,
|
| 73 |
+
"args": espnet_wer_setup(
|
| 74 |
+
model_tag="default",
|
| 75 |
+
beam_size=1,
|
| 76 |
+
text_cleaner="whisper_en",
|
| 77 |
+
use_gpu=True,
|
| 78 |
+
),
|
| 79 |
+
}
|
| 80 |
+
dict1 = score_modules_espnet["module"](
|
| 81 |
+
score_modules_espnet["args"],
|
| 82 |
+
int2float(ASR_audio_output[1]),
|
| 83 |
+
ASR_transcript,
|
| 84 |
+
ASR_audio_output[0],
|
| 85 |
+
)
|
| 86 |
+
espnet_wer = (
|
| 87 |
+
dict1["espnet_wer_delete"]
|
| 88 |
+
+ dict1["espnet_wer_insert"]
|
| 89 |
+
+ dict1["espnet_wer_replace"]
|
| 90 |
+
) / (
|
| 91 |
+
dict1["espnet_wer_insert"]
|
| 92 |
+
+ dict1["espnet_wer_replace"]
|
| 93 |
+
+ dict1["espnet_wer_equal"]
|
| 94 |
+
)
|
| 95 |
+
espnet_cer = (
|
| 96 |
+
dict1["espnet_cer_delete"]
|
| 97 |
+
+ dict1["espnet_cer_insert"]
|
| 98 |
+
+ dict1["espnet_cer_replace"]
|
| 99 |
+
) / (
|
| 100 |
+
dict1["espnet_cer_insert"]
|
| 101 |
+
+ dict1["espnet_cer_replace"]
|
| 102 |
+
+ dict1["espnet_cer_equal"]
|
| 103 |
+
)
|
| 104 |
+
score_modules_owsm = {
|
| 105 |
+
"module": owsm_levenshtein_metric,
|
| 106 |
+
"args": owsm_wer_setup(
|
| 107 |
+
model_tag="default",
|
| 108 |
+
beam_size=1,
|
| 109 |
+
text_cleaner="whisper_en",
|
| 110 |
+
use_gpu=True,
|
| 111 |
+
),
|
| 112 |
+
}
|
| 113 |
+
dict1 = score_modules_owsm["module"](
|
| 114 |
+
score_modules_owsm["args"],
|
| 115 |
+
int2float(ASR_audio_output[1]),
|
| 116 |
+
ASR_transcript,
|
| 117 |
+
ASR_audio_output[0],
|
| 118 |
+
)
|
| 119 |
+
owsm_wer = (
|
| 120 |
+
dict1["owsm_wer_delete"] + dict1["owsm_wer_insert"] + dict1["owsm_wer_replace"]
|
| 121 |
+
) / (dict1["owsm_wer_insert"] + dict1["owsm_wer_replace"] + dict1["owsm_wer_equal"])
|
| 122 |
+
owsm_cer = (
|
| 123 |
+
dict1["owsm_cer_delete"] + dict1["owsm_cer_insert"] + dict1["owsm_cer_replace"]
|
| 124 |
+
) / (dict1["owsm_cer_insert"] + dict1["owsm_cer_replace"] + dict1["owsm_cer_equal"])
|
| 125 |
+
score_modules_whisper = {
|
| 126 |
+
"module": whisper_levenshtein_metric,
|
| 127 |
+
"args": whisper_wer_setup(
|
| 128 |
+
model_tag="default",
|
| 129 |
+
beam_size=1,
|
| 130 |
+
text_cleaner="whisper_en",
|
| 131 |
+
use_gpu=True,
|
| 132 |
+
),
|
| 133 |
+
}
|
| 134 |
+
dict1 = score_modules_whisper["module"](
|
| 135 |
+
score_modules_whisper["args"],
|
| 136 |
+
int2float(ASR_audio_output[1]),
|
| 137 |
+
ASR_transcript,
|
| 138 |
+
ASR_audio_output[0],
|
| 139 |
+
)
|
| 140 |
+
whisper_wer = (
|
| 141 |
+
dict1["whisper_wer_delete"]
|
| 142 |
+
+ dict1["whisper_wer_insert"]
|
| 143 |
+
+ dict1["whisper_wer_replace"]
|
| 144 |
+
) / (
|
| 145 |
+
dict1["whisper_wer_insert"]
|
| 146 |
+
+ dict1["whisper_wer_replace"]
|
| 147 |
+
+ dict1["whisper_wer_equal"]
|
| 148 |
+
)
|
| 149 |
+
whisper_cer = (
|
| 150 |
+
dict1["whisper_cer_delete"]
|
| 151 |
+
+ dict1["whisper_cer_insert"]
|
| 152 |
+
+ dict1["whisper_cer_replace"]
|
| 153 |
+
) / (
|
| 154 |
+
dict1["whisper_cer_insert"]
|
| 155 |
+
+ dict1["whisper_cer_replace"]
|
| 156 |
+
+ dict1["whisper_cer_equal"]
|
| 157 |
+
)
|
| 158 |
+
return (
|
| 159 |
+
f"ESPnet WER: {espnet_wer*100:.2f}\n"
|
| 160 |
+
f"ESPnet CER: {espnet_cer*100:.2f}\n"
|
| 161 |
+
f"OWSM WER: {owsm_wer*100:.2f}\n"
|
| 162 |
+
f"OWSM CER: {owsm_cer*100:.2f}\n"
|
| 163 |
+
f"Whisper WER: {whisper_wer*100:.2f}\n"
|
| 164 |
+
f"Whisper CER: {whisper_cer*100:.2f}"
|
| 165 |
+
)
|
pyscripts/utils/dialog_eval/LLM_Metrics.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from multiprocessing import Pool
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from pyscripts.utils.dialog_eval.vert import (
|
| 7 |
+
get_auto_bleu2_geometric,
|
| 8 |
+
get_self_bleu2_geometric,
|
| 9 |
+
run_f,
|
| 10 |
+
)
|
| 11 |
+
from scipy.stats import gmean
|
| 12 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 13 |
+
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def perplexity(LLM_Output: str, model_id: str = "gpt2") -> str:
|
| 17 |
+
"""
|
| 18 |
+
Compute the perplexity of the given text using a specified model from the
|
| 19 |
+
`evaluate` library (default: GPT-2).
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
LLM_Output str:
|
| 23 |
+
The text (string) for which perplexity is to be computed.
|
| 24 |
+
model_id (str, optional):
|
| 25 |
+
The identifier of the model to use for computing
|
| 26 |
+
perplexity. Defaults to "gpt2".
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
str:
|
| 30 |
+
A formatted string showing the perplexity of the
|
| 31 |
+
provided text(s), for example:
|
| 32 |
+
"Perplexity: 45.23\n"
|
| 33 |
+
|
| 34 |
+
Raises:
|
| 35 |
+
ImportError:
|
| 36 |
+
If the `evaluate` library is not installed or cannot be imported.
|
| 37 |
+
|
| 38 |
+
Example:
|
| 39 |
+
>>> text = "Hello world, this is a test."
|
| 40 |
+
>>> result = perplexity(text, model_id="gpt2")
|
| 41 |
+
>>> print(result)
|
| 42 |
+
"Perplexity: 27.34\n"
|
| 43 |
+
"""
|
| 44 |
+
try:
|
| 45 |
+
import evaluate
|
| 46 |
+
except Exception as e:
|
| 47 |
+
print("Error: evaluate is not properly installed.")
|
| 48 |
+
raise e
|
| 49 |
+
perplexity = evaluate.load("perplexity", module_type="metric")
|
| 50 |
+
results = perplexity.compute(model_id=model_id, predictions=[LLM_Output])
|
| 51 |
+
return f"Perplexity: {results['mean_perplexity']:.2f}\n"
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def vert(LLM_response_arr: List[str]) -> str:
|
| 55 |
+
"""
|
| 56 |
+
Calculate and return Self BLEU-2, Auto BLEU-2 and VERT-2
|
| 57 |
+
metrics for a list of LLM responses.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
LLM_response_arr (List[str]):
|
| 61 |
+
A list of responses (strings) generated by the language
|
| 62 |
+
model acting as text dialog response generator.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
str:
|
| 66 |
+
A formatted string that includes each computed metric and the final
|
| 67 |
+
VERT value, for example:
|
| 68 |
+
|
| 69 |
+
"Self-BLEU2-geometric: 42.13
|
| 70 |
+
Auto-BLEU2-geometric: 38.94
|
| 71 |
+
VERT: 40.5
|
| 72 |
+
"
|
| 73 |
+
|
| 74 |
+
Example:
|
| 75 |
+
>>> # Suppose we have the following LLM responses:
|
| 76 |
+
>>> responses = ["Hello world", "Foo bar", "Lorem ipsum dolor sit amet"]
|
| 77 |
+
>>> result = vert(responses)
|
| 78 |
+
>>> print(result)
|
| 79 |
+
"Self-BLEU2-geometric: 42.13
|
| 80 |
+
Auto-BLEU2-geometric: 38.94
|
| 81 |
+
VERT: 40.5
|
| 82 |
+
"
|
| 83 |
+
"""
|
| 84 |
+
terms = [x.strip().split() for x in LLM_response_arr]
|
| 85 |
+
|
| 86 |
+
tasks = [
|
| 87 |
+
("Self-BLEU2-geometric", get_self_bleu2_geometric),
|
| 88 |
+
("Auto-BLEU2-geometric", get_auto_bleu2_geometric),
|
| 89 |
+
]
|
| 90 |
+
n_processes = min(16, len(tasks))
|
| 91 |
+
with Pool(n_processes) as pool:
|
| 92 |
+
metrics = pool.map(run_f, [(t[1], terms) for t in tasks])
|
| 93 |
+
metric_arr = []
|
| 94 |
+
str1 = ""
|
| 95 |
+
for (metric_name, _), metric in zip(tasks, metrics):
|
| 96 |
+
metric, sem = np.mean(metric), np.std(metric) / np.sqrt(len(metric))
|
| 97 |
+
|
| 98 |
+
metric, sem = [round(100 * x, 2) for x in [metric, sem]]
|
| 99 |
+
metric_arr.append(metric)
|
| 100 |
+
|
| 101 |
+
str1 += f"{metric_name}: {metric}\n"
|
| 102 |
+
str1 += f"VERT: {round(gmean(metric_arr), 2)}\n"
|
| 103 |
+
return str1
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def bert_score(
|
| 107 |
+
total_response_arr: List[str], bert_model_name: str = "bert-base-uncased"
|
| 108 |
+
) -> str:
|
| 109 |
+
"""
|
| 110 |
+
Compute a cosine similarity score between the concatenated
|
| 111 |
+
context (all but the last element)
|
| 112 |
+
and the final response (last element) using a BERT-based model.
|
| 113 |
+
This serves as a simplified
|
| 114 |
+
measure of how closely the response aligns with the preceding context semantically.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
total_response_arr (List[str]):
|
| 118 |
+
A list of strings. The last element represents the response,
|
| 119 |
+
while all other elements
|
| 120 |
+
are treated as the context.
|
| 121 |
+
bert_model_name (str, optional):
|
| 122 |
+
The name or path of the BERT model to use (from the Hugging Face Model Hub).
|
| 123 |
+
Defaults to "bert-base-uncased".
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
str:
|
| 127 |
+
A string containing the cosine similarity
|
| 128 |
+
(as a percentage) followed by a newline.
|
| 129 |
+
For example:
|
| 130 |
+
"Cosine Similarity: 85.67\n"
|
| 131 |
+
|
| 132 |
+
Example:
|
| 133 |
+
>>> total_responses = [
|
| 134 |
+
... "User: Hi, how are you?",
|
| 135 |
+
... "Assistant: I'm good! How can I help you today?",
|
| 136 |
+
... "User: Can you tell me a joke?",
|
| 137 |
+
... "Assistant: Sure! Here's one: Why did the chicken join a band?"
|
| 138 |
+
... ]
|
| 139 |
+
>>> result = bert_score(total_responses, bert_model_name="bert-base-uncased")
|
| 140 |
+
>>> print(result)
|
| 141 |
+
"Cosine Similarity: 75.89\n"
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
def cosine_similarity_context_response(context, response, model, tokenizer):
|
| 145 |
+
# Tokenize and encode both context and response
|
| 146 |
+
context_inputs = tokenizer(context, return_tensors="pt", truncation=True)
|
| 147 |
+
response_inputs = tokenizer(response, return_tensors="pt", truncation=True)
|
| 148 |
+
for k in context_inputs:
|
| 149 |
+
context_inputs[k] = context_inputs[k].cuda()
|
| 150 |
+
for k in response_inputs:
|
| 151 |
+
response_inputs[k] = response_inputs[k].cuda()
|
| 152 |
+
|
| 153 |
+
# Get embeddings from the model
|
| 154 |
+
with torch.no_grad():
|
| 155 |
+
context_embedding = model(**context_inputs).last_hidden_state.mean(dim=1)
|
| 156 |
+
response_embedding = model(**response_inputs).last_hidden_state.mean(dim=1)
|
| 157 |
+
|
| 158 |
+
# Compute cosine similarity
|
| 159 |
+
similarity = cosine_similarity(
|
| 160 |
+
context_embedding.cpu().numpy(), response_embedding.cpu().numpy()
|
| 161 |
+
)
|
| 162 |
+
return similarity[0][0]
|
| 163 |
+
|
| 164 |
+
bert_model = AutoModel.from_pretrained(bert_model_name).cuda()
|
| 165 |
+
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
|
| 166 |
+
similarity = cosine_similarity_context_response(
|
| 167 |
+
" ".join(total_response_arr[:-1]),
|
| 168 |
+
total_response_arr[-1],
|
| 169 |
+
bert_model,
|
| 170 |
+
bert_tokenizer,
|
| 171 |
+
)
|
| 172 |
+
return f"Cosine Similarity: {similarity*100:.2f}" + "\n"
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def DialoGPT_perplexity(
|
| 176 |
+
user_utterance: str,
|
| 177 |
+
response: str,
|
| 178 |
+
dialog_model_name: str = "microsoft/DialoGPT-medium",
|
| 179 |
+
) -> str:
|
| 180 |
+
"""
|
| 181 |
+
Compute the perplexity of a response given a user utterance using a pre-trained
|
| 182 |
+
DialoGPT model. The function loads DialoGPT (medium by default)
|
| 183 |
+
from the Hugging Face Model Hub, then calculates the perplexity
|
| 184 |
+
for the
|
| 185 |
+
(context + response) sequence.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
user_utterance (str):
|
| 189 |
+
The user utterance preceding the model's response.
|
| 190 |
+
response (str):
|
| 191 |
+
The generated response whose perplexity needs to be evaluated.
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
str:
|
| 195 |
+
A formatted string containing the DialoGPT perplexity score. For example:
|
| 196 |
+
"DialoGPT Perplexity: 25.67\n"
|
| 197 |
+
|
| 198 |
+
Example:
|
| 199 |
+
>>> user_text = "Hi, how are you today?"
|
| 200 |
+
>>> system_response = "I'm good, thank you! How can I help you?"
|
| 201 |
+
>>> result = DialoGPT_perplexity(user_text, system_response)
|
| 202 |
+
>>> print(result)
|
| 203 |
+
"DialoGPT Perplexity: 31.45\n"
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
def evaluate_response_with_dialoGPT(context, response, model, tokenizer):
|
| 207 |
+
"""
|
| 208 |
+
Evaluate the appropriateness of a response based on the
|
| 209 |
+
given context using DialoGPT.
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
context (str): The dialogue context (previous conversation).
|
| 213 |
+
response (str): The generated response to evaluate.
|
| 214 |
+
model: Pre-trained DialoGPT model.
|
| 215 |
+
tokenizer: Corresponding tokenizer for the DialoGPT model.
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
float: Perplexity score of the response given the context.
|
| 219 |
+
"""
|
| 220 |
+
model.eval()
|
| 221 |
+
|
| 222 |
+
# Combine context and response as input
|
| 223 |
+
input_text = context + tokenizer.eos_token + response + tokenizer.eos_token
|
| 224 |
+
inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
|
| 225 |
+
inputs["input_ids"] = inputs["input_ids"].cuda()
|
| 226 |
+
inputs["attention_mask"] = inputs["attention_mask"].cuda()
|
| 227 |
+
# import pdb;pdb.set_trace()
|
| 228 |
+
|
| 229 |
+
# Compute model outputs and loss
|
| 230 |
+
with torch.no_grad():
|
| 231 |
+
outputs = model(**inputs, labels=inputs["input_ids"].cuda())
|
| 232 |
+
loss = outputs.loss
|
| 233 |
+
|
| 234 |
+
# Calculate perplexity
|
| 235 |
+
perplexity = torch.exp(loss)
|
| 236 |
+
return perplexity.cpu().item()
|
| 237 |
+
|
| 238 |
+
# Load DialoGPT model and tokenizer
|
| 239 |
+
model_name = dialog_model_name
|
| 240 |
+
model = AutoModelForCausalLM.from_pretrained(model_name).cuda()
|
| 241 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 242 |
+
perplexity = evaluate_response_with_dialoGPT(
|
| 243 |
+
user_utterance, response, model, tokenizer
|
| 244 |
+
)
|
| 245 |
+
return f"DialoGPT Perplexity: {perplexity:.2f}" + "\n"
|
pyscripts/utils/dialog_eval/TTS_intelligibility.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from espnet2.sds.utils.utils import int2float
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def handle_espnet_TTS_intelligibility(
|
| 9 |
+
TTS_audio_output: Tuple[int, np.ndarray], LLM_Output: str
|
| 10 |
+
) -> str:
|
| 11 |
+
"""
|
| 12 |
+
Compute and return Word Error Rate (WER) and Character Error Rate (CER) metrics
|
| 13 |
+
for multiple ASR systems (ESPnet, OWSM, Whisper) using the Versa library.
|
| 14 |
+
|
| 15 |
+
This function:
|
| 16 |
+
1. Imports the necessary metrics and setup functions from Versa.
|
| 17 |
+
2. Prepares configuration arguments for each ASR system (ESPnet, OWSM, Whisper).
|
| 18 |
+
3. Runs the Levenshtein-based WER/CER calculations on the provided TTS audio.
|
| 19 |
+
4. Returns a formatted string summarizing WER and CER results
|
| 20 |
+
for hypotheses produced
|
| 21 |
+
by each ASR system when transcribing the TTS audio, using
|
| 22 |
+
the LLM output as the reference text.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
TTS_audio_output (Tuple[int, np.ndarray]):
|
| 26 |
+
A tuple consisting of:
|
| 27 |
+
- The first element (int): the frame rate of the audio.
|
| 28 |
+
- The second element (np.ndarray):
|
| 29 |
+
the audio signal (e.g., a NumPy array).
|
| 30 |
+
LLM_Output (str):
|
| 31 |
+
The reference text generated by the LLM, which serves as the ground truth
|
| 32 |
+
for evaluating the TTS audio.
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
str:
|
| 36 |
+
A formatted string showing the WER and CER percentages
|
| 37 |
+
for ESPnet, OWSM, and Whisper.
|
| 38 |
+
Example:
|
| 39 |
+
|
| 40 |
+
ESPnet WER: 10.50
|
| 41 |
+
ESPnet CER: 7.20
|
| 42 |
+
OWSM WER: 11.30
|
| 43 |
+
OWSM CER: 8.00
|
| 44 |
+
Whisper WER: 9.25
|
| 45 |
+
Whisper CER: 6.50
|
| 46 |
+
|
| 47 |
+
Raises:
|
| 48 |
+
ImportError:
|
| 49 |
+
If the Versa library is not installed or cannot be imported.
|
| 50 |
+
|
| 51 |
+
Example:
|
| 52 |
+
>>> tts_audio_output = (16000, audio_array)
|
| 53 |
+
>>> llm_output = "This is the reference text for evaluation."
|
| 54 |
+
>>> result = handle_espnet_TTS_intelligibility(tts_audio_output, llm_output)
|
| 55 |
+
>>> print(result)
|
| 56 |
+
ESPnet WER: 10.50
|
| 57 |
+
ESPnet CER: 7.20
|
| 58 |
+
OWSM WER: 11.30
|
| 59 |
+
OWSM CER: 8.00
|
| 60 |
+
Whisper WER: 9.25
|
| 61 |
+
Whisper CER: 6.50
|
| 62 |
+
"""
|
| 63 |
+
try:
|
| 64 |
+
from versa import (
|
| 65 |
+
espnet_levenshtein_metric,
|
| 66 |
+
espnet_wer_setup,
|
| 67 |
+
owsm_levenshtein_metric,
|
| 68 |
+
owsm_wer_setup,
|
| 69 |
+
whisper_levenshtein_metric,
|
| 70 |
+
whisper_wer_setup,
|
| 71 |
+
)
|
| 72 |
+
except Exception as e:
|
| 73 |
+
print("Error: Versa is not properly installed.")
|
| 74 |
+
raise e
|
| 75 |
+
score_modules_espnet = {
|
| 76 |
+
"module": espnet_levenshtein_metric,
|
| 77 |
+
"args": espnet_wer_setup(
|
| 78 |
+
model_tag="default",
|
| 79 |
+
beam_size=1,
|
| 80 |
+
text_cleaner="whisper_en",
|
| 81 |
+
use_gpu=True,
|
| 82 |
+
),
|
| 83 |
+
}
|
| 84 |
+
dict1 = score_modules_espnet["module"](
|
| 85 |
+
score_modules_espnet["args"],
|
| 86 |
+
int2float(TTS_audio_output[1]),
|
| 87 |
+
LLM_Output,
|
| 88 |
+
TTS_audio_output[0],
|
| 89 |
+
)
|
| 90 |
+
espnet_wer = (
|
| 91 |
+
dict1["espnet_wer_delete"]
|
| 92 |
+
+ dict1["espnet_wer_insert"]
|
| 93 |
+
+ dict1["espnet_wer_replace"]
|
| 94 |
+
) / (
|
| 95 |
+
dict1["espnet_wer_delete"]
|
| 96 |
+
+ dict1["espnet_wer_replace"]
|
| 97 |
+
+ dict1["espnet_wer_equal"]
|
| 98 |
+
)
|
| 99 |
+
espnet_cer = (
|
| 100 |
+
dict1["espnet_cer_delete"]
|
| 101 |
+
+ dict1["espnet_cer_insert"]
|
| 102 |
+
+ dict1["espnet_cer_replace"]
|
| 103 |
+
) / (
|
| 104 |
+
dict1["espnet_cer_delete"]
|
| 105 |
+
+ dict1["espnet_cer_replace"]
|
| 106 |
+
+ dict1["espnet_cer_equal"]
|
| 107 |
+
)
|
| 108 |
+
score_modules_owsm = {
|
| 109 |
+
"module": owsm_levenshtein_metric,
|
| 110 |
+
"args": owsm_wer_setup(
|
| 111 |
+
model_tag="default",
|
| 112 |
+
beam_size=1,
|
| 113 |
+
text_cleaner="whisper_en",
|
| 114 |
+
use_gpu=True,
|
| 115 |
+
),
|
| 116 |
+
}
|
| 117 |
+
dict1 = score_modules_owsm["module"](
|
| 118 |
+
score_modules_owsm["args"],
|
| 119 |
+
int2float(TTS_audio_output[1]),
|
| 120 |
+
LLM_Output,
|
| 121 |
+
TTS_audio_output[0],
|
| 122 |
+
)
|
| 123 |
+
owsm_wer = (
|
| 124 |
+
dict1["owsm_wer_delete"] + dict1["owsm_wer_insert"] + dict1["owsm_wer_replace"]
|
| 125 |
+
) / (dict1["owsm_wer_delete"] + dict1["owsm_wer_replace"] + dict1["owsm_wer_equal"])
|
| 126 |
+
owsm_cer = (
|
| 127 |
+
dict1["owsm_cer_delete"] + dict1["owsm_cer_insert"] + dict1["owsm_cer_replace"]
|
| 128 |
+
) / (dict1["owsm_cer_delete"] + dict1["owsm_cer_replace"] + dict1["owsm_cer_equal"])
|
| 129 |
+
score_modules_whisper = {
|
| 130 |
+
"module": whisper_levenshtein_metric,
|
| 131 |
+
"args": whisper_wer_setup(
|
| 132 |
+
model_tag="default",
|
| 133 |
+
beam_size=1,
|
| 134 |
+
text_cleaner="whisper_en",
|
| 135 |
+
use_gpu=True,
|
| 136 |
+
),
|
| 137 |
+
}
|
| 138 |
+
dict1 = score_modules_whisper["module"](
|
| 139 |
+
score_modules_whisper["args"],
|
| 140 |
+
int2float(TTS_audio_output[1]),
|
| 141 |
+
LLM_Output,
|
| 142 |
+
TTS_audio_output[0],
|
| 143 |
+
)
|
| 144 |
+
whisper_wer = (
|
| 145 |
+
dict1["whisper_wer_delete"]
|
| 146 |
+
+ dict1["whisper_wer_insert"]
|
| 147 |
+
+ dict1["whisper_wer_replace"]
|
| 148 |
+
) / (
|
| 149 |
+
dict1["whisper_wer_delete"]
|
| 150 |
+
+ dict1["whisper_wer_replace"]
|
| 151 |
+
+ dict1["whisper_wer_equal"]
|
| 152 |
+
)
|
| 153 |
+
whisper_cer = (
|
| 154 |
+
dict1["whisper_cer_delete"]
|
| 155 |
+
+ dict1["whisper_cer_insert"]
|
| 156 |
+
+ dict1["whisper_cer_replace"]
|
| 157 |
+
) / (
|
| 158 |
+
dict1["whisper_cer_delete"]
|
| 159 |
+
+ dict1["whisper_cer_replace"]
|
| 160 |
+
+ dict1["whisper_cer_equal"]
|
| 161 |
+
)
|
| 162 |
+
return (
|
| 163 |
+
f"ESPnet WER: {espnet_wer*100:.2f}\n"
|
| 164 |
+
f"ESPnet CER: {espnet_cer*100:.2f}\n"
|
| 165 |
+
f"OWSM WER: {owsm_wer*100:.2f}\n"
|
| 166 |
+
f"OWSM CER: {owsm_cer*100:.2f}\n"
|
| 167 |
+
f"Whisper WER: {whisper_wer*100:.2f}\n"
|
| 168 |
+
f"Whisper CER: {whisper_cer*100:.2f}"
|
| 169 |
+
)
|
pyscripts/utils/dialog_eval/TTS_speech_quality.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from espnet2.sds.utils.utils import int2float
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def TTS_psuedomos(TTS_audio_output: Tuple[int, np.ndarray]) -> str:
|
| 9 |
+
"""
|
| 10 |
+
Compute and return speech quality metrics
|
| 11 |
+
for the given synthesized audio output
|
| 12 |
+
using the Versa library.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
TTS_audio_output (Tuple[int, np.ndarray]):
|
| 16 |
+
A tuple containing:
|
| 17 |
+
- The first element (int): The frame rate of the audio.
|
| 18 |
+
- The second element (np.ndarray): The audio signal,
|
| 19 |
+
typically a NumPy array.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
str:
|
| 23 |
+
A formatted string containing each metric name
|
| 24 |
+
and its corresponding score, for example:
|
| 25 |
+
|
| 26 |
+
utmos: 3.54
|
| 27 |
+
dnsmos: 3.47
|
| 28 |
+
plcmos: 3.62
|
| 29 |
+
sheet_ssqa: 4.03
|
| 30 |
+
|
| 31 |
+
Raises:
|
| 32 |
+
ImportError:
|
| 33 |
+
If the Versa library is not installed or cannot be imported.
|
| 34 |
+
|
| 35 |
+
Example:
|
| 36 |
+
>>> tts_audio_output = (16000, audio_array)
|
| 37 |
+
>>> result = TTS_psuedomos(tts_audio_output)
|
| 38 |
+
>>> print(result)
|
| 39 |
+
utmos: 3.54
|
| 40 |
+
dnsmos: 3.47
|
| 41 |
+
plcmos: 3.62
|
| 42 |
+
sheet_ssqa: 4.03
|
| 43 |
+
"""
|
| 44 |
+
try:
|
| 45 |
+
from versa import (
|
| 46 |
+
pseudo_mos_metric,
|
| 47 |
+
pseudo_mos_setup,
|
| 48 |
+
sheet_ssqa,
|
| 49 |
+
sheet_ssqa_setup,
|
| 50 |
+
)
|
| 51 |
+
except Exception as e:
|
| 52 |
+
print("Error: Versa is not properly installed.")
|
| 53 |
+
raise e
|
| 54 |
+
|
| 55 |
+
predictor_dict, predictor_fs = pseudo_mos_setup(
|
| 56 |
+
use_gpu=True,
|
| 57 |
+
predictor_types=["utmos", "dnsmos", "plcmos"],
|
| 58 |
+
predictor_args={
|
| 59 |
+
"utmos": {"fs": 16000},
|
| 60 |
+
"dnsmos": {"fs": 16000},
|
| 61 |
+
"plcmos": {"fs": 16000},
|
| 62 |
+
},
|
| 63 |
+
)
|
| 64 |
+
score_modules = {
|
| 65 |
+
"module": pseudo_mos_metric,
|
| 66 |
+
"args": {
|
| 67 |
+
"predictor_dict": predictor_dict,
|
| 68 |
+
"predictor_fs": predictor_fs,
|
| 69 |
+
"use_gpu": True,
|
| 70 |
+
},
|
| 71 |
+
}
|
| 72 |
+
dict1 = score_modules["module"](
|
| 73 |
+
int2float(TTS_audio_output[1]),
|
| 74 |
+
TTS_audio_output[0],
|
| 75 |
+
**score_modules["args"],
|
| 76 |
+
)
|
| 77 |
+
str1 = ""
|
| 78 |
+
for k in dict1:
|
| 79 |
+
str1 = str1 + f"{k}: {dict1[k]:.2f}\n"
|
| 80 |
+
sheet_model = sheet_ssqa_setup(
|
| 81 |
+
model_tag="default",
|
| 82 |
+
model_path=None,
|
| 83 |
+
model_config=None,
|
| 84 |
+
use_gpu=True,
|
| 85 |
+
)
|
| 86 |
+
score_modules = {
|
| 87 |
+
"module": sheet_ssqa,
|
| 88 |
+
"args": {"model": sheet_model, "use_gpu": True},
|
| 89 |
+
}
|
| 90 |
+
dict1 = score_modules["module"](
|
| 91 |
+
score_modules["args"]["model"],
|
| 92 |
+
int2float(TTS_audio_output[1]),
|
| 93 |
+
TTS_audio_output[0],
|
| 94 |
+
use_gpu=score_modules["args"]["use_gpu"],
|
| 95 |
+
)
|
| 96 |
+
for k in dict1:
|
| 97 |
+
str1 = str1 + f"{k}: {dict1[k]:.2f}\n"
|
| 98 |
+
return str1
|
pyscripts/utils/dialog_eval/__pycache__/ASR_WER.cpython-39.pyc
ADDED
|
Binary file (4.12 kB). View file
|
|
|
pyscripts/utils/dialog_eval/__pycache__/LLM_Metrics.cpython-39.pyc
ADDED
|
Binary file (8.51 kB). View file
|
|
|
pyscripts/utils/dialog_eval/__pycache__/TTS_intelligibility.cpython-39.pyc
ADDED
|
Binary file (4.34 kB). View file
|
|
|
pyscripts/utils/dialog_eval/__pycache__/TTS_speech_quality.cpython-39.pyc
ADDED
|
Binary file (2.39 kB). View file
|
|
|
pyscripts/utils/dialog_eval/__pycache__/human_feedback.cpython-39.pyc
ADDED
|
Binary file (7.34 kB). View file
|
|
|
pyscripts/utils/dialog_eval/__pycache__/vert.cpython-39.pyc
ADDED
|
Binary file (9.13 kB). View file
|
|
|
pyscripts/utils/dialog_eval/human_feedback.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
|
| 3 |
+
disable_btn = gr.Button(interactive=False, visible=False)
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def get_ip(request: gr.Request) -> str:
|
| 7 |
+
"""
|
| 8 |
+
Retrieve the IP address from an incoming HTTP request.
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
request (gr.Request):
|
| 12 |
+
The incoming HTTP request from which the IP address will be extracted.
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
str:
|
| 16 |
+
The IP address as a string.
|
| 17 |
+
"""
|
| 18 |
+
if "cf-connecting-ip" in request.headers:
|
| 19 |
+
ip = request.headers["cf-connecting-ip"]
|
| 20 |
+
elif "x-forwarded-for" in request.headers:
|
| 21 |
+
ip = request.headers["x-forwarded-for"]
|
| 22 |
+
if "," in ip:
|
| 23 |
+
ip = ip.split(",")[0]
|
| 24 |
+
else:
|
| 25 |
+
ip = request.client.host
|
| 26 |
+
return ip
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def natural_vote1_last_response(request: gr.Request):
|
| 30 |
+
"""
|
| 31 |
+
Handle a user vote for naturalness as "Very Natural".
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
request (gr.Request):
|
| 36 |
+
The Gradio request object providing access to HTTP headers and metadata.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
tuple:
|
| 40 |
+
A tuple containing:
|
| 41 |
+
("Very Natural", <ip_address>, (disable_btn,) * 4)
|
| 42 |
+
|
| 43 |
+
- "Very Natural": The selected vote or label.
|
| 44 |
+
- <ip_address>: The IP address of the client retrieved from the request.
|
| 45 |
+
- disable_btn: An object repeated four times,
|
| 46 |
+
to disable natural vote buttons.
|
| 47 |
+
"""
|
| 48 |
+
ip_address1 = get_ip(request)
|
| 49 |
+
print(f"Very Natural (voted). ip: {ip_address1}")
|
| 50 |
+
return (
|
| 51 |
+
"Very Natural",
|
| 52 |
+
ip_address1,
|
| 53 |
+
) + (disable_btn,) * 4
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def natural_vote2_last_response(request: gr.Request):
|
| 57 |
+
"""
|
| 58 |
+
Handle a user vote for naturalness as "Somewhat Awkward".
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
request (gr.Request):
|
| 63 |
+
The Gradio request object providing access to HTTP headers and metadata.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
tuple:
|
| 67 |
+
A tuple containing:
|
| 68 |
+
("Somewhat Awkward", <ip_address>, (disable_btn,) * 4)
|
| 69 |
+
|
| 70 |
+
- "Somewhat Awkward": The selected vote or label.
|
| 71 |
+
- <ip_address>: The IP address of the client retrieved from the request.
|
| 72 |
+
- disable_btn: An object repeated four times,
|
| 73 |
+
to disable natural vote buttons.
|
| 74 |
+
"""
|
| 75 |
+
ip_address1 = get_ip(request)
|
| 76 |
+
print(f"Somewhat Awkward (voted). ip: {ip_address1}")
|
| 77 |
+
return (
|
| 78 |
+
"Somewhat Awkward",
|
| 79 |
+
ip_address1,
|
| 80 |
+
) + (disable_btn,) * 4
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def natural_vote3_last_response(request: gr.Request):
|
| 84 |
+
"""
|
| 85 |
+
Handle a user vote for naturalness as "Very Awkward".
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
request (gr.Request):
|
| 90 |
+
The Gradio request object providing access to HTTP headers and metadata.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
tuple:
|
| 94 |
+
A tuple containing:
|
| 95 |
+
("Very Awkward", <ip_address>, (disable_btn,) * 4)
|
| 96 |
+
|
| 97 |
+
- "Very Awkward": The selected vote or label.
|
| 98 |
+
- <ip_address>: The IP address of the client retrieved from the request.
|
| 99 |
+
- disable_btn: An object repeated four times,
|
| 100 |
+
to disable natural vote buttons.
|
| 101 |
+
"""
|
| 102 |
+
ip_address1 = get_ip(request)
|
| 103 |
+
print(f"Very Awkward (voted). ip: {ip_address1}")
|
| 104 |
+
return (
|
| 105 |
+
"Very Awkward",
|
| 106 |
+
ip_address1,
|
| 107 |
+
) + (disable_btn,) * 4
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def natural_vote4_last_response(request: gr.Request):
|
| 111 |
+
"""
|
| 112 |
+
Handle a user vote for naturalness as "Unnatural".
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
request (gr.Request):
|
| 117 |
+
The Gradio request object providing access to HTTP headers and metadata.
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
tuple:
|
| 121 |
+
A tuple containing:
|
| 122 |
+
("Unnatural", <ip_address>, (disable_btn,) * 4)
|
| 123 |
+
|
| 124 |
+
- "Unnatural": The selected vote or label.
|
| 125 |
+
- <ip_address>: The IP address of the client retrieved from the request.
|
| 126 |
+
- disable_btn: An object repeated four times,
|
| 127 |
+
to disable natural vote buttons.
|
| 128 |
+
"""
|
| 129 |
+
ip_address1 = get_ip(request)
|
| 130 |
+
print(f"Unnatural (voted). ip: {ip_address1}")
|
| 131 |
+
return (
|
| 132 |
+
"Unnatural",
|
| 133 |
+
ip_address1,
|
| 134 |
+
) + (disable_btn,) * 4
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def relevant_vote1_last_response(request: gr.Request):
|
| 138 |
+
"""
|
| 139 |
+
Handle a user vote for relevance as "Highly Relevant".
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
request (gr.Request):
|
| 144 |
+
The Gradio request object providing access to HTTP headers and metadata.
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
tuple:
|
| 148 |
+
A tuple containing:
|
| 149 |
+
("Highly Relevant", <ip_address>, (disable_btn,) * 4)
|
| 150 |
+
|
| 151 |
+
- "Highly Relevant": The selected vote or label.
|
| 152 |
+
- <ip_address>: The IP address of the client retrieved from the request.
|
| 153 |
+
- disable_btn: An object repeated four times,
|
| 154 |
+
to disable relevance vote buttons.
|
| 155 |
+
"""
|
| 156 |
+
ip_address1 = get_ip(request)
|
| 157 |
+
print(f"Highly Relevant (voted). ip: {ip_address1}")
|
| 158 |
+
return (
|
| 159 |
+
"Highly Relevant",
|
| 160 |
+
ip_address1,
|
| 161 |
+
) + (disable_btn,) * 4
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def relevant_vote2_last_response(request: gr.Request):
|
| 165 |
+
"""
|
| 166 |
+
Handle a user vote for relevance as "Partially Relevant".
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
request (gr.Request):
|
| 171 |
+
The Gradio request object providing access to HTTP headers and metadata.
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
tuple:
|
| 175 |
+
A tuple containing:
|
| 176 |
+
("Partially Relevant", <ip_address>, (disable_btn,) * 4)
|
| 177 |
+
|
| 178 |
+
- "Partially Relevant": The selected vote or label.
|
| 179 |
+
- <ip_address>: The IP address of the client retrieved from the request.
|
| 180 |
+
- disable_btn: An object repeated four times,
|
| 181 |
+
to disable relevance vote buttons.
|
| 182 |
+
"""
|
| 183 |
+
ip_address1 = get_ip(request)
|
| 184 |
+
print(f"Partially Relevant (voted). ip: {ip_address1}")
|
| 185 |
+
return (
|
| 186 |
+
"Partially Relevant",
|
| 187 |
+
ip_address1,
|
| 188 |
+
) + (disable_btn,) * 4
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def relevant_vote3_last_response(request: gr.Request):
|
| 192 |
+
"""
|
| 193 |
+
Handle a user vote for relevance as "Slightly Irrelevant".
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
request (gr.Request):
|
| 198 |
+
The Gradio request object providing access to HTTP headers and metadata.
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
tuple:
|
| 202 |
+
A tuple containing:
|
| 203 |
+
("Slightly Irrelevant", <ip_address>, (disable_btn,) * 4)
|
| 204 |
+
|
| 205 |
+
- "Slightly Irrelevant": The selected vote or label.
|
| 206 |
+
- <ip_address>: The IP address of the client retrieved from the request.
|
| 207 |
+
- disable_btn: An object repeated four times,
|
| 208 |
+
to disable relevance vote buttons.
|
| 209 |
+
"""
|
| 210 |
+
ip_address1 = get_ip(request)
|
| 211 |
+
print(f"Slightly Irrelevant (voted). ip: {ip_address1}")
|
| 212 |
+
return (
|
| 213 |
+
"Slightly Irrelevant",
|
| 214 |
+
ip_address1,
|
| 215 |
+
) + (disable_btn,) * 4
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def relevant_vote4_last_response(request: gr.Request):
|
| 219 |
+
"""
|
| 220 |
+
Handle a user vote for relevance as "Completely Irrelevant".
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
request (gr.Request):
|
| 225 |
+
The Gradio request object providing access to HTTP headers and metadata.
|
| 226 |
+
|
| 227 |
+
Returns:
|
| 228 |
+
tuple:
|
| 229 |
+
A tuple containing:
|
| 230 |
+
("Completely Irrelevant", <ip_address>, (disable_btn,) * 4)
|
| 231 |
+
|
| 232 |
+
- "Completely Irrelevant": The selected vote or label.
|
| 233 |
+
- <ip_address>: The IP address of the client retrieved from the request.
|
| 234 |
+
- disable_btn: An object repeated four times,
|
| 235 |
+
to disable relevance vote buttons.
|
| 236 |
+
"""
|
| 237 |
+
ip_address1 = get_ip(request)
|
| 238 |
+
print(f"Completely Irrelevant (voted). ip: {ip_address1}")
|
| 239 |
+
return (
|
| 240 |
+
"Completely Irrelevant",
|
| 241 |
+
ip_address1,
|
| 242 |
+
) + (disable_btn,) * 4
|
pyscripts/utils/dialog_eval/vert.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import sys
|
| 8 |
+
import warnings
|
| 9 |
+
from collections import Counter
|
| 10 |
+
from fractions import Fraction
|
| 11 |
+
|
| 12 |
+
import nltk
|
| 13 |
+
import numpy as np
|
| 14 |
+
from nltk.translate.bleu_score import (
|
| 15 |
+
SmoothingFunction,
|
| 16 |
+
brevity_penalty,
|
| 17 |
+
closest_ref_length,
|
| 18 |
+
modified_precision,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def corpus_bleu(
|
| 23 |
+
list_of_references,
|
| 24 |
+
hypotheses,
|
| 25 |
+
weights=(0.25, 0.25, 0.25, 0.25),
|
| 26 |
+
smoothing_function=None,
|
| 27 |
+
auto_reweigh=False,
|
| 28 |
+
averaging_mode="geometric",
|
| 29 |
+
no_length_penalty=False,
|
| 30 |
+
):
|
| 31 |
+
"""
|
| 32 |
+
Calculate a single corpus-level BLEU score (aka. system-level BLEU) for all
|
| 33 |
+
the hypotheses and their respective references.
|
| 34 |
+
|
| 35 |
+
Instead of averaging the sentence level BLEU scores (i.e. marco-average
|
| 36 |
+
precision), the original BLEU metric (Papineni et al. 2002) accounts for
|
| 37 |
+
the micro-average precision (i.e. summing the numerators and denominators
|
| 38 |
+
for each hypothesis-reference(s) pairs before the division).
|
| 39 |
+
|
| 40 |
+
>>> hyp1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which',
|
| 41 |
+
... 'ensures', 'that', 'the', 'military', 'always',
|
| 42 |
+
... 'obeys', 'the', 'commands', 'of', 'the', 'party']
|
| 43 |
+
>>> ref1a = ['It', 'is', 'a', 'guide', 'to', 'action', 'that',
|
| 44 |
+
... 'ensures', 'that', 'the', 'military', 'will', 'forever',
|
| 45 |
+
... 'heed', 'Party', 'commands']
|
| 46 |
+
>>> ref1b = ['It', 'is', 'the', 'guiding', 'principle', 'which',
|
| 47 |
+
... 'guarantees', 'the', 'military', 'forces', 'always',
|
| 48 |
+
... 'being', 'under', 'the', 'command', 'of', 'the', 'Party']
|
| 49 |
+
>>> ref1c = ['It', 'is', 'the', 'practical', 'guide', 'for', 'the',
|
| 50 |
+
... 'army', 'always', 'to', 'heed', 'the', 'directions',
|
| 51 |
+
... 'of', 'the', 'party']
|
| 52 |
+
|
| 53 |
+
>>> hyp2 = ['he', 'read', 'the', 'book', 'because', 'he', 'was',
|
| 54 |
+
... 'interested', 'in', 'world', 'history']
|
| 55 |
+
>>> ref2a = ['he', 'was', 'interested', 'in', 'world', 'history',
|
| 56 |
+
... 'because', 'he', 'read', 'the', 'book']
|
| 57 |
+
|
| 58 |
+
>>> list_of_references = [[ref1a, ref1b, ref1c], [ref2a]]
|
| 59 |
+
>>> hypotheses = [hyp1, hyp2]
|
| 60 |
+
>>> corpus_bleu(list_of_references, hypotheses) # doctest: +ELLIPSIS
|
| 61 |
+
0.5920...
|
| 62 |
+
|
| 63 |
+
The example below show that corpus_bleu() is different from averaging
|
| 64 |
+
sentence_bleu() for hypotheses
|
| 65 |
+
|
| 66 |
+
>>> score1 = sentence_bleu([ref1a, ref1b, ref1c], hyp1)
|
| 67 |
+
>>> score2 = sentence_bleu([ref2a], hyp2)
|
| 68 |
+
>>> (score1 + score2) / 2 # doctest: +ELLIPSIS
|
| 69 |
+
0.6223...
|
| 70 |
+
|
| 71 |
+
:param list_of_references: a corpus of lists of reference
|
| 72 |
+
sentences, w.r.t. hypotheses
|
| 73 |
+
:type list_of_references: list(list(list(str)))
|
| 74 |
+
:param hypotheses: a list of hypothesis sentences
|
| 75 |
+
:type hypotheses: list(list(str))
|
| 76 |
+
:param weights: weights for unigrams, bigrams, trigrams and so on
|
| 77 |
+
:type weights: list(float)
|
| 78 |
+
:param smoothing_function:
|
| 79 |
+
:type smoothing_function: SmoothingFunction
|
| 80 |
+
:param auto_reweigh: Option to re-normalize the weights uniformly.
|
| 81 |
+
:type auto_reweigh: bool
|
| 82 |
+
:return: The corpus-level BLEU score.
|
| 83 |
+
:rtype: float
|
| 84 |
+
"""
|
| 85 |
+
# Before proceeding to compute BLEU, perform sanity checks.
|
| 86 |
+
|
| 87 |
+
p_numerators = Counter() # Key = ngram order, and value = no. of ngram matches.
|
| 88 |
+
p_denominators = Counter() # Key = ngram order, and value = no. of ngram in ref.
|
| 89 |
+
hyp_lengths, ref_lengths = 0, 0
|
| 90 |
+
|
| 91 |
+
assert len(list_of_references) == len(hypotheses), (
|
| 92 |
+
"The number of hypotheses and their reference(s) should be the " "same "
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# Iterate through each hypothesis and their corresponding references.
|
| 96 |
+
for references, hypothesis in zip(list_of_references, hypotheses):
|
| 97 |
+
# For each order of ngram, calculate the numerator and
|
| 98 |
+
# denominator for the corpus-level modified precision.
|
| 99 |
+
for i, _ in enumerate(weights, start=1):
|
| 100 |
+
p_i = modified_precision(references, hypothesis, i)
|
| 101 |
+
p_numerators[i] += p_i.numerator
|
| 102 |
+
p_denominators[i] += p_i.denominator
|
| 103 |
+
|
| 104 |
+
# Calculate the hypothesis length and the closest reference length.
|
| 105 |
+
# Adds them to the corpus-level hypothesis and reference counts.
|
| 106 |
+
hyp_len = len(hypothesis)
|
| 107 |
+
hyp_lengths += hyp_len
|
| 108 |
+
ref_lengths += closest_ref_length(references, hyp_len)
|
| 109 |
+
|
| 110 |
+
# Calculate corpus-level brevity penalty.
|
| 111 |
+
if no_length_penalty and averaging_mode == "geometric":
|
| 112 |
+
bp = 1.0
|
| 113 |
+
elif no_length_penalty and averaging_mode == "arithmetic":
|
| 114 |
+
bp = 0.0
|
| 115 |
+
else:
|
| 116 |
+
assert not no_length_penalty
|
| 117 |
+
assert (
|
| 118 |
+
averaging_mode != "arithmetic"
|
| 119 |
+
), "Not sure how to apply length penalty when aurithmetic mode"
|
| 120 |
+
bp = brevity_penalty(ref_lengths, hyp_lengths)
|
| 121 |
+
|
| 122 |
+
# Uniformly re-weighting based on maximum hypothesis lengths if largest
|
| 123 |
+
# order of n-grams < 4 and weights is set at default.
|
| 124 |
+
if auto_reweigh:
|
| 125 |
+
if hyp_lengths < 4 and weights == (0.25, 0.25, 0.25, 0.25):
|
| 126 |
+
weights = (1 / hyp_lengths,) * hyp_lengths
|
| 127 |
+
|
| 128 |
+
# Collects the various precision values for the different ngram orders.
|
| 129 |
+
p_n = [
|
| 130 |
+
Fraction(p_numerators[i], p_denominators[i], _normalize=False)
|
| 131 |
+
for i, _ in enumerate(weights, start=1)
|
| 132 |
+
]
|
| 133 |
+
|
| 134 |
+
# Returns 0 if there's no matching n-grams
|
| 135 |
+
# We only need to check for p_numerators[1] == 0, since if there's
|
| 136 |
+
# no unigrams, there won't be any higher order ngrams.
|
| 137 |
+
if p_numerators[1] == 0:
|
| 138 |
+
return 0
|
| 139 |
+
|
| 140 |
+
# If there's no smoothing, set use method0 from SmoothinFunction class.
|
| 141 |
+
if not smoothing_function:
|
| 142 |
+
smoothing_function = SmoothingFunction().method0
|
| 143 |
+
# Smoothen the modified precision.
|
| 144 |
+
# Note: smoothing_function() may convert values into floats;
|
| 145 |
+
# it tries to retain the Fraction object as much as the
|
| 146 |
+
# smoothing method allows.
|
| 147 |
+
p_n = smoothing_function(
|
| 148 |
+
p_n, references=references, hypothesis=hypothesis, hyp_len=hyp_lengths
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
if averaging_mode == "geometric":
|
| 152 |
+
s = (w_i * math.log(p_i) for w_i, p_i in zip(weights, p_n))
|
| 153 |
+
s = bp * math.exp(math.fsum(s))
|
| 154 |
+
elif averaging_mode == "arithmetic":
|
| 155 |
+
s = (w_i * p_i for w_i, p_i in zip(weights, p_n))
|
| 156 |
+
s = math.fsum(s)
|
| 157 |
+
|
| 158 |
+
return s
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def sentence_bleu(
|
| 162 |
+
references,
|
| 163 |
+
hypothesis,
|
| 164 |
+
weights=(0.25, 0.25, 0.25, 0.25),
|
| 165 |
+
smoothing_function=None,
|
| 166 |
+
auto_reweigh=False,
|
| 167 |
+
averaging_mode="geometric",
|
| 168 |
+
no_length_penalty=False,
|
| 169 |
+
):
|
| 170 |
+
return corpus_bleu(
|
| 171 |
+
[references],
|
| 172 |
+
[hypothesis],
|
| 173 |
+
weights,
|
| 174 |
+
smoothing_function,
|
| 175 |
+
auto_reweigh,
|
| 176 |
+
averaging_mode,
|
| 177 |
+
no_length_penalty,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def get_target_sequences(manifest, ground_truth, to_take=1000):
|
| 182 |
+
import json
|
| 183 |
+
import pathlib
|
| 184 |
+
|
| 185 |
+
with open(ground_truth, "r") as fin:
|
| 186 |
+
original_continuations = json.loads(fin.read())
|
| 187 |
+
|
| 188 |
+
sequence2length = [(k, v[0]) for k, v in original_continuations.items()]
|
| 189 |
+
assert all(float(v) >= 6.0 for (_, v) in sequence2length) # 6 seconds
|
| 190 |
+
|
| 191 |
+
sequence2length.sort(key=lambda x: x[1])
|
| 192 |
+
to_take_sequences = set(v[0] for v in sequence2length[:to_take])
|
| 193 |
+
to_take_ids = []
|
| 194 |
+
|
| 195 |
+
with open(manifest, "r") as f:
|
| 196 |
+
f.readline()
|
| 197 |
+
|
| 198 |
+
for i, line in enumerate(f.readlines()):
|
| 199 |
+
seq_id = line.split()[0]
|
| 200 |
+
seq_id = pathlib.Path(seq_id).name.split("__")[0]
|
| 201 |
+
|
| 202 |
+
if seq_id in to_take_sequences:
|
| 203 |
+
to_take_ids.append(i)
|
| 204 |
+
|
| 205 |
+
print(f"Took {len(to_take_ids)} ids")
|
| 206 |
+
return set(to_take_ids)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def get_self_bleu(utterances, averaging_mode, weights):
|
| 210 |
+
self_bleu = []
|
| 211 |
+
|
| 212 |
+
for i in range(len(utterances)):
|
| 213 |
+
hypo = utterances[i]
|
| 214 |
+
rest = utterances[:i] + utterances[i + 1 :]
|
| 215 |
+
|
| 216 |
+
self_bleu.append(
|
| 217 |
+
sentence_bleu(
|
| 218 |
+
rest,
|
| 219 |
+
hypo,
|
| 220 |
+
weights,
|
| 221 |
+
no_length_penalty=True,
|
| 222 |
+
averaging_mode=averaging_mode,
|
| 223 |
+
)
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
return self_bleu
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def get_self_bleu2_arithmetic(utterances):
|
| 230 |
+
weights = (0.5, 0.5) # equal weight for unigrams and bigrams
|
| 231 |
+
return get_self_bleu(utterances, averaging_mode="arithmetic", weights=weights)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def get_self_bleu2_geometric(utterances):
|
| 235 |
+
weights = (0.5, 0.5)
|
| 236 |
+
return get_self_bleu(utterances, averaging_mode="geometric", weights=weights)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def get_auto_bleu2_arithmetic(utterances):
|
| 240 |
+
weights = (0.5, 0.5)
|
| 241 |
+
return [auto_bleu(u, mean_mode="arithmetic", weights=weights) for u in utterances]
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def get_auto_bleu2_geometric(utterances):
|
| 245 |
+
weights = (0.5, 0.5)
|
| 246 |
+
return [auto_bleu(u, mean_mode="geometric", weights=weights) for u in utterances]
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def get_auto_bleu3_geometric(utterances):
|
| 250 |
+
weights = (1.0 / 3, 1.0 / 3, 1.0 / 3)
|
| 251 |
+
return [auto_bleu(u, mean_mode="geometric", weights=weights) for u in utterances]
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def get_auto_bleu3_arithmetic(utterances):
|
| 255 |
+
weights = (1.0 / 3, 1.0 / 3, 1.0 / 3)
|
| 256 |
+
return [auto_bleu(u, mean_mode="arithmetic", weights=weights) for u in utterances]
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def get_self_bleu3_arithmetic(utterances):
|
| 260 |
+
weights = (1.0 / 3, 1.0 / 3, 1.0 / 3)
|
| 261 |
+
return get_self_bleu(utterances, averaging_mode="arithmetic", weights=weights)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def get_self_bleu3_geometric(utterances):
|
| 265 |
+
weights = (1.0 / 3, 1.0 / 3, 1.0 / 3)
|
| 266 |
+
return get_self_bleu(utterances, averaging_mode="geometric", weights=weights)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def auto_bleu(sentence, weights, mean_mode="arithmetic"):
|
| 270 |
+
if len(sentence) <= 1:
|
| 271 |
+
return 0
|
| 272 |
+
|
| 273 |
+
N = len(weights)
|
| 274 |
+
|
| 275 |
+
bleu_n = np.zeros([N])
|
| 276 |
+
for n in range(N):
|
| 277 |
+
targ_ngrams = list(nltk.ngrams(sentence, n + 1))
|
| 278 |
+
for p in range(len(targ_ngrams)):
|
| 279 |
+
left = sentence[:p]
|
| 280 |
+
right = sentence[(p + n + 1) :]
|
| 281 |
+
rest_ngrams = list(nltk.ngrams(left, n + 1)) + list(
|
| 282 |
+
nltk.ngrams(right, n + 1)
|
| 283 |
+
)
|
| 284 |
+
# compute the nb of matching ngrams
|
| 285 |
+
bleu_n[n] += targ_ngrams[p] in rest_ngrams
|
| 286 |
+
bleu_n[n] /= len(targ_ngrams) # average them to get a proportion
|
| 287 |
+
|
| 288 |
+
weights = np.array(weights)
|
| 289 |
+
if mean_mode == "arithmetic":
|
| 290 |
+
return (bleu_n * weights).sum()
|
| 291 |
+
elif mean_mode == "geometric":
|
| 292 |
+
return (bleu_n**weights).prod()
|
| 293 |
+
else:
|
| 294 |
+
raise ValueError(f"Unknown agggregation mode {mean_mode}")
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def run_f(task_params):
|
| 298 |
+
f, terms = task_params
|
| 299 |
+
return f(terms)
|