littlebird13 commited on
Commit
c9489e8
·
verified ·
1 Parent(s): f00c078

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -21
app.py CHANGED
@@ -27,6 +27,10 @@ from fastrtc import (
27
  )
28
  from gradio.utils import get_space
29
  from websockets.asyncio.client import connect
 
 
 
 
30
 
31
  load_dotenv()
32
 
@@ -37,6 +41,9 @@ API_KEY = os.environ['API_KEY'] # Set with: export DASHSCOPE_API_KEY=xxx
37
  API_URL = "wss://dashscope-intl.aliyuncs.com/api-ws/v1/realtime?model=qwen3-livetranslate-flash-realtime"
38
  VOICES = ["Cherry", "Nofish", "Jada", "Dylan", "Sunny", "Peter", "Kiki", "Eric"]
39
 
 
 
 
40
  if not API_KEY:
41
  raise RuntimeError("Missing DASHSCOPE_API_KEY environment variable.")
42
  headers = {"Authorization": "Bearer " + API_KEY}
@@ -77,6 +84,60 @@ class LiveTranslateHandler(AsyncStreamHandler):
77
  )
78
  self.connection = None
79
  self.output_queue = asyncio.Queue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  def copy(self):
82
  return LiveTranslateHandler()
@@ -89,19 +150,22 @@ class LiveTranslateHandler(AsyncStreamHandler):
89
  try:
90
  await self.wait_for_args()
91
  args = self.latest_args
92
- src_language_name = args[2] if len(args) > 2 else "English" # 现在 dropdown 返回的是全称
93
- target_language_name = args[3] if len(args) > 3 else "Chinese"
94
  src_language_code = LANG_MAP_REVERSE[src_language_name]
95
  target_language_code = LANG_MAP_REVERSE[target_language_name]
96
 
97
- # src_language = args[2] if len(args) > 2 else "zh" # 新增源语言参数
98
- # target_language = args[3] if len(args) > 3 else "en"
99
  voice_id = args[4] if len(args) > 4 else "Cherry"
 
 
 
 
 
100
 
101
  if src_language_code == target_language_code:
102
  print(f"⚠️ 源语言和目标语言相同({target_language_name}),将以复述模式运行")
103
 
104
- async with connect(API_URL, additional_headers=headers) as conn:
105
  self.client = conn
106
  await conn.send(
107
  json.dumps(
@@ -123,36 +187,89 @@ class LiveTranslateHandler(AsyncStreamHandler):
123
  )
124
  self.connection = conn
125
 
 
126
  async for data in self.connection:
127
  event = json.loads(data)
128
  if "type" not in event:
129
  continue
130
  event_type = event["type"]
131
 
132
- if event_type == "response.audio_transcript.delta":
133
- # 增量字幕
134
- text = event.get("transcript", "")
135
- if text:
136
- await self.output_queue.put(
137
- AdditionalOutputs({"role": "assistant", "content": text})
138
- )
 
 
139
 
140
- # elif event_type in ("response.text.text", "response.audio_transcript.text"):
 
141
  # # 中间结果 + stash(stash通常是句子完整缓存)
142
  # stash_text = event.get("stash", "")
143
  # text_field = event.get("text", "")
144
  # if stash_text or text_field:
145
  # await self.output_queue.put(
146
- # AdditionalOutputs({"role": "assistant", "content": stash_text or text_field})
147
  # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  elif event_type == "response.audio_transcript.done":
150
- # 最终完整句子
151
  transcript = event.get("transcript", "")
152
  if transcript:
 
 
153
  await self.output_queue.put(
154
- AdditionalOutputs({"role": "assistant", "content": transcript})
 
 
 
 
 
155
  )
 
 
 
 
156
 
157
  elif event_type == "response.audio.delta":
158
  audio_b64 = event.get("delta", "")
@@ -183,10 +300,22 @@ class LiveTranslateHandler(AsyncStreamHandler):
183
  )
184
  )
185
 
 
 
 
 
 
 
 
186
  async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None:
187
  return await wait_for_item(self.output_queue)
188
 
189
  async def shutdown(self) -> None:
 
 
 
 
 
190
  if self.connection:
191
  await self.connection.close()
192
  self.connection = None
@@ -197,7 +326,32 @@ class LiveTranslateHandler(AsyncStreamHandler):
197
 
