Spaces:
Runtime error
Runtime error
| import logging | |
| import os | |
| from datetime import datetime | |
| from decimal import Decimal | |
| from typing import List | |
| import boto3 | |
| from boto3.dynamodb.conditions import Attr, Key | |
| from datasets import Dataset | |
| logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO")) | |
| # Create a DynamoDB client | |
| dynamodb = boto3.resource('dynamodb', region_name='us-east-1') | |
| def _create_arena_table(): | |
| dynamodb.create_table( | |
| TableName='oaaic_chatbot_arena', | |
| KeySchema=[ | |
| { | |
| 'AttributeName': 'arena_battle_id', | |
| 'KeyType': 'HASH' | |
| }, | |
| ], | |
| AttributeDefinitions=[ | |
| { | |
| 'AttributeName': 'arena_battle_id', | |
| 'AttributeType': 'S' | |
| }, | |
| { | |
| 'AttributeName': 'timestamp', | |
| 'AttributeType': 'S' | |
| }, | |
| ], | |
| ProvisionedThroughput={ | |
| 'ReadCapacityUnits': 5, | |
| 'WriteCapacityUnits': 5 | |
| }, | |
| GlobalSecondaryIndexes=[ | |
| { | |
| 'IndexName': 'TimestampIndex', | |
| 'KeySchema': [ | |
| { | |
| 'AttributeName': 'arena_battle_id', | |
| 'KeyType': 'HASH' | |
| }, | |
| { | |
| 'AttributeName': 'timestamp', | |
| 'KeyType': 'RANGE' | |
| }, | |
| ], | |
| 'Projection': { | |
| 'ProjectionType': 'ALL', | |
| }, | |
| 'ProvisionedThroughput': { | |
| 'ReadCapacityUnits': 5, | |
| 'WriteCapacityUnits': 5, | |
| } | |
| }, | |
| ] | |
| ) | |
| def _create_elo_scores_table(): | |
| dynamodb.create_table( | |
| TableName='elo_scores', | |
| KeySchema=[ | |
| { | |
| 'AttributeName': 'chatbot_name', | |
| 'KeyType': 'HASH' # Partition key | |
| }, | |
| ], | |
| AttributeDefinitions=[ | |
| { | |
| 'AttributeName': 'chatbot_name', | |
| 'AttributeType': 'S' | |
| }, | |
| ], | |
| ProvisionedThroughput={ | |
| 'ReadCapacityUnits': 5, | |
| 'WriteCapacityUnits': 5 | |
| } | |
| ) | |
| def _create_elo_logs_table(): | |
| dynamodb.create_table( | |
| TableName='elo_logs', | |
| KeySchema=[ | |
| { | |
| 'AttributeName': 'arena_battle_id', | |
| 'KeyType': 'HASH' # Partition key | |
| }, | |
| { | |
| 'AttributeName': 'battle_timestamp', | |
| 'KeyType': 'RANGE' # Sort key | |
| }, | |
| ], | |
| AttributeDefinitions=[ | |
| { | |
| 'AttributeName': 'arena_battle_id', | |
| 'AttributeType': 'S' | |
| }, | |
| { | |
| 'AttributeName': 'battle_timestamp', | |
| 'AttributeType': 'S' | |
| }, | |
| { | |
| 'AttributeName': 'all', | |
| 'AttributeType': 'S' | |
| } | |
| ], | |
| ProvisionedThroughput={ | |
| 'ReadCapacityUnits': 10, | |
| 'WriteCapacityUnits': 10 | |
| }, | |
| GlobalSecondaryIndexes=[ | |
| { | |
| 'IndexName': 'AllTimestampIndex', | |
| 'KeySchema': [ | |
| { | |
| 'AttributeName': 'all', | |
| 'KeyType': 'HASH' # Partition key for the GSI | |
| }, | |
| { | |
| 'AttributeName': 'battle_timestamp', | |
| 'KeyType': 'RANGE' # Sort key for the GSI | |
| } | |
| ], | |
| 'Projection': { | |
| 'ProjectionType': 'ALL' | |
| }, | |
| 'ProvisionedThroughput': { | |
| 'ReadCapacityUnits': 10, | |
| 'WriteCapacityUnits': 10 | |
| } | |
| }, | |
| ] | |
| ) | |
| def get_unprocessed_battles(last_processed_timestamp): | |
| # Use boto3 to create a DynamoDB resource and reference the table | |
| table = dynamodb.Table('oaaic_chatbot_arena') | |
| # Use a query to retrieve unprocessed battles in temporal order | |
| response = table.scan( | |
| FilterExpression=Attr('timestamp').gt(last_processed_timestamp), | |
| # ScanIndexForward=True | |
| ) | |
| return response['Items'] | |
| def calculate_elo(rating1, rating2, result, K=32): | |
| # Convert ratings to float | |
| rating1 = float(rating1) | |
| rating2 = float(rating2) | |
| # Calculate the expected outcomes | |
| expected_outcome1 = 1.0 / (1.0 + 10.0 ** ((rating2 - rating1) / 400.0)) | |
| expected_outcome2 = 1.0 - expected_outcome1 | |
| # Calculate the new Elo ratings | |
| new_rating1 = rating1 + K * (result - expected_outcome1) | |
| new_rating2 = rating2 + K * ((1.0 - result) - expected_outcome2) | |
| return Decimal(new_rating1).quantize(Decimal('0.00')), Decimal(new_rating2).quantize(Decimal('0.00')) | |
| def get_last_processed_timestamp(): | |
| table = dynamodb.Table('elo_logs') | |
| # Scan the table sorted by timestamp in descending order | |
| response = table.query( | |
| IndexName='AllTimestampIndex', | |
| KeyConditionExpression=Key('all').eq('ALL'), | |
| ScanIndexForward=False, | |
| Limit=1 | |
| ) | |
| # If there are no items in the table, return a default timestamp | |
| if not response['Items']: | |
| return '1970-01-01T00:00:00' | |
| # Otherwise, return the timestamp of the latest item | |
| return response['Items'][0]['battle_timestamp'] | |
| def log_elo_update(arena_battle_id, battle_timestamp, new_rating1, new_rating2): | |
| # Reference the elo_logs table | |
| table = dynamodb.Table('elo_logs') | |
| # Update the table | |
| table.put_item( | |
| Item={ | |
| 'arena_battle_id': arena_battle_id, | |
| 'battle_timestamp': battle_timestamp, # Use the timestamp of the battle | |
| 'log_timestamp': datetime.now().isoformat(), # Also store the timestamp of the log for completeness | |
| 'new_rating1': new_rating1, | |
| 'new_rating2': new_rating2, | |
| 'all': 'ALL', | |
| } | |
| ) | |
| def get_elo_score(chatbot_name, elo_scores): | |
| if chatbot_name in elo_scores: | |
| return elo_scores[chatbot_name] | |
| table = dynamodb.Table('elo_scores') | |
| response = table.get_item(Key={'chatbot_name': chatbot_name}) | |
| # If there is no item in the table, return a default score | |
| if 'Item' not in response: | |
| return 1500 | |
| return response['Item']['elo_score'] | |
| def update_elo_score(chatbot_name, new_elo_score): | |
| table = dynamodb.Table('elo_scores') | |
| # This will create a new item if it doesn't exist | |
| table.put_item( | |
| Item={ | |
| 'chatbot_name': chatbot_name, | |
| 'elo_score': Decimal(str(new_elo_score)), | |
| } | |
| ) | |
| def get_elo_scores(): | |
| table = dynamodb.Table('elo_scores') | |
| response = table.scan() | |
| data = response['Items'] | |
| return data | |
| def _backfill_logs(): | |
| table = dynamodb.Table('elo_logs') | |
| # Initialize the scan operation | |
| response = table.scan() | |
| for item in response['Items']: | |
| table.update_item( | |
| Key={ | |
| 'arena_battle_id': item['arena_battle_id'], | |
| 'battle_timestamp': item['battle_timestamp'] | |
| }, | |
| UpdateExpression="SET #all = :value", | |
| ExpressionAttributeNames={ | |
| '#all': 'all' | |
| }, | |
| ExpressionAttributeValues={ | |
| ':value': 'ALL' | |
| } | |
| ) | |
| def main(): | |
| last_processed_timestamp = get_last_processed_timestamp() | |
| battles: List[dict] = get_unprocessed_battles(last_processed_timestamp) | |
| battles = sorted(battles, key=lambda x: x['timestamp']) | |
| elo_scores = {} | |
| for battle in battles: | |
| print(repr(battle)) | |
| if battle['label'] in {-1, 0, 1, 2}: | |
| outcome = battle['label'] | |
| for chatbot_name in [battle['choice1_name'], battle['choice2_name']]: | |
| if chatbot_name not in elo_scores: | |
| elo_scores[chatbot_name] = get_elo_score(chatbot_name, elo_scores) | |
| # 1: This means that the first player (or team) won the match. | |
| # 0.5: This means that the match ended in a draw. | |
| # 0: This means that the first player (or team) lost the match. | |
| if outcome == 0 or outcome == -1: | |
| elo_result = 0.5 | |
| elif outcome == 1: | |
| elo_result = 1 | |
| else: | |
| elo_result = 0 | |
| new_rating1, new_rating2 = calculate_elo(elo_scores[battle['choice1_name']], elo_scores[battle['choice2_name']], elo_result) | |
| logging.info(f"{battle['choice1_name']}: {elo_scores[battle['choice1_name']]} -> {new_rating1} | {battle['choice2_name']}: {elo_scores[battle['choice2_name']]} -> {new_rating2}") | |
| elo_scores[battle['choice1_name']] = new_rating1 | |
| elo_scores[battle['choice2_name']] = new_rating2 | |
| log_elo_update(battle['arena_battle_id'], battle['timestamp'], new_rating1, new_rating2) | |
| update_elo_score(battle['choice1_name'], new_rating1) | |
| update_elo_score(battle['choice2_name'], new_rating2) | |
| elo_scores[battle['choice1_name']] = new_rating1 | |
| elo_scores[battle['choice2_name']] = new_rating2 | |
| elo_scores = get_elo_scores() | |
| for i, j in enumerate(elo_scores): | |
| j["elo_score"] = float(j["elo_score"]) | |
| elo_scores[i] = j | |
| print(elo_scores) | |
| if battles: | |
| # Convert the data into a format suitable for Hugging Face Dataset | |
| elo_dataset = Dataset.from_list(elo_scores) | |
| elo_dataset.push_to_hub("openaccess-ai-collective/chatbot-arena-elo-scores", private=False) | |
| if __name__ == "__main__": | |
| main() | |