arena / chat_state.py
terryyz
init
fe3f5b0
'''
Chat State and Logging
'''
import json
import os
from typing import Any, Literal, Optional
from conversation import Conversation
import datetime
import uuid
LOG_DIR = os.getenv("LOGDIR", "./logs")
'''
The default output dir of log files
'''
class ModelChatState:
'''
The state of a chat with a model.
'''
is_vision: bool
'''
Whether the model is vision based.
'''
conv: Conversation
'''
The conversation
'''
conv_id: str
'''
Unique identifier for the model conversation.
Unique per chat per model.
'''
chat_session_id: str
'''
Unique identifier for the chat session.
Unique per chat. The two battle models share the same chat session id.
'''
skip_next: bool
'''
Flag to indicate skipping the next operation.
'''
model_name: str
'''
Name of the model being used.
'''
oai_thread_id: Optional[str]
'''
Identifier for the OpenAI thread.
'''
has_csam_image: bool
'''
Indicates if a CSAM image has been uploaded.
'''
regen_support: bool
'''
Indicates if regeneration is supported for the model.
'''
chat_start_time: datetime.datetime
'''
Chat start time.
'''
chat_mode: Literal['battle_anony', 'battle_named', 'direct']
'''
Chat mode.
'''
curr_response_type: Literal['chat_multi', 'chat_single', 'regenerate_multi', 'regenerate_single'] | None
'''
Current response type. Used for logging.
'''
@staticmethod
def create_chat_session_id() -> str:
'''
Create a new chat session id.
'''
return uuid.uuid4().hex
@staticmethod
def create_battle_chat_states(
model_name_1: str, model_name_2: str,
chat_mode: Literal['battle_anony', 'battle_named'],
is_vision: bool,
) -> tuple['ModelChatState', 'ModelChatState']:
'''
Create two chat states for a battle.
'''
chat_session_id = ModelChatState.create_chat_session_id()
return (
ModelChatState(model_name_1, chat_mode,
is_vision=is_vision,
chat_session_id=chat_session_id),
ModelChatState(model_name_2, chat_mode,
is_vision=is_vision,
chat_session_id=chat_session_id),
)
def __init__(self,
model_name: str,
chat_mode: Literal['battle_anony', 'battle_named', 'direct'],
is_vision: bool,
chat_session_id: str | None = None,
):
from fastchat.model.model_adapter import get_conversation_template
self.conv = get_conversation_template(model_name)
self.conv_id = uuid.uuid4().hex
# if no chat session id is provided, use the conversation id
self.chat_session_id = chat_session_id if chat_session_id else self.conv_id
self.chat_start_time = datetime.datetime.now()
self.chat_mode = chat_mode
self.skip_next = False
self.model_name = model_name
self.oai_thread_id = None
self.is_vision = is_vision
# NOTE(chris): This could be sort of a hack since it assumes the user only uploads one image. If they can upload multiple, we should store a list of image hashes.
self.has_csam_image = False
self.regen_support = True
if "browsing" in model_name:
self.regen_support = False
self.init_system_prompt(self.conv, is_vision)
def init_system_prompt(self, conv, is_vision):
system_prompt = conv.get_system_message(is_vision)
if len(system_prompt) == 0:
return
current_date = datetime.datetime.now().strftime("%Y-%m-%d")
system_prompt = system_prompt.replace("{{currentDateTime}}", current_date)
current_date_v2 = datetime.datetime.now().strftime("%d %b %Y")
system_prompt = system_prompt.replace("{{currentDateTimev2}}", current_date_v2)
current_date_v3 = datetime.datetime.now().strftime("%B %Y")
system_prompt = system_prompt.replace("{{currentDateTimev3}}", current_date_v3)
conv.set_system_message(system_prompt)
def set_response_type(
self,
response_type: Literal['chat_multi', 'chat_single', 'regenerate_multi', 'regenerate_single']
):
'''
Set the response type for the chat state.
'''
self.curr_response_type = response_type
def to_gradio_chatbot(self):
'''
Convert to a Gradio chatbot.
'''
return self.conv.to_gradio_chatbot()
def get_conv_log_filepath(self, path_prefix: str):
'''
Get the filepath for the conversation log.
Expected directory structure:
softwarearenlog/
└── YEAR_MONTH_DAY/
β”œβ”€β”€ conv_logs/
└── sandbox_logs/
'''
date_str = self.chat_start_time.strftime('%Y_%m_%d')
filepath = os.path.join(
path_prefix,
date_str,
'conv_logs',
self.chat_mode,
f"conv-log-{self.chat_session_id}.json"
)
return filepath
def to_dict(self):
base = self.conv.to_dict()
base.update(
{
"chat_session_id": self.chat_session_id,
"conv_id": self.conv_id,
"chat_mode": self.chat_mode,
"chat_start_time": self.chat_start_time,
"model_name": self.model_name,
}
)
if self.is_vision:
base.update({"has_csam_image": self.has_csam_image})
return base
def generate_vote_record(
self,
vote_type: str,
ip: str
) -> dict[str, Any]:
'''
Generate a vote record for telemertry.
'''
data = {
"tstamp": round(datetime.datetime.now().timestamp(), 4),
"type": vote_type,
"model": self.model_name,
"state": self.to_dict(),
"ip": ip,
}
return data
def generate_response_record(
self,
gen_params: dict[str, Any],
start_ts: float,
end_ts: float,
ip: str
) -> dict[str, Any]:
'''
Generate a vote record for telemertry.
'''
data = {
"tstamp": round(datetime.datetime.now().timestamp(), 4),
"type": self.curr_response_type,
"model": self.model_name,
"start_ts": round(start_ts, 4),
"end_ts": round(end_ts, 4),
"gen_params": gen_params,
"state": self.to_dict(),
"ip": ip,
}
return data
def save_log_to_local(
log_data: dict[str, Any],
log_path: str,
write_mode: Literal['overwrite', 'append'] = 'append'
):
'''
Save the log locally.
'''
log_json = json.dumps(log_data, default=str)
os.makedirs(os.path.dirname(log_path), exist_ok=True)
with open(log_path, "w" if write_mode == 'overwrite' else 'a') as fout:
fout.write(log_json + "\n")