Spaces:
Running
Running
Commit
·
8ed90ba
1
Parent(s):
3fb07cc
(wip)update file cache pipeline
Browse files
app.py
CHANGED
|
@@ -4,6 +4,8 @@ from apscheduler.schedulers.background import BackgroundScheduler
|
|
| 4 |
from concurrent.futures import ThreadPoolExecutor
|
| 5 |
from datetime import datetime
|
| 6 |
import threading # Added for locking
|
|
|
|
|
|
|
| 7 |
from sqlalchemy import or_ # Added for vote counting query
|
| 8 |
import hashlib
|
| 9 |
|
|
@@ -42,7 +44,6 @@ from flask import (
|
|
| 42 |
redirect,
|
| 43 |
url_for,
|
| 44 |
session,
|
| 45 |
-
abort,
|
| 46 |
)
|
| 47 |
from flask_login import LoginManager, current_user
|
| 48 |
from models import *
|
|
@@ -61,8 +62,6 @@ import json
|
|
| 61 |
from datetime import datetime, timedelta
|
| 62 |
from flask_migrate import Migrate
|
| 63 |
import requests
|
| 64 |
-
import functools
|
| 65 |
-
import time # Added for potential retries
|
| 66 |
|
| 67 |
# Load environment variables
|
| 68 |
if not IS_SPACES:
|
|
@@ -118,6 +117,7 @@ TTS_CACHE_SIZE = int(os.getenv("TTS_CACHE_SIZE", "10"))
|
|
| 118 |
CACHE_AUDIO_SUBDIR = "cache"
|
| 119 |
tts_cache = {} # sentence -> {model_a, model_b, audio_a, audio_b, created_at}
|
| 120 |
tts_cache_lock = threading.Lock()
|
|
|
|
| 121 |
SMOOTHING_FACTOR_MODEL_SELECTION = 500 # For weighted random model selection
|
| 122 |
# Increased max_workers to 8 for concurrent generation/refill
|
| 123 |
cache_executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix='CacheReplacer')
|
|
@@ -371,6 +371,7 @@ with open("init_sentences.txt", "r") as f:
|
|
| 371 |
initial_sentences = random.sample(all_harvard_sentences,
|
| 372 |
min(len(all_harvard_sentences), 500)) # Limit initial pass for template
|
| 373 |
|
|
|
|
| 374 |
@app.route("/")
|
| 375 |
def arena():
|
| 376 |
# Pass a subset of sentences for the random button fallback
|
|
@@ -616,6 +617,11 @@ def generate_tts():
|
|
| 616 |
if not text or len(text) > 1000:
|
| 617 |
return jsonify({"error": "Invalid or too long text"}), 400
|
| 618 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 619 |
# --- Cache Check ---
|
| 620 |
cache_hit = False
|
| 621 |
session_data_from_cache = None
|
|
@@ -662,7 +668,31 @@ def generate_tts():
|
|
| 662 |
return jsonify({"error": "Not enough TTS models available"}), 500
|
| 663 |
|
| 664 |
selected_models = get_weighted_random_models(available_models, 2, ModelType.TTS)
|
|
|
|
|
|
|
|
|
|
| 665 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 666 |
try:
|
| 667 |
audio_files = []
|
| 668 |
model_ids = []
|
|
@@ -716,15 +746,16 @@ def generate_tts():
|
|
| 716 |
|
| 717 |
# Check if text and prompt are in predefined libraries
|
| 718 |
if text in predefined_texts and prompt_md5 in predefined_prompts.values():
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
|
|
|
| 723 |
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
| 728 |
|
| 729 |
# Return audio file paths and session
|
| 730 |
return jsonify(
|
|
@@ -1120,98 +1151,105 @@ def setup_periodic_tasks():
|
|
| 1120 |
同步缓存音频到HF dataset并从HF下载更新的缓存��频。
|
| 1121 |
"""
|
| 1122 |
os.makedirs(PRELOAD_CACHE_DIR, exist_ok=True)
|
| 1123 |
-
|
| 1124 |
-
|
| 1125 |
-
|
| 1126 |
-
# 获取带有 etag 的文件列表
|
| 1127 |
-
files_info = api.list_repo_files(repo_id=REFERENCE_AUDIO_DATASET, repo_type="dataset", expand=True)
|
| 1128 |
-
# 只处理cache_audios/下的wav文件
|
| 1129 |
-
wav_files = [f for f in files_info if
|
| 1130 |
-
f["rfilename"].startswith(CACHE_AUDIO_PATTERN) and f["rfilename"].endswith(".wav")]
|
| 1131 |
-
|
| 1132 |
-
# 获取本地已有文件名及hash集合
|
| 1133 |
-
local_hashes = {}
|
| 1134 |
-
for root, _, filenames in os.walk(PRELOAD_CACHE_DIR):
|
| 1135 |
-
for fname in filenames:
|
| 1136 |
-
if fname.endswith(".wav"):
|
| 1137 |
-
rel_path = os.path.relpath(os.path.join(root, fname), PRELOAD_CACHE_DIR)
|
| 1138 |
-
remote_path = os.path.join(CACHE_AUDIO_PATTERN, rel_path)
|
| 1139 |
-
local_file_path = os.path.join(root, fname)
|
| 1140 |
-
# 计算本地文件md5
|
| 1141 |
-
try:
|
| 1142 |
-
with open(local_file_path, 'rb') as f:
|
| 1143 |
-
md5 = hashlib.md5(f.read()).hexdigest()
|
| 1144 |
-
local_hashes[remote_path] = md5
|
| 1145 |
-
except Exception:
|
| 1146 |
-
continue
|
| 1147 |
-
|
| 1148 |
-
download_count = 0
|
| 1149 |
-
for f in wav_files:
|
| 1150 |
-
remote_path = f["rfilename"]
|
| 1151 |
-
etag = f.get("lfs", {}).get("oid") or f.get("etag") # 优先lfs oid, 其次etag
|
| 1152 |
-
local_md5 = local_hashes.get(remote_path)
|
| 1153 |
-
|
| 1154 |
-
# 如果远端etag为32位md5且与本地一致,跳过下载
|
| 1155 |
-
if etag and len(etag) == 32 and local_md5 == etag:
|
| 1156 |
-
continue
|
| 1157 |
-
|
| 1158 |
-
# 下载文件
|
| 1159 |
-
local_path = hf_hub_download(
|
| 1160 |
-
repo_id=REFERENCE_AUDIO_DATASET,
|
| 1161 |
-
filename=remote_path,
|
| 1162 |
-
repo_type="dataset",
|
| 1163 |
-
local_dir=PRELOAD_CACHE_DIR,
|
| 1164 |
-
token=os.getenv("HF_TOKEN"),
|
| 1165 |
-
force_download=True if local_md5 else False
|
| 1166 |
-
)
|
| 1167 |
-
print(f"Downloaded cache audio: {local_path}")
|
| 1168 |
-
download_count += 1
|
| 1169 |
|
| 1170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1171 |
|
| 1172 |
-
|
| 1173 |
-
for root, _, files in os.walk(PRELOAD_CACHE_DIR):
|
| 1174 |
-
for file in files:
|
| 1175 |
-
if file.endswith('.wav'):
|
| 1176 |
-
local_path = os.path.join(root, file)
|
| 1177 |
-
rel_path = os.path.relpath(local_path, PRELOAD_CACHE_DIR)
|
| 1178 |
-
remote_path = os.path.join(CACHE_AUDIO_PATTERN, rel_path)
|
| 1179 |
|
| 1180 |
-
|
| 1181 |
-
|
| 1182 |
-
|
| 1183 |
-
|
|
|
|
|
|
|
|
|
|
| 1184 |
|
| 1185 |
-
# 尝试获取远程文件信息
|
| 1186 |
try:
|
| 1187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1188 |
repo_id=REFERENCE_AUDIO_DATASET,
|
| 1189 |
repo_type="dataset",
|
| 1190 |
-
path
|
| 1191 |
-
|
| 1192 |
-
|
| 1193 |
-
|
| 1194 |
-
|
| 1195 |
-
continue
|
| 1196 |
-
except Exception:
|
| 1197 |
-
pass
|
| 1198 |
|
| 1199 |
-
|
| 1200 |
-
|
| 1201 |
-
|
| 1202 |
-
path_or_fileobj=local_path,
|
| 1203 |
-
path_in_repo=remote_path,
|
| 1204 |
-
repo_id=REFERENCE_AUDIO_DATASET,
|
| 1205 |
-
repo_type="dataset",
|
| 1206 |
-
commit_message=f"Upload preload cache file: {os.path.basename(file)}"
|
| 1207 |
-
)
|
| 1208 |
-
app.logger.info(f"Successfully uploaded {remote_path}")
|
| 1209 |
-
except Exception as e:
|
| 1210 |
-
app.logger.error(f"Error uploading {remote_path}: {str(e)}")
|
| 1211 |
-
|
| 1212 |
-
except Exception as e:
|
| 1213 |
-
print(f"Error syncing cache audios with HF: {e}")
|
| 1214 |
-
app.logger.error(f"Error syncing cache audios with HF: {e}")
|
| 1215 |
|
| 1216 |
# Schedule periodic tasks
|
| 1217 |
scheduler = BackgroundScheduler()
|
|
@@ -1354,9 +1392,6 @@ def get_tts_cache_key(model_name, text, prompt_audio_path):
|
|
| 1354 |
return hashlib.md5(key_str.encode('utf-8')).hexdigest()
|
| 1355 |
|
| 1356 |
|
| 1357 |
-
|
| 1358 |
-
|
| 1359 |
-
|
| 1360 |
if __name__ == "__main__":
|
| 1361 |
with app.app_context():
|
| 1362 |
# Ensure ./instance and ./votes directories exist
|
|
|
|
| 4 |
from concurrent.futures import ThreadPoolExecutor
|
| 5 |
from datetime import datetime
|
| 6 |
import threading # Added for locking
|
| 7 |
+
|
| 8 |
+
from huggingface_hub.hf_api import RepoFile
|
| 9 |
from sqlalchemy import or_ # Added for vote counting query
|
| 10 |
import hashlib
|
| 11 |
|
|
|
|
| 44 |
redirect,
|
| 45 |
url_for,
|
| 46 |
session,
|
|
|
|
| 47 |
)
|
| 48 |
from flask_login import LoginManager, current_user
|
| 49 |
from models import *
|
|
|
|
| 62 |
from datetime import datetime, timedelta
|
| 63 |
from flask_migrate import Migrate
|
| 64 |
import requests
|
|
|
|
|
|
|
| 65 |
|
| 66 |
# Load environment variables
|
| 67 |
if not IS_SPACES:
|
|
|
|
| 117 |
CACHE_AUDIO_SUBDIR = "cache"
|
| 118 |
tts_cache = {} # sentence -> {model_a, model_b, audio_a, audio_b, created_at}
|
| 119 |
tts_cache_lock = threading.Lock()
|
| 120 |
+
preload_cache_lock = threading.Lock()
|
| 121 |
SMOOTHING_FACTOR_MODEL_SELECTION = 500 # For weighted random model selection
|
| 122 |
# Increased max_workers to 8 for concurrent generation/refill
|
| 123 |
cache_executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix='CacheReplacer')
|
|
|
|
| 371 |
initial_sentences = random.sample(all_harvard_sentences,
|
| 372 |
min(len(all_harvard_sentences), 500)) # Limit initial pass for template
|
| 373 |
|
| 374 |
+
|
| 375 |
@app.route("/")
|
| 376 |
def arena():
|
| 377 |
# Pass a subset of sentences for the random button fallback
|
|
|
|
| 617 |
if not text or len(text) > 1000:
|
| 618 |
return jsonify({"error": "Invalid or too long text"}), 400
|
| 619 |
|
| 620 |
+
prompt_md5 = ''
|
| 621 |
+
if reference_audio_path and os.path.exists(reference_audio_path):
|
| 622 |
+
with open(reference_audio_path, 'rb') as f:
|
| 623 |
+
prompt_md5 = hashlib.md5(f.read()).hexdigest()
|
| 624 |
+
|
| 625 |
# --- Cache Check ---
|
| 626 |
cache_hit = False
|
| 627 |
session_data_from_cache = None
|
|
|
|
| 668 |
return jsonify({"error": "Not enough TTS models available"}), 500
|
| 669 |
|
| 670 |
selected_models = get_weighted_random_models(available_models, 2, ModelType.TTS)
|
| 671 |
+
# 尝试从持久化缓存中查找两个模型的音频
|
| 672 |
+
audio_a_path = find_cached_audio(str(selected_models[0].id), text, reference_audio_path)
|
| 673 |
+
audio_b_path = find_cached_audio(str(selected_models[1].id), text, reference_audio_path)
|
| 674 |
|
| 675 |
+
if audio_a_path and audio_b_path:
|
| 676 |
+
app.logger.info(f"Persistent Cache HIT for: '{text[:50]}...'. Using files directly.")
|
| 677 |
+
session_id = str(uuid.uuid4())
|
| 678 |
+
app.tts_sessions[session_id] = {
|
| 679 |
+
"model_a": selected_models[0].id,
|
| 680 |
+
"model_b": selected_models[1].id,
|
| 681 |
+
"audio_a": audio_a_path,
|
| 682 |
+
"audio_b": audio_b_path,
|
| 683 |
+
"text": text,
|
| 684 |
+
"created_at": datetime.utcnow(),
|
| 685 |
+
"expires_at": datetime.utcnow() + timedelta(minutes=30),
|
| 686 |
+
"voted": False,
|
| 687 |
+
}
|
| 688 |
+
return jsonify({
|
| 689 |
+
"session_id": session_id,
|
| 690 |
+
"audio_a": f"/api/tts/audio/{session_id}/a",
|
| 691 |
+
"audio_b": f"/api/tts/audio/{session_id}/b",
|
| 692 |
+
"expires_in": 1800,
|
| 693 |
+
"cache_hit": True, # 可以认为这也是一种缓存命中
|
| 694 |
+
})
|
| 695 |
+
# --- 持久化缓存检查结束 ---
|
| 696 |
try:
|
| 697 |
audio_files = []
|
| 698 |
model_ids = []
|
|
|
|
| 746 |
|
| 747 |
# Check if text and prompt are in predefined libraries
|
| 748 |
if text in predefined_texts and prompt_md5 in predefined_prompts.values():
|
| 749 |
+
with preload_cache_lock:
|
| 750 |
+
preload_key = get_tts_cache_key(str(model_ids[0]), text, reference_audio_path)
|
| 751 |
+
preload_path = os.path.join(PRELOAD_CACHE_DIR, f"{preload_key}.wav")
|
| 752 |
+
shutil.copy(audio_files[0], preload_path)
|
| 753 |
+
app.logger.info(f"Preloaded cache audio saved: {preload_path}")
|
| 754 |
|
| 755 |
+
preload_key = get_tts_cache_key(str(model_ids[1]), text, reference_audio_path)
|
| 756 |
+
preload_path = os.path.join(PRELOAD_CACHE_DIR, f"{preload_key}.wav")
|
| 757 |
+
shutil.copy(audio_files[1], preload_path)
|
| 758 |
+
app.logger.info(f"Preloaded cache audio saved: {preload_path}")
|
| 759 |
|
| 760 |
# Return audio file paths and session
|
| 761 |
return jsonify(
|
|
|
|
| 1151 |
同步缓存音频到HF dataset并从HF下载更新的缓存��频。
|
| 1152 |
"""
|
| 1153 |
os.makedirs(PRELOAD_CACHE_DIR, exist_ok=True)
|
| 1154 |
+
with preload_cache_lock:
|
| 1155 |
+
try:
|
| 1156 |
+
api = HfApi(token=os.getenv("HF_TOKEN"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1157 |
|
| 1158 |
+
# 获取带有 etag 的文件列表
|
| 1159 |
+
files_info = [
|
| 1160 |
+
f
|
| 1161 |
+
for f in api.list_repo_tree(
|
| 1162 |
+
repo_id=REFERENCE_AUDIO_DATASET, path_in_repo=CACHE_AUDIO_PATTERN.strip("/"), recursive=True,
|
| 1163 |
+
repo_type="dataset", expand=True
|
| 1164 |
+
)
|
| 1165 |
+
if isinstance(f, RepoFile)
|
| 1166 |
+
]
|
| 1167 |
+
# 只处理cache_audios/下的wav文件
|
| 1168 |
+
wav_files = [f for f in files_info if
|
| 1169 |
+
f.path.endswith(".wav")]
|
| 1170 |
+
|
| 1171 |
+
# 获取本地已有文件名及hash集合
|
| 1172 |
+
local_hashes = {}
|
| 1173 |
+
for root, _, filenames in os.walk(PRELOAD_CACHE_DIR):
|
| 1174 |
+
for fname in filenames:
|
| 1175 |
+
if fname.endswith(".wav"):
|
| 1176 |
+
rel_path = os.path.relpath(os.path.join(root, fname), PRELOAD_CACHE_DIR)
|
| 1177 |
+
remote_path = os.path.join(CACHE_AUDIO_PATTERN, rel_path)
|
| 1178 |
+
local_file_path = os.path.join(root, fname)
|
| 1179 |
+
# 计算本地文件md5
|
| 1180 |
+
try:
|
| 1181 |
+
with open(local_file_path, 'rb') as f:
|
| 1182 |
+
file_hash = hashlib.sha256(f.read()).hexdigest()
|
| 1183 |
+
local_hashes[remote_path] = file_hash
|
| 1184 |
+
except Exception:
|
| 1185 |
+
continue
|
| 1186 |
+
|
| 1187 |
+
download_count = 0
|
| 1188 |
+
for f in wav_files:
|
| 1189 |
+
remote_path = f.path
|
| 1190 |
+
etag = f.lfs.sha256 if f.lfs else None
|
| 1191 |
+
local_hash = local_hashes.get(remote_path)
|
| 1192 |
+
|
| 1193 |
+
# 如果远端etag为32位md5且与本地一致,跳过下载
|
| 1194 |
+
if local_hash == etag:
|
| 1195 |
+
continue
|
| 1196 |
+
|
| 1197 |
+
# 下载文件
|
| 1198 |
+
local_path = hf_hub_download(
|
| 1199 |
+
repo_id=REFERENCE_AUDIO_DATASET,
|
| 1200 |
+
filename=remote_path,
|
| 1201 |
+
repo_type="dataset",
|
| 1202 |
+
local_dir=PRELOAD_CACHE_DIR,
|
| 1203 |
+
token=os.getenv("HF_TOKEN"),
|
| 1204 |
+
force_download=True if local_hash else False
|
| 1205 |
+
)
|
| 1206 |
+
print(f"Downloaded cache audio: {local_path}")
|
| 1207 |
+
download_count += 1
|
| 1208 |
|
| 1209 |
+
print(f"Downloaded {download_count} new/updated cache audios from HF to {PRELOAD_CACHE_DIR}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1210 |
|
| 1211 |
+
# 上传本地文件到HF dataset
|
| 1212 |
+
for root, _, files in os.walk(PRELOAD_CACHE_DIR):
|
| 1213 |
+
for file in files:
|
| 1214 |
+
if file.endswith('.wav'):
|
| 1215 |
+
local_path = os.path.join(root, file)
|
| 1216 |
+
rel_path = os.path.relpath(local_path, PRELOAD_CACHE_DIR)
|
| 1217 |
+
remote_path = os.path.join(CACHE_AUDIO_PATTERN, rel_path)
|
| 1218 |
|
|
|
|
| 1219 |
try:
|
| 1220 |
+
# 计算本地文件MD5,用于检查是否需要上传
|
| 1221 |
+
with open(local_path, 'rb') as f:
|
| 1222 |
+
file_hash = hashlib.sha256(f.read()).hexdigest()
|
| 1223 |
+
|
| 1224 |
+
# 尝试获取远程文件信息
|
| 1225 |
+
try:
|
| 1226 |
+
remote_info = api.get_paths_info(
|
| 1227 |
+
repo_id=REFERENCE_AUDIO_DATASET,
|
| 1228 |
+
repo_type="dataset",
|
| 1229 |
+
path=[remote_path],expand=True)
|
| 1230 |
+
remote_etag = remote_info[0].lfs.sha256 if remote_info and remote_info[0].lfs else None
|
| 1231 |
+
# 如果远程文件存在且hash相同,则跳过
|
| 1232 |
+
if remote_etag and remote_etag == file_hash:
|
| 1233 |
+
app.logger.debug(f"Skipping upload for {remote_path}: file unchanged")
|
| 1234 |
+
continue
|
| 1235 |
+
except Exception as e:
|
| 1236 |
+
app.logger.warning(f"Could not get remote info for {remote_path}: {str(e)}")
|
| 1237 |
+
# 上传文件
|
| 1238 |
+
app.logger.info(f"Uploading preload cache file: {remote_path}")
|
| 1239 |
+
api.upload_file(
|
| 1240 |
+
path_or_fileobj=local_path,
|
| 1241 |
+
path_in_repo=remote_path,
|
| 1242 |
repo_id=REFERENCE_AUDIO_DATASET,
|
| 1243 |
repo_type="dataset",
|
| 1244 |
+
commit_message=f"Upload preload cache file: {os.path.basename(file)}"
|
| 1245 |
+
)
|
| 1246 |
+
app.logger.info(f"Successfully uploaded {remote_path}")
|
| 1247 |
+
except Exception as e:
|
| 1248 |
+
app.logger.error(f"Error uploading {remote_path}: {str(e)}")
|
|
|
|
|
|
|
|
|
|
| 1249 |
|
| 1250 |
+
except Exception as e:
|
| 1251 |
+
print(f"Error syncing cache audios with HF: {e}")
|
| 1252 |
+
app.logger.error(f"Error syncing cache audios with HF: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1253 |
|
| 1254 |
# Schedule periodic tasks
|
| 1255 |
scheduler = BackgroundScheduler()
|
|
|
|
| 1392 |
return hashlib.md5(key_str.encode('utf-8')).hexdigest()
|
| 1393 |
|
| 1394 |
|
|
|
|
|
|
|
|
|
|
| 1395 |
if __name__ == "__main__":
|
| 1396 |
with app.app_context():
|
| 1397 |
# Ensure ./instance and ./votes directories exist
|