Spaces:
Running
Running
Commit
·
2249567
1
Parent(s):
a2a66cd
modify elo caculate
Browse files
app.py
CHANGED
|
@@ -662,62 +662,111 @@ def generate_tts():
|
|
| 662 |
)
|
| 663 |
# --- End Cache Check ---
|
| 664 |
|
| 665 |
-
# --- Cache Miss:
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 721 |
|
| 722 |
|
| 723 |
@app.route("/api/tts/audio/<session_id>/<model_key>")
|
|
|
|
| 662 |
)
|
| 663 |
# --- End Cache Check ---
|
| 664 |
|
| 665 |
+
# --- Cache Miss: Local File Cache ---
|
| 666 |
+
# 对于预置文本和预置prompt,检查本地缓存
|
| 667 |
+
if text in predefined_texts and prompt_md5 in predefined_prompts.values():
|
| 668 |
+
app.logger.warning(f"TTS Cache MISS for: '{text[:50]}...'. Generating on the fly.")
|
| 669 |
+
available_models = Model.query.filter_by(
|
| 670 |
+
model_type=ModelType.TTS, is_active=True
|
| 671 |
+
).all()
|
| 672 |
+
if len(available_models) < 2:
|
| 673 |
+
return jsonify({"error": "Not enough TTS models available"}), 500
|
| 674 |
+
|
| 675 |
+
# 新增:a和b模型都需通过缓存检测
|
| 676 |
+
candidate_models = available_models.copy()
|
| 677 |
+
valid_models = []
|
| 678 |
+
invalid_models = []
|
| 679 |
+
for model in candidate_models:
|
| 680 |
+
audio_path = find_cached_audio(model.name, text, prompt_audio_path=reference_audio_path)
|
| 681 |
+
if audio_path and os.path.exists(audio_path):
|
| 682 |
+
valid_models.append(model)
|
| 683 |
+
else:
|
| 684 |
+
invalid_models.append(model)
|
| 685 |
+
|
| 686 |
+
if len(valid_models) < 2:
|
| 687 |
+
return jsonify({"error": "Not enough valid TTS model results available"}), 500
|
| 688 |
+
|
| 689 |
+
apply_filter_penalty_and_redistribute(invalid_models, valid_models, penalty_amount=1.0)
|
| 690 |
+
|
| 691 |
+
# 从有结果的模型中随机选择两个
|
| 692 |
+
model_a,model_b = random.sample(valid_models, 2)
|
| 693 |
+
audio_a_path = find_cached_audio(model_a.name, text, prompt_audio_path=reference_audio_path)
|
| 694 |
+
audio_b_path = find_cached_audio(model_b.name, text, prompt_audio_path=reference_audio_path)
|
| 695 |
+
|
| 696 |
+
session_id = str(uuid.uuid4())
|
| 697 |
+
app.tts_sessions[session_id] = {
|
| 698 |
+
"model_a": model_a.id,
|
| 699 |
+
"model_b": model_b.id,
|
| 700 |
+
"audio_a": audio_a_path,
|
| 701 |
+
"audio_b": audio_b_path,
|
| 702 |
+
"text": text,
|
| 703 |
+
"created_at": datetime.utcnow(),
|
| 704 |
+
"expires_at": datetime.utcnow() + timedelta(minutes=30),
|
| 705 |
+
"voted": False,
|
| 706 |
+
}
|
| 707 |
+
# 清理临时参考音频文件
|
| 708 |
+
if reference_audio_path and os.path.exists(reference_audio_path):
|
| 709 |
+
os.remove(reference_audio_path)
|
| 710 |
+
return jsonify({
|
| 711 |
+
"session_id": session_id,
|
| 712 |
+
"audio_a": f"/api/tts/audio/{session_id}/a",
|
| 713 |
+
"audio_b": f"/api/tts/audio/{session_id}/b",
|
| 714 |
+
"expires_in": 1800,
|
| 715 |
+
"cache_hit": True,
|
| 716 |
+
})
|
| 717 |
+
# --- End Cache Miss ---
|
| 718 |
+
else:
|
| 719 |
+
app.logger.warning(f"TTS Cache MISS for: '{text[:50]}...'. Generating on the fly.")
|
| 720 |
+
available_models = Model.query.filter_by(
|
| 721 |
+
model_type=ModelType.TTS, is_active=True
|
| 722 |
+
).all()
|
| 723 |
+
if len(available_models) < 2:
|
| 724 |
+
return jsonify({"error": "Not enough TTS models available"}), 500
|
| 725 |
+
|
| 726 |
+
# Get two random models with weighted selection
|
| 727 |
+
models = get_weighted_random_models(available_models, 2, ModelType.TTS)
|
| 728 |
+
|
| 729 |
+
# Generate audio concurrently using a local executor for clarity within the request
|
| 730 |
+
with ThreadPoolExecutor(max_workers=2, thread_name_prefix='AudioGen') as audio_executor:
|
| 731 |
+
future_a = audio_executor.submit(generate_and_save_tts, text, models[0].id, RUNTIME_CACHE_DIR,
|
| 732 |
+
prompt_audio_path=reference_audio_path)
|
| 733 |
+
future_b = audio_executor.submit(generate_and_save_tts, text, models[1].id, RUNTIME_CACHE_DIR,
|
| 734 |
+
prompt_audio_path=reference_audio_path)
|
| 735 |
+
|
| 736 |
+
timeout_seconds = 120
|
| 737 |
+
audio_a_path, ref_a = future_a.result(timeout=timeout_seconds)
|
| 738 |
+
audio_b_path, ref_b = future_b.result(timeout=timeout_seconds)
|
| 739 |
+
|
| 740 |
+
if not audio_a_path or not audio_b_path:
|
| 741 |
+
return jsonify({"error": "Failed to generate TTS audio"}), 500
|
| 742 |
+
|
| 743 |
+
session_id = str(uuid.uuid4())
|
| 744 |
+
app.tts_sessions[session_id] = {
|
| 745 |
+
"model_a": models[0].id,
|
| 746 |
+
"model_b": models[1].id,
|
| 747 |
+
"audio_a": audio_a_path,
|
| 748 |
+
"audio_b": audio_b_path,
|
| 749 |
+
"text": text,
|
| 750 |
+
"created_at": datetime.utcnow(),
|
| 751 |
+
"expires_at": datetime.utcnow() + timedelta(minutes=30),
|
| 752 |
+
"voted": False,
|
| 753 |
+
}
|
| 754 |
+
|
| 755 |
+
# Clean up temporary reference audio file if it was provided
|
| 756 |
+
if reference_audio_path and os.path.exists(reference_audio_path):
|
| 757 |
+
os.remove(reference_audio_path)
|
| 758 |
+
|
| 759 |
+
# Return response with session ID and audio URLs
|
| 760 |
+
return jsonify(
|
| 761 |
+
{
|
| 762 |
+
"session_id": session_id,
|
| 763 |
+
"audio_a": f"/api/tts/audio/{session_id}/a",
|
| 764 |
+
"audio_b": f"/api/tts/audio/{session_id}/b",
|
| 765 |
+
"expires_in": 1800, # 30 minutes in seconds
|
| 766 |
+
"cache_hit": False,
|
| 767 |
+
}
|
| 768 |
+
)
|
| 769 |
+
|
| 770 |
|
| 771 |
|
| 772 |
@app.route("/api/tts/audio/<session_id>/<model_key>")
|
models.py
CHANGED
|
@@ -84,6 +84,7 @@ class EloHistory(db.Model):
|
|
| 84 |
model_id = db.Column(db.String(100), db.ForeignKey("model.id"), nullable=False)
|
| 85 |
timestamp = db.Column(db.DateTime, default=datetime.utcnow)
|
| 86 |
elo_score = db.Column(db.Float, nullable=False)
|
|
|
|
| 87 |
vote_id = db.Column(db.Integer, db.ForeignKey("vote.id"), nullable=True)
|
| 88 |
model_type = db.Column(db.String(20), nullable=False) # 'tts' or 'conversational'
|
| 89 |
|
|
@@ -130,18 +131,18 @@ def record_vote(user_id, text, chosen_model_id, rejected_model_id, model_type):
|
|
| 130 |
db.session.rollback()
|
| 131 |
return None, "One or both models not found for the specified model type"
|
| 132 |
|
| 133 |
-
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
chosen_model.current_elo, rejected_model.current_elo, k_factor_winner, k_factor_loser
|
| 139 |
)
|
| 140 |
|
| 141 |
-
# new_chosen_elo, new_rejected_elo = calculate_elo_change(
|
| 142 |
-
# chosen_model.current_elo, rejected_model.current_elo
|
| 143 |
-
# )
|
| 144 |
-
|
| 145 |
# Update model stats
|
| 146 |
chosen_model.current_elo = new_chosen_elo
|
| 147 |
chosen_model.win_count += 1
|
|
@@ -535,32 +536,69 @@ def toggle_user_leaderboard_visibility(user_id):
|
|
| 535 |
return user.show_in_leaderboard
|
| 536 |
|
| 537 |
|
| 538 |
-
def get_dynamic_k_factor(match_count):
|
| 539 |
-
"""
|
| 540 |
-
使用连续衰减函数动态计算K因子。
|
| 541 |
-
K因子会从一个最大值平滑地衰减到一个最小值。
|
| 542 |
-
|
| 543 |
-
Args:
|
| 544 |
-
match_count (int): 模型的总比赛次数。
|
| 545 |
-
|
| 546 |
-
Returns:
|
| 547 |
-
float: 计算出的K因子。
|
| 548 |
-
"""
|
| 549 |
k_max = 40 # 新模型的最大K因子
|
| 550 |
k_min = 10 # 成熟模型的最小K因子
|
| 551 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 552 |
|
| 553 |
-
#
|
| 554 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 555 |
|
| 556 |
return k_factor
|
| 557 |
|
| 558 |
-
def
|
| 559 |
-
"""
|
| 560 |
-
|
| 561 |
-
|
| 562 |
|
| 563 |
-
|
| 564 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 565 |
|
| 566 |
-
|
|
|
|
|
|
| 84 |
model_id = db.Column(db.String(100), db.ForeignKey("model.id"), nullable=False)
|
| 85 |
timestamp = db.Column(db.DateTime, default=datetime.utcnow)
|
| 86 |
elo_score = db.Column(db.Float, nullable=False)
|
| 87 |
+
by_system = db.Column(db.Boolean, default=False) # Whether this is a penalty or reward change
|
| 88 |
vote_id = db.Column(db.Integer, db.ForeignKey("vote.id"), nullable=True)
|
| 89 |
model_type = db.Column(db.String(20), nullable=False) # 'tts' or 'conversational'
|
| 90 |
|
|
|
|
| 131 |
db.session.rollback()
|
| 132 |
return None, "One or both models not found for the specified model type"
|
| 133 |
|
| 134 |
+
# --- ELO 计算逻辑与 test_elo.py 保持一致 ---
|
| 135 |
+
# a. 计算双方的基础动态K因子
|
| 136 |
+
max_match = max(chosen_model.match_count, rejected_model.match_count, 10)
|
| 137 |
+
k_winner_base = get_dynamic_k_factor(chosen_model.match_count, max_match)
|
| 138 |
+
k_loser_base = get_dynamic_k_factor(rejected_model.match_count, max_match)
|
| 139 |
+
# b. 取平均K因子
|
| 140 |
+
base_k = (k_winner_base + k_loser_base) / 2.0
|
| 141 |
|
| 142 |
+
new_chosen_elo, new_rejected_elo = calculate_elo_change(
|
| 143 |
+
chosen_model.current_elo, rejected_model.current_elo, k_factor=base_k
|
|
|
|
| 144 |
)
|
| 145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
# Update model stats
|
| 147 |
chosen_model.current_elo = new_chosen_elo
|
| 148 |
chosen_model.win_count += 1
|
|
|
|
| 536 |
return user.show_in_leaderboard
|
| 537 |
|
| 538 |
|
| 539 |
+
def get_dynamic_k_factor(match_count, max_match_count):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 540 |
k_max = 40 # 新模型的最大K因子
|
| 541 |
k_min = 10 # 成熟模型的最小K因子
|
| 542 |
+
decay_factor = 5.0 # 衰减因子,控制K因子下降的速度
|
| 543 |
+
|
| 544 |
+
# 防止除以零
|
| 545 |
+
if max_match_count == 0:
|
| 546 |
+
return k_max
|
| 547 |
|
| 548 |
+
# 计算相对比赛进度 (0到1之间)
|
| 549 |
+
relative_progress = match_count / max_match_count
|
| 550 |
+
|
| 551 |
+
# 使用指数衰减公式,但基于相对进度
|
| 552 |
+
# K = K_min + (K_max - K_min) * e^(-decay_factor * relative_progress)
|
| 553 |
+
k_factor = k_min + (k_max - k_min) * math.exp(-decay_factor * relative_progress)
|
| 554 |
|
| 555 |
return k_factor
|
| 556 |
|
| 557 |
+
def apply_filter_penalty_and_redistribute(unavailable_models, available_models, penalty_amount=1.0):
|
| 558 |
+
"""
|
| 559 |
+
对不可用的模型施加惩罚,并将扣除的分数平均重新分配给可用的模型。
|
| 560 |
+
这确保了系统的ELO总分保持不变(零和)。
|
| 561 |
|
| 562 |
+
Args:
|
| 563 |
+
unavailable_models (list[Model]): 因被过滤而不可用的模型对象列表。
|
| 564 |
+
available_models (list[Model]): 当前可用的模型对象列表。
|
| 565 |
+
penalty_amount (float): 每个不可用模型被扣除的ELO分数。
|
| 566 |
+
"""
|
| 567 |
+
if not unavailable_models or not available_models:
|
| 568 |
+
# 如果没有不可用模型或没有可用的模型来接收分数,则不执行任何操作
|
| 569 |
+
return
|
| 570 |
+
|
| 571 |
+
# 1. 计算总惩罚分数
|
| 572 |
+
total_penalty = len(unavailable_models) * penalty_amount
|
| 573 |
+
reward_per_model = total_penalty / len(available_models)
|
| 574 |
+
|
| 575 |
+
# 2. 从不可用模型中扣除分数并记录历史
|
| 576 |
+
for model in unavailable_models:
|
| 577 |
+
new_elo = model.current_elo - penalty_amount
|
| 578 |
+
model.current_elo = new_elo
|
| 579 |
+
# 为惩罚创建一条历史记录 (没有 vote_id)
|
| 580 |
+
penalty_history = EloHistory(
|
| 581 |
+
model_id=model.id,
|
| 582 |
+
elo_score=new_elo,
|
| 583 |
+
vote_id=None,
|
| 584 |
+
by_system=True,
|
| 585 |
+
model_type=model.model_type,
|
| 586 |
+
)
|
| 587 |
+
db.session.add(penalty_history)
|
| 588 |
+
|
| 589 |
+
# 3. 将分数奖励给可用模型并记录历史
|
| 590 |
+
for model in available_models:
|
| 591 |
+
new_elo = model.current_elo + reward_per_model
|
| 592 |
+
model.current_elo = new_elo
|
| 593 |
+
# 为奖励创建一条历史记录 (没有 vote_id)
|
| 594 |
+
reward_history = EloHistory(
|
| 595 |
+
model_id=model.id,
|
| 596 |
+
elo_score=new_elo,
|
| 597 |
+
vote_id=None,
|
| 598 |
+
by_system=True,
|
| 599 |
+
model_type=model.model_type,
|
| 600 |
+
)
|
| 601 |
+
db.session.add(reward_history)
|
| 602 |
|
| 603 |
+
# 4. 提交所有更改到数据库
|
| 604 |
+
db.session.commit()
|