Spaces:
Runtime error
Runtime error
handle the correct prompting style
Browse files
app.py
CHANGED
|
@@ -21,12 +21,26 @@ dynamodb = boto3.resource('dynamodb', region_name='us-east-1')
|
|
| 21 |
# Get a reference to the table
|
| 22 |
table = dynamodb.Table('oaaic_chatbot_arena')
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
class Pipeline:
|
| 25 |
prefer_async = True
|
| 26 |
|
| 27 |
-
def __init__(self, endpoint_id, name):
|
| 28 |
self.endpoint_id = endpoint_id
|
| 29 |
self.name = name
|
|
|
|
| 30 |
self.generation_config = {
|
| 31 |
"max_tokens": 1024,
|
| 32 |
"top_k": 40,
|
|
@@ -37,7 +51,7 @@ class Pipeline:
|
|
| 37 |
"seed": -1,
|
| 38 |
"batch_size": 8,
|
| 39 |
"threads": -1,
|
| 40 |
-
"stop": ["</s>", "USER:"],
|
| 41 |
}
|
| 42 |
|
| 43 |
def __call__(self, prompt):
|
|
@@ -79,13 +93,16 @@ class Pipeline:
|
|
| 79 |
# Sleep for 3 seconds between each request
|
| 80 |
sleep(3)
|
| 81 |
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
AVAILABLE_MODELS = {
|
| 84 |
-
"hermes-13b": "p0zqb2gkcwp0ww",
|
| 85 |
-
"manticore-13b-chat": "u6tv84bpomhfei",
|
| 86 |
-
"airoboros-13b": "rglzxnk80660ja",
|
| 87 |
-
"supercot-13b": "0be7865dwxpwqk",
|
| 88 |
-
"mpt-7b-instruct": "jpqbvnyluj18b0",
|
| 89 |
}
|
| 90 |
|
| 91 |
_memoized_models = defaultdict()
|
|
@@ -93,7 +110,7 @@ _memoized_models = defaultdict()
|
|
| 93 |
|
| 94 |
def get_model_pipeline(model_name):
|
| 95 |
if not _memoized_models.get(model_name):
|
| 96 |
-
_memoized_models[model_name] = Pipeline(AVAILABLE_MODELS[model_name], model_name)
|
| 97 |
return _memoized_models.get(model_name)
|
| 98 |
|
| 99 |
start_message = """- The Assistant is helpful and transparent.
|
|
@@ -116,20 +133,17 @@ def chat(history1, history2, system_msg):
|
|
| 116 |
history1 = history1 or []
|
| 117 |
history2 = history2 or []
|
| 118 |
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
|
| 126 |
# remove last space from assistant, some models output a ZWSP if you leave a space
|
| 127 |
messages1 = messages1.rstrip()
|
| 128 |
messages2 = messages2.rstrip()
|
| 129 |
|
| 130 |
-
random_battle = random.sample(AVAILABLE_MODELS.keys(), 2)
|
| 131 |
-
model1 = get_model_pipeline(random_battle[0])
|
| 132 |
-
model2 = get_model_pipeline(random_battle[1])
|
| 133 |
|
| 134 |
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
| 135 |
futures = []
|
|
@@ -212,14 +226,14 @@ with gr.Blocks() as arena:
|
|
| 212 |
dismiss_reveal = gr.Button(value="Dismiss & Continue", variant="secondary", visible=False).style(full_width=True)
|
| 213 |
with gr.Row():
|
| 214 |
with gr.Column():
|
| 215 |
-
rlhf_persona = gr.Textbox(
|
| 216 |
-
"", label="Persona Tags", interactive=True, visible=True, placeholder="Tell us about how you are judging the quality. ex: #CoT #SFW #NSFW #helpful #ethical #creativity", lines=2)
|
| 217 |
message = gr.Textbox(
|
| 218 |
label="What do you want to ask?",
|
| 219 |
placeholder="Ask me anything.",
|
| 220 |
lines=3,
|
| 221 |
)
|
| 222 |
with gr.Column():
|
|
|
|
|
|
|
| 223 |
system_msg = gr.Textbox(
|
| 224 |
start_message, label="System Message", interactive=True, visible=True, placeholder="system prompt", lines=8)
|
| 225 |
|
|
|
|
| 21 |
# Get a reference to the table
|
| 22 |
table = dynamodb.Table('oaaic_chatbot_arena')
|
| 23 |
|
| 24 |
+
|
| 25 |
+
def prompt_instruct(system_msg, history):
|
| 26 |
+
return system_msg.strip() + "\n" + \
|
| 27 |
+
"\n".join(["\n".join(["### Instruction: "+item[0], "### Response: "+item[1]])
|
| 28 |
+
for item in history])
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def prompt_chat(system_msg, history):
|
| 32 |
+
return system_msg.strip() + "\n" + \
|
| 33 |
+
"\n".join(["\n".join(["USER: "+item[0], "ASSISTANT: "+item[1]])
|
| 34 |
+
for item in history])
|
| 35 |
+
|
| 36 |
+
|
| 37 |
class Pipeline:
|
| 38 |
prefer_async = True
|
| 39 |
|
| 40 |
+
def __init__(self, endpoint_id, name, prompt_fn):
|
| 41 |
self.endpoint_id = endpoint_id
|
| 42 |
self.name = name
|
| 43 |
+
self.prompt_fn = prompt_fn
|
| 44 |
self.generation_config = {
|
| 45 |
"max_tokens": 1024,
|
| 46 |
"top_k": 40,
|
|
|
|
| 51 |
"seed": -1,
|
| 52 |
"batch_size": 8,
|
| 53 |
"threads": -1,
|
| 54 |
+
"stop": ["</s>", "USER:", "### Instruction:"],
|
| 55 |
}
|
| 56 |
|
| 57 |
def __call__(self, prompt):
|
|
|
|
| 93 |
# Sleep for 3 seconds between each request
|
| 94 |
sleep(3)
|
| 95 |
|
| 96 |
+
def transform_prompt(self, system_msg, history):
|
| 97 |
+
return self.prompt_fn(system_msg, history)
|
| 98 |
+
|
| 99 |
|
| 100 |
AVAILABLE_MODELS = {
|
| 101 |
+
"hermes-13b": ("p0zqb2gkcwp0ww", prompt_instruct),
|
| 102 |
+
"manticore-13b-chat": ("u6tv84bpomhfei", prompt_chat),
|
| 103 |
+
"airoboros-13b": ("rglzxnk80660ja", prompt_chat),
|
| 104 |
+
"supercot-13b": ("0be7865dwxpwqk", prompt_instruct),
|
| 105 |
+
"mpt-7b-instruct": ("jpqbvnyluj18b0", prompt_instruct),
|
| 106 |
}
|
| 107 |
|
| 108 |
_memoized_models = defaultdict()
|
|
|
|
| 110 |
|
| 111 |
def get_model_pipeline(model_name):
|
| 112 |
if not _memoized_models.get(model_name):
|
| 113 |
+
_memoized_models[model_name] = Pipeline(AVAILABLE_MODELS[model_name][0], model_name, AVAILABLE_MODELS[model_name][1])
|
| 114 |
return _memoized_models.get(model_name)
|
| 115 |
|
| 116 |
start_message = """- The Assistant is helpful and transparent.
|
|
|
|
| 133 |
history1 = history1 or []
|
| 134 |
history2 = history2 or []
|
| 135 |
|
| 136 |
+
random_battle = random.sample(AVAILABLE_MODELS.keys(), 2)
|
| 137 |
+
model1 = get_model_pipeline(random_battle[0])
|
| 138 |
+
model2 = get_model_pipeline(random_battle[1])
|
| 139 |
+
|
| 140 |
+
messages1 = model1.transform_prompt(system_msg, history1)
|
| 141 |
+
messages2 = model2.transform_prompt(system_msg, history2)
|
| 142 |
|
| 143 |
# remove last space from assistant, some models output a ZWSP if you leave a space
|
| 144 |
messages1 = messages1.rstrip()
|
| 145 |
messages2 = messages2.rstrip()
|
| 146 |
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
| 149 |
futures = []
|
|
|
|
| 226 |
dismiss_reveal = gr.Button(value="Dismiss & Continue", variant="secondary", visible=False).style(full_width=True)
|
| 227 |
with gr.Row():
|
| 228 |
with gr.Column():
|
|
|
|
|
|
|
| 229 |
message = gr.Textbox(
|
| 230 |
label="What do you want to ask?",
|
| 231 |
placeholder="Ask me anything.",
|
| 232 |
lines=3,
|
| 233 |
)
|
| 234 |
with gr.Column():
|
| 235 |
+
rlhf_persona = gr.Textbox(
|
| 236 |
+
"", label="Persona Tags", interactive=True, visible=True, placeholder="Tell us about how you are judging the quality. ex: #CoT #SFW #NSFW #helpful #ethical #creativity", lines=2)
|
| 237 |
system_msg = gr.Textbox(
|
| 238 |
start_message, label="System Message", interactive=True, visible=True, placeholder="system prompt", lines=8)
|
| 239 |
|