198
 
199
  def update_chatbot(chatbot: list[dict], response: dict):
200
- chatbot.append(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  return chatbot
202
 
203
 
@@ -215,17 +369,24 @@ language = gr.Dropdown(
215
  label="Target Language"
216
  )
217
  voice = gr.Dropdown(choices=VOICES, value=VOICES[0], type="value", label="Voice")
 
 
 
 
 
 
218
  latest_message = gr.Textbox(type="text", visible=False)
219
 
220
  # 可选:暂时禁用 TURN 配置进行测试
221
  rtc_config = get_cloudflare_turn_credentials_async if get_space() else None
222
  # rtc_config = None # 取消注释可禁用 TURN 测试
223
 
 
224
  stream = Stream(
225
  LiveTranslateHandler(),
226
  mode="send-receive",
227
  modality="audio",
228
- additional_inputs=[src_language, language, voice, chatbot], # 添加 src_language
229
  additional_outputs=[chatbot],
230
  additional_outputs_handler=update_chatbot,
231
  rtc_configuration=rtc_config,
@@ -234,6 +395,7 @@ stream = Stream(
234
  )
235
 
236
 
 
237
  app = FastAPI()
238
 
239
  stream.mount(app)
@@ -271,10 +433,10 @@ if __name__ == "__main__":
271
  import os
272
 
273
  if (mode := os.getenv("MODE")) == "UI":
274
- stream.ui.launch(server_port=7860)
275
  elif mode == "PHONE":
276
- stream.fastphone(host="0.0.0.0", port=7860)
277
  else:
278
  import uvicorn
279
 
280
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
27
  )
28
  from gradio.utils import get_space
29
  from websockets.asyncio.client import connect
30
+ import ssl
31
+ import certifi
32
+
33
+ import cv2
34
 
35
  load_dotenv()
36
 
 
41
  API_URL = "wss://dashscope-intl.aliyuncs.com/api-ws/v1/realtime?model=qwen3-livetranslate-flash-realtime"
42
  VOICES = ["Cherry", "Nofish", "Jada", "Dylan", "Sunny", "Peter", "Kiki", "Eric"]
43
 
44
+ ssl_context = ssl.create_default_context(cafile=certifi.where())
45
+ # ssl_context = ssl._create_unverified_context() # 禁用证书验证
46
+
47
  if not API_KEY:
48
  raise RuntimeError("Missing DASHSCOPE_API_KEY environment variable.")
49
  headers = {"Authorization": "Bearer " + API_KEY}
 
84
  )
85
  self.connection = None
86
  self.output_queue = asyncio.Queue()
87
+ self.video_capture = None # 视频捕获设备
88
+ self.last_capture_time = 0 # 上次视频帧捕获时间戳
89
+ self.enable_video = False
90
+ self.output_queue = asyncio.Queue()
91
+ self.awaiting_new_message = True
92
+ self.stable_text = "" # 黑色部分
93
+ self.temp_text = "" # 灰色部分
94
+
95
+ def setup_video(self):
96
+ """设置视频捕获设备"""
97
+ self.video_capture = cv2.VideoCapture(0) # 打开默认摄像头
98
+ self.video_capture.set(cv2.CAP_PROP_FRAME_WIDTH, 640) # 设置宽度
99
+ self.video_capture.set(cv2.CAP_PROP_FRAME_HEIGHT, 480) # 设置高度
100
+ self.video_capture.set(cv2.CAP_PROP_FPS, 30) # 设置 FPS
101
+
102
+ def get_video_frame(self) -> bytes | None:
103
+ """获取视频帧并处理成缩放后的字节"""
104
+ if not self.video_capture:
105
+ return None
106
+
107
+ # 获取当前时间
108
+ current_time = time.time()
109
+
110
+ # 每隔 0.5 秒截取一帧
111
+ if current_time - self.last_capture_time >= 0.5:
112
+ self.last_capture_time = current_time
113
+ ret, frame = self.video_capture.read() # 捕获当前帧
114
+ if ret:
115
+ # 压缩并调整分辨率
116
+ resized_frame = cv2.resize(frame, (640, 360)) # 确保分辨率低于 480p
117
+ # 使用 JPEG 格式编码视频帧
118
+ _, encoded_image = cv2.imencode('.jpg', resized_frame)
119
+ return encoded_image.tobytes()
120
+ return None
121
+
122
+ async def send_image_frame(self, image_bytes: bytes, *, event_id: str | None = None):
123
+ """将图像数据发送给服务器"""
124
+ if not self.connection:
125
+ return
126
+
127
+ if not image_bytes:
128
+ raise ValueError("image_bytes 不能为空")
129
+
130
+ # 编码为 Base64
131
+ image_b64 = base64.b64encode(image_bytes).decode()
132
+
133
+ event = {
134
+ "event_id": event_id or self.msg_id(),
135
+ "type": "input_image_buffer.append",
136
+ "image": image_b64,
137
+ }
138
+
139
+ await self.connection.send(json.dumps(event))
140
+
141
 
