Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -232,17 +232,17 @@ def bot(history):
|
|
| 232 |
our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
|
| 233 |
prompt = our_chatbot.conversation.get_prompt()
|
| 234 |
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
# prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
|
| 238 |
-
# )
|
| 239 |
-
# .unsqueeze(0)
|
| 240 |
-
# .to(our_chatbot.model.device)
|
| 241 |
-
# )
|
| 242 |
-
input_ids = tokenizer_image_token(
|
| 243 |
prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
|
| 244 |
-
)
|
| 245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
stop_str = (
|
| 247 |
our_chatbot.conversation.sep
|
| 248 |
if our_chatbot.conversation.sep_style != SeparatorStyle.TWO
|
|
@@ -252,58 +252,58 @@ def bot(history):
|
|
| 252 |
stopping_criteria = KeywordsStoppingCriteria(
|
| 253 |
keywords, our_chatbot.tokenizer, input_ids
|
| 254 |
)
|
| 255 |
-
|
| 256 |
-
# our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
|
| 257 |
-
# )
|
| 258 |
-
streamer = TextIteratorStreamer(
|
| 259 |
our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
|
| 260 |
)
|
|
|
|
|
|
|
|
|
|
| 261 |
print(our_chatbot.model.device)
|
| 262 |
print(input_ids.device)
|
| 263 |
print(image_tensor.device)
|
| 264 |
# import pdb;pdb.set_trace()
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
generate_kwargs = dict(
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
)
|
| 296 |
|
| 297 |
-
t = Thread(target=our_chatbot.model.generate, kwargs=generate_kwargs)
|
| 298 |
-
t.start()
|
| 299 |
|
| 300 |
-
outputs = []
|
| 301 |
-
for text in streamer:
|
| 302 |
-
|
| 303 |
-
|
| 304 |
|
| 305 |
-
our_chatbot.conversation.messages[-1][-1] = "".join(outputs)
|
| 306 |
-
history[-1] = [text, "".join(outputs)]
|
| 307 |
|
| 308 |
|
| 309 |
txt = gr.Textbox(
|
|
|
|
| 232 |
our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
|
| 233 |
prompt = our_chatbot.conversation.get_prompt()
|
| 234 |
|
| 235 |
+
input_ids = (
|
| 236 |
+
tokenizer_image_token(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
|
| 238 |
+
)
|
| 239 |
+
.unsqueeze(0)
|
| 240 |
+
.to(our_chatbot.model.device)
|
| 241 |
+
)
|
| 242 |
+
# input_ids = tokenizer_image_token(
|
| 243 |
+
# prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
|
| 244 |
+
# ).unsqueeze(0).to(our_chatbot.model.device)
|
| 245 |
+
# print("### input_id",input_ids)
|
| 246 |
stop_str = (
|
| 247 |
our_chatbot.conversation.sep
|
| 248 |
if our_chatbot.conversation.sep_style != SeparatorStyle.TWO
|
|
|
|
| 252 |
stopping_criteria = KeywordsStoppingCriteria(
|
| 253 |
keywords, our_chatbot.tokenizer, input_ids
|
| 254 |
)
|
| 255 |
+
streamer = TextStreamer(
|
|
|
|
|
|
|
|
|
|
| 256 |
our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
|
| 257 |
)
|
| 258 |
+
# streamer = TextIteratorStreamer(
|
| 259 |
+
# our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
|
| 260 |
+
# )
|
| 261 |
print(our_chatbot.model.device)
|
| 262 |
print(input_ids.device)
|
| 263 |
print(image_tensor.device)
|
| 264 |
# import pdb;pdb.set_trace()
|
| 265 |
+
with torch.inference_mode():
|
| 266 |
+
output_ids = our_chatbot.model.generate(
|
| 267 |
+
input_ids,
|
| 268 |
+
images=image_tensor,
|
| 269 |
+
do_sample=True,
|
| 270 |
+
temperature=0.2,
|
| 271 |
+
max_new_tokens=1024,
|
| 272 |
+
streamer=streamer,
|
| 273 |
+
use_cache=False,
|
| 274 |
+
stopping_criteria=[stopping_criteria],
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
outputs = our_chatbot.tokenizer.decode(output_ids[0]).strip()
|
| 278 |
+
if outputs.endswith(stop_str):
|
| 279 |
+
outputs = outputs[: -len(stop_str)]
|
| 280 |
+
our_chatbot.conversation.messages[-1][-1] = outputs
|
| 281 |
+
|
| 282 |
+
history[-1] = [text, outputs]
|
| 283 |
+
|
| 284 |
+
return history
|
| 285 |
+
# generate_kwargs = dict(
|
| 286 |
+
# inputs=input_ids,
|
| 287 |
+
# streamer=streamer,
|
| 288 |
+
# images=image_tensor,
|
| 289 |
+
# max_new_tokens=1024,
|
| 290 |
+
# do_sample=True,
|
| 291 |
+
# temperature=0.2,
|
| 292 |
+
# num_beams=1,
|
| 293 |
+
# use_cache=False,
|
| 294 |
+
# stopping_criteria=[stopping_criteria],
|
| 295 |
+
# )
|
| 296 |
|
| 297 |
+
# t = Thread(target=our_chatbot.model.generate, kwargs=generate_kwargs)
|
| 298 |
+
# t.start()
|
| 299 |
|
| 300 |
+
# outputs = []
|
| 301 |
+
# for text in streamer:
|
| 302 |
+
# outputs.append(text)
|
| 303 |
+
# yield "".join(outputs)
|
| 304 |
|
| 305 |
+
# our_chatbot.conversation.messages[-1][-1] = "".join(outputs)
|
| 306 |
+
# history[-1] = [text, "".join(outputs)]
|
| 307 |
|
| 308 |
|
| 309 |
txt = gr.Textbox(
|