Tijs Zwinkels
commited on
Commit
·
f0a24cd
1
Parent(s):
3696fef
Make --vad work with --backend openai-api
Browse files- whisper_online.py +22 -16
whisper_online.py
CHANGED
|
@@ -162,7 +162,7 @@ class OpenaiApiASR(ASRBase):
|
|
| 162 |
|
| 163 |
self.load_model()
|
| 164 |
|
| 165 |
-
self.
|
| 166 |
|
| 167 |
# reset the task in set_translate_task
|
| 168 |
self.task = "transcribe"
|
|
@@ -175,21 +175,27 @@ class OpenaiApiASR(ASRBase):
|
|
| 175 |
|
| 176 |
|
| 177 |
def ts_words(self, segments):
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
return o
|
| 189 |
|
| 190 |
|
| 191 |
def segments_end_ts(self, res):
|
| 192 |
-
return [s["end"] for s in res]
|
| 193 |
|
| 194 |
def transcribe(self, audio_data, prompt=None, *args, **kwargs):
|
| 195 |
# Write the audio data to a buffer
|
|
@@ -205,7 +211,7 @@ class OpenaiApiASR(ASRBase):
|
|
| 205 |
"file": buffer,
|
| 206 |
"response_format": self.response_format,
|
| 207 |
"temperature": self.temperature,
|
| 208 |
-
"timestamp_granularities": ["word"]
|
| 209 |
}
|
| 210 |
if self.task != "translate" and self.language:
|
| 211 |
params["language"] = self.language
|
|
@@ -221,10 +227,10 @@ class OpenaiApiASR(ASRBase):
|
|
| 221 |
transcript = proc.create(**params)
|
| 222 |
print(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds",file=self.logfile)
|
| 223 |
|
| 224 |
-
return transcript
|
| 225 |
|
| 226 |
def use_vad(self):
|
| 227 |
-
self.
|
| 228 |
|
| 229 |
def set_translate_task(self):
|
| 230 |
self.task = "translate"
|
|
@@ -592,9 +598,9 @@ if __name__ == "__main__":
|
|
| 592 |
e = time.time()
|
| 593 |
print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
|
| 594 |
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
|
| 599 |
if args.task == "translate":
|
| 600 |
asr.set_translate_task()
|
|
|
|
| 162 |
|
| 163 |
self.load_model()
|
| 164 |
|
| 165 |
+
self.use_vad_opt = False
|
| 166 |
|
| 167 |
# reset the task in set_translate_task
|
| 168 |
self.task = "transcribe"
|
|
|
|
| 175 |
|
| 176 |
|
| 177 |
def ts_words(self, segments):
|
| 178 |
+
no_speech_segments = []
|
| 179 |
+
if self.use_vad_opt:
|
| 180 |
+
for segment in segments.segments:
|
| 181 |
+
# TODO: threshold can be set from outside
|
| 182 |
+
if segment["no_speech_prob"] > 0.8:
|
| 183 |
+
no_speech_segments.append((segment.get("start"), segment.get("end")))
|
| 184 |
|
| 185 |
+
o = []
|
| 186 |
+
for word in segments.words:
|
| 187 |
+
start = word.get("start")
|
| 188 |
+
end = word.get("end")
|
| 189 |
+
if any(s[0] <= start <= s[1] for s in no_speech_segments):
|
| 190 |
+
# print("Skipping word", word.get("word"), "because it's in a no-speech segment")
|
| 191 |
+
continue
|
| 192 |
+
o.append((start, end, word.get("word")))
|
| 193 |
|
| 194 |
return o
|
| 195 |
|
| 196 |
|
| 197 |
def segments_end_ts(self, res):
|
| 198 |
+
return [s["end"] for s in res.words]
|
| 199 |
|
| 200 |
def transcribe(self, audio_data, prompt=None, *args, **kwargs):
|
| 201 |
# Write the audio data to a buffer
|
|
|
|
| 211 |
"file": buffer,
|
| 212 |
"response_format": self.response_format,
|
| 213 |
"temperature": self.temperature,
|
| 214 |
+
"timestamp_granularities": ["word", "segment"]
|
| 215 |
}
|
| 216 |
if self.task != "translate" and self.language:
|
| 217 |
params["language"] = self.language
|
|
|
|
| 227 |
transcript = proc.create(**params)
|
| 228 |
print(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds",file=self.logfile)
|
| 229 |
|
| 230 |
+
return transcript
|
| 231 |
|
| 232 |
def use_vad(self):
|
| 233 |
+
self.use_vad_opt = True
|
| 234 |
|
| 235 |
def set_translate_task(self):
|
| 236 |
self.task = "translate"
|
|
|
|
| 598 |
e = time.time()
|
| 599 |
print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
|
| 600 |
|
| 601 |
+
if args.vad:
|
| 602 |
+
print("setting VAD filter",file=logfile)
|
| 603 |
+
asr.use_vad()
|
| 604 |
|
| 605 |
if args.task == "translate":
|
| 606 |
asr.set_translate_task()
|