142
  def copy(self):
143
  return LiveTranslateHandler()
 
150
  try:
151
  await self.wait_for_args()
152
  args = self.latest_args
153
+ src_language_name = args[2] if len(args) > 2 else "Chinese" # 现在 dropdown 返回的是全称
154
+ target_language_name = args[3] if len(args) > 3 else "English"
155
  src_language_code = LANG_MAP_REVERSE[src_language_name]
156
  target_language_code = LANG_MAP_REVERSE[target_language_name]
157
 
 
 
158
  voice_id = args[4] if len(args) > 4 else "Cherry"
159
+
160
+ self.enable_video = True if args[5] == "True" else False
161
+
162
+ if self.enable_video:
163
+ self.setup_video() # 初始化视频设备
164
 
165
  if src_language_code == target_language_code:
166
  print(f"⚠️ 源语言和目标语言相同({target_language_name}),将以复述模式运行")
167
 
168
+ async with connect(API_URL, additional_headers=headers, ssl=ssl_context) as conn:
169
  self.client = conn
170
  await conn.send(
171
  json.dumps(
 
187
  )
188
  self.connection = conn
189
 
190
+ # WebSocket 收到的每一个响应(data)是一个 JSON 事件,表示翻译任务的进展。
191
  async for data in self.connection:
192
  event = json.loads(data)
193
  if "type" not in event:
194
  continue
195
  event_type = event["type"]
196
 
197
+ # if event_type == "response.audio_transcript.delta":
198
+ # # 增量字幕
199
+ # text = event.get("transcript", "")
200
+ # if text:
201
+ # await self.output_queue.put(
202
+ # AdditionalOutputs({"role": "assistant", "content": text, "update": True, "new_message": self.awaiting_new_message
203
+ # })
204
+ # )
205
+ # self.awaiting_new_message = False
206
 
207
+ # # 中间文本内容
208
+ # if event_type in ("response.text.text", "response.audio_transcript.text"):
209
  # # 中间结果 + stash(stash通常是句子完整缓存)
210
  # stash_text = event.get("stash", "")
211
  # text_field = event.get("text", "")
212
  # if stash_text or text_field:
213
  # await self.output_queue.put(
214
+ # AdditionalOutputs({"role": "assistant", "content": stash_text or text_field, "update": True, "new_message": self.awaiting_new_message})
215
  # )
216
+ # self.awaiting_new_message = False
217
+
218
+ # elif event_type == "response.audio_transcript.done":
219
+ # # 最终完整句子
220
+ # transcript = event.get("transcript", "")
221
+ # if transcript:
222
+ # await self.output_queue.put(
223
+ # AdditionalOutputs({"role": "assistant", "content": transcript, "update": True, "new_message": self.awaiting_new_message})
224
+ # )
225
+ # self.awaiting_new_message = True
226
+
227
+ if event_type == "response.audio_transcript.delta":
228
+ self.temp_text = event.get("transcript", "")
229
+ if self.temp_text:
230
+ await self.output_queue.put(
231
+ AdditionalOutputs({
232
+ "role": "assistant",
233
+ "content": (self.stable_text, self.temp_text),
234
+ "update": True,
235
+ "new_message": self.awaiting_new_message
236
+ })
237
+ )
238
+ self.awaiting_new_message = False
239
+
240
+ elif event_type in ("response.text.text", "response.audio_transcript.text"):
241
+ # 更新稳定部分(stash / text 认为是已确认的)
242
+ new_stable = event.get("stash") or event.get("text") or ""
243
+ if new_stable:
244
+ self.stable_text = f"{self.stable_text}{new_stable}"
245
+ self.temp_text = "" # 临时部分清空
246
+ await self.output_queue.put(
247
+ AdditionalOutputs({
248
+ "role": "assistant",
249
+ "content": (self.stable_text, self.temp_text),
250
+ "update": True,
251
+ "new_message": self.awaiting_new_message
252
+ })
253
+ )
254
+ self.awaiting_new_message = False
255
 
256
  elif event_type == "response.audio_transcript.done":
 
257
  transcript = event.get("transcript", "")
258
  if transcript:
259
+ self.stable_text = transcript
260
+ self.temp_text = ""
261
  await self.output_queue.put(
262
+ AdditionalOutputs({
263
+ "role": "assistant",
264
+ "content": (self.stable_text, self.temp_text),
265
+ "update": True,
266
+ "new_message": self.awaiting_new_message
267
+ })
268
  )
269
+ # 开启新气泡
270
+ self.awaiting_new_message = True
271
+ self.stable_text = ""
272
+ self.temp_text = ""
273
 
274
  elif event_type == "response.audio.delta":
275
  audio_b64 = event.get("delta", "")
 
300
  )
