Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| import json | |
| from trueskill import TrueSkill | |
| import paramiko | |
| import io, os | |
| import sys | |
| from serve.constants import SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD, SSH_SKILL | |
| trueskill_env = TrueSkill() | |
| sys.path.append('../') | |
| from model.models import IMAGE_GENERATION_MODELS | |
| ssh_skill_client = None | |
| sftp_skill_client = None | |
| def create_ssh_skill_client(server, port, user, password): | |
| global ssh_skill_client, sftp_skill_client | |
| ssh_skill_client = paramiko.SSHClient() | |
| ssh_skill_client.load_system_host_keys() | |
| ssh_skill_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) | |
| ssh_skill_client.connect(server, port, user, password) | |
| transport = ssh_skill_client.get_transport() | |
| transport.set_keepalive(60) | |
| sftp_skill_client = ssh_skill_client.open_sftp() | |
| def is_connected(): | |
| global ssh_skill_client, sftp_skill_client | |
| if ssh_skill_client is None or sftp_skill_client is None: | |
| return False | |
| if not ssh_skill_client.get_transport().is_active(): | |
| return False | |
| try: | |
| sftp_skill_client.listdir('.') | |
| except Exception as e: | |
| print(f"Error checking SFTP connection: {e}") | |
| return False | |
| return True | |
| def ucb_score(trueskill_diff, t, n): | |
| exploration_term = np.sqrt((2 * np.log(t + 1e-5)) / (n + 1e-5)) | |
| ucb = -trueskill_diff + 1.0 * exploration_term | |
| return ucb | |
| def update_trueskill(ratings, ranks): | |
| new_ratings = trueskill_env.rate(ratings, ranks) | |
| return new_ratings | |
| def serialize_rating(rating): | |
| return {'mu': rating.mu, 'sigma': rating.sigma} | |
| def deserialize_rating(rating_dict): | |
| return trueskill_env.Rating(mu=rating_dict['mu'], sigma=rating_dict['sigma']) | |
| def save_json_via_sftp(ratings, comparison_counts, total_comparisons): | |
| global sftp_skill_client | |
| if not is_connected(): | |
| create_ssh_skill_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) | |
| data = { | |
| 'ratings': [serialize_rating(r) for r in ratings], | |
| 'comparison_counts': comparison_counts.tolist(), | |
| 'total_comparisons': total_comparisons | |
| } | |
| json_data = json.dumps(data) | |
| with sftp_skill_client.open(SSH_SKILL, 'w') as f: | |
| f.write(json_data) | |
| def load_json_via_sftp(): | |
| global sftp_skill_client | |
| if not is_connected(): | |
| create_ssh_skill_client(SSH_SERVER, SSH_PORT, SSH_USER, SSH_PASSWORD) | |
| with sftp_skill_client.open(SSH_SKILL, 'r') as f: | |
| data = json.load(f) | |
| ratings = [deserialize_rating(r) for r in data['ratings']] | |
| comparison_counts = np.array(data['comparison_counts']) | |
| total_comparisons = data['total_comparisons'] | |
| return ratings, comparison_counts, total_comparisons | |
| def update_skill(rank, model_names, k_group=4): | |
| ratings, comparison_counts, total_comparisons = load_json_via_sftp() | |
| # group = Model_ID.group | |
| group = [] | |
| for model_name in model_names: | |
| group.append(IMAGE_GENERATION_MODELS.index(model_name)) | |
| print(group) | |
| group_ratings = [[ratings[player]] for player in group] | |
| updated_group_ratings = update_trueskill(group_ratings, ranks=rank) | |
| for i, player in enumerate(group): | |
| ratings[player] = updated_group_ratings[i][0] | |
| for i in range(len(group)): | |
| for j in range(i + 1, len(group)): | |
| comparison_counts[group[i], group[j]] += 1 | |
| comparison_counts[group[j], group[i]] += 1 | |
| # pairwise_comparisons = [(i, j) for i in range(len(group)) for j in range(i+1, len(group))] | |
| # for player1, player2 in pairwise_comparisons: | |
| # if rank[player1] < rank[player2]: | |
| # ranks = [0, 1] | |
| # updated_ratings = update_trueskill([[ratings[group[player1]]], [ratings[group[player2]]]], ranks) | |
| # ratings[group[player1]], ratings[group[player2]] = updated_ratings[0][0], updated_ratings[1][0] | |
| # elif rank[player1] > rank[player2]: | |
| # ranks = [1, 0] | |
| # updated_ratings = update_trueskill([[ratings[group[player1]]], [ratings[group[player2]]]], ranks) | |
| # ratings[group[player1]], ratings[group[player2]] = updated_ratings[0][0], updated_ratings[1][0] | |
| # comparison_counts[group[player1], group[player2]] += 1 | |
| # comparison_counts[group[player2], group[player1]] += 1 | |
| total_comparisons += 1 | |
| save_json_via_sftp(ratings, comparison_counts, total_comparisons) | |
| from model.matchmaker import RunningPivot | |
| if group[0] in RunningPivot.running_pivot: | |
| RunningPivot.running_pivot.remove(group[0]) |