301
  )
302
 
303
+ # 视频部分
304
+ if self.enable_video:
305
+ image_frame = self.get_video_frame()
306
+ if image_frame:
307
+ await self.send_image_frame(image_frame)
308
+
309
+
310
  async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None:
311
  return await wait_for_item(self.output_queue)
312
 
313
  async def shutdown(self) -> None:
314
+ """关闭连接并清理资源"""
315
+ if self.video_capture:
316
+ self.video_capture.release() # 释放视频设备
317
+ self.video_capture = None
318
+
319
  if self.connection:
320
  await self.connection.close()
321
  self.connection = None
 
326
 
327
 
328
  def update_chatbot(chatbot: list[dict], response: dict):
329
+ is_update = response.pop("update", False)
330
+ new_message_flag = response.pop("new_message", False)
331
+ content_tuple = response["content"]
332
+
333
+ # 组 HTML:黑色稳定文本 + 灰色临时文本
334
+ stable_html = f"<span style='color:black'>{content_tuple[0]}</span>"
335
+ temp_html = f"<span style='color:gray'>{content_tuple[1]}</span>"
336
+ html_content = stable_html + temp_html
337
+
338
+ if is_update:
339
+ if new_message_flag or not chatbot:
340
+ chatbot.append({
341
+ "role": "assistant",
342
+ "content": html_content
343
+ })
344
+ else:
345
+ if chatbot[-1]["role"] == "assistant":
346
+ chatbot[-1]["content"] = html_content
347
+ else:
348
+ chatbot.append({
349
+ "role": "assistant",
350
+ "content": html_content
351
+ })
352
+ else:
353
+ chatbot.append(response)
354
+
355
  return chatbot
356
 
357
 
 
369
  label="Target Language"
370
  )
371
  voice = gr.Dropdown(choices=VOICES, value=VOICES[0], type="value", label="Voice")
372
+ video_flag = gr.Dropdown(
373
+ choices=["True", "False"],
374
+ value="False",
375
+ label="Use Video"
376
+ )
377
+
378
  latest_message = gr.Textbox(type="text", visible=False)
379
 
380
  # 可选:暂时禁用 TURN 配置进行测试
381
  rtc_config = get_cloudflare_turn_credentials_async if get_space() else None
382
  # rtc_config = None # 取消注释可禁用 TURN 测试
383
 
384
+
385
  stream = Stream(
386
  LiveTranslateHandler(),
387
  mode="send-receive",
388
  modality="audio",
389
+ additional_inputs=[src_language, language, voice, video_flag, chatbot],
390
  additional_outputs=[chatbot],
391
  additional_outputs_handler=update_chatbot,
392
  rtc_configuration=rtc_config,
 
395
  )
396
 
397
 
398
+
399
  app = FastAPI()
400
 
401
  stream.mount(app)
 
433
  import os
434
 
435
  if (mode := os.getenv("MODE")) == "UI":
436
+ stream.ui.launch(server_port=7862)
437
  elif mode == "PHONE":
438
+ stream.fastphone(host="0.0.0.0", port=7862)
439
  else:
440
  import uvicorn
441
 
442
+ uvicorn.run(app, host="0.0.0.0", port=